From 94afb8a3ae5f5e76034855c5abdff459305a89c3 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Mon, 16 Feb 2026 15:11:31 -0800 Subject: [PATCH 1/2] fix: formatting Clears ~15,000 flake8 warnings --- .flake8 | 3 + allensdk/__init__.py | 9 +- allensdk/api/__init__.py | 4 +- allensdk/api/api.py | 162 +- allensdk/api/cloud_cache/cloud_cache.py | 334 ++--- allensdk/api/cloud_cache/file_attributes.py | 25 +- allensdk/api/cloud_cache/manifest.py | 79 +- allensdk/api/cloud_cache/utils.py | 8 +- allensdk/api/queries/__init__.py | 2 +- .../annotated_section_data_sets_api.py | 143 +- allensdk/api/queries/biophysical_api.py | 377 ++--- allensdk/api/queries/brain_observatory_api.py | 202 +-- allensdk/api/queries/cell_types_api.py | 228 ++- allensdk/api/queries/connected_services.py | 1319 +++++----------- allensdk/api/queries/glif_api.py | 169 +-- allensdk/api/queries/grid_data_api.py | 160 +- allensdk/api/queries/image_download_api.py | 354 +++-- allensdk/api/queries/mouse_atlas_api.py | 94 +- .../api/queries/mouse_connectivity_api.py | 341 ++--- allensdk/api/queries/ontologies_api.py | 352 ++--- allensdk/api/queries/reference_space_api.py | 178 +-- allensdk/api/queries/rma_api.py | 261 ++-- allensdk/api/queries/rma_pager.py | 45 +- allensdk/api/queries/rma_template.py | 111 +- allensdk/api/queries/svg_api.py | 29 +- allensdk/api/queries/synchronization_api.py | 150 +- allensdk/api/queries/tree_search_api.py | 36 +- allensdk/api/warehouse_cache/cache.py | 293 ++-- .../api/warehouse_cache/caching_utilities.py | 38 +- allensdk/brain_observatory/__init__.py | 10 +- .../brain_observatory/argschema_utilities.py | 43 +- .../brain_observatory/behavior/__init__.py | 18 +- .../behavior/behavior_ophys_analysis.py | 60 +- .../behavior/behavior_ophys_experiment.py | 123 +- .../behavior/behavior_ophys_session.py | 7 +- .../behavior_project_cache/__init__.py | 11 +- .../behavior_neuropixels_project_cache.py | 42 +- .../behavior_project_cache.py | 48 +- .../project_apis/abcs/__init__.py | 4 +- .../abcs/behavior_project_base.py | 14 +- .../project_apis/data_io/__init__.py | 20 +- .../behavior_neuropixels_project_cloud_api.py | 19 +- .../data_io/behavior_project_cloud_api.py | 81 +- .../data_io/behavior_project_lims_api.py | 106 +- .../data_io/natural_movie_one_cache.py | 4 +- .../data_io/project_cloud_api_base.py | 72 +- .../project_cache_base.py | 104 +- .../project_metadata_writer/__main__.py | 5 +- .../behavior_project_metadata_writer.py | 148 +- .../project_metadata_writer/schemas.py | 44 +- .../tables/experiments_table.py | 28 +- .../tables/metadata_table_schemas.py | 65 +- .../tables/ophys_mixin.py | 29 +- .../tables/ophys_sessions_table.py | 29 +- .../tables/project_table.py | 7 +- .../tables/sessions_table.py | 73 +- .../tables/util/experiments_table_utils.py | 19 +- .../tables/util/image_presentation_utils.py | 8 +- .../tables/util/prior_exposure_processing.py | 39 +- .../behavior/behavior_project_cache/utils.py | 17 +- .../behavior/behavior_session.py | 216 +-- .../brain_observatory/behavior/criteria.py | 134 +- .../behavior/data_files/__init__.py | 3 +- .../behavior/data_files/demix_file.py | 8 +- .../behavior/data_files/dff_file.py | 15 +- .../data_files/event_detection_file.py | 30 +- .../behavior/data_files/eye_tracking_file.py | 8 +- .../data_files/eye_tracking_metadata_file.py | 9 +- .../behavior/data_files/eye_tracking_video.py | 10 +- .../data_files/neuropil_corrected_file.py | 8 +- .../behavior/data_files/neuropil_file.py | 8 +- .../data_files/rigid_motion_transform_file.py | 7 +- .../behavior/data_files/stimulus_file.py | 60 +- .../behavior/data_files/sync_file.py | 19 +- .../behavior/data_objects/__init__.py | 10 +- .../cell_specimens/cell_specimens.py | 249 +-- .../data_objects/cell_specimens/events.py | 161 +- .../data_objects/cell_specimens/rois_mixin.py | 26 +- .../traces/corrected_fluorescence_traces.py | 45 +- .../cell_specimens/traces/demixed_traces.py | 27 +- .../cell_specimens/traces/dff_traces.py | 64 +- .../cell_specimens/traces/neuropil_traces.py | 27 +- .../eye_tracking/eye_tracking_table.py | 279 ++-- .../data_objects/eye_tracking/rig_geometry.py | 180 +-- .../behavior/data_objects/licks.py | 65 +- .../behavior_metadata/behavior_metadata.py | 209 ++- .../behavior_metadata/behavior_session_id.py | 26 +- .../behavior_session_uuid.py | 18 +- .../behavior_metadata/date_of_acquisition.py | 28 +- .../metadata/behavior_metadata/equipment.py | 28 +- .../metadata/behavior_metadata/foraging_id.py | 7 +- .../behavior_metadata/project_code.py | 12 +- .../behavior_metadata/session_type.py | 19 +- .../behavior_metadata/stimulus_frame_rate.py | 30 +- .../metadata/behavior_ophys_metadata.py | 56 +- .../field_of_view_shape.py | 24 +- .../imaging_depth.py | 16 +- .../imaging_plane.py | 93 +- .../imaging_plane_group.py | 37 +- .../multi_plane_metadata.py | 70 +- .../ophys_container_id.py | 19 +- .../ophys_experiment_metadata.py | 36 +- .../ophys_project_code.py | 6 +- .../ophys_session_id.py | 19 +- .../targeted_imaging_depth.py | 12 +- .../metadata/subject_metadata/age.py | 21 +- .../metadata/subject_metadata/driver_line.py | 20 +- .../subject_metadata/full_genotype.py | 19 +- .../metadata/subject_metadata/mouse_id.py | 4 +- .../subject_metadata/reporter_line.py | 48 +- .../metadata/subject_metadata/sex.py | 10 +- .../subject_metadata/subject_metadata.py | 114 +- .../data_objects/motion_correction.py | 43 +- .../behavior/data_objects/projections.py | 81 +- .../behavior/data_objects/rewards.py | 63 +- .../multi_stim_running_processing.py | 157 +- .../running_speed/running_acquisition.py | 126 +- .../running_speed/running_processing.py | 135 +- .../running_speed/running_speed.py | 137 +- .../stimuli/fingerprint_stimulus.py | 72 +- .../data_objects/stimuli/presentations.py | 191 +-- .../behavior/data_objects/stimuli/stimuli.py | 14 +- .../stimuli/stimulus_templates.py | 108 +- .../data_objects/stimuli/templates.py | 184 +-- .../behavior/data_objects/stimuli/util.py | 39 +- .../behavior/data_objects/task_parameters.py | 129 +- .../timestamps/ophys_timestamps.py | 38 +- .../stimulus_timestamps.py | 207 +-- .../timestamps_processing.py | 17 +- .../behavior/data_objects/trials/trial.py | 223 ++- .../behavior/data_objects/trials/trials.py | 86 +- allensdk/brain_observatory/behavior/dprime.py | 73 +- .../behavior/event_detection.py | 10 +- .../behavior/eye_tracking_processing.py | 54 +- .../brain_observatory/behavior/image_api.py | 11 +- allensdk/brain_observatory/behavior/mtrain.py | 86 +- .../behavior/rewards_processing.py | 3 +- .../brain_observatory/behavior/schemas.py | 234 ++- .../behavior/session_metrics.py | 8 +- .../behavior/stimulus_processing.py | 153 +- .../behavior/swdb/analysis_tools.py | 40 +- .../behavior/swdb/behavior_project_cache.py | 386 +++-- .../behavior/swdb/create_multi_session_df.py | 33 +- .../behavior/swdb/run_multi_session_df.py | 30 +- ...save_extended_stimulus_presentations_df.py | 41 +- .../swdb/run_save_flash_response_df.py | 41 +- .../swdb/run_save_trial_response_df.py | 41 +- .../behavior/swdb/run_summary_figures.py | 43 +- ...save_extended_stimulus_presentations_df.py | 79 +- .../behavior/swdb/save_flash_response_df.py | 477 +++--- .../behavior/swdb/save_trial_response_df.py | 377 +++-- .../behavior/swdb/summary_figures.py | 387 +++-- .../behavior/swdb/utilities.py | 344 +++-- .../behavior/sync/__init__.py | 138 +- .../behavior/sync/process_sync.py | 47 +- .../brain_observatory/behavior/trial_masks.py | 16 +- .../behavior/trials_processing.py | 196 ++- .../behavior/utils/metadata_parsers.py | 7 +- .../behavior/write_nwb/behavior/__main__.py | 8 +- .../behavior/write_nwb/behavior/schemas.py | 81 +- .../event_detection/extension_builder.py | 40 +- .../event_detection/ndx_ophys_events.py | 8 +- .../stimulus_template/extension_builder.py | 47 +- .../ndx_stimulus_template.py | 8 +- .../behavior/write_nwb/nwb_writer_utils.py | 5 +- .../behavior/write_nwb/ophys/__main__.py | 11 +- .../behavior/write_nwb/ophys/schemas.py | 39 +- .../brain_observatory_exceptions.py | 11 +- .../brain_observatory_plotting.py | 884 ++++++----- .../chisquare_categorical.py | 38 +- allensdk/brain_observatory/circle_plots.py | 451 +++--- .../brain_observatory/comparison_utils.py | 5 +- .../metadata_utils/id_generator.py | 11 +- .../metadata_utils/utils.py | 55 +- allensdk/brain_observatory/demixer.py | 135 +- allensdk/brain_observatory/dff.py | 68 +- .../brain_observatory/drifting_gratings.py | 392 +++-- .../brain_observatory/ecephys/__init__.py | 15 +- .../ecephys/_behavior_ecephys_metadata.py | 73 +- .../brain_observatory/ecephys/_channel.py | 40 +- .../brain_observatory/ecephys/_channels.py | 120 +- .../ecephys/_current_source_density.py | 26 +- allensdk/brain_observatory/ecephys/_lfp.py | 54 +- allensdk/brain_observatory/ecephys/_probe.py | 239 ++- allensdk/brain_observatory/ecephys/_unit.py | 75 +- allensdk/brain_observatory/ecephys/_units.py | 96 +- .../ecephys/align_timestamps/__main__.py | 53 +- .../ecephys/align_timestamps/_schemas.py | 20 +- .../ecephys/align_timestamps/barcode.py | 43 +- .../align_timestamps/barcode_sync_dataset.py | 12 +- .../align_timestamps/channel_states.py | 25 +- .../align_timestamps/probe_synchronizer.py | 36 +- .../ecephys/behavior_ecephys_session.py | 79 +- .../ecephys/copy_utility/__main__.py | 72 +- .../ecephys/copy_utility/_schemas.py | 80 +- .../current_source_density/__main__.py | 239 ++- .../_current_source_density.py | 74 +- .../current_source_density/_filter_utils.py | 29 +- .../_interpolation_utils.py | 75 +- .../current_source_density/_schemas.py | 152 +- .../ecephys/data_objects/trials.py | 55 +- .../ecephys_project_api.py | 52 +- .../ecephys_project_fixed_api.py | 1 - .../ecephys_project_lims_api.py | 239 ++- .../ecephys_project_warehouse_api.py | 28 +- .../ecephys_project_api/http_engine.py | 110 +- .../ecephys/ecephys_project_api/rma_engine.py | 36 +- .../ecephys/ecephys_project_api/utilities.py | 5 +- .../ecephys/ecephys_project_cache.py | 380 ++--- .../ecephys/ecephys_session.py | 707 ++++----- .../ecephys/ecephys_session_api/__init__.py | 2 +- .../ecephys_nwb1_session_api.py | 216 +-- .../ecephys_nwb_session_api.py | 241 ++- .../ecephys_session_api.py | 1 - .../ecephys/file_io/continuous_file.py | 99 +- .../ecephys/file_io/ecephys_sync_dataset.py | 134 +- .../ecephys/file_io/stim_file.py | 42 +- .../ecephys/lfp_subsampling/__main__.py | 100 +- .../ecephys/lfp_subsampling/_schemas.py | 118 +- .../ecephys/lfp_subsampling/subsampling.py | 73 +- .../brain_observatory/ecephys/nwb/__init__.py | 12 +- .../nwb/ecephys_nwb_extension_builder.py | 192 ++- .../brain_observatory/ecephys/nwb_util.py | 70 +- .../brain_observatory/ecephys/optotagging.py | 37 +- .../ecephys/optotagging_table/__main__.py | 34 +- .../ecephys/optotagging_table/_schemas.py | 36 +- allensdk/brain_observatory/ecephys/probes.py | 129 +- .../ecephys/stimulus_analysis/__init__.py | 2 +- .../ecephys/stimulus_analysis/__main__.py | 82 +- .../ecephys/stimulus_analysis/_schemas.py | 59 +- .../ecephys/stimulus_analysis/dot_motion.py | 120 +- .../stimulus_analysis/drifting_gratings.py | 379 +++-- .../ecephys/stimulus_analysis/flashes.py | 125 +- .../stimulus_analysis/natural_movies.py | 42 +- .../stimulus_analysis/natural_scenes.py | 75 +- .../receptive_field_mapping.py | 181 +-- .../stimulus_analysis/static_gratings.py | 352 +++-- .../stimulus_analysis/stimulus_analysis.py | 297 ++-- .../ecephys/stimulus_sync.py | 144 +- .../ecephys/stimulus_table/__main__.py | 70 +- .../ecephys/stimulus_table/_schemas.py | 46 +- ...ecephys_visual_coding_time_alignment.ipynb | 30 +- .../stimulus_table/ephys_pre_spikes.py | 79 +- .../stimulus_table/naming_utilities.py | 25 +- .../stimulus_table/output_validation.py | 14 +- .../stimulus_parameter_extraction.py | 11 +- .../visualization/view_blocks.py | 13 +- allensdk/brain_observatory/ecephys/utils.py | 20 +- .../ecephys/visualization/__init__.py | 64 +- .../ecephys/write_nwb/__main__.py | 255 ++-- .../ecephys/write_nwb/nwb_writer.py | 75 +- .../ecephys/write_nwb/schemas.py | 44 +- .../ecephys/write_nwb/vbn/__main__.py | 25 +- .../ecephys/write_nwb/vbn/_schemas.py | 48 +- .../extract_running_speed/__main__.py | 41 +- .../extract_running_speed/_schemas.py | 15 +- .../extract_running_speed/examples.ipynb | 56 +- .../eye_tracking/__main__.py | 4 +- .../eye_tracking/_schemas.py | 33 +- .../brain_observatory/eye_tracking/build.py | 251 ++-- .../eye_tracking/stage_1/DLC_Eye_Tracking.py | 36 +- .../stage_2/DLC_Ellipse_Fitting.py | 152 +- .../eye_tracking/stage_3/DLC_Labeled_Video.py | 28 +- .../eye_tracking/stage_4/DLC_Ellipse_Video.py | 75 +- allensdk/brain_observatory/findlevel.py | 6 +- .../gaze_mapping/__main__.py | 212 +-- .../gaze_mapping/_filter_utils.py | 23 +- .../gaze_mapping/_gaze_mapper.py | 91 +- .../gaze_mapping/_schemas.py | 135 +- .../brain_observatory/locally_sparse_noise.py | 277 ++-- .../multi_stimulus_running_speed/__main__.py | 7 +- .../multi_stimulus_running_speed/_schemas.py | 27 +- .../multi_stimulus_running_speed.py | 61 +- allensdk/brain_observatory/natural_movie.py | 68 +- allensdk/brain_observatory/natural_scenes.py | 212 ++- allensdk/brain_observatory/nwb/__init__.py | 681 ++++----- .../behavior_ophys_nwb_extension_builder.py | 19 +- .../nwb/eye_tracking/extension_builder.py | 91 +- .../eye_tracking/ndx_ellipse_eye_tracking.py | 9 +- allensdk/brain_observatory/nwb/metadata.py | 67 +- allensdk/brain_observatory/nwb/nwb_api.py | 36 +- allensdk/brain_observatory/nwb/nwb_utils.py | 57 +- allensdk/brain_observatory/nwb/schemas.py | 2 +- .../brain_observatory/observatory_plots.py | 410 ++--- .../ophys/project_constants.py | 3 +- .../ophys/trace_extraction/__init__.py | 14 +- .../ophys/trace_extraction/__main__.py | 69 +- .../ophys/trace_extraction/_schemas.py | 71 +- allensdk/brain_observatory/r_neuropil.py | 72 +- .../receptive_field_analysis/__init__.py | 2 +- .../receptive_field_analysis/chisquarerf.py | 86 +- .../eventdetection.py | 25 +- .../fit_parameters.py | 37 +- .../receptive_field_analysis/fitgaussian2D.py | 65 +- .../postprocessing.py | 67 +- .../receptive_field.py | 176 +-- .../receptive_field_analysis/tools.py | 7 +- .../receptive_field_analysis/utilities.py | 68 +- .../receptive_field_analysis/visualization.py | 130 +- allensdk/brain_observatory/roi_masks.py | 125 +- allensdk/brain_observatory/running_speed.py | 4 +- .../brain_observatory/session_analysis.py | 354 +++-- .../brain_observatory/session_api_utils.py | 63 +- allensdk/brain_observatory/static_gratings.py | 469 +++--- .../brain_observatory/stimulus_analysis.py | 340 ++--- allensdk/brain_observatory/stimulus_info.py | 272 +--- allensdk/brain_observatory/sync_dataset.py | 341 ++--- .../brain_observatory/sync_stim_aligner.py | 123 +- .../sync_utilities/__init__.py | 17 +- .../vbn_2022/input_json_writer/__main__.py | 4 +- .../input_json_writer/input_json_writer.py | 31 +- .../vbn_2022/input_json_writer/schemas.py | 84 +- .../vbn_2022/input_json_writer/utils.py | 739 +++++---- .../vbn_2022/metadata_writer/__main__.py | 4 +- .../dataframe_manipulations.py | 48 +- .../vbn_2022/metadata_writer/lims_queries.py | 69 +- .../metadata_writer/metadata_writer.py | 42 +- .../vbn_2022/metadata_writer/schemas.py | 25 +- .../vbn_2022/utils/schemas.py | 20 +- .../visualization/__init__.py | 30 +- allensdk/config/__init__.py | 25 +- allensdk/config/app/__init__.py | 4 +- allensdk/config/app/application_config.py | 126 +- allensdk/config/manifest.py | 178 +-- allensdk/config/manifest_builder.py | 30 +- allensdk/config/model/__init__.py | 2 +- allensdk/config/model/description.py | 25 +- allensdk/config/model/description_parser.py | 16 +- allensdk/config/model/formats/__init__.py | 2 +- allensdk/config/model/formats/hdf5_util.py | 21 +- .../model/formats/json_description_parser.py | 24 +- .../model/formats/pycfg_description_parser.py | 34 +- allensdk/core/__init__.py | 14 +- .../core/_data_object_base/data_object.py | 24 +- .../_data_object_base/readable_interfaces.py | 10 +- .../_data_object_base/writable_interfaces.py | 1 + allensdk/core/auth_config.py | 8 +- allensdk/core/authentication.py | 22 +- allensdk/core/brain_observatory_cache.py | 127 +- .../core/brain_observatory_nwb_data_set.py | 612 ++++---- allensdk/core/cell_types_cache.py | 128 +- allensdk/core/dat_utilities.py | 5 +- allensdk/core/dataframe_utils.py | 23 +- allensdk/core/exceptions.py | 6 +- allensdk/core/h5_utilities.py | 43 +- allensdk/core/json_utilities.py | 21 +- allensdk/core/lazy_property/__init__.py | 1 - allensdk/core/lazy_property/lazy_property.py | 6 +- .../core/lazy_property/lazy_property_mixin.py | 9 +- allensdk/core/mouse_connectivity_cache.py | 134 +- allensdk/core/nwb_data_set.py | 172 +-- allensdk/core/obj_utilities.py | 33 +- allensdk/core/ontology.py | 24 +- .../ophys_experiment_session_id_mapping.py | 2 +- allensdk/core/pickle_utils.py | 38 +- allensdk/core/reference_space.py | 313 ++-- allensdk/core/reference_space_cache.py | 187 +-- allensdk/core/simple_tree.py | 270 ++-- allensdk/core/sitk_utilities.py | 84 +- allensdk/core/structure_tree.py | 398 +++-- allensdk/core/swc.py | 193 ++- allensdk/core/typing.py | 3 +- allensdk/core/utilities.py | 12 +- allensdk/deprecated.py | 73 +- allensdk/ephys/__init__.py | 2 +- allensdk/ephys/ephys_extractor.py | 467 +++--- allensdk/ephys/ephys_features.py | 229 +-- allensdk/ephys/extract_cell_features.py | 106 +- allensdk/ephys/feature_extractor.py | 126 +- allensdk/internal/api/__init__.py | 49 +- allensdk/internal/api/api_prerelease.py | 7 +- allensdk/internal/api/lims_api.py | 24 +- allensdk/internal/api/mtrain_api.py | 178 +-- .../api/queries/behavior_lims_queries.py | 32 +- .../api/queries/biophysical_module_api.py | 148 +- .../api/queries/biophysical_module_reader.py | 410 +++-- .../api/queries/compound_lims_queries.py | 15 +- .../api/queries/ecephys_lims_queries.py | 17 +- .../api/queries/equipment_lims_queries.py | 15 +- .../api/queries/grid_data_api_prerelease.py | 39 +- .../mouse_connectivity_api_prerelease.py | 102 +- .../api/queries/optimize_config_reader.py | 389 ++--- allensdk/internal/api/queries/pre_release.py | 207 ++- allensdk/internal/api/queries/utils.py | 28 +- .../internal/api/queries/wkf_lims_queries.py | 43 +- .../annotated_region_metrics.py | 68 +- .../brain_observatory/demix_report.py | 194 +-- .../internal/brain_observatory/demixer.py | 208 +-- .../brain_observatory/eye_calibration.py | 170 +-- .../internal/brain_observatory/fit_ellipse.py | 205 +-- .../brain_observatory/frame_stream.py | 130 +- .../internal/brain_observatory/itracker.py | 428 +++--- .../brain_observatory/itracker_utils.py | 162 +- .../internal/brain_observatory/mask_set.py | 45 +- allensdk/internal/brain_observatory/mouse.py | 78 +- .../ophys_session_decomposition.py | 62 +- .../internal/brain_observatory/roi_filter.py | 128 +- .../brain_observatory/roi_filter_utils.py | 95 +- .../brain_observatory/run_itracker.py | 154 +- .../internal/brain_observatory/time_sync.py | 153 +- .../util/multi_session_utils.py | 95 +- allensdk/internal/core/__init__.py | 2 +- allensdk/internal/core/_data_file.py | 10 +- .../internal/core/lims_pipeline_module.py | 88 +- allensdk/internal/core/lims_utilities.py | 165 +- .../mouse_connectivity_cache_prerelease.py | 142 +- allensdk/internal/core/simpletree.py | 20 +- allensdk/internal/core/swc.py | 57 +- .../internal/ephys/core_feature_extract.py | 171 ++- allensdk/internal/ephys/plot_qc_figures.py | 540 ++++--- allensdk/internal/ephys/plot_qc_figures3.py | 598 ++++---- allensdk/internal/model/AIC.py | 17 +- allensdk/internal/model/GLM.py | 191 +-- .../model/biophysical/biophysical_archiver.py | 102 +- .../model/biophysical/check_fi_shift.py | 15 +- .../internal/model/biophysical/deap_utils.py | 45 +- .../internal/model/biophysical/ephys_utils.py | 27 +- .../internal/model/biophysical/fit_stage_1.py | 215 +-- .../internal/model/biophysical/fit_stage_2.py | 66 +- .../model/biophysical/make_deap_fit_json.py | 106 +- .../model/biophysical/neuron_parallel.py | 20 +- .../internal/model/biophysical/optimize.py | 57 +- .../passive_fitting/neuron_passive_fit.py | 47 +- .../passive_fitting/neuron_passive_fit2.py | 30 +- .../neuron_passive_fit_elec.py | 32 +- .../passive_fitting/neuron_utils.py | 17 +- .../passive_fitting/output_grabber.py | 5 +- .../biophysical/passive_fitting/preprocess.py | 35 +- .../model/biophysical/run_optimize.py | 161 +- .../biophysical/run_optimize_workflow.py | 8 +- .../model/biophysical/run_passive_fit.py | 128 +- .../model/biophysical/run_simulate_lims.py | 106 +- .../biophysical/run_simulate_workflow.py | 8 +- allensdk/internal/model/data_access.py | 53 +- allensdk/internal/model/glif/ASGLM.py | 342 +++-- allensdk/internal/model/glif/MLIN.py | 228 +-- .../glif/are_two_lists_of_arrays_the_same.py | 17 +- .../internal/model/glif/configure_model.py | 541 +++---- .../internal/model/glif/error_functions.py | 392 ++--- allensdk/internal/model/glif/find_spikes.py | 103 +- allensdk/internal/model/glif/find_sweeps.py | 118 +- .../internal/model/glif/glif_experiment.py | 129 +- .../internal/model/glif/glif_optimizer.py | 267 ++-- .../model/glif/glif_optimizer_neuron.py | 868 ++++++----- .../internal/model/glif/optimize_neuron.py | 111 +- allensdk/internal/model/glif/plotting.py | 205 +-- .../internal/model/glif/preprocess_neuron.py | 746 +++++---- allensdk/internal/model/glif/rc.py | 45 +- allensdk/internal/model/glif/spike_cutting.py | 322 ++-- .../model/glif/threshold_adaptation.py | 668 +++++---- allensdk/internal/morphology/compartment.py | 2 +- allensdk/internal/morphology/morphology.py | 325 ++-- allensdk/internal/morphology/morphvis.py | 314 ++-- allensdk/internal/morphology/node.py | 66 +- allensdk/internal/morphology/validate_swc.py | 71 +- .../interval_unionize/cav_unionize.py | 19 +- .../interval_unionize/cav_unionizer.py | 30 +- .../interval_unionize/data_utilities.py | 16 +- .../interval_unionize/interval_unionizer.py | 184 ++- .../run_tissuecyte_unionize_cav.py | 68 +- .../run_tissuecyte_unionize_classic.py | 119 +- .../tissuecyte_unionize_record.py | 119 +- .../interval_unionize/tissuecyte_unionizer.py | 108 +- .../interval_unionize/unionize_record.py | 26 +- .../generate_projection_strip.py | 42 +- .../projection_thumbnail/image_sheet.py | 12 +- .../projection_functions.py | 3 - .../visualization_utilities.py | 43 +- .../projection_thumbnail/volume_projector.py | 38 +- .../projection_thumbnail/volume_utilities.py | 23 +- .../tissuecyte_stitching/stitcher.py | 64 +- .../tissuecyte_stitching/tile.py | 66 +- .../internal/notebooks/execute_notebooks.py | 80 +- .../IVSCC/ephys_nwb/convert_igor_nwb.py | 12 +- .../IVSCC/ephys_nwb/extract_nwb_data.py | 184 ++- .../ephys_nwb/feature_extraction_module.py | 35 +- .../IVSCC/ephys_nwb/lab_notebook_reader.py | 12 +- .../IVSCC/ephys_nwb/nwb_publish.py | 267 ++-- .../pipeline_modules/IVSCC/ephys_nwb/qc.py | 47 +- .../IVSCC/ephys_nwb/qc_support.py | 69 +- .../IVSCC/ephys_nwb/resource_file.py | 65 +- .../morphology/calculate_features.py | 10 +- .../cell_types/morphology/cortical_layers.py | 173 ++- .../morphology/surrogate_strategy.py | 28 +- .../morphology/upright_transform.py | 158 +- .../gbm/generate_gbm_analysis_run_records.py | 22 +- .../gbm/generate_gbm_heatmap.py | 74 +- .../gbm/generate_gbm_sample_metadata.py | 40 +- .../run_annotated_region_metrics.py | 37 +- .../internal/pipeline_modules/run_demixing.py | 134 +- .../pipeline_modules/run_eye_tracking.py | 88 +- .../run_neuropil_correction.py | 109 +- .../run_observatory_analysis.py | 84 +- .../run_observatory_container_thumbnails.py | 53 +- .../run_observatory_thumbnails.py | 286 ++-- .../run_ophys_eye_calibration.py | 170 +-- .../run_ophys_session_decomposition.py | 87 +- .../pipeline_modules/run_ophys_time_sync.py | 126 +- .../pipeline_modules/run_roi_filter.py | 121 +- ...ssuecyte_projection_thumbnail_from_json.py | 67 +- .../run_tissuecyte_stitching_classic.py | 61 +- .../run_tissuecyte_unionize_cav_from_json.py | 8 +- ...ecyte_unionize_classic_counts_from_json.py | 8 +- ...n_tissuecyte_unionize_classic_from_json.py | 8 +- allensdk/model/__init__.py | 2 +- allensdk/model/biophys_sim/__init__.py | 2 +- allensdk/model/biophys_sim/bps_command.py | 32 +- allensdk/model/biophys_sim/config.py | 52 +- allensdk/model/biophys_sim/neuron/__init__.py | 2 +- .../model/biophys_sim/neuron/hoc_utils.py | 29 +- .../model/biophys_sim/scripts/__init__.py | 2 +- allensdk/model/biophysical/__init__.py | 2 +- allensdk/model/biophysical/run_simulate.py | 59 +- allensdk/model/biophysical/runner.py | 91 +- allensdk/model/biophysical/utils.py | 69 +- allensdk/model/glif/__init__.py | 2 +- allensdk/model/glif/glif_neuron.py | 403 ++--- allensdk/model/glif/glif_neuron_methods.py | 397 ++--- allensdk/model/glif/simulate_neuron.py | 63 +- allensdk/morphology/__init__.py | 2 +- allensdk/morphology/validate_swc.py | 27 +- allensdk/mouse_connectivity/grid/__init__.py | 18 +- allensdk/mouse_connectivity/grid/__main__.py | 123 +- allensdk/mouse_connectivity/grid/_schemas.py | 83 +- .../grid/image_series_gridder.py | 134 +- .../grid/subimage/__init__.py | 17 +- .../grid/subimage/base_subimage.py | 196 +-- .../grid/subimage/cav_subimage.py | 19 +- .../grid/subimage/classic_subimage.py | 141 +- .../grid/subimage/count_subimage.py | 86 +- .../grid/utilities/downsampling_utilities.py | 39 +- .../grid/utilities/image_utilities.py | 92 +- .../grid/writers/__init__.py | 222 ++- allensdk/test/api/__init__.py | 10 +- allensdk/test/api/cloud_cache/conftest.py | 177 ++- allensdk/test/api/cloud_cache/test_cache.py | 639 ++++---- .../test/api/cloud_cache/test_change_log.py | 293 ++-- .../api/cloud_cache/test_file_attributes.py | 54 +- .../test/api/cloud_cache/test_full_process.py | 230 ++- .../test/api/cloud_cache/test_local_cache.py | 53 +- .../test/api/cloud_cache/test_manifest.py | 139 +- .../api/cloud_cache/test_smart_download.py | 377 +++-- .../cloud_cache/test_static_local_cache.py | 91 +- allensdk/test/api/cloud_cache/test_utils.py | 16 +- .../cloud_cache/test_windows_isilon_paths.py | 55 +- allensdk/test/api/cloud_cache/utils.py | 86 +- allensdk/test/api/queries/test_utils.py | 14 +- .../test_annotated_section_data_set_api.py | 33 +- allensdk/test/api/test_api.py | 96 +- allensdk/test/api/test_biophysical_api.py | 60 +- .../test/api/test_brain_observatory_api.py | 199 +-- allensdk/test/api/test_cache.py | 58 +- allensdk/test/api/test_cacheable.py | 269 ++-- allensdk/test/api/test_caching_utilities.py | 42 +- allensdk/test/api/test_cell_types_api.py | 169 +-- allensdk/test/api/test_file_download.py | 229 ++- allensdk/test/api/test_glif_api.py | 46 +- allensdk/test/api/test_grid_data_api.py | 97 +- allensdk/test/api/test_image_download_api.py | 490 +++--- allensdk/test/api/test_mouse_atlas_api.py | 49 +- .../test/api/test_mouse_connectivity_api.py | 369 ++--- allensdk/test/api/test_ontologies_api.py | 66 +- allensdk/test/api/test_pager.py | 282 ++-- allensdk/test/api/test_reference_space_api.py | 248 ++- allensdk/test/api/test_rma_template.py | 241 +-- allensdk/test/api/test_svg_api.py | 53 +- allensdk/test/api/test_synchronization_api.py | 47 +- allensdk/test/api/test_tree_search_api.py | 45 +- .../behavior_project_cache/conftest.py | 6 +- .../tables/test_ophys_mixin.py | 63 +- ..._behavior_neuropixels_project_cloud_api.py | 56 +- .../test_behavior_project_cloud_api.py | 43 +- .../test_behavior_project_lims_api.py | 125 +- .../test_behavior_project_metadata_writer.py | 191 +-- .../test_experiments_table_utils.py | 45 +- .../test_metadata_parsers.py | 11 +- .../test_natural_movie_cache.py | 8 +- .../test_pandas_compat.py | 32 +- .../test_vbn_from_s3.py | 204 +-- .../test_vbo_from_s3.py | 205 +-- .../behavior/behavior_project_cache/utils.py | 94 +- .../conftest.py | 525 +++---- .../test_behavior_project_cache.py | 77 +- .../brain_observatory/behavior/conftest.py | 69 +- .../behavior/data_files/conftest.py | 52 +- .../test_eye_tracking_metadata_file.py | 21 +- .../behavior/data_files/test_stimulus_file.py | 16 +- .../behavior/data_files/test_sync_file.py | 45 +- .../data_objects/base/test_data_object.py | 37 +- .../eye_tracking/test_eye_tracking_table.py | 175 ++- .../eye_tracking/test_eye_tracking_utils.py | 47 +- .../eye_tracking/test_rig_geometry.py | 63 +- .../behavior/data_objects/lims_util.py | 15 +- .../test_behavior_metadata.py | 134 +- .../metadata/test_behavior_ophys_metadata.py | 53 +- .../behavior/data_objects/nwb_input_json.py | 8 +- .../data_objects/running_speed/conftest.py | 329 ++-- .../test_multi_stim_running_processing.py | 196 ++- .../running_speed/test_running_acquisition.py | 105 +- .../running_speed/test_running_processing.py | 268 ++-- .../running_speed/test_running_speed.py | 165 +- .../test_running_speed_from_multi_stim.py | 99 +- .../test_stimulus_timestamps.py | 240 ++- .../test_timestamps_processing.py | 149 +- .../data_objects/test_cell_specimens.py | 130 +- .../behavior/data_objects/test_licks.py | 171 +-- .../data_objects/test_motion_correction.py | 81 +- .../data_objects/test_ophys_timestamps.py | 69 +- .../behavior/data_objects/test_projections.py | 62 +- .../behavior/data_objects/test_rewards.py | 78 +- .../behavior/data_objects/test_stimuli.py | 78 +- .../test_stimulus_presentations.py | 139 +- .../data_objects/test_task_parameters.py | 33 +- .../behavior/data_objects/test_templates.py | 15 +- .../behavior/data_objects/test_trial_obj.py | 263 ++-- .../behavior/data_objects/test_trial_table.py | 147 +- .../behavior/test_behavior_metadata_legacy.py | 416 +++--- .../test_behavior_ophys_experiment.py | 50 +- .../behavior/test_behavior_session.py | 89 +- .../behavior/test_criteria.py | 718 +++++++-- .../brain_observatory/behavior/test_dprime.py | 174 +-- .../behavior/test_event_detection.py | 3 +- .../behavior/test_eye_tracking_processing.py | 375 +++-- .../behavior/test_incomplete_data_objects.py | 68 +- .../behavior/test_mtrain_annotate.py | 9 +- .../test_prior_exposure_count_processing.py | 71 +- .../behavior/test_rewards_processing.py | 33 +- .../behavior/test_session_metrics.py | 49 +- .../behavior/test_stimulus_processing.py | 123 +- .../behavior/test_sync_processing.py | 79 +- .../behavior/test_trial_masks.py | 86 +- .../behavior/test_trials_processing.py | 1330 +++++++++++++---- allensdk/test/brain_observatory/conftest.py | 40 +- .../metadata_utils/conftest.py | 32 +- .../metadata_utils/test_id_generator.py | 24 +- .../metadata_utils/test_utils.py | 72 +- .../test_align_timestamps_module.py | 73 +- .../ecephys/align_timestamps/test_barcode.py | 18 +- .../test_barcode_sync_dataset.py | 11 +- .../align_timestamps/test_channel_states.py | 6 +- .../test_probe_synchronizer.py | 9 +- .../brain_observatory/ecephys/conftest.py | 4 +- .../test_ecephys_nwb1_session_api.py | 48 +- .../ecephys/stimulus_analysis/conftest.py | 80 +- .../stimulus_analysis/test_dot_motion.py | 112 +- .../test_drifting_gratings.py | 348 +++-- .../ecephys/stimulus_analysis/test_flashes.py | 92 +- .../stimulus_analysis/test_natural_movies.py | 30 +- .../stimulus_analysis/test_natural_scenes.py | 106 +- .../test_receptive_field_mapping.py | 52 +- .../stimulus_analysis/test_static_gratings.py | 215 ++- .../test_stimulus_analysis.py | 181 +-- .../stimulus_table/test_ephys_pre_spikes.py | 35 +- .../stimulus_table/test_naming_utilities.py | 16 +- .../test_stimulus_parameter_extraction.py | 2 - .../test_stimulus_table_module.py | 7 +- .../ecephys/test_behavior_ecephys_metadata.py | 43 +- .../ecephys/test_behavior_ecephys_session.py | 76 +- .../ecephys/test_copy_utility.py | 134 +- .../ecephys/test_current_source_density.py | 422 +++--- .../ecephys/test_ecephys_project_cache.py | 351 ++--- .../ecephys/test_ecephys_project_fixed_api.py | 2 +- .../ecephys/test_ecephys_project_lims_api.py | 316 ++-- .../test_ecephys_project_warehouse_api.py | 13 +- .../ecephys/test_ecephys_session.py | 472 +++--- .../ecephys/test_ecephys_session_nwb_api.py | 37 +- .../ecephys/test_ecephys_sync_dataset.py | 64 +- .../ecephys/test_ecephys_utils.py | 19 +- .../ecephys/test_http_engine.py | 47 +- .../ecephys/test_lfp_subsampling.py | 75 +- .../ecephys/test_optotagging_table.py | 19 +- .../brain_observatory/ecephys/test_probes.py | 63 +- .../ecephys/test_rma_engine.py | 24 +- .../ecephys/test_stim_file.py | 53 +- .../ecephys/test_stimulus_sync.py | 242 +-- .../ecephys/test_visualization.py | 22 +- .../ecephys/test_write_nwb.py | 254 +--- .../test_extract_running_speed_module.py | 50 +- .../gaze_mapping/test_gaze_mapping.py | 580 +++---- .../gaze_mapping/test_main.py | 194 +-- .../test_multi_stimulus_running_speed.py | 116 +- .../test/brain_observatory/nwb/conftest.py | 4 +- .../test/brain_observatory/nwb/test_nwb.py | 39 +- .../brain_observatory/nwb/test_nwb_api.py | 4 +- .../brain_observatory/nwb/test_nwb_utils.py | 8 +- .../test_chisquarerf.py | 93 +- .../test_fitgaussian2D.py | 166 +- .../sync_utilities/conftest.py | 77 +- .../test_sync_stim_alignment.py | 137 +- .../test_sync_stim_get_start_frames.py | 261 ++-- .../sync_utilities/test_sync_utilities.py | 168 +-- .../brain_observatory/test_circle_plots.py | 53 +- .../test/brain_observatory/test_demixer.py | 55 +- allensdk/test/brain_observatory/test_dff.py | 45 +- .../test_drifting_gratings.py | 82 +- .../test_locally_sparse_noise.py | 105 +- allensdk/test/brain_observatory/test_mouse.py | 67 +- .../brain_observatory/test_natural_movie.py | 97 +- .../brain_observatory/test_natural_scenes.py | 82 +- .../test/brain_observatory/test_notebook.py | 224 +-- .../test_observatory_plots.py | 169 ++- .../test/brain_observatory/test_roi_masks.py | 59 +- .../test_session_analysis.py | 71 +- .../test_session_analysis_regression.py | 167 +-- .../test_session_api_utils.py | 585 +++++--- .../brain_observatory/test_static_gratings.py | 82 +- .../test_stimulus_analysis.py | 65 +- .../brain_observatory/test_stimulus_info.py | 468 +++--- .../input_json_writer/test_json_writer_cli.py | 43 +- .../vbn_2022/metadata_writer/conftest.py | 76 +- .../vbn_2022/metadata_writer/test_cli.py | 119 +- .../test_dataframe_manipulations.py | 30 +- ...t_vbn_2022_metadata_writer_lims_queries.py | 100 +- .../test_vbn_prior_omissions.py | 173 +-- .../config/test_config_single_file_json.py | 16 +- allensdk/test/config/test_json_comments.py | 157 +- allensdk/test/config/test_manifest.py | 46 +- .../test/config/test_multi_file_config.py | 120 +- allensdk/test/config/test_pyconfig_parser.py | 126 +- allensdk/test/conftest.py | 17 +- allensdk/test/core/test_authentication.py | 62 +- .../test/core/test_brain_observatory_cache.py | 197 ++- .../test_brain_observatory_nwb_data_set.py | 118 +- allensdk/test/core/test_cell_filters.py | 308 ++-- .../test/core/test_cell_types_cache_unit.py | 543 +++---- allensdk/test/core/test_datafame_utils.py | 12 +- allensdk/test/core/test_h5_utilities.py | 54 +- allensdk/test/core/test_json_utilities.py | 34 +- allensdk/test/core/test_lazy_property.py | 13 +- .../core/test_mouse_connectivity_cache.py | 501 ++++--- .../core/test_mouse_connectivity_notebook.py | 120 +- allensdk/test/core/test_nwb_data_set.py | 125 +- allensdk/test/core/test_obj_utilities.py | 26 +- allensdk/test/core/test_pickle_utils.py | 143 +- allensdk/test/core/test_reference_space.py | 142 +- .../test/core/test_reference_space_cache.py | 129 +- .../core/test_reference_space_notebook.py | 64 +- allensdk/test/core/test_simple_tree.py | 195 ++- allensdk/test/core/test_sitk_utilities.py | 81 +- allensdk/test/core/test_structure_tree.py | 292 ++-- allensdk/test/ephys/test_extractor.py | 23 +- allensdk/test/ephys/test_features.py | 6 +- allensdk/test/glif_tests.py | 45 +- .../test/internal/api/test_api_prerelease.py | 4 +- .../api/test_grid_data_api_prerelease.py | 51 +- .../test_mouse_connectivity_api_prerelease.py | 96 +- .../test/internal/api/test_pre_release.py | 82 +- .../test/internal/biophysical/conftest.py | 2 +- .../internal/biophysical/test_ephys_utils.py | 17 +- .../internal/biophysical/test_optimize_run.py | 72 +- .../internal/biophysical/test_simulate_run.py | 46 +- .../test_roi_filter_utils.py | 42 +- .../test_run_ophys_time_sync.py | 82 +- .../brain_observatory/test_time_sync.py | 456 +++--- allensdk/test/internal/conftest.py | 6 +- ...est_mouse_connectivity_cache_prerelease.py | 141 +- .../internal/gbm/test_generate_gbm_heatmap.py | 192 ++- .../internal/morphology/test_apply_affine.py | 26 +- .../test_interval_unionizer.py | 109 +- .../test_projection_functions.py | 12 +- .../test_visualization_utilities.py | 15 +- .../test_volume_projector.py | 37 +- .../test_volume_utilities.py | 37 +- .../test_tissuecyte_unionize_record.py | 180 ++- .../test_unionize_record.py | 9 +- .../internal/test_annotated_region_metrics.py | 40 +- .../test/internal/test_biophysical_modules.py | 49 +- .../internal/test_core_feature_extract.py | 208 ++- .../test/internal/test_eye_calibration.py | 115 +- allensdk/test/internal/test_internal.py | 9 +- allensdk/test/internal/test_mtrain_api.py | 164 +- .../internal/test_optimize_config_reader.py | 62 +- .../test/internal/test_optimize_manifest.py | 112 +- allensdk/test/internal/test_roi_filter.py | 215 ++- .../test/internal/test_simulate_manifest.py | 120 +- .../internal/test_simulate_update_output.py | 60 +- .../tissuecyte_stitching/test_stitcher.py | 74 +- .../tissuecyte_stitching/test_tile.py | 65 +- .../aa_model/test_biophysical_all_active.py | 21 +- allensdk/test/model/check_parser.py | 4 +- .../model/peri_model/test_biophysical_peri.py | 21 +- .../model/test_biophysical_perisomatic.py | 27 +- allensdk/test/model/test_glif.py | 36 +- allensdk/test/model/test_runner.py | 11 +- .../grid/test_base_subimage.py | 33 +- .../grid/test_cav_subimage.py | 33 +- .../grid/test_classic_subimage.py | 131 +- .../grid/test_image_series_gridder.py | 218 +-- .../grid/test_image_utilities.py | 17 +- allensdk/test/test_argschema_utilities.py | 26 +- allensdk/test/test_deprecated.py | 26 +- allensdk/test/test_inline_examples.py | 18 +- allensdk/test/test_lims_queries.py | 119 +- allensdk/test/test_temp_dir.py | 14 +- allensdk/test_utilities/__init__.py | 2 +- allensdk/test_utilities/custom_comparators.py | 73 +- allensdk/test_utilities/regression_fixture.py | 13 +- allensdk/test_utilities/temp_dir.py | 11 +- conftest.py | 51 +- doc_template/conf.py | 191 ++- .../examples_root/examples/biophysical_ex1.py | 8 +- .../biophysical_sim/biophysical_sim.ipynb | 113 +- .../biophysical_sim/biophysical_sim.py | 89 +- .../examples_root/examples/cell_types_ex.py | 53 +- .../examples_root/examples/connectivity_ex.py | 17 +- .../examples/data_api_client_ex.py | 153 +- .../examples_root/examples/glif_ex.py | 137 +- .../Lims Behavior Project Cache.ipynb | 20 +- ...ct cache to access eye tracking data.ipynb | 37 +- .../examples_root/examples/multicell/multi.py | 12 +- .../examples/multicell/multicell_diff.py | 19 +- .../examples_root/examples/multicell/utils.py | 44 +- ..._with_the_stimulus_and_trials_tables.ipynb | 190 +-- .../updated_vs_legacy_api.ipynb | 2 +- .../examples/nb/behavior_ophys_session.ipynb | 29 +- .../examples/nb/brain_observatory.ipynb | 161 +- .../nb/brain_observatory_analysis.ipynb | 69 +- .../nb/brain_observatory_monitor.ipynb | 133 +- .../nb/brain_observatory_stimuli.ipynb | 77 +- .../examples/nb/cell_specimen_mapping.ipynb | 37 +- .../examples/nb/cell_types.ipynb | 134 +- .../examples/nb/download_data_via_api.ipynb | 185 +-- .../examples/nb/ecephys_data_access.ipynb | 111 +- .../examples/nb/ecephys_lfp_analysis.ipynb | 139 +- .../examples/nb/ecephys_optotagging.ipynb | 163 +- .../examples/nb/ecephys_quality_metrics.ipynb | 227 +-- .../examples/nb/ecephys_quickstart.ipynb | 33 +- .../nb/ecephys_receptive_fields.ipynb | 34 +- .../examples/nb/ecephys_session.ipynb | 124 +- .../experiment_detail_example.ipynb | 90 +- .../examples/nb/image_download.ipynb | 61 +- .../examples/nb/mouse_connectivity.ipynb | 85 +- .../examples/nb/neuron/pulse_stimulus.ipynb | 37 +- .../examples/nb/receptive_fields.ipynb | 46 +- .../examples/nb/reference_space.ipynb | 60 +- .../experiment_detail_example.ipynb | 82 +- ..._behavior_compare_across_trial_types.ipynb | 225 +-- .../nb/visual_behavior_mouse_history.ipynb | 139 +- ...al_behavior_neuropixels_LFP_analysis.ipynb | 203 +-- ...opixels_analyzing_behavior_only_data.ipynb | 119 +- ...ual_behavior_neuropixels_data_access.ipynb | 32 +- ...ehavior_neuropixels_dataset_manifest.ipynb | 112 +- ...behavior_neuropixels_quality_metrics.ipynb | 276 ++-- ...sual_behavior_neuropixels_quickstart.ipynb | 197 ++- .../visual_behavior_ophys_data_access.ipynb | 33 +- ...sual_behavior_ophys_dataset_manifest.ipynb | 291 ++-- .../examples_root/examples/simple/simple.py | 10 +- .../examples_root/examples/simple/utils.py | 34 +- pyproject.toml | 3 + .../brain_observatory/create_input_json.py | 291 ++-- ...deploy_visual_coding_ophys_eye_tracking.py | 60 +- .../run_ecephys_nwb_packaging.py | 86 +- .../run_ecephys_stimulus_analysis.py | 116 +- .../brain_observatory/run_session_analysis.py | 4 +- 854 files changed, 43292 insertions(+), 49144 deletions(-) diff --git a/.flake8 b/.flake8 index 1a6365b8dd..a5929244c9 100644 --- a/.flake8 +++ b/.flake8 @@ -2,8 +2,11 @@ max-line-length = 120 extend-ignore = E203, + E501, E402, E741, F403, F405, + W291, + W293, W503 diff --git a/allensdk/__init__.py b/allensdk/__init__.py index 88236aa8c6..9fed4c3ab8 100644 --- a/allensdk/__init__.py +++ b/allensdk/__init__.py @@ -54,8 +54,7 @@ def one(x): except TypeError: return x if xlen != 1: - raise OneResultExpectedError("Expected length one result, received: " - f"{x} results from query") + raise OneResultExpectedError(f"Expected length one result, received: {x} results from query") if isinstance(x, set): return list(x)[0] else: @@ -65,11 +64,9 @@ def one(x): logging.getLogger(__name__).addHandler(logging.NullHandler()) if True: - file_download_log = logging.getLogger( - 'allensdk.api.api.retrieve_file_over_http') + file_download_log = logging.getLogger("allensdk.api.api.retrieve_file_over_http") file_download_log.setLevel(logging.INFO) console = logging.StreamHandler() - formatter = logging.Formatter("%(asctime)s %(name)-12s " - "%(levelname)-8s %(message)s") + formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s") console.setFormatter(formatter) file_download_log.addHandler(console) diff --git a/allensdk/api/__init__.py b/allensdk/api/__init__.py index 361bcf102b..c932ddebba 100644 --- a/allensdk/api/__init__.py +++ b/allensdk/api/__init__.py @@ -33,6 +33,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -''' Subclasses of allensdk.api.api.Api to implement specific queries to the +"""Subclasses of allensdk.api.api.Api to implement specific queries to the `Allen Brain Atlas Data Portal `_. -''' +""" diff --git a/allensdk/api/api.py b/allensdk/api/api.py index 4c391e9c37..82ddbb0119 100644 --- a/allensdk/api/api.py +++ b/allensdk/api/api.py @@ -51,10 +51,10 @@ class Api(object): - _log = logging.getLogger('allensdk.api.api') - _file_download_log = logging.getLogger('allensdk.api.api.retrieve_file_over_http') - default_api_url = 'http://api.brain-map.org' - download_url = 'http://download.alleninstitute.org' + _log = logging.getLogger("allensdk.api.api") + _file_download_log = logging.getLogger("allensdk.api.api.retrieve_file_over_http") + default_api_url = "http://api.brain-map.org" + download_url = "http://download.alleninstitute.org" def __init__(self, api_base_url_string=None): if api_base_url_string is None: @@ -64,73 +64,67 @@ def __init__(self, api_base_url_string=None): self.default_working_directory = os.getcwd() def set_api_urls(self, api_base_url_string): - '''Set the internal RMA and well known file download endpoint urls + """Set the internal RMA and well known file download endpoint urls based on a api server endpoint. Parameters ---------- api_base_url_string : string url of the api to point to - ''' + """ self.api_url = api_base_url_string # http://help.brain-map.org/display/api/Downloading+a+WellKnownFile - self.well_known_file_endpoint = api_base_url_string + \ - '/api/v2/well_known_file_download' + self.well_known_file_endpoint = api_base_url_string + "/api/v2/well_known_file_download" # http://help.brain-map.org/display/api/Downloading+3-D+Expression+Grid+Data - self.grid_data_endpoint = api_base_url_string + '/grid_data' + self.grid_data_endpoint = api_base_url_string + "/grid_data" # http://help.brain-map.org/display/api/Downloading+and+Displaying+SVG - self.svg_endpoint = api_base_url_string + '/api/v2/svg' - self.svg_download_endpoint = api_base_url_string + '/api/v2/svg_download' + self.svg_endpoint = api_base_url_string + "/api/v2/svg" + self.svg_download_endpoint = api_base_url_string + "/api/v2/svg_download" # http://help.brain-map.org/display/api/Downloading+an+Ontology%27s+Structure+Graph - self.structure_graph_endpoint = api_base_url_string + \ - '/api/v2/structure_graph_download' + self.structure_graph_endpoint = api_base_url_string + "/api/v2/structure_graph_download" # http://help.brain-map.org/display/api/Searching+a+Specimen+or+Structure+Tree - self.tree_search_endpoint = api_base_url_string + '/api/v2/tree_search' + self.tree_search_endpoint = api_base_url_string + "/api/v2/tree_search" # http://help.brain-map.org/display/api/Searching+Annotated+SectionDataSets - self.annotated_section_data_sets_endpoint = api_base_url_string + \ - '/api/v2/annotated_section_data_sets' - self.compound_annotated_section_data_sets_endpoint = api_base_url_string + \ - '/api/v2/compound_annotated_section_data_sets' + self.annotated_section_data_sets_endpoint = api_base_url_string + "/api/v2/annotated_section_data_sets" + self.compound_annotated_section_data_sets_endpoint = ( + api_base_url_string + "/api/v2/compound_annotated_section_data_sets" + ) # http://help.brain-map.org/display/api/Image-to-Image+Synchronization#Image-to-ImageSynchronization-ImagetoImage - self.image_to_atlas_endpoint = api_base_url_string + '/api/v2/image_to_atlas' - self.image_to_image_endpoint = api_base_url_string + '/api/v2/image_to_image' - self.image_to_image_2d_endpoint = api_base_url_string + '/api/v2/image_to_image_2d' - self.reference_to_image_endpoint = api_base_url_string + '/api/v2/reference_to_image' - self.image_to_reference_endpoint = api_base_url_string + '/api/v2/image_to_reference' - self.structure_to_image_endpoint = api_base_url_string + '/api/v2/structure_to_image' + self.image_to_atlas_endpoint = api_base_url_string + "/api/v2/image_to_atlas" + self.image_to_image_endpoint = api_base_url_string + "/api/v2/image_to_image" + self.image_to_image_2d_endpoint = api_base_url_string + "/api/v2/image_to_image_2d" + self.reference_to_image_endpoint = api_base_url_string + "/api/v2/reference_to_image" + self.image_to_reference_endpoint = api_base_url_string + "/api/v2/image_to_reference" + self.structure_to_image_endpoint = api_base_url_string + "/api/v2/structure_to_image" # http://help.brain-map.org/display/mouseconnectivity/API - self.section_image_download_endpoint = api_base_url_string + \ - '/api/v2/section_image_download' - self.atlas_image_download_endpoint = api_base_url_string + \ - '/api/v2/atlas_image_download' - self.projection_image_download_endpoint = api_base_url_string + \ - '/api/v2/projection_image_download' - self.image_download_endpoint = api_base_url_string + \ - '/api/v2/image_download' - self.informatics_archive_endpoint = Api.download_url + '/informatics-archive' - - self.rma_endpoint = api_base_url_string + '/api/v2/data' + self.section_image_download_endpoint = api_base_url_string + "/api/v2/section_image_download" + self.atlas_image_download_endpoint = api_base_url_string + "/api/v2/atlas_image_download" + self.projection_image_download_endpoint = api_base_url_string + "/api/v2/projection_image_download" + self.image_download_endpoint = api_base_url_string + "/api/v2/image_download" + self.informatics_archive_endpoint = Api.download_url + "/informatics-archive" + + self.rma_endpoint = api_base_url_string + "/api/v2/data" def set_default_working_directory(self, working_directory): - '''Set the working directory where files will be saved. + """Set the working directory where files will be saved. Parameters ---------- working_directory : string the absolute path string of the working directory. - ''' + """ self.default_working_directory = working_directory def read_data(self, parsed_json): - '''Return the message data from the parsed query. + """Return the message data from the parsed query. Parameters ---------- @@ -141,11 +135,11 @@ def read_data(self, parsed_json): ----- See `API Response Formats - Response Envelope `_ for additional documentation. - ''' - return parsed_json['msg'] + """ + return parsed_json["msg"] def json_msg_query(self, url, dataframe=False): - ''' Common case where the url is fully constructed + """Common case where the url is fully constructed and the response data is stored in the 'msg' field. Parameters @@ -159,10 +153,9 @@ def json_msg_query(self, url, dataframe=False): ------- dict or DataFrame returned data; type depends on dataframe option - ''' + """ - data = self.do_query(lambda *a, **k: url, - self.read_data) + data = self.do_query(lambda *a, **k: url, self.read_data) if dataframe is True: warnings.warn("dataframe argument is deprecated", DeprecationWarning) @@ -171,7 +164,7 @@ def json_msg_query(self, url, dataframe=False): return data def do_query(self, url_builder_fn, json_traversal_fn, *args, **kwargs): - '''Bundle an query url construction function + """Bundle an query url construction function with a corresponding response json traversal function. Parameters @@ -196,17 +189,17 @@ def do_query(self, url_builder_fn, json_traversal_fn, *args, **kwargs): -------- `A simple Api subclass example `_. - ''' + """ api_url = url_builder_fn(*args, **kwargs) - post = kwargs.get('post', False) + post = kwargs.get("post", False) json_parsed_data = self.retrieve_parsed_json_over_http(api_url, post) return json_traversal_fn(json_parsed_data) def do_rma_query(self, rma_builder_fn, json_traversal_fn, *args, **kwargs): - '''Bundle an RMA query url construction function + """Bundle an RMA query url construction function with a corresponding response json traversal function. ..note:: Deprecated in AllenSDK 0.9.2 @@ -233,11 +226,11 @@ def do_rma_query(self, rma_builder_fn, json_traversal_fn, *args, **kwargs): -------- `A simple Api subclass example `_. - ''' + """ return self.do_query(rma_builder_fn, json_traversal_fn, *args, **kwargs) def load_api_schema(self): - '''Download the RMA schema from the current RMA endpoint + """Download the RMA schema from the current RMA endpoint Returns ------- @@ -252,15 +245,14 @@ def load_api_schema(self): `Class Hierarchy `_ and `Class List `_. - ''' - schema_url = self.rma_endpoint + '/enumerate.json' - json_parsed_schema_data = self.retrieve_parsed_json_over_http( - schema_url) + """ + schema_url = self.rma_endpoint + "/enumerate.json" + json_parsed_schema_data = self.retrieve_parsed_json_over_http(schema_url) return json_parsed_schema_data def construct_well_known_file_download_url(self, well_known_file_id): - '''Join data api endpoint and id. + """Join data api endpoint and id. Parameters ---------- @@ -275,16 +267,16 @@ def construct_well_known_file_download_url(self, well_known_file_id): See Also -------- retrieve_file_over_http: Can be used to retrieve the file from the url. - ''' - return self.well_known_file_endpoint + '/' + str(well_known_file_id) + """ + return self.well_known_file_endpoint + "/" + str(well_known_file_id) def cleanup_truncated_file(self, file_path): - '''Helper for removing files. + """Helper for removing files. Parameters ---------- file_path : string - Absolute path including the file name to remove.''' + Absolute path including the file name to remove.""" try: os.remove(file_path) except OSError as e: @@ -292,7 +284,7 @@ def cleanup_truncated_file(self, file_path): raise def retrieve_file_over_http(self, url, file_path, zipped=False): - '''Get a file from the data api and save it. + """Get a file from the data api and save it. Parameters ---------- @@ -301,8 +293,8 @@ def retrieve_file_over_http(self, url, file_path, zipped=False): file_path : string Absolute path including the file name to save. zipped : bool, optional - If true, assume that the response is a zipped directory and attempt - to extract contained files into the directory containing file_path. + If true, assume that the response is a zipped directory and attempt + to extract contained files into the directory containing file_path. Default is False. See Also @@ -312,7 +304,7 @@ def retrieve_file_over_http(self, url, file_path, zipped=False): References ---------- .. [1] Allen Brain Atlas Data Portal: `Downloading a WellKnownFile `_. - ''' + """ self._file_download_log.info("Downloading URL: %s", url) @@ -323,22 +315,22 @@ def retrieve_file_over_http(self, url, file_path, zipped=False): stream_file_over_http(url, file_path) except exceptions.StreamingError: - self._file_download_log.error("Couldn't retrieve file %s from %s (streaming)." % (file_path,url)) + self._file_download_log.error("Couldn't retrieve file %s from %s (streaming)." % (file_path, url)) self.cleanup_truncated_file(file_path) raise except requests.exceptions.ConnectionError: - self._file_download_log.error("Couldn't retrieve file %s from %s (connection)." % (file_path,url)) + self._file_download_log.error("Couldn't retrieve file %s from %s (connection)." % (file_path, url)) self.cleanup_truncated_file(file_path) raise except requests.exceptions.ReadTimeout: - self._file_download_log.error("Couldn't retrieve file %s from %s (timeout)." % (file_path,url)) + self._file_download_log.error("Couldn't retrieve file %s from %s (timeout)." % (file_path, url)) self.cleanup_truncated_file(file_path) raise except requests.exceptions.RequestException: - self._file_download_log.error("Couldn't retrieve file %s from %s (request)." % (file_path,url)) + self._file_download_log.error("Couldn't retrieve file %s from %s (request)." % (file_path, url)) self.cleanup_truncated_file(file_path) raise @@ -347,9 +339,8 @@ def retrieve_file_over_http(self, url, file_path, zipped=False): self.cleanup_truncated_file(file_path) raise - def retrieve_parsed_json_over_http(self, url, post=False): - '''Get the document and put it in a Python data structure + """Get the document and put it in a Python data structure Parameters ---------- @@ -362,20 +353,18 @@ def retrieve_parsed_json_over_http(self, url, post=False): ------- dict Result document as parsed by the JSON library. - ''' + """ self._log.info("Downloading URL: %s", url) - + if post is False: - data = json_utilities.read_url_get( - requests.utils.quote(url, - ';/?:@&=+$,')) + data = json_utilities.read_url_get(requests.utils.quote(url, ";/?:@&=+$,")) else: data = json_utilities.read_url_post(url) return data def retrieve_xml_over_http(self, url): - '''Get the document and put it in a Python data structure + """Get the document and put it in a Python data structure Parameters ---------- @@ -386,16 +375,16 @@ def retrieve_xml_over_http(self, url): ------- string Unparsed xml string. - ''' + """ self._log.info("Downloading URL: %s", url) - + response = requests.get(url) return response.content def stream_zip_directory_over_http(url, directory, members=None, timeout=(9.05, 31.1)): - ''' Supply an http get request and stream the response to a file. + """Supply an http get request and stream the response to a file. Parameters ---------- @@ -406,15 +395,15 @@ def stream_zip_directory_over_http(url, directory, members=None, timeout=(9.05, members : list of str, optional Extract only these files timeout : float or tuple of float, optional - Specify a timeout for the request. If a tuple, specify seperate connect + Specify a timeout for the request. If a tuple, specify seperate connect and read timeouts. - ''' + """ buf = io.BytesIO() - with closing( requests.get(url, stream=True, timeout=timeout) ) as request: - stream.stream_response_to_file( request, buf ) + with closing(requests.get(url, stream=True, timeout=timeout)) as request: + stream.stream_response_to_file(request, buf) zipper = zipfile.ZipFile(buf) zipper.extractall(path=directory, members=members) @@ -422,7 +411,7 @@ def stream_zip_directory_over_http(url, directory, members=None, timeout=(9.05, def stream_file_over_http(url, file_path, timeout=(9.05, 31.1)): - ''' Supply an http get request and stream the response to a file. + """Supply an http get request and stream the response to a file. Parameters ---------- @@ -431,13 +420,12 @@ def stream_file_over_http(url, file_path, timeout=(9.05, 31.1)): file_path : str Stream the response to this path timeout : float or tuple of float, optional - Specify a timeout for the request. If a tuple, specify seperate connect + Specify a timeout for the request. If a tuple, specify seperate connect and read timeouts. - ''' + """ with closing(requests.get(url, stream=True, timeout=timeout)) as response: - response.raise_for_status() - with open(file_path, 'wb') as fil: + with open(file_path, "wb") as fil: stream.stream_response_to_file(response, path=fil) diff --git a/allensdk/api/cloud_cache/cloud_cache.py b/allensdk/api/cloud_cache/cloud_cache.py index 0724e85286..344c6006b3 100644 --- a/allensdk/api/cloud_cache/cloud_cache.py +++ b/allensdk/api/cloud_cache/cloud_cache.py @@ -54,12 +54,7 @@ class BasicLocalCache(ABC): functionality (used to populate helpful error messages) """ - def __init__( - self, - cache_dir: Union[str, Path], - project_name: str, - ui_class_name: Optional[str] = None - ): + def __init__(self, cache_dir: Union[str, Path], project_name: str, ui_class_name: Optional[str] = None): os.makedirs(cache_dir, exist_ok=True) # the class users are actually interacting with @@ -96,12 +91,11 @@ def project_name(self) -> str: @property def manifest_prefix(self) -> str: """On-line prefix for manifest files""" - return f'{self.project_name}/manifests/' + return f"{self.project_name}/manifests/" @property def file_id_column(self) -> str: - """The col name in metadata files used to uniquely identify data files - """ + """The col name in metadata files used to uniquely identify data files""" return self._manifest.file_id_column @property @@ -116,8 +110,7 @@ def metadata_file_names(self) -> list: @property def manifest_file_names(self) -> list: - """Sorted list of manifest file names associated with this dataset - """ + """Sorted list of manifest file names associated with this dataset""" return self._manifest_file_names @property @@ -160,8 +153,7 @@ def list_all_downloaded_manifests(self) -> list: Return a list of all of the manifest files that have been downloaded for this dataset """ - output = [x for x in os.listdir(self._cache_dir) - if re.fullmatch(".*_manifest_v.*.json", x)] + output = [x for x in os.listdir(self._cache_dir) if re.fullmatch(".*_manifest_v.*.json", x)] output.sort() return output @@ -173,17 +165,12 @@ def _find_latest_file(self, file_name_list: List[str]) -> str: and return the one with the latest version """ - vstrs = [s.split(".json")[0].split("_v")[-1] - for s in file_name_list] + vstrs = [s.split(".json")[0].split("_v")[-1] for s in file_name_list] versions = [semver.VersionInfo.parse(v) for v in vstrs] imax = versions.index(max(versions)) return file_name_list[imax] - def _load_manifest( - self, - manifest_name: str, - use_static_project_dir: bool = False - ) -> Manifest: + def _load_manifest(self, manifest_name: str, use_static_project_dir: bool = False) -> Manifest: """ Load and return a manifest from this dataset. @@ -217,17 +204,13 @@ def _load_manifest( ) if use_static_project_dir: - manifest_path = os.path.join( - self._cache_dir, self.project_name, "manifests", manifest_name - ) + manifest_path = os.path.join(self._cache_dir, self.project_name, "manifests", manifest_name) else: manifest_path = os.path.join(self._cache_dir, manifest_name) with open(manifest_path, "r") as f: local_manifest = Manifest( - cache_dir=self._cache_dir, - json_input=f, - use_static_project_dir=use_static_project_dir + cache_dir=self._cache_dir, json_input=f, use_static_project_dir=use_static_project_dir ) return local_manifest @@ -270,9 +253,7 @@ def _file_exists(self, file_attributes: CacheFileAttributes) -> bool: if file_attributes.local_path.exists(): if not file_attributes.local_path.is_file(): - raise RuntimeError(f"{file_attributes.local_path}\n" - "exists, but is not a file;\n" - "unsure how to proceed") + raise RuntimeError(f"{file_attributes.local_path}\nexists, but is not a file;\nunsure how to proceed") file_exists = True @@ -308,9 +289,7 @@ def metadata_path(self, fname: str) -> dict: file_attributes = self._manifest.metadata_file_attributes(fname) exists = self._file_exists(file_attributes) local_path = file_attributes.local_path - output = {'local_path': local_path, - 'exists': exists, - 'file_attributes': file_attributes} + output = {"local_path": local_path, "exists": exists, "file_attributes": file_attributes} return output @@ -344,9 +323,7 @@ def data_path(self, file_id) -> dict: file_attributes = self.get_file_attributes(file_id) exists = self._file_exists(file_attributes) local_path = file_attributes.local_path - output = {'local_path': local_path, - 'exists': exists, - 'file_attributes': file_attributes} + output = {"local_path": local_path, "exists": exists, "file_attributes": file_attributes} return output @@ -388,8 +365,7 @@ class CloudCacheBase(BasicLocalCache): _bucket_name = None def __init__(self, cache_dir, project_name, ui_class_name=None): - super().__init__(cache_dir=cache_dir, project_name=project_name, - ui_class_name=ui_class_name) + super().__init__(cache_dir=cache_dir, project_name=project_name, ui_class_name=ui_class_name) # what latest_manifest was the last time an OutdatedManifestWarning # was emitted @@ -399,43 +375,43 @@ def __init__(self, cache_dir, project_name, ui_class_name=None): # self._manifest_last_used contains the name of the manifest # last loaded from this cache dir (if applicable) - self._manifest_last_used = c_path / '_manifest_last_used.txt' + self._manifest_last_used = c_path / "_manifest_last_used.txt" # self._downloaded_data_path is where we will keep a JSONized # dict mapping paths to downloaded files to their file_hashes; # this will be used when determining if a downloaded file # can instead be a symlink - self._downloaded_data_path = c_path / '_downloaded_data.json' + self._downloaded_data_path = c_path / "_downloaded_data.json" # if the local manifest is missing but there are # data files in cache_dir, emit a warning # suggesting that the user run # self.construct_local_manifest if not self._downloaded_data_path.exists(): - file_list = c_path.glob('**/*') + file_list = c_path.glob("**/*") has_files = False for fname in file_list: if fname.is_file(): - if 'json' not in fname.name: + if "json" not in fname.name: has_files = True break if has_files: - msg = 'This cache directory appears to ' - msg += 'contain data files, but it has no ' - msg += 'record of what those files are. ' - msg += 'You might want to consider running\n\n' - msg += f'{self.ui}.construct_local_manifest()\n\n' - msg += 'to avoid needlessly downloading duplicates ' - msg += 'of data files that did not change between ' - msg += 'data releases. NOTE: running this method ' - msg += 'will require hashing every data file you ' - msg += 'have currently downloaded and could be ' - msg += 'very time consuming.\n\n' - msg += 'To avoid this warning in the future, make ' - msg += 'sure that\n\n' - msg += f'{str(self._downloaded_data_path.resolve())}\n\n' - msg += 'is not deleted between instantiations of this ' - msg += 'cache' + msg = "This cache directory appears to " + msg += "contain data files, but it has no " + msg += "record of what those files are. " + msg += "You might want to consider running\n\n" + msg += f"{self.ui}.construct_local_manifest()\n\n" + msg += "to avoid needlessly downloading duplicates " + msg += "of data files that did not change between " + msg += "data releases. NOTE: running this method " + msg += "will require hashing every data file you " + msg += "have currently downloaded and could be " + msg += "very time consuming.\n\n" + msg += "To avoid this warning in the future, make " + msg += "sure that\n\n" + msg += f"{str(self._downloaded_data_path.resolve())}\n\n" + msg += "is not deleted between instantiations of this " + msg += "cache" warnings.warn(msg, MissingLocalManifestWarning) def construct_local_manifest(self) -> None: @@ -446,22 +422,19 @@ def construct_local_manifest(self) -> None: lookup = {} files_to_hash = set() c_dir = pathlib.Path(self._cache_dir) - file_iterator = c_dir.glob('**/*') + file_iterator = c_dir.glob("**/*") for file_name in file_iterator: if file_name.is_file(): - if 'json' not in file_name.name: + if "json" not in file_name.name: if file_name != self._manifest_last_used: files_to_hash.add(file_name.resolve()) - with tqdm.tqdm(files_to_hash, - total=len(files_to_hash), - unit='(files hashed)') as pbar: - + with tqdm.tqdm(files_to_hash, total=len(files_to_hash), unit="(files hashed)") as pbar: for local_path in pbar: hsh = file_hash_from_path(local_path) lookup[str(local_path.absolute())] = hsh - with open(self._downloaded_data_path, 'w') as out_file: + with open(self._downloaded_data_path, "w") as out_file: out_file.write(json.dumps(lookup, indent=2, sort_keys=True)) def _warn_of_outdated_manifest(self, manifest_name: str) -> None: @@ -474,14 +447,14 @@ def _warn_of_outdated_manifest(self, manifest_name: str) -> None: self._manifest_last_warned_on = self.latest_manifest_file - msg = '\n\n' - msg += 'The manifest file you are loading is not the ' - msg += 'most up to date manifest file available for ' - msg += 'this dataset. The most up to data manifest file ' - msg += 'available for this dataset is \n\n' - msg += f'{self.latest_manifest_file}\n\n' - msg += 'To see the differences between these manifests,' - msg += 'run\n\n' + msg = "\n\n" + msg += "The manifest file you are loading is not the " + msg += "most up to date manifest file available for " + msg += "this dataset. The most up to data manifest file " + msg += "available for this dataset is \n\n" + msg += f"{self.latest_manifest_file}\n\n" + msg += "To see the differences between these manifests," + msg += "run\n\n" msg += f"{self.ui}.compare_manifests('{manifest_name}', " msg += f"'{self.latest_manifest_file}')\n\n" msg += "To see all of the manifest files currently downloaded " @@ -506,7 +479,7 @@ def latest_downloaded_manifest_file(self) -> str: """ file_list = self.list_all_downloaded_manifests() if len(file_list) == 0: - return '' + return "" return self._find_latest_file(self.list_all_downloaded_manifests()) def load_last_manifest(self): @@ -519,17 +492,17 @@ def load_last_manifest(self): self.load_latest_manifest() return None - with open(self._manifest_last_used, 'r') as in_file: + with open(self._manifest_last_used, "r") as in_file: to_load = in_file.read() latest = self.latest_manifest_file if to_load not in self.manifest_file_names: - msg = 'The manifest version recorded as last used ' - msg += f'for this cache -- {to_load}-- ' - msg += 'is not a valid manifest for this dataset. ' - msg += f'Loading latest version -- {latest} -- ' - msg += 'instead.' + msg = "The manifest version recorded as last used " + msg += f"for this cache -- {to_load}-- " + msg += "is not a valid manifest for this dataset. " + msg += f"Loading latest version -- {latest} -- " + msg += "instead." warnings.warn(msg, UserWarning) self.load_latest_manifest() return None @@ -552,23 +525,22 @@ def load_latest_manifest(self): latest_downloaded = self.latest_downloaded_manifest_file latest = self.latest_manifest_file if latest != latest_downloaded: - if latest_downloaded != '': - msg = f'You are loading\n{self.latest_manifest_file}\n' - msg += 'which is newer than the most recent manifest ' - msg += 'file you have previously been working with\n' - msg += f'{latest_downloaded}\n' - msg += 'It is possible that some data files have changed ' - msg += 'between these two data releases, which will ' - msg += 'force you to re-download those data files ' - msg += '(currently downloaded files will not be overwritten).' - msg += f' To continue using {latest_downloaded}, run\n' + if latest_downloaded != "": + msg = f"You are loading\n{self.latest_manifest_file}\n" + msg += "which is newer than the most recent manifest " + msg += "file you have previously been working with\n" + msg += f"{latest_downloaded}\n" + msg += "It is possible that some data files have changed " + msg += "between these two data releases, which will " + msg += "force you to re-download those data files " + msg += "(currently downloaded files will not be overwritten)." + msg += f" To continue using {latest_downloaded}, run\n" msg += f"{self.ui}.load_manifest('{latest_downloaded}')" warnings.warn(msg, OutdatedManifestWarning) self.load_manifest(self.latest_manifest_file) @abstractmethod - def _download_manifest(self, - manifest_name: str): + def _download_manifest(self, manifest_name: str): """ Download a manifest from the dataset @@ -636,14 +608,12 @@ def load_manifest(self, manifest_name: str): self._manifest = self._load_manifest(manifest_name) # Keep track of the newly loaded manifest - with open(self._manifest_last_used, 'w') as out_file: + with open(self._manifest_last_used, "w") as out_file: out_file.write(manifest_name) self._manifest_name = manifest_name - def _update_list_of_downloads(self, - file_attributes: CacheFileAttributes - ) -> None: + def _update_list_of_downloads(self, file_attributes: CacheFileAttributes) -> None: """ Update the local file that keeps track of files that have actually been downloaded to reflect a newly downloaded file. @@ -661,7 +631,7 @@ def _update_list_of_downloads(self, return None if self._downloaded_data_path.exists(): - with open(self._downloaded_data_path, 'rb') as in_file: + with open(self._downloaded_data_path, "rb") as in_file: downloaded_data = json.load(in_file) else: downloaded_data = {} @@ -674,15 +644,11 @@ def _update_list_of_downloads(self, return None downloaded_data[abs_path] = file_attributes.file_hash - with open(self._downloaded_data_path, 'w') as out_file: - out_file.write(json.dumps(downloaded_data, - indent=2, - sort_keys=True)) + with open(self._downloaded_data_path, "w") as out_file: + out_file.write(json.dumps(downloaded_data, indent=2, sort_keys=True)) return None - def _check_for_identical_copy(self, - file_attributes: CacheFileAttributes - ) -> bool: + def _check_for_identical_copy(self, file_attributes: CacheFileAttributes) -> bool: """ Check the manifest of files that have been locally downloaded to see if a file with an identical hash to the requested file has already @@ -704,7 +670,7 @@ def _check_for_identical_copy(self, if not self._downloaded_data_path.exists(): return False - with open(self._downloaded_data_path, 'rb') as in_file: + with open(self._downloaded_data_path, "rb") as in_file: available_files = json.load(in_file) matched_path = None @@ -756,9 +722,7 @@ def _file_exists(self, file_attributes: CacheFileAttributes) -> bool: if file_attributes.local_path.exists(): if not file_attributes.local_path.is_file(): - raise RuntimeError(f"{file_attributes.local_path}\n" - "exists, but is not a file;\n" - "unsure how to proceed") + raise RuntimeError(f"{file_attributes.local_path}\nexists, but is not a file;\nunsure how to proceed") file_exists = True @@ -789,7 +753,7 @@ def download_data(self, file_id) -> pathlib.Path: If the file cannot be downloaded """ super_attributes = self.data_path(file_id) - file_attributes = super_attributes['file_attributes'] + file_attributes = super_attributes["file_attributes"] was_downloaded = self._download_file(file_attributes) if was_downloaded: self._update_list_of_downloads(file_attributes) @@ -817,7 +781,7 @@ def download_metadata(self, fname: str) -> pathlib.Path: If the file cannot be downloaded """ super_attributes = self.metadata_path(fname) - file_attributes = super_attributes['file_attributes'] + file_attributes = super_attributes["file_attributes"] was_downloaded = self._download_file(file_attributes) if was_downloaded: self._update_list_of_downloads(file_attributes) @@ -845,8 +809,7 @@ def get_metadata(self, fname: str) -> pd.DataFrame: local_path = self.download_metadata(fname) return pd.read_csv(local_path) - def _detect_changes(self, - filename_to_hash: dict) -> List[Tuple[str, str]]: + def _detect_changes(self, filename_to_hash: dict) -> List[Tuple[str, str]]: """ Assemble list of changes between two manifests @@ -893,18 +856,18 @@ def _detect_changes(self, h0 = filename_to_hash[0][fname] h1 = filename_to_hash[1][fname] if h0 != h1: - delta = f'{fname} changed' + delta = f"{fname} changed" elif fname in filename_to_hash[0]: h0 = filename_to_hash[0][fname] if h0 in hash_to_filename[1]: f1 = hash_to_filename[1][h0] - delta = f'{fname} renamed {f1}' + delta = f"{fname} renamed {f1}" else: - delta = f'{fname} deleted' + delta = f"{fname} deleted" elif fname in filename_to_hash[1]: h1 = filename_to_hash[1][fname] if h1 not in hash_to_filename[0]: - delta = f'{fname} created' + delta = f"{fname} created" else: raise RuntimeError("should never reach this line") @@ -913,10 +876,7 @@ def _detect_changes(self, return output - def summarize_comparison(self, - manifest_0_name: str, - manifest_1_name: str - ) -> Dict[str, List[Tuple[str, str]]]: + def summarize_comparison(self, manifest_0_name: str, manifest_1_name: str) -> Dict[str, List[Tuple[str, str]]]: """ Compare two manifests from this dataset. Return a dict containing the list of metadata and data files that changed @@ -958,34 +918,27 @@ def summarize_comparison(self, man1 = self._load_manifest(manifest_1_name) result: dict = dict() - for (result_key, - file_id_list, - attr_lookup) in zip(('metadata_changes', 'data_changes'), - ((man0.metadata_file_names, - man1.metadata_file_names), - (man0.file_id_values, - man1.file_id_values)), - ((man0.metadata_file_attributes, - man1.metadata_file_attributes), - (man0.data_file_attributes, - man1.data_file_attributes))): - + for result_key, file_id_list, attr_lookup in zip( + ("metadata_changes", "data_changes"), + ((man0.metadata_file_names, man1.metadata_file_names), (man0.file_id_values, man1.file_id_values)), + ( + (man0.metadata_file_attributes, man1.metadata_file_attributes), + (man0.data_file_attributes, man1.data_file_attributes), + ), + ): filename_to_hash: dict = dict() for version in (0, 1): filename_to_hash[version] = {} for file_id in file_id_list[version]: obj = attr_lookup[version](file_id) file_name = relative_path_from_url(obj.url) - file_name = '/'.join(file_name.split('/')[1:]) + file_name = "/".join(file_name.split("/")[1:]) filename_to_hash[version][file_name] = obj.file_hash changes = self._detect_changes(filename_to_hash) result[result_key] = changes return result - def compare_manifests(self, - manifest_0_name: str, - manifest_1_name: str - ) -> str: + def compare_manifests(self, manifest_0_name: str, manifest_1_name: str) -> str: """ Compare two manifests from this dataset. Return a dict containing the list of metadata and data files that changed @@ -1006,32 +959,31 @@ def compare_manifests(self, manifest_0 to manifest_1 """ - changes = self.summarize_comparison(manifest_0_name, - manifest_1_name) - if len(changes['data_changes']) == 0: - if len(changes['metadata_changes']) == 0: + changes = self.summarize_comparison(manifest_0_name, manifest_1_name) + if len(changes["data_changes"]) == 0: + if len(changes["metadata_changes"]) == 0: return "The two manifests are equivalent" data_change_dict = {} - for delta in changes['data_changes']: + for delta in changes["data_changes"]: data_change_dict[delta[0]] = delta[1] metadata_change_dict = {} - for delta in changes['metadata_changes']: + for delta in changes["metadata_changes"]: metadata_change_dict[delta[0]] = delta[1] - msg = 'Changes going from\n' - msg += f'{manifest_0_name}\n' - msg += 'to\n' - msg += f'{manifest_1_name}\n\n' + msg = "Changes going from\n" + msg += f"{manifest_0_name}\n" + msg += "to\n" + msg += f"{manifest_1_name}\n\n" m_keys = list(metadata_change_dict.keys()) m_keys.sort() for m in m_keys: - msg += f'{metadata_change_dict[m]}\n' + msg += f"{metadata_change_dict[m]}\n" d_keys = list(data_change_dict.keys()) d_keys.sort() for d in d_keys: - msg += f'{data_change_dict[d]}\n' + msg += f"{data_change_dict[d]}\n" return msg @@ -1058,13 +1010,11 @@ class S3CloudCache(CloudCacheBase): functionality (used to populate helpful error messages) """ - def __init__(self, cache_dir, bucket_name, project_name, - ui_class_name=None): + def __init__(self, cache_dir, bucket_name, project_name, ui_class_name=None): self._manifest = None self._bucket_name = bucket_name - super().__init__(cache_dir=cache_dir, project_name=project_name, - ui_class_name=ui_class_name) + super().__init__(cache_dir=cache_dir, project_name=project_name, ui_class_name=ui_class_name) _s3_client = None @@ -1076,8 +1026,7 @@ def bucket_name(self) -> str: def s3_client(self): if self._s3_client is None: s3_config = Config(signature_version=UNSIGNED) - self._s3_client = boto3.client('s3', - config=s3_config) + self._s3_client = boto3.client("s3", config=s3_config) return self._s3_client def _list_all_manifests(self) -> list: @@ -1085,23 +1034,19 @@ def _list_all_manifests(self) -> list: Return a list of all of the file names of the manifests associated with this dataset """ - paginator = self.s3_client.get_paginator('list_objects_v2') - subset_iterator = paginator.paginate( - Bucket=self._bucket_name, - Prefix=self.manifest_prefix - ) + paginator = self.s3_client.get_paginator("list_objects_v2") + subset_iterator = paginator.paginate(Bucket=self._bucket_name, Prefix=self.manifest_prefix) output = [] for subset in subset_iterator: - if 'Contents' in subset: - for obj in subset['Contents']: - output.append(pathlib.Path(obj['Key']).name) + if "Contents" in subset: + for obj in subset["Contents"]: + output.append(pathlib.Path(obj["Key"]).name) output.sort() return output - def _download_manifest(self, - manifest_name: str): + def _download_manifest(self, manifest_name: str): """ Download a manifest from the dataset @@ -1113,13 +1058,12 @@ def _download_manifest(self, """ manifest_key = self.manifest_prefix + manifest_name - response = self.s3_client.get_object(Bucket=self._bucket_name, - Key=manifest_key) + response = self.s3_client.get_object(Bucket=self._bucket_name, Key=manifest_key) filepath = os.path.join(self._cache_dir, manifest_name) - with open(filepath, 'wb') as f: - for chunk in response['Body'].iter_chunks(): + with open(filepath, "wb") as f: + for chunk in response["Body"].iter_chunks(): f.write(chunk) def _download_file(self, file_attributes: CacheFileAttributes) -> bool: @@ -1161,8 +1105,7 @@ def _download_file(self, file_attributes: CacheFileAttributes) -> bool: # returns a str os.makedirs(local_dir, exist_ok=True) if not os.path.isdir(local_dir): - raise RuntimeError(f"{local_dir}\n" - "is not a directory") + raise RuntimeError(f"{local_dir}\nis not a directory") bucket_name = bucket_name_from_url(file_attributes.url) obj_key = relative_path_from_url(file_attributes.url) @@ -1174,25 +1117,23 @@ def _download_file(self, file_attributes: CacheFileAttributes) -> bool: pbar = None if not self._file_exists(file_attributes): - response = self.s3_client.list_object_versions(Bucket=bucket_name, - Prefix=str(obj_key)) - object_info = [i for i in response["Versions"] - if i["VersionId"] == version_id][0] - pbar = tqdm.tqdm(desc=object_info["Key"].split("/")[-1], - total=object_info["Size"], - unit_scale=True, - unit_divisor=1000., - unit="MB") + response = self.s3_client.list_object_versions(Bucket=bucket_name, Prefix=str(obj_key)) + object_info = [i for i in response["Versions"] if i["VersionId"] == version_id][0] + pbar = tqdm.tqdm( + desc=object_info["Key"].split("/")[-1], + total=object_info["Size"], + unit_scale=True, + unit_divisor=1000.0, + unit="MB", + ) while not self._file_exists(file_attributes): was_downloaded = True - response = self.s3_client.get_object(Bucket=bucket_name, - Key=str(obj_key), - VersionId=version_id) + response = self.s3_client.get_object(Bucket=bucket_name, Key=str(obj_key), VersionId=version_id) - if 'Body' in response: - with open(local_path, 'wb') as out_file: - for chunk in response['Body'].iter_chunks(): + if "Body" in response: + with open(local_path, "wb") as out_file: + for chunk in response["Body"].iter_chunks(): out_file.write(chunk) pbar.update(len(chunk)) @@ -1206,9 +1147,7 @@ def _download_file(self, file_attributes: CacheFileAttributes) -> bool: n_iter += 1 if n_iter > max_iter: pbar.close() - raise RuntimeError("Could not download\n" - f"{file_attributes}\n" - "In {max_iter} iterations") + raise RuntimeError(f"Could not download\n{file_attributes}\nIn {{max_iter}} iterations") if pbar is not None: pbar.close() @@ -1232,9 +1171,9 @@ class LocalCache(CloudCacheBase): Name of the class users are actually using to maniuplate this functionality (used to populate helpful error messages) """ + def __init__(self, cache_dir, project_name, ui_class_name=None): - super().__init__(cache_dir=cache_dir, project_name=project_name, - ui_class_name=ui_class_name) + super().__init__(cache_dir=cache_dir, project_name=project_name, ui_class_name=ui_class_name) def _list_all_manifests(self) -> list: return self.list_all_downloaded_manifests() @@ -1272,8 +1211,7 @@ class StaticLocalCache(BasicLocalCache): """ def __init__(self, cache_dir, project_name, ui_class_name=None): - super().__init__(cache_dir=cache_dir, project_name=project_name, - ui_class_name=ui_class_name) + super().__init__(cache_dir=cache_dir, project_name=project_name, ui_class_name=ui_class_name) def _list_all_manifests(self) -> list: """ @@ -1281,9 +1219,7 @@ def _list_all_manifests(self) -> list: with this dataset. For the StaticLocalCache only return only the latest manifest. """ - manifest_dir = os.path.join( - self._cache_dir, self.project_name, "manifests" - ) + manifest_dir = os.path.join(self._cache_dir, self.project_name, "manifests") if not os.path.exists(manifest_dir): raise RuntimeError( f"Expected the provided cache_dir ({self._cache_dir})" @@ -1291,8 +1227,7 @@ def _list_all_manifests(self) -> list: f"{self.project_name}/manifests" ) - output = [x for x in os.listdir(manifest_dir) - if re.fullmatch(".*_manifest_v.*.json", x)] + output = [x for x in os.listdir(manifest_dir) if re.fullmatch(".*_manifest_v.*.json", x)] return [self._find_latest_file(output)] @@ -1317,10 +1252,7 @@ def load_manifest(self, manifest_name: str): The name of the manifest to load. Must be an element in self.manifest_file_names """ - self._manifest = self._load_manifest( - manifest_name, - use_static_project_dir=True - ) + self._manifest = self._load_manifest(manifest_name, use_static_project_dir=True) self._manifest_name = manifest_name def compare_manifests(self, manifest_0_name: str, manifest_1_name: str): diff --git a/allensdk/api/cloud_cache/file_attributes.py b/allensdk/api/cloud_cache/file_attributes.py index 682ceb457d..25aeb69d56 100644 --- a/allensdk/api/cloud_cache/file_attributes.py +++ b/allensdk/api/cloud_cache/file_attributes.py @@ -22,22 +22,15 @@ class CacheFileAttributes(object): (probably computed by the Manifest class) """ - def __init__(self, - url: str, - version_id: str, - file_hash: str, - local_path: pathlib.Path): - + def __init__(self, url: str, version_id: str, file_hash: str, local_path: pathlib.Path): if not isinstance(url, str): raise ValueError(f"url must be str; got {type(url)}") if not isinstance(version_id, str): raise ValueError(f"version_id must be str; got {type(version_id)}") if not isinstance(file_hash, str): - raise ValueError(f"file_hash must be str; " - f"got {type(file_hash)}") + raise ValueError(f"file_hash must be str; got {type(file_hash)}") if not isinstance(local_path, pathlib.Path): - raise ValueError(f"local_path must be pathlib.Path; " - f"got {type(local_path)}") + raise ValueError(f"local_path must be pathlib.Path; got {type(local_path)}") self._url = url self._version_id = version_id @@ -61,9 +54,11 @@ def local_path(self) -> pathlib.Path: return self._local_path def __str__(self): - output = {'url': self.url, - 'version_id': self.version_id, - 'file_hash': self.file_hash, - 'local_path': str(self.local_path)} + output = { + "url": self.url, + "version_id": self.version_id, + "file_hash": self.file_hash, + "local_path": str(self.local_path), + } output = json.dumps(output, indent=2, sort_keys=True) - return f'CacheFileParameters{output}' + return f"CacheFileParameters{output}" diff --git a/allensdk/api/cloud_cache/manifest.py b/allensdk/api/cloud_cache/manifest.py index 6e3160d203..b15aa015d1 100644 --- a/allensdk/api/cloud_cache/manifest.py +++ b/allensdk/api/cloud_cache/manifest.py @@ -35,39 +35,28 @@ class Manifest(object): local paths for remote resources. Defaults to False. """ - def __init__( - self, - cache_dir: Union[str, pathlib.Path], - json_input, - use_static_project_dir: bool = False - ): + def __init__(self, cache_dir: Union[str, pathlib.Path], json_input, use_static_project_dir: bool = False): if isinstance(cache_dir, str): self._cache_dir = pathlib.Path(cache_dir).resolve() elif isinstance(cache_dir, pathlib.Path): self._cache_dir = cache_dir.resolve() else: - raise ValueError("cache_dir must be either a str " - "or a pathlib.Path; " - f"got {type(cache_dir)}") + raise ValueError(f"cache_dir must be either a str or a pathlib.Path; got {type(cache_dir)}") self._use_static_project_dir = use_static_project_dir self._data: Dict[str, Any] = json.load(json_input) if not isinstance(self._data, dict): - raise ValueError("Expected to deserialize manifest into a dict; " - f"instead got {type(self._data)}") + raise ValueError(f"Expected to deserialize manifest into a dict; instead got {type(self._data)}") self._project_name: str = self._data["project_name"] - self._version: str = self._data['manifest_version'] - self._file_id_column: str = self._data['metadata_file_id_column_name'] + self._version: str = self._data["manifest_version"] + self._file_id_column: str = self._data["metadata_file_id_column_name"] self._data_pipeline: str = self._data["data_pipeline"] - self._metadata_file_names: List[str] = [ - file_name for file_name in self._data['metadata_files'] - ] + self._metadata_file_names: List[str] = [file_name for file_name in self._data["metadata_files"]] self._metadata_file_names.sort() - self._file_id_values: List[Any] = [ii for ii in - self._data['data_files'].keys()] + self._file_id_values: List[Any] = [ii for ii in self._data["data_files"].keys()] self._file_id_values.sort() @property @@ -107,10 +96,7 @@ def file_id_values(self): """ return self._file_id_values - def _create_file_attributes(self, - remote_path: str, - version_id: str, - file_hash: str) -> CacheFileAttributes: + def _create_file_attributes(self, remote_path: str, version_id: str, file_hash: str) -> CacheFileAttributes: """ Create the cache_file_attributes describing a file. This method does the work of assigning a local_path for a remote file. @@ -151,19 +137,11 @@ def _create_file_attributes(self, local_path = project_dir / shaved_rel_path - obj = CacheFileAttributes( - remote_path, - version_id, - file_hash, - local_path - ) + obj = CacheFileAttributes(remote_path, version_id, file_hash, local_path) return obj - def metadata_file_attributes( - self, - metadata_file_name: str - ) -> CacheFileAttributes: + def metadata_file_attributes(self, metadata_file_name: str) -> CacheFileAttributes: """ Return the CacheFileAttributes associated with a metadata file @@ -186,19 +164,15 @@ def metadata_file_attributes( If the metadata_file_name is not a valid option """ if self._data is None: - raise RuntimeError("You cannot retrieve " - "metadata_file_attributes;\n" - "you have not yet loaded a manifest.json file") + raise RuntimeError( + "You cannot retrieve metadata_file_attributes;\nyou have not yet loaded a manifest.json file" + ) if metadata_file_name not in self._metadata_file_names: - raise ValueError(f"{metadata_file_name}\n" - "is not in self.metadata_file_names:\n" - f"{self._metadata_file_names}") + raise ValueError(f"{metadata_file_name}\nis not in self.metadata_file_names:\n{self._metadata_file_names}") - file_data = self._data['metadata_files'][metadata_file_name] - return self._create_file_attributes(file_data['url'], - file_data['version_id'], - file_data['file_hash']) + file_data = self._data["metadata_files"][metadata_file_name] + return self._create_file_attributes(file_data["url"], file_data["version_id"], file_data["file_hash"]) def data_file_attributes(self, file_id) -> CacheFileAttributes: """ @@ -224,17 +198,14 @@ def data_file_attributes(self, file_id) -> CacheFileAttributes: If the file_id is not a valid option """ if self._data is None: - raise RuntimeError("You cannot retrieve data_file_attributes;\n" - "you have not yet loaded a manifest.json file") + raise RuntimeError( + "You cannot retrieve data_file_attributes;\nyou have not yet loaded a manifest.json file" + ) - if file_id not in self._data['data_files']: - valid_keys = list(self._data['data_files'].keys()) + if file_id not in self._data["data_files"]: + valid_keys = list(self._data["data_files"].keys()) valid_keys.sort() - raise ValueError(f"file_id: {file_id}\n" - "Is not a data file listed in manifest:\n" - f"{valid_keys}") - - file_data = self._data['data_files'][file_id] - return self._create_file_attributes(file_data['url'], - file_data['version_id'], - file_data['file_hash']) + raise ValueError(f"file_id: {file_id}\nIs not a data file listed in manifest:\n{valid_keys}") + + file_data = self._data["data_files"][file_id] + return self._create_file_attributes(file_data["url"], file_data["version_id"], file_data["file_hash"]) diff --git a/allensdk/api/cloud_cache/utils.py b/allensdk/api/cloud_cache/utils.py index 473161c856..e636299401 100644 --- a/allensdk/api/cloud_cache/utils.py +++ b/allensdk/api/cloud_cache/utils.py @@ -28,7 +28,7 @@ def bucket_name_from_url(url: str) -> Optional[str]: here https://aws.amazon.com/blogs/aws/amazon-s3-path-deprecation-plan-the-rest-of-the-story/ """ - s3_pattern = re.compile('\.s3[\.,a-z,0-9,\-]*\.amazonaws.com') # noqa: W605, E501 + s3_pattern = re.compile("\.s3[\.,a-z,0-9,\-]*\.amazonaws.com") # noqa: W605, E501 url_params = url_parse.urlparse(url) raw_location = url_params.netloc s3_match = s3_pattern.search(raw_location) @@ -37,8 +37,8 @@ def bucket_name_from_url(url: str) -> Optional[str]: warnings.warn(f"{s3_pattern} does not occur in url {url}") return None - s3_match = raw_location[s3_match.start():s3_match.end()] - return url_params.netloc.replace(s3_match, '') + s3_match = raw_location[s3_match.start() : s3_match.end()] + return url_params.netloc.replace(s3_match, "") def relative_path_from_url(url: str) -> str: @@ -81,7 +81,7 @@ def file_hash_from_path(file_path: Union[str, Path]) -> str: The file hash (Blake2b; hexadecimal) of the file """ hasher = hashlib.blake2b() - with open(file_path, 'rb') as in_file: + with open(file_path, "rb") as in_file: chunk = in_file.read(1000000) while len(chunk) > 0: hasher.update(chunk) diff --git a/allensdk/api/queries/__init__.py b/allensdk/api/queries/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/api/queries/__init__.py +++ b/allensdk/api/queries/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/api/queries/annotated_section_data_sets_api.py b/allensdk/api/queries/annotated_section_data_sets_api.py index 72cefbf874..f1d950ceab 100644 --- a/allensdk/api/queries/annotated_section_data_sets_api.py +++ b/allensdk/api/queries/annotated_section_data_sets_api.py @@ -38,20 +38,17 @@ class AnnotatedSectionDataSetsApi(RmaApi): - '''See: + """See: `Searching Annotated SectionDataSets `_ - ''' + """ def __init__(self, base_uri=None): super(AnnotatedSectionDataSetsApi, self).__init__(base_uri) - def get_annotated_section_data_sets(self, - structures, - intensity_values=None, - density_values=None, - pattern_values=None, - age_names=None): - '''For a list of target structures, find the SectionDataSet + def get_annotated_section_data_sets( + self, structures, intensity_values=None, density_values=None, pattern_values=None, age_names=None + ): + """For a list of target structures, find the SectionDataSet that matches the parameters for intensity_values, density_values, pattern_values, and Age. Parameters @@ -75,41 +72,32 @@ def get_annotated_section_data_sets(self, Notes ----- This method uses the non-RMA Annotated SectionDataSet endpoint. - ''' - params = ['structures=' + ','.join((str(s) for s in structures))] + """ + params = ["structures=" + ",".join((str(s) for s in structures))] if intensity_values is not None and len(intensity_values) > 0: - params.append('intensity_values=' + - ','.join(("'%s'" % (v) for v in intensity_values))) + params.append("intensity_values=" + ",".join(("'%s'" % (v) for v in intensity_values))) if density_values is not None and len(density_values) > 0: - params.append('density_values=' + - ','.join(("'%s'" % (v) for v in density_values))) + params.append("density_values=" + ",".join(("'%s'" % (v) for v in density_values))) if pattern_values is not None and len(pattern_values) > 0: - params.append('pattern_values=' + - ','.join(("'%s'" % (v) for v in pattern_values))) + params.append("pattern_values=" + ",".join(("'%s'" % (v) for v in pattern_values))) if age_names is not None and len(age_names) > 0: - params.append('age_names=' + - ','.join(("'%s'" % (v) for v in age_names))) + params.append("age_names=" + ",".join(("'%s'" % (v) for v in age_names))) - url_params = '?' + '&'.join(params) + url_params = "?" + "&".join(params) - url = ''.join([self.annotated_section_data_sets_endpoint, - '.json', - url_params]) + url = "".join([self.annotated_section_data_sets_endpoint, ".json", url_params]) return self.json_msg_query(url) @cacheable() - def get_annotated_section_data_sets_via_rma(self, - structures, - intensity_values=None, - density_values=None, - pattern_values=None, - age_names=None): - '''For a list of target structures, find the SectionDataSet + def get_annotated_section_data_sets_via_rma( + self, structures, intensity_values=None, density_values=None, pattern_values=None, age_names=None + ): + """For a list of target structures, find the SectionDataSet that matches the parameters for intensity_values, density_values, pattern_values, and Age. Parameters @@ -133,60 +121,47 @@ def get_annotated_section_data_sets_via_rma(self, Notes ----- This method uses the RMA endpoint to search annotated SectionDataSet data. - ''' - age_include_strings = ['age'] + """ + age_include_strings = ["age"] if age_names is not None and len(age_names) > 0: - age_include_strings.append('[name$in') - age_include_strings.append( - ','.join(("'%s'" % (a) for a in age_names))) - age_include_strings.append(']') - age_include = ''.join(age_include_strings) + age_include_strings.append("[name$in") + age_include_strings.append(",".join(("'%s'" % (a) for a in age_names))) + age_include_strings.append("]") + age_include = "".join(age_include_strings) - criteria_strings = ['manual_annotations'] + criteria_strings = ["manual_annotations"] if intensity_values is not None and len(intensity_values) > 0: - criteria_strings.append('[intensity_call$in%s]' % - (','.join(("'%s'" % (v) for v in intensity_values)))) + criteria_strings.append("[intensity_call$in%s]" % (",".join(("'%s'" % (v) for v in intensity_values)))) if density_values is not None and len(density_values) > 0: - criteria_strings.append('[density_call$in%s]' % - (','.join(("'%s'" % (v) for v in density_values)))) + criteria_strings.append("[density_call$in%s]" % (",".join(("'%s'" % (v) for v in density_values)))) if pattern_values is not None and len(pattern_values) > 0: - criteria_strings.append('[pattern_call$in%s]' % - (','.join(("'%s'" % (v) for v in pattern_values)))) + criteria_strings.append("[pattern_call$in%s]" % (",".join(("'%s'" % (v) for v in pattern_values)))) - criteria_strings.append('(structure[id$in%s])' % - (','.join((str(s) for s in structures)))) + criteria_strings.append("(structure[id$in%s])" % (",".join((str(s) for s in structures)))) - criteria_clause = ''.join(criteria_strings) + criteria_clause = "".join(criteria_strings) - include_clause = ''.join(['specimen', - '(donor(', - age_include, - ')),', - 'probes(gene),' - 'plane_of_section']) + include_clause = "".join(["specimen", "(donor(", age_include, ")),", "probes(gene),plane_of_section"]) - order_by_array = ['genes.acronym', - 'ages.embryonic+desc', - 'ages.days', - 'data_sets.id'] + order_by_array = ["genes.acronym", "ages.embryonic+desc", "ages.days", "data_sets.id"] - data = self.model_query('SectionDataSet', - criteria=criteria_clause, - include=include_clause, - start_row=0, - num_rows=50, - order=order_by_array) + data = self.model_query( + "SectionDataSet", + criteria=criteria_clause, + include=include_clause, + start_row=0, + num_rows=50, + order=order_by_array, + ) return data - def get_compound_annotated_section_data_sets(self, - queries, - fmt='json'): - '''Find the SectionDataSet that matches several annotated_section_data_sets queries + def get_compound_annotated_section_data_sets(self, queries, fmt="json"): + """Find the SectionDataSet that matches several annotated_section_data_sets queries linked together with a Boolean 'and' or 'or'. Parameters @@ -200,35 +175,29 @@ def get_compound_annotated_section_data_sets(self, ------- data : dict The parsed JSON repsonse message. - ''' - url_strings = ['?query='] + """ + url_strings = ["?query="] for query in queries: - url_strings.append('[') + url_strings.append("[") - params = ['structures $in ' + - ','.join((str(s) for s in query['structures']))] + params = ["structures $in " + ",".join((str(s) for s in query["structures"]))] - for key in ['intensity_values', 'density_values', 'pattern_values', 'age_names']: + for key in ["intensity_values", "density_values", "pattern_values", "age_names"]: if key in query and len(query[key]) > 0: - params.append('%s $in %s' % - (key, - ','.join(("'%s'" % (v) for v in query['intensity_values'])))) + params.append("%s $in %s" % (key, ",".join(("'%s'" % (v) for v in query["intensity_values"])))) - url_strings.append(' : '.join(params)) + url_strings.append(" : ".join(params)) - url_strings.append(']') + url_strings.append("]") - if 'link' in query and query['link'] == 'or': - url_strings.append(' or ') - if 'link' in query and query['link'] == 'and': - url_strings.append(' and ') + if "link" in query and query["link"] == "or": + url_strings.append(" or ") + if "link" in query and query["link"] == "and": + url_strings.append(" and ") - url_params = ''.join(url_strings) + url_params = "".join(url_strings) - url = ''.join([self.compound_annotated_section_data_sets_endpoint, - '.', - fmt, - url_params]) + url = "".join([self.compound_annotated_section_data_sets_endpoint, ".", fmt, url_params]) return self.json_msg_query(url) diff --git a/allensdk/api/queries/biophysical_api.py b/allensdk/api/queries/biophysical_api.py index 89cd6021f2..ab47fe3551 100644 --- a/allensdk/api/queries/biophysical_api.py +++ b/allensdk/api/queries/biophysical_api.py @@ -42,24 +42,29 @@ class BiophysicalApi(RmaTemplate): - _NWB_file_type = 'NWBDownload' - _SWC_file_type = '3DNeuronReconstruction' - _MOD_file_type = 'BiophysicalModelDescription' - _FIT_file_type = 'NeuronalModelParameters' - _MARKER_file_type = '3DNeuronMarker' - BIOPHYSICAL_MODEL_TYPE_IDS = (491455321, 329230710,) - - rma_templates = \ - {"model_queries": [ - {'name': 'models_by_specimen', - 'description': 'see name', - 'model': 'NeuronalModel', - 'num_rows': 'all', - 'count': False, - 'criteria': '[neuronal_model_template_id$in{{biophysical_model_types}}],[specimen_id$in{{specimen_ids}}]', - 'criteria_params': ['specimen_ids', 'biophysical_model_types'] - }]} - + _NWB_file_type = "NWBDownload" + _SWC_file_type = "3DNeuronReconstruction" + _MOD_file_type = "BiophysicalModelDescription" + _FIT_file_type = "NeuronalModelParameters" + _MARKER_file_type = "3DNeuronMarker" + BIOPHYSICAL_MODEL_TYPE_IDS = ( + 491455321, + 329230710, + ) + + rma_templates = { + "model_queries": [ + { + "name": "models_by_specimen", + "description": "see name", + "model": "NeuronalModel", + "num_rows": "all", + "count": False, + "criteria": "[neuronal_model_template_id$in{{biophysical_model_types}}],[specimen_id$in{{specimen_ids}}]", + "criteria_params": ["specimen_ids", "biophysical_model_types"], + } + ] + } def __init__(self, base_uri=None): super(BiophysicalApi, self).__init__(base_uri, query_manifest=BiophysicalApi.rma_templates) @@ -69,10 +74,9 @@ def __init__(self, base_uri=None): self.manifest = {} self.model_type = None - @cacheable() - def get_neuronal_models(self, specimen_ids, num_rows='all', count=False, model_type_ids=None, **kwargs): - '''Fetch all of the biophysically detailed model records associated with + def get_neuronal_models(self, specimen_ids, num_rows="all", count=False, model_type_ids=None, **kwargs): + """Fetch all of the biophysically detailed model records associated with a particular specimen_id Parameters @@ -84,29 +88,32 @@ def get_neuronal_models(self, specimen_ids, num_rows='all', count=False, model_t count : bool, optional If True, return a count of the lines found by the query. Default is False. model_type_ids : list, optional - One or more integer ids identifying categories of neuronal model. Defaults + One or more integer ids identifying categories of neuronal model. Defaults to all-active and perisomatic biophysical_models. Returns ------- List of dict - Each element is a biophysical model record, containing a unique integer - id, the id of the associated specimen, and the id of the model type to + Each element is a biophysical model record, containing a unique integer + id, the id of the associated specimen, and the id of the model type to which this model belongs. - - ''' + + """ if model_type_ids is None: model_type_ids = self.BIOPHYSICAL_MODEL_TYPE_IDS - return self.template_query('model_queries', 'models_by_specimen', - specimen_ids=specimen_ids, - biophysical_model_types=list(model_type_ids), - num_rows=num_rows, count=count) + return self.template_query( + "model_queries", + "models_by_specimen", + specimen_ids=specimen_ids, + biophysical_model_types=list(model_type_ids), + num_rows=num_rows, + count=count, + ) - - def build_rma(self, neuronal_model_id, fmt='json'): - '''Construct a query to find all files related to a neuronal model. + def build_rma(self, neuronal_model_id, fmt="json"): + """Construct a query to find all files related to a neuronal model. Parameters ---------- @@ -119,29 +126,34 @@ def build_rma(self, neuronal_model_id, fmt='json'): ------- string RMA query url. - ''' - include_associations = ''.join([ - 'neuronal_model_template(well_known_files(well_known_file_type)),', - 'specimen(ephys_result(well_known_files(well_known_file_type)),', - 'neuron_reconstructions(well_known_files(well_known_file_type)),', - 'ephys_sweeps),', - 'well_known_files(well_known_file_type)']) - criteria_associations = ''.join([ - ("[id$eq%d]," % (neuronal_model_id)), - include_associations]) - - return ''.join([self.rma_endpoint, - '/query.', - fmt, - '?q=', - 'model::NeuronalModel,', - 'rma::criteria,', - criteria_associations, - ',rma::include,', - include_associations]) + """ + include_associations = "".join( + [ + "neuronal_model_template(well_known_files(well_known_file_type)),", + "specimen(ephys_result(well_known_files(well_known_file_type)),", + "neuron_reconstructions(well_known_files(well_known_file_type)),", + "ephys_sweeps),", + "well_known_files(well_known_file_type)", + ] + ) + criteria_associations = "".join([("[id$eq%d]," % (neuronal_model_id)), include_associations]) + + return "".join( + [ + self.rma_endpoint, + "/query.", + fmt, + "?q=", + "model::NeuronalModel,", + "rma::criteria,", + criteria_associations, + ",rma::include,", + include_associations, + ] + ) def read_json(self, json_parsed_data): - '''Get the list of well_known_file ids from a response body + """Get the list of well_known_file ids from a response body containing nested sample,microarray_slides,well_known_files. Parameters @@ -153,76 +165,73 @@ def read_json(self, json_parsed_data): ------- list of strings Well known file ids. - ''' - self.ids = { - 'stimulus': {}, - 'morphology': {}, - 'marker': {}, - 'modfiles': {}, - 'fit': {} - } + """ + self.ids = {"stimulus": {}, "morphology": {}, "marker": {}, "modfiles": {}, "fit": {}} self.sweeps = [] - if 'msg' in json_parsed_data: - for neuronal_model in json_parsed_data['msg']: - if 'well_known_files' in neuronal_model: - for well_known_file in neuronal_model['well_known_files']: - if ('id' in well_known_file and - 'path' in well_known_file and - self.is_well_known_file_type(well_known_file, - BiophysicalApi._FIT_file_type)): - self.ids['fit'][str(well_known_file['id'])] = \ - os.path.split(well_known_file['path'])[1] - - if 'neuronal_model_template' in neuronal_model: - neuronal_model_template = neuronal_model[ - 'neuronal_model_template'] - self.model_type = neuronal_model_template['name'] - if 'well_known_files' in neuronal_model_template: - for well_known_file in neuronal_model_template['well_known_files']: - if ('id' in well_known_file and - 'path' in well_known_file and - self.is_well_known_file_type(well_known_file, - BiophysicalApi._MOD_file_type)): - self.ids['modfiles'][str(well_known_file['id'])] = \ - os.path.join('modfiles', - os.path.split(well_known_file['path'])[1]) - - if 'specimen' in neuronal_model: - specimen = neuronal_model['specimen'] - - if 'neuron_reconstructions' in specimen: - for neuron_reconstruction in specimen['neuron_reconstructions']: - if 'well_known_files' in neuron_reconstruction: - for well_known_file in neuron_reconstruction['well_known_files']: - if ('id' in well_known_file and 'path' in well_known_file): + if "msg" in json_parsed_data: + for neuronal_model in json_parsed_data["msg"]: + if "well_known_files" in neuronal_model: + for well_known_file in neuronal_model["well_known_files"]: + if ( + "id" in well_known_file + and "path" in well_known_file + and self.is_well_known_file_type(well_known_file, BiophysicalApi._FIT_file_type) + ): + self.ids["fit"][str(well_known_file["id"])] = os.path.split(well_known_file["path"])[1] + + if "neuronal_model_template" in neuronal_model: + neuronal_model_template = neuronal_model["neuronal_model_template"] + self.model_type = neuronal_model_template["name"] + if "well_known_files" in neuronal_model_template: + for well_known_file in neuronal_model_template["well_known_files"]: + if ( + "id" in well_known_file + and "path" in well_known_file + and self.is_well_known_file_type(well_known_file, BiophysicalApi._MOD_file_type) + ): + self.ids["modfiles"][str(well_known_file["id"])] = os.path.join( + "modfiles", os.path.split(well_known_file["path"])[1] + ) + + if "specimen" in neuronal_model: + specimen = neuronal_model["specimen"] + + if "neuron_reconstructions" in specimen: + for neuron_reconstruction in specimen["neuron_reconstructions"]: + if "well_known_files" in neuron_reconstruction: + for well_known_file in neuron_reconstruction["well_known_files"]: + if "id" in well_known_file and "path" in well_known_file: if self.is_well_known_file_type(well_known_file, BiophysicalApi._SWC_file_type): - self.ids['morphology'][str(well_known_file['id'])] = \ - os.path.split( - well_known_file['path'])[1] - elif self.is_well_known_file_type(well_known_file, BiophysicalApi._MARKER_file_type): - self.ids['marker'][str(well_known_file['id'])] = \ - os.path.split( - well_known_file['path'])[1] - - if 'ephys_result' in specimen: - ephys_result = specimen['ephys_result'] - if 'well_known_files' in ephys_result: - for well_known_file in ephys_result['well_known_files']: - if ('id' in well_known_file and - 'path' in well_known_file and - self.is_well_known_file_type(well_known_file, BiophysicalApi._NWB_file_type)): - self.ids['stimulus'][str(well_known_file['id'])] = \ - "%d.nwb" % (ephys_result['id']) - - self.sweeps = [sweep['sweep_number'] - for sweep in specimen['ephys_sweeps'] - if sweep['stimulus_name'] != 'Test'] + self.ids["morphology"][str(well_known_file["id"])] = os.path.split( + well_known_file["path"] + )[1] + elif self.is_well_known_file_type( + well_known_file, BiophysicalApi._MARKER_file_type + ): + self.ids["marker"][str(well_known_file["id"])] = os.path.split( + well_known_file["path"] + )[1] + + if "ephys_result" in specimen: + ephys_result = specimen["ephys_result"] + if "well_known_files" in ephys_result: + for well_known_file in ephys_result["well_known_files"]: + if ( + "id" in well_known_file + and "path" in well_known_file + and self.is_well_known_file_type(well_known_file, BiophysicalApi._NWB_file_type) + ): + self.ids["stimulus"][str(well_known_file["id"])] = "%d.nwb" % (ephys_result["id"]) + + self.sweeps = [ + sweep["sweep_number"] for sweep in specimen["ephys_sweeps"] if sweep["stimulus_name"] != "Test" + ] return self.ids def is_well_known_file_type(self, wkf, name): - '''Check if a structure has the expected name. + """Check if a structure has the expected name. Parameters ---------- @@ -234,34 +243,30 @@ def is_well_known_file_type(self, wkf, name): See Also -------- read_json: where this helper function is used. - ''' + """ try: - return wkf['well_known_file_type']['name'] == name + return wkf["well_known_file_type"]["name"] == name except Exception: return False def get_well_known_file_ids(self, neuronal_model_id): - '''Query the current RMA endpoint with a neuronal_model id + """Query the current RMA endpoint with a neuronal_model id to get the corresponding well known file ids. Returns ------- list A list of well known file id strings. - ''' + """ rma_builder_fn = self.build_rma json_traversal_fn = self.read_json return self.do_query(rma_builder_fn, json_traversal_fn, neuronal_model_id) - def create_manifest(self, - fit_path='', - model_type='', - stimulus_filename='', - swc_morphology_path='', - marker_path='', - sweeps=[]): - '''Generate a json configuration file with parameters for a + def create_manifest( + self, fit_path="", model_type="", stimulus_filename="", swc_morphology_path="", marker_path="", sweeps=[] + ): + """Generate a json configuration file with parameters for a a biophysical experiment. Parameters @@ -274,64 +279,23 @@ def create_manifest(self, file in SWC format. sweeps : array of integers which sweeps in the stimulus file are to be used. - ''' + """ self.manifest = OrderedDict() - self.manifest['biophys'] = [{ - 'model_file': ['manifest.json', fit_path], - 'model_type': model_type - }] - self.manifest['runs'] = [{ - 'sweeps': sweeps - }] - self.manifest['neuron'] = [{ - 'hoc': ['stdgui.hoc', 'import3d.hoc'] - }] - self.manifest['manifest'] = [ - { - 'type': 'dir', - 'spec': '.', - 'key': 'BASEDIR' - }, - { - 'type': 'dir', - 'spec': 'work', - 'key': 'WORKDIR', - 'parent': 'BASEDIR' - }, - { - 'type': 'file', - 'spec': swc_morphology_path, - 'key': 'MORPHOLOGY' - }, - { - 'type': 'file', - 'spec': marker_path, - 'key': 'MARKER' - }, - { - 'type': 'dir', - 'spec': 'modfiles', - 'key': 'MODFILE_DIR' - }, - { - 'type': 'file', - 'format': 'NWB', - 'spec': stimulus_filename, - 'key': 'stimulus_path' - }, - { - 'parent_key': 'WORKDIR', - 'type': 'file', - 'format': 'NWB', - 'spec': stimulus_filename, - 'key': 'output_path' - } + self.manifest["biophys"] = [{"model_file": ["manifest.json", fit_path], "model_type": model_type}] + self.manifest["runs"] = [{"sweeps": sweeps}] + self.manifest["neuron"] = [{"hoc": ["stdgui.hoc", "import3d.hoc"]}] + self.manifest["manifest"] = [ + {"type": "dir", "spec": ".", "key": "BASEDIR"}, + {"type": "dir", "spec": "work", "key": "WORKDIR", "parent": "BASEDIR"}, + {"type": "file", "spec": swc_morphology_path, "key": "MORPHOLOGY"}, + {"type": "file", "spec": marker_path, "key": "MARKER"}, + {"type": "dir", "spec": "modfiles", "key": "MODFILE_DIR"}, + {"type": "file", "format": "NWB", "spec": stimulus_filename, "key": "stimulus_path"}, + {"parent_key": "WORKDIR", "type": "file", "format": "NWB", "spec": stimulus_filename, "key": "output_path"}, ] - def cache_data(self, - neuronal_model_id, - working_directory=None): - '''Take a an experiment id, query the Api RMA to get well-known-files + def cache_data(self, neuronal_model_id, working_directory=None): + """Take a an experiment id, query the Api RMA to get well-known-files download the files, and store them in the working directory. Parameters @@ -340,51 +304,40 @@ def cache_data(self, found in the neuronal_model table in the api working_directory : string Absolute path name where the downloaded well-known files will be stored. - ''' + """ if working_directory is None: working_directory = self.default_working_directory - well_known_file_id_dict = self.get_well_known_file_ids( - neuronal_model_id) + well_known_file_id_dict = self.get_well_known_file_ids(neuronal_model_id) - if not well_known_file_id_dict or \ - (not any(list(well_known_file_id_dict.values()))): - raise(Exception("No data found for neuronal model id %d" % - (neuronal_model_id))) + if not well_known_file_id_dict or (not any(list(well_known_file_id_dict.values()))): + raise (Exception("No data found for neuronal model id %d" % (neuronal_model_id))) Manifest.safe_mkdir(working_directory) - work_dir = os.path.join(working_directory, 'work') + work_dir = os.path.join(working_directory, "work") Manifest.safe_mkdir(work_dir) - modfile_dir = os.path.join(working_directory, 'modfiles') + modfile_dir = os.path.join(working_directory, "modfiles") Manifest.safe_mkdir(modfile_dir) for key, id_dict in well_known_file_id_dict.items(): - if (not self.cache_stimulus) and (key == 'stimulus'): + if (not self.cache_stimulus) and (key == "stimulus"): continue for well_known_id, filename in id_dict.items(): - well_known_file_url = self.construct_well_known_file_download_url( - well_known_id) + well_known_file_url = self.construct_well_known_file_download_url(well_known_id) cached_file_path = os.path.join(working_directory, filename) - self.retrieve_file_over_http( - well_known_file_url, cached_file_path) - - fit_path = list(self.ids['fit'].values())[0] - stimulus_filename = list(self.ids['stimulus'].values())[0] - swc_morphology_path = list(self.ids['morphology'].values())[0] - marker_path = \ - list(self.ids['marker'].values())[0] if 'marker' in self.ids else '' + self.retrieve_file_over_http(well_known_file_url, cached_file_path) + + fit_path = list(self.ids["fit"].values())[0] + stimulus_filename = list(self.ids["stimulus"].values())[0] + swc_morphology_path = list(self.ids["morphology"].values())[0] + marker_path = list(self.ids["marker"].values())[0] if "marker" in self.ids else "" sweeps = sorted(self.sweeps) - self.create_manifest(fit_path, - self.model_type, - stimulus_filename, - swc_morphology_path, - marker_path, - sweeps) + self.create_manifest(fit_path, self.model_type, stimulus_filename, swc_morphology_path, marker_path, sweeps) - manifest_path = os.path.join(working_directory, 'manifest.json') - with open(manifest_path, 'w') as f: + manifest_path = os.path.join(working_directory, "manifest.json") + with open(manifest_path, "w") as f: json.dump(self.manifest, f, indent=2) diff --git a/allensdk/api/queries/brain_observatory_api.py b/allensdk/api/queries/brain_observatory_api.py index 0fddf91571..2a291b4b54 100644 --- a/allensdk/api/queries/brain_observatory_api.py +++ b/allensdk/api/queries/brain_observatory_api.py @@ -73,7 +73,7 @@ class BrainObservatoryApi(RmaTemplate): "description": "see name", "model": "IsiExperiment", "criteria": "[id$in{{ isi_experiment_ids }}]", - "include": "experiment_container(ophys_experiments,targeted_structure)", # noqa e501 + "include": "experiment_container(ophys_experiments,targeted_structure)", # noqa e501 "num_rows": "all", "count": False, "criteria_params": ["isi_experiment_ids"], @@ -82,7 +82,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "ophys_experiment_by_ids", "description": "see name", "model": "OphysExperiment", - "criteria": "{% if ophys_experiment_ids is defined %}[id$in{{ ophys_experiment_ids }}]{%endif%}", # noqa e501 + "criteria": "{% if ophys_experiment_ids is defined %}[id$in{{ ophys_experiment_ids }}]{%endif%}", # noqa e501 "include": "experiment_container,well_known_files(well_known_file_type),targeted_structure,specimen(donor(age,transgenic_lines))", # noqa e501 "num_rows": "all", "count": False, @@ -92,7 +92,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "ophys_experiment_data", "description": "see name", "model": "WellKnownFile", - "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 % NWB_FILE_TYPE, "num_rows": "all", "count": False, @@ -102,7 +102,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "ophys_analysis_file", "description": "see name", "model": "WellKnownFile", - "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 % OPHYS_ANALYSIS_FILE_TYPE, "num_rows": "all", "count": False, @@ -112,7 +112,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "ophys_events_file", "description": "see name", "model": "WellKnownFile", - "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 % OPHYS_EVENTS_FILE_TYPE, "num_rows": "all", "count": False, @@ -139,7 +139,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "stimulus_mapping", "description": "see name", "model": "ApiCamStimulusMapping", - "criteria": "{% if stimulus_mapping_ids is defined %}[id$in{{ stimulus_mapping_ids }}]{%endif%}", # noqa e501 + "criteria": "{% if stimulus_mapping_ids is defined %}[id$in{{ stimulus_mapping_ids }}]{%endif%}", # noqa e501 "num_rows": "all", "count": False, "criteria_params": ["stimulus_mapping_ids"], @@ -148,7 +148,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "experiment_container", "description": "see name", "model": "ExperimentContainer", - "criteria": "{% if experiment_container_ids is defined %}[id$in{{ experiment_container_ids }}]{%endif%}", # noqa e501 + "criteria": "{% if experiment_container_ids is defined %}[id$in{{ experiment_container_ids }}]{%endif%}", # noqa e501 "include": "ophys_experiments,isi_experiment,specimen(donor(conditions,age,transgenic_lines)),targeted_structure", # noqa e501 "num_rows": "all", "count": False, @@ -158,7 +158,7 @@ class BrainObservatoryApi(RmaTemplate): "name": "experiment_container_metric", "description": "see name", "model": "ApiCamExperimentContainerMetric", - "criteria": "{% if experiment_container_metric_ids is defined %}[id$in{{ experiment_container_metric_ids }}]{%endif%}", # noqa e501 + "criteria": "{% if experiment_container_metric_ids is defined %}[id$in{{ experiment_container_metric_ids }}]{%endif%}", # noqa e501 "num_rows": "all", "count": False, "criteria_params": ["experiment_container_metric_ids"], @@ -167,23 +167,23 @@ class BrainObservatoryApi(RmaTemplate): "name": "cell_metric", "description": "see name", "model": "ApiCamCellMetric", - "criteria": "{% if cell_specimen_ids is defined %}[cell_specimen_id$in{{ cell_specimen_ids }}]{%endif%}", # noqa e501 + "criteria": "{% if cell_specimen_ids is defined %}[cell_specimen_id$in{{ cell_specimen_ids }}]{%endif%}", # noqa e501 "criteria_params": ["cell_specimen_ids"], }, { "name": "cell_specimen_id_mapping_table", "description": "see name", "model": "WellKnownFile", - "criteria": "[id$eq{{ mapping_table_id }}],well_known_file_type[name$eqOphysCellSpecimenIdMapping]", # noqa e501 + "criteria": "[id$eq{{ mapping_table_id }}],well_known_file_type[name$eqOphysCellSpecimenIdMapping]", # noqa e501 "num_rows": "all", "count": False, "criteria_params": ["mapping_table_id"], }, { "name": "eye_gaze_mapping_file", - "description": "h5 file containing mouse eye gaze mapped onto screen coordinates (as well as pupil and eye sizes)", # noqa e501 + "description": "h5 file containing mouse eye gaze mapped onto screen coordinates (as well as pupil and eye sizes)", # noqa e501 "model": "WellKnownFile", - "criteria": "[attachable_id$eq{{ ophys_session_id }}],well_known_file_type[name$eqEyeDlcScreenMapping]", # noqa e501 + "criteria": "[attachable_id$eq{{ ophys_session_id }}],well_known_file_type[name$eqEyeDlcScreenMapping]", # noqa e501 "num_rows": "all", "count": False, "criteria_params": ["ophys_session_id"], @@ -194,7 +194,7 @@ class BrainObservatoryApi(RmaTemplate): # the relationship is added. { "name": "all_eye_mapping_files", - "description": "Get a list of dictionaries for all eye mapping wkfs", # noqa e501 + "description": "Get a list of dictionaries for all eye mapping wkfs", # noqa e501 "model": "WellKnownFile", "criteria": "well_known_file_type[name$eqEyeDlcScreenMapping]", "num_rows": "all", @@ -215,9 +215,7 @@ class BrainObservatoryApi(RmaTemplate): } def __init__(self, base_uri=None, datacube_uri=None): - super(BrainObservatoryApi, self).__init__( - base_uri, query_manifest=BrainObservatoryApi.rma_templates - ) + super(BrainObservatoryApi, self).__init__(base_uri, query_manifest=BrainObservatoryApi.rma_templates) self.datacube_uri = datacube_uri @@ -274,8 +272,7 @@ def list_isi_experiments(self, isi_ids=None): ------- dict : neuronal model metadata """ - data = self.template_query( - "brain_observatory_queries", "list_isi_experiments") + data = self.template_query("brain_observatory_queries", "list_isi_experiments") return data @@ -289,9 +286,7 @@ def list_column_definition_class_names(self): ------- list : api class name strings """ - data = self.template_query( - "brain_observatory_queries", "column_definition_class_names" - ) + data = self.template_query("brain_observatory_queries", "column_definition_class_names") names = list(set([n["api_class_name"] for n in data])) @@ -387,10 +382,7 @@ def get_experiment_containers(self, experiment_container_ids=None): return data - def get_experiment_container_metrics( - self, - experiment_container_metric_ids=None - ): + def get_experiment_container_metrics(self, experiment_container_metric_ids=None): """Get experiment container metrics by id Parameters @@ -412,8 +404,7 @@ def get_experiment_container_metrics( @cacheable( strategy="create", - pathfinder=Cache.pathfinder( - file_name_position=2, path_keyword="file_name"), + pathfinder=Cache.pathfinder(file_name_position=2, path_keyword="file_name"), ) def save_ophys_experiment_data(self, ophys_experiment_id, file_name): data = self.template_query( @@ -425,25 +416,17 @@ def save_ophys_experiment_data(self, ophys_experiment_id, file_name): try: file_url = data[0]["download_link"] except Exception: - raise Exception( - "ophys experiment %d has no data file" % ophys_experiment_id - ) + raise Exception("ophys experiment %d has no data file" % ophys_experiment_id) - self._log.warning( - "Downloading ophys_experiment %d NWB. This can take some time." - % ophys_experiment_id - ) + self._log.warning("Downloading ophys_experiment %d NWB. This can take some time." % ophys_experiment_id) self.retrieve_file_over_http(self.api_url + file_url, file_name) @cacheable( strategy="create", - pathfinder=Cache.pathfinder( - file_name_position=2, path_keyword="file_name"), + pathfinder=Cache.pathfinder(file_name_position=2, path_keyword="file_name"), ) - def save_ophys_experiment_analysis_data( - self, ophys_experiment_id, file_name): - + def save_ophys_experiment_analysis_data(self, ophys_experiment_id, file_name): data = self.template_query( "brain_observatory_queries", "ophys_analysis_file", @@ -453,23 +436,17 @@ def save_ophys_experiment_analysis_data( try: file_url = data[0]["download_link"] except Exception: - raise Exception( - "ophys experiment %d has no analysis file" % - (ophys_experiment_id,) - ) + raise Exception("ophys experiment %d has no analysis file" % (ophys_experiment_id,)) self._log.warning( - "Downloading ophys_experiment %d analysis file. This can take " - "some time." - % (ophys_experiment_id,) + "Downloading ophys_experiment %d analysis file. This can take some time." % (ophys_experiment_id,) ) self.retrieve_file_over_http(self.api_url + file_url, file_name) @cacheable( strategy="create", - pathfinder=Cache.pathfinder(file_name_position=2, - path_keyword="file_name"), + pathfinder=Cache.pathfinder(file_name_position=2, path_keyword="file_name"), ) def save_ophys_experiment_event_data(self, ophys_experiment_id, file_name): data = self.template_query( @@ -480,21 +457,13 @@ def save_ophys_experiment_event_data(self, ophys_experiment_id, file_name): try: file_url = data[0]["download_link"] except Exception: - raise Exception( - "ophys experiment %d has no events file" % ophys_experiment_id - ) - self._log.warning( - "Downloading ophys_experiment %d events file. This can take " - "some time." - % ophys_experiment_id - ) + raise Exception("ophys experiment %d has no events file" % ophys_experiment_id) + self._log.warning("Downloading ophys_experiment %d events file. This can take some time." % ophys_experiment_id) self.retrieve_file_over_http(self.api_url + file_url, file_name) @staticmethod - def save_ophys_experiment_eye_tracking_data( - ophys_experiment_id, cloud_cache: S3CloudCache - ) -> Path: + def save_ophys_experiment_eye_tracking_data(ophys_experiment_id, cloud_cache: S3CloudCache) -> Path: """ Downloads eye tracking data for `ophys_experiment_id` using `S3CloudCache` @@ -519,42 +488,29 @@ def save_ophys_experiment_eye_tracking_data( meta = cloud_cache.get_metadata(fname="metadata") meta = meta.set_index("ophys_experiment_id") if ophys_experiment_id not in meta.index: - raise ValueError( - f"No eye tracking data for ophys experiment id " - f"{ophys_experiment_id}" - ) + raise ValueError(f"No eye tracking data for ophys experiment id {ophys_experiment_id}") file_id = meta.loc[ophys_experiment_id]["file_id"] file_path = cloud_cache.download_data(file_id=str(file_id)) return file_path @cacheable( strategy="create", - pathfinder=Cache.pathfinder(file_name_position=3, - path_keyword="file_name"), + pathfinder=Cache.pathfinder(file_name_position=3, path_keyword="file_name"), ) - def save_ophys_experiment_eye_gaze_data( - self, ophys_experiment_id: int, ophys_session_id: int, file_name: str - ): + def save_ophys_experiment_eye_gaze_data(self, ophys_experiment_id: int, ophys_session_id: int, file_name: str): data = self.template_query( "brain_observatory_queries", "eye_gaze_mapping_file", ophys_session_id=ophys_session_id, ) - experiment_session_string = ( - f"ophys_experiment '{ophys_experiment_id}' (session " - f"'{ophys_session_id}')" - ) + experiment_session_string = f"ophys_experiment '{ophys_experiment_id}' (session '{ophys_session_id}')" try: file_url = data[0]["download_link"] except Exception: - raise Exception(f"{experiment_session_string} has no eye gaze " - f"mapping file") - self._log.warning( - f"Downloading {experiment_session_string} gaze mapping file. " - f"This can take some time." - ) + raise Exception(f"{experiment_session_string} has no eye gaze mapping file") + self._log.warning(f"Downloading {experiment_session_string} gaze mapping file. This can take some time.") self.retrieve_file_over_http(self.api_url + file_url, file_name) @@ -569,7 +525,6 @@ def filter_experiments_and_containers( transgenic_lines=None, include_failed=False, ): - if not include_failed: objs = [o for o in objs if not o.get("failed", False)] @@ -577,11 +532,7 @@ def filter_experiments_and_containers( objs = [o for o in objs if o["id"] in ids] if targeted_structures is not None: - objs = [ - o - for o in objs - if o["targeted_structure"]["acronym"] in targeted_structures - ] + objs = [o for o in objs if o["targeted_structure"]["acronym"] in targeted_structures] if imaging_depths is not None: objs = [o for o in objs if o["imaging_depth"] in imaging_depths] @@ -594,26 +545,14 @@ def filter_experiments_and_containers( if reporter_lines is not None: tls = [tl.lower() for tl in reporter_lines] - obj_tls = [find_specimen_reporter_line(o["specimen"]) - for o in objs] + obj_tls = [find_specimen_reporter_line(o["specimen"]) for o in objs] obj_tls = [o.lower() if o else None for o in obj_tls] objs = [o for i, o in enumerate(objs) if obj_tls[i] in tls] if transgenic_lines is not None: tls = set([tl.lower() for tl in transgenic_lines]) objs = [ - o - for o in objs - if len( - tls - & set( - [ - tl.lower() - for tl in find_specimen_transgenic_lines( - o["specimen"]) - ] - ) - ) + o for o in objs if len(tls & set([tl.lower() for tl in find_specimen_transgenic_lines(o["specimen"])])) ] return objs @@ -630,7 +569,6 @@ def filter_experiment_containers( include_failed=False, simple=False, ): - containers = self.filter_experiments_and_containers( containers, ids=ids, @@ -663,7 +601,6 @@ def filter_ophys_experiments( require_eye_tracking=False, simple=False, ): - experiments = self.filter_experiments_and_containers( experiments, ids=ids, @@ -675,38 +612,21 @@ def filter_ophys_experiments( ) if require_eye_tracking: - experiments = [ - e for e in experiments - if e.get("fail_eye_tracking", None) is False - ] + experiments = [e for e in experiments if e.get("fail_eye_tracking", None) is False] if not include_failed: - experiments = [ - e - for e in experiments - if not e.get("experiment_container", {}).get("failed", False) - ] + experiments = [e for e in experiments if not e.get("experiment_container", {}).get("failed", False)] if experiment_container_ids is not None: - experiments = [ - e - for e in experiments - if e["experiment_container_id"] in experiment_container_ids - ] + experiments = [e for e in experiments if e["experiment_container_id"] in experiment_container_ids] if session_types is not None: - experiments = [ - e for e in experiments if e["stimulus_name"] in session_types - ] + experiments = [e for e in experiments if e["stimulus_name"] in session_types] if stimuli is not None: experiments = [ e for e in experiments - if len( - set(stimuli) - & set(stimulus_info.stimuli_in_session(e["stimulus_name"])) - ) - > 0 + if len(set(stimuli) & set(stimulus_info.stimuli_in_session(e["stimulus_name"]))) > 0 ] if simple: @@ -753,27 +673,16 @@ def filter_cell_specimens( """ if not include_failed: - cell_specimens = [ - c - for c in cell_specimens - if not c.get("failed_experiment_container", False) - ] + cell_specimens = [c for c in cell_specimens if not c.get("failed_experiment_container", False)] if ids is not None: - cell_specimens = [c for c in cell_specimens - if c["cell_specimen_id"] in ids] + cell_specimens = [c for c in cell_specimens if c["cell_specimen_id"] in ids] if experiment_container_ids is not None: - cell_specimens = [ - c - for c in cell_specimens - if c["experiment_container_id"] in experiment_container_ids - ] + cell_specimens = [c for c in cell_specimens if c["experiment_container_id"] in experiment_container_ids] if filters is not None: - cell_specimens = self.dataframe_query( - cell_specimens, filters, "cell_specimen_id" - ) + cell_specimens = self.dataframe_query(cell_specimens, filters, "cell_specimen_id") return cell_specimens @@ -801,9 +710,7 @@ def _filter_clause(op, field, value): return cluster_string - query_string = " & ".join( - _filter_clause(f["op"], f["field"], f["value"]) for f in filters - ) + query_string = " & ".join(_filter_clause(f["op"], f["field"], f["value"]) for f in filters) return query_string @@ -918,8 +825,7 @@ def find_specimen_cre_line(specimen): return next( tl["name"] for tl in specimen["donor"]["transgenic_lines"] - if tl["transgenic_line_type_name"] == "driver" and - "Cre" in tl["name"] + if tl["transgenic_line_type_name"] == "driver" and "Cre" in tl["name"] ) except StopIteration: return None @@ -928,9 +834,7 @@ def find_specimen_cre_line(specimen): def find_specimen_reporter_line(specimen): try: return next( - tl["name"] - for tl in specimen["donor"]["transgenic_lines"] - if tl["transgenic_line_type_name"] == "reporter" + tl["name"] for tl in specimen["donor"]["transgenic_lines"] if tl["transgenic_line_type_name"] == "reporter" ) except StopIteration: return None @@ -942,10 +846,7 @@ def find_specimen_transgenic_lines(specimen): def find_experiment_acquisition_age(exp): try: - return ( - parse_date(exp["date_of_acquisition"]) - - parse_date(exp["specimen"]["donor"]["date_of_birth"]) - ).days + return (parse_date(exp["date_of_acquisition"]) - parse_date(exp["specimen"]["donor"]["date_of_birth"])).days except KeyError: return None @@ -954,5 +855,4 @@ def find_container_tags(container): """Custom logic for extracting tags from donor conditions. Filtering out tissuecyte tags.""" conditions = container["specimen"]["donor"].get("conditions", []) - return [c["name"] for c in conditions if not c["name"].startswith( - "tissuecyte")] + return [c["name"] for c in conditions if not c["name"].startswith("tissuecyte")] diff --git a/allensdk/api/queries/cell_types_api.py b/allensdk/api/queries/cell_types_api.py index b1b7fcae08..fd9993eecb 100644 --- a/allensdk/api/queries/cell_types_api.py +++ b/allensdk/api/queries/cell_types_api.py @@ -41,50 +41,40 @@ class CellTypesApi(RmaApi): - NWB_FILE_TYPE = 'NWBDownload' - SWC_FILE_TYPE = '3DNeuronReconstruction' - MARKER_FILE_TYPE = '3DNeuronMarker' + NWB_FILE_TYPE = "NWBDownload" + SWC_FILE_TYPE = "3DNeuronReconstruction" + MARKER_FILE_TYPE = "3DNeuronMarker" - MOUSE = 'Mus musculus' - HUMAN = 'Homo Sapiens' + MOUSE = "Mus musculus" + HUMAN = "Homo Sapiens" def __init__(self, base_uri=None): super(CellTypesApi, self).__init__(base_uri) - @cacheable() - def list_cells_api(self, - id=None, - require_morphology=False, - require_reconstruction=False, - reporter_status=None, - species=None): - - + def list_cells_api( + self, id=None, require_morphology=False, require_reconstruction=False, reporter_status=None, species=None + ): criteria = None if id: criteria = "[specimen__id$eq%d]" % id - cells = self.model_query( - 'ApiCellTypesSpecimenDetail', criteria=criteria, num_rows='all') - + cells = self.model_query("ApiCellTypesSpecimenDetail", criteria=criteria, num_rows="all") + return cells @deprecated("please use list_cells_api instead") - def list_cells(self, - id=None, - require_morphology=False, - require_reconstruction=False, - reporter_status=None, - species=None): + def list_cells( + self, id=None, require_morphology=False, require_reconstruction=False, reporter_status=None, species=None + ): """ Query the API for a list of all cells in the Cell Types Database. Parameters ---------- id: int - ID of a cell. If not provided returns all matching cells. + ID of a cell. If not provided returns all matching cells. require_morphology: boolean Only return cells that have morphology images. @@ -109,47 +99,46 @@ def list_cells(self, criteria = "[id$eq'%d']" % id else: criteria = "[is_cell_specimen$eq'true'],products[name$in'Mouse Cell Types','Human Cell Types'],ephys_result[failed$eqfalse]" - - include = ('structure,cortex_layer,donor(transgenic_lines,organism,conditions),specimen_tags,cell_soma_locations,' + - 'ephys_features,data_sets,neuron_reconstructions,cell_reporter') - cells = self.model_query( - 'Specimen', criteria=criteria, include=include, num_rows='all') + include = ( + "structure,cortex_layer,donor(transgenic_lines,organism,conditions),specimen_tags,cell_soma_locations," + + "ephys_features,data_sets,neuron_reconstructions,cell_reporter" + ) + + cells = self.model_query("Specimen", criteria=criteria, include=include, num_rows="all") for cell in cells: # specimen tags - for tag in cell['specimen_tags']: - tag_name, tag_value = tag['name'].split(' - ') - tag_name = tag_name.replace(' ', '_') + for tag in cell["specimen_tags"]: + tag_name, tag_value = tag["name"].split(" - ") + tag_name = tag_name.replace(" ", "_") cell[tag_name] = tag_value # morphology and reconstuction - cell['has_reconstruction'] = len( - cell['neuron_reconstructions']) > 0 - cell['has_morphology'] = len(cell['data_sets']) > 0 + cell["has_reconstruction"] = len(cell["neuron_reconstructions"]) > 0 + cell["has_morphology"] = len(cell["data_sets"]) > 0 # transgenic line - cell['transgenic_line'] = None - for tl in cell['donor']['transgenic_lines']: - if tl['transgenic_line_type_name'] == 'driver': - cell['transgenic_line'] = tl['name'] + cell["transgenic_line"] = None + for tl in cell["donor"]["transgenic_lines"]: + if tl["transgenic_line_type_name"] == "driver": + cell["transgenic_line"] = tl["name"] # cell reporter status - cell['reporter_status'] = cell.get('cell_reporter', {}).get('name', None) + cell["reporter_status"] = cell.get("cell_reporter", {}).get("name", None) # species - cell['species'] = cell.get('donor',{}).get('organism',{}).get('name', None) + cell["species"] = cell.get("donor", {}).get("organism", {}).get("name", None) # conditions (whitelist) - condition_types = [ 'disease categories' ] - condition_keys = dict(zip(condition_types, - [ ct.replace(' ', '_') for ct in condition_types ])) + condition_types = ["disease categories"] + condition_keys = dict(zip(condition_types, [ct.replace(" ", "_") for ct in condition_types])) for ct, ck in condition_keys.items(): cell[ck] = [] - conditions = cell.get('donor',{}).get('conditions', []) + conditions = cell.get("donor", {}).get("conditions", []) for condition in conditions: - c_type, c_val = condition['name'].split(' - ') + c_type, c_val = condition["name"].split(" - ") if c_type in condition_keys: cell[condition_keys[c_type]].append(c_val) @@ -158,15 +147,15 @@ def list_cells(self, return result def get_cell(self, id): - ''' + """ Query the API for a one cells in the Cell Types Database. - + Returns ------- list Meta data for one cell. - ''' + """ cells = self.list_cells_api(id=id) cell = None if not cells else cells[0] @@ -187,10 +176,8 @@ def get_ephys_sweeps(self, specimen_id): list: List of sweep dictionaries belonging to a cell """ criteria = "[specimen_id$eq%d]" % specimen_id - sweeps = self.model_query( - 'EphysSweep', criteria=criteria, num_rows='all') - return sorted(sweeps, key=lambda x: x['sweep_number']) - + sweeps = self.model_query("EphysSweep", criteria=criteria, num_rows="all") + return sorted(sweeps, key=lambda x: x["sweep_number"]) @deprecated("please use filter_cells_api") def filter_cells(self, cells, require_morphology, require_reconstruction, reporter_status, species): @@ -219,38 +206,39 @@ def filter_cells(self, cells, require_morphology, require_reconstruction, report """ if require_morphology: - cells = [c for c in cells if c['has_morphology']] + cells = [c for c in cells if c["has_morphology"]] if require_reconstruction: - cells = [c for c in cells if c['has_reconstruction']] + cells = [c for c in cells if c["has_reconstruction"]] if reporter_status: - cells = [c for c in cells if c[ - 'reporter_status'] in reporter_status] + cells = [c for c in cells if c["reporter_status"] in reporter_status] if species: - species_lower = [ s.lower() for s in species ] - cells = [c for c in cells if c['donor']['organism']['name'].lower() in species_lower] + species_lower = [s.lower() for s in species] + cells = [c for c in cells if c["donor"]["organism"]["name"].lower() in species_lower] return cells - def filter_cells_api(self, cells, - require_morphology=False, - require_reconstruction=False, - reporter_status=None, - species=None, - simple=True): - """ - """ + def filter_cells_api( + self, + cells, + require_morphology=False, + require_reconstruction=False, + reporter_status=None, + species=None, + simple=True, + ): + """ """ if require_morphology or require_reconstruction: - cells = [c for c in cells if c.get('nr__reconstruction_type') is not None] + cells = [c for c in cells if c.get("nr__reconstruction_type") is not None] if reporter_status: - cells = [c for c in cells if c.get('cell_reporter_status') in reporter_status] + cells = [c for c in cells if c.get("cell_reporter_status") in reporter_status] if species: - species_lower = [ s.lower() for s in species ] - cells = [c for c in cells if c.get('donor__species',"").lower() in species_lower] + species_lower = [s.lower() for s in species] + cells = [c for c in cells if c.get("donor__species", "").lower() in species_lower] if simple: cells = self.simplify_cells_api(cells) @@ -258,24 +246,27 @@ def filter_cells_api(self, cells, return cells def simplify_cells_api(self, cells): - return [{ - 'reporter_status': cell['cell_reporter_status'], - 'cell_soma_location': [ cell['csl__x'], cell['csl__y'], cell['csl__z'] ], - 'species': cell['donor__species'], - 'id': cell['specimen__id'], - 'name': cell['specimen__name'], - 'structure_layer_name': cell['structure__layer'], - 'structure_area_id': cell['structure_parent__id'], - 'structure_area_abbrev': cell['structure_parent__acronym'], - 'transgenic_line': cell['line_name'], - 'dendrite_type': cell['tag__dendrite_type'], - 'apical': cell['tag__apical'], - 'reconstruction_type': cell['nr__reconstruction_type'], - 'disease_state': cell['donor__disease_state'], - 'donor_id': cell['donor__id'], - 'structure_hemisphere': cell['specimen__hemisphere'], - 'normalized_depth': cell['csl__normalized_depth'] - } for cell in cells ] + return [ + { + "reporter_status": cell["cell_reporter_status"], + "cell_soma_location": [cell["csl__x"], cell["csl__y"], cell["csl__z"]], + "species": cell["donor__species"], + "id": cell["specimen__id"], + "name": cell["specimen__name"], + "structure_layer_name": cell["structure__layer"], + "structure_area_id": cell["structure_parent__id"], + "structure_area_abbrev": cell["structure_parent__acronym"], + "transgenic_line": cell["line_name"], + "dendrite_type": cell["tag__dendrite_type"], + "apical": cell["tag__apical"], + "reconstruction_type": cell["nr__reconstruction_type"], + "disease_state": cell["donor__disease_state"], + "donor_id": cell["donor__id"], + "structure_hemisphere": cell["specimen__hemisphere"], + "normalized_depth": cell["csl__normalized_depth"], + } + for cell in cells + ] @cacheable() def get_ephys_features(self): @@ -283,29 +274,22 @@ def get_ephys_features(self): Query the API for the full table of EphysFeatures for all cells. """ - return self.model_query( - 'EphysFeature', - criteria='specimen(ephys_result[failed$eqfalse])', - num_rows='all') + return self.model_query("EphysFeature", criteria="specimen(ephys_result[failed$eqfalse])", num_rows="all") @cacheable() def get_morphology_features(self): """ Query the API for the full table of morphology features for all cells - + Notes ----- by default the tags column is removed because it isn't useful """ return self.model_query( - 'NeuronReconstruction', - criteria="specimen(ephys_result[failed$eqfalse])", - excpt='tags', - num_rows='all') - - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=2, - path_keyword='file_name')) + "NeuronReconstruction", criteria="specimen(ephys_result[failed$eqfalse])", excpt="tags", num_rows="all" + ) + + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=2, path_keyword="file_name")) def save_ephys_data(self, specimen_id, file_name): """ Save the electrophysology recordings for a cell as an NWB file. @@ -318,18 +302,16 @@ def save_ephys_data(self, specimen_id, file_name): file_name: str Path to save the NWB file. """ - criteria = '[id$eq%d],ephys_result(well_known_files(well_known_file_type[name$eq%s]))' % ( - specimen_id, self.NWB_FILE_TYPE) - includes = 'ephys_result(well_known_files(well_known_file_type))' + criteria = "[id$eq%d],ephys_result(well_known_files(well_known_file_type[name$eq%s]))" % ( + specimen_id, + self.NWB_FILE_TYPE, + ) + includes = "ephys_result(well_known_files(well_known_file_type))" - results = self.model_query('Specimen', - criteria=criteria, - include=includes, - num_rows='all') + results = self.model_query("Specimen", criteria=criteria, include=includes, num_rows="all") try: - file_url = results[0]['ephys_result'][ - 'well_known_files'][0]['download_link'] + file_url = results[0]["ephys_result"]["well_known_files"][0]["download_link"] except Exception as _: # noqa: F841 raise Exception("Specimen %d has no ephys data" % specimen_id) @@ -350,17 +332,13 @@ def save_reconstruction(self, specimen_id, file_name): Manifest.safe_make_parent_dirs(file_name) - criteria = '[id$eq%d],neuron_reconstructions(well_known_files)' % specimen_id - includes = 'neuron_reconstructions(well_known_files(well_known_file_type[name$eq\'%s\']))' % self.SWC_FILE_TYPE + criteria = "[id$eq%d],neuron_reconstructions(well_known_files)" % specimen_id + includes = "neuron_reconstructions(well_known_files(well_known_file_type[name$eq'%s']))" % self.SWC_FILE_TYPE - results = self.model_query('Specimen', - criteria=criteria, - include=includes, - num_rows='all') + results = self.model_query("Specimen", criteria=criteria, include=includes, num_rows="all") try: - file_url = results[0]['neuron_reconstructions'][ - 0]['well_known_files'][0]['download_link'] + file_url = results[0]["neuron_reconstructions"][0]["well_known_files"][0]["download_link"] except Exception: raise Exception("Specimen %d has no reconstruction" % specimen_id) @@ -383,17 +361,13 @@ def save_reconstruction_markers(self, specimen_id, file_name): Manifest.safe_make_parent_dirs(file_name) - criteria = '[id$eq%d],neuron_reconstructions(well_known_files)' % specimen_id - includes = 'neuron_reconstructions(well_known_files(well_known_file_type[name$eq\'%s\']))' % self.MARKER_FILE_TYPE + criteria = "[id$eq%d],neuron_reconstructions(well_known_files)" % specimen_id + includes = "neuron_reconstructions(well_known_files(well_known_file_type[name$eq'%s']))" % self.MARKER_FILE_TYPE - results = self.model_query('Specimen', - criteria=criteria, - include=includes, - num_rows='all') + results = self.model_query("Specimen", criteria=criteria, include=includes, num_rows="all") try: - file_url = results[0]['neuron_reconstructions'][ - 0]['well_known_files'][0]['download_link'] + file_url = results[0]["neuron_reconstructions"][0]["well_known_files"][0]["download_link"] except Exception: raise LookupError("Specimen %d has no marker file" % specimen_id) diff --git a/allensdk/api/queries/connected_services.py b/allensdk/api/queries/connected_services.py index 17060012b5..7f8fd61606 100644 --- a/allensdk/api/queries/connected_services.py +++ b/allensdk/api/queries/connected_services.py @@ -37,7 +37,7 @@ class ConnectedServices(object): - ''' + """ A class representing a schema of informatics web services. Notes @@ -50,33 +50,32 @@ class ConnectedServices(object): Connected Services only include API services that are accessed via the RMA endpoint using an rma::services stage. - ''' - ARRAY = 'array' - STRING = 'string' - INTEGER = 'integer' - FLOAT = 'float' - BOOLEAN = 'boolean' + """ + + ARRAY = "array" + STRING = "string" + INTEGER = "integer" + FLOAT = "float" + BOOLEAN = "boolean" def __init__(self): pass def build_url(self, service_name, kwargs): - '''Create a single stage RMA url from a service name and parameters. - ''' + """Create a single stage RMA url from a service name and parameters.""" rma = RmaApi() - fmt = kwargs.get('fmt', 'json') + fmt = kwargs.get("fmt", "json") schema_entry = ConnectedServices._schema[service_name] params = [] - for parameter in schema_entry['parameters']: - value = kwargs.get(parameter['name'], None) + for parameter in schema_entry["parameters"]: + value = kwargs.get(parameter["name"], None) if value is not None: - params.append((parameter['name'], value)) + params.append((parameter["name"], value)) - service_stage = rma.service_stage(service_name, - params) + service_stage = rma.service_stage(service_name, params) url = rma.build_query_url([service_stage], fmt) @@ -85,1035 +84,395 @@ def build_url(self, service_name, kwargs): @classmethod @property def schema(cls): - '''Dictionary of service names and parameters. + """Dictionary of service names and parameters. Notes ----- See `Connected Services and Pipes `_ for a human-readable list of connected services and their parameters. - ''' + """ return cls._schema _schema = { - 'dev_human_correlation': { - 'parameters': [ - {'name': 'set', - 'optional': True, - 'type': STRING, - 'values': ['rna_seq_genes', - 'rna_seq_exons', - 'exon_microarray_genes' - 'exon_microarray_exons'] - }, - {'name': 'donors', - 'optional': True, - 'type': ARRAY - }, - {'name': 'structures', - 'optional': False, - 'type': ARRAY - }, - {'name': 'probes', - 'optional': False, - 'type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': ['desc'] - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_human_correlation": { + "parameters": [ + { + "name": "set", + "optional": True, + "type": STRING, + "values": ["rna_seq_genes", "rna_seq_exons", "exon_microarray_genesexon_microarray_exons"], + }, + {"name": "donors", "optional": True, "type": ARRAY}, + {"name": "structures", "optional": False, "type": ARRAY}, + {"name": "probes", "optional": False, "type": INTEGER}, + { + "name": "sort_order", + "type": STRING, + "optional": True, + "values": ["asc", "desc"], + "default": ["desc"], + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_human_differential': { - 'parameters': [ - {'name': 'set', - 'type': STRING, - 'values': ['rna_seq_genes', - 'rna_seq_exons', - 'exon_microarray_genes', - 'exon_microarray_exons'] - }, - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'optional': True, - 'array_type': INTEGER - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_by', - 'type': STRING, - 'optional': True, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_human_differential": { + "parameters": [ + { + "name": "set", + "type": STRING, + "values": ["rna_seq_genes", "rna_seq_exons", "exon_microarray_genes", "exon_microarray_exons"], + }, + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "optional": True, "array_type": INTEGER}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + { + "name": "sort_by", + "type": STRING, + "optional": True, + "values": ["p-value", "fold-change"], + "default": "p-value", + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_human_expression': { - 'parameters': [ - {'name': 'set', - 'type': STRING, - 'values': ['rna_seq_genes', - 'rna_seq_exons', - 'exon_microarray_genes', - 'exon_microarray_exons'] - }, - {'name': 'probes', - 'type': INTEGER - }, - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_human_expression": { + "parameters": [ + { + "name": "set", + "type": STRING, + "values": ["rna_seq_genes", "rna_seq_exons", "exon_microarray_genes", "exon_microarray_exons"], + }, + {"name": "probes", "type": INTEGER}, + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_human_microarray_correlation': { - 'parameters': [ - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING] - }, - {'name': 'probes', - 'type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_human_microarray_correlation": { + "parameters": [ + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING]}, + {"name": "probes", "type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_human_microarray_differential': { - 'parameters': [ - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_by', - 'type': STRING, - 'optional': True, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_human_microarray_differential": { + "parameters": [ + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + { + "name": "sort_by", + "type": STRING, + "optional": True, + "values": ["p-value", "fold-change"], + "default": "p-value", + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_human_microarray_expression': { - 'parameters': [ - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - }, - {'name': 'donors', - 'type': INTEGER, - 'optional': True, - }, - {'name': 'structures', - 'type': INTEGER, - 'optional': True - } + "dev_human_microarray_expression": { + "parameters": [ + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, + { + "name": "donors", + "type": INTEGER, + "optional": True, + }, + {"name": "structures", "type": INTEGER, "optional": True}, ] }, - 'dev_mouse_agea': { - 'parameters': [ - {'name': 'seed_age', - 'type': STRING - }, - {'name': 'map_age', - 'type': STRING - }, - {'name': 'seed_point', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'seed_threshold', - 'type': ARRAY, - 'array_type': FLOAT - }, - {'name': 'map_threshold', - 'type': ARRAY, - 'array_type': FLOAT - }, - {'name': 'contrast_threshold', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'target_threshold', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_mouse_agea": { + "parameters": [ + {"name": "seed_age", "type": STRING}, + {"name": "map_age", "type": STRING}, + {"name": "seed_point", "type": ARRAY, "array_type": INTEGER}, + {"name": "seed_threshold", "type": ARRAY, "array_type": FLOAT}, + {"name": "map_threshold", "type": ARRAY, "array_type": FLOAT}, + {"name": "contrast_threshold", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "target_threshold", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'dev_mouse_correlation': { - 'parameters': [ - {'name': 'row', - 'type': INTEGER - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'ages', - 'type': ARRAY, - 'array_type': STRING, - 'optional': True - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "dev_mouse_correlation": { + "parameters": [ + {"name": "row", "type": INTEGER}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "ages", "type": ARRAY, "array_type": STRING, "optional": True}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'gbm_correlation': { - 'parameters': [ - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING] - }, - {'name': 'probes', - 'type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "gbm_correlation": { + "parameters": [ + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING]}, + {"name": "probes", "type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'gbm_differential': { - 'parameters': [ - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_by', - 'type': STRING, - 'optional': True, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "gbm_differential": { + "parameters": [ + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + { + "name": "sort_by", + "type": STRING, + "optional": True, + "values": ["p-value", "fold-change"], + "default": "p-value", + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'gbm_expression': { - 'parameters': [ - {'name': 'probes', - 'type': INTEGER, - 'array_type': INTEGER - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - }, - {'name': 'donors', - 'type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': INTEGER, - 'optional': True - } + "gbm_expression": { + "parameters": [ + {"name": "probes", "type": INTEGER, "array_type": INTEGER}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, + {"name": "donors", "type": INTEGER, "optional": True}, + {"name": "structures", "type": INTEGER, "optional": True}, ] }, - 'gbm_ish_differential': { - 'parameters': [ - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'threshold1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'threshold2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "gbm_ish_differential": { + "parameters": [ + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + {"name": "threshold1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "threshold2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'gbm_ish_expression': { - 'parameters': [ - {'name': 'structures', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'threshold', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "gbm_ish_expression": { + "parameters": [ + {"name": "structures", "type": ARRAY, "array_type": INTEGER}, + {"name": "threshold", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'human_microarray_correlation': { - 'parameters': [ - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING] - }, - {'name': 'probes', - 'type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "human_microarray_correlation": { + "parameters": [ + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING]}, + {"name": "probes", "type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'human_microarray_differential': { - 'parameters': [ - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'optional': True, - 'array_type': INTEGER - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_by', - 'type': STRING, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "human_microarray_differential": { + "parameters": [ + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "optional": True, "array_type": INTEGER}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + {"name": "sort_by", "type": STRING, "values": ["p-value", "fold-change"], "default": "p-value"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'human_microarray_expression': { - 'parameters': [ - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "human_microarray_expression": { + "parameters": [ + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_agea': { - 'parameters': [ - {'name': 'set', - 'type': STRING - }, - {'name': 'seed_age', - 'type': STRING - }, - {'name': 'map_age', - 'type': STRING - }, - {'name': 'seed_point', - 'type': ARRAY, - 'array_type': FLOAT - }, - {'name': 'correlation_threshold1', - 'type': FLOAT, - 'optional': True, - }, - {'name': 'correlation_threshold2', - 'type': FLOAT, - 'optional': True - }, - {'name': 'threshold1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'threshold2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_agea": { + "parameters": [ + {"name": "set", "type": STRING}, + {"name": "seed_age", "type": STRING}, + {"name": "map_age", "type": STRING}, + {"name": "seed_point", "type": ARRAY, "array_type": FLOAT}, + { + "name": "correlation_threshold1", + "type": FLOAT, + "optional": True, + }, + {"name": "correlation_threshold2", "type": FLOAT, "optional": True}, + {"name": "threshold1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "threshold2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_correlation': { - 'parameters': [ - {'name': 'set', - 'type': STRING, - 'values': ['mouse', 'mouse_coronal'] - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'row', - 'type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_correlation": { + "parameters": [ + {"name": "set", "type": STRING, "values": ["mouse", "mouse_coronal"]}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "row", "type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_differential': { - 'parameters': [ - {'name': 'set', - 'type': STRING, - 'values': ['mouse', 'mouse_coronal'] - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'threshold1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'threshold2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_differential": { + "parameters": [ + {"name": "set", "type": STRING, "values": ["mouse", "mouse_coronal"]}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + {"name": "threshold1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "threshold2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_connectivity_correlation': { - 'parameters': [ - {'name': 'row', - 'type': INTEGER - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'product_ids', - 'type': ARRAY, - 'array_type': [INTEGER], - 'optional': True - }, - {'name': 'hemisphere', - 'type': STRING, - 'optional': True, - 'values': ['right', 'left'] - }, - {'name': 'transgenic_lines', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'injection_structures', - 'type': 'Array', - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'primary_structure_only', - 'type': BOOLEAN, - 'optional': True, - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_connectivity_correlation": { + "parameters": [ + {"name": "row", "type": INTEGER}, + {"name": "structures", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "product_ids", "type": ARRAY, "array_type": [INTEGER], "optional": True}, + {"name": "hemisphere", "type": STRING, "optional": True, "values": ["right", "left"]}, + {"name": "transgenic_lines", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "injection_structures", "type": "Array", "array_type": [INTEGER, STRING], "optional": True}, + { + "name": "primary_structure_only", + "type": BOOLEAN, + "optional": True, + }, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_connectivity_injection_coordinate': { - 'parameters': [ - {'name': 'seed_point', - 'type': ARRAY, - 'array_type': FLOAT, - 'optional': False, - }, - {'name': 'transgenic_lines', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'injection_structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'product_ids', - 'type': ARRAY, - 'array_type': [INTEGER], - 'optional': True - }, - {'name': 'primary_structure_only', - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_connectivity_injection_coordinate": { + "parameters": [ + { + "name": "seed_point", + "type": ARRAY, + "array_type": FLOAT, + "optional": False, + }, + {"name": "transgenic_lines", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "injection_structures", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "product_ids", "type": ARRAY, "array_type": [INTEGER], "optional": True}, + {"name": "primary_structure_only", "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_connectivity_injection_structure': { - 'parameters': [ - {'name': 'injection_structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING] - }, - {'name': 'target_domain', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'injection_hemisphere', - 'type': STRING, - 'optional': True, - 'values': ['right', 'left'] - }, - {'name': 'target_hemisphere', - 'type': STRING, - 'optional': True, - 'values': ['right', 'left'] - }, - {'name': 'transgenic_lines', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'injection_domain', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'product_ids', - 'type': ARRAY, - 'array_type': [INTEGER], - 'optional': True - }, - {'name': 'primary_structure_only', - 'type': BOOLEAN, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_connectivity_injection_structure": { + "parameters": [ + {"name": "injection_structures", "type": ARRAY, "array_type": [INTEGER, STRING]}, + {"name": "target_domain", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "injection_hemisphere", "type": STRING, "optional": True, "values": ["right", "left"]}, + {"name": "target_hemisphere", "type": STRING, "optional": True, "values": ["right", "left"]}, + {"name": "transgenic_lines", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "injection_domain", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "product_ids", "type": ARRAY, "array_type": [INTEGER], "optional": True}, + {"name": "primary_structure_only", "type": BOOLEAN, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'mouse_connectivity_target_spatial': { - 'parameters': [ - {'name': 'seed_point', - 'type': ARRAY, - 'array_type': FLOAT, - }, - {'name': 'transgenic_lines', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'section_data_set', - 'type': INTEGER, - 'optional': True - }, - {'name': 'injection_structures', - 'type': ARRAY, - 'array_type': [INTEGER, STRING], - 'optional': True - }, - {'name': 'product_ids', - 'type': ARRAY, - 'array_type': [INTEGER], - 'optional': True - }, - {'name': 'primary_structure_only', - 'type': BOOLEAN, - 'optional': True - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "mouse_connectivity_target_spatial": { + "parameters": [ + { + "name": "seed_point", + "type": ARRAY, + "array_type": FLOAT, + }, + {"name": "transgenic_lines", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "section_data_set", "type": INTEGER, "optional": True}, + {"name": "injection_structures", "type": ARRAY, "array_type": [INTEGER, STRING], "optional": True}, + {"name": "product_ids", "type": ARRAY, "array_type": [INTEGER], "optional": True}, + {"name": "primary_structure_only", "type": BOOLEAN, "optional": True}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_lmd_microarray_correlation': { - 'parameters': [ - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': INTEGER, - }, - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_lmd_microarray_correlation": { + "parameters": [ + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + { + "name": "structures", + "type": ARRAY, + "array_type": INTEGER, + }, + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_lmd_microarray_differential': { - 'parameters': [ - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures2', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_by', - 'type': STRING, - 'optional': True, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_lmd_microarray_differential": { + "parameters": [ + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures2", "type": ARRAY, "array_type": INTEGER}, + { + "name": "sort_by", + "type": STRING, + "optional": True, + "values": ["p-value", "fold-change"], + "default": "p-value", + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_lmd_microarray_expression': { - 'parameters': [ - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_lmd_microarray_expression": { + "parameters": [ + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_macro_microarray_correlation': { - 'parameters': [ - {'name': 'donors', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'sort_order', - 'type': STRING, - 'optional': True, - 'values': ['asc', 'desc'], - 'default': 'desc' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_macro_microarray_correlation": { + "parameters": [ + {"name": "donors", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures", "type": ARRAY, "array_type": INTEGER}, + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "sort_order", "type": STRING, "optional": True, "values": ["asc", "desc"], "default": "desc"}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_macro_microarray_differential': { - 'parameters': [ - {'name': 'donors1', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures1', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'donors2', - 'type': ARRAY, - 'array_type': INTEGER, - 'optional': True - }, - {'name': 'structures2', - 'array_type': INTEGER, - 'type': ARRAY - }, - {'name': 'sort_by', - 'type': STRING, - 'optional': True, - 'values': ['p-value', 'fold-change'], - 'default': 'p-value' - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_macro_microarray_differential": { + "parameters": [ + {"name": "donors1", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures1", "type": ARRAY, "array_type": INTEGER}, + {"name": "donors2", "type": ARRAY, "array_type": INTEGER, "optional": True}, + {"name": "structures2", "array_type": INTEGER, "type": ARRAY}, + { + "name": "sort_by", + "type": STRING, + "optional": True, + "values": ["p-value", "fold-change"], + "default": "p-value", + }, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'nhp_macro_microarray_expression': { - 'parameters': [ - {'name': 'probes', - 'type': ARRAY, - 'array_type': INTEGER - }, - {'name': 'start_row', - 'type': INTEGER, - 'optional': True, - 'default': 0 - }, - {'name': 'num_rows', - 'type': INTEGER, - 'optional': True, - 'default': 2000 - } + "nhp_macro_microarray_expression": { + "parameters": [ + {"name": "probes", "type": ARRAY, "array_type": INTEGER}, + {"name": "start_row", "type": INTEGER, "optional": True, "default": 0}, + {"name": "num_rows", "type": INTEGER, "optional": True, "default": 2000}, ] }, - 'text_search': { - 'parameters': [ - {'name': 'query_string', - 'type': STRING - }, - {'name': 'k', - 'type': STRING - } - ] - } + "text_search": {"parameters": [{"name": "query_string", "type": STRING}, {"name": "k", "type": STRING}]}, } diff --git a/allensdk/api/queries/glif_api.py b/allensdk/api/queries/glif_api.py index 5ec4b1028d..dc550126e3 100644 --- a/allensdk/api/queries/glif_api.py +++ b/allensdk/api/queries/glif_api.py @@ -40,109 +40,101 @@ class GlifApi(RmaTemplate): - - _log = logging.getLogger('allensdk.api.queries.glif_api') + _log = logging.getLogger("allensdk.api.queries.glif_api") NWB_FILE_TYPE = None - GLIF_TYPES = [ 395310498, 395310469, 395310475, 395310479, 471355161 ] - - rma_templates = \ - {"glif_queries": [ - {'name': 'neuronal_model_templates', - 'description': 'see name', - 'model': 'NeuronalModelTemplate', - 'num_rows': 'all', - 'count': False, - }, - {'name': 'neuronal_models', - 'description': 'see name', - 'model': 'Specimen', - 'include': 'neuronal_models(well_known_files,neuronal_model_template[id$in' + ','.join(map(str,GLIF_TYPES)) + '],neuronal_model_runs(well_known_files))', - 'criteria':'{% if ephys_experiment_ids is defined %}[id$in{{ ephys_experiment_ids }}]{%endif%}', - 'num_rows': 'all', - 'criteria_params':['ephys_experiment_ids'], - 'count': False, - }, - {'name': 'neuron_config', - 'description': 'see name', - 'model': 'NeuronalModel', - 'include': 'well_known_files(well_known_file_type)', - 'criteria':'[id$in{{ neuronal_model_ids }}]', - 'num_rows': 'all', - 'criteria_params':['neuronal_model_ids'], - 'count': False, - } - ] - } + GLIF_TYPES = [395310498, 395310469, 395310475, 395310479, 471355161] + + rma_templates = { + "glif_queries": [ + { + "name": "neuronal_model_templates", + "description": "see name", + "model": "NeuronalModelTemplate", + "num_rows": "all", + "count": False, + }, + { + "name": "neuronal_models", + "description": "see name", + "model": "Specimen", + "include": "neuronal_models(well_known_files,neuronal_model_template[id$in" + + ",".join(map(str, GLIF_TYPES)) + + "],neuronal_model_runs(well_known_files))", + "criteria": "{% if ephys_experiment_ids is defined %}[id$in{{ ephys_experiment_ids }}]{%endif%}", + "num_rows": "all", + "criteria_params": ["ephys_experiment_ids"], + "count": False, + }, + { + "name": "neuron_config", + "description": "see name", + "model": "NeuronalModel", + "include": "well_known_files(well_known_file_type)", + "criteria": "[id$in{{ neuronal_model_ids }}]", + "num_rows": "all", + "criteria_params": ["neuronal_model_ids"], + "count": False, + }, + ] + } def __init__(self, base_uri=None): super(GlifApi, self).__init__(base_uri, query_manifest=GlifApi.rma_templates) def get_neuronal_model_templates(self): - - return self.template_query('glif_queries', - 'neuronal_model_templates') + return self.template_query("glif_queries", "neuronal_model_templates") def get_neuronal_models(self, ephys_experiment_ids=None): - return self.template_query('glif_queries', - 'neuronal_models', ephys_experiment_ids=ephys_experiment_ids) + return self.template_query("glif_queries", "neuronal_models", ephys_experiment_ids=ephys_experiment_ids) def get_neuronal_models_by_id(self, neuronal_model_ids=None): - return self.template_query('glif_queries', - 'neuron_config', neuronal_model_ids=neuronal_model_ids) + return self.template_query("glif_queries", "neuron_config", neuronal_model_ids=neuronal_model_ids) def get_neuron_configs(self, neuronal_model_ids=None): - - data = self.template_query('glif_queries', - 'neuron_config', neuronal_model_ids=neuronal_model_ids) + data = self.template_query("glif_queries", "neuron_config", neuronal_model_ids=neuronal_model_ids) return_dict = {} for curr_config in data: - neuron_config_url = curr_config['well_known_files'][0]['download_link'] - return_dict[curr_config['id']] = self.retrieve_parsed_json_over_http(self.api_url + - neuron_config_url) + neuron_config_url = curr_config["well_known_files"][0]["download_link"] + return_dict[curr_config["id"]] = self.retrieve_parsed_json_over_http(self.api_url + neuron_config_url) return return_dict - @deprecated() def list_neuronal_models(self): - ''' DEPRECATED Query the API for a list of all GLIF neuronal models. + """DEPRECATED Query the API for a list of all GLIF neuronal models. Returns ------- list Meta data for all GLIF neuronal models. - ''' + """ include = "specimen(ephys_result[failed$eqfalse]),neuronal_model_template[name$il'*LIF*']" - return self.model_query('NeuronalModel', - include=include, - num_rows='all') + return self.model_query("NeuronalModel", include=include, num_rows="all") @deprecated() def get_neuronal_model(self, neuronal_model_id): - '''DEPRECATED Query the current RMA endpoint with a neuronal_model id + """DEPRECATED Query the current RMA endpoint with a neuronal_model id to get the corresponding well known files and meta data. Returns ------- dict A dictionary containing - ''' - + """ - include = ('neuronal_model_template(well_known_files(well_known_file_type)),' + - 'specimen(ephys_sweeps,ephys_result(well_known_files(well_known_file_type))),' + - 'well_known_files(well_known_file_type)') + include = ( + "neuronal_model_template(well_known_files(well_known_file_type))," + + "specimen(ephys_sweeps,ephys_result(well_known_files(well_known_file_type)))," + + "well_known_files(well_known_file_type)" + ) criteria = "[id$eq%d]" % neuronal_model_id - self.neuronal_model = self.model_query('NeuronalModel', - criteria=criteria, - include=include, - num_rows='all')[0] + self.neuronal_model = self.model_query("NeuronalModel", criteria=criteria, include=include, num_rows="all")[0] self.ephys_sweeps = None self.neuron_config_url = None @@ -150,101 +142,98 @@ def get_neuronal_model(self, neuronal_model_id): # sweeps come from the specimen try: - specimen = self.neuronal_model['specimen'] - self.ephys_sweeps = specimen['ephys_sweeps'] + specimen = self.neuronal_model["specimen"] + self.ephys_sweeps = specimen["ephys_sweeps"] except Exception as e: logging.info(e.args) self.ephys_sweeps = None if self.ephys_sweeps is None: - logging.warning( - "Could not find ephys_sweeps for this model (%d)" % self.neuronal_model['id']) + logging.warning("Could not find ephys_sweeps for this model (%d)" % self.neuronal_model["id"]) # neuron config file comes from the neuronal model's well known files try: - for wkf in self.neuronal_model['well_known_files']: - if wkf['path'].endswith('neuron_config.json'): - self.neuron_config_url = wkf['download_link'] + for wkf in self.neuronal_model["well_known_files"]: + if wkf["path"].endswith("neuron_config.json"): + self.neuron_config_url = wkf["download_link"] break except Exception: self.neuron_config_url = None if self.neuron_config_url is None: logging.warning( - "Could not find neuron config well_known_file for this model (%d)" % self.neuronal_model['id']) + "Could not find neuron config well_known_file for this model (%d)" % self.neuronal_model["id"] + ) # NWB file comes from the ephys_result's well known files try: - ephys_result = specimen['ephys_result'] - for wkf in ephys_result['well_known_files']: - if wkf['well_known_file_type']['name'] == 'NWBDownload': - self.stimulus_url = wkf['download_link'] + ephys_result = specimen["ephys_result"] + for wkf in ephys_result["well_known_files"]: + if wkf["well_known_file_type"]["name"] == "NWBDownload": + self.stimulus_url = wkf["download_link"] break except Exception: self.stimulus_url = None if self.stimulus_url is None: - logging.warning( - "Could not find stimulus well_known_file for this model (%d)" % self.neuronal_model['id']) + logging.warning("Could not find stimulus well_known_file for this model (%d)" % self.neuronal_model["id"]) self.metadata = { - 'neuron_config_url': self.neuron_config_url, - 'stimulus_url': self.stimulus_url, - 'ephys_sweeps': self.ephys_sweeps, - 'neuronal_model': self.neuronal_model + "neuron_config_url": self.neuron_config_url, + "stimulus_url": self.stimulus_url, + "ephys_sweeps": self.ephys_sweeps, + "neuronal_model": self.neuronal_model, } return self.metadata @deprecated() def get_ephys_sweeps(self): - ''' DEPRECATED Retrieve ephys sweep information out of downloaded metadata for a neuronal model + """DEPRECATED Retrieve ephys sweep information out of downloaded metadata for a neuronal model Returns ------- list A list of sweeps metadata dictionaries - ''' + """ return self.ephys_sweeps @deprecated() def get_neuron_config(self, output_file_name=None): - ''' DEPRECATED Retrieve a model configuration file from the API, optionally save it to disk, and + """DEPRECATED Retrieve a model configuration file from the API, optionally save it to disk, and return the contents of that file as a dictionary. Parameters ---------- output_file_name: string File name to store the neuron configuration (optional). - ''' + """ if self.neuron_config_url is None: raise Exception("URL for neuron config file is empty.") logging.info(self.api_url + self.neuron_config_url) - neuron_config = self.retrieve_parsed_json_over_http( - self.api_url + self.neuron_config_url) + neuron_config = self.retrieve_parsed_json_over_http(self.api_url + self.neuron_config_url) if output_file_name: - with open(output_file_name, 'wb') as f: + with open(output_file_name, "wb") as f: f.write(json.dumps(neuron_config, indent=2)) return neuron_config @deprecated() def cache_stimulus_file(self, output_file_name): - ''' DEPRECATED Download the NWB file for the current neuronal model and save it to a file. + """DEPRECATED Download the NWB file for the current neuronal model and save it to a file. Parameters ---------- output_file_name: string File name to store the NWB file. - ''' + """ if self.stimulus_url is None: raise Exception("URL for stimulus file is empty.") - self.retrieve_file_over_http( - self.api_url + self.metadata['stimulus_url'], output_file_name) + self.retrieve_file_over_http(self.api_url + self.metadata["stimulus_url"], output_file_name) diff --git a/allensdk/api/queries/grid_data_api.py b/allensdk/api/queries/grid_data_api.py index 9702326a72..73da662ca0 100644 --- a/allensdk/api/queries/grid_data_api.py +++ b/allensdk/api/queries/grid_data_api.py @@ -40,61 +40,51 @@ class GridDataApi(RmaApi): - '''HTTP Client for the Allen 3-D Expression Grid Data Service. + """HTTP Client for the Allen 3-D Expression Grid Data Service. See: `Downloading 3-D Expression Grid Data `_ - ''' - - INJECTION_DENSITY = 'injection_density' - PROJECTION_DENSITY = 'projection_density' - INJECTION_FRACTION = 'injection_fraction' - INJECTION_ENERGY = 'injection_energy' - PROJECTION_ENERGY = 'projection_energy' - DATA_MASK = 'data_mask' - - ENERGY = 'energy' - DENSITY = 'density' - INTENSITY = 'intensity' - - def __init__(self, - resolution=None, - base_uri=None): + """ + + INJECTION_DENSITY = "injection_density" + PROJECTION_DENSITY = "projection_density" + INJECTION_FRACTION = "injection_fraction" + INJECTION_ENERGY = "injection_energy" + PROJECTION_ENERGY = "projection_energy" + DATA_MASK = "data_mask" + + ENERGY = "energy" + DENSITY = "density" + INTENSITY = "intensity" + + def __init__(self, resolution=None, base_uri=None): super(GridDataApi, self).__init__(base_uri) if resolution is None: resolution = 25 self.resolution = resolution - - def download_gene_expression_grid_data(self, - section_data_set_id, - volume_type, - path): - ''' Download a metaimage file containing registered gene expression grid data + def download_gene_expression_grid_data(self, section_data_set_id, volume_type, path): + """Download a metaimage file containing registered gene expression grid data Parameters ---------- section_data_set_id : int Download data from this experiment volume_type : str - Download this type of data (options are GridDataApi.ENERGY, + Download this type of data (options are GridDataApi.ENERGY, GridDataApi.DENSITY, GridDataApi.INTENSITY) path : str Download to this path - ''' + """ - include = '?include={}'.format(volume_type) - url = ''.join([self.grid_data_endpoint, '/download/', str(section_data_set_id), include]) + include = "?include={}".format(volume_type) + url = "".join([self.grid_data_endpoint, "/download/", str(section_data_set_id), include]) self.retrieve_file_over_http(url, path, zipped=True) - - @deprecated(message='Use download_gene_expression_grid_data instead') - def download_expression_grid_data(self, - section_data_set_id, - include=None, - path=None): - '''Download in zipped metaimage format. + @deprecated(message="Use download_gene_expression_grid_data instead") + def download_expression_grid_data(self, section_data_set_id, include=None, path=None): + """Download in zipped metaimage format. Parameters ---------- @@ -104,36 +94,28 @@ def download_expression_grid_data(self, Image volumes. 'energy' (default), 'density', 'intensity'. path : string, optional File name to save as. - + Returns ------- file : 3-D expression grid data packaged into a compressed archive file (.zip). Notes ----- - ''' + """ if include is not None: - include_clause = ''.join(['?include=', - ','.join(include)]) + include_clause = "".join(["?include=", ",".join(include)]) else: - include_clause = '' + include_clause = "" - url = ''.join([self.grid_data_endpoint, - '/download/', - str(section_data_set_id), - include_clause]) + url = "".join([self.grid_data_endpoint, "/download/", str(section_data_set_id), include_clause]) if path is None: - path = str(section_data_set_id) + '.zip' + path = str(section_data_set_id) + ".zip" self.retrieve_file_over_http(url, path) - def download_projection_grid_data(self, - section_data_set_id, - image=None, - resolution=None, - save_file_path=None): - '''Download in NRRD format. + def download_projection_grid_data(self, section_data_set_id, image=None, resolution=None, save_file_path=None): + """Download in NRRD format. Parameters ---------- @@ -150,39 +132,36 @@ def download_projection_grid_data(self, ----- See `Downloading 3-D Projection Grid Data `_ for additional documentation. - ''' + """ params_list = [] if image is not None: - params_list.append('image=' + ','.join(image)) + params_list.append("image=" + ",".join(image)) if resolution is not None: - params_list.append('resolution=%d' % (resolution)) + params_list.append("resolution=%d" % (resolution)) if len(params_list) > 0: - params_clause = '?' + '&'.join(params_list) + params_clause = "?" + "&".join(params_list) else: - params_clause = '' + params_clause = "" - url = ''.join([self.grid_data_endpoint, - '/download_file/', - str(section_data_set_id), - params_clause]) + url = "".join([self.grid_data_endpoint, "/download_file/", str(section_data_set_id), params_clause]) if save_file_path is None: - save_file_path = str(section_data_set_id) + '.nrrd' + save_file_path = str(section_data_set_id) + ".nrrd" self.retrieve_file_over_http(url, save_file_path) - - def download_deformation_field(self, - section_data_set_id, + def download_deformation_field( + self, + section_data_set_id, header_path=None, voxel_path=None, - voxel_type='DeformationFieldVoxels', - header_type='DeformationFieldHeader' + voxel_type="DeformationFieldVoxels", + header_type="DeformationFieldHeader", ): - ''' Download the local alignment parameters for this dataset. This a 3D vector image (3 components) describing + """Download the local alignment parameters for this dataset. This a 3D vector image (3 components) describing a deformable local mapping from CCF voxels to this section data set's affine-aligned image stack. Parameters @@ -197,56 +176,59 @@ def download_deformation_field(self, WellKnownFileType of this dataset's data file header_type : str WellKnownFileType of this dataset's header file - ''' + """ - header_path = '{}_dfmfld.mhd'.format(section_data_set_id) if header_path is None else header_path - voxel_path = '{}_dfmfld.raw'.format(section_data_set_id) if voxel_path is None else voxel_path + header_path = "{}_dfmfld.mhd".format(section_data_set_id) if header_path is None else header_path + voxel_path = "{}_dfmfld.raw".format(section_data_set_id) if voxel_path is None else voxel_path well_known_files = self.model_query( - model='WellKnownFile', - filters={'attachable_id': section_data_set_id}, - criteria='well_known_file_type[name$in\'DeformationFieldHeader\',\'DeformationFieldVoxels\']', - include='well_known_file_type' + model="WellKnownFile", + filters={"attachable_id": section_data_set_id}, + criteria="well_known_file_type[name$in'DeformationFieldHeader','DeformationFieldVoxels']", + include="well_known_file_type", ) well_known_file_urls = { - wkf['well_known_file_type']['name']: - self.construct_well_known_file_download_url(wkf['id']) for wkf in well_known_files + wkf["well_known_file_type"]["name"]: self.construct_well_known_file_download_url(wkf["id"]) + for wkf in well_known_files } - + self.retrieve_file_over_http(well_known_file_urls[header_type], header_path) self.retrieve_file_over_http(well_known_file_urls[voxel_type], voxel_path) - @cacheable() - def download_alignment3d(self, section_data_set_id, num_rows='all', count=False, **kwargs): - ''' Download the parameters of the 3D affine tranformation mapping this section data set's image-space stack to + def download_alignment3d(self, section_data_set_id, num_rows="all", count=False, **kwargs): + """Download the parameters of the 3D affine tranformation mapping this section data set's image-space stack to CCF-space (or vice-versa). Parameters ---------- section_data_set_id : int download the parameters for this data set. - + Returns ------- dict : parameters of this section data set's alignment3d - ''' + """ results = self.model_query( - model='SectionDataSet', - filters={'id': section_data_set_id}, - include='alignment3d', + model="SectionDataSet", + filters={"id": section_data_set_id}, + include="alignment3d", num_rows=num_rows, count=count, - **kwargs + **kwargs, ) - results = [result for result in results if 'alignment3d' in result] + results = [result for result in results if "alignment3d" in result] if len(results) == 0: - raise ValueError('no SectionDataSet with attached alignment3d found for id {}'.format(section_data_set_id)) + raise ValueError("no SectionDataSet with attached alignment3d found for id {}".format(section_data_set_id)) elif len(results) > 1: - raise ValueError('found multiple SectionDataSets with attached alignment3ds for id {}: {}'.format(section_data_set_id, results)) + raise ValueError( + "found multiple SectionDataSets with attached alignment3ds for id {}: {}".format( + section_data_set_id, results + ) + ) - return results[0]['alignment3d'] + return results[0]["alignment3d"] diff --git a/allensdk/api/queries/image_download_api.py b/allensdk/api/queries/image_download_api.py index bf4976652a..96746ca087 100644 --- a/allensdk/api/queries/image_download_api.py +++ b/allensdk/api/queries/image_download_api.py @@ -36,64 +36,71 @@ from .rma_template import RmaTemplate from allensdk.api.warehouse_cache.cache import cacheable + class ImageDownloadApi(RmaTemplate): - '''HTTP Client to download whole or partial two-dimensional images from the Allen Institute + """HTTP Client to download whole or partial two-dimensional images from the Allen Institute with the SectionImage, AtlasImage and ProjectionImage Download Services. See `Downloading an Image `_ for more documentation. - ''' - - _FILTER_TYPES = [ 'range', 'rgb', 'contrast' ] - COLORMAPS = { "gray": 0, - "hotmetal": 1, - "jet": 2, - "redtemp": 3, - "expression": 4, - "red": 5, - "blue": 6, - "green": 7, - "aba": 8, - "aibsmap_alt": 9, - "colormap": 10, - "projection": 11 + """ + + _FILTER_TYPES = ["range", "rgb", "contrast"] + COLORMAPS = { + "gray": 0, + "hotmetal": 1, + "jet": 2, + "redtemp": 3, + "expression": 4, + "red": 5, + "blue": 6, + "green": 7, + "aba": 8, + "aibsmap_alt": 9, + "colormap": 10, + "projection": 11, } - rma_templates = \ - {"image_queries": [ - {'name': 'section_image_ranges', - 'description': 'see name', - 'model': 'Equalization', - 'num_rows': 'all', - 'count': False, - 'only': ['blue_lower', 'blue_upper', 'red_lower', 'red_upper', 'green_lower', 'green_upper'], - 'criteria': 'section_data_set(section_images[id$in{{ section_image_ids }}])', - 'criteria_params': ['section_image_ids'] - }, - {'name': 'section_images_by_data_set_id', - 'description': 'see name', - 'model': 'SectionImage', - 'num_rows': 'all', - 'count': False, - 'criteria': '[data_set_id$eq{{ data_set_id }}]', - 'criteria_params': ['data_set_id'] - }, - {'name': 'section_data_sets_by_product_id', - 'description': 'see name', - 'model': 'SectionDataSet', - 'num_rows': 'all', - 'count': False, - 'criteria': '[failed$in{{failed}}],products[id$in{{ product_ids }}]', - 'criteria_params': ['product_ids', 'failed'] - }]} + rma_templates = { + "image_queries": [ + { + "name": "section_image_ranges", + "description": "see name", + "model": "Equalization", + "num_rows": "all", + "count": False, + "only": ["blue_lower", "blue_upper", "red_lower", "red_upper", "green_lower", "green_upper"], + "criteria": "section_data_set(section_images[id$in{{ section_image_ids }}])", + "criteria_params": ["section_image_ids"], + }, + { + "name": "section_images_by_data_set_id", + "description": "see name", + "model": "SectionImage", + "num_rows": "all", + "count": False, + "criteria": "[data_set_id$eq{{ data_set_id }}]", + "criteria_params": ["data_set_id"], + }, + { + "name": "section_data_sets_by_product_id", + "description": "see name", + "model": "SectionDataSet", + "num_rows": "all", + "count": False, + "criteria": "[failed$in{{failed}}],products[id$in{{ product_ids }}]", + "criteria_params": ["product_ids", "failed"], + }, + ] + } def __init__(self, base_uri=None): super(ImageDownloadApi, self).__init__(base_uri, query_manifest=ImageDownloadApi.rma_templates) @cacheable() - def get_section_image_ranges(self, section_image_ids, num_rows='all', count=False, as_lists=True, **kwargs): - '''Section images from the Mouse Connectivity Atlas are displayed on connectivity.brain-map.org after having been - linearly windowed and leveled. This method obtains parameters defining channelwise upper and lower bounds of the windows used for + def get_section_image_ranges(self, section_image_ids, num_rows="all", count=False, as_lists=True, **kwargs): + """Section images from the Mouse Connectivity Atlas are displayed on connectivity.brain-map.org after having been + linearly windowed and leveled. This method obtains parameters defining channelwise upper and lower bounds of the windows used for one or more images. Parameters @@ -105,34 +112,44 @@ def get_section_image_ranges(self, section_image_ids, num_rows='all', count=Fals count : bool, optional If True, return a count of the lines found by the query. Default is False. as_lists : bool, optional - If True, return the window parameters in a list, rather than a dict - (this is the format of the range parameter on ImageDownloadApi.download_image). + If True, return the window parameters in a list, rather than a dict + (this is the format of the range parameter on ImageDownloadApi.download_image). Default is False. Returns ------- - list of dict or list of list : + list of dict or list of list : For each section image id provided, return the window bounds for each channel. - ''' + """ - dict_ranges = self.template_query('image_queries', 'section_image_ranges', - section_image_ids=section_image_ids, - num_rows=num_rows, count=count) + dict_ranges = self.template_query( + "image_queries", "section_image_ranges", section_image_ids=section_image_ids, num_rows=num_rows, count=count + ) if not as_lists: return dict_ranges list_ranges = [] for rng in dict_ranges: - list_ranges.append([ rng['red_lower'], rng['red_upper'], rng['green_lower'], rng['green_upper'], rng['blue_lower'], rng['blue_upper'] ]) + list_ranges.append( + [ + rng["red_lower"], + rng["red_upper"], + rng["green_lower"], + rng["green_upper"], + rng["blue_lower"], + rng["blue_upper"], + ] + ) return list_ranges - @cacheable() - def get_section_data_sets_by_product(self, product_ids, include_failed=False, num_rows='all', count=False, **kwargs): - '''List all of the section data sets produced as part of one or more products + def get_section_data_sets_by_product( + self, product_ids, include_failed=False, num_rows="all", count=False, **kwargs + ): + """List all of the section data sets produced as part of one or more products Parameters ---------- @@ -147,29 +164,32 @@ def get_section_data_sets_by_product(self, product_ids, include_failed=False, nu Returns ------- - list of dict : + list of dict : Each returned element is a section data set record. Notes ----- See http://api.brain-map.org/api/v2/data/query.json?criteria=model::Product for a list of products. - ''' + """ if include_failed: - failed_crit = "\'false\',\'true\'" + failed_crit = "'false','true'" else: - failed_crit = "\'false\'" - - return self.template_query('image_queries', 'section_data_sets_by_product_id', - product_ids=product_ids, - failed=failed_crit, - num_rows=num_rows, count=count) + failed_crit = "'false'" + return self.template_query( + "image_queries", + "section_data_sets_by_product_id", + product_ids=product_ids, + failed=failed_crit, + num_rows=num_rows, + count=count, + ) @cacheable() - def section_image_query(self, section_data_set_id, num_rows='all', count=False, **kwargs): - '''List section images belonging to a specified section data set + def section_image_query(self, section_data_set_id, num_rows="all", count=False, **kwargs): + """List section images belonging to a specified section data set Parameters ---------- @@ -187,49 +207,31 @@ def section_image_query(self, section_data_set_id, num_rows='all', count=False, Notes ----- - The SectionDataSet model is used to represent single experiments which produce an array of images. + The SectionDataSet model is used to represent single experiments which produce an array of images. This includes Mouse Connectivity and Mouse Brain Atlas experiments, among other projects. - You may see references to the ids of experiments from those projects. + You may see references to the ids of experiments from those projects. These are the same as section data set ids. - ''' - - return self.template_query('image_queries', 'section_images_by_data_set_id', - data_set_id=section_data_set_id, - num_rows=num_rows, count=count) - - def download_section_image(self, - section_image_id, - file_path=None, - **kwargs): - self.download_image(section_image_id, - file_path, - endpoint=self.section_image_download_endpoint, - **kwargs) - - def download_atlas_image(self, - atlas_image_id, - file_path=None, - **kwargs): - self.download_image(atlas_image_id, - file_path, - endpoint=self.atlas_image_download_endpoint, - **kwargs) - - def download_projection_image(self, - projection_image_id, - file_path=None, - **kwargs): - self.download_image(projection_image_id, - file_path, - endpoint=self.projection_image_download_endpoint, - **kwargs) - - def download_image(self, - image_id, - file_path=None, - endpoint=None, - **kwargs): - ''' Download whole or partial two-dimensional images + """ + + return self.template_query( + "image_queries", + "section_images_by_data_set_id", + data_set_id=section_data_set_id, + num_rows=num_rows, + count=count, + ) + + def download_section_image(self, section_image_id, file_path=None, **kwargs): + self.download_image(section_image_id, file_path, endpoint=self.section_image_download_endpoint, **kwargs) + + def download_atlas_image(self, atlas_image_id, file_path=None, **kwargs): + self.download_image(atlas_image_id, file_path, endpoint=self.atlas_image_download_endpoint, **kwargs) + + def download_projection_image(self, projection_image_id, file_path=None, **kwargs): + self.download_image(projection_image_id, file_path, endpoint=self.projection_image_download_endpoint, **kwargs) + + def download_image(self, image_id, file_path=None, endpoint=None, **kwargs): + """Download whole or partial two-dimensional images from the Allen Institute with the SectionImage or AtlasImage service. Parameters @@ -287,7 +289,7 @@ def download_image(self, 'downsample=1' halves the number of pixels of the original image both horizontally and vertically. range_list = kwargs.get('range', None) - + Specifying 'downsample=2' quarters the height and width values. Quality must be an integer from 0, for the lowest quality, @@ -318,138 +320,132 @@ def download_image(self, `Projection Dataset `_ help topic. See: `Image Download Service `_ - ''' + """ params = [] if endpoint is None: endpoint = self.image_download_endpoint - downsample = kwargs.get('downsample', None) + downsample = kwargs.get("downsample", None) if downsample is not None: - params.append('downsample=%d' % (downsample)) + params.append("downsample=%d" % (downsample)) - quality = kwargs.get('quality', None) + quality = kwargs.get("quality", None) if quality is not None: - params.append('quality=%d' % (quality)) + params.append("quality=%d" % (quality)) - tumor_feature_annotation = kwargs.get('tumor_feature_annotation', None) + tumor_feature_annotation = kwargs.get("tumor_feature_annotation", None) if tumor_feature_annotation is not None: if tumor_feature_annotation: - params.append('tumor_feature_annotation=true') + params.append("tumor_feature_annotation=true") else: - params.append('tumor_feature_annotation=false') + params.append("tumor_feature_annotation=false") - tumor_feature_boundary = kwargs.get('tumor_feature_boundary', None) + tumor_feature_boundary = kwargs.get("tumor_feature_boundary", None) if tumor_feature_boundary is not None: if tumor_feature_boundary: - params.append('tumor_feature_boundary=true') + params.append("tumor_feature_boundary=true") else: - params.append('tumor_feature_boundary=false') + params.append("tumor_feature_boundary=false") - annotation = kwargs.get('annotation', None) + annotation = kwargs.get("annotation", None) if annotation is not None: if annotation is True: - params.append('annotation=true') + params.append("annotation=true") else: - params.append('annotation=false') + params.append("annotation=false") - atlas = kwargs.get('atlas', None) + atlas = kwargs.get("atlas", None) if atlas is not None: - params.append('atlas=%d' % (atlas)) + params.append("atlas=%d" % (atlas)) - projection = kwargs.get('projection', None) + projection = kwargs.get("projection", None) if projection is not None: if projection is True: - params.append('projection=true') + params.append("projection=true") else: - params.append('projection=false') + params.append("projection=false") - expression = kwargs.get('expression', None) + expression = kwargs.get("expression", None) if expression is not None: if expression: - params.append('expression=true') + params.append("expression=true") else: - params.append('expression=false') + params.append("expression=false") + + colormap_filter = kwargs.get("colormap", None) - colormap_filter = kwargs.get('colormap', None) - if colormap_filter is not None: if isinstance(colormap_filter, str): - params.append('colormap=%s' % (colormap_filter)) + params.append("colormap=%s" % (colormap_filter)) else: lower_threshold = colormap_filter[0] colormap_id = ImageDownloadApi.COLORMAPS[colormap_filter[1]] - filter_values_list = '0.5,%s,0,256,%d' % (str(lower_threshold), - colormap_id) - params.append('colormap=%s' % (filter_values_list)) + filter_values_list = "0.5,%s,0,256,%d" % (str(lower_threshold), colormap_id) + params.append("colormap=%s" % (filter_values_list)) # see # http://api.brain-map.org/api/v2/data/SectionDataSet/100141599.xml?include=equalization,section_images for filter_type in ImageDownloadApi._FILTER_TYPES: filter_values = kwargs.get(filter_type, None) - + if filter_values is not None: - filter_values_list = ','.join(str(r) for r in filter_values) - params.append('%s=%s' % (filter_type, filter_values_list)) + filter_values_list = ",".join(str(r) for r in filter_values) + params.append("%s=%s" % (filter_type, filter_values_list)) - view = kwargs.get('view', None) + view = kwargs.get("view", None) if view is not None: - if view in ['expression', - 'projection', - 'tumor_feature_annotation', - 'tumor_feature_boundary']: - params.append('view=%s' % (view)) + if view in ["expression", "projection", "tumor_feature_annotation", "tumor_feature_boundary"]: + params.append("view=%s" % (view)) else: - raise ValueError("view argument should be 'expression', 'projection', 'tumor_feature_annotation' or 'tumor_feature_boundary'") + raise ValueError( + "view argument should be 'expression', 'projection', 'tumor_feature_annotation' or 'tumor_feature_boundary'" + ) # region of interest - for roi_key in ['left', 'top', 'width', 'height']: + for roi_key in ["left", "top", "width", "height"]: roi_value = kwargs.get(roi_key, None) if roi_value is not None: - params.append('%s=%d' % (roi_key, roi_value)) + params.append("%s=%d" % (roi_key, roi_value)) - downsample_dimensions = kwargs.get('downsample_dimensions', None) + downsample_dimensions = kwargs.get("downsample_dimensions", None) if downsample_dimensions is not None: if downsample_dimensions: - params.append('downsample_dimensions=true') + params.append("downsample_dimensions=true") else: - params.append('downsample_dimensions=false') + params.append("downsample_dimensions=false") if len(params) > 0: url_params = "?" + "&".join(params) else: - url_params = '' + url_params = "" - image_url = ''.join([endpoint, - '/', - str(image_id), - url_params]) + image_url = "".join([endpoint, "/", str(image_id), url_params]) if file_path is None: - file_path = '%d.jpg' % (image_id) + file_path = "%d.jpg" % (image_id) self.retrieve_file_over_http(image_url, file_path) - def atlas_image_query(self, atlas_id, image_type_name=None): - '''List atlas images belonging to a specified atlas + """List atlas images belonging to a specified atlas Parameters ---------- atlas_id : integer, optional Find images from this atlas. image_type_name : string, optional - Restrict response to images of this type. If not provided, + Restrict response to images of this type. If not provided, the query will get it from the atlas id. Returns @@ -462,37 +458,33 @@ def atlas_image_query(self, atlas_id, image_type_name=None): See `Downloading Atlas Images and Graphics `_ for additional documentation. :py:meth:`allensdk.api.queries.ontologies_api.OntologiesApi.get_atlases` can also be used to list atlases along with their ids. - ''' + """ stages = [] if image_type_name is None: - atlas_stage = self.model_stage('Atlas', - criteria='[id$eq%d]' % (atlas_id), - only=['image_type']) + atlas_stage = self.model_stage("Atlas", criteria="[id$eq%d]" % (atlas_id), only=["image_type"]) stages.append(atlas_stage) - atlas_name_pipe_stage = self.pipe_stage('list', - parameters=[('type_name', - self.IS, - self.quote_string('image_type'))]) + atlas_name_pipe_stage = self.pipe_stage( + "list", parameters=[("type_name", self.IS, self.quote_string("image_type"))] + ) stages.append(atlas_name_pipe_stage) - image_type_name = '$type_name' + image_type_name = "$type_name" else: image_type_name = self.quote_string(image_type_name) - criteria_list = ['[annotated$eqtrue],', - 'atlas_data_set(atlases[id$eq%d]),' % (atlas_id), - "alternate_images[image_type$eq%s]" % (image_type_name)] + criteria_list = [ + "[annotated$eqtrue],", + "atlas_data_set(atlases[id$eq%d])," % (atlas_id), + "alternate_images[image_type$eq%s]" % (image_type_name), + ] - atlas_image_model_stage = self.model_stage('AtlasImage', - criteria=criteria_list, - order=[ - 'sub_images.section_number'], - num_rows='all') + atlas_image_model_stage = self.model_stage( + "AtlasImage", criteria=criteria_list, order=["sub_images.section_number"], num_rows="all" + ) stages.append(atlas_image_model_stage) - return self.json_msg_query( - self.build_query_url(stages)) + return self.json_msg_query(self.build_query_url(stages)) diff --git a/allensdk/api/queries/mouse_atlas_api.py b/allensdk/api/queries/mouse_atlas_api.py index b02dada34f..74dc6c7eeb 100644 --- a/allensdk/api/queries/mouse_atlas_api.py +++ b/allensdk/api/queries/mouse_atlas_api.py @@ -42,10 +42,8 @@ from .rma_pager import pageable - class MouseAtlasApi(ReferenceSpaceApi, GridDataApi): - ''' Downloads Mouse Brain Atlas grid data, reference volumes, and metadata. - ''' + """Downloads Mouse Brain Atlas grid data, reference volumes, and metadata.""" MOUSE_ATLAS_PRODUCTS = (1,) DEVMOUSE_ATLAS_PRODUCTS = (3,) @@ -53,46 +51,42 @@ class MouseAtlasApi(ReferenceSpaceApi, GridDataApi): HUMAN_ORGANISM = (1,) @cacheable() - @pageable(num_rows=2000, total_rows='all') + @pageable(num_rows=2000, total_rows="all") def get_section_data_sets(self, gene_ids=None, product_ids=None, **kwargs): - ''' Download a list of section data sets (experiments) from the Mouse Brain + """Download a list of section data sets (experiments) from the Mouse Brain Atlas project. Parameters ---------- gene_ids : list of int, optional - Filter results based on the genes whose expression was characterized + Filter results based on the genes whose expression was characterized in each experiment. Default is all. product_ids : list of int, optional Filter results to a subset of products. Default is the Mouse Brain Atlas. Returns ------- - list of dict : - Each element is a section data set record, with one or more gene - records nested in a list. + list of dict : + Each element is a section data set record, with one or more gene + records nested in a list. + + """ - ''' - if product_ids is None: product_ids = list(self.MOUSE_ATLAS_PRODUCTS) - criteria = 'products[id$in{}]'.format(','.join(map(str, product_ids))) + criteria = "products[id$in{}]".format(",".join(map(str, product_ids))) if gene_ids is not None: - criteria += ',genes[id$in{}]'.format(','.join(map(str, gene_ids))) + criteria += ",genes[id$in{}]".format(",".join(map(str, gene_ids))) - order = kwargs.pop('order', ['\'id\'']) + order = kwargs.pop("order", ["'id'"]) - return self.model_query(model='SectionDataSet', - criteria=criteria, - include='genes', - order=order, - **kwargs) + return self.model_query(model="SectionDataSet", criteria=criteria, include="genes", order=order, **kwargs) @cacheable() - @pageable(num_rows=2000, total_rows='all') + @pageable(num_rows=2000, total_rows="all") def get_genes(self, organism_ids=None, chromosome_ids=None, **kwargs): - ''' Download a list of genes + """Download a list of genes Parameters ---------- @@ -106,45 +100,39 @@ def get_genes(self, organism_ids=None, chromosome_ids=None, **kwargs): list of dict: Each element is a gene record, with a nested chromosome record (also a dict). - ''' + """ if organism_ids is None: organism_ids = list(self.MOUSE_ORGANISM) - criteria = '[organism_id$in{}]'.format(','.join(map(str, organism_ids))) + criteria = "[organism_id$in{}]".format(",".join(map(str, organism_ids))) if chromosome_ids is not None: - criteria += ',[chromosome_id$in{}]'.format(','.join(map(str, chromosome_ids))) - - order = kwargs.pop('order', ['\'id\'']) - - return self.model_query(model='Gene', - criteria=criteria, - include='chromosome', - order=order, - **kwargs) - - @cacheable(strategy='create', - reader = sitk_utilities.read_ndarray_with_sitk, - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) - def download_expression_density(self, path, experiment_id): - self.download_gene_expression_grid_data( - experiment_id, GridDataApi.DENSITY, path) + criteria += ",[chromosome_id$in{}]".format(",".join(map(str, chromosome_ids))) + order = kwargs.pop("order", ["'id'"]) - @cacheable(strategy='create', - reader = sitk_utilities.read_ndarray_with_sitk, - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) - def download_expression_energy(self, path, experiment_id): - self.download_gene_expression_grid_data( - experiment_id, GridDataApi.ENERGY, path) + return self.model_query(model="Gene", criteria=criteria, include="chromosome", order=order, **kwargs) + @cacheable( + strategy="create", + reader=sitk_utilities.read_ndarray_with_sitk, + pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path"), + ) + def download_expression_density(self, path, experiment_id): + self.download_gene_expression_grid_data(experiment_id, GridDataApi.DENSITY, path) + + @cacheable( + strategy="create", + reader=sitk_utilities.read_ndarray_with_sitk, + pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path"), + ) + def download_expression_energy(self, path, experiment_id): + self.download_gene_expression_grid_data(experiment_id, GridDataApi.ENERGY, path) - @cacheable(strategy='create', - reader = sitk_utilities.read_ndarray_with_sitk, - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable( + strategy="create", + reader=sitk_utilities.read_ndarray_with_sitk, + pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path"), + ) def download_expression_intensity(self, path, experiment_id): - self.download_gene_expression_grid_data( - experiment_id, GridDataApi.INTENSITY, path) + self.download_gene_expression_grid_data(experiment_id, GridDataApi.INTENSITY, path) diff --git a/allensdk/api/queries/mouse_connectivity_api.py b/allensdk/api/queries/mouse_connectivity_api.py index 93ac52025a..0b5ce96d4a 100644 --- a/allensdk/api/queries/mouse_connectivity_api.py +++ b/allensdk/api/queries/mouse_connectivity_api.py @@ -38,23 +38,22 @@ from allensdk.api.warehouse_cache.cache import cacheable, Cache import numpy as np + class MouseConnectivityApi(ReferenceSpaceApi, GridDataApi): - ''' + """ HTTP Client for the Allen Mouse Brain Connectivity Atlas. See: `Mouse Connectivity API `_ - ''' + """ + PRODUCT_IDS = [5, 31] def __init__(self, base_uri=None): super(MouseConnectivityApi, self).__init__(base_uri=base_uri) - @cacheable() - def get_experiments(self, - structure_ids, - **kwargs): - ''' + def get_experiments(self, structure_ids, **kwargs): + """ Fetch experiment metadata from the Mouse Brain Connectivity Atlas. Parameters @@ -66,100 +65,97 @@ def get_experiments(self, ------- url : string The constructed URL - ''' - criteria_list = ['[failed$eqfalse]', - 'products[id$in%s]' % (','.join(str(i) for i in MouseConnectivityApi.PRODUCT_IDS))] + """ + criteria_list = [ + "[failed$eqfalse]", + "products[id$in%s]" % (",".join(str(i) for i in MouseConnectivityApi.PRODUCT_IDS)), + ] if structure_ids is not None: if type(structure_ids) is not list: structure_ids = [structure_ids] - criteria_list.append('[id$in%s]' % ','.join(str(i) - for i in structure_ids)) + criteria_list.append("[id$in%s]" % ",".join(str(i) for i in structure_ids)) - criteria_string = ','.join(criteria_list) + criteria_string = ",".join(criteria_list) - return self.model_query('SectionDataSet', - criteria=criteria_string, - **kwargs) + return self.model_query("SectionDataSet", criteria=criteria_string, **kwargs) @cacheable() def get_experiments_api(self): - ''' + """ Fetch experiment metadata from the Mouse Brain Connectivity Atlas via the ApiConnectivity table. Returns ------- url : string The constructed URL - ''' - return self.model_query('ApiConnectivity', num_rows='all') + """ + return self.model_query("ApiConnectivity", num_rows="all") @cacheable() def get_manual_injection_summary(self, experiment_id): - ''' Retrieve manual injection summary. ''' - - criteria = '[id$in%d]' % (experiment_id) - - include = ['specimen(donor(transgenic_mouse(transgenic_lines)),', - 'injections(structure,age)),', - 'equalization,products'] - - only = ['id', - 'failed', - 'storage_directory', - 'red_lower', - 'red_upper', - 'green_lower', - 'green_upper', - 'blue_lower', - 'blue_upper', - 'products.id', - 'specimen_id', - 'structure_id', - 'reference_space_id', - 'primary_injection_structure_id', - 'registration_point', - 'coordinates_ap', - 'coordinates_dv', - 'coordinates_ml', - 'angle', - 'sex', - 'strain', - 'injection_materials', - 'acronym', - 'structures.name', - 'days', - 'transgenic_mice.name', - 'transgenic_lines.name', - 'transgenic_lines.description', - 'transgenic_lines.id', - 'donors.id'] - - return self.model_query('SectionDataSet', - criteria=criteria, - include=include, - only=only) + """Retrieve manual injection summary.""" + + criteria = "[id$in%d]" % (experiment_id) + + include = [ + "specimen(donor(transgenic_mouse(transgenic_lines)),", + "injections(structure,age)),", + "equalization,products", + ] + + only = [ + "id", + "failed", + "storage_directory", + "red_lower", + "red_upper", + "green_lower", + "green_upper", + "blue_lower", + "blue_upper", + "products.id", + "specimen_id", + "structure_id", + "reference_space_id", + "primary_injection_structure_id", + "registration_point", + "coordinates_ap", + "coordinates_dv", + "coordinates_ml", + "angle", + "sex", + "strain", + "injection_materials", + "acronym", + "structures.name", + "days", + "transgenic_mice.name", + "transgenic_lines.name", + "transgenic_lines.description", + "transgenic_lines.id", + "donors.id", + ] + + return self.model_query("SectionDataSet", criteria=criteria, include=include, only=only) @cacheable() def get_experiment_detail(self, experiment_id): - '''Retrieve the experiments data.''' - - criteria = '[id$eq%d]' % (experiment_id) - include = ['specimen(stereotaxic_injections(primary_injection_structure,structures,stereotaxic_injection_coordinates)),', - 'equalization,', - 'sub_images'] + """Retrieve the experiments data.""" + + criteria = "[id$eq%d]" % (experiment_id) + include = [ + "specimen(stereotaxic_injections(primary_injection_structure,structures,stereotaxic_injection_coordinates)),", + "equalization,", + "sub_images", + ] order = ["'sub_images.section_number$asc'"] - return self.model_query('SectionDataSet', - criteria=criteria, - include=include, - order=order) + return self.model_query("SectionDataSet", criteria=criteria, include=include, order=order) @cacheable() - def get_projection_image_info(self, - experiment_id, - section_number): - '''Fetch meta-information of one projection image. + def get_projection_image_info(self, experiment_id, section_number): + """Fetch meta-information of one projection image. Parameters ---------- @@ -173,36 +169,28 @@ def get_projection_image_info(self, `Experimental Overview and Metadata `_ for additional documentation. Download the image using :py:meth:`allensdk.api.queries.image_download_api.ImageDownloadApi.download_section_image` - ''' + """ - criteria = '[id$eq%d]' % (experiment_id) - include = ['equalization,sub_images[section_number$eq%d]' % - (section_number)] + criteria = "[id$eq%d]" % (experiment_id) + include = ["equalization,sub_images[section_number$eq%d]" % (section_number)] - return self.model_query('SectionDataSet', - criteria=criteria, - include=include) - + return self.model_query("SectionDataSet", criteria=criteria, include=include) - def download_reference_aligned_image_channel_volumes(self, - data_set_id, - save_file_path=None): - ''' + def download_reference_aligned_image_channel_volumes(self, data_set_id, save_file_path=None): + """ Returns ------- The well known file is downloaded - ''' - well_known_file_url = self.get_reference_aligned_image_channel_volumes_url( - data_set_id) + """ + well_known_file_url = self.get_reference_aligned_image_channel_volumes_url(data_set_id) if save_file_path is None: - save_file_path = str(data_set_id) + '.zip' + save_file_path = str(data_set_id) + ".zip" self.retrieve_file_over_http(well_known_file_url, save_file_path) - def build_reference_aligned_image_channel_volumes_url(self, - data_set_id): - '''Construct url to download the red, green, and blue channels + def build_reference_aligned_image_channel_volumes_url(self, data_set_id): + """Construct url to download the red, green, and blue channels aligned to the 25um adult mouse brain reference space volume. Parameters @@ -214,39 +202,40 @@ def build_reference_aligned_image_channel_volumes_url(self, ----- See: `Reference-aligned Image Channel Volumes `_ for additional documentation. - ''' + """ - criteria = ['well_known_file_type', - "[name$eq'ImagesResampledTo25MicronARA']", - "[attachable_id$eq%d]" % (data_set_id)] + criteria = [ + "well_known_file_type", + "[name$eq'ImagesResampledTo25MicronARA']", + "[attachable_id$eq%d]" % (data_set_id), + ] - model_stage = self.model_stage('WellKnownFile', - criteria=criteria) + model_stage = self.model_stage("WellKnownFile", criteria=criteria) url = self.build_query_url([model_stage]) return url - def get_reference_aligned_image_channel_volumes_url(self, - data_set_id): - '''Retrieve the download link for a specific data set.\ + def get_reference_aligned_image_channel_volumes_url(self, data_set_id): + """Retrieve the download link for a specific data set.\ Notes ----- See `Reference-aligned Image Channel Volumes `_ for additional documentation. - ''' - download_link = self.do_query(self.build_reference_aligned_image_channel_volumes_url, - lambda parsed_json: str( - parsed_json['msg'][0]['download_link']), - data_set_id) + """ + download_link = self.do_query( + self.build_reference_aligned_image_channel_volumes_url, + lambda parsed_json: str(parsed_json["msg"][0]["download_link"]), + data_set_id, + ) url = self.api_url + download_link return url def experiment_source_search(self, **kwargs): - '''Search over the whole projection signal statistics dataset + """Search over the whole projection signal statistics dataset to find experiments with specific projection profiles. Parameters @@ -278,12 +267,12 @@ def experiment_source_search(self, **kwargs): and `service::mouse_connectivity_injection_structure `_. - ''' + """ tuples = [(k, v) for k, v in kwargs.items()] - return self.service_query('mouse_connectivity_injection_structure', parameters=tuples) + return self.service_query("mouse_connectivity_injection_structure", parameters=tuples) def experiment_spatial_search(self, **kwargs): - '''Displays all SectionDataSets + """Displays all SectionDataSets with projection signal density >= 0.1 at the seed point. This service also returns the path along the most dense pixels from the seed point @@ -313,13 +302,13 @@ def experiment_spatial_search(self, **kwargs): and `service::mouse_connectivity_target_spatial `_. - ''' + """ tuples = [(k, v) for k, v in kwargs.items()] - return self.service_query('mouse_connectivity_target_spatial', parameters=tuples) + return self.service_query("mouse_connectivity_target_spatial", parameters=tuples) def experiment_injection_coordinate_search(self, **kwargs): - '''User specifies a seed location within the 3D reference space. + """User specifies a seed location within the 3D reference space. The service returns a rank list of experiments by distance of its injection site to the specified seed location. @@ -345,12 +334,12 @@ def experiment_injection_coordinate_search(self, **kwargs): and `service::mouse_connectivity_injection_coordinate `_. - ''' + """ tuples = [(k, v) for k, v in kwargs.items()] - return self.service_query('mouse_connectivity_injection_coordinate', parameters=tuples) + return self.service_query("mouse_connectivity_injection_coordinate", parameters=tuples) def experiment_correlation_search(self, **kwargs): - '''Select a seed experiment and a domain over + """Select a seed experiment and a domain over which the similarity comparison is to be made. @@ -380,99 +369,79 @@ def experiment_correlation_search(self, **kwargs): and `service::mouse_connectivity_correlation `_. - ''' + """ tuples = sorted(kwargs.items()) - return self.service_query('mouse_connectivity_correlation', - parameters=tuples) + return self.service_query("mouse_connectivity_correlation", parameters=tuples) @cacheable() - def get_structure_unionizes(self, - experiment_ids, - is_injection=None, - structure_name=None, - structure_ids=None, - hemisphere_ids=None, - normalized_projection_volume_limit=None, - include=None, - debug=None, - order=None): - - experiment_filter = '[section_data_set_id$in%s]' %\ - ','.join(str(i) for i in experiment_ids) + def get_structure_unionizes( + self, + experiment_ids, + is_injection=None, + structure_name=None, + structure_ids=None, + hemisphere_ids=None, + normalized_projection_volume_limit=None, + include=None, + debug=None, + order=None, + ): + experiment_filter = "[section_data_set_id$in%s]" % ",".join(str(i) for i in experiment_ids) if is_injection is True: - is_injection_filter = '[is_injection$eqtrue]' + is_injection_filter = "[is_injection$eqtrue]" elif is_injection is False: - is_injection_filter = '[is_injection$eqfalse]' + is_injection_filter = "[is_injection$eqfalse]" else: - is_injection_filter = '' + is_injection_filter = "" if normalized_projection_volume_limit is not None: - volume_filter = '[normalized_projection_volume$gt%f]' %\ - (normalized_projection_volume_limit) + volume_filter = "[normalized_projection_volume$gt%f]" % (normalized_projection_volume_limit) else: - volume_filter = '' + volume_filter = "" if hemisphere_ids is not None: - hemisphere_filter = '[hemisphere_id$in%s]' %\ - ','.join(str(h) for h in hemisphere_ids) + hemisphere_filter = "[hemisphere_id$in%s]" % ",".join(str(h) for h in hemisphere_ids) else: - hemisphere_filter = '' + hemisphere_filter = "" if structure_name is not None: structure_filter = ",structure[name$eq'%s']" % (structure_name) elif structure_ids is not None: - structure_filter = '[structure_id$in%s]' %\ - ','.join(str(i) for i in structure_ids) + structure_filter = "[structure_id$in%s]" % ",".join(str(i) for i in structure_ids) else: - structure_filter = '' + structure_filter = "" return self.model_query( - 'ProjectionStructureUnionize', - criteria=''.join([experiment_filter, - is_injection_filter, - volume_filter, - hemisphere_filter, - structure_filter]), + "ProjectionStructureUnionize", + criteria="".join( + [experiment_filter, is_injection_filter, volume_filter, hemisphere_filter, structure_filter] + ), include=include, order=order, - num_rows='all', + num_rows="all", debug=debug, - count=False) + count=False, + ) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_injection_density(self, path, experiment_id, resolution): - self.download_projection_grid_data( - experiment_id, [GridDataApi.INJECTION_DENSITY], resolution, path) + self.download_projection_grid_data(experiment_id, [GridDataApi.INJECTION_DENSITY], resolution, path) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_projection_density(self, path, experiment_id, resolution): - self.download_projection_grid_data( - experiment_id, [GridDataApi.PROJECTION_DENSITY], resolution, path) + self.download_projection_grid_data(experiment_id, [GridDataApi.PROJECTION_DENSITY], resolution, path) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_injection_fraction(self, path, experiment_id, resolution): - self.download_projection_grid_data( - experiment_id, [GridDataApi.INJECTION_FRACTION], resolution, path) + self.download_projection_grid_data(experiment_id, [GridDataApi.INJECTION_FRACTION], resolution, path) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_data_mask(self, path, experiment_id, resolution): - self.download_projection_grid_data( - experiment_id, [GridDataApi.DATA_MASK], resolution, path) - - def calculate_injection_centroid(self, - injection_density, - injection_fraction, - resolution=25): - ''' + self.download_projection_grid_data(experiment_id, [GridDataApi.DATA_MASK], resolution, path) + + def calculate_injection_centroid(self, injection_density, injection_fraction, resolution=25): + """ Compute the centroid of an injection site. Parameters @@ -484,18 +453,18 @@ def calculate_injection_centroid(self, injection_fraction: np.ndarray The injection fraction volume of an experiment - ''' + """ # find all voxels with injection_fraction > 0 injection_voxels = np.nonzero(injection_fraction) - injection_density_computed = np.multiply(injection_density[injection_voxels], - injection_fraction[injection_voxels]) + injection_density_computed = np.multiply( + injection_density[injection_voxels], injection_fraction[injection_voxels] + ) sum_density = np.sum(injection_density_computed) # compute centroid in CCF coordinates if sum_density > 0: - centroid = np.dot(injection_density_computed, - list(zip(*injection_voxels))) / sum_density * resolution + centroid = np.dot(injection_density_computed, list(zip(*injection_voxels))) / sum_density * resolution else: centroid = None diff --git a/allensdk/api/queries/ontologies_api.py b/allensdk/api/queries/ontologies_api.py index b7bbaa0ec8..d9c44820b5 100644 --- a/allensdk/api/queries/ontologies_api.py +++ b/allensdk/api/queries/ontologies_api.py @@ -37,122 +37,129 @@ from allensdk.api.warehouse_cache.cache import cacheable - class OntologiesApi(RmaTemplate): - ''' + """ See: `Atlas Drawings and Ontologies `_ - ''' + """ - rma_templates = \ - {"ontology_queries": [ - {'name': 'structures_by_graph_ids', - 'description': 'see name', - 'model': 'Structure', - 'criteria': '[graph_id$in{{ graph_ids }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_ids'] - }, - {'name': 'structures_by_graph_names', - 'description': 'see name', - 'model': 'Structure', - 'criteria': 'graph[structure_graphs.name$in{{ graph_names }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_names'] - }, - {'name': 'structures_by_set_ids', - 'description': 'see name', - 'model': 'Structure', - 'criteria': '[structure_set_id$in{{ set_ids }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['set_ids'] - }, - {'name': 'structures_by_set_names', - 'description': 'see name', - 'model': 'Structure', - 'criteria': 'structure_sets[name$in{{ set_names }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['set_names'] - }, - {'name': 'structure_graphs_list', - 'description': 'see name', - 'model': 'StructureGraph', - 'num_rows': 'all', - 'count': False - }, - {'name': 'structure_sets_list', - 'description': 'see name', - 'model': 'StructureSet', - 'num_rows': 'all', - 'count': False - }, - {'name': 'atlases_list', - 'description': 'see name', - 'model': 'Atlas', - 'num_rows': 'all', - 'count': False - }, - {'name': 'atlases_table', - 'description': 'see name', - 'model': 'Atlas', - 'criteria': '{% if atlas_ids is defined %}[id$in{{ atlas_ids }}],{%endif%}structure_graph(ontology),graphic_group_labels', - 'include': 'structure_graph(ontology),graphic_group_labels', - 'only': ['atlases.id', - 'atlases.name', - 'atlases.image_type', - 'ontologies.id', - 'ontologies.name', - 'structure_graphs.id', - 'structure_graphs.name', - 'graphic_group_labels.id', - 'graphic_group_labels.name'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['atlas_ids'] - }, - {'name': 'structures_with_sets', - 'description': 'see name', - 'model': 'Structure', - 'include': 'structure_sets', - 'criteria': '[graph_id$in{{ graph_ids }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_ids'] - }, - {'name': 'structure_sets_by_id', - 'description': 'see name', - 'model': 'StructureSet', - 'criteria': '[id$in{{ set_ids }}]', - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['set_ids'] - } - ]} + rma_templates = { + "ontology_queries": [ + { + "name": "structures_by_graph_ids", + "description": "see name", + "model": "Structure", + "criteria": "[graph_id$in{{ graph_ids }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["graph_ids"], + }, + { + "name": "structures_by_graph_names", + "description": "see name", + "model": "Structure", + "criteria": "graph[structure_graphs.name$in{{ graph_names }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["graph_names"], + }, + { + "name": "structures_by_set_ids", + "description": "see name", + "model": "Structure", + "criteria": "[structure_set_id$in{{ set_ids }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["set_ids"], + }, + { + "name": "structures_by_set_names", + "description": "see name", + "model": "Structure", + "criteria": "structure_sets[name$in{{ set_names }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["set_names"], + }, + { + "name": "structure_graphs_list", + "description": "see name", + "model": "StructureGraph", + "num_rows": "all", + "count": False, + }, + { + "name": "structure_sets_list", + "description": "see name", + "model": "StructureSet", + "num_rows": "all", + "count": False, + }, + {"name": "atlases_list", "description": "see name", "model": "Atlas", "num_rows": "all", "count": False}, + { + "name": "atlases_table", + "description": "see name", + "model": "Atlas", + "criteria": "{% if atlas_ids is defined %}[id$in{{ atlas_ids }}],{%endif%}structure_graph(ontology),graphic_group_labels", + "include": "structure_graph(ontology),graphic_group_labels", + "only": [ + "atlases.id", + "atlases.name", + "atlases.image_type", + "ontologies.id", + "ontologies.name", + "structure_graphs.id", + "structure_graphs.name", + "graphic_group_labels.id", + "graphic_group_labels.name", + ], + "num_rows": "all", + "count": False, + "criteria_params": ["atlas_ids"], + }, + { + "name": "structures_with_sets", + "description": "see name", + "model": "Structure", + "include": "structure_sets", + "criteria": "[graph_id$in{{ graph_ids }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["graph_ids"], + }, + { + "name": "structure_sets_by_id", + "description": "see name", + "model": "StructureSet", + "criteria": "[id$in{{ set_ids }}]", + "num_rows": "all", + "count": False, + "criteria_params": ["set_ids"], + }, + ] + } def __init__(self, base_uri=None): - super(OntologiesApi, self).__init__(base_uri, - query_manifest=OntologiesApi.rma_templates) + super(OntologiesApi, self).__init__(base_uri, query_manifest=OntologiesApi.rma_templates) @cacheable() - def get_structures(self, - structure_graph_ids=None, - structure_graph_names=None, - structure_set_ids=None, - structure_set_names=None, - order=['structures.graph_order'], - num_rows='all', - count=False, - **kwargs): - '''Retrieve data about anatomical structures. + def get_structures( + self, + structure_graph_ids=None, + structure_graph_names=None, + structure_set_ids=None, + structure_set_names=None, + order=["structures.graph_order"], + num_rows="all", + count=False, + **kwargs, + ): + """Retrieve data about anatomical structures. Parameters ---------- @@ -177,43 +184,51 @@ def get_structures(self, Notes ----- Only one of the methods of limiting the query should be used at a time. - ''' + """ if structure_graph_ids is not None: - data = self.template_query('ontology_queries', - 'structures_by_graph_ids', - graph_ids=structure_graph_ids, - order=order, - num_rows=num_rows, - count=count) + data = self.template_query( + "ontology_queries", + "structures_by_graph_ids", + graph_ids=structure_graph_ids, + order=order, + num_rows=num_rows, + count=count, + ) elif structure_graph_names is not None: - data = self.template_query('ontology_queries', - 'structures_by_graph_names', - graph_names=structure_graph_names, - order=order, - num_rows=num_rows, - count=count) + data = self.template_query( + "ontology_queries", + "structures_by_graph_names", + graph_names=structure_graph_names, + order=order, + num_rows=num_rows, + count=count, + ) elif structure_set_ids is not None: - data = self.template_query('ontology_queries', - 'structures_by_set_ids', - set_ids=structure_set_ids, - order=order, - num_rows=num_rows, - count=count) + data = self.template_query( + "ontology_queries", + "structures_by_set_ids", + set_ids=structure_set_ids, + order=order, + num_rows=num_rows, + count=count, + ) elif structure_set_names is not None: - data = self.template_query('ontology_queries', - 'structures_by_set_names', - set_names=structure_set_names, - order=order, - num_rows=num_rows, - count=count) + data = self.template_query( + "ontology_queries", + "structures_by_set_names", + set_names=structure_set_names, + order=order, + num_rows=num_rows, + count=count, + ) + + return data - return data - - @cacheable() - def get_structures_with_sets(self, structure_graph_ids, order=['structures.graph_order'], - num_rows='all', count=False, **kwargs): - '''Download structures along with the sets to which they belong. + def get_structures_with_sets( + self, structure_graph_ids, order=["structures.graph_order"], num_rows="all", count=False, **kwargs + ): + """Download structures along with the sets to which they belong. Parameters ---------- @@ -223,22 +238,25 @@ def get_structures_with_sets(self, structure_graph_ids, order=['structures.graph list of RMA order clauses for sorting num_rows : int how many records to retrieve - + Returns ------- dict the parsed json response containing data from the API - ''' - - return self.template_query('ontology_queries', 'structures_with_sets', - graph_ids=structure_graph_ids, - order=order, num_rows=num_rows, - count=count) - + """ + + return self.template_query( + "ontology_queries", + "structures_with_sets", + graph_ids=structure_graph_ids, + order=order, + num_rows=num_rows, + count=count, + ) def unpack_structure_set_ancestors(self, structure_dataframe): - '''Convert a slash-separated structure_id_path field to a list. + """Convert a slash-separated structure_id_path field to a list. Parameters ---------- @@ -249,17 +267,14 @@ def unpack_structure_set_ancestors(self, structure_dataframe): ------- None A new column is added to the dataframe containing the ancestor list. - ''' - ancestors = structure_dataframe['structure_id_path'].apply( - lambda e: [int(a) for a in e.split('/')[1:-1]]) - structure_ancestors = [ - [n for n in ancestors_n] for ancestors_n in ancestors - ] - structure_dataframe['structure_set_ancestor'] = structure_ancestors + """ + ancestors = structure_dataframe["structure_id_path"].apply(lambda e: [int(a) for a in e.split("/")[1:-1]]) + structure_ancestors = [[n for n in ancestors_n] for ancestors_n in ancestors] + structure_dataframe["structure_set_ancestor"] = structure_ancestors @cacheable() def get_atlases_table(self, atlas_ids=None, brief=True): - '''List Atlases available through the API + """List Atlases available through the API with associated ontologies and structure graphs. Parameters @@ -278,36 +293,25 @@ def get_atlases_table(self, atlas_ids=None, brief=True): This query is based on the `table of available Atlases `_. See also: `Class: Atlas `_ - ''' + """ if brief is True: - data = self.template_query('ontology_queries', - 'atlases_table', - atlas_ids=atlas_ids) + data = self.template_query("ontology_queries", "atlases_table", atlas_ids=atlas_ids) else: - data = self.template_query('ontology_queries', - 'atlases_table', - atlas_ids=atlas_ids, - only=None) + data = self.template_query("ontology_queries", "atlases_table", atlas_ids=atlas_ids, only=None) return data @cacheable() def get_atlases(self): - return self.template_query('ontology_queries', - 'atlases_list') + return self.template_query("ontology_queries", "atlases_list") @cacheable() def get_structure_graphs(self): - return self.template_query('ontology_queries', - 'structure_graphs_list') + return self.template_query("ontology_queries", "structure_graphs_list") @cacheable() def get_structure_sets(self, structure_set_ids=None): - if structure_set_ids is None: - return self.template_query('ontology_queries', - 'structure_sets_list') + return self.template_query("ontology_queries", "structure_sets_list") else: - return self.template_query('ontology_queries', - 'structure_sets_by_id', - set_ids=list(structure_set_ids)) + return self.template_query("ontology_queries", "structure_sets_by_id", set_ids=list(structure_set_ids)) diff --git a/allensdk/api/queries/reference_space_api.py b/allensdk/api/queries/reference_space_api.py index 22f69c8b8f..cf2685a93d 100644 --- a/allensdk/api/queries/reference_space_api.py +++ b/allensdk/api/queries/reference_space_api.py @@ -39,15 +39,15 @@ import allensdk.core.sitk_utilities as sitk_utilities import nrrd -class ReferenceSpaceApi(RmaApi): - AVERAGE_TEMPLATE = 'average_template' - ARA_NISSL = 'ara_nissl' - MOUSE_2011 = 'annotation/mouse_2011' - DEVMOUSE_2012 = 'annotation/devmouse_2012' - CCF_2015 = 'annotation/ccf_2015' - CCF_2016 = 'annotation/ccf_2016' - CCF_2017 = 'annotation/ccf_2017' +class ReferenceSpaceApi(RmaApi): + AVERAGE_TEMPLATE = "average_template" + ARA_NISSL = "ara_nissl" + MOUSE_2011 = "annotation/mouse_2011" + DEVMOUSE_2012 = "annotation/devmouse_2012" + CCF_2015 = "annotation/ccf_2015" + CCF_2016 = "annotation/ccf_2016" + CCF_2017 = "annotation/ccf_2017" CCF_VERSION_DEFAULT = CCF_2017 VOXEL_RESOLUTION_10_MICRONS = 10 @@ -55,20 +55,14 @@ class ReferenceSpaceApi(RmaApi): VOXEL_RESOLUTION_50_MICRONS = 50 VOXEL_RESOLUTION_100_MICRONS = 100 - def __init__(self, base_uri=None): super(ReferenceSpaceApi, self).__init__(base_uri=base_uri) - - @cacheable(strategy='create', - reader=nrrd.read, - pathfinder=Cache.pathfinder(file_name_position=3, - path_keyword='file_name')) - def download_annotation_volume(self, - ccf_version, - resolution, - file_name): - ''' + @cacheable( + strategy="create", reader=nrrd.read, pathfinder=Cache.pathfinder(file_name_position=3, path_keyword="file_name") + ) + def download_annotation_volume(self, ccf_version, resolution, file_name): + """ Download the annotation volume at a particular resolution. Parameters @@ -80,23 +74,22 @@ def download_annotation_volume(self, Must be 10, 25, 50, or 100. file_name: string Where to save the annotation volume. - + Note: the parameters must be used as positional parameters, not keywords - ''' + """ if ccf_version is None: ccf_version = ReferenceSpaceApi.CCF_VERSION_DEFAULT - self.download_volumetric_data(ccf_version, - 'annotation_%d.nrrd' % resolution, - save_file_path=file_name) - + self.download_volumetric_data(ccf_version, "annotation_%d.nrrd" % resolution, save_file_path=file_name) - @cacheable(strategy='create', reader=sitk_utilities.read_ndarray_with_sitk, - pathfinder=Cache.pathfinder(file_name_position=3, - path_keyword='file_name')) + @cacheable( + strategy="create", + reader=sitk_utilities.read_ndarray_with_sitk, + pathfinder=Cache.pathfinder(file_name_position=3, path_keyword="file_name"), + ) def download_mouse_atlas_volume(self, age, volume_type, file_name): - '''Download a reference volume (annotation, grid annotation, atlas volume) + """Download a reference volume (annotation, grid annotation, atlas volume) from the mouse brain atlas project Parameters @@ -107,22 +100,18 @@ def download_mouse_atlas_volume(self, age, volume_type, file_name): Specify the type of volume to download file_name : str Specify the path to the downloaded volume - ''' + """ - remote_file_name = '{}_{}.zip'.format(age, volume_type) - url = '/'.join([ self.informatics_archive_endpoint, - 'current-release', 'mouse_annotation', - remote_file_name ]) + remote_file_name = "{}_{}.zip".format(age, volume_type) + url = "/".join([self.informatics_archive_endpoint, "current-release", "mouse_annotation", remote_file_name]) self.retrieve_file_over_http(url, file_name, zipped=True) - - @cacheable(strategy='create', - reader=nrrd.read, - pathfinder=Cache.pathfinder(file_name_position=2, - path_keyword='file_name')) + @cacheable( + strategy="create", reader=nrrd.read, pathfinder=Cache.pathfinder(file_name_position=2, path_keyword="file_name") + ) def download_template_volume(self, resolution, file_name): - ''' + """ Download the registration template volume at a particular resolution. Parameters @@ -133,17 +122,16 @@ def download_template_volume(self, resolution, file_name): file_name: string Where to save the registration template volume. - ''' - self.download_volumetric_data(ReferenceSpaceApi.AVERAGE_TEMPLATE, - 'average_template_%d.nrrd' % resolution, - save_file_path=file_name) - - @cacheable(strategy='create', - reader=nrrd.read, - pathfinder=Cache.pathfinder(file_name_position=4, - path_keyword='file_name')) + """ + self.download_volumetric_data( + ReferenceSpaceApi.AVERAGE_TEMPLATE, "average_template_%d.nrrd" % resolution, save_file_path=file_name + ) + + @cacheable( + strategy="create", reader=nrrd.read, pathfinder=Cache.pathfinder(file_name_position=4, path_keyword="file_name") + ) def download_structure_mask(self, structure_id, ccf_version, resolution, file_name): - '''Download an indicator mask for a specific structure. + """Download an indicator mask for a specific structure. Parameters ---------- @@ -156,30 +144,28 @@ def download_structure_mask(self, structure_id, ccf_version, resolution, file_na file_name : string Where to save the downloaded mask. - ''' + """ - if ccf_version is None: + if ccf_version is None: ccf_version = ReferenceSpaceApi.CCF_VERSION_DEFAULT - structure_mask_dir = 'structure_masks_{0}'.format(resolution) - data_path = '{0}/{1}/{2}'.format(ccf_version, 'structure_masks', structure_mask_dir) - remote_file_name = 'structure_{0}.nrrd'.format(structure_id) + structure_mask_dir = "structure_masks_{0}".format(resolution) + data_path = "{0}/{1}/{2}".format(ccf_version, "structure_masks", structure_mask_dir) + remote_file_name = "structure_{0}.nrrd".format(structure_id) try: self.download_volumetric_data(data_path, remote_file_name, save_file_path=file_name) except Exception: - self._file_download_log.error('''We weren't able to download a structure mask for structure {0}. + self._file_download_log.error("""We weren't able to download a structure mask for structure {0}. You can instead build the mask locally using - ReferenceSpace.many_structure_masks''') + ReferenceSpace.many_structure_masks""") raise - - @cacheable(strategy='create', - reader=read_obj, - pathfinder=Cache.pathfinder(file_name_position=3, - path_keyword='file_name')) + @cacheable( + strategy="create", reader=read_obj, pathfinder=Cache.pathfinder(file_name_position=3, path_keyword="file_name") + ) def download_structure_mesh(self, structure_id, ccf_version, file_name): - '''Download a Wavefront obj file containing a triangulated 3d mesh built + """Download a Wavefront obj file containing a triangulated 3d mesh built from an annotated structure. Parameters @@ -191,33 +177,29 @@ def download_structure_mesh(self, structure_id, ccf_version, file_name): file_name : string Where to save the downloaded mask. - ''' + """ - if ccf_version is None: + if ccf_version is None: ccf_version = ReferenceSpaceApi.CCF_VERSION_DEFAULT - data_path = '{0}/{1}'.format(ccf_version, 'structure_meshes') - remote_file_name = '{0}.obj'.format(structure_id) + data_path = "{0}/{1}".format(ccf_version, "structure_meshes") + remote_file_name = "{0}.obj".format(structure_id) try: self.download_volumetric_data(data_path, remote_file_name, save_file_path=file_name) except Exception: - self._file_download_log.error('unable to download a structure mesh for structure {0}.'.format(structure_id)) + self._file_download_log.error("unable to download a structure mesh for structure {0}.".format(structure_id)) raise - - def build_volumetric_data_download_url(self, - data_path, - file_name, - voxel_resolution=None, - release=None, - coordinate_framework=None): - '''Construct url to download 3D reference model in NRRD format. + def build_volumetric_data_download_url( + self, data_path, file_name, voxel_resolution=None, release=None, coordinate_framework=None + ): + """Construct url to download 3D reference model in NRRD format. Parameters ---------- data_path : string - 'average_template', 'ara_nissl', 'annotation/ccf_{year}', + 'average_template', 'ara_nissl', 'annotation/ccf_{year}', 'annotation/mouse_2011', or 'annotation/devmouse_2012' voxel_resolution : int 10, 25, 50 or 100 @@ -228,39 +210,32 @@ def build_volumetric_data_download_url(self, ----- See: `3-D Reference Models `_ for additional documentation. - ''' + """ if voxel_resolution is None: voxel_resolution = ReferenceSpaceApi.VOXEL_RESOLUTION_10_MICRONS if release is None: - release = 'current-release' + release = "current-release" if coordinate_framework is None: - coordinate_framework = 'mouse_ccf' + coordinate_framework = "mouse_ccf" - url = ''.join([self.informatics_archive_endpoint, - '/%s/%s/' % (release, coordinate_framework), - data_path, - '/', - file_name]) + url = "".join( + [self.informatics_archive_endpoint, "/%s/%s/" % (release, coordinate_framework), data_path, "/", file_name] + ) return url - - def download_volumetric_data(self, - data_path, - file_name, - voxel_resolution=None, - save_file_path=None, - release=None, - coordinate_framework=None): - '''Download 3D reference model in NRRD format. + def download_volumetric_data( + self, data_path, file_name, voxel_resolution=None, save_file_path=None, release=None, coordinate_framework=None + ): + """Download 3D reference model in NRRD format. Parameters ---------- data_path : string - 'average_template', 'ara_nissl', 'annotation/ccf_{year}', + 'average_template', 'ara_nissl', 'annotation/ccf_{year}', 'annotation/mouse_2011', or 'annotation/devmouse_2012' file_name : string server-side file name. 'annotation_10.nrrd' for example. @@ -273,18 +248,15 @@ def download_volumetric_data(self, ----- See: `3-D Reference Models `_ for additional documentation. - ''' - url = self.build_volumetric_data_download_url(data_path, - file_name, - voxel_resolution, - release, - coordinate_framework) + """ + url = self.build_volumetric_data_download_url( + data_path, file_name, voxel_resolution, release, coordinate_framework + ) if save_file_path is None: save_file_path = file_name if save_file_path is None: - save_file_path = 'volumetric_data.nrrd' + save_file_path = "volumetric_data.nrrd" self.retrieve_file_over_http(url, save_file_path) - diff --git a/allensdk/api/queries/rma_api.py b/allensdk/api/queries/rma_api.py index f34df152c5..72345d92e9 100644 --- a/allensdk/api/queries/rma_api.py +++ b/allensdk/api/queries/rma_api.py @@ -38,38 +38,37 @@ class RmaApi(Api): - ''' + """ See: `RESTful Model Access (RMA) `_ - ''' - MODEL = 'model::' - PIPE = 'pipe::' - SERVICE = 'service::' - CRITERIA = 'rma::criteria' - INCLUDE = 'rma::include' - OPTIONS = 'rma::options' - ORDER = 'order' - NUM_ROWS = 'num_rows' - ALL = 'all' - START_ROW = 'start_row' - COUNT = 'count' - ONLY = 'only' - EXCEPT = 'except' - EXCPT = 'excpt' - TABULAR = 'tabular' - DEBUG = 'debug' - PREVIEW = 'preview' - TRUE = 'true' - FALSE = 'false' - IS = '$is' - EQ = '$eq' + """ + + MODEL = "model::" + PIPE = "pipe::" + SERVICE = "service::" + CRITERIA = "rma::criteria" + INCLUDE = "rma::include" + OPTIONS = "rma::options" + ORDER = "order" + NUM_ROWS = "num_rows" + ALL = "all" + START_ROW = "start_row" + COUNT = "count" + ONLY = "only" + EXCEPT = "except" + EXCPT = "excpt" + TABULAR = "tabular" + DEBUG = "debug" + PREVIEW = "preview" + TRUE = "true" + FALSE = "false" + IS = "$is" + EQ = "$eq" def __init__(self, base_uri=None): super(RmaApi, self).__init__(base_uri) - def build_query_url(self, - stage_clauses, - fmt='json'): - '''Combine one or more RMA query stages into a single RMA query. + def build_query_url(self, stage_clauses, fmt="json"): + """Combine one or more RMA query stages into a single RMA query. Parameters ---------- @@ -82,23 +81,16 @@ def build_query_url(self, ------- string complete RMA url - ''' + """ if type(stage_clauses) is not list: stage_clauses = [stage_clauses] - url = ''.join([ - self.rma_endpoint, - '/query.', - fmt, - '?q=', - ','.join(stage_clauses)]) + url = "".join([self.rma_endpoint, "/query.", fmt, "?q=", ",".join(stage_clauses)]) return url - def model_stage(self, - model, - **kwargs): - '''Construct a model stage of an RMA query string. + def model_stage(self, model, **kwargs): + """Construct a model stage of an RMA query string. Parameters ---------- @@ -134,63 +126,59 @@ def model_stage(self, used in much of the RMA documentation. Using the &debug=true option with an RMA URL will include debugging information in the response, including the normalized query. - ''' + """ clauses = [RmaApi.MODEL + model] - filters = kwargs.get('filters', None) + filters = kwargs.get("filters", None) if filters is not None: clauses.append(self.filters(filters)) - criteria = kwargs.get('criteria', None) + criteria = kwargs.get("criteria", None) if criteria is not None: - clauses.append(',') + clauses.append(",") clauses.append(RmaApi.CRITERIA) - clauses.append(',') + clauses.append(",") clauses.extend(criteria) - include = kwargs.get('include', None) + include = kwargs.get("include", None) if include is not None: - clauses.append(',') + clauses.append(",") clauses.append(RmaApi.INCLUDE) - clauses.append(',') + clauses.append(",") clauses.extend(include) options_clause = self.options_clause(**kwargs) - if options_clause != '': - clauses.append(',') + if options_clause != "": + clauses.append(",") clauses.append(options_clause) - stage = ''.join(clauses) + stage = "".join(clauses) return stage - def pipe_stage(self, - pipe_name, - parameters): - '''Connect model and service stages via their JSON responses. + def pipe_stage(self, pipe_name, parameters): + """Connect model and service stages via their JSON responses. Notes ----- See: `Service Pipelines `_ and `Connected Services and Pipes `_ - ''' + """ clauses = [RmaApi.PIPE + pipe_name] clauses.append(self.tuple_filters(parameters)) - stage = ''.join(clauses) + stage = "".join(clauses) return stage - def service_stage(self, - service_name, - parameters=None): - '''Construct an RMA query fragment to send a request to a connected service. + def service_stage(self, service_name, parameters=None): + """Construct an RMA query fragment to send a request to a connected service. Parameters ---------- @@ -204,18 +192,18 @@ def service_stage(self, See: `Service Pipelines `_ and `Connected Services and Pipes `_ - ''' + """ clauses = [RmaApi.SERVICE + service_name] if parameters is not None: clauses.append(self.tuple_filters(parameters)) - stage = ''.join(clauses) + stage = "".join(clauses) return stage def model_query(self, *args, **kwargs): - '''Construct and execute a model stage of an RMA query string. + """Construct and execute a model stage of an RMA query string. Parameters ---------- @@ -253,13 +241,11 @@ def model_query(self, *args, **kwargs): used in much of the RMA documentation. Using the &debug=true option with an RMA URL will include debugging information in the response, including the normalized query. - ''' - return self.json_msg_query( - self.build_query_url( - self.model_stage(*args, **kwargs))) + """ + return self.json_msg_query(self.build_query_url(self.model_stage(*args, **kwargs))) def service_query(self, *args, **kwargs): - '''Construct and Execute a single-stage RMA query + """Construct and Execute a single-stage RMA query to send a request to a connected service. Parameters @@ -274,13 +260,11 @@ def service_query(self, *args, **kwargs): See: `Service Pipelines `_ and `Connected Services and Pipes `_ - ''' - return self.json_msg_query( - self.build_query_url( - self.service_stage(*args, **kwargs))) + """ + return self.json_msg_query(self.build_query_url(self.service_stage(*args, **kwargs))) def options_clause(self, **kwargs): - '''build rma:: options clause. + """build rma:: options clause. Parameters ---------- @@ -292,38 +276,31 @@ def options_clause(self, **kwargs): 'true', 'false' or 'preview' num_rows : int or string, optional start_row : int or string, optional - ''' - clause = '' + """ + clause = "" options_params = [] only = kwargs.get(RmaApi.ONLY, None) if only is not None: - options_params.append( - self.only_except_tabular_clause(RmaApi.ONLY, - only)) + options_params.append(self.only_except_tabular_clause(RmaApi.ONLY, only)) # handle alternate 'except' spelling to avoid reserved word conflict excpt = kwargs.get(RmaApi.EXCEPT, None) excpt2 = kwargs.get(RmaApi.EXCPT, None) - + if excpt is not None and excpt2 is not None: - warnings.warn('excpt and except options should not be used together', - Warning) + warnings.warn("excpt and except options should not be used together", Warning) elif excpt2 is not None: - excpt = excpt2 + excpt = excpt2 if excpt is not None: - options_params.append( - self.only_except_tabular_clause(RmaApi.EXCEPT, - excpt)) + options_params.append(self.only_except_tabular_clause(RmaApi.EXCEPT, excpt)) tabular = kwargs.get(RmaApi.TABULAR, None) if tabular is not None: - options_params.append( - self.only_except_tabular_clause(RmaApi.TABULAR, - tabular)) + options_params.append(self.only_except_tabular_clause(RmaApi.TABULAR, tabular)) num_rows = kwargs.get(RmaApi.NUM_ROWS, None) @@ -331,14 +308,12 @@ def options_clause(self, **kwargs): if num_rows == RmaApi.ALL: options_params.append("[%s$eq'all']" % (RmaApi.NUM_ROWS)) else: - options_params.append('[%s$eq%d]' % (RmaApi.NUM_ROWS, - num_rows)) + options_params.append("[%s$eq%d]" % (RmaApi.NUM_ROWS, num_rows)) start_row = kwargs.get(RmaApi.START_ROW, None) if start_row is not None: - options_params.append('[%s$eq%d]' % (RmaApi.START_ROW, - start_row)) + options_params.append("[%s$eq%d]" % (RmaApi.START_ROW, start_row)) order = kwargs.get(RmaApi.ORDER, None) @@ -353,22 +328,20 @@ def options_clause(self, **kwargs): cnt = kwargs.get(RmaApi.COUNT, None) if cnt is not None: - if cnt is True or cnt == 'true': - options_params.append('[%s$eq%s]' % (RmaApi.COUNT, - RmaApi.TRUE)) - elif cnt is False or cnt == 'false': - options_params.append('[%s$eq%s]' % (RmaApi.COUNT, - RmaApi.FALSE)) + if cnt is True or cnt == "true": + options_params.append("[%s$eq%s]" % (RmaApi.COUNT, RmaApi.TRUE)) + elif cnt is False or cnt == "false": + options_params.append("[%s$eq%s]" % (RmaApi.COUNT, RmaApi.FALSE)) else: pass if len(options_params) > 0: - clause = RmaApi.OPTIONS + ''.join(options_params) + clause = RmaApi.OPTIONS + "".join(options_params) return clause def only_except_tabular_clause(self, filter_type, attribute_list): - '''Construct a clause to filter which attributes are returned + """Construct a clause to filter which attributes are returned for use in an rma::options clause. Parameters @@ -394,17 +367,16 @@ def only_except_tabular_clause(self, filter_type, attribute_list): The tabular option does not mask the inner-join behavior of an rma::include clause. The tabular filter is required for .csv format RMA requests. - ''' - clause = '' + """ + clause = "" if attribute_list is not None: - clause = '[%s$eq%s]' % (filter_type, - ','.join(attribute_list)) + clause = "[%s$eq%s]" % (filter_type, ",".join(attribute_list)) return clause def order_clause(self, order_list=None): - '''Construct a debug clause for use in an rma::options clause. + """Construct a debug clause for use in an rma::options clause. Parameters ---------- @@ -420,16 +392,16 @@ def order_clause(self, order_list=None): ----- Optionally adding '+asc' (default) or '+desc' after an attribute will change the sort order. - ''' - clause = '' + """ + clause = "" if order_list is not None: - clause = '[order$eq%s]' % (','.join(order_list)) + clause = "[order$eq%s]" % (",".join(order_list)) return clause def debug_clause(self, debug_value=None): - '''Construct a debug clause for use in an rma::options clause. + """Construct a debug clause for use in an rma::options clause. Parameters ---------- debug_value : string or boolean @@ -447,23 +419,23 @@ def debug_clause(self, debug_value=None): None will return an empty clause. 'preview' will request debugging information without the query being run. - ''' - clause = '' + """ + clause = "" if debug_value is None: - clause = '' - if debug_value is True or debug_value == 'true': - clause = '[debug$eqtrue]' - elif debug_value is False or debug_value == 'false': - clause = '[debug$eqfalse]' - elif debug_value == 'preview': + clause = "" + if debug_value is True or debug_value == "true": + clause = "[debug$eqtrue]" + elif debug_value is False or debug_value == "false": + clause = "[debug$eqfalse]" + elif debug_value == "preview": clause = "[debug$eq'preview']" return clause # TODO: deprecate for something that can preserve order def filters(self, filters): - '''serialize RMA query filter clauses. + """serialize RMA query filter clauses. Parameters ---------- @@ -474,23 +446,23 @@ def filters(self, filters): ------- string filter clause for an RMA query string. - ''' + """ filters_builder = [] - for (key, value) in filters.items(): + for key, value in filters.items(): filters_builder.append(self.filter(key, value)) - return ''.join(filters_builder) + return "".join(filters_builder) # TODO: this needs to be more rigorous. def tuple_filters(self, filters): - '''Construct an RMA filter clause. + """Construct an RMA filter clause. Notes ----- See `RMA Path Syntax - Square Brackets for Filters `_ for additional documentation. - ''' + """ filters_builder = [] for filt in sorted(filters): @@ -505,7 +477,7 @@ def tuple_filters(self, filters): val_array.append(v) else: val_array.append(str(v)) - val = ','.join(val_array) + val = ",".join(val_array) filters_builder.append("[%s$eq%s]" % (filt[0], val)) elif type(val) is int: filters_builder.append("[%s$eq%d]" % (filt[0], val)) @@ -517,14 +489,12 @@ def tuple_filters(self, filters): elif type(val) is str: filters_builder.append("[%s$eq%s]" % (filt[0], filt[1])) elif len(filt) == 3: - filters_builder.append("[%s%s%s]" % (filt[0], - filt[1], - str(filt[2]))) + filters_builder.append("[%s%s%s]" % (filt[0], filt[1], str(filt[2]))) - return ''.join(filters_builder) + return "".join(filters_builder) def quote_string(self, the_string): - '''Wrap a clause in single quotes. + """Wrap a clause in single quotes. Parameters ---------- @@ -535,11 +505,11 @@ def quote_string(self, the_string): ------- string input wrapped in single quotes - ''' - return ''.join(["'", the_string, "'"]) + """ + return "".join(["'", the_string, "'"]) def filter(self, key, value): - '''serialize a single RMA query filter clause. + """serialize a single RMA query filter clause. Parameters ---------- @@ -552,15 +522,11 @@ def filter(self, key, value): ------- string a single filter clause for an RMA query string. - ''' - return "".join(['[', - key, - RmaApi.EQ, - str(value), - ']']) + """ + return "".join(["[", key, RmaApi.EQ, str(value), "]"]) - def build_schema_query(self, clazz=None, fmt='json'): - '''Build the URL that will fetch the data schema. + def build_schema_query(self, clazz=None, fmt="json"): + """Build the URL that will fetch the data schema. Parameters ---------- @@ -578,23 +544,18 @@ def build_schema_query(self, clazz=None, fmt='json'): ----- If a class is specified, only the schema information for that class will be requested, otherwise the url requests the entire schema. - ''' + """ if clazz is not None: - class_clause = '/' + clazz + class_clause = "/" + clazz else: - class_clause = '' + class_clause = "" - url = ''.join([self.rma_endpoint, - class_clause, - '.', - fmt]) + url = "".join([self.rma_endpoint, class_clause, ".", fmt]) return url def get_schema(self, clazz=None): - '''Retrieve schema information.''' - schema_data = self.do_query(self.build_schema_query, - self.read_data, - clazz) + """Retrieve schema information.""" + schema_data = self.do_query(self.build_schema_query, self.read_data, clazz) return schema_data diff --git a/allensdk/api/queries/rma_pager.py b/allensdk/api/queries/rma_pager.py index 0df30251d8..d652154724 100644 --- a/allensdk/api/queries/rma_pager.py +++ b/allensdk/api/queries/rma_pager.py @@ -41,20 +41,18 @@ def __init__(self): pass @staticmethod - def pager(fn, - *args, - **kwargs): - total_rows = kwargs.pop('total_rows', None) - num_rows = kwargs.get('num_rows', None) + def pager(fn, *args, **kwargs): + total_rows = kwargs.pop("total_rows", None) + num_rows = kwargs.get("num_rows", None) - if total_rows == 'all': + if total_rows == "all": start_row = 0 result_count = num_rows kwargs = kwargs - kwargs['count'] = False + kwargs["count"] = False while result_count == num_rows: - kwargs['start_row'] = start_row + kwargs["start_row"] = start_row data = fn(*args, **kwargs) start_row = start_row + num_rows @@ -65,10 +63,10 @@ def pager(fn, else: start_row = 0 kwargs = kwargs - kwargs['count'] = False + kwargs["count"] = False while start_row < total_rows: - kwargs['start_row'] = start_row + kwargs["start_row"] = start_row data = fn(*args, **kwargs) result_count = len(data) @@ -77,23 +75,22 @@ def pager(fn, for r in data: yield r -def pageable(total_rows=None, - num_rows=None): + +def pageable(total_rows=None, num_rows=None): def decor(func): - decor.total_rows=total_rows - decor.num_rows=num_rows + decor.total_rows = total_rows + decor.num_rows = num_rows @functools.wraps(func) - def w(*args, - **kwargs): - if decor.num_rows and 'num_rows' not in kwargs: - kwargs['num_rows'] = decor.num_rows - if decor.total_rows and 'total_rows' not in kwargs: - kwargs['total_rows'] = decor.total_rows - - result = RmaPager.pager(func, - *args, - **kwargs) + def w(*args, **kwargs): + if decor.num_rows and "num_rows" not in kwargs: + kwargs["num_rows"] = decor.num_rows + if decor.total_rows and "total_rows" not in kwargs: + kwargs["total_rows"] = decor.total_rows + + result = RmaPager.pager(func, *args, **kwargs) return result + return w + return decor diff --git a/allensdk/api/queries/rma_template.py b/allensdk/api/queries/rma_template.py index 1b3c44f1c6..c876593c86 100644 --- a/allensdk/api/queries/rma_template.py +++ b/allensdk/api/queries/rma_template.py @@ -38,10 +38,10 @@ class RmaTemplate(RmaApi): - ''' + """ See: `Atlas Drawings and Ontologies `_ - ''' + """ def __init__(self, base_uri=None, query_manifest=None): super(RmaTemplate, self).__init__(base_uri) @@ -49,83 +49,84 @@ def __init__(self, base_uri=None, query_manifest=None): def to_filter_rhs(self, rhs): if type(rhs) == list: - return ','.join(str(r) for r in rhs) + return ",".join(str(r) for r in rhs) return rhs def template_query(self, template_name, entry_name, **kwargs): cb = self.templates[template_name] - templates = [e for e in cb if e['name'] == entry_name] + templates = [e for e in cb if e["name"] == entry_name] if len(templates) > 0: template = templates[0] else: - raise Exception('Entry %s not found.' % (entry_name)) + raise Exception("Entry %s not found." % (entry_name)) - query_args = {'model': template['model']} + query_args = {"model": template["model"]} - if 'criteria' in template: - criteria_template = Template(template['criteria']) + if "criteria" in template: + criteria_template = Template(template["criteria"]) - if 'criteria_params' in template: - criteria_params = {key: self.to_filter_rhs(kwargs.get(key)) - for key in template['criteria_params'] - if key in kwargs and kwargs.get(key) is not None} + if "criteria_params" in template: + criteria_params = { + key: self.to_filter_rhs(kwargs.get(key)) + for key in template["criteria_params"] + if key in kwargs and kwargs.get(key) is not None + } else: criteria_params = {} criteria_str = str(criteria_template.render(**criteria_params)) if criteria_str: - query_args['criteria'] = criteria_str + query_args["criteria"] = criteria_str - if 'include' in template: - include_template = Template(template['include']) + if "include" in template: + include_template = Template(template["include"]) - if 'include_params' in template: - include_params = {key: self.to_filter_rhs(kwargs.get(key)) - for key in template['include_params'] - if key in kwargs and kwargs.get(key) is not None} + if "include_params" in template: + include_params = { + key: self.to_filter_rhs(kwargs.get(key)) + for key in template["include_params"] + if key in kwargs and kwargs.get(key) is not None + } else: include_params = {} include_str = str(include_template.render(**include_params)) if include_str: - query_args['include'] = include_str - - if 'only' in kwargs: - if kwargs.get('only') is not None: - query_args['only'] = [self.quote_string( - ','.join(kwargs.get('only')))] - elif 'only' in template: - query_args['only'] = [ - self.quote_string(','.join(template['only']))] - - if 'except' in kwargs: - if kwargs.get('except') is not None: - query_args['except'] = [self.quote_string( - ','.join(kwargs.get('except')))] - elif 'except' in template: - query_args['except'] = template['except'] - - if 'start_row' in kwargs: - query_args['start_row'] = kwargs.get('start_row') - elif 'start_row' in template: - query_args['start_row'] = template['start_row'] - - if 'num_rows' in kwargs: - query_args['num_rows'] = kwargs.get('num_rows') - elif 'num_rows' in template: - query_args['num_rows'] = template['num_rows'] - - if 'count' in kwargs: - query_args['count'] = kwargs.get('count') - elif 'count' in template: - query_args['count'] = template['count'] - - if 'order' in kwargs: - query_args['order'] = kwargs.get('order') - elif 'order' in template: - query_args['order'] = template['order'] + query_args["include"] = include_str + + if "only" in kwargs: + if kwargs.get("only") is not None: + query_args["only"] = [self.quote_string(",".join(kwargs.get("only")))] + elif "only" in template: + query_args["only"] = [self.quote_string(",".join(template["only"]))] + + if "except" in kwargs: + if kwargs.get("except") is not None: + query_args["except"] = [self.quote_string(",".join(kwargs.get("except")))] + elif "except" in template: + query_args["except"] = template["except"] + + if "start_row" in kwargs: + query_args["start_row"] = kwargs.get("start_row") + elif "start_row" in template: + query_args["start_row"] = template["start_row"] + + if "num_rows" in kwargs: + query_args["num_rows"] = kwargs.get("num_rows") + elif "num_rows" in template: + query_args["num_rows"] = template["num_rows"] + + if "count" in kwargs: + query_args["count"] = kwargs.get("count") + elif "count" in template: + query_args["count"] = template["count"] + + if "order" in kwargs: + query_args["order"] = kwargs.get("order") + elif "order" in template: + query_args["order"] = template["order"] query_args.update(kwargs) diff --git a/allensdk/api/queries/svg_api.py b/allensdk/api/queries/svg_api.py index 29678b9ea7..902916ece1 100644 --- a/allensdk/api/queries/svg_api.py +++ b/allensdk/api/queries/svg_api.py @@ -37,12 +37,11 @@ class SvgApi(Api): - def __init__(self, base_uri=None): super(SvgApi, self).__init__(base_uri) def build_query(self, section_image_id, groups=None, download=False): - '''Build the URL that will fetch meta data for the specified structure. + """Build the URL that will fetch meta data for the specified structure. Parameters ---------- @@ -55,7 +54,7 @@ def build_query(self, section_image_id, groups=None, download=False): ------- url : string The constructed URL - ''' + """ if download is True: endpoint = self.svg_download_endpoint else: @@ -65,32 +64,24 @@ def build_query(self, section_image_id, groups=None, download=False): groups = [] if groups and len(groups) > 0: - url_params = '?groups=' + ','.join([str(g) for g in groups]) + url_params = "?groups=" + ",".join([str(g) for g in groups]) else: - url_params = '' + url_params = "" - url = ''.join([endpoint, - '/', - str(section_image_id), - url_params]) + url = "".join([endpoint, "/", str(section_image_id), url_params]) return url - def download_svg(self, - section_image_id, - groups=None, - file_path=None): - '''Download the svg file''' + def download_svg(self, section_image_id, groups=None, file_path=None): + """Download the svg file""" if file_path is None: - file_path = '%d.svg' % (section_image_id) + file_path = "%d.svg" % (section_image_id) svg_url = self.build_query(section_image_id, groups, download=True) self.retrieve_file_over_http(svg_url, file_path) - def get_svg(self, - section_image_id, - groups=None): - '''Get the svg document.''' + def get_svg(self, section_image_id, groups=None): + """Get the svg document.""" svg_url = self.build_query(section_image_id, groups) return self.retrieve_xml_over_http(svg_url) diff --git a/allensdk/api/queries/synchronization_api.py b/allensdk/api/queries/synchronization_api.py index 5895783789..929d11639b 100644 --- a/allensdk/api/queries/synchronization_api.py +++ b/allensdk/api/queries/synchronization_api.py @@ -37,23 +37,20 @@ class SynchronizationApi(Api): - '''HTTP client for image synchronization services uses the image alignment results from + """HTTP client for image synchronization services uses the image alignment results from the Informatics Data Processing Pipeline. Note: all locations on SectionImages are reported in pixel coordinates and all locations in 3-D ReferenceSpaces are reported in microns. See `Image to Image Synchronization `_ for additional documentation. - ''' + """ def __init__(self, base_uri=None): super(SynchronizationApi, self).__init__(base_uri) - def get_image_to_atlas(self, - section_image_id, - x, y, - atlas_id): - '''For a specified Atlas, find the closest annotated SectionImage + def get_image_to_atlas(self, section_image_id, x, y, atlas_id): + """For a specified Atlas, find the closest annotated SectionImage and (x,y) location as defined by a seed SectionImage and seed (x,y) location. Parameters @@ -71,22 +68,23 @@ def get_image_to_atlas(self, ------- dict The parsed json response - ''' - url = ''.join([self.image_to_atlas_endpoint, - '/', - str(section_image_id), - '.json', - '?x=%f&y=%f' % (x, y), - '&atlas_id=', - str(atlas_id)]) + """ + url = "".join( + [ + self.image_to_atlas_endpoint, + "/", + str(section_image_id), + ".json", + "?x=%f&y=%f" % (x, y), + "&atlas_id=", + str(atlas_id), + ] + ) return self.json_msg_query(url) - def get_image_to_image(self, - section_image_id, - x, y, - section_data_set_ids): - '''For a list of target SectionDataSets, find the closest SectionImage + def get_image_to_image(self, section_image_id, x, y, section_data_set_ids): + """For a list of target SectionDataSets, find the closest SectionImage and (x,y) location as defined by a seed SectionImage and seed (x,y) pixel location. Parameters @@ -104,22 +102,23 @@ def get_image_to_image(self, ------- dict The parsed json response - ''' - url = ''.join([self.image_to_image_endpoint, - '/', - str(section_image_id), - '.json', - '?x=%f&y=%f' % (x, y), - '§ion_data_set_ids=', - ','.join(str(i) for i in section_data_set_ids)]) + """ + url = "".join( + [ + self.image_to_image_endpoint, + "/", + str(section_image_id), + ".json", + "?x=%f&y=%f" % (x, y), + "§ion_data_set_ids=", + ",".join(str(i) for i in section_data_set_ids), + ] + ) return self.json_msg_query(url) - def get_image_to_image_2d(self, - section_image_id, - x, y, - section_image_ids): - '''For a list of target SectionImages, find the closest (x,y) location + def get_image_to_image_2d(self, section_image_id, x, y, section_image_ids): + """For a list of target SectionImages, find the closest (x,y) location as defined by a seed SectionImage and seed (x,y) location. Parameters @@ -137,22 +136,23 @@ def get_image_to_image_2d(self, ------- dict The parsed json response - ''' - url = ''.join([self.image_to_image_2d_endpoint, - '/', - str(section_image_id), - '.json', - '?x=%f&y=%f' % (x, y), - '§ion_image_ids=', - ','.join(str(i) for i in section_image_ids)]) + """ + url = "".join( + [ + self.image_to_image_2d_endpoint, + "/", + str(section_image_id), + ".json", + "?x=%f&y=%f" % (x, y), + "§ion_image_ids=", + ",".join(str(i) for i in section_image_ids), + ] + ) return self.json_msg_query(url) - def get_reference_to_image(self, - reference_space_id, - x, y, z, - section_data_set_ids): - '''For a list of target SectionDataSets, find the closest SectionImage + def get_reference_to_image(self, reference_space_id, x, y, z, section_data_set_ids): + """For a list of target SectionDataSets, find the closest SectionImage and (x,y) location as defined by a (x,y,z) location in a specified ReferenceSpace. Parameters @@ -172,21 +172,23 @@ def get_reference_to_image(self, ------- dict The parsed json response - ''' - url = ''.join([self.reference_to_image_endpoint, - '/', - str(reference_space_id), - '.json', - '?x=%f&y=%f&z=%f' % (x, y, z), - '§ion_data_set_ids=', - ','.join(str(i) for i in section_data_set_ids)]) + """ + url = "".join( + [ + self.reference_to_image_endpoint, + "/", + str(reference_space_id), + ".json", + "?x=%f&y=%f&z=%f" % (x, y, z), + "§ion_data_set_ids=", + ",".join(str(i) for i in section_data_set_ids), + ] + ) return self.json_msg_query(url) - def get_image_to_reference(self, - section_image_id, - x, y): - '''For a specified SectionImage and (x,y) location, + def get_image_to_reference(self, section_image_id, x, y): + """For a specified SectionImage and (x,y) location, return the (x,y,z) location in the ReferenceSpace of the associated SectionDataSet. Parameters @@ -202,19 +204,13 @@ def get_image_to_reference(self, ------- dict The parsed json response - ''' - url = ''.join([self.image_to_reference_endpoint, - '/', - str(section_image_id), - '.json', - '?x=%f&y=%f' % (x, y)]) + """ + url = "".join([self.image_to_reference_endpoint, "/", str(section_image_id), ".json", "?x=%f&y=%f" % (x, y)]) return self.json_msg_query(url) - def get_structure_to_image(self, - section_data_set_id, - structure_ids): - '''For a list of target structures, find the closest SectionImage + def get_structure_to_image(self, section_data_set_id, structure_ids): + """For a list of target structures, find the closest SectionImage and (x,y) location as defined by the centroid of each Structure. Parameters @@ -228,12 +224,16 @@ def get_structure_to_image(self, ------- dict The parsed json response - ''' - url = ''.join([self.structure_to_image_endpoint, - '/', - str(section_data_set_id), - '.json', - '?structure_ids=', - ','.join([str(i) for i in structure_ids])]) + """ + url = "".join( + [ + self.structure_to_image_endpoint, + "/", + str(section_data_set_id), + ".json", + "?structure_ids=", + ",".join([str(i) for i in structure_ids]), + ] + ) return self.json_msg_query(url) diff --git a/allensdk/api/queries/tree_search_api.py b/allensdk/api/queries/tree_search_api.py index 1b84c248f3..3a606598d1 100644 --- a/allensdk/api/queries/tree_search_api.py +++ b/allensdk/api/queries/tree_search_api.py @@ -37,21 +37,17 @@ class TreeSearchApi(Api): - ''' + """ See `Searching a Specimen or Structure Tree `_ for additional documentation. - ''' + """ def __init__(self, base_uri=None): super(TreeSearchApi, self).__init__(base_uri) - def get_tree(self, - kind, - db_id, - ancestors=None, - descendants=None): - '''Fetch meta data for the specified structure or specimen. + def get_tree(self, kind, db_id, ancestors=None, descendants=None): + """Fetch meta data for the specified structure or specimen. Parameters ---------- @@ -68,31 +64,25 @@ def get_tree(self, ------- dict parsed json response data - ''' + """ params = [] - url_params = '' + url_params = "" if ancestors is True: - params.append('ancestors=true') + params.append("ancestors=true") elif ancestors is False: - params.append('ancestors=false') + params.append("ancestors=false") if descendants is True: - params.append('descendants=true') + params.append("descendants=true") elif descendants is False: - params.append('descendants=false') + params.append("descendants=false") if len(params) > 0: - url_params = '?' + '&'.join(params) + url_params = "?" + "&".join(params) else: - url_params = '' + url_params = "" - url = ''.join([self.tree_search_endpoint, - '/', - kind, - '/', - str(db_id), - '.json', - url_params]) + url = "".join([self.tree_search_endpoint, "/", kind, "/", str(db_id), ".json", url_params]) return self.json_msg_query(url) diff --git a/allensdk/api/warehouse_cache/cache.py b/allensdk/api/warehouse_cache/cache.py index 2c12f70895..38c01680e6 100755 --- a/allensdk/api/warehouse_cache/cache.py +++ b/allensdk/api/warehouse_cache/cache.py @@ -61,14 +61,13 @@ def memoize(f): Access the underlying function with f.__wrapped__. """ cache = {} - sentinel = object() # unique object for cache misses - make_key = _make_key # efficient key building from function args + sentinel = object() # unique object for cache misses + make_key = _make_key # efficient key building from function args cache_get = cache.get cache_len = cache.__len__ @wraps(f) def wrapper(*args, **kwargs): - # Don't consider 3.0 and 3 different key = make_key(args, kwargs, typed=False) @@ -92,20 +91,16 @@ def cache_size(): class Cache(object): - _log = logging.getLogger('allensdk.api.cache') + _log = logging.getLogger("allensdk.api.cache") - def __init__(self, - manifest=None, - cache=True, - version=None, - **kwargs): + def __init__(self, manifest=None, cache=True, version=None, **kwargs): self.cache = cache - if version is None and hasattr(self, 'MANIFEST_VERSION'): + if version is None and hasattr(self, "MANIFEST_VERSION"): version = self.MANIFEST_VERSION self.load_manifest(manifest, version) def get_cache_path(self, file_name, manifest_key, *args): - '''Helper method for accessing path specs from manifest keys. + """Helper method for accessing path specs from manifest keys. Parameters ---------- @@ -117,7 +112,7 @@ def get_cache_path(self, file_name, manifest_key, *args): ------- string or None path - ''' + """ if self.cache: if file_name: return file_name @@ -127,7 +122,7 @@ def get_cache_path(self, file_name, manifest_key, *args): return None def load_manifest(self, file_name, version=None): - '''Read a keyed collection of path specifications. + """Read a keyed collection of path specifications. Parameters ---------- @@ -137,10 +132,9 @@ def load_manifest(self, file_name, version=None): Returns ------- Manifest - ''' + """ if file_name is not None: if not os.path.exists(file_name): - # make the directory if it doesn't exist already dirname = os.path.dirname(file_name) if dirname: @@ -149,10 +143,7 @@ def load_manifest(self, file_name, version=None): self.build_manifest(file_name) try: - self.manifest = Manifest( - ju.read(file_name)['manifest'], - os.path.dirname(file_name), - version=version) + self.manifest = Manifest(ju.read(file_name)["manifest"], os.path.dirname(file_name), version=version) except ManifestVersionError as e: if e.outdated is True: intro = "is out of date" @@ -162,24 +153,27 @@ def load_manifest(self, file_name, version=None): intro = "version did not match the expected version" ref_url = "https://github.com/alleninstitute/allensdk/wiki" - raise ManifestVersionError(("Your manifest file (%s) %s" + - " (its version is '%s', but" + - " version '%s' is expected). " + - " Please remove this file" + - " and it will be regenerated for" + - " you the next time you" + - " instantiate this class." + - " WARNING: There may be new data" + - " files available that replace" + - " the ones you already have" + - " downloaded. Read the notes" + - " for this release for more" + - " details on what has changed" + - " (%s).") % - (file_name, intro, - e.found_version, e.version, - ref_url), - e.version, e.found_version) + raise ManifestVersionError( + ( + "Your manifest file (%s) %s" + + " (its version is '%s', but" + + " version '%s' is expected). " + + " Please remove this file" + + " and it will be regenerated for" + + " you the next time you" + + " instantiate this class." + + " WARNING: There may be new data" + + " files available that replace" + + " the ones you already have" + + " downloaded. Read the notes" + + " for this release for more" + + " details on what has changed" + + " (%s)." + ) + % (file_name, intro, e.found_version, e.version, ref_url), + e.version, + e.found_version, + ) self.manifest_path = file_name @@ -187,13 +181,13 @@ def load_manifest(self, file_name, version=None): self.manifest = None def build_manifest(self, file_name): - '''Creation of default path specifications. + """Creation of default path specifications. Parameters ---------- file_name : string where to save it - ''' + """ manifest_builder = ManifestBuilder() manifest_builder.set_version(self.MANIFEST_VERSION) @@ -203,20 +197,18 @@ def build_manifest(self, file_name): manifest_builder.write_json_file(file_name) def add_manifest_paths(self, manifest_builder): - '''Add cache-class specific paths to the manifest. In derived classes, + """Add cache-class specific paths to the manifest. In derived classes, should call super. - ''' - manifest_builder.add_path('BASEDIR', '.') - if hasattr(self, 'MANIFEST_CONFIG'): + """ + manifest_builder.add_path("BASEDIR", ".") + if hasattr(self, "MANIFEST_CONFIG"): for key, config in self.MANIFEST_CONFIG.items(): manifest_builder.add_path(key, **config) return manifest_builder def manifest_dataframe(self): - '''Convenience method to view manifest as a pandas dataframe. - ''' - return pd.DataFrame.from_dict(self.manifest.path_info, - orient='index') + """Convenience method to view manifest as a pandas dataframe.""" + return pd.DataFrame.from_dict(self.manifest.path_info, orient="index") @staticmethod def json_remove_keys(data, keys): @@ -228,8 +220,7 @@ def json_remove_keys(data, keys): @staticmethod def remove_keys(data, keys=None): - ''' DataFrame version - ''' + """DataFrame version""" if keys is None: keys = [] @@ -237,16 +228,15 @@ def remove_keys(data, keys=None): del data[key] @staticmethod - def json_rename_columns(data, - new_old_name_tuples=None): - '''Convenience method to rename columns in a pandas dataframe. + def json_rename_columns(data, new_old_name_tuples=None): + """Convenience method to rename columns in a pandas dataframe. Parameters ---------- data : dataframe edited in place. new_old_name_tuples : list of string tuples (new, old) - ''' + """ if new_old_name_tuples is None: new_old_name_tuples = [] @@ -256,28 +246,23 @@ def json_rename_columns(data, del r[old_name] @staticmethod - def rename_columns(data, - new_old_name_tuples=None): - '''Convenience method to rename columns in a pandas dataframe. + def rename_columns(data, new_old_name_tuples=None): + """Convenience method to rename columns in a pandas dataframe. Parameters ---------- data : dataframe edited in place. new_old_name_tuples : list of string tuples (new, old) - ''' + """ if new_old_name_tuples is None: new_old_name_tuples = [] for new_name, old_name in new_old_name_tuples: - data.columns = [new_name if c == old_name else c - for c in data.columns] + data.columns = [new_name if c == old_name else c for c in data.columns] - def load_csv(self, - path, - rename=None, - index=None): - '''Read a csv file as a pandas dataframe. + def load_csv(self, path, rename=None, index=None): + """Read a csv file as a pandas dataframe. Parameters ---------- @@ -285,7 +270,7 @@ def load_csv(self, columns to rename index : string, optional post-rename column to use as the row label. - ''' + """ data = pd.read_csv(path, parse_dates=True) Cache.rename_columns(data, rename) @@ -295,11 +280,8 @@ def load_csv(self, return data - def load_json(self, - path, - rename=None, - index=None): - '''Read a json file as a pandas dataframe. + def load_json(self, path, rename=None, index=None): + """Read a json file as a pandas dataframe. Parameters ---------- @@ -307,8 +289,8 @@ def load_json(self, columns to rename index : string, optional post-rename column to use as the row label. - ''' - data = pj.read_json(path, orient='records') + """ + data = pj.read_json(path, orient="records") Cache.rename_columns(data, rename) @@ -318,10 +300,8 @@ def load_json(self, return data @staticmethod - def cacher(fn, - *args, - **kwargs): - '''make an rma query, save it and return the dataframe. + def cacher(fn, *args, **kwargs): + """make an rma query, save it and return the dataframe. Parameters ---------- @@ -350,33 +330,32 @@ def cacher(fn, ------- Object or None data type depends on fn, reader and/or post methods. - ''' - path = kwargs.pop('path', None) - strategy = kwargs.pop('strategy', None) - pre = kwargs.pop('pre', lambda d: d) - post = kwargs.pop('post', None) - reader = kwargs.pop('reader', None) - writer = kwargs.pop('writer', None) + """ + path = kwargs.pop("path", None) + strategy = kwargs.pop("strategy", None) + pre = kwargs.pop("pre", lambda d: d) + post = kwargs.pop("post", None) + reader = kwargs.pop("reader", None) + writer = kwargs.pop("writer", None) if strategy is None: if writer or path: - strategy = 'lazy' + strategy = "lazy" else: - strategy = 'pass_through' + strategy = "pass_through" - if strategy not in ['lazy', 'pass_through', - 'file', 'create']: + if strategy not in ["lazy", "pass_through", "file", "create"]: raise ValueError("Unknown query strategy: {}.".format(strategy)) - if 'lazy' == strategy: + if "lazy" == strategy: if os.path.exists(path): - strategy = 'file' + strategy = "file" else: - strategy = 'create' + strategy = "create" - if strategy == 'pass_through': + if strategy == "pass_through": data = fn(*args, **kwargs) - elif strategy in ['create']: + elif strategy in ["create"]: Manifest.safe_make_parent_dirs(path) if writer: @@ -409,75 +388,51 @@ def csv_writer(pth, gen): first_row = True row_count = 1 - with open(pth, 'w') as output: + with open(pth, "w") as output: for row in gen: if first_row: field_names = [str(k) for k in row.keys()] - csv_writer = csv.DictWriter(output, - fieldnames=field_names, - delimiter=',', - quoting=csv.QUOTE_ALL) + csv_writer = csv.DictWriter(output, fieldnames=field_names, delimiter=",", quoting=csv.QUOTE_ALL) csv_writer.writeheader() first_row = False - Cache._log.info('row: {}'.format(row_count)) + Cache._log.info("row: {}".format(row_count)) row_count = row_count + 1 csv_writer.writerow(row) @staticmethod def cache_csv_json(): - def reader(f): - return pd.read_csv(f, parse_dates=True).to_dict('records') + return pd.read_csv(f, parse_dates=True).to_dict("records") - return { - 'writer': Cache.csv_writer, - 'reader': reader - } + return {"writer": Cache.csv_writer, "reader": reader} @staticmethod def cache_csv_dataframe(): - return { - 'writer': Cache.csv_writer, - 'reader': lambda f: pd.read_csv(f, parse_dates=True) - } + return {"writer": Cache.csv_writer, "reader": lambda f: pd.read_csv(f, parse_dates=True)} @staticmethod def nocache_dataframe(): - return { - 'post': pd.DataFrame - } + return {"post": pd.DataFrame} @staticmethod def nocache_json(): - return { - } + return {} @staticmethod def cache_json_dataframe(): - return { - 'writer': ju.write, - 'reader': lambda p: pj.read_json(p, orient='records') - } + return {"writer": ju.write, "reader": lambda p: pj.read_json(p, orient="records")} @staticmethod def cache_json(): - return { - 'writer': ju.write, - 'reader': ju.read - } + return {"writer": ju.write, "reader": ju.read} @staticmethod def cache_csv(): - return { - 'writer': Cache.csv_writer, - 'reader': lambda f: pd.read_csv(f, parse_dates=True) - } + return {"writer": Cache.csv_writer, "reader": lambda f: pd.read_csv(f, parse_dates=True)} @staticmethod - def pathfinder(file_name_position, - secondary_file_name_position=None, - path_keyword=None): - '''helper method to find path argument in legacy methods written + def pathfinder(file_name_position, secondary_file_name_position=None, path_keyword=None): + """helper method to find path argument in legacy methods written prior to the @cacheable decorator. Do not use for new @cacheable methods. @@ -497,7 +452,8 @@ def pathfinder(file_name_position, This method is only intended to provide backward-compatibility for some methods that otherwise do not follow the path conventions of the @cacheable decorator. - ''' + """ + def pf(*args, **kwargs): file_name = None @@ -507,22 +463,16 @@ def pf(*args, **kwargs): if file_name_position < len(args): file_name = args[file_name_position] - if (file_name is None and - secondary_file_name_position and - secondary_file_name_position < len(args)): # noqa E129 + if file_name is None and secondary_file_name_position and secondary_file_name_position < len(args): # noqa E129 file_name = args[secondary_file_name_position] return file_name + return pf @deprecated() - def wrap(self, fn, path, cache, - save_as_json=True, - return_dataframe=False, - index=None, - rename=None, - **kwargs): - '''make an rma query, save it and return the dataframe. + def wrap(self, fn, path, cache, save_as_json=True, return_dataframe=False, index=None, rename=None, **kwargs): + """make an rma query, save it and return the dataframe. Parameters ---------- @@ -552,7 +502,7 @@ def wrap(self, fn, path, cache, Notes ----- Column renaming happens after the file is reloaded for json - ''' + """ if cache is True: json_data = fn(**kwargs) @@ -570,7 +520,7 @@ def wrap(self, fn, path, cache, # read it back in if save_as_json is True: if return_dataframe is True: - data = pj.read_json(path, orient='records') + data = pj.read_json(path, orient="records") Cache.rename_columns(data, rename) if index is not None: data.set_index([index], inplace=True) @@ -579,20 +529,13 @@ def wrap(self, fn, path, cache, elif return_dataframe is True: data = pd.read_csv(path, parse_dates=True) else: - raise ValueError( - 'save_as_json=False cannot be used with ' - 'return_dataframe=False') + raise ValueError("save_as_json=False cannot be used with return_dataframe=False") return data -def cacheable(strategy=None, - pre=None, - writer=None, - reader=None, - post=None, - pathfinder=None): - '''decorator for rma queries, save it and return the dataframe. +def cacheable(strategy=None, pre=None, writer=None, reader=None, post=None, pathfinder=None): + """decorator for rma queries, save it and return the dataframe. Parameters ---------- @@ -625,7 +568,8 @@ def cacheable(strategy=None, Notes ----- Column renaming happens after the file is reloaded for json - ''' + """ + def decor(func): decor.strategy = strategy decor.pre = pre @@ -635,40 +579,35 @@ def decor(func): decor.pathfinder = pathfinder @functools.wraps(func) - def w(*args, - **kwargs): - if decor.pathfinder and 'pathfinder' not in kwargs: + def w(*args, **kwargs): + if decor.pathfinder and "pathfinder" not in kwargs: pathfinder = decor.pathfinder else: - pathfinder = kwargs.pop('pathfinder', None) + pathfinder = kwargs.pop("pathfinder", None) - if pathfinder and 'path' not in kwargs: + if pathfinder and "path" not in kwargs: found_path = pathfinder(*args, **kwargs) if found_path: - kwargs['path'] = found_path - if decor.strategy and 'strategy' not in kwargs: - kwargs['strategy'] = decor.strategy - if decor.pre and 'pre' not in kwargs: - kwargs['pre'] = decor.pre - if decor.writer and 'writer' not in kwargs: - kwargs['writer'] = decor.writer - if decor.reader and 'reader' not in kwargs: - kwargs['reader'] = decor.reader - if decor.post and not 'post in kwargs': - kwargs['post'] = decor.post - - result = Cache.cacher(func, - *args, - **kwargs) + kwargs["path"] = found_path + if decor.strategy and "strategy" not in kwargs: + kwargs["strategy"] = decor.strategy + if decor.pre and "pre" not in kwargs: + kwargs["pre"] = decor.pre + if decor.writer and "writer" not in kwargs: + kwargs["writer"] = decor.writer + if decor.reader and "reader" not in kwargs: + kwargs["reader"] = decor.reader + if decor.post and not "post in kwargs": + kwargs["post"] = decor.post + + result = Cache.cacher(func, *args, **kwargs) return result return w + return decor def get_default_manifest_file(cache_name): - return os.environ.get( - '{}_MANIFEST'.format(cache_name.upper()), - '{}/manifest.json'.format(cache_name.lower()) - ) + return os.environ.get("{}_MANIFEST".format(cache_name.upper()), "{}/manifest.json".format(cache_name.lower())) diff --git a/allensdk/api/warehouse_cache/caching_utilities.py b/allensdk/api/warehouse_cache/caching_utilities.py index 92dfeb8106..88b575a627 100644 --- a/allensdk/api/warehouse_cache/caching_utilities.py +++ b/allensdk/api/warehouse_cache/caching_utilities.py @@ -24,10 +24,9 @@ def call_caching( cleanup: Optional[Callable[[], None]] = None, lazy: bool = True, num_tries: int = 1, - failure_message: str = "" + failure_message: str = "", ) -> P: - """ Case where a reader is provided - """ + """Case where a reader is provided""" @overload @@ -39,10 +38,9 @@ def call_caching( cleanup: Optional[Callable[[], None]] = None, lazy: bool = True, num_tries: int = 1, - failure_message: str = "" + failure_message: str = "", ) -> None: - """ Case where no reader is provided (fetches and writes, but returns nothing) - """ + """Case where no reader is provided (fetches and writes, but returns nothing)""" def call_caching( @@ -53,34 +51,34 @@ def call_caching( cleanup: Optional[Callable[[], None]] = None, lazy: bool = True, num_tries: int = 1, - failure_message: str = "" + failure_message: str = "", ) -> Optional[P]: - """ Access data, caching on a local store for future accesses. + """Access data, caching on a local store for future accesses. Parameters ---------- fetch : Function which pulls data from a remote/expensive source. - write : + write : Function which stores data in a local/inexpensive store. read : Function which pulls data from a local/inexpensive store. pre_write : Function applied to obtained data after fetching, but before writing. cleanup : - Function for fixing a failed fetch. e.g. unlinking a partially - downloaded file. Exceptions raised by cleanup are not themselves + Function for fixing a failed fetch. e.g. unlinking a partially + downloaded file. Exceptions raised by cleanup are not themselves handled lazy : - If True, attempt to read the data from the local/inexpensive store - before fetching it. If False, forcibly fetch from the + If True, attempt to read the data from the local/inexpensive store + before fetching it. If False, forcibly fetch from the remote/expensive store. num_tries : - How many fetches to attempt before (re)raising an exception. A fetch + How many fetches to attempt before (re)raising an exception. A fetch is failed if reading the result raises an exception. failure_message : - Provides additional context in the event of a failed download. Emitted - when retrying, and when a fetch failure occurs after tries are + Provides additional context in the event of a failed download. Emitted + when retrying, and when a fetch failure occurs after tries are exhausted Returns @@ -106,8 +104,7 @@ def call_caching( if isinstance(e, FileNotFoundError): logger.info("No cache file found.") # Pandas throws ValueError rather than FileNotFoundError - elif (isinstance(e, ValueError) - and str(e) == "Expected object or value"): + elif isinstance(e, ValueError) and str(e) == "Expected object or value": logger.info("No cache file found.") if cleanup is not None and not lazy: cleanup() @@ -151,15 +148,16 @@ def one_file_call_caching( num_tries: int = 1, failure_message: str = "", ) -> Optional[P]: - """ A call_caching variant where the local store is a single file. See + """A call_caching variant where the local store is a single file. See call_caching for complete documentation. Parameters ---------- - path : + path : Path at which the data will be stored """ + def safe_unlink(): try: os.unlink(path) diff --git a/allensdk/brain_observatory/__init__.py b/allensdk/brain_observatory/__init__.py index 180adeeb9d..986cb3a93c 100644 --- a/allensdk/brain_observatory/__init__.py +++ b/allensdk/brain_observatory/__init__.py @@ -36,6 +36,7 @@ import numpy as np import sys + if sys.version_info < (3, 3): from collections import Iterable else: @@ -47,9 +48,9 @@ def dict_to_indexed_array(dc, order=None): - ''' Given a dictionary and an ordered arr, build a concatenation of the dictionary's values and an index describing + """Given a dictionary and an ordered arr, build a concatenation of the dictionary's values and an index describing how that concatenation can be unpacked - ''' + """ if order is None: order = dc.keys() @@ -59,7 +60,6 @@ def dict_to_indexed_array(dc, order=None): counter = 0 for key in order: - if isinstance(dc[key], (np.ndarray, list)): extended = dc[key] if isinstance(dc[key], Iterable): @@ -86,9 +86,9 @@ def default(self, o): def hook(json_dict): for key, value in json_dict.items(): - if key == 'experiment_date': + if key == "experiment_date": json_dict[key] = dateutil.parser.parse(value) - elif key == 'behavior_session_uuid': + elif key == "behavior_session_uuid": json_dict[key] = uuid.UUID(value) else: pass diff --git a/allensdk/brain_observatory/argschema_utilities.py b/allensdk/brain_observatory/argschema_utilities.py index e581fdc462..7cd082b53e 100644 --- a/allensdk/brain_observatory/argschema_utilities.py +++ b/allensdk/brain_observatory/argschema_utilities.py @@ -10,8 +10,8 @@ class InputFile(marshmallow.fields.String): """A marshmallow String field subclass which deserializes json str fields - that represent a desired input path to pathlib.Path. - Also performs read access checking. + that represent a desired input path to pathlib.Path. + Also performs read access checking. """ def _deserialize(self, value, attr, obj, **kwargs) -> pathlib.Path: @@ -26,8 +26,8 @@ def _validate(self, value: pathlib.Path): class OutputFile(marshmallow.fields.String): """A marshmallow String field subclass which deserializes json str fields - that represent a desired output file path to a pathlib.Path. - Also performs write access checking. + that represent a desired output file path to a pathlib.Path. + Also performs write access checking. """ def _deserialize(self, value, attr, obj, **kwargs) -> pathlib.Path: @@ -41,8 +41,8 @@ def _validate(self, value: pathlib.Path): def write_or_print_outputs(data, parser): - data.update({'input_parameters': parser.args}) - if 'output_json' in parser.args: + data.update({"input_parameters": parser.args}) + if "output_json" in parser.args: parser.output(data, indent=2) else: print(parser.get_output_json(data)) @@ -50,25 +50,23 @@ def write_or_print_outputs(data, parser): def check_write_access_dir(dirpath): if os.path.exists(dirpath): - test_filepath = pathlib.Path(dirpath, 'test_file.txt') + test_filepath = pathlib.Path(dirpath, "test_file.txt") try: with test_filepath.open() as _: pass os.remove(test_filepath) return True except PermissionError: - raise ValidationError( - f'don\'t have permissions to write in directory {dirpath}') + raise ValidationError(f"don't have permissions to write in directory {dirpath}") else: try: pathlib.Path(dirpath).mkdir(parents=True) pathlib.Path(dirpath).rmdir() return True except PermissionError: - raise ValidationError( - f'Can\'t build path to requested location {dirpath}') + raise ValidationError(f"Can't build path to requested location {dirpath}") - raise RuntimeError('Unhandled case; this should not happen') + raise RuntimeError("Unhandled case; this should not happen") def check_write_access(filepath, allow_exists=False): @@ -78,9 +76,8 @@ def check_write_access(filepath, allow_exists=False): os.remove(filepath) return True except FileExistsError: - if not allow_exists: - raise ValidationError(f'file at {filepath} already exists') + raise ValidationError(f"file at {filepath} already exists") else: return True @@ -90,7 +87,7 @@ def check_write_access(filepath, allow_exists=False): except Exception as e: raise e - raise RuntimeError('Unhandled case; this should not happen') + raise RuntimeError("Unhandled case; this should not happen") def check_write_access_overwrite(path): @@ -99,12 +96,11 @@ def check_write_access_overwrite(path): def check_read_access(path): try: - f = open(path, mode='r') + f = open(path, mode="r") f.close() return True except Exception as err: - raise ValidationError( - f'file at #{path} not readable (#{type(err)}: {err}') + raise ValidationError(f"file at #{path} not readable (#{type(err)}: {err}") class RaisingSchema(DefaultSchema): @@ -113,7 +109,6 @@ class Meta: class ArgSchemaParserPlus(ArgSchemaParser): # pragma: no cover - def __init__(self, *args, **kwargs): parser = argparse.ArgumentParser() [known_args, extra_args] = parser.parse_known_args() @@ -131,15 +126,11 @@ def optional_lims_inputs(argv, input_schema, output_schema, lims_input_getter): lims_parser.add_argument("--host", type=str, default="http://lims2") lims_parser.add_argument("--job_queue", type=str, default=None) lims_parser.add_argument("--strategy", type=str, default=None) - lims_parser.add_argument("--ecephys_session_id", type=int, - default=None) + lims_parser.add_argument("--ecephys_session_id", type=int, default=None) lims_parser.add_argument("--output_root", type=str, default=None) - lims_args, remaining_args = lims_parser.parse_known_args( - remaining_args) - remaining_args = [ - item for item in remaining_args if item != "--get_inputs_from_lims" - ] + lims_args, remaining_args = lims_parser.parse_known_args(remaining_args) + remaining_args = [item for item in remaining_args if item != "--get_inputs_from_lims"] input_data = lims_input_getter(**lims_args.__dict__) try: diff --git a/allensdk/brain_observatory/behavior/__init__.py b/allensdk/brain_observatory/behavior/__init__.py index 0ac0fbbaf3..494ca0dc51 100644 --- a/allensdk/brain_observatory/behavior/__init__.py +++ b/allensdk/brain_observatory/behavior/__init__.py @@ -1,11 +1,11 @@ +IMAGE_SETS = { + "Natural_Images_Lum_Matched_set_ophys_6_2017.07.14": "//allen/programs/braintv/workgroups/nc-ophys/Doug/Stimulus_Code/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_6_2017.07.14.pkl", + "Natural_Images_Lum_Matched_set_training_2017.07.14": "//allen/programs/braintv/workgroups/nc-ophys/Doug/Stimulus_Code/image_dictionaries/Natural_Images_Lum_Matched_set_training_2017.07.14.pkl", + "Natural_Images_Lum_Matched_set_training_2017.07.14_2": "//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_training_2017.07.14.pkl", + "Natural_Images_Lum_Matched_set_ophys_6_2017.07.14_2": "//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_6_2017.07.14.pkl", + "Natural_Images_Lum_Matched_set_ophys_H_2019.05.26": "//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_H_2019.05.26.pkl", + "Natural_Images_Lum_Matched_set_ophys_G_2019.05.26": "//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_G_2019.05.26.pkl", +} -IMAGE_SETS = {'Natural_Images_Lum_Matched_set_ophys_6_2017.07.14': '//allen/programs/braintv/workgroups/nc-ophys/Doug/Stimulus_Code/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_6_2017.07.14.pkl', - 'Natural_Images_Lum_Matched_set_training_2017.07.14': '//allen/programs/braintv/workgroups/nc-ophys/Doug/Stimulus_Code/image_dictionaries/Natural_Images_Lum_Matched_set_training_2017.07.14.pkl', - 'Natural_Images_Lum_Matched_set_training_2017.07.14_2': '//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_training_2017.07.14.pkl', - 'Natural_Images_Lum_Matched_set_ophys_6_2017.07.14_2': '//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_6_2017.07.14.pkl', - 'Natural_Images_Lum_Matched_set_ophys_H_2019.05.26': '//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_H_2019.05.26.pkl', - 'Natural_Images_Lum_Matched_set_ophys_G_2019.05.26': '//allen/programs/braintv/workgroups/nc-ophys/visual_behavior/image_dictionaries/Natural_Images_Lum_Matched_set_ophys_G_2019.05.26.pkl'} - - -assert len(IMAGE_SETS) == len(set(IMAGE_SETS.keys())) == len(set(IMAGE_SETS.values())) \ No newline at end of file +assert len(IMAGE_SETS) == len(set(IMAGE_SETS.keys())) == len(set(IMAGE_SETS.values())) diff --git a/allensdk/brain_observatory/behavior/behavior_ophys_analysis.py b/allensdk/brain_observatory/behavior/behavior_ophys_analysis.py index d4867d7f42..f305af706c 100644 --- a/allensdk/brain_observatory/behavior/behavior_ophys_analysis.py +++ b/allensdk/brain_observatory/behavior/behavior_ophys_analysis.py @@ -3,10 +3,10 @@ import seaborn as sns from allensdk.core.lazy_property import LazyPropertyMixin -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import \ - BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment -def plot_trace(timestamps, trace, ax=None, xlabel='time (seconds)', ylabel='fluorescence', title='roi'): + +def plot_trace(timestamps, trace, ax=None, xlabel="time (seconds)", ylabel="fluorescence", title="roi"): if ax is None: fig, ax = plt.subplots(figsize=(15, 5)) colors = sns.color_palette() @@ -18,9 +18,10 @@ def plot_trace(timestamps, trace, ax=None, xlabel='time (seconds)', ylabel='fluo return ax -def plot_example_traces_and_behavior(dataset, cell_roi_ids, xmin_seconds, length_mins, save_dir=None, - include_running=False, cell_label=False): - suffix = '' +def plot_example_traces_and_behavior( + dataset, cell_roi_ids, xmin_seconds, length_mins, save_dir=None, include_running=False, cell_label=False +): + suffix = "" if include_running: n = 2 else: @@ -36,19 +37,18 @@ def plot_example_traces_and_behavior(dataset, cell_roi_ids, xmin_seconds, length ymins = [] ymaxs = [] for i, cell_roi_id in enumerate(cell_roi_ids): - trace = dataset.dff_traces[dataset.dff_traces['cell_roi_id']==cell_roi_id]['dff'].values[0] - ax[i] = plot_trace(dataset.ophys_timestamps, trace, ax=ax[i], - title='', ylabel=str(cell_roi_id)) + trace = dataset.dff_traces[dataset.dff_traces["cell_roi_id"] == cell_roi_id]["dff"].values[0] + ax[i] = plot_trace(dataset.ophys_timestamps, trace, ax=ax[i], title="", ylabel=str(cell_roi_id)) ax[i] = add_stim_color_span(dataset, ax=ax[i], xlim=xlim) ax[i] = restrict_axes(xmin_seconds, xmax_seconds, interval_seconds, ax=ax[i]) - ax[i].set_xlabel('') + ax[i].set_xlabel("") ymin, ymax = ax[i].get_ylim() ymins.append(ymin) ymaxs.append(ymax) if cell_label: ax[i].set_ylabel(str(cell_index)) else: - ax[i].set_ylabel('dF/F') + ax[i].set_ylabel("dF/F") sns.despine(ax=ax[i]) for i, cell_roi_id in enumerate(cell_roi_ids): @@ -59,9 +59,9 @@ def plot_example_traces_and_behavior(dataset, cell_roi_ids, xmin_seconds, length ax[i] = plot_behavior_events(dataset, ax=ax[i], behavior_only=True) ax[i] = add_stim_color_span(dataset, ax=ax[i], xlim=xlim) ax[i].set_xlim(xlim) - ax[i].set_ylabel('') + ax[i].set_ylabel("") ax[i].axes.get_yaxis().set_visible(False) - ax[i].legend(loc='upper left', fontsize=14) + ax[i].legend(loc="upper left", fontsize=14) sns.despine(ax=ax[i]) if include_running: @@ -69,47 +69,43 @@ def plot_example_traces_and_behavior(dataset, cell_roi_ids, xmin_seconds, length ax[i].plot(dataset.stimulus_timestamps, dataset.running_speed) ax[i] = add_stim_color_span(dataset, ax=ax[i], xlim=xlim) ax[i] = restrict_axes(xmin_seconds, xmax_seconds, interval_seconds, ax=ax[i]) - ax[i].set_ylabel('run speed\n(cm/s)') + ax[i].set_ylabel("run speed\n(cm/s)") # ax[i].axes.get_yaxis().set_visible(False) sns.despine(ax=ax[i]) - ax[i].set_xlabel('time (seconds)') + ax[i].set_xlabel("time (seconds)") ax[0].set_title(dataset.ophys_experiment_id) fig.tight_layout() plt.subplots_adjust(wspace=0, hspace=0) if save_dir is not None: - save_figure(fig, figsize, save_dir, 'example_traces', 'example_traces_' + str(xlim[0]) + suffix) - save_figure(fig, figsize, save_dir, 'example_traces', - str(dataset.ophys_experiment_id) + '_' + str(xlim[0]) + suffix) + save_figure(fig, figsize, save_dir, "example_traces", "example_traces_" + str(xlim[0]) + suffix) + save_figure( + fig, figsize, save_dir, "example_traces", str(dataset.ophys_experiment_id) + "_" + str(xlim[0]) + suffix + ) plt.close() -class BehaviorOphysAnalysis(LazyPropertyMixin): +class BehaviorOphysAnalysis(LazyPropertyMixin): def __init__(self, session, api=None): - self.session = session - self.api = self if api is None else api + self.api = self if api is None else api # self.active_cell_roi_ids = LazyProperty(self.api.get_active_cell_roi_ids, ophys_experiment_id=self.ophys_experiment_id) - - - - def plot_example_traces_and_behavior(self, N=10): dff_traces_df = self.session.dff_traces - dff_traces_df['mean'] = dff_traces_df['dff'].apply(np.mean) - dff_traces_df['std'] = dff_traces_df['dff'].apply(np.std) - dff_traces_df['snr'] = dff_traces_df['mean']/dff_traces_df['std'] - active_cell_roi_ids = dff_traces_df.sort_values('snr', ascending=False)['cell_roi_id'].values[:N] + dff_traces_df["mean"] = dff_traces_df["dff"].apply(np.mean) + dff_traces_df["std"] = dff_traces_df["dff"].apply(np.std) + dff_traces_df["snr"] = dff_traces_df["mean"] / dff_traces_df["std"] + active_cell_roi_ids = dff_traces_df.sort_values("snr", ascending=False)["cell_roi_id"].values[:N] length_mins = 1 for xmin_seconds in np.arange(0, 5000, length_mins * 60): - plot_example_traces_and_behavior(self.session, active_cell_roi_ids, xmin_seconds, length_mins, cell_label=False, include_running=True) + plot_example_traces_and_behavior( + self.session, active_cell_roi_ids, xmin_seconds, length_mins, cell_label=False, include_running=True + ) if __name__ == "__main__": - session = BehaviorOphysExperiment(789359614) analysis = BehaviorOphysAnalysis(session) analysis.plot_example_traces_and_behavior() - diff --git a/allensdk/brain_observatory/behavior/behavior_ophys_experiment.py b/allensdk/brain_observatory/behavior/behavior_ophys_experiment.py index 54283948cc..ae97b4f7e2 100644 --- a/allensdk/brain_observatory/behavior/behavior_ophys_experiment.py +++ b/allensdk/brain_observatory/behavior/behavior_ophys_experiment.py @@ -75,9 +75,7 @@ def __init__( task_parameters=behavior_session._task_parameters, trials=behavior_session._trials, date_of_acquisition=date_of_acquisition, - eye_tracking_rig_geometry=( - behavior_session._eye_tracking_rig_geometry - ), + eye_tracking_rig_geometry=(behavior_session._eye_tracking_rig_geometry), eye_tracking_table=behavior_session._eye_tracking, ) @@ -92,9 +90,7 @@ def to_nwb(self) -> NWBFile: self._metadata.to_nwb(nwbfile=nwbfile) self._projections.to_nwb(nwbfile=nwbfile) - self._cell_specimens.to_nwb( - nwbfile=nwbfile, ophys_timestamps=self._ophys_timestamps - ) + self._cell_specimens.to_nwb(nwbfile=nwbfile, ophys_timestamps=self._ophys_timestamps) self._motion_correction.to_nwb(nwbfile=nwbfile) return nwbfile @@ -136,21 +132,15 @@ def _is_multi_plane_session(): imaging_plane_group_meta = ImagingPlaneGroup.from_lims( ophys_experiment_id=ophys_experiment_id, lims_db=lims_db ) - return cls._is_multi_plane_session( - imaging_plane_group_meta=imaging_plane_group_meta - ) + return cls._is_multi_plane_session(imaging_plane_group_meta=imaging_plane_group_meta) def _get_motion_correction(): rigid_motion_transform_file = RigidMotionTransformFile.from_lims( ophys_experiment_id=ophys_experiment_id, db=lims_db ) - return MotionCorrection.from_data_file( - rigid_motion_transform_file=rigid_motion_transform_file - ) + return MotionCorrection.from_data_file(rigid_motion_transform_file=rigid_motion_transform_file) - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) behavior_session_id = BehaviorSessionId.from_lims( db=lims_db, @@ -165,17 +155,11 @@ def _get_motion_correction(): is_multiplane=is_multiplane_session, ) - sync_file = SyncFile.from_lims( - db=lims_db, behavior_session_id=behavior_session_id.value - ) + sync_file = SyncFile.from_lims(db=lims_db, behavior_session_id=behavior_session_id.value) - monitor_delay = calculate_monitor_delay( - sync_file=sync_file, equipment=meta.behavior_metadata.equipment - ) + monitor_delay = calculate_monitor_delay(sync_file=sync_file, equipment=meta.behavior_metadata.equipment) - date_of_acquisition = DateOfAcquisitionOphys.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + date_of_acquisition = DateOfAcquisitionOphys.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) behavior_session = BehaviorSession.from_lims( lims_db=lims_db, behavior_session_id=behavior_session_id.value, @@ -192,13 +176,9 @@ def _get_motion_correction(): plane_group=meta.ophys_metadata.imaging_plane_group, ) else: - ophys_timestamps = OphysTimestamps.from_sync_file( - sync_file=sync_file - ) + ophys_timestamps = OphysTimestamps.from_sync_file(sync_file=sync_file) - projections = Projections.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + projections = Projections.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) cell_specimens = CellSpecimens.from_lims( ophys_experiment_id=ophys_experiment_id, lims_db=lims_db, @@ -257,12 +237,8 @@ def from_nwb( """ def _is_multi_plane_session(): - imaging_plane_group_meta = ImagingPlaneGroup.from_nwb( - nwbfile=nwbfile - ) - return cls._is_multi_plane_session( - imaging_plane_group_meta=imaging_plane_group_meta - ) + imaging_plane_group_meta = ImagingPlaneGroup.from_nwb(nwbfile=nwbfile) + return cls._is_multi_plane_session(imaging_plane_group_meta=imaging_plane_group_meta) behavior_session = BehaviorSession.from_nwb(nwbfile=nwbfile) projections = Projections.from_nwb(nwbfile=nwbfile) @@ -277,13 +253,9 @@ def _is_multi_plane_session(): ) motion_correction = MotionCorrection.from_nwb(nwbfile=nwbfile) is_multiplane_session = _is_multi_plane_session() - metadata = BehaviorOphysMetadata.from_nwb( - nwbfile=nwbfile, is_multiplane=is_multiplane_session - ) + metadata = BehaviorOphysMetadata.from_nwb(nwbfile=nwbfile, is_multiplane=is_multiplane_session) if is_multiplane_session: - ophys_timestamps = OphysTimestampsMultiplane.from_nwb( - nwbfile=nwbfile - ) + ophys_timestamps = OphysTimestampsMultiplane.from_nwb(nwbfile=nwbfile) else: ophys_timestamps = OphysTimestamps.from_nwb(nwbfile=nwbfile) date_of_acquisition = DateOfAcquisitionOphys.from_nwb(nwbfile=nwbfile) @@ -327,30 +299,18 @@ def from_json( """ def _is_multi_plane_session(): - imaging_plane_group_meta = ImagingPlaneGroup.from_json( - dict_repr=session_data - ) - return cls._is_multi_plane_session( - imaging_plane_group_meta=imaging_plane_group_meta - ) + imaging_plane_group_meta = ImagingPlaneGroup.from_json(dict_repr=session_data) + return cls._is_multi_plane_session(imaging_plane_group_meta=imaging_plane_group_meta) def _get_motion_correction(): - rigid_motion_transform_file = RigidMotionTransformFile.from_json( - dict_repr=session_data - ) - return MotionCorrection.from_data_file( - rigid_motion_transform_file=rigid_motion_transform_file - ) + rigid_motion_transform_file = RigidMotionTransformFile.from_json(dict_repr=session_data) + return MotionCorrection.from_data_file(rigid_motion_transform_file=rigid_motion_transform_file) sync_file = SyncFile.from_json(dict_repr=session_data) is_multiplane_session = _is_multi_plane_session() - meta = BehaviorOphysMetadata.from_json( - dict_repr=session_data, is_multiplane=is_multiplane_session - ) + meta = BehaviorOphysMetadata.from_json(dict_repr=session_data, is_multiplane=is_multiplane_session) if "monitor_delay" not in session_data: - monitor_delay = calculate_monitor_delay( - sync_file=sync_file, equipment=meta.behavior_metadata.equipment - ) + monitor_delay = calculate_monitor_delay(sync_file=sync_file, equipment=meta.behavior_metadata.equipment) session_data["monitor_delay"] = monitor_delay behavior_session = BehaviorSession.from_json(session_data=session_data) @@ -362,9 +322,7 @@ def _get_motion_correction(): plane_group=meta.ophys_metadata.imaging_plane_group, ) else: - ophys_timestamps = OphysTimestamps.from_sync_file( - sync_file=sync_file - ) + ophys_timestamps = OphysTimestamps.from_sync_file(sync_file=sync_file) projections = Projections.from_json(dict_repr=session_data) cell_specimens = CellSpecimens.from_json( @@ -402,12 +360,8 @@ def update_targeted_imaging_depth(self, ophys_experiment_ids: List[int]): Subset of experiments sharing the same container as the experiment being loaded in this object. """ - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) - self._metadata.update_targeted_imaging_depth( - ophys_experiment_ids, lims_db=lims_db - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + self._metadata.update_targeted_imaging_depth(ophys_experiment_ids, lims_db=lims_db) # ========================= 'get' methods ========================== @@ -427,12 +381,8 @@ def get_dff_traces(self, cell_specimen_ids=None): if cell_specimen_ids is None: cell_specimen_ids = self.get_cell_specimen_ids() - csid_table = self.cell_specimen_table.reset_index()[ - ["cell_specimen_id"] - ] - csid_subtable = csid_table[ - csid_table["cell_specimen_id"].isin(cell_specimen_ids) - ].set_index("cell_specimen_id") + csid_table = self.cell_specimen_table.reset_index()[["cell_specimen_id"]] + csid_subtable = csid_table[csid_table["cell_specimen_id"].isin(cell_specimen_ids)].set_index("cell_specimen_id") dff_table = csid_subtable.join(self.dff_traces, how="left") dff_traces = np.vstack(dff_table["dff"].values) timestamps = self.ophys_timestamps @@ -442,22 +392,14 @@ def get_dff_traces(self, cell_specimen_ids=None): @legacy() def get_cell_specimen_indices(self, cell_specimen_ids): - return [ - self.cell_specimen_table.index.get_loc(csid) - for csid in cell_specimen_ids - ] + return [self.cell_specimen_table.index.get_loc(csid) for csid in cell_specimen_ids] @legacy("Consider using cell_specimen_table['cell_specimen_id'] instead.") def get_cell_specimen_ids(self): cell_specimen_ids = self.cell_specimen_table.index.values - if np.isnan(cell_specimen_ids.astype(float)).sum() == len( - self.cell_specimen_table - ): - raise ValueError( - "cell_specimen_id values not assigned " - f"for {self.ophys_experiment_id}" - ) + if np.isnan(cell_specimen_ids.astype(float)).sum() == len(self.cell_specimen_table): + raise ValueError(f"cell_specimen_id values not assigned for {self.ophys_experiment_id}") return cell_specimen_ids # ====================== properties ======================== @@ -478,9 +420,7 @@ def ophys_session_id(self) -> int: @property def metadata(self): - behavior_meta = super()._get_metadata( - behavior_metadata=self._metadata.behavior_metadata - ) + behavior_meta = super()._get_metadata(behavior_metadata=self._metadata.behavior_metadata) ophys_meta = { "indicator": self._cell_specimens.meta.imaging_plane.indicator, "emission_lambda": self._cell_specimens.meta.emission_lambda, @@ -744,10 +684,7 @@ def _is_multi_plane_session( imaging_plane_group_meta: ImagingPlaneGroup, ) -> bool: """Returns whether this experiment is part of a multiplane session""" - return ( - imaging_plane_group_meta is not None - and imaging_plane_group_meta.plane_group_count > 1 - ) + return imaging_plane_group_meta is not None and imaging_plane_group_meta.plane_group_count > 1 def _get_session_type(self) -> str: return self._metadata.behavior_metadata.session_type diff --git a/allensdk/brain_observatory/behavior/behavior_ophys_session.py b/allensdk/brain_observatory/behavior/behavior_ophys_session.py index 30186cfb48..810cb04711 100644 --- a/allensdk/brain_observatory/behavior/behavior_ophys_session.py +++ b/allensdk/brain_observatory/behavior/behavior_ophys_session.py @@ -1,11 +1,11 @@ import warnings -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import \ - BehaviorOphysExperiment as BOE +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment as BOE # alias as BOE prevents someone becoming comfortable with # import BehaviorOphysExperiment from this to-be-deprecated module + class BehaviorOphysSession(BOE): def __init__(self, **kwargs): warnings.warn( @@ -14,5 +14,6 @@ def __init__(self, **kwargs): "allensdk.brain_observatory.behavior.behavior_ophys_experiment." "BehaviorOphysExperiment.", DeprecationWarning, - stacklevel=3) + stacklevel=3, + ) super().__init__(**kwargs) diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py index 3d898c3558..198c6d8a37 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py @@ -1,6 +1,7 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - behavior_project_cache import VisualBehaviorOphysProjectCache # noqa F401 +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_project_cache import ( + VisualBehaviorOphysProjectCache, +) # noqa F401 -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - behavior_neuropixels_project_cache import \ - VisualBehaviorNeuropixelsProjectCache # noqa F401 +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import ( + VisualBehaviorNeuropixelsProjectCache, +) # noqa F401 diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_neuropixels_project_cache.py b/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_neuropixels_project_cache.py index 3e0355da51..0141a7bfc4 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_neuropixels_project_cache.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_neuropixels_project_cache.py @@ -1,18 +1,16 @@ import numpy as np import pandas as pd -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - project_apis.data_io import VisualBehaviorNeuropixelsProjectCloudApi -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - project_cache_base import ProjectCacheBase -from allensdk.brain_observatory.behavior.behavior_session import \ - BehaviorSession -from allensdk.brain_observatory.ecephys.behavior_ecephys_session import \ - BehaviorEcephysSession +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io import ( + VisualBehaviorNeuropixelsProjectCloudApi, +) +from allensdk.brain_observatory.behavior.behavior_project_cache.project_cache_base import ProjectCacheBase +from allensdk.brain_observatory.behavior.behavior_session import BehaviorSession +from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession class VisualBehaviorNeuropixelsProjectCache(ProjectCacheBase): - """ Entrypoint for accessing Visual Behavior Neuropixels data. + """Entrypoint for accessing Visual Behavior Neuropixels data. Supports access to metadata tables: get_ecephys_session_table() @@ -35,21 +33,14 @@ class VisualBehaviorNeuropixelsProjectCache(ProjectCacheBase): PROJECT_NAME = "visual-behavior-neuropixels" BUCKET_NAME = "visual-behavior-neuropixels-data" - def __init__( - self, - fetch_api: VisualBehaviorNeuropixelsProjectCloudApi, - fetch_tries: int = 2 - ): + def __init__(self, fetch_api: VisualBehaviorNeuropixelsProjectCloudApi, fetch_tries: int = 2): super().__init__(fetch_api=fetch_api, fetch_tries=fetch_tries) @classmethod def cloud_api_class(cls): return VisualBehaviorNeuropixelsProjectCloudApi - def get_ecephys_session_table( - self, - filter_abnormalities: bool = True - ) -> pd.DataFrame: + def get_ecephys_session_table(self, filter_abnormalities: bool = True) -> pd.DataFrame: """ Parameters ---------- @@ -69,9 +60,8 @@ def get_ecephys_session_table( if filter_abnormalities: sessions_table = sessions_table.loc[ - np.logical_and( - sessions_table.abnormal_histology.isna(), - sessions_table.abnormal_activity.isna())] + np.logical_and(sessions_table.abnormal_histology.isna(), sessions_table.abnormal_activity.isna()) + ] return sessions_table @@ -126,10 +116,7 @@ def get_unit_table(self) -> pd.DataFrame: """ return self.fetch_api.get_unit_table() - def get_ecephys_session( - self, - ecephys_session_id: int - ) -> BehaviorEcephysSession: + def get_ecephys_session(self, ecephys_session_id: int) -> BehaviorEcephysSession: """ Loads all data for `ecephys_session_id` into an `allensdk.ecephys.behavior_ecephys_session.BehaviorEcephysSession` @@ -148,10 +135,7 @@ def get_ecephys_session( """ return self.fetch_api.get_ecephys_session(ecephys_session_id) - def get_behavior_session( - self, - behavior_session_id: int - ) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """ Loads all data for `behavior_session_id` into an `allensdk.brain_observatory.behavior.behavior_session.BehaviorSession` diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_project_cache.py b/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_project_cache.py index 6b6aa24942..6e651ee02d 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_project_cache.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_project_cache.py @@ -78,9 +78,7 @@ class VisualBehaviorOphysProjectCache(ProjectCacheBase): def __init__( self, - fetch_api: Optional[ - Union[BehaviorProjectLimsApi, BehaviorProjectCloudApi] - ] = None, + fetch_api: Optional[Union[BehaviorProjectLimsApi, BehaviorProjectCloudApi]] = None, fetch_tries: int = 2, manifest: Optional[Union[str, Path]] = None, version: Optional[str] = None, @@ -128,9 +126,7 @@ def __init__( if not isinstance(self.fetch_api, BehaviorProjectCloudApi): if cache: - self.cache = VBOLimsCache( - manifest=manifest_, version=version, cache=cache - ) + self.cache = VBOLimsCache(manifest=manifest_, version=version, cache=cache) warnings.warn( message="\n\tAs of AllenSDK version 2.16.0, the latest Visual " @@ -183,9 +179,7 @@ def get_ophys_session_table( if isinstance(self.fetch_api, BehaviorProjectCloudApi): return self.fetch_api.get_ophys_session_table() if self.cache is not None: - path = self.cache.get_cache_path( - None, self.cache.OPHYS_SESSIONS_KEY - ) + path = self.cache.get_cache_path(None, self.cache.OPHYS_SESSIONS_KEY) ophys_sessions = one_file_call_caching( path, self.fetch_api.get_ophys_session_table, @@ -210,9 +204,7 @@ def get_ophys_session_table( suffixes=("_behavior", "_ophys"), ) - sessions = BehaviorOphysSessionsTable( - df=ophys_sessions, suppress=suppress, index_column=index_column - ) + sessions = BehaviorOphysSessionsTable(df=ophys_sessions, suppress=suppress, index_column=index_column) return sessions.table if as_df else sessions @@ -230,16 +222,12 @@ def get_ophys_experiment_table( if isinstance(self.fetch_api, BehaviorProjectCloudApi): return self.fetch_api.get_ophys_experiment_table() if self.cache is not None: - path = self.cache.get_cache_path( - None, self.cache.OPHYS_EXPERIMENTS_KEY - ) + path = self.cache.get_cache_path(None, self.cache.OPHYS_EXPERIMENTS_KEY) experiments = one_file_call_caching( path, self.fetch_api.get_ophys_experiment_table, _write_json, - lambda path: _read_json( - path, index_name="ophys_experiment_id" - ), + lambda path: _read_json(path, index_name="ophys_experiment_id"), ) else: experiments = self.fetch_api.get_ophys_experiment_table() @@ -304,16 +292,12 @@ def get_behavior_session_table( if isinstance(self.fetch_api, BehaviorProjectCloudApi): return self.fetch_api.get_behavior_session_table() if self.cache is not None: - path = self.cache.get_cache_path( - None, self.cache.BEHAVIOR_SESSIONS_KEY - ) + path = self.cache.get_cache_path(None, self.cache.BEHAVIOR_SESSIONS_KEY) sessions = one_file_call_caching( path, self.fetch_api.get_behavior_session_table, _write_json, - lambda path: _read_json( - path, index_name="behavior_session_id" - ), + lambda path: _read_json(path, index_name="behavior_session_id"), ) else: sessions = self.fetch_api.get_behavior_session_table() @@ -334,9 +318,7 @@ def get_behavior_session_table( return sessions.table if as_df else sessions - def get_behavior_ophys_experiment( - self, ophys_experiment_id: int - ) -> BehaviorOphysExperiment: + def get_behavior_ophys_experiment(self, ophys_experiment_id: int) -> BehaviorOphysExperiment: """ Gets `BehaviorOphysExperiment` for `ophys_experiment_id` Parameters @@ -347,13 +329,9 @@ def get_behavior_ophys_experiment( ------- BehaviorOphysExperiment """ - return self.fetch_api.get_behavior_ophys_experiment( - ophys_experiment_id=ophys_experiment_id - ) + return self.fetch_api.get_behavior_ophys_experiment(ophys_experiment_id=ophys_experiment_id) - def get_behavior_session( - self, behavior_session_id: int - ) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """ Gets `BehaviorSession` for `behavior_session_id` Parameters @@ -364,9 +342,7 @@ def get_behavior_session( ------- BehaviorSession """ - return self.fetch_api.get_behavior_session( - behavior_session_id=behavior_session_id - ) + return self.fetch_api.get_behavior_session(behavior_session_id=behavior_session_id) def get_raw_natural_movie(self) -> np.ndarray: """Download the raw movie data from the cloud and return it as a numpy diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py index 53b4621b76..8a46b05196 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py @@ -1 +1,3 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.abcs.behavior_project_base import BehaviorProjectBase # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.abcs.behavior_project_base import ( + BehaviorProjectBase, +) # noqa: F401, E501 diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/behavior_project_base.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/behavior_project_base.py index 7f701aa27a..27b31c93ed 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/behavior_project_base.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/behavior_project_base.py @@ -1,17 +1,14 @@ from abc import ABC, abstractmethod from typing import Iterable -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import ( - BehaviorOphysExperiment) -from allensdk.brain_observatory.behavior.behavior_session import ( - BehaviorSession) +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.behavior_session import BehaviorSession import pandas as pd class BehaviorProjectBase(ABC): @abstractmethod - def get_behavior_ophys_experiment(self, ophys_experiment_id: int - ) -> BehaviorOphysExperiment: + def get_behavior_ophys_experiment(self, ophys_experiment_id: int) -> BehaviorOphysExperiment: """Returns a BehaviorOphysExperiment object that contains methods to analyze a single behavior+ophys session. :param ophys_experiment_id: id that corresponds to an ophys experiment @@ -27,8 +24,7 @@ def get_ophys_session_table(self) -> pd.DataFrame: pass @abstractmethod - def get_behavior_session( - self, behavior_session_id: int) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """Returns a BehaviorSession object that contains methods to analyze a single behavior session. :param behavior_session_id: id that corresponds to a behavior session @@ -47,7 +43,7 @@ def get_behavior_session_table(self) -> pd.DataFrame: @abstractmethod def get_natural_movie_template(self, number: int) -> Iterable[bytes]: - """ Download a template for the natural movie stimulus. This is the + """Download a template for the natural movie stimulus. This is the actual movie that was shown during the recording session. :param number: identifier for this scene :type number: int diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py index 907699e591..e72177b134 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py @@ -1,5 +1,15 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_lims_api import BehaviorProjectLimsApi # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import BehaviorProjectCloudApi # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import VisualBehaviorNeuropixelsProjectCloudApi # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ProjectCloudApiBase # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.natural_movie_one_cache import NaturalMovieOneCache # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_lims_api import ( + BehaviorProjectLimsApi, +) # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import ( + BehaviorProjectCloudApi, +) # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( + VisualBehaviorNeuropixelsProjectCloudApi, +) # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( + ProjectCloudApiBase, +) # noqa: F401, E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.natural_movie_one_cache import ( + NaturalMovieOneCache, +) # noqa: F401, E501 diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_neuropixels_project_cloud_api.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_neuropixels_project_cloud_api.py index 96b588fa49..7827e91cd1 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_neuropixels_project_cloud_api.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_neuropixels_project_cloud_api.py @@ -34,9 +34,7 @@ def _load_manifest_tables(self): self._get_probe_table() self._get_channel_table() - def get_behavior_session( - self, behavior_session_id: int - ) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """ Retrieve behavior session data from either the released behavior only nwb or the behavior side of the released ecephys data. @@ -63,9 +61,7 @@ def get_behavior_session( ecephys_session_id = row.ecephys_session_id # If a file_id for the behavior session is not set, attempt to load # an associated ecephys session. - if row[self.cache.file_id_column] < 0 or np.isnan( - row[self.cache.file_id_column] - ): + if row[self.cache.file_id_column] < 0 or np.isnan(row[self.cache.file_id_column]): row = return_one_dataframe_row_only( input_table=self._ecephys_session_table, index_value=ecephys_session_id, @@ -77,9 +73,7 @@ def get_behavior_session( return BehaviorSession.from_nwb_path(str(data_path)) - def get_ecephys_session( - self, ecephys_session_id: int - ) -> BehaviorEcephysSession: + def get_ecephys_session(self, ecephys_session_id: int) -> BehaviorEcephysSession: """get a BehaviorEcephysSession by specifying ecephys_session_id Parameters @@ -98,8 +92,7 @@ def get_ecephys_session( table_name="ecephys_session_table", ) probes_meta = self._probe_table[ - (self._probe_table["ecephys_session_id"] == ecephys_session_id) - & (self._probe_table["has_lfp_data"]) + (self._probe_table["ecephys_session_id"] == ecephys_session_id) & (self._probe_table["has_lfp_data"]) ] session_file_id = str(int(session_meta[self.cache.file_id_column])) session_data_path = self._get_data_path(file_id=session_file_id) @@ -130,9 +123,7 @@ def f(): } else: probe_meta = None - return BehaviorEcephysSession.from_nwb_path( - str(session_data_path), probe_meta=probe_meta - ) + return BehaviorEcephysSession.from_nwb_path(str(session_data_path), probe_meta=probe_meta) def _get_ecephys_session_table(self): session_table_path = self._get_metadata_path(fname="ecephys_sessions") diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_cloud_api.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_cloud_api.py index 08369aefab..f01e45793e 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_cloud_api.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_cloud_api.py @@ -37,9 +37,7 @@ COL_EVAL_LIST = ["ophys_experiment_id", "ophys_container_id", "driver_line"] -def sanitize_data_columns( - input_csv_path: str, dtype_convert: dict = None -) -> pd.DataFrame: +def sanitize_data_columns(input_csv_path: str, dtype_convert: dict = None) -> pd.DataFrame: """Given an input csv path, parse the data and convert columns. Parameters @@ -70,9 +68,7 @@ def __init__( skip_version_check: bool = False, local: bool = False, ): - super().__init__( - cache=cache, skip_version_check=skip_version_check, local=local - ) + super().__init__(cache=cache, skip_version_check=skip_version_check, local=local) self._load_manifest_tables() if isinstance(cache, S3CloudCache): @@ -108,9 +104,7 @@ def _load_manifest_tables(self): self._get_ophys_experiment_table() self._get_ophys_cells_table() - def get_behavior_session( - self, behavior_session_id: int - ) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """get a BehaviorSession by specifying behavior_session_id Parameters @@ -141,10 +135,7 @@ def get_behavior_session( table_name="behavior_session_table", ) row = row.squeeze() - has_file_id = ( - not pd.isna(row[self.cache.file_id_column]) - and row[self.cache.file_id_column] > 0 - ) + has_file_id = not pd.isna(row[self.cache.file_id_column]) and row[self.cache.file_id_column] > 0 if not has_file_id: oeid = row.ophys_experiment_id[0] row = return_one_dataframe_row_only( @@ -156,9 +147,7 @@ def get_behavior_session( data_path = self._get_data_path(file_id=file_id) return BehaviorSession.from_nwb_path(nwb_path=str(data_path)) - def get_behavior_ophys_experiment( - self, ophys_experiment_id: int - ) -> BehaviorOphysExperiment: + def get_behavior_ophys_experiment(self, ophys_experiment_id: int) -> BehaviorOphysExperiment: """get a BehaviorOphysExperiment by specifying ophys_experiment_id Parameters @@ -181,20 +170,12 @@ def get_behavior_ophys_experiment( return BehaviorOphysExperiment.from_nwb_path(str(data_path)) def _get_ophys_session_table(self): - session_table_path = self._get_metadata_path( - fname="ophys_session_table" - ) + session_table_path = self._get_metadata_path(fname="ophys_session_table") df = sanitize_data_columns(session_table_path, {"mouse_id": str}) # Add UTC to match DateOfAcquisition object. - df["date_of_acquisition"] = pd.to_datetime( - df["date_of_acquisition"], format="ISO8601", utc=True - ) - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) - df = enforce_df_column_order( - input_df=df, column_order=VBO_METADATA_COLUMN_ORDER - ) + df["date_of_acquisition"] = pd.to_datetime(df["date_of_acquisition"], format="ISO8601", utc=True) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) + df = enforce_df_column_order(input_df=df, column_order=VBO_METADATA_COLUMN_ORDER) self._ophys_session_table = df.set_index("ophys_session_id") def get_ophys_session_table(self) -> pd.DataFrame: @@ -211,20 +192,12 @@ def get_ophys_session_table(self) -> pd.DataFrame: return self._ophys_session_table def _get_behavior_session_table(self): - session_table_path = self._get_metadata_path( - fname="behavior_session_table" - ) + session_table_path = self._get_metadata_path(fname="behavior_session_table") df = sanitize_data_columns(session_table_path, {"mouse_id": str}) # Add UTC to match DateOfAcquisition object. - df["date_of_acquisition"] = pd.to_datetime( - df["date_of_acquisition"], format="ISO8601", utc=True - ) - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) - df = enforce_df_column_order( - input_df=df, column_order=VBO_METADATA_COLUMN_ORDER - ) + df["date_of_acquisition"] = pd.to_datetime(df["date_of_acquisition"], format="ISO8601", utc=True) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) + df = enforce_df_column_order(input_df=df, column_order=VBO_METADATA_COLUMN_ORDER) self._behavior_session_table = df.set_index("behavior_session_id") @@ -246,31 +219,19 @@ def get_behavior_session_table(self) -> pd.DataFrame: return self._behavior_session_table def _get_ophys_experiment_table(self): - experiment_table_path = self._get_metadata_path( - fname="ophys_experiment_table" - ) + experiment_table_path = self._get_metadata_path(fname="ophys_experiment_table") df = sanitize_data_columns(experiment_table_path, {"mouse_id": str}) # Add UTC to match DateOfAcquisition object. - df["date_of_acquisition"] = pd.to_datetime( - df["date_of_acquisition"], format="ISO8601", utc=True - ) - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) - df = enforce_df_column_order( - input_df=df, column_order=VBO_METADATA_COLUMN_ORDER - ) + df["date_of_acquisition"] = pd.to_datetime(df["date_of_acquisition"], format="ISO8601", utc=True) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) + df = enforce_df_column_order(input_df=df, column_order=VBO_METADATA_COLUMN_ORDER) self._ophys_experiment_table = df.set_index("ophys_experiment_id") def _get_ophys_cells_table(self): - ophys_cells_table_path = self._get_metadata_path( - fname="ophys_cells_table" - ) + ophys_cells_table_path = self._get_metadata_path(fname="ophys_cells_table") df = sanitize_data_columns(ophys_cells_table_path) # NaN's for invalid cells force this to float, push to int - df["cell_specimen_id"] = pd.array( - df["cell_specimen_id"], dtype="Int64" - ) + df["cell_specimen_id"] = pd.array(df["cell_specimen_id"], dtype="Int64") self._ophys_cells_table = df.set_index("cell_roi_id") def get_ophys_cells_table(self): @@ -316,9 +277,7 @@ def get_natural_movie_template(self, n_workers=None) -> pd.DataFrame: ------- processed_movie : pd.DataFrame """ - return self._natural_movie_cache.get_processed_template_movie( - n_workers=n_workers - ) + return self._natural_movie_cache.get_processed_template_movie(n_workers=n_workers) def get_natural_scene_template(self, number: int) -> Iterable[bytes]: """Download a template for the natural scene stimulus. This is the diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_lims_api.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_lims_api.py index 875b55c8bc..c0edc77271 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_lims_api.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/behavior_project_lims_api.py @@ -285,9 +285,7 @@ def _get_behavior_summary_table(self) -> pd.DataFrame: self.logger.debug(f"get_behavior_session_table query: \n{query}") return self.lims_engine.select(query) - def get_behavior_stage_parameters( - self, foraging_ids: List[str] - ) -> pd.Series: + def get_behavior_stage_parameters(self, foraging_ids: List[str]) -> pd.Series: """Gets the stage parameters for each foraging id from mtrain Parameters @@ -300,9 +298,7 @@ def get_behavior_stage_parameters( --------- Series with index of foraging id and values stage parameters """ - foraging_ids_query = build_in_list_selector_query( - "bs.id", foraging_ids - ) + foraging_ids_query = build_in_list_selector_query("bs.id", foraging_ids) query = f""" SELECT @@ -316,18 +312,14 @@ def get_behavior_stage_parameters( df = df.set_index("foraging_id") return df["stage_parameters"] - def get_behavior_ophys_experiment( - self, ophys_experiment_id: int - ) -> BehaviorOphysExperiment: + def get_behavior_ophys_experiment(self, ophys_experiment_id: int) -> BehaviorOphysExperiment: """Returns a BehaviorOphysExperiment object that contains methods to analyze a single behavior+ophys session. :param ophys_experiment_id: id that corresponds to an ophys experiment :type ophys_experiment_id: int :rtype: BehaviorOphysExperiment """ - return BehaviorOphysExperiment.from_lims( - ophys_experiment_id=ophys_experiment_id - ) + return BehaviorOphysExperiment.from_lims(ophys_experiment_id=ophys_experiment_id) def _get_ophys_experiment_table(self) -> pd.DataFrame: """ @@ -386,19 +378,12 @@ def _get_ophys_experiment_table(self) -> pd.DataFrame: # Hard type targeted_imaging_depth to int to match the data_object # type. targeted_imaging_depth = ( - query_df[["ophys_container_id", "imaging_depth"]] - .groupby("ophys_container_id") - .mean() - .astype(int) + query_df[["ophys_container_id", "imaging_depth"]].groupby("ophys_container_id").mean().astype(int) ) targeted_imaging_depth.columns = ["targeted_imaging_depth"] df = query_df.merge(targeted_imaging_depth, on="ophys_container_id") - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) - df = enforce_df_column_order( - input_df=df, column_order=VBO_METADATA_COLUMN_ORDER - ) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) + df = enforce_df_column_order(input_df=df, column_order=VBO_METADATA_COLUMN_ORDER) return df def _get_ophys_cells_table(self): @@ -442,9 +427,7 @@ def _get_ophys_cells_table(self): df = self.lims_engine.select(query) # NaN's for invalid cells force this to float, push to int - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) return df def get_ophys_cells_table(self): @@ -506,13 +489,9 @@ def get_ophys_session_table(self) -> pd.DataFrame: """ # There is one ophys_session_id from 2018 that has multiple behavior # ids, causing duplicates -- drop all dupes for now; # TODO - table = self._get_ophys_session_table().drop_duplicates( - subset=["ophys_session_id"], keep=False - ) + table = self._get_ophys_session_table().drop_duplicates(subset=["ophys_session_id"], keep=False) # Make date time explicitly UTC. - table["date_of_acquisition"] = pd.to_datetime( - table["date_of_acquisition"], format="ISO8601", utc=True - ) + table["date_of_acquisition"] = pd.to_datetime(table["date_of_acquisition"], format="ISO8601", utc=True) # Fill NaN values of imaging_plane_group_count with zero to match # the behavior of the BehaviorOphysExperiment object. @@ -521,23 +500,17 @@ def get_ophys_session_table(self) -> pd.DataFrame: int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True, ) - table = enforce_df_column_order( - input_df=table, column_order=VBO_METADATA_COLUMN_ORDER - ) + table = enforce_df_column_order(input_df=table, column_order=VBO_METADATA_COLUMN_ORDER) return table.set_index("ophys_session_id") - def get_behavior_session( - self, behavior_session_id: int - ) -> BehaviorSession: + def get_behavior_session(self, behavior_session_id: int) -> BehaviorSession: """Returns a BehaviorSession object that contains methods to analyze a single behavior session. :param behavior_session_id: id that corresponds to a behavior session :type behavior_session_id: int :rtype: BehaviorSession """ - return BehaviorSession.from_lims( - behavior_session_id=behavior_session_id - ) + return BehaviorSession.from_lims(behavior_session_id=behavior_session_id) def get_ophys_experiment_table(self) -> pd.DataFrame: """Return a pd.Dataframe table with all ophys_experiment_ids and @@ -553,17 +526,11 @@ def get_ophys_experiment_table(self) -> pd.DataFrame: :rtype: pd.DataFrame """ df = self._get_ophys_experiment_table() - df["date_of_acquisition"] = pd.to_datetime( - df["date_of_acquisition"], format="ISO8601", utc=True - ) + df["date_of_acquisition"] = pd.to_datetime(df["date_of_acquisition"], format="ISO8601", utc=True) # Set type to pandas.Int64 to enforce integer typing and not revert to # float. - df = enforce_df_int_typing( - input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True - ) - df = enforce_df_column_order( - input_df=df, column_order=VBO_METADATA_COLUMN_ORDER - ) + df = enforce_df_int_typing(input_df=df, int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True) + df = enforce_df_column_order(input_df=df, column_order=VBO_METADATA_COLUMN_ORDER) return df.set_index("ophys_experiment_id") @@ -590,9 +557,7 @@ def get_behavior_session_table(self) -> pd.DataFrame: int_columns=VBO_INTEGER_COLUMNS, use_pandas_type=True, ) - summary_tbl = enforce_df_column_order( - input_df=summary_tbl, column_order=VBO_METADATA_COLUMN_ORDER - ) + summary_tbl = enforce_df_column_order(input_df=summary_tbl, column_order=VBO_METADATA_COLUMN_ORDER) return summary_tbl.set_index("behavior_session_id") @@ -655,50 +620,31 @@ def get_release_files(self, file_type="BehaviorNwb") -> pd.DataFrame: """ res = self.lims_engine.select(query) - res["isilon_filepath"] = res["storage_directory"].str.cat( - res["filename"] - ) + res["isilon_filepath"] = res["storage_directory"].str.cat(res["filename"]) res = res.drop(["filename", "storage_directory"], axis=1) return res.set_index(attachable_id_alias) def _get_behavior_session_release_filter(self): # 1) Get release behavior only session ids - behavior_only_release_files = self.get_release_files( - file_type="BehaviorNwb" - ) - release_behavior_only_session_ids = ( - behavior_only_release_files.index.tolist() - ) + behavior_only_release_files = self.get_release_files(file_type="BehaviorNwb") + release_behavior_only_session_ids = behavior_only_release_files.index.tolist() # 2) Get release behavior with ophys session ids - ophys_release_files = self.get_release_files( - file_type="BehaviorOphysNwb" - ) - release_behavior_with_ophys_session_ids = ophys_release_files[ - "behavior_session_id" - ].tolist() + ophys_release_files = self.get_release_files(file_type="BehaviorOphysNwb") + release_behavior_with_ophys_session_ids = ophys_release_files["behavior_session_id"].tolist() # 3) release behavior session ids is combination - release_behavior_session_ids = ( - release_behavior_only_session_ids - + release_behavior_with_ophys_session_ids - ) + release_behavior_session_ids = release_behavior_only_session_ids + release_behavior_with_ophys_session_ids - return build_in_list_selector_query( - "bs.id", release_behavior_session_ids - ) + return build_in_list_selector_query("bs.id", release_behavior_session_ids) def _get_ophys_session_release_filter(self): release_files = self.get_release_files(file_type="BehaviorOphysNwb") - return build_in_list_selector_query( - "bs.id", release_files["behavior_session_id"].tolist() - ) + return build_in_list_selector_query("bs.id", release_files["behavior_session_id"].tolist()) def _get_ophys_experiment_release_filter(self): release_files = self.get_release_files(file_type="BehaviorOphysNwb") - return build_in_list_selector_query( - "oe.id", release_files.index.tolist() - ) + return build_in_list_selector_query("oe.id", release_files.index.tolist()) def get_natural_movie_template(self, number: int) -> Iterable[bytes]: """Download a template for the natural movie stimulus. This is the diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/natural_movie_one_cache.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/natural_movie_one_cache.py index 4f222c6e85..4c6de409e1 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/natural_movie_one_cache.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/natural_movie_one_cache.py @@ -86,6 +86,4 @@ def get_processed_template_movie(self, n_workers=None): movie_frames=movie_data, n_workers=n_workers, ) - return movie_template.to_dataframe( - index_name="movie_frame_index", index_type="int" - ) + return movie_template.to_dataframe(index_name="movie_frame_index", index_type="int") diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/project_cloud_api_base.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/project_cloud_api_base.py index 594c97ed38..21e06fa174 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/project_cloud_api_base.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/project_cloud_api_base.py @@ -2,11 +2,9 @@ from pathlib import Path import logging -from allensdk.api.cloud_cache.cloud_cache import ( - S3CloudCache, LocalCache, StaticLocalCache) +from allensdk.api.cloud_cache.cloud_cache import S3CloudCache, LocalCache, StaticLocalCache -from allensdk.brain_observatory.behavior.behavior_project_cache \ - .utils import version_check +from allensdk.brain_observatory.behavior.behavior_project_cache.utils import version_check class ProjectCloudApiBase(object): @@ -26,13 +24,13 @@ class ProjectCloudApiBase(object): Whether to operate in local mode, where no data will be downloaded and instead will be loaded from local """ + def __init__( self, cache: Union[S3CloudCache, LocalCache, StaticLocalCache], skip_version_check: bool = False, - local: bool = False + local: bool = False, ): - self.cache = cache self.skip_version_check = skip_version_check self._local = local @@ -54,32 +52,32 @@ def load_manifest(self, manifest_name: Optional[str] = None): self.cache.load_manifest(manifest_name) if self.cache._manifest.metadata_file_names is None: - raise RuntimeError(f"{type(self.cache)} object has no metadata " - f"file names. Check contents of the loaded " - f"manifest file: {self.cache._manifest_name}") + raise RuntimeError( + f"{type(self.cache)} object has no metadata " + f"file names. Check contents of the loaded " + f"manifest file: {self.cache._manifest_name}" + ) if not self.skip_version_check: - data_sdk_version = [i for i in self.cache._manifest._data_pipeline - if i['name'] == "AllenSDK"][0]["version"] + data_sdk_version = [i for i in self.cache._manifest._data_pipeline if i["name"] == "AllenSDK"][0]["version"] version_check( self.cache._manifest.version, data_sdk_version, cmin=self.MANIFEST_COMPATIBILITY[0], - cmax=self.MANIFEST_COMPATIBILITY[1]) + cmax=self.MANIFEST_COMPATIBILITY[1], + ) self.logger = logging.getLogger(self.__class__.__name__) self._load_manifest_tables() def _load_manifest_tables(self): - raise NotImplementedError @classmethod - def from_s3_cache(cls, cache_dir: Union[str, Path], - bucket_name: str, - project_name: str, - ui_class_name: str) -> "ProjectCloudApiBase": + def from_s3_cache( + cls, cache_dir: Union[str, Path], bucket_name: str, project_name: str, ui_class_name: str + ) -> "ProjectCloudApiBase": """instantiates this object with a connection to an s3 bucket and/or a local cache related to that bucket. @@ -105,19 +103,12 @@ def from_s3_cache(cls, cache_dir: Union[str, Path], BehaviorProjectCloudApi instance """ - cache = S3CloudCache(cache_dir, - bucket_name, - project_name, - ui_class_name=ui_class_name) + cache = S3CloudCache(cache_dir, bucket_name, project_name, ui_class_name=ui_class_name) return cls(cache) @classmethod def from_local_cache( - cls, - cache_dir: Union[str, Path], - project_name: str, - ui_class_name: str, - use_static_cache: bool = False + cls, cache_dir: Union[str, Path], project_name: str, ui_class_name: str, use_static_cache: bool = False ) -> "ProjectCloudApiBase": """instantiates this object with a local cache. @@ -140,17 +131,9 @@ def from_local_cache( """ if use_static_cache: - cache = StaticLocalCache( - cache_dir, - project_name, - ui_class_name=ui_class_name - ) + cache = StaticLocalCache(cache_dir, project_name, ui_class_name=ui_class_name) else: - cache = LocalCache( - cache_dir, - project_name, - ui_class_name=ui_class_name - ) + cache = LocalCache(cache_dir, project_name, ui_class_name=ui_class_name) return cls(cache, local=True) def _get_metadata_path(self, fname: str): @@ -167,23 +150,22 @@ def _get_data_path(self, file_id: str): data_path = self.cache.download_data(file_id=file_id) return data_path - def _get_local_path(self, fname: Optional[str] = None, file_id: - Optional[str] = None): + def _get_local_path(self, fname: Optional[str] = None, file_id: Optional[str] = None): if fname is None and file_id is None: - raise ValueError('Must pass either fname or file_id') + raise ValueError("Must pass either fname or file_id") if fname is not None and file_id is not None: - raise ValueError('Must pass only one of fname or file_id') + raise ValueError("Must pass only one of fname or file_id") if fname is not None: path = self.cache.metadata_path(fname=fname) else: path = self.cache.data_path(file_id=file_id) - exists = path['exists'] - local_path = path['local_path'] + exists = path["exists"] + local_path = path["local_path"] if not exists: - raise FileNotFoundError(f'You started a cache without a ' - f'connection to s3 and {local_path} is ' - 'not already on your system') + raise FileNotFoundError( + f"You started a cache without a connection to s3 and {local_path} is not already on your system" + ) return local_path diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_cache_base.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_cache_base.py index b1bc38db1c..492b91014d 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_cache_base.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_cache_base.py @@ -2,23 +2,20 @@ from pathlib import Path import logging -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - project_apis.data_io import ProjectCloudApiBase +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io import ProjectCloudApiBase from allensdk.core.authentication import DbCredentials -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - project_apis.data_io import BehaviorProjectLimsApi +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io import BehaviorProjectLimsApi class ProjectCacheBase(object): - BUCKET_NAME: str = None PROJECT_NAME: str = None def __init__( - self, - fetch_api: Union[ProjectCloudApiBase, BehaviorProjectLimsApi], - fetch_tries: int = 2, - ): + self, + fetch_api: Union[ProjectCloudApiBase, BehaviorProjectLimsApi], + fetch_tries: int = 2, + ): """ Parameters ========== @@ -41,17 +38,14 @@ def __init__( def manifest(self): if self.cache is None: api_name = type(self.fetch_api).__name__ - raise NotImplementedError(f"A {type(self).__name__} " - f"based on {api_name} " - "does not have an accessible manifest " - "property") + raise NotImplementedError( + f"A {type(self).__name__} based on {api_name} does not have an accessible manifest property" + ) return self.cache.manifest @classmethod def from_s3_cache( - cls, - cache_dir: Union[str, Path], - bucket_name_override: Optional[str] = None + cls, cache_dir: Union[str, Path], bucket_name_override: Optional[str] = None ) -> "ProjectCacheBase": """instantiates this object with a connection to an s3 bucket and/or a local cache related to that bucket. @@ -76,20 +70,15 @@ def from_s3_cache( fetch_api = cls.cloud_api_class().from_s3_cache( cache_dir, - bucket_name=( - bucket_name_override if bucket_name_override is not None - else cls.BUCKET_NAME), + bucket_name=(bucket_name_override if bucket_name_override is not None else cls.BUCKET_NAME), project_name=cls.PROJECT_NAME, - ui_class_name=cls.__name__) + ui_class_name=cls.__name__, + ) return cls(fetch_api=fetch_api) @classmethod - def from_local_cache( - cls, - cache_dir: Union[str, Path], - use_static_cache: bool = False - ) -> "ProjectCacheBase": + def from_local_cache(cls, cache_dir: Union[str, Path], use_static_cache: bool = False) -> "ProjectCacheBase": """instantiates this object with a local cache. Parameters @@ -108,26 +97,25 @@ def from_local_cache( """ fetch_api = cls.cloud_api_class().from_local_cache( - cache_dir, - project_name=cls.PROJECT_NAME, - ui_class_name=cls.__name__, - use_static_cache=use_static_cache + cache_dir, project_name=cls.PROJECT_NAME, ui_class_name=cls.__name__, use_static_cache=use_static_cache ) return cls(fetch_api=fetch_api) @classmethod - def from_lims(cls, manifest: Optional[Union[str, Path]] = None, - version: Optional[str] = None, - cache: bool = False, - fetch_tries: int = 2, - lims_credentials: Optional[DbCredentials] = None, - mtrain_credentials: Optional[DbCredentials] = None, - host: Optional[str] = None, - scheme: Optional[str] = None, - asynchronous: bool = True, - data_release_date: Optional[Union[str, List[str]]] = None, - passed_only: bool = True - ) -> "ProjectCacheBase": + def from_lims( + cls, + manifest: Optional[Union[str, Path]] = None, + version: Optional[str] = None, + cache: bool = False, + fetch_tries: int = 2, + lims_credentials: Optional[DbCredentials] = None, + mtrain_credentials: Optional[DbCredentials] = None, + host: Optional[str] = None, + scheme: Optional[str] = None, + asynchronous: bool = True, + data_release_date: Optional[Union[str, List[str]]] = None, + passed_only: bool = True, + ) -> "ProjectCacheBase": """ Construct a ProjectCacheBase with a lims api. @@ -169,8 +157,7 @@ def from_lims(cls, manifest: Optional[Union[str, Path]] = None, ProjectCacheBase instance with a LIMS fetch API """ if host and scheme: - app_kwargs = {"host": host, "scheme": scheme, - "asynchronous": asynchronous} + app_kwargs = {"host": host, "scheme": scheme, "asynchronous": asynchronous} else: app_kwargs = None fetch_api = cls.lims_api_class().default( @@ -178,10 +165,9 @@ def from_lims(cls, manifest: Optional[Union[str, Path]] = None, mtrain_credentials=mtrain_credentials, data_release_date=data_release_date, app_kwargs=app_kwargs, - passed_only=passed_only + passed_only=passed_only, ) - return cls(fetch_api=fetch_api, manifest=manifest, version=version, - cache=cache, fetch_tries=fetch_tries) + return cls(fetch_api=fetch_api, manifest=manifest, version=version, cache=cache, fetch_tries=fetch_tries) def _cache_not_implemented(self, method_name: str) -> None: """ @@ -202,13 +188,10 @@ def construct_local_manifest(self) -> None: thinks that you need to run this method). """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('construct_local_manifest') + self._cache_not_implemented("construct_local_manifest") self.fetch_api.cache.construct_local_manifest() - def compare_manifests(self, - manifest_0_name: str, - manifest_1_name: str - ) -> str: + def compare_manifests(self, manifest_0_name: str, manifest_1_name: str) -> str: """ Compare two manifests from this dataset. Return a dict containing the list of metadata and data files that changed @@ -229,9 +212,8 @@ def compare_manifests(self, manifest_0 to manifest_1 """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('compare_manifests') - return self.fetch_api.cache.compare_manifests(manifest_0_name, - manifest_1_name) + self._cache_not_implemented("compare_manifests") + return self.fetch_api.cache.compare_manifests(manifest_0_name, manifest_1_name) def load_latest_manifest(self) -> None: """ @@ -239,7 +221,7 @@ def load_latest_manifest(self) -> None: version of the dataset. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('load_latest_manifest') + self._cache_not_implemented("load_latest_manifest") self.fetch_api.cache.load_latest_manifest() self.load_manifest(self.current_manifest()) @@ -249,7 +231,7 @@ def latest_downloaded_manifest_file(self) -> str: available on your local system. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('latest_downloaded_manifest_file') + self._cache_not_implemented("latest_downloaded_manifest_file") return self.fetch_api.cache.latest_downloaded_manifest_file def latest_manifest_file(self) -> str: @@ -259,7 +241,7 @@ def latest_manifest_file(self) -> str: if this is a cloud-backed cache. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('latest_manifest_file') + self._cache_not_implemented("latest_manifest_file") return self.fetch_api.cache.latest_manifest_file def load_manifest(self, manifest_name: str): @@ -273,7 +255,7 @@ def load_manifest(self, manifest_name: str): self.manifest_file_names """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('load_manifest') + self._cache_not_implemented("load_manifest") self.fetch_api.load_manifest(manifest_name) def list_all_downloaded_manifests(self) -> list: @@ -282,7 +264,7 @@ def list_all_downloaded_manifests(self) -> list: that have been downloaded to this cache. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('list_all_downloaded_manifests') + self._cache_not_implemented("list_all_downloaded_manifests") return self.fetch_api.cache.list_all_downloaded_manifests() def list_manifest_file_names(self) -> list: @@ -291,7 +273,7 @@ def list_manifest_file_names(self) -> list: associated with this dataset. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('list_manifest_file_names') + self._cache_not_implemented("list_manifest_file_names") return self.fetch_api.cache.manifest_file_names def current_manifest(self) -> Union[None, str]: @@ -300,5 +282,5 @@ def current_manifest(self) -> Union[None, str]: used by this cache. """ if not isinstance(self.fetch_api, self.cloud_api_class()): - self._cache_not_implemented('current_manifest') + self._cache_not_implemented("current_manifest") return self.fetch_api.cache.current_manifest diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/__main__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/__main__.py index 56480d6b16..11bde70980 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/__main__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/__main__.py @@ -1,5 +1,6 @@ - -from allensdk.brain_observatory.behavior.behavior_project_cache.project_metadata_writer.behavior_project_metadata_writer import BehaviorProjectMetadataWriter # noqa: E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_metadata_writer.behavior_project_metadata_writer import ( + BehaviorProjectMetadataWriter, +) # noqa: E501 """Module for creating behavior ophys metadata tables and adding NWB file paths from the specified directory to the behavior ophys metadata tables. diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/behavior_project_metadata_writer.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/behavior_project_metadata_writer.py index 462b07974c..1b996408ab 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/behavior_project_metadata_writer.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/behavior_project_metadata_writer.py @@ -5,11 +5,10 @@ from pathlib import Path import allensdk -from allensdk.brain_observatory.behavior.behavior_project_cache import \ - VisualBehaviorOphysProjectCache +from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorOphysProjectCache from allensdk.brain_observatory.behavior.behavior_project_cache.project_metadata_writer.schemas import ( # noqa: E501 BehaviorOphysMetadataInputSchema, - DataReleaseToolsInputSchema + DataReleaseToolsInputSchema, ) from allensdk.brain_observatory.data_release_utils.metadata_utils.id_generator import ( # noqa: E501 FileIDGenerator, @@ -21,27 +20,16 @@ ######### # These columns should be dropped from external-facing metadata ######### -SESSION_SUPPRESS = ( - 'donor_id', - 'foraging_id', - 'session_name', - 'specimen_id' -) -OPHYS_EXPERIMENTS_SUPPRESS = SESSION_SUPPRESS + ( - 'behavior_session_uuid', - 'published_at', - 'isi_experiment_id' -) -OPHYS_EXPERIMENTS_SUPPRESS_FINAL = [ - 'container_workflow_state', - 'experiment_workflow_state'] +SESSION_SUPPRESS = ("donor_id", "foraging_id", "session_name", "specimen_id") +OPHYS_EXPERIMENTS_SUPPRESS = SESSION_SUPPRESS + ("behavior_session_uuid", "published_at", "isi_experiment_id") +OPHYS_EXPERIMENTS_SUPPRESS_FINAL = ["container_workflow_state", "experiment_workflow_state"] ######### OUTPUT_METADATA_FILENAMES = { - 'behavior_session_table': 'behavior_session_table.csv', - 'ophys_session_table': 'ophys_session_table.csv', - 'ophys_experiment_table': 'ophys_experiment_table.csv', - 'ophys_cells_table': 'ophys_cells_table.csv' + "behavior_session_table": "behavior_session_table.csv", + "ophys_session_table": "ophys_session_table.csv", + "ophys_experiment_table": "ophys_experiment_table.csv", + "ophys_cells_table": "ophys_cells_table.csv", } @@ -52,60 +40,49 @@ class BehaviorProjectMetadataWriter(argschema.ArgSchemaParser): default_output_schema = DataReleaseToolsInputSchema def run(self): - """Create metadata tables and add file paths/ids. - """ + """Create metadata tables and add file paths/ids.""" self._initialize_metadata_writer() self.write_metadata() def _initialize_metadata_writer(self): - """Initialize the project cache and release file information. - """ + """Initialize the project cache and release file information.""" self._file_id_generator = FileIDGenerator() - self._behavior_project_cache = \ - VisualBehaviorOphysProjectCache.from_lims( - data_release_date=self.args['data_release_date']) + self._behavior_project_cache = VisualBehaviorOphysProjectCache.from_lims( + data_release_date=self.args["data_release_date"] + ) def write_metadata(self): """Writes metadata to csv""" - os.makedirs(self.args['output_dir'], exist_ok=True) + os.makedirs(self.args["output_dir"], exist_ok=True) - self.logger.info('Writing ophys sessions table') + self.logger.info("Writing ophys sessions table") self._write_ophys_sessions() - self.logger.info('Writing ophys experiments table') + self.logger.info("Writing ophys experiments table") self._write_ophys_experiments() - self.logger.info('Writing behavior sessions table') + self.logger.info("Writing behavior sessions table") self._write_behavior_sessions() - self.logger.info('Writing ophys cells table') + self.logger.info("Writing ophys cells table") self._write_ophys_cells() self._write_manifest() def _write_behavior_sessions( - self, - suppress=SESSION_SUPPRESS, - output_filename=OUTPUT_METADATA_FILENAMES[ - 'behavior_session_table'], - include_trial_metrics: bool = True + self, + suppress=SESSION_SUPPRESS, + output_filename=OUTPUT_METADATA_FILENAMES["behavior_session_table"], + include_trial_metrics: bool = True, ): - behavior_sessions = self._behavior_project_cache. \ - get_behavior_session_table( - suppress=suppress, - as_df=True, - include_trial_metrics=include_trial_metrics) + behavior_sessions = self._behavior_project_cache.get_behavior_session_table( + suppress=suppress, as_df=True, include_trial_metrics=include_trial_metrics + ) # Add release files - ophys_experiments = \ - self._behavior_project_cache.get_ophys_experiment_table( - suppress=suppress, as_df=True) - ophys_session_mask = behavior_sessions.ophys_session_id.isin( - ophys_experiments.ophys_session_id - ) + ophys_experiments = self._behavior_project_cache.get_ophys_experiment_table(suppress=suppress, as_df=True) + ophys_session_mask = behavior_sessions.ophys_session_id.isin(ophys_experiments.ophys_session_id) behavior_session_w_ophys = behavior_sessions[ophys_session_mask] - behavior_session_w_ophys["file_id"] = \ - self._file_id_generator.dummy_value - behavior_session_w_out_ophys = behavior_sessions[ - ~ophys_session_mask] + behavior_session_w_ophys["file_id"] = self._file_id_generator.dummy_value + behavior_session_w_out_ophys = behavior_sessions[~ophys_session_mask] behavior_session_w_out_ophys.reset_index(inplace=True) behavior_session_w_out_ophys = add_file_paths_to_metadata_table( metadata_table=behavior_session_w_out_ophys, @@ -116,39 +93,25 @@ def _write_behavior_sessions( data_dir_col="behavior_session_id", on_missing_file=self.args["on_missing_file"], ) - behavior_session_w_out_ophys.set_index("behavior_session_id", - inplace=True) - behavior_sessions = pd.concat( - [behavior_session_w_out_ophys, behavior_session_w_ophys] - ) + behavior_session_w_out_ophys.set_index("behavior_session_id", inplace=True) + behavior_sessions = pd.concat([behavior_session_w_out_ophys, behavior_session_w_ophys]) + + self._write_metadata_table(df=behavior_sessions, filename=output_filename) + + def _write_ophys_cells(self, output_filename=OUTPUT_METADATA_FILENAMES["ophys_cells_table"]): + ophys_cells = self._behavior_project_cache.get_ophys_cells_table() + self._write_metadata_table(df=ophys_cells, filename=output_filename) - self._write_metadata_table(df=behavior_sessions, - filename=output_filename) - - def _write_ophys_cells(self, - output_filename=OUTPUT_METADATA_FILENAMES[ - 'ophys_cells_table']): - ophys_cells = self._behavior_project_cache. \ - get_ophys_cells_table() - self._write_metadata_table(df=ophys_cells, - filename=output_filename) - - def _write_ophys_sessions(self, suppress=SESSION_SUPPRESS, - output_filename=OUTPUT_METADATA_FILENAMES[ - 'ophys_session_table' - ]): - ophys_sessions = self._behavior_project_cache. \ - get_ophys_session_table(suppress=suppress, as_df=True) - self._write_metadata_table(df=ophys_sessions, - filename=output_filename) - - def _write_ophys_experiments(self, suppress=OPHYS_EXPERIMENTS_SUPPRESS, - output_filename=OUTPUT_METADATA_FILENAMES[ - 'ophys_experiment_table' - ]): - ophys_experiments = \ - self._behavior_project_cache.get_ophys_experiment_table( - suppress=suppress, as_df=True) + def _write_ophys_sessions( + self, suppress=SESSION_SUPPRESS, output_filename=OUTPUT_METADATA_FILENAMES["ophys_session_table"] + ): + ophys_sessions = self._behavior_project_cache.get_ophys_session_table(suppress=suppress, as_df=True) + self._write_metadata_table(df=ophys_sessions, filename=output_filename) + + def _write_ophys_experiments( + self, suppress=OPHYS_EXPERIMENTS_SUPPRESS, output_filename=OUTPUT_METADATA_FILENAMES["ophys_experiment_table"] + ): + ophys_experiments = self._behavior_project_cache.get_ophys_experiment_table(suppress=suppress, as_df=True) # Add release files ophys_experiments.reset_index(inplace=True) @@ -161,16 +124,12 @@ def _write_ophys_experiments(self, suppress=OPHYS_EXPERIMENTS_SUPPRESS, data_dir_col="ophys_experiment_id", on_missing_file=self.args["on_missing_file"], ) - ophys_experiments.set_index('ophys_experiment_id', inplace=True) + ophys_experiments.set_index("ophys_experiment_id", inplace=True) # users don't need to see these - ophys_experiments.drop( - labels=OPHYS_EXPERIMENTS_SUPPRESS_FINAL, - inplace=True, - axis=1) + ophys_experiments.drop(labels=OPHYS_EXPERIMENTS_SUPPRESS_FINAL, inplace=True, axis=1) - self._write_metadata_table(df=ophys_experiments, - filename=output_filename) + self._write_metadata_table(df=ophys_experiments, filename=output_filename) return ophys_experiments @@ -187,15 +146,14 @@ def _write_metadata_table(self, df: pd.DataFrame, filename: str): """ filepath = os.path.join(self.args["output_dir"], filename) - self.logger.info(f'Writing {filepath}') + self.logger.info(f"Writing {filepath}") df = df.reset_index() df.to_csv(filepath, index=False) def _write_manifest(self): def get_abs_path(filename): - return os.path.abspath(os.path.join(self.args["output_dir"], - filename)) + return os.path.abspath(os.path.join(self.args["output_dir"], filename)) metadata_filenames = OUTPUT_METADATA_FILENAMES.values() metadata_files = [get_abs_path(f) for f in metadata_filenames] diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/schemas.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/schemas.py index c084f595b3..1da86d8609 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/schemas.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_metadata_writer/schemas.py @@ -6,14 +6,11 @@ class BaseMetadataWriterInputSchema(argschema.ArgSchema): - behavior_nwb_dir = argschema.fields.InputDir( required=True, default=None, allow_none=True, - description=( - "The directory where behavior_nwb sessions are to be found." - ), + description=("The directory where behavior_nwb sessions are to be found."), ) behavior_nwb_prefix = argschema.fields.Str( required=False, @@ -27,16 +24,12 @@ class BaseMetadataWriterInputSchema(argschema.ArgSchema): output_dir = argschema.fields.OutputDir( required=True, - description=( - "Directory to output metadata tables." - ), + description=("Directory to output metadata tables."), ) clobber = argschema.fields.Boolean( default=False, - description=( - "If false, throw an error if output files already exist." - ), + description=("If false, throw an error if output files already exist."), ) on_missing_file = argschema.fields.Str( @@ -57,7 +50,6 @@ class BaseMetadataWriterInputSchema(argschema.ArgSchema): class BehaviorOphysMetadataInputSchema(BaseMetadataWriterInputSchema): - data_release_date = argschema.fields.List( argschema.fields.String, required=True, @@ -72,9 +64,7 @@ class BehaviorOphysMetadataInputSchema(BaseMetadataWriterInputSchema): required=True, allow_none=True, default=None, - description=( - "The directory where ophys experiments are to be found." - ), + description=("The directory where ophys experiments are to be found."), ) ophys_nwb_prefix = argschema.fields.Str( required=False, @@ -89,10 +79,10 @@ class BehaviorOphysMetadataInputSchema(BaseMetadataWriterInputSchema): @post_load def validate_paths(self, data, **kwargs): fname_lookup = { - 'behavior_session_table': 'behavior_session_table.csv', - 'ophys_session_table': 'ophys_session_table.csv', - 'ophys_experiment_table': 'ophys_experiment_table.csv', - 'ophys_cells_table': 'ophys_cells_table.csv' + "behavior_session_table": "behavior_session_table.csv", + "ophys_session_table": "ophys_session_table.csv", + "ophys_experiment_table": "ophys_experiment_table.csv", + "ophys_cells_table": "ophys_cells_table.csv", } out_dir = pathlib.Path(data["output_dir"]) @@ -105,15 +95,12 @@ def validate_paths(self, data, **kwargs): if len(msg) > 0: raise RuntimeError( - "The following files already exist\n" - f"{msg}" - "Run with clobber=True if you want to overwrite" + f"The following files already exist\n{msg}Run with clobber=True if you want to overwrite" ) return data class PipelineMetadataSchema(DefaultSchema): - name = argschema.fields.Str( required=True, allow_none=False, @@ -144,24 +131,17 @@ class DataReleaseToolsInputSchema(argschema.ArgSchema): metadata_files = argschema.fields.List( argschema.fields.InputFile, - description=( - "Paths to the metadata '.csv' files written by this modules." - ), + description=("Paths to the metadata '.csv' files written by this modules."), ) data_pipeline_metadata = argschema.fields.Nested( PipelineMetadataSchema, many=True, - description=( - "Metadata about the pipeline used to create this data release." - ), + description=("Metadata about the pipeline used to create this data release."), ) project_name = argschema.fields.Str( required=True, allow_none=False, - description=( - "The project name to be passed along to the data_release_tool " - "when uploading this dataset." - ), + description=("The project name to be passed along to the data_release_tool when uploading this dataset."), ) diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/experiments_table.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/experiments_table.py index 4531f6553b..cfe57e1677 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/experiments_table.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/experiments_table.py @@ -2,31 +2,24 @@ import pandas as pd -from allensdk.brain_observatory.behavior.behavior_project_cache.tables\ - .ophys_mixin import \ - OphysMixin -from allensdk.brain_observatory.behavior.behavior_project_cache.tables\ - .project_table import \ - ProjectTable +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.ophys_mixin import OphysMixin +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.project_table import ProjectTable from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.prior_exposure_processing import ( # noqa: E501 add_experience_level_ophys, ) -from allensdk.brain_observatory.behavior.behavior_project_cache.tables\ - .util.experiments_table_utils import ( - add_passive_flag_to_ophys_experiment_table, - add_image_set_to_experiment_table) -from allensdk.core.dataframe_utils import ( - enforce_df_column_order +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.experiments_table_utils import ( + add_passive_flag_to_ophys_experiment_table, + add_image_set_to_experiment_table, ) +from allensdk.core.dataframe_utils import enforce_df_column_order from allensdk.brain_observatory.ophys.project_constants import VBO_METADATA_COLUMN_ORDER # noqa: E501 class ExperimentsTable(ProjectTable, OphysMixin): """Class for storing and manipulating project-level data at the behavior-ophys experiment level""" - def __init__(self, df: pd.DataFrame, - suppress: Optional[List[str]] = None, - passed_only: bool = True): + + def __init__(self, df: pd.DataFrame, suppress: Optional[List[str]] = None, passed_only: bool = True): """ Parameters ---------- @@ -42,10 +35,7 @@ def __init__(self, df: pd.DataFrame, ProjectTable.__init__(self, df=df, suppress=suppress) OphysMixin.__init__(self) self.final_processing() - self._df = enforce_df_column_order( - self._df, - VBO_METADATA_COLUMN_ORDER - ) + self._df = enforce_df_column_order(self._df, VBO_METADATA_COLUMN_ORDER) def postprocess_base(self): """ diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/metadata_table_schemas.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/metadata_table_schemas.py index d0cb5b6910..d57fbfdeb1 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/metadata_table_schemas.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/metadata_table_schemas.py @@ -13,23 +13,13 @@ class BehaviorSessionMetadataSchema(RaisingSchema): behavior_session_id = Int( required=False, allow_none=True, - description=( - "Unique identifier for the " - "behavior session to write into " - "NWB format" - ), - ) - cre_line = String( - required=False, - allow_none=True, - description="Genetic cre line of the subject." + description=("Unique identifier for the behavior session to write into NWB format"), ) + cre_line = String(required=False, allow_none=True, description="Genetic cre line of the subject.") date_of_acquisition = String( required=False, allow_none=True, - description=( - "Date of acquisition of " "behavior session, in string " "format" - ), + description=("Date of acquisition of behavior session, in string format"), ) driver_line = List( String, @@ -38,16 +28,8 @@ class BehaviorSessionMetadataSchema(RaisingSchema): cli_as_single_argument=True, description="Genetic driver line(s) of subject", ) - equipment_name = String( - required=False, - allow_none=True, - description=("Name of the equipment used.") - ) - full_genotype = String( - required=False, - allow_none=True, - description="Full genotype of subject" - ) + equipment_name = String(required=False, allow_none=True, description=("Name of the equipment used.")) + full_genotype = String(required=False, allow_none=True, description="Full genotype of subject") mouse_id = String( required=False, allow_none=True, @@ -58,35 +40,19 @@ class BehaviorSessionMetadataSchema(RaisingSchema): allow_none=True, description="LabTracks ID of the subject. aka external_specimen_name.", ) - reporter_line = String( - required=False, - allow_none=True, - description="Genetic reporter line(s) of subject" - ) - session_type = String( - required=False, - allow_none=True, - description="Full name of session type." - ) - sex = String( - required=False, - allow_none=True, - description="Subject sex" - ) + reporter_line = String(required=False, allow_none=True, description="Genetic reporter line(s) of subject") + session_type = String(required=False, allow_none=True, description="Full name of session type.") + sex = String(required=False, allow_none=True, description="Subject sex") @mm.post_load def convert_date_time(self, data, **kwargs): """Change date_of_acquisition to a date time type from string.""" - data["date_of_acquisition"] = pd.to_datetime( - data["date_of_acquisition"], format="ISO8601", utc=True - ) + data["date_of_acquisition"] = pd.to_datetime(data["date_of_acquisition"], format="ISO8601", utc=True) return data class OphysExperimentMetadataSchema(BehaviorSessionMetadataSchema): - imaging_depth = Int( - required=True, description="Imaging depth of the OphysExperiment." - ) + imaging_depth = Int(required=True, description="Imaging depth of the OphysExperiment.") imaging_plane_group = Int( required=True, allow_none=True, @@ -95,12 +61,9 @@ class OphysExperimentMetadataSchema(BehaviorSessionMetadataSchema): indicator = String(required=True, description="String indicator line.") ophys_container_id = Int( required=True, - description="ID of ophys container of which this experiment is a " - "member.", - ) - ophys_experiment_id = Int( - required=True, description="ID of the ophys experiment." + description="ID of ophys container of which this experiment is a member.", ) + ophys_experiment_id = Int(required=True, description="ID of the ophys experiment.") ophys_session_id = Int( required=True, description="ID of the ophys session this experiment is a member of.", @@ -109,6 +72,4 @@ class OphysExperimentMetadataSchema(BehaviorSessionMetadataSchema): required=True, description="Average of all experiments in the container.", ) - targeted_structure = String( - required=True, description="String name of the structure targeted." - ) + targeted_structure = String(required=True, description="String name of the structure targeted.") diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_mixin.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_mixin.py index b427dfb3ef..005d3c1a93 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_mixin.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_mixin.py @@ -7,6 +7,7 @@ class OphysMixin: _df: pd.DataFrame """A mixin class for ophys project data""" + def __init__(self): self._merge_columns() @@ -19,12 +20,12 @@ def _merge_columns(self): columns = self._df.columns to_drop = [] for column in columns: - if column.endswith('_behavior'): - column = column.replace('_behavior', '') - if f'{column}_ophys' in self._df: + if column.endswith("_behavior"): + column = column.replace("_behavior", "") + if f"{column}_ophys" in self._df: self._check_behavior_ophys_equal(column=column) self._df[column] = self._merge_column_values(column=column) - to_drop += [f'{column}_behavior', f'{column}_ophys'] + to_drop += [f"{column}_behavior", f"{column}_ophys"] self._df.drop(to_drop, axis=1, inplace=True) def _check_behavior_ophys_equal(self, column: str): @@ -36,14 +37,15 @@ def _check_behavior_ophys_equal(self, column: str): column Column to check """ - mask = ~self._df[f'{column}_ophys'].isna() + mask = ~self._df[f"{column}_ophys"].isna() - if not self._df[f'{column}_ophys'][mask].equals( - self._df[f'{column}_behavior'][mask]): - warnings.warn("BehaviorSession and OphysSession " - f"{column} do not agree. This is " - "likely due to issues with the data in " - f"LIMS.") + if not self._df[f"{column}_ophys"][mask].equals(self._df[f"{column}_behavior"][mask]): + warnings.warn( + "BehaviorSession and OphysSession " + f"{column} do not agree. This is " + "likely due to issues with the data in " + f"LIMS." + ) def _merge_column_values(self, column: str) -> pd.Series: """Takes the non-null values from ophys and merges with behavior @@ -60,7 +62,6 @@ def _merge_column_values(self, column: str) -> pd.Series: Merged Series """ - values = self._df[f'{column}_ophys'] - values.loc[values.isna()] = \ - self._df[f'{column}_behavior'][values.isna()] + values = self._df[f"{column}_ophys"] + values.loc[values.isna()] = self._df[f"{column}_behavior"][values.isna()] return values diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_sessions_table.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_sessions_table.py index 480fe5ac52..b1526b0c4f 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_sessions_table.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/ophys_sessions_table.py @@ -12,9 +12,7 @@ parse_num_cortical_structures, parse_num_depths, ) -from allensdk.core.dataframe_utils import ( - enforce_df_column_order -) +from allensdk.core.dataframe_utils import enforce_df_column_order from allensdk.brain_observatory.ophys.project_constants import VBO_METADATA_COLUMN_ORDER # noqa: E501 @@ -43,26 +41,15 @@ def __init__( self._index_column = index_column ProjectTable.__init__(self, df=df, suppress=suppress) OphysMixin.__init__(self) - self._df = enforce_df_column_order( - self._df, - VBO_METADATA_COLUMN_ORDER - ) + self._df = enforce_df_column_order(self._df, VBO_METADATA_COLUMN_ORDER) def postprocess_additional(self): # Add ophys specific information. - project_code_col = ( - "project_code_ophys" - if "project_code_ophys" in self._df.columns - else "project_code" - ) + project_code_col = "project_code_ophys" if "project_code_ophys" in self._df.columns else "project_code" self._df["num_targeted_structures"] = ( - self._df[project_code_col] - .apply(parse_num_cortical_structures) - .astype("Int64") - ) - self._df["num_depths_per_area"] = ( - self._df[project_code_col].apply(parse_num_depths).astype("Int64") + self._df[project_code_col].apply(parse_num_cortical_structures).astype("Int64") ) + self._df["num_depths_per_area"] = self._df[project_code_col].apply(parse_num_depths).astype("Int64") # Possibly explode and reindex self.__explode() @@ -70,11 +57,7 @@ def __explode(self): if self._index_column == "ophys_session_id": pass elif self._index_column == "ophys_experiment_id": - self._df = ( - self._df.reset_index() - .explode("ophys_experiment_id") - .set_index("ophys_experiment_id") - ) + self._df = self._df.reset_index().explode("ophys_experiment_id").set_index("ophys_experiment_id") else: self._logger.warning( f"Invalid value for `by`, '{self._index_column}', passed to " diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/project_table.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/project_table.py index c14380590f..30f95e9c21 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/project_table.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/project_table.py @@ -6,8 +6,8 @@ class ProjectTable(ABC): """Class for storing and manipulating project-level data""" - def __init__(self, df: pd.DataFrame, - suppress: Optional[Iterable[str]] = None): + + def __init__(self, df: pd.DataFrame, suppress: Optional[Iterable[str]] = None): """ Parameters ---------- @@ -40,8 +40,7 @@ def postprocess(self): self.postprocess_additional() if self._suppress: - self._df.drop(columns=self._suppress, inplace=True, - errors="ignore") + self._df.drop(columns=self._suppress, inplace=True, errors="ignore") @abstractmethod def postprocess_additional(self): diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/sessions_table.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/sessions_table.py index cf977e3573..0432d8ea36 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/sessions_table.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/sessions_table.py @@ -24,9 +24,7 @@ get_prior_exposures_to_session_type, add_experience_level_ophys, ) -from allensdk.core.dataframe_utils import ( - enforce_df_column_order -) +from allensdk.core.dataframe_utils import enforce_df_column_order from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory.behavior.data_objects.licks import Licks @@ -81,64 +79,39 @@ def __init__( self._include_trial_metrics = include_trial_metrics ProjectTable.__init__(self, df=df, suppress=suppress) OphysMixin.__init__(self) - self._df = enforce_df_column_order( - self._df, - VBO_METADATA_COLUMN_ORDER - ) + self._df = enforce_df_column_order(self._df, VBO_METADATA_COLUMN_ORDER) def postprocess_additional(self): # Add subject metadata - self._df["reporter_line"] = self._df["reporter_line"].apply( - ReporterLine.parse - ) - self._df["cre_line"] = self._df["full_genotype"].apply( - lambda x: FullGenotype(x).parse_cre_line() - ) - self._df["indicator"] = self._df["reporter_line"].apply( - lambda x: ReporterLine(x).parse_indicator() - ) + self._df["reporter_line"] = self._df["reporter_line"].apply(ReporterLine.parse) + self._df["cre_line"] = self._df["full_genotype"].apply(lambda x: FullGenotype(x).parse_cre_line()) + self._df["indicator"] = self._df["reporter_line"].apply(lambda x: ReporterLine(x).parse_indicator()) # add session number self.__add_session_number() # add prior exposure - self._df[ - "prior_exposures_to_session_type" - ] = get_prior_exposures_to_session_type(df=self._df) - self._df[ - "prior_exposures_to_image_set" - ] = get_prior_exposures_to_image_set(df=self._df) - self._df[ - "prior_exposures_to_omissions" - ] = get_prior_exposures_to_omissions( + self._df["prior_exposures_to_session_type"] = get_prior_exposures_to_session_type(df=self._df) + self._df["prior_exposures_to_image_set"] = get_prior_exposures_to_image_set(df=self._df) + self._df["prior_exposures_to_omissions"] = get_prior_exposures_to_omissions( df=self._df, fetch_api=self._fetch_api ) self._df = add_experience_level_ophys(self._df) - self._df["behavior_type"] = self._df["session_type"].apply( - parse_behavior_context - ) - self._df["image_set"] = self._df["session_type"].apply( - parse_stimulus_set - ) + self._df["behavior_type"] = self._df["session_type"].apply(parse_behavior_context) + self._df["image_set"] = self._df["session_type"].apply(parse_stimulus_set) if self._include_trial_metrics: # add trial metrics trial_metrics = multiprocessing_helper( target=self._get_trial_metrics_helper, behavior_session_ids=self._df.index.tolist(), - lims_engine=db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ), + lims_engine=db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP), progress_bar_title="Getting trial metrics for each session", ) - trial_metrics = pd.DataFrame(trial_metrics).set_index( - "behavior_session_id" - ) - self._df = self._df.merge( - trial_metrics, left_index=True, right_index=True - ) + trial_metrics = pd.DataFrame(trial_metrics).set_index("behavior_session_id") + self._df = self._df.merge(trial_metrics, left_index=True, right_index=True) # Add data from ophys session if self._ophys_session_table is not None: @@ -152,9 +125,7 @@ def postprocess_additional(self): self._df = self._df.set_index("behavior_session_id") # Prioritize behavior date_of_acquisition - self._df["date_of_acquisition"] = self._df[ - "date_of_acquisition_behavior" - ] + self._df["date_of_acquisition"] = self._df["date_of_acquisition_behavior"] self._df = self._df.drop( ["date_of_acquisition_behavior", "date_of_acquisition_ophys"], axis=1, @@ -162,9 +133,7 @@ def postprocess_additional(self): # Enforce an integer type on due to there not being a value for # ophys_session_id for every behavior_session. Pandas defaults to # NaN here, changing the type to float unless otherwise fixed. - self._df["ophys_session_id"] = self._df["ophys_session_id"].astype( - "Int64" - ) + self._df["ophys_session_id"] = self._df["ophys_session_id"].astype("Int64") def __add_session_number(self): """Parses session number from session type and and adds to dataframe""" @@ -179,9 +148,7 @@ def parse_session_number(session_type: str): session_type = self._df["session_type"] session_type = session_type[session_type.notnull()] - self._df.loc[ - session_type.index, "session_number" - ] = session_type.apply(parse_session_number).astype('Int64') + self._df.loc[session_type.index, "session_number"] = session_type.apply(parse_session_number).astype("Int64") @staticmethod def _get_trial_metrics_helper(*args) -> Dict: @@ -189,12 +156,8 @@ def _get_trial_metrics_helper(*args) -> Dict: Meant to be called by a multiprocessing worker""" behavior_session_id, db_conn = args[0] - stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=behavior_session_id, db=db_conn - ) - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, monitor_delay=0.0 - ) + stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=behavior_session_id, db=db_conn) + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.0) trials = Trials.from_stimulus_file( stimulus_file=stimulus_file, diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/experiments_table_utils.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/experiments_table_utils.py index c5461b372c..d137942411 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/experiments_table_utils.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/experiments_table_utils.py @@ -1,8 +1,7 @@ import pandas as pd -def add_passive_flag_to_ophys_experiment_table( - experiments_table: pd.DataFrame) -> pd.DataFrame: +def add_passive_flag_to_ophys_experiment_table(experiments_table: pd.DataFrame) -> pd.DataFrame: """ adds a column to ophys_experiment_table that contains a Boolean indicating whether a session was passive or not based on session @@ -27,17 +26,16 @@ def add_passive_flag_to_ophys_experiment_table( experiments_table = experiments_table.copy(deep=True) - experiments_table['passive'] = False + experiments_table["passive"] = False session_25 = experiments_table.session_number.isin([2, 5]) passive_indices = experiments_table[session_25].index.values - experiments_table.loc[passive_indices, 'passive'] = True + experiments_table.loc[passive_indices, "passive"] = True return experiments_table -def add_image_set_to_experiment_table( - experiments_table: pd.DataFrame) -> pd.DataFrame: +def add_image_set_to_experiment_table(experiments_table: pd.DataFrame) -> pd.DataFrame: """ Adds a column 'image_set' to the experiment_table, determined based on the image set listed in the session_type column string @@ -61,9 +59,8 @@ def add_image_set_to_experiment_table( experiments_table = experiments_table.copy(deep=True) - experiments_table['image_set'] = [ - session_type[15] - if len(session_type) > 15 else 'N/A' - for session_type - in experiments_table.session_type.values.astype(str)] + experiments_table["image_set"] = [ + session_type[15] if len(session_type) > 15 else "N/A" + for session_type in experiments_table.session_type.values.astype(str) + ] return experiments_table diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/image_presentation_utils.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/image_presentation_utils.py index 0f73ba4e63..5cbf942017 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/image_presentation_utils.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/image_presentation_utils.py @@ -21,13 +21,13 @@ def get_image_set(df: pd.DataFrame) -> pd.Series: -------- Series with index same as df whose values are image_set """ + def __get_image_set_name(session_type: Optional[str]): - match = re.match(r'.*images_(?P\w)', session_type) + match = re.match(r".*images_(?P\w)", session_type) if match is None: return None - return match.group('image_set') + return match.group("image_set") - session_type = df['session_type'][ - df['session_type'].notnull()] + session_type = df["session_type"][df["session_type"].notnull()] image_set = session_type.apply(__get_image_set_name) return image_set diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/prior_exposure_processing.py b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/prior_exposure_processing.py index a355304f43..65645e0d18 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/prior_exposure_processing.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/tables/util/prior_exposure_processing.py @@ -21,9 +21,7 @@ def get_prior_exposures_to_session_type(df: pd.DataFrame) -> pd.Series: Series with index same as df and values prior exposure counts to session type """ - return __get_prior_exposure_count(df=df, to=df["session_type"]).astype( - "Int64" - ) + return __get_prior_exposure_count(df=df, to=df["session_type"]).astype("Int64") def get_prior_exposures_to_image_set(df: pd.DataFrame) -> pd.Series: @@ -48,9 +46,7 @@ def get_prior_exposures_to_image_set(df: pd.DataFrame) -> pd.Series: return __get_prior_exposure_count(df=df, to=image_set).astype("Int64") -def get_prior_exposures_to_omissions( - df: pd.DataFrame, fetch_api: BehaviorProjectLimsApi -) -> pd.Series: +def get_prior_exposures_to_omissions(df: pd.DataFrame, fetch_api: BehaviorProjectLimsApi) -> pd.Series: """Get prior exposures to omissions Parameters @@ -103,38 +99,27 @@ def __session_contains_omissions( foraging_ids = habituation_sessions["foraging_id"].tolist() foraging_ids = [f"'{x}'" for x in foraging_ids] - mtrain_stage_parameters = fetch_api.get_behavior_stage_parameters( - foraging_ids=foraging_ids - ) + mtrain_stage_parameters = fetch_api.get_behavior_stage_parameters(foraging_ids=foraging_ids) return habituation_sessions.apply( lambda session: __session_contains_omissions( - mtrain_stage_parameters=mtrain_stage_parameters[ - session["foraging_id"] - ] + mtrain_stage_parameters=mtrain_stage_parameters[session["foraging_id"]] ), axis=1, ) habituation_sessions = __get_habituation_sessions(df=df) if not habituation_sessions.empty: - contains_omissions.loc[ - habituation_sessions.index - ] = __get_habituation_sessions_contain_omissions( + contains_omissions.loc[habituation_sessions.index] = __get_habituation_sessions_contain_omissions( habituation_sessions=habituation_sessions, fetch_api=fetch_api ) contains_omissions.loc[ - (df["session_type"].str.lower().str.contains("ophys")) - & (~df.index.isin(habituation_sessions.index)) + (df["session_type"].str.lower().str.contains("ophys")) & (~df.index.isin(habituation_sessions.index)) ] = True - return __get_prior_exposure_count( - df=df, to=contains_omissions, agg_method="cumsum" - ).astype("Int64") + return __get_prior_exposure_count(df=df, to=contains_omissions, agg_method="cumsum").astype("Int64") -def __get_prior_exposure_count( - df: pd.DataFrame, to: pd.Series, agg_method="cumcount" -) -> pd.Series: +def __get_prior_exposure_count(df: pd.DataFrame, to: pd.Series, agg_method="cumcount") -> pd.Series: """Returns prior exposures a subject had to something i.e can be prior exposures to a stimulus type, a image_set or omission @@ -210,9 +195,7 @@ def add_experience_level_ophys(input_df: pd.DataFrame) -> pd.DataFrame: # do not modify in place table = input_df.copy(deep=True) - session_number = ( - "session_number" if "session_number" in table.columns else "session" - ) + session_number = "session_number" if "session_number" in table.columns else "session" # add experience_level column with strings indicating relevant conditions table["experience_level"] = "None" @@ -262,7 +245,5 @@ def add_experience_level_simple(input_df: pd.DataFrame) -> pd.DataFrame: DataFrame with added "experience_level" column. """ tmp_exposures = input_df["prior_exposures_to_image_set"].fillna(0) - input_df["experience_level"] = np.where( - tmp_exposures == 0, "Novel", "Familiar" - ) + input_df["experience_level"] = np.where(tmp_exposures == 0, "Novel", "Familiar") return input_df diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/utils.py b/allensdk/brain_observatory/behavior/behavior_project_cache/utils.py index 16d1a1b449..d34e17f7fc 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/utils.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/utils.py @@ -5,18 +5,17 @@ class BehaviorCloudCacheVersionException(Exception): pass -def version_check(manifest_version: str, - data_pipeline_version: str, - cmin: str, - cmax: str): +def version_check(manifest_version: str, data_pipeline_version: str, cmin: str, cmax: str): mver_parsed = semver.VersionInfo.parse(manifest_version) cmin_parsed = semver.VersionInfo.parse(cmin) cmax_parsed = semver.VersionInfo.parse(cmax) if (mver_parsed < cmin_parsed) | (mver_parsed >= cmax_parsed): - estr = (f"the manifest has manifest_version {manifest_version} but " - "this version of AllenSDK is compatible only with manifest " - f"versions {cmin} <= X < {cmax}. \n" - "Consider using a version of AllenSDK closer to the version " - f"used to release the data: {data_pipeline_version}") + estr = ( + f"the manifest has manifest_version {manifest_version} but " + "this version of AllenSDK is compatible only with manifest " + f"versions {cmin} <= X < {cmax}. \n" + "Consider using a version of AllenSDK closer to the version " + f"used to release the data: {data_pipeline_version}" + ) raise BehaviorCloudCacheVersionException(estr) diff --git a/allensdk/brain_observatory/behavior/behavior_session.py b/allensdk/brain_observatory/behavior/behavior_session.py index 556cf1f4dd..b2632ca4f2 100644 --- a/allensdk/brain_observatory/behavior/behavior_session.py +++ b/allensdk/brain_observatory/behavior/behavior_session.py @@ -108,9 +108,7 @@ def __init__( eye_tracking_table: Optional[EyeTrackingTable] = None, eye_tracking_rig_geometry: Optional[EyeTrackingRigGeometry] = None, ): - super().__init__( - name="behavior_session", value=None, is_value_self=True - ) + super().__init__(name="behavior_session", value=None, is_value_self=True) self._behavior_session_id = behavior_session_id self._licks = licks @@ -181,66 +179,38 @@ def from_json( else: monitor_delay = session_data["monitor_delay"] - behavior_session_id = BehaviorSessionId.from_json( - dict_repr=session_data - ) + behavior_session_id = BehaviorSessionId.from_json(dict_repr=session_data) - stimulus_file_lookup = stimulus_lookup_from_json( - dict_repr=session_data - ) + stimulus_file_lookup = stimulus_lookup_from_json(dict_repr=session_data) if "sync_file" in session_data: - sync_file = SyncFile.from_json( - dict_repr=session_data, permissive=sync_file_permissive - ) + sync_file = SyncFile.from_json(dict_repr=session_data, permissive=sync_file_permissive) else: sync_file = None if running_speed_load_from_multiple_stimulus_files: - running_acquisition = ( - RunningAcquisition.from_multiple_stimulus_files( - behavior_stimulus_file=( - BehaviorStimulusFile.from_json(dict_repr=session_data) - ), - mapping_stimulus_file=MappingStimulusFile.from_json( - dict_repr=session_data - ), - replay_stimulus_file=ReplayStimulusFile.from_json( - dict_repr=session_data - ), - sync_file=SyncFile.from_json(dict_repr=session_data), - ) + running_acquisition = RunningAcquisition.from_multiple_stimulus_files( + behavior_stimulus_file=(BehaviorStimulusFile.from_json(dict_repr=session_data)), + mapping_stimulus_file=MappingStimulusFile.from_json(dict_repr=session_data), + replay_stimulus_file=ReplayStimulusFile.from_json(dict_repr=session_data), + sync_file=SyncFile.from_json(dict_repr=session_data), ) raw_running_speed = RunningSpeed.from_multiple_stimulus_files( - behavior_stimulus_file=( - BehaviorStimulusFile.from_json(dict_repr=session_data) - ), - mapping_stimulus_file=MappingStimulusFile.from_json( - dict_repr=session_data - ), - replay_stimulus_file=ReplayStimulusFile.from_json( - dict_repr=session_data - ), + behavior_stimulus_file=(BehaviorStimulusFile.from_json(dict_repr=session_data)), + mapping_stimulus_file=MappingStimulusFile.from_json(dict_repr=session_data), + replay_stimulus_file=ReplayStimulusFile.from_json(dict_repr=session_data), sync_file=SyncFile.from_json(dict_repr=session_data), filtered=False, ) running_speed = RunningSpeed.from_multiple_stimulus_files( - behavior_stimulus_file=( - BehaviorStimulusFile.from_json(dict_repr=session_data) - ), - mapping_stimulus_file=MappingStimulusFile.from_json( - dict_repr=session_data - ), - replay_stimulus_file=ReplayStimulusFile.from_json( - dict_repr=session_data - ), + behavior_stimulus_file=(BehaviorStimulusFile.from_json(dict_repr=session_data)), + mapping_stimulus_file=MappingStimulusFile.from_json(dict_repr=session_data), + replay_stimulus_file=ReplayStimulusFile.from_json(dict_repr=session_data), sync_file=SyncFile.from_json(dict_repr=session_data), filtered=True, ) else: - behavior_stimulus_file = ( - stimulus_file_lookup.behavior_stimulus_file - ) + behavior_stimulus_file = stimulus_file_lookup.behavior_stimulus_file running_acquisition = RunningAcquisition.from_stimulus_file( behavior_stimulus_file=behavior_stimulus_file, @@ -284,35 +254,25 @@ def from_json( exclude_columns=stimulus_presentation_exclude_columns, trials=trials, ), - templates=Templates.from_stimulus_file( - stimulus_file=stimulus_file_lookup.behavior_stimulus_file - ), + templates=Templates.from_stimulus_file(stimulus_file=stimulus_file_lookup.behavior_stimulus_file), ) - date_of_acquisition = DateOfAcquisition.from_json( - dict_repr=session_data - ).validate( + date_of_acquisition = DateOfAcquisition.from_json(dict_repr=session_data).validate( stimulus_file=stimulus_file_lookup.behavior_stimulus_file, behavior_session_id=behavior_session_id.value, ) try: - eye_tracking_file = EyeTrackingFile.from_json( - dict_repr=session_data - ) + eye_tracking_file = EyeTrackingFile.from_json(dict_repr=session_data) except KeyError: eye_tracking_file = None if eye_tracking_file is None: # Return empty data to match what is returned by from_nwb. - eye_tracking_table = EyeTrackingTable( - eye_tracking=EyeTrackingTable._get_empty_df() - ) + eye_tracking_table = EyeTrackingTable(eye_tracking=EyeTrackingTable._get_empty_df()) eye_tracking_rig_geometry = None else: try: - eye_tracking_metadata_file = EyeTrackingMetadataFile.from_json( - dict_repr=session_data - ) + eye_tracking_metadata_file = EyeTrackingMetadataFile.from_json(dict_repr=session_data) except KeyError: eye_tracking_metadata_file = None @@ -324,9 +284,7 @@ def from_json( dilation_frames=eye_tracking_dilation_frames, ) - eye_tracking_rig_geometry = EyeTrackingRigGeometry.from_json( - dict_repr=session_data - ) + eye_tracking_rig_geometry = EyeTrackingRigGeometry.from_json(dict_repr=session_data) return cls( behavior_session_id=behavior_session_id, @@ -386,18 +344,14 @@ def from_lims( `BehaviorSession` instance """ if lims_db is None: - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) if monitor_delay is None: monitor_delay = cls._get_monitor_delay() if sync_file is None: try: - sync_file = SyncFile.from_lims( - db=lims_db, behavior_session_id=behavior_session_id - ) + sync_file = SyncFile.from_lims(db=lims_db, behavior_session_id=behavior_session_id) except OneResultExpectedError: sync_file = None @@ -405,10 +359,8 @@ def from_lims( stimulus_file_lookup = StimulusFileLookup() - stimulus_file_lookup.behavior_stimulus_file = ( - BehaviorStimulusFile.from_lims( - db=lims_db, behavior_session_id=behavior_session_id.value - ) + stimulus_file_lookup.behavior_stimulus_file = BehaviorStimulusFile.from_lims( + db=lims_db, behavior_session_id=behavior_session_id.value ) running_acquisition = RunningAcquisition.from_stimulus_file( @@ -428,9 +380,7 @@ def from_lims( filtered=True, ) - behavior_metadata = BehaviorMetadata.from_lims( - behavior_session_id=behavior_session_id, lims_db=lims_db - ) + behavior_metadata = BehaviorMetadata.from_lims(behavior_session_id=behavior_session_id, lims_db=lims_db) ( stimulus_timestamps, @@ -444,9 +394,7 @@ def from_lims( stimulus_file_lookup=stimulus_file_lookup, sync_file=sync_file, monitor_delay=monitor_delay, - project_code=ProjectCode.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db - ), + project_code=ProjectCode.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db), ) if date_of_acquisition is None: @@ -458,20 +406,14 @@ def from_lims( behavior_session_id=behavior_session_id.value, ) - eye_tracking_file = EyeTrackingFile.from_lims( - db=lims_db, behavior_session_id=behavior_session_id.value - ) + eye_tracking_file = EyeTrackingFile.from_lims(db=lims_db, behavior_session_id=behavior_session_id.value) if eye_tracking_file is None: # Return empty data to match what is returned by from_nwb. - eye_tracking_table = EyeTrackingTable( - eye_tracking=EyeTrackingTable._get_empty_df() - ) + eye_tracking_table = EyeTrackingTable(eye_tracking=EyeTrackingTable._get_empty_df()) eye_tracking_rig_geometry = None else: - eye_tracking_video = EyeTrackingVideo.from_lims( - db=lims_db, behavior_session_id=behavior_session_id.value - ) + eye_tracking_video = EyeTrackingVideo.from_lims(db=lims_db, behavior_session_id=behavior_session_id.value) eye_tracking_metadata_file = None @@ -544,9 +486,7 @@ def from_nwb( rewards = Rewards.from_nwb(nwbfile=nwbfile) stimuli = Stimuli.from_nwb( nwbfile=nwbfile, - add_is_change_to_presentations_table=( - add_is_change_to_stimulus_presentations_table - ), + add_is_change_to_presentations_table=(add_is_change_to_stimulus_presentations_table), ) task_parameters = TaskParameters.from_nwb(nwbfile=nwbfile) trials = cls._trials_class().from_nwb(nwbfile=nwbfile) @@ -558,9 +498,7 @@ def from_nwb( message="This nwb file with identifier ", category=UserWarning, ) - eye_tracking_rig_geometry = EyeTrackingRigGeometry.from_nwb( - nwbfile=nwbfile - ) + eye_tracking_rig_geometry = EyeTrackingRigGeometry.from_nwb(nwbfile=nwbfile) with warnings.catch_warnings(): warnings.filterwarnings( action="ignore", @@ -629,9 +567,7 @@ def to_nwb( denoting the stimulus name in the presentations table """ if include_experiment_description: - experiment_description = get_expt_description( - session_type=self._get_session_type() - ) + experiment_description = get_expt_description(session_type=self._get_session_type()) else: experiment_description = None @@ -657,9 +593,7 @@ def to_nwb( self._rewards.to_nwb(nwbfile=nwbfile) self._stimuli.to_nwb( nwbfile=nwbfile, - presentations_stimulus_column_name=( - stimulus_presentations_stimulus_column_name - ), + presentations_stimulus_column_name=(stimulus_presentations_stimulus_column_name), ) self._task_parameters.to_nwb(nwbfile=nwbfile) self._trials.to_nwb(nwbfile=nwbfile) @@ -693,11 +627,7 @@ def list_data_attributes_and_methods(self) -> List[str]: attrs_and_methods_to_ignore.update(dir(NwbWritableInterface)) attrs_and_methods_to_ignore.update(dir(DataObject)) class_dir = dir(self) - attrs_and_methods = [ - r - for r in class_dir - if (r not in attrs_and_methods_to_ignore and not r.startswith("_")) - ] + attrs_and_methods = [r for r in class_dir if (r not in attrs_and_methods_to_ignore and not r.startswith("_"))] return attrs_and_methods # ========================= 'get' methods ========================== @@ -753,9 +683,7 @@ def get_rolling_performance_df(self) -> pd.DataFrame: """ return self._trials.rolling_performance - def get_performance_metrics( - self, engaged_trial_reward_rate_threshold: float = 2.0 - ) -> dict: + def get_performance_metrics(self, engaged_trial_reward_rate_threshold: float = 2.0) -> dict: """Get a dictionary containing a subject's behavior response summary data. @@ -855,43 +783,21 @@ def get_performance_metrics( # the 'hit_trial_count'. rpdf = self.get_rolling_performance_df() - engaged_trial_mask = ( - rpdf["reward_rate"] > engaged_trial_reward_rate_threshold - ) - performance_metrics["maximum_reward_rate"] = np.nanmax( - rpdf["reward_rate"].values - ) - performance_metrics[ - "engaged_trial_count" - ] = self._trials.get_engaged_trial_count( - engaged_trial_reward_rate_threshold=( - engaged_trial_reward_rate_threshold - ) + engaged_trial_mask = rpdf["reward_rate"] > engaged_trial_reward_rate_threshold + performance_metrics["maximum_reward_rate"] = np.nanmax(rpdf["reward_rate"].values) + performance_metrics["engaged_trial_count"] = self._trials.get_engaged_trial_count( + engaged_trial_reward_rate_threshold=(engaged_trial_reward_rate_threshold) ) performance_metrics["mean_hit_rate"] = rpdf["hit_rate"].mean() - performance_metrics["mean_hit_rate_uncorrected"] = rpdf[ - "hit_rate_raw" - ].mean() - performance_metrics["mean_hit_rate_engaged"] = rpdf["hit_rate"][ - engaged_trial_mask - ].mean() - performance_metrics["mean_false_alarm_rate"] = rpdf[ - "false_alarm_rate" - ].mean() - performance_metrics["mean_false_alarm_rate_uncorrected"] = rpdf[ - "false_alarm_rate_raw" - ].mean() - performance_metrics["mean_false_alarm_rate_engaged"] = rpdf[ - "false_alarm_rate" - ][engaged_trial_mask].mean() + performance_metrics["mean_hit_rate_uncorrected"] = rpdf["hit_rate_raw"].mean() + performance_metrics["mean_hit_rate_engaged"] = rpdf["hit_rate"][engaged_trial_mask].mean() + performance_metrics["mean_false_alarm_rate"] = rpdf["false_alarm_rate"].mean() + performance_metrics["mean_false_alarm_rate_uncorrected"] = rpdf["false_alarm_rate_raw"].mean() + performance_metrics["mean_false_alarm_rate_engaged"] = rpdf["false_alarm_rate"][engaged_trial_mask].mean() performance_metrics["mean_dprime"] = rpdf["rolling_dprime"].mean() - performance_metrics["mean_dprime_engaged"] = rpdf["rolling_dprime"][ - engaged_trial_mask - ].mean() + performance_metrics["mean_dprime_engaged"] = rpdf["rolling_dprime"][engaged_trial_mask].mean() performance_metrics["max_dprime"] = rpdf["rolling_dprime"].max() - performance_metrics["max_dprime_engaged"] = rpdf["rolling_dprime"][ - engaged_trial_mask - ].max() + performance_metrics["max_dprime_engaged"] = rpdf["rolling_dprime"][engaged_trial_mask].max() return performance_metrics @@ -940,11 +846,7 @@ def eye_tracking(self) -> Optional[pd.DataFrame]: :rtype: pandas.DataFrame """ - return ( - self._eye_tracking.value - if self._eye_tracking is not None - else None - ) + return self._eye_tracking.value if self._eye_tracking is not None else None @property def eye_tracking_rig_geometry(self) -> dict: @@ -1195,9 +1097,7 @@ def stimulus_templates(self) -> Optional[pd.DataFrame]: """ if self._stimuli.templates.image_template_key is not None: - return self._stimuli.templates.value[ - self._stimuli.templates.image_template_key - ].to_dataframe() + return self._stimuli.templates.value[self._stimuli.templates.image_template_key].to_dataframe() else: return None @@ -1504,9 +1404,7 @@ def _read_behavior_stimulus_timestamps( behavior_stimulus_file """ if sync_file is not None: - stimulus_timestamps = StimulusTimestamps.from_sync_file( - sync_file=sync_file, monitor_delay=monitor_delay - ) + stimulus_timestamps = StimulusTimestamps.from_sync_file(sync_file=sync_file, monitor_delay=monitor_delay) else: stimulus_timestamps = StimulusTimestamps.from_stimulus_file( stimulus_file=stimulus_file_lookup.behavior_stimulus_file, @@ -1551,9 +1449,7 @@ def _read_data_from_stimulus_file( monitor_delay=monitor_delay, ) - rewards = cls._read_rewards( - stimulus_file_lookup=stimulus_file_lookup, sync_file=sync_file - ) + rewards = cls._read_rewards(stimulus_file_lookup=stimulus_file_lookup, sync_file=sync_file) session_stimulus_timestamps = cls._read_session_timestamps( stimulus_file_lookup=stimulus_file_lookup, @@ -1582,9 +1478,7 @@ def _read_data_from_stimulus_file( else: stimuli = None - task_parameters = TaskParameters.from_stimulus_file( - stimulus_file=stimulus_file_lookup.behavior_stimulus_file - ) + task_parameters = TaskParameters.from_stimulus_file(stimulus_file=stimulus_file_lookup.behavior_stimulus_file) return ( session_stimulus_timestamps.subtract_monitor_delay(), @@ -1623,9 +1517,7 @@ def _read_eye_tracking_table( trim_after_spike=False, ) - stimulus_timestamps = StimulusTimestamps( - timestamps=frame_times.to_numpy(), monitor_delay=0.0 - ) + stimulus_timestamps = StimulusTimestamps(timestamps=frame_times.to_numpy(), monitor_delay=0.0) return EyeTrackingTable.from_data_file( data_file=eye_tracking_file, diff --git a/allensdk/brain_observatory/behavior/criteria.py b/allensdk/brain_observatory/behavior/criteria.py index 6743e3717d..6ce819744f 100644 --- a/allensdk/brain_observatory/behavior/criteria.py +++ b/allensdk/brain_observatory/behavior/criteria.py @@ -19,27 +19,34 @@ def two_out_of_three_aint_bad(session_summary): ordered ascending by training day, for at least the past 3 days. If dataframe is not properly ordered, criterion may not be correctly calculated. This function does not sort the data to preserve prior behavior (sorting column was not required by mtrain function). - The mtrain implementation created the required columns if they didn't exist, so + The mtrain implementation created the required columns if they didn't exist, so a more informative error is raised here to assist end-users in debugging. Returns: bool: True if criterion is met, False otherwise """ if len(session_summary) < 3: - raise DataFrameIndexError("Not enough data in session_summary frame. " - "Expected >= 3 rows, got {}".format(len(session_summary))) + raise DataFrameIndexError( + "Not enough data in session_summary frame. Expected >= 3 rows, got {}".format(len(session_summary)) + ) try: last_three = session_summary["dprime_peak"][-3:] except KeyError as e: - raise DataFrameKeyError("Failed accessing last three values in colum" - "'dprime_peak'.\n df length={}, df columns={}\n" - .format(len(session_summary), list(session_summary)), e) - logger.info('dprime_peak over last three days: {}'.format(list(last_three))) + raise DataFrameKeyError( + "Failed accessing last three values in colum'dprime_peak'.\n df length={}, df columns={}\n".format( + len(session_summary), list(session_summary) + ), + e, + ) + logger.info("dprime_peak over last three days: {}".format(list(last_three))) criteria = bool( - ((last_three > 2).sum() > 1) # at least two of the last three + ( + (last_three > 2).sum() > 1 # at least two of the last three + ) ) logger.info("'Two out of three ain't bad' criteria met: '{}'".format(criteria)) return criteria + def yesterday_was_good(session_summary): """Returns true if the last day showed a peak d-prime above 2 Args: @@ -47,20 +54,24 @@ def yesterday_was_good(session_summary): ordered ascending by training day, for at least 1 day. If dataframe is not properly ordered, criterion may not be correctly calculated. This function does not sort the data to preserve prior behavior (sorting column was not required by mtrain function). - The mtrain implementation created the required columns if they didn't exist, so + The mtrain implementation created the required columns if they didn't exist, so a more informative error is raised here to assist end-users in debugging. Returns: bool: True if criterion is met, False otherwise """ if len(session_summary) < 1: - raise DataFrameIndexError("Not enough data in session_summary frame. " - "Expected >= 1 row(s), got {}".format(len(session_summary))) + raise DataFrameIndexError( + "Not enough data in session_summary frame. Expected >= 1 row(s), got {}".format(len(session_summary)) + ) try: - last_day = session_summary['dprime_peak'].iloc[-1] + last_day = session_summary["dprime_peak"].iloc[-1] except KeyError as e: - raise DataFrameKeyError("Failed accessing last three values in colum" - "'dprime_peak'.\n df length={}, df columns={}\n" - .format(len(session_summary), list(session_summary)), e) + raise DataFrameKeyError( + "Failed accessing last three values in colum'dprime_peak'.\n df length={}, df columns={}\n".format( + len(session_summary), list(session_summary) + ), + e, + ) criteria = bool(last_day > 2) logger.info("'Yesterday was good' criteria met: {}".format(criteria)) return criteria @@ -74,23 +85,26 @@ def no_response_bias(session_summary): ordered ascending by training day, for at least 1 day. If dataframe is not properly ordered, criterion may not be correctly calculated. This function does not sort the data to preserve prior behavior (sorting column was not required by mtrain function). - The mtrain implementation created the required columns if they didn't exist, so + The mtrain implementation created the required columns if they didn't exist, so a more informative error is raised here to assist end-users in debugging. Returns: bool: True if criterion is met, False otherwise """ if len(session_summary) < 1: - raise DataFrameIndexError("Not enough data in session_summary frame. " - "Expected >= 1 row(s), got {}".format(len(session_summary))) + raise DataFrameIndexError( + "Not enough data in session_summary frame. Expected >= 1 row(s), got {}".format(len(session_summary)) + ) try: - response_bias = session_summary['response_bias'].iloc[-1] + response_bias = session_summary["response_bias"].iloc[-1] except KeyError as e: - raise DataFrameKeyError("Failed accessing last values in colum" - "'response_bias'.\n df length={}, df columns={}\n" - .format(len(session_summary), list(session_summary)), e) + raise DataFrameKeyError( + "Failed accessing last values in colum'response_bias'.\n df length={}, df columns={}\n".format( + len(session_summary), list(session_summary) + ), + e, + ) criteria = (response_bias < 0.9) & (response_bias > 0.1) - logger.info("'No response bias' criteria met: {} (response bias={})" - .format(criteria, response_bias)) + logger.info("'No response bias' criteria met: {} (response bias={})".format(criteria, response_bias)) return criteria @@ -102,22 +116,26 @@ def whole_lotta_trials(session_summary): ordered ascending by training day, for at least 1 day. If dataframe is not properly ordered, criterion may not be correctly calculated. This function does not sort the data to preserve prior behavior (sorting column was not required by mtrain function). - The mtrain implementation created the required columns if they didn't exist, so + The mtrain implementation created the required columns if they didn't exist, so a more informative error is raised here to assist end-users in debugging. Returns: bool: True if criterion is met, False otherwise """ if len(session_summary) < 1: - raise DataFrameIndexError("Not enough data in session_summary frame. " - "Expected >= 1 row(s), got {}".format(len(session_summary))) + raise DataFrameIndexError( + "Not enough data in session_summary frame. Expected >= 1 row(s), got {}".format(len(session_summary)) + ) try: - num_trials = session_summary['num_contingent_trials'].iloc[-1] + num_trials = session_summary["num_contingent_trials"].iloc[-1] except KeyError as e: - raise DataFrameKeyError("Failed accessing last values in colum" - "'num_contingent_trials'.\n df length={}, df columns={}\n" - .format(len(session_summary), list(session_summary)), e) + raise DataFrameKeyError( + "Failed accessing last values in colum'num_contingent_trials'.\n df length={}, df columns={}\n".format( + len(session_summary), list(session_summary) + ), + e, + ) criteria = num_trials > 300 - logger.info("'Trials > 300' criteria met: {} (n trials={})".format(criteria, num_trials)) + logger.info("'Trials > 300' criteria met: {} (n trials={})".format(criteria, num_trials)) return criteria @@ -131,37 +149,39 @@ def mostly_useful(trials): Returns: bool: True if criterion is met, False otherwise """ - if len(trials) == 0: # empty df would return true, but shouldn't + if len(trials) == 0: # empty df would return true, but shouldn't return False - last_day = trials['training_day'].max() - group = trials.groupby('training_day').get_group(last_day) - trial_fractions = group.groupby('trial_type')['trial_length'].sum() \ - / group['trial_length'].sum() - aborted = trial_fractions['aborted'] + last_day = trials["training_day"].max() + group = trials.groupby("training_day").get_group(last_day) + trial_fractions = group.groupby("trial_type")["trial_length"].sum() / group["trial_length"].sum() + aborted = trial_fractions["aborted"] criteria = aborted < 0.5 - logger.info("Fewer than half the trials were aborted on the last training day: {} " - "(% aborted trials={})".format(criteria, aborted)) + logger.info( + "Fewer than half the trials were aborted on the last training day: {} (% aborted trials={})".format( + criteria, aborted + ) + ) return criteria def consistency_is_key(session_summary): - '''need some way to judge consistency of various parameters + """need some way to judge consistency of various parameters - dprime - num trials - hit rate - fa rate - lick timing - ''' + """ raise NotImplementedError def consistent_behavior_within_session(session_summary): - '''need some way to measure consistent performance within a session + """need some way to measure consistent performance within a session - compare peak to overall dprime? - variance in rolling window dprime? - ''' + """ raise NotImplementedError @@ -186,25 +206,29 @@ def meets_engagement_criteria(session_summary): ordered ascending by training day, for at least 3 days. If dataframe is not properly ordered, criterion may not be correctly calculated. This function does not sort the data to preserve prior behavior (sorting column was not required by mtrain function) - The mtrain implementation created the required columns if they didn't exist, so + The mtrain implementation created the required columns if they didn't exist, so a more informative error is raised here to assist end-users in debugging. Returns: bool: True if criterion is met, False otherwise """ criteria = 3 if len(session_summary) < 3: - raise DataFrameIndexError("Not enough data in session_summary frame. " - "Expected >= 3 rows, got {}".format(len(session_summary))) + raise DataFrameIndexError( + "Not enough data in session_summary frame. Expected >= 3 rows, got {}".format(len(session_summary)) + ) try: - session_summary['engagement_criteria'] = ( - (session_summary['dprime_peak'] > 1.0) - & (session_summary['num_engaged_trials'] > 100) + session_summary["engagement_criteria"] = (session_summary["dprime_peak"] > 1.0) & ( + session_summary["num_engaged_trials"] > 100 ) - engaged_days = session_summary['engagement_criteria'].iloc[-3:].sum() + engaged_days = session_summary["engagement_criteria"].iloc[-3:].sum() except KeyError as e: - raise DataFrameKeyError("Failed accessing columns 'dprime_peak' and/or " - "'num_engaged_trials' for 3 days.\n df length={}, df columns={}\n" - .format(len(session_summary), list(session_summary)), e) + raise DataFrameKeyError( + "Failed accessing columns 'dprime_peak' and/or " + "'num_engaged_trials' for 3 days.\n df length={}, df columns={}\n".format( + len(session_summary), list(session_summary) + ), + e, + ) return engaged_days == criteria @@ -213,4 +237,4 @@ def summer_over(trials): Returns true if the maximum value of 'training_day' in the trials dataframe is >= 40, else false. """ - return trials['training_day'].max() >= 40 + return trials["training_day"].max() >= 40 diff --git a/allensdk/brain_observatory/behavior/data_files/__init__.py b/allensdk/brain_observatory/behavior/data_files/__init__.py index 3bf5fb9210..4896e4dcbb 100644 --- a/allensdk/brain_observatory/behavior/data_files/__init__.py +++ b/allensdk/brain_observatory/behavior/data_files/__init__.py @@ -1,6 +1,7 @@ from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( # noqa F401 BehaviorStimulusFile, ReplayStimulusFile, - MappingStimulusFile) + MappingStimulusFile, +) from allensdk.brain_observatory.behavior.data_files.sync_file import SyncFile # noqa E501, F401 diff --git a/allensdk/brain_observatory/behavior/data_files/demix_file.py b/allensdk/brain_observatory/behavior/data_files/demix_file.py index c60baff5e3..4ba8d94898 100644 --- a/allensdk/brain_observatory/behavior/data_files/demix_file.py +++ b/allensdk/brain_observatory/behavior/data_files/demix_file.py @@ -36,9 +36,7 @@ def from_json(cls, dict_repr: dict) -> "DemixFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str] - ) -> "DemixFile": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "DemixFile": query = """ SELECT wkf.storage_directory || wkf.filename AS demix_file FROM ophys_experiments oe @@ -48,9 +46,7 @@ def from_lims( WHERE wkf.attachable_type = 'OphysExperiment' AND wkft.name = 'DemixedTracesFile' AND oe.id = {}; - """.format( - ophys_experiment_id - ) + """.format(ophys_experiment_id) filepath = db.fetchone(query, strict=True) return cls(filepath=filepath) diff --git a/allensdk/brain_observatory/behavior/data_files/dff_file.py b/allensdk/brain_observatory/behavior/data_files/dff_file.py index fdd371b7fe..dc7ae5f419 100644 --- a/allensdk/brain_observatory/behavior/data_files/dff_file.py +++ b/allensdk/brain_observatory/behavior/data_files/dff_file.py @@ -37,10 +37,7 @@ def from_json(cls, dict_repr: dict) -> "DFFFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, - ophys_experiment_id: Union[int, str] - ) -> "DFFFile": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "DFFFile": query = """ SELECT wkf.storage_directory || wkf.filename AS dff_file FROM ophys_experiments oe @@ -55,8 +52,8 @@ def from_lims( @staticmethod def load_data(filepath: Union[str, Path]) -> pd.DataFrame: - with h5py.File(filepath, 'r') as raw_file: - traces = np.asarray(raw_file['data'], dtype=np.float64) - roi_names = np.asarray(raw_file['roi_names']) - idx = pd.Index(roi_names, name='cell_roi_id').astype('int64') - return pd.DataFrame({'dff': [x for x in traces]}, index=idx) + with h5py.File(filepath, "r") as raw_file: + traces = np.asarray(raw_file["data"], dtype=np.float64) + roi_names = np.asarray(raw_file["roi_names"]) + idx = pd.Index(roi_names, name="cell_roi_id").astype("int64") + return pd.DataFrame({"dff": [x for x in traces]}, index=idx) diff --git a/allensdk/brain_observatory/behavior/data_files/event_detection_file.py b/allensdk/brain_observatory/behavior/data_files/event_detection_file.py index 00221ac6a6..810189690e 100644 --- a/allensdk/brain_observatory/behavior/data_files/event_detection_file.py +++ b/allensdk/brain_observatory/behavior/data_files/event_detection_file.py @@ -38,33 +38,25 @@ def from_json(cls, dict_repr: dict) -> "EventDetectionFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, - ophys_experiment_id: Union[int, str] - ) -> "EventDetectionFile": - query = f''' + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "EventDetectionFile": + query = f""" SELECT wkf.storage_directory || wkf.filename AS event_detection_filepath FROM ophys_experiments oe LEFT JOIN well_known_files wkf ON wkf.attachable_id = oe.id JOIN well_known_file_types wkft ON wkf.well_known_file_type_id = wkft.id WHERE wkft.name = 'OphysEventTraceFile' AND oe.id = {ophys_experiment_id}; - ''' # noqa E501 + """ # noqa E501 filepath = safe_system_path(db.fetchone(query, strict=True)) return cls(filepath=filepath) @staticmethod - def load_data(filepath: Union[str, Path]) -> \ - Tuple[np.ndarray, pd.DataFrame]: - with h5py.File(filepath, 'r') as f: - events = f['events'][:] - lambdas = f['lambdas'][:] - noise_stds = f['noise_stds'][:] - roi_ids = f['roi_names'][:] - - df = pd.DataFrame({ - 'lambda': lambdas, - 'noise_std': noise_stds, - 'cell_roi_id': roi_ids - }) + def load_data(filepath: Union[str, Path]) -> Tuple[np.ndarray, pd.DataFrame]: + with h5py.File(filepath, "r") as f: + events = f["events"][:] + lambdas = f["lambdas"][:] + noise_stds = f["noise_stds"][:] + roi_ids = f["roi_names"][:] + + df = pd.DataFrame({"lambda": lambdas, "noise_std": noise_stds, "cell_roi_id": roi_ids}) return events, df diff --git a/allensdk/brain_observatory/behavior/data_files/eye_tracking_file.py b/allensdk/brain_observatory/behavior/data_files/eye_tracking_file.py index 623a941284..45ac32573f 100644 --- a/allensdk/brain_observatory/behavior/data_files/eye_tracking_file.py +++ b/allensdk/brain_observatory/behavior/data_files/eye_tracking_file.py @@ -3,8 +3,7 @@ import pandas as pd -from allensdk.brain_observatory.behavior.eye_tracking_processing import \ - load_eye_tracking_hdf +from allensdk.brain_observatory.behavior.eye_tracking_processing import load_eye_tracking_hdf from allensdk.internal.api import PostgresQueryMixin, OneResultExpectedError from allensdk.internal.core.lims_utilities import safe_system_path from allensdk.internal.core import DataFile @@ -24,10 +23,7 @@ def from_json(cls, dict_repr: dict) -> "EyeTrackingFile": return cls(filepath=filepath) @classmethod - def from_lims( - cls, db: PostgresQueryMixin, - behavior_session_id: Union[int, str] - ) -> "EyeTrackingFile": + def from_lims(cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str]) -> "EyeTrackingFile": query = f""" SELECT wkf.storage_directory || wkf.filename AS eye_tracking_file FROM behavior_sessions bs diff --git a/allensdk/brain_observatory/behavior/data_files/eye_tracking_metadata_file.py b/allensdk/brain_observatory/behavior/data_files/eye_tracking_metadata_file.py index dee0338333..df86a47853 100644 --- a/allensdk/brain_observatory/behavior/data_files/eye_tracking_metadata_file.py +++ b/allensdk/brain_observatory/behavior/data_files/eye_tracking_metadata_file.py @@ -15,7 +15,7 @@ def __init__(self, filepath: Union[str, pathlib.Path]): @staticmethod def load_data(filepath: Union[str, pathlib.Path]) -> dict: - with open(filepath, 'rb') as in_file: + with open(filepath, "rb") as in_file: return json.load(in_file) @classmethod @@ -24,12 +24,9 @@ def file_path_key(cls) -> str: @classmethod def from_lims(cls): - raise NotImplementedError( - "from_lims not yet supported for EyeTrackingMetadataFile") + raise NotImplementedError("from_lims not yet supported for EyeTrackingMetadataFile") @classmethod - def from_json( - cls, - dict_repr: dict) -> "EyeTrackingMetadataFile": + def from_json(cls, dict_repr: dict) -> "EyeTrackingMetadataFile": filepath = dict_repr[cls.file_path_key()] return cls(filepath=filepath) diff --git a/allensdk/brain_observatory/behavior/data_files/eye_tracking_video.py b/allensdk/brain_observatory/behavior/data_files/eye_tracking_video.py index ab803cb2a7..ffb1e33850 100644 --- a/allensdk/brain_observatory/behavior/data_files/eye_tracking_video.py +++ b/allensdk/brain_observatory/behavior/data_files/eye_tracking_video.py @@ -11,10 +11,7 @@ class EyeTrackingVideo(DataFile): @classmethod def from_lims( - cls, - db: PostgresQueryMixin, - behavior_session_id: Union[int, str], - session_type: str = 'OphysSession' + cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str], session_type: str = "OphysSession" ) -> "EyeTrackingVideo": """ @@ -29,10 +26,9 @@ def from_lims( ------- `EyeTrackingVideo` instance """ - valid_session_types = ('OphysSession', 'EcephysSession') + valid_session_types = ("OphysSession", "EcephysSession") if session_type not in valid_session_types: - raise ValueError(f'Session type must be one of ' - f'{valid_session_types}') + raise ValueError(f"Session type must be one of {valid_session_types}") query = f""" SELECT wkf.storage_directory || wkf.filename AS eye_tracking_file FROM behavior_sessions bs diff --git a/allensdk/brain_observatory/behavior/data_files/neuropil_corrected_file.py b/allensdk/brain_observatory/behavior/data_files/neuropil_corrected_file.py index b60a0ed822..c2ee71dc9a 100644 --- a/allensdk/brain_observatory/behavior/data_files/neuropil_corrected_file.py +++ b/allensdk/brain_observatory/behavior/data_files/neuropil_corrected_file.py @@ -36,9 +36,7 @@ def from_json(cls, dict_repr: dict) -> "NeuropilCorrectedFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str] - ) -> "NeuropilCorrectedFile": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "NeuropilCorrectedFile": query = """ SELECT wkf.storage_directory || wkf.filename AS \ neuropil_corrected_file @@ -49,9 +47,7 @@ def from_lims( WHERE wkf.attachable_type = 'OphysExperiment' AND wkft.name = 'NeuropilCorrection' AND oe.id = {}; - """.format( - ophys_experiment_id - ) + """.format(ophys_experiment_id) filepath = db.fetchone(query, strict=True) return cls(filepath=filepath) diff --git a/allensdk/brain_observatory/behavior/data_files/neuropil_file.py b/allensdk/brain_observatory/behavior/data_files/neuropil_file.py index 846a76b1cd..787ff39d12 100644 --- a/allensdk/brain_observatory/behavior/data_files/neuropil_file.py +++ b/allensdk/brain_observatory/behavior/data_files/neuropil_file.py @@ -36,9 +36,7 @@ def from_json(cls, dict_repr: dict) -> "NeuropilFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str] - ) -> "NeuropilFile": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "NeuropilFile": query = """ SELECT wkf.storage_directory || wkf.filename AS neuropil_file FROM ophys_experiments oe @@ -48,9 +46,7 @@ def from_lims( WHERE wkf.attachable_type = 'OphysExperiment' AND wkft.name = 'OphysNeuropilTraces' AND oe.id = {}; - """.format( - ophys_experiment_id - ) + """.format(ophys_experiment_id) filepath = db.fetchone(query, strict=True) return cls(filepath=filepath) diff --git a/allensdk/brain_observatory/behavior/data_files/rigid_motion_transform_file.py b/allensdk/brain_observatory/behavior/data_files/rigid_motion_transform_file.py index 07f121d556..ad6e09ca37 100644 --- a/allensdk/brain_observatory/behavior/data_files/rigid_motion_transform_file.py +++ b/allensdk/brain_observatory/behavior/data_files/rigid_motion_transform_file.py @@ -36,10 +36,7 @@ def from_json(cls, dict_repr: dict) -> "RigidMotionTransformFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, - ophys_experiment_id: Union[int, str] - ) -> "RigidMotionTransformFile": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: Union[int, str]) -> "RigidMotionTransformFile": query = """ SELECT wkf.storage_directory || wkf.filename AS transform_file FROM ophys_experiments oe @@ -56,4 +53,4 @@ def from_lims( @staticmethod def load_data(filepath: Union[str, Path]) -> pd.DataFrame: motion_correction = pd.read_csv(filepath) - return motion_correction[['x', 'y']] + return motion_correction[["x", "y"]] diff --git a/allensdk/brain_observatory/behavior/data_files/stimulus_file.py b/allensdk/brain_observatory/behavior/data_files/stimulus_file.py index b222a4052e..c9e9cfef67 100644 --- a/allensdk/brain_observatory/behavior/data_files/stimulus_file.py +++ b/allensdk/brain_observatory/behavior/data_files/stimulus_file.py @@ -69,9 +69,7 @@ def _from_json(cls, stimulus_file_path: str) -> "_StimulusFile": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str] - ) -> "_StimulusFile": + def from_lims(cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str]) -> "_StimulusFile": raise NotImplementedError() @staticmethod @@ -97,12 +95,8 @@ def file_path_key(cls) -> str: @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) - def from_lims( - cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str] - ) -> "BehaviorStimulusFile": - query = BEHAVIOR_STIMULUS_FILE_QUERY_TEMPLATE.format( - behavior_session_id=behavior_session_id - ) + def from_lims(cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str]) -> "BehaviorStimulusFile": + query = BEHAVIOR_STIMULUS_FILE_QUERY_TEMPLATE.format(behavior_session_id=behavior_session_id) filepath = db.fetchone(query, strict=True) return cls(filepath=filepath) @@ -121,10 +115,7 @@ def _validate_frame_data(self): msg += "self.data['items']['behavior'] not present\n" else: if "intervalsms" not in self.data["items"]["behavior"]: - msg += ( - "self.data['items']['behavior']['intervalsms'] " - "not present\n" - ) + msg += "self.data['items']['behavior']['intervalsms'] not present\n" if len(msg) > 0: full_msg = f"When getting num_frames from {type(self)}\n" @@ -161,9 +152,7 @@ def date_of_acquisition(self) -> datetime.datetime: """ assert isinstance(self.data, dict) if "start_time" not in self.data: - raise KeyError( - "No 'start_time' listed in pickle file " f"{self.filepath}" - ) + raise KeyError(f"No 'start_time' listed in pickle file {self.filepath}") return copy.deepcopy(self.data["start_time"]) @@ -222,9 +211,7 @@ def stimulus_name(self) -> str: the mouse. """ try: - stimulus_name = Path( - self.stimuli["images"]["image_set"] - ).stem.split(".")[0] + stimulus_name = Path(self.stimuli["images"]["image_set"]).stem.split(".")[0] except KeyError: # if we can't find the images key in the stimuli, check for the # name ``grating`` as the stimulus. If not add generic @@ -255,21 +242,15 @@ def _retrieve_from_params(self, key_name: str): param_value = None if "params" in self.data["items"]["behavior"]: if key_name in self.data["items"]["behavior"]["params"]: - param_value = self.data["items"]["behavior"]["params"][ - key_name - ] + param_value = self.data["items"]["behavior"]["params"][key_name] cl_value = None if "cl_params" in self.data["items"]["behavior"]: if key_name in self.data["items"]["behavior"]["cl_params"]: - cl_value = self.data["items"]["behavior"]["cl_params"][ - key_name - ] + cl_value = self.data["items"]["behavior"]["cl_params"][key_name] if cl_value is None and param_value is None: - raise RuntimeError( - f"Could not find {key_name} in pickle file " f"{self.filepath}" - ) + raise RuntimeError(f"Could not find {key_name} in pickle file {self.filepath}") if param_value is None: return cl_value @@ -328,8 +309,7 @@ def stimuli(self) -> Dict[str, Tuple[str, Union[str, int], int, int]]: def validate(self) -> "BehaviorStimulusFile": if "items" not in self.data or "behavior" not in self.data["items"]: raise MalformedStimulusFileError( - f'Expected to find key "behavior" in "items" dict. ' - f'Found {self.data["items"].keys()}' + f'Expected to find key "behavior" in "items" dict. Found {self.data["items"].keys()}' ) return self @@ -351,9 +331,7 @@ class StimulusFileReadableInterface(abc.ABC): @classmethod @abc.abstractmethod - def from_stimulus_file( - cls, stimulus_file: BehaviorStimulusFile - ) -> "DataObject": + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "DataObject": """Populate a DataObject from the stimulus file Returns @@ -376,9 +354,7 @@ def __init__(self): @property def behavior_stimulus_file(self) -> BehaviorStimulusFile: if "behavior" not in self._values: - raise ValueError( - "This StimulusFileLookup has no " "BehaviorStimulusFile" - ) + raise ValueError("This StimulusFileLookup has no BehaviorStimulusFile") return self._values["behavior"] @behavior_stimulus_file.setter @@ -394,27 +370,21 @@ def behavior_stimulus_file(self, value: BehaviorStimulusFile): @property def replay_stimulus_file(self) -> ReplayStimulusFile: if "replay" not in self._values: - raise ValueError( - "This StimulusFileLookup has no " "ReplayStimulusFile" - ) + raise ValueError("This StimulusFileLookup has no ReplayStimulusFile") return self._values["replay"] @replay_stimulus_file.setter def replay_stimulus_file(self, value: ReplayStimulusFile): if not isinstance(value, ReplayStimulusFile): raise ValueError( - "Trying to set replay_stimulus_file to " - f"value of type {type(value)}; type should " - "be ReplayStimulusFile" + f"Trying to set replay_stimulus_file to value of type {type(value)}; type should be ReplayStimulusFile" ) self._values["replay"] = value @property def mapping_stimulus_file(self) -> MappingStimulusFile: if "mapping" not in self._values: - raise ValueError( - "This StimulusFileLookup has no " "MappingStimulusFile" - ) + raise ValueError("This StimulusFileLookup has no MappingStimulusFile") return self._values["mapping"] @mapping_stimulus_file.setter diff --git a/allensdk/brain_observatory/behavior/data_files/sync_file.py b/allensdk/brain_observatory/behavior/data_files/sync_file.py index 6099375e53..9b0cdb6abd 100644 --- a/allensdk/brain_observatory/behavior/data_files/sync_file.py +++ b/allensdk/brain_observatory/behavior/data_files/sync_file.py @@ -14,9 +14,7 @@ def _get_sync_file_query_template(behavior_session_id: int): - - """Query returns path to sync timing file associated with behavior session - """ + """Query returns path to sync timing file associated with behavior session""" SYNC_FILE_QUERY_TEMPLATE = f""" SELECT wkf.storage_directory || wkf.filename AS sync_file FROM behavior_sessions bs @@ -65,33 +63,28 @@ def permissive(self) -> bool: # pragma: no cover @classmethod @cached(cache=LRUCache(maxsize=10), key=from_json_cache_key) - def from_json(cls, - dict_repr: dict, - permissive: bool = False) -> "SyncFile": + def from_json(cls, dict_repr: dict, permissive: bool = False) -> "SyncFile": filepath = dict_repr["sync_file"] return cls(filepath=filepath, permissive=permissive) @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) def from_lims( - cls, db: PostgresQueryMixin, - behavior_session_id: Union[int, str], - permissive: bool = False + cls, db: PostgresQueryMixin, behavior_session_id: Union[int, str], permissive: bool = False ) -> "SyncFile": - query = _get_sync_file_query_template( - behavior_session_id=behavior_session_id) + query = _get_sync_file_query_template(behavior_session_id=behavior_session_id) filepath = db.fetchone(query, strict=True) return cls(filepath=filepath, permissive=permissive) @staticmethod - def load_data(filepath: Union[str, Path], - permissive: bool = False) -> dict: + def load_data(filepath: Union[str, Path], permissive: bool = False) -> dict: filepath = safe_system_path(file_name=filepath) return get_sync_data(sync_path=filepath, permissive=permissive) class SyncFileReadableInterface(abc.ABC): """Marks a data object as readable from sync file""" + @classmethod @abc.abstractmethod def from_sync_file(cls, *args) -> "DataObject": diff --git a/allensdk/brain_observatory/behavior/data_objects/__init__.py b/allensdk/brain_observatory/behavior/data_objects/__init__.py index 8bdb4f4b11..7bcc137971 100644 --- a/allensdk/brain_observatory/behavior/data_objects/__init__.py +++ b/allensdk/brain_observatory/behavior/data_objects/__init__.py @@ -1,6 +1,8 @@ -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_session_id import BehaviorSessionId # noqa: E501, F401 -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .stimulus_timestamps.stimulus_timestamps import StimulusTimestamps # noqa: E501, F401 +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_session_id import ( + BehaviorSessionId, +) # noqa: E501, F401 +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.stimulus_timestamps import ( + StimulusTimestamps, +) # noqa: E501, F401 from allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed import RunningSpeed # noqa: E501, F401 from allensdk.brain_observatory.behavior.data_objects.running_speed.running_acquisition import RunningAcquisition # noqa: E501, F401 diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/cell_specimens.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/cell_specimens.py index 243770d662..c4f7b5a7bf 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/cell_specimens.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/cell_specimens.py @@ -6,15 +6,13 @@ from pynwb.ophys import OpticalChannel, ImageSegmentation import allensdk.brain_observatory.roi_masks as roi -from allensdk.brain_observatory.behavior.data_files.neuropil_corrected_file \ - import NeuropilCorrectedFile +from allensdk.brain_observatory.behavior.data_files.neuropil_corrected_file import NeuropilCorrectedFile from allensdk.brain_observatory.behavior.data_files.demix_file import DemixFile from allensdk.brain_observatory.behavior.data_files.neuropil_file import ( NeuropilFile, ) from allensdk.brain_observatory.behavior.data_files.dff_file import DFFFile -from allensdk.brain_observatory.behavior.data_files.event_detection_file \ - import EventDetectionFile +from allensdk.brain_observatory.behavior.data_files.event_detection_file import EventDetectionFile from allensdk.core import DataObject from allensdk.core import ( JsonReadableInterface, @@ -22,22 +20,20 @@ NwbReadableInterface, ) from allensdk.core import NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.cell_specimens.events \ - import Events -from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces\ - .corrected_fluorescence_traces import CorrectedFluorescenceTraces -from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces\ - .demixed_traces import DemixedTraces -from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces\ - .neuropil_traces import NeuropilTraces -from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces\ - .dff_traces import DFFTraces -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .ophys_experiment_metadata.field_of_view_shape import FieldOfViewShape -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .ophys_experiment_metadata.imaging_plane import ImagingPlane -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .ophys_timestamps import OphysTimestamps +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.events import Events +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces.corrected_fluorescence_traces import ( + CorrectedFluorescenceTraces, +) +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces.demixed_traces import DemixedTraces +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces.neuropil_traces import NeuropilTraces +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.traces.dff_traces import DFFTraces +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.field_of_view_shape import ( + FieldOfViewShape, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.imaging_plane import ( + ImagingPlane, +) +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import OphysTimestamps from allensdk.brain_observatory.behavior.image_api import Image from allensdk.brain_observatory.nwb import CELL_SPECIMEN_COL_DESCRIPTIONS from allensdk.brain_observatory.nwb.nwb_utils import add_image_to_nwb @@ -79,9 +75,7 @@ class CellSpecimenMeta( """Cell specimen metadata""" def __init__(self, imaging_plane: ImagingPlane, emission_lambda=520.0): - super().__init__( - name="cell_specimen_meta", value=None, is_value_self=True - ) + super().__init__(name="cell_specimen_meta", value=None, is_value_self=True) self._emission_lambda = emission_lambda self._imaging_plane = imaging_plane @@ -108,12 +102,8 @@ def from_lims( return cls(imaging_plane=imaging_plane_meta) @classmethod - def from_json( - cls, dict_repr: dict, ophys_timestamps: OphysTimestamps - ) -> "CellSpecimenMeta": - imaging_plane_meta = ImagingPlane.from_json( - dict_repr=dict_repr, ophys_timestamps=ophys_timestamps - ) + def from_json(cls, dict_repr: dict, ophys_timestamps: OphysTimestamps) -> "CellSpecimenMeta": + imaging_plane_meta = ImagingPlane.from_json(dict_repr=dict_repr, ophys_timestamps=ophys_timestamps) return cls(imaging_plane=imaging_plane_meta) @classmethod @@ -128,9 +118,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "CellSpecimenMeta": emission_lambda = optical_channel.emission_lambda imaging_plane = ImagingPlane.from_nwb(nwbfile=nwbfile) - return CellSpecimenMeta( - emission_lambda=emission_lambda, imaging_plane=imaging_plane - ) + return CellSpecimenMeta(emission_lambda=emission_lambda, imaging_plane=imaging_plane) class CellSpecimens( @@ -185,14 +173,10 @@ def __init__( exclude_invalid_rois Whether to exclude invalid rois """ - super().__init__( - name="cell_specimen_table", value=None, is_value_self=True - ) + super().__init__(name="cell_specimen_table", value=None, is_value_self=True) # Validate ophys timestamps, traces - ophys_timestamps = ophys_timestamps.validate( - number_of_frames=dff_traces.get_number_of_frames() - ) + ophys_timestamps = ophys_timestamps.validate(number_of_frames=dff_traces.get_number_of_frames()) self._validate_traces( ophys_timestamps=ophys_timestamps, dff_traces=dff_traces, @@ -203,26 +187,16 @@ def __init__( ) if exclude_invalid_rois: - cell_specimen_table = cell_specimen_table[ - cell_specimen_table["valid_roi"] - ] + cell_specimen_table = cell_specimen_table[cell_specimen_table["valid_roi"]] # Filter/reorder rois according to cell_specimen_table if dff_traces is not None: - dff_traces.filter_and_reorder( - roi_ids=cell_specimen_table["cell_roi_id"].values - ) + dff_traces.filter_and_reorder(roi_ids=cell_specimen_table["cell_roi_id"].values) if demixed_traces is not None: - demixed_traces.filter_and_reorder( - roi_ids=cell_specimen_table["cell_roi_id"].values - ) + demixed_traces.filter_and_reorder(roi_ids=cell_specimen_table["cell_roi_id"].values) if neuropil_traces is not None: - neuropil_traces.filter_and_reorder( - roi_ids=cell_specimen_table["cell_roi_id"].values - ) - corrected_fluorescence_traces.filter_and_reorder( - roi_ids=cell_specimen_table["cell_roi_id"].values - ) + neuropil_traces.filter_and_reorder(roi_ids=cell_specimen_table["cell_roi_id"].values) + corrected_fluorescence_traces.filter_and_reorder(roi_ids=cell_specimen_table["cell_roi_id"].values) # Note: setting raise_if_rois_missing to False for events, since # there seem to be cases where cell_specimen_table contains rois not in @@ -240,9 +214,7 @@ def __init__( self._neuropil_traces = neuropil_traces self._corrected_fluorescence_traces = corrected_fluorescence_traces self._events = events - self._segmentation_mask_image = self._get_segmentation_mask_image( - spacing=segmentation_mask_image_spacing - ) + self._segmentation_mask_image = self._get_segmentation_mask_image(spacing=segmentation_mask_image_spacing) @property def table(self) -> pd.DataFrame: @@ -280,9 +252,7 @@ def dff_traces(self) -> pd.DataFrame: """ if self._dff_traces is None: return None - df = self.table[["cell_roi_id"]].join( - self._dff_traces.value, on="cell_roi_id" - ) + df = self.table[["cell_roi_id"]].join(self._dff_traces.value, on="cell_roi_id") return df @property @@ -308,9 +278,7 @@ def demixed_traces(self) -> pd.DataFrame: """ if self._demixed_traces is None: return None - df = self.table[["cell_roi_id"]].join( - self._demixed_traces.value, on="cell_roi_id" - ) + df = self.table[["cell_roi_id"]].join(self._demixed_traces.value, on="cell_roi_id") return df @property @@ -337,9 +305,7 @@ def neuropil_traces(self) -> pd.DataFrame: """ if self._neuropil_traces is None: return None - df = self.table[["cell_roi_id"]].join( - self._neuropil_traces.value, on="cell_roi_id" - ) + df = self.table[["cell_roi_id"]].join(self._neuropil_traces.value, on="cell_roi_id") return df @property @@ -368,17 +334,13 @@ def corrected_fluorescence_traces(self) -> pd.DataFrame: r: r values (arbitrary units) """ - df = self.table[["cell_roi_id"]].join( - self._corrected_fluorescence_traces.value, on="cell_roi_id" - ) + df = self.table[["cell_roi_id"]].join(self._corrected_fluorescence_traces.value, on="cell_roi_id") return df @property def events(self) -> pd.DataFrame: df = self.table.reset_index() - df = df[["cell_roi_id", "cell_specimen_id"]].merge( - self._events.value, on="cell_roi_id" - ) + df = df[["cell_roi_id", "cell_specimen_id"]].merge(self._events.value, on="cell_roi_id") df = df.set_index("cell_specimen_id") return df @@ -406,9 +368,7 @@ def _get_ophys_cell_segmentation_run_id() -> int: ON oe.id = oseg.ophys_experiment_id WHERE oseg.current = 't' AND oe.id = {}; - """.format( - ophys_experiment_id - ) + """.format(ophys_experiment_id) return lims_db.fetchone(query, strict=True) def _get_cell_specimen_table(): @@ -417,61 +377,39 @@ def _get_cell_specimen_table(): SELECT * FROM cell_rois cr WHERE cr.ophys_cell_segmentation_run_id = {}; - """.format( - ophys_cell_seg_run_id - ) + """.format(ophys_cell_seg_run_id) initial_cs_table = pd.read_sql(query, lims_db.get_connection()) - cst = initial_cs_table.rename( - columns={"id": "cell_roi_id", "mask_matrix": "roi_mask"} - ) + cst = initial_cs_table.rename(columns={"id": "cell_roi_id", "mask_matrix": "roi_mask"}) cst.drop( ["ophys_experiment_id", "ophys_cell_segmentation_run_id"], inplace=True, axis=1, ) cst = cst.to_dict() - fov_shape = FieldOfViewShape.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) - cst = cls._postprocess( - cell_specimen_table=cst, fov_shape=fov_shape - ) + fov_shape = FieldOfViewShape.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) + cst = cls._postprocess(cell_specimen_table=cst, fov_shape=fov_shape) return cst def _get_dff_traces(): - dff_file = DFFFile.from_lims( - ophys_experiment_id=ophys_experiment_id, db=lims_db - ) + dff_file = DFFFile.from_lims(ophys_experiment_id=ophys_experiment_id, db=lims_db) return DFFTraces.from_data_file(dff_file=dff_file) def _get_demixed_traces(): - demix_file = DemixFile.from_lims( - ophys_experiment_id=ophys_experiment_id, db=lims_db - ) - return DemixedTraces.from_data_file( - demix_file=demix_file - ) + demix_file = DemixFile.from_lims(ophys_experiment_id=ophys_experiment_id, db=lims_db) + return DemixedTraces.from_data_file(demix_file=demix_file) def _get_neuropil_traces(): - neuropil_file = NeuropilFile.from_lims( - ophys_experiment_id=ophys_experiment_id, db=lims_db - ) - return NeuropilTraces.from_data_file( - neuropil_file=neuropil_file - ) + neuropil_file = NeuropilFile.from_lims(ophys_experiment_id=ophys_experiment_id, db=lims_db) + return NeuropilTraces.from_data_file(neuropil_file=neuropil_file) def _get_corrected_fluorescence_traces(): neuropil_corrected_file = NeuropilCorrectedFile.from_lims( ophys_experiment_id=ophys_experiment_id, db=lims_db ) - return CorrectedFluorescenceTraces.from_data_file( - neuropil_corrected_file=neuropil_corrected_file - ) + return CorrectedFluorescenceTraces.from_data_file(neuropil_corrected_file=neuropil_corrected_file) def _get_events(): - events_file = EventDetectionFile.from_lims( - ophys_experiment_id=ophys_experiment_id, db=lims_db - ) + events_file = EventDetectionFile.from_lims(ophys_experiment_id=ophys_experiment_id, db=lims_db) return cls._get_events( events_file=events_file, events_params=events_params, @@ -516,41 +454,25 @@ def from_json( ) -> "CellSpecimens": cell_specimen_table = dict_repr["cell_specimen_table_dict"] fov_shape = FieldOfViewShape.from_json(dict_repr=dict_repr) - cell_specimen_table = cls._postprocess( - cell_specimen_table=cell_specimen_table, fov_shape=fov_shape - ) + cell_specimen_table = cls._postprocess(cell_specimen_table=cell_specimen_table, fov_shape=fov_shape) def _get_dff_traces(): dff_file = DFFFile.from_json(dict_repr=dict_repr) return DFFTraces.from_data_file(dff_file=dff_file) def _get_demixed_traces(): - demix_file = DemixFile.from_json( - dict_repr=dict_repr - ) - return DemixedTraces.from_data_file( - demix_file=demix_file - ) + demix_file = DemixFile.from_json(dict_repr=dict_repr) + return DemixedTraces.from_data_file(demix_file=demix_file) def _get_neuropil_traces(): - neuropil_file = NeuropilFile.from_json( - dict_repr=dict_repr - ) - return NeuropilTraces.from_data_file( - neuropil_file=neuropil_file - ) + neuropil_file = NeuropilFile.from_json(dict_repr=dict_repr) + return NeuropilTraces.from_data_file(neuropil_file=neuropil_file) def _get_corrected_fluorescence_traces(): - neuropil_corrected_file = NeuropilCorrectedFile.from_json( - dict_repr=dict_repr - ) - return CorrectedFluorescenceTraces.from_data_file( - neuropil_corrected_file=neuropil_corrected_file - ) + neuropil_corrected_file = NeuropilCorrectedFile.from_json(dict_repr=dict_repr) + return CorrectedFluorescenceTraces.from_data_file(neuropil_corrected_file=neuropil_corrected_file) - meta = CellSpecimenMeta.from_json( - dict_repr=dict_repr, ophys_timestamps=ophys_timestamps - ) + meta = CellSpecimenMeta.from_json(dict_repr=dict_repr, ophys_timestamps=ophys_timestamps) def _get_events(): events_file = EventDetectionFile.from_json(dict_repr=dict_repr) @@ -596,19 +518,14 @@ def _read_table(cell_specimen_table): df = cell_specimen_table.to_dataframe() # Ensure int64 used instead of int32 - df = df.astype( - {col: "int64" for col in df.select_dtypes("int32").columns} - ) + df = df.astype({col: "int64" for col in df.select_dtypes("int32").columns}) # Because pynwb stores this field as "image_mask", it is renamed # here df = df.rename(columns={"image_mask": "roi_mask"}) df.index.rename("cell_roi_id", inplace=True) - df["cell_specimen_id"] = [ - None if id_ == -1 else id_ - for id_ in df["cell_specimen_id"].values - ] + df["cell_specimen_id"] = [None if id_ == -1 else id_ for id_ in df["cell_specimen_id"].values] df.reset_index(inplace=True) df.set_index("cell_specimen_id", inplace=True) @@ -619,9 +536,7 @@ def _read_table(cell_specimen_table): dff_traces = DFFTraces.from_nwb(nwbfile=nwbfile) demixed_traces = DemixedTraces.from_nwb(nwbfile=nwbfile) neuropil_traces = NeuropilTraces.from_nwb(nwbfile=nwbfile) - corrected_fluorescence_traces = CorrectedFluorescenceTraces.from_nwb( - nwbfile=nwbfile - ) + corrected_fluorescence_traces = CorrectedFluorescenceTraces.from_nwb(nwbfile=nwbfile) def _get_events(): return Events.from_nwb( @@ -647,9 +562,7 @@ def _get_events(): exclude_invalid_rois=exclude_invalid_rois, ) - def to_nwb( - self, nwbfile: NWBFile, ophys_timestamps: OphysTimestamps - ) -> NWBFile: + def to_nwb(self, nwbfile: NWBFile, ophys_timestamps: OphysTimestamps) -> NWBFile: """ :param nwbfile In-memory nwb file object @@ -665,13 +578,10 @@ def to_nwb( # FOV: fov_width = metadata.field_of_view_width fov_height = metadata.field_of_view_height - imaging_plane_description = ( - "{} field of view in {} at depth {} " - "um".format( - (fov_width, fov_height), - self._meta.imaging_plane.targeted_structure, - metadata.imaging_depth, - ) + imaging_plane_description = "{} field of view in {} at depth {} um".format( + (fov_width, fov_height), + self._meta.imaging_plane.targeted_structure, + metadata.imaging_depth, ) # Optical Channel: @@ -725,9 +635,7 @@ def to_nwb( # of column both equal to the column name in the cell_roi_table plane_segmentation.add_column( col_name, - CELL_SPECIMEN_COL_DESCRIPTIONS.get( - col_name, "No Description Available" - ), + CELL_SPECIMEN_COL_DESCRIPTIONS.get(col_name, "No Description Available"), ) # go through each roi and add it to the plan segmentation object @@ -748,9 +656,7 @@ def to_nwb( plane_segmentation.add_roi(image_mask=mask, **table_row.to_dict()) # 2. Add DFF traces - self._dff_traces.to_nwb( - nwbfile=nwbfile, ophys_timestamps=ophys_timestamps - ) + self._dff_traces.to_nwb(nwbfile=nwbfile, ophys_timestamps=ophys_timestamps) # 3. Add demixed traces self._demixed_traces.to_nwb(nwbfile=nwbfile) @@ -793,15 +699,9 @@ def _get_segmentation_mask_image(self, spacing: tuple) -> Image: return mask_image @staticmethod - def _postprocess( - cell_specimen_table: dict, fov_shape: FieldOfViewShape - ) -> pd.DataFrame: + def _postprocess(cell_specimen_table: dict, fov_shape: FieldOfViewShape) -> pd.DataFrame: """Converts raw cell_specimen_table dict to dataframe""" - cell_specimen_table = ( - pd.DataFrame.from_dict(cell_specimen_table) - .set_index("cell_roi_id") - .sort_index() - ) + cell_specimen_table = pd.DataFrame.from_dict(cell_specimen_table).set_index("cell_roi_id").sort_index() fov_width = fov_shape.width fov_height = fov_shape.height @@ -823,9 +723,7 @@ def _postprocess( roi_mask_list.append(curr_roi.get_mask_plane().astype(bool)) cell_specimen_table["roi_mask"] = roi_mask_list - cell_specimen_table = cell_specimen_table[ - sorted(cell_specimen_table.columns) - ] + cell_specimen_table = cell_specimen_table[sorted(cell_specimen_table.columns)] cell_specimen_table.index.rename("cell_roi_id", inplace=True) cell_specimen_table.reset_index(inplace=True) @@ -858,21 +756,12 @@ def _validate_traces( continue # validate traces contain expected roi ids if not np.isin(traces.value.index, cell_roi_ids).all(): - raise RuntimeError( - f"{traces.name} contains ROI IDs that " - f"are not in " - f"cell_specimen_table.cell_roi_id" - ) + raise RuntimeError(f"{traces.name} contains ROI IDs that are not in cell_specimen_table.cell_roi_id") if not np.isin(cell_roi_ids, traces.value.index).all(): - raise RuntimeError( - f"cell_specimen_table contains ROI IDs " - f"that are not in {traces.name}" - ) + raise RuntimeError(f"cell_specimen_table contains ROI IDs that are not in {traces.name}") # validate traces contain expected timepoints - num_trace_timepoints = len( - traces.value.iloc[0][trace_col_map[traces.name]] - ) + num_trace_timepoints = len(traces.value.iloc[0][trace_col_map[traces.name]]) num_ophys_timestamps = ophys_timestamps.value.shape[0] if num_trace_timepoints != num_ophys_timestamps: raise RuntimeError( diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/events.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/events.py index 49c431f66a..cfc39f2e53 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/events.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/events.py @@ -4,26 +4,18 @@ from hdmf.backends.hdf5 import H5DataIO from pynwb import NWBFile -from allensdk.brain_observatory.behavior.data_files.event_detection_file \ - import \ - EventDetectionFile +from allensdk.brain_observatory.behavior.data_files.event_detection_file import EventDetectionFile from allensdk.core import DataObject -from allensdk.core import \ - DataFileReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .rois_mixin import \ - RoisMixin -from allensdk.brain_observatory.behavior.event_detection import \ - filter_events_array -from allensdk.brain_observatory.behavior.write_nwb.extensions\ - .event_detection.ndx_ophys_events import \ - OphysEventDetection - - -class Events(DataObject, RoisMixin, DataFileReadableInterface, - NwbReadableInterface, NwbWritableInterface): +from allensdk.core import DataFileReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.rois_mixin import RoisMixin +from allensdk.brain_observatory.behavior.event_detection import filter_events_array +from allensdk.brain_observatory.behavior.write_nwb.extensions.event_detection.ndx_ophys_events import ( + OphysEventDetection, +) + + +class Events(DataObject, RoisMixin, DataFileReadableInterface, NwbReadableInterface, NwbWritableInterface): """Events columns: events: np.array @@ -31,12 +23,15 @@ class Events(DataObject, RoisMixin, DataFileReadableInterface, noise_std: float cell_roi_id: int """ - def __init__(self, - events: np.ndarray, - events_meta: pd.DataFrame, - frame_rate_hz: float, - filter_scale_seconds: float = 2.0/31.0, - filter_n_time_steps: int = 20): + + def __init__( + self, + events: np.ndarray, + events_meta: pd.DataFrame, + frame_rate_hz: float, + filter_scale_seconds: float = 2.0 / 31.0, + filter_n_time_steps: int = 20, + ): """ Parameters ---------- @@ -55,44 +50,51 @@ def __init__(self, """ filtered_events = filter_events_array( - arr=events, - scale=filter_scale_seconds*frame_rate_hz, - n_time_steps=filter_n_time_steps) + arr=events, scale=filter_scale_seconds * frame_rate_hz, n_time_steps=filter_n_time_steps + ) # Convert matrix to list of 1d arrays so that it can be stored # in a single column of the dataframe events = [x for x in events] filtered_events = [x for x in filtered_events] - df = pd.DataFrame({ - 'events': events, - 'filtered_events': filtered_events, - 'lambda': events_meta['lambda'], - 'noise_std': events_meta['noise_std'], - 'cell_roi_id': events_meta['cell_roi_id'] - }) - super().__init__(name='events', value=df) + df = pd.DataFrame( + { + "events": events, + "filtered_events": filtered_events, + "lambda": events_meta["lambda"], + "noise_std": events_meta["noise_std"], + "cell_roi_id": events_meta["cell_roi_id"], + } + ) + super().__init__(name="events", value=df) @classmethod - def from_data_file(cls, - events_file: EventDetectionFile, - filter_scale_seconds: float = 2.0/31.0, - filter_n_time_steps: int = 20, - frame_rate_hz: Optional[float] = None) -> "Events": + def from_data_file( + cls, + events_file: EventDetectionFile, + filter_scale_seconds: float = 2.0 / 31.0, + filter_n_time_steps: int = 20, + frame_rate_hz: Optional[float] = None, + ) -> "Events": events, events_meta = events_file.data - return cls(events=events, - events_meta=events_meta, - filter_scale_seconds=filter_scale_seconds, - filter_n_time_steps=filter_n_time_steps, - frame_rate_hz=frame_rate_hz) + return cls( + events=events, + events_meta=events_meta, + filter_scale_seconds=filter_scale_seconds, + filter_n_time_steps=filter_n_time_steps, + frame_rate_hz=frame_rate_hz, + ) @classmethod - def from_nwb(cls, - nwbfile: NWBFile, - filter_scale_seconds: float = 2.0/31.0, - filter_n_time_steps: int = 20, - frame_rate_hz: Optional[float] = None) -> "Events": - event_detection = nwbfile.processing['ophys']['event_detection'] + def from_nwb( + cls, + nwbfile: NWBFile, + filter_scale_seconds: float = 2.0 / 31.0, + filter_n_time_steps: int = 20, + frame_rate_hz: Optional[float] = None, + ) -> "Events": + event_detection = nwbfile.processing["ophys"]["event_detection"] # NOTE: The rois with events are stored in event detection partial_cell_specimen_table = event_detection.rois.to_dataframe() @@ -101,47 +103,48 @@ def from_nwb(cls, # events stored time x roi. Change back to roi x time events = events.T - events_meta = pd.DataFrame({ - 'cell_roi_id': partial_cell_specimen_table.index, - 'lambda': event_detection.lambdas[:], - 'noise_std': event_detection.noise_stds[:] - }) - return cls(events=events, - events_meta=events_meta, - filter_scale_seconds=filter_scale_seconds, - filter_n_time_steps=filter_n_time_steps, - frame_rate_hz=frame_rate_hz) + events_meta = pd.DataFrame( + { + "cell_roi_id": partial_cell_specimen_table.index, + "lambda": event_detection.lambdas[:], + "noise_std": event_detection.noise_stds[:], + } + ) + return cls( + events=events, + events_meta=events_meta, + filter_scale_seconds=filter_scale_seconds, + filter_n_time_steps=filter_n_time_steps, + frame_rate_hz=frame_rate_hz, + ) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - events = self.value.set_index('cell_roi_id') + events = self.value.set_index("cell_roi_id") - ophys_module = nwbfile.processing['ophys'] - dff_interface = ophys_module.data_interfaces['dff'] - traces = dff_interface.roi_response_series['traces'] - seg_interface = ophys_module.data_interfaces['image_segmentation'] + ophys_module = nwbfile.processing["ophys"] + dff_interface = ophys_module.data_interfaces["dff"] + traces = dff_interface.roi_response_series["traces"] + seg_interface = ophys_module.data_interfaces["image_segmentation"] - cell_specimen_table = ( - seg_interface.plane_segmentations['cell_specimen_table']) + cell_specimen_table = seg_interface.plane_segmentations["cell_specimen_table"] cell_specimen_df = cell_specimen_table.to_dataframe() # We only want to store the subset of rois that have events data - rois_with_events_indices = [cell_specimen_df.index.get_loc(label) - for label in events.index] + rois_with_events_indices = [cell_specimen_df.index.get_loc(label) for label in events.index] roi_table_region = cell_specimen_table.create_roi_table_region( - description="Cells with detected events", - region=rois_with_events_indices) + description="Cells with detected events", region=rois_with_events_indices + ) - events_data = np.vstack(events['events']) + events_data = np.vstack(events["events"]) events = OphysEventDetection( # time x rois instead of rois x time # store using compression since sparse data=H5DataIO(events_data.T, compression=True), - - lambdas=events['lambda'].values, - noise_stds=events['noise_std'].values, - unit='N/A', + lambdas=events["lambda"].values, + noise_stds=events["noise_std"].values, + unit="N/A", rois=roi_table_region, - timestamps=traces.timestamps + timestamps=traces.timestamps, ) ophys_module.add_data_interface(events) diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/rois_mixin.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/rois_mixin.py index dda58a7a67..734d3b1ced 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/rois_mixin.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/rois_mixin.py @@ -7,10 +7,10 @@ class RoisMixin: """A mixin for a collection of rois stored as a dataframe (._value is a dataframe)""" + _value: pd.DataFrame - def filter_and_reorder(self, roi_ids: np.ndarray, - raise_if_rois_missing=True): + def filter_and_reorder(self, roi_ids: np.ndarray, raise_if_rois_missing=True): """Orders dataframe according to input roi_ids. Will also filter dataframe to contain only rois given by roi_ids. Use for, ie excluding invalid rois @@ -36,9 +36,9 @@ def filter_and_reorder(self, roi_ids: np.ndarray, RuntimeError if raise_if_rois_missing and there are input roi_ids not in dataframe """ + def handle_rois_in_input_not_in_dataframe(): - msg = f'Input contains roi ids not in ' \ - f'{type(self).__name__}.' + msg = f"Input contains roi ids not in {type(self).__name__}." if raise_if_rois_missing: raise RuntimeError(msg) warnings.warn(msg) @@ -50,7 +50,7 @@ def handle_rois_in_input_not_in_dataframe(): # (adding NaN records coerces int to float) for c in self._value: # Skipping column added due to reset_index - if c == 'index': + if c == "index": continue if self._value[c].dtype != original_dtypes[c]: @@ -60,12 +60,10 @@ def handle_rois_in_input_not_in_dataframe(): original_index_type = self._value.index.dtype original_dtypes = self._value.dtypes if original_index_name is None: - original_index_name = 'index' + original_index_name = "index" - if original_index_name != 'cell_roi_id': - self._value = (self._value - .reset_index() - .set_index('cell_roi_id')) + if original_index_name != "cell_roi_id": + self._value = self._value.reset_index().set_index("cell_roi_id") # Reorders dataframe according to roi_ids self._value = self._value.reindex(roi_ids) @@ -76,14 +74,12 @@ def handle_rois_in_input_not_in_dataframe(): # There are some roi ids in input not in index. handle_rois_in_input_not_in_dataframe() - if original_index_name != 'cell_roi_id': + if original_index_name != "cell_roi_id": # Set it back to the original index - self._value = (self._value - .reset_index() - .set_index(original_index_name)) + self._value = self._value.reset_index().set_index(original_index_name) # Set index back to original dtype # (can get coerced from int to float) self._value.index = self._value.index.astype(original_index_type) - if original_index_name == 'index': + if original_index_name == "index": # Set it back to None self._value.index.name = None diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/corrected_fluorescence_traces.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/corrected_fluorescence_traces.py index 82f8b7e821..6712e42923 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/corrected_fluorescence_traces.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/corrected_fluorescence_traces.py @@ -47,33 +47,15 @@ def __init__(self, traces: pd.DataFrame): @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "CorrectedFluorescenceTraces": - corr_fluorescence_traces_nwb = nwbfile.processing[ - "ophys" - ].data_interfaces["corrected_fluorescence"] + corr_fluorescence_traces_nwb = nwbfile.processing["ophys"].data_interfaces["corrected_fluorescence"] # f traces stored as timepoints x rois in NWB # We want rois x timepoints, hence the transpose - f_traces = ( - corr_fluorescence_traces_nwb.roi_response_series["traces"] - .data[:] - .T.copy() - ) - roi_ids = ( - corr_fluorescence_traces_nwb.roi_response_series["traces"] - .rois.table.id[:] - .copy() - ) + f_traces = corr_fluorescence_traces_nwb.roi_response_series["traces"].data[:].T.copy() + roi_ids = corr_fluorescence_traces_nwb.roi_response_series["traces"].rois.table.id[:].copy() # TODO: Remove try/except once VBO released. try: - r_values = ( - corr_fluorescence_traces_nwb.roi_response_series["r"] - .data[:] - .copy() - ) - rmse = ( - corr_fluorescence_traces_nwb.roi_response_series["RMSE"] - .data[:] - .copy() - ) + r_values = corr_fluorescence_traces_nwb.roi_response_series["r"].data[:].copy() + rmse = corr_fluorescence_traces_nwb.roi_response_series["RMSE"].data[:].copy() data_dict = { "corrected_fluorescence": [x for x in f_traces], "r": r_values, @@ -88,9 +70,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "CorrectedFluorescenceTraces": return CorrectedFluorescenceTraces(traces=df) @classmethod - def from_data_file( - cls, neuropil_corrected_file: NeuropilCorrectedFile - ) -> "CorrectedFluorescenceTraces": + def from_data_file(cls, neuropil_corrected_file: NeuropilCorrectedFile) -> "CorrectedFluorescenceTraces": corrected_fluorescence_traces = neuropil_corrected_file.data return cls(traces=corrected_fluorescence_traces) @@ -105,17 +85,8 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: # Create/Add corrected_fluorescence_traces modules and interfaces: ophys_module = nwbfile.processing["ophys"] - roi_table_region = ( - nwbfile.processing["ophys"] - .data_interfaces["dff"] - .roi_response_series["traces"] - .rois - ) # noqa: E501 - ophys_timestamps = ( - ophys_module.get_data_interface("dff") - .roi_response_series["traces"] - .timestamps - ) + roi_table_region = nwbfile.processing["ophys"].data_interfaces["dff"].roi_response_series["traces"].rois # noqa: E501 + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps f_interface = Fluorescence(name="corrected_fluorescence") ophys_module.add_data_interface(f_interface) diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/demixed_traces.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/demixed_traces.py index 41b25bd8ae..e4b63b7b42 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/demixed_traces.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/demixed_traces.py @@ -5,13 +5,9 @@ from allensdk.brain_observatory.behavior.data_files.demix_file import DemixFile from allensdk.core import DataObject -from allensdk.core import \ - DataFileReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .rois_mixin import \ - RoisMixin +from allensdk.core import DataFileReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.rois_mixin import RoisMixin class DemixedTraces( @@ -43,9 +39,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "DemixedTraces": # TODO Remove try/except once VBO released. try: demixed_traces_nwb = ( - nwbfile.processing["ophys"] - .data_interfaces["demixed_trace"] - .roi_response_series["traces"] + nwbfile.processing["ophys"].data_interfaces["demixed_trace"].roi_response_series["traces"] ) # f traces stored as timepoints x rois in NWB # We want rois x timepoints, hence the transpose @@ -74,17 +68,8 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: # Create/Add demixed_traces modules and interfaces: ophys_module = nwbfile.processing["ophys"] - roi_table_region = ( - nwbfile.processing["ophys"] - .data_interfaces["dff"] - .roi_response_series["traces"] - .rois - ) # noqa: E501 - ophys_timestamps = ( - ophys_module.get_data_interface("dff") - .roi_response_series["traces"] - .timestamps - ) + roi_table_region = nwbfile.processing["ophys"].data_interfaces["dff"].roi_response_series["traces"].rois # noqa: E501 + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps f_interface = Fluorescence(name="demixed_trace") ophys_module.add_data_interface(f_interface) diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/dff_traces.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/dff_traces.py index 3501e82426..d2ed3525dc 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/dff_traces.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/dff_traces.py @@ -5,21 +5,13 @@ from allensdk.brain_observatory.behavior.data_files.dff_file import DFFFile from allensdk.core import DataObject -from allensdk.core import \ - DataFileReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .rois_mixin import \ - RoisMixin -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .ophys_timestamps import \ - OphysTimestamps +from allensdk.core import DataFileReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.rois_mixin import RoisMixin +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import OphysTimestamps -class DFFTraces(DataObject, RoisMixin, - DataFileReadableInterface, NwbReadableInterface, - NwbWritableInterface): +class DFFTraces(DataObject, RoisMixin, DataFileReadableInterface, NwbReadableInterface, NwbWritableInterface): def __init__(self, traces: pd.DataFrame): """ Parameters @@ -29,50 +21,48 @@ def __init__(self, traces: pd.DataFrame): columns: dff: List of float """ - super().__init__(name='dff_traces', value=traces) + super().__init__(name="dff_traces", value=traces) - def to_nwb(self, nwbfile: NWBFile, - ophys_timestamps: OphysTimestamps) -> NWBFile: - dff_traces = self.value[['dff']] + def to_nwb(self, nwbfile: NWBFile, ophys_timestamps: OphysTimestamps) -> NWBFile: + dff_traces = self.value[["dff"]] - ophys_module = nwbfile.processing['ophys'] + ophys_module = nwbfile.processing["ophys"] # trace data in the form of rois x timepoints - trace_data = np.array([dff_traces.loc[cell_roi_id].dff - for cell_roi_id in dff_traces.index.values]) + trace_data = np.array([dff_traces.loc[cell_roi_id].dff for cell_roi_id in dff_traces.index.values]) - cell_specimen_table = nwbfile.processing['ophys'].data_interfaces[ - 'image_segmentation'].plane_segmentations[ - 'cell_specimen_table'] # noqa: E501 + cell_specimen_table = ( + nwbfile.processing["ophys"].data_interfaces["image_segmentation"].plane_segmentations["cell_specimen_table"] + ) # noqa: E501 roi_table_region = cell_specimen_table.create_roi_table_region( - description="segmented cells labeled by cell_specimen_id", - region=slice(len(dff_traces))) + description="segmented cells labeled by cell_specimen_id", region=slice(len(dff_traces)) + ) # Create/Add dff modules and interfaces: - assert dff_traces.index.name == 'cell_roi_id' - dff_interface = DfOverF(name='dff') + assert dff_traces.index.name == "cell_roi_id" + dff_interface = DfOverF(name="dff") ophys_module.add_data_interface(dff_interface) dff_interface.create_roi_response_series( - name='traces', + name="traces", data=trace_data.T, # Should be stored as timepoints x rois - unit='NA', + unit="NA", rois=roi_table_region, - timestamps=ophys_timestamps.value) + timestamps=ophys_timestamps.value, + ) return nwbfile @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "DFFTraces": try: - dff_nwb = nwbfile.processing[ - 'ophys'].data_interfaces['dff'].roi_response_series['traces'] + dff_nwb = nwbfile.processing["ophys"].data_interfaces["dff"].roi_response_series["traces"] # dff traces stored as timepoints x rois in NWB # We want rois x timepoints, hence the transpose dff_traces = dff_nwb.data[:].T - df = pd.DataFrame({'dff': [x for x in dff_traces]}, - index=pd.Index(data=dff_nwb.rois.table.id[:], - name='cell_roi_id')) + df = pd.DataFrame( + {"dff": [x for x in dff_traces]}, index=pd.Index(data=dff_nwb.rois.table.id[:], name="cell_roi_id") + ) return DFFTraces(traces=df) except KeyError: return None @@ -85,5 +75,5 @@ def from_data_file(cls, dff_file: DFFFile) -> "DFFTraces": def get_number_of_frames(self) -> int: """Returns the number of frames in the movie""" if self.value.empty: - raise RuntimeError('Cannot determine number of frames') - return len(self.value.iloc[0]['dff']) + raise RuntimeError("Cannot determine number of frames") + return len(self.value.iloc[0]["dff"]) diff --git a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/neuropil_traces.py b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/neuropil_traces.py index eddbae83bb..ff85d36c3a 100644 --- a/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/neuropil_traces.py +++ b/allensdk/brain_observatory/behavior/data_objects/cell_specimens/traces/neuropil_traces.py @@ -7,13 +7,9 @@ NeuropilFile, ) from allensdk.core import DataObject -from allensdk.core import \ - DataFileReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .rois_mixin import \ - RoisMixin +from allensdk.core import DataFileReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.rois_mixin import RoisMixin class NeuropilTraces( @@ -45,9 +41,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "NeuropilTraces": # TODO Remove try/except once VBO released. try: neuropil_traces_nwb = ( - nwbfile.processing["ophys"] - .data_interfaces["neuropil_trace"] - .roi_response_series["traces"] + nwbfile.processing["ophys"].data_interfaces["neuropil_trace"].roi_response_series["traces"] ) # f traces stored as timepoints x rois in NWB # We want rois x timepoints, hence the transpose @@ -76,17 +70,8 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: # Create/Add neuropil_traces modules and interfaces: ophys_module = nwbfile.processing["ophys"] - roi_table_region = ( - nwbfile.processing["ophys"] - .data_interfaces["dff"] - .roi_response_series["traces"] - .rois - ) # noqa: E501 - ophys_timestamps = ( - ophys_module.get_data_interface("dff") - .roi_response_series["traces"] - .timestamps - ) + roi_table_region = nwbfile.processing["ophys"].data_interfaces["dff"].roi_response_series["traces"].rois # noqa: E501 + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps f_interface = Fluorescence(name="neuropil_trace") ophys_module.add_data_interface(f_interface) diff --git a/allensdk/brain_observatory/behavior/data_objects/eye_tracking/eye_tracking_table.py b/allensdk/brain_observatory/behavior/data_objects/eye_tracking/eye_tracking_table.py index 7b9a7b0e13..6fc19b4e2d 100644 --- a/allensdk/brain_observatory/behavior/data_objects/eye_tracking/eye_tracking_table.py +++ b/allensdk/brain_observatory/behavior/data_objects/eye_tracking/eye_tracking_table.py @@ -6,37 +6,32 @@ import pandas as pd from pynwb import NWBFile, TimeSeries -from allensdk.brain_observatory.behavior.data_files.eye_tracking_video import \ - EyeTrackingVideo -from allensdk.brain_observatory.behavior.data_objects import ( - StimulusTimestamps) -from allensdk.brain_observatory.behavior.data_files.eye_tracking_file import \ - EyeTrackingFile -from allensdk.brain_observatory.behavior.\ - data_files.eye_tracking_metadata_file import EyeTrackingMetadataFile +from allensdk.brain_observatory.behavior.data_files.eye_tracking_video import EyeTrackingVideo +from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps +from allensdk.brain_observatory.behavior.data_files.eye_tracking_file import EyeTrackingFile +from allensdk.brain_observatory.behavior.data_files.eye_tracking_metadata_file import EyeTrackingMetadataFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface, DataFileReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.eye_tracking_processing import \ - process_eye_tracking_data, determine_outliers, determine_likely_blinks, \ - filter_on_blinks, EyeTrackingError -from allensdk.brain_observatory.nwb.eye_tracking.ndx_ellipse_eye_tracking \ - import \ - EllipseSeries, EllipseEyeTracking - - -class EyeTrackingTable(DataObject, DataFileReadableInterface, - NwbReadableInterface, NwbWritableInterface): +from allensdk.core import NwbReadableInterface, DataFileReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.eye_tracking_processing import ( + process_eye_tracking_data, + determine_outliers, + determine_likely_blinks, + filter_on_blinks, + EyeTrackingError, +) +from allensdk.brain_observatory.nwb.eye_tracking.ndx_ellipse_eye_tracking import EllipseSeries, EllipseEyeTracking + + +class EyeTrackingTable(DataObject, DataFileReadableInterface, NwbReadableInterface, NwbWritableInterface): """corneal, eye, and pupil ellipse fit data""" + _logger = logging.getLogger(__name__) def __init__(self, eye_tracking: pd.DataFrame): - super().__init__(name='eye_tracking', value=eye_tracking) + super().__init__(name="eye_tracking", value=eye_tracking) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - # If there is actually no data in this data object, # do not bother writing anything to the NWBFile if self.value.empty: @@ -45,52 +40,54 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: eye_tracking_df = self.value eye_tracking = EllipseSeries( - name='eye_tracking', - reference_frame='nose', - data=eye_tracking_df[['eye_center_x', 'eye_center_y']].values, - area=eye_tracking_df['eye_area'].values, - area_raw=eye_tracking_df['eye_area_raw'].values, - width=eye_tracking_df['eye_width'].values, - height=eye_tracking_df['eye_height'].values, - angle=eye_tracking_df['eye_phi'].values, - timestamps=eye_tracking_df['timestamps'].values + name="eye_tracking", + reference_frame="nose", + data=eye_tracking_df[["eye_center_x", "eye_center_y"]].values, + area=eye_tracking_df["eye_area"].values, + area_raw=eye_tracking_df["eye_area_raw"].values, + width=eye_tracking_df["eye_width"].values, + height=eye_tracking_df["eye_height"].values, + angle=eye_tracking_df["eye_phi"].values, + timestamps=eye_tracking_df["timestamps"].values, ) pupil_tracking = EllipseSeries( - name='pupil_tracking', - reference_frame='nose', - data=eye_tracking_df[['pupil_center_x', 'pupil_center_y']].values, - area=eye_tracking_df['pupil_area'].values, - area_raw=eye_tracking_df['pupil_area_raw'].values, - width=eye_tracking_df['pupil_width'].values, - height=eye_tracking_df['pupil_height'].values, - angle=eye_tracking_df['pupil_phi'].values, - timestamps=eye_tracking + name="pupil_tracking", + reference_frame="nose", + data=eye_tracking_df[["pupil_center_x", "pupil_center_y"]].values, + area=eye_tracking_df["pupil_area"].values, + area_raw=eye_tracking_df["pupil_area_raw"].values, + width=eye_tracking_df["pupil_width"].values, + height=eye_tracking_df["pupil_height"].values, + angle=eye_tracking_df["pupil_phi"].values, + timestamps=eye_tracking, ) corneal_reflection_tracking = EllipseSeries( - name='corneal_reflection_tracking', - reference_frame='nose', - data=eye_tracking_df[['cr_center_x', 'cr_center_y']].values, - area=eye_tracking_df['cr_area'].values, - area_raw=eye_tracking_df['cr_area_raw'].values, - width=eye_tracking_df['cr_width'].values, - height=eye_tracking_df['cr_height'].values, - angle=eye_tracking_df['cr_phi'].values, - timestamps=eye_tracking + name="corneal_reflection_tracking", + reference_frame="nose", + data=eye_tracking_df[["cr_center_x", "cr_center_y"]].values, + area=eye_tracking_df["cr_area"].values, + area_raw=eye_tracking_df["cr_area_raw"].values, + width=eye_tracking_df["cr_width"].values, + height=eye_tracking_df["cr_height"].values, + angle=eye_tracking_df["cr_phi"].values, + timestamps=eye_tracking, ) - likely_blink = TimeSeries(timestamps=eye_tracking, - data=eye_tracking_df['likely_blink'].values, - name='likely_blink', - description='blinks', - unit='N/A') + likely_blink = TimeSeries( + timestamps=eye_tracking, + data=eye_tracking_df["likely_blink"].values, + name="likely_blink", + description="blinks", + unit="N/A", + ) ellipse_eye_tracking = EllipseEyeTracking( eye_tracking=eye_tracking, pupil_tracking=pupil_tracking, corneal_reflection_tracking=corneal_reflection_tracking, - likely_blink=likely_blink + likely_blink=likely_blink, ) nwbfile.add_acquisition(ellipse_eye_tracking) @@ -103,24 +100,40 @@ def _get_empty_df(cls) -> pd.DataFrame: names, but no data """ empty_data = dict() - for colname in ['timestamps', 'cr_area', 'eye_area', - 'pupil_area', 'likely_blink', 'pupil_area_raw', - 'cr_area_raw', 'eye_area_raw', 'cr_center_x', - 'cr_center_y', 'cr_width', 'cr_height', 'cr_phi', - 'eye_center_x', 'eye_center_y', 'eye_width', - 'eye_height', 'eye_phi', 'pupil_center_x', - 'pupil_center_y', 'pupil_width', 'pupil_height', - 'pupil_phi']: + for colname in [ + "timestamps", + "cr_area", + "eye_area", + "pupil_area", + "likely_blink", + "pupil_area_raw", + "cr_area_raw", + "eye_area_raw", + "cr_center_x", + "cr_center_y", + "cr_width", + "cr_height", + "cr_phi", + "eye_center_x", + "eye_center_y", + "eye_width", + "eye_height", + "eye_phi", + "pupil_center_x", + "pupil_center_y", + "pupil_width", + "pupil_height", + "pupil_phi", + ]: empty_data[colname] = [] - eye_tracking_data = pd.DataFrame(empty_data, - index=pd.Index([], name='frame')) + eye_tracking_data = pd.DataFrame(empty_data, index=pd.Index([], name="frame")) return eye_tracking_data @classmethod - def from_nwb(cls, nwbfile: NWBFile, - z_threshold: float = 3.0, - dilation_frames: int = 2) -> Optional["EyeTrackingTable"]: + def from_nwb( + cls, nwbfile: NWBFile, z_threshold: float = 3.0, dilation_frames: int = 2 + ) -> Optional["EyeTrackingTable"]: """ Parameters ----------- @@ -131,18 +144,17 @@ def from_nwb(cls, nwbfile: NWBFile, See from_lims for description """ try: - eye_tracking_acquisition = nwbfile.acquisition['EyeTracking'] + eye_tracking_acquisition = nwbfile.acquisition["EyeTracking"] except KeyError as e: - warnings.warn("This nwb file with identifier " - f"'{int(nwbfile.identifier)}' has no eye " - f"tracking data. (NWB error: {e})") + warnings.warn( + f"This nwb file with identifier '{int(nwbfile.identifier)}' has no eye tracking data. (NWB error: {e})" + ) eye_tracking_data = cls._get_empty_df() return EyeTrackingTable(eye_tracking=eye_tracking_data) eye_tracking = eye_tracking_acquisition.eye_tracking pupil_tracking = eye_tracking_acquisition.pupil_tracking - corneal_reflection_tracking = \ - eye_tracking_acquisition.corneal_reflection_tracking + corneal_reflection_tracking = eye_tracking_acquisition.corneal_reflection_tracking eye_tracking_dict = { "timestamps": eye_tracking.timestamps[:], @@ -150,42 +162,38 @@ def from_nwb(cls, nwbfile: NWBFile, "eye_area": eye_tracking.area_raw[:], "pupil_area": pupil_tracking.area_raw[:], "likely_blink": eye_tracking_acquisition.likely_blink.data[:], - "pupil_area_raw": pupil_tracking.area_raw[:], "cr_area_raw": corneal_reflection_tracking.area_raw[:], "eye_area_raw": eye_tracking.area_raw[:], - "cr_center_x": corneal_reflection_tracking.data[:, 0], "cr_center_y": corneal_reflection_tracking.data[:, 1], "cr_width": corneal_reflection_tracking.width[:], "cr_height": corneal_reflection_tracking.height[:], "cr_phi": corneal_reflection_tracking.angle[:], - "eye_center_x": eye_tracking.data[:, 0], "eye_center_y": eye_tracking.data[:, 1], "eye_width": eye_tracking.width[:], "eye_height": eye_tracking.height[:], "eye_phi": eye_tracking.angle[:], - "pupil_center_x": pupil_tracking.data[:, 0], "pupil_center_y": pupil_tracking.data[:, 1], "pupil_width": pupil_tracking.width[:], "pupil_height": pupil_tracking.height[:], "pupil_phi": pupil_tracking.angle[:], - } eye_tracking_data = pd.DataFrame(eye_tracking_dict) - eye_tracking_data.index = eye_tracking_data.index.rename('frame') + eye_tracking_data.index = eye_tracking_data.index.rename("frame") # re-calculate likely blinks for new z_threshold and dilate_frames - area_df = eye_tracking_data[['eye_area_raw', 'pupil_area_raw']] + area_df = eye_tracking_data[["eye_area_raw", "pupil_area_raw"]] outliers = determine_outliers(area_df, z_threshold=z_threshold) likely_blinks = determine_likely_blinks( - eye_tracking_data['eye_area_raw'], - eye_tracking_data['pupil_area_raw'], + eye_tracking_data["eye_area_raw"], + eye_tracking_data["pupil_area_raw"], outliers, - dilation_frames=dilation_frames) + dilation_frames=dilation_frames, + ) eye_tracking_data["likely_blink"] = likely_blinks filter_on_blinks(eye_tracking_data) @@ -194,14 +202,15 @@ def from_nwb(cls, nwbfile: NWBFile, @classmethod def from_data_file( - cls, - data_file: EyeTrackingFile, - stimulus_timestamps: StimulusTimestamps, - metadata_file: Optional[EyeTrackingMetadataFile] = None, - video: Optional[EyeTrackingVideo] = None, - z_threshold: float = 3.0, - dilation_frames: int = 2, - empty_on_fail: bool = False) -> "EyeTrackingTable": + cls, + data_file: EyeTrackingFile, + stimulus_timestamps: StimulusTimestamps, + metadata_file: Optional[EyeTrackingMetadataFile] = None, + video: Optional[EyeTrackingVideo] = None, + z_threshold: float = 3.0, + dilation_frames: int = 2, + empty_on_fail: bool = False, + ) -> "EyeTrackingTable": """ Parameters ---------- @@ -222,9 +231,9 @@ def from_data_file( video: EyeTrackingVideo. Used for detecting if video is MVR. Either this or metadata_file must be given. """ - cls._logger.info(f"Getting eye_tracking_data with " - f"'z_threshold={z_threshold}', " - f"'dilation_frames={dilation_frames}'") + cls._logger.info( + f"Getting eye_tracking_data with 'z_threshold={z_threshold}', 'dilation_frames={dilation_frames}'" + ) # TODO currently the only codepath that doesn't pass metadata file or # video is BehaviorSession.from_json. Once we add metadata file or @@ -232,15 +241,16 @@ def from_data_file( # `if metadata_file is not None or video is not None else False` clause # to always check if metadata frame is present is_metadata_frame_present = ( - _is_metadata_frame_present( - metadata_file=metadata_file, - video=video - ) if metadata_file is not None or video is not None else False) + _is_metadata_frame_present(metadata_file=metadata_file, video=video) + if metadata_file is not None or video is not None + else False + ) try: frames, stimulus_timestamps = cls._validate_frame_time_alignment( - frames=data_file.data.index.values, times=stimulus_timestamps, - is_metadata_frame_present=is_metadata_frame_present + frames=data_file.data.index.values, + times=stimulus_timestamps, + is_metadata_frame_present=is_metadata_frame_present, ) eye_data = data_file.data.loc[frames] @@ -249,10 +259,8 @@ def from_data_file( eye_data.index -= 1 eye_tracking_data = process_eye_tracking_data( - eye_data, - stimulus_timestamps.value, - z_threshold, - dilation_frames) + eye_data, stimulus_timestamps.value, z_threshold, dilation_frames + ) except EyeTrackingError as err: if empty_on_fail: msg = f"{str(err)}\n" @@ -266,10 +274,7 @@ def from_data_file( @classmethod def _validate_frame_time_alignment( - cls, - frames: np.ndarray, - times: StimulusTimestamps, - is_metadata_frame_present: bool = False + cls, frames: np.ndarray, times: StimulusTimestamps, is_metadata_frame_present: bool = False ) -> Tuple[np.ndarray, StimulusTimestamps]: """ Checks whether frames or timestamps need to be modified in order to be @@ -295,9 +300,10 @@ def _validate_frame_time_alignment( if is_metadata_frame_present: # Remove the metadata frame cls._logger.info( - f'Number of eye tracking timestamps: {len(times.value)}. ' - f'Number of eye tracking frames: {len(frames)}. ' - f'Removing metadata frame') + f"Number of eye tracking timestamps: {len(times.value)}. " + f"Number of eye tracking frames: {len(frames)}. " + f"Removing metadata frame" + ) frames = frames[1:] if len(times) > len(frames): @@ -306,21 +312,21 @@ def _validate_frame_time_alignment( # See discussion in https://github.com/AllenInstitute/AllenSDK/issues/2376 # noqa # Truncate timestamps to match the number of frames cls._logger.info( - f'Number of eye tracking timestamps: {len(times.value)}. ' - f'Number of eye tracking frames: {len(frames)}. ' - f'Truncating timestamps') - times = times.update_timestamps( - timestamps=times.value[:len(frames)]) + f"Number of eye tracking timestamps: {len(times.value)}. " + f"Number of eye tracking frames: {len(frames)}. " + f"Truncating timestamps" + ) + times = times.update_timestamps(timestamps=times.value[: len(frames)]) elif len(frames) > len(times): raise EyeTrackingError( - f'Number of eye tracking timestamps: {len(times.value)}. ' - f'Number of eye tracking frames: {len(frames)}. ' - f'We expect these to be equal') + f"Number of eye tracking timestamps: {len(times.value)}. " + f"Number of eye tracking frames: {len(frames)}. " + f"We expect these to be equal" + ) return frames, times -def get_lost_frames( - eye_tracking_metadata: EyeTrackingMetadataFile) -> np.ndarray: +def get_lost_frames(eye_tracking_metadata: EyeTrackingMetadataFile) -> np.ndarray: """ Get lost frames from the video metadata json Must subtract one since the json starts indexing at 1 @@ -348,28 +354,26 @@ def get_lost_frames( camera_metadata = eye_tracking_metadata.data - lost_count = camera_metadata['RecordingReport']['FramesLostCount'] + lost_count = camera_metadata["RecordingReport"]["FramesLostCount"] if lost_count == 0: return [] - lost_string = camera_metadata['RecordingReport']['LostFrames'][0] - lost_spans = lost_string.split(',') + lost_string = camera_metadata["RecordingReport"]["LostFrames"][0] + lost_spans = lost_string.split(",") lost_frames = [] for span in lost_spans: - start_end = span.split('-') + start_end = span.split("-") if len(start_end) == 1: lost_frames.append(int(start_end[0])) else: - lost_frames.extend(np.arange(int(start_end[0]), - int(start_end[1])+1)) + lost_frames.extend(np.arange(int(start_end[0]), int(start_end[1]) + 1)) - return np.array(lost_frames)-1 + return np.array(lost_frames) - 1 def _is_metadata_frame_present( - metadata_file: Optional[EyeTrackingMetadataFile] = None, - video: Optional[EyeTrackingVideo] = None + metadata_file: Optional[EyeTrackingMetadataFile] = None, video: Optional[EyeTrackingVideo] = None ) -> bool: """Return whether a metadata frame was placed at the front of the eye tracking movie. Tries to determine this by using the fact that the MVR @@ -385,14 +389,11 @@ def _is_metadata_frame_present( ValueError if neither metadata_file or video is given """ if metadata_file is not None: - video_file_name = \ - metadata_file.data['RecordingReport']['VideoOutputFileName']\ - .lower() + video_file_name = metadata_file.data["RecordingReport"]["VideoOutputFileName"].lower() elif video is not None: video_file_name = video.filepath.lower() else: - raise ValueError('Either metadata_file or video must be given') + raise ValueError("Either metadata_file or video must be given") video_file_name = Path(video_file_name) - return video_file_name.suffix == '.mp4' or \ - 'mvr' in video_file_name.name + return video_file_name.suffix == ".mp4" or "mvr" in video_file_name.name diff --git a/allensdk/brain_observatory/behavior/data_objects/eye_tracking/rig_geometry.py b/allensdk/brain_observatory/behavior/data_objects/eye_tracking/rig_geometry.py index d61d3b6938..6d4ab6991a 100644 --- a/allensdk/brain_observatory/behavior/data_objects/eye_tracking/rig_geometry.py +++ b/allensdk/brain_observatory/behavior/data_objects/eye_tracking/rig_geometry.py @@ -8,18 +8,16 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - LimsReadableInterface, JsonReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.schemas import \ - OphysEyeTrackingRigMetadataSchema +from allensdk.core import LimsReadableInterface, JsonReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.schemas import OphysEyeTrackingRigMetadataSchema from allensdk.brain_observatory.nwb import load_pynwb_extension from allensdk.internal.api import PostgresQueryMixin class Coordinates: """Represents coordinates in 3d space""" + def __init__(self, x: float, y: float, z: float): self._x = x self._y = y @@ -44,30 +42,27 @@ def __iter__(self): def __eq__(self, other): if type(other) not in (type(self), list): - raise ValueError(f'Do not know how to compare with type ' - f'{type(other)}') + raise ValueError(f"Do not know how to compare with type {type(other)}") if isinstance(other, list): - return self._x == other[0] and \ - self._y == other[1] and \ - self._z == other[2] + return self._x == other[0] and self._y == other[1] and self._z == other[2] else: - return self._x == other.x and \ - self._y == other.y and \ - self._z == other.z + return self._x == other.x and self._y == other.y and self._z == other.z def __str__(self): - return f'[{self._x}, {self._y}, {self._z}]' - - -class RigGeometry(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface, NwbWritableInterface): - def __init__(self, equipment: str, - monitor_position_mm: Coordinates, - monitor_rotation_deg: Coordinates, - camera_position_mm: Coordinates, - camera_rotation_deg: Coordinates, - led_position: Coordinates): - super().__init__(name='rig_geometry', value=None, is_value_self=True) + return f"[{self._x}, {self._y}, {self._z}]" + + +class RigGeometry(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface, NwbWritableInterface): + def __init__( + self, + equipment: str, + monitor_position_mm: Coordinates, + monitor_rotation_deg: Coordinates, + camera_position_mm: Coordinates, + camera_rotation_deg: Coordinates, + led_position: Coordinates, + ): + super().__init__(name="rig_geometry", value=None, is_value_self=True) self._monitor_position_mm = monitor_position_mm self._monitor_rotation_deg = monitor_rotation_deg self._camera_position_mm = camera_position_mm @@ -101,11 +96,10 @@ def equipment(self): def to_nwb(self, nwbfile: NWBFile) -> NWBFile: eye_tracking_rig_mod = pynwb.ProcessingModule( - name='eye_tracking_rig_metadata', - description='Eye tracking rig metadata module') + name="eye_tracking_rig_metadata", description="Eye tracking rig metadata module" + ) - nwb_extension = load_pynwb_extension( - OphysEyeTrackingRigMetadataSchema, 'ndx-aibs-behavior-ophys') + nwb_extension = load_pynwb_extension(OphysEyeTrackingRigMetadataSchema, "ndx-aibs-behavior-ophys") rig_metadata = nwb_extension( name="eye_tracking_rig_metadata", @@ -119,7 +113,7 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: monitor_rotation=list(self._monitor_rotation_deg), monitor_rotation__unit_of_measurement="deg", camera_rotation=list(self._camera_rotation_deg), - camera_rotation__unit_of_measurement="deg" + camera_rotation__unit_of_measurement="deg", ) eye_tracking_rig_mod.add_data_interface(rig_metadata) @@ -129,40 +123,31 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: @classmethod def from_nwb(cls, nwbfile: NWBFile) -> Optional["RigGeometry"]: try: - et_mod = \ - nwbfile.get_processing_module("eye_tracking_rig_metadata") + et_mod = nwbfile.get_processing_module("eye_tracking_rig_metadata") except KeyError as e: - warnings.warn("This nwb file with identifier " - f"'{int(nwbfile.identifier)}' has no eye " - f"tracking rig metadata. (NWB error: {e})") + warnings.warn( + "This nwb file with identifier " + f"'{int(nwbfile.identifier)}' has no eye " + f"tracking rig metadata. (NWB error: {e})" + ) return None meta = et_mod.get_data_interface("eye_tracking_rig_metadata") monitor_position = meta.monitor_position[:] - monitor_position = (monitor_position.tolist() - if isinstance(monitor_position, np.ndarray) - else monitor_position) + monitor_position = monitor_position.tolist() if isinstance(monitor_position, np.ndarray) else monitor_position monitor_rotation = meta.monitor_rotation[:] - monitor_rotation = (monitor_rotation.tolist() - if isinstance(monitor_rotation, np.ndarray) - else monitor_rotation) + monitor_rotation = monitor_rotation.tolist() if isinstance(monitor_rotation, np.ndarray) else monitor_rotation camera_position = meta.camera_position[:] - camera_position = (camera_position.tolist() - if isinstance(camera_position, np.ndarray) - else camera_position) + camera_position = camera_position.tolist() if isinstance(camera_position, np.ndarray) else camera_position camera_rotation = meta.camera_rotation[:] - camera_rotation = (camera_rotation.tolist() - if isinstance(camera_rotation, np.ndarray) - else camera_rotation) + camera_rotation = camera_rotation.tolist() if isinstance(camera_rotation, np.ndarray) else camera_rotation led_position = meta.led_position[:] - led_position = (led_position.tolist() - if isinstance(led_position, np.ndarray) - else led_position) + led_position = led_position.tolist() if isinstance(led_position, np.ndarray) else led_position return RigGeometry( equipment=meta.equipment, @@ -170,36 +155,34 @@ def from_nwb(cls, nwbfile: NWBFile) -> Optional["RigGeometry"]: camera_position_mm=Coordinates(*camera_position), led_position=Coordinates(*led_position), monitor_rotation_deg=Coordinates(*monitor_rotation), - camera_rotation_deg=Coordinates(*camera_rotation) + camera_rotation_deg=Coordinates(*camera_rotation), ) @classmethod def from_json(cls, dict_repr: dict) -> "RigGeometry": - rg = dict_repr['eye_tracking_rig_geometry'] + rg = dict_repr["eye_tracking_rig_geometry"] return RigGeometry( - equipment=rg['equipment'], - monitor_position_mm=Coordinates(*rg['monitor_position_mm']), - monitor_rotation_deg=Coordinates(*rg['monitor_rotation_deg']), - camera_position_mm=Coordinates(*rg['camera_position_mm']), - camera_rotation_deg=Coordinates(*rg['camera_rotation_deg']), - led_position=Coordinates(*rg['led_position']) + equipment=rg["equipment"], + monitor_position_mm=Coordinates(*rg["monitor_position_mm"]), + monitor_rotation_deg=Coordinates(*rg["monitor_rotation_deg"]), + camera_position_mm=Coordinates(*rg["camera_position_mm"]), + camera_rotation_deg=Coordinates(*rg["camera_rotation_deg"]), + led_position=Coordinates(*rg["led_position"]), ) @classmethod def from_lims( - cls, - lims_db: PostgresQueryMixin, - behavior_session_id: Optional[int] = None, - ophys_experiment_id: Optional[int] = None + cls, + lims_db: PostgresQueryMixin, + behavior_session_id: Optional[int] = None, + ophys_experiment_id: Optional[int] = None, ) -> Optional["RigGeometry"]: if behavior_session_id is None and ophys_experiment_id is None: - raise ValueError('Must provide either behavior_session_id or ' - 'ophys_experiment_id') + raise ValueError("Must provide either behavior_session_id or ophys_experiment_id") if behavior_session_id is not None and ophys_experiment_id is not None: - raise ValueError('Supply either ophys_experiment_id or ' - 'behavior_session_id') + raise ValueError("Supply either ophys_experiment_id or behavior_session_id") if ophys_experiment_id is not None: - query = f''' + query = f""" SELECT oec.*, oect.name as config_type, equipment.name as equipment_name FROM ophys_experiments oe @@ -212,9 +195,9 @@ def from_lims( WHERE oe.id = {ophys_experiment_id} AND oec.active_date <= os.date_of_acquisition AND oect.name IN ('eye camera position', 'led position', 'screen position') - ''' # noqa E501 + """ # noqa E501 else: - query = f''' + query = f""" SELECT oec.*, oect.name as config_type, equipment.name as equipment_name FROM behavior_sessions bs @@ -228,7 +211,7 @@ def from_lims( WHERE bs.id = {behavior_session_id} AND oec.active_date <= os.date_of_acquisition AND oect.name IN ('eye camera position', 'led position', 'screen position') - ''' # noqa E501 + """ # noqa E501 # Get the raw data rig_geometry = pd.read_sql(query, lims_db.get_connection()) @@ -238,53 +221,35 @@ def from_lims( # Map the config types to new names rig_geometry_config_type_map = { - 'eye camera position': 'camera', - 'screen position': 'monitor', - 'led position': 'led' + "eye camera position": "camera", + "screen position": "monitor", + "led position": "led", } - rig_geometry['config_type'] = rig_geometry['config_type'] \ - .map(rig_geometry_config_type_map) + rig_geometry["config_type"] = rig_geometry["config_type"].map(rig_geometry_config_type_map) - rig_geometry = cls._select_most_recent_geometry( - rig_geometry=rig_geometry) + rig_geometry = cls._select_most_recent_geometry(rig_geometry=rig_geometry) # Construct dictionary for positions - position = rig_geometry[['center_x_mm', 'center_y_mm', 'center_z_mm']] - position.index = [ - f'{v}_position_mm' if v != 'led' - else f'{v}_position' for v in position.index] - position = position.to_dict(orient='index') + position = rig_geometry[["center_x_mm", "center_y_mm", "center_z_mm"]] + position.index = [f"{v}_position_mm" if v != "led" else f"{v}_position" for v in position.index] + position = position.to_dict(orient="index") position = { - config_type: - Coordinates( - values['center_x_mm'], - values['center_y_mm'], - values['center_z_mm']) + config_type: Coordinates(values["center_x_mm"], values["center_y_mm"], values["center_z_mm"]) for config_type, values in position.items() } # Construct dictionary for rotations - rotation = rig_geometry[['rotation_x_deg', 'rotation_y_deg', - 'rotation_z_deg']] - rotation = rotation[rotation.index != 'led'] - rotation.index = [f'{v}_rotation_deg' for v in rotation.index] - rotation = rotation.to_dict(orient='index') + rotation = rig_geometry[["rotation_x_deg", "rotation_y_deg", "rotation_z_deg"]] + rotation = rotation[rotation.index != "led"] + rotation.index = [f"{v}_rotation_deg" for v in rotation.index] + rotation = rotation.to_dict(orient="index") rotation = { - config_type: - Coordinates( - values['rotation_x_deg'], - values['rotation_y_deg'], - values['rotation_z_deg'] - ) - for config_type, values in rotation.items() + config_type: Coordinates(values["rotation_x_deg"], values["rotation_y_deg"], values["rotation_z_deg"]) + for config_type, values in rotation.items() } # Combine the dictionaries - rig_geometry = { - **position, - **rotation, - 'equipment': rig_geometry['equipment_name'].iloc[0] - } + rig_geometry = {**position, **rotation, "equipment": rig_geometry["equipment_name"].iloc[0]} return RigGeometry(**rig_geometry) @staticmethod @@ -303,7 +268,6 @@ def _select_most_recent_geometry(rig_geometry: pd.DataFrame): date_of_acquisition of the session (only relevant for retrieving from LIMS) """ - rig_geometry = rig_geometry.sort_values('active_date', ascending=False) - rig_geometry = rig_geometry.groupby('config_type') \ - .apply(lambda x: x.iloc[0]) + rig_geometry = rig_geometry.sort_values("active_date", ascending=False) + rig_geometry = rig_geometry.groupby("config_type").apply(lambda x: x.iloc[0]) return rig_geometry diff --git a/allensdk/brain_observatory/behavior/data_objects/licks.py b/allensdk/brain_observatory/behavior/data_objects/licks.py index 2561d8bd9c..af52ebecad 100644 --- a/allensdk/brain_observatory/behavior/data_objects/licks.py +++ b/allensdk/brain_observatory/behavior/data_objects/licks.py @@ -8,16 +8,12 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface -from allensdk.core import \ - NwbWritableInterface +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface +from allensdk.core import NwbWritableInterface -class Licks(DataObject, StimulusFileReadableInterface, NwbReadableInterface, - NwbWritableInterface): +class Licks(DataObject, StimulusFileReadableInterface, NwbReadableInterface, NwbWritableInterface): _logger = logging.getLogger(__name__) def __init__(self, licks: pd.DataFrame): @@ -29,14 +25,12 @@ def __init__(self, licks: pd.DataFrame): - frame: int frame number in which there was a lick """ - super().__init__(name='licks', value=licks) + super().__init__(name="licks", value=licks) @classmethod def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile, - stimulus_timestamps: Union[StimulusTimestamps, np.ndarray] - ) -> "Licks": + cls, stimulus_file: BehaviorStimulusFile, stimulus_timestamps: Union[StimulusTimestamps, np.ndarray] + ) -> "Licks": """Get lick data from pkl file. This function assumes that the first sensor in the list of lick_sensors is the desired lick sensor. @@ -64,15 +58,16 @@ def from_stimulus_file( """ data = stimulus_file.data - lick_frames = (data["items"]["behavior"]["lick_sensors"][0] - ["lick_events"]) + lick_frames = data["items"]["behavior"]["lick_sensors"][0]["lick_events"] if isinstance(stimulus_timestamps, StimulusTimestamps): if not np.isclose(stimulus_timestamps.monitor_delay, 0.0): - msg = ("Instantiating licks with monitor_delay = " - f"{stimulus_timestamps.monitor_delay: .2e}; " - "monitor_delay should be zero for Licks " - "data object") + msg = ( + "Instantiating licks with monitor_delay = " + f"{stimulus_timestamps.monitor_delay: .2e}; " + "monitor_delay should be zero for Licks " + "data object" + ) raise RuntimeError(msg) lick_times = stimulus_timestamps.value @@ -93,9 +88,7 @@ def from_stimulus_file( if len(lick_frames) > 0: if lick_frames[-1] == len(lick_times): lick_frames = lick_frames[:-1] - cls._logger.error('removed last lick - ' - 'it fell outside of stimulus_timestamps ' - 'range') + cls._logger.error("removed last lick - it fell outside of stimulus_timestamps range") if isinstance(stimulus_timestamps, StimulusTimestamps): lick_times = np.array([lick_times[frame] for frame in lick_frames]) @@ -110,42 +103,36 @@ def from_stimulus_file( @classmethod def from_nwb(cls, nwbfile: NWBFile) -> Optional["Licks"]: - if 'licking' in nwbfile.processing: - lick_module = nwbfile.processing['licking'] - licks = lick_module.get_data_interface('licks') + if "licking" in nwbfile.processing: + lick_module = nwbfile.processing["licking"] + licks = lick_module.get_data_interface("licks") timestamps = licks.timestamps[:] frame = licks.data[:] else: timestamps = [] frame = [] - df = pd.DataFrame({ - 'timestamps': timestamps, - 'frame': frame - }) + df = pd.DataFrame({"timestamps": timestamps, "frame": frame}) return cls(licks=df) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - # If there is no lick data, do not write # anything to the NWB file (this is # expected for passive sessions) - if len(self.value['frame']) == 0: + if len(self.value["frame"]) == 0: return nwbfile lick_timeseries = TimeSeries( - name='licks', - data=self.value['frame'].values, - timestamps=self.value['timestamps'].values, - description=('Timestamps and stimulus presentation ' - 'frame indices for lick events'), - unit='N/A' + name="licks", + data=self.value["frame"].values, + timestamps=self.value["timestamps"].values, + description=("Timestamps and stimulus presentation frame indices for lick events"), + unit="N/A", ) # Add lick interface to nwb file, by way of a processing module: - licks_mod = ProcessingModule('licking', - 'Licking behavior processing module') + licks_mod = ProcessingModule("licking", "Licking behavior processing module") licks_mod.add_data_interface(lick_timeseries) nwbfile.add_processing_module(licks_mod) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_metadata.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_metadata.py index 9c0b95c56d..a5ef5a3176 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_metadata.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_metadata.py @@ -6,38 +6,25 @@ from pynwb import NWBFile from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.date_of_acquisition import \ - DateOfAcquisition +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.date_of_acquisition import ( + DateOfAcquisition, +) from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.core import \ - JsonReadableInterface, NwbReadableInterface, \ - LimsReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_session_uuid import \ - BehaviorSessionUUID -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.equipment import \ - Equipment -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.foraging_id import \ - ForagingId -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.project_code import \ - ProjectCode -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.session_type import \ - SessionType -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.stimulus_frame_rate import \ - StimulusFrameRate -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.subject_metadata import \ - SubjectMetadata +from allensdk.core import JsonReadableInterface, NwbReadableInterface, LimsReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_session_uuid import ( + BehaviorSessionUUID, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.equipment import Equipment +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.foraging_id import ForagingId +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.project_code import ProjectCode +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.session_type import SessionType +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.stimulus_frame_rate import ( + StimulusFrameRate, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.subject_metadata import SubjectMetadata from allensdk.brain_observatory.behavior.schemas import BehaviorMetadataSchema from allensdk.brain_observatory.nwb import load_pynwb_extension from allensdk.internal.api import PostgresQueryMixin @@ -54,8 +41,8 @@ r"\ATRAINING_2_gratings": "An operant behavior training session where a mouse must lick following a change in stimulus identity to earn rewards. Stimuli consist of full-field, square-wave static gratings with a spatial frequency of 0.04 cycles per degree. Gratings of 0 or 90 degrees are presented for 250 ms with a 500 ms intervening gray period. Delivered rewards are 10ul in volume, and the session lasts 60 minutes.", # noqa: E501 r"\ATRAINING_3_images": "An operant behavior training session where a mouse must lick following a change in stimulus identity to earn rewards. Stimuli consist of 8 natural scene images, for a total of 64 possible pairwise transitions. Images are shown for 250 ms with a 500 ms intervening gray period. Delivered rewards are 10ul in volume, and the session lasts for 60 minutes", # noqa: E501 r"\ATRAINING_4_images": "An operant behavior training session where a mouse must lick a spout following a change in stimulus identity to earn rewards. Stimuli consist of 8 natural scene images, for a total of 64 possible pairwise transitions. Images are shown for 250 ms with a 500 ms intervening gray period. Delivered rewards are 7ul in volume, and the session lasts for 60 minutes", # noqa: E501 - r"\ATRAINING_5_images": "An operant behavior training session where a mouse must lick a spout following a change in stimulus identity to earn rewards. Stimuli consist of 8 natural scene images, for a total of 64 possible pairwise transitions. Images are shown for 250 ms with a 500 ms intervening gray period. Delivered rewards are 7ul in volume. The session is 75 minutes long, with 5 minutes of gray screen before and after 60 minutes of behavior, followed by 10 repeats of a 30 second natural movie stimulus at the end of the session." # noqa: E501 - } + r"\ATRAINING_5_images": "An operant behavior training session where a mouse must lick a spout following a change in stimulus identity to earn rewards. Stimuli consist of 8 natural scene images, for a total of 64 possible pairwise transitions. Images are shown for 250 ms with a 500 ms intervening gray period. Delivered rewards are 7ul in volume. The session is 75 minutes long, with 5 minutes of gray screen before and after 60 minutes of behavior, followed by 10 repeats of a 30 second natural movie stimulus at the end of the session.", # noqa: E501 +} def get_expt_description(session_type: str) -> str: @@ -86,9 +73,11 @@ def get_expt_description(session_type: str) -> str: match.update({k: v}) if len(match) != 1: - emsg = (f"session type should match one and only one possible pattern " - f"template. '{session_type}' matched {len(match)} pattern " - "templates.") + emsg = ( + f"session type should match one and only one possible pattern " + f"template. '{session_type}' matched {len(match)} pattern " + "templates." + ) if len(match) > 1: emsg += f"{list(match.keys())}" emsg += f"the regex pattern templates are {list(description_dict)}" @@ -114,19 +103,18 @@ def get_task_parameters(data: Dict) -> Dict: A dict containing the task_parameters associated with this session. """ behavior = data["items"]["behavior"] - stimuli = behavior['stimuli'] + stimuli = behavior["stimuli"] config = behavior["config"] doc = config["DoC"] task_parameters = {} - task_parameters['blank_duration_sec'] = \ - [float(x) for x in doc['blank_duration_range']] + task_parameters["blank_duration_sec"] = [float(x) for x in doc["blank_duration_range"]] - if 'images' in stimuli: - stim_key = 'images' - elif 'grating' in stimuli: - stim_key = 'grating' + if "images" in stimuli: + stim_key = "images" + elif "grating" in stimuli: + stim_key = "grating" else: msg = "Cannot get stimulus_duration_sec\n" msg += "'images' and/or 'grating' not a valid " @@ -135,7 +123,7 @@ def get_task_parameters(data: Dict) -> Dict: msg += f"keys: {list(stimuli.keys())}" raise RuntimeError(msg) - stim_duration = stimuli[stim_key]['flash_interval_sec'] + stim_duration = stimuli[stim_key]["flash_interval_sec"] # from discussion in # https://github.com/AllenInstitute/AllenSDK/issues/1572 @@ -153,21 +141,19 @@ def get_task_parameters(data: Dict) -> Dict: else: stim_duration = stim_duration[0] - task_parameters['stimulus_duration_sec'] = stim_duration - - task_parameters['omitted_flash_fraction'] = \ - behavior['params'].get('flash_omit_probability', float('nan')) - task_parameters['response_window_sec'] = \ - [float(x) for x in doc["response_window"]] - task_parameters['reward_volume'] = config["reward"]["reward_volume"] - task_parameters['auto_reward_volume'] = doc['auto_reward_volume'] - task_parameters['session_type'] = behavior["params"]["stage"] - task_parameters['stimulus'] = next(iter(behavior["stimuli"])) - task_parameters['stimulus_distribution'] = doc["change_time_dist"] - - task_id = config['behavior']['task_id'] - if 'DoC' in task_id: - task_parameters['task'] = 'change detection' + task_parameters["stimulus_duration_sec"] = stim_duration + + task_parameters["omitted_flash_fraction"] = behavior["params"].get("flash_omit_probability", float("nan")) + task_parameters["response_window_sec"] = [float(x) for x in doc["response_window"]] + task_parameters["reward_volume"] = config["reward"]["reward_volume"] + task_parameters["auto_reward_volume"] = doc["auto_reward_volume"] + task_parameters["session_type"] = behavior["params"]["stage"] + task_parameters["stimulus"] = next(iter(behavior["stimuli"])) + task_parameters["stimulus_distribution"] = doc["change_time_dist"] + + task_id = config["behavior"]["task_id"] + if "DoC" in task_id: + task_parameters["task"] = "change detection" else: msg = "metadata.get_task_parameters does not " msg += f"know how to parse 'task_id' = {task_id}" @@ -176,29 +162,29 @@ def get_task_parameters(data: Dict) -> Dict: n_stimulus_frames = 0 for stim_type, stim_table in behavior["stimuli"].items(): n_stimulus_frames += sum(stim_table.get("draw_log", [])) - task_parameters['n_stimulus_frames'] = n_stimulus_frames + task_parameters["n_stimulus_frames"] = n_stimulus_frames return task_parameters -class BehaviorMetadata(DataObject, LimsReadableInterface, - JsonReadableInterface, - NwbReadableInterface, - NwbWritableInterface): +class BehaviorMetadata( + DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface, NwbWritableInterface +): """Container class for behavior metadata""" - def __init__(self, - date_of_acquisition: DateOfAcquisition, - subject_metadata: SubjectMetadata, - behavior_session_id: BehaviorSessionId, - equipment: Equipment, - stimulus_frame_rate: StimulusFrameRate, - session_type: SessionType, - behavior_session_uuid: BehaviorSessionUUID, - project_code: ProjectCode = ProjectCode(), - session_duration: Optional[float] = None - ): - super().__init__(name='behavior_metadata', value=None, - is_value_self=True) + + def __init__( + self, + date_of_acquisition: DateOfAcquisition, + subject_metadata: SubjectMetadata, + behavior_session_id: BehaviorSessionId, + equipment: Equipment, + stimulus_frame_rate: StimulusFrameRate, + session_type: SessionType, + behavior_session_uuid: BehaviorSessionUUID, + project_code: ProjectCode = ProjectCode(), + session_duration: Optional[float] = None, + ): + super().__init__(name="behavior_metadata", value=None, is_value_self=True) self._date_of_acquisition = date_of_acquisition self._subject_metadata = subject_metadata self._behavior_session_id = behavior_session_id @@ -213,39 +199,27 @@ def __init__(self, @classmethod def from_lims( - cls, - behavior_session_id: BehaviorSessionId, - lims_db: PostgresQueryMixin, + cls, + behavior_session_id: BehaviorSessionId, + lims_db: PostgresQueryMixin, ) -> "BehaviorMetadata": - subject_metadata = SubjectMetadata.from_lims( - behavior_session_id=behavior_session_id, - lims_db=lims_db - ) - equipment = Equipment.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db) + subject_metadata = SubjectMetadata.from_lims(behavior_session_id=behavior_session_id, lims_db=lims_db) + equipment = Equipment.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) stimulus_file = BehaviorStimulusFile.from_lims( - db=lims_db, behavior_session_id=behavior_session_id.value)\ - .validate() - date_of_acquisition = DateOfAcquisition.from_stimulus_file( - stimulus_file=stimulus_file) - - stimulus_frame_rate = StimulusFrameRate.from_stimulus_file( - stimulus_file=stimulus_file) - session_type = SessionType.from_stimulus_file( - stimulus_file=stimulus_file) - - foraging_id = ForagingId.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db) - behavior_session_uuid = BehaviorSessionUUID.from_stimulus_file( - stimulus_file=stimulus_file)\ - .validate(behavior_session_id=behavior_session_id.value, - foraging_id=foraging_id.value, - stimulus_file=stimulus_file) - - project_code = ProjectCode.from_lims( - behavior_session_id=behavior_session_id.value, - lims_db=lims_db) + db=lims_db, behavior_session_id=behavior_session_id.value + ).validate() + date_of_acquisition = DateOfAcquisition.from_stimulus_file(stimulus_file=stimulus_file) + + stimulus_frame_rate = StimulusFrameRate.from_stimulus_file(stimulus_file=stimulus_file) + session_type = SessionType.from_stimulus_file(stimulus_file=stimulus_file) + + foraging_id = ForagingId.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + behavior_session_uuid = BehaviorSessionUUID.from_stimulus_file(stimulus_file=stimulus_file).validate( + behavior_session_id=behavior_session_id.value, foraging_id=foraging_id.value, stimulus_file=stimulus_file + ) + + project_code = ProjectCode.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) return BehaviorMetadata( date_of_acquisition=date_of_acquisition, @@ -256,7 +230,7 @@ def from_lims( session_type=session_type, behavior_session_uuid=behavior_session_uuid, project_code=project_code, - session_duration=stimulus_file.session_duration + session_duration=stimulus_file.session_duration, ) @classmethod @@ -266,14 +240,10 @@ def from_json(cls, dict_repr: dict) -> "BehaviorMetadata": equipment = Equipment.from_json(dict_repr=dict_repr) stimulus_file = BehaviorStimulusFile.from_json(dict_repr=dict_repr) - date_of_acquisition = DateOfAcquisition.from_stimulus_file( - stimulus_file=stimulus_file) - stimulus_frame_rate = StimulusFrameRate.from_stimulus_file( - stimulus_file=stimulus_file) - session_type = SessionType.from_stimulus_file( - stimulus_file=stimulus_file) - session_uuid = BehaviorSessionUUID.from_stimulus_file( - stimulus_file=stimulus_file) + date_of_acquisition = DateOfAcquisition.from_stimulus_file(stimulus_file=stimulus_file) + stimulus_frame_rate = StimulusFrameRate.from_stimulus_file(stimulus_file=stimulus_file) + session_type = SessionType.from_stimulus_file(stimulus_file=stimulus_file) + session_uuid = BehaviorSessionUUID.from_stimulus_file(stimulus_file=stimulus_file) return BehaviorMetadata( date_of_acquisition=date_of_acquisition, @@ -344,25 +314,24 @@ def subject_metadata(self): @property def is_pretest(self): - return self.session_type.lower().startswith('pretest') + return self.session_type.lower().startswith("pretest") @property def is_training(self): - return self.session_type.lower().startswith('training_0') + return self.session_type.lower().startswith("training_0") def to_nwb(self, nwbfile: NWBFile) -> NWBFile: self._subject_metadata.to_nwb(nwbfile=nwbfile) self._equipment.to_nwb(nwbfile=nwbfile) - extension = load_pynwb_extension(BehaviorMetadataSchema, - 'ndx-aibs-behavior-ophys') + extension = load_pynwb_extension(BehaviorMetadataSchema, "ndx-aibs-behavior-ophys") nwb_metadata = extension( - name='metadata', + name="metadata", behavior_session_id=self.behavior_session_id, behavior_session_uuid=str(self.behavior_session_uuid), stimulus_frame_rate=self.stimulus_frame_rate, session_type=self.session_type, equipment_name=self.equipment.value, - project_code=self.project_code + project_code=self.project_code, ) nwbfile.add_lab_meta_data(nwb_metadata) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_id.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_id.py index ba28c0b9df..733512d309 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_id.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_id.py @@ -3,8 +3,7 @@ from cachetools import cached, LRUCache from cachetools.keys import hashkey -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin from allensdk.core import DataObject @@ -13,10 +12,12 @@ def from_lims_cache_key(cls, db, ophys_experiment_id: int): return hashkey(ophys_experiment_id) -class BehaviorSessionId(DataObject, LimsReadableInterface, - JsonReadableInterface, - NwbReadableInterface, - ): +class BehaviorSessionId( + DataObject, + LimsReadableInterface, + JsonReadableInterface, + NwbReadableInterface, +): def __init__(self, behavior_session_id: int): super().__init__(name="behavior_session_id", value=behavior_session_id) @@ -27,10 +28,7 @@ def from_json(cls, dict_repr: dict) -> "BehaviorSessionId": @classmethod @cached(cache=LRUCache(maxsize=10), key=from_lims_cache_key) # TODO should be from_ophys_experiment_id - def from_lims( - cls, db: PostgresQueryMixin, - ophys_experiment_id: int - ) -> "BehaviorSessionId": + def from_lims(cls, db: PostgresQueryMixin, ophys_experiment_id: int) -> "BehaviorSessionId": query = f""" SELECT bs.id FROM ophys_experiments oe @@ -43,11 +41,7 @@ def from_lims( return cls(behavior_session_id=behavior_session_id) @classmethod - def from_ecephys_session_id( - cls, - db: PostgresQueryMixin, - ecephys_session_id: int - ) -> "BehaviorSessionId": + def from_ecephys_session_id(cls, db: PostgresQueryMixin, ecephys_session_id: int) -> "BehaviorSessionId": query = f""" SELECT bs.id FROM behavior_sessions bs @@ -58,5 +52,5 @@ def from_ecephys_session_id( @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "BehaviorSessionId": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(behavior_session_id=metadata.behavior_session_id) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_uuid.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_uuid.py index 1d8408341c..b7c44a63f9 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_uuid.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/behavior_session_uuid.py @@ -9,20 +9,14 @@ from pynwb import NWBFile -class BehaviorSessionUUID( - DataObject, StimulusFileReadableInterface, NwbReadableInterface -): +class BehaviorSessionUUID(DataObject, StimulusFileReadableInterface, NwbReadableInterface): """the universally unique identifier (UUID)""" def __init__(self, behavior_session_uuid: Optional[uuid.UUID]): - super().__init__( - name="behavior_session_uuid", value=behavior_session_uuid - ) + super().__init__(name="behavior_session_uuid", value=behavior_session_uuid) @classmethod - def from_stimulus_file( - cls, stimulus_file: BehaviorStimulusFile - ) -> "BehaviorSessionUUID": + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "BehaviorSessionUUID": bs_uuid = stimulus_file.behavior_session_uuid return cls(behavior_session_uuid=bs_uuid) @@ -30,11 +24,7 @@ def from_stimulus_file( def from_nwb(cls, nwbfile: NWBFile) -> "BehaviorSessionUUID": metadata = nwbfile.lab_meta_data["metadata"] behavior_session_uuid = metadata.behavior_session_uuid - behavior_session_uuid = ( - uuid.UUID(behavior_session_uuid) - if behavior_session_uuid != "None" - else None - ) + behavior_session_uuid = uuid.UUID(behavior_session_uuid) if behavior_session_uuid != "None" else None return cls(behavior_session_uuid=behavior_session_uuid) def validate( diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/date_of_acquisition.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/date_of_acquisition.py index 831d70c3ca..f5e6ea0c3e 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/date_of_acquisition.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/date_of_acquisition.py @@ -7,14 +7,13 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class DateOfAcquisition(DataObject, LimsReadableInterface, - JsonReadableInterface, NwbReadableInterface): +class DateOfAcquisition(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): """timestamp for when experiment was started in UTC""" + def __init__(self, date_of_acquisition: datetime): if date_of_acquisition.tzinfo is None: # Add UTC tzinfo if not already set @@ -23,7 +22,7 @@ def __init__(self, date_of_acquisition: datetime): @classmethod def from_json(cls, dict_repr: dict) -> "DateOfAcquisition": - doa = dict_repr['date_of_acquisition'] + doa = dict_repr["date_of_acquisition"] doa = datetime_parser.parse(doa) if doa.tzinfo is None: @@ -38,9 +37,7 @@ def from_json(cls, dict_repr: dict) -> "DateOfAcquisition": return cls(date_of_acquisition=doa) @classmethod - def from_lims( - cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "DateOfAcquisition": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "DateOfAcquisition": query = """ SELECT bs.date_of_acquisition FROM behavior_sessions bs @@ -51,10 +48,7 @@ def from_lims( return cls(date_of_acquisition=experiment_date) @classmethod - def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile - ) -> "DateOfAcquisition": + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "DateOfAcquisition": return cls(date_of_acquisition=stimulus_file.date_of_acquisition) @classmethod @@ -62,8 +56,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "DateOfAcquisition": date_of_acquisition = nwbfile.session_start_time return cls(date_of_acquisition=date_of_acquisition) - def validate(self, stimulus_file: BehaviorStimulusFile, - behavior_session_id: int) -> "DateOfAcquisition": + def validate(self, stimulus_file: BehaviorStimulusFile, behavior_session_id: int) -> "DateOfAcquisition": """raise a warning if the date differs too much from the datetime obtained from the behavior stimulus (*.pkl) file.""" pkl_data = stimulus_file.data @@ -85,8 +78,7 @@ def validate(self, stimulus_file: BehaviorStimulusFile, ) if pkl_acq_date: - acq_start_diff = ( - self.value - pkl_acq_date).total_seconds() + acq_start_diff = (self.value - pkl_acq_date).total_seconds() # If acquisition dates differ by more than an hour if abs(acq_start_diff) > 3600: session_id = behavior_session_id @@ -106,9 +98,7 @@ class DateOfAcquisitionOphys(DateOfAcquisition): table in LIMS instead of the behavior_sessions table""" @classmethod - def from_lims( - cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "DateOfAcquisitionOphys": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "DateOfAcquisitionOphys": query = f""" SELECT os.date_of_acquisition FROM ophys_experiments oe diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/equipment.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/equipment.py index 31f0f95538..349d33ecdc 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/equipment.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/equipment.py @@ -3,20 +3,19 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.core import NwbWritableInterface from allensdk.internal.api import PostgresQueryMixin class EquipmentType(Enum): - MESOSCOPE = 'MESOSCOPE' - OTHER = 'OTHER' + MESOSCOPE = "MESOSCOPE" + OTHER = "OTHER" -class Equipment(DataObject, JsonReadableInterface, LimsReadableInterface, - NwbReadableInterface, NwbWritableInterface): +class Equipment(DataObject, JsonReadableInterface, LimsReadableInterface, NwbReadableInterface, NwbWritableInterface): """the name of the experimental rig.""" + def __init__(self, equipment_name: str): super().__init__(name="equipment_name", value=equipment_name) @@ -25,8 +24,7 @@ def from_json(cls, dict_repr: dict) -> "Equipment": return cls(equipment_name=dict_repr["rig_name"]) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "Equipment": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "Equipment": query = f""" SELECT e.name AS device_name FROM behavior_sessions bs @@ -38,28 +36,24 @@ def from_lims(cls, behavior_session_id: int, @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "Equipment": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(equipment_name=metadata.equipment_name) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: if self.type == EquipmentType.MESOSCOPE: - device_config = { - "name": self.value, - "description": "Allen Brain Observatory - Mesoscope 2P Rig" - } + device_config = {"name": self.value, "description": "Allen Brain Observatory - Mesoscope 2P Rig"} else: device_config = { "name": self.value, - "description": "Allen Brain Observatory - Scientifica 2P " - "Rig", - "manufacturer": "Scientifica" + "description": "Allen Brain Observatory - Scientifica 2P Rig", + "manufacturer": "Scientifica", } nwbfile.create_device(**device_config) return nwbfile @property def type(self): - if self.value.startswith('MESO'): + if self.value.startswith("MESO"): et = EquipmentType.MESOSCOPE else: et = EquipmentType.OTHER diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/foraging_id.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/foraging_id.py index 51753733a2..3f75810da1 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/foraging_id.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/foraging_id.py @@ -1,13 +1,13 @@ import uuid from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface from allensdk.internal.api import PostgresQueryMixin class ForagingId(DataObject, LimsReadableInterface, JsonReadableInterface): """Foraging id""" + def __init__(self, foraging_id: uuid.UUID): super().__init__(name="foraging_id", value=foraging_id) @@ -16,8 +16,7 @@ def from_json(cls, dict_repr: dict) -> "ForagingId": pass @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "ForagingId": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "ForagingId": query = f""" SELECT foraging_id diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/project_code.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/project_code.py index 475104b8c7..7c8ec07056 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/project_code.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/project_code.py @@ -2,8 +2,7 @@ from typing import Optional from allensdk.core import DataObject -from allensdk.core import \ - LimsReadableInterface +from allensdk.core import LimsReadableInterface from allensdk.internal.api import PostgresQueryMixin from allensdk.core import NwbReadableInterface @@ -17,12 +16,11 @@ class ProjectCode(DataObject, LimsReadableInterface, NwbReadableInterface): def __init__(self, project_code: Optional[str] = None): if project_code is None: - project_code = 'Not Available' - super().__init__(name='project_code', value=project_code) + project_code = "Not Available" + super().__init__(name="project_code", value=project_code) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "ProjectCode": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "ProjectCode": query = f""" SELECT ps.code AS project_code FROM behavior_sessions bs @@ -35,7 +33,7 @@ def from_lims(cls, behavior_session_id: int, @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "ProjectCode": try: - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(project_code=metadata.project_code) except AttributeError: # Return values for NWBs without the project code set/available. diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/session_type.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/session_type.py index e297264981..d9cdf908f8 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/session_type.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/session_type.py @@ -2,25 +2,20 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface -class SessionType(DataObject, StimulusFileReadableInterface, - NwbReadableInterface): +class SessionType(DataObject, StimulusFileReadableInterface, NwbReadableInterface): """the stimulus set used""" + def __init__(self, session_type: str): super().__init__(name="session_type", value=session_type) @classmethod - def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile) -> "SessionType": + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "SessionType": try: - stimulus_name = \ - stimulus_file.data["items"]["behavior"]["cl_params"]["stage"] + stimulus_name = stimulus_file.data["items"]["behavior"]["cl_params"]["stage"] except KeyError: raise RuntimeError( f"Could not obtain stimulus_name/stage information from " @@ -33,5 +28,5 @@ def from_stimulus_file( @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "SessionType": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(session_type=metadata.session_type) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/stimulus_frame_rate.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/stimulus_frame_rate.py index 96643555fd..c5f976568f 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/stimulus_frame_rate.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_metadata/stimulus_frame_rate.py @@ -2,39 +2,31 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .stimulus_timestamps.stimulus_timestamps import \ - StimulusTimestamps -from allensdk.brain_observatory.behavior.data_objects.timestamps.util import \ - calc_frame_rate +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.stimulus_timestamps import ( + StimulusTimestamps, +) +from allensdk.brain_observatory.behavior.data_objects.timestamps.util import calc_frame_rate -class StimulusFrameRate(DataObject, StimulusFileReadableInterface, - NwbReadableInterface): +class StimulusFrameRate(DataObject, StimulusFileReadableInterface, NwbReadableInterface): """Stimulus frame rate""" + def __init__(self, stimulus_frame_rate: float): super().__init__(name="stimulus_frame_rate", value=stimulus_frame_rate) @classmethod - def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile) -> "StimulusFrameRate": - + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "StimulusFrameRate": # in this data object, we only care about the difference between # timestamps, so we can set the monitor_delay to any # value without affecting the result - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, - monitor_delay=0.0) + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.0) frame_rate = calc_frame_rate(timestamps=stimulus_timestamps.value) return cls(stimulus_frame_rate=frame_rate) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "StimulusFrameRate": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(stimulus_frame_rate=metadata.stimulus_frame_rate) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_ophys_metadata.py b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_ophys_metadata.py index de69d2d084..8eb6bbc7a1 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_ophys_metadata.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/behavior_ophys_metadata.py @@ -38,9 +38,7 @@ def __init__( behavior_metadata: BehaviorMetadata, ophys_metadata: Union[OphysExperimentMetadata, MultiplaneMetadata], ): - super().__init__( - name="behavior_ophys_metadata", value=None, is_value_self=True - ) + super().__init__(name="behavior_ophys_metadata", value=None, is_value_self=True) self._behavior_metadata = behavior_metadata self._ophys_metadata = ophys_metadata @@ -72,22 +70,14 @@ def from_lims( Whether to fetch metadata for an experiment that is part of a container containing multiple imaging planes """ - behavior_session_id = BehaviorSessionId.from_lims( - ophys_experiment_id=ophys_experiment_id, db=lims_db - ) + behavior_session_id = BehaviorSessionId.from_lims(ophys_experiment_id=ophys_experiment_id, db=lims_db) - behavior_metadata = BehaviorMetadata.from_lims( - behavior_session_id=behavior_session_id, lims_db=lims_db - ) + behavior_metadata = BehaviorMetadata.from_lims(behavior_session_id=behavior_session_id, lims_db=lims_db) if is_multiplane: - ophys_metadata = MultiplaneMetadata.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + ophys_metadata = MultiplaneMetadata.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) else: - ophys_metadata = OphysExperimentMetadata.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + ophys_metadata = OphysExperimentMetadata.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) if ophys_metadata.project_code != behavior_metadata.project_code: raise warnings.warn( @@ -97,13 +87,9 @@ def from_lims( f"behavior_session_id={behavior_session_id}." ) - return cls( - behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata - ) + return cls(behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata) - def update_targeted_imaging_depth( - self, ophys_experiment_ids: List[int], lims_db: PostgresQueryMixin - ): + def update_targeted_imaging_depth(self, ophys_experiment_ids: List[int], lims_db: PostgresQueryMixin): """Update the value for targeted imaging depth given a set of experiments to be published. @@ -118,14 +104,10 @@ def update_targeted_imaging_depth( lims_db : PostgresQueryMixin Connection to the LIMS2 database. """ - self._ophys_metadata.update_targeted_imaging_depth( - ophys_experiment_ids, lims_db - ) + self._ophys_metadata.update_targeted_imaging_depth(ophys_experiment_ids, lims_db) @classmethod - def from_json( - cls, dict_repr: dict, is_multiplane=False - ) -> "BehaviorOphysMetadata": + def from_json(cls, dict_repr: dict, is_multiplane=False) -> "BehaviorOphysMetadata": """ Parameters @@ -144,18 +126,12 @@ def from_json( if is_multiplane: ophys_metadata = MultiplaneMetadata.from_json(dict_repr=dict_repr) else: - ophys_metadata = OphysExperimentMetadata.from_json( - dict_repr=dict_repr - ) + ophys_metadata = OphysExperimentMetadata.from_json(dict_repr=dict_repr) - return cls( - behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata - ) + return cls(behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata) @classmethod - def from_nwb( - cls, nwbfile: NWBFile, is_multiplane=False - ) -> "BehaviorOphysMetadata": + def from_nwb(cls, nwbfile: NWBFile, is_multiplane=False) -> "BehaviorOphysMetadata": """ Parameters @@ -172,17 +148,13 @@ def from_nwb( else: ophys_metadata = OphysExperimentMetadata.from_nwb(nwbfile=nwbfile) - return cls( - behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata - ) + return cls(behavior_metadata=behavior_metadata, ophys_metadata=ophys_metadata) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: self._behavior_metadata.subject_metadata.to_nwb(nwbfile=nwbfile) self._behavior_metadata.equipment.to_nwb(nwbfile=nwbfile) - nwb_extension = load_pynwb_extension( - OphysBehaviorMetadataSchema, "ndx-aibs-behavior-ophys" - ) + nwb_extension = load_pynwb_extension(OphysBehaviorMetadataSchema, "ndx-aibs-behavior-ophys") behavior_meta = self._behavior_metadata ophys_meta = self._ophys_metadata diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/field_of_view_shape.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/field_of_view_shape.py index 9c7a6cfdae..cf57e52ad4 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/field_of_view_shape.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/field_of_view_shape.py @@ -1,16 +1,13 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class FieldOfViewShape(DataObject, LimsReadableInterface, - NwbReadableInterface, JsonReadableInterface): +class FieldOfViewShape(DataObject, LimsReadableInterface, NwbReadableInterface, JsonReadableInterface): def __init__(self, height: int, width: int): - super().__init__(name='field_of_view_shape', value=None, - is_value_self=True) + super().__init__(name="field_of_view_shape", value=None, is_value_self=True) self._height = height self._width = width @@ -24,25 +21,22 @@ def width(self): return self._width @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "FieldOfViewShape": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "FieldOfViewShape": query = f""" SELECT oe.movie_width as width, oe.movie_height as height FROM ophys_experiments oe WHERE oe.id = {ophys_experiment_id}; """ df = lims_db.select(query=query) - height = df.iloc[0]['height'] - width = df.iloc[0]['width'] + height = df.iloc[0]["height"] + width = df.iloc[0]["width"] return cls(height=height, width=width) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "FieldOfViewShape": - metadata = nwbfile.lab_meta_data['metadata'] - return cls(height=metadata.field_of_view_height, - width=metadata.field_of_view_width) + metadata = nwbfile.lab_meta_data["metadata"] + return cls(height=metadata.field_of_view_height, width=metadata.field_of_view_width) @classmethod def from_json(cls, dict_repr: dict) -> "FieldOfViewShape": - return cls(height=dict_repr['movie_height'], - width=dict_repr['movie_width']) + return cls(height=dict_repr["movie_height"], width=dict_repr["movie_width"]) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_depth.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_depth.py index 023870f075..ccb12af70f 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_depth.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_depth.py @@ -1,23 +1,21 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class ImagingDepth(DataObject, LimsReadableInterface, NwbReadableInterface, - JsonReadableInterface): +class ImagingDepth(DataObject, LimsReadableInterface, NwbReadableInterface, JsonReadableInterface): """Data object loads and stores the imaging_depth (microns) for an experiments. This is the calculated difference between measured z-depths of the surface and imaging_depth. """ + def __init__(self, imaging_depth: int): - super().__init__(name='imaging_depth', value=imaging_depth) + super().__init__(name="imaging_depth", value=imaging_depth) @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "ImagingDepth": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "ImagingDepth": query = """ SELECT imd.depth FROM ophys_experiments oe @@ -32,9 +30,9 @@ def from_lims(cls, ophys_experiment_id: int, def from_json(cls, dict_repr: dict) -> "ImagingDepth": # TODO remove all of the from_json loading and validation step # ticket 2607 - return cls(imaging_depth=dict_repr['targeted_depth']) + return cls(imaging_depth=dict_repr["targeted_depth"]) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "ImagingDepth": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(imaging_depth=metadata.imaging_depth) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_plane.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_plane.py index fc22171fd2..e81d1f0bce 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_plane.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/imaging_plane.py @@ -4,79 +4,75 @@ from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.core import \ - JsonReadableInterface, NwbReadableInterface, \ - LimsReadableInterface -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.reporter_line import \ - ReporterLine -from allensdk.brain_observatory.behavior.data_objects.timestamps \ - .ophys_timestamps import OphysTimestamps -from allensdk.brain_observatory.behavior.data_objects.timestamps.util import \ - calc_frame_rate +from allensdk.core import JsonReadableInterface, NwbReadableInterface, LimsReadableInterface +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.reporter_line import ReporterLine +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import OphysTimestamps +from allensdk.brain_observatory.behavior.data_objects.timestamps.util import calc_frame_rate from allensdk.internal.api import PostgresQueryMixin -class ImagingPlane(DataObject, LimsReadableInterface, - JsonReadableInterface, NwbReadableInterface): - def __init__(self, ophys_frame_rate: float, - targeted_structure: str, - excitation_lambda: float, - indicator: Optional[str]): - super().__init__(name='imaging_plane', value=None, - is_value_self=True) +class ImagingPlane(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): + def __init__( + self, ophys_frame_rate: float, targeted_structure: str, excitation_lambda: float, indicator: Optional[str] + ): + super().__init__(name="imaging_plane", value=None, is_value_self=True) self._ophys_frame_rate = ophys_frame_rate self._targeted_structure = targeted_structure self._excitation_lambda = excitation_lambda self._indicator = indicator @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin, - ophys_timestamps: OphysTimestamps, - excitation_lambda=910.0) -> "ImagingPlane": - behavior_session_id = BehaviorSessionId.from_lims( - db=lims_db, ophys_experiment_id=ophys_experiment_id) + def from_lims( + cls, + ophys_experiment_id: int, + lims_db: PostgresQueryMixin, + ophys_timestamps: OphysTimestamps, + excitation_lambda=910.0, + ) -> "ImagingPlane": + behavior_session_id = BehaviorSessionId.from_lims(db=lims_db, ophys_experiment_id=ophys_experiment_id) ophys_frame_rate = calc_frame_rate(timestamps=ophys_timestamps.value) targeted_structure = cls._get_targeted_structure_from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) - reporter_line = ReporterLine.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db) + ophys_experiment_id=ophys_experiment_id, lims_db=lims_db + ) + reporter_line = ReporterLine.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) indicator = reporter_line.parse_indicator(warn=True) - return cls(ophys_frame_rate=ophys_frame_rate, - targeted_structure=targeted_structure, - excitation_lambda=excitation_lambda, - indicator=indicator) + return cls( + ophys_frame_rate=ophys_frame_rate, + targeted_structure=targeted_structure, + excitation_lambda=excitation_lambda, + indicator=indicator, + ) @classmethod - def from_json(cls, dict_repr: dict, - ophys_timestamps: OphysTimestamps, - excitation_lambda=910.0) -> "ImagingPlane": - targeted_structure = dict_repr['targeted_structure'] + def from_json(cls, dict_repr: dict, ophys_timestamps: OphysTimestamps, excitation_lambda=910.0) -> "ImagingPlane": + targeted_structure = dict_repr["targeted_structure"] ophys_fame_rate = calc_frame_rate(timestamps=ophys_timestamps.value) reporter_line = ReporterLine.from_json(dict_repr=dict_repr) indicator = reporter_line.parse_indicator(warn=True) - return cls(targeted_structure=targeted_structure, - ophys_frame_rate=ophys_fame_rate, - excitation_lambda=excitation_lambda, - indicator=indicator) + return cls( + targeted_structure=targeted_structure, + ophys_frame_rate=ophys_fame_rate, + excitation_lambda=excitation_lambda, + indicator=indicator, + ) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "ImagingPlane": - ophys_module = nwbfile.processing['ophys'] - image_seg = ophys_module.data_interfaces['image_segmentation'] - imaging_plane = image_seg.plane_segmentations[ - 'cell_specimen_table'].imaging_plane + ophys_module = nwbfile.processing["ophys"] + image_seg = ophys_module.data_interfaces["image_segmentation"] + imaging_plane = image_seg.plane_segmentations["cell_specimen_table"].imaging_plane ophys_frame_rate = imaging_plane.imaging_rate targeted_structure = imaging_plane.location excitation_lambda = imaging_plane.excitation_lambda reporter_line = ReporterLine.from_nwb(nwbfile=nwbfile) indicator = reporter_line.parse_indicator(warn=True) - return cls(ophys_frame_rate=ophys_frame_rate, - targeted_structure=targeted_structure, - excitation_lambda=excitation_lambda, - indicator=indicator) + return cls( + ophys_frame_rate=ophys_frame_rate, + targeted_structure=targeted_structure, + excitation_lambda=excitation_lambda, + indicator=indicator, + ) @property def ophys_frame_rate(self) -> float: @@ -95,8 +91,7 @@ def indicator(self) -> Optional[str]: return self._indicator @staticmethod - def _get_targeted_structure_from_lims(ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> str: + def _get_targeted_structure_from_lims(ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> str: query = """ SELECT st.acronym FROM ophys_experiments oe diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/imaging_plane_group.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/imaging_plane_group.py index 2c3a10f252..5195b48161 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/imaging_plane_group.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/imaging_plane_group.py @@ -3,15 +3,13 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class ImagingPlaneGroup(DataObject, LimsReadableInterface, - JsonReadableInterface, NwbReadableInterface): +class ImagingPlaneGroup(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): def __init__(self, plane_group: int, plane_group_count: int): - super().__init__(name='plane_group', value=None, is_value_self=True) + super().__init__(name="plane_group", value=None, is_value_self=True) self._plane_group = plane_group self._plane_group_count = plane_group_count @@ -24,9 +22,7 @@ def plane_group_count(self): return self._plane_group_count @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> \ - Optional["ImagingPlaneGroup"]: + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> Optional["ImagingPlaneGroup"]: """ Parameters @@ -41,7 +37,7 @@ def from_lims(cls, ophys_experiment_id: int, else None """ - query = f''' + query = f""" SELECT oe.id as ophys_experiment_id, pg.group_order AS plane_group FROM ophys_experiments oe JOIN ophys_sessions os ON oe.ophys_session_id = os.id @@ -52,25 +48,22 @@ def from_lims(cls, ophys_experiment_id: int, FROM ophys_experiments oe WHERE oe.id = {ophys_experiment_id} ) - ''' + """ df = lims_db.select(query=query) if df.empty: return None - df = df.set_index('ophys_experiment_id') - plane_group = df.loc[ophys_experiment_id, 'plane_group'] - plane_group_count = df['plane_group'].nunique() - return cls(plane_group=plane_group, - plane_group_count=plane_group_count) + df = df.set_index("ophys_experiment_id") + plane_group = df.loc[ophys_experiment_id, "plane_group"] + plane_group_count = df["plane_group"].nunique() + return cls(plane_group=plane_group, plane_group_count=plane_group_count) @classmethod def from_json(cls, dict_repr: dict) -> "ImagingPlaneGroup": - plane_group = dict_repr['imaging_plane_group'] - plane_group_count = dict_repr['plane_group_count'] - return cls(plane_group=plane_group, - plane_group_count=plane_group_count) + plane_group = dict_repr["imaging_plane_group"] + plane_group_count = dict_repr["plane_group_count"] + return cls(plane_group=plane_group, plane_group_count=plane_group_count) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "ImagingPlaneGroup": - metadata = nwbfile.lab_meta_data['metadata'] - return cls(plane_group=metadata.imaging_plane_group, - plane_group_count=metadata.imaging_plane_group_count) + metadata = nwbfile.lab_meta_data["metadata"] + return cls(plane_group=metadata.imaging_plane_group, plane_group_count=metadata.imaging_plane_group_count) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/multi_plane_metadata.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/multi_plane_metadata.py index 04606c412b..deb5f8052d 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/multi_plane_metadata.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/multi_plane_metadata/multi_plane_metadata.py @@ -1,26 +1,44 @@ from pynwb import NWBFile -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.field_of_view_shape import FieldOfViewShape # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.imaging_depth import ImagingDepth # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.multi_plane_metadata.imaging_plane_group import ImagingPlaneGroup # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_container_id import OphysContainerId # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_experiment_metadata import OphysExperimentMetadata # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_session_id import OphysSessionId # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_project_code import OphysProjectCode # NOQA -from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.targeted_imaging_depth import TargetedImagingDepth # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.field_of_view_shape import ( + FieldOfViewShape, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.imaging_depth import ( + ImagingDepth, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.multi_plane_metadata.imaging_plane_group import ( + ImagingPlaneGroup, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_container_id import ( + OphysContainerId, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_experiment_metadata import ( + OphysExperimentMetadata, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_session_id import ( + OphysSessionId, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.ophys_project_code import ( + OphysProjectCode, +) # NOQA +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.targeted_imaging_depth import ( + TargetedImagingDepth, +) # NOQA from allensdk.internal.api import PostgresQueryMixin class MultiplaneMetadata(OphysExperimentMetadata): - def __init__(self, - ophys_experiment_id: int, - ophys_session_id: OphysSessionId, - ophys_container_id: OphysContainerId, - field_of_view_shape: FieldOfViewShape, - imaging_depth: ImagingDepth, - targeted_imaging_depth: TargetedImagingDepth, - imaging_plane_group: ImagingPlaneGroup, - project_code: OphysProjectCode = OphysProjectCode()): + def __init__( + self, + ophys_experiment_id: int, + ophys_session_id: OphysSessionId, + ophys_container_id: OphysContainerId, + field_of_view_shape: FieldOfViewShape, + imaging_depth: ImagingDepth, + targeted_imaging_depth: TargetedImagingDepth, + imaging_plane_group: ImagingPlaneGroup, + project_code: OphysProjectCode = OphysProjectCode(), + ): super().__init__( ophys_experiment_id=ophys_experiment_id, ophys_session_id=ophys_session_id, @@ -28,18 +46,16 @@ def __init__(self, field_of_view_shape=field_of_view_shape, imaging_depth=imaging_depth, targeted_imaging_depth=targeted_imaging_depth, - project_code=project_code + project_code=project_code, ) self._imaging_plane_group = imaging_plane_group @classmethod - def from_lims( - cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "MultiplaneMetadata": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "MultiplaneMetadata": ophys_experiment_metadata = OphysExperimentMetadata.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) - imaging_plane_group = ImagingPlaneGroup.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) + ophys_experiment_id=ophys_experiment_id, lims_db=lims_db + ) + imaging_plane_group = ImagingPlaneGroup.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) return cls( ophys_experiment_id=ophys_experiment_metadata.ophys_experiment_id, ophys_session_id=ophys_experiment_metadata._ophys_session_id, @@ -48,7 +64,7 @@ def from_lims( imaging_depth=ophys_experiment_metadata._imaging_depth, targeted_imaging_depth=ophys_experiment_metadata._targeted_imaging_depth, # noqa E501 project_code=ophys_experiment_metadata._project_code, - imaging_plane_group=imaging_plane_group + imaging_plane_group=imaging_plane_group, ) @classmethod @@ -63,7 +79,7 @@ def from_json(cls, dict_repr: dict) -> "MultiplaneMetadata": imaging_depth=ophys_experiment_metadata._imaging_depth, targeted_imaging_depth=ophys_experiment_metadata._targeted_imaging_depth, # noqa E501 project_code=ophys_experiment_metadata._project_code, - imaging_plane_group=imaging_plane_group + imaging_plane_group=imaging_plane_group, ) @classmethod @@ -78,7 +94,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "MultiplaneMetadata": imaging_depth=ophys_experiment_metadata._imaging_depth, targeted_imaging_depth=ophys_experiment_metadata._targeted_imaging_depth, # noqa E501 project_code=ophys_experiment_metadata._project_code, - imaging_plane_group=imaging_plane_group + imaging_plane_group=imaging_plane_group, ) @property diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_container_id.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_container_id.py index b4791dccf1..9dabe63f22 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_container_id.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_container_id.py @@ -1,21 +1,18 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class OphysContainerId(DataObject, LimsReadableInterface, - JsonReadableInterface, NwbReadableInterface): - """"experiment container id""" +class OphysContainerId(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): + """ "experiment container id""" + def __init__(self, ophys_container_id: int): - super().__init__(name='ophys_container_id', - value=ophys_container_id) + super().__init__(name="ophys_container_id", value=ophys_container_id) @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "ExperimentContainerId": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "ExperimentContainerId": query = """ SELECT visual_behavior_experiment_container_id FROM ophys_experiments_visual_behavior_experiment_containers @@ -26,12 +23,12 @@ def from_lims(cls, ophys_experiment_id: int, @classmethod def from_json(cls, dict_repr: dict) -> "ExperimentContainerId": - return cls(ophys_container_id=dict_repr['container_id']) + return cls(ophys_container_id=dict_repr["container_id"]) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "ExperimentContainerId": try: - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(ophys_container_id=metadata.ophys_container_id) except AttributeError: return None diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_experiment_metadata.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_experiment_metadata.py index d7c67f9df5..11e05c53d1 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_experiment_metadata.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_experiment_metadata.py @@ -46,9 +46,7 @@ def __init__( targeted_imaging_depth: TargetedImagingDepth, project_code: OphysProjectCode = OphysProjectCode(), ): - super().__init__( - name="ophys_experiment_metadata", value=None, is_value_self=True - ) + super().__init__(name="ophys_experiment_metadata", value=None, is_value_self=True) self._ophys_experiment_id = ophys_experiment_id self._ophys_session_id = ophys_session_id self._ophys_container_id = ophys_container_id @@ -58,27 +56,15 @@ def __init__( self._project_code = project_code @classmethod - def from_lims( - cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin - ) -> "OphysExperimentMetadata": - ophys_session_id = OphysSessionId.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) - ophys_container_id = OphysContainerId.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) - field_of_view_shape = FieldOfViewShape.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) - imaging_depth = ImagingDepth.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "OphysExperimentMetadata": + ophys_session_id = OphysSessionId.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) + ophys_container_id = OphysContainerId.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) + field_of_view_shape = FieldOfViewShape.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) + imaging_depth = ImagingDepth.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) targeted_imaging_depth = TargetedImagingDepth.from_lims( ophys_experiment_id=ophys_experiment_id, lims_db=lims_db ) - project_code = OphysProjectCode.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=lims_db - ) + project_code = OphysProjectCode.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=lims_db) return cls( ophys_experiment_id=ophys_experiment_id, @@ -97,9 +83,7 @@ def from_json(cls, dict_repr: dict) -> "OphysExperimentMetadata": ophys_experiment_id = dict_repr["ophys_experiment_id"] field_of_view_shape = FieldOfViewShape.from_json(dict_repr=dict_repr) imaging_depth = ImagingDepth.from_json(dict_repr=dict_repr) - targeted_imaging_depth = TargetedImagingDepth.from_json( - dict_repr=dict_repr - ) + targeted_imaging_depth = TargetedImagingDepth.from_json(dict_repr=dict_repr) return OphysExperimentMetadata( ophys_experiment_id=ophys_experiment_id, @@ -155,9 +139,7 @@ def targeted_imaging_depth(self) -> int: return None return self._targeted_imaging_depth.value - def update_targeted_imaging_depth( - self, ophys_experiment_ids: List[int], lims_db: PostgresQueryMixin - ): + def update_targeted_imaging_depth(self, ophys_experiment_ids: List[int], lims_db: PostgresQueryMixin): """Update the value for targeted imaging depth given a set of experiments to be published. diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_project_code.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_project_code.py index b3de8ab208..aeecdcc6da 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_project_code.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_project_code.py @@ -1,5 +1,4 @@ -from allensdk.brain_observatory.behavior.data_objects.metadata.\ - behavior_metadata.project_code import ProjectCode +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.project_code import ProjectCode from allensdk.internal.api import PostgresQueryMixin @@ -11,8 +10,7 @@ class OphysProjectCode(ProjectCode): """ @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "OphysProjectCode": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "OphysProjectCode": query = f""" SELECT ps.code AS project_code FROM ophys_sessions os diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_session_id.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_session_id.py index 725db2cbc0..69103be0a5 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_session_id.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/ophys_session_id.py @@ -1,21 +1,18 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class OphysSessionId(DataObject, LimsReadableInterface, - JsonReadableInterface, NwbReadableInterface): - """"Ophys session id""" +class OphysSessionId(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): + """ "Ophys session id""" + def __init__(self, session_id: int): - super().__init__(name='session_id', - value=session_id) + super().__init__(name="session_id", value=session_id) @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "OphysSessionId": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "OphysSessionId": query = """ SELECT oe.ophys_session_id FROM ophys_experiments oe @@ -26,9 +23,9 @@ def from_lims(cls, ophys_experiment_id: int, @classmethod def from_json(cls, dict_repr: dict) -> "OphysSessionId": - return cls(session_id=dict_repr['ophys_session_id']) + return cls(session_id=dict_repr["ophys_session_id"]) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "OphysSessionId": - metadata = nwbfile.lab_meta_data['metadata'] + metadata = nwbfile.lab_meta_data["metadata"] return cls(session_id=metadata.ophys_session_id) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/targeted_imaging_depth.py b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/targeted_imaging_depth.py index 1841aba087..579feed999 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/targeted_imaging_depth.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/ophys_experiment_metadata/targeted_imaging_depth.py @@ -22,9 +22,7 @@ class TargetedImagingDepth( """ def __init__(self, targeted_imaging_depth: int): - super().__init__( - name="targeted_imaging_depth", value=targeted_imaging_depth - ) + super().__init__(name="targeted_imaging_depth", value=targeted_imaging_depth) @classmethod def from_lims( @@ -50,9 +48,7 @@ def from_lims( SELECT visual_behavior_experiment_container_id FROM ophys_experiments_visual_behavior_experiment_containers WHERE ophys_experiment_id = {} - """.format( - ophys_experiment_id - ) + """.format(ophys_experiment_id) container_id = lims_db.fetchone(query_container_id, strict=True) @@ -62,9 +58,7 @@ def from_lims( JOIN ophys_experiments oe ON oe.id = ec.ophys_experiment_id LEFT JOIN imaging_depths imd ON imd.id = oe.imaging_depth_id WHERE ec.visual_behavior_experiment_container_id = {}; - """.format( - container_id - ) + """.format(container_id) depths = lims_db.select(query_depths).set_index("ophys_experiment_id") if ophys_experiment_ids is not None: if ophys_experiment_id not in ophys_experiment_ids: diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/age.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/age.py index 7fb93c4aea..b85c9792a8 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/age.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/age.py @@ -35,12 +35,9 @@ def from_json(cls, dict_repr: dict) -> "Age": return cls(age=age) @classmethod - def from_lims( - cls, behavior_session_id: int, lims_db: PostgresQueryMixin - ) -> "Age": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "Age": date_of_acquisition = DateOfAcquisition.from_lims( - behavior_session_id=behavior_session_id, - lims_db=lims_db + behavior_session_id=behavior_session_id, lims_db=lims_db ).value query = f""" @@ -49,9 +46,7 @@ def from_lims( JOIN donors d ON d.id = bs.donor_id WHERE bs.id = {behavior_session_id}; """ - date_of_birth = cls._check_timezone( - lims_db.fetchone(query, strict=True) - ) + date_of_birth = cls._check_timezone(lims_db.fetchone(query, strict=True)) age = (date_of_acquisition - date_of_birth).days return cls(age=age) @@ -87,20 +82,14 @@ def _age_code_to_days(age: str, warn=False) -> Optional[int]: """ if not age.startswith("P"): if warn: - warnings.warn( - "Could not parse numeric age from age code " - '(age code does not start with "P")' - ) + warnings.warn('Could not parse numeric age from age code (age code does not start with "P")') return None match = re.search(r"\d+", age) if match is None: if warn: - warnings.warn( - "Could not parse numeric age from age code " - "(no numeric values found in age code)" - ) + warnings.warn("Could not parse numeric age from age code (no numeric values found in age code)") return None start, end = match.span() diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/driver_line.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/driver_line.py index ca13d2f3ed..2164caf476 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/driver_line.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/driver_line.py @@ -3,26 +3,22 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface -from allensdk.internal.api import PostgresQueryMixin, \ - OneOrMoreResultExpectedError +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.internal.api import PostgresQueryMixin, OneOrMoreResultExpectedError -class DriverLine(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface): +class DriverLine(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): """the genotype name(s) of the driver line(s)""" + def __init__(self, driver_line: Optional[List[str]]): super().__init__(name="driver_line", value=driver_line) @classmethod def from_json(cls, dict_repr: dict) -> "DriverLine": - return cls(driver_line=dict_repr['driver_line']) + return cls(driver_line=dict_repr["driver_line"]) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin, - allow_none: bool = True) -> "DriverLine": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin, allow_none: bool = True) -> "DriverLine": """ Parameters ---------- @@ -54,9 +50,7 @@ def from_lims(cls, behavior_session_id: int, if allow_none: return cls(driver_line=None) - raise OneOrMoreResultExpectedError( - f"Expected one or more, but received: '{result}' " - f"from query:\n'{query}'") + raise OneOrMoreResultExpectedError(f"Expected one or more, but received: '{result}' from query:\n'{query}'") driver_line = sorted(result) return cls(driver_line=driver_line) diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/full_genotype.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/full_genotype.py index e49a1c340c..c5649f949d 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/full_genotype.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/full_genotype.py @@ -4,27 +4,24 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class FullGenotype(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface): +class FullGenotype(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): """the name of the subject's genotype""" - def __init__(self, full_genotype: Optional[str]): + def __init__(self, full_genotype: Optional[str]): # casting full_genotype into a str because there are instances # in LIMS of full_genotype == NULL super().__init__(name="full_genotype", value=str(full_genotype)) @classmethod def from_json(cls, dict_repr: dict) -> "FullGenotype": - return cls(full_genotype=dict_repr['full_genotype']) + return cls(full_genotype=dict_repr["full_genotype"]) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "FullGenotype": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "FullGenotype": query = f""" SELECT d.full_genotype FROM behavior_sessions bs @@ -52,8 +49,8 @@ def parse_cre_line(self, warn=False) -> Optional[str]: parse """ full_genotype = self.value - if ';' not in full_genotype: + if ";" not in full_genotype: if warn: - warnings.warn('Unable to parse cre_line from full_genotype') + warnings.warn("Unable to parse cre_line from full_genotype") return None - return full_genotype.split(';')[0].replace('/wt', '') + return full_genotype.split(";")[0].replace("/wt", "") diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/mouse_id.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/mouse_id.py index 58842673af..28e2a989f2 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/mouse_id.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/mouse_id.py @@ -29,9 +29,7 @@ def from_json(cls, dict_repr: dict) -> "MouseId": return cls(mouse_id=mouse_id) @classmethod - def from_lims( - cls, behavior_session_id: int, lims_db: PostgresQueryMixin - ) -> "MouseId": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "MouseId": # TODO: Should this even be included? # Found sometimes there were entries with NONE which is # why they are filtered out; also many entries in the table diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/reporter_line.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/reporter_line.py index df9dc8fce6..71c07417c8 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/reporter_line.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/reporter_line.py @@ -4,28 +4,26 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface -from allensdk.internal.api import PostgresQueryMixin, \ - OneOrMoreResultExpectedError +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.internal.api import PostgresQueryMixin, OneOrMoreResultExpectedError -class ReporterLine(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface): +class ReporterLine(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): """the genotype name(s) of the reporter line(s)""" + def __init__(self, reporter_line: Optional[str]): super().__init__(name="reporter_line", value=reporter_line) @classmethod def from_json(cls, dict_repr: dict) -> "ReporterLine": - reporter_line = dict_repr['reporter_line'] + reporter_line = dict_repr["reporter_line"] reporter_line = cls.parse(reporter_line=reporter_line, warn=True) return cls(reporter_line=reporter_line) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin, - allow_none: bool = True) -> "ReporterLine": + def from_lims( + cls, behavior_session_id: int, lims_db: PostgresQueryMixin, allow_none: bool = True + ) -> "ReporterLine": """ Parameters ---------- @@ -57,9 +55,7 @@ def from_lims(cls, behavior_session_id: int, if allow_none: return cls(reporter_line=None) - raise OneOrMoreResultExpectedError( - f"Expected one or more, but received: '{result}' " - f"from query:\n'{query}'") + raise OneOrMoreResultExpectedError(f"Expected one or more, but received: '{result}' from query:\n'{query}'") reporter_line = cls.parse(reporter_line=result, warn=True) return cls(reporter_line=reporter_line) @@ -72,8 +68,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "ReporterLine": return cls(reporter_line=reporter_line) @staticmethod - def parse(reporter_line: Union[Optional[List[str]], str], - warn=False) -> Optional[str]: + def parse(reporter_line: Union[Optional[List[str]], str], warn=False) -> Optional[str]: """There can be multiple reporter lines, so it is returned from LIMS as a list. But there shouldn't be more than 1 for behavior. This tries to convert to str @@ -91,13 +86,12 @@ def parse(reporter_line: Union[Optional[List[str]], str], """ if reporter_line is None: if warn: - warnings.warn('Error parsing reporter line. It is null.') + warnings.warn("Error parsing reporter line. It is null.") return None if len(reporter_line) == 0: if warn: - warnings.warn('Error parsing reporter line. ' - 'The array is empty') + warnings.warn("Error parsing reporter line. The array is empty") return None if isinstance(reporter_line, str): @@ -105,24 +99,17 @@ def parse(reporter_line: Union[Optional[List[str]], str], if len(reporter_line) > 1: if warn: - warnings.warn('More than 1 reporter line. Returning the first ' - 'one') + warnings.warn("More than 1 reporter line. Returning the first one") return reporter_line[0] def parse_indicator(self, warn=False) -> Optional[str]: """Parses indicator from reporter""" reporter_line = self._value - reporter_substring_indicator_map = { - 'GCaMP6f': 'GCaMP6f', - 'GC6f': 'GCaMP6f', - 'GCaMP6s': 'GCaMP6s' - } + reporter_substring_indicator_map = {"GCaMP6f": "GCaMP6f", "GC6f": "GCaMP6f", "GCaMP6s": "GCaMP6s"} if reporter_line is None: if warn: - warnings.warn( - 'Could not parse indicator from reporter because ' - 'there is no reporter') + warnings.warn("Could not parse indicator from reporter because there is no reporter") return None for substr, indicator in reporter_substring_indicator_map.items(): @@ -131,6 +118,7 @@ def parse_indicator(self, warn=False) -> Optional[str]: if warn: warnings.warn( - 'Could not parse indicator from reporter because none' - 'of the expected substrings were found in the reporter') + "Could not parse indicator from reporter because none" + "of the expected substrings were found in the reporter" + ) return None diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/sex.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/sex.py index c0c66822ff..80446d6337 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/sex.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/sex.py @@ -1,14 +1,13 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.internal.api import PostgresQueryMixin -class Sex(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface): +class Sex(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface): """sex of the animal (M/F)""" + def __init__(self, sex: str): super().__init__(name="sex", value=sex) @@ -17,8 +16,7 @@ def from_json(cls, dict_repr: dict) -> "Sex": return cls(sex=dict_repr["sex"]) @classmethod - def from_lims(cls, behavior_session_id: int, - lims_db: PostgresQueryMixin) -> "Sex": + def from_lims(cls, behavior_session_id: int, lims_db: PostgresQueryMixin) -> "Sex": query = f""" SELECT g.name AS sex FROM behavior_sessions bs diff --git a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/subject_metadata.py b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/subject_metadata.py index 163f438229..387ded1c43 100644 --- a/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/subject_metadata.py +++ b/allensdk/brain_observatory/behavior/data_objects/metadata/subject_metadata/subject_metadata.py @@ -6,48 +6,40 @@ from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.core import \ - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface +from allensdk.core import JsonReadableInterface, LimsReadableInterface, NwbReadableInterface from allensdk.core import NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.age import \ - Age -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.driver_line import \ - DriverLine -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.full_genotype import \ - FullGenotype -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.mouse_id import \ - MouseId -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.reporter_line import \ - ReporterLine -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .subject_metadata.sex import \ - Sex +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.age import Age +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.driver_line import DriverLine +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.full_genotype import FullGenotype +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.mouse_id import MouseId +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.reporter_line import ReporterLine +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.sex import Sex from allensdk.brain_observatory.behavior.schemas import SubjectMetadataSchema from allensdk.brain_observatory.nwb import load_pynwb_extension from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP from allensdk.internal.api import PostgresQueryMixin, db_connection_creator -class SubjectMetadata(DataObject, LimsReadableInterface, NwbReadableInterface, - NwbWritableInterface, JsonReadableInterface, - ): +class SubjectMetadata( + DataObject, + LimsReadableInterface, + NwbReadableInterface, + NwbWritableInterface, + JsonReadableInterface, +): """Subject metadata""" - def __init__(self, - sex: Sex, - age: Age, - reporter_line: ReporterLine, - full_genotype: FullGenotype, - driver_line: DriverLine, - mouse_id: MouseId, - death_on: Optional[datetime] = None): - super().__init__(name='subject_metadata', value=None, - is_value_self=True) + def __init__( + self, + sex: Sex, + age: Age, + reporter_line: ReporterLine, + full_genotype: FullGenotype, + driver_line: DriverLine, + mouse_id: MouseId, + death_on: Optional[datetime] = None, + ): + super().__init__(name="subject_metadata", value=None, is_value_self=True) if death_on is not None and death_on.tzinfo is None: # Add UTC tzinfo if not already set death_on = pytz.utc.localize(death_on) @@ -60,28 +52,14 @@ def __init__(self, self._death_on = death_on @classmethod - def from_lims( - cls, - behavior_session_id: BehaviorSessionId, - lims_db: PostgresQueryMixin - ) -> "SubjectMetadata": - sex = Sex.from_lims(behavior_session_id=behavior_session_id.value, - lims_db=lims_db) - age = Age.from_lims(behavior_session_id=behavior_session_id.value, - lims_db=lims_db) - reporter_line = ReporterLine.from_lims( - behavior_session_id=behavior_session_id.value, - lims_db=lims_db - ) - full_genotype = FullGenotype.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db) - driver_line = DriverLine.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=lims_db) - mouse_id = MouseId.from_lims( - behavior_session_id=behavior_session_id.value, - lims_db=lims_db) - death_on = cls._get_death_date_from_lims( - mouse_id=mouse_id.value, lims_db=lims_db) + def from_lims(cls, behavior_session_id: BehaviorSessionId, lims_db: PostgresQueryMixin) -> "SubjectMetadata": + sex = Sex.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + age = Age.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + reporter_line = ReporterLine.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + full_genotype = FullGenotype.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + driver_line = DriverLine.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + mouse_id = MouseId.from_lims(behavior_session_id=behavior_session_id.value, lims_db=lims_db) + death_on = cls._get_death_date_from_lims(mouse_id=mouse_id.value, lims_db=lims_db) return cls( sex=sex, age=age, @@ -89,7 +67,8 @@ def from_lims( driver_line=driver_line, mouse_id=mouse_id, reporter_line=reporter_line, - death_on=death_on) + death_on=death_on, + ) @classmethod def from_json(cls, dict_repr: dict) -> "SubjectMetadata": @@ -100,9 +79,8 @@ def from_json(cls, dict_repr: dict) -> "SubjectMetadata": driver_line = DriverLine.from_json(dict_repr=dict_repr) mouse_id = MouseId.from_json(dict_repr=dict_repr) death_on = cls._get_death_date_from_lims( - mouse_id=mouse_id.value, - lims_db=db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP)) + mouse_id=mouse_id.value, lims_db=db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + ) return cls( sex=sex, @@ -111,7 +89,7 @@ def from_json(cls, dict_repr: dict) -> "SubjectMetadata": driver_line=driver_line, mouse_id=mouse_id, reporter_line=reporter_line, - death_on=death_on + death_on=death_on, ) @classmethod @@ -129,12 +107,11 @@ def from_nwb(cls, nwbfile: NWBFile) -> "SubjectMetadata": age=age, reporter_line=reporter_line, driver_line=driver_line, - full_genotype=genotype + full_genotype=genotype, ) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - BehaviorSubject = load_pynwb_extension(SubjectMetadataSchema, - 'ndx-aibs-behavior-ophys') + BehaviorSubject = load_pynwb_extension(SubjectMetadataSchema, "ndx-aibs-behavior-ophys") nwb_subject = BehaviorSubject( description="A visual behavior subject with a LabTracks ID", age=Age.to_iso8601(age=self.age_in_days), @@ -143,7 +120,8 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: subject_id=str(self.mouse_id), reporter_line=str(self.reporter_line), sex=self.sex, - species='Mus musculus') + species="Mus musculus", + ) nwbfile.subject = nwb_subject return nwbfile @@ -185,11 +163,7 @@ def get_death_date(self) -> Optional[datetime]: return self._death_on @classmethod - def _get_death_date_from_lims( - cls, - mouse_id: int, - lims_db: PostgresQueryMixin - ): + def _get_death_date_from_lims(cls, mouse_id: int, lims_db: PostgresQueryMixin): query = f""" SELECT death_on FROM donors @@ -199,7 +173,7 @@ def _get_death_date_from_lims( res = res[0] if res is not None: # convert to datetime.datetime - res = res.astype('datetime64[s]').astype(datetime) + res = res.astype("datetime64[s]").astype(datetime) res = pytz.utc.localize(res) return res diff --git a/allensdk/brain_observatory/behavior/data_objects/motion_correction.py b/allensdk/brain_observatory/behavior/data_objects/motion_correction.py index dbd4a83d0f..de9688a201 100644 --- a/allensdk/brain_observatory/behavior/data_objects/motion_correction.py +++ b/allensdk/brain_observatory/behavior/data_objects/motion_correction.py @@ -1,19 +1,15 @@ import pandas as pd from pynwb import NWBFile, TimeSeries -from allensdk.brain_observatory.behavior.data_files\ - .rigid_motion_transform_file import \ - RigidMotionTransformFile +from allensdk.brain_observatory.behavior.data_files.rigid_motion_transform_file import RigidMotionTransformFile from allensdk.core import DataObject -from allensdk.core import \ - DataFileReadableInterface, NwbReadableInterface -from allensdk.core import \ - NwbWritableInterface +from allensdk.core import DataFileReadableInterface, NwbReadableInterface +from allensdk.core import NwbWritableInterface -class MotionCorrection(DataObject, DataFileReadableInterface, - NwbReadableInterface, NwbWritableInterface): +class MotionCorrection(DataObject, DataFileReadableInterface, NwbReadableInterface, NwbWritableInterface): """motion correction output""" + def __init__(self, motion_correction: pd.DataFrame): """ :param motion_correction @@ -21,46 +17,35 @@ def __init__(self, motion_correction: pd.DataFrame): x: float y: float """ - super().__init__(name='motion_correction', value=motion_correction) + super().__init__(name="motion_correction", value=motion_correction) @classmethod - def from_data_file( - cls, rigid_motion_transform_file: RigidMotionTransformFile) \ - -> "MotionCorrection": + def from_data_file(cls, rigid_motion_transform_file: RigidMotionTransformFile) -> "MotionCorrection": df = rigid_motion_transform_file.data return cls(motion_correction=df) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "MotionCorrection": - ophys_module = nwbfile.processing['ophys'] + ophys_module = nwbfile.processing["ophys"] motion_correction_data = { - 'x': ophys_module.get_data_interface( - 'ophys_motion_correction_x').data[:], - 'y': ophys_module.get_data_interface( - 'ophys_motion_correction_y').data[:] + "x": ophys_module.get_data_interface("ophys_motion_correction_x").data[:], + "y": ophys_module.get_data_interface("ophys_motion_correction_y").data[:], } df = pd.DataFrame(motion_correction_data) return cls(motion_correction=df) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - ophys_module = nwbfile.processing['ophys'] - ophys_timestamps = ophys_module.get_data_interface( - 'dff').roi_response_series['traces'].timestamps + ophys_module = nwbfile.processing["ophys"] + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps t1 = TimeSeries( - name='ophys_motion_correction_x', - data=self.value['x'].values, - timestamps=ophys_timestamps, - unit='pixels' + name="ophys_motion_correction_x", data=self.value["x"].values, timestamps=ophys_timestamps, unit="pixels" ) t2 = TimeSeries( - name='ophys_motion_correction_y', - data=self.value['y'].values, - timestamps=ophys_timestamps, - unit='pixels' + name="ophys_motion_correction_y", data=self.value["y"].values, timestamps=ophys_timestamps, unit="pixels" ) ophys_module.add_data_interface(t1) diff --git a/allensdk/brain_observatory/behavior/data_objects/projections.py b/allensdk/brain_observatory/behavior/data_objects/projections.py index 08a553629b..5a3dcbfbf7 100644 --- a/allensdk/brain_observatory/behavior/data_objects/projections.py +++ b/allensdk/brain_observatory/behavior/data_objects/projections.py @@ -2,22 +2,17 @@ from pynwb import NWBFile from allensdk.core import DataObject -from allensdk.core import \ - JsonReadableInterface, NwbReadableInterface, \ - LimsReadableInterface -from allensdk.core import \ - NwbWritableInterface +from allensdk.core import JsonReadableInterface, NwbReadableInterface, LimsReadableInterface +from allensdk.core import NwbWritableInterface from allensdk.brain_observatory.behavior.image_api import ImageApi, Image -from allensdk.brain_observatory.nwb.nwb_utils import get_image, \ - add_image_to_nwb +from allensdk.brain_observatory.nwb.nwb_utils import get_image, add_image_to_nwb from allensdk.internal.api import PostgresQueryMixin from allensdk.internal.core.lims_utilities import safe_system_path -class Projections(DataObject, LimsReadableInterface, JsonReadableInterface, - NwbReadableInterface, NwbWritableInterface): +class Projections(DataObject, LimsReadableInterface, JsonReadableInterface, NwbReadableInterface, NwbWritableInterface): def __init__(self, max_projection: Image, avg_projection: Image): - super().__init__(name='projections', value=None, is_value_self=True) + super().__init__(name="projections", value=None, is_value_self=True) self._max_projection = max_projection self._avg_projection = avg_projection @@ -30,8 +25,7 @@ def avg_projection(self) -> Image: return self._avg_projection @classmethod - def from_lims(cls, ophys_experiment_id: int, - lims_db: PostgresQueryMixin) -> "Projections": + def from_lims(cls, ophys_experiment_id: int, lims_db: PostgresQueryMixin) -> "Projections": def _get_filepaths(): """ Note @@ -61,9 +55,10 @@ def _get_filepaths(): # Check if the projections are attached to motion correction/the # ophys_experiment. If not, this is an older experiment and # need to load the projections from the segmentation. - if 'OphysMaxIntImage' not in res['wkfn'].to_list() \ - or 'OphysAverageIntensityProjectionImage' \ - not in res['wkfn'].to_list(): + if ( + "OphysMaxIntImage" not in res["wkfn"].to_list() + or "OphysAverageIntensityProjectionImage" not in res["wkfn"].to_list() + ): query = """ SELECT wkf.storage_directory || wkf.filename AS filepath, @@ -81,7 +76,7 @@ def _get_filepaths(): AND oe.id = {}; """.format(ophys_experiment_id) res = lims_db.select(query=query) - res['filepath'] = res['filepath'].apply(safe_system_path) + res["filepath"] = res["filepath"].apply(safe_system_path) return res def _get_pixel_size(): @@ -96,51 +91,34 @@ def _get_pixel_size(): res = _get_filepaths() pixel_size = _get_pixel_size() - max_projection_filepath = \ - res[res['wkfn'] == 'OphysMaxIntImage'].iloc[0]['filepath'] - max_projection = cls._from_filepath(filepath=max_projection_filepath, - pixel_size=pixel_size) + max_projection_filepath = res[res["wkfn"] == "OphysMaxIntImage"].iloc[0]["filepath"] + max_projection = cls._from_filepath(filepath=max_projection_filepath, pixel_size=pixel_size) - avg_projection_filepath = \ - (res[res['wkfn'] == 'OphysAverageIntensityProjectionImage'].iloc[0] - ['filepath']) - avg_projection = cls._from_filepath(filepath=avg_projection_filepath, - pixel_size=pixel_size) - return Projections(max_projection=max_projection, - avg_projection=avg_projection) + avg_projection_filepath = res[res["wkfn"] == "OphysAverageIntensityProjectionImage"].iloc[0]["filepath"] + avg_projection = cls._from_filepath(filepath=avg_projection_filepath, pixel_size=pixel_size) + return Projections(max_projection=max_projection, avg_projection=avg_projection) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "Projections": - max_projection = get_image(nwbfile=nwbfile, name='max_projection', - module='ophys') - avg_projection = get_image(nwbfile=nwbfile, name='average_image', - module='ophys') - return Projections(max_projection=max_projection, - avg_projection=avg_projection) + max_projection = get_image(nwbfile=nwbfile, name="max_projection", module="ophys") + avg_projection = get_image(nwbfile=nwbfile, name="average_image", module="ophys") + return Projections(max_projection=max_projection, avg_projection=avg_projection) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - add_image_to_nwb(nwbfile=nwbfile, - image_data=self._max_projection, - image_name='max_projection') - add_image_to_nwb(nwbfile=nwbfile, - image_data=self._avg_projection, - image_name='average_image') + add_image_to_nwb(nwbfile=nwbfile, image_data=self._max_projection, image_name="max_projection") + add_image_to_nwb(nwbfile=nwbfile, image_data=self._avg_projection, image_name="average_image") return nwbfile @classmethod def from_json(cls, dict_repr: dict) -> "Projections": - max_projection_filepath = dict_repr['max_projection_file'] - avg_projection_filepath = \ - dict_repr['average_intensity_projection_image_file'] - pixel_size = dict_repr['surface_2p_pixel_size_um'] - - max_projection = cls._from_filepath(filepath=max_projection_filepath, - pixel_size=pixel_size) - avg_projection = cls._from_filepath(filepath=avg_projection_filepath, - pixel_size=pixel_size) - return Projections(max_projection=max_projection, - avg_projection=avg_projection) + max_projection_filepath = dict_repr["max_projection_file"] + avg_projection_filepath = dict_repr["average_intensity_projection_image_file"] + pixel_size = dict_repr["surface_2p_pixel_size_um"] + + max_projection = cls._from_filepath(filepath=max_projection_filepath, pixel_size=pixel_size) + avg_projection = cls._from_filepath(filepath=avg_projection_filepath, pixel_size=pixel_size) + return Projections(max_projection=max_projection, avg_projection=avg_projection) @staticmethod def _from_filepath(filepath: str, pixel_size: float) -> Image: @@ -151,7 +129,6 @@ def _from_filepath(filepath: str, pixel_size: float) -> Image: pixel size in um """ img = PILImage.open(filepath) - img = ImageApi.serialize(img, [pixel_size / 1000., - pixel_size / 1000.], 'mm') + img = ImageApi.serialize(img, [pixel_size / 1000.0, pixel_size / 1000.0], "mm") img = ImageApi.deserialize(img=img) return img diff --git a/allensdk/brain_observatory/behavior/data_objects/rewards.py b/allensdk/brain_observatory/behavior/data_objects/rewards.py index 1dcc0a745c..2eae30297d 100644 --- a/allensdk/brain_observatory/behavior/data_objects/rewards.py +++ b/allensdk/brain_observatory/behavior/data_objects/rewards.py @@ -7,32 +7,30 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface -from allensdk.core import \ - NwbWritableInterface +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface +from allensdk.core import NwbWritableInterface -class Rewards(DataObject, StimulusFileReadableInterface, NwbReadableInterface, - NwbWritableInterface): +class Rewards(DataObject, StimulusFileReadableInterface, NwbReadableInterface, NwbWritableInterface): def __init__(self, rewards: pd.DataFrame): - super().__init__(name='rewards', value=rewards) + super().__init__(name="rewards", value=rewards) @classmethod def from_stimulus_file( - cls, stimulus_file: BehaviorStimulusFile, - stimulus_timestamps: StimulusTimestamps) -> "Rewards": + cls, stimulus_file: BehaviorStimulusFile, stimulus_timestamps: StimulusTimestamps + ) -> "Rewards": """Get reward data from pkl file, based on timestamps (not sync file). """ if not np.isclose(stimulus_timestamps.monitor_delay, 0.0): - msg = ("Instantiating rewards with monitor_delay = " - f"{stimulus_timestamps.monitor_delay: .2e}; " - "monitor_delay should be zero for Rewards " - "data object") + msg = ( + "Instantiating rewards with monitor_delay = " + f"{stimulus_timestamps.monitor_delay: .2e}; " + "monitor_delay should be zero for Rewards " + "data object" + ) raise RuntimeError(msg) data = stimulus_file.data @@ -44,8 +42,7 @@ def from_stimulus_file( # as i write this there can only ever be one reward per trial if rewards: rewards_dict["volume"].append(rewards[0][0]) - rewards_dict["timestamps"].append( - stimulus_timestamps.value[rewards[0][2]]) + rewards_dict["timestamps"].append(stimulus_timestamps.value[rewards[0][2]]) auto_rwrd = trial["trial_params"]["auto_reward"] rewards_dict["auto_rewarded"].append(auto_rwrd) @@ -54,46 +51,38 @@ def from_stimulus_file( @classmethod def from_nwb(cls, nwbfile: NWBFile) -> Optional["Rewards"]: - if 'rewards' in nwbfile.processing: - rewards = nwbfile.processing['rewards'] - time = rewards.get_data_interface('autorewarded').timestamps[:] - autorewarded = rewards.get_data_interface('autorewarded').data[:] - volume = rewards.get_data_interface('volume').data[:] + if "rewards" in nwbfile.processing: + rewards = nwbfile.processing["rewards"] + time = rewards.get_data_interface("autorewarded").timestamps[:] + autorewarded = rewards.get_data_interface("autorewarded").data[:] + volume = rewards.get_data_interface("volume").data[:] else: volume = [] time = [] autorewarded = [] - df = pd.DataFrame({ - 'volume': volume, - 'timestamps': time, - 'auto_rewarded': autorewarded}) + df = pd.DataFrame({"volume": volume, "timestamps": time, "auto_rewarded": autorewarded}) return cls(rewards=df) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - # If there is no rewards data, do not # write anything to the NWB file (this # is expected for passive sessions) - if len(self.value['timestamps']) == 0: + if len(self.value["timestamps"]) == 0: return nwbfile reward_volume_ts = TimeSeries( - name='volume', - data=self.value['volume'].values, - timestamps=self.value['timestamps'].values, - unit='mL' + name="volume", data=self.value["volume"].values, timestamps=self.value["timestamps"].values, unit="mL" ) autorewarded_ts = TimeSeries( - name='autorewarded', - data=self.value['auto_rewarded'].values, + name="autorewarded", + data=self.value["auto_rewarded"].values, timestamps=reward_volume_ts.timestamps, - unit='mL' + unit="mL", ) - rewards_mod = ProcessingModule('rewards', - 'Licking behavior processing module') + rewards_mod = ProcessingModule("rewards", "Licking behavior processing module") rewards_mod.add_data_interface(reward_volume_ts) rewards_mod.add_data_interface(autorewarded_ts) nwbfile.add_processing_module(rewards_mod) diff --git a/allensdk/brain_observatory/behavior/data_objects/running_speed/multi_stim_running_processing.py b/allensdk/brain_observatory/behavior/data_objects/running_speed/multi_stim_running_processing.py index 3b17143bd3..99d7dfabaa 100644 --- a/allensdk/brain_observatory/behavior/data_objects/running_speed/multi_stim_running_processing.py +++ b/allensdk/brain_observatory/behavior/data_objects/running_speed/multi_stim_running_processing.py @@ -6,22 +6,19 @@ BehaviorStimulusFile, MappingStimulusFile, ReplayStimulusFile, - _StimulusFile) + _StimulusFile, +) -from allensdk.brain_observatory.sync_stim_aligner import ( - get_stim_timestamps_from_stimulus_blocks) +from allensdk.brain_observatory.sync_stim_aligner import get_stim_timestamps_from_stimulus_blocks -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.running_processing import ( - get_running_df - ) +from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import get_running_df def _extract_dx_info( - frame_times: np.ndarray, - stimulus_file: _StimulusFile, - zscore_threshold: float = 10.0, - use_lowpass_filter: bool = True + frame_times: np.ndarray, + stimulus_file: _StimulusFile, + zscore_threshold: float = 10.0, + use_lowpass_filter: bool = True, ) -> pd.core.frame.DataFrame: """ Extract all of the running speed data @@ -54,12 +51,7 @@ def _extract_dx_info( stim_file = stimulus_file.data - velocities = get_running_df( - stim_file, - frame_times, - use_lowpass_filter, - zscore_threshold - ) + velocities = get_running_df(stim_file, frame_times, use_lowpass_filter, zscore_threshold) return velocities @@ -69,7 +61,7 @@ def _merge_dx_data( behavior_velocities: pd.core.frame.DataFrame, replay_velocities: pd.core.frame.DataFrame, frame_times: np.ndarray, - behavior_start_frame: int + behavior_start_frame: int, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Concatenate all of the running speed data @@ -95,49 +87,23 @@ def _merge_dx_data( """ speed = np.concatenate( - ( - behavior_velocities['speed'], - mapping_velocities['speed'], - replay_velocities['speed']), - axis=None - ) - - dx = np.concatenate( - ( - behavior_velocities['dx'], - mapping_velocities['dx'], - replay_velocities['dx']), - axis=None + (behavior_velocities["speed"], mapping_velocities["speed"], replay_velocities["speed"]), axis=None ) + dx = np.concatenate((behavior_velocities["dx"], mapping_velocities["dx"], replay_velocities["dx"]), axis=None) + vsig = np.concatenate( - ( - behavior_velocities['v_sig'], - mapping_velocities['v_sig'], - replay_velocities['v_sig']), - axis=None + (behavior_velocities["v_sig"], mapping_velocities["v_sig"], replay_velocities["v_sig"]), axis=None ) vin = np.concatenate( - ( - behavior_velocities['v_in'], - mapping_velocities['v_in'], - replay_velocities['v_in']), - axis=None + (behavior_velocities["v_in"], mapping_velocities["v_in"], replay_velocities["v_in"]), axis=None ) - frame_indexes = list( - range(behavior_start_frame, - behavior_start_frame+len(frame_times)) - ) + frame_indexes = list(range(behavior_start_frame, behavior_start_frame + len(frame_times))) velocities = pd.DataFrame( - { - "velocity": speed, - "net_rotation": dx, - "frame_indexes": frame_indexes, - "frame_time": frame_times - } + {"velocity": speed, "net_rotation": dx, "frame_indexes": frame_indexes, "frame_time": frame_times} ) # Warning - the 'isclose' line below needs to be refactored @@ -149,9 +115,7 @@ def _merge_dx_data( # there may be exact zeros in the velocity. velocities = velocities[~(np.isclose(velocities["net_rotation"], 0.0))] - raw_data = pd.DataFrame( - {"vsig": vsig, "vin": vin, "frame_time": frame_times, "dx": dx} - ) + raw_data = pd.DataFrame({"vsig": vsig, "vin": vin, "frame_time": frame_times, "dx": dx}) return (velocities, raw_data) @@ -163,7 +127,7 @@ def multi_stim_running_df_from_raw_data( replay_stimulus_file: ReplayStimulusFile, use_lowpass_filter: bool, zscore_threshold: float, - behavior_start_frame: int + behavior_start_frame: int, ) -> Tuple[pd.DataFrame, pd.DataFrame]: """ Derive running speed data frames from sync file and @@ -221,15 +185,12 @@ def multi_stim_running_df_from_raw_data( """ timestamp_results = get_stim_timestamps_from_stimulus_blocks( - stimulus_files=[behavior_stimulus_file, - mapping_stimulus_file, - replay_stimulus_file], - sync_file=sync_path, - raw_frame_time_lines=['frames', - 'stim_vsync', - 'vsync_stim'], - raw_frame_time_direction='rising', - frame_count_tolerance=0.0) + stimulus_files=[behavior_stimulus_file, mapping_stimulus_file, replay_stimulus_file], + sync_file=sync_path, + raw_frame_time_lines=["frames", "stim_vsync", "vsync_stim"], + raw_frame_time_direction="rising", + frame_count_tolerance=0.0, + ) behavior_timestamps = timestamp_results["timestamps"][0] mapping_timestamps = timestamp_results["timestamps"][1] @@ -241,46 +202,44 @@ def multi_stim_running_df_from_raw_data( frame_times=behavior_timestamps, stimulus_file=behavior_stimulus_file, zscore_threshold=zscore_threshold, - use_lowpass_filter=use_lowpass_filter + use_lowpass_filter=use_lowpass_filter, ) mapping_velocities = _extract_dx_info( frame_times=mapping_timestamps, stimulus_file=mapping_stimulus_file, zscore_threshold=zscore_threshold, - use_lowpass_filter=use_lowpass_filter + use_lowpass_filter=use_lowpass_filter, ) replay_velocities = _extract_dx_info( frame_times=replay_timestamps, stimulus_file=replay_stimulus_file, zscore_threshold=zscore_threshold, - use_lowpass_filter=use_lowpass_filter + use_lowpass_filter=use_lowpass_filter, ) - all_frame_times = np.concatenate( - [behavior_timestamps, - mapping_timestamps, - replay_timestamps]) + all_frame_times = np.concatenate([behavior_timestamps, mapping_timestamps, replay_timestamps]) velocities, raw_data = _merge_dx_data( mapping_velocities=mapping_velocities, behavior_velocities=behavior_velocities, replay_velocities=replay_velocities, frame_times=all_frame_times, - behavior_start_frame=start_frame + behavior_start_frame=start_frame, ) return (velocities, raw_data) def _get_multi_stim_running_df( - sync_path: str, - behavior_stimulus_file: BehaviorStimulusFile, - mapping_stimulus_file: MappingStimulusFile, - replay_stimulus_file: ReplayStimulusFile, - use_lowpass_filter: bool, - zscore_threshold: float) -> Dict[str, pd.DataFrame]: + sync_path: str, + behavior_stimulus_file: BehaviorStimulusFile, + mapping_stimulus_file: MappingStimulusFile, + replay_stimulus_file: ReplayStimulusFile, + use_lowpass_filter: bool, + zscore_threshold: float, +) -> Dict[str, pd.DataFrame]: """ Parameters ---------- @@ -308,29 +267,27 @@ def _get_multi_stim_running_df( 'running_acquisition': A dataframe mapping time to raw data coming off the running wheel """ - (velocity_data, - acq_data) = multi_stim_running_df_from_raw_data( - sync_path=sync_path, - behavior_stimulus_file=behavior_stimulus_file, - mapping_stimulus_file=mapping_stimulus_file, - replay_stimulus_file=replay_stimulus_file, - use_lowpass_filter=use_lowpass_filter, - zscore_threshold=zscore_threshold, - behavior_start_frame=0) + (velocity_data, acq_data) = multi_stim_running_df_from_raw_data( + sync_path=sync_path, + behavior_stimulus_file=behavior_stimulus_file, + mapping_stimulus_file=mapping_stimulus_file, + replay_stimulus_file=replay_stimulus_file, + use_lowpass_filter=use_lowpass_filter, + zscore_threshold=zscore_threshold, + behavior_start_frame=0, + ) running_speed = pd.DataFrame( - data={ - 'timestamps': velocity_data.frame_time.values, - 'speed': velocity_data.velocity.values - }) + data={"timestamps": velocity_data.frame_time.values, "speed": velocity_data.velocity.values} + ) running_acq = pd.DataFrame( - data={ - 'dx': acq_data.dx.values, - 'timestamps': acq_data.frame_time.values, - 'v_in': acq_data.vin.values, - 'v_sig': acq_data.vsig.values - }).set_index('timestamps') - - return {'running_speed': running_speed, - 'running_acquisition': running_acq} + data={ + "dx": acq_data.dx.values, + "timestamps": acq_data.frame_time.values, + "v_in": acq_data.vin.values, + "v_sig": acq_data.vsig.values, + } + ).set_index("timestamps") + + return {"running_speed": running_speed, "running_acquisition": running_acq} diff --git a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_acquisition.py b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_acquisition.py index c554206519..5721193ebb 100644 --- a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_acquisition.py +++ b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_acquisition.py @@ -1,4 +1,3 @@ - from typing import Optional import pandas as pd @@ -12,23 +11,17 @@ from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory.behavior.data_files import SyncFile -from allensdk.brain_observatory.behavior.data_files import ( - BehaviorStimulusFile, - ReplayStimulusFile, - MappingStimulusFile -) +from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile, ReplayStimulusFile, MappingStimulusFile from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import ( # noqa: E501 - get_running_df + get_running_df, ) -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.multi_stim_running_processing import ( - _get_multi_stim_running_df) +from allensdk.brain_observatory.behavior.data_objects.running_speed.multi_stim_running_processing import ( + _get_multi_stim_running_df, +) -class RunningAcquisition(DataObject, - NwbReadableInterface, - NwbWritableInterface): +class RunningAcquisition(DataObject, NwbReadableInterface, NwbWritableInterface): """A DataObject which contains properties and methods to load, process, and represent running acquisition data. @@ -59,29 +52,27 @@ def __init__( "Running acquisition timestamps have montior delay " f"{stimulus_timestamps.monitor_delay}; there " "should be no monitor delay applied to the timestamps " - "associated with running acquisition") + "associated with running acquisition" + ) self._stimulus_file = stimulus_file self._stimulus_timestamps = stimulus_timestamps @classmethod def from_stimulus_file( - cls, - behavior_stimulus_file: BehaviorStimulusFile, - sync_file: Optional[SyncFile] = None) -> "RunningAcquisition": + cls, behavior_stimulus_file: BehaviorStimulusFile, sync_file: Optional[SyncFile] = None + ) -> "RunningAcquisition": """ sync_file is used for generating timestamps. If None, timestamps will be generated from the stimulus file. """ if sync_file is not None: - stimulus_timestamps = StimulusTimestamps.from_sync_file( - sync_file=sync_file, - monitor_delay=0.0) + stimulus_timestamps = StimulusTimestamps.from_sync_file(sync_file=sync_file, monitor_delay=0.0) else: stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=behavior_stimulus_file, - monitor_delay=0.0) + stimulus_file=behavior_stimulus_file, monitor_delay=0.0 + ) running_acq_df = get_running_df( data=behavior_stimulus_file.data, @@ -97,11 +88,12 @@ def from_stimulus_file( @classmethod def from_multiple_stimulus_files( - cls, - behavior_stimulus_file: BehaviorStimulusFile, - mapping_stimulus_file: MappingStimulusFile, - replay_stimulus_file: ReplayStimulusFile, - sync_file: SyncFile) -> "RunningAcquisition": + cls, + behavior_stimulus_file: BehaviorStimulusFile, + mapping_stimulus_file: MappingStimulusFile, + replay_stimulus_file: ReplayStimulusFile, + sync_file: SyncFile, + ) -> "RunningAcquisition": """ sync_file is used for generating timestamps. @@ -112,38 +104,28 @@ def from_multiple_stimulus_files( """ df = _get_multi_stim_running_df( - sync_path=sync_file.filepath, - behavior_stimulus_file=behavior_stimulus_file, - mapping_stimulus_file=mapping_stimulus_file, - replay_stimulus_file=replay_stimulus_file, - use_lowpass_filter=False, - zscore_threshold=10.0)['running_acquisition'] + sync_path=sync_file.filepath, + behavior_stimulus_file=behavior_stimulus_file, + mapping_stimulus_file=mapping_stimulus_file, + replay_stimulus_file=replay_stimulus_file, + use_lowpass_filter=False, + zscore_threshold=10.0, + )["running_acquisition"] - return cls( - running_acquisition=df, - stimulus_file=None, - stimulus_timestamps=None) + return cls(running_acquisition=df, stimulus_file=None, stimulus_timestamps=None) @classmethod - def from_nwb( - cls, - nwbfile: NWBFile - ) -> "RunningAcquisition": - running_module = nwbfile.processing['running'] - dx_interface = running_module.get_data_interface('dx') + def from_nwb(cls, nwbfile: NWBFile) -> "RunningAcquisition": + running_module = nwbfile.processing["running"] + dx_interface = running_module.get_data_interface("dx") dx = dx_interface.data - v_in = nwbfile.get_acquisition('v_in').data - v_sig = nwbfile.get_acquisition('v_sig').data + v_in = nwbfile.get_acquisition("v_in").data + v_sig = nwbfile.get_acquisition("v_sig").data timestamps = dx_interface.timestamps[:] running_acq_df = pd.DataFrame( - { - 'dx': dx, - 'v_in': v_in, - 'v_sig': v_sig - }, - index=pd.Index(timestamps, name='timestamps') + {"dx": dx, "v_in": v_in, "v_sig": v_sig}, index=pd.Index(timestamps, name="timestamps") ) return cls(running_acquisition=running_acq_df) @@ -151,40 +133,38 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: running_acquisition_df: pd.DataFrame = self.value running_dx_series = TimeSeries( - name='dx', - data=running_acquisition_df['dx'].values, + name="dx", + data=running_acquisition_df["dx"].values, timestamps=running_acquisition_df.index.values, - unit='cm', - description=( - 'Running wheel angular change, computed during data collection' - ) + unit="cm", + description=("Running wheel angular change, computed during data collection"), ) v_sig = TimeSeries( - name='v_sig', - data=running_acquisition_df['v_sig'].values, + name="v_sig", + data=running_acquisition_df["v_sig"].values, timestamps=running_acquisition_df.index.values, - unit='V', - description='Voltage signal from the running wheel encoder' + unit="V", + description="Voltage signal from the running wheel encoder", ) v_in = TimeSeries( - name='v_in', - data=running_acquisition_df['v_in'].values, + name="v_in", + data=running_acquisition_df["v_in"].values, timestamps=running_acquisition_df.index.values, - unit='V', + unit="V", description=( - 'The theoretical maximum voltage that the running wheel ' + "The theoretical maximum voltage that the running wheel " 'encoder will reach prior to "wrapping". This should ' - 'theoretically be 5V (after crossing 5V goes to 0V, or ' - 'vice versa). In practice the encoder does not always ' - 'reach this value before wrapping, which can cause ' - 'transient spikes in speed at the voltage "wraps".') + "theoretically be 5V (after crossing 5V goes to 0V, or " + "vice versa). In practice the encoder does not always " + "reach this value before wrapping, which can cause " + 'transient spikes in speed at the voltage "wraps".' + ), ) - if 'running' in nwbfile.processing: - running_mod = nwbfile.processing['running'] + if "running" in nwbfile.processing: + running_mod = nwbfile.processing["running"] else: - running_mod = ProcessingModule('running', - 'Running speed processing module') + running_mod = ProcessingModule("running", "Running speed processing module") nwbfile.add_processing_module(running_mod) running_mod.add_data_interface(running_dx_series) diff --git a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_processing.py b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_processing.py index 7222cb7d10..1e2b5b4b24 100644 --- a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_processing.py +++ b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_processing.py @@ -12,8 +12,7 @@ def calc_deriv(x, time): return dx / dt -def _angular_change(summed_voltage: np.ndarray, - vmax: Union[np.ndarray, float]) -> np.ndarray: +def _angular_change(summed_voltage: np.ndarray, vmax: Union[np.ndarray, float]) -> np.ndarray: """ Compute the change in degrees in radians at each point from the summed voltage encoder data. @@ -37,10 +36,7 @@ def _angular_change(summed_voltage: np.ndarray, return delta_theta -def _shift( - arr: Iterable, - periods: int = 1, - fill_value: float = np.nan) -> np.ndarray: +def _shift(arr: Iterable, periods: int = 1, fill_value: float = np.nan) -> np.ndarray: """ Shift index of an iterable (array-like) by desired number of periods with an optional fill value (default = NaN). @@ -89,14 +85,13 @@ def deg_to_dist(angular_speed: np.ndarray) -> np.ndarray: wheel_diameter = 6.5 * 2.54 # 6.5" wheel diameter, 2.54 = cm/in running_radius = 0.5 * ( # assume the animal runs at 2/3 the distance from the wheel center - 2.0 * wheel_diameter / 3.0) + 2.0 * wheel_diameter / 3.0 + ) running_speed_cm_per_sec = angular_speed * running_radius return running_speed_cm_per_sec -def _identify_wraps(vsig: Iterable, *, - min_threshold: float = 1.5, - max_threshold: float = 3.5): +def _identify_wraps(vsig: Iterable, *, min_threshold: float = 1.5, max_threshold: float = 3.5): """ Identify "wraps" in the voltage signal. In practice, this is when the encoder voltage signal crosses 5V and wraps to 0V, or @@ -125,13 +120,9 @@ def _identify_wraps(vsig: Iterable, *, if not isinstance(vsig, np.ndarray): vsig = np.array(vsig) # Suppress warnings for when comparing to nan values - with np.errstate(invalid='ignore'): - pos_wraps = np.asarray( - np.logical_and(vsig < min_threshold, shifted_vsig > max_threshold) - ).nonzero()[0] - neg_wraps = np.asarray( - np.logical_and(vsig > max_threshold, shifted_vsig < min_threshold) - ).nonzero()[0] + with np.errstate(invalid="ignore"): + pos_wraps = np.asarray(np.logical_and(vsig < min_threshold, shifted_vsig > max_threshold)).nonzero()[0] + neg_wraps = np.asarray(np.logical_and(vsig > max_threshold, shifted_vsig < min_threshold)).nonzero()[0] return pos_wraps, neg_wraps @@ -160,8 +151,9 @@ def _local_boundaries(time, index, span: float = 0.25) -> tuple: ``` """ if np.diff(time[~np.isnan(time)]).min() < 0: - raise ValueError("Data do not monotonically increase. This probably " - "means there is an error in your time series.") + raise ValueError( + "Data do not monotonically increase. This probably means there is an error in your time series." + ) t_val = time[index] max_val = t_val + abs(span) min_val = t_val - abs(span) @@ -169,13 +161,15 @@ def _local_boundaries(time, index, span: float = 0.25) -> tuple: max_ix = eligible_indices.max() min_ix = eligible_indices.min() if (min_ix == index) or (max_ix == index): - warnings.warn("Unable to find two data points around index " - f"for span={span} that do not include the index. " - "This could mean that your time span is too small for " - "the time data sampling rate, the data are not " - "monotonically increasing, or that you are trying " - "to find a neighborhood at the beginning/end of the " - "data stream.") + warnings.warn( + "Unable to find two data points around index " + f"for span={span} that do not include the index. " + "This could mean that your time span is too small for " + "the time data sampling rate, the data are not " + "monotonically increasing, or that you are trying " + "to find a neighborhood at the beginning/end of the " + "data stream." + ) return min_ix, max_ix @@ -191,21 +185,22 @@ def _clip_speed_wraps(speed, time, wrap_indices, t_span: float = 0.25): corrected_speed = speed.copy() for wrap in wrap_indices: start_ix, end_ix = _local_boundaries(time, wrap, t_span) - local_slice = np.concatenate( # Remove the wrap point - (speed[start_ix:wrap], speed[wrap+1:end_ix+1])) - corrected_speed[wrap] = np.clip( - speed[wrap], np.nanmin(local_slice), np.nanmax(local_slice)) + local_slice = np.concatenate( # Remove the wrap point + (speed[start_ix:wrap], speed[wrap + 1 : end_ix + 1]) + ) + corrected_speed[wrap] = np.clip(speed[wrap], np.nanmin(local_slice), np.nanmax(local_slice)) return corrected_speed def _unwrap_voltage_signal( - vsig: Iterable, - pos_wrap_ix: Iterable, - neg_wrap_ix: Iterable, - *, - vmax: Optional[float] = None, - max_threshold: float = 5.1, - max_diff: float = 1.0) -> np.ndarray: + vsig: Iterable, + pos_wrap_ix: Iterable, + neg_wrap_ix: Iterable, + *, + vmax: Optional[float] = None, + max_threshold: float = 5.1, + max_diff: float = 1.0, +) -> np.ndarray: """ Calculate the change in voltage at each timestamp. 'Unwraps' the @@ -250,12 +245,10 @@ def _unwrap_voltage_signal( vsig_last = _shift(vsig) if len(pos_wrap_ix): # positive wraps: subtract from the previous value and add vmax - unwrapped_diff[pos_wrap_ix] = ( - (vsig[pos_wrap_ix] + vmax) - vsig_last[pos_wrap_ix]) + unwrapped_diff[pos_wrap_ix] = (vsig[pos_wrap_ix] + vmax) - vsig_last[pos_wrap_ix] # negative: subtract vmax and the previous value if len(neg_wrap_ix): - unwrapped_diff[neg_wrap_ix] = ( - vsig[neg_wrap_ix] - (vsig_last[neg_wrap_ix] + vmax)) + unwrapped_diff[neg_wrap_ix] = vsig[neg_wrap_ix] - (vsig_last[neg_wrap_ix] + vmax) # Other indices, just compute straight diff from previous value wrap_ix = np.concatenate((pos_wrap_ix, neg_wrap_ix)) other_ix = np.array(list(set(range(len(vsig_last))).difference(wrap_ix))) @@ -263,19 +256,17 @@ def _unwrap_voltage_signal( # Correct for wrap artifacts based on allowed `max_diff` value # (fill with nan) # Suppress warnings when comparing with nan values to reduce noise - with np.errstate(invalid='ignore'): - unwrapped_diff = np.where( - np.abs(unwrapped_diff) <= max_diff, unwrapped_diff, np.nan) + with np.errstate(invalid="ignore"): + unwrapped_diff = np.where(np.abs(unwrapped_diff) <= max_diff, unwrapped_diff, np.nan) # Get nan indices to propogate to the cumulative sum (otherwise # treated as 0) unwrapped_nans = np.array(np.isnan(unwrapped_diff)).nonzero() - summed_diff = np.nancumsum(unwrapped_diff) + vsig[0] # Add the baseline + summed_diff = np.nancumsum(unwrapped_diff) + vsig[0] # Add the baseline summed_diff[unwrapped_nans] = np.nan return summed_diff -def _zscore_threshold_1d(data: np.ndarray, - threshold: float = 5.0) -> np.ndarray: +def _zscore_threshold_1d(data: np.ndarray, threshold: float = 5.0) -> np.ndarray: """ Replace values in 1d array `data` that exceed `threshold` number of SDs from the mean with NaN. @@ -294,14 +285,12 @@ def _zscore_threshold_1d(data: np.ndarray, corrected_data = data.copy().astype("float") scores = zscore(data, nan_policy="omit") # Suppress warnings when comparing to nan values to reduce noise - with np.errstate(invalid='ignore'): + with np.errstate(invalid="ignore"): corrected_data[np.abs(scores) > threshold] = np.nan return corrected_data -def get_running_df( - data, time: np.ndarray, lowpass: bool = True, zscore_threshold=10.0 -): +def get_running_df(data, time: np.ndarray, lowpass: bool = True, zscore_threshold=10.0): """ Given the data from the behavior 'pkl' file object and a 1d array of timestamps, compute the running speed. Returns a @@ -355,17 +344,14 @@ def get_running_df( v_in = data["items"]["behavior"]["encoders"][0]["vin"] if len(v_in) > len(time) + 1: - error_string = ("length of v_in ({}) cannot be longer than length of " - "time ({}) + 1, they are off by {}").format( - len(v_in), - len(time), - abs(len(v_in) - len(time)) + error_string = ("length of v_in ({}) cannot be longer than length of time ({}) + 1, they are off by {}").format( + len(v_in), len(time), abs(len(v_in) - len(time)) ) raise ValueError(error_string) if len(v_in) == len(time) + 1: warnings.warn( - "Time array is 1 value shorter than encoder array. Last encoder " - "value removed\n", UserWarning, stacklevel=1) + "Time array is 1 value shorter than encoder array. Last encoder value removed\n", UserWarning, stacklevel=1 + ) v_in = v_in[:-1] v_sig = v_sig[:-1] @@ -375,11 +361,9 @@ def get_running_df( dx_raw = data["items"]["behavior"]["encoders"][0]["dx"] # Identify "wraps" in the voltage signal that need to be unwrapped # This is where the encoder switches from 0V to 5V or vice versa - pos_wraps, neg_wraps = _identify_wraps( - v_sig, min_threshold=1.5, max_threshold=3.5) + pos_wraps, neg_wraps = _identify_wraps(v_sig, min_threshold=1.5, max_threshold=3.5) # Unwrap the voltage signal and apply correction for transient spikes - unwrapped_vsig = _unwrap_voltage_signal( - v_sig, pos_wraps, neg_wraps, max_threshold=5.1, max_diff=1.0) + unwrapped_vsig = _unwrap_voltage_signal(v_sig, pos_wraps, neg_wraps, max_threshold=5.1, max_diff=1.0) angular_change_point = _angular_change(unwrapped_vsig, v_in) angular_change = np.nancumsum(angular_change_point) # Add the nans back in (get turned to 0 in nancumsum) @@ -388,20 +372,21 @@ def get_running_df( linear_speed = deg_to_dist(angular_speed) # Artifact correction to speed data wrap_corrected_linear_speed = _clip_speed_wraps( - linear_speed, time, np.concatenate([pos_wraps, neg_wraps]), - t_span=0.25) - outlier_corrected_linear_speed = _zscore_threshold_1d( - wrap_corrected_linear_speed, threshold=zscore_threshold) + linear_speed, time, np.concatenate([pos_wraps, neg_wraps]), t_span=0.25 + ) + outlier_corrected_linear_speed = _zscore_threshold_1d(wrap_corrected_linear_speed, threshold=zscore_threshold) # Final filtering (optional) for smoothing out the speed data if lowpass: b, a = signal.butter(3, Wn=4, fs=60, btype="lowpass") - outlier_corrected_linear_speed = signal.filtfilt( - b, a, np.nan_to_num(outlier_corrected_linear_speed)) - - return pd.DataFrame({ - 'speed': outlier_corrected_linear_speed[:len(time)], - 'dx': dx_raw[:len(time)], - 'v_sig': v_sig[:len(time)], - 'v_in': v_in[:len(time)], - }, index=pd.Index(time, name='timestamps')) + outlier_corrected_linear_speed = signal.filtfilt(b, a, np.nan_to_num(outlier_corrected_linear_speed)) + + return pd.DataFrame( + { + "speed": outlier_corrected_linear_speed[: len(time)], + "dx": dx_raw[: len(time)], + "v_sig": v_sig[: len(time)], + "v_in": v_in[: len(time)], + }, + index=pd.Index(time, name="timestamps"), + ) diff --git a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_speed.py b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_speed.py index c2640ef2e7..3a7fa59d06 100644 --- a/allensdk/brain_observatory/behavior/data_objects/running_speed/running_speed.py +++ b/allensdk/brain_observatory/behavior/data_objects/running_speed/running_speed.py @@ -12,24 +12,18 @@ from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_files import SyncFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps -from allensdk.brain_observatory.behavior.data_files import ( - BehaviorStimulusFile, - ReplayStimulusFile, - MappingStimulusFile -) +from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile, ReplayStimulusFile, MappingStimulusFile from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import ( # noqa: E501 - get_running_df + get_running_df, ) -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.multi_stim_running_processing import ( - _get_multi_stim_running_df) +from allensdk.brain_observatory.behavior.data_objects.running_speed.multi_stim_running_processing import ( + _get_multi_stim_running_df, +) -class RunningSpeed(DataObject, - NwbReadableInterface, - NwbWritableInterface): +class RunningSpeed(DataObject, NwbReadableInterface, NwbWritableInterface): """A DataObject which contains properties and methods to load, process, and represent running speed data. @@ -46,10 +40,10 @@ def __init__( stimulus_file: Optional[BehaviorStimulusFile] = None, sync_file: Optional[SyncFile] = None, stimulus_timestamps: Optional[StimulusTimestamps] = None, - filtered: bool = True + filtered: bool = True, ): running_speed = self._fix_polarity(running_speed) - super().__init__(name='running_speed', value=running_speed) + super().__init__(name="running_speed", value=running_speed) if stimulus_timestamps is not None: if not np.isclose(stimulus_timestamps.monitor_delay, 0.0): @@ -57,7 +51,8 @@ def __init__( "Running speed timestamps have monitor delay " f"{stimulus_timestamps.monitor_delay}; there " "should be no monitor delay applied to the timestamps " - "associated with running speed") + "associated with running speed" + ) self._stimulus_file = stimulus_file self._sync_file = sync_file @@ -88,9 +83,9 @@ def _fix_polarity(running_speed: pd.DataFrame) -> pd.DataFrame: DataFrame with potentially flipped value of running speed if the average speed is < -1 cm/s """ - mean_speed = running_speed['speed'].mean() + mean_speed = running_speed["speed"].mean() if mean_speed < -1: - running_speed['speed'] = -1 * running_speed['speed'] + running_speed["speed"] = -1 * running_speed["speed"] return running_speed @staticmethod @@ -98,66 +93,62 @@ def _get_running_speed_df( stimulus_file: BehaviorStimulusFile, stimulus_timestamps: StimulusTimestamps, filtered: bool = True, - zscore_threshold: float = 1.0 + zscore_threshold: float = 1.0, ) -> pd.DataFrame: running_data_df = get_running_df( - data=stimulus_file.data, time=stimulus_timestamps.value, - lowpass=filtered, zscore_threshold=zscore_threshold + data=stimulus_file.data, time=stimulus_timestamps.value, lowpass=filtered, zscore_threshold=zscore_threshold ) if running_data_df.index.name != "timestamps": raise DataFrameIndexError( f"Expected running_data_df index to be named 'timestamps' " f"But instead got: '{running_data_df.index.name}'" ) - running_speed = pd.DataFrame({ - "timestamps": running_data_df.index.values, - "speed": running_data_df.speed.values - }) + running_speed = pd.DataFrame( + {"timestamps": running_data_df.index.values, "speed": running_data_df.speed.values} + ) return running_speed @classmethod def from_stimulus_file( - cls, - behavior_stimulus_file: BehaviorStimulusFile, - sync_file: Optional[SyncFile] = None, - filtered: bool = True, - zscore_threshold: float = 10.0) -> "RunningSpeed": + cls, + behavior_stimulus_file: BehaviorStimulusFile, + sync_file: Optional[SyncFile] = None, + filtered: bool = True, + zscore_threshold: float = 10.0, + ) -> "RunningSpeed": """ sync_file is used for generating timestamps. If None, timestamps will be generated from the stimulus file. """ if sync_file is not None: - stimulus_timestamps = StimulusTimestamps.from_sync_file( - sync_file=sync_file, - monitor_delay=0.0) + stimulus_timestamps = StimulusTimestamps.from_sync_file(sync_file=sync_file, monitor_delay=0.0) else: stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=behavior_stimulus_file, - monitor_delay=0.0) + stimulus_file=behavior_stimulus_file, monitor_delay=0.0 + ) running_speed = cls._get_running_speed_df( - behavior_stimulus_file, - stimulus_timestamps, - filtered, - zscore_threshold + behavior_stimulus_file, stimulus_timestamps, filtered, zscore_threshold ) return cls( running_speed=running_speed, stimulus_file=behavior_stimulus_file, stimulus_timestamps=stimulus_timestamps, - filtered=filtered) + filtered=filtered, + ) @classmethod def from_multiple_stimulus_files( - cls, - behavior_stimulus_file: BehaviorStimulusFile, - mapping_stimulus_file: MappingStimulusFile, - replay_stimulus_file: ReplayStimulusFile, - sync_file: SyncFile, - filtered: bool = True, - zscore_threshold: float = 10.0) -> "RunningSpeed": + cls, + behavior_stimulus_file: BehaviorStimulusFile, + mapping_stimulus_file: MappingStimulusFile, + replay_stimulus_file: ReplayStimulusFile, + sync_file: SyncFile, + filtered: bool = True, + zscore_threshold: float = 10.0, + ) -> "RunningSpeed": """ sync_file is used for generating timestamps. @@ -168,62 +159,44 @@ def from_multiple_stimulus_files( """ df = _get_multi_stim_running_df( - sync_path=sync_file.filepath, - behavior_stimulus_file=behavior_stimulus_file, - mapping_stimulus_file=mapping_stimulus_file, - replay_stimulus_file=replay_stimulus_file, - use_lowpass_filter=filtered, - zscore_threshold=zscore_threshold)['running_speed'] + sync_path=sync_file.filepath, + behavior_stimulus_file=behavior_stimulus_file, + mapping_stimulus_file=mapping_stimulus_file, + replay_stimulus_file=replay_stimulus_file, + use_lowpass_filter=filtered, + zscore_threshold=zscore_threshold, + )["running_speed"] - return cls( - running_speed=df, - filtered=filtered, - sync_file=None, - stimulus_file=None, - stimulus_timestamps=None) + return cls(running_speed=df, filtered=filtered, sync_file=None, stimulus_file=None, stimulus_timestamps=None) @classmethod - def from_nwb( - cls, - nwbfile: NWBFile, - filtered=True - ) -> "RunningSpeed": - running_module = nwbfile.processing['running'] - interface_name = 'speed' if filtered else 'speed_unfiltered' + def from_nwb(cls, nwbfile: NWBFile, filtered=True) -> "RunningSpeed": + running_module = nwbfile.processing["running"] + interface_name = "speed" if filtered else "speed_unfiltered" running_interface = running_module.get_data_interface(interface_name) timestamps = running_interface.timestamps[:] values = running_interface.data[:] - running_speed = pd.DataFrame( - { - "timestamps": timestamps, - "speed": values - } - ) + running_speed = pd.DataFrame({"timestamps": timestamps, "speed": values}) return cls(running_speed=running_speed, filtered=filtered) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: running_speed: pd.DataFrame = self.value - data = running_speed['speed'].values - timestamps = running_speed['timestamps'].values + data = running_speed["speed"].values + timestamps = running_speed["timestamps"].values if self._filtered: data_interface_name = "speed" else: data_interface_name = "speed_unfiltered" - running_speed_series = TimeSeries( - name=data_interface_name, - data=data, - timestamps=timestamps, - unit='cm/s') + running_speed_series = TimeSeries(name=data_interface_name, data=data, timestamps=timestamps, unit="cm/s") - if 'running' in nwbfile.processing: - running_mod = nwbfile.processing['running'] + if "running" in nwbfile.processing: + running_mod = nwbfile.processing["running"] else: - running_mod = ProcessingModule('running', - 'Running speed processing module') + running_mod = ProcessingModule("running", "Running speed processing module") nwbfile.add_processing_module(running_mod) running_mod.add_data_interface(running_speed_series) diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/fingerprint_stimulus.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/fingerprint_stimulus.py index 13414d461b..f9ed3176d1 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/fingerprint_stimulus.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/fingerprint_stimulus.py @@ -7,7 +7,8 @@ class FingerprintStimulus: """The fingerprint stimulus is a movie used to trigger many neurons - and is used to improve cell matching.""" + and is used to improve cell matching.""" + def __init__(self, table: pd.DataFrame): self._table = table @@ -18,10 +19,10 @@ def table(self) -> pd.DataFrame: @classmethod def from_stimulus_file( - cls, - stimulus_presentations: pd.DataFrame, - stimulus_file: BehaviorStimulusFile, - stimulus_timestamps: StimulusTimestamps + cls, + stimulus_presentations: pd.DataFrame, + stimulus_file: BehaviorStimulusFile, + stimulus_timestamps: StimulusTimestamps, ) -> "FingerprintStimulus": """ Instantiates `FingerprintStimulus` from stimulus file @@ -40,34 +41,29 @@ def from_stimulus_file( `FingerprintStimulus` Instantiated FingerprintStimulus """ - fingerprint_stim = ( - stimulus_file.data['items']['behavior']['items']['fingerprint'] - ['static_stimulus']) + fingerprint_stim = stimulus_file.data["items"]["behavior"]["items"]["fingerprint"]["static_stimulus"] - n_repeats = fingerprint_stim['runs'] + n_repeats = fingerprint_stim["runs"] # spontaneous + fingerprint indices relative to start of session stimulus_session_frame_indices = np.array( - stimulus_file.data['items']['behavior']['items'] - ['fingerprint']['frame_indices']) + stimulus_file.data["items"]["behavior"]["items"]["fingerprint"]["frame_indices"] + ) - movie_length = int(len(fingerprint_stim['sweep_frames']) / n_repeats) + movie_length = int(len(fingerprint_stim["sweep_frames"]) / n_repeats) # Start index within the spontaneous + fingerprint block - movie_start_index = (fingerprint_stim['frame_list'] == -1).sum() + movie_start_index = (fingerprint_stim["frame_list"] == -1).sum() res = [] for repeat in range(n_repeats): for frame in range(movie_length): # 0-indexed frame indices relative to start of fingerprint # movie - stimulus_frame_indices = \ - np.array(fingerprint_stim['sweep_frames'] - [frame + (repeat * movie_length)]) - start_frame, end_frame = stimulus_session_frame_indices[ - stimulus_frame_indices + movie_start_index] - start_time, stop_time = \ - stimulus_timestamps.value[[ + stimulus_frame_indices = np.array(fingerprint_stim["sweep_frames"][frame + (repeat * movie_length)]) + start_frame, end_frame = stimulus_session_frame_indices[stimulus_frame_indices + movie_start_index] + start_time, stop_time = stimulus_timestamps.value[ + [ start_frame, # Sometimes stimulus timestamps gets truncated too # early. There should be 2 extra frames after last @@ -77,25 +73,27 @@ def from_stimulus_file( # index out of bounds. This results in the last # frame's duration being too short TODO this is # probably a bug somewhere in timestamp creation - min(end_frame + 1, - len(stimulus_timestamps.value) - 1)]] - res.append({ - 'movie_frame_index': frame, - 'start_time': start_time, - 'stop_time': stop_time, - 'start_frame': start_frame, - 'end_frame': end_frame, - 'movie_repeat': repeat, - 'duration': stop_time - start_time - }) + min(end_frame + 1, len(stimulus_timestamps.value) - 1), + ] + ] + res.append( + { + "movie_frame_index": frame, + "start_time": start_time, + "stop_time": stop_time, + "start_frame": start_frame, + "end_frame": end_frame, + "movie_repeat": repeat, + "duration": stop_time - start_time, + } + ) table = pd.DataFrame(res) - table['stimulus_block'] = \ - stimulus_presentations['stimulus_block'].max() \ - + 2 # + 2 since there is a gap before this stimulus - table['stimulus_name'] = 'natural_movie_one' + table["stimulus_block"] = ( + stimulus_presentations["stimulus_block"].max() + 2 + ) # + 2 since there is a gap before this stimulus + table["stimulus_name"] = "natural_movie_one" - table = table.astype( - {c: 'int64' for c in table.select_dtypes(include='int')}) + table = table.astype({c: "int64" for c in table.select_dtypes(include="int")}) return FingerprintStimulus(table=table) diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/presentations.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/presentations.py index f511b03d07..2cbcdef734 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/presentations.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/presentations.py @@ -107,15 +107,11 @@ def __init__( if "trials_id" not in presentations.columns: # Add trials_id to presentations df to allow for joining of the # two tables. - presentations["trials_id"] = compute_trials_id_for_stimulus( - presentations, trials.data - ) + presentations["trials_id"] = compute_trials_id_for_stimulus(presentations, trials.data) if "is_sham_change" not in presentations.columns: # Mark changes in active and replay stimulus that are # #sham-changes - presentations = compute_is_sham_change( - presentations, trials.data - ) + presentations = compute_is_sham_change(presentations, trials.data) if sort_columns: presentations = enforce_df_column_order( presentations, @@ -141,9 +137,7 @@ def __init__( ) super().__init__(name="presentations", value=presentations) - def to_nwb( - self, nwbfile: NWBFile, stimulus_name_column="stimulus_name" - ) -> NWBFile: + def to_nwb(self, nwbfile: NWBFile, stimulus_name_column="stimulus_name") -> NWBFile: """Adds a stimulus table (defining stimulus characteristics for each time point in a session) to an nwbfile as TimeIntervals. @@ -159,9 +153,7 @@ def to_nwb( stimulus_names = stimulus_table[stimulus_name_column].unique() for stim_name in sorted(stimulus_names): - specific_stimulus_table = stimulus_table[ - stimulus_table[stimulus_name_column] == stim_name - ] + specific_stimulus_table = stimulus_table[stimulus_table[stimulus_name_column] == stim_name] # Drop columns where all values in column are NaN cleaned_table = specific_stimulus_table.dropna(axis=1, how="all") # For columns with mixed strings and NaNs, fill NaNs with 'N/A' @@ -193,9 +185,7 @@ def to_nwb( for row in cleaned_table.itertuples(index=False): row = row._asdict() - presentation_interval.add_interval( - **row, tags="stimulus_time_interval", timeseries=ts - ) + presentation_interval.add_interval(**row, tags="stimulus_time_interval", timeseries=ts) nwbfile.add_time_intervals(presentation_interval) @@ -241,9 +231,7 @@ def from_nwb( presentation_dfs.append(df) table = pd.concat(presentation_dfs, sort=False) - table = table.astype( - {c: "int64" for c in table.select_dtypes(include="int")} - ) + table = table.astype({c: "int64" for c in table.select_dtypes(include="int")}) # coercing bool columns with nans to boolean. "Boolean" dtype # allows null values, while "bool" does not (see comment in to_nwb) @@ -263,16 +251,12 @@ def from_nwb( if add_is_change: table["is_change"] = is_change_event(stimulus_presentations=table) - table["flashes_since_change"] = get_flashes_since_change( - stimulus_presentations=table - ) + table["flashes_since_change"] = get_flashes_since_change(stimulus_presentations=table) trials = None if add_trials_dependent_values and nwbfile.trials is not None: trials = Trials.from_nwb(nwbfile) - return Presentations( - presentations=table, column_list=column_list, trials=trials - ) + return Presentations(presentations=table, column_list=column_list, trials=trials) @classmethod def from_stimulus_file( @@ -318,13 +302,9 @@ def from_stimulus_file( and whose columns are presentation characteristics. """ data = stimulus_file.data - raw_stim_pres_df = get_stimulus_presentations( - data, stimulus_timestamps.value - ) + raw_stim_pres_df = get_stimulus_presentations(data, stimulus_timestamps.value) raw_stim_pres_df = raw_stim_pres_df.drop(columns=["index"]) - raw_stim_pres_df = cls._check_for_errant_omitted_stimulus( - input_df=raw_stim_pres_df - ) + raw_stim_pres_df = cls._check_for_errant_omitted_stimulus(input_df=raw_stim_pres_df) # Fill in nulls for image_name # This makes two assumptions: @@ -333,13 +313,9 @@ def from_stimulus_file( # values for `image_name` are null. if pd.isnull(raw_stim_pres_df["image_name"]).all(): if ~pd.isnull(raw_stim_pres_df["orientation"]).all(): - raw_stim_pres_df["image_name"] = raw_stim_pres_df[ - "orientation" - ].apply(lambda x: f"gratings_{x}") + raw_stim_pres_df["image_name"] = raw_stim_pres_df["orientation"].apply(lambda x: f"gratings_{x}") else: - raise ValueError( - "All values for 'orientation' and " "'image_name are null." - ) + raise ValueError("All values for 'orientation' and 'image_name are null.") stimulus_metadata_df = get_stimulus_metadata(data) @@ -363,9 +339,7 @@ def from_stimulus_file( .sort_index() .set_index("timestamps", drop=True) ) - stimulus_index_df["image_index"] = stimulus_index_df[ - "image_index" - ].astype("int") + stimulus_index_df["image_index"] = stimulus_index_df["image_index"].astype("int") stim_pres_df = raw_stim_pres_df.merge( stimulus_index_df, left_on="start_time", @@ -379,24 +353,14 @@ def from_stimulus_file( f" {len(stim_pres_df)}." ) - stim_pres_df["is_change"] = is_change_event( - stimulus_presentations=stim_pres_df - ) - stim_pres_df["flashes_since_change"] = get_flashes_since_change( - stimulus_presentations=stim_pres_df - ) + stim_pres_df["is_change"] = is_change_event(stimulus_presentations=stim_pres_df) + stim_pres_df["flashes_since_change"] = get_flashes_since_change(stimulus_presentations=stim_pres_df) # Sort columns then drop columns which contain only all NaN values - stim_pres_df = stim_pres_df[sorted(stim_pres_df)].dropna( - axis=1, how="all" - ) + stim_pres_df = stim_pres_df[sorted(stim_pres_df)].dropna(axis=1, how="all") if limit_to_images is not None: - stim_pres_df = stim_pres_df[ - stim_pres_df["image_name"].isin(limit_to_images) - ] - stim_pres_df.index = pd.Index( - range(stim_pres_df.shape[0]), name=stim_pres_df.index.name - ) + stim_pres_df = stim_pres_df[stim_pres_df["image_name"].isin(limit_to_images)] + stim_pres_df.index = pd.Index(range(stim_pres_df.shape[0]), name=stim_pres_df.index.name) stim_pres_df["stimulus_block"] = 0 stim_pres_df["stimulus_name"] = stimulus_file.stimulus_name @@ -408,9 +372,7 @@ def from_stimulus_file( behavior_session_id=behavior_session_id, ) - has_fingerprint_stimulus = ( - "fingerprint" in stimulus_file.data["items"]["behavior"]["items"] - ) + has_fingerprint_stimulus = "fingerprint" in stimulus_file.data["items"]["behavior"]["items"] if has_fingerprint_stimulus: stim_pres_df = cls._add_fingerprint_stimulus( stimulus_presentations=stim_pres_df, @@ -423,13 +385,9 @@ def from_stimulus_file( coerce_bool_to_boolean=True, ) if project_code is not None: - stim_pres_df = produce_stimulus_block_names( - stim_pres_df, stimulus_file.session_type, project_code.value - ) + stim_pres_df = produce_stimulus_block_names(stim_pres_df, stimulus_file.session_type, project_code.value) - return Presentations( - presentations=stim_pres_df, column_list=column_list, trials=trials - ) + return Presentations(presentations=stim_pres_df, column_list=column_list, trials=trials) @classmethod def from_path( @@ -460,9 +418,7 @@ def from_path( """ path = Path(path) df = pd.read_csv(path) - cls._add_is_image_novel( - stimulus_presentations=df, behavior_session_id=behavior_session_id - ) + cls._add_is_image_novel(stimulus_presentations=df, behavior_session_id=behavior_session_id) exclude_columns = exclude_columns if exclude_columns else [] df = df[[c for c in df if c not in exclude_columns]] df = cls._postprocess( @@ -497,26 +453,15 @@ def _get_is_image_novel( ------- Dict mapping image name to is_novel """ - mouse = Mouse.from_behavior_session_id( - behavior_session_id=behavior_session_id - ) - prior_images_shown = mouse.get_images_shown( - up_to_behavior_session_id=behavior_session_id - ) + mouse = Mouse.from_behavior_session_id(behavior_session_id=behavior_session_id) + prior_images_shown = mouse.get_images_shown(up_to_behavior_session_id=behavior_session_id) - image_names = set( - [x for x in image_names if x != "omitted" and type(x) is str] - ) - is_novel = { - f"{image_name}": image_name not in prior_images_shown - for image_name in image_names - } + image_names = set([x for x in image_names if x != "omitted" and type(x) is str]) + is_novel = {f"{image_name}": image_name not in prior_images_shown for image_name in image_names} return is_novel @classmethod - def _add_is_image_novel( - cls, stimulus_presentations: pd.DataFrame, behavior_session_id: int - ): + def _add_is_image_novel(cls, stimulus_presentations: pd.DataFrame, behavior_session_id: int): """Adds a column 'is_image_novel' to `stimulus_presentations` Parameters @@ -525,9 +470,7 @@ def _add_is_image_novel( behavior_session_id: LIMS id of behavior session """ - stimulus_presentations["is_image_novel"] = stimulus_presentations[ - "image_name" - ].map( + stimulus_presentations["is_image_novel"] = stimulus_presentations["image_name"].map( cls._get_is_image_novel( image_names=stimulus_presentations["image_name"].tolist(), behavior_session_id=behavior_session_id, @@ -558,17 +501,13 @@ def _postprocess( Amount of time a stimuli is omitted for in seconds""" df = presentations if fill_omitted_values: - cls._fill_missing_values_for_omitted_flashes( - df=df, omitted_time_duration=omitted_time_duration - ) + cls._fill_missing_values_for_omitted_flashes(df=df, omitted_time_duration=omitted_time_duration) if coerce_bool_to_boolean: df = df.astype( { c: "boolean" for c in df.select_dtypes("O") - if set(df[c][~df[c].isna()].unique()).issubset( - {True, False} - ) + if set(df[c][~df[c].isna()].unique()).issubset({True, False}) } ) df = cls._check_for_errant_omitted_stimulus(input_df=df) @@ -597,11 +536,9 @@ def _check_for_errant_omitted_stimulus( found, return input_df unmodified. """ - def safe_omitted_check(input_df: pd.Series, - stimulus_block: Optional[int]): + def safe_omitted_check(input_df: pd.Series, stimulus_block: Optional[int]): if stimulus_block is not None: - first_row = input_df[ - input_df['stimulus_block'] == stim_block].iloc[0] + first_row = input_df[input_df["stimulus_block"] == stim_block].iloc[0] else: first_row = input_df.iloc[0] @@ -612,18 +549,14 @@ def safe_omitted_check(input_df: pd.Series, if "omitted" in input_df.columns and len(input_df) > 0: if "stimulus_block" in input_df.columns: - for stim_block in input_df['stimulus_block'].unique(): - input_df = safe_omitted_check(input_df=input_df, - stimulus_block=stim_block) + for stim_block in input_df["stimulus_block"].unique(): + input_df = safe_omitted_check(input_df=input_df, stimulus_block=stim_block) else: - input_df = safe_omitted_check(input_df=input_df, - stimulus_block=None) + input_df = safe_omitted_check(input_df=input_df, stimulus_block=None) return input_df @staticmethod - def _fill_missing_values_for_omitted_flashes( - df: pd.DataFrame, omitted_time_duration: float = 0.25 - ) -> pd.DataFrame: + def _fill_missing_values_for_omitted_flashes(df: pd.DataFrame, omitted_time_duration: float = 0.25) -> pd.DataFrame: """ This function sets the stop time for a row that is an omitted stimulus. An omitted stimulus is a stimulus where a mouse is @@ -639,16 +572,12 @@ def _fill_missing_values_for_omitted_flashes( Amount of time a stimulus is omitted for in seconds """ omitted = df["omitted"].fillna(False) - df.loc[omitted, "stop_time"] = ( - df.loc[omitted, "start_time"] + omitted_time_duration - ) + df.loc[omitted, "stop_time"] = df.loc[omitted, "start_time"] + omitted_time_duration df.loc[omitted, "duration"] = omitted_time_duration return df @classmethod - def _get_spontaneous_stimulus( - cls, stimulus_presentations_table: pd.DataFrame - ) -> pd.DataFrame: + def _get_spontaneous_stimulus(cls, stimulus_presentations_table: pd.DataFrame) -> pd.DataFrame: """The spontaneous stimulus is a gray screen shown in between different stimulus blocks. This method finds any gaps in the stimulus presentations. These gaps are assumed to be spontaneous stimulus. @@ -680,17 +609,11 @@ def _get_spontaneous_stimulus( ): res.append( { - "duration": stimulus_presentations_table.iloc[0][ - "start_time" - ], + "duration": stimulus_presentations_table.iloc[0]["start_time"], "start_time": 0, - "stop_time": stimulus_presentations_table.iloc[0][ - "start_time" - ], + "stop_time": stimulus_presentations_table.iloc[0]["start_time"], "start_frame": 0, - "end_frame": stimulus_presentations_table.iloc[0][ - "start_frame" - ], + "end_frame": stimulus_presentations_table.iloc[0]["start_frame"], "stimulus_block": 0, "stimulus_name": "spontaneous", } @@ -700,27 +623,21 @@ def _get_spontaneous_stimulus( stimulus_presentations_table["stimulus_block"] += 1 spontaneous_stimulus_blocks = get_spontaneous_block_indices( - stimulus_blocks=( - stimulus_presentations_table["stimulus_block"].values - ) + stimulus_blocks=(stimulus_presentations_table["stimulus_block"].values) ) for spontaneous_block in spontaneous_stimulus_blocks: prev_stop_time = stimulus_presentations_table[ - stimulus_presentations_table["stimulus_block"] - == spontaneous_block - 1 + stimulus_presentations_table["stimulus_block"] == spontaneous_block - 1 ]["stop_time"].max() prev_end_frame = stimulus_presentations_table[ - stimulus_presentations_table["stimulus_block"] - == spontaneous_block - 1 + stimulus_presentations_table["stimulus_block"] == spontaneous_block - 1 ]["end_frame"].max() next_start_time = stimulus_presentations_table[ - stimulus_presentations_table["stimulus_block"] - == spontaneous_block + 1 + stimulus_presentations_table["stimulus_block"] == spontaneous_block + 1 ]["start_time"].min() next_start_frame = stimulus_presentations_table[ - stimulus_presentations_table["stimulus_block"] - == spontaneous_block + 1 + stimulus_presentations_table["stimulus_block"] == spontaneous_block + 1 ]["start_frame"].min() res.append( { @@ -736,9 +653,7 @@ def _get_spontaneous_stimulus( res = pd.DataFrame(res) - return pd.concat([stimulus_presentations_table, res]).sort_values( - "start_frame" - ) + return pd.concat([stimulus_presentations_table, res]).sort_values("start_frame") @classmethod def _add_fingerprint_stimulus( @@ -763,12 +678,8 @@ def _add_fingerprint_stimulus( except MalformedStimulusFileError: return stimulus_presentations - stimulus_presentations = pd.concat( - [stimulus_presentations, fingerprint_stimulus.table] - ) - stimulus_presentations = cls._get_spontaneous_stimulus( - stimulus_presentations_table=stimulus_presentations - ) + stimulus_presentations = pd.concat([stimulus_presentations, fingerprint_stimulus.table]) + stimulus_presentations = cls._get_spontaneous_stimulus(stimulus_presentations_table=stimulus_presentations) # reset index to go from 0...end stimulus_presentations.index = pd.Index( @@ -801,9 +712,7 @@ def get_spontaneous_block_indices(stimulus_blocks: np.ndarray) -> np.ndarray: block_diffs = np.diff(blocks) if (block_diffs > 2).any(): raise RuntimeError( - f"There should not be any stimulus block " - f"diffs greater than 2. The stimulus " - f"blocks are {blocks}" + f"There should not be any stimulus block diffs greater than 2. The stimulus blocks are {blocks}" ) # i.e. if the current blocks are [0, 2], then block_diffs will diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/stimuli.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/stimuli.py index 262c7d22f3..955cb3704e 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/stimuli.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/stimuli.py @@ -31,9 +31,7 @@ class Stimuli( NwbReadableInterface, NwbWritableInterface, ): - def __init__(self, - presentations: Presentations, - templates: Templates): + def __init__(self, presentations: Presentations, templates: Templates): super().__init__(name="stimuli", value=None, is_value_self=True) self._presentations = presentations self._templates = templates @@ -72,7 +70,7 @@ def from_stimulus_file( presentation_columns: Optional[List[str]] = None, presentation_fill_omitted_values: bool = True, project_code: Optional[ProjectCode] = None, - load_stimulus_movie: bool = False + load_stimulus_movie: bool = False, ) -> "Stimuli": """ @@ -115,9 +113,7 @@ def from_stimulus_file( trials=trials, ) t = Templates.from_stimulus_file( - stimulus_file=stimulus_file, - limit_to_images=limit_to_images, - load_stimulus_movie=load_stimulus_movie + stimulus_file=stimulus_file, limit_to_images=limit_to_images, load_stimulus_movie=load_stimulus_movie ) return Stimuli(presentations=p, templates=t) @@ -138,9 +134,7 @@ def to_nwb( ------- NWBFile """ - nwbfile = self._templates.to_nwb( - nwbfile=nwbfile, stimulus_presentations=self._presentations - ) + nwbfile = self._templates.to_nwb(nwbfile=nwbfile, stimulus_presentations=self._presentations) nwbfile = self._presentations.to_nwb( nwbfile=nwbfile, stimulus_name_column=presentations_stimulus_column_name, diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/stimulus_templates.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/stimulus_templates.py index 93b2c97597..cc65e907f4 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/stimulus_templates.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/stimulus_templates.py @@ -6,8 +6,7 @@ import numpy as np import pandas as pd -from allensdk.brain_observatory.behavior.data_objects.stimuli.util import \ - convert_filepath_caseinsensitive +from allensdk.brain_observatory.behavior.data_objects.stimuli.util import convert_filepath_caseinsensitive from allensdk.brain_observatory.stimulus_info import BrainObservatoryMonitor @@ -36,10 +35,10 @@ def name(self): class StimulusImageFactory: """Factory for StimulusImage""" + _monitor = BrainObservatoryMonitor() - def from_unprocessed(self, input_array: np.ndarray, - name: str) -> StimulusImage: + def from_unprocessed(self, input_array: np.ndarray, name: str) -> StimulusImage: """Creates a StimulusImage from unprocessed input (usually pkl). Image needs to be warped and preprocessed""" resized, unwarped = self._get_unwarped(arr=input_array) @@ -48,8 +47,7 @@ def from_unprocessed(self, input_array: np.ndarray, return image @staticmethod - def from_processed(warped: np.ndarray, unwarped: np.ndarray, - name: str) -> StimulusImage: + def from_processed(warped: np.ndarray, unwarped: np.ndarray, name: str) -> StimulusImage: """Creates a StimulusImage from processed input (usually nwb). Image has already been warped and preprocessed""" image = StimulusImage(name=name, warped=warped, unwarped=unwarped) @@ -65,8 +63,7 @@ def _get_unwarped(self, arr: np.ndarray): """This produces the pixels that would be visible in the unwarped image post-warping""" # 1. Resize image to the same size as the monitor - resized_array = self._monitor.natural_scene_image_to_screen( - arr, origin='upper') + resized_array = self._monitor.natural_scene_image_to_screen(arr, origin="upper") # 2. Remove unseen pixels arr = self._exclude_unseen_pixels(arr=resized_array) @@ -76,7 +73,7 @@ def _exclude_unseen_pixels(self, arr: np.ndarray): """After warping, some pixels are not visible on the screen. This sets those pixels to nan to make downstream analysis easier.""" mask = self._monitor.get_mask() - arr = arr.astype('float') + arr = arr.astype("float") arr *= mask arr[mask == 0] = np.nan return arr @@ -89,13 +86,11 @@ def _warp(self, arr: np.ndarray) -> np.ndarray: class StimulusMovieFrameFactory(StimulusImageFactory): - def _get_unwarped(self, arr: np.ndarray): """This produces the pixels that would be visible in the unwarped movie frame post-warping""" # 1. Resize image to the same size as the monitor - resized_array = self._monitor.natural_movie_image_to_screen( - arr, origin='upper') + resized_array = self._monitor.natural_movie_image_to_screen(arr, origin="upper") # 2. Remove unseen pixels arr = self._exclude_unseen_pixels(arr=resized_array) @@ -116,8 +111,7 @@ def __init__(self, image_set_name: str, images: List[StimulusImage]): """ self._image_set_name = image_set_name - image_set_name = convert_filepath_caseinsensitive( - image_set_name) + image_set_name = convert_filepath_caseinsensitive(image_set_name) self._image_set_filepath = image_set_name self._images: Dict[str, StimulusImage] = {} @@ -146,9 +140,7 @@ def values(self): def items(self): return self._images.items() - def to_dataframe(self, - index_name: str = 'image_name', - index_type: str = 'str') -> pd.DataFrame: + def to_dataframe(self, index_name: str = "image_name", index_type: str = "str") -> pd.DataFrame: """ Convert the collection of stimulus templates to a dataframe. @@ -167,17 +159,14 @@ def to_dataframe(self, (un)warped columns return as a list of lists representing the displayed image/frame. """ - index = pd.Index(np.array(list(self.image_names), dtype=index_type), - name=index_name) + index = pd.Index(np.array(list(self.image_names), dtype=index_type), name=index_name) warped = [img.warped for img in self.images] unwarped = [img.unwarped for img in self.images] - df = pd.DataFrame({'unwarped': unwarped, 'warped': warped}, - index=index) + df = pd.DataFrame({"unwarped": unwarped, "warped": warped}, index=index) df.name = self._image_set_name return df - def __add_image(self, warped_values: np.ndarray, - unwarped_values: np.ndarray, name: str): + def __add_image(self, warped_values: np.ndarray, unwarped_values: np.ndarray, name: str): """ Parameters ---------- @@ -190,9 +179,7 @@ def __add_image(self, warped_values: np.ndarray, The image array corresponding to the 'unwarped' version of the stimuli. """ - image = StimulusImage(warped=warped_values, - unwarped=unwarped_values, - name=name) + image = StimulusImage(warped=warped_values, unwarped=unwarped_values, name=name) self._images[name] = image def __getitem__(self, item) -> StimulusImage: @@ -208,7 +195,7 @@ def __iter__(self): yield from self._images def __repr__(self): - return f'{self._images}' + return f"{self._images}" def __eq__(self, other: object): if isinstance(other, StimulusTemplate): @@ -218,29 +205,25 @@ def __eq__(self, other: object): if sorted(self.image_names) != sorted(other.image_names): return False - for (img_name, self_img) in self.items(): + for img_name, self_img in self.items(): other_img = other._images[img_name] - warped_equal = np.array_equal( - self_img.warped, other_img.warped) - unwarped_equal = np.allclose(self_img.unwarped, - other_img.unwarped, - equal_nan=True) + warped_equal = np.array_equal(self_img.warped, other_img.warped) + unwarped_equal = np.allclose(self_img.unwarped, other_img.unwarped, equal_nan=True) if not (warped_equal and unwarped_equal): return False return True else: - raise NotImplementedError( - "Cannot compare a StimulusTemplate with an object of type: " - f"{type(other)}!") + raise NotImplementedError(f"Cannot compare a StimulusTemplate with an object of type: {type(other)}!") class StimulusTemplateFactory: """Factory for StimulusTemplate""" @staticmethod - def from_unprocessed(image_set_name: str, image_attributes: List[dict], - images: List[np.ndarray]) -> StimulusTemplate: + def from_unprocessed( + image_set_name: str, image_attributes: List[dict], images: List[np.ndarray] + ) -> StimulusTemplate: """Create StimulusTemplate from pkl or unprocessed input. Stimulus templates created this way need to be processed to acquire unwarped versions of the images presented. @@ -273,17 +256,15 @@ def from_unprocessed(image_set_name: str, image_attributes: List[dict], """ stimulus_images = [] for i, image in enumerate(images): - name = image_attributes[i]['image_name'] - stimulus_image = StimulusImageFactory().from_unprocessed( - name=name, input_array=image) + name = image_attributes[i]["image_name"] + stimulus_image = StimulusImageFactory().from_unprocessed(name=name, input_array=image) stimulus_images.append(stimulus_image) - return StimulusTemplate(image_set_name=image_set_name, - images=stimulus_images) + return StimulusTemplate(image_set_name=image_set_name, images=stimulus_images) @staticmethod - def from_processed(image_set_name: str, image_attributes: List[dict], - unwarped: List[np.ndarray], - warped: List[np.ndarray]) -> StimulusTemplate: + def from_processed( + image_set_name: str, image_attributes: List[dict], unwarped: List[np.ndarray], warped: List[np.ndarray] + ) -> StimulusTemplate: """Create StimulusTemplate from nwb or other processed input. Stimulus templates created this way DO NOT need to be processed to acquire unwarped versions of the images presented. @@ -320,12 +301,12 @@ def from_processed(image_set_name: str, image_attributes: List[dict], for i, attrs in enumerate(image_attributes): warped_image = warped[i] unwarped_image = unwarped[i] - name = attrs['image_name'] + name = attrs["image_name"] stimulus_image = StimulusImageFactory.from_processed( - name=name, warped=warped_image, unwarped=unwarped_image) + name=name, warped=warped_image, unwarped=unwarped_image + ) stimulus_images.append(stimulus_image) - return StimulusTemplate(image_set_name=image_set_name, - images=stimulus_images) + return StimulusTemplate(image_set_name=image_set_name, images=stimulus_images) class StimulusMovieTemplateFactory(StimulusTemplateFactory): @@ -334,9 +315,9 @@ class StimulusMovieTemplateFactory(StimulusTemplateFactory): """ @staticmethod - def from_unprocessed(movie_name: str, - movie_frames: List[np.ndarray], - n_workers: Optional[int] = None) -> StimulusTemplate: + def from_unprocessed( + movie_name: str, movie_frames: List[np.ndarray], n_workers: Optional[int] = None + ) -> StimulusTemplate: """Create StimulusTemplate from pkl or unprocessed input. Stimulus templates created this way need to be processed to acquire unwarped versions of the movie frames presented. @@ -365,17 +346,15 @@ def from_unprocessed(movie_name: str, n_workers = os.cpu_count() with Pool(n_workers) as worker_pool: - stimulus_images = list(tqdm( - worker_pool.imap( - _movie_warper_helper, - [(idx, frame) for idx, frame in enumerate(movie_frames)] - ), - total=len(movie_frames), - desc="Warping natural movie frames" - )) + stimulus_images = list( + tqdm( + worker_pool.imap(_movie_warper_helper, [(idx, frame) for idx, frame in enumerate(movie_frames)]), + total=len(movie_frames), + desc="Warping natural movie frames", + ) + ) - return StimulusTemplate(image_set_name=movie_name, - images=stimulus_images) + return StimulusTemplate(image_set_name=movie_name, images=stimulus_images) def _movie_warper_helper(*args): @@ -383,5 +362,4 @@ def _movie_warper_helper(*args): Simple helper wrapping the stimulus movie frame factory. """ name, frame = args[0] - return StimulusMovieFrameFactory().from_unprocessed( - name=name, input_array=frame) + return StimulusMovieFrameFactory().from_unprocessed(name=name, input_array=frame) diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/templates.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/templates.py index 66e699acc1..cf02fe7dd0 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/templates.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/templates.py @@ -9,69 +9,59 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations \ - import \ - Presentations -from allensdk.brain_observatory.behavior.stimulus_processing import \ - get_stimulus_templates -from allensdk.brain_observatory.behavior.data_objects.stimuli \ - .stimulus_templates import \ - StimulusTemplate, StimulusTemplateFactory, StimulusMovieTemplateFactory -from allensdk.brain_observatory.behavior.write_nwb.extensions\ - .stimulus_template.ndx_stimulus_template import \ - StimulusTemplateExtension +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations import Presentations +from allensdk.brain_observatory.behavior.stimulus_processing import get_stimulus_templates +from allensdk.brain_observatory.behavior.data_objects.stimuli.stimulus_templates import ( + StimulusTemplate, + StimulusTemplateFactory, + StimulusMovieTemplateFactory, +) +from allensdk.brain_observatory.behavior.write_nwb.extensions.stimulus_template.ndx_stimulus_template import ( + StimulusTemplateExtension, +) from allensdk.internal.core.lims_utilities import safe_system_path -class Templates(DataObject, StimulusFileReadableInterface, - NwbReadableInterface, NwbWritableInterface): +class Templates(DataObject, StimulusFileReadableInterface, NwbReadableInterface, NwbWritableInterface): def __init__(self, templates: Dict[str, StimulusTemplate]): - super().__init__(name='stimulus_templates', value=templates) + super().__init__(name="stimulus_templates", value=templates) # Grab the keys from the input dictionary. The "images" key is assumed # to be the key in the dictionary that does not have "movie" in its key # name. For VBO and VBN releases, there should only be at most 2 # keys in the dictionary. - image_template_keys = [ - key for key in templates.keys() - if 'movie' not in key.lower()] + image_template_keys = [key for key in templates.keys() if "movie" not in key.lower()] self._image_template_key = None error_message = "" if len(image_template_keys) == 1: self._image_template_key = image_template_keys[0] elif len(image_template_keys) > 1: - error_message += ( - "Found multiple image StimulusTemplates " - f"{image_template_keys}. ") + error_message += f"Found multiple image StimulusTemplates {image_template_keys}. " - movie_template_keys = [ - key for key in templates.keys() - if 'movie' in key.lower()] + movie_template_keys = [key for key in templates.keys() if "movie" in key.lower()] self._fingerprint_movie_template_key = None if len(movie_template_keys) == 1: self._fingerprint_movie_template_key = movie_template_keys[0] elif len(movie_template_keys) > 1: - error_message += ( - "Found multiple fingerprint movie StimulusTemplates " - f"{movie_template_keys}. ") + error_message += f"Found multiple fingerprint movie StimulusTemplates {movie_template_keys}. " if len(error_message) > 0: error_message += ( "This is not currently supported. " "Please limit input to one image template and/or one " - "fingerprint movie template.") + "fingerprint movie template." + ) raise NotImplementedError(error_message) @classmethod def from_stimulus_file( - cls, stimulus_file: BehaviorStimulusFile, - limit_to_images: Optional[List] = None, - load_stimulus_movie: bool = False) -> "Templates": + cls, + stimulus_file: BehaviorStimulusFile, + limit_to_images: Optional[List] = None, + load_stimulus_movie: bool = False, + ) -> "Templates": """Get stimulus templates (movies, scenes) for behavior session.""" # TODO: Eventually the `grating_images_dict` should be provided by the @@ -79,57 +69,52 @@ def from_stimulus_file( # - NJM 2021/2/23 gratings_dir = "/allen/programs/braintv/production/visualbehavior" - gratings_dir = os.path.join(gratings_dir, - "prod5/project_VisualBehavior") + gratings_dir = os.path.join(gratings_dir, "prod5/project_VisualBehavior") grating_images_dict = { "gratings_0.0": { - "warped": np.asarray(imageio.imread( - safe_system_path(os.path.join(gratings_dir, - "warped_grating_0.png")))), - "unwarped": np.asarray(imageio.imread( - safe_system_path(os.path.join( - gratings_dir, "masked_unwarped_grating_0.png")))) + "warped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "warped_grating_0.png"))) + ), + "unwarped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "masked_unwarped_grating_0.png"))) + ), }, "gratings_90.0": { - "warped": np.asarray(imageio.imread( - safe_system_path(os.path.join(gratings_dir, - "warped_grating_90.png")))), - "unwarped": np.asarray(imageio.imread( - safe_system_path(os.path.join( - gratings_dir, "masked_unwarped_grating_90.png")))) + "warped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "warped_grating_90.png"))) + ), + "unwarped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "masked_unwarped_grating_90.png"))) + ), }, "gratings_180.0": { - "warped": np.asarray(imageio.imread( - safe_system_path(os.path.join(gratings_dir, - "warped_grating_180.png")))), - "unwarped": np.asarray(imageio.imread( - safe_system_path(os.path.join( - gratings_dir, "masked_unwarped_grating_180.png")))) + "warped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "warped_grating_180.png"))) + ), + "unwarped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "masked_unwarped_grating_180.png"))) + ), }, "gratings_270.0": { - "warped": np.asarray(imageio.imread( - safe_system_path(os.path.join(gratings_dir, - "warped_grating_270.png")))), - "unwarped": np.asarray(imageio.imread( - safe_system_path(os.path.join( - gratings_dir, "masked_unwarped_grating_270.png")))) - } + "warped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "warped_grating_270.png"))) + ), + "unwarped": np.asarray( + imageio.imread(safe_system_path(os.path.join(gratings_dir, "masked_unwarped_grating_270.png"))) + ), + }, } pkl = stimulus_file.data stim_template = get_stimulus_templates( - pkl=pkl, - grating_images_dict=grating_images_dict, - limit_to_images=limit_to_images) + pkl=pkl, grating_images_dict=grating_images_dict, limit_to_images=limit_to_images + ) t = {stim_template.image_set_name: stim_template} - has_fingerprint_stimulus = ( - "fingerprint" in pkl["items"]["behavior"]["items"] - ) + has_fingerprint_stimulus = "fingerprint" in pkl["items"]["behavior"]["items"] if has_fingerprint_stimulus and load_stimulus_movie: movie_data = np.load( - Path(pkl['items']['behavior']['items'][ - 'fingerprint']['static_stimulus']['movie_path']) + Path(pkl["items"]["behavior"]["items"]["fingerprint"]["static_stimulus"]["movie_path"]) ) movie_template = StimulusMovieTemplateFactory.from_unprocessed( movie_name="natural_movie_one", @@ -143,23 +128,17 @@ def from_stimulus_file( def from_nwb(cls, nwbfile: NWBFile) -> "Templates": templates = {} for image_set_name, image_data in nwbfile.stimulus_template.items(): - - image_attributes = [ - {'image_name': image_name} - for image_name in image_data.control_description - ] + image_attributes = [{"image_name": image_name} for image_name in image_data.control_description] templates[image_set_name] = StimulusTemplateFactory.from_processed( image_set_name=image_set_name, image_attributes=image_attributes, warped=image_data.data[:], - unwarped=image_data.unwarped[:] + unwarped=image_data.unwarped[:], ) return Templates(templates=templates) - def to_nwb(self, nwbfile: NWBFile, - stimulus_presentations: Presentations) -> NWBFile: + def to_nwb(self, nwbfile: NWBFile, stimulus_presentations: Presentations) -> NWBFile: for key, stimulus_templates in self.value.items(): - unwarped_images = [] warped_images = [] image_names = [] @@ -171,46 +150,41 @@ def to_nwb(self, nwbfile: NWBFile, image_index = np.zeros(len(image_names)) image_index[:] = np.nan - visual_stimulus_image_series = \ - StimulusTemplateExtension( - name=stimulus_templates.image_set_name, - data=warped_images, - unwarped=unwarped_images, - control=list(range(len(image_names))), - control_description=image_names, - unit='NA', - format='raw', - timestamps=image_index) + visual_stimulus_image_series = StimulusTemplateExtension( + name=stimulus_templates.image_set_name, + data=warped_images, + unwarped=unwarped_images, + control=list(range(len(image_names))), + control_description=image_names, + unit="NA", + format="raw", + timestamps=image_index, + ) nwbfile.add_stimulus_template(visual_stimulus_image_series) - if 'image_index' in stimulus_presentations.value \ - and self._image_template_key is not None: - nwbfile = self._add_image_index_to_nwb( - nwbfile=nwbfile, presentations=stimulus_presentations) + if "image_index" in stimulus_presentations.value and self._image_template_key is not None: + nwbfile = self._add_image_index_to_nwb(nwbfile=nwbfile, presentations=stimulus_presentations) return nwbfile - def _add_image_index_to_nwb( - self, nwbfile: NWBFile, presentations: Presentations): + def _add_image_index_to_nwb(self, nwbfile: NWBFile, presentations: Presentations): """Adds the image index and start_time for all stimulus templates to NWB""" stimulus_templates = self.value[self._image_template_key] presentations = presentations.value - nwb_template = nwbfile.stimulus_template[ - stimulus_templates.image_set_name] - stimulus_name = 'image_set' \ - if 'image_set' in presentations else 'stimulus_name' - stimulus_index = presentations[ - presentations[stimulus_name] == nwb_template.name] + nwb_template = nwbfile.stimulus_template[stimulus_templates.image_set_name] + stimulus_name = "image_set" if "image_set" in presentations else "stimulus_name" + stimulus_index = presentations[presentations[stimulus_name] == nwb_template.name] image_index = IndexSeries( name=nwb_template.name, - data=stimulus_index['image_index'].values, - unit='N/A', + data=stimulus_index["image_index"].values, + unit="N/A", indexed_timeseries=nwb_template, - timestamps=stimulus_index['start_time'].values) + timestamps=stimulus_index["start_time"].values, + ) nwbfile.add_stimulus(image_index) return nwbfile diff --git a/allensdk/brain_observatory/behavior/data_objects/stimuli/util.py b/allensdk/brain_observatory/behavior/data_objects/stimuli/util.py index 02a6d07976..49ae57762b 100644 --- a/allensdk/brain_observatory/behavior/data_objects/stimuli/util.py +++ b/allensdk/brain_observatory/behavior/data_objects/stimuli/util.py @@ -2,14 +2,12 @@ from pathlib import Path from allensdk.brain_observatory.behavior.data_files import SyncFile -from allensdk.brain_observatory.behavior.data_objects.metadata \ - .behavior_metadata.equipment import \ - Equipment +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.equipment import Equipment from allensdk.internal.brain_observatory.time_sync import OphysTimeAligner def convert_filepath_caseinsensitive(filename_in): - return filename_in.replace('TRAINING', 'training') + return filename_in.replace("TRAINING", "training") def get_image_set_name(image_set_path: str) -> str: @@ -19,8 +17,7 @@ def get_image_set_name(image_set_path: str) -> str: return Path(image_set_path).stem -def calculate_monitor_delay(sync_file: SyncFile, - equipment: Equipment) -> float: +def calculate_monitor_delay(sync_file: SyncFile, equipment: Equipment) -> float: """Calculates monitor delay using sync file. If that fails, looks up monitor delay from known values for equipment. @@ -36,30 +33,32 @@ def calculate_monitor_delay(sync_file: SyncFile, except ValueError as ee: equipment_name = equipment.value - warning_msg = 'Monitory delay calculation failed ' - warning_msg += 'with ValueError\n' + warning_msg = "Monitory delay calculation failed " + warning_msg += "with ValueError\n" warning_msg += f' "{ee}"' - warning_msg += '\nlooking monitor delay up from table ' - warning_msg += f'for rig: {equipment_name} ' + warning_msg += "\nlooking monitor delay up from table " + warning_msg += f"for rig: {equipment_name} " # see # https://github.com/AllenInstitute/AllenSDK/issues/1318 # https://github.com/AllenInstitute/AllenSDK/issues/1916 - delay_lookup = {'CAM2P.1': 0.020842, - 'CAM2P.2': 0.037566, - 'CAM2P.3': 0.021390, - 'CAM2P.4': 0.021102, - 'CAM2P.5': 0.021192, - 'MESO.1': 0.03613, - # TODO update with actual monitor delay once calculated - 'MESO.2': 0.03613} + delay_lookup = { + "CAM2P.1": 0.020842, + "CAM2P.2": 0.037566, + "CAM2P.3": 0.021390, + "CAM2P.4": 0.021102, + "CAM2P.5": 0.021192, + "MESO.1": 0.03613, + # TODO update with actual monitor delay once calculated + "MESO.2": 0.03613, + } if equipment_name not in delay_lookup: msg = warning_msg - msg += f'\nequipment_name {equipment_name} not in lookup table' + msg += f"\nequipment_name {equipment_name} not in lookup table" raise RuntimeError(msg) delay = delay_lookup[equipment_name] - warning_msg += f'\ndelay: {delay} seconds' + warning_msg += f"\ndelay: {delay} seconds" warnings.warn(warning_msg) return delay diff --git a/allensdk/brain_observatory/behavior/data_objects/task_parameters.py b/allensdk/brain_observatory/behavior/data_objects/task_parameters.py index c7b7a02d1a..5cfc9185a2 100644 --- a/allensdk/brain_observatory/behavior/data_objects/task_parameters.py +++ b/allensdk/brain_observatory/behavior/data_objects/task_parameters.py @@ -6,51 +6,47 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface -from allensdk.core import \ - NwbWritableInterface -from allensdk.brain_observatory.behavior.schemas import \ - BehaviorTaskParametersSchema +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface +from allensdk.core import NwbWritableInterface +from allensdk.brain_observatory.behavior.schemas import BehaviorTaskParametersSchema from allensdk.brain_observatory.nwb import load_pynwb_extension from allensdk.brain_observatory.behavior.utils.metadata_parsers import ( # noqa E501 - parse_stimulus_set + parse_stimulus_set, ) class BehaviorStimulusType(Enum): - IMAGES = 'images' - GRATING = 'grating' + IMAGES = "images" + GRATING = "grating" class StimulusDistribution(Enum): - EXPONENTIAL = 'exponential' - GEOMETRIC = 'geometric' + EXPONENTIAL = "exponential" + GEOMETRIC = "geometric" class TaskType(Enum): - CHANGE_DETECTION = 'change detection' - - -class TaskParameters(DataObject, StimulusFileReadableInterface, - NwbReadableInterface, NwbWritableInterface): - def __init__(self, - blank_duration_sec: List[float], - stimulus_duration_sec: float, - omitted_flash_fraction: float, - response_window_sec: List[float], - reward_volume: float, - auto_reward_volume: float, - session_type: str, - stimulus: str, - stimulus_distribution: StimulusDistribution, - task_type: TaskType, - n_stimulus_frames: int, - stimulus_name: Optional[str] = None): - super().__init__(name='task_parameters', value=None, - is_value_self=True) + CHANGE_DETECTION = "change detection" + + +class TaskParameters(DataObject, StimulusFileReadableInterface, NwbReadableInterface, NwbWritableInterface): + def __init__( + self, + blank_duration_sec: List[float], + stimulus_duration_sec: float, + omitted_flash_fraction: float, + response_window_sec: List[float], + reward_volume: float, + auto_reward_volume: float, + session_type: str, + stimulus: str, + stimulus_distribution: StimulusDistribution, + task_type: TaskType, + n_stimulus_frames: int, + stimulus_name: Optional[str] = None, + ): + super().__init__(name="task_parameters", value=None, is_value_self=True) self._blank_duration_sec = blank_duration_sec self._stimulus_duration_sec = stimulus_duration_sec self._omitted_flash_fraction = omitted_flash_fraction @@ -59,8 +55,7 @@ def __init__(self, self._auto_reward_volume = auto_reward_volume self._session_type = session_type self._stimulus = BehaviorStimulusType(stimulus) - self._stimulus_distribution = StimulusDistribution( - stimulus_distribution) + self._stimulus_distribution = StimulusDistribution(stimulus_distribution) self._task = TaskType(task_type) self._n_stimulus_frames = n_stimulus_frames self._stimulus_name = stimulus_name @@ -119,13 +114,9 @@ def image_set(self) -> str: return self._image_set def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - nwb_extension = load_pynwb_extension( - BehaviorTaskParametersSchema, 'ndx-aibs-behavior-ophys' - ) - task_parameters = self.to_dict()['task_parameters'] - task_parameters_clean = BehaviorTaskParametersSchema().dump( - task_parameters - ) + nwb_extension = load_pynwb_extension(BehaviorTaskParametersSchema, "ndx-aibs-behavior-ophys") + task_parameters = self.to_dict()["task_parameters"] + task_parameters_clean = BehaviorTaskParametersSchema().dump(task_parameters) new_task_parameters_dict = {} for key, val in task_parameters_clean.items(): @@ -133,44 +124,38 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: new_task_parameters_dict[key] = np.array(val) else: new_task_parameters_dict[key] = val - nwb_task_parameters = nwb_extension( - name='task_parameters', **new_task_parameters_dict) + nwb_task_parameters = nwb_extension(name="task_parameters", **new_task_parameters_dict) nwbfile.add_lab_meta_data(nwb_task_parameters) return nwbfile @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "TaskParameters": - metadata_nwb_obj = nwbfile.lab_meta_data['task_parameters'] + metadata_nwb_obj = nwbfile.lab_meta_data["task_parameters"] data = BehaviorTaskParametersSchema().dump(metadata_nwb_obj) - data['task_type'] = data['task'] - del data['task'] + data["task_type"] = data["task"] + del data["task"] return TaskParameters(**data) @classmethod - def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile) -> "TaskParameters": + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile) -> "TaskParameters": data = stimulus_file.data behavior = data["items"]["behavior"] config = behavior["config"] doc = config["DoC"] - blank_duration_sec = [float(x) for x in doc['blank_duration_range']] - stim_duration = cls._calculate_stimulus_duration( - stimulus_file=stimulus_file) - omitted_flash_fraction = \ - behavior['params'].get('flash_omit_probability', float('nan')) + blank_duration_sec = [float(x) for x in doc["blank_duration_range"]] + stim_duration = cls._calculate_stimulus_duration(stimulus_file=stimulus_file) + omitted_flash_fraction = behavior["params"].get("flash_omit_probability", float("nan")) response_window_sec = [float(x) for x in doc["response_window"]] reward_volume = config["reward"]["reward_volume"] - auto_reward_volume = doc['auto_reward_volume'] + auto_reward_volume = doc["auto_reward_volume"] session_type = behavior["params"]["stage"] stimulus = next(iter(behavior["stimuli"])) stimulus_name = stimulus_file.stimulus_name stimulus_distribution = doc["change_time_dist"] task = cls._parse_task(stimulus_file=stimulus_file) - n_stimulus_frames = cls._calculuate_n_stimulus_frames( - stimulus_file=stimulus_file) + n_stimulus_frames = cls._calculuate_n_stimulus_frames(stimulus_file=stimulus_file) return TaskParameters( blank_duration_sec=blank_duration_sec, stimulus_duration_sec=stim_duration, @@ -183,22 +168,21 @@ def from_stimulus_file( stimulus_distribution=stimulus_distribution, task_type=task, n_stimulus_frames=n_stimulus_frames, - stimulus_name=stimulus_name + stimulus_name=stimulus_name, ) @staticmethod - def _calculate_stimulus_duration( - stimulus_file: BehaviorStimulusFile) -> float: + def _calculate_stimulus_duration(stimulus_file: BehaviorStimulusFile) -> float: data = stimulus_file.data behavior = data["items"]["behavior"] - stimuli = behavior['stimuli'] + stimuli = behavior["stimuli"] def _parse_stimulus_key(): - if 'images' in stimuli: - stim_key = 'images' - elif 'grating' in stimuli: - stim_key = 'grating' + if "images" in stimuli: + stim_key = "images" + elif "grating" in stimuli: + stim_key = "grating" else: msg = "Cannot get stimulus_duration_sec\n" msg += "'images' and/or 'grating' not a valid " @@ -208,8 +192,9 @@ def _parse_stimulus_key(): raise RuntimeError(msg) return stim_key + stim_key = _parse_stimulus_key() - stim_duration = stimuli[stim_key]['flash_interval_sec'] + stim_duration = stimuli[stim_key]["flash_interval_sec"] # from discussion in # https://github.com/AllenInstitute/AllenSDK/issues/1572 @@ -229,13 +214,12 @@ def _parse_stimulus_key(): return stim_duration @staticmethod - def _parse_task( - stimulus_file: BehaviorStimulusFile) -> TaskType: + def _parse_task(stimulus_file: BehaviorStimulusFile) -> TaskType: data = stimulus_file.data config = data["items"]["behavior"]["config"] - task_id = config['behavior']['task_id'] - if 'DoC' in task_id: + task_id = config["behavior"]["task_id"] + if "DoC" in task_id: task = TaskType.CHANGE_DETECTION else: msg = "metadata.get_task_parameters does not " @@ -244,8 +228,7 @@ def _parse_task( return task @staticmethod - def _calculuate_n_stimulus_frames( - stimulus_file: BehaviorStimulusFile) -> int: + def _calculuate_n_stimulus_frames(stimulus_file: BehaviorStimulusFile) -> int: data = stimulus_file.data behavior = data["items"]["behavior"] diff --git a/allensdk/brain_observatory/behavior/data_objects/timestamps/ophys_timestamps.py b/allensdk/brain_observatory/behavior/data_objects/timestamps/ophys_timestamps.py index 3b0665a00a..4f08a91e11 100644 --- a/allensdk/brain_observatory/behavior/data_objects/timestamps/ophys_timestamps.py +++ b/allensdk/brain_observatory/behavior/data_objects/timestamps/ophys_timestamps.py @@ -5,14 +5,11 @@ from allensdk.brain_observatory.behavior.data_files import SyncFile from allensdk.core import DataObject -from allensdk.core import \ - NwbReadableInterface -from allensdk.brain_observatory.behavior.data_files.sync_file import \ - SyncFileReadableInterface +from allensdk.core import NwbReadableInterface +from allensdk.brain_observatory.behavior.data_files.sync_file import SyncFileReadableInterface -class OphysTimestamps(DataObject, SyncFileReadableInterface, - NwbReadableInterface): +class OphysTimestamps(DataObject, SyncFileReadableInterface, NwbReadableInterface): _logger = logging.getLogger(__name__) def __init__(self, timestamps: np.ndarray): @@ -20,18 +17,16 @@ def __init__(self, timestamps: np.ndarray): :param timestamps ophys timestamps """ - super().__init__(name='ophys_timestamps', value=timestamps) + super().__init__(name="ophys_timestamps", value=timestamps) @classmethod def from_sync_file(cls, sync_file: SyncFile) -> "OphysTimestamps": - ophys_timestamps = sync_file.data['ophys_frames'] + ophys_timestamps = sync_file.data["ophys_frames"] return cls(timestamps=ophys_timestamps) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "OphysTimestamps": - ts = nwbfile.processing[ - 'ophys'].get_data_interface('dff').roi_response_series[ - 'traces'].timestamps[:] + ts = nwbfile.processing["ophys"].get_data_interface("dff").roi_response_series["traces"].timestamps[:] return cls(timestamps=ts) def validate(self, number_of_frames: int) -> "OphysTimestamps": @@ -58,12 +53,13 @@ def validate(self, number_of_frames: int) -> "OphysTimestamps": self._logger.info( "Truncating acquisition frames ('ophys_frames') " f"(len={num_of_timestamps}) to the number of frames " - f"in the df/f trace ({number_of_frames}).") + f"in the df/f trace ({number_of_frames})." + ) self._value = ophys_timestamps[:number_of_frames] elif number_of_frames > num_of_timestamps: raise RuntimeError( - f"dff_frames (len={number_of_frames}) is longer " - f"than timestamps (len={num_of_timestamps}).") + f"dff_frames (len={number_of_frames}) is longer than timestamps (len={num_of_timestamps})." + ) return self @@ -72,17 +68,16 @@ def __init__(self, timestamps: np.ndarray): super().__init__(timestamps=timestamps) @classmethod - def from_sync_file(cls, sync_file: SyncFile, - group_count: int, - plane_group: int) -> "OphysTimestampsMultiplane": + def from_sync_file(cls, sync_file: SyncFile, group_count: int, plane_group: int) -> "OphysTimestampsMultiplane": if group_count == 0: - raise ValueError('Group count cannot be 0') + raise ValueError("Group count cannot be 0") - ophys_timestamps = sync_file.data['ophys_frames'] + ophys_timestamps = sync_file.data["ophys_frames"] cls._logger.info( "Mesoscope data detected. Splitting timestamps " f"(len={len(ophys_timestamps)} over {group_count} " - "plane group(s).") + "plane group(s)." + ) # Resample if collecting multiple concurrent planes # because the frames are interleaved @@ -101,5 +96,6 @@ def validate(self, number_of_frames: int) -> "OphysTimestampsMultiplane": if number_of_frames != num_of_timestamps: raise RuntimeError( f"dff_frames (len={number_of_frames}) is not equal to " - f"number of split timestamps (len={num_of_timestamps}).") + f"number of split timestamps (len={num_of_timestamps})." + ) return self diff --git a/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/stimulus_timestamps.py b/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/stimulus_timestamps.py index 09a6541837..16242580fa 100644 --- a/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/stimulus_timestamps.py +++ b/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/stimulus_timestamps.py @@ -4,35 +4,34 @@ from pynwb import NWBFile, ProcessingModule from pynwb.base import TimeSeries -from allensdk.core import \ - LimsReadableInterface, NwbReadableInterface, JsonReadableInterface -from allensdk.brain_observatory.behavior.data_files.sync_file import \ - SyncFileReadableInterface -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - StimulusFileReadableInterface +from allensdk.core import LimsReadableInterface, NwbReadableInterface, JsonReadableInterface +from allensdk.brain_observatory.behavior.data_files.sync_file import SyncFileReadableInterface +from allensdk.brain_observatory.behavior.data_files.stimulus_file import StimulusFileReadableInterface from allensdk.core import DataObject from allensdk.brain_observatory.behavior.data_files import ( BehaviorStimulusFile, MappingStimulusFile, ReplayStimulusFile, - SyncFile + SyncFile, ) from allensdk.core import NwbWritableInterface -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .stimulus_timestamps.timestamps_processing import ( - get_behavior_stimulus_timestamps, get_ophys_stimulus_timestamps) +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.timestamps_processing import ( + get_behavior_stimulus_timestamps, + get_ophys_stimulus_timestamps, +) from allensdk.internal.api import PostgresQueryMixin -from allensdk.brain_observatory.sync_stim_aligner import ( - get_stim_timestamps_from_stimulus_blocks) - - -class StimulusTimestamps(DataObject, - StimulusFileReadableInterface, - SyncFileReadableInterface, - NwbReadableInterface, - LimsReadableInterface, - NwbWritableInterface, - JsonReadableInterface): +from allensdk.brain_observatory.sync_stim_aligner import get_stim_timestamps_from_stimulus_blocks + + +class StimulusTimestamps( + DataObject, + StimulusFileReadableInterface, + SyncFileReadableInterface, + NwbReadableInterface, + LimsReadableInterface, + NwbWritableInterface, + JsonReadableInterface, +): """A DataObject which contains properties and methods to load, process, and represent visual behavior stimulus timestamp data. @@ -47,18 +46,14 @@ def __init__( timestamps: np.ndarray, monitor_delay: float, stimulus_file: Optional[BehaviorStimulusFile] = None, - sync_file: Optional[SyncFile] = None + sync_file: Optional[SyncFile] = None, ): - super().__init__(name="stimulus_timestamps", - value=timestamps+monitor_delay) + super().__init__(name="stimulus_timestamps", value=timestamps + monitor_delay) self._stimulus_file = stimulus_file self._sync_file = sync_file self._monitor_delay = monitor_delay - def update_timestamps( - self, - timestamps: np.ndarray - ) -> "StimulusTimestamps": + def update_timestamps(self, timestamps: np.ndarray) -> "StimulusTimestamps": """ Returns newly instantiated `StimulusTimestamps` with `timestamps` @@ -74,7 +69,7 @@ def update_timestamps( timestamps=timestamps, monitor_delay=self._monitor_delay, stimulus_file=self._stimulus_file, - sync_file=self._sync_file + sync_file=self._sync_file, ) def subtract_monitor_delay(self) -> "StimulusTimestamps": @@ -83,50 +78,26 @@ def subtract_monitor_delay(self) -> "StimulusTimestamps": monitor_delay = 0 by subtracting self.monitor_delay from self.value """ - new_value = self.value-self.monitor_delay - return StimulusTimestamps( - timestamps=new_value, - monitor_delay=0.0) + new_value = self.value - self.monitor_delay + return StimulusTimestamps(timestamps=new_value, monitor_delay=0.0) @property def monitor_delay(self) -> float: return self._monitor_delay @classmethod - def from_stimulus_file( - cls, - stimulus_file: BehaviorStimulusFile, - monitor_delay: float) -> "StimulusTimestamps": - stimulus_timestamps = get_behavior_stimulus_timestamps( - stimulus_pkl=stimulus_file.data - ) + def from_stimulus_file(cls, stimulus_file: BehaviorStimulusFile, monitor_delay: float) -> "StimulusTimestamps": + stimulus_timestamps = get_behavior_stimulus_timestamps(stimulus_pkl=stimulus_file.data) - return cls( - timestamps=stimulus_timestamps, - monitor_delay=monitor_delay, - stimulus_file=stimulus_file - ) + return cls(timestamps=stimulus_timestamps, monitor_delay=monitor_delay, stimulus_file=stimulus_file) @classmethod - def from_sync_file( - cls, - sync_file: SyncFile, - monitor_delay: float) -> "StimulusTimestamps": - stimulus_timestamps = get_ophys_stimulus_timestamps( - sync_path=sync_file.filepath - ) - return cls( - timestamps=stimulus_timestamps, - monitor_delay=monitor_delay, - sync_file=sync_file - ) + def from_sync_file(cls, sync_file: SyncFile, monitor_delay: float) -> "StimulusTimestamps": + stimulus_timestamps = get_ophys_stimulus_timestamps(sync_path=sync_file.filepath) + return cls(timestamps=stimulus_timestamps, monitor_delay=monitor_delay, sync_file=sync_file) @classmethod - def from_json( - cls, - dict_repr: dict, - monitor_delay=0.0 - ) -> "StimulusTimestamps": + def from_json(cls, dict_repr: dict, monitor_delay=0.0) -> "StimulusTimestamps": """ Reads timestamps from stimulus file or sync file. Note that `from_multiple_stimulus_blocks` method of constructing @@ -141,29 +112,24 @@ def from_json( ------- StimulusTimestamps """ - if 'sync_file' in dict_repr: + if "sync_file" in dict_repr: sync_file = SyncFile.from_json(dict_repr=dict_repr) - return cls.from_sync_file( - sync_file=sync_file, - monitor_delay=monitor_delay) + return cls.from_sync_file(sync_file=sync_file, monitor_delay=monitor_delay) else: stim_file = BehaviorStimulusFile.from_json(dict_repr=dict_repr) - return cls.from_stimulus_file( - stimulus_file=stim_file, - monitor_delay=monitor_delay) + return cls.from_stimulus_file(stimulus_file=stim_file, monitor_delay=monitor_delay) @classmethod def from_multiple_stimulus_blocks( - cls, - sync_file: SyncFile, - list_of_stims: List[Union[BehaviorStimulusFile, - MappingStimulusFile, - ReplayStimulusFile]], - stims_of_interest: Optional[List[int]] = None, - monitor_delay: float = 0.0, - frame_time_lines: Union[str, List[str]] = 'vsync_stim', - frame_time_line_direction: str = 'rising', - frame_count_tolerance: float = 0.0) -> "StimulusTimestamps": + cls, + sync_file: SyncFile, + list_of_stims: List[Union[BehaviorStimulusFile, MappingStimulusFile, ReplayStimulusFile]], + stims_of_interest: Optional[List[int]] = None, + monitor_delay: float = 0.0, + frame_time_lines: Union[str, List[str]] = "vsync_stim", + frame_time_line_direction: str = "rising", + frame_count_tolerance: float = 0.0, + ) -> "StimulusTimestamps": """ Construct a StimulusTimestamps instance by registering multiple stimulus blocks to one sync file and concatenating the results @@ -200,70 +166,65 @@ def from_multiple_stimulus_blocks( The tolerance to within two blocks of frame counts are considered equal """ - behavior_stimulus_files = [x for x in list_of_stims - if isinstance(x, BehaviorStimulusFile)] + behavior_stimulus_files = [x for x in list_of_stims if isinstance(x, BehaviorStimulusFile)] if len(behavior_stimulus_files) == 0: - raise ValueError( - 'One of the values in `list_of_stims` must be a ' - '`BehaviorStimulusFile`') + raise ValueError("One of the values in `list_of_stims` must be a `BehaviorStimulusFile`") elif len(behavior_stimulus_files) > 1: - raise ValueError('You passed multiple `BehaviorStimulusFile` to ' - '`list_of_stims`. Please pass only 1.') + raise ValueError("You passed multiple `BehaviorStimulusFile` to `list_of_stims`. Please pass only 1.") if stims_of_interest: if len(stims_of_interest) > len(list_of_stims): raise ValueError( - f'stims_of_interest has length {len(stims_of_interest)} ' - f'but list_of_stims has length {len(list_of_stims)}. ' - f'len(stims_of_interest) should be <= len(list_of_stims)') + f"stims_of_interest has length {len(stims_of_interest)} " + f"but list_of_stims has length {len(list_of_stims)}. " + f"len(stims_of_interest) should be <= len(list_of_stims)" + ) if any([x < 0 for x in stims_of_interest]): - raise ValueError('stims_of_interest should not be negative') + raise ValueError("stims_of_interest should not be negative") if any([x >= len(list_of_stims) for x in stims_of_interest]): - raise ValueError('stims_of_interest contains an index ' - 'greater than the number of elements in ' - 'list_of_stims') + raise ValueError( + "stims_of_interest contains an index greater than the number of elements in list_of_stims" + ) stimulus_times = get_stim_timestamps_from_stimulus_blocks( - stimulus_files=list_of_stims, - sync_file=sync_file.filepath, - raw_frame_time_lines=frame_time_lines, - raw_frame_time_direction=frame_time_line_direction, - frame_count_tolerance=frame_count_tolerance) + stimulus_files=list_of_stims, + sync_file=sync_file.filepath, + raw_frame_time_lines=frame_time_lines, + raw_frame_time_direction=frame_time_line_direction, + frame_count_tolerance=frame_count_tolerance, + ) - to_concatenate = \ - [t for t in stimulus_times["timestamps"]] \ - if stims_of_interest is None else \ - [stimulus_times["timestamps"][idx] for idx in stims_of_interest] + to_concatenate = ( + [t for t in stimulus_times["timestamps"]] + if stims_of_interest is None + else [stimulus_times["timestamps"][idx] for idx in stims_of_interest] + ) timestamps = np.concatenate(to_concatenate) - return cls(timestamps=timestamps, - monitor_delay=monitor_delay, - sync_file=sync_file, - stimulus_file=behavior_stimulus_files[0]) + return cls( + timestamps=timestamps, + monitor_delay=monitor_delay, + sync_file=sync_file, + stimulus_file=behavior_stimulus_files[0], + ) def from_lims( cls, db: PostgresQueryMixin, monitor_delay: float, behavior_session_id: int, - ophys_experiment_id: Optional[int] = None + ophys_experiment_id: Optional[int] = None, ) -> "StimulusTimestamps": - stimulus_file = BehaviorStimulusFile.from_lims( - db, - behavior_session_id) + stimulus_file = BehaviorStimulusFile.from_lims(db, behavior_session_id) if ophys_experiment_id: - sync_file = SyncFile.from_lims( - db=db, ophys_experiment_id=ophys_experiment_id) - return cls.from_sync_file(sync_file=sync_file, - monitor_delay=monitor_delay) + sync_file = SyncFile.from_lims(db=db, ophys_experiment_id=ophys_experiment_id) + return cls.from_sync_file(sync_file=sync_file, monitor_delay=monitor_delay) else: - return cls.from_stimulus_file(stimulus_file=stimulus_file, - monitor_delay=monitor_delay) + return cls.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=monitor_delay) @classmethod - def from_nwb(cls, - nwbfile: NWBFile) -> "StimulusTimestamps": + def from_nwb(cls, nwbfile: NWBFile) -> "StimulusTimestamps": stim_module = nwbfile.processing["stimulus"] stim_ts_interface = stim_module.get_data_interface("timestamps") stim_timestamps = stim_ts_interface.timestamps[:] @@ -271,16 +232,10 @@ def from_nwb(cls, # Because the monitor delay was already applied when # saving the stimulus timestamps to the NWB file, # we set it to zero here. - return cls(timestamps=stim_timestamps, - monitor_delay=0.0) + return cls(timestamps=stim_timestamps, monitor_delay=0.0) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: - stimulus_ts = TimeSeries( - data=self._value, - name="timestamps", - timestamps=self._value, - unit="s" - ) + stimulus_ts = TimeSeries(data=self._value, name="timestamps", timestamps=self._value, unit="s") stim_mod = ProcessingModule("stimulus", "Stimulus Times processing") stim_mod.add_data_interface(stimulus_ts) diff --git a/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/timestamps_processing.py b/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/timestamps_processing.py index a4d91677eb..7de125857d 100644 --- a/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/timestamps_processing.py +++ b/allensdk/brain_observatory/behavior/data_objects/timestamps/stimulus_timestamps/timestamps_processing.py @@ -49,9 +49,7 @@ def get_ophys_stimulus_timestamps(sync_path: Union[str, Path]) -> np.ndarray: return stimulus_timestamps -def get_frame_indices( - frame_timestamps: np.ndarray, - event_timestamps: np.ndarray) -> np.ndarray: +def get_frame_indices(frame_timestamps: np.ndarray, event_timestamps: np.ndarray) -> np.ndarray: """ Given an array of timestamps corresponding to stimulus frames and an array of timestamps corresponding to some event (i.e. @@ -85,21 +83,18 @@ def get_frame_indices( n_frames = len(frame_timestamps) - event_indices = np.searchsorted( - frame_timestamps, - event_timestamps, - side='left') + event_indices = np.searchsorted(frame_timestamps, event_timestamps, side="left") - event_indices = np.clip(event_indices, None, n_frames-1) + event_indices = np.clip(event_indices, None, n_frames - 1) # correct for fact that searchsorted will select as # frame index the first frame time that is larger # than lick_times; we want the lick_times associated # with the last frame that is smaller than the lick_time event_frame_times = frame_timestamps[event_indices] - delta = event_timestamps-event_frame_times - to_decrement = (delta < -1.0e-6) + delta = event_timestamps - event_frame_times + to_decrement = delta < -1.0e-6 event_indices[to_decrement] -= 1 - event_indices = np.clip(event_indices, 0, n_frames-1) + event_indices = np.clip(event_indices, 0, n_frames - 1) return event_indices diff --git a/allensdk/brain_observatory/behavior/data_objects/trials/trial.py b/allensdk/brain_observatory/behavior/data_objects/trials/trial.py index 97e32370d2..b6a615f687 100644 --- a/allensdk/brain_observatory/behavior/data_objects/trials/trial.py +++ b/allensdk/brain_observatory/behavior/data_objects/trials/trial.py @@ -12,17 +12,18 @@ class Trial: def __init__( - self, - trial: dict, - start: float, - end: float, - behavior_stimulus_file: BehaviorStimulusFile, - index: int, - stimulus_timestamps: StimulusTimestamps, - licks: Licks, - rewards: Rewards, - stimuli: dict, - sync_file: Optional[SyncFile] = None): + self, + trial: dict, + start: float, + end: float, + behavior_stimulus_file: BehaviorStimulusFile, + index: int, + stimulus_timestamps: StimulusTimestamps, + licks: Licks, + rewards: Rewards, + stimuli: dict, + sync_file: Optional[SyncFile] = None, + ): """ sync_file is an argument that will be used by sub-classes that have a more subtle way of handling @@ -31,27 +32,21 @@ def __init__( self._trial = trial self._start = start - self._end = self._calculate_trial_end( - trial_end=end, behavior_stimulus_file=behavior_stimulus_file) + self._end = self._calculate_trial_end(trial_end=end, behavior_stimulus_file=behavior_stimulus_file) self._index = index self._stimulus_timestamps = stimulus_timestamps self._sync_file = sync_file self._data = self._match_to_sync_timestamps( - raw_stimulus_timestamps=stimulus_timestamps, - licks=licks, - rewards=rewards, - stimuli=stimuli) + raw_stimulus_timestamps=stimulus_timestamps, licks=licks, rewards=rewards, stimuli=stimuli + ) @property def data(self): return self._data def _match_to_sync_timestamps( - self, - raw_stimulus_timestamps: StimulusTimestamps, - licks: Licks, - rewards: Rewards, - stimuli: dict) -> Dict[str, Any]: + self, raw_stimulus_timestamps: StimulusTimestamps, licks: Licks, rewards: Rewards, stimuli: dict + ) -> Dict[str, Any]: """ raw_stimulus_timestamps include monitor_delay """ @@ -61,15 +56,13 @@ def _match_to_sync_timestamps( stimulus_timestamps = raw_stimulus_timestamps.subtract_monitor_delay() event_dict = { - (e[0], e[1]): { - 'timestamp': stimulus_timestamps.value[e[3]], - 'frame': e[3]} for e in self._trial['events'] + (e[0], e[1]): {"timestamp": stimulus_timestamps.value[e[3]], "frame": e[3]} for e in self._trial["events"] } tr_data = {"trial": self._trial["index"]} - lick_frames = licks.value['frame'].values + lick_frames = licks.value["frame"].values timestamps = stimulus_timestamps.value - reward_times = rewards.value['timestamps'].values + reward_times = rewards.value["timestamps"].values # this block of code is trying to mimic # https://github.com/AllenInstitute/visual_behavior_analysis @@ -93,8 +86,7 @@ def _match_to_sync_timestamps( # licks on the boundary get assigned to the trial that is ending, # rather than the trial that is starting if self._end > 0: - valid_idx = np.where(np.logical_and(lick_frames > self._start, - lick_frames <= self._end)) + valid_idx = np.where(np.logical_and(lick_frames > self._start, lick_frames <= self._end)) else: valid_idx = np.where(lick_frames > self._start) @@ -105,21 +97,21 @@ def _match_to_sync_timestamps( tr_data["lick_times"] = np.array([], dtype=float) tr_data["reward_time"] = self._get_reward_time( - reward_times, - event_dict[('trial_start', '')]['timestamp'], - event_dict[('trial_end', '')]['timestamp'] + reward_times, event_dict[("trial_start", "")]["timestamp"], event_dict[("trial_end", "")]["timestamp"] ) tr_data.update(self._get_trial_data()) - tr_data.update(self._get_trial_timing( - event_dict, - tr_data['lick_times'], - tr_data['go'], - tr_data['catch'], - tr_data['auto_rewarded'], - tr_data['hit'], - tr_data['false_alarm'], - tr_data["aborted"], - )) + tr_data.update( + self._get_trial_timing( + event_dict, + tr_data["lick_times"], + tr_data["go"], + tr_data["catch"], + tr_data["auto_rewarded"], + tr_data["hit"], + tr_data["false_alarm"], + tr_data["aborted"], + ) + ) tr_data.update(self._get_trial_image_names(stimuli)) self._validate_trial_condition_exclusivity(tr_data=tr_data) @@ -127,25 +119,19 @@ def _match_to_sync_timestamps( return tr_data @staticmethod - def _get_reward_time(rebased_reward_times, - start_time, - stop_time) -> float: + def _get_reward_time(rebased_reward_times, start_time, stop_time) -> float: """extract reward times in time range""" - reward_times = rebased_reward_times[np.where(np.logical_and( - rebased_reward_times >= start_time, - rebased_reward_times <= stop_time - ))] - return float('nan') if len(reward_times) == 0 else one( - reward_times) + reward_times = rebased_reward_times[ + np.where(np.logical_and(rebased_reward_times >= start_time, rebased_reward_times <= stop_time)) + ] + return float("nan") if len(reward_times) == 0 else one(reward_times) @staticmethod - def _calculate_trial_end( - trial_end, - behavior_stimulus_file: BehaviorStimulusFile) -> int: + def _calculate_trial_end(trial_end, behavior_stimulus_file: BehaviorStimulusFile) -> int: if trial_end < 0: - bhv = behavior_stimulus_file.data['items']['behavior']['items'] - if 'fingerprint' in bhv.keys(): - trial_end = bhv['fingerprint']['starting_frame'] + bhv = behavior_stimulus_file.data["items"]["behavior"]["items"] + if "fingerprint" in bhv.keys(): + trial_end = bhv["fingerprint"]["starting_frame"] return trial_end def _get_trial_data(self) -> Dict[str, Any]: @@ -178,13 +164,13 @@ def _get_trial_data(self) -> Dict[str, Any]: This will bias the animals choice and should not be categorized as hit/miss) """ - trial_event_names = [val[0] for val in self._trial['events']] - hit = 'hit' in trial_event_names - false_alarm = 'false_alarm' in trial_event_names - miss = 'miss' in trial_event_names - sham_change = 'sham_change' in trial_event_names - stimulus_change = 'stimulus_changed' in trial_event_names - aborted = 'abort' in trial_event_names + trial_event_names = [val[0] for val in self._trial["events"]] + hit = "hit" in trial_event_names + false_alarm = "false_alarm" in trial_event_names + miss = "miss" in trial_event_names + sham_change = "sham_change" in trial_event_names + stimulus_change = "stimulus_changed" in trial_event_names + aborted = "abort" in trial_event_names if aborted: go = catch = auto_rewarded = False @@ -199,8 +185,7 @@ def _get_trial_data(self) -> Dict[str, Any]: hit = miss = correct_reject = false_alarm = False return { - "reward_volume": sum([ - r[0] for r in self._trial.get("rewards", [])]), + "reward_volume": sum([r[0] for r in self._trial.get("rewards", [])]), "hit": hit, "false_alarm": false_alarm, "miss": miss, @@ -214,15 +199,16 @@ def _get_trial_data(self) -> Dict[str, Any]: } def _get_trial_timing( - self, - event_dict: dict, - licks: List[float], - go: bool, - catch: bool, - auto_rewarded: bool, - hit: bool, - false_alarm: bool, - aborted: bool) -> Dict[str, Any]: + self, + event_dict: dict, + licks: List[float], + go: bool, + catch: bool, + auto_rewarded: bool, + hit: bool, + false_alarm: bool, + aborted: bool, + ) -> Dict[str, Any]: """ Extract a dictionary of trial timing data. See trial_data_from_log for a description of the trial types. @@ -281,18 +267,15 @@ def _get_trial_timing( hit, miss, false_alarm, aborted, auto_rewarded """ assert not (aborted and (hit or false_alarm or auto_rewarded)), ( - "'aborted' trials cannot be 'hit', 'false_alarm', " - "or 'auto_rewarded'") + "'aborted' trials cannot be 'hit', 'false_alarm', or 'auto_rewarded'" + ) assert not (hit and false_alarm), ( - "both `hit` and `false_alarm` cannot be True, they are mutually " - "exclusive categories") - assert not (go and catch), ( - "both `go` and `catch` cannot be True, they are mutually " - "exclusive " - "categories") + "both `hit` and `false_alarm` cannot be True, they are mutually exclusive categories" + ) + assert not (go and catch), "both `go` and `catch` cannot be True, they are mutually exclusive categories" assert not (go and auto_rewarded), ( - "both `go` and `auto_rewarded` cannot be True, they are mutually " - "exclusive categories") + "both `go` and `auto_rewarded` cannot be True, they are mutually exclusive categories" + ) def _get_response_time(licks: List[float], aborted: bool) -> float: """ @@ -307,23 +290,21 @@ def _get_response_time(licks: List[float], aborted: bool) -> float: else: return float("nan") - start_time = event_dict["trial_start", ""]['timestamp'] - stop_time = event_dict["trial_end", ""]['timestamp'] + start_time = event_dict["trial_start", ""]["timestamp"] + stop_time = event_dict["trial_end", ""]["timestamp"] response_time = _get_response_time(licks, aborted) change_frame = self.calculate_change_frame( - event_dict=event_dict, - go=go, - catch=catch, - auto_rewarded=auto_rewarded) + event_dict=event_dict, go=go, catch=catch, auto_rewarded=auto_rewarded + ) result = { "start_time": start_time, "stop_time": stop_time, "trial_length": stop_time - start_time, "response_time": response_time, - "change_frame": change_frame + "change_frame": change_frame, } result, change_time = self.add_change_time(result) @@ -339,13 +320,7 @@ def _get_response_time(licks: List[float], aborted: bool) -> float: return result - def calculate_change_frame( - self, - event_dict: dict, - go: bool, - catch: bool, - auto_rewarded: bool) -> Union[int, float]: - + def calculate_change_frame(self, event_dict: dict, go: bool, catch: bool, auto_rewarded: bool) -> Union[int, float]: """ Calculate the frame index of a stimulus change associated with a specific event. @@ -375,9 +350,9 @@ def calculate_change_frame( """ if go or auto_rewarded: - change_frame = event_dict.get(('stimulus_changed', ''))['frame'] + change_frame = event_dict.get(("stimulus_changed", ""))["frame"] elif catch: - change_frame = event_dict.get(('sham_change', ''))['frame'] + change_frame = event_dict.get(("sham_change", ""))["frame"] else: change_frame = float("nan") @@ -413,14 +388,14 @@ def add_change_time(self, trial_dict: dict) -> Tuple[dict, float]: ---- Modified trial_dict in-place, in addition to returning it """ - change_frame = trial_dict['change_frame'] + change_frame = trial_dict["change_frame"] if np.isnan(change_frame): change_time = np.nan else: change_frame = int(change_frame) change_time = self._stimulus_timestamps.value[change_frame] - trial_dict['change_time'] = change_time + trial_dict["change_time"] = change_time return trial_dict, change_time def _get_trial_image_names(self, stimuli) -> Dict[str, str]: @@ -439,30 +414,23 @@ def _get_trial_image_names(self, stimuli) -> Dict[str, str]: changed to. """ - grating_oris = {'horizontal', 'vertical'} + grating_oris = {"horizontal", "vertical"} trial_start_frame = self._trial["events"][0][3] - initial_image_category_name, _, initial_image_name = \ - self._resolve_initial_image( - stimuli, trial_start_frame) + initial_image_category_name, _, initial_image_name = self._resolve_initial_image(stimuli, trial_start_frame) if len(self._trial["stimulus_changes"]) == 0: change_image_name = initial_image_name else: - ((from_set, from_name), - (to_set, to_name), - _, _) = self._trial["stimulus_changes"][0] + ((from_set, from_name), (to_set, to_name), _, _) = self._trial["stimulus_changes"][0] # do this to fix names if the stimuli is a grating if from_set in grating_oris: - from_name = f'gratings_{from_name}' + from_name = f"gratings_{from_name}" if to_set in grating_oris: - to_name = f'gratings_{to_name}' + to_name = f"gratings_{to_name}" assert from_name == initial_image_name change_image_name = to_name - return { - "initial_image_name": initial_image_name, - "change_image_name": change_image_name - } + return {"initial_image_name": initial_image_name, "change_image_name": change_image_name} @staticmethod def _resolve_initial_image(stimuli, start_frame) -> Tuple[str, str, str]: @@ -486,9 +454,9 @@ def _resolve_initial_image(stimuli, start_frame) -> Tuple[str, str, str]: name of the initial image """ max_frame = float("-inf") - initial_image_group = '' - initial_image_name = '' - initial_image_category_name = '' + initial_image_group = "" + initial_image_name = "" + initial_image_category_name = "" for stim_category_name, stim_dict in stimuli.items(): for set_event in stim_dict["set_log"]: @@ -498,23 +466,17 @@ def _resolve_initial_image(stimuli, start_frame) -> Tuple[str, str, str]: # only initial_image_name is present for natual_scenes initial_image_group = initial_image_name = set_event[1] initial_image_category_name = stim_category_name - if initial_image_category_name == 'grating': - initial_image_name = f'gratings_{initial_image_name}' + if initial_image_category_name == "grating": + initial_image_name = f"gratings_{initial_image_name}" max_frame = set_frame - return initial_image_category_name, initial_image_group, \ - initial_image_name + return initial_image_category_name, initial_image_group, initial_image_name def _validate_trial_condition_exclusivity(self, tr_data: dict) -> None: """ensure that only one of N possible mutually exclusive trial conditions is True""" trial_conditions = {} - for key in ['hit', - 'miss', - 'false_alarm', - 'correct_reject', - 'auto_rewarded', - 'aborted']: + for key in ["hit", "miss", "false_alarm", "correct_reject", "auto_rewarded", "aborted"]: trial_conditions[key] = tr_data[key] on = [] @@ -524,7 +486,6 @@ def _validate_trial_condition_exclusivity(self, tr_data: dict) -> None: if len(on) != 1: all_conditions = list(trial_conditions.keys()) - msg = f"expected exactly 1 trial condition out of " \ - f"{all_conditions} " + msg = f"expected exactly 1 trial condition out of {all_conditions} " msg += f"to be True, instead {on} were True (trial {self._index})" raise AssertionError(msg) diff --git a/allensdk/brain_observatory/behavior/data_objects/trials/trials.py b/allensdk/brain_observatory/behavior/data_objects/trials/trials.py index 176d86d492..6ec2f8e263 100644 --- a/allensdk/brain_observatory/behavior/data_objects/trials/trials.py +++ b/allensdk/brain_observatory/behavior/data_objects/trials/trials.py @@ -128,9 +128,7 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: index=index, ) else: - nwbfile.add_trial_column( - name=c, description="NOT IMPLEMENTED: %s" % c, data=data - ) + nwbfile.add_trial_column(name=c, description="NOT IMPLEMENTED: %s" % c, data=data) return nwbfile @classmethod @@ -141,9 +139,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "Trials": trials.index = trials.index.rename("trials_id") return cls( trials=trials, - response_window_start=TaskParameters.from_nwb( - nwbfile=nwbfile - ).response_window_sec[0], + response_window_start=TaskParameters.from_nwb(nwbfile=nwbfile).response_window_sec[0], ) @classmethod @@ -217,9 +213,7 @@ def from_stimulus_file( return cls( trials=trials, - response_window_start=TaskParameters.from_stimulus_file( - stimulus_file=stimulus_file - ).response_window_sec[0], + response_window_start=TaskParameters.from_stimulus_file(stimulus_file=stimulus_file).response_window_sec[0], ) @staticmethod @@ -355,25 +349,15 @@ def rolling_performance(self) -> pd.DataFrame: performance_metrics_df = pd.DataFrame(index=trials_index) # Reward rate: - performance_metrics_df["reward_rate"] = pd.Series( - reward_rate, index=self.data.index - ) + performance_metrics_df["reward_rate"] = pd.Series(reward_rate, index=self.data.index) # Hit rate raw: - hit_rate_raw = get_hit_rate( - hit=self.hit, miss=self.miss, aborted=self.aborted - ) - performance_metrics_df["hit_rate_raw"] = pd.Series( - hit_rate_raw, index=not_aborted_index - ) + hit_rate_raw = get_hit_rate(hit=self.hit, miss=self.miss, aborted=self.aborted) + performance_metrics_df["hit_rate_raw"] = pd.Series(hit_rate_raw, index=not_aborted_index) # Hit rate with trial count correction: - hit_rate = get_trial_count_corrected_hit_rate( - hit=self.hit, miss=self.miss, aborted=self.aborted - ) - performance_metrics_df["hit_rate"] = pd.Series( - hit_rate, index=not_aborted_index - ) + hit_rate = get_trial_count_corrected_hit_rate(hit=self.hit, miss=self.miss, aborted=self.aborted) + performance_metrics_df["hit_rate"] = pd.Series(hit_rate, index=not_aborted_index) # False-alarm rate raw: false_alarm_rate_raw = get_false_alarm_rate( @@ -381,9 +365,7 @@ def rolling_performance(self) -> pd.DataFrame: correct_reject=self.correct_reject, aborted=self.aborted, ) - performance_metrics_df["false_alarm_rate_raw"] = pd.Series( - false_alarm_rate_raw, index=not_aborted_index - ) + performance_metrics_df["false_alarm_rate_raw"] = pd.Series(false_alarm_rate_raw, index=not_aborted_index) # False-alarm rate with trial count correction: false_alarm_rate = get_trial_count_corrected_false_alarm_rate( @@ -391,9 +373,7 @@ def rolling_performance(self) -> pd.DataFrame: correct_reject=self.correct_reject, aborted=self.aborted, ) - performance_metrics_df["false_alarm_rate"] = pd.Series( - false_alarm_rate, index=not_aborted_index - ) + performance_metrics_df["false_alarm_rate"] = pd.Series(false_alarm_rate, index=not_aborted_index) # Rolling-dprime: is_passive_session = (self.data["reward_volume"] == 0).all() and ( @@ -405,9 +385,7 @@ def rolling_performance(self) -> pd.DataFrame: rolling_dprime = np.zeros(len(hit_rate)) else: rolling_dprime = get_rolling_dprime(hit_rate, false_alarm_rate) - performance_metrics_df["rolling_dprime"] = pd.Series( - rolling_dprime, index=not_aborted_index - ) + performance_metrics_df["rolling_dprime"] = pd.Series(rolling_dprime, index=not_aborted_index) return performance_metrics_df @@ -432,36 +410,25 @@ def _calculate_response_latency_list(self) -> List: (the two instance of monitor delay cancel out in the difference). """ - df = pd.DataFrame( - {"lick_times": self.lick_times, "change_time": self.change_time} - ) + df = pd.DataFrame({"lick_times": self.lick_times, "change_time": self.change_time}) df["valid_response_licks"] = df.apply( - lambda trial: [ - lt - for lt in trial["lick_times"] - if lt - trial["change_time"] > self._response_window_start - ], + lambda trial: [lt for lt in trial["lick_times"] if lt - trial["change_time"] > self._response_window_start], axis=1, ) response_latency = df.apply( - lambda trial: trial["valid_response_licks"][0] - - trial["change_time"] + lambda trial: trial["valid_response_licks"][0] - trial["change_time"] if len(trial["valid_response_licks"]) > 0 else float("inf"), axis=1, ) return response_latency.tolist() - def calculate_reward_rate( - self, window=0.75, trial_window=25, initial_trials=10 - ): + def calculate_reward_rate(self, window=0.75, trial_window=25, initial_trials=10): response_latency = self._calculate_response_latency_list() starttime = self.start_time.values assert len(response_latency) == len(starttime) - df = pd.DataFrame( - {"response_latency": response_latency, "starttime": starttime} - ) + df = pd.DataFrame({"response_latency": response_latency, "starttime": starttime}) # adds a column called reward_rate to the input dataframe # the reward_rate column contains a rolling average of rewards/min @@ -485,9 +452,7 @@ def calculate_reward_rate( correct = len(df_roll[df_roll.response_latency < window]) # get the time elapsed over the trials - time_elapsed = ( - df_roll.starttime.iloc[-1] - df_roll.starttime.iloc[0] - ) + time_elapsed = df_roll.starttime.iloc[-1] - df_roll.starttime.iloc[0] # calculate the reward rate, rewards/min reward_rate_on_this_lap = correct / time_elapsed * 60 @@ -497,9 +462,7 @@ def calculate_reward_rate( reward_rate[np.isinf(reward_rate)] = float("nan") return reward_rate - def _get_engaged_trials( - self, engaged_trial_reward_rate_threshold: float = 2.0 - ) -> pd.Series: + def _get_engaged_trials(self, engaged_trial_reward_rate_threshold: float = 2.0) -> pd.Series: """ Gets `Series` where each trial that is considered "engaged" is set to `True` @@ -515,15 +478,10 @@ def _get_engaged_trials( `pd.Series` """ rolling_performance = self.rolling_performance - engaged_trial_mask = ( - rolling_performance["reward_rate"] - > engaged_trial_reward_rate_threshold - ) + engaged_trial_mask = rolling_performance["reward_rate"] > engaged_trial_reward_rate_threshold return engaged_trial_mask - def get_engaged_trial_count( - self, engaged_trial_reward_rate_threshold: float = 2.0 - ) -> int: + def get_engaged_trial_count(self, engaged_trial_reward_rate_threshold: float = 2.0) -> int: """Gets count of trials considered "engaged" Parameters @@ -537,8 +495,6 @@ def get_engaged_trial_count( count of trials considered "engaged" """ engaged_trials = self._get_engaged_trials( - engaged_trial_reward_rate_threshold=( - engaged_trial_reward_rate_threshold - ) + engaged_trial_reward_rate_threshold=(engaged_trial_reward_rate_threshold) ) return engaged_trials.sum() diff --git a/allensdk/brain_observatory/behavior/dprime.py b/allensdk/brain_observatory/behavior/dprime.py index 0c6e41ec3d..6dad8f2f73 100644 --- a/allensdk/brain_observatory/behavior/dprime.py +++ b/allensdk/brain_observatory/behavior/dprime.py @@ -21,37 +21,17 @@ def get_go_responses(hit=None, miss=None, aborted=None): return go_responses -def get_hit_rate( - hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW -): +def get_hit_rate(hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW): go_responses = get_go_responses(hit=hit, miss=miss, aborted=aborted) - hit_rate = ( - pd.Series(go_responses) - .rolling(window=sliding_window, min_periods=0) - .mean() - .values - ) + hit_rate = pd.Series(go_responses).rolling(window=sliding_window, min_periods=0).mean().values return hit_rate -def get_trial_count_corrected_hit_rate( - hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW -): +def get_trial_count_corrected_hit_rate(hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW): go_responses = get_go_responses(hit=hit, miss=miss, aborted=aborted) - go_responses_count = ( - pd.Series(go_responses) - .rolling(window=sliding_window, min_periods=0) - .count() - ) - hit_rate = ( - pd.Series(go_responses) - .rolling(window=sliding_window, min_periods=0) - .mean() - .values - ) - trial_count_corrected_hit_rate = np.vectorize(trial_number_limit)( - hit_rate, go_responses_count - ) + go_responses_count = pd.Series(go_responses).rolling(window=sliding_window, min_periods=0).count() + hit_rate = pd.Series(go_responses).rolling(window=sliding_window, min_periods=0).mean().values + trial_count_corrected_hit_rate = np.vectorize(trial_number_limit)(hit_rate, go_responses_count) return trial_count_corrected_hit_rate @@ -77,15 +57,8 @@ def get_false_alarm_rate( aborted=None, sliding_window=SLIDING_WINDOW, ): - catch_responses = get_catch_responses( - correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted - ) - false_alarm_rate = ( - pd.Series(catch_responses) - .rolling(window=sliding_window, min_periods=0) - .mean() - .values - ) + catch_responses = get_catch_responses(correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted) + false_alarm_rate = pd.Series(catch_responses).rolling(window=sliding_window, min_periods=0).mean().values return false_alarm_rate @@ -95,34 +68,16 @@ def get_trial_count_corrected_false_alarm_rate( aborted=None, sliding_window=SLIDING_WINDOW, ): - catch_responses = get_catch_responses( - correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted - ) - catch_responses_count = ( - pd.Series(catch_responses) - .rolling(window=sliding_window, min_periods=0) - .count() - ) - false_alarm_rate = ( - pd.Series(catch_responses) - .rolling(window=sliding_window, min_periods=0) - .mean() - .values - ) - trial_count_corrected_false_alarm_rate = np.vectorize(trial_number_limit)( - false_alarm_rate, catch_responses_count - ) + catch_responses = get_catch_responses(correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted) + catch_responses_count = pd.Series(catch_responses).rolling(window=sliding_window, min_periods=0).count() + false_alarm_rate = pd.Series(catch_responses).rolling(window=sliding_window, min_periods=0).mean().values + trial_count_corrected_false_alarm_rate = np.vectorize(trial_number_limit)(false_alarm_rate, catch_responses_count) return trial_count_corrected_false_alarm_rate -def get_rolling_dprime( - rolling_hit_rate, rolling_fa_rate, sliding_window=SLIDING_WINDOW -): +def get_rolling_dprime(rolling_hit_rate, rolling_fa_rate, sliding_window=SLIDING_WINDOW): return np.array( - [ - get_dprime(hr, far, sliding_window=SLIDING_WINDOW) - for hr, far in zip(rolling_hit_rate, rolling_fa_rate) - ] + [get_dprime(hr, far, sliding_window=SLIDING_WINDOW) for hr, far in zip(rolling_hit_rate, rolling_fa_rate)] ) diff --git a/allensdk/brain_observatory/behavior/event_detection.py b/allensdk/brain_observatory/behavior/event_detection.py index 118d8ae632..d8fd07c70b 100644 --- a/allensdk/brain_observatory/behavior/event_detection.py +++ b/allensdk/brain_observatory/behavior/event_detection.py @@ -2,8 +2,7 @@ from scipy import stats -def filter_events_array(arr: np.ndarray, scale: float = 2, - n_time_steps: int = 20) -> np.ndarray: +def filter_events_array(arr: np.ndarray, scale: float = 2, n_time_steps: int = 20) -> np.ndarray: """ Convolve the trace array with a 1d causal half-gaussian filter to smooth it for visualization @@ -27,15 +26,14 @@ def filter_events_array(arr: np.ndarray, scale: float = 2, Output of the convolution operation """ if len(arr.shape) == 1: - raise ValueError('Expected a 2d array but received a 1d array') + raise ValueError("Expected a 2d array but received a 1d array") if n_time_steps < 1: - raise ValueError(f'n_time_steps must be a minimum of 1 but received ' - f'{n_time_steps}') + raise ValueError(f"n_time_steps must be a minimum of 1 but received {n_time_steps}") filt = stats.halfnorm(loc=0, scale=scale).pdf(np.arange(n_time_steps)) filt = filt / np.sum(filt) # normalize filter filtered_arr = np.zeros(arr.shape) for i, trace in enumerate(arr): - filtered_arr[i] = np.convolve(arr[i], filt)[:len(arr[i])] + filtered_arr[i] = np.convolve(arr[i], filt)[: len(arr[i])] return filtered_arr diff --git a/allensdk/brain_observatory/behavior/eye_tracking_processing.py b/allensdk/brain_observatory/behavior/eye_tracking_processing.py index 91772b1847..f9fda52283 100644 --- a/allensdk/brain_observatory/behavior/eye_tracking_processing.py +++ b/allensdk/brain_observatory/behavior/eye_tracking_processing.py @@ -38,13 +38,12 @@ def load_eye_tracking_hdf(eye_tracking_file: Path) -> pd.DataFrame: eye_tracking_dfs = [] for field_name in eye_tracking_fields: field_data = pd.read_hdf(eye_tracking_file, key=field_name) - new_col_name_map = {col_name: f"{field_name}_{col_name}" - for col_name in field_data.columns} + new_col_name_map = {col_name: f"{field_name}_{col_name}" for col_name in field_data.columns} field_data.rename(new_col_name_map, axis=1, inplace=True) eye_tracking_dfs.append(field_data) eye_tracking_data = pd.concat(eye_tracking_dfs, axis=1) - eye_tracking_data.index.name = 'frame' + eye_tracking_data.index.name = "frame" # Values in the hdf5 may be complex (likely an artifact of the ellipse # fitting process). Take only the real component. @@ -53,8 +52,7 @@ def load_eye_tracking_hdf(eye_tracking_file: Path) -> pd.DataFrame: return eye_tracking_data.astype(float) -def determine_outliers(data_df: pd.DataFrame, - z_threshold: float) -> pd.Series: +def determine_outliers(data_df: pd.DataFrame, z_threshold: float) -> pd.Series: """Given a dataframe and some z-score threshold return a pandas boolean Series where each entry indicates whether a given row contains at least one outlier (where outliers are calculated along columns). @@ -74,8 +72,7 @@ def determine_outliers(data_df: pd.DataFrame, True denotes that a row in the data_df contains at least one outlier. """ - outliers = data_df.apply(stats.zscore, - nan_policy='omit').apply(np.abs) > z_threshold + outliers = data_df.apply(stats.zscore, nan_policy="omit").apply(np.abs) > z_threshold return pd.Series(outliers.any(axis=1)) @@ -122,10 +119,9 @@ def compute_elliptical_area(df_row: pd.Series) -> float: return np.pi * df_row.iloc[0] * df_row.iloc[1] -def determine_likely_blinks(eye_areas: pd.Series, - pupil_areas: pd.Series, - outliers: pd.Series, - dilation_frames: int = 2) -> pd.Series: +def determine_likely_blinks( + eye_areas: pd.Series, pupil_areas: pd.Series, outliers: pd.Series, dilation_frames: int = 2 +) -> pd.Series: """Determine eye tracking frames which contain likely blinks or outliers Parameters @@ -148,17 +144,15 @@ def determine_likely_blinks(eye_areas: pd.Series, """ blinks = pd.isnull(eye_areas) | pd.isnull(pupil_areas) | outliers if dilation_frames > 0: - likely_blinks = ndimage.binary_dilation(blinks, - iterations=dilation_frames) + likely_blinks = ndimage.binary_dilation(blinks, iterations=dilation_frames) else: likely_blinks = blinks return pd.Series(likely_blinks, index=eye_areas.index) -def process_eye_tracking_data(eye_data: pd.DataFrame, - frame_times: pd.Series, - z_threshold: float = 3.0, - dilation_frames: int = 2) -> pd.DataFrame: +def process_eye_tracking_data( + eye_data: pd.DataFrame, frame_times: pd.Series, z_threshold: float = 3.0, dilation_frames: int = 2 +) -> pd.DataFrame: """Processes and refines raw eye tracking data by adding additional computed feature columns. @@ -205,26 +199,22 @@ def process_eye_tracking_data(eye_data: pd.DataFrame, n_sync = len(frame_times) if n_sync != n_eye_frames: - raise EyeTrackingError(f"Error! The number of sync file frame times " - f"({len(frame_times)}) does not match the " - f"number of eye tracking frames " - f"({len(eye_data.index)})!") - - cr_areas = (eye_data[["cr_width", "cr_height"]] - .apply(compute_elliptical_area, axis=1)) - eye_areas = (eye_data[["eye_width", "eye_height"]] - .apply(compute_elliptical_area, axis=1)) - pupil_areas = (eye_data[["pupil_width", "pupil_height"]] - .apply(compute_circular_area, axis=1)) + raise EyeTrackingError( + f"Error! The number of sync file frame times " + f"({len(frame_times)}) does not match the " + f"number of eye tracking frames " + f"({len(eye_data.index)})!" + ) + + cr_areas = eye_data[["cr_width", "cr_height"]].apply(compute_elliptical_area, axis=1) + eye_areas = eye_data[["eye_width", "eye_height"]].apply(compute_elliptical_area, axis=1) + pupil_areas = eye_data[["pupil_width", "pupil_height"]].apply(compute_circular_area, axis=1) # only use eye and pupil areas for outlier detection area_df = pd.concat([eye_areas, pupil_areas], axis=1) outliers = determine_outliers(area_df, z_threshold=z_threshold) - likely_blinks = determine_likely_blinks(eye_areas, - pupil_areas, - outliers, - dilation_frames=dilation_frames) + likely_blinks = determine_likely_blinks(eye_areas, pupil_areas, outliers, dilation_frames=dilation_frames) # remove outliers/likely blinks `pupil_area`, `cr_area`, `eye_area` pupil_areas_raw = pupil_areas.copy() diff --git a/allensdk/brain_observatory/behavior/image_api.py b/allensdk/brain_observatory/behavior/image_api.py index 8418d1bcdf..e206e7f98c 100644 --- a/allensdk/brain_observatory/behavior/image_api.py +++ b/allensdk/brain_observatory/behavior/image_api.py @@ -4,7 +4,7 @@ class Image(NamedTuple): - ''' Describes a 2D Image + """Describes a 2D Image data : np.ndarray Image data points @@ -12,11 +12,11 @@ class Image(NamedTuple): Spacing describes the physical size of each pixel unit : str Physical unit of the spacing (currently constrained to be isotropic) - ''' + """ data: np.ndarray spacing: tuple - unit: str = 'mm' + unit: str = "mm" def __eq__(self, other): a = np.array_equal(self.data, other.data) @@ -29,17 +29,16 @@ def __array__(self): class ImageApi: - @staticmethod def serialize(data, spacing, unit): img = sitk.GetImageFromArray(data) img.SetSpacing(np.array(spacing, dtype=np.double)) - img.SetMetaData('unit', unit) + img.SetMetaData("unit", unit) return img @staticmethod def deserialize(img): data = sitk.GetArrayFromImage(img) spacing = img.GetSpacing() - unit = img.GetMetaData('unit') + unit = img.GetMetaData("unit") return Image(data, spacing, unit) diff --git a/allensdk/brain_observatory/behavior/mtrain.py b/allensdk/brain_observatory/behavior/mtrain.py index 002e4bf8b1..1de3b0fe18 100644 --- a/allensdk/brain_observatory/behavior/mtrain.py +++ b/allensdk/brain_observatory/behavior/mtrain.py @@ -4,7 +4,7 @@ def annotate_change_detect(trials): - """ adds `change` and `detect` columns to dataframe + """adds `change` and `detect` columns to dataframe Parameters ---------- @@ -18,14 +18,14 @@ def annotate_change_detect(trials): io.load_trials """ - trials['change'] = trials['trial_type'] == 'go' - trials['detect'] = trials['response'] == 1.0 + trials["change"] = trials["trial_type"] == "go" + trials["detect"] = trials["response"] == 1.0 return trials def assign_session_id(trials): - """ adds a column with a unique ID for the session defined as + """adds a column with a unique ID for the session defined as a combination of the mouse ID and startdatetime Parameters @@ -39,16 +39,13 @@ def assign_session_id(trials): -------- io.load_trials """ - trials['session_id'] = (trials['mouse_id'] - + '_' - + trials['startdatetime'].map( - lambda x: x.isoformat())) + trials["session_id"] = trials["mouse_id"] + "_" + trials["startdatetime"].map(lambda x: x.isoformat()) return trials def fix_change_time(trials): - """ forces `None` values in the `change_time` column to numpy NaN + """forces `None` values in the `change_time` column to numpy NaN Parameters ---------- @@ -61,14 +58,13 @@ def fix_change_time(trials): -------- io.load_trials """ - trials['change_time'] = trials['change_time'].map( - lambda x: np.nan if x is None else x) + trials["change_time"] = trials["change_time"].map(lambda x: np.nan if x is None else x) return trials def explode_response_window(trials): - """ explodes the `response_window` column in lower & upper columns + """explodes the `response_window` column in lower & upper columns Parameters ---------- @@ -81,16 +77,14 @@ def explode_response_window(trials): -------- io.load_trials """ - trials['response_window_lower'] = \ - trials['response_window'].map(lambda x: x[0]) - trials['response_window_upper'] = \ - trials['response_window'].map(lambda x: x[1]) + trials["response_window_lower"] = trials["response_window"].map(lambda x: x[0]) + trials["response_window_upper"] = trials["response_window"].map(lambda x: x[1]) return trials def annotate_trials(trials): - """ performs multiple annotatations: + """performs multiple annotatations: - annotate_change_detect - fix_change_time @@ -144,23 +138,23 @@ class ExtendedTrialSchema(Schema): """ index = fields.Int( - description='Trial number in this session', + description="Trial number in this session", required=True, ) startframe = fields.Int( - description='frame when this trial starts', + description="frame when this trial starts", required=True, ) starttime = fields.Float( - description='time in seconds when this trial starts', + description="time in seconds when this trial starts", required=True, ) endframe = fields.Int( - description='frame when this trial ends', + description="frame when this trial ends", required=True, ) endtime = fields.Float( - description='time in seconds when this trial ends', + description="time in seconds when this trial ends", required=True, ) trial_length = fields.Float( @@ -169,69 +163,66 @@ class ExtendedTrialSchema(Schema): # timing paramters change_frame = fields.Float( - description='The stimulus frame when the change occured on this trial', + description="The stimulus frame when the change occured on this trial", required=True, allow_nan=True, ) scheduled_change_time = fields.Float( - description=("The time when the change was scheduled to occur " - "on this trial"), + description=("The time when the change was scheduled to occur on this trial"), required=True, ) change_time = fields.Float( - description='The time when the change occured on this trial', + description="The time when the change occured on this trial", required=True, allow_nan=True, ) # image parameters initial_image_category = fields.String( - description='The category of the initial images on this trial', + description="The category of the initial images on this trial", required=True, allow_none=True, ) initial_image_name = fields.String( - description=("The name of the last initial image before the " - "change on this trial"), + description=("The name of the last initial image before the change on this trial"), required=True, allow_none=True, ) change_image_category = fields.String( - description='The category of the change images on this trial', + description="The category of the change images on this trial", required=True, allow_none=True, ) change_image_name = fields.String( - description='The name of the first change image on this trial', + description="The name of the first change image on this trial", required=True, allow_none=True, ) # oriented gratings paramters initial_contrast = fields.Float( - description='The contrast of the initial orientation on this trial', + description="The contrast of the initial orientation on this trial", required=True, allow_none=True, ) change_contrast = fields.Float( - description='The contrast of the change orientation on this trial', + description="The contrast of the change orientation on this trial", required=True, allow_none=True, ) initial_ori = fields.Float( - description='The orientation of the initial orientation on this trial', + description="The orientation of the initial orientation on this trial", required=True, allow_none=True, ) change_ori = fields.Float( - description='The orientation of the change orientation on this trial', + description="The orientation of the change orientation on this trial", required=True, allow_none=True, allow_nan=True, ) delta_ori = fields.Float( - description=("The difference between the initial and change " - "orientations on this trial"), + description=("The difference between the initial and change orientations on this trial"), required=True, allow_none=True, ) @@ -239,18 +230,17 @@ class ExtendedTrialSchema(Schema): # licks lick_times = fields.List( fields.Float, - description='times of licks on this trial', + description="times of licks on this trial", required=True, ) response_latency = fields.Float( - description=("The latency between the change and the first lick " - "on this trial"), + description=("The latency between the change and the first lick on this trial"), required=True, allow_nan=True, ) response_time = fields.List( fields.Float, - description='need to check this with Doug', + description="need to check this with Doug", required=True, ) reward_frames = fields.List( @@ -269,24 +259,22 @@ class ExtendedTrialSchema(Schema): ) auto_rewarded = fields.Bool( - description='whether this trial was an auto_rewarded trial', + description="whether this trial was an auto_rewarded trial", required=True, allow_none=True, ) cumulative_reward_number = fields.Int( - description=("the cumulative number of rewards in the session at " - "trial end"), + description=("the cumulative number of rewards in the session at trial end"), required=True, ) cumulative_volume = fields.Float( - description='the total volume of rewards in the session at trial end', + description="the total volume of rewards in the session at trial end", required=True, ) # optogenetics optogenetics = fields.Bool( - description=("whether optogenetic stimulation was applied " - "on this trial"), + description=("whether optogenetic stimulation was applied on this trial"), required=True, ) @@ -385,9 +373,7 @@ class ExtendedTrialSchema(Schema): date = FriendlyDate( required=True, ) - year = fields.Integer( - strict=True - ) + year = fields.Integer(strict=True) month = fields.Integer( required=True, strict=True, diff --git a/allensdk/brain_observatory/behavior/rewards_processing.py b/allensdk/brain_observatory/behavior/rewards_processing.py index 81b66bb9dc..cc8b3b722f 100644 --- a/allensdk/brain_observatory/behavior/rewards_processing.py +++ b/allensdk/brain_observatory/behavior/rewards_processing.py @@ -3,8 +3,7 @@ import pandas as pd -def get_rewards(data: Dict, - timestamps: np.ndarray) -> pd.DataFrame: +def get_rewards(data: Dict, timestamps: np.ndarray) -> pd.DataFrame: """ Construct and return a pandas DataFrame containing reward data for this session diff --git a/allensdk/brain_observatory/behavior/schemas.py b/allensdk/brain_observatory/behavior/schemas.py index c37b05ce76..b290eb3c1b 100644 --- a/allensdk/brain_observatory/behavior/schemas.py +++ b/allensdk/brain_observatory/behavior/schemas.py @@ -1,11 +1,22 @@ import numpy as np from marshmallow import RAISE, Schema, fields -STYPE_DICT = {fields.Float: 'float', fields.Int: 'int', - fields.String: 'text', fields.List: 'text', - fields.DateTime: 'text', fields.UUID: 'text'} -TYPE_DICT = {fields.Float: float, fields.Int: int, fields.String: str, - fields.List: np.ndarray, fields.DateTime: str, fields.UUID: str} +STYPE_DICT = { + fields.Float: "float", + fields.Int: "int", + fields.String: "text", + fields.List: "text", + fields.DateTime: "text", + fields.UUID: "text", +} +TYPE_DICT = { + fields.Float: float, + fields.Int: int, + fields.String: str, + fields.List: np.ndarray, + fields.DateTime: str, + fields.UUID: str, +} class RaisingSchema(Schema): @@ -18,15 +29,15 @@ class SubjectMetadataSchema(RaisingSchema): behavior or behavior + ophys experiment. """ - neurodata_type = 'BehaviorSubject' - neurodata_type_inc = 'Subject' + neurodata_type = "BehaviorSubject" + neurodata_type_inc = "Subject" neurodata_doc = "Metadata for an AIBS behavior or behavior + ophys subject" # Fields to skip converting to extension # In this case they already exist in the 'Subject' builtin pyNWB class neurodata_skip = {"age_in_days", "genotype", "sex", "subject_id"} age_in_days = fields.String( - doc='Age of the specimen donor/subject (in days)', + doc="Age of the specimen donor/subject (in days)", required=True, ) driver_line = fields.List( @@ -38,12 +49,12 @@ class SubjectMetadataSchema(RaisingSchema): ) # 'full_genotype' will be stored in pynwb Subject 'genotype' attr genotype = fields.String( - doc='full genotype of subject', + doc="full genotype of subject", required=True, ) # 'mouse_id' will be stored in pynwb Subject 'subject_id' attr subject_id = fields.Int( - doc='Mouse ID of subject', + doc="Mouse ID of subject", required=True, ) reporter_line = fields.String( @@ -51,49 +62,45 @@ class SubjectMetadataSchema(RaisingSchema): required=True, ) sex = fields.String( - doc='Sex of the specimen donor/subject', + doc="Sex of the specimen donor/subject", required=True, ) class BehaviorMetadataSchema(RaisingSchema): - """This schema contains metadata pertaining to behavior. - """ - neurodata_type = 'BehaviorMetadata' - neurodata_type_inc = 'LabMetaData' + """This schema contains metadata pertaining to behavior.""" + + neurodata_type = "BehaviorMetadata" + neurodata_type_inc = "LabMetaData" neurodata_doc = "Metadata for behavior and behavior + ophys experiments" neurodata_skip = {"date_of_acquisition"} - behavior_session_id = fields.Int( - doc='The unique ID for the behavior session', - required=True - ) + behavior_session_id = fields.Int(doc="The unique ID for the behavior session", required=True) behavior_session_uuid = fields.UUID( - doc='MTrain record for session, also called foraging_id', + doc="MTrain record for session, also called foraging_id", required=True, ) stimulus_frame_rate = fields.Float( - doc=('Frame rate (frames/second) of the ' - 'visual_stimulus from the monitor'), + doc=("Frame rate (frames/second) of the visual_stimulus from the monitor"), required=True, ) session_type = fields.String( - doc='Experimental session description', + doc="Experimental session description", allow_none=True, required=True, ) # 'date_of_acquisition' will be stored in # pynwb NWBFile 'session_start_time' attr date_of_acquisition = fields.DateTime( - doc='Date of the experiment (UTC, as string)', + doc="Date of the experiment (UTC, as string)", required=True, ) equipment_name = fields.String( - doc='Name of behavior or optical physiology experiment rig', + doc="Name of behavior or optical physiology experiment rig", required=True, ) project_code = fields.String( - doc='String Id of project associated with session.', + doc="String Id of project associated with session.", allow_none=True, required=True, ) @@ -102,55 +109,49 @@ class BehaviorMetadataSchema(RaisingSchema): class NwbOphysMetadataSchema(RaisingSchema): """This schema contains fields that will be stored in pyNWB base classes pertaining to optical physiology.""" + # 'emission_lambda' will be stored in # pyNWB OpticalChannel 'emission_lambda' attr emission_lambda = fields.Float( - doc='Emission lambda of fluorescent indicator', + doc="Emission lambda of fluorescent indicator", required=True, ) # 'excitation_lambda' will be stored in the pyNWB ImagingPlane # 'excitation_lambda' attr excitation_lambda = fields.Float( - doc='Excitation lambda of fluorescent indicator', + doc="Excitation lambda of fluorescent indicator", required=True, ) # 'indicator' will be stored in the pyNWB ImagingPlane 'indicator' attr indicator = fields.String( - doc='Name of optical physiology fluorescent indicator', + doc="Name of optical physiology fluorescent indicator", required=True, ) # 'targeted_structure' will be stored in the pyNWB # ImagingPlane 'location' attr targeted_structure = fields.String( - doc='Anatomical structure targeted for two-photon acquisition', + doc="Anatomical structure targeted for two-photon acquisition", required=True, ) # 'ophys_frame_rate' will be stored in the pyNWB ImagingPlane # 'imaging_rate' attr ophys_frame_rate = fields.Float( - doc='Frame rate (frames/second) of the two-photon microscope', + doc="Frame rate (frames/second) of the two-photon microscope", required=True, ) class OphysMetadataSchema(NwbOphysMetadataSchema): - """This schema contains metadata pertaining to optical physiology (ophys). - """ - ophys_experiment_id = fields.Int( - doc='Unique ID for the ophys experiment (aka imaging plane)', - required=True - ) - ophys_session_id = fields.Int( - doc='Unique ID for the ophys session', - required=True - ) + """This schema contains metadata pertaining to optical physiology (ophys).""" + + ophys_experiment_id = fields.Int(doc="Unique ID for the ophys experiment (aka imaging plane)", required=True) + ophys_session_id = fields.Int(doc="Unique ID for the ophys session", required=True) ophys_container_id = fields.Int( - doc='Container ID for the container that contains this ophys session', + doc="Container ID for the container that contains this ophys session", required=True, ) imaging_depth = fields.Int( - doc=('Depth (microns) below the cortical surface ' - 'targeted for two-photon acquisition'), + doc=("Depth (microns) below the cortical surface targeted for two-photon acquisition"), required=True, ) targeted_imaging_depth = fields.Int( @@ -162,53 +163,63 @@ class OphysMetadataSchema(NwbOphysMetadataSchema): required=True, ) field_of_view_width = fields.Int( - doc='Width of optical physiology imaging plane in pixels', + doc="Width of optical physiology imaging plane in pixels", required=True, ) field_of_view_height = fields.Int( - doc='Height of optical physiology imaging plane in pixels', + doc="Height of optical physiology imaging plane in pixels", required=True, ) imaging_plane_group = fields.Int( - doc=('A numeric index which indicates the order that an imaging plane ' - 'was acquired for a mesoscope experiment. Will be -1 for ' - 'non-mesoscope data'), - required=True + doc=( + "A numeric index which indicates the order that an imaging plane " + "was acquired for a mesoscope experiment. Will be -1 for " + "non-mesoscope data" + ), + required=True, ) imaging_plane_group_count = fields.Int( - doc=('The total number of plane groups collected in a session ' - 'for a mesoscope experiment. Will be 0 if the scope did not ' - 'capture multiple concurrent imaging planes.'), - required=True + doc=( + "The total number of plane groups collected in a session " + "for a mesoscope experiment. Will be 0 if the scope did not " + "capture multiple concurrent imaging planes." + ), + required=True, ) project_code = fields.String( - doc='String Id of project associated with session.', + doc="String Id of project associated with session.", allow_none=True, required=True, ) class OphysBehaviorMetadataSchema(BehaviorMetadataSchema, OphysMetadataSchema): - """ This schema contains fields pertaining to ophys+behavior. It is used + """This schema contains fields pertaining to ophys+behavior. It is used as a template for generating our custom NWB behavior + ophys extension. """ - neurodata_type = 'OphysBehaviorMetadata' - neurodata_type_inc = 'BehaviorMetadata' + + neurodata_type = "OphysBehaviorMetadata" + neurodata_type_inc = "BehaviorMetadata" neurodata_doc = "Metadata for behavior + ophys experiments" # Fields to skip converting to extension # They already exist as attributes for the following pyNWB classes: # OpticalChannel, ImagingPlane, NWBFile - neurodata_skip = {"emission_lambda", "excitation_lambda", "indicator", - "targeted_structure", "date_of_acquisition", - "ophys_frame_rate"} + neurodata_skip = { + "emission_lambda", + "excitation_lambda", + "indicator", + "targeted_structure", + "date_of_acquisition", + "ophys_frame_rate", + } -class CompleteOphysBehaviorMetadataSchema(OphysBehaviorMetadataSchema, - SubjectMetadataSchema): +class CompleteOphysBehaviorMetadataSchema(OphysBehaviorMetadataSchema, SubjectMetadataSchema): """This schema combines fields from behavior, ophys, and subject schemas. Metadata info is passed by the behavior+ophys session in a combined lump containing all the field types. """ + pass @@ -216,115 +227,84 @@ class BehaviorTaskParametersSchema(RaisingSchema): """This schema encompasses task parameters used for behavior or ophys + behavior. """ - neurodata_type = 'BehaviorTaskParameters' - neurodata_type_inc = 'LabMetaData' + + neurodata_type = "BehaviorTaskParameters" + neurodata_type_inc = "LabMetaData" neurodata_doc = "Metadata for behavior or behavior + ophys task parameters" blank_duration_sec = fields.List( fields.Float, - doc=('The lower and upper bound (in seconds) for a randomly chosen ' - 'inter-stimulus interval duration for a trial'), + doc=( + "The lower and upper bound (in seconds) for a randomly chosen inter-stimulus interval duration for a trial" + ), required=True, shape=(2,), ) stimulus_duration_sec = fields.Float( - doc='Duration of each stimulus presentation in seconds', - required=True, - allow_nan=True + doc="Duration of each stimulus presentation in seconds", required=True, allow_nan=True ) omitted_flash_fraction = fields.Float( - doc='Fraction of flashes/image presentations that were omitted', + doc="Fraction of flashes/image presentations that were omitted", required=True, allow_nan=True, ) response_window_sec = fields.List( fields.Float, - doc=('The lower and upper bound (in seconds) for a randomly chosen ' - 'time window where subject response influences trial outcome'), + doc=( + "The lower and upper bound (in seconds) for a randomly chosen " + "time window where subject response influences trial outcome" + ), required=True, shape=(2,), ) reward_volume = fields.Float( - doc='Volume of water (in mL) delivered as reward', + doc="Volume of water (in mL) delivered as reward", required=True, ) auto_reward_volume = fields.Float( - doc='Volume of water (in mL) delivered as an automatic reward', + doc="Volume of water (in mL) delivered as an automatic reward", required=True, ) session_type = fields.String( - doc='Stage of behavioral task', + doc="Stage of behavioral task", required=True, ) stimulus = fields.String( - doc='Stimulus type', + doc="Stimulus type", required=True, ) stimulus_distribution = fields.String( - doc=("Distribution type of drawing change times " - "(e.g. 'geometric', 'exponential')"), + doc=("Distribution type of drawing change times (e.g. 'geometric', 'exponential')"), required=True, ) task = fields.String( - doc='The name of the behavioral task', + doc="The name of the behavioral task", required=True, ) n_stimulus_frames = fields.Int( - doc='Total number of stimuli frames', + doc="Total number of stimuli frames", required=True, ) - stimulus_name = fields.String( - doc="Name of the stimulus file set presented", - required=False - ) + stimulus_name = fields.String(doc="Name of the stimulus file set presented", required=False) class EyeTrackingRigGeometry(RaisingSchema): """Eye tracking rig geometry""" - values = fields.Float( - doc='position/rotation with respect to (x, y, z)', - required=True, - shape=(3,) - ) - unit_of_measurement = fields.Str( - doc='Unit of measurement for the data', - required=True - ) + + values = fields.Float(doc="position/rotation with respect to (x, y, z)", required=True, shape=(3,)) + unit_of_measurement = fields.Str(doc="Unit of measurement for the data", required=True) class OphysEyeTrackingRigMetadataSchema(RaisingSchema): - """This schema encompasses metadata for ophys experiment rig - """ - neurodata_type = 'OphysEyeTrackingRigMetadata' - neurodata_type_inc = 'NWBDataInterface' + """This schema encompasses metadata for ophys experiment rig""" + + neurodata_type = "OphysEyeTrackingRigMetadata" + neurodata_type_inc = "NWBDataInterface" neurodata_doc = "Metadata for ophys experiment rig" - equipment = fields.Str( - doc='Description of rig', - required=True - ) - monitor_position = fields.Nested( - EyeTrackingRigGeometry, - doc='position of monitor (x, y, z)', - required=True - ) - camera_position = fields.Nested( - EyeTrackingRigGeometry, - doc='position of camera (x, y, z)', - required=True - ) - led_position = fields.Nested( - EyeTrackingRigGeometry, - doc='position of LED (x, y, z)', - required=True - ) - monitor_rotation = fields.Nested( - EyeTrackingRigGeometry, - doc='rotation of monitor (x, y, z)', - required=True - ) - camera_rotation = fields.Nested( - EyeTrackingRigGeometry, - doc='rotation of camera (x, y, z)', - required=True - ) + equipment = fields.Str(doc="Description of rig", required=True) + monitor_position = fields.Nested(EyeTrackingRigGeometry, doc="position of monitor (x, y, z)", required=True) + camera_position = fields.Nested(EyeTrackingRigGeometry, doc="position of camera (x, y, z)", required=True) + led_position = fields.Nested(EyeTrackingRigGeometry, doc="position of LED (x, y, z)", required=True) + monitor_rotation = fields.Nested(EyeTrackingRigGeometry, doc="rotation of monitor (x, y, z)", required=True) + camera_rotation = fields.Nested(EyeTrackingRigGeometry, doc="rotation of camera (x, y, z)", required=True) diff --git a/allensdk/brain_observatory/behavior/session_metrics.py b/allensdk/brain_observatory/behavior/session_metrics.py index 96d4287190..d500171607 100644 --- a/allensdk/brain_observatory/behavior/session_metrics.py +++ b/allensdk/brain_observatory/behavior/session_metrics.py @@ -1,5 +1,6 @@ from allensdk.brain_observatory.behavior import trial_masks as masks + def response_bias(trials, detect_col, trial_types=("go", "catch")): """ Calculate the response bias for a subset of trial types from a behavioral @@ -8,8 +9,8 @@ def response_bias(trials, detect_col, trial_types=("go", "catch")): trials (pandas.DataFrame): Dataframe containing trial-level information from a behavioral training session. Required columns: "trial_type", `detect_col`. - detect_col (str): Name of column containing boolean - or numeric codings (0/1) for whether or not the mouse had a + detect_col (str): Name of column containing boolean + or numeric codings (0/1) for whether or not the mouse had a response. trial_types (iterable): Iterable containing string trial types to check for the response bias. Trials of types not included in this @@ -27,10 +28,9 @@ def num_contingent_trials(session_trials): Returns the number of "go" and "catch" trials in a training session dataframe. Args: - session_trials (pandas.DataFrame): a pandas.DataFrame describing + session_trials (pandas.DataFrame): a pandas.DataFrame describing behavior training trials, with the string column "trial_type" describing the type of trial. Returns (int): Number of "go" and "catch" trials """ return session_trials["trial_type"].isin(["go", "catch"]).sum() - diff --git a/allensdk/brain_observatory/behavior/stimulus_processing.py b/allensdk/brain_observatory/behavior/stimulus_processing.py index 7ea2afd966..b44d82e38d 100644 --- a/allensdk/brain_observatory/behavior/stimulus_processing.py +++ b/allensdk/brain_observatory/behavior/stimulus_processing.py @@ -39,9 +39,7 @@ def get_stimulus_presentations(data, stimulus_timestamps) -> pd.DataFrame: stimulus_table = get_visual_stimuli_df(data, stimulus_timestamps) # workaround to rename columns to harmonize with visual # coding and rebase timestamps to sync time - stimulus_table.insert( - loc=0, column="flash_number", value=np.arange(0, len(stimulus_table)) - ) + stimulus_table.insert(loc=0, column="flash_number", value=np.arange(0, len(stimulus_table))) stimulus_table = stimulus_table.rename( columns={ "frame": "start_frame", @@ -50,8 +48,7 @@ def get_stimulus_presentations(data, stimulus_timestamps) -> pd.DataFrame: } ) stimulus_table.start_time = [ - stimulus_timestamps[int(start_frame)] - for start_frame in stimulus_table.start_frame.values + stimulus_timestamps[int(start_frame)] for start_frame in stimulus_table.start_frame.values ] end_time = [] for end_frame in stimulus_table.end_frame.values: @@ -226,11 +223,7 @@ def get_stimulus_templates( attrs = images["image_attributes"] image_values = images["images"] if limit_to_images is not None: - keep_idxs = [ - i - for i in range(len(images)) - if attrs[i]["image_name"] in limit_to_images - ] + keep_idxs = [i for i in range(len(images)) if attrs[i]["image_name"] in limit_to_images] attrs = [attrs[i] for i in keep_idxs] image_values = [image_values[i] for i in keep_idxs] @@ -247,9 +240,7 @@ def get_stimulus_templates( "because this pkl data contains " "gratings presentations." ) - gratings_metadata = get_gratings_metadata(pkl_stimuli).to_dict( - orient="records" - ) + gratings_metadata = get_gratings_metadata(pkl_stimuli).to_dict(orient="records") unwarped_images = [] warped_images = [] @@ -299,12 +290,8 @@ def get_stimulus_metadata(pkl) -> pd.DataFrame: if "images" in stimuli: images = get_images_dict(pkl) stimulus_index_df = pd.DataFrame(images["image_attributes"]) - image_set_filename = convert_filepath_caseinsensitive( - images["metadata"]["image_set"] - ) - stimulus_index_df["image_set"] = get_image_set_name( - image_set_path=image_set_filename - ) + image_set_filename = convert_filepath_caseinsensitive(images["metadata"]["image_set"]) + stimulus_index_df["image_set"] = get_image_set_name(image_set_path=image_set_filename) else: stimulus_index_df = pd.DataFrame( columns=[ @@ -328,12 +315,8 @@ def get_stimulus_metadata(pkl) -> pd.DataFrame: ) # get the grating metadata will be empty if gratings are absent - grating_df = get_gratings_metadata( - stimuli, start_idx=len(stimulus_index_df) - ) - stimulus_index_df = pd.concat( - [stimulus_index_df, grating_df], ignore_index=True, sort=False - ) + grating_df = get_gratings_metadata(stimuli, start_idx=len(stimulus_index_df)) + stimulus_index_df = pd.concat([stimulus_index_df, grating_df], ignore_index=True, sort=False) # Add an entry for omitted stimuli omitted_df = pd.DataFrame( @@ -347,9 +330,7 @@ def get_stimulus_metadata(pkl) -> pd.DataFrame: "image_index": len(stimulus_index_df), } ) - stimulus_index_df = pd.concat( - [stimulus_index_df, omitted_df], ignore_index=True, sort=False - ) + stimulus_index_df = pd.concat([stimulus_index_df, omitted_df], ignore_index=True, sort=False) stimulus_index_df.set_index(["image_index"], inplace=True, drop=True) return stimulus_index_df @@ -406,9 +387,7 @@ def _get_stimulus_epoch( return start_frame, next_set_event[3] # end frame isn't inclusive -def _get_draw_epochs( - draw_log: List[int], start_frame: int, stop_frame: int -) -> List[Tuple[int, int]]: +def _get_draw_epochs(draw_log: List[int], start_frame: int, stop_frame: int) -> List[Tuple[int, int]]: """ Gets the frame numbers of the active frames within a stimulus window. Stimulus epochs come in the form [0, 0, 1, 1, 0, 0] where the stimulus is @@ -494,9 +473,7 @@ def get_visual_stimuli_df(data, time) -> pd.DataFrame: n_frames = len(time) visual_stimuli_data = [] for stim_dict in stimuli.values(): - for idx, (attr_name, attr_value, _, frame) in enumerate( - stim_dict["set_log"] - ): + for idx, (attr_name, attr_value, _, frame) in enumerate(stim_dict["set_log"]): orientation = attr_value if attr_name.lower() == "ori" else np.nan image_name = attr_value if attr_name.lower() == "image" else np.nan @@ -506,9 +483,7 @@ def get_visual_stimuli_df(data, time) -> pd.DataFrame: frame, n_frames, ) - draw_epochs = _get_draw_epochs( - stim_dict["draw_log"], *stimulus_epoch - ) + draw_epochs = _get_draw_epochs(stim_dict["draw_log"], *stimulus_epoch) for epoch_start, epoch_end in draw_epochs: visual_stimuli_data.append( @@ -529,9 +504,7 @@ def get_visual_stimuli_df(data, time) -> pd.DataFrame: # Add omitted flash info: try: - omitted_flash_frame_log = data["items"]["behavior"][ - "omitted_flash_frame_log" - ] + omitted_flash_frame_log = data["items"]["behavior"]["omitted_flash_frame_log"] except KeyError: # For sessions for which there were no omitted flashes omitted_flash_frame_log = dict() @@ -545,9 +518,7 @@ def get_visual_stimuli_df(data, time) -> pd.DataFrame: # to see if they are in the stim log offsets = np.arange(-3, 4) offset_arr = np.add( - np.repeat( - omitted_flash_frames[:, np.newaxis], offsets.shape[0], axis=1 - ), + np.repeat(omitted_flash_frames[:, np.newaxis], offsets.shape[0], axis=1), offsets, ) matched_any_offset = np.any(np.isin(offset_arr, stim_frames), axis=1) @@ -570,11 +541,7 @@ def get_visual_stimuli_df(data, time) -> pd.DataFrame: } ) - df = ( - pd.concat((visual_stimuli_df, omitted_df), sort=False) - .sort_values("frame") - .reset_index() - ) + df = pd.concat((visual_stimuli_df, omitted_df), sort=False).sort_values("frame").reset_index() return df @@ -616,9 +583,7 @@ def is_change_event(stimulus_presentations: pd.DataFrame) -> pd.Series: is_change = stimuli != prev_stimuli # reset back to original index - is_change = is_change.reindex(stimulus_presentations.index).rename( - "is_change" - ) + is_change = is_change.reindex(stimulus_presentations.index).rename("is_change") # Excluded stimuli are not change events is_change = is_change.fillna(False) @@ -658,15 +623,11 @@ def get_flashes_since_change( if row["is_change"] or idx == 0: flashes_since_change.iloc[idx] = 0 else: - flashes_since_change.iloc[idx] = ( - flashes_since_change.iloc[idx - 1] + 1 - ) + flashes_since_change.iloc[idx] = flashes_since_change.iloc[idx - 1] + 1 return flashes_since_change -def add_active_flag( - stim_pres_table: pd.DataFrame, trials: pd.DataFrame -) -> pd.DataFrame: +def add_active_flag(stim_pres_table: pd.DataFrame, trials: pd.DataFrame) -> pd.DataFrame: """Mark the active stimuli by lining up the stimulus times with the trials times. @@ -708,9 +669,7 @@ def add_active_flag( return stim_pres_table -def compute_trials_id_for_stimulus( - stim_pres_table: pd.DataFrame, trials_table: pd.DataFrame -) -> pd.Series: +def compute_trials_id_for_stimulus(stim_pres_table: pd.DataFrame, trials_table: pd.DataFrame) -> pd.Series: """Add an id to allow for merging of the stimulus presentations table with the trials table. @@ -757,10 +716,7 @@ def compute_trials_id_for_stimulus( trials_ids[stim_mask] = idx # Return input frame if the stimulus_block or active is not available. - if ( - "stimulus_block" not in stim_pres_table.columns - or "active" not in stim_pres_table.columns - ): + if "stimulus_block" not in stim_pres_table.columns or "active" not in stim_pres_table.columns: return trials_ids active_sorted = stim_pres_table.active @@ -775,9 +731,7 @@ def compute_trials_id_for_stimulus( active_stim_blocks = stim_blocks[active_sorted].unique() # Find passive blocks that show images for potential copying of the active # into a passive stimulus block. - passive_stim_blocks = stim_blocks[ - np.logical_and(~active_sorted, ~stim_image_names.isna()) - ].unique() + passive_stim_blocks = stim_blocks[np.logical_and(~active_sorted, ~stim_image_names.isna())].unique() # Copy the trials_id into the passive block if it exists. if len(passive_stim_blocks) > 0: @@ -786,12 +740,8 @@ def compute_trials_id_for_stimulus( active_images = stim_image_names[active_block_mask].values for passive_stim_block in passive_stim_blocks: passive_block_mask = stim_blocks == passive_stim_block - if np.array_equal( - active_images, stim_image_names[passive_block_mask].values - ): - trials_ids.loc[passive_block_mask] = trials_ids[ - active_block_mask - ].values + if np.array_equal(active_images, stim_image_names[passive_block_mask].values): + trials_ids.loc[passive_block_mask] = trials_ids[active_block_mask].values return trials_ids.sort_index() @@ -812,16 +762,9 @@ def fix_omitted_end_frame(stim_pres_table: pd.DataFrame) -> pd.DataFrame: Copy of input DataFrame with filled omitted, ``end_frame`` values and fixed typing. """ - median_stim_frame_duration = np.nanmedian( - stim_pres_table["end_frame"] - stim_pres_table["start_frame"] - ) - omitted_end_frames = ( - stim_pres_table[stim_pres_table["omitted"]]["start_frame"] - + median_stim_frame_duration - ) - stim_pres_table.loc[ - stim_pres_table["omitted"], "end_frame" - ] = omitted_end_frames + median_stim_frame_duration = np.nanmedian(stim_pres_table["end_frame"] - stim_pres_table["start_frame"]) + omitted_end_frames = stim_pres_table[stim_pres_table["omitted"]]["start_frame"] + median_stim_frame_duration + stim_pres_table.loc[stim_pres_table["omitted"], "end_frame"] = omitted_end_frames stim_dtypes = stim_pres_table.dtypes.to_dict() stim_dtypes["start_frame"] = int @@ -830,9 +773,7 @@ def fix_omitted_end_frame(stim_pres_table: pd.DataFrame) -> pd.DataFrame: return stim_pres_table.astype(stim_dtypes) -def produce_stimulus_block_names( - stim_df: pd.DataFrame, session_type: str, project_code: str -) -> pd.DataFrame: +def produce_stimulus_block_names(stim_df: pd.DataFrame, session_type: str, project_code: str) -> pd.DataFrame: """Add a column stimulus_block_name to explicitly reference the kind of stimulus block in addition to the numbered blocks. @@ -869,16 +810,12 @@ def produce_stimulus_block_names( block_id = stim_block if len(stim_df.stimulus_block.unique()) == 1: block_id += 1 - stim_df.loc[ - stim_df["stimulus_block"] == stim_block, "stimulus_block_name" - ] = vbo_map[block_id] + stim_df.loc[stim_df["stimulus_block"] == stim_block, "stimulus_block_name"] = vbo_map[block_id] return stim_df -def compute_is_sham_change( - stim_df: pd.DataFrame, trials: pd.DataFrame -) -> pd.DataFrame: +def compute_is_sham_change(stim_df: pd.DataFrame, trials: pd.DataFrame) -> pd.DataFrame: """Add is_sham_change to stimulus presentation table. Parameters @@ -893,23 +830,13 @@ def compute_is_sham_change( stimulus_presentations : pandas.DataFrame Input ``stim_df`` DataFrame with the is_sham_change column added. """ - if ( - "trials_id" not in stim_df.columns - or "active" not in stim_df.columns - or "stimulus_block" not in stim_df.columns - ): + if "trials_id" not in stim_df.columns or "active" not in stim_df.columns or "stimulus_block" not in stim_df.columns: return stim_df - stim_trials = stim_df.merge( - trials, left_on="trials_id", right_index=True, how="left" - ) - catch_frames = stim_trials[stim_trials["catch"].fillna(False)][ - "change_frame" - ].unique() + stim_trials = stim_df.merge(trials, left_on="trials_id", right_index=True, how="left") + catch_frames = stim_trials[stim_trials["catch"].fillna(False)]["change_frame"].unique() stim_df["is_sham_change"] = False - catch_flashes = stim_df[ - stim_df["start_frame"].isin(catch_frames) - ].index.values + catch_flashes = stim_df[stim_df["start_frame"].isin(catch_frames)].index.values stim_df.loc[catch_flashes, "is_sham_change"] = True stim_blocks = stim_df.stimulus_block @@ -917,9 +844,7 @@ def compute_is_sham_change( active_stim_blocks = stim_blocks[stim_df.active].unique() # Find passive blocks that show images for potential copying of the active # into a passive stimulus block. - passive_stim_blocks = stim_blocks[ - np.logical_and(~stim_df.active, ~stim_image_names.isna()) - ].unique() + passive_stim_blocks = stim_blocks[np.logical_and(~stim_df.active, ~stim_image_names.isna())].unique() # Copy the trials_id into the passive block if it exists. if len(passive_stim_blocks) > 0: @@ -928,11 +853,9 @@ def compute_is_sham_change( active_images = stim_image_names[active_block_mask].values for passive_stim_block in passive_stim_blocks: passive_block_mask = stim_blocks == passive_stim_block - if np.array_equal( - active_images, stim_image_names[passive_block_mask].values - ): - stim_df.loc[ - passive_block_mask, "is_sham_change" - ] = stim_df[active_block_mask]["is_sham_change"].values + if np.array_equal(active_images, stim_image_names[passive_block_mask].values): + stim_df.loc[passive_block_mask, "is_sham_change"] = stim_df[active_block_mask][ + "is_sham_change" + ].values return stim_df.sort_index() diff --git a/allensdk/brain_observatory/behavior/swdb/analysis_tools.py b/allensdk/brain_observatory/behavior/swdb/analysis_tools.py index 8915cddcc8..9d61d593be 100644 --- a/allensdk/brain_observatory/behavior/swdb/analysis_tools.py +++ b/allensdk/brain_observatory/behavior/swdb/analysis_tools.py @@ -1,13 +1,14 @@ import numpy as np import bisect + def get_nearest_frame(timepoint, timestamps): - ''' + """ Get the nearest frame timestamp for any time point This is kinda not true. This returns the index at which you would insert the timepoint to retain the sort order of the list, so if you - use the index on the list of timestamps you will always get the smallest + use the index on the list of timestamps you will always get the smallest timestamps that is larger than your input timestamp (not alway the closest) Args: @@ -16,13 +17,13 @@ def get_nearest_frame(timepoint, timestamps): Returns: nearest_frame (int): The index of the next frame in time - ''' + """ nearest_frame = bisect.bisect_left(timestamps, timepoint) return nearest_frame -def get_trace_around_timepoint(trace, timepoint, timestamps, - window_around_timepoint_seconds, frame_rate): - ''' + +def get_trace_around_timepoint(trace, timepoint, timestamps, window_around_timepoint_seconds, frame_rate): + """ Return the values around a timepoint using a window defined in seconds Args: @@ -31,9 +32,9 @@ def get_trace_around_timepoint(trace, timepoint, timestamps, timestamps (np.array): Timestamp in seconds for each point in the trace window_around_timepoint_seconds (list with len==2): [-2, 3] for a window that starts 2 seconds before the timepoint and - ends 3 seconds after. + ends 3 seconds after. frame_rate (float): The frame rate at which the trace is collected. - ''' + """ assert trace.shape == timestamps.shape @@ -45,12 +46,14 @@ def get_trace_around_timepoint(trace, timepoint, timestamps, timepoints = np.array(timestamps[lower_frame:upper_frame]) return trace, timepoints + def get_mean_in_window(trace, window_after_trace_start_seconds, frame_rate): window = window_after_trace_start_seconds.copy() - mean = np.nanmean(trace[int(window[0] * frame_rate): int(window[1] * frame_rate)]) + mean = np.nanmean(trace[int(window[0] * frame_rate) : int(window[1] * frame_rate)]) return mean -if __name__=="__main__": + +if __name__ == "__main__": trace = np.arange(100, dtype=float) timestamps = np.arange(100, dtype=float) @@ -61,29 +64,20 @@ def get_mean_in_window(trace, window_after_trace_start_seconds, frame_rate): assert timestamps[b] == 51.0 window_around_timepoint_seconds = [-5, 5] - t_vals, t_ts = get_trace_around_timepoint(trace, 49.9, timestamps, - window_around_timepoint_seconds, - frame_rate=1) + t_vals, t_ts = get_trace_around_timepoint(trace, 49.9, timestamps, window_around_timepoint_seconds, frame_rate=1) - assert np.all(t_vals == np.array([45., 46., 47., 48., 49., 50., 51., 52., 53., 54.])) + assert np.all(t_vals == np.array([45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0])) # def traces_around_timepoints(trace_values, trace_timestamps, event_times, window): # ''' # Get peri-event slices of a trace. -# +# # Args: # trace_values (1d np.array): Trace for one cell # trace_timestamps (1d np.array): Timestamps for each trace value # event_times (np.array): The times of events you want traces for # window (2-tuple): Time range around event times -# +# # Returns # eventlocked_traces (np.array with shape (n_events, n_samples_in_window)) # ''' - - - - - - - diff --git a/allensdk/brain_observatory/behavior/swdb/behavior_project_cache.py b/allensdk/brain_observatory/behavior/swdb/behavior_project_cache.py index d01ee48b6f..7b3181acaa 100644 --- a/allensdk/brain_observatory/behavior/swdb/behavior_project_cache.py +++ b/allensdk/brain_observatory/behavior/swdb/behavior_project_cache.py @@ -5,33 +5,28 @@ import re from allensdk import one -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.session_apis.data_io import ( - BehaviorOphysNwbApi) -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import \ - BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.session_apis.data_io import BehaviorOphysNwbApi +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment from allensdk.core.lazy_property import LazyProperty -from allensdk.brain_observatory.behavior.data_objects.trials.trials import \ - calculate_reward_rate +from allensdk.brain_observatory.behavior.data_objects.trials.trials import calculate_reward_rate from allensdk.deprecated import deprecated -csv_io = { - 'reader': lambda path: pd.read_csv(path, index_col='Unnamed: 0'), - 'writer': lambda path, df: df.to_csv(path) -} +csv_io = {"reader": lambda path: pd.read_csv(path, index_col="Unnamed: 0"), "writer": lambda path, df: df.to_csv(path)} -cache_path_example = '/allen/programs/braintv/workgroups/nc-ophys/' \ - 'visual_behavior/SWDB_2019/cache_20190813' +cache_path_example = "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813" -@deprecated("swdb.behavior_project_cache.BehaviorProjectCache is deprecated " - "and will be removed in version 1.3. Please use brain_observatory." - "behavior.behavior_project_cache.BehaviorProjectCache.") +@deprecated( + "swdb.behavior_project_cache.BehaviorProjectCache is deprecated " + "and will be removed in version 1.3. Please use brain_observatory." + "behavior.behavior_project_cache.BehaviorProjectCache." +) class BehaviorProjectCache(object): def __init__(self, cache_base): - ''' + """ A cache-level object for the behavior/ophys data. Provides access to the manifest of ophys/behavior containers, as well as pre-computed analysis files @@ -57,118 +52,94 @@ def __init__(self, cache_base): Returns a dictionary with behavior stages as keys and the corresponding session object from that container, that stage as the value. - ''' + """ self.cache_paths = { - 'manifest_path': os.path.join(cache_base, - 'visual_behavior_data_manifest.csv'), - 'nwb_base_dir': os.path.join(cache_base, 'nwb_files'), - 'analysis_files_base_dir': os.path.join(cache_base, - 'analysis_files'), - 'analysis_files_metadata_path': os.path.join( - cache_base, 'analysis_files_metadata.json'), + "manifest_path": os.path.join(cache_base, "visual_behavior_data_manifest.csv"), + "nwb_base_dir": os.path.join(cache_base, "nwb_files"), + "analysis_files_base_dir": os.path.join(cache_base, "analysis_files"), + "analysis_files_metadata_path": os.path.join(cache_base, "analysis_files_metadata.json"), } - self.experiment_table = csv_io['reader']( - self.cache_paths['manifest_path']) - - self.experiment_table['cre_line'] = self.experiment_table[ - 'full_genotype'].apply(BehaviorMetadata.parse_cre_line) - self.experiment_table['passive_session'] = self.experiment_table[ - 'stage_name'].apply(parse_passive) - self.experiment_table['image_set'] = self.experiment_table[ - 'stage_name'].apply(parse_image_set) - - self.experiment_table = self.experiment_table[[ - 'ophys_experiment_id', - 'container_id', - 'full_genotype', - 'cre_line', - 'imaging_depth', - 'targeted_structure', - 'image_set', - 'stage_name', - 'passive_session', - 'animal_name', - 'sex', - 'date_of_acquisition', - 'retake_number' - ]] - - self.nwb_base_dir = self.cache_paths['nwb_base_dir'] - self.analysis_files_base_dir = self.cache_paths[ - 'analysis_files_base_dir'] + self.experiment_table = csv_io["reader"](self.cache_paths["manifest_path"]) + + self.experiment_table["cre_line"] = self.experiment_table["full_genotype"].apply( + BehaviorMetadata.parse_cre_line + ) + self.experiment_table["passive_session"] = self.experiment_table["stage_name"].apply(parse_passive) + self.experiment_table["image_set"] = self.experiment_table["stage_name"].apply(parse_image_set) + + self.experiment_table = self.experiment_table[ + [ + "ophys_experiment_id", + "container_id", + "full_genotype", + "cre_line", + "imaging_depth", + "targeted_structure", + "image_set", + "stage_name", + "passive_session", + "animal_name", + "sex", + "date_of_acquisition", + "retake_number", + ] + ] + + self.nwb_base_dir = self.cache_paths["nwb_base_dir"] + self.analysis_files_base_dir = self.cache_paths["analysis_files_base_dir"] self.analysis_files_metadata = self.get_analysis_files_metadata( - self.cache_paths['analysis_files_metadata_path'] + self.cache_paths["analysis_files_metadata_path"] ) def get_analysis_files_metadata(self, path): - with open(path, 'r') as metadata_path: + with open(path, "r") as metadata_path: metadata = json.load(metadata_path) return metadata def get_nwb_filepath(self, experiment_id): - return os.path.join( - self.nwb_base_dir, - 'behavior_ophys_session_{}.nwb'.format(experiment_id) - ) + return os.path.join(self.nwb_base_dir, "behavior_ophys_session_{}.nwb".format(experiment_id)) def get_trial_response_df_path(self, experiment_id): - return os.path.join( - self.analysis_files_base_dir, - 'trial_response_df_{}.h5'.format(experiment_id) - ) + return os.path.join(self.analysis_files_base_dir, "trial_response_df_{}.h5".format(experiment_id)) def get_flash_response_df_path(self, experiment_id): - return os.path.join( - self.analysis_files_base_dir, - 'flash_response_df_{}.h5'.format(experiment_id) - ) + return os.path.join(self.analysis_files_base_dir, "flash_response_df_{}.h5".format(experiment_id)) def get_extended_stimulus_presentations_df(self, experiment_id): return os.path.join( - self.analysis_files_base_dir, - 'extended_stimulus_presentations_df_{}.h5'.format(experiment_id) + self.analysis_files_base_dir, "extended_stimulus_presentations_df_{}.h5".format(experiment_id) ) def get_session(self, experiment_id): - ''' + """ Return a BehaviorOphysExperiment object given an ophys_experiment_id. - ''' + """ nwb_path = self.get_nwb_filepath(experiment_id) trial_response_df_path = self.get_trial_response_df_path(experiment_id) flash_response_df_path = self.get_flash_response_df_path(experiment_id) - extended_stim_df_path = self.get_extended_stimulus_presentations_df( - experiment_id) - api = ExtendedNwbApi( - nwb_path, - trial_response_df_path, - flash_response_df_path, - extended_stim_df_path - ) + extended_stim_df_path = self.get_extended_stimulus_presentations_df(experiment_id) + api = ExtendedNwbApi(nwb_path, trial_response_df_path, flash_response_df_path, extended_stim_df_path) session = ExtendedBehaviorOphysExperiment(api) return session def get_container_sessions(self, container_id): container_stages = {} - container_experiments = self.experiment_table.groupby( - 'container_id').get_group(container_id) + container_experiments = self.experiment_table.groupby("container_id").get_group(container_id) for ind_row, row in container_experiments.iterrows(): - container_stages.update( - {row['stage_name']: self.get_session( - row['ophys_experiment_id'])} - ) + container_stages.update({row["stage_name"]: self.get_session(row["ophys_experiment_id"])}) return container_stages def parse_passive(behavior_stage): - ''' + """ Args: behavior_stage (str): the stage string, e.g. OPHYS_1_images_A or OPHYS_1_images_A_passive Returns: passive (bool): whether or not the session was a passive session - ''' + """ r = re.compile(".*_passive") if r.match(behavior_stage): return True @@ -177,62 +148,58 @@ def parse_passive(behavior_stage): def parse_image_set(behavior_stage): - ''' + """ Args: behavior_stage (str): the stage string, e.g. OPHYS_1_images_A or OPHYS_1_images_A_passive Returns: image_set (str): which image set is designated by the stage name - ''' + """ r = re.compile(".*images_(?P[AB]).*") - image_set = r.match(behavior_stage).groups('image_set')[0] + image_set = r.match(behavior_stage).groups("image_set")[0] return image_set class ExtendedNwbApi(BehaviorOphysNwbApi): - def __init__(self, nwb_path, trial_response_df_path, - flash_response_df_path, - extended_stimulus_presentations_df_path): - ''' + def __init__( + self, nwb_path, trial_response_df_path, flash_response_df_path, extended_stimulus_presentations_df_path + ): + """ Api to read data from an NWB file and associated analysis HDF5 files. - ''' - super(ExtendedNwbApi, self).__init__(path=nwb_path, - filter_invalid_rois=True) + """ + super(ExtendedNwbApi, self).__init__(path=nwb_path, filter_invalid_rois=True) self.trial_response_df_path = trial_response_df_path self.flash_response_df_path = flash_response_df_path - self.extended_stimulus_presentations_df_path = \ - extended_stimulus_presentations_df_path + self.extended_stimulus_presentations_df_path = extended_stimulus_presentations_df_path def get_trial_response_df(self): - tdf = pd.read_hdf(self.trial_response_df_path, key='df') + tdf = pd.read_hdf(self.trial_response_df_path, key="df") tdf.reset_index(inplace=True) - tdf.drop(columns=['cell_roi_id'], inplace=True) + tdf.drop(columns=["cell_roi_id"], inplace=True) return tdf def get_flash_response_df(self): - fdf = pd.read_hdf(self.flash_response_df_path, key='df') + fdf = pd.read_hdf(self.flash_response_df_path, key="df") fdf.reset_index(inplace=True) - fdf.drop(columns=['image_name', 'cell_roi_id'], inplace=True) - fdf = fdf.join(self.get_stimulus_presentations(), on='flash_id', - how='left') + fdf.drop(columns=["image_name", "cell_roi_id"], inplace=True) + fdf = fdf.join(self.get_stimulus_presentations(), on="flash_id", how="left") return fdf def get_extended_stimulus_presentations_df(self): - return pd.read_hdf(self.extended_stimulus_presentations_df_path, - key='df') + return pd.read_hdf(self.extended_stimulus_presentations_df_path, key="df") def get_task_parameters(self): - ''' + """ The task parameters are incorrect. See: https://github.com/AllenInstitute/AllenSDK/issues/637 We need to hard-code the omitted flash fraction and stimulus duration here. - ''' + """ task_parameters = super(ExtendedNwbApi, self).get_task_parameters() - task_parameters['omitted_flash_fraction'] = 0.05 - task_parameters['stimulus_duration_sec'] = 0.25 - task_parameters['blank_duration_sec'] = 0.5 - task_parameters.pop('task') + task_parameters["omitted_flash_fraction"] = 0.05 + task_parameters["stimulus_duration_sec"] = 0.25 + task_parameters["blank_duration_sec"] = 0.5 + task_parameters.pop("task") return task_parameters def get_metadata(self): @@ -240,18 +207,18 @@ def get_metadata(self): # We want stage name in metadata for easy access by the students task_parameters = self.get_task_parameters() - metadata['stage'] = task_parameters['stage'] + metadata["stage"] = task_parameters["stage"] # metadata should not include 'session_type' because it is 'Unknown' - metadata.pop('session_type') + metadata.pop("session_type") # For SWDB only # metadata should not include 'behavior_session_uuid' because it is # not useful to students and confusing - metadata.pop('behavior_session_uuid') + metadata.pop("behavior_session_uuid") # Rename LabTracks_ID to mouse_id to reduce student confusion - metadata['mouse_id'] = metadata.pop('LabTracks_ID') + metadata["mouse_id"] = metadata.pop("LabTracks_ID") return metadata @@ -263,13 +230,11 @@ def get_running_speed(self): # have columns for both 'timestamps' and 'values' of things, # since this is more intuitive for students running_speed = super(ExtendedNwbApi, self).get_running_speed() - return pd.DataFrame({'speed': running_speed.speed, - 'timestamps': running_speed.timestamps}) + return pd.DataFrame({"speed": running_speed.speed, "timestamps": running_speed.timestamps}) def get_trials(self, filter_aborted_trials=True): trials = super(ExtendedNwbApi, self).get_trials() - stimulus_presentations = super(ExtendedNwbApi, - self).get_stimulus_presentations() + stimulus_presentations = super(ExtendedNwbApi, self).get_stimulus_presentations() # Note: everything between dashed lines is a patch to deal with # timing issues in @@ -281,14 +246,13 @@ def get_trials(self, filter_aborted_trials=True): # gets start_time of next stimulus after timestamp in # stimulus_presentations def get_next_flash(timestamp): - query = stimulus_presentations.query('start_time >= @timestamp') + query = stimulus_presentations.query("start_time >= @timestamp") if len(query) > 0: - return query.iloc[0]['start_time'] + return query.iloc[0]["start_time"] else: return None - trials['change_time'] = trials['change_time'].map( - lambda x: get_next_flash(x)) + trials["change_time"] = trials["change_time"].map(lambda x: get_next_flash(x)) # This method can lead to a NaN change time for any trials at the # end of the session. @@ -298,138 +262,129 @@ def get_next_flash(timestamp): # covered by the # stimulus_presentations # Using start time in case last stim is omitted - last_stimulus_presentation = stimulus_presentations.iloc[-1][ - 'start_time'] - trials = trials[ - np.logical_not(trials['stop_time'] > last_stimulus_presentation)] + last_stimulus_presentation = stimulus_presentations.iloc[-1]["start_time"] + trials = trials[np.logical_not(trials["stop_time"] > last_stimulus_presentation)] # recalculates response latency based on corrected change time and # first lick time def recalculate_response_latency(row): - if len(row['lick_times'] > 0) and not pd.isnull( - row['change_time']): - return row['lick_times'][0] - row['change_time'] + if len(row["lick_times"] > 0) and not pd.isnull(row["change_time"]): + return row["lick_times"][0] - row["change_time"] else: return np.nan - trials['response_latency'] = trials.apply(recalculate_response_latency, - axis=1) + trials["response_latency"] = trials.apply(recalculate_response_latency, axis=1) # ------------------------------------------------------------------------------- # asserts that every change time exists in the # stimulus_presentations table - for change_time in ( - trials[trials['change_time'].notna()]['change_time']): - assert change_time in stimulus_presentations['start_time'].values + for change_time in trials[trials["change_time"].notna()]["change_time"]: + assert change_time in stimulus_presentations["start_time"].values # Return only non-aborted trials from this API by default if filter_aborted_trials: - trials = trials.query('not aborted') + trials = trials.query("not aborted") # Reorder / drop some columns to make more sense to students - trials = trials[[ - 'initial_image_name', - 'change_image_name', - 'change_time', - 'lick_times', - 'response_latency', - 'reward_time', - 'go', - 'catch', - 'hit', - 'miss', - 'false_alarm', - 'correct_reject', - 'aborted', - 'auto_rewarded', - 'reward_volume', - 'start_time', - 'stop_time', - 'trial_length' - ]] + trials = trials[ + [ + "initial_image_name", + "change_image_name", + "change_time", + "lick_times", + "response_latency", + "reward_time", + "go", + "catch", + "hit", + "miss", + "false_alarm", + "correct_reject", + "aborted", + "auto_rewarded", + "reward_volume", + "start_time", + "stop_time", + "trial_length", + ] + ] # Calculate reward rate per trial - trials['reward_rate'] = calculate_reward_rate( + trials["reward_rate"] = calculate_reward_rate( response_latency=trials.response_latency, starttime=trials.start_time, - window=.75, + window=0.75, trial_window=25, - initial_trials=10 + initial_trials=10, ) # Response_binary is just whether or not they responded - e.g. true # for hit or FA. - hit = trials['hit'].values - fa = trials['false_alarm'].values - trials['response_binary'] = np.logical_or(hit, fa) + hit = trials["hit"].values + fa = trials["false_alarm"].values + trials["response_binary"] = np.logical_or(hit, fa) return trials def get_stimulus_presentations(self): - stimulus_presentations = super(ExtendedNwbApi, - self).get_stimulus_presentations() - extended_stimulus_presentations = \ - self.get_extended_stimulus_presentations_df() - extended_stimulus_presentations = extended_stimulus_presentations.drop( - columns=['omitted']) - stimulus_presentations = stimulus_presentations.join( - extended_stimulus_presentations) + stimulus_presentations = super(ExtendedNwbApi, self).get_stimulus_presentations() + extended_stimulus_presentations = self.get_extended_stimulus_presentations_df() + extended_stimulus_presentations = extended_stimulus_presentations.drop(columns=["omitted"]) + stimulus_presentations = stimulus_presentations.join(extended_stimulus_presentations) # Reorder the columns returned to make more sense to students - stimulus_presentations = stimulus_presentations[[ - 'image_name', - 'image_index', - 'start_time', - 'stop_time', - 'omitted', - 'change', - 'duration', - 'licks', - 'rewards', - 'running_speed', - 'index', - 'time_from_last_lick', - 'time_from_last_reward', - 'time_from_last_change', - 'block_index', - 'image_block_repetition', - 'repeat_within_block', - 'image_set' - ]] + stimulus_presentations = stimulus_presentations[ + [ + "image_name", + "image_index", + "start_time", + "stop_time", + "omitted", + "change", + "duration", + "licks", + "rewards", + "running_speed", + "index", + "time_from_last_lick", + "time_from_last_reward", + "time_from_last_change", + "block_index", + "image_block_repetition", + "repeat_within_block", + "image_set", + ] + ] # Rename some columns to make more sense to students stimulus_presentations = stimulus_presentations.rename( - columns={'index': 'absolute_flash_number', - 'running_speed': 'mean_running_speed'}) + columns={"index": "absolute_flash_number", "running_speed": "mean_running_speed"} + ) # Replace image set with A/B - stimulus_presentations['image_set'] = \ - self.get_task_parameters()['stage'][15] + stimulus_presentations["image_set"] = self.get_task_parameters()["stage"][15] # Change index name for easier merge with flash_response_df - stimulus_presentations.index.rename('flash_id', inplace=True) + stimulus_presentations.index.rename("flash_id", inplace=True) return stimulus_presentations def get_stimulus_templates(self): # super stim templates is a dict with one annoyingly-long key, # so pop the val out - stimulus_templates = super(ExtendedNwbApi, - self).get_stimulus_templates() - stimulus_template_array = stimulus_templates[ - list(stimulus_templates.keys())[0]] + stimulus_templates = super(ExtendedNwbApi, self).get_stimulus_templates() + stimulus_template_array = stimulus_templates[list(stimulus_templates.keys())[0]] # What we really want is a dict with image_name as key template_dict = {} image_index_names = self.get_image_index_names() for image_index, image_name in image_index_names.items(): - if image_name != 'omitted': - template_dict.update( - {image_name: stimulus_template_array[image_index, :, :]}) + if image_name != "omitted": + template_dict.update({image_name: stimulus_template_array[image_index, :, :]}) return template_dict def get_licks(self): # Licks column 'time' should be 'timestamps' to be consistent with # rest of session licks = super(ExtendedNwbApi, self).get_licks() - licks = licks.rename(columns={'time': 'timestamps'}) + licks = licks.rename(columns={"time": "timestamps"}) return licks def get_rewards(self): @@ -447,13 +402,14 @@ def get_dff_traces(self): # This is just for Friday Harbor, not for eventual inclusion in the # LIMS api. dff_traces = super(ExtendedNwbApi, self).get_dff_traces() - dff_traces = dff_traces.drop(columns=['cell_roi_id']) + dff_traces = dff_traces.drop(columns=["cell_roi_id"]) return dff_traces def get_image_index_names(self): - image_index_names = self.get_stimulus_presentations().groupby( - 'image_index').apply( - lambda group: one(group['image_name'].unique()) + image_index_names = ( + self.get_stimulus_presentations() + .groupby("image_index") + .apply(lambda group: one(group["image_name"].unique())) ) return image_index_names @@ -533,8 +489,7 @@ def __init__(self, api): def get_roi_masks(self): masks = super(ExtendedBehaviorOphysExperiment, self).get_roi_masks() return { - cell_specimen_id: masks.loc[ - {"cell_specimen_id": cell_specimen_id}].data + cell_specimen_id: masks.loc[{"cell_specimen_id": cell_specimen_id}].data for cell_specimen_id in masks["cell_specimen_id"].data } @@ -545,5 +500,4 @@ def get_segmentation_mask_image(self): if __name__ == "__main__": cache = BehaviorProjectCache(cache_path_example) - session = cache.get_session( - cache.experiment_table.iloc[0]['ophys_experiment_id']) + session = cache.get_session(cache.experiment_table.iloc[0]["ophys_experiment_id"]) diff --git a/allensdk/brain_observatory/behavior/swdb/create_multi_session_df.py b/allensdk/brain_observatory/behavior/swdb/create_multi_session_df.py index 72c84546a8..d3e6c74193 100644 --- a/allensdk/brain_observatory/behavior/swdb/create_multi_session_df.py +++ b/allensdk/brain_observatory/behavior/swdb/create_multi_session_df.py @@ -3,22 +3,27 @@ from allensdk.brain_observatory.behavior.swdb import utilities as ut -cache_json = {'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/analysis_files', - 'analysis_files_metadata_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/analysis_files_metadata.json', - } +cache_json = { + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/analysis_files", + "analysis_files_metadata_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/analysis_files_metadata.json", +} cache = bpc.BehaviorProjectCache(cache_json) manifest = cache.manifest experiment_ids = manifest.ophys_experiment_id.unique() -print('generating mega_trial_mdf') -mega_trial_mdf = ut.create_multi_session_mean_df(cache, experiment_ids, conditions=['cell_specimen_id','change_image_name']) -save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019' -mega_trial_mdf.to_hdf(os.path.join(save_dir, 'multi_session_mean_trials_df.h5'), key='df') -print('done with trials, creating mega_flash_mdf') -mega_flash_mdf = ut.create_multi_session_mean_df(cache, experiment_ids, flashes=True, conditions=['cell_specimen_id','image_name']) -save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019' -mega_flash_mdf.to_hdf(os.path.join(save_dir, 'multi_session_mean_flashes_df.h5'), key='df') -print('done with flash df') \ No newline at end of file +print("generating mega_trial_mdf") +mega_trial_mdf = ut.create_multi_session_mean_df( + cache, experiment_ids, conditions=["cell_specimen_id", "change_image_name"] +) +save_dir = r"/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019" +mega_trial_mdf.to_hdf(os.path.join(save_dir, "multi_session_mean_trials_df.h5"), key="df") +print("done with trials, creating mega_flash_mdf") +mega_flash_mdf = ut.create_multi_session_mean_df( + cache, experiment_ids, flashes=True, conditions=["cell_specimen_id", "image_name"] +) +save_dir = r"/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019" +mega_flash_mdf.to_hdf(os.path.join(save_dir, "multi_session_mean_flashes_df.h5"), key="df") +print("done with flash df") diff --git a/allensdk/brain_observatory/behavior/swdb/run_multi_session_df.py b/allensdk/brain_observatory/behavior/swdb/run_multi_session_df.py index c9b191fa48..f3cc622a8a 100644 --- a/allensdk/brain_observatory/behavior/swdb/run_multi_session_df.py +++ b/allensdk/brain_observatory/behavior/swdb/run_multi_session_df.py @@ -1,25 +1,27 @@ import sys -sys.path.append('/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools') + +sys.path.append("/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools") from pbstools import PythonJob # python_file = r"/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/summary_figures.py" python_file = r"/home/marinag/AllenSDK/allensdk/brain_observatory/behavior/swdb/create_multi_session_df.py" -jobdir = '/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/visb_swdb_summary_figures' +jobdir = "/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/visb_swdb_summary_figures" -job_settings = {'queue': 'braintv', - 'mem': '100g', - 'walltime': '2:00:00', - 'ppn':1, - 'jobdir': jobdir, - } +job_settings = { + "queue": "braintv", + "mem": "100g", + "walltime": "2:00:00", + "ppn": 1, + "jobdir": jobdir, +} PythonJob( python_file, - python_executable = '/home/marinag/anaconda2/envs/visual_behavior_sdk/bin/python', - python_args = None, - conda_env = None, - jobname = 'multi_session_dfs', - **job_settings - ).run(dryrun=False) + python_executable="/home/marinag/anaconda2/envs/visual_behavior_sdk/bin/python", + python_args=None, + conda_env=None, + jobname="multi_session_dfs", + **job_settings, +).run(dryrun=False) diff --git a/allensdk/brain_observatory/behavior/swdb/run_save_extended_stimulus_presentations_df.py b/allensdk/brain_observatory/behavior/swdb/run_save_extended_stimulus_presentations_df.py index 19332a63c8..537378d276 100644 --- a/allensdk/brain_observatory/behavior/swdb/run_save_extended_stimulus_presentations_df.py +++ b/allensdk/brain_observatory/behavior/swdb/run_save_extended_stimulus_presentations_df.py @@ -1,34 +1,37 @@ import sys -sys.path.append('/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools') -from pbstools import PythonJob + +sys.path.append("/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools") +from pbstools import PythonJob import behavior_project_cache as bpc python_file = r"/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/save_extended_stimulus_presentations_df.py" -jobdir = '/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190813_save_extended_stim' +jobdir = "/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190813_save_extended_stim" -job_settings = {'queue': 'braintv', - 'mem': '15g', - 'walltime': '0:30:00', - 'ppn':1, - 'jobdir': jobdir, - } +job_settings = { + "queue": "braintv", + "mem": "15g", + "walltime": "0:30:00", + "ppn": 1, + "jobdir": jobdir, +} -cache_json = {'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files_final' - } +cache_json = { + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files_final", +} cache = bpc.BehaviorProjectCache(cache_json) -experiment_ids = cache.manifest['ophys_experiment_id'].values +experiment_ids = cache.manifest["ophys_experiment_id"].values for experiment_id in experiment_ids: PythonJob( python_file, - python_executable = '/home/nick.ponvert/anaconda3/envs/allen/bin/python', - python_args = experiment_id, - conda_env = None, - jobname = 'extended_stimulus_df_{}'.format(experiment_id), - **job_settings + python_executable="/home/nick.ponvert/anaconda3/envs/allen/bin/python", + python_args=experiment_id, + conda_env=None, + jobname="extended_stimulus_df_{}".format(experiment_id), + **job_settings, ).run(dryrun=False) diff --git a/allensdk/brain_observatory/behavior/swdb/run_save_flash_response_df.py b/allensdk/brain_observatory/behavior/swdb/run_save_flash_response_df.py index cc3747608c..3f1e99e0a2 100644 --- a/allensdk/brain_observatory/behavior/swdb/run_save_flash_response_df.py +++ b/allensdk/brain_observatory/behavior/swdb/run_save_flash_response_df.py @@ -1,34 +1,37 @@ import sys -sys.path.append('/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools') -from pbstools import PythonJob + +sys.path.append("/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools") +from pbstools import PythonJob import behavior_project_cache as bpc python_file = r"/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/save_flash_response_df.py" -jobdir = '/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190810_save_flash_response_df' +jobdir = "/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190810_save_flash_response_df" -job_settings = {'queue': 'braintv', - 'mem': '24g', - 'walltime': '12:00:00', - 'ppn':1, - 'jobdir': jobdir, - } +job_settings = { + "queue": "braintv", + "mem": "24g", + "walltime": "12:00:00", + "ppn": 1, + "jobdir": jobdir, +} -cache_json = {'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files' - } +cache_json = { + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files", +} cache = bpc.BehaviorProjectCache(cache_json) -experiment_ids = cache.manifest['ophys_experiment_id'].values +experiment_ids = cache.manifest["ophys_experiment_id"].values for experiment_id in experiment_ids: PythonJob( python_file, - python_executable = '/home/nick.ponvert/anaconda3/envs/allen/bin/python', - python_args = experiment_id, - conda_env = None, - jobname = 'flash_response_df_{}'.format(experiment_id), - **job_settings + python_executable="/home/nick.ponvert/anaconda3/envs/allen/bin/python", + python_args=experiment_id, + conda_env=None, + jobname="flash_response_df_{}".format(experiment_id), + **job_settings, ).run(dryrun=False) diff --git a/allensdk/brain_observatory/behavior/swdb/run_save_trial_response_df.py b/allensdk/brain_observatory/behavior/swdb/run_save_trial_response_df.py index 026da87553..f2239fc450 100644 --- a/allensdk/brain_observatory/behavior/swdb/run_save_trial_response_df.py +++ b/allensdk/brain_observatory/behavior/swdb/run_save_trial_response_df.py @@ -1,34 +1,37 @@ import sys -sys.path.append('/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools') -from pbstools import PythonJob + +sys.path.append("/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools") +from pbstools import PythonJob import behavior_project_cache as bpc python_file = r"/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/save_trial_response_df.py" -jobdir = '/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190810_save_trial_response_df' +jobdir = "/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/20190810_save_trial_response_df" -job_settings = {'queue': 'braintv', - 'mem': '15g', - 'walltime': '0:30:00', - 'ppn':1, - 'jobdir': jobdir, - } +job_settings = { + "queue": "braintv", + "mem": "15g", + "walltime": "0:30:00", + "ppn": 1, + "jobdir": jobdir, +} -cache_json = {'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files' - } +cache_json = { + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files", +} cache = bpc.BehaviorProjectCache(cache_json) -experiment_ids = cache.manifest['ophys_experiment_id'].values +experiment_ids = cache.manifest["ophys_experiment_id"].values for experiment_id in experiment_ids: PythonJob( python_file, - python_executable = '/home/nick.ponvert/anaconda3/envs/allen/bin/python', - python_args = experiment_id, - conda_env = None, - jobname = 'trial_response_df_{}'.format(experiment_id), - **job_settings + python_executable="/home/nick.ponvert/anaconda3/envs/allen/bin/python", + python_args=experiment_id, + conda_env=None, + jobname="trial_response_df_{}".format(experiment_id), + **job_settings, ).run(dryrun=False) diff --git a/allensdk/brain_observatory/behavior/swdb/run_summary_figures.py b/allensdk/brain_observatory/behavior/swdb/run_summary_figures.py index 6d7c04d2c7..7047b86722 100644 --- a/allensdk/brain_observatory/behavior/swdb/run_summary_figures.py +++ b/allensdk/brain_observatory/behavior/swdb/run_summary_figures.py @@ -1,6 +1,7 @@ import sys -sys.path.append('/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools') -from pbstools import PythonJob + +sys.path.append("/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/pbstools") +from pbstools import PythonJob import behavior_project_cache as bpc # python_file = r"/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/summary_figures.py" @@ -8,32 +9,34 @@ python_file = r"/home/marinag/AllenSDK/allensdk/brain_observatory/behavior/swdb/summary_figures.py" # python_file = r"/home/nick.ponvert/src/AllenSDK/allensdk/brain_observatory/behavior/swdb/summary_figures.py" -jobdir = '/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/visb_swdb_summary_figures' +jobdir = "/allen/programs/braintv/workgroups/nc-ophys/nick.ponvert/cluster_jobs/visb_swdb_summary_figures" -job_settings = {'queue': 'braintv', - 'mem': '15g', - 'walltime': '0:30:00', - 'ppn':1, - 'jobdir': jobdir, - } +job_settings = { + "queue": "braintv", + "mem": "15g", + "walltime": "0:30:00", + "ppn": 1, + "jobdir": jobdir, +} -cache_json = {'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files', - 'analysis_files_metadata_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files_metadata.json', - } +cache_json = { + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files", + "analysis_files_metadata_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files_metadata.json", +} cache = bpc.BehaviorProjectCache(cache_json) -experiment_ids = cache.manifest['ophys_experiment_id'].values +experiment_ids = cache.manifest["ophys_experiment_id"].values for experiment_id in experiment_ids: PythonJob( python_file, - python_args = experiment_id, - python_executable = '/home/marinag/anaconda2/envs/visual_behavior_sdk/bin/python', - conda_env = None, - jobname = 'trial_response_df_{}'.format(experiment_id), - **job_settings + python_args=experiment_id, + python_executable="/home/marinag/anaconda2/envs/visual_behavior_sdk/bin/python", + conda_env=None, + jobname="trial_response_df_{}".format(experiment_id), + **job_settings, ).run(dryrun=False) # python_executable = '/home/nick.ponvert/anaconda3/envs/allen/bin/python', diff --git a/allensdk/brain_observatory/behavior/swdb/save_extended_stimulus_presentations_df.py b/allensdk/brain_observatory/behavior/swdb/save_extended_stimulus_presentations_df.py index acf6f85f90..9cf0c30e7f 100644 --- a/allensdk/brain_observatory/behavior/swdb/save_extended_stimulus_presentations_df.py +++ b/allensdk/brain_observatory/behavior/swdb/save_extended_stimulus_presentations_df.py @@ -2,10 +2,8 @@ import os import numpy as np -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import ( - BehaviorOphysExperiment) -from allensdk.brain_observatory.behavior.session_apis.data_io import ( - BehaviorOphysNwbApi) +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.session_apis.data_io import BehaviorOphysNwbApi import behavior_project_cache as bpc from importlib import reload @@ -24,8 +22,7 @@ def time_from_last(flash_times, other_times): def trace_average(values, timestamps, start_time, stop_time): - values_this_range = values[ - ((timestamps >= start_time) & (timestamps < stop_time))] + values_this_range = values[((timestamps >= start_time) & (timestamps < stop_time))] return values_this_range.mean() @@ -37,7 +34,7 @@ def trace_average(values, timestamps, start_time, stop_time): def find_change(image_index, omitted_index): - ''' + """ Args: image_index (pd.Series): The index of the presented image for each flash @@ -45,11 +42,10 @@ def find_change(image_index, omitted_index): Returns: change (np.array of bool): Whether each flash was a change flash - ''' + """ change = np.diff(image_index) != 0 - change = np.concatenate( - [np.array([False]), change]) # First flash not a change + change = np.concatenate([np.array([False]), change]) # First flash not a change omitted = image_index == omitted_index omitted_inds = np.flatnonzero(omitted) change[omitted_inds] = False @@ -91,9 +87,9 @@ def get_extended_stimulus_presentations(session): intermediate_df["time_from_last_change"] = time_from_last_change # Was the flash a change flash? - omitted_index = intermediate_df.groupby("image_name").apply( - lambda group: group["image_index"].unique()[0] - )["omitted"] + omitted_index = intermediate_df.groupby("image_name").apply(lambda group: group["image_index"].unique()[0])[ + "omitted" + ] changes = find_change(intermediate_df["image_index"], omitted_index) omitted = intermediate_df["image_index"] == omitted_index @@ -105,15 +101,12 @@ def get_extended_stimulus_presentations(session): changes_including_first[0] = True change_indices = np.flatnonzero(changes_including_first) flash_inds = np.arange(len(intermediate_df)) - block_inds = np.searchsorted(a=change_indices, v=flash_inds, - side="right") - 1 + block_inds = np.searchsorted(a=change_indices, v=flash_inds, side="right") - 1 intermediate_df["block_index"] = block_inds # Block repetition number - blocks_per_image = intermediate_df.groupby("image_name").apply( - lambda group: np.unique(group["block_index"]) - ) + blocks_per_image = intermediate_df.groupby("image_name").apply(lambda group: np.unique(group["block_index"])) block_repetition_number = np.copy(block_inds) for image_name, image_blocks in blocks_per_image.items(): @@ -121,8 +114,7 @@ def get_extended_stimulus_presentations(session): for ind_block, block_number in enumerate(image_blocks): # block_rep_number starts as a copy of block_inds, so we can # go write over the index number with the rep number - block_repetition_number[ - block_repetition_number == block_number] = ind_block + block_repetition_number[block_repetition_number == block_number] = ind_block intermediate_df["image_block_repetition"] = block_repetition_number @@ -143,19 +135,11 @@ def get_extended_stimulus_presentations(session): # Lists of licks/rewards on each flash licks_each_flash = intermediate_df.apply( - lambda row: lick_times[ - ((lick_times > row["start_time"]) & ( - lick_times < row["start_time"] + 0.75)) - ], + lambda row: lick_times[((lick_times > row["start_time"]) & (lick_times < row["start_time"] + 0.75))], axis=1, ) rewards_each_flash = intermediate_df.apply( - lambda row: reward_times[ - ( - (reward_times > row["start_time"]) - & (reward_times < row["start_time"] + 0.75) - ) - ], + lambda row: reward_times[((reward_times > row["start_time"]) & (reward_times < row["start_time"] + 0.75))], axis=1, ) @@ -196,24 +180,19 @@ def get_extended_stimulus_presentations(session): if __name__ == "__main__": - case = 0 cache_json = { "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys" - "/visual_behavior/SWDB_2019/" - "visual_behavior_data_manifest.csv", - "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys" - "/visual_behavior/SWDB_2019/nwb_files", - "analysis_files_base_dir": - "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior" - "/SWDB_2019/extra_files", + "/visual_behavior/SWDB_2019/" + "visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files", } cache = bpc.BehaviorProjectCache(cache_json) if case == 0: - experiment_id = sys.argv[1] # experiment_id = cache.manifest.iloc[5]['ophys_experiment_id'] nwb_path = cache.get_nwb_filepath(experiment_id) @@ -222,22 +201,15 @@ def get_extended_stimulus_presentations(session): # output_path = "/allen/programs/braintv/workgroups/nc-ophys # /visual_behavior/SWDB_2019/extra_files_final" - output_path = "/allen/programs/braintv/workgroups/nc-ophys" \ - "/visual_behavior/SWDB_2019/corrected_extended_stim" - - extended_stimulus_presentations_df = \ - get_extended_stimulus_presentations(session) - - output_fn = os.path.join( - output_path, - "extended_stimulus_presentations_df_{}.h5".format(experiment_id) - ) - print("Writing extended_stimulus_presentations_df to {}".format( - output_fn)) + output_path = "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/corrected_extended_stim" + + extended_stimulus_presentations_df = get_extended_stimulus_presentations(session) + + output_fn = os.path.join(output_path, "extended_stimulus_presentations_df_{}.h5".format(experiment_id)) + print("Writing extended_stimulus_presentations_df to {}".format(output_fn)) extended_stimulus_presentations_df.to_hdf(output_fn, key="df") elif case == 1: - failed_oeid = 825623170 success_oeid = 826585773 @@ -246,5 +218,4 @@ def get_extended_stimulus_presentations(session): api = BehaviorOphysNwbApi(nwb_path, filter_invalid_rois=True) session = BehaviorOphysExperiment(api) - extended_stimulus_presentations_df = \ - get_extended_stimulus_presentations(session) + extended_stimulus_presentations_df = get_extended_stimulus_presentations(session) diff --git a/allensdk/brain_observatory/behavior/swdb/save_flash_response_df.py b/allensdk/brain_observatory/behavior/swdb/save_flash_response_df.py index 30fe156265..4b79697214 100644 --- a/allensdk/brain_observatory/behavior/swdb/save_flash_response_df.py +++ b/allensdk/brain_observatory/behavior/swdb/save_flash_response_df.py @@ -4,132 +4,116 @@ import pandas as pd import itertools -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import \ - BehaviorOphysExperiment -from allensdk.brain_observatory.behavior.session_apis.data_io import ( - BehaviorOphysNwbApi) -from allensdk.brain_observatory.behavior.swdb import \ - behavior_project_cache as bpc -from allensdk.brain_observatory.behavior.swdb.analysis_tools import \ - get_trace_around_timepoint, get_mean_in_window - -''' +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.session_apis.data_io import BehaviorOphysNwbApi +from allensdk.brain_observatory.behavior.swdb import behavior_project_cache as bpc +from allensdk.brain_observatory.behavior.swdb.analysis_tools import get_trace_around_timepoint, get_mean_in_window + +""" This script computes the flash_response_df for a BehaviorOphysExperiment object -''' +""" def get_flash_response_df(session, response_analysis_params): - ''' - Builds the flash response dataframe for - - INPUTS: - BehaviorOphysExperiment to build the flash response - dataframe for - A dictionary with the following keys - 'window_around_timepoint_seconds' is the time window to save - out the dff_trace around the flash onset. - 'response_window_duration_seconds' is the length of time - after the flash onset to compute the mean_response - 'baseline_window_duration_seconds' is the length of time - before the flash onset to compute the baseline response - - OUTPUTS: - A dataframe with index: (cell_specimen_id, flash_id) - and columns: - cell_roi_id, the cell's roi id for that session - mean_response, the mean df/f in the response_window - baseline_response, the mean df/f in the baseline_window - dff_trace, the dff trace in the window_around_timepoint_seconds - dff_trace_timestamps, the timestamps for the dff_trace - - ''' - frame_rate = 31. # Shouldn't hard code this here + """ + Builds the flash response dataframe for + + INPUTS: + BehaviorOphysExperiment to build the flash response + dataframe for + A dictionary with the following keys + 'window_around_timepoint_seconds' is the time window to save + out the dff_trace around the flash onset. + 'response_window_duration_seconds' is the length of time + after the flash onset to compute the mean_response + 'baseline_window_duration_seconds' is the length of time + before the flash onset to compute the baseline response + + OUTPUTS: + A dataframe with index: (cell_specimen_id, flash_id) + and columns: + cell_roi_id, the cell's roi id for that session + mean_response, the mean df/f in the response_window + baseline_response, the mean df/f in the baseline_window + dff_trace, the dff trace in the window_around_timepoint_seconds + dff_trace_timestamps, the timestamps for the dff_trace + + """ + frame_rate = 31.0 # Shouldn't hard code this here # get data to analyze dff_traces = session.dff_traces.copy() flashes = session.stimulus_presentations.copy() # get params to define response window, in seconds - window_around_timepoint_seconds = response_analysis_params[ - 'window_around_timepoint_seconds'] - response_window_duration_seconds = response_analysis_params[ - 'response_window_duration_seconds'] - baseline_window_duration_seconds = response_analysis_params[ - 'baseline_window_duration_seconds'] - mean_response_window_seconds = [np.abs(window_around_timepoint_seconds[0]), - np.abs(window_around_timepoint_seconds[ - 0]) + - response_window_duration_seconds] - baseline_window_seconds = [np.abs( - window_around_timepoint_seconds[0]) - baseline_window_duration_seconds, - np.abs(window_around_timepoint_seconds[0])] + window_around_timepoint_seconds = response_analysis_params["window_around_timepoint_seconds"] + response_window_duration_seconds = response_analysis_params["response_window_duration_seconds"] + baseline_window_duration_seconds = response_analysis_params["baseline_window_duration_seconds"] + mean_response_window_seconds = [ + np.abs(window_around_timepoint_seconds[0]), + np.abs(window_around_timepoint_seconds[0]) + response_window_duration_seconds, + ] + baseline_window_seconds = [ + np.abs(window_around_timepoint_seconds[0]) - baseline_window_duration_seconds, + np.abs(window_around_timepoint_seconds[0]), + ] # Build a dataframe with multiindex defined as product of cell_id X # flash_id - cell_flash_combinations = itertools.product(dff_traces.index, - flashes.index) - index = pd.MultiIndex.from_tuples(cell_flash_combinations, - names=['cell_specimen_id', 'flash_id']) + cell_flash_combinations = itertools.product(dff_traces.index, flashes.index) + index = pd.MultiIndex.from_tuples(cell_flash_combinations, names=["cell_specimen_id", "flash_id"]) df = pd.DataFrame(index=index) traces_list = [] trace_timestamps_list = [] # Iterate though cell/flash pairs and build table - for cell_specimen_id, flash_id in itertools.product(dff_traces.index, - flashes.index): - timepoint = flashes.loc[flash_id]['start_time'] - cell_roi_id = dff_traces.loc[cell_specimen_id]['cell_roi_id'] - full_cell_trace = dff_traces.loc[cell_specimen_id, 'dff'] + for cell_specimen_id, flash_id in itertools.product(dff_traces.index, flashes.index): + timepoint = flashes.loc[flash_id]["start_time"] + cell_roi_id = dff_traces.loc[cell_specimen_id]["cell_roi_id"] + full_cell_trace = dff_traces.loc[cell_specimen_id, "dff"] trace, trace_timestamps = get_trace_around_timepoint( - full_cell_trace, - timepoint, - session.ophys_timestamps, - window_around_timepoint_seconds, - frame_rate) - mean_response = get_mean_in_window(trace, mean_response_window_seconds, - frame_rate) - baseline_response = get_mean_in_window(trace, baseline_window_seconds, - frame_rate) + full_cell_trace, timepoint, session.ophys_timestamps, window_around_timepoint_seconds, frame_rate + ) + mean_response = get_mean_in_window(trace, mean_response_window_seconds, frame_rate) + baseline_response = get_mean_in_window(trace, baseline_window_seconds, frame_rate) traces_list.append(trace) trace_timestamps_list.append(trace_timestamps) - df.loc[(cell_specimen_id, flash_id), 'cell_roi_id'] = int(cell_roi_id) - df.loc[(cell_specimen_id, flash_id), 'mean_response'] = mean_response - df.loc[(cell_specimen_id, - flash_id), 'baseline_response'] = baseline_response - df.insert(loc=1, column='dff_trace', value=traces_list) - df.insert(loc=2, column='dff_trace_timestamps', - value=trace_timestamps_list) + df.loc[(cell_specimen_id, flash_id), "cell_roi_id"] = int(cell_roi_id) + df.loc[(cell_specimen_id, flash_id), "mean_response"] = mean_response + df.loc[(cell_specimen_id, flash_id), "baseline_response"] = baseline_response + df.insert(loc=1, column="dff_trace", value=traces_list) + df.insert(loc=2, column="dff_trace_timestamps", value=trace_timestamps_list) return df -def get_p_values_from_shuffled_spontaneous(session, flash_response_df, - response_window_duration=0.5, - number_of_shuffles=10000): - ''' - Computes the P values for each cell/flash. The P value is the - probability of observing a response of that - magnitude in the spontaneous window. The algorithm is copied from VBA - - INPUTS: - a BehaviorOphysExperiment object - the flash_response_df for this session - is the duration of the - response_window that was used to compute the mean_response in - the flash_response_df. This is used here to extract an - equivalent duration df/f trace from the spontaneous timepoint - the number of shuffles of spontaneous - activity used to compute the pvalue - - OUTPUTS: - fdf, a copy of the flash_response_df with a new column appended - 'p_value' which is the per-flash X per-cell p-value - - ASSERTS: - each p value is bounded by 0 and 1, and does not include any NaNs - - ''' +def get_p_values_from_shuffled_spontaneous( + session, flash_response_df, response_window_duration=0.5, number_of_shuffles=10000 +): + """ + Computes the P values for each cell/flash. The P value is the + probability of observing a response of that + magnitude in the spontaneous window. The algorithm is copied from VBA + + INPUTS: + a BehaviorOphysExperiment object + the flash_response_df for this session + is the duration of the + response_window that was used to compute the mean_response in + the flash_response_df. This is used here to extract an + equivalent duration df/f trace from the spontaneous timepoint + the number of shuffles of spontaneous + activity used to compute the pvalue + + OUTPUTS: + fdf, a copy of the flash_response_df with a new column appended + 'p_value' which is the per-flash X per-cell p-value + + ASSERTS: + each p value is bounded by 0 and 1, and does not include any NaNs + + """ # Organize Data fdf = flash_response_df.copy() st = session.stimulus_presentations.copy() @@ -141,16 +125,14 @@ def get_p_values_from_shuffled_spontaneous(session, flash_response_df, # Compute the number of response_window frames ophys_frame_rate = 31 # Shouldn't hard code this here - n_mean_response_window_frames = int( - np.round(response_window_duration * ophys_frame_rate, 0)) + n_mean_response_window_frames = int(np.round(response_window_duration * ophys_frame_rate, 0)) cell_ids = np.unique(fdf.index.get_level_values(0)) n_cells = len(cell_ids) # Get Shuffled responses from spontaneous frames # get mean response for shuffles of the spontaneous activity frames # in a window the same size as the stim response window duration - shuffled_responses = np.empty( - (n_cells, number_of_shuffles, n_mean_response_window_frames)) + shuffled_responses = np.empty((n_cells, number_of_shuffles, n_mean_response_window_frames)) idx = np.random.choice(spontaneous_frames, number_of_shuffles) dff_traces = np.stack(session.dff_traces.to_numpy()[:, 1], axis=0) for i in range(n_mean_response_window_frames): @@ -160,42 +142,38 @@ def get_p_values_from_shuffled_spontaneous(session, flash_response_df, # compare flash responses to shuffled values and make a dataframe of # p_value for cell_id X flash_id iterables = [cell_ids, st.index.values] - flash_p_values = pd.DataFrame(index=pd.MultiIndex.from_product( - iterables, - names=[ - 'cell_specimen_id', - 'flash_id'])) + flash_p_values = pd.DataFrame(index=pd.MultiIndex.from_product(iterables, names=["cell_specimen_id", "flash_id"])) for i, cell_index in enumerate(cell_ids): responses = fdf.loc[cell_index].mean_response.values null_dist_mat = np.tile(shuffled_mean[i, :], reps=(len(responses), 1)) actual_is_less = responses.reshape(len(responses), 1) <= null_dist_mat p_values = np.mean(actual_is_less, axis=1) for j in range(0, len(p_values)): - flash_p_values.at[(cell_index, j), 'p_value'] = p_values[j] + flash_p_values.at[(cell_index, j), "p_value"] = p_values[j] fdf = pd.concat([fdf, flash_p_values], axis=1) # Test to ensure p values are bounded between 0 and 1, and dont include # NaNs - assert np.all(fdf['p_value'].values <= 1) - assert np.all(fdf['p_value'].values >= 0) - assert np.all(~np.isnan(fdf['p_value'].values)) + assert np.all(fdf["p_value"].values <= 1) + assert np.all(fdf["p_value"].values >= 0) + assert np.all(~np.isnan(fdf["p_value"].values)) return fdf def get_spontaneous_frames(session): - ''' - Returns a list of the frames that occur during the before and after - spontaneous windows. This is copied from VBA. Does not use the full - spontaneous period because that is what VBA did. It only uses 4 - minutes of the before and after spontaneous period. - - INPUTS: - a BehaviorOphysExperiment object to get all the - spontaneous frames - - OUTPUTS: a list of the frames during the spontaneous period - ''' + """ + Returns a list of the frames that occur during the before and after + spontaneous windows. This is copied from VBA. Does not use the full + spontaneous period because that is what VBA did. It only uses 4 + minutes of the before and after spontaneous period. + + INPUTS: + a BehaviorOphysExperiment object to get all the + spontaneous frames + + OUTPUTS: a list of the frames during the spontaneous period + """ st = session.stimulus_presentations.copy() # dont use full 5 mins to avoid fingerprint and countdown # spont_duration_frames = 4 * 60 * 60 # 4 mins * * 60s/min * 60Hz @@ -205,35 +183,28 @@ def get_spontaneous_frames(session): behavior_start_time = st.iloc[0].start_time spontaneous_start_time_pre = behavior_start_time - spont_duration spontaneous_end_time_pre = behavior_start_time - spontaneous_start_frame_pre = get_successive_frame_list( - spontaneous_start_time_pre, session.ophys_timestamps) - spontaneous_end_frame_pre = get_successive_frame_list( - spontaneous_end_time_pre, session.ophys_timestamps) - spontaneous_frames_pre = np.arange(spontaneous_start_frame_pre, - spontaneous_end_frame_pre, 1) + spontaneous_start_frame_pre = get_successive_frame_list(spontaneous_start_time_pre, session.ophys_timestamps) + spontaneous_end_frame_pre = get_successive_frame_list(spontaneous_end_time_pre, session.ophys_timestamps) + spontaneous_frames_pre = np.arange(spontaneous_start_frame_pre, spontaneous_end_frame_pre, 1) # for spontaneous epoch at end of session behavior_end_time = st.iloc[-1].stop_time spontaneous_start_time_post = behavior_end_time + 0.5 spontaneous_end_time_post = behavior_end_time + spont_duration - spontaneous_start_frame_post = get_successive_frame_list( - spontaneous_start_time_post, session.ophys_timestamps) - spontaneous_end_frame_post = get_successive_frame_list( - spontaneous_end_time_post, session.ophys_timestamps) - spontaneous_frames_post = np.arange(spontaneous_start_frame_post, - spontaneous_end_frame_post, 1) + spontaneous_start_frame_post = get_successive_frame_list(spontaneous_start_time_post, session.ophys_timestamps) + spontaneous_end_frame_post = get_successive_frame_list(spontaneous_end_time_post, session.ophys_timestamps) + spontaneous_frames_post = np.arange(spontaneous_start_frame_post, spontaneous_end_frame_post, 1) # add them together - spontaneous_frames = list(spontaneous_frames_pre) + ( - list(spontaneous_frames_post)) + spontaneous_frames = list(spontaneous_frames_pre) + (list(spontaneous_frames_post)) return spontaneous_frames def get_successive_frame_list(timepoints_array, timestamps): - ''' - Returns the next frame after timestamps in timepoints_array - copied from VBA - ''' + """ + Returns the next frame after timestamps in timepoints_array + copied from VBA + """ # This is a modification of get_nearest_frame for speedup # This implementation looks for the first 2p frame consecutive to the stim successive_frames = np.searchsorted(timestamps, timepoints_array) @@ -241,112 +212,105 @@ def get_successive_frame_list(timepoints_array, timestamps): def add_image_name(session, fdf): - ''' - Adds a column to flash_response_df with the image_name taken from - the stimulus_presentations table - Slow to run, could probably be improved with some more intelligent - use of pandas + """ + Adds a column to flash_response_df with the image_name taken from + the stimulus_presentations table + Slow to run, could probably be improved with some more intelligent + use of pandas - INPUTS: - a BehaviorOphysExperiment object - a flash_response_df for this session + INPUTS: + a BehaviorOphysExperiment object + a flash_response_df for this session - OUTPUTS: - fdf, with a new column appended 'image_name' which gives the image - identity (like 'im066') for each flash. - ''' + OUTPUTS: + fdf, with a new column appended 'image_name' which gives the image + identity (like 'im066') for each flash. + """ fdf = fdf.reset_index() - fdf = fdf.set_index('flash_id') - fdf['image_name'] = '' + fdf = fdf.set_index("flash_id") + fdf["image_name"] = "" # So slow!!! for stim_id in np.unique(fdf.index.values): - fdf.loc[stim_id, 'image_name'] = session.stimulus_presentations.loc[ - stim_id].image_name + fdf.loc[stim_id, "image_name"] = session.stimulus_presentations.loc[stim_id].image_name fdf = fdf.reset_index() - fdf = fdf.set_index(['cell_specimen_id', 'flash_id']) + fdf = fdf.set_index(["cell_specimen_id", "flash_id"]) return fdf def annotate_flash_response_df_with_pref_stim(fdf): - ''' - Adds a column to flash_response_df with a boolean value of whether - that flash was that cells pref image. - Computes preferred image by looking for the image that on average - evokes the largest response. - Slow to run, could probably be improved with more intelligent pandas - use + """ + Adds a column to flash_response_df with a boolean value of whether + that flash was that cells pref image. + Computes preferred image by looking for the image that on average + evokes the largest response. + Slow to run, could probably be improved with more intelligent pandas + use - INPUTS: - fdf, a flash_response_dataframe + INPUTS: + fdf, a flash_response_dataframe - RETURNS: - fdf, appended with 'pref_stim' column + RETURNS: + fdf, appended with 'pref_stim' column - ASSERTS: - each cell has one unique preferred_stimulus + ASSERTS: + each cell has one unique preferred_stimulus - ''' + """ # Prepare dataframe fdf = fdf.reset_index() - if 'cell_specimen_id' in fdf.keys(): - cell_key = 'cell_specimen_id' + if "cell_specimen_id" in fdf.keys(): + cell_key = "cell_specimen_id" else: - cell_key = 'cell' + cell_key = "cell" # Set up empty column - fdf['pref_stim'] = False + fdf["pref_stim"] = False # Compute average response for each image - mean_response = fdf.groupby([cell_key, 'image_name']).apply(get_mean_sem) + mean_response = fdf.groupby([cell_key, "image_name"]).apply(get_mean_sem) m = mean_response.unstack() # Iterate through each cell and find which image evoked the largest # average response for cell in m.index: - temp = np.where(m.loc[cell]['mean_response'].values == np.nanmax( - m.loc[cell]['mean_response'].values))[0] + temp = np.where(m.loc[cell]["mean_response"].values == np.nanmax(m.loc[cell]["mean_response"].values))[0] # If the mean_response was NaN, then temp is empty, so we have this # check here if len(temp) > 0: image_index = temp[0] - pref_image = m.loc[cell]['mean_response'].index[image_index] + pref_image = m.loc[cell]["mean_response"].index[image_index] # find all repeats of that cell X pref_image, and set # 'pref_stim' to True - cell_flash_pairs = fdf[ - (fdf[cell_key] == cell) & (fdf.image_name == pref_image)].index - fdf.loc[cell_flash_pairs, 'pref_stim'] = True + cell_flash_pairs = fdf[(fdf[cell_key] == cell) & (fdf.image_name == pref_image)].index + fdf.loc[cell_flash_pairs, "pref_stim"] = True # Test to ensure preferred stimulus is unique for each cell - for cell in fdf['cell_specimen_id'].unique(): - assert len(fdf.set_index('cell_specimen_id').loc[cell].query( - 'pref_stim').image_name.unique()) == 1 + for cell in fdf["cell_specimen_id"].unique(): + assert len(fdf.set_index("cell_specimen_id").loc[cell].query("pref_stim").image_name.unique()) == 1 # Reset the df index - fdf = fdf.set_index(['cell_specimen_id', 'flash_id']) + fdf = fdf.set_index(["cell_specimen_id", "flash_id"]) return fdf def get_mean_sem(group): - ''' - Returns the mean and sem of the mean_response values for all entries - in the group. Copied from VBA + """ + Returns the mean and sem of the mean_response values for all entries + in the group. Copied from VBA - INPUTS: - group is a pandas group + INPUTS: + group is a pandas group - Output, a pandas series with the average 'mean_response' from the - group, and the sem 'mean_response' from the group - ''' - mean_response = np.mean(group['mean_response']) - sem_response = np.std(group['mean_response'].values) / np.sqrt( - len(group['mean_response'].values)) - return pd.Series( - {'mean_response': mean_response, 'sem_response': sem_response}) + Output, a pandas series with the average 'mean_response' from the + group, and the sem 'mean_response' from the group + """ + mean_response = np.mean(group["mean_response"]) + sem_response = np.std(group["mean_response"].values) / np.sqrt(len(group["mean_response"].values)) + return pd.Series({"mean_response": mean_response, "sem_response": sem_response}) -if __name__ == '__main__': - +if __name__ == "__main__": case = 0 if case == 0: @@ -357,14 +321,12 @@ def get_mean_sem(group): # Define the cache cache_json = { - 'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys' - '/visual_behavior/SWDB_2019/' - 'visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys' - '/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': - '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior' - '/SWDB_2019/extra_files' + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys" + "/visual_behavior/SWDB_2019/" + "visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior" + "/SWDB_2019/extra_files", } # load the session @@ -374,73 +336,78 @@ def get_mean_sem(group): session = BehaviorOphysExperiment(api) # Where to save the results - output_path = '/allen/programs/braintv/workgroups/nc-ophys' \ - '/visual_behavior/SWDB_2019/' \ - 'flash_response_500msec_response' + output_path = ( + "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/flash_response_500msec_response" + ) # Define parameters for dff_trace, and response_window response_analysis_params = { - 'window_around_timepoint_seconds': [-.5, .75], # -500ms, 750ms - 'response_window_duration_seconds': 0.5, - 'baseline_window_duration_seconds': 0.5} + "window_around_timepoint_seconds": [-0.5, 0.75], # -500ms, 750ms + "response_window_duration_seconds": 0.5, + "baseline_window_duration_seconds": 0.5, + } # compute the base flash_response_df - flash_response_df = get_flash_response_df(session, - response_analysis_params) + flash_response_df = get_flash_response_df(session, response_analysis_params) # Add p_value, image_name, and pref_stim - flash_response_df = get_p_values_from_shuffled_spontaneous( - session, - flash_response_df) + flash_response_df = get_p_values_from_shuffled_spontaneous(session, flash_response_df) flash_response_df = add_image_name(session, flash_response_df) - flash_response_df = annotate_flash_response_df_with_pref_stim( - flash_response_df) + flash_response_df = annotate_flash_response_df_with_pref_stim(flash_response_df) # Test columns in flash_response_df - for new_key in ['cell_roi_id', 'mean_response', 'baseline_response', - 'dff_trace', 'dff_trace_timestamps', 'p_value', - 'image_name', 'pref_stim']: + for new_key in [ + "cell_roi_id", + "mean_response", + "baseline_response", + "dff_trace", + "dff_trace_timestamps", + "p_value", + "image_name", + "pref_stim", + ]: assert new_key in flash_response_df.keys() # Save the flash_response_df to file - output_fn = os.path.join(output_path, 'flash_response_df_{}.h5'.format( - experiment_id)) - print('Writing flash response df to {}'.format(output_fn)) - flash_response_df.to_hdf(output_fn, key='df', complib='bzip2', - complevel=9) + output_fn = os.path.join(output_path, "flash_response_df_{}.h5".format(experiment_id)) + print("Writing flash response df to {}".format(output_fn)) + flash_response_df.to_hdf(output_fn, key="df", complib="bzip2", complevel=9) elif case == 1: # This case is just for debugging. It computes the flash_response_df # on a truncated portion of the data. - nwb_path = '/allen/programs/braintv/workgroups/nc-ophys' \ - '/visual_behavior/SWDB_2019/nwb_files' \ - '/behavior_ophys_session_880961028.nwb' + nwb_path = ( + "/allen/programs/braintv/workgroups/nc-ophys" + "/visual_behavior/SWDB_2019/nwb_files" + "/behavior_ophys_session_880961028.nwb" + ) api = BehaviorOphysNwbApi(nwb_path, filter_invalid_rois=True) session = BehaviorOphysExperiment(api) # Small data for testing - session.__dict__['dff_traces'].value = session.dff_traces.iloc[:5] - session.__dict__[ - 'stimulus_presentations'].value = \ - session.stimulus_presentations.iloc[ - :20] + session.__dict__["dff_traces"].value = session.dff_traces.iloc[:5] + session.__dict__["stimulus_presentations"].value = session.stimulus_presentations.iloc[:20] response_analysis_params = { - 'window_around_timepoint_seconds': [-.5, .75], # -500ms, 750ms - 'response_window_duration_seconds': 0.5, - 'baseline_window_duration_seconds': 0.5} - - flash_response_df = get_flash_response_df(session, - response_analysis_params) - flash_response_df = get_p_values_from_shuffled_spontaneous( - session, - flash_response_df) + "window_around_timepoint_seconds": [-0.5, 0.75], # -500ms, 750ms + "response_window_duration_seconds": 0.5, + "baseline_window_duration_seconds": 0.5, + } + + flash_response_df = get_flash_response_df(session, response_analysis_params) + flash_response_df = get_p_values_from_shuffled_spontaneous(session, flash_response_df) flash_response_df = add_image_name(session, flash_response_df) - flash_response_df = annotate_flash_response_df_with_pref_stim( - flash_response_df) + flash_response_df = annotate_flash_response_df_with_pref_stim(flash_response_df) # Test columns in flash_response_df - for new_key in ['cell_roi_id', 'mean_response', 'baseline_response', - 'dff_trace', 'dff_trace_timestamps', 'p_value', - 'image_name', 'pref_stim']: + for new_key in [ + "cell_roi_id", + "mean_response", + "baseline_response", + "dff_trace", + "dff_trace_timestamps", + "p_value", + "image_name", + "pref_stim", + ]: assert new_key in flash_response_df.keys() diff --git a/allensdk/brain_observatory/behavior/swdb/save_trial_response_df.py b/allensdk/brain_observatory/behavior/swdb/save_trial_response_df.py index faa750f7c7..f541108c92 100644 --- a/allensdk/brain_observatory/behavior/swdb/save_trial_response_df.py +++ b/allensdk/brain_observatory/behavior/swdb/save_trial_response_df.py @@ -5,196 +5,192 @@ from scipy import stats import itertools -from allensdk.brain_observatory.behavior.swdb import \ - behavior_project_cache as bpc +from allensdk.brain_observatory.behavior.swdb import behavior_project_cache as bpc from importlib import reload -from allensdk.brain_observatory.behavior.swdb.analysis_tools import \ - get_trace_around_timepoint, get_mean_in_window +from allensdk.brain_observatory.behavior.swdb.analysis_tools import get_trace_around_timepoint, get_mean_in_window reload(bpc) -''' +""" This file contains functions and a script for computing the trial_response_df dataframe. This file was hastily constructed before friday harbor. Places where there are known issues are flagged with PROBLEM -''' +""" def add_p_vals_tr(tr, response_window=[4, 4.5]): - ''' - Computes the p value for each cell's response on each trial. The - p-value is computed using the function 'get_p_val' - - INPUT: - tr, trial_response_dataframe - response_window, the time points in the dff trace to use for - computing the p-value. - PROBLEM: The default value here assumes that the - dff_trace starts 4 seconds before the change time. - This should be set up with more care and flexibility. - - OUTPUTS: - tr, the same trial_response_dataframe, with a new column 'p_value' - appended. - - ASSERTS: - tr['p_value'] is inclusively bounded between 0 and 1, and does not - include NaNs - ''' + """ + Computes the p value for each cell's response on each trial. The + p-value is computed using the function 'get_p_val' + + INPUT: + tr, trial_response_dataframe + response_window, the time points in the dff trace to use for + computing the p-value. + PROBLEM: The default value here assumes that the + dff_trace starts 4 seconds before the change time. + This should be set up with more care and flexibility. + + OUTPUTS: + tr, the same trial_response_dataframe, with a new column 'p_value' + appended. + + ASSERTS: + tr['p_value'] is inclusively bounded between 0 and 1, and does not + include NaNs + """ # Set up empty column - tr['p_value'] = 1. - ophys_frame_rate = 31. # Shouldn't hard code this PROBLEM + tr["p_value"] = 1.0 + ophys_frame_rate = 31.0 # Shouldn't hard code this PROBLEM # Iterate over trial/cell pairs, and compute p-value for index, row in tr.iterrows(): - tr.at[index, 'p_value'] = get_p_val(row.dff_trace, response_window, - ophys_frame_rate) + tr.at[index, "p_value"] = get_p_val(row.dff_trace, response_window, ophys_frame_rate) # Test to ensure p values are bounded between 0 and 1, and dont include # NaNs - assert np.all(tr['p_value'].values <= 1) - assert np.all(tr['p_value'].values >= 0) - assert np.all(~np.isnan(tr['p_value'].values)) + assert np.all(tr["p_value"].values <= 1) + assert np.all(tr["p_value"].values >= 0) + assert np.all(~np.isnan(tr["p_value"].values)) return tr def get_p_val(trace, response_window, frame_rate): - ''' - Computes a p-value for the trace by comparing the dff in the - response_window to the same sized trace before the response_window. - PROBLEM: This should be computed by comparing to spontaneous - activity to be consistent with the flash_response_df - - INPUTS: - trace, the dff trace for this cell/trial - response_window, [start_time, end_time] the time in seconds from the - start of trace to asses whether the activity is significant - frame_rate, the number of samples in trace per second. - - OUTPUTS: - a p-value - ''' + """ + Computes a p-value for the trace by comparing the dff in the + response_window to the same sized trace before the response_window. + PROBLEM: This should be computed by comparing to spontaneous + activity to be consistent with the flash_response_df + + INPUTS: + trace, the dff trace for this cell/trial + response_window, [start_time, end_time] the time in seconds from the + start of trace to asses whether the activity is significant + frame_rate, the number of samples in trace per second. + + OUTPUTS: + a p-value + """ response_window_duration = response_window[1] - response_window[0] baseline_end = int(response_window[0] * frame_rate) - baseline_start = int( - (response_window[0] - response_window_duration) * frame_rate) + baseline_start = int((response_window[0] - response_window_duration) * frame_rate) stim_start = int(response_window[0] * frame_rate) - stim_end = int( - (response_window[0] + response_window_duration) * frame_rate) - (_, p) = stats.f_oneway(trace[baseline_start:baseline_end], - trace[stim_start:stim_end]) + stim_end = int((response_window[0] + response_window_duration) * frame_rate) + (_, p) = stats.f_oneway(trace[baseline_start:baseline_end], trace[stim_start:stim_end]) return p def annotate_trial_response_df_with_pref_stim(trial_response_df): - ''' - Computes the preferred stimulus for each cell/trial combination. - Preferred image is computed by seeing which image evoked the largest - average mean_response across all change_images. + """ + Computes the preferred stimulus for each cell/trial combination. + Preferred image is computed by seeing which image evoked the largest + average mean_response across all change_images. - INPUTS: - trial_response_df, the trial_response_df to be annotated + INPUTS: + trial_response_df, the trial_response_df to be annotated - OUTPUTS: - a copy of trial_response_df with a new column appended 'pref_stim' - which is a boolean TRUE/FALSE for whether that change_image was that - cell's preferred image. + OUTPUTS: + a copy of trial_response_df with a new column appended 'pref_stim' + which is a boolean TRUE/FALSE for whether that change_image was that + cell's preferred image. - ASSERTS: - Each cell has one unique preferred stimulus - ''' + ASSERTS: + Each cell has one unique preferred stimulus + """ # Copy the trial_response_df rdf = trial_response_df.copy() # Set up empty column - rdf['pref_stim'] = False + rdf["pref_stim"] = False # get average mean_response for each cell X change_image - mean_response = rdf.groupby( - ['cell_specimen_id', 'change_image_name']).apply(get_mean_sem_trace) + mean_response = rdf.groupby(["cell_specimen_id", "change_image_name"]).apply(get_mean_sem_trace) m = mean_response.unstack() # set index to be cell/image pairs rdf = rdf.reset_index() - rdf = rdf.set_index(['cell_specimen_id', 'change_image_name']) + rdf = rdf.set_index(["cell_specimen_id", "change_image_name"]) # Iterate through cells, and determine which change_image evoked the # largest response for cell in m.index: - image_index = np.where(m.loc[cell]['mean_response'].values == np.max( - m.loc[cell]['mean_response'].values))[0][0] - pref_image = m.loc[cell]['mean_response'].index[image_index] + image_index = np.where(m.loc[cell]["mean_response"].values == np.max(m.loc[cell]["mean_response"].values))[0][0] + pref_image = m.loc[cell]["mean_response"].index[image_index] # Update the cell X change_image pairs to have the pref_stim set to # True - rdf.at[(cell, pref_image), 'pref_stim'] = True + rdf.at[(cell, pref_image), "pref_stim"] = True # Test to ensure preferred stimulus is unique for each cell - for cell in rdf.reset_index()['cell_specimen_id'].unique(): - assert len( - rdf.reset_index().set_index('cell_specimen_id').loc[cell].query( - 'pref_stim').change_image_name.unique()) == 1 + for cell in rdf.reset_index()["cell_specimen_id"].unique(): + assert ( + len(rdf.reset_index().set_index("cell_specimen_id").loc[cell].query("pref_stim").change_image_name.unique()) + == 1 + ) # Reset index to be cell/trial pairs rdf = rdf.reset_index() - rdf = rdf.set_index(['cell_specimen_id', 'trial_id']) + rdf = rdf.set_index(["cell_specimen_id", "trial_id"]) return rdf def get_mean_sem_trace(group): - ''' - Computes the average and sem of the mean_response column - - INPUTS: - group, a pandas group - - OUTPUT: - a pandas series with the mean_response, sem_response, mean_trace, - sem_trace, and mean_responses computed for the group. - ''' - mean_response = np.mean(group['mean_response']) - mean_responses = group['mean_response'].values - sem_response = np.std(group['mean_response'].values) / np.sqrt( - len(group['mean_response'].values)) - mean_trace = np.mean(group['dff_trace']) - sem_trace = np.std(group['dff_trace'].values) / np.sqrt( - len(group['dff_trace'].values)) + """ + Computes the average and sem of the mean_response column + + INPUTS: + group, a pandas group + + OUTPUT: + a pandas series with the mean_response, sem_response, mean_trace, + sem_trace, and mean_responses computed for the group. + """ + mean_response = np.mean(group["mean_response"]) + mean_responses = group["mean_response"].values + sem_response = np.std(group["mean_response"].values) / np.sqrt(len(group["mean_response"].values)) + mean_trace = np.mean(group["dff_trace"]) + sem_trace = np.std(group["dff_trace"].values) / np.sqrt(len(group["dff_trace"].values)) return pd.Series( - {'mean_response': mean_response, 'sem_response': sem_response, - 'mean_trace': mean_trace, 'sem_trace': sem_trace, - 'mean_responses': mean_responses}) + { + "mean_response": mean_response, + "sem_response": sem_response, + "mean_trace": mean_trace, + "sem_trace": sem_trace, + "mean_responses": mean_responses, + } + ) def get_trial_response_df(session, response_analysis_params): - ''' - Computes the trial_response_df for the session - PROBLEM: Ignores aborted trials - - INPUTS: - session, a behaviorOphysSession object to be analyzed - response_analysis_params, a dictionary with keys: - 'window_around_timepoint_seconds' The window around the - change_time to use in the dff trace - 'response_window_duration_seconds' The duration after the - change time to use in the mean_response - 'baseline_window_duration_seconds' The duration before the - change time to use as the baseline_response - - OUTPUTS: - trial_response_df, a pandas dataframe with multi-index ( - cell_specimen_id/trial_id), and columns: - cell_roi_id, this sessions roi_id - mean_response, the average dff in the response_window - baseline_response, the average dff in the baseline window - dff_trace, the dff_trace in the window_around_timepoint_seconds - dff_trace_timestamps, the timestamps for the dff_trace - ''' - frame_rate = 31. # PROBLEM, shouldnt hard code this here + """ + Computes the trial_response_df for the session + PROBLEM: Ignores aborted trials + + INPUTS: + session, a behaviorOphysSession object to be analyzed + response_analysis_params, a dictionary with keys: + 'window_around_timepoint_seconds' The window around the + change_time to use in the dff trace + 'response_window_duration_seconds' The duration after the + change time to use in the mean_response + 'baseline_window_duration_seconds' The duration before the + change time to use as the baseline_response + + OUTPUTS: + trial_response_df, a pandas dataframe with multi-index ( + cell_specimen_id/trial_id), and columns: + cell_roi_id, this sessions roi_id + mean_response, the average dff in the response_window + baseline_response, the average dff in the baseline window + dff_trace, the dff_trace in the window_around_timepoint_seconds + dff_trace_timestamps, the timestamps for the dff_trace + """ + frame_rate = 31.0 # PROBLEM, shouldnt hard code this here # get data to analyze dff_traces = session.dff_traces.copy() @@ -202,68 +198,55 @@ def get_trial_response_df(session, response_analysis_params): trials = trials[~trials.aborted] # PROBLEM # get params to define response window, in seconds - window_around_timepoint_seconds = response_analysis_params[ - 'window_around_timepoint_seconds'] - response_window_duration_seconds = response_analysis_params[ - 'response_window_duration_seconds'] - baseline_window_duration_seconds = response_analysis_params[ - 'baseline_window_duration_seconds'] - mean_response_window_seconds = [np.abs(window_around_timepoint_seconds[0]), - np.abs(window_around_timepoint_seconds[ - 0]) + - response_window_duration_seconds] - baseline_window_seconds = [np.abs( - window_around_timepoint_seconds[0]) - baseline_window_duration_seconds, - np.abs(window_around_timepoint_seconds[0])] + window_around_timepoint_seconds = response_analysis_params["window_around_timepoint_seconds"] + response_window_duration_seconds = response_analysis_params["response_window_duration_seconds"] + baseline_window_duration_seconds = response_analysis_params["baseline_window_duration_seconds"] + mean_response_window_seconds = [ + np.abs(window_around_timepoint_seconds[0]), + np.abs(window_around_timepoint_seconds[0]) + response_window_duration_seconds, + ] + baseline_window_seconds = [ + np.abs(window_around_timepoint_seconds[0]) - baseline_window_duration_seconds, + np.abs(window_around_timepoint_seconds[0]), + ] # Set up multi-index dataframe cell_trial_combinations = itertools.product(dff_traces.index, trials.index) - index = pd.MultiIndex.from_tuples(cell_trial_combinations, - names=['cell_specimen_id', 'trial_id']) + index = pd.MultiIndex.from_tuples(cell_trial_combinations, names=["cell_specimen_id", "trial_id"]) df = pd.DataFrame(index=index) # Iterate through cell/trial pairs, and construct the columns traces_list = [] trace_timestamps_list = [] - for cell_specimen_id, trial_id in itertools.product(dff_traces.index, - trials.index): - timepoint = trials.loc[trial_id]['change_time'] - cell_roi_id = dff_traces.loc[cell_specimen_id]['cell_roi_id'] - full_cell_trace = dff_traces.loc[cell_specimen_id, 'dff'] + for cell_specimen_id, trial_id in itertools.product(dff_traces.index, trials.index): + timepoint = trials.loc[trial_id]["change_time"] + cell_roi_id = dff_traces.loc[cell_specimen_id]["cell_roi_id"] + full_cell_trace = dff_traces.loc[cell_specimen_id, "dff"] trace, trace_timestamps = get_trace_around_timepoint( - full_cell_trace, - timepoint, - session.ophys_timestamps, - window_around_timepoint_seconds, - frame_rate) - mean_response = get_mean_in_window(trace, mean_response_window_seconds, - frame_rate) - baseline_response = get_mean_in_window(trace, baseline_window_seconds, - frame_rate) + full_cell_trace, timepoint, session.ophys_timestamps, window_around_timepoint_seconds, frame_rate + ) + mean_response = get_mean_in_window(trace, mean_response_window_seconds, frame_rate) + baseline_response = get_mean_in_window(trace, baseline_window_seconds, frame_rate) traces_list.append(trace) trace_timestamps_list.append(trace_timestamps) - df.loc[(cell_specimen_id, trial_id), 'cell_roi_id'] = int(cell_roi_id) - df.loc[(cell_specimen_id, trial_id), 'mean_response'] = mean_response - df.loc[(cell_specimen_id, - trial_id), 'baseline_response'] = baseline_response - df.insert(loc=1, column='dff_trace', value=traces_list) - df.insert(loc=2, column='dff_trace_timestamps', - value=trace_timestamps_list) + df.loc[(cell_specimen_id, trial_id), "cell_roi_id"] = int(cell_roi_id) + df.loc[(cell_specimen_id, trial_id), "mean_response"] = mean_response + df.loc[(cell_specimen_id, trial_id), "baseline_response"] = baseline_response + df.insert(loc=1, column="dff_trace", value=traces_list) + df.insert(loc=2, column="dff_trace_timestamps", value=trace_timestamps_list) return df -if __name__ == '__main__': +if __name__ == "__main__": # Load cache cache_json = { - 'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys' - '/visual_behavior/SWDB_2019/' - 'visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys' - '/visual_behavior/SWDB_2019/nwb_files', - 'analysis_files_base_dir': - '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior' - '/SWDB_2019/extra_files_final' + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys" + "/visual_behavior/SWDB_2019/" + "visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior" + "/SWDB_2019/extra_files_final", } cache = bpc.BehaviorProjectCache(cache_json) @@ -281,33 +264,29 @@ def get_trial_response_df(session, response_analysis_params): # Get the session using the cache so that the change time fix is # applied session = cache.get_session(experiment_id) - change_times = session.trials['change_time'][ - ~pd.isnull(session.trials['change_time'])].values - flash_times = session.stimulus_presentations['start_time'].values + change_times = session.trials["change_time"][~pd.isnull(session.trials["change_time"])].values + flash_times = session.stimulus_presentations["start_time"].values assert np.all(np.isin(change_times, flash_times)) - output_path = '/allen/programs/braintv/workgroups/nc-ophys' \ - '/visual_behavior/SWDB_2019/extra_files_final' + output_path = "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files_final" - response_analysis_params = {'window_around_timepoint_seconds': [-4, 8], - 'response_window_duration_seconds': 0.5, - 'baseline_window_duration_seconds': 0.5} + response_analysis_params = { + "window_around_timepoint_seconds": [-4, 8], + "response_window_duration_seconds": 0.5, + "baseline_window_duration_seconds": 0.5, + } - trial_response_df = get_trial_response_df(session, - response_analysis_params) + trial_response_df = get_trial_response_df(session, response_analysis_params) trial_metadata = session.trials.copy() - trial_metadata.index.names = ['trial_id'] + trial_metadata.index.names = ["trial_id"] trial_response_df = trial_response_df.join(trial_metadata) trial_response_df = add_p_vals_tr(trial_response_df) - trial_response_df = annotate_trial_response_df_with_pref_stim( - trial_response_df) + trial_response_df = annotate_trial_response_df_with_pref_stim(trial_response_df) - output_fn = os.path.join(output_path, 'trial_response_df_{}.h5'.format( - experiment_id)) - print('Writing trial response df to {}'.format(output_fn)) - trial_response_df.to_hdf(output_fn, key='df', complib='bzip2', - complevel=9) + output_fn = os.path.join(output_path, "trial_response_df_{}.h5".format(experiment_id)) + print("Writing trial response df to {}".format(output_fn)) + trial_response_df.to_hdf(output_fn, key="df", complib="bzip2", complevel=9) elif case == 1: # This is a debugging case @@ -315,24 +294,22 @@ def get_trial_response_df(session, response_analysis_params): session = cache.get_session(experiment_id) - change_times = session.trials['change_time'][ - ~pd.isnull(session.trials['change_time'])].values - flash_times = session.stimulus_presentations['start_time'].values + change_times = session.trials["change_time"][~pd.isnull(session.trials["change_time"])].values + flash_times = session.stimulus_presentations["start_time"].values assert np.all(np.isin(change_times, flash_times)) - output_path = '/allen/programs/braintv/workgroups/nc-ophys' \ - '/visual_behavior/SWDB_2019/extra_files_final' + output_path = "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/extra_files_final" - response_analysis_params = {'window_around_timepoint_seconds': [-4, 8], - 'response_window_duration_seconds': 0.5, - 'baseline_window_duration_seconds': 0.5} + response_analysis_params = { + "window_around_timepoint_seconds": [-4, 8], + "response_window_duration_seconds": 0.5, + "baseline_window_duration_seconds": 0.5, + } - trial_response_df = get_trial_response_df(session, - response_analysis_params) + trial_response_df = get_trial_response_df(session, response_analysis_params) trial_metadata = session.trials.copy() - trial_metadata.index.names = ['trial_id'] + trial_metadata.index.names = ["trial_id"] trial_response_df = trial_response_df.join(trial_metadata) trial_response_df = add_p_vals_tr(trial_response_df) - trial_response_df = annotate_trial_response_df_with_pref_stim( - trial_response_df) + trial_response_df = annotate_trial_response_df_with_pref_stim(trial_response_df) diff --git a/allensdk/brain_observatory/behavior/swdb/summary_figures.py b/allensdk/brain_observatory/behavior/swdb/summary_figures.py index 57cd522a45..319c929f7b 100644 --- a/allensdk/brain_observatory/behavior/swdb/summary_figures.py +++ b/allensdk/brain_observatory/behavior/swdb/summary_figures.py @@ -3,9 +3,9 @@ import matplotlib.pyplot as plt import seaborn as sns -sns.set_context('notebook', font_scale=1.5, rc={'lines.markeredgewidth': 2}) -sns.set_style('white') -sns.set_palette('deep') +sns.set_context("notebook", font_scale=1.5, rc={"lines.markeredgewidth": 2}) +sns.set_style("white") +sns.set_palette("deep") from allensdk.brain_observatory.behavior.swdb import behavior_project_cache as bpc from allensdk.brain_observatory.behavior.swdb import utilities as ut @@ -13,18 +13,18 @@ def get_color_for_image_name(session, image_name): images = np.sort(session.stimulus_presentations.image_name.unique()) - images = images[images != 'omitted'] + images = images[images != "omitted"] colors = sns.color_palette("hls", len(images)) image_index = np.where(images == image_name)[0][0] color = colors[image_index] return color -def addSpan(ax, amin, amax, color='k', alpha=0.3, axtype='x', zorder=1): - if axtype == 'x': - ax.axvspan(amin, amax, facecolor=color, edgecolor='none', alpha=alpha, linewidth=0, zorder=zorder) - if axtype == 'y': - ax.axhspan(amin, amax, facecolor=color, edgecolor='none', alpha=alpha, linewidth=0, zorder=zorder) +def addSpan(ax, amin, amax, color="k", alpha=0.3, axtype="x", zorder=1): + if axtype == "x": + ax.axvspan(amin, amax, facecolor=color, edgecolor="none", alpha=alpha, linewidth=0, zorder=zorder) + if axtype == "y": + ax.axhspan(amin, amax, facecolor=color, edgecolor="none", alpha=alpha, linewidth=0, zorder=zorder) def add_stim_color_span(session, ax, xlim=None): @@ -34,12 +34,12 @@ def add_stim_color_span(session, ax, xlim=None): else: stim_table = session.stimulus_presentations.copy() stim_table = stim_table[(stim_table.start_time >= xlim[0]) & (stim_table.stop_time <= xlim[1])] - if 'omitted' in stim_table.keys(): + if "omitted" in stim_table.keys(): stim_table = stim_table[~stim_table.omitted].copy() for idx in stim_table.index: - start_time = stim_table.loc[idx]['start_time'] - end_time = stim_table.loc[idx]['stop_time'] - image_name = stim_table.loc[idx]['image_name'] + start_time = stim_table.loc[idx]["start_time"] + end_time = stim_table.loc[idx]["stop_time"] + image_name = stim_table.loc[idx]["image_name"] color = get_color_for_image_name(session, image_name) addSpan(ax, start_time, end_time, color=color) return ax @@ -60,9 +60,16 @@ def plot_behavior_events(session, ax, behavior_only=False): lick_y_array[:] = lick_y reward_y_array = np.empty(len(reward_times)) reward_y_array[:] = reward_y - ax.plot(lick_times, lick_y_array, '|', color='g', markeredgewidth=1, label='licks') - ax.plot(reward_times, reward_y_array, 'o', markerfacecolor='purple', markeredgecolor='purple', markeredgewidth=0.1, - label='rewards') + ax.plot(lick_times, lick_y_array, "|", color="g", markeredgewidth=1, label="licks") + ax.plot( + reward_times, + reward_y_array, + "o", + markerfacecolor="purple", + markeredgecolor="purple", + markeredgewidth=0.1, + label="rewards", + ) return ax @@ -83,12 +90,13 @@ def plot_behavior_events_trace(session, xmin=360, length=3, ax=None, save_dir=No ax = add_stim_color_span(session, ax, xlim=[xmin, xmax]) ax = plot_behavior_events(session, ax) ax = restrict_axes(xmin, xmax, interval, ax) - ax.set_ylabel('running speed (cm/s)') - ax.set_xlabel('time (sec)') + ax.set_ylabel("running speed (cm/s)") + ax.set_xlabel("time (sec)") if save_dir: fig.tight_layout() - ut.save_figure(fig, figsize, save_dir, 'behavior_events', - str(session.metadata['ophys_experiment_id']) + '_' + str(xmin)) + ut.save_figure( + fig, figsize, save_dir, "behavior_events", str(session.metadata["ophys_experiment_id"]) + "_" + str(xmin) + ) plt.close() return ax @@ -98,14 +106,14 @@ def plot_traces_heatmap(session, ax=None): dff_traces_array = np.vstack(dff_traces.dff.values) if ax is None: fig, ax = plt.subplots(figsize=(20, 5)) - cax = ax.pcolormesh(dff_traces_array, cmap='magma', vmin=0, vmax=np.percentile(dff_traces_array, 99)) + cax = ax.pcolormesh(dff_traces_array, cmap="magma", vmin=0, vmax=np.percentile(dff_traces_array, 99)) ax.set_yticks(np.arange(0, len(dff_traces_array)), 10) - ax.set_ylabel('cells') - ax.set_xlabel('time (sec)') - ax.set_xticks(np.arange(0, len(session.ophys_timestamps), 10*60*31.)) - ax.set_xticklabels(np.arange(0, session.ophys_timestamps[-1], 10*60)) + ax.set_ylabel("cells") + ax.set_xlabel("time (sec)") + ax.set_xticks(np.arange(0, len(session.ophys_timestamps), 10 * 60 * 31.0)) + ax.set_xticklabels(np.arange(0, session.ophys_timestamps[-1], 10 * 60)) cb = plt.colorbar(cax, pad=0.015) - cb.set_label('dF/F', labelpad=3) + cb.set_label("dF/F", labelpad=3) return ax @@ -113,19 +121,19 @@ def plot_behavior_segment(session, xlims=[620, 640], ax=None): if ax is None: fig, ax = plt.subplots() ax.plot(session.running_speed.timestamps, session.running_speed.speed) - ax.set_ylabel('running speed\ncm/s') - ax.set_xlabel('time (s)') + ax.set_ylabel("running speed\ncm/s") + ax.set_xlabel("time (s)") ax.set_xlim(xlims) ax.set_ylim(-15, 60) - ax.plot(session.rewards.index.values, -10 * np.ones(np.shape(session.rewards.index.values)), 'ro') + ax.plot(session.rewards.index.values, -10 * np.ones(np.shape(session.rewards.index.values)), "ro") ax.vlines(session.licks.timestamps.values, ymin=-10, ymax=-5) image_index = -1 last_omitted = False for index, row in session.stimulus_presentations.iterrows(): if row.omitted is False: - ax.axvspan(row.start_time, row.stop_time, alpha=0.3, facecolor='gray') + ax.axvspan(row.start_time, row.stop_time, alpha=0.3, facecolor="gray") if not (row.image_index == image_index) and (not last_omitted): - ax.axvspan(row.start_time, row.stop_time, alpha=0.3, facecolor='blue') + ax.axvspan(row.start_time, row.stop_time, alpha=0.3, facecolor="blue") image_index = row.image_index last_omitted = row.omitted return ax @@ -142,25 +150,32 @@ def plot_lick_raster(trials, ax=None): reward_time = [(t - trial_data.change_time) for t in [trial_data.reward_time]] # plot reward times if len(reward_time) > 0: - ax.plot(reward_time[0], trial_index + 0.5, '.', color='b', label='reward', markersize=6) + ax.plot(reward_time[0], trial_index + 0.5, ".", color="b", label="reward", markersize=6) # plot lick times - ax.vlines(lick_times, trial_index, trial_index + 1, color='k', linewidth=1) + ax.vlines(lick_times, trial_index, trial_index + 1, color="k", linewidth=1) # put a line at the change time - ax.vlines(0, trial_index, trial_index + 1, color=[.5, .5, .5], linewidth=1) + ax.vlines(0, trial_index, trial_index + 1, color=[0.5, 0.5, 0.5], linewidth=1) # gray bar for response window - ax.axvspan(0.15, 0.75, facecolor='gray', alpha=.3, edgecolor='none') + ax.axvspan(0.15, 0.75, facecolor="gray", alpha=0.3, edgecolor="none") ax.grid(False) ax.set_ylim(0, len(trials)) ax.set_xlim([-1, 4]) - ax.set_ylabel('trials') - ax.set_xlabel('time (sec)') - ax.set_title('lick raster') + ax.set_ylabel("trials") + ax.set_xlabel("time (sec)") + ax.set_title("lick raster") plt.gca().invert_yaxis() return ax -def plot_trace(timestamps, trace, ax=None, xlabel='time (seconds)', ylabel='fluorescence', title='roi', - color=sns.color_palette()[0]): +def plot_trace( + timestamps, + trace, + ax=None, + xlabel="time (seconds)", + ylabel="fluorescence", + title="roi", + color=sns.color_palette()[0], +): if ax is None: fig, ax = plt.subplots(figsize=(15, 5)) ax.plot(timestamps, trace, color=color, linewidth=2) @@ -171,8 +186,15 @@ def plot_trace(timestamps, trace, ax=None, xlabel='time (seconds)', ylabel='fluo return ax -def plot_trace(timestamps, trace, ax=None, xlabel='time (seconds)', ylabel='fluorescence', title='roi', - color=sns.color_palette()[0]): +def plot_trace( + timestamps, + trace, + ax=None, + xlabel="time (seconds)", + ylabel="fluorescence", + title="roi", + color=sns.color_palette()[0], +): if ax is None: fig, ax = plt.subplots(figsize=(15, 5)) ax.plot(timestamps, trace, color=color, linewidth=2) @@ -198,36 +220,61 @@ def plot_example_traces_and_behavior(session, xmin_seconds, length_mins, cell_la ymins = [] ymaxs = [] for i, cell_index in enumerate(cell_indices): - ax[i].tick_params(reset=True, which='both', bottom='off', top='off', right='off', left='off', - labeltop='off', labelright='off', labelleft='off', labelbottom='off') - ax[i] = plot_trace(session.ophys_timestamps, traces[cell_index, :], ax=ax[i], - title='', ylabel=str(cell_index), color=[.5, .5, .5]) + ax[i].tick_params( + reset=True, + which="both", + bottom="off", + top="off", + right="off", + left="off", + labeltop="off", + labelright="off", + labelleft="off", + labelbottom="off", + ) + ax[i] = plot_trace( + session.ophys_timestamps, + traces[cell_index, :], + ax=ax[i], + title="", + ylabel=str(cell_index), + color=[0.5, 0.5, 0.5], + ) ax[i] = add_stim_color_span(session, ax=ax[i], xlim=xlim) ax[i] = restrict_axes(xmin_seconds, xmax_seconds, interval_seconds, ax=ax[i]) ax[i].set_xticks([]) - ax[i].set_xlabel('') + ax[i].set_xlabel("") ax[i].set_xlim(xlim) ymin, ymax = ax[i].get_ylim() ymins.append(ymin) ymaxs.append(ymax) if cell_label: - ax[i].set_ylabel('cell ' + str(i), fontsize=12) + ax[i].set_ylabel("cell " + str(i), fontsize=12) else: - ax[i].set_ylabel('') + ax[i].set_ylabel("") ax[i].set_yticks([]) sns.despine(ax=ax[i], left=True, bottom=True) ymin, ymax = ax[i].get_ylim() - if 'Vip' in session.metadata['full_genotype']: + if "Vip" in session.metadata["full_genotype"]: ax[i].vlines(x=xmin_seconds, ymin=0, ymax=2, linewidth=4) ax[i].set_ylim(ymin=-0.5, ymax=5) - elif 'Slc' in session.metadata['full_genotype']: + elif "Slc" in session.metadata["full_genotype"]: ax[i].vlines(x=xmin_seconds, ymin=0, ymax=1, linewidth=4) ax[i].set_ylim(ymin=-0.5, ymax=3) ax[i].get_xaxis().set_ticks([]) ax[i].get_yaxis().set_ticks([]) - ax[i].tick_params(which='both', bottom='off', top='off', right='off', left='off', - labeltop='off', labelright='off', labelleft='off', labelbottom='off') - ax[i].set_xticklabels('') + ax[i].tick_params( + which="both", + bottom="off", + top="off", + right="off", + left="off", + labeltop="off", + labelright="off", + labelleft="off", + labelbottom="off", + ) + ax[i].set_xticklabels("") i += 1 ax[i].tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False) @@ -236,42 +283,52 @@ def plot_example_traces_and_behavior(session, xmin_seconds, length_mins, cell_la ax[i] = add_stim_color_span(session, ax=ax[i], xlim=xlim) ax[i] = restrict_axes(xmin_seconds, xmax_seconds, interval_seconds, ax=ax[i]) ax[i].set_xlim(xlim) - ax[i].set_ylabel('run speed\n(cm/s)', fontsize=12) + ax[i].set_ylabel("run speed\n(cm/s)", fontsize=12) sns.despine(ax=ax[i], left=True, bottom=True) - ax[i].set_yticklabels('') + ax[i].set_yticklabels("") xticks = np.arange(xmin_seconds, xmax_seconds, interval_seconds) ax[i].set_xticks(xticks) ax[i].set_xticklabels(xticks) - ax[i].set_xlabel('time (seconds)') + ax[i].set_xlabel("time (seconds)") ax[0].set_title( - str(session.metadata['ophys_experiment_id']) + '_' + session.metadata['full_genotype'].split('-')[0]) + str(session.metadata["ophys_experiment_id"]) + "_" + session.metadata["full_genotype"].split("-")[0] + ) plt.subplots_adjust(wspace=0, hspace=0) plt.subplots_adjust(bottom=0.2) if save_dir: - ut.save_figure(fig, figsize, save_dir, 'example_traces', - str(session.metadata['ophys_experiment_id']) + '_' + str(xlim[0])) + ut.save_figure( + fig, figsize, save_dir, "example_traces", str(session.metadata["ophys_experiment_id"]) + "_" + str(xlim[0]) + ) def plot_transitions_response_heatmap(trials, ax=None): trials = trials[~trials.aborted] - trials['response_binary'] = [1 if response_latency < 0.75 else 0 for response_latency in - trials.response_latency.values] + trials["response_binary"] = [ + 1 if response_latency < 0.75 else 0 for response_latency in trials.response_latency.values + ] - response_matrix = pd.pivot_table(trials, - values='response_binary', - index=['initial_image_name'], - columns=['change_image_name']) + response_matrix = pd.pivot_table( + trials, values="response_binary", index=["initial_image_name"], columns=["change_image_name"] + ) if ax is None: fig, ax = plt.subplots(figsize=(5, 5)) - ax = sns.heatmap(response_matrix, cmap='magma', square=True, annot=False, - annot_kws={"fontsize": 10}, vmin=0, vmax=1, - robust=True, cbar_kws={"drawedges": False, "shrink": 0.7, "label": 'response probability'}, ax=ax) + ax = sns.heatmap( + response_matrix, + cmap="magma", + square=True, + annot=False, + annot_kws={"fontsize": 10}, + vmin=0, + vmax=1, + robust=True, + cbar_kws={"drawedges": False, "shrink": 0.7, "label": "response probability"}, + ax=ax, + ) return ax - def plot_mean_trace_heatmap(mean_df, ax=None, save_dir=None, window=[-4, 8], interval_sec=2): """ There must be only one row per cell in the input df. @@ -295,8 +352,8 @@ def plot_mean_trace_heatmap(mean_df, ax=None, save_dir=None, window=[-4, 8], int trace[:] = np.nan response_array[x, :] = trace - sns.heatmap(data=response_array, vmin=0, vmax=np.percentile(response_array, 99), ax=ax, cmap='magma', cbar=False) - xticks, xticklabels = ut.get_xticks_xticklabels(trace, 31., interval_sec=interval_sec, window=window) + sns.heatmap(data=response_array, vmin=0, vmax=np.percentile(response_array, 99), ax=ax, cmap="magma", cbar=False) + xticks, xticklabels = ut.get_xticks_xticklabels(trace, 31.0, interval_sec=interval_sec, window=window) ax.set_xticks(xticks) ax.set_xticklabels([int(x) for x in xticklabels]) if response_array.shape[0] < 50: @@ -305,21 +362,21 @@ def plot_mean_trace_heatmap(mean_df, ax=None, save_dir=None, window=[-4, 8], int interval = 50 ax.set_yticks(np.arange(0, response_array.shape[0], interval)) ax.set_yticklabels(np.arange(0, response_array.shape[0], interval)) - ax.set_xlabel('time after change (s)', fontsize=16) - ax.set_ylabel('cells') + ax.set_xlabel("time after change (s)", fontsize=16) + ax.set_ylabel("cells") if save_dir: fig.tight_layout() - ut.save_figure(fig, figsize, save_dir, 'experiment_summary', 'mean_trace_heatmap_' + condition + suffix) + ut.save_figure(fig, figsize, save_dir, "experiment_summary", "mean_trace_heatmap_" + condition + suffix) return ax def plot_mean_image_response_heatmap(mean_df, title=None, ax=None, save_dir=None): df = mean_df.copy() - if 'change_image_name' in df.keys(): - image_key = 'change_image_name' + if "change_image_name" in df.keys(): + image_key = "change_image_name" else: - image_key = 'image_name' + image_key = "image_name" images = np.sort(df[image_key].unique()) cell_list = [] for image in images: @@ -341,16 +398,25 @@ def plot_mean_image_response_heatmap(mean_df, title=None, ax=None, save_dir=None fig, ax = plt.subplots(figsize=figsize) vmax = 0.3 - label = 'mean dF/F' - ax = sns.heatmap(response_matrix, cmap='magma', linewidths=0, linecolor='white', square=False, - vmin=0, vmax=vmax, robust=True, - cbar_kws={"drawedges": False, "shrink": 1, "label": label}, ax=ax) + label = "mean dF/F" + ax = sns.heatmap( + response_matrix, + cmap="magma", + linewidths=0, + linecolor="white", + square=False, + vmin=0, + vmax=vmax, + robust=True, + cbar_kws={"drawedges": False, "shrink": 1, "label": label}, + ax=ax, + ) if title is None: - title = 'mean response by image' - ax.set_title(title, va='bottom', ha='center') + title = "mean response by image" + ax.set_title(title, va="bottom", ha="center") ax.set_xticklabels(images, rotation=90) - ax.set_ylabel('cells') + ax.set_ylabel("cells") if response_matrix.shape[0] < 50: interval = 10 else: @@ -359,38 +425,38 @@ def plot_mean_image_response_heatmap(mean_df, title=None, ax=None, save_dir=None ax.set_yticklabels(np.arange(0, response_matrix.shape[0], interval)) if save_dir: fig.tight_layout() - ut.save_figure(fig, figsize, save_dir, 'experiment_summary', 'mean_image_response_heatmap' + suffix) + ut.save_figure(fig, figsize, save_dir, "experiment_summary", "mean_image_response_heatmap" + suffix) def plot_max_proj_and_roi_masks(session, save_dir=None): figsize = (15, 5) - fig, ax = plt.subplots(1,3,figsize=figsize) + fig, ax = plt.subplots(1, 3, figsize=figsize) ax = ax.ravel() - ax[0].imshow(session.max_projection, cmap='gray', vmin=0, vmax=np.amax(session.max_projection)) - ax[0].axis('off') - ax[0].set_title('max intensity projection') + ax[0].imshow(session.max_projection, cmap="gray", vmin=0, vmax=np.amax(session.max_projection)) + ax[0].axis("off") + ax[0].set_title("max intensity projection") - ax[1].imshow(session.segmentation_mask_image, cmap='gray') - ax[1].set_title('roi masks') - ax[1].axis('off') + ax[1].imshow(session.segmentation_mask_image, cmap="gray") + ax[1].set_title("roi masks") + ax[1].axis("off") - ax[2].imshow(session.max_projection, cmap='gray', vmin=0, vmax=np.amax(session.max_projection)) - ax[2].axis('off') - ax[2].set_title(str(session.metadata['ophys_experiment_id'])) + ax[2].imshow(session.max_projection, cmap="gray", vmin=0, vmax=np.amax(session.max_projection)) + ax[2].axis("off") + ax[2].set_title(str(session.metadata["ophys_experiment_id"])) tmp = session.segmentation_mask_image.data.copy() - mask = np.empty(session.segmentation_mask_image.data.shape, dtype='float') + mask = np.empty(session.segmentation_mask_image.data.shape, dtype="float") mask[:] = np.nan mask[tmp > 0] = 1 - ax[2].imshow(mask, cmap='hsv', alpha=0.4, vmin=0, vmax=1) + ax[2].imshow(mask, cmap="hsv", alpha=0.4, vmin=0, vmax=1) if save_dir: - ut.save_figure(fig, figsize, save_dir, 'roi_masks', str(session.metadata['ophys_experiment_id'])) + ut.save_figure(fig, figsize, save_dir, "roi_masks", str(session.metadata["ophys_experiment_id"])) def placeAxesOnGrid(fig, dim=[1, 1], xspan=[0, 1], yspan=[0, 1], wspace=None, hspace=None, sharex=False, sharey=False): - ''' + """ Takes a figure with a gridspec defined and places an array of sub-axes on a portion of the gridspec Takes as arguments: @@ -402,20 +468,27 @@ def placeAxesOnGrid(fig, dim=[1, 1], xspan=[0, 1], yspan=[0, 1], wspace=None, hs returns: subaxes handles - ''' + """ import matplotlib.gridspec as gridspec outer_grid = gridspec.GridSpec(100, 100) - inner_grid = gridspec.GridSpecFromSubplotSpec(dim[0], dim[1], - subplot_spec=outer_grid[int(100 * yspan[0]):int(100 * yspan[1]), - # flake8: noqa: E999 - int(100 * xspan[0]):int(100 * xspan[1])], wspace=wspace, - hspace=hspace) # flake8: noqa: E999 + inner_grid = gridspec.GridSpecFromSubplotSpec( + dim[0], + dim[1], + subplot_spec=outer_grid[ + int(100 * yspan[0]) : int(100 * yspan[1]), + # flake8: noqa: E999 + int(100 * xspan[0]) : int(100 * xspan[1]), + ], + wspace=wspace, + hspace=hspace, + ) # flake8: noqa: E999 # NOTE: A cleaner way to do this is with list comprehension: # inner_ax = [[0 for ii in range(dim[1])] for ii in range(dim[0])] - inner_ax = dim[0] * [dim[1] * [ - fig]] # filling the list with figure objects prevents an error when it they are later replaced by axis handles + inner_ax = dim[0] * [ + dim[1] * [fig] + ] # filling the list with figure objects prevents an error when it they are later replaced by axis handles inner_ax = np.array(inner_ax) idx = 0 for row in range(dim[0]): @@ -442,87 +515,95 @@ def plot_experiment_summary_figure(session, save_dir=None): import allensdk.brain_observatory.behavior.swdb.utilities as ut meta = session.metadata - title = meta['driver_line'][0] + ', ' + meta['targeted_structure'] + ', ' + str(meta['imaging_depth']) + ', ' + \ - session.task_parameters['stage'] + title = ( + meta["driver_line"][0] + + ", " + + meta["targeted_structure"] + + ", " + + str(meta["imaging_depth"]) + + ", " + + session.task_parameters["stage"] + ) figsize = [2 * 11, 2 * 8.5] - fig = plt.figure(figsize=figsize, facecolor='white') + fig = plt.figure(figsize=figsize, facecolor="white") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .2), yspan=(0, .2)) - ax.imshow(session.max_projection, cmap='gray') - ax.set_title('max intensity projection') - ax.axis('off') + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.0, 0.2), yspan=(0, 0.2)) + ax.imshow(session.max_projection, cmap="gray") + ax.set_title("max intensity projection") + ax.axis("off") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0, .18), yspan=(.24, .4)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0, 0.18), yspan=(0.24, 0.4)) trials = session.trials.copy() trials = trials[trials.reward_rate > 1] plot_transitions_response_heatmap(trials, ax=ax) - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.24, .86), yspan=(0, .26)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.24, 0.86), yspan=(0, 0.26)) ax = plot_traces_heatmap(session, ax=ax) ax.set_title(title) - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.28, .92), yspan=(.32, .44)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.28, 0.92), yspan=(0.32, 0.44)) ax.plot(session.running_speed.timestamps, session.running_speed.speed) - ax.set_xlabel('time (seconds)') - ax.set_ylabel('running speed\n(cm/s)') + ax.set_xlabel("time (seconds)") + ax.set_ylabel("running speed\n(cm/s)") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.86, 1.), yspan=(0, .2)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.86, 1.0), yspan=(0, 0.2)) image_index = 0 - ax.imshow(session.stimulus_templates[image_index, :, :], cmap='gray') + ax.imshow(session.stimulus_templates[image_index, :, :], cmap="gray") st = session.stimulus_presentations.copy() - image_name = st[st.image_index==image_index].image_name.values[0] + image_name = st[st.image_index == image_index].image_name.values[0] ax.set_title(image_name) - ax.axis('off') + ax.axis("off") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.0, .17), yspan=(.54, .99)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.0, 0.17), yspan=(0.54, 0.99)) ax = plot_lick_raster(session.trials, ax=ax) - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.24, .42), yspan=(.54, .99)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.24, 0.42), yspan=(0.54, 0.99)) fr = session.flash_response_df - mdf = ut.get_mean_df(fr, conditions=['cell_specimen_id', 'image_name']) + mdf = ut.get_mean_df(fr, conditions=["cell_specimen_id", "image_name"]) plot_mean_image_response_heatmap(mdf, title=None, ax=ax) - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.52, .68), yspan=(.54, .99)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.52, 0.68), yspan=(0.54, 0.99)) tr = session.trial_response_df.copy() - mdf = ut.get_mean_df(tr[tr.go], conditions=['cell_specimen_id']) - mdf['pref_stim'] = True + mdf = ut.get_mean_df(tr[tr.go], conditions=["cell_specimen_id"]) + mdf["pref_stim"] = True ax = plot_mean_trace_heatmap(mdf, ax=ax, window=[-4, 8], interval_sec=2) - ax.set_title('mean trace for pref image') - ax.set_ylabel('cells') + ax.set_title("mean trace for pref image") + ax.set_ylabel("cells") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, .98), yspan=(.5, .62)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.76, 0.98), yspan=(0.5, 0.62)) ax.plot(session.trials.reward_rate) - ax.set_ylabel('reward rate') - ax.set_xlabel('trials') + ax.set_ylabel("reward rate") + ax.set_xlabel("trials") - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, 0.98), yspan=(.68, .8)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.76, 0.98), yspan=(0.68, 0.8)) plot_behavior_segment(session, xlims=[620, 640], ax=ax) - ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(.76, .98), yspan=(.86, .99)) + ax = placeAxesOnGrid(fig, dim=(1, 1), xspan=(0.76, 0.98), yspan=(0.86, 0.99)) traces = tr[tr.go].dff_trace.values ax = ut.plot_mean_trace(traces, window=[-4, 8], ax=ax) ax = ut.plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=True) - ax.set_xlabel('time after change (sec)') - ax.set_ylabel('mean dF/F') + ax.set_xlabel("time after change (sec)") + ax.set_ylabel("mean dF/F") fig.tight_layout() if save_dir: fig.tight_layout() - ut.save_figure(fig, figsize, save_dir, 'experiment_summary', str(experiment_id)) + ut.save_figure(fig, figsize, save_dir, "experiment_summary", str(experiment_id)) -if __name__ == '__main__': +if __name__ == "__main__": import sys + experiment_id = sys.argv[1] cache_json = { - 'manifest_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/visual_behavior_data_manifest.csv', - 'nwb_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/nwb_files', - 'analysis_files_base_dir': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files', - 'analysis_files_metadata_path': '/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files_metadata.json', - } + "manifest_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/visual_behavior_data_manifest.csv", + "nwb_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/nwb_files", + "analysis_files_base_dir": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files", + "analysis_files_metadata_path": "/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/cache_20190813/analysis_files_metadata.json", + } # cache_json = { # 'manifest_path': r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\SWDB_2019\visual_behavior_data_manifest.csv', @@ -539,18 +620,18 @@ def plot_experiment_summary_figure(session, save_dir=None): # experiment_id = manifest.ophys_experiment_id.values[16] # save_dir = r'\\allen\programs\braintv\workgroups\nc-ophys\visual_behavior\SWDB_2019\summary_figures' - save_dir = r'/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/summary_figures' - print('loading session') + save_dir = r"/allen/programs/braintv/workgroups/nc-ophys/visual_behavior/SWDB_2019/summary_figures" + print("loading session") session = cache.get_session(experiment_id) - print('plotting experiment summary') + print("plotting experiment summary") plot_experiment_summary_figure(session, save_dir=save_dir) plot_max_proj_and_roi_masks(session, save_dir=save_dir) - print('plotting example traces') + print("plotting example traces") for xmin_seconds in np.arange(500, 1000, 60): plot_example_traces_and_behavior(session, xmin_seconds=xmin_seconds, length_mins=1, save_dir=save_dir) for xmin_seconds in np.arange(1600, 1800, 18): - plot_example_traces_and_behavior(session, xmin_seconds=xmin_seconds, length_mins=.3, save_dir=save_dir) - print('plotting behavior events') + plot_example_traces_and_behavior(session, xmin_seconds=xmin_seconds, length_mins=0.3, save_dir=save_dir) + print("plotting behavior events") for xmin in np.arange(0, 1200, 30): plot_behavior_events_trace(session, xmin=xmin, length=0.5, ax=None, save_dir=save_dir) - print('done') + print("done") diff --git a/allensdk/brain_observatory/behavior/swdb/utilities.py b/allensdk/brain_observatory/behavior/swdb/utilities.py index 36fd9e3f7f..2f7098b45e 100644 --- a/allensdk/brain_observatory/behavior/swdb/utilities.py +++ b/allensdk/brain_observatory/behavior/swdb/utilities.py @@ -4,168 +4,184 @@ import seaborn as sns import matplotlib as mpl -''' +""" This file contains a set of functions that are useful in analyzing visual behavior data -''' - - -def save_figure(fig, figsize, save_dir, folder, filename, formats=['.png']): - ''' - Function for saving a figure - - INPUTS: - fig: a figure object - figsize: tuple of desired figure size - save_dir: string, the directory to save the figure - folder: string, the sub-folder to save the figure in. if the folder does not exist, it will be created - filename: string, the desired name of the saved figure - formats: a list of file formats as strings to save the figure as, ex: ['.png','.pdf'] - ''' +""" + + +def save_figure(fig, figsize, save_dir, folder, filename, formats=[".png"]): + """ + Function for saving a figure + + INPUTS: + fig: a figure object + figsize: tuple of desired figure size + save_dir: string, the directory to save the figure + folder: string, the sub-folder to save the figure in. if the folder does not exist, it will be created + filename: string, the desired name of the saved figure + formats: a list of file formats as strings to save the figure as, ex: ['.png','.pdf'] + """ fig_dir = os.path.join(save_dir, folder) if not os.path.exists(fig_dir): os.mkdir(fig_dir) - mpl.rcParams['pdf.fonttype'] = 42 + mpl.rcParams["pdf.fonttype"] = 42 fig.set_size_inches(figsize) for f in formats: - fig.savefig(os.path.join(fig_dir, fig_title + f), transparent=True, orientation='landscape') + fig.savefig(os.path.join(fig_dir, fig_title + f), transparent=True, orientation="landscape") def get_dff_matrix(session): - ''' - Returns the dff_trace of a session as a numpy matrix - - INPUTS: - session: a behaviorOphysSession object - - OUTPUTS: - dff: a matrix of cells x dff_trace for the entire session - ''' - dff = np.stack(session.dff_traces.dff, axis=0) - return dff + """ + Returns the dff_trace of a session as a numpy matrix + INPUTS: + session: a behaviorOphysSession object -def get_mean_df(response_df, conditions=['cell_specimen_id', 'image_name']): - ''' - Computes an analysis on a selection of responses (either flashes or trials). Computes mean_response, sem_response, the pref_stim, fraction_active_responses. + OUTPUTS: + dff: a matrix of cells x dff_trace for the entire session + """ + dff = np.stack(session.dff_traces.dff, axis=0) + return dff - INPUTS - response_df: the dataframe to group - conditions: the conditions to group by, the first entry should be 'cell_specimen_id', the second could be 'image_name' or 'change_image_name' - OUTPUTS: - mdf: a dataframe with the following columns: - mean_response: the average mean_response for each condition - sem_response: the sem of the mean_response - mean_trace: the average dff trace for each condition - sem_trace: the sem of the mean_trace - mean_responses: the list of mean_responses for each element of each group - pref_stim: if conditions includes image_name or change_image_name, sets a boolean column for whether that was the cell's preferred stimulus - fraction_significant_responses: the fraction of individual image presentations or trials that were significant (p_value > 0.05) - ''' +def get_mean_df(response_df, conditions=["cell_specimen_id", "image_name"]): + """ + Computes an analysis on a selection of responses (either flashes or trials). Computes mean_response, sem_response, the pref_stim, fraction_active_responses. + + INPUTS + response_df: the dataframe to group + conditions: the conditions to group by, the first entry should be 'cell_specimen_id', the second could be 'image_name' or 'change_image_name' + + OUTPUTS: + mdf: a dataframe with the following columns: + mean_response: the average mean_response for each condition + sem_response: the sem of the mean_response + mean_trace: the average dff trace for each condition + sem_trace: the sem of the mean_trace + mean_responses: the list of mean_responses for each element of each group + pref_stim: if conditions includes image_name or change_image_name, sets a boolean column for whether that was the cell's preferred stimulus + fraction_significant_responses: the fraction of individual image presentations or trials that were significant (p_value > 0.05) + """ # Group by conditions rdf = response_df.copy() mdf = rdf.groupby(conditions).apply(get_mean_sem_trace) - mdf = mdf[['mean_response', 'sem_response', 'mean_trace', 'sem_trace', 'mean_responses']] + mdf = mdf[["mean_response", "sem_response", "mean_trace", "sem_trace", "mean_responses"]] mdf = mdf.reset_index() # Add preferred stimulus if we can - if ('image_name' in conditions) or ('change_image_name' in conditions): + if ("image_name" in conditions) or ("change_image_name" in conditions): mdf = annotate_mean_df_with_pref_stim(mdf) # What fraction of individual responses were significant? fraction_significant_responses = rdf.groupby(conditions).apply(get_fraction_significant_responses) fraction_significant_responses = fraction_significant_responses.reset_index() - mdf['fraction_significant_responses'] = fraction_significant_responses.fraction_significant_responses + mdf["fraction_significant_responses"] = fraction_significant_responses.fraction_significant_responses - if 'index' in mdf.keys(): - mdf = mdf.drop(columns=['index']) + if "index" in mdf.keys(): + mdf = mdf.drop(columns=["index"]) return mdf def get_mean_sem_trace(group): - ''' - Computes the average and sem of the mean_response column - - INPUTS: - group: a pandas groupby object - - OUTPUT: - a pandas series with the mean_response, sem_response, mean_trace, sem_trace, and mean_responses computed for the group. - ''' - mean_response = np.mean(group['mean_response']) - mean_responses = group['mean_response'].values - sem_response = np.std(group['mean_response'].values) / np.sqrt(len(group['mean_response'].values)) - mean_trace = np.mean(group['dff_trace']) - sem_trace = np.std(group['dff_trace'].values) / np.sqrt(len(group['dff_trace'].values)) - return pd.Series({'mean_response': mean_response, 'sem_response': sem_response, - 'mean_trace': mean_trace, 'sem_trace': sem_trace, - 'mean_responses': mean_responses}) + """ + Computes the average and sem of the mean_response column + + INPUTS: + group: a pandas groupby object + + OUTPUT: + a pandas series with the mean_response, sem_response, mean_trace, sem_trace, and mean_responses computed for the group. + """ + mean_response = np.mean(group["mean_response"]) + mean_responses = group["mean_response"].values + sem_response = np.std(group["mean_response"].values) / np.sqrt(len(group["mean_response"].values)) + mean_trace = np.mean(group["dff_trace"]) + sem_trace = np.std(group["dff_trace"].values) / np.sqrt(len(group["dff_trace"].values)) + return pd.Series( + { + "mean_response": mean_response, + "sem_response": sem_response, + "mean_trace": mean_trace, + "sem_trace": sem_trace, + "mean_responses": mean_responses, + } + ) def annotate_mean_df_with_pref_stim(mean_df): - ''' - Computes the preferred stimulus for each cell/trial or cell/flash combination. Preferred image is computed by seeing which image evoked the largest average mean_response across all images. + """ + Computes the preferred stimulus for each cell/trial or cell/flash combination. Preferred image is computed by seeing which image evoked the largest average mean_response across all images. - INPUTS: - mean_df: the mean_df to be annotated + INPUTS: + mean_df: the mean_df to be annotated - OUTPUTS: - mean_df with a new column appended 'pref_stim' which is a boolean TRUE/FALSE for whether that image was that cell's preferred image. - - ASSERTS: - Each cell has one unique preferred stimulus - ''' + OUTPUTS: + mean_df with a new column appended 'pref_stim' which is a boolean TRUE/FALSE for whether that image was that cell's preferred image. + + ASSERTS: + Each cell has one unique preferred stimulus + """ # Are we dealing with flash_response or trial_response - if 'image_name' in mean_df.keys(): - image_name = 'image_name' + if "image_name" in mean_df.keys(): + image_name = "image_name" else: - image_name = 'change_image_name' + image_name = "change_image_name" # set up dataframe mdf = mean_df.reset_index() - mdf['pref_stim'] = False + mdf["pref_stim"] = False - # Iterate through cells in df - for cell in mdf['cell_specimen_id'].unique(): - mc = mdf[(mdf['cell_specimen_id'] == cell)] - mc = mc[mc[image_name] != 'omitted'] + # Iterate through cells in df + for cell in mdf["cell_specimen_id"].unique(): + mc = mdf[(mdf["cell_specimen_id"] == cell)] + mc = mc[mc[image_name] != "omitted"] temp = mc[(mc.mean_response == np.max(mc.mean_response.values))][image_name].values if len(temp) > 0: # need this test if the mean_response was nan pref_image = temp[0] # PROBLEM, this is slow, and sets on slice, better to use mdf.at[test, 'pref_stim'] - row = mdf[(mdf['cell_specimen_id'] == cell) & (mdf[image_name] == pref_image)].index - mdf.loc[row, 'pref_stim'] = True + row = mdf[(mdf["cell_specimen_id"] == cell) & (mdf[image_name] == pref_image)].index + mdf.loc[row, "pref_stim"] = True # Test to ensure preferred stimulus is unique for each cell - for cell in mdf.reset_index()['cell_specimen_id'].unique(): - if image_name == 'image_name': - assert len( - mdf.reset_index().set_index('cell_specimen_id').loc[cell].query('pref_stim').image_name.unique()) == 1 + for cell in mdf.reset_index()["cell_specimen_id"].unique(): + if image_name == "image_name": + assert ( + len(mdf.reset_index().set_index("cell_specimen_id").loc[cell].query("pref_stim").image_name.unique()) + == 1 + ) else: - assert len(mdf.reset_index().set_index('cell_specimen_id').loc[cell].query( - 'pref_stim').change_image_name.unique()) == 1 + assert ( + len( + mdf.reset_index() + .set_index("cell_specimen_id") + .loc[cell] + .query("pref_stim") + .change_image_name.unique() + ) + == 1 + ) return mdf def get_fraction_significant_responses(group, threshold=0.05): - ''' - Calculates the fraction of trials or flashes that have a p_value below threshold - Note that this function does not handle multiple comparisons - - INPUT: - group: a pandas groupby object - threshold: the p_value threshold for significance for an individual response - - OUTPUT: - a pandas series with column 'fraction_significant_responses' - ''' + """ + Calculates the fraction of trials or flashes that have a p_value below threshold + Note that this function does not handle multiple comparisons + + INPUT: + group: a pandas groupby object + threshold: the p_value threshold for significance for an individual response + + OUTPUT: + a pandas series with column 'fraction_significant_responses' + """ fraction_significant_responses = len(group[group.p_value < threshold]) / float(len(group)) - return pd.Series({'fraction_significant_responses': fraction_significant_responses}) + return pd.Series({"fraction_significant_responses": fraction_significant_responses}) -def get_xticks_xticklabels(trace, ophys_frame_rate=31., interval_sec=1, window=[-4, 8]): +def get_xticks_xticklabels(trace, ophys_frame_rate=31.0, interval_sec=1, window=[-4, 8]): """ Function that accepts a timeseries, evaluates the number of points in the trace, and converts from acquisition frames to timestamps relative to a given window of time covered by the trace. @@ -190,7 +206,7 @@ def get_xticks_xticklabels(trace, ophys_frame_rate=31., interval_sec=1, window=[ return xticks, xticklabels -def plot_mean_trace(traces, window=[-4, 8], interval_sec=1, ylabel='dF/F', legend_label=None, color='k', ax=None): +def plot_mean_trace(traces, window=[-4, 8], interval_sec=1, ylabel="dF/F", legend_label=None, color="k", ax=None): """ Function that accepts an array of single trial traces and plots the mean and SEM of the trace, with xticklabels in seconds @@ -205,7 +221,7 @@ def plot_mean_trace(traces, window=[-4, 8], interval_sec=1, ylabel='dF/F', legen :return: axis handle """ - ophys_frame_rate = 31. # PROBLEM, shouldn't hard code this here + ophys_frame_rate = 31.0 # PROBLEM, shouldn't hard code this here if ax is None: fig, ax = plt.subplots() if len(traces) > 0: @@ -222,14 +238,15 @@ def plot_mean_trace(traces, window=[-4, 8], interval_sec=1, ylabel='dF/F', legen else: ax.set_xticklabels([int(x) for x in xticklabels]) ax.set_xlim(0, len(trace)) - ax.set_xlabel('time (sec)') + ax.set_xlabel("time (sec)") ax.set_ylabel(ylabel) sns.despine(ax=ax) return ax -def plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=False, omitted=False, flashes=False, alpha=0.25, - facecolor='gray'): +def plot_flashes_on_trace( + ax, window=[-4, 8], go_trials_only=False, omitted=False, flashes=False, alpha=0.25, facecolor="gray" +): """ Function to create transparent gray bars spanning the duration of visual stimulus presentations to overlay on existing figure @@ -242,9 +259,9 @@ def plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=False, omitted=Fals :return: axis handle """ # PROBLEM: shouldn't hard code these things here - frame_rate = 31. - stim_duration = .25 - blank_duration = .5 + frame_rate = 31.0 + stim_duration = 0.25 + blank_duration = 0.5 change_frame = np.abs(window[0]) * frame_rate end_frame = (window[1] + np.abs(window[0])) * frame_rate interval = blank_duration + stim_duration @@ -256,7 +273,7 @@ def plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=False, omitted=Fals for i, vals in enumerate(array): amin = array[i] amax = array[i] + (stim_duration * frame_rate) - ax.axvspan(amin, amax, facecolor=facecolor, edgecolor='none', alpha=alpha, linewidth=0, zorder=1) + ax.axvspan(amin, amax, facecolor=facecolor, edgecolor="none", alpha=alpha, linewidth=0, zorder=1) if go_trials_only: alpha = alpha * 3 else: @@ -265,32 +282,33 @@ def plot_flashes_on_trace(ax, window=[-4, 8], go_trials_only=False, omitted=Fals for i, vals in enumerate(array): amin = array[i] amax = array[i] - (stim_duration * frame_rate) - ax.axvspan(amin, amax, facecolor=facecolor, edgecolor='none', alpha=alpha, linewidth=0, zorder=1) + ax.axvspan(amin, amax, facecolor=facecolor, edgecolor="none", alpha=alpha, linewidth=0, zorder=1) return ax -def create_multi_session_mean_df(cache, experiment_ids, conditions=['cell_specimen_id', 'change_image_name'], - flashes=False): - ''' - Creates a mean response dataframe by combining multiple sessions. - - INPUTS: - cache: the cache object for the dataset - experiment_ids: a list of experiment_ids for sessions to merge - conditions: the set of conditions to group by. The first entry should be 'cell_specimen_id' - flashes: if TRUE, uses the flash_response_df to merge, otherwise uses the trial_response_df - - OUTPUTS - mega_mdf, a dataframe with index given by the session experiment ids. This allows for easy analysis like: - mega_mdf.groupby('experiment_id').mean_response.mean() - ''' +def create_multi_session_mean_df( + cache, experiment_ids, conditions=["cell_specimen_id", "change_image_name"], flashes=False +): + """ + Creates a mean response dataframe by combining multiple sessions. + + INPUTS: + cache: the cache object for the dataset + experiment_ids: a list of experiment_ids for sessions to merge + conditions: the set of conditions to group by. The first entry should be 'cell_specimen_id' + flashes: if TRUE, uses the flash_response_df to merge, otherwise uses the trial_response_df + + OUTPUTS + mega_mdf, a dataframe with index given by the session experiment ids. This allows for easy analysis like: + mega_mdf.groupby('experiment_id').mean_response.mean() + """ manifest = cache.experiment_table mega_mdf = pd.DataFrame() # Iterate through experiments for experiment_id in experiment_ids: # load the session object session = cache.get_session(experiment_id) - print(session.metadata['ophys_experiment_id']) + print(session.metadata["ophys_experiment_id"]) # Get the individual session mean_df if flashes: mdf = get_mean_df(session.flash_response_df, conditions=conditions) @@ -298,52 +316,53 @@ def create_multi_session_mean_df(cache, experiment_ids, conditions=['cell_specim mdf = get_mean_df(session.trial_response_df, conditions=conditions) # Append metadata - mdf['experiment_id'] = session.metadata['ophys_experiment_id'] - mdf['experiment_container_id'] = session.metadata['experiment_container_id'] - stage = manifest[manifest.ophys_experiment_id == session.metadata['ophys_experiment_id']].stage_name.values[0] - mdf['stage_name'] = stage - mdf['passive'] = parse_stage_for_passive(stage) - mdf['image_set'] = parse_stage_for_image_set(stage) - mdf['targeted_structure'] = session.metadata['targeted_structure'] - mdf['imaging_depth'] = session.metadata['imaging_depth'] - mdf['full_genotype'] = session.metadata['full_genotype'] - mdf['cre_line'] = session.metadata['full_genotype'].split('/')[0] - mdf['retake_number'] = \ - manifest[manifest.ophys_experiment_id == session.metadata['ophys_experiment_id']].retake_number.values[0] + mdf["experiment_id"] = session.metadata["ophys_experiment_id"] + mdf["experiment_container_id"] = session.metadata["experiment_container_id"] + stage = manifest[manifest.ophys_experiment_id == session.metadata["ophys_experiment_id"]].stage_name.values[0] + mdf["stage_name"] = stage + mdf["passive"] = parse_stage_for_passive(stage) + mdf["image_set"] = parse_stage_for_image_set(stage) + mdf["targeted_structure"] = session.metadata["targeted_structure"] + mdf["imaging_depth"] = session.metadata["imaging_depth"] + mdf["full_genotype"] = session.metadata["full_genotype"] + mdf["cre_line"] = session.metadata["full_genotype"].split("/")[0] + mdf["retake_number"] = manifest[ + manifest.ophys_experiment_id == session.metadata["ophys_experiment_id"] + ].retake_number.values[0] # Concatenate this session to the other sessions mega_mdf = pd.concat([mega_mdf, mdf]) # Clean up indexes mega_mdf = mega_mdf.reset_index() - mega_mdf = mega_mdf.set_index('experiment_id') - if 'index' in mega_mdf.keys(): - mega_mdf = mega_mdf.drop(columns=['index']) - if 'level_0' in mega_mdf.keys(): - mega_mdf = mega_mdf.drop(columns=['level_0']) + mega_mdf = mega_mdf.set_index("experiment_id") + if "index" in mega_mdf.keys(): + mega_mdf = mega_mdf.drop(columns=["index"]) + if "level_0" in mega_mdf.keys(): + mega_mdf = mega_mdf.drop(columns=["level_0"]) return mega_mdf def parse_stage_for_passive(stage): - ''' - Returns TRUE if the stage_name indicates a passive sessions - ''' - return 'passive' in stage + """ + Returns TRUE if the stage_name indicates a passive sessions + """ + return "passive" in stage def parse_stage_for_image_set(stage): - ''' - Returns the character for the image_set, for example 'A' - ''' + """ + Returns the character for the image_set, for example 'A' + """ return stage[15] def get_active_cell_indices(dff_traces): - ''' - Returns the ten most active cells. - Computes active cells by SNR = mean/std over all timepoints. - ''' + """ + Returns the ten most active cells. + Computes active cells by SNR = mean/std over all timepoints. + """ snr_values = [] for i, trace in enumerate(dff_traces): mean = np.mean(trace, axis=0) @@ -360,6 +379,7 @@ def compute_lifetime_sparseness(image_responses): # N = number of images # after Vinje & Gallant, 2000; Froudarakis et al., 2014 N = float(len(image_responses)) - ls = ((1 - (1 / N) * ((np.power(image_responses.sum(axis=0), 2)) / (np.power(image_responses, 2).sum(axis=0)))) / ( - 1 - (1 / N))) + ls = (1 - (1 / N) * ((np.power(image_responses.sum(axis=0), 2)) / (np.power(image_responses, 2).sum(axis=0)))) / ( + 1 - (1 / N) + ) return ls diff --git a/allensdk/brain_observatory/behavior/sync/__init__.py b/allensdk/brain_observatory/behavior/sync/__init__.py index c0677b80e4..51d34b1d33 100644 --- a/allensdk/brain_observatory/behavior/sync/__init__.py +++ b/allensdk/brain_observatory/behavior/sync/__init__.py @@ -3,17 +3,14 @@ @author: marinag """ + from typing import Dict, Optional, List, Union -from allensdk.brain_observatory.sync_dataset import \ - Dataset as SyncDataset +from allensdk.brain_observatory.sync_dataset import Dataset as SyncDataset import numpy as np -def get_raw_stimulus_frames( - dataset: SyncDataset, - permissive: bool = False -) -> np.ndarray: - """ Report the raw timestamps of each stimulus frame. This corresponds to +def get_raw_stimulus_frames(dataset: SyncDataset, permissive: bool = False) -> np.ndarray: + """Report the raw timestamps of each stimulus frame. This corresponds to the time at which the psychopy window's flip method returned, but not necessarily to the time at which the stimulus frame was displayed. @@ -30,20 +27,15 @@ def get_raw_stimulus_frames( """ try: - return dataset.get_edges(kind="falling", - keys=["vsync_stim", "stim_vsync"], - units="seconds") + return dataset.get_edges(kind="falling", keys=["vsync_stim", "stim_vsync"], units="seconds") except KeyError: if not permissive: raise return -def get_ophys_frames( - dataset: SyncDataset, - permissive: bool = False -) -> np.ndarray: - """ Report the timestamps of each optical physiology video frame +def get_ophys_frames(dataset: SyncDataset, permissive: bool = False) -> np.ndarray: + """Report the timestamps of each optical physiology video frame Parameters ---------- @@ -64,20 +56,15 @@ def get_ophys_frames( """ try: - return dataset.get_edges(kind="rising", - keys=["vsync_2p", "2p_vsync"], - units="seconds") + return dataset.get_edges(kind="rising", keys=["vsync_2p", "2p_vsync"], units="seconds") except KeyError: if not permissive: raise return -def get_lick_times( - dataset: SyncDataset, - permissive: bool = False -) -> Optional[np.ndarray]: - """ Report the timestamps of each detected lick +def get_lick_times(dataset: SyncDataset, permissive: bool = False) -> Optional[np.ndarray]: + """Report the timestamps of each detected lick Parameters ---------- @@ -93,15 +80,11 @@ def get_lick_times( dataset. """ - return dataset.get_edges( - "rising", ["lick_times", "lick_sensor"], "seconds", permissive) + return dataset.get_edges("rising", ["lick_times", "lick_sensor"], "seconds", permissive) -def get_stim_photodiode( - dataset: SyncDataset, - permissive: bool = False -) -> Optional[List[float]]: - """ Report the timestamps of each detected sync square transition (both +def get_stim_photodiode(dataset: SyncDataset, permissive: bool = False) -> Optional[List[float]]: + """Report the timestamps of each detected sync square transition (both black -> white and white -> black) in this experiment. Parameters @@ -117,48 +100,37 @@ def get_stim_photodiode( dataset. """ - return dataset.get_edges( - "all", ["stim_photodiode", "photodiode"], "seconds", permissive) + return dataset.get_edges("all", ["stim_photodiode", "photodiode"], "seconds", permissive) -def get_trigger( - dataset: SyncDataset, - permissive: bool = False -) -> Optional[np.ndarray]: - """ Returns (as a 1-element array) the time at which optical physiology - acquisition was started. +def get_trigger(dataset: SyncDataset, permissive: bool = False) -> Optional[np.ndarray]: + """Returns (as a 1-element array) the time at which optical physiology + acquisition was started. - Parameters - ---------- - dataset : describes experiment timing - permissive : If True, None will be returned if timestamps are not found. - If False, a KeyError will be raised + Parameters + ---------- + dataset : describes experiment timing + permissive : If True, None will be returned if timestamps are not found. + If False, a KeyError will be raised - Returns - ------- - timestamps (floating point; seconds; relative to experiment start) - or None. If None, no timestamps were found in this sync dataset. + Returns + ------- + timestamps (floating point; seconds; relative to experiment start) + or None. If None, no timestamps were found in this sync dataset. - Notes - ----- - Ophys frame timestamps can be recorded before acquisition start when - experimenters are setting up the recording session. These do not - correspond to acquired ophys frames. + Notes + ----- + Ophys frame timestamps can be recorded before acquisition start when + experimenters are setting up the recording session. These do not + correspond to acquired ophys frames. """ - keys = ["2p_trigger", "acq_trigger", "2p_acq_trigger", "2p_acquiring", - "stim_running"] - return dataset.get_edges(kind="rising", - keys=keys, - units="seconds", - permissive=permissive) + keys = ["2p_trigger", "acq_trigger", "2p_acq_trigger", "2p_acquiring", "stim_running"] + return dataset.get_edges(kind="rising", keys=keys, units="seconds", permissive=permissive) -def get_eye_tracking( - dataset: SyncDataset, - permissive: bool = False -) -> Optional[np.ndarray]: - """ Report the timestamps of each frame of the eye tracking video +def get_eye_tracking(dataset: SyncDataset, permissive: bool = False) -> Optional[np.ndarray]: + """Report the timestamps of each frame of the eye tracking video Parameters ---------- @@ -174,17 +146,11 @@ def get_eye_tracking( """ keys = ["cam2_exposure", "eye_tracking", "eye_frame_received"] - return dataset.get_edges(kind="rising", - keys=keys, - units="seconds", - permissive=permissive) + return dataset.get_edges(kind="rising", keys=keys, units="seconds", permissive=permissive) -def get_behavior_monitoring( - dataset: SyncDataset, - permissive: bool = False -) -> Optional[np.ndarray]: - """ Report the timestamps of each frame of the behavior +def get_behavior_monitoring(dataset: SyncDataset, permissive: bool = False) -> Optional[np.ndarray]: + """Report the timestamps of each frame of the behavior monitoring video Parameters @@ -201,17 +167,11 @@ def get_behavior_monitoring( """ keys = ["cam1_exposure", "behavior_monitoring", "beh_frame_received"] - return dataset.get_edges(kind="rising", - keys=keys, - units="seconds", - permissive=permissive) + return dataset.get_edges(kind="rising", keys=keys, units="seconds", permissive=permissive) -def get_sync_data( - sync_path: str, - permissive: bool = False -) -> Dict[str, Union[List, np.ndarray, None]]: - """ Convenience function for extracting several timestamp arrays from a +def get_sync_data(sync_path: str, permissive: bool = False) -> Dict[str, Union[List, np.ndarray, None]]: + """Convenience function for extracting several timestamp arrays from a sync file. Parameters @@ -240,13 +200,11 @@ def get_sync_data( sync_dataset = SyncDataset(sync_path) return { - 'ophys_frames': get_ophys_frames(sync_dataset, permissive), - 'lick_times': get_lick_times(sync_dataset, permissive), - 'ophys_trigger': get_trigger(sync_dataset, permissive), - 'eye_tracking': get_eye_tracking(sync_dataset, permissive), - 'behavior_monitoring': get_behavior_monitoring(sync_dataset, - permissive), - 'stim_photodiode': get_stim_photodiode(sync_dataset, permissive), - 'stimulus_times_no_delay': get_raw_stimulus_frames(sync_dataset, - permissive) + "ophys_frames": get_ophys_frames(sync_dataset, permissive), + "lick_times": get_lick_times(sync_dataset, permissive), + "ophys_trigger": get_trigger(sync_dataset, permissive), + "eye_tracking": get_eye_tracking(sync_dataset, permissive), + "behavior_monitoring": get_behavior_monitoring(sync_dataset, permissive), + "stim_photodiode": get_stim_photodiode(sync_dataset, permissive), + "stimulus_times_no_delay": get_raw_stimulus_frames(sync_dataset, permissive), } diff --git a/allensdk/brain_observatory/behavior/sync/process_sync.py b/allensdk/brain_observatory/behavior/sync/process_sync.py index f836842a97..180e45681d 100644 --- a/allensdk/brain_observatory/behavior/sync/process_sync.py +++ b/allensdk/brain_observatory/behavior/sync/process_sync.py @@ -1,9 +1,8 @@ - - import numpy as np import logging + logger = logging.getLogger(__name__) @@ -38,11 +37,11 @@ def calculate_delay(sync_data, stim_vsync_fall, sample_frequency): ROUND_PRECISION = 4 ONE = 1 - logger.info('calculating monitor delay') + logger.info("calculating monitor delay") # try: # photodiode transitions - photodiode_rise = sync_data.get_rising_edges('stim_photodiode') / sample_frequency + photodiode_rise = sync_data.get_rising_edges("stim_photodiode") / sample_frequency # Find start and stop of stimulus # test and correct for photodiode transition errors @@ -53,12 +52,16 @@ def calculate_delay(sync_data, stim_vsync_fall, sample_frequency): max_medium_photodiode_rise = 1.5 # find the short and medium length photodiode rises - short_rise_indexes = np.where(np.logical_and(photodiode_rise_diff > min_short_photodiode_rise, - photodiode_rise_diff < max_short_photodiode_rise))[ - FIRST_ELEMENT_INDEX] - medium_rise_indexes = np.where(np.logical_and(photodiode_rise_diff > min_medium_photodiode_rise, - photodiode_rise_diff < max_medium_photodiode_rise))[ - FIRST_ELEMENT_INDEX] + short_rise_indexes = np.where( + np.logical_and( + photodiode_rise_diff > min_short_photodiode_rise, photodiode_rise_diff < max_short_photodiode_rise + ) + )[FIRST_ELEMENT_INDEX] + medium_rise_indexes = np.where( + np.logical_and( + photodiode_rise_diff > min_medium_photodiode_rise, photodiode_rise_diff < max_medium_photodiode_rise + ) + )[FIRST_ELEMENT_INDEX] short_set = set(short_rise_indexes) @@ -86,8 +89,10 @@ def calculate_delay(sync_data, stim_vsync_fall, sample_frequency): # iterate until all of the errors have been corrected while any(photodiode_rise_diff[ptd_start:ptd_end] < photodiode_rise_error_threshold): - error_frames = np.where(photodiode_rise_diff[ptd_start:ptd_end] < photodiode_rise_error_threshold)[ - FIRST_ELEMENT_INDEX] + ptd_start + error_frames = ( + np.where(photodiode_rise_diff[ptd_start:ptd_end] < photodiode_rise_error_threshold)[FIRST_ELEMENT_INDEX] + + ptd_start + ) # remove the bad photodiode event photodiode_rise = np.delete(photodiode_rise, error_frames[last_frame_index]) ptd_end -= 1 @@ -102,16 +107,24 @@ def calculate_delay(sync_data, stim_vsync_fall, sample_frequency): delay_rise = np.empty(number_of_photodiode_rises) for photodiode_rise_index in range(number_of_photodiode_rises): - delay_rise[photodiode_rise_index] = photodiode_rise[photodiode_rise_index + first_pulse] - \ - stim_vsync_fall[(photodiode_rise_index * vsync_fall_events_per_photodiode_rise) + half_vsync_fall_events_per_photodiode_rise] + delay_rise[photodiode_rise_index] = ( + photodiode_rise[photodiode_rise_index + first_pulse] + - stim_vsync_fall[ + (photodiode_rise_index * vsync_fall_events_per_photodiode_rise) + + half_vsync_fall_events_per_photodiode_rise + ] + ) # get a single delay value by finding the mean of all of the delays - skip the last element in the array (the end of the experimenet) delay = np.mean(delay_rise[:last_frame_index]) delay_std = np.std(delay_rise[:last_frame_index]) - if (delay_std > DELAY_THRESHOLD or np.isnan(delay)): - - logger.error("Sync photodiode error needs to be fixed. Using assumed monitor delay: {}".format(round(delay, ROUND_PRECISION))) + if delay_std > DELAY_THRESHOLD or np.isnan(delay): + logger.error( + "Sync photodiode error needs to be fixed. Using assumed monitor delay: {}".format( + round(delay, ROUND_PRECISION) + ) + ) raise # assume delay diff --git a/allensdk/brain_observatory/behavior/trial_masks.py b/allensdk/brain_observatory/behavior/trial_masks.py index 363b10b436..7b55b8be2d 100644 --- a/allensdk/brain_observatory/behavior/trial_masks.py +++ b/allensdk/brain_observatory/behavior/trial_masks.py @@ -1,8 +1,9 @@ import numpy as np import pandas as pd + def trial_types(trials, trial_types): - """ only include trials of certain trial types + """only include trials of certain trial types Parameters ---------- @@ -18,14 +19,13 @@ def trial_types(trials, trial_types): """ if trial_types is not None and len(trial_types) > 0: - return trials['trial_type'].isin(trial_types) + return trials["trial_type"].isin(trial_types) else: - return pd.Series(np.ones((len(trials), ), dtype=bool), - name="trial_type", index=trials.index) + return pd.Series(np.ones((len(trials),), dtype=bool), name="trial_type", index=trials.index) def contingent_trials(trials): - """ GO & CATCH trials only + """GO & CATCH trials only Parameters ---------- @@ -37,11 +37,11 @@ def contingent_trials(trials): mask : pandas Series of booleans, indexed to trials DataFrame """ - return trial_types(trials, ('go', 'catch')) + return trial_types(trials, ("go", "catch")) def reward_rate(trials, thresh=2.0): - """ masks trials where the reward rate (per minute) is below some threshold. + """masks trials where the reward rate (per minute) is below some threshold. This de facto omits trials in which the animal was not licking for extended periods or periods when they were licking indiscriminantly. @@ -59,5 +59,5 @@ def reward_rate(trials, thresh=2.0): """ - mask = trials['reward_rate'] > thresh + mask = trials["reward_rate"] > thresh return mask diff --git a/allensdk/brain_observatory/behavior/trials_processing.py b/allensdk/brain_observatory/behavior/trials_processing.py index a6baea7bc9..bcfa5bddcc 100644 --- a/allensdk/brain_observatory/behavior/trials_processing.py +++ b/allensdk/brain_observatory/behavior/trials_processing.py @@ -1,79 +1,125 @@ -EDF_COLUMNS = ['index', 'lick_times', 'auto_rewarded', 'cumulative_volume', - 'cumulative_reward_number', 'reward_volume', 'reward_times', - 'reward_frames', 'rewarded', 'optogenetics', 'response_type', - 'response_time', 'change_time', 'change_frame', - 'response_latency', 'starttime', 'startframe', 'trial_length', - 'scheduled_change_time', 'endtime', 'endframe', - 'initial_image_category', 'initial_image_name', - 'change_image_name', 'change_image_category', 'change_ori', - 'change_contrast', 'initial_ori', 'initial_contrast', - 'delta_ori', 'mouse_id', 'response_window', 'task', 'stage', - 'session_duration', 'user_id', 'LDT_mode', - 'blank_screen_timeout', 'stim_duration', - 'blank_duration_range', 'prechange_minimum', - 'stimulus_distribution', 'stimulus', 'distribution_mean', - 'computer_name', 'behavior_session_uuid', 'startdatetime', - 'date', 'year', 'month', 'day', 'hour', 'dayofweek', - 'number_of_rewards', 'rig_id', 'trial_type', - 'lick_frames', 'reward_licks', 'reward_lick_count', - 'reward_lick_latency', 'reward_rate', 'response', 'color'] +EDF_COLUMNS = [ + "index", + "lick_times", + "auto_rewarded", + "cumulative_volume", + "cumulative_reward_number", + "reward_volume", + "reward_times", + "reward_frames", + "rewarded", + "optogenetics", + "response_type", + "response_time", + "change_time", + "change_frame", + "response_latency", + "starttime", + "startframe", + "trial_length", + "scheduled_change_time", + "endtime", + "endframe", + "initial_image_category", + "initial_image_name", + "change_image_name", + "change_image_category", + "change_ori", + "change_contrast", + "initial_ori", + "initial_contrast", + "delta_ori", + "mouse_id", + "response_window", + "task", + "stage", + "session_duration", + "user_id", + "LDT_mode", + "blank_screen_timeout", + "stim_duration", + "blank_duration_range", + "prechange_minimum", + "stimulus_distribution", + "stimulus", + "distribution_mean", + "computer_name", + "behavior_session_uuid", + "startdatetime", + "date", + "year", + "month", + "day", + "hour", + "dayofweek", + "number_of_rewards", + "rig_id", + "trial_type", + "lick_frames", + "reward_licks", + "reward_lick_count", + "reward_lick_latency", + "reward_rate", + "response", + "color", +] RIG_NAME = { - 'W7DTMJ19R2F': 'A1', - 'W7DTMJ35Y0T': 'A2', - 'W7DTMJ03J70R': 'Dome', - 'W7VS-SYSLOGIC2': 'A3', - 'W7VS-SYSLOGIC3': 'A4', - 'W7VS-SYSLOGIC4': 'A5', - 'W7VS-SYSLOGIC5': 'A6', - 'W7VS-SYSLOGIC7': 'B1', - 'W7VS-SYSLOGIC8': 'B2', - 'W7VS-SYSLOGIC9': 'B3', - 'W7VS-SYSLOGIC10': 'B4', - 'W7VS-SYSLOGIC11': 'B5', - 'W7VS-SYSLOGIC12': 'B6', - 'W7VS-SYSLOGIC13': 'C1', - 'W7VS-SYSLOGIC14': 'C2', - 'W7VS-SYSLOGIC15': 'C3', - 'W7VS-SYSLOGIC16': 'C4', - 'W7VS-SYSLOGIC17': 'C5', - 'W7VS-SYSLOGIC18': 'C6', - 'W7VS-SYSLOGIC19': 'D1', - 'W7VS-SYSLOGIC20': 'D2', - 'W7VS-SYSLOGIC21': 'D3', - 'W7VS-SYSLOGIC22': 'D4', - 'W7VS-SYSLOGIC23': 'D5', - 'W7VS-SYSLOGIC24': 'D6', - 'W7VS-SYSLOGIC31': 'E1', - 'W7VS-SYSLOGIC32': 'E2', - 'W7VS-SYSLOGIC33': 'E3', - 'W7VS-SYSLOGIC34': 'E4', - 'W7VS-SYSLOGIC35': 'E5', - 'W7VS-SYSLOGIC36': 'E6', - 'W7DT102905': 'F1', - 'W10DT102905': 'F1', - 'W7DT102904': 'F2', - 'W7DT102903': 'F3', - 'W7DT102914': 'F4', - 'W7DT102913': 'F5', - 'W7DT12497': 'F6', - 'W7DT102906': 'G1', - 'W7DT102907': 'G2', - 'W7DT102908': 'G3', - 'W7DT102909': 'G4', - 'W7DT102910': 'G5', - 'W7DT102911': 'G6', - 'W7VS-SYSLOGIC26': 'Widefield-329', - 'OSXLTTF6T6.local': 'DougLaptop', - 'W7DTMJ026LUL': 'DougPC', - 'W7DTMJ036PSL': 'Marina2P_Sutter', - 'W7DT2PNC1STIM': '2P6', - 'W7DTMJ234MG': 'peterl_2p', - 'W7DT2P3STiM': '2P3', - 'W7DT2P4STIM': '2P4', - 'W7DT2P5STIM': '2P5', - 'W10DTSM118296': 'NP3', - 'meso1stim': 'MS1', - 'localhost': 'localhost' + "W7DTMJ19R2F": "A1", + "W7DTMJ35Y0T": "A2", + "W7DTMJ03J70R": "Dome", + "W7VS-SYSLOGIC2": "A3", + "W7VS-SYSLOGIC3": "A4", + "W7VS-SYSLOGIC4": "A5", + "W7VS-SYSLOGIC5": "A6", + "W7VS-SYSLOGIC7": "B1", + "W7VS-SYSLOGIC8": "B2", + "W7VS-SYSLOGIC9": "B3", + "W7VS-SYSLOGIC10": "B4", + "W7VS-SYSLOGIC11": "B5", + "W7VS-SYSLOGIC12": "B6", + "W7VS-SYSLOGIC13": "C1", + "W7VS-SYSLOGIC14": "C2", + "W7VS-SYSLOGIC15": "C3", + "W7VS-SYSLOGIC16": "C4", + "W7VS-SYSLOGIC17": "C5", + "W7VS-SYSLOGIC18": "C6", + "W7VS-SYSLOGIC19": "D1", + "W7VS-SYSLOGIC20": "D2", + "W7VS-SYSLOGIC21": "D3", + "W7VS-SYSLOGIC22": "D4", + "W7VS-SYSLOGIC23": "D5", + "W7VS-SYSLOGIC24": "D6", + "W7VS-SYSLOGIC31": "E1", + "W7VS-SYSLOGIC32": "E2", + "W7VS-SYSLOGIC33": "E3", + "W7VS-SYSLOGIC34": "E4", + "W7VS-SYSLOGIC35": "E5", + "W7VS-SYSLOGIC36": "E6", + "W7DT102905": "F1", + "W10DT102905": "F1", + "W7DT102904": "F2", + "W7DT102903": "F3", + "W7DT102914": "F4", + "W7DT102913": "F5", + "W7DT12497": "F6", + "W7DT102906": "G1", + "W7DT102907": "G2", + "W7DT102908": "G3", + "W7DT102909": "G4", + "W7DT102910": "G5", + "W7DT102911": "G6", + "W7VS-SYSLOGIC26": "Widefield-329", + "OSXLTTF6T6.local": "DougLaptop", + "W7DTMJ026LUL": "DougPC", + "W7DTMJ036PSL": "Marina2P_Sutter", + "W7DT2PNC1STIM": "2P6", + "W7DTMJ234MG": "peterl_2p", + "W7DT2P3STiM": "2P3", + "W7DT2P4STIM": "2P4", + "W7DT2P5STIM": "2P5", + "W10DTSM118296": "NP3", + "meso1stim": "MS1", + "localhost": "localhost", } RIG_NAME = {k.lower(): v for k, v in RIG_NAME.items()} diff --git a/allensdk/brain_observatory/behavior/utils/metadata_parsers.py b/allensdk/brain_observatory/behavior/utils/metadata_parsers.py index e6bc9ddf8d..2fcba056e0 100644 --- a/allensdk/brain_observatory/behavior/utils/metadata_parsers.py +++ b/allensdk/brain_observatory/behavior/utils/metadata_parsers.py @@ -1,7 +1,6 @@ from typing import Optional -from allensdk.brain_observatory.ophys.project_constants import ( - NUM_DEPTHS_DICT, NUM_STRUCTURES_DICT) +from allensdk.brain_observatory.ophys.project_constants import NUM_DEPTHS_DICT, NUM_STRUCTURES_DICT ############################ @@ -53,9 +52,7 @@ def parse_stimulus_set(session_type: str) -> str: elif stimulus_type == "gratings": stim_set = "gratings" else: - raise ValueError( - f"Session_type {session_type} not formatted as " "expected." - ) + raise ValueError(f"Session_type {session_type} not formatted as expected.") return stim_set diff --git a/allensdk/brain_observatory/behavior/write_nwb/behavior/__main__.py b/allensdk/brain_observatory/behavior/write_nwb/behavior/__main__.py index fb13a434dc..e6db3ea0e7 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/behavior/__main__.py +++ b/allensdk/brain_observatory/behavior/write_nwb/behavior/__main__.py @@ -27,9 +27,7 @@ def run(self): nwb_filepath=bs_id_dir / f"behavior_session_{bs_id}.nwb", skip_metadata=self.args["skip_metadata_key"], skip_stim=self.args["skip_stimulus_file_key"], - include_experiment_description=self.args[ - 'include_experiment_description' - ] + include_experiment_description=self.args["include_experiment_description"], ) logging.info("File successfully created") @@ -46,7 +44,7 @@ def write_behavior_nwb( nwb_filepath: Path, skip_metadata: List[str], skip_stim: List[str], - include_experiment_description=True + include_experiment_description=True, ) -> str: """Load and write a BehaviorSession as NWB. @@ -88,7 +86,7 @@ def write_behavior_nwb( nwb_writer.write_nwb( skip_metadata=skip_metadata, skip_stim=skip_stim, - include_experiment_description=include_experiment_description + include_experiment_description=include_experiment_description, ) return str(nwb_filepath) diff --git a/allensdk/brain_observatory/behavior/write_nwb/behavior/schemas.py b/allensdk/brain_observatory/behavior/write_nwb/behavior/schemas.py index 843e6a5050..8c44a59f33 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/behavior/schemas.py +++ b/allensdk/brain_observatory/behavior/write_nwb/behavior/schemas.py @@ -12,25 +12,14 @@ BehaviorSessionMetadataSchema, ) from argschema import ArgSchema -from argschema.fields import ( - Int, - List, - LogLevel, - Nested, - OutputDir, - OutputFile, - String, - Bool -) +from argschema.fields import Int, List, LogLevel, Nested, OutputDir, OutputFile, String, Bool class BaseInputSchema(ArgSchema): class Meta: unknown = mm.RAISE - log_level = LogLevel( - default="INFO", description="Logging level of the module" - ) + log_level = LogLevel(default="INFO", description="Logging level of the module") metadata_table = InputFile( required=True, description="CSV file containing rows of BehaviorSession or " @@ -58,49 +47,26 @@ class Meta: description="Path of output.json to be written", ) include_experiment_description = Bool( - required=False, - description="If True, include experiment description in NWB file.", - default=True + required=False, description="If True, include experiment description in NWB file.", default=True ) def _get_behavior_metadata(self, bs_row): """ """ behavior_session_metadata = {} - behavior_session_metadata["age_in_days"] = self._retrieve_value( - bs_row=bs_row, column_name="age_in_days" - ) - behavior_session_metadata["cre_line"] = self._retrieve_value( - bs_row=bs_row, column_name="cre_line" - ) + behavior_session_metadata["age_in_days"] = self._retrieve_value(bs_row=bs_row, column_name="age_in_days") + behavior_session_metadata["cre_line"] = self._retrieve_value(bs_row=bs_row, column_name="cre_line") behavior_session_metadata["date_of_acquisition"] = self._retrieve_value( # noqa: E501 bs_row=bs_row, column_name="date_of_acquisition" ) - behavior_session_metadata["driver_line"] = self._retrieve_value( - bs_row=bs_row, column_name="driver_line" - ) - behavior_session_metadata["equipment_name"] = self._retrieve_value( - bs_row=bs_row, column_name="equipment_name" - ) - behavior_session_metadata["full_genotype"] = self._retrieve_value( - bs_row=bs_row, column_name="full_genotype" - ) - behavior_session_metadata["mouse_id"] = self._retrieve_value( - bs_row=bs_row, column_name="mouse_id" - ) - behavior_session_metadata["project_code"] = self._retrieve_value( - bs_row=bs_row, column_name="project_code" - ) - behavior_session_metadata["reporter_line"] = self._retrieve_value( - bs_row=bs_row, column_name="reporter_line" - ) - behavior_session_metadata["session_type"] = self._retrieve_value( - bs_row=bs_row, column_name="session_type" - ) - behavior_session_metadata["sex"] = self._retrieve_value( - bs_row=bs_row, - column_name="sex" - ) + behavior_session_metadata["driver_line"] = self._retrieve_value(bs_row=bs_row, column_name="driver_line") + behavior_session_metadata["equipment_name"] = self._retrieve_value(bs_row=bs_row, column_name="equipment_name") + behavior_session_metadata["full_genotype"] = self._retrieve_value(bs_row=bs_row, column_name="full_genotype") + behavior_session_metadata["mouse_id"] = self._retrieve_value(bs_row=bs_row, column_name="mouse_id") + behavior_session_metadata["project_code"] = self._retrieve_value(bs_row=bs_row, column_name="project_code") + behavior_session_metadata["reporter_line"] = self._retrieve_value(bs_row=bs_row, column_name="reporter_line") + behavior_session_metadata["session_type"] = self._retrieve_value(bs_row=bs_row, column_name="session_type") + behavior_session_metadata["sex"] = self._retrieve_value(bs_row=bs_row, column_name="sex") return behavior_session_metadata @@ -121,10 +87,12 @@ def _retrieve_value(self, bs_row: pd.Series, column_name: str): bs_row """ if column_name not in bs_row.index: - warn(f"Warning, {column_name} not in metadata table. Unless this " - "has been added to the inputs skip_metadata_key or " - "skip_stimulus_file_key, creating the NWB file " - "may fail.") + warn( + f"Warning, {column_name} not in metadata table. Unless this " + "has been added to the inputs skip_metadata_key or " + "skip_stimulus_file_key, creating the NWB file " + "may fail." + ) return None else: value = bs_row[column_name] @@ -134,9 +102,7 @@ def _retrieve_value(self, bs_row: pd.Series, column_name: str): class BehaviorInputSchema(BaseInputSchema): - behavior_session_id = Int( - required=True, description="Id of BehaviorSession to create." - ) + behavior_session_id = Int(required=True, description="Id of BehaviorSession to create.") behavior_session_metadata = Nested( BehaviorSessionMetadataSchema, @@ -149,16 +115,13 @@ def retreive_metadata(self, data, **kwargs): """Load the data from csv using Pandas the same as the project_cloud api. """ - df = sanitize_data_columns( - data["metadata_table"], dtype_convert={"mouse_id": str} - ) + df = sanitize_data_columns(data["metadata_table"], dtype_convert={"mouse_id": str}) df.set_index("behavior_session_id", inplace=True) try: bs_row = df.loc[int(data["behavior_session_id"])] except KeyError: raise KeyError( - f"Behavior session id {data['behavior_session_id']} " - "not in input BehaviorSessionTable. Exiting." + f"Behavior session id {data['behavior_session_id']} not in input BehaviorSessionTable. Exiting." ) data["behavior_session_metadata"] = self._get_behavior_metadata(bs_row) diff --git a/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/extension_builder.py b/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/extension_builder.py index e04eaa82e8..8d4779dd7f 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/extension_builder.py +++ b/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/extension_builder.py @@ -1,45 +1,33 @@ import os.path -from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, \ - NWBDatasetSpec +from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBDatasetSpec -NAMESPACE = 'ndx-aibs-ophys-event-detection' +NAMESPACE = "ndx-aibs-ophys-event-detection" def main(): - ns_builder = NWBNamespaceBuilder( doc="Detected events from optical physiology ROI fluorescence traces", name=f"""{NAMESPACE}""", version="""0.1.0""", author="""Allen Institute for Brain Science""", - contact="""waynew@alleninstitute.org""" + contact="""waynew@alleninstitute.org""", ) - ns_builder.include_type('RoiResponseSeries', namespace='core') - ns_builder.include_type('DynamicTableRegion', namespace='core') - ns_builder.include_type('TimeSeries', namespace='core') - ns_builder.include_type('NWBDataInterface', namespace='core') + ns_builder.include_type("RoiResponseSeries", namespace="core") + ns_builder.include_type("DynamicTableRegion", namespace="core") + ns_builder.include_type("TimeSeries", namespace="core") + ns_builder.include_type("NWBDataInterface", namespace="core") ophys_events_spec = NWBGroupSpec( - neurodata_type_def='OphysEventDetection', - neurodata_type_inc='RoiResponseSeries', - name='event_detection', - doc='Stores event detection output', + neurodata_type_def="OphysEventDetection", + neurodata_type_inc="RoiResponseSeries", + name="event_detection", + doc="Stores event detection output", datasets=[ - NWBDatasetSpec( - name='lambdas', - dtype='float', - doc='calculated regularization weights', - shape=(None,) - ), - NWBDatasetSpec( - name='noise_stds', - dtype='float', - doc='calculated noise std deviations', - shape=(None,) - ) - ] + NWBDatasetSpec(name="lambdas", dtype="float", doc="calculated regularization weights", shape=(None,)), + NWBDatasetSpec(name="noise_stds", dtype="float", doc="calculated noise std deviations", shape=(None,)), + ], ) new_data_types = [ophys_events_spec] diff --git a/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/ndx_ophys_events.py b/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/ndx_ophys_events.py index 142c234173..ede77336ee 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/ndx_ophys_events.py +++ b/allensdk/brain_observatory/behavior/write_nwb/extensions/event_detection/ndx_ophys_events.py @@ -2,14 +2,10 @@ from pynwb import load_namespaces, get_class # Set path of the namespace.yaml file to the expected install location -ndx_ophys_events_specpath = os.path.join( - os.path.dirname(__file__), - 'ndx-aibs-ophys-event-detection.namespace.yaml' -) +ndx_ophys_events_specpath = os.path.join(os.path.dirname(__file__), "ndx-aibs-ophys-event-detection.namespace.yaml") # Load the namespace load_namespaces(ndx_ophys_events_specpath) -OphysEventDetection = get_class('OphysEventDetection', - 'ndx-aibs-ophys-event-detection') +OphysEventDetection = get_class("OphysEventDetection", "ndx-aibs-ophys-event-detection") diff --git a/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/extension_builder.py b/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/extension_builder.py index 954a2ac25f..3972333dc1 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/extension_builder.py +++ b/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/extension_builder.py @@ -1,46 +1,43 @@ import os.path -from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, \ - NWBDatasetSpec +from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBDatasetSpec -NAMESPACE = 'ndx-aibs-stimulus-template' +NAMESPACE = "ndx-aibs-stimulus-template" def main(): - ns_builder = NWBNamespaceBuilder( doc="Stimulus images", name=f"""{NAMESPACE}""", version="""0.1.0""", author="""Allen Institute for Brain Science""", - contact="""waynew@alleninstitute.org""" + contact="""waynew@alleninstitute.org""", ) - ns_builder.include_type('ImageSeries', namespace='core') - ns_builder.include_type('TimeSeries', namespace='core') - ns_builder.include_type('NWBDataInterface', namespace='core') + ns_builder.include_type("ImageSeries", namespace="core") + ns_builder.include_type("TimeSeries", namespace="core") + ns_builder.include_type("NWBDataInterface", namespace="core") stimulus_template_spec = NWBGroupSpec( - neurodata_type_def='StimulusTemplate', - neurodata_type_inc='ImageSeries', - doc='Note: image names in control_description are referenced by ' - 'stimulus/presentation table as well as intervals ' - '\n' - 'Each image shown to the animals is warped to account for ' - 'distance and eye position relative to the monitor. This ' - 'extension stores the warped images that were shown to the animal ' - 'as well as an unwarped version of each image in which a mask has ' - 'been applied such that only the pixels visible after warping are ' - 'included', + neurodata_type_def="StimulusTemplate", + neurodata_type_inc="ImageSeries", + doc="Note: image names in control_description are referenced by " + "stimulus/presentation table as well as intervals " + "\n" + "Each image shown to the animals is warped to account for " + "distance and eye position relative to the monitor. This " + "extension stores the warped images that were shown to the animal " + "as well as an unwarped version of each image in which a mask has " + "been applied such that only the pixels visible after warping are " + "included", datasets=[ NWBDatasetSpec( - name='unwarped', - dtype='float', - doc='Original image with mask applied such that only the ' - 'pixels visible after warping are included', - shape=(None, None, None) + name="unwarped", + dtype="float", + doc="Original image with mask applied such that only the pixels visible after warping are included", + shape=(None, None, None), ) - ] + ], ) new_data_types = [stimulus_template_spec] diff --git a/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/ndx_stimulus_template.py b/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/ndx_stimulus_template.py index 1e90c78da9..9bd4ab5139 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/ndx_stimulus_template.py +++ b/allensdk/brain_observatory/behavior/write_nwb/extensions/stimulus_template/ndx_stimulus_template.py @@ -2,14 +2,10 @@ from pynwb import load_namespaces, get_class # Set path of the namespace.yaml file to the expected install location -ndx_stimulus_template_specpath = os.path.join( - os.path.dirname(__file__), - 'ndx-aibs-stimulus-template.namespace.yaml' -) +ndx_stimulus_template_specpath = os.path.join(os.path.dirname(__file__), "ndx-aibs-stimulus-template.namespace.yaml") # Load the namespace load_namespaces(ndx_stimulus_template_specpath) -StimulusTemplateExtension = get_class('StimulusTemplate', - 'ndx-aibs-stimulus-template') +StimulusTemplateExtension = get_class("StimulusTemplate", "ndx-aibs-stimulus-template") diff --git a/allensdk/brain_observatory/behavior/write_nwb/nwb_writer_utils.py b/allensdk/brain_observatory/behavior/write_nwb/nwb_writer_utils.py index 1b1c3836e3..8b30e14e04 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/nwb_writer_utils.py +++ b/allensdk/brain_observatory/behavior/write_nwb/nwb_writer_utils.py @@ -10,10 +10,7 @@ class OphysNwbWriter(NWBWriter): """NWBWriter with additional options to modify targeted_imaging_depth.""" def _update_session( - self, - lims_session: BehaviorSession, - ophys_experiment_ids: Optional[List[int]] = None, - **kwargs + self, lims_session: BehaviorSession, ophys_experiment_ids: Optional[List[int]] = None, **kwargs ) -> BehaviorSession: """Call session methods to update certain values within the session. diff --git a/allensdk/brain_observatory/behavior/write_nwb/ophys/__main__.py b/allensdk/brain_observatory/behavior/write_nwb/ophys/__main__.py index 77e92c6928..a7c14e04f1 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/ophys/__main__.py +++ b/allensdk/brain_observatory/behavior/write_nwb/ophys/__main__.py @@ -29,14 +29,11 @@ def run(self): nwb_filepath=oe_id_dir / f"behavior_ophys_experiment_{oe_id}.nwb", skip_metadata=self.args["skip_metadata_key"], skip_stim=self.args["skip_stimulus_file_key"], - include_experiment_description=self.args[ - 'include_experiment_description' - ] + include_experiment_description=self.args["include_experiment_description"], ) logging.info("File successfully created") - output_dict = {"output_path": output_file, - "input_parameters": self.args} + output_dict = {"output_path": output_file, "input_parameters": self.args} self.output(output_dict) @@ -46,7 +43,7 @@ def write_experiment_nwb( nwb_filepath: Path, skip_metadata: List[str], skip_stim: List[str], - include_experiment_description=True + include_experiment_description=True, ) -> str: """Load and write a BehaviorOphysExperiment as NWB. @@ -89,7 +86,7 @@ def write_experiment_nwb( ophys_experiment_ids=self.args["ophys_container_experiment_ids"], skip_metadata=skip_metadata, skip_stim=skip_stim, - include_experiment_description=include_experiment_description + include_experiment_description=include_experiment_description, ) return str(nwb_filepath) diff --git a/allensdk/brain_observatory/behavior/write_nwb/ophys/schemas.py b/allensdk/brain_observatory/behavior/write_nwb/ophys/schemas.py index 8cb06ae6e4..d3c76043fb 100644 --- a/allensdk/brain_observatory/behavior/write_nwb/ophys/schemas.py +++ b/allensdk/brain_observatory/behavior/write_nwb/ophys/schemas.py @@ -14,9 +14,7 @@ class OphysExperimentInputSchema(BaseInputSchema): - ophys_experiment_id = Int( - required=True, description="Id of OphysExperiment to create." - ) + ophys_experiment_id = Int(required=True, description="Id of OphysExperiment to create.") ophys_container_experiment_ids = List( Int, @@ -42,9 +40,7 @@ def retreive_metadata(self, data, **kwargs): """Load the data from csv using Pandas the same as the project_cloud api. """ - df = sanitize_data_columns( - data["metadata_table"], dtype_convert={"mouse_id": str} - ) + df = sanitize_data_columns(data["metadata_table"], dtype_convert={"mouse_id": str}) df.set_index("ophys_experiment_id", inplace=True) try: # Enforce type as we haven't enfoced type in the @@ -52,39 +48,24 @@ def retreive_metadata(self, data, **kwargs): oe_row = df.loc[int(data["ophys_experiment_id"])] except KeyError: raise KeyError( - f"Ophys experiment id {data['ophys_experiment_id']} " - "not in input OphysExperimentTable. Exiting." + f"Ophys experiment id {data['ophys_experiment_id']} not in input OphysExperimentTable. Exiting." ) data["ophys_experiment_metadata"] = self._get_behavior_metadata(oe_row) - data["ophys_experiment_metadata"]["behavior_session_id"] = oe_row[ - "behavior_session_id" - ] + data["ophys_experiment_metadata"]["behavior_session_id"] = oe_row["behavior_session_id"] # Ophys Experiment specific data. - data["ophys_experiment_metadata"]["imaging_depth"] = oe_row[ - "imaging_depth" - ] + data["ophys_experiment_metadata"]["imaging_depth"] = oe_row["imaging_depth"] imaging_plane_group = oe_row["imaging_plane_group"] if pd.isna(imaging_plane_group): imaging_plane_group = None - data["ophys_experiment_metadata"][ - "imaging_plane_group" - ] = imaging_plane_group + data["ophys_experiment_metadata"]["imaging_plane_group"] = imaging_plane_group data["ophys_experiment_metadata"]["indicator"] = oe_row["indicator"] - data["ophys_experiment_metadata"]["ophys_container_id"] = oe_row[ - "ophys_container_id" - ] + data["ophys_experiment_metadata"]["ophys_container_id"] = oe_row["ophys_container_id"] data["ophys_experiment_metadata"]["ophys_experiment_id"] = oe_row.name - data["ophys_experiment_metadata"]["ophys_session_id"] = oe_row[ - "ophys_session_id" - ] - data["ophys_experiment_metadata"]["targeted_imaging_depth"] = oe_row[ - "targeted_imaging_depth" - ] - data["ophys_experiment_metadata"]["targeted_structure"] = oe_row[ - "targeted_structure" - ] + data["ophys_experiment_metadata"]["ophys_session_id"] = oe_row["ophys_session_id"] + data["ophys_experiment_metadata"]["targeted_imaging_depth"] = oe_row["targeted_imaging_depth"] + data["ophys_experiment_metadata"]["targeted_structure"] = oe_row["targeted_structure"] data["ophys_container_experiment_ids"] = df[ df["ophys_container_id"] == oe_row["ophys_container_id"] diff --git a/allensdk/brain_observatory/brain_observatory_exceptions.py b/allensdk/brain_observatory/brain_observatory_exceptions.py index 481e09a05c..7b8a2ba9fc 100644 --- a/allensdk/brain_observatory/brain_observatory_exceptions.py +++ b/allensdk/brain_observatory/brain_observatory_exceptions.py @@ -36,16 +36,17 @@ class BrainObservatoryAnalysisException(Exception): pass + class MissingStimulusException(Exception): pass + class NoEyeTrackingException(Exception): pass -class EpochSeparationException(Exception): - def __init__(self, *args, **kwargs): - - self.delta = kwargs.pop('delta') +class EpochSeparationException(Exception): + def __init__(self, *args, **kwargs): + self.delta = kwargs.pop("delta") - super(EpochSeparationException, self).__init__(*args, **kwargs) \ No newline at end of file + super(EpochSeparationException, self).__init__(*args, **kwargs) diff --git a/allensdk/brain_observatory/brain_observatory_plotting.py b/allensdk/brain_observatory/brain_observatory_plotting.py index a69206774f..87271b71fb 100644 --- a/allensdk/brain_observatory/brain_observatory_plotting.py +++ b/allensdk/brain_observatory/brain_observatory_plotting.py @@ -42,76 +42,77 @@ def plot_drifting_grating_traces(dg, save_dir): - '''saves figures with a Ori X TF grid of mean resposes''' + """saves figures with a Ori X TF grid of mean resposes""" logging.info("Plotting Ori and TF mean response for all cells") blank = dg.sweep_response[dg.stim_table.temporal_frequency == 0] for nc in range(dg.numbercells): if np.mod(nc, 20) == 0: logging.info("Cell #%s", str(nc)) - xtime = np.arange(-1 * dg.interlength / dg.acquisition_rate, (dg.sweeplength + - dg.interlength) / dg.acquisition_rate, 1 / dg.acquisition_rate) + xtime = np.arange( + -1 * dg.interlength / dg.acquisition_rate, + (dg.sweeplength + dg.interlength) / dg.acquisition_rate, + 1 / dg.acquisition_rate, + ) plt.figure(nc, figsize=(20, 16)) vmax = 0 vmin = 0 try: - blank_p = blank[str(nc)].mean() + \ - (blank[str(nc)].std() / len(blank[str(nc)])) - blank_n = blank[str(nc)].mean() - \ - (blank[str(nc)].std() / len(blank[str(nc)])) + blank_p = blank[str(nc)].mean() + (blank[str(nc)].std() / len(blank[str(nc)])) + blank_n = blank[str(nc)].mean() - (blank[str(nc)].std() / len(blank[str(nc)])) except Exception: - blank_p = blank.iloc[:, nc].apply( - np.mean) + (blank.iloc[:, nc].apply(np.std) / blank.iloc[:, nc].apply(len)) - blank_n = blank.iloc[:, nc].apply( - np.mean) - (blank.iloc[:, nc].apply(np.std) / blank.iloc[:, nc].apply(len)) + blank_p = blank.iloc[:, nc].apply(np.mean) + ( + blank.iloc[:, nc].apply(np.std) / blank.iloc[:, nc].apply(len) + ) + blank_n = blank.iloc[:, nc].apply(np.mean) - ( + blank.iloc[:, nc].apply(np.std) / blank.iloc[:, nc].apply(len) + ) for ori in dg.orivals: ori_pt = np.where(dg.orivals == ori)[0][0] for tf in dg.tfvals[1:]: tf_pt = np.where(dg.tfvals == tf)[0][0] sp_pt = (5 * ori_pt) + tf_pt subset_response = dg.sweep_response[ - (dg.stim_table.temporal_frequency == tf) & (dg.stim_table.orientation == ori)] + (dg.stim_table.temporal_frequency == tf) & (dg.stim_table.orientation == ori) + ] try: - subset_response_p = subset_response[str(nc)].mean( - ) + (subset_response[str(nc)][:-1].std() / len(subset_response[str(nc)])) - subset_response_n = subset_response[str(nc)].mean( - ) - (subset_response[str(nc)][:-1].std() / len(subset_response[str(nc)])) + subset_response_p = subset_response[str(nc)].mean() + ( + subset_response[str(nc)][:-1].std() / len(subset_response[str(nc)]) + ) + subset_response_n = subset_response[str(nc)].mean() - ( + subset_response[str(nc)][:-1].std() / len(subset_response[str(nc)]) + ) except Exception: - subset_response_p = subset_response.iloc[:, nc].apply( - np.mean) + (subset_response.iloc[:, nc].apply(np.std) / subset_response.iloc[:, nc].apply(len)) - subset_response_n = subset_response.iloc[:, nc].apply( - np.mean) - (subset_response.iloc[:, nc].apply(np.std) / subset_response.iloc[:, nc].apply(len)) + subset_response_p = subset_response.iloc[:, nc].apply(np.mean) + ( + subset_response.iloc[:, nc].apply(np.std) / subset_response.iloc[:, nc].apply(len) + ) + subset_response_n = subset_response.iloc[:, nc].apply(np.mean) - ( + subset_response.iloc[:, nc].apply(np.std) / subset_response.iloc[:, nc].apply(len) + ) ax = plt.subplot(8, 5, sp_pt) while len(xtime) > len(subset_response[str(nc)].mean()): xtime = np.delete(xtime, -1) try: - ax.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) + ax.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) except Exception: pass try: - ax.fill_between(xtime, blank_p, blank_n, - color='k', alpha=0.5) + ax.fill_between(xtime, blank_p, blank_n, color="k", alpha=0.5) except Exception: pass try: - ax.plot(xtime, subset_response[ - str(nc)].mean(), color='b', lw=2) + ax.plot(xtime, subset_response[str(nc)].mean(), color="b", lw=2) except Exception: pass - ax.plot(xtime, subset_response[ - str(nc)].mean(), color='b', lw=2) + ax.plot(xtime, subset_response[str(nc)].mean(), color="b", lw=2) # TODO: remove the [:119] and [:-1] and the try/except - ax.plot(xtime, blank[str(nc)].mean(), color='k', lw=2) - ax.axvspan(0, dg.sweeplength / dg.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + ax.plot(xtime, blank[str(nc)].mean(), color="k", lw=2) + ax.axvspan(0, dg.sweeplength / dg.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax.set_xlim(-1, 3) ax.set_xticks(range(-1, 4)) ax.yaxis.set_major_locator(MaxNLocator(4)) - vmax = np.where(np.amax(subset_response_p) > - vmax, np.amax(subset_response_p), vmax) - vmin = np.where(np.amin(subset_response_n) < - vmin, np.amin(subset_response_n), vmin) + vmax = np.where(np.amax(subset_response_p) > vmax, np.amax(subset_response_p), vmax) + vmin = np.where(np.amin(subset_response_n) < vmin, np.amin(subset_response_n), vmin) if ori_pt < 7: ax.set_xticks([]) @@ -131,7 +132,7 @@ def plot_drifting_grating_traces(dg, save_dir): plt.tight_layout() plt.suptitle("Cell " + str(nc + 1), fontsize=20) plt.subplots_adjust(top=0.9) - filename = 'Traces DG Cell_' + str(nc + 1) + '.png' + filename = "Traces DG Cell_" + str(nc + 1) + ".png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() @@ -139,43 +140,41 @@ def plot_drifting_grating_traces(dg, save_dir): def plot_ns_traces(nsa, save_dir): logging.info("Plotting Natural Scene traces for each cell") - xtime = np.arange(-1 * nsa.interlength / nsa.acquisition_rate, (nsa.sweeplength + - nsa.interlength) / nsa.acquisition_rate, 1 / nsa.acquisition_rate) + xtime = np.arange( + -1 * nsa.interlength / nsa.acquisition_rate, + (nsa.sweeplength + nsa.interlength) / nsa.acquisition_rate, + 1 / nsa.acquisition_rate, + ) blank = nsa.sweep_response[nsa.stim_table.frame == -1] for nc in range(nsa.numbercells): if np.mod(nc, 20) == 0: logging.info("Cell #%s", str(nc)) vmax = 0 vmin = 0 - blank_p = blank[str(nc)].mean() + \ - (blank[str(nc)].std() / len(blank[str(nc)])) - blank_n = blank[str(nc)].mean() - \ - (blank[str(nc)].std() / len(blank[str(nc)])) + blank_p = blank[str(nc)].mean() + (blank[str(nc)].std() / len(blank[str(nc)])) + blank_n = blank[str(nc)].mean() - (blank[str(nc)].std() / len(blank[str(nc)])) plt.figure(nc, figsize=(30, 25)) for ns in range(nsa.number_scenes - 1): subset_response = nsa.sweep_response[nsa.stim_table.frame == ns] - subset_response_p = subset_response[str(nc)].mean( - ) + (subset_response[str(nc)][:].std() / len(subset_response[str(nc)])) - subset_response_n = subset_response[str(nc)].mean( - ) - (subset_response[str(nc)][:].std() / len(subset_response[str(nc)])) + subset_response_p = subset_response[str(nc)].mean() + ( + subset_response[str(nc)][:].std() / len(subset_response[str(nc)]) + ) + subset_response_n = subset_response[str(nc)].mean() - ( + subset_response[str(nc)][:].std() / len(subset_response[str(nc)]) + ) ax = plt.subplot(10, 12, ns + 1) try: - ax.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) + ax.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) except Exception: xtime = xtime[:-1] - ax.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) - ax.fill_between(xtime, blank_p, blank_n, color='k', alpha=0.5) - ax.plot(xtime, subset_response[str(nc)].mean(), color='b', lw=2) - ax.plot(xtime, blank[str(nc)].mean(), color='k', lw=2) - ax.axvspan(0, nsa.sweeplength / nsa.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + ax.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) + ax.fill_between(xtime, blank_p, blank_n, color="k", alpha=0.5) + ax.plot(xtime, subset_response[str(nc)].mean(), color="b", lw=2) + ax.plot(xtime, blank[str(nc)].mean(), color="k", lw=2) + ax.axvspan(0, nsa.sweeplength / nsa.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax.yaxis.set_major_locator(MaxNLocator(4)) - vmax = np.where(np.amax(subset_response_p) > vmax, - np.amax(subset_response_p), vmax) - vmin = np.where(np.amin(subset_response_n) < vmin, - np.amin(subset_response_n), vmin) + vmax = np.where(np.amax(subset_response_p) > vmax, np.amax(subset_response_p), vmax) + vmin = np.where(np.amin(subset_response_n) < vmin, np.amin(subset_response_n), vmin) if ns < 108: ax.set_xticks([]) if np.mod(ns, 12): @@ -186,7 +185,7 @@ def plot_ns_traces(nsa, save_dir): plt.tight_layout() plt.suptitle("Cell " + str(nc + 1), fontsize=20) plt.subplots_adjust(top=0.9) - filename = 'NS Traces Cell_' + str(nc + 1) + '.png' + filename = "NS Traces Cell_" + str(nc + 1) + ".png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() @@ -194,18 +193,19 @@ def plot_ns_traces(nsa, save_dir): def plot_sg_traces(sg, save_dir): logging.info("Plotting Static Grating traces for each cell") - xtime = np.arange(-1 * sg.interlength / sg.acquisition_rate, (sg.sweeplength + - sg.interlength) / sg.acquisition_rate, 1 / sg.acquisition_rate) + xtime = np.arange( + -1 * sg.interlength / sg.acquisition_rate, + (sg.sweeplength + sg.interlength) / sg.acquisition_rate, + 1 / sg.acquisition_rate, + ) blank = sg.sweep_response[sg.stim_table.spatial_frequency == 0] for nc in range(sg.numbercells): if np.mod(nc, 20) == 0: logging.info("Cell #%s", str(nc)) vmax = 0 vmin = 0 - blank_p = blank[str(nc)].mean() + \ - (blank[str(nc)].std() / len(blank[str(nc)])) - blank_n = blank[str(nc)].mean() - \ - (blank[str(nc)].std() / len(blank[str(nc)])) + blank_p = blank[str(nc)].mean() + (blank[str(nc)].std() / len(blank[str(nc)])) + blank_n = blank[str(nc)].mean() - (blank[str(nc)].std() / len(blank[str(nc)])) while len(xtime) > len(blank_p): xtime = np.delete(xtime, -1) plt.figure(nc, figsize=(30, 30)) @@ -217,27 +217,26 @@ def plot_sg_traces(sg, save_dir): for phase in sg.phasevals: ph_pt = ph_dict[phase] subplotnum = sf_pt + (ori_pt * 11) + ph_pt - subset_response = sg.sweep_response[(sg.stim_table.spatial_frequency == sf) & ( - sg.stim_table.orientation == ori) & (sg.stim_table.phase == phase)] - subset_response_p = subset_response[str(nc)].mean( - ) + (subset_response[str(nc)][:].std() / len(subset_response[str(nc)])) - subset_response_n = subset_response[str(nc)].mean( - ) - (subset_response[str(nc)][:].std() / len(subset_response[str(nc)])) + subset_response = sg.sweep_response[ + (sg.stim_table.spatial_frequency == sf) + & (sg.stim_table.orientation == ori) + & (sg.stim_table.phase == phase) + ] + subset_response_p = subset_response[str(nc)].mean() + ( + subset_response[str(nc)][:].std() / len(subset_response[str(nc)]) + ) + subset_response_n = subset_response[str(nc)].mean() - ( + subset_response[str(nc)][:].std() / len(subset_response[str(nc)]) + ) ax = plt.subplot(13, 11, subplotnum) - ax.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) - ax.fill_between(xtime, blank_p, blank_n, - color='k', alpha=0.5) - ax.plot(xtime, subset_response[ - str(nc)].mean(), color='b', lw=2) - ax.plot(xtime, blank[str(nc)].mean(), color='k', lw=2) - ax.axvspan(0, sg.sweeplength / sg.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + ax.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) + ax.fill_between(xtime, blank_p, blank_n, color="k", alpha=0.5) + ax.plot(xtime, subset_response[str(nc)].mean(), color="b", lw=2) + ax.plot(xtime, blank[str(nc)].mean(), color="k", lw=2) + ax.axvspan(0, sg.sweeplength / sg.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax.yaxis.set_major_locator(MaxNLocator(4)) - vmax = np.where(np.amax(subset_response_p) - > vmax, np.amax(subset_response_p), vmax) - vmin = np.where(np.amin(subset_response_n) - < vmin, np.amin(subset_response_n), vmin) + vmax = np.where(np.amax(subset_response_p) > vmax, np.amax(subset_response_p), vmax) + vmin = np.where(np.amin(subset_response_n) < vmin, np.amin(subset_response_n), vmin) if np.mod(subplotnum, 11) != 1: ax.set_yticks([]) else: @@ -260,17 +259,19 @@ def plot_sg_traces(sg, save_dir): plt.tight_layout() plt.suptitle("Cell " + str(nc + 1), fontsize=20) plt.subplots_adjust(top=0.9) - filename = 'SG Traces Cell_' + str(nc + 1) + '.png' + filename = "SG Traces Cell_" + str(nc + 1) + ".png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() -def plot_lsn_traces(lsn, save_dir, suffix=''): +def plot_lsn_traces(lsn, save_dir, suffix=""): logging.info("Plotting LSN traces for all cells") - xtime = np.arange(-lsn.interlength / lsn.acquisition_rate, - (lsn.interlength + lsn.sweeplength) / lsn.acquisition_rate, - 1.0 / lsn.acquisition_rate) + xtime = np.arange( + -lsn.interlength / lsn.acquisition_rate, + (lsn.interlength + lsn.sweeplength) / lsn.acquisition_rate, + 1.0 / lsn.acquisition_rate, + ) for nc in range(lsn.numbercells): if np.mod(nc, 20) == 0: @@ -293,18 +294,13 @@ def plot_lsn_traces(lsn, save_dir, suffix=''): subset_off_mean = subset_off.mean() ax = plt.subplot(16, 28, sp_pt) - ax.plot(xtime, subset_on_mean, color='r', lw=2) - ax.plot(xtime, subset_off_mean, color='b', lw=2) - ax.axvspan(0, lsn.sweeplength / lsn.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) - vmax = np.where(np.amax(subset_on_mean) > vmax, - np.amax(subset_on_mean), vmax) - vmax = np.where(np.amax(subset_off_mean) > vmax, - np.amax(subset_off_mean), vmax) - vmin = np.where(np.amin(subset_on_mean) < vmin, - np.amin(subset_on_mean), vmin) - vmin = np.where(np.amin(subset_off_mean) < vmin, - np.amin(subset_off_mean), vmin) + ax.plot(xtime, subset_on_mean, color="r", lw=2) + ax.plot(xtime, subset_off_mean, color="b", lw=2) + ax.axvspan(0, lsn.sweeplength / lsn.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) + vmax = np.where(np.amax(subset_on_mean) > vmax, np.amax(subset_on_mean), vmax) + vmax = np.where(np.amax(subset_off_mean) > vmax, np.amax(subset_off_mean), vmax) + vmin = np.where(np.amin(subset_on_mean) < vmin, np.amin(subset_on_mean), vmin) + vmin = np.where(np.amin(subset_off_mean) < vmin, np.amin(subset_off_mean), vmin) ax.set_xticks([]) ax.set_yticks([]) @@ -315,7 +311,7 @@ def plot_lsn_traces(lsn, save_dir, suffix=''): plt.tight_layout() plt.suptitle("Cell " + str(nc + 1), fontsize=20) plt.subplots_adjust(top=0.9) - filename = 'Traces LSN Cell_' + str(nc + 1) + suffix + '.png' + filename = "Traces LSN Cell_" + str(nc + 1) + suffix + ".png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() @@ -341,57 +337,81 @@ def _plot_3sa(dg, nm1, nm3, save_dir): ax12 = plt.subplot2grid((6, 6), (4, 3), colspan=3) ax13 = plt.subplot2grid((6, 6), (5, 3), colspan=3) - xtime = np.arange(0, np.size(dg.celltraces, 1), 1.) + xtime = np.arange(0, np.size(dg.celltraces, 1), 1.0) xtime /= dg.acquisition_rate - dif = np.ediff1d(dg.stim_table.start.values, - to_begin=8000, to_end=8000) + dif = np.ediff1d(dg.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) ax1.plot(xtime, dg.celltraces[nc, :]) for i in range(len(test) - 1): - ax1.axvspan(xmin=(dg.stim_table.start.iloc[test[i]].values / dg.acquisition_rate), xmax=( - dg.stim_table.end.iloc[test[i + 1] - 1].values / dg.acquisition_rate), color='gray', alpha=0.3) - ax1.axvspan(xmin=nm1.stim_table.start.min() / nm1.acquisition_rate, xmax=( - (nm1.stim_table.start.max() + nm1.sweeplength) / nm1.acquisition_rate), color='red', alpha=0.3) - dif = np.ediff1d(nm3.stim_table.start.values, - to_begin=8000, to_end=8000) + ax1.axvspan( + xmin=(dg.stim_table.start.iloc[test[i]].values / dg.acquisition_rate), + xmax=(dg.stim_table.end.iloc[test[i + 1] - 1].values / dg.acquisition_rate), + color="gray", + alpha=0.3, + ) + ax1.axvspan( + xmin=nm1.stim_table.start.min() / nm1.acquisition_rate, + xmax=((nm1.stim_table.start.max() + nm1.sweeplength) / nm1.acquisition_rate), + color="red", + alpha=0.3, + ) + dif = np.ediff1d(nm3.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) for i in range(len(test) - 1): - ax1.axvspan(xmin=(nm3.stim_table.start.iloc[test[i]].values / nm3.acquisition_rate), xmax=( - (nm3.stim_table.end.iloc[test[i + 1] - 1].values + nm3.sweeplength) / nm3.acquisition_rate), color='blue', alpha=0.3) + ax1.axvspan( + xmin=(nm3.stim_table.start.iloc[test[i]].values / nm3.acquisition_rate), + xmax=((nm3.stim_table.end.iloc[test[i + 1] - 1].values + nm3.sweeplength) / nm3.acquisition_rate), + color="blue", + alpha=0.3, + ) ax1.set_xlabel("Time (s)", fontsize=20) ax1.set_ylabel("Fluorescence", fontsize=20) ax2.hist(dg.celltraces[nc, :], bins=70) - ax2.set_yscale('log') + ax2.set_yscale("log") ax2.set_xlabel("Fluorescence", fontsize=20) ax2.set_ylabel("Count", fontsize=20) - xtime = np.arange(0, np.size(dg.dxcm), 1.) + xtime = np.arange(0, np.size(dg.dxcm), 1.0) xtime /= dg.acquisition_rate - ax3.plot(xtime, dg.dxcm, color='k') + ax3.plot(xtime, dg.dxcm, color="k") ax3.set_xlabel("Time (s)", fontsize=20) ax3.set_xlabel("Speed (cm/s)", fontsize=20) - smax = nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[ - nc, :, 0]), 0] + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 1] - vmax = nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[ - nc, :, 0]), 0] + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 1] + smax = ( + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 0] + + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 1] + ) + vmax = ( + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 0] + + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 1] + ) rmax = np.where(smax > vmax, smax, vmax) - smin = nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[ - nc, :, 0]), 0] - nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 1] - vmin = nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[ - nc, :, 0]), 0] - nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 1] + smin = ( + nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 0] + - nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 1] + ) + vmin = ( + nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 0] + - nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 1] + ) rmin = np.where(smin < vmin, smin, vmin) - ax4.errorbar(nm1.binned_dx_sp[:, 0], nm1.binned_cells_sp[ - nc, :, 0], yerr=nm1.binned_cells_sp[nc, :, 1], fmt='.', color='k') + ax4.errorbar( + nm1.binned_dx_sp[:, 0], + nm1.binned_cells_sp[nc, :, 0], + yerr=nm1.binned_cells_sp[nc, :, 1], + fmt=".", + color="k", + ) ax4.set_ylim(rmin, rmax) ax4.set_xlabel("Speed (cm/s)", fontsize=20) ax4.set_ylabel("DF/F", fontsize=20) ax4.set_title("Spontaneous", fontsize=20) - ax5.errorbar(nm1.binned_dx_vis[:, 0], nm1.binned_cells_vis[ - nc, :, 0], yerr=nm1.binned_cells_vis[nc, :, 1], fmt='.') + ax5.errorbar( + nm1.binned_dx_vis[:, 0], nm1.binned_cells_vis[nc, :, 0], yerr=nm1.binned_cells_vis[nc, :, 1], fmt="." + ) ax5.set_ylim(rmin, rmax) ax5.set_xlabel("Speed (cm/s)", fontsize=20) ax5.set_ylabel("DF/F", fontsize=20) @@ -399,51 +419,63 @@ def _plot_3sa(dg, nm1, nm3, save_dir): peakori = dg.peak.ori_dg[nc] peaktf = dg.peak.tf_dg[nc] - ax6.errorbar(dg.orivals, dg.response[:, peaktf, nc, 0], yerr=dg.response[ - :, peaktf, nc, 1], fmt='bo-', lw=2) - ax6.fill_between(dg.orivals, np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_ori), np.repeat( - dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_ori), color='gray', alpha=0.5) - ax6.axhline(y=dg.response[0, 0, nc, 0], ls='--', color='k') - ax6.annotate(str(dg.tfvals[peaktf]) + " Hz", - xy=(0, 0.9), xycoords='axes fraction', fontsize=14) + ax6.errorbar(dg.orivals, dg.response[:, peaktf, nc, 0], yerr=dg.response[:, peaktf, nc, 1], fmt="bo-", lw=2) + ax6.fill_between( + dg.orivals, + np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_ori), + np.repeat(dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_ori), + color="gray", + alpha=0.5, + ) + ax6.axhline(y=dg.response[0, 0, nc, 0], ls="--", color="k") + ax6.annotate(str(dg.tfvals[peaktf]) + " Hz", xy=(0, 0.9), xycoords="axes fraction", fontsize=14) ax6.set_xticks(dg.orivals) ax7.set_xlim(-10, 325) ax6.set_xlabel("Direction (d)", fontsize=20) ax6.set_ylabel("Mean DF/F (%)", fontsize=20) ax6.yaxis.set_major_locator(MaxNLocator(6)) - ax7.errorbar(range(5), dg.response[peakori, 1:, nc, 0], yerr=dg.response[ - peakori, 1:, nc, 1], fmt='bo-', lw=2) - ax7.fill_between(range(5), np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], 5), np.repeat( - dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], 5), color='gray', alpha=0.5) - ax7.axhline(y=dg.response[0, 0, nc, 0], ls='--', color='k', lw=2) - ax7.annotate(str(dg.orivals[peakori]) + " Deg", - xy=(0, 0.9), xycoords='axes fraction', fontsize=14) + ax7.errorbar(range(5), dg.response[peakori, 1:, nc, 0], yerr=dg.response[peakori, 1:, nc, 1], fmt="bo-", lw=2) + ax7.fill_between( + range(5), + np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], 5), + np.repeat(dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], 5), + color="gray", + alpha=0.5, + ) + ax7.axhline(y=dg.response[0, 0, nc, 0], ls="--", color="k", lw=2) + ax7.annotate(str(dg.orivals[peakori]) + " Deg", xy=(0, 0.9), xycoords="axes fraction", fontsize=14) ax7.set_xlim(-0.2, 4.2) ax7.set_xticks(range(5)) ax7.set_xticklabels(dg.tfvals[1:]) ax7.set_xlabel("Temporal frequency (Hz)", fontsize=20) - subset = dg.sweep_response[(dg.stim_table.orientation == dg.orivals[peakori]) & ( - dg.stim_table.temporal_frequency == dg.tfvals[peaktf])] - xtime = np.arange(-1 * dg.interlength / dg.acquisition_rate, (dg.sweeplength + - dg.interlength) / dg.acquisition_rate, 1 / dg.acquisition_rate) + subset = dg.sweep_response[ + (dg.stim_table.orientation == dg.orivals[peakori]) & (dg.stim_table.temporal_frequency == dg.tfvals[peaktf]) + ] + xtime = np.arange( + -1 * dg.interlength / dg.acquisition_rate, + (dg.sweeplength + dg.interlength) / dg.acquisition_rate, + 1 / dg.acquisition_rate, + ) while len(xtime) > len(subset[str(nc)].mean()): xtime = np.delete(xtime, -1) for index, row in subset.iterrows(): ax8.plot(xtime, subset[str(nc)][index], lw=2) ax8.set_xlim(-1, 3) - ax8.annotate(str(dg.orivals[peakori]) + " Deg / " + str( - dg.tfvals[peaktf]) + " Hz", xy=(0, 0.9), xycoords='axes fraction', fontsize=14) + ax8.annotate( + str(dg.orivals[peakori]) + " Deg / " + str(dg.tfvals[peaktf]) + " Hz", + xy=(0, 0.9), + xycoords="axes fraction", + fontsize=14, + ) ax8.set_xlabel("Time (s)", fontsize=20) ax8.set_ylabel("DF/F (%)", fontsize=20) - ax8.axvspan(0, dg.sweeplength / dg.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + ax8.axvspan(0, dg.sweeplength / dg.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax8.yaxis.set_major_locator(MaxNLocator(6)) ax8.set_title("Trial responses to prefered ori/tf", fontsize=20) - im = ax9.imshow(dg.response[:, 1:, nc, 0], - cmap='gray', interpolation='none') + im = ax9.imshow(dg.response[:, 1:, nc, 0], cmap="gray", interpolation="none") ax9.set_ylabel("Direction (d)", fontsize=20) ax9.set_yticks(range(8)) ax9.set_yticklabels(dg.orivals) @@ -451,53 +483,51 @@ def _plot_3sa(dg, nm1, nm3, save_dir): ax9.set_xticks(range(5)) ax9.set_xticklabels(dg.tfvals[1:]) cbar = plt.colorbar(im, ax=ax9) - cbar.ax.set_ylabel('DF/F (%)', fontsize=8) + cbar.ax.set_ylabel("DF/F (%)", fontsize=8) for t in cbar.ax.get_yticklabels(): t.set_fontsize(8) - xtime = np.arange(0, nm1.sweeplength / - nm1.acquisition_rate, 1 / nm1.acquisition_rate) + xtime = np.arange(0, nm1.sweeplength / nm1.acquisition_rate, 1 / nm1.acquisition_rate) while len(xtime) > len(nm1.sweep_response[str(nc)].mean()): xtime = np.delete(xtime, -1) for index, row in nm1.sweep_response.iterrows(): ax10.plot(xtime, nm1.sweep_response[str(nc)][index], lw=2) ax10.set_xlabel("Time (s)", fontsize=20) ax10.set_ylabel("DF/F", fontsize=20) - ax10.set_title("Natural Movie 1", fontsize=20, color='red') + ax10.set_title("Natural Movie 1", fontsize=20, color="red") temp = np.empty((len(nm1.stim_table), nm1.sweeplength)) for i in range(len(nm1.stim_table)): temp[i, :] = nm1.sweep_response[str(nc)].iloc[i] - ax11.imshow(temp, cmap='gray', interpolation='none', aspect=40) + ax11.imshow(temp, cmap="gray", interpolation="none", aspect=40) ax11.set_ylabel("Trials", fontsize=20) ax11.set_xticks([]) - xtime = np.arange(0, nm3.sweeplength / - nm3.acquisition_rate, 1 / nm3.acquisition_rate) + xtime = np.arange(0, nm3.sweeplength / nm3.acquisition_rate, 1 / nm3.acquisition_rate) while len(xtime) > len(nm3.sweep_response[str(nc)].mean()): xtime = np.delete(xtime, -1) for index, row in nm3.sweep_response.iterrows(): ax12.plot(xtime, nm3.sweep_response[str(nc)][index], lw=2) ax12.set_xlabel("Time (s)", fontsize=20) ax12.set_ylabel("DF/F", fontsize=20) - ax12.set_title("Natural Movie Long", fontsize=20, color='blue') + ax12.set_title("Natural Movie Long", fontsize=20, color="blue") temp = np.empty((len(nm3.stim_table), nm3.sweeplength)) for i in range(len(nm3.stim_table)): temp[i, :] = nm3.sweep_response[str(nc)].iloc[i] - ax13.imshow(temp, cmap='gray', interpolation='none', aspect=100) + ax13.imshow(temp, cmap="gray", interpolation="none", aspect=100) ax13.set_ylabel("Trials", fontsize=20) ax13.set_xticks([]) plt.tick_params(labelsize=16) plt.tight_layout() - filename = 'Cell_' + str(nc + 1) + '_3SA.png' + filename = "Cell_" + str(nc + 1) + "_3SA.png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() -def _plot_3sc(lsn, nm1, nm2, save_dir, suffix=''): +def _plot_3sc(lsn, nm1, nm2, save_dir, suffix=""): logging.info("Plotting for all cells") for nc in range(lsn.numbercells): if np.mod(nc, 20) == 0: @@ -518,136 +548,166 @@ def _plot_3sc(lsn, nm1, nm2, save_dir, suffix=''): ax9 = plt.subplot2grid((6, 6), (5, 0), colspan=3) ax10 = plt.subplot2grid((6, 6), (5, 3), colspan=3) - xtime = np.arange(0, np.size(lsn.celltraces, 1), 1.) + xtime = np.arange(0, np.size(lsn.celltraces, 1), 1.0) xtime /= lsn.acquisition_rate - dif = np.ediff1d(lsn.stim_table.start.values, - to_begin=8000, to_end=8000) + dif = np.ediff1d(lsn.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) ax1.plot(xtime, lsn.celltraces[nc, :]) for i in range(len(test) - 1): - ax1.axvspan(xmin=(lsn.stim_table.start.iloc[test[i]].values / lsn.acquisition_rate), xmax=( - lsn.stim_table.end.iloc[test[i + 1] - 1].values / lsn.acquisition_rate), color='gray', alpha=0.3) - ax1.axvspan(nm1.stim_table.start.min() / nm1.acquisition_rate, nm1.stim_table.end.max() / - nm1.acquisition_rate, ymin=0, ymax=1, color='red', alpha=0.3) - ax1.axvspan(nm2.stim_table.start.min() / nm2.acquisition_rate, nm2.stim_table.end.max() / - nm2.acquisition_rate, ymin=0, ymax=1, color='green', alpha=0.3) + ax1.axvspan( + xmin=(lsn.stim_table.start.iloc[test[i]].values / lsn.acquisition_rate), + xmax=(lsn.stim_table.end.iloc[test[i + 1] - 1].values / lsn.acquisition_rate), + color="gray", + alpha=0.3, + ) + ax1.axvspan( + nm1.stim_table.start.min() / nm1.acquisition_rate, + nm1.stim_table.end.max() / nm1.acquisition_rate, + ymin=0, + ymax=1, + color="red", + alpha=0.3, + ) + ax1.axvspan( + nm2.stim_table.start.min() / nm2.acquisition_rate, + nm2.stim_table.end.max() / nm2.acquisition_rate, + ymin=0, + ymax=1, + color="green", + alpha=0.3, + ) ax1.set_xlabel("Time (s)", fontsize=20) ax1.set_ylabel("Fluorescence", fontsize=20) ax2.hist(lsn.celltraces[nc, :], bins=70) - ax2.set_yscale('log') + ax2.set_yscale("log") ax2.set_xlabel("Fluorescence", fontsize=20) ax2.set_ylabel("Count", fontsize=20) - xtime = np.arange(0, np.size(lsn.dxcm), 1.) + xtime = np.arange(0, np.size(lsn.dxcm), 1.0) xtime /= lsn.acquisition_rate - ax11.plot(xtime, lsn.dxcm, color='k') + ax11.plot(xtime, lsn.dxcm, color="k") ax11.set_xlabel("Time (s)", fontsize=20) ax11.set_xlabel("Speed (cm/s)", fontsize=20) - smax = nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[ - nc, :, 0]), 0] + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 1] - vmax = nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[ - nc, :, 0]), 0] + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 1] + smax = ( + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 0] + + nm1.binned_cells_sp[nc, np.argmax(nm1.binned_cells_sp[nc, :, 0]), 1] + ) + vmax = ( + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 0] + + nm1.binned_cells_vis[nc, np.argmax(nm1.binned_cells_vis[nc, :, 0]), 1] + ) rmax = np.where(smax > vmax, smax, vmax) - smin = nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[ - nc, :, 0]), 0] - nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 1] - vmin = nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[ - nc, :, 0]), 0] - nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 1] + smin = ( + nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 0] + - nm1.binned_cells_sp[nc, np.argmin(nm1.binned_cells_sp[nc, :, 0]), 1] + ) + vmin = ( + nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 0] + - nm1.binned_cells_vis[nc, np.argmin(nm1.binned_cells_vis[nc, :, 0]), 1] + ) rmin = np.where(smin < vmin, smin, vmin) - ax12.errorbar(nm1.binned_dx_sp[:, 0], nm1.binned_cells_sp[ - nc, :, 0], yerr=nm1.binned_cells_sp[nc, :, 1], fmt='.', color='k') + ax12.errorbar( + nm1.binned_dx_sp[:, 0], + nm1.binned_cells_sp[nc, :, 0], + yerr=nm1.binned_cells_sp[nc, :, 1], + fmt=".", + color="k", + ) ax12.set_ylim(rmin, rmax) ax12.set_xlabel("Speed (cm/s)", fontsize=20) ax12.set_ylabel("DF/F", fontsize=20) ax12.set_title("Spontaneous", fontsize=20) - ax13.errorbar(nm1.binned_dx_vis[:, 0], nm1.binned_cells_vis[ - nc, :, 0], yerr=nm1.binned_cells_vis[nc, :, 1], fmt='.') + ax13.errorbar( + nm1.binned_dx_vis[:, 0], nm1.binned_cells_vis[nc, :, 0], yerr=nm1.binned_cells_vis[nc, :, 1], fmt="." + ) ax13.set_ylim(rmin, rmax) ax13.set_xlabel("Speed (cm/s)", fontsize=20) ax13.set_ylabel("DF/F", fontsize=20) ax13.set_title("Visual Stimuli", fontsize=20) - xtime = np.arange(0, nm1.sweeplength / - nm1.acquisition_rate, 1 / nm1.acquisition_rate) + xtime = np.arange(0, nm1.sweeplength / nm1.acquisition_rate, 1 / nm1.acquisition_rate) for index, row in nm1.sweep_response.iterrows(): ax3.plot(xtime, nm1.sweep_response[str(nc)][index], lw=2) ax3.set_xlabel("Time (s)", fontsize=20) ax3.set_ylabel("DF/F", fontsize=20) - ax3.set_title("Natural Movie 1", fontsize=20, color='red') + ax3.set_title("Natural Movie 1", fontsize=20, color="red") temp = np.empty((len(nm1.stim_table), nm1.sweeplength)) for i in range(len(nm1.stim_table)): temp[i, :] = nm1.sweep_response[str(nc)].iloc[i] - ax4.imshow(temp, cmap='gray', interpolation='none', aspect=40) + ax4.imshow(temp, cmap="gray", interpolation="none", aspect=40) ax4.set_ylabel("Trials", fontsize=20) ax4.set_xticks([]) - xtime = np.arange(0, nm2.sweeplength / - nm2.acquisition_rate, 1 / nm2.acquisition_rate) + xtime = np.arange(0, nm2.sweeplength / nm2.acquisition_rate, 1 / nm2.acquisition_rate) for index, row in nm2.sweep_response.iterrows(): ax5.plot(xtime, nm2.sweep_response[str(nc)][index], lw=2) ax5.set_xlabel("Time (s)", fontsize=20) ax5.set_ylabel("DF/F", fontsize=20) - ax5.set_title("Natural Movie 2", fontsize=20, color='green') + ax5.set_title("Natural Movie 2", fontsize=20, color="green") temp = np.empty((len(nm2.stim_table), nm2.sweeplength)) for i in range(len(nm2.stim_table)): temp[i, :] = nm2.sweep_response[str(nc)].iloc[i] - ax6.imshow(temp, cmap='gray', interpolation='none', aspect=40) + ax6.imshow(temp, cmap="gray", interpolation="none", aspect=40) ax6.set_ylabel("Trials", fontsize=20) ax6.set_xticks([]) - vMax = np.where(np.amax(lsn.receptive_field[:, :, nc, 0]) > np.amax(lsn.receptive_field[ - :, :, nc, 1]), np.amax(lsn.receptive_field[:, :, nc, 0]), np.amax(lsn.receptive_field[:, :, nc, 1])) - vMin = np.where(np.amin(lsn.receptive_field[:, :, nc, 0]) < np.amin(lsn.receptive_field[ - :, :, nc, 1]), np.amin(lsn.receptive_field[:, :, nc, 0]), np.amin(lsn.receptive_field[:, :, nc, 1])) - - imon = ax7.imshow(lsn.receptive_field[ - :, :, nc, 0], cmap='gray', interpolation='None', vmin=vMin, vmax=vMax) + vMax = np.where( + np.amax(lsn.receptive_field[:, :, nc, 0]) > np.amax(lsn.receptive_field[:, :, nc, 1]), + np.amax(lsn.receptive_field[:, :, nc, 0]), + np.amax(lsn.receptive_field[:, :, nc, 1]), + ) + vMin = np.where( + np.amin(lsn.receptive_field[:, :, nc, 0]) < np.amin(lsn.receptive_field[:, :, nc, 1]), + np.amin(lsn.receptive_field[:, :, nc, 0]), + np.amin(lsn.receptive_field[:, :, nc, 1]), + ) + + imon = ax7.imshow(lsn.receptive_field[:, :, nc, 0], cmap="gray", interpolation="None", vmin=vMin, vmax=vMax) ax7.set_title("ON", fontsize=20) ax7.set_xticks([]) ax7.set_yticks([]) cbar = plt.colorbar(imon, ax=ax7, fraction=0.046, pad=0.04) - cbar.ax.set_ylabel('DF/F (%)', fontsize=10) + cbar.ax.set_ylabel("DF/F (%)", fontsize=10) for t in cbar.ax.get_yticklabels(): t.set_fontsize(8) - imoff = ax8.imshow(lsn.receptive_field[ - :, :, nc, 1], cmap='gray', interpolation='None', vmin=vMin, vmax=vMax) + imoff = ax8.imshow(lsn.receptive_field[:, :, nc, 1], cmap="gray", interpolation="None", vmin=vMin, vmax=vMax) ax8.set_title("OFF", fontsize=20) ax8.set_xticks([]) ax8.set_yticks([]) cbar = plt.colorbar(imoff, ax=ax8, fraction=0.046, pad=0.04) - cbar.ax.set_ylabel('DF/F (%)', fontsize=10) + cbar.ax.set_ylabel("DF/F (%)", fontsize=10) for t in cbar.ax.get_yticklabels(): t.set_fontsize(8) - zon = (lsn.receptive_field[:, :, nc, 0] - np.mean(lsn.receptive_field[ - :, :, nc, 0])) / np.std(lsn.receptive_field[:, :, nc, 0]) + zon = (lsn.receptive_field[:, :, nc, 0] - np.mean(lsn.receptive_field[:, :, nc, 0])) / np.std( + lsn.receptive_field[:, :, nc, 0] + ) zon = np.where(abs(zon) > 2, zon, 0) - Vmax_on = np.where(abs(np.amax(zon)) > abs( - np.amin(zon)), np.amax(zon), -1 * np.amin(zon)) - zoff = (lsn.receptive_field[:, :, nc, 1] - np.mean(lsn.receptive_field[ - :, :, nc, 1])) / np.std(lsn.receptive_field[:, :, nc, 1]) + Vmax_on = np.where(abs(np.amax(zon)) > abs(np.amin(zon)), np.amax(zon), -1 * np.amin(zon)) + zoff = (lsn.receptive_field[:, :, nc, 1] - np.mean(lsn.receptive_field[:, :, nc, 1])) / np.std( + lsn.receptive_field[:, :, nc, 1] + ) zoff = np.where(abs(zoff) > 2, zoff, 0) - Vmax_off = np.where(abs(np.amax(zoff)) > abs( - np.amin(zoff)), np.amax(zoff), -1 * np.amin(zoff)) + Vmax_off = np.where(abs(np.amax(zoff)) > abs(np.amin(zoff)), np.amax(zoff), -1 * np.amin(zoff)) Vmax = np.where(Vmax_on > Vmax_off, Vmax_on, Vmax_off) - imzon = ax9.imshow(zon, cmap='RdBu_r', - interpolation='none', vmin=-1 * Vmax, vmax=Vmax) + imzon = ax9.imshow(zon, cmap="RdBu_r", interpolation="none", vmin=-1 * Vmax, vmax=Vmax) ax9.set_title("On Z-score", fontsize=20) ax9.set_xticks([]) ax9.set_yticks([]) cbar = plt.colorbar(imzon, ax=ax9, fraction=0.046, pad=0.04) - zoff = (lsn.receptive_field[:, :, nc, 1] - np.mean(lsn.receptive_field[ - :, :, nc, 1])) / np.std(lsn.receptive_field[:, :, nc, 1]) + zoff = (lsn.receptive_field[:, :, nc, 1] - np.mean(lsn.receptive_field[:, :, nc, 1])) / np.std( + lsn.receptive_field[:, :, nc, 1] + ) zoff = np.where(abs(zoff) > 2, zoff, 0) - imzoff = ax10.imshow( - zoff, cmap='RdBu', interpolation='none', vmin=-1 * Vmax, vmax=Vmax) + imzoff = ax10.imshow(zoff, cmap="RdBu", interpolation="none", vmin=-1 * Vmax, vmax=Vmax) ax10.set_title("Off Z-score", fontsize=20) ax10.set_xticks([]) ax10.set_yticks([]) @@ -655,7 +715,7 @@ def _plot_3sc(lsn, nm1, nm2, save_dir, suffix=''): plt.tick_params(labelsize=16) plt.tight_layout() - filename = 'Cell_' + str(nc + 1) + '_3SC' + suffix + '.png' + filename = "Cell_" + str(nc + 1) + "_3SC" + suffix + ".png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() @@ -684,203 +744,275 @@ def _plot_3sb(sg, nm1, ns, save_dir): ax11 = plt.subplot2grid((6, 6), (4, 3), colspan=2) # natural scenes ax12 = plt.subplot2grid((6, 6), (5, 3), colspan=2) - xtime = np.arange(0, np.size(sg.celltraces, 1), 1.) + xtime = np.arange(0, np.size(sg.celltraces, 1), 1.0) xtime /= sg.acquisition_rate - dif = np.ediff1d(sg.stim_table.start.values, - to_begin=8000, to_end=8000) + dif = np.ediff1d(sg.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) ax1.plot(xtime, sg.celltraces[nc, :]) for i in range(len(test) - 1): - ax1.axvspan(xmin=(sg.stim_table.start.iloc[test[i]].values / sg.acquisition_rate), xmax=( - sg.stim_table.end.iloc[test[i + 1] - 1].values / sg.acquisition_rate), color='gray', alpha=0.3) - ax1.axvspan(nm1.stim_table.start.min() / nm1.acquisition_rate, nm1.stim_table.end.max() / - nm1.acquisition_rate, ymin=0, ymax=1, color='red', alpha=0.3) - dif = np.ediff1d(ns.stim_table.start.values, - to_begin=8000, to_end=8000) + ax1.axvspan( + xmin=(sg.stim_table.start.iloc[test[i]].values / sg.acquisition_rate), + xmax=(sg.stim_table.end.iloc[test[i + 1] - 1].values / sg.acquisition_rate), + color="gray", + alpha=0.3, + ) + ax1.axvspan( + nm1.stim_table.start.min() / nm1.acquisition_rate, + nm1.stim_table.end.max() / nm1.acquisition_rate, + ymin=0, + ymax=1, + color="red", + alpha=0.3, + ) + dif = np.ediff1d(ns.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) for i in range(len(test) - 1): - ax1.axvspan(xmin=(ns.stim_table.start.iloc[test[i]].values / ns.acquisition_rate), xmax=( - ns.stim_table.end.iloc[test[i + 1] - 1].values / ns.acquisition_rate), color='blue', alpha=0.3) + ax1.axvspan( + xmin=(ns.stim_table.start.iloc[test[i]].values / ns.acquisition_rate), + xmax=(ns.stim_table.end.iloc[test[i + 1] - 1].values / ns.acquisition_rate), + color="blue", + alpha=0.3, + ) ax1.set_xlabel("Time (s)", fontsize=20) ax1.set_ylabel("Fluorescence", fontsize=20) ax2.hist(sg.celltraces[nc, :], bins=70) - ax2.set_yscale('log') + ax2.set_yscale("log") ax2.set_xlabel("Fluorescence", fontsize=20) ax2.set_ylabel("Count", fontsize=20) - xtime = np.arange(0, np.size(sg.dxcm), 1.) + xtime = np.arange(0, np.size(sg.dxcm), 1.0) xtime /= sg.acquisition_rate - ax15.plot(xtime, sg.dxcm, color='k') + ax15.plot(xtime, sg.dxcm, color="k") ax15.set_xlabel("Time (s)", fontsize=20) ax15.set_xlabel("Speed (cm/s)", fontsize=20) peakori = sg.peak.ori_sg[nc] peaksf = sg.peak.sf_sg[nc] - ax3.errorbar(sg.orivals, sg.response[:, peaksf, 0, nc, 0], yerr=sg.response[ - :, peaksf, 0, nc, 1], color='blue', fmt='o-', lw=2) - ax3.errorbar(sg.orivals, sg.response[:, peaksf, 1, nc, 0], yerr=sg.response[ - :, peaksf, 1, nc, 1], color='cornflowerblue', fmt='o-', lw=2) - ax3.errorbar(sg.orivals, sg.response[:, peaksf, 2, nc, 0], yerr=sg.response[ - :, peaksf, 2, nc, 1], color='steelblue', fmt='o-', lw=2) - ax3.errorbar(sg.orivals, sg.response[:, peaksf, 3, nc, 0], yerr=sg.response[ - :, peaksf, 3, nc, 1], color='lightskyblue', fmt='o-', lw=2) - ax3.fill_between(sg.orivals, np.repeat(sg.response[0, 0, 0, nc, 0] + sg.response[0, 0, 0, nc, 1], sg.number_ori), np.repeat( - sg.response[0, 0, 0, nc, 0] - sg.response[0, 0, 0, nc, 1], sg.number_ori), color='gray', alpha=0.5) - ax3.axhline(y=sg.response[0, 0, 0, nc, 0], ls='--', color='k', lw=2) + ax3.errorbar( + sg.orivals, + sg.response[:, peaksf, 0, nc, 0], + yerr=sg.response[:, peaksf, 0, nc, 1], + color="blue", + fmt="o-", + lw=2, + ) + ax3.errorbar( + sg.orivals, + sg.response[:, peaksf, 1, nc, 0], + yerr=sg.response[:, peaksf, 1, nc, 1], + color="cornflowerblue", + fmt="o-", + lw=2, + ) + ax3.errorbar( + sg.orivals, + sg.response[:, peaksf, 2, nc, 0], + yerr=sg.response[:, peaksf, 2, nc, 1], + color="steelblue", + fmt="o-", + lw=2, + ) + ax3.errorbar( + sg.orivals, + sg.response[:, peaksf, 3, nc, 0], + yerr=sg.response[:, peaksf, 3, nc, 1], + color="lightskyblue", + fmt="o-", + lw=2, + ) + ax3.fill_between( + sg.orivals, + np.repeat(sg.response[0, 0, 0, nc, 0] + sg.response[0, 0, 0, nc, 1], sg.number_ori), + np.repeat(sg.response[0, 0, 0, nc, 0] - sg.response[0, 0, 0, nc, 1], sg.number_ori), + color="gray", + alpha=0.5, + ) + ax3.axhline(y=sg.response[0, 0, 0, nc, 0], ls="--", color="k", lw=2) ax3.set_xlim(-10, 160) ax3.set_xticks(sg.orivals) ax3.set_xlabel("Orientation (d)", fontsize=20) ax3.set_ylabel("DF/F (%)", fontsize=20) - ax4.errorbar(range(5), sg.response[peakori, 1:, 0, nc, 0], yerr=sg.response[ - peakori, 1:, 0, nc, 1], color='blue', fmt='o-', lw=2) - ax4.errorbar(range(5), sg.response[peakori, 1:, 1, nc, 0], yerr=sg.response[ - peakori, 1:, 1, nc, 1], color='cornflowerblue', fmt='o-', lw=2) - ax4.errorbar(range(5), sg.response[peakori, 1:, 2, nc, 0], yerr=sg.response[ - peakori, 1:, 2, nc, 1], color='steelblue', fmt='o-', lw=2) - ax4.errorbar(range(5), sg.response[peakori, 1:, 3, nc, 0], yerr=sg.response[ - peakori, 1:, 3, nc, 1], color='lightskyblue', fmt='o-', lw=2) - ax4.fill_between(range(5), np.repeat(sg.response[0, 0, 0, nc, 0] + sg.response[0, 0, 0, nc, 1], 5), np.repeat( - sg.response[0, 0, 0, nc, 0] - sg.response[0, 0, 0, nc, 1], 5), color='gray', alpha=0.5) - ax4.axhline(y=sg.response[0, 0, 0, nc, 0], ls='--', color='k', lw=2) + ax4.errorbar( + range(5), + sg.response[peakori, 1:, 0, nc, 0], + yerr=sg.response[peakori, 1:, 0, nc, 1], + color="blue", + fmt="o-", + lw=2, + ) + ax4.errorbar( + range(5), + sg.response[peakori, 1:, 1, nc, 0], + yerr=sg.response[peakori, 1:, 1, nc, 1], + color="cornflowerblue", + fmt="o-", + lw=2, + ) + ax4.errorbar( + range(5), + sg.response[peakori, 1:, 2, nc, 0], + yerr=sg.response[peakori, 1:, 2, nc, 1], + color="steelblue", + fmt="o-", + lw=2, + ) + ax4.errorbar( + range(5), + sg.response[peakori, 1:, 3, nc, 0], + yerr=sg.response[peakori, 1:, 3, nc, 1], + color="lightskyblue", + fmt="o-", + lw=2, + ) + ax4.fill_between( + range(5), + np.repeat(sg.response[0, 0, 0, nc, 0] + sg.response[0, 0, 0, nc, 1], 5), + np.repeat(sg.response[0, 0, 0, nc, 0] - sg.response[0, 0, 0, nc, 1], 5), + color="gray", + alpha=0.5, + ) + ax4.axhline(y=sg.response[0, 0, 0, nc, 0], ls="--", color="k", lw=2) ax4.set_xlim(-0.2, 4.2) ax4.set_xticks(range(5)) ax4.set_xticklabels(sg.sfvals[1:]) ax4.set_xlabel("Spatial frequency (cpd)", fontsize=20) - xtime = np.arange(-1 * sg.interlength / sg.acquisition_rate, (sg.sweeplength + - sg.interlength) / sg.acquisition_rate, 1 / sg.acquisition_rate) + xtime = np.arange( + -1 * sg.interlength / sg.acquisition_rate, + (sg.sweeplength + sg.interlength) / sg.acquisition_rate, + 1 / sg.acquisition_rate, + ) peakori = sg.peak.ori_sg[nc] peaksf = sg.peak.sf_sg[nc] peakphase = sg.peak.phase_sg[nc] - subset = sg.sweep_response[(sg.stim_table.orientation == sg.orivals[peakori]) & ( - sg.stim_table.spatial_frequency == sg.sfvals[peaksf]) & (sg.stim_table.phase == sg.phasevals[peakphase])] - subset_p = subset[str(nc)].mean( - ) + (subset[str(nc)].std() / np.sqrt(len(subset[str(nc)]))) - subset_n = subset[str(nc)].mean( - ) - (subset[str(nc)].std() / np.sqrt(len(subset[str(nc)]))) + subset = sg.sweep_response[ + (sg.stim_table.orientation == sg.orivals[peakori]) + & (sg.stim_table.spatial_frequency == sg.sfvals[peaksf]) + & (sg.stim_table.phase == sg.phasevals[peakphase]) + ] + subset_p = subset[str(nc)].mean() + (subset[str(nc)].std() / np.sqrt(len(subset[str(nc)]))) + subset_n = subset[str(nc)].mean() - (subset[str(nc)].std() / np.sqrt(len(subset[str(nc)]))) try: - ax13.fill_between(xtime, subset_p, subset_n, color='b', alpha=0.5) + ax13.fill_between(xtime, subset_p, subset_n, color="b", alpha=0.5) except Exception: xtime = xtime[:-1] - ax13.fill_between(xtime, subset_p, subset_n, color='b', alpha=0.5) - blank = sg.sweep_response[(sg.stim_table.orientation == 0) & ( - sg.stim_table.spatial_frequency == 0) & (sg.stim_table.phase == 0)] - blank_p = blank[str(nc)].mean() + \ - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) - blank_n = blank[str(nc)].mean() - \ - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) - ax13.fill_between(xtime, blank_p, blank_n, color='gray', alpha=0.5) - ax13.plot(xtime, subset[str(nc)].mean(), color='b', lw=2) - ax13.plot(xtime, blank[str(nc)].mean(), color='k', lw=2) - ax13.axvspan(0, sg.sweeplength / sg.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + ax13.fill_between(xtime, subset_p, subset_n, color="b", alpha=0.5) + blank = sg.sweep_response[ + (sg.stim_table.orientation == 0) & (sg.stim_table.spatial_frequency == 0) & (sg.stim_table.phase == 0) + ] + blank_p = blank[str(nc)].mean() + (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) + blank_n = blank[str(nc)].mean() - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) + ax13.fill_between(xtime, blank_p, blank_n, color="gray", alpha=0.5) + ax13.plot(xtime, subset[str(nc)].mean(), color="b", lw=2) + ax13.plot(xtime, blank[str(nc)].mean(), color="k", lw=2) + ax13.axvspan(0, sg.sweeplength / sg.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax13.yaxis.set_major_locator(MaxNLocator(4)) ax13.set_xlabel("Time (s)", fontsize=20) ax13.set_ylabel("DF/F (%)", fontsize=20) Vmax = sg.response[:, 1:, :, nc, 0].max() - ax5.imshow(sg.response[:, 1:, 0, nc, 0], cmap='gray', - interpolation='none', vmin=0, vmax=Vmax) + ax5.imshow(sg.response[:, 1:, 0, nc, 0], cmap="gray", interpolation="none", vmin=0, vmax=Vmax) ax5.set_ylabel("Orientation (d)", fontsize=20) ax5.set_yticks(range(6)) ax5.set_yticklabels(sg.orivals) ax5.set_xlabel("Spatial frequency (cpd)", fontsize=20) ax5.set_xticks(range(5)) ax5.set_xticklabels(sg.sfvals[1:]) - ax5.set_title("Phase 0.0", color='blue', fontsize=20) + ax5.set_title("Phase 0.0", color="blue", fontsize=20) - ax6.imshow(sg.response[:, 1:, 1, nc, 0], cmap='gray', - interpolation='none', vmin=0, vmax=Vmax) + ax6.imshow(sg.response[:, 1:, 1, nc, 0], cmap="gray", interpolation="none", vmin=0, vmax=Vmax) ax6.set_xlabel("Spatial frequency (cpd)", fontsize=20) ax6.set_xticks(range(5)) ax6.set_xticklabels(sg.sfvals[1:]) ax6.set_yticks(range(6)) ax6.set_yticklabels(sg.orivals) - ax6.set_title("Phase 0.25", color='cornflowerblue', fontsize=20) + ax6.set_title("Phase 0.25", color="cornflowerblue", fontsize=20) - ax7.imshow(sg.response[:, 1:, 2, nc, 0], cmap='gray', - interpolation='none', vmin=0, vmax=Vmax) + ax7.imshow(sg.response[:, 1:, 2, nc, 0], cmap="gray", interpolation="none", vmin=0, vmax=Vmax) ax7.set_xlabel("Spatial frequency (cpd)", fontsize=20) ax7.set_xticks(range(5)) ax7.set_xticklabels(sg.sfvals[1:]) ax7.set_yticks(range(6)) ax7.set_yticklabels(sg.orivals) - ax7.set_title("Phase 0.5", color='steelblue', fontsize=20) + ax7.set_title("Phase 0.5", color="steelblue", fontsize=20) - ax8.imshow(sg.response[:, 1:, 3, nc, 0], cmap='gray', - interpolation='none', vmin=0, vmax=Vmax) + ax8.imshow(sg.response[:, 1:, 3, nc, 0], cmap="gray", interpolation="none", vmin=0, vmax=Vmax) ax8.set_xlabel("Spatial frequency (cpd)", fontsize=20) ax8.set_xticks(range(5)) ax8.set_xticklabels(sg.sfvals[1:]) ax8.set_yticks(range(6)) ax8.set_yticklabels(sg.orivals) - ax8.set_title("Phase 0.75", color='lightskyblue', fontsize=20) + ax8.set_title("Phase 0.75", color="lightskyblue", fontsize=20) - xtime = np.arange(0, nm1.sweeplength / - nm1.acquisition_rate, 1 / nm1.acquisition_rate) + xtime = np.arange(0, nm1.sweeplength / nm1.acquisition_rate, 1 / nm1.acquisition_rate) while len(xtime) > nm1.sweeplength: xtime = np.delete(xtime, -1) for index, row in nm1.sweep_response.iterrows(): ax9.plot(xtime, nm1.sweep_response[str(nc)][index], lw=2) ax9.set_xlabel("Time (s)", fontsize=20) ax9.set_ylabel("DF/F", fontsize=20) - ax9.set_title("Natural Movie 1", fontsize=20, color='red') + ax9.set_title("Natural Movie 1", fontsize=20, color="red") temp = np.empty((len(nm1.stim_table), nm1.sweeplength)) for i in range(len(nm1.stim_table)): temp[i, :] = nm1.sweep_response[str(nc)].iloc[i] - ax10.imshow(temp, cmap='gray', interpolation='none', aspect=40) + ax10.imshow(temp, cmap="gray", interpolation="none", aspect=40) ax10.set_ylabel("Trials", fontsize=20) ax10.set_xticks([]) temp = np.copy(ns.response[1:, nc, :2]) - scene_response = pd.DataFrame(temp, columns=('response', 'error')) - scene_response = scene_response.sort( - columns='response', ascending=False) - ax11.errorbar(range(ns.number_scenes - 1), scene_response.response, - yerr=scene_response.error, fmt='o', color='k') - ax11.fill_between(range(ns.number_scenes - 1), np.repeat(ns.response[0, nc, 0] + ns.response[0, nc, 1], ns.number_scenes - 1), np.repeat( - ns.response[0, nc, 0] - ns.response[0, nc, 1], ns.number_scenes - 1), color='gray', alpha=0.3) - ax11.axhline(y=ns.response[0, nc, 0], ls='--', lw=2, color='k') + scene_response = pd.DataFrame(temp, columns=("response", "error")) + scene_response = scene_response.sort(columns="response", ascending=False) + ax11.errorbar( + range(ns.number_scenes - 1), scene_response.response, yerr=scene_response.error, fmt="o", color="k" + ) + ax11.fill_between( + range(ns.number_scenes - 1), + np.repeat(ns.response[0, nc, 0] + ns.response[0, nc, 1], ns.number_scenes - 1), + np.repeat(ns.response[0, nc, 0] - ns.response[0, nc, 1], ns.number_scenes - 1), + color="gray", + alpha=0.3, + ) + ax11.axhline(y=ns.response[0, nc, 0], ls="--", lw=2, color="k") ax11.set_xlim(-2, 120) - ax11.set_title("Natural Scenes", fontsize=20, color='blue') + ax11.set_title("Natural Scenes", fontsize=20, color="blue") ax11.set_xlabel("Scene", fontsize=20) ax11.set_ylabel("DF/F (%)", fontsize=20) - xtime = np.arange(-1 * ns.interlength / ns.acquisition_rate, (ns.sweeplength + - ns.interlength) / ns.acquisition_rate, 1 / ns.acquisition_rate) + xtime = np.arange( + -1 * ns.interlength / ns.acquisition_rate, + (ns.sweeplength + ns.interlength) / ns.acquisition_rate, + 1 / ns.acquisition_rate, + ) nsp = np.argmax(ns.response[1:, nc, 0]) subset_response = ns.sweep_response[ns.stim_table.frame == nsp] - subset_response_p = subset_response[str(nc)].mean( - ) + (subset_response[str(nc)][:].std() / np.sqrt(len(subset_response[str(nc)]))) - subset_response_n = subset_response[str(nc)].mean( - ) - (subset_response[str(nc)][:].std() / np.sqrt(len(subset_response[str(nc)]))) + subset_response_p = subset_response[str(nc)].mean() + ( + subset_response[str(nc)][:].std() / np.sqrt(len(subset_response[str(nc)])) + ) + subset_response_n = subset_response[str(nc)].mean() - ( + subset_response[str(nc)][:].std() / np.sqrt(len(subset_response[str(nc)])) + ) try: - ax12.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) + ax12.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) except Exception: xtime = xtime[:-1] - ax12.fill_between(xtime, subset_response_p, - subset_response_n, color='b', alpha=0.5) + ax12.fill_between(xtime, subset_response_p, subset_response_n, color="b", alpha=0.5) blank = ns.sweep_response[ns.stim_table.frame == -1] - blank_p = blank[str(nc)].mean() + \ - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) - blank_n = blank[str(nc)].mean() - \ - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) - ax12.fill_between(xtime, blank_p, blank_n, color='gray', alpha=0.5) - ax12.plot(xtime, subset_response[str(nc)].mean(), color='b', lw=2) - ax12.plot(xtime, blank[str(nc)].mean(), color='k', lw=2) - ax12.axvspan(0, ns.sweeplength / ns.acquisition_rate, - ymin=0, ymax=1, facecolor='gray', alpha=0.3) + blank_p = blank[str(nc)].mean() + (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) + blank_n = blank[str(nc)].mean() - (blank[str(nc)].std() / np.sqrt(len(blank[str(nc)]))) + ax12.fill_between(xtime, blank_p, blank_n, color="gray", alpha=0.5) + ax12.plot(xtime, subset_response[str(nc)].mean(), color="b", lw=2) + ax12.plot(xtime, blank[str(nc)].mean(), color="k", lw=2) + ax12.axvspan(0, ns.sweeplength / ns.acquisition_rate, ymin=0, ymax=1, facecolor="gray", alpha=0.3) ax12.yaxis.set_major_locator(MaxNLocator(4)) ax12.set_xlabel("Time (s)", fontsize=20) ax12.set_ylabel("DF/F (%)", fontsize=20) plt.tick_params(labelsize=16) plt.tight_layout() - filename = 'Cell_' + str(nc + 1) + '_3SB.png' + filename = "Cell_" + str(nc + 1) + "_3SB.png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() @@ -900,58 +1032,74 @@ def plot_running_a(dg, nm1, nm3, save_dir): ax8 = plt.subplot2grid((4, 4), (2, 2), colspan=2) ax9 = plt.subplot2grid((4, 4), (3, 2), colspan=2) - xtime = np.arange(0, np.size(dg.dxcm), 1.) + xtime = np.arange(0, np.size(dg.dxcm), 1.0) xtime /= dg.acquisition_rate dif = np.ediff1d(dg.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) - ax1.plot(xtime, dg.dxcm, color='k') + ax1.plot(xtime, dg.dxcm, color="k") for i in range(len(test) - 1): - ax1.axvspan(xmin=(dg.stim_table.start.iloc[test[i]].values / dg.acquisition_rate), xmax=( - dg.stim_table.end.iloc[test[i + 1] - 1].values / dg.acquisition_rate), color='gray', alpha=0.3) - ax1.axvspan(xmin=nm1.stim_table.start.min() / nm1.acquisition_rate, xmax=( - (nm1.stim_table.start.max() + nm1.sweeplength) / nm1.acquisition_rate), color='red', alpha=0.3) + ax1.axvspan( + xmin=(dg.stim_table.start.iloc[test[i]].values / dg.acquisition_rate), + xmax=(dg.stim_table.end.iloc[test[i + 1] - 1].values / dg.acquisition_rate), + color="gray", + alpha=0.3, + ) + ax1.axvspan( + xmin=nm1.stim_table.start.min() / nm1.acquisition_rate, + xmax=((nm1.stim_table.start.max() + nm1.sweeplength) / nm1.acquisition_rate), + color="red", + alpha=0.3, + ) dif = np.ediff1d(nm3.stim_table.start.values, to_begin=8000, to_end=8000) test = np.argwhere(dif > 5000) for i in range(len(test) - 1): - ax1.axvspan(xmin=(nm3.stim_table.start.iloc[test[i]].values / nm3.acquisition_rate), xmax=( - (nm3.stim_table.end.iloc[test[i + 1] - 1].values + nm3.sweeplength) / nm3.acquisition_rate), color='blue', alpha=0.3) + ax1.axvspan( + xmin=(nm3.stim_table.start.iloc[test[i]].values / nm3.acquisition_rate), + xmax=((nm3.stim_table.end.iloc[test[i + 1] - 1].values + nm3.sweeplength) / nm3.acquisition_rate), + color="blue", + alpha=0.3, + ) ax1.set_xlabel("Time (s)", fontsize=20) ax1.set_ylabel("Speed (cm/s)", fontsize=20) dx = dg.dxcm[np.logical_not(np.isnan(dg.dxcm))] - ax2.hist(dx, bins=80, range=(-20, 100), color='gray') + ax2.hist(dx, bins=80, range=(-20, 100), color="gray") ax2.set_xlabel("Speed (cm/s)", fontsize=20) - run_peak = np.where(dg.response[:, 1:, nc, 0] - == np.nanmax(dg.response[:, 1:, nc, 0])) + run_peak = np.where(dg.response[:, 1:, nc, 0] == np.nanmax(dg.response[:, 1:, nc, 0])) peakori = run_peak[0][0] peaktf = run_peak[1][0] + 1 - ax3.errorbar(dg.orivals, dg.response[:, peaktf, nc, 0], yerr=dg.response[ - :, peaktf, nc, 1], fmt='b.-', lw=2) - ax3.fill_between(dg.orivals, np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_ori), np.repeat( - dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_ori), color='gray', alpha=0.5) - ax3.axhline(y=dg.response[0, 0, nc, 0], ls='--', color='k') - ax3.annotate(str(dg.tfvals[peaktf]) + " Hz", - xy=(0, 0.9), xycoords='axes fraction', fontsize=14) - ax3.set_xtick = (dg.orivals) + ax3.errorbar(dg.orivals, dg.response[:, peaktf, nc, 0], yerr=dg.response[:, peaktf, nc, 1], fmt="b.-", lw=2) + ax3.fill_between( + dg.orivals, + np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_ori), + np.repeat(dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_ori), + color="gray", + alpha=0.5, + ) + ax3.axhline(y=dg.response[0, 0, nc, 0], ls="--", color="k") + ax3.annotate(str(dg.tfvals[peaktf]) + " Hz", xy=(0, 0.9), xycoords="axes fraction", fontsize=14) + ax3.set_xtick = dg.orivals ax3.set_xlabel("Direction (deg)", fontsize=20) ax3.set_ylabel("Speed (cm/s)", fontsize=20) ax3.yaxis.set_major_locator(MaxNLocator(6)) - ax4.errorbar(dg.tfvals[1:], dg.response[peakori, 1:, nc, 0], yerr=dg.response[ - peakori, 1:, nc, 1], fmt='b.-', lw=2) - ax4.fill_between(dg.tfvals[1:], np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_tf - 1), - np.repeat(dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_tf - 1), color='gray', alpha=0.5) - ax4.axhline(y=dg.response[0, 0, nc, 0], ls='--', color='k') - ax4.annotate(str(dg.orivals[peakori]) + " Deg", - xy=(0, 0.9), xycoords='axes fraction', fontsize=14) - ax4.set_xticks = (dg.tfvals[1:]) + ax4.errorbar(dg.tfvals[1:], dg.response[peakori, 1:, nc, 0], yerr=dg.response[peakori, 1:, nc, 1], fmt="b.-", lw=2) + ax4.fill_between( + dg.tfvals[1:], + np.repeat(dg.response[0, 0, nc, 0] + dg.response[0, 0, nc, 1], dg.number_tf - 1), + np.repeat(dg.response[0, 0, nc, 0] - dg.response[0, 0, nc, 1], dg.number_tf - 1), + color="gray", + alpha=0.5, + ) + ax4.axhline(y=dg.response[0, 0, nc, 0], ls="--", color="k") + ax4.annotate(str(dg.orivals[peakori]) + " Deg", xy=(0, 0.9), xycoords="axes fraction", fontsize=14) + ax4.set_xticks = dg.tfvals[1:] ax4.set_xlabel("Temporal frequency (Hz)", fontsize=20) ax4.yaxis.set_major_locator(MaxNLocator(6)) - im = ax5.imshow(dg.response[:, 1:, nc, 0], - cmap='gray', interpolation='none') + im = ax5.imshow(dg.response[:, 1:, nc, 0], cmap="gray", interpolation="none") ax5.set_ylabel("Direction", fontsize=16) ax5.set_xlabel("TF", fontsize=16) ax5.set_yticks(range(dg.number_ori)) @@ -959,41 +1107,39 @@ def plot_running_a(dg, nm1, nm3, save_dir): ax5.set_xticks(range(dg.number_tf - 1)) ax5.set_xticklabels(list(dg.tfvals[1:].astype(int).astype(str))) cbar = plt.colorbar(im, ax=ax5) - cbar.ax.set_ylabel('Speed (cm/s)', fontsize=8) + cbar.ax.set_ylabel("Speed (cm/s)", fontsize=8) for t in cbar.ax.get_yticklabels(): t.set_fontsize(8) - xtime = np.arange(0, nm1.sweeplength / - nm1.acquisition_rate, 1 / nm1.acquisition_rate) - while len(xtime) > len(nm1.sweep_response['dx'].mean()): + xtime = np.arange(0, nm1.sweeplength / nm1.acquisition_rate, 1 / nm1.acquisition_rate) + while len(xtime) > len(nm1.sweep_response["dx"].mean()): xtime = np.delete(xtime, -1) for index, row in nm1.sweep_response.iterrows(): - ax6.plot(xtime, nm1.sweep_response['dx'][index], lw=2) + ax6.plot(xtime, nm1.sweep_response["dx"][index], lw=2) ax6.set_xlabel("Time (s)", fontsize=20) ax6.set_ylabel("DF/F", fontsize=20) - ax6.set_title("Natural Movie 1", fontsize=20, color='red') + ax6.set_title("Natural Movie 1", fontsize=20, color="red") temp = np.empty((len(nm1.stim_table), nm1.sweeplength)) for i in range(len(nm1.stim_table)): - temp[i, :] = nm1.sweep_response['dx'].iloc[i][:, 0] - ax7.imshow(temp, cmap='gray', interpolation='none', aspect=40) + temp[i, :] = nm1.sweep_response["dx"].iloc[i][:, 0] + ax7.imshow(temp, cmap="gray", interpolation="none", aspect=40) ax7.set_ylabel("Trials", fontsize=20) ax7.set_xticks([]) - xtime = np.arange(0, nm3.sweeplength / - nm3.acquisition_rate, 1 / nm3.acquisition_rate) - while len(xtime) > len(nm3.sweep_response['dx'].mean()): + xtime = np.arange(0, nm3.sweeplength / nm3.acquisition_rate, 1 / nm3.acquisition_rate) + while len(xtime) > len(nm3.sweep_response["dx"].mean()): xtime = np.delete(xtime, -1) for index, row in nm3.sweep_response.iterrows(): - ax8.plot(xtime, nm3.sweep_response['dx'][index], lw=2) + ax8.plot(xtime, nm3.sweep_response["dx"][index], lw=2) ax8.set_xlabel("Time (s)", fontsize=20) ax8.set_ylabel("DF/F", fontsize=20) - ax8.set_title("Natural Movie Long", fontsize=20, color='blue') + ax8.set_title("Natural Movie Long", fontsize=20, color="blue") temp = np.empty((len(nm3.stim_table), nm3.sweeplength)) for i in range(len(nm3.stim_table)): - temp[i, :] = nm3.sweep_response['dx'].iloc[i][:, 0] - ax9.imshow(temp, cmap='gray', interpolation='none', aspect=100) + temp[i, :] = nm3.sweep_response["dx"].iloc[i][:, 0] + ax9.imshow(temp, cmap="gray", interpolation="none", aspect=100) ax9.set_ylabel("Trials", fontsize=20) ax9.set_xticks([]) @@ -1001,7 +1147,7 @@ def plot_running_a(dg, nm1, nm3, save_dir): plt.tight_layout() plt.suptitle("Running Summary", fontsize=20) plt.subplots_adjust(top=0.9) - filename = 'Running Summary.png' + filename = "Running Summary.png" fullfilename = os.path.join(save_dir, filename) plt.savefig(fullfilename) plt.close() diff --git a/allensdk/brain_observatory/chisquare_categorical.py b/allensdk/brain_observatory/chisquare_categorical.py index debb41f971..6eb1cf1de1 100644 --- a/allensdk/brain_observatory/chisquare_categorical.py +++ b/allensdk/brain_observatory/chisquare_categorical.py @@ -12,27 +12,19 @@ import numpy as np -def chisq_from_stim_table( - stim_table, columns, mean_sweep_events, num_shuffles=1000, verbose=False -): +def chisq_from_stim_table(stim_table, columns, mean_sweep_events, num_shuffles=1000, verbose=False): # stim_table is a pandas DataFrame with len = num_sweeps # columns is a list of column names that define the categories (e.g. # ['Ori','Contrast']) mean_sweep_events is a numpy array with shape # (num_sweeps,num_cells) - sweep_categories = stim_table_to_categories( - stim_table, columns, verbose=verbose - ) - p_vals = compute_chi_shuffle( - mean_sweep_events, sweep_categories, num_shuffles=num_shuffles - ) + sweep_categories = stim_table_to_categories(stim_table, columns, verbose=verbose) + p_vals = compute_chi_shuffle(mean_sweep_events, sweep_categories, num_shuffles=num_shuffles) return p_vals -def compute_chi_shuffle( - mean_sweep_events, sweep_categories, num_shuffles=1000 -): +def compute_chi_shuffle(mean_sweep_events, sweep_categories, num_shuffles=1000): # mean_sweep_events is a numpy array with shape (num_sweeps,num_cells) # sweep_conditions is a numpy array with shape (num_sweeps) # sweep_conditions gives the category label for each sweep @@ -54,12 +46,8 @@ def compute_chi_shuffle( shuffle_sweeps = np.random.choice(num_sweeps, size=(num_sweeps,)) shuffle_sweep_events = mean_sweep_events[shuffle_sweeps] - shuffle_expected = compute_expected( - shuffle_sweep_events, sweep_categories_dummy - ) - shuffle_observed = compute_observed( - shuffle_sweep_events, sweep_categories_dummy - ) + shuffle_expected = compute_expected(shuffle_sweep_events, sweep_categories_dummy) + shuffle_observed = compute_observed(shuffle_sweep_events, sweep_categories_dummy) chi_shuffle[:, ns] = compute_chi(shuffle_observed, shuffle_expected) @@ -111,9 +99,7 @@ def stim_table_to_categories(stim_table, columns, verbose=False): category += 1 # advance the combination - curr_combination = advance_combination( - curr_combination, options_per_column - ) + curr_combination = advance_combination(curr_combination, options_per_column) all_tried = curr_combination[0] == options_per_column[0] if verbose: @@ -158,9 +144,9 @@ def compute_observed(mean_sweep_events, sweep_conditions): (num_sweeps, num_conditions) = np.shape(sweep_conditions) num_cells = np.shape(mean_sweep_events)[1] - observed_mat = (mean_sweep_events.T).reshape( - num_cells, num_sweeps, 1 - ) * sweep_conditions.reshape(1, num_sweeps, num_conditions) + observed_mat = (mean_sweep_events.T).reshape(num_cells, num_sweeps, 1) * sweep_conditions.reshape( + 1, num_sweeps, num_conditions + ) observed = np.sum(observed_mat, axis=1) return observed @@ -173,9 +159,7 @@ def compute_expected(mean_sweep_events, sweep_conditions): sweeps_per_condition = np.sum(sweep_conditions, axis=0) events_per_sweep = np.mean(mean_sweep_events, axis=0) - expected = sweeps_per_condition.reshape( - 1, num_conditions - ) * events_per_sweep.reshape(num_cells, 1) + expected = sweeps_per_condition.reshape(1, num_conditions) * events_per_sweep.reshape(num_cells, 1) return expected diff --git a/allensdk/brain_observatory/circle_plots.py b/allensdk/brain_observatory/circle_plots.py index 5c36e54d47..68374f9feb 100644 --- a/allensdk/brain_observatory/circle_plots.py +++ b/allensdk/brain_observatory/circle_plots.py @@ -48,28 +48,28 @@ import skimage.transform -DEFAULT_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[.7,0,.7,0.0],[.7,0,0,1]]) -DEFAULT_MEAN_RESP_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[0.0,0.0,0.5,0.0],[0.0,0.0,0.5,1]]) +DEFAULT_COLOR_MAP = LinearSegmentedColormap.from_list("default", [[0.7, 0, 0.7, 0.0], [0.7, 0, 0, 1]]) +DEFAULT_MEAN_RESP_COLOR_MAP = LinearSegmentedColormap.from_list("default", [[0.0, 0.0, 0.5, 0.0], [0.0, 0.0, 0.5, 1]]) DEFAULT_AXIS_COLOR = (0.8, 0.8, 0.8) DEFAULT_LABEL_COLOR = (0.8, 0.8, 0.8) -LSN_ON_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[.7,0,.7,0.0],[.7,0,0,1]]) -LSN_OFF_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[0.0,0.7,.7,0.0],[0,0,0.7,1]]) +LSN_ON_COLOR_MAP = LinearSegmentedColormap.from_list("default", [[0.7, 0, 0.7, 0.0], [0.7, 0, 0, 1]]) +LSN_OFF_COLOR_MAP = LinearSegmentedColormap.from_list("default", [[0.0, 0.7, 0.7, 0.0], [0, 0, 0.7, 1]]) HEX_POSITIONS = [] def polar_to_xy(angles, radius): - """ Convert an array of angles (in radians) and a radius in polar coordinates - to an array of x,y coordinates. + """Convert an array of angles (in radians) and a radius in polar coordinates + to an array of x,y coordinates. """ - x = radius*np.cos(angles) - y = radius*np.sin(angles) - return np.array([x,y]).T + x = radius * np.cos(angles) + y = radius * np.sin(angles) + return np.array([x, y]).T def polar_linspace(radius, start_angle, stop_angle, num, endpoint=False, degrees=True): - """ Evenly distributed list of x,y coordinates from an input range of angles - and a radius in polar coordinates. + """Evenly distributed list of x,y coordinates from an input range of angles + and a radius in polar coordinates. """ angles = np.linspace(start_angle, stop_angle, num=num, endpoint=endpoint) @@ -86,10 +86,10 @@ def spiral_trials(radii, x=0.0, y=0.0): if radii.size > 0: spiral = hex_pack(radii[0], len(radii)) - for i,radius in enumerate(radii): + for i, radius in enumerate(radii): circles.append(mpatches.Circle((spiral[i][0], spiral[i][1]), radii[i])) - - pos_xfm = mxfms.Affine2D().translate(x,y) + + pos_xfm = mxfms.Affine2D().translate(x, y) collection = PatchCollection(circles) collection.set_transform(pos_xfm) @@ -99,7 +99,7 @@ def spiral_trials(radii, x=0.0, y=0.0): def spiral_trials_polar(r, theta, radii, offset=None): if offset is None: - offset = [0,0] + offset = [0, 0] collection = spiral_trials(radii, r + offset[0], offset[1]) @@ -122,100 +122,104 @@ def radial_arcs(rs, start_theta, end_theta): arcs = [] for r in rs: - arcs.append(mpatches.Arc((0,0), 2*r, 2*r, - theta1=start_theta*180.0/np.pi, - theta2=end_theta*180.0/np.pi)) - + arcs.append( + mpatches.Arc((0, 0), 2 * r, 2 * r, theta1=start_theta * 180.0 / np.pi, theta2=end_theta * 180.0 / np.pi) + ) return PatchCollection(arcs) + def rings_in_hex_pack(ct): - return np.ceil((-3.0 + np.sqrt(9.0 - 12.0*(1.0 - ct))) / 6.0 + 1.0) + return np.ceil((-3.0 + np.sqrt(9.0 - 12.0 * (1.0 - ct))) / 6.0 + 1.0) + def radial_circles(rs): - circles = [ mpatches.Circle((0,0), r) for r in rs ] + circles = [mpatches.Circle((0, 0), r) for r in rs] return PatchCollection(circles) + def reset_hex_pack(): global HEX_POSITIONS HEX_POSITIONS = [] - + + def hex_pack(radius, n): global HEX_POSITIONS if len(HEX_POSITIONS) < n: HEX_POSITIONS = build_hex_pack(n) - return HEX_POSITIONS[:n]*radius*2.0 + return HEX_POSITIONS[:n] * radius * 2.0 + def build_hex_pack(n): pos = [] sq32 = math.sqrt(3.0) / 2.0 - + N = 1 - - vs = [ [-0.5, -sq32], [-1.0, 0.0], [-0.5, sq32], [0.5, sq32], [1, 0], [0.5, -sq32] ] - pos.append([0,0]) + + vs = [[-0.5, -sq32], [-1.0, 0.0], [-0.5, sq32], [0.5, sq32], [1, 0], [0.5, -sq32]] + pos.append([0, 0]) while len(pos) < n: - layer_pos = [ ] - - for i,v in enumerate(vs): - x = - N * v[1] * sq32 + layer_pos = [] + + for i, v in enumerate(vs): + x = -N * v[1] * sq32 y = N * v[0] * sq32 - + if N % 2 == 1: x -= 0.5 * v[0] y -= 0.5 * v[1] - + layer_pos.append([]) - layer_pos[i].append([x,y]) + layer_pos[i].append([x, y]) mag = 1 sign = 1 - - for j in range(N-1): + + for j in range(N - 1): x += v[0] * mag * sign y += v[1] * mag * sign mag += 1 sign = -sign - layer_pos[i].append([x,y]) - + layer_pos[i].append([x, y]) + for j in range(N): for i in range(len(vs)): if j < len(layer_pos[i]): pos.append(layer_pos[i][j]) - N+=1 + N += 1 return np.array(pos) def polar_line_circles(radii, theta, start_r=0): - circles = [ mpatches.Circle( (0,0), radii[0] ) ] - - line_xfm = mxfms.Affine2D().translate(start_r,0).rotate(theta) - + circles = [mpatches.Circle((0, 0), radii[0])] + + line_xfm = mxfms.Affine2D().translate(start_r, 0).rotate(theta) + x = 0 for ri in range(1, len(radii)): - x += radii[ri-1] + radii[ri] - - circles.append(mpatches.Circle( (x,0), radii[ri] )) - + x += radii[ri - 1] + radii[ri] + + circles.append(mpatches.Circle((x, 0), radii[ri])) + collection = PatchCollection(circles) collection.set_transform(line_xfm) - + return collection def wedge_ring(N, inner_radius, outer_radius, start=0, stop=360): - degs = np.linspace(start, stop, N+1, endpoint=True) + degs = np.linspace(start, stop, N + 1, endpoint=True) wedges = [] if stop > start: - for i in range(len(degs)-1): - wedges.append( mpatches.Wedge( (0,0), outer_radius, degs[i], degs[i+1], width=outer_radius-inner_radius ) ) + for i in range(len(degs) - 1): + wedges.append(mpatches.Wedge((0, 0), outer_radius, degs[i], degs[i + 1], width=outer_radius - inner_radius)) else: - for i in range(1,len(degs)): - wedges.append( mpatches.Wedge( (0,0), outer_radius, degs[i], degs[i-1], width=outer_radius-inner_radius ) ) + for i in range(1, len(degs)): + wedges.append(mpatches.Wedge((0, 0), outer_radius, degs[i], degs[i - 1], width=outer_radius - inner_radius)) return PatchCollection(wedges) @@ -224,14 +228,18 @@ def add_angle_labels(ax, angles, labels, radius, color=None, fontdict=None, offs angle_pos = polar_to_xy(angles, radius) for i in range(len(angle_pos)): - xy = angle_pos[i,:] + xy = angle_pos[i, :] u = xy + xy / np.linalg.norm(xy) * offset - ax.text(u[0], u[1], - labels[i], color=color, - horizontalalignment='center', - verticalalignment='center', - fontdict=fontdict) - + ax.text( + u[0], + u[1], + labels[i], + color=color, + horizontalalignment="center", + verticalalignment="center", + fontdict=fontdict, + ) + def add_arrow(ax, radius, start_angle, end_angle, color=None, width=18.0): if color is None: @@ -246,15 +254,21 @@ def add_arrow(ax, radius, start_angle, end_angle, color=None, width=18.0): start_pos = (radius * np.cos(start_angle), radius * np.sin(start_angle)) end_pos = (radius * np.cos(end_angle), radius * np.sin(end_angle)) - - connstyle = mpatches.ConnectionStyle.Angle3(angleA=0, angleB=(d_angle*180.0/np.pi)) + + connstyle = mpatches.ConnectionStyle.Angle3(angleA=0, angleB=(d_angle * 180.0 / np.pi)) arrowstyle = mpatches.ArrowStyle.Simple(tail_width=0.33, head_length=0.66, head_width=1.0) - ax.add_patch(mpatches.FancyArrowPatch(posA=start_pos, posB=end_pos, - arrowstyle=arrowstyle, - connectionstyle=connstyle, - facecolor=color, - linewidth=0, - mutation_scale=mutation_scale)) + ax.add_patch( + mpatches.FancyArrowPatch( + posA=start_pos, + posB=end_pos, + arrowstyle=arrowstyle, + connectionstyle=connstyle, + facecolor=color, + linewidth=0, + mutation_scale=mutation_scale, + ) + ) + def make_pincushion_plot(data, trials, on, nrows, ncols, clim=None, color_map=None, radius=None): if radius is None: @@ -266,22 +280,22 @@ def make_pincushion_plot(data, trials, on, nrows, ncols, clim=None, color_map=No radius = 0.5 / (2.0 * rings - 1.0) if clim is None: - clim = [ data.min(), data.max() ] + clim = [data.min(), data.max()] if color_map is None: color_map = LSN_ON_COLOR_MAP if on else LSN_OFF_COLOR_MAP ax = plt.gca() - for (col,row,on_state), sweeps in trials.items(): + for (col, row, on_state), sweeps in trials.items(): if on_state != on: continue valid_sweeps = sweeps[0][sweeps[0] < data.size] responses = np.sort(data[valid_sweeps])[::-1] responses = responses[responses >= clim[0]] - + if responses.size > 0: - coll = spiral_trials(np.ones(responses.shape)*radius, col+0.5, row+0.5) + coll = spiral_trials(np.ones(responses.shape) * radius, col + 0.5, row + 0.5) coll.set_transform(coll.get_transform() + ax.transData) coll.set_array(responses) coll.set_cmap(color_map) @@ -289,27 +303,27 @@ def make_pincushion_plot(data, trials, on, nrows, ncols, clim=None, color_map=No coll.set_linewidths(0) ax.add_collection(coll) - ax.set_ylim((0,nrows)) - ax.set_xlim((0,ncols)) + ax.set_ylim((0, nrows)) + ax.set_xlim((0, ncols)) - -class PolarPlotter( object ): +class PolarPlotter(object): DIR_CW = -1 DIR_CCW = 1 - def __init__(self, - direction=DIR_CW, - angle_start=0, - circle_scale=1.1, - inner_radius=None, - plot_center=(0.0,0.0), - plot_scale=0.9): - + def __init__( + self, + direction=DIR_CW, + angle_start=0, + circle_scale=1.1, + inner_radius=None, + plot_center=(0.0, 0.0), + plot_scale=0.9, + ): self.plot_scale = plot_scale self.plot_center = plot_center - self.angle_transform = np.vectorize(lambda x: ((x + angle_start)*direction)*np.pi/180.0) + self.angle_transform = np.vectorize(lambda x: ((x + angle_start) * direction) * np.pi / 180.0) self.inner_radius = inner_radius self.circle_scale = circle_scale @@ -319,24 +333,25 @@ def finalize(self): figsize = fig.get_size_inches() aspect = figsize[0] / figsize[1] - w = 2.0 / self.plot_scale + w = 2.0 / self.plot_scale h = w / aspect - bounds = ( self.plot_center[0] - w*.5, - self.plot_center[0] + w*.5, - self.plot_center[1] - h*.5, - self.plot_center[1] + h*.5 ) + bounds = ( + self.plot_center[0] - w * 0.5, + self.plot_center[0] + w * 0.5, + self.plot_center[1] - h * 0.5, + self.plot_center[1] + h * 0.5, + ) ax.set_xlim(bounds[0], bounds[1]) ax.set_ylim(bounds[2], bounds[3]) - plt.subplots_adjust(left=0,right=1,bottom=0,top=1) + plt.subplots_adjust(left=0, right=1, bottom=0, top=1) @classmethod def _clim(self, clim, data): - if clim is None: - clim = [ data.min(), data.max() ] + clim = [data.min(), data.max()] if clim[0] == clim[1]: clim[0] = 0 @@ -346,44 +361,34 @@ def _clim(self, clim, data): return clim -class TrackPlotter( PolarPlotter ): - def __init__(self, - direction=PolarPlotter.DIR_CW, - angle_start=270.0, - inner_radius=.45, - ring_length=None, - *args, **kwargs): - super(TrackPlotter, self).__init__(direction=direction, - angle_start=angle_start, - inner_radius=inner_radius, - *args, **kwargs) +class TrackPlotter(PolarPlotter): + def __init__( + self, direction=PolarPlotter.DIR_CW, angle_start=270.0, inner_radius=0.45, ring_length=None, *args, **kwargs + ): + super(TrackPlotter, self).__init__( + direction=direction, angle_start=angle_start, inner_radius=inner_radius, *args, **kwargs + ) self.ring_length = ring_length def show_arrow(self, color=None): start, end = self.angle_transform([0.0, 40.0]) - add_arrow(plt.gca(), self.inner_radius * .85, start, end, color) - - def plot(self, data, - clim=None, - cmap=DEFAULT_COLOR_MAP, - mean_cmap=DEFAULT_MEAN_RESP_COLOR_MAP, - norm=None): + add_arrow(plt.gca(), self.inner_radius * 0.85, start, end, color) + def plot(self, data, clim=None, cmap=DEFAULT_COLOR_MAP, mean_cmap=DEFAULT_MEAN_RESP_COLOR_MAP, norm=None): ax = plt.gca() clim = self._clim(clim, data) if self.ring_length: - data = skimage.transform.resize(data.astype(np.float64), - (data.shape[0], self.ring_length), - mode='constant', - anti_aliasing=False) - + data = skimage.transform.resize( + data.astype(np.float64), (data.shape[0], self.ring_length), mode="constant", anti_aliasing=False + ) + data_mean = data.mean(axis=0) data = np.vstack((data, data_mean)) - - radii = np.linspace(self.inner_radius, 1.0, data.shape[0]+2) - start,stop = self.angle_transform([0,360])*180.0/np.pi + + radii = np.linspace(self.inner_radius, 1.0, data.shape[0] + 2) + start, stop = self.angle_transform([0, 360]) * 180.0 / np.pi if norm is None: norm = mcolors.PowerNorm(0.5, vmin=clim[0], vmax=clim[1], clip=True) @@ -392,35 +397,30 @@ def plot(self, data, inner_radius = radii[i] if i < data.shape[0] - 1: - outer_radius = radii[i+1] + outer_radius = radii[i + 1] ring_cmap = cmap else: - outer_radius = radii[i+2] + outer_radius = radii[i + 2] ring_cmap = mean_cmap - - wedges = wedge_ring(len(row_data), - inner_radius, outer_radius, - start=start, stop=stop) + wedges = wedge_ring(len(row_data), inner_radius, outer_radius, start=start, stop=stop) wedges.set_array(row_data) - #wedges.set_clim(clim) + # wedges.set_clim(clim) wedges.set_cmap(ring_cmap) wedges.set_norm(norm) - wedges.set_edgecolors((0,0,0,0)) + wedges.set_edgecolors((0, 0, 0, 0)) ax.add_collection(wedges) self.finalize() - -class CoronaPlotter( PolarPlotter ): - def __init__(self, - angle_start=270, - plot_scale=1.2, - inner_radius=.3, - *args, **kwargs): - super(CoronaPlotter, self).__init__(inner_radius=inner_radius, angle_start=angle_start, plot_scale=plot_scale, *args, **kwargs) + +class CoronaPlotter(PolarPlotter): + def __init__(self, angle_start=270, plot_scale=1.2, inner_radius=0.3, *args, **kwargs): + super(CoronaPlotter, self).__init__( + inner_radius=inner_radius, angle_start=angle_start, plot_scale=plot_scale, *args, **kwargs + ) self.categories = None self.cat_idx_map = None @@ -434,23 +434,19 @@ def set_dims(self, categories): def show_arrow(self, color=None): start, end = self.angle_transform([0.0, 40.0]) - add_arrow(plt.gca(), self.inner_radius * .85, start, end, color) + add_arrow(plt.gca(), self.inner_radius * 0.85, start, end, color) def show_circle(self, color=None): if color is None: color = DEFAULT_LABEL_COLOR ax = plt.gca() collection = radial_circles([0.96 * self.inner_radius]) - collection.set_facecolor((0,0,0,0)) + collection.set_facecolor((0, 0, 0, 0)) collection.set_edgecolor(color) collection.set_zorder(1) ax.add_collection(collection) - def plot(self, category_data, - data=None, - clim=None, - cmap=DEFAULT_COLOR_MAP): - + def plot(self, category_data, data=None, clim=None, cmap=DEFAULT_COLOR_MAP): ax = plt.gca() if self.categories is None: @@ -458,41 +454,40 @@ def plot(self, category_data, if data is None: data = np.ones(len(category_data)) - + clim = self._clim(clim, data) num_cats = len(self.categories) hth = 180.0 / num_cats - degs = np.linspace(hth, 360.0-hth, num_cats) + degs = np.linspace(hth, 360.0 - hth, num_cats) degs = self.angle_transform(degs) - circle_radius = self.inner_radius * abs(np.sin((degs[1] - degs[0]) * .5)) - + circle_radius = self.inner_radius * abs(np.sin((degs[1] - degs[0]) * 0.5)) + radii = np.ones(len(data)) * circle_radius * self.circle_scale - df = pd.DataFrame({ 'category': category_data }) - gb = df.groupby(['category']) + df = pd.DataFrame({"category": category_data}) + gb = df.groupby(["category"]) for category, trials in gb.groups.items(): idx = self.cat_idx_map[category] order = np.argsort(data[trials])[::-1] trial_order = np.array(trials)[order] - circles = polar_line_circles(radii[trial_order], - degs[idx], - self.inner_radius) + circles = polar_line_circles(radii[trial_order], degs[idx], self.inner_radius) circles.set_transform(circles.get_transform() + ax.transData) circles.set_array(data[trial_order]) circles.set_cmap(cmap) circles.set_clim(clim) - circles.set_edgecolors((0,0,0,0)) + circles.set_edgecolors((0, 0, 0, 0)) circles.set_zorder(2) ax.add_collection(circles) self.finalize() -class FanPlotter( PolarPlotter ): + +class FanPlotter(PolarPlotter): def __init__(self, group_scale=0.9, *args, **kwargs): super(FanPlotter, self).__init__(*args, **kwargs) @@ -517,7 +512,7 @@ def infer_dims(self, r_data, angle_data, group_data): angles = np.sort(np.unique(angle_data)) groups = np.sort(np.unique(group_data)) if group_data is not None else None - self.set_dims(rs, angles, groups) + self.set_dims(rs, angles, groups) def set_dims(self, rs, angles, groups): self.angles = angles @@ -529,32 +524,28 @@ def set_dims(self, rs, angles, groups): # map r value to radius if self.inner_radius is None: - self.inner_radius = 1.0 / ( 2 * num_rs ) + self.inner_radius = 1.0 / (2 * num_rs) - hdr = ( 1.0 - self.inner_radius ) / num_rs / 2.0 - self.radii = np.linspace(self.inner_radius + hdr, - 1.0 - hdr, - num_rs) + hdr = (1.0 - self.inner_radius) / num_rs / 2.0 + self.radii = np.linspace(self.inner_radius + hdr, 1.0 - hdr, num_rs) self.r_radius_map = dict(zip(rs, self.radii)) self.group_radius = hdr * self.group_scale - self.groups = groups if groups is not None else [ np.nan ] + self.groups = groups if groups is not None else [np.nan] num_groups = len(self.groups) # map group to group offset if num_groups == 1: - self.group_offsets = [ [ 0, 0 ] ] + self.group_offsets = [[0, 0]] else: offset_radius = self.group_radius * self.circle_scale - self.group_offsets = polar_linspace(offset_radius/np.sqrt(2), - -45, -45-360, num_groups) + self.group_offsets = polar_linspace(offset_radius / np.sqrt(2), -45, -45 - 360, num_groups) self.group_radius = offset_radius * 0.5 self.group_offset_map = dict(zip(self.groups, self.group_offsets)) - def show_axes(self, angles=None, radii=None, closed=False, color=None): ax = plt.gca() @@ -569,7 +560,7 @@ def show_axes(self, angles=None, radii=None, closed=False, color=None): if radii is None: radii = self.radii - + lines = angle_lines(angles, radii[0], radii[-1]) lines.set_zorder(1) lines.set_edgecolors(color) @@ -579,26 +570,24 @@ def show_axes(self, angles=None, radii=None, closed=False, color=None): collection = radial_circles(radii) else: collection = radial_arcs(radii, min(angles), max(angles)) - - collection.set_facecolors((0,0,0,0.0)) + + collection.set_facecolors((0, 0, 0, 0.0)) collection.set_edgecolors(color) collection.set_zorder(1) ax.add_collection(collection) - - def show_angle_labels(self, angles=None, labels=None, color=None, offset=.05, fontdict=None): + def show_angle_labels(self, angles=None, labels=None, color=None, offset=0.05, fontdict=None): if angles is None: angles = self.xangles if labels is None: labels = self.angles.astype(int) - + if color is None: color = DEFAULT_LABEL_COLOR add_angle_labels(plt.gca(), angles, labels, 1.0, offset=offset, color=color, fontdict=fontdict) - def show_group_labels(self, groups=None, color=None, fontdict=None): ax = plt.gca() @@ -608,35 +597,37 @@ def show_group_labels(self, groups=None, color=None, fontdict=None): if color is None: color = DEFAULT_LABEL_COLOR - r = self.inner_radius*.5 + r = self.inner_radius * 0.5 angle = 90.0 - + for group in groups: off = self.group_offset_map[group] - xfm = mxfms.Affine2D().translate(r+off[0]*2.0,off[1]*2.0).rotate(self.angle_transform(angle)) - p = xfm.transform_point([0,0]) - - ax.text(p[0], p[1], - group, color=color, - horizontalalignment='center', - verticalalignment='center', - fontdict=fontdict) - - start_theta = self.angle_transform(angle+20) - end_theta = self.angle_transform(angle-20) + xfm = mxfms.Affine2D().translate(r + off[0] * 2.0, off[1] * 2.0).rotate(self.angle_transform(angle)) + p = xfm.transform_point([0, 0]) - ax.add_patch(mpatches.Arc((0,0), 2*r, 2*r, - theta1=start_theta*180.0/np.pi, - theta2=end_theta*180.0/np.pi, - color=color)) + ax.text( + p[0], + p[1], + group, + color=color, + horizontalalignment="center", + verticalalignment="center", + fontdict=fontdict, + ) - ax.add_collection(LineCollection([[[0, .7*r], [0, 1.3*r]]], color=color)) + start_theta = self.angle_transform(angle + 20) + end_theta = self.angle_transform(angle - 20) + ax.add_patch( + mpatches.Arc( + (0, 0), 2 * r, 2 * r, theta1=start_theta * 180.0 / np.pi, theta2=end_theta * 180.0 / np.pi, color=color + ) + ) + ax.add_collection(LineCollection([[[0, 0.7 * r], [0, 1.3 * r]]], color=color)) - - def show_r_labels(self, radii=None, labels=None, color=None, offset=.1, fontdict=None): + def show_r_labels(self, radii=None, labels=None, color=None, offset=0.1, fontdict=None): ax = plt.gca() if radii is None: @@ -644,7 +635,7 @@ def show_r_labels(self, radii=None, labels=None, color=None, offset=.1, fontdict if labels is None: labels = self.rs - + if color is None: color = DEFAULT_LABEL_COLOR @@ -654,25 +645,30 @@ def show_r_labels(self, radii=None, labels=None, color=None, offset=.1, fontdict line_th = self.xangles[0] line_x = radii * np.cos(line_th) line_y = radii * np.sin(line_th) - for i,(x,y) in enumerate(zip(line_x,line_y)): - ax.text(x, y-offset, - labels[i], color=color, - horizontalalignment='center', - verticalalignment='center', - fontdict=fontdict) - - def plot(self, - r_data, - angle_data, - group_data=None, - data=None, - cmap=DEFAULT_COLOR_MAP, - clim=None, - rmap=None, - rlim=None, - axis_color=None, - label_color=None): - + for i, (x, y) in enumerate(zip(line_x, line_y)): + ax.text( + x, + y - offset, + labels[i], + color=color, + horizontalalignment="center", + verticalalignment="center", + fontdict=fontdict, + ) + + def plot( + self, + r_data, + angle_data, + group_data=None, + data=None, + cmap=DEFAULT_COLOR_MAP, + clim=None, + rmap=None, + rlim=None, + axis_color=None, + label_color=None, + ): ax = plt.gca() if data is None: @@ -695,24 +691,24 @@ def plot(self, num_rs = len(self.rs) num_angles = len(self.angles) - df = pd.DataFrame({ 'group': group_data, - 'angle': angle_data, - 'r': r_data }) + df = pd.DataFrame({"group": group_data, "angle": angle_data, "r": r_data}) # compute circle radius trials_per_group = float(len(df)) / num_groups / num_rs / num_angles rings = rings_in_hex_pack(trials_per_group) - circle_radius = self.group_radius / (2*rings - 1) * self.circle_scale + circle_radius = self.group_radius / (2 * rings - 1) * self.circle_scale - gb = df.groupby(['group', 'angle', 'r']) + gb = df.groupby(["group", "angle", "r"]) for (group, angle, r), trials in gb.groups.items(): responses = np.sort(data[trials])[::-1] - circles = spiral_trials_polar(self.r_radius_map[r], - self.angle_map[angle], - rnorm(responses) * circle_radius, - offset=self.group_offset_map[group]) + circles = spiral_trials_polar( + self.r_radius_map[r], + self.angle_map[angle], + rnorm(responses) * circle_radius, + offset=self.group_offset_map[group], + ) circles.set_transform(circles.get_transform() + ax.transData) circles.set_array(responses) @@ -727,19 +723,10 @@ def plot(self, @staticmethod def for_static_gratings(): - return FanPlotter(angle_start=180, - plot_scale=0.9, - circle_scale=2.0, - group_scale=0.4, - plot_center=[0,.45], - inner_radius=.2) + return FanPlotter( + angle_start=180, plot_scale=0.9, circle_scale=2.0, group_scale=0.4, plot_center=[0, 0.45], inner_radius=0.2 + ) @staticmethod def for_drifting_gratings(): return FanPlotter() - - - - - - diff --git a/allensdk/brain_observatory/comparison_utils.py b/allensdk/brain_observatory/comparison_utils.py index c9afa8b3c5..0b525857c1 100644 --- a/allensdk/brain_observatory/comparison_utils.py +++ b/allensdk/brain_observatory/comparison_utils.py @@ -9,8 +9,7 @@ from pandas.testing import assert_frame_equal -def compare_fields(x1: Any, x2: Any, err_msg="", - ignore_keys: Optional[Set[str]] = None): +def compare_fields(x1: Any, x2: Any, err_msg="", ignore_keys: Optional[Set[str]] = None): """Helper function to compare if two fields (attributes) are equal to one another. @@ -63,7 +62,7 @@ def compare_fields(x1: Any, x2: Any, err_msg="", assert abs(time_delta) < 60 elif isinstance(x1, (float,)): if math.isnan(x1) or math.isnan(x2): - both_nan = (math.isnan(x1) and math.isnan(x2)) + both_nan = math.isnan(x1) and math.isnan(x2) assert both_nan, err_msg else: assert x1 == x2, err_msg diff --git a/allensdk/brain_observatory/data_release_utils/metadata_utils/id_generator.py b/allensdk/brain_observatory/data_release_utils/metadata_utils/id_generator.py index be7b0e8294..cfea7d51f9 100644 --- a/allensdk/brain_observatory/data_release_utils/metadata_utils/id_generator.py +++ b/allensdk/brain_observatory/data_release_utils/metadata_utils/id_generator.py @@ -20,17 +20,18 @@ def dummy_value(self) -> int: """ return self._dummy_value - def id_from_path(self, - file_path: pathlib.Path) -> int: + def id_from_path(self, file_path: pathlib.Path) -> int: """ Get the unique ID for a file path. If the file has already been assigned a unique ID, return that. Otherwise, assign a unique ID to the file path and return it """ if not isinstance(file_path, pathlib.Path): - msg = ("file_path must be a pathlib.Path (this is so " - "we can resolve it into an absolute path). You passed " - f"in a {type(file_path)}") + msg = ( + "file_path must be a pathlib.Path (this is so " + "we can resolve it into an absolute path). You passed " + f"in a {type(file_path)}" + ) raise ValueError(msg) if not file_path.is_file(): diff --git a/allensdk/brain_observatory/data_release_utils/metadata_utils/utils.py b/allensdk/brain_observatory/data_release_utils/metadata_utils/utils.py index 6532932936..098a3238a4 100644 --- a/allensdk/brain_observatory/data_release_utils/metadata_utils/utils.py +++ b/allensdk/brain_observatory/data_release_utils/metadata_utils/utils.py @@ -4,21 +4,19 @@ import pathlib import warnings -from allensdk.brain_observatory.data_release_utils \ - .metadata_utils.id_generator import ( - FileIDGenerator) +from allensdk.brain_observatory.data_release_utils.metadata_utils.id_generator import FileIDGenerator def add_file_paths_to_metadata_table( - metadata_table: pd.DataFrame, - id_generator: FileIDGenerator, - file_dir: pathlib.Path, - file_prefix: Optional[str], - index_col: str, - data_dir_col: Optional[str], - on_missing_file: str, - file_suffix: str = 'nwb', - file_stem: Optional[str] = None + metadata_table: pd.DataFrame, + id_generator: FileIDGenerator, + file_dir: pathlib.Path, + file_prefix: Optional[str], + index_col: str, + data_dir_col: Optional[str], + on_missing_file: str, + file_suffix: str = "nwb", + file_stem: Optional[str] = None, ) -> pd.DataFrame: """ Add file_id and file_path columns to session dataframe. @@ -74,10 +72,8 @@ def add_file_paths_to_metadata_table( {file_dir}/{file_prefix}_{metadata_table.index_col}.{file_suffix} """ - if on_missing_file not in ('error', 'warn', 'skip'): - msg = ("on_missing_file must be one of ('error', " - "'warn', or 'skip'); you passed in " - f"{on_missing_file}") + if on_missing_file not in ("error", "warn", "skip"): + msg = f"on_missing_file must be one of ('error', 'warn', or 'skip'); you passed in {on_missing_file}" raise ValueError(msg) new_data = [] @@ -88,19 +84,16 @@ def add_file_paths_to_metadata_table( data_dir = getattr(row, data_dir_col, row.Index) if file_stem is None: - file_stem_ = \ - f'{file_prefix}_{row.Index}' if file_prefix is not None else \ - f'{row.Index}' + file_stem_ = f"{file_prefix}_{row.Index}" if file_prefix is not None else f"{row.Index}" else: file_stem_ = file_stem if data_dir is None: # If `data_dir` is not given, assume files stored flat - file_path = file_dir / f'{file_stem_}.{file_suffix}' + file_path = file_dir / f"{file_stem_}.{file_suffix}" else: # assume files stored under data_dir - file_path = file_dir / f'{data_dir}' / \ - f'{file_stem_}.{file_suffix}' + file_path = file_dir / f"{data_dir}" / f"{file_stem_}.{file_suffix}" if not file_path.exists(): file_id = id_generator.dummy_value @@ -108,29 +101,23 @@ def add_file_paths_to_metadata_table( else: file_id = id_generator.id_from_path(file_path=file_path) str_path = str(file_path.resolve().absolute()) - new_data.append( - {'file_id': file_id, - 'file_path': str_path, - index_col: row.Index}) + new_data.append({"file_id": file_id, "file_path": str_path, index_col: row.Index}) if len(missing_files) > 0: msg = "The following files do not exist:" for file_path in missing_files: msg += f"\n{file_path}" - if on_missing_file == 'error': + if on_missing_file == "error": raise RuntimeError(msg) else: warnings.warn(msg) new_df = pd.DataFrame(data=new_data) - metadata_table = metadata_table.join( - new_df.set_index(index_col), - on=index_col, - how='left') - if on_missing_file == 'skip' and len(missing_files) > 0: + metadata_table = metadata_table.join(new_df.set_index(index_col), on=index_col, how="left") + if on_missing_file == "skip" and len(missing_files) > 0: metadata_table = metadata_table.drop( - metadata_table.loc[ - metadata_table.file_id == id_generator.dummy_value].index) + metadata_table.loc[metadata_table.file_id == id_generator.dummy_value].index + ) metadata_table = metadata_table.reset_index() diff --git a/allensdk/brain_observatory/demixer.py b/allensdk/brain_observatory/demixer.py index effdfc8d97..39ed26fe4f 100644 --- a/allensdk/brain_observatory/demixer.py +++ b/allensdk/brain_observatory/demixer.py @@ -55,7 +55,7 @@ def identify_valid_masks(mask_array): duplicates = ms.detect_duplicates(overlap_threshold=0.9) if len(duplicates) > 0: valid_masks[duplicates.keys()] = False - + # detect unions, only for remaining valid masks valid_idxs = np.where(valid_masks) ms = mask_set.MaskSet(masks=mask_array[valid_idxs].astype(bool)) @@ -68,9 +68,9 @@ def identify_valid_masks(mask_array): return valid_masks -def _demix_point(source_frame: np.ndarray, mask_traces: np.ndarray, - flat_masks: sparse, - pixels_per_mask: np.ndarray) -> Optional[np.ndarray]: +def _demix_point( + source_frame: np.ndarray, mask_traces: np.ndarray, flat_masks: sparse, pixels_per_mask: np.ndarray +) -> Optional[np.ndarray]: """ Helper function to run demixing for single point in time for a source with overlapping traces. @@ -112,9 +112,9 @@ def _demix_point(source_frame: np.ndarray, mask_traces: np.ndarray, return demix_traces -def demix_time_dep_masks(raw_traces: np.ndarray, stack: np.ndarray, - masks: np.ndarray, - max_block_size: int = 1000) -> Tuple[np.ndarray, list]: +def demix_time_dep_masks( + raw_traces: np.ndarray, stack: np.ndarray, masks: np.ndarray, max_block_size: int = 1000 +) -> Tuple[np.ndarray, list]: """ Demix traces of potentially overlapping masks extraced from a single 2p recording. @@ -139,9 +139,11 @@ def demix_time_dep_masks(raw_traces: np.ndarray, stack: np.ndarray, if max_block_size == -1: max_block_size = T elif max_block_size < 1: - raise ValueError("Invalid maximum block size {}. Must be strictly " - "positive (>= 1), or -1 for full length block " - "size.".format(max_block_size)) + raise ValueError( + "Invalid maximum block size {}. Must be strictly positive (>= 1), or -1 for full length block size.".format( + max_block_size + ) + ) num_pixels_in_mask = np.sum(masks, axis=(1, 2)) @@ -152,15 +154,12 @@ def demix_time_dep_masks(raw_traces: np.ndarray, stack: np.ndarray, demix_traces = np.zeros((N, T)) for t in range(T): - block_t = t % max_block_size if block_t == 0: # load next block into memory and reshape block_T = np.min([(T - t), max_block_size]) - stack_block = stack[t : t+block_T].reshape(block_T, P) + stack_block = stack[t : t + block_T].reshape(block_T, P) - demixed_point = _demix_point( - stack_block[block_t], raw_traces[:, t], flat_masks, - num_pixels_in_mask) + demixed_point = _demix_point(stack_block[block_t], raw_traces[:, t], flat_masks, num_pixels_in_mask) if demixed_point is not None: demix_traces[:, t] = demixed_point drop_frames.append(False) @@ -172,8 +171,8 @@ def demix_time_dep_masks(raw_traces: np.ndarray, stack: np.ndarray, def plot_traces(raw_trace, demix_trace, roi_id, roi_ind, save_file): fig, ax = plt.subplots() - ax.plot(raw_trace, label='Fluoresence') - ax.plot(demix_trace, label='Demixed') + ax.plot(raw_trace, label="Fluoresence") + ax.plot(demix_trace, label="Demixed") ax.set_title("ROI ID(%d) index (%d)" % (roi_id, roi_ind)) ax.legend() plt.savefig(save_file) @@ -183,27 +182,26 @@ def plot_traces(raw_trace, demix_trace, roi_id, roi_ind, save_file): def find_zero_baselines(traces): means = traces.mean(axis=1) stds = traces.std(axis=1) - return np.where((means-stds) < 0) + return np.where((means - stds) < 0) -def plot_negative_baselines(raw_traces, demix_traces, mask_array, - roi_ids_mask, plot_dir, ext='png'): +def plot_negative_baselines(raw_traces, demix_traces, mask_array, roi_ids_mask, plot_dir, ext="png"): N, T = raw_traces.shape _, x, y = mask_array.shape logging.debug("finding negative baselines") neg_inds = find_negative_baselines(demix_traces)[0] - + overlap_inds = set() logging.debug("detected negative baselines: %s", str(neg_inds)) for roi_ind in neg_inds: Manifest.safe_mkdir(plot_dir) - save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_negative.' + ext) + save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_negative." + ext) plot_traces(raw_traces[roi_ind], demix_traces[roi_ind], roi_ids_mask[roi_ind], roi_ind, save_file) - ''' plot overlapping masks ''' - save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_negative_masks.' + ext) + """ plot overlapping masks """ + save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_negative_masks." + ext) roi_overlap_inds = plot_overlap_masks_lengthOne(roi_ind, mask_array, save_file) overlap_inds.update(roi_overlap_inds) @@ -215,9 +213,7 @@ def plot_negative_baselines(raw_traces, demix_traces, mask_array, return list(overlap_inds) -def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, - roi_ids_mask, plot_dir, ext='png'): - +def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, roi_ids_mask, plot_dir, ext="png"): N, T = raw_traces.shape _, x, y = mask_array.shape @@ -229,24 +225,23 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, logging.debug("plotting negative transients") - flat_masks = mask_array.reshape(N, x*y) + flat_masks = mask_array.reshape(N, x * y) overlap = flat_masks.dot(flat_masks.T) overlap ^= np.diag(np.diag(overlap)) for roi_ind in rois_with_trans: - - ''' plot biggest negative transient of this roi ''' + """ plot biggest negative transient of this roi """ trans_ind_list = trans_ind_list1[roi_ind] trans_ind_list = trans_ind_list[0] trans_list = [] for i in trans_ind_list: if i > 100 and i < T - 100: - trans_list.append(demix_traces[roi_ind, i - 100:i + 100]) + trans_list.append(demix_traces[roi_ind, i - 100 : i + 100]) elif i > 100 and i >= T - 100: - trans_list.append(demix_traces[roi_ind, i - 100:]) + trans_list.append(demix_traces[roi_ind, i - 100 :]) else: - trans_list.append(demix_traces[roi_ind, :i + 100]) + trans_list.append(demix_traces[roi_ind, : i + 100]) # trans_list = [demix_traces[roi_ind, i-100:i+100] for i in trans_ind_list if i > 100 and i < Nt] Ntrans = len(trans_list) @@ -261,22 +256,20 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, # trans_list_min = np.where(demix_traces[roi_ind, trans_ind_list] == min(demix_traces[roi_ind, trans_ind_list]))[0] if np.sum(overlap[roi_ind]) > 0: - if valid_roi[roi_ind]: - - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_transient_valid.' + ext) + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_transient_valid." + ext) plot_transients(roi_ind, trans_ind, mask_array, raw_traces, demix_traces, savefile) - ''' plot overlapping masks ''' - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_masks_valid.' + ext) + """ plot overlapping masks """ + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_masks_valid." + ext) plot_overlap_masks_lengthOne(roi_ind, mask_array, savefile) # plot_overlap_masks(roi_ind, mask_test, savefile) else: - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_transient_invalid.' + ext) + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_transient_invalid." + ext) plot_transients(roi_ind, trans_ind, mask_array, raw_traces, demix_traces, savefile) - ''' plot overlapping masks ''' - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_masks_invalid.' + ext) + """ plot overlapping masks """ + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_masks_invalid." + ext) plot_overlap_masks_lengthOne(roi_ind, mask_array, savefile) # plot_overlap_masks(roi_ind, mask_test, savefile) # @@ -287,15 +280,15 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, def rolling_window(trace, window=500): - ''' + """ :param trace: :param window: :return: - ''' + """ shape = trace.shape[:-1] + (trace.shape[-1] - window + 1, window) - strides = trace.strides + (trace.strides[-1], ) + strides = trace.strides + (trace.strides[-1],) return np.lib.stride_tricks.as_strided(trace, shape=shape, strides=strides) @@ -303,18 +296,18 @@ def rolling_window(trace, window=500): def find_negative_baselines(trace): means = trace.mean(axis=1) stds = trace.std(axis=1) - return np.where((means+stds) < 0) + return np.where((means + stds) < 0) def find_negative_transients_threshold(trace, window=500, length=10, std_devs=3): - trace = np.pad(trace, pad_width=(window-1, 0), mode='constant', constant_values=[np.mean(trace[:window])]) + trace = np.pad(trace, pad_width=(window - 1, 0), mode="constant", constant_values=[np.mean(trace[:window])]) rolling_mean = np.mean(rolling_window(trace, window), -1) rolling_std = np.std(rolling_window(trace, window), -1) - below_thresh = (trace[window-1:] < rolling_mean - std_devs*rolling_std) - below_thresh = np.pad(below_thresh, pad_width=(window-1, 0), mode='constant') + below_thresh = trace[window - 1 :] < rolling_mean - std_devs * rolling_std + below_thresh = np.pad(below_thresh, pad_width=(window - 1, 0), mode="constant") trans_length = np.sum(rolling_window(below_thresh, length), -1) - trans_length = trans_length[window-length:] + trans_length = trans_length[window - length :] trans_ind = np.where(trans_length == length) @@ -322,14 +315,13 @@ def find_negative_transients_threshold(trace, window=500, length=10, std_devs=3) def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): - masks = np.array(masks).astype(float) N, x, y = masks.shape - if np.sum(masks[-1]) == x*y: + if np.sum(masks[-1]) == x * y: masks = masks[:-1] N -= 1 - flat_masks = masks.reshape(N, x*y) + flat_masks = masks.reshape(N, x * y) masks_overlap = flat_masks.dot(flat_masks.T) ind_plot = np.where(masks_overlap[roi_ind, :] > 0)[0] # rois (k) that roi_ind overlaps with @@ -338,31 +330,37 @@ def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): ind_plot = np.concatenate((ind_plot, ind_k)) ind_plot = np.unique(ind_plot) - ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot!=roi_ind])) + ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot != roi_ind])) plt.figure() - color_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + color_list = ["b", "g", "r", "c", "m", "y", "k"] Ncol = len(color_list) for num, i in enumerate(ind_plot): mask_plot = masks[i] if not weighted: - mask_plot = ((num % Ncol)+1)*np.ma.array(masks[i], mask=(masks[i] == 0)) - plt.imshow(mask_plot, clim=(1., Ncol+1), cmap=colors.ListedColormap(color_list), alpha=0.5, interpolation='nearest') + mask_plot = ((num % Ncol) + 1) * np.ma.array(masks[i], mask=(masks[i] == 0)) + plt.imshow( + mask_plot, + clim=(1.0, Ncol + 1), + cmap=colors.ListedColormap(color_list), + alpha=0.5, + interpolation="nearest", + ) # plt.imshow(mask_plot, clim=(1., len(ind_plot)), alpha=.5) elif weighted: mask_plot = np.ma.array(masks[i], mask=(masks[i] == 0)) - plt.imshow(mask_plot, cmap='gray_r', alpha=.5, interpolation='nearest') + plt.imshow(mask_plot, cmap="gray_r", alpha=0.5, interpolation="nearest") - plt.text(np.mean(np.where(np.sum(mask_plot, axis=0))), np.mean(np.where(np.sum(mask_plot, axis=1))) ,str(i)) + plt.text(np.mean(np.where(np.sum(mask_plot, axis=0))), np.mean(np.where(np.sum(mask_plot, axis=1))), str(i)) mask_tot = np.sum(masks[ind_plot, :, :], axis=0) mask_x = np.sum(mask_tot, axis=0) mask_y = np.sum(mask_tot, axis=1) - plt.xlim((np.amin(np.where(mask_x))-5, np.amax(np.where(mask_x))+5)) - plt.ylim((np.amin(np.where(mask_y))-5, np.amax(np.where(mask_y))+5)) - plt.title('Masks') + plt.xlim((np.amin(np.where(mask_x)) - 5, np.amax(np.where(mask_x)) + 5)) + plt.ylim((np.amin(np.where(mask_y)) - 5, np.amax(np.where(mask_y)) + 5)) + plt.title("Masks") if savefile is not None: plt.savefig(savefile) @@ -372,12 +370,11 @@ def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): - masks = np.array(masks).astype(float) N, x, y = masks.shape _, Nt = traces.shape - flat_masks = masks.reshape(N, x*y) + flat_masks = masks.reshape(N, x * y) masks_overlap = flat_masks.dot(flat_masks.T) ind_plot = np.where(masks_overlap[roi_ind, :] > 0)[0] # rois (k) that roi_ind overlaps with @@ -386,7 +383,7 @@ def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): ind_plot = np.concatenate((ind_plot, ind_k)) ind_plot = np.unique(ind_plot) - ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot!=roi_ind])) + ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot != roi_ind])) if t_trans > 150 and t_trans < Nt - 150: plot_t = range(t_trans - 150, t_trans + 150) @@ -396,17 +393,17 @@ def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): plot_t = range(0, t_trans + 150) fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True) - color_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + color_list = ["b", "g", "r", "c", "m", "y", "k"] Ncol = len(color_list) for num, i in enumerate(ind_plot): ax[0].plot(plot_t, traces[i, plot_t], label=str(i), color=color_list[(num % Ncol)]) ax[1].plot(plot_t, demix_traces[i, plot_t], label=str(i), color=color_list[(num % Ncol)]) - ax[0].set_title('Raw') - ax[0].set_ylabel('Fluorescence') - ax[1].set_title('Demixed') - ax[1].set_xlabel('Time') + ax[0].set_title("Raw") + ax[0].set_ylabel("Fluorescence") + ax[1].set_title("Demixed") + ax[1].set_xlabel("Time") ax[0].legend(loc=0) plt.savefig(savefile) diff --git a/allensdk/brain_observatory/dff.py b/allensdk/brain_observatory/dff.py index 81746fbe49..e00d3a854f 100644 --- a/allensdk/brain_observatory/dff.py +++ b/allensdk/brain_observatory/dff.py @@ -41,8 +41,7 @@ import numpy as np from scipy.ndimage.filters import median_filter -from allensdk.core.brain_observatory_nwb_data_set import \ - BrainObservatoryNwbDataSet +from allensdk.core.brain_observatory_nwb_data_set import BrainObservatoryNwbDataSet GAUSSIAN_MAD_STD_SCALE = 1.4826 @@ -72,8 +71,7 @@ def movingmode_fast(x, kernelsize, y): # compute a histogram of a half kernel halfsize = int(kernelsize / 2) - histo = np.bincount(np.rint(x[:halfsize]).astype( - np.uint32), minlength=int(maxval + 2)) + histo = np.bincount(np.rint(x[:halfsize]).astype(np.uint32), minlength=int(maxval + 2)) # find the mode of the first half kernel mode = np.argmax(histo) @@ -168,24 +166,22 @@ def plot_onetrace(dff, fc): frames = np.arange(r[0], r[1]) ax = plt.subplot(len(qs), 1, qi + 1) - ax.plot(frames, dff[r[0]:r[1]], 'g') + ax.plot(frames, dff[r[0] : r[1]], "g") ax.set_ylim(dff_min, dff_max) ax.set_xlim(r[0], r[1]) - ax.set_xlabel('frames', fontsize=18) - ax.set_ylabel('DF/F', fontsize=18, color='g') + ax.set_xlabel("frames", fontsize=18) + ax.set_ylabel("DF/F", fontsize=18, color="g") ax = ax.twinx() - ax.plot(frames, fc[r[0]:r[1]], 'b') + ax.plot(frames, fc[r[0] : r[1]], "b") ax.set_ylim(fc_min, fc_max) ax.set_xlim(r[0], r[1]) - ax.set_ylabel('FC', fontsize=18, color='b') + ax.set_ylabel("FC", fontsize=18, color="b") return 0 -def compute_dff_windowed_mode(traces, - mode_kernelsize=5400, - mean_kernelsize=3000): +def compute_dff_windowed_mode(traces, mode_kernelsize=5400, mean_kernelsize=3000): """Compute dF/F of a set of traces using a low-pass windowed-mode operator. The operation is basically: @@ -219,8 +215,7 @@ def compute_dff_windowed_mode(traces, if mode_kernelsize == 0 or mean_kernelsize == 0: raise ValueError("Kernel length is 0!") - logging.debug("trace matrix shape: %d %d" % - (traces.shape[0], traces.shape[1])) + logging.debug("trace matrix shape: %d %d" % (traces.shape[0], traces.shape[1])) modeline = np.zeros(traces.shape[1]) modelineLP = np.zeros(traces.shape[1]) @@ -230,8 +225,7 @@ def compute_dff_windowed_mode(traces, for n in range(0, traces.shape[0]): if np.any(np.isnan(traces[n])): - logging.warning( - "trace for roi %d contains NaNs, setting to NaN", n) + logging.warning("trace for roi %d contains NaNs, setting to NaN", n) dff[n, :] = np.nan continue @@ -244,12 +238,9 @@ def compute_dff_windowed_mode(traces, return dff -def compute_dff_windowed_median(traces, - median_kernel_long=5401, - median_kernel_short=101, - noise_stds=None, - n_small_baseline_frames=None, - **kwargs): +def compute_dff_windowed_median( + traces, median_kernel_long=5401, median_kernel_short=101, noise_stds=None, n_small_baseline_frames=None, **kwargs +): """Compute dF/F of a set of traces with median filter detrending. The operation is basically: @@ -294,7 +285,7 @@ def compute_dff_windowed_median(traces, sigma_f = noise_std(dff, **kwargs) # long timescale median filter for baseline subtraction - tf = median_filter(dff, median_kernel_long, mode='constant') + tf = median_filter(dff, median_kernel_long, mode="constant") dff -= tf dff /= np.maximum(tf, sigma_f) @@ -306,8 +297,8 @@ def compute_dff_windowed_median(traces, noise_stds.append(sigma_dff) # short timescale detrending - tf = median_filter(dff, median_kernel_short, mode='constant') - tf = np.minimum(tf, 2.5*sigma_dff) + tf = median_filter(dff, median_kernel_short, mode="constant") + tf = np.minimum(tf, 2.5 * sigma_dff) dff -= tf return dff_traces @@ -315,23 +306,24 @@ def compute_dff_windowed_median(traces, def _check_kernel(kernel_size, data_size): if kernel_size % 2 == 0 or kernel_size <= 0 or kernel_size >= data_size: - raise ValueError("Invalid kernel length {} for data length {}. Kernel " - "length must be positive and odd, and less than data " - "length.".format(kernel_size, data_size)) + raise ValueError( + "Invalid kernel length {} for data length {}. Kernel " + "length must be positive and odd, and less than data " + "length.".format(kernel_size, data_size) + ) -def noise_std(x, noise_kernel_length=31, positive_peak_scale=1.5, - outlier_std_scale=2.5): +def noise_std(x, noise_kernel_length=31, positive_peak_scale=1.5, outlier_std_scale=2.5): """Robust estimate of the standard deviation of the trace noise.""" _check_kernel(noise_kernel_length, len(x)) if any(np.isnan(x)): return np.nan - x = x - median_filter(x, noise_kernel_length, mode='constant') + x = x - median_filter(x, noise_kernel_length, mode="constant") # first pass removing big pos peak outliers - x = x[x < positive_peak_scale*np.abs(x.min())] + x = x[x < positive_peak_scale * np.abs(x.min())] rstd = robust_std(x) # second pass removing remaining pos and neg peak outliers - x = x[abs(x) < outlier_std_scale*rstd] + x = x[abs(x) < outlier_std_scale * rstd] return robust_std(x) @@ -342,7 +334,7 @@ def robust_std(x): deviation of x. """ median_absolute_deviation = np.median(np.abs(x - np.median(x))) - return GAUSSIAN_MAD_STD_SCALE*median_absolute_deviation + return GAUSSIAN_MAD_STD_SCALE * median_absolute_deviation def calculate_dff(traces, dff_computation_cb=None, save_plot_dir=None): @@ -382,9 +374,8 @@ def calculate_dff(traces, dff_computation_cb=None, save_plot_dir=None): fig = plt.figure(figsize=(150, 40)) plot_onetrace(dff[n, :], traces[n, :]) - plt.title('ROI ' + str(n) + ' ', fontsize=18) - fig.savefig(os.path.join(save_plot_dir, 'dff_%d.png' % - n), orientation='landscape') + plt.title("ROI " + str(n) + " ", fontsize=18) + fig.savefig(os.path.join(save_plot_dir, "dff_%d.png" % n), orientation="landscape") plt.close(fig) return dff @@ -403,8 +394,7 @@ def main(): # read from "data" if args.input_h5.endswith("nwb"): - timestamps, traces = BrainObservatoryNwbDataSet( - args.input_h5).get_corrected_fluorescence_traces() + timestamps, traces = BrainObservatoryNwbDataSet(args.input_h5).get_corrected_fluorescence_traces() else: input_h5 = h5py.File(args.input_h5, "r") traces = input_h5["data"][()] diff --git a/allensdk/brain_observatory/drifting_gratings.py b/allensdk/brain_observatory/drifting_gratings.py index c840a9d0df..0656d414c4 100644 --- a/allensdk/brain_observatory/drifting_gratings.py +++ b/allensdk/brain_observatory/drifting_gratings.py @@ -47,14 +47,14 @@ class DriftingGratings(StimulusAnalysis): - """ Perform tuning analysis specific to drifting gratings stimulus. + """Perform tuning analysis specific to drifting gratings stimulus. Parameters ---------- data_set: BrainObservatoryNwbDataSet object """ - _log = logging.getLogger('allensdk.brain_observatory.drifting_gratings') + _log = logging.getLogger("allensdk.brain_observatory.drifting_gratings") def __init__(self, data_set, **kwargs): super(DriftingGratings, self).__init__(data_set, **kwargs) @@ -97,16 +97,15 @@ def number_tf(self): return self._number_tf def populate_stimulus_table(self): - stimulus_table = self.data_set.get_stimulus_table('drifting_gratings') - self._stim_table = stimulus_table.fillna(value=0.) + stimulus_table = self.data_set.get_stimulus_table("drifting_gratings") + self._stim_table = stimulus_table.fillna(value=0.0) self._orivals = np.unique(self.stim_table.orientation).astype(int) - self._tfvals = np.unique(self.stim_table.temporal_frequency).astype( - int) + self._tfvals = np.unique(self.stim_table.temporal_frequency).astype(int) self._number_ori = len(self.orivals) self._number_tf = len(self.tfvals) def get_response(self): - ''' Computes the mean response for each cell to each stimulus + """Computes the mean response for each cell to each stimulus condition. Return is a (# orientations, # temporal frequencies, # cells, 3) np.ndarray. The final dimension @@ -119,11 +118,10 @@ def get_response(self): Returns ------- Numpy array storing mean responses. - ''' + """ DriftingGratings._log.info("Calculating mean responses") - response = np.empty( - (self.number_ori, self.number_tf, self.numbercells + 1, 3)) + response = np.empty((self.number_ori, self.number_tf, self.numbercells + 1, 3)) def ptest(x): if x.empty: @@ -135,20 +133,18 @@ def ptest(x): for tf in self.tfvals: tf_pt = np.where(self.tfvals == tf)[0][0] subset_response = self.mean_sweep_response[ - (self.stim_table.temporal_frequency == tf) & ( - self.stim_table.orientation == ori)] + (self.stim_table.temporal_frequency == tf) & (self.stim_table.orientation == ori) + ] subset_pval = self.pval[ - (self.stim_table.temporal_frequency == tf) & ( - self.stim_table.orientation == ori)] + (self.stim_table.temporal_frequency == tf) & (self.stim_table.orientation == ori) + ] response[ori_pt, tf_pt, :, 0] = subset_response.mean(axis=0) - response[ori_pt, tf_pt, :, 1] = subset_response.std( - axis=0) / sqrt(len(subset_response)) - response[ori_pt, tf_pt, :, 2] = subset_pval.apply( - ptest, axis=0) + response[ori_pt, tf_pt, :, 1] = subset_response.std(axis=0) / sqrt(len(subset_response)) + response[ori_pt, tf_pt, :, 2] = subset_pval.apply(ptest, axis=0) return response def get_peak(self): - ''' Computes metrics related to each cell's peak response condition. + """Computes metrics related to each cell's peak response condition. Returns ------- @@ -164,22 +160,32 @@ def get_peak(self): * p_run_dg * run_modulation_dg * cv_dg (circular variance) - ''' - DriftingGratings._log.info('Calculating peak response properties') - - peak = pd.DataFrame(index=range(self.numbercells), - columns=('ori_dg', 'tf_dg', 'reliability_dg', - 'osi_dg', 'dsi_dg', 'peak_dff_dg', - 'ptest_dg', 'p_run_dg', - 'run_modulation_dg', - 'cv_os_dg', 'cv_ds_dg', 'tf_index_dg', - 'cell_specimen_id')) + """ + DriftingGratings._log.info("Calculating peak response properties") + + peak = pd.DataFrame( + index=range(self.numbercells), + columns=( + "ori_dg", + "tf_dg", + "reliability_dg", + "osi_dg", + "dsi_dg", + "peak_dff_dg", + "ptest_dg", + "p_run_dg", + "run_modulation_dg", + "cv_os_dg", + "cv_ds_dg", + "tf_index_dg", + "cell_specimen_id", + ), + ) cids = self.data_set.get_cell_specimen_ids() orivals_rad = np.deg2rad(self.orivals) for nc in range(self.numbercells): - cell_peak = np.where(self.response[:, 1:, nc, 0] == np.nanmax( - self.response[:, 1:, nc, 0])) + cell_peak = np.where(self.response[:, 1:, nc, 0] == np.nanmax(self.response[:, 1:, nc, 0])) prefori = cell_peak[0][0] preftf = cell_peak[1][0] + 1 peak.cell_specimen_id.iloc[nc] = cids[nc] @@ -198,8 +204,8 @@ def get_peak(self): CV_top_os = np.empty((8), dtype=np.complex128) CV_top_ds = np.empty((8), dtype=np.complex128) for i in range(8): - CV_top_os[i] = (tuning[i] * np.exp(1j * 2 * orivals_rad[i])) - CV_top_ds[i] = (tuning[i] * np.exp(1j * orivals_rad[i])) + CV_top_os[i] = tuning[i] * np.exp(1j * 2 * orivals_rad[i]) + CV_top_ds[i] = tuning[i] * np.exp(1j * orivals_rad[i]) peak.cv_os_dg.iloc[nc] = np.abs(CV_top_os.sum()) / tuning.sum() peak.cv_ds_dg.iloc[nc] = np.abs(CV_top_ds.sum()) / tuning.sum() @@ -212,39 +218,34 @@ def get_peak(self): for tf in self.tfvals[1:]: groups.append( self.mean_sweep_response[ - (self.stim_table.temporal_frequency == tf) & - (self.stim_table.orientation == ori)][str(nc)]) - groups.append(self.mean_sweep_response[ - self.stim_table.temporal_frequency == 0][ - str(nc)]) + (self.stim_table.temporal_frequency == tf) & (self.stim_table.orientation == ori) + ][str(nc)] + ) + groups.append(self.mean_sweep_response[self.stim_table.temporal_frequency == 0][str(nc)]) _, p = st.f_oneway(*groups) peak.ptest_dg.iloc[nc] = p subset = self.mean_sweep_response[ - (self.stim_table.temporal_frequency == self.tfvals[preftf]) & - (self.stim_table.orientation == self.orivals[prefori])] + (self.stim_table.temporal_frequency == self.tfvals[preftf]) + & (self.stim_table.orientation == self.orivals[prefori]) + ] # running modulation subset_stat = subset[subset.dx < 1] subset_run = subset[subset.dx >= 1] if (len(subset_run) > 2) & (len(subset_stat) > 2): - (_, peak.p_run_dg.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], - subset_stat[ - str(nc)], - equal_var=False) + (_, peak.p_run_dg.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], subset_stat[str(nc)], equal_var=False) if subset_run[str(nc)].mean() > subset_stat[str(nc)].mean(): - peak.run_modulation_dg.iloc[nc] = (subset_run[ - str(nc)].mean() - - subset_stat[ - str(nc)].mean()) \ - / np.abs( - subset_run[str(nc)].mean()) + peak.run_modulation_dg.iloc[nc] = ( + subset_run[str(nc)].mean() - subset_stat[str(nc)].mean() + ) / np.abs(subset_run[str(nc)].mean()) elif subset_run[str(nc)].mean() < subset_stat[str(nc)].mean(): - peak.run_modulation_dg.iloc[nc] = \ - (-1 * (subset_stat[str(nc)].mean() - - subset_run[str(nc)].mean()) / - np.abs(subset_stat[str(nc)].mean())) + peak.run_modulation_dg.iloc[nc] = ( + -1 + * (subset_stat[str(nc)].mean() - subset_run[str(nc)].mean()) + / np.abs(subset_stat[str(nc)].mean()) + ) else: peak.p_run_dg.iloc[nc] = np.nan @@ -252,13 +253,13 @@ def get_peak(self): # reliability subset = self.sweep_response[ - (self.stim_table.temporal_frequency == self.tfvals[preftf]) & - (self.stim_table.orientation == self.orivals[prefori])] + (self.stim_table.temporal_frequency == self.tfvals[preftf]) + & (self.stim_table.orientation == self.orivals[prefori]) + ] corr_matrix = np.empty((len(subset), len(subset))) for i in range(len(subset)): for j in range(len(subset)): - r, p = st.pearsonr(subset[str(nc)].iloc[i][30:90], - subset[str(nc)].iloc[j][30:90]) + r, p = st.pearsonr(subset[str(nc)].iloc[i][30:90], subset[str(nc)].iloc[j][30:90]) corr_matrix[i, j] = r mask = np.ones((len(subset), len(subset))) for i in range(len(subset)): @@ -271,23 +272,19 @@ def get_peak(self): # TF index tf_tuning = self.response[prefori, 1:, nc, 0] trials = self.mean_sweep_response[ - (self.stim_table.temporal_frequency != 0) & - (self.stim_table.orientation == self.orivals[prefori]) + (self.stim_table.temporal_frequency != 0) & (self.stim_table.orientation == self.orivals[prefori]) ][str(nc)].values - SSE_part = np.sqrt( - np.sum((trials - trials.mean()) ** 2) / (len(trials) - 5)) - peak.tf_index_dg.iloc[nc] = (np.ptp(tf_tuning)) / ( - np.ptp(tf_tuning) + 2 * SSE_part) + SSE_part = np.sqrt(np.sum((trials - trials.mean()) ** 2) / (len(trials) - 5)) + peak.tf_index_dg.iloc[nc] = (np.ptp(tf_tuning)) / (np.ptp(tf_tuning) + 2 * SSE_part) return peak - def open_star_plot(self, cell_specimen_id=None, include_labels=False, - cell_index=None): + def open_star_plot(self, cell_specimen_id=None, include_labels=False, cell_index=None): cell_index = self.row_from_cell_id(cell_specimen_id, cell_index) df = self.mean_sweep_response[str(cell_index)] - st = self.data_set.get_stimulus_table('drifting_gratings') - mask = st.dropna(subset=['orientation']).index + st = self.data_set.get_stimulus_table("drifting_gratings") + mask = st.dropna(subset=["orientation"]).index data = df.values @@ -295,264 +292,234 @@ def open_star_plot(self, cell_specimen_id=None, include_labels=False, cmax = max(cmin, data.mean() + data.std() * 3) fp = cplots.FanPlotter.for_drifting_gratings() - fp.plot(r_data=st.temporal_frequency.loc[mask].values, - angle_data=st.orientation.loc[mask].values, - data=df.loc[mask].values, - clim=[cmin, cmax]) + fp.plot( + r_data=st.temporal_frequency.loc[mask].values, + angle_data=st.orientation.loc[mask].values, + data=df.loc[mask].values, + clim=[cmin, cmax], + ) fp.show_axes(closed=True) if include_labels: fp.show_r_labels() fp.show_angle_labels() - def plot_orientation_selectivity(self, - si_range=oplots.SI_RANGE, - n_hist_bins=oplots.N_HIST_BINS, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): + def plot_orientation_selectivity( + self, + si_range=oplots.SI_RANGE, + n_hist_bins=oplots.N_HIST_BINS, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): # responsive cells - vis_cells = (self.peak.ptest_dg < p_value_max) & ( - self.peak.peak_dff_dg > peak_dff_min) + vis_cells = (self.peak.ptest_dg < p_value_max) & (self.peak.peak_dff_dg > peak_dff_min) # orientation selective cells - osi_cells = vis_cells & (self.peak.osi_dg > si_range[0]) & ( - self.peak.osi_dg < si_range[1]) + osi_cells = vis_cells & (self.peak.osi_dg > si_range[0]) & (self.peak.osi_dg < si_range[1]) peak_osi = self.peak.loc[osi_cells] osis = peak_osi.osi_dg.values - oplots.plot_selectivity_cumulative_histogram(osis, - "orientation " - "selectivity index", - si_range=si_range, - n_hist_bins=n_hist_bins, - color=color) - - def plot_direction_selectivity(self, - si_range=oplots.SI_RANGE, - n_hist_bins=oplots.N_HIST_BINS, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - + oplots.plot_selectivity_cumulative_histogram( + osis, "orientation selectivity index", si_range=si_range, n_hist_bins=n_hist_bins, color=color + ) + + def plot_direction_selectivity( + self, + si_range=oplots.SI_RANGE, + n_hist_bins=oplots.N_HIST_BINS, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): # responsive cells - vis_cells = (self.peak.ptest_dg < p_value_max) & ( - self.peak.peak_dff_dg > peak_dff_min) + vis_cells = (self.peak.ptest_dg < p_value_max) & (self.peak.peak_dff_dg > peak_dff_min) # direction selective cells - dsi_cells = vis_cells & (self.peak.dsi_dg > si_range[0]) & ( - self.peak.dsi_dg < si_range[1]) + dsi_cells = vis_cells & (self.peak.dsi_dg > si_range[0]) & (self.peak.dsi_dg < si_range[1]) peak_dsi = self.peak.loc[dsi_cells] dsis = peak_dsi.dsi_dg.values - oplots.plot_selectivity_cumulative_histogram(dsis, - "direction selectivity " - "index", - si_range=si_range, - n_hist_bins=n_hist_bins, - color=color) - - def plot_preferred_direction(self, - include_labels=False, - si_range=oplots.SI_RANGE, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - vis_cells = (self.peak.ptest_dg < p_value_max) & ( - self.peak.peak_dff_dg > peak_dff_min) + oplots.plot_selectivity_cumulative_histogram( + dsis, "direction selectivity index", si_range=si_range, n_hist_bins=n_hist_bins, color=color + ) + + def plot_preferred_direction( + self, + include_labels=False, + si_range=oplots.SI_RANGE, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): + vis_cells = (self.peak.ptest_dg < p_value_max) & (self.peak.peak_dff_dg > peak_dff_min) pref_dirs = self.peak.loc[vis_cells].ori_dg.values pref_dirs = [self.orivals[pref_dir] for pref_dir in pref_dirs] angles, counts = np.unique(pref_dirs, return_counts=True) - oplots.plot_radial_histogram(angles, - counts, - include_labels=include_labels, - all_angles=self.orivals, - direction=-1, - offset=0.0, - closed=True, - color=color) - - def plot_preferred_temporal_frequency(self, - si_range=oplots.SI_RANGE, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - - vis_cells = (self.peak.ptest_dg < p_value_max) & ( - self.peak.peak_dff_dg > peak_dff_min) + oplots.plot_radial_histogram( + angles, + counts, + include_labels=include_labels, + all_angles=self.orivals, + direction=-1, + offset=0.0, + closed=True, + color=color, + ) + + def plot_preferred_temporal_frequency( + self, + si_range=oplots.SI_RANGE, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): + vis_cells = (self.peak.ptest_dg < p_value_max) & (self.peak.peak_dff_dg > peak_dff_min) pref_tfs = self.peak.loc[vis_cells].tf_dg.values - oplots.plot_condition_histogram(pref_tfs, - self.tfvals[1:], - color=color) + oplots.plot_condition_histogram(pref_tfs, self.tfvals[1:], color=color) plt.xlabel("temporal frequency (Hz)") plt.ylabel("number of cells") def reshape_response_array(self): - ''' + """ :return: response array in cells x stim x repetition for noise correlations - ''' + """ - mean_sweep_response = \ - self.mean_sweep_response.values[:, :self.numbercells] + mean_sweep_response = self.mean_sweep_response.values[:, : self.numbercells] stim_table = self.stim_table tfvals = self.tfvals tfvals = tfvals[tfvals != 0] # blank sweep - response_new = np.zeros( - (self.numbercells, self.number_ori, self.number_tf - 1), - dtype='object') + response_new = np.zeros((self.numbercells, self.number_ori, self.number_tf - 1), dtype="object") for i, ori in enumerate(self.orivals): for j, tf in enumerate(tfvals): - ind = (stim_table.orientation.values == ori) * ( - stim_table.temporal_frequency.values == tf) + ind = (stim_table.orientation.values == ori) * (stim_table.temporal_frequency.values == tf) for c in range(self.numbercells): response_new[c, i, j] = mean_sweep_response[ind, c] - ind = (stim_table.temporal_frequency.values == 0) + ind = stim_table.temporal_frequency.values == 0 response_blank = mean_sweep_response[ind, :].T return response_new, response_blank - def get_signal_correlation(self, corr='spearman'): + def get_signal_correlation(self, corr="spearman"): logging.debug("Calculating signal correlation") # orientation x freq x cell, no blank - response = \ - self.response[:, 1:, :self.numbercells, 0] + response = self.response[:, 1:, : self.numbercells, 0] - response = response.reshape(self.number_ori * (self.number_tf - 1), - self.numbercells).T + response = response.reshape(self.number_ori * (self.number_tf - 1), self.numbercells).T N, Nstim = response.shape signal_corr = np.zeros((N, N)) signal_p = np.empty((N, N)) - if corr == 'pearson': + if corr == "pearson": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.pearsonr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.spearmanr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") # fill in lower triangle - signal_corr = \ - np.triu(signal_corr) + \ - np.triu(signal_corr, 1).T + signal_corr = np.triu(signal_corr) + np.triu(signal_corr, 1).T # fill in lower triangle - signal_p = \ - np.triu(signal_p) + \ - np.triu(signal_p, 1).T + signal_p = np.triu(signal_p) + np.triu(signal_p, 1).T return signal_corr, signal_p - def get_representational_similarity(self, corr='spearman'): + def get_representational_similarity(self, corr="spearman"): logging.debug("Calculating representational similarity") # orientation x freq x phase x cell, no blank - response = self.response[:, 1:, :self.numbercells, 0] + response = self.response[:, 1:, : self.numbercells, 0] - response = response.reshape(self.number_ori * (self.number_tf - 1), - self.numbercells) + response = response.reshape(self.number_ori * (self.number_tf - 1), self.numbercells) # TODO 25 lines of repeated code!!!!!!!! Nstim, N = response.shape rep_sim = np.zeros((Nstim, Nstim)) rep_sim_p = np.empty((Nstim, Nstim)) - if corr == 'pearson': + if corr == "pearson": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") - rep_sim = np.triu(rep_sim) + np.triu(rep_sim, - 1).T # fill in lower triangle - rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, - 1).T # fill in lower triangle + rep_sim = np.triu(rep_sim) + np.triu(rep_sim, 1).T # fill in lower triangle + rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, 1).T # fill in lower triangle return rep_sim, rep_sim_p - def get_noise_correlation(self, corr='spearman'): + def get_noise_correlation(self, corr="spearman"): logging.debug("Calculating noise correlations") response, response_blank = self.reshape_response_array() - noise_corr = np.zeros((self.numbercells, self.numbercells, - self.number_ori, self.number_tf - 1)) - noise_corr_p = np.zeros((self.numbercells, self.numbercells, - self.number_ori, self.number_tf - 1)) + noise_corr = np.zeros((self.numbercells, self.numbercells, self.number_ori, self.number_tf - 1)) + noise_corr_p = np.zeros((self.numbercells, self.numbercells, self.number_ori, self.number_tf - 1)) noise_corr_blank = np.zeros((self.numbercells, self.numbercells)) noise_corr_blank_p = np.zeros((self.numbercells, self.numbercells)) - if corr == 'pearson': + if corr == "pearson": for k in range(self.number_ori): - for l in range(self.number_tf - 1): # noqa E741 + for l in range(self.number_tf - 1): # noqa E741 for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k, l], noise_corr_p[ - i, j, k, l] = st.pearsonr(response[i, k, l], - response[j, k, l]) + noise_corr[i, j, k, l], noise_corr_p[i, j, k, l] = st.pearsonr( + response[i, k, l], response[j, k, l] + ) - noise_corr[:, :, k, l] = np.triu( - noise_corr[:, :, k, l]) + np.triu( - noise_corr[:, :, k, l], 1).T + noise_corr[:, :, k, l] = np.triu(noise_corr[:, :, k, l]) + np.triu(noise_corr[:, :, k, l], 1).T for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr_blank[i, j], noise_corr_blank_p[ - i, j] = st.pearsonr(response_blank[i], - response_blank[j]) + noise_corr_blank[i, j], noise_corr_blank_p[i, j] = st.pearsonr(response_blank[i], response_blank[j]) - elif corr == 'spearman': + elif corr == "spearman": for k in range(self.number_ori): - for l in range(self.number_tf - 1): # noqa E741 + for l in range(self.number_tf - 1): # noqa E741 for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k, l], noise_corr_p[ - i, j, k, l] = st.spearmanr(response[i, k, l], - response[j, k, l]) + noise_corr[i, j, k, l], noise_corr_p[i, j, k, l] = st.spearmanr( + response[i, k, l], response[j, k, l] + ) - noise_corr[:, :, k, l] = np.triu( - noise_corr[:, :, k, l]) + np.triu( - noise_corr[:, :, k, l], 1).T + noise_corr[:, :, k, l] = np.triu(noise_corr[:, :, k, l]) + np.triu(noise_corr[:, :, k, l], 1).T for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr_blank[i, j], noise_corr_blank_p[ - i, j] = st.spearmanr(response_blank[i], - response_blank[j]) + noise_corr_blank[i, j], noise_corr_blank_p[i, j] = st.spearmanr( + response_blank[i], response_blank[j] + ) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") - noise_corr_blank[:, :] = np.triu(noise_corr_blank[:, :]) + np.triu( - noise_corr_blank[:, :], 1).T + noise_corr_blank[:, :] = np.triu(noise_corr_blank[:, :]) + np.triu(noise_corr_blank[:, :], 1).T return noise_corr, noise_corr_p, noise_corr_blank, noise_corr_blank_p @@ -563,10 +530,8 @@ def from_analysis_file(data_set, analysis_file): try: dg.populate_stimulus_table() - dg._sweep_response = pd.read_hdf(analysis_file, - "analysis/sweep_response_dg") - dg._mean_sweep_response = pd.read_hdf( - analysis_file, "analysis/mean_sweep_response_dg") + dg._sweep_response = pd.read_hdf(analysis_file, "analysis/sweep_response_dg") + dg._mean_sweep_response = pd.read_hdf(analysis_file, "analysis/mean_sweep_response_dg") dg._peak = pd.read_hdf(analysis_file, "analysis/peak") with h5py.File(analysis_file, "r") as f: @@ -580,8 +545,7 @@ def from_analysis_file(data_set, analysis_file): if "analysis/signal_corr_dg" in f: dg.signal_correlation = f["analysis/signal_corr_dg"][()] if "analysis/rep_similarity_dg" in f: - dg.representational_similarity = f[ - "analysis/rep_similarity_dg"][()] + dg.representational_similarity = f["analysis/rep_similarity_dg"][()] except Exception as e: raise MissingStimulusException(e.args) diff --git a/allensdk/brain_observatory/ecephys/__init__.py b/allensdk/brain_observatory/ecephys/__init__.py index 8e131952a3..b9c96a41e8 100644 --- a/allensdk/brain_observatory/ecephys/__init__.py +++ b/allensdk/brain_observatory/ecephys/__init__.py @@ -2,18 +2,9 @@ UNIT_FILTER_DEFAULTS = { - "amplitude_cutoff_maximum": { - "value": 0.1, - "missing": np.inf - }, - "presence_ratio_minimum": { - "value": 0.95, - "missing": -np.inf - }, - "isi_violations_maximum": { - "value": 0.5, - "missing": np.inf - } + "amplitude_cutoff_maximum": {"value": 0.1, "missing": np.inf}, + "presence_ratio_minimum": {"value": 0.95, "missing": -np.inf}, + "isi_violations_maximum": {"value": 0.5, "missing": np.inf}, } diff --git a/allensdk/brain_observatory/ecephys/_behavior_ecephys_metadata.py b/allensdk/brain_observatory/ecephys/_behavior_ecephys_metadata.py index 066b4407e1..225ea4a9b7 100644 --- a/allensdk/brain_observatory/ecephys/_behavior_ecephys_metadata.py +++ b/allensdk/brain_observatory/ecephys/_behavior_ecephys_metadata.py @@ -1,46 +1,37 @@ from pynwb import NWBFile from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_session_uuid import \ - BehaviorSessionUUID -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.date_of_acquisition import \ - DateOfAcquisition -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.equipment import \ - Equipment -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.session_type import \ - SessionType -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.project_code import \ - ProjectCode -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.stimulus_frame_rate import \ - StimulusFrameRate -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.subject_metadata import \ - SubjectMetadata +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_session_uuid import ( + BehaviorSessionUUID, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.date_of_acquisition import ( + DateOfAcquisition, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.equipment import Equipment +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.session_type import SessionType +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.project_code import ProjectCode +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.stimulus_frame_rate import ( + StimulusFrameRate, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.subject_metadata import SubjectMetadata from allensdk.core import JsonReadableInterface, NwbReadableInterface -class BehaviorEcephysMetadata(BehaviorMetadata, JsonReadableInterface, - NwbReadableInterface): +class BehaviorEcephysMetadata(BehaviorMetadata, JsonReadableInterface, NwbReadableInterface): def __init__( - self, - ecephys_session_id: int, - date_of_acquisition: DateOfAcquisition, - subject_metadata: SubjectMetadata, - behavior_session_id: BehaviorSessionId, - behavior_session_uuid: BehaviorSessionUUID, - equipment: Equipment, - session_type: SessionType, - stimulus_frame_rate: StimulusFrameRate, - project_code: ProjectCode = ProjectCode(), + self, + ecephys_session_id: int, + date_of_acquisition: DateOfAcquisition, + subject_metadata: SubjectMetadata, + behavior_session_id: BehaviorSessionId, + behavior_session_uuid: BehaviorSessionUUID, + equipment: Equipment, + session_type: SessionType, + stimulus_frame_rate: StimulusFrameRate, + project_code: ProjectCode = ProjectCode(), ): super().__init__( date_of_acquisition=date_of_acquisition, @@ -62,16 +53,15 @@ def ecephys_session_id(self) -> int: def from_json(cls, dict_repr: dict) -> "BehaviorEcephysMetadata": behavior_metadata = super().from_json(dict_repr=dict_repr) return BehaviorEcephysMetadata( - ecephys_session_id=dict_repr['ecephys_session_id'], - date_of_acquisition=DateOfAcquisition( - behavior_metadata.date_of_acquisition), + ecephys_session_id=dict_repr["ecephys_session_id"], + date_of_acquisition=DateOfAcquisition(behavior_metadata.date_of_acquisition), subject_metadata=behavior_metadata.subject_metadata, behavior_session_id=behavior_metadata._behavior_session_id, behavior_session_uuid=behavior_metadata._behavior_session_uuid, equipment=behavior_metadata.equipment, session_type=behavior_metadata._session_type, project_code=behavior_metadata._project_code, - stimulus_frame_rate=behavior_metadata._stimulus_frame_rate + stimulus_frame_rate=behavior_metadata._stimulus_frame_rate, ) @classmethod @@ -79,8 +69,7 @@ def from_nwb(cls, nwbfile: NWBFile) -> "BehaviorEcephysMetadata": behavior_metadata = super().from_nwb(nwbfile=nwbfile) return BehaviorEcephysMetadata( ecephys_session_id=int(nwbfile.identifier), - date_of_acquisition=DateOfAcquisition( - behavior_metadata.date_of_acquisition), + date_of_acquisition=DateOfAcquisition(behavior_metadata.date_of_acquisition), behavior_session_id=behavior_metadata._behavior_session_id, behavior_session_uuid=behavior_metadata._behavior_session_uuid, equipment=behavior_metadata.equipment, diff --git a/allensdk/brain_observatory/ecephys/_channel.py b/allensdk/brain_observatory/ecephys/_channel.py index c687fd7cf0..2f0dce3b6d 100644 --- a/allensdk/brain_observatory/ecephys/_channel.py +++ b/allensdk/brain_observatory/ecephys/_channel.py @@ -2,28 +2,27 @@ import numpy as np from allensdk.core import DataObject -from allensdk.brain_observatory.ecephys.utils import ( - strip_substructure_acronym) +from allensdk.brain_observatory.ecephys.utils import strip_substructure_acronym class Channel(DataObject): """Probe channel""" + def __init__( - self, - id: int, - probe_id: int, - valid_data: bool, - probe_channel_number: int, - probe_vertical_position: int, - probe_horizontal_position: int, - structure_acronym: str = '', - anterior_posterior_ccf_coordinate: Optional[float] = None, - dorsal_ventral_ccf_coordinate: Optional[float] = None, - left_right_ccf_coordinate: Optional[float] = None, - impedance: float = np.nan, - filtering: str = 'AP band: 500 Hz high-pass; ' - 'LFP band: 1000 Hz low-pass', - strip_structure_subregion: bool = True + self, + id: int, + probe_id: int, + valid_data: bool, + probe_channel_number: int, + probe_vertical_position: int, + probe_horizontal_position: int, + structure_acronym: str = "", + anterior_posterior_ccf_coordinate: Optional[float] = None, + dorsal_ventral_ccf_coordinate: Optional[float] = None, + left_right_ccf_coordinate: Optional[float] = None, + impedance: float = np.nan, + filtering: str = "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", + strip_structure_subregion: bool = True, ): """ @@ -34,9 +33,7 @@ def __init__( parsed as "LGd". You might want to strip it if the subregion is beyond annotation accuracy. """ - super().__init__(name='channel', - value=None, - is_value_self=True) + super().__init__(name="channel", value=None, is_value_self=True) self._id = id self._probe_id = probe_id self._valid_data = valid_data @@ -44,8 +41,7 @@ def __init__( self._probe_vertical_position = probe_vertical_position self._probe_horizontal_position = probe_horizontal_position self._structure_acronym = structure_acronym - self._anterior_posterior_ccf_coordinate = \ - anterior_posterior_ccf_coordinate + self._anterior_posterior_ccf_coordinate = anterior_posterior_ccf_coordinate self._dorsal_ventral_ccf_coordinate = dorsal_ventral_ccf_coordinate self._left_right_ccf_coordinate = left_right_ccf_coordinate self._impedance = impedance diff --git a/allensdk/brain_observatory/ecephys/_channels.py b/allensdk/brain_observatory/ecephys/_channels.py index 74a3fba9d0..5956a9b8db 100644 --- a/allensdk/brain_observatory/ecephys/_channels.py +++ b/allensdk/brain_observatory/ecephys/_channels.py @@ -5,45 +5,40 @@ from allensdk.brain_observatory.ecephys._channel import Channel from allensdk.brain_observatory.ecephys.utils import clobbering_merge -from allensdk.core import DataObject, NwbReadableInterface, \ - JsonReadableInterface +from allensdk.core import DataObject, NwbReadableInterface, JsonReadableInterface class Channels(DataObject, NwbReadableInterface, JsonReadableInterface): """A set of channels""" def __init__(self, channels: List[Channel]): - super().__init__(name='channels', value=channels) + super().__init__(name="channels", value=channels) @classmethod - def from_json( - cls, - channels: dict - ) -> "Channels": + def from_json(cls, channels: dict) -> "Channels": for channel in channels: - if 'impedence' in channel: + if "impedence" in channel: # Correct misspelling - channel['impedance'] = channel.pop('impedence') - - channels = [Channel( - id=channel['id'], - probe_id=channel['probe_id'], - valid_data=channel['valid_data'], - probe_channel_number=channel['probe_channel_number'], - probe_vertical_position=channel['probe_vertical_position'], - probe_horizontal_position=channel['probe_horizontal_position'], - structure_acronym=channel['structure_acronym'], - anterior_posterior_ccf_coordinate=( - channel['anterior_posterior_ccf_coordinate']), - dorsal_ventral_ccf_coordinate=( - channel['dorsal_ventral_ccf_coordinate']), - left_right_ccf_coordinate=channel['left_right_ccf_coordinate'] - ) - for channel in channels] + channel["impedance"] = channel.pop("impedence") + + channels = [ + Channel( + id=channel["id"], + probe_id=channel["probe_id"], + valid_data=channel["valid_data"], + probe_channel_number=channel["probe_channel_number"], + probe_vertical_position=channel["probe_vertical_position"], + probe_horizontal_position=channel["probe_horizontal_position"], + structure_acronym=channel["structure_acronym"], + anterior_posterior_ccf_coordinate=(channel["anterior_posterior_ccf_coordinate"]), + dorsal_ventral_ccf_coordinate=(channel["dorsal_ventral_ccf_coordinate"]), + left_right_ccf_coordinate=channel["left_right_ccf_coordinate"], + ) + for channel in channels + ] return Channels(channels=channels) - def to_dataframe(self, external_channel_columns=None, - filter_by_validity=True) -> pd.DataFrame: + def to_dataframe(self, external_channel_columns=None, filter_by_validity=True) -> pd.DataFrame: """ Parameters @@ -56,24 +51,22 @@ def to_dataframe(self, external_channel_columns=None, ------- """ - channels = [channel.to_dict()['channel'] for channel in self.value] + channels = [channel.to_dict()["channel"] for channel in self.value] channels = pd.DataFrame(channels) - channels = channels.set_index('id') - channels = channels.drop(columns=['impedance']) + channels = channels.set_index("id") + channels = channels.drop(columns=["impedance"]) if external_channel_columns is not None: external_channel_columns = external_channel_columns() - channels = clobbering_merge(channels, external_channel_columns, - left_index=True, right_index=True) + channels = clobbering_merge(channels, external_channel_columns, left_index=True, right_index=True) if filter_by_validity: - channels = channels[channels['valid_data']] - channels = channels.drop(columns=['valid_data']) + channels = channels[channels["valid_data"]] + channels = channels.drop(columns=["valid_data"]) return channels @classmethod - def from_nwb(cls, nwbfile: NWBFile, - probe_id: Optional[int] = None) -> "Channels": + def from_nwb(cls, nwbfile: NWBFile, probe_id: Optional[int] = None) -> "Channels": """ Parameters @@ -88,41 +81,42 @@ def from_nwb(cls, nwbfile: NWBFile, channels = [] for channel_id, row in nwbfile.electrodes.to_dataframe().iterrows(): if probe_id is not None: - if row['probe_id'] != probe_id: + if row["probe_id"] != probe_id: continue # this block of code is necessary to maintain # backwards compatibility with Visual Coding Neuropixels # NWB files, which used 'local_index' to mean what we # now mean by 'probe_channel_number' - has_local_index = ('local_index' in row.keys()) - has_channel_number = ('probe_channel_number' in row.keys()) + has_local_index = "local_index" in row.keys() + has_channel_number = "probe_channel_number" in row.keys() if has_local_index and has_channel_number: - raise RuntimeError("Unclear how to read channel; " - "has both 'local_index' and " - "'probe_channel_number'") + raise RuntimeError("Unclear how to read channel; has both 'local_index' and 'probe_channel_number'") elif has_local_index: - idx_col = 'local_index' + idx_col = "local_index" elif has_channel_number: - idx_col = 'probe_channel_number' + idx_col = "probe_channel_number" else: - raise RuntimeError("Unclear how to read channel; " - "has neither 'local_index' nor " - "'probe_channel_number'.\n" - f"Columns are {row.keys()}") - - structure_acronym = \ - None if row['location'] in ['None', ''] else row['location'] - channels.append(Channel( - id=channel_id, - probe_channel_number=row[idx_col], - probe_horizontal_position=row['probe_horizontal_position'], - probe_vertical_position=row['probe_vertical_position'], - probe_id=row['probe_id'], - valid_data=row['valid_data'], - structure_acronym=structure_acronym, - anterior_posterior_ccf_coordinate=row['x'], - dorsal_ventral_ccf_coordinate=row['y'], - left_right_ccf_coordinate=row['z'] - )) + raise RuntimeError( + "Unclear how to read channel; " + "has neither 'local_index' nor " + "'probe_channel_number'.\n" + f"Columns are {row.keys()}" + ) + + structure_acronym = None if row["location"] in ["None", ""] else row["location"] + channels.append( + Channel( + id=channel_id, + probe_channel_number=row[idx_col], + probe_horizontal_position=row["probe_horizontal_position"], + probe_vertical_position=row["probe_vertical_position"], + probe_id=row["probe_id"], + valid_data=row["valid_data"], + structure_acronym=structure_acronym, + anterior_posterior_ccf_coordinate=row["x"], + dorsal_ventral_ccf_coordinate=row["y"], + left_right_ccf_coordinate=row["z"], + ) + ) return Channels(channels=channels) diff --git a/allensdk/brain_observatory/ecephys/_current_source_density.py b/allensdk/brain_observatory/ecephys/_current_source_density.py index b319812e67..bfb7d2207d 100644 --- a/allensdk/brain_observatory/ecephys/_current_source_density.py +++ b/allensdk/brain_observatory/ecephys/_current_source_density.py @@ -7,12 +7,8 @@ class CurrentSourceDensity(DataObject, JsonReadableInterface): """Current Source Density""" - def __init__( - self, - data: np.ndarray, - timestamps: np.ndarray, - interpolated_channel_locations: np.ndarray - ): + + def __init__(self, data: np.ndarray, timestamps: np.ndarray, interpolated_channel_locations: np.ndarray): """ Parameters @@ -24,11 +20,7 @@ def __init__( interpolated_channel_locations: Array of interpolated channel indices for CSD """ - super().__init__( - name='current_source_density', - value=None, - is_value_self=True - ) + super().__init__(name="current_source_density", value=None, is_value_self=True) self._data = data self._timestamps = timestamps self._interpolated_channel_locations = interpolated_channel_locations @@ -51,11 +43,11 @@ def channel_locations(self) -> np.ndarray: @classmethod def from_json(cls, probe_meta: dict) -> "CurrentSourceDensity": scale = probe_meta.get("scale_mean_waveform_and_csd", 1) - with h5py.File(probe_meta['csd_path'], "r") as csd_file: + with h5py.File(probe_meta["csd_path"], "r") as csd_file: return CurrentSourceDensity( data=csd_file["current_source_density"][:] / scale, timestamps=csd_file["timestamps"][:], - interpolated_channel_locations=csd_file["csd_locations"][:] + interpolated_channel_locations=csd_file["csd_locations"][:], ) def to_dataarray(self) -> DataArray: @@ -69,9 +61,7 @@ def to_dataarray(self) -> DataArray: coords={ "virtual_channel_index": np.arange(self.data.shape[0]), "time": self.timestamps, - "vertical_position": (("virtual_channel_index",), - y_locs), - "horizontal_position": (("virtual_channel_index",), - x_locs) - } + "vertical_position": (("virtual_channel_index",), y_locs), + "horizontal_position": (("virtual_channel_index",), x_locs), + }, ) diff --git a/allensdk/brain_observatory/ecephys/_lfp.py b/allensdk/brain_observatory/ecephys/_lfp.py index 1f2a00fc06..17aa96de9a 100644 --- a/allensdk/brain_observatory/ecephys/_lfp.py +++ b/allensdk/brain_observatory/ecephys/_lfp.py @@ -1,8 +1,7 @@ import numpy as np from xarray import DataArray -from allensdk.brain_observatory.ecephys.file_io.continuous_file import \ - ContinuousFile +from allensdk.brain_observatory.ecephys.file_io.continuous_file import ContinuousFile from allensdk.core import DataObject, JsonReadableInterface @@ -10,13 +9,8 @@ class LFP(DataObject, JsonReadableInterface): """ Probe LFP """ - def __init__( - self, - data: np.ndarray, - timestamps: np.ndarray, - channels: np.ndarray, - sampling_rate: float - ): + + def __init__(self, data: np.ndarray, timestamps: np.ndarray, channels: np.ndarray, sampling_rate: float): """ Parameters @@ -30,7 +24,7 @@ def __init__( sampling_rate LFP sampling rate """ - super().__init__(name='lfp', value=None, is_value_self=True) + super().__init__(name="lfp", value=None, is_value_self=True) self._data = data self._timestamps = timestamps self._channels = channels @@ -53,11 +47,7 @@ def sampling_rate(self): return self._sampling_rate @classmethod - def from_json( - cls, - probe_meta: dict, - amplitude_scale_factor: float = 0.195e-6 - ) -> "LFP": + def from_json(cls, probe_meta: dict, amplitude_scale_factor: float = 0.195e-6) -> "LFP": """ Parameters @@ -72,35 +62,23 @@ def from_json( ------- `LFP` instance """ - lfp_meta = probe_meta['lfp'] - lfp_channels = np.load(lfp_meta['input_channels_path'], - allow_pickle=False) + lfp_meta = probe_meta["lfp"] + lfp_channels = np.load(lfp_meta["input_channels_path"], allow_pickle=False) lfp_data, lfp_timestamps = ContinuousFile( - data_path=lfp_meta['input_data_path'], - timestamps_path=lfp_meta['input_timestamps_path'], - total_num_channels=len(lfp_channels) + data_path=lfp_meta["input_data_path"], + timestamps_path=lfp_meta["input_timestamps_path"], + total_num_channels=len(lfp_channels), ).load(memmap=False) lfp_data = lfp_data.astype(np.float32) - lfp_data = lfp_data * probe_meta.get("amplitude_scale_factor", - amplitude_scale_factor) - - sampling_rate = ( - probe_meta['lfp_sampling_rate'] / - probe_meta['temporal_subsampling_factor']) - - return cls( - data=lfp_data, - timestamps=lfp_timestamps, - channels=lfp_channels, - sampling_rate=sampling_rate - ) + lfp_data = lfp_data * probe_meta.get("amplitude_scale_factor", amplitude_scale_factor) + + sampling_rate = probe_meta["lfp_sampling_rate"] / probe_meta["temporal_subsampling_factor"] + + return cls(data=lfp_data, timestamps=lfp_timestamps, channels=lfp_channels, sampling_rate=sampling_rate) def to_dataarray(self) -> DataArray: return DataArray( - name="LFP", - data=self._data, - dims=['time', 'channel'], - coords=[self._timestamps, self._channels] + name="LFP", data=self._data, dims=["time", "channel"], coords=[self._timestamps, self._channels] ) diff --git a/allensdk/brain_observatory/ecephys/_probe.py b/allensdk/brain_observatory/ecephys/_probe.py index 2360b33ee1..87b27e033b 100644 --- a/allensdk/brain_observatory/ecephys/_probe.py +++ b/allensdk/brain_observatory/ecephys/_probe.py @@ -10,18 +10,14 @@ from pynwb import NWBFile from xarray import DataArray -from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import \ - BehaviorEcephysMetadata +from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import BehaviorEcephysMetadata from allensdk.brain_observatory.ecephys._channels import Channels -from allensdk.brain_observatory.ecephys._current_source_density import \ - CurrentSourceDensity +from allensdk.brain_observatory.ecephys._current_source_density import CurrentSourceDensity from allensdk.brain_observatory.ecephys._units import Units from allensdk.brain_observatory.ecephys._lfp import LFP from allensdk.brain_observatory.ecephys.nwb import EcephysCSD -from allensdk.brain_observatory.ecephys.nwb_util import add_probe_to_nwbfile, \ - add_ecephys_electrodes -from allensdk.core import DataObject, JsonReadableInterface, \ - NwbWritableInterface, NwbReadableInterface +from allensdk.brain_observatory.ecephys.nwb_util import add_probe_to_nwbfile, add_ecephys_electrodes +from allensdk.core import DataObject, JsonReadableInterface, NwbWritableInterface, NwbReadableInterface @dataclasses.dataclass @@ -31,11 +27,11 @@ class ProbeWithLFPMeta: Attributes: - - lfp_csd_filepath --> Either a path to the NWB file containing the LFP - and CSD data or a callable which returns it. The nwb file is loaded - separately from the main session nwb file in order to load the LFP data - on the fly rather than with the main session NWB file. This is to speed - up download of the NWB for users who don't wish to load the LFP data (it + - lfp_csd_filepath --> Either a path to the NWB file containing the LFP + and CSD data or a callable which returns it. The nwb file is loaded + separately from the main session nwb file in order to load the LFP data + on the fly rather than with the main session NWB file. This is to speed + up download of the NWB for users who don't wish to load the LFP data (it is large). - lfp_sampling_rate --> LFP sampling rate """ # noqa E402 @@ -44,21 +40,21 @@ class ProbeWithLFPMeta: lfp_sampling_rate: float -class Probe(DataObject, JsonReadableInterface, NwbWritableInterface, - NwbReadableInterface): +class Probe(DataObject, JsonReadableInterface, NwbWritableInterface, NwbReadableInterface): """A single probe""" + def __init__( - self, - id: int, - name: str, - channels: Channels, - units: Units, - sampling_rate: float = 30000.0, - lfp: Optional[LFP] = None, - lfp_meta: Optional[ProbeWithLFPMeta] = None, - current_source_density: Optional[CurrentSourceDensity] = None, - location: str = 'See electrode locations', - temporal_subsampling_factor: Optional[float] = 2.0 + self, + id: int, + name: str, + channels: Channels, + units: Units, + sampling_rate: float = 30000.0, + lfp: Optional[LFP] = None, + lfp_meta: Optional[ProbeWithLFPMeta] = None, + current_source_density: Optional[CurrentSourceDensity] = None, + location: str = "See electrode locations", + temporal_subsampling_factor: Optional[float] = 2.0, ): """ @@ -95,9 +91,7 @@ def __init__( self._current_source_density = current_source_density self._location = location self._temporal_subsampling_factor = temporal_subsampling_factor - super().__init__(name=name, - value=None, - is_value_self=True) + super().__init__(name=name, value=None, is_value_self=True) @property def id(self) -> int: @@ -151,45 +145,35 @@ def temporal_subsampling_factor(self) -> Optional[float]: @property def units_table(self) -> pd.DataFrame: - df = pd.DataFrame( - [unit.to_dict()['unit'] for unit in self.units.value]) + df = pd.DataFrame([unit.to_dict()["unit"] for unit in self.units.value]) df = df.fillna(np.nan) return df @classmethod def from_json(cls, probe: dict) -> "Probe": - channels = Channels.from_json(channels=probe['channels']) + channels = Channels.from_json(channels=probe["channels"]) units = Units.from_json(probe=probe) - if probe['lfp'] is not None: - lfp = LFP.from_json( - probe_meta=probe - ) - csd = CurrentSourceDensity.from_json( - probe_meta=probe - ) + if probe["lfp"] is not None: + lfp = LFP.from_json(probe_meta=probe) + csd = CurrentSourceDensity.from_json(probe_meta=probe) else: lfp = None csd = None return Probe( - id=probe['id'], - name=probe['name'], + id=probe["id"], + name=probe["name"], channels=channels, units=units, - sampling_rate=probe['sampling_rate'], + sampling_rate=probe["sampling_rate"], lfp=lfp, current_source_density=csd, - temporal_subsampling_factor=probe['temporal_subsampling_factor'] + temporal_subsampling_factor=probe["temporal_subsampling_factor"], ) @classmethod - def from_nwb( - cls, - nwbfile: NWBFile, - probe_name: str, - lfp_meta: Optional[ProbeWithLFPMeta] = None - ) -> "Probe": + def from_nwb(cls, nwbfile: NWBFile, probe_name: str, lfp_meta: Optional[ProbeWithLFPMeta] = None) -> "Probe": """ Parameters @@ -214,13 +198,10 @@ def from_nwb( sampling_rate=probe.device.sampling_rate, channels=channels, units=units, - lfp_meta=lfp_meta + lfp_meta=lfp_meta, ) - def to_nwb( - self, - nwbfile: NWBFile - ) -> Tuple[NWBFile, Optional[NWBFile]]: + def to_nwb(self, nwbfile: NWBFile) -> Tuple[NWBFile, Optional[NWBFile]]: """ Parameters @@ -237,34 +218,27 @@ def to_nwb( if self._lfp is not None: probe_nwbfile = self.add_lfp_to_nwb( session_id=nwbfile.session_id, - session_metadata=BehaviorEcephysMetadata.from_nwb( - nwbfile=nwbfile), - session_start_time=nwbfile.session_start_time + session_metadata=BehaviorEcephysMetadata.from_nwb(nwbfile=nwbfile), + session_start_time=nwbfile.session_start_time, ) else: - logging.info(f'No LFP data found for probe {self._id}') + logging.info(f"No LFP data found for probe {self._id}") probe_nwbfile = None return nwbfile, probe_nwbfile - def _add_probe_to_nwb( - self, nwbfile: NWBFile, - add_only_lfp_channels: bool = False - ): - logging.info(f'found probe {self._id} with name {self._name}') - - nwbfile, probe_nwb_device, probe_nwb_electrode_group = \ - add_probe_to_nwbfile( - nwbfile, - probe_id=self._id, - name=self._name, - sampling_rate=self._sampling_rate, - lfp_sampling_rate=( - self._lfp.sampling_rate if self._lfp is not None - else np.nan), - has_lfp_data=self._lfp is not None - ) + def _add_probe_to_nwb(self, nwbfile: NWBFile, add_only_lfp_channels: bool = False): + logging.info(f"found probe {self._id} with name {self._name}") + + nwbfile, probe_nwb_device, probe_nwb_electrode_group = add_probe_to_nwbfile( + nwbfile, + probe_id=self._id, + name=self._name, + sampling_rate=self._sampling_rate, + lfp_sampling_rate=(self._lfp.sampling_rate if self._lfp is not None else np.nan), + has_lfp_data=self._lfp is not None, + ) - channels = [c.to_dict()['channel'] for c in self._channels.value] + channels = [c.to_dict()["channel"] for c in self._channels.value] if self._lfp is not None and add_only_lfp_channels: channel_number_whitelist = self._lfp.channels @@ -272,47 +246,39 @@ def _add_probe_to_nwb( channel_number_whitelist = None add_ecephys_electrodes( - nwbfile, - channels, - probe_nwb_electrode_group, - channel_number_whitelist=channel_number_whitelist) + nwbfile, channels, probe_nwb_electrode_group, channel_number_whitelist=channel_number_whitelist + ) return nwbfile - def add_lfp_to_nwb( - self, - session_id: str, - session_start_time: datetime, - session_metadata: BehaviorEcephysMetadata - ): - logging.info(f'writing lfp file for probe {self._id}') + def add_lfp_to_nwb(self, session_id: str, session_start_time: datetime, session_metadata: BehaviorEcephysMetadata): + logging.info(f"writing lfp file for probe {self._id}") nwbfile = pynwb.NWBFile( - session_description='LFP data and associated info for one probe', + session_description="LFP data and associated info for one probe", identifier=f"{self._id}", session_id=f"{session_id}", session_start_time=session_start_time, - institution="Allen Institute for Brain Science" + institution="Allen Institute for Brain Science", ) session_metadata.to_nwb(nwbfile=nwbfile) - nwbfile = self._add_probe_to_nwb( - nwbfile=nwbfile, - add_only_lfp_channels=True - ) + nwbfile = self._add_probe_to_nwb(nwbfile=nwbfile, add_only_lfp_channels=True) lfp_nwb = pynwb.ecephys.LFP(name=f"probe_{self._id}_lfp") electrode_table_region = nwbfile.create_electrode_table_region( region=np.arange(len(nwbfile.electrodes)).tolist(), - name='electrodes', - description=f"lfp channels on probe {self._id}" + name="electrodes", + description=f"lfp channels on probe {self._id}", ) - nwbfile.add_acquisition(lfp_nwb.create_electrical_series( - name=f"probe_{self._id}_lfp_data", - data=self._lfp.data, - timestamps=self._lfp.timestamps, - electrodes=electrode_table_region - )) + nwbfile.add_acquisition( + lfp_nwb.create_electrical_series( + name=f"probe_{self._id}_lfp_data", + data=self._lfp.data, + timestamps=self._lfp.timestamps, + electrodes=electrode_table_region, + ) + ) nwbfile.add_acquisition(lfp_nwb) if self._current_source_density is not None: @@ -320,12 +286,7 @@ def add_lfp_to_nwb( return nwbfile - def _add_csd_to_nwb( - self, - nwbfile: NWBFile, - csd_unit: str = 'V/cm^2', - position_unit: str = "um" - ): + def _add_csd_to_nwb(self, nwbfile: NWBFile, csd_unit: str = "V/cm^2", position_unit: str = "um"): """ Parameters @@ -343,9 +304,7 @@ def _add_csd_to_nwb( """ csd = self._current_source_density - csd_mod = pynwb.ProcessingModule("current_source_density", - "Precalculated current source " - "density") + csd_mod = pynwb.ProcessingModule("current_source_density", "Precalculated current source density") nwbfile.add_processing_module(csd_mod) csd_ts = pynwb.base.TimeSeries( @@ -353,19 +312,19 @@ def _add_csd_to_nwb( data=csd.data.T, # TimeSeries should have data in (time x channels) format timestamps=csd.timestamps.T, - unit=csd_unit + unit=csd_unit, ) - x_locs, y_locs = np.split(csd.channel_locations.astype(np.uint64), - 2, - axis=1) + x_locs, y_locs = np.split(csd.channel_locations.astype(np.uint64), 2, axis=1) - csd = EcephysCSD(name="ecephys_csd", - time_series=csd_ts, - virtual_electrode_x_positions=x_locs.flatten(), - virtual_electrode_x_positions__unit=position_unit, - virtual_electrode_y_positions=y_locs.flatten(), - virtual_electrode_y_positions__unit=position_unit) + csd = EcephysCSD( + name="ecephys_csd", + time_series=csd_ts, + virtual_electrode_x_positions=x_locs.flatten(), + virtual_electrode_x_positions__unit=position_unit, + virtual_electrode_y_positions=y_locs.flatten(), + virtual_electrode_y_positions__unit=position_unit, + ) csd_mod.add_data_interface(csd) @@ -373,15 +332,15 @@ def _add_csd_to_nwb( def _read_lfp_from_nwb(self) -> LFP: if isinstance(self._lfp_meta.lfp_csd_filepath, Callable): - logging.info('Fetching LFP NWB file') + logging.info("Fetching LFP NWB file") path = self._lfp_meta.lfp_csd_filepath() else: path = self._lfp_meta.lfp_csd_filepath - with pynwb.NWBHDF5IO(path, 'r', load_namespaces=True) as f: + with pynwb.NWBHDF5IO(path, "r", load_namespaces=True) as f: nwbfile = f.read() probe = nwbfile.electrode_groups[self._name] - lfp = nwbfile.get_acquisition(f'probe_{self._id}_lfp') - series = lfp.get_electrical_series(f'probe_{self._id}_lfp_data') + lfp = nwbfile.get_acquisition(f"probe_{self._id}_lfp") + series = lfp.get_electrical_series(f"probe_{self._id}_lfp_data") electrodes = nwbfile.electrodes.to_dataframe() @@ -389,22 +348,18 @@ def _read_lfp_from_nwb(self) -> LFP: timestamps = series.timestamps[:] return LFP( - data=data, - timestamps=timestamps, - channels=electrodes.index.values, - sampling_rate=probe.lfp_sampling_rate + data=data, timestamps=timestamps, channels=electrodes.index.values, sampling_rate=probe.lfp_sampling_rate ) def _read_csd_data_from_nwb(self) -> CurrentSourceDensity: if isinstance(self._lfp_meta.lfp_csd_filepath, Callable): - logging.info('Fetching LFP NWB file') + logging.info("Fetching LFP NWB file") path = self._lfp_meta.lfp_csd_filepath() else: path = self._lfp_meta.lfp_csd_filepath - with pynwb.NWBHDF5IO(path, 'r', load_namespaces=True) as f: + with pynwb.NWBHDF5IO(path, "r", load_namespaces=True) as f: nwbfile = f.read() - csd_mod = nwbfile.get_processing_module( - "current_source_density") + csd_mod = nwbfile.get_processing_module("current_source_density") nwb_csd = csd_mod["ecephys_csd"] csd_data = nwb_csd.time_series.data[:] @@ -412,13 +367,13 @@ def _read_csd_data_from_nwb(self) -> CurrentSourceDensity: # want (channels x timepoints) csd_data = csd_data.T - channel_locations = np.stack([ - nwb_csd.virtual_electrode_x_positions, - nwb_csd.virtual_electrode_y_positions], axis=1).astype('int') + channel_locations = np.stack( + [nwb_csd.virtual_electrode_x_positions, nwb_csd.virtual_electrode_y_positions], axis=1 + ).astype("int") return CurrentSourceDensity( data=csd_data, timestamps=nwb_csd.time_series.timestamps[:], - interpolated_channel_locations=channel_locations + interpolated_channel_locations=channel_locations, ) def to_dict(self) -> dict: @@ -433,10 +388,10 @@ def to_dict(self) -> dict: has_lfp_data = True return { - 'id': self._id, - 'name': self._name, - 'location': self._location, - 'sampling_rate': self._sampling_rate, - 'lfp_sampling_rate': lfp_sampling_rate, - 'has_lfp_data': has_lfp_data + "id": self._id, + "name": self._name, + "location": self._location, + "sampling_rate": self._sampling_rate, + "lfp_sampling_rate": lfp_sampling_rate, + "has_lfp_data": has_lfp_data, } diff --git a/allensdk/brain_observatory/ecephys/_unit.py b/allensdk/brain_observatory/ecephys/_unit.py index 5ae40dc03b..50e6bc7b9e 100644 --- a/allensdk/brain_observatory/ecephys/_unit.py +++ b/allensdk/brain_observatory/ecephys/_unit.py @@ -11,45 +11,44 @@ class Unit(DataObject): from a single cell", it is called a "unit" rather than a "neuron" """ def __init__( - self, - id: int, - peak_channel_id: int, - local_index: int, - cluster_id: int, - quality: str, - firing_rate: float, - isi_violations: float, - presence_ratio: float, - amplitude_cutoff: float, - mean_waveforms: np.ndarray, - spike_amplitudes: np.ndarray, - spike_times: np.ndarray, - isolation_distance: Optional[float] = None, - l_ratio: Optional[float] = None, - d_prime: Optional[float] = None, - nn_hit_rate: Optional[float] = None, - nn_miss_rate: Optional[float] = None, - max_drift: Optional[float] = None, - cumulative_drift: Optional[float] = None, - silhouette_score: Optional[float] = None, - waveform_duration: Optional[float] = None, - waveform_halfwidth: Optional[float] = None, - PT_ratio: Optional[float] = None, - repolarization_slope: Optional[float] = None, - recovery_slope: Optional[float] = None, - amplitude: Optional[float] = None, - spread: Optional[float] = None, - velocity_above: Optional[float] = None, - velocity_below: Optional[float] = None, - snr: Optional[float] = None, - filter_and_sort_spikes=True + self, + id: int, + peak_channel_id: int, + local_index: int, + cluster_id: int, + quality: str, + firing_rate: float, + isi_violations: float, + presence_ratio: float, + amplitude_cutoff: float, + mean_waveforms: np.ndarray, + spike_amplitudes: np.ndarray, + spike_times: np.ndarray, + isolation_distance: Optional[float] = None, + l_ratio: Optional[float] = None, + d_prime: Optional[float] = None, + nn_hit_rate: Optional[float] = None, + nn_miss_rate: Optional[float] = None, + max_drift: Optional[float] = None, + cumulative_drift: Optional[float] = None, + silhouette_score: Optional[float] = None, + waveform_duration: Optional[float] = None, + waveform_halfwidth: Optional[float] = None, + PT_ratio: Optional[float] = None, + repolarization_slope: Optional[float] = None, + recovery_slope: Optional[float] = None, + amplitude: Optional[float] = None, + spread: Optional[float] = None, + velocity_above: Optional[float] = None, + velocity_below: Optional[float] = None, + snr: Optional[float] = None, + filter_and_sort_spikes=True, ): - super().__init__(name='unit', - value=None, - is_value_self=True) + super().__init__(name="unit", value=None, is_value_self=True) if filter_and_sort_spikes: spike_times, spike_amplitudes = _get_filtered_and_sorted_spikes( - spike_times=spike_times, spike_amplitudes=spike_amplitudes) + spike_times=spike_times, spike_amplitudes=spike_amplitudes + ) self._id = id self._peak_channel_id = peak_channel_id self._local_index = local_index @@ -211,8 +210,8 @@ def snr(self) -> Optional[float]: def _get_filtered_and_sorted_spikes( - spike_times: np.ndarray, spike_amplitudes: np.ndarray) -> \ - Tuple[np.ndarray, np.ndarray]: + spike_times: np.ndarray, spike_amplitudes: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: """Filter out invalid spike timepoints and sort spike data (times + amplitudes) by times. diff --git a/allensdk/brain_observatory/ecephys/_units.py b/allensdk/brain_observatory/ecephys/_units.py index 7128b88791..f5ef8a0f96 100644 --- a/allensdk/brain_observatory/ecephys/_units.py +++ b/allensdk/brain_observatory/ecephys/_units.py @@ -6,10 +6,8 @@ from allensdk.brain_observatory.ecephys._channels import Channels from allensdk.brain_observatory.ecephys._unit import Unit -from allensdk.brain_observatory.ecephys.utils import load_and_squeeze_npy, \ - scale_amplitudes, group_1d_by_unit -from allensdk.core import DataObject, NwbReadableInterface, \ - JsonReadableInterface +from allensdk.brain_observatory.ecephys.utils import load_and_squeeze_npy, scale_amplitudes, group_1d_by_unit +from allensdk.core import DataObject, NwbReadableInterface, JsonReadableInterface class Units(DataObject, JsonReadableInterface, NwbReadableInterface): @@ -18,14 +16,10 @@ class Units(DataObject, JsonReadableInterface, NwbReadableInterface): """ def __init__(self, units: List[Unit]): - super().__init__(name='units', value=units) + super().__init__(name="units", value=units) @classmethod - def from_json( - cls, - probe: dict, - amplitude_scale_factor=0.195e-6 - ) -> "Units": + def from_json(cls, probe: dict, amplitude_scale_factor=0.195e-6) -> "Units": """ Parameters @@ -38,17 +32,14 @@ def from_json( ------- """ - local_to_global_unit_map = { - unit['cluster_id']: unit['id'] for unit in probe['units']} + local_to_global_unit_map = {unit["cluster_id"]: unit["id"] for unit in probe["units"]} spike_times = _read_spike_times_to_dictionary( - probe['spike_times_path'], - probe['spike_clusters_file'], - local_to_global_unit_map + probe["spike_times_path"], probe["spike_clusters_file"], local_to_global_unit_map ) mean_waveforms = _read_waveforms_to_dictionary( - probe['mean_waveforms_path'], + probe["mean_waveforms_path"], local_to_global_unit_map, - mean_waveform_scale=probe.get('scale_mean_waveform_and_csd', 1) + mean_waveform_scale=probe.get("scale_mean_waveform_and_csd", 1), ) spike_amplitudes = _read_spike_amplitudes_to_dictionary( probe["spike_amplitudes_path"], @@ -57,22 +48,22 @@ def from_json( probe["spike_templates_path"], probe["inverse_whitening_matrix_path"], local_to_global_unit_map=local_to_global_unit_map, - scale_factor=probe.get('amplitude_scale_factor', - amplitude_scale_factor) + scale_factor=probe.get("amplitude_scale_factor", amplitude_scale_factor), ) units = [ - Unit(**unit, - spike_times=spike_times[unit['id']], - spike_amplitudes=spike_amplitudes[unit['id']], - mean_waveforms=mean_waveforms[unit['id']]) - for unit in probe['units'] + Unit( + **unit, + spike_times=spike_times[unit["id"]], + spike_amplitudes=spike_amplitudes[unit["id"]], + mean_waveforms=mean_waveforms[unit["id"]], + ) + for unit in probe["units"] ] units = Units(units=units) return units @classmethod - def from_nwb(cls, nwbfile: NWBFile, - probe_id: Optional[str] = None) -> "Units": + def from_nwb(cls, nwbfile: NWBFile, probe_id: Optional[str] = None) -> "Units": """ Parameters @@ -86,56 +77,51 @@ def from_nwb(cls, nwbfile: NWBFile, """ units = nwbfile.units.to_dataframe() units = units.reset_index() - units = units.rename(columns={'waveform_mean': 'mean_waveforms'}) + units = units.rename(columns={"waveform_mean": "mean_waveforms"}) if probe_id is not None: channels = Channels.from_nwb(nwbfile=nwbfile) - units = units[units['peak_channel_id'].map( - {c.id: c.probe_id for c in channels.value}) == probe_id] + units = units[units["peak_channel_id"].map({c.id: c.probe_id for c in channels.value}) == probe_id] - units = units.to_dict(orient='records') + units = units.to_dict(orient="records") units = [Unit(**unit, filter_and_sort_spikes=False) for unit in units] return Units(units=units) def _read_spike_amplitudes_to_dictionary( - spike_amplitudes_path, spike_units_path, - templates_path, spike_templates_path, inverse_whitening_matrix_path, - local_to_global_unit_map=None, - scale_factor=0.195e-6 + spike_amplitudes_path, + spike_units_path, + templates_path, + spike_templates_path, + inverse_whitening_matrix_path, + local_to_global_unit_map=None, + scale_factor=0.195e-6, ): spike_amplitudes = load_and_squeeze_npy(spike_amplitudes_path) spike_units = load_and_squeeze_npy(spike_units_path) templates = load_and_squeeze_npy(templates_path) spike_templates = load_and_squeeze_npy(spike_templates_path) - inverse_whitening_matrix = \ - load_and_squeeze_npy(inverse_whitening_matrix_path) + inverse_whitening_matrix = load_and_squeeze_npy(inverse_whitening_matrix_path) for temp_idx in range(templates.shape[0]): templates[temp_idx, :, :] = np.dot( - np.ascontiguousarray(templates[temp_idx, :, :]), - np.ascontiguousarray(inverse_whitening_matrix) + np.ascontiguousarray(templates[temp_idx, :, :]), np.ascontiguousarray(inverse_whitening_matrix) ) - scaled_amplitudes = scale_amplitudes(spike_amplitudes, - templates, - spike_templates, - scale_factor=scale_factor) + scaled_amplitudes = scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor=scale_factor) - return group_1d_by_unit(scaled_amplitudes, - spike_units, - local_to_global_unit_map) + return group_1d_by_unit(scaled_amplitudes, spike_units, local_to_global_unit_map) def _read_waveforms_to_dictionary( - waveforms_path, - local_to_global_unit_map=None, - peak_channel_map=None, - mean_waveform_scale=1, + waveforms_path, + local_to_global_unit_map=None, + peak_channel_map=None, + mean_waveform_scale=1, ): - """ Builds a lookup table for unitwise waveform data + """Builds a lookup table for unitwise waveform data Parameters ---------- @@ -160,9 +146,7 @@ def _read_waveforms_to_dictionary( waveforms = np.squeeze(np.load(waveforms_path, allow_pickle=False)) output_waveforms = {} - for unit_id, waveform in enumerate( - np.split(waveforms, waveforms.shape[0], axis=0) - ): + for unit_id, waveform in enumerate(np.split(waveforms, waveforms.shape[0], axis=0)): if local_to_global_unit_map is not None: if unit_id not in local_to_global_unit_map: logging.warning( @@ -180,10 +164,8 @@ def _read_waveforms_to_dictionary( return output_waveforms -def _read_spike_times_to_dictionary( - spike_times_path, spike_units_path, local_to_global_unit_map=None -): - """ Reads spike times and assigned units from npy files into a lookup +def _read_spike_times_to_dictionary(spike_times_path, spike_units_path, local_to_global_unit_map=None): + """Reads spike times and assigned units from npy files into a lookup table. Parameters diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py b/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py index 116d9cb738..b84bd4f501 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/__main__.py @@ -1,13 +1,9 @@ import numpy as np -from allensdk.brain_observatory.argschema_utilities import \ - ArgSchemaParserPlus, \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import ArgSchemaParserPlus, write_or_print_outputs from ._schemas import InputParameters, OutputParameters from .barcode_sync_dataset import BarcodeSyncDataset -from .channel_states import extract_barcodes_from_states, \ - extract_splits_from_states, \ - extract_splits_from_barcode_times +from .channel_states import extract_barcodes_from_states, extract_splits_from_states, extract_splits_from_barcode_times from .probe_synchronizer import ProbeSynchronizer @@ -26,13 +22,9 @@ def align_timestamps(args): probe_barcode_times, probe_barcodes = extract_barcodes_from_states( channel_states, timestamps, probe["sampling_rate"] ) - probe_split_times = extract_splits_from_states( - channel_states, timestamps, probe["sampling_rate"] - ) + probe_split_times = extract_splits_from_states(channel_states, timestamps, probe["sampling_rate"]) - barcode_split_times = extract_splits_from_barcode_times( - probe_barcode_times - ) + barcode_split_times = extract_splits_from_barcode_times(probe_barcode_times) probe_split_times = np.union1d(probe_split_times, barcode_split_times) @@ -42,7 +34,6 @@ def align_timestamps(args): synchronizers = [] for idx, split_time in enumerate(probe_split_times): - min_time = probe_split_times[idx] if idx == (len(probe_split_times) - 1): @@ -72,31 +63,17 @@ def align_timestamps(args): for synchronizer in synchronizers: aligned_timestamps = synchronizer(aligned_timestamps) - print( - "total time shift: " + str(synchronizer.total_time_shift)) - print( - "actual sampling rate: " - + str(synchronizer.global_probe_sampling_rate) - ) - - np.save( - timestamp_file["output_path"], aligned_timestamps, - allow_pickle=False - ) - mapped_files[timestamp_file["name"]] = timestamp_file[ - "output_path"] + print("total time shift: " + str(synchronizer.total_time_shift)) + print("actual sampling rate: " + str(synchronizer.global_probe_sampling_rate)) - lfp_sampling_rate = ( - probe["lfp_sampling_rate"] * synchronizer.sampling_rate_scale - ) + np.save(timestamp_file["output_path"], aligned_timestamps, allow_pickle=False) + mapped_files[timestamp_file["name"]] = timestamp_file["output_path"] + + lfp_sampling_rate = probe["lfp_sampling_rate"] * synchronizer.sampling_rate_scale - this_probe_output_info[ - "total_time_shift"] = synchronizer.total_time_shift - this_probe_output_info[ - "global_probe_sampling_rate" - ] = synchronizer.global_probe_sampling_rate - this_probe_output_info[ - "global_probe_lfp_sampling_rate"] = lfp_sampling_rate + this_probe_output_info["total_time_shift"] = synchronizer.total_time_shift + this_probe_output_info["global_probe_sampling_rate"] = synchronizer.global_probe_sampling_rate + this_probe_output_info["global_probe_lfp_sampling_rate"] = lfp_sampling_rate this_probe_output_info["output_paths"] = mapped_files this_probe_output_info["name"] = probe["name"] this_probe_output_info["split_times"] = probe_split_times @@ -107,9 +84,7 @@ def align_timestamps(args): def main(): - mod = ArgSchemaParserPlus( - schema_type=InputParameters, output_schema_type=OutputParameters - ) + mod = ArgSchemaParserPlus(schema_type=InputParameters, output_schema_type=OutputParameters) output = align_timestamps(mod.args) write_or_print_outputs(data=output, parser=mod) diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py b/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py index df3a5fa4f7..a4eb81e6b2 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/_schemas.py @@ -16,7 +16,7 @@ class ProbeMappable(DefaultSchema): output_path = String( required=True, help="""Output path for the mapped version of this file. Will write a 1D - timestamps array with values in seconds on the master clock.""" + timestamps array with values in seconds on the master clock.""", ) @@ -28,12 +28,14 @@ class ProbeInputParameters(DefaultSchema): the probe clock.""", ) lfp_sampling_rate = Float( - required=True, help="""The sampling rate of the LFP collected on this - probe.""" + required=True, + help="""The sampling rate of the LFP collected on this + probe.""", ) start_index = Int( - default=0, help="""Sample index of probe recording start time. - Defaults to 0.""" + default=0, + help="""Sample index of probe recording start time. + Defaults to 0.""", ) barcode_channel_states_path = String( required=True, @@ -66,7 +68,7 @@ class InputParameters(ArgSchema): sync_h5_path = String( required=True, help="""path to h5 file containing syncronization - information""" + information""", ) @@ -95,7 +97,7 @@ class ProbeOutputParameters(DefaultSchema): Float(), required=True, help="""Start/stop times of likely dropped data, due to gaps in - recording or irregular barcode intervals""" + recording or irregular barcode intervals""", ) @@ -108,6 +110,4 @@ class OutputSchema(DefaultSchema): class OutputParameters(OutputSchema): - probe_outputs = Nested( - ProbeOutputParameters, many="True", help="Probewise outputs." - ) + probe_outputs = Nested(ProbeOutputParameters, many="True", help="Probewise outputs.") diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py b/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py index 31503e54fa..31a2cb7716 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/barcode.py @@ -48,28 +48,15 @@ def extract_barcodes_from_times( barcodes = [] for i, t in enumerate(barcode_start_times): - - oncode = on_times[ - np.where( - np.logical_and(on_times > t, - on_times < t + barcode_duration_ceiling) - )[0] - ] - offcode = off_times[ - np.where( - np.logical_and(off_times > t, - off_times < t + barcode_duration_ceiling) - )[0] - ] + oncode = on_times[np.where(np.logical_and(on_times > t, on_times < t + barcode_duration_ceiling))[0]] + offcode = off_times[np.where(np.logical_and(off_times > t, off_times < t + barcode_duration_ceiling))[0]] if len(offcode) > 0: - currTime = offcode[0] bits = np.zeros((nbits,)) for bit in range(0, nbits): - nextOn = np.where(oncode > currTime)[0] nextOff = np.where(offcode > currTime)[0] @@ -99,9 +86,7 @@ def extract_barcodes_from_times( return barcode_start_times, barcodes -def find_matching_index(master_barcodes, - probe_barcodes, - alignment_type="start"): +def find_matching_index(master_barcodes, probe_barcodes, alignment_type="start"): """Given a set of barcodes for the master clock and the probe clock, find the indices of a matching set, either starting from the beginning or the end of the list. @@ -135,10 +120,7 @@ def find_matching_index(master_barcodes, direction = -1 while not foundMatch and abs(probe_barcode_index) < len(probe_barcodes): - - master_barcode_index = np.where( - master_barcodes == probe_barcodes[probe_barcode_index] - )[0] + master_barcode_index = np.where(master_barcodes == probe_barcodes[probe_barcode_index])[0] assert len(master_barcode_index) < 2 @@ -187,9 +169,7 @@ def match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes): """ - master_start_index, probe_start_index = find_matching_index( - master_barcodes, probe_barcodes, alignment_type="start" - ) + master_start_index, probe_start_index = find_matching_index(master_barcodes, probe_barcodes, alignment_type="start") if master_start_index is not None: t_m_start = master_times[master_start_index] @@ -202,10 +182,7 @@ def match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes): print("Master start index: " + str(master_start_index)) if len(probe_barcodes) > 2: - master_end_index, probe_end_index = \ - find_matching_index(master_barcodes, - probe_barcodes, - alignment_type='end') + master_end_index, probe_end_index = find_matching_index(master_barcodes, probe_barcodes, alignment_type="end") if probe_end_index is not None: print("Probe end index: " + str(probe_end_index)) @@ -303,12 +280,8 @@ def get_probe_time_offset( """ - probe_endpoints, master_endpoints = match_barcodes( - master_times, master_barcodes, probe_times, probe_barcodes - ) - rate_scale, time_offset = linear_transform_from_intervals( - master_endpoints, probe_endpoints - ) + probe_endpoints, master_endpoints = match_barcodes(master_times, master_barcodes, probe_times, probe_barcodes) + rate_scale, time_offset = linear_transform_from_intervals(master_endpoints, probe_endpoints) if time_offset is not None: probe_rate = local_probe_rate * rate_scale diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py b/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py index 233d021f02..b76fb65be3 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/barcode_sync_dataset.py @@ -10,9 +10,7 @@ class BarcodeSyncDataset(EcephysSyncDataset): @property def barcode_line(self): - """ Obtain the index of the barcode line for this dataset. - - """ + """Obtain the index of the barcode line for this dataset.""" if "barcode" in self.line_labels: return self.line_labels.index("barcode") @@ -24,7 +22,7 @@ def barcode_line(self): raise ValueError("no barcode line found") def extract_barcodes(self, **barcode_kwargs): - """ Read barcodes and their times from this dataset's barcode line. + """Read barcodes and their times from this dataset's barcode line. Parameters ---------- @@ -49,12 +47,10 @@ def extract_barcodes(self, **barcode_kwargs): on_times = on_events / sample_freq_digital off_times = off_events / sample_freq_digital - return barcode.extract_barcodes_from_times( - on_times, off_times, **barcode_kwargs - ) + return barcode.extract_barcodes_from_times(on_times, off_times, **barcode_kwargs) def get_barcode_table(self, **barcode_kwargs): - """ A convenience method for getting barcode times and codes in a dictionary. + """A convenience method for getting barcode times and codes in a dictionary. Notes ----- diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py b/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py index 91727b0da0..2c725aa8c4 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/channel_states.py @@ -3,9 +3,7 @@ from . import barcode -def extract_barcodes_from_states( - channel_states, timestamps, sampling_rate, **barcode_kwargs -): +def extract_barcodes_from_states(channel_states, timestamps, sampling_rate, **barcode_kwargs): """Obtain barcodes from timestamped rising/falling edges. Parameters @@ -31,9 +29,7 @@ def extract_barcodes_from_states( return barcode.extract_barcodes_from_times(T_on, T_off, **barcode_kwargs) -def extract_splits_from_states( - channel_states, timestamps, sampling_rate, **barcode_kwargs -): +def extract_splits_from_states(channel_states, timestamps, sampling_rate, **barcode_kwargs): """Obtain data split times from timestamped rising/falling edges. Parameters @@ -60,10 +56,7 @@ def extract_splits_from_states( return T_split -def extract_splits_from_barcode_times( - barcode_times, - tolerance=0.0001 -): +def extract_splits_from_barcode_times(barcode_times, tolerance=0.0001): """Determine locations of likely dropped data from barcode times Parameters ---------- @@ -77,17 +70,17 @@ def extract_splits_from_barcode_times( median_interval = np.median(barcode_intervals) - irregular_intervals = np.where(np.abs(barcode_intervals - median_interval) - > tolerance * median_interval)[0] + irregular_intervals = np.where(np.abs(barcode_intervals - median_interval) > tolerance * median_interval)[0] T_split = [0] for i in irregular_intervals: + T_split.append(barcode_times[i - 1]) - T_split.append(barcode_times[i-1]) - - if i+1 < len(barcode_times): - T_split.append(barcode_times[i+1]) + if i + 1 < len(barcode_times): + T_split.append(barcode_times[i + 1]) return np.array(T_split) + + # diff --git a/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py b/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py index 22a239e5a1..a094454352 100644 --- a/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py +++ b/allensdk/brain_observatory/ecephys/align_timestamps/probe_synchronizer.py @@ -6,7 +6,7 @@ class ProbeSynchronizer(object): @property def sampling_rate_scale(self): - """ The ratio of the probe's sampling rate assessed on the global clock to the + """The ratio of the probe's sampling rate assessed on the global clock to the probe's locally assessed sampling rate. """ @@ -69,21 +69,14 @@ def __call__(self, samples, sync_condition="master"): )[0] if self.global_probe_sampling_rate > 0: - if sync_condition == "probe": - samples[in_range] = samples[in_range] / \ - self.local_probe_sampling_rate + samples[in_range] = samples[in_range] / self.local_probe_sampling_rate elif sync_condition == "master": - samples[in_range] = ( - samples[in_range] / self.global_probe_sampling_rate - - self.total_time_shift - ) + samples[in_range] = samples[in_range] / self.global_probe_sampling_rate - self.total_time_shift else: - raise ValueError( - "unrecognized sync condition: {}".format(sync_condition) - ) + raise ValueError("unrecognized sync condition: {}".format(sync_condition)) else: samples[in_range] = -1 @@ -138,25 +131,22 @@ def compute( times_array = np.array(probe_barcode_times) barcodes_array = np.array(probe_barcodes) - ok_barcodes = np.where((times_array > min_time) * - (times_array < max_time))[0] + ok_barcodes = np.where((times_array > min_time) * (times_array < max_time))[0] ok_barcodes = ok_barcodes[ok_barcodes < len(barcodes_array)] times_to_align = list(times_array[ok_barcodes]) barcodes_to_align = list(barcodes_array[ok_barcodes]) if len(barcodes_to_align) > 0: - print("Num barcodes: " + str(len(barcodes_to_align))) - total_time_shift, global_probe_sampling_rate, _ = \ - barcode.get_probe_time_offset( - master_barcode_times, - master_barcodes, - times_to_align, - barcodes_to_align, - probe_start_index, - local_probe_sampling_rate, - ) + total_time_shift, global_probe_sampling_rate, _ = barcode.get_probe_time_offset( + master_barcode_times, + master_barcodes, + times_to_align, + barcodes_to_align, + probe_start_index, + local_probe_sampling_rate, + ) else: print("Not enough barcodes...setting sampling rate to 0") diff --git a/allensdk/brain_observatory/ecephys/behavior_ecephys_session.py b/allensdk/brain_observatory/ecephys/behavior_ecephys_session.py index ff28ee83da..1804865fdf 100644 --- a/allensdk/brain_observatory/ecephys/behavior_ecephys_session.py +++ b/allensdk/brain_observatory/ecephys/behavior_ecephys_session.py @@ -67,9 +67,7 @@ def from_lims( eye_tracking_z_threshold: float = 3.0, eye_tracking_dilation_frames: int = 2, ) -> "VBNBehaviorSession": - raise NotImplementedError( - "from_lims is not supported for a VBNBehaviorSession" - ) + raise NotImplementedError("from_lims is not supported for a VBNBehaviorSession") @classmethod def _read_stimuli( @@ -148,15 +146,10 @@ def _read_licks( """ if sync_file is None: - msg = ( - f"{cls}._read_licks requires a sync_file; " - "you passed in sync_file=None" - ) + msg = f"{cls}._read_licks requires a sync_file; you passed in sync_file=None" raise ValueError(msg) - lick_times = StimulusTimestamps( - timestamps=sync_file.data["lick_times"], monitor_delay=0.0 - ) + lick_times = StimulusTimestamps(timestamps=sync_file.data["lick_times"], monitor_delay=0.0) # get the timestamps of the behavior stimulus presentations beh_stim_times = cls._read_behavior_stimulus_timestamps( @@ -175,9 +168,7 @@ def _read_licks( min_time = beh_stim_times_no_monitor.value.min() max_time = beh_stim_times_no_monitor.value.max() - valid = np.logical_and( - lick_times.value >= min_time, lick_times.value <= max_time - ) + valid = np.logical_and(lick_times.value >= min_time, lick_times.value <= max_time) lick_times = lick_times.value[valid] @@ -188,16 +179,10 @@ def _read_licks( # is as close as we can get to the time when it reads the nidaq # buffer" - lick_frames = np.searchsorted( - beh_stim_times_no_monitor.value, lick_times - ) + lick_frames = np.searchsorted(beh_stim_times_no_monitor.value, lick_times) if len(lick_frames) != len(lick_times): - msg = ( - f"{len(lick_frames)} lick frames; " - f"{len(lick_times)} lick timestamps " - "in the Sync file. Should be equal" - ) + msg = f"{len(lick_frames)} lick frames; {len(lick_times)} lick timestamps in the Sync file. Should be equal" raise RuntimeError(msg) df = pd.DataFrame({"timestamps": lick_times, "frame": lick_frames}) @@ -228,9 +213,7 @@ def _read_eye_tracking_table( } camera_line = exposure_sync_line_label_dict[camera_label] - lost_frames = get_lost_frames( - eye_tracking_metadata=eye_tracking_metadata_file - ) + lost_frames = get_lost_frames(eye_tracking_metadata=eye_tracking_metadata_file) frame_times = sync_utilities.get_synchronized_frame_times( session_sync_file=sync_file.filepath, @@ -239,9 +222,7 @@ def _read_eye_tracking_table( trim_after_spike=False, ) - stimulus_timestamps = StimulusTimestamps( - timestamps=frame_times.values, monitor_delay=0.0 - ) + stimulus_timestamps = StimulusTimestamps(timestamps=frame_times.values, monitor_delay=0.0) return EyeTrackingTable.from_data_file( data_file=eye_tracking_file, @@ -313,9 +294,7 @@ def __init__( task_parameters=behavior_session._task_parameters, trials=behavior_session._trials, eye_tracking_table=behavior_session._eye_tracking, - eye_tracking_rig_geometry=( - behavior_session._eye_tracking_rig_geometry - ), + eye_tracking_rig_geometry=(behavior_session._eye_tracking_rig_geometry), ) self._probes = probes self._optotagging_table = optotagging_table @@ -353,9 +332,7 @@ def optotagging_table(self) -> pd.DataFrame: @property def metadata(self) -> dict: behavior_meta = super()._get_metadata(behavior_metadata=self._metadata) - ecephys_meta = { - "ecephys_session_id": self._metadata.ecephys_session_id - } + ecephys_meta = {"ecephys_session_id": self._metadata.ecephys_session_id} return {**behavior_meta, **ecephys_meta} @property @@ -403,12 +380,7 @@ def get_channels(self, filter_by_validity: bool = True) -> pd.DataFrame: ------- `pd.DataFrame` of channels """ - return pd.concat( - [ - p.channels.to_dataframe(filter_by_validity=filter_by_validity) - for p in self._probes.probes - ] - ) + return pd.concat([p.channels.to_dataframe(filter_by_validity=filter_by_validity) for p in self._probes.probes]) def get_units( self, @@ -492,19 +464,12 @@ def from_json( behavior_session = cls.behavior_data_class().from_json( session_data=session_data, read_stimulus_presentations_table_from_file=True, - stimulus_presentation_exclude_columns=( - stimulus_presentation_exclude_columns - ), + stimulus_presentation_exclude_columns=(stimulus_presentation_exclude_columns), sync_file_permissive=True, eye_tracking_drop_frames=True, - running_speed_load_from_multiple_stimulus_files=( - running_speed_load_from_multiple_stimulus_files - ), - ) - probes = Probes.from_json( - probes=session_data["probes"], - skip_probes=skip_probes + running_speed_load_from_multiple_stimulus_files=(running_speed_load_from_multiple_stimulus_files), ) + probes = Probes.from_json(probes=session_data["probes"], skip_probes=skip_probes) optotagging_table = OptotaggingTable.from_json(dict_repr=session_data) return BehaviorEcephysSession( @@ -540,9 +505,7 @@ def to_nwb(self) -> Tuple[NWBFile, Dict[str, Optional[NWBFile]]]: def from_nwb( cls, nwbfile: NWBFile, - probe_meta: Optional[ - Dict[str, ProbeWithLFPMeta] - ] = None, + probe_meta: Optional[Dict[str, ProbeWithLFPMeta]] = None, **kwargs, ) -> "BehaviorEcephysSession": """ @@ -559,14 +522,10 @@ def from_nwb( instantiated `BehaviorEcephysSession` """ kwargs["add_is_change_to_stimulus_presentations_table"] = False - behavior_session = cls.behavior_data_class().from_nwb( - nwbfile=nwbfile, **kwargs - ) + behavior_session = cls.behavior_data_class().from_nwb(nwbfile=nwbfile, **kwargs) return BehaviorEcephysSession( behavior_session=behavior_session, - probes=Probes.from_nwb( - nwbfile=nwbfile, probe_lfp_meta_map=probe_meta - ), + probes=Probes.from_nwb(nwbfile=nwbfile, probe_lfp_meta_map=probe_meta), optotagging_table=OptotaggingTable.from_nwb(nwbfile=nwbfile), metadata=BehaviorEcephysMetadata.from_nwb(nwbfile=nwbfile), ) @@ -580,8 +539,6 @@ def _get_probe(self, probe_id: int): if len(probe) == 0: raise ValueError(f"Could not find probe with id {probe_id}") if len(probe) > 1: - raise RuntimeError( - f"Multiple probes found with probe_id " f"{probe_id}" - ) + raise RuntimeError(f"Multiple probes found with probe_id {probe_id}") probe = probe[0] return probe diff --git a/allensdk/brain_observatory/ecephys/copy_utility/__main__.py b/allensdk/brain_observatory/ecephys/copy_utility/__main__.py index fdc59e5eaf..d953b42a76 100644 --- a/allensdk/brain_observatory/ecephys/copy_utility/__main__.py +++ b/allensdk/brain_observatory/ecephys/copy_utility/__main__.py @@ -8,22 +8,15 @@ import argschema from allensdk.config.manifest import Manifest -from ._schemas import ( - SessionUploadInputSchema, - SessionUploadOutputSchema, - available_hashers -) +from ._schemas import SessionUploadInputSchema, SessionUploadOutputSchema, available_hashers def hash_file(path, hasher_cls, blocks_per_chunk=128): - """ - - """ + """ """ hasher = hasher_cls() - with open(path, 'rb') as f: + with open(path, "rb") as f: # TODO: Update to new assignment syntax if drop < python 3.8 support - for chunk in iter( - lambda: f.read(hasher.block_size*blocks_per_chunk), b""): + for chunk in iter(lambda: f.read(hasher.block_size * blocks_per_chunk), b""): hasher.update(chunk) return hasher.digest() @@ -38,13 +31,12 @@ def walk_fs_tree(root, fn): def copy_file_entry(source, dest, use_rsync, make_parent_dirs, chmod=None): - leftmost = None if make_parent_dirs: leftmost = Manifest.safe_make_parent_dirs(dest) if use_rsync: - sp.check_call(['rsync', '-a', source, dest]) + sp.check_call(["rsync", "-a", source, dest]) else: if Path(source).is_dir(): shutil.copytree(source, dest) @@ -77,16 +69,11 @@ def compare(source, dest, hasher_cls, raise_if_comparison_fails): dest_path = Path(dest) if source_path.is_dir() and dest_path.is_dir(): - return compare_directories( - source, dest, hasher_cls, raise_if_comparison_fails) + return compare_directories(source, dest, hasher_cls, raise_if_comparison_fails) elif (not source_path.is_dir()) and (not dest_path.is_dir()): - return compare_files( - source, dest, hasher_cls, raise_if_comparison_fails) + return compare_files(source, dest, hasher_cls, raise_if_comparison_fails) else: - raise_or_warn( - f"unable to compare files with directories: {source}, {dest}", - raise_if_comparison_fails - ) + raise_or_warn(f"unable to compare files with directories: {source}, {dest}", raise_if_comparison_fails) def compare_files(source, dest, hasher_cls, raise_if_comparison_fails): @@ -95,8 +82,8 @@ def compare_files(source, dest, hasher_cls, raise_if_comparison_fails): if source_hash != dest_hash: raise_or_warn( - f"comparison of {source} and {dest} " - f"using {hasher_cls.__name__} failed", raise_if_comparison_fails) + f"comparison of {source} and {dest} using {hasher_cls.__name__} failed", raise_if_comparison_fails + ) return source_hash, dest_hash @@ -107,9 +94,8 @@ def compare_directories(source, dest, hasher_cls, raise_if_comparison_fails): if len(source_contents) != len(dest_contents): raise_or_warn( - f"{source} contains {len(source_contents)} items " - f"while {dest} contains {len(dest_contents)} items", - raise_if_comparison_fails + f"{source} contains {len(source_contents)} items while {dest} contains {len(dest_contents)} items", + raise_if_comparison_fails, ) for sitem, ditem in zip(source_contents, dest_contents): @@ -117,21 +103,12 @@ def compare_directories(source, dest, hasher_cls, raise_if_comparison_fails): dpath = str(Path(dest, ditem)) if sitem != ditem: - raise_or_warn( - f"mismatch between {spath} and {dpath}", - raise_if_comparison_fails - ) + raise_or_warn(f"mismatch between {spath} and {dpath}", raise_if_comparison_fails) compare(spath, dpath, hasher_cls, raise_if_comparison_fails) def main( - files, - use_rsync=True, - hasher_key=None, - raise_if_comparison_fails=True, - make_parent_dirs=True, - chmod=775, - **kwargs + files, use_rsync=True, hasher_key=None, raise_if_comparison_fails=True, make_parent_dirs=True, chmod=775, **kwargs ): hasher_cls = available_hashers[hasher_key] output = [] @@ -139,28 +116,21 @@ def main( for file_entry in files: record = cp.deepcopy(file_entry) - copy_file_entry( - file_entry['source'], file_entry['destination'], - use_rsync, make_parent_dirs, chmod=chmod - ) + copy_file_entry(file_entry["source"], file_entry["destination"], use_rsync, make_parent_dirs, chmod=chmod) if hasher_cls is not None: - hashes = compare( - file_entry['source'], file_entry['destination'], - hasher_cls, raise_if_comparison_fails - ) + hashes = compare(file_entry["source"], file_entry["destination"], hasher_cls, raise_if_comparison_fails) if hashes is not None: - record['source_hash'] = [int(ii) for ii in hashes[0]] - record['destination_hash'] = [int(ii) for ii in hashes[1]] + record["source_hash"] = [int(ii) for ii in hashes[0]] + record["destination_hash"] = [int(ii) for ii in hashes[1]] output.append(record) - return {'files': output} + return {"files": output} -if __name__ == '__main__': - logging.basicConfig( - format='%(asctime)s - %(process)s - %(levelname)s - %(message)s') +if __name__ == "__main__": + logging.basicConfig(format="%(asctime)s - %(process)s - %(levelname)s - %(message)s") parser = argschema.ArgSchemaParser( schema_type=SessionUploadInputSchema, diff --git a/allensdk/brain_observatory/ecephys/copy_utility/_schemas.py b/allensdk/brain_observatory/ecephys/copy_utility/_schemas.py index ea3c7f7c14..db6dfeab0e 100644 --- a/allensdk/brain_observatory/ecephys/copy_utility/_schemas.py +++ b/allensdk/brain_observatory/ecephys/copy_utility/_schemas.py @@ -1,15 +1,10 @@ import hashlib from argschema import ArgSchema -from argschema.fields import ( - LogLevel, String, Int, Nested, Boolean, List, InputFile) +from argschema.fields import LogLevel, String, Int, Nested, Boolean, List, InputFile from argschema.schemas import DefaultSchema -available_hashers = { - 'sha3_256': hashlib.sha3_256, - 'sha256': hashlib.sha256, - None: None -} +available_hashers = {"sha3_256": hashlib.sha3_256, "sha256": hashlib.sha256, None: None} class FileExists(InputFile): @@ -17,58 +12,47 @@ class FileExists(InputFile): class FileToCopy(DefaultSchema): - source = InputFile( - required=True, - description='copy from here') - destination = String( - required=True, - description='copy to here (full path, not just directory!)') - key = String(required=True, - description='will be passed through to outputs, allowing a ' - 'name or kind to be associated with this file') + source = InputFile(required=True, description="copy from here") + destination = String(required=True, description="copy to here (full path, not just directory!)") + key = String( + required=True, + description="will be passed through to outputs, allowing a name or kind to be associated with this file", + ) class CopiedFile(DefaultSchema): - source = InputFile(required=True, description='copied from here') - destination = FileExists(required=True, description='copied to here') - key = String(required=False, description='passed from inputs') - source_hash = List(Int, - required=False) # int array vs bytes for JSONability + source = InputFile(required=True, description="copied from here") + destination = FileExists(required=True, description="copied to here") + key = String(required=False, description="passed from inputs") + source_hash = List(Int, required=False) # int array vs bytes for JSONability destination_hash = List(Int, required=False) class NonFileParameters(DefaultSchema): - use_rsync = Boolean(default=True, - description='copy files using rsync rather than ' - 'shutil (this is not likely to work if ' - 'you are running windows!)' - ) - hasher_key = String(default='sha256', - validate=lambda st: st in available_hashers, - allow_none=True, - description='select a hash function to compute over ' - 'base64-encoded pre- and post-copy files' - ) - raise_if_comparison_fails = Boolean(default=True, - description='if a hash comparison ' - 'fails, throw an error (' - 'vs. a warning)') - make_parent_dirs = Boolean(default=True, - description='build missing parent directories ' - 'for destination') - chmod = Int(default=775, - description="destination files (and any created parents will " - "have these permissions") + use_rsync = Boolean( + default=True, + description="copy files using rsync rather than " + "shutil (this is not likely to work if " + "you are running windows!)", + ) + hasher_key = String( + default="sha256", + validate=lambda st: st in available_hashers, + allow_none=True, + description="select a hash function to compute over base64-encoded pre- and post-copy files", + ) + raise_if_comparison_fails = Boolean( + default=True, description="if a hash comparison fails, throw an error (vs. a warning)" + ) + make_parent_dirs = Boolean(default=True, description="build missing parent directories for destination") + chmod = Int(default=775, description="destination files (and any created parents will have these permissions") class SessionUploadInputSchema(ArgSchema, NonFileParameters): - log_level = LogLevel(default='INFO', - description='set the logging level of the module') - files = Nested(FileToCopy, many=True, required=True, - description='files to be copied') + log_level = LogLevel(default="INFO", description="set the logging level of the module") + files = Nested(FileToCopy, many=True, required=True, description="files to be copied") class SessionUploadOutputSchema(DefaultSchema): input_parameters = Nested(NonFileParameters) - files = Nested(CopiedFile, many=True, required=True, - description='copied files') + files = Nested(CopiedFile, many=True, required=True, description="copied files") diff --git a/allensdk/brain_observatory/ecephys/current_source_density/__main__.py b/allensdk/brain_observatory/ecephys/current_source_density/__main__.py index d406a8c478..eaeb183a01 100644 --- a/allensdk/brain_observatory/ecephys/current_source_density/__main__.py +++ b/allensdk/brain_observatory/ecephys/current_source_density/__main__.py @@ -13,220 +13,193 @@ from scipy.spatial.qhull import QhullError -from allensdk.brain_observatory.ecephys.current_source_density._schemas \ - import \ - InputParameters, OutputParameters -from allensdk.brain_observatory.ecephys.current_source_density.\ - _current_source_density import ( - accumulate_lfp_data, - compute_csd, - extract_trial_windows - ) -from allensdk.brain_observatory.ecephys.current_source_density._filter_utils \ - import filter_lfp_channels, select_good_channels -from allensdk.brain_observatory.ecephys.current_source_density\ - ._interpolation_utils import ( - interp_channel_locs, - make_actual_channel_locations, - make_interp_channel_locations - ) -from allensdk.brain_observatory.ecephys.file_io.continuous_file import ( - ContinuousFile +from allensdk.brain_observatory.ecephys.current_source_density._schemas import InputParameters, OutputParameters +from allensdk.brain_observatory.ecephys.current_source_density._current_source_density import ( + accumulate_lfp_data, + compute_csd, + extract_trial_windows, +) +from allensdk.brain_observatory.ecephys.current_source_density._filter_utils import ( + filter_lfp_channels, + select_good_channels, ) -from allensdk.brain_observatory.argschema_utilities import ( - write_or_print_outputs, optional_lims_inputs +from allensdk.brain_observatory.ecephys.current_source_density._interpolation_utils import ( + interp_channel_locs, + make_actual_channel_locations, + make_interp_channel_locations, ) +from allensdk.brain_observatory.ecephys.file_io.continuous_file import ContinuousFile +from allensdk.brain_observatory.argschema_utilities import write_or_print_outputs, optional_lims_inputs -from allensdk.brain_observatory.ecephys.lfp_subsampling.subsampling \ - import remove_lfp_noise +from allensdk.brain_observatory.ecephys.lfp_subsampling.subsampling import remove_lfp_noise def get_inputs_from_lims(args) -> dict: - session_id = args.session_id output_root = args.output_root host = args.host - request_str = ''.join(''' + request_str = "".join( + """ {}/input_jsons? strategy_class=EcephysCurrentSourceDensityStrategy& object_id={}& object_class=EcephysSession& job_queue_name=ECEPHYS_CURRENT_SOURCE_DENSITY_QUEUE - '''.format(host, session_id).split()) + """.format(host, session_id).split() + ) response = requests.get(request_str) data = response.json() - if data['num_trials'] == 'null': - data['num_trials'] = None + if data["num_trials"] == "null": + data["num_trials"] = None else: - data['num_trials'] = int(data['num_trials']) + data["num_trials"] = int(data["num_trials"]) - data['pre_stimulus_time'] = float(data['pre_stimulus_time']) - data['post_stimulus_time'] = float(data['post_stimulus_time']) - data['surface_channel_adjustment'] = int( - data['surface_channel_adjustment'] - ) + data["pre_stimulus_time"] = float(data["pre_stimulus_time"]) + data["post_stimulus_time"] = float(data["post_stimulus_time"]) + data["surface_channel_adjustment"] = int(data["surface_channel_adjustment"]) - for probe in data['probes']: - probe['surface_channel_adjustment'] = int( - probe['surface_channel_adjustment'] - ) - probe['csd_output_path'] = os.path.join( - output_root, os.path.split(probe['csd_output_path'])[-1] - ) - probe['phase'] = str(probe['phase']) + for probe in data["probes"]: + probe["surface_channel_adjustment"] = int(probe["surface_channel_adjustment"]) + probe["csd_output_path"] = os.path.join(output_root, os.path.split(probe["csd_output_path"])[-1]) + probe["phase"] = str(probe["phase"]) return data def run_csd(args: dict) -> dict: - - stimulus_table = pd.read_csv(args['stimulus']['stimulus_table_path']) + stimulus_table = pd.read_csv(args["stimulus"]["stimulus_table_path"]) # backwards compatibility - stimulus_table['stimulus_name'] = stimulus_table['stimulus_name'].apply( - lambda x: args['stimulus']['key'] if x == 'flashes' else x) - if args['start_field'] not in stimulus_table: - stimulus_table = stimulus_table.rename( - columns={'Start': args['start_field']}) + stimulus_table["stimulus_name"] = stimulus_table["stimulus_name"].apply( + lambda x: args["stimulus"]["key"] if x == "flashes" else x + ) + if args["start_field"] not in stimulus_table: + stimulus_table = stimulus_table.rename(columns={"Start": args["start_field"]}) probewise_outputs = [] - for probe_idx, probe in enumerate(args['probes']): - logging.info('Processing probe: {} (index: {})'.format(probe['name'], - probe_idx)) + for probe_idx, probe in enumerate(args["probes"]): + logging.info("Processing probe: {} (index: {})".format(probe["name"], probe_idx)) - time_step = 1.0 / probe['sampling_rate'] - logging.info('Calculated time step: {}'.format(time_step)) + time_step = 1.0 / probe["sampling_rate"] + logging.info("Calculated time step: {}".format(time_step)) - logging.info('Extracting trial windows') + logging.info("Extracting trial windows") trial_windows, relative_window = extract_trial_windows( stimulus_table=stimulus_table, - stimulus_name=args['stimulus']['key'], + stimulus_name=args["stimulus"]["key"], time_step=time_step, - pre_stimulus_time=args['pre_stimulus_time'], - post_stimulus_time=args['post_stimulus_time'], - num_trials=args['num_trials'], - stimulus_index=args['stimulus']['index'], - start_field=args['start_field'] + pre_stimulus_time=args["pre_stimulus_time"], + post_stimulus_time=args["post_stimulus_time"], + num_trials=args["num_trials"], + stimulus_index=args["stimulus"]["index"], + start_field=args["start_field"], ) - logging.info('Loading LFP data') - lfp_data_file = ContinuousFile(probe['lfp_data_path'], - probe['lfp_timestamps_path'], - probe['total_channels']) - lfp_raw, timestamps = lfp_data_file.load( - memmap=args['memmap'], - memmap_thresh=args['memmap_thresh'] - ) + logging.info("Loading LFP data") + lfp_data_file = ContinuousFile(probe["lfp_data_path"], probe["lfp_timestamps_path"], probe["total_channels"]) + lfp_raw, timestamps = lfp_data_file.load(memmap=args["memmap"], memmap_thresh=args["memmap_thresh"]) - if probe['phase'].lower() == '3a': + if probe["phase"].lower() == "3a": lfp_channels = lfp_data_file.get_lfp_channel_order() else: - lfp_channels = np.arange(0, probe['total_channels']) + lfp_channels = np.arange(0, probe["total_channels"]) lfp_referenced = remove_lfp_noise( lfp=lfp_raw, - surface_channel=probe['surface_channel'], + surface_channel=probe["surface_channel"], channel_numbers=lfp_channels, - max_out_of_brain_channels=args['max_out_of_brain_channels'] + max_out_of_brain_channels=args["max_out_of_brain_channels"], ) - logging.info('Accumulating LFP data') + logging.info("Accumulating LFP data") accumulated_lfp_data = accumulate_lfp_data( timestamps=timestamps, lfp_raw=lfp_referenced, lfp_channels=lfp_channels, trial_windows=trial_windows, - volts_per_bit=args['volts_per_bit'] + volts_per_bit=args["volts_per_bit"], ) - logging.info('Removing noisy and reference channels') + logging.info("Removing noisy and reference channels") clean_lfp, clean_channels = select_good_channels( lfp=accumulated_lfp_data, - reference_channels=probe['reference_channels'], - noisy_channel_threshold=args['noisy_channel_threshold'] + reference_channels=probe["reference_channels"], + noisy_channel_threshold=args["noisy_channel_threshold"], ) - logging.info('Bandpass filtering LFP channel data') - filt_lfp = filter_lfp_channels(lfp=clean_lfp, - sampling_rate=probe['sampling_rate'], - filter_cuts=args['filter_cuts'], - filter_order=args['filter_order']) - - logging.info('Interpolating LFP channel locations') - actual_locs = make_actual_channel_locations( - 0, - accumulated_lfp_data.shape[1] + logging.info("Bandpass filtering LFP channel data") + filt_lfp = filter_lfp_channels( + lfp=clean_lfp, + sampling_rate=probe["sampling_rate"], + filter_cuts=args["filter_cuts"], + filter_order=args["filter_order"], ) + + logging.info("Interpolating LFP channel locations") + actual_locs = make_actual_channel_locations(0, accumulated_lfp_data.shape[1]) clean_actual_locs = actual_locs[clean_channels, :] - interp_locs = make_interp_channel_locations( - 0, - accumulated_lfp_data.shape[1] - ) + interp_locs = make_interp_channel_locations(0, accumulated_lfp_data.shape[1]) if len(clean_channels) == 0: - logging.error(f'There are no clean channels. Skipping probe ' - f'{probe["name"]}') - probewise_outputs.append({ - 'name': probe['name'], - 'csd_path': None, - 'clean_channels': clean_channels.tolist() - }) + logging.error(f"There are no clean channels. Skipping probe {probe['name']}") + probewise_outputs.append( + {"name": probe["name"], "csd_path": None, "clean_channels": clean_channels.tolist()} + ) continue try: interp_lfp, spacing = interp_channel_locs( - lfp=filt_lfp, - actual_locs=clean_actual_locs, - interp_locs=interp_locs + lfp=filt_lfp, actual_locs=clean_actual_locs, interp_locs=interp_locs ) except QhullError: - logging.error(f'There are only {len(clean_channels)} ' - f'clean channels, which is not enough for ' - f'interpolation. Skipping probe {probe["name"]}') - probewise_outputs.append({ - 'name': probe['name'], - 'csd_path': None, - 'clean_channels': clean_channels.tolist() - }) + logging.error( + f"There are only {len(clean_channels)} " + f"clean channels, which is not enough for " + f"interpolation. Skipping probe {probe['name']}" + ) + probewise_outputs.append( + {"name": probe["name"], "csd_path": None, "clean_channels": clean_channels.tolist()} + ) continue - logging.info('Averaging LFPs over trials') + logging.info("Averaging LFPs over trials") trial_mean_lfp = np.nanmean(interp_lfp, axis=0) - logging.info('Computing CSD') - current_source_density, csd_channels = compute_csd( - trial_mean_lfp=trial_mean_lfp, - spacing=spacing - ) + logging.info("Computing CSD") + current_source_density, csd_channels = compute_csd(trial_mean_lfp=trial_mean_lfp, spacing=spacing) - logging.info('Saving data') + logging.info("Saving data") write_csd_to_h5( path=probe["csd_output_path"], csd=current_source_density, relative_window=relative_window, channels=csd_channels, csd_locations=interp_locs, - stimulus_name=args['stimulus']['key'], + stimulus_name=args["stimulus"]["key"], stimulus_index=args["stimulus"]["index"], - num_trials=args["num_trials"] + num_trials=args["num_trials"], ) - probewise_outputs.append({ - 'name': probe['name'], - 'csd_path': probe['csd_output_path'], - 'clean_channels': clean_channels.tolist() - }) + probewise_outputs.append( + {"name": probe["name"], "csd_path": probe["csd_output_path"], "clean_channels": clean_channels.tolist()} + ) return { - 'probe_outputs': probewise_outputs, + "probe_outputs": probewise_outputs, } -def write_csd_to_h5(path: Path, csd: np.ndarray, relative_window, - channels: np.ndarray, csd_locations: np.ndarray, - stimulus_name: str, stimulus_index: Optional[int], - num_trials: Optional[int]): +def write_csd_to_h5( + path: Path, + csd: np.ndarray, + relative_window, + channels: np.ndarray, + csd_locations: np.ndarray, + stimulus_name: str, + stimulus_index: Optional[int], + num_trials: Optional[int], +): with h5py.File(str(path), "w") as output: output.create_dataset("current_source_density", data=csd) output.create_dataset("timestamps", data=relative_window) @@ -243,12 +216,8 @@ def write_csd_to_h5(path: Path, csd: np.ndarray, relative_window, def main(): - - logging.basicConfig(format=('%(asctime)s:%(funcName)s' - ':%(levelname)s:%(message)s'), - level=logging.INFO) - parser = optional_lims_inputs(sys.argv, InputParameters, - OutputParameters, get_inputs_from_lims) + logging.basicConfig(format=("%(asctime)s:%(funcName)s:%(levelname)s:%(message)s"), level=logging.INFO) + parser = optional_lims_inputs(sys.argv, InputParameters, OutputParameters, get_inputs_from_lims) output = run_csd(parser.args) write_or_print_outputs(output, parser) diff --git a/allensdk/brain_observatory/ecephys/current_source_density/_current_source_density.py b/allensdk/brain_observatory/ecephys/current_source_density/_current_source_density.py index 04e9aad5f3..8273a6537f 100644 --- a/allensdk/brain_observatory/ecephys/current_source_density/_current_source_density.py +++ b/allensdk/brain_observatory/ecephys/current_source_density/_current_source_density.py @@ -16,12 +16,12 @@ def extract_trial_windows( post_stimulus_time: float, num_trials: Optional[int] = None, stimulus_index: Optional[int] = None, - name_field: str = 'stimulus_name', - index_field: str = 'stimulus_index', - start_field: str = 'Start', - end_field: str = 'End' + name_field: str = "stimulus_name", + index_field: str = "stimulus_index", + start_field: str = "Start", + end_field: str = "End", ) -> Tuple[List[np.ndarray], np.ndarray]: - '''Obtains time interval surrounding stimulus sweep onsets + """Obtains time interval surrounding stimulus sweep onsets Parameters ---------- @@ -62,42 +62,38 @@ def extract_trial_windows( trial's onset. relative_times : numpy.ndarray The basic time domain, centered on 0. - ''' + """ if stimulus_index is None: - stimulus_index = np.amin(stimulus_table[stimulus_table[name_field] - == stimulus_name][index_field].values) + stimulus_index = np.amin(stimulus_table[stimulus_table[name_field] == stimulus_name][index_field].values) - stimulus_name_mask = (stimulus_table[name_field] == stimulus_name) - stimulus_index_mask = (stimulus_table[index_field] == stimulus_index) + stimulus_name_mask = stimulus_table[name_field] == stimulus_name + stimulus_index_mask = stimulus_table[index_field] == stimulus_index trials = stimulus_table[stimulus_name_mask & stimulus_index_mask] if num_trials is not None: trials = trials.iloc[:num_trials, :] - trials = trials.to_dict('records') + trials = trials.to_dict("records") - relative_times = np.arange(-pre_stimulus_time, - post_stimulus_time, - time_step) + relative_times = np.arange(-pre_stimulus_time, post_stimulus_time, time_step) trial_windows = [relative_times + trial[start_field] for trial in trials] - msg = 'calculated relative timestamps: {} ({} timestamps per trial)' + msg = "calculated relative timestamps: {} ({} timestamps per trial)" logging.info(msg.format(relative_times, len(relative_times))) - msg = 'setup {} trial windows spanning {} to {}' - logging.info(msg.format(len(trial_windows), - trial_windows[0][0], - trial_windows[-1][-1])) + msg = "setup {} trial windows spanning {} to {}" + logging.info(msg.format(len(trial_windows), trial_windows[0][0], trial_windows[-1][-1])) return (trial_windows, relative_times) -def accumulate_lfp_data(timestamps: np.ndarray, lfp_raw: np.ndarray, - lfp_channels: np.ndarray, - trial_windows: List[np.ndarray], - volts_per_bit: float = 1.0, - extractor_factory: Callable = ( - regular_grid_extractor_factory) - ) -> np.ndarray: - ''' Extracts slices of LFP data at defined channels and times. +def accumulate_lfp_data( + timestamps: np.ndarray, + lfp_raw: np.ndarray, + lfp_channels: np.ndarray, + trial_windows: List[np.ndarray], + volts_per_bit: float = 1.0, + extractor_factory: Callable = (regular_grid_extractor_factory), +) -> np.ndarray: + """Extracts slices of LFP data at defined channels and times. Parameters ---------- @@ -121,17 +117,16 @@ def accumulate_lfp_data(timestamps: np.ndarray, lfp_raw: np.ndarray, accumulated : numpy.ndarray Extracted data. Dimensions are trials X channels X samples - ''' + """ num_samples = min(len(tw) for tw in trial_windows) num_trials = len(trial_windows) num_channels = len(lfp_channels) - accumulated = np.zeros((num_trials, num_channels, num_samples), - dtype=lfp_raw.dtype) + accumulated = np.zeros((num_trials, num_channels, num_samples), dtype=lfp_raw.dtype) for channel_idx, chan in enumerate(lfp_channels): - logging.info('extracting lfp for channel {}'.format(chan)) + logging.info("extracting lfp for channel {}".format(chan)) extractor = extractor_factory(timestamps, lfp_raw, chan) for trial_index, trial_window in enumerate(trial_windows): @@ -141,14 +136,13 @@ def accumulate_lfp_data(timestamps: np.ndarray, lfp_raw: np.ndarray, current = np.around(current).astype(accumulated.dtype) accumulated[trial_index, channel_idx, :] = current - msg = 'extracted lfp data for {} trials, {} channels, and {} samples' + msg = "extracted lfp data for {} trials, {} channels, and {} samples" logging.info(msg.format(*accumulated.shape)) return accumulated * volts_per_bit -def compute_csd(trial_mean_lfp: np.ndarray, - spacing: float) -> Tuple[np.ndarray, np.ndarray]: - '''Compute current source density for real or virtual channels from +def compute_csd(trial_mean_lfp: np.ndarray, spacing: float) -> Tuple[np.ndarray, np.ndarray]: + """Compute current source density for real or virtual channels from a neuropixels probe. Compute a second spatial derivative along the probe length @@ -171,16 +165,12 @@ def compute_csd(trial_mean_lfp: np.ndarray, Current source density. Dimensions are channels X time samples. csd_channels: numpy.ndarray Array of channel indices for CSD. - ''' + """ # Need to pad lfp channels for Laplacian approx. - padded_lfp = np.pad(trial_mean_lfp, - pad_width=((1, 1), (0, 0)), - mode='edge') + padded_lfp = np.pad(trial_mean_lfp, pad_width=((1, 1), (0, 0)), mode="edge") - csd = (1 / (spacing ** 2)) * (padded_lfp[2:, :] - - (2 * padded_lfp[1:-1, :]) - + padded_lfp[:-2, :]) + csd = (1 / (spacing**2)) * (padded_lfp[2:, :] - (2 * padded_lfp[1:-1, :]) + padded_lfp[:-2, :]) csd_channels = np.arange(0, trial_mean_lfp.shape[0]) diff --git a/allensdk/brain_observatory/ecephys/current_source_density/_filter_utils.py b/allensdk/brain_observatory/ecephys/current_source_density/_filter_utils.py index 06e15394a7..0c264ec459 100644 --- a/allensdk/brain_observatory/ecephys/current_source_density/_filter_utils.py +++ b/allensdk/brain_observatory/ecephys/current_source_density/_filter_utils.py @@ -4,10 +4,9 @@ from scipy import signal -def select_good_channels(lfp: np.ndarray, - reference_channels: List[int], - noisy_channel_threshold: float - ) -> Tuple[np.ndarray, np.ndarray]: +def select_good_channels( + lfp: np.ndarray, reference_channels: List[int], noisy_channel_threshold: float +) -> Tuple[np.ndarray, np.ndarray]: """Remove reference channels and channels that are too noisy from lfp data. Parameters @@ -31,12 +30,7 @@ def select_good_channels(lfp: np.ndarray, channel_variance = np.mean(np.std(lfp, 2), 0) noisy_channels = np.where(channel_variance > noisy_channel_threshold)[0] - to_remove = np.concatenate( - ( - np.array(reference_channels), - noisy_channels - ) - ).astype(int) + to_remove = np.concatenate((np.array(reference_channels), noisy_channels)).astype(int) good_indices = np.delete(np.arange(0, lfp.shape[1]), to_remove) # Remove noisy or reference channels (axis=1) @@ -45,11 +39,10 @@ def select_good_channels(lfp: np.ndarray, return (cleaned_lfp, good_indices) -def filter_lfp_channels(lfp: np.ndarray, - sampling_rate: float, - filter_cuts: List[float], - filter_order: int) -> np.ndarray: - '''Bandpass filter lfp channel data. +def filter_lfp_channels( + lfp: np.ndarray, sampling_rate: float, filter_cuts: List[float], filter_order: int +) -> np.ndarray: + """Bandpass filter lfp channel data. Parameters ---------- @@ -68,11 +61,11 @@ def filter_lfp_channels(lfp: np.ndarray, filtered_lfp: numpy.ndarray LFP that has been bandpassed filtered along the sample axis. Still in the form of: trials x channels x time samples - ''' + """ - wn = (sampling_rate / 2) + wn = sampling_rate / 2 filter_cutoffs = np.array(filter_cuts) / wn - b, a = signal.butter(filter_order, filter_cutoffs, 'bandpass') + b, a = signal.butter(filter_order, filter_cutoffs, "bandpass") # Bandpass filter time samples (axis=2) filtered_lfp = signal.filtfilt(b, a, lfp, axis=2) diff --git a/allensdk/brain_observatory/ecephys/current_source_density/_interpolation_utils.py b/allensdk/brain_observatory/ecephys/current_source_density/_interpolation_utils.py index 6f6545af49..909a41b715 100644 --- a/allensdk/brain_observatory/ecephys/current_source_density/_interpolation_utils.py +++ b/allensdk/brain_observatory/ecephys/current_source_density/_interpolation_utils.py @@ -5,11 +5,10 @@ from scipy.interpolate import RegularGridInterpolator, griddata -def regular_grid_extractor_factory(timestamps: np.ndarray, - lfp_raw: np.ndarray, - channel: int, - method: str = 'linear') -> np.ndarray: - '''Builds an LFP data extractor using interpolation on a regular grid +def regular_grid_extractor_factory( + timestamps: np.ndarray, lfp_raw: np.ndarray, channel: int, method: str = "linear" +) -> np.ndarray: + """Builds an LFP data extractor using interpolation on a regular grid Ignores timestamps less than zero (which result from unaligned data segments) @@ -30,20 +29,21 @@ def regular_grid_extractor_factory(timestamps: np.ndarray, ------- numpy.ndarray LFP data that has been interpolated to a regular grid. - ''' + """ - valid_timestamps = (timestamps >= 0) + valid_timestamps = timestamps >= 0 - return RegularGridInterpolator((timestamps[valid_timestamps],), - lfp_raw[valid_timestamps, channel], - method=method, - bounds_error=False, - fill_value=np.nan) + return RegularGridInterpolator( + (timestamps[valid_timestamps],), + lfp_raw[valid_timestamps, channel], + method=method, + bounds_error=False, + fill_value=np.nan, + ) -def make_actual_channel_locations(min_chan: int = 0, - max_chan: int = 384) -> np.ndarray: - '''Generate x/y locations of Neuropixels recording sites. +def make_actual_channel_locations(min_chan: int = 0, max_chan: int = 384) -> np.ndarray: + """Generate x/y locations of Neuropixels recording sites. 0 8 16 24 32 40 48 60 * - - - * - - @@ -66,7 +66,7 @@ def make_actual_channel_locations(min_chan: int = 0, actual_channel_locations: numpy.ndarray column 1 = x positions in microns column 2 = y positions in microns - ''' + """ actual_channel_locations = np.zeros((max_chan, 2)) @@ -79,9 +79,8 @@ def make_actual_channel_locations(min_chan: int = 0, return actual_channel_locations[min_chan:, :] -def make_interp_channel_locations(min_chan: int = 0, - max_chan: int = 384) -> np.ndarray: - '''Generate x/y locations for interpolated Neuropixels recording sites. +def make_interp_channel_locations(min_chan: int = 0, max_chan: int = 384) -> np.ndarray: + """Generate x/y locations for interpolated Neuropixels recording sites. This version just returns the central column of interpolated sites. @@ -108,7 +107,7 @@ def make_interp_channel_locations(min_chan: int = 0, interp_channel_locations: numpy.ndarray column 1 = interpolated x positions in microns column 2 = y positions in microns - ''' + """ interp_channel_locations = np.zeros((max_chan, 2)) @@ -119,11 +118,10 @@ def make_interp_channel_locations(min_chan: int = 0, return interp_channel_locations[min_chan:, :] -def interp_channel_locs(lfp: np.ndarray, - actual_locs: np.ndarray, - interp_locs: np.ndarray, - method: str = 'cubic') -> Tuple[np.ndarray, float]: - '''Interpolates single-trial lfp channel locations to account for +def interp_channel_locs( + lfp: np.ndarray, actual_locs: np.ndarray, interp_locs: np.ndarray, method: str = "cubic" +) -> Tuple[np.ndarray, float]: + """Interpolates single-trial lfp channel locations to account for channel stagger. Parameters @@ -149,27 +147,30 @@ def interp_channel_locs(lfp: np.ndarray, spacing: float Distance between new interpolated virtual channel sites (in millimeters) - ''' + """ if lfp.shape[1] != actual_locs.shape[0]: - e_msg = (f"Number of 'lfp' channels ({lfp.shape[1]}) does not " - f"match number of 'actual_locs' ({actual_locs.shape[0]})!") + e_msg = ( + f"Number of 'lfp' channels ({lfp.shape[1]}) does not " + f"match number of 'actual_locs' ({actual_locs.shape[0]})!" + ) raise RuntimeError(e_msg) spacing = np.mean(np.diff(interp_locs[:, 1])) / 1000 - interp_lfp = np.zeros((lfp.shape[0], # number of interp trials - interp_locs.shape[0], # number of interp channels - lfp.shape[2])) # number of interp samples + interp_lfp = np.zeros( + ( + lfp.shape[0], # number of interp trials + interp_locs.shape[0], # number of interp channels + lfp.shape[2], + ) + ) # number of interp samples for trial in range(lfp.shape[0]): # trials trial_data = lfp[trial, :, :] for t in range(0, lfp.shape[2]): # time samples - interp_lfp[trial, :, t] = griddata(points=actual_locs, - values=trial_data[:, t], - xi=interp_locs, - method=method, - fill_value=0, - rescale=False) + interp_lfp[trial, :, t] = griddata( + points=actual_locs, values=trial_data[:, t], xi=interp_locs, method=method, fill_value=0, rescale=False + ) return (interp_lfp, spacing) diff --git a/allensdk/brain_observatory/ecephys/current_source_density/_schemas.py b/allensdk/brain_observatory/ecephys/current_source_density/_schemas.py index 58808d3808..2e4e174a4c 100644 --- a/allensdk/brain_observatory/ecephys/current_source_density/_schemas.py +++ b/allensdk/brain_observatory/ecephys/current_source_density/_schemas.py @@ -5,113 +5,81 @@ class ProbeInputParameters(DefaultSchema): - name = String(required=True, help='Identifier for this probe.') - lfp_data_path = String(required=True, - help='Path to lfp data for this probe') - lfp_timestamps_path = String(required=True, - help="Path to aligned lfp timestamps for " - "this probe.") - surface_channel = Int(required=True, - help='Estimate of surface (pia boundary) channel ' - 'index') - reference_channels = List(Int, many=True, - help='Indices of reference channels for this ' - 'probe') - csd_output_path = String(required=True, - help='CSD output will be written here.') - sampling_rate = Float(required=True, - help='sampling rate assessed on master clock') - total_channels = Int(default=384, - help='Total channel count for this probe.') - surface_channel_adjustment = Int(default=40, - help='Erring up in the surface channel ' - 'estimate is less dangerous for ' - 'the CSD calculation than erring ' - 'down, so an adjustment is ' - 'provided.') - spacing = Float(default=0.04, - help='distance (in millimiters) between ' - 'lengthwise-adjacent rows of recording sites on ' - 'this probe.') - phase = String(required=True, - help='The probe type (3a or PXI) which determines if ' - 'channels need to be reordered') + name = String(required=True, help="Identifier for this probe.") + lfp_data_path = String(required=True, help="Path to lfp data for this probe") + lfp_timestamps_path = String(required=True, help="Path to aligned lfp timestamps for this probe.") + surface_channel = Int(required=True, help="Estimate of surface (pia boundary) channel index") + reference_channels = List(Int, many=True, help="Indices of reference channels for this probe") + csd_output_path = String(required=True, help="CSD output will be written here.") + sampling_rate = Float(required=True, help="sampling rate assessed on master clock") + total_channels = Int(default=384, help="Total channel count for this probe.") + surface_channel_adjustment = Int( + default=40, + help="Erring up in the surface channel " + "estimate is less dangerous for " + "the CSD calculation than erring " + "down, so an adjustment is " + "provided.", + ) + spacing = Float( + default=0.04, + help="distance (in millimiters) between lengthwise-adjacent rows of recording sites on this probe.", + ) + phase = String(required=True, help="The probe type (3a or PXI) which determines if channels need to be reordered") class StimulusInputParameters(DefaultSchema): - stimulus_table_path = String(required=True, help='Path to stimulus table') - key = String(required=True, - help='CSD is calculated from a specific stimulus, defined (' - 'in part) by this key.') - index = Int(default=None, allow_none=True, - help='CSD is calculated from a specific stimulus, defined (' - 'in part) by this index.') + stimulus_table_path = String(required=True, help="Path to stimulus table") + key = String(required=True, help="CSD is calculated from a specific stimulus, defined (in part) by this key.") + index = Int( + default=None, + allow_none=True, + help="CSD is calculated from a specific stimulus, defined (in part) by this index.", + ) class InputParameters(ArgSchema): - stimulus = Nested(StimulusInputParameters, required=True, - help='Defines the stimulus from which CSD is calculated') - probes = Nested(ProbeInputParameters, many=True, required=True, - help='Probewise parameters.') - pre_stimulus_time = Float(required=True, - help='how much time pre stimulus onset is used ' - 'for CSD calculation ') - post_stimulus_time = Float(required=True, - help='how much time post stimulus onset is ' - 'used for CSD calculation ') - num_trials = Int(default=None, allow_none=True, - help='Number of trials after stimulus onset from which ' - 'to compute CSD') - volts_per_bit = Float(default=1.0, - help='If the data are not in units of volts, ' - 'they must be converted. In the past, ' - 'this value was 0.195') - memmap = Bool(default=False, - help='whether to memory map the data file on disk or load ' - 'it directly to main memory') - memmap_thresh = Float(default=np.inf, - help='files larger than this threshold (bytes) ' - 'will be memmapped, regardless of the memmap ' - 'setting.') - filter_cuts = List(Float, default=[5.0, 150.0], - cli_as_single_argument=True, - help='Cutoff frequencies for bandpass filter') - filter_order = Int(default=5, help='Order for bandpass filter') - reorder_channels = Bool(default=True, - help='Determines whether LFP channels should be ' - 're-ordered') - noisy_channel_threshold = Float(default=1500.0, - help='Threshold for removing noisy ' - 'channels from analysis') + stimulus = Nested(StimulusInputParameters, required=True, help="Defines the stimulus from which CSD is calculated") + probes = Nested(ProbeInputParameters, many=True, required=True, help="Probewise parameters.") + pre_stimulus_time = Float(required=True, help="how much time pre stimulus onset is used for CSD calculation ") + post_stimulus_time = Float(required=True, help="how much time post stimulus onset is used for CSD calculation ") + num_trials = Int( + default=None, allow_none=True, help="Number of trials after stimulus onset from which to compute CSD" + ) + volts_per_bit = Float( + default=1.0, + help="If the data are not in units of volts, they must be converted. In the past, this value was 0.195", + ) + memmap = Bool(default=False, help="whether to memory map the data file on disk or load it directly to main memory") + memmap_thresh = Float( + default=np.inf, + help="files larger than this threshold (bytes) will be memmapped, regardless of the memmap setting.", + ) + filter_cuts = List( + Float, default=[5.0, 150.0], cli_as_single_argument=True, help="Cutoff frequencies for bandpass filter" + ) + filter_order = Int(default=5, help="Order for bandpass filter") + reorder_channels = Bool(default=True, help="Determines whether LFP channels should be re-ordered") + noisy_channel_threshold = Float(default=1500.0, help="Threshold for removing noisy channels from analysis") max_out_of_brain_channels = Int( default=50, - help='Rereferencing can sometimes fail for experiments with shallow ' - 'probe insertions as the uppermost channels are in air and not ' - 'agar. This places a limit on the number of channels to use for ' - 're-referencing.' - ) - start_field = String( - default='Start', - help='Column from which to extract start times.' + help="Rereferencing can sometimes fail for experiments with shallow " + "probe insertions as the uppermost channels are in air and not " + "agar. This places a limit on the number of channels to use for " + "re-referencing.", ) + start_field = String(default="Start", help="Column from which to extract start times.") class ProbeOutputParameters(DefaultSchema): - name = String(required=True, help='Identifier for this probe.') - csd_path = String(required=True, - allow_none=True, - help='Path to current source density file.') - clean_channels = List(Int, required=True, - help='List of channels used in CSD calculation') + name = String(required=True, help="Identifier for this probe.") + csd_path = String(required=True, allow_none=True, help="Path to current source density file.") + clean_channels = List(Int, required=True, help="List of channels used in CSD calculation") class OutputSchema(DefaultSchema): - input_parameters = Nested(InputParameters, - description=("Input parameters the module " - "was run with"), - required=True) + input_parameters = Nested(InputParameters, description=("Input parameters the module was run with"), required=True) class OutputParameters(OutputSchema): - probe_outputs = Nested(ProbeOutputParameters, many=True, required=True, - help='probewise outputs') + probe_outputs = Nested(ProbeOutputParameters, many=True, required=True, help="probewise outputs") diff --git a/allensdk/brain_observatory/ecephys/data_objects/trials.py b/allensdk/brain_observatory/ecephys/data_objects/trials.py index 3b434fea1e..6a9892d03c 100644 --- a/allensdk/brain_observatory/ecephys/data_objects/trials.py +++ b/allensdk/brain_observatory/ecephys/data_objects/trials.py @@ -2,21 +2,12 @@ import numpy as np -from allensdk.brain_observatory.behavior.data_objects.trials.trial import ( - Trial) -from allensdk.brain_observatory.behavior.data_objects.\ - trials.trials import Trials +from allensdk.brain_observatory.behavior.data_objects.trials.trial import Trial +from allensdk.brain_observatory.behavior.data_objects.trials.trials import Trials class VBNTrial(Trial): - - def calculate_change_frame( - self, - event_dict: dict, - go: bool, - catch: bool, - auto_rewarded: bool) -> Union[int, float]: - + def calculate_change_frame(self, event_dict: dict, go: bool, catch: bool, auto_rewarded: bool) -> Union[int, float]: """ Calculate the frame index of a stimulus change associated with a specific event. @@ -46,9 +37,9 @@ def calculate_change_frame( """ if go or auto_rewarded: - change_frame = event_dict.get(('stimulus_changed', ''))['frame'] + change_frame = event_dict.get(("stimulus_changed", ""))["frame"] elif catch: - change_frame = event_dict.get(('sham_change', ''))['frame'] + change_frame = event_dict.get(("sham_change", ""))["frame"] else: change_frame = float("nan") @@ -85,7 +76,7 @@ def add_change_time(self, trial_dict: dict) -> Tuple[dict, float]: ---- Modified trial_dict in-place, in addition to returning it """ - change_frame = trial_dict['change_frame'] + change_frame = trial_dict["change_frame"] if np.isnan(change_frame): change_time = np.nan else: @@ -93,12 +84,11 @@ def add_change_time(self, trial_dict: dict) -> Tuple[dict, float]: change_frame = int(change_frame) change_time = no_delay.value[change_frame] - trial_dict['change_time_no_display_delay'] = change_time + trial_dict["change_time_no_display_delay"] = change_time return trial_dict, change_time class VBNTrials(Trials): - @classmethod def trial_class(cls): """ @@ -111,14 +101,29 @@ def columns_to_output(cls) -> List[str]: """ Return the list of columns to be output in this table """ - return ['initial_image_name', 'change_image_name', - 'stimulus_change', 'change_time_no_display_delay', - 'go', 'catch', 'lick_times', 'response_time', - 'reward_time', 'reward_volume', - 'hit', 'false_alarm', 'miss', 'correct_reject', - 'aborted', 'auto_rewarded', 'change_frame', - 'start_time', 'stop_time', 'trial_length'] + return [ + "initial_image_name", + "change_image_name", + "stimulus_change", + "change_time_no_display_delay", + "go", + "catch", + "lick_times", + "response_time", + "reward_time", + "reward_volume", + "hit", + "false_alarm", + "miss", + "correct_reject", + "aborted", + "auto_rewarded", + "change_frame", + "start_time", + "stop_time", + "trial_length", + ] @property def change_time(self): - return self.data['change_time_no_display_delay'] + return self.data["change_time_no_display_delay"] diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_api.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_api.py index 6db0e473d3..2942f20139 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_api.py @@ -4,20 +4,16 @@ import pandas as pd -# TODO: This should be a generic over the type of the values, but there is not -# good support currently for numpy and pandas type annotations +# TODO: This should be a generic over the type of the values, but there is not +# good support currently for numpy and pandas type annotations # we should investigate numpy and pandas typing support and migrate # https://github.com/numpy/numpy-stubs -# https://github.com/pandas-dev/pandas/blob/master/pandas/_typing.py +# https://github.com/pandas-dev/pandas/blob/master/pandas/_typing.py ArrayLike = TypeVar("ArrayLike", list, np.ndarray, pd.Series, tuple) class EcephysProjectApi: - def get_sessions( - self, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None - ): + def get_sessions(self, session_ids: Optional[ArrayLike] = None, published_at: Optional[str] = None): raise NotImplementedError() def get_session_data(self, session_id: int) -> Iterable: @@ -27,29 +23,29 @@ def get_isi_experiments(self, *args, **kwargs): raise NotImplementedError() def get_units( - self, - unit_ids: Optional[ArrayLike] = None, - channel_ids: Optional[ArrayLike] = None, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + unit_ids: Optional[ArrayLike] = None, + channel_ids: Optional[ArrayLike] = None, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ): raise NotImplementedError() def get_channels( - self, - channel_ids: Optional[ArrayLike] = None, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + channel_ids: Optional[ArrayLike] = None, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ): raise NotImplementedError() def get_probes( - self, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ): raise NotImplementedError() @@ -63,9 +59,9 @@ def get_natural_scene_template(self, number) -> Iterable: raise NotImplementedError() def get_unit_analysis_metrics( - self, - unit_ids: Optional[ArrayLike] = None, - ecephys_session_ids: Optional[ArrayLike] = None, - session_types: Optional[ArrayLike] = None + self, + unit_ids: Optional[ArrayLike] = None, + ecephys_session_ids: Optional[ArrayLike] = None, + session_types: Optional[ArrayLike] = None, ) -> pd.DataFrame: - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_fixed_api.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_fixed_api.py index c67c89c33c..1c2eb15cbc 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_fixed_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_fixed_api.py @@ -6,7 +6,6 @@ class MissingDataError(ValueError): class EcephysProjectFixedApi(EcephysProjectApi): - def get_session_data(self, session_id, *args, **kwargs): raise MissingDataError(f"data for session {session_id} not found!") diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_lims_api.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_lims_api.py index 6eb6a854cf..d8d13041bb 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_lims_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_lims_api.py @@ -12,49 +12,47 @@ class EcephysProjectLimsApi(EcephysProjectApi): - STIMULUS_TEMPLATE_NAMESPACE = "brain_observatory_1.1" def __init__(self, postgres_engine, app_engine): - """ Downloads extracellular ephys data from the Allen Institute's - internal Laboratory Information Management System (LIMS). If you are - on our network you can use this class to get bleeding-edge data into + """Downloads extracellular ephys data from the Allen Institute's + internal Laboratory Information Management System (LIMS). If you are + on our network you can use this class to get bleeding-edge data into an EcephysProjectCache. If not, it won't work at all Parameters ---------- - postgres_engine : - used for making queries against the LIMS postgres database. Must + postgres_engine : + used for making queries against the LIMS postgres database. Must implement: - select : takes a postgres query as a string. Returns a pandas + select : takes a postgres query as a string. Returns a pandas dataframe of results - select_one : takes a postgres query as a string. If there is - exactly one record in the response, returns that record as + select_one : takes a postgres query as a string. If there is + exactly one record in the response, returns that record as a dict. Otherwise returns an empty dict. - app_engine : - used for making queries agains the lims web application. Must + app_engine : + used for making queries agains the lims web application. Must implement: - stream : takes a url as a string. Returns an iterable yielding + stream : takes a url as a string. Returns an iterable yielding the response body as bytes. Notes ----- - You almost certainly want to construct this class by calling + You almost certainly want to construct this class by calling EcephysProjectLimsApi.default() rather than this constructor directly. """ - self.postgres_engine = postgres_engine self.app_engine = app_engine def get_session_data(self, session_id: int) -> Iterable[bytes]: - """ Download an NWB file containing detailed data for an ecephys + """Download an NWB file containing detailed data for an ecephys session. Parameters ---------- - session_id : + session_id : Download an NWB file for this session Returns @@ -86,17 +84,15 @@ def get_session_data(self, session_id: int) -> Iterable[bytes]: ) nwb_id = nwb_response.loc[0, "id"] - return self.app_engine.stream( - f"well_known_files/download/{nwb_id}?wkf_id={nwb_id}" - ) + return self.app_engine.stream(f"well_known_files/download/{nwb_id}?wkf_id={nwb_id}") def get_probe_lfp_data(self, probe_id: int) -> Iterable[bytes]: - """ Download an NWB file containing detailed data for the local field + """Download an NWB file containing detailed data for the local field potential recorded from an ecephys probe. Parameters ---------- - probe_id : + probe_id : Download an NWB file for this probe's LFP Returns @@ -119,7 +115,7 @@ def get_probe_lfp_data(self, probe_id: int) -> Iterable[bytes]: and earp.ecephys_probe_id = {{probe_id}} """, engine=self.postgres_engine.select, - probe_id=probe_id + probe_id=probe_id, ) if nwb_response.shape[0] != 1: @@ -129,40 +125,38 @@ def get_probe_lfp_data(self, probe_id: int) -> Iterable[bytes]: ) nwb_id = nwb_response.loc[0, "id"] - return self.app_engine.stream( - f"well_known_files/download/{nwb_id}?wkf_id={nwb_id}" - ) + return self.app_engine.stream(f"well_known_files/download/{nwb_id}?wkf_id={nwb_id}") def get_units( - self, - unit_ids: Optional[ArrayLike] = None, - channel_ids: Optional[ArrayLike] = None, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + unit_ids: Optional[ArrayLike] = None, + channel_ids: Optional[ArrayLike] = None, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ) -> pd.DataFrame: - """ Download a table of records describing sorted ecephys units. + """Download a table of records describing sorted ecephys units. Parameters ---------- - unit_ids : - A collection of integer identifiers for sorted ecephys units. If + unit_ids : + A collection of integer identifiers for sorted ecephys units. If provided, only return records describing these units. - channel_ids : - A collection of integer identifiers for ecephys channels. If - provided, results will be filtered to units recorded from these + channel_ids : + A collection of integer identifiers for ecephys channels. If + provided, results will be filtered to units recorded from these channels. - probe_ids : - A collection of integer identifiers for ecephys probes. If - provided, results will be filtered to units recorded from these + probe_ids : + A collection of integer identifiers for ecephys probes. If + provided, results will be filtered to units recorded from these probes. - session_ids : - A collection of integer identifiers for ecephys sessions. If + session_ids : + A collection of integer identifiers for ecephys sessions. If provided, results will be filtered to units recorded during these sessions. - published_at : - A date (rendered as "YYYY-MM-DD"). If provided, only units - recorded during sessions published before this date will be + published_at : + A date (rendered as "YYYY-MM-DD"). If provided, only units + recorded during sessions published before this date will be returned. Returns @@ -225,34 +219,34 @@ def get_units( channel_ids=channel_ids, probe_ids=probe_ids, session_ids=session_ids, - **_split_published_at(published_at)._asdict() + **_split_published_at(published_at)._asdict(), ) return response.set_index("id", inplace=False) def get_channels( - self, - channel_ids: Optional[ArrayLike] = None, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + channel_ids: Optional[ArrayLike] = None, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ) -> pd.DataFrame: - """ Download a table of ecephys channel records. + """Download a table of ecephys channel records. Parameters ---------- - channel_ids : - A collection of integer identifiers for ecephys channels. If + channel_ids : + A collection of integer identifiers for ecephys channels. If provided, results will be filtered to these channels. - probe_ids : - A collection of integer identifiers for ecephys probes. If + probe_ids : + A collection of integer identifiers for ecephys probes. If provided, results will be filtered to channels on these probes. - session_ids : - A collection of integer identifiers for ecephys sessions. If + session_ids : + A collection of integer identifiers for ecephys sessions. If provided, results will be filtered to channels recorded from during these sessions. - published_at : - A date (rendered as "YYYY-MM-DD"). If provided, only channels - recorded from during sessions published before this date will be + published_at : + A date (rendered as "YYYY-MM-DD"). If provided, only channels + recorded from during sessions published before this date will be returned. Returns @@ -295,30 +289,30 @@ def get_channels( channel_ids=channel_ids, probe_ids=probe_ids, session_ids=session_ids, - **_split_published_at(published_at)._asdict() + **_split_published_at(published_at)._asdict(), ) return response.set_index("id") def get_probes( - self, - probe_ids: Optional[ArrayLike] = None, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None + self, + probe_ids: Optional[ArrayLike] = None, + session_ids: Optional[ArrayLike] = None, + published_at: Optional[str] = None, ) -> pd.DataFrame: - """ Download a table of ecephys probe records. + """Download a table of ecephys probe records. Parameters ---------- - probe_ids : - A collection of integer identifiers for ecephys probes. If + probe_ids : + A collection of integer identifiers for ecephys probes. If provided, results will be filtered to these probes. - session_ids : - A collection of integer identifiers for ecephys sessions. If + session_ids : + A collection of integer identifiers for ecephys sessions. If provided, results will be filtered to probes recorded from during these sessions. - published_at : - A date (rendered as "YYYY-MM-DD"). If provided, only probes - recorded from during sessions published before this date will be + published_at : + A date (rendered as "YYYY-MM-DD"). If provided, only probes + recorded from during sessions published before this date will be returned. Returns @@ -356,25 +350,20 @@ def get_probes( engine=self.postgres_engine.select, probe_ids=probe_ids, session_ids=session_ids, - **_split_published_at(published_at)._asdict() + **_split_published_at(published_at)._asdict(), ) return response.set_index("id") - - def get_sessions( - self, - session_ids: Optional[ArrayLike] = None, - published_at: Optional[str] = None - ) -> pd.DataFrame: - """ Download a table of ecephys session records. + def get_sessions(self, session_ids: Optional[ArrayLike] = None, published_at: Optional[str] = None) -> pd.DataFrame: + """Download a table of ecephys session records. Parameters ---------- - session_ids : - A collection of integer identifiers for ecephys sessions. If + session_ids : + A collection of integer identifiers for ecephys sessions. If provided, results will be filtered to these sessions. - published_at : - A date (rendered as "YYYY-MM-DD"). If provided, only sessions + published_at : + A date (rendered as "YYYY-MM-DD"). If provided, only sessions published before this date will be returned. Returns @@ -431,41 +420,40 @@ def get_sessions( base=postgres_macros(), engine=self.postgres_engine.select, session_ids=session_ids, - **_split_published_at(published_at)._asdict() + **_split_published_at(published_at)._asdict(), ) - - response.set_index("id", inplace=True) + + response.set_index("id", inplace=True) response["genotype"].fillna("wt", inplace=True) return response - def get_unit_analysis_metrics( - self, - unit_ids: Optional[ArrayLike] = None, - ecephys_session_ids: Optional[ArrayLike] = None, - session_types: Optional[ArrayLike] = None + self, + unit_ids: Optional[ArrayLike] = None, + ecephys_session_ids: Optional[ArrayLike] = None, + session_types: Optional[ArrayLike] = None, ) -> pd.DataFrame: - """ Fetch analysis metrics (stimulus set-specific characterizations of - unit response patterns) for ecephys units. Note that the metrics + """Fetch analysis metrics (stimulus set-specific characterizations of + unit response patterns) for ecephys units. Note that the metrics returned depend on the stimuli that were presented during recording ( and thus on the session_type) Parameters --------- unit_ids : - integer identifiers for a set of ecephys units. If provided, the + integer identifiers for a set of ecephys units. If provided, the response will only include metrics calculated for these units ecephys_session_ids : - integer identifiers for a set of ecephys sessions. If provided, the - response will only include metrics calculated for units identified + integer identifiers for a set of ecephys sessions. If provided, the + response will only include metrics calculated for units identified during these sessions session_types : - string names identifying ecephys session types (e.g. + string names identifying ecephys session types (e.g. "brain_observatory_1.1" or "functional_connectivity") Returns ------- - a pandas dataframe indexed by ecephys unit id whose columns are + a pandas dataframe indexed by ecephys unit id whose columns are metrics. """ @@ -489,7 +477,7 @@ def get_unit_analysis_metrics( engine=self.postgres_engine.select, unit_ids=unit_ids, ecephys_session_ids=ecephys_session_ids, - session_types=session_types + session_types=session_types, ) data = pd.DataFrame(response.pop("data").values.tolist(), index=response.index) @@ -498,9 +486,8 @@ def get_unit_analysis_metrics( return response - def _get_template(self, name, namespace): - """ Identify the WellKnownFile record associated with a stimulus + """Identify the WellKnownFile record associated with a stimulus template and stream its data if present. """ @@ -516,7 +503,7 @@ def _get_template(self, name, namespace): and sn.name = '{namespace}' """, base=postgres_macros(), - engine=self.postgres_engine.select_one + engine=self.postgres_engine.select_one, ) wkf_id = well_known_file["well_known_file_id"] except (KeyError, IndexError): @@ -525,15 +512,14 @@ def _get_template(self, name, namespace): download_link = f"well_known_files/download/{wkf_id}?wkf_id={wkf_id}" return self.app_engine.stream(download_link) - def get_natural_movie_template(self, number: int) -> Iterable[bytes]: - """ Download a template for the natural movie stimulus. This is the + """Download a template for the natural movie stimulus. This is the actual movie that was shown during the recording session. Parameters ---------- number : - idenfifier for this movie (note that this is an integer, so to get + idenfifier for this movie (note that this is an integer, so to get the template for natural_movie_three you should pass in 3) Returns @@ -542,13 +528,10 @@ def get_natural_movie_template(self, number: int) -> Iterable[bytes]: """ - return self._get_template( - f"natural_movie_{number}", self.STIMULUS_TEMPLATE_NAMESPACE - ) - + return self._get_template(f"natural_movie_{number}", self.STIMULUS_TEMPLATE_NAMESPACE) def get_natural_scene_template(self, number: int) -> Iterable[bytes]: - """ Download a template for the natural scene stimulus. This is the + """Download a template for the natural scene stimulus. This is the actual image that was shown during the recording session. Parameters @@ -561,15 +544,11 @@ def get_natural_scene_template(self, number: int) -> Iterable[bytes]: An iterable yielding a tiff file as bytes. """ - return self._get_template( - f"natural_scene_{int(number)}", self.STIMULUS_TEMPLATE_NAMESPACE - ) - + return self._get_template(f"natural_scene_{int(number)}", self.STIMULUS_TEMPLATE_NAMESPACE) @classmethod - def default(cls, lims_credentials: Optional[DbCredentials] = None, - app_kwargs=None, asynchronous=False): - """ Construct a "straightforward" lims api that can fetch data from + def default(cls, lims_credentials: Optional[DbCredentials] = None, app_kwargs=None, asynchronous=False): + """Construct a "straightforward" lims api that can fetch data from lims2. Parameters @@ -579,8 +558,8 @@ def default(cls, lims_credentials: Optional[DbCredentials] = None, the LIMS database. If left unspecified will attempt to provide credentials from environment variables. app_kwargs : dict - High-level configuration for http requests. See - allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine.HttpEngine + High-level configuration for http requests. See + allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine.HttpEngine and AsyncHttpEngine for details. asynchronous : bool If true, (http) queries will be made asynchronously. @@ -602,13 +581,15 @@ def default(cls, lims_credentials: Optional[DbCredentials] = None, if lims_credentials is not None: pg_engine = PostgresQueryMixin( - dbname=lims_credentials.dbname, user=lims_credentials.user, - host=lims_credentials.host, password=lims_credentials.password, - port=lims_credentials.port) + dbname=lims_credentials.dbname, + user=lims_credentials.user, + host=lims_credentials.host, + password=lims_credentials.password, + port=lims_credentials.port, + ) else: # Currying is equivalent to decorator syntactic sugar - pg_engine = (credential_injector(LIMS_DB_CREDENTIAL_MAP) - (PostgresQueryMixin)()) + pg_engine = credential_injector(LIMS_DB_CREDENTIAL_MAP)(PostgresQueryMixin)() return cls(pg_engine, app_engine) @@ -619,11 +600,11 @@ class SplitPublishedAt(NamedTuple): def _split_published_at(published_at: Optional[str]) -> SplitPublishedAt: - """ LIMS queries that filter on published_at need a couple of + """LIMS queries that filter on published_at need a couple of reformattings of the argued date string. """ return SplitPublishedAt( published_at=f"'{published_at}'" if published_at is not None else None, - published_at_not_null=None if published_at is None else True + published_at_not_null=None if published_at is None else True, ) diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_warehouse_api.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_warehouse_api.py index 08ffa7e7a4..a8553b86e3 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_warehouse_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/ecephys_project_warehouse_api.py @@ -40,9 +40,7 @@ def get_session_data(self, session_id, **kwargs): return self.rma_engine.stream(download_link) def get_natural_movie_template(self, number): - well_known_files = self.stimulus_templates[ - self.stimulus_templates["movie_number"] == number - ] + well_known_files = self.stimulus_templates[self.stimulus_templates["movie_number"] == number] if well_known_files.shape[0] != 1: raise ValueError( f"expected exactly one natural movie template with number {number}, found {well_known_files}" # noqa: E501 @@ -52,9 +50,7 @@ def get_natural_movie_template(self, number): return self.rma_engine.stream(download_link) def get_natural_scene_template(self, number): - well_known_files = self.stimulus_templates[ - self.stimulus_templates["scene_number"] == number - ] + well_known_files = self.stimulus_templates[self.stimulus_templates["scene_number"] == number] if well_known_files.shape[0] != 1: raise ValueError( f"expected exactly one natural scene template with number {number}, found {well_known_files}" # noqa: E501 @@ -119,9 +115,7 @@ def get_probe_lfp_data(self, probe_id): download_link = well_known_files.loc[0, "download_link"] return self.rma_engine.stream(download_link) - def get_sessions( - self, session_ids=None, has_eye_tracking=None, stimulus_names=None - ): + def get_sessions(self, session_ids=None, has_eye_tracking=None, stimulus_names=None): response = build_and_execute( ( "{% import 'rma_macros' as rm %}" @@ -171,9 +165,7 @@ def get_sessions( columns=["specimen", "fail_eye_tracking", "well_known_files"], inplace=True, ) - response.rename( - columns={"stimulus_name": "session_type"}, inplace=True - ) + response.rename(columns={"stimulus_name": "session_type"}, inplace=True) return response @@ -257,9 +249,7 @@ def get_units( return response - def get_unit_analysis_metrics( - self, unit_ids=None, ecephys_session_ids=None, session_types=None - ): + def get_unit_analysis_metrics(self, unit_ids=None, ecephys_session_ids=None, session_types=None): """Download analysis metrics - precalculated descriptions of unitwise responses to visual stimulation. @@ -310,9 +300,7 @@ def get_unit_analysis_metrics( for colname in output.columns: try: - output[colname] = output.apply( - lambda row: ast.literal_eval(str(row[colname])), axis=1 - ) + output[colname] = output.apply(lambda row: ast.literal_eval(str(row[colname])), axis=1) except ValueError: pass @@ -322,9 +310,7 @@ def get_unit_analysis_metrics( columns = set(output.columns.values.tolist()) if "p_value_rf" in columns and "on_screen_rf" in columns: pv_is_bool = np.issubdtype(output["p_value_rf"].values[0], bool) - on_screen_is_float = np.issubdtype( - output["on_screen_rf"].values[0].dtype, np.floating - ) + on_screen_is_float = np.issubdtype(output["on_screen_rf"].values[0].dtype, np.floating) # this is not a good test, but it avoids the case where we fix # these in the data for a future release, but diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/http_engine.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/http_engine.py index 2d89d556ce..37e958d1da 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/http_engine.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/http_engine.py @@ -17,29 +17,24 @@ class HttpEngine: def __init__( - self, - scheme: str, - host: str, - timeout: float = DEFAULT_TIMEOUT, - chunksize: int = DEFAULT_CHUNKSIZE, - **kwargs + self, scheme: str, host: str, timeout: float = DEFAULT_TIMEOUT, chunksize: int = DEFAULT_CHUNKSIZE, **kwargs ): - """ Simple tool for making streaming http requests. + """Simple tool for making streaming http requests. Parameters ---------- scheme : e.g "http" or "https" - host : + host : will be used as the base for request urls - timeout : - requests taking longer than this (in seconds) will raise a - `requests.Timeout` error. The clock on this timeout starts running + timeout : + requests taking longer than this (in seconds) will raise a + `requests.Timeout` error. The clock on this timeout starts running when the initial request is made. - chunksize : + chunksize : When streaming data, how many bytes ought to be requested at once. - **kwargs : - unused. Defined here so that parameters can fall through from + **kwargs : + unused. Defined here so that parameters can fall through from subclasses """ @@ -52,7 +47,7 @@ def _build_url(self, route): return f"{self.scheme}://{self.host}/{route}" def stream(self, route): - """ Makes an http request and returns an iterator over the response. + """Makes an http request and returns an iterator over the response. Parameters ---------- @@ -62,19 +57,19 @@ def stream(self, route): """ url = self._build_url(route) - + start_time = time.perf_counter() response = requests.get(url, stream=True) response_b = None if "Content-length" in response.headers: response_b = float(response.headers["Content-length"]) - size_message = f"{response_b / 1024 ** 2:3.3f}MiB" if response_b is not None else "potentially large" + size_message = f"{response_b / 1024**2:3.3f}MiB" if response_b is not None else "potentially large" logging.warning(f"downloading a {size_message} file from {url}") - progress = tqdm( unit="B", total=response_b, unit_scale=True, desc="Downloading") + progress = tqdm(unit="B", total=response_b, unit_scale=True, desc="Downloading") for chunk in response.iter_content(self.chunksize): - if chunk: # filter out keep-alive new chunks + if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) yield chunk @@ -91,25 +86,18 @@ def write_bytes(path: str, stream: Iterable[bytes]): class AsyncHttpEngine(HttpEngine): - - def __init__( - self, - scheme: str, - host: str, - session: Optional[aiohttp.ClientSession] = None, - **kwargs - ): - """ Simple tool for making asynchronous streaming http requests. + def __init__(self, scheme: str, host: str, session: Optional[aiohttp.ClientSession] = None, **kwargs): + """Simple tool for making asynchronous streaming http requests. Parameters ---------- scheme : e.g "http" or "https" - host : + host : will be used as the base for request urls - session : - If provided, this preconstructed session will be used rather than - a new one. Keep in mind that AsyncHttpEngine closes its session + session : + If provided, this preconstructed session will be used rather than + a new one. Keep in mind that AsyncHttpEngine closes its session when it is garbage collected! **kwargs : Will be passed to parent. @@ -121,33 +109,22 @@ def __init__( self._session = None if session: self._session = session - warnings.warn( - "Recieved preconstructed session, ignoring timeout parameter." - ) + warnings.warn("Recieved preconstructed session, ignoring timeout parameter.") @property def session(self): if self._session is None: - self._session = aiohttp.ClientSession( - timeout=aiohttp.client.ClientTimeout(self.timeout) - ) + self._session = aiohttp.ClientSession(timeout=aiohttp.client.ClientTimeout(self.timeout)) return self._session - async def _stream_coroutine( - self, - route: str, - callback: AsyncStreamCallbackType - ): + async def _stream_coroutine(self, route: str, callback: AsyncStreamCallbackType): url = self._build_url(route) async with self.session.get(url) as response: await callback(response.content.iter_chunked(self.chunksize)) - def stream( - self, - route: str - ) -> Callable[[AsyncStreamCallbackType], Awaitable[None]]: - """ Returns a coroutine which + def stream(self, route: str) -> Callable[[AsyncStreamCallbackType], Awaitable[None]]: + """Returns a coroutine which - makes an http request - exposes internally an asynchronous iterator over the response - takes a callback parameter, which should consume the iterator. @@ -180,40 +157,35 @@ def __del__(self): loop.run_until_complete(self._session.close()) @staticmethod - def write_bytes( - path: str, - coroutine: Callable[[AsyncStreamCallbackType], Awaitable[None]]): + def write_bytes(path: str, coroutine: Callable[[AsyncStreamCallbackType], Awaitable[None]]): write_bytes_from_coroutine(path, coroutine) -def write_bytes_from_coroutine( - path: str, - coroutine: Callable[[AsyncStreamCallbackType], Awaitable[None]] -): - """ Utility for streaming http from an asynchronous requester to a file. +def write_bytes_from_coroutine(path: str, coroutine: Callable[[AsyncStreamCallbackType], Awaitable[None]]): + """Utility for streaming http from an asynchronous requester to a file. Parameters ---------- - path : + path : Write to this file - coroutine : + coroutine : The source of the data. Needs to have a specific structure, namely: - the first-position parameter of the coroutine ought to accept a callback. This callback ought to itself be awaitable. - - within the coroutine, this callback ought to be called with a - single argument. That single argument should be an asynchronous + - within the coroutine, this callback ought to be called with a + single argument. That single argument should be an asynchronous iterator. - Please see AsyncHttpEngine.stream (and - AsyncHttpEngine._stream_coroutine) for an example. - + Please see AsyncHttpEngine.stream (and + AsyncHttpEngine._stream_coroutine) for an example. + """ - + os.makedirs(os.path.dirname(path), exist_ok=True) - + async def callback(file_, iterable): async for chunk in iterable: file_.write(chunk) - + async def wrapper(): with open(path, "wb") as file_: callback_ = functools.partial(callback, file_) @@ -225,13 +197,13 @@ async def wrapper(): def write_from_stream(path: str, stream: Iterable[bytes]): - """ Write bytes to a file from an iterator + """Write bytes to a file from an iterator Parameters ---------- - path : + path : write to this file - stream : + stream : iterable yielding bytes to be written """ diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/rma_engine.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/rma_engine.py index da9e646778..aa4d671c8d 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/rma_engine.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/rma_engine.py @@ -13,35 +13,28 @@ class RmaRequestError(Exception): class RmaEngine(HttpEngine): - @property def format_query_string(self): return f"query.{self.rma_format}" def __init__( - self, - scheme, - host, - rma_prefix: str = "api/v2/data", - rma_format: str = "json", - page_size: int = 5000, - **kwargs + self, scheme, host, rma_prefix: str = "api/v2/data", rma_format: str = "json", page_size: int = 5000, **kwargs ): - """ Simple tool for making rma and streaming http requests. + """Simple tool for making rma and streaming http requests. Parameters ---------- scheme : e.g "http" or "https" - host : + host : will be used as the base for request urls rma_prefix : rma request routes will be prefixed with this string - rma_format : + rma_format : Format of reuturned response. e.g. "json", "xml", "csv" page_size : how many rma records to request in one query. - **kwargs : + **kwargs : will be passed to parent """ @@ -56,11 +49,11 @@ def add_page_params(self, url, start, count=None): return f"{url},rma::options[start_row$eq{start}][num_rows$eq{count}][order$eq'id']" def get_rma(self, query: str): - """ Makes a paging rma query + """Makes a paging rma query Parameters ---------- - query : + query : The RMA query parameters """ @@ -84,7 +77,6 @@ def get_rma(self, query: str): logging.debug(f"downloaded {start_row} of {total_rows} records ({time.time() - start_time:.3f} seconds)") yield response_json["msg"] - def get_rma_list(self, query): response = [] for chunk in self.get_rma(query): @@ -101,18 +93,17 @@ def get_rma_tabular(self, query, try_infer_dtypes=True): class AsyncRmaEngine(RmaEngine, AsyncHttpEngine): - def __init__(self, scheme: str, host: str, **kwargs): - """ Simple tool for making rma and asynchronous streaming http + """Simple tool for making rma and asynchronous streaming http requests. Parameters ---------- scheme : e.g "http" or "https" - host : + host : will be used as the base for request urls - **kwargs : + **kwargs : will be passed to parent """ @@ -120,8 +111,7 @@ def __init__(self, scheme: str, host: str, **kwargs): def infer_column_types(dataframe): - """ RMA queries often come back with string-typed columns. This utility tries to infer numeric types. - """ + """RMA queries often come back with string-typed columns. This utility tries to infer numeric types.""" dataframe = dataframe.copy() @@ -130,6 +120,6 @@ def infer_column_types(dataframe): dataframe[colname] = dataframe[colname].apply(ast.literal_eval) except (ValueError, SyntaxError): continue - + dataframe = dataframe.infer_objects() - return dataframe \ No newline at end of file + return dataframe diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_api/utilities.py b/allensdk/brain_observatory/ecephys/ecephys_project_api/utilities.py index ecedeb934d..abef26b161 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_api/utilities.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_api/utilities.py @@ -51,7 +51,7 @@ def postgres_macros(): {% endif %} {% endmacro %} """, - "macros": macros()["macros"] + "macros": macros()["macros"], } @@ -63,11 +63,10 @@ def rma_macros(): {%- if data is not none %}[{{key}}$in{{m.comma_sep(data,quote)}}]{% endif -%} {%- endmacro -%} """, - "macros": macros()["macros"] + "macros": macros()["macros"], } - def build_and_execute(query, base=None, engine=None, **kwargs): env = build_environment({"__tmp__": query}, base=base) return execute_templated(env, "__tmp__", engine=engine, **kwargs) diff --git a/allensdk/brain_observatory/ecephys/ecephys_project_cache.py b/allensdk/brain_observatory/ecephys/ecephys_project_cache.py index 2672c4876d..3194173d88 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_project_cache.py +++ b/allensdk/brain_observatory/ecephys/ecephys_project_cache.py @@ -11,29 +11,29 @@ from allensdk.api.warehouse_cache.cache import Cache from allensdk.core.authentication import DbCredentials from allensdk.brain_observatory.ecephys.ecephys_project_api import ( - EcephysProjectApi, EcephysProjectLimsApi, EcephysProjectWarehouseApi, - EcephysProjectFixedApi + EcephysProjectApi, + EcephysProjectLimsApi, + EcephysProjectWarehouseApi, + EcephysProjectFixedApi, ) from allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine import ( - write_bytes_from_coroutine, write_from_stream -) -from allensdk.brain_observatory.ecephys.ecephys_session_api import ( - EcephysNwbSessionApi + write_bytes_from_coroutine, + write_from_stream, ) +from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysNwbSessionApi from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession from allensdk.brain_observatory.ecephys import get_unit_filter_value from allensdk.api.warehouse_cache.caching_utilities import one_file_call_caching class EcephysProjectCache(Cache): + SESSIONS_KEY = "sessions" + PROBES_KEY = "probes" + CHANNELS_KEY = "channels" + UNITS_KEY = "units" - SESSIONS_KEY = 'sessions' - PROBES_KEY = 'probes' - CHANNELS_KEY = 'channels' - UNITS_KEY = 'units' - - SESSION_DIR_KEY = 'session_data' - SESSION_NWB_KEY = 'session_nwb' + SESSION_DIR_KEY = "session_data" + SESSION_NWB_KEY = "session_nwb" PROBE_LFP_NWB_KEY = "probe_lfp_nwb" NATURAL_MOVIE_DIR_KEY = "movie_dir" @@ -45,40 +45,58 @@ class EcephysProjectCache(Cache): SESSION_ANALYSIS_METRICS_KEY = "session_analysis_metrics" TYPEWISE_ANALYSIS_METRICS_KEY = "typewise_analysis_metrics" - MANIFEST_VERSION = '0.3.0' + MANIFEST_VERSION = "0.3.0" - SUPPRESS_FROM_UNITS = ("air_channel_index", - "surface_channel_index", - "has_nwb", - "lfp_temporal_subsampling_factor", - "epoch_name_quality_metrics", - "epoch_name_waveform_metrics", - "isi_experiment_id") + SUPPRESS_FROM_UNITS = ( + "air_channel_index", + "surface_channel_index", + "has_nwb", + "lfp_temporal_subsampling_factor", + "epoch_name_quality_metrics", + "epoch_name_waveform_metrics", + "isi_experiment_id", + ) SUPPRESS_FROM_CHANNELS = ( - "air_channel_index", "surface_channel_index", "name", - "date_of_acquisition", "published_at", "specimen_id", "session_type", "isi_experiment_id", "age_in_days", - "sex", "genotype", "has_nwb", "lfp_temporal_subsampling_factor" + "air_channel_index", + "surface_channel_index", + "name", + "date_of_acquisition", + "published_at", + "specimen_id", + "session_type", + "isi_experiment_id", + "age_in_days", + "sex", + "genotype", + "has_nwb", + "lfp_temporal_subsampling_factor", ) SUPPRESS_FROM_PROBES = ( - "air_channel_index", "surface_channel_index", - "date_of_acquisition", "published_at", "specimen_id", "session_type", "isi_experiment_id", "age_in_days", - "sex", "genotype", "has_nwb", "lfp_temporal_subsampling_factor" - ) - SUPPRESS_FROM_SESSION_TABLE = ( - "has_nwb", + "air_channel_index", + "surface_channel_index", + "date_of_acquisition", + "published_at", + "specimen_id", + "session_type", "isi_experiment_id", - "date_of_acquisition" + "age_in_days", + "sex", + "genotype", + "has_nwb", + "lfp_temporal_subsampling_factor", ) + SUPPRESS_FROM_SESSION_TABLE = ("has_nwb", "isi_experiment_id", "date_of_acquisition") def __init__( - self, - fetch_api: Optional[EcephysProjectApi] = None, - fetch_tries: int = 2, - stream_writer: Optional[Callable] = None, - manifest: Optional[Union[str, Path]] = None, - version: Optional[str] = None, - cache: bool = True): - """ Entrypoint for accessing ecephys (neuropixels) data. Supports + self, + fetch_api: Optional[EcephysProjectApi] = None, + fetch_tries: int = 2, + stream_writer: Optional[Callable] = None, + manifest: Optional[Union[str, Path]] = None, + version: Optional[str] = None, + cache: bool = True, + ): + """Entrypoint for accessing ecephys (neuropixels) data. Supports access to cross-session data (like stimulus templates) and high-level summaries of sessionwise data and provides tools for downloading detailed sessionwise data (such as spike times). @@ -162,21 +180,17 @@ class constructors rather than to initialize this class directly. manifest_ = manifest or "ecephys_project_manifest.json" version_ = version or self.MANIFEST_VERSION - super(EcephysProjectCache, self).__init__(manifest=manifest_, - version=version_, - cache=cache) - self.fetch_api = (EcephysProjectWarehouseApi.default() - if fetch_api is None else fetch_api) + super(EcephysProjectCache, self).__init__(manifest=manifest_, version=version_, cache=cache) + self.fetch_api = EcephysProjectWarehouseApi.default() if fetch_api is None else fetch_api self.fetch_tries = fetch_tries - self.stream_writer = (stream_writer - or self.fetch_api.rma_engine.write_bytes) + self.stream_writer = stream_writer or self.fetch_api.rma_engine.write_bytes if stream_writer is not None: self.stream_writer = stream_writer else: - if hasattr(self.fetch_api, "rma_engine"): # EcephysProjectWarehouseApi # noqa + if hasattr(self.fetch_api, "rma_engine"): # EcephysProjectWarehouseApi # noqa self.stream_writer = self.fetch_api.rma_engine.write_bytes # TODO: Make these names consistent in the different fetch apis - elif hasattr(self.fetch_api, "app_engine"): # EcephysProjectLimsApi # noqa + elif hasattr(self.fetch_api, "app_engine"): # EcephysProjectLimsApi # noqa self.stream_writer = self.fetch_api.app_engine.write_bytes else: raise ValueError( @@ -185,11 +199,14 @@ class constructors rather than to initialize this class directly. "that implements `write_bytes`. See `HttpEngine` and " "`AsyncHttpEngine` from " "allensdk.brain_observatory.ecephys.ecephys_project_api." - "http_engine for examples.") + "http_engine for examples." + ) def _get_sessions(self): path = self.get_cache_path(None, self.SESSIONS_KEY) - response = one_file_call_caching(path, self.fetch_api.get_sessions, write_csv, read_csv, num_tries=self.fetch_tries) + response = one_file_call_caching( + path, self.fetch_api.get_sessions, write_csv, read_csv, num_tries=self.fetch_tries + ) if "structure_acronyms" in response.columns: # unfortunately, structure_acronyms is a list of str response["ecephys_structure_acronyms"] = [ast.literal_eval(item) for item in response["structure_acronyms"]] @@ -201,10 +218,8 @@ def _get_probes(self): path: str = self.get_cache_path(None, self.PROBES_KEY) probes = one_file_call_caching(path, self.fetch_api.get_probes, write_csv, read_csv, num_tries=self.fetch_tries) # Divide the lfp sampling by the subsampling factor for clearer presentation (if provided) - if all(c in list(probes) for c in - ["lfp_sampling_rate", "lfp_temporal_subsampling_factor"]): - probes["lfp_sampling_rate"] = ( - probes["lfp_sampling_rate"] / probes["lfp_temporal_subsampling_factor"]) + if all(c in list(probes) for c in ["lfp_sampling_rate", "lfp_temporal_subsampling_factor"]): + probes["lfp_sampling_rate"] = probes["lfp_sampling_rate"] / probes["lfp_temporal_subsampling_factor"] return probes def _get_channels(self): @@ -215,18 +230,20 @@ def _get_units(self, filter_by_validity: bool = True, **unit_filter_kwargs) -> p path = self.get_cache_path(None, self.UNITS_KEY) units = one_file_call_caching(path, self.fetch_api.get_units, write_csv, read_csv, num_tries=self.fetch_tries) - units = units.rename(columns={ - 'PT_ratio': 'waveform_PT_ratio', - 'amplitude': 'waveform_amplitude', - 'duration': 'waveform_duration', - 'halfwidth': 'waveform_halfwidth', - 'recovery_slope': 'waveform_recovery_slope', - 'repolarization_slope': 'waveform_repolarization_slope', - 'spread': 'waveform_spread', - 'velocity_above': 'waveform_velocity_above', - 'velocity_below': 'waveform_velocity_below', - 'l_ratio': 'L_ratio', - }) + units = units.rename( + columns={ + "PT_ratio": "waveform_PT_ratio", + "amplitude": "waveform_amplitude", + "duration": "waveform_duration", + "halfwidth": "waveform_halfwidth", + "recovery_slope": "waveform_recovery_slope", + "repolarization_slope": "waveform_repolarization_slope", + "spread": "waveform_spread", + "velocity_above": "waveform_velocity_above", + "velocity_below": "waveform_velocity_below", + "l_ratio": "L_ratio", + } + ) units = units[ (units["amplitude_cutoff"] <= get_unit_filter_value("amplitude_cutoff_maximum", **unit_filter_kwargs)) @@ -247,27 +264,33 @@ def _get_annotated_probes(self): sessions = self._get_sessions() probes = self._get_probes() - return pd.merge(probes, sessions, left_on="ecephys_session_id", right_index=True, suffixes=['_probe', '_session']) + return pd.merge( + probes, sessions, left_on="ecephys_session_id", right_index=True, suffixes=["_probe", "_session"] + ) def _get_annotated_channels(self): channels = self._get_channels() probes = self._get_annotated_probes() - return pd.merge(channels, probes, left_on="ecephys_probe_id", right_index=True, suffixes=['_channel', '_probe']) + return pd.merge(channels, probes, left_on="ecephys_probe_id", right_index=True, suffixes=["_channel", "_probe"]) def _get_annotated_units(self, filter_by_validity: bool = True, **unit_filter_kwargs) -> pd.DataFrame: units = self._get_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) channels = self._get_annotated_channels() - annotated_units = pd.merge(units, channels, left_on='ecephys_channel_id', right_index=True, suffixes=['_unit', '_channel']) - annotated_units = annotated_units.rename(columns={ - 'name': 'probe_name', - 'phase': 'probe_phase', - 'sampling_rate': 'probe_sampling_rate', - 'lfp_sampling_rate': 'probe_lfp_sampling_rate', - 'local_index': 'peak_channel' - }) + annotated_units = pd.merge( + units, channels, left_on="ecephys_channel_id", right_index=True, suffixes=["_unit", "_channel"] + ) + annotated_units = annotated_units.rename( + columns={ + "name": "probe_name", + "phase": "probe_phase", + "sampling_rate": "probe_sampling_rate", + "lfp_sampling_rate": "probe_lfp_sampling_rate", + "local_index": "peak_channel", + } + ) - return pd.merge(units, channels, left_on='ecephys_channel_id', right_index=True, suffixes=['_unit', '_channel']) + return pd.merge(units, channels, left_on="ecephys_channel_id", right_index=True, suffixes=["_unit", "_channel"]) def get_session_table(self, suppress=None) -> pd.DataFrame: sessions = self._get_sessions() @@ -276,12 +299,19 @@ def get_session_table(self, suppress=None) -> pd.DataFrame: count_owned(sessions, self._get_annotated_channels(), "ecephys_session_id", "channel_count", inplace=True) count_owned(sessions, self._get_annotated_probes(), "ecephys_session_id", "probe_count", inplace=True) - get_grouped_uniques(sessions, self._get_annotated_channels(), "ecephys_session_id", "ecephys_structure_acronym", "ecephys_structure_acronyms", inplace=True) + get_grouped_uniques( + sessions, + self._get_annotated_channels(), + "ecephys_session_id", + "ecephys_structure_acronym", + "ecephys_structure_acronyms", + inplace=True, + ) if suppress is None: suppress = list(self.SUPPRESS_FROM_SESSION_TABLE) sessions.drop(columns=suppress, inplace=True, errors="ignore") - sessions = sessions.rename(columns={'genotype': 'full_genotype'}) + sessions = sessions.rename(columns={"genotype": "full_genotype"}) return sessions def get_probes(self, suppress=None): @@ -290,7 +320,14 @@ def get_probes(self, suppress=None): count_owned(probes, self._get_annotated_units(), "ecephys_probe_id", "unit_count", inplace=True) count_owned(probes, self._get_annotated_channels(), "ecephys_probe_id", "channel_count", inplace=True) - get_grouped_uniques(probes, self._get_annotated_channels(), "ecephys_probe_id", "ecephys_structure_acronym", "ecephys_structure_acronyms", inplace=True) + get_grouped_uniques( + probes, + self._get_annotated_channels(), + "ecephys_probe_id", + "ecephys_structure_acronym", + "ecephys_structure_acronyms", + inplace=True, + ) if suppress is None: suppress = list(self.SUPPRESS_FROM_PROBES) @@ -299,8 +336,7 @@ def get_probes(self, suppress=None): return probes def get_channels(self, suppress=None): - """ Load (potentially downloading and caching) a table whose rows are individual channels. - """ + """Load (potentially downloading and caching) a table whose rows are individual channels.""" channels = self._get_annotated_channels() count_owned(channels, self._get_annotated_units(), "ecephys_channel_id", "unit_count", inplace=True) @@ -312,7 +348,9 @@ def get_channels(self, suppress=None): return channels - def get_units(self, suppress: Optional[List[str]] = None, filter_by_validity: bool = True, **unit_filter_kwargs) -> pd.DataFrame: + def get_units( + self, suppress: Optional[List[str]] = None, filter_by_validity: bool = True, **unit_filter_kwargs + ) -> pd.DataFrame: """Reports a table consisting of all sorted units across the entire extracellular electrophysiology project. Parameters @@ -341,8 +379,7 @@ def get_units(self, suppress: Optional[List[str]] = None, filter_by_validity: bo return units def get_session_data(self, session_id: int, filter_by_validity: bool = True, **unit_filter_kwargs): - """ Obtain an EcephysSession object containing detailed data for a single session - """ + """Obtain an EcephysSession object containing detailed data for a single session""" def read(_path): session_api = self._build_nwb_api_for_session(_path, session_id, filter_by_validity, **unit_filter_kwargs) @@ -353,17 +390,16 @@ def read(_path): partial(self.fetch_api.get_session_data, session_id), self.stream_writer, read, - num_tries=self.fetch_tries + num_tries=self.fetch_tries, ) def _build_nwb_api_for_session(self, path, session_id, filter_by_validity, **unit_filter_kwargs): - get_analysis_metrics = partial( self.get_unit_analysis_metrics_for_session, session_id=session_id, annotate=False, filter_by_validity=True, - **unit_filter_kwargs + **unit_filter_kwargs, ) return EcephysNwbSessionApi( @@ -372,7 +408,7 @@ def _build_nwb_api_for_session(self, path, session_id, filter_by_validity, **uni additional_unit_metrics=get_analysis_metrics, external_channel_columns=partial(self._get_substitute_channel_columns, session_id), filter_by_validity=filter_by_validity, - **unit_filter_kwargs + **unit_filter_kwargs, ) def _setup_probe_promises(self, session_id): @@ -386,20 +422,23 @@ def _setup_probe_promises(self, session_id): partial(self.fetch_api.get_probe_lfp_data, probe_id), self.stream_writer, read_nwb, - num_tries=self.fetch_tries + num_tries=self.fetch_tries, ) for probe_id in probe_ids } def _get_substitute_channel_columns(self, session_id): channels = self.get_channels() - return channels.loc[channels["ecephys_session_id"] == session_id, [ - "ecephys_structure_id", - "ecephys_structure_acronym", - "anterior_posterior_ccf_coordinate", - "dorsal_ventral_ccf_coordinate", - "left_right_ccf_coordinate" - ]] + return channels.loc[ + channels["ecephys_session_id"] == session_id, + [ + "ecephys_structure_id", + "ecephys_structure_acronym", + "anterior_posterior_ccf_coordinate", + "dorsal_ventral_ccf_coordinate", + "left_right_ccf_coordinate", + ], + ] def get_natural_movie_template(self, number): return one_file_call_caching( @@ -407,7 +446,7 @@ def get_natural_movie_template(self, number): partial(self.fetch_api.get_natural_movie_template, number=number), self.stream_writer, read_movie, - num_tries=self.fetch_tries + num_tries=self.fetch_tries, ) def get_natural_scene_template(self, number): @@ -416,7 +455,7 @@ def get_natural_scene_template(self, number): partial(self.fetch_api.get_natural_scene_template, number=number), self.stream_writer, read_scene, - num_tries=self.fetch_tries + num_tries=self.fetch_tries, ) def get_all_session_types(self, **session_kwargs): @@ -440,8 +479,10 @@ def _get_all_values(self, key, method=None, **method_kwargs) -> List[Any]: data = method(**method_kwargs) return data[key].unique().tolist() - def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): - """ Cache and return a table of analysis metrics calculated on each unit from a specified session. See + def get_unit_analysis_metrics_for_session( + self, session_id, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs + ): + """Cache and return a table of analysis metrics calculated on each unit from a specified session. See get_session_table for a list of sessions. Parameters @@ -465,7 +506,9 @@ def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = Tru path = self.get_cache_path(None, self.SESSION_ANALYSIS_METRICS_KEY, session_id, session_id) fetch_metrics = partial(self.fetch_api.get_unit_analysis_metrics, ecephys_session_ids=[session_id]) - metrics = one_file_call_caching(path, fetch_metrics, write_metrics_csv, read_metrics_csv, num_tries=self.fetch_tries) + metrics = one_file_call_caching( + path, fetch_metrics, write_metrics_csv, read_metrics_csv, num_tries=self.fetch_tries + ) if annotate: units = self.get_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) @@ -475,8 +518,10 @@ def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = Tru return metrics - def get_unit_analysis_metrics_by_session_type(self, session_type, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): - """ Cache and return a table of analysis metrics calculated on each unit from a specified session type. See + def get_unit_analysis_metrics_by_session_type( + self, session_type, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs + ): + """Cache and return a table of analysis metrics calculated on each unit from a specified session type. See get_all_session_types for a list of session types. Parameters @@ -505,11 +550,7 @@ def get_unit_analysis_metrics_by_session_type(self, session_type, annotate: bool fetch_metrics = partial(self.fetch_api.get_unit_analysis_metrics, session_types=[session_type]) metrics = one_file_call_caching( - path, - fetch_metrics, - write_metrics_csv, - read_metrics_csv, - num_tries=self.fetch_tries + path, fetch_metrics, write_metrics_csv, read_metrics_csv, num_tries=self.fetch_tries ) if annotate: @@ -522,36 +563,29 @@ def get_unit_analysis_metrics_by_session_type(self, session_type, annotate: bool def add_manifest_paths(self, manifest_builder): manifest_builder = super(EcephysProjectCache, self).add_manifest_paths(manifest_builder) - manifest_builder.add_path( - self.SESSIONS_KEY, 'sessions.csv', parent_key='BASEDIR', typename='file' - ) + manifest_builder.add_path(self.SESSIONS_KEY, "sessions.csv", parent_key="BASEDIR", typename="file") - manifest_builder.add_path( - self.PROBES_KEY, 'probes.csv', parent_key='BASEDIR', typename='file' - ) + manifest_builder.add_path(self.PROBES_KEY, "probes.csv", parent_key="BASEDIR", typename="file") - manifest_builder.add_path( - self.CHANNELS_KEY, 'channels.csv', parent_key='BASEDIR', typename='file' - ) + manifest_builder.add_path(self.CHANNELS_KEY, "channels.csv", parent_key="BASEDIR", typename="file") - manifest_builder.add_path( - self.UNITS_KEY, 'units.csv', parent_key='BASEDIR', typename='file' - ) + manifest_builder.add_path(self.UNITS_KEY, "units.csv", parent_key="BASEDIR", typename="file") - manifest_builder.add_path( - self.SESSION_DIR_KEY, 'session_%d', parent_key='BASEDIR', typename='dir' - ) + manifest_builder.add_path(self.SESSION_DIR_KEY, "session_%d", parent_key="BASEDIR", typename="dir") manifest_builder.add_path( - self.SESSION_NWB_KEY, 'session_%d.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' + self.SESSION_NWB_KEY, "session_%d.nwb", parent_key=self.SESSION_DIR_KEY, typename="file" ) manifest_builder.add_path( - self.SESSION_ANALYSIS_METRICS_KEY, 'session_%d_analysis_metrics.csv', parent_key=self.SESSION_DIR_KEY, typename='file' + self.SESSION_ANALYSIS_METRICS_KEY, + "session_%d_analysis_metrics.csv", + parent_key=self.SESSION_DIR_KEY, + typename="file", ) manifest_builder.add_path( - self.PROBE_LFP_NWB_KEY, 'probe_%d_lfp.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' + self.PROBE_LFP_NWB_KEY, "probe_%d_lfp.nwb", parent_key=self.SESSION_DIR_KEY, typename="file" ) manifest_builder.add_path( @@ -559,7 +593,7 @@ def add_manifest_paths(self, manifest_builder): ) manifest_builder.add_path( - self.TYPEWISE_ANALYSIS_METRICS_KEY, "%s_analysis_metrics.csv", parent_key='BASEDIR', typename="file" + self.TYPEWISE_ANALYSIS_METRICS_KEY, "%s_analysis_metrics.csv", parent_key="BASEDIR", typename="file" ) manifest_builder.add_path( @@ -578,9 +612,7 @@ def add_manifest_paths(self, manifest_builder): @classmethod def _from_http_source_default(cls, fetch_api_cls, fetch_api_kwargs, **kwargs): - fetch_api_kwargs = { - "asynchronous": True - } if fetch_api_kwargs is None else fetch_api_kwargs + fetch_api_kwargs = {"asynchronous": True} if fetch_api_kwargs is None else fetch_api_kwargs if kwargs.get("stream_writer") is None: if fetch_api_kwargs.get("asynchronous", True): @@ -588,20 +620,20 @@ def _from_http_source_default(cls, fetch_api_cls, fetch_api_kwargs, **kwargs): else: kwargs["stream_writer"] = write_from_stream - return cls( - fetch_api=fetch_api_cls.default(**fetch_api_kwargs), - **kwargs - ) + return cls(fetch_api=fetch_api_cls.default(**fetch_api_kwargs), **kwargs) @classmethod - def from_lims(cls, lims_credentials: Optional[DbCredentials] = None, - scheme: Optional[str] = None, - host: Optional[str] = None, - asynchronous: bool = False, - manifest: Optional[Union[str, Path]] = None, - version: Optional[str] = None, - cache: bool = True, - fetch_tries: int = 2): + def from_lims( + cls, + lims_credentials: Optional[DbCredentials] = None, + scheme: Optional[str] = None, + host: Optional[str] = None, + asynchronous: bool = False, + manifest: Optional[Union[str, Path]] = None, + version: Optional[str] = None, + cache: bool = True, + fetch_tries: int = 2, + ): """ Create an instance of EcephysProjectCache with an EcephysProjectLimsApi. Retrieves bleeding-edge data stored @@ -641,23 +673,29 @@ def from_lims(cls, lims_credentials: Optional[DbCredentials] = None, app_kwargs = None return cls._from_http_source_default( EcephysProjectLimsApi, - {"lims_credentials": lims_credentials, - "app_kwargs": app_kwargs, - "asynchronous": asynchronous, - }, # expects dictionary of kwargs - manifest=manifest, version=version, cache=cache, - fetch_tries=fetch_tries) + { + "lims_credentials": lims_credentials, + "app_kwargs": app_kwargs, + "asynchronous": asynchronous, + }, # expects dictionary of kwargs + manifest=manifest, + version=version, + cache=cache, + fetch_tries=fetch_tries, + ) @classmethod - def from_warehouse(cls, - scheme: Optional[str] = None, - host: Optional[str] = None, - asynchronous: bool = False, - manifest: Optional[Union[str, Path]] = None, - version: Optional[str] = None, - cache: bool = True, - fetch_tries: int = 2, - timeout: int = 1200): + def from_warehouse( + cls, + scheme: Optional[str] = None, + host: Optional[str] = None, + asynchronous: bool = False, + manifest: Optional[Union[str, Path]] = None, + version: Optional[str] = None, + cache: bool = True, + fetch_tries: int = 2, + timeout: int = 1200, + ): """ Create an instance of EcephysProjectCache with an EcephysProjectWarehouseApi. Retrieves released data stored in @@ -691,19 +729,21 @@ def from_warehouse(cls, are not responding quickly. Defaults to 1200 seconds (20 minutes). """ if scheme and host: - app_kwargs = {"scheme": scheme, "host": host, - "asynchronous": asynchronous} + app_kwargs = {"scheme": scheme, "host": host, "asynchronous": asynchronous} else: app_kwargs = {"asynchronous": asynchronous} - app_kwargs['timeout'] = timeout + app_kwargs["timeout"] = timeout return cls._from_http_source_default( - EcephysProjectWarehouseApi, app_kwargs, manifest=manifest, - version=version, cache=cache, fetch_tries=fetch_tries + EcephysProjectWarehouseApi, + app_kwargs, + manifest=manifest, + version=version, + cache=cache, + fetch_tries=fetch_tries, ) @classmethod - def fixed(cls, manifest: Optional[Union[str, Path]] = None, - version: Optional[str] = None): + def fixed(cls, manifest: Optional[Union[str, Path]] = None, version: Optional[str] = None): """ Creates a EcephysProjectCache that refuses to fetch any data - only the existing local cache is accessible. Useful if you @@ -717,8 +757,7 @@ def fixed(cls, manifest: Optional[Union[str, Path]] = None, version of manifest file. If this mismatches the version recorded in the file at manifest, an error will be raised. """ - return cls(fetch_api=EcephysProjectFixedApi(), manifest=manifest, - version=version) + return cls(fetch_api=EcephysProjectFixedApi(), manifest=manifest, version=version) def count_owned(this, other, foreign_key, count_key, inplace=False): @@ -737,8 +776,7 @@ def get_grouped_uniques(this, other, foreign_key, field_key, unique_key, inplace if not inplace: this = this.copy() - uniques = other.groupby(foreign_key)[field_key]\ - .apply(lambda x: x.unique()) + uniques = other.groupby(foreign_key)[field_key].apply(lambda x: x.unique()) this[unique_key] = None this.loc[uniques.index.values, unique_key] = uniques.values @@ -759,7 +797,7 @@ def write_metrics_csv(path, df): def read_metrics_csv(path): - return pd.read_csv(path, index_col='ecephys_unit_id') + return pd.read_csv(path, index_col="ecephys_unit_id") def read_scene(path): @@ -771,7 +809,7 @@ def read_movie(path): def read_nwb(path): - reader = pynwb.NWBHDF5IO(str(path), 'r') + reader = pynwb.NWBHDF5IO(str(path), "r") nwbfile = reader.read() nwbfile.identifier # if the file is corrupt, make sure an exception gets raised during read return nwbfile diff --git a/allensdk/brain_observatory/ecephys/ecephys_session.py b/allensdk/brain_observatory/ecephys/ecephys_session.py index 69cf5e928b..00b0ff5dfb 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session.py @@ -22,17 +22,11 @@ from allensdk.core.lazy_property import LazyPropertyMixin # stimulus_presentation column names not describing a parameter of a stimulus -NON_STIMULUS_PARAMETERS = tuple([ - 'start_time', - 'stop_time', - 'duration', - 'stimulus_block', - "stimulus_condition_id" -]) +NON_STIMULUS_PARAMETERS = tuple(["start_time", "stop_time", "duration", "stimulus_block", "stimulus_condition_id"]) class EcephysSession(LazyPropertyMixin): - ''' Represents data from a single EcephysSession + """Represents data from a single EcephysSession Attributes ---------- @@ -140,7 +134,7 @@ class EcephysSession(LazyPropertyMixin): two presentations in seconds on the experiment's master clock. - ''' + """ DETAILED_STIMULUS_PARAMETERS = ( "colorSpace", @@ -163,7 +157,7 @@ class EcephysSession(LazyPropertyMixin): "nDots", "dotSize", "dotLife", - "color_triplet" + "color_triplet", ) @property @@ -184,7 +178,7 @@ def num_stimulus_presentations(self): @property def stimulus_names(self): - return self.stimulus_presentations['stimulus_name'].unique().tolist() + return self.stimulus_presentations["stimulus_name"].unique().tolist() @property def stimulus_conditions(self): @@ -227,15 +221,19 @@ def session_type(self): @property def units(self): - return self._units.drop(columns=['width_rf', - 'height_rf', - 'on_screen_rf', - 'time_to_peak_fl', - 'time_to_peak_rf', - 'time_to_peak_sg', - 'sustained_idx_fl', - 'time_to_peak_dg'], - errors='ignore') + return self._units.drop( + columns=[ + "width_rf", + "height_rf", + "on_screen_rf", + "time_to_peak_fl", + "time_to_peak_rf", + "time_to_peak_sg", + "sustained_idx_fl", + "time_to_peak_dg", + ], + errors="ignore", + ) @property def structure_acronyms(self): @@ -261,13 +259,12 @@ def metadata(self): "session_start_time": self.session_start_time, "ecephys_session_id": self.ecephys_session_id, "structure_acronyms": self.structure_acronyms, - "stimulus_names": self.stimulus_names + "stimulus_names": self.stimulus_names, } @property def stimulus_presentations(self): - return self.__class__._remove_detailed_stimulus_parameters( - self._stimulus_presentations) + return self.__class__._remove_detailed_stimulus_parameters(self._stimulus_presentations) @property def spike_times(self): @@ -277,13 +274,8 @@ def spike_times(self): return self._spike_times - def __init__( - self, - api: EcephysSessionApi, - test: bool = False, - **kwargs - ): - """ Construct an EcephysSession object, which provides access to + def __init__(self, api: EcephysSessionApi, test: bool = False, **kwargs): + """Construct an EcephysSession object, which provides access to detailed data for a single extracellular electrophysiology (neuropixels) session. @@ -303,38 +295,25 @@ def __init__( self.api: EcephysSessionApi = api - self.ecephys_session_id = \ - self.LazyProperty(self.api.get_ecephys_session_id) - self.session_start_time = \ - self.LazyProperty(self.api.get_session_start_time) - self.running_speed = \ - self.LazyProperty(self.api.get_running_speed) - self.mean_waveforms = \ - self.LazyProperty(self.api.get_mean_waveforms, - wrappers=[self._build_mean_waveforms]) - self._spike_times = \ - self.LazyProperty(self.api.get_spike_times, - wrappers=[self._build_spike_times]) - self.optogenetic_stimulation_epochs = \ - self.LazyProperty(self.api.get_optogenetic_stimulation) - self.spike_amplitudes = \ - self.LazyProperty(self.api.get_spike_amplitudes) + self.ecephys_session_id = self.LazyProperty(self.api.get_ecephys_session_id) + self.session_start_time = self.LazyProperty(self.api.get_session_start_time) + self.running_speed = self.LazyProperty(self.api.get_running_speed) + self.mean_waveforms = self.LazyProperty(self.api.get_mean_waveforms, wrappers=[self._build_mean_waveforms]) + self._spike_times = self.LazyProperty(self.api.get_spike_times, wrappers=[self._build_spike_times]) + self.optogenetic_stimulation_epochs = self.LazyProperty(self.api.get_optogenetic_stimulation) + self.spike_amplitudes = self.LazyProperty(self.api.get_spike_amplitudes) self.probes = self.LazyProperty(self.api.get_probes) self.channels = self.LazyProperty(self.api.get_channels) - self._stimulus_presentations = \ - self.LazyProperty( - self.api.get_stimulus_presentations, - wrappers=[self._build_stimulus_presentations, - self._mask_invalid_stimulus_presentations]) - self.inter_presentation_intervals = \ - self.LazyProperty(self._build_inter_presentation_intervals) + self._stimulus_presentations = self.LazyProperty( + self.api.get_stimulus_presentations, + wrappers=[self._build_stimulus_presentations, self._mask_invalid_stimulus_presentations], + ) + self.inter_presentation_intervals = self.LazyProperty(self._build_inter_presentation_intervals) self.invalid_times = self.LazyProperty(self.api.get_invalid_times) - self._units = \ - self.LazyProperty(self.api.get_units, - wrappers=[self._build_units_table]) + self._units = self.LazyProperty(self.api.get_units, wrappers=[self._build_units_table]) self._rig_metadata = self.LazyProperty(self.api.get_rig_metadata) self._metadata = self.LazyProperty(self.api.get_metadata) @@ -342,7 +321,7 @@ def __init__( self.api.test() def get_current_source_density(self, probe_id): - """ Obtain current source density (CSD) of trial-averaged response + """Obtain current source density (CSD) of trial-averaged response to a flash stimuli for this probe. See allensdk.brain_observatory.ecephys.current_source_density for details of CSD calculation. @@ -368,7 +347,7 @@ def get_current_source_density(self, probe_id): return self.api.get_current_source_density(probe_id) def get_lfp(self, probe_id, mask_invalid_intervals=True): - ''' Load an xarray DataArray with LFP data from channels on a + """Load an xarray DataArray with LFP data from channels on a single probe Parameters @@ -389,39 +368,30 @@ def get_lfp(self, probe_id, mask_invalid_intervals=True): Unlike many other data access methods on this class. This one does not cache the loaded data in memory due to the large size of the LFP data. - ''' + """ if mask_invalid_intervals: probe_name = self.probes.loc[probe_id]["description"] fail_tags = ["all_probes", probe_name] - invalid_time_intervals = \ - self._filter_invalid_times_by_tags(fail_tags) + invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) lfp = self.api.get_lfp(probe_id) time_points = lfp.time - valid_time_points = \ - self._get_valid_time_points(time_points, - invalid_time_intervals) + valid_time_points = self._get_valid_time_points(time_points, invalid_time_intervals) return lfp.where(cond=valid_time_points) else: return self.api.get_lfp(probe_id) def _get_valid_time_points(self, time_points, invalid_time_intevals): - all_time_points = xr.DataArray( - name="time_points", - data=[True] * len(time_points), - dims=['time'], - coords=[time_points] + name="time_points", data=[True] * len(time_points), dims=["time"], coords=[time_points] ) valid_time_points = all_time_points for ix, invalid_time_interval in invalid_time_intevals.iterrows(): - invalid_time_points = \ - ((time_points >= invalid_time_interval['start_time']) - & (time_points <= invalid_time_interval['stop_time'])) - valid_time_points = \ - np.logical_and(valid_time_points, - np.logical_not(invalid_time_points)) + invalid_time_points = (time_points >= invalid_time_interval["start_time"]) & ( + time_points <= invalid_time_interval["stop_time"] + ) + valid_time_points = np.logical_and(valid_time_points, np.logical_not(invalid_time_points)) return valid_time_points @@ -440,14 +410,13 @@ def _filter_invalid_times_by_tags(self, tags): """ invalid_times = self.invalid_times.copy() if not invalid_times.empty: - mask = invalid_times['tags'].apply(lambda x: - any([t in x for t in tags])) + mask = invalid_times["tags"].apply(lambda x: any([t in x for t in tags])) invalid_times = invalid_times[mask] return invalid_times def get_inter_presentation_intervals_for_stimulus(self, stimulus_names): - ''' Get a subset of this session's inter-presentation intervals, + """Get a subset of this session's inter-presentation intervals, filtered by stimulus name. Parameters @@ -461,35 +430,27 @@ def get_inter_presentation_intervals_for_stimulus(self, stimulus_names): inter-presentation intervals, filtered to the requested stimulus names. - ''' - - stimulus_names = \ - coerce_scalar( - stimulus_names, - 'expected stimulus_names to be a collection (list-like), ' - f'but found {type(stimulus_names)}: {stimulus_names}') - filtered_presentations = \ - self.stimulus_presentations[ - self.stimulus_presentations[ - 'stimulus_name' - ].isin(stimulus_names)] + """ + + stimulus_names = coerce_scalar( + stimulus_names, + "expected stimulus_names to be a collection (list-like), " + f"but found {type(stimulus_names)}: {stimulus_names}", + ) + filtered_presentations = self.stimulus_presentations[ + self.stimulus_presentations["stimulus_name"].isin(stimulus_names) + ] filtered_ids = set(filtered_presentations.index.values) return self.inter_presentation_intervals[ - (self.inter_presentation_intervals.index.isin( - filtered_ids, - level='from_presentation_id')) - & (self.inter_presentation_intervals.index.isin( - filtered_ids, - level='to_presentation_id')) + (self.inter_presentation_intervals.index.isin(filtered_ids, level="from_presentation_id")) + & (self.inter_presentation_intervals.index.isin(filtered_ids, level="to_presentation_id")) ] def get_stimulus_table( - self, - stimulus_names=None, - include_detailed_parameters=False, - include_unused_parameters=False): - '''Get a subset of stimulus presentations by name, with irrelevant + self, stimulus_names=None, include_detailed_parameters=False, include_unused_parameters=False + ): + """Get a subset of stimulus presentations by name, with irrelevant parameters filtered off Parameters @@ -503,35 +464,28 @@ def get_stimulus_table( Rows are filtered presentations, columns are the relevant subset of stimulus parameters - ''' + """ if stimulus_names is None: stimulus_names = self.stimulus_names - stimulus_names = \ - coerce_scalar( - stimulus_names, - 'expected stimulus_names to be a collection (list-like), ' - f'but found {type(stimulus_names)}: {stimulus_names}') - presentations = \ - self._stimulus_presentations[ - self._stimulus_presentations[ - 'stimulus_name' - ].isin(stimulus_names)] + stimulus_names = coerce_scalar( + stimulus_names, + "expected stimulus_names to be a collection (list-like), " + f"but found {type(stimulus_names)}: {stimulus_names}", + ) + presentations = self._stimulus_presentations[self._stimulus_presentations["stimulus_name"].isin(stimulus_names)] if not include_detailed_parameters: - presentations = \ - self.__class__._remove_detailed_stimulus_parameters( - presentations) + presentations = self.__class__._remove_detailed_stimulus_parameters(presentations) if not include_unused_parameters: - presentations = removed_unused_stimulus_presentation_columns( - presentations) + presentations = removed_unused_stimulus_presentation_columns(presentations) return presentations def get_stimulus_epochs(self, duration_thresholds=None): - """ Reports continuous periods of time during which a single kind of + """Reports continuous periods of time during which a single kind of stimulus was presented Parameters @@ -553,29 +507,24 @@ def get_stimulus_epochs(self, duration_thresholds=None): epochs = [] for left, right in zip(diff_indices[:-1], diff_indices[1:]): - epochs.append({ - "start_time": presentations.iloc[left]["start_time"], - "stop_time": presentations.iloc[right - 1]["stop_time"], - "stimulus_name": presentations.iloc[left]["stimulus_name"], - "stimulus_block": presentations.iloc[left]["stimulus_block"] - }) + epochs.append( + { + "start_time": presentations.iloc[left]["start_time"], + "stop_time": presentations.iloc[right - 1]["stop_time"], + "stimulus_name": presentations.iloc[left]["stimulus_name"], + "stimulus_block": presentations.iloc[left]["stimulus_block"], + } + ) epochs = pd.DataFrame(epochs) epochs["duration"] = epochs["stop_time"] - epochs["start_time"] for key, threshold in duration_thresholds.items(): - epochs = epochs[ - (epochs["stimulus_name"] != key) - | (epochs["duration"] >= threshold) - ] + epochs = epochs[(epochs["stimulus_name"] != key) | (epochs["duration"] >= threshold)] - return epochs.loc[:, ["start_time", - "stop_time", - "duration", - "stimulus_name", - "stimulus_block"]] + return epochs.loc[:, ["start_time", "stop_time", "duration", "stimulus_name", "stimulus_block"]] def get_invalid_times(self): - """ Report invalid time intervals with tags describing the scope + """Report invalid time intervals with tags describing the scope of invalid data The tags format: [scope,scope_id,label] @@ -601,9 +550,7 @@ def get_invalid_times(self): return self.invalid_times - def get_screen_gaze_data( - self, - include_filtered_data=False) -> Optional[pd.DataFrame]: + def get_screen_gaze_data(self, include_filtered_data=False) -> Optional[pd.DataFrame]: """Return a dataframe with estimated gaze position on screen. Parameters @@ -623,8 +570,7 @@ def get_screen_gaze_data( *_screen_coordinates_spherical_x_deg *_screen_coorindates_spherical_y_deg """ - return self.api.get_screen_gaze_data( - include_filtered_data=include_filtered_data) + return self.api.get_screen_gaze_data(include_filtered_data=include_filtered_data) def get_pupil_data(self) -> Optional[pd.DataFrame]: """Return a dataframe with eye tracking ellipse fit data @@ -665,18 +611,15 @@ def _mask_invalid_stimulus_presentations(self, stimulus_presentations): invalid_times = self._filter_invalid_times_by_tags(fail_tags) for ix_sp, sp in stimulus_presentations.iterrows(): - stim_epoch = sp['start_time'], sp['stop_time'] + stim_epoch = sp["start_time"], sp["stop_time"] for ix_it, it in invalid_times.iterrows(): - invalid_interval = it['start_time'], it['stop_time'] + invalid_interval = it["start_time"], it["stop_time"] if _overlap(stim_epoch, invalid_interval): stimulus_presentations.iloc[ix_sp, :] = np.nan - stimulus_presentations.at[ix_sp, "stimulus_name"] = \ - "invalid_presentation" - stimulus_presentations.at[ix_sp, "start_time"] = \ - stim_epoch[0] - stimulus_presentations.at[ix_sp, "stop_time"] = \ - stim_epoch[1] + stimulus_presentations.at[ix_sp, "stimulus_name"] = "invalid_presentation" + stimulus_presentations.at[ix_sp, "start_time"] = stim_epoch[0] + stimulus_presentations.at[ix_sp, "stop_time"] = stim_epoch[1] return stimulus_presentations @@ -688,9 +631,9 @@ def presentationwise_spike_counts( binarize=False, dtype=None, large_bin_size_threshold=0.001, - time_domain_callback=None + time_domain_callback=None, ): - ''' Build an array of spike counts surrounding stimulus onset per + """Build an array of spike counts surrounding stimulus onset per unit and stimulus frame. Parameters @@ -720,37 +663,34 @@ def presentationwise_spike_counts( Data array whose dimensions are stimulus presentation, unit, and time bin and whose values are spike counts. - ''' + """ - stimulus_presentations = self._filter_owned_df( - 'stimulus_presentations', - ids=stimulus_presentation_ids) - units = self._filter_owned_df('units', ids=unit_ids) + stimulus_presentations = self._filter_owned_df("stimulus_presentations", ids=stimulus_presentation_ids) + units = self._filter_owned_df("units", ids=unit_ids) largest_bin_size = np.amax(np.diff(bin_edges)) if binarize and largest_bin_size > large_bin_size_threshold: warnings.warn( - 'You\'ve elected to binarize spike counts, but your maximum ' - f'bin width is {largest_bin_size:2.5f} seconds. ' - 'Binarizing spike counts with such a large bin width can ' - 'cause significant loss of accuracy! ' - 'Please consider only binarizing spike counts ' - f'when your bins are <= {large_bin_size_threshold} ' - 'seconds wide.' + "You've elected to binarize spike counts, but your maximum " + f"bin width is {largest_bin_size:2.5f} seconds. " + "Binarizing spike counts with such a large bin width can " + "cause significant loss of accuracy! " + "Please consider only binarizing spike counts " + f"when your bins are <= {large_bin_size_threshold} " + "seconds wide." ) bin_edges = np.array(bin_edges) domain = build_time_window_domain( - bin_edges, - stimulus_presentations['start_time'].values, - callback=time_domain_callback) + bin_edges, stimulus_presentations["start_time"].values, callback=time_domain_callback + ) out_of_order = np.where(np.diff(domain, axis=1) < 0) if len(out_of_order[0]) > 0: - out_of_order_time_bins = \ - [(row, col) for row, col in zip(out_of_order)] - raise ValueError("The time domain specified contains out-of-order " - f"bin edges at indices: {out_of_order_time_bins}") + out_of_order_time_bins = [(row, col) for row, col in zip(out_of_order)] + raise ValueError( + f"The time domain specified contains out-of-order bin edges at indices: {out_of_order_time_bins}" + ) ends = domain[:, -1] starts = domain[:, 0] @@ -761,42 +701,32 @@ def presentationwise_spike_counts( # Ignoring intervals that overlaps multiple time bins because # trying to figure that out would take O(n) overlapping = [(s, s + 1) for s in overlapping] - warnings.warn("You've specified some overlapping time intervals " - f"between neighboring rows: {overlapping}, " - "with a maximum overlap of" - f" {np.abs(np.min(time_diffs))} seconds.") - - tiled_data = build_spike_histogram( - domain, - self.spike_times, - units.index.values, - dtype=dtype, - binarize=binarize - ) + warnings.warn( + "You've specified some overlapping time intervals " + f"between neighboring rows: {overlapping}, " + "with a maximum overlap of" + f" {np.abs(np.min(time_diffs))} seconds." + ) + + tiled_data = build_spike_histogram(domain, self.spike_times, units.index.values, dtype=dtype, binarize=binarize) stim_presentation_id = stimulus_presentations.index.values tiled_data = xr.DataArray( - name='spike_counts', + name="spike_counts", data=tiled_data, coords={ - 'stimulus_presentation_id': stim_presentation_id, - 'time_relative_to_stimulus_onset': (bin_edges[:-1] + - np.diff(bin_edges) / 2), - 'unit_id': units.index.values + "stimulus_presentation_id": stim_presentation_id, + "time_relative_to_stimulus_onset": (bin_edges[:-1] + np.diff(bin_edges) / 2), + "unit_id": units.index.values, }, - dims=['stimulus_presentation_id', - 'time_relative_to_stimulus_onset', - 'unit_id'] + dims=["stimulus_presentation_id", "time_relative_to_stimulus_onset", "unit_id"], ) return tiled_data - def presentationwise_spike_times( - self, - stimulus_presentation_ids=None, - unit_ids=None): - ''' Produce a table associating spike times with units and + def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids=None): + """Produce a table associating spike times with units and stimulus presentations Parameters @@ -817,18 +747,14 @@ def presentationwise_spike_times( The stimulus presentation on which this spike occurred. unit_id : int The unit that emitted this spike. - ''' + """ - stimulus_presentations = \ - self._filter_owned_df('stimulus_presentations', - ids=stimulus_presentation_ids) - units = self._filter_owned_df('units', ids=unit_ids) + stimulus_presentations = self._filter_owned_df("stimulus_presentations", ids=stimulus_presentation_ids) + units = self._filter_owned_df("units", ids=unit_ids) presentation_times = np.zeros([stimulus_presentations.shape[0] * 2]) - presentation_times[::2] = \ - np.array(stimulus_presentations['start_time']) - presentation_times[1::2] = \ - np.array(stimulus_presentations['stop_time']) + presentation_times[::2] = np.array(stimulus_presentations["start_time"]) + presentation_times[1::2] = np.array(stimulus_presentations["stop_time"]) all_presentation_ids = np.array(stimulus_presentations.index.values) presentation_ids = [] @@ -840,60 +766,49 @@ def presentationwise_spike_times( indices = np.searchsorted(presentation_times, data) - 1 index_valid = indices % 2 == 0 - presentations = \ - all_presentation_ids[np.floor(indices / 2).astype(int)] + presentations = all_presentation_ids[np.floor(indices / 2).astype(int)] sorder = np.argsort(presentations) presentations = presentations[sorder] index_valid = index_valid[sorder] data = data[sorder] - changes = \ - np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0] + changes = np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0] for ii, jj in zip(changes[:-1], changes[1:]): values = data[ii:jj][index_valid[ii:jj]] if values.size == 0: continue unit_ids.append(np.zeros([values.size]) + unit_id) - presentation_ids.append(np.zeros([values.size]) + - presentations[ii]) + presentation_ids.append(np.zeros([values.size]) + presentations[ii]) spike_times.append(values) if not spike_times: # If there are no units firing during the given stimulus return an # empty dataframe - return pd.DataFrame(columns=[ - 'spike_times', - 'stimulus_presentation', - 'unit_id', - 'time_since_stimulus_presentation_onset']) + return pd.DataFrame( + columns=["spike_times", "stimulus_presentation", "unit_id", "time_since_stimulus_presentation_onset"] + ) pres_ids = np.concatenate(presentation_ids).astype(int) - spike_df = pd.DataFrame({ - 'stimulus_presentation_id': pres_ids, - 'unit_id': np.concatenate(unit_ids).astype(int) - }, index=pd.Index(np.concatenate(spike_times), name='spike_time')) + spike_df = pd.DataFrame( + {"stimulus_presentation_id": pres_ids, "unit_id": np.concatenate(unit_ids).astype(int)}, + index=pd.Index(np.concatenate(spike_times), name="spike_time"), + ) # Add time since stimulus presentation onset - onset_times = self._filter_owned_df( - "stimulus_presentations", ids=all_presentation_ids)["start_time"] - spikes_with_onset = spike_df.join(onset_times, - on=["stimulus_presentation_id"]) + onset_times = self._filter_owned_df("stimulus_presentations", ids=all_presentation_ids)["start_time"] + spikes_with_onset = spike_df.join(onset_times, on=["stimulus_presentation_id"]) spikes_with_onset["time_since_stimulus_presentation_onset"] = ( spikes_with_onset.index - spikes_with_onset["start_time"] ) - spikes_with_onset.sort_values('spike_time', axis=0, inplace=True) + spikes_with_onset.sort_values("spike_time", axis=0, inplace=True) spikes_with_onset.drop(columns=["start_time"], inplace=True) return spikes_with_onset - def conditionwise_spike_statistics( - self, - stimulus_presentation_ids=None, - unit_ids=None, - use_rates=False): - """ Produce summary statistics for each distinct stimulus condition + def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_ids=None, use_rates=False): + """Produce summary statistics for each distinct stimulus condition Parameters ---------- @@ -918,49 +833,42 @@ def conditionwise_spike_statistics( # TODO: To use filter_owned_df() make sure to convert the results # from a Series to a Dataframe stimulus_presentation_ids = ( - stimulus_presentation_ids if stimulus_presentation_ids is not None - else self.stimulus_presentations.index.values) # In case + stimulus_presentation_ids + if stimulus_presentation_ids is not None + else self.stimulus_presentations.index.values + ) # In case presentations = self.stimulus_presentations.loc[ stimulus_presentation_ids, ["stimulus_condition_id", "duration"] - ] + ] spikes = self.presentationwise_spike_times( - stimulus_presentation_ids=stimulus_presentation_ids, - unit_ids=unit_ids + stimulus_presentation_ids=stimulus_presentation_ids, unit_ids=unit_ids ) if spikes.empty: # In the case there are no spikes spike_counts = pd.DataFrame( - {'spike_count': 0}, - index=pd.MultiIndex.from_product([ - stimulus_presentation_ids, - unit_ids], - names=['stimulus_presentation_id', 'unit_id'])) + {"spike_count": 0}, + index=pd.MultiIndex.from_product( + [stimulus_presentation_ids, unit_ids], names=["stimulus_presentation_id", "unit_id"] + ), + ) else: spike_counts = spikes.copy() spike_counts["spike_count"] = np.zeros(spike_counts.shape[0]) - spike_counts = \ - spike_counts.groupby(["stimulus_presentation_id", - "unit_id"]).count() + spike_counts = spike_counts.groupby(["stimulus_presentation_id", "unit_id"]).count() # If not explicity stated get unit ids from spikes table. - unit_ids = unit_ids if unit_ids is not None \ - else spikes['unit_id'].unique() - spike_counts = \ - spike_counts.reindex( - pd.MultiIndex.from_product( - [stimulus_presentation_ids, - unit_ids], - names=['stimulus_presentation_id', - 'unit_id']), fill_value=0) - - sp = pd.merge(spike_counts, - presentations, - left_on="stimulus_presentation_id", - right_index=True, - how="left") + unit_ids = unit_ids if unit_ids is not None else spikes["unit_id"].unique() + spike_counts = spike_counts.reindex( + pd.MultiIndex.from_product( + [stimulus_presentation_ids, unit_ids], names=["stimulus_presentation_id", "unit_id"] + ), + fill_value=0, + ) + + sp = pd.merge(spike_counts, presentations, left_on="stimulus_presentation_id", right_index=True, how="left") sp.reset_index(inplace=True) if use_rates: @@ -975,15 +883,10 @@ def conditionwise_spike_statistics( for ind, gr in sp.groupby(["stimulus_condition_id", "unit_id"]): summary.append(extractor(ind, gr)) - return pd.DataFrame(summary).set_index(keys=[ - "unit_id", - "stimulus_condition_id"]) + return pd.DataFrame(summary).set_index(keys=["unit_id", "stimulus_condition_id"]) - def get_parameter_values_for_stimulus( - self, - stimulus_name, - drop_nulls=True): - """ For each stimulus parameter, report the unique values taken + def get_parameter_values_for_stimulus(self, stimulus_name, drop_nulls=True): + """For each stimulus parameter, report the unique values taken on by that parameter while a named stimulus was presented. Parameters @@ -998,17 +901,11 @@ def get_parameter_values_for_stimulus( """ - presentation_ids = \ - self.get_stimulus_table([stimulus_name]).index.values - return self.get_stimulus_parameter_values( - presentation_ids, - drop_nulls=drop_nulls) - - def get_stimulus_parameter_values( - self, - stimulus_presentation_ids=None, - drop_nulls=True): - ''' For each stimulus parameter, report the unique values taken + presentation_ids = self.get_stimulus_table([stimulus_name]).index.values + return self.get_stimulus_parameter_values(presentation_ids, drop_nulls=drop_nulls) + + def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nulls=True): + """For each stimulus parameter, report the unique values taken on by that parameter throughout the course of the session. Parameters @@ -1022,17 +919,11 @@ def get_stimulus_parameter_values( dict : maps parameters (column names) to their unique values. - ''' + """ - stimulus_presentations = \ - self._filter_owned_df('stimulus_presentations', - ids=stimulus_presentation_ids) - stimulus_presentations = \ - stimulus_presentations.drop( - columns=list(NON_STIMULUS_PARAMETERS) + ['stimulus_name']) - stimulus_presentations = \ - removed_unused_stimulus_presentation_columns( - stimulus_presentations) + stimulus_presentations = self._filter_owned_df("stimulus_presentations", ids=stimulus_presentation_ids) + stimulus_presentations = stimulus_presentations.drop(columns=list(NON_STIMULUS_PARAMETERS) + ["stimulus_name"]) + stimulus_presentations = removed_unused_stimulus_presentation_columns(stimulus_presentations) parameters = {} for colname in stimulus_presentations.columns: @@ -1049,8 +940,7 @@ def get_stimulus_parameter_values( return parameters def channel_structure_intervals(self, channel_ids): - - """ find on a list of channels the intervals of channels inserted + """find on a list of channels the intervals of channels inserted into particular structures Parameters @@ -1078,8 +968,7 @@ def channel_structure_intervals(self, channel_ids): unique_probes = table["probe_id"].unique() if len(unique_probes) > 1: - warnings.warn("Calculating structure boundaries across channels " - "from multiple probes.") + warnings.warn("Calculating structure boundaries across channels from multiple probes.") intervals = nan_intervals(table[structure_id_key].values) labels = table[structure_label_key].iloc[intervals[:-1]].values @@ -1098,58 +987,37 @@ def _build_spike_times(self, spike_times): return output_spike_times - def _build_stimulus_presentations( - self, - stimulus_presentations, - nonapplicable="null"): - stimulus_presentations.index.name = 'stimulus_presentation_id' - stimulus_presentations = \ - stimulus_presentations.drop(columns=['stimulus_index']) + def _build_stimulus_presentations(self, stimulus_presentations, nonapplicable="null"): + stimulus_presentations.index.name = "stimulus_presentation_id" + stimulus_presentations = stimulus_presentations.drop(columns=["stimulus_index"]) # TODO: putting these here for now; after SWDB 2019, will rerun # stimulus table module for all sessions and can remove these - stimulus_presentations = \ - naming_utilities.collapse_columns(stimulus_presentations) - stimulus_presentations = \ - naming_utilities.standardize_movie_numbers(stimulus_presentations) - stimulus_presentations = \ - naming_utilities.add_number_to_shuffled_movie( - stimulus_presentations) - stimulus_presentations = \ - naming_utilities.map_stimulus_names( - stimulus_presentations, default_stimulus_renames) - stimulus_presentations = \ - naming_utilities.map_column_names( - stimulus_presentations, - default_column_renames, - ignore_case=False) + stimulus_presentations = naming_utilities.collapse_columns(stimulus_presentations) + stimulus_presentations = naming_utilities.standardize_movie_numbers(stimulus_presentations) + stimulus_presentations = naming_utilities.add_number_to_shuffled_movie(stimulus_presentations) + stimulus_presentations = naming_utilities.map_stimulus_names(stimulus_presentations, default_stimulus_renames) + stimulus_presentations = naming_utilities.map_column_names( + stimulus_presentations, default_column_renames, ignore_case=False + ) # pandas groupby ops ignore nans, so we need a new "nonapplicable" # value that pandas does not recognize as null ... stimulus_presentations.replace("", nonapplicable, inplace=True) # pandas does not automatically convert boolean cols for fillna - boolean_colnames = stimulus_presentations.dtypes[ - stimulus_presentations.dtypes == "boolean"].index + boolean_colnames = stimulus_presentations.dtypes[stimulus_presentations.dtypes == "boolean"].index col_type_map = {colname: "object" for colname in boolean_colnames} - stimulus_presentations = stimulus_presentations.astype( - col_type_map).fillna(nonapplicable) + stimulus_presentations = stimulus_presentations.astype(col_type_map).fillna(nonapplicable) # eval str(numeric) and str(lists) # convert lists to tuple for hashability # Rationale: pd dataframe reads values as str from nwb files # where they are expected to be float col_list = ["phase, size, spatial_frequency"] - stimulus_presentations = literal_col_eval( - stimulus_presentations, - columns=col_list) - stimulus_presentations = df_list_to_tuple( - stimulus_presentations, - columns=col_list) - stimulus_presentations["duration"] = ( - stimulus_presentations["stop_time"] - - stimulus_presentations["start_time"] - ) + stimulus_presentations = literal_col_eval(stimulus_presentations, columns=col_list) + stimulus_presentations = df_list_to_tuple(stimulus_presentations, columns=col_list) + stimulus_presentations["duration"] = stimulus_presentations["stop_time"] - stimulus_presentations["start_time"] # TODO: database these stimulus_conditions = {} presentation_conditions = [] @@ -1158,13 +1026,8 @@ def _build_stimulus_presentations( # TODO: Can we have parameters on what columns to omit? # If stimulus_block or duration is left in it can affect # how conditionwise_spike_statistics counts spikes - params_only = \ - stimulus_presentations.drop(columns=["start_time", - "stop_time", - "duration", - "stimulus_block"]) + params_only = stimulus_presentations.drop(columns=["start_time", "stop_time", "duration", "stimulus_block"]) for row in params_only.itertuples(index=False): - if row in stimulus_conditions: cid = stimulus_conditions[row] else: @@ -1181,11 +1044,8 @@ def _build_stimulus_presentations( cond_ids.append(ci) cond_vals.append(cv) - self._stimulus_conditions = \ - pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, - name="stimulus_condition_id")) - stimulus_presentations["stimulus_condition_id"] = \ - presentation_conditions + self._stimulus_conditions = pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, name="stimulus_condition_id")) + stimulus_presentations["stimulus_condition_id"] = presentation_conditions return stimulus_presentations @@ -1194,40 +1054,34 @@ def _build_units_table(self, units_table): probes = self.probes.copy() self._unmerged_units = units_table.copy() - table = pd.merge(units_table, - channels, - left_on='peak_channel_id', - right_index=True, - suffixes=['_unit', '_channel']) - table = pd.merge(table, - probes, - left_on='probe_id', - right_index=True, - suffixes=['_unit', '_probe']) - - table.index.name = 'unit_id' - table = table.rename(columns={ - 'description': 'probe_description', - 'local_index_channel': 'channel_local_index', - 'PT_ratio': 'waveform_PT_ratio', - 'amplitude': 'waveform_amplitude', - 'duration': 'waveform_duration', - 'halfwidth': 'waveform_halfwidth', - 'recovery_slope': 'waveform_recovery_slope', - 'repolarization_slope': 'waveform_repolarization_slope', - 'spread': 'waveform_spread', - 'velocity_above': 'waveform_velocity_above', - 'velocity_below': 'waveform_velocity_below', - 'sampling_rate': 'probe_sampling_rate', - 'lfp_sampling_rate': 'probe_lfp_sampling_rate', - 'has_lfp_data': 'probe_has_lfp_data', - 'l_ratio': 'L_ratio', - 'pref_images_multi_ns': 'pref_image_multi_ns', - }) - - return table.sort_values(by=['probe_description', - 'probe_vertical_position', - 'probe_horizontal_position']) + table = pd.merge( + units_table, channels, left_on="peak_channel_id", right_index=True, suffixes=["_unit", "_channel"] + ) + table = pd.merge(table, probes, left_on="probe_id", right_index=True, suffixes=["_unit", "_probe"]) + + table.index.name = "unit_id" + table = table.rename( + columns={ + "description": "probe_description", + "local_index_channel": "channel_local_index", + "PT_ratio": "waveform_PT_ratio", + "amplitude": "waveform_amplitude", + "duration": "waveform_duration", + "halfwidth": "waveform_halfwidth", + "recovery_slope": "waveform_recovery_slope", + "repolarization_slope": "waveform_repolarization_slope", + "spread": "waveform_spread", + "velocity_above": "waveform_velocity_above", + "velocity_below": "waveform_velocity_below", + "sampling_rate": "probe_sampling_rate", + "lfp_sampling_rate": "probe_lfp_sampling_rate", + "has_lfp_data": "probe_has_lfp_data", + "l_ratio": "L_ratio", + "pref_images_multi_ns": "pref_image_multi_ns", + } + ) + + return table.sort_values(by=["probe_description", "probe_vertical_position", "probe_horizontal_position"]) def _build_nwb1_waveforms(self, mean_waveforms): # _build_mean_waveforms() assumes every unit has the same number of @@ -1236,21 +1090,17 @@ def _build_nwb1_waveforms(self, mean_waveforms): # ONE channel units_df = self._units output_waveforms = {} - sampling_rate_lu = { - uid: self.probes.loc[ - r['probe_id'] - ]['sampling_rate'] for uid, r in units_df.iterrows() - } + sampling_rate_lu = {uid: self.probes.loc[r["probe_id"]]["sampling_rate"] for uid, r in units_df.iterrows()} for uid in list(mean_waveforms.keys()): data = mean_waveforms.pop(uid) output_waveforms[uid] = xr.DataArray( data=data, - dims=['channel_id', 'time'], + dims=["channel_id", "time"], coords={ - 'channel_id': [units_df.loc[uid]['peak_channel_id']], - 'time': np.arange(data.shape[1]) / sampling_rate_lu[uid] - } + "channel_id": [units_df.loc[uid]["peak_channel_id"]], + "time": np.arange(data.shape[1]) / sampling_rate_lu[uid], + }, ) return output_waveforms @@ -1261,14 +1111,14 @@ def _build_mean_waveforms(self, mean_waveforms): channel_id_lut = defaultdict(lambda: -1) for cid, row in self.channels.iterrows(): - channel_id_lut[( - row["probe_channel_number"], - row["probe_id"], - )] = cid + channel_id_lut[ + ( + row["probe_channel_number"], + row["probe_id"], + ) + ] = cid - probe_id_lut = { - uid: row['probe_id'] for uid, row in self._units.iterrows() - } + probe_id_lut = {uid: row["probe_id"] for uid, row in self._units.iterrows()} output_waveforms = {} for uid in list(mean_waveforms.keys()): @@ -1280,40 +1130,35 @@ def _build_mean_waveforms(self, mean_waveforms): probe_id = probe_id_lut[uid] - time_vals = np.arange(data.shape[1]) / \ - self.probes.loc[probe_id]['sampling_rate'] + time_vals = np.arange(data.shape[1]) / self.probes.loc[probe_id]["sampling_rate"] output_waveforms[uid] = xr.DataArray( data=data, - dims=['channel_id', 'time'], + dims=["channel_id", "time"], coords={ - 'channel_id': [channel_id_lut[(ii, probe_id)] - for ii in range(data.shape[0])], - 'time': time_vals - } + "channel_id": [channel_id_lut[(ii, probe_id)] for ii in range(data.shape[0])], + "time": time_vals, + }, ) - output_waveforms[uid] = \ - output_waveforms[uid][ - output_waveforms[uid]["channel_id"] != -1 - ] + output_waveforms[uid] = output_waveforms[uid][output_waveforms[uid]["channel_id"] != -1] return output_waveforms def _build_inter_presentation_intervals(self): - from_presentation_id = self.stimulus_presentations.index.values[:-1] to_presentation_id = self.stimulus_presentations.index.values[1:] - interval1 = self.stimulus_presentations['start_time'].values[1:] - interval2 = self.stimulus_presentations['stop_time'].values[:-1] - - intervals = pd.DataFrame({ - 'from_presentation_id': from_presentation_id, - 'to_presentation_id': to_presentation_id, - 'interval': interval1 - interval2 - }) - return intervals.set_index(['from_presentation_id', - 'to_presentation_id'], inplace=False) + interval1 = self.stimulus_presentations["start_time"].values[1:] + interval2 = self.stimulus_presentations["stop_time"].values[:-1] + + intervals = pd.DataFrame( + { + "from_presentation_id": from_presentation_id, + "to_presentation_id": to_presentation_id, + "interval": interval1 - interval2, + } + ) + return intervals.set_index(["from_presentation_id", "to_presentation_id"], inplace=False) def _filter_owned_df(self, key, ids=None, copy=True): df = getattr(self, key) @@ -1324,14 +1169,12 @@ def _filter_owned_df(self, key, ids=None, copy=True): if ids is None: return df - ids = coerce_scalar( - ids, f'a scalar ({ids}) was ' - f'provided as ids, filtering to a single row of {key}.') + ids = coerce_scalar(ids, f"a scalar ({ids}) was provided as ids, filtering to a single row of {key}.") df = df.loc[ids] if df.shape[0] == 0: - warnings.warn(f'filtering to an empty set of {key}!') + warnings.warn(f"filtering to an empty set of {key}!") return df @@ -1355,37 +1198,31 @@ def from_nwb_path(cls, path, nwb_version=2, api_kwargs=None, **kwargs): NWBAdaptorCls = EcephysNwb1Api else: - raise Exception(f'specified NWB version {nwb_version} not ' - 'supported. Supported versions are: 2.X, 1.X') + raise Exception(f"specified NWB version {nwb_version} not supported. Supported versions are: 2.X, 1.X") - return cls(api=NWBAdaptorCls.from_path(path=path, - **api_kwargs), **kwargs) + return cls(api=NWBAdaptorCls.from_path(path=path, **api_kwargs), **kwargs) def _warn_invalid_spike_intervals(self): - fail_tags = list(self.probes["description"]) fail_tags.append("all_probes") invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) if not invalid_time_intervals.empty: - warnings.warn("Session includes invalid time intervals that could " - "be accessed with the attribute 'invalid_times'," - "Spikes within these intervals are invalid and may " - "need to be excluded from the analysis.") - + warnings.warn( + "Session includes invalid time intervals that could " + "be accessed with the attribute 'invalid_times'," + "Spikes within these intervals are invalid and may " + "need to be excluded from the analysis." + ) -def build_spike_histogram(time_domain, - spike_times, - unit_ids, - dtype=None, - binarize=False): +def build_spike_histogram(time_domain, spike_times, unit_ids, dtype=None, binarize=False): time_domain = np.array(time_domain) unit_ids = np.array(unit_ids) tiled_data = np.zeros( (time_domain.shape[0], time_domain.shape[1] - 1, unit_ids.size), - dtype=(np.uint8 if binarize else np.uint16) if dtype is None else dtype + dtype=(np.uint8 if binarize else np.uint16) if dtype is None else dtype, ) starts = time_domain[:, :-1] @@ -1396,7 +1233,7 @@ def build_spike_histogram(time_domain, start_positions = np.searchsorted(data, starts.flat) end_positions = np.searchsorted(data, ends.flat, side="right") - counts = (end_positions - start_positions) + counts = end_positions - start_positions tiled_data[:, :, ii].flat = counts > 0 if binarize else counts @@ -1415,15 +1252,15 @@ def removed_unused_stimulus_presentation_columns(stimulus_presentations): for cn in stimulus_presentations.columns: if np.all(stimulus_presentations[cn].isna()): to_drop.append(cn) - elif np.all(stimulus_presentations[cn].astype(str).values == ''): + elif np.all(stimulus_presentations[cn].astype(str).values == ""): to_drop.append(cn) - elif np.all(stimulus_presentations[cn].astype(str).values == 'null'): + elif np.all(stimulus_presentations[cn].astype(str).values == "null"): to_drop.append(cn) return stimulus_presentations.drop(columns=to_drop) def nan_intervals(array, nan_like=["null"]): - """ find interval bounds (bounding consecutive identical values) in an + """find interval bounds (bounding consecutive identical values) in an array, which may contain nans Parameters @@ -1461,7 +1298,7 @@ def is_distinct_from(left, right): def array_intervals(array): - """ find interval bounds (bounding consecutive identical values) + """find interval bounds (bounding consecutive identical values) in an array Parameters @@ -1496,7 +1333,7 @@ def _extract_summary_count_statistics(index, group): "stimulus_presentation_count": group.shape[0], "spike_mean": np.mean(group["spike_count"].values), "spike_std": np.std(group["spike_count"].values, ddof=1), - "spike_sem": scipy.stats.sem(group["spike_count"].values) + "spike_sem": scipy.stats.sem(group["spike_count"].values), } @@ -1507,7 +1344,7 @@ def _extract_summary_rate_statistics(index, group): "stimulus_presentation_count": group.shape[0], "spike_mean": np.mean(group["spike_rate"].values), "spike_std": np.std(group["spike_rate"].values, ddof=1), - "spike_sem": scipy.stats.sem(group["spike_rate"].values) + "spike_sem": scipy.stats.sem(group["spike_rate"].values), } diff --git a/allensdk/brain_observatory/ecephys/ecephys_session_api/__init__.py b/allensdk/brain_observatory/ecephys/ecephys_session_api/__init__.py index 02c224d9a7..4eb6111a34 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session_api/__init__.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session_api/__init__.py @@ -6,4 +6,4 @@ "EcephysSessionApi", "EcephysNwbSessionApi", "EcephysNwb1Api", -] \ No newline at end of file +] diff --git a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb1_session_api.py b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb1_session_api.py index 652a38c835..9923061b22 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb1_session_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb1_session_api.py @@ -58,16 +58,14 @@ class EcephysNwb1Api(EcephysSessionApi): def __init__(self, path, *args, **kwargs): self._path = path - self._h5_root = h5py.File(self._path, 'r') + self._h5_root = h5py.File(self._path, "r") try: # check file is a valid NWB 1 file - version_str = self._h5_root['nwb_version'][()] + version_str = self._h5_root["nwb_version"][()] if isinstance(version_str, bytes): - version_str = version_str.decode('utf-8') - if not (version_str.startswith('NWB-1.') or - version_str.startswith('1.')): - raise Exception( - '{} is not a valid NWB 1 file path'.format(self._path)) + version_str = version_str.decode("utf-8") + if not (version_str.startswith("NWB-1.") or version_str.startswith("1.")): + raise Exception("{} is not a valid NWB 1 file path".format(self._path)) except Exception: raise @@ -82,50 +80,54 @@ def __init__(self, path, *args, **kwargs): @property def processing_grp(self): - return self._h5_root['/processing'] + return self._h5_root["/processing"] @property def running_speed_grp(self): - return self._h5_root['/acquisition/timeseries/RunningSpeed'] + return self._h5_root["/acquisition/timeseries/RunningSpeed"] def _probe_groups(self): - return [(pname, pgrp) for pname, pgrp in self.processing_grp.items() - if isinstance(pgrp, h5py.Group) and pname.lower().startswith( - 'probe')] + return [ + (pname, pgrp) + for pname, pgrp in self.processing_grp.items() + if isinstance(pgrp, h5py.Group) and pname.lower().startswith("probe") + ] def get_running_speed(self): running_speed_grp = self.running_speed_grp - return pd.DataFrame({ - "start_time": running_speed_grp['timestamps'][:], - "velocity": running_speed_grp['data'][:] - # average velocities over a given interval - }) + return pd.DataFrame( + { + "start_time": running_speed_grp["timestamps"][:], + "velocity": running_speed_grp["data"][:], + # average velocities over a given interval + } + ) __stim_col_map = { # Used for mapping column names from NWB 1.0 features ds to their # appropiate NWB 2.0 name - b'temporal_frequency': 'TF', - b'spatial_frequency': 'SF', - b'pos_x': 'Pos_x', - b'pos_y': 'Pos_y', - b'orientation': 'Ori', - b'color': 'Color', - b'phase': 'Phase', - b'frame': 'Image' + b"temporal_frequency": "TF", + b"spatial_frequency": "SF", + b"pos_x": "Pos_x", + b"pos_y": "Pos_y", + b"orientation": "Ori", + b"color": "Color", + b"phase": "Phase", + b"frame": "Image", } def get_stimulus_presentations(self) -> pd.DataFrame: # TODO: Missing 'stimulus_block', 'stimulus_index, Image, stimulus_presentations_df = None presentation_ids = 0 # make up a id for every stim-presentation - stim_pres_grp = self._h5_root['/stimulus/presentation'] + stim_pres_grp = self._h5_root["/stimulus/presentation"] # Stimulus-presentations are heirarchily grouped by presentation # name. Iterate through all of them and build # a single table. for block_i, (stim_name, stim_grp) in enumerate(stim_pres_grp.items()): - timestamps = stim_grp['timestamps'][()] + timestamps = stim_grp["timestamps"][()] start_times = timestamps[:, 0] if timestamps.shape[1] == 2: stop_times = timestamps[:, 1] @@ -135,48 +137,45 @@ def get_stimulus_presentations(self) -> pd.DataFrame: continue stop_times = np.nan - n_stims = stim_grp['num_samples'][()] + n_stims = stim_grp["num_samples"][()] try: # parse the features/data datasets, map old column names ( # temporal freq->TF, phase-> phase, etc). stim_props = { - self.__stim_col_map.get( - ftr_name, ftr_name): stim_grp['data'][:, i] - for i, ftr_name in enumerate(stim_grp['features'][()])} + self.__stim_col_map.get(ftr_name, ftr_name): stim_grp["data"][:, i] + for i, ftr_name in enumerate(stim_grp["features"][()]) + } except Exception: stim_props = {} - stim_df = pd.DataFrame({ - 'stimulus_presentation_id': np.arange(presentation_ids, - presentation_ids + - n_stims), - 'start_time': start_times, - 'stop_time': stop_times, - 'stimulus_name': stim_name, - 'TF': stim_props.get('TF', np.nan), - 'SF': stim_props.get('SF', np.nan), - 'Ori': stim_props.get('Ori', np.nan), - 'Pos_x': stim_props.get('Pos_x', np.nan), - 'Pos_y': stim_props.get('Pos_y', np.nan), - 'Color': stim_props.get('Color', np.nan), - 'Phase': stim_props.get('Phase', np.nan), - 'Image': stim_props.get('Image', np.nan), - 'stimulus_block': block_i - # Required by conditionwise_spike_counts(), add made-up number - }) + stim_df = pd.DataFrame( + { + "stimulus_presentation_id": np.arange(presentation_ids, presentation_ids + n_stims), + "start_time": start_times, + "stop_time": stop_times, + "stimulus_name": stim_name, + "TF": stim_props.get("TF", np.nan), + "SF": stim_props.get("SF", np.nan), + "Ori": stim_props.get("Ori", np.nan), + "Pos_x": stim_props.get("Pos_x", np.nan), + "Pos_y": stim_props.get("Pos_y", np.nan), + "Color": stim_props.get("Color", np.nan), + "Phase": stim_props.get("Phase", np.nan), + "Image": stim_props.get("Image", np.nan), + "stimulus_block": block_i, + # Required by conditionwise_spike_counts(), add made-up number + } + ) presentation_ids += n_stims if stimulus_presentations_df is None: stimulus_presentations_df = stim_df else: - stimulus_presentations_df = pd.concat( - [stimulus_presentations_df, stim_df]) + stimulus_presentations_df = pd.concat([stimulus_presentations_df, stim_df]) - stimulus_presentations_df[ - 'stimulus_index'] = 0 # I'm not sure what column is, but is + stimulus_presentations_df["stimulus_index"] = 0 # I'm not sure what column is, but is # droped by EcephysSession - stimulus_presentations_df.set_index('stimulus_presentation_id', - inplace=True) + stimulus_presentations_df.set_index("stimulus_presentation_id", inplace=True) return stimulus_presentations_df def get_probes(self) -> pd.DataFrame: @@ -186,15 +185,17 @@ def get_probes(self) -> pd.DataFrame: probe_ids.append(self._probe_ids[prb_name]) locations.append(prb_name) - probes_df = pd.DataFrame({ - 'id': pd.Series(probe_ids, dtype=np.uint64), - 'location': pd.Series(locations, dtype=object), - 'description': "" # TODO: Find description - }) - probes_df.set_index('id', inplace=True) + probes_df = pd.DataFrame( + { + "id": pd.Series(probe_ids, dtype=np.uint64), + "location": pd.Series(locations, dtype=object), + "description": "", # TODO: Find description + } + ) + probes_df.set_index("id", inplace=True) # TODO: calculate real sampling rate for each probe. - probes_df['sampling_rate'] = 30000.0 + probes_df["sampling_rate"] = 30000.0 return probes_df @@ -202,8 +203,7 @@ def get_channels(self) -> pd.DataFrame: # TODO: Missing: structure_id processing_grp = self.processing_grp - max_channels = sum( - len(prb_grp['unit_list']) for prb_grp in processing_grp.values()) + max_channels = sum(len(prb_grp["unit_list"]) for prb_grp in processing_grp.values()) channel_ids = np.zeros(max_channels, dtype=np.uint64) local_channel_indices = np.zeros(max_channels, dtype=np.int64) prb_ids = np.zeros(max_channels, dtype=np.uint64) @@ -218,10 +218,10 @@ def get_channels(self) -> pd.DataFrame: # to get information about all available channels for prb_name, prb_grp in self._probe_groups(): prb_id = self._probe_ids[prb_name] - unit_list = prb_grp['unit_list'][()] + unit_list = prb_grp["unit_list"][()] for indx, uid in enumerate(unit_list): - unit_grp = prb_grp['UnitTimes'][str(uid)] - local_channel_index = unit_grp['channel'][()] + unit_grp = prb_grp["UnitTimes"][str(uid)] + local_channel_index = unit_grp["channel"][()] channel_id = self._channel_ids[(prb_name, local_channel_index)] if channel_id in existing_channels: # If a channel has already been processed (ie it's @@ -232,29 +232,29 @@ def get_channels(self) -> pd.DataFrame: channel_ids[channel_indx] = channel_id local_channel_indices[channel_indx] = local_channel_index prb_ids[channel_indx] = prb_id - prb_hrz_pos[channel_indx] = unit_grp['xpos_probe'][()] - prb_vert_pos[channel_indx] = unit_grp['ypos_probe'][()] + prb_hrz_pos[channel_indx] = unit_grp["xpos_probe"][()] + prb_vert_pos[channel_indx] = unit_grp["ypos_probe"][()] try: - struct_acronyms[channel_indx] = str( - unit_grp['ccf_structure'][()], encoding='ascii') + struct_acronyms[channel_indx] = str(unit_grp["ccf_structure"][()], encoding="ascii") except TypeError: - struct_acronyms[channel_indx] = \ - unit_grp['ccf_structure'][()] + struct_acronyms[channel_indx] = unit_grp["ccf_structure"][()] existing_channels.add(channel_id) channel_indx += 1 n_channels = len(existing_channels) - channels_df = pd.DataFrame({ - 'id': channel_ids[:n_channels], - 'local_index': local_channel_indices[:n_channels], - 'probe_id': prb_ids[:n_channels], - 'probe_horizontal_position': prb_hrz_pos[:n_channels], - 'probe_vertical_position': prb_vert_pos[:n_channels], - 'ecephys_structure_acronym': struct_acronyms[:n_channels], - 'valid_data': True # TODO: Pull out valid table column from NWB - }) - channels_df.set_index('id', inplace=True) + channels_df = pd.DataFrame( + { + "id": channel_ids[:n_channels], + "local_index": local_channel_indices[:n_channels], + "probe_id": prb_ids[:n_channels], + "probe_horizontal_position": prb_hrz_pos[:n_channels], + "probe_vertical_position": prb_vert_pos[:n_channels], + "ecephys_structure_acronym": struct_acronyms[:n_channels], + "valid_data": True, # TODO: Pull out valid table column from NWB + } + ) + channels_df.set_index("id", inplace=True) return channels_df def get_mean_waveforms(self) -> Dict[int, np.ndarray]: @@ -262,22 +262,25 @@ def get_mean_waveforms(self) -> Dict[int, np.ndarray]: for prb_name, prb_grp in self._probe_groups(): # There is one waveform for any given spike, but still calling # it "mean" waveform - for indx, uid in enumerate(prb_grp['unit_list']): - unit_grp = prb_grp['UnitTimes'][str(uid)] + for indx, uid in enumerate(prb_grp["unit_list"]): + unit_grp = prb_grp["UnitTimes"][str(uid)] unit_id = self._unit_ids[(prb_name, uid)] # EcephysSession is expecting an array of waveforms - waveforms[unit_id] = \ - np.array([unit_grp['waveform'][()], ]) + waveforms[unit_id] = np.array( + [ + unit_grp["waveform"][()], + ] + ) return waveforms def get_spike_times(self) -> Dict[int, np.ndarray]: spike_times = {} for prb_name, prb_grp in self._probe_groups(): - for indx, uid in enumerate(prb_grp['unit_list']): - unit_grp = prb_grp['UnitTimes'][str(uid)] + for indx, uid in enumerate(prb_grp["unit_list"]): + unit_grp = prb_grp["UnitTimes"][str(uid)] unit_id = self._unit_ids[(prb_name, uid)] - spike_times[unit_id] = unit_grp['times'][()] + spike_times[unit_id] = unit_grp["times"][()] return spike_times @@ -292,33 +295,34 @@ def get_units(self) -> pd.DataFrame: # visit every /processing/probeN/UnitList/N/ group to build # TODO: Since just visting the tree is so expensive, maybe build # the channels and probes at the same time. - unit_list = prb_grp['unit_list'][()] + unit_list = prb_grp["unit_list"][()] prb_uids = np.zeros(len(unit_list), dtype=np.uint64) prb_channels = np.zeros(len(unit_list), dtype=np.int64) prb_snr = np.zeros(len(unit_list), dtype=np.float64) for indx, uid in enumerate(unit_list): - unit_grp = prb_grp['UnitTimes'][str(uid)] + unit_grp = prb_grp["UnitTimes"][str(uid)] prb_uids[indx] = self._unit_ids[(prb_name, uid)] - prb_channels[indx] = self._channel_ids[ - (prb_name, unit_grp['channel'][()])] - prb_snr[indx] = unit_grp['snr'][()] + prb_channels[indx] = self._channel_ids[(prb_name, unit_grp["channel"][()])] + prb_snr[indx] = unit_grp["snr"][()] unit_ids = np.append(unit_ids, prb_uids) local_indices = np.append(local_indices, unit_list) peak_channel_ids = np.append(peak_channel_ids, prb_channels) snrs = np.append(snrs, prb_snr) - units_df = pd.DataFrame({ - 'unit_id': pd.Series(unit_ids, dtype=np.int64), - 'local_index': local_indices, - 'peak_channel_id': peak_channel_ids, - 'snr': snrs, - 'quality': "good" - # TODO: NWB 1.0 is missing quality table, need to find an - # equivalent - }) - - units_df.set_index('unit_id', inplace=True) + units_df = pd.DataFrame( + { + "unit_id": pd.Series(unit_ids, dtype=np.int64), + "local_index": local_indices, + "peak_channel_id": peak_channel_ids, + "snr": snrs, + "quality": "good", + # TODO: NWB 1.0 is missing quality table, need to find an + # equivalent + } + ) + + units_df.set_index("unit_id", inplace=True) return units_df def get_invalid_times(self) -> pd.DataFrame: diff --git a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py index de79a9bcce..7d1d04ce8c 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_nwb_session_api.py @@ -25,24 +25,19 @@ class EcephysNwbSessionApi(NwbApi, EcephysSessionApi): - - def __init__(self, - path, - probe_lfp_paths: Optional[ - Dict[int, Callable[[], pynwb.NWBFile]]] = None, - additional_unit_metrics=None, - external_channel_columns=None, - **kwargs): - - self.filter_out_of_brain_units = kwargs.pop( - "filter_out_of_brain_units", True) + def __init__( + self, + path, + probe_lfp_paths: Optional[Dict[int, Callable[[], pynwb.NWBFile]]] = None, + additional_unit_metrics=None, + external_channel_columns=None, + **kwargs, + ): + self.filter_out_of_brain_units = kwargs.pop("filter_out_of_brain_units", True) self.filter_by_validity = kwargs.pop("filter_by_validity", True) - self.amplitude_cutoff_maximum = get_unit_filter_value( - "amplitude_cutoff_maximum", **kwargs) - self.presence_ratio_minimum = get_unit_filter_value( - "presence_ratio_minimum", **kwargs) - self.isi_violations_maximum = get_unit_filter_value( - "isi_violations_maximum", **kwargs) + self.amplitude_cutoff_maximum = get_unit_filter_value("amplitude_cutoff_maximum", **kwargs) + self.presence_ratio_minimum = get_unit_filter_value("presence_ratio_minimum", **kwargs) + self.isi_violations_maximum = get_unit_filter_value("isi_violations_maximum", **kwargs) super(EcephysNwbSessionApi, self).__init__(path, **kwargs) self.probe_lfp_paths = probe_lfp_paths @@ -60,10 +55,12 @@ def __init__(self, f"was created by a previous (and incompatible) version of " f"AllenSDK and pynwb. You will need to either 1) use " f"AllenSDK version < 2.0.0 or 2) re-download an updated " - f"version of the nwbfile to access the desired data.")) + f"version of the nwbfile to access the desired data." + ), + ) def test(self): - """ A minimal test to make sure that this API's NWB file exists and is + """A minimal test to make sure that this API's NWB file exists and is readable. Ecephys NWB files use the required session identifier field to store the session id, so this is guaranteed to be present for any uncorrupted NWB file. @@ -76,28 +73,25 @@ def get_session_start_time(self): return self.nwbfile.session_start_time def get_stimulus_presentations(self): - table = Presentations.from_nwb(nwbfile=self.nwbfile, - add_is_change=False) + table = Presentations.from_nwb(nwbfile=self.nwbfile, add_is_change=False) table = table.value if "color" in table.columns: # .loc breaks on nan values so fill with empty string # This is backwards compatible change for older nwb files # Newer ones encode nan value here with empty string - table['color'] = table['color'].fillna('') + table["color"] = table["color"].fillna("") # the color column actually contains two parameters. One is # coded as rgb triplets and the other as -1 or 1 if "color_triplet" not in table.columns: table["color_triplet"] = pd.Series("", index=table.index) rgb_color_match = table["color"].str.match(color_triplet_re) - table.loc[rgb_color_match, "color_triplet"] = table.loc[ - rgb_color_match, "color"] + table.loc[rgb_color_match, "color_triplet"] = table.loc[rgb_color_match, "color"] table.loc[rgb_color_match, "color"] = "" # make sure the color column's values are numeric - table.loc[table["color"] != "", "color"] = table.loc[ - table["color"] != "", "color"].apply(ast.literal_eval) + table.loc[table["color"] != "", "color"] = table.loc[table["color"] != "", "color"].apply(ast.literal_eval) return table @@ -109,32 +103,33 @@ def _probe_nwbfile(self, probe_id: int): "this object was not configured with probe_lfp_paths" ) elif probe_id not in self.probe_lfp_paths: - raise KeyError( - f"no probe lfp file path is recorded for probe {probe_id}") + raise KeyError(f"no probe lfp file path is recorded for probe {probe_id}") return self.probe_lfp_paths[probe_id]() def get_probes(self) -> pd.DataFrame: probes: Union[List, pd.DataFrame] = [] for k, v in self.nwbfile.electrode_groups.items(): - probes.append({ - 'id': v.probe_id, - 'name': v.name, - 'location': v.location, - "sampling_rate": v.device.sampling_rate, - "lfp_sampling_rate": v.lfp_sampling_rate, - "has_lfp_data": v.has_lfp_data - }) + probes.append( + { + "id": v.probe_id, + "name": v.name, + "location": v.location, + "sampling_rate": v.device.sampling_rate, + "lfp_sampling_rate": v.lfp_sampling_rate, + "has_lfp_data": v.has_lfp_data, + } + ) probes = pd.DataFrame(probes) - probes = probes.set_index(keys='id', drop=True) + probes = probes.set_index(keys="id", drop=True) probes = probes.rename(columns={"name": "description"}) return probes def get_channels(self) -> pd.DataFrame: channels = Channels.from_nwb(nwbfile=self.nwbfile) channels = channels.to_dataframe( - external_channel_columns=self.external_channel_columns, - filter_by_validity=self.filter_by_validity) + external_channel_columns=self.external_channel_columns, filter_by_validity=self.filter_by_validity + ) return channels @@ -157,13 +152,13 @@ def get_units(self) -> pd.DataFrame: filter_out_of_brain_units=self.filter_out_of_brain_units, amplitude_cutoff_maximum=self.amplitude_cutoff_maximum, presence_ratio_minimum=self.presence_ratio_minimum, - isi_violations_maximum=self.isi_violations_maximum + isi_violations_maximum=self.isi_violations_maximum, ) def get_lfp(self, probe_id: int) -> xr.DataArray: lfp_file = self._probe_nwbfile(probe_id) - lfp = lfp_file.get_acquisition(f'probe_{probe_id}_lfp') - series = lfp.get_electrical_series(f'probe_{probe_id}_lfp_data') + lfp = lfp_file.get_acquisition(f"probe_{probe_id}_lfp") + series = lfp.get_electrical_series(f"probe_{probe_id}_lfp_data") electrodes = lfp_file.electrodes.to_dataframe() @@ -171,10 +166,7 @@ def get_lfp(self, probe_id: int) -> xr.DataArray: timestamps = series.timestamps[:] return xr.DataArray( - name="LFP", - data=data, - dims=['time', 'channel'], - coords=[timestamps, electrodes.index.values] + name="LFP", data=data, dims=["time", "channel"], coords=[timestamps, electrodes.index.values] ) def get_running_speed(self, include_rotation=False) -> pd.DataFrame: @@ -185,11 +177,13 @@ def get_running_speed(self, include_rotation=False) -> pd.DataFrame: running_speed_end_series = running_module["running_speed_end_times"] running_speed_end_times = running_speed_end_series.timestamps[:] - running = pd.DataFrame({ - "start_time": running_speed_start_times, - "end_time": running_speed_end_times, - "velocity": running_speed_series.data[:] - }) + running = pd.DataFrame( + { + "start_time": running_speed_start_times, + "end_time": running_speed_end_times, + "velocity": running_speed_series.data[:], + } + ) if include_rotation: rotation_series = running_module["running_wheel_rotation"] @@ -198,90 +192,72 @@ def get_running_speed(self, include_rotation=False) -> pd.DataFrame: return running def get_raw_running_data(self): - rotation_series = self.nwbfile.get_acquisition( - "raw_running_wheel_rotation") - signal_voltage_series = self.nwbfile.get_acquisition( - "running_wheel_signal_voltage") - supply_voltage_series = self.nwbfile.get_acquisition( - "running_wheel_supply_voltage") - - return pd.DataFrame({ - "frame_time": rotation_series.timestamps[:], - "net_rotation": rotation_series.data[:], - "signal_voltage": signal_voltage_series.data[:], - "supply_voltage": supply_voltage_series.data[:] - }) + rotation_series = self.nwbfile.get_acquisition("raw_running_wheel_rotation") + signal_voltage_series = self.nwbfile.get_acquisition("running_wheel_signal_voltage") + supply_voltage_series = self.nwbfile.get_acquisition("running_wheel_supply_voltage") + + return pd.DataFrame( + { + "frame_time": rotation_series.timestamps[:], + "net_rotation": rotation_series.data[:], + "signal_voltage": signal_voltage_series.data[:], + "supply_voltage": supply_voltage_series.data[:], + } + ) def get_rig_metadata(self) -> Optional[dict]: try: - et_mod = self.nwbfile.get_processing_module( - "eye_tracking_rig_metadata") + et_mod = self.nwbfile.get_processing_module("eye_tracking_rig_metadata") except KeyError as e: print( f"This ecephys session '{int(self.nwbfile.identifier)}' has " - f"no eye tracking rig metadata. (NWB error: {e})") + f"no eye tracking rig metadata. (NWB error: {e})" + ) return None meta = et_mod.get_data_interface("eye_tracking_rig_metadata") - rig_geometry = pd.DataFrame({ - f"monitor_position_{meta.monitor_position__unit}": - meta.monitor_position, - f"camera_position_{meta.camera_position__unit}": - meta.camera_position, - f"led_position_{meta.led_position__unit}": meta.led_position, - f"monitor_rotation_{meta.monitor_rotation__unit}": - meta.monitor_rotation, - f"camera_rotation_{meta.camera_rotation__unit}": - meta.camera_rotation - }) - - rig_geometry = rig_geometry.rename(index={0: 'x', 1: 'y', 2: 'z'}) - - returned_metadata = { - "geometry": rig_geometry, - "equipment": meta.equipment - } + rig_geometry = pd.DataFrame( + { + f"monitor_position_{meta.monitor_position__unit}": meta.monitor_position, + f"camera_position_{meta.camera_position__unit}": meta.camera_position, + f"led_position_{meta.led_position__unit}": meta.led_position, + f"monitor_rotation_{meta.monitor_rotation__unit}": meta.monitor_rotation, + f"camera_rotation_{meta.camera_rotation__unit}": meta.camera_rotation, + } + ) + + rig_geometry = rig_geometry.rename(index={0: "x", 1: "y", 2: "z"}) + + returned_metadata = {"geometry": rig_geometry, "equipment": meta.equipment} return returned_metadata - def get_screen_gaze_data(self, include_filtered_data=False) -> \ - Optional[pd.DataFrame]: + def get_screen_gaze_data(self, include_filtered_data=False) -> Optional[pd.DataFrame]: try: rgm_mod = self.nwbfile.get_processing_module("raw_gaze_mapping") - fgm_mod = self.nwbfile.get_processing_module( - "filtered_gaze_mapping") + fgm_mod = self.nwbfile.get_processing_module("filtered_gaze_mapping") except KeyError as e: - print( - f"This ecephys session '{int(self.nwbfile.identifier)}' has " - f"no eye tracking data. (NWB error: {e})") + print(f"This ecephys session '{int(self.nwbfile.identifier)}' has no eye tracking data. (NWB error: {e})") return None raw_eye_area_ts = rgm_mod.get_data_interface("eye_area") raw_pupil_area_ts = rgm_mod.get_data_interface("pupil_area") - raw_screen_coordinates_ts = rgm_mod.get_data_interface( - "screen_coordinates") - raw_screen_coordinates_spherical_ts = rgm_mod.get_data_interface( - "screen_coordinates_spherical") + raw_screen_coordinates_ts = rgm_mod.get_data_interface("screen_coordinates") + raw_screen_coordinates_spherical_ts = rgm_mod.get_data_interface("screen_coordinates_spherical") filtered_eye_area_ts = fgm_mod.get_data_interface("eye_area") filtered_pupil_area_ts = fgm_mod.get_data_interface("pupil_area") - filtered_screen_coordinates_ts = fgm_mod.get_data_interface( - "screen_coordinates") - filtered_screen_coordinates_spherical_ts = fgm_mod.get_data_interface( - "screen_coordinates_spherical") + filtered_screen_coordinates_ts = fgm_mod.get_data_interface("screen_coordinates") + filtered_screen_coordinates_spherical_ts = fgm_mod.get_data_interface("screen_coordinates_spherical") gaze_data = { "raw_eye_area": raw_eye_area_ts.data[:], "raw_pupil_area": raw_pupil_area_ts.data[:], - "raw_screen_coordinates_x_cm": - raw_screen_coordinates_ts.data[:, 1], - "raw_screen_coordinates_y_cm": - raw_screen_coordinates_ts.data[:, 0], - "raw_screen_coordinates_spherical_x_deg": - raw_screen_coordinates_spherical_ts.data[:, 1], - "raw_screen_coordinates_spherical_y_deg": - raw_screen_coordinates_spherical_ts.data[:, 0] + "raw_screen_coordinates_x_cm": raw_screen_coordinates_ts.data[:, 1], + "raw_screen_coordinates_y_cm": raw_screen_coordinates_ts.data[:, 0], + "raw_screen_coordinates_spherical_x_deg": raw_screen_coordinates_spherical_ts.data[:, 1], + "raw_screen_coordinates_spherical_y_deg": raw_screen_coordinates_spherical_ts.data[:, 0], } if include_filtered_data: @@ -289,18 +265,10 @@ def get_screen_gaze_data(self, include_filtered_data=False) -> \ { "filtered_eye_area": filtered_eye_area_ts.data[:], "filtered_pupil_area": filtered_pupil_area_ts.data[:], - "filtered_screen_coordinates_x_cm": - filtered_screen_coordinates_ts.data[ - :, 1], - "filtered_screen_coordinates_y_cm": - filtered_screen_coordinates_ts.data[ - :, 0], - "filtered_screen_coordinates_spherical_x_deg": - filtered_screen_coordinates_spherical_ts.data[ - :, 1], - "filtered_screen_coordinates_spherical_y_deg": - filtered_screen_coordinates_spherical_ts.data[ - :, 0] + "filtered_screen_coordinates_x_cm": filtered_screen_coordinates_ts.data[:, 1], + "filtered_screen_coordinates_y_cm": filtered_screen_coordinates_ts.data[:, 0], + "filtered_screen_coordinates_spherical_x_deg": filtered_screen_coordinates_spherical_ts.data[:, 1], + "filtered_screen_coordinates_spherical_y_deg": filtered_screen_coordinates_spherical_ts.data[:, 0], } ) @@ -312,17 +280,12 @@ def get_pupil_data(self) -> Optional[pd.DataFrame]: et_mod = self.nwbfile.get_processing_module("eye_tracking") rgm_mod = self.nwbfile.get_processing_module("raw_gaze_mapping") except KeyError as e: - print( - f"This ecephys session '{int(self.nwbfile.identifier)}' has " - f"no eye tracking data. (NWB error: {e})") + print(f"This ecephys session '{int(self.nwbfile.identifier)}' has no eye tracking data. (NWB error: {e})") return None - cr_ellipse_fits = et_mod.get_data_interface( - "cr_ellipse_fits").to_dataframe() - eye_ellipse_fits = et_mod.get_data_interface( - "eye_ellipse_fits").to_dataframe() - pupil_ellipse_fits = et_mod.get_data_interface( - "pupil_ellipse_fits").to_dataframe() + cr_ellipse_fits = et_mod.get_data_interface("cr_ellipse_fits").to_dataframe() + eye_ellipse_fits = et_mod.get_data_interface("eye_ellipse_fits").to_dataframe() + pupil_ellipse_fits = et_mod.get_data_interface("pupil_ellipse_fits").to_dataframe() # NOTE: ellipse fit "height" and "width" parameters describe the # "half-height" and "half-width" of fitted ellipse. @@ -332,18 +295,16 @@ def get_pupil_data(self) -> Optional[pd.DataFrame]: "corneal_reflection_height": 2 * cr_ellipse_fits["height"].values, "corneal_reflection_width": 2 * cr_ellipse_fits["width"].values, "corneal_reflection_phi": cr_ellipse_fits["phi"].values, - "pupil_center_x": pupil_ellipse_fits["center_x"].values, "pupil_center_y": pupil_ellipse_fits["center_y"].values, "pupil_height": 2 * pupil_ellipse_fits["height"].values, "pupil_width": 2 * pupil_ellipse_fits["width"].values, "pupil_phi": pupil_ellipse_fits["phi"].values, - "eye_center_x": eye_ellipse_fits["center_x"].values, "eye_center_y": eye_ellipse_fits["center_y"].values, "eye_height": 2 * eye_ellipse_fits["height"].values, "eye_width": 2 * eye_ellipse_fits["width"].values, - "eye_phi": eye_ellipse_fits["phi"].values + "eye_phi": eye_ellipse_fits["phi"].values, } timestamps = rgm_mod.get_data_interface("eye_area").timestamps[:] @@ -354,11 +315,9 @@ def get_ecephys_session_id(self) -> int: return int(self.nwbfile.identifier) def get_current_source_density(self, probe_id): - csd_mod = self._probe_nwbfile(probe_id).get_processing_module( - "current_source_density") + csd_mod = self._probe_nwbfile(probe_id).get_processing_module("current_source_density") nwb_csd = csd_mod["ecephys_csd"] - csd_data = nwb_csd.time_series.data[ - :].T # csd data stored as (timepoints x channels) but we + csd_data = nwb_csd.time_series.data[:].T # csd data stored as (timepoints x channels) but we # want (channels x timepoints) csd = xr.DataArray( @@ -368,11 +327,9 @@ def get_current_source_density(self, probe_id): coords={ "virtual_channel_index": np.arange(csd_data.shape[0]), "time": nwb_csd.time_series.timestamps[:], - "vertical_position": (("virtual_channel_index",), - nwb_csd.virtual_electrode_y_positions), - "horizontal_position": (("virtual_channel_index",), - nwb_csd.virtual_electrode_x_positions) - } + "vertical_position": (("virtual_channel_index",), nwb_csd.virtual_electrode_y_positions), + "horizontal_position": (("virtual_channel_index",), nwb_csd.virtual_electrode_x_positions), + }, ) return csd @@ -391,6 +348,6 @@ def get_metadata(self): "stimulus_name": self.nwbfile.stimulus_notes, "subject_id": nwb_subject.subject_id, "age": nwb_subject.age, - "species": nwb_subject.species + "species": nwb_subject.species, } return metadata diff --git a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_session_api.py b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_session_api.py index 5334aa00f0..cf79fea7e1 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_session_api.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session_api/ecephys_session_api.py @@ -9,7 +9,6 @@ class EcephysSessionApi: - session_na = -1 __slots__: tuple = tuple([]) diff --git a/allensdk/brain_observatory/ecephys/file_io/continuous_file.py b/allensdk/brain_observatory/ecephys/file_io/continuous_file.py index fa950f3b8f..cc23948159 100644 --- a/allensdk/brain_observatory/ecephys/file_io/continuous_file.py +++ b/allensdk/brain_observatory/ecephys/file_io/continuous_file.py @@ -38,15 +38,13 @@ import logging -class ContinuousFile(): - +class ContinuousFile: """ Represents a continuous (.dat) file, and its associated timestamps """ def __init__(self, data_path, timestamps_path, total_num_channels=384, dtype=np.int16): - """ data_path : str Path to file containing LFP data. The file is expected to be a raw binary with channels as its fast axis and samples as its slow axis. @@ -63,9 +61,7 @@ def __init__(self, data_path, timestamps_path, total_num_channels=384, dtype=np. self.total_num_channels = total_num_channels self.dtype = dtype - - def load(self, memmap=False, memmap_thresh = 10e9): - + def load(self, memmap=False, memmap_thresh=10e9): """ Reads lfp data and timestamps from the filesystem @@ -73,7 +69,7 @@ def load(self, memmap=False, memmap_thresh = 10e9): ---------- memmap : bool, optional - If True, the returned data array will be a memory map of the file on disk. Default is True. + If True, the returned data array will be a memory map of the file on disk. Default is True. memmap_thresh : float, optional Files above this size in bytes will be memory-mapped, regardless of memmap setting @@ -86,38 +82,40 @@ def load(self, memmap=False, memmap_thresh = 10e9): """ - logging.info('loading timestamps from {}'.format(self.timestamps_path)) + logging.info("loading timestamps from {}".format(self.timestamps_path)) timestamps = np.load(self.timestamps_path, allow_pickle=False) - logging.info('done loading timestamps from {}. Count: {}'.format(self.timestamps_path, timestamps.size)) + logging.info("done loading timestamps from {}. Count: {}".format(self.timestamps_path, timestamps.size)) bytes_per_sample = self.dtype(0).nbytes - num_samples = timestamps.size * self.total_num_channels + num_samples = timestamps.size * self.total_num_channels expected_num_bytes = num_samples * bytes_per_sample - logging.info('calculated LFP filesize: {} bytes'.format(expected_num_bytes)) + logging.info("calculated LFP filesize: {} bytes".format(expected_num_bytes)) num_bytes = Path(self.data_path).stat().st_size if not expected_num_bytes == num_bytes: - raise IOError('expected LFP data filesize to be {} bytes, but its size was {} bytes'.format(expected_num_bytes, num_bytes)) + raise IOError( + "expected LFP data filesize to be {} bytes, but its size was {} bytes".format( + expected_num_bytes, num_bytes + ) + ) shape = (timestamps.size, self.total_num_channels) - logging.info('calculated LFP data shape: {}'.format(shape)) + logging.info("calculated LFP data shape: {}".format(shape)) if memmap or num_bytes > memmap_thresh: - logging.info('memmaping LFP file at {}'.format(self.data_path)) - lfp_raw = np.memmap(self.data_path, dtype=self.dtype, shape=shape, mode='r') - logging.info('done memmaping LFP file at {}'.format(self.data_path)) + logging.info("memmaping LFP file at {}".format(self.data_path)) + lfp_raw = np.memmap(self.data_path, dtype=self.dtype, shape=shape, mode="r") + logging.info("done memmaping LFP file at {}".format(self.data_path)) else: - with open(self.data_path, 'rb') as data_file: - logging.info('reading LFP file at {}'.format(self.data_path)) + with open(self.data_path, "rb") as data_file: + logging.info("reading LFP file at {}".format(self.data_path)) lfp_raw = np.frombuffer(data_file.read(), dtype=self.dtype) - logging.info('done reading LFP file at {}'.format(self.data_path)) + logging.info("done reading LFP file at {}".format(self.data_path)) lfp_raw = lfp_raw.reshape(shape) return lfp_raw, timestamps - def get_lfp_channel_order(self): - """ Returns the channel ordering for LFP data extracted from NPX files. @@ -131,10 +129,59 @@ def get_lfp_channel_order(self): Contains the actual channel ordering. """ - remapping_pattern = np.array([0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, - 8, 20, 9, 21, 10, 22, 11, 23, 24, 36, 25, 37, 26, 38, - 27, 39, 28, 40, 29, 41, 30, 42, 31, 43, 32, 44, 33, 45, 34, 46, 35, 47]) - - channel_order = np.concatenate([remapping_pattern + 48*i for i in range(0,8)]) + remapping_pattern = np.array( + [ + 0, + 12, + 1, + 13, + 2, + 14, + 3, + 15, + 4, + 16, + 5, + 17, + 6, + 18, + 7, + 19, + 8, + 20, + 9, + 21, + 10, + 22, + 11, + 23, + 24, + 36, + 25, + 37, + 26, + 38, + 27, + 39, + 28, + 40, + 29, + 41, + 30, + 42, + 31, + 43, + 32, + 44, + 33, + 45, + 34, + 46, + 35, + 47, + ] + ) + + channel_order = np.concatenate([remapping_pattern + 48 * i for i in range(0, 8)]) return channel_order diff --git a/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py b/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py index 9f182b2859..fed6bdba29 100644 --- a/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py +++ b/allensdk/brain_observatory/ecephys/file_io/ecephys_sync_dataset.py @@ -10,19 +10,18 @@ class EcephysSyncDataset(Dataset): - @property def sample_frequency(self): - return self.meta_data['ni_daq']['counter_output_freq'] + return self.meta_data["ni_daq"]["counter_output_freq"] @sample_frequency.setter def sample_frequency(self, value): - if not hasattr(self, 'meta_data'): + if not hasattr(self, "meta_data"): self.meta_data = defaultdict(dict) - self.meta_data['ni_daq']['counter_output_freq'] = value + self.meta_data["ni_daq"]["counter_output_freq"] = value def __init__(self): - '''In-memory representation of a sync h5 file as produced by the sync package. + """In-memory representation of a sync h5 file as produced by the sync package. Notes ----- @@ -32,28 +31,19 @@ def __init__(self): object. To make a new SyncDataset in client code, use the factory classmethod. This is done for ease of testability. - ''' + """ pass - def extract_led_times(self, - keys=Dataset.OPTOGENETIC_STIMULATION_KEYS, - fallback_line=18): - + def extract_led_times(self, keys=Dataset.OPTOGENETIC_STIMULATION_KEYS, fallback_line=18): try: - led_times = self.get_edges( - kind="rising", - keys=keys, - units="seconds" - ) + led_times = self.get_edges(kind="rising", keys=keys, units="seconds") except KeyError: - warnings.warn("unable to find LED times using line labels" + - f"{keys}, returning line {fallback_line}") + warnings.warn("unable to find LED times using line labels" + f"{keys}, returning line {fallback_line}") led_times = self.get_rising_edges(fallback_line, units="seconds") return led_times def remove_zero_frames(self, frame_times): - D = np.diff(frame_times) a = np.where(D < 0.01)[0] @@ -72,65 +62,52 @@ def find_match(b, value): for idx, d in enumerate(a): if c[idx] is not None: if c[idx] > -100: - ft[d+c[idx]] = np.median(D) + ft[d + c[idx]] = np.median(D) ft[d] = np.median(D) - t = np.concatenate(([np.min(frame_times)], - np.cumsum(ft) + np.min(frame_times))) + t = np.concatenate(([np.min(frame_times)], np.cumsum(ft) + np.min(frame_times))) return t def extract_frame_times_from_photodiode( - self, - photodiode_cycle=60, - frame_keys=Dataset.FRAME_KEYS, - photodiode_keys=Dataset.PHOTODIODE_KEYS, - trim_discontiguous_frame_times=True): - - photodiode_times = self.get_edges('all', photodiode_keys) - vsync_times = self.get_edges('falling', frame_keys) + self, + photodiode_cycle=60, + frame_keys=Dataset.FRAME_KEYS, + photodiode_keys=Dataset.PHOTODIODE_KEYS, + trim_discontiguous_frame_times=True, + ): + photodiode_times = self.get_edges("all", photodiode_keys) + vsync_times = self.get_edges("falling", frame_keys) if trim_discontiguous_frame_times: vsync_times = stimulus_sync.trim_discontiguous_vsyncs(vsync_times) - vsync_times_chunked, pd_times_chunked = \ - stimulus_sync.separate_vsyncs_and_photodiode_times( - vsync_times, - photodiode_times, - photodiode_cycle) + vsync_times_chunked, pd_times_chunked = stimulus_sync.separate_vsyncs_and_photodiode_times( + vsync_times, photodiode_times, photodiode_cycle + ) logging.info(f"Total chunks: {len(vsync_times_chunked)}") frame_start_times = np.zeros((0,)) for i in range(len(vsync_times_chunked)): + photodiode_times = stimulus_sync.trim_border_pulses(pd_times_chunked[i], vsync_times_chunked[i]) + photodiode_times = stimulus_sync.correct_on_off_effects(photodiode_times) + photodiode_times = stimulus_sync.fix_unexpected_edges(photodiode_times, cycle=photodiode_cycle) - photodiode_times = stimulus_sync.trim_border_pulses( - pd_times_chunked[i], - vsync_times_chunked[i]) - photodiode_times = stimulus_sync.correct_on_off_effects( - photodiode_times) - photodiode_times = stimulus_sync.fix_unexpected_edges( + frame_duration = stimulus_sync.estimate_frame_duration(photodiode_times, cycle=photodiode_cycle) + irregular_interval_policy = functools.partial( + stimulus_sync.allocate_by_vsync, np.diff(vsync_times_chunked[i]) + ) + frame_indices, frame_starts, frame_end_times = stimulus_sync.compute_frame_times( photodiode_times, - cycle=photodiode_cycle) + frame_duration, + len(vsync_times_chunked[i]), + cycle=photodiode_cycle, + irregular_interval_policy=irregular_interval_policy, + ) - frame_duration = stimulus_sync.estimate_frame_duration( - photodiode_times, - cycle=photodiode_cycle) - irregular_interval_policy = functools.partial( - stimulus_sync.allocate_by_vsync, - np.diff(vsync_times_chunked[i])) - frame_indices, frame_starts, frame_end_times = \ - stimulus_sync.compute_frame_times( - photodiode_times, - frame_duration, - len(vsync_times_chunked[i]), - cycle=photodiode_cycle, - irregular_interval_policy=irregular_interval_policy - ) - - frame_start_times = np.concatenate((frame_start_times, - frame_starts)) + frame_start_times = np.concatenate((frame_start_times, frame_starts)) frame_start_times = self.remove_zero_frames(frame_start_times) @@ -139,40 +116,35 @@ def extract_frame_times_from_photodiode( return frame_start_times def extract_frame_times_from_vsyncs( - self, - photodiode_cycle=60, - frame_keys=Dataset.FRAME_KEYS, photodiode_keys=Dataset.PHOTODIODE_KEYS + self, photodiode_cycle=60, frame_keys=Dataset.FRAME_KEYS, photodiode_keys=Dataset.PHOTODIODE_KEYS ): raise NotImplementedError() def extract_frame_times( - self, - strategy, - photodiode_cycle=60, - frame_keys=Dataset.FRAME_KEYS, - photodiode_keys=Dataset.PHOTODIODE_KEYS, - trim_discontiguous_frame_times=True - ): - - if strategy == 'use_photodiode': + self, + strategy, + photodiode_cycle=60, + frame_keys=Dataset.FRAME_KEYS, + photodiode_keys=Dataset.PHOTODIODE_KEYS, + trim_discontiguous_frame_times=True, + ): + if strategy == "use_photodiode": return self.extract_frame_times_from_photodiode( photodiode_cycle=photodiode_cycle, frame_keys=frame_keys, photodiode_keys=photodiode_keys, - trim_discontiguous_frame_times=trim_discontiguous_frame_times - ) - elif strategy == 'use_vsyncs': + trim_discontiguous_frame_times=trim_discontiguous_frame_times, + ) + elif strategy == "use_vsyncs": return self.extract_frame_times_from_vsyncs( - photodiode_cycle=photodiode_cycle, - frame_keys=frame_keys, - photodiode_keys=photodiode_keys - ) + photodiode_cycle=photodiode_cycle, frame_keys=frame_keys, photodiode_keys=photodiode_keys + ) else: - raise ValueError('unrecognized strategy: {}'.format(strategy)) + raise ValueError("unrecognized strategy: {}".format(strategy)) @classmethod def factory(cls, path): - ''' Build a new SyncDataset. + """Build a new SyncDataset. Parameters ---------- @@ -180,9 +152,11 @@ def factory(cls, path): Filesystem path to the h5 file containing sync information to be loaded. - ''' + """ obj = cls() obj.load(path) return obj + + # diff --git a/allensdk/brain_observatory/ecephys/file_io/stim_file.py b/allensdk/brain_observatory/ecephys/file_io/stim_file.py index a9298169ba..15dad875c7 100644 --- a/allensdk/brain_observatory/ecephys/file_io/stim_file.py +++ b/allensdk/brain_observatory/ecephys/file_io/stim_file.py @@ -3,71 +3,57 @@ class CamStimOnePickleStimFile(object): - - @property def stimuli(self): - '''List of dictionaries containing information about individual stimuli - ''' - return self.data['stimuli'] - + """List of dictionaries containing information about individual stimuli""" + return self.data["stimuli"] @property def frames_per_second(self): - '''Framerate of stimulus presentation - ''' - return self.data['fps'] - + """Framerate of stimulus presentation""" + return self.data["fps"] @property def pre_blank_sec(self): - '''Time (s) before initial stimulus presentation - ''' - return self.data['pre_blank_sec'] - + """Time (s) before initial stimulus presentation""" + return self.data["pre_blank_sec"] @property def angular_wheel_velocity(self): - ''' Extract the mean angular velocity of the running wheel (degrees / s) for each + """Extract the mean angular velocity of the running wheel (degrees / s) for each frame. - ''' + """ return self.frames_per_second * self.angular_wheel_rotation - @property def angular_wheel_rotation(self): - ''' Extract the total rotation of the running wheel on each frame. - ''' + """Extract the total rotation of the running wheel on each frame.""" return self._extract_running_array("dx") - @property def vsig(self): - """Running speed signal voltage - """ + """Running speed signal voltage""" return self._extract_running_array("vsig") @property def vin(self): return self._extract_running_array("vin") - def __init__(self, data, **kwargs): self.data = data - def _extract_running_array(self, key): try: - result = self.data['items']['foraging']['encoders'][0][key] + result = self.data["items"]["foraging"]["encoders"][0][key] except (KeyError, IndexError): try: result = self.data[key] except KeyError: - raise KeyError(f'unable to extract {key} from this stimulus pickle') - + raise KeyError(f"unable to extract {key} from this stimulus pickle") + return np.array(result) @classmethod def factory(cls, path, **kwargs): data = pd.read_pickle(path) - return cls(data, **kwargs) \ No newline at end of file + return cls(data, **kwargs) diff --git a/allensdk/brain_observatory/ecephys/lfp_subsampling/__main__.py b/allensdk/brain_observatory/ecephys/lfp_subsampling/__main__.py index 7ac8c2e245..9e16e7b3c8 100644 --- a/allensdk/brain_observatory/ecephys/lfp_subsampling/__main__.py +++ b/allensdk/brain_observatory/ecephys/lfp_subsampling/__main__.py @@ -37,15 +37,10 @@ import numpy as np -from allensdk.brain_observatory.argschema_utilities import \ - ArgSchemaParserPlus, \ - write_or_print_outputs -from allensdk.brain_observatory.ecephys.file_io.continuous_file import \ - ContinuousFile +from allensdk.brain_observatory.argschema_utilities import ArgSchemaParserPlus, write_or_print_outputs +from allensdk.brain_observatory.ecephys.file_io.continuous_file import ContinuousFile from ._schemas import InputParameters, OutputParameters -from .subsampling import select_channels, subsample_timestamps, \ - subsample_lfp, \ - remove_lfp_offset, remove_lfp_noise +from .subsampling import select_channels, subsample_timestamps, subsample_lfp, remove_lfp_offset, remove_lfp_noise logger = logging.getLogger(__name__) @@ -56,84 +51,83 @@ def subsample(args): :param args: :return: """ - params = args['lfp_subsampling'] + params = args["lfp_subsampling"] probe_outputs = [] - for probe in args['probes']: - logging.info("Sub-sampling LFP for " + probe['name']) - lfp_data_file = ContinuousFile(probe['lfp_input_file_path'], - probe['lfp_timestamps_input_path'], - probe['total_channels']) + for probe in args["probes"]: + logging.info("Sub-sampling LFP for " + probe["name"]) + lfp_data_file = ContinuousFile( + probe["lfp_input_file_path"], probe["lfp_timestamps_input_path"], probe["total_channels"] + ) logging.info("loading lfp data...") lfp_raw, timestamps = lfp_data_file.load() - if params['reorder_channels']: + if params["reorder_channels"]: lfp_channel_order = lfp_data_file.get_lfp_channel_order() else: - lfp_channel_order = np.arange(0, probe['total_channels']) + lfp_channel_order = np.arange(0, probe["total_channels"]) logging.info("selecting channels...") channels_to_save, actual_channels = select_channels( - probe['total_channels'], - probe['surface_channel'], - params['surface_padding'], - params['start_channel_offset'], - params['channel_stride'], + probe["total_channels"], + probe["surface_channel"], + params["surface_padding"], + params["start_channel_offset"], + params["channel_stride"], lfp_channel_order, - probe.get('noisy_channels', []), - params['remove_noisy_channels'], - probe['reference_channels'], - params['remove_reference_channels']) + probe.get("noisy_channels", []), + params["remove_noisy_channels"], + probe["reference_channels"], + params["remove_reference_channels"], + ) - ts_subsampled = subsample_timestamps(timestamps, params[ - 'temporal_subsampling_factor']) + ts_subsampled = subsample_timestamps(timestamps, params["temporal_subsampling_factor"]) logging.info("subsampling data...") - lfp_subsampled = subsample_lfp(lfp_raw, channels_to_save, - params['temporal_subsampling_factor']) + lfp_subsampled = subsample_lfp(lfp_raw, channels_to_save, params["temporal_subsampling_factor"]) del lfp_raw logging.info("removing offset...") - lfp_filtered = remove_lfp_offset(lfp_subsampled, - probe['lfp_sampling_rate'] / params[ - 'temporal_subsampling_factor'], - params['cutoff_frequency'], - params['filter_order']) + lfp_filtered = remove_lfp_offset( + lfp_subsampled, + probe["lfp_sampling_rate"] / params["temporal_subsampling_factor"], + params["cutoff_frequency"], + params["filter_order"], + ) del lfp_subsampled - logging.info("Surface channel: " + str(probe['surface_channel'])) + logging.info("Surface channel: " + str(probe["surface_channel"])) logging.info("removing noise...") - lfp = remove_lfp_noise(lfp_filtered, probe['surface_channel'], - actual_channels) + lfp = remove_lfp_noise(lfp_filtered, probe["surface_channel"], actual_channels) del lfp_filtered - if params['remove_channels_out_of_brain']: - channels_to_keep = actual_channels < ( - probe['surface_channel'] + 10) + if params["remove_channels_out_of_brain"]: + channels_to_keep = actual_channels < (probe["surface_channel"] + 10) actual_channels = actual_channels[channels_to_keep] lfp = lfp[:, channels_to_keep] - logging.info('Writing to disk...') - lfp.tofile(probe['lfp_data_path']) - np.save(probe['lfp_timestamps_path'], ts_subsampled) - np.save(probe['lfp_channel_info_path'], actual_channels) + logging.info("Writing to disk...") + lfp.tofile(probe["lfp_data_path"]) + np.save(probe["lfp_timestamps_path"], ts_subsampled) + np.save(probe["lfp_channel_info_path"], actual_channels) - probe_outputs.append({'name': probe['name'], - 'lfp_data_path': probe['lfp_data_path'], - 'lfp_timestamps_path': probe[ - 'lfp_timestamps_path'], - 'lfp_channel_info_path': probe[ - 'lfp_channel_info_path']}) + probe_outputs.append( + { + "name": probe["name"], + "lfp_data_path": probe["lfp_data_path"], + "lfp_timestamps_path": probe["lfp_timestamps_path"], + "lfp_channel_info_path": probe["lfp_channel_info_path"], + } + ) - return {'probe_outputs': probe_outputs} + return {"probe_outputs": probe_outputs} def main(): - mod = ArgSchemaParserPlus(schema_type=InputParameters, - output_schema_type=OutputParameters) + mod = ArgSchemaParserPlus(schema_type=InputParameters, output_schema_type=OutputParameters) output = subsample(mod.args) write_or_print_outputs(data=output, parser=mod) diff --git a/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py b/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py index 5cdaba7720..37d64463d2 100644 --- a/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py +++ b/allensdk/brain_observatory/ecephys/lfp_subsampling/_schemas.py @@ -39,107 +39,49 @@ class ProbeInputParameters(DefaultSchema): - name = String(required=True, help='Identifier for this probe') - lfp_input_file_path = String( - required=True, - description="path to original LFP .dat file") - lfp_timestamps_input_path = String( - required=True, - description="path to LFP timestamps") - lfp_data_path = String( - required=True, - help="Path to LFP data continuous file") - lfp_timestamps_path = String( - required=True, - help="Path to LFP timestamps aligned to master clock") - lfp_channel_info_path = String( - required=True, - help="Path to LFP channel info") - total_channels = Int( - default=384, - help='Total channel count for this probe.') - surface_channel = Int( - required=True, - help="Probe surface channel") - reference_channels = NumpyArray( - required=False, - help="Probe reference channels") - lfp_sampling_rate = Float( - required=True, - help="Sampling rate of LFP data") - noisy_channels = NumpyArray( - required=False, - help="Noisy channels to remove") + name = String(required=True, help="Identifier for this probe") + lfp_input_file_path = String(required=True, description="path to original LFP .dat file") + lfp_timestamps_input_path = String(required=True, description="path to LFP timestamps") + lfp_data_path = String(required=True, help="Path to LFP data continuous file") + lfp_timestamps_path = String(required=True, help="Path to LFP timestamps aligned to master clock") + lfp_channel_info_path = String(required=True, help="Path to LFP channel info") + total_channels = Int(default=384, help="Total channel count for this probe.") + surface_channel = Int(required=True, help="Probe surface channel") + reference_channels = NumpyArray(required=False, help="Probe reference channels") + lfp_sampling_rate = Float(required=True, help="Sampling rate of LFP data") + noisy_channels = NumpyArray(required=False, help="Noisy channels to remove") class LfpSubsamplingParameters(DefaultSchema): - temporal_subsampling_factor = Int( - default=2, - description="Ratio of input samples to output samples in time") - channel_stride = Int( - default=4, - description="Distance between channels to keep") - surface_padding = Int( - default=40, - description="Number of channels above surface to include") - start_channel_offset = Int( - default=2, - description="Offset of first channel (from bottom of the probe)") - reorder_channels = Boolean( - default=False, - description="Implement channel reordering") - cutoff_frequency = Float( - default=0.1, - description="Cutoff frequency for DC offset filter (Butterworth)") - filter_order = Int( - default=1, - description="Order of DC offset filter (Butterworth)") - remove_reference_channels = Boolean( - default=False, - description="indicates whether references should be removed") + temporal_subsampling_factor = Int(default=2, description="Ratio of input samples to output samples in time") + channel_stride = Int(default=4, description="Distance between channels to keep") + surface_padding = Int(default=40, description="Number of channels above surface to include") + start_channel_offset = Int(default=2, description="Offset of first channel (from bottom of the probe)") + reorder_channels = Boolean(default=False, description="Implement channel reordering") + cutoff_frequency = Float(default=0.1, description="Cutoff frequency for DC offset filter (Butterworth)") + filter_order = Int(default=1, description="Order of DC offset filter (Butterworth)") + remove_reference_channels = Boolean(default=False, description="indicates whether references should be removed") remove_channels_out_of_brain = Boolean( - default=False, - description="indicates whether to remove channels outside the brain") - remove_noisy_channels = Boolean( - default=False, - description="indicates whether noisy channels should be removed") + default=False, description="indicates whether to remove channels outside the brain" + ) + remove_noisy_channels = Boolean(default=False, description="indicates whether noisy channels should be removed") class InputParameters(ArgSchema): - probes = Nested( - ProbeInputParameters, - many=True, - help='Probes for LFP subsampling') - lfp_subsampling = Nested( - LfpSubsamplingParameters, - help='Parameters for this module') + probes = Nested(ProbeInputParameters, many=True, help="Probes for LFP subsampling") + lfp_subsampling = Nested(LfpSubsamplingParameters, help="Parameters for this module") class OutputSchema(DefaultSchema): - input_parameters = Nested( - InputParameters, - description="Input parameters the module was run with", - required=True) + input_parameters = Nested(InputParameters, description="Input parameters the module was run with", required=True) class ProbeOutputParameters(DefaultSchema): - name = String( - equired=True, - help='Identifier for this probe.') - lfp_data_path = String( - required=True, - help='Output subsampled data file.') - lfp_timestamps_path = String( - required=True, - help='Timestamps for subsampled data.') - lfp_channel_info_path = String( - required=True, - help='LFP channels from that was subsampled.') + name = String(equired=True, help="Identifier for this probe.") + lfp_data_path = String(required=True, help="Output subsampled data file.") + lfp_timestamps_path = String(required=True, help="Timestamps for subsampled data.") + lfp_channel_info_path = String(required=True, help="LFP channels from that was subsampled.") class OutputParameters(OutputSchema): - probe_outputs = Nested( - ProbeOutputParameters, - many=True, - required=True, - help='probewise outputs') + probe_outputs = Nested(ProbeOutputParameters, many=True, required=True, help="probewise outputs") diff --git a/allensdk/brain_observatory/ecephys/lfp_subsampling/subsampling.py b/allensdk/brain_observatory/ecephys/lfp_subsampling/subsampling.py index ca96174a92..4e5e198ef1 100644 --- a/allensdk/brain_observatory/ecephys/lfp_subsampling/subsampling.py +++ b/allensdk/brain_observatory/ecephys/lfp_subsampling/subsampling.py @@ -41,16 +41,18 @@ logger = logging.getLogger(__name__) -def select_channels(total_channels, - surface_channel, - surface_padding, - start_channel_offset, - channel_stride, - channel_order, - noisy_channels=np.array([]), - remove_noisy_channels=False, - reference_channels=np.array([]), - remove_references=False): +def select_channels( + total_channels, + surface_channel, + surface_padding, + start_channel_offset, + channel_stride, + channel_order, + noisy_channels=np.array([]), + remove_noisy_channels=False, + reference_channels=np.array([]), + remove_references=False, +): """ Selects a subset of channels for spatial downsampling @@ -91,12 +93,10 @@ def select_channels(total_channels, max_channel = np.min([total_channels, surface_channel + surface_padding]) - selected_channels = channel_order[ - start_channel_offset:max_channel:channel_stride] + selected_channels = channel_order[start_channel_offset:max_channel:channel_stride] actual_channel_numbers = np.arange(total_channels) - actual_channel_numbers = actual_channel_numbers[ - start_channel_offset:max_channel:channel_stride] + actual_channel_numbers = actual_channel_numbers[start_channel_offset:max_channel:channel_stride] if remove_references or remove_noisy_channels: # TODO: Is there a case that reference/noisy channels won't be @@ -105,16 +105,14 @@ def select_channels(total_channels, logger.info("Before:") logger.info(actual_channel_numbers) # create mask to filter out reference channels - mask = (not remove_references) or np.isin(actual_channel_numbers, - reference_channels, - assume_unique=True, - invert=True) + mask = (not remove_references) or np.isin( + actual_channel_numbers, reference_channels, assume_unique=True, invert=True + ) # mask to remove noisy channels - mask &= (not remove_noisy_channels) or np.isin(actual_channel_numbers, - noisy_channels, - assume_unique=True, - invert=True) + mask &= (not remove_noisy_channels) or np.isin( + actual_channel_numbers, noisy_channels, assume_unique=True, invert=True + ) actual_channel_numbers = actual_channel_numbers[mask] selected_channels = selected_channels[mask] logger.info("After:") @@ -165,18 +163,16 @@ def subsample_lfp(lfp_raw, selected_channels, subsampling_factor): """ - num_samples = len(lfp_raw[::subsampling_factor, - 0]) # np.round(lfp_raw.shape[0] / + num_samples = len(lfp_raw[::subsampling_factor, 0]) # np.round(lfp_raw.shape[0] / # subsampling_factor).astype('int') num_channels = selected_channels.size - lfp_subsampled = np.zeros((num_samples, num_channels), dtype='int16') + lfp_subsampled = np.zeros((num_samples, num_channels), dtype="int16") for new_ch, old_ch in enumerate(selected_channels): - tmp = decimate(lfp_raw[:, old_ch], subsampling_factor, ftype='iir', - zero_phase=True) - assert (len(tmp) == num_samples) - lfp_subsampled[:, new_ch] = tmp.astype('int16') + tmp = decimate(lfp_raw[:, old_ch], subsampling_factor, ftype="iir", zero_phase=True) + assert len(tmp) == num_samples + lfp_subsampled[:, new_ch] = tmp.astype("int16") return lfp_subsampled @@ -203,19 +199,19 @@ def remove_lfp_offset(lfp, sampling_frequency, cutoff_frequency, filter_order): New 2D array of LFP values """ - lfp_filtered = np.zeros(lfp.shape, dtype='int16') - b, a = butter(filter_order, cutoff_frequency / (sampling_frequency / 2), - btype='high') + lfp_filtered = np.zeros(lfp.shape, dtype="int16") + b, a = butter(filter_order, cutoff_frequency / (sampling_frequency / 2), btype="high") for ch in range(lfp.shape[1]): tmp = filtfilt(b, a, lfp[:, ch]) - lfp_filtered[:, ch] = tmp.astype('int16') + lfp_filtered[:, ch] = tmp.astype("int16") return lfp_filtered -def remove_lfp_noise(lfp, surface_channel, channel_numbers, channel_max=384, - channel_limit=380, max_out_of_brain_channels=50): +def remove_lfp_noise( + lfp, surface_channel, channel_numbers, channel_max=384, channel_limit=380, max_out_of_brain_channels=50 +): """ Subtract mean of channels out of brain to remove noise @@ -242,10 +238,9 @@ def remove_lfp_noise(lfp, surface_channel, channel_numbers, channel_max=384, """ - lfp_noise_removed = np.zeros(lfp.shape, dtype='int16') + lfp_noise_removed = np.zeros(lfp.shape, dtype="int16") - surface_channel = channel_limit if surface_channel >= channel_max else \ - surface_channel + surface_channel = channel_limit if surface_channel >= channel_max else surface_channel channel_selection = np.where(channel_numbers > surface_channel)[0] if len(channel_selection) > max_out_of_brain_channels: @@ -255,6 +250,6 @@ def remove_lfp_noise(lfp, surface_channel, channel_numbers, channel_max=384, for ch in range(lfp.shape[1]): tmp = lfp[:, ch] - median_signal_out_of_brain - lfp_noise_removed[:, ch] = tmp.astype('int16') + lfp_noise_removed[:, ch] = tmp.astype("int16") return lfp_noise_removed diff --git a/allensdk/brain_observatory/ecephys/nwb/__init__.py b/allensdk/brain_observatory/ecephys/nwb/__init__.py index bc79a7ed5e..285eefe89a 100644 --- a/allensdk/brain_observatory/ecephys/nwb/__init__.py +++ b/allensdk/brain_observatory/ecephys/nwb/__init__.py @@ -7,14 +7,12 @@ namespace_path = (file_dir / "ndx-aibs-ecephys.namespace.yaml").resolve() pynwb.load_namespaces(str(namespace_path)) -EcephysProbe = pynwb.get_class('EcephysProbe', 'ndx-aibs-ecephys') +EcephysProbe = pynwb.get_class("EcephysProbe", "ndx-aibs-ecephys") -EcephysElectrodeGroup = pynwb.get_class('EcephysElectrodeGroup', - 'ndx-aibs-ecephys') +EcephysElectrodeGroup = pynwb.get_class("EcephysElectrodeGroup", "ndx-aibs-ecephys") -EcephysSpecimen = pynwb.get_class('EcephysSpecimen', 'ndx-aibs-ecephys') +EcephysSpecimen = pynwb.get_class("EcephysSpecimen", "ndx-aibs-ecephys") -EcephysEyeTrackingRigMetadata = pynwb.get_class('EcephysEyeTrackingRigMetadata', - 'ndx-aibs-ecephys') +EcephysEyeTrackingRigMetadata = pynwb.get_class("EcephysEyeTrackingRigMetadata", "ndx-aibs-ecephys") -EcephysCSD = pynwb.get_class('EcephysCSD', 'ndx-aibs-ecephys') +EcephysCSD = pynwb.get_class("EcephysCSD", "ndx-aibs-ecephys") diff --git a/allensdk/brain_observatory/ecephys/nwb/ecephys_nwb_extension_builder.py b/allensdk/brain_observatory/ecephys/nwb/ecephys_nwb_extension_builder.py index fbe4520c7a..862020f4ee 100644 --- a/allensdk/brain_observatory/ecephys/nwb/ecephys_nwb_extension_builder.py +++ b/allensdk/brain_observatory/ecephys/nwb/ecephys_nwb_extension_builder.py @@ -1,5 +1,4 @@ -from pynwb.spec import (NWBAttributeSpec, NWBDatasetSpec, - NWBGroupSpec, NWBNamespaceBuilder) +from pynwb.spec import NWBAttributeSpec, NWBDatasetSpec, NWBGroupSpec, NWBNamespaceBuilder # This is the script used to generate the AIBS ecephys NWB extension .yaml # files. It can be run by installing pynwb and executing @@ -7,141 +6,128 @@ # the same directory which the script is run in. For more details see: # https://pynwb.readthedocs.io/en/stable/extensions.html -ns_builder = NWBNamespaceBuilder(doc="Allen Institute Ecephys Extension", - version="0.2.0", - name="ndx-aibs-ecephys", - author="Allen Institute for Brain Science", - contact="waynew@alleninstitute.org") +ns_builder = NWBNamespaceBuilder( + doc="Allen Institute Ecephys Extension", + version="0.2.0", + name="ndx-aibs-ecephys", + author="Allen Institute for Brain Science", + contact="waynew@alleninstitute.org", +) -probe_id_attr = NWBAttributeSpec(name="probe_id", - doc="Unique ID of the neuropixels probe", - dtype="int") +probe_id_attr = NWBAttributeSpec(name="probe_id", doc="Unique ID of the neuropixels probe", dtype="int") # Ecephys probe device extension (inherits from NWB `Device`) -sampling_rate_attr = NWBAttributeSpec(name="sampling_rate", - doc="The sampling rate for the device", - dtype="float64") +sampling_rate_attr = NWBAttributeSpec(name="sampling_rate", doc="The sampling rate for the device", dtype="float64") ecephys_probe_attributes = [sampling_rate_attr, probe_id_attr] -ecephys_probe_ext = NWBGroupSpec(doc="A neuropixels probe device", - attributes=ecephys_probe_attributes, - neurodata_type_def="EcephysProbe", - neurodata_type_inc="Device") +ecephys_probe_ext = NWBGroupSpec( + doc="A neuropixels probe device", + attributes=ecephys_probe_attributes, + neurodata_type_def="EcephysProbe", + neurodata_type_inc="Device", +) # Ecephys electrode group extension (inherits from NWB `ElectrodeGroup`) -has_lfp_data_attr = NWBAttributeSpec(name="has_lfp_data", - doc="Indicates availability of LFP data", - dtype="bool") +has_lfp_data_attr = NWBAttributeSpec(name="has_lfp_data", doc="Indicates availability of LFP data", dtype="bool") -lfp_sampling_rate = NWBAttributeSpec(name="lfp_sampling_rate", - doc=("The sampling rate at which data " - "were acquired on this electrode " - "group's channels"), - dtype="float64") +lfp_sampling_rate = NWBAttributeSpec( + name="lfp_sampling_rate", + doc=("The sampling rate at which data were acquired on this electrode group's channels"), + dtype="float64", +) -ecephys_egroup_attributes = [has_lfp_data_attr, probe_id_attr, - lfp_sampling_rate] +ecephys_egroup_attributes = [has_lfp_data_attr, probe_id_attr, lfp_sampling_rate] -ecephys_egroup_ext = NWBGroupSpec(doc=("A group consisting of the channels " - "on a single neuropixels probe"), - attributes=ecephys_egroup_attributes, - neurodata_type_def="EcephysElectrodeGroup", - neurodata_type_inc="ElectrodeGroup") +ecephys_egroup_ext = NWBGroupSpec( + doc=("A group consisting of the channels on a single neuropixels probe"), + attributes=ecephys_egroup_attributes, + neurodata_type_def="EcephysElectrodeGroup", + neurodata_type_inc="ElectrodeGroup", +) # Ecephys specimen metadata extension (inherits from NWB `Subject`) -specimen_name_attr = NWBAttributeSpec(name="specimen_name", - doc="Full name of specimen", - dtype="text") +specimen_name_attr = NWBAttributeSpec(name="specimen_name", doc="Full name of specimen", dtype="text") -age_in_days_attr = NWBAttributeSpec(name="age_in_days", - doc="Age of specimen in days", - dtype="float") +age_in_days_attr = NWBAttributeSpec(name="age_in_days", doc="Age of specimen in days", dtype="float") -strain_attr = NWBAttributeSpec(name="strain", - doc="Specimen strain", - dtype="text") +strain_attr = NWBAttributeSpec(name="strain", doc="Specimen strain", dtype="text") -ecephys_specimen_attributes = [specimen_name_attr, age_in_days_attr, - strain_attr] +ecephys_specimen_attributes = [specimen_name_attr, age_in_days_attr, strain_attr] -ecephys_specimen_ext = NWBGroupSpec(doc="Metadata for ecephys specimen", - attributes=ecephys_specimen_attributes, - neurodata_type_def="EcephysSpecimen", - neurodata_type_inc="Subject") +ecephys_specimen_ext = NWBGroupSpec( + doc="Metadata for ecephys specimen", + attributes=ecephys_specimen_attributes, + neurodata_type_def="EcephysSpecimen", + neurodata_type_inc="Subject", +) # Ecephys eye tracking rig metadata extension (inherits from `NWBDataInterface`) -rig_equipment_attr = NWBAttributeSpec(name="equipment", - doc="Description of rig", - dtype="text") - -unit_attr = NWBAttributeSpec('unit', 'Unit of measurement for the data', 'text') - -rig_monitor_position_dset = NWBDatasetSpec(name="monitor_position", - doc="position of monitor (x, y, z)", - attributes=[unit_attr], - dtype='float32', - dims=(3,)) - -rig_camera_position_dset = NWBDatasetSpec(name="camera_position", - doc="position of camera (x, y, z)", - attributes=[unit_attr], - dtype='float32', - dims=(3,)) - -rig_led_position_dset = NWBDatasetSpec(name="led_position", - doc="position of LED (x, y, z)", - attributes=[unit_attr], - dtype='float32', - dims=(3,)) - -rig_monitor_rotation_dset = NWBDatasetSpec(name="monitor_rotation", - doc="rotation of monitor (x, y, z)", - attributes=[unit_attr], - dtype='float32', - dims=(3,)) - -rig_camera_rotation_dset = NWBDatasetSpec(name="camera_rotation", - doc="rotation of camera (x, y, z)", - attributes=[unit_attr], - dtype='float32', - dims=(3,)) +rig_equipment_attr = NWBAttributeSpec(name="equipment", doc="Description of rig", dtype="text") + +unit_attr = NWBAttributeSpec("unit", "Unit of measurement for the data", "text") + +rig_monitor_position_dset = NWBDatasetSpec( + name="monitor_position", doc="position of monitor (x, y, z)", attributes=[unit_attr], dtype="float32", dims=(3,) +) + +rig_camera_position_dset = NWBDatasetSpec( + name="camera_position", doc="position of camera (x, y, z)", attributes=[unit_attr], dtype="float32", dims=(3,) +) + +rig_led_position_dset = NWBDatasetSpec( + name="led_position", doc="position of LED (x, y, z)", attributes=[unit_attr], dtype="float32", dims=(3,) +) + +rig_monitor_rotation_dset = NWBDatasetSpec( + name="monitor_rotation", doc="rotation of monitor (x, y, z)", attributes=[unit_attr], dtype="float32", dims=(3,) +) + +rig_camera_rotation_dset = NWBDatasetSpec( + name="camera_rotation", doc="rotation of camera (x, y, z)", attributes=[unit_attr], dtype="float32", dims=(3,) +) ecephys_eye_tracking_rig_metadata_ext = NWBGroupSpec( doc="Metadata for ecephys experiment rig", attributes=[rig_equipment_attr], - datasets=[rig_monitor_position_dset, - rig_camera_position_dset, - rig_led_position_dset, - rig_monitor_rotation_dset, - rig_camera_rotation_dset], + datasets=[ + rig_monitor_position_dset, + rig_camera_position_dset, + rig_led_position_dset, + rig_monitor_rotation_dset, + rig_camera_rotation_dset, + ], neurodata_type_def="EcephysEyeTrackingRigMetadata", - neurodata_type_inc="NWBDataInterface" + neurodata_type_inc="NWBDataInterface", ) # Ecephys CSD extension -csd_timeseries_group = NWBGroupSpec(doc="A timeseries containing current source density (CSD) data", - neurodata_type_inc="TimeSeries") +csd_timeseries_group = NWBGroupSpec( + doc="A timeseries containing current source density (CSD) data", neurodata_type_inc="TimeSeries" +) -csd_virtual_electrode_vertical_positions = NWBDatasetSpec(name="virtual_electrode_y_positions", - doc="Virtual vertical positions of electrodes from which CSD was calculated", - attributes=[unit_attr], - dtype='float32', - shape=(None,)) +csd_virtual_electrode_vertical_positions = NWBDatasetSpec( + name="virtual_electrode_y_positions", + doc="Virtual vertical positions of electrodes from which CSD was calculated", + attributes=[unit_attr], + dtype="float32", + shape=(None,), +) -csd_virtual_electrode_horizontal_positions = NWBDatasetSpec(name="virtual_electrode_x_positions", - doc="Virtual horizontal positions of electrodes from which CSD was calculated", - attributes=[unit_attr], - dtype='float32', - shape=(None,)) +csd_virtual_electrode_horizontal_positions = NWBDatasetSpec( + name="virtual_electrode_x_positions", + doc="Virtual horizontal positions of electrodes from which CSD was calculated", + attributes=[unit_attr], + dtype="float32", + shape=(None,), +) ecephys_csd_ext = NWBGroupSpec( doc="A group containing current source density (CSD) data and virtual electrode locations", groups=[csd_timeseries_group], - datasets=[csd_virtual_electrode_horizontal_positions, - csd_virtual_electrode_vertical_positions], + datasets=[csd_virtual_electrode_horizontal_positions, csd_virtual_electrode_vertical_positions], neurodata_type_def="EcephysCSD", - neurodata_type_inc="NWBDataInterface" + neurodata_type_inc="NWBDataInterface", ) ext_source = "ndx-aibs-ecephys.extension.yaml" diff --git a/allensdk/brain_observatory/ecephys/nwb_util.py b/allensdk/brain_observatory/ecephys/nwb_util.py index 16c2ce7a8e..23bf74d202 100644 --- a/allensdk/brain_observatory/ecephys/nwb_util.py +++ b/allensdk/brain_observatory/ecephys/nwb_util.py @@ -5,25 +5,19 @@ import pynwb from allensdk.brain_observatory import dict_to_indexed_array -from allensdk.brain_observatory.ecephys.nwb import EcephysProbe, \ - EcephysElectrodeGroup +from allensdk.brain_observatory.ecephys.nwb import EcephysProbe, EcephysElectrodeGroup ELECTRODE_TABLE_DEFAULT_COLUMNS = [ - ("probe_vertical_position", - "Length-wise position of electrode/channel on device (microns)"), - ("probe_horizontal_position", - "Width-wise position of electrode/channel on device (microns)"), + ("probe_vertical_position", "Length-wise position of electrode/channel on device (microns)"), + ("probe_horizontal_position", "Width-wise position of electrode/channel on device (microns)"), ("probe_id", "The unique id of this electrode's/channel's device"), - ("probe_channel_number", - "The local index of electrode/channel on device"), - ("valid_data", "Whether data from this electrode/channel is usable") + ("probe_channel_number", "The local index of electrode/channel on device"), + ("valid_data", "Whether data from this electrode/channel is usable"), ] -def add_ragged_data_to_dynamic_table( - table, data, column_name, column_description="" -): - """ Builds the index and data vectors required for writing ragged array +def add_ragged_data_to_dynamic_table(table, data, column_name, column_description=""): + """Builds the index and data vectors required for writing ragged array data to a pynwb dynamic table Parameters @@ -46,18 +40,13 @@ def add_ragged_data_to_dynamic_table( idx, values = dict_to_indexed_array(data, table.id.data) del data - table.add_column( - name=column_name, - description=column_description, - data=values, - index=idx - ) + table.add_column(name=column_name, description=column_description, data=values, index=idx) -def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, - has_lfp_data, name, - location="See electrode locations"): - """ Creates objects required for representation of a single +def add_probe_to_nwbfile( + nwbfile, probe_id, sampling_rate, lfp_sampling_rate, has_lfp_data, name, location="See electrode locations" +): + """Creates objects required for representation of a single extracellular ephys probe within an NWB file. Parameters @@ -92,11 +81,13 @@ def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, electrode group object corresponding to this probe """ - probe_nwb_device = EcephysProbe(name=name, - description="Neuropixels 1.0 Probe", - manufacturer="imec", - probe_id=probe_id, - sampling_rate=sampling_rate) + probe_nwb_device = EcephysProbe( + name=name, + description="Neuropixels 1.0 Probe", + manufacturer="imec", + probe_id=probe_id, + sampling_rate=sampling_rate, + ) probe_nwb_electrode_group = EcephysElectrodeGroup( name=name, @@ -105,7 +96,7 @@ def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, location=location, device=probe_nwb_device, lfp_sampling_rate=lfp_sampling_rate, - has_lfp_data=has_lfp_data + has_lfp_data=has_lfp_data, ) nwbfile.add_device(probe_nwb_device) @@ -115,10 +106,11 @@ def add_probe_to_nwbfile(nwbfile, probe_id, sampling_rate, lfp_sampling_rate, def add_ecephys_electrodes( - nwbfile: pynwb.NWBFile, - channels: List[dict], - electrode_group: EcephysElectrodeGroup, - channel_number_whitelist: Optional[np.ndarray] = None): + nwbfile: pynwb.NWBFile, + channels: List[dict], + electrode_group: EcephysElectrodeGroup, + channel_number_whitelist: Optional[np.ndarray] = None, +): """Add electrode information to an ecephys nwbfile electrode table. Parameters @@ -182,13 +174,11 @@ def add_ecephys_electrodes( group=electrode_group, location=row["structure_acronym"], imp=row.get("impedence", row.get("impedance")), - filtering=row["filtering"] + filtering=row["filtering"], ) -def _add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, - columns_to_add: - Optional[List[Tuple[str, str]]] = None): +def _add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, columns_to_add: Optional[List[Tuple[str, str]]] = None): """Add additional columns to ecephys nwbfile electrode table. Parameters @@ -204,7 +194,5 @@ def _add_ecephys_electrode_columns(nwbfile: pynwb.NWBFile, columns_to_add = ELECTRODE_TABLE_DEFAULT_COLUMNS for col_name, col_description in columns_to_add: - if (not nwbfile.electrodes) or \ - (col_name not in nwbfile.electrodes.colnames): - nwbfile.add_electrode_column(name=col_name, - description=col_description) + if (not nwbfile.electrodes) or (col_name not in nwbfile.electrodes.colnames): + nwbfile.add_electrode_column(name=col_name, description=col_description) diff --git a/allensdk/brain_observatory/ecephys/optotagging.py b/allensdk/brain_observatory/ecephys/optotagging.py index a9a3d064fc..e1ddf8e06e 100644 --- a/allensdk/brain_observatory/ecephys/optotagging.py +++ b/allensdk/brain_observatory/ecephys/optotagging.py @@ -3,20 +3,18 @@ from pynwb import NWBFile from allensdk.brain_observatory.nwb import setup_table_for_epochs -from allensdk.core import DataObject, JsonReadableInterface, \ - NwbWritableInterface, NwbReadableInterface +from allensdk.core import DataObject, JsonReadableInterface, NwbWritableInterface, NwbReadableInterface -class OptotaggingTable(DataObject, JsonReadableInterface, - NwbWritableInterface, NwbReadableInterface): +class OptotaggingTable(DataObject, JsonReadableInterface, NwbWritableInterface, NwbReadableInterface): """Optotagging table - optotagging stimulation""" + def __init__(self, table: pd.DataFrame): # "name" is a pynwb reserved column name that older versions of the # pre-processed optotagging_table may use. - table = \ - table.rename(columns={"name": "stimulus_name"}) - table.index = table.index.rename('id') - super().__init__(name='optotaggging_table', value=table) + table = table.rename(columns={"name": "stimulus_name"}) + table.index = table.index.rename("id") + super().__init__(name="optotaggging_table", value=table) @property def value(self) -> pd.DataFrame: @@ -36,16 +34,15 @@ def value(self) -> pd.DataFrame: @classmethod def from_json(cls, dict_repr: dict) -> "OptotaggingTable": - table = pd.read_csv(dict_repr['optotagging_table_path']) - table.index.name = 'id' + table = pd.read_csv(dict_repr["optotagging_table_path"]) + table.index.name = "id" return OptotaggingTable(table=table) @classmethod def from_nwb(cls, nwbfile: NWBFile) -> "OptotaggingTable": - mod = nwbfile.get_processing_module('optotagging') - table = mod.get_data_interface('optogenetic_stimulation')\ - .to_dataframe() - table.drop(columns=['tags', 'timeseries'], inplace=True) + mod = nwbfile.get_processing_module("optotagging") + table = mod.get_data_interface("optogenetic_stimulation").to_dataframe() + table.drop(columns=["tags", "timeseries"], inplace=True) return OptotaggingTable(table=table) def to_nwb(self, nwbfile: NWBFile) -> NWBFile: @@ -55,21 +52,17 @@ def to_nwb(self, nwbfile: NWBFile) -> NWBFile: name="optotagging", timestamps=optotagging_table["start_time"].values, data=optotagging_table["duration"].values, - unit="seconds" + unit="seconds", ) - opto_mod = pynwb.ProcessingModule("optotagging", - "optogenetic stimulution data") + opto_mod = pynwb.ProcessingModule("optotagging", "optogenetic stimulution data") opto_mod.add_data_interface(opto_ts) nwbfile.add_processing_module(opto_mod) - optotagging_table = setup_table_for_epochs(optotagging_table, opto_ts, - 'optical_stimulation') + optotagging_table = setup_table_for_epochs(optotagging_table, opto_ts, "optical_stimulation") if len(optotagging_table) > 0: - container = \ - pynwb.epoch.TimeIntervals.from_dataframe( - optotagging_table, "optogenetic_stimulation") + container = pynwb.epoch.TimeIntervals.from_dataframe(optotagging_table, "optogenetic_stimulation") opto_mod.add_data_interface(container) return nwbfile diff --git a/allensdk/brain_observatory/ecephys/optotagging_table/__main__.py b/allensdk/brain_observatory/ecephys/optotagging_table/__main__.py index 24c5813a22..ae96ac8a4a 100644 --- a/allensdk/brain_observatory/ecephys/optotagging_table/__main__.py +++ b/allensdk/brain_observatory/ecephys/optotagging_table/__main__.py @@ -1,8 +1,6 @@ import pandas as pd -from allensdk.brain_observatory.argschema_utilities import \ - ArgSchemaParserPlus, \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import ArgSchemaParserPlus, write_or_print_outputs from allensdk.brain_observatory.ecephys.file_io.ecephys_sync_dataset import ( EcephysSyncDataset, ) @@ -10,25 +8,19 @@ def build_opto_table(args): - opto_file = pd.read_pickle(args['opto_pickle_path']) - sync_file = EcephysSyncDataset.factory(args['sync_h5_path']) + opto_file = pd.read_pickle(args["opto_pickle_path"]) + sync_file = EcephysSyncDataset.factory(args["sync_h5_path"]) start_times = sync_file.extract_led_times() - conditions = [str(item) for item in opto_file['opto_conditions']] - levels = opto_file['opto_levels'] + conditions = [str(item) for item in opto_file["opto_conditions"]] + levels = opto_file["opto_levels"] assert len(conditions) == len(levels) if len(start_times) > len(conditions): - raise ValueError( - f"there are {len(start_times) - len(conditions)} extra " - f"optotagging sync times!") + raise ValueError(f"there are {len(start_times) - len(conditions)} extra optotagging sync times!") - optotagging_table = pd.DataFrame({ - 'start_time': start_times, - 'condition': conditions, - 'level': levels - }) - optotagging_table = optotagging_table.sort_values(by='start_time', axis=0) + optotagging_table = pd.DataFrame({"start_time": start_times, "condition": conditions, "level": levels}) + optotagging_table = optotagging_table.sort_values(by="start_time", axis=0) stop_times = [] names = [] @@ -42,16 +34,14 @@ def build_opto_table(args): optotagging_table["stop_time"] = stop_times optotagging_table["stimulus_name"] = names optotagging_table["condition"] = conditions - optotagging_table["duration"] = \ - optotagging_table["stop_time"] - optotagging_table["start_time"] + optotagging_table["duration"] = optotagging_table["stop_time"] - optotagging_table["start_time"] - optotagging_table.to_csv(args['output_opto_table_path'], index=False) - return {'output_opto_table_path': args['output_opto_table_path']} + optotagging_table.to_csv(args["output_opto_table_path"], index=False) + return {"output_opto_table_path": args["output_opto_table_path"]} def main(): - mod = ArgSchemaParserPlus(schema_type=InputParameters, - output_schema_type=OutputParameters) + mod = ArgSchemaParserPlus(schema_type=InputParameters, output_schema_type=OutputParameters) output = build_opto_table(mod.args) write_or_print_outputs(data=output, parser=mod) diff --git a/allensdk/brain_observatory/ecephys/optotagging_table/_schemas.py b/allensdk/brain_observatory/ecephys/optotagging_table/_schemas.py index 0c6898724a..6fdfb8dec0 100644 --- a/allensdk/brain_observatory/ecephys/optotagging_table/_schemas.py +++ b/allensdk/brain_observatory/ecephys/optotagging_table/_schemas.py @@ -1,29 +1,13 @@ -from argschema import ArgSchema +from argschema import ArgSchema from argschema.schemas import DefaultSchema from argschema.fields import Nested, String, Float, Dict known_conditions = { - "0": { - "duration": 1.0, - "name": "fast_pulses", - "condition": "2.5 ms pulses at 10 Hz" - }, - "1": { - "duration": 0.005, - "name": "pulse", - "condition": "a single square pulse" - }, - "2": { - "duration": 0.01, - "name": "pulse", - "condition": "a single square pulse" - }, - "3": { - "duration": 1.0, - "name": "raised_cosine", - "condition": "half-period of a cosine wave" - } + "0": {"duration": 1.0, "name": "fast_pulses", "condition": "2.5 ms pulses at 10 Hz"}, + "1": {"duration": 0.005, "name": "pulse", "condition": "a single square pulse"}, + "2": {"duration": 0.01, "name": "pulse", "condition": "a single square pulse"}, + "3": {"duration": 1.0, "name": "raised_cosine", "condition": "half-period of a cosine wave"}, } @@ -34,15 +18,15 @@ class Condition(DefaultSchema): class InputParameters(ArgSchema): - opto_pickle_path = String(required=True, help='path to file containing optotagging information') - sync_h5_path = String(required=True, help='path to h5 file containing syncronization information') - output_opto_table_path = String(required=True, help='the optotagging stimulation table will be written here') + opto_pickle_path = String(required=True, help="path to file containing optotagging information") + sync_h5_path = String(required=True, help="path to h5 file containing syncronization information") + output_opto_table_path = String(required=True, help="the optotagging stimulation table will be written here") conditions = Dict(String, Nested(Condition), default=known_conditions) class OutputSchema(DefaultSchema): - input_parameters = Nested(InputParameters, description=('Input parameters the module was run with'), required=True) + input_parameters = Nested(InputParameters, description=("Input parameters the module was run with"), required=True) class OutputParameters(OutputSchema): - output_opto_table_path = String(required=True, help='path to optotagging stimulation table') \ No newline at end of file + output_opto_table_path = String(required=True, help="path to optotagging stimulation table") diff --git a/allensdk/brain_observatory/ecephys/probes.py b/allensdk/brain_observatory/ecephys/probes.py index 8ddbea3c34..b7c10417b1 100644 --- a/allensdk/brain_observatory/ecephys/probes.py +++ b/allensdk/brain_observatory/ecephys/probes.py @@ -7,18 +7,14 @@ from pynwb import NWBFile from allensdk.brain_observatory.ecephys._probe import Probe, ProbeWithLFPMeta -from allensdk.brain_observatory.ecephys.nwb_util import \ - add_ragged_data_to_dynamic_table -from allensdk.core import DataObject, JsonReadableInterface, \ - NwbReadableInterface, NwbWritableInterface +from allensdk.brain_observatory.ecephys.nwb_util import add_ragged_data_to_dynamic_table +from allensdk.core import DataObject, JsonReadableInterface, NwbReadableInterface, NwbWritableInterface -class Probes(DataObject, JsonReadableInterface, NwbReadableInterface, - NwbWritableInterface): +class Probes(DataObject, JsonReadableInterface, NwbReadableInterface, NwbWritableInterface): """Probes""" - def __init__(self, - probes: List[Probe]): + def __init__(self, probes: List[Probe]): """ Parameters @@ -26,9 +22,7 @@ def __init__(self, probes: List of Probe """ self._probes = probes - super().__init__(name='probes', - value=None, - is_value_self=True) + super().__init__(name="probes", value=None, is_value_self=True) @property def probes(self): @@ -42,10 +36,7 @@ def spike_times(self) -> Dict[int, np.ndarray]: ------- Dictionary mapping unit id to spike_times for all probes """ - return { - unit.id: unit.spike_times - for probe in self.probes - for unit in probe.units.value} + return {unit.id: unit.spike_times for probe in self.probes for unit in probe.units.value} @property def mean_waveforms(self) -> Dict[int, np.ndarray]: @@ -55,10 +46,7 @@ def mean_waveforms(self) -> Dict[int, np.ndarray]: ------- Dictionary mapping unit id to mean_waveforms for all probes """ - return { - unit.id: unit.mean_waveforms - for probe in self.probes - for unit in probe.units.value} + return {unit.id: unit.mean_waveforms for probe in self.probes for unit in probe.units.value} @property def spike_amplitudes(self) -> Dict[int, np.ndarray]: @@ -68,18 +56,15 @@ def spike_amplitudes(self) -> Dict[int, np.ndarray]: ------- Dictionary mapping unit id to spike_amplitudes for all probes """ - return { - unit.id: unit.spike_amplitudes - for probe in self.probes - for unit in probe.units.value} + return {unit.id: unit.spike_amplitudes for probe in self.probes for unit in probe.units.value} def get_units_table( - self, - filter_by_validity: bool = True, - filter_out_of_brain_units: bool = True, - amplitude_cutoff_maximum: Optional[float] = None, - presence_ratio_minimum: Optional[float] = None, - isi_violations_maximum: Optional[float] = None + self, + filter_by_validity: bool = True, + filter_out_of_brain_units: bool = True, + amplitude_cutoff_maximum: Optional[float] = None, + presence_ratio_minimum: Optional[float] = None, + isi_violations_maximum: Optional[float] = None, ) -> pd.DataFrame: """ Gets a dataframe representing all units detected by all probes @@ -106,47 +91,30 @@ def get_units_table( """ units_table = pd.concat([probe.units_table for probe in self.probes]) - units_table = units_table.set_index(keys='id', drop=True) - units_table = units_table.drop(columns=[ - 'spike_times', 'spike_amplitudes', 'mean_waveforms']) + units_table = units_table.set_index(keys="id", drop=True) + units_table = units_table.drop(columns=["spike_times", "spike_amplitudes", "mean_waveforms"]) if filter_by_validity or filter_out_of_brain_units: - channels = pd.concat([ - p.channels.to_dataframe( - filter_by_validity=filter_by_validity - ) for p in self.probes - ]) + channels = pd.concat([p.channels.to_dataframe(filter_by_validity=filter_by_validity) for p in self.probes]) if filter_out_of_brain_units: - channels = channels[ - ~(channels['structure_acronym'].isna())] + channels = channels[~(channels["structure_acronym"].isna())] # noinspection PyTypeChecker channel_ids = set(channels.index.values.tolist()) - units_table = units_table[ - units_table["peak_channel_id"].isin(channel_ids)] + units_table = units_table[units_table["peak_channel_id"].isin(channel_ids)] if filter_by_validity: units_table = units_table[units_table["quality"] == "good"] units_table.drop(columns=["quality"], inplace=True) - units_table = units_table[ - units_table["amplitude_cutoff"] <= - (amplitude_cutoff_maximum or np.inf)] - units_table = units_table[ - units_table["presence_ratio"] >= - (presence_ratio_minimum or -np.inf)] - units_table = units_table[ - units_table["isi_violations"] <= - (isi_violations_maximum or np.inf)] + units_table = units_table[units_table["amplitude_cutoff"] <= (amplitude_cutoff_maximum or np.inf)] + units_table = units_table[units_table["presence_ratio"] >= (presence_ratio_minimum or -np.inf)] + units_table = units_table[units_table["isi_violations"] <= (isi_violations_maximum or np.inf)] return units_table @classmethod - def from_json( - cls, - probes: List[Dict[str, Any]], - skip_probes: Optional[List[str]] = None - ) -> "Probes": + def from_json(cls, probes: List[Dict[str, Any]], skip_probes: Optional[List[str]] = None) -> "Probes": """ Parameters @@ -159,32 +127,24 @@ def from_json( `Probes` instance """ skip_probes = skip_probes if skip_probes is not None else [] - invalid_skip_probes = set(skip_probes).difference( - [p['name'] for p in probes]) + invalid_skip_probes = set(skip_probes).difference([p["name"] for p in probes]) if invalid_skip_probes: - raise ValueError( - f'You passed invalid probes to skip: {invalid_skip_probes} ' - f'are not valid probe names') + raise ValueError(f"You passed invalid probes to skip: {invalid_skip_probes} are not valid probe names") for probe in skip_probes: - logging.info(f'Skipping {probe}') - probes = [p for p in probes if p['name'] not in skip_probes] - probes = sorted(probes, key=lambda probe: probe['name']) + logging.info(f"Skipping {probe}") + probes = [p for p in probes if p["name"] not in skip_probes] + probes = sorted(probes, key=lambda probe: probe["name"]) probes = [Probe.from_json(probe=probe) for probe in probes] return Probes(probes=probes) def to_dataframe(self): probes = [probe.to_dict() for probe in self.probes] probes = pd.DataFrame(probes) - probes = probes.set_index(keys='id') + probes = probes.set_index(keys="id") return probes @classmethod - def from_nwb( - cls, - nwbfile: NWBFile, - probe_lfp_meta_map: Optional[ - Dict[str, ProbeWithLFPMeta]] = None - ) -> "Probes": + def from_nwb(cls, nwbfile: NWBFile, probe_lfp_meta_map: Optional[Dict[str, ProbeWithLFPMeta]] = None) -> "Probes": """ Parameters @@ -200,18 +160,12 @@ def from_nwb( if probe_lfp_meta_map is None: probe_lfp_meta_map = dict() probes = [ - Probe.from_nwb( - nwbfile=nwbfile, - probe_name=probe_name, - lfp_meta=probe_lfp_meta_map.get(probe_name) - ) - for probe_name in nwbfile.electrode_groups] + Probe.from_nwb(nwbfile=nwbfile, probe_name=probe_name, lfp_meta=probe_lfp_meta_map.get(probe_name)) + for probe_name in nwbfile.electrode_groups + ] return Probes(probes=probes) - def to_nwb( - self, - nwbfile: NWBFile - ) -> Tuple[NWBFile, Dict[str, Optional[NWBFile]]]: + def to_nwb(self, nwbfile: NWBFile) -> Tuple[NWBFile, Dict[str, Optional[NWBFile]]]: """ Adds probes to NWBFile instance @@ -233,16 +187,12 @@ def to_nwb( """ probe_nwbfile_map = dict() for probe in self.probes: - _, probe_nwbfile = probe.to_nwb( - nwbfile=nwbfile - ) + _, probe_nwbfile = probe.to_nwb(nwbfile=nwbfile) probe_nwbfile_map[probe.name] = probe_nwbfile nwbfile.units = pynwb.misc.Units.from_dataframe( - self.get_units_table( - filter_by_validity=False, - filter_out_of_brain_units=False), - name='units') + self.get_units_table(filter_by_validity=False, filter_out_of_brain_units=False), name="units" + ) add_ragged_data_to_dynamic_table( table=nwbfile.units, @@ -255,15 +205,14 @@ def to_nwb( table=nwbfile.units, data=self.spike_amplitudes, column_name="spike_amplitudes", - column_description="amplitude (s) of detected spiking events" + column_description="amplitude (s) of detected spiking events", ) add_ragged_data_to_dynamic_table( table=nwbfile.units, data=self.mean_waveforms, column_name="waveform_mean", - column_description="mean waveforms on peak channels (over " - "samples)", + column_description="mean waveforms on peak channels (over samples)", ) return nwbfile, probe_nwbfile_map diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/__init__.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/__init__.py index f0795d6da8..2bd8d365b3 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/__init__.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/__init__.py @@ -14,4 +14,4 @@ "DotMotion", "NaturalMovies", "ReceptiveFieldMapping", -] \ No newline at end of file +] diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/__main__.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/__main__.py index e5197864ef..93b78ae8f9 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/__main__.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/__main__.py @@ -7,8 +7,7 @@ import pandas as pd from argschema import ArgSchemaParser -from allensdk.brain_observatory.argschema_utilities import \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import write_or_print_outputs from .dot_motion import DotMotion from .drifting_gratings import DriftingGratings from .flashes import Flashes @@ -30,8 +29,8 @@ # Run without mpi4py installed MPI_rank = 0 MPI_size = 1 - barrier = lambda: None # noqa F841 - gather = lambda data, root: data # noqa F841 + barrier = lambda: None # noqa F841 + gather = lambda data, root: data # noqa F841 logger = logging.getLogger(__name__) @@ -39,13 +38,13 @@ # TODO: Try to order this list by how long each subclass takes to finish. # Helps spread work evenly across cores stim_classes = [ - ('receptive_field_mapping', ReceptiveFieldMapping), - ('drifting_gratings', DriftingGratings), - ('dot_motion', DotMotion), - ('static_gratings', StaticGratings), - ('natural_scenes', NaturalScenes), - ('natural_moves', NaturalMovies), - ('flashes', Flashes), + ("receptive_field_mapping", ReceptiveFieldMapping), + ("drifting_gratings", DriftingGratings), + ("dot_motion", DotMotion), + ("static_gratings", StaticGratings), + ("natural_scenes", NaturalScenes), + ("natural_moves", NaturalMovies), + ("flashes", Flashes), ] @@ -55,13 +54,16 @@ def log_info(message, all_ranks=False): def load_session(nwb_path, stimulus_class, **session_params): - session = EcephysSession.from_nwb_path(nwb_path, api_kwargs={ - "amplitude_cutoff_maximum": np.inf, - "presence_ratio_minimum": -np.inf, - "isi_violations_maximum": np.inf, - "filter_by_validity": False - # actually you probably still want this one - }) + session = EcephysSession.from_nwb_path( + nwb_path, + api_kwargs={ + "amplitude_cutoff_maximum": np.inf, + "presence_ratio_minimum": -np.inf, + "isi_violations_maximum": np.inf, + "filter_by_validity": False, + # actually you probably still want this one + }, + ) return stimulus_class(session, **session_params) @@ -92,11 +94,11 @@ def calculate_stimulus_metrics_ondisk(args): Same as below except pass the metric tables between ranks by writing/reading to a file. """ - log_info('ecephys: stimulus metrics module') + log_info("ecephys: stimulus metrics module") start = time.time() - input_session_nwb = args['input_session_nwb'] - output_file = args['output_file'] + input_session_nwb = args["input_session_nwb"] + output_file = args["output_file"] # For each stimulus class that needs to be processed; calculate and save # the metrics on a different rank (unless @@ -106,17 +108,14 @@ def _temp_csv_file(stim_class): # merged into final output_dir = pathlib.Path(output_file).parents[0] session_name = pathlib.Path(input_session_nwb).stem - return os.path.join(output_dir, - '{}.{}.csv'.format(session_name, stim_class)) + return os.path.join(output_dir, "{}.{}.csv".format(session_name, stim_class)) - relevant_stim_class = [(sc[0], sc[1], _temp_csv_file(sc[0])) - for sc in stim_classes if sc[ - 0] in args] # only stims specified in the + relevant_stim_class = [ + (sc[0], sc[1], _temp_csv_file(sc[0])) for sc in stim_classes if sc[0] in args + ] # only stims specified in the # input json - for sc_name, stim_class, tmp_csv in relevant_stim_class[ - MPI_rank::MPI_size]: - analysis_obj = load_session(input_session_nwb, stim_class, - **args[sc_name]) + for sc_name, stim_class, tmp_csv in relevant_stim_class[MPI_rank::MPI_size]: + analysis_obj = load_session(input_session_nwb, stim_class, **args[sc_name]) # analysis_obj = stim_class(input_session_nwb, **args[sc_name]) analysis_obj.metrics.to_csv(tmp_csv) @@ -128,7 +127,7 @@ def _temp_csv_file(stim_class): final_table = pd.read_csv(relevant_stim_class[0][2]) for _, _, tmp_csv in relevant_stim_class[1:]: tmp_table = pd.read_csv(tmp_csv) - final_table = pd.merge(final_table, tmp_table, on='unit_id') + final_table = pd.merge(final_table, tmp_table, on="unit_id") final_table.to_csv(output_file) @@ -143,7 +142,7 @@ def _temp_csv_file(stim_class): barrier() execution_time = time.time() - start - log_info(f'total time: {str(np.around(execution_time, 2))} seconds') + log_info(f"total time: {str(np.around(execution_time, 2))} seconds") return {"execution_time": execution_time} @@ -153,27 +152,25 @@ def calculate_stimulus_metrics_gather(args): Same as above but uses MPI Gather to send the dataframes across ranks """ - log_info('ecephys: stimulus metrics module') + log_info("ecephys: stimulus metrics module") start = time.time() - input_session_nwb = args['input_session_nwb'] - output_file = args['output_file'] + input_session_nwb = args["input_session_nwb"] + output_file = args["output_file"] # Divide the work across the ranks, calculate each metric and combine # all the result on each rank. combined_df = None - relevant_stim_class = [(sc[0], sc[1]) for sc in stim_classes if - sc[0] in args] # metrics for this rank + relevant_stim_class = [(sc[0], sc[1]) for sc in stim_classes if sc[0] in args] # metrics for this rank if MPI_rank < len(relevant_stim_class): for sc_name, stim_class in relevant_stim_class[MPI_rank::MPI_size]: - analysis_obj = load_session(input_session_nwb, stim_class, - **args[sc_name]) + analysis_obj = load_session(input_session_nwb, stim_class, **args[sc_name]) analysis_df = analysis_obj.metrics if combined_df is None: combined_df = analysis_df else: - combined_df = pd.merge(combined_df, analysis_df, on='unit_id') + combined_df = pd.merge(combined_df, analysis_df, on="unit_id") barrier() @@ -190,7 +187,7 @@ def calculate_stimulus_metrics_gather(args): for df in all_ranks_data[1:]: if df is None: continue - final_df = pd.merge(final_df, df, on='unit_id') + final_df = pd.merge(final_df, df, on="unit_id") final_df.to_csv(output_file) @@ -202,8 +199,7 @@ def calculate_stimulus_metrics_gather(args): def main(): from ._schemas import InputParameters, OutputParameters - mod = ArgSchemaParser(schema_type=InputParameters, - output_schema_type=OutputParameters) + mod = ArgSchemaParser(schema_type=InputParameters, output_schema_type=OutputParameters) # output = calculate_stimulus_metrics_ondisk(mod.args) output = calculate_stimulus_metrics_gather(mod.args) if MPI_rank == 0: diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/_schemas.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/_schemas.py index 5403ebfede..b9e4129b97 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/_schemas.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/_schemas.py @@ -11,52 +11,57 @@ class DriftingGratings(DefaultSchema): - stimulus_key = List(String, default=DriftingGratings.known_stimulus_keys(), help='Key for the drifting gratings stimulus') - trial_duration = Float(default=2.0, help='typical length of a epoch for given stimulus in seconds') - psth_resolution = Float(default=0.001, help='resultion (seconds) for generating PSTH') - + stimulus_key = List( + String, default=DriftingGratings.known_stimulus_keys(), help="Key for the drifting gratings stimulus" + ) + trial_duration = Float(default=2.0, help="typical length of a epoch for given stimulus in seconds") + psth_resolution = Float(default=0.001, help="resultion (seconds) for generating PSTH") class StaticGratings(DefaultSchema): - stimulus_key = List(String, default=StaticGratings.known_stimulus_keys(), help='Key for the static gratings stimulus') - trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') - psth_resolution = Float(default=0.001, help='resultion (seconds) for generating PSTH') + stimulus_key = List( + String, default=StaticGratings.known_stimulus_keys(), help="Key for the static gratings stimulus" + ) + trial_duration = Float(default=0.25, help="typical length of a epoch for given stimulus in seconds") + psth_resolution = Float(default=0.001, help="resultion (seconds) for generating PSTH") class NaturalScenes(DefaultSchema): - stimulus_key = List(String, default=NaturalScenes.known_stimulus_keys(), help='Key for the natural scenes stimulus') - trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') - psth_resolution = Float(default=0.001, help='resultion (seconds) for generating PSTH') + stimulus_key = List(String, default=NaturalScenes.known_stimulus_keys(), help="Key for the natural scenes stimulus") + trial_duration = Float(default=0.25, help="typical length of a epoch for given stimulus in seconds") + psth_resolution = Float(default=0.001, help="resultion (seconds) for generating PSTH") -#class NaturalMovies(DefaultSchema): +# class NaturalMovies(DefaultSchema): # stimulus_key = String(help='Key for the natural movies stimulus') # trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') class DotMotion(DefaultSchema): - stimulus_key = List(String, default=DotMotion.known_stimulus_keys(), help='Key for the dot motion stimulus') - trial_duration = Float(default=1.0, help='typical length of a epoch for given stimulus in seconds') - psth_resolution = Float(default=0.001, help='resultion (seconds) for generating PSTH') + stimulus_key = List(String, default=DotMotion.known_stimulus_keys(), help="Key for the dot motion stimulus") + trial_duration = Float(default=1.0, help="typical length of a epoch for given stimulus in seconds") + psth_resolution = Float(default=0.001, help="resultion (seconds) for generating PSTH") -#class ContrastTuning(DefaultSchema): +# class ContrastTuning(DefaultSchema): # stimulus_key = String(help='Key for the contrast tuning stimulus') # trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') class Flashes(DefaultSchema): - stimulus_key = List(String, default=Flashes.known_stimulus_keys(), help='Key for the flash stimulus') - trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') - psth_resolution = Float(default=0.001, help='resultion (seconds) for generating PSTH') + stimulus_key = List(String, default=Flashes.known_stimulus_keys(), help="Key for the flash stimulus") + trial_duration = Float(default=0.25, help="typical length of a epoch for given stimulus in seconds") + psth_resolution = Float(default=0.001, help="resultion (seconds) for generating PSTH") class ReceptiveFieldMapping(DefaultSchema): - stimulus_key = List(String, default=ReceptiveFieldMapping.known_stimulus_keys(), help='Key for the receptive field mapping stimulus') - trial_duration = Float(default=0.25, help='typical length of a epoch for given stimulus in seconds') - minimum_spike_count = Int(default=10, help='Minimum number of spikes for computing receptive field parameters') - mask_threshold = Float(default=1.0, help='Threshold (as fraction of peak) for computing receptive field mask') - stimulus_step_size = Float(default=10.0, help='Distance between stimulus locations in degrees') + stimulus_key = List( + String, default=ReceptiveFieldMapping.known_stimulus_keys(), help="Key for the receptive field mapping stimulus" + ) + trial_duration = Float(default=0.25, help="typical length of a epoch for given stimulus in seconds") + minimum_spike_count = Int(default=10, help="Minimum number of spikes for computing receptive field parameters") + mask_threshold = Float(default=1.0, help="Threshold (as fraction of peak) for computing receptive field mask") + stimulus_step_size = Float(default=10.0, help="Distance between stimulus locations in degrees") class InputParameters(ArgSchema): @@ -69,14 +74,12 @@ class InputParameters(ArgSchema): flashes = Nested(Flashes) receptive_field_mapping = Nested(ReceptiveFieldMapping) - input_session_nwb = String(required=True, help='Ecephys spiking nwb file for session') - output_file = String(required=True, help='Location for saving output file') + input_session_nwb = String(required=True, help="Ecephys spiking nwb file for session") + output_file = String(required=True, help="Location for saving output file") class OutputSchema(DefaultSchema): - input_parameters = Nested(InputParameters, - description=("Input parameters the module was run with"), - required=True) + input_parameters = Nested(InputParameters, description=("Input parameters the module was run with"), required=True) class OutputParameters(OutputSchema): diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/dot_motion.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/dot_motion.py index 1d154a11f9..56de7a77ad 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/dot_motion.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/dot_motion.py @@ -6,7 +6,7 @@ from .stimulus_analysis import StimulusAnalysis -warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -33,7 +33,8 @@ class DotMotion(StimulusAnalysis): metrics_table_df = dm_analysis.metrics() """ - def __init__(self, ecephys_session, col_dir='Dir', col_speeds='Speed', trial_duration=1.0, **kwargs): + + def __init__(self, ecephys_session, col_dir="Dir", col_speeds="Speed", trial_duration=1.0, **kwargs): super(DotMotion, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._dirvals = None @@ -45,14 +46,14 @@ def __init__(self, ecephys_session, col_dir='Dir', col_speeds='Speed', trial_dur self._col_speed = col_speeds if self._params is not None: - self._params = self._params['dot_motion'] - self._stimulus_key = self._params['stimulus_key'] - #else: + self._params = self._params["dot_motion"] + self._stimulus_key = self._params["stimulus_key"] + # else: # self._stimulus_key = 'motion_stimulus' @property def name(self): - return 'Dot Motion' + return "Dot Motion" @property def directions(self): @@ -67,7 +68,7 @@ def number_directions(self): self._get_stim_table_stats() return self._number_dir - + @property def speeds(self): if self._speedvals is None: @@ -84,52 +85,57 @@ def number_speeds(self): @property def known_spontaneous_keys(self): - return ['dot_motion', "spontaneous_activity"] + return ["dot_motion", "spontaneous_activity"] @property def null_condition(self): - """ Stimulus condition ID for null stimulus (not used, so set to -1) """ + """Stimulus condition ID for null stimulus (not used, so set to -1)""" return -1 @property def METRICS_COLUMNS(self): - return [('pref_speed_dm', np.float64), - ('pref_speed_multi_dm', bool), - ('pref_dir_dm', np.float64), - ('pref_dir_multi_dm', bool), - ('firing_rate_dm', np.float64), - ('fano_dm', np.float64), - ('time_to_peak_dm', np.float64), - ('lifetime_sparseness_dm', np.float64), - ('run_mod_dm', np.float64), - ('run_pval_dm', np.float64)] + return [ + ("pref_speed_dm", np.float64), + ("pref_speed_multi_dm", bool), + ("pref_dir_dm", np.float64), + ("pref_dir_multi_dm", bool), + ("firing_rate_dm", np.float64), + ("fano_dm", np.float64), + ("time_to_peak_dm", np.float64), + ("lifetime_sparseness_dm", np.float64), + ("run_mod_dm", np.float64), + ("run_pval_dm", np.float64), + ] @property def metrics(self): if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self.stim_table) > 0: - metrics_df['pref_speed_dm'] = [self._get_pref_speed(unit) for unit in unit_ids] - metrics_df['pref_speed_multi_dm'] = [ + metrics_df["pref_speed_dm"] = [self._get_pref_speed(unit) for unit in unit_ids] + metrics_df["pref_speed_multi_dm"] = [ self._check_multiple_pref_conditions(unit_id, self._col_speed, self.speeds) for unit_id in unit_ids ] - metrics_df['pref_dir_dm'] = [self._get_pref_dir(unit) for unit in unit_ids] - metrics_df['pref_dir_multi_dm'] = [ - self._check_multiple_pref_conditions(unit_id, self._col_dir, self.directions) for unit_id in unit_ids + metrics_df["pref_dir_dm"] = [self._get_pref_dir(unit) for unit in unit_ids] + metrics_df["pref_dir_multi_dm"] = [ + self._check_multiple_pref_conditions(unit_id, self._col_dir, self.directions) + for unit_id in unit_ids + ] + metrics_df["firing_rate_dm"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["fano_dm"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids ] - metrics_df['firing_rate_dm'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['fano_dm'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] # metrics_df['speed_tuning_idx_dm'] = [self._get_speed_tuning_index(unit) for unit in unit_ids] - metrics_df['time_to_peak_dm'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for - unit in unit_ids] - metrics_df['lifetime_sparseness_dm'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] - metrics_df.loc[:, ['run_pval_dm', 'run_mod_dm']] = \ - [self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] - + metrics_df["time_to_peak_dm"] = [ + self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["lifetime_sparseness_dm"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df.loc[:, ["run_pval_dm", "run_mod_dm"]] = [ + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] self._metrics = metrics_df @@ -137,20 +143,22 @@ def metrics(self): @classmethod def known_stimulus_keys(cls): - return ['motion_stimulus', 'dot_motion'] + return ["motion_stimulus", "dot_motion"] def _get_stim_table_stats(self): - """ Extract directions and speeds from the stimulus table """ - self._dirvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_dir] - != 'null'][self._col_dir].unique()) + """Extract directions and speeds from the stimulus table""" + self._dirvals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_dir] != "null"][self._col_dir].unique() + ) self._number_dir = len(self._dirvals) - self._speedvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_speed] - != 'null'][self._col_speed].unique()) + self._speedvals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_speed] != "null"][self._col_speed].unique() + ) self._number_speed = len(self._speedvals) def _get_pref_speed(self, unit_id): - """ Calculate the preferred speed condition for a given unit + """Calculate the preferred speed condition for a given unit Parameters ---------- @@ -163,12 +171,18 @@ def _get_pref_speed(self, unit_id): stimulus speed driving the maximal response """ # TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes - similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_speed] - == speed].tolist() for speed in self.speeds] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_speed] == speed].tolist() + for speed in self.speeds + ] df = pd.DataFrame( index=self.speeds, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] + }, ).rename_axis(self._col_speed) return df.idxmax().iloc[0] @@ -186,18 +200,24 @@ def _get_pref_dir(self, unit_id): pref_dir : float stimulus direction driving the maximal response """ - similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_dir] - == direction].tolist() for direction in self.directions] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_dir] == direction].tolist() + for direction in self.directions + ] df = pd.DataFrame( index=self.directions, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] + }, ).rename_axis(self._col_dir) return df.idxmax().iloc[0] def _get_speed_tuning_index(self, unit_id): - """ Calculate the speed tuning for a given unit + """Calculate the speed tuning for a given unit SEE: https://github.com/AllenInstitute/ecephys_analysis_modules/blob/master/ecephys_analysis_modules/modules/tuning/tuning_speed.py diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py index dc54c4d75f..22aefff25c 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py @@ -11,7 +11,8 @@ from ...circle_plots import FanPlotter import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) + +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -36,8 +37,16 @@ class DriftingGratings(StimulusAnalysis): metrics_table_df = dg_analysis.metrics() """ - def __init__(self, ecephys_session, col_ori='orientation', col_tf='temporal_frequency', col_contrast='contrast', - trial_duration=2.0, **kwargs): + + def __init__( + self, + ecephys_session, + col_ori="orientation", + col_tf="temporal_frequency", + col_contrast="contrast", + trial_duration=2.0, + **kwargs, + ): super(DriftingGratings, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._metrics = None @@ -55,26 +64,25 @@ def __init__(self, ecephys_session, col_ori='orientation', col_tf='temporal_freq if self._params is not None: # TODO: Need to make sure - self._params = self._params.get('drifting_gratings', {}) - self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars + self._params = self._params.get("drifting_gratings", {}) + self._stimulus_key = self._params.get("stimulus_key", None) # Overwrites parent value with argvars else: self._params = {} self._stim_table_contrast = None - #stim_table = self.stim_table - #self._stim_table_contrast = stim_table[stim_table['stimulus_name'] == 'drifting_gratings_contrast'] - #self._stim_table = stim_table[stim_table['stimulus_name'] != 'drifting_gratings_contrast'] + # stim_table = self.stim_table + # self._stim_table_contrast = stim_table[stim_table['stimulus_name'] == 'drifting_gratings_contrast'] + # self._stim_table = stim_table[stim_table['stimulus_name'] != 'drifting_gratings_contrast'] self._conditionwise_statistics_contrast = None self._stimulus_conditions_contrast = None - @property def stim_table_contrast(self): if self._stim_table_contrast is None: stim_table = self.ecephys_session.stimulus_presentations - if 'drifting_gratings_contrast' in stim_table['stimulus_name'].unique(): - self._stim_table_contrast = stim_table[stim_table['stimulus_name'] == 'drifting_gratings_contrast'] + if "drifting_gratings_contrast" in stim_table["stimulus_name"].unique(): + self._stim_table_contrast = stim_table[stim_table["stimulus_name"] == "drifting_gratings_contrast"] else: self._stim_table_contrast = pd.DataFrame() @@ -82,11 +90,11 @@ def stim_table_contrast(self): @property def name(self): - return 'Drifting Gratings' + return "Drifting Gratings" @property def orivals(self): - """ Array of grating orientation conditions """ + """Array of grating orientation conditions""" if self._orivals is None: self._get_stim_table_stats() @@ -94,7 +102,7 @@ def orivals(self): @property def number_ori(self): - """ Number of grating orientation conditions """ + """Number of grating orientation conditions""" if self._number_ori is None: self._get_stim_table_stats() @@ -102,7 +110,7 @@ def number_ori(self): @property def tfvals(self): - """ Array of grating temporal frequency conditions """ + """Array of grating temporal frequency conditions""" if self._tfvals is None: self._get_stim_table_stats() @@ -110,7 +118,7 @@ def tfvals(self): @property def number_tf(self): - """ Number of grating temporal frequency conditions """ + """Number of grating temporal frequency conditions""" if self._tfvals is None: self._get_stim_table_stats() @@ -118,7 +126,7 @@ def number_tf(self): @property def contrastvals(self): - """ Array of grating temporal frequency conditions """ + """Array of grating temporal frequency conditions""" if self._contrastvals is None: self._get_stim_table_stats() @@ -126,7 +134,7 @@ def contrastvals(self): @property def number_contrast(self): - """ Number of grating temporal frequency conditions """ + """Number of grating temporal frequency conditions""" if self._number_contrast is None: self._get_stim_table_stats() @@ -134,12 +142,12 @@ def number_contrast(self): @property def null_condition(self): - """ Stimulus condition ID for null (blank) stimulus """ - return self.stimulus_conditions[self.stimulus_conditions[self._col_tf] == 'null'].index + """Stimulus condition ID for null (blank) stimulus""" + return self.stimulus_conditions[self.stimulus_conditions[self._col_tf] == "null"].index @property def stimulus_conditions_contrast(self): - """ Stimulus conditions for contrast stimulus """ + """Stimulus conditions for contrast stimulus""" if self._stimulus_conditions_contrast is None: # TODO: look into efficiency of using a table intersect instead. contrast_condition_list = self.stim_table_contrast.stimulus_condition_id.unique() @@ -152,67 +160,72 @@ def stimulus_conditions_contrast(self): @property def conditionwise_statistics_contrast(self): - """ Conditionwise statistics for contrast stimulus """ + """Conditionwise statistics for contrast stimulus""" if self._conditionwise_statistics_contrast is None: self._conditionwise_statistics_contrast = self.ecephys_session.conditionwise_spike_statistics( - self.stim_table_contrast.index.values, - self.unit_ids + self.stim_table_contrast.index.values, self.unit_ids ) return self._conditionwise_statistics_contrast @property def METRICS_COLUMNS(self): - return [('pref_ori_dg', np.float64), - ('pref_ori_multi_dg', bool), - ('pref_tf_dg', np.float64), - ('pref_tf_multi_dg', bool), - ('c50_dg', np.float64), - ('f1_f0_dg', np.float64), - ('mod_idx_dg', np.float64), - ('g_osi_dg', np.float64), - ('g_dsi_dg', np.float64), - ('firing_rate_dg', np.float64), - ('fano_dg', np.float64), - ('lifetime_sparseness_dg', np.float64), - ('run_pval_dg', np.float64), - ('run_mod_dg', np.float64)] + return [ + ("pref_ori_dg", np.float64), + ("pref_ori_multi_dg", bool), + ("pref_tf_dg", np.float64), + ("pref_tf_multi_dg", bool), + ("c50_dg", np.float64), + ("f1_f0_dg", np.float64), + ("mod_idx_dg", np.float64), + ("g_osi_dg", np.float64), + ("g_dsi_dg", np.float64), + ("firing_rate_dg", np.float64), + ("fano_dg", np.float64), + ("lifetime_sparseness_dg", np.float64), + ("run_pval_dg", np.float64), + ("run_mod_dg", np.float64), + ] @property def metrics(self): - if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self.stim_table) > 0: - metrics_df['pref_ori_dg'] = [self._get_pref_ori(unit) for unit in unit_ids] - metrics_df['pref_ori_multi_dg'] = [ + metrics_df["pref_ori_dg"] = [self._get_pref_ori(unit) for unit in unit_ids] + metrics_df["pref_ori_multi_dg"] = [ self._check_multiple_pref_conditions(unit_id, self._col_ori, self.orivals) for unit_id in unit_ids ] - metrics_df['pref_tf_dg'] = [self._get_pref_tf(unit) for unit in unit_ids] - metrics_df['pref_tf_multi_dg'] = [ + metrics_df["pref_tf_dg"] = [self._get_pref_tf(unit) for unit in unit_ids] + metrics_df["pref_tf_multi_dg"] = [ self._check_multiple_pref_conditions(unit_id, self._col_tf, self.tfvals) for unit_id in unit_ids ] - metrics_df['f1_f0_dg'] = [self._get_f1_f0(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['mod_idx_dg'] = [self._get_modulation_index(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['g_osi_dg'] = [self._get_selectivity(unit, metrics_df.loc[unit]['pref_tf_dg'], 'osi') - for unit in unit_ids] - metrics_df['g_dsi_dg'] = [self._get_selectivity(unit, metrics_df.loc[unit]['pref_tf_dg'], 'dsi') - for unit in unit_ids] - metrics_df['firing_rate_dg'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['fano_dg'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['lifetime_sparseness_dg'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] - metrics_df.loc[:, ['run_pval_dg', 'run_mod_dg']] = [ - self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] + metrics_df["f1_f0_dg"] = [ + self._get_f1_f0(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["mod_idx_dg"] = [ + self._get_modulation_index(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["g_osi_dg"] = [ + self._get_selectivity(unit, metrics_df.loc[unit]["pref_tf_dg"], "osi") for unit in unit_ids + ] + metrics_df["g_dsi_dg"] = [ + self._get_selectivity(unit, metrics_df.loc[unit]["pref_tf_dg"], "dsi") for unit in unit_ids + ] + metrics_df["firing_rate_dg"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["fano_dg"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["lifetime_sparseness_dg"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df.loc[:, ["run_pval_dg", "run_mod_dg"]] = [ + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] if len(self.stim_table_contrast) > 0: - metrics_df['c50_dg'] = [self._get_c50(unit) for unit in unit_ids] - + metrics_df["c50_dg"] = [self._get_c50(unit) for unit in unit_ids] self._metrics = metrics_df @@ -220,24 +233,29 @@ def metrics(self): @classmethod def known_stimulus_keys(cls): - return ['drifting_gratings', 'drifting_gratings_75_repeats'] + return ["drifting_gratings", "drifting_gratings_75_repeats"] def _get_stim_table_stats(self): - """ Extract orientations and temporal frequencies from the stimulus table """ - self._orivals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] - != 'null'][self._col_ori].unique()) + """Extract orientations and temporal frequencies from the stimulus table""" + self._orivals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] != "null"][self._col_ori].unique() + ) self._number_ori = len(self._orivals) - self._tfvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_tf] - != 'null'][self._col_tf].unique()) + self._tfvals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_tf] != "null"][self._col_tf].unique() + ) self._number_tf = len(self._tfvals) - self._contrastvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_contrast] - != 'null'][self._col_contrast].unique()) + self._contrastvals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_contrast] != "null"][ + self._col_contrast + ].unique() + ) self._number_contrast = len(self._contrastvals) def _get_pref_ori(self, unit_id): - """ Calculate the preferred orientation condition for a given unit + """Calculate the preferred orientation condition for a given unit Parameters ---------- @@ -250,18 +268,24 @@ def _get_pref_ori(self, unit_id): stimulus orientation driving the maximal response """ # TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes - similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist() - for ori in self.orivals] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist() + for ori in self.orivals + ] df = pd.DataFrame( index=self.orivals, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] + }, ).rename_axis(self._col_ori) return df.idxmax().iloc[0] def _get_pref_tf(self, unit_id): - """ Calculate the preferred temporal frequency condition for a given unit + """Calculate the preferred temporal frequency condition for a given unit Params: ------- @@ -273,18 +297,23 @@ def _get_pref_tf(self, unit_id): pref_tf : float stimulus temporal frequency driving the maximal response """ - similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_tf] == tf].tolist() - for tf in self.tfvals] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_tf] == tf].tolist() for tf in self.tfvals + ] df = pd.DataFrame( index=self.tfvals, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] + }, ).rename_axis(self._col_tf) return df.idxmax().iloc[0] - def _get_selectivity(self, unit_id, pref_tf, selectivity_type='osi'): - """ Calculate the orientation or direction selectivity for a given unit + def _get_selectivity(self, unit_id, pref_tf, selectivity_type="osi"): + """Calculate the orientation or direction selectivity for a given unit Params: ------- @@ -297,25 +326,25 @@ def _get_selectivity(self, unit_id, pref_tf, selectivity_type='osi'): selectivity - orientation or direction selectivity value """ - orivals_rad = deg2rad(self.orivals).astype('complex128') + orivals_rad = deg2rad(self.orivals).astype("complex128") condition_inds = self.stimulus_conditions[self.stimulus_conditions[self._col_tf] == pref_tf].index.values df = self.conditionwise_statistics.loc[unit_id].loc[condition_inds] df = df.assign(ori=self.stimulus_conditions.loc[df.index.values][self._col_ori]) - df = df.sort_values(by=['ori']) # do not replace with self._col_ori unless we modify the line above + df = df.sort_values(by=["ori"]) # do not replace with self._col_ori unless we modify the line above - tuning = np.array(df['spike_mean'].values) + tuning = np.array(df["spike_mean"].values) - if selectivity_type == 'osi': + if selectivity_type == "osi": return osi(orivals_rad, tuning) - elif selectivity_type == 'dsi': + elif selectivity_type == "dsi": return dsi(orivals_rad, tuning) else: - warnings.warn(f'unkown selectivity function {selectivity_type}.') + warnings.warn(f"unkown selectivity function {selectivity_type}.") return np.nan def _get_f1_f0(self, unit_id, condition_id): - """ Calculate F1/F0 for a given unit + """Calculate F1/F0 for a given unit A measure of how tightly locked a unit's firing rate is to the cycles of a drifting grating @@ -329,22 +358,24 @@ def _get_f1_f0(self, unit_id, condition_id): f1_f0 - metric """ - presentation_ids = self.stim_table[self.stim_table['stimulus_condition_id'] == condition_id].index.values - + presentation_ids = self.stim_table[self.stim_table["stimulus_condition_id"] == condition_id].index.values + tf = self.stim_table.loc[presentation_ids[0]][self._col_tf] dataset = self.ecephys_session.presentationwise_spike_counts( bin_edges=np.arange(0, self.trial_duration, 0.001), stimulus_presentation_ids=presentation_ids, - unit_ids=[unit_id] - ).drop('unit_id') + unit_ids=[unit_id], + ).drop("unit_id") arr = np.squeeze(dataset.values) - trial_duration = dataset.time_relative_to_stimulus_onset.max() #TODO: If there a reason not to use self.trial_duration? + trial_duration = ( + dataset.time_relative_to_stimulus_onset.max() + ) # TODO: If there a reason not to use self.trial_duration? return f1_f0(arr, tf, trial_duration) def _get_modulation_index(self, unit_id, condition_id): - """ Calculate modulation index for a given unit. + """Calculate modulation index for a given unit. Parameters ---------- @@ -365,7 +396,7 @@ def _get_modulation_index(self, unit_id, condition_id): return modulation_index(data, tf, sample_rate) def _get_c50(self, unit_id): - """ Calculate C50 for a given unit. Only valid if the contrast tuning stimulus is present. Otherwise, + """Calculate C50 for a given unit. Only valid if the contrast tuning stimulus is present. Otherwise, return NaN value Parameters @@ -381,11 +412,18 @@ def _get_c50(self, unit_id): """ contrast_conditions = self.stim_table_contrast[ - (self.stim_table_contrast[self._col_ori] == self._get_pref_ori(unit_id))]['stimulus_condition_id'].unique() + (self.stim_table_contrast[self._col_ori] == self._get_pref_ori(unit_id)) + ]["stimulus_condition_id"].unique() # contrasts = self.stimulus_conditions_contrast.loc[contrast_conditions]['contrast'].values.astype('float') - contrasts = self.stimulus_conditions_contrast.loc[contrast_conditions][self._col_contrast].values.astype('float') - mean_responses = self.conditionwise_statistics_contrast.loc[unit_id].loc[contrast_conditions]['spike_mean'].values.astype('float') + contrasts = self.stimulus_conditions_contrast.loc[contrast_conditions][self._col_contrast].values.astype( + "float" + ) + mean_responses = ( + self.conditionwise_statistics_contrast.loc[unit_id] + .loc[contrast_conditions]["spike_mean"] + .values.astype("float") + ) return c50(contrasts, mean_responses) @@ -514,69 +552,81 @@ def _fit_tf_tuning(self, unit_id, pref_ori, pref_tf): ## VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): - """ Plot raster for one condition and one unit """ + """Plot raster for one condition and one unit""" idx_tf = np.where(self.tfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_tf])[0] idx_ori = np.where(self.orivals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_ori])[0] - + if len(idx_tf) == len(idx_ori) == 1: - presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[ - self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] == stimulus_condition_id + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] == stimulus_condition_id ].index.values - + df = self.presentationwise_spike_times[ - (self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & - (self.presentationwise_spike_times['unit_id'] == unit_id)] - + (self.presentationwise_spike_times["stimulus_presentation_id"].isin(presentation_ids)) + & (self.presentationwise_spike_times["unit_id"] == unit_id) + ] + x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time - _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) - - plt.subplot(self.number_tf, self.number_ori, idx_tf*self.number_ori + idx_ori + 1) - plt.scatter(x, y, c='k', s=1, alpha=0.25) - plt.axis('off') + _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) + + plt.subplot(self.number_tf, self.number_ori, idx_tf * self.number_ori + idx_ori + 1) + plt.scatter(x, y, c="k", s=1, alpha=0.25) + plt.axis("off") def plot_response_summary(self, unit_id, bar_thickness=0.25): - """ Plot the spike counts across conditions """ + """Plot the spike counts across conditions""" df = self.stimulus_conditions.drop(index=self.null_condition) - - df['tf_index'] = np.searchsorted(self.tfvals, df[self._col_tf].values) - df['ori_index'] = np.searchsorted(self.orivals, df[self._col_ori].values) - - cond_values = self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] - - x = df.loc[cond_values.values]['tf_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 - y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] - c = df.loc[cond_values.values]['tf_index'] - + + df["tf_index"] = np.searchsorted(self.tfvals, df[self._col_tf].values) + df["ori_index"] = np.searchsorted(self.orivals, df[self._col_ori].values) + + cond_values = self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + + x = ( + df.loc[cond_values.values]["tf_index"] + + np.random.rand(cond_values.size) * bar_thickness + - bar_thickness / 2 + ) + y = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"] + c = df.loc[cond_values.values]["tf_index"] + plt.subplot(2, 1, 1) - plt.scatter(y, x, c=c, alpha=0.5, cmap='Purples', vmin=-5) + plt.scatter(y, x, c=c, alpha=0.5, cmap="Purples", vmin=-5) locs, labels = plt.yticks(ticks=np.arange(self.number_tf), labels=self.tfvals) - plt.ylabel('Temporal frequency') - plt.xlabel('Spikes per trial') + plt.ylabel("Temporal frequency") + plt.xlabel("Spikes per trial") plt.ylim([self.number_tf, -1]) - x = df.loc[cond_values.values]['ori_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 - y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] - c = df.loc[cond_values.values]['ori_index'] - + x = ( + df.loc[cond_values.values]["ori_index"] + + np.random.rand(cond_values.size) * bar_thickness + - bar_thickness / 2 + ) + y = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"] + c = df.loc[cond_values.values]["ori_index"] + plt.subplot(2, 1, 2) - plt.scatter(x, y, c=c, alpha=0.5, cmap='Spectral') + plt.scatter(x, y, c=c, alpha=0.5, cmap="Spectral") locs, labels = plt.xticks(ticks=np.arange(self.number_ori), labels=self.orivals) - plt.xlabel('Orientation') - plt.ylabel('Spikes per trial') + plt.xlabel("Orientation") + plt.ylabel("Spikes per trial") def make_star_plot(self, unit_id): - """ Make a 2P-style Star Plot based on presentationwise spike counts""" - angle_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_ori].values - r_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_tf].values - data = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values - - null_trials = np.where(angle_data == 'null')[0] - + """Make a 2P-style Star Plot based on presentationwise spike counts""" + angle_data = self.stimulus_conditions.loc[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + ][self._col_ori].values + r_data = self.stimulus_conditions.loc[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + ][self._col_tf].values + data = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"].values + + null_trials = np.where(angle_data == "null")[0] + angle_data = np.delete(angle_data, null_trials) r_data = np.delete(r_data, null_trials) data = np.delete(data, null_trials) - + cmin = np.min(data) cmax = np.max(data) @@ -584,8 +634,8 @@ def make_star_plot(self, unit_id): fp.plot(r_data=r_data, angle_data=angle_data, data=data, clim=[cmin, cmax]) fp.show_axes(closed=False) plt.ylim([-5, 5]) - plt.axis('equal') - plt.axis('off') + plt.axis("equal") + plt.axis("off") ### General functions ### @@ -594,23 +644,23 @@ def _gauss_function(x, a, x0, sigma): fit gaussian function at log scale good for fitting band pass, not good at low pass or high pass """ - return a*np.exp(-(x-x0)**2/(2*sigma**2)) + return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) def _exp_function(x, a, b, c): - return a*np.exp(-b*x)+c + return a * np.exp(-b * x) + c def _contrast_curve(x, b, c, d, e): """Difference of gaussian. - fit sigmoid function at log scale - not good for fitting band pass - - b: hill slope - - c: min response - - d: max response - - e: EC50 + fit sigmoid function at log scale + not good for fitting band pass + - b: hill slope + - c: min response + - d: max response + - e: EC50 """ - return c+(d-c)/(1+np.exp(b*(np.log(x)-np.log(e)))) + return c + (d - c) / (1 + np.exp(b * (np.log(x) - np.log(e)))) def c50(contrasts, responses): @@ -629,7 +679,7 @@ def c50(contrasts, responses): c50 : float """ if contrasts.size == 0 or contrasts.size != responses.size: - warnings.warn('the contrasts and responses arrays must be of the same length') + warnings.warn("the contrasts and responses arrays must be of the same length") return np.nan try: @@ -639,10 +689,10 @@ def c50(contrasts, responses): except RuntimeError as e: warnings.warn(str(e)) return np.nan - + # Create the constrast curve using the optimized parameters, get the halfway range point on the curve # resids = responses - contrast_curve(contrasts.astype('float'), *fitCoefs) - X = np.linspace(min(contrasts)*0.9, max(contrasts)*1.1, 256) # + X = np.linspace(min(contrasts) * 0.9, max(contrasts) * 1.1, 256) # y_fit = _contrast_curve(X, *fitCoefs) y_middle = (np.max(y_fit) - np.min(y_fit)) / 2 + np.min(y_fit) @@ -657,7 +707,7 @@ def c50(contrasts, responses): except IndexError as e: warnings.warn(str(e)) return np.nan - + return c50 @@ -692,24 +742,24 @@ def f1_f0(arr, tf, trial_duration): # can occur if temp-freq x trial duration is greater than the total trial duration return np.nan - arr = arr[:, :cycles_per_trial*bins_per_cycle].reshape((num_trials, cycles_per_trial, bins_per_cycle)) + arr = arr[:, : cycles_per_trial * bins_per_cycle].reshape((num_trials, cycles_per_trial, bins_per_cycle)) avg_rate = np.mean(arr, 1) - AMP = 2*np.abs(fft(avg_rate, bins_per_cycle)) / bins_per_cycle + AMP = 2 * np.abs(fft(avg_rate, bins_per_cycle)) / bins_per_cycle - f0 = 0.5*AMP[:, 0] + f0 = 0.5 * AMP[:, 0] f1 = AMP[:, 1] selection = f0 > 0.0 if not np.any(selection): # No spikes found return np.nan - return np.nanmean(f1[selection]/f0[selection]) + return np.nanmean(f1[selection] / f0[selection]) def modulation_index(response_psth, tf, sample_rate): """Depth of modulation by each cycle of a drifting grating; similar to F1/F0 - ref: Matteucci et al. (2019) Nonlinear processing of shape information + ref: Matteucci et al. (2019) Nonlinear processing of shape information in rat lateral extrastriate cortex. J Neurosci 39: 1649-1670 Parameters @@ -728,7 +778,7 @@ def modulation_index(response_psth, tf, sample_rate): """ if response_psth.size == 0: - warnings.warn('response_psth is empty') + warnings.warn("response_psth is empty") return np.nan f, psd = signal.welch(response_psth, fs=sample_rate, nperseg=1024) # get freqs. and power spectral density @@ -739,9 +789,10 @@ def modulation_index(response_psth, tf, sample_rate): tf_index = np.searchsorted(f, tf) if not 0 <= tf_index < psd.size: - warnings.warn('specified temporal frequency is not within the singals sampling range. Please adjust tf and/or' - 'sample_rate parameters.') + warnings.warn( + "specified temporal frequency is not within the singals sampling range. Please adjust tf and/or" + "sample_rate parameters." + ) return np.nan - return abs((psd[tf_index] - np.mean(psd))/np.sqrt(np.mean(psd**2)- mean_psd**2)) - + return abs((psd[tf_index] - np.mean(psd)) / np.sqrt(np.mean(psd**2) - mean_psd**2)) diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py index 2841ef65e3..a200f35a82 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py @@ -7,7 +7,8 @@ from .stimulus_analysis import StimulusAnalysis import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) + +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ class Flashes(StimulusAnalysis): """ - def __init__(self, ecephys_session, col_color='color', trial_duration=0.25, **kwargs): + def __init__(self, ecephys_session, col_color="color", trial_duration=0.25, **kwargs): super(Flashes, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._metrics = None @@ -41,18 +42,18 @@ def __init__(self, ecephys_session, col_color='color', trial_duration=0.25, **kw self._col_color = col_color if self._params is not None: - self._params = self._params.get('flashes', {}) - self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars + self._params = self._params.get("flashes", {}) + self._stimulus_key = self._params.get("stimulus_key", None) # Overwrites parent value with argvars else: self._params = {} @property def name(self): - return 'Flashes' + return "Flashes" @property def colors(self): - """ Array of 'color' conditions (black vs. white flash) """ + """Array of 'color' conditions (black vs. white flash)""" if self._colors is None: self._get_stim_table_stats() @@ -60,48 +61,54 @@ def colors(self): @property def number_colors(self): - """ Number of 'color' conditions (black vs. white flash) """ + """Number of 'color' conditions (black vs. white flash)""" if self._colors is None: self._get_stim_table_stats() return len(self._colors) - + @property def null_condition(self): - """ Stimulus condition ID for null stimulus (not used, so set to -1) """ + """Stimulus condition ID for null stimulus (not used, so set to -1)""" # TODO: If null_condition is not used remove it, parent should have it set to 1 return -1 - + @property def METRICS_COLUMNS(self): - return [('on_off_ratio_fl', np.float64), - ('sustained_idx_fl', np.float64), - ('firing_rate_fl', np.float64), - ('time_to_peak_fl', np.float64), - ('fano_fl', np.float64), - ('lifetime_sparseness_fl', np.float64), - ('run_pval_fl', np.float64), - ('run_mod_fl', np.float64)] + return [ + ("on_off_ratio_fl", np.float64), + ("sustained_idx_fl", np.float64), + ("firing_rate_fl", np.float64), + ("time_to_peak_fl", np.float64), + ("fano_fl", np.float64), + ("lifetime_sparseness_fl", np.float64), + ("run_pval_fl", np.float64), + ("run_mod_fl", np.float64), + ] @property def metrics(self): if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() - if len(self. stim_table) > 0: - metrics_df['on_off_ratio_fl'] = [self._get_on_off_ratio(unit) for unit in unit_ids] - metrics_df['sustained_idx_fl'] = [self._get_sustained_index(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['firing_rate_fl'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['time_to_peak_fl'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['fano_fl'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['lifetime_sparseness_fl'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] - metrics_df.loc[:, ['run_pval_fl', 'run_mod_fl']] = [ - self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] + if len(self.stim_table) > 0: + metrics_df["on_off_ratio_fl"] = [self._get_on_off_ratio(unit) for unit in unit_ids] + metrics_df["sustained_idx_fl"] = [ + self._get_sustained_index(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["firing_rate_fl"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["time_to_peak_fl"] = [ + self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["fano_fl"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["lifetime_sparseness_fl"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df.loc[:, ["run_pval_fl", "run_mod_fl"]] = [ + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] self._metrics = metrics_df @@ -115,7 +122,7 @@ def _find_stimulus_key(self, stim_table): """ known_keys_lc = [k.lower() for k in self.__class__.known_stimulus_keys()] - for table_key in stim_table['stimulus_name'].unique(): + for table_key in stim_table["stimulus_name"].unique(): table_key_lc = table_key.lower() for known_key in known_keys_lc: if table_key_lc.startswith(known_key): @@ -126,15 +133,16 @@ def _find_stimulus_key(self, stim_table): @classmethod def known_stimulus_keys(cls): - return ['flash', 'flashes'] + return ["flash", "flashes"] def _get_stim_table_stats(self): - """ Extract colors from the stimulus table """ - self._colors = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_color] - != 'null'][self._col_color].unique()) + """Extract colors from the stimulus table""" + self._colors = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_color] != "null"][self._col_color].unique() + ) def _get_sustained_index(self, unit_id, condition_id): - """ Calculate the sustained index for a given unit, a measure of the transience of the flash response. + """Calculate the sustained index for a given unit, a measure of the transience of the flash response. Parameters ---------- @@ -149,7 +157,7 @@ def _get_sustained_index(self, unit_id, condition_id): A cell that first continuously throughout the flash will have a sustained index closer to 1 """ psth = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=condition_id).data - return np.mean(psth)/np.amax(psth) + return np.mean(psth) / np.amax(psth) def _get_on_off_ratio(self, unit_id): """Gets the ratio of mean spikes for on-stimuli vs off stimuli. @@ -166,8 +174,8 @@ def _get_on_off_ratio(self, unit_id): on_condition_id = self.stimulus_conditions[self.stimulus_conditions[self._col_color] == 1.0].index.values off_condition_id = self.stimulus_conditions[self.stimulus_conditions[self._col_color] == -1.0].index.values - on_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[on_condition_id]['spike_mean'].values - off_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[off_condition_id]['spike_mean'].values + on_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[on_condition_id]["spike_mean"].values + off_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[off_condition_id]["spike_mean"].values if len(on_mean_spikes) == 0 or len(off_mean_spikes) == 0: return np.nan @@ -179,39 +187,36 @@ def _get_on_off_ratio(self, unit_id): ## VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): - - """ Plot raster for one condition and one unit """ + """Plot raster for one condition and one unit""" idx_color = np.where(self.colors == self.stimulus_conditions.loc[stimulus_condition_id][self._col_color])[0] if len(idx_color) == 1: - presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[ - self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] - == stimulus_condition_id].index.values - + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] == stimulus_condition_id + ].index.values + df = self.presentationwise_spike_times[ - (self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & - (self.presentationwise_spike_times['unit_id'] == unit_id) + (self.presentationwise_spike_times["stimulus_presentation_id"].isin(presentation_ids)) + & (self.presentationwise_spike_times["unit_id"] == unit_id) ] - + x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time - _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) - + _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) + plt.subplot(self.number_colors, 1, idx_color + 1) - plt.scatter(x, y, c='k', s=1, alpha=0.25) - plt.axis('off') + plt.scatter(x, y, c="k", s=1, alpha=0.25) + plt.axis("off") def plot_response(self, unit_id): - """ Plot a histogram for the two conditions """ - plot_colors = ('darkslateblue', 'grey') + """Plot a histogram for the two conditions""" + plot_colors = ("darkslateblue", "grey") for idx, color in enumerate(self.colors): + condition_id = self.stimulus_conditions[self.stimulus_conditions["color"] == color].index.values[0] - condition_id = self.stimulus_conditions[self.stimulus_conditions['color'] == color].index.values[0] - psth = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=condition_id).values - - plt.bar(np.arange(len(psth))-0.5, psth, color=plot_colors[idx], alpha=0.5, width=1.0) + + plt.bar(np.arange(len(psth)) - 0.5, psth, color=plot_colors[idx], alpha=0.5, width=1.0) plt.step(np.arange(len(psth)), psth, color=plot_colors[idx]) - plt.axis('off') + plt.axis("off") diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_movies.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_movies.py index 440ea77c87..fc8b729f8d 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_movies.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_movies.py @@ -6,7 +6,8 @@ from .stimulus_analysis import StimulusAnalysis import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) + +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -39,41 +40,44 @@ def __init__(self, ecephys_session, trial_duration=None, **kwargs): self._metrics = None if self._params is not None: - self._params = self._params['natural_movies'] - self._stimulus_key = self._params['stimulus_key'] - #else: + self._params = self._params["natural_movies"] + self._stimulus_key = self._params["stimulus_key"] + # else: # self._stimulus_key = 'natural_movies' @property def name(self): - return 'Natural Movies' + return "Natural Movies" @property def null_condition(self): return -1 - + @property def METRICS_COLUMNS(self): - return [('fano_nm', np.uint64), - ('firing_rate_nm', np.float64), - ('lifetime_sparseness_nm', np.float64), - ('run_pval_ns', np.float64), - ('run_mod_ns', np.float64)] + return [ + ("fano_nm", np.uint64), + ("firing_rate_nm", np.float64), + ("lifetime_sparseness_nm", np.float64), + ("run_pval_ns", np.float64), + ("run_mod_ns", np.float64), + ] @property def metrics(self): if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() - metrics_df['fano_nm'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['firing_rate_nm'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['lifetime_sparseness_nm'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df["fano_nm"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["firing_rate_nm"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["lifetime_sparseness_nm"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] run_vals = [self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] - metrics_df['run_pval_nm'] = [rv[0] for rv in run_vals] - metrics_df['run_mod_nm'] = [rv[1] for rv in run_vals] + metrics_df["run_pval_nm"] = [rv[0] for rv in run_vals] + metrics_df["run_mod_nm"] = [rv[1] for rv in run_vals] self._metrics = metrics_df @@ -81,7 +85,7 @@ def metrics(self): @classmethod def known_stimulus_keys(cls): - return ['natural_movies', 'natural_movie_1', 'natural_movie_3'] + return ["natural_movies", "natural_movie_1", "natural_movie_3"] def _get_stim_table_stats(self): pass diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_scenes.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_scenes.py index 5c2014f9e4..c7e0ba5e74 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_scenes.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/natural_scenes.py @@ -5,7 +5,7 @@ from .stimulus_analysis import StimulusAnalysis -warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class NaturalScenes(StimulusAnalysis): """ - def __init__(self, ecephys_session, col_image='frame', trial_duration=0.25, **kwargs): + def __init__(self, ecephys_session, col_image="frame", trial_duration=0.25, **kwargs): super(NaturalScenes, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._images = None @@ -46,18 +46,18 @@ def __init__(self, ecephys_session, col_image='frame', trial_duration=0.25, **kw self._col_image = col_image if self._params is not None: - self._params = self._params.get('natural_scenes', {}) - self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars + self._params = self._params.get("natural_scenes", {}) + self._stimulus_key = self._params.get("stimulus_key", None) # Overwrites parent value with argvars else: self._params = {} @property def name(self): - return 'Natural Scenes' + return "Natural Scenes" @property def images(self): - """ Array of iamge labels """ + """Array of iamge labels""" if self._images is None: self._get_stim_table_stats() @@ -77,7 +77,7 @@ def frames(self): @property def number_images(self): - """ Number of images shown """ + """Number of images shown""" if self._images is None: self._get_stim_table_stats() @@ -85,7 +85,7 @@ def number_images(self): @property def number_nonblank(self): - """ Number of images shown (excluding blank condition) """ + """Number of images shown (excluding blank condition)""" if self._number_nonblank is None: self._get_stim_table_stats() @@ -93,45 +93,50 @@ def number_nonblank(self): @property def null_condition(self): - """ Stimulus condition ID for null (blank) stimulus """ + """Stimulus condition ID for null (blank) stimulus""" return self.stimulus_conditions[self.stimulus_conditions[self._col_image] == -1].index @property def METRICS_COLUMNS(self): - return [('pref_image_ns', np.uint64), - ('image_selectivity_ns', np.float64), - ('firing_rate_ns', np.float64), - ('fano_ns', np.float64), - ('time_to_peak_ns', np.float64), - ('lifetime_sparseness_ns', np.float64), - ('run_pval_ns', np.float64), - ('run_mod_ns', np.float64)] + return [ + ("pref_image_ns", np.uint64), + ("image_selectivity_ns", np.float64), + ("firing_rate_ns", np.float64), + ("fano_ns", np.float64), + ("time_to_peak_ns", np.float64), + ("lifetime_sparseness_ns", np.float64), + ("run_pval_ns", np.float64), + ("run_mod_ns", np.float64), + ] @property def metrics(self): if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self.stim_table) > 0: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) - metrics_df['pref_image_ns'] = [self._get_preferred_condition(unit) for unit in unit_ids] - metrics_df['pref_images_multi_ns'] = [ + metrics_df["pref_image_ns"] = [self._get_preferred_condition(unit) for unit in unit_ids] + metrics_df["pref_images_multi_ns"] = [ self._check_multiple_pref_conditions(unit_id, self._col_image, self.images_nonblank) for unit_id in unit_ids ] - metrics_df['image_selectivity_ns'] = [self._get_image_selectivity(unit) for unit in unit_ids] - metrics_df['firing_rate_ns'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['fano_ns'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['time_to_peak_ns'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) - for unit in unit_ids] - metrics_df['lifetime_sparseness_ns'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] - metrics_df.loc[:, ['run_pval_ns', 'run_mod_ns']] = [ - self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] + metrics_df["image_selectivity_ns"] = [self._get_image_selectivity(unit) for unit in unit_ids] + metrics_df["firing_rate_ns"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["fano_ns"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["time_to_peak_ns"] = [ + self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["lifetime_sparseness_ns"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df.loc[:, ["run_pval_ns", "run_mod_ns"]] = [ + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] self._metrics = metrics_df @@ -139,20 +144,20 @@ def metrics(self): @classmethod def known_stimulus_keys(cls): - return ['natural_scenes', 'Natural_Images', 'Natural Images'] + return ["natural_scenes", "Natural_Images", "Natural Images"] def _get_stim_table_stats(self): - """ Extract image labels from the stimulus table """ + """Extract image labels from the stimulus table""" self._images = np.sort(self.stimulus_conditions[self._col_image].unique()).astype(np.int64) self._number_images = len(self._images) self._images_nonblank = self._images[self._images >= 0] self._number_nonblank = len(self._images_nonblank) def _get_image_selectivity(self, unit_id, num_steps=1000): - """ Calculate the image selectivity for a given unit using spike means at every image""" + """Calculate the image selectivity for a given unit using spike means at every image""" unit_stats = self.conditionwise_statistics.loc[unit_id].drop(index=self.null_condition) - return image_selectivity(unit_stats['spike_mean'].values, num_steps=num_steps) + return image_selectivity(unit_stats["spike_mean"].values, num_steps=num_steps) def image_selectivity(spike_means, num_steps=1000): @@ -185,7 +190,7 @@ def image_selectivity(spike_means, num_steps=1000): return 0.0 j = np.arange(num_steps) - thresh = fmin + j*((fmax - fmin) / num_steps) + thresh = fmin + j * ((fmax - fmin) / num_steps) rtj = [np.mean(spike_means > t) for t in thresh] return 1 - (2 * np.mean(rtj)) diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/receptive_field_mapping.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/receptive_field_mapping.py index 57ab351d6c..0e2cd2e202 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/receptive_field_mapping.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/receptive_field_mapping.py @@ -52,11 +52,9 @@ def __init__( trial_duration=0.25, minimum_spike_count=10.0, mask_threshold=0.5, - **kwargs + **kwargs, ): - super(ReceptiveFieldMapping, self).__init__( - ecephys_session, trial_duration=trial_duration, **kwargs - ) + super(ReceptiveFieldMapping, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._pos_x = None self._pos_y = None @@ -103,9 +101,7 @@ def number_azimuths(self): if self._pos_x is None: self._get_stim_table_stats() - return len( - self._pos_y - ) # TODO: Save this instead of calculating every time. + return len(self._pos_y) # TODO: Save this instead of calculating every time. @property def null_condition(self): @@ -121,20 +117,14 @@ def receptive_fields(self): if self._rf_matrix is None: bin_edges = np.linspace(0, 0.249, 3) - self.stim_table.loc[:, self._col_pos_y] = ( - 40.0 - self.stim_table[self._col_pos_y] - ) - presentationwise_response_matrix = ( - self.ecephys_session.presentationwise_spike_counts( - bin_edges=bin_edges, - stimulus_presentation_ids=self.stim_table.index.values, - unit_ids=self.unit_ids, - ) + self.stim_table.loc[:, self._col_pos_y] = 40.0 - self.stim_table[self._col_pos_y] + presentationwise_response_matrix = self.ecephys_session.presentationwise_spike_counts( + bin_edges=bin_edges, + stimulus_presentation_ids=self.stim_table.index.values, + unit_ids=self.unit_ids, ) - self._rf_matrix = self._response_by_stimulus_position( - presentationwise_response_matrix, self.stim_table - ) + self._rf_matrix = self._response_by_stimulus_position(presentationwise_response_matrix, self.stim_table) return self._rf_matrix @@ -176,29 +166,16 @@ def metrics(self): "on_screen_rf", ], ] = [self._get_rf_stats(unit) for unit in unit_ids] - metrics_df["firing_rate_rf"] = [ - self._get_overall_firing_rate(unit) for unit in unit_ids - ] + metrics_df["firing_rate_rf"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] metrics_df["fano_rf"] = [ - self._get_fano_factor( - unit, self._get_preferred_condition(unit) - ) - for unit in unit_ids + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids ] metrics_df["time_to_peak_rf"] = [ - self._get_time_to_peak( - unit, self._get_preferred_condition(unit) - ) - for unit in unit_ids - ] - metrics_df["lifetime_sparseness_rf"] = [ - self._get_lifetime_sparseness(unit) for unit in unit_ids + self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids ] + metrics_df["lifetime_sparseness_rf"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] metrics_df.loc[:, ["run_pval_rf", "run_mod_rf"]] = [ - self._get_running_modulation( - unit, self._get_preferred_condition(unit) - ) - for unit in unit_ids + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids ] self._metrics = metrics_df @@ -210,9 +187,7 @@ def known_stimulus_keys(cls): return ["receptive_field_mapping", "gabor", "gabors"] def _find_stimulus_key(self, stim_table): - known_keys_lc = [ - k.lower() for k in self.__class__.known_stimulus_keys() - ] + known_keys_lc = [k.lower() for k in self.__class__.known_stimulus_keys()] for table_key in stim_table["stimulus_name"].unique(): table_key_lc = table_key.lower() @@ -227,14 +202,10 @@ def _get_stim_table_stats(self): """Extract azimuths and elevations from stimulus table.""" self._pos_y = np.sort( - self.stimulus_conditions.loc[ - self.stimulus_conditions[self._col_pos_y] != "null" - ][self._col_pos_y].unique() + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_pos_y] != "null"][self._col_pos_y].unique() ) self._pos_x = np.sort( - self.stimulus_conditions.loc[ - self.stimulus_conditions[self._col_pos_x] != "null" - ][self._col_pos_x].unique() + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_pos_x] != "null"][self._col_pos_x].unique() ) def get_receptive_field(self, unit_id): @@ -285,18 +256,14 @@ def _response_by_stimulus_position( dataset = dataset.drop(time_key) dataset = dataset.assign_coords( - {row_key: ("stimulus_presentation_id", - presentations.loc[:, row_key].to_numpy())}) + {row_key: ("stimulus_presentation_id", presentations.loc[:, row_key].to_numpy())} + ) dataset = dataset.assign_coords( - {column_key: ("stimulus_presentation_id", - presentations.loc[:, column_key].to_numpy())}) + {column_key: ("stimulus_presentation_id", presentations.loc[:, column_key].to_numpy())} + ) dataset = dataset.to_dataframe() - dataset = ( - dataset.reset_index(unit_key) - .groupby([row_key, column_key, unit_key]) - .sum() - ) + dataset = dataset.reset_index(unit_key).groupby([row_key, column_key, unit_key]).sum() return dataset.to_xarray() @@ -331,9 +298,7 @@ def _get_rf_stats(self, unit_id): based on Gaussian fit """ rf = self._get_rf(unit_id) - spikes_per_trial = self.presentationwise_statistics.xs( - unit_id, level=1 - )["spike_counts"].values + spikes_per_trial = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"].values if np.sum(spikes_per_trial) < self._minimum_spike_count: return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, False @@ -346,20 +311,21 @@ def _get_rf_stats(self, unit_id): # print(self._params) # exit() - rf_thresh, azimuth, elevation, area = threshold_rf( - rf, self._mask_threshold - ) + rf_thresh, azimuth, elevation, area = threshold_rf(rf, self._mask_threshold) if is_rf_inverted(rf_thresh): rf = invert_rf(rf) ( - peak_height, - center_y, - center_x, - width_y, - width_x, - ), success = fit_2d_gaussian(rf) + ( + peak_height, + center_y, + center_x, + width_y, + width_x, + ), + success, + ) = fit_2d_gaussian(rf) on_screen = rf_on_screen(rf, center_y, center_x) height_deg = convert_pixels_to_degrees(width_y) @@ -380,51 +346,25 @@ def _get_rf_stats(self, unit_id): # VISUALIZATION def plot_raster(self, stimulus_condition_id, unit_id): - """Plot raster for one condition and one unit""" - idx_elev = np.where( - self.elevations - == self.stimulus_conditions.loc[stimulus_condition_id][ - self._col_pos_y - ] - )[0] - idx_azi = np.where( - self.azimuths - == self.stimulus_conditions.loc[stimulus_condition_id][ - self._col_pos_x - ] - )[0] + idx_elev = np.where(self.elevations == self.stimulus_conditions.loc[stimulus_condition_id][self._col_pos_y])[0] + idx_azi = np.where(self.azimuths == self.stimulus_conditions.loc[stimulus_condition_id][self._col_pos_x])[0] if len(idx_elev) == len(idx_azi) == 1: - - presentation_ids = self.presentationwise_statistics.xs( - unit_id, level=1 - )[ - self.presentationwise_statistics.xs(unit_id, level=1)[ - "stimulus_condition_id" - ] - == stimulus_condition_id + presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] == stimulus_condition_id ].index.values df = self.presentationwise_spike_times[ - ( - self.presentationwise_spike_times[ - "stimulus_presentation_id" - ].isin(presentation_ids) - ) + (self.presentationwise_spike_times["stimulus_presentation_id"].isin(presentation_ids)) & (self.presentationwise_spike_times["unit_id"] == unit_id) ] - x = ( - df.index.values - - self.stim_table.loc[df.stimulus_presentation_id].start_time - ) + x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) - idx_elev = ( - self.number_elevations - idx_elev - 1 - ) # reverse the elevation index so it matches the RF + idx_elev = self.number_elevations - idx_elev - 1 # reverse the elevation index so it matches the RF plt.subplot( self.number_elevations, @@ -465,10 +405,7 @@ def _gaussian_function_2d(peak_height, center_y, center_x, width_y, width_x): """ - return lambda y, x: peak_height * np.exp( - -(((center_y - y) / width_y) ** 2 + ((center_x - x) / width_x) ** 2) - / 2 - ) + return lambda y, x: peak_height * np.exp(-(((center_y - y) / width_y) ** 2 + ((center_x - x) / width_x) ** 2) / 2) def gaussian_moments_2d(data): @@ -501,23 +438,14 @@ def gaussian_moments_2d(data): center_y = (Y * data).sum() / total center_x = (X * data).sum() / total - if ( - np.isnan(center_y) - or np.isinf(center_y) - or np.isnan(center_x) - or np.isinf(center_x) - ): + if np.isnan(center_y) or np.isinf(center_y) or np.isnan(center_x) or np.isinf(center_x): return None col = data[:, int(center_x)] row = data[int(center_y), :] - width_y = np.sqrt( - np.abs((np.arange(row.size) - center_y) ** 2 * row).sum() / row.sum() - ) - width_x = np.sqrt( - np.abs((np.arange(col.size) - center_x) ** 2 * col).sum() / col.sum() - ) + width_y = np.sqrt(np.abs((np.arange(row.size) - center_y) ** 2 * row).sum() / row.sum()) + width_x = np.sqrt(np.abs((np.arange(col.size) - center_x) ** 2 * col).sum() / col.sum()) return height, center_y, center_x, width_y, width_x @@ -548,9 +476,7 @@ def fit_2d_gaussian(matrix): return (np.nan, np.nan, np.nan, np.nan, np.nan), False def errorfunction(p): - return np.ravel( - _gaussian_function_2d(*p)(*np.indices(matrix.shape)) - matrix - ) + return np.ravel(_gaussian_function_2d(*p)(*np.indices(matrix.shape)) - matrix) fit_params, ier = leastsq(errorfunction, params) success = True if ier < 5 else False @@ -631,9 +557,7 @@ def threshold_rf(rf, threshold): labels, num_features = ndi.label(rf_thresh) - best_label = np.argmax( - ndi.maximum(rf_filt, labels=labels, index=np.unique(labels)) - ) + best_label = np.argmax(ndi.maximum(rf_filt, labels=labels, index=np.unique(labels))) labels[labels != best_label] = 0 labels[labels > 0] = 1 @@ -650,9 +574,7 @@ def rf_on_screen(rf, center_y, center_x): return 0 < center_y < rf.shape[0] and 0 < center_x < rf.shape[1] -def convert_elevation_to_degrees( - elevation_in_pixels, elevation_offset_degrees=-30 -): +def convert_elevation_to_degrees(elevation_in_pixels, elevation_offset_degrees=-30): """Converts a pixel-based elevation into degrees relative to center of gaze @@ -673,10 +595,7 @@ def convert_elevation_to_degrees( ------- elevation_in_degrees : float """ - elevation_in_degrees = ( - convert_pixels_to_degrees(8 - elevation_in_pixels) - + elevation_offset_degrees - ) + elevation_in_degrees = convert_pixels_to_degrees(8 - elevation_in_pixels) + elevation_offset_degrees return elevation_in_degrees @@ -694,9 +613,7 @@ def convert_azimuth_to_degrees(azimuth_in_pixels, azimuth_offset_degrees=10): ------- azimuth_in_degrees : float """ - azimuth_in_degrees = ( - convert_pixels_to_degrees((azimuth_in_pixels)) + azimuth_offset_degrees - ) + azimuth_in_degrees = convert_pixels_to_degrees((azimuth_in_pixels)) + azimuth_offset_degrees return azimuth_in_degrees diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py index acbc39f35f..919a1f3f85 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py @@ -10,7 +10,8 @@ from ...circle_plots import FanPlotter import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) + +warnings.simplefilter(action="ignore", category=FutureWarning) logger = logging.getLogger(__name__) @@ -36,8 +37,15 @@ class StaticGratings(StimulusAnalysis): """ - def __init__(self, ecephys_session, col_ori='orientation', col_sf='spatial_frequency', col_phase='phase', - trial_duration=0.25, **kwargs): + def __init__( + self, + ecephys_session, + col_ori="orientation", + col_sf="spatial_frequency", + col_phase="phase", + trial_duration=0.25, + **kwargs, + ): super(StaticGratings, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._orivals = None self._number_ori = None @@ -57,18 +65,18 @@ def __init__(self, ecephys_session, col_ori='orientation', col_sf='spatial_frequ # self._module_name = 'Static Gratings' # TODO: module_name should be a static class variable if self._params is not None: - self._params = self._params.get('static_gratings', {}) - self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars + self._params = self._params.get("static_gratings", {}) + self._stimulus_key = self._params.get("stimulus_key", None) # Overwrites parent value with argvars else: self._params = {} @property def name(self): - return 'Static Gratings' + return "Static Gratings" @property def orivals(self): - """ Array of grating orientation conditions """ + """Array of grating orientation conditions""" if self._orivals is None: self._get_stim_table_stats() @@ -76,7 +84,7 @@ def orivals(self): @property def number_ori(self): - """ Number of grating orientation conditions """ + """Number of grating orientation conditions""" if self._number_ori is None: self._get_stim_table_stats() @@ -84,7 +92,7 @@ def number_ori(self): @property def sfvals(self): - """ Array of grating spatial frequency conditions """ + """Array of grating spatial frequency conditions""" if self._sfvals is None: self._get_stim_table_stats() @@ -92,7 +100,7 @@ def sfvals(self): @property def number_sf(self): - """ Number of grating orientation conditions """ + """Number of grating orientation conditions""" if self._number_sf is None: self._get_stim_table_stats() @@ -100,7 +108,7 @@ def number_sf(self): @property def phasevals(self): - """ Array of grating phase conditions """ + """Array of grating phase conditions""" if self._phasevals is None: self._get_stim_table_stats() @@ -108,7 +116,7 @@ def phasevals(self): @property def number_phase(self): - """ Number of grating phase conditions """ + """Number of grating phase conditions""" if self._number_phase is None: self._get_stim_table_stats() @@ -116,53 +124,63 @@ def number_phase(self): @property def null_condition(self): - """ Stimulus condition ID for null (blank) stimulus """ - return self.stimulus_conditions[self.stimulus_conditions[self._col_sf] == 'null'].index - + """Stimulus condition ID for null (blank) stimulus""" + return self.stimulus_conditions[self.stimulus_conditions[self._col_sf] == "null"].index @property def METRICS_COLUMNS(self): - return [('pref_sf_sg', np.float64), - ('pref_sf_multi_sg', bool), - ('pref_ori_sg', np.float64), - ('pref_ori_multi_sg', bool), - ('pref_phase_sg', np.float64), - ('pref_phase_multi_sg', bool), - ('g_osi_sg', np.float64), - ('time_to_peak_sg', np.float64), - ('firing_rate_sg', np.float64), - ('fano_sg', np.float64), - ('lifetime_sparseness_sg', np.float64), - ('run_pval_sg', np.float64), - ('run_mod_sg', np.float64)] + return [ + ("pref_sf_sg", np.float64), + ("pref_sf_multi_sg", bool), + ("pref_ori_sg", np.float64), + ("pref_ori_multi_sg", bool), + ("pref_phase_sg", np.float64), + ("pref_phase_multi_sg", bool), + ("g_osi_sg", np.float64), + ("time_to_peak_sg", np.float64), + ("firing_rate_sg", np.float64), + ("fano_sg", np.float64), + ("lifetime_sparseness_sg", np.float64), + ("run_pval_sg", np.float64), + ("run_mod_sg", np.float64), + ] @property def metrics(self): if self._metrics is None: - logger.info('Calculating metrics for ' + self.name) + logger.info("Calculating metrics for " + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self.stim_table) > 0: - metrics_df['pref_sf_sg'] = [self._get_pref_sf(unit) for unit in unit_ids] - metrics_df['pref_sf_multi_sg'] = [ + metrics_df["pref_sf_sg"] = [self._get_pref_sf(unit) for unit in unit_ids] + metrics_df["pref_sf_multi_sg"] = [ self._check_multiple_pref_conditions(unit_id, self._col_sf, self.sfvals) for unit_id in unit_ids ] - metrics_df['pref_ori_sg'] = [self._get_pref_ori(unit) for unit in unit_ids] - metrics_df['pref_ori_multi_sg'] = [ + metrics_df["pref_ori_sg"] = [self._get_pref_ori(unit) for unit in unit_ids] + metrics_df["pref_ori_multi_sg"] = [ self._check_multiple_pref_conditions(unit_id, self._col_ori, self.orivals) for unit_id in unit_ids ] - metrics_df['pref_phase_sg'] = [self._get_pref_phase(unit) for unit in unit_ids] - metrics_df['pref_phase_multi_sg'] = [ - self._check_multiple_pref_conditions(unit_id, self._col_phase, self.phasevals) for unit_id in unit_ids + metrics_df["pref_phase_sg"] = [self._get_pref_phase(unit) for unit in unit_ids] + metrics_df["pref_phase_multi_sg"] = [ + self._check_multiple_pref_conditions(unit_id, self._col_phase, self.phasevals) + for unit_id in unit_ids + ] + metrics_df["g_osi_sg"] = [ + self._get_osi(unit, metrics_df.loc[unit]["pref_sf_sg"], metrics_df.loc[unit]["pref_phase_sg"]) + for unit in unit_ids + ] + metrics_df["time_to_peak_sg"] = [ + self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["firing_rate_sg"] = [self._get_overall_firing_rate(unit) for unit in unit_ids] + metrics_df["fano_sg"] = [ + self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids + ] + metrics_df["lifetime_sparseness_sg"] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] + metrics_df.loc[:, ["run_pval_sg", "run_mod_sg"]] = [ + self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids ] - metrics_df['g_osi_sg'] = [self._get_osi(unit, metrics_df.loc[unit]['pref_sf_sg'], metrics_df.loc[unit]['pref_phase_sg']) for unit in unit_ids] - metrics_df['time_to_peak_sg'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids] - metrics_df['firing_rate_sg'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] - metrics_df['fano_sg'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids] - metrics_df['lifetime_sparseness_sg'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] - metrics_df.loc[:, ['run_pval_sg', 'run_mod_sg']] = \ - [self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] self._metrics = metrics_df @@ -170,17 +188,23 @@ def metrics(self): @classmethod def known_stimulus_keys(cls): - return ['static_gratings'] + return ["static_gratings"] def _get_stim_table_stats(self): - """ Extract orientations, spatial frequencies, and phases from the stimulus table """ - self._orivals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] != 'null'][self._col_ori].unique()) + """Extract orientations, spatial frequencies, and phases from the stimulus table""" + self._orivals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] != "null"][self._col_ori].unique() + ) self._number_ori = len(self._orivals) - self._sfvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_sf] != 'null'][self._col_sf].unique()) + self._sfvals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_sf] != "null"][self._col_sf].unique() + ) self._number_sf = len(self._sfvals) - self._phasevals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_phase] != 'null'][self._col_phase].unique()) + self._phasevals = np.sort( + self.stimulus_conditions.loc[self.stimulus_conditions[self._col_phase] != "null"][self._col_phase].unique() + ) self._number_phase = len(self._phasevals) def _get_pref_sf(self, unit_id): @@ -199,21 +223,26 @@ def _get_pref_sf(self, unit_id): """ # TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes # Combine the stimulus_condition_id values that have the save spatial-frequency - similar_conditions_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_sf] == sf].tolist() - for sf in self.sfvals] + similar_conditions_ids = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_sf] == sf].tolist() for sf in self.sfvals + ] # For each spatial frequency average up conditionwise_statistics 'spike_mean' column using the indicies above. # return the sf with the largest spike_mean. df = pd.DataFrame( index=self.sfvals, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions_ids]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions_ids + ] + }, ).rename_axis(self._col_sf) return df.idxmax().iloc[0] def _get_pref_ori(self, unit_id): - """ Calculate the preferred orientation condition for a given unit + """Calculate the preferred orientation condition for a given unit Parameters ---------- @@ -227,15 +256,21 @@ def _get_pref_ori(self, unit_id): """ # Combine the stimulus_condition_id values that have the save orientations - similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist() - for ori in self.orivals] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist() + for ori in self.orivals + ] # For each orientations average up conditionwise_statistics 'spike_mean' column using the indicies above. # Return the oris with the largest spike_mean. df = pd.DataFrame( index=self.orivals, - data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in similar_conditions]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] + }, ).rename_axis(self._col_ori) return df.idxmax().iloc[0] @@ -253,18 +288,24 @@ def _get_pref_phase(self, unit_id): pref_phase : float stimulus phase driving the maximal response """ - combined_cond_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_phase] == phase].tolist() - for phase in self.phasevals] + combined_cond_ids = [ + self.stimulus_conditions.index[self.stimulus_conditions[self._col_phase] == phase].tolist() + for phase in self.phasevals + ] df = pd.DataFrame( index=self.phasevals, - data = {'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() - for condition_inds in combined_cond_ids]} + data={ + "spike_mean": [ + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in combined_cond_ids + ] + }, ).rename_axis(self._col_phase) return df.idxmax().iloc[0] def _get_osi(self, unit_id, pref_sf, pref_phase): - """ Calculate the orientation selectivity for a given unit + """Calculate the orientation selectivity for a given unit Parameters ---------- @@ -280,98 +321,109 @@ def _get_osi(self, unit_id, pref_sf, pref_phase): osi : float orientation selectivity value """ - orivals_rad = deg2rad(self.orivals).astype('complex128') # TODO: can we use numpy deg2rad? + orivals_rad = deg2rad(self.orivals).astype("complex128") # TODO: can we use numpy deg2rad? condition_inds = self.stimulus_conditions[ - (self.stimulus_conditions[self._col_sf] == pref_sf) & - (self.stimulus_conditions[self._col_phase] == pref_phase) + (self.stimulus_conditions[self._col_sf] == pref_sf) + & (self.stimulus_conditions[self._col_phase] == pref_phase) ].index.values df = self.conditionwise_statistics.loc[unit_id].loc[condition_inds] df = df.assign(ori=self.stimulus_conditions.loc[df.index.values][self._col_ori]) - df = df.sort_values(by=['ori']) - tuning = np.array(df['spike_mean'].values) + df = df.sort_values(by=["ori"]) + tuning = np.array(df["spike_mean"].values) return osi(orivals_rad, tuning) ## VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): - """ Plot raster for one condition and one unit """ + """Plot raster for one condition and one unit""" idx_sf = np.where(self.sfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_sf])[0] idx_ori = np.where(self.orivals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_ori])[0] - + if len(idx_sf) == len(idx_ori) == 1: - - presentation_ids = \ - self.presentationwise_statistics.xs(unit_id, level=1)\ - [self.presentationwise_statistics.xs(unit_id, level=1)\ - ['stimulus_condition_id'] == stimulus_condition_id].index.values - - df = self.presentationwise_spike_times[ \ - (self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & \ - (self.presentationwise_spike_times['unit_id'] == unit_id) ] - + presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] == stimulus_condition_id + ].index.values + + df = self.presentationwise_spike_times[ + (self.presentationwise_spike_times["stimulus_presentation_id"].isin(presentation_ids)) + & (self.presentationwise_spike_times["unit_id"] == unit_id) + ] + x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time - _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) - - plt.subplot(self.number_sf, self.number_ori, idx_sf*self.number_ori + idx_ori + 1) - plt.scatter(x, y, c='k', s=1, alpha=0.25) - plt.axis('off') + _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) + plt.subplot(self.number_sf, self.number_ori, idx_sf * self.number_ori + idx_ori + 1) + plt.scatter(x, y, c="k", s=1, alpha=0.25) + plt.axis("off") def plot_response_summary(self, unit_id, bar_thickness=0.25): - - """ Plot the spike counts across conditions """ + """Plot the spike counts across conditions""" df = self.stimulus_conditions.drop(index=self.null_condition) - - df['sf_index'] = np.searchsorted(self.sfvals, df[self._col_sf].values) - df['ori_index'] = np.searchsorted(self.orivals, df[self._col_ori].values) - df['phase_index'] = np.searchsorted(self.phasevals, df[self._col_phase].values) - - cond_values = self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] - - x = df.loc[cond_values.values]['sf_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 - y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] - c = df.loc[cond_values.values]['phase_index'] - - plt.subplot(2,1,1) - plt.scatter(y,x,c=c,alpha=0.5,cmap='Blues',vmin=-5) + + df["sf_index"] = np.searchsorted(self.sfvals, df[self._col_sf].values) + df["ori_index"] = np.searchsorted(self.orivals, df[self._col_ori].values) + df["phase_index"] = np.searchsorted(self.phasevals, df[self._col_phase].values) + + cond_values = self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + + x = ( + df.loc[cond_values.values]["sf_index"] + + np.random.rand(cond_values.size) * bar_thickness + - bar_thickness / 2 + ) + y = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"] + c = df.loc[cond_values.values]["phase_index"] + + plt.subplot(2, 1, 1) + plt.scatter(y, x, c=c, alpha=0.5, cmap="Blues", vmin=-5) locs, labels = plt.yticks(ticks=np.arange(self.number_sf), labels=self.sfvals) - plt.ylabel('Spatial frequency') - plt.xlabel('Spikes per trial') - plt.ylim([self.number_sf,-1]) - - x = df.loc[cond_values.values]['ori_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 - y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] - c = df.loc[cond_values.values]['phase_index'] - - plt.subplot(2,1,2) - plt.scatter(x,y,c=c,alpha=0.5,cmap='Spectral') + plt.ylabel("Spatial frequency") + plt.xlabel("Spikes per trial") + plt.ylim([self.number_sf, -1]) + + x = ( + df.loc[cond_values.values]["ori_index"] + + np.random.rand(cond_values.size) * bar_thickness + - bar_thickness / 2 + ) + y = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"] + c = df.loc[cond_values.values]["phase_index"] + + plt.subplot(2, 1, 2) + plt.scatter(x, y, c=c, alpha=0.5, cmap="Spectral") locs, labels = plt.xticks(ticks=np.arange(self.number_ori), labels=self.orivals) - plt.xlabel('Orientation') - plt.ylabel('Spikes per trial') + plt.xlabel("Orientation") + plt.ylabel("Spikes per trial") def make_fan_plot(self, unit_id): - """ Make a 2P-style Fan Plot based on presentationwise spike counts""" - - angle_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_ori].values - r_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_sf].values - group_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_phase].values - data = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values - - null_trials = np.where(angle_data == 'null')[0] - + """Make a 2P-style Fan Plot based on presentationwise spike counts""" + + angle_data = self.stimulus_conditions.loc[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + ][self._col_ori].values + r_data = self.stimulus_conditions.loc[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + ][self._col_sf].values + group_data = self.stimulus_conditions.loc[ + self.presentationwise_statistics.xs(unit_id, level=1)["stimulus_condition_id"] + ][self._col_phase].values + data = self.presentationwise_statistics.xs(unit_id, level=1)["spike_counts"].values + + null_trials = np.where(angle_data == "null")[0] + angle_data = np.delete(angle_data, null_trials) r_data = np.delete(r_data, null_trials) group_data = np.delete(group_data, null_trials) data = np.delete(data, null_trials) - + cmin = np.min(data) cmax = np.max(data) fp = FanPlotter.for_static_gratings() - fp.plot(r_data = r_data, angle_data = angle_data, group_data = group_data, data =data, clim=[cmin, cmax]) + fp.plot(r_data=r_data, angle_data=angle_data, group_data=group_data, data=data, clim=[cmin, cmax]) fp.show_axes(closed=False) - plt.axis('off') + plt.axis("off") def fit_sf_tuning(sf_tuning_responses, sf_values, pref_sf_index): @@ -390,22 +442,30 @@ def fit_sf_tuning(sf_tuning_responses, sf_values, pref_sf_index): fit_sf = np.nan sf_low_cutoff = np.nan sf_high_cutoff = np.nan - if pref_sf_index in range(1, len(sf_values)-1): + if pref_sf_index in range(1, len(sf_values) - 1): # If the prefered spatial freq is an interior case try to fit the tunning curve with a gaussian. try: - popt, pcov = curve_fit(gauss_function, np.arange(len(sf_values)), sf_tuning_responses, p0=[np.amax(sf_tuning_responses), - pref_sf_index, 1.], maxfev=2000) - sf_prediction = gauss_function(np.arange(0., 4.1, 0.1), *popt) + popt, pcov = curve_fit( + gauss_function, + np.arange(len(sf_values)), + sf_tuning_responses, + p0=[np.amax(sf_tuning_responses), pref_sf_index, 1.0], + maxfev=2000, + ) + sf_prediction = gauss_function(np.arange(0.0, 4.1, 0.1), *popt) fit_sf_ind = popt[1] - fit_sf = 0.02*np.power(2, popt[1]) - low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin() - high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin() + sf_prediction.argmax() + fit_sf = 0.02 * np.power(2, popt[1]) + low_cut_ind = np.abs(sf_prediction - (sf_prediction.max() / 2.0))[: sf_prediction.argmax()].argmin() + high_cut_ind = ( + np.abs(sf_prediction - (sf_prediction.max() / 2.0))[sf_prediction.argmax() :].argmin() + + sf_prediction.argmax() + ) if low_cut_ind > 0: low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind] - sf_low_cutoff = 0.02*np.power(2, low_cutoff) + sf_low_cutoff = 0.02 * np.power(2, low_cutoff) elif high_cut_ind < 4: high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind] - sf_high_cutoff = 0.02*np.power(2, high_cutoff) + sf_high_cutoff = 0.02 * np.power(2, high_cutoff) except Exception: pass else: @@ -413,17 +473,25 @@ def fit_sf_tuning(sf_tuning_responses, sf_values, pref_sf_index): fit_sf_ind = pref_sf_index fit_sf = sf_values[pref_sf_index] try: - popt, pcov = curve_fit(exp_function, np.arange(len(sf_values)), sf_tuning_responses, - p0=[np.amax(sf_tuning_responses), 2., np.amin(sf_tuning_responses)], maxfev=2000) - sf_prediction = exp_function(np.arange(0., 4.1, 0.1), *popt) + popt, pcov = curve_fit( + exp_function, + np.arange(len(sf_values)), + sf_tuning_responses, + p0=[np.amax(sf_tuning_responses), 2.0, np.amin(sf_tuning_responses)], + maxfev=2000, + ) + sf_prediction = exp_function(np.arange(0.0, 4.1, 0.1), *popt) if pref_sf_index == 0: - high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin()+sf_prediction.argmax() + high_cut_ind = ( + np.abs(sf_prediction - (sf_prediction.max() / 2.0))[sf_prediction.argmax() :].argmin() + + sf_prediction.argmax() + ) high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind] - sf_high_cutoff = 0.02*np.power(2, high_cutoff) + sf_high_cutoff = 0.02 * np.power(2, high_cutoff) else: - low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin() + low_cut_ind = np.abs(sf_prediction - (sf_prediction.max() / 2.0))[: sf_prediction.argmax()].argmin() low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind] - sf_low_cutoff = 0.02*np.power(2, low_cutoff) + sf_low_cutoff = 0.02 * np.power(2, low_cutoff) except Exception: pass @@ -440,13 +508,13 @@ def get_sfdi(sf_tuning_responses, mean_sweeps_trials, bias=5): :return: The sfdi value (float) """ trial_mean = mean_sweeps_trials.mean() - sse_part = np.sqrt(np.sum((mean_sweeps_trials - trial_mean)**2) / (len(mean_sweeps_trials) - bias)) + sse_part = np.sqrt(np.sum((mean_sweeps_trials - trial_mean) ** 2) / (len(mean_sweeps_trials) - bias)) return (np.ptp(sf_tuning_responses)) / (np.ptp(sf_tuning_responses) + 2 * sse_part) def gauss_function(x, a, x0, sigma): - return a*np.exp(-(x-x0)**2/(2*sigma**2)) + return a * np.exp(-((x - x0) ** 2) / (2 * sigma**2)) def exp_function(x, a, b, c): - return a*np.exp(-b*x)+c + return a * np.exp(-b * x) + c diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/stimulus_analysis.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/stimulus_analysis.py index 95e1b8a1e0..52a520bc2c 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/stimulus_analysis.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/stimulus_analysis.py @@ -7,12 +7,11 @@ from scipy.ndimage import gaussian_filter from ..ecephys_session import EcephysSession -from allensdk.brain_observatory.ecephys.ecephys_session_api import \ - EcephysNwbSessionApi +from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysNwbSessionApi import warnings -warnings.simplefilter(action='ignore', category=RuntimeWarning) +warnings.simplefilter(action="ignore", category=RuntimeWarning) class StimulusAnalysis(object): @@ -25,20 +24,17 @@ def __init__(self, ecephys_session, trial_duration=None, **kwargs): if isinstance(ecephys_session, EcephysSession): self._ecephys_session = ecephys_session elif isinstance(ecephys_session, str): - nwb_version = kwargs.get('nwb_version', 2) - self._ecephys_session = EcephysSession.from_nwb_path( - path=ecephys_session, nwb_version=nwb_version) + nwb_version = kwargs.get("nwb_version", 2) + self._ecephys_session = EcephysSession.from_nwb_path(path=ecephys_session, nwb_version=nwb_version) elif isinstance(ecephys_session, EcephysNwbSessionApi): # nwb_version = kwargs.get('nwb_version', 2) self._ecephys_session = EcephysSession(api=ecephys_session) else: - raise TypeError( - f"Don't know how to make a stimulus analysis object from a " - f"{type(ecephys_session)}") + raise TypeError(f"Don't know how to make a stimulus analysis object from a {type(ecephys_session)}") self._unit_ids = None - self._unit_filter = kwargs.get('filter', None) - self._params = kwargs.get('params', None) + self._unit_filter = kwargs.get("filter", None) + self._params = kwargs.get("params", None) self._unit_count = None self._stim_table = None self._conditionwise_statistics = None @@ -49,7 +45,7 @@ def __init__(self, ecephys_session, trial_duration=None, **kwargs): self._spikes = None self._stim_table_spontaneous = None - self._stimulus_key = kwargs.get('stimulus_key', None) + self._stimulus_key = kwargs.get("stimulus_key", None) self._running_speed = None # self._sweep_events = None # self._mean_sweep_events = None @@ -64,12 +60,11 @@ def __init__(self, ecephys_session, trial_duration=None, **kwargs): # self._module_name = None # TODO: Remove, .name() should be hardcoded - self._psth_resolution = kwargs.get('psth_resolution', 0.001) + self._psth_resolution = kwargs.get("psth_resolution", 0.001) # Duration a spontaneous stimulus should last for before it gets # included in the analysis. - self._spontaneous_threshold = kwargs.get('spontaneous_threshold', - 100.0) + self._spontaneous_threshold = kwargs.get("spontaneous_threshold", 100.0) # Roughly the length of each stimulus duration, used for calculating # spike statistics @@ -87,16 +82,15 @@ def unit_ids(self): """Returns a list of unit IDs for which to apply the analysis""" if self._unit_ids is None: units_df = self.ecephys_session.units - if isinstance(self._unit_filter, - (list, tuple, np.ndarray, pd.Series)): + if isinstance(self._unit_filter, (list, tuple, np.ndarray, pd.Series)): # If the user passes a list/array of ids units_df = units_df.loc[self._unit_filter] elif isinstance(self._unit_filter, dict): - if 'unit_id' in self._unit_filter.keys(): + if "unit_id" in self._unit_filter.keys(): # If user wants to filter by the unit_id column which is # actually the dataframe index - units_df = units_df.loc[self._unit_filter['unit_id']] + units_df = units_df.loc[self._unit_filter["unit_id"]] else: # Create a mask for all units that match the all of @@ -111,7 +105,7 @@ def unit_ids(self): if units_df is None or units_df.empty: # If not units are found don't proceed. - raise Exception('Could not find units for ecephys session.') + raise Exception("Could not find units for ecephys session.") self._unit_ids = units_df.index.values @@ -126,7 +120,7 @@ def unit_count(self): @property def name(self): - """ Return the stimulus name.""" + """Return the stimulus name.""" return self._module_name @property @@ -134,9 +128,7 @@ def trial_duration(self): if self._trial_duration is None or self._trial_duration < 0.0: # TODO: Should we calculate trial_duration from # min(stim_table[duration']) if not set by user/subclass? - raise TypeError( - f'Invalid value {self._trial_duration} for parameter ' - f'"trial_duration".') + raise TypeError(f'Invalid value {self._trial_duration} for parameter "trial_duration".') return self._trial_duration @@ -150,8 +142,7 @@ def spikes(self): if len(self._spikes) > self.unit_count: # if a filter has been applied such that not all the cells # are being used in the analysis - self._spikes = {k: v for k, v in self._spikes.items() if - k in self.unit_ids} + self._spikes = {k: v for k, v in self._spikes.items() if k in self.unit_ids} return self._spikes @@ -165,20 +156,17 @@ def stim_table(self): self._stimulus_key = self._find_stimulus_key(stims_table) if self._stimulus_key is None: raise Exception( - 'Could not find appropriate stimulus_name key for ' - 'current stimulus type. Please ' - 'specify using the stimulus_key parameter.') + "Could not find appropriate stimulus_name key for " + "current stimulus type. Please " + "specify using the stimulus_key parameter." + ) self._stim_table = self.ecephys_session.get_stimulus_table( - [self._stimulus_key] if isinstance(self._stimulus_key, - str) else - self._stimulus_key + [self._stimulus_key] if isinstance(self._stimulus_key, str) else self._stimulus_key ) if self._stim_table.empty: - raise Exception( - f'Could not find stimulus data with "stimulus_key" ' - f'{self._stimulus_key}') + raise Exception(f'Could not find stimulus data with "stimulus_key" {self._stimulus_key}') # TODO: Should we remove columns that are not relevant to the # selected stimulus? If a feature for another @@ -192,9 +180,8 @@ def _find_stimulus_key(self, stim_table): :param stim_table: :return: """ - known_keys_lc = [k.lower() for k in - self.__class__.known_stimulus_keys()] - for table_key in stim_table['stimulus_name'].unique(): + known_keys_lc = [k.lower() for k in self.__class__.known_stimulus_keys()] + for table_key in stim_table["stimulus_name"].unique(): if table_key.lower() in known_keys_lc: return table_key @@ -203,11 +190,11 @@ def _find_stimulus_key(self, stim_table): @property def known_spontaneous_keys(self): - return ['spontaneous', "spontaneous_activity"] + return ["spontaneous", "spontaneous_activity"] @property def total_presentations(self): - """ Total nmber of presentations / trials""" + """Total nmber of presentations / trials""" return len(self.stim_table) @property @@ -230,12 +217,10 @@ def stim_table_spontaneous(self): # TODO: This may not be need anymore? Ask the scientists if # sweep_p_events will be required in the future. if self._stim_table_spontaneous is None: - stim_table = self.ecephys_session.get_stimulus_table( - self.known_spontaneous_keys) + stim_table = self.ecephys_session.get_stimulus_table(self.known_spontaneous_keys) # TODO: If duration does not exists in stim_table create it from # stop and start times - self._stim_table_spontaneous = stim_table[ - stim_table['duration'] > self._spontaneous_threshold] + self._stim_table_spontaneous = stim_table[stim_table["duration"] > self._spontaneous_threshold] return self._stim_table_spontaneous @@ -267,39 +252,31 @@ def conditionwise_psth(self): if self._conditionwise_psth is None: if self._psth_resolution > self.trial_duration: - warnings.warn( - 'parameter "psth_resolution" > "trial_duration", ' - 'PSTH will not be properly created.') + warnings.warn('parameter "psth_resolution" > "trial_duration", PSTH will not be properly created.') # get the spike-counts for every stimulus_presentation_id dataset = self.ecephys_session.presentationwise_spike_counts( - bin_edges=np.arange(0, self.trial_duration, - self._psth_resolution), + bin_edges=np.arange(0, self.trial_duration, self._psth_resolution), stimulus_presentation_ids=self.stim_table.index.values, - unit_ids=self.unit_ids + unit_ids=self.unit_ids, ) # replace the stimulus_presentation_id (which will be unique for # every single stim) with the corresponding # stimulus_condition_id (which will be shared among presenations # with the same conditions. - da = dataset.assign_coords( - stimulus_presentation_id=self.stim_table[ - 'stimulus_condition_id'].values) - da = da.rename( - {'stimulus_presentation_id': 'stimulus_condition_id'}) + da = dataset.assign_coords(stimulus_presentation_id=self.stim_table["stimulus_condition_id"].values) + da = da.rename({"stimulus_presentation_id": "stimulus_condition_id"}) # Average spike counts across each stimulus_condition_id. - n_stimuli = len(da['stimulus_condition_id']) - n_cond_ids = len( - np.unique(da.coords['stimulus_condition_id'].values)) + n_stimuli = len(da["stimulus_condition_id"]) + n_cond_ids = len(np.unique(da.coords["stimulus_condition_id"].values)) if n_stimuli == n_cond_ids: # If every condition_id is unique then calling # groupby().mean() is unnecessary and will raise an error. self._conditionwise_psth = da else: - self._conditionwise_psth = da.groupby( - 'stimulus_condition_id').mean(dim='stimulus_condition_id') + self._conditionwise_psth = da.groupby("stimulus_condition_id").mean(dim="stimulus_condition_id") return self._conditionwise_psth @@ -316,9 +293,9 @@ def conditionwise_statistics(self): spike_std and stimulus_presentation_count information. """ if self._conditionwise_statistics is None: - self._conditionwise_statistics = \ - self.ecephys_session.conditionwise_spike_statistics( - self.stim_table.index.values, self.unit_ids) + self._conditionwise_statistics = self.ecephys_session.conditionwise_spike_statistics( + self.stim_table.index.values, self.unit_ids + ) return self._conditionwise_statistics @@ -336,10 +313,9 @@ def presentationwise_spike_times(self): """ if self._presentationwise_spikes is None: - self._presentationwise_spikes = \ - self.ecephys_session.presentationwise_spike_times( - stimulus_presentation_ids=self.stim_table.index.values, - unit_ids=self.unit_ids) + self._presentationwise_spikes = self.ecephys_session.presentationwise_spike_times( + stimulus_presentation_ids=self.stim_table.index.values, unit_ids=self.unit_ids + ) return self._presentationwise_spikes @@ -360,17 +336,19 @@ def presentationwise_statistics(self): # for each presentation_id and unit_id get the spike_counts # across the entire duration. Since there is only # a single bin we can drop time_relative_to_stimulus_onset. - df = self.ecephys_session.presentationwise_spike_counts( - bin_edges=np.array([0.0, self.trial_duration]), - stimulus_presentation_ids=self.stim_table.index.values, - unit_ids=self.unit_ids - ).to_dataframe().reset_index( - level='time_relative_to_stimulus_onset', drop=True) + df = ( + self.ecephys_session.presentationwise_spike_counts( + bin_edges=np.array([0.0, self.trial_duration]), + stimulus_presentation_ids=self.stim_table.index.values, + unit_ids=self.unit_ids, + ) + .to_dataframe() + .reset_index(level="time_relative_to_stimulus_onset", drop=True) + ) # left join table with stimulus_condition_id and mean # running_speed joined on stimulus_presentation_id - df = df.join(self.stim_table.loc[df.index.levels[0].values][ - 'stimulus_condition_id']) + df = df.join(self.stim_table.loc[df.index.levels[0].values]["stimulus_condition_id"]) self._presentationwise_statistics = df.join(self.running_speed) return self._presentationwise_statistics @@ -388,11 +366,10 @@ def stimulus_conditions(self): """ if self._stimulus_conditions is None: - condition_list = self.stim_table['stimulus_condition_id'].unique() - self._stimulus_conditions = \ - self.ecephys_session.stimulus_conditions[ - self.ecephys_session.stimulus_conditions.index.isin( - condition_list)] + condition_list = self.stim_table["stimulus_condition_id"].unique() + self._stimulus_conditions = self.ecephys_session.stimulus_conditions[ + self.ecephys_session.stimulus_conditions.index.isin(condition_list) + ] return self._stimulus_conditions @@ -409,24 +386,21 @@ def running_speed(self): """ if self._running_speed is None: + def get_velocity(presentation_id): """Helper function for getting avg. velocities for a given presenation_id""" pres_row = self.stim_table.loc[presentation_id] - mask = \ - ((self.ecephys_session.running_speed['start_time'] >= - pres_row['start_time']) & - (self.ecephys_session.running_speed['start_time'] < - pres_row['stop_time'])) + mask = (self.ecephys_session.running_speed["start_time"] >= pres_row["start_time"]) & ( + self.ecephys_session.running_speed["start_time"] < pres_row["stop_time"] + ) - return self.ecephys_session.running_speed[mask][ - 'velocity'].mean() + return self.ecephys_session.running_speed[mask]["velocity"].mean() self._running_speed = pd.DataFrame( index=self.stim_table.index.values, - data={'running_speed': [get_velocity(i) for i in - self.stim_table.index.values] - }).rename_axis('stimulus_presentation_id') + data={"running_speed": [get_velocity(i) for i in self.stim_table.index.values]}, + ).rename_axis("stimulus_presentation_id") return self._running_speed @@ -439,9 +413,9 @@ def metrics(self): def empty_metrics_table(self): empty_array = np.zeros((self.unit_count, len(self.METRICS_COLUMNS))) - df = pd.DataFrame(empty_array, - index=pd.Index(self.unit_ids, name='unit_id'), - columns=[x[0] for x in self.METRICS_COLUMNS]) + df = pd.DataFrame( + empty_array, index=pd.Index(self.unit_ids, name="unit_id"), columns=[x[0] for x in self.METRICS_COLUMNS] + ) df = df.astype(dict(self.METRICS_COLUMNS)) df[df == 0] = np.nan return df @@ -462,65 +436,58 @@ def _get_preferred_condition(self, unit_id): # stimulus_condition_id that gives the highest # value. try: - df = self.conditionwise_statistics.drop( - index=self.null_condition, level=1) + df = self.conditionwise_statistics.drop(index=self.null_condition, level=1) except (IndexError, NotImplementedError, KeyError): df = self.conditionwise_statistics # TODO: Calculated preferred condition_id once for all units and # store in a table. - self._preferred_condition[unit_id] = df.loc[unit_id][ - 'spike_mean'].idxmax() + self._preferred_condition[unit_id] = df.loc[unit_id]["spike_mean"].idxmax() return self._preferred_condition[unit_id] - def _check_multiple_pref_conditions(self, unit_id, stim_cond_col, - valid_conditions): + def _check_multiple_pref_conditions(self, unit_id, stim_cond_col, valid_conditions): # find all stimulus_condition which share the same 'stim_cond_col' ( # eg TF, ORI, etc) value, calculate the avg # spiking - similar_conditions = [self.stimulus_conditions.index[ - self.stimulus_conditions[ - stim_cond_col] == v].tolist() - for v in valid_conditions] + similar_conditions = [ + self.stimulus_conditions.index[self.stimulus_conditions[stim_cond_col] == v].tolist() + for v in valid_conditions + ] spike_means = [ - self.conditionwise_statistics.loc[unit_id].loc[condition_inds][ - 'spike_mean'].mean() - for condition_inds in similar_conditions] + self.conditionwise_statistics.loc[unit_id].loc[condition_inds]["spike_mean"].mean() + for condition_inds in similar_conditions + ] # Check if there is more than one stimulus condition that provokes a # maximum response return len(np.argwhere(spike_means == np.amax(spike_means))) > 1 - def _get_running_modulation(self, unit_id, preferred_condition, - threshold=1.0): + def _get_running_modulation(self, unit_id, preferred_condition, threshold=1.0): """Get running modulation for the preferred condition of a given unit""" subset = self.presentationwise_statistics[ - self.presentationwise_statistics[ - 'stimulus_condition_id'] == preferred_condition - ].xs(unit_id, level='unit_id') + self.presentationwise_statistics["stimulus_condition_id"] == preferred_condition + ].xs(unit_id, level="unit_id") - spike_counts = subset['spike_counts'].values - running_speeds = subset['running_speed'].values + spike_counts = subset["spike_counts"].values + running_speeds = subset["running_speed"].values return running_modulation(spike_counts, running_speeds, threshold) def _get_lifetime_sparseness(self, unit_id): """Computes lifetime sparseness of responses for one unit""" - df = self.conditionwise_statistics.drop(index=self.null_condition, - level=1, errors='ignore') - responses = df.loc[unit_id]['spike_count'].values + df = self.conditionwise_statistics.drop(index=self.null_condition, level=1, errors="ignore") + responses = df.loc[unit_id]["spike_count"].values return lifetime_sparseness(responses) def _get_fano_factor(self, unit_id, preferred_condition): # See: https://en.wikipedia.org/wiki/Fano_factor subset = self.presentationwise_statistics[ - self.presentationwise_statistics[ - 'stimulus_condition_id'] == preferred_condition - ].xs(unit_id, level=1) + self.presentationwise_statistics["stimulus_condition_id"] == preferred_condition + ].xs(unit_id, level=1) - spike_counts = subset['spike_counts'].values + spike_counts = subset["spike_counts"].values return fano_factor(spike_counts) def _get_time_to_peak(self, unit_id, preferred_condition): @@ -529,43 +496,38 @@ def _get_time_to_peak(self, unit_id, preferred_condition): try: # TODO: Try to find a way to generalize that doesn't rely on # conditionwise_psth - psth = self.conditionwise_psth.sel( - unit_id=unit_id, stimulus_condition_id=preferred_condition) - peak_time = psth.where(psth == psth.max(), drop=True)[ - 'time_relative_to_stimulus_onset'][0].values + psth = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=preferred_condition) + peak_time = psth.where(psth == psth.max(), drop=True)["time_relative_to_stimulus_onset"][0].values except Exception: peak_time = np.nan return peak_time def _get_overall_firing_rate(self, unit_id): - """ Average firing rate over the entire stimulus interval""" + """Average firing rate over the entire stimulus interval""" if self._block_starts is None: # For the stimulus, create a list of start and stop times for # the given block of trials. Only needs to be # calculated once TODO: see if python allows for private # property variables - start_time_intervals = np.diff(self.stim_table['start_time']) + start_time_intervals = np.diff(self.stim_table["start_time"]) interval_end_inds = np.concatenate( - (np.where(start_time_intervals > self.trial_duration * 2)[0], - np.array([self.total_presentations - 1]))) - interval_start_inds = np.concatenate((np.array([0]), - np.where( - start_time_intervals > - self.trial_duration * 2)[ - 0] + 1)) - - self._block_starts = self.stim_table.iloc[interval_start_inds][ - 'start_time'].values - self._block_stops = self.stim_table.iloc[interval_end_inds][ - 'stop_time'].values + (np.where(start_time_intervals > self.trial_duration * 2)[0], np.array([self.total_presentations - 1])) + ) + interval_start_inds = np.concatenate( + (np.array([0]), np.where(start_time_intervals > self.trial_duration * 2)[0] + 1) + ) + + self._block_starts = self.stim_table.iloc[interval_start_inds]["start_time"].values + self._block_stops = self.stim_table.iloc[interval_end_inds]["stop_time"].values # TODO: Check start and start times that differences are positive return overall_firing_rate( start_times=self._block_starts, stop_times=self._block_stops, - spike_times=self.ecephys_session.spike_times[unit_id]) + spike_times=self.ecephys_session.spike_times[unit_id], + ) def get_intrinsic_timescale(self, unit_ids): """Calculates the intrinsic timescale for a subset of units""" @@ -574,7 +536,7 @@ def get_intrinsic_timescale(self, unit_ids): dataset = self.ecephys_session.presentationwise_spike_counts( bin_edges=np.arange(0, self.trial_duration, 0.025), stimulus_presentation_ids=self.stim_table.index.values, - unit_ids=unit_ids + unit_ids=unit_ids, ) rsc_time_matrix = calculate_time_delayed_correlation(dataset) t, y, y_std, a, intrinsic_timescale, c = fit_exp(rsc_time_matrix) @@ -584,10 +546,9 @@ def get_intrinsic_timescale(self, unit_ids): # VISUALIZATION ############ def plot_conditionwise_raster(self, unit_id): - """ Plot a matrix of rasters for each condition (orientations x - temporal frequencies) """ - _ = [self.plot_raster(cond, unit_id) for cond in - self.stimulus_conditions.index.values] + """Plot a matrix of rasters for each condition (orientations x + temporal frequencies)""" + _ = [self.plot_raster(cond, unit_id) for cond in self.stimulus_conditions.index.values] def plot_raster(self, condition, unit_id): raise NotImplementedError() @@ -627,10 +588,8 @@ def running_modulation(spike_counts, running_speeds, speed_threshold=1.0): run_mod : float or Nan Relative difference between running and stationary mean firing rates. """ - if (len(spike_counts) != len(running_speeds)): - warnings.warn( - 'spike_counts and running_speeds must be arrays of the same ' - 'shape.') + if len(spike_counts) != len(running_speeds): + warnings.warn("spike_counts and running_speeds must be arrays of the same shape.") return np.nan, np.nan # keep track of when the animal is and isn't running @@ -677,14 +636,11 @@ def lifetime_sparseness(responses): """ if len(responses) <= 1: # Unable to calculate, return nan - warnings.warn( - 'responses array must contain at least two or more values to ' - 'calculate.') + warnings.warn("responses array must contain at least two or more values to calculate.") return np.nan coeff = 1.0 / len(responses) - return (1.0 - coeff * ((np.power(np.sum(responses), 2)) / ( - np.sum(np.power(responses, 2))))) / (1.0 - coeff) + return (1.0 - coeff * ((np.power(np.sum(responses), 2)) / (np.sum(np.power(responses, 2))))) / (1.0 - coeff) def fano_factor(spike_counts): @@ -725,8 +681,7 @@ def overall_firing_rate(start_times, stop_times, spike_times): firing_rate : float """ if len(start_times) != len(stop_times): - warnings.warn( - 'start_times and stop_times must be arrays of the same length') + warnings.warn("start_times and stop_times must be arrays of the same length") return np.nan if len(spike_times) == 0: @@ -736,12 +691,10 @@ def overall_firing_rate(start_times, stop_times, spike_times): total_time = np.sum(stop_times - start_times) if total_time <= 0: # Probably start and stop times got inverted. - warnings.warn(f'The total duration was {total_time} seconds.') + warnings.warn(f"The total duration was {total_time} seconds.") return np.nan - return np.sum( - spike_times.searchsorted(stop_times) - spike_times.searchsorted( - start_times)) / total_time + return np.sum(spike_times.searchsorted(stop_times) - spike_times.searchsorted(start_times)) / total_time def get_fr(spikes, num_timestep_second=30, sweep_length=3.1, filter_width=0.1): @@ -773,8 +726,7 @@ def get_fr(spikes, num_timestep_second=30, sweep_length=3.1, filter_width=0.1): return fr -def reliability(unit_sweeps, padding=1.0, num_timestep_second=30, - filter_width=0.1, window_beg=0, window_end=None): +def reliability(unit_sweeps, padding=1.0, num_timestep_second=30, filter_width=0.1, window_beg=0, window_end=None): """Computes the trial-to-trial reliability for a set of sweeps for a given cell @@ -791,12 +743,9 @@ def reliability(unit_sweeps, padding=1.0, num_timestep_second=30, corr_matrix = np.empty((len(unit_sweeps), len(unit_sweeps))) fr_window = slice(window_beg, window_end) for i in range(len(unit_sweeps)): - fri = get_fr(unit_sweeps[i], num_timestep_second=num_timestep_second, - filter_width=filter_width) + fri = get_fr(unit_sweeps[i], num_timestep_second=num_timestep_second, filter_width=filter_width) for j in range(len(unit_sweeps)): - frj = get_fr(unit_sweeps[j], - num_timestep_second=num_timestep_second, - filter_width=filter_width) + frj = get_fr(unit_sweeps[j], num_timestep_second=num_timestep_second, filter_width=filter_width) # Warning: the pearson coefficient is likely to have a # denominator of 0 for some cells/stimulus and give # a divide by 0 warning. @@ -828,7 +777,7 @@ def osi(orivals, tuning): in radians) of the responses. """ if len(orivals) == 0 or len(orivals) != len(tuning): - warnings.warn('orivals and tunings are of different lengths') + warnings.warn("orivals and tunings are of different lengths") return np.nan tuning_sum = tuning.sum() @@ -858,7 +807,7 @@ def dsi(orivals, tuning): in radians) of the responses. """ if len(orivals) == 0 or len(orivals) != len(tuning): - warnings.warn('orivals and tunings are of different lengths') + warnings.warn("orivals and tunings are of different lengths") return np.nan tuning_sum = tuning.sum() @@ -870,7 +819,7 @@ def dsi(orivals, tuning): def deg2rad(arr): - """ Converts array-like input from degrees to radians""" + """Converts array-like input from degrees to radians""" # TODO: Is there any reason not to use np.deg2rad? return arr / 180 * np.pi @@ -883,8 +832,7 @@ def fit_exp(rsc_time_matrix): t = np.arange(len(tmp))[1:] y = gaussian_filter(np.nanmean(tmp, axis=0)[1:], 0.8) - p, amo = curve_fit(lambda t, a, b, c: a * np.exp(-1 / b * t) + c, t, y, - p0=(-4, 2, 1), maxfev=1000000000) + p, amo = curve_fit(lambda t, a, b, c: a * np.exp(-1 / b * t) + c, t, y, p0=(-4, 2, 1), maxfev=1000000000) a = p[0] b = p[1] # this is the intrinsic timescale @@ -901,17 +849,14 @@ def calculate_time_delayed_correlation(dataset): rsc_time_matrix = np.zeros((num_units, nbins, nbins)) * np.nan for unit_idx, unit in enumerate(dataset.unit_id): - spikes_for_unit = dataset.sel(unit_id=unit).data for i in np.arange(nbins - 1): for j in np.arange(i + 1, nbins): # remove zero spike count bins - good_trials = \ - (spikes_for_unit[:, i] * spikes_for_unit[:, j]) > 0 + good_trials = (spikes_for_unit[:, i] * spikes_for_unit[:, j]) > 0 - r, p = st.pearsonr(spikes_for_unit[good_trials, i], - spikes_for_unit[good_trials, j]) + r, p = st.pearsonr(spikes_for_unit[good_trials, i], spikes_for_unit[good_trials, j]) rsc_time_matrix[unit_idx, i, j] = r return rsc_time_matrix diff --git a/allensdk/brain_observatory/ecephys/stimulus_sync.py b/allensdk/brain_observatory/ecephys/stimulus_sync.py index a59ae8da64..71858d2827 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_sync.py +++ b/allensdk/brain_observatory/ecephys/stimulus_sync.py @@ -6,10 +6,7 @@ def trimmed_stats(data, pctiles=(10, 90)): low = np.percentile(data, pctiles[0]) high = np.percentile(data, pctiles[1]) - trimmed = data[np.logical_and( - data <= high, - data >= low - )] + trimmed = data[np.logical_and(data <= high, data >= low)] return np.mean(trimmed), np.std(trimmed) @@ -17,32 +14,43 @@ def trimmed_stats(data, pctiles=(10, 90)): def trim_discontiguous_vsyncs(vs_times, photodiode_cycle=60): vs_times = np.array(vs_times) - breaks = np.where(np.diff(vs_times) > (1/photodiode_cycle)*100)[0] + breaks = np.where(np.diff(vs_times) > (1 / photodiode_cycle) * 100)[0] if len(breaks) > 0: - chunk_sizes = np.diff(np.concatenate((np.array([0, ]), - breaks, - np.array([len(vs_times), ])))) + chunk_sizes = np.diff( + np.concatenate( + ( + np.array( + [ + 0, + ] + ), + breaks, + np.array( + [ + len(vs_times), + ] + ), + ) + ) + ) largest_chunk = np.argmax(chunk_sizes) if largest_chunk == 0: - return vs_times[:np.min(breaks+1)] + return vs_times[: np.min(breaks + 1)] elif largest_chunk == len(breaks): - return vs_times[np.max(breaks+1):] + return vs_times[np.max(breaks + 1) :] else: - return vs_times[breaks[largest_chunk-1]:breaks[largest_chunk]] + return vs_times[breaks[largest_chunk - 1] : breaks[largest_chunk]] else: return vs_times -def separate_vsyncs_and_photodiode_times(vs_times, - pd_times, - photodiode_cycle=60): - +def separate_vsyncs_and_photodiode_times(vs_times, pd_times, photodiode_cycle=60): vs_times = np.array(vs_times) pd_times = np.array(pd_times) - breaks = np.where(np.diff(vs_times) > (1/photodiode_cycle)*100)[0] + breaks = np.where(np.diff(vs_times) > (1 / photodiode_cycle) * 100)[0] shift = 2.0 break_times = [-shift] @@ -53,11 +61,8 @@ def separate_vsyncs_and_photodiode_times(vs_times, pd_times_out = [] for indx, b in enumerate(break_times[:-1]): - - pd_in_range = np.where((pd_times > break_times[indx] + shift) * - (pd_times <= break_times[indx+1] + shift))[0] - vs_in_range = np.where((vs_times > break_times[indx]) * - (vs_times <= break_times[indx+1]))[0] + pd_in_range = np.where((pd_times > break_times[indx] + shift) * (pd_times <= break_times[indx + 1] + shift))[0] + vs_in_range = np.where((vs_times > break_times[indx]) * (vs_times <= break_times[indx + 1]))[0] vs_times_out.append(vs_times[vs_in_range]) pd_times_out.append(pd_times[pd_in_range]) @@ -65,22 +70,19 @@ def separate_vsyncs_and_photodiode_times(vs_times, return vs_times_out, pd_times_out -def trim_border_pulses(pd_times, vs_times, frame_interval=1/60, num_frames=5): +def trim_border_pulses(pd_times, vs_times, frame_interval=1 / 60, num_frames=5): pd_times = np.array(pd_times) - return pd_times[np.logical_and( - pd_times >= vs_times[0], - pd_times <= vs_times[-1] + num_frames * frame_interval - )] + return pd_times[np.logical_and(pd_times >= vs_times[0], pd_times <= vs_times[-1] + num_frames * frame_interval)] def correct_on_off_effects(pd_times): - ''' + """ Notes ----- This cannot (without additional info) determine whether an assymmetric offset is odd-long or even-long. - ''' + """ pd_diff = np.diff(pd_times) odd_diff_mean, odd_diff_std = trimmed_stats(pd_diff[1::2]) @@ -104,14 +106,11 @@ def flag_unexpected_edges(pd_times, ndevs=10): diff_mean, diff_std = trimmed_stats(pd_diff) expected_duration_mask = np.ones(pd_diff.size) - expected_duration_mask[np.logical_or( - pd_diff < diff_mean - ndevs * diff_std, - pd_diff > diff_mean + ndevs * diff_std - )] = 0 - expected_duration_mask[1:] = np.logical_and(expected_duration_mask[:-1], - expected_duration_mask[1:]) - expected_duration_mask = np.concatenate([expected_duration_mask, - [expected_duration_mask[-1]]]) + expected_duration_mask[ + np.logical_or(pd_diff < diff_mean - ndevs * diff_std, pd_diff > diff_mean + ndevs * diff_std) + ] = 0 + expected_duration_mask[1:] = np.logical_and(expected_duration_mask[:-1], expected_duration_mask[1:]) + expected_duration_mask = np.concatenate([expected_duration_mask, [expected_duration_mask[-1]]]) return expected_duration_mask @@ -123,15 +122,11 @@ def fix_unexpected_edges(pd_times, ndevs=10, cycle=60, max_frame_offset=4): frame_interval = diff_mean / cycle bad_edges = np.where(expected_duration_mask == 0)[0] - bad_blocks = np.sort(np.unique(np.concatenate([ - [0], - np.where(np.diff(bad_edges) > 1)[0] + 1, - [len(bad_edges)] - ]))) + bad_blocks = np.sort(np.unique(np.concatenate([[0], np.where(np.diff(bad_edges) > 1)[0] + 1, [len(bad_edges)]]))) output_edges = [] for low, high in zip(bad_blocks[:-1], bad_blocks[1:]): - current_bad_edge_indices = bad_edges[low: high-1] + current_bad_edge_indices = bad_edges[low : high - 1] current_bad_edges = pd_times[current_bad_edge_indices] low_bound = pd_times[current_bad_edge_indices[0]] high_bound = pd_times[current_bad_edge_indices[-1] + 1] @@ -139,21 +134,20 @@ def fix_unexpected_edges(pd_times, ndevs=10, cycle=60, max_frame_offset=4): edges_missing = int(np.around((high_bound - low_bound) / diff_mean)) expected = np.linspace(low_bound, high_bound, edges_missing + 1) - distances = distance.cdist(current_bad_edges[:, None], - expected[:, None]) + distances = distance.cdist(current_bad_edges[:, None], expected[:, None]) distances = np.around(distances / frame_interval).astype(int) min_offsets = np.amin(distances, axis=0) min_offset_indices = np.argmin(distances, axis=0) - output_edges = np.concatenate([ - output_edges, - expected[min_offsets > max_frame_offset], - current_bad_edges[min_offset_indices[min_offsets <= - max_frame_offset]] - ]) + output_edges = np.concatenate( + [ + output_edges, + expected[min_offsets > max_frame_offset], + current_bad_edges[min_offset_indices[min_offsets <= max_frame_offset]], + ] + ) - return np.sort(np.concatenate([output_edges, - pd_times[expected_duration_mask > 0]])) + return np.sort(np.concatenate([output_edges, pd_times[expected_duration_mask > 0]])) def estimate_frame_duration(pd_times, cycle=60): @@ -165,14 +159,8 @@ def assign_to_last(index, starts, ends, frame_duration, irregularity, cycle): return starts, ends -def allocate_by_vsync(vs_diff, - index, - starts, - ends, - frame_duration, - irregularity, - cycle): - current_vs_diff = vs_diff[index * cycle: (index + 1) * cycle] +def allocate_by_vsync(vs_diff, index, starts, ends, frame_duration, irregularity, cycle): + current_vs_diff = vs_diff[index * cycle : (index + 1) * cycle] sign = np.sign(irregularity) if sign > 0: @@ -181,56 +169,42 @@ def allocate_by_vsync(vs_diff, vs_ind = np.argmin(current_vs_diff) ends[vs_ind:] += sign * frame_duration - starts[vs_ind + 1:] += sign * frame_duration + starts[vs_ind + 1 :] += sign * frame_duration return starts, ends -def compute_frame_times(photodiode_times, - frame_duration, - num_frames, - cycle, - irregular_interval_policy=assign_to_last): - +def compute_frame_times(photodiode_times, frame_duration, num_frames, cycle, irregular_interval_policy=assign_to_last): indices = np.arange(num_frames) starts = np.zeros(num_frames, dtype=float) ends = np.zeros(num_frames, dtype=float) num_intervals = len(photodiode_times) - 1 - for start_index, (start_time, end_time) in \ - enumerate(zip(photodiode_times[:-1], photodiode_times[1:])): - + for start_index, (start_time, end_time) in enumerate(zip(photodiode_times[:-1], photodiode_times[1:])): interval_duration = end_time - start_time - irregularity = \ - int(np.around((interval_duration) / frame_duration)) - cycle + irregularity = int(np.around((interval_duration) / frame_duration)) - cycle local_frame_duration = interval_duration / (cycle + irregularity) - durations = \ - np.zeros(cycle + - (start_index == num_intervals - 1)) + local_frame_duration + durations = np.zeros(cycle + (start_index == num_intervals - 1)) + local_frame_duration current_ends = np.cumsum(durations) + start_time current_starts = current_ends - durations while irregularity != 0: current_starts, current_ends = irregular_interval_policy( - start_index, - current_starts, - current_ends, - local_frame_duration, - irregularity, cycle + start_index, current_starts, current_ends, local_frame_duration, irregularity, cycle ) irregularity += -1 * np.sign(irregularity) early_frame = start_index * cycle - late_frame = \ - (start_index + 1) * cycle + (start_index == num_intervals - 1) + late_frame = (start_index + 1) * cycle + (start_index == num_intervals - 1) - remaining = starts[early_frame: late_frame].size - starts[early_frame: late_frame] = current_starts[:remaining] - ends[early_frame: late_frame] = current_ends[:remaining] + remaining = starts[early_frame:late_frame].size + starts[early_frame:late_frame] = current_starts[:remaining] + ends[early_frame:late_frame] = current_ends[:remaining] return indices, starts, ends + # # diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py b/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py index 7a354a5a5d..e22c95a0f2 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/__main__.py @@ -2,9 +2,7 @@ import numpy as np -from allensdk.brain_observatory.argschema_utilities import \ - ArgSchemaParserPlus, \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import ArgSchemaParserPlus, write_or_print_outputs from allensdk.brain_observatory.ecephys.file_io.ecephys_sync_dataset import ( EcephysSyncDataset, ) @@ -18,38 +16,31 @@ def build_stimulus_table( - stimulus_pkl_path, - sync_h5_path, - frame_time_strategy, - minimum_spontaneous_activity_duration, - extract_const_params_from_repr, - drop_const_params, - maximum_expected_spontanous_activity_duration, - stimulus_name_map, - column_name_map, - output_stimulus_table_path, - output_frame_times_path, - fail_on_negative_duration, - **kwargs + stimulus_pkl_path, + sync_h5_path, + frame_time_strategy, + minimum_spontaneous_activity_duration, + extract_const_params_from_repr, + drop_const_params, + maximum_expected_spontanous_activity_duration, + stimulus_name_map, + column_name_map, + output_stimulus_table_path, + output_frame_times_path, + fail_on_negative_duration, + **kwargs, ): stim_file = CamStimOnePickleStimFile.factory(stimulus_pkl_path) sync_dataset = EcephysSyncDataset.factory(sync_h5_path) frame_times = sync_dataset.extract_frame_times( - strategy=frame_time_strategy, - trim_discontiguous_frame_times=kwargs.get( - 'trim_discontiguous_frame_times', - True) - ) + strategy=frame_time_strategy, trim_discontiguous_frame_times=kwargs.get("trim_discontiguous_frame_times", True) + ) def seconds_to_frames(seconds): - return \ - (np.array(seconds) + stim_file.pre_blank_sec) * \ - stim_file.frames_per_second + return (np.array(seconds) + stim_file.pre_blank_sec) * stim_file.frames_per_second - minimum_spontaneous_activity_duration = ( - minimum_spontaneous_activity_duration / stim_file.frames_per_second - ) + minimum_spontaneous_activity_duration = minimum_spontaneous_activity_duration / stim_file.frames_per_second stimulus_tabler = functools.partial( ephys_pre_spikes.build_stimuluswise_table, @@ -62,15 +53,12 @@ def seconds_to_frames(seconds): duration_threshold=minimum_spontaneous_activity_duration, ) - stim_table_full = ephys_pre_spikes.create_stim_table( - stim_file.stimuli, stimulus_tabler, spon_tabler - ) + stim_table_full = ephys_pre_spikes.create_stim_table(stim_file.stimuli, stimulus_tabler, spon_tabler) stim_table_full = ephys_pre_spikes.apply_frame_times( stim_table_full, frame_times, stim_file.frames_per_second, True ) - output_validation.validate_epoch_durations( - stim_table_full, fail_on_negative_durations=fail_on_negative_duration) + output_validation.validate_epoch_durations(stim_table_full, fail_on_negative_durations=fail_on_negative_duration) output_validation.validate_max_spontaneous_epoch_duration( stim_table_full, maximum_expected_spontanous_activity_duration ) @@ -79,20 +67,14 @@ def seconds_to_frames(seconds): stim_table_full = naming_utilities.collapse_columns(stim_table_full) stim_table_full = naming_utilities.drop_empty_columns(stim_table_full) - stim_table_full = naming_utilities.standardize_movie_numbers( - stim_table_full) - stim_table_full = naming_utilities.add_number_to_shuffled_movie( - stim_table_full) - stim_table_full = naming_utilities.map_stimulus_names( - stim_table_full, stimulus_name_map - ) + stim_table_full = naming_utilities.standardize_movie_numbers(stim_table_full) + stim_table_full = naming_utilities.add_number_to_shuffled_movie(stim_table_full) + stim_table_full = naming_utilities.map_stimulus_names(stim_table_full, stimulus_name_map) print(stim_table_full.keys()) print(column_name_map) - stim_table_full = naming_utilities.map_column_names(stim_table_full, - column_name_map, - ignore_case=False) + stim_table_full = naming_utilities.map_column_names(stim_table_full, column_name_map, ignore_case=False) print(stim_table_full.keys()) @@ -105,9 +87,7 @@ def seconds_to_frames(seconds): def main(): - mod = ArgSchemaParserPlus( - schema_type=InputParameters, output_schema_type=OutputSchema - ) + mod = ArgSchemaParserPlus(schema_type=InputParameters, output_schema_type=OutputSchema) output = build_stimulus_table(**mod.args) write_or_print_outputs(data=output, parser=mod) diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py b/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py index e7ff7228dd..b2ceea38d5 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/_schemas.py @@ -7,7 +7,6 @@ default_stimulus_renames = { "": "spontaneous", - "natural_movie_1": "natural_movie_one", "natural_movie_3": "natural_movie_three", "Natural Images": "natural_scenes", @@ -15,18 +14,14 @@ "gabor_20_deg_250ms": "gabors", "drifting_gratings": "drifting_gratings", "static_gratings": "static_gratings", - "contrast_response": "drifting_gratings_contrast", - "Natural_Images_Shuffled": "natural_scenes_shuffled", "Natural_Images_Sequential": "natural_scenes_sequential", "natural_movie_1_more_repeats": "natural_movie_one", "natural_movie_shuffled": "natural_movie_one_shuffled", "motion_stimulus": "dot_motion", "drifting_gratings_more_repeats": "drifting_gratings_75_repeats", - "signal_noise_test_0_200_repeats": "test_movie_one", - "signal_noise_test_0": "test_movie_one", "signal_noise_test_1": "test_movie_two", "signal_noise_session_1": "dense_movie_one", @@ -40,33 +35,22 @@ default_column_renames = { "Contrast": "contrast", - "Ori": "orientation", + "Ori": "orientation", "SF": "spatial_frequency", "TF": "temporal_frequency", "Phase": "phase", "Color": "color", "Image": "frame", "Pos_x": "x_position", - "Pos_y": "y_position" + "Pos_y": "y_position", } class InputParameters(ArgSchema): - stimulus_pkl_path = String( - required=True, - help="""path to pkl file containing raw stimulus information""" - ) - sync_h5_path = String( - required=True, - help="""path to h5 file containing syncronization information""" - ) - output_stimulus_table_path = String( - required=True, - help="""the output stimulus table csv will be written here""" - ) - output_frame_times_path = String( - required=True, - help="""output all frame times here""") + stimulus_pkl_path = String(required=True, help="""path to pkl file containing raw stimulus information""") + sync_h5_path = String(required=True, help="""path to h5 file containing syncronization information""") + output_stimulus_table_path = String(required=True, help="""the output stimulus table csv will be written here""") + output_frame_times_path = String(required=True, help="""output all frame times here""") minimum_spontaneous_activity_duration = Float( default=sys.float_info.epsilon, help="""detected spontaneous activity sweeps will be rejected if @@ -85,16 +69,10 @@ class InputParameters(ArgSchema): which is preferred when reliable vsync times are available.""", ) stimulus_name_map = Dict( - keys=String(), - values=String(), - help="optionally rename stimuli", - default=default_stimulus_renames + keys=String(), values=String(), help="optionally rename stimuli", default=default_stimulus_renames ) column_name_map = Dict( - keys=String(), - values=String(), - help="optionally rename stimulus parameters", - default=default_column_renames + keys=String(), values=String(), help="optionally rename stimulus parameters", default=default_column_renames ) extract_const_params_from_repr = Bool(default=True) drop_const_params = List( @@ -106,22 +84,24 @@ class InputParameters(ArgSchema): fail_on_negative_duration = Bool( default=False, help="""Determine if the module should fail if a - stimulus epoch has a negative duration.""" + stimulus epoch has a negative duration.""", ) trim_discontiguous_frame_times = Bool( default=True, help="""set to False if stimulus was shown in chunks, - and discontiguous vsyncs are expected""" + and discontiguous vsyncs are expected""", ) class OutputSchema(DefaultSchema): input_parameters = Nested( InputParameters, - description=("Input parameters the module " "was run with"), + description=("Input parameters the module was run with"), required=True, ) output_path = String(help="Path to output csv file") output_frame_times_path = String(help="output all frame times here") + + # diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/ecephys_visual_coding_time_alignment.ipynb b/allensdk/brain_observatory/ecephys/stimulus_table/ecephys_visual_coding_time_alignment.ipynb index a80d32c384..ebeeb86a59 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/ecephys_visual_coding_time_alignment.ipynb +++ b/allensdk/brain_observatory/ecephys/stimulus_table/ecephys_visual_coding_time_alignment.ipynb @@ -123,7 +123,7 @@ "\n", "num_pd_samples = 8\n", "step_vals = np.vstack([np.zeros(num_pd_samples), np.ones(num_pd_samples)]).T.flatten()\n", - "ax.step(ophys_pd[:num_pd_samples*2], step_vals)\n", + "ax.step(ophys_pd[: num_pd_samples * 2], step_vals)\n", "\n", "# validate that the edge directions are oriented correctly\n", "ax.vlines(ophys_pd_falling[:num_pd_samples], ymin=-0.05, ymax=1.05, color=\"black\")\n", @@ -175,7 +175,7 @@ "ax.plot(ophys_pd[1:], pd_diff)\n", "\n", "for ii in range(30):\n", - " ax.hlines([med_diff + frame_dur_exp * ii], xmin=ophys_pd[1]-20, xmax=ophys_pd[-1]+20)\n", + " ax.hlines([med_diff + frame_dur_exp * ii], xmin=ophys_pd[1] - 20, xmax=ophys_pd[-1] + 20)\n", "\n", "plt.ylim(top=1.5, bottom=0)\n", "\n", @@ -243,7 +243,7 @@ "ax.plot(ophys_pd[1:], pd_diff)\n", "\n", "for ii in range(30):\n", - " ax.hlines([med_diff + frame_dur_exp * ii], xmin=ophys_pd[1]-20, xmax=ophys_pd[-1]+20)\n", + " ax.hlines([med_diff + frame_dur_exp * ii], xmin=ophys_pd[1] - 20, xmax=ophys_pd[-1] + 20)\n", "\n", "plt.ylim(top=1.5)\n", "\n", @@ -311,10 +311,10 @@ "ax.plot(ophys_vs[1:], vs_diff)\n", "\n", "for ii in range(1, 31):\n", - " ax.hlines([frame_dur_exp * ii], xmin=ophys_pd[1]-20, xmax=ophys_pd[-1]+20)\n", + " ax.hlines([frame_dur_exp * ii], xmin=ophys_pd[1] - 20, xmax=ophys_pd[-1] + 20)\n", "\n", "plt.ylim(top=0.5)\n", - " \n", + "\n", "plt.show()" ] }, @@ -368,7 +368,9 @@ "metadata": {}, "outputs": [], "source": [ - "ecephys_dir = Path(\"/allen/programs/braintv/production/neuralcoding/prod56/specimen_719828690/ecephys_session_754312389\")\n", + "ecephys_dir = Path(\n", + " \"/allen/programs/braintv/production/neuralcoding/prod56/specimen_719828690/ecephys_session_754312389\"\n", + ")\n", "ecephys_sync_path = ecephys_dir / Path(\"754312389_404570_20180917.sync\")\n", "\n", "ecephys_sync = Dataset(ecephys_sync_path)\n", @@ -401,10 +403,10 @@ "plt.plot(ecephys_vs[1:], np.diff(ecephys_vs))\n", "\n", "for ii in range(1, 10):\n", - " ax.hlines([frame_dur_exp * ii], xmin=ecephys_vs[1]-20, xmax=ecephys_vs[-1]+20)\n", + " ax.hlines([frame_dur_exp * ii], xmin=ecephys_vs[1] - 20, xmax=ecephys_vs[-1] + 20)\n", "\n", "plt.ylim(top=0.5)\n", - " \n", + "\n", "plt.show()" ] }, @@ -445,7 +447,9 @@ "plt.plot(ecephys_pd_trim[1:], ecephys_pd_trim_diff)\n", "\n", "for ii in range(-2, 20):\n", - " ax.hlines([ecephys_pd_trim_diff_med + frame_dur_exp * ii], xmin=ecephys_pd_trim[1]-20, xmax=ecephys_pd_trim[-1]+20)\n", + " ax.hlines(\n", + " [ecephys_pd_trim_diff_med + frame_dur_exp * ii], xmin=ecephys_pd_trim[1] - 20, xmax=ecephys_pd_trim[-1] + 20\n", + " )\n", "\n", "plt.show()" ] @@ -528,14 +532,10 @@ "plt.plot(messy_pd_trim[1:], messy_pd_trim_diff)\n", "\n", "for ii in range(-10, 11):\n", - " ax.hlines(\n", - " [messy_pd_trim_diff_med + frame_dur_exp * ii], \n", - " xmin=messy_pd_trim[1] - 20, \n", - " xmax=messy_pd_trim[-1] + 20\n", - " )\n", + " ax.hlines([messy_pd_trim_diff_med + frame_dur_exp * ii], xmin=messy_pd_trim[1] - 20, xmax=messy_pd_trim[-1] + 20)\n", "\n", "plt.ylim([0.8, 1.2])\n", - " \n", + "\n", "plt.show()" ] }, diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py b/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py index e8bfdf2cdb..564a5bfb06 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/ephys_pre_spikes.py @@ -14,7 +14,7 @@ def create_stim_table( block_key="stimulus_block", index_key="stimulus_index", ): - """ Build a full stimulus table + """Build a full stimulus table Parameters ---------- @@ -50,8 +50,7 @@ def create_stim_table( stimulus_tables.extend(current_tables) - stimulus_tables = sorted(stimulus_tables, - key=lambda df: min(df[sort_key].values)) + stimulus_tables = sorted(stimulus_tables, key=lambda df: min(df[sort_key].values)) for ii, stim_table in enumerate(stimulus_tables): stim_table[block_key] = ii @@ -64,10 +63,8 @@ def create_stim_table( return stim_table_full -def make_spontaneous_activity_tables( - stimulus_tables, start_key="Start", end_key="End", duration_threshold=0.0 -): - """ Fills in frame gaps in a set of stimulus tables. Suitable for use as +def make_spontaneous_activity_tables(stimulus_tables, start_key="Start", end_key="End", duration_threshold=0.0): + """Fills in frame gaps in a set of stimulus tables. Suitable for use as the spontaneous_activity_tabler in create_stim_table. Parameters @@ -105,11 +102,7 @@ def make_spontaneous_activity_tables( spon_sweeps = pd.DataFrame({start_key: spon_start, end_key: spon_end}) if duration_threshold is not None: - spon_sweeps = spon_sweeps[ - np.fabs(spon_sweeps[start_key] - - spon_sweeps[end_key]) - > duration_threshold - ] + spon_sweeps = spon_sweeps[np.fabs(spon_sweeps[start_key] - spon_sweeps[end_key]) > duration_threshold] spon_sweeps.reset_index(drop=True, inplace=True) return [spon_sweeps] @@ -122,7 +115,7 @@ def apply_frame_times( extra_frame_time=False, map_columns=("Start", "End"), ): - """ Converts sweep times from frames to seconds. + """Converts sweep times from frames to seconds. Parameters ---------- @@ -155,13 +148,10 @@ def apply_frame_times( if extra_frame_time is True and frames_per_second is not None: extra_frame_time = 1.0 / frames_per_second if extra_frame_time is not False: - frame_times = np.append(frame_times, frame_times[-1] - + extra_frame_time) + frame_times = np.append(frame_times, frame_times[-1] + extra_frame_time) for column in map_columns: - stimulus_table[column] = frame_times[ - np.around(stimulus_table[column]).astype(int) - ] + stimulus_table[column] = frame_times[np.around(stimulus_table[column]).astype(int)] return stimulus_table @@ -174,7 +164,7 @@ def apply_display_sequence( diff_key="dif", block_key="stimulus_block", ): - """ Adjust raw sweep frames for a stimulus based on the display sequence + """Adjust raw sweep frames for a stimulus based on the display sequence for that stimulus. Parameters @@ -204,33 +194,22 @@ def apply_display_sequence( sweep_frames_table = sweep_frames_table.copy() if block_key not in sweep_frames_table.columns.values: - sweep_frames_table[block_key] = np.zeros( - (sweep_frames_table.shape[0]), dtype=int - ) + sweep_frames_table[block_key] = np.zeros((sweep_frames_table.shape[0]), dtype=int) - sweep_frames_table[diff_key] = ( - sweep_frames_table[end_key] - sweep_frames_table[start_key] - ) + sweep_frames_table[diff_key] = sweep_frames_table[end_key] - sweep_frames_table[start_key] sweep_frames_table[start_key] += frame_display_sequence[0, 0] for seg in range(len(frame_display_sequence) - 1): - match_inds = sweep_frames_table[start_key] \ - >= frame_display_sequence[seg, 1] + match_inds = sweep_frames_table[start_key] >= frame_display_sequence[seg, 1] sweep_frames_table.loc[match_inds, start_key] += ( frame_display_sequence[seg + 1, 0] - frame_display_sequence[seg, 1] ) sweep_frames_table.loc[match_inds, block_key] = seg + 1 - sweep_frames_table[end_key] = ( - sweep_frames_table[start_key] + sweep_frames_table[diff_key] - ) - sweep_frames_table = sweep_frames_table[ - sweep_frames_table[end_key] <= frame_display_sequence[-1, 1] - ] - sweep_frames_table = sweep_frames_table[ - sweep_frames_table[start_key] <= frame_display_sequence[-1, 1] - ] + sweep_frames_table[end_key] = sweep_frames_table[start_key] + sweep_frames_table[diff_key] + sweep_frames_table = sweep_frames_table[sweep_frames_table[end_key] <= frame_display_sequence[-1, 1]] + sweep_frames_table = sweep_frames_table[sweep_frames_table[start_key] <= frame_display_sequence[-1, 1]] sweep_frames_table.drop(diff_key, inplace=True, axis=1) return sweep_frames_table @@ -271,7 +250,7 @@ def build_stimuluswise_table( extract_const_params_from_repr=False, drop_const_params=spe.DROP_PARAMS, ): - """ Construct a table of sweeps, including their times on the + """Construct a table of sweeps, including their times on the experiment-global clock and the values of each relevant parameter. Parameters @@ -323,14 +302,9 @@ def build_stimuluswise_table( frame_display_sequence = seconds_to_frames(stimulus["display_sequence"]) - sweep_frames_table = pd.DataFrame( - stimulus["sweep_frames"], columns=(start_key, end_key) - ) - sweep_frames_table[block_key] = np.zeros([sweep_frames_table.shape[0]], - dtype=int) - sweep_frames_table = apply_display_sequence( - sweep_frames_table, frame_display_sequence, block_key=block_key - ) + sweep_frames_table = pd.DataFrame(stimulus["sweep_frames"], columns=(start_key, end_key)) + sweep_frames_table[block_key] = np.zeros([sweep_frames_table.shape[0]], dtype=int) + sweep_frames_table = apply_display_sequence(sweep_frames_table, frame_display_sequence, block_key=block_key) stim_table = pd.DataFrame( { @@ -359,19 +333,15 @@ def build_stimuluswise_table( ) if extract_const_params_from_repr: - const_params = spe.parse_stim_repr( - stimulus["stim"], drop_params=drop_const_params - ) + const_params = spe.parse_stim_repr(stimulus["stim"], drop_params=drop_const_params) existing_columns = set(stim_table.columns) for const_param_key, const_param_value in const_params.items(): - existing_cap = const_param_key.capitalize() in existing_columns existing_upper = const_param_key.upper() in existing_columns existing = const_param_key in existing_columns if not (existing_cap or existing_upper or existing): - stim_table[const_param_key] = [const_param_value] * \ - stim_table.shape[0] + stim_table[const_param_key] = [const_param_value] * stim_table.shape[0] else: logging.info( f"""found sweep_param named: {const_param_key}, @@ -380,14 +350,13 @@ def build_stimuluswise_table( ) unique_indices = np.unique(stim_table[block_key].values) - output = [stim_table.loc[stim_table[block_key] == ii, :] - for ii in unique_indices] + output = [stim_table.loc[stim_table[block_key] == ii, :] for ii in unique_indices] return output def split_column(table, column, new_columns, drop_old=True): - """ Divides a dataframe column into multiple columns. + """Divides a dataframe column into multiple columns. Parameters ---------- @@ -430,7 +399,7 @@ def assign_sweep_values( drop=True, tmp_suffix="_stimtable_todrop", ): - """ Left joins a stimulus table to a sweep table in order to associate + """Left joins a stimulus table to a sweep table in order to associate epochs in time with stimulus characteristics. Parameters diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py b/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py index adf75b8f07..8ff6c96651 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/naming_utilities.py @@ -3,9 +3,7 @@ import numpy as np -GABOR_DIAMETER_RE = re.compile( - r"gabor_(\d*\.{0,1}\d*)_{0,1}deg(?:_\d+ms){0,1}" -) +GABOR_DIAMETER_RE = re.compile(r"gabor_(\d*\.{0,1}\d*)_{0,1}deg(?:_\d+ms){0,1}") GENERIC_MOVIE_RE = re.compile( r"natural_movie_" @@ -81,17 +79,12 @@ def add_number_to_shuffled_movie( return table table = table.copy() - table[tmp_colname] = table[stim_colname].str.extract( - natural_movie_re, expand=True - )["number"] + table[tmp_colname] = table[stim_colname].str.extract(natural_movie_re, expand=True)["number"] - unique_numbers = [ - item for item in table[tmp_colname].dropna(inplace=False).unique() - ] + unique_numbers = [item for item in table[tmp_colname].dropna(inplace=False).unique()] if len(unique_numbers) != 1: raise ValueError( - "unable to uniquely determine a movie number for this session. " - + f"Candidates: {unique_numbers}" + "unable to uniquely determine a movie number for this session. " + f"Candidates: {unique_numbers}" ) movie_number = unique_numbers[0] @@ -150,9 +143,9 @@ def replace(match_obj): warnings.filterwarnings("ignore", category=UserWarning) movie_rows = table[stim_colname].str.contains(movie_re, na=False) - table.loc[movie_rows, stim_colname] = table.loc[ - movie_rows, stim_colname - ].str.replace(numeral_re, replace, regex=True) + table.loc[movie_rows, stim_colname] = table.loc[movie_rows, stim_colname].str.replace( + numeral_re, replace, regex=True + ) return table @@ -176,9 +169,7 @@ def map_stimulus_names(table, name_map=None, stim_colname="stimulus_name"): name_map[np.nan] = "spontaneous" - table[stim_colname] = table[stim_colname].replace( - to_replace=name_map, inplace=False - ) + table[stim_colname] = table[stim_colname].replace(to_replace=name_map, inplace=False) name_map.pop(np.nan) diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py b/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py index b1bcec1b39..075cc8bd76 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/output_validation.py @@ -2,10 +2,7 @@ import warnings -def validate_epoch_durations(table, - start_key="Start", - end_key="End", - fail_on_negative_durations=False): +def validate_epoch_durations(table, start_key="Start", end_key="End", fail_on_negative_durations=False): durations = table[end_key] - table[start_key] min_duration_index = durations.idxmin() min_duration = durations[min_duration_index] @@ -39,17 +36,14 @@ def validate_max_spontaneous_epoch_duration( end_key="End", ): if get_spontanous_epochs is None: + def get_spontanous_epochs(table): table[np.isnan(table[index_key])] spontaneous_epochs = get_spontanous_epochs(table) if spontaneous_epochs is not None: - - durations = ( - spontaneous_epochs[end_key].values - - spontaneous_epochs[start_key].values - ) + durations = spontaneous_epochs[end_key].values - spontaneous_epochs[start_key].values try: if np.amax(durations) > max_duration: @@ -57,6 +51,6 @@ def get_spontanous_epochs(table): f"""there is a spontaneous activity duration longer than {max_duration}""", UserWarning, - ) + ) except ValueError: warnings.warn("No spontaneous intervals detected.", UserWarning) diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py b/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py index 6e30faf8d1..ffdaf56fec 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/stimulus_parameter_extraction.py @@ -21,7 +21,7 @@ def parse_stim_repr( array_re=ARRAY_RE, raise_on_unrecognized=False, ): - """ Read the string representation of a psychopy stimulus and extract + """Read the string representation of a psychopy stimulus and extract stimulus parameters. Parameters @@ -39,9 +39,7 @@ def parse_stim_repr( """ - stim_params = extract_const_params_from_stim_repr( - stim_repr, repr_params_re=repr_params_re, array_re=array_re - ) + stim_params = extract_const_params_from_stim_repr(stim_repr, repr_params_re=repr_params_re, array_re=array_re) for drop_param in drop_params: if drop_param in stim_params: @@ -58,9 +56,7 @@ def extract_stim_class_from_repr(stim_repr, repr_class_re=REPR_CLASS_RE): return match["class_name"] -def extract_const_params_from_stim_repr( - stim_repr, repr_params_re=REPR_PARAMS_RE, array_re=ARRAY_RE -): +def extract_const_params_from_stim_repr(stim_repr, repr_params_re=REPR_PARAMS_RE, array_re=ARRAY_RE): """Parameters which are not set as sweep_params in the stimulus script (usually because they are not varied during the course of the session) are not output in an easily machine-readable format. This function @@ -90,7 +86,6 @@ def extract_const_params_from_stim_repr( k, v = match.split("=") if k not in repr_params: - m = array_re.match(v) if m is not None: v = m["contents"] diff --git a/allensdk/brain_observatory/ecephys/stimulus_table/visualization/view_blocks.py b/allensdk/brain_observatory/ecephys/stimulus_table/visualization/view_blocks.py index 882a378918..96fd45f412 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_table/visualization/view_blocks.py +++ b/allensdk/brain_observatory/ecephys/stimulus_table/visualization/view_blocks.py @@ -7,7 +7,6 @@ def build_colormap(table, existing_map={}, base_colors=sns.color_palette("pastel")): - colormap = {} unique_names = table["stimulus_name"].unique() @@ -36,11 +35,7 @@ def get_blocks(table): recorded_blocks = np.unique(block["stimulus_block"].values) if len(recorded_blocks) > 1: - raise ValueError( - "expected one recorded block per block, found: {}".format( - recorded_blocks - ) - ) + raise ValueError("expected one recorded block per block, found: {}".format(recorded_blocks)) else: pass @@ -76,7 +71,6 @@ def plot_blocks(blocks, colormap): labels = [] for block in blocks: - handle = ax.axvspan( block["start"], block["end"], @@ -100,7 +94,6 @@ def plot_blocks(blocks, colormap): def main(table_csv_path): - table = pd.read_csv(table_csv_path) colormap = build_colormap(table) @@ -112,9 +105,7 @@ def main(table_csv_path): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "table_csv_path", type=str, help="filesystem path to stimulus table csv" - ) + parser.add_argument("table_csv_path", type=str, help="filesystem path to stimulus table csv") args = parser.parse_args() main(args.table_csv_path) diff --git a/allensdk/brain_observatory/ecephys/utils.py b/allensdk/brain_observatory/ecephys/utils.py index 53124fd88f..b5ae6edc1f 100644 --- a/allensdk/brain_observatory/ecephys/utils.py +++ b/allensdk/brain_observatory/ecephys/utils.py @@ -29,9 +29,7 @@ def group_1d_by_unit(data, data_unit_map, local_to_global_unit_map=None): if local_to_global_unit_map is not None: if local_unit not in local_to_global_unit_map: - logging.warning( - f"unable to find unit at local position {local_unit}" - ) + logging.warning(f"unable to find unit at local position {local_unit}") continue global_id = local_to_global_unit_map[local_unit] output[global_id] = current @@ -41,11 +39,7 @@ def group_1d_by_unit(data, data_unit_map, local_to_global_unit_map=None): return output -def scale_amplitudes(spike_amplitudes, - templates, - spike_templates, - scale_factor=1.0): - +def scale_amplitudes(spike_amplitudes, templates, spike_templates, scale_factor=1.0): template_full_amplitudes = templates.max(axis=1) - templates.min(axis=1) template_amplitudes = template_full_amplitudes.max(axis=1) @@ -68,9 +62,7 @@ def clobbering_merge(to_df, from_df, **kwargs): return pd.merge(to_df, from_df, **kwargs) -def strip_substructure_acronym( - acronym: Optional[Union[str, list, float]] -) -> Optional[Union[str, list]]: +def strip_substructure_acronym(acronym: Optional[Union[str, list, float]]) -> Optional[Union[str, list]]: """ Sanitize a structure acronym or a list of structure acronyms by removing the substructure (e.g. DG-mo becomes DG). @@ -91,7 +83,7 @@ def strip_substructure_acronym( acronym = None if isinstance(acronym, str): - return acronym.split('-')[0] + return acronym.split("-")[0] elif isinstance(acronym, list): new_acronym = set() @@ -106,6 +98,4 @@ def strip_substructure_acronym( elif acronym is None: return None else: - raise RuntimeError( - "acronym must be a list or a str or None; you gave " - f"{acronym} which is a {type(acronym)}") + raise RuntimeError(f"acronym must be a list or a str or None; you gave {acronym} which is a {type(acronym)}") diff --git a/allensdk/brain_observatory/ecephys/visualization/__init__.py b/allensdk/brain_observatory/ecephys/visualization/__init__.py index ecff26ed01..a13dd846cd 100644 --- a/allensdk/brain_observatory/ecephys/visualization/__init__.py +++ b/allensdk/brain_observatory/ecephys/visualization/__init__.py @@ -3,8 +3,8 @@ import numpy as np -def plot_mean_waveforms(mean_waveforms, unit_ids, peak_channels): # pragma: no cover - ''' Utility for plotting mean waveforms on each unit's peak channel +def plot_mean_waveforms(mean_waveforms, unit_ids, peak_channels): # pragma: no cover + """Utility for plotting mean waveforms on each unit's peak channel Parameters ---------- @@ -13,50 +13,47 @@ def plot_mean_waveforms(mean_waveforms, unit_ids, peak_channels): # pragma: no c unit_ids : array-like unique integer identifiers for units to be included - ''' + """ fig, ax = plt.subplots(figsize=(10, 10)) for uid in unit_ids: wf = mean_waveforms[uid] - ax.plot(wf.loc[{'channel_id': peak_channels[uid]}]) + ax.plot(wf.loc[{"channel_id": peak_channels[uid]}]) ax.legend(unit_ids) - ax.set_ylabel('membrane potential (uV)', fontsize=16) - ax.set_xlabel('time (s)', fontsize=16) + ax.set_ylabel("membrane potential (uV)", fontsize=16) + ax.set_xlabel("time (s)", fontsize=16) - ax.set_xticks(np.arange(0, len(wf['time']), 20)) - ax.set_xticklabels([f'{float(ii):1.4f}' for ii in wf['time'][::20]], rotation=45) + ax.set_xticks(np.arange(0, len(wf["time"]), 20)) + ax.set_xticklabels([f"{float(ii):1.4f}" for ii in wf["time"][::20]], rotation=45) return fig - + def plot_spike_counts( - data_array, + data_array, time_coords, - cbar_label, - title, - xlabel='time relative to stimulus onset (s)', - ylabel='unit', - xtick_step=20 -): # pragma: no cover - '''Utility for making a simple spike counts plot. + cbar_label, + title, + xlabel="time relative to stimulus onset (s)", + ylabel="unit", + xtick_step=20, +): # pragma: no cover + """Utility for making a simple spike counts plot. Parameters ---------- data_array : xarray.DataArray 2D data array unitwise values per time bin. See EcephysSession.sweepwise_spike_counts - ''' - + """ + fig, ax = plt.subplots(figsize=(12, 12)) div = make_axes_locatable(ax) cbar_axis = div.append_axes("right", 0.2, pad=0.05) - img = ax.imshow( - data_array.T, - interpolation='none' - ) + img = ax.imshow(data_array.T, interpolation="none") plt.colorbar(img, cax=cbar_axis) cbar_axis.set_ylabel(cbar_label, fontsize=16) @@ -66,7 +63,7 @@ def plot_spike_counts( reltime = np.array(time_coords) ax.set_xticks(np.arange(0, len(reltime), xtick_step)) - ax.set_xticklabels([f'{mp:1.3f}' for mp in reltime[::xtick_step]], rotation=45) + ax.set_xticklabels([f"{mp:1.3f}" for mp in reltime[::xtick_step]], rotation=45) ax.set_xlabel(xlabel, fontsize=16) ax.set_title(title, fontsize=20) @@ -93,19 +90,18 @@ def __call__(self, gb): self.ii += 1 -def raster_plot(spike_times, figsize=(8,8), cmap=plt.cm.tab20, title='spike raster', cycle_colors=False): - +def raster_plot(spike_times, figsize=(8, 8), cmap=plt.cm.tab20, title="spike raster", cycle_colors=False): fig, ax = plt.subplots(figsize=figsize) - plotter = _VlPlotter(ax, num_objects=len(spike_times['unit_id'].unique()), cmap=cmap, cycle_colors=cycle_colors) + plotter = _VlPlotter(ax, num_objects=len(spike_times["unit_id"].unique()), cmap=cmap, cycle_colors=cycle_colors) # aggregate is called on each column, so pass only one (eg the stimulus_presentation_id) # to plot each unit once - spike_times[['stimulus_presentation_id', 'unit_id']].groupby('unit_id').agg(plotter) - - ax.set_xlabel('time (s)', fontsize=16) - ax.set_ylabel('unit', fontsize=16) + spike_times[["stimulus_presentation_id", "unit_id"]].groupby("unit_id").agg(plotter) + + ax.set_xlabel("time (s)", fontsize=16) + ax.set_ylabel("unit", fontsize=16) ax.set_title(title, fontsize=20) - + plt.yticks([]) - plt.axis('tight') - + plt.axis("tight") + return fig diff --git a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py index fe3f6dd7f6..fe09999011 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/__main__.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/__main__.py @@ -1,4 +1,5 @@ """Module for writing NWB files for the VCN project""" + import logging import sys from typing import Any, Dict, List, Tuple @@ -11,16 +12,12 @@ import pandas as pd import numpy as np -from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations \ - import \ - Presentations -from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import \ - BehaviorEcephysMetadata +from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations import Presentations +from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import BehaviorEcephysMetadata from allensdk.brain_observatory.ecephys._probe import Probe from allensdk.brain_observatory.ecephys.optotagging import OptotaggingTable from allensdk.brain_observatory.ecephys.probes import Probes -from allensdk.brain_observatory.ecephys.write_nwb.schemas import \ - VCNInputSchema, OutputSchema +from allensdk.brain_observatory.ecephys.write_nwb.schemas import VCNInputSchema, OutputSchema from allensdk.config.manifest import Manifest from allensdk.brain_observatory.nwb import ( @@ -30,15 +27,11 @@ read_eye_gaze_mappings, add_eye_tracking_ellipse_fit_data_to_nwbfile, add_eye_gaze_mapping_data_to_nwbfile, - eye_tracking_data_is_valid -) -from allensdk.brain_observatory.argschema_utilities import ( - optional_lims_inputs + eye_tracking_data_is_valid, ) +from allensdk.brain_observatory.argschema_utilities import optional_lims_inputs -from allensdk.brain_observatory.ecephys.nwb import ( - EcephysSpecimen, - EcephysEyeTrackingRigMetadata) +from allensdk.brain_observatory.ecephys.nwb import EcephysSpecimen, EcephysEyeTrackingRigMetadata from allensdk.brain_observatory.sync_dataset import Dataset import allensdk.brain_observatory.sync_utilities as su @@ -46,11 +39,7 @@ STIM_TABLE_RENAMES_MAP = {"Start": "start_time", "End": "stop_time"} -def get_inputs_from_lims(host, - ecephys_session_id, - output_root, - job_queue, - strategy): +def get_inputs_from_lims(host, ecephys_session_id, output_root, job_queue, strategy): """ This is a development / testing utility for running this module from the Allen Institute for Brain Science's Laboratory Information Management @@ -74,9 +63,11 @@ def get_inputs_from_lims(host, """ - uri = f"{host}/input_jsons?object_id={ecephys_session_id}" + \ - f"&object_class=EcephysSession&strategy_class={strategy}" + \ - f"&job_queue_name={job_queue}&output_directory={output_root}" + uri = ( + f"{host}/input_jsons?object_id={ecephys_session_id}" + + f"&object_class=EcephysSession&strategy_class={strategy}" + + f"&job_queue_name={job_queue}&output_directory={output_root}" + ) response = requests.get(uri) data = response.json() @@ -87,10 +78,10 @@ def get_inputs_from_lims(host, return data -def read_stimulus_table(path: str, - column_renames_map: Dict[str, str] = None, - columns_to_drop: List[str] = None) -> pd.DataFrame: - """ Loads from a CSV on disk the stimulus table for this session. +def read_stimulus_table( + path: str, column_renames_map: Dict[str, str] = None, columns_to_drop: List[str] = None +) -> pd.DataFrame: + """Loads from a CSV on disk the stimulus table for this session. Optionally renames columns to match NWB epoch specifications. Parameters @@ -122,8 +113,7 @@ def read_stimulus_table(path: str, raise IOError(f"unrecognized stimulus table extension: {ext}") if columns_to_drop: - stimulus_table = stimulus_table.drop(errors='ignore', - columns=columns_to_drop) + stimulus_table = stimulus_table.drop(errors="ignore", columns=columns_to_drop) return stimulus_table.rename(columns=column_renames_map, index={}) @@ -148,7 +138,7 @@ def add_metadata_to_nwbfile(nwbfile, input_metadata): def read_running_speed(path): - """ Reads running speed data and timestamps into a RunningSpeed named tuple + """Reads running speed data and timestamps into a RunningSpeed named tuple Parameters ---------- @@ -164,18 +154,10 @@ def read_running_speed(path): """ - return ( - pd.read_hdf(path, key="running_speed"), - pd.read_hdf(path, key="raw_data") - ) + return (pd.read_hdf(path, key="running_speed"), pd.read_hdf(path, key="raw_data")) -DEFAULT_RUNNING_SPEED_UNITS = { - "velocity": "cm/s", - "vin": "V", - "vsig": "V", - "rotation": "radians" -} +DEFAULT_RUNNING_SPEED_UNITS = {"velocity": "cm/s", "vin": "V", "vsig": "V", "rotation": "radians"} def add_running_speed_to_nwbfile(nwbfile, running_speed, units=None): @@ -189,7 +171,7 @@ def add_running_speed_to_nwbfile(nwbfile, running_speed, units=None): name="running_speed", timestamps=running_speed["start_time"].values, data=running_speed["velocity"].values, - unit=units["velocity"] + unit=units["velocity"], ) # Create an 'empty' timeseries that only stores end times @@ -198,14 +180,14 @@ def add_running_speed_to_nwbfile(nwbfile, running_speed, units=None): name="running_speed_end_times", data=np.full(running_speed["velocity"].shape, np.nan), timestamps=running_speed["end_time"].values, - unit=units["velocity"] + unit=units["velocity"], ) rotation_timeseries = pynwb.base.TimeSeries( name="running_wheel_rotation", timestamps=running_speed_timeseries, data=running_speed["net_rotation"].values, - unit=units["rotation"] + unit=units["rotation"], ) running_mod.add_data_interface(running_speed_timeseries) @@ -223,21 +205,21 @@ def add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, units=None): name="raw_running_wheel_rotation", timestamps=np.array(raw_running_data["frame_time"]), data=raw_running_data["dx"].values, - unit=units["rotation"] + unit=units["rotation"], ) vsig_ts = pynwb.base.TimeSeries( name="running_wheel_signal_voltage", timestamps=raw_rotation_timeseries, data=raw_running_data["vsig"].values, - unit=units["vsig"] + unit=units["vsig"], ) vin_ts = pynwb.base.TimeSeries( name="running_wheel_supply_voltage", timestamps=raw_rotation_timeseries, data=raw_running_data["vin"].values, - unit=units["vin"] + unit=units["vin"], ) nwbfile.add_acquisition(raw_rotation_timeseries) @@ -247,40 +229,36 @@ def add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, units=None): return nwbfile -def write_probe_lfp_file(session_id, session_metadata, session_start_time, - log_level, probe_meta): - """ Writes LFP data (and associated channel information) for one +def write_probe_lfp_file(session_id, session_metadata, session_start_time, log_level, probe_meta): + """Writes LFP data (and associated channel information) for one probe to a standalone nwb file """ - logging.getLogger('').setLevel(log_level) + logging.getLogger("").setLevel(log_level) logging.info(f"writing lfp file for probe {probe_meta['id']}") probe = Probe.from_json(probe=probe_meta) nwbfile = probe.add_lfp_to_nwb( session_id=session_id, session_start_time=session_start_time, - session_metadata=BehaviorEcephysMetadata.from_json( - dict_repr=session_metadata) + session_metadata=BehaviorEcephysMetadata.from_json(dict_repr=session_metadata), ) - with pynwb.NWBHDF5IO(probe_meta['lfp']['output_path'], 'w') as lfp_writer: + with pynwb.NWBHDF5IO(probe_meta["lfp"]["output_path"], "w") as lfp_writer: logging.info(f"writing lfp file to {probe_meta['lfp']['output_path']}") lfp_writer.write(nwbfile, cache_spec=True) - return { - "id": probe_meta["id"], - "nwb_path": probe_meta["lfp"]["output_path"]} - + return {"id": probe_meta["id"], "nwb_path": probe_meta["lfp"]["output_path"]} -def write_probewise_lfp_files(probes, session_id, session_metadata, - session_start_time, pool_size=3): +def write_probewise_lfp_files(probes, session_id, session_metadata, session_start_time, pool_size=3): output_paths = [] pool = mp.Pool(processes=pool_size) - write = partial(write_probe_lfp_file, - session_id, - session_metadata, - session_start_time, - logging.getLogger("").getEffectiveLevel()) + write = partial( + write_probe_lfp_file, + session_id, + session_metadata, + session_start_time, + logging.getLogger("").getEffectiveLevel(), + ) for pout in pool.imap_unordered(write, probes): output_paths.append(pout) @@ -288,10 +266,12 @@ def write_probewise_lfp_files(probes, session_id, session_metadata, return output_paths -ParsedProbeData = Tuple[pd.DataFrame, # unit_tables - Dict[int, np.ndarray], # spike_times - Dict[int, np.ndarray], # spike_amplitudes - Dict[int, np.ndarray]] # mean_waveforms +ParsedProbeData = Tuple[ + pd.DataFrame, # unit_tables + Dict[int, np.ndarray], # spike_times + Dict[int, np.ndarray], # spike_amplitudes + Dict[int, np.ndarray], +] # mean_waveforms def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData: @@ -319,15 +299,12 @@ def parse_probes_data(probes: List[Dict[str, Any]]) -> ParsedProbeData: Keys: unit identifiers, Values: mean waveform arrays """ probes = Probes.from_json(probes=probes) - return probes.units_table, \ - probes.spike_times, \ - probes.spike_amplitudes, \ - probes.mean_waveforms + return probes.units_table, probes.spike_times, probes.spike_amplitudes, probes.mean_waveforms def add_probewise_data_to_nwbfile(nwbfile, probes): - """ Adds channel (electrode) and spike data for a single probe to - the session-level nwb file. + """Adds channel (electrode) and spike data for a single probe to + the session-level nwb file. """ probes = Probes.from_json(probes=probes) probes.to_nwb(nwbfile=nwbfile) @@ -336,10 +313,9 @@ def add_probewise_data_to_nwbfile(nwbfile, probes): def add_eye_tracking_rig_geometry_data_to_nwbfile( - nwbfile: pynwb.NWBFile, - eye_tracking_rig_geometry: dict) -> pynwb.NWBFile: - - """ Rig geometry dict should consist of the following fields: + nwbfile: pynwb.NWBFile, eye_tracking_rig_geometry: dict +) -> pynwb.NWBFile: + """Rig geometry dict should consist of the following fields: monitor_position_mm: [x, y, z] monitor_rotation_deg: [x, y, z] camera_position_mm: [x, y, z] @@ -347,23 +323,23 @@ def add_eye_tracking_rig_geometry_data_to_nwbfile( led_position: [x, y, z] equipment: A string describing rig """ - eye_tracking_rig_mod = \ - pynwb.ProcessingModule(name='eye_tracking_rig_metadata', - description='Eye tracking rig metadata module') + eye_tracking_rig_mod = pynwb.ProcessingModule( + name="eye_tracking_rig_metadata", description="Eye tracking rig metadata module" + ) rig_metadata = EcephysEyeTrackingRigMetadata( name="eye_tracking_rig_metadata", - equipment=eye_tracking_rig_geometry['equipment'], - monitor_position=eye_tracking_rig_geometry['monitor_position_mm'], + equipment=eye_tracking_rig_geometry["equipment"], + monitor_position=eye_tracking_rig_geometry["monitor_position_mm"], monitor_position__unit="mm", - camera_position=eye_tracking_rig_geometry['camera_position_mm'], + camera_position=eye_tracking_rig_geometry["camera_position_mm"], camera_position__unit="mm", - led_position=eye_tracking_rig_geometry['led_position'], + led_position=eye_tracking_rig_geometry["led_position"], led_position__unit="mm", - monitor_rotation=eye_tracking_rig_geometry['monitor_rotation_deg'], + monitor_rotation=eye_tracking_rig_geometry["monitor_rotation_deg"], monitor_rotation__unit="deg", - camera_rotation=eye_tracking_rig_geometry['camera_rotation_deg'], - camera_rotation__unit="deg" + camera_rotation=eye_tracking_rig_geometry["camera_rotation_deg"], + camera_rotation__unit="deg", ) eye_tracking_rig_mod.add_data_interface(rig_metadata) @@ -373,30 +349,29 @@ def add_eye_tracking_rig_geometry_data_to_nwbfile( def add_eye_tracking_data_to_nwbfile( - nwbfile: pynwb.NWBFile, - eye_tracking_frame_times: pd.Series, - eye_dlc_tracking_data: Dict[str, pd.DataFrame], - eye_gaze_data: Dict[str, pd.DataFrame]) -> pynwb.NWBFile: - - if eye_tracking_data_is_valid(eye_dlc_tracking_data=eye_dlc_tracking_data, - synced_timestamps=eye_tracking_frame_times): - + nwbfile: pynwb.NWBFile, + eye_tracking_frame_times: pd.Series, + eye_dlc_tracking_data: Dict[str, pd.DataFrame], + eye_gaze_data: Dict[str, pd.DataFrame], +) -> pynwb.NWBFile: + if eye_tracking_data_is_valid( + eye_dlc_tracking_data=eye_dlc_tracking_data, synced_timestamps=eye_tracking_frame_times + ): add_eye_tracking_ellipse_fit_data_to_nwbfile( - nwbfile, - eye_dlc_tracking_data=eye_dlc_tracking_data, - synced_timestamps=eye_tracking_frame_times) + nwbfile, eye_dlc_tracking_data=eye_dlc_tracking_data, synced_timestamps=eye_tracking_frame_times + ) # --- Add gaze mapped positions to nwb file --- if eye_gaze_data: - add_eye_gaze_mapping_data_to_nwbfile(nwbfile, - eye_gaze_data=eye_gaze_data) + add_eye_gaze_mapping_data_to_nwbfile(nwbfile, eye_gaze_data=eye_gaze_data) return nwbfile def write_ecephys_nwb( output_path, - session_id, session_start_time, + session_id, + session_start_time, stimulus_table_path, invalid_epochs, probes, @@ -408,41 +383,46 @@ def write_ecephys_nwb( eye_dlc_ellipses_path=None, eye_gaze_mapping_path=None, session_metadata=None, - **kwargs + **kwargs, ): - nwbfile = pynwb.NWBFile( - session_description='Data and metadata for an Ecephys session', + session_description="Data and metadata for an Ecephys session", identifier=f"{session_id}", session_id=f"{session_id}", session_start_time=session_start_time, - institution="Allen Institute" + institution="Allen Institute", ) if session_metadata is not None: nwbfile = add_metadata_to_nwbfile(nwbfile, session_metadata) stimulus_columns_to_drop = [ - "colorSpace", "depth", "interpolate", "pos", "rgbPedestal", "tex", - "texRes", "flipHoriz", "flipVert", "rgb", "signalDots" + "colorSpace", + "depth", + "interpolate", + "pos", + "rgbPedestal", + "tex", + "texRes", + "flipHoriz", + "flipVert", + "rgb", + "signalDots", ] stimulus_table = Presentations.from_path( path=stimulus_table_path, behavior_session_id=session_id, exclude_columns=stimulus_columns_to_drop, columns_to_rename=STIM_TABLE_RENAMES_MAP, - sort_columns=False + sort_columns=False, ) - nwbfile = \ - add_stimulus_timestamps(nwbfile, - stimulus_table.value['start_time'].values) + nwbfile = add_stimulus_timestamps(nwbfile, stimulus_table.value["start_time"].values) nwbfile = stimulus_table.to_nwb(nwbfile=nwbfile) nwbfile = add_invalid_times(nwbfile, invalid_epochs) if optotagging_table_path is not None: - optotagging_table = OptotaggingTable.from_json( - dict_repr={'optotagging_table_path': optotagging_table_path}) + optotagging_table = OptotaggingTable.from_json(dict_repr={"optotagging_table_path": optotagging_table_path}) nwbfile = optotagging_table.to_nwb(nwbfile=nwbfile) nwbfile = add_probewise_data_to_nwbfile(nwbfile, probes) @@ -452,64 +432,43 @@ def write_ecephys_nwb( add_raw_running_data_to_nwbfile(nwbfile, raw_running_data) if eye_tracking_rig_geometry is not None: - add_eye_tracking_rig_geometry_data_to_nwbfile( - nwbfile, - eye_tracking_rig_geometry - ) + add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile, eye_tracking_rig_geometry) # Collect eye tracking/gaze mapping data from files if eye_dlc_ellipses_path is not None: - eye_tracking_frame_times = \ - su.get_synchronized_frame_times( - session_sync_file=session_sync_path, - sync_line_label_keys=Dataset.EYE_TRACKING_KEYS - ) - eye_dlc_tracking_data = \ - read_eye_dlc_tracking_ellipses(Path(eye_dlc_ellipses_path)) + eye_tracking_frame_times = su.get_synchronized_frame_times( + session_sync_file=session_sync_path, sync_line_label_keys=Dataset.EYE_TRACKING_KEYS + ) + eye_dlc_tracking_data = read_eye_dlc_tracking_ellipses(Path(eye_dlc_ellipses_path)) if eye_gaze_mapping_path is not None: eye_gaze_data = read_eye_gaze_mappings(Path(eye_gaze_mapping_path)) else: eye_gaze_data = None - add_eye_tracking_data_to_nwbfile(nwbfile, - eye_tracking_frame_times, - eye_dlc_tracking_data, - eye_gaze_data) + add_eye_tracking_data_to_nwbfile(nwbfile, eye_tracking_frame_times, eye_dlc_tracking_data, eye_gaze_data) Manifest.safe_make_parent_dirs(output_path) - with pynwb.NWBHDF5IO(output_path, mode='w') as io: + with pynwb.NWBHDF5IO(output_path, mode="w") as io: logging.info(f"writing session nwb file to {output_path}") io.write(nwbfile, cache_spec=True) probes_with_lfp = [p for p in probes if p["lfp"] is not None] probes_without_lfp = [p for p in probes if p["lfp"] is None] - probe_outputs = write_probewise_lfp_files(probes_with_lfp, session_id, - session_metadata, - session_start_time, - pool_size=pool_size) + probe_outputs = write_probewise_lfp_files( + probes_with_lfp, session_id, session_metadata, session_start_time, pool_size=pool_size + ) - probe_outputs += \ - [{'id': p["id"], "nwb_path": ""} for p in probes_without_lfp] + probe_outputs += [{"id": p["id"], "nwb_path": ""} for p in probes_without_lfp] - return { - 'nwb_path': output_path, - "probe_outputs": probe_outputs - } + return {"nwb_path": output_path, "probe_outputs": probe_outputs} def main(): - logging.basicConfig( - format="%(asctime)s - %(process)s - %(levelname)s - %(message)s" - ) + logging.basicConfig(format="%(asctime)s - %(process)s - %(levelname)s - %(message)s") - parser = optional_lims_inputs( - sys.argv, - VCNInputSchema, - OutputSchema, - get_inputs_from_lims - ) + parser = optional_lims_inputs(sys.argv, VCNInputSchema, OutputSchema, get_inputs_from_lims) write_ecephys_nwb(**parser.args) diff --git a/allensdk/brain_observatory/ecephys/write_nwb/nwb_writer.py b/allensdk/brain_observatory/ecephys/write_nwb/nwb_writer.py index 45c974e2fb..4373b1b062 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/nwb_writer.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/nwb_writer.py @@ -6,29 +6,22 @@ from pynwb import NWBHDF5IO, NWBFile -from allensdk.brain_observatory.ecephys.behavior_ecephys_session import \ - BehaviorEcephysSession +from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession from allensdk.brain_observatory.nwb.nwb_utils import NWBWriter -from allensdk.core import JsonReadableInterface, NwbReadableInterface, \ - NwbWritableInterface +from allensdk.core import JsonReadableInterface, NwbReadableInterface, NwbWritableInterface class BehaviorEcephysNwbWriter(NWBWriter): """NWB Writer for behavior ecephys. Same as `NWBWriter` except also - writes probe NWB files separately """ + writes probe NWB files separately""" + def __init__( - self, - session_nwb_filepath: str, - session_data: dict, - serializer: Union[ - JsonReadableInterface, - NwbReadableInterface, - NwbWritableInterface]): - super().__init__( - nwb_filepath=session_nwb_filepath, - session_data=session_data, - serializer=serializer - ) + self, + session_nwb_filepath: str, + session_data: dict, + serializer: Union[JsonReadableInterface, NwbReadableInterface, NwbWritableInterface], + ): + super().__init__(nwb_filepath=session_nwb_filepath, session_data=session_data, serializer=serializer) def write_nwb(self, **kwargs): """Tries to write nwb to disk. If it fails, the filepath has ".error" @@ -40,52 +33,36 @@ def write_nwb(self, **kwargs): """ from_json_kwargs = { - k: v for k, v in kwargs.items() - if k in inspect.signature(self._serializer.from_json).parameters} - json_session = self._serializer.from_json( - session_data=self._session_data, **from_json_kwargs) + k: v for k, v in kwargs.items() if k in inspect.signature(self._serializer.from_json).parameters + } + json_session = self._serializer.from_json(session_data=self._session_data, **from_json_kwargs) try: - nwbfile = self._write_nwb( - session=json_session, **kwargs) - self._compare_sessions(nwbfile=nwbfile, - loaded_session=json_session, - **kwargs) + nwbfile = self._write_nwb(session=json_session, **kwargs) + self._compare_sessions(nwbfile=nwbfile, loaded_session=json_session, **kwargs) os.rename(self.nwb_filepath_inprogress, self._nwb_filepath) except Exception as e: if os.path.isfile(self.nwb_filepath_inprogress): - os.rename(self.nwb_filepath_inprogress, - self._nwb_filepath_error) + os.rename(self.nwb_filepath_inprogress, self._nwb_filepath_error) raise e - def _write_nwb( - self, - session: BehaviorEcephysSession, - **kwargs) -> NWBFile: - to_nwb_kwargs = { - k: v for k, v in kwargs.items() - if k in inspect.signature(self._serializer.to_nwb).parameters} + def _write_nwb(self, session: BehaviorEcephysSession, **kwargs) -> NWBFile: + to_nwb_kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(self._serializer.to_nwb).parameters} session_nwbfile, probe_nwbfile_map = session.to_nwb(**to_nwb_kwargs) os.makedirs(Path(self.nwb_filepath_inprogress).parent, exist_ok=True) - logging.info(f'Writing session NWB file to ' - f'{self.nwb_filepath_inprogress}') - with NWBHDF5IO(self.nwb_filepath_inprogress, 'w') as nwb_file_writer: + logging.info(f"Writing session NWB file to {self.nwb_filepath_inprogress}") + with NWBHDF5IO(self.nwb_filepath_inprogress, "w") as nwb_file_writer: nwb_file_writer.write(session_nwbfile) - logging.info(f'Wrote session NWB file to ' - f'{self.nwb_filepath_inprogress}') + logging.info(f"Wrote session NWB file to {self.nwb_filepath_inprogress}") for probe_name, probe_nwbfile in probe_nwbfile_map.items(): - probe_id = [p.id for p in session.get_probes_obj() - if p.name == probe_name][0] + probe_id = [p.id for p in session.get_probes_obj() if p.name == probe_name][0] if probe_nwbfile is not None: - probe_nwb_path = Path(self._nwb_filepath).parent / \ - f'lfp_probe_{probe_id}.nwb' - logging.info(f'Writing probe NWB file to ' - f'{probe_nwb_path}') - with NWBHDF5IO(probe_nwb_path, 'w') as nwb_file_writer: + probe_nwb_path = Path(self._nwb_filepath).parent / f"lfp_probe_{probe_id}.nwb" + logging.info(f"Writing probe NWB file to {probe_nwb_path}") + with NWBHDF5IO(probe_nwb_path, "w") as nwb_file_writer: nwb_file_writer.write(probe_nwbfile) - logging.info(f'Wrote probe NWB file to ' - f'{probe_nwb_path}') + logging.info(f"Wrote probe NWB file to {probe_nwb_path}") return session_nwbfile diff --git a/allensdk/brain_observatory/ecephys/write_nwb/schemas.py b/allensdk/brain_observatory/ecephys/write_nwb/schemas.py index 8b0bb0f34a..a9d876d00e 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/schemas.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/schemas.py @@ -23,11 +23,7 @@ class BaseBehaviorSessionDataSchema(RaisingSchema): behavior_session_id = Int( required=True, - description=( - "Unique identifier for the " - "behavior session to write into " - "NWB format" - ), + description=("Unique identifier for the behavior session to write into NWB format"), ) driver_line = List( String, @@ -41,27 +37,19 @@ class BaseBehaviorSessionDataSchema(RaisingSchema): cli_as_single_argument=True, description="Genetic reporter line(s) of subject", ) - full_genotype = String( - required=True, description="Full genotype of subject" - ) + full_genotype = String(required=True, description="Full genotype of subject") rig_name = String( required=True, - description=( - "Name of experimental rig used for " "the behavior session" - ), + description=("Name of experimental rig used for the behavior session"), ) date_of_acquisition = String( required=True, - description=( - "Date of acquisition of " "behavior session, in string " "format" - ), - ) - external_specimen_name = Int( - required=True, description="LabTracks ID of the subject" + description=("Date of acquisition of behavior session, in string format"), ) + external_specimen_name = Int(required=True, description="LabTracks ID of the subject") behavior_stimulus_file = argschema.fields.InputFile( required=True, - description=("Path of behavior_stimulus " "camstim *.pkl file"), + description=("Path of behavior_stimulus camstim *.pkl file"), ) date_of_birth = String(required=True, description="Subject date of birth") sex = String(required=True, description="Subject sex") @@ -72,9 +60,7 @@ class Channel(RaisingSchema): @mm.pre_load def set_field_defaults(self, data, **kwargs): if data.get("filtering") is None: - data["filtering"] = ( - "AP band: 500 Hz high-pass; " "LFP band: 1000 Hz low-pass" - ) + data["filtering"] = "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass" if data.get("structure_acronym") is None: data["structure_acronym"] = "" return data @@ -212,7 +198,7 @@ class Probe(RaisingSchema): The values for unit amplitudes were changed in the input_json file and do not use this scale. If the data in LIMS for these sessions is updated, this scaling - is not needed. Default is 1""" + is not needed. Default is 1""", ) @@ -268,17 +254,13 @@ class VCNInputSchema(BaseNeuropixelsSchema): class Meta: unknown = mm.RAISE - log_level = LogLevel( - default="INFO", help="set the logging level of the module" - ) + log_level = LogLevel(default="INFO", help="set the logging level of the module") output_path = String( required=True, validate=check_write_access, help="write outputs to here", ) - session_id = Int( - required=True, help="unique identifier for this ecephys session" - ) + session_id = Int(required=True, help="unique identifier for this ecephys session") session_start_time = DateTime( required=True, help="the date and time (iso8601) at which the session started", @@ -288,9 +270,7 @@ class Meta: validate=check_read_access, help="path to stimulus table file", ) - invalid_epochs = Nested( - InvalidEpoch, many=True, required=True, help="epochs with invalid data" - ) + invalid_epochs = Nested(InvalidEpoch, many=True, required=True, help="epochs with invalid data") session_sync_path = String( required=True, validate=check_read_access, @@ -319,7 +299,7 @@ class Meta: SessionMetadata, allow_none=True, required=False, - help="miscellaneous information describing this session" "", + help="miscellaneous information describing this session", ) running_speed_path = String( required=True, diff --git a/allensdk/brain_observatory/ecephys/write_nwb/vbn/__main__.py b/allensdk/brain_observatory/ecephys/write_nwb/vbn/__main__.py index d7592c8268..4b34bf31c5 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/vbn/__main__.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/vbn/__main__.py @@ -4,12 +4,9 @@ import sys import argschema import marshmallow -from allensdk.brain_observatory.ecephys.behavior_ecephys_session import \ - BehaviorEcephysSession -from allensdk.brain_observatory.ecephys.write_nwb.nwb_writer import \ - BehaviorEcephysNwbWriter -from allensdk.brain_observatory.ecephys.write_nwb.vbn._schemas import \ - VBNInputSchema, OutputSchema +from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession +from allensdk.brain_observatory.ecephys.write_nwb.nwb_writer import BehaviorEcephysNwbWriter +from allensdk.brain_observatory.ecephys.write_nwb.vbn._schemas import VBNInputSchema, OutputSchema def main(): @@ -20,23 +17,23 @@ def main(): schema_type=VBNInputSchema, output_schema_type=OutputSchema, ) - logging.info('Input successfully parsed') + logging.info("Input successfully parsed") except marshmallow.exceptions.ValidationError as err: - logging.error('Parsing failure') + logging.error("Parsing failure") logging.error(err) raise err nwb_writer = BehaviorEcephysNwbWriter( - session_nwb_filepath=parser.args['output_path'], - session_data=parser.args['session_data'], - serializer=BehaviorEcephysSession + session_nwb_filepath=parser.args["output_path"], + session_data=parser.args["session_data"], + serializer=BehaviorEcephysSession, ) try: - nwb_writer.write_nwb(skip_probes=parser.args['skip_probes']) - logging.info('File successfully created') + nwb_writer.write_nwb(skip_probes=parser.args["skip_probes"]) + logging.info("File successfully created") except Exception as err: - logging.error('NWB write failure') + logging.error("NWB write failure") logging.error(err) raise err diff --git a/allensdk/brain_observatory/ecephys/write_nwb/vbn/_schemas.py b/allensdk/brain_observatory/ecephys/write_nwb/vbn/_schemas.py index fc2951cbb4..6f18efa275 100644 --- a/allensdk/brain_observatory/ecephys/write_nwb/vbn/_schemas.py +++ b/allensdk/brain_observatory/ecephys/write_nwb/vbn/_schemas.py @@ -8,48 +8,28 @@ from argschema.fields import LogLevel -class _VBNSessionDataSchema( - BaseBehaviorSessionDataSchema, BaseNeuropixelsSchema -): - mapping_stimulus_file = argschema.fields.InputFile( - required=True, description="path to mapping_stimulus_file" - ) - replay_stimulus_file = argschema.fields.InputFile( - required=True, description="path to replay_stimulus_file" - ) - stim_table_file = argschema.fields.InputFile( - required=True, description="path to stimulus presentations csv file" - ) +class _VBNSessionDataSchema(BaseBehaviorSessionDataSchema, BaseNeuropixelsSchema): + mapping_stimulus_file = argschema.fields.InputFile(required=True, description="path to mapping_stimulus_file") + replay_stimulus_file = argschema.fields.InputFile(required=True, description="path to replay_stimulus_file") + stim_table_file = argschema.fields.InputFile(required=True, description="path to stimulus presentations csv file") raw_eye_tracking_video_meta_data = argschema.fields.InputFile( required=True, description="path to eye tracking metadata" ) - eye_dlc_file = argschema.fields.InputFile( - required=True, description="path to deeplabcut eye tracking h5 file" - ) - side_dlc_file = argschema.fields.InputFile( - required=True, description="path to deeplabcut side tracking h5 file" - ) - face_dlc_file = argschema.fields.InputFile( - required=True, description="path to deeplabcut face tracking h5 file" - ) + eye_dlc_file = argschema.fields.InputFile(required=True, description="path to deeplabcut eye tracking h5 file") + side_dlc_file = argschema.fields.InputFile(required=True, description="path to deeplabcut side tracking h5 file") + face_dlc_file = argschema.fields.InputFile(required=True, description="path to deeplabcut face tracking h5 file") eye_tracking_filepath = argschema.fields.InputFile( required=True, description="h5 filepath containing eye tracking ellipses", ) - sync_file = argschema.fields.InputFile( - required=True, description="path to sync file" - ) - ecephys_session_id = argschema.fields.Int( - required=True, description="ecephys session id" - ) + sync_file = argschema.fields.InputFile(required=True, description="path to sync file") + ecephys_session_id = argschema.fields.Int(required=True, description="ecephys session id") class VBNInputSchema(ArgSchema): """Input schema for visual behavior neuropixels""" - log_level = LogLevel( - default="INFO", description="Logging level of the module" - ) + log_level = LogLevel(default="INFO", description="Logging level of the module") session_data = argschema.fields.Nested( _VBNSessionDataSchema, required=True, @@ -61,13 +41,9 @@ class VBNInputSchema(ArgSchema): default=None, allow_none=True, ) - output_path = argschema.fields.OutputFile( - required=True, description="Path of output.json to be written" - ) + output_path = argschema.fields.OutputFile(required=True, description="Path of output.json to be written") class OutputSchema(RaisingSchema): input_parameters = argschema.fields.Nested(VBNInputSchema) - output_path = argschema.fields.OutputFile( - required=True, description="write outputs to here" - ) + output_path = argschema.fields.OutputFile(required=True, description="write outputs to here") diff --git a/allensdk/brain_observatory/extract_running_speed/__main__.py b/allensdk/brain_observatory/extract_running_speed/__main__.py index 62935ac9ad..b38a816f02 100644 --- a/allensdk/brain_observatory/extract_running_speed/__main__.py +++ b/allensdk/brain_observatory/extract_running_speed/__main__.py @@ -4,9 +4,7 @@ import pandas as pd from allensdk.brain_observatory import sync_utilities -from allensdk.brain_observatory.argschema_utilities import \ - ArgSchemaParserPlus, \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import ArgSchemaParserPlus, write_or_print_outputs from allensdk.brain_observatory.sync_dataset import Dataset from ._schemas import InputParameters, OutputParameters @@ -24,13 +22,9 @@ def check_encoder(parent, key): def running_from_stim_file(stim_file, key, expected_length): - if "behavior" in stim_file["items"] and check_encoder( - stim_file["items"]["behavior"], key - ): + if "behavior" in stim_file["items"] and check_encoder(stim_file["items"]["behavior"], key): return stim_file["items"]["behavior"]["encoders"][0][key][:] - if "foraging" in stim_file["items"] and check_encoder( - stim_file["items"]["foraging"], key - ): + if "foraging" in stim_file["items"] and check_encoder(stim_file["items"]["foraging"], key): return stim_file["items"]["foraging"]["encoders"][0][key][:] if key in stim_file: return stim_file[key][:] @@ -47,10 +41,7 @@ def angular_to_linear_velocity(angular_velocity, radius): return np.multiply(angular_velocity, radius) -def extract_running_speeds( - frame_times, dx_deg, vsig, vin, wheel_radius, subject_position, - use_median_duration=False -): +def extract_running_speeds(frame_times, dx_deg, vsig, vin, wheel_radius, subject_position, use_median_duration=False): # the first interval does not have a known start time, so we can't compute # an average velocity from dx dx_rad = degrees_to_radians(dx_deg[1:]) @@ -84,10 +75,7 @@ def extract_running_speeds( return df -def main( - stimulus_pkl_path, sync_h5_path, output_path, wheel_radius, - subject_position, use_median_duration, **kwargs -): +def main(stimulus_pkl_path, sync_h5_path, output_path, wheel_radius, subject_position, use_median_duration, **kwargs): stim_file = pd.read_pickle(stimulus_pkl_path) sync_dataset = Dataset(sync_h5_path) @@ -96,13 +84,11 @@ def main( # 2. updates the "items", causing a running speed sample to be acquired # 3. sets the vsync line high # 4. flips the buffer - frame_times = sync_dataset.get_edges( - "rising", Dataset.FRAME_KEYS, units="seconds" - ) + frame_times = sync_dataset.get_edges("rising", Dataset.FRAME_KEYS, units="seconds") # occasionally an extra set of frame times are acquired after the rest of # the signals. We detect and remove these - if kwargs.get('trim_discontiguous_frame_times', True): + if kwargs.get("trim_discontiguous_frame_times", True): frame_times = sync_utilities.trim_discontiguous_times(frame_times) num_raw_timestamps = len(frame_times) @@ -111,8 +97,7 @@ def main( if num_raw_timestamps != len(dx_deg): raise ValueError( - f"found {num_raw_timestamps} rising edges on the vsync line, " - f"but only {len(dx_deg)} rotation samples" + f"found {num_raw_timestamps} rising edges on the vsync line, but only {len(dx_deg)} rotation samples" ) vsig = running_from_stim_file(stim_file, "vsig", num_raw_timestamps) @@ -131,12 +116,10 @@ def main( vin=vin, wheel_radius=wheel_radius, subject_position=subject_position, - use_median_duration=use_median_duration + use_median_duration=use_median_duration, ) - raw_data = pd.DataFrame( - {"vsig": vsig, "vin": vin, "frame_time": frame_times, "dx": dx_deg} - ) + raw_data = pd.DataFrame({"vsig": vsig, "vin": vin, "frame_time": frame_times, "dx": dx_deg}) store = pd.HDFStore(output_path) store.put("running_speed", velocities) @@ -147,9 +130,7 @@ def main( if __name__ == "__main__": - mod = ArgSchemaParserPlus( - schema_type=InputParameters, output_schema_type=OutputParameters - ) + mod = ArgSchemaParserPlus(schema_type=InputParameters, output_schema_type=OutputParameters) output = main(**mod.args) write_or_print_outputs(data=output, parser=mod) diff --git a/allensdk/brain_observatory/extract_running_speed/_schemas.py b/allensdk/brain_observatory/extract_running_speed/_schemas.py index e040ce49a9..e55e8c4fa5 100644 --- a/allensdk/brain_observatory/extract_running_speed/_schemas.py +++ b/allensdk/brain_observatory/extract_running_speed/_schemas.py @@ -16,27 +16,24 @@ class InputParameters(ArgSchema): wheel_radius = Float(default=8.255, help="radius, in cm, of running wheel") subject_position = Float( default=2 / 3, - help="normalized distance of the subject from the center " + - "of the running wheel (1 is rim, 0 is center)", + help="normalized distance of the subject from the center " + "of the running wheel (1 is rim, 0 is center)", ) use_median_duration = Boolean( default=True, - help="frame timestamps are often too noisy to use as the " + - "denominator in the velocity calculation. " + - "Can instead use the median frame duration." + help="frame timestamps are often too noisy to use as the " + + "denominator in the velocity calculation. " + + "Can instead use the median frame duration.", ) trim_discontiguous_frame_times = Boolean( - default=True, - help="set to False if stimulus was shown in chunks, " + - "and discontiguous vsyncs are expected." + default=True, help="set to False if stimulus was shown in chunks, " + "and discontiguous vsyncs are expected." ) class OutputSchema(DefaultSchema): input_parameters = Nested( InputParameters, - description=("Input parameters the module " "was run with"), + description=("Input parameters the module was run with"), required=True, ) diff --git a/allensdk/brain_observatory/extract_running_speed/examples.ipynb b/allensdk/brain_observatory/extract_running_speed/examples.ipynb index be3e479d35..5a188a0579 100644 --- a/allensdk/brain_observatory/extract_running_speed/examples.ipynb +++ b/allensdk/brain_observatory/extract_running_speed/examples.ipynb @@ -33,12 +33,14 @@ "outputs": [], "source": [ "def roundtrip(inputs):\n", - " output_path = main(**inputs)[\"output_path\"] \n", + " output_path = main(**inputs)[\"output_path\"]\n", " return pd.read_hdf(output_path, key=\"running_speed\")\n", - " \n", + "\n", + "\n", "def midpoints(obt):\n", " return obt[\"start_time\"] + (obt[\"end_time\"] - obt[\"start_time\"]) / 2\n", "\n", + "\n", "def remove_outliers(data, filter_width=5, percentile=99.9999):\n", " data = np.array(data)\n", " filtered_data = median_filter(data, size=filter_width)\n", @@ -84,8 +86,15 @@ ], "source": [ "ec_storage = Path(\n", - " \"/\", \"allen\", \"programs\", \"braintv\", \"production\", \"neuralcoding\", \n", - " \"prod0\", \"specimen_717038288\", \"ecephys_session_732592105\"\n", + " \"/\",\n", + " \"allen\",\n", + " \"programs\",\n", + " \"braintv\",\n", + " \"production\",\n", + " \"neuralcoding\",\n", + " \"prod0\",\n", + " \"specimen_717038288\",\n", + " \"ecephys_session_732592105\",\n", ")\n", "\n", "ec_inputs = {\n", @@ -93,7 +102,7 @@ " \"stimulus_pkl_path\": ec_storage / Path(\"732592105_404553_20180808.stim.pkl\"),\n", " \"output_path\": local_dir / Path(\"ecephys_running_speed.h5\"),\n", " \"wheel_radius\": wheel_radius,\n", - " \"subject_position\": subject_position\n", + " \"subject_position\": subject_position,\n", "}\n", "\n", "ec_obt = roundtrip(ec_inputs)" @@ -1017,16 +1026,14 @@ }, "outputs": [], "source": [ - "vb_storage = Path(\n", - " \"/allen/programs/braintv/production/visualbehavior/prod0/specimen_843387586/ophys_session_884528231/\"\n", - ")\n", + "vb_storage = Path(\"/allen/programs/braintv/production/visualbehavior/prod0/specimen_843387586/ophys_session_884528231/\")\n", "\n", "vb_inputs = {\n", " \"sync_h5_path\": vb_storage / Path(\"884528231_sync.h5\"),\n", " \"stimulus_pkl_path\": vb_storage / Path(\"884528231_stim.pkl\"),\n", " \"output_path\": local_dir / Path(\"vb_running_speed.h5\"),\n", " \"wheel_radius\": wheel_radius,\n", - " \"subject_position\": subject_position\n", + " \"subject_position\": subject_position,\n", "}\n", "\n", "vb_obt = roundtrip(vb_inputs)" @@ -1969,31 +1976,28 @@ "metadata": {}, "outputs": [], "source": [ - "def running_in_intervals(\n", - " running, intervals, start_key=\"start_time\", end_key=\"end_time\", data_key=\"velocity\"\n", - "):\n", + "def running_in_intervals(running, intervals, start_key=\"start_time\", end_key=\"end_time\", data_key=\"velocity\"):\n", " # TODO: handle bounds\n", - " \n", - " intervals = np.array(intervals) # N X (start, end)\n", - " \n", + "\n", + " intervals = np.array(intervals) # N X (start, end)\n", + "\n", " start_indices = np.searchsorted(running[end_key], intervals[:, 0])\n", " end_indices = np.searchsorted(running[end_key], intervals[:, 1])\n", - " \n", + "\n", " durations = running[end_key] - running[start_key]\n", "\n", " values = []\n", " for ii, (start_index, end_index) in enumerate(zip(start_indices, end_indices)):\n", - " \n", " start = intervals[ii, 0]\n", " end = intervals[ii, 1]\n", - " \n", - " raw_weights = durations[start_index:end_index+1].values\n", + "\n", + " raw_weights = durations[start_index : end_index + 1].values\n", " raw_weights[0] -= start - running.loc[start_index, start_key]\n", " raw_weights[-1] -= running.loc[end_index, end_key] - end\n", - " \n", + "\n", " weights = raw_weights / raw_weights.sum()\n", " values.append(np.multiply(weights, running.loc[start_index:end_index, data_key]).sum())\n", - " \n", + "\n", " return values" ] }, @@ -2015,12 +2019,10 @@ ], "source": [ "running_in_intervals(\n", - " pd.DataFrame({\n", - " \"start_time\": np.array([0, 2, 4, 5, 6]),\n", - " \"end_time\": np.array([1.5, 4, 5, 6, 8]),\n", - " \"velocity\": np.arange(5)\n", - " }), \n", - " [(4.0, 8)]\n", + " pd.DataFrame(\n", + " {\"start_time\": np.array([0, 2, 4, 5, 6]), \"end_time\": np.array([1.5, 4, 5, 6, 8]), \"velocity\": np.arange(5)}\n", + " ),\n", + " [(4.0, 8)],\n", ")" ] } diff --git a/allensdk/brain_observatory/eye_tracking/__main__.py b/allensdk/brain_observatory/eye_tracking/__main__.py index eecac6e59b..f977a83506 100644 --- a/allensdk/brain_observatory/eye_tracking/__main__.py +++ b/allensdk/brain_observatory/eye_tracking/__main__.py @@ -1,6 +1,4 @@ - - -raise NotImplementedError('refactoring in progress') +raise NotImplementedError("refactoring in progress") # def run_rule(rule, **kwargs): diff --git a/allensdk/brain_observatory/eye_tracking/_schemas.py b/allensdk/brain_observatory/eye_tracking/_schemas.py index 162ef14c39..f664875069 100644 --- a/allensdk/brain_observatory/eye_tracking/_schemas.py +++ b/allensdk/brain_observatory/eye_tracking/_schemas.py @@ -2,19 +2,32 @@ from argschema.fields import LogLevel, String from marshmallow import RAISE -from allensdk.brain_observatory.argschema_utilities import check_read_access, check_write_access_overwrite, RaisingSchema +from allensdk.brain_observatory.argschema_utilities import ( + check_read_access, + check_write_access_overwrite, + RaisingSchema, +) + class InputSchema(ArgSchema): class Meta: unknown = RAISE - log_level = LogLevel(default='INFO', description='set the logging level of the module') - rule = String(default='run', required=False) - dockerfile = String(required=True, validate=check_read_access, description='Dockerfile for image') - modelfile = String(required=True, validate=check_read_access, description='Zip file for model') - video_input_file = String(required=True, validate=check_read_access, description='Eye tracking movie') - ellipse_output_data_file = String(required=True, validate=check_write_access_overwrite, description='write outputs to here') - ellipse_output_video_file = String(required=False, validate=check_write_access_overwrite, description='write outputs to here') - points_output_video_file = String(required=False, validate=check_write_access_overwrite, description='write outputs to here') + + log_level = LogLevel(default="INFO", description="set the logging level of the module") + rule = String(default="run", required=False) + dockerfile = String(required=True, validate=check_read_access, description="Dockerfile for image") + modelfile = String(required=True, validate=check_read_access, description="Zip file for model") + video_input_file = String(required=True, validate=check_read_access, description="Eye tracking movie") + ellipse_output_data_file = String( + required=True, validate=check_write_access_overwrite, description="write outputs to here" + ) + ellipse_output_video_file = String( + required=False, validate=check_write_access_overwrite, description="write outputs to here" + ) + points_output_video_file = String( + required=False, validate=check_write_access_overwrite, description="write outputs to here" + ) + class OutputSchema(RaisingSchema): - output_path = String(required=True, description='write outputs to here') \ No newline at end of file + output_path = String(required=True, description="write outputs to here") diff --git a/allensdk/brain_observatory/eye_tracking/build.py b/allensdk/brain_observatory/eye_tracking/build.py index d2bc8b75ce..9902e5ff52 100644 --- a/allensdk/brain_observatory/eye_tracking/build.py +++ b/allensdk/brain_observatory/eye_tracking/build.py @@ -6,31 +6,27 @@ CURR_FILE_DIR = os.path.dirname(os.path.abspath(__file__)) -DOCKERFILE_STAGE_1 = os.path.join(CURR_FILE_DIR, 'stage_1', 'Dockerfile') -DOCKERFILE_STAGE_2 = os.path.join(CURR_FILE_DIR, 'stage_2', 'Dockerfile') -DOCKERFILE_STAGE_3 = os.path.join(CURR_FILE_DIR, 'stage_3', 'Dockerfile') -DOCKERFILE_STAGE_4 = os.path.join(CURR_FILE_DIR, 'stage_4', 'Dockerfile') -MODELFILE_CACHE_LOC = os.path.join(CURR_FILE_DIR, 'stage_1', 'modelfile.zip') +DOCKERFILE_STAGE_1 = os.path.join(CURR_FILE_DIR, "stage_1", "Dockerfile") +DOCKERFILE_STAGE_2 = os.path.join(CURR_FILE_DIR, "stage_2", "Dockerfile") +DOCKERFILE_STAGE_3 = os.path.join(CURR_FILE_DIR, "stage_3", "Dockerfile") +DOCKERFILE_STAGE_4 = os.path.join(CURR_FILE_DIR, "stage_4", "Dockerfile") +MODELFILE_CACHE_LOC = os.path.join(CURR_FILE_DIR, "stage_1", "modelfile.zip") def run_rule(rule, **kwargs): - - if rule == 'clean': - + if rule == "clean": with contextlib.suppress(FileNotFoundError): os.remove(MODELFILE_CACHE_LOC) - elif rule == 'build:stage-1': - + elif rule == "build:stage-1": # Download and cache modelfile: - modelfile = kwargs.get('modelfile') + modelfile = kwargs.get("modelfile") if not modelfile: - if not os.path.exists(MODELFILE_CACHE_LOC): from google.cloud import storage - source_bucketname = 'dlc-eye-tracking-models' - source_filename = 'universal_eye_tracking-peterl-2019-07-10.zip' + source_bucketname = "dlc-eye-tracking-models" + source_filename = "universal_eye_tracking-peterl-2019-07-10.zip" target_filename = MODELFILE_CACHE_LOC client = storage.Client() @@ -40,29 +36,39 @@ def run_rule(rule, **kwargs): modelfile = MODELFILE_CACHE_LOC assert os.path.exists(modelfile) - subprocess.check_call(['docker', 'build', - '--build-arg', 'MODELFILE={modelfile}'.format(modelfile='modelfile.zip'), - '-t', 'dlc-eye-tracking:stage-1', '-f', DOCKERFILE_STAGE_1, 'stage_1']) - - elif rule == 'build:stage-2': - subprocess.check_call(['docker', 'build', - '-t', 'dlc-eye-tracking:stage-2', '-f', DOCKERFILE_STAGE_2, 'stage_2']) - - elif rule == 'build:stage-4': - subprocess.check_call(['docker', 'build', - '-t', 'dlc-eye-tracking:stage-4', '-f', DOCKERFILE_STAGE_4, 'stage_4']) - - elif rule == 'build:stage-3': - + subprocess.check_call( + [ + "docker", + "build", + "--build-arg", + "MODELFILE={modelfile}".format(modelfile="modelfile.zip"), + "-t", + "dlc-eye-tracking:stage-1", + "-f", + DOCKERFILE_STAGE_1, + "stage_1", + ] + ) + + elif rule == "build:stage-2": + subprocess.check_call( + ["docker", "build", "-t", "dlc-eye-tracking:stage-2", "-f", DOCKERFILE_STAGE_2, "stage_2"] + ) + + elif rule == "build:stage-4": + subprocess.check_call( + ["docker", "build", "-t", "dlc-eye-tracking:stage-4", "-f", DOCKERFILE_STAGE_4, "stage_4"] + ) + + elif rule == "build:stage-3": # Download and cache modelfile: - modelfile = kwargs.get('modelfile') + modelfile = kwargs.get("modelfile") if not modelfile: - if not os.path.exists(MODELFILE_CACHE_LOC): from google.cloud import storage - source_bucketname = 'dlc-eye-tracking-models' - source_filename = 'universal_eye_tracking-peterl-2019-07-10.zip' + source_bucketname = "dlc-eye-tracking-models" + source_filename = "universal_eye_tracking-peterl-2019-07-10.zip" target_filename = MODELFILE_CACHE_LOC client = storage.Client() @@ -71,80 +77,135 @@ def run_rule(rule, **kwargs): blob.download_to_filename(target_filename) modelfile = MODELFILE_CACHE_LOC assert os.path.exists(modelfile) - shutil.copyfile(modelfile, os.path.join(CURR_FILE_DIR, 'stage_3', 'modelfile.zip')) - - subprocess.check_call(['docker', 'build', - '--build-arg', 'MODELFILE={modelfile}'.format(modelfile='modelfile.zip'), - '-t', 'dlc-eye-tracking:stage-3', '-f', DOCKERFILE_STAGE_3, 'stage_3']) - - elif rule in ['run:stage-1', 'run:stage-2']: - assert kwargs['modelfile'] is None - video_input_file = kwargs.get('video_input_file') - stage = rule.split(':')[-1] - - subprocess.check_call(['docker', 'run', - '--runtime=nvidia', - '-e', 'VIDEO_INPUT_FILE={}'.format(video_input_file), 'dlc-eye-tracking:{}'.format(stage)]) - - elif rule in ['tag:stage-1', 'tag:stage-2', 'tag:stage-3', 'tag:stage-4']: - stage = rule.split(':')[-1] - subprocess.check_call(['docker', 'tag', 'dlc-eye-tracking:{}'.format(stage), 'us.gcr.io/aibs-pilot/dlc-eye-tracking:{}'.format(stage)]) - - elif rule in ['tag-aibs:stage-1', 'tag-aibs:stage-2', 'tag-aibs:stage-3']: - stage = rule.split(':')[-1] - subprocess.check_call(['docker', 'tag', 'dlc-eye-tracking:{}'.format(stage), 'docker.aibs-artifactory.corp.alleninstitute.org/dlc-eye-tracking:{}'.format(stage)]) - - elif rule in ['push:stage-1', 'push:stage-2', 'push:stage-3', 'push:stage-4']: - stage = rule.split(':')[-1] - subprocess.check_call(['docker', 'push', 'us.gcr.io/aibs-pilot/dlc-eye-tracking:{}'.format(stage)]) - - elif rule in ['push-aibs:stage-1', 'push-aibs:stage-2', 'push-aibs:stage-3']: - stage = rule.split(':')[-1] - subprocess.check_call(['docker', 'push', 'docker.aibs-artifactory.corp.alleninstitute.org/dlc-eye-tracking:{}'.format(stage)]) - - elif rule == 'build:all': - run_rule('build:stage-1', **kwargs) - run_rule('build:stage-2', **kwargs) - run_rule('build:stage-3', **kwargs) - run_rule('build:stage-4', **kwargs) - - elif rule == 'tag:all': - run_rule('tag:stage-1', **kwargs) - run_rule('tag:stage-2', **kwargs) - run_rule('tag:stage-3', **kwargs) - run_rule('tag:stage-4', **kwargs) - - elif rule == 'push:all': - run_rule('push:stage-1', **kwargs) - run_rule('push:stage-2', **kwargs) - run_rule('push:stage-3', **kwargs) - run_rule('push:stage-4', **kwargs) - - elif rule == 'all': - run_rule('build:all', **kwargs) - run_rule('tag:all', **kwargs) - run_rule('push:all', **kwargs) + shutil.copyfile(modelfile, os.path.join(CURR_FILE_DIR, "stage_3", "modelfile.zip")) + + subprocess.check_call( + [ + "docker", + "build", + "--build-arg", + "MODELFILE={modelfile}".format(modelfile="modelfile.zip"), + "-t", + "dlc-eye-tracking:stage-3", + "-f", + DOCKERFILE_STAGE_3, + "stage_3", + ] + ) + + elif rule in ["run:stage-1", "run:stage-2"]: + assert kwargs["modelfile"] is None + video_input_file = kwargs.get("video_input_file") + stage = rule.split(":")[-1] + + subprocess.check_call( + [ + "docker", + "run", + "--runtime=nvidia", + "-e", + "VIDEO_INPUT_FILE={}".format(video_input_file), + "dlc-eye-tracking:{}".format(stage), + ] + ) + + elif rule in ["tag:stage-1", "tag:stage-2", "tag:stage-3", "tag:stage-4"]: + stage = rule.split(":")[-1] + subprocess.check_call( + [ + "docker", + "tag", + "dlc-eye-tracking:{}".format(stage), + "us.gcr.io/aibs-pilot/dlc-eye-tracking:{}".format(stage), + ] + ) + + elif rule in ["tag-aibs:stage-1", "tag-aibs:stage-2", "tag-aibs:stage-3"]: + stage = rule.split(":")[-1] + subprocess.check_call( + [ + "docker", + "tag", + "dlc-eye-tracking:{}".format(stage), + "docker.aibs-artifactory.corp.alleninstitute.org/dlc-eye-tracking:{}".format(stage), + ] + ) + + elif rule in ["push:stage-1", "push:stage-2", "push:stage-3", "push:stage-4"]: + stage = rule.split(":")[-1] + subprocess.check_call(["docker", "push", "us.gcr.io/aibs-pilot/dlc-eye-tracking:{}".format(stage)]) + + elif rule in ["push-aibs:stage-1", "push-aibs:stage-2", "push-aibs:stage-3"]: + stage = rule.split(":")[-1] + subprocess.check_call( + ["docker", "push", "docker.aibs-artifactory.corp.alleninstitute.org/dlc-eye-tracking:{}".format(stage)] + ) + + elif rule == "build:all": + run_rule("build:stage-1", **kwargs) + run_rule("build:stage-2", **kwargs) + run_rule("build:stage-3", **kwargs) + run_rule("build:stage-4", **kwargs) + + elif rule == "tag:all": + run_rule("tag:stage-1", **kwargs) + run_rule("tag:stage-2", **kwargs) + run_rule("tag:stage-3", **kwargs) + run_rule("tag:stage-4", **kwargs) + + elif rule == "push:all": + run_rule("push:stage-1", **kwargs) + run_rule("push:stage-2", **kwargs) + run_rule("push:stage-3", **kwargs) + run_rule("push:stage-4", **kwargs) + + elif rule == "all": + run_rule("build:all", **kwargs) + run_rule("tag:all", **kwargs) + run_rule("push:all", **kwargs) else: - - raise RuntimeError('Invalid rule: {}'.format(rule)) + raise RuntimeError("Invalid rule: {}".format(rule)) if __name__ == "__main__": - # Sanity check: for filename in [DOCKERFILE_STAGE_1, DOCKERFILE_STAGE_2]: assert os.path.exists(filename) parser = argparse.ArgumentParser() - parser.add_argument("rule", help="Rule to run", choices=['build:stage-1', 'run:stage-1', 'tag:stage-1', 'push:stage-1', - 'build:stage-2', 'run:stage-2', 'tag:stage-2', 'push:stage-2', - 'build:stage-3', 'tag:stage-3', 'push:stage-3', - 'build:stage-4', 'tag:stage-4', 'push:stage-4', - 'tag-aibs:stage-1', 'tag-aibs:stage-2', 'push-aibs:stage-1', 'push-aibs:stage-2', - 'clean', 'build:all', 'tag:all', 'push:all', 'all'], nargs='+') + parser.add_argument( + "rule", + help="Rule to run", + choices=[ + "build:stage-1", + "run:stage-1", + "tag:stage-1", + "push:stage-1", + "build:stage-2", + "run:stage-2", + "tag:stage-2", + "push:stage-2", + "build:stage-3", + "tag:stage-3", + "push:stage-3", + "build:stage-4", + "tag:stage-4", + "push:stage-4", + "tag-aibs:stage-1", + "tag-aibs:stage-2", + "push-aibs:stage-1", + "push-aibs:stage-2", + "clean", + "build:all", + "tag:all", + "push:all", + "all", + ], + nargs="+", + ) parser.add_argument("--modelfile", help="DLC model zip file location", type=str) parser.add_argument("--video_input_file", help="input artifact", type=str) args = vars(parser.parse_args()) - for rule in args.pop('rule'): + for rule in args.pop("rule"): run_rule(rule, **args) diff --git a/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py b/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py index 57afc5c4d4..379c2fd068 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py +++ b/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py @@ -1,7 +1,9 @@ import time + t0 = time.time() import os -os.environ["DLClight"]="True" + +os.environ["DLClight"] = "True" import deeplabcut from moviepy.editor import * @@ -10,9 +12,9 @@ ch = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) -logger = logging.getLogger('dlc-eye-tracking') +logger = logging.getLogger("dlc-eye-tracking") logger.setLevel(logging.INFO) logger.addHandler(ch) logger.propagate = False @@ -24,17 +26,18 @@ args = parser.parse_args() video_file_path = args.video_input_file -bucket_data_blobname = video_file_path[:-4] + 'DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000.h5' -output_data_file = '/workdir/{}'.format(bucket_data_blobname) +bucket_data_blobname = video_file_path[:-4] + "DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000.h5" +output_data_file = "/workdir/{}".format(bucket_data_blobname) from google.cloud import storage + client = storage.Client() -src_bucket = client.get_bucket('brain-observatory-eye-videos') -tgt_bucket = client.get_bucket('brain-observatory-dlc-eye-tracking') +src_bucket = client.get_bucket("brain-observatory-eye-videos") +tgt_bucket = client.get_bucket("brain-observatory-dlc-eye-tracking") blob = src_bucket.get_blob(video_file_path) blob.download_to_filename(video_file_path) -path_config_file = '/workdir/model/config.yaml' +path_config_file = "/workdir/model/config.yaml" # ### Track points in video and generate h5 file: @@ -48,19 +51,18 @@ blob2.upload_from_filename(filename=output_data_file) -logger.info('Initialization Time: {}'.format(initialization_time)) -logger.info('DLC Analysis Time: {}'.format(dlc_analysis_time)) -logger.info('Total Walltime: {}'.format(time.time()-t0)) +logger.info("Initialization Time: {}".format(initialization_time)) +logger.info("DLC Analysis Time: {}".format(dlc_analysis_time)) +logger.info("Total Walltime: {}".format(time.time() - t0)) -#optional: display video in notebook -#animation.ipython_display(fps=fps) +# optional: display video in notebook +# animation.ipython_display(fps=fps) -#optional: plot some of the ellipse parameters over time +# optional: plot some of the ellipse parameters over time # %matplotlib inline # import seaborn as sns # sns.set() -# raw = pd.read_hdf(flat_file, key='flat') -# raw.plot(y=['pupil_area', 'pupil_center_x', 'pupil_center_y', 'reflection_x', 'reflection_y'], subplots = True, layout=(1,5), figsize=[25,5], ls='', marker='.', ms=5, title = video_file_path) - +# raw = pd.read_hdf(flat_file, key='flat') +# raw.plot(y=['pupil_area', 'pupil_center_x', 'pupil_center_y', 'reflection_x', 'reflection_y'], subplots = True, layout=(1,5), figsize=[25,5], ls='', marker='.', ms=5, title = video_file_path) diff --git a/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py b/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py index 18a0032e11..05afcee2a0 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py +++ b/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py @@ -1,4 +1,5 @@ import time + t0 = time.time() import os @@ -12,9 +13,9 @@ ch = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) -logger = logging.getLogger('dlc-ellipse-fitting') +logger = logging.getLogger("dlc-ellipse-fitting") logger.setLevel(logging.INFO) logger.addHandler(ch) logger.propagate = False @@ -24,117 +25,148 @@ args = parser.parse_args() video_file_path = args.video_input_file -h5file_path = video_file_path[:-4] + 'DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000.h5' +h5file_path = video_file_path[:-4] + "DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000.h5" -ellipse_bucket_data_blobname = '{}.h5'.format(os.path.splitext(video_file_path)[0]) -ellipse_output_data_file = '/workdir/{}'.format(ellipse_bucket_data_blobname) +ellipse_bucket_data_blobname = "{}.h5".format(os.path.splitext(video_file_path)[0]) +ellipse_output_data_file = "/workdir/{}".format(ellipse_bucket_data_blobname) client = storage.Client() -src_bucket = client.get_bucket('brain-observatory-dlc-eye-tracking') -tgt_bucket = client.get_bucket('dlc-ellipse-fitting') +src_bucket = client.get_bucket("brain-observatory-dlc-eye-tracking") +tgt_bucket = client.get_bucket("dlc-ellipse-fitting") blob = src_bucket.get_blob(h5file_path) blob.download_to_filename(h5file_path) -path_config_file = '/workdir/model/config.yaml' +path_config_file = "/workdir/model/config.yaml" def fit_ellipse(h5name): - df = pd.read_hdf(h5name).DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000 - l_threshold = 0.8 #increased likelihood threshold for points that are allowed in fit + l_threshold = 0.8 # increased likelihood threshold for points that are allowed in fit min_num_points = 6 # uses https://github.com/bdhammel/least-squares-ellipse-fitting # based on the publication Halir, R., Flusser, J.: 'Numerically Stable Direct Least Squares Fitting of Ellipses' - cr = [] - eye = [] - pupil = [] - - #new for loop + cr = [] + eye = [] + pupil = [] + + # new for loop loop_t0 = time.time() last_loop_time = time.time() for j in range(len(df)): - - #fit ellipses to the pupil & eye points in 4/25 - - frac_completed = max(1,float(j))/len(df) - tot_time_est = (time.time() - loop_t0)/frac_completed - progress_str = "{:10.2f} {:10.2f} {:5s} {:10.2f}".format(time.time()-last_loop_time, time.time()-loop_t0, "{0:.0%}".format(frac_completed), tot_time_est) - logger.info('Ellipse fit: {}'.format(progress_str)) + # fit ellipses to the pupil & eye points in 4/25 + + frac_completed = max(1, float(j)) / len(df) + tot_time_est = (time.time() - loop_t0) / frac_completed + progress_str = "{:10.2f} {:10.2f} {:5s} {:10.2f}".format( + time.time() - last_loop_time, time.time() - loop_t0, "{0:.0%}".format(frac_completed), tot_time_est + ) + logger.info("Ellipse fit: {}".format(progress_str)) last_loop_time = time.time() - + x_data = df.filter(regex=("cr*")).iloc[j].values[0::3] y_data = df.filter(regex=("cr*")).iloc[j].values[1::3] l = df.filter(regex=("cr*")).iloc[j].values[2::3] try: - if len(l[l>l_threshold]) >= min_num_points: #at least 6 tracked points for annotation quality data - lsqe = LSqEllipse() #make fitting object - lsqe.fit([x_data[l>l_threshold], y_data[l>l_threshold]]) + if len(l[l > l_threshold]) >= min_num_points: # at least 6 tracked points for annotation quality data + lsqe = LSqEllipse() # make fitting object + lsqe.fit([x_data[l > l_threshold], y_data[l > l_threshold]]) center, width, height, phi = lsqe.parameters() - ellipse_dict = {'center_x' : center[0], 'center_y' : center[1], 'width' : width, 'height' : height, 'phi' : phi} + ellipse_dict = { + "center_x": center[0], + "center_y": center[1], + "width": width, + "height": height, + "phi": phi, + } else: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = { + "center_x": np.nan, + "center_y": np.nan, + "width": np.nan, + "height": np.nan, + "phi": np.nan, + } except Exception as e: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = {"center_x": np.nan, "center_y": np.nan, "width": np.nan, "height": np.nan, "phi": np.nan} print(e) cr.append(ellipse_dict) - #eye + # eye x_data = df.filter(regex=("eye*")).iloc[j].values[0::3] y_data = df.filter(regex=("eye*")).iloc[j].values[1::3] l = df.filter(regex=("eye*")).iloc[j].values[2::3] try: - if len(l[l>l_threshold]) >= min_num_points: #at least 6 tracked points for annotation quality data - lsqe = LSqEllipse() #make fitting object - lsqe.fit([x_data[l>l_threshold], y_data[l>l_threshold]]) + if len(l[l > l_threshold]) >= min_num_points: # at least 6 tracked points for annotation quality data + lsqe = LSqEllipse() # make fitting object + lsqe.fit([x_data[l > l_threshold], y_data[l > l_threshold]]) center, width, height, phi = lsqe.parameters() - ellipse_dict = {'center_x' : center[0], 'center_y' : center[1], 'width' : width, 'height' : height, 'phi' : phi} + ellipse_dict = { + "center_x": center[0], + "center_y": center[1], + "width": width, + "height": height, + "phi": phi, + } else: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = { + "center_x": np.nan, + "center_y": np.nan, + "width": np.nan, + "height": np.nan, + "phi": np.nan, + } except Exception as e: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = {"center_x": np.nan, "center_y": np.nan, "width": np.nan, "height": np.nan, "phi": np.nan} print(e) - eye.append(ellipse_dict) + eye.append(ellipse_dict) - - #pupil + # pupil x_data = df.filter(regex=("pupil*")).iloc[j].values[0::3] y_data = df.filter(regex=("pupil*")).iloc[j].values[1::3] l = df.filter(regex=("pupil*")).iloc[j].values[2::3] try: - if len(l[l>l_threshold]) >= min_num_points: #at least 6 tracked points for annotation quality data - lsqe = LSqEllipse() #make fitting object - lsqe.fit([x_data[l>l_threshold], y_data[l>l_threshold]]) + if len(l[l > l_threshold]) >= min_num_points: # at least 6 tracked points for annotation quality data + lsqe = LSqEllipse() # make fitting object + lsqe.fit([x_data[l > l_threshold], y_data[l > l_threshold]]) center, width, height, phi = lsqe.parameters() - ellipse_dict = {'center_x' : center[0], 'center_y' : center[1], 'width' : width, 'height' : height, 'phi' : phi} + ellipse_dict = { + "center_x": center[0], + "center_y": center[1], + "width": width, + "height": height, + "phi": phi, + } else: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = { + "center_x": np.nan, + "center_y": np.nan, + "width": np.nan, + "height": np.nan, + "phi": np.nan, + } except Exception as e: - ellipse_dict = {'center_x' : np.nan, 'center_y' : np.nan, 'width' : np.nan, 'height' : np.nan, 'phi' : np.nan} + ellipse_dict = {"center_x": np.nan, "center_y": np.nan, "width": np.nan, "height": np.nan, "phi": np.nan} print(e) - pupil.append(ellipse_dict) - - pd.DataFrame(cr).to_hdf(ellipse_output_data_file, key='cr', mode='w') #overwrite file - pd.DataFrame(eye).to_hdf(ellipse_output_data_file, key='eye', mode='a') - pd.DataFrame(pupil).to_hdf(ellipse_output_data_file, key='pupil', mode='a') + pupil.append(ellipse_dict) + pd.DataFrame(cr).to_hdf(ellipse_output_data_file, key="cr", mode="w") # overwrite file + pd.DataFrame(eye).to_hdf(ellipse_output_data_file, key="eye", mode="a") + pd.DataFrame(pupil).to_hdf(ellipse_output_data_file, key="pupil", mode="a") blob2 = tgt_bucket.blob(ellipse_bucket_data_blobname) blob2.upload_from_filename(filename=ellipse_output_data_file) - - + return cr, eye, pupil - + + initialization_time = time.time() - t0 ellipse_fit_t0 = time.time() -cr, eye, pupil = fit_ellipse(h5file_path) +cr, eye, pupil = fit_ellipse(h5file_path) ellipse_fit_time = time.time() - ellipse_fit_t0 -logger.info('Initialization Time: {}'.format(initialization_time)) -logger.info('Ellipse Fit Time: {}'.format(ellipse_fit_time)) -logger.info('Total Walltime: {}'.format(time.time()-t0)) - - - +logger.info("Initialization Time: {}".format(initialization_time)) +logger.info("Ellipse Fit Time: {}".format(ellipse_fit_time)) +logger.info("Total Walltime: {}".format(time.time() - t0)) diff --git a/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py b/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py index d41f73b67c..37a0a025cb 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py +++ b/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py @@ -1,7 +1,9 @@ import time + t0 = time.time() import os -os.environ["DLClight"]="True" + +os.environ["DLClight"] = "True" import deeplabcut from moviepy.editor import * @@ -10,9 +12,9 @@ ch = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) -logger = logging.getLogger('dlc-eye-tracking') +logger = logging.getLogger("dlc-eye-tracking") logger.setLevel(logging.INFO) logger.addHandler(ch) logger.propagate = False @@ -24,16 +26,17 @@ args = parser.parse_args() video_file_path = args.video_input_file -bucket_data_blobname = video_file_path[:-4] + 'DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000_labeled.mp4' -output_data_file = '/workdir/{}'.format(bucket_data_blobname) +bucket_data_blobname = video_file_path[:-4] + "DeepCut_resnet50_universal_eye_trackingJul10shuffle1_1030000_labeled.mp4" +output_data_file = "/workdir/{}".format(bucket_data_blobname) from google.cloud import storage + client = storage.Client() -src_bucket = client.get_bucket('brain-observatory-eye-videos') -tgt_bucket = client.get_bucket('dlc-labeled-videos') +src_bucket = client.get_bucket("brain-observatory-eye-videos") +tgt_bucket = client.get_bucket("dlc-labeled-videos") blob = src_bucket.get_blob(video_file_path) blob.download_to_filename(video_file_path) -path_config_file = '/workdir/model/config.yaml' +path_config_file = "/workdir/model/config.yaml" initialization_time = time.time() - t0 dlc_analysis_t0 = time.time() @@ -45,12 +48,11 @@ dlc_movie_time = time.time() - dlc_movie_t0 - blob2 = tgt_bucket.blob(bucket_data_blobname) blob2.upload_from_filename(filename=output_data_file) -logger.info('Initialization Time: {}'.format(initialization_time)) -logger.info('DLC Analysis Time: {}'.format(dlc_analysis_time)) -logger.info('DLC Movie Generation Time: {}'.format(dlc_movie_time)) -logger.info('Total Walltime: {}'.format(time.time()-t0)) +logger.info("Initialization Time: {}".format(initialization_time)) +logger.info("DLC Analysis Time: {}".format(dlc_analysis_time)) +logger.info("DLC Movie Generation Time: {}".format(dlc_movie_time)) +logger.info("Total Walltime: {}".format(time.time() - t0)) diff --git a/allensdk/brain_observatory/eye_tracking/stage_4/DLC_Ellipse_Video.py b/allensdk/brain_observatory/eye_tracking/stage_4/DLC_Ellipse_Video.py index a9624396fb..317f371df1 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_4/DLC_Ellipse_Video.py +++ b/allensdk/brain_observatory/eye_tracking/stage_4/DLC_Ellipse_Video.py @@ -1,4 +1,5 @@ import time + t0 = time.time() import os @@ -14,9 +15,9 @@ ch = logging.StreamHandler() -formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ch.setFormatter(formatter) -logger = logging.getLogger('dlc-ellipse-fitting') +logger = logging.getLogger("dlc-ellipse-fitting") logger.setLevel(logging.INFO) logger.addHandler(ch) logger.propagate = False @@ -26,16 +27,16 @@ args = parser.parse_args() video_file_path = args.video_input_file -ellipse_bucket_data_blobname = '{}.h5'.format(os.path.splitext(video_file_path)[0]) -source_ellipse_data_file = '/workdir/{}'.format(ellipse_bucket_data_blobname) +ellipse_bucket_data_blobname = "{}.h5".format(os.path.splitext(video_file_path)[0]) +source_ellipse_data_file = "/workdir/{}".format(ellipse_bucket_data_blobname) source_ellipse_data_file = ellipse_bucket_data_blobname client = storage.Client() -fit_src_bucket = client.get_bucket('dlc-ellipse-fitting') +fit_src_bucket = client.get_bucket("dlc-ellipse-fitting") blob = fit_src_bucket.get_blob(ellipse_bucket_data_blobname) blob.download_to_filename(source_ellipse_data_file) -movie_src_bucket = client.get_bucket('brain-observatory-eye-videos') +movie_src_bucket = client.get_bucket("brain-observatory-eye-videos") blob = movie_src_bucket.get_blob(video_file_path) blob.download_to_filename(video_file_path) @@ -43,32 +44,62 @@ ellipse_output_video_file = "/workdir/{}".format(ellipse_output_blob_name) ellipse_output_video_file = ellipse_output_blob_name -cr = pd.read_hdf(source_ellipse_data_file, key='cr') -eye = pd.read_hdf(source_ellipse_data_file, key='eye') -pupil = pd.read_hdf(source_ellipse_data_file, key='pupil') +cr = pd.read_hdf(source_ellipse_data_file, key="cr") +eye = pd.read_hdf(source_ellipse_data_file, key="eye") +pupil = pd.read_hdf(source_ellipse_data_file, key="pupil") -def make_frame(t): - fi = int(np.round(t*fps)) +def make_frame(t): + fi = int(np.round(t * fps)) ax.clear() ax.imshow(clip.get_frame(t)) - #that is the pupi; ellipse in red + # that is the pupi; ellipse in red try: - ellipse = Ellipse((cr.loc[fi]['center_x'], cr.loc[fi]['center_y']), 2*cr.loc[fi]['width'], 2*cr.loc[fi]['height'], np.rad2deg(cr.loc[fi]['phi']), alpha=0.8, ec='r', fc=None, lw=2, fill=False) + ellipse = Ellipse( + (cr.loc[fi]["center_x"], cr.loc[fi]["center_y"]), + 2 * cr.loc[fi]["width"], + 2 * cr.loc[fi]["height"], + np.rad2deg(cr.loc[fi]["phi"]), + alpha=0.8, + ec="r", + fc=None, + lw=2, + fill=False, + ) ax.add_patch(ellipse) except Exception as e: print(e) - #that is the eye ellipse in green + # that is the eye ellipse in green try: - ellipse = Ellipse((eye.loc[fi]['center_x'], eye.loc[fi]['center_y']), 2*eye.loc[fi]['width'], 2*eye.loc[fi]['height'], np.rad2deg(eye.loc[fi]['phi']), alpha=0.8, ec='g', fc=None, lw=2, fill=False) + ellipse = Ellipse( + (eye.loc[fi]["center_x"], eye.loc[fi]["center_y"]), + 2 * eye.loc[fi]["width"], + 2 * eye.loc[fi]["height"], + np.rad2deg(eye.loc[fi]["phi"]), + alpha=0.8, + ec="g", + fc=None, + lw=2, + fill=False, + ) ax.add_patch(ellipse) except Exception as e: print(e) - #Corneal reflection in blue + # Corneal reflection in blue try: - ellipse = Ellipse((pupil.loc[fi]['center_x'], pupil.loc[fi]['center_y']), 2*pupil.loc[fi]['width'], 2*pupil.loc[fi]['height'], np.rad2deg(pupil.loc[fi]['phi']), alpha=0.8, ec='b', fc=None, lw=2, fill=False) + ellipse = Ellipse( + (pupil.loc[fi]["center_x"], pupil.loc[fi]["center_y"]), + 2 * pupil.loc[fi]["width"], + 2 * pupil.loc[fi]["height"], + np.rad2deg(pupil.loc[fi]["phi"]), + alpha=0.8, + ec="b", + fc=None, + lw=2, + fill=False, + ) ax.add_patch(ellipse) ax.set_axis_off() except Exception as e: @@ -76,6 +107,7 @@ def make_frame(t): return mplfig_to_npimage(fig) + initialization_time = time.time() - t0 ellipse_video_t0 = time.time() @@ -90,15 +122,12 @@ def make_frame(t): animation = VideoClip(make_frame, duration=clip.duration).resize(newsize=clip.size) animation.write_videofile(ellipse_output_video_file, fps=fps) -tgt_bucket = client.get_bucket('dlc-ellipse-videos') +tgt_bucket = client.get_bucket("dlc-ellipse-videos") blob2 = tgt_bucket.blob(ellipse_output_blob_name) blob2.upload_from_filename(filename=ellipse_output_video_file) ellipse_video_time = time.time() - ellipse_video_t0 -logger.info('Initialization Time: {}'.format(initialization_time)) +logger.info("Initialization Time: {}".format(initialization_time)) # logger.info('Ellipse Video Time: {}'.format(ellipse_video_time)) -logger.info('Total Walltime: {}'.format(time.time()-t0)) - - - +logger.info("Total Walltime: {}".format(time.time() - t0)) diff --git a/allensdk/brain_observatory/findlevel.py b/allensdk/brain_observatory/findlevel.py index 09d3c28c2f..29416ffd46 100644 --- a/allensdk/brain_observatory/findlevel.py +++ b/allensdk/brain_observatory/findlevel.py @@ -36,11 +36,11 @@ import numpy as np -def findlevel(inwave, threshold, direction='both'): +def findlevel(inwave, threshold, direction="both"): temp = inwave - threshold - if (direction.find("up") + 1): + if direction.find("up") + 1: crossings = np.nonzero(np.ediff1d(np.sign(temp), to_begin=0.0) > 0) - elif (direction.find("down") + 1): + elif direction.find("down") + 1: crossings = np.nonzero(np.ediff1d(np.sign(temp), to_begin=0.0) < 0) else: crossings = np.nonzero(np.ediff1d(np.sign(temp), to_begin=0.0)) diff --git a/allensdk/brain_observatory/gaze_mapping/__main__.py b/allensdk/brain_observatory/gaze_mapping/__main__.py index 71d17fb946..9cd236460d 100644 --- a/allensdk/brain_observatory/gaze_mapping/__main__.py +++ b/allensdk/brain_observatory/gaze_mapping/__main__.py @@ -9,17 +9,12 @@ import allensdk -from allensdk.brain_observatory.argschema_utilities import ( - write_or_print_outputs -) -from allensdk.brain_observatory.gaze_mapping._schemas import ( - InputSchema, - OutputSchema -) +from allensdk.brain_observatory.argschema_utilities import write_or_print_outputs +from allensdk.brain_observatory.gaze_mapping._schemas import InputSchema, OutputSchema from allensdk.brain_observatory.gaze_mapping._gaze_mapper import ( compute_circular_areas, compute_elliptical_areas, - GazeMapper + GazeMapper, ) from allensdk.brain_observatory.gaze_mapping._filter_utils import ( post_process_areas, @@ -58,18 +53,17 @@ def load_ellipse_fit_params(input_file: Path) -> Dict[str, pd.DataFrame]: cr_params = pd.read_hdf(input_file, key="cr").astype(float) eye_params = pd.read_hdf(input_file, key="eye").astype(float) - num_frames_match = ((pupil_params.shape[0] == cr_params.shape[0]) - and (cr_params.shape[0] == eye_params.shape[0])) + num_frames_match = (pupil_params.shape[0] == cr_params.shape[0]) and (cr_params.shape[0] == eye_params.shape[0]) if not num_frames_match: - raise RuntimeError("The number of frames for ellipse fits don't " - "match when they should: " - f"pupil_params ({pupil_params.shape[0]}), " - f"cr_params ({cr_params.shape[0]}), " - f"eye_params ({eye_params.shape[0]}).") + raise RuntimeError( + "The number of frames for ellipse fits don't " + "match when they should: " + f"pupil_params ({pupil_params.shape[0]}), " + f"cr_params ({cr_params.shape[0]}), " + f"eye_params ({eye_params.shape[0]})." + ) - return {"pupil_params": pupil_params, - "cr_params": cr_params, - "eye_params": eye_params} + return {"pupil_params": pupil_params, "cr_params": cr_params, "eye_params": eye_params} def preprocess_input_args(parser_args: dict) -> dict: @@ -95,29 +89,52 @@ def preprocess_input_args(parser_args: dict) -> dict: new_args["session_sync_file"] = parser_args["session_sync_file"] new_args["output_file"] = parser_args["output_file"] - monitor_position = np.array([parser_args["monitor_position_x_mm"], - parser_args["monitor_position_y_mm"], - parser_args["monitor_position_z_mm"]]) / 10 + monitor_position = ( + np.array( + [ + parser_args["monitor_position_x_mm"], + parser_args["monitor_position_y_mm"], + parser_args["monitor_position_z_mm"], + ] + ) + / 10 + ) new_args["monitor_position"] = monitor_position - monitor_rotations_deg = np.array([parser_args["monitor_rotation_x_deg"], - parser_args["monitor_rotation_y_deg"], - parser_args["monitor_rotation_z_deg"]]) + monitor_rotations_deg = np.array( + [ + parser_args["monitor_rotation_x_deg"], + parser_args["monitor_rotation_y_deg"], + parser_args["monitor_rotation_z_deg"], + ] + ) new_args["monitor_rotations"] = np.radians(monitor_rotations_deg) - camera_position = np.array([parser_args["camera_position_x_mm"], - parser_args["camera_position_y_mm"], - parser_args["camera_position_z_mm"]]) / 10 + camera_position = ( + np.array( + [ + parser_args["camera_position_x_mm"], + parser_args["camera_position_y_mm"], + parser_args["camera_position_z_mm"], + ] + ) + / 10 + ) new_args["camera_position"] = camera_position - camera_rotations_deg = np.array([parser_args["camera_rotation_x_deg"], - parser_args["camera_rotation_y_deg"], - parser_args["camera_rotation_z_deg"]]) + camera_rotations_deg = np.array( + [ + parser_args["camera_rotation_x_deg"], + parser_args["camera_rotation_y_deg"], + parser_args["camera_rotation_z_deg"], + ] + ) new_args["camera_rotations"] = np.radians(camera_rotations_deg) - led_position = np.array([parser_args["led_position_x_mm"], - parser_args["led_position_y_mm"], - parser_args["led_position_z_mm"]]) / 10 + led_position = ( + np.array([parser_args["led_position_x_mm"], parser_args["led_position_y_mm"], parser_args["led_position_z_mm"]]) + / 10 + ) new_args["led_position"] = led_position new_args["eye_radius_cm"] = parser_args["eye_radius_cm"] new_args["cm_per_pixel"] = parser_args["cm_per_pixel"] @@ -128,16 +145,18 @@ def preprocess_input_args(parser_args: dict) -> dict: return new_args -def run_gaze_mapping(pupil_parameters: pd.DataFrame, - cr_parameters: pd.DataFrame, - eye_parameters: pd.DataFrame, - monitor_position: np.ndarray, - monitor_rotations: np.ndarray, - camera_position: np.ndarray, - camera_rotations: np.ndarray, - led_position: np.ndarray, - eye_radius_cm: float, - cm_per_pixel: float) -> dict: +def run_gaze_mapping( + pupil_parameters: pd.DataFrame, + cr_parameters: pd.DataFrame, + eye_parameters: pd.DataFrame, + monitor_position: np.ndarray, + monitor_rotations: np.ndarray, + camera_position: np.ndarray, + camera_rotations: np.ndarray, + led_position: np.ndarray, + eye_radius_cm: float, + cm_per_pixel: float, +) -> dict: """Map gaze positions onto monitor coordinates and calculate eye/pupil areas @@ -178,13 +197,15 @@ def run_gaze_mapping(pupil_parameters: pd.DataFrame, """ output = {} - gaze_mapper = GazeMapper(monitor_position=monitor_position, - monitor_rotations=monitor_rotations, - led_position=led_position, - camera_position=camera_position, - camera_rotations=camera_rotations, - eye_radius=eye_radius_cm, - cm_per_pixel=cm_per_pixel) + gaze_mapper = GazeMapper( + monitor_position=monitor_position, + monitor_rotations=monitor_rotations, + led_position=led_position, + camera_position=camera_position, + camera_rotations=camera_rotations, + eye_radius=eye_radius_cm, + cm_per_pixel=cm_per_pixel, + ) pupil_params_in_cm = pupil_parameters * cm_per_pixel raw_pupil_areas = compute_circular_areas(pupil_params_in_cm) @@ -194,7 +215,7 @@ def run_gaze_mapping(pupil_parameters: pd.DataFrame, raw_pupil_on_monitor_cm = gaze_mapper.pupil_position_on_monitor_in_cm( cam_pupil_params=pupil_parameters[["center_x", "center_y"]].values, - cam_cr_params=cr_parameters[["center_x", "center_y"]].values + cam_cr_params=cr_parameters[["center_x", "center_y"]].values, ) raw_pupil_on_monitor_deg = gaze_mapper.pupil_position_on_monitor_in_degrees( @@ -203,9 +224,7 @@ def run_gaze_mapping(pupil_parameters: pd.DataFrame, # Make bool mask for all time indices where # pupil_area or eye_area or pupil_on_monitor_* is np.nan - raw_nan_mask = (raw_pupil_areas.isna() - | raw_eye_areas.isna() - | np.isnan(raw_pupil_on_monitor_deg.T[0])) + raw_nan_mask = raw_pupil_areas.isna() | raw_eye_areas.isna() | np.isnan(raw_pupil_on_monitor_deg.T[0]) raw_pupil_areas[raw_nan_mask] = np.nan raw_eye_areas[raw_nan_mask] = np.nan raw_pupil_on_monitor_cm[raw_nan_mask, :] = np.nan @@ -224,15 +243,9 @@ def run_gaze_mapping(pupil_parameters: pd.DataFrame, new_pupil_areas = post_process_areas(new_pupil_areas.values) new_eye_areas = post_process_areas(new_eye_areas.values) - _, filtered_pos_indices = post_process_cr(cr_parameters[["center_x", - "center_y", - "phi", - "width", - "height"]].values) - - new_nan_mask = (np.isnan(new_pupil_areas) - | np.isnan(new_eye_areas) - | filtered_pos_indices) + _, filtered_pos_indices = post_process_cr(cr_parameters[["center_x", "center_y", "phi", "width", "height"]].values) + + new_nan_mask = np.isnan(new_pupil_areas) | np.isnan(new_eye_areas) | filtered_pos_indices new_pupil_areas[new_nan_mask] = np.nan new_eye_areas[new_nan_mask] = np.nan new_pupil_on_monitor_cm[new_nan_mask, :] = np.nan @@ -246,8 +259,7 @@ def run_gaze_mapping(pupil_parameters: pd.DataFrame, return output -def write_gaze_mapping_output_to_h5(output_savepath: Path, - gaze_map_output: dict): +def write_gaze_mapping_output_to_h5(output_savepath: Path, gaze_map_output: dict): """Write output of gaze mapping to an h5 file. Args: @@ -260,12 +272,16 @@ def write_gaze_mapping_output_to_h5(output_savepath: Path, gaze_map_output["raw_eye_areas"].to_hdf(output_savepath, key="raw_eye_areas", mode="w") gaze_map_output["raw_pupil_areas"].to_hdf(output_savepath, key="raw_pupil_areas", mode="a") gaze_map_output["raw_pupil_on_monitor_cm"].to_hdf(output_savepath, key="raw_screen_coordinates", mode="a") - gaze_map_output["raw_pupil_on_monitor_deg"].to_hdf(output_savepath, key="raw_screen_coordinates_spherical", mode="a") + gaze_map_output["raw_pupil_on_monitor_deg"].to_hdf( + output_savepath, key="raw_screen_coordinates_spherical", mode="a" + ) gaze_map_output["new_eye_areas"].to_hdf(output_savepath, key="new_eye_areas", mode="a") gaze_map_output["new_pupil_areas"].to_hdf(output_savepath, key="new_pupil_areas", mode="a") gaze_map_output["new_pupil_on_monitor_cm"].to_hdf(output_savepath, key="new_screen_coordinates", mode="a") - gaze_map_output["new_pupil_on_monitor_deg"].to_hdf(output_savepath, key="new_screen_coordinates_spherical", mode="a") + gaze_map_output["new_pupil_on_monitor_deg"].to_hdf( + output_savepath, key="new_screen_coordinates_spherical", mode="a" + ) gaze_map_output["synced_frame_timestamps_sec"].to_hdf(output_savepath, key="synced_frame_timestamps", mode="a") @@ -273,9 +289,7 @@ def write_gaze_mapping_output_to_h5(output_savepath: Path, version.to_hdf(output_savepath, key="version", mode="a") -def load_sync_file_timings(sync_file: Path, - pupil_params_rows: int, - truncate_timestamps: bool) -> pd.Series: +def load_sync_file_timings(sync_file: Path, pupil_params_rows: int, truncate_timestamps: bool) -> pd.Series: """Load sync file timings from .h5 file. Parameters @@ -300,42 +314,44 @@ def load_sync_file_timings(sync_file: Path, up with number of new frame times from the sync file. """ # Add synchronized frame times - frame_times = su.get_synchronized_frame_times(session_sync_file=sync_file, - sync_line_label_keys=Dataset.EYE_TRACKING_KEYS, - trim_after_spike=truncate_timestamps) - if (pupil_params_rows != len(frame_times)): - raise RuntimeError("The number of camera sync pulses in the " - f"sync file ({len(frame_times)}) do not match " - "with the number of eye tracking frames " - f"({pupil_params_rows})!!!") + frame_times = su.get_synchronized_frame_times( + session_sync_file=sync_file, + sync_line_label_keys=Dataset.EYE_TRACKING_KEYS, + trim_after_spike=truncate_timestamps, + ) + if pupil_params_rows != len(frame_times): + raise RuntimeError( + "The number of camera sync pulses in the " + f"sync file ({len(frame_times)}) do not match " + "with the number of eye tracking frames " + f"({pupil_params_rows})!!!" + ) return frame_times def main(): + logging.basicConfig(format=("%(asctime)s:%(funcName)s:%(levelname)s:%(message)s")) - logging.basicConfig(format=('%(asctime)s:%(funcName)s' - ':%(levelname)s:%(message)s')) - - parser = ArgSchemaParser(args=sys.argv[1:], - schema_type=InputSchema, - output_schema_type=OutputSchema) + parser = ArgSchemaParser(args=sys.argv[1:], schema_type=InputSchema, output_schema_type=OutputSchema) args = preprocess_input_args(parser.args) - output = run_gaze_mapping(pupil_parameters=args["pupil_params"], - cr_parameters=args["cr_params"], - eye_parameters=args["eye_params"], - monitor_position=args["monitor_position"], - monitor_rotations=args["monitor_rotations"], - camera_position=args["camera_position"], - camera_rotations=args["camera_rotations"], - led_position=args["led_position"], - eye_radius_cm=args["eye_radius_cm"], - cm_per_pixel=args["cm_per_pixel"]) - - output["synced_frame_timestamps_sec"] = load_sync_file_timings(args["session_sync_file"], - args["pupil_params"].shape[0], - parser.args["truncate_timestamps"]) + output = run_gaze_mapping( + pupil_parameters=args["pupil_params"], + cr_parameters=args["cr_params"], + eye_parameters=args["eye_params"], + monitor_position=args["monitor_position"], + monitor_rotations=args["monitor_rotations"], + camera_position=args["camera_position"], + camera_rotations=args["camera_rotations"], + led_position=args["led_position"], + eye_radius_cm=args["eye_radius_cm"], + cm_per_pixel=args["cm_per_pixel"], + ) + + output["synced_frame_timestamps_sec"] = load_sync_file_timings( + args["session_sync_file"], args["pupil_params"].shape[0], parser.args["truncate_timestamps"] + ) write_gaze_mapping_output_to_h5(args["output_file"], output) module_output = {"screen_mapping_file": str(args["output_file"])} diff --git a/allensdk/brain_observatory/gaze_mapping/_filter_utils.py b/allensdk/brain_observatory/gaze_mapping/_filter_utils.py index 669d410015..32f3b4b1c4 100644 --- a/allensdk/brain_observatory/gaze_mapping/_filter_utils.py +++ b/allensdk/brain_observatory/gaze_mapping/_filter_utils.py @@ -3,13 +3,13 @@ def medfilt_custom(x, kernel_size=3): - '''This median filter returns 'nan' whenever any value in the kernal width - is 'nan' and the median otherwise''' + """This median filter returns 'nan' whenever any value in the kernal width + is 'nan' and the median otherwise""" T = x.shape[0] delta = kernel_size // 2 x_med = np.zeros(x.shape) - window = x[0:delta + 1] + window = x[0 : delta + 1] if np.any(np.isnan(window)): x_med[0] = np.nan else: @@ -17,7 +17,7 @@ def medfilt_custom(x, kernel_size=3): # print window for t in range(1, T): - window = x[t - delta:t + delta + 1] + window = x[t - delta : t + delta + 1] # print window if np.any(np.isnan(window)): x_med[t] = np.nan @@ -28,7 +28,7 @@ def medfilt_custom(x, kernel_size=3): def median_absolute_deviation(a, consistency_constant=1.4826): - '''Calculate the median absolute deviation of a univariate dataset. + """Calculate the median absolute deviation of a univariate dataset. Parameters ---------- @@ -42,7 +42,7 @@ def median_absolute_deviation(a, consistency_constant=1.4826): ------- float Median absolute deviation of the data. - ''' + """ return consistency_constant * np.nanmedian(np.abs(a - np.nanmedian(a))) @@ -69,8 +69,7 @@ def post_process_cr(cr_params): # compute a threshold on the area of the cr ellipse dev = median_absolute_deviation(area) if dev == 0: - logging.warning("Median absolute deviation is 0," - "falling back to standard deviation.") + logging.warning("Median absolute deviation is 0,falling back to standard deviation.") dev = np.nanstd(area) threshold = np.nanmedian(area) + 3 * dev @@ -97,8 +96,8 @@ def post_process_cr(cr_params): std_y = np.std(y_center_med[y_mask_finite]) # set these extreme values to nan - x_center_med[np.abs(x_center_med - mean_x) > 3*std_x] = np.nan - y_center_med[np.abs(y_center_med - mean_y) > 3*std_y] = np.nan + x_center_med[np.abs(x_center_med - mean_x) > 3 * std_x] = np.nan + y_center_med[np.abs(y_center_med - mean_y) > 3 * std_y] = np.nan either_nan_mask = np.isnan(x_center_med) | np.isnan(y_center_med) x_center_med[either_nan_mask] = np.nan @@ -112,7 +111,7 @@ def post_process_cr(cr_params): def post_process_areas(areas: np.ndarray, percent_thresh: int = 99): - '''Filter pupil or eye area data by replacing outliers with nan + """Filter pupil or eye area data by replacing outliers with nan Parameters ---------- @@ -125,7 +124,7 @@ def post_process_areas(areas: np.ndarray, percent_thresh: int = 99): ------- numpy.ndarray Eye/pupil areas with outliers replaced with nan - ''' + """ threshold = np.percentile(areas[np.isfinite(areas)], percent_thresh) outlier_indices = areas > threshold areas[outlier_indices] = np.nan diff --git a/allensdk/brain_observatory/gaze_mapping/_gaze_mapper.py b/allensdk/brain_observatory/gaze_mapping/_gaze_mapper.py index d3c51fd6ae..512b3aebf0 100644 --- a/allensdk/brain_observatory/gaze_mapping/_gaze_mapper.py +++ b/allensdk/brain_observatory/gaze_mapping/_gaze_mapper.py @@ -17,9 +17,8 @@ class EyeTrackingRigObject(object): in the rig object's own coordinate system """ - def __init__(self, - position_in_eye_coord_frame: np.ndarray, - rotations_in_self_coord_frame: np.ndarray): + + def __init__(self, position_in_eye_coord_frame: np.ndarray, rotations_in_self_coord_frame: np.ndarray): self.position = position_in_eye_coord_frame self.rotations = rotations_in_self_coord_frame @@ -64,13 +63,13 @@ def generate_self_to_eye_frame_xform(self) -> Rotation: # Determine rotation in ECS needed to rotate obj_norm vector so that # its x-axis aligns with the ECS x-axis. theta_z = -(np.pi / 2 + np.arctan2(obj_norm[1], obj_norm[0])) - rz = Rotation.from_euler('z', theta_z, degrees=False) + rz = Rotation.from_euler("z", theta_z, degrees=False) obj_norm_prime = rz.apply(obj_norm) # Determine rotation in ECS needed to rotate transformed obj_norm # vector so that its z-axis aligns with the ECS z-axis theta_x = np.pi / 2 - np.arctan2(obj_norm_prime[2], obj_norm_prime[1]) - rx = Rotation.from_euler('x', theta_x, degrees=False) + rx = Rotation.from_euler("x", theta_x, degrees=False) # Compose rotations, note the order! eye_to_object_xform = rx * rz @@ -101,21 +100,26 @@ class GazeMapper(object): cm_per_pixel : float Pixel size of eye-tracking camera. """ - def __init__(self, - monitor_position: np.ndarray, - monitor_rotations: np.ndarray, - led_position: np.ndarray, - camera_position: np.ndarray, - camera_rotations: np.ndarray, - eye_radius: float, - cm_per_pixel: float): + + def __init__( + self, + monitor_position: np.ndarray, + monitor_rotations: np.ndarray, + led_position: np.ndarray, + camera_position: np.ndarray, + camera_rotations: np.ndarray, + eye_radius: float, + cm_per_pixel: float, + ): self.eye_radius = eye_radius self.cm_per_pixel = cm_per_pixel self.led_pos = led_position - self.monitor = EyeTrackingRigObject(position_in_eye_coord_frame=monitor_position, - rotations_in_self_coord_frame=monitor_rotations) - self.camera = EyeTrackingRigObject(position_in_eye_coord_frame=camera_position, - rotations_in_self_coord_frame=camera_rotations) + self.monitor = EyeTrackingRigObject( + position_in_eye_coord_frame=monitor_position, rotations_in_self_coord_frame=monitor_rotations + ) + self.camera = EyeTrackingRigObject( + position_in_eye_coord_frame=camera_position, rotations_in_self_coord_frame=camera_rotations + ) self.cr = self.compute_cr_coordinate() def compute_cr_coordinate(self) -> np.ndarray: @@ -158,16 +162,13 @@ def compute_cr_coordinate(self) -> np.ndarray: # Undo mirror pole offset image_dist_from_origin = self.eye_radius + image_dist image_height = -(image_dist / object_dist) * object_height - image_dist_from_origin_mag = np.linalg.norm([image_height, - image_dist_from_origin]) + image_dist_from_origin_mag = np.linalg.norm([image_height, image_dist_from_origin]) # To get full 3D position of virtual image we multiply the LED unit # position vector with magnitude of the image distance from origin. - led_unit_position_vec = (self.led_pos / np.linalg.norm(self.led_pos)) + led_unit_position_vec = self.led_pos / np.linalg.norm(self.led_pos) return led_unit_position_vec * image_dist_from_origin_mag - def pupil_pos_in_eye_coords(self, - cam_pupil_params: np.ndarray, - cam_cr_params: np.ndarray) -> np.ndarray: + def pupil_pos_in_eye_coords(self, cam_pupil_params: np.ndarray, cam_cr_params: np.ndarray) -> np.ndarray: """Compute the 3D pupil position in eye coordinates. Parameters @@ -198,13 +199,13 @@ def pupil_pos_in_eye_coords(self, py_cam = cr_pos_in_cam_coord_frame[1] + delta_py # np.sqrt(np.array([-5, 25])) will result in np.array([np.nan, 5.]) # and an 'invalid' value RuntimeWarning which is fine - with np.errstate(invalid='ignore'): + with np.errstate(invalid="ignore"): pz_cam = np.sqrt(self.eye_radius**2 - px_cam**2 - py_cam**2) # Find and assign np.nan to pupil positions which land outside of eyeball radius. # An operation like: np.array([np.nan, 5, 1]) > 2 will result in array([False, True, False]) # and an 'invalid' value RuntimeWarning which is fine - with np.errstate(invalid='ignore'): + with np.errstate(invalid="ignore"): bad_idx = np.linalg.norm([px_cam, py_cam], axis=0) > self.eye_radius px_cam[bad_idx] = np.nan py_cam[bad_idx] = np.nan @@ -218,9 +219,7 @@ def pupil_pos_in_eye_coords(self, cam_to_eye_xform = R_eye_to_cam.inv() * R_cam.inv() return cam_to_eye_xform.apply(pupil_pos_cam) - def pupil_position_on_monitor_in_cm(self, - cam_pupil_params: np.ndarray, - cam_cr_params: np.ndarray) -> np.ndarray: + def pupil_position_on_monitor_in_cm(self, cam_pupil_params: np.ndarray, cam_cr_params: np.ndarray) -> np.ndarray: """Compute the pupil position on the monitor in cm. General strategy: @@ -249,16 +248,17 @@ def pupil_position_on_monitor_in_cm(self, coordinates (in centimeters). Estimate values will have the center of the monitor as the (0, 0) origin. """ - pupil_positions = self.pupil_pos_in_eye_coords(cam_pupil_params, - cam_cr_params) + pupil_positions = self.pupil_pos_in_eye_coords(cam_pupil_params, cam_cr_params) monitor_normal = self.monitor.compute_unit_normal_in_eye_coord_frame() # Project pupil locations from origin of eye coordinate system line_points = np.tile([0, 0, 0], (pupil_positions.shape[0], 1)) - projected_positions = project_to_plane(plane_normal=monitor_normal, - plane_point=self.monitor.position, - line_vectors=pupil_positions, - line_points=line_points) + projected_positions = project_to_plane( + plane_normal=monitor_normal, + plane_point=self.monitor.position, + line_vectors=pupil_positions, + line_points=line_points, + ) monitor_positions = projected_positions - self.monitor.position @@ -270,8 +270,7 @@ def pupil_position_on_monitor_in_cm(self, # Discard z component of monitor locs as it's orthogonal to viewing plane return np.delete(result, 2, axis=1) - def pupil_position_on_monitor_in_degrees(self, - pupil_pos_on_monitor_in_cm: np.ndarray) -> np.ndarray: + def pupil_position_on_monitor_in_degrees(self, pupil_pos_on_monitor_in_cm: np.ndarray) -> np.ndarray: """Get pupil position on monitor measured in visual degrees. Parameters @@ -289,8 +288,7 @@ def pupil_position_on_monitor_in_degrees(self, mag = np.linalg.norm(self.monitor.position) meridian = np.degrees(np.arctan(x / mag)) - elevation = np.degrees(np.arctan(y / np.linalg.norm( - np.vstack([x, np.full_like(x, mag, dtype=float)]), axis=0))) + elevation = np.degrees(np.arctan(y / np.linalg.norm(np.vstack([x, np.full_like(x, mag, dtype=float)]), axis=0))) angles = np.vstack([meridian, elevation]).T @@ -345,10 +343,9 @@ def compute_elliptical_areas(ellipse_params: pd.DataFrame) -> pd.Series: return np.pi * ellipse_params["height"] * ellipse_params["width"] -def project_to_plane(plane_normal: np.ndarray, - plane_point: np.ndarray, - line_vectors: np.ndarray, - line_points: np.ndarray) -> np.ndarray: +def project_to_plane( + plane_normal: np.ndarray, plane_point: np.ndarray, line_vectors: np.ndarray, line_points: np.ndarray +) -> np.ndarray: """Find the points of intersection between a plane and a series of lines. See: https://en.wikipedia.org/wiki/Line–plane_intersection @@ -377,9 +374,7 @@ def project_to_plane(plane_normal: np.ndarray, return factors * line_vectors + line_points -def generate_object_rotation_xform(x_rotation: float, - y_rotation: float, - z_rotation: float) -> Rotation: +def generate_object_rotation_xform(x_rotation: float, y_rotation: float, z_rotation: float) -> Rotation: """Generate a matrix for rotating an object in place. Parameters @@ -396,9 +391,9 @@ def generate_object_rotation_xform(x_rotation: float, A rotation instance. See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.html """ - rx = Rotation.from_euler('x', x_rotation, degrees=False) - ry = Rotation.from_euler('y', y_rotation, degrees=False) - rz = Rotation.from_euler('z', z_rotation, degrees=False) + rx = Rotation.from_euler("x", x_rotation, degrees=False) + ry = Rotation.from_euler("y", y_rotation, degrees=False) + rz = Rotation.from_euler("z", z_rotation, degrees=False) # Compose rotations with * operator. Note the order! return rz * ry * rx diff --git a/allensdk/brain_observatory/gaze_mapping/_schemas.py b/allensdk/brain_observatory/gaze_mapping/_schemas.py index 67829c7bb0..ebd6ada427 100644 --- a/allensdk/brain_observatory/gaze_mapping/_schemas.py +++ b/allensdk/brain_observatory/gaze_mapping/_schemas.py @@ -1,109 +1,74 @@ from argschema import ArgSchema from argschema.fields import Float, LogLevel, String, Boolean, Nested -from allensdk.brain_observatory.argschema_utilities import ( - InputFile, - OutputFile, - RaisingSchema -) +from allensdk.brain_observatory.argschema_utilities import InputFile, OutputFile, RaisingSchema class InputSchema(ArgSchema): # ============== Required fields ============== input_file = InputFile( - required=True, - description=('An h5 file containing ellipses fits for ' - 'eye, pupil, and corneal reflections.') + required=True, description=("An h5 file containing ellipses fits for eye, pupil, and corneal reflections.") ) session_sync_file = InputFile( required=True, - description=('An h5 file containing timestamps to synchronize ' - 'eye tracking video frames with rest of ephys ' - 'session events.') + description=( + "An h5 file containing timestamps to synchronize " + "eye tracking video frames with rest of ephys " + "session events." + ), ) output_file = OutputFile( - required=True, - description=('Full save path of output h5 file that ' - 'will be created by this module.') + required=True, description=("Full save path of output h5 file that will be created by this module.") ) - monitor_position_x_mm = Float(required=True, - description=("Monitor center X position in " - "'global' coordinates " - "(millimeters).")) - monitor_position_y_mm = Float(required=True, - description=("Monitor center Y position in " - "'global' coordinates " - "(millimeters).")) - monitor_position_z_mm = Float(required=True, - description=("Monitor center Z position in " - "'global' coordinates " - "(millimeters).")) - monitor_rotation_x_deg = Float(required=True, - description="Monitor X rotation in degrees") - monitor_rotation_y_deg = Float(required=True, - description="Monitor Y rotation in degrees") - monitor_rotation_z_deg = Float(required=True, - description="Monitor Z rotation in degrees") - camera_position_x_mm = Float(required=True, - description=("Camera center X position in " - "'global' coordinates " - "(millimeters)")) - camera_position_y_mm = Float(required=True, - description=("Camera center Y position in " - "'global' coordinates " - "(millimeters)")) - camera_position_z_mm = Float(required=True, - description=("Camera center Z position in " - "'global' coordinates " - "(millimeters)")) - camera_rotation_x_deg = Float(required=True, - description="Camera X rotation in degrees") - camera_rotation_y_deg = Float(required=True, - description="Camera Y rotation in degrees") - camera_rotation_z_deg = Float(required=True, - description="Camera Z rotation in degrees") - led_position_x_mm = Float(required=True, - description=("LED X position in 'global' " - "coordinates (millimeters)")) - led_position_y_mm = Float(required=True, - description=("LED Y position in 'global' " - "coordinates (millimeters)")) - led_position_z_mm = Float(required=True, - description=("LED Z position in 'global' " - "coordinates (millimeters)")) - equipment = String(required=True, - description=('String describing equipment setup used ' - 'to acquire eye tracking videos.')) - date_of_acquisition = String(required=True, - description='Acquisition datetime string.') - eye_video_file = InputFile(required=True, - description=('Full path to raw eye video ' - 'file (*.avi).')) + monitor_position_x_mm = Float( + required=True, description=("Monitor center X position in 'global' coordinates (millimeters).") + ) + monitor_position_y_mm = Float( + required=True, description=("Monitor center Y position in 'global' coordinates (millimeters).") + ) + monitor_position_z_mm = Float( + required=True, description=("Monitor center Z position in 'global' coordinates (millimeters).") + ) + monitor_rotation_x_deg = Float(required=True, description="Monitor X rotation in degrees") + monitor_rotation_y_deg = Float(required=True, description="Monitor Y rotation in degrees") + monitor_rotation_z_deg = Float(required=True, description="Monitor Z rotation in degrees") + camera_position_x_mm = Float( + required=True, description=("Camera center X position in 'global' coordinates (millimeters)") + ) + camera_position_y_mm = Float( + required=True, description=("Camera center Y position in 'global' coordinates (millimeters)") + ) + camera_position_z_mm = Float( + required=True, description=("Camera center Z position in 'global' coordinates (millimeters)") + ) + camera_rotation_x_deg = Float(required=True, description="Camera X rotation in degrees") + camera_rotation_y_deg = Float(required=True, description="Camera Y rotation in degrees") + camera_rotation_z_deg = Float(required=True, description="Camera Z rotation in degrees") + led_position_x_mm = Float(required=True, description=("LED X position in 'global' coordinates (millimeters)")) + led_position_y_mm = Float(required=True, description=("LED Y position in 'global' coordinates (millimeters)")) + led_position_z_mm = Float(required=True, description=("LED Z position in 'global' coordinates (millimeters)")) + equipment = String( + required=True, description=("String describing equipment setup used to acquire eye tracking videos.") + ) + date_of_acquisition = String(required=True, description="Acquisition datetime string.") + eye_video_file = InputFile(required=True, description=("Full path to raw eye video file (*.avi).")) # ============== Optional fields ============== - eye_radius_cm = Float(default=0.1682, - description=('Radius of tracked eye(s) in ' - 'centimeters.')) - cm_per_pixel = Float(default=(10.2 / 10000.0), - description=('Centimeter per pixel conversion ' - 'ratio.')) - log_level = LogLevel(default='INFO', - description='Set the logging level of the module.') + eye_radius_cm = Float(default=0.1682, description=("Radius of tracked eye(s) in centimeters.")) + cm_per_pixel = Float(default=(10.2 / 10000.0), description=("Centimeter per pixel conversion ratio.")) + log_level = LogLevel(default="INFO", description="Set the logging level of the module.") - truncate_timestamps = Boolean(default=True, - description=('If True, truncate sync ' - 'timestamps whenever unusually ' - 'large gapes occur; ' - 'Default=True')) + truncate_timestamps = Boolean( + default=True, + description=("If True, truncate sync timestamps whenever unusually large gapes occur; Default=True"), + ) class OutputSchema(RaisingSchema): input_parameters = Nested(InputSchema) - screen_mapping_file = OutputFile(required=True, - description=( - 'Full save path of output h5 ' - 'file that will be created ' - 'by this module.')) + screen_mapping_file = OutputFile( + required=True, description=("Full save path of output h5 file that will be created by this module.") + ) diff --git a/allensdk/brain_observatory/locally_sparse_noise.py b/allensdk/brain_observatory/locally_sparse_noise.py index 474e23b0ee..af0d17867e 100644 --- a/allensdk/brain_observatory/locally_sparse_noise.py +++ b/allensdk/brain_observatory/locally_sparse_noise.py @@ -39,8 +39,7 @@ import numpy as np import pandas as pd import scipy.ndimage -from .receptive_field_analysis.receptive_field import \ - compute_receptive_field_with_postprocessing +from .receptive_field_analysis.receptive_field import compute_receptive_field_with_postprocessing from .receptive_field_analysis.visualization import plot_receptive_field_data from . import circle_plots as cplots @@ -51,7 +50,7 @@ class LocallySparseNoise(StimulusAnalysis): - """ Perform tuning analysis specific to the locally sparse noise stimulus. + """Perform tuning analysis specific to the locally sparse noise stimulus. Parameters ---------- @@ -81,8 +80,7 @@ def __init__(self, data_set, stimulus=None, **kwargs): self.stimulus = stimulus try: - lsn_dims = stimulus_info.LOCALLY_SPARSE_NOISE_DIMENSIONS[ - self.stimulus] + lsn_dims = stimulus_info.LOCALLY_SPARSE_NOISE_DIMENSIONS[self.stimulus] except KeyError: raise KeyError("Unknown stimulus name: %s" % self.stimulus) @@ -96,8 +94,7 @@ def __init__(self, data_set, stimulus=None, **kwargs): self._extralength = LocallySparseNoise._PRELOAD self._mean_response = LocallySparseNoise._PRELOAD self._receptive_field = LocallySparseNoise._PRELOAD - self._cell_index_receptive_field_analysis_data = \ - LocallySparseNoise._PRELOAD + self._cell_index_receptive_field_analysis_data = LocallySparseNoise._PRELOAD @property def LSN(self): @@ -143,10 +140,8 @@ def receptive_field(self): @property def cell_index_receptive_field_analysis_data(self): - if self._cell_index_receptive_field_analysis_data is \ - LocallySparseNoise._PRELOAD: - self._cell_index_receptive_field_analysis_data = \ - self.get_receptive_field_analysis_data() + if self._cell_index_receptive_field_analysis_data is LocallySparseNoise._PRELOAD: + self._cell_index_receptive_field_analysis_data = self.get_receptive_field_analysis_data() return self._cell_index_receptive_field_analysis_data @@ -158,28 +153,35 @@ def mean_response(self): return self._mean_response def get_peak(self): - LocallySparseNoise._log.info('Calculating peak response properties') - - peak = pd.DataFrame(index=range(self.numbercells), columns=( - 'rf_center_on_x_lsn', 'rf_center_on_y_lsn', - 'rf_center_off_x_lsn', 'rf_center_off_y_lsn', - 'rf_area_on_lsn', 'rf_area_off_lsn', - 'rf_distance_lsn', 'rf_overlap_index_lsn', - 'rf_chi2_lsn', - 'cell_specimen_id')) + LocallySparseNoise._log.info("Calculating peak response properties") + + peak = pd.DataFrame( + index=range(self.numbercells), + columns=( + "rf_center_on_x_lsn", + "rf_center_on_y_lsn", + "rf_center_off_x_lsn", + "rf_center_off_y_lsn", + "rf_area_on_lsn", + "rf_area_off_lsn", + "rf_distance_lsn", + "rf_overlap_index_lsn", + "rf_chi2_lsn", + "cell_specimen_id", + ), + ) csids = self.data_set.get_cell_specimen_ids() df = self.get_receptive_field_attribute_df() peak.cell_specimen_id = csids for nc in range(self.numbercells): - peak['rf_chi2_lsn'].iloc[nc] = \ - df['chi_squared_analysis/min_p'].iloc[nc] + peak["rf_chi2_lsn"].iloc[nc] = df["chi_squared_analysis/min_p"].iloc[nc] # find the index of the largest on subunit, if it exists on_i = None - if 'on/gaussian_fit/area' in df.columns: - area_on = df['on/gaussian_fit/area'].iloc[nc] + if "on/gaussian_fit/area" in df.columns: + area_on = df["on/gaussian_fit/area"].iloc[nc] # watch out for NaNs and Nones if isinstance(area_on, np.ndarray): @@ -190,21 +192,18 @@ def get_peak(self): on_i = None if on_i is None: - peak['rf_area_on_lsn'].iloc[nc] = np.nan - peak['rf_center_on_x_lsn'].iloc[nc] = np.nan - peak['rf_center_on_y_lsn'].iloc[nc] = np.nan + peak["rf_area_on_lsn"].iloc[nc] = np.nan + peak["rf_center_on_x_lsn"].iloc[nc] = np.nan + peak["rf_center_on_y_lsn"].iloc[nc] = np.nan else: - peak['rf_area_on_lsn'].iloc[nc] = \ - df['on/gaussian_fit/area'].iloc[nc][on_i] - peak['rf_center_on_x_lsn'].iloc[nc] = \ - df['on/gaussian_fit/center_x'].iloc[nc][on_i] - peak['rf_center_on_y_lsn'].iloc[nc] = \ - df['on/gaussian_fit/center_y'].iloc[nc][on_i] + peak["rf_area_on_lsn"].iloc[nc] = df["on/gaussian_fit/area"].iloc[nc][on_i] + peak["rf_center_on_x_lsn"].iloc[nc] = df["on/gaussian_fit/center_x"].iloc[nc][on_i] + peak["rf_center_on_y_lsn"].iloc[nc] = df["on/gaussian_fit/center_y"].iloc[nc][on_i] # find the index of the largest off subunit, if it exists off_i = None - if 'off/gaussian_fit/area' in df.columns: - area_off = df['off/gaussian_fit/area'].iloc[nc] + if "off/gaussian_fit/area" in df.columns: + area_off = df["off/gaussian_fit/area"].iloc[nc] # watch out for NaNs and Nones if isinstance(area_off, np.ndarray): @@ -215,88 +214,70 @@ def get_peak(self): off_i = None if off_i is None: - peak['rf_area_off_lsn'].iloc[nc] = np.nan - peak['rf_center_off_x_lsn'].iloc[nc] = np.nan - peak['rf_center_off_y_lsn'].iloc[nc] = np.nan + peak["rf_area_off_lsn"].iloc[nc] = np.nan + peak["rf_center_off_x_lsn"].iloc[nc] = np.nan + peak["rf_center_off_y_lsn"].iloc[nc] = np.nan else: - peak['rf_area_off_lsn'].iloc[nc] = \ - df['off/gaussian_fit/area'].iloc[nc][off_i] - peak['rf_center_off_x_lsn'].iloc[nc] = \ - df['off/gaussian_fit/center_x'].iloc[nc][off_i] - peak['rf_center_off_y_lsn'].iloc[nc] = \ - df['off/gaussian_fit/center_y'].iloc[nc][off_i] + peak["rf_area_off_lsn"].iloc[nc] = df["off/gaussian_fit/area"].iloc[nc][off_i] + peak["rf_center_off_x_lsn"].iloc[nc] = df["off/gaussian_fit/center_x"].iloc[nc][off_i] + peak["rf_center_off_y_lsn"].iloc[nc] = df["off/gaussian_fit/center_y"].iloc[nc][off_i] if on_i is not None and off_i is not None: - peak['rf_distance_lsn'].iloc[nc] = \ - df['on/gaussian_fit/distance'].iloc[nc][on_i][off_i] - peak['rf_overlap_index_lsn'].iloc[nc] = \ - df['on/gaussian_fit/overlap'].iloc[nc][on_i][off_i] + peak["rf_distance_lsn"].iloc[nc] = df["on/gaussian_fit/distance"].iloc[nc][on_i][off_i] + peak["rf_overlap_index_lsn"].iloc[nc] = df["on/gaussian_fit/overlap"].iloc[nc][on_i][off_i] else: - peak['rf_distance_lsn'].iloc[nc] = np.nan - peak['rf_overlap_index_lsn'].iloc[nc] = np.nan + peak["rf_distance_lsn"].iloc[nc] = np.nan + peak["rf_overlap_index_lsn"].iloc[nc] = np.nan return peak def populate_stimulus_table(self): self._stim_table = self.data_set.get_stimulus_table(self.stimulus) - self._LSN, self._LSN_mask = \ - self.data_set.get_locally_sparse_noise_stimulus_template( - self.stimulus, mask_off_screen=False) - self._sweeplength = ( - self._stim_table['end'][1] - - self._stim_table['start'][1]) + self._LSN, self._LSN_mask = self.data_set.get_locally_sparse_noise_stimulus_template( + self.stimulus, mask_off_screen=False + ) + self._sweeplength = self._stim_table["end"][1] - self._stim_table["start"][1] self._interlength = 4 * self._sweeplength self._extralength = self._sweeplength def get_mean_response(self): logging.debug("Calculating mean responses") - mean_response = np.empty( - (self.nrows, self.ncols, self.numbercells + 1, 2)) + mean_response = np.empty((self.nrows, self.ncols, self.numbercells + 1, 2)) for xp in range(self.nrows): for yp in range(self.ncols): on_frame = np.where(self.LSN[:, xp, yp] == self.LSN_ON)[0] off_frame = np.where(self.LSN[:, xp, yp] == self.LSN_OFF)[0] - subset_on = self.mean_sweep_response[ - self.stim_table.frame.isin(on_frame)] - subset_off = self.mean_sweep_response[ - self.stim_table.frame.isin(off_frame)] + subset_on = self.mean_sweep_response[self.stim_table.frame.isin(on_frame)] + subset_off = self.mean_sweep_response[self.stim_table.frame.isin(off_frame)] mean_response[xp, yp, :, 0] = subset_on.mean(axis=0) mean_response[xp, yp, :, 1] = subset_off.mean(axis=0) return mean_response def get_receptive_field(self): - ''' Calculates receptive fields for each cell - ''' - - receptive_field = np.zeros( - (self.nrows, self.ncols, self.numbercells, 2)) - - for cell_index in range( - len(self.cell_index_receptive_field_analysis_data)): - curr_rf = self.cell_index_receptive_field_analysis_data[ - str(cell_index)] - rf_on = curr_rf['on']['rts_convolution']['data'].copy() - rf_off = curr_rf['off']['rts_convolution']['data'].copy() - rf_on[np.logical_not( - curr_rf['on']['fdr_mask']['data'].sum(axis=0))] = np.nan - rf_off[np.logical_not( - curr_rf['off']['fdr_mask']['data'].sum(axis=0))] = np.nan + """Calculates receptive fields for each cell""" + + receptive_field = np.zeros((self.nrows, self.ncols, self.numbercells, 2)) + + for cell_index in range(len(self.cell_index_receptive_field_analysis_data)): + curr_rf = self.cell_index_receptive_field_analysis_data[str(cell_index)] + rf_on = curr_rf["on"]["rts_convolution"]["data"].copy() + rf_off = curr_rf["off"]["rts_convolution"]["data"].copy() + rf_on[np.logical_not(curr_rf["on"]["fdr_mask"]["data"].sum(axis=0))] = np.nan + rf_off[np.logical_not(curr_rf["off"]["fdr_mask"]["data"].sum(axis=0))] = np.nan receptive_field[:, :, cell_index, 0] = rf_on receptive_field[:, :, cell_index, 1] = rf_off return receptive_field def get_receptive_field_analysis_data(self): - ''' Calculates receptive fields for each cell - ''' + """Calculates receptive fields for each cell""" csid_rf = {} for cell_index in range(self.data_set.number_of_cells): - csid_rf[str(cell_index)] = \ - compute_receptive_field_with_postprocessing( - self.data_set, cell_index, self.stimulus, alpha=.05, - number_of_shuffles=10000) + csid_rf[str(cell_index)] = compute_receptive_field_with_postprocessing( + self.data_set, cell_index, self.stimulus, alpha=0.05, number_of_shuffles=10000 + ) return csid_rf @@ -305,38 +286,33 @@ def plot_receptive_field_analysis_data(self, cell_index, **kwargs): return plot_receptive_field_data(rf, self, **kwargs) def get_receptive_field_attribute_df(self): - df_list = [] - for cell_index_as_str, rf in \ - self.cell_index_receptive_field_analysis_data.items(): - + for cell_index_as_str, rf in self.cell_index_receptive_field_analysis_data.items(): attribute_dict = {} for x in dict_generator(rf): - if x[-3] == 'attrs': + if x[-3] == "attrs": if len(x[:-3]) == 0: key = x[-2] else: - key = '/'.join(['/'.join(x[:-3]), x[-2]]) + key = "/".join(["/".join(x[:-3]), x[-2]]) attribute_dict[key] = x[-1] massaged_dict = {} for key, val in attribute_dict.items(): massaged_dict[key] = [val] - massaged_dict['oeid'] = self.data_set.get_metadata()[ - 'ophys_experiment_id'] + massaged_dict["oeid"] = self.data_set.get_metadata()["ophys_experiment_id"] curr_df = pd.DataFrame.from_dict(massaged_dict) df_list.append(curr_df) attribute_df = pd.concat(df_list, sort=True) - return attribute_df.sort_values('cell_index') + return attribute_df.sort_values("cell_index") @staticmethod def merge_mean_response(rc1, rc2): - """ Move out of this class, to session analysis - """ + """Move out of this class, to session analysis""" # make sure that rc1 is the larger one if rc2.shape[0] > rc1.shape[0]: @@ -348,44 +324,31 @@ def merge_mean_response(rc1, rc2): return rc1 + rc2_zoom - def plot_cell_receptive_field(self, on, cell_specimen_id=None, - color_map=None, clim=None, mask=None, - cell_index=None, scalebar=True): + def plot_cell_receptive_field( + self, on, cell_specimen_id=None, color_map=None, clim=None, mask=None, cell_index=None, scalebar=True + ): if color_map is None: - color_map = 'Reds' if on else 'Blues' + color_map = "Reds" if on else "Blues" - onst = 'on' if on else 'off' + onst = "on" if on else "off" cell_idx = self.row_from_cell_id(cell_specimen_id, cell_index) rf = self.cell_index_receptive_field_analysis_data[str(cell_idx)] - rts = rf[onst]['rts']['data'] - rts[np.logical_not(rf[onst]['fdr_mask']['data'].sum(axis=0))] = np.nan + rts = rf[onst]["rts"]["data"] + rts[np.logical_not(rf[onst]["fdr_mask"]["data"].sum(axis=0))] = np.nan - oplots.plot_receptive_field(rts, - color_map=color_map, - clim=clim, - mask=mask, - scalebar=scalebar) + oplots.plot_receptive_field(rts, color_map=color_map, clim=clim, mask=mask, scalebar=scalebar) - def plot_population_receptive_field(self, color_map='RdPu', clim=None, - mask=None, scalebar=True): + def plot_population_receptive_field(self, color_map="RdPu", clim=None, mask=None, scalebar=True): rf = np.nansum(self.receptive_field, axis=(2, 3)) - oplots.plot_receptive_field(rf, - color_map=color_map, - clim=clim, - mask=mask, - scalebar=scalebar) + oplots.plot_receptive_field(rf, color_map=color_map, clim=clim, mask=mask, scalebar=scalebar) def sort_trials(self): ds = self.data_set - lsn_movie, lsn_mask = ds.get_locally_sparse_noise_stimulus_template( - self.stimulus, - mask_off_screen=False) + lsn_movie, lsn_mask = ds.get_locally_sparse_noise_stimulus_template(self.stimulus, mask_off_screen=False) - baseline_trials = np.unique( - np.where(lsn_movie[:, -5:, -1] != LocallySparseNoise.LSN_GREY)[0]) - valid_indices = pd.Index(set(baseline_trials) & - set(self.mean_sweep_response.index.tolist())) + baseline_trials = np.unique(np.where(lsn_movie[:, -5:, -1] != LocallySparseNoise.LSN_GREY)[0]) + valid_indices = pd.Index(set(baseline_trials) & set(self.mean_sweep_response.index.tolist())) baseline_df = self.mean_sweep_response.loc[valid_indices] cell_baselines = np.nanmean(baseline_df.values, axis=0) @@ -394,29 +357,30 @@ def sort_trials(self): trials = {} for row in range(self.nrows): for col in range(self.ncols): - on_trials = np.where( - lsn_movie[:, row, col] == LocallySparseNoise.LSN_ON) - off_trials = np.where( - lsn_movie[:, row, col] == LocallySparseNoise.LSN_OFF) + on_trials = np.where(lsn_movie[:, row, col] == LocallySparseNoise.LSN_ON) + off_trials = np.where(lsn_movie[:, row, col] == LocallySparseNoise.LSN_OFF) trials[(col, row, True)] = on_trials trials[(col, row, False)] = off_trials return trials, cell_baselines - def open_pincushion_plot(self, on, cell_specimen_id=None, color_map=None, - cell_index=None): + def open_pincushion_plot(self, on, cell_specimen_id=None, color_map=None, cell_index=None): cell_index = self.row_from_cell_id(cell_specimen_id, cell_index) trials, baselines = self.sort_trials() data = self.mean_sweep_response[str(cell_index)].values - cplots.make_pincushion_plot(data, trials, on, - self.nrows, self.ncols, - clim=[baselines[cell_index], - data.mean() + data.std() * 3], - color_map=color_map, - radius=1.0 / 16.0) + cplots.make_pincushion_plot( + data, + trials, + on, + self.nrows, + self.ncols, + clim=[baselines[cell_index], data.mean() + data.std() * 3], + color_map=color_map, + radius=1.0 / 16.0, + ) @staticmethod def from_analysis_file(data_set, analysis_file, stimulus): @@ -432,23 +396,18 @@ def from_analysis_file(data_set, analysis_file, stimulus): stimulus_suffix = stimulus_info.LOCALLY_SPARSE_NOISE_8DEG_SHORT try: - with h5py.File(analysis_file, "r") as f: k = "analysis/mean_response_%s" % stimulus_suffix if k in f: lsn._mean_response = f[k][()] - lsn._sweep_response = pd.read_hdf(analysis_file, - "analysis/sweep_response_%s" % - stimulus_suffix) - lsn._mean_sweep_response = pd.read_hdf( - analysis_file, "analysis/mean_sweep_response_%s" % - stimulus_suffix) + lsn._sweep_response = pd.read_hdf(analysis_file, "analysis/sweep_response_%s" % stimulus_suffix) + lsn._mean_sweep_response = pd.read_hdf(analysis_file, "analysis/mean_sweep_response_%s" % stimulus_suffix) with h5py.File(analysis_file, "r") as f: - lsn._cell_index_receptive_field_analysis_data = \ - LocallySparseNoise.\ - read_cell_index_receptive_field_analysis(f, stimulus) + lsn._cell_index_receptive_field_analysis_data = ( + LocallySparseNoise.read_cell_index_receptive_field_analysis(f, stimulus) + ) except Exception as e: raise MissingStimulusException(e.args) @@ -456,34 +415,31 @@ def from_analysis_file(data_set, analysis_file, stimulus): return lsn @staticmethod - def save_cell_index_receptive_field_analysis( - cell_index_receptive_field_analysis_data, new_nwb, prefix): - + def save_cell_index_receptive_field_analysis(cell_index_receptive_field_analysis_data, new_nwb, prefix): attr_list = [] - file_handle = h5py.File(new_nwb.nwb_file, 'a') - if prefix in file_handle['analysis']: - del file_handle['analysis'][prefix] - f = file_handle.create_group('analysis/%s' % prefix) + file_handle = h5py.File(new_nwb.nwb_file, "a") + if prefix in file_handle["analysis"]: + del file_handle["analysis"][prefix] + f = file_handle.create_group("analysis/%s" % prefix) for x in dict_generator(cell_index_receptive_field_analysis_data): - if x[-2] == 'data': - f['/'.join(x[:-1])] = x[-1] - elif x[-3] == 'attrs': + if x[-2] == "data": + f["/".join(x[:-1])] = x[-1] + elif x[-3] == "attrs": attr_list.append(x) else: raise Exception for x in attr_list: - # replace None => nan before writing # set array type to float for ii, item in enumerate(x): if isinstance(item, np.ndarray): - if item.dtype == np.dtype('O'): - item[item == None] = np.nan # noqa E711 + if item.dtype == np.dtype("O"): + item[item == None] = np.nan # noqa E711 x[ii] = np.array(item, dtype=float) if len(x) > 3: - f['/'.join(x[:-3])].attrs[x[-2]] = x[-1] + f["/".join(x[:-3])].attrs[x[-2]] = x[-1] else: assert len(x) == 3 @@ -495,11 +451,10 @@ def save_cell_index_receptive_field_analysis( file_handle.close() @staticmethod - def read_cell_index_receptive_field_analysis(file_handle, prefix, - path=None): - k = 'analysis/%s' % prefix + def read_cell_index_receptive_field_analysis(file_handle, prefix, path=None): + k = "analysis/%s" % prefix if k in file_handle: - f = file_handle['analysis/%s' % prefix] + f = file_handle["analysis/%s" % prefix] if path is None: rf = read_h5_group(f) else: diff --git a/allensdk/brain_observatory/multi_stimulus_running_speed/__main__.py b/allensdk/brain_observatory/multi_stimulus_running_speed/__main__.py index d62a1ca930..f747fa3b4c 100644 --- a/allensdk/brain_observatory/multi_stimulus_running_speed/__main__.py +++ b/allensdk/brain_observatory/multi_stimulus_running_speed/__main__.py @@ -1,7 +1,6 @@ -from allensdk.brain_observatory.\ - multi_stimulus_running_speed.multi_stimulus_running_speed import ( - MultiStimulusRunningSpeed - ) +from allensdk.brain_observatory.multi_stimulus_running_speed.multi_stimulus_running_speed import ( + MultiStimulusRunningSpeed, +) if __name__ == "__main__": multi_stimulus_running_speed = MultiStimulusRunningSpeed() diff --git a/allensdk/brain_observatory/multi_stimulus_running_speed/_schemas.py b/allensdk/brain_observatory/multi_stimulus_running_speed/_schemas.py index 3e041a1e85..b714bdca1b 100644 --- a/allensdk/brain_observatory/multi_stimulus_running_speed/_schemas.py +++ b/allensdk/brain_observatory/multi_stimulus_running_speed/_schemas.py @@ -3,10 +3,7 @@ class MultiStimulusRunningSpeedInputParameters(argschema.ArgSchema): - output_path = argschema.fields.OutputFile( - required=True, - description="The location to write the output file" - ) + output_path = argschema.fields.OutputFile(required=True, description="The location to write the output file") mapping_pkl_path = argschema.fields.InputFile( required=True, @@ -26,20 +23,15 @@ class MultiStimulusRunningSpeedInputParameters(argschema.ArgSchema): ) use_lowpass_filter = argschema.fields.Bool( - required=True, - default=True, - description=( - "apply a low pass filter to the running speed results" - ) - ) + required=True, default=True, description=("apply a low pass filter to the running speed results") + ) zscore_threshold = argschema.fields.Float( required=True, default=10.0, description=( - "The threshold to use for removing outlier " - "running speeds which might be noise and not true signal" - ) + "The threshold to use for removing outlier running speeds which might be noise and not true signal" + ), ) @@ -51,10 +43,5 @@ class MultiStimulusRunningSpeedOutputSchema(argschema.schemas.DefaultSchema): ) -class MultiStimulusRunningSpeedOutputParameters( - MultiStimulusRunningSpeedOutputSchema -): - output_path = argschema.fields.OutputFile( - required=True, - help="Filtered running speed hdf5 output file." - ) +class MultiStimulusRunningSpeedOutputParameters(MultiStimulusRunningSpeedOutputSchema): + output_path = argschema.fields.OutputFile(required=True, help="Filtered running speed hdf5 output file.") diff --git a/allensdk/brain_observatory/multi_stimulus_running_speed/multi_stimulus_running_speed.py b/allensdk/brain_observatory/multi_stimulus_running_speed/multi_stimulus_running_speed.py index 3da6a0c95d..7558b8d960 100644 --- a/allensdk/brain_observatory/multi_stimulus_running_speed/multi_stimulus_running_speed.py +++ b/allensdk/brain_observatory/multi_stimulus_running_speed/multi_stimulus_running_speed.py @@ -12,16 +12,17 @@ from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( BehaviorStimulusFile, MappingStimulusFile, - ReplayStimulusFile) + ReplayStimulusFile, +) from allensdk.brain_observatory.multi_stimulus_running_speed._schemas import ( MultiStimulusRunningSpeedInputParameters, - MultiStimulusRunningSpeedOutputParameters + MultiStimulusRunningSpeedOutputParameters, ) -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.multi_stim_running_processing import ( - multi_stim_running_df_from_raw_data) +from allensdk.brain_observatory.behavior.data_objects.running_speed.multi_stim_running_processing import ( + multi_stim_running_df_from_raw_data, +) class MultiStimulusRunningSpeed(argschema.ArgSchemaParser): @@ -36,42 +37,34 @@ def _write_output_json(self): """ ouput_data = {} - ouput_data['output_path'] = self.args['output_path'] - ouput_data['input_parameters'] = self.args + ouput_data["output_path"] = self.args["output_path"] + ouput_data["input_parameters"] = self.args - with open(self.args['output_json'], 'w') as output_file: + with open(self.args["output_json"], "w") as output_file: json.dump(ouput_data, output_file, indent=2) - def process( - self - ): + def process(self): """ Process an experiment with a three stimulus sessions """ - bstim = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': - self.args['behavior_pkl_path']}) - - mstim = MappingStimulusFile.from_json( - dict_repr={'mapping_stimulus_file': - self.args['mapping_pkl_path']}) - - rstim = ReplayStimulusFile.from_json( - dict_repr={'replay_stimulus_file': - self.args['replay_pkl_path']}) - - (velocities, - raw_data) = multi_stim_running_df_from_raw_data( - sync_path=self.args['sync_h5_path'], - behavior_stimulus_file=bstim, - mapping_stimulus_file=mstim, - replay_stimulus_file=rstim, - use_lowpass_filter=self.args['use_lowpass_filter'], - zscore_threshold=self.args['zscore_threshold'], - behavior_start_frame=MultiStimulusRunningSpeed.START_FRAME) - - store = pd.HDFStore(self.args['output_path']) + bstim = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": self.args["behavior_pkl_path"]}) + + mstim = MappingStimulusFile.from_json(dict_repr={"mapping_stimulus_file": self.args["mapping_pkl_path"]}) + + rstim = ReplayStimulusFile.from_json(dict_repr={"replay_stimulus_file": self.args["replay_pkl_path"]}) + + (velocities, raw_data) = multi_stim_running_df_from_raw_data( + sync_path=self.args["sync_h5_path"], + behavior_stimulus_file=bstim, + mapping_stimulus_file=mstim, + replay_stimulus_file=rstim, + use_lowpass_filter=self.args["use_lowpass_filter"], + zscore_threshold=self.args["zscore_threshold"], + behavior_start_frame=MultiStimulusRunningSpeed.START_FRAME, + ) + + store = pd.HDFStore(self.args["output_path"]) store.put("running_speed", velocities) store.put("raw_data", raw_data) store.close() diff --git a/allensdk/brain_observatory/natural_movie.py b/allensdk/brain_observatory/natural_movie.py index 100d498116..a8d25eff79 100644 --- a/allensdk/brain_observatory/natural_movie.py +++ b/allensdk/brain_observatory/natural_movie.py @@ -44,7 +44,7 @@ class NaturalMovie(StimulusAnalysis): - """ Perform tuning analysis specific to natural movie stimulus. + """Perform tuning analysis specific to natural movie stimulus. Parameters ---------- @@ -80,19 +80,18 @@ def sweep_response(self): def populate_stimulus_table(self): stimulus_table = self.data_set.get_stimulus_table(self.movie_name) self._stim_table = stimulus_table[stimulus_table.frame == 0] - self._sweeplength = \ - self.stim_table.start.iloc[1] - self.stim_table.start.iloc[0] + self._sweeplength = self.stim_table.start.iloc[1] - self.stim_table.start.iloc[0] def get_sweep_response(self): - ''' Returns the dF/F response for each cell + """Returns the dF/F response for each cell Returns ------- Numpy array - ''' - sweep_response = pd.DataFrame(index=self.stim_table.index.values, - columns=np.array( - range(self.numbercells)).astype(str)) + """ + sweep_response = pd.DataFrame( + index=self.stim_table.index.values, columns=np.array(range(self.numbercells)).astype(str) + ) for index, row in self.stim_table.iterrows(): start = row.start end = start + self.sweeplength @@ -101,7 +100,7 @@ def get_sweep_response(self): return sweep_response def get_peak(self): - ''' Computes properties of the peak response condition for each cell. + """Computes properties of the peak response condition for each cell. Returns ------- @@ -110,9 +109,10 @@ def get_peak(self): on which of three movie clips was presented. * peak_nm1 (frame with peak response) * response_variability_nm1 - ''' - peak_movie = pd.DataFrame(index=range(self.numbercells), columns=( - 'peak', 'response_reliability', 'cell_specimen_id')) + """ + peak_movie = pd.DataFrame( + index=range(self.numbercells), columns=("peak", "response_reliability", "cell_specimen_id") + ) cids = self.data_set.get_cell_specimen_ids() mask = np.ones((10, 10)) @@ -150,8 +150,7 @@ def get_peak(self): corr_matrix = np.empty((10, 10)) for i in range(10): for j in range(10): - r, p = st.pearsonr(self.sweep_response[str(nc)].iloc[i], - self.sweep_response[str(nc)].iloc[j]) + r, p = st.pearsonr(self.sweep_response[str(nc)].iloc[i], self.sweep_response[str(nc)].iloc[j]) corr_matrix[i, j] = r corr_matrix *= mask peak_movie.response_reliability.iloc[nc] = np.nanmean(corr_matrix) @@ -159,24 +158,27 @@ def get_peak(self): if self.movie_name == stiminfo.NATURAL_MOVIE_ONE: peak_movie.rename( columns={ - 'peak': 'peak_' + stiminfo.NATURAL_MOVIE_ONE_SHORT, - 'response_reliability': 'response_reliability_' + - stiminfo.NATURAL_MOVIE_ONE_SHORT}, - inplace=True) + "peak": "peak_" + stiminfo.NATURAL_MOVIE_ONE_SHORT, + "response_reliability": "response_reliability_" + stiminfo.NATURAL_MOVIE_ONE_SHORT, + }, + inplace=True, + ) elif self.movie_name == stiminfo.NATURAL_MOVIE_TWO: peak_movie.rename( columns={ - 'peak': 'peak_' + stiminfo.NATURAL_MOVIE_TWO_SHORT, - 'response_reliability': 'response_reliability_' + - stiminfo.NATURAL_MOVIE_TWO_SHORT}, - inplace=True) + "peak": "peak_" + stiminfo.NATURAL_MOVIE_TWO_SHORT, + "response_reliability": "response_reliability_" + stiminfo.NATURAL_MOVIE_TWO_SHORT, + }, + inplace=True, + ) elif self.movie_name == stiminfo.NATURAL_MOVIE_THREE: peak_movie.rename( columns={ - 'peak': 'peak_' + stiminfo.NATURAL_MOVIE_THREE_SHORT, - 'response_reliability': 'response_reliability_' + - stiminfo.NATURAL_MOVIE_THREE_SHORT - }, inplace=True) + "peak": "peak_" + stiminfo.NATURAL_MOVIE_THREE_SHORT, + "response_reliability": "response_reliability_" + stiminfo.NATURAL_MOVIE_THREE_SHORT, + }, + inplace=True, + ) return peak_movie @@ -191,8 +193,7 @@ def open_track_plot(self, cell_specimen_id=None, cell_index=None): data = np.vstack(data) tp = cplots.TrackPlotter(ring_length=360) - tp.plot(data, - clim=[0, data.mean() + data.std() * 3]) + tp.plot(data, clim=[0, data.mean() + data.std() * 3]) tp.show_arrow() @staticmethod @@ -202,18 +203,15 @@ def from_analysis_file(data_set, analysis_file, movie_name): # TODO: deal with this properly suffix_map = { - stiminfo.NATURAL_MOVIE_ONE: '_' + stiminfo.NATURAL_MOVIE_ONE_SHORT, - stiminfo.NATURAL_MOVIE_TWO: '_' + stiminfo.NATURAL_MOVIE_TWO_SHORT, - stiminfo.NATURAL_MOVIE_THREE: '_' + - stiminfo.NATURAL_MOVIE_THREE_SHORT + stiminfo.NATURAL_MOVIE_ONE: "_" + stiminfo.NATURAL_MOVIE_ONE_SHORT, + stiminfo.NATURAL_MOVIE_TWO: "_" + stiminfo.NATURAL_MOVIE_TWO_SHORT, + stiminfo.NATURAL_MOVIE_THREE: "_" + stiminfo.NATURAL_MOVIE_THREE_SHORT, } try: suffix = suffix_map[movie_name] - nm._sweep_response = pd.read_hdf(analysis_file, - "analysis/sweep_response" + - suffix) + nm._sweep_response = pd.read_hdf(analysis_file, "analysis/sweep_response" + suffix) nm._peak = pd.read_hdf(analysis_file, "analysis/peak") with h5py.File(analysis_file, "r") as f: diff --git a/allensdk/brain_observatory/natural_scenes.py b/allensdk/brain_observatory/natural_scenes.py index 0d7164303c..b2c5a4f86e 100644 --- a/allensdk/brain_observatory/natural_scenes.py +++ b/allensdk/brain_observatory/natural_scenes.py @@ -45,14 +45,14 @@ class NaturalScenes(StimulusAnalysis): - """ Perform tuning analysis specific to natural scenes stimulus. + """Perform tuning analysis specific to natural scenes stimulus. Parameters ---------- data_set: BrainObservatoryNwbDataSet object """ - _log = logging.getLogger('allensdk.brain_observatory.natural_scenes') + _log = logging.getLogger("allensdk.brain_observatory.natural_scenes") def __init__(self, data_set, **kwargs): super(NaturalScenes, self).__init__(data_set, **kwargs) @@ -91,16 +91,14 @@ def extralength(self): return self._extralength def populate_stimulus_table(self): - self._stim_table = self.data_set.get_stimulus_table('natural_scenes') + self._stim_table = self.data_set.get_stimulus_table("natural_scenes") self._number_scenes = len(np.unique(self._stim_table.frame)) - self._sweeplength = ( - self._stim_table.end.iloc[1] - - self._stim_table.start.iloc[1]) + self._sweeplength = self._stim_table.end.iloc[1] - self._stim_table.start.iloc[1] self._interlength = 4 * self._sweeplength self._extralength = self._sweeplength def get_response(self): - ''' Computes the mean response for each cell to each stimulus + """Computes the mean response for each cell to each stimulus condition. Return is a (# scenes, # cells, 3) np.ndarray. The final dimension contains the mean response to the condition (index 0), standard @@ -112,7 +110,7 @@ def get_response(self): Returns ------- Numpy array storing mean responses. - ''' + """ NaturalScenes._log.info("Calculating mean responses") response = np.empty((self.number_scenes, self.numbercells + 1, 3)) @@ -121,18 +119,16 @@ def ptest(x): return len(np.where(x < (0.05 / (self.number_scenes - 1)))[0]) for ns in range(self.number_scenes): - subset_response = self.mean_sweep_response[ - self.stim_table.frame == (ns - 1)] + subset_response = self.mean_sweep_response[self.stim_table.frame == (ns - 1)] subset_pval = self.pval[self.stim_table.frame == (ns - 1)] response[ns, :, 0] = subset_response.mean(axis=0) - response[ns, :, 1] = subset_response.std( - axis=0) / np.sqrt(len(subset_response)) + response[ns, :, 1] = subset_response.std(axis=0) / np.sqrt(len(subset_response)) response[ns, :, 2] = subset_pval.apply(ptest, axis=0) return response def get_peak(self): - ''' Computes metrics about peak response condition for each cell. + """Computes metrics about peak response condition for each cell. Returns ------- @@ -145,13 +141,22 @@ def get_peak(self): * p_run_ns * run_modulation_ns * time_to_peak_ns - ''' - NaturalScenes._log.info('Calculating peak response properties') - peak = pd.DataFrame(index=range(self.numbercells), columns=( - 'scene_ns', 'reliability_ns', 'peak_dff_ns', - 'ptest_ns', 'p_run_ns', 'run_modulation_ns', - 'time_to_peak_ns', - 'cell_specimen_id', 'image_selectivity_ns')) + """ + NaturalScenes._log.info("Calculating peak response properties") + peak = pd.DataFrame( + index=range(self.numbercells), + columns=( + "scene_ns", + "reliability_ns", + "peak_dff_ns", + "ptest_ns", + "p_run_ns", + "run_modulation_ns", + "time_to_peak_ns", + "cell_specimen_id", + "image_selectivity_ns", + ), + ) cids = self.data_set.get_cell_specimen_ids() for nc in range(self.numbercells): @@ -175,37 +180,27 @@ def get_peak(self): # peak.run_modulation_ns[nc] = np.nan groups = [] for im in range(self.number_scenes): - subset = self.mean_sweep_response[ - self.stim_table.frame == (im - 1)] + subset = self.mean_sweep_response[self.stim_table.frame == (im - 1)] groups.append(subset[str(nc)].values) (_, peak.ptest_ns[nc]) = st.f_oneway(*groups) - test = self.sweep_response[ - self.stim_table.frame == nsp][str(nc)].mean() - peak.time_to_peak_ns[nc] = \ - (np.argmax(test) - self.interlength) / self.acquisition_rate + test = self.sweep_response[self.stim_table.frame == nsp][str(nc)].mean() + peak.time_to_peak_ns[nc] = (np.argmax(test) - self.interlength) / self.acquisition_rate # running modulation subset = self.mean_sweep_response[self.stim_table.frame == nsp] subset_run = subset[subset.dx >= 1] subset_stat = subset[subset.dx < 1] if (len(subset_run) > 4) & (len(subset_stat) > 4): - (_, peak.p_run_ns.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], - subset_stat[ - str(nc)], - equal_var=False) + (_, peak.p_run_ns.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], subset_stat[str(nc)], equal_var=False) if subset_run[str(nc)].mean() > subset_stat[str(nc)].mean(): - peak.run_modulation_ns.iloc[nc] = (subset_run[ - str(nc)].mean() - - subset_stat[ - str(nc)].mean()) \ - / np.abs( - subset_run[str(nc)].mean()) + peak.run_modulation_ns.iloc[nc] = ( + subset_run[str(nc)].mean() - subset_stat[str(nc)].mean() + ) / np.abs(subset_run[str(nc)].mean()) elif subset_run[str(nc)].mean() < subset_stat[str(nc)].mean(): - peak.run_modulation_ns.iloc[nc] = \ - (-1 * ((subset_stat[str(nc)].mean() - - subset_run[str(nc)].mean()) / - np.abs(subset_stat[str(nc)].mean()))) + peak.run_modulation_ns.iloc[nc] = -1 * ( + (subset_stat[str(nc)].mean() - subset_run[str(nc)].mean()) / np.abs(subset_stat[str(nc)].mean()) + ) else: peak.p_run_ns.iloc[nc] = np.nan peak.run_modulation_ns.iloc[nc] = np.nan @@ -215,8 +210,7 @@ def get_peak(self): corr_matrix = np.empty((len(subset), len(subset))) for i in range(len(subset)): for j in range(len(subset)): - r, p = st.pearsonr(subset[str(nc)].iloc[i][28:42], - subset[str(nc)].iloc[j][28:42]) + r, p = st.pearsonr(subset[str(nc)].iloc[i][28:42], subset[str(nc)].iloc[j][28:42]) corr_matrix[i, j] = r mask = np.ones((len(subset), len(subset))) for i in range(len(subset)): @@ -231,7 +225,7 @@ def get_peak(self): fmax = self.response[1:, nc, 0].max() rtj = np.empty((1000, 1)) for j in range(1000): - thresh = fmin + j * ((fmax - fmin) / 1000.) + thresh = fmin + j * ((fmax - fmin) / 1000.0) theta = np.empty((118, 1)) for im in range(118): # im+1 to only look at @@ -248,29 +242,22 @@ def get_peak(self): return peak - def plot_time_to_peak(self, - p_value_max=oplots.P_VALUE_MAX, - color_map=oplots.STIMULUS_COLOR_MAP): - stimulus_table = self.data_set.get_stimulus_table('natural_scenes') + def plot_time_to_peak(self, p_value_max=oplots.P_VALUE_MAX, color_map=oplots.STIMULUS_COLOR_MAP): + stimulus_table = self.data_set.get_stimulus_table("natural_scenes") resps = [] for index, row in self.peak.iterrows(): - mean_response = \ - self.sweep_response.loc[stimulus_table.frame == row.scene_ns][ - str(index)].mean() - resps.append( - (mean_response - mean_response.mean() / mean_response.std())) + mean_response = self.sweep_response.loc[stimulus_table.frame == row.scene_ns][str(index)].mean() + resps.append((mean_response - mean_response.mean() / mean_response.std())) mean_responses = np.array(resps) - sorted_table = self.peak[self.peak.ptest_ns < p_value_max].sort_values( - 'time_to_peak_ns') + sorted_table = self.peak[self.peak.ptest_ns < p_value_max].sort_values("time_to_peak_ns") cell_order = sorted_table.index # time to peak is relative to stimulus start in seconds - ttps = sorted_table.time_to_peak_ns.values + self.interlength / \ - self.acquisition_rate + ttps = sorted_table.time_to_peak_ns.values + self.interlength / self.acquisition_rate msrs_sorted = mean_responses[cell_order, :] oplots.plot_time_to_peak( @@ -280,7 +267,8 @@ def plot_time_to_peak(self, (2 * self.interlength + self.sweeplength) / self.acquisition_rate, self.interlength / self.acquisition_rate, (self.interlength + self.sweeplength) / self.acquisition_rate, - color_map) + color_map, + ) def open_corona_plot(self, cell_specimen_id=None, cell_index=None): cell_index = self.row_from_cell_id(cell_specimen_id, cell_index) @@ -288,39 +276,34 @@ def open_corona_plot(self, cell_specimen_id=None, cell_index=None): df = self.mean_sweep_response[str(cell_index)] data = df.values - st = self.data_set.get_stimulus_table('natural_scenes') + st = self.data_set.get_stimulus_table("natural_scenes") mask = st[st.frame >= 0].index cmin = self.response[0, cell_index, 0] cmax = max(cmin, data.mean() + data.std() * 3) cp = cplots.CoronaPlotter() - cp.plot(st.frame.loc[mask].values, - data=df.loc[mask].values, - clim=[cmin, cmax]) + cp.plot(st.frame.loc[mask].values, data=df.loc[mask].values, clim=[cmin, cmax]) cp.show_arrow() cp.show_circle() def reshape_response_array(self): - ''' + """ :return: response array in cells x stim x repetition for noise correlations - ''' + """ - mean_sweep_response = \ - self.mean_sweep_response.values[:, :self.numbercells] + mean_sweep_response = self.mean_sweep_response.values[:, : self.numbercells] stim_table = self.stim_table frames = np.unique(stim_table.frame.values) - reps = [len(np.where(stim_table.frame.values == frame)[0]) for frame in - frames] + reps = [len(np.where(stim_table.frame.values == frame)[0]) for frame in frames] # just in case there are different numbers of repetitions Nreps = min(reps) - response_new = np.zeros((self.numbercells, self.number_scenes), - dtype='object') + response_new = np.zeros((self.numbercells, self.number_scenes), dtype="object") for i, frame in enumerate(frames): ind = np.where(stim_table.frame.values == frame)[0][:Nreps] for c in range(self.numbercells): @@ -328,112 +311,90 @@ def reshape_response_array(self): return response_new - def get_signal_correlation(self, corr='spearman'): + def get_signal_correlation(self, corr="spearman"): logging.debug("Calculating signal correlations") response = self.response[:, :, 0].T - response = response[:self.numbercells, :] + response = response[: self.numbercells, :] N, Nstim = response.shape signal_corr = np.zeros((N, N)) signal_p = np.empty((N, N)) - if corr == 'pearson': + if corr == "pearson": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.pearsonr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.spearmanr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") # fill in lower triangle - signal_corr = ( - np.triu(signal_corr) + - np.triu(signal_corr, 1).T) + signal_corr = np.triu(signal_corr) + np.triu(signal_corr, 1).T # fill in lower triangle - signal_p = ( - np.triu(signal_p) + - np.triu(signal_p, 1).T) + signal_p = np.triu(signal_p) + np.triu(signal_p, 1).T return signal_corr, signal_p - def get_representational_similarity(self, corr='spearman'): + def get_representational_similarity(self, corr="spearman"): logging.debug("Calculating representational similarity") response = self.response[:, :, 0] - response = response[:, :self.numbercells] + response = response[:, : self.numbercells] Nstim, N = response.shape rep_sim = np.zeros((Nstim, Nstim)) rep_sim_p = np.empty((Nstim, Nstim)) - if corr == 'pearson': + if corr == "pearson": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") - rep_sim = np.triu(rep_sim) + np.triu(rep_sim, - 1).T # fill in lower triangle - rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, - 1).T # fill in lower triangle + rep_sim = np.triu(rep_sim) + np.triu(rep_sim, 1).T # fill in lower triangle + rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, 1).T # fill in lower triangle return rep_sim, rep_sim_p - def get_noise_correlation(self, corr='spearman'): + def get_noise_correlation(self, corr="spearman"): logging.debug("Calculating noise correlations") response = self.reshape_response_array() - noise_corr = np.zeros( - (self.numbercells, self.numbercells, self.number_scenes)) - noise_corr_p = np.zeros( - (self.numbercells, self.numbercells, self.number_scenes)) + noise_corr = np.zeros((self.numbercells, self.numbercells, self.number_scenes)) + noise_corr_p = np.zeros((self.numbercells, self.numbercells, self.number_scenes)) - if corr == 'pearson': + if corr == "pearson": for k in range(self.number_scenes): for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k], noise_corr_p[ - i, j, k] = st.pearsonr(response[i, k], - response[j, k]) + noise_corr[i, j, k], noise_corr_p[i, j, k] = st.pearsonr(response[i, k], response[j, k]) - noise_corr[:, :, k] = np.triu(noise_corr[:, :, k]) + np.triu( - noise_corr[:, :, k], 1).T - noise_corr_p[:, :, k] = np.triu( - noise_corr_p[:, :, k]) + np.triu(noise_corr_p[:, :, k], - 1).T + noise_corr[:, :, k] = np.triu(noise_corr[:, :, k]) + np.triu(noise_corr[:, :, k], 1).T + noise_corr_p[:, :, k] = np.triu(noise_corr_p[:, :, k]) + np.triu(noise_corr_p[:, :, k], 1).T - elif corr == 'spearman': + elif corr == "spearman": for k in range(self.number_scenes): for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k], noise_corr_p[ - i, j, k] = st.spearmanr(response[i, k], - response[j, k]) + noise_corr[i, j, k], noise_corr_p[i, j, k] = st.spearmanr(response[i, k], response[j, k]) - noise_corr[:, :, k] = np.triu(noise_corr[:, :, k]) + np.triu( - noise_corr[:, :, k], 1).T - noise_corr_p[:, :, k] = np.triu( - noise_corr_p[:, :, k]) + np.triu(noise_corr_p[:, :, k], - 1).T + noise_corr[:, :, k] = np.triu(noise_corr[:, :, k]) + np.triu(noise_corr[:, :, k], 1).T + noise_corr_p[:, :, k] = np.triu(noise_corr_p[:, :, k]) + np.triu(noise_corr_p[:, :, k], 1).T else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") return noise_corr, noise_corr_p @@ -443,10 +404,8 @@ def from_analysis_file(data_set, analysis_file): ns.populate_stimulus_table() try: - ns._sweep_response = pd.read_hdf(analysis_file, - "analysis/sweep_response_ns") - ns._mean_sweep_response = pd.read_hdf( - analysis_file, "analysis/mean_sweep_response_ns") + ns._sweep_response = pd.read_hdf(analysis_file, "analysis/sweep_response_ns") + ns._mean_sweep_response = pd.read_hdf(analysis_file, "analysis/mean_sweep_response_ns") ns._peak = pd.read_hdf(analysis_file, "analysis/peak") with h5py.File(analysis_file, "r") as f: @@ -461,8 +420,7 @@ def from_analysis_file(data_set, analysis_file): if "analysis/signal_corr_ns" in f: ns.signal_correlation = f["analysis/signal_corr_ns"][()] if "analysis/rep_similarity_ns" in f: - ns.representational_similarity = f[ - "analysis/rep_similarity_ns"][()] + ns.representational_similarity = f["analysis/rep_similarity_ns"][()] except Exception as e: raise MissingStimulusException(e.args) diff --git a/allensdk/brain_observatory/nwb/__init__.py b/allensdk/brain_observatory/nwb/__init__.py index f9576d5a58..98a8bce10e 100644 --- a/allensdk/brain_observatory/nwb/__init__.py +++ b/allensdk/brain_observatory/nwb/__init__.py @@ -14,16 +14,18 @@ from pynwb.base import TimeSeries, Images from pynwb import ProcessingModule, NWBFile from pynwb.image import GrayscaleImage -from pynwb.ophys import ( - DfOverF, ImageSegmentation, OpticalChannel, Fluorescence) +from pynwb.ophys import DfOverF, ImageSegmentation, OpticalChannel, Fluorescence from allensdk.brain_observatory import dict_to_indexed_array from allensdk.brain_observatory.behavior.image_api import Image from allensdk.brain_observatory.behavior.image_api import ImageApi from allensdk.brain_observatory.behavior.schemas import ( - CompleteOphysBehaviorMetadataSchema, NwbOphysMetadataSchema, - BehaviorMetadataSchema, OphysBehaviorMetadataSchema, - BehaviorTaskParametersSchema, SubjectMetadataSchema + CompleteOphysBehaviorMetadataSchema, + NwbOphysMetadataSchema, + BehaviorMetadataSchema, + OphysBehaviorMetadataSchema, + BehaviorTaskParametersSchema, + SubjectMetadataSchema, ) from allensdk.brain_observatory.nwb.metadata import load_pynwb_extension @@ -31,28 +33,23 @@ log = logging.getLogger("allensdk.brain_observatory.nwb") CELL_SPECIMEN_COL_DESCRIPTIONS = { - 'cell_specimen_id': 'Unified id of segmented cell across experiments ' - '(after cell matching)', - 'height': 'Height of ROI in pixels', - 'width': 'Width of ROI in pixels', - 'mask_image_plane': 'Which image plane an ROI resides on. Overlapping ' - 'ROIs are stored on different mask image planes.', - 'max_correction_down': 'Max motion correction in down direction in pixels', - 'max_correction_left': 'Max motion correction in left direction in pixels', - 'max_correction_up': 'Max motion correction in up direction in pixels', - 'max_correction_right': 'Max motion correction in right direction in ' - 'pixels', - 'valid_roi': 'Indicates if cell classification found the ROI to be a cell ' - 'or not', - 'x': 'x position of ROI in Image Plane in pixels (top left corner)', - 'y': 'y position of ROI in Image Plane in pixels (top left corner)' + "cell_specimen_id": "Unified id of segmented cell across experiments (after cell matching)", + "height": "Height of ROI in pixels", + "width": "Width of ROI in pixels", + "mask_image_plane": "Which image plane an ROI resides on. Overlapping " + "ROIs are stored on different mask image planes.", + "max_correction_down": "Max motion correction in down direction in pixels", + "max_correction_left": "Max motion correction in left direction in pixels", + "max_correction_up": "Max motion correction in up direction in pixels", + "max_correction_right": "Max motion correction in right direction in pixels", + "valid_roi": "Indicates if cell classification found the ROI to be a cell or not", + "x": "x position of ROI in Image Plane in pixels (top left corner)", + "y": "y position of ROI in Image Plane in pixels (top left corner)", } -def check_nwbfile_version(nwbfile_path: str, - desired_minimum_version: str, - warning_msg: str): - with h5py.File(nwbfile_path, 'r') as f: +def check_nwbfile_version(nwbfile_path: str, desired_minimum_version: str, warning_msg: str): + with h5py.File(nwbfile_path, "r") as f: # nwb 2.x files store version as an attribute try: nwb_version = str(f.attrs["nwb_version"]).split(".") @@ -66,10 +63,12 @@ def check_nwbfile_version(nwbfile_path: str, nwb_version = None if nwb_version is None: - warnings.warn(f"'{nwbfile_path}' doesn't appear to be a valid " - f"Neurodata Without Borders (*.nwb) format file as " - f"neither a 'nwb_version' field nor dataset could " - f"be found!") + warnings.warn( + f"'{nwbfile_path}' doesn't appear to be a valid " + f"Neurodata Without Borders (*.nwb) format file as " + f"neither a 'nwb_version' field nor dataset could " + f"be found!" + ) else: if tuple(nwb_version) < tuple(desired_minimum_version.split(".")): warnings.warn(warning_msg) @@ -123,110 +122,94 @@ def read_eye_gaze_mappings(input_path: Path) -> dict: """ eye_gaze_data = {} - eye_gaze_data["raw_eye_areas"] = \ - pd.read_hdf(input_path, key="raw_eye_areas") - eye_gaze_data["raw_pupil_areas"] = \ - pd.read_hdf(input_path, key="raw_pupil_areas") - eye_gaze_data["raw_screen_coordinates"] = \ - pd.read_hdf(input_path, key="raw_screen_coordinates") - eye_gaze_data["raw_screen_coordinates_spherical"] = \ - pd.read_hdf(input_path, key="raw_screen_coordinates_spherical") - eye_gaze_data["new_eye_areas"] = \ - pd.read_hdf(input_path, key="new_eye_areas") - eye_gaze_data["new_pupil_areas"] = \ - pd.read_hdf(input_path, key="new_pupil_areas") - eye_gaze_data["new_screen_coordinates"] = \ - pd.read_hdf(input_path, key="new_screen_coordinates") - eye_gaze_data["new_screen_coordinates_spherical"] = \ - pd.read_hdf(input_path, key="new_screen_coordinates_spherical") - eye_gaze_data["synced_frame_timestamps"] = \ - pd.read_hdf(input_path, key="synced_frame_timestamps") + eye_gaze_data["raw_eye_areas"] = pd.read_hdf(input_path, key="raw_eye_areas") + eye_gaze_data["raw_pupil_areas"] = pd.read_hdf(input_path, key="raw_pupil_areas") + eye_gaze_data["raw_screen_coordinates"] = pd.read_hdf(input_path, key="raw_screen_coordinates") + eye_gaze_data["raw_screen_coordinates_spherical"] = pd.read_hdf(input_path, key="raw_screen_coordinates_spherical") + eye_gaze_data["new_eye_areas"] = pd.read_hdf(input_path, key="new_eye_areas") + eye_gaze_data["new_pupil_areas"] = pd.read_hdf(input_path, key="new_pupil_areas") + eye_gaze_data["new_screen_coordinates"] = pd.read_hdf(input_path, key="new_screen_coordinates") + eye_gaze_data["new_screen_coordinates_spherical"] = pd.read_hdf(input_path, key="new_screen_coordinates_spherical") + eye_gaze_data["synced_frame_timestamps"] = pd.read_hdf(input_path, key="synced_frame_timestamps") return eye_gaze_data def create_eye_gaze_mapping_dataframe(eye_gaze_data: dict) -> pd.DataFrame: - - eye_gaze_mapping_df = pd.DataFrame({ - "raw_eye_area": eye_gaze_data["raw_eye_areas"].values, - "raw_pupil_area": eye_gaze_data["raw_pupil_areas"].values, - "raw_screen_coordinates_x_cm": - eye_gaze_data["raw_screen_coordinates"]["x_pos_cm"].values, - "raw_screen_coordinates_y_cm": - eye_gaze_data["raw_screen_coordinates"]["y_pos_cm"].values, - "raw_screen_coordinates_spherical_x_deg": - eye_gaze_data["raw_screen_coordinates_spherical"]["x_pos_deg"].values, - "raw_screen_coordinates_spherical_y_deg": - eye_gaze_data["raw_screen_coordinates_spherical"]["y_pos_deg"].values, - "filtered_eye_area": eye_gaze_data["new_eye_areas"].values, - "filtered_pupil_area": eye_gaze_data["new_pupil_areas"].values, - "filtered_screen_coordinates_x_cm": - eye_gaze_data["new_screen_coordinates"]["x_pos_cm"].values, - "filtered_screen_coordinates_y_cm": - eye_gaze_data["new_screen_coordinates"]["y_pos_cm"].values, - "filtered_screen_coordinates_spherical_x_deg": - eye_gaze_data["new_screen_coordinates_spherical"]["x_pos_deg"].values, - "filtered_screen_coordinates_spherical_y_deg": - eye_gaze_data["new_screen_coordinates_spherical"]["y_pos_deg"].values + eye_gaze_mapping_df = pd.DataFrame( + { + "raw_eye_area": eye_gaze_data["raw_eye_areas"].values, + "raw_pupil_area": eye_gaze_data["raw_pupil_areas"].values, + "raw_screen_coordinates_x_cm": eye_gaze_data["raw_screen_coordinates"]["x_pos_cm"].values, + "raw_screen_coordinates_y_cm": eye_gaze_data["raw_screen_coordinates"]["y_pos_cm"].values, + "raw_screen_coordinates_spherical_x_deg": eye_gaze_data["raw_screen_coordinates_spherical"][ + "x_pos_deg" + ].values, + "raw_screen_coordinates_spherical_y_deg": eye_gaze_data["raw_screen_coordinates_spherical"][ + "y_pos_deg" + ].values, + "filtered_eye_area": eye_gaze_data["new_eye_areas"].values, + "filtered_pupil_area": eye_gaze_data["new_pupil_areas"].values, + "filtered_screen_coordinates_x_cm": eye_gaze_data["new_screen_coordinates"]["x_pos_cm"].values, + "filtered_screen_coordinates_y_cm": eye_gaze_data["new_screen_coordinates"]["y_pos_cm"].values, + "filtered_screen_coordinates_spherical_x_deg": eye_gaze_data["new_screen_coordinates_spherical"][ + "x_pos_deg" + ].values, + "filtered_screen_coordinates_spherical_y_deg": eye_gaze_data["new_screen_coordinates_spherical"][ + "y_pos_deg" + ].values, }, - index=eye_gaze_data["synced_frame_timestamps"].values + index=eye_gaze_data["synced_frame_timestamps"].values, ) return eye_gaze_mapping_df -def eye_tracking_data_is_valid(eye_dlc_tracking_data: dict, - synced_timestamps: pd.Series) -> bool: +def eye_tracking_data_is_valid(eye_dlc_tracking_data: dict, synced_timestamps: pd.Series) -> bool: is_valid = True pupil_params = eye_dlc_tracking_data["pupil_params"] cr_params = eye_dlc_tracking_data["cr_params"] eye_params = eye_dlc_tracking_data["eye_params"] - num_frames_match = ((pupil_params.shape[0] == cr_params.shape[0]) - and (cr_params.shape[0] == eye_params.shape[0])) + num_frames_match = (pupil_params.shape[0] == cr_params.shape[0]) and (cr_params.shape[0] == eye_params.shape[0]) if not num_frames_match: - log.warn("The number of frames for ellipse fits don't " - "match when they should. No ellipse fits will be written! " - f"pupil_params ({pupil_params.shape[0]}), " - f"cr_params ({cr_params.shape[0]}), " - f"eye_params ({eye_params.shape[0]})") + log.warn( + "The number of frames for ellipse fits don't " + "match when they should. No ellipse fits will be written! " + f"pupil_params ({pupil_params.shape[0]}), " + f"cr_params ({cr_params.shape[0]}), " + f"eye_params ({eye_params.shape[0]})" + ) is_valid = False - if (pupil_params.shape[0] != len(synced_timestamps)): - log.warn("The number of camera sync pulses in the " - f"sync file ({len(synced_timestamps)}) do not match " - "with the number of eye tracking frames " - f"({pupil_params.shape[0]})! No ellipse fits will be " - "written!") + if pupil_params.shape[0] != len(synced_timestamps): + log.warn( + "The number of camera sync pulses in the " + f"sync file ({len(synced_timestamps)}) do not match " + "with the number of eye tracking frames " + f"({pupil_params.shape[0]})! No ellipse fits will be " + "written!" + ) is_valid = False return is_valid -def create_eye_tracking_nwb_processing_module(eye_dlc_tracking_data: dict, - synced_timestamps: pd.Series - ) -> pynwb.ProcessingModule: - +def create_eye_tracking_nwb_processing_module( + eye_dlc_tracking_data: dict, synced_timestamps: pd.Series +) -> pynwb.ProcessingModule: # Top level container for eye tracking processed data - eye_tracking_mod = pynwb.ProcessingModule( - name='eye_tracking', - description='Eye tracking processing module') + eye_tracking_mod = pynwb.ProcessingModule(name="eye_tracking", description="Eye tracking processing module") # Data interfaces of dlc_fits_container - pupil_fits = eye_dlc_tracking_data["pupil_params"].assign( - timestamps=synced_timestamps) - pupil_params = pynwb.core.DynamicTable.from_dataframe( - df=pupil_fits, name="pupil_ellipse_fits") + pupil_fits = eye_dlc_tracking_data["pupil_params"].assign(timestamps=synced_timestamps) + pupil_params = pynwb.core.DynamicTable.from_dataframe(df=pupil_fits, name="pupil_ellipse_fits") - cr_fits = eye_dlc_tracking_data["cr_params"].assign( - timestamps=synced_timestamps) - cr_params = pynwb.core.DynamicTable.from_dataframe(df=cr_fits, - name="cr_ellipse_fits") + cr_fits = eye_dlc_tracking_data["cr_params"].assign(timestamps=synced_timestamps) + cr_params = pynwb.core.DynamicTable.from_dataframe(df=cr_fits, name="cr_ellipse_fits") - eye_fits = eye_dlc_tracking_data["eye_params"].assign( - timestamps=synced_timestamps) - eye_params = pynwb.core.DynamicTable.from_dataframe( - df=eye_fits, name="eye_ellipse_fits") + eye_fits = eye_dlc_tracking_data["eye_params"].assign(timestamps=synced_timestamps) + eye_params = pynwb.core.DynamicTable.from_dataframe(df=eye_fits, name="eye_ellipse_fits") eye_tracking_mod.add_data_interface(pupil_params) eye_tracking_mod.add_data_interface(cr_params) @@ -235,40 +218,34 @@ def create_eye_tracking_nwb_processing_module(eye_dlc_tracking_data: dict, return eye_tracking_mod -def add_eye_gaze_data_interfaces(pynwb_container: pynwb.NWBContainer, - pupil_areas: pd.Series, - eye_areas: pd.Series, - screen_coordinates: pd.DataFrame, - screen_coordinates_spherical: pd.DataFrame, - synced_timestamps: pd.Series - ) -> pynwb.NWBContainer: - +def add_eye_gaze_data_interfaces( + pynwb_container: pynwb.NWBContainer, + pupil_areas: pd.Series, + eye_areas: pd.Series, + screen_coordinates: pd.DataFrame, + screen_coordinates_spherical: pd.DataFrame, + synced_timestamps: pd.Series, +) -> pynwb.NWBContainer: pupil_area_ts = pynwb.base.TimeSeries( - name="pupil_area", - data=pupil_areas.values, - timestamps=synced_timestamps.values, - unit="Pixels ^ 2" + name="pupil_area", data=pupil_areas.values, timestamps=synced_timestamps.values, unit="Pixels ^ 2" ) eye_area_ts = pynwb.base.TimeSeries( - name="eye_area", - data=eye_areas.values, - timestamps=synced_timestamps.values, - unit="Pixels ^ 2" + name="eye_area", data=eye_areas.values, timestamps=synced_timestamps.values, unit="Pixels ^ 2" ) screen_coord_ts = pynwb.base.TimeSeries( name="screen_coordinates", data=screen_coordinates.values, timestamps=synced_timestamps.values, - unit="Centimeters" + unit="Centimeters", ) screen_coord_spherical_ts = pynwb.base.TimeSeries( name="screen_coordinates_spherical", data=screen_coordinates_spherical.values, timestamps=synced_timestamps.values, - unit="Degrees" + unit="Degrees", ) pynwb_container.add_data_interface(pupil_area_ts) @@ -282,21 +259,22 @@ def add_eye_gaze_data_interfaces(pynwb_container: pynwb.NWBContainer, def create_gaze_mapping_nwb_processing_modules(eye_gaze_data: dict): # Container for raw gaze mapped data raw_gaze_mapping_mod = pynwb.ProcessingModule( - name='raw_gaze_mapping', - description='Gaze mapping processing module raw outputs') + name="raw_gaze_mapping", description="Gaze mapping processing module raw outputs" + ) raw_gaze_mapping_mod = add_eye_gaze_data_interfaces( - raw_gaze_mapping_mod, - pupil_areas=eye_gaze_data["raw_pupil_areas"], - eye_areas=eye_gaze_data["raw_eye_areas"], - screen_coordinates=eye_gaze_data["raw_screen_coordinates"], - screen_coordinates_spherical=eye_gaze_data["raw_screen_coordinates_spherical"], # noqa: E501 - synced_timestamps=eye_gaze_data["synced_frame_timestamps"]) + raw_gaze_mapping_mod, + pupil_areas=eye_gaze_data["raw_pupil_areas"], + eye_areas=eye_gaze_data["raw_eye_areas"], + screen_coordinates=eye_gaze_data["raw_screen_coordinates"], + screen_coordinates_spherical=eye_gaze_data["raw_screen_coordinates_spherical"], # noqa: E501 + synced_timestamps=eye_gaze_data["synced_frame_timestamps"], + ) # Container for filtered gaze mapped data filt_gaze_mapping_mod = pynwb.ProcessingModule( - name='filtered_gaze_mapping', - description='Gaze mapping processing module filtered outputs') + name="filtered_gaze_mapping", description="Gaze mapping processing module filtered outputs" + ) filt_gaze_mapping_mod = add_eye_gaze_data_interfaces( filt_gaze_mapping_mod, @@ -304,71 +282,65 @@ def create_gaze_mapping_nwb_processing_modules(eye_gaze_data: dict): eye_areas=eye_gaze_data["new_eye_areas"], screen_coordinates=eye_gaze_data["new_screen_coordinates"], screen_coordinates_spherical=eye_gaze_data["new_screen_coordinates_spherical"], # noqa: E501 - synced_timestamps=eye_gaze_data["synced_frame_timestamps"]) + synced_timestamps=eye_gaze_data["synced_frame_timestamps"], + ) return (raw_gaze_mapping_mod, filt_gaze_mapping_mod) -def add_eye_tracking_ellipse_fit_data_to_nwbfile(nwbfile: pynwb.NWBFile, - eye_dlc_tracking_data: dict, - synced_timestamps: pd.Series - ) -> pynwb.NWBFile: - eye_tracking_mod = create_eye_tracking_nwb_processing_module( - eye_dlc_tracking_data, synced_timestamps) +def add_eye_tracking_ellipse_fit_data_to_nwbfile( + nwbfile: pynwb.NWBFile, eye_dlc_tracking_data: dict, synced_timestamps: pd.Series +) -> pynwb.NWBFile: + eye_tracking_mod = create_eye_tracking_nwb_processing_module(eye_dlc_tracking_data, synced_timestamps) nwbfile.add_processing_module(eye_tracking_mod) return nwbfile -def add_eye_gaze_mapping_data_to_nwbfile(nwbfile: pynwb.NWBFile, - eye_gaze_data: dict) -> pynwb.NWBFile: - raw_gaze_mapping_mod, filt_gaze_mapping_mod = \ - create_gaze_mapping_nwb_processing_modules(eye_gaze_data) +def add_eye_gaze_mapping_data_to_nwbfile(nwbfile: pynwb.NWBFile, eye_gaze_data: dict) -> pynwb.NWBFile: + raw_gaze_mapping_mod, filt_gaze_mapping_mod = create_gaze_mapping_nwb_processing_modules(eye_gaze_data) nwbfile.add_processing_module(raw_gaze_mapping_mod) nwbfile.add_processing_module(filt_gaze_mapping_mod) return nwbfile -def add_running_acquisition_to_nwbfile(nwbfile, - running_acquisition_df: pd.DataFrame): - +def add_running_acquisition_to_nwbfile(nwbfile, running_acquisition_df: pd.DataFrame): running_dx_series = TimeSeries( - name='dx', - data=running_acquisition_df['dx'].values, + name="dx", + data=running_acquisition_df["dx"].values, timestamps=running_acquisition_df.index.values, - unit='cm', - description=( - 'Running wheel angular change, computed during data collection') + unit="cm", + description=("Running wheel angular change, computed during data collection"), ) v_sig = TimeSeries( - name='v_sig', - data=running_acquisition_df['v_sig'].values, + name="v_sig", + data=running_acquisition_df["v_sig"].values, timestamps=running_acquisition_df.index.values, - unit='V', - description='Voltage signal from the running wheel encoder' + unit="V", + description="Voltage signal from the running wheel encoder", ) v_in = TimeSeries( - name='v_in', - data=running_acquisition_df['v_in'].values, + name="v_in", + data=running_acquisition_df["v_in"].values, timestamps=running_acquisition_df.index.values, - unit='V', + unit="V", description=( - 'The theoretical maximum voltage that the running wheel encoder ' + "The theoretical maximum voltage that the running wheel encoder " 'will reach prior to "wrapping". This should ' - 'theoretically be 5V (after crossing 5V goes to 0V, or ' - 'vice versa). In practice the encoder does not always ' - 'reach this value before wrapping, which can cause ' - 'transient spikes in speed at the voltage "wraps".') + "theoretically be 5V (after crossing 5V goes to 0V, or " + "vice versa). In practice the encoder does not always " + "reach this value before wrapping, which can cause " + 'transient spikes in speed at the voltage "wraps".' + ), ) - if 'running' in nwbfile.processing: - running_mod = nwbfile.processing['running'] + if "running" in nwbfile.processing: + running_mod = nwbfile.processing["running"] else: - running_mod = ProcessingModule('running', - 'Running speed processing module') + running_mod = ProcessingModule("running", "Running speed processing module") nwbfile.add_processing_module(running_mod) running_mod.add_data_interface(running_dx_series) @@ -378,10 +350,8 @@ def add_running_acquisition_to_nwbfile(nwbfile, return nwbfile -def add_running_speed_to_nwbfile(nwbfile, running_speed, - name='speed', unit='cm/s', - from_dataframe=False): - ''' Adds running speed data to an NWBFile as a timeseries in acquisition +def add_running_speed_to_nwbfile(nwbfile, running_speed, name="speed", unit="cm/s", from_dataframe=False): + """Adds running speed data to an NWBFile as a timeseries in acquisition Parameters ---------- @@ -401,26 +371,21 @@ def add_running_speed_to_nwbfile(nwbfile, running_speed, ------- nwbfile : pynwb.NWBFile - ''' + """ if from_dataframe: - data = running_speed['speed'].values - timestamps = running_speed['timestamps'].values + data = running_speed["speed"].values + timestamps = running_speed["timestamps"].values else: data = running_speed.values timestamps = running_speed.timestamps - running_speed_series = pynwb.base.TimeSeries( - name=name, - data=data, - timestamps=timestamps, - unit=unit) + running_speed_series = pynwb.base.TimeSeries(name=name, data=data, timestamps=timestamps, unit=unit) - if 'running' in nwbfile.processing: - running_mod = nwbfile.processing['running'] + if "running" in nwbfile.processing: + running_mod = nwbfile.processing["running"] else: - running_mod = ProcessingModule('running', - 'Running speed processing module') + running_mod = ProcessingModule("running", "Running speed processing module") nwbfile.add_processing_module(running_mod) running_mod.add_data_interface(running_speed_series) @@ -429,12 +394,11 @@ def add_running_speed_to_nwbfile(nwbfile, running_speed, def create_stimulus_presentation_time_interval( - name: str, description: str, - columns_to_add: Iterable) -> pynwb.epoch.TimeIntervals: + name: str, description: str, columns_to_add: Iterable +) -> pynwb.epoch.TimeIntervals: column_descriptions = { "stimulus_name": "Name of stimulus", - "stimulus_block": ("Index of contiguous presentations of " - "one stimulus type"), + "stimulus_block": ("Index of contiguous presentations of one stimulus type"), "temporal_frequency": "Temporal frequency of stimulus", "x_position": "Horizontal position of stimulus on screen", "y_position": "Vertical position of stimulus on screen", @@ -456,18 +420,16 @@ def create_stimulus_presentation_time_interval( "fieldPos": "Position of moving dot field", "fieldShape": "Shape of moving dot field", "fieldSize": "Size of moving dot field", - "nDots": "Number of dots in moving dot field" + "nDots": "Number of dots in moving dot field", } - columns_to_ignore = {'start_time', 'stop_time', 'tags', 'timeseries'} + columns_to_ignore = {"start_time", "stop_time", "tags", "timeseries"} - interval = pynwb.epoch.TimeIntervals(name=name, - description=description) + interval = pynwb.epoch.TimeIntervals(name=name, description=description) for column_name in columns_to_add: if column_name not in columns_to_ignore: - description = column_descriptions.get( - column_name, "No description") + description = column_descriptions.get(column_name, "No description") interval.add_column(name=column_name, description=description) return interval @@ -489,14 +451,14 @@ def add_invalid_times(nwbfile, epochs): table = setup_table_for_invalid_times(epochs) if not table.empty: - container = pynwb.epoch.TimeIntervals('invalid_times') + container = pynwb.epoch.TimeIntervals("invalid_times") for index, row in table.iterrows(): - - container.add_interval(start_time=row['start_time'], - stop_time=row['stop_time'], - tags=row['tags'], - ) + container.add_interval( + start_time=row["start_time"], + stop_time=row["stop_time"], + tags=row["tags"], + ) nwbfile.invalid_times = container @@ -521,17 +483,12 @@ def setup_table_for_invalid_times(invalid_epochs): if invalid_epochs: df = pd.DataFrame.from_dict(invalid_epochs) - start_time = df['start_time'].values - stop_time = df['end_time'].values - tags = [[_type, str(_id), label] - for _type, _id, label - in zip(df['type'], df['id'], df['label'])] + start_time = df["start_time"].values + stop_time = df["end_time"].values + tags = [[_type, str(_id), label] for _type, _id, label in zip(df["type"], df["id"], df["label"])] - table = pd.DataFrame({'start_time': start_time, - 'stop_time': stop_time, - 'tags': tags} - ) - table.index.name = 'id' + table = pd.DataFrame({"start_time": start_time, "stop_time": stop_time, "tags": tags}) + table.index.name = "id" else: table = pd.DataFrame() @@ -541,30 +498,21 @@ def setup_table_for_invalid_times(invalid_epochs): def setup_table_for_epochs(table, timeseries, tag): table = table.copy() - indices = np.searchsorted(timeseries.timestamps[:], - table['start_time'].values) + indices = np.searchsorted(timeseries.timestamps[:], table["start_time"].values) if len(indices > 0): - diffs = np.concatenate([np.diff(indices), - [table.shape[0] - indices[-1]]]) + diffs = np.concatenate([np.diff(indices), [table.shape[0] - indices[-1]]]) else: diffs = [] - table['tags'] = [(tag,)] * table.shape[0] - table['timeseries'] = [[[indices[ii], diffs[ii], timeseries]] - for ii in range(table.shape[0])] + table["tags"] = [(tag,)] * table.shape[0] + table["timeseries"] = [[[indices[ii], diffs[ii], timeseries]] for ii in range(table.shape[0])] return table -def add_stimulus_timestamps(nwbfile, stimulus_timestamps, - module_name='stimulus'): - stimulus_ts = TimeSeries( - data=stimulus_timestamps, - name='timestamps', - timestamps=stimulus_timestamps, - unit='s' - ) +def add_stimulus_timestamps(nwbfile, stimulus_timestamps, module_name="stimulus"): + stimulus_ts = TimeSeries(data=stimulus_timestamps, name="timestamps", timestamps=stimulus_timestamps, unit="s") - stim_mod = ProcessingModule(module_name, 'Stimulus Times processing') + stim_mod = ProcessingModule(module_name, "Stimulus Times processing") nwbfile.add_processing_module(stim_mod) stim_mod.add_data_interface(stimulus_ts) @@ -574,47 +522,37 @@ def add_stimulus_timestamps(nwbfile, stimulus_timestamps, def add_trials(nwbfile, trials, description_dict={}): order = list(trials.index) - for _, row in trials[['start_time', 'stop_time']].iterrows(): + for _, row in trials[["start_time", "stop_time"]].iterrows(): row_dict = row.to_dict() nwbfile.add_trial(**row_dict) for c in trials.columns: - if c in ['start_time', 'stop_time']: + if c in ["start_time", "stop_time"]: continue index, data = dict_to_indexed_array(trials[c].to_dict(), order) - if data.dtype == ' Optional[str]: """Convert numeric age_in_days to ISO 8601""" if age_in_days is None: - return 'null' - return f'P{age_in_days}D' + return "null" + return f"P{age_in_days}D" nwb_subject = BehaviorSubject( description="A visual behavior subject with a LabTracks ID", - age=_get_age(age_in_days=subject_metadata['age_in_days']), + age=_get_age(age_in_days=subject_metadata["age_in_days"]), driver_line=subject_metadata["driver_line"], genotype=subject_metadata["genotype"], subject_id=str(subject_metadata["subject_id"]), reporter_line=subject_metadata["reporter_line"], sex=subject_metadata["sex"], - species='Mus musculus') + species="Mus musculus", + ) nwbfile.subject = nwb_subject # Remove metadata that will go into pyNWB base classes @@ -774,25 +687,17 @@ def _get_age(age_in_days: Optional[int]) -> Optional[str]: new_metadata_dict[key] = val if behavior_only: - BehaviorMetadata = load_pynwb_extension(BehaviorMetadataSchema, - 'ndx-aibs-behavior-ophys') - nwb_metadata = BehaviorMetadata(name='metadata', **new_metadata_dict) + BehaviorMetadata = load_pynwb_extension(BehaviorMetadataSchema, "ndx-aibs-behavior-ophys") + nwb_metadata = BehaviorMetadata(name="metadata", **new_metadata_dict) else: - OphysBehaviorMetadata = load_pynwb_extension( - OphysBehaviorMetadataSchema, 'ndx-aibs-behavior-ophys') - nwb_metadata = OphysBehaviorMetadata(name='metadata', - **new_metadata_dict) + OphysBehaviorMetadata = load_pynwb_extension(OphysBehaviorMetadataSchema, "ndx-aibs-behavior-ophys") + nwb_metadata = OphysBehaviorMetadata(name="metadata", **new_metadata_dict) nwbfile.add_lab_meta_data(nwb_metadata) def add_task_parameters(nwbfile, task_parameters): - - OphysBehaviorTaskParameters = load_pynwb_extension( - BehaviorTaskParametersSchema, 'ndx-aibs-behavior-ophys' - ) - task_parameters_clean = BehaviorTaskParametersSchema().dump( - task_parameters - ) + OphysBehaviorTaskParameters = load_pynwb_extension(BehaviorTaskParametersSchema, "ndx-aibs-behavior-ophys") + task_parameters_clean = BehaviorTaskParametersSchema().dump(task_parameters) new_task_parameters_dict = {} for key, val in task_parameters_clean.items(): @@ -800,14 +705,11 @@ def add_task_parameters(nwbfile, task_parameters): new_task_parameters_dict[key] = np.array(val) else: new_task_parameters_dict[key] = val - nwb_task_parameters = OphysBehaviorTaskParameters( - name='task_parameters', **new_task_parameters_dict) + nwb_task_parameters = OphysBehaviorTaskParameters(name="task_parameters", **new_task_parameters_dict) nwbfile.add_lab_meta_data(nwb_task_parameters) -def add_cell_specimen_table(nwbfile: NWBFile, - cell_specimen_table: pd.DataFrame, - session_metadata: dict): +def add_cell_specimen_table(nwbfile: NWBFile, cell_specimen_table: pd.DataFrame, session_metadata: dict): """ This function takes the cell specimen table and writes the ROIs contained within. It writes these to a new NWB imaging plane @@ -834,175 +736,164 @@ def add_cell_specimen_table(nwbfile: NWBFile, nwbfile: NWBFile The altered in memory NWBFile object that now has a specimen table """ - cell_specimen_metadata = NwbOphysMetadataSchema().load( - session_metadata, unknown=marshmallow.EXCLUDE) - cell_roi_table = cell_specimen_table.reset_index().set_index('cell_roi_id') + cell_specimen_metadata = NwbOphysMetadataSchema().load(session_metadata, unknown=marshmallow.EXCLUDE) + cell_roi_table = cell_specimen_table.reset_index().set_index("cell_roi_id") # Device: - device_name: str = nwbfile.lab_meta_data['metadata'].equipment_name + device_name: str = nwbfile.lab_meta_data["metadata"].equipment_name if device_name.startswith("MESO"): - device_config = { - "name": device_name, - "description": "Allen Brain Observatory - Mesoscope 2P Rig" - } + device_config = {"name": device_name, "description": "Allen Brain Observatory - Mesoscope 2P Rig"} else: device_config = { "name": device_name, "description": "Allen Brain Observatory - Scientifica 2P Rig", - "manufacturer": "Scientifica" + "manufacturer": "Scientifica", } nwbfile.create_device(**device_config) device = nwbfile.get_device(device_name) # FOV: - fov_width = nwbfile.lab_meta_data['metadata'].field_of_view_width - fov_height = nwbfile.lab_meta_data['metadata'].field_of_view_height + fov_width = nwbfile.lab_meta_data["metadata"].field_of_view_width + fov_height = nwbfile.lab_meta_data["metadata"].field_of_view_height imaging_plane_description = "{} field of view in {} at depth {} um".format( (fov_width, fov_height), - cell_specimen_metadata['targeted_structure'], - nwbfile.lab_meta_data['metadata'].imaging_depth) + cell_specimen_metadata["targeted_structure"], + nwbfile.lab_meta_data["metadata"].imaging_depth, + ) # Optical Channel: optical_channel = OpticalChannel( - name='channel_1', - description='2P Optical Channel', - emission_lambda=cell_specimen_metadata['emission_lambda']) + name="channel_1", description="2P Optical Channel", emission_lambda=cell_specimen_metadata["emission_lambda"] + ) # Imaging Plane: imaging_plane = nwbfile.create_imaging_plane( - name='imaging_plane_1', + name="imaging_plane_1", optical_channel=optical_channel, description=imaging_plane_description, device=device, - excitation_lambda=cell_specimen_metadata['excitation_lambda'], - imaging_rate=cell_specimen_metadata['ophys_frame_rate'], - indicator=cell_specimen_metadata['indicator'], - location=cell_specimen_metadata['targeted_structure']) + excitation_lambda=cell_specimen_metadata["excitation_lambda"], + imaging_rate=cell_specimen_metadata["ophys_frame_rate"], + indicator=cell_specimen_metadata["indicator"], + location=cell_specimen_metadata["targeted_structure"], + ) # Image Segmentation: image_segmentation = ImageSegmentation(name="image_segmentation") - if 'ophys' not in nwbfile.processing: - ophys_module = ProcessingModule('ophys', 'Ophys processing module') + if "ophys" not in nwbfile.processing: + ophys_module = ProcessingModule("ophys", "Ophys processing module") nwbfile.add_processing_module(ophys_module) else: - ophys_module = nwbfile.processing['ophys'] + ophys_module = nwbfile.processing["ophys"] ophys_module.add_data_interface(image_segmentation) # Plane Segmentation: plane_segmentation = image_segmentation.create_plane_segmentation( - name='cell_specimen_table', - description="Segmented rois", - imaging_plane=imaging_plane) + name="cell_specimen_table", description="Segmented rois", imaging_plane=imaging_plane + ) for col_name in cell_roi_table.columns: # the columns 'roi_mask', 'pixel_mask', and 'voxel_mask' are # already defined in the nwb.ophys::PlaneSegmentation Object - if col_name not in ['id', 'mask_matrix', 'roi_mask', - 'pixel_mask', 'voxel_mask']: + if col_name not in ["id", "mask_matrix", "roi_mask", "pixel_mask", "voxel_mask"]: # This builds the columns with name of column and description # of column both equal to the column name in the cell_roi_table plane_segmentation.add_column( - col_name, - CELL_SPECIMEN_COL_DESCRIPTIONS.get( - col_name, - "No Description Available")) + col_name, CELL_SPECIMEN_COL_DESCRIPTIONS.get(col_name, "No Description Available") + ) # go through each roi and add it to the plan segmentation object for cell_roi_id, table_row in cell_roi_table.iterrows(): - # NOTE: The 'roi_mask' in this cell_roi_table has already been # processing by the function from # allensdk.brain_observatory.behavior.session_apis.data_io.ophys_lims_api # get_cell_specimen_table() method. As a result, the ROI is stored in # an array that is the same shape as the FULL field of view of the # experiment (e.g. 512 x 512). - mask = table_row.pop('roi_mask') + mask = table_row.pop("roi_mask") - csid = table_row.pop('cell_specimen_id') - table_row['cell_specimen_id'] = -1 if csid is None else csid - table_row['id'] = cell_roi_id + csid = table_row.pop("cell_specimen_id") + table_row["cell_specimen_id"] = -1 if csid is None else csid + table_row["id"] = cell_roi_id plane_segmentation.add_roi(image_mask=mask, **table_row.to_dict()) return nwbfile def add_dff_traces(nwbfile, dff_traces, ophys_timestamps): - dff_traces = dff_traces.reset_index().set_index('cell_roi_id')[['dff']] + dff_traces = dff_traces.reset_index().set_index("cell_roi_id")[["dff"]] - ophys_module = nwbfile.processing['ophys'] + ophys_module = nwbfile.processing["ophys"] # trace data in the form of rois x timepoints - trace_data = np.array([dff_traces.loc[cell_roi_id].dff - for cell_roi_id in dff_traces.index.values]) + trace_data = np.array([dff_traces.loc[cell_roi_id].dff for cell_roi_id in dff_traces.index.values]) - cell_specimen_table = nwbfile.processing['ophys'].data_interfaces['image_segmentation'].plane_segmentations['cell_specimen_table'] # noqa: E501 + cell_specimen_table = ( + nwbfile.processing["ophys"].data_interfaces["image_segmentation"].plane_segmentations["cell_specimen_table"] + ) # noqa: E501 roi_table_region = cell_specimen_table.create_roi_table_region( - description="segmented cells labeled by cell_specimen_id", - region=slice(len(dff_traces))) + description="segmented cells labeled by cell_specimen_id", region=slice(len(dff_traces)) + ) # Create/Add dff modules and interfaces: - assert dff_traces.index.name == 'cell_roi_id' - dff_interface = DfOverF(name='dff') + assert dff_traces.index.name == "cell_roi_id" + dff_interface = DfOverF(name="dff") ophys_module.add_data_interface(dff_interface) dff_interface.create_roi_response_series( - name='traces', + name="traces", data=trace_data.T, # Should be stored as timepoints x rois - unit='NA', + unit="NA", rois=roi_table_region, - timestamps=ophys_timestamps) + timestamps=ophys_timestamps, + ) return nwbfile def add_corrected_fluorescence_traces(nwbfile, corrected_fluorescence_traces): - corrected_fluorescence_traces = \ - corrected_fluorescence_traces.reset_index().set_index( - 'cell_roi_id')[['corrected_fluorescence']] + corrected_fluorescence_traces = corrected_fluorescence_traces.reset_index().set_index("cell_roi_id")[ + ["corrected_fluorescence"] + ] # Create/Add corrected_fluorescence_traces modules and interfaces: - assert corrected_fluorescence_traces.index.name == 'cell_roi_id' - ophys_module = nwbfile.processing['ophys'] + assert corrected_fluorescence_traces.index.name == "cell_roi_id" + ophys_module = nwbfile.processing["ophys"] # trace data in the form of rois x timepoints f_trace_data = np.array( - [corrected_fluorescence_traces.loc[cell_roi_id].corrected_fluorescence - for cell_roi_id in corrected_fluorescence_traces.index.values]) + [ + corrected_fluorescence_traces.loc[cell_roi_id].corrected_fluorescence + for cell_roi_id in corrected_fluorescence_traces.index.values + ] + ) - roi_table_region = nwbfile.processing['ophys'].data_interfaces['dff'].roi_response_series['traces'].rois # noqa: E501 - ophys_timestamps = ophys_module.get_data_interface( - 'dff').roi_response_series['traces'].timestamps - f_interface = Fluorescence(name='corrected_fluorescence') + roi_table_region = nwbfile.processing["ophys"].data_interfaces["dff"].roi_response_series["traces"].rois # noqa: E501 + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps + f_interface = Fluorescence(name="corrected_fluorescence") ophys_module.add_data_interface(f_interface) f_interface.create_roi_response_series( - name='traces', + name="traces", data=f_trace_data.T, # Should be stored as timepoints x rois - unit='NA', + unit="NA", rois=roi_table_region, - timestamps=ophys_timestamps) + timestamps=ophys_timestamps, + ) return nwbfile def add_motion_correction(nwbfile, motion_correction): - - ophys_module = nwbfile.processing['ophys'] - ophys_timestamps = ophys_module.get_data_interface( - 'dff').roi_response_series['traces'].timestamps + ophys_module = nwbfile.processing["ophys"] + ophys_timestamps = ophys_module.get_data_interface("dff").roi_response_series["traces"].timestamps t1 = TimeSeries( - name='ophys_motion_correction_x', - data=motion_correction['x'].values, - timestamps=ophys_timestamps, - unit='pixels' + name="ophys_motion_correction_x", data=motion_correction["x"].values, timestamps=ophys_timestamps, unit="pixels" ) t2 = TimeSeries( - name='ophys_motion_correction_y', - data=motion_correction['y'].values, - timestamps=ophys_timestamps, - unit='pixels' + name="ophys_motion_correction_y", data=motion_correction["y"].values, timestamps=ophys_timestamps, unit="pixels" ) ophys_module.add_data_interface(t1) diff --git a/allensdk/brain_observatory/nwb/behavior_ophys_nwb_extension_builder.py b/allensdk/brain_observatory/nwb/behavior_ophys_nwb_extension_builder.py index 3857c59d66..399bca926f 100644 --- a/allensdk/brain_observatory/nwb/behavior_ophys_nwb_extension_builder.py +++ b/allensdk/brain_observatory/nwb/behavior_ophys_nwb_extension_builder.py @@ -4,20 +4,21 @@ BehaviorMetadataSchema, OphysBehaviorMetadataSchema, BehaviorTaskParametersSchema, - SubjectMetadataSchema, OphysEyeTrackingRigMetadataSchema -) -from allensdk.brain_observatory.nwb.metadata import ( - create_pynwb_extension_from_schemas + SubjectMetadataSchema, + OphysEyeTrackingRigMetadataSchema, ) +from allensdk.brain_observatory.nwb.metadata import create_pynwb_extension_from_schemas if __name__ == "__main__": - # Run this module to regenerate the extension yaml files into this dir: - prefix = 'ndx-aibs-behavior-ophys' + prefix = "ndx-aibs-behavior-ophys" schemas = [ - BehaviorTaskParametersSchema, SubjectMetadataSchema, - BehaviorMetadataSchema, OphysBehaviorMetadataSchema, - OphysEyeTrackingRigMetadataSchema] + BehaviorTaskParametersSchema, + SubjectMetadataSchema, + BehaviorMetadataSchema, + OphysBehaviorMetadataSchema, + OphysEyeTrackingRigMetadataSchema, + ] curr_dir = os.path.abspath(os.path.dirname(__file__)) create_pynwb_extension_from_schemas(schemas, prefix, save_dir=curr_dir) diff --git a/allensdk/brain_observatory/nwb/eye_tracking/extension_builder.py b/allensdk/brain_observatory/nwb/eye_tracking/extension_builder.py index 4b28c824d3..e0faabf6a9 100644 --- a/allensdk/brain_observatory/nwb/eye_tracking/extension_builder.py +++ b/allensdk/brain_observatory/nwb/eye_tracking/extension_builder.py @@ -1,9 +1,8 @@ import os.path -from pynwb.spec import (NWBNamespaceBuilder, export_spec, - NWBGroupSpec, NWBDatasetSpec) +from pynwb.spec import NWBNamespaceBuilder, export_spec, NWBGroupSpec, NWBDatasetSpec -NAMESPACE = 'ndx-ellipse-eye-tracking' +NAMESPACE = "ndx-ellipse-eye-tracking" def main(): @@ -12,82 +11,58 @@ def main(): doc="""Store the elliptical eye tracking output of DeepLabCut""", name=f"""{NAMESPACE}""", version="""0.1.0""", - author=list(map(str.strip, """Ben Dichter""".split(','))), - contact=list(map(str.strip, """bdichter@lbl.gov""".split(','))) + author=list(map(str.strip, """Ben Dichter""".split(","))), + contact=list(map(str.strip, """bdichter@lbl.gov""".split(","))), ) - ns_builder.include_type('SpatialSeries', namespace='core') - ns_builder.include_type('EyeTracking', namespace='core') - ns_builder.include_type('TimeSeries', namespace='core') + ns_builder.include_type("SpatialSeries", namespace="core") + ns_builder.include_type("EyeTracking", namespace="core") + ns_builder.include_type("TimeSeries", namespace="core") ellipse_series_spec = NWBGroupSpec( - neurodata_type_def='EllipseSeries', - neurodata_type_inc='SpatialSeries', - doc='Information about an ellipse moving over time', + neurodata_type_def="EllipseSeries", + neurodata_type_inc="SpatialSeries", + doc="Information about an ellipse moving over time", datasets=[ NWBDatasetSpec( - name='data', # override SpatialSeries 'data' dataset to be more explicit - dtype='numeric', - doc='The (x, y) coordinates of the center of the ellipse at each time point.', - dims=('num_times', 'x, y'), + name="data", # override SpatialSeries 'data' dataset to be more explicit + dtype="numeric", + doc="The (x, y) coordinates of the center of the ellipse at each time point.", + dims=("num_times", "x, y"), shape=(None, 2), ), NWBDatasetSpec( - name='area', - dtype='float', - doc='ellipse area, with nan values in likely blink times', - shape=(None, ) + name="area", dtype="float", doc="ellipse area, with nan values in likely blink times", shape=(None,) ), NWBDatasetSpec( - name='area_raw', - dtype='float', - doc='ellipse area, with no regard to likely blink times', - shape=(None, ) + name="area_raw", dtype="float", doc="ellipse area, with no regard to likely blink times", shape=(None,) ), - NWBDatasetSpec( - name='width', - dtype='float', - doc='width of ellipse', - shape=(None, ) - ), - NWBDatasetSpec( - name='height', - dtype='float', - doc='height of ellipse', - shape=(None, ) - ), - NWBDatasetSpec( - name='angle', - dtype='float', - doc='angle that ellipse is rotated by (phi)', - shape=(None, ) - ) - ] + NWBDatasetSpec(name="width", dtype="float", doc="width of ellipse", shape=(None,)), + NWBDatasetSpec(name="height", dtype="float", doc="height of ellipse", shape=(None,)), + NWBDatasetSpec(name="angle", dtype="float", doc="angle that ellipse is rotated by (phi)", shape=(None,)), + ], ) ellipse_eye_tracking_spec = NWBGroupSpec( - neurodata_type_def='EllipseEyeTracking', - neurodata_type_inc='EyeTracking', + neurodata_type_def="EllipseEyeTracking", + neurodata_type_inc="EyeTracking", name=None, - default_name='EyeTracking', - doc='Stores detailed eye tracking information output from DeepLabCut', + default_name="EyeTracking", + doc="Stores detailed eye tracking information output from DeepLabCut", groups=[ + NWBGroupSpec(neurodata_type_inc=ellipse_series_spec, name=x, doc=x.replace("_", " ")) + for x in ("eye_tracking", "pupil_tracking", "corneal_reflection_tracking") + ] + + [ NWBGroupSpec( - neurodata_type_inc=ellipse_series_spec, - name=x, - doc=x.replace('_', ' ') - ) for x in ('eye_tracking', 'pupil_tracking', 'corneal_reflection_tracking') - ] + [ - NWBGroupSpec( - neurodata_type_inc='TimeSeries', - name='likely_blink', - doc='Indicator of whether there was a probable blink for this frame' + neurodata_type_inc="TimeSeries", + name="likely_blink", + doc="Indicator of whether there was a probable blink for this frame", ) - ] - + ], ) - new_data_types = [ellipse_series_spec, ellipse_eye_tracking_spec] + new_data_types = [ellipse_series_spec, ellipse_eye_tracking_spec] # export the spec to yaml files in the spec folder output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__))) diff --git a/allensdk/brain_observatory/nwb/eye_tracking/ndx_ellipse_eye_tracking.py b/allensdk/brain_observatory/nwb/eye_tracking/ndx_ellipse_eye_tracking.py index 3127c9640f..692b603b53 100644 --- a/allensdk/brain_observatory/nwb/eye_tracking/ndx_ellipse_eye_tracking.py +++ b/allensdk/brain_observatory/nwb/eye_tracking/ndx_ellipse_eye_tracking.py @@ -3,15 +3,12 @@ # import ndx_events # Set path of the namespace.yaml file to the expected install location -ndx_ellipse_eye_tracking_specpath = os.path.join( - os.path.dirname(__file__), - 'ndx-ellipse-eye-tracking.namespace.yaml' -) +ndx_ellipse_eye_tracking_specpath = os.path.join(os.path.dirname(__file__), "ndx-ellipse-eye-tracking.namespace.yaml") # Load the namespace # load_namespaces(ndx_events.ndx_events_specpath) load_namespaces(ndx_ellipse_eye_tracking_specpath) -EllipseSeries = get_class('EllipseSeries', 'ndx-ellipse-eye-tracking') -EllipseEyeTracking = get_class('EllipseEyeTracking', 'ndx-ellipse-eye-tracking') \ No newline at end of file +EllipseSeries = get_class("EllipseSeries", "ndx-ellipse-eye-tracking") +EllipseEyeTracking = get_class("EllipseEyeTracking", "ndx-ellipse-eye-tracking") diff --git a/allensdk/brain_observatory/nwb/metadata.py b/allensdk/brain_observatory/nwb/metadata.py index a23fe55ae8..11c8340921 100644 --- a/allensdk/brain_observatory/nwb/metadata.py +++ b/allensdk/brain_observatory/nwb/metadata.py @@ -2,28 +2,25 @@ from marshmallow import fields import pynwb -from pynwb.spec import \ - NWBNamespaceBuilder, NWBGroupSpec, NWBAttributeSpec, NWBDatasetSpec +from pynwb.spec import NWBNamespaceBuilder, NWBGroupSpec, NWBAttributeSpec, NWBDatasetSpec from allensdk.brain_observatory.behavior.schemas import STYPE_DICT, TYPE_DICT def extract_from_schema(schema): - if hasattr(schema, 'neurodata_skip'): + if hasattr(schema, "neurodata_skip"): fields_to_skip = schema.neurodata_skip else: fields_to_skip = set() # Extract fields from Schema: - docval_list = [{'name': 'name', 'type': str, 'doc': 'name'}] + docval_list = [{"name": "name", "type": str, "doc": "name"}] - attributes = _extract_attributes(attributes=schema().fields, - fields_to_skip=fields_to_skip) + attributes = _extract_attributes(attributes=schema().fields, fields_to_skip=fields_to_skip) datasets = [] nwbfields_list = [] for name, val in schema().fields.items(): - if name in fields_to_skip: continue @@ -32,9 +29,7 @@ def extract_from_schema(schema): datasets.append(dataset) continue - docval_list.append({'name': name, - 'type': TYPE_DICT[type(val)], - 'doc': val.metadata['doc']}) + docval_list.append({"name": name, "type": TYPE_DICT[type(val)], "doc": val.metadata["doc"]}) nwbfields_list.append(name) return docval_list, attributes, nwbfields_list, datasets @@ -43,7 +38,7 @@ def extract_from_schema(schema): def load_pynwb_extension(schema, prefix: str): neurodata_type = schema.neurodata_type outdir = os.path.abspath(os.path.dirname(__file__)) - ns_path = f'{prefix}.namespace.yaml' + ns_path = f"{prefix}.namespace.yaml" # Read spec and load namespace: ns_abs_path = os.path.join(outdir, ns_path) @@ -52,29 +47,27 @@ def load_pynwb_extension(schema, prefix: str): return pynwb.get_class(neurodata_type, prefix) -def create_pynwb_extension_from_schemas(schema_list, prefix: str, - save_dir: str): +def create_pynwb_extension_from_schemas(schema_list, prefix: str, save_dir: str): # Initializations: - ext_source = f'{prefix}.extension.yaml' - ns_path = f'{prefix}.namespace.yaml' + ext_source = f"{prefix}.extension.yaml" + ns_path = f"{prefix}.namespace.yaml" print(f"Saving extensions to: {save_dir}") - extension_doc = ("Allen Institute behavior and optical " - "physiology extensions") + extension_doc = "Allen Institute behavior and optical physiology extensions" ns_builder = NWBNamespaceBuilder( doc=extension_doc, name=prefix, version="0.2.0", author="Allen Institute for Brain Science", - contact="waynew@alleninstitute.org") + contact="waynew@alleninstitute.org", + ) # Loops through and create NWB custom group specs for schemas found in: # allensdk.brain_observatory.behavior.schemas for schema in schema_list: - docval_list, attributes, nwbfields_list, datasets = \ - extract_from_schema(schema) + docval_list, attributes, nwbfields_list, datasets = extract_from_schema(schema) # Build the spec: ext_group_spec = NWBGroupSpec( @@ -82,7 +75,8 @@ def create_pynwb_extension_from_schemas(schema_list, prefix: str, neurodata_type_inc=schema.neurodata_type_inc, doc=schema.neurodata_doc, attributes=attributes, - datasets=datasets) + datasets=datasets, + ) # Add spec to builder: ns_builder.add_spec(ext_source, ext_group_spec) @@ -93,19 +87,18 @@ def create_pynwb_extension_from_schemas(schema_list, prefix: str, def _extract_dataset(val): if val.many: - raise NotImplementedError('many not supported') - if 'values' not in val.schema.fields: + raise NotImplementedError("many not supported") + if "values" not in val.schema.fields: raise ValueError('A dataset must contain an attribute called "values"') - values = val.schema.fields['values'] - attributes = _extract_attributes(attributes=val.schema.fields, - fields_to_skip=['values']) + values = val.schema.fields["values"] + attributes = _extract_attributes(attributes=val.schema.fields, fields_to_skip=["values"]) return NWBDatasetSpec( name=val.name, attributes=attributes, - doc=val.metadata['doc'], + doc=val.metadata["doc"], dtype=STYPE_DICT[type(values)], - dims=values.metadata['shape'] + dims=values.metadata["shape"], ) @@ -116,15 +109,17 @@ def _extract_attributes(attributes, fields_to_skip=None): continue if type(val) == fields.List: - res.append(NWBAttributeSpec(name=name, - dtype=STYPE_DICT[type(val)], - doc=val.metadata['doc'], - shape=val.metadata['shape'], - required=val.required)) + res.append( + NWBAttributeSpec( + name=name, + dtype=STYPE_DICT[type(val)], + doc=val.metadata["doc"], + shape=val.metadata["shape"], + required=val.required, + ) + ) elif type(val) == fields.Nested: continue else: - res.append(NWBAttributeSpec(name=name, - dtype=STYPE_DICT[type(val)], - doc=val.metadata['doc'])) + res.append(NWBAttributeSpec(name=name, dtype=STYPE_DICT[type(val)], doc=val.metadata["doc"])) return res diff --git a/allensdk/brain_observatory/nwb/nwb_api.py b/allensdk/brain_observatory/nwb/nwb_api.py index ace7604a8f..c8178491e0 100644 --- a/allensdk/brain_observatory/nwb/nwb_api.py +++ b/allensdk/brain_observatory/nwb/nwb_api.py @@ -2,35 +2,31 @@ import pynwb import SimpleITK as sitk -from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations \ - import \ - Presentations +from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations import Presentations from allensdk.brain_observatory.running_speed import RunningSpeed from allensdk.brain_observatory.behavior.image_api import ImageApi class NwbApi: - - __slots__ = ('path', '_nwbfile') + __slots__ = ("path", "_nwbfile") @property def nwbfile(self): - if hasattr(self, '_nwbfile'): + if hasattr(self, "_nwbfile"): return self._nwbfile - io = pynwb.NWBHDF5IO(self.path, 'r', load_namespaces=True) + io = pynwb.NWBHDF5IO(self.path, "r", load_namespaces=True) return io.read() def __init__(self, path, **kwargs): - ''' Reads data for a single Brain Observatory session from an NWB 2.0 + """Reads data for a single Brain Observatory session from an NWB 2.0 file - ''' + """ self.path = path @classmethod def from_nwbfile(cls, nwbfile, **kwargs): - obj = cls(path=None, **kwargs) obj._nwbfile = nwbfile @@ -38,7 +34,7 @@ def from_nwbfile(cls, nwbfile, **kwargs): @classmethod def from_path(cls, path, **kwargs): - with open(path, 'r'): + with open(path, "r"): pass return cls(path=path, **kwargs) @@ -58,11 +54,9 @@ def get_running_speed(self, lowpass=True) -> RunningSpeed: The running speed """ - interface_name = 'speed' if lowpass else 'speed_unfiltered' - values = self.nwbfile.processing['running'].get_data_interface( - interface_name).data[:] - timestamps = self.nwbfile.processing['running'].get_data_interface( - interface_name).timestamps[:] + interface_name = "speed" if lowpass else "speed_unfiltered" + values = self.nwbfile.processing["running"].get_data_interface(interface_name).data[:] + timestamps = self.nwbfile.processing["running"].get_data_interface(interface_name).timestamps[:] return RunningSpeed( timestamps=timestamps, @@ -70,12 +64,10 @@ def get_running_speed(self, lowpass=True) -> RunningSpeed: ) def get_stimulus_presentations(self) -> pd.DataFrame: - presentations = Presentations.from_nwb(nwbfile=self.nwbfile, - add_is_change=False) + presentations = Presentations.from_nwb(nwbfile=self.nwbfile, add_is_change=False) return presentations.value def get_invalid_times(self) -> pd.DataFrame: - container = self.nwbfile.invalid_times if container: return container.to_dataframe() @@ -83,14 +75,12 @@ def get_invalid_times(self) -> pd.DataFrame: return pd.DataFrame() def get_image(self, name, module, image_api=None) -> sitk.Image: - if image_api is None: image_api = ImageApi - nwb_img = self.nwbfile.processing[module].get_data_interface( - 'images')[name] + nwb_img = self.nwbfile.processing[module].get_data_interface("images")[name] data = nwb_img.data resolution = nwb_img.resolution # px/cm spacing = [resolution * 10, resolution * 10] - return ImageApi.serialize(data, spacing, 'mm') + return ImageApi.serialize(data, spacing, "mm") diff --git a/allensdk/brain_observatory/nwb/nwb_utils.py b/allensdk/brain_observatory/nwb/nwb_utils.py index b37582f15b..15ff07de3d 100644 --- a/allensdk/brain_observatory/nwb/nwb_utils.py +++ b/allensdk/brain_observatory/nwb/nwb_utils.py @@ -45,10 +45,7 @@ def get_column_name(table_cols: list, possible_names: set) -> str: column_set = set(table_cols) column_names = list(column_set.intersection(possible_names)) if not len(column_names) == 1: - raise KeyError( - "Table expected one name column in intersection, found:" - f" {column_names}" - ) + raise KeyError(f"Table expected one name column in intersection, found: {column_names}") return column_names[0] @@ -93,9 +90,7 @@ def add_image_to_nwb(nwbfile: NWBFile, image_data: Image, image_name: str): else: ophys_mod = nwbfile.processing[module_name] - image = GrayscaleImage( - image_name, data, resolution=spacing[0] / 10, description=description - ) + image = GrayscaleImage(image_name, data, resolution=spacing[0] / 10, description=description) if "images" not in ophys_mod.containers: images = Images(name="images") @@ -112,9 +107,7 @@ def __init__( self, nwb_filepath: str, session_data: dict, - serializer: Union[ - JsonReadableInterface, NwbReadableInterface, NwbWritableInterface - ], + serializer: Union[JsonReadableInterface, NwbReadableInterface, NwbWritableInterface], ): """ @@ -174,13 +167,9 @@ def write_nwb( """ from_lims_kwargs = { - k: v - for k, v in kwargs.items() - if k in inspect.signature(self._serializer.from_lims).parameters + k: v for k, v in kwargs.items() if k in inspect.signature(self._serializer.from_lims).parameters } - lims_session = self._serializer.from_lims( - self._session_data[id_column_name], **from_lims_kwargs - ) + lims_session = self._serializer.from_lims(self._session_data[id_column_name], **from_lims_kwargs) lims_session = self._update_session(lims_session, **kwargs) try: @@ -195,20 +184,14 @@ def write_nwb( input_session=lims_session, skip_stim=skip_stim, ) - self._compare_sessions( - nwbfile=nwbfile, loaded_session=lims_session, **kwargs - ) + self._compare_sessions(nwbfile=nwbfile, loaded_session=lims_session, **kwargs) os.rename(self.nwb_filepath_inprogress, self._nwb_filepath) except Exception as e: if os.path.isfile(self.nwb_filepath_inprogress): - os.rename( - self.nwb_filepath_inprogress, self._nwb_filepath_error - ) + os.rename(self.nwb_filepath_inprogress, self._nwb_filepath_error) raise e - def _update_session( - self, lims_session: BehaviorSession, **kwargs - ) -> BehaviorSession: + def _update_session(self, lims_session: BehaviorSession, **kwargs) -> BehaviorSession: """Call session methods to update certain values within the session. Should be used as part of a datarelease to resolve known data issues. @@ -227,11 +210,7 @@ def _write_nwb(self, session: BehaviorSession, **kwargs) -> NWBFile: ------- """ - to_nwb_kwargs = { - k: v - for k, v in kwargs.items() - if k in inspect.signature(self._serializer.to_nwb).parameters - } + to_nwb_kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(self._serializer.to_nwb).parameters} nwbfile = session.to_nwb(**to_nwb_kwargs) with NWBHDF5IO(self.nwb_filepath_inprogress, "w") as nwb_file_writer: @@ -299,12 +278,8 @@ def _compare_stimulus_file( skip_stim = [] error_message = "" behavior_session_id = input_session.behavior_session_id - db_conn = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) - stimulus_file = BehaviorStimulusFile.from_lims( - db=db_conn, behavior_session_id=behavior_session_id - ).validate() + db_conn = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + stimulus_file = BehaviorStimulusFile.from_lims(db=db_conn, behavior_session_id=behavior_session_id).validate() stim_file_methods = dir(stimulus_file) for key, bs_val in input_session.metadata.items(): if key in skip_stim: @@ -325,13 +300,7 @@ def _compare_stimulus_file( if len(error_message) > 0: raise ValueError(error_message) - def _compare_sessions( - self, nwbfile: NWBFile, loaded_session: DataObject, **kwargs - ): - kwargs = { - k: v - for k, v in kwargs.items() - if k in inspect.signature(self._serializer.from_nwb).parameters - } + def _compare_sessions(self, nwbfile: NWBFile, loaded_session: DataObject, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if k in inspect.signature(self._serializer.from_nwb).parameters} nwb_session = self._serializer.from_nwb(nwbfile, **kwargs) assert sessions_are_equal(loaded_session, nwb_session, reraise=True) diff --git a/allensdk/brain_observatory/nwb/schemas.py b/allensdk/brain_observatory/nwb/schemas.py index 110be05cd1..be3ba39793 100644 --- a/allensdk/brain_observatory/nwb/schemas.py +++ b/allensdk/brain_observatory/nwb/schemas.py @@ -5,4 +5,4 @@ class RunningSpeedPathsSchema(RaisingSchema): running_speed_path = String(required=True, validate=check_read_access) - running_speed_timestamps_path = String(required=True, validate=check_read_access) \ No newline at end of file + running_speed_timestamps_path = String(required=True, validate=check_read_access) diff --git a/allensdk/brain_observatory/observatory_plots.py b/allensdk/brain_observatory/observatory_plots.py index 90e21d0345..8ce1569628 100644 --- a/allensdk/brain_observatory/observatory_plots.py +++ b/allensdk/brain_observatory/observatory_plots.py @@ -48,22 +48,27 @@ import numpy as np -SI_RANGE = [ 0, 1.5 ] +SI_RANGE = [0, 1.5] P_VALUE_MAX = 0.05 PEAK_DFF_MIN = 3 N_HIST_BINS = 50 STIM_COLOR = "#ccccdd" -STIMULUS_COLOR_MAP = LinearSegmentedColormap.from_list('default',[ [1.0,1.0,1.0,0.0], [.6,.6,.85,1.0] ]) +STIMULUS_COLOR_MAP = LinearSegmentedColormap.from_list("default", [[1.0, 1.0, 1.0, 0.0], [0.6, 0.6, 0.85, 1.0]]) PUPIL_COLOR_MAP = LinearSegmentedColormap.from_list( - 'custom_plasma', [[0.050383, 0.029803, 0.527975], - [0.417642, 0.000564, 0.658390], - [0.692840, 0.165141, 0.564522], - [0.881443, 0.392529, 0.383229], - [0.988260, 0.652325, 0.211364], - [0.940015, 0.975158, 0.131326]]) + "custom_plasma", + [ + [0.050383, 0.029803, 0.527975], + [0.417642, 0.000564, 0.658390], + [0.692840, 0.165141, 0.564522], + [0.881443, 0.392529, 0.383229], + [0.988260, 0.652325, 0.211364], + [0.940015, 0.975158, 0.131326], + ], +) EVOKED_COLOR = "#b30000" SPONTANEOUS_COLOR = "#0000b3" + def plot_cell_correlation(sig_corrs, labels, colors, scale=15): if len(sig_corrs) > 1: alpha = 1.0 / (len(sig_corrs) + 1) @@ -72,62 +77,81 @@ def plot_cell_correlation(sig_corrs, labels, colors, scale=15): ax = plt.gca() for sig_corr, color, label in zip(sig_corrs, colors, labels): - ax.hist(sig_corr, bins=30, range=[-1,1], - histtype='stepfilled', - facecolor=(.6,.6,.6,alpha), - edgecolor=color, - linewidth=1.5, - label=label) - + ax.hist( + sig_corr, + bins=30, + range=[-1, 1], + histtype="stepfilled", + facecolor=(0.6, 0.6, 0.6, alpha), + edgecolor=color, + linewidth=1.5, + label=label, + ) + ax.set_xlabel("signal correlation") ax.set_ylabel("cell count") ax.xaxis.grid(True) - - leg = ax.legend(loc='upper left', frameon=False) + + leg = ax.legend(loc="upper left", frameon=False) for i, t in enumerate(leg.get_texts()): t.set_color(colors[i]) - - plt.text(.125, .5, u'\u2014', transform=ax.transAxes, - horizontalalignment='center', verticalalignment='center', - weight='bold', size='xx-large') - plt.text(.875, .5, '+', transform=ax.transAxes, - horizontalalignment='center', verticalalignment='center', - weight='bold', size='xx-large') + + plt.text( + 0.125, + 0.5, + "\u2014", + transform=ax.transAxes, + horizontalalignment="center", + verticalalignment="center", + weight="bold", + size="xx-large", + ) + plt.text( + 0.875, + 0.5, + "+", + transform=ax.transAxes, + horizontalalignment="center", + verticalalignment="center", + weight="bold", + size="xx-large", + ) + def population_correlation_scatter(sig_corrs, noise_corrs, labels, colors, scale=15): - alpha = max(0.85 - 0.15 * (len(sig_corrs)-1), 0.2) + alpha = max(0.85 - 0.15 * (len(sig_corrs) - 1), 0.2) ax = plt.gca() for sig_corr, noise_corr, color, label in zip(sig_corrs, noise_corrs, colors, labels): inds = np.tril_indices(len(sig_corr)) - ax.scatter(sig_corr[inds], noise_corr[inds], - s=scale, - color=color, - linewidth=0.5, edgecolor='#333333', - label=label, - alpha=alpha) + ax.scatter( + sig_corr[inds], + noise_corr[inds], + s=scale, + color=color, + linewidth=0.5, + edgecolor="#333333", + label=label, + alpha=alpha, + ) ax.set_xlabel("signal correlation") ax.set_ylabel("noise correlation") - ax.set_xlim([-1,1]) - ax.set_ylim([-1,1]) - leg = ax.legend(loc='upper left', frameon=False) + ax.set_xlim([-1, 1]) + ax.set_ylim([-1, 1]) + leg = ax.legend(loc="upper left", frameon=False) for i, t in enumerate(leg.get_texts()): t.set_color(colors[i]) -def plot_mask_outline(mask, ax, color='k'): - pim = np.pad(mask, 1, 'constant', constant_values=(0,0)) +def plot_mask_outline(mask, ax, color="k"): + pim = np.pad(mask, 1, "constant", constant_values=(0, 0)) hedges = np.argwhere(np.diff(pim, axis=0)) vedges = np.argwhere(np.diff(pim, axis=1)) - hlines = [ [ [r-.5, c-1.5], [r-.5, c-.5] ] for r,c in hedges ] - vlines = [ [ [r-1.5, c-.5], [r-.5, c-.5] ] for r,c in vedges ] - - for p1,p2 in hlines + vlines: - ax.add_line(mlines.Line2D([ p1[1], p2[1] ], - [ p1[0], p2[0] ], - linewidth=3, - color=color, - clip_on=False)) - + hlines = [[[r - 0.5, c - 1.5], [r - 0.5, c - 0.5]] for r, c in hedges] + vlines = [[[r - 1.5, c - 0.5], [r - 0.5, c - 0.5]] for r, c in vedges] + + for p1, p2 in hlines + vlines: + ax.add_line(mlines.Line2D([p1[1], p2[1]], [p1[0], p2[0]], linewidth=3, color=color, clip_on=False)) + class DimensionPatchHandler(object): def __init__(self, vals, start_color, end_color, *args, **kwargs): @@ -145,9 +169,7 @@ def legend_artist(self, legend, orig_handle, fontsize, handlebox): x = x0 for i in range(len(self.vals)): rgb = self.dim_color(i) - r = mpatches.Rectangle((x+i*sub_width, y0), - sub_width, y0+height, - facecolor=rgb, linewidth=0) + r = mpatches.Rectangle((x + i * sub_width, y0), sub_width, y0 + height, facecolor=rgb, linewidth=0) r.set_clip_on(False) handlebox.add_artist(r) @@ -156,10 +178,11 @@ def legend_artist(self, legend, orig_handle, fontsize, handlebox): def dim_color(self, index): rgb1 = np.array(mcolors.colorConverter.to_rgb(self.start_color)) rgb2 = np.array(mcolors.colorConverter.to_rgb(self.end_color)) - t = float(index) / (len(self.vals)+1) + t = float(index) / (len(self.vals) + 1) rgb = t * rgb2 + (1.0 - t) * rgb1 return rgb + def float_label(n): if isinstance(n, int): return str(n) @@ -168,40 +191,38 @@ def float_label(n): else: return "%.2f" % n + def plot_representational_similarity(rs, dims=None, dim_labels=None, colors=None, dim_order=None, labels=True): if np.all(np.isnan(rs)): - return # if rs is all NaN (happens with only 1 cell), there is nothing to plot + return # if rs is all NaN (happens with only 1 cell), there is nothing to plot if dim_order is not None: - rsr = np.arange(len(rs)).reshape(*map(len,dims)) + rsr = np.arange(len(rs)).reshape(*map(len, dims)) rsrt = rsr.transpose(dim_order) ri = rsrt.flatten() - rs = rs[ri,:][:,ri] + rs = rs[ri, :][:, ri] dims = np.array(dims)[dim_order] colors = np.array(colors)[dim_order] dim_labels = np.array(dim_labels)[dim_order] - + # force the color map to be centered at zero - clim = np.nanpercentile(rs, [5.0,95.0], axis=None) + clim = np.nanpercentile(rs, [5.0, 95.0], axis=None) vrange = max(abs(clim[0]), abs(clim[1])) rs = rs.copy() np.fill_diagonal(rs, np.nan) - + if labels: - grid = ImageGrid(plt.gcf(), 111, - nrows_ncols=(1,1), - cbar_location="right", - cbar_mode="single", - cbar_size="7%", - cbar_pad=0.05) - + grid = ImageGrid( + plt.gcf(), 111, nrows_ncols=(1, 1), cbar_location="right", cbar_mode="single", cbar_size="7%", cbar_pad=0.05 + ) + for ax in grid: pass else: ax = plt.gca() - im = ax.imshow(rs, interpolation='nearest', cmap='RdBu_r', vmin=-vrange, vmax=vrange) + im = ax.imshow(rs, interpolation="nearest", cmap="RdBu_r", vmin=-vrange, vmax=vrange) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_xticks([]) @@ -209,103 +230,88 @@ def plot_representational_similarity(rs, dims=None, dim_labels=None, colors=None if labels: cbar = ax.cax.colorbar(im) - cbar.set_label_text('stimulus correlation') - + cbar.set_label_text("stimulus correlation") + if dims is not None: - dim_labels = ["%s(%s)" % (dim_labels[i],', '.join(map(float_label, dims[i].tolist()))) for i in range(len(dims)) ] - dim_handlers = [ DimensionPatchHandler(dims[i], colors[i], 'w') for i in range(len(dims)) ] + dim_labels = [ + "%s(%s)" % (dim_labels[i], ", ".join(map(float_label, dims[i].tolist()))) for i in range(len(dims)) + ] + dim_handlers = [DimensionPatchHandler(dims[i], colors[i], "w") for i in range(len(dims))] n = len(rs) for cell_i in range(n): idx = np.unravel_index(cell_i, map(len, dims)) - start = -(len(dims))*2 + start = -(len(dims)) * 2 width = 1.8 for dim_i, color in enumerate(colors): v_i = idx[dim_i] rgb = dim_handlers[dim_i].dim_color(v_i) - r = mpatches.Rectangle((start + dim_i * width, cell_i-.5), - width, 1.2, - facecolor=rgb, linewidth=0) + r = mpatches.Rectangle((start + dim_i * width, cell_i - 0.5), width, 1.2, facecolor=rgb, linewidth=0) r.set_clip_on(False) ax.add_patch(r) - r = mpatches.Rectangle((cell_i-.5, start + dim_i * width), - 1.2, width, - facecolor=rgb, linewidth=0) + r = mpatches.Rectangle((cell_i - 0.5, start + dim_i * width), 1.2, width, facecolor=rgb, linewidth=0) r.set_clip_on(False) ax.add_patch(r) if labels: - patches = [ mpatches.Patch(label=dim_labels[i]) for i in range(len(dims)) ] - ax.legend(handles=patches, - handler_map=dict(zip(patches,dim_handlers)), - loc='upper left', - bbox_to_anchor=(0,0), - ncol=2, - fontsize=9, - frameon=False) + patches = [mpatches.Patch(label=dim_labels[i]) for i in range(len(dims))] + ax.legend( + handles=patches, + handler_map=dict(zip(patches, dim_handlers)), + loc="upper left", + bbox_to_anchor=(0, 0), + ncol=2, + fontsize=9, + frameon=False, + ) if labels: - plt.subplots_adjust(left=0.07, - right=.88, - wspace=0.0, hspace=0.0) + plt.subplots_adjust(left=0.07, right=0.88, wspace=0.0, hspace=0.0) + def plot_condition_histogram(vals, bins, color=STIM_COLOR): plt.grid() if len(vals) > 1: vals = [np.array(vals).flatten()] # matplotlib >= 2.1 needs this if len(vals) > 0: - n, hbins, patches = plt.hist(vals, - bins=np.arange(len(bins)+1)+1, - align='left', - density=False, - rwidth=.8, - color=color, - zorder=3) + n, hbins, patches = plt.hist( + vals, bins=np.arange(len(bins) + 1) + 1, align="left", density=False, rwidth=0.8, color=color, zorder=3 + ) else: - hbins = np.arange(len(bins)+1)+1 + hbins = np.arange(len(bins) + 1) + 1 plt.xticks(hbins[:-1], np.round(bins, 2)) - -def plot_selectivity_cumulative_histogram(sis, - xlabel, - si_range=SI_RANGE, - n_hist_bins=N_HIST_BINS, - color=STIM_COLOR): + +def plot_selectivity_cumulative_histogram(sis, xlabel, si_range=SI_RANGE, n_hist_bins=N_HIST_BINS, color=STIM_COLOR): if len(sis) > 1: sis = [np.array(sis).flatten()] # matplotlib >= 2.1 needs this bins = np.linspace(si_range[0], si_range[1], n_hist_bins) - yticks = np.linspace(0,1,5) + yticks = np.linspace(0, 1, 5) xticks = np.linspace(si_range[0], si_range[1], 4) - + yscale = 1.0 # this is for normalizing to total # cells, not just significant cells # yscale = float(num_cells) / len(osis) # orientation selectivity cumulative histogram if len(sis) > 0: - n, bins, patches = plt.hist(sis, density=True, bins=bins, - cumulative=True, histtype='stepfilled', - color=color) + n, bins, patches = plt.hist(sis, density=True, bins=bins, cumulative=True, histtype="stepfilled", color=color) plt.xlim(si_range) - plt.ylim([0,yscale]) - plt.yticks(yticks*yscale, yticks) + plt.ylim([0, yscale]) + plt.yticks(yticks * yscale, yticks) plt.xticks(xticks) - + plt.xlabel(xlabel) plt.ylabel("fraction of cells") plt.grid() -def plot_radial_histogram(angles, - counts, - all_angles=None, - include_labels=False, - offset=180.0, - direction=-1, - closed=False, - color=STIM_COLOR): + +def plot_radial_histogram( + angles, counts, all_angles=None, include_labels=False, offset=180.0, direction=-1, closed=False, color=STIM_COLOR +): if all_angles is None: if len(angles) < 2: all_angles = np.linspace(0, 315, 8) @@ -321,27 +327,25 @@ def plot_radial_histogram(angles, wedges = [] for count, angle in zip(counts, angles): - angle = angle*direction + offset - wedge = mpatches.Wedge((0,0), count, angle-dth, angle+dth) + angle = angle * direction + offset + wedge = mpatches.Wedge((0, 0), count, angle - dth, angle + dth) wedges.append(wedge) wedge_coll = PatchCollection(wedges) wedge_coll.set_facecolor(color) wedge_coll.set_zorder(2) - angles_rad = (all_angles*direction + offset)*np.pi/180.0 + angles_rad = (all_angles * direction + offset) * np.pi / 180.0 if closed: border_coll = cplots.radial_circles([max_count]) else: - border_coll = cplots.radial_arcs([max_count], - min(angles_rad), - max(angles_rad)) - border_coll.set_facecolor((0,0,0,0)) + border_coll = cplots.radial_arcs([max_count], min(angles_rad), max(angles_rad)) + border_coll.set_facecolor((0, 0, 0, 0)) border_coll.set_zorder(1) line_coll = cplots.angle_lines(angles_rad, 0, max_count) - line_coll.set_edgecolor((0,0,0,1)) + line_coll.set_edgecolor((0, 0, 0, 1)) line_coll.set_linestyle(":") line_coll.set_zorder(1) @@ -351,58 +355,59 @@ def plot_radial_histogram(angles, ax.add_collection(line_coll) if include_labels: - cplots.add_angle_labels(ax, angles_rad, all_angles.astype(int), max_count, (0,0,0,1), offset=max_count*0.1) - ax.set(xlim=(-max_count*1.2, max_count*1.2), - ylim=(-max_count*1.2, max_count*1.2), - aspect=1.0) + cplots.add_angle_labels(ax, angles_rad, all_angles.astype(int), max_count, (0, 0, 0, 1), offset=max_count * 0.1) + ax.set(xlim=(-max_count * 1.2, max_count * 1.2), ylim=(-max_count * 1.2, max_count * 1.2), aspect=1.0) else: - ax.set(xlim=(-max_count*1.05, max_count*1.05), - ylim=(-max_count*1.05, max_count*1.05), - aspect=1.0) - + ax.set(xlim=(-max_count * 1.05, max_count * 1.05), ylim=(-max_count * 1.05, max_count * 1.05), aspect=1.0) + + def plot_time_to_peak(msrs, ttps, t_start, t_end, stim_start, stim_end, cmap): - plt.plot(ttps, np.arange(msrs.shape[0],0,-1)-0.5, color='black') + plt.plot(ttps, np.arange(msrs.shape[0], 0, -1) - 0.5, color="black") if msrs.shape[0] > 0: - plt.imshow(msrs, - cmap=cmap, clim=[0,3], - aspect=float((t_end-t_start) / msrs.shape[0]), # float to get rid of MPL error - extent=[t_start, t_end, 0, msrs.shape[0]], interpolation='nearest') - plt.ylim([0,msrs.shape[0]]) + plt.imshow( + msrs, + cmap=cmap, + clim=[0, 3], + aspect=float((t_end - t_start) / msrs.shape[0]), # float to get rid of MPL error + extent=[t_start, t_end, 0, msrs.shape[0]], + interpolation="nearest", + ) + plt.ylim([0, msrs.shape[0]]) else: plt.ylim([0, 1]) plt.xlim([t_start, t_end]) - plt.axvline(stim_start, linestyle=':', color='black') - plt.axvline(stim_end, linestyle=':', color='black') + plt.axvline(stim_start, linestyle=":", color="black") + plt.axvline(stim_end, linestyle=":", color="black") - xticks = np.array([ t_start, stim_start, stim_end, t_end ]) + xticks = np.array([t_start, stim_start, stim_end, t_end]) plt.xticks(xticks, np.round(xticks - stim_start, 2)) plt.xlabel("time from stimulus start (s)") yticks, _ = plt.yticks() plt.ylabel("cell number") + @contextmanager def figure_in_px(w, h, file_name, dpi=96.0, transparent=False): - fig = plt.figure(figsize=(w/dpi, h/dpi), dpi=dpi) + fig = plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi) yield fig - + plt.savefig(file_name, dpi=dpi, transparent=transparent) plt.close() + def finalize_no_axes(pad=0.0): - plt.axis('off') - plt.subplots_adjust(left=pad, - right=1.0-pad, - bottom=pad, - top=1.0-pad, - wspace=0.0, hspace=0.0) - -def finalize_with_axes(pad=.3): + plt.axis("off") + plt.subplots_adjust(left=pad, right=1.0 - pad, bottom=pad, top=1.0 - pad, wspace=0.0, hspace=0.0) + + +def finalize_with_axes(pad=0.3): plt.tight_layout(pad=pad) -def finalize_no_labels(pad=.3, legend=False): + +def finalize_no_labels(pad=0.3, legend=False): ax = plt.gca() ax.set_xlabel("") ax.set_ylabel("") @@ -412,96 +417,95 @@ def finalize_no_labels(pad=.3, legend=False): ax.legend_.remove() plt.tight_layout(pad=pad) -def plot_combined_speed(binned_resp_vis, binned_dx_vis, binned_resp_sp, binned_dx_sp, - evoked_color, spont_color): + +def plot_combined_speed(binned_resp_vis, binned_dx_vis, binned_resp_sp, binned_dx_sp, evoked_color, spont_color): ax = plt.gca() num_bins = max(binned_dx_vis.shape[0], binned_dx_sp.shape[0]) - + plot_speed(binned_resp_vis, binned_dx_vis, num_bins, evoked_color) plot_speed(binned_resp_sp, binned_dx_sp, num_bins, spont_color) - - xmin = min(binned_dx_vis[:,0].min(), binned_dx_sp[:,0].min()) - xmax = max(binned_dx_vis[:,0].max(), binned_dx_sp[:,0].max()) + xmin = min(binned_dx_vis[:, 0].min(), binned_dx_sp[:, 0].min()) + xmax = max(binned_dx_vis[:, 0].max(), binned_dx_sp[:, 0].max()) - ymin = min(binned_resp_vis[:,0].min(), binned_resp_sp[:,0].min()) - ymax = max(binned_resp_vis[:,0].max(), binned_resp_sp[:,0].max()) + ymin = min(binned_resp_vis[:, 0].min(), binned_resp_sp[:, 0].min()) + ymax = max(binned_resp_vis[:, 0].max(), binned_resp_sp[:, 0].max()) - xpadding = (xmax-xmin)*.05 - ypadding = (ymax-ymin)*.20 + xpadding = (xmax - xmin) * 0.05 + ypadding = (ymax - ymin) * 0.20 ax.set_xlim([xmin - xpadding, xmax + xpadding]) ax.set_ylim([ymin - ypadding, ymax + ypadding]) -def plot_speed(binned_resp, binned_dx, num_bins, color): +def plot_speed(binned_resp, binned_dx, num_bins, color): ax = plt.gca() # plot the zero bin as a dot with whiskers - ax.errorbar([ binned_dx[0,0] ], [ binned_resp[0,0] ], yerr=[ binned_resp[0,1] ], fmt='o', color=color) + ax.errorbar([binned_dx[0, 0]], [binned_resp[0, 0]], yerr=[binned_resp[0, 1]], fmt="o", color=color) # if there's only one bin, drop out - if len(binned_dx[:,0]) <= 1: + if len(binned_dx[:, 0]) <= 1: return - f = si.interp1d(binned_dx[:,0], binned_resp[:,0]) - x = np.linspace(min(binned_dx[:,0]), max(binned_dx[:,0]), num=num_bins, endpoint=True) + f = si.interp1d(binned_dx[:, 0], binned_resp[:, 0]) + x = np.linspace(min(binned_dx[:, 0]), max(binned_dx[:, 0]), num=num_bins, endpoint=True) y = f(x) - - f_up = si.interp1d(binned_dx[:,0], binned_resp[:,0] + binned_resp[:,1]) + + f_up = si.interp1d(binned_dx[:, 0], binned_resp[:, 0] + binned_resp[:, 1]) y_up = f_up(x) - - f_down = si.interp1d(binned_dx[:,0], binned_resp[:,0] - binned_resp[:,1]) + + f_down = si.interp1d(binned_dx[:, 0], binned_resp[:, 0] - binned_resp[:, 1]) y_down = f_down(x) - - ax.plot(x, y, color=color) + + ax.plot(x, y, color=color) ax.fill_between(x, y_down, y_up, facecolor=color, alpha=0.1) -def plot_receptive_field(rf, color_map=None, clim=None, - mask=None, outline_color='#cccccc', - scalebar=True): +def plot_receptive_field(rf, color_map=None, clim=None, mask=None, outline_color="#cccccc", scalebar=True): if mask is not None: rf = np.ma.array(rf, mask=~mask) if clim is None: - clim = np.nanpercentile(rf, [1.0,99.0], axis=None) + clim = np.nanpercentile(rf, [1.0, 99.0], axis=None) - plt.imshow(rf, interpolation='nearest', - cmap=color_map, - clim=clim, - origin='bottom') + plt.imshow(rf, interpolation="nearest", cmap=color_map, clim=clim, origin="bottom") if mask is not None: plot_mask_outline(mask, plt.gca(), outline_color) if scalebar: - scale_dims = np.array([ 28.0, 16.0 ]) - scale_p = [ 26.8, 14.8 ] - text_p = [ scale_p[0]+0.5, scale_p[1]-0.5 ] - - + scale_dims = np.array([28.0, 16.0]) + scale_p = [26.8, 14.8] + text_p = [scale_p[0] + 0.5, scale_p[1] - 0.5] + ax = plt.gca() - ax.add_patch(mpatches.Rectangle(scale_p / scale_dims, - 1.0/scale_dims[0], 1.0/scale_dims[1], - facecolor='w', - transform=ax.transAxes, - linewidth=1.0, - edgecolor=outline_color)) - plt.text(text_p[0] / scale_dims[0], text_p[1] / scale_dims[1], "4deg", - horizontalalignment='center', - verticalalignment='center', - transform=ax.transAxes) - - - -def plot_pupil_location(xy_deg, s=1, c=None, cmap=PUPIL_COLOR_MAP, - edgecolor='', include_labels=True): + ax.add_patch( + mpatches.Rectangle( + scale_p / scale_dims, + 1.0 / scale_dims[0], + 1.0 / scale_dims[1], + facecolor="w", + transform=ax.transAxes, + linewidth=1.0, + edgecolor=outline_color, + ) + ) + plt.text( + text_p[0] / scale_dims[0], + text_p[1] / scale_dims[1], + "4deg", + horizontalalignment="center", + verticalalignment="center", + transform=ax.transAxes, + ) + + +def plot_pupil_location(xy_deg, s=1, c=None, cmap=PUPIL_COLOR_MAP, edgecolor="", include_labels=True): if c is None: xy_deg = xy_deg[~np.isnan(xy_deg).any(axis=1)] c = gaussian_kde(xy_deg.T)(xy_deg.T) - plt.scatter(xy_deg[:,0], xy_deg[:,1], s=s, c=c, cmap=cmap, - edgecolor=edgecolor) + plt.scatter(xy_deg[:, 0], xy_deg[:, 1], s=s, c=c, cmap=cmap, edgecolor=edgecolor) plt.xlim(-70, 70) plt.ylim(-70, 70) diff --git a/allensdk/brain_observatory/ophys/project_constants.py b/allensdk/brain_observatory/ophys/project_constants.py index e9448feeee..45168295de 100644 --- a/allensdk/brain_observatory/ophys/project_constants.py +++ b/allensdk/brain_observatory/ophys/project_constants.py @@ -1,5 +1,4 @@ -"""Collection of specific project metadata not readily available on LIMS. -""" +"""Collection of specific project metadata not readily available on LIMS.""" ##################################################################### diff --git a/allensdk/brain_observatory/ophys/trace_extraction/__init__.py b/allensdk/brain_observatory/ophys/trace_extraction/__init__.py index 1a36940705..bc0e4b8c2d 100644 --- a/allensdk/brain_observatory/ophys/trace_extraction/__init__.py +++ b/allensdk/brain_observatory/ophys/trace_extraction/__init__.py @@ -1,8 +1,10 @@ import warnings -warnings.warn("trace_extraction functionality has been moved from AllenSDK " - "to https://github.com/AllenInstitute/ophys_etl_pipelines ." - "The functionality in this AllenSDK package will be removed " - "in v3.0.0.", - category=DeprecationWarning, - stacklevel=2) +warnings.warn( + "trace_extraction functionality has been moved from AllenSDK " + "to https://github.com/AllenInstitute/ophys_etl_pipelines ." + "The functionality in this AllenSDK package will be removed " + "in v3.0.0.", + category=DeprecationWarning, + stacklevel=2, +) diff --git a/allensdk/brain_observatory/ophys/trace_extraction/__main__.py b/allensdk/brain_observatory/ophys/trace_extraction/__main__.py index c0da66dcbc..5b272a5c43 100644 --- a/allensdk/brain_observatory/ophys/trace_extraction/__main__.py +++ b/allensdk/brain_observatory/ophys/trace_extraction/__main__.py @@ -21,13 +21,12 @@ def create_roi_masks(rois, w, h, motion_border): for roi in rois: mask = np.array(roi["mask"], dtype=bool) px = np.argwhere(mask) - px[:,0] += roi["y"] - px[:,1] += roi["x"] + px[:, 0] += roi["y"] + px[:, 1] += roi["x"] - mask = roi_masks.create_roi_mask(w, h, motion_border, - pix_list=px[:,[1,0]], - label=str(roi["id"]), - mask_group=roi.get("mask_page",-1)) + mask = roi_masks.create_roi_mask( + w, h, motion_border, pix_list=px[:, [1, 0]], label=str(roi["id"]), mask_group=roi.get("mask_page", -1) + ) roi_list.append(mask) @@ -37,11 +36,8 @@ def create_roi_masks(rois, w, h, motion_border): return roi_list -def get_inputs_from_lims( - host, ophys_experiment_id, output_root, - job_queue, strategy -): - ''' This is a development / testing utility for running this module from the Allen Institute for Brain Science's +def get_inputs_from_lims(host, ophys_experiment_id, output_root, job_queue, strategy): + """This is a development / testing utility for running this module from the Allen Institute for Brain Science's Laboratory Information Management System (LIMS). It will only work if you are on our internal network. Parameters @@ -60,16 +56,16 @@ def get_inputs_from_lims( data : dict Response from LIMS. Should meet the schema defined in _schemas.py - ''' - - uri = f'{host}/input_jsons?object_id={ophys_experiment_id}&object_class=OphysExperiment&strategy_class={strategy}&job_queue_name={job_queue}&output_directory={output_root}' + """ + + uri = f"{host}/input_jsons?object_id={ophys_experiment_id}&object_class=OphysExperiment&strategy_class={strategy}&job_queue_name={job_queue}&output_directory={output_root}" response = requests.get(uri) data = response.json() - if len(data) == 1 and 'error' in data: - raise ValueError('bad request uri: {} ({})'.format(uri, data['error'])) + if len(data) == 1 and "error" in data: + raise ValueError("bad request uri: {} ({})".format(uri, data["error"])) - return data + return data def write_trace_file(data, names, path): @@ -82,13 +78,12 @@ def write_trace_file(data, names, path): else: raise TypeError("unable to create a variable length h5 string dtype in python version: {}", sys.version_info) - with h5py.File(path, 'w') as fil: + with h5py.File(path, "w") as fil: fil["data"] = data fil.create_dataset("roi_names", data=np.array(names).astype(np.bytes_), dtype=utf_dtype) def extract_traces(motion_corrected_stack, motion_border, storage_directory, rois, log_0, **kwargs): - # find width and height of movie with h5py.File(motion_corrected_stack, "r") as f: d = f["data"] @@ -96,16 +91,11 @@ def extract_traces(motion_corrected_stack, motion_border, storage_directory, roi w = d.shape[2] # motion border - border = [ - motion_border["x0"], - motion_border["x1"], - motion_border["y0"], - motion_border["y1"] - ] + border = [motion_border["x0"], motion_border["x1"], motion_border["y0"], motion_border["y1"]] # create roi mask objects roi_mask_list = create_roi_masks(rois, w, h, border) - roi_names = [ roi.label for roi in roi_mask_list ] + roi_names = [roi.label for roi in roi_mask_list] # extract traces roi_traces, neuropil_traces, exclusions = roi_masks.calculate_roi_and_neuropil_traces( @@ -117,32 +107,27 @@ def extract_traces(motion_corrected_stack, motion_border, storage_directory, roi np_file = os.path.abspath(os.path.join(storage_directory, "neuropil_traces.h5")) write_trace_file(neuropil_traces, roi_names, np_file) - - return { - 'neuropil_trace_file': np_file, - 'roi_trace_file': roi_file, - 'exclusion_labels': exclusions - } + + return {"neuropil_trace_file": np_file, "roi_trace_file": roi_file, "exclusion_labels": exclusions} def main(): - logging.basicConfig(format='%(asctime)s - %(process)s - %(levelname)s - %(message)s') + logging.basicConfig(format="%(asctime)s - %(process)s - %(levelname)s - %(message)s") remaining_args = sys.argv[1:] input_data = {} - if '--get_inputs_from_lims' in sys.argv: + if "--get_inputs_from_lims" in sys.argv: lims_parser = argparse.ArgumentParser(add_help=False) - lims_parser.add_argument('--host', type=str, default='http://lims2') - lims_parser.add_argument('--job_queue', type=str, default='OPHYS_EXTRACT_TRACES_QUEUE') - lims_parser.add_argument('--strategy', type=str,default='ExtractTracesStrategy') - lims_parser.add_argument('--ophys_experiment_id', type=int, default=None) - lims_parser.add_argument('--output_root', type=str, default= None) + lims_parser.add_argument("--host", type=str, default="http://lims2") + lims_parser.add_argument("--job_queue", type=str, default="OPHYS_EXTRACT_TRACES_QUEUE") + lims_parser.add_argument("--strategy", type=str, default="ExtractTracesStrategy") + lims_parser.add_argument("--ophys_experiment_id", type=int, default=None) + lims_parser.add_argument("--output_root", type=str, default=None) lims_args, remaining_args = lims_parser.parse_known_args(remaining_args) - remaining_args = [item for item in remaining_args if item != '--get_inputs_from_lims'] + remaining_args = [item for item in remaining_args if item != "--get_inputs_from_lims"] input_data = get_inputs_from_lims(**lims_args.__dict__) - try: parser = argschema.ArgSchemaParser( args=remaining_args, @@ -158,5 +143,5 @@ def main(): write_or_print_outputs(output, parser) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/brain_observatory/ophys/trace_extraction/_schemas.py b/allensdk/brain_observatory/ophys/trace_extraction/_schemas.py index 89c6528b7d..52d54d8ba8 100644 --- a/allensdk/brain_observatory/ophys/trace_extraction/_schemas.py +++ b/allensdk/brain_observatory/ophys/trace_extraction/_schemas.py @@ -1,35 +1,27 @@ from argschema import ArgSchema -from argschema.fields import LogLevel, String, Nested, Boolean, Float, List, \ - Integer +from argschema.fields import LogLevel, String, Nested, Boolean, Float, List, Integer from marshmallow import RAISE from allensdk.brain_observatory.argschema_utilities import RaisingSchema class MotionBorder(RaisingSchema): - x0 = Float(default=0.0, - description='') # TODO: be really certain about how these + x0 = Float(default=0.0, description="") # TODO: be really certain about how these # relate to physical space and then write it here - x1 = Float(default=0.0, description='') - y0 = Float(default=0.0, description='') - y1 = Float(default=0.0, description='') + x1 = Float(default=0.0, description="") + y0 = Float(default=0.0, description="") + y1 = Float(default=0.0, description="") class Roi(RaisingSchema): - mask = List(List(Boolean), required=True, description='raster mask') - y = Integer(required=True, - description='y position (pixels) of mask\'s bounding box') - x = Integer(required=True, - description='x position (pixels) of mask\'s bounding box') - width = Integer(required=True, - description='width (pixels)of mask\'s bounding box') - height = Integer(required=True, - description='height (pixels) of mask\'s bounding box') - valid = Boolean(default=True, description='Is this Roi known to be valid?') - id = Integer(required=True, - description='unique integer identifier for this Roi') - mask_page = Integer(default=-1, - description='') # TODO: this isn't in the examples + mask = List(List(Boolean), required=True, description="raster mask") + y = Integer(required=True, description="y position (pixels) of mask's bounding box") + x = Integer(required=True, description="x position (pixels) of mask's bounding box") + width = Integer(required=True, description="width (pixels)of mask's bounding box") + height = Integer(required=True, description="height (pixels) of mask's bounding box") + valid = Boolean(default=True, description="Is this Roi known to be valid?") + id = Integer(required=True, description="unique integer identifier for this Roi") + mask_page = Integer(default=-1, description="") # TODO: this isn't in the examples # I'm looking at. What is it? @@ -42,33 +34,24 @@ class InputSchema(ArgSchema): class Meta: unknown = RAISE - log_level = LogLevel(default='INFO', - description='set the logging level of the module') - motion_border = Nested(MotionBorder, required=True, - description='border widths - pixels outside the ' - 'border are considered invalid') - storage_directory = String(required=True, - description='used to set output directory') - motion_corrected_stack = String(required=True, - description='path to h5 file containing ' - 'motion corrected image stack') - rois = Nested(Roi, many=True, - description='specifications of individual regions of ' - 'interest') - log_0 = String(required=True, - description='path to motion correction output csv') # + log_level = LogLevel(default="INFO", description="set the logging level of the module") + motion_border = Nested( + MotionBorder, required=True, description="border widths - pixels outside the border are considered invalid" + ) + storage_directory = String(required=True, description="used to set output directory") + motion_corrected_stack = String( + required=True, description="path to h5 file containing motion corrected image stack" + ) + rois = Nested(Roi, many=True, description="specifications of individual regions of interest") + log_0 = String(required=True, description="path to motion correction output csv") # # TODO: is this redundant with motion border? class OutputSchema(RaisingSchema): input_parameters = Nested(InputSchema) - neuropil_trace_file = String( - required=True, - description='path to output h5 file containing neuropil traces') # + neuropil_trace_file = String(required=True, description="path to output h5 file containing neuropil traces") # # TODO rename these to _path - roi_trace_file = String( - required=True, - description='path to output h5 file containing roi traces') + roi_trace_file = String(required=True, description="path to output h5 file containing roi traces") exclusion_labels = Nested( - ExclusionLabel, many=True, - description='a report of roi-wise problems detected during extraction') + ExclusionLabel, many=True, description="a report of roi-wise problems detected during extraction" + ) diff --git a/allensdk/brain_observatory/r_neuropil.py b/allensdk/brain_observatory/r_neuropil.py index 7619b31217..6a578ebfba 100644 --- a/allensdk/brain_observatory/r_neuropil.py +++ b/allensdk/brain_observatory/r_neuropil.py @@ -40,7 +40,7 @@ def get_diagonals_from_sparse(mat): - ''' Returns a dictionary of diagonals keyed by offsets + """Returns a dictionary of diagonals keyed by offsets Parameters ---------- @@ -49,7 +49,7 @@ def get_diagonals_from_sparse(mat): Returns ------- dictionary: diagonals keyed by offsets - ''' + """ mat_dia = mat.todia() # make sure the matrix is in diagonal format @@ -65,7 +65,7 @@ def get_diagonals_from_sparse(mat): def ab_from_diagonals(mat_dict): - ''' Constructs value for scipy.linalg.solve_banded + """Constructs value for scipy.linalg.solve_banded Parameters ---------- @@ -74,7 +74,7 @@ def ab_from_diagonals(mat_dict): Returns ------- ab: value for scipy.linalg.solve_banded - ''' + """ offsets = list(mat_dict.keys()) l = -np.min(offsets) u = np.max(offsets) @@ -91,28 +91,26 @@ def ab_from_diagonals(mat_dict): def error_calc(F_M, F_N, F_C, r): - er = np.sqrt(np.mean(np.square(F_C - (F_M - r * F_N)))) / np.mean(F_M) return er def error_calc_outlier(F_M, F_N, F_C, r): - std_F_M = np.std(F_M) mean_F_M = np.mean(F_M) - ind_outlier = np.where(F_M > mean_F_M + 2. * std_F_M) + ind_outlier = np.where(F_M > mean_F_M + 2.0 * std_F_M) - er = np.sqrt(np.mean(np.square( - F_C[ind_outlier] - (F_M[ind_outlier] - r * F_N[ind_outlier])))) / np.mean(F_M[ind_outlier]) + er = np.sqrt(np.mean(np.square(F_C[ind_outlier] - (F_M[ind_outlier] - r * F_N[ind_outlier])))) / np.mean( + F_M[ind_outlier] + ) return er def ab_from_T(T, lam, dt): # using csr because multiplication is fast - Ls = -sparse.eye(T - 1, T, format='csr') + \ - sparse.eye(T - 1, T, 1, format='csr') + Ls = -sparse.eye(T - 1, T, format="csr") + sparse.eye(T - 1, T, 1, format="csr") Ls /= dt Ls2 = Ls.T.dot(Ls) @@ -138,7 +136,7 @@ def alpha_filter(A=1.0, alpha=0.05, beta=0.25, T=100): def validate_with_synthetic_F(T, N): - """ Compute N synthetic traces of length T with known values of r, then estimate r. + """Compute N synthetic traces of length T with known values of r, then estimate r. TODO: docs """ af1 = alpha_filter() @@ -151,7 +149,7 @@ def validate_with_synthetic_F(T, N): F_M_truth, F_N_truth, F_C_truth, r_truth = synthesize_F(T, af1, af2) results = estimate_contamination_ratios(F_M_truth, F_N_truth) - r_est = results['r'] + r_est = results["r"] r_truth_vals.append(r_truth) r_est_vals.append(r_est) @@ -160,14 +158,14 @@ def validate_with_synthetic_F(T, N): def synthesize_F(T, af1, af2, p1=0.05, p2=0.1): - """ Build a synthetic F_C, F_M, F_N, and r of length T + """Build a synthetic F_C, F_M, F_N, and r of length T TODO: docs """ x1 = np.random.random(T) < p1 - F_C = np.convolve(af1, x1, mode='full')[:T] + F_C = np.convolve(af1, x1, mode="full")[:T] x2 = np.random.random(T) < p2 - F_N = np.convolve(af2, x2, mode='full')[:T] + F_N = np.convolve(af2, x2, mode="full")[:T] r = 2.0 * np.random.random() @@ -177,8 +175,7 @@ def synthesize_F(T, af1, af2, p1=0.05, p2=0.1): class NeuropilSubtract(object): - """ TODO: docs - """ + """TODO: docs""" def __init__(self, lam=0.05, dt=1.0, folds=4): self.lam = lam @@ -198,7 +195,7 @@ def __init__(self, lam=0.05, dt=1.0, folds=4): self.error = None def set_F(self, F_M, F_N): - """ Break the F_M and F_N traces into the number of folds specified + """Break the F_M and F_N traces into the number of folds specified in the class constructor and normalize each fold of F_M and R_N relative to F_N. """ @@ -206,8 +203,7 @@ def set_F(self, F_M, F_N): F_N_len = len(F_N) if F_M_len != F_N_len: - raise Exception( - "F_M and F_N must have the same length (%d vs %d)" % (F_M_len, F_N_len)) + raise Exception("F_M and F_N must have the same length (%d vs %d)" % (F_M_len, F_N_len)) if self.T != F_M_len: logging.debug("updating ab matrix for new T=%d", F_M_len) @@ -221,8 +217,8 @@ def set_F(self, F_M, F_N): for fi in range(self.folds): # F_M_i_s, F_N_i_s = normalize_F(F_M[fi*self.T_f:(fi+1)*self.T_f], # F_N[fi*self.T_f:(fi+1)*self.T_f]) - self.F_M.append(F_M[fi * self.T_f:(fi + 1) * self.T_f]) - self.F_N.append(F_N[fi * self.T_f:(fi + 1) * self.T_f]) + self.F_M.append(F_M[fi * self.T_f : (fi + 1) * self.T_f]) + self.F_N.append(F_N[fi * self.T_f : (fi + 1) * self.T_f]) def fit_block_coordinate_desc(self, r_init=5.0, min_delta_r=0.00000001): F_M = np.concatenate(self.F_M) @@ -238,7 +234,7 @@ def fit_block_coordinate_desc(self, r_init=5.0, min_delta_r=0.00000001): ab = ab_from_T(self.T, self.lam, self.dt) while delta_r is None or delta_r > min_delta_r: F_C = solve_banded((1, 1), ab, F_M - r * F_N) - new_r = - np.sum((F_C - F_M) * F_N) / np.sum(np.square(F_N)) + new_r = -np.sum((F_C - F_M) * F_N) / np.sum(np.square(F_N)) error = self.estimate_error(new_r) error_vals.append(error) @@ -256,7 +252,7 @@ def fit_block_coordinate_desc(self, r_init=5.0, min_delta_r=0.00000001): self.error = error_vals.min() def fit(self, r_range=[0.0, 2.0], iterations=3, dr=0.1, dr_factor=0.1): - """ Estimate error values for a range of r values. Identify a new r range + """Estimate error values for a range of r values. Identify a new r range around the minimum error values and repeat multiple times. TODO: docs """ @@ -292,19 +288,16 @@ def fit(self, r_range=[0.0, 2.0], iterations=3, dr=0.1, dr_factor=0.1): global_min_error = min_error global_min_r = rs[min_i] - logging.debug("iteration %d, r=%0.4f, e=%.6e", - it, global_min_r, global_min_error) + logging.debug("iteration %d, r=%0.4f, e=%.6e", it, global_min_r, global_min_error) # if the minimum error is on the upper boundary, # extend the boundary and redo this iteration if min_i == len(it_errors) - 1: - logging.debug( - "minimum error found on upper r bound, extending range") + logging.debug("minimum error found on upper r bound, extending range") it_range = [rs[-1], rs[-1] + (rs[-1] - rs[0])] else: # error is somewhere on either side of the minimum error index - it_range = [rs[max(min_i - 1, 0)], - rs[min(min_i + 1, len(rs) - 1)]] + it_range = [rs[max(min_i - 1, 0)], rs[min(min_i + 1, len(rs) - 1)]] it_dr *= dr_factor it += 1 @@ -314,7 +307,7 @@ def fit(self, r_range=[0.0, 2.0], iterations=3, dr=0.1, dr_factor=0.1): self.error = global_min_error def estimate_error(self, r): - """ Estimate error values for a given r for each fold and return the mean. """ + """Estimate error values for a given r for each fold and return the mean.""" errors = np.zeros(self.folds) for fi in range(self.folds): @@ -326,10 +319,8 @@ def estimate_error(self, r): return np.mean(errors) -def estimate_contamination_ratios(F_M, F_N, - lam=0.05, folds=4, iterations=3, - r_range=[0.0, 2.0], dr=0.1, dr_factor=0.1): - ''' Calculates neuropil contamination of ROI +def estimate_contamination_ratios(F_M, F_N, lam=0.05, folds=4, iterations=3, r_range=[0.0, 2.0], dr=0.1, dr_factor=0.1): + """Calculates neuropil contamination of ROI Parameters ---------- @@ -343,16 +334,13 @@ def estimate_contamination_ratios(F_M, F_N, * 'err': RMS error * 'min_error': minimum error * 'bounds_error': boolean. True if error or R are outside tolerance - ''' + """ ns = NeuropilSubtract(lam=lam, folds=folds) ns.set_F(F_M, F_N) - ns.fit(r_range=r_range, - iterations=iterations, - dr=dr, - dr_factor=dr_factor) + ns.fit(r_range=r_range, iterations=iterations, dr=dr, dr_factor=dr_factor) # ns.fit_block_coordinate_desc() @@ -366,5 +354,5 @@ def estimate_contamination_ratios(F_M, F_N, "err": ns.error, "err_vals": ns.error_vals, "min_error": ns.error, - "it": len(ns.r_vals) + "it": len(ns.r_vals), } diff --git a/allensdk/brain_observatory/receptive_field_analysis/__init__.py b/allensdk/brain_observatory/receptive_field_analysis/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/__init__.py +++ b/allensdk/brain_observatory/receptive_field_analysis/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/brain_observatory/receptive_field_analysis/chisquarerf.py b/allensdk/brain_observatory/receptive_field_analysis/chisquarerf.py index 867f67c8c2..330c69a1f8 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/chisquarerf.py +++ b/allensdk/brain_observatory/receptive_field_analysis/chisquarerf.py @@ -83,20 +83,14 @@ def chi_square_binary(events, LSN_template): # smooth stimulus-triggered average spatially with a gaussian for n in range(num_cells): for on_off in range(2): - events_per_pixel[n, :, :, on_off] = smooth_STA( - events_per_pixel[n, :, :, on_off] - ) + events_per_pixel[n, :, :, on_off] = smooth_STA(events_per_pixel[n, :, :, on_off]) # calculate the p_value for each exclusion region chi_square_grid = np.zeros((num_cells, num_y, num_x)) for y in range(num_y): for x in range(num_x): - exclusion_mask = np.ones((num_y, num_x, 2)) * disc_masks[ - y, x, :, : - ].reshape(num_y, num_x, 1) - p_vals, __ = chi_square_within_mask( - exclusion_mask, events_per_pixel, trials_per_pixel - ) + exclusion_mask = np.ones((num_y, num_x, 2)) * disc_masks[y, x, :, :].reshape(num_y, num_x, 1) + p_vals, __ = chi_square_within_mask(exclusion_mask, events_per_pixel, trials_per_pixel) chi_square_grid[:, y, x] = p_vals return chi_square_grid @@ -128,29 +122,21 @@ def get_peak_significance(chi_square_grid_NLL, LSN_template, alpha=0.05): # find the smallest p-value and determine if it's significant significant_cells = np.zeros((num_cells)).astype(bool) best_p = np.zeros((num_cells)) - p_value_correction_factor_per_pixel = ( - 1.0 * num_y * num_x / pixels_per_mask_per_pixel - ) + p_value_correction_factor_per_pixel = 1.0 * num_y * num_x / pixels_per_mask_per_pixel best_exclusion_region_list = [] corrected_p_value_array_list = [] for n in range(num_cells): # Sidak correction: - p_value_corrected_per_pixel = 1 - np.power( - (1 - chi_square_grid[n, :, :]), p_value_correction_factor_per_pixel - ) + p_value_corrected_per_pixel = 1 - np.power((1 - chi_square_grid[n, :, :]), p_value_correction_factor_per_pixel) corrected_p_value_array_list.append(p_value_corrected_per_pixel) - y, x = np.unravel_index( - p_value_corrected_per_pixel.argmin(), (num_y, num_x) - ) + y, x = np.unravel_index(p_value_corrected_per_pixel.argmin(), (num_y, num_x)) # if more than one p-value that maxes out, use the median location if np.sum(p_value_corrected_per_pixel == 0.0) > 1: y, x = np.unravel_index( - np.argwhere(p_value_corrected_per_pixel.flatten() == 0.0)[ - :, 0 - ], + np.argwhere(p_value_corrected_per_pixel.flatten() == 0.0)[:, 0], (num_y, num_x), ) center_y, center_x = locate_median(y, x) @@ -158,15 +144,9 @@ def get_peak_significance(chi_square_grid_NLL, LSN_template, alpha=0.05): best_p[n] = p_value_corrected_per_pixel[y, x] if best_p[n] < alpha: significant_cells[n] = True - best_exclusion_region_list.append( - disc_masks[y, x, :, :].astype(bool) - ) + best_exclusion_region_list.append(disc_masks[y, x, :, :].astype(bool)) else: - best_exclusion_region_list.append( - np.zeros( - (disc_masks.shape[0], disc_masks.shape[1]), dtype=bool - ) - ) + best_exclusion_region_list.append(np.zeros((disc_masks.shape[0], disc_masks.shape[1]), dtype=bool)) return ( significant_cells, @@ -233,9 +213,7 @@ def get_events_per_pixel(responses_np, trial_matrix): for x in range(num_x): for on_off in range(2): frames = np.argwhere(trial_matrix[y, x, on_off, :])[:, 0] - events_per_pixel[:, y, x, on_off] = np.sum( - responses_np[frames, :], axis=0 - ) + events_per_pixel[:, y, x, on_off] = np.sum(responses_np[frames, :], axis=0) return events_per_pixel @@ -262,12 +240,8 @@ def smooth_STA(STA, gauss_std=0.75, total_degrees=64): deg_per_pnt = total_degrees // STA.shape[0] STA_interpolated = interpolate_RF(STA, deg_per_pnt) - STA_interpolated_smoothed = filt.gaussian_filter( - STA_interpolated, gauss_std - ) - STA_smoothed = deinterpolate_RF( - STA_interpolated_smoothed, STA.shape[1], STA.shape[0], deg_per_pnt - ) + STA_interpolated_smoothed = filt.gaussian_filter(STA_interpolated, gauss_std) + STA_smoothed = deinterpolate_RF(STA_interpolated_smoothed, STA.shape[1], STA.shape[0], deg_per_pnt) return STA_smoothed @@ -396,12 +370,8 @@ def chi_square_within_mask(exclusion_mask, events_per_pixel, trials_per_pixel): degrees_of_freedom = int(np.sum(exclusion_mask)) - 1 # observed_by_pixel has shape (num_cells,num_y,num_x,2) - expected_by_pixel = get_expected_events_by_pixel( - exclusion_mask, events_per_pixel, trials_per_pixel - ) - observed_by_pixel = ( - events_per_pixel * exclusion_mask.reshape(1, num_y, num_x, 2) - ).astype(float) + expected_by_pixel = get_expected_events_by_pixel(exclusion_mask, events_per_pixel, trials_per_pixel) + observed_by_pixel = (events_per_pixel * exclusion_mask.reshape(1, num_y, num_x, 2)).astype(float) # calculate test statistic given observed and expected residual_by_pixel = observed_by_pixel - expected_by_pixel @@ -414,9 +384,7 @@ def chi_square_within_mask(exclusion_mask, events_per_pixel, trials_per_pixel): return p_vals, chi -def get_expected_events_by_pixel( - exclusion_mask, events_per_pixel, trials_per_pixel -): +def get_expected_events_by_pixel(exclusion_mask, events_per_pixel, trials_per_pixel): """Calculate expected number of events per pixel Parameters @@ -453,14 +421,10 @@ def get_expected_events_by_pixel( total_events_by_cell = np.sum(masked_events, axis=(1, 2, 3)).astype(float) expected_by_cell_per_trial = total_events_by_cell / total_trials - return masked_trials * expected_by_cell_per_trial.reshape( - num_cells, 1, 1, 1 - ) + return masked_trials * expected_by_cell_per_trial.reshape(num_cells, 1, 1, 1) -def build_trial_matrix( - LSN_template, num_trials, on_off_luminance=(ON_LUMINANCE, OFF_LUMINANCE) -): +def build_trial_matrix(LSN_template, num_trials, on_off_luminance=(ON_LUMINANCE, OFF_LUMINANCE)): """Construct indicator arrays for on/off pixels across trials. Parameters @@ -489,9 +453,7 @@ def build_trial_matrix( for y in range(num_y): for x in range(num_x): for oo, on_off in enumerate(on_off_luminance): - frame = np.argwhere(LSN_template[:num_trials, y, x] == on_off)[ - :, 0 - ] + frame = np.argwhere(LSN_template[:num_trials, y, x] == on_off)[:, 0] trial_mat[y, x, oo, frame] = True return trial_mat @@ -544,13 +506,9 @@ def get_disc_masks( for y in range(num_y): for x in range(num_x): trials_not_gray = np.argwhere(LSN_binary[:, y, x] > 0)[:, 0] - raw_mask = np.divide( - LSN_binary[trials_not_gray, :, :].sum(axis=0), on_trials - ) + raw_mask = np.divide(LSN_binary[trials_not_gray, :, :].sum(axis=0), on_trials) - center_y, center_x = np.unravel_index( - raw_mask.argmax(), (num_y, num_x) - ) + center_y, center_x = np.unravel_index(raw_mask.argmax(), (num_y, num_x)) # include center pixel in mask raw_mask[center_y, center_x] = 0.0 @@ -571,9 +529,7 @@ def get_disc_masks( # don't include far away pixels that just happen # to not have any trials in common with center pixel clean_mask = np.ones(np.shape(raw_mask)) - clean_mask[y_min:y_max, x_min:x_max] = raw_mask[ - y_min:y_max, x_min:x_max - ] + clean_mask[y_min:y_max, x_min:x_max] = raw_mask[y_min:y_max, x_min:x_max] masks[y, x, :, :] = clean_mask diff --git a/allensdk/brain_observatory/receptive_field_analysis/eventdetection.py b/allensdk/brain_observatory/receptive_field_analysis/eventdetection.py index a63f9145f7..af9a95a019 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/eventdetection.py +++ b/allensdk/brain_observatory/receptive_field_analysis/eventdetection.py @@ -52,19 +52,13 @@ def detect_events(data, cell_index, stimulus, debug_plots=False): var_dict = {} debug_dict = {} for ii, fi in enumerate(stimulus_table["start"].values): - if ( - ii > 0 - and stimulus_table.iloc[ii].start - == stimulus_table.iloc[ii - 1].end - ): + if ii > 0 and stimulus_table.iloc[ii].start == stimulus_table.iloc[ii - 1].end: offset = 1 else: offset = 0 if fi + k_min >= 0 and fi + k_max <= len(dff_trace): - trace = dff_trace[ - fi + k_min + 1 + offset : fi + k_max + 1 + offset - ] + trace = dff_trace[fi + k_min + 1 + offset : fi + k_max + 1 + offset] xx = (trace - trace[0])[delta] - (trace - trace[0])[0] yy = max( @@ -122,9 +116,7 @@ def detect_events(data, cell_index, stimulus, debug_plots=False): # ========================================================================= - noise_threshold = max( - allowed_sigma * std_x + mu_x, allowed_sigma * std_y + mu_y - ) + noise_threshold = max(allowed_sigma * std_x + mu_x, allowed_sigma * std_y + mu_y) mu_array = np.array([mu_x, mu_y]) yes_set, no_set = set(), set() for ii, (t0, tf, xx, yy) in var_dict.items(): @@ -136,12 +128,7 @@ def detect_events(data, cell_index, stimulus, debug_plots=False): # 3) Change evoked by this trial, not previous # 4) At end of trace, ended up outside of noise floor - if ( - np.sqrt(xi_z**2 + yi_z**2) > 4 - and yy > 0.05 - and xx < yy - and tf > noise_threshold / 2 - ): + if np.sqrt(xi_z**2 + yi_z**2) > 4 and yy > 0.05 and xx < yy and tf > noise_threshold / 2: yes_set.add(ii) else: no_set.add(ii) @@ -161,9 +148,7 @@ def detect_events(data, cell_index, stimulus, debug_plots=False): if ti in no_set: ax[0].plot(np.arange(key, key + len(trace)), trace, "b") elif ti in yes_set: - ax[0].plot( - np.arange(key, key + len(trace)), trace, "r", linewidth=2 - ) + ax[0].plot(np.arange(key, key + len(trace)), trace, "r", linewidth=2) else: raise Exception diff --git a/allensdk/brain_observatory/receptive_field_analysis/fit_parameters.py b/allensdk/brain_observatory/receptive_field_analysis/fit_parameters.py index 02b0f59556..44a1416260 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/fit_parameters.py +++ b/allensdk/brain_observatory/receptive_field_analysis/fit_parameters.py @@ -37,47 +37,44 @@ import numpy as np import warnings -def add_to_fit_parameters_dict_single(fit_parameters_dict, p): - fit_parameters_dict['height'].append(p[0]) - fit_parameters_dict['center_y'].append(p[1]) - fit_parameters_dict['center_x'].append(p[2]) - fit_parameters_dict['width_y'].append(p[3]) - fit_parameters_dict['width_x'].append(p[4]) - fit_parameters_dict['rotation'].append(p[5]) +def add_to_fit_parameters_dict_single(fit_parameters_dict, p): + fit_parameters_dict["height"].append(p[0]) + fit_parameters_dict["center_y"].append(p[1]) + fit_parameters_dict["center_x"].append(p[2]) + fit_parameters_dict["width_y"].append(p[3]) + fit_parameters_dict["width_x"].append(p[4]) + fit_parameters_dict["rotation"].append(p[5]) if (p[3] is None) or (p[4] is None): - fit_parameters_dict['area'].append(None) + fit_parameters_dict["area"].append(None) else: - fit_parameters_dict['area'].append(np.pi * (3./2) ** 2 * np.abs(p[3]) * np.abs(p[4])) + fit_parameters_dict["area"].append(np.pi * (3.0 / 2) ** 2 * np.abs(p[3]) * np.abs(p[4])) -def get_gaussian_fit_single_channel(rf, fit_parameters_dict): +def get_gaussian_fit_single_channel(rf, fit_parameters_dict): try: p_fit = fitgaussian2D(rf) add_to_fit_parameters_dict_single(fit_parameters_dict, p_fit) data_fitted_on = gaussian2D(*p_fit)(*np.indices(rf.shape)) - fit_parameters_dict['data'].append(data_fitted_on) + fit_parameters_dict["data"].append(data_fitted_on) except GaussianFitError: - warnings.warn('GaussianFitError (on subfield) caught') - add_to_fit_parameters_dict_single(fit_parameters_dict, [None]*6) - fit_parameters_dict['data'].append(np.zeros_like(rf)) + warnings.warn("GaussianFitError (on subfield) caught") + add_to_fit_parameters_dict_single(fit_parameters_dict, [None] * 6) + fit_parameters_dict["data"].append(np.zeros_like(rf)) -def compute_distance(center_on, center_off): +def compute_distance(center_on, center_off): center_x_on, center_y_on = center_on center_x_off, center_y_off = center_off if (center_x_on is None) or (center_y_on is None) or (center_x_off is None) or (center_y_off is None): return None else: - return np.sqrt((center_x_off-center_x_on)**2+(center_y_off-center_y_on)**2) + return np.sqrt((center_x_off - center_x_on) ** 2 + (center_y_off - center_y_on) ** 2) -def compute_overlap(data_fitted_on, data_fitted_off): +def compute_overlap(data_fitted_on, data_fitted_off): on_bin = np.where(data_fitted_on > 0.001, 1, 0) off_bin = np.where(data_fitted_off > 0.001, 1, 0) return float((np.multiply(on_bin, off_bin)).sum()) / (np.sqrt(on_bin.sum()) * np.sqrt(off_bin.sum())) - - - diff --git a/allensdk/brain_observatory/receptive_field_analysis/fitgaussian2D.py b/allensdk/brain_observatory/receptive_field_analysis/fitgaussian2D.py index 2278b2a016..19e74be7c6 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/fitgaussian2D.py +++ b/allensdk/brain_observatory/receptive_field_analysis/fitgaussian2D.py @@ -42,7 +42,7 @@ class GaussianFitError(RuntimeError): def gaussian2D(height, center_x, center_y, width_x, width_y, rotation): - '''Build a function which evaluates a scaled 2d gaussian pdf + """Build a function which evaluates a scaled 2d gaussian pdf Parameters ---------- @@ -62,29 +62,30 @@ def gaussian2D(height, center_x, center_y, width_x, width_y, rotation): Returns ------- rotgauss: fn - parameters are x and y positions (row/column semantics are set by your - inputs to this function). Return value is the scaled gaussian pdf + parameters are x and y positions (row/column semantics are set by your + inputs to this function). Return value is the scaled gaussian pdf evaluated at the argued point. - ''' + """ width_x = float(width_x) width_y = float(width_y) - + rotation = np.deg2rad(rotation) - center_xp = center_x*np.cos(rotation) - center_y*np.sin(rotation) - center_yp = center_x*np.sin(rotation) + center_y*np.cos(rotation) - - def rotgauss(x,y): - xp = x*np.cos(rotation) - y*np.sin(rotation) - yp = x*np.sin(rotation) + y*np.cos(rotation) - g = height*np.exp(-((center_xp-xp)/width_x)**2/2.0 - ((center_yp-yp)/width_y)**2/2.) + center_xp = center_x * np.cos(rotation) - center_y * np.sin(rotation) + center_yp = center_x * np.sin(rotation) + center_y * np.cos(rotation) + + def rotgauss(x, y): + xp = x * np.cos(rotation) - y * np.sin(rotation) + yp = x * np.sin(rotation) + y * np.cos(rotation) + g = height * np.exp(-(((center_xp - xp) / width_x) ** 2) / 2.0 - ((center_yp - yp) / width_y) ** 2 / 2.0) return g + return rotgauss - + def moments2(data): - '''Treating input image data as an independent multivariate gaussian, + """Treating input image data as an independent multivariate gaussian, estimate mean and standard deviations Parameters @@ -101,37 +102,37 @@ def moments2(data): x : float Mean column index width_y : float - The standard deviation along the mean row + The standard deviation along the mean row width_x : float The standard deviation along the mean column - None : + None : This function returns an instance of None. Notes ----- uses original method from website for finding center - ''' + """ total = data.sum() - Y,X = np.indices(data.shape) - x = ( X * data ).sum() / total - y = ( Y * data ).sum() / total + Y, X = np.indices(data.shape) + x = (X * data).sum() / total + y = (Y * data).sum() / total col = data[:, int(np.around(x))] - width_x = np.sqrt( abs( ( np.arange(col.size) - y ) ** 2 * col ).sum() / col.sum() ) + width_x = np.sqrt(abs((np.arange(col.size) - y) ** 2 * col).sum() / col.sum()) row = data[int(np.around(y)), :] - width_y = np.sqrt( abs( ( np.arange(row.size) - x ) ** 2 * row ).sum() / row.sum() ) + width_y = np.sqrt(abs((np.arange(row.size) - x) ** 2 * row).sum() / row.sum()) height = data.max() return height, y, x, width_y, width_x, None - + def fitgaussian2D(data): - '''Fit a 2D gaussian to an image + """Fit a 2D gaussian to an image Parameters ---------- @@ -147,30 +148,34 @@ def fitgaussian2D(data): row standard deviation column standard deviation rotation - + Notes ----- see gaussian2D for details about output values - ''' + """ params = moments2(data) + def errorfunction(p): p2 = np.array([p[0], params[1], params[2], np.abs(p[1]), np.abs(p[2]), p[3]]) - val = np.ravel(gaussian2D(*p2)(*np.indices(data.shape)) - data) return (val**2).sum() - res = optimize.minimize(errorfunction, [ params[0], params[3], params[4], 0.0 ], method='Nelder-Mead', options={'maxfev':2500}) + res = optimize.minimize( + errorfunction, [params[0], params[3], params[4], 0.0], method="Nelder-Mead", options={"maxfev": 2500} + ) p = res.x p2 = np.array([p[0], params[1], params[2], np.abs(p[1]), np.abs(p[2]), p[3]]) success = res.success - if not success and res.status != 2: # Status 2 is loss of precision; might need to handle this separately instead of passing... + if ( + not success and res.status != 2 + ): # Status 2 is loss of precision; might need to handle this separately instead of passing... print(success) print(res.message) print(res.status) - raise GaussianFitError('Gaussian optimization failed to converge:\n%s' % res.message) + raise GaussianFitError("Gaussian optimization failed to converge:\n%s" % res.message) return p2 diff --git a/allensdk/brain_observatory/receptive_field_analysis/postprocessing.py b/allensdk/brain_observatory/receptive_field_analysis/postprocessing.py index de12f29d93..6b9e41080e 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/postprocessing.py +++ b/allensdk/brain_observatory/receptive_field_analysis/postprocessing.py @@ -58,43 +58,24 @@ def get_gaussian_fit(rf): counter = {"on": 0, "off": 0} for on_off_key in ["on", "off"]: fit_parameters_dict = fit_parameters_dict_combined[on_off_key] - for ci in range( - rf[on_off_key]["fdr_mask"]["attrs"]["number_of_components"] - ): + for ci in range(rf[on_off_key]["fdr_mask"]["attrs"]["number_of_components"]): curr_component_mask = ( - upsample_image_to_degrees( - np.logical_not( - rf[on_off_key]["fdr_mask"]["data"][ci, :, :] - ) - ) - > 0.5 - ) - rf_response = upsample_image_to_degrees( - rf[on_off_key]["rts_convolution"]["data"].copy() + upsample_image_to_degrees(np.logical_not(rf[on_off_key]["fdr_mask"]["data"][ci, :, :])) > 0.5 ) + rf_response = upsample_image_to_degrees(rf[on_off_key]["rts_convolution"]["data"].copy()) rf_response[curr_component_mask] = 0 if rf_response.sum() > 0: - get_gaussian_fit_single_channel( - rf_response, fit_parameters_dict - ) + get_gaussian_fit_single_channel(rf_response, fit_parameters_dict) counter[on_off_key] += 1 for ii_off in range(counter["on"]): - fit_parameters_dict_combined["on"]["distance"].append( - [None] * counter["off"] - ) - fit_parameters_dict_combined["on"]["overlap"].append( - [None] * counter["off"] - ) + fit_parameters_dict_combined["on"]["distance"].append([None] * counter["off"]) + fit_parameters_dict_combined["on"]["overlap"].append([None] * counter["off"]) for ii_off in range(counter["off"]): - fit_parameters_dict_combined["off"]["distance"].append( - [None] * counter["on"] - ) - fit_parameters_dict_combined["off"]["overlap"].append( - [None] * counter["on"] - ) + fit_parameters_dict_combined["off"]["distance"].append([None] * counter["on"]) + fit_parameters_dict_combined["off"]["overlap"].append([None] * counter["on"]) for ii_on in range(counter["on"]): for ii_off in range(counter["off"]): @@ -107,22 +88,14 @@ def get_gaussian_fit(rf): fit_parameters_dict_combined["off"]["center_y"][ii_off], ) curr_distance = compute_distance(center_on, center_off) - fit_parameters_dict_combined["on"]["distance"][ii_on][ - ii_off - ] = curr_distance - fit_parameters_dict_combined["off"]["distance"][ii_off][ - ii_on - ] = curr_distance + fit_parameters_dict_combined["on"]["distance"][ii_on][ii_off] = curr_distance + fit_parameters_dict_combined["off"]["distance"][ii_off][ii_on] = curr_distance data_on = fit_parameters_dict_combined["on"]["data"][ii_on] data_off = fit_parameters_dict_combined["off"]["data"][ii_off] curr_overlap = compute_overlap(data_on, data_off) - fit_parameters_dict_combined["on"]["overlap"][ii_on][ - ii_off - ] = curr_overlap - fit_parameters_dict_combined["off"]["overlap"][ii_off][ - ii_on - ] = curr_overlap + fit_parameters_dict_combined["on"]["overlap"][ii_on][ii_off] = curr_overlap + fit_parameters_dict_combined["off"]["overlap"][ii_off][ii_on] = curr_overlap return fit_parameters_dict_combined, counter @@ -143,29 +116,21 @@ def run_postprocessing(data, rf): if key == "data": rf[on_off_key]["gaussian_fit"]["data"] = np.array(val) else: - rf[on_off_key]["gaussian_fit"]["attrs"][key] = np.array( - val - ) + rf[on_off_key]["gaussian_fit"]["attrs"][key] = np.array(val) # Chi squared test statistic postprocessing: # cell_index = rf["attrs"]["cell_index"] locally_sparse_noise_template = data.get_stimulus_template(stimulus) - event_array = np.zeros( - (rf["event_vector"]["data"].shape[0], 1), dtype=bool - ) + event_array = np.zeros((rf["event_vector"]["data"].shape[0], 1), dtype=bool) event_array[:, 0] = rf["event_vector"]["data"] - chi_squared_grid = chi_square_binary( - event_array, locally_sparse_noise_template - ) + chi_squared_grid = chi_square_binary(event_array, locally_sparse_noise_template) alpha = rf["on"]["fdr_mask"]["attrs"]["alpha"] assert rf["off"]["fdr_mask"]["attrs"]["alpha"] == alpha chi_square_grid_NLL = pvalue_to_NLL(chi_squared_grid) - peak_significance = get_peak_significance( - chi_square_grid_NLL, locally_sparse_noise_template, alpha=alpha - ) + peak_significance = get_peak_significance(chi_square_grid_NLL, locally_sparse_noise_template, alpha=alpha) significant = peak_significance[0][0] min_p = peak_significance[1][0] pvalues_chi_square = peak_significance[2][0] diff --git a/allensdk/brain_observatory/receptive_field_analysis/receptive_field.py b/allensdk/brain_observatory/receptive_field_analysis/receptive_field.py index 4c646e13af..3c3d3b4e8d 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/receptive_field.py +++ b/allensdk/brain_observatory/receptive_field_analysis/receptive_field.py @@ -36,16 +36,14 @@ from .eventdetection import detect_events from statsmodels.sandbox.stats.multicomp import multipletests import numpy as np -from .utilities import get_A, get_A_blur, get_shuffle_matrix, get_components, \ - dict_generator +from .utilities import get_A, get_A_blur, get_shuffle_matrix, get_components, dict_generator from .postprocessing import run_postprocessing import h5py -def events_to_pvalues_no_fdr_correction(data, event_vector, A, - number_of_shuffles=5000, - response_detection_error_std_dev=.1, - seed=1): +def events_to_pvalues_no_fdr_correction( + data, event_vector, A, number_of_shuffles=5000, response_detection_error_std_dev=0.1, seed=1 +): number_of_pixels = A.shape[0] // 2 # Initializations: @@ -53,52 +51,48 @@ def events_to_pvalues_no_fdr_correction(data, event_vector, A, np.random.seed(seed) shuffle_data = get_shuffle_matrix( - data, event_vector, A, + data, + event_vector, + A, number_of_shuffles=number_of_shuffles, - response_detection_error_std_dev=response_detection_error_std_dev) + response_detection_error_std_dev=response_detection_error_std_dev, + ) # Build list of p-values: response_triggered_stimulus_vector = A.dot(event_vector) / number_of_events p_value_list = [] for pi in range(2 * number_of_pixels): - curr_p_value = \ - 1 - (shuffle_data[pi, :] < - response_triggered_stimulus_vector[pi]).sum() * \ - 1. / number_of_shuffles + curr_p_value = ( + 1 - (shuffle_data[pi, :] < response_triggered_stimulus_vector[pi]).sum() * 1.0 / number_of_shuffles + ) p_value_list.append(curr_p_value) return np.array(p_value_list) def compute_receptive_field(data, cell_index, stimulus, **kwargs): - alpha = kwargs.pop('alpha') + alpha = kwargs.pop("alpha") event_vector = detect_events(data, cell_index, stimulus) A_blur = get_A_blur(data, stimulus) number_of_pixels = A_blur.shape[0] // 2 - pvalues = events_to_pvalues_no_fdr_correction(data, event_vector, A_blur, - **kwargs) + pvalues = events_to_pvalues_no_fdr_correction(data, event_vector, A_blur, **kwargs) stimulus_table = data.get_stimulus_table(stimulus) - stimulus_template = data.get_stimulus_template(stimulus)[ - stimulus_table['frame'].values, :, :] + stimulus_template = data.get_stimulus_template(stimulus)[stimulus_table["frame"].values, :, :] s1, s2 = stimulus_template.shape[1], stimulus_template.shape[2] - pvalues_on, pvalues_off = \ - pvalues[:number_of_pixels]\ - .reshape(s1, s2), pvalues[number_of_pixels:].reshape(s1, s2) + pvalues_on, pvalues_off = pvalues[:number_of_pixels].reshape(s1, s2), pvalues[number_of_pixels:].reshape(s1, s2) fdr_corrected_pvalues = multipletests(pvalues, alpha=alpha)[1] - fdr_corrected_pvalues_on = fdr_corrected_pvalues[ - :number_of_pixels].reshape(s1, s2) + fdr_corrected_pvalues_on = fdr_corrected_pvalues[:number_of_pixels].reshape(s1, s2) _fdr_mask_on = np.zeros_like(pvalues_on, dtype=bool) _fdr_mask_on[fdr_corrected_pvalues_on < alpha] = True components_on, number_of_components_on = get_components(_fdr_mask_on) - fdr_corrected_pvalues_off = fdr_corrected_pvalues[ - number_of_pixels:].reshape(s1, s2) + fdr_corrected_pvalues_off = fdr_corrected_pvalues[number_of_pixels:].reshape(s1, s2) _fdr_mask_off = np.zeros_like(pvalues_off, dtype=bool) _fdr_mask_off[fdr_corrected_pvalues_off < alpha] = True components_off, number_of_components_off = get_components(_fdr_mask_off) @@ -107,66 +101,63 @@ def compute_receptive_field(data, cell_index, stimulus, **kwargs): A_blur = get_A_blur(data, stimulus) response_triggered_stimulus_field = A.dot(event_vector) - response_triggered_stimulus_field_on = response_triggered_stimulus_field[ - :number_of_pixels].reshape(s1, s2) - response_triggered_stimulus_field_off = response_triggered_stimulus_field[ - number_of_pixels:].reshape(s1, s2) + response_triggered_stimulus_field_on = response_triggered_stimulus_field[:number_of_pixels].reshape(s1, s2) + response_triggered_stimulus_field_off = response_triggered_stimulus_field[number_of_pixels:].reshape(s1, s2) response_triggered_stimulus_field_convolution = A_blur.dot(event_vector) - response_triggered_stimulus_field_convolution_on = \ - response_triggered_stimulus_field_convolution[:number_of_pixels]\ - .reshape(s1, s2) - response_triggered_stimulus_field_convolution_off = \ - response_triggered_stimulus_field_convolution[number_of_pixels:]\ - .reshape(s1, s2) - - on_dict = {'pvalues': {'data': pvalues_on}, - 'fdr_corrected': {'data': fdr_corrected_pvalues_on, - 'attrs': { - 'alpha': alpha, - 'min_p': fdr_corrected_pvalues_on.min()}}, - 'fdr_mask': { - 'data': components_on, - 'attrs': { - 'alpha': alpha, - 'number_of_components': number_of_components_on, - 'number_of_pixels': components_on - .sum(axis=1) - .sum(axis=1)}}, - 'rts_convolution': { - 'data': response_triggered_stimulus_field_convolution_on}, - 'rts': {'data': response_triggered_stimulus_field_on} - } - off_dict = {'pvalues': {'data': pvalues_off}, - 'fdr_corrected': {'data': fdr_corrected_pvalues_off, - 'attrs': { - 'alpha': alpha, - 'min_p': - fdr_corrected_pvalues_off.min()}}, - 'fdr_mask': { - 'data': components_off, - 'attrs': { - 'alpha': alpha, - 'number_of_components': number_of_components_off, - 'number_of_pixels': components_off - .sum(axis=1) - .sum(axis=1)}}, - 'rts_convolution': { - 'data': response_triggered_stimulus_field_convolution_off}, - 'rts': {'data': response_triggered_stimulus_field_off} - } - - result_dict = {'event_vector': {'data': event_vector, 'attrs': { - 'number_of_events': event_vector.sum()}}, - 'on': on_dict, - 'off': off_dict, - 'attrs': {'cell_index': cell_index, 'stimulus': stimulus}} + response_triggered_stimulus_field_convolution_on = response_triggered_stimulus_field_convolution[ + :number_of_pixels + ].reshape(s1, s2) + response_triggered_stimulus_field_convolution_off = response_triggered_stimulus_field_convolution[ + number_of_pixels: + ].reshape(s1, s2) + + on_dict = { + "pvalues": {"data": pvalues_on}, + "fdr_corrected": { + "data": fdr_corrected_pvalues_on, + "attrs": {"alpha": alpha, "min_p": fdr_corrected_pvalues_on.min()}, + }, + "fdr_mask": { + "data": components_on, + "attrs": { + "alpha": alpha, + "number_of_components": number_of_components_on, + "number_of_pixels": components_on.sum(axis=1).sum(axis=1), + }, + }, + "rts_convolution": {"data": response_triggered_stimulus_field_convolution_on}, + "rts": {"data": response_triggered_stimulus_field_on}, + } + off_dict = { + "pvalues": {"data": pvalues_off}, + "fdr_corrected": { + "data": fdr_corrected_pvalues_off, + "attrs": {"alpha": alpha, "min_p": fdr_corrected_pvalues_off.min()}, + }, + "fdr_mask": { + "data": components_off, + "attrs": { + "alpha": alpha, + "number_of_components": number_of_components_off, + "number_of_pixels": components_off.sum(axis=1).sum(axis=1), + }, + }, + "rts_convolution": {"data": response_triggered_stimulus_field_convolution_off}, + "rts": {"data": response_triggered_stimulus_field_off}, + } + + result_dict = { + "event_vector": {"data": event_vector, "attrs": {"number_of_events": event_vector.sum()}}, + "on": on_dict, + "off": off_dict, + "attrs": {"cell_index": cell_index, "stimulus": stimulus}, + } return result_dict -def compute_receptive_field_with_postprocessing(data, cell_index, stimulus, - **kwargs): +def compute_receptive_field_with_postprocessing(data, cell_index, stimulus, **kwargs): rf = compute_receptive_field(data, cell_index, stimulus, **kwargs) rf = run_postprocessing(data, rf) @@ -176,41 +167,38 @@ def compute_receptive_field_with_postprocessing(data, cell_index, stimulus, def get_attribute_dict(rf): attribute_dict = {} for x in dict_generator(rf): - if x[-3] == 'attrs': + if x[-3] == "attrs": if len(x[:-3]) == 0: key = x[-2] else: - key = '/'.join(['/'.join(x[:-3]), x[-2]]) + key = "/".join(["/".join(x[:-3]), x[-2]]) attribute_dict[key] = x[-1] return attribute_dict def print_summary(rf): - for key_val in sorted(get_attribute_dict(rf).items(), - key=lambda x: x[0]): + for key_val in sorted(get_attribute_dict(rf).items(), key=lambda x: x[0]): print("%s : %s" % key_val) -def write_receptive_field_to_h5(rf, file_name, prefix=''): +def write_receptive_field_to_h5(rf, file_name, prefix=""): attr_list = [] - f = h5py.File(file_name, 'a') + f = h5py.File(file_name, "a") for x in dict_generator(rf): - - if x[-2] == 'data': - f['/'.join([prefix] + x[:-1])] = x[-1] - elif x[-3] == 'attrs': + if x[-2] == "data": + f["/".join([prefix] + x[:-1])] = x[-1] + elif x[-3] == "attrs": attr_list.append(x) else: raise Exception for x in attr_list: if len(x) > 3: - f['/'.join([prefix] + x[:-3])].attrs[x[-2]] = x[-1] + f["/".join([prefix] + x[:-3])].attrs[x[-2]] = x[-1] else: assert len(x) == 3 - if prefix == '': - + if prefix == "": if x[-1] is None: f.attrs[x[-2]] = np.nan else: @@ -227,9 +215,9 @@ def write_receptive_field_to_h5(rf, file_name, prefix=''): def read_h5_group(g): return_dict = {} if len(g.attrs) > 0: - return_dict['attrs'] = dict(g.attrs) + return_dict["attrs"] = dict(g.attrs) for key in g: - if key == 'data': + if key == "data": return_dict[key] = g[key][()] else: return_dict[key] = read_h5_group(g[key]) @@ -238,7 +226,7 @@ def read_h5_group(g): def read_receptive_field_from_h5(file_name, path=None): - f = h5py.File(file_name, 'r') + f = h5py.File(file_name, "r") if path is None: rf = read_h5_group(f) else: diff --git a/allensdk/brain_observatory/receptive_field_analysis/tools.py b/allensdk/brain_observatory/receptive_field_analysis/tools.py index 1ba9efa035..dac0be9b64 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/tools.py +++ b/allensdk/brain_observatory/receptive_field_analysis/tools.py @@ -34,8 +34,7 @@ # POSSIBILITY OF SUCH DAMAGE. # def list_of_dicts_to_dict_of_lists(list_of_dicts): - return {key: [item[key] for item in list_of_dicts] for key in - list_of_dicts[0].keys()} + return {key: [item[key] for item in list_of_dicts] for key in list_of_dicts[0].keys()} def dict_generator(indict, pre=None): @@ -58,9 +57,9 @@ def dict_generator(indict, pre=None): def read_h5_group(g): return_dict = {} if len(g.attrs) > 0: - return_dict['attrs'] = dict(g.attrs) + return_dict["attrs"] = dict(g.attrs) for key in g: - if key == 'data': + if key == "data": return_dict[key] = g[key][()] else: return_dict[key] = read_h5_group(g[key]) diff --git a/allensdk/brain_observatory/receptive_field_analysis/utilities.py b/allensdk/brain_observatory/receptive_field_analysis/utilities.py index bbd7b19628..eb4e412d1c 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/utilities.py +++ b/allensdk/brain_observatory/receptive_field_analysis/utilities.py @@ -68,9 +68,7 @@ def convolve(img, sigma=4): return img img_pad = np.zeros((3 * img.shape[0], 3 * img.shape[1])) - img_pad[ - img.shape[0] : 2 * img.shape[0], img.shape[1] : 2 * img.shape[1] - ] = img + img_pad[img.shape[0] : 2 * img.shape[0], img.shape[1] : 2 * img.shape[1]] = img x = np.arange(3 * img.shape[0]) y = np.arange(3 * img.shape[1]) @@ -93,9 +91,7 @@ def convolve(img, sigma=4): z_on_new = block_reduce(ZZ_on_f, (upsample, upsample)) z_on_new = z_on_new / z_on_new.sum() * img.sum() - z_on_new = z_on_new[ - img.shape[0] : 2 * img.shape[0], img.shape[1] : 2 * img.shape[1] - ] + z_on_new = z_on_new[img.shape[0] : 2 * img.shape[0], img.shape[1] : 2 * img.shape[1]] return z_on_new @@ -103,20 +99,14 @@ def convolve(img, sigma=4): @memoize def get_A(data, stimulus): stimulus_table = data.get_stimulus_table(stimulus) - stimulus_template = data.get_stimulus_template(stimulus)[ - stimulus_table["frame"].values, :, : - ] + stimulus_template = data.get_stimulus_template(stimulus)[stimulus_table["frame"].values, :, :] number_of_pixels = stimulus_template.shape[1] * stimulus_template.shape[2] A = np.zeros((2 * number_of_pixels, stimulus_template.shape[0])) for fi in range(stimulus_template.shape[0]): - A[:number_of_pixels, fi] = ( - stimulus_template[fi, :, :].flatten() > 127 - ).astype(float) - A[number_of_pixels:, fi] = ( - stimulus_template[fi, :, :].flatten() < 127 - ).astype(float) + A[:number_of_pixels, fi] = (stimulus_template[fi, :, :].flatten() > 127).astype(float) + A[number_of_pixels:, fi] = (stimulus_template[fi, :, :].flatten() < 127).astype(float) return A @@ -124,23 +114,17 @@ def get_A(data, stimulus): @memoize def get_A_blur(data, stimulus): stimulus_table = data.get_stimulus_table(stimulus) - stimulus_template = data.get_stimulus_template(stimulus)[ - stimulus_table["frame"].values, :, : - ] + stimulus_template = data.get_stimulus_template(stimulus)[stimulus_table["frame"].values, :, :] A = get_A(data, stimulus).copy() number_of_pixels = A.shape[0] // 2 for fi in range(A.shape[1]): A[:number_of_pixels, fi] = convolve( - A[:number_of_pixels, fi].reshape( - stimulus_template.shape[1], stimulus_template.shape[2] - ) + A[:number_of_pixels, fi].reshape(stimulus_template.shape[1], stimulus_template.shape[2]) ).flatten() A[number_of_pixels:, fi] = convolve( - A[number_of_pixels:, fi].reshape( - stimulus_template.shape[1], stimulus_template.shape[2] - ) + A[number_of_pixels:, fi].reshape(stimulus_template.shape[1], stimulus_template.shape[2]) ).flatten() return A @@ -158,13 +142,7 @@ def get_shuffle_matrix( shuffle_data = np.zeros((2 * number_of_pixels, number_of_shuffles)) evr = range(len(event_vector)) for ii in range(number_of_shuffles): - size = number_of_events + int( - np.round( - response_detection_error_std_dev - * number_of_events - * np.random.randn() - ) - ) + size = number_of_events + int(np.round(response_detection_error_std_dev * number_of_events * np.random.randn())) shuffled_event_inds = np.random.choice(evr, size=size, replace=False) b_tmp = np.zeros(len(event_vector), dtype=bool) @@ -174,9 +152,7 @@ def get_shuffle_matrix( return shuffle_data -def get_sparse_noise_epoch_mask_list( - st, number_of_acquisition_frames, threshold=7 -): +def get_sparse_noise_epoch_mask_list(st, number_of_acquisition_frames, threshold=7): delta = st.start.values[1:] - st.end.values[:-1] cut_inds = np.where(delta > threshold)[0] + 1 @@ -248,10 +224,7 @@ def smooth(x, window_len=11, window="hanning", mode="valid"): return x if window not in ["flat", "hanning", "hamming", "bartlett", "blackman"]: - raise ValueError( - "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', " - "'blackman'" - ) + raise ValueError("Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'") s = np.r_[x[window_len - 1 : 0 : -1], x, x[-1:-window_len:-1]] # print(len(s)) @@ -269,8 +242,7 @@ def get_components(receptive_field_data): candidate_pixel_list = np.where(receptive_field_data.flatten())[0] pixel_coord_dict = dict( - (px, (int(px / s2), (px - s2 * int(px / s2)), px % (s1 * s2) == px)) - for px in candidate_pixel_list + (px, (int(px / s2), (px - s2 * int(px / s2)), px % (s1 * s2) == px)) for px in candidate_pixel_list ) component_list = [] @@ -305,13 +277,9 @@ def get_components(receptive_field_data): component_list = new_component_list if len(component_list) == 0: - return np.zeros( - (1, receptive_field_data.shape[0], receptive_field_data.shape[1]) - ), len(component_list) + return np.zeros((1, receptive_field_data.shape[0], receptive_field_data.shape[1])), len(component_list) elif len(component_list) == 1: - return_array = np.zeros( - (1, receptive_field_data.shape[0], receptive_field_data.shape[1]) - ) + return_array = np.zeros((1, receptive_field_data.shape[0], receptive_field_data.shape[1])) else: return_array = np.zeros( ( @@ -322,13 +290,9 @@ def get_components(receptive_field_data): ) for ii, component in enumerate(component_list): - curr_component_mask = np.zeros_like( - receptive_field_data, dtype=bool - ).flatten() + curr_component_mask = np.zeros_like(receptive_field_data, dtype=bool).flatten() curr_component_mask[component] = True - return_array[ii, :, :] = curr_component_mask.reshape( - receptive_field_data.shape - ) + return_array[ii, :, :] = curr_component_mask.reshape(receptive_field_data.shape) return return_array, len(component_list) diff --git a/allensdk/brain_observatory/receptive_field_analysis/visualization.py b/allensdk/brain_observatory/receptive_field_analysis/visualization.py index 2094ccabd2..18a8479e21 100644 --- a/allensdk/brain_observatory/receptive_field_analysis/visualization.py +++ b/allensdk/brain_observatory/receptive_field_analysis/visualization.py @@ -39,33 +39,35 @@ import matplotlib.gridspec as gridspec import matplotlib.patches as mpatches -DEFAULT_CMAP = 'magma' +DEFAULT_CMAP = "magma" -def plot_ellipses(gaussian_fit_dict, ax=None, show=True, close=True, save_file_name=None, color='b'): - '''Example Usage: + +def plot_ellipses(gaussian_fit_dict, ax=None, show=True, close=True, save_file_name=None, color="b"): + """Example Usage: oeid, cell_index, stimulus = 512176430, 12, 'locally_sparse_noise' brain_observatory_cache = BrainObservatoryCache() data_set = brain_observatory_cache.get_ophys_experiment_data(oeid) lsn = LocallySparseNoise(data_set, stimulus) result = compute_receptive_field_with_postprocessing(data_set, cell_index, stimulus, alpha=.05, number_of_shuffles=5000) plot_ellipses(result['off']['gaussian_fit'], color='r') - ''' + """ if ax is None: fig, ax = plt.subplots(1) ax.set_xlim(0, 130) ax.set_ylim(0, 74) - plt.axis('off') + plt.axis("off") - on_comp = len(gaussian_fit_dict['attrs']['center_x']) + on_comp = len(gaussian_fit_dict["attrs"]["center_x"]) for i in range(on_comp): - xy = (gaussian_fit_dict['attrs']['center_x'][i], gaussian_fit_dict['attrs']['center_y'][i]) - width = 3 * np.abs(gaussian_fit_dict['attrs']['width_x'][i]) - height = 3 * np.abs(gaussian_fit_dict['attrs']['width_y'][i]) - angle = gaussian_fit_dict['attrs']['rotation'][i] + xy = (gaussian_fit_dict["attrs"]["center_x"][i], gaussian_fit_dict["attrs"]["center_y"][i]) + width = 3 * np.abs(gaussian_fit_dict["attrs"]["width_x"][i]) + height = 3 * np.abs(gaussian_fit_dict["attrs"]["width_y"][i]) + angle = gaussian_fit_dict["attrs"]["rotation"][i] if np.logical_not(any(np.isnan(xy))): - ellipse = mpatches.Ellipse(xy, width=width, height=height, angle=angle, lw=2, edgecolor=color, - facecolor='none') + ellipse = mpatches.Ellipse( + xy, width=width, height=height, angle=angle, lw=2, edgecolor=color, facecolor="none" + ) ax.add_artist(ellipse) if save_file_name is not None: @@ -79,18 +81,19 @@ def plot_ellipses(gaussian_fit_dict, ax=None, show=True, close=True, save_file_n return ax -def pvalue_to_NLL(p_values, - max_NLL=10.0): + +def pvalue_to_NLL(p_values, max_NLL=10.0): return np.where(p_values == 0.0, max_NLL, -np.log10(p_values)) + def plot_chi_square_summary(rf_data, ax=None, cax=None, cmap=DEFAULT_CMAP): if ax is None: ax = plt.gca() - chi_squared_grid = rf_data['chi_squared_analysis']['pvalues']['data'] + chi_squared_grid = rf_data["chi_squared_analysis"]["pvalues"]["data"] chi_square_grid_NLL = pvalue_to_NLL(chi_squared_grid) - clim = (0, max(2,chi_square_grid_NLL.max())) - img = ax.imshow(chi_square_grid_NLL, interpolation='none', origin='lower', clim=clim, cmap=cmap) + clim = (0, max(2, chi_square_grid_NLL.max())) + img = ax.imshow(chi_square_grid_NLL, interpolation="none", origin="lower", clim=clim, cmap=cmap) if cax is None: cb = ax.figure.colorbar(img, ax=ax, ticks=clim) @@ -102,27 +105,37 @@ def plot_chi_square_summary(rf_data, ax=None, cax=None, cmap=DEFAULT_CMAP): cb.update_ticks() ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) - ax.set_title('Significant: %s (min_p=%s)' % (rf_data['chi_squared_analysis']['attrs']['significant'], - rf_data['chi_squared_analysis']['attrs']['min_p']) ) + ax.set_title( + "Significant: %s (min_p=%s)" + % (rf_data["chi_squared_analysis"]["attrs"]["significant"], rf_data["chi_squared_analysis"]["attrs"]["min_p"]) + ) + def plot_msr_summary(lsn, cell_index, ax_on, ax_off, ax_cbar=None, cmap=None): - min_clim = lsn.mean_response[:, :, cell_index,:].min() - max_clim = lsn.mean_response[:, :, cell_index,:].max() - plot_fields(lsn.mean_response[:, :, cell_index, 0], - lsn.mean_response[:, :, cell_index, 1], - ax_on, ax_off, clim=(min_clim, max_clim), cmap=cmap, cbar_axes=ax_cbar) + min_clim = lsn.mean_response[:, :, cell_index, :].min() + max_clim = lsn.mean_response[:, :, cell_index, :].max() + plot_fields( + lsn.mean_response[:, :, cell_index, 0], + lsn.mean_response[:, :, cell_index, 1], + ax_on, + ax_off, + clim=(min_clim, max_clim), + cmap=cmap, + cbar_axes=ax_cbar, + ) + def plot_fields(on_data, off_data, on_axes, off_axes, cbar_axes=None, clim=None, cmap=DEFAULT_CMAP): if cbar_axes is None: on_axes.figure.subplots_adjust(right=0.9) - cbar_axes = on_axes.figure.add_axes([0.93, 0.37, 0.02, .28]) - + cbar_axes = on_axes.figure.add_axes([0.93, 0.37, 0.02, 0.28]) + if clim is None: clim_max = max(np.nanmax(on_data), np.nanmax(off_data)) - clim = (0,clim_max) - on_axes.imshow(on_data, clim=clim, cmap=cmap, interpolation='none', origin='lower') + clim = (0, clim_max) + on_axes.imshow(on_data, clim=clim, cmap=cmap, interpolation="none", origin="lower") on_axes.set_title("on") - img = off_axes.imshow(off_data, clim=clim, cmap=cmap, interpolation='none', origin='lower') + img = off_axes.imshow(off_data, clim=clim, cmap=cmap, interpolation="none", origin="lower") off_axes.set_title("off") cb = cbar_axes.figure.colorbar(img, cax=cbar_axes, ticks=clim) tick_locator = ticker.MaxNLocator(nbins=5) @@ -132,44 +145,48 @@ def plot_fields(on_data, off_data, on_axes, off_axes, cbar_axes=None, clim=None, frame.axes.get_xaxis().set_visible(False) frame.axes.get_yaxis().set_visible(False) + def plot_rts_summary(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): - rts_on = rf_data['on']['rts']['data'] - rts_off = rf_data['off']['rts']['data'] + rts_on = rf_data["on"]["rts"]["data"] + rts_off = rf_data["off"]["rts"]["data"] plot_fields(rts_on, rts_off, ax_on, ax_off, cbar_axes=ax_cbar, cmap=cmap) + def plot_rts_blur_summary(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): - rts_on_blur = rf_data['on']['rts_convolution']['data'] - rts_off_blur = rf_data['off']['rts_convolution']['data'] + rts_on_blur = rf_data["on"]["rts_convolution"]["data"] + rts_off_blur = rf_data["off"]["rts_convolution"]["data"] plot_fields(rts_on_blur, rts_off_blur, ax_on, ax_off, cbar_axes=ax_cbar, cmap=cmap) + def plot_p_values(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): - pvalues_on = rf_data['on']['pvalues']['data'] - pvalues_off = rf_data['off']['pvalues']['data'] + pvalues_on = rf_data["on"]["pvalues"]["data"] + pvalues_off = rf_data["off"]["pvalues"]["data"] clim_max = max(pvalues_on.max(), pvalues_off.max()) - plot_fields(pvalues_on, pvalues_off, ax_on, ax_off, cbar_axes=ax_cbar, clim=(0, clim_max/2), cmap=cmap) + plot_fields(pvalues_on, pvalues_off, ax_on, ax_off, cbar_axes=ax_cbar, clim=(0, clim_max / 2), cmap=cmap) + def plot_mask(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): - pvalues_on = rf_data['on']['pvalues']['data'] - pvalues_off = rf_data['off']['pvalues']['data'] + pvalues_on = rf_data["on"]["pvalues"]["data"] + pvalues_off = rf_data["off"]["pvalues"]["data"] rf_on = pvalues_on.copy() rf_off = pvalues_off.copy() - rf_on[np.logical_not(rf_data['on']['fdr_mask']['data'].sum(axis=0))] = np.nan - rf_off[np.logical_not(rf_data['off']['fdr_mask']['data'].sum(axis=0))] = np.nan + rf_on[np.logical_not(rf_data["on"]["fdr_mask"]["data"].sum(axis=0))] = np.nan + rf_off[np.logical_not(rf_data["off"]["fdr_mask"]["data"].sum(axis=0))] = np.nan plot_fields(rf_on, rf_off, ax_on, ax_off, cbar_axes=ax_cbar, cmap=cmap) -def plot_gaussian_fit(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): - gf_on_exists = 'gaussian_fit' in rf_data['on'] - gf_off_exists = 'gaussian_fit' in rf_data['off'] +def plot_gaussian_fit(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): + gf_on_exists = "gaussian_fit" in rf_data["on"] + gf_off_exists = "gaussian_fit" in rf_data["off"] if not gf_on_exists and not gf_off_exists: return - img_data_on = rf_data['on']['gaussian_fit']['data'].sum(axis=0) if gf_on_exists else None - img_data_off = rf_data['off']['gaussian_fit']['data'].sum(axis=0) if gf_off_exists else None + img_data_on = rf_data["on"]["gaussian_fit"]["data"].sum(axis=0) if gf_on_exists else None + img_data_off = rf_data["off"]["gaussian_fit"]["data"].sum(axis=0) if gf_off_exists else None if gf_on_exists and gf_off_exists: plot_fields(img_data_on, img_data_off, ax_on, ax_off, cbar_axes=ax_cbar, cmap=cmap) @@ -181,21 +198,22 @@ def plot_gaussian_fit(rf_data, ax_on, ax_off, ax_cbar=None, cmap=DEFAULT_CMAP): plot_fields(img_data_on, img_data_off, ax_on, ax_off, cbar_axes=ax_cbar, cmap=cmap) + def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=True, cmap=DEFAULT_CMAP): - cell_index = rf['attrs']['cell_index'] + cell_index = rf["attrs"]["cell_index"] # Prepare plotting figure:n number_of_major_rows = 7 if lsn else 6 pwidth = 1.7 pheight = 1.0 - fig = plt.figure(figsize=(pwidth*2.3, pheight*number_of_major_rows)) - gsp = gridspec.GridSpec(number_of_major_rows, 3, width_ratios=[1,1,.1], right=0.9) + fig = plt.figure(figsize=(pwidth * 2.3, pheight * number_of_major_rows)) + gsp = gridspec.GridSpec(number_of_major_rows, 3, width_ratios=[1, 1, 0.1], right=0.9) ax_list = [] # Plot chi-square summary: row = 0 - curr_axes = fig.add_subplot(gsp[row,:2]) - cbar_axes = fig.add_subplot(gsp[row,-1]) + curr_axes = fig.add_subplot(gsp[row, :2]) + cbar_axes = fig.add_subplot(gsp[row, -1]) ax_list += [curr_axes] plot_chi_square_summary(rf, ax=curr_axes, cax=cbar_axes, cmap=cmap) @@ -212,7 +230,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru row += 1 curr_on_axes = fig.add_subplot(gsp[row, 0]) curr_off_axes = fig.add_subplot(gsp[row, 1]) - cbar_axes = fig.add_subplot(gsp[row,2]) + cbar_axes = fig.add_subplot(gsp[row, 2]) ax_list += [curr_on_axes, curr_off_axes] plot_rts_summary(rf, curr_on_axes, curr_off_axes, cbar_axes, cmap=cmap) @@ -220,7 +238,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru row += 1 curr_on_axes = fig.add_subplot(gsp[row, 0]) curr_off_axes = fig.add_subplot(gsp[row, 1]) - cbar_axes = fig.add_subplot(gsp[row,2]) + cbar_axes = fig.add_subplot(gsp[row, 2]) ax_list += [curr_on_axes, curr_off_axes] plot_rts_blur_summary(rf, curr_on_axes, curr_off_axes, cbar_axes, cmap=cmap) @@ -228,7 +246,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru row += 1 curr_on_axes = fig.add_subplot(gsp[row, 0]) curr_off_axes = fig.add_subplot(gsp[row, 1]) - cbar_axes = fig.add_subplot(gsp[row,2]) + cbar_axes = fig.add_subplot(gsp[row, 2]) ax_list += [curr_on_axes, curr_off_axes] plot_p_values(rf, curr_on_axes, curr_off_axes, cbar_axes, cmap=cmap) @@ -236,7 +254,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru row += 1 curr_on_axes = fig.add_subplot(gsp[row, 0]) curr_off_axes = fig.add_subplot(gsp[row, 1]) - cbar_axes = fig.add_subplot(gsp[row,2]) + cbar_axes = fig.add_subplot(gsp[row, 2]) ax_list += [curr_on_axes, curr_off_axes] plot_mask(rf, curr_on_axes, curr_off_axes, cbar_axes, cmap=cmap) @@ -244,7 +262,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru row += 1 curr_on_axes = fig.add_subplot(gsp[row, 0]) curr_off_axes = fig.add_subplot(gsp[row, 1]) - cbar_axes = fig.add_subplot(gsp[row,2]) + cbar_axes = fig.add_subplot(gsp[row, 2]) ax_list += [curr_on_axes, curr_off_axes] plot_gaussian_fit(rf, curr_on_axes, curr_off_axes, cbar_axes, cmap=cmap) @@ -253,7 +271,7 @@ def plot_receptive_field_data(rf, lsn, show=True, save_file_name=None, close=Tru plt.subplots_adjust(top=0.95) for ax in ax_list: - ax.set_adjustable('box-forced') + ax.set_adjustable("box-forced") if save_file_name is not None: fig.savefig(save_file_name) diff --git a/allensdk/brain_observatory/roi_masks.py b/allensdk/brain_observatory/roi_masks.py index 39d8873ec5..f3cae5744e 100644 --- a/allensdk/brain_observatory/roi_masks.py +++ b/allensdk/brain_observatory/roi_masks.py @@ -47,7 +47,7 @@ class Mask(object): - ''' + """ Abstract class to represent image segmentation mask. Its two main subclasses are RoiMask and NeuropilMask. The former represents the mask of a region of interest (ROI), such as a cell observed in @@ -70,18 +70,18 @@ class Mask(object): mask_group: integer User-defined number to help put masks into different categories - ''' + """ @property def overlaps_motion_border(self): # flags like this are now in self.flags, patch for backwards compatibility - return 'overlaps_motion_border' in self.flags + return "overlaps_motion_border" in self.flags def __init__(self, image_w, image_h, label, mask_group): - ''' + """ Mask class constructor. The Mask class is designed to be abstract and it should not be instantiated directly. - ''' + """ self.img_rows = image_h self.img_cols = image_w @@ -104,7 +104,7 @@ def __str__(self): return "%s: TL=%d,%d w,h=%d,%d\n%s" % (self.label, self.x, self.y, self.width, self.height, str(self.mask)) def init_by_pixels(self, border, pix_list): - ''' + """ Initialize mask using a list of mask pixels Parameters @@ -114,7 +114,7 @@ def init_by_pixels(self, border, pix_list): pix_list: integer[][2] List of pixel coordinates (x,y) that define the mask - ''' + """ assert pix_list.shape[1] == 2, "Pixel list not properly formed" array = np.zeros((self.img_rows, self.img_cols), dtype=bool) @@ -124,20 +124,20 @@ def init_by_pixels(self, border, pix_list): self.init_by_mask(border, array) def get_mask_plane(self): - ''' + """ Returns mask content on full-size image plane Returns ------- numpy 2D array [img_rows][img_cols] - ''' + """ mask = np.zeros((self.img_rows, self.img_cols)) - mask[self.y:self.y + self.height, self.x:self.x + self.width] = self.mask + mask[self.y : self.y + self.height, self.x : self.x + self.width] = self.mask return mask def create_roi_mask(image_w, image_h, border, pix_list=None, roi_mask=None, label=None, mask_group=-1): - ''' + """ Conveninece function to create and initializes an RoiMask Parameters @@ -183,7 +183,7 @@ def create_roi_mask(image_w, image_h, border, pix_list=None, roi_mask=None, labe Returns ------- RoiMask object - ''' + """ m = RoiMask(image_w, image_h, label, mask_group) if pix_list is not None: m.init_by_pixels(border, pix_list) @@ -195,9 +195,8 @@ def create_roi_mask(image_w, image_h, border, pix_list=None, roi_mask=None, labe class RoiMask(Mask): - def __init__(self, image_w, image_h, label, mask_group): - ''' + """ RoiMask class constructor Parameters @@ -213,11 +212,11 @@ def __init__(self, image_w, image_h, label, mask_group): mask_group: integer User-defined number to help put masks into different categories - ''' + """ super(RoiMask, self).__init__(image_w, image_h, label, mask_group) def init_by_mask(self, border, array): - ''' + """ Initialize mask using spatial mask Parameters @@ -228,11 +227,11 @@ def init_by_mask(self, border, array): roi_mask: integer[image height][image width] Image-sized array that describes the mask. Active parts of the mask should have values >0. Background pixels must be zero - ''' + """ px = np.argwhere(array) if len(px) == 0: - self.flags.add('zero_pixels') + self.flags.add("zero_pixels") return (top, left), (bottom, right) = px.min(0), px.max(0) @@ -246,20 +245,20 @@ def init_by_mask(self, border, array): # if ROI crosses border, it's considered invalid if left < l_inset or right > r_inset: - self.flags.add('overlaps_motion_border') + self.flags.add("overlaps_motion_border") if top < t_inset or bottom > b_inset: - self.flags.add('overlaps_motion_border') + self.flags.add("overlaps_motion_border") # self.x = left self.width = right - left + 1 self.y = top self.height = bottom - top + 1 # make copy of mask - self.mask = array[top:bottom + 1, left:right + 1] + self.mask = array[top : bottom + 1, left : right + 1] def create_neuropil_mask(roi, border, combined_binary_mask, label=None): - ''' + """ Conveninece function to create and initializes a Neuropil mask. Neuropil masks are defined as the region around an ROI, up to 13 pixels out, that does not include other ROIs @@ -271,7 +270,7 @@ def create_neuropil_mask(roi, border, combined_binary_mask, label=None): The ROI that the neuropil masks will be based on border: float[4] - Border widths on the [right, left, down, up] sides. The resulting + Border widths on the [right, left, down, up] sides. The resulting neuropil mask will not include pixels falling into a border. combined_binary_mask @@ -289,28 +288,25 @@ def create_neuropil_mask(roi, border, combined_binary_mask, label=None): Returns ------- NeuropilMask object - ''' + """ # combined_binary_mask is a bitmap union of ALL ROI masks # create a binary mask of the ROI binary_mask = np.zeros((roi.img_rows, roi.img_cols)) - binary_mask[roi.y:roi.y + roi.height, roi.x:roi.x + roi.width] = roi.mask + binary_mask[roi.y : roi.y + roi.height, roi.x : roi.x + roi.width] = roi.mask binary_mask = binary_mask > 0 # dilate the mask - binary_mask_dilated = morphology.binary_dilation( - binary_mask, structure=np.ones((3, 3)), iterations=13) # T/F + binary_mask_dilated = morphology.binary_dilation(binary_mask, structure=np.ones((3, 3)), iterations=13) # T/F # eliminate ROIs from the dilation binary_mask_dilated = binary_mask_dilated > combined_binary_mask # create mask from binary dilation - m = NeuropilMask(w=roi.img_cols, h=roi.img_rows, - label=label, mask_group=roi.mask_group) + m = NeuropilMask(w=roi.img_cols, h=roi.img_rows, label=label, mask_group=roi.mask_group) m.init_by_mask(border, binary_mask_dilated) return m class NeuropilMask(Mask): - def __init__(self, w, h, label, mask_group): - ''' + """ NeuropilMask class constructor. This class should be created by calling create_neuropil_mask() @@ -321,26 +317,26 @@ def __init__(self, w, h, label, mask_group): mask_group: integer User-defined number to help put masks into different categories - ''' + """ super(NeuropilMask, self).__init__(w, h, label, mask_group) def init_by_mask(self, border, array): - ''' + """ Initialize mask using spatial mask Parameters ---------- border: float[4] - Border widths on the [right, left, down, up] sides. The resulting + Border widths on the [right, left, down, up] sides. The resulting neuropil mask will not include pixels falling into a border. array: integer[image height][image width] Image-sized array that describes the mask. Active parts of the mask should have values >0. Background pixels must be zero - ''' + """ px = np.argwhere(array) if len(px) == 0: - self.flags.add('zero_pixels') + self.flags.add("zero_pixels") return (top, left), (bottom, right) = px.min(0), px.max(0) @@ -375,39 +371,32 @@ def init_by_mask(self, border, array): self.y = top self.height = bottom - top + 1 # make copy of mask - self.mask = array[top:bottom + 1, left:right + 1] + self.mask = array[top : bottom + 1, left : right + 1] + def validate_mask(mask): - '''Check a given roi or neuropil mask for (a subset of) disqualifying problems. - ''' + """Check a given roi or neuropil mask for (a subset of) disqualifying problems.""" exclusions = [] - if 'zero_pixels' in mask.flags or mask.mask.sum() == 0: - + if "zero_pixels" in mask.flags or mask.mask.sum() == 0: if isinstance(mask, NeuropilMask): - label = 'empty_neuropil_mask' + label = "empty_neuropil_mask" elif isinstance(mask, RoiMask): - label = 'empty_roi_mask' + label = "empty_roi_mask" else: - label = 'zero_pixels' + label = "zero_pixels" - exclusions.append({ - 'roi_id': mask.label, - 'exclusion_label_name': label - }) + exclusions.append({"roi_id": mask.label, "exclusion_label_name": label}) - if 'overlaps_motion_border' in mask.flags: - exclusions.append({ - 'roi_id': mask.label, - 'exclusion_label_name': 'motion_border' - }) + if "overlaps_motion_border" in mask.flags: + exclusions.append({"roi_id": mask.label, "exclusion_label_name": "motion_border"}) return exclusions - + def calculate_traces(stack, mask_list, block_size=1000): - ''' + """ Calculates the average response of the specified masks in the image stack @@ -423,7 +412,7 @@ def calculate_traces(stack, mask_list, block_size=1000): ------- float[number masks][number frames] This is the average response for each Mask in each image frame - ''' + """ traces = np.zeros((len(mask_list), stack.shape[0]), dtype=float) num_frames = stack.shape[0] @@ -434,14 +423,13 @@ def calculate_traces(stack, mask_list, block_size=1000): exclusions = [] for i, mask in enumerate(mask_list): - current_exclusions = validate_mask(mask) if len(current_exclusions) > 0: - traces[i,:] = np.nan + traces[i, :] = np.nan valid_masks[i] = False exclusions.extend(current_exclusions) reasons = ", ".join([item["exclusion_label_name"] for item in current_exclusions]) - logging.warning("unable to extract traces for mask \"{}\": {} ".format(mask.label, reasons)) + logging.warning('unable to extract traces for mask "{}": {} '.format(mask.label, reasons)) continue if not isinstance(mask.mask, np.ndarray): @@ -452,25 +440,25 @@ def calculate_traces(stack, mask_list, block_size=1000): for frame_num in range(0, num_frames, block_size): if frame_num % block_size == 0: logging.debug("frame " + str(frame_num) + " of " + str(num_frames)) - frames = stack[frame_num:frame_num+block_size] + frames = stack[frame_num : frame_num + block_size] for i in range(len(mask_list)): if not valid_masks[i]: continue mask = mask_list[i] - subframe = frames[:,mask.y:mask.y + mask.height, - mask.x:mask.x + mask.width] + subframe = frames[:, mask.y : mask.y + mask.height, mask.x : mask.x + mask.width] total = subframe[:, mask.mask].sum(axis=1) - traces[i, frame_num:frame_num+block_size] = total / mask_areas[i] + traces[i, frame_num : frame_num + block_size] = total / mask_areas[i] return traces, exclusions + def calculate_roi_and_neuropil_traces(movie_h5, roi_mask_list, motion_border): - """ get roi and neuropil masks """ + """get roi and neuropil masks""" - # a combined binary mask for all ROIs (this is used to + # a combined binary mask for all ROIs (this is used to # subtracted ROIs from annuli mask_array = create_roi_mask_array(roi_mask_list) combined_mask = mask_array.max(axis=0) @@ -484,7 +472,7 @@ def calculate_roi_and_neuropil_traces(movie_h5, roi_mask_list, motion_border): neuropil_masks.append(nmask) num_rois = len(roi_mask_list) - combined_list = roi_mask_list + neuropil_masks # read the large image stack only once + combined_list = roi_mask_list + neuropil_masks # read the large image stack only once with h5py.File(movie_h5, "r") as movie_f: stack_frames = movie_f["data"] @@ -499,7 +487,7 @@ def calculate_roi_and_neuropil_traces(movie_h5, roi_mask_list, motion_border): def create_roi_mask_array(rois): - '''Create full image mask array from list of RoiMasks. + """Create full image mask array from list of RoiMasks. Parameters ---------- @@ -510,7 +498,7 @@ def create_roi_mask_array(rois): ------- np.ndarray: NxWxH array Boolean array of of len(rois) image masks. - ''' + """ if rois: height = rois[0].img_rows width = rois[0].img_cols @@ -520,4 +508,3 @@ def create_roi_mask_array(rois): else: masks = None return masks - diff --git a/allensdk/brain_observatory/running_speed.py b/allensdk/brain_observatory/running_speed.py index c579a15af5..8b351fe098 100644 --- a/allensdk/brain_observatory/running_speed.py +++ b/allensdk/brain_observatory/running_speed.py @@ -4,14 +4,14 @@ class RunningSpeed(NamedTuple): - ''' Describes the rate at which an experimental subject ran during a session. + """Describes the rate at which an experimental subject ran during a session. values : np.ndarray running speed (cm/s) at each sample point timestamps : np.ndarray The time at which each sample was collected (s). - ''' + """ timestamps: np.ndarray values: np.ndarray diff --git a/allensdk/brain_observatory/session_analysis.py b/allensdk/brain_observatory/session_analysis.py index 8ff933aa79..dc44039cab 100644 --- a/allensdk/brain_observatory/session_analysis.py +++ b/allensdk/brain_observatory/session_analysis.py @@ -40,11 +40,9 @@ from .drifting_gratings import DriftingGratings from .natural_movie import NaturalMovie -from allensdk.core.brain_observatory_nwb_data_set \ - import BrainObservatoryNwbDataSet +from allensdk.core.brain_observatory_nwb_data_set import BrainObservatoryNwbDataSet from . import stimulus_info -from allensdk.brain_observatory.brain_observatory_exceptions \ - import BrainObservatoryAnalysisException +from allensdk.brain_observatory.brain_observatory_exceptions import BrainObservatoryAnalysisException from . import brain_observatory_plotting as cp import argparse import logging @@ -53,28 +51,26 @@ from allensdk.deprecated import deprecated - def multi_dataframe_merge(dfs): - """ merge a number of pd.DataFrames into a single dataframe on their index columns. - If any columns are duplicated, prefer the first occuring instance of the column """ + """merge a number of pd.DataFrames into a single dataframe on their index columns. + If any columns are duplicated, prefer the first occuring instance of the column""" out_df = None for _, df in enumerate(dfs): if out_df is None: out_df = df else: - out_df = out_df.merge(df, left_index=True, - right_index=True, suffixes=['', '_deleteme']) + out_df = out_df.merge(df, left_index=True, right_index=True, suffixes=["", "_deleteme"]) - bad_columns = set([c for c in out_df.columns if c.endswith('deleteme')]) + bad_columns = set([c for c in out_df.columns if c.endswith("deleteme")]) out_df.drop(list(bad_columns), axis=1, inplace=True) return out_df class SessionAnalysis(object): - """ - Run all of the stimulus-specific analyses associated with a single experiment session. + """ + Run all of the stimulus-specific analyses associated with a single experiment session. Parameters ---------- @@ -83,27 +79,27 @@ class SessionAnalysis(object): save_path: string, path to HDF5 file to store outputs. Recommended NOT to modify the NWB file. """ - _log = logging.getLogger('allensdk.brain_observatory.session_analysis') + _log = logging.getLogger("allensdk.brain_observatory.session_analysis") def __init__(self, nwb_path, save_path): self.nwb = BrainObservatoryNwbDataSet(nwb_path) self.save_path = save_path self.save_dir = os.path.dirname(save_path) - self.metrics_a = dict(cell={},experiment={}) - self.metrics_b = dict(cell={},experiment={}) - self.metrics_c = dict(cell={},experiment={}) + self.metrics_a = dict(cell={}, experiment={}) + self.metrics_b = dict(cell={}, experiment={}) + self.metrics_c = dict(cell={}, experiment={}) self.metadata = self.nwb.get_metadata() def append_metadata(self, df): - """ Append the metadata fields from the NWB file as columns to a pd.DataFrame """ + """Append the metadata fields from the NWB file as columns to a pd.DataFrame""" for k, v in self.metadata.items(): df[k] = v def save_session_a(self, dg, nm1, nm3, peak): - """ Save the output of session A analysis to self.save_path. + """Save the output of session A analysis to self.save_path. Parameters ---------- @@ -124,28 +120,28 @@ def save_session_a(self, dg, nm1, nm3, peak): nwb = BrainObservatoryNwbDataSet(self.save_path) nwb.save_analysis_dataframes( - ('stim_table_dg', dg.stim_table), - ('sweep_response_dg', dg.sweep_response), - ('mean_sweep_response_dg', dg.mean_sweep_response), - ('peak', peak), - ('sweep_response_nm1', nm1.sweep_response), - ('stim_table_nm1', nm1.stim_table), - ('sweep_response_nm3', nm3.sweep_response)) + ("stim_table_dg", dg.stim_table), + ("sweep_response_dg", dg.sweep_response), + ("mean_sweep_response_dg", dg.mean_sweep_response), + ("peak", peak), + ("sweep_response_nm1", nm1.sweep_response), + ("stim_table_nm1", nm1.stim_table), + ("sweep_response_nm3", nm3.sweep_response), + ) nwb.save_analysis_arrays( - ('response_dg', dg.response), - ('binned_cells_sp', nm1.binned_cells_sp), - ('binned_cells_vis', nm1.binned_cells_vis), - ('binned_dx_sp', nm1.binned_dx_sp), - ('binned_dx_vis', nm1.binned_dx_vis), - ('noise_corr_dg', dg.noise_correlation), - ('signal_corr_dg', dg.signal_correlation), - ('rep_similarity_dg', dg.representational_similarity) - ) - + ("response_dg", dg.response), + ("binned_cells_sp", nm1.binned_cells_sp), + ("binned_cells_vis", nm1.binned_cells_vis), + ("binned_dx_sp", nm1.binned_dx_sp), + ("binned_dx_vis", nm1.binned_dx_vis), + ("noise_corr_dg", dg.noise_correlation), + ("signal_corr_dg", dg.signal_correlation), + ("rep_similarity_dg", dg.representational_similarity), + ) def save_session_b(self, sg, nm1, ns, peak): - """ Save the output of session B analysis to self.save_path. + """Save the output of session B analysis to self.save_path. Parameters ---------- @@ -164,33 +160,34 @@ def save_session_b(self, sg, nm1, ns, peak): nwb = BrainObservatoryNwbDataSet(self.save_path) nwb.save_analysis_dataframes( - ('stim_table_sg', sg.stim_table), - ('sweep_response_sg', sg.sweep_response), - ('mean_sweep_response_sg', sg.mean_sweep_response), - ('sweep_response_nm1', nm1.sweep_response), - ('stim_table_nm1', nm1.stim_table), - ('sweep_response_ns', ns.sweep_response), - ('stim_table_ns', ns.stim_table), - ('mean_sweep_response_ns', ns.mean_sweep_response), - ('peak', peak)) + ("stim_table_sg", sg.stim_table), + ("sweep_response_sg", sg.sweep_response), + ("mean_sweep_response_sg", sg.mean_sweep_response), + ("sweep_response_nm1", nm1.sweep_response), + ("stim_table_nm1", nm1.stim_table), + ("sweep_response_ns", ns.sweep_response), + ("stim_table_ns", ns.stim_table), + ("mean_sweep_response_ns", ns.mean_sweep_response), + ("peak", peak), + ) nwb.save_analysis_arrays( - ('response_sg', sg.response), - ('response_ns', ns.response), - ('binned_cells_sp', nm1.binned_cells_sp), - ('binned_cells_vis', nm1.binned_cells_vis), - ('binned_dx_sp', nm1.binned_dx_sp), - ('binned_dx_vis', nm1.binned_dx_vis), - ('noise_corr_sg', sg.noise_correlation), - ('signal_corr_sg', sg.signal_correlation), - ('rep_similarity_sg', sg.representational_similarity), - ('noise_corr_ns', ns.noise_correlation), - ('signal_corr_ns', ns.signal_correlation), - ('rep_similarity_ns', ns.representational_similarity) - ) + ("response_sg", sg.response), + ("response_ns", ns.response), + ("binned_cells_sp", nm1.binned_cells_sp), + ("binned_cells_vis", nm1.binned_cells_vis), + ("binned_dx_sp", nm1.binned_dx_sp), + ("binned_dx_vis", nm1.binned_dx_vis), + ("noise_corr_sg", sg.noise_correlation), + ("signal_corr_sg", sg.signal_correlation), + ("rep_similarity_sg", sg.representational_similarity), + ("noise_corr_ns", ns.noise_correlation), + ("signal_corr_ns", ns.signal_correlation), + ("rep_similarity_ns", ns.representational_similarity), + ) def save_session_c(self, lsn, nm1, nm2, peak): - """ Save the output of session C analysis to self.save_path. + """Save the output of session C analysis to self.save_path. Parameters ---------- @@ -211,34 +208,38 @@ def save_session_c(self, lsn, nm1, nm2, peak): nwb = BrainObservatoryNwbDataSet(self.save_path) nwb.save_analysis_dataframes( - ('stim_table_lsn', lsn.stim_table), - ('sweep_response_nm1', nm1.sweep_response), - ('peak', peak), - ('sweep_response_nm2', nm2.sweep_response), - ('sweep_response_lsn', lsn.sweep_response), - ('mean_sweep_response_lsn', lsn.mean_sweep_response)) + ("stim_table_lsn", lsn.stim_table), + ("sweep_response_nm1", nm1.sweep_response), + ("peak", peak), + ("sweep_response_nm2", nm2.sweep_response), + ("sweep_response_lsn", lsn.sweep_response), + ("mean_sweep_response_lsn", lsn.mean_sweep_response), + ) nwb.save_analysis_arrays( - ('receptive_field_lsn', lsn.receptive_field), - ('mean_response_lsn', lsn.mean_response), - ('binned_dx_sp', nm1.binned_dx_sp), - ('binned_dx_vis', nm1.binned_dx_vis), - ('binned_cells_sp', nm1.binned_cells_sp), - ('binned_cells_vis', nm1.binned_cells_vis)) + ("receptive_field_lsn", lsn.receptive_field), + ("mean_response_lsn", lsn.mean_response), + ("binned_dx_sp", nm1.binned_dx_sp), + ("binned_dx_vis", nm1.binned_dx_vis), + ("binned_cells_sp", nm1.binned_cells_sp), + ("binned_cells_vis", nm1.binned_cells_vis), + ) - LocallySparseNoise.save_cell_index_receptive_field_analysis(lsn.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE) + LocallySparseNoise.save_cell_index_receptive_field_analysis( + lsn.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE + ) - def save_session_c2(self, lsn4, lsn8, nm1, nm2, peak): - """ Save the output of session C2 analysis to self.save_path. + def save_session_c2(self, lsn4, lsn8, nm1, nm2, peak): + """Save the output of session C2 analysis to self.save_path. Parameters ---------- lsn4: LocallySparseNoise instance - This LocallySparseNoise instance should have been created with + This LocallySparseNoise instance should have been created with self.stimulus = stimulus_info.LOCALLY_SPARSE_NOISE_4DEG. lsn8: LocallySparseNoise instance - This LocallySparseNoise instance should have been created with + This LocallySparseNoise instance should have been created with self.stimulus = stimulus_info.LOCALLY_SPARSE_NOISE_8DEG. nm1: NaturalMovie instance @@ -256,41 +257,44 @@ def save_session_c2(self, lsn4, lsn8, nm1, nm2, peak): nwb = BrainObservatoryNwbDataSet(self.save_path) nwb.save_analysis_dataframes( - ('stim_table_lsn4', lsn4.stim_table), - ('stim_table_lsn8', lsn8.stim_table), - ('sweep_response_nm1', nm1.sweep_response), - ('peak', peak), - ('sweep_response_nm2', nm2.sweep_response), - ('sweep_response_lsn4', lsn4.sweep_response), - ('sweep_response_lsn8', lsn8.sweep_response), - ('mean_sweep_response_lsn4', lsn4.mean_sweep_response), - ('mean_sweep_response_lsn8', lsn8.mean_sweep_response)) - - merge_mean_response = LocallySparseNoise.merge_mean_response( - lsn4.mean_response, - lsn8.mean_response) + ("stim_table_lsn4", lsn4.stim_table), + ("stim_table_lsn8", lsn8.stim_table), + ("sweep_response_nm1", nm1.sweep_response), + ("peak", peak), + ("sweep_response_nm2", nm2.sweep_response), + ("sweep_response_lsn4", lsn4.sweep_response), + ("sweep_response_lsn8", lsn8.sweep_response), + ("mean_sweep_response_lsn4", lsn4.mean_sweep_response), + ("mean_sweep_response_lsn8", lsn8.mean_sweep_response), + ) + + merge_mean_response = LocallySparseNoise.merge_mean_response(lsn4.mean_response, lsn8.mean_response) nwb.save_analysis_arrays( - ('mean_response_lsn4', lsn4.mean_response), - ('mean_response_lsn8', lsn8.mean_response), - ('receptive_field_lsn4', lsn4.receptive_field), - ('receptive_field_lsn8', lsn8.receptive_field), - ('merge_mean_response', merge_mean_response), - ('binned_dx_sp', nm1.binned_dx_sp), - ('binned_dx_vis', nm1.binned_dx_vis), - ('binned_cells_sp', nm1.binned_cells_sp), - ('binned_cells_vis', nm1.binned_cells_vis)) - - LocallySparseNoise.save_cell_index_receptive_field_analysis(lsn4.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE_4DEG) - LocallySparseNoise.save_cell_index_receptive_field_analysis(lsn8.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE_8DEG) + ("mean_response_lsn4", lsn4.mean_response), + ("mean_response_lsn8", lsn8.mean_response), + ("receptive_field_lsn4", lsn4.receptive_field), + ("receptive_field_lsn8", lsn8.receptive_field), + ("merge_mean_response", merge_mean_response), + ("binned_dx_sp", nm1.binned_dx_sp), + ("binned_dx_vis", nm1.binned_dx_vis), + ("binned_cells_sp", nm1.binned_cells_sp), + ("binned_cells_vis", nm1.binned_cells_vis), + ) + + LocallySparseNoise.save_cell_index_receptive_field_analysis( + lsn4.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE_4DEG + ) + LocallySparseNoise.save_cell_index_receptive_field_analysis( + lsn8.cell_index_receptive_field_analysis_data, nwb, stimulus_info.LOCALLY_SPARSE_NOISE_8DEG + ) def append_metrics_drifting_grating(self, metrics, dg): - """ Extract metrics from the DriftingGratings peak response table into a dictionary. """ + """Extract metrics from the DriftingGratings peak response table into a dictionary.""" metrics["osi_dg"] = dg.peak["osi_dg"] metrics["dsi_dg"] = dg.peak["dsi_dg"] - metrics["pref_dir_dg"] = [dg.orivals[i] - for i in dg.peak["ori_dg"].values] + metrics["pref_dir_dg"] = [dg.orivals[i] for i in dg.peak["ori_dg"].values] metrics["pref_tf_dg"] = [dg.tfvals[i] for i in dg.peak["tf_dg"].values] metrics["p_dg"] = dg.peak["ptest_dg"] metrics["g_osi_dg"] = dg.peak["cv_os_dg"] @@ -302,14 +306,12 @@ def append_metrics_drifting_grating(self, metrics, dg): metrics["peak_dff_dg"] = dg.peak["peak_dff_dg"] def append_metrics_static_grating(self, metrics, sg): - """ Extract metrics from the StaticGratings peak response table into a dictionary. """ + """Extract metrics from the StaticGratings peak response table into a dictionary.""" metrics["osi_sg"] = sg.peak["osi_sg"] - metrics["pref_ori_sg"] = [sg.orivals[i] - for i in sg.peak["ori_sg"].values] + metrics["pref_ori_sg"] = [sg.orivals[i] for i in sg.peak["ori_sg"].values] metrics["pref_sf_sg"] = [sg.sfvals[i] for i in sg.peak["sf_sg"].values] - metrics["pref_phase_sg"] = [sg.phasevals[i] - for i in sg.peak["phase_sg"].values] + metrics["pref_phase_sg"] = [sg.phasevals[i] for i in sg.peak["phase_sg"].values] metrics["p_sg"] = sg.peak["ptest_sg"] metrics["time_to_peak_sg"] = sg.peak["time_to_peak_sg"] metrics["run_mod_sg"] = sg.peak["run_modulation_sg"] @@ -320,7 +322,7 @@ def append_metrics_static_grating(self, metrics, sg): metrics["reliability_sg"] = sg.peak["reliability_sg"] def append_metrics_natural_scene(self, metrics, ns): - """ Extract metrics from the NaturalScenes peak response table into a dictionary. """ + """Extract metrics from the NaturalScenes peak response table into a dictionary.""" metrics["pref_image_ns"] = ns.peak["scene_ns"] metrics["p_ns"] = ns.peak["ptest_ns"] @@ -332,49 +334,47 @@ def append_metrics_natural_scene(self, metrics, ns): metrics["peak_dff_ns"] = ns.peak["peak_dff_ns"] def append_metrics_locally_sparse_noise(self, metrics, lsn): - """ Extract metrics from the LocallySparseNoise peak response table into a dictionary. """ - - metrics['rf_chi2_lsn'] = lsn.peak['rf_chi2_lsn'] - metrics['rf_area_on_lsn'] = lsn.peak['rf_area_on_lsn'] - metrics['rf_center_on_x_lsn'] = lsn.peak['rf_center_on_x_lsn'] - metrics['rf_center_on_y_lsn'] = lsn.peak['rf_center_on_y_lsn'] - metrics['rf_area_off_lsn'] = lsn.peak['rf_area_off_lsn'] - metrics['rf_center_off_x_lsn'] = lsn.peak['rf_center_off_x_lsn'] - metrics['rf_center_off_y_lsn'] = lsn.peak['rf_center_off_y_lsn'] - metrics['rf_distance_lsn'] = lsn.peak['rf_distance_lsn'] - metrics['rf_overlap_index_lsn'] = lsn.peak['rf_overlap_index_lsn'] + """Extract metrics from the LocallySparseNoise peak response table into a dictionary.""" + + metrics["rf_chi2_lsn"] = lsn.peak["rf_chi2_lsn"] + metrics["rf_area_on_lsn"] = lsn.peak["rf_area_on_lsn"] + metrics["rf_center_on_x_lsn"] = lsn.peak["rf_center_on_x_lsn"] + metrics["rf_center_on_y_lsn"] = lsn.peak["rf_center_on_y_lsn"] + metrics["rf_area_off_lsn"] = lsn.peak["rf_area_off_lsn"] + metrics["rf_center_off_x_lsn"] = lsn.peak["rf_center_off_x_lsn"] + metrics["rf_center_off_y_lsn"] = lsn.peak["rf_center_off_y_lsn"] + metrics["rf_distance_lsn"] = lsn.peak["rf_distance_lsn"] + metrics["rf_overlap_index_lsn"] = lsn.peak["rf_overlap_index_lsn"] def append_metrics_natural_movie_one(self, metrics, nma): - """ Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_ONE) peak response table into a dictionary. """ - metrics['reliability_nm1'] = nma.peak['response_reliability_nm1'] + """Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_ONE) peak response table into a dictionary.""" + metrics["reliability_nm1"] = nma.peak["response_reliability_nm1"] def append_metrics_natural_movie_two(self, metrics, nma): - """ Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_TWO) peak response table into a dictionary. """ - metrics['reliability_nm2'] = nma.peak['response_reliability_nm2'] + """Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_TWO) peak response table into a dictionary.""" + metrics["reliability_nm2"] = nma.peak["response_reliability_nm2"] def append_metrics_natural_movie_three(self, metrics, nma): - """ Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_THREE) peak response table into a dictionary. """ - metrics['reliability_nm3'] = nma.peak['response_reliability_nm3'] + """Extract metrics from the NaturalMovie(stimulus_info.NATURAL_MOVIE_THREE) peak response table into a dictionary.""" + metrics["reliability_nm3"] = nma.peak["response_reliability_nm3"] def append_experiment_metrics(self, metrics): - """ Extract stimulus-agnostic metrics from an experiment into a dictionary """ + """Extract stimulus-agnostic metrics from an experiment into a dictionary""" dxcm, dxtime = self.nwb.get_running_speed() - metrics['mean_running_speed'] = np.nanmean(dxcm) + metrics["mean_running_speed"] = np.nanmean(dxcm) def verify_roi_lists_equal(self, roi1, roi2): - """ TODO: replace this with simpler numpy comparisons """ + """TODO: replace this with simpler numpy comparisons""" if len(roi1) != len(roi2): - raise BrainObservatoryAnalysisException( - "Error -- ROI lists are of different length") + raise BrainObservatoryAnalysisException("Error -- ROI lists are of different length") for i in range(len(roi1)): if roi1[i] != roi2[i]: - raise BrainObservatoryAnalysisException( - "Error -- ROI lists have different entries") + raise BrainObservatoryAnalysisException("Error -- ROI lists have different entries") def session_a(self, plot_flag=False, save_flag=True): - """ Run stimulus-specific analysis for natural movie one, natural movie three, and drifting gratings. + """Run stimulus-specific analysis for natural movie one, natural movie three, and drifting gratings. The input NWB be for a stimulus_info.THREE_SESSION_A experiment. Parameters @@ -386,8 +386,8 @@ def session_a(self, plot_flag=False, save_flag=True): Whether to save the output of analysis to self.save_path upon completion. """ - nm1 = NaturalMovie(self.nwb, 'natural_movie_one') - nm3 = NaturalMovie(self.nwb, 'natural_movie_three') + nm1 = NaturalMovie(self.nwb, "natural_movie_one") + nm3 = NaturalMovie(self.nwb, "natural_movie_three") dg = DriftingGratings(self.nwb) dg.noise_correlation, _, _, _ = dg.get_noise_correlation() @@ -395,14 +395,13 @@ def session_a(self, plot_flag=False, save_flag=True): dg.representational_similarity, _ = dg.get_representational_similarity() SessionAnalysis._log.info("Session A analyzed") - peak = multi_dataframe_merge( - [nm1.peak_run, dg.peak, nm1.peak, nm3.peak]) + peak = multi_dataframe_merge([nm1.peak_run, dg.peak, nm1.peak, nm3.peak]) - self.append_metrics_drifting_grating(self.metrics_a['cell'], dg) - self.append_metrics_natural_movie_one(self.metrics_a['cell'], nm1) - self.append_metrics_natural_movie_three(self.metrics_a['cell'], nm3) - self.append_experiment_metrics(self.metrics_a['experiment']) - self.metrics_a['cell']['roi_id'] = dg.roi_id + self.append_metrics_drifting_grating(self.metrics_a["cell"], dg) + self.append_metrics_natural_movie_one(self.metrics_a["cell"], nm1) + self.append_metrics_natural_movie_three(self.metrics_a["cell"], nm3) + self.append_experiment_metrics(self.metrics_a["experiment"]) + self.metrics_a["cell"]["roi_id"] = dg.roi_id self.append_metadata(peak) @@ -414,7 +413,7 @@ def session_a(self, plot_flag=False, save_flag=True): cp.plot_drifting_grating_traces(dg, self.save_dir) def session_b(self, plot_flag=False, save_flag=True): - """ Run stimulus-specific analysis for natural scenes, static gratings, and natural movie one. + """Run stimulus-specific analysis for natural scenes, static gratings, and natural movie one. The input NWB be for a stimulus_info.THREE_SESSION_B experiment. Parameters @@ -428,18 +427,17 @@ def session_b(self, plot_flag=False, save_flag=True): ns = NaturalScenes(self.nwb) sg = StaticGratings(self.nwb) - nm1 = NaturalMovie(self.nwb, 'natural_movie_one') + nm1 = NaturalMovie(self.nwb, "natural_movie_one") SessionAnalysis._log.info("Session B analyzed") - peak = multi_dataframe_merge( - [nm1.peak_run, sg.peak, ns.peak, nm1.peak]) + peak = multi_dataframe_merge([nm1.peak_run, sg.peak, ns.peak, nm1.peak]) self.append_metadata(peak) - self.append_metrics_static_grating(self.metrics_b['cell'], sg) - self.append_metrics_natural_scene(self.metrics_b['cell'], ns) - self.append_metrics_natural_movie_one(self.metrics_b['cell'], nm1) - self.append_experiment_metrics(self.metrics_b['experiment']) + self.append_metrics_static_grating(self.metrics_b["cell"], sg) + self.append_metrics_natural_scene(self.metrics_b["cell"], ns) + self.append_metrics_natural_movie_one(self.metrics_b["cell"], nm1) + self.append_experiment_metrics(self.metrics_b["experiment"]) self.verify_roi_lists_equal(sg.roi_id, ns.roi_id) - self.metrics_b['cell']['roi_id'] = sg.roi_id + self.metrics_b["cell"]["roi_id"] = sg.roi_id sg.noise_correlation, _, _, _ = sg.get_noise_correlation() sg.signal_correlation, _ = sg.get_signal_correlation() @@ -458,7 +456,7 @@ def session_b(self, plot_flag=False, save_flag=True): cp.plot_sg_traces(sg, self.save_dir) def session_c(self, plot_flag=False, save_flag=True): - """ Run stimulus-specific analysis for natural movie one, natural movie two, and locally sparse noise. + """Run stimulus-specific analysis for natural movie one, natural movie two, and locally sparse noise. The input NWB be for a stimulus_info.THREE_SESSION_C experiment. Parameters @@ -471,17 +469,17 @@ def session_c(self, plot_flag=False, save_flag=True): """ lsn = LocallySparseNoise(self.nwb, stimulus_info.LOCALLY_SPARSE_NOISE) - nm2 = NaturalMovie(self.nwb, 'natural_movie_two') - nm1 = NaturalMovie(self.nwb, 'natural_movie_one') + nm2 = NaturalMovie(self.nwb, "natural_movie_two") + nm1 = NaturalMovie(self.nwb, "natural_movie_one") SessionAnalysis._log.info("Session C analyzed") peak = multi_dataframe_merge([nm1.peak_run, nm1.peak, nm2.peak, lsn.peak]) self.append_metadata(peak) - self.append_metrics_locally_sparse_noise(self.metrics_c['cell'], lsn) - self.append_metrics_natural_movie_one(self.metrics_c['cell'], nm1) - self.append_metrics_natural_movie_two(self.metrics_c['cell'], nm2) - self.append_experiment_metrics(self.metrics_c['experiment']) - self.metrics_c['cell']['roi_id'] = nm1.roi_id + self.append_metrics_locally_sparse_noise(self.metrics_c["cell"], lsn) + self.append_metrics_natural_movie_one(self.metrics_c["cell"], nm1) + self.append_metrics_natural_movie_two(self.metrics_c["cell"], nm2) + self.append_experiment_metrics(self.metrics_c["experiment"]) + self.metrics_c["cell"]["roi_id"] = nm1.roi_id if save_flag: self.save_session_c(lsn, nm1, nm2, peak) @@ -491,7 +489,7 @@ def session_c(self, plot_flag=False, save_flag=True): cp.plot_lsn_traces(lsn, self.save_dir) def session_c2(self, plot_flag=False, save_flag=True): - """ Run stimulus-specific analysis for locally sparse noise (4 deg.), locally sparse noise (8 deg.), + """Run stimulus-specific analysis for locally sparse noise (4 deg.), locally sparse noise (8 deg.), natural movie one, and natural movie two. The input NWB be for a stimulus_info.THREE_SESSION_C2 experiment. Parameters @@ -506,11 +504,11 @@ def session_c2(self, plot_flag=False, save_flag=True): lsn4 = LocallySparseNoise(self.nwb, stimulus_info.LOCALLY_SPARSE_NOISE_4DEG) lsn8 = LocallySparseNoise(self.nwb, stimulus_info.LOCALLY_SPARSE_NOISE_8DEG) - nm2 = NaturalMovie(self.nwb, 'natural_movie_two') - nm1 = NaturalMovie(self.nwb, 'natural_movie_one') + nm2 = NaturalMovie(self.nwb, "natural_movie_two") + nm1 = NaturalMovie(self.nwb, "natural_movie_one") SessionAnalysis._log.info("Session C2 analyzed") - if self.nwb.get_metadata()['targeted_structure'] == 'VISp': + if self.nwb.get_metadata()["targeted_structure"] == "VISp": lsn_peak = lsn4 else: lsn_peak = lsn8 @@ -518,24 +516,24 @@ def session_c2(self, plot_flag=False, save_flag=True): peak = multi_dataframe_merge([nm1.peak_run, nm1.peak, nm2.peak, lsn_peak.peak]) self.append_metadata(peak) - self.append_metrics_locally_sparse_noise(self.metrics_c['cell'], lsn_peak) - self.append_metrics_natural_movie_one(self.metrics_c['cell'], nm1) - self.append_metrics_natural_movie_two(self.metrics_c['cell'], nm2) - self.append_experiment_metrics(self.metrics_c['experiment']) - self.metrics_c['cell']['roi_id'] = nm1.roi_id + self.append_metrics_locally_sparse_noise(self.metrics_c["cell"], lsn_peak) + self.append_metrics_natural_movie_one(self.metrics_c["cell"], nm1) + self.append_metrics_natural_movie_two(self.metrics_c["cell"], nm2) + self.append_experiment_metrics(self.metrics_c["experiment"]) + self.metrics_c["cell"]["roi_id"] = nm1.roi_id if save_flag: self.save_session_c2(lsn4, lsn8, nm1, nm2, peak) if plot_flag: - cp._plot_3sc(lsn4, nm1, nm2, self.save_dir, '_4deg') - cp._plot_3sc(lsn8, nm1, nm2, self.save_dir, '_8deg') - cp.plot_lsn_traces(lsn4, self.save_dir, '_4deg') - cp.plot_lsn_traces(lsn4, self.save_dir, '_8deg') + cp._plot_3sc(lsn4, nm1, nm2, self.save_dir, "_4deg") + cp._plot_3sc(lsn8, nm1, nm2, self.save_dir, "_8deg") + cp.plot_lsn_traces(lsn4, self.save_dir, "_4deg") + cp.plot_lsn_traces(lsn4, self.save_dir, "_8deg") def run_session_analysis(nwb_path, save_path, plot_flag=False, save_flag=True): - """ Inspect an NWB file to determine which experiment session was run + """Inspect an NWB file to determine which experiment session was run and compute all stimulus-specific analyses. Parameters @@ -580,13 +578,13 @@ def run_session_analysis(nwb_path, save_path, plot_flag=False, save_flag=True): return metrics -@deprecated('use the standalone version in bin/brain_observatory') +@deprecated("use the standalone version in bin/brain_observatory") def main(): parser = argparse.ArgumentParser() parser.add_argument("input_nwb") parser.add_argument("output_h5") - parser.add_argument("--plot", action='store_true') + parser.add_argument("--plot", action="store_true") args = parser.parse_args() logging.basicConfig() @@ -595,5 +593,5 @@ def main(): run_session_analysis(args.input_nwb, args.output_h5, args.plot) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/brain_observatory/session_api_utils.py b/allensdk/brain_observatory/session_api_utils.py index 347c8d5e1b..6ec6f76f68 100644 --- a/allensdk/brain_observatory/session_api_utils.py +++ b/allensdk/brain_observatory/session_api_utils.py @@ -38,8 +38,7 @@ def is_equal(a: Any, b: Any) -> bool: return False return True elif isinstance(a, dict): - for (a_k, a_v), (b_k, b_v) in zip_longest(sorted(a.items()), - sorted(b.items())): + for (a_k, a_v), (b_k, b_v) in zip_longest(sorted(a.items()), sorted(b.items())): if (a_k != b_k) or (not is_equal(a_v, b_v)): return False return True @@ -76,20 +75,19 @@ def __init__(self, param_to_ignore, a_param_1: int, a_param_2: float, 'needs_data_refresh', and 'clear_updated_params' will be available. """ - def __init__(self, ignore: set = {'api'}): + def __init__(self, ignore: set = {"api"}): self._updated_params: set = set() self._ignore = ignore @classmethod def _get_param_signatures(cls) -> List[inspect.Parameter]: - init = getattr(cls, '__init__') + init = getattr(cls, "__init__") if init is object.__init__: # Class has a default __init__ and thus no params return [] init_signature = inspect.signature(init) # Filter out 'self' and '**kwargs' params - parameters = [p for p in init_signature.parameters.values() - if (p.name != 'self') and (p.kind != p.VAR_KEYWORD)] + parameters = [p for p in init_signature.parameters.values() if (p.name != "self") and (p.kind != p.VAR_KEYWORD)] return parameters @classmethod @@ -127,18 +125,20 @@ def set_params(self, **params): setattr(self, f"_{param}", value) self._updated_params.add(param) else: - warnings.warn(f"The value ({value}) for parameter " - f"'{param}' should be of type " - f"'{param_types[param]}' but is instead " - f"{type(value)}. It will remain as: " - f"{current_value} " - f"({type(current_value)}).", - stacklevel=2) + warnings.warn( + f"The value ({value}) for parameter " + f"'{param}' should be of type " + f"'{param_types[param]}' but is instead " + f"{type(value)}. It will remain as: " + f"{current_value} " + f"({type(current_value)}).", + stacklevel=2, + ) else: - warnings.warn(f"The parameter '{param}' is not valid " - f"and is being ignored! " - f"Possible params are: {valid_params}", - stacklevel=2) + warnings.warn( + f"The parameter '{param}' is not valid and is being ignored! Possible params are: {valid_params}", + stacklevel=2, + ) def needs_data_refresh(self, data_params: set) -> bool: """Check if specific params have been updated via `set_params()`""" @@ -149,11 +149,14 @@ def clear_updated_params(self, data_params: set): self._updated_params -= data_params -def sessions_are_equal(A, B, reraise=False, - ignore_keys: Optional[Dict[str, Set[str]]] = None, - skip_fields: Optional[Iterable] = None, - test_methods=False) \ - -> bool: +def sessions_are_equal( + A, + B, + reraise=False, + ignore_keys: Optional[Dict[str, Set[str]]] = None, + skip_fields: Optional[Iterable] = None, + test_methods=False, +) -> bool: """Check if two Session objects are equal (have same property and get method values). @@ -182,7 +185,6 @@ def sessions_are_equal(A, B, reraise=False, if ignore_keys is None: ignore_keys = dict() if skip_fields is None: - skip_fields = set() A_data_attrs_and_methods = A.list_data_attributes_and_methods() @@ -205,22 +207,17 @@ def sessions_are_equal(A, B, reraise=False, else: continue - err_msg = (f"{field} on {A} did not equal {field} " - f"on {B} (\n{x1} vs\n{x2}\n)") + err_msg = f"{field} on {A} did not equal {field} on {B} (\n{x1} vs\n{x2}\n)" if isinstance(x1, DataObject): x1 = x1.value if isinstance(x2, DataObject): x2 = x2.value - compare_fields(x1, x2, err_msg, - ignore_keys=ignore_keys.get(field, None)) + compare_fields(x1, x2, err_msg, ignore_keys=ignore_keys.get(field, None)) except NotImplementedError: - A_implements_get_field = hasattr( - A.api, getattr(type(A), field).getter_name) - B_implements_get_field = hasattr( - B.api, getattr(type(B), field).getter_name) - assert ((A_implements_get_field is False) - and (B_implements_get_field is False)) + A_implements_get_field = hasattr(A.api, getattr(type(A), field).getter_name) + B_implements_get_field = hasattr(B.api, getattr(type(B), field).getter_name) + assert (A_implements_get_field is False) and (B_implements_get_field is False) except (AssertionError, AttributeError): if reraise: diff --git a/allensdk/brain_observatory/static_gratings.py b/allensdk/brain_observatory/static_gratings.py index 06f1f4c6bc..fcb0c4fa04 100644 --- a/allensdk/brain_observatory/static_gratings.py +++ b/allensdk/brain_observatory/static_gratings.py @@ -39,8 +39,7 @@ from math import sqrt import logging from .stimulus_analysis import StimulusAnalysis -from .brain_observatory_exceptions import BrainObservatoryAnalysisException, \ - MissingStimulusException +from .brain_observatory_exceptions import BrainObservatoryAnalysisException, MissingStimulusException from . import observatory_plots as oplots from . import circle_plots as cplots import h5py @@ -48,14 +47,14 @@ class StaticGratings(StimulusAnalysis): - """ Perform tuning analysis specific to static gratings stimulus. + """Perform tuning analysis specific to static gratings stimulus. Parameters ---------- data_set: BrainObservatoryNwbDataSet object """ - _log = logging.getLogger('allensdk.brain_observatory.static_gratings') + _log = logging.getLogger("allensdk.brain_observatory.static_gratings") def __init__(self, data_set, **kwargs): super(StaticGratings, self).__init__(data_set, **kwargs) @@ -134,10 +133,9 @@ def number_phase(self): return self._number_phase def populate_stimulus_table(self): - stimulus_table = self.data_set.get_stimulus_table('static_gratings') - self._stim_table = stimulus_table.fillna(value=0.) - self._sweeplength = (self.stim_table['end'].iloc[1] - - self.stim_table['start'].iloc[1]) + stimulus_table = self.data_set.get_stimulus_table("static_gratings") + self._stim_table = stimulus_table.fillna(value=0.0) + self._sweeplength = self.stim_table["end"].iloc[1] - self.stim_table["start"].iloc[1] self._interlength = 4 * self._sweeplength self._extralength = self._sweeplength self._orivals = np.unique(self._stim_table.orientation.dropna()) @@ -148,7 +146,7 @@ def populate_stimulus_table(self): self._number_phase = len(self._phasevals) def get_response(self): - ''' Computes the mean response for each cell to each stimulus + """Computes the mean response for each cell to each stimulus condition. Return is a (# orientations, # spatial frequencies, # phases, # cells, 3) np.ndarray. The final dimension @@ -161,17 +159,15 @@ def get_response(self): Returns ------- Numpy array storing mean responses. - ''' + """ StaticGratings._log.info("Calculating mean responses") - response = np.empty((self.number_ori, self.number_sf, - self.number_phase, self.numbercells + 1, 3)) + response = np.empty((self.number_ori, self.number_sf, self.number_phase, self.numbercells + 1, 3)) def ptest(x): if x.empty: return np.nan - return len(np.where( - x < (0.05 / (self.number_ori * (self.number_sf - 1))))[0]) + return len(np.where(x < (0.05 / (self.number_ori * (self.number_sf - 1))))[0]) for ori in self.orivals: ori_pt = np.where(self.orivals == ori)[0][0] @@ -182,25 +178,23 @@ def ptest(x): for phase in self.phasevals: phase_pt = np.where(self.phasevals == phase)[0][0] subset_response = self.mean_sweep_response[ - (self.stim_table.spatial_frequency == sf) & ( - self.stim_table.orientation == ori) & ( - self.stim_table.phase == phase)] + (self.stim_table.spatial_frequency == sf) + & (self.stim_table.orientation == ori) + & (self.stim_table.phase == phase) + ] subset_pval = self.pval[ - (self.stim_table.spatial_frequency == sf) & ( - self.stim_table.orientation == ori) & ( - self.stim_table.phase == phase)] - response[ori_pt, sf_pt, phase_pt, :, 0] = \ - subset_response.mean(axis=0) - response[ori_pt, sf_pt, phase_pt, :, 1] = \ - subset_response.std( - axis=0) / sqrt(len(subset_response)) - response[ori_pt, sf_pt, phase_pt, :, 2] = \ - subset_pval.apply(ptest, axis=0) + (self.stim_table.spatial_frequency == sf) + & (self.stim_table.orientation == ori) + & (self.stim_table.phase == phase) + ] + response[ori_pt, sf_pt, phase_pt, :, 0] = subset_response.mean(axis=0) + response[ori_pt, sf_pt, phase_pt, :, 1] = subset_response.std(axis=0) / sqrt(len(subset_response)) + response[ori_pt, sf_pt, phase_pt, :, 2] = subset_pval.apply(ptest, axis=0) return response def get_peak(self): - """ Computes metrics related to each cell's peak response condition. + """Computes metrics related to each cell's peak response condition. Returns ------- @@ -215,19 +209,31 @@ def get_peak(self): * ptest_sg * time_to_peak_sg """ - StaticGratings._log.info('Calculating peak response properties') - - peak = pd.DataFrame(index=range(self.numbercells), columns=( - 'ori_sg', 'sf_sg', 'phase_sg', 'reliability_sg', - 'osi_sg', 'peak_dff_sg', 'ptest_sg', 'time_to_peak_sg', - 'cell_specimen_id', 'p_run_sg', 'cv_os_sg', - 'run_modulation_sg', 'sf_index_sg')) + StaticGratings._log.info("Calculating peak response properties") + + peak = pd.DataFrame( + index=range(self.numbercells), + columns=( + "ori_sg", + "sf_sg", + "phase_sg", + "reliability_sg", + "osi_sg", + "peak_dff_sg", + "ptest_sg", + "time_to_peak_sg", + "cell_specimen_id", + "p_run_sg", + "cv_os_sg", + "run_modulation_sg", + "sf_index_sg", + ), + ) cids = self.data_set.get_cell_specimen_ids() orivals_rad = np.deg2rad(self.orivals) for nc in range(self.numbercells): - cell_peak = np.where(self.response[:, 1:, :, nc, 0] == np.nanmax( - self.response[:, 1:, :, nc, 0])) + cell_peak = np.where(self.response[:, 1:, :, nc, 0] == np.nanmax(self.response[:, 1:, :, nc, 0])) pref_ori = cell_peak[0][0] pref_sf = cell_peak[1][0] + 1 pref_phase = cell_peak[2][0] @@ -241,13 +247,12 @@ def get_peak(self): # TODO: check number of trials pref = self.response[pref_ori, pref_sf, pref_phase, nc, 0] - orth = self.response[ - np.mod(pref_ori + 3, 6), pref_sf, pref_phase, nc, 0] + orth = self.response[np.mod(pref_ori + 3, 6), pref_sf, pref_phase, nc, 0] tuning = self.response[:, pref_sf, pref_phase, nc, 0] tuning = np.where(tuning > 0, tuning, 0) CV_top_os = np.empty((6), dtype=np.complex128) for i in range(6): - CV_top_os[i] = (tuning[i] * np.exp(1j * 2 * orivals_rad[i])) + CV_top_os[i] = tuning[i] * np.exp(1j * 2 * orivals_rad[i]) peak.cv_os_sg.iloc[nc] = np.abs(CV_top_os.sum()) / tuning.sum() peak.osi_sg[nc] = (pref - orth) / (pref + orth) @@ -259,74 +264,74 @@ def get_peak(self): for phase in self.phasevals: groups.append( self.mean_sweep_response[ - (self.stim_table.spatial_frequency == sf) & - (self.stim_table.orientation == ori) & - (self.stim_table.phase == phase)][str(nc)]) - groups.append(self.mean_sweep_response[ - self.stim_table.spatial_frequency == 0][str(nc)]) + (self.stim_table.spatial_frequency == sf) + & (self.stim_table.orientation == ori) + & (self.stim_table.phase == phase) + ][str(nc)] + ) + groups.append(self.mean_sweep_response[self.stim_table.spatial_frequency == 0][str(nc)]) _, p = st.f_oneway(*groups) peak.ptest_sg[nc] = p - test_rows = \ - (self.stim_table.orientation == self.orivals[pref_ori]) & \ - (self.stim_table.spatial_frequency == self.sfvals[pref_sf]) & \ - (self.stim_table.phase == self.phasevals[pref_phase]) + test_rows = ( + (self.stim_table.orientation == self.orivals[pref_ori]) + & (self.stim_table.spatial_frequency == self.sfvals[pref_sf]) + & (self.stim_table.phase == self.phasevals[pref_phase]) + ) if len(test_rows) < 2: - msg = "Static grating p value requires at least 2 trials at " \ - "the preferred orientation/spatial frequency/phase. " \ - "Cell %d (%f, %f, %f) has %d." % \ - (int(nc), self.orivals[pref_ori], self.sfvals[pref_sf], - self.phasevals[pref_phase], len(test_rows)) + msg = ( + "Static grating p value requires at least 2 trials at " + "the preferred orientation/spatial frequency/phase. " + "Cell %d (%f, %f, %f) has %d." + % ( + int(nc), + self.orivals[pref_ori], + self.sfvals[pref_sf], + self.phasevals[pref_phase], + len(test_rows), + ) + ) raise BrainObservatoryAnalysisException(msg) test = self.sweep_response[test_rows][str(nc)].mean() - peak.time_to_peak_sg[nc] = \ - ((np.argmax(test) - self.interlength) / self.acquisition_rate) + peak.time_to_peak_sg[nc] = (np.argmax(test) - self.interlength) / self.acquisition_rate # running modulation subset = self.mean_sweep_response[ - (self.stim_table.spatial_frequency == self.sfvals[pref_sf]) & - (self.stim_table.orientation == self.orivals[pref_ori]) & - (self.stim_table.phase == self.phasevals[pref_phase])] + (self.stim_table.spatial_frequency == self.sfvals[pref_sf]) + & (self.stim_table.orientation == self.orivals[pref_ori]) + & (self.stim_table.phase == self.phasevals[pref_phase]) + ] subset_run = subset[subset.dx >= 1] subset_stat = subset[subset.dx < 1] if (len(subset_run) > 4) & (len(subset_stat) > 4): - (_, peak.p_run_sg.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], - subset_stat[ - str(nc)], - equal_var=False) + (_, peak.p_run_sg.iloc[nc]) = st.ttest_ind(subset_run[str(nc)], subset_stat[str(nc)], equal_var=False) if subset_run[str(nc)].mean() > subset_stat[str(nc)].mean(): - peak.run_modulation_sg.iloc[nc] = (subset_run[ - str(nc)].mean() - - subset_stat[ - str(nc)].mean()) \ - / np.abs( - subset_run[str(nc)].mean()) + peak.run_modulation_sg.iloc[nc] = ( + subset_run[str(nc)].mean() - subset_stat[str(nc)].mean() + ) / np.abs(subset_run[str(nc)].mean()) elif subset_run[str(nc)].mean() < subset_stat[str(nc)].mean(): - peak.run_modulation_sg.iloc[nc] = \ - (-1 * ((subset_stat[str(nc)].mean() - - subset_run[str(nc)].mean()) / - np.abs(subset_stat[str(nc)].mean()))) + peak.run_modulation_sg.iloc[nc] = -1 * ( + (subset_stat[str(nc)].mean() - subset_run[str(nc)].mean()) / np.abs(subset_stat[str(nc)].mean()) + ) else: peak.p_run_sg.iloc[nc] = np.nan peak.run_modulation_sg.iloc[nc] = np.nan # reliability - subset = \ - self.sweep_response[ - (self.stim_table.spatial_frequency == - self.sfvals[pref_sf]) & - (self.stim_table.orientation == self.orivals[pref_ori]) & - (self.stim_table.phase == self.phasevals[pref_phase])] + subset = self.sweep_response[ + (self.stim_table.spatial_frequency == self.sfvals[pref_sf]) + & (self.stim_table.orientation == self.orivals[pref_ori]) + & (self.stim_table.phase == self.phasevals[pref_phase]) + ] corr_matrix = np.empty((len(subset), len(subset))) for i in range(len(subset)): for j in range(len(subset)): - r, p = st.pearsonr(subset[str(nc)].iloc[i][28:42], - subset[str(nc)].iloc[j][28:42]) + r, p = st.pearsonr(subset[str(nc)].iloc[i][28:42], subset[str(nc)].iloc[j][28:42]) corr_matrix[i, j] = r mask = np.ones((len(subset), len(subset))) for i in range(len(subset)): @@ -339,44 +344,37 @@ def get_peak(self): # SF index sf_tuning = self.response[pref_ori, 1:, pref_phase, nc, 0] trials = self.mean_sweep_response[ - (self.stim_table.spatial_frequency != 0) & - (self.stim_table.orientation == self.orivals[pref_ori]) & - (self.stim_table.phase == self.phasevals[pref_phase]) + (self.stim_table.spatial_frequency != 0) + & (self.stim_table.orientation == self.orivals[pref_ori]) + & (self.stim_table.phase == self.phasevals[pref_phase]) ][str(nc)].values - SSE_part = np.sqrt( - np.sum((trials - trials.mean()) ** 2) / (len(trials) - 5)) - peak.sf_index_sg.iloc[nc] = (np.ptp(sf_tuning)) / ( - np.ptp(sf_tuning) + 2 * SSE_part) + SSE_part = np.sqrt(np.sum((trials - trials.mean()) ** 2) / (len(trials) - 5)) + peak.sf_index_sg.iloc[nc] = (np.ptp(sf_tuning)) / (np.ptp(sf_tuning) + 2 * SSE_part) return peak - def plot_time_to_peak(self, - p_value_max=oplots.P_VALUE_MAX, - color_map=oplots.STIMULUS_COLOR_MAP): - stimulus_table = self.data_set.get_stimulus_table('static_gratings') + def plot_time_to_peak(self, p_value_max=oplots.P_VALUE_MAX, color_map=oplots.STIMULUS_COLOR_MAP): + stimulus_table = self.data_set.get_stimulus_table("static_gratings") resps = [] for index, row in self.peak.iterrows(): - pref_rows = (stimulus_table.orientation == self.orivals[ - row.ori_sg]) & \ - (stimulus_table.spatial_frequency == self.sfvals[ - row.sf_sg]) & \ - (stimulus_table.phase == self.phasevals[row.phase_sg]) + pref_rows = ( + (stimulus_table.orientation == self.orivals[row.ori_sg]) + & (stimulus_table.spatial_frequency == self.sfvals[row.sf_sg]) + & (stimulus_table.phase == self.phasevals[row.phase_sg]) + ) mean_response = self.sweep_response[pref_rows][str(index)].mean() - resps.append( - (mean_response - mean_response.mean() / mean_response.std())) + resps.append((mean_response - mean_response.mean() / mean_response.std())) mean_responses = np.array(resps) - sorted_table = self.peak[self.peak.ptest_sg < p_value_max].sort_values( - 'time_to_peak_sg') + sorted_table = self.peak[self.peak.ptest_sg < p_value_max].sort_values("time_to_peak_sg") cell_order = sorted_table.index # time to peak is relative to stimulus start in seconds - ttps = (sorted_table.time_to_peak_sg.values + - self.interlength / self.acquisition_rate) + ttps = sorted_table.time_to_peak_sg.values + self.interlength / self.acquisition_rate msrs_sorted = mean_responses[cell_order, :] oplots.plot_time_to_peak( @@ -386,54 +384,53 @@ def plot_time_to_peak(self, (2 * self.interlength + self.sweeplength) / self.acquisition_rate, self.interlength / self.acquisition_rate, (self.interlength + self.sweeplength) / self.acquisition_rate, - color_map) - - def plot_orientation_selectivity(self, - si_range=oplots.SI_RANGE, - n_hist_bins=oplots.N_HIST_BINS, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - + color_map, + ) + + def plot_orientation_selectivity( + self, + si_range=oplots.SI_RANGE, + n_hist_bins=oplots.N_HIST_BINS, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): # responsive cells - vis_cells = (self.peak.ptest_sg < p_value_max) & ( - self.peak.peak_dff_sg > peak_dff_min) + vis_cells = (self.peak.ptest_sg < p_value_max) & (self.peak.peak_dff_sg > peak_dff_min) # orientation selective cells - osi_cells = vis_cells & (self.peak.osi_sg > si_range[0]) & ( - self.peak.osi_sg < si_range[1]) + osi_cells = vis_cells & (self.peak.osi_sg > si_range[0]) & (self.peak.osi_sg < si_range[1]) peak_osi = self.peak.loc[osi_cells] osis = peak_osi.osi_sg.values - oplots.plot_selectivity_cumulative_histogram(osis, - "orientation " - "selectivity index", - si_range=si_range, - n_hist_bins=n_hist_bins, - color=color) - - def plot_preferred_orientation(self, - include_labels=False, - si_range=oplots.SI_RANGE, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - - vis_cells = (self.peak.ptest_sg < p_value_max) & ( - self.peak.peak_dff_sg > peak_dff_min) + oplots.plot_selectivity_cumulative_histogram( + osis, "orientation selectivity index", si_range=si_range, n_hist_bins=n_hist_bins, color=color + ) + + def plot_preferred_orientation( + self, + include_labels=False, + si_range=oplots.SI_RANGE, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): + vis_cells = (self.peak.ptest_sg < p_value_max) & (self.peak.peak_dff_sg > peak_dff_min) pref_oris = self.peak.loc[vis_cells].ori_sg.values pref_oris = [self.orivals[pref_ori] for pref_ori in pref_oris] angles, counts = np.unique(pref_oris, return_counts=True) - oplots.plot_radial_histogram(angles, - counts, - include_labels=include_labels, - all_angles=self.orivals, - direction=-1, - offset=180.0, - color=color) + oplots.plot_radial_histogram( + angles, + counts, + include_labels=include_labels, + all_angles=self.orivals, + direction=-1, + offset=180.0, + color=color, + ) if len(counts) == 0: max_count = 1 @@ -447,34 +444,31 @@ def plot_preferred_orientation(self, h = 1.6 * max_count w = 2.4 * max_count - plt.gca().set(xlim=(center_x - w * 0.5, center_x + w * 0.5), - ylim=(center_y - h * 0.5, center_y + h * 0.5), - aspect=1.0) - - def plot_preferred_spatial_frequency(self, - si_range=oplots.SI_RANGE, - color=oplots.STIM_COLOR, - p_value_max=oplots.P_VALUE_MAX, - peak_dff_min=oplots.PEAK_DFF_MIN): - - vis_cells = (self.peak.ptest_sg < p_value_max) & ( - self.peak.peak_dff_sg > peak_dff_min) + plt.gca().set( + xlim=(center_x - w * 0.5, center_x + w * 0.5), ylim=(center_y - h * 0.5, center_y + h * 0.5), aspect=1.0 + ) + + def plot_preferred_spatial_frequency( + self, + si_range=oplots.SI_RANGE, + color=oplots.STIM_COLOR, + p_value_max=oplots.P_VALUE_MAX, + peak_dff_min=oplots.PEAK_DFF_MIN, + ): + vis_cells = (self.peak.ptest_sg < p_value_max) & (self.peak.peak_dff_sg > peak_dff_min) pref_sfs = self.peak.loc[vis_cells].sf_sg.values - oplots.plot_condition_histogram(pref_sfs, - self.sfvals[1:], - color=color) + oplots.plot_condition_histogram(pref_sfs, self.sfvals[1:], color=color) plt.xlabel("spatial frequency (cycles/deg)") plt.ylabel("number of cells") - def open_fan_plot(self, cell_specimen_id=None, include_labels=False, - cell_index=None): + def open_fan_plot(self, cell_specimen_id=None, include_labels=False, cell_index=None): cell_index = self.row_from_cell_id(cell_specimen_id, cell_index) df = self.mean_sweep_response[str(cell_index)] - st = self.data_set.get_stimulus_table('static_gratings') - mask = st.dropna(subset=['orientation']).index + st = self.data_set.get_stimulus_table("static_gratings") + mask = st.dropna(subset=["orientation"]).index data = df.values @@ -482,11 +476,13 @@ def open_fan_plot(self, cell_specimen_id=None, include_labels=False, cmax = max(cmin, data.mean() + data.std() * 3) fp = cplots.FanPlotter.for_static_gratings() - fp.plot(r_data=st.spatial_frequency.loc[mask].values, - angle_data=st.orientation.loc[mask].values, - group_data=st.phase.loc[mask].values, - data=df.loc[mask].values, - clim=[cmin, cmax]) + fp.plot( + r_data=st.spatial_frequency.loc[mask].values, + angle_data=st.orientation.loc[mask].values, + group_data=st.phase.loc[mask].values, + data=df.loc[mask].values, + clim=[cmin, cmax], + ) fp.show_axes(closed=False) if include_labels: @@ -494,170 +490,154 @@ def open_fan_plot(self, cell_specimen_id=None, include_labels=False, fp.show_angle_labels() def reshape_response_array(self): - ''' + """ :return: response array in cells x stim conditions x repetition for noise correlations this is a re-organization of the mean sweep response table - ''' + """ - mean_sweep_response = \ - self.mean_sweep_response.values[:, :self.numbercells] + mean_sweep_response = self.mean_sweep_response.values[:, : self.numbercells] stim_table = self.stim_table sfvals = self.sfvals sfvals = sfvals[sfvals != 0] # blank sweep - response_new = np.zeros((self.numbercells, self.number_ori, - self.number_sf - 1, self.number_phase), - dtype='object') + response_new = np.zeros( + (self.numbercells, self.number_ori, self.number_sf - 1, self.number_phase), dtype="object" + ) for i, ori in enumerate(self.orivals): for j, sf in enumerate(sfvals): for k, phase in enumerate(self.phasevals): - ind = (stim_table.orientation.values == ori) * ( - stim_table.spatial_frequency.values == sf) * ( - stim_table.phase.values == phase) + ind = ( + (stim_table.orientation.values == ori) + * (stim_table.spatial_frequency.values == sf) + * (stim_table.phase.values == phase) + ) for c in range(self.numbercells): response_new[c, i, j, k] = mean_sweep_response[ind, c] - ind = (stim_table.spatial_frequency.values == 0) + ind = stim_table.spatial_frequency.values == 0 response_blank = mean_sweep_response[ind, :].T return response_new, response_blank - def get_signal_correlation(self, corr='spearman'): + def get_signal_correlation(self, corr="spearman"): logging.debug("Calculating signal correlation") # orientation x freq x phase x cell, no blank - response = self.response[:, 1:, :, :self.numbercells, 0] + response = self.response[:, 1:, :, : self.numbercells, 0] - response = response.reshape( - self.number_ori * (self.number_sf - 1) * self.number_phase, - self.numbercells).T + response = response.reshape(self.number_ori * (self.number_sf - 1) * self.number_phase, self.numbercells).T N, Nstim = response.shape signal_corr = np.zeros((N, N)) signal_p = np.empty((N, N)) - if corr == 'pearson': + if corr == "pearson": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.pearsonr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(N): for j in range(i, N): # matrix is symmetric - signal_corr[i, j], signal_p[i, j] = st.spearmanr( - response[i], response[j]) + signal_corr[i, j], signal_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") # fill in lower triangle - signal_corr = ( - np.triu(signal_corr) + - np.triu(signal_corr, 1).T) + signal_corr = np.triu(signal_corr) + np.triu(signal_corr, 1).T # fill in lower triangle - signal_p = ( - np.triu(signal_p) + - np.triu(signal_p, 1).T) + signal_p = np.triu(signal_p) + np.triu(signal_p, 1).T return signal_corr, signal_p - def get_representational_similarity(self, corr='spearman'): + def get_representational_similarity(self, corr="spearman"): logging.debug("Calculating representational similarity") # orientation x freq x phase x cell - response = self.response[:, 1:, :, :self.numbercells, 0] - response = response.reshape( - self.number_ori * (self.number_sf - 1) * self.number_phase, - self.numbercells) + response = self.response[:, 1:, :, : self.numbercells, 0] + response = response.reshape(self.number_ori * (self.number_sf - 1) * self.number_phase, self.numbercells) Nstim, N = response.shape rep_sim = np.zeros((Nstim, Nstim)) rep_sim_p = np.empty((Nstim, Nstim)) - if corr == 'pearson': + if corr == "pearson": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.pearsonr(response[i], response[j]) - elif corr == 'spearman': + elif corr == "spearman": for i in range(Nstim): for j in range(i, Nstim): # matrix is symmetric - rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], - response[j]) + rep_sim[i, j], rep_sim_p[i, j] = st.spearmanr(response[i], response[j]) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") - rep_sim = np.triu(rep_sim) + np.triu(rep_sim, - 1).T # fill in lower triangle - rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, - 1).T # fill in lower triangle + rep_sim = np.triu(rep_sim) + np.triu(rep_sim, 1).T # fill in lower triangle + rep_sim_p = np.triu(rep_sim_p) + np.triu(rep_sim_p, 1).T # fill in lower triangle return rep_sim, rep_sim_p - def get_noise_correlation(self, corr='spearman'): + def get_noise_correlation(self, corr="spearman"): logging.debug("Calculating noise correlation") response, response_blank = self.reshape_response_array() - noise_corr = np.zeros((self.numbercells, self.numbercells, - self.number_ori, self.number_sf - 1, - self.number_phase)) - noise_corr_p = np.zeros((self.numbercells, self.numbercells, - self.number_ori, self.number_sf - 1, - self.number_phase)) + noise_corr = np.zeros( + (self.numbercells, self.numbercells, self.number_ori, self.number_sf - 1, self.number_phase) + ) + noise_corr_p = np.zeros( + (self.numbercells, self.numbercells, self.number_ori, self.number_sf - 1, self.number_phase) + ) noise_corr_blank = np.zeros((self.numbercells, self.numbercells)) noise_corr_blank_p = np.zeros((self.numbercells, self.numbercells)) - if corr == 'pearson': + if corr == "pearson": for k in range(self.number_ori): - for l in range(self.number_sf - 1): # noqa E741 + for l in range(self.number_sf - 1): # noqa E741 for m in range(self.number_phase): for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k, l, m], noise_corr_p[ - i, j, k, l, m] = st.pearsonr( - response[i, k, l, m], response[j, k, l, m]) + noise_corr[i, j, k, l, m], noise_corr_p[i, j, k, l, m] = st.pearsonr( + response[i, k, l, m], response[j, k, l, m] + ) - noise_corr[:, :, k, l, m] = np.triu( - noise_corr[:, :, k, l, m]) + np.triu( - noise_corr[:, :, k, l, m], 1).T + noise_corr[:, :, k, l, m] = ( + np.triu(noise_corr[:, :, k, l, m]) + np.triu(noise_corr[:, :, k, l, m], 1).T + ) for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr_blank[i, j], noise_corr_blank_p[ - i, j] = st.pearsonr(response_blank[i], - response_blank[j]) + noise_corr_blank[i, j], noise_corr_blank_p[i, j] = st.pearsonr(response_blank[i], response_blank[j]) - elif corr == 'spearman': + elif corr == "spearman": for k in range(self.number_ori): - for l in range(self.number_sf - 1): # noqa E741 + for l in range(self.number_sf - 1): # noqa E741 for m in range(self.number_phase): for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr[i, j, k, l, m], noise_corr_p[ - i, j, k, l, m] = st.spearmanr( - response[i, k, l, m], response[j, k, l, m]) + noise_corr[i, j, k, l, m], noise_corr_p[i, j, k, l, m] = st.spearmanr( + response[i, k, l, m], response[j, k, l, m] + ) - noise_corr[:, :, k, l, m] = np.triu( - noise_corr[:, :, k, l, m]) + np.triu( - noise_corr[:, :, k, l, m], 1).T + noise_corr[:, :, k, l, m] = ( + np.triu(noise_corr[:, :, k, l, m]) + np.triu(noise_corr[:, :, k, l, m], 1).T + ) for i in range(self.numbercells): for j in range(i, self.numbercells): - noise_corr_blank[i, j], noise_corr_blank_p[ - i, j] = st.spearmanr(response_blank[i], - response_blank[j]) + noise_corr_blank[i, j], noise_corr_blank_p[i, j] = st.spearmanr( + response_blank[i], response_blank[j] + ) else: - raise Exception('correlation should be pearson or spearman') + raise Exception("correlation should be pearson or spearman") - noise_corr_blank[:, :] = np.triu(noise_corr_blank[:, :]) + np.triu( - noise_corr_blank[:, :], 1).T + noise_corr_blank[:, :] = np.triu(noise_corr_blank[:, :]) + np.triu(noise_corr_blank[:, :], 1).T return noise_corr, noise_corr_p, noise_corr_blank, noise_corr_blank_p @@ -668,10 +648,8 @@ def from_analysis_file(data_set, analysis_file): try: sg.populate_stimulus_table() - sg._sweep_response = pd.read_hdf(analysis_file, - "analysis/sweep_response_sg") - sg._mean_sweep_response = pd.read_hdf( - analysis_file, "analysis/mean_sweep_response_sg") + sg._sweep_response = pd.read_hdf(analysis_file, "analysis/sweep_response_sg") + sg._mean_sweep_response = pd.read_hdf(analysis_file, "analysis/mean_sweep_response_sg") sg._peak = pd.read_hdf(analysis_file, "analysis/peak") with h5py.File(analysis_file, "r") as f: @@ -686,8 +664,7 @@ def from_analysis_file(data_set, analysis_file): if "analysis/signal_corr_sg" in f: sg.signal_correlation = f["analysis/signal_corr_sg"][()] if "analysis/rep_similarity_sg" in f: - sg.representational_similarity = f[ - "analysis/rep_similarity_sg"][()] + sg.representational_similarity = f["analysis/rep_similarity_sg"][()] except Exception as e: raise MissingStimulusException(e.args) diff --git a/allensdk/brain_observatory/stimulus_analysis.py b/allensdk/brain_observatory/stimulus_analysis.py index 5d85131ce7..e140d6e627 100644 --- a/allensdk/brain_observatory/stimulus_analysis.py +++ b/allensdk/brain_observatory/stimulus_analysis.py @@ -46,7 +46,7 @@ class StimulusAnalysis(object): - """ Base class for all response analysis code. Subclasses are responsible + """Base class for all response analysis code. Subclasses are responsible for computing metrics and traces relevant to a particular stimulus. The base class contains methods for organizing sweep responses row of a stimulus stable (get_sweep_response). Subclasses implement the @@ -61,7 +61,8 @@ class StimulusAnalysis(object): Whether or not to compute speed tuning histograms """ - _log = logging.getLogger('allensdk.brain_observatory.stimulus_analysis') + + _log = logging.getLogger("allensdk.brain_observatory.stimulus_analysis") _PRELOAD = "PRELOAD" def __init__(self, data_set): @@ -103,24 +104,21 @@ def stim_table(self): @property def sweep_response(self): if self._sweep_response is StimulusAnalysis._PRELOAD: - self._sweep_response, self._mean_sweep_response, self._pval = \ - self.get_sweep_response() + self._sweep_response, self._mean_sweep_response, self._pval = self.get_sweep_response() return self._sweep_response @property def mean_sweep_response(self): if self._mean_sweep_response is StimulusAnalysis._PRELOAD: - self._sweep_response, self._mean_sweep_response, self._pval = \ - self.get_sweep_response() + self._sweep_response, self._mean_sweep_response, self._pval = self.get_sweep_response() return self._mean_sweep_response @property def pval(self): if self._pval is StimulusAnalysis._PRELOAD: - self._sweep_response, self._mean_sweep_response, self._pval = \ - self.get_sweep_response() + self._sweep_response, self._mean_sweep_response, self._pval = self.get_sweep_response() return self._pval @@ -140,8 +138,7 @@ def peak(self): def get_fluorescence(self): # get fluorescence - self._timestamps, self._celltraces = \ - self.data_set.get_corrected_fluorescence_traces() + self._timestamps, self._celltraces = self.data_set.get_corrected_fluorescence_traces() self._acquisition_rate = 1 / (self.timestamps[1] - self.timestamps[0]) self._numbercells = len(self.celltraces) # number of cells in dataset @@ -211,63 +208,62 @@ def dxtime(self): @property def binned_dx_sp(self): if self._binned_dx_sp is StimulusAnalysis._PRELOAD: - (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, - self._binned_cells_vis, self._peak_run) = \ + (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, self._binned_cells_vis, self._peak_run) = ( self.get_speed_tuning(binsize=self._binsize) + ) return self._binned_dx_sp @property def binned_cells_sp(self): if self._binned_cells_sp is StimulusAnalysis._PRELOAD: - (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, - self._binned_cells_vis, self._peak_run) = \ + (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, self._binned_cells_vis, self._peak_run) = ( self.get_speed_tuning(binsize=self._binsize) + ) return self._binned_cells_sp @property def binned_dx_vis(self): if self._binned_dx_vis is StimulusAnalysis._PRELOAD: - (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, - self._binned_cells_vis, self._peak_run) = \ + (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, self._binned_cells_vis, self._peak_run) = ( self.get_speed_tuning(binsize=self._binsize) + ) return self._binned_dx_vis @property def binned_cells_vis(self): if self._binned_cells_vis is StimulusAnalysis._PRELOAD: - (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, - self._binned_cells_vis, self._peak_run) = \ + (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, self._binned_cells_vis, self._peak_run) = ( self.get_speed_tuning(binsize=self._binsize) + ) return self._binned_cells_vis @property def peak_run(self): if self._peak_run is StimulusAnalysis._PRELOAD: - (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, - self._binned_cells_vis, self._peak_run) = \ + (self._binned_dx_sp, self._binned_cells_sp, self._binned_dx_vis, self._binned_cells_vis, self._peak_run) = ( self.get_speed_tuning(binsize=self._binsize) + ) return self._peak_run def populate_stimulus_table(self): - """ Implemented by subclasses. """ - raise BrainObservatoryAnalysisException( - "populate_stimulus_table not implemented") + """Implemented by subclasses.""" + raise BrainObservatoryAnalysisException("populate_stimulus_table not implemented") def get_response(self): - """ Implemented by subclasses. """ + """Implemented by subclasses.""" raise BrainObservatoryAnalysisException("get_response not implemented") def get_peak(self): - """ Implemented by subclasses. """ + """Implemented by subclasses.""" raise BrainObservatoryAnalysisException("get_peak not implemented") def get_speed_tuning(self, binsize): - """ Calculates speed tuning, spontaneous versus visually driven. + """Calculates speed tuning, spontaneous versus visually driven. The return is a 5-tuple of speed and dF/F histograms. @@ -313,39 +309,42 @@ def get_speed_tuning(self, binsize): "outputs obtained under recent scipy versions!" ) - StimulusAnalysis._log.info( - 'Calculating speed tuning, spontaneous vs visually driven') + StimulusAnalysis._log.info("Calculating speed tuning, spontaneous vs visually driven") - celltraces_trimmed = np.delete(self.dfftraces, range( - len(self.dxcm), np.size(self.dfftraces, 1)), axis=1) + celltraces_trimmed = np.delete(self.dfftraces, range(len(self.dxcm), np.size(self.dfftraces, 1)), axis=1) # pull out spontaneous epoch(s) - spontaneous = self.data_set.get_stimulus_table('spontaneous') - - peak_run = pd.DataFrame(index=range(self.numbercells), columns=( - 'speed_max_sp', 'speed_min_sp', 'ptest_sp', 'mod_sp', - 'speed_max_vis', 'speed_min_vis', 'ptest_vis', 'mod_vis')) - - dx_sp = self.dxcm[spontaneous.start.iloc[-1]:spontaneous.end.iloc[-1]] - celltraces_sp = celltraces_trimmed[ - :, spontaneous.start.iloc[-1]:spontaneous.end.iloc[-1]] - dx_vis = np.delete(self.dxcm, np.arange( - spontaneous.start.iloc[-1], spontaneous.end.iloc[-1])) - celltraces_vis = np.delete(celltraces_trimmed, np.arange( - spontaneous.start.iloc[-1], spontaneous.end.iloc[-1]), axis=1) + spontaneous = self.data_set.get_stimulus_table("spontaneous") + + peak_run = pd.DataFrame( + index=range(self.numbercells), + columns=( + "speed_max_sp", + "speed_min_sp", + "ptest_sp", + "mod_sp", + "speed_max_vis", + "speed_min_vis", + "ptest_vis", + "mod_vis", + ), + ) + + dx_sp = self.dxcm[spontaneous.start.iloc[-1] : spontaneous.end.iloc[-1]] + celltraces_sp = celltraces_trimmed[:, spontaneous.start.iloc[-1] : spontaneous.end.iloc[-1]] + dx_vis = np.delete(self.dxcm, np.arange(spontaneous.start.iloc[-1], spontaneous.end.iloc[-1])) + celltraces_vis = np.delete( + celltraces_trimmed, np.arange(spontaneous.start.iloc[-1], spontaneous.end.iloc[-1]), axis=1 + ) if len(spontaneous) > 1: - dx_sp = np.append( - dx_sp, - self.dxcm[spontaneous.start.iloc[-2]:spontaneous.end.iloc[-2]], - axis=0) + dx_sp = np.append(dx_sp, self.dxcm[spontaneous.start.iloc[-2] : spontaneous.end.iloc[-2]], axis=0) celltraces_sp = np.append( - celltraces_sp, - celltraces_trimmed[:, spontaneous.start.iloc[-2]: - spontaneous.end.iloc[-2]], axis=1) - dx_vis = np.delete(dx_vis, np.arange( - spontaneous.start.iloc[-2], spontaneous.end.iloc[-2])) - celltraces_vis = np.delete(celltraces_vis, np.arange( - spontaneous.start.iloc[-2], spontaneous.end.iloc[-2]), axis=1) + celltraces_sp, celltraces_trimmed[:, spontaneous.start.iloc[-2] : spontaneous.end.iloc[-2]], axis=1 + ) + dx_vis = np.delete(dx_vis, np.arange(spontaneous.start.iloc[-2], spontaneous.end.iloc[-2])) + celltraces_vis = np.delete( + celltraces_vis, np.arange(spontaneous.start.iloc[-2], spontaneous.end.iloc[-2]), axis=1 + ) celltraces_vis = celltraces_vis[:, ~np.isnan(dx_vis)] dx_vis = dx_vis[~np.isnan(dx_vis)] @@ -356,64 +355,50 @@ def get_speed_tuning(self, binsize): binned_dx_sp = np.zeros((nbins, 2)) for i in range(nbins): if np.all(np.isnan(dx_sorted)): - raise BrainObservatoryAnalysisException( - "dx is filled with NaNs") + raise BrainObservatoryAnalysisException("dx is filled with NaNs") - offset = findlevel(dx_sorted, 1, 'up') + offset = findlevel(dx_sorted, 1, "up") if offset is None: - StimulusAnalysis._log.info( - "dx never crosses 1, all speed data going into single bin") + StimulusAnalysis._log.info("dx never crosses 1, all speed data going into single bin") offset = len(dx_sorted) if i == 0: binned_dx_sp[i, 0] = np.mean(dx_sorted[:offset]) - binned_dx_sp[i, 1] = np.std( - dx_sorted[:offset]) / np.sqrt(offset) - binned_cells_sp[:, i, 0] = np.mean( - celltraces_sorted_sp[:, :offset], axis=1) - binned_cells_sp[:, i, 1] = np.std( - celltraces_sorted_sp[:, :offset], axis=1) / np.sqrt(offset) + binned_dx_sp[i, 1] = np.std(dx_sorted[:offset]) / np.sqrt(offset) + binned_cells_sp[:, i, 0] = np.mean(celltraces_sorted_sp[:, :offset], axis=1) + binned_cells_sp[:, i, 1] = np.std(celltraces_sorted_sp[:, :offset], axis=1) / np.sqrt(offset) else: start = offset + (i - 1) * binsize - binned_dx_sp[i, 0] = np.mean(dx_sorted[start:start + binsize]) - binned_dx_sp[i, 1] = np.std( - dx_sorted[start:start + binsize]) / np.sqrt(binsize) - binned_cells_sp[:, i, 0] = np.mean( - celltraces_sorted_sp[:, start:start + binsize], axis=1) - binned_cells_sp[:, i, 1] = np.std( - celltraces_sorted_sp[:, start:start + binsize], - axis=1) / np.sqrt(binsize) + binned_dx_sp[i, 0] = np.mean(dx_sorted[start : start + binsize]) + binned_dx_sp[i, 1] = np.std(dx_sorted[start : start + binsize]) / np.sqrt(binsize) + binned_cells_sp[:, i, 0] = np.mean(celltraces_sorted_sp[:, start : start + binsize], axis=1) + binned_cells_sp[:, i, 1] = np.std(celltraces_sorted_sp[:, start : start + binsize], axis=1) / np.sqrt( + binsize + ) binned_cells_shuffled_sp = np.empty((self.numbercells, nbins, 2, 200)) for shuf in range(200): - celltraces_shuffled = \ - celltraces_sp[:, - np.random.permutation(np.size(celltraces_sp, 1))] - celltraces_shuffled_sorted = celltraces_shuffled[ - :, np.argsort(dx_sp)] + celltraces_shuffled = celltraces_sp[:, np.random.permutation(np.size(celltraces_sp, 1))] + celltraces_shuffled_sorted = celltraces_shuffled[:, np.argsort(dx_sp)] for i in range(nbins): - offset = findlevel(dx_sorted, 1, 'up') + offset = findlevel(dx_sorted, 1, "up") if offset is None: - StimulusAnalysis._log.info( - "dx never crosses 1, all speed data going into " - "single bin") + StimulusAnalysis._log.info("dx never crosses 1, all speed data going into single bin") offset = celltraces_shuffled_sorted.shape[1] if i == 0: - binned_cells_shuffled_sp[:, i, 0, shuf] = np.mean( - celltraces_shuffled_sorted[:, :offset], axis=1) - binned_cells_shuffled_sp[:, i, 1, shuf] = np.std( - celltraces_shuffled_sorted[:, :offset], axis=1) + binned_cells_shuffled_sp[:, i, 0, shuf] = np.mean(celltraces_shuffled_sorted[:, :offset], axis=1) + binned_cells_shuffled_sp[:, i, 1, shuf] = np.std(celltraces_shuffled_sorted[:, :offset], axis=1) else: start = offset + (i - 1) * binsize binned_cells_shuffled_sp[:, i, 0, shuf] = np.mean( - celltraces_shuffled_sorted[:, start:start + binsize], - axis=1) + celltraces_shuffled_sorted[:, start : start + binsize], axis=1 + ) binned_cells_shuffled_sp[:, i, 1, shuf] = np.std( - celltraces_shuffled_sorted[:, start:start + binsize], - axis=1) + celltraces_shuffled_sorted[:, start : start + binsize], axis=1 + ) nbins = 1 + len(np.where(dx_vis >= 1)[0]) // binsize dx_sorted = dx_vis[np.argsort(dx_vis)] @@ -421,75 +406,56 @@ def get_speed_tuning(self, binsize): binned_cells_vis = np.zeros((self.numbercells, nbins, 2)) binned_dx_vis = np.zeros((nbins, 2)) for i in range(nbins): - offset = findlevel(dx_sorted, 1, 'up') + offset = findlevel(dx_sorted, 1, "up") if offset is None: - StimulusAnalysis._log.info( - "dx never crosses 1, all speed data going into single bin") + StimulusAnalysis._log.info("dx never crosses 1, all speed data going into single bin") offset = len(dx_sorted) if i == 0: binned_dx_vis[i, 0] = np.mean(dx_sorted[:offset]) - binned_dx_vis[i, 1] = np.std( - dx_sorted[:offset]) / np.sqrt(offset) - binned_cells_vis[:, i, 0] = np.mean( - celltraces_sorted_vis[:, :offset], axis=1) - binned_cells_vis[:, i, 1] = ( - np.std(celltraces_sorted_vis[:, :offset], axis=1) / - np.sqrt(offset)) + binned_dx_vis[i, 1] = np.std(dx_sorted[:offset]) / np.sqrt(offset) + binned_cells_vis[:, i, 0] = np.mean(celltraces_sorted_vis[:, :offset], axis=1) + binned_cells_vis[:, i, 1] = np.std(celltraces_sorted_vis[:, :offset], axis=1) / np.sqrt(offset) else: # TODO 9 lines of repeated code!!!!!!!!!!!! start = offset + (i - 1) * binsize - binned_dx_vis[i, 0] = np.mean(dx_sorted[start:start + binsize]) - binned_dx_vis[i, 1] = np.std( - dx_sorted[start:start + binsize]) / np.sqrt(binsize) - binned_cells_vis[:, i, 0] = np.mean( - celltraces_sorted_vis[:, start:start + binsize], axis=1) - binned_cells_vis[:, i, 1] = np.std( - celltraces_sorted_vis[:, start:start + binsize], - axis=1) / np.sqrt(binsize) + binned_dx_vis[i, 0] = np.mean(dx_sorted[start : start + binsize]) + binned_dx_vis[i, 1] = np.std(dx_sorted[start : start + binsize]) / np.sqrt(binsize) + binned_cells_vis[:, i, 0] = np.mean(celltraces_sorted_vis[:, start : start + binsize], axis=1) + binned_cells_vis[:, i, 1] = np.std(celltraces_sorted_vis[:, start : start + binsize], axis=1) / np.sqrt( + binsize + ) binned_cells_shuffled_vis = np.empty((self.numbercells, nbins, 2, 200)) for shuf in range(200): - celltraces_shuffled = \ - celltraces_vis[:, - np.random.permutation( - np.size(celltraces_vis, 1))] - celltraces_shuffled_sorted = celltraces_shuffled[ - :, np.argsort(dx_vis)] + celltraces_shuffled = celltraces_vis[:, np.random.permutation(np.size(celltraces_vis, 1))] + celltraces_shuffled_sorted = celltraces_shuffled[:, np.argsort(dx_vis)] for i in range(nbins): - offset = findlevel(dx_sorted, 1, 'up') + offset = findlevel(dx_sorted, 1, "up") if offset is None: - StimulusAnalysis._log.info( - "dx never crosses 1, all speed data going into " - "single bin") + StimulusAnalysis._log.info("dx never crosses 1, all speed data going into single bin") offset = len(dx_sorted) if i == 0: - binned_cells_shuffled_vis[:, i, 0, shuf] = np.mean( - celltraces_shuffled_sorted[:, :offset], axis=1) - binned_cells_shuffled_vis[:, i, 1, shuf] = np.std( - celltraces_shuffled_sorted[:, :offset], axis=1) + binned_cells_shuffled_vis[:, i, 0, shuf] = np.mean(celltraces_shuffled_sorted[:, :offset], axis=1) + binned_cells_shuffled_vis[:, i, 1, shuf] = np.std(celltraces_shuffled_sorted[:, :offset], axis=1) else: start = offset + (i - 1) * binsize binned_cells_shuffled_vis[:, i, 0, shuf] = np.mean( - celltraces_shuffled_sorted[:, start:start + binsize], - axis=1) + celltraces_shuffled_sorted[:, start : start + binsize], axis=1 + ) binned_cells_shuffled_vis[:, i, 1, shuf] = np.std( - celltraces_shuffled_sorted[:, start:start + binsize], - axis=1) + celltraces_shuffled_sorted[:, start : start + binsize], axis=1 + ) - shuffled_variance_sp = binned_cells_shuffled_sp[ - :, :, 0, :].std(axis=1) ** 2 - variance_threshold_sp = np.percentile( - shuffled_variance_sp, 99.9, axis=1) + shuffled_variance_sp = binned_cells_shuffled_sp[:, :, 0, :].std(axis=1) ** 2 + variance_threshold_sp = np.percentile(shuffled_variance_sp, 99.9, axis=1) response_variance_sp = binned_cells_sp[:, :, 0].std(axis=1) ** 2 - shuffled_variance_vis = binned_cells_shuffled_vis[ - :, :, 0, :].std(axis=1) ** 2 - variance_threshold_vis = np.percentile( - shuffled_variance_vis, 99.9, axis=1) + shuffled_variance_vis = binned_cells_shuffled_vis[:, :, 0, :].std(axis=1) ** 2 + variance_threshold_vis = np.percentile(shuffled_variance_vis, 99.9, axis=1) response_variance_vis = binned_cells_vis[:, :, 0].std(axis=1) ** 2 for nc in range(self.numbercells): @@ -507,48 +473,37 @@ def get_speed_tuning(self, binsize): start_min = temp.argmin() peak_run.speed_min_sp[nc] = binned_dx_sp[start_min, 0] if peak_run.speed_max_sp[nc] > peak_run.speed_min_sp[nc]: - test_values = celltraces_sorted_sp[ - nc, - start_max * binsize:(start_max + 1) * binsize] - other_values = np.delete(celltraces_sorted_sp[nc, :], range( - start_max * binsize, (start_max + 1) * binsize)) - (_, peak_run.ptest_sp[nc]) = nonraising_ks_2samp( - test_values, other_values) + test_values = celltraces_sorted_sp[nc, start_max * binsize : (start_max + 1) * binsize] + other_values = np.delete( + celltraces_sorted_sp[nc, :], range(start_max * binsize, (start_max + 1) * binsize) + ) + (_, peak_run.ptest_sp[nc]) = nonraising_ks_2samp(test_values, other_values) else: - test_values = celltraces_sorted_sp[ - nc, - start_min * binsize:(start_min + 1) * binsize] - ind_max = min(celltraces_sorted_sp[nc, :].size, - (start_min + 1) * binsize) - other_values = np.delete(celltraces_sorted_sp[nc, :], range( - start_min * binsize, ind_max)) - (_, peak_run.ptest_sp[nc]) = nonraising_ks_2samp( - test_values, other_values) + test_values = celltraces_sorted_sp[nc, start_min * binsize : (start_min + 1) * binsize] + ind_max = min(celltraces_sorted_sp[nc, :].size, (start_min + 1) * binsize) + other_values = np.delete(celltraces_sorted_sp[nc, :], range(start_min * binsize, ind_max)) + (_, peak_run.ptest_sp[nc]) = nonraising_ks_2samp(test_values, other_values) temp = binned_cells_vis[nc, :, 0] start_max = temp.argmax() peak_run.speed_max_vis[nc] = binned_dx_vis[start_max, 0] start_min = temp.argmin() peak_run.speed_min_vis[nc] = binned_dx_vis[start_min, 0] if peak_run.speed_max_vis[nc] > peak_run.speed_min_vis[nc]: - test_values = celltraces_sorted_vis[ - nc, - start_max * binsize:(start_max + 1) * binsize] - other_values = np.delete(celltraces_sorted_vis[nc, :], range( - start_max * binsize, (start_max + 1) * binsize)) + test_values = celltraces_sorted_vis[nc, start_max * binsize : (start_max + 1) * binsize] + other_values = np.delete( + celltraces_sorted_vis[nc, :], range(start_max * binsize, (start_max + 1) * binsize) + ) else: - test_values = celltraces_sorted_vis[ - nc, - start_min * binsize:(start_min + 1) * binsize] - other_values = np.delete(celltraces_sorted_vis[nc, :], range( - start_min * binsize, (start_min + 1) * binsize)) - (_, peak_run.ptest_vis[nc]) = nonraising_ks_2samp( - test_values, other_values) + test_values = celltraces_sorted_vis[nc, start_min * binsize : (start_min + 1) * binsize] + other_values = np.delete( + celltraces_sorted_vis[nc, :], range(start_min * binsize, (start_min + 1) * binsize) + ) + (_, peak_run.ptest_vis[nc]) = nonraising_ks_2samp(test_values, other_values) - return binned_dx_sp, binned_cells_sp, binned_dx_vis, \ - binned_cells_vis, peak_run + return binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run def get_sweep_response(self): - """ Calculates the response to each sweep in the stimulus table for + """Calculates the response to each sweep in the stimulus table for each cell and the mean response. The return is a 3-tuple of: @@ -568,35 +523,29 @@ def get_sweep_response(self): def do_mean(x): # +1]) - return np.mean( - x[self.interlength: - self.interlength + self.sweeplength + self.extralength]) + return np.mean(x[self.interlength : self.interlength + self.sweeplength + self.extralength]) def do_p_value(x): - (_, p) = \ - st.f_oneway( - x[:self.interlength], - x[self.interlength: - self.interlength + self.sweeplength + self.extralength]) + (_, p) = st.f_oneway( + x[: self.interlength], x[self.interlength : self.interlength + self.sweeplength + self.extralength] + ) return p - StimulusAnalysis._log.info('Calculating responses for each sweep') - sweep_response = pd.DataFrame(index=self.stim_table.index.values, - columns=list(map(str, range( - self.numbercells + 1)))) + StimulusAnalysis._log.info("Calculating responses for each sweep") + sweep_response = pd.DataFrame( + index=self.stim_table.index.values, columns=list(map(str, range(self.numbercells + 1))) + ) - sweep_response.rename( - columns={str(self.numbercells): 'dx'}, inplace=True) + sweep_response.rename(columns={str(self.numbercells): "dx"}, inplace=True) for index, row in self.stim_table.iterrows(): - start = int(row['start'] - self.interlength) - end = int(row['start'] + self.sweeplength + self.interlength) + start = int(row["start"] - self.interlength) + end = int(row["start"] + self.sweeplength + self.interlength) for nc in range(self.numbercells): temp = self.celltraces[int(nc), start:end] - sweep_response[str(nc)][index] = \ - 100 * ((temp / np.mean(temp[:self.interlength])) - 1) - sweep_response['dx'][index] = self.dxcm[start:end] + sweep_response[str(nc)][index] = 100 * ((temp / np.mean(temp[: self.interlength])) - 1) + sweep_response["dx"][index] = self.dxcm[start:end] mean_sweep_response = sweep_response.applymap(do_mean) @@ -608,7 +557,7 @@ def plot_representational_similarity(self, repsim, stimulus=False): pass ax = plt.gca() - ax.imshow(repsim, interpolation='nearest', cmap='plasma') + ax.imshow(repsim, interpolation="nearest", cmap="plasma") def plot_running_speed_histogram(self, xlim=None, nbins=None): if xlim is None: @@ -622,10 +571,13 @@ def plot_running_speed_histogram(self, xlim=None, nbins=None): plt.xlabel("running speed (cm/s)") plt.ylabel("time points") - def plot_speed_tuning(self, cell_specimen_id=None, - cell_index=None, - evoked_color=oplots.EVOKED_COLOR, - spontaneous_color=oplots.SPONTANEOUS_COLOR): + def plot_speed_tuning( + self, + cell_specimen_id=None, + cell_index=None, + evoked_color=oplots.EVOKED_COLOR, + spontaneous_color=oplots.SPONTANEOUS_COLOR, + ): cell_index = self.row_from_cell_id(cell_specimen_id, cell_index) oplots.plot_combined_speed( @@ -633,24 +585,24 @@ def plot_speed_tuning(self, cell_specimen_id=None, self.binned_dx_vis[:, :], self.binned_cells_sp[cell_index, :, :] * 100, self.binned_dx_sp[:, :], - evoked_color, spontaneous_color) + evoked_color, + spontaneous_color, + ) plt.xlabel("running speed (cm/s)") plt.ylabel("percent dF/F") def row_from_cell_id(self, csid=None, idx=None): - if csid is not None and not np.isnan(csid): return self.data_set.get_cell_specimen_ids().tolist().index(csid) elif idx is not None: return idx else: - raise Exception("Could not find row for csid(%s) idx(%s)" - % (str(csid), str(idx))) + raise Exception("Could not find row for csid(%s) idx(%s)" % (str(csid), str(idx))) def nonraising_ks_2samp(data1, data2, **kwargs): - """ scipy.stats.ks_2samp now raises a ValueError if one of the input arrays + """scipy.stats.ks_2samp now raises a ValueError if one of the input arrays is of length 0. Previously it signaled this case by returning nans. This function restores the prior behavior. """ diff --git a/allensdk/brain_observatory/stimulus_info.py b/allensdk/brain_observatory/stimulus_info.py index 8d0f04abaf..848c280f20 100755 --- a/allensdk/brain_observatory/stimulus_info.py +++ b/allensdk/brain_observatory/stimulus_info.py @@ -202,9 +202,7 @@ def stimuli_in_session(session, allow_unknown=True): def all_stimuli(): """Return a list of all stimuli in the data set""" - return set( - [v for k, vl in SESSION_STIMULUS_MAP.items() for v in vl] - ) + return set([v for k, vl in SESSION_STIMULUS_MAP.items() for v in vl]) class BinaryIntervalSearchTree(object): @@ -262,9 +260,7 @@ def search(self, fi, tmp=None): if tmp is None: tmp = [] - if (self.data[tuple(tmp)][0] <= fi) and ( - fi <= self.data[tuple(tmp)][1] - ): + if (self.data[tuple(tmp)][0] <= fi) and (fi <= self.data[tuple(tmp)][1]): return_val = self.data[tuple(tmp)] elif fi < self.data[tuple(tmp)][1]: return_val = self.search(fi, tmp=tmp + [0]) @@ -311,9 +307,7 @@ def search(self, fi): def rotate(X, Y, theta): x = np.array([X, Y]) - M = np.array( - [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] - ) + M = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) if len(x.shape) in [1, 2]: assert x.shape[0] == 2 return M.dot(x) @@ -353,25 +347,19 @@ def get_spatial_grating( img = np.cos(2.0 * np.pi * Xp * sf + ph) - return (p2p_amp / 2.0) * spndi.zoom( - img, height / float(_height_prime) - ) + baseline + return (p2p_amp / 2.0) * spndi.zoom(img, height / float(_height_prime)) + baseline # def grating_to_screen(self, phase, spatial_frequency, orientation, **kwargs): def get_spatio_temporal_grating(t, temporal_frequency=None, **kwargs): - kwargs["phase"] = ( - kwargs.pop("phase", 0) + (float(t) * temporal_frequency) % 1 - ) + kwargs["phase"] = kwargs.pop("phase", 0) + (float(t) * temporal_frequency) % 1 return get_spatial_grating(**kwargs) -def map_template_coordinate_to_monitor_coordinate( - template_coord, monitor_shape, template_shape -): +def map_template_coordinate_to_monitor_coordinate(template_coord, monitor_shape, template_shape): rx, cx = template_coord n_pixels_r, n_pixels_c = monitor_shape tr, tc = template_shape @@ -382,9 +370,7 @@ def map_template_coordinate_to_monitor_coordinate( return rx_new, cx_new -def map_monitor_coordinate_to_template_coordinate( - monitor_coord, monitor_shape, template_shape -): +def map_monitor_coordinate_to_template_coordinate(monitor_coord, monitor_shape, template_shape): rx, cx = monitor_coord n_pixels_r, n_pixels_c = monitor_shape tr, tc = template_shape @@ -395,9 +381,7 @@ def map_monitor_coordinate_to_template_coordinate( return rx_new, cx_new -def lsn_coordinate_to_monitor_coordinate( - lsn_coordinate, monitor_shape, stimulus_type -): +def lsn_coordinate_to_monitor_coordinate(lsn_coordinate, monitor_shape, stimulus_type): template_shape = LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus_type] pixels_per_patch = LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type] @@ -411,9 +395,7 @@ def lsn_coordinate_to_monitor_coordinate( ) -def monitor_coordinate_to_lsn_coordinate( - monitor_coordinate, monitor_shape, stimulus_type -): +def monitor_coordinate_to_lsn_coordinate(monitor_coordinate, monitor_shape, stimulus_type): pixels_per_patch = LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type] tr, tc = LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus_type] @@ -426,50 +408,24 @@ def monitor_coordinate_to_lsn_coordinate( return (rx / pixels_per_patch, cx / pixels_per_patch) -def natural_scene_coordinate_to_monitor_coordinate( - natural_scene_coordinate, monitor_shape -): - return map_template_coordinate_to_monitor_coordinate( - natural_scene_coordinate, monitor_shape, NATURAL_SCENES_PIXELS - ) +def natural_scene_coordinate_to_monitor_coordinate(natural_scene_coordinate, monitor_shape): + return map_template_coordinate_to_monitor_coordinate(natural_scene_coordinate, monitor_shape, NATURAL_SCENES_PIXELS) -def natural_movie_coordinate_to_monitor_coordinate( - natural_movie_coordinate, monitor_shape -): - local_y = ( - 1.0 - * NATURAL_MOVIE_PIXELS[0] - * natural_movie_coordinate[0] - / NATURAL_MOVIE_DIMENSIONS[0] - ) - local_x = ( - 1.0 - * NATURAL_MOVIE_PIXELS[1] - * natural_movie_coordinate[1] - / NATURAL_MOVIE_DIMENSIONS[1] - ) +def natural_movie_coordinate_to_monitor_coordinate(natural_movie_coordinate, monitor_shape): + local_y = 1.0 * NATURAL_MOVIE_PIXELS[0] * natural_movie_coordinate[0] / NATURAL_MOVIE_DIMENSIONS[0] + local_x = 1.0 * NATURAL_MOVIE_PIXELS[1] * natural_movie_coordinate[1] / NATURAL_MOVIE_DIMENSIONS[1] - return map_template_coordinate_to_monitor_coordinate( - (local_y, local_x), monitor_shape, NATURAL_MOVIE_PIXELS - ) + return map_template_coordinate_to_monitor_coordinate((local_y, local_x), monitor_shape, NATURAL_MOVIE_PIXELS) -def map_stimulus_coordinate_to_monitor_coordinate( - template_coordinate, monitor_shape, stimulus_type -): +def map_stimulus_coordinate_to_monitor_coordinate(template_coordinate, monitor_shape, stimulus_type): if stimulus_type in LOCALLY_SPARSE_NOISE_STIMULUS_TYPES: - return lsn_coordinate_to_monitor_coordinate( - template_coordinate, monitor_shape, stimulus_type - ) + return lsn_coordinate_to_monitor_coordinate(template_coordinate, monitor_shape, stimulus_type) elif stimulus_type in NATURAL_MOVIE_STIMULUS_TYPES: - return natural_movie_coordinate_to_monitor_coordinate( - template_coordinate, monitor_shape - ) + return natural_movie_coordinate_to_monitor_coordinate(template_coordinate, monitor_shape) elif stimulus_type == NATURAL_SCENES: - return natural_scene_coordinate_to_monitor_coordinate( - template_coordinate, monitor_shape - ) + return natural_scene_coordinate_to_monitor_coordinate(template_coordinate, monitor_shape) elif stimulus_type in [ DRIFTING_GRATINGS, STATIC_GRATINGS, @@ -480,9 +436,7 @@ def map_stimulus_coordinate_to_monitor_coordinate( raise NotImplementedError # pragma: no cover -def monitor_coordinate_to_natural_movie_coordinate( - monitor_coordinate, monitor_shape -): +def monitor_coordinate_to_natural_movie_coordinate(monitor_coordinate, monitor_shape): local_y, local_x = map_monitor_coordinate_to_template_coordinate( monitor_coordinate, monitor_shape, NATURAL_MOVIE_PIXELS ) @@ -493,21 +447,13 @@ def monitor_coordinate_to_natural_movie_coordinate( ) -def map_monitor_coordinate_to_stimulus_coordinate( - monitor_coordinate, monitor_shape, stimulus_type -): +def map_monitor_coordinate_to_stimulus_coordinate(monitor_coordinate, monitor_shape, stimulus_type): if stimulus_type in LOCALLY_SPARSE_NOISE_STIMULUS_TYPES: - return monitor_coordinate_to_lsn_coordinate( - monitor_coordinate, monitor_shape, stimulus_type - ) + return monitor_coordinate_to_lsn_coordinate(monitor_coordinate, monitor_shape, stimulus_type) elif stimulus_type == NATURAL_SCENES: - return map_monitor_coordinate_to_template_coordinate( - monitor_coordinate, monitor_shape, NATURAL_SCENES_PIXELS - ) + return map_monitor_coordinate_to_template_coordinate(monitor_coordinate, monitor_shape, NATURAL_SCENES_PIXELS) elif stimulus_type in NATURAL_MOVIE_STIMULUS_TYPES: - return monitor_coordinate_to_natural_movie_coordinate( - monitor_coordinate, monitor_shape - ) + return monitor_coordinate_to_natural_movie_coordinate(monitor_coordinate, monitor_shape) elif stimulus_type in [ DRIFTING_GRATINGS, STATIC_GRATINGS, @@ -524,12 +470,8 @@ def map_stimulus( target_stimulus_type, monitor_shape, ): - mc = map_stimulus_coordinate_to_monitor_coordinate( - source_stimulus_coordinate, monitor_shape, source_stimulus_type - ) - return map_monitor_coordinate_to_stimulus_coordinate( - mc, monitor_shape, target_stimulus_type - ) + mc = map_stimulus_coordinate_to_monitor_coordinate(source_stimulus_coordinate, monitor_shape, source_stimulus_type) + return map_monitor_coordinate_to_stimulus_coordinate(mc, monitor_shape, target_stimulus_type) def translate_image_and_fill(img, translation=(0, 0)): @@ -581,9 +523,7 @@ def aspect_ratio(self): @property def height(self): - return self.spatial_conversion_factor * np.sqrt( - self.panel_size**2 / (1 + self.aspect_ratio**2) - ) + return self.spatial_conversion_factor * np.sqrt(self.panel_size**2 / (1 + self.aspect_ratio**2)) @property def width(self): @@ -604,32 +544,17 @@ def set_spatial_unit(self, new_unit): def pixel_size(self): return float(self.width) / self.n_pixels_c - def pixels_to_visual_degrees( - self, n, distance_from_monitor, small_angle_approximation=True - ): + def pixels_to_visual_degrees(self, n, distance_from_monitor, small_angle_approximation=True): if small_angle_approximation: - return ( - n - * self.pixel_size - / distance_from_monitor - * RADIANS_TO_DEGREES - ) # radians to degrees + return n * self.pixel_size / distance_from_monitor * RADIANS_TO_DEGREES # radians to degrees else: return ( - 2 - * np.arctan( - n * 1.0 / 2 * self.pixel_size / distance_from_monitor - ) - * RADIANS_TO_DEGREES + 2 * np.arctan(n * 1.0 / 2 * self.pixel_size / distance_from_monitor) * RADIANS_TO_DEGREES ) # radians to degrees - def visual_degrees_to_pixels( - self, vd, distance_from_monitor, small_angle_approximation=True - ): + def visual_degrees_to_pixels(self, vd, distance_from_monitor, small_angle_approximation=True): if small_angle_approximation: - return vd * ( - distance_from_monitor / self.pixel_size / RADIANS_TO_DEGREES - ) + return vd * (distance_from_monitor / self.pixel_size / RADIANS_TO_DEGREES) else: raise NotImplementedError @@ -650,24 +575,14 @@ def lsn_image_to_screen( ) pixels_per_patch = float(LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type]) - target_size = tuple( - int(pixels_per_patch * dimsize) for dimsize in img.shape[::-1] - ) - img_full_res = np.array( - Image.fromarray(img).resize(target_size, 0) - ) # 0 -> nearest neighbor interpolator + target_size = tuple(int(pixels_per_patch * dimsize) for dimsize in img.shape[::-1]) + img_full_res = np.array(Image.fromarray(img).resize(target_size, 0)) # 0 -> nearest neighbor interpolator - mr, mc = lsn_coordinate_to_monitor_coordinate( - (0, 0), (self.n_pixels_r, self.n_pixels_c), stimulus_type - ) - Mr, Mc = lsn_coordinate_to_monitor_coordinate( - img.shape, (self.n_pixels_r, self.n_pixels_c), stimulus_type - ) + mr, mc = lsn_coordinate_to_monitor_coordinate((0, 0), (self.n_pixels_r, self.n_pixels_c), stimulus_type) + Mr, Mc = lsn_coordinate_to_monitor_coordinate(img.shape, (self.n_pixels_r, self.n_pixels_c), stimulus_type) full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img_full_res - full_image = translate_image_and_fill( - full_image, translation=translation - ) + full_image = translate_image_and_fill(full_image, translation=translation) if origin == "lower": return full_image @@ -678,23 +593,15 @@ def lsn_image_to_screen( return full_image - def natural_scene_image_to_screen( - self, img, origin="lower", translation=(0, 0) - ): - full_image = np.full( - (self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8 - ) - mr, mc = natural_scene_coordinate_to_monitor_coordinate( - (0, 0), (self.n_pixels_r, self.n_pixels_c) - ) + def natural_scene_image_to_screen(self, img, origin="lower", translation=(0, 0)): + full_image = np.full((self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8) + mr, mc = natural_scene_coordinate_to_monitor_coordinate((0, 0), (self.n_pixels_r, self.n_pixels_c)) Mr, Mc = natural_scene_coordinate_to_monitor_coordinate( (img.shape[0], img.shape[1]), (self.n_pixels_r, self.n_pixels_c) ) full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img - full_image = translate_image_and_fill( - full_image, translation=translation - ) + full_image = translate_image_and_fill(full_image, translation=translation) if origin == "lower": return np.flipud(full_image) @@ -703,20 +610,14 @@ def natural_scene_image_to_screen( else: raise Exception - def natural_movie_image_to_screen( - self, img, origin="lower", translation=(0, 0) - ): - img = np.array( - Image.fromarray(img).resize(NATURAL_MOVIE_PIXELS[::-1], 2) - ).astype( + def natural_movie_image_to_screen(self, img, origin="lower", translation=(0, 0)): + img = np.array(Image.fromarray(img).resize(NATURAL_MOVIE_PIXELS[::-1], 2)).astype( np.uint8 ) # 2 -> bilinear interpolator assert img.dtype == np.uint8 - full_image = np.full( - (self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8 - ) + full_image = np.full((self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8) mr, mc = map_template_coordinate_to_monitor_coordinate( (0, 0), (self.n_pixels_r, self.n_pixels_c), NATURAL_MOVIE_PIXELS ) @@ -728,9 +629,7 @@ def natural_movie_image_to_screen( full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img - full_image = translate_image_and_fill( - full_image, translation=translation - ) + full_image = translate_image_and_fill(full_image, translation=translation) if origin == "lower": return np.flipud(full_image) @@ -739,15 +638,9 @@ def natural_movie_image_to_screen( else: raise Exception - def spatial_frequency_to_pix_per_cycle( - self, spatial_frequency, distance_from_monitor - ): + def spatial_frequency_to_pix_per_cycle(self, spatial_frequency, distance_from_monitor): # How many cycles do I want to see post warp: - number_of_cycles = ( - spatial_frequency - * 2 - * np.degrees(np.arctan(self.width / 2.0 / distance_from_monitor)) - ) + number_of_cycles = spatial_frequency * 2 * np.degrees(np.arctan(self.width / 2.0 / distance_from_monitor)) # How many pixels to I have pre-warp to place my cycles on: _, m_col = np.where(self.mask != 0) @@ -765,9 +658,7 @@ def grating_to_screen( baseline=127, translation=(0, 0), ): - pix_per_cycle = self.spatial_frequency_to_pix_per_cycle( - spatial_frequency, distance_from_monitor - ) + pix_per_cycle = self.spatial_frequency_to_pix_per_cycle(spatial_frequency, distance_from_monitor) full_image = get_spatial_grating( height=self.n_pixels_r, @@ -779,24 +670,18 @@ def grating_to_screen( baseline=baseline, ) - full_image = translate_image_and_fill( - full_image, translation=translation - ) + full_image = translate_image_and_fill(full_image, translation=translation) return full_image def get_mask(self): - mask = make_display_mask( - display_shape=(self.n_pixels_c, self.n_pixels_r) - ).T + mask = make_display_mask(display_shape=(self.n_pixels_c, self.n_pixels_r)).T assert mask.shape[0] == self.n_pixels_r assert mask.shape[1] == self.n_pixels_c return mask - def show_image( - self, img, ax=None, show=True, mask=False, warp=False, origin="lower" - ): + def show_image(self, img, ax=None, show=True, mask=False, warp=False, origin="lower"): import matplotlib.pyplot as plt assert img.shape == ( @@ -816,9 +701,7 @@ def show_image( ax.imshow(img, origin=origin, cmap=plt.cm.gray, interpolation="none") if mask: - mask = make_display_mask( - display_shape=(self.n_pixels_c, self.n_pixels_r) - ).T + mask = make_display_mask(display_shape=(self.n_pixels_c, self.n_pixels_r)).T alpha_mask = np.zeros((mask.shape[0], mask.shape[1], 4)) alpha_mask[:, :, 2] = 1 - mask alpha_mask[:, :, 3] = 0.4 @@ -854,9 +737,7 @@ def map_stimulus( class ExperimentGeometry(object): - def __init__( - self, distance, mon_height_cm, mon_width_cm, mon_res, eyepoint - ): + def __init__(self, distance, mon_height_cm, mon_width_cm, mon_res, eyepoint): self.distance = distance self.mon_height_cm = mon_height_cm self.mon_width_cm = mon_width_cm @@ -902,9 +783,7 @@ class BrainObservatoryMonitor(Monitor): def __init__(self, experiment_geometry=None): height, width = MONITOR_DIMENSIONS - super(BrainObservatoryMonitor, self).__init__( - height, width, 61.214, "cm" - ) + super(BrainObservatoryMonitor, self).__init__(height, width, 61.214, "cm") if experiment_geometry is None: self.experiment_geometry = ExperimentGeometry( @@ -918,24 +797,12 @@ def __init__(self, experiment_geometry=None): self.experiment_geometry = experiment_geometry def lsn_image_to_screen(self, img, **kwargs): - if img.shape == tuple( - LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE] - ): - return super(BrainObservatoryMonitor, self).lsn_image_to_screen( - img, LOCALLY_SPARSE_NOISE, **kwargs - ) - elif img.shape == tuple( - LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_4DEG] - ): - return super(BrainObservatoryMonitor, self).lsn_image_to_screen( - img, LOCALLY_SPARSE_NOISE_4DEG, **kwargs - ) - elif img.shape == tuple( - LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_8DEG] - ): - return super(BrainObservatoryMonitor, self).lsn_image_to_screen( - img, LOCALLY_SPARSE_NOISE_8DEG, **kwargs - ) + if img.shape == tuple(LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE]): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen(img, LOCALLY_SPARSE_NOISE, **kwargs) + elif img.shape == tuple(LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_4DEG]): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen(img, LOCALLY_SPARSE_NOISE_4DEG, **kwargs) + elif img.shape == tuple(LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_8DEG]): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen(img, LOCALLY_SPARSE_NOISE_8DEG, **kwargs) else: # pragma: no cover raise RuntimeError # pragma: no cover @@ -943,13 +810,11 @@ def warp_image(self, img, **kwargs): assert img.shape == (self.n_pixels_r, self.n_pixels_c) assert self.spatial_unit == "cm" - return spndi.map_coordinates( - img, self.experiment_geometry.warp_coordinates.T - ).reshape((self.n_pixels_r, self.n_pixels_c)) + return spndi.map_coordinates(img, self.experiment_geometry.warp_coordinates.T).reshape( + (self.n_pixels_r, self.n_pixels_c) + ) - def grating_to_screen( - self, phase, spatial_frequency, orientation, **kwargs - ): + def grating_to_screen(self, phase, spatial_frequency, orientation, **kwargs): return super(BrainObservatoryMonitor, self).grating_to_screen( phase, spatial_frequency, @@ -1097,9 +962,7 @@ def make_display_mask(display_shape=(1920, 1200)): return mask -def mask_stimulus_template( - template_display_coords, template_shape, display_mask=None, threshold=1.0 -): +def mask_stimulus_template(template_display_coords, template_shape, display_mask=None, threshold=1.0): """Build a mask for a stimulus template of a given shape and display coordinates that indicates which part of the template is on screen after warping. @@ -1131,10 +994,7 @@ def mask_stimulus_template( mask = np.zeros(template_shape, dtype=bool) for y in range(template_shape[1]): for x in range(template_shape[0]): - tdcm = np.where( - (template_display_coords[0, :, :] == x) - & (template_display_coords[1, :, :] == y) - ) + tdcm = np.where((template_display_coords[0, :, :] == x) & (template_display_coords[1, :, :] == y)) v = display_mask[tdcm] f = np.sum(v) / len(v) frac[x, y] = f diff --git a/allensdk/brain_observatory/sync_dataset.py b/allensdk/brain_observatory/sync_dataset.py index eab4140d14..1eef020c42 100644 --- a/allensdk/brain_observatory/sync_dataset.py +++ b/allensdk/brain_observatory/sync_dataset.py @@ -14,6 +14,7 @@ h5py http://www.h5py.org/ """ + import collections from typing import Union, Sequence, Optional @@ -22,12 +23,13 @@ import warnings import logging + logger = logging.getLogger(__name__) dset_version = 1.04 -def unpack_uint32(uint32_array, endian='L'): +def unpack_uint32(uint32_array, endian="L"): """ Unpacks an array of 32-bit unsigned integers into bits. @@ -45,7 +47,7 @@ def unpack_uint32(uint32_array, endian='L'): uint8_array = np.frombuffer(buff, dtype=np.uint8) uint8_array = np.fliplr(uint8_array.reshape(-1, 4)) bits = np.unpackbits(uint8_array).reshape(-1, 32) - if endian.upper() == 'B': + if endian.upper() == "B": bits = np.fliplr(bits) return bits @@ -62,7 +64,7 @@ def get_bit(uint_array, bit): The bit to extract. """ - return np.bitwise_and(uint_array, 2 ** bit).astype(bool).astype(np.uint8) + return np.bitwise_and(uint_array, 2**bit).astype(bool).astype(np.uint8) class Dataset(object): @@ -93,23 +95,28 @@ class Dataset(object): """ - FRAME_KEYS = ('frames', 'stim_vsync', 'vsync_stim') - PHOTODIODE_KEYS = ('photodiode', 'stim_photodiode') + + FRAME_KEYS = ("frames", "stim_vsync", "vsync_stim") + PHOTODIODE_KEYS = ("photodiode", "stim_photodiode") OPTOGENETIC_STIMULATION_KEYS = ("LED_sync", "opto_trial") - EYE_TRACKING_KEYS = ("eye_frame_received", # Expected eye tracking - # line label after 3/27/2020 - # clocks eye tracking frame pulses (port 0, line 9) - "cam2_exposure", - # previous line label for eye tracking - # (prior to ~ Oct. 2018) - "eyetracking", - "eye_cam_exposing", - "eye_tracking") # An undocumented, but possible eye tracking line label # NOQA E114 - BEHAVIOR_TRACKING_KEYS = ("beh_frame_received", # Expected behavior line label after 3/27/2020 # NOQA E127 - # clocks behavior tracking frame # NOQA E127 - # pulses (port 0, line 8) - "cam1_exposure", - "behavior_monitoring") + EYE_TRACKING_KEYS = ( + "eye_frame_received", # Expected eye tracking + # line label after 3/27/2020 + # clocks eye tracking frame pulses (port 0, line 9) + "cam2_exposure", + # previous line label for eye tracking + # (prior to ~ Oct. 2018) + "eyetracking", + "eye_cam_exposing", + "eye_tracking", + ) # An undocumented, but possible eye tracking line label # NOQA E114 + BEHAVIOR_TRACKING_KEYS = ( + "beh_frame_received", # Expected behavior line label after 3/27/2020 # NOQA E127 + # clocks behavior tracking frame # NOQA E127 + # pulses (port 0, line 8) + "cam1_exposure", + "behavior_monitoring", + ) DEPRECATED_KEYS = set() @@ -121,13 +128,17 @@ def _check_line_labels(self): if hasattr(self, "line_labels"): deprecated_keys = set(self.line_labels) & self.DEPRECATED_KEYS if deprecated_keys: - warnings.warn((f"The loaded sync file contains the " - f"following deprecated line label keys: " - f"{deprecated_keys}. Consider updating the " - f"sync file line labels."), stacklevel=2) + warnings.warn( + ( + f"The loaded sync file contains the " + f"following deprecated line label keys: " + f"{deprecated_keys}. Consider updating the " + f"sync file line labels." + ), + stacklevel=2, + ) else: - warnings.warn(("The loaded sync file has no line labels and may " - "not be valid."), stacklevel=2) + warnings.warn(("The loaded sync file has no line labels and may not be valid."), stacklevel=2) def _process_times(self): """ @@ -156,18 +167,19 @@ def load(self, path): """ self.dfile = h5.File( - path, 'r') # MG edit 3/15 removed 'r' because some sync files were unable to load # NOQA E501 - self.meta_data = eval(self.dfile['meta'][()]) - self.line_labels = self.meta_data['line_labels'] + path, "r" + ) # MG edit 3/15 removed 'r' because some sync files were unable to load # NOQA E501 + self.meta_data = eval(self.dfile["meta"][()]) + self.line_labels = self.meta_data["line_labels"] self.times = self._process_times() return self.dfile @property def sample_freq(self): try: - return float(self.meta_data['ni_daq']['sample_freq']) + return float(self.meta_data["ni_daq"]["sample_freq"]) except KeyError: - return float(self.meta_data['ni_daq']['counter_output_freq']) + return float(self.meta_data["ni_daq"]["counter_output_freq"]) def get_bit(self, bit): """ @@ -226,9 +238,9 @@ def get_all_bits(self): Returns the data for all bits. """ - return self.dfile['data'][()][:, -1] + return self.dfile["data"][()][:, -1] - def get_all_times(self, units='samples'): + def get_all_times(self, units="samples"): """ Returns all counter values. @@ -238,14 +250,14 @@ def get_all_times(self, units='samples'): Return times in 'samples' or 'seconds' """ - if self.meta_data['ni_daq']['counter_bits'] == 32: + if self.meta_data["ni_daq"]["counter_bits"] == 32: times = self.get_all_events()[:, 0] else: times = self.times units = units.lower() - if units == 'samples': + if units == "samples": return times - elif units in ['seconds', 'sec', 'secs']: + elif units in ["seconds", "sec", "secs"]: freq = self.sample_freq return times / freq else: @@ -255,9 +267,9 @@ def get_all_events(self): """ Returns all counter values and their cooresponding IO state. """ - return self.dfile['data'][()] + return self.dfile["data"][()] - def get_events_by_bit(self, bit, units='samples'): + def get_events_by_bit(self, bit, units="samples"): """ Returns all counter values for transitions (both rising and falling) for a specific bit. @@ -271,7 +283,7 @@ def get_events_by_bit(self, bit, units='samples'): changes = self.get_bit_changes(bit) return self.get_all_times(units)[np.where(changes != 0)] - def get_events_by_line(self, line, units='samples'): + def get_events_by_line(self, line, units="samples"): """ Returns all counter values for transitions (both rising and falling) for a specific line. @@ -314,7 +326,7 @@ def _bit_to_line(self, bit): """ return self.line_labels[bit] - def get_rising_edges(self, line, units='samples'): + def get_rising_edges(self, line, units="samples"): """ Returns the counter values for the rizing edges for a specific bit or line. @@ -330,13 +342,9 @@ def get_rising_edges(self, line, units='samples'): return self.get_all_times(units)[np.where(changes == 1)] def get_edges( - self, - kind: str, - keys: Union[str, Sequence[str]], - units: str = "seconds", - permissive: bool = False + self, kind: str, keys: Union[str, Sequence[str]], units: str = "seconds", permissive: bool = False ) -> Optional[np.ndarray]: - """ Utility function for extracting edge times from a line + """Utility function for extracting edge times from a line Parameters ---------- @@ -361,15 +369,14 @@ def get_edges( line labels """ - if kind == 'falling': + if kind == "falling": fn = self.get_falling_edges - elif kind == 'rising': + elif kind == "rising": fn = self.get_rising_edges - elif kind == 'all': - return np.sort(np.concatenate([ - self.get_edges('rising', keys, units), - self.get_edges('falling', keys, units) - ])) + elif kind == "all": + return np.sort( + np.concatenate([self.get_edges("rising", keys, units), self.get_edges("falling", keys, units)]) + ) if isinstance(keys, str): keys = [keys] @@ -381,10 +388,9 @@ def get_edges( continue if not permissive: - raise KeyError( - f"none of {keys} were found in this dataset's line labels") + raise KeyError(f"none of {keys} were found in this dataset's line labels") - def get_falling_edges(self, line, units='samples'): + def get_falling_edges(self, line, units="samples"): """ Returns the counter values for the falling edges for a specific bit or line. @@ -399,14 +405,15 @@ def get_falling_edges(self, line, units='samples'): changes = self.get_bit_changes(bit) return self.get_all_times(units)[np.where(changes == 255)] - def get_nearest(self, - source, - target, - source_edge="rising", - target_edge="rising", - direction="previous", - units='indices', - ): + def get_nearest( + self, + source, + target, + source_edge="rising", + target_edge="rising", + direction="previous", + units="indices", + ): """ For all values of the source line, finds the nearest edge from the target line. @@ -423,30 +430,23 @@ def get_nearest(self, units (str): "indices" """ - source_edges = getattr(self, - "get_{}_edges".format(source_edge.lower()))(source.lower(), units="samples") # NOQA E501 - target_edges = getattr(self, - "get_{}_edges".format(target_edge.lower()))(target.lower(), units="samples") # NOQA E501 + source_edges = getattr(self, "get_{}_edges".format(source_edge.lower()))(source.lower(), units="samples") # NOQA E501 + target_edges = getattr(self, "get_{}_edges".format(target_edge.lower()))(target.lower(), units="samples") # NOQA E501 indices = np.searchsorted(target_edges, source_edges, side="right") if direction.lower() == "previous": indices[np.where(indices != 0)] -= 1 elif direction.lower() == "next": indices[np.where(indices == len(target_edges))] = -1 - if units in ["indices", 'index']: + if units in ["indices", "index"]: return indices elif units == "samples": return target_edges[indices] - elif units in ['sec', 'seconds', 'second']: + elif units in ["sec", "seconds", "second"]: return target_edges[indices] / self.sample_freq else: - raise KeyError( - "Invalid units. Try 'seconds', 'samples' or 'indices'") + raise KeyError("Invalid units. Try 'seconds', 'samples' or 'indices'") - def get_analog_channel(self, - channel, - start_time=0.0, - stop_time=None, - downsample=1): + def get_analog_channel(self, channel, start_time=0.0, stop_time=None, downsample=1): """ Returns the data from the specified analog channel between the timepoints. @@ -465,15 +465,13 @@ def get_analog_channel(self, """ if isinstance(channel, str): - channel_index = self.analog_meta_data['analog_labels'].index( - channel) - channel = self.analog_meta_data['analog_channels'].index( - channel_index) + channel_index = self.analog_meta_data["analog_labels"].index(channel) + channel = self.analog_meta_data["analog_channels"].index(channel_index) if "analog_data" in self.dfile.keys(): - dset = self.dfile['analog_data'] + dset = self.dfile["analog_data"] analog_meta = self.get_analog_meta() - sample_rate = analog_meta['analog_sample_rate'] + sample_rate = analog_meta["analog_sample_rate"] start = int(start_time * sample_rate) if stop_time: stop = int(stop_time * sample_rate) @@ -488,7 +486,7 @@ def get_analog_meta(self): Returns the metadata for the analog data. """ if "analog_meta" in self.dfile.keys(): - return eval(self.dfile['analog_meta'].value) + return eval(self.dfile["analog_meta"].value) else: raise KeyError("No analog data was saved.") @@ -536,22 +534,21 @@ def line_stats(self, line, print_results=True): logger.info("Falling: %s" % total_falling) logger.info("*" * 70) return { - 'line': line, - 'bit': bit, - 'total_rising': total_rising, - 'total_falling': total_falling, - 'avg_freq': None, - 'duty_cycle': None, + "line": line, + "bit": bit, + "total_rising": total_rising, + "total_falling": total_falling, + "avg_freq": None, + "duty_cycle": None, } else: - # period period = self.period(line) - avg_period = period['avg'] - max_period = period['max'] - min_period = period['min'] - period_sd = period['sd'] + avg_period = period["avg"] + max_period = period["max"] + min_period = period["min"] + period_sd = period["sd"] # freq avg_freq = self.frequency(line) @@ -578,18 +575,18 @@ def line_stats(self, line, print_results=True): logger.info("*" * 70) return { - 'line': line, - 'bit': bit, - 'total_data_points': total_data_points, - 'total_events': total_events, - 'total_rising': total_rising, - 'total_falling': total_falling, - 'avg_period': avg_period, - 'min_period': min_period, - 'max_period': max_period, - 'period_sd': period_sd, - 'avg_freq': avg_freq, - 'duty_cycle': duty_cycle, + "line": line, + "bit": bit, + "total_data_points": total_data_points, + "total_events": total_events, + "total_rising": total_rising, + "total_falling": total_falling, + "avg_period": avg_period, + "min_period": min_period, + "max_period": max_period, + "period_sd": period_sd, + "avg_freq": avg_freq, + "duty_cycle": duty_cycle, } def period(self, line, edge="rising"): @@ -604,8 +601,7 @@ def period(self, line, edge="rising"): edges = self.get_falling_edges(bit) if len(edges) > 2: - - timebase_freq = self.meta_data['ni_daq']['counter_output_freq'] + timebase_freq = self.meta_data["ni_daq"]["counter_output_freq"] avg_period = np.mean(np.ediff1d(edges[1:])) / timebase_freq max_period = np.max(np.ediff1d(edges[1:])) / timebase_freq min_period = np.min(np.ediff1d(edges[1:])) / timebase_freq @@ -615,10 +611,10 @@ def period(self, line, edge="rising"): raise IndexError("Not enough edges for period: %i" % len(edges)) return { - 'avg': avg_period, - 'max': max_period, - 'min': min_period, - 'sd': period_sd, + "avg": avg_period, + "max": max_period, + "min": min_period, + "sd": period_sd, } def frequency(self, line, edge="rising"): @@ -627,7 +623,7 @@ def frequency(self, line, edge="rising"): """ period = self.period(line, edge) - return 1.0 / period['avg'] + return 1.0 / period["avg"] def duty_cycle(self, line): """ @@ -657,8 +653,7 @@ def duty_cycle(self, line): high = falling - rising else: # line starts high - high = np.concatenate(falling, self.get_all_events()[-1, 0]) - \ - np.concatenate(0, rising) + high = np.concatenate(falling, self.get_all_events()[-1, 0]) - np.concatenate(0, rising) total_high_time = np.sum(high) all_events = self.get_events_by_bit(bit) @@ -677,20 +672,21 @@ def stats(self): logger.info("Active bits: ", len(active_bits)) for bit in active_bits: logger.info("*" * 70) - logger.info("Bit: %i" % bit['bit']) - logger.info("Label: %s" % self.line_labels[bit['bit']]) - logger.info("Rising edges: %i" % bit['total_rising']) + logger.info("Bit: %i" % bit["bit"]) + logger.info("Label: %s" % self.line_labels[bit["bit"]]) + logger.info("Rising edges: %i" % bit["total_rising"]) logger.info("Falling edges: %i" % bit["total_falling"]) - logger.info("Average freq: %s" % bit['avg_freq']) - logger.info("Duty cycle: %s" % bit['duty_cycle']) + logger.info("Average freq: %s" % bit["avg_freq"]) + logger.info("Duty cycle: %s" % bit["duty_cycle"]) logger.info("*" * 70) return active_bits - def plot_all(self, - start_time, - stop_time, - auto_show=True, - ): + def plot_all( + self, + start_time, + stop_time, + auto_show=True, + ): """ Plot all active bits. @@ -698,21 +694,25 @@ def plot_all(self, """ import matplotlib.pyplot as plt + for bit in range(32): if len(self.get_events_by_bit(bit)) > 0: - self.plot_bit(bit, - start_time, - stop_time, - auto_show=False, ) + self.plot_bit( + bit, + start_time, + stop_time, + auto_show=False, + ) if auto_show: plt.show() - def plot_bits(self, - bits, - start_time=0.0, - end_time=None, - auto_show=True, - ): + def plot_bits( + self, + bits, + start_time=0.0, + end_time=None, + auto_show=True, + ): """ Plots a list of bits. """ @@ -724,11 +724,7 @@ def plot_bits(self, axes = [axes] for bit, ax in zip(bits, axes): - self.plot_bit(bit, - start_time, - end_time, - auto_show=False, - axes=ax) + self.plot_bit(bit, start_time, end_time, auto_show=False, axes=ax) # f.set_size_inches(18, 10, forward=True) f.subplots_adjust(hspace=0) @@ -737,22 +733,23 @@ def plot_bits(self, return f, axes - def plot_bit(self, - bit, - start_time=0.0, - end_time=None, - auto_show=True, - axes=None, - name="", - ): + def plot_bit( + self, + bit, + start_time=0.0, + end_time=None, + auto_show=True, + axes=None, + name="", + ): """ Plots a specific bit at a specific time period. """ import matplotlib.pyplot as plt - times = self.get_all_times(units='sec') + times = self.get_all_times(units="sec") if not end_time: - end_time = 2 ** 32 + end_time = 2**32 window = (times < end_time) & (times > start_time) @@ -767,7 +764,7 @@ def plot_bit(self, name = str(bit) bit = self.get_bit(bit) - ax.step(times[window], bit[window], where='post') + ax.step(times[window], bit[window], where="post") if hasattr(ax, "set_ylim"): ax.set_ylim(-0.1, 1.1) else: @@ -776,7 +773,7 @@ def plot_bit(self, # ax.set_ylabel('Logic State') # ax.yaxis.set_ticks_position('none') plt.setp(ax.get_yticklabels(), visible=False) - ax.set_xlabel('time (seconds)') + ax.set_xlabel("time (seconds)") ax.legend([name]) if auto_show: @@ -784,16 +781,18 @@ def plot_bit(self, return plt.gcf() - def plot_line(self, - line, - start_time=0.0, - end_time=None, - auto_show=True, - ): + def plot_line( + self, + line, + start_time=0.0, + end_time=None, + auto_show=True, + ): """ Plots a specific line at a specific time period. """ import matplotlib.pyplot as plt + bit = self._line_to_bit(line) self.plot_bit(bit, start_time, end_time, auto_show=False) @@ -801,23 +800,27 @@ def plot_line(self, if auto_show: plt.show() - def plot_lines(self, - lines, - start_time=0.0, - end_time=None, - auto_show=True, - ): + def plot_lines( + self, + lines, + start_time=0.0, + end_time=None, + auto_show=True, + ): """ Plots specific lines at a specific time period. """ import matplotlib.pyplot as plt + bits = [] for line in lines: bits.append(self._line_to_bit(line)) - f, axes = self.plot_bits(bits, - start_time, - end_time, - auto_show=False, ) + f, axes = self.plot_bits( + bits, + start_time, + end_time, + auto_show=False, + ) plt.subplots_adjust(left=0.025, right=0.975, bottom=0.05, top=0.95) if auto_show: @@ -850,5 +853,5 @@ def __exit__(self, type, value, traceback): self.close() -if __name__ == '__main__': +if __name__ == "__main__": pass diff --git a/allensdk/brain_observatory/sync_stim_aligner.py b/allensdk/brain_observatory/sync_stim_aligner.py index 669fffbd28..dfb4d92cba 100644 --- a/allensdk/brain_observatory/sync_stim_aligner.py +++ b/allensdk/brain_observatory/sync_stim_aligner.py @@ -7,13 +7,10 @@ import pathlib from allensdk.brain_observatory import sync_dataset from allensdk.internal.core.lims_utilities import safe_system_path -from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( - _StimulusFile) +from allensdk.brain_observatory.behavior.data_files.stimulus_file import _StimulusFile -def _choose_line( - data: sync_dataset.Dataset, - sync_lines: Union[str, Tuple[str]]) -> str: +def _choose_line(data: sync_dataset.Dataset, sync_lines: Union[str, Tuple[str]]) -> str: """ Scan through sync_lines in order. Select the first one that is present in the sync file. Raise an exception if @@ -32,7 +29,7 @@ def _choose_line( the sync file. """ if isinstance(sync_lines, str): - sync_lines = (sync_lines, ) + sync_lines = (sync_lines,) chosen_line = None for this_line in sync_lines: @@ -41,17 +38,13 @@ def _choose_line( break if chosen_line is None: - msg = ("Could not find one of " - f"{sync_lines} in sync dataset. " - f"available lines:\n{data.line_labels}") + msg = f"Could not find one of {sync_lines} in sync dataset. available lines:\n{data.line_labels}" raise RuntimeError(msg) return chosen_line -def _get_rising_times( - data: sync_dataset.Dataset, - sync_lines: Union[str, Tuple[str]]): +def _get_rising_times(data: sync_dataset.Dataset, sync_lines: Union[str, Tuple[str]]): """ Get the timestamps, in seconds, associated with the rising edges in a specific line in a sync file @@ -73,20 +66,14 @@ def _get_rising_times( The times, in seconds, associated with the rising edges of the chosen line. """ - chosen_line = _choose_line( - data=data, - sync_lines=sync_lines) + chosen_line = _choose_line(data=data, sync_lines=sync_lines) - timestamps = data.get_rising_edges( - line=chosen_line, - units='seconds') + timestamps = data.get_rising_edges(line=chosen_line, units="seconds") return timestamps -def _get_falling_times( - data: sync_dataset.Dataset, - sync_lines: Union[str, Tuple[str]]): +def _get_falling_times(data: sync_dataset.Dataset, sync_lines: Union[str, Tuple[str]]): """ Get the timestamps, in seconds, associated with the falling edges in a specific line in a sync file. @@ -112,25 +99,19 @@ def _get_falling_times( of the chosen line. """ - chosen_line = _choose_line( - data=data, - sync_lines=sync_lines) + chosen_line = _choose_line(data=data, sync_lines=sync_lines) - rising_edges = data.get_rising_edges( - line=chosen_line, - units='seconds') + rising_edges = data.get_rising_edges(line=chosen_line, units="seconds") - falling_edges = data.get_falling_edges( - line=chosen_line, - units='seconds') + falling_edges = data.get_falling_edges(line=chosen_line, units="seconds") - valid = (falling_edges > rising_edges[0]) + valid = falling_edges > rising_edges[0] return falling_edges[valid] def _get_line_starts_and_ends( - data: sync_dataset.Dataset, - sync_lines: Union[str, Tuple[str]]) -> Tuple[np.ndarray, np.ndarray]: + data: sync_dataset.Dataset, sync_lines: Union[str, Tuple[str]] +) -> Tuple[np.ndarray, np.ndarray]: """ Parameters ---------- @@ -149,22 +130,16 @@ def _get_line_starts_and_ends( np.ndarrays of times (in seconds) that the given line turns on (rises) and turns off (falls). """ - start_times = _get_rising_times( - data=data, - sync_lines=sync_lines) + start_times = _get_rising_times(data=data, sync_lines=sync_lines) - end_times = _get_falling_times( - data=data, - sync_lines=sync_lines) + end_times = _get_falling_times(data=data, sync_lines=sync_lines) return (start_times, end_times) def _get_start_frames( - data: sync_dataset.Dataset, - raw_frame_times: np.ndarray, - stimulus_frame_counts: List[int], - tolerance: float) -> List[int]: + data: sync_dataset.Dataset, raw_frame_times: np.ndarray, stimulus_frame_counts: List[int], tolerance: float +) -> List[int]: """ Find the start frames for a series of stimuli that need to be registered to a single sync file. @@ -226,9 +201,7 @@ def _get_start_frames( frame_count_arr = np.array(stimulus_frame_counts) - stim_starts, stim_ends = _get_line_starts_and_ends( - data=data, - sync_lines=('stim_running', 'sweep')) + stim_starts, stim_ends = _get_line_starts_and_ends(data=data, sync_lines=("stim_running", "sweep")) # break raw_frame_times into epochs based on stim_starts and stim_ends epoch_frame_counts = [] @@ -236,8 +209,7 @@ def _get_start_frames( for start, end in zip(stim_starts, stim_ends): # Inner expression returns a bool array where conditions are True # np.where evaluates bool array to return indices where bool array True - epoch_frames = np.where((raw_frame_times >= start) - & (raw_frame_times < end))[0] + epoch_frames = np.where((raw_frame_times >= start) & (raw_frame_times < end))[0] epoch_frame_counts.append(len(epoch_frames)) epoch_start_frames.append(epoch_frames[0]) @@ -266,23 +238,17 @@ def _get_start_frames( start_frames = [] for stim_idx, fc in enumerate(frame_count_arr): - logging.info(f"Finding stim start for stim with index: {stim_idx}") # Get index of stimulus whose frame counts most closely match # the expected number of frames - best_match = int( - np.argmin([np.abs(efc - fc) for efc in epoch_frame_counts]) - ) + best_match = int(np.argmin([np.abs(efc - fc) for efc in epoch_frame_counts])) lower_tol = fc * (1.0 - tolerance) upper_tol = fc * (1.0 + tolerance) if lower_tol <= epoch_frame_counts[best_match] <= upper_tol: _ = epoch_frame_counts.pop(best_match) start_frame = epoch_start_frames.pop(best_match) start_frames.append(start_frame) - logging.info( - f"Found stim start for stim with index ({stim_idx})" - f"at vsync ({start_frame})" - ) + logging.info(f"Found stim start for stim with index ({stim_idx})at vsync ({start_frame})") else: raise RuntimeError( "Could not find matching sync frames " @@ -302,11 +268,12 @@ def _get_start_frames( def get_stim_timestamps_from_stimulus_blocks( - stimulus_files: Union[_StimulusFile, List[_StimulusFile]], - sync_file: Union[str, pathlib.Path], - raw_frame_time_lines: Union[str, List[str]], - raw_frame_time_direction: str, - frame_count_tolerance: float) -> Dict[str, Any]: + stimulus_files: Union[_StimulusFile, List[_StimulusFile]], + sync_file: Union[str, pathlib.Path], + raw_frame_time_lines: Union[str, List[str]], + raw_frame_time_direction: str, + frame_count_tolerance: float, +) -> Dict[str, Any]: """ Find the timestamps associated a set of stimulus blocks that have to be aligned with a single sync file @@ -350,18 +317,22 @@ def get_stim_timestamps_from_stimulus_blocks( stimulus_block.num_frames to each stimulus block. """ - if raw_frame_time_direction == 'rising': + if raw_frame_time_direction == "rising": frame_time_fn = _get_rising_times - elif raw_frame_time_direction == 'falling': + elif raw_frame_time_direction == "falling": frame_time_fn = _get_falling_times else: - msg = ("Cannot parse raw_frame_time_direction = " - f"'{raw_frame_time_direction}'\n" - "must be either 'rising' or 'falling'") + msg = ( + "Cannot parse raw_frame_time_direction = " + f"'{raw_frame_time_direction}'\n" + "must be either 'rising' or 'falling'" + ) raise ValueError(msg) if not isinstance(stimulus_files, list): - stimulus_files = [stimulus_files, ] + stimulus_files = [ + stimulus_files, + ] if isinstance(sync_file, pathlib.Path): str_path = str(sync_file.resolve().absolute()) @@ -371,20 +342,18 @@ def get_stim_timestamps_from_stimulus_blocks( list_of_timestamps = [] with sync_dataset.Dataset(safe_sync_path) as sync_data: - raw_frame_times = frame_time_fn( - data=sync_data, - sync_lines=raw_frame_time_lines) + raw_frame_times = frame_time_fn(data=sync_data, sync_lines=raw_frame_time_lines) frame_count_list = [s.num_frames for s in stimulus_files] start_frames = _get_start_frames( - data=sync_data, - raw_frame_times=raw_frame_times, - stimulus_frame_counts=frame_count_list, - tolerance=frame_count_tolerance) + data=sync_data, + raw_frame_times=raw_frame_times, + stimulus_frame_counts=frame_count_list, + tolerance=frame_count_tolerance, + ) for f0, nf in zip(start_frames, frame_count_list): - this_array = raw_frame_times[f0:f0+nf] + this_array = raw_frame_times[f0 : f0 + nf] list_of_timestamps.append(this_array) - return {"timestamps": list_of_timestamps, - "start_frames": start_frames} + return {"timestamps": list_of_timestamps, "start_frames": start_frames} diff --git a/allensdk/brain_observatory/sync_utilities/__init__.py b/allensdk/brain_observatory/sync_utilities/__init__.py index fb8350f8cb..4d2b65c2b2 100644 --- a/allensdk/brain_observatory/sync_utilities/__init__.py +++ b/allensdk/brain_observatory/sync_utilities/__init__.py @@ -35,14 +35,15 @@ def trim_discontiguous_times(times: np.ndarray, threshold=100) -> np.ndarray: if len(gap_indices) == 0: return times - return times[:gap_indices[0] + 1] + return times[: gap_indices[0] + 1] -def get_synchronized_frame_times(session_sync_file: Path, - sync_line_label_keys: Tuple[str, ...], - drop_frames: Optional[List[int]] = None, - trim_after_spike: bool = True, - ) -> pd.Series: +def get_synchronized_frame_times( + session_sync_file: Path, + sync_line_label_keys: Tuple[str, ...], + drop_frames: Optional[List[int]] = None, + trim_after_spike: bool = True, +) -> pd.Series: """Get experimental frame times from an experiment session sync file. 1. Get rising edges from the sync dataset @@ -76,9 +77,7 @@ def get_synchronized_frame_times(session_sync_file: Path, """ sync_dataset = Dataset(str(session_sync_file)) - times = sync_dataset.get_edges( - "rising", sync_line_label_keys, units="seconds" - ) + times = sync_dataset.get_edges("rising", sync_line_label_keys, units="seconds") times = trim_discontiguous_times(times) if trim_after_spike else times if drop_frames is not None: diff --git a/allensdk/brain_observatory/vbn_2022/input_json_writer/__main__.py b/allensdk/brain_observatory/vbn_2022/input_json_writer/__main__.py index 490acb402c..6bd290844b 100644 --- a/allensdk/brain_observatory/vbn_2022/input_json_writer/__main__.py +++ b/allensdk/brain_observatory/vbn_2022/input_json_writer/__main__.py @@ -1,6 +1,4 @@ -from allensdk.brain_observatory.vbn_2022.\ - input_json_writer.input_json_writer import ( - VBN2022InputJsonWriter) +from allensdk.brain_observatory.vbn_2022.input_json_writer.input_json_writer import VBN2022InputJsonWriter def main(): diff --git a/allensdk/brain_observatory/vbn_2022/input_json_writer/input_json_writer.py b/allensdk/brain_observatory/vbn_2022/input_json_writer/input_json_writer.py index eb77a86e63..d12988cf7d 100644 --- a/allensdk/brain_observatory/vbn_2022/input_json_writer/input_json_writer.py +++ b/allensdk/brain_observatory/vbn_2022/input_json_writer/input_json_writer.py @@ -4,40 +4,35 @@ import argschema import json -from allensdk.brain_observatory.vbn_2022.input_json_writer.schemas import ( - VBN2022InputJsonWriterSchema) +from allensdk.brain_observatory.vbn_2022.input_json_writer.schemas import VBN2022InputJsonWriterSchema -from allensdk.brain_observatory.vbn_2022.input_json_writer.utils import ( - vbn_nwb_config_from_ecephys_session_id_list) +from allensdk.brain_observatory.vbn_2022.input_json_writer.utils import vbn_nwb_config_from_ecephys_session_id_list class VBN2022InputJsonWriter(argschema.ArgSchemaParser): default_schema = VBN2022InputJsonWriterSchema def run(self): - results = vbn_nwb_config_from_ecephys_session_id_list( - ecephys_session_id_list=self.args['ecephys_session_id_list'], - probes_to_skip=self.args['probes_to_skip'] + ecephys_session_id_list=self.args["ecephys_session_id_list"], probes_to_skip=self.args["probes_to_skip"] ) - session_specs = results['sessions'] - msg = results['log'] - self.logger.info("\n\nIrregularities:\n" - f"==============\n{msg}") + session_specs = results["sessions"] + msg = results["log"] + self.logger.info(f"\n\nIrregularities:\n==============\n{msg}") for session in session_specs: - session_id = session['ecephys_session_id'] - json_path = self.args['json_path_lookup'][session_id] + session_id = session["ecephys_session_id"] + json_path = self.args["json_path_lookup"][session_id] config = dict() - config['log_level'] = "INFO" + config["log_level"] = "INFO" - str_nwb_path = self.args['nwb_path_lookup'][session_id] + str_nwb_path = self.args["nwb_path_lookup"][session_id] str_nwb_path = str(str_nwb_path.resolve().absolute()) os.makedirs(Path(str_nwb_path).parent, exist_ok=True) - config['output_path'] = str_nwb_path - config['session_data'] = session - with open(json_path, 'w') as out_file: + config["output_path"] = str_nwb_path + config["session_data"] = session + with open(json_path, "w") as out_file: out_file.write(json.dumps(config, indent=2)) self.logger.info(f"wrote {json_path.resolve().absolute()}") diff --git a/allensdk/brain_observatory/vbn_2022/input_json_writer/schemas.py b/allensdk/brain_observatory/vbn_2022/input_json_writer/schemas.py index 0a0571b0a1..deb72bd3e1 100644 --- a/allensdk/brain_observatory/vbn_2022/input_json_writer/schemas.py +++ b/allensdk/brain_observatory/vbn_2022/input_json_writer/schemas.py @@ -2,53 +2,49 @@ import pathlib from marshmallow import post_load -from allensdk.brain_observatory.vbn_2022.utils.schemas import ( - ProbeToSkip) +from allensdk.brain_observatory.vbn_2022.utils.schemas import ProbeToSkip class VBN2022InputJsonWriterSchema(argschema.ArgSchema): - ecephys_session_id_list = argschema.fields.List( - argschema.fields.Int, - required=True, - description=("List of ecephys_sessions.id values " - "of sessions to be released")) + argschema.fields.Int, + required=True, + description=("List of ecephys_sessions.id values of sessions to be released"), + ) probes_to_skip = argschema.fields.List( - argschema.fields.Nested(ProbeToSkip), - required=False, - default=None, - allow_none=True, - description=("List of probes to skip")) + argschema.fields.Nested(ProbeToSkip), + required=False, + default=None, + allow_none=True, + description=("List of probes to skip"), + ) json_output_dir = argschema.fields.OutputDir( - required=True, - description=("Directory where input JSONs will be written")) + required=True, description=("Directory where input JSONs will be written") + ) nwb_output_dir = argschema.fields.OutputDir( - required=True, - description=("Directory where NWB files will be written")) + required=True, description=("Directory where NWB files will be written") + ) clobber = argschema.fields.Boolean( - default=False, - description=("If false, throw an error if output files " - "already exist")) + default=False, description=("If false, throw an error if output files already exist") + ) json_prefix = argschema.fields.Str( - required=False, - default='vbn_ecephys_session', - allow_none=False, - description=('The files written by this module will be ' - 'named like ' - '{json_prefix}_{session_id}_input.json')) + required=False, + default="vbn_ecephys_session", + allow_none=False, + description=("The files written by this module will be named like {json_prefix}_{session_id}_input.json"), + ) nwb_prefix = argschema.fields.Str( - required=False, - default='ecepys', - allow_none=False, - description=('The NWB files specified in the input JSONs ' - 'will be named like ' - '{nwb_prefix}_{session_id}.nwb')) + required=False, + default="ecepys", + allow_none=False, + description=("The NWB files specified in the input JSONs will be named like {nwb_prefix}_{session_id}.nwb"), + ) @post_load def create_path_lookup(self, data, **kwargs): @@ -62,32 +58,30 @@ def create_path_lookup(self, data, **kwargs): unq_json = set() unq_nwb = set() - json_dir_path = pathlib.Path(data['json_output_dir']) - nwb_dir_path = pathlib.Path(data['nwb_output_dir']) + json_dir_path = pathlib.Path(data["json_output_dir"]) + nwb_dir_path = pathlib.Path(data["nwb_output_dir"]) - for ecephys_id in data['ecephys_session_id_list']: + for ecephys_id in data["ecephys_session_id_list"]: json_name = f"{data['json_prefix']}_{ecephys_id}_input.json" if json_name in unq_json: - raise RuntimeError("This configuration would write " - f"{json_name} more than once") + raise RuntimeError(f"This configuration would write {json_name} more than once") unq_json.add(json_name) json_path = json_dir_path / json_name - if not data['clobber'] and json_path.is_file(): - raise RuntimeError(f"{json_path.resolve().absolute()} " - "already exists; " - "run with clobber=True to overwrite") + if not data["clobber"] and json_path.is_file(): + raise RuntimeError( + f"{json_path.resolve().absolute()} already exists; run with clobber=True to overwrite" + ) json_lookup[ecephys_id] = json_path nwb_name = f"{data['nwb_prefix']}_{ecephys_id}.nwb" if nwb_name in unq_nwb: - raise RuntimeError("This configuration would write " - f"{nwb_name} more than once") + raise RuntimeError(f"This configuration would write {nwb_name} more than once") unq_nwb.add(nwb_name) - nwb_path = nwb_dir_path / f'{ecephys_id}' / nwb_name + nwb_path = nwb_dir_path / f"{ecephys_id}" / nwb_name nwb_lookup[ecephys_id] = nwb_path - data['json_path_lookup'] = json_lookup - data['nwb_path_lookup'] = nwb_lookup + data["json_path_lookup"] = json_lookup + data["nwb_path_lookup"] = nwb_lookup return data diff --git a/allensdk/brain_observatory/vbn_2022/input_json_writer/utils.py b/allensdk/brain_observatory/vbn_2022/input_json_writer/utils.py index 49cbbfa57b..a460acfe70 100644 --- a/allensdk/brain_observatory/vbn_2022/input_json_writer/utils.py +++ b/allensdk/brain_observatory/vbn_2022/input_json_writer/utils.py @@ -4,37 +4,32 @@ import numpy as np import numbers -from allensdk.internal.api.queries.wkf_lims_queries import ( - wkf_path_from_attachable) +from allensdk.internal.api.queries.wkf_lims_queries import wkf_path_from_attachable -from allensdk.internal.api.queries.equipment_lims_queries import ( - experiment_configs_from_equipment_id_and_type) +from allensdk.internal.api.queries.equipment_lims_queries import experiment_configs_from_equipment_id_and_type from allensdk.internal.api import PostgresQueryMixin from allensdk import OneResultExpectedError -from allensdk.internal.api.queries.utils import ( - build_in_list_selector_query) +from allensdk.internal.api.queries.utils import build_in_list_selector_query -from allensdk.brain_observatory.behavior.data_objects.\ - metadata.subject_metadata.reporter_line import ReporterLine +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.reporter_line import ReporterLine -from allensdk.brain_observatory.behavior.data_objects.\ - metadata.subject_metadata.driver_line import DriverLine +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.driver_line import DriverLine from allensdk.brain_observatory.vbn_2022.metadata_writer.lims_queries import ( _ecephys_summary_table_from_ecephys_session_id_list, probes_table_from_ecephys_session_id_list, channels_table_from_ecephys_session_id_list, units_table_from_ecephys_session_id_list, - get_list_of_bad_probe_ids) + get_list_of_bad_probe_ids, +) -from allensdk.brain_observatory.vbn_2022.metadata_writer.\ - dataframe_manipulations import ( - _add_age_in_days, - _patch_date_and_stage_from_pickle_file) +from allensdk.brain_observatory.vbn_2022.metadata_writer.dataframe_manipulations import ( + _add_age_in_days, + _patch_date_and_stage_from_pickle_file, +) -from allensdk.core.auth_config import ( - LIMS_DB_CREDENTIAL_MAP) +from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP from allensdk.internal.api import db_connection_creator @@ -49,9 +44,7 @@ class NwbConfigErrorLog(object): def __init__(self): self._messages = dict() - def log(self, - ecephys_session_id: Union[int, str], - msg: str) -> None: + def log(self, ecephys_session_id: Union[int, str], msg: str) -> None: """ Log an irregularity associated with a specifiec ecephys session @@ -93,8 +86,7 @@ def write(self) -> str: def vbn_nwb_config_from_ecephys_session_id_list( - ecephys_session_id_list: List[int], - probes_to_skip: Optional[List[dict]] + ecephys_session_id_list: List[int], probes_to_skip: Optional[List[dict]] ) -> dict: """ Return a list of dicts. Each dict the specification for @@ -129,14 +121,11 @@ def vbn_nwb_config_from_ecephys_session_id_list( error_log = NwbConfigErrorLog() - lims_connection = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + lims_connection = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) # convert probes_to_skip into a list of ecephys_probe_ids if probes_to_skip is not None: - probe_ids_to_skip = get_list_of_bad_probe_ids( - lims_connection=lims_connection, - probes_to_skip=probes_to_skip) + probe_ids_to_skip = get_list_of_bad_probe_ids(lims_connection=lims_connection, probes_to_skip=probes_to_skip) else: probe_ids_to_skip = None @@ -144,76 +133,72 @@ def vbn_nwb_config_from_ecephys_session_id_list( # data for each session excluding the lists of probes, # channels, and units) session_list = session_input_from_ecephys_session_id_list( - ecephys_session_id_list=ecephys_session_id_list, - lims_connection=lims_connection, - error_log=error_log) + ecephys_session_id_list=ecephys_session_id_list, lims_connection=lims_connection, error_log=error_log + ) # iterate over each session, adding the probes, channels, # and units as appropriate for session in session_list: - session_id = session['ecephys_session_id'] + session_id = session["ecephys_session_id"] probe_list = probe_input_from_ecephys_session_id( ecephys_session_id=session_id, probe_ids_to_skip=probe_ids_to_skip, lims_connection=lims_connection, - error_log=error_log + error_log=error_log, ) - session['probes'] = probe_list + session["probes"] = probe_list channel_input = channel_input_from_ecephys_session_id( - ecephys_session_id=session_id, - probe_ids_to_skip=probe_ids_to_skip, - lims_connection=lims_connection, - error_log=error_log) + ecephys_session_id=session_id, + probe_ids_to_skip=probe_ids_to_skip, + lims_connection=lims_connection, + error_log=error_log, + ) # bad_probe_list keeps track of any probes that did not have # channels attached to it; these probes will be excluded # from the final configuration, and a message will be logged bad_probe_list = [] - for idx, probe in enumerate(session['probes']): - probe_id = probe['id'] + for idx, probe in enumerate(session["probes"]): + probe_id = probe["id"] if probe_id in channel_input: channels = channel_input[probe_id] - probe['channels'] = channels + probe["channels"] = channels else: bad_probe_list.append(idx) - msg = (f"could not find channels for probe {probe_id}; " - "not listing in the input.json") - error_log.log(ecephys_session_id=session_id, - msg=msg) + msg = f"could not find channels for probe {probe_id}; not listing in the input.json" + error_log.log(ecephys_session_id=session_id, msg=msg) bad_probe_list.reverse() for idx in bad_probe_list: - this_probe = session['probes'].pop(idx) - assert 'channels' not in this_probe + this_probe = session["probes"].pop(idx) + assert "channels" not in this_probe unit_input = unit_input_from_ecephys_session_id( - ecephys_session_id=session_id, - probe_ids_to_skip=probe_ids_to_skip, - lims_connection=lims_connection, - error_log=error_log) + ecephys_session_id=session_id, + probe_ids_to_skip=probe_ids_to_skip, + lims_connection=lims_connection, + error_log=error_log, + ) - for probe in session['probes']: - probe_id = probe['id'] + for probe in session["probes"]: + probe_id = probe["id"] if probe_id in unit_input: units = unit_input[probe_id] - probe['units'] = units + probe["units"] = units else: msg = f"could not find units for probe {probe_id}" - error_log.log(ecephys_session_id=session_id, - msg=msg) + error_log.log(ecephys_session_id=session_id, msg=msg) - return {'sessions': session_list, - 'log': error_log.write()} + return {"sessions": session_list, "log": error_log.write()} def session_input_from_ecephys_session_id_list( - ecephys_session_id_list: List[int], - lims_connection: PostgresQueryMixin, - error_log: NwbConfigErrorLog) -> List[dict]: + ecephys_session_id_list: List[int], lims_connection: PostgresQueryMixin, error_log: NwbConfigErrorLog +) -> List[dict]: """ Return a list of dicts, each dict representing the configuration data necessary for writing an NWB file for a session, excluding @@ -239,68 +224,64 @@ def session_input_from_ecephys_session_id_list( """ session_table = _ecephys_summary_table_from_ecephys_session_id_list( - lims_connection=lims_connection, - ecephys_session_id_list=ecephys_session_id_list, - failed_ecephys_session_id_list=None) + lims_connection=lims_connection, + ecephys_session_id_list=ecephys_session_id_list, + failed_ecephys_session_id_list=None, + ) # get date_of_acquisition from the pickle file by nulling out the # dates of acqusition from any sessions with behavior_session_ids, # then filling the values back in from the pickle file. - session_table.loc[ - np.logical_not(session_table.behavior_session_id.isna()), - 'date_of_acquisition'] = None + session_table.loc[np.logical_not(session_table.behavior_session_id.isna()), "date_of_acquisition"] = None session_table = _patch_date_and_stage_from_pickle_file( - lims_connection=lims_connection, - behavior_df=session_table, - flag_columns=['date_of_acquisition'], - columns_to_patch=['date_of_acquisition']) + lims_connection=lims_connection, + behavior_df=session_table, + flag_columns=["date_of_acquisition"], + columns_to_patch=["date_of_acquisition"], + ) # clip fractions of a second off the date of acquisition # (DateTime data object will fail deserialization if you do not) - session_table.date_of_acquisition = \ - session_table.date_of_acquisition.dt.floor('S') + session_table.date_of_acquisition = session_table.date_of_acquisition.dt.floor("S") - session_table = _add_age_in_days( - df=session_table, - index_column='ecephys_session_id') + session_table = _add_age_in_days(df=session_table, index_column="ecephys_session_id") - session_table.age_in_days = session_table.age_in_days.apply( - lambda x: f'P{int(x):d}') + session_table.age_in_days = session_table.age_in_days.apply(lambda x: f"P{int(x):d}") # apply naming conventions from the NWB writer's schema session_table.rename( - columns={'equipment_name': 'rig_name', - 'mouse_id': 'external_specimen_name', - 'genotype': 'full_genotype', - 'age_in_days': 'age'}, - inplace=True) + columns={ + "equipment_name": "rig_name", + "mouse_id": "external_specimen_name", + "genotype": "full_genotype", + "age_in_days": "age", + }, + inplace=True, + ) - session_table.external_specimen_name = \ - session_table.external_specimen_name.astype(int) + session_table.external_specimen_name = session_table.external_specimen_name.astype(int) - session_table.drop( - labels=['session_type', 'project_code'], - axis='columns', - inplace=True) + session_table.drop(labels=["session_type", "project_code"], axis="columns", inplace=True) - session_table = session_table.set_index( - 'ecephys_session_id') + session_table = session_table.set_index("ecephys_session_id") - session_table = session_table.to_dict(orient='index') + session_table = session_table.to_dict(orient="index") # get lookup tables mapping ecephys_session_id to the # ecephys_analysis_run_ids for the optotagging and # stimulus table files optotagging_run_lookup = _analysis_run_from_session_id( - lims_connection=lims_connection, - ecephys_session_id_list=ecephys_session_id_list, - strategy_class='EcephysOptotaggingTableStrategy') + lims_connection=lims_connection, + ecephys_session_id_list=ecephys_session_id_list, + strategy_class="EcephysOptotaggingTableStrategy", + ) stim_table_run_lookup = _analysis_run_from_session_id( - lims_connection=lims_connection, - ecephys_session_id_list=ecephys_session_id_list, - strategy_class='VbnCreateStimTableStrategy') + lims_connection=lims_connection, + ecephys_session_id_list=ecephys_session_id_list, + strategy_class="VbnCreateStimTableStrategy", + ) # A list of tuples associating fields in the final specification # returned by this method with the names of files in the @@ -308,106 +289,110 @@ def session_input_from_ecephys_session_id_list( # tuple is the field in the returned specification; the first # element is the associated well_known_file_types.name input_from_wkf_session = [ - ('behavior_stimulus_file', 'StimulusPickle'), - ('mapping_stimulus_file', 'MappingPickle'), - ('replay_stimulus_file', 'EcephysReplayStimulus'), - ('raw_eye_tracking_video_meta_data', - 'RawEyeTrackingVideoMetadata'), - ('eye_dlc_file', 'EyeDlcOutputFile'), - ('face_dlc_file', 'FaceDlcOutputFile'), - ('side_dlc_file', 'SideDlcOutputFile'), - ('eye_tracking_filepath', 'EyeTracking Ellipses'), - ('sync_file', 'EcephysRigSync')] + ("behavior_stimulus_file", "StimulusPickle"), + ("mapping_stimulus_file", "MappingPickle"), + ("replay_stimulus_file", "EcephysReplayStimulus"), + ("raw_eye_tracking_video_meta_data", "RawEyeTrackingVideoMetadata"), + ("eye_dlc_file", "EyeDlcOutputFile"), + ("face_dlc_file", "FaceDlcOutputFile"), + ("side_dlc_file", "SideDlcOutputFile"), + ("eye_tracking_filepath", "EyeTracking Ellipses"), + ("sync_file", "EcephysRigSync"), + ] # cast to a SQL-safe string - wkf_types_to_query_session = [f"'{el[1]}'" - for el in input_from_wkf_session] + wkf_types_to_query_session = [f"'{el[1]}'" for el in input_from_wkf_session] result = [] for session_id in ecephys_session_id_list: session_id = int(session_id) if session_id not in session_table: - error_log.log(ecephys_session_id=session_id, - msg="No session data was found at all; skipping") + error_log.log(ecephys_session_id=session_id, msg="No session data was found at all; skipping") continue data = session_table[session_id] - data['ecephys_session_id'] = session_id + data["ecephys_session_id"] = session_id # lookup all of the well known files we need for this # specification wkf_path_lookup = wkf_path_from_attachable( - lims_connection=lims_connection, - wkf_type_name=wkf_types_to_query_session, - attachable_type='EcephysSession', - attachable_id=session_id) + lims_connection=lims_connection, + wkf_type_name=wkf_types_to_query_session, + attachable_type="EcephysSession", + attachable_id=session_id, + ) for key_pair in input_from_wkf_session: this_path = wkf_path_lookup.get(key_pair[1], None) if this_path is None: - msg = (f"Could not find {key_pair[1]} " - f"for ecephys_session {session_id}") - error_log.log(ecephys_session_id=session_id, - msg=msg) + msg = f"Could not find {key_pair[1]} for ecephys_session {session_id}" + error_log.log(ecephys_session_id=session_id, msg=msg) data[key_pair[0]] = this_path # get stimulus_table stim_path_lookup = wkf_path_from_attachable( - lims_connection=lims_connection, - wkf_type_name=["'EcephysStimulusTable'", ], - attachable_type="EcephysAnalysisRun", - attachable_id=stim_table_run_lookup[session_id]) + lims_connection=lims_connection, + wkf_type_name=[ + "'EcephysStimulusTable'", + ], + attachable_type="EcephysAnalysisRun", + attachable_id=stim_table_run_lookup[session_id], + ) - data['stim_table_file'] = stim_path_lookup['EcephysStimulusTable'] + data["stim_table_file"] = stim_path_lookup["EcephysStimulusTable"] # get optotagging_table optotagging_path_lookup = wkf_path_from_attachable( - lims_connection=lims_connection, - wkf_type_name=["'EcephysOptotaggingTable'", ], - attachable_type="EcephysAnalysisRun", - attachable_id=optotagging_run_lookup[session_id]) + lims_connection=lims_connection, + wkf_type_name=[ + "'EcephysOptotaggingTable'", + ], + attachable_type="EcephysAnalysisRun", + attachable_id=optotagging_run_lookup[session_id], + ) - data['optotagging_table_path'] = optotagging_path_lookup[ - "EcephysOptotaggingTable"] + data["optotagging_table_path"] = optotagging_path_lookup["EcephysOptotaggingTable"] driver_line = DriverLine.from_lims( - lims_db=lims_connection, - behavior_session_id=data['behavior_session_id'], - allow_none=True).value + lims_db=lims_connection, behavior_session_id=data["behavior_session_id"], allow_none=True + ).value if driver_line is not None: if isinstance(driver_line, list): - data['driver_line'] = driver_line + data["driver_line"] = driver_line else: - data['driver_line'] = [driver_line, ] + data["driver_line"] = [ + driver_line, + ] else: - data['driver_line'] = [] + data["driver_line"] = [] reporter_line = ReporterLine.from_lims( - lims_db=lims_connection, - behavior_session_id=data['behavior_session_id'], - allow_none=True).value + lims_db=lims_connection, behavior_session_id=data["behavior_session_id"], allow_none=True + ).value if reporter_line is not None: if isinstance(reporter_line, list): - data['reporter_line'] = reporter_line + data["reporter_line"] = reporter_line else: - data['reporter_line'] = [reporter_line, ] + data["reporter_line"] = [ + reporter_line, + ] else: - data['reporter_line'] = [] + data["reporter_line"] = [] eye_geometry = eye_tracking_geometry_from_equipment_id( - equipment_id=data.pop('equipment_id'), - date_of_acquisition=data['date_of_acquisition'], - lims_connection=lims_connection) + equipment_id=data.pop("equipment_id"), + date_of_acquisition=data["date_of_acquisition"], + lims_connection=lims_connection, + ) - eye_geometry['equipment'] = data['rig_name'] + eye_geometry["equipment"] = data["rig_name"] - data['eye_tracking_rig_geometry'] = eye_geometry + data["eye_tracking_rig_geometry"] = eye_geometry - for k in ('date_of_acquisition', - 'date_of_birth'): + for k in ("date_of_acquisition", "date_of_birth"): data[k] = str(data[k]) result.append(data) @@ -415,31 +400,22 @@ def session_input_from_ecephys_session_id_list( return result -def _get_probe_analysis_run_from_probe_id( - lims_connection: PostgresQueryMixin, - probe_id: int, - lims_strategy: str -): - query = f''' +def _get_probe_analysis_run_from_probe_id(lims_connection: PostgresQueryMixin, probe_id: int, lims_strategy: str): + query = f""" SELECT earp.id FROM ecephys_analysis_run_probes earp JOIN ecephys_analysis_runs ear on ear.id = earp.ecephys_analysis_run_id WHERE earp.ecephys_probe_id = {probe_id} and job_strategy_class = '{lims_strategy}' and ear.current - ''' + """ res = lims_connection.select_one(query) if not res: - raise OneResultExpectedError( - f'Expected to find one analysis probe run for probe ' - f'{probe_id}') - return res['id'] + raise OneResultExpectedError(f"Expected to find one analysis probe run for probe {probe_id}") + return res["id"] -def _get_probe_lfp_meta( - lims_connection: PostgresQueryMixin, - probe_id: int -): +def _get_probe_lfp_meta(lims_connection: PostgresQueryMixin, probe_id: int): """Gets filepaths for files needed to build LFP data Parameters @@ -449,56 +425,44 @@ def _get_probe_lfp_meta( """ lfp_subsampling_run_well_known_files = [ - 'EcephysSubsampledLfpContinuous', - 'EcephysSubsampledLfpTimestamps', - 'EcephysSubsampledChannelStates' - ] - current_source_density_well_known_files = [ - 'EcephysCurrentSourceDensity' + "EcephysSubsampledLfpContinuous", + "EcephysSubsampledLfpTimestamps", + "EcephysSubsampledChannelStates", ] + current_source_density_well_known_files = ["EcephysCurrentSourceDensity"] probe_lfp_subsampling_run_id = _get_probe_analysis_run_from_probe_id( - lims_connection=lims_connection, - probe_id=probe_id, - lims_strategy='EcephysLfpSubsamplingStrategy' + lims_connection=lims_connection, probe_id=probe_id, lims_strategy="EcephysLfpSubsamplingStrategy" + ) + probe_current_source_density_run_id = _get_probe_analysis_run_from_probe_id( + lims_connection=lims_connection, probe_id=probe_id, lims_strategy="EcephysCurrentSourceDensityStrategy" ) - probe_current_source_density_run_id = \ - _get_probe_analysis_run_from_probe_id( - lims_connection=lims_connection, - probe_id=probe_id, - lims_strategy='EcephysCurrentSourceDensityStrategy' - ) probe_lfp_well_known_files = wkf_path_from_attachable( lims_connection=lims_connection, wkf_type_name=lfp_subsampling_run_well_known_files, - attachable_type='EcephysAnalysisRunProbe', - attachable_id=probe_lfp_subsampling_run_id) + attachable_type="EcephysAnalysisRunProbe", + attachable_id=probe_lfp_subsampling_run_id, + ) probe_csd_well_known_files = wkf_path_from_attachable( lims_connection=lims_connection, wkf_type_name=current_source_density_well_known_files, - attachable_type='EcephysAnalysisRunProbe', - attachable_id=probe_current_source_density_run_id) + attachable_type="EcephysAnalysisRunProbe", + attachable_id=probe_current_source_density_run_id, + ) lfp = { - 'input_data_path': - probe_lfp_well_known_files.get( - 'EcephysSubsampledLfpContinuous'), - 'input_timestamps_path': - probe_lfp_well_known_files.get( - 'EcephysSubsampledLfpTimestamps'), - 'input_channels_path': - probe_lfp_well_known_files.get( - 'EcephysSubsampledChannelStates'), - 'csd_path': probe_csd_well_known_files.get( - 'EcephysCurrentSourceDensity') + "input_data_path": probe_lfp_well_known_files.get("EcephysSubsampledLfpContinuous"), + "input_timestamps_path": probe_lfp_well_known_files.get("EcephysSubsampledLfpTimestamps"), + "input_channels_path": probe_lfp_well_known_files.get("EcephysSubsampledChannelStates"), + "csd_path": probe_csd_well_known_files.get("EcephysCurrentSourceDensity"), } return lfp def probe_input_from_ecephys_session_id( - ecephys_session_id: int, - probe_ids_to_skip: Optional[List[int]], - lims_connection: PostgresQueryMixin, - error_log: NwbConfigErrorLog, + ecephys_session_id: int, + probe_ids_to_skip: Optional[List[int]], + lims_connection: PostgresQueryMixin, + error_log: NwbConfigErrorLog, ) -> List[dict]: """ Get the list of probe specifications, excluding the lists @@ -530,22 +494,22 @@ def probe_input_from_ecephys_session_id( """ probes_table = probes_table_from_ecephys_session_id_list( - lims_connection=lims_connection, - ecephys_session_id_list=[ecephys_session_id, ], - probe_ids_to_skip=probe_ids_to_skip) + lims_connection=lims_connection, + ecephys_session_id_list=[ + ecephys_session_id, + ], + probe_ids_to_skip=probe_ids_to_skip, + ) - probes_table = probes_table.set_index('ecephys_probe_id') + probes_table = probes_table.set_index("ecephys_probe_id") probes_table.drop( - labels=['ecephys_session_id', - 'phase', - 'unit_count', - 'channel_count', - 'structure_acronyms'], - axis='columns', - inplace=True) + labels=["ecephys_session_id", "phase", "unit_count", "channel_count", "structure_acronyms"], + axis="columns", + inplace=True, + ) - probes_table = probes_table.to_dict(orient='index') + probes_table = probes_table.to_dict(orient="index") # A list of tuples associating fields in the final specification # returned by this method with the names of files in the @@ -553,15 +517,15 @@ def probe_input_from_ecephys_session_id( # tuple is the field in the returned specification; the first # element is the associated well_known_file_types.name input_from_wkf_probe = [ - ('inverse_whitening_matrix_path', 'EcephysSortedWhiteningMatInv'), - ('mean_waveforms_path', 'EcephysSortedMeanWaveforms'), - ('spike_amplitudes_path', 'EcephysSortedAmplitudes'), - ('spike_clusters_file', 'EcephysSortedSpikeClusters'), - ('spike_templates_path', 'EcephysSortedSpikeTemplates'), - ('templates_path', 'EcephysSortedTemplates')] + ("inverse_whitening_matrix_path", "EcephysSortedWhiteningMatInv"), + ("mean_waveforms_path", "EcephysSortedMeanWaveforms"), + ("spike_amplitudes_path", "EcephysSortedAmplitudes"), + ("spike_clusters_file", "EcephysSortedSpikeClusters"), + ("spike_templates_path", "EcephysSortedSpikeTemplates"), + ("templates_path", "EcephysSortedTemplates"), + ] - wkf_to_query = [f"'{el[1]}'" - for el in input_from_wkf_probe] + wkf_to_query = [f"'{el[1]}'" for el in input_from_wkf_probe] probe_list = [] probe_id_list = list(probes_table.keys()) @@ -569,42 +533,42 @@ def probe_input_from_ecephys_session_id( for probe_id in probe_id_list: data = probes_table[probe_id] - has_lfp = data.pop('has_lfp_data') - data['id'] = probe_id + has_lfp = data.pop("has_lfp_data") + data["id"] = probe_id wkf_path_lookup = wkf_path_from_attachable( - lims_connection=lims_connection, - wkf_type_name=wkf_to_query, - attachable_type='EcephysProbe', - attachable_id=probe_id) + lims_connection=lims_connection, + wkf_type_name=wkf_to_query, + attachable_type="EcephysProbe", + attachable_id=probe_id, + ) for key_pair in input_from_wkf_probe: data[key_pair[0]] = wkf_path_lookup.get(key_pair[1], None) if has_lfp: - lfp_meta = _get_probe_lfp_meta( - lims_connection=lims_connection, - probe_id=probe_id - ) - data['csd_path'] = lfp_meta.pop('csd_path') - data['lfp'] = lfp_meta + lfp_meta = _get_probe_lfp_meta(lims_connection=lims_connection, probe_id=probe_id) + data["csd_path"] = lfp_meta.pop("csd_path") + data["lfp"] = lfp_meta else: - data['lfp'] = None + data["lfp"] = None probe_list.append(_nan_to_none(data)) probe_list = _add_spike_times_path( - probe_list=probe_list, - ecephys_session_id=ecephys_session_id, - lims_connection=lims_connection, - error_log=error_log) + probe_list=probe_list, + ecephys_session_id=ecephys_session_id, + lims_connection=lims_connection, + error_log=error_log, + ) return probe_list def channel_input_from_ecephys_session_id( - ecephys_session_id: int, - probe_ids_to_skip: Optional[List[int]], - lims_connection: PostgresQueryMixin, - error_log: NwbConfigErrorLog) -> Dict[int, list]: + ecephys_session_id: int, + probe_ids_to_skip: Optional[List[int]], + lims_connection: PostgresQueryMixin, + error_log: NwbConfigErrorLog, +) -> Dict[int, list]: """ Get a dict mapping probe_id to the list of channel specifications for a given ecephys_session_id @@ -632,45 +596,49 @@ def channel_input_from_ecephys_session_id( """ raw_channels_table = channels_table_from_ecephys_session_id_list( - ecephys_session_id_list=[ecephys_session_id, ], - probe_ids_to_skip=probe_ids_to_skip, - lims_connection=lims_connection) - - raw_channels_table.rename( - columns={'ecephys_channel_id': 'id', - 'ecephys_probe_id': 'probe_id'}, - inplace=True) - - raw_channels_table = raw_channels_table[[ - 'id', - 'probe_id', - 'probe_channel_number', - 'structure_id', - 'structure_acronym', - 'anterior_posterior_ccf_coordinate', - 'dorsal_ventral_ccf_coordinate', - 'left_right_ccf_coordinate', - 'probe_horizontal_position', - 'probe_vertical_position', - 'valid_data']] - raw_channels_table = raw_channels_table.set_index('id') - raw_channels_table = raw_channels_table.to_dict(orient='index') + ecephys_session_id_list=[ + ecephys_session_id, + ], + probe_ids_to_skip=probe_ids_to_skip, + lims_connection=lims_connection, + ) + + raw_channels_table.rename(columns={"ecephys_channel_id": "id", "ecephys_probe_id": "probe_id"}, inplace=True) + + raw_channels_table = raw_channels_table[ + [ + "id", + "probe_id", + "probe_channel_number", + "structure_id", + "structure_acronym", + "anterior_posterior_ccf_coordinate", + "dorsal_ventral_ccf_coordinate", + "left_right_ccf_coordinate", + "probe_horizontal_position", + "probe_vertical_position", + "valid_data", + ] + ] + raw_channels_table = raw_channels_table.set_index("id") + raw_channels_table = raw_channels_table.to_dict(orient="index") probe_id_to_channels = dict() for channel_id in raw_channels_table.keys(): this_channel = raw_channels_table[channel_id] - probe_id = this_channel['probe_id'] + probe_id = this_channel["probe_id"] if probe_id not in probe_id_to_channels: probe_id_to_channels[probe_id] = [] - this_channel['id'] = channel_id + this_channel["id"] = channel_id probe_id_to_channels[probe_id].append(_nan_to_none(this_channel)) return probe_id_to_channels def unit_input_from_ecephys_session_id( - ecephys_session_id: int, - probe_ids_to_skip: Optional[List[int]], - lims_connection: PostgresQueryMixin, - error_log: NwbConfigErrorLog) -> Dict[int, list]: + ecephys_session_id: int, + probe_ids_to_skip: Optional[List[int]], + lims_connection: PostgresQueryMixin, + error_log: NwbConfigErrorLog, +) -> Dict[int, list]: """ Get a dict mapping probe_id to the list of unit specifications for a given ecephys_session_id @@ -697,41 +665,43 @@ def unit_input_from_ecephys_session_id( that needs to be written to the input.json. """ raw_unit_table = units_table_from_ecephys_session_id_list( - ecephys_session_id_list=[ecephys_session_id, ], - probe_ids_to_skip=probe_ids_to_skip, - lims_connection=lims_connection) + ecephys_session_id_list=[ + ecephys_session_id, + ], + probe_ids_to_skip=probe_ids_to_skip, + lims_connection=lims_connection, + ) - raw_unit_table.rename( - columns={'unit_id': 'id', - 'ecephys_channel_id': 'peak_channel_id'}, - inplace=True) + raw_unit_table.rename(columns={"unit_id": "id", "ecephys_channel_id": "peak_channel_id"}, inplace=True) if len(raw_unit_table) == 0: msg = f"could not find units for session {ecephys_session_id}" - error_log.log(ecephys_session_id=ecephys_session_id, - msg=msg) + error_log.log(ecephys_session_id=ecephys_session_id, msg=msg) return dict() raw_unit_table.drop( - labels=['ecephys_session_id', - 'probe_vertical_position', - 'probe_horizontal_position', - 'anterior_posterior_ccf_coordinate', - 'dorsal_ventral_ccf_coordinate', - 'left_right_ccf_coordinate', - 'structure_id', - 'structure_acronym', - 'valid_data'], - axis='columns', - inplace=True) - - raw_unit_table = raw_unit_table.set_index('id') - raw_unit_table = raw_unit_table.to_dict(orient='index') + labels=[ + "ecephys_session_id", + "probe_vertical_position", + "probe_horizontal_position", + "anterior_posterior_ccf_coordinate", + "dorsal_ventral_ccf_coordinate", + "left_right_ccf_coordinate", + "structure_id", + "structure_acronym", + "valid_data", + ], + axis="columns", + inplace=True, + ) + + raw_unit_table = raw_unit_table.set_index("id") + raw_unit_table = raw_unit_table.to_dict(orient="index") probe_id_to_units = dict() for unit_id in raw_unit_table.keys(): this_unit = raw_unit_table[unit_id] - this_unit['id'] = unit_id - probe_id = this_unit.pop('ecephys_probe_id') + this_unit["id"] = unit_id + probe_id = this_unit.pop("ecephys_probe_id") if probe_id not in probe_id_to_units: probe_id_to_units[probe_id] = [] probe_id_to_units[probe_id].append(_nan_to_none(this_unit)) @@ -739,10 +709,8 @@ def unit_input_from_ecephys_session_id( def eye_tracking_geometry_from_equipment_id( - equipment_id: int, - date_of_acquisition: pd.Timestamp, - lims_connection: PostgresQueryMixin) -> dict: - + equipment_id: int, date_of_acquisition: pd.Timestamp, lims_connection: PostgresQueryMixin +) -> dict: """ Return eye_tracking_rig_geometry given a specified equipment_id and date_of_acquisition @@ -773,43 +741,46 @@ def eye_tracking_geometry_from_equipment_id( before date_of_acquisition. """ raw_eye_geometry = _raw_eye_tracking_geometry_from_equipment_id( - equipment_id=equipment_id, - date_of_acquisition=date_of_acquisition, - lims_connection=lims_connection) + equipment_id=equipment_id, date_of_acquisition=date_of_acquisition, lims_connection=lims_connection + ) eye_geometry = dict() - eye_geometry['led_position'] = [ - raw_eye_geometry['led position']['center_x_mm'], - raw_eye_geometry['led position']['center_y_mm'], - raw_eye_geometry['led position']['center_z_mm']] - - eye_geometry['monitor_position_mm'] = [ - raw_eye_geometry['screen position']['center_x_mm'], - raw_eye_geometry['screen position']['center_y_mm'], - raw_eye_geometry['screen position']['center_z_mm']] - - eye_geometry['monitor_rotation_deg'] = [ - raw_eye_geometry['screen position']['rotation_x_deg'], - raw_eye_geometry['screen position']['rotation_y_deg'], - raw_eye_geometry['screen position']['rotation_z_deg']] - - eye_geometry['camera_position_mm'] = [ - raw_eye_geometry['eye camera position']['center_x_mm'], - raw_eye_geometry['eye camera position']['center_y_mm'], - raw_eye_geometry['eye camera position']['center_z_mm']] - - eye_geometry['camera_rotation_deg'] = [ - raw_eye_geometry['eye camera position']['rotation_x_deg'], - raw_eye_geometry['eye camera position']['rotation_y_deg'], - raw_eye_geometry['eye camera position']['rotation_z_deg']] + eye_geometry["led_position"] = [ + raw_eye_geometry["led position"]["center_x_mm"], + raw_eye_geometry["led position"]["center_y_mm"], + raw_eye_geometry["led position"]["center_z_mm"], + ] + + eye_geometry["monitor_position_mm"] = [ + raw_eye_geometry["screen position"]["center_x_mm"], + raw_eye_geometry["screen position"]["center_y_mm"], + raw_eye_geometry["screen position"]["center_z_mm"], + ] + + eye_geometry["monitor_rotation_deg"] = [ + raw_eye_geometry["screen position"]["rotation_x_deg"], + raw_eye_geometry["screen position"]["rotation_y_deg"], + raw_eye_geometry["screen position"]["rotation_z_deg"], + ] + + eye_geometry["camera_position_mm"] = [ + raw_eye_geometry["eye camera position"]["center_x_mm"], + raw_eye_geometry["eye camera position"]["center_y_mm"], + raw_eye_geometry["eye camera position"]["center_z_mm"], + ] + + eye_geometry["camera_rotation_deg"] = [ + raw_eye_geometry["eye camera position"]["rotation_x_deg"], + raw_eye_geometry["eye camera position"]["rotation_y_deg"], + raw_eye_geometry["eye camera position"]["rotation_z_deg"], + ] return eye_geometry def _raw_eye_tracking_geometry_from_equipment_id( - equipment_id: int, - date_of_acquisition: pd.Timestamp, - lims_connection: PostgresQueryMixin) -> dict: + equipment_id: int, date_of_acquisition: pd.Timestamp, lims_connection: PostgresQueryMixin +) -> dict: """ Return eye_tracking_rig_geometry given a specified equipment_id and date_of_acquisition @@ -848,31 +819,27 @@ def _raw_eye_tracking_geometry_from_equipment_id( active_date that is before date_of_acquisition. """ config = dict() - for name in ('led position', 'behavior camera position', - 'eye camera position', 'screen position'): + for name in ("led position", "behavior camera position", "eye camera position", "screen position"): this_df = experiment_configs_from_equipment_id_and_type( - equipment_id=equipment_id, - config_type=name, - lims_connection=lims_connection) - this_df = this_df.loc[ - this_df.active_date.dt.date <= date_of_acquisition] + equipment_id=equipment_id, config_type=name, lims_connection=lims_connection + ) + this_df = this_df.loc[this_df.active_date.dt.date <= date_of_acquisition] this_df = this_df.iloc[this_df.active_date.idxmax()] this_config = dict() - this_config['center_x_mm'] = this_df.center_x_mm - this_config['center_y_mm'] = this_df.center_y_mm - this_config['center_z_mm'] = this_df.center_z_mm - this_config['rotation_x_deg'] = this_df.rotation_x_deg - this_config['rotation_y_deg'] = this_df.rotation_y_deg - this_config['rotation_z_deg'] = this_df.rotation_z_deg + this_config["center_x_mm"] = this_df.center_x_mm + this_config["center_y_mm"] = this_df.center_y_mm + this_config["center_z_mm"] = this_df.center_z_mm + this_config["rotation_x_deg"] = this_df.rotation_x_deg + this_config["rotation_y_deg"] = this_df.rotation_y_deg + this_config["rotation_z_deg"] = this_df.rotation_z_deg config[name] = this_config return config def _analysis_run_from_session_id( - lims_connection: PostgresQueryMixin, - ecephys_session_id_list: List[int], - strategy_class: str) -> Dict[int, int]: + lims_connection: PostgresQueryMixin, ecephys_session_id_list: List[int], strategy_class: str +) -> Dict[int, int]: """ Get a dict mapping ecephys_session_id to ecephys_analysis_runs.id for a specific job strategy class ('VbnCreateStimTableStrategy', @@ -911,26 +878,19 @@ def _analysis_run_from_session_id( """ query += build_in_list_selector_query( - col="ecephys_analysis_runs.ecephys_session_id", - valid_list=ecephys_session_id_list, - operator="AND", - valid=True) + col="ecephys_analysis_runs.ecephys_session_id", valid_list=ecephys_session_id_list, operator="AND", valid=True + ) query += build_in_list_selector_query( - col="ecephys_analysis_runs.job_strategy_class", - valid_list=[f"'{strategy_class}'"], - operator="AND", - valid=True) + col="ecephys_analysis_runs.job_strategy_class", valid_list=[f"'{strategy_class}'"], operator="AND", valid=True + ) query_result = lims_connection.select(query) analysis_run_map = dict() msg = "" - for session_id, run_id in zip(query_result.ecephys_session_id, - query_result.ecephys_analysis_run_id): - + for session_id, run_id in zip(query_result.ecephys_session_id, query_result.ecephys_analysis_run_id): if session_id in analysis_run_map: - msg += ("More than one analysis run returned for " - f"ecephys_session_id={session_id}\n") + msg += f"More than one analysis run returned for ecephys_session_id={session_id}\n" analysis_run_map[session_id] = run_id if len(msg) > 0: @@ -940,10 +900,8 @@ def _analysis_run_from_session_id( def _add_spike_times_path( - probe_list: List[dict], - ecephys_session_id: int, - lims_connection: PostgresQueryMixin, - error_log: NwbConfigErrorLog) -> List[dict]: + probe_list: List[dict], ecephys_session_id: int, lims_connection: PostgresQueryMixin, error_log: NwbConfigErrorLog +) -> List[dict]: """ Add the 'spike_times_path' entry to a list of probe specifications. @@ -966,12 +924,15 @@ def _add_spike_times_path( Will alter probe_list in place """ - probe_id_list = [this_probe['id'] for this_probe in probe_list] + probe_id_list = [this_probe["id"] for this_probe in probe_list] timestamp_run_lookup = _analysis_run_from_session_id( - lims_connection=lims_connection, - ecephys_session_id_list=[ecephys_session_id, ], - strategy_class='EcephysAlignTimestampsStrategy') + lims_connection=lims_connection, + ecephys_session_id_list=[ + ecephys_session_id, + ], + strategy_class="EcephysAlignTimestampsStrategy", + ) # get mapping from probe_id to # ecephys_analysis_run_probes.id @@ -984,43 +945,41 @@ def _add_spike_times_path( """ query += build_in_list_selector_query( - col="ecephys_analysis_run_probes.ecephys_probe_id", - valid_list=probe_id_list, - operator="WHERE", - valid=True) + col="ecephys_analysis_run_probes.ecephys_probe_id", valid_list=probe_id_list, operator="WHERE", valid=True + ) query += build_in_list_selector_query( - col="ecephys_analysis_run_probes.ecephys_analysis_run_id", - valid_list=[ - timestamp_run_lookup[ecephys_session_id], ], - operator="AND", - valid=True) + col="ecephys_analysis_run_probes.ecephys_analysis_run_id", + valid_list=[ + timestamp_run_lookup[ecephys_session_id], + ], + operator="AND", + valid=True, + ) query_result = lims_connection.select(query) probe_run_lookup = dict() - for p_id, r_id in zip(query_result.probe_id, - query_result.ecephys_analysis_run_probe_id): + for p_id, r_id in zip(query_result.probe_id, query_result.ecephys_analysis_run_probe_id): probe_run_lookup[int(p_id)] = int(r_id) for this_probe in probe_list: - probe_id = this_probe['id'] + probe_id = this_probe["id"] if probe_id in probe_run_lookup: timestamp_lookup = wkf_path_from_attachable( - lims_connection=lims_connection, - wkf_type_name=[ - "'EcephysAlignedEventTimestamps'", ], - attachable_type='EcephysAnalysisRunProbe', - attachable_id=probe_run_lookup[probe_id]) - - this_probe['spike_times_path'] = \ - timestamp_lookup['EcephysAlignedEventTimestamps'] + lims_connection=lims_connection, + wkf_type_name=[ + "'EcephysAlignedEventTimestamps'", + ], + attachable_type="EcephysAnalysisRunProbe", + attachable_id=probe_run_lookup[probe_id], + ) + + this_probe["spike_times_path"] = timestamp_lookup["EcephysAlignedEventTimestamps"] else: - msg = ("could not find EcephysAlignedEventTimestamps for " - f"probe {probe_id}") - error_log.log(ecephys_session_id=ecephys_session_id, - msg=msg) + msg = f"could not find EcephysAlignedEventTimestamps for probe {probe_id}" + error_log.log(ecephys_session_id=ecephys_session_id, msg=msg) - this_probe['spike_times_path'] = None + this_probe["spike_times_path"] = None return probe_list diff --git a/allensdk/brain_observatory/vbn_2022/metadata_writer/__main__.py b/allensdk/brain_observatory/vbn_2022/metadata_writer/__main__.py index daabe25392..4e70fb6fc4 100644 --- a/allensdk/brain_observatory/vbn_2022/metadata_writer/__main__.py +++ b/allensdk/brain_observatory/vbn_2022/metadata_writer/__main__.py @@ -1,6 +1,4 @@ -from allensdk.brain_observatory.\ - vbn_2022.metadata_writer.metadata_writer import ( - VBN2022MetadataWriterClass) +from allensdk.brain_observatory.vbn_2022.metadata_writer.metadata_writer import VBN2022MetadataWriterClass def main(): diff --git a/allensdk/brain_observatory/vbn_2022/metadata_writer/dataframe_manipulations.py b/allensdk/brain_observatory/vbn_2022/metadata_writer/dataframe_manipulations.py index 027fc31bb6..e56d536bba 100644 --- a/allensdk/brain_observatory/vbn_2022/metadata_writer/dataframe_manipulations.py +++ b/allensdk/brain_observatory/vbn_2022/metadata_writer/dataframe_manipulations.py @@ -26,9 +26,7 @@ ) -def _add_session_number( - sessions_df: pd.DataFrame, index_col: str -) -> pd.DataFrame: +def _add_session_number(sessions_df: pd.DataFrame, index_col: str) -> pd.DataFrame: """ For each mouse: order sessions by date_of_acquisition. Add a session_number column corresponding to where that session falls in the mouse's history. @@ -77,9 +75,7 @@ def _add_session_number( new_data.append(element) new_df = pd.DataFrame(data=new_data) - sessions_df = sessions_df.join( - new_df.set_index(index_col), on=index_col, how="left" - ) + sessions_df = sessions_df.join(new_df.set_index(index_col), on=index_col, how="left") return sessions_df @@ -207,13 +203,10 @@ def _add_prior_omissions( # add prior_exposures_to_omissions to the full history data frame contains_omissions = pd.Series(False, index=full_history_df.index) contains_omissions.loc[ - (full_history_df.session_type.notnull()) - & (full_history_df.session_type.str.lower().str.contains("ephys")) + (full_history_df.session_type.notnull()) & (full_history_df.session_type.str.lower().str.contains("ephys")) ] = True - full_history_df[ - "prior_exposures_to_omissions" - ] = __get_prior_exposure_count( + full_history_df["prior_exposures_to_omissions"] = __get_prior_exposure_count( df=full_history_df, to=contains_omissions, agg_method="cumsum" ) @@ -293,10 +286,7 @@ def _patch_date_and_stage_from_pickle_file( for col in columns_to_patch: msg = "" if col not in ("date_of_acquisition", "session_type"): - msg += ( - "can only patch 'date_of_acquisition' " - "and 'session_type'; you asked for '{col}'\n" - ) + msg += "can only patch 'date_of_acquisition' and 'session_type'; you asked for '{col}'\n" if len(msg) > 0: raise ValueError(msg) @@ -305,7 +295,7 @@ def _patch_date_and_stage_from_pickle_file( invalid_rows = np.zeros(len(behavior_df), dtype=bool) for col_name in flag_columns: if col_name not in behavior_df.columns: - raise ValueError("dataframe does not contain column " "{col_name}") + raise ValueError("dataframe does not contain column {col_name}") invalid_rows[behavior_df[col_name].isna()] = True invalid_beh = behavior_df.iloc[invalid_rows].behavior_session_id.values @@ -320,13 +310,9 @@ def _patch_date_and_stage_from_pickle_file( t0 = time.time() n_to_log = max(1, n_to_patch // 10) - for beh_ct, (beh_id, pkl_path) in enumerate( - zip(pickle_path_df.behavior_session_id, pickle_path_df.pkl_path) - ): + for beh_ct, (beh_id, pkl_path) in enumerate(zip(pickle_path_df.behavior_session_id, pickle_path_df.pkl_path)): stim_file = BehaviorStimulusFile(filepath=pkl_path) - new_date = DateOfAcquisition.from_stimulus_file( - stimulus_file=stim_file - ).value + new_date = DateOfAcquisition.from_stimulus_file(stimulus_file=stim_file).value new_session_type = stim_file.session_type new_vals = { @@ -338,9 +324,7 @@ def _patch_date_and_stage_from_pickle_file( if len(new_row) == 1: new_row = new_row[0] - behavior_df.loc[ - behavior_df.behavior_session_id == beh_id, columns_to_patch - ] = new_row + behavior_df.loc[behavior_df.behavior_session_id == beh_id, columns_to_patch] = new_row if (beh_ct + 1) % n_to_log == 0 and logger is not None: duration = time.time() - t0 @@ -348,7 +332,7 @@ def _patch_date_and_stage_from_pickle_file( pred = n_to_patch * per remaining = pred - duration logger.info( - f"Patched {beh_ct+1} of {n_to_patch} " + f"Patched {beh_ct + 1} of {n_to_patch} " f"in {duration:.2e} seconds; " f"predict {remaining:.2e} seconds more" ) @@ -378,17 +362,13 @@ def _add_age_in_days(df: pd.DataFrame, index_column: str) -> pd.DataFrame: df: pd.DataFrame Same as input, but with age_in_days added """ - age_in_days = ( - df["date_of_acquisition"].dt.date - df["date_of_birth"].dt.date - ) + age_in_days = df["date_of_acquisition"].dt.date - df["date_of_birth"].dt.date age_in_days = age_in_days.apply(lambda x: x.days) df["age_in_days"] = age_in_days return df -def _add_images_from_behavior( - ecephys_table: pd.DataFrame, behavior_table: pd.DataFrame -) -> pd.DataFrame: +def _add_images_from_behavior(ecephys_table: pd.DataFrame, behavior_table: pd.DataFrame) -> pd.DataFrame: """ Use the behavior sessions table to add image_set and prior_exposures_to_image_set to ecephys table. @@ -431,9 +411,7 @@ def _add_images_from_behavior( return ecephys_table -def strip_substructure_acronym_df( - df: pd.DataFrame, col_name: str -) -> pd.DataFrame: +def strip_substructure_acronym_df(df: pd.DataFrame, col_name: str) -> pd.DataFrame: """ Take the structure_acronym(s) column of a dataframe and remove the substructure (i.e. convert DG-mo to DG). diff --git a/allensdk/brain_observatory/vbn_2022/metadata_writer/lims_queries.py b/allensdk/brain_observatory/vbn_2022/metadata_writer/lims_queries.py index c0acef7fe1..1ffa1c2f8a 100644 --- a/allensdk/brain_observatory/vbn_2022/metadata_writer/lims_queries.py +++ b/allensdk/brain_observatory/vbn_2022/metadata_writer/lims_queries.py @@ -30,9 +30,7 @@ ) -def get_list_of_bad_probe_ids( - lims_connection: PostgresQueryMixin, probes_to_skip: List[Dict[str, Any]] -) -> List[int]: +def get_list_of_bad_probe_ids(lims_connection: PostgresQueryMixin, probes_to_skip: List[Dict[str, Any]]) -> List[int]: """ Given a list of probes to skip,each of the form @@ -438,9 +436,7 @@ def _merge_ecephys_id_and_failed( session_id_list=failed_ecephys_session_id_list, ) to_keep = [] - for session_id, donor_id in zip( - failed_donor_lookup.ecephys_session_id, failed_donor_lookup.donor_id - ): + for session_id, donor_id in zip(failed_donor_lookup.ecephys_session_id, failed_donor_lookup.donor_id): if donor_id in passed_donor_ids: to_keep.append(int(session_id)) @@ -553,9 +549,7 @@ def _ecephys_summary_table_from_ecephys_session_id_list( summary_table = lims_connection.select(query) # Add UTC tz - summary_table["date_of_acquisition"] = summary_table[ - "date_of_acquisition" - ].dt.tz_localize("UTC") + summary_table["date_of_acquisition"] = summary_table["date_of_acquisition"].dt.tz_localize("UTC") return summary_table @@ -758,43 +752,27 @@ def _behavior_session_table_from_ecephys_session_id_list( ) behavior_sessions = get_session_metadata_multiprocessing( behavior_session_ids=behavior_session_df["behavior_session_id"], - lims_engine=db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ), + lims_engine=db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP), n_workers=n_workers, ) if exclude_invalid_sessions: - behavior_sessions = remove_invalid_sessions( - behavior_sessions=behavior_sessions - ) + behavior_sessions = remove_invalid_sessions(behavior_sessions=behavior_sessions) behavior_session_df = behavior_session_df[ - behavior_session_df["behavior_session_id"].isin( - [x.behavior_session_id for x in behavior_sessions] - ) + behavior_session_df["behavior_session_id"].isin([x.behavior_session_id for x in behavior_sessions]) ] # Add timezone information to behavior daq. Matches ecephys table. - behavior_session_df["date_of_acquisition"] = behavior_session_df[ - "date_of_acquisition" - ].dt.tz_localize("UTC") + behavior_session_df["date_of_acquisition"] = behavior_session_df["date_of_acquisition"].dt.tz_localize("UTC") behavior_session_df["image_set"] = get_image_set(df=behavior_session_df) - behavior_session_df[ - "prior_exposures_to_session_type" - ] = get_prior_exposures_to_session_type(df=behavior_session_df) + behavior_session_df["prior_exposures_to_session_type"] = get_prior_exposures_to_session_type(df=behavior_session_df) - behavior_session_df[ - "prior_exposures_to_image_set" - ] = get_prior_exposures_to_image_set(df=behavior_session_df) + behavior_session_df["prior_exposures_to_image_set"] = get_prior_exposures_to_image_set(df=behavior_session_df) - behavior_session_df = _add_age_in_days( - df=behavior_session_df, index_column="behavior_session_id" - ) + behavior_session_df = _add_age_in_days(df=behavior_session_df, index_column="behavior_session_id") - behavior_session_df = _add_session_number( - sessions_df=behavior_session_df, index_col="behavior_session_id" - ) + behavior_session_df = _add_session_number(sessions_df=behavior_session_df, index_col="behavior_session_id") return behavior_session_df @@ -892,13 +870,9 @@ def session_tables_from_ecephys_session_id_list( # since we had to read date_of_acquisition from the pickle file, # we now need to calculate age_in_days - summary_tbl = _add_age_in_days( - df=summary_tbl, index_column="ecephys_session_id" - ) + summary_tbl = _add_age_in_days(df=summary_tbl, index_column="ecephys_session_id") - summary_tbl.drop( - labels=["date_of_birth", "equipment_id"], axis="columns", inplace=True - ) + summary_tbl.drop(labels=["date_of_birth", "equipment_id"], axis="columns", inplace=True) ct_tbl = _ecephys_counts_per_session_from_ecephys_session_id_list( lims_connection=lims_connection, @@ -924,18 +898,12 @@ def session_tables_from_ecephys_session_id_list( how="left", ) - summary_tbl = _add_images_from_behavior( - ecephys_table=summary_tbl, behavior_table=beh_table - ) + summary_tbl = _add_images_from_behavior(ecephys_table=summary_tbl, behavior_table=beh_table) - sessions_table = _add_session_number( - sessions_df=summary_tbl, index_col="ecephys_session_id" - ) + sessions_table = _add_session_number(sessions_df=summary_tbl, index_col="ecephys_session_id") sessions_table = add_experience_level_simple(input_df=sessions_table) - omission_results = _add_prior_omissions( - behavior_sessions_df=beh_table, ecephys_sessions_df=sessions_table - ) + omission_results = _add_prior_omissions(behavior_sessions_df=beh_table, ecephys_sessions_df=sessions_table) beh_table = omission_results["behavior"] sessions_table = omission_results["ecephys"] @@ -962,10 +930,7 @@ def session_tables_from_ecephys_session_id_list( # pare back down to only passed sessions if failed_ecephys_session_id_list is not None: sessions_table = sessions_table[ - [ - eid in set(ecephys_session_id_list) - for eid in sessions_table.ecephys_session_id - ] + [eid in set(ecephys_session_id_list) for eid in sessions_table.ecephys_session_id] ] return sessions_table, beh_table diff --git a/allensdk/brain_observatory/vbn_2022/metadata_writer/metadata_writer.py b/allensdk/brain_observatory/vbn_2022/metadata_writer/metadata_writer.py index 7112d261f3..e15b2e5d36 100644 --- a/allensdk/brain_observatory/vbn_2022/metadata_writer/metadata_writer.py +++ b/allensdk/brain_observatory/vbn_2022/metadata_writer/metadata_writer.py @@ -42,18 +42,14 @@ def write_df(self, df: pd.DataFrame, output_path: str) -> None: """ df.to_csv(output_path, index=False) self.files_written.append(output_path) - self.logger.info( - f"Wrote {output_path} after " f"{time.time()-self.t0: .2e} seconds" - ) + self.logger.info(f"Wrote {output_path} after {time.time() - self.t0: .2e} seconds") def run(self): self.t0 = time.time() file_id_generator = FileIDGenerator() - lims_connection = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_connection = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) if self.args["probes_to_skip"] is not None: probe_ids_to_skip = get_list_of_bad_probe_ids( @@ -72,9 +68,7 @@ def run(self): probe_ids_to_skip=probe_ids_to_skip, ) - units_table = strip_substructure_acronym_df( - df=units_table, col_name="structure_acronym" - ) + units_table = strip_substructure_acronym_df(df=units_table, col_name="structure_acronym") units_table = units_table[ [ @@ -125,9 +119,7 @@ def run(self): probe_ids_to_skip=probe_ids_to_skip, ) - probes_table = strip_substructure_acronym_df( - df=probes_table, col_name="structure_acronyms" - ) + probes_table = strip_substructure_acronym_df(df=probes_table, col_name="structure_acronyms") probes_table.drop( labels=["temporal_subsampling_factor"], @@ -161,17 +153,11 @@ def run(self): probe_ids_to_skip=probe_ids_to_skip, ) - channels_table = strip_substructure_acronym_df( - df=channels_table, col_name="structure_acronym" - ) + channels_table = strip_substructure_acronym_df(df=channels_table, col_name="structure_acronym") - channels_table.drop( - labels=["structure_id"], axis="columns", inplace=True - ) + channels_table.drop(labels=["structure_id"], axis="columns", inplace=True) - self.write_df( - df=channels_table, output_path=self.args["channels_path"] - ) + self.write_df(df=channels_table, output_path=self.args["channels_path"]) failed_session_list = self.args["failed_ecephys_session_id_list"] @@ -203,12 +189,8 @@ def run(self): ) ecephys_session_ids = ecephys_session_table["ecephys_session_id"] - behavior_ecephys_session_ids = behavior_session_table[ - "ecephys_session_id" - ] - ecephys_session_mask = behavior_ecephys_session_ids.isin( - ecephys_session_ids - ) + behavior_ecephys_session_ids = behavior_session_table["ecephys_session_id"] + ecephys_session_mask = behavior_ecephys_session_ids.isin(ecephys_session_ids) behavior_only_table = behavior_session_table[~ecephys_session_mask] behavior_w_ecephy_table = behavior_session_table[ecephys_session_mask] @@ -228,9 +210,7 @@ def run(self): data_dir_col="behavior_session_id", on_missing_file=self.args["on_missing_file"], ) - behavior_session_table = pd.concat( - [behavior_only_table, behavior_w_ecephy_table] - ) + behavior_session_table = pd.concat([behavior_only_table, behavior_w_ecephy_table]) # add supplemental columns to the ecephys_sessions # column @@ -281,5 +261,5 @@ def run(self): self.logger.info( f"Wrote {self.args['ecephys_sessions_path']}\n" f"and {self.args['behavior_sessions_path']}\n" - f"after {time.time()-self.t0:.2e} seconds" + f"after {time.time() - self.t0:.2e} seconds" ) diff --git a/allensdk/brain_observatory/vbn_2022/metadata_writer/schemas.py b/allensdk/brain_observatory/vbn_2022/metadata_writer/schemas.py index f00a257b77..2a9437db78 100644 --- a/allensdk/brain_observatory/vbn_2022/metadata_writer/schemas.py +++ b/allensdk/brain_observatory/vbn_2022/metadata_writer/schemas.py @@ -2,18 +2,17 @@ import argschema from allensdk.brain_observatory.vbn_2022.utils.schemas import ProbeToSkip -from allensdk.brain_observatory.behavior.behavior_project_cache.project_metadata_writer.schemas import BaseMetadataWriterInputSchema # noqa: E501 +from allensdk.brain_observatory.behavior.behavior_project_cache.project_metadata_writer.schemas import ( + BaseMetadataWriterInputSchema, +) # noqa: E501 from marshmallow import post_load class VBN2022MetadataWriterInputSchema(BaseMetadataWriterInputSchema): - ecephys_session_id_list = argschema.fields.List( argschema.fields.Int, required=True, - description=( - "List of ecephys_sessions.id values " "of sessions to be released" - ), + description=("List of ecephys_sessions.id values of sessions to be released"), ) failed_ecephys_session_id_list = argschema.fields.List( @@ -41,9 +40,7 @@ class VBN2022MetadataWriterInputSchema(BaseMetadataWriterInputSchema): ecephys_nwb_dir = argschema.fields.InputDir( required=True, allow_none=False, - description=( - "The directory where ecephys_nwb sessions are " "to be found" - ), + description=("The directory where ecephys_nwb sessions are to be found"), ) ecephys_nwb_prefix = argschema.fields.Str( @@ -72,10 +69,10 @@ class VBN2022MetadataWriterInputSchema(BaseMetadataWriterInputSchema): n_workers = argschema.fields.Int( default=8, allow_none=True, - description='Number of workers for reading from pkl file. ' - 'Default=8 due to issues with making too many ' - 'requests to the database. Increase if too slow, decrease ' - 'if the database rejects the connection' + description="Number of workers for reading from pkl file. " + "Default=8 due to issues with making too many " + "requests to the database. Increase if too slow, decrease " + "if the database rejects the connection", ) @post_load @@ -98,8 +95,6 @@ def validate_paths(self, data, **kwargs): if len(msg) > 0: raise RuntimeError( - "The following files already exist\n" - f"{msg}" - "Run with clobber=True if you want to overwrite" + f"The following files already exist\n{msg}Run with clobber=True if you want to overwrite" ) return data diff --git a/allensdk/brain_observatory/vbn_2022/utils/schemas.py b/allensdk/brain_observatory/vbn_2022/utils/schemas.py index 40ed0d2572..0c4ea1727a 100644 --- a/allensdk/brain_observatory/vbn_2022/utils/schemas.py +++ b/allensdk/brain_observatory/vbn_2022/utils/schemas.py @@ -4,22 +4,14 @@ class ProbeToSkip(argschema.ArgSchema): + session = argschema.fields.Int(required=True, description=("The ecephys_session_id associated with the bad probe")) - session = argschema.fields.Int( - required=True, - description=("The ecephys_session_id associated with " - "the bad probe")) - - probe = argschema.fields.Str( - required=True, - description=("The name of the bad probe, e.g. 'probeA'")) + probe = argschema.fields.Str(required=True, description=("The name of the bad probe, e.g. 'probeA'")) @post_load def validate_probe_names(self, data, **kwargs): - pattern = re.compile('probe[A-Z]') - match = pattern.match(data['probe']) - if match is None or len(data['probe']) != 6: - raise ValueError( - f"{data['probe']} is not a valid probe name; " - "must be like 'probe[A-Z]'") + pattern = re.compile("probe[A-Z]") + match = pattern.match(data["probe"]) + if match is None or len(data["probe"]) != 6: + raise ValueError(f"{data['probe']} is not a valid probe name; must be like 'probe[A-Z]'") return data diff --git a/allensdk/brain_observatory/visualization/__init__.py b/allensdk/brain_observatory/visualization/__init__.py index 53bda7c47e..5cfbf5d9b4 100644 --- a/allensdk/brain_observatory/visualization/__init__.py +++ b/allensdk/brain_observatory/visualization/__init__.py @@ -1,13 +1,17 @@ import matplotlib.pyplot as plt + def plot_running_speed( - timestamps, values, - start_index=0, stop_index=None, step=1, - ylabel='running speed (cm/s)', - xlabel='time (s)', - title=None -): # pragma: no cover - ''' Make a simple plot of a running speed trace + timestamps, + values, + start_index=0, + stop_index=None, + step=1, + ylabel="running speed (cm/s)", + xlabel="time (s)", + title=None, +): # pragma: no cover + """Make a simple plot of a running speed trace Parameters ---------- @@ -15,22 +19,22 @@ def plot_running_speed( Times at which running speed samples were collected values : numpy.ndarray Running speed values (by default: linear cm / s with negative values indicating backwards movement) - - ''' + + """ stop_index = len(timestamps) if stop_index is None else stop_index if title is None: - title = f'running speed from {timestamps[start_index]:2.2f} to {timestamps[stop_index-1]:2.2f} seconds' + title = f"running speed from {timestamps[start_index]:2.2f} to {timestamps[stop_index - 1]:2.2f} seconds" fig, ax = plt.subplots(figsize=(8, 8)) plt.plot( - timestamps[start_index:stop_index:step], + timestamps[start_index:stop_index:step], values[start_index:stop_index:step], ) ax.set_ylabel(ylabel, fontsize=16) ax.set_xlabel(xlabel, fontsize=16) ax.set_title(title, fontsize=20) - plt.axis('tight') + plt.axis("tight") - return fig \ No newline at end of file + return fig diff --git a/allensdk/config/__init__.py b/allensdk/config/__init__.py index 66d0b42cfd..6d80230400 100644 --- a/allensdk/config/__init__.py +++ b/allensdk/config/__init__.py @@ -38,24 +38,25 @@ _console_handler = logging.StreamHandler(sys.stdout) + def enable_console_log(level=None): - '''configure allensdk logging to output to the console. + """configure allensdk logging to output to the console. + + Parameters + ---------- + level : int + logging level 0-50 (logging.INFO, logging.DEBUG, etc.) - Parameters - ---------- - level : int - logging level 0-50 (logging.INFO, logging.DEBUG, etc.) + Notes + ----- + See: `Logging Cookbook `_ + """ - Notes - ----- - See: `Logging Cookbook `_ - ''' - - sdk_logger = logging.getLogger('allensdk') + sdk_logger = logging.getLogger("allensdk") if level is None: sdk_logger.setLevel(logging.DEBUG) else: sdk_logger.setLevel(level) - sdk_logger.addHandler(_console_handler) \ No newline at end of file + sdk_logger.addHandler(_console_handler) diff --git a/allensdk/config/app/__init__.py b/allensdk/config/app/__init__.py index 6177de1ae7..ec4853361e 100644 --- a/allensdk/config/app/__init__.py +++ b/allensdk/config/app/__init__.py @@ -33,8 +33,8 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -''' +""" allensdk.config.app is a package that assists in configuring application software, as opposed to domain-specific configuration. -''' +""" diff --git a/allensdk/config/app/application_config.py b/allensdk/config/app/application_config.py index 4bc5e3d67c..f3db87ebe0 100644 --- a/allensdk/config/app/application_config.py +++ b/allensdk/config/app/application_config.py @@ -43,21 +43,15 @@ class ApplicationConfig(object): - ''' Convenience class that handles of application configuration + """Convenience class that handles of application configuration from environment variables, .conf files and the command line using Python standard libraries and formats. - ''' + """ _log = logging.getLogger(__name__) - _DEFAULT_LOG_CONFIG = os.getenv( - 'LOG_CFG', str(files(__package__).joinpath('logging.conf'))) - - def __init__(self, - defaults, - name="app", - halp="Run application.", - default_log_config=None): + _DEFAULT_LOG_CONFIG = os.getenv("LOG_CFG", str(files(__package__).joinpath("logging.conf"))) + def __init__(self, defaults, name="app", halp="Run application.", default_log_config=None): self.application_name = name self.help = halp self.debug_enabled = False @@ -66,18 +60,11 @@ def __init__(self, default_log_config = ApplicationConfig._DEFAULT_LOG_CONFIG lc.fileConfig(_DEFAULT_LOG_CONFIG) - ApplicationConfig._log.info( - "default log config: %s" % (default_log_config)) + ApplicationConfig._log.info("default log config: %s" % (default_log_config)) self.defaults = { - 'config_file_path': { - 'default': "%s.conf" % (self.application_name), - 'help': 'configuration file path' - }, - 'log_config_path': { - 'default': default_log_config, - 'help': 'logging configuration path' - } + "config_file_path": {"default": "%s.conf" % (self.application_name), "help": "configuration file path"}, + "log_config_path": {"default": default_log_config, "help": "logging configuration path"}, } self.defaults.update(defaults) @@ -87,10 +74,10 @@ def __init__(self, self.argparser = self.create_argparser() for key, value in self.defaults.items(): - setattr(self, key, value['default']) + setattr(self, key, value["default"]) def load(self, command_line_args, disable_existing_loggers=True): - ''' Load application configuration options, first from the environment, + """Load application configuration options, first from the environment, then from the configuration file, then from the command line. Each stage of loading can override the previous stage. @@ -106,7 +93,7 @@ def load(self, command_line_args, disable_existing_loggers=True): ------- fileConfig Configuration object with all levels applied - ''' + """ # read and apply options from the environment self.apply_configuration_from_environment() @@ -125,24 +112,22 @@ def load(self, command_line_args, disable_existing_loggers=True): # apply the remaining command line options self.apply_configuration_from_command_line(parsed_args) except Exception as e: - ApplicationConfig._log.error("Could not load configuration file: %s\n%s" % - (parsed_args.config_file_path, - e)) + ApplicationConfig._log.error( + "Could not load configuration file: %s\n%s" % (parsed_args.config_file_path, e) + ) raise if parsed_args.log_config_path: try: - lc.fileConfig(self.log_config_path, - disable_existing_loggers=disable_existing_loggers) + lc.fileConfig(self.log_config_path, disable_existing_loggers=disable_existing_loggers) except Exception: - logging.error("Could not load log configuration file: %s" % - (parsed_args.log_config_path)) + logging.error("Could not load log configuration file: %s" % (parsed_args.log_config_path)) else: # TODO: configure default logging pass def create_argparser(self): - '''Initialization for the command-line parsing stage. + """Initialization for the command-line parsing stage. An application specific prefix is applied to argument names. @@ -162,21 +147,18 @@ def create_argparser(self): ----- Defaults are set at the first environment reading. Command line args only override them when present - ''' - parser = argparse.ArgumentParser(prog=self.application_name, - description=self.help) + """ + parser = argparse.ArgumentParser(prog=self.application_name, description=self.help) for key, value in self.defaults.items(): - if key == 'config_file_path': - parser.add_argument( - "%s" % (key), default=None, help=value['help']) + if key == "config_file_path": + parser.add_argument("%s" % (key), default=None, help=value["help"]) else: - parser.add_argument("--%s" % - (key), default=None, help=value['help']) + parser.add_argument("--%s" % (key), default=None, help=value["help"]) return parser def parse_command_line_args(self, args): - '''Simply call the internal argparser object. + """Simply call the internal argparser object. Parameters ---------- @@ -187,11 +169,11 @@ def parse_command_line_args(self, args): ------- Namespace Parsed paramenters. - ''' + """ return self.argparser.parse_args(args) def apply_configuration_from_command_line(self, parsed_args): - '''Read application configuration variables from the command line. + """Read application configuration variables from the command line. Unassigned variables are left unchanged if previously assigned, set to their default values, @@ -205,8 +187,8 @@ def apply_configuration_from_command_line(self, parsed_args): parsed_args : dict the arguments as parsed from the command line. - ''' - logging.info('command_line args: %s' % (parsed_args)) + """ + logging.info("command_line args: %s" % (parsed_args)) for key in self.defaults: parsed_value = getattr(parsed_args, key) @@ -214,22 +196,21 @@ def apply_configuration_from_command_line(self, parsed_args): setattr(self, key, parsed_value) def apply_configuration_from_environment(self): - '''Read application configuration variables from the environment. + """Read application configuration variables from the environment. The variable names are upper case and have a prefix defined by the application. See: https://docs.python.org/2/library/os.html - ''' + """ for key in self.defaults: - environment_variable = "%s_%s" % ( - self.application_name.upper(), key.upper()) + environment_variable = "%s_%s" % (self.application_name.upper(), key.upper()) environment_value = os.environ.get(environment_variable) if environment_value: setattr(self, key, environment_value) def from_json_file(self, json_path): - '''Read an application configuration from a JSON format file. + """Read an application configuration from a JSON format file. Parameters ---------- @@ -241,13 +222,13 @@ def from_json_file(self, json_path): string An application configuration in INI format - ''' + """ description = JsonComments.read_file(json_path) return self.to_config_string(description) def from_json_string(self, json_string): - '''Read a configuration from a JSON format string. + """Read a configuration from a JSON format string. Parameters ---------- @@ -258,13 +239,13 @@ def from_json_string(self, json_string): ------- string An application configuration in INI format - ''' + """ description = JsonComments.read_string(json_string) return self.to_config_string(description) def to_config_string(self, description): - '''Create a configuration string from a dict. + """Create a configuration string from a dict. Parameters ---------- @@ -279,25 +260,23 @@ def to_config_string(self, description): Notes ----- The Python configparser library natively supports this functionality in Python 3. - ''' - if 'biophys' not in description: - bps_config_string = '[biophys]\n\n' + """ + if "biophys" not in description: + bps_config_string = "[biophys]\n\n" return bps_config_string - bps_config = description['biophys'][0] + bps_config = description["biophys"][0] - cfg_array = ['[biophys]'] + cfg_array = ["[biophys]"] - if 'log_config_path' in bps_config: - cfg_array.append(str('log_config_path: %s' % - bps_config['log_config_path'])) + if "log_config_path" in bps_config: + cfg_array.append(str("log_config_path: %s" % bps_config["log_config_path"])) - if 'debug' in bps_config: - cfg_array.append(str('debug: %s' % bps_config['debug'])) + if "debug" in bps_config: + cfg_array.append(str("debug: %s" % bps_config["debug"])) - if 'model_file' in bps_config: - cfg_array.append(str('model_file: %s' % - ','.join(bps_config['model_file']))) + if "model_file" in bps_config: + cfg_array.append(str("model_file: %s" % ",".join(bps_config["model_file"]))) cfg_array.append("\n") @@ -307,7 +286,7 @@ def to_config_string(self, description): return bps_cfg_string def apply_configuration_from_file(self, config_file_path): - ''' Read application configuration variables from a .conf file. + """Read application configuration variables from a .conf file. Unassigned variables are set to their default values or None if no default is specified at init time. @@ -322,7 +301,7 @@ def apply_configuration_from_file(self, config_file_path): ------- see: https://docs.python.org/2/library/configparser.html - ''' + """ none_defaults = {} # defaults are set in environment @@ -335,14 +314,12 @@ def apply_configuration_from_file(self, config_file_path): config = None try: - config = ConfigParser(defaults=none_defaults, - allow_no_value=True) + config = ConfigParser(defaults=none_defaults, allow_no_value=True) except Exception: - logging.warn( - "This python installation does not support configuration defaults.") + logging.warn("This python installation does not support configuration defaults.") config = ConfigParser() - if config_file_path.endswith('.json'): + if config_file_path.endswith(".json"): cfg_string = self.from_json_file(config_file_path) config.read_string(cfg_string) else: @@ -355,5 +332,4 @@ def apply_configuration_from_file(self, config_file_path): logging.info("setting %s to %s" % (key, file_value)) setattr(self, key, file_value) except Exception: - logging.info("Configuration option not specified: %s" % - (key)) + logging.info("Configuration option not specified: %s" % (key)) diff --git a/allensdk/config/manifest.py b/allensdk/config/manifest.py index 1736cd428a..93f1117ca8 100644 --- a/allensdk/config/manifest.py +++ b/allensdk/config/manifest.py @@ -42,12 +42,11 @@ from pathlib import Path -class ManifestVersionError(Exception): - +class ManifestVersionError(Exception): @property def outdated(self): try: - return self.found_version < self.version + return self.found_version < self.version except TypeError: return @@ -55,20 +54,20 @@ def __init__(self, message, version, found_version): super(ManifestVersionError, self).__init__(message) self.found_version = found_version self.version = version - + class Manifest(object): """Manages the location of external files - referenced in an Allen SDK configuration """ + referenced in an Allen SDK configuration""" - DIR = 'dir' - FILE = 'file' - DIRNAME = 'dir_name' - VERSION = 'manifest_version' + DIR = "dir" + FILE = "file" + DIRNAME = "dir_name" + VERSION = "manifest_version" log = logging.getLogger(__name__) - def __init__(self, config=None, relative_base_dir='.', version=None): + def __init__(self, config=None, relative_base_dir=".", version=None): self.path_info = {} self.relative_base_dir = relative_base_dir @@ -76,61 +75,50 @@ def __init__(self, config=None, relative_base_dir='.', version=None): self.load_config(config, version=version) def load_config(self, config, version=None): - ''' Load paths into the manifest from an Allen SDK config section. + """Load paths into the manifest from an Allen SDK config section. Parameters ---------- config : Config Manifest section of an Allen SDK config. - ''' + """ found_version = None for path_info in config: - path_type = path_info['type'] + path_type = path_info["type"] path_format = None - if 'format' in path_info: - path_format = path_info['format'] + if "format" in path_info: + path_format = path_info["format"] - if path_type == 'file': + if path_type == "file": try: - parent_key = path_info['parent_key'] + parent_key = path_info["parent_key"] except Exception: parent_key = None - self.add_file(path_info['key'], - path_info['spec'], - parent_key, - path_format) - elif path_type == 'dir': + self.add_file(path_info["key"], path_info["spec"], parent_key, path_format) + elif path_type == "dir": try: - parent_key = path_info['parent_key'] + parent_key = path_info["parent_key"] except Exception: parent_key = None - spec = path_info['spec'] + spec = path_info["spec"] absolute = False - if spec[0] == '/': + if spec[0] == "/": absolute = True - self.add_path(path_info['key'], - path_info['spec'], - path_type, - absolute, - path_format, - parent_key) + self.add_path(path_info["key"], path_info["spec"], path_type, absolute, path_format, parent_key) elif path_type == self.VERSION: - found_version = path_info['value'] + found_version = path_info["value"] else: - Manifest.log.warning("Unknown path type in manifest: %s" % - (path_type)) - + Manifest.log.warning("Unknown path type in manifest: %s" % (path_type)) if found_version != version: raise ManifestVersionError("", version, found_version) self.version = version - def add_path(self, key, path, path_type=DIR, - absolute=True, path_format=None, parent_key=None): - '''Insert a new entry. + def add_path(self, key, path, path_type=DIR, absolute=True, path_format=None, parent_key=None): + """Insert a new entry. Parameters ---------- @@ -146,18 +134,17 @@ def add_path(self, key, path, path_type=DIR, Indicate a known file type for further parsing. parent_key : string Refer to another entry. - ''' + """ if parent_key: path_args = [] try: - parent_path = self.path_info[parent_key]['spec'] + parent_path = self.path_info[parent_key]["spec"] path_args.append(parent_path) except Exception: - Manifest.log.error( - "cannot resolve directory key %s" % (parent_key)) + Manifest.log.error("cannot resolve directory key %s" % (parent_key)) raise - path_args.extend(path.split('/')) + path_args.extend(path.split("/")) path = os.path.join(*path_args) # TODO: relative paths need to be considered better @@ -169,43 +156,35 @@ def add_path(self, key, path, path_type=DIR, if path_type == Manifest.DIRNAME: path = os.path.dirname(path) - self.path_info[key] = {'type': path_type, - 'spec': path} + self.path_info[key] = {"type": path_type, "spec": path} if path_type == Manifest.FILE and path_format is not None: - self.path_info[key]['format'] = path_format + self.path_info[key]["format"] = path_format def add_paths(self, path_info): - ''' add information about paths stored in the manifest. + """add information about paths stored in the manifest. Parameters path_info : dict Information about the new paths - ''' + """ for path_key, path_data in path_info.items(): path_format = None - if 'format' in path_data: - path_format = path_data['format'] + if "format" in path_data: + path_format = path_data["format"] - Manifest.log.info("Adding path. type: %s, format: %s, spec: %s" % - (path_data['type'], - path_data['spec'], - path_format)) - entry = {'type': path_data['type'], - 'spec': path_data['spec'] - } + Manifest.log.info( + "Adding path. type: %s, format: %s, spec: %s" % (path_data["type"], path_data["spec"], path_format) + ) + entry = {"type": path_data["type"], "spec": path_data["spec"]} if path_format is not None: - entry['format'] = path_format + entry["format"] = path_format self.path_info[path_key] = entry - def add_file(self, - file_key, - file_name, - dir_key=None, - path_format=None): - '''Insert a new file entry. + def add_file(self, file_key, file_name, dir_key=None, path_format=None): + """Insert a new file entry. Parameters ---------- @@ -217,33 +196,31 @@ def add_file(self, Reference to the parent directory entry. path_format : string, optional File type for further parsing. - ''' + """ path_args = [] if dir_key: try: - dir_path = self.path_info[dir_key]['spec'] + dir_path = self.path_info[dir_key]["spec"] path_args.append(dir_path) except Exception: - Manifest.log.error( - "cannot resolve directory key %s" % (dir_key)) + Manifest.log.error("cannot resolve directory key %s" % (dir_key)) raise - elif not file_name.startswith('/'): + elif not file_name.startswith("/"): path_args.append(os.curdir) else: path_args.append(os.path.sep) - path_args.extend(file_name.split('/')) + path_args.extend(file_name.split("/")) file_path = os.path.join(*path_args) - self.path_info[file_key] = {'type': Manifest.FILE, - 'spec': file_path} + self.path_info[file_key] = {"type": Manifest.FILE, "spec": file_path} if path_format: - self.path_info[file_key]['format'] = path_format + self.path_info[file_key]["format"] = path_format def get_path(self, path_key, *args): - '''Retrieve an entry with substitutions. + """Retrieve an entry with substitutions. Parameters ---------- @@ -256,8 +233,8 @@ def get_path(self, path_key, *args): ------- string Path with parent structure and substitutions applied. - ''' - path_spec = self.path_info[path_key]['spec'] + """ + path_spec = self.path_info[path_key]["spec"] if args is not None and len(args) != 0: path = path_spec % args @@ -267,7 +244,7 @@ def get_path(self, path_key, *args): return path def get_format(self, path_key): - '''Retrieve the type of a path entry. + """Retrieve the type of a path entry. Parameters ---------- @@ -278,18 +255,18 @@ def get_format(self, path_key): ------- string File type. - ''' + """ path_entry = self.path_info[path_key] path_format = None - if 'format' in path_entry: - path_format = path_entry['format'] + if "format" in path_entry: + path_format = path_entry["format"] return path_format @classmethod def safe_make_parent_dirs(cls, file_name): - ''' Create a parent directories for file. + """Create a parent directories for file. Parameters ---------- @@ -297,10 +274,10 @@ def safe_make_parent_dirs(cls, file_name): Returns ------- - leftmost : string + leftmost : string most rootward directory created - ''' + """ dirname = os.path.dirname(file_name) @@ -312,7 +289,7 @@ def safe_make_parent_dirs(cls, file_name): @classmethod def safe_mkdir(cls, directory): - '''Create path if not already there. + """Create path if not already there. Parameters ---------- @@ -321,10 +298,10 @@ def safe_mkdir(cls, directory): Returns ------- - leftmost : string + leftmost : string most rootward directory created - ''' + """ parts = Path(directory).parts sub_paths = [Path(parts[0])] @@ -339,16 +316,14 @@ def safe_mkdir(cls, directory): try: os.makedirs(directory) except OSError as e: - if ((sys.platform == "darwin") and (e.errno == errno.EISDIR) and \ - (e.filename == "/")): + if (sys.platform == "darwin") and (e.errno == errno.EISDIR) and (e.filename == "/"): # undocumented behavior of mkdir on OSX where for / it raises # EISDIR and not EEXIST # https://bugs.python.org/issue24231 (old but still holds true) pass elif sys.platform == "win32" and e.errno == errno.EACCES: root_path = os.path.abspath(os.sep) - if e.filename == root_path or \ - e.filename == root_path.replace("\\", "/"): + if e.filename == root_path or e.filename == root_path.replace("\\", "/"): # When attempting to os.makedirs the root drive letter on # Windows, EACCES is raised, not EEXIST pass @@ -361,20 +336,19 @@ def safe_mkdir(cls, directory): return leftmost - def create_dir(self, path_key): - '''Make a directory for an entry. + """Make a directory for an entry. Parameters ---------- path_key : string Reference to the entry. - ''' + """ dir_path = self.get_path(path_key) Manifest.safe_mkdir(dir_path) def check_dir(self, path_key, do_exit=False): - '''Verify a directories existence or optionally exit. + """Verify a directories existence or optionally exit. Parameters ---------- @@ -382,17 +356,16 @@ def check_dir(self, path_key, do_exit=False): Reference to the entry. do_exit : boolean What to do if the directory is not present. - ''' + """ dir_path = self.get_path(path_key) if not os.path.exists(dir_path): - Manifest.log.fatal('Directory %s does not exist; exiting.' % - (dir_path)) + Manifest.log.fatal("Directory %s does not exist; exiting." % (dir_path)) if do_exit is True: quit() - def resolve_paths(self, description_dict, suffix='_key'): - '''Walk input items and expand those that refer to a manifest entry. + def resolve_paths(self, description_dict, suffix="_key"): + """Walk input items and expand those that refer to a manifest entry. Parameters ---------- @@ -400,8 +373,8 @@ def resolve_paths(self, description_dict, suffix='_key'): Any entries with key names ending in suffix will be expanded. suffix : string Indicates the entries to be expanded. - ''' - key_pattern = re.compile('(.*)%s$' % (suffix)) + """ + key_pattern = re.compile("(.*)%s$" % (suffix)) for description_key, manifest_key in description_dict.items(): m = key_pattern.match(description_key) @@ -412,5 +385,4 @@ def resolve_paths(self, description_dict, suffix='_key'): del description_dict[description_key] def as_dataframe(self): - return pd.DataFrame.from_dict(self.path_info, - orient='index') + return pd.DataFrame.from_dict(self.path_info, orient="index") diff --git a/allensdk/config/manifest_builder.py b/allensdk/config/manifest_builder.py index aab6a843bd..ac433aa5f0 100644 --- a/allensdk/config/manifest_builder.py +++ b/allensdk/config/manifest_builder.py @@ -38,8 +38,9 @@ from allensdk.config.manifest import Manifest import pandas as pd + class ManifestBuilder(object): - df_columns = ['key', 'parent_key', 'spec', 'type', 'format'] + df_columns = ["key", "parent_key", "spec", "type", "format"] def __init__(self): self._log = logging.getLogger(__name__) @@ -47,22 +48,16 @@ def __init__(self): self.sections = {} def set_version(self, value): - self.path_info.append({'type': Manifest.VERSION, 'value': value}) + self.path_info.append({"type": Manifest.VERSION, "value": value}) - def add_path(self, key, spec, - typename='dir', - parent_key=None, - format=None): - entry = { - 'key': key, - 'type': typename, - 'spec': spec} + def add_path(self, key, spec, typename="dir", parent_key=None, format=None): + entry = {"key": key, "type": typename, "spec": spec} if format is not None: - entry['format'] = format + entry["format"] = format if parent_key is not None: - entry['parent_key'] = parent_key + entry["parent_key"] = parent_key self.path_info.append(entry) @@ -70,18 +65,18 @@ def add_section(self, name, contents): self.sections[name] = contents def write_json_file(self, path, overwrite=False): - mode = 'wb' + mode = "wb" if overwrite is True: - mode = 'wb+' + mode = "wb+" json_string = self.write_json_string() with open(path, mode) as f: try: - f.write(json_string) # Python 2.7 + f.write(json_string) # Python 2.7 except TypeError: - f.write(bytes(json_string, 'utf-8')) # Python 3 + f.write(bytes(json_string, "utf-8")) # Python 3 def get_config(self): wrapper = {"manifest": self.path_info} @@ -98,8 +93,7 @@ def write_json_string(self): return ju.write_string(config) def as_dataframe(self): - return pd.DataFrame(self.path_info, - columns=ManifestBuilder.df_columns) + return pd.DataFrame(self.path_info, columns=ManifestBuilder.df_columns) def from_dataframe(self, df): self.path_info = {} diff --git a/allensdk/config/model/__init__.py b/allensdk/config/model/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/config/model/__init__.py +++ b/allensdk/config/model/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/config/model/description.py b/allensdk/config/model/description.py index 3de0d2fe3f..658a5a991f 100644 --- a/allensdk/config/model/description.py +++ b/allensdk/config/model/description.py @@ -47,7 +47,7 @@ def __init__(self): self.manifest = Manifest() def update_data(self, data, section=None): - '''Merge configuration data possibly from multiple files. + """Merge configuration data possibly from multiple files. Parameters ---------- @@ -55,9 +55,9 @@ def update_data(self, data, section=None): Configuration structure to add. section : string, optional What configuration section to read it into if the file does not specify. - ''' + """ if section is None: - for (section, entries) in data.items(): + for section, entries in data.items(): if section not in self.data: self.data[section] = entries else: @@ -69,20 +69,20 @@ def update_data(self, data, section=None): self.data[section].append(data) def is_empty(self): - '''Check if anything is in the object. + """Check if anything is in the object. Returns ------- boolean true if self.data is missing or empty - ''' + """ if self.data: return False return True def unpack(self, data, section=None): - '''Read the manifest and other stand-alone configuration structure, + """Read the manifest and other stand-alone configuration structure, or insert a configuration object into a section of an existing configuration. Parameters @@ -92,7 +92,7 @@ def unpack(self, data, section=None): or an configuration object to be placed within a section. section : string, optional. If this is present, place data within an existing section array. - ''' + """ if section is None: self.unpack_manifest(data) self.update_data(data) @@ -100,27 +100,27 @@ def unpack(self, data, section=None): self.update_data(data, section) def unpack_manifest(self, data): - '''Pull the manifest configuration section into a separate place. + """Pull the manifest configuration section into a separate place. Parameters ---------- data : dict A configuration structure that still has a manifest section. - ''' + """ data_manifest = data.pop("manifest", {}) reserved_data = {"manifest": data_manifest} self.reserved_data.append(reserved_data) self.manifest.load_config(data_manifest) def fix_unary_sections(self, section_names=None): - ''' Wrap section contents that don't have the proper + """Wrap section contents that don't have the proper array surrounding them in an array. Parameters ---------- section_names : list of strings, optional Keys of sections that might not be in array form. - ''' + """ if section_names is None: section_names = [] @@ -128,5 +128,4 @@ def fix_unary_sections(self, section_names=None): if section in self.data: if type(self.data[section]) is dict: self.data[section] = [self.data[section]] - Description._log.warn( - "wrapped description section %s in an array." % (section)) + Description._log.warn("wrapped description section %s in an array." % (section)) diff --git a/allensdk/config/model/description_parser.py b/allensdk/config/model/description_parser.py index 2a2ae864dd..909c21a91f 100644 --- a/allensdk/config/model/description_parser.py +++ b/allensdk/config/model/description_parser.py @@ -44,7 +44,7 @@ def __init__(self): pass def read(self, file_path, description=None, section=None, **kwargs): - '''Parse data needed for a simulation. + """Parse data needed for a simulation. Parameters ---------- @@ -52,7 +52,7 @@ def read(self, file_path, description=None, section=None, **kwargs): Configuration from parsing previous files. section : string, optional What configuration section to read it into if the file does not specify. - ''' + """ if description is None: description = Description() @@ -62,7 +62,7 @@ def read(self, file_path, description=None, section=None, **kwargs): return description def read_string(self, data_string, description=None, section=None, header=None): - '''Parse data needed for a simulation from a string.''' + """Parse data needed for a simulation from a string.""" raise Exception("Not implemented, use a sub class") def write(self, filename, description): @@ -78,7 +78,7 @@ def write(self, filename, description): writer.write(filename, description) def parser_for_extension(self, filename): - '''Choose a subclass that can read the format. + """Choose a subclass that can read the format. Parameters ---------- @@ -89,18 +89,18 @@ def parser_for_extension(self, filename): ------- DescriptionParser Appropriate subclass. - ''' + """ # Circular imports from allensdk.config.model.formats.json_description_parser import JsonDescriptionParser from allensdk.config.model.formats.pycfg_description_parser import PycfgDescriptionParser parser = None - if filename.endswith('.json'): + if filename.endswith(".json"): parser = JsonDescriptionParser() - elif filename.endswith('.pycfg'): + elif filename.endswith(".pycfg"): parser = PycfgDescriptionParser() else: - raise Exception('could not determine file format') + raise Exception("could not determine file format") return parser diff --git a/allensdk/config/model/formats/__init__.py b/allensdk/config/model/formats/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/config/model/formats/__init__.py +++ b/allensdk/config/model/formats/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/config/model/formats/hdf5_util.py b/allensdk/config/model/formats/hdf5_util.py index 2842b04be0..7f9dd690eb 100644 --- a/allensdk/config/model/formats/hdf5_util.py +++ b/allensdk/config/model/formats/hdf5_util.py @@ -40,28 +40,23 @@ class Hdf5Util(object): - def __init__(self): self.log = logging.getLogger(__name__) def read(self, file_path): try: - with h5py.File(file_path, 'r') as csr: - return csr_matrix((csr['data'][...], - csr['indices'][...], - csr['indptr'][...])) + with h5py.File(file_path, "r") as csr: + return csr_matrix((csr["data"][...], csr["indices"][...], csr["indptr"][...])) except Exception: - self.log.error( - "Couldn't read AllenSDK HDF5 CSR configuration: %s" % file_path) + self.log.error("Couldn't read AllenSDK HDF5 CSR configuration: %s" % file_path) raise def write(self, file_path, m): try: - with h5py.File(file_path, 'w') as csr: - csr.create_dataset('data', data=m.data, dtype=np.uint8) - csr.create_dataset('indices', data=m.indices, dtype=np.uint32) - csr.create_dataset('indptr', data=m.indptr, dtype=np.uint32) + with h5py.File(file_path, "w") as csr: + csr.create_dataset("data", data=m.data, dtype=np.uint8) + csr.create_dataset("indices", data=m.indices, dtype=np.uint32) + csr.create_dataset("indptr", data=m.indptr, dtype=np.uint32) except Exception: - self.log.warn( - "Couldn't write AllenSDK HDF5 CSR configuration: %s" % file_path) + self.log.warn("Couldn't write AllenSDK HDF5 CSR configuration: %s" % file_path) raise diff --git a/allensdk/config/model/formats/json_description_parser.py b/allensdk/config/model/formats/json_description_parser.py index cda7448143..f4f3a117e2 100644 --- a/allensdk/config/model/formats/json_description_parser.py +++ b/allensdk/config/model/formats/json_description_parser.py @@ -47,7 +47,7 @@ def __init__(self): super(JsonDescriptionParser, self).__init__() def read(self, file_path, description=None, section=None, **kwargs): - '''Parse a complete or partial configuration. + """Parse a complete or partial configuration. Parameters ---------- @@ -64,7 +64,7 @@ def read(self, file_path, description=None, section=None, **kwargs): The input description with parsed configuration added. Section is only specified for "bare" objects that are to be added to a section array. - ''' + """ if description is None: description = Description() @@ -74,7 +74,7 @@ def read(self, file_path, description=None, section=None, **kwargs): return description def read_string(self, json_string, description=None, section=None, **kwargs): - '''Parse a complete or partial configuration. + """Parse a complete or partial configuration. Parameters ---------- @@ -91,7 +91,7 @@ def read_string(self, json_string, description=None, section=None, **kwargs): The input description with parsed configuration added. Section is only specified for "bare" objects that are to be added to a section array. - ''' + """ if description is None: description = Description() @@ -102,26 +102,25 @@ def read_string(self, json_string, description=None, section=None, **kwargs): return description def write(self, filename, description): - '''Write the description to a JSON file. + """Write the description to a JSON file. Parameters ---------- description : Description Object to write. - ''' + """ try: - with open(filename, 'w') as f: + with open(filename, "w") as f: dump(description.data, f, indent=2) except Exception: - self.log.warn( - "Couldn't write allensdk json description: %s" % filename) + self.log.warn("Couldn't write allensdk json description: %s" % filename) raise return def write_string(self, description): - '''Write the description to a JSON string. + """Write the description to a JSON string. Parameters ---------- @@ -132,10 +131,9 @@ def write_string(self, description): ------- string JSON serialization of the input. - ''' + """ try: - json_string = dumps(description.data, - indent=2) + json_string = dumps(description.data, indent=2) return json_string except Exception: self.log.warn("Couldn't write allensdk json description: %s") diff --git a/allensdk/config/model/formats/pycfg_description_parser.py b/allensdk/config/model/formats/pycfg_description_parser.py index e84c96cd32..5089d95b70 100644 --- a/allensdk/config/model/formats/pycfg_description_parser.py +++ b/allensdk/config/model/formats/pycfg_description_parser.py @@ -46,7 +46,7 @@ def __init__(self): super(PycfgDescriptionParser, self).__init__() def read(self, pycfg_file_path, description=None, section=None, **kwargs): - '''Read a serialized description from a Python (.pycfg) file. + """Read a serialized description from a Python (.pycfg) file. Parameters ---------- @@ -57,14 +57,14 @@ def read(self, pycfg_file_path, description=None, section=None, **kwargs): ------- Description Configuration object. - ''' - header = kwargs.get('prefix', '') + """ + header = kwargs.get("prefix", "") - with open(pycfg_file_path, 'r') as f: + with open(pycfg_file_path, "r") as f: return self.read_string(f.read(), description, section, header=header) def read_string(self, python_string, description=None, section=None, **kwargs): - '''Read a serialized description from a Python (.pycfg) string. + """Read a serialized description from a Python (.pycfg) string. Parameters ---------- @@ -75,51 +75,49 @@ def read_string(self, python_string, description=None, section=None, **kwargs): ------- Description Configuration object. - ''' + """ if description is None: description = Description() - header = kwargs.get('header', '') + header = kwargs.get("header", "") - python_string = "%s\n\nallensdk_description = %s" % ( - header, python_string) + python_string = "%s\n\nallensdk_description = %s" % (header, python_string) ns = {} - code = compile(python_string, 'string', 'exec') + code = compile(python_string, "string", "exec") exec(code, ns) - data = ns['allensdk_description'] + data = ns["allensdk_description"] description.unpack(data, section) return description def write(self, filename, description): - '''Write the description to a Python (.pycfg) file. + """Write the description to a Python (.pycfg) file. Parameters ---------- filename : string Name of the file to write. - ''' + """ try: - with open(filename, 'w') as f: + with open(filename, "w") as f: pprint(description.data, f, indent=2) except Exception: - self.log.warn( - "Couldn't write allensdk python description: %s" % filename) + self.log.warn("Couldn't write allensdk python description: %s" % filename) raise return def write_string(self, description): - '''Write the description to a pretty-printed Python string. + """Write the description to a pretty-printed Python string. Parameters ---------- description : Description Configuration object to write. - ''' + """ pycfg_string = pformat(description.data, indent=2) return pycfg_string diff --git a/allensdk/core/__init__.py b/allensdk/core/__init__.py index 81885dc2d5..41151778fa 100644 --- a/allensdk/core/__init__.py +++ b/allensdk/core/__init__.py @@ -34,11 +34,13 @@ # POSSIBILITY OF SUCH DAMAGE. # -from ._data_object_base.data_object import DataObject # noqa F401 -from ._data_object_base.readable_interfaces import ( # noqa F401 - JsonReadableInterface, LimsReadableInterface, NwbReadableInterface, - DataFileReadableInterface +from ._data_object_base.data_object import DataObject # noqa F401 +from ._data_object_base.readable_interfaces import ( # noqa F401 + JsonReadableInterface, + LimsReadableInterface, + NwbReadableInterface, + DataFileReadableInterface, ) -from ._data_object_base.writable_interfaces import ( # noqa F401 - NwbWritableInterface +from ._data_object_base.writable_interfaces import ( # noqa F401 + NwbWritableInterface, ) diff --git a/allensdk/core/_data_object_base/data_object.py b/allensdk/core/_data_object_base/data_object.py index e13206baa8..76a7ada069 100644 --- a/allensdk/core/_data_object_base/data_object.py +++ b/allensdk/core/_data_object_base/data_object.py @@ -14,9 +14,9 @@ class DataObject(abc.ABC): data sources and sinks (e.g. LIMS, JSON, NWB). """ - def __init__(self, name: str, value: Any, - exclude_from_equals: Optional[Set[str]] = None, - is_value_self: bool = False): + def __init__( + self, name: str, value: Any, exclude_from_equals: Optional[Set[str]] = None, is_value_self: bool = False + ): """ :param name Name @@ -31,10 +31,9 @@ def __init__(self, name: str, value: Any, array or dataframe """ if value is self: - raise ValueError('Passing value of self is not supported') + raise ValueError("Passing value of self is not supported") if is_value_self and value is not None: - raise ValueError('If passing is_value_self=True, then value ' - 'should be None') + raise ValueError("If passing is_value_self=True, then value should be None") self._name = name self._value = value self._is_value_self = is_value_self @@ -107,7 +106,7 @@ def _get_keys_and_values(base_value: DataObject): # skip properties that return self # (leads to infinite recursion) continue - if name == 'name': + if name == "name": # The name is the key continue @@ -119,6 +118,7 @@ def _get_keys_and_values(base_value: DataObject): pass properties.append((name, value, newpath)) return properties + properties = _get_keys_and_values(base_value=value) # Find the nested dict @@ -126,8 +126,7 @@ def _get_keys_and_values(base_value: DataObject): for p in path: cur = cur[p] - if isinstance(value._value, DataObject) or \ - value._is_value_self: + if isinstance(value._value, DataObject) or value._is_value_self: # it's nested cur[value._name] = dict() for p in properties: @@ -150,14 +149,16 @@ def _get_keys_and_values(base_value: DataObject): def _get_properties(self): """Returns all property names and values""" + def is_prop(attr): return isinstance(getattr(type(self), attr, None), property) + props = [attr for attr in dir(self) if is_prop(attr)] return {name: getattr(self, name) for name in props} def __eq__(self, other: "DataObject"): if type(self) != type(other): - msg = f'Do not know how to compare with type {type(other)}' + msg = f"Do not know how to compare with type {type(other)}" raise NotImplementedError(msg) d_self = self.to_dict() @@ -170,8 +171,7 @@ def __eq__(self, other: "DataObject"): x2 = d_other[p] try: - compare_fields(x1=x1, x2=x2, - ignore_keys=self._exclude_from_equals) + compare_fields(x1=x1, x2=x2, ignore_keys=self._exclude_from_equals) except AssertionError: return False return True diff --git a/allensdk/core/_data_object_base/readable_interfaces.py b/allensdk/core/_data_object_base/readable_interfaces.py index d17df9869c..10ca317efb 100644 --- a/allensdk/core/_data_object_base/readable_interfaces.py +++ b/allensdk/core/_data_object_base/readable_interfaces.py @@ -7,6 +7,7 @@ class JsonReadableInterface(abc.ABC): """Marks a data object as readable from json""" + @classmethod @abc.abstractmethod def from_json(cls, dict_repr: dict) -> "DataObject": # pragma: no cover @@ -23,6 +24,7 @@ def from_json(cls, dict_repr: dict) -> "DataObject": # pragma: no cover class LimsReadableInterface(abc.ABC): """Marks a data object as readable from LIMS""" + @classmethod @abc.abstractmethod def from_lims(cls, *args) -> "DataObject": # pragma: no cover @@ -40,13 +42,10 @@ def from_lims(cls, *args) -> "DataObject": # pragma: no cover class NwbReadableInterface(abc.ABC): """Marks a data object as readable from NWB""" + @classmethod @abc.abstractmethod - def from_nwb( - cls, - nwbfile: NWBFile, - **kwargs - ) -> "DataObject": # pragma: no cover + def from_nwb(cls, nwbfile: NWBFile, **kwargs) -> "DataObject": # pragma: no cover """Populate a DataObject from a pyNWB file object. Parameters @@ -65,6 +64,7 @@ def from_nwb( class DataFileReadableInterface(abc.ABC): """Marks a data object as readable from various data files, not covered by existing interfaces""" + @classmethod @abc.abstractmethod def from_data_file(cls, *args) -> "DataObject": diff --git a/allensdk/core/_data_object_base/writable_interfaces.py b/allensdk/core/_data_object_base/writable_interfaces.py index 131e68ebc2..66900dfb8b 100644 --- a/allensdk/core/_data_object_base/writable_interfaces.py +++ b/allensdk/core/_data_object_base/writable_interfaces.py @@ -5,6 +5,7 @@ class NwbWritableInterface(abc.ABC): """Marks a data object as writable to NWB""" + @abc.abstractmethod def to_nwb(self, nwbfile: NWBFile) -> NWBFile: # pragma: no cover """Given an already populated DataObject, return an pyNWB file object diff --git a/allensdk/core/auth_config.py b/allensdk/core/auth_config.py index 845090065d..f940a45aed 100644 --- a/allensdk/core/auth_config.py +++ b/allensdk/core/auth_config.py @@ -9,7 +9,7 @@ ("MTRAIN_USER", None), ("MTRAIN_HOST", None), ("MTRAIN_PORT", 5432), - ("MTRAIN_PASSWORD", None) + ("MTRAIN_PASSWORD", None), ] # For PostgresQueryMixin @@ -18,7 +18,7 @@ "user": "LIMS_USER", "host": "LIMS_HOST", "password": "LIMS_PASSWORD", - "port": "LIMS_PORT" + "port": "LIMS_PORT", } # For PostgresQueryMixin @@ -27,5 +27,5 @@ "user": "MTRAIN_USER", "host": "MTRAIN_HOST", "password": "MTRAIN_PASSWORD", - "port": "MTRAIN_PORT" -} \ No newline at end of file + "port": "MTRAIN_PORT", +} diff --git a/allensdk/core/authentication.py b/allensdk/core/authentication.py index f85e41120f..64edeadbb2 100644 --- a/allensdk/core/authentication.py +++ b/allensdk/core/authentication.py @@ -9,8 +9,7 @@ logger = logging.getLogger(__name__) -DbCredentials = namedtuple("DbCredentials", - ["dbname", "user", "host", "port", "password"]) +DbCredentials = namedtuple("DbCredentials", ["dbname", "user", "host", "port", "password"]) class CredentialProvider(ABC): @@ -26,6 +25,7 @@ class EnvCredentialProvider(CredentialProvider): Provides credentials from environment variables for variables listed in CREDENTIAL_KEYS. """ + METHOD = "env" def __init__(self, environ: Optional[Dict[str, Any]] = None): @@ -39,8 +39,7 @@ def __init__(self, environ: Optional[Dict[str, Any]] = None): """ if environ is None: environ = os.environ - self.credentials = dict((k[0], environ.get(k[0], k[1])) - for k in CREDENTIAL_KEYS) + self.credentials = dict((k[0], environ.get(k[0], k[1])) for k in CREDENTIAL_KEYS) def provide(self, credential): return self.credentials.get(credential) @@ -59,8 +58,7 @@ def get_credential_provider(): return CREDENTIAL_PROVIDER -def credential_injector(credential_map: Dict[str, Any], - provider: Optional[CredentialProvider] = None): +def credential_injector(credential_map: Dict[str, Any], provider: Optional[CredentialProvider] = None): """ Decorator used to inject credentials from another source if not explicitly provided in the function call. This function will only supply @@ -97,17 +95,17 @@ def injector_decorator(func): def wrapper(*args, **kwargs): for kw, credential in credential_map.items(): if kw not in kwargs.keys(): - logger.debug(f"No explicit value provided for {kw}. " - "Searching credential provider.") + logger.debug(f"No explicit value provided for {kw}. Searching credential provider.") secret = provider.provide(credential) if secret is not None: - logger.debug("Found value in credential provider, " - f"from '{provider.METHOD}' method.") + logger.debug(f"Found value in credential provider, from '{provider.METHOD}' method.") kwargs.update({kw: provider.provide(credential)}) else: logger.warning( - f"Value for {kw} was neither explicitly provided " - "nor found in credential provider.") + f"Value for {kw} was neither explicitly provided nor found in credential provider." + ) return func(*args, **kwargs) + return wrapper + return injector_decorator diff --git a/allensdk/core/brain_observatory_cache.py b/allensdk/core/brain_observatory_cache.py index c3d0171682..6e22e29f1b 100644 --- a/allensdk/core/brain_observatory_cache.py +++ b/allensdk/core/brain_observatory_cache.py @@ -59,8 +59,7 @@ # NOTE: This is a really ugly hack to get around the fact that warehouse does # not have Ophys session ids associated with experiment ids. -from .ophys_experiment_session_id_mapping import \ - ophys_experiment_session_id_map +from .ophys_experiment_session_id_mapping import ophys_experiment_session_id_map from ..api.cloud_cache.cloud_cache import S3CloudCache @@ -119,15 +118,11 @@ class BrainObservatoryCache(Cache): EYE_GAZE_DATA_KEY = "EYE_GAZE_DATA" MANIFEST_VERSION = "1.3" - def __init__(self, cache=True, manifest_file=None, base_uri=None, - api=None): - + def __init__(self, cache=True, manifest_file=None, base_uri=None, api=None): if manifest_file is None: manifest_file = get_default_manifest_file("brain_observatory") - super(BrainObservatoryCache, self).__init__( - manifest=manifest_file, cache=cache, version=self.MANIFEST_VERSION - ) + super(BrainObservatoryCache, self).__init__(manifest=manifest_file, cache=cache, version=self.MANIFEST_VERSION) if api is None: self.api = BrainObservatoryApi(base_uri=base_uri) @@ -137,9 +132,7 @@ def __init__(self, cache=True, manifest_file=None, base_uri=None, def get_all_targeted_structures(self): """Return a list of all targeted structures in the data set.""" containers = self.get_experiment_containers(simple=False) - targeted_structures = set( - [c["targeted_structure"]["acronym"] for c in containers] - ) + targeted_structures = set([c["targeted_structure"]["acronym"] for c in containers]) return sorted(list(targeted_structures)) def get_all_cre_lines(self): @@ -233,12 +226,9 @@ def get_experiment_containers( _assert_not_string(reporter_lines, "reporter_lines") _assert_not_string(transgenic_lines, "transgenic_lines") - file_name = self.get_cache_path(file_name, - self.EXPERIMENT_CONTAINERS_KEY) + file_name = self.get_cache_path(file_name, self.EXPERIMENT_CONTAINERS_KEY) - containers = self.api.get_experiment_containers( - path=file_name, strategy="lazy", **Cache.cache_json() - ) + containers = self.api.get_experiment_containers(path=file_name, strategy="lazy", **Cache.cache_json()) containers = self.api.filter_experiment_containers( containers, @@ -356,22 +346,16 @@ def get_ophys_experiments( file_name = self.get_cache_path(file_name, self.EXPERIMENTS_KEY) - exps = self.api.get_ophys_experiments( - path=file_name, strategy="lazy", **Cache.cache_json() - ) + exps = self.api.get_ophys_experiments(path=file_name, strategy="lazy", **Cache.cache_json()) # NOTE: Ugly hack to update the 'fail_eye_tracking' field # which is using True/False values for the previous eye mapping # implementation. This will also need to be fixed in warehouse. # ----- Start of ugly hack ----- - response = self.api.template_query( - "brain_observatory_queries", "all_eye_mapping_files" - ) + response = self.api.template_query("brain_observatory_queries", "all_eye_mapping_files") session_ids_with_eye_tracking: set = { - entry["attachable_id"] - for entry in response - if entry["attachable_type"] == "OphysSession" + entry["attachable_id"] for entry in response if entry["attachable_type"] == "OphysSession" } for indx, exp in enumerate(exps): @@ -387,13 +371,9 @@ def get_ophys_experiments( if cell_specimen_ids is not None: cells = self.get_cell_specimens(ids=cell_specimen_ids) - cell_container_ids = set( - [cell["experiment_container_id"] for cell in cells] - ) + cell_container_ids = set([cell["experiment_container_id"] for cell in cells]) if experiment_container_ids is not None: - experiment_container_ids = list( - set(experiment_container_ids) - cell_container_ids - ) + experiment_container_ids = list(set(experiment_container_ids) - cell_container_ids) else: experiment_container_ids = list(cell_container_ids) @@ -421,9 +401,7 @@ def _get_stimulus_mappings(self, file_name=None): file_name = self.get_cache_path(file_name, self.STIMULUS_MAPPINGS_KEY) - mappings = self.api.get_stimulus_mappings( - path=file_name, strategy="lazy", **Cache.cache_json() - ) + mappings = self.api.get_stimulus_mappings(path=file_name, strategy="lazy", **Cache.cache_json()) return mappings @@ -504,11 +482,7 @@ def get_cell_specimens( # drop the thumbnail columns if simple: mappings = self._get_stimulus_mappings() - thumbnails = [ - m["item"] - for m in mappings - if m["item_type"] == "T" and m["level"] == "R" - ] + thumbnails = [m["item"] for m in mappings if m["item_type"] == "T" and m["level"] == "R"] for cs in cell_specimens: for t in thumbnails: del cs[t] @@ -516,9 +490,7 @@ def get_cell_specimens( return cell_specimens def get_nwb_filepath(self, ophys_experiment_id=None): - cache_nwb_filepath = self.get_cache_path( - None, self.EXPERIMENT_DATA_KEY, ophys_experiment_id - ) + cache_nwb_filepath = self.get_cache_path(None, self.EXPERIMENT_DATA_KEY, ophys_experiment_id) if os.path.exists(cache_nwb_filepath): return cache_nwb_filepath else: @@ -543,19 +515,13 @@ def get_ophys_experiment_data(self, ophys_experiment_id, file_name=None): ------- BrainObservatoryNwbDataSet """ - file_name = self.get_cache_path( - file_name, self.EXPERIMENT_DATA_KEY, ophys_experiment_id - ) + file_name = self.get_cache_path(file_name, self.EXPERIMENT_DATA_KEY, ophys_experiment_id) - self.api.save_ophys_experiment_data( - ophys_experiment_id, file_name, strategy="lazy" - ) + self.api.save_ophys_experiment_data(ophys_experiment_id, file_name, strategy="lazy") return BrainObservatoryNwbDataSet(file_name) - def get_ophys_experiment_analysis( - self, ophys_experiment_id, stimulus_type, file_name=None - ): + def get_ophys_experiment_analysis(self, ophys_experiment_id, stimulus_type, file_name=None): """Download the h5 analysis file for a stimulus set, for a particular ophys_experiment (if it hasn't already been downloaded) and return a data accessor @@ -578,42 +544,27 @@ def get_ophys_experiment_analysis( ------- BrainObservatoryNwbDataSet """ - data_set = self.get_ophys_experiment_data(ophys_experiment_id, - file_name=None) + data_set = self.get_ophys_experiment_data(ophys_experiment_id, file_name=None) session_type = data_set.get_session_type() if stimulus_type not in stim_info.SESSION_STIMULUS_MAP[session_type]: raise RuntimeError( "Stimulus %s not available session type: %s" - % (stimulus_type, - stim_info.SESSION_STIMULUS_MAP[stimulus_type]) + % (stimulus_type, stim_info.SESSION_STIMULUS_MAP[stimulus_type]) ) # Use manifest to figure out where to cache the file: - file_name = self.get_cache_path( - file_name, self.ANALYSIS_DATA_KEY, ophys_experiment_id, - session_type - ) + file_name = self.get_cache_path(file_name, self.ANALYSIS_DATA_KEY, ophys_experiment_id, session_type) # Cache the analsis file from an RMA query: - self.api.save_ophys_experiment_analysis_data( - ophys_experiment_id, file_name, strategy="lazy" - ) + self.api.save_ophys_experiment_analysis_data(ophys_experiment_id, file_name, strategy="lazy") # Get the analysis class from ANALYSIS_CLASS_DICT, and build # from the static method: - if ( - stimulus_type - in stim_info.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES - + stim_info.NATURAL_MOVIE_STIMULUS_TYPES - ): - return ANALYSIS_CLASS_DICT[stimulus_type].from_analysis_file( - data_set, file_name, stimulus_type - ) + if stimulus_type in stim_info.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES + stim_info.NATURAL_MOVIE_STIMULUS_TYPES: + return ANALYSIS_CLASS_DICT[stimulus_type].from_analysis_file(data_set, file_name, stimulus_type) else: - return ANALYSIS_CLASS_DICT[stimulus_type].from_analysis_file( - data_set, file_name - ) + return ANALYSIS_CLASS_DICT[stimulus_type].from_analysis_file(data_set, file_name) def get_ophys_experiment_events(self, ophys_experiment_id, file_name=None): """Download the npz events file for an ophys_experiment if it hasn't @@ -632,13 +583,9 @@ def get_ophys_experiment_events(self, ophys_experiment_id, file_name=None): events: numpy.ndarray [N_cells,N_times] array of events. """ - file_name = self.get_cache_path( - file_name, self.EVENTS_DATA_KEY, ophys_experiment_id - ) + file_name = self.get_cache_path(file_name, self.EVENTS_DATA_KEY, ophys_experiment_id) - self.api.save_ophys_experiment_event_data( - ophys_experiment_id, file_name, strategy="lazy" - ) + self.api.save_ophys_experiment_event_data(ophys_experiment_id, file_name, strategy="lazy") return np.load(file_name, allow_pickle=False)["ev"] @@ -723,18 +670,12 @@ def get_ophys_pupil_data( # experiment ids. # ----- Start of ugly hack ----- try: - ophys_session_id = \ - ophys_experiment_session_id_map[ophys_experiment_id] + ophys_session_id = ophys_experiment_session_id_map[ophys_experiment_id] except KeyError: - raise RuntimeError( - f"Experiment id '{ophys_experiment_id}' has no " - f"associated session!" - ) + raise RuntimeError(f"Experiment id '{ophys_experiment_id}' has no associated session!") # ----- End of ugly hack ----- - file_name = self.get_cache_path( - file_name, self.EYE_GAZE_DATA_KEY, ophys_session_id - ) + file_name = self.get_cache_path(file_name, self.EYE_GAZE_DATA_KEY, ophys_session_id) if not file_name: raise RuntimeError( @@ -745,9 +686,7 @@ def get_ophys_pupil_data( # NOTE: `save_ophys_experiment_eye_gaze_data` will also need to be # updated to remove ophy_session_id param when ugly hack is removed. - self.api.save_ophys_experiment_eye_gaze_data( - ophys_experiment_id, ophys_session_id, file_name, strategy="lazy" - ) + self.api.save_ophys_experiment_eye_gaze_data(ophys_experiment_id, ophys_session_id, file_name, strategy="lazy") gaze_mapping_data = read_eye_gaze_mappings(Path(file_name)) @@ -822,8 +761,4 @@ def build_manifest(self, file_name): def _assert_not_string(arg, name): if isinstance(arg, str): - raise TypeError( - "Argument '%s' with value '%s' is a string type, but " - "should be a list." - % (name, arg) - ) + raise TypeError("Argument '%s' with value '%s' is a string type, but should be a list." % (name, arg)) diff --git a/allensdk/core/brain_observatory_nwb_data_set.py b/allensdk/core/brain_observatory_nwb_data_set.py index 2f2a117420..be992d0aaa 100755 --- a/allensdk/core/brain_observatory_nwb_data_set.py +++ b/allensdk/core/brain_observatory_nwb_data_set.py @@ -48,111 +48,118 @@ from allensdk.brain_observatory.locally_sparse_noise import LocallySparseNoise import allensdk.brain_observatory.stimulus_info as si -from allensdk.brain_observatory.brain_observatory_exceptions import (MissingStimulusException, - NoEyeTrackingException) +from allensdk.brain_observatory.brain_observatory_exceptions import MissingStimulusException, NoEyeTrackingException from allensdk.api.warehouse_cache.cache import memoize -from allensdk.core import h5_utilities +from allensdk.core import h5_utilities from allensdk.brain_observatory.stimulus_info import mask_stimulus_template as si_mask_stimulus_template from allensdk.brain_observatory.brain_observatory_exceptions import EpochSeparationException -_STIMULUS_PRESENTATION_PATH = 'stimulus/presentation' -_STIMULUS_PRESENTATION_PATTERNS = ('{}', '{}_stimulus',) +_STIMULUS_PRESENTATION_PATH = "stimulus/presentation" +_STIMULUS_PRESENTATION_PATTERNS = ( + "{}", + "{}_stimulus", +) def get_epoch_mask_list(st, threshold, max_cuts=2): - '''Convenience function to cut a stim table into multiple epochs + """Convenience function to cut a stim table into multiple epochs :param st: input stimtable :param threshold: threshold on the max duration of a subepoch :param max_cuts: maximum number of allowed epochs to cut into :return: epoch_mask_list, a list of indices that define the start and end of sub-epochs - ''' + """ if threshold is None: - raise NotImplementedError('threshold not set for this type of session') + raise NotImplementedError("threshold not set for this type of session") - delta = (st.start.values[1:] - st.end.values[:-1]) + delta = st.start.values[1:] - st.end.values[:-1] cut_inds = np.where(delta > threshold)[0] + 1 epoch_mask_list = [] if len(cut_inds) > max_cuts: - # See: https://gist.github.com/nicain/bce66cd073e422f07cf337b476c63be7 # https://github.com/AllenInstitute/AllenSDK/issues/66 - raise EpochSeparationException('more than 2 epochs cut', delta=delta) - - for ii in range(len(cut_inds)+1): + raise EpochSeparationException("more than 2 epochs cut", delta=delta) + for ii in range(len(cut_inds) + 1): if ii == 0: first_ind = st.iloc[0].start else: - first_ind = st.iloc[cut_inds[ii-1]].start + first_ind = st.iloc[cut_inds[ii - 1]].start if ii == len(cut_inds): last_ind_inclusive = st.iloc[-1].end else: - last_ind_inclusive = st.iloc[cut_inds[ii]-1].end + last_ind_inclusive = st.iloc[cut_inds[ii] - 1].end - epoch_mask_list.append((first_ind,last_ind_inclusive)) + epoch_mask_list.append((first_ind, last_ind_inclusive)) return epoch_mask_list class BrainObservatoryNwbDataSet(object): - PIPELINE_DATASET = 'brain_observatory_pipeline' + PIPELINE_DATASET = "brain_observatory_pipeline" SUPPORTED_PIPELINE_VERSION = "3.0" FILE_METADATA_MAPPING = { - 'age': 'general/subject/age', - 'sex': 'general/subject/sex', - 'imaging_depth': 'general/optophysiology/imaging_plane_1/imaging depth', - 'targeted_structure': 'general/optophysiology/imaging_plane_1/location', - 'ophys_experiment_id': 'general/session_id', - 'experiment_container_id': 'general/experiment_container_id', - 'device_string': 'general/devices/2-photon microscope', - 'excitation_lambda': 'general/optophysiology/imaging_plane_1/excitation_lambda', - 'indicator': 'general/optophysiology/imaging_plane_1/indicator', - 'fov': 'general/fov', - 'genotype': 'general/subject/genotype', - 'session_start_time': 'session_start_time', - 'session_type': 'general/session_type', - 'specimen_name': 'general/specimen_name', - 'generated_by': 'general/generated_by' + "age": "general/subject/age", + "sex": "general/subject/sex", + "imaging_depth": "general/optophysiology/imaging_plane_1/imaging depth", + "targeted_structure": "general/optophysiology/imaging_plane_1/location", + "ophys_experiment_id": "general/session_id", + "experiment_container_id": "general/experiment_container_id", + "device_string": "general/devices/2-photon microscope", + "excitation_lambda": "general/optophysiology/imaging_plane_1/excitation_lambda", + "indicator": "general/optophysiology/imaging_plane_1/indicator", + "fov": "general/fov", + "genotype": "general/subject/genotype", + "session_start_time": "session_start_time", + "session_type": "general/session_type", + "specimen_name": "general/specimen_name", + "generated_by": "general/generated_by", } STIMULUS_TABLE_TYPES = { - 'abstract_feature_series': [si.DRIFTING_GRATINGS, si.STATIC_GRATINGS], - 'indexed_time_series': [si.NATURAL_SCENES, si.LOCALLY_SPARSE_NOISE, - si.LOCALLY_SPARSE_NOISE_4DEG, si.LOCALLY_SPARSE_NOISE_8DEG], - 'repeated_indexed_time_series':[si.NATURAL_MOVIE_ONE, si.NATURAL_MOVIE_TWO, si.NATURAL_MOVIE_THREE] - + "abstract_feature_series": [si.DRIFTING_GRATINGS, si.STATIC_GRATINGS], + "indexed_time_series": [ + si.NATURAL_SCENES, + si.LOCALLY_SPARSE_NOISE, + si.LOCALLY_SPARSE_NOISE_4DEG, + si.LOCALLY_SPARSE_NOISE_8DEG, + ], + "repeated_indexed_time_series": [si.NATURAL_MOVIE_ONE, si.NATURAL_MOVIE_TWO, si.NATURAL_MOVIE_THREE], } # this array was moved before file versioning was in place - MOTION_CORRECTION_DATASETS = [ "MotionCorrection/2p_image_series/xy_translations", - "MotionCorrection/2p_image_series/xy_translation" ] + MOTION_CORRECTION_DATASETS = [ + "MotionCorrection/2p_image_series/xy_translations", + "MotionCorrection/2p_image_series/xy_translation", + ] def __init__(self, nwb_file): - self.nwb_file = nwb_file self.pipeline_version = None if os.path.exists(self.nwb_file): meta = self.get_metadata() - if meta and 'pipeline_version' in meta: - pipeline_version_str = meta['pipeline_version'] + if meta and "pipeline_version" in meta: + pipeline_version_str = meta["pipeline_version"] self.pipeline_version = Version(pipeline_version_str) if self.pipeline_version > Version(self.SUPPORTED_PIPELINE_VERSION): - logging.warning("File %s has a pipeline version newer than the version supported by this class (%s vs %s)." - " Please update your AllenSDK." % (nwb_file, pipeline_version_str, self.SUPPORTED_PIPELINE_VERSION)) + logging.warning( + "File %s has a pipeline version newer than the version supported by this class (%s vs %s)." + " Please update your AllenSDK." + % (nwb_file, pipeline_version_str, self.SUPPORTED_PIPELINE_VERSION) + ) self._stimulus_search = None def get_stimulus_epoch_table(self): - '''Returns a pandas dataframe that summarizes the stimulus epoch duration for each acquisition time index in + """Returns a pandas dataframe that summarizes the stimulus epoch duration for each acquisition time index in the experiment Parameters @@ -166,68 +173,73 @@ def get_stimulus_epoch_table(self): traces: 2D numpy array Fluorescence traces for each cell - ''' - + """ # These are thresholds used by get_epoch_mask_list to set a maximum limit on the delta aqusistion frames to # count as different trials (rows in the stim table). This helps account for dropped frames, so that they dont # cause the cutting of an entire experiment into too many stimulus epochs. If these thresholds are too low, # the assert statment in get_epoch_mask_list will halt execution. In that case, make a bug report!. - threshold_dict = {si.THREE_SESSION_A:32+7, - si.THREE_SESSION_B:15, - si.THREE_SESSION_C:7, - si.THREE_SESSION_C2:7} + threshold_dict = { + si.THREE_SESSION_A: 32 + 7, + si.THREE_SESSION_B: 15, + si.THREE_SESSION_C: 7, + si.THREE_SESSION_C2: 7, + } stimulus_table_dict = {} for stimulus in self.list_stimuli(): - stimulus_table_dict[stimulus] = self.get_stimulus_table(stimulus) if stimulus == si.SPONTANEOUS_ACTIVITY: - stimulus_table_dict[stimulus]['frame'] = 0 + stimulus_table_dict[stimulus]["frame"] = 0 interval_list = [] interval_stimulus_dict = {} for stimulus in self.list_stimuli(): - stimulus_interval_list = get_epoch_mask_list(stimulus_table_dict[stimulus], threshold=threshold_dict.get(self.get_session_type(), None)) + stimulus_interval_list = get_epoch_mask_list( + stimulus_table_dict[stimulus], threshold=threshold_dict.get(self.get_session_type(), None) + ) for stimulus_interval in stimulus_interval_list: interval_stimulus_dict[stimulus_interval] = stimulus interval_list += stimulus_interval_list interval_list.sort(key=lambda x: x[0]) - stimulus_signature_list = ['gap'] + stimulus_signature_list = ["gap"] duration_signature_list = [int(interval_list[0][0])] - interval_signature_list = [(0,int(interval_list[0][0]))] + interval_signature_list = [(0, int(interval_list[0][0]))] for ii, interval in enumerate(interval_list): stimulus_signature_list.append(interval_stimulus_dict[interval]) duration_signature_list.append(int(interval[1] - interval[0])) interval_signature_list.append((int(interval[0]), int(interval[1]))) - if ii != len(interval_list)-1: - stimulus_signature_list.append('gap') - duration_signature_list.append((int(interval_list[ii+1][0] - interval_list[ii][1]))) - interval_signature_list.append((int(interval_list[ii][1]), int(interval_list[ii+1][0]))) + if ii != len(interval_list) - 1: + stimulus_signature_list.append("gap") + duration_signature_list.append((int(interval_list[ii + 1][0] - interval_list[ii][1]))) + interval_signature_list.append((int(interval_list[ii][1]), int(interval_list[ii + 1][0]))) - stimulus_signature_list.append('gap') + stimulus_signature_list.append("gap") interval_signature_list.append((int(interval_list[-1][1]), len(self.get_fluorescence_timestamps()))) - duration_signature_list.append(interval_signature_list[-1][1]-interval_signature_list[-1][0]) - - interval_df = pd.DataFrame({'stimulus':stimulus_signature_list, - 'duration':duration_signature_list, - 'interval':interval_signature_list}) + duration_signature_list.append(interval_signature_list[-1][1] - interval_signature_list[-1][0]) + + interval_df = pd.DataFrame( + { + "stimulus": stimulus_signature_list, + "duration": duration_signature_list, + "interval": interval_signature_list, + } + ) # Gaps are uninformative; remove them: - interval_df = interval_df[interval_df.stimulus != 'gap'] - interval_df['start'] = [x[0] for x in interval_df['interval'].values] - interval_df['end'] = [x[1] for x in interval_df['interval'].values] + interval_df = interval_df[interval_df.stimulus != "gap"] + interval_df["start"] = [x[0] for x in interval_df["interval"].values] + interval_df["end"] = [x[1] for x in interval_df["interval"].values] interval_df.reset_index(inplace=True, drop=True) - interval_df.drop(['interval', 'duration'], axis=1, inplace=True) + interval_df.drop(["interval", "duration"], axis=1, inplace=True) return interval_df - def get_fluorescence_traces(self, cell_specimen_ids=None): - ''' Returns an array of fluorescence traces for all ROI and + """Returns an array of fluorescence traces for all ROI and the timestamps for each datapoint Parameters @@ -243,11 +255,10 @@ def get_fluorescence_traces(self, cell_specimen_ids=None): traces: 2D numpy array Fluorescence traces for each cell - ''' + """ timestamps = self.get_fluorescence_timestamps() - with h5py.File(self.nwb_file, 'r') as f: - ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1']['data'] + with h5py.File(self.nwb_file, "r") as f: + ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1"]["data"] if cell_specimen_ids is None: cell_traces = ds[()] @@ -258,15 +269,14 @@ def get_fluorescence_traces(self, cell_specimen_ids=None): return timestamps, cell_traces def get_fluorescence_timestamps(self): - ''' Returns an array of timestamps in seconds for the fluorescence traces ''' + """Returns an array of timestamps in seconds for the fluorescence traces""" - with h5py.File(self.nwb_file, 'r') as f: - timestamps = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1']['timestamps'][()] + with h5py.File(self.nwb_file, "r") as f: + timestamps = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1"]["timestamps"][()] return timestamps def get_neuropil_traces(self, cell_specimen_ids=None): - ''' Returns an array of neuropil fluorescence traces for all ROIs + """Returns an array of neuropil fluorescence traces for all ROIs and the timestamps for each datapoint Parameters @@ -282,17 +292,15 @@ def get_neuropil_traces(self, cell_specimen_ids=None): traces: 2D numpy array Neuropil fluorescence traces for each cell - ''' + """ timestamps = self.get_fluorescence_timestamps() - with h5py.File(self.nwb_file, 'r') as f: + with h5py.File(self.nwb_file, "r") as f: if self.pipeline_version >= Version("2.0"): - ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1_neuropil_response']['data'] + ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1_neuropil_response"]["data"] else: - ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1']['neuropil_traces'] + ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1"]["neuropil_traces"] if cell_specimen_ids is None: np_traces = ds[()] @@ -302,9 +310,8 @@ def get_neuropil_traces(self, cell_specimen_ids=None): return timestamps, np_traces - def get_neuropil_r(self, cell_specimen_ids=None): - ''' Returns a scalar value of r for neuropil correction of flourescence traces + """Returns a scalar value of r for neuropil correction of flourescence traces Parameters ---------- @@ -316,15 +323,13 @@ def get_neuropil_r(self, cell_specimen_ids=None): ------- r: 1D numpy array, len(r)=len(cell_specimen_ids) Scalar for neuropil subtraction for each cell - ''' + """ - with h5py.File(self.nwb_file, 'r') as f: + with h5py.File(self.nwb_file, "r") as f: if self.pipeline_version >= Version("2.0"): - r_ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1_neuropil_response']['r'] + r_ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1_neuropil_response"]["r"] else: - r_ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1']['r'] + r_ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1"]["r"] if cell_specimen_ids is None: r = r_ds[()] @@ -335,7 +340,7 @@ def get_neuropil_r(self, cell_specimen_ids=None): return r def get_demixed_traces(self, cell_specimen_ids=None): - ''' Returns an array of demixed fluorescence traces for all ROIs + """Returns an array of demixed fluorescence traces for all ROIs and the timestamps for each datapoint Parameters @@ -351,13 +356,12 @@ def get_demixed_traces(self, cell_specimen_ids=None): traces: 2D numpy array Demixed fluorescence traces for each cell - ''' + """ timestamps = self.get_fluorescence_timestamps() - with h5py.File(self.nwb_file, 'r') as f: - ds = f['processing'][self.PIPELINE_DATASET][ - 'Fluorescence']['imaging_plane_1_demixed_signal']['data'] + with h5py.File(self.nwb_file, "r") as f: + ds = f["processing"][self.PIPELINE_DATASET]["Fluorescence"]["imaging_plane_1_demixed_signal"]["data"] if cell_specimen_ids is None: traces = ds[()] else: @@ -367,7 +371,7 @@ def get_demixed_traces(self, cell_specimen_ids=None): return timestamps, traces def get_corrected_fluorescence_traces(self, cell_specimen_ids=None): - ''' Returns an array of demixed and neuropil-corrected fluorescence traces + """Returns an array of demixed and neuropil-corrected fluorescence traces for all ROIs and the timestamps for each datapoint Parameters @@ -383,7 +387,7 @@ def get_corrected_fluorescence_traces(self, cell_specimen_ids=None): traces: 2D numpy array Corrected fluorescence traces for each cell - ''' + """ # starting in version 2.0, neuropil correction follows trace demixing if self.pipeline_version >= Version("2.0"): @@ -400,26 +404,25 @@ def get_corrected_fluorescence_traces(self, cell_specimen_ids=None): return timestamps, fc def get_cell_specimen_indices(self, cell_specimen_ids): - ''' Given a list of cell specimen ids, return their index based on their order in this file. + """Given a list of cell specimen ids, return their index based on their order in this file. Parameters ---------- cell_specimen_ids: list of cell specimen ids - ''' + """ all_cell_specimen_ids = list(self.get_cell_specimen_ids()) - + try: - inds = [list(all_cell_specimen_ids).index(i) - for i in cell_specimen_ids] + inds = [list(all_cell_specimen_ids).index(i) for i in cell_specimen_ids] except ValueError as e: raise ValueError("Cell specimen not found (%s)" % str(e)) return inds def get_dff_traces(self, cell_specimen_ids=None): - ''' Returns an array of dF/F traces for all ROIs and + """Returns an array of dF/F traces for all ROIs and the timestamps for each datapoint Parameters @@ -435,87 +438,84 @@ def get_dff_traces(self, cell_specimen_ids=None): dF/F: 2D numpy array dF/F values for each cell - ''' - with h5py.File(self.nwb_file, 'r') as f: - dff_ds = f['processing'][self.PIPELINE_DATASET][ - 'DfOverF']['imaging_plane_1'] + """ + with h5py.File(self.nwb_file, "r") as f: + dff_ds = f["processing"][self.PIPELINE_DATASET]["DfOverF"]["imaging_plane_1"] - timestamps = dff_ds['timestamps'][()] + timestamps = dff_ds["timestamps"][()] if cell_specimen_ids is None: - cell_traces = dff_ds['data'][()] + cell_traces = dff_ds["data"][()] else: inds = self.get_cell_specimen_indices(cell_specimen_ids) - cell_traces = dff_ds['data'][inds, :] + cell_traces = dff_ds["data"][inds, :] return timestamps, cell_traces def get_roi_ids(self): - ''' Returns an array of IDs for all ROIs in the file + """Returns an array of IDs for all ROIs in the file Returns ------- ROI IDs: list - ''' - with h5py.File(self.nwb_file, 'r') as f: - roi_id = f['processing'][self.PIPELINE_DATASET][ - 'ImageSegmentation']['roi_ids'][()] + """ + with h5py.File(self.nwb_file, "r") as f: + roi_id = f["processing"][self.PIPELINE_DATASET]["ImageSegmentation"]["roi_ids"][()] return roi_id def get_cell_specimen_ids(self): - ''' Returns an array of cell IDs for all cells in the file + """Returns an array of cell IDs for all cells in the file Returns ------- cell specimen IDs: list - ''' - with h5py.File(self.nwb_file, 'r') as f: - cell_id = f['processing'][self.PIPELINE_DATASET][ - 'ImageSegmentation']['cell_specimen_ids'][()] + """ + with h5py.File(self.nwb_file, "r") as f: + cell_id = f["processing"][self.PIPELINE_DATASET]["ImageSegmentation"]["cell_specimen_ids"][()] return cell_id def get_session_type(self): - ''' Returns the type of experimental session, presently one of the + """Returns the type of experimental session, presently one of the following: three_session_A, three_session_B, three_session_C Returns ------- session type: string - ''' - with h5py.File(self.nwb_file, 'r') as f: - session_type = f['general/session_type'][()] - return session_type.decode('utf-8') + """ + with h5py.File(self.nwb_file, "r") as f: + session_type = f["general/session_type"][()] + return session_type.decode("utf-8") def get_max_projection(self): - '''Returns the maximum projection image for the 2P movie. + """Returns the maximum projection image for the 2P movie. Returns ------- max projection: np.ndarray - ''' + """ - with h5py.File(self.nwb_file, 'r') as f: - max_projection = f['processing'][self.PIPELINE_DATASET]['ImageSegmentation'][ - 'imaging_plane_1']['reference_images']['maximum_intensity_projection_image']['data'][()] + with h5py.File(self.nwb_file, "r") as f: + max_projection = f["processing"][self.PIPELINE_DATASET]["ImageSegmentation"]["imaging_plane_1"][ + "reference_images" + ]["maximum_intensity_projection_image"]["data"][()] return max_projection def list_stimuli(self): - ''' Return a list of the stimuli presented in the experiment. + """Return a list of the stimuli presented in the experiment. Returns ------- stimuli: list of strings - ''' + """ - with h5py.File(self.nwb_file, 'r') as f: + with h5py.File(self.nwb_file, "r") as f: keys = list(f["stimulus/presentation/"].keys()) - return [ k.replace('_stimulus', '') for k in keys ] - + return [k.replace("_stimulus", "") for k in keys] def _get_master_stimulus_table(self): - ''' Builds a table for all stimuli by concatenating (vertically) the + """Builds a table for all stimuli by concatenating (vertically) the sub-tables describing presentation of each stimulus - ''' + """ epoch_table = self.get_stimulus_epoch_table() @@ -527,12 +527,12 @@ def _get_master_stimulus_table(self): for stimulus in self.list_stimuli(): curr_stimtable = stimulus_table_dict[stimulus] - for _, row in epoch_table[epoch_table['stimulus'] == stimulus].iterrows(): - - epoch_start_ind, epoch_end_ind = row['start'], row['end'] - curr_subtable = curr_stimtable[(epoch_start_ind <= curr_stimtable['start']) & - (curr_stimtable['end'] <= epoch_end_ind)].copy() - curr_subtable['stimulus'] = stimulus + for _, row in epoch_table[epoch_table["stimulus"] == stimulus].iterrows(): + epoch_start_ind, epoch_end_ind = row["start"], row["end"] + curr_subtable = curr_stimtable[ + (epoch_start_ind <= curr_stimtable["start"]) & (curr_stimtable["end"] <= epoch_end_ind) + ].copy() + curr_subtable["stimulus"] = stimulus table_list.append(curr_subtable) new_table = pd.concat(table_list, sort=True) @@ -540,48 +540,47 @@ def _get_master_stimulus_table(self): return new_table - def get_stimulus_table(self, stimulus_name): - ''' Return a stimulus table given a stimulus name - + """Return a stimulus table given a stimulus name + Notes ----- For more information, see: - http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf - ''' + """ - if stimulus_name == 'master': + if stimulus_name == "master": return self._get_master_stimulus_table() - with h5py.File(self.nwb_file, 'r') as nwb_file: - + with h5py.File(self.nwb_file, "r") as nwb_file: stimulus_group = _find_stimulus_presentation_group(nwb_file, stimulus_name) - if stimulus_name in self.STIMULUS_TABLE_TYPES['abstract_feature_series']: + if stimulus_name in self.STIMULUS_TABLE_TYPES["abstract_feature_series"]: datasets = h5_utilities.load_datasets_by_relnames( - ['data', 'features', 'frame_duration'], nwb_file, stimulus_group) + ["data", "features", "frame_duration"], nwb_file, stimulus_group + ) return _make_abstract_feature_series_stimulus_table( - datasets['data'], h5_utilities.decode_bytes(datasets['features']), datasets['frame_duration']) + datasets["data"], h5_utilities.decode_bytes(datasets["features"]), datasets["frame_duration"] + ) - if stimulus_name in self.STIMULUS_TABLE_TYPES['indexed_time_series']: - datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) - return _make_indexed_time_series_stimulus_table(datasets['data'], datasets['frame_duration']) + if stimulus_name in self.STIMULUS_TABLE_TYPES["indexed_time_series"]: + datasets = h5_utilities.load_datasets_by_relnames(["data", "frame_duration"], nwb_file, stimulus_group) + return _make_indexed_time_series_stimulus_table(datasets["data"], datasets["frame_duration"]) - if stimulus_name in self.STIMULUS_TABLE_TYPES['repeated_indexed_time_series']: - datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) - return _make_repeated_indexed_time_series_stimulus_table(datasets['data'], datasets['frame_duration']) + if stimulus_name in self.STIMULUS_TABLE_TYPES["repeated_indexed_time_series"]: + datasets = h5_utilities.load_datasets_by_relnames(["data", "frame_duration"], nwb_file, stimulus_group) + return _make_repeated_indexed_time_series_stimulus_table(datasets["data"], datasets["frame_duration"]) - if stimulus_name == 'spontaneous': - datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) - return _make_spontaneous_activity_stimulus_table(datasets['data'], datasets['frame_duration']) + if stimulus_name == "spontaneous": + datasets = h5_utilities.load_datasets_by_relnames(["data", "frame_duration"], nwb_file, stimulus_group) + return _make_spontaneous_activity_stimulus_table(datasets["data"], datasets["frame_duration"]) raise IOError("Could not find a stimulus table named '%s'" % stimulus_name) - @memoize def get_stimulus_template(self, stimulus_name): - ''' Return an array of the stimulus template for the specified stimulus. + """Return an array of the stimulus template for the specified stimulus. Parameters ---------- @@ -591,16 +590,14 @@ def get_stimulus_template(self, stimulus_name): Returns ------- stimulus table: pd.DataFrame - ''' + """ stim_name = stimulus_name + "_image_stack" - with h5py.File(self.nwb_file, 'r') as f: - image_stack = f['stimulus']['templates'][stim_name]['data'][()] + with h5py.File(self.nwb_file, "r") as f: + image_stack = f["stimulus"]["templates"][stim_name]["data"][()] return image_stack - def get_locally_sparse_noise_stimulus_template(self, - stimulus, - mask_off_screen=True): - ''' Return an array of the stimulus template for the specified stimulus. + def get_locally_sparse_noise_stimulus_template(self, stimulus, mask_off_screen=True): + """Return an array of the stimulus template for the specified stimulus. Parameters ---------- @@ -616,7 +613,7 @@ def get_locally_sparse_noise_stimulus_template(self, Returns ------- tuple: (template, off-screen mask) - ''' + """ if stimulus not in si.LOCALLY_SPARSE_NOISE_DIMENSIONS: raise KeyError("%s is not a known locally sparse noise stimulus" % stimulus) @@ -625,30 +622,28 @@ def get_locally_sparse_noise_stimulus_template(self, # build mapping from template coordinates to display coordinates template_shape = si.LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus] - template_shape = [ template_shape[1], template_shape[0] ] + template_shape = [template_shape[1], template_shape[0]] template_display_shape = (1260, 720) display_shape = (1920, 1200) scale = [ float(template_shape[0]) / float(template_display_shape[0]), - float(template_shape[1]) / float(template_display_shape[1]) + float(template_shape[1]) / float(template_display_shape[1]), ] offset = [ -(display_shape[0] - template_display_shape[0]) * 0.5, - -(display_shape[1] - template_display_shape[1]) * 0.5 + -(display_shape[1] - template_display_shape[1]) * 0.5, ] - x, y = np.meshgrid(np.arange(display_shape[0]), np.arange( - display_shape[1]), indexing='ij') - template_display_coords = np.array([(x + offset[0]) * scale[0] - 0.5, - (y + offset[1]) * scale[1] - 0.5], - dtype=float) + x, y = np.meshgrid(np.arange(display_shape[0]), np.arange(display_shape[1]), indexing="ij") + template_display_coords = np.array( + [(x + offset[0]) * scale[0] - 0.5, (y + offset[1]) * scale[1] - 0.5], dtype=float + ) template_display_coords = np.rint(template_display_coords).astype(int) # build mask - template_mask, template_frac = si_mask_stimulus_template( - template_display_coords, template_shape) + template_mask, template_frac = si_mask_stimulus_template(template_display_coords, template_shape) if mask_off_screen: template[:, ~template_mask.T] = LocallySparseNoise.LSN_OFF_SCREEN @@ -656,7 +651,7 @@ def get_locally_sparse_noise_stimulus_template(self, return template, template_mask.T def get_roi_mask_array(self, cell_specimen_ids=None): - ''' Return a numpy array containing all of the ROI masks for requested cells. + """Return a numpy array containing all of the ROI masks for requested cells. If cell_specimen_ids is omitted, return all masks. Parameters @@ -667,7 +662,7 @@ def get_roi_mask_array(self, cell_specimen_ids=None): Returns ------- np.ndarray: NxWxH array, where N is number of cells - ''' + """ roi_masks = self.get_roi_mask(cell_specimen_ids) @@ -679,7 +674,7 @@ def get_roi_mask_array(self, cell_specimen_ids=None): return roi_arr def get_roi_mask(self, cell_specimen_ids=None): - ''' Returns an array of all the ROI masks + """Returns an array of all the ROI masks Parameters ---------- @@ -690,13 +685,11 @@ def get_roi_mask(self, cell_specimen_ids=None): Returns ------- List of ROI_Mask objects - ''' + """ - with h5py.File(self.nwb_file, 'r') as f: - mask_loc = f['processing'][self.PIPELINE_DATASET][ - 'ImageSegmentation']['imaging_plane_1'] - roi_list = f['processing'][self.PIPELINE_DATASET][ - 'ImageSegmentation']['imaging_plane_1']['roi_list'][()] + with h5py.File(self.nwb_file, "r") as f: + mask_loc = f["processing"][self.PIPELINE_DATASET]["ImageSegmentation"]["imaging_plane_1"] + roi_list = f["processing"][self.PIPELINE_DATASET]["ImageSegmentation"]["imaging_plane_1"]["roi_list"][()] inds = None if cell_specimen_ids is None: @@ -708,33 +701,31 @@ def get_roi_mask(self, cell_specimen_ids=None): for i in inds: v = roi_list[i] roi_mask = mask_loc[v]["img_mask"][()] - m = roi.create_roi_mask(roi_mask.shape[1], roi_mask.shape[0], - [0, 0, 0, 0], roi_mask=roi_mask, label=v) + m = roi.create_roi_mask(roi_mask.shape[1], roi_mask.shape[0], [0, 0, 0, 0], roi_mask=roi_mask, label=v) roi_array.append(m) return roi_array @property def number_of_cells(self): - '''Number of cells in the experiment''' + """Number of cells in the experiment""" # Replace here is there is a better way to get this info: return len(self.get_cell_specimen_ids()) - def get_metadata(self): - ''' Returns a dictionary of meta data associated with each + """Returns a dictionary of meta data associated with each experiment, including Cre line, specimen number, visual area imaged, imaging depth Returns ------- metadata: dictionary - ''' + """ meta = {} - with h5py.File(self.nwb_file, 'r') as f: + with h5py.File(self.nwb_file, "r") as f: for memory_key, disk_key in BrainObservatoryNwbDataSet.FILE_METADATA_MAPPING.items(): try: v = f[disk_key][()] @@ -742,9 +733,9 @@ def get_metadata(self): # convert numpy strings to python strings if v.dtype.type is np.bytes_: if len(v.shape) == 0: - v = v.decode('UTF-8') + v = v.decode("UTF-8") elif len(v.shape) == 1: - v = [ s.decode('UTF-8') for s in v ] + v = [s.decode("UTF-8") for s in v] else: raise Exception("Unrecognized metadata formatting for field %s" % disk_key) @@ -753,59 +744,56 @@ def get_metadata(self): logging.warning("could not find key %s", disk_key) # extract cre line from genotype string - genotype = meta.get('genotype') - meta['cre_line'] = meta['genotype'].split(';')[0] if genotype else None + genotype = meta.get("genotype") + meta["cre_line"] = meta["genotype"].split(";")[0] if genotype else None - imaging_depth = meta.pop('imaging_depth', None) - meta['imaging_depth_um'] = int(imaging_depth.split()[0]) if imaging_depth else None + imaging_depth = meta.pop("imaging_depth", None) + meta["imaging_depth_um"] = int(imaging_depth.split()[0]) if imaging_depth else None - ophys_experiment_id = meta.get('ophys_experiment_id') - meta['ophys_experiment_id'] = int(ophys_experiment_id) if ophys_experiment_id else None + ophys_experiment_id = meta.get("ophys_experiment_id") + meta["ophys_experiment_id"] = int(ophys_experiment_id) if ophys_experiment_id else None - experiment_container_id = meta.get('experiment_container_id') - meta['experiment_container_id'] = int(experiment_container_id) if experiment_container_id else None + experiment_container_id = meta.get("experiment_container_id") + meta["experiment_container_id"] = int(experiment_container_id) if experiment_container_id else None # convert start time to a date object - session_start_time = meta.get('session_start_time') - if isinstance( session_start_time, str ): - meta['session_start_time'] = dateutil.parser.parse(session_start_time) + session_start_time = meta.get("session_start_time") + if isinstance(session_start_time, str): + meta["session_start_time"] = dateutil.parser.parse(session_start_time) - age = meta.pop('age', None) + age = meta.pop("age", None) if age: # parse the age in days m = re.match("(.*?) days", age) if m: - meta['age_days'] = int(m.groups()[0]) + meta["age_days"] = int(m.groups()[0]) else: raise IOError("Could not parse age.") - # parse the device string (ugly, sorry) - device_string = meta.pop('device_string', None) + device_string = meta.pop("device_string", None) if device_string: m = re.match("(.*?)\.\s(.*?)\sPlease*", device_string) if m: device, device_name = m.groups() - meta['device'] = device - meta['device_name'] = device_name + meta["device"] = device + meta["device_name"] = device_name else: raise IOError("Could not parse device string.") # file version - generated_by = meta.pop('generated_by', None) + generated_by = meta.pop("generated_by", None) version = generated_by[-1] if generated_by else "0.9" meta["pipeline_version"] = version return meta def get_running_speed(self): - ''' Returns the mouse running speed in cm/s - ''' - with h5py.File(self.nwb_file, 'r') as f: - dx_ds = f['processing'][self.PIPELINE_DATASET][ - 'BehavioralTimeSeries']['running_speed'] - dxcm = dx_ds['data'][()] - dxtime = dx_ds['timestamps'][()] + """Returns the mouse running speed in cm/s""" + with h5py.File(self.nwb_file, "r") as f: + dx_ds = f["processing"][self.PIPELINE_DATASET]["BehavioralTimeSeries"]["running_speed"] + dxcm = dx_ds["data"][()] + dxtime = dx_ds["timestamps"][()] timestamps = self.get_fluorescence_timestamps() @@ -818,7 +806,7 @@ def get_running_speed(self): return dxcm, dxtime def get_pupil_location(self, as_spherical=True): - '''Returns the x, y pupil location. + """Returns the x, y pupil location. Parameters ---------- @@ -833,49 +821,46 @@ def get_pupil_location(self, as_spherical=True): (timestamps, location) Timestamps is an (Nx1) array of timestamps in seconds. Location is an (Nx2) array of spatial location. - ''' + """ if as_spherical: location_key = "pupil_location_spherical" else: location_key = "pupil_location" try: - with h5py.File(self.nwb_file, 'r') as f: - eye_tracking = f['processing'][self.PIPELINE_DATASET][ - 'EyeTracking'][location_key] - pupil_location = eye_tracking['data'][()] - pupil_times = eye_tracking['timestamps'][()] + with h5py.File(self.nwb_file, "r") as f: + eye_tracking = f["processing"][self.PIPELINE_DATASET]["EyeTracking"][location_key] + pupil_location = eye_tracking["data"][()] + pupil_times = eye_tracking["timestamps"][()] except KeyError: raise NoEyeTrackingException("No eye tracking for this experiment.") return pupil_times, pupil_location def get_pupil_size(self): - '''Returns the pupil area in pixels. + """Returns the pupil area in pixels. Returns ------- (timestamps, areas) Timestamps is an (Nx1) array of timestamps in seconds. Areas is an (Nx1) array of pupil areas in pixels. - ''' + """ try: - with h5py.File(self.nwb_file, 'r') as f: - pupil_tracking = f['processing'][self.PIPELINE_DATASET][ - 'PupilTracking']['pupil_size'] - pupil_size = pupil_tracking['data'][()] - pupil_times = pupil_tracking['timestamps'][()] + with h5py.File(self.nwb_file, "r") as f: + pupil_tracking = f["processing"][self.PIPELINE_DATASET]["PupilTracking"]["pupil_size"] + pupil_size = pupil_tracking["data"][()] + pupil_times = pupil_tracking["timestamps"][()] except KeyError: raise NoEyeTrackingException("No pupil tracking for this experiment.") return pupil_times, pupil_size def get_motion_correction(self): - ''' Returns a Panda DataFrame containing the x- and y- translation of each image used for image alignment - ''' + """Returns a Panda DataFrame containing the x- and y- translation of each image used for image alignment""" motion_correction = None - with h5py.File(self.nwb_file, 'r') as f: - pipeline_ds = f['processing'][self.PIPELINE_DATASET] + with h5py.File(self.nwb_file, "r") as f: + pipeline_ds = f["processing"][self.PIPELINE_DATASET] # pipeline 0.9 stores this in xy_translations # pipeline 1.0 stores this in xy_translation @@ -883,12 +868,12 @@ def get_motion_correction(self): try: mc_ds = pipeline_ds[mc_ds_name] - motion_log = mc_ds['data'][()] - motion_time = mc_ds['timestamps'][()] - motion_names = mc_ds['feature_description'][()] + motion_log = mc_ds["data"][()] + motion_time = mc_ds["timestamps"][()] + motion_names = mc_ds["feature_description"][()] motion_correction = pd.DataFrame(motion_log, columns=motion_names) - motion_correction['timestamp'] = motion_time + motion_correction["timestamp"] = motion_time # break out if we found it break @@ -908,109 +893,109 @@ def get_motion_correction(self): return motion_correction def save_analysis_dataframes(self, *tables): - store = pd.HDFStore(self.nwb_file, mode='a') + store = pd.HDFStore(self.nwb_file, mode="a") for k, v in tables: - store.put('analysis/%s' % (k), v) + store.put("analysis/%s" % (k), v) store.close() def save_analysis_arrays(self, *datasets): - with h5py.File(self.nwb_file, 'a') as f: + with h5py.File(self.nwb_file, "a") as f: for k, v in datasets: - if k in f['analysis']: - del f['analysis'][k] - f.create_dataset('analysis/%s' % k, data=v) + if k in f["analysis"]: + del f["analysis"][k] + f.create_dataset("analysis/%s" % k, data=v) @property def stimulus_search(self): - if self._stimulus_search is None: self._stimulus_search = si.StimulusSearch(self) return self._stimulus_search def get_stimulus(self, frame_ind): - search_result = self.stimulus_search.search(frame_ind) - if search_result is None or search_result[2]['stimulus'] == si.SPONTANEOUS_ACTIVITY: + if search_result is None or search_result[2]["stimulus"] == si.SPONTANEOUS_ACTIVITY: return None, None else: - - curr_stimulus = search_result[2]['stimulus'] - if curr_stimulus in si.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES + si.NATURAL_MOVIE_STIMULUS_TYPES + [si.NATURAL_SCENES]: - curr_frame = search_result[2]['frame'] + curr_stimulus = search_result[2]["stimulus"] + if curr_stimulus in si.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES + si.NATURAL_MOVIE_STIMULUS_TYPES + [ + si.NATURAL_SCENES + ]: + curr_frame = search_result[2]["frame"] return search_result, self.get_stimulus_template(curr_stimulus)[int(curr_frame), :, :] elif curr_stimulus == si.STATIC_GRATINGS or curr_stimulus == si.DRIFTING_GRATINGS: return search_result, None -def _find_stimulus_presentation_group(nwb_file, - stimulus_name, - base_path=_STIMULUS_PRESENTATION_PATH, - group_patterns=_STIMULUS_PRESENTATION_PATTERNS): - ''' Searches an NWB file for a stimulus presentation group. +def _find_stimulus_presentation_group( + nwb_file, stimulus_name, base_path=_STIMULUS_PRESENTATION_PATH, group_patterns=_STIMULUS_PRESENTATION_PATTERNS +): + """Searches an NWB file for a stimulus presentation group. Parameters ---------- nwb_file : h5py.File File to search stimulus_name : str - Identifier for this stimulus. Corresponds to the relative name of its h5 + Identifier for this stimulus. Corresponds to the relative name of its h5 group. base_path : str, optional Begin the search from here. Defaults to 'stimulus/presentation' group_patterns : array-like of str, optional - Patterns for the relative name of the stimulus' h5 group. Defaults to + Patterns for the relative name of the stimulus' h5 group. Defaults to the name, and the name suffixed by '_stimulus' Returns ------- - h5py.Group, h5py.Dataset : + h5py.Group, h5py.Dataset : h5 object found - ''' + """ - group_candidates = [ pattern.format(stimulus_name) for pattern in group_patterns ] + group_candidates = [pattern.format(stimulus_name) for pattern in group_patterns] matcher = functools.partial(h5_utilities.h5_object_matcher_relname_in, group_candidates) matches = h5_utilities.locate_h5_objects(matcher, nwb_file, base_path) if len(matches) == 0: raise MissingStimulusException( - 'Unable to locate stimulus: {}. ' - 'Looked for this stimulus under the names: {} '.format(stimulus_name, group_candidates) + "Unable to locate stimulus: {}. Looked for this stimulus under the names: {} ".format( + stimulus_name, group_candidates ) + ) if len(matches) > 1: raise MissingStimulusException( - 'Unable to locate stimulus: {}. ' - 'Found multiple matching stimuli: {}'.format(stimulus_name, [match.name for match in matches]) + "Unable to locate stimulus: {}. Found multiple matching stimuli: {}".format( + stimulus_name, [match.name for match in matches] + ) ) return matches[0] def align_running_speed(dxcm, dxtime, timestamps): - ''' If running speed timestamps differ from fluorescence + """If running speed timestamps differ from fluorescence timestamps, adjust by inserting NaNs to running speed. Returns ------- tuple: dxcm, dxtime - ''' + """ if dxtime[0] != timestamps[0]: adjust = np.where(timestamps == dxtime[0])[0][0] dxtime = np.insert(dxtime, 0, timestamps[:adjust]) dxcm = np.insert(dxcm, 0, np.repeat(np.nan, adjust)) adjust = len(timestamps) - len(dxtime) if adjust > 0: - dxtime = np.append(dxtime, timestamps[(-1 * adjust):]) + dxtime = np.append(dxtime, timestamps[(-1 * adjust) :]) dxcm = np.append(dxcm, np.repeat(np.nan, adjust)) return dxcm, dxtime def _make_abstract_feature_series_stimulus_table(stim_data, features, frame_dur): - ''' Return the a stimulus table for an abstract feature series. + """Return the a stimulus table for an abstract feature series. Parameters ---------- @@ -1029,24 +1014,24 @@ def _make_abstract_feature_series_stimulus_table(stim_data, features, frame_dur) Notes ----- For more information, see: - http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf - ''' + """ stimulus_table = pd.DataFrame(stim_data, columns=features) - stimulus_table.loc[:, 'start'] = frame_dur[:, 0].astype(int) - stimulus_table.loc[:, 'end'] = frame_dur[:, 1].astype(int) + stimulus_table.loc[:, "start"] = frame_dur[:, 0].astype(int) + stimulus_table.loc[:, "end"] = frame_dur[:, 1].astype(int) - stimulus_table = stimulus_table.sort_values(['start', 'end']) + stimulus_table = stimulus_table.sort_values(["start", "end"]) return stimulus_table def _make_indexed_time_series_stimulus_table(inds, frame_dur): - ''' Return the a stimulus table for an indexed time series. + """Return the a stimulus table for an indexed time series. Parameters ---------- - inds : + inds : frame_durations : np.ndarray start and stop times (s) of frames @@ -1058,34 +1043,33 @@ def _make_indexed_time_series_stimulus_table(inds, frame_dur): Notes ----- For more information, see: - http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + + """ - ''' + stimulus_table = pd.DataFrame(inds, columns=["frame"]) + stimulus_table.loc[:, "start"] = frame_dur[:, 0].astype(int) + stimulus_table.loc[:, "end"] = frame_dur[:, 1].astype(int) - stimulus_table = pd.DataFrame(inds, columns=['frame']) - stimulus_table.loc[:, 'start'] = frame_dur[:, 0].astype(int) - stimulus_table.loc[:, 'end'] = frame_dur[:, 1].astype(int) - - stimulus_table = stimulus_table.sort_values(['start', 'end']) + stimulus_table = stimulus_table.sort_values(["start", "end"]) return stimulus_table def _make_repeated_indexed_time_series_stimulus_table(inds, frame_dur): - stimulus_table = _make_indexed_time_series_stimulus_table(inds, frame_dur) - a = stimulus_table.groupby(by='frame') + a = stimulus_table.groupby(by="frame") # If this ever occurs, the repeat counter cant be trusted! - assert np.floor(len(stimulus_table))/len(a) == int(len(stimulus_table))/len(a) + assert np.floor(len(stimulus_table)) / len(a) == int(len(stimulus_table)) / len(a) - stimulus_table['repeat'] = np.repeat(range(len(stimulus_table)//len(a)), len(a)) + stimulus_table["repeat"] = np.repeat(range(len(stimulus_table) // len(a)), len(a)) return stimulus_table def _make_spontaneous_activity_stimulus_table(events, frame_durations): - ''' Builds a table describing the start and end times of the spontaneous viewing - intervals. + """Builds a table describing the start and end times of the spontaneous viewing + intervals. Parameters ---------- @@ -1096,31 +1080,25 @@ def _make_spontaneous_activity_stimulus_table(events, frame_durations): Returns ------- - pd.DataFrame : + pd.DataFrame : Each row describes an interval of spontaneous viewing. Columns are start and end times. Notes ----- For more information, see: - http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf - ''' + """ start_inds = np.where(events == 1) stop_inds = np.where(events == -1) if len(start_inds) != len(stop_inds): - raise Exception( - "inconsistent start and time times in spontaneous activity stimulus table") + raise Exception("inconsistent start and time times in spontaneous activity stimulus table") - stim_data = np.column_stack([ - frame_durations[start_inds, 0].T, - frame_durations[stop_inds, 0].T] - ).astype(int) + stim_data = np.column_stack([frame_durations[start_inds, 0].T, frame_durations[stop_inds, 0].T]).astype(int) - stimulus_table = pd.DataFrame(stim_data, columns=['start', 'end']) - stimulus_table = stimulus_table.sort_values(['start', 'end']) + stimulus_table = pd.DataFrame(stim_data, columns=["start", "end"]) + stimulus_table = stimulus_table.sort_values(["start", "end"]) return stimulus_table - - diff --git a/allensdk/core/cell_types_cache.py b/allensdk/core/cell_types_cache.py index 15d125971e..4da5b62f66 100644 --- a/allensdk/core/cell_types_cache.py +++ b/allensdk/core/cell_types_cache.py @@ -41,7 +41,7 @@ from . import json_utilities as json_utilities from .nwb_data_set import NwbDataSet -from . import swc +from . import swc import logging import warnings @@ -75,30 +75,31 @@ class CellTypesCache(Cache): """ # manifest keys - CELLS_KEY = 'CELLS' - EPHYS_FEATURES_KEY = 'EPHYS_FEATURES' - MORPHOLOGY_FEATURES_KEY = 'MORPHOLOGY_FEATURES' - EPHYS_DATA_KEY = 'EPHYS_DATA' - EPHYS_SWEEPS_KEY = 'EPHYS_SWEEPS' - RECONSTRUCTION_KEY = 'RECONSTRUCTION' - MARKER_KEY = 'MARKER' + CELLS_KEY = "CELLS" + EPHYS_FEATURES_KEY = "EPHYS_FEATURES" + MORPHOLOGY_FEATURES_KEY = "MORPHOLOGY_FEATURES" + EPHYS_DATA_KEY = "EPHYS_DATA" + EPHYS_SWEEPS_KEY = "EPHYS_SWEEPS" + RECONSTRUCTION_KEY = "RECONSTRUCTION" + MARKER_KEY = "MARKER" MANIFEST_VERSION = "1.1" def __init__(self, cache=True, manifest_file=None, base_uri=None): - if manifest_file is None: - manifest_file = get_default_manifest_file('cell_types') + manifest_file = get_default_manifest_file("cell_types") - super(CellTypesCache, self).__init__( - manifest=manifest_file, cache=cache, version=self.MANIFEST_VERSION) + super(CellTypesCache, self).__init__(manifest=manifest_file, cache=cache, version=self.MANIFEST_VERSION) self.api = CellTypesApi(base_uri=base_uri) - def get_cells(self, file_name=None, - require_morphology=False, - require_reconstruction=False, - reporter_status=None, - species=None, - simple=True): + def get_cells( + self, + file_name=None, + require_morphology=False, + require_reconstruction=False, + reporter_status=None, + species=None, + simple=True, + ): """ Download metadata for all cells in the database and optionally return a subset filtered by whether or not they have a morphology or reconstruction. @@ -127,27 +128,18 @@ def get_cells(self, file_name=None, file_name = self.get_cache_path(file_name, self.CELLS_KEY) - cells = self.api.list_cells_api(path=file_name, - strategy='lazy', - **Cache.cache_json()) + cells = self.api.list_cells_api(path=file_name, strategy="lazy", **Cache.cache_json()) if isinstance(reporter_status, str): reporter_status = [reporter_status] # filter the cells on the way out - cells = self.api.filter_cells_api(cells, - require_morphology, - require_reconstruction, - reporter_status, - species, - simple) - + cells = self.api.filter_cells_api( + cells, require_morphology, require_reconstruction, reporter_status, species, simple + ) return cells - - - def get_ephys_sweeps(self, specimen_id, file_name=None): """ Download sweep metadata for a single cell specimen. @@ -159,13 +151,9 @@ def get_ephys_sweeps(self, specimen_id, file_name=None): ID of a cell. """ - file_name = self.get_cache_path( - file_name, self.EPHYS_SWEEPS_KEY, specimen_id) + file_name = self.get_cache_path(file_name, self.EPHYS_SWEEPS_KEY, specimen_id) - sweeps = self.api.get_ephys_sweeps(specimen_id, - strategy='lazy', - path=file_name, - **Cache.cache_json()) + sweeps = self.api.get_ephys_sweeps(specimen_id, strategy="lazy", path=file_name, **Cache.cache_json()) return sweeps @@ -194,16 +182,14 @@ def get_ephys_features(self, dataframe=False, file_name=None): args = Cache.cache_csv_dataframe() else: args = Cache.cache_csv_json() - args['strategy'] = 'lazy' + args["strategy"] = "lazy" else: args = Cache.nocache_json() - features_df = self.api.get_ephys_features(path=file_name, - **args) + features_df = self.api.get_ephys_features(path=file_name, **args) return features_df - def get_morphology_features(self, dataframe=False, file_name=None): """ Download morphology features for all cells with reconstructions in the database. @@ -222,8 +208,7 @@ def get_morphology_features(self, dataframe=False, file_name=None): a list of dictionaries. """ - file_name = self.get_cache_path( - file_name, self.MORPHOLOGY_FEATURES_KEY) + file_name = self.get_cache_path(file_name, self.MORPHOLOGY_FEATURES_KEY) if self.cache: if dataframe: @@ -234,12 +219,11 @@ def get_morphology_features(self, dataframe=False, file_name=None): else: args = Cache.nocache_json() - args['strategy'] = 'lazy' - args['path'] = file_name + args["strategy"] = "lazy" + args["path"] = file_name return self.api.get_morphology_features(**args) - def get_all_features(self, dataframe=False, require_reconstruction=True): """ Download morphology and electrophysiology features for all cells and merge them @@ -260,17 +244,15 @@ def get_all_features(self, dataframe=False, require_reconstruction=True): ephys_features = pd.DataFrame(self.get_ephys_features()) morphology_features = pd.DataFrame(self.get_morphology_features()) - how = 'inner' if require_reconstruction else 'outer' + how = "inner" if require_reconstruction else "outer" - all_features = ephys_features.merge(morphology_features, - how=how, - on='specimen_id') + all_features = ephys_features.merge(morphology_features, how=how, on="specimen_id") if dataframe: warnings.warn("dataframe argument is deprecated.") return all_features else: - return all_features.to_dict('records') + return all_features.to_dict("records") def get_ephys_data(self, specimen_id, file_name=None): """ @@ -295,10 +277,9 @@ def get_ephys_data(self, specimen_id, file_name=None): and response traces out of an NWB file. """ - file_name = self.get_cache_path( - file_name, self.EPHYS_DATA_KEY, specimen_id) + file_name = self.get_cache_path(file_name, self.EPHYS_DATA_KEY, specimen_id) - self.api.save_ephys_data(specimen_id, file_name, strategy='lazy') + self.api.save_ephys_data(specimen_id, file_name, strategy="lazy") return NwbDataSet(file_name) @@ -324,12 +305,10 @@ def get_reconstruction(self, specimen_id, file_name=None): A class instance with methods for accessing morphology compartments. """ - file_name = self.get_cache_path( - file_name, self.RECONSTRUCTION_KEY, specimen_id) + file_name = self.get_cache_path(file_name, self.RECONSTRUCTION_KEY, specimen_id) if file_name is None: - raise Exception( - "Please enable caching (CellTypes.cache = True) or specify a save_file_name.") + raise Exception("Please enable caching (CellTypes.cache = True) or specify a save_file_name.") if not os.path.exists(file_name): self.api.save_reconstruction(specimen_id, file_name) @@ -358,12 +337,10 @@ def get_reconstruction_markers(self, specimen_id, file_name=None): A class instance with methods for accessing morphology compartments. """ - file_name = self.get_cache_path( - file_name, self.MARKER_KEY, specimen_id) + file_name = self.get_cache_path(file_name, self.MARKER_KEY, specimen_id) if file_name is None: - raise Exception( - "Please enable caching (CellTypes.cache = True) or specify a save_file_name.") + raise Exception("Please enable caching (CellTypes.cache = True) or specify a save_file_name.") if not os.path.exists(file_name): try: @@ -388,21 +365,14 @@ def build_manifest(self, file_name): mb = ManifestBuilder() mb.set_version(self.MANIFEST_VERSION) - mb.add_path('BASEDIR', '.') - mb.add_path(self.CELLS_KEY, 'cells.json', - typename='file', parent_key='BASEDIR') - mb.add_path(self.EPHYS_DATA_KEY, 'specimen_%d/ephys.nwb', - typename='file', parent_key='BASEDIR') - mb.add_path(self.EPHYS_FEATURES_KEY, 'ephys_features.csv', - typename='file', parent_key='BASEDIR') - mb.add_path(self.MORPHOLOGY_FEATURES_KEY, 'morphology_features.csv', - typename='file', parent_key='BASEDIR') - mb.add_path(self.RECONSTRUCTION_KEY, 'specimen_%d/reconstruction.swc', - typename='file', parent_key='BASEDIR') - mb.add_path(self.MARKER_KEY, 'specimen_%d/reconstruction.marker', - typename='file', parent_key='BASEDIR') - mb.add_path(self.EPHYS_SWEEPS_KEY, 'specimen_%d/ephys_sweeps.json', - typename='file', parent_key='BASEDIR') + mb.add_path("BASEDIR", ".") + mb.add_path(self.CELLS_KEY, "cells.json", typename="file", parent_key="BASEDIR") + mb.add_path(self.EPHYS_DATA_KEY, "specimen_%d/ephys.nwb", typename="file", parent_key="BASEDIR") + mb.add_path(self.EPHYS_FEATURES_KEY, "ephys_features.csv", typename="file", parent_key="BASEDIR") + mb.add_path(self.MORPHOLOGY_FEATURES_KEY, "morphology_features.csv", typename="file", parent_key="BASEDIR") + mb.add_path(self.RECONSTRUCTION_KEY, "specimen_%d/reconstruction.swc", typename="file", parent_key="BASEDIR") + mb.add_path(self.MARKER_KEY, "specimen_%d/reconstruction.marker", typename="file", parent_key="BASEDIR") + mb.add_path(self.EPHYS_SWEEPS_KEY, "specimen_%d/ephys_sweeps.json", typename="file", parent_key="BASEDIR") mb.write_json_file(file_name) @@ -412,7 +382,7 @@ class ReporterStatus: Valid strings for filtering by cell reporter status. """ - POSITIVE = 'positive' - NEGATIVE = 'negative' + POSITIVE = "positive" + NEGATIVE = "negative" NA = None INDETERMINATE = None diff --git a/allensdk/core/dat_utilities.py b/allensdk/core/dat_utilities.py index 36f8d4d8fe..0f90293d0f 100644 --- a/allensdk/core/dat_utilities.py +++ b/allensdk/core/dat_utilities.py @@ -37,10 +37,9 @@ class DatUtilities(object): - @classmethod def save_voltage(cls, output_path, v, t): - '''Save a single voltage output result into a simple text format. + """Save a single voltage output result into a simple text format. The output file is one t v pair per line. @@ -52,7 +51,7 @@ def save_voltage(cls, output_path, v, t): voltage t : numpy array time - ''' + """ data = numpy.transpose(numpy.vstack((t, v))) with open(output_path, "w") as f: numpy.savetxt(f, data) diff --git a/allensdk/core/dataframe_utils.py b/allensdk/core/dataframe_utils.py index 7dffde2086..68ad96c404 100644 --- a/allensdk/core/dataframe_utils.py +++ b/allensdk/core/dataframe_utils.py @@ -72,11 +72,7 @@ def patch_df_from_other( target_df[column] = None if index_column in columns_to_patch: - msg += ( - f"{index_column} is in the list of " - f"columns to patch {columns_to_patch}; " - "unsure how to handle that case\n" - ) + msg += f"{index_column} is in the list of columns to patch {columns_to_patch}; unsure how to handle that case\n" if len(msg) > 0: msg = f"failures in patch_df_from_other:\n{msg}" @@ -95,10 +91,7 @@ def patch_df_from_other( return target_df -def enforce_df_column_order( - input_df: pd.DataFrame, - column_order: List[str] -) -> pd.DataFrame: +def enforce_df_column_order(input_df: pd.DataFrame, column_order: List[str]) -> pd.DataFrame: """Return the data frame but with columns ordered. Parameters @@ -122,16 +115,12 @@ def enforce_df_column_order( pruned_order.append(col) # Get the full list of columns in the data frame with our ordered columns # first. - pruned_order.extend( - list(set(input_df.columns).difference(set(pruned_order))) - ) + pruned_order.extend(list(set(input_df.columns).difference(set(pruned_order)))) return input_df[pruned_order] def enforce_df_int_typing( - input_df: pd.DataFrame, - int_columns: List[str], - use_pandas_type: object = False + input_df: pd.DataFrame, int_columns: List[str], use_pandas_type: object = False ) -> pd.DataFrame: """Enforce integer typing for columns that may have lost int typing when combined into the final DataFrame. @@ -164,9 +153,7 @@ def enforce_df_int_typing( return input_df -def return_one_dataframe_row_only( - input_table: pd.DataFrame, index_value: int, table_name: str -) -> pd.Series: +def return_one_dataframe_row_only(input_table: pd.DataFrame, index_value: int, table_name: str) -> pd.Series: """Lookup and return one and only one row from the DataFrame returning an informative error if no or multiple rows are returned for a given index. diff --git a/allensdk/core/exceptions.py b/allensdk/core/exceptions.py index 94dc8f028c..4fcb198904 100644 --- a/allensdk/core/exceptions.py +++ b/allensdk/core/exceptions.py @@ -1,7 +1,8 @@ class DataFrameKeyError(LookupError): - """More verbose method for accessing invalid rows or columns + """More verbose method for accessing invalid rows or columns in a dataframe. Should be used when a keyerror is thrown on a dataframe. """ + def __init__(self, msg, caught_exception=None): if caught_exception: error_string = "{}\nCaught Exception: {}".format(msg, caught_exception) @@ -11,9 +12,10 @@ def __init__(self, msg, caught_exception=None): class DataFrameIndexError(LookupError): - """More verbose method for accessing invalid rows or columns + """More verbose method for accessing invalid rows or columns in a dataframe. Should be used when an index error is thrown on a dataframe. """ + def __init__(self, msg, caught_exception=None): if caught_exception: error_string = "{}\nCaught Exception: {}".format(msg, caught_exception) diff --git a/allensdk/core/h5_utilities.py b/allensdk/core/h5_utilities.py index 6c72bcb178..88d8a83a06 100644 --- a/allensdk/core/h5_utilities.py +++ b/allensdk/core/h5_utilities.py @@ -37,30 +37,25 @@ import functools +def decode_bytes(bytes_dataset, encoding="UTF-8"): + """Convert the elements of a dataset of bytes to str""" -def decode_bytes(bytes_dataset, encoding='UTF-8'): - ''' Convert the elements of a dataset of bytes to str - ''' - - return [ item.decode(encoding) for item in bytes_dataset[:].flat ] + return [item.decode(encoding) for item in bytes_dataset[:].flat] def load_datasets_by_relnames(relnames, h5_file, start_node): - ''' A convenience function for finding and loading into memory one or more + """A convenience function for finding and loading into memory one or more datasets from an h5 file - ''' + """ - matcher_cbs = { - relname: functools.partial(h5_object_matcher_relname_in, [relname]) - for relname in relnames - } + matcher_cbs = {relname: functools.partial(h5_object_matcher_relname_in, [relname]) for relname in relnames} matches = keyed_locate_h5_objects(matcher_cbs, h5_file, start_node=start_node) - return { key: value[:] for key, value in matches.items() } + return {key: value[:] for key, value in matches.items()} def h5_object_matcher_relname_in(relnames, h5_object_name, h5_object): - ''' Asks if an h5 object's relative name (the final section of its absolute name) + """Asks if an h5 object's relative name (the final section of its absolute name) is contained within a provided array Parameters @@ -74,22 +69,23 @@ def h5_object_matcher_relname_in(relnames, h5_object_name, h5_object): Returns ------- - bool : + bool : whether the match succeeded h5_object : h5py.group, h5py.Dataset the argued object - ''' + """ - return h5_object_name.split('/')[-1] in relnames, h5_object + return h5_object_name.split("/")[-1] in relnames, h5_object def keyed_locate_h5_objects(matcher_cbs, h5_file, start_node=None): - ''' Traverse an h5 file and build up a dictionary mapping supplied keys to + """Traverse an h5 file and build up a dictionary mapping supplied keys to located objects - ''' + """ matches = {} + def matcher(obj_name, obj): for key, matcher_cb in matcher_cbs.items(): match, _ = matcher_cb(obj_name, obj) @@ -101,10 +97,10 @@ def matcher(obj_name, obj): def locate_h5_objects(matcher_cb, h5_file, start_node=None): - ''' Traverse an h5 file and return objects matching supplied criteria - ''' + """Traverse an h5 file and return objects matching supplied criteria""" matches = [] + def matcher(h5_object_name, h5_object): match, _ = matcher_cb(h5_object_name, h5_object) if match: @@ -115,12 +111,11 @@ def matcher(h5_object_name, h5_object): def traverse_h5_file(callback, h5_file, start_node=None): - ''' Traverse an h5 file and apply a callback to each node - ''' + """Traverse an h5 file and apply a callback to each node""" if start_node is None: - start_node = h5_file['/'] + start_node = h5_file["/"] elif isinstance(start_node, str): start_node = h5_file[start_node] - start_node.visititems(callback) \ No newline at end of file + start_node.visititems(callback) diff --git a/allensdk/core/json_utilities.py b/allensdk/core/json_utilities.py index 6c26196f0e..c0d5f6f2a1 100644 --- a/allensdk/core/json_utilities.py +++ b/allensdk/core/json_utilities.py @@ -136,9 +136,7 @@ def read_url_post(url): the output will be of the corresponding type. """ urlp = urlparse.urlparse(url) - main_url = urlparse.urlunsplit( - (urlp.scheme, urlp.netloc, urlp.path, "", "") - ) + main_url = urlparse.urlunsplit((urlp.scheme, urlp.netloc, urlp.path, "", "")) data = json.dumps(dict(urlparse.parse_qsl(urlp.query))) handler = urllib_request.HTTPHandler() @@ -177,10 +175,7 @@ def json_handler(obj): elif hasattr(obj, "isoformat"): return obj.isoformat() else: - raise TypeError( - "Object of type %s with value of %s is not JSON serializable" - % (type(obj), repr(obj)) - ) + raise TypeError("Object of type %s with value of %s is not JSON serializable" % (type(obj), repr(obj))) class JsonComments(object): @@ -204,9 +199,7 @@ def read_file(cls, file_name): return json_object except ValueError: - ju_logger.error( - "Could not load json object from file: %s" % (file_name) - ) + ju_logger.error("Could not load json object from file: %s" % (file_name)) raise @classmethod @@ -247,18 +240,14 @@ def remove_multiline_comments(cls, json_string): Copy of the input without the comments. """ new_json = [] - start_iter = JsonComments._multiline_comment_start.finditer( - json_string - ) + start_iter = JsonComments._multiline_comment_start.finditer(json_string) json_slice_start = 0 for comment_start in start_iter: json_slice_end = comment_start.start() new_json.append(json_string[json_slice_start:json_slice_end]) search_start = comment_start.end() - comment_end = JsonComments._multiline_comment_end.search( - json_string[search_start:] - ) + comment_end = JsonComments._multiline_comment_end.search(json_string[search_start:]) if comment_end is None: break else: diff --git a/allensdk/core/lazy_property/__init__.py b/allensdk/core/lazy_property/__init__.py index ec7169740e..a436f91e43 100644 --- a/allensdk/core/lazy_property/__init__.py +++ b/allensdk/core/lazy_property/__init__.py @@ -2,4 +2,3 @@ from .lazy_property_mixin import LazyPropertyMixin __all__ = ["LazyProperty", "LazyPropertyMixin"] - diff --git a/allensdk/core/lazy_property/lazy_property.py b/allensdk/core/lazy_property/lazy_property.py index a001bbfe88..ad1f97eb4d 100644 --- a/allensdk/core/lazy_property/lazy_property.py +++ b/allensdk/core/lazy_property/lazy_property.py @@ -1,10 +1,8 @@ from typing import Callable, Iterable -class LazyProperty(object): - - def __init__(self, api_method: Callable, wrappers: Iterable = tuple(), - settable: bool = False, *args, **kwargs): +class LazyProperty(object): + def __init__(self, api_method: Callable, wrappers: Iterable = tuple(), settable: bool = False, *args, **kwargs): self.api_method = api_method self.wrappers = wrappers self.settable = settable diff --git a/allensdk/core/lazy_property/lazy_property_mixin.py b/allensdk/core/lazy_property/lazy_property_mixin.py index d21f2eb9d2..7494b391d0 100644 --- a/allensdk/core/lazy_property/lazy_property_mixin.py +++ b/allensdk/core/lazy_property/lazy_property_mixin.py @@ -2,28 +2,25 @@ class LazyPropertyMixin(object): - @property def LazyProperty(self): return LazyProperty def __getattribute__(self, name): - - lazy_class = super(LazyPropertyMixin, self).__getattribute__('LazyProperty') + lazy_class = super(LazyPropertyMixin, self).__getattribute__("LazyProperty") curr_attr = super(LazyPropertyMixin, self).__getattribute__(name) if isinstance(curr_attr, lazy_class): return curr_attr.__get__(curr_attr) else: return super(LazyPropertyMixin, self).__getattribute__(name) - def __setattr__(self, name, value): if not hasattr(self, name): super(LazyPropertyMixin, self).__setattr__(name, value) else: curr_attr = super(LazyPropertyMixin, self).__getattribute__(name) - lazy_class = super(LazyPropertyMixin, self).__getattribute__('LazyProperty') + lazy_class = super(LazyPropertyMixin, self).__getattribute__("LazyProperty") if isinstance(curr_attr, lazy_class): curr_attr.__set__(curr_attr, value) else: - super(LazyPropertyMixin, self).__setattr__(name, value) \ No newline at end of file + super(LazyPropertyMixin, self).__setattr__(name, value) diff --git a/allensdk/core/mouse_connectivity_cache.py b/allensdk/core/mouse_connectivity_cache.py index 16be982bfc..2ff58b14b2 100644 --- a/allensdk/core/mouse_connectivity_cache.py +++ b/allensdk/core/mouse_connectivity_cache.py @@ -118,12 +118,8 @@ class MouseConnectivityCache(ReferenceSpaceCache): def default_structure_ids(self): if not hasattr(self, "_default_structure_ids"): tree = self.get_structure_tree() - default_structures = tree.get_structures_by_set_id( - MouseConnectivityCache.DEFAULT_STRUCTURE_SET_IDS - ) - self._default_structure_ids = [ - st["id"] for st in default_structures - ] + default_structures = tree.get_structures_by_set_id(MouseConnectivityCache.DEFAULT_STRUCTURE_SET_IDS) + self._default_structure_ids = [st["id"] for st in default_structures] return self._default_structure_ids @@ -185,9 +181,7 @@ def get_projection_density(self, experiment_id, file_name=None): self.resolution, ) - self.api.download_projection_density( - file_name, experiment_id, self.resolution, strategy="lazy" - ) + self.api.download_projection_density(file_name, experiment_id, self.resolution, strategy="lazy") return nrrd.read(file_name) @@ -218,9 +212,7 @@ def get_injection_density(self, experiment_id, file_name=None): experiment_id, self.resolution, ) - self.api.download_injection_density( - file_name, experiment_id, self.resolution, strategy="lazy" - ) + self.api.download_injection_density(file_name, experiment_id, self.resolution, strategy="lazy") return nrrd.read(file_name) @@ -250,9 +242,7 @@ def get_injection_fraction(self, experiment_id, file_name=None): experiment_id, self.resolution, ) - self.api.download_injection_fraction( - file_name, experiment_id, self.resolution, strategy="lazy" - ) + self.api.download_injection_fraction(file_name, experiment_id, self.resolution, strategy="lazy") return nrrd.read(file_name) @@ -276,12 +266,8 @@ def get_data_mask(self, experiment_id, file_name=None): """ - file_name = self.get_cache_path( - file_name, self.DATA_MASK_KEY, experiment_id, self.resolution - ) - self.api.download_data_mask( - file_name, experiment_id, self.resolution, strategy="lazy" - ) + file_name = self.get_cache_path(file_name, self.DATA_MASK_KEY, experiment_id, self.resolution) + self.api.download_data_mask(file_name, experiment_id, self.resolution, strategy="lazy") return nrrd.read(file_name) @@ -326,9 +312,7 @@ def get_experiments( file_name = self.get_cache_path(file_name, self.EXPERIMENTS_KEY) - experiments = self.api.get_experiments_api( - path=file_name, strategy="lazy", **Cache.cache_json() - ) + experiments = self.api.get_experiments_api(path=file_name, strategy="lazy", **Cache.cache_json()) for e in experiments: # renaming id @@ -349,9 +333,7 @@ def get_experiments( del e["storage_directory"] # filter the read/downloaded list of experiments - experiments = self.filter_experiments( - experiments, cre, injection_structure_ids - ) + experiments = self.filter_experiments(experiments, cre, injection_structure_ids) if dataframe: experiments = pd.DataFrame(experiments) @@ -359,9 +341,7 @@ def get_experiments( return experiments - def filter_experiments( - self, experiments, cre=None, injection_structure_ids=None - ): + def filter_experiments(self, experiments, cre=None, injection_structure_ids=None): """ Take a list of experiments and filter them by cre status and injection structure. @@ -390,25 +370,18 @@ def filter_experiments( elif cre is not None: cre = [c.lower() for c in cre] experiments = [ - e - for e in experiments - if e["transgenic_line"] is not None - and e["transgenic_line"].lower() in cre + e for e in experiments if e["transgenic_line"] is not None and e["transgenic_line"].lower() in cre ] if injection_structure_ids is not None: descendant_ids = set( reduce( op.add, - self.get_structure_tree().descendant_ids( - injection_structure_ids - ), + self.get_structure_tree().descendant_ids(injection_structure_ids), ) ) - experiments = [ - e for e in experiments if e["structure_id"] in descendant_ids - ] + experiments = [e for e in experiments if e["structure_id"] in descendant_ids] return experiments @@ -462,9 +435,7 @@ def get_experiment_structure_unionizes( """ - file_name = self.get_cache_path( - file_name, self.STRUCTURE_UNIONIZES_KEY, experiment_id - ) + file_name = self.get_cache_path(file_name, self.STRUCTURE_UNIONIZES_KEY, experiment_id) filter_fn = functools.partial( self.filter_structure_unionizes, @@ -475,9 +446,7 @@ def get_experiment_structure_unionizes( ) def col_rn(x): - return pd.DataFrame(x).rename( - columns={"section_data_set_id": "experiment_id"} - ) + return pd.DataFrame(x).rename(columns={"section_data_set_id": "experiment_id"}) return self.api.get_structure_unionizes( [experiment_id], @@ -563,25 +532,13 @@ def rank_structures( results = [] for eid in experiment_ids: - this_experiment_unionizes = unionizes[ - unionizes["experiment_id"] == eid - ] - this_experiment_unionizes = this_experiment_unionizes.sort_values( - by=rank_on, ascending=False - ) - this_experiment_unionizes = this_experiment_unionizes.loc[ - :, output_keys - ] + this_experiment_unionizes = unionizes[unionizes["experiment_id"] == eid] + this_experiment_unionizes = this_experiment_unionizes.sort_values(by=rank_on, ascending=False) + this_experiment_unionizes = this_experiment_unionizes.loc[:, output_keys] - this_experiment_unionizes = unionizes[ - unionizes["experiment_id"] == eid - ] - this_experiment_unionizes = this_experiment_unionizes.sort_values( - by=rank_on, ascending=False - ) - this_experiment_unionizes = this_experiment_unionizes.loc[ - :, output_keys - ] + this_experiment_unionizes = unionizes[unionizes["experiment_id"] == eid] + this_experiment_unionizes = this_experiment_unionizes.sort_values(by=rank_on, ascending=False) + this_experiment_unionizes = this_experiment_unionizes.loc[:, output_keys] records = this_experiment_unionizes.to_dict("records") if len(records) > n: @@ -631,9 +588,7 @@ def filter_structure_unionizes( unionizes = unionizes[unionizes.is_injection == is_injection] if structure_ids is not None: - structure_ids = MouseConnectivityCache.validate_structure_ids( - structure_ids - ) + structure_ids = MouseConnectivityCache.validate_structure_ids(structure_ids) if include_descendants: structure_ids = reduce( @@ -643,14 +598,10 @@ def filter_structure_unionizes( else: structure_ids = set(structure_ids) - unionizes = unionizes[ - unionizes["structure_id"].isin(structure_ids) - ] + unionizes = unionizes[unionizes["structure_id"].isin(structure_ids)] if hemisphere_ids is not None: - unionizes = unionizes[ - unionizes["hemisphere_id"].isin(hemisphere_ids) - ] + unionizes = unionizes[unionizes["hemisphere_id"].isin(hemisphere_ids)] return unionizes @@ -746,17 +697,13 @@ def get_projection_matrix( cidx = 0 hlabel = {1: "-L", 2: "-R", 3: ""} - acronym_map = self.get_structure_tree().value_map( - lambda x: x["id"], lambda x: x["acronym"] - ) + acronym_map = self.get_structure_tree().value_map(lambda x: x["id"], lambda x: x["acronym"]) for hid in hemisphere_ids: for sid in projection_structure_ids: column_lookup[(hid, sid)] = cidx label = acronym_map[sid] + hlabel[hid] - columns.append( - {"hemisphere_id": hid, "structure_id": sid, "label": label} - ) + columns.append({"hemisphere_id": hid, "structure_id": sid, "label": label}) cidx += 1 for _, row in unionizes.iterrows(): @@ -781,9 +728,7 @@ def get_projection_matrix( "columns": columns, } - def get_deformation_field( - self, section_data_set_id, header_path=None, voxel_path=None - ): + def get_deformation_field(self, section_data_set_id, header_path=None, voxel_path=None): """Extract the local alignment parameters for this dataset. This a 3D vector image (3 components) describing a deformable local mapping from CCF voxels to this section data set's @@ -814,17 +759,11 @@ def get_deformation_field( warnings.warn( "deformation fields are only available at {} isometric" "resolutions, but this is a " - "{}-micron cache".format( - self.DFMFLD_RESOLUTIONS, self.resolution - ) + "{}-micron cache".format(self.DFMFLD_RESOLUTIONS, self.resolution) ) - header_path = self.get_cache_path( - header_path, self.DEFORMATION_FIELD_HEADER_KEY, section_data_set_id - ) - voxel_path = self.get_cache_path( - voxel_path, self.DEFORMATION_FIELD_VOXEL_KEY, section_data_set_id - ) + header_path = self.get_cache_path(header_path, self.DEFORMATION_FIELD_HEADER_KEY, section_data_set_id) + voxel_path = self.get_cache_path(voxel_path, self.DEFORMATION_FIELD_VOXEL_KEY, section_data_set_id) if not (os.path.exists(header_path) and os.path.exists(voxel_path)): Manifest.safe_make_parent_dirs(header_path) @@ -839,9 +778,7 @@ def get_deformation_field( sitk.ReadImage(str(header_path)) ) # TODO the str call here is only necessary in 2.7 - def get_affine_parameters( - self, section_data_set_id, direction="trv", file_name=None - ): + def get_affine_parameters(self, section_data_set_id, direction="trv", file_name=None): """Extract the parameters of the 3D affine tranformation mapping this section data set's image-space stack to CCF-space (or vice-versa). @@ -872,10 +809,7 @@ def get_affine_parameters( """ if direction not in ("trv", "tvr"): - raise ValueError( - "invalid direction: {}. direction must be one of tvr," - "trv".format(direction) - ) + raise ValueError("invalid direction: {}. direction must be one of tvr,trv".format(direction)) file_name = self.get_cache_path(file_name, self.ALIGNMENT3D_KEY) @@ -908,9 +842,7 @@ def add_manifest_paths(self, manifest_builder): """ - manifest_builder = super( - MouseConnectivityCache, self - ).add_manifest_paths(manifest_builder) + manifest_builder = super(MouseConnectivityCache, self).add_manifest_paths(manifest_builder) manifest_builder.add_path( self.EXPERIMENTS_KEY, diff --git a/allensdk/core/nwb_data_set.py b/allensdk/core/nwb_data_set.py index 7afc84a961..9f384e0b05 100644 --- a/allensdk/core/nwb_data_set.py +++ b/allensdk/core/nwb_data_set.py @@ -38,14 +38,15 @@ class NwbDataSet(object): - """ A very simple interface for exracting electrophysiology data + """A very simple interface for exracting electrophysiology data from an NWB file. """ + SPIKE_TIMES = "spike_times" DEPRECATED_SPIKE_TIMES = "aibs_spike_times" def __init__(self, file_name, spike_time_key=None): - """ Initialize the NwbDataSet instance with a file name + """Initialize the NwbDataSet instance with a file name Parameters ---------- @@ -59,7 +60,7 @@ def __init__(self, file_name, spike_time_key=None): self.spike_time_key = spike_time_key def get_sweep(self, sweep_number): - """ Retrieve the stimulus, response, index_range, and sampling rate + """Retrieve the stimulus, response, index_range, and sampling rate for a particular sweep. This method hides the NWB file's distinction between a "Sweep" and an "Experiment". An experiment is a subset of of a sweep that excludes the initial test pulse. It also excludes @@ -83,9 +84,8 @@ def get_sweep(self, sweep_number): the first element indicates the end of the test pulse and the second index is the end of valid response data. """ - with h5py.File(self.file_name, 'r') as f: - - swp = f['epochs']['Sweep_%d' % sweep_number] + with h5py.File(self.file_name, "r") as f: + swp = f["epochs"]["Sweep_%d" % sweep_number] # fetch data from file and convert to correct SI unit # this operation depends on file version. early versions of @@ -96,34 +96,33 @@ def get_sweep(self, sweep_number): major, minor = self.get_pipeline_version() if (major == 1 and minor > 0) or major > 1: # stimulus - stimulus_dataset = swp['stimulus']['timeseries']['data'] + stimulus_dataset = swp["stimulus"]["timeseries"]["data"] conversion = float(stimulus_dataset.attrs["conversion"]) stimulus = stimulus_dataset[()] * conversion # acquisition - response_dataset = swp['response']['timeseries']['data'] + response_dataset = swp["response"]["timeseries"]["data"] conversion = float(response_dataset.attrs["conversion"]) response = response_dataset[()] * conversion else: # old file version - stimulus_dataset = swp['stimulus']['timeseries']['data'] + stimulus_dataset = swp["stimulus"]["timeseries"]["data"] stimulus = stimulus_dataset[()] - response = swp['response']['timeseries']['data'][()] + response = swp["response"]["timeseries"]["data"][()] - if 'unit' in stimulus_dataset.attrs: - unit = stimulus_dataset.attrs["unit"].decode('UTF-8') + if "unit" in stimulus_dataset.attrs: + unit = stimulus_dataset.attrs["unit"].decode("UTF-8") unit_str = None - if unit.startswith('A'): + if unit.startswith("A"): unit_str = "Amps" - elif unit.startswith('V'): + elif unit.startswith("V"): unit_str = "Volts" - assert unit_str is not None, Exception( - "Stimulus time series unit not recognized") + assert unit_str is not None, Exception("Stimulus time series unit not recognized") else: unit = None - unit_str = 'Unknown' + unit_str = "Unknown" - swp_idx_start = swp['stimulus']['idx_start'][()] - swp_length = swp['stimulus']['count'][()] + swp_idx_start = swp["stimulus"]["idx_start"][()] + swp_length = swp["stimulus"]["count"][()] swp_idx_stop = swp_idx_start + swp_length - 1 sweep_index_range = (swp_idx_start, swp_idx_stop) @@ -131,9 +130,9 @@ def get_sweep(self, sweep_number): # if the sweep has an experiment, extract the experiment's index # range try: - exp = f['epochs']['Experiment_%d' % sweep_number] - exp_idx_start = exp['stimulus']['idx_start'][()] - exp_length = exp['stimulus']['count'][()] + exp = f["epochs"]["Experiment_%d" % sweep_number] + exp_idx_start = exp["stimulus"]["idx_start"][()] + exp_length = exp["stimulus"]["count"][()] exp_idx_stop = exp_idx_start + exp_length - 1 experiment_index_range = (exp_idx_start, exp_idx_stop) except KeyError: @@ -141,20 +140,18 @@ def get_sweep(self, sweep_number): # entire sweep. experiment_index_range = sweep_index_range - assert sweep_index_range[0] == 0, Exception( - "index range of the full sweep does not start at 0.") + assert sweep_index_range[0] == 0, Exception("index range of the full sweep does not start at 0.") return { - 'stimulus': stimulus, - 'response': response, - 'stimulus_unit': unit_str, - 'index_range': experiment_index_range, - 'sampling_rate': 1.0 * swp['stimulus']['timeseries'][ - 'starting_time'].attrs['rate'] + "stimulus": stimulus, + "response": response, + "stimulus_unit": unit_str, + "index_range": experiment_index_range, + "sampling_rate": 1.0 * swp["stimulus"]["timeseries"]["starting_time"].attrs["rate"], } def set_sweep(self, sweep_number, stimulus, response): - """ Overwrite the stimulus or response of an NWB file. + """Overwrite the stimulus or response of an NWB file. If the supplied arrays are shorter than stored arrays, they are padded with zeros to match the original data size. @@ -172,17 +169,17 @@ def set_sweep(self, sweep_number, stimulus, response): unchanged. """ - with h5py.File(self.file_name, 'r+') as f: - swp = f['epochs']['Sweep_%d' % sweep_number] + with h5py.File(self.file_name, "r+") as f: + swp = f["epochs"]["Sweep_%d" % sweep_number] # this is the length of the entire sweep data, including test # pulse and # whatever might be in front of it # TODO: remove deprecated 'idx_stop' - if 'idx_stop' in swp['stimulus']: - sweep_length = swp['stimulus']['idx_stop'][()] + 1 + if "idx_stop" in swp["stimulus"]: + sweep_length = swp["stimulus"]["idx_stop"][()] + 1 else: - sweep_length = swp['stimulus']['count'][()] + sweep_length = swp["stimulus"]["count"][()] if stimulus is not None: # if the data is shorter than the sweep, pad it with zeros @@ -190,7 +187,7 @@ def set_sweep(self, sweep_number, stimulus, response): if missing_data > 0: stimulus = np.append(stimulus, np.zeros(missing_data)) - swp['stimulus']['timeseries']['data'][...] = stimulus + swp["stimulus"]["timeseries"]["data"][...] = stimulus if response is not None: # if the data is shorter than the sweep, pad it with zeros @@ -198,29 +195,29 @@ def set_sweep(self, sweep_number, stimulus, response): if missing_data > 0: response = np.append(response, np.zeros(missing_data)) - swp['response']['timeseries']['data'][...] = response + swp["response"]["timeseries"]["data"][...] = response def get_pipeline_version(self): - """ Returns the AI pipeline version number, stored in the - metadata field 'generated_by'. If that field is - missing, version 0.0 is returned. + """Returns the AI pipeline version number, stored in the + metadata field 'generated_by'. If that field is + missing, version 0.0 is returned. - Returns - ------- - int tuple: (major, minor) + Returns + ------- + int tuple: (major, minor) """ try: - with h5py.File(self.file_name, 'r') as f: - if 'generated_by' in f["general"]: + with h5py.File(self.file_name, "r") as f: + if "generated_by" in f["general"]: info = f["general/generated_by"] # generated_by stores array of keys and values # keys are even numbered, corresponding values are in # odd indices for i in range(len(info)): - if info[i] == 'version': + if info[i] == "version": version = info[i + 1] break - toks = version.split('.') + toks = version.split(".") if len(toks) >= 2: major = int(toks[0]) minor = int(toks[1]) @@ -230,7 +227,7 @@ def get_pipeline_version(self): return major, minor def get_spike_times(self, sweep_number, key=None): - """ Return any spike times stored in the NWB file for a sweep. + """Return any spike times stored in the NWB file for a sweep. Parameters ---------- @@ -249,10 +246,11 @@ def get_spike_times(self, sweep_number, key=None): if key is None: key = self.spike_time_key - with h5py.File(self.file_name, 'r') as f: - datasets = ["analysis/%s/Sweep_%d" % (key, sweep_number), - "analysis/%s/Sweep_%d" % ( - self.DEPRECATED_SPIKE_TIMES, sweep_number)] + with h5py.File(self.file_name, "r") as f: + datasets = [ + "analysis/%s/Sweep_%d" % (key, sweep_number), + "analysis/%s/Sweep_%d" % (self.DEPRECATED_SPIKE_TIMES, sweep_number), + ] for ds in datasets: if ds in f: @@ -260,7 +258,7 @@ def get_spike_times(self, sweep_number, key=None): return [] def set_spike_times(self, sweep_number, spike_times, key=None): - """ Set or overwrite the spikes times for a sweep. + """Set or overwrite the spikes times for a sweep. Parameters ---------- @@ -276,7 +274,7 @@ def set_spike_times(self, sweep_number, spike_times, key=None): if key is None: key = self.spike_time_key - with h5py.File(self.file_name, 'r+') as f: + with h5py.File(self.file_name, "r+") as f: # make sure expected directory structure is in place if "analysis" not in f.keys(): f.create_group("analysis") @@ -299,31 +297,25 @@ def set_spike_times(self, sweep_number, spike_times, key=None): # rewriting data -- delete old dataset del spike_dir[sweep_name] - spike_dir.create_dataset( - sweep_name, data=spike_times, dtype='f8', maxshape=(None,)) + spike_dir.create_dataset(sweep_name, data=spike_times, dtype="f8", maxshape=(None,)) def get_sweep_numbers(self): - """ Get all of the sweep numbers in the file, including test sweeps. - """ + """Get all of the sweep numbers in the file, including test sweeps.""" - with h5py.File(self.file_name, 'r') as f: - sweeps = [int(e.split('_')[1]) - for e in f['epochs'].keys() if e.startswith('Sweep_')] + with h5py.File(self.file_name, "r") as f: + sweeps = [int(e.split("_")[1]) for e in f["epochs"].keys() if e.startswith("Sweep_")] return sweeps def get_experiment_sweep_numbers(self): - """ Get all of the sweep numbers for experiment epochs in the file, - not including test sweeps. """ + """Get all of the sweep numbers for experiment epochs in the file, + not including test sweeps.""" - with h5py.File(self.file_name, 'r') as f: - sweeps = [int(e.split('_')[1]) - for e in f['epochs'].keys() if - e.startswith('Experiment_')] + with h5py.File(self.file_name, "r") as f: + sweeps = [int(e.split("_")[1]) for e in f["epochs"].keys() if e.startswith("Experiment_")] return sweeps - def fill_sweep_responses(self, fill_value=0.0, sweep_numbers=None, - extend_experiment=False): - """ Fill sweep response arrays with a single value. + def fill_sweep_responses(self, fill_value=0.0, sweep_numbers=None, extend_experiment=False): + """Fill sweep response arrays with a single value. Parameters ---------- @@ -339,30 +331,26 @@ def fill_sweep_responses(self, fill_value=0.0, sweep_numbers=None, """ - with h5py.File(self.file_name, 'a') as f: + with h5py.File(self.file_name, "a") as f: if sweep_numbers is None: sweep_numbers = self.get_sweep_numbers() for sweep_number in sweep_numbers: epoch = "Sweep_%d" % sweep_number - if epoch in f['epochs']: - f['epochs'][epoch]['response'][ - 'timeseries']['data'][...] = fill_value + if epoch in f["epochs"]: + f["epochs"][epoch]["response"]["timeseries"]["data"][...] = fill_value if extend_experiment: epoch = "Experiment_%d" % sweep_number - if epoch in f['epochs']: - idx_start = \ - f['epochs'][epoch]['stimulus']['idx_start'][()] - count = f['epochs'][epoch]['stimulus']['timeseries'][ - 'data'].shape[0] + if epoch in f["epochs"]: + idx_start = f["epochs"][epoch]["stimulus"]["idx_start"][()] + count = f["epochs"][epoch]["stimulus"]["timeseries"]["data"].shape[0] - del f['epochs'][epoch]['stimulus']['count'] - f['epochs'][epoch]['stimulus'][ - 'count'] = count - idx_start + del f["epochs"][epoch]["stimulus"]["count"] + f["epochs"][epoch]["stimulus"]["count"] = count - idx_start def get_sweep_metadata(self, sweep_number): - """ Retrieve the sweep level metadata associated with each sweep. + """Retrieve the sweep level metadata associated with each sweep. Includes information on stimulus parameters like its name and amplitude as well as recording quality metadata, like access resistance and seal quality. @@ -380,20 +368,22 @@ def get_sweep_metadata(self, sweep_number): specific fields are ones encoded in the original AIBS in vitro .nwb files. """ - with h5py.File(self.file_name, 'r') as f: - + with h5py.File(self.file_name, "r") as f: sweep_metadata = {} # the sweep level metadata is stored in # stimulus/presentation/Sweep_XX in the .nwb file # indicates which metadata fields to return - metadata_fields = ['aibs_stimulus_amplitude_pa', - 'aibs_stimulus_name', - 'gain', 'initial_access_resistance', 'seal'] + metadata_fields = [ + "aibs_stimulus_amplitude_pa", + "aibs_stimulus_name", + "gain", + "initial_access_resistance", + "seal", + ] try: - stim_details = f['stimulus']['presentation'][ - 'Sweep_%d' % sweep_number] + stim_details = f["stimulus"]["presentation"]["Sweep_%d" % sweep_number] for field in metadata_fields: # check if sweep contains the specific metadata field if field in stim_details.keys(): diff --git a/allensdk/core/obj_utilities.py b/allensdk/core/obj_utilities.py index de603c745a..6ffc286b54 100644 --- a/allensdk/core/obj_utilities.py +++ b/allensdk/core/obj_utilities.py @@ -39,14 +39,14 @@ def read_obj(path): - with open(path, 'r') as obj_file: - lines = obj_file.read().split('\n') + with open(path, "r") as obj_file: + lines = obj_file.read().split("\n") output = parse_obj(lines) return output def parse_obj(lines): - '''Parse a wavefront obj file into a triplet of vertices, normals, and faces. + """Parse a wavefront obj file into a triplet of vertices, normals, and faces. This parser is specific to obj files generated from our annotation volumes Parameters @@ -62,17 +62,17 @@ def parse_obj(lines): vertex_normals : np.ndarray Dimensions are (nSample, nElements=3). Vectors normal to vertices. face_vertices : np.ndarray - Dimensions are (sample, nVertices=3). References are given in indices + Dimensions are (sample, nVertices=3). References are given in indices (0-indexed here, but 1-indexed in the file) of vertices that make up each face. face_normals : np.ndarray - Dimensions are (sample, nNormals=3). References are given in indices + Dimensions are (sample, nNormals=3). References are given in indices (0-indexed here, but 1-indexed in the file) of vertex normals that make up each face. Notes ----- - This parser is specialized to the obj files that the Allen Institute for Brain Science + This parser is specialized to the obj files that the Allen Institute for Brain Science generates from our own structure annotations. - ''' + """ vertices = [] vertex_normals = [] @@ -80,19 +80,18 @@ def parse_obj(lines): face_normals = [] for line in lines: - - if line[:2] == 'v ': - vertices.append( line.split()[1:] ) + if line[:2] == "v ": + vertices.append(line.split()[1:]) - elif line[:3] == 'vn ': - vertex_normals.append( line.split()[1:] ) + elif line[:3] == "vn ": + vertex_normals.append(line.split()[1:]) - elif line[:2] == 'f ': - line = line.replace('//', ' ').split()[1:] + elif line[:2] == "f ": + line = line.replace("//", " ").split()[1:] + + face_vertices.append(line[::2]) + face_normals.append(line[1::2]) - face_vertices.append( line[::2] ) - face_normals.append( line[1::2] ) - vertices = np.array(vertices).astype(float) vertex_normals = np.array(vertex_normals).astype(float) face_vertices = np.array(face_vertices).astype(int) - 1 diff --git a/allensdk/core/ontology.py b/allensdk/core/ontology.py index 30b15c5dba..a1b5411fd8 100644 --- a/allensdk/core/ontology.py +++ b/allensdk/core/ontology.py @@ -40,14 +40,14 @@ from allensdk.deprecated import class_deprecated -@class_deprecated('Use StructureTree instead.') +@class_deprecated("Use StructureTree instead.") class Ontology(object): - ''' + """ .. note:: Deprecated from 0.12.5 `Ontology` has been replaced by `StructureTree`. - ''' + """ def __init__(self, df): self.df = df @@ -57,12 +57,12 @@ def __init__(self, df): for _, s in df.iterrows(): sid = s.name - parent_id = s['parent_structure_id'] + parent_id = s["parent_structure_id"] if np.isfinite(parent_id): parent_id = int(parent_id) child_ids[parent_id].add(sid) - parent_id_list = map(int, s['structure_id_path'].split('/')[1:-1]) + parent_id_list = map(int, s["structure_id_path"].split("/")[1:-1]) for parent_id in parent_id_list: descendant_ids[parent_id].add(sid) @@ -94,7 +94,7 @@ def __getitem__(self, structures): # multiple arguments (e.g. ontology[315,997]), that gets passed through as a # tuple. This normalizes the arguments so that everything is iterable. if not isinstance(structures, tuple) and not isinstance(structures, list) and not isinstance(structures, set): - structures = structures, + structures = (structures,) # this is the final set of structure ids used to filter structure_ids = set() @@ -115,9 +115,8 @@ def __getitem__(self, structures): # convert the string arguments to rows if len(string_strs): - # pull out the rows that match these acronyms - string_strs = self.df[self.df['acronym'].isin(string_strs)] + string_strs = self.df[self.df["acronym"].isin(string_strs)] # if there are no other structure ids, just return this dataframe if len(structure_ids) == 0: @@ -126,7 +125,7 @@ def __getitem__(self, structures): # otherwise pull out the ids and add them to the set structure_ids.update(string_strs.id.tolist()) - return self.df.loc[structure_ids].dropna(axis=0, how='all') + return self.df.loc[structure_ids].dropna(axis=0, how="all") def get_descendant_ids(self, structure_ids): """ @@ -149,8 +148,7 @@ def get_descendant_ids(self, structure_ids): else: descendants = set() for structure_id in structure_ids: - descendants.update(self.descendant_ids.get( - int(structure_id), set())) + descendants.update(self.descendant_ids.get(int(structure_id), set())) return descendants def get_child_ids(self, structure_ids): @@ -220,7 +218,7 @@ def structure_descends_from(self, child_id, parent_id): child = self[child_id] if child is not None: - parent_str = '/%d/' % parent_id - return child['structure_id_path'].values[0].find(parent_str) >= 0 + parent_str = "/%d/" % parent_id + return child["structure_id_path"].values[0].find(parent_str) >= 0 return False diff --git a/allensdk/core/ophys_experiment_session_id_mapping.py b/allensdk/core/ophys_experiment_session_id_mapping.py index eee314ac5f..0dca74cbf5 100644 --- a/allensdk/core/ophys_experiment_session_id_mapping.py +++ b/allensdk/core/ophys_experiment_session_id_mapping.py @@ -1373,5 +1373,5 @@ 591640135: 610504038, 591823992: 610504052, 591780793: 610504045, - 500855614: 610491429 + 500855614: 610491429, } diff --git a/allensdk/core/pickle_utils.py b/allensdk/core/pickle_utils.py index 3b150801ba..51f503d651 100644 --- a/allensdk/core/pickle_utils.py +++ b/allensdk/core/pickle_utils.py @@ -4,8 +4,7 @@ import pathlib -def load_and_sanitize_pickle( - pickle_path: Union[str, pathlib.Path]) -> Any: +def load_and_sanitize_pickle(pickle_path: Union[str, pathlib.Path]) -> Any: """ Load the data from a pickle file and pass it through sanitize_pickle_data, so that all bytes in the data are @@ -31,21 +30,19 @@ def load_and_sanitize_pickle( if isinstance(pickle_path, str): pickle_path = pathlib.Path(pickle_path) - if pickle_path.name.endswith('gz'): + if pickle_path.name.endswith("gz"): open_method = gzip.open - elif pickle_path.name.endswith('pkl'): + elif pickle_path.name.endswith("pkl"): open_method = open else: - raise ValueError("Can open .pkl and .gz files; " - f"you gave {pickle_path.resolve().absolute()}") + raise ValueError(f"Can open .pkl and .gz files; you gave {pickle_path.resolve().absolute()}") - with open_method(pickle_path, 'rb') as in_file: - raw_data = pickle.load(in_file, encoding='bytes') + with open_method(pickle_path, "rb") as in_file: + raw_data = pickle.load(in_file, encoding="bytes") return _sanitize_pickle_data(raw_data) -def _sanitize_pickle_data( - raw_data: Union[list, dict]) -> Union[list, dict]: +def _sanitize_pickle_data(raw_data: Union[list, dict]) -> Union[list, dict]: """ Sometimes data read from the pickle file comes with keys that are strings; sometimes it comes with keys that are bytes. @@ -66,8 +63,7 @@ def _sanitize_pickle_data( return raw_data -def _sanitize_list( - raw_data: list) -> list: +def _sanitize_list(raw_data: list) -> list: """ Sanitize a list read from the pickle file, casting bytes into str and returning the sanitized list. @@ -82,15 +78,14 @@ def _sanitize_list( elif isinstance(element, dict): raw_data[idx] = _sanitize_dict(element) elif isinstance(element, bytes): - raw_data[idx] = element.decode('utf-8') + raw_data[idx] = element.decode("utf-8") else: pass return raw_data -def _sanitize_tuple( - raw_data: tuple) -> tuple: +def _sanitize_tuple(raw_data: tuple) -> tuple: """ Sanitize a list read from the pickle file, casting bytes into str and returning the sanitized list. @@ -101,8 +96,7 @@ def _sanitize_tuple( return output -def _sanitize_list_or_tuple( - raw_data: Union[list, tuple]) -> Union[list, tuple]: +def _sanitize_list_or_tuple(raw_data: Union[list, tuple]) -> Union[list, tuple]: """ Sanitize a list or tuple read from the pickle file, casting bytes into str and returning the sanitized iterable. @@ -117,12 +111,10 @@ def _sanitize_list_or_tuple( elif isinstance(raw_data, tuple): return _sanitize_tuple(raw_data) - raise ValueError("Can only process lists or tuples; " - f"you gave {type(raw_data)}") + raise ValueError(f"Can only process lists or tuples; you gave {type(raw_data)}") -def _sanitize_dict( - raw_data: dict) -> dict: +def _sanitize_dict(raw_data: dict) -> dict: """ Sanitize a dict read from the pickle file, casting bytes into str and returning the sanitized dict. @@ -138,14 +130,14 @@ def _sanitize_dict( this_value = raw_data.pop(this_key) if isinstance(this_key, bytes): - this_key = this_key.decode('utf-8') + this_key = this_key.decode("utf-8") if isinstance(this_value, list) or isinstance(this_value, tuple): this_value = _sanitize_list_or_tuple(this_value) elif isinstance(this_value, dict): this_value = _sanitize_dict(this_value) elif isinstance(this_value, bytes): - this_value = this_value.decode('utf-8') + this_value = this_value.decode("utf-8") raw_data[this_key] = this_value return raw_data diff --git a/allensdk/core/reference_space.py b/allensdk/core/reference_space.py index e8346cee1f..fe697b8bd4 100644 --- a/allensdk/core/reference_space.py +++ b/allensdk/core/reference_space.py @@ -47,30 +47,29 @@ class ReferenceSpace(object): - @property def direct_voxel_map(self): - if not hasattr(self, '_direct_voxel_map'): + if not hasattr(self, "_direct_voxel_map"): self.direct_voxel_counts() - return self._direct_voxel_map - + return self._direct_voxel_map + @direct_voxel_map.setter def direct_voxel_map(self, data): self._direct_voxel_map = data - + @property def total_voxel_map(self): - if not hasattr(self, '_total_voxel_map'): + if not hasattr(self, "_total_voxel_map"): self.total_voxel_counts() return self._total_voxel_map - + @total_voxel_map.setter def total_voxel_map(self, data): self._total_voxel_map = data - + def __init__(self, structure_tree, annotation, resolution): - '''Handles brain structures in a 3d reference space - + """Handles brain structures in a 3d reference space + Parameters ---------- structure_tree : StructureTree @@ -79,262 +78,248 @@ def __init__(self, structure_tree, annotation, resolution): 3d volume whose elements are structure ids. resolution : length-3 tuple of numeric Resolution of annotation voxels along each dimension. - - ''' - + + """ + self.structure_tree = structure_tree self.resolution = resolution - + self.annotation = np.ascontiguousarray(annotation) - + def direct_voxel_counts(self): - '''Determines the number of voxels directly assigned to one or more + """Determines the number of voxels directly assigned to one or more structures. - + Returns ------- - dict : - Keys are structure ids, values are the number of voxels directly + dict : + Keys are structure ids, values are the number of voxels directly assigned to those structures. - - ''' + + """ uniques = np.unique(self.annotation, return_counts=True) found = {k: v for k, v in zip(*uniques) if k != 0} - self._direct_voxel_map = {k: (found[k] if k in found else 0) for k - in self.structure_tree.node_ids()} - + self._direct_voxel_map = {k: (found[k] if k in found else 0) for k in self.structure_tree.node_ids()} + def total_voxel_counts(self): - '''Determines the number of voxels assigned to a structure or its + """Determines the number of voxels assigned to a structure or its descendants - + Returns ------- - dict : - Keys are structure ids, values are the number of voxels assigned + dict : + Keys are structure ids, values are the number of voxels assigned to structures' descendants. - - ''' + + """ self._total_voxel_map = {} for stid in self.structure_tree.node_ids(): - desc_ids = self.structure_tree.descendant_ids([stid])[0] - self._total_voxel_map[stid] = sum([self.direct_voxel_map[dscid] - for dscid in desc_ids]) - + self._total_voxel_map[stid] = sum([self.direct_voxel_map[dscid] for dscid in desc_ids]) + def remove_unassigned(self, update_self=True): - '''Obtains a structure tree consisting only of structures that have + """Obtains a structure tree consisting only of structures that have at least one voxel in the annotation. - + Parameters ---------- update_self : bool, optional If True, the contained structure tree will be replaced, - + Returns ------- - list of dict : + list of dict : elements are filtered structures - - ''' - - structures = self.structure_tree.filter_nodes( - lambda x: self.total_voxel_map[x['id']] > 0) - + + """ + + structures = self.structure_tree.filter_nodes(lambda x: self.total_voxel_map[x["id"]] > 0) + if update_self: self.structure_tree = StructureTree(structures) - + return structures - + def make_structure_mask(self, structure_ids, direct_only=False): - '''Return an indicator array for one or more structures - + """Return an indicator array for one or more structures + Parameters ---------- structure_ids : list of int Make a mask that indicates the union of these structures' voxels direct_only : bool, optional - If True, only include voxels directly assigned to a structure in + If True, only include voxels directly assigned to a structure in the mask. Otherwise include voxels assigned to descendants. - + Returns ------- numpy ndarray : Same shape as annotation. 1 inside mask, 0 outside. - - ''' - + + """ + if direct_only: - mask = np.zeros(self.annotation.shape, dtype=np.uint8, order='C') + mask = np.zeros(self.annotation.shape, dtype=np.uint8, order="C") for stid in structure_ids: - if self.direct_voxel_map[stid] == 0: continue - + mask[self.annotation == stid] = True - + return mask - + else: structure_ids = self.structure_tree.descendant_ids(structure_ids) structure_ids = set(functools.reduce(op.add, structure_ids)) return self.make_structure_mask(structure_ids, direct_only=True) - - def many_structure_masks(self, structure_ids, output_cb=None, - direct_only=False): - '''Build one or more structure masks and do something with them - + + def many_structure_masks(self, structure_ids, output_cb=None, direct_only=False): + """Build one or more structure masks and do something with them + Parameters ---------- structure_ids : list of int Specify structures to be masked output_cb : function, optional - Must have the following signature: output_cb(structure_id, fn). - On each requested id, fn will be curried to make a mask for that + Must have the following signature: output_cb(structure_id, fn). + On each requested id, fn will be curried to make a mask for that id. Defaults to returning the structure id and mask. direct_only : bool, optional - If True, only include voxels directly assigned to a structure in + If True, only include voxels directly assigned to a structure in the mask. Otherwise include voxels assigned to descendants. - + Yields ------- - Return values of output_cb called on each structure_id, structure_mask + Return values of output_cb called on each structure_id, structure_mask pair. - + Notes ----- - output_cb is called on every yield, so any side-effects (such as - writing to a file) will be carried out regardless of what you do with - the return values. You do actually have to iterate through the output, + output_cb is called on every yield, so any side-effects (such as + writing to a file) will be carried out regardless of what you do with + the return values. You do actually have to iterate through the output, though. - - ''' - + + """ + if output_cb is None: output_cb = ReferenceSpace.return_mask_cb - - for stid in structure_ids: - yield output_cb(stid, functools.partial(self.make_structure_mask, - [stid], direct_only)) + for stid in structure_ids: + yield output_cb(stid, functools.partial(self.make_structure_mask, [stid], direct_only)) def check_coverage(self, structure_ids, domain_mask): - '''Determines whether a spatial domain is completely covered by + """Determines whether a spatial domain is completely covered by structures in a set. - + Parameters ---------- - structure_ids : list of int + structure_ids : list of int Specifies the set of structures to check. domain_mask : numpy ndarray - Same shape as annotation. 1 inside the mask, 0 out. Specifies + Same shape as annotation. 1 inside the mask, 0 out. Specifies spatial domain. - + Returns ------- - numpy ndarray : - 1 where voxels are missing from the candidate, 0 where the + numpy ndarray : + 1 where voxels are missing from the candidate, 0 where the candidate exceeds the domain - - ''' - + + """ + candidate_mask = self.make_structure_mask(structure_ids) return domain_mask - candidate_mask - + def validate_structures(self, structure_ids, domain_mask): - '''Determines whether a set of structures produces an exact and + """Determines whether a set of structures produces an exact and nonoverlapping tiling of a spatial domain - + Parameters ---------- - structure_ids : list of int + structure_ids : list of int Specifies the set of structures to check. domain_mask : numpy ndarray - Same shape as annotation. 1 inside the mask, 0 out. Specifies + Same shape as annotation. 1 inside the mask, 0 out. Specifies spatial domain. - + Returns ------- - set : - Ids of structures that are the ancestors of other structures in + set : + Ids of structures that are the ancestors of other structures in the supplied set. - numpy ndarray : + numpy ndarray : Indicator for missing voxels. - - ''' - - return [self.structure_tree.has_overlaps(structure_ids), - self.check_coverage(structure_ids, domain_mask)] - - + + """ + + return [self.structure_tree.has_overlaps(structure_ids), self.check_coverage(structure_ids, domain_mask)] + def downsample(self, target_resolution): - '''Obtain a smaller reference space by downsampling - + """Obtain a smaller reference space by downsampling + Parameters ---------- target_resolution : tuple of numeric Resolution in microns of the output space. interpolator : string - Method used to interpolate the volume. Currently only 'nearest' + Method used to interpolate the volume. Currently only 'nearest' is supported - + Returns ------- - ReferenceSpace : - A new ReferenceSpace with the same structure tree and a + ReferenceSpace : + A new ReferenceSpace with the same structure tree and a downsampled annotation. - - ''' - - factors = [ float(ii / jj) for ii, jj in zip(self.resolution, - target_resolution)] - + + """ + + factors = [float(ii / jj) for ii, jj in zip(self.resolution, target_resolution)] + target = zoom(self.annotation, factors, order=0) - + return ReferenceSpace(self.structure_tree, target, target_resolution) - - + def get_slice_image(self, axis, position, cmap=None): - '''Produce a AxBx3 RGB image from a slice in the annotation - + """Produce a AxBx3 RGB image from a slice in the annotation + Parameters ---------- axis : int - Along which to slice the annotation volume. 0 is coronal, 1 is + Along which to slice the annotation volume. 0 is coronal, 1 is horizontal, and 2 is sagittal. - position : int + position : int In microns. Take the slice from this far along the specified axis. cmap : dict, optional - Keys are structure ids, values are rgb triplets. Defaults to - structure rgb_triplets. - + Keys are structure ids, values are rgb triplets. Defaults to + structure rgb_triplets. + Returns ------- - np.ndarray : - RGB image array. - + np.ndarray : + RGB image array. + Notes ----- - If you assign a custom colormap, make sure that you take care of the + If you assign a custom colormap, make sure that you take care of the background in addition to the structures. - - ''' - + + """ + if cmap is None: cmap = self.structure_tree.get_colormap() cmap[0] = [0, 0, 0] - + position = int(np.around(position / self.resolution[axis])) image = np.squeeze(self.annotation.take([position], axis=axis)) - - return np.reshape([cmap[point] for point in image.flat], - list(image.shape) + [3]).astype(np.uint8) - - + + return np.reshape([cmap[point] for point in image.flat], list(image.shape) + [3]).astype(np.uint8) + def export_itksnap_labels(self, id_type=np.uint16, label_description_kwargs=None): - '''Produces itksnap labels, remapping large ids if needed. + """Produces itksnap labels, remapping large ids if needed. Parameters ---------- @@ -345,36 +330,35 @@ def export_itksnap_labels(self, id_type=np.uint16, label_description_kwargs=None Returns ------- - np.ndarray : + np.ndarray : Annotation volume, remapped if needed pd.DataFrame label_description dataframe - ''' + """ if label_description_kwargs is None: label_description_kwargs = {} label_description = self.structure_tree.export_label_description(**label_description_kwargs) - if np.any(label_description['IDX'].values > np.iinfo(id_type).max): - label_description = label_description.sort_values(by='LABEL') + if np.any(label_description["IDX"].values > np.iinfo(id_type).max): + label_description = label_description.sort_values(by="LABEL") label_description = label_description.reset_index(drop=True) new_annotation = np.zeros(self.annotation.shape, dtype=id_type) id_map = {} - for ii, idx in enumerate(label_description['IDX'].values): + for ii, idx in enumerate(label_description["IDX"].values): id_map[idx] = ii + 1 new_annotation[self.annotation == idx] = ii + 1 - label_description['IDX'] = label_description.apply(lambda row: id_map[row['IDX']], axis=1) + label_description["IDX"] = label_description.apply(lambda row: id_map[row["IDX"]], axis=1) return new_annotation, label_description return self.annotation, label_description - def write_itksnap_labels(self, annotation_path, label_path, **kwargs): - '''Generate a label file (nrrd) and a label_description file (csv) for use with ITKSnap + """Generate a label file (nrrd) and a label_description file (csv) for use with ITKSnap Parameters ---------- @@ -382,35 +366,30 @@ def write_itksnap_labels(self, annotation_path, label_path, **kwargs): write generated label file here label_path : str write generated label_description file here - **kwargs : + **kwargs : will be passed to self.export_itksnap_labels - ''' + """ annotation, labels = self.export_itksnap_labels(**kwargs) - nrrd.write(annotation_path, annotation, header={'spacings': self.resolution}) - labels.to_csv(label_path, sep=' ', index=False, header=False, quoting=csv.QUOTE_NONNUMERIC) - + nrrd.write(annotation_path, annotation, header={"spacings": self.resolution}) + labels.to_csv(label_path, sep=" ", index=False, header=False, quoting=csv.QUOTE_NONNUMERIC) @staticmethod def return_mask_cb(structure_id, fn): - '''A basic callback for many_structure_masks - ''' - + """A basic callback for many_structure_masks""" + return structure_id, fn() - - + @staticmethod def check_and_write(base_dir, structure_id, fn): - '''A many_structure_masks callback that writes the mask to a nrrd file + """A many_structure_masks callback that writes the mask to a nrrd file if the file does not already exist. - ''' - - mask_path = os.path.join(base_dir, - 'structure_{0}.nrrd'.format(structure_id)) - + """ + + mask_path = os.path.join(base_dir, "structure_{0}.nrrd".format(structure_id)) + if not os.path.exists(mask_path): nrrd.write(mask_path, fn()) - - return structure_id + return structure_id diff --git a/allensdk/core/reference_space_cache.py b/allensdk/core/reference_space_cache.py index 488b916b93..6889bc3be4 100644 --- a/allensdk/core/reference_space_cache.py +++ b/allensdk/core/reference_space_cache.py @@ -42,36 +42,30 @@ class ReferenceSpaceCache(Cache): - - REFERENCE_SPACE_VERSION_KEY = 'REFERENCE_SPACE_VERSION' - ANNOTATION_KEY = 'ANNOTATION' - TEMPLATE_KEY = 'TEMPLATE' - STRUCTURES_KEY = 'STRUCTURES' - STRUCTURE_TREE_KEY = 'STRUCTURE_TREE' - STRUCTURE_MASK_KEY = 'STRUCTURE_MASK' - STRUCTURE_MESH_KEY = 'STRUCTURE_MESH' + REFERENCE_SPACE_VERSION_KEY = "REFERENCE_SPACE_VERSION" + ANNOTATION_KEY = "ANNOTATION" + TEMPLATE_KEY = "TEMPLATE" + STRUCTURES_KEY = "STRUCTURES" + STRUCTURE_TREE_KEY = "STRUCTURE_TREE" + STRUCTURE_MASK_KEY = "STRUCTURE_MASK" + STRUCTURE_MESH_KEY = "STRUCTURE_MESH" MANIFEST_VERSION = 1.2 - def __init__(self, - resolution, - reference_space_key, - **kwargs): - - if 'version' not in kwargs: - kwargs['version'] = self.MANIFEST_VERSION + def __init__(self, resolution, reference_space_key, **kwargs): + if "version" not in kwargs: + kwargs["version"] = self.MANIFEST_VERSION - if 'base_uri' not in kwargs: - kwargs['base_uri'] = None + if "base_uri" not in kwargs: + kwargs["base_uri"] = None super(ReferenceSpaceCache, self).__init__(**kwargs) self.resolution = resolution - self.reference_space_key = reference_space_key - - self.api = ReferenceSpaceApi(base_uri=kwargs['base_uri']) + self.reference_space_key = reference_space_key + + self.api = ReferenceSpaceApi(base_uri=kwargs["base_uri"]) - def get_annotation_volume(self, file_name=None): """ Read the annotation volume. Download it first if it doesn't exist. @@ -86,18 +80,14 @@ def get_annotation_volume(self, file_name=None): """ - file_name = self.get_cache_path( - file_name, self.ANNOTATION_KEY, self.reference_space_key, self.resolution) + file_name = self.get_cache_path(file_name, self.ANNOTATION_KEY, self.reference_space_key, self.resolution) annotation, info = self.api.download_annotation_volume( - self.reference_space_key, - self.resolution, - file_name, - strategy='lazy') + self.reference_space_key, self.resolution, file_name, strategy="lazy" + ) return annotation, info - def get_template_volume(self, file_name=None): """ Read the template volume. Download it first if it doesn't exist. @@ -112,19 +102,15 @@ def get_template_volume(self, file_name=None): """ - file_name = self.get_cache_path( - file_name, self.TEMPLATE_KEY, self.resolution) + file_name = self.get_cache_path(file_name, self.TEMPLATE_KEY, self.resolution) - template, info = self.api.download_template_volume(self.resolution, - file_name, - strategy='lazy') + template, info = self.api.download_template_volume(self.resolution, file_name, strategy="lazy") return template, info - def get_structure_tree(self, file_name=None, structure_graph_id=1): """ - Read the list of adult mouse structures and return an StructureTree + Read the list of adult mouse structures and return an StructureTree instance. Parameters @@ -137,43 +123,44 @@ def get_structure_tree(self, file_name=None, structure_graph_id=1): structure_graph_id: int Build a tree using structure only from the identified structure graph. """ - + file_name = self.get_cache_path(file_name, self.STRUCTURE_TREE_KEY) return OntologiesApi(self.api.api_url).get_structures_with_sets( - strategy='lazy', + strategy="lazy", path=file_name, pre=StructureTree.clean_structures, - post=lambda x: StructureTree(StructureTree.clean_structures(x)), + post=lambda x: StructureTree(StructureTree.clean_structures(x)), structure_graph_ids=structure_graph_id, - **Cache.cache_json()) + **Cache.cache_json(), + ) - - def get_reference_space(self, structure_file_name=None, - annotation_file_name=None): + def get_reference_space(self, structure_file_name=None, annotation_file_name=None): """ - Build a ReferenceSpace from this cache's annotation volume and - structure tree. The ReferenceSpace does operations that relate brain + Build a ReferenceSpace from this cache's annotation volume and + structure tree. The ReferenceSpace does operations that relate brain structures to spatial domains. - + Parameters ---------- - + structure_file_name: string File name to save/read the structures table. If file_name is None, the file_name will be pulled out of the manifest. If caching is disabled, no file will be saved. Default is None. - + annotation_file_name: string File name to store the annotation volume. If it already exists, it will be read from this file. If file_name is None, the file_name will be pulled out of the manifest. Default is None. - + """ - - return ReferenceSpace(self.get_structure_tree(structure_file_name), - self.get_annotation_volume(annotation_file_name)[0], - [self.resolution] * 3) + + return ReferenceSpace( + self.get_structure_tree(structure_file_name), + self.get_annotation_volume(annotation_file_name)[0], + [self.resolution] * 3, + ) def get_structure_mask(self, structure_id, file_name=None, annotation_file_name=None): """ @@ -182,9 +169,9 @@ def get_structure_mask(self, structure_id, file_name=None, annotation_file_name= Notes ----- - This method downloads structure masks from the Allen Institute. To make your own locally, see + This method downloads structure masks from the Allen Institute. To make your own locally, see ReferenceSpace.many_structure_masks. - + Parameters ---------- @@ -204,19 +191,16 @@ def get_structure_mask(self, structure_id, file_name=None, annotation_file_name= structure_id = ReferenceSpaceCache.validate_structure_id(structure_id) file_name = self.get_cache_path( - file_name, self.STRUCTURE_MASK_KEY, self.reference_space_key, - self.resolution, structure_id) - - return self.api.download_structure_mask(structure_id, - self.reference_space_key, - self.resolution, - file_name, - strategy='lazy') + file_name, self.STRUCTURE_MASK_KEY, self.reference_space_key, self.resolution, structure_id + ) + return self.api.download_structure_mask( + structure_id, self.reference_space_key, self.resolution, file_name, strategy="lazy" + ) def get_structure_mesh(self, structure_id, file_name=None): """Obtain a 3D mesh specifying the surface of an annotated structure. - + Parameters ----------- structure_id: int @@ -234,29 +218,24 @@ def get_structure_mesh(self, structure_id, file_name=None): vertex_normals : np.ndarray Dimensions are (nSample, nElements=3). Vectors normal to vertices. face_vertices : np.ndarray - Dimensions are (sample, nVertices=3). References are given in indices + Dimensions are (sample, nVertices=3). References are given in indices (0-indexed here, but 1-indexed in the file) of vertices that make up each face. face_normals : np.ndarray - Dimensions are (sample, nNormals=3). References are given in indices + Dimensions are (sample, nNormals=3). References are given in indices (0-indexed here, but 1-indexed in the file) of vertex normals that make up each face. Notes ----- - These meshes are meant for 3D visualization and as such have been smoothed. - If you are interested in performing quantative analyses, we recommend that you + These meshes are meant for 3D visualization and as such have been smoothed. + If you are interested in performing quantative analyses, we recommend that you use the structure masks instead. """ structure_id = ReferenceSpaceCache.validate_structure_id(structure_id) - file_name = self.get_cache_path( - file_name, self.STRUCTURE_MESH_KEY, self.reference_space_key, structure_id) - - return self.api.download_structure_mesh(structure_id, - self.reference_space_key, - file_name, - strategy='lazy') + file_name = self.get_cache_path(file_name, self.STRUCTURE_MESH_KEY, self.reference_space_key, structure_id) + return self.api.download_structure_mesh(structure_id, self.reference_space_key, file_name, strategy="lazy") def add_manifest_paths(self, manifest_builder): """ @@ -271,45 +250,35 @@ def add_manifest_paths(self, manifest_builder): """ manifest_builder = super(ReferenceSpaceCache, self).add_manifest_paths(manifest_builder) - - manifest_builder.add_path(self.STRUCTURE_TREE_KEY, - 'structures.json', - parent_key='BASEDIR', - typename='file') - - manifest_builder.add_path(self.REFERENCE_SPACE_VERSION_KEY, - '%s', - parent_key='BASEDIR', - typename='dir') - - manifest_builder.add_path(self.ANNOTATION_KEY, - 'annotation_%d.nrrd', - parent_key=self.REFERENCE_SPACE_VERSION_KEY, - typename='file') - - manifest_builder.add_path(self.TEMPLATE_KEY, - 'average_template_%d.nrrd', - parent_key='BASEDIR', - typename='file') - - manifest_builder.add_path(self.STRUCTURE_MASK_KEY, - 'structure_masks/resolution_%d/structure_%d.nrrd', - parent_key=self.REFERENCE_SPACE_VERSION_KEY, - typename='file') - - manifest_builder.add_path(self.STRUCTURE_MESH_KEY, - 'structure_meshes/structure_%d.obj', - parent_key=self.REFERENCE_SPACE_VERSION_KEY, - typename='file') - return manifest_builder + manifest_builder.add_path(self.STRUCTURE_TREE_KEY, "structures.json", parent_key="BASEDIR", typename="file") + + manifest_builder.add_path(self.REFERENCE_SPACE_VERSION_KEY, "%s", parent_key="BASEDIR", typename="dir") + + manifest_builder.add_path( + self.ANNOTATION_KEY, "annotation_%d.nrrd", parent_key=self.REFERENCE_SPACE_VERSION_KEY, typename="file" + ) + manifest_builder.add_path(self.TEMPLATE_KEY, "average_template_%d.nrrd", parent_key="BASEDIR", typename="file") + + manifest_builder.add_path( + self.STRUCTURE_MASK_KEY, + "structure_masks/resolution_%d/structure_%d.nrrd", + parent_key=self.REFERENCE_SPACE_VERSION_KEY, + typename="file", + ) + + manifest_builder.add_path( + self.STRUCTURE_MESH_KEY, + "structure_meshes/structure_%d.obj", + parent_key=self.REFERENCE_SPACE_VERSION_KEY, + typename="file", + ) + + return manifest_builder - - @classmethod def validate_structure_id(cls, structure_id): - try: structure_id = int(structure_id) except ValueError: @@ -317,10 +286,8 @@ def validate_structure_id(cls, structure_id): return structure_id - @classmethod def validate_structure_ids(cls, structure_ids): - for ii, sid in enumerate(structure_ids): structure_ids[ii] = cls.validate_structure_id(sid) diff --git a/allensdk/core/simple_tree.py b/allensdk/core/simple_tree.py index 1618e4f7f0..891c7764f4 100644 --- a/allensdk/core/simple_tree.py +++ b/allensdk/core/simple_tree.py @@ -39,40 +39,38 @@ from allensdk.deprecated import deprecated -class SimpleTree( object ): - def __init__(self, nodes, - node_id_cb, - parent_id_cb): - '''A tree structure - +class SimpleTree(object): + def __init__(self, nodes, node_id_cb, parent_id_cb): + """A tree structure + Parameters ---------- nodes : list of dict - Each dict is a node in the tree. The keys of the dict name the + Each dict is a node in the tree. The keys of the dict name the properties of the node and should be consistent across nodes. node_id_cb : function | node dict -> node id - Calling node_id_cb on a node dictionary ought to produce a unique - identifier for that node (we call this the node's id). The type - of the node id is up to you, but ought to be consistent across + Calling node_id_cb on a node dictionary ought to produce a unique + identifier for that node (we call this the node's id). The type + of the node id is up to you, but ought to be consistent across nodes and must be hashable. parent_id_cb : function | node_dict => parent node's id As node_id_cb, but returns the id of the node's parent. - + Notes ----- - It is easy to pass a pandas DataFrame as the nodes. Just use the + It is easy to pass a pandas DataFrame as the nodes. Just use the to_dict method of the dataframe like so: list_of_dict = your_dataframe.to_dict('record') your_tree = SimpleTree(list_of_dict, ...) - Converting a list of dictionaries to a pandas DataFrame is also very + Converting a list of dictionaries to a pandas DataFrame is also very easy. The DataFrame constructor does it for you: your_dataframe = pandas.DataFrame(list_of_dict) - - ''' - self._nodes = { node_id_cb(n):n for n in nodes } - self._parent_ids = { nid:parent_id_cb(n) for nid,n in self._nodes.items() } - self._child_ids = { nid:[] for nid in self._nodes } + """ + + self._nodes = {node_id_cb(n): n for n in nodes} + self._parent_ids = {nid: parent_id_cb(n) for nid, n in self._nodes.items()} + self._child_ids = {nid: [] for nid in self._nodes} for nid in self._parent_ids: pid = self._parent_ids[nid] @@ -82,317 +80,301 @@ def __init__(self, nodes, self.node_id_cb = node_id_cb self.parent_id_cb = parent_id_cb - def filter_nodes(self, criterion): - '''Obtain a list of nodes filtered by some criterion - + """Obtain a list of nodes filtered by some criterion + Parameters ---------- criterion : function | node dict => bool Only nodes for which criterion returns true will be returned. - + Returns ------- list of dict : Items are node dictionaries that passed the filter. - - ''' - + + """ + return list(filter(criterion, self._nodes.values())) - def value_map(self, from_fn, to_fn): - '''Obtain a look-up table relating a pair of node properties across + """Obtain a look-up table relating a pair of node properties across nodes - + Parameters ---------- from_fn : function | node dict => hashable value - The keys of the output dictionary will be obtained by calling + The keys of the output dictionary will be obtained by calling from_fn on each node. Should be unique. to_fn : function | node_dict => value - The values of the output function will be obtained by calling + The values of the output function will be obtained by calling to_fn on each node. - + Returns ------- dict : - Maps the node property defined by from_fn to the node property + Maps the node property defined by from_fn to the node property defined by to_fn across nodes. - ''' - + """ + vm = {} for node in self._nodes.values(): key = from_fn(node) value = to_fn(node) - + if key in vm: - raise RuntimeError('from_fn is not unique across nodes. ' - 'Collision between {0} and {1}.'.format(value, vm[key])) + raise RuntimeError( + "from_fn is not unique across nodes. Collision between {0} and {1}.".format(value, vm[key]) + ) vm[key] = value - - return vm + return vm def nodes_by_property(self, key, values, to_fn=None): - '''Get nodes by a specified property + """Get nodes by a specified property Parameters ---------- key : hashable or function - The property used for lookup. Should be unique. If a function, will + The property used for lookup. Should be unique. If a function, will be invoked on each node. values : list Select matching elements from the lookup. to_fn : function, optional - Defines the outputs, on a per-node basis. Defaults to returning + Defines the outputs, on a per-node basis. Defaults to returning the whole node. - + Returns ------- - list : + list : outputs, 1 for each input value. - ''' + """ if to_fn is None: + def to_fn(x): return x - if not callable( key ): + if not callable(key): + def from_fn(x): return x[key] else: from_fn = key - value_map = self.value_map( from_fn, to_fn ) - return [ value_map[vv] for vv in values ] - + value_map = self.value_map(from_fn, to_fn) + return [value_map[vv] for vv in values] def node_ids(self): - '''Obtain the node ids of each node in the tree - + """Obtain the node ids of each node in the tree + Returns ------- list : - elements are node ids - - ''' - + elements are node ids + + """ + return list(self._nodes) - @deprecated("Use SimpleTree.parent_ids instead.") def parent_id(self, node_ids): return self.parent_ids(node_ids) - def parent_ids(self, node_ids): - '''Obtain the ids of one or more nodes' parents - + """Obtain the ids of one or more nodes' parents + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose parents you wish to find. - + Returns ------- - list of hashable : + list of hashable : Items are ids of input nodes' parents in order. - - ''' - - return [ self._parent_ids[nid] for nid in node_ids ] - + """ + + return [self._parent_ids[nid] for nid in node_ids] + def child_ids(self, node_ids): - '''Obtain the ids of one or more nodes' children - + """Obtain the ids of one or more nodes' children + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose children you wish to find. - + Returns ------- - list of list of hashable : + list of list of hashable : Items are lists of input nodes' children's ids. - - ''' - - return [ self._child_ids[nid] for nid in node_ids ] + """ + + return [self._child_ids[nid] for nid in node_ids] def ancestor_ids(self, node_ids): - '''Obtain the ids of one or more nodes' ancestors - + """Obtain the ids of one or more nodes' ancestors + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose ancestors you wish to find. - + Returns ------- - list of list of hashable : + list of list of hashable : Items are lists of input nodes' ancestors' ids. - + Notes ----- Given the tree: A -> B -> C `-> D - - The ancestors of C are [C, B, A]. The ancestors of A are [A]. The + + The ancestors of C are [C, B, A]. The ancestors of A are [A]. The ancestors of D are [D, A] - - ''' - + + """ + out = [] for nid in node_ids: - current = [nid] while current[-1] is not None: current.extend(self.parent_ids([current[-1]])) out.append(current[:-1]) - + return out - - + def descendant_ids(self, node_ids): - '''Obtain the ids of one or more nodes' descendants - + """Obtain the ids of one or more nodes' descendants + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose descendants you wish to find. - + Returns ------- - list of list of hashable : + list of list of hashable : Items are lists of input nodes' descendants' ids. - + Notes ----- Given the tree: A -> B -> C `-> D - + The descendants of A are [B, C, D]. The descendants of C are []. - - ''' - + + """ + out = [] for ii, nid in enumerate(node_ids): - current = [nid] children = self.child_ids([nid])[0] - + if children: - current.extend(functools.reduce(op.add, map(list, - self.descendant_ids(children)))) - + current.extend(functools.reduce(op.add, map(list, self.descendant_ids(children)))) + out.append(current) return out - @deprecated("Use SimpleTree.nodes instead") def node(self, node_ids=None): return self.nodes(node_ids) - def nodes(self, node_ids=None): - '''Get one or more nodes' full dictionaries from their ids. - + """Get one or more nodes' full dictionaries from their ids. + Parameters ---------- node_ids : list of hashable Items are ids of nodes to be returned. Default is all. - + Returns ------- - list of dict : + list of dict : Items are nodes corresponding to argued ids. - ''' - + """ + if node_ids is None: node_ids = self.node_ids() - - return [ self._nodes[nid] if nid in self._nodes else None for nid in node_ids] + return [self._nodes[nid] if nid in self._nodes else None for nid in node_ids] @deprecated("Use SimpleTree.parents instead") def parent(self, node_ids): return self.parents(node_ids) - def parents(self, node_ids): - '''Get one or mode nodes' parent nodes - + """Get one or mode nodes' parent nodes + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose parents will be found. - + Returns ------- - list of dict : + list of dict : Items are parents of nodes corresponding to argued ids. - - ''' - - return self.nodes([self._parent_ids[nid] for nid in node_ids]) + """ + + return self.nodes([self._parent_ids[nid] for nid in node_ids]) def children(self, node_ids): - '''Get one or mode nodes' child nodes - + """Get one or mode nodes' child nodes + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose children will be found. - + Returns ------- - list of list of dict : + list of list of dict : Items are lists of child nodes corresponding to argued ids. - - ''' - - return list(map(self.nodes, self.child_ids(node_ids))) + """ + + return list(map(self.nodes, self.child_ids(node_ids))) def descendants(self, node_ids): - '''Get one or mode nodes' descendant nodes - + """Get one or mode nodes' descendant nodes + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose descendants will be found. - + Returns ------- - list of list of dict : + list of list of dict : Items are lists of descendant nodes corresponding to argued ids. - - ''' - + + """ + return list(map(self.nodes, self.descendant_ids(node_ids))) - def ancestors(self, node_ids): - '''Get one or mode nodes' ancestor nodes - + """Get one or mode nodes' ancestor nodes + Parameters ---------- node_ids : list of hashable Items are ids of nodes whose ancestors will be found. - + Returns ------- - list of list of dict : + list of list of dict : Items are lists of ancestor nodes corresponding to argued ids. - - ''' - + + """ + return list(map(self.nodes, self.ancestor_ids(node_ids))) diff --git a/allensdk/core/sitk_utilities.py b/allensdk/core/sitk_utilities.py index 598f377f53..01295df022 100644 --- a/allensdk/core/sitk_utilities.py +++ b/allensdk/core/sitk_utilities.py @@ -41,7 +41,7 @@ def get_sitk_image_information(image): - ''' Extract information about a SimpleITK image + """Extract information about a SimpleITK image Parameters ---------- @@ -50,71 +50,73 @@ def get_sitk_image_information(image): Returns ------- - dict : - Extracted information. Includes spacing, origin, size, direction, and + dict : + Extracted information. Includes spacing, origin, size, direction, and number of components per pixel - ''' + """ - return {'spacing': image.GetSpacing(), - 'origin': image.GetOrigin(), - 'size': image.GetSize(), - 'direction': image.GetDirection(), - 'ncomponents': image.GetNumberOfComponentsPerPixel()} + return { + "spacing": image.GetSpacing(), + "origin": image.GetOrigin(), + "size": image.GetSize(), + "direction": image.GetDirection(), + "ncomponents": image.GetNumberOfComponentsPerPixel(), + } def set_sitk_image_information(image, information): - ''' Set information on a SimpleITK image + """Set information on a SimpleITK image Parameters ---------- image : sitk.Image Set information on this image. information : dict - Stores information to be set. Supports spacing, origin, direction. Also + Stores information to be set. Supports spacing, origin, direction. Also checks (but cannot set) size and number of components per pixel - ''' + """ - if 'spacing' in information: - image.SetSpacing(information.pop('spacing')) - if 'origin' in information: - image.SetOrigin(information.pop('origin')) - if 'direction' in information: - image.SetDirection(information.pop('direction')) + if "spacing" in information: + image.SetSpacing(information.pop("spacing")) + if "origin" in information: + image.SetOrigin(information.pop("origin")) + if "direction" in information: + image.SetDirection(information.pop("direction")) + + if "size" in information: + assert np.array_equal(information.pop("size"), image.GetSize()) + if "ncomponents" in information: + assert information.pop("ncomponents") == image.GetNumberOfComponentsPerPixel() - if 'size' in information: - assert(np.array_equal( information.pop('size'), image.GetSize() )) - if 'ncomponents' in information: - assert( information.pop('ncomponents') == image.GetNumberOfComponentsPerPixel() ) - if not len(information) == 0: - warnings.warn('unwritten keys: {}'.format(','.join(information.keys()))) + warnings.warn("unwritten keys: {}".format(",".join(information.keys()))) def fix_array_dimensions(array, ncomponents=1): - ''' Convenience function that reorders ndarray dimensions for io with SimpleITK + """Convenience function that reorders ndarray dimensions for io with SimpleITK Parameters ---------- array : np.ndarray The array to be reordered ncomponents : int, optional - Number of components per pixel, default 1. + Number of components per pixel, default 1. Returns ------- - np.ndarray : + np.ndarray : Reordered array - ''' + """ act_size = list(array.shape) ndims = len(act_size) multicomponent = ncomponents > 1 - from_order = list(range( ndims - multicomponent )) - to_order = list(range( ndims - multicomponent ))[::-1] + from_order = list(range(ndims - multicomponent)) + to_order = list(range(ndims - multicomponent))[::-1] if multicomponent: from_order += [-1] @@ -124,13 +126,13 @@ def fix_array_dimensions(array, ncomponents=1): def read_ndarray_with_sitk(path): - ''' Read a numpy array from a file using SimpleITK + """Read a numpy array from a file using SimpleITK Parameters ---------- path : str Read from this path - + Returns ------- image : np.ndarray @@ -138,18 +140,18 @@ def read_ndarray_with_sitk(path): information : dict Additional information about the array - ''' + """ image = sitk.ReadImage(str(path)) information = get_sitk_image_information(image) image = sitk.GetArrayFromImage(image) - image = fix_array_dimensions(image, information['ncomponents']) + image = fix_array_dimensions(image, information["ncomponents"]) return image, information def write_ndarray_with_sitk(array, path, **information): - ''' Write a numpy array to a file using SimpleITK + """Write a numpy array to a file using SimpleITK Parameters ---------- @@ -158,17 +160,17 @@ def write_ndarray_with_sitk(array, path, **information): path : str Write to here **information : dict - Contains additional information to be stored in the image file. + Contains additional information to be stored in the image file. See set_sitk_image_information for more information. - ''' + """ - if 'ncomponents' not in information: - information['ncomponents'] = 1 - ncomponents = information.pop('ncomponents') + if "ncomponents" not in information: + information["ncomponents"] = 1 + ncomponents = information.pop("ncomponents") array = fix_array_dimensions(array, ncomponents) - + array = sitk.GetImageFromArray(array, ncomponents > 1) set_sitk_image_information(array, information) diff --git a/allensdk/core/structure_tree.py b/allensdk/core/structure_tree.py index 5832ba7b56..ff3f8f685f 100644 --- a/allensdk/core/structure_tree.py +++ b/allensdk/core/structure_tree.py @@ -43,17 +43,16 @@ from .simple_tree import SimpleTree -class StructureTree( SimpleTree ): - +class StructureTree(SimpleTree): def __init__(self, nodes): - '''A tree whose nodes are brain structures and whose edges indicate + """A tree whose nodes are brain structures and whose edges indicate physical containment. - + Parameters ---------- nodes : list of dict Each specifies a structure. Fields are: - + 'acronym' : str Abbreviated name for the structure. 'rgb_triplet' : str @@ -67,217 +66,200 @@ def __init__(self, nodes): 'name' : str Full name of structure. 'structure_id_path' : list of int - This structure's ancestors (inclusive) from the root of the + This structure's ancestors (inclusive) from the root of the tree. 'structure_set_ids' : list of int - Unique identifiers of structure sets to which this structure + Unique identifiers of structure sets to which this structure belongs. - - ''' - - super(StructureTree, self).__init__(nodes, - lambda s: int(s['id']), - lambda s: s['structure_id_path'][-2] \ - if len(s['structure_id_path']) > 1 \ - and s['structure_id_path'] is not None \ - and np.isfinite(s['structure_id_path'][-2]) \ - else None) - - + + """ + + super(StructureTree, self).__init__( + nodes, + lambda s: int(s["id"]), + lambda s: s["structure_id_path"][-2] + if len(s["structure_id_path"]) > 1 + and s["structure_id_path"] is not None + and np.isfinite(s["structure_id_path"][-2]) + else None, + ) + def get_structures_by_id(self, structure_ids): - '''Obtain a list of brain structures from their structure ids - + """Obtain a list of brain structures from their structure ids + Parameters ---------- structure_ids : list of int Get structures corresponding to these ids. - + Returns ------- - list of dict : + list of dict : Each item describes a structure. - - ''' - + + """ + return self.nodes(structure_ids) - - + def get_structures_by_name(self, names): - '''Obtain a list of brain structures from their names, - + """Obtain a list of brain structures from their names, + Parameters ---------- names : list of str Get structures corresponding to these names. - + Returns ------- - list of dict : + list of dict : Each item describes a structure. - - ''' - - return self.nodes_by_property('name', names) - - + + """ + + return self.nodes_by_property("name", names) + def get_structures_by_acronym(self, acronyms): - '''Obtain a list of brain structures from their acronyms - + """Obtain a list of brain structures from their acronyms + Parameters ---------- names : list of str Get structures corresponding to these acronyms. - + Returns ------- - list of dict : + list of dict : Each item describes a structure. - - ''' - - return self.nodes_by_property('acronym', acronyms) - - + + """ + + return self.nodes_by_property("acronym", acronyms) + def get_structures_by_set_id(self, structure_set_ids): - '''Obtain a list of brain structures from by the sets that contain + """Obtain a list of brain structures from by the sets that contain them. - + Parameters ---------- structure_set_ids : list of int Get structures belonging to these structure sets. - + Returns ------- - list of dict : + list of dict : Each item describes a structure. - - ''' - + + """ + def overlap(x): - return (set(structure_set_ids) & set(x['structure_set_ids'])) + return set(structure_set_ids) & set(x["structure_set_ids"]) + return self.filter_nodes(overlap) - - + def get_colormap(self): - '''Get a dictionary mapping structure ids to colors across all nodes. - + """Get a dictionary mapping structure ids to colors across all nodes. + Returns ------- - dict : + dict : Keys are structure ids. Values are RGB lists of integers. - - ''' - - return self.value_map(lambda x: x['id'], - lambda y: y['rgb_triplet']) - - - + + """ + + return self.value_map(lambda x: x["id"], lambda y: y["rgb_triplet"]) + def get_name_map(self): - '''Get a dictionary mapping structure ids to names across all nodes. - + """Get a dictionary mapping structure ids to names across all nodes. + Returns ------- - dict : + dict : Keys are structure ids. Values are structure name strings. - - ''' - - return self.value_map(lambda x: x['id'], - lambda y: y['name']) - - + + """ + + return self.value_map(lambda x: x["id"], lambda y: y["name"]) + def get_id_acronym_map(self): - '''Get a dictionary mapping structure acronyms to ids across all nodes. - + """Get a dictionary mapping structure acronyms to ids across all nodes. + Returns ------- - dict : + dict : Keys are structure acronyms. Values are structure ids. - - ''' - - return self.value_map(lambda x: x['acronym'], - lambda y: y['id']) - - + + """ + + return self.value_map(lambda x: x["acronym"], lambda y: y["id"]) + def get_ancestor_id_map(self): - '''Get a dictionary mapping structure ids to ancestor ids across all - nodes. - + """Get a dictionary mapping structure ids to ancestor ids across all + nodes. + Returns ------- - dict : + dict : Keys are structure ids. Values are lists of ancestor ids. - - ''' - return self.value_map(lambda x: x['id'], - lambda y: self.ancestor_ids([y['id']])[0]) - - + """ + + return self.value_map(lambda x: x["id"], lambda y: self.ancestor_ids([y["id"]])[0]) + def structure_descends_from(self, child_id, parent_id): - '''Tests whether one structure descends from another. - + """Tests whether one structure descends from another. + Parameters ---------- child_id : int Id of the putative child structure. parent_id : int Id of the putative parent structure. - + Returns ------- bool : - True if the structure specified by child_id is a descendant of + True if the structure specified by child_id is a descendant of the one specified by parent_id. Otherwise False. - - ''' - + + """ + return parent_id in self.ancestor_ids([child_id])[0] - - + def get_structure_sets(self): - '''Lists all unique structure sets that are assigned to at least one + """Lists all unique structure sets that are assigned to at least one structure in the tree. - + Returns ------- - list of int : + list of int : Elements are ids of structure sets. - - ''' - - return set(functools.reduce(op.add, map(lambda x: x['structure_set_ids'], - self.nodes()))) - - + + """ + + return set(functools.reduce(op.add, map(lambda x: x["structure_set_ids"], self.nodes()))) + def has_overlaps(self, structure_ids): - '''Determine if a list of structures contains structures along with + """Determine if a list of structures contains structures along with their ancestors - + Parameters ---------- structure_ids : list of int Check this set of structures for overlaps - + Returns ------- - set : - Ids of structures that are the ancestors of other structures in + set : + Ids of structures that are the ancestors of other structures in the supplied set. - - ''' - - ancestor_ids = functools.reduce(op.add, - map(lambda x: x[1:], - self.ancestor_ids(structure_ids))) - return (set(ancestor_ids) & set(structure_ids)) - - - def export_label_description(self, alphas=None, exclude_label_vis=None, exclude_mesh_vis=None, label_key='acronym'): - '''Produces an itksnap label_description table from this structure tree + + """ + + ancestor_ids = functools.reduce(op.add, map(lambda x: x[1:], self.ancestor_ids(structure_ids))) + return set(ancestor_ids) & set(structure_ids) + + def export_label_description(self, alphas=None, exclude_label_vis=None, exclude_mesh_vis=None, label_key="acronym"): + """Produces an itksnap label_description table from this structure tree Parameters ---------- @@ -292,10 +274,10 @@ def export_label_description(self, alphas=None, exclude_label_vis=None, exclude_ Returns ------- - pd.DataFrame : + pd.DataFrame : Contains data needed for loading as an ITKSnap label description file. - ''' + """ if alphas is None: alphas = {} @@ -304,50 +286,51 @@ def export_label_description(self, alphas=None, exclude_label_vis=None, exclude_ if exclude_mesh_vis is None: exclude_mesh_vis = set([]) - df = pd.DataFrame([ - { - 'IDX': node['id'], - '-R-': node['rgb_triplet'][0], - '-G-': node['rgb_triplet'][1], - '-B-': node['rgb_triplet'][2], - '-A-': alphas.get(node['id'], 1.0), - 'VIS': 1 if node['id'] not in exclude_label_vis else 0, - 'MSH': 1 if node['id'] not in exclude_mesh_vis else 0, - 'LABEL': node[label_key] - } - for node in self.nodes() - ]).loc[:, ('IDX', '-R-', '-G-', '-B-', '-A-', 'VIS', 'MSH', 'LABEL')] + df = pd.DataFrame( + [ + { + "IDX": node["id"], + "-R-": node["rgb_triplet"][0], + "-G-": node["rgb_triplet"][1], + "-B-": node["rgb_triplet"][2], + "-A-": alphas.get(node["id"], 1.0), + "VIS": 1 if node["id"] not in exclude_label_vis else 0, + "MSH": 1 if node["id"] not in exclude_mesh_vis else 0, + "LABEL": node[label_key], + } + for node in self.nodes() + ] + ).loc[:, ("IDX", "-R-", "-G-", "-B-", "-A-", "VIS", "MSH", "LABEL")] return df - @staticmethod def clean_structures(structures, whitelist=None, data_transforms=None, renames=None): - '''Convert structures_with_sets query results into a form that can be + """Convert structures_with_sets query results into a form that can be used to construct a StructureTree - + Parameters ---------- structures : list of dict - Each element describes a structure. Should have a structure id path + Each element describes a structure. Should have a structure id path field (str values) and a structure_sets field (list of dict). whitelist : list of str, optional - Only these fields will be included in the final structure record. Default is + Only these fields will be included in the final structure record. Default is the output of StructureTree.whitelist. data_transforms : dict, optional - Keys are str field names. Values are functions which will be applied to the - data associated with those fields. Default is to map colors from hex to rgb and + Keys are str field names. Values are functions which will be applied to the + data associated with those fields. Default is to map colors from hex to rgb and convert the structure id path to a list of int. renames : dict, optional - Controls the field names that appear in the output structure records. Default is + Controls the field names that appear in the output structure records. Default is to map 'color_hex_triplet' to 'rgb_triplet'. - + Returns ------- - list of dict : - structures, after conversion of structure_id_path and structure_sets - - ''' + list of dict : + structures, after conversion of structure_id_path and structure_sets + + """ if whitelist is None: whitelist = StructureTree.whitelist() @@ -355,21 +338,19 @@ def clean_structures(structures, whitelist=None, data_transforms=None, renames=N if data_transforms is None: data_transforms = StructureTree.data_transforms() - if renames is None: + if renames is None: renames = StructureTree.renames() whitelist.extend(renames.values()) for ii, st in enumerate(structures): - StructureTree.collect_sets(st) record = {} for name in whitelist: - if name not in st: continue data = st[name] - + if name in data_transforms: data = data_transforms[name](data) @@ -381,77 +362,78 @@ def clean_structures(structures, whitelist=None, data_transforms=None, renames=N structures[ii] = record return structures - + @staticmethod def data_transforms(): - return {'color_hex_triplet': StructureTree.hex_to_rgb, - 'structure_id_path': StructureTree.path_to_list} - + return {"color_hex_triplet": StructureTree.hex_to_rgb, "structure_id_path": StructureTree.path_to_list} @staticmethod def renames(): - return {'color_hex_triplet': 'rgb_triplet'} + return {"color_hex_triplet": "rgb_triplet"} @staticmethod def whitelist(): - return ['acronym', 'color_hex_triplet', 'graph_id', 'graph_order', 'id', - 'name', 'structure_id_path', 'structure_set_ids'] - - + return [ + "acronym", + "color_hex_triplet", + "graph_id", + "graph_order", + "id", + "name", + "structure_id_path", + "structure_set_ids", + ] + @staticmethod def hex_to_rgb(hex_color): - '''Convert a hexadecimal color string to a uint8 triplet - + """Convert a hexadecimal color string to a uint8 triplet + Parameters ---------- - hex_color : string - Must be 6 characters long, unless it is 7 long and the first - character is #. If hex_color is a triplet of int, it will be + hex_color : string + Must be 6 characters long, unless it is 7 long and the first + character is #. If hex_color is a triplet of int, it will be returned unchanged. - + Returns ------- - list of int : + list of int : 3 characters long - 1 per two characters in the input string. - - ''' - + + """ + if not isinstance(hex_color, str): return list(hex_color) - if hex_color[0] == '#': + if hex_color[0] == "#": hex_color = hex_color[1:] - - return [int(hex_color[a * 2: a*2 + 2], 16) for a in range(3)] - + + return [int(hex_color[a * 2 : a * 2 + 2], 16) for a in range(3)] @staticmethod def path_to_list(path): - '''Structure id paths are sometimes formatted as "/"-seperated strings. + """Structure id paths are sometimes formatted as "/"-seperated strings. This method converts them to a list of integers, if needed. - ''' + """ if not isinstance(path, str): return list(path) - return [int(stid) for stid in path.split('/') if stid != ''] - + return [int(stid) for stid in path.split("/") if stid != ""] @staticmethod def collect_sets(structure): - '''Structure sets may be specified by full records or id. This method - collects all of the structure set records/ids in a structure record and + """Structure sets may be specified by full records or id. This method + collects all of the structure set records/ids in a structure record and replaces them with a single list of id records. - ''' - - if 'structure_sets' not in structure: - structure['structure_sets'] = [] - if 'structure_set_ids' not in structure: - structure['structure_set_ids'] = [] - - structure['structure_set_ids'].extend([sts['id'] for sts - in structure['structure_sets']]) - structure['structure_set_ids'] = list(set(structure['structure_set_ids'])) - - del structure['structure_sets'] - + """ + + if "structure_sets" not in structure: + structure["structure_sets"] = [] + if "structure_set_ids" not in structure: + structure["structure_set_ids"] = [] + + structure["structure_set_ids"].extend([sts["id"] for sts in structure["structure_sets"]]) + structure["structure_set_ids"] = list(set(structure["structure_set_ids"])) + + del structure["structure_sets"] diff --git a/allensdk/core/swc.py b/allensdk/core/swc.py index 34f13740f2..fd2eed31b3 100644 --- a/allensdk/core/swc.py +++ b/allensdk/core/swc.py @@ -38,17 +38,17 @@ import math # Morphology nodes have the following fields. SWC fields are numeric. -NODE_ID = 'id' -NODE_TYPE = 'type' -NODE_X = 'x' -NODE_Y = 'y' -NODE_Z = 'z' -NODE_R = 'radius' -NODE_PN = 'parent' +NODE_ID = "id" +NODE_TYPE = "type" +NODE_X = "x" +NODE_Y = "y" +NODE_Z = "z" +NODE_R = "radius" +NODE_PN = "parent" SWC_COLUMNS = [NODE_ID, NODE_TYPE, NODE_X, NODE_Y, NODE_Z, NODE_R, NODE_PN] -NODE_TREE_ID = 'tree_id' -NODE_CHILDREN = 'children' +NODE_TREE_ID = "tree_id" +NODE_CHILDREN = "children" # shorthand for dictionary entries, to shorten sometimes long code lines _N = NODE_ID @@ -83,21 +83,23 @@ def read_swc(file_name, columns="NOT_USED", numeric_columns="NOT_USED"): with open(file_name, "r") as f: for line in f: # remove comments - if line.lstrip().startswith('#'): + if line.lstrip().startswith("#"): continue # read values. expected SWC format is: # ID, type, x, y, z, rad, parent # x, y, z and rad are floats. the others are ints - toks = line.split(' ') - vals = Compartment({ - NODE_ID: int(toks[0]), - NODE_TYPE: int(toks[1]), - NODE_X: float(toks[2]), - NODE_Y: float(toks[3]), - NODE_Z: float(toks[4]), - NODE_R: float(toks[5]), - NODE_PN: int(toks[6].rstrip()) - }) + toks = line.split(" ") + vals = Compartment( + { + NODE_ID: int(toks[0]), + NODE_TYPE: int(toks[1]), + NODE_X: float(toks[2]), + NODE_Y: float(toks[3]), + NODE_Z: float(toks[4]), + NODE_R: float(toks[5]), + NODE_PN: int(toks[6].rstrip()), + } + ) # store this compartment compartments.append(vals) # increment line number (used for error reporting only) @@ -121,15 +123,16 @@ class Compartment(dict): def __init__(self, *args, **kwargs): super(Compartment, self).__init__(*args, **kwargs) - if (NODE_ID not in self or - NODE_TYPE not in self or - NODE_X not in self or - NODE_Y not in self or - NODE_Z not in self or - NODE_R not in self or - NODE_PN not in self): - raise ValueError( - "Compartment was not initialized with requisite fields") + if ( + NODE_ID not in self + or NODE_TYPE not in self + or NODE_X not in self + or NODE_Y not in self + or NODE_Z not in self + or NODE_R not in self + or NODE_PN not in self + ): + raise ValueError("Compartment was not initialized with requisite fields") # Each unconnected graph has its own ID. This is the ID # of graph that the node resides in self[NODE_TREE_ID] = -1 @@ -138,9 +141,11 @@ def __init__(self, *args, **kwargs): self[NODE_CHILDREN] = [] def print_node(self): - """ print out compartment information with field names """ - print("%d %d %.4f %.4f %.4f %.4f %d %s %d" % (self[_N], self[_TYP], self[ - _X], self[_Y], self[_Z], self[_R], self[_P], str(self[_C]), self[_TID])) + """print out compartment information with field names""" + print( + "%d %d %.4f %.4f %.4f %.4f %d %s %d" + % (self[_N], self[_TYP], self[_X], self[_Y], self[_Z], self[_R], self[_P], str(self[_C]), self[_TID]) + ) class Morphology(object): @@ -209,36 +214,36 @@ def __init__(self, compartment_list=None, compartment_index=None): @property def compartment_list(self): - """ Return the compartment list. This is a property to ensure that the - compartment list and compartment index are in sync. """ + """Return the compartment list. This is a property to ensure that the + compartment list and compartment index are in sync.""" return self._compartment_list @compartment_list.setter def compartment_list(self, compartment_list): - """ Update the compartment list. Update the compartment index. """ + """Update the compartment list. Update the compartment index.""" self._set_compartments(compartment_list) @property def compartment_index(self): - """ Return the compartment index. This is a property to ensure that the - compartment list and compartment index are in sync. """ + """Return the compartment index. This is a property to ensure that the + compartment list and compartment index are in sync.""" return self._compartment_index @compartment_index.setter def compartment_index(self, compartment_index): - """ Update the compartment index. Update the compartment list. """ + """Update the compartment index. Update the compartment list.""" self._set_compartments(compartment_index.values()) @property def num_trees(self): - """ Return the number of trees in the morphology. A tree is - defined as everything following from a single root compartment. """ + """Return the number of trees in the morphology. A tree is + defined as everything following from a single root compartment.""" return len(self._tree_list) # TODO add filter for number of nodes of a particular type @property def num_nodes(self): - """ Return the number of compartments in the morphology. """ + """Return the number of compartments in the morphology.""" return len(self.compartment_list) # internal function @@ -263,12 +268,12 @@ def _set_compartments(self, compartment_list): @property def soma(self): - """ Returns root node of soma, if present""" + """Returns root node of soma, if present""" return self._soma @property def root(self): - """ [deprecated] Returns root node of soma, if present. Use 'soma' instead of 'root'""" + """[deprecated] Returns root node of soma, if present. Use 'soma' instead of 'root'""" return self._soma #################################################################### @@ -314,7 +319,7 @@ def node(self, n): return self._resolve_node_type(n) def parent_of(self, seg): - """ Returns parent of the specified node. + """Returns parent of the specified node. Parameters ---------- @@ -336,7 +341,7 @@ def parent_of(self, seg): return None def children_of(self, seg): - """ Returns a list of the children of the specified node + """Returns a list of the children of the specified node Parameters ---------- @@ -368,12 +373,11 @@ def _resolve_node_type(self, seg): return None seg = self._compartment_list[seg] except ValueError: - raise TypeError( - "Object not recognized as morphology node or index") + raise TypeError("Object not recognized as morphology node or index") return seg def change_parent(self, child, parent): - """ Change the parent of a node. The child node is adjusted to + """Change the parent of a node. The child node is adjusted to point to the new parent, the child is taken off of the previous parent's child list, and it is added to the new parent's child list. @@ -400,7 +404,7 @@ def change_parent(self, child, parent): # returns a list of nodes located within dist of x,y,z def find(self, x, y, z, dist, node_type=None): - """ Returns a list of Morphology Objects located within 'dist' + """Returns a list of Morphology Objects located within 'dist' of coordinate (x,y,z). If node_type is specified, the search will be constrained to return only nodes of that type. @@ -431,7 +435,7 @@ def find(self, x, y, z, dist, node_type=None): return found def compartment_list_by_type(self, compartment_type): - """ Return an list of all compartments having the specified + """Return an list of all compartments having the specified compartment type. Parameters @@ -446,7 +450,7 @@ def compartment_list_by_type(self, compartment_type): return [x for x in self._compartment_list if x[NODE_TYPE] == compartment_type] def compartment_index_by_type(self, compartment_type): - """ Return an dictionary of compartments indexed by id that all have + """Return an dictionary of compartments indexed by id that all have a particular compartment type. Parameters @@ -461,7 +465,7 @@ def compartment_index_by_type(self, compartment_type): return {c[NODE_ID]: c for c in self._compartment_list if c[NODE_TYPE] == compartment_type} def save(self, file_name): - """ Write this morphology out to an SWC file + """Write this morphology out to an SWC file Parameters ---------- @@ -484,7 +488,7 @@ def write(self, file_name): self.save(file_name) def sparsify(self, modulo, compress_ids=False): - """ Return a new Morphology object that has a given number of non-leaf, + """Return a new Morphology object that has a given number of non-leaf, non-root nodes removed. IDs can be reassigned so as to be continuous. Parameters @@ -528,8 +532,7 @@ def sparsify(self, modulo, compress_ids=False): compartments[child_id][NODE_PN] = parent_id # filter out the orphans - sparsified_compartments = {k: v for k, - v in compartments.items() if keep[k]} + sparsified_compartments = {k: v for k, v in compartments.items() if keep[k]} if compress_ids: ids = sorted(sparsified_compartments.keys(), key=lambda x: int(x)) id_hash = {fid: str(i + 1) for i, fid in enumerate(ids)} @@ -584,14 +587,12 @@ def _reconstruct(self): for seg in self.compartment_list: par_num = seg[NODE_PN] if par_num >= 0: - self.compartment_list[par_num][ - NODE_CHILDREN].append(seg[NODE_ID]) + self.compartment_list[par_num][NODE_CHILDREN].append(seg[NODE_ID]) # update tree lists self._separate_trees() ############################ # Rebuild internal index and links between parents and children - self._compartment_index = { - c[NODE_ID]: c for c in self.compartment_list} + self._compartment_index = {c[NODE_ID]: c for c in self.compartment_list} # compartment list is complete and sequential so don't need index # to resolve relationships # for each node, reset children array @@ -600,17 +601,15 @@ def _reconstruct(self): seg[NODE_CHILDREN] = [] for seg in self._compartment_list: if seg[NODE_PN] >= 0: - self._compartment_list[seg[NODE_PN]][ - NODE_CHILDREN].append(seg[NODE_ID]) + self._compartment_list[seg[NODE_PN]][NODE_CHILDREN].append(seg[NODE_ID]) # verify that each node ID is the same as its position in the # compartment list for i in range(len(self.compartment_list)): if i != self.node(i)[NODE_ID]: - raise RuntimeError( - "Internal error detected -- compartment list not properly formed") + raise RuntimeError("Internal error detected -- compartment list not properly formed") def append(self, node_list): - """ Add additional nodes to this Morphology. Those nodes must + """Add additional nodes to this Morphology. Those nodes must originate from another morphology object. Parameters @@ -639,7 +638,7 @@ def append(self, node_list): self._reconstruct() def stumpify_axon(self, count=10): - """ Remove all axon compartments except the first 'count' + """Remove all axon compartments except the first 'count' nodes, as counted from the connected axon root. Parameters @@ -677,7 +676,7 @@ def stumpify_axon(self, count=10): # strip out everything but the soma and the specified SWC type def strip_all_other_types(self, node_type, keep_soma=True): - """ Strips everything from the morphology except for the + """Strips everything from the morphology except for the specified type. Parent and child relationships are updated accordingly, creating new roots when necessary. @@ -718,7 +717,7 @@ def strip_all_other_types(self, node_type, keep_soma=True): # strip out the specified SWC type def strip_type(self, node_type): - """ Strips all compartments of the specified type from the + """Strips all compartments of the specified type from the morphology. Parent and child relationships are updated accordingly, creating new roots when necessary. @@ -750,7 +749,7 @@ def strip_type(self, node_type): # strip out the specified SWC type def convert_type(self, old_type, new_type): - """ Converts all compartments from one type to another. + """Converts all compartments from one type to another. Nodes of the original type are not affected so this procedure can also be used as a merge procedure. @@ -771,7 +770,7 @@ def convert_type(self, old_type, new_type): seg[NODE_TYPE] = new_type def apply_affine(self, aff, scale=None): - """ Apply an affine transform to all compartments in this + """Apply an affine transform to all compartments in this morphology. Node radius is adjusted as well. Format of the affine matrix is: @@ -815,20 +814,17 @@ def apply_affine(self, aff, scale=None): det_scale = math.pow(abs(det), 1.0 / 3.0) # measure scale along each axis # keep this code here in case - #scale_x = abs(aff[0] + aff[3] + aff[6]) - #scale_y = abs(aff[1] + aff[4] + aff[7]) - #scale_z = abs(aff[2] + aff[5] + aff[8]) - #avg_scale = (scale_x + scale_y + scale_z) / 3.0; + # scale_x = abs(aff[0] + aff[3] + aff[6]) + # scale_y = abs(aff[1] + aff[4] + aff[7]) + # scale_z = abs(aff[2] + aff[5] + aff[8]) + # avg_scale = (scale_x + scale_y + scale_z) / 3.0; # # use determinant for scaling for now as it's most simple scale = det_scale for seg in self.compartment_list: - x = seg[NODE_X] * aff[0] + seg[NODE_Y] * \ - aff[1] + seg[NODE_Z] * aff[2] + aff[9] - y = seg[NODE_X] * aff[3] + seg[NODE_Y] * \ - aff[4] + seg[NODE_Z] * aff[5] + aff[10] - z = seg[NODE_X] * aff[6] + seg[NODE_Y] * \ - aff[7] + seg[NODE_Z] * aff[8] + aff[11] + x = seg[NODE_X] * aff[0] + seg[NODE_Y] * aff[1] + seg[NODE_Z] * aff[2] + aff[9] + y = seg[NODE_X] * aff[3] + seg[NODE_Y] * aff[4] + seg[NODE_Z] * aff[5] + aff[10] + z = seg[NODE_X] * aff[6] + seg[NODE_Y] * aff[7] + seg[NODE_Z] * aff[8] + aff[11] seg[NODE_X] = x seg[NODE_Y] = y seg[NODE_Z] = z @@ -849,8 +845,7 @@ def _separate_trees(self): # see what trees this node is adjacent to local_trees = [] if seg[NODE_PN] >= 0 and self.compartment_list[seg[NODE_PN]][NODE_TREE_ID] >= 0: - local_trees.append(self.compartment_list[ - seg[NODE_PN]][NODE_TREE_ID]) + local_trees.append(self.compartment_list[seg[NODE_PN]][NODE_TREE_ID]) for child_id in seg[NODE_CHILDREN]: child = self.compartment_list[child_id] if child[NODE_TREE_ID] >= 0: @@ -860,7 +855,7 @@ def _separate_trees(self): if len(local_trees) == 0: tree_num = len(trees) # create new tree elif len(local_trees) == 1: - tree_num = local_trees[0] # use existing tree + tree_num = local_trees[0] # use existing tree elif len(local_trees) > 1: # this node is an intersection of multiple trees # merge all trees into the first one found @@ -918,8 +913,7 @@ def _check_consistency(self): for seg in self.compartment_list: if seg[NODE_PN] >= 0: if seg[NODE_PN] >= n: - print("Parent for node %d is invalid (%d)" % - (seg[NODE_ID], seg[NODE_PN])) + print("Parent for node %d is invalid (%d)" % (seg[NODE_ID], seg[NODE_PN])) errs += 1 # make sure that each tree has exactly one root for i in range(self.num_trees): @@ -970,7 +964,7 @@ def _find_type_boundary(self): # remove tree from swc's "forest" def delete_tree(self, n): - """ Delete tree, and all of its compartments, from the morphology. + """Delete tree, and all of its compartments, from the morphology. Parameters ---------- @@ -997,11 +991,12 @@ def _print_all_nodes(self): for node in self.compartment_list: print(node) + ######################################################################## class Marker(dict): - """ Simple dictionary class for handling reconstruction marker objects. """ + """Simple dictionary class for handling reconstruction marker objects.""" - SPACING = [.1144, .1144, .28] + SPACING = [0.1144, 0.1144, 0.28] CUT_DENDRITE = 10 NO_RECONSTRUCTION = 20 @@ -1011,20 +1006,20 @@ def __init__(self, *args, **kwargs): # marker file x,y,z coordinates are offset by a single image-space # pixel - self['x'] -= self.SPACING[0] - self['y'] -= self.SPACING[1] - self['z'] -= self.SPACING[2] + self["x"] -= self.SPACING[0] + self["y"] -= self.SPACING[1] + self["z"] -= self.SPACING[2] def read_marker_file(file_name): - """ read in a marker file and return a list of dictionaries """ + """read in a marker file and return a list of dictionaries""" - with open(file_name, 'r') as f: - rows = csv.DictReader((r for r in f if not r.startswith('#')), - fieldnames=['x', 'y', 'z', 'radius', 'shape', 'name', 'comment', - 'color_r', 'color_g', 'color_b']) + with open(file_name, "r") as f: + rows = csv.DictReader( + (r for r in f if not r.startswith("#")), + fieldnames=["x", "y", "z", "radius", "shape", "name", "comment", "color_r", "color_g", "color_b"], + ) - return [Marker({'x': float(r['x']), - 'y': float(r['y']), - 'z': float(r['z']), - 'name': int(r['name'])}) for r in rows] + return [ + Marker({"x": float(r["x"]), "y": float(r["y"]), "z": float(r["z"]), "name": int(r["name"])}) for r in rows + ] diff --git a/allensdk/core/typing.py b/allensdk/core/typing.py index 86a7353c6a..e145aac3b7 100644 --- a/allensdk/core/typing.py +++ b/allensdk/core/typing.py @@ -4,12 +4,13 @@ except ImportError: # for Python 3.7 and before from typing import _Protocol as Protocol - + from abc import abstractmethod class SupportsStr(Protocol): """Classes that support the __str__ method""" + @abstractmethod def __str__(self) -> str: pass diff --git a/allensdk/core/utilities.py b/allensdk/core/utilities.py index d5eb2d2efa..8a5adc3d5c 100644 --- a/allensdk/core/utilities.py +++ b/allensdk/core/utilities.py @@ -10,9 +10,9 @@ def literal_col_eval(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame: for column in columns: if column in df.columns: - df.loc[df[column].notnull(), column] = df[column][ - df[column].notnull() - ].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x) + df.loc[df[column].notnull(), column] = df[column][df[column].notnull()].apply( + lambda x: ast.literal_eval(x) if isinstance(x, str) else x + ) return df @@ -21,7 +21,7 @@ def df_list_to_tuple(df: pd.DataFrame, columns: List[str]) -> pd.DataFrame: for column in columns: if column in df.columns: - df.loc[df[column].notnull(), column] = df[column][ - df[column].notnull() - ].apply(lambda x: tuple(x) if isinstance(x, list) else x) + df.loc[df[column].notnull(), column] = df[column][df[column].notnull()].apply( + lambda x: tuple(x) if isinstance(x, list) else x + ) return df diff --git a/allensdk/deprecated.py b/allensdk/deprecated.py index 748946d932..26c117eb15 100644 --- a/allensdk/deprecated.py +++ b/allensdk/deprecated.py @@ -36,71 +36,72 @@ import copy import warnings import functools + try: from numpy import VisibleDeprecationWarning except ImportError: VisibleDeprecationWarning = DeprecationWarning - -def deprecated(message=None): +def deprecated(message=None): if message is None: - message = '' - + message = "" + def output_decorator(fn): - @functools.wraps(fn) def wrapper(*args, **kwargs): - - warnings.warn("Function {0} is deprecated. {1}".format( - fn.__name__, message), - category=VisibleDeprecationWarning, stacklevel=2) - + warnings.warn( + "Function {0} is deprecated. {1}".format(fn.__name__, message), + category=VisibleDeprecationWarning, + stacklevel=2, + ) + return fn(*args, **kwargs) - + return wrapper - + return output_decorator - - -def class_deprecated(message=None): + +def class_deprecated(message=None): if message is None: - message = '' - + message = "" + def output_class_decorator(cls): - fn_copy = copy.deepcopy(cls.__init__) - + @functools.wraps(cls.__init__) def wrapper(*args, **kwargs): - warnings.warn("Class {0} is deprecated. {1}".format( - cls.__name__, message), - category=VisibleDeprecationWarning, stacklevel=2) + warnings.warn( + "Class {0} is deprecated. {1}".format(cls.__name__, message), + category=VisibleDeprecationWarning, + stacklevel=2, + ) fn_copy(*args, **kwargs) - + cls.__init__ = wrapper return cls - + return output_class_decorator def legacy(message=None): - if message is None: - message = '' - + message = "" + def output_decorator(fn): - @functools.wraps(fn) def wrapper(*args, **kwargs): - - warnings.warn("Function {0} is provided for backward-compatibilty with a legacy API, and may be removed in the future. {1}".format( - fn.__name__, message), - category=VisibleDeprecationWarning, stacklevel=2) - + warnings.warn( + "Function {0} is provided for backward-compatibilty with a legacy API, and may be removed in the future. {1}".format( + fn.__name__, message + ), + category=VisibleDeprecationWarning, + stacklevel=2, + ) + return fn(*args, **kwargs) - + return wrapper - - return output_decorator \ No newline at end of file + + return output_decorator diff --git a/allensdk/ephys/__init__.py b/allensdk/ephys/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/ephys/__init__.py +++ b/allensdk/ephys/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/ephys/ephys_extractor.py b/allensdk/ephys/ephys_extractor.py index a3685c8761..497db44a3d 100644 --- a/allensdk/ephys/ephys_extractor.py +++ b/allensdk/ephys/ephys_extractor.py @@ -54,13 +54,23 @@ class EphysSweepFeatureExtractor: """Feature calculation for a sweep (voltage and/or current time series).""" - def __init__(self, t=None, v=None, i=None, start=None, end=None, - filter=10., - dv_cutoff=20., max_interval=0.005, min_height=2., - min_peak=-30., - thresh_frac=0.05, baseline_interval=0.1, - baseline_detect_thresh=0.3, - id=None): + def __init__( + self, + t=None, + v=None, + i=None, + start=None, + end=None, + filter=10.0, + dv_cutoff=20.0, + max_interval=0.005, + min_height=2.0, + min_peak=-30.0, + thresh_frac=0.05, + baseline_interval=0.1, + baseline_detect_thresh=0.3, + id=None, + ): """Initialize SweepFeatures object. Parameters @@ -117,51 +127,35 @@ def _process_individual_spikes(self): dvdt = ft.calculate_dvdt(v, t, self.filter) # Basic features of spikes - putative_spikes = ft.detect_putative_spikes(v, t, self.start, self.end, - self.filter, - self.dv_cutoff) + putative_spikes = ft.detect_putative_spikes(v, t, self.start, self.end, self.filter, self.dv_cutoff) peaks = ft.find_peak_indexes(v, t, putative_spikes, self.end) - putative_spikes, peaks = ft.filter_putative_spikes(v, t, - putative_spikes, - peaks, - self.min_height, - self.min_peak, - dvdt=dvdt, - filter=self.filter) + putative_spikes, peaks = ft.filter_putative_spikes( + v, t, putative_spikes, peaks, self.min_height, self.min_peak, dvdt=dvdt, filter=self.filter + ) if not putative_spikes.size: # Save time if no spikes detected self._spikes_df = DataFrame() return - upstrokes = ft.find_upstroke_indexes(v, t, putative_spikes, peaks, - self.filter, dvdt) - thresholds = ft.refine_threshold_indexes(v, t, upstrokes, - self.thresh_frac, - self.filter, dvdt) + upstrokes = ft.find_upstroke_indexes(v, t, putative_spikes, peaks, self.filter, dvdt) + thresholds = ft.refine_threshold_indexes(v, t, upstrokes, self.thresh_frac, self.filter, dvdt) thresholds, peaks, upstrokes, clipped = ft.check_thresholds_and_peaks( - v, t, thresholds, peaks, - upstrokes, self.end, self.max_interval, - dvdt=dvdt, filter=self.filter) + v, t, thresholds, peaks, upstrokes, self.end, self.max_interval, dvdt=dvdt, filter=self.filter + ) if not thresholds.size: # Save time if no spikes detected self._spikes_df = DataFrame() return # Spike list and thresholds have been refined - now find other features - upstrokes = ft.find_upstroke_indexes(v, t, thresholds, peaks, - self.filter, dvdt) - troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, - self.end) - downstrokes = ft.find_downstroke_indexes(v, t, peaks, troughs, clipped, - self.filter, dvdt) - trough_details, clipped = ft.analyze_trough_details(v, t, thresholds, - peaks, clipped, - self.end, - self.filter, - dvdt=dvdt) - widths = ft.find_widths(v, t, thresholds, peaks, trough_details[1], - clipped) + upstrokes = ft.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt) + troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end) + downstrokes = ft.find_downstroke_indexes(v, t, peaks, troughs, clipped, self.filter, dvdt) + trough_details, clipped = ft.analyze_trough_details( + v, t, thresholds, peaks, clipped, self.end, self.filter, dvdt=dvdt + ) + widths = ft.find_widths(v, t, thresholds, peaks, trough_details[1], clipped) base_clipped_list = [] @@ -174,16 +168,12 @@ def _process_individual_spikes(self): base_clipped_list += ["trough"] # Points where we care about t and dv/dt - dvdt_data_indexes = { - "upstroke": upstrokes, - "downstroke": downstrokes - } + dvdt_data_indexes = {"upstroke": upstrokes, "downstroke": downstrokes} base_clipped_list += ["downstroke"] # Trough details isi_types = trough_details[0] - trough_detail_indexes = dict( - zip(["fast_trough", "adp", "slow_trough"], trough_details[1:])) + trough_detail_indexes = dict(zip(["fast_trough", "adp", "slow_trough"], trough_details[1:])) base_clipped_list += ["fast_trough", "adp", "slow_trough"] # Redundant, but ensures that DataFrame has right number of rows @@ -265,8 +255,7 @@ def _process_individual_spikes(self): spikes_df["width"] = widths self._affected_by_clipping += ["width"] - spikes_df["upstroke_downstroke_ratio"] = \ - (spikes_df["upstroke"] / -spikes_df["downstroke"]) + spikes_df["upstroke_downstroke_ratio"] = spikes_df["upstroke"] / -spikes_df["downstroke"] self._affected_by_clipping += ["upstroke_downstroke_ratio"] self._spikes_df = spikes_df @@ -282,19 +271,16 @@ def _process_spike_related_features(self): isis = ft.get_isis(t, thresholds) with warnings.catch_warnings(): # ignore mean of empty slice warnings here - warnings.filterwarnings("ignore", category=RuntimeWarning, - module="numpy") + warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy") sweep_level_features = { "adapt": ft.adaptation_index(isis), "latency": ft.latency(t, thresholds, self.start), - "isi_cv": (isis.std() / isis.mean()) if len( - isis) >= 1 else np.nan, + "isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan, "mean_isi": isis.mean() if len(isis) > 0 else np.nan, "median_isi": np.median(isis), "first_isi": isis[0] if len(isis) >= 1 else np.nan, - "avg_rate": ft.average_rate(t, thresholds, self.start, - self.end), + "avg_rate": ft.average_rate(t, thresholds, self.start, self.end), } for k, v in sweep_level_features.items(): @@ -335,7 +321,7 @@ def pause_metrics(self): pause_list = self._process_pauses(weight) if len(pause_list) == 0: - return 0, 0. + return 0, 0.0 n_pauses = len(pause_list) pause_frac = isis[pause_list].sum() @@ -355,9 +341,7 @@ def _process_bursts(self, tol=0.5, pause_cost=1.0): slow_tr_t = self._spikes_df["slow_trough_t"].values thr_v = self._spikes_df["threshold_v"].values - bursts = ft.detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, - slow_tr_v, slow_tr_t, - thr_v, tol, pause_cost) + bursts = ft.detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, thr_v, tol, pause_cost) return np.array(bursts) @@ -376,7 +360,7 @@ def burst_metrics(self): if burst_info.shape[0] > 0: return burst_info[:, 0].max(), burst_info.shape[0] else: - return 0., 0 + return 0.0, 0 def delay_metrics(self): """Calculates ratio of latency to dominant time constant of rise @@ -390,7 +374,7 @@ def delay_metrics(self): if len(self._spikes_df) == 0: logging.info("No spikes available for delay calculation") - return 0., 0. + return 0.0, 0.0 start = self.start spike_time = self._spikes_df["threshold_t"].values[0] @@ -403,33 +387,23 @@ def delay_metrics(self): def _get_baseline_voltage(self): v = self.v t = self.t - filter_frequency = 1. # in kHz + filter_frequency = 1.0 # in kHz # Look at baseline interval before start if start is defined if self.start is not None: - return ft.average_voltage(v, t, - self.start - self.baseline_interval, - self.start) + return ft.average_voltage(v, t, self.start - self.baseline_interval, self.start) # Otherwise try to find an interval where things are pretty flat dv = ft.calculate_dvdt(v, t, filter_frequency) - non_flat_points = np.flatnonzero( - np.abs(dv >= self.baseline_detect_thresh)) + non_flat_points = np.flatnonzero(np.abs(dv >= self.baseline_detect_thresh)) flat_intervals = t[non_flat_points[1:]] - t[non_flat_points[:-1]] - long_flat_intervals = np.flatnonzero( - flat_intervals >= self.baseline_interval) + long_flat_intervals = np.flatnonzero(flat_intervals >= self.baseline_interval) if long_flat_intervals.size > 0: interval_index = long_flat_intervals[0] + 1 baseline_end_time = t[non_flat_points[interval_index]] - return ft.average_voltage(v, t, - baseline_end_time - - self.baseline_interval, - baseline_end_time) + return ft.average_voltage(v, t, baseline_end_time - self.baseline_interval, baseline_end_time) else: - logging.info( - "Could not find sufficiently flat interval for automatic " - "baseline voltage", - RuntimeWarning) + logging.info("Could not find sufficiently flat interval for automatic baseline voltage", RuntimeWarning) return np.nan def voltage_deflection(self, deflect_type=None): @@ -467,8 +441,7 @@ def voltage_deflection(self, deflect_type=None): if deflect_type is None: if self.i is not None: - halfway_index = ft.find_time_index(self.t, - (end - start) / 2. + start) + halfway_index = ft.find_time_index(self.t, (end - start) / 2.0 + start) if self.i[halfway_index] >= 0: deflect_type = "max" else: @@ -510,18 +483,15 @@ def estimate_time_constant(self): start_index = 0 frac = 0.1 - search_result = np.flatnonzero( - self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline) + search_result = np.flatnonzero(self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline) if not search_result.size: - raise ft.FeatureError( - "could not find interval for time constant estimate") + raise ft.FeatureError("could not find interval for time constant estimate") fit_start = self.t[search_result[0] + start_index] fit_end = self.t[peak_index] - a, inv_tau, y0 = ft.fit_membrane_time_constant(self.v, self.t, - fit_start, fit_end) + a, inv_tau, y0 = ft.fit_membrane_time_constant(self.v, self.t, fit_start, fit_end) - return 1. / inv_tau + return 1.0 / inv_tau def estimate_sag(self, peak_width=0.005): """Calculate the sag in a hyperpolarizing voltage response. @@ -548,21 +518,19 @@ def estimate_sag(self, peak_width=0.005): end = self.t[-1] v_peak, peak_index = self.voltage_deflection("min") - v_peak_avg = ft.average_voltage(v, t, - start=t[peak_index] - peak_width / 2., - end=t[peak_index] + peak_width / 2.) + v_peak_avg = ft.average_voltage( + v, t, start=t[peak_index] - peak_width / 2.0, end=t[peak_index] + peak_width / 2.0 + ) v_baseline = self.sweep_feature("v_baseline") - v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, - end=end) + v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, end=end) sag = (v_peak_avg - v_steady) / (v_peak_avg - v_baseline) return sag def spikes(self): """Get all features for each spike as a list of records.""" - return self._spikes_df.to_dict('records') + return self._spikes_df.to_dict("records") - def spike_feature(self, key, include_clipped=False, - force_exclude_clipped=False): + def spike_feature(self, key, include_clipped=False, force_exclude_clipped=False): """Get specified feature for every spike. Parameters @@ -579,24 +547,21 @@ def spike_feature(self, key, include_clipped=False, if not hasattr(self, "_spikes_df"): raise AttributeError( "EphysSweepFeatureExtractor instance attribute with spike " - "information does not exist yet - have spikes been processed?") + "information does not exist yet - have spikes been processed?" + ) if len(self._spikes_df) == 0: return np.array([]) if key not in self._spikes_df.columns: - raise KeyError( - "requested feature '{:s}' not available".format(key)) + raise KeyError("requested feature '{:s}' not available".format(key)) values = self._spikes_df[key].values if include_clipped and force_exclude_clipped: - raise ValueError( - "include_clipped and force_exclude_clipped cannot both be " - "true") + raise ValueError("include_clipped and force_exclude_clipped cannot both be true") - if not include_clipped and self.is_spike_feature_affected_by_clipping( - key): + if not include_clipped and self.is_spike_feature_affected_by_clipping(key): values = values[~self._spikes_df["clipped"].values] elif force_exclude_clipped: values = values[~self._spikes_df["clipped"].values] @@ -632,36 +597,30 @@ def sweep_feature(self, key, allow_missing=False): "stim_amp": self.stimulus_amplitude, } - if allow_missing and key not in self._sweep_features and key not in \ - on_request_dispatch: + if allow_missing and key not in self._sweep_features and key not in on_request_dispatch: return np.nan - elif key not in self._sweep_features and key not in \ - on_request_dispatch: - raise KeyError( - "requested feature '{:s}' not available".format(key)) + elif key not in self._sweep_features and key not in on_request_dispatch: + raise KeyError("requested feature '{:s}' not available".format(key)) if key not in self._sweep_features and key in on_request_dispatch: fn = on_request_dispatch[key] if fn is not None: self._sweep_features[key] = fn() else: - raise KeyError( - "requested feature '{:s}' not defined".format(key)) + raise KeyError("requested feature '{:s}' not defined".format(key)) return self._sweep_features[key] - def process_new_spike_feature(self, feature_name, feature_func, - affected_by_clipping=False): + def process_new_spike_feature(self, feature_name, feature_func, affected_by_clipping=False): """Add new spike-level feature calculation function - The function should take this sweep extractor as its argument. - Its results can be accessed by calling the method - spike_feature(). + The function should take this sweep extractor as its argument. + Its results can be accessed by calling the method + spike_feature(). """ if feature_name in self._spikes_df.columns: - raise KeyError( - "Feature {:s} already exists for sweep".format(feature_name)) + raise KeyError("Feature {:s} already exists for sweep".format(feature_name)) self._spikes_df[feature_name] = feature_func(self) @@ -671,14 +630,13 @@ def process_new_spike_feature(self, feature_name, feature_func, def process_new_sweep_feature(self, feature_name, feature_func): """Add new sweep-level feature calculation function - The function should take this sweep extractor as its argument. - Its results - can be accessed by calling the method sweep_feature(). + The function should take this sweep extractor as its argument. + Its results + can be accessed by calling the method sweep_feature(). """ if feature_name in self._sweep_features: - raise KeyError( - "Feature {:s} already exists for sweep".format(feature_name)) + raise KeyError("Feature {:s} already exists for sweep".format(feature_name)) self._sweep_features[feature_name] = feature_func(self) @@ -699,11 +657,23 @@ def as_dict(self): class EphysSweepSetFeatureExtractor: - def __init__(self, t_set=None, v_set=None, i_set=None, start=None, - end=None, - filter=10., dv_cutoff=20., max_interval=0.005, min_height=2., - min_peak=-30., thresh_frac=0.05, baseline_interval=0.1, - baseline_detect_thresh=0.3, id_set=None): + def __init__( + self, + t_set=None, + v_set=None, + i_set=None, + start=None, + end=None, + filter=10.0, + dv_cutoff=20.0, + max_interval=0.005, + min_height=2.0, + min_peak=-30.0, + thresh_frac=0.05, + baseline_interval=0.1, + baseline_detect_thresh=0.3, + id_set=None, + ): """Initialize EphysSweepSetFeatureExtractor object. Parameters @@ -733,11 +703,22 @@ def __init__(self, t_set=None, v_set=None, i_set=None, start=None, """ if t_set is not None and v_set is not None: - self._set_sweeps(t_set, v_set, i_set, start, end, filter, - dv_cutoff, max_interval, - min_height, min_peak, thresh_frac, - baseline_interval, - baseline_detect_thresh, id_set) + self._set_sweeps( + t_set, + v_set, + i_set, + start, + end, + filter, + dv_cutoff, + max_interval, + min_height, + min_peak, + thresh_frac, + baseline_interval, + baseline_detect_thresh, + id_set, + ) else: self._sweeps = None @@ -752,10 +733,23 @@ def from_sweeps(cls, sweep_list): obj._sweeps = sweep_list return obj - def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, - max_interval, - min_height, min_peak, thresh_frac, baseline_interval, - baseline_detect_thresh, id_set): + def _set_sweeps( + self, + t_set, + v_set, + i_set, + start, + end, + filter, + dv_cutoff, + max_interval, + min_height, + min_peak, + thresh_frac, + baseline_interval, + baseline_detect_thresh, + id_set, + ): if type(t_set) != list: raise ValueError("t_set must be a list") @@ -766,18 +760,15 @@ def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, raise ValueError("i_set must be a list") if len(t_set) != len(v_set): - raise ValueError( - "t_set and v_set must have the same number of items") + raise ValueError("t_set and v_set must have the same number of items") if i_set and len(t_set) != len(i_set): - raise ValueError( - "t_set and i_set must have the same number of items") + raise ValueError("t_set and i_set must have the same number of items") if id_set is None: id_set = range(len(t_set)) if len(id_set) != len(t_set): - raise ValueError( - "t_set and id_set must have the same number of items") + raise ValueError("t_set and id_set must have the same number of items") sweeps = [] if i_set is None: @@ -789,7 +780,11 @@ def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, sweeps = [ EphysSweepFeatureExtractor( - t, v, i, start, end, + t, + v, + i, + start, + end, filter=filter, dv_cutoff=dv_cutoff, max_interval=max_interval, @@ -798,9 +793,10 @@ def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, thresh_frac=thresh_frac, baseline_interval=baseline_interval, baseline_detect_thresh=baseline_detect_thresh, - id=sid) - for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, - end, id_set)] + id=sid, + ) + for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, end, id_set) + ] self._sweeps = sweeps @@ -827,22 +823,19 @@ def sweep_features(self, key, allow_missing=False): sweep_feature : nparray of sweep-level feature values """ - return np.array( - [swp.sweep_feature(key, allow_missing) for swp in self._sweeps]) + return np.array([swp.sweep_feature(key, allow_missing) for swp in self._sweeps]) def spike_feature_averages(self, key): """Get nparray of average spike-level feature (`key`) for all sweeps""" - return np.array( - [swp.spike_feature(key).mean() for swp in self._sweeps]) + return np.array([swp.spike_feature(key).mean() for swp in self._sweeps]) class EphysCellFeatureExtractor: # Class constants for specific processing SUBTHRESH_MAX_AMP = 0 - SAG_TARGET = -100. + SAG_TARGET = -100.0 - def __init__(self, ramps_ext, short_squares_ext, long_squares_ext, - subthresh_min_amp=-100): + def __init__(self, ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp=-100): """Initialize EphysCellFeatureExtractor object from EphysSweepSetExtractors for ramp, short square, and long square sweeps. @@ -900,8 +893,7 @@ def _analyze_ramps(self): self._all_ramps_ext = ext # pull out the spiking sweeps - spiking_sweeps = [sweep for sweep in self._ramps_ext.sweeps() - if sweep.sweep_feature("avg_rate") > 0] + spiking_sweeps = [sweep for sweep in self._ramps_ext.sweeps() if sweep.sweep_feature("avg_rate") > 0] ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps) self._ramps_ext = ext @@ -919,16 +911,12 @@ def _analyze_short_squares(self): # Need to count how many had spikes at each amplitude; find most; # ties go to lower amplitude - spiking_sweeps = [sweep for sweep in ext.sweeps() - if sweep.sweep_feature("avg_rate") > 0] + spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0] if len(spiking_sweeps) == 0: - raise ft.FeatureError( - "No spiking short square sweeps, cannot compute cell " - "features.") + raise ft.FeatureError("No spiking short square sweeps, cannot compute cell features.") - most_common = Counter( - map(_short_step_stim_amp, spiking_sweeps)).most_common() + most_common = Counter(map(_short_step_stim_amp, spiking_sweeps)).most_common() common_amp, common_count = most_common[0] for c in most_common[1:]: if c[1] < common_count: @@ -938,8 +926,8 @@ def _analyze_short_squares(self): self._features["short_squares"]["stimulus_amplitude"] = common_amp ext = EphysSweepSetFeatureExtractor.from_sweeps( - [sweep for sweep in spiking_sweeps - if _short_step_stim_amp(sweep) == common_amp]) + [sweep for sweep in spiking_sweeps if _short_step_stim_amp(sweep) == common_amp] + ) self._short_squares_ext = ext self._features["short_squares"]["common_amp_sweeps"] = ext.sweeps() @@ -966,81 +954,61 @@ def _analyze_long_squares_spiking(self, force_reprocess=False): spiking_indexes = np.flatnonzero(ext.sweep_features("avg_rate")) if len(spiking_indexes) == 0: - raise ft.FeatureError( - "No spiking long square sweeps, cannot compute cell features.") + raise ft.FeatureError("No spiking long square sweeps, cannot compute cell features.") amps = ext.sweep_features("stim_amp") # self.long_squares_stim_amps() min_index = np.argmin(amps[spiking_indexes]) rheobase_index = spiking_indexes[min_index] rheobase_i = _step_stim_amp(ext.sweeps()[rheobase_index]) - self._features["long_squares"][ - "rheobase_extractor_index"] = rheobase_index + self._features["long_squares"]["rheobase_extractor_index"] = rheobase_index self._features["long_squares"]["rheobase_i"] = rheobase_i - self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[ - rheobase_index] - spiking_sweeps = [sweep for sweep in ext.sweeps() - if sweep.sweep_feature("avg_rate") > 0] - self._spiking_long_squares_ext = \ - EphysSweepSetFeatureExtractor.from_sweeps( - spiking_sweeps) - self._features["long_squares"][ - "spiking_sweeps"] = self._spiking_long_squares_ext.sweeps() - - self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope( - self._spiking_long_squares_ext) + self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[rheobase_index] + spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0] + self._spiking_long_squares_ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps) + self._features["long_squares"]["spiking_sweeps"] = self._spiking_long_squares_ext.sweeps() + + self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope(self._spiking_long_squares_ext) def _analyze_long_squares_subthreshold(self): ext = self._long_squares_ext - subthresh_sweeps = [sweep for sweep in ext.sweeps() - if sweep.sweep_feature("avg_rate") == 0] - subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps( - subthresh_sweeps) + subthresh_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") == 0] + subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(subthresh_sweeps) self._subthreshold_long_squares_ext = subthresh_ext if len(subthresh_ext.sweeps()) == 0: - raise ft.FeatureError( - "No subthreshold long square sweeps, cannot evaluate cell " - "features.") + raise ft.FeatureError("No subthreshold long square sweeps, cannot evaluate cell features.") sags = subthresh_ext.sweep_features("sag") - sag_eval_levels = np.array([sweep.voltage_deflection()[0] for sweep in - subthresh_ext.sweeps()]) + sag_eval_levels = np.array([sweep.voltage_deflection()[0] for sweep in subthresh_ext.sweeps()]) target_level = self.SAG_TARGET closest_index = np.argmin(np.abs(sag_eval_levels - target_level)) self._features["long_squares"]["sag"] = sags[closest_index] - self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[ - closest_index] - self._features["long_squares"][ - "subthreshold_sweeps"] = subthresh_ext.sweeps() + self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[closest_index] + self._features["long_squares"]["subthreshold_sweeps"] = subthresh_ext.sweeps() for s in self._features["long_squares"]["subthreshold_sweeps"]: s.set_stimulus_amplitude_calculator(_step_stim_amp) logging.debug("subthresh_sweeps: %d", len(subthresh_sweeps)) - calc_subthresh_sweeps = \ - [sweep for sweep in subthresh_sweeps - if self._subthresh_min_amp < sweep.sweep_feature("stim_amp") < self.SUBTHRESH_MAX_AMP] # noqa F501 + calc_subthresh_sweeps = [ + sweep + for sweep in subthresh_sweeps + if self._subthresh_min_amp < sweep.sweep_feature("stim_amp") < self.SUBTHRESH_MAX_AMP + ] # noqa F501 logging.debug("calc_subthresh_sweeps: %d", len(calc_subthresh_sweeps)) - calc_subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps( - calc_subthresh_sweeps) + calc_subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(calc_subthresh_sweeps) self._subthreshold_membrane_property_ext = calc_subthresh_ext - self._features["long_squares"][ - "subthreshold_membrane_property_sweeps"] = \ - calc_subthresh_ext.sweeps() - self._features["long_squares"]["input_resistance"] = input_resistance( - calc_subthresh_ext) - self._features["long_squares"]["tau"] = membrane_time_constant( - calc_subthresh_ext) - self._features["long_squares"]["v_baseline"] = np.nanmean( - ext.sweep_features("v_baseline")) + self._features["long_squares"]["subthreshold_membrane_property_sweeps"] = calc_subthresh_ext.sweeps() + self._features["long_squares"]["input_resistance"] = input_resistance(calc_subthresh_ext) + self._features["long_squares"]["tau"] = membrane_time_constant(calc_subthresh_ext) + self._features["long_squares"]["v_baseline"] = np.nanmean(ext.sweep_features("v_baseline")) def long_squares_features(self, option=None): option_table = { "spiking": self._spiking_long_squares_ext, "subthreshold": self._subthreshold_long_squares_ext, - "subthreshold_membrane_property": - self._subthreshold_membrane_property_ext, + "subthreshold_membrane_property": self._subthreshold_membrane_property_ext, } if option: return option_table[option] @@ -1051,8 +1019,7 @@ def long_squares_stim_amps(self, option=None): option_table = { "spiking": self._spiking_long_squares_ext, "subthreshold": self._subthreshold_long_squares_ext, - "subthreshold_membrane_property": - self._subthreshold_membrane_property_ext, + "subthreshold_membrane_property": self._subthreshold_membrane_property_ext, } if option: ext = option_table[option] @@ -1076,22 +1043,17 @@ def as_dict(self): # convert feature extractor lists to sweep dictionarsweep extract lists ls_sweeps = [s.as_dict() for s in out["long_squares"]["sweeps"]] - ls_spike_sweeps = [s.as_dict() for s in - out["long_squares"]["spiking_sweeps"]] + ls_spike_sweeps = [s.as_dict() for s in out["long_squares"]["spiking_sweeps"]] rheo_sweep = out["long_squares"]["rheobase_sweep"].as_dict() - ls_sub_sweeps = [s.as_dict() for s in - out["long_squares"]["subthreshold_sweeps"]] - ls_sub_mem_sweeps = [s.as_dict() for s in out["long_squares"][ - "subthreshold_membrane_property_sweeps"]] - ss_sweeps = [s.as_dict() for s in - out["short_squares"]["common_amp_sweeps"]] + ls_sub_sweeps = [s.as_dict() for s in out["long_squares"]["subthreshold_sweeps"]] + ls_sub_mem_sweeps = [s.as_dict() for s in out["long_squares"]["subthreshold_membrane_property_sweeps"]] + ss_sweeps = [s.as_dict() for s in out["short_squares"]["common_amp_sweeps"]] ramp_sweeps = [s.as_dict() for s in out["ramps"]["spiking_sweeps"]] out["long_squares"]["sweeps"] = ls_sweeps out["long_squares"]["spiking_sweeps"] = ls_spike_sweeps out["long_squares"]["subthreshold_sweeps"] = ls_sub_sweeps - out["long_squares"][ - "subthreshold_membrane_property_sweeps"] = ls_sub_mem_sweeps + out["long_squares"]["subthreshold_membrane_property_sweeps"] = ls_sub_mem_sweeps out["long_squares"]["rheobase_sweep"] = rheo_sweep out["short_squares"]["common_amp_sweeps"] = ss_sweeps out["ramps"]["spiking_sweeps"] = ramp_sweeps @@ -1106,17 +1068,15 @@ def input_resistance(ext): sweeps = ext.sweeps() if not sweeps: - raise ft.FeatureError( - "no sweeps available for input resistance calculation") + raise ft.FeatureError("no sweeps available for input resistance calculation") v_vals = [] i_vals = [] for sweep in sweeps: if sweep.i is None: - raise ft.FeatureError( - "cannot calculate input resistance: i not defined for a sweep") + raise ft.FeatureError("cannot calculate input resistance: i not defined for a sweep") - v_peak, min_index = sweep.voltage_deflection('min') + v_peak, min_index = sweep.voltage_deflection("min") v_vals.append(v_peak) i_vals.append(sweep.i[min_index]) @@ -1127,7 +1087,7 @@ def input_resistance(ext): # If there's just one sweep, we'll have to use its own baseline to # estimate the input resistance v = np.append(v, sweeps[0].sweep_feature("v_baseline")) - i = np.append(i, 0.) + i = np.append(i, 0.0) A = np.vstack([i, np.ones_like(i)]).T m, c = np.linalg.lstsq(A, v)[0] @@ -1140,8 +1100,7 @@ def membrane_time_constant(ext): in passed extractor.""" with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=RuntimeWarning, - module="numpy") + warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy") avg_tau = np.nanmean(ext.sweep_features("tau")) return avg_tau @@ -1150,9 +1109,7 @@ def fit_fi_slope(ext): """Fit the rate and stimulus amplitude to a line and return the slope of the fit.""" if len(ext.sweeps()) < 2: - raise ft.FeatureError( - "Cannot fit f-I curve slope with less than two suprathreshold " - "sweeps") + raise ft.FeatureError("Cannot fit f-I curve slope with less than two suprathreshold sweeps") x = np.array(list(map(_step_stim_amp, ext.sweeps()))) y = ext.sweep_features("avg_rate") @@ -1170,8 +1127,7 @@ def reset_long_squares_start(when): LONG_SQUARES_END = when + delta -def cell_extractor_for_nwb(dataset, ramps, short_squares, long_squares, - subthresh_min_amp=-100): +def cell_extractor_for_nwb(dataset, ramps, short_squares, long_squares, subthresh_min_amp=-100): """Initialize EphysCellFeatureExtractor object from NWB data set Parameters @@ -1189,33 +1145,28 @@ def cell_extractor_for_nwb(dataset, ramps, short_squares, long_squares, if len(long_squares) == 0: raise ft.FeatureError("no long_square sweep numbers provided") - ramps_ext = extractor_for_nwb_sweeps(dataset, ramps, - fixed_start=RAMPS_START) + ramps_ext = extractor_for_nwb_sweeps(dataset, ramps, fixed_start=RAMPS_START) temp_short_sq_ext = extractor_for_nwb_sweeps(dataset, short_squares) t_set = [s.t for s in temp_short_sq_ext.sweeps()] v_set = [s.v for s in temp_short_sq_ext.sweeps()] - cutoff, thresh_frac = \ - ft.estimate_adjusted_detection_parameters(v_set, t_set, - SHORT_SQUARES_WINDOW_START, - SHORT_SQUARES_WINDOW_END) + cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters( + v_set, t_set, SHORT_SQUARES_WINDOW_START, SHORT_SQUARES_WINDOW_END + ) thresh_frac = max(thresh_frac, 0.1) - short_squares_ext = extractor_for_nwb_sweeps(dataset, short_squares, - dv_cutoff=cutoff, - thresh_frac=thresh_frac) - long_squares_ext = extractor_for_nwb_sweeps(dataset, long_squares, - fixed_start=LONG_SQUARES_START, - fixed_end=LONG_SQUARES_END) + short_squares_ext = extractor_for_nwb_sweeps(dataset, short_squares, dv_cutoff=cutoff, thresh_frac=thresh_frac) + long_squares_ext = extractor_for_nwb_sweeps( + dataset, long_squares, fixed_start=LONG_SQUARES_START, fixed_end=LONG_SQUARES_END + ) - return EphysCellFeatureExtractor(ramps_ext, short_squares_ext, - long_squares_ext, subthresh_min_amp) + return EphysCellFeatureExtractor(ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp) -def extractor_for_nwb_sweeps(dataset, sweep_numbers, - fixed_start=None, fixed_end=None, - dv_cutoff=20., thresh_frac=0.05): +def extractor_for_nwb_sweeps( + dataset, sweep_numbers, fixed_start=None, fixed_end=None, dv_cutoff=20.0, thresh_frac=0.05 +): v_set = [] t_set = [] i_set = [] @@ -1225,13 +1176,13 @@ def extractor_for_nwb_sweeps(dataset, sweep_numbers, for sweep_number in sweep_numbers: data = dataset.get_sweep(sweep_number) - v = data['response'] * 1e3 # mV - i = data['stimulus'] * 1e12 # pA - hz = data['sampling_rate'] - dt = 1. / hz + v = data["response"] * 1e3 # mV + i = data["stimulus"] * 1e12 # pA + hz = data["sampling_rate"] + dt = 1.0 / hz t = np.arange(0, len(v)) * dt # sec - s, e = dt * np.array(data['index_range']) + s, e = dt * np.array(data["index_range"]) v_set.append(v) i_set.append(i) t_set.append(t) @@ -1244,11 +1195,9 @@ def extractor_for_nwb_sweeps(dataset, sweep_numbers, start = fixed_start end = fixed_end - return EphysSweepSetFeatureExtractor(t_set, v_set, i_set, start=start, - end=end, - dv_cutoff=dv_cutoff, - thresh_frac=thresh_frac, - id_set=sweep_numbers) + return EphysSweepSetFeatureExtractor( + t_set, v_set, i_set, start=start, end=end, dv_cutoff=dv_cutoff, thresh_frac=thresh_frac, id_set=sweep_numbers + ) def _step_stim_amp(sweep): @@ -1258,4 +1207,4 @@ def _step_stim_amp(sweep): def _short_step_stim_amp(sweep): t_index = ft.find_time_index(sweep.t, sweep.start) - return sweep.i[t_index + 1:].max() + return sweep.i[t_index + 1 :].max() diff --git a/allensdk/ephys/ephys_features.py b/allensdk/ephys/ephys_features.py index de1984578f..f0af02ed04 100644 --- a/allensdk/ephys/ephys_features.py +++ b/allensdk/ephys/ephys_features.py @@ -40,7 +40,8 @@ from scipy.optimize import curve_fit from functools import partial -def detect_putative_spikes(v, t, start=None, end=None, filter=10., dv_cutoff=20.): + +def detect_putative_spikes(v, t, start=None, end=None, filter=10.0, dv_cutoff=20.0): """Perform initial detection of spikes and return their indexes. Parameters @@ -75,8 +76,8 @@ def detect_putative_spikes(v, t, start=None, end=None, filter=10., dv_cutoff=20. start_index = find_time_index(t, start) end_index = find_time_index(t, end) - v_window = v[start_index:end_index + 1] - t_window = t[start_index:end_index + 1] + v_window = v[start_index : end_index + 1] + t_window = t[start_index : end_index + 1] dvdt = calculate_dvdt(v_window, t_window, filter) @@ -88,8 +89,9 @@ def detect_putative_spikes(v, t, start=None, end=None, filter=10., dv_cutoff=20. return np.array(putative_spikes) + start_index # Only keep spike times if dV/dt has dropped all the way to zero between putative spikes - putative_spikes = [putative_spikes[0]] + [s for i, s in enumerate(putative_spikes[1:]) - if np.any(dvdt[putative_spikes[i]:s] < 0)] + putative_spikes = [putative_spikes[0]] + [ + s for i, s in enumerate(putative_spikes[1:]) if np.any(dvdt[putative_spikes[i] : s] < 0) + ] # Set back to original index space (not just window) return np.array(putative_spikes) + start_index @@ -111,14 +113,12 @@ def find_peak_indexes(v, t, spike_indexes, end=None): end_index = find_time_index(t, end) spks_and_end = np.append(spike_indexes, end_index) - peak_indexes = [np.argmax(v[spk:next]) + spk for spk, next in - zip(spks_and_end[:-1], spks_and_end[1:])] + peak_indexes = [np.argmax(v[spk:next]) + spk for spk, next in zip(spks_and_end[:-1], spks_and_end[1:])] return np.array(peak_indexes) -def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2., - min_peak=-30., filter=10., dvdt=None): +def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2.0, min_peak=-30.0, filter=10.0, dvdt=None): """Filter out events that are unlikely to be spikes based on: * Voltage failing to go down between peak and the next spike's threshold * Height (threshold to peak) @@ -147,9 +147,9 @@ def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2., if dvdt is None: dvdt = calculate_dvdt(v, t, filter) - diff_mask = [np.any(dvdt[peak_ind:spike_ind] < 0) - for peak_ind, spike_ind - in zip(peak_indexes[:-1], spike_indexes[1:])] + diff_mask = [ + np.any(dvdt[peak_ind:spike_ind] < 0) for peak_ind, spike_ind in zip(peak_indexes[:-1], spike_indexes[1:]) + ] peak_indexes = peak_indexes[np.array(diff_mask + [True])] spike_indexes = spike_indexes[np.array([True] + diff_mask)] @@ -164,7 +164,7 @@ def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2., return spike_indexes, peak_indexes -def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10., dvdt=None): +def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10.0, dvdt=None): """Find indexes of maximum upstroke of spike. Parameters @@ -185,13 +185,12 @@ def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10., dvdt=No if dvdt is None: dvdt = calculate_dvdt(v, t, filter) - upstroke_indexes = [np.argmax(dvdt[spike:peak]) + spike for spike, peak in - zip(spike_indexes, peak_indexes)] + upstroke_indexes = [np.argmax(dvdt[spike:peak]) + spike for spike, peak in zip(spike_indexes, peak_indexes)] return np.array(upstroke_indexes) -def refine_threshold_indexes(v, t, upstroke_indexes, thresh_frac=0.05, filter=10., dvdt=None): +def refine_threshold_indexes(v, t, upstroke_indexes, thresh_frac=0.05, filter=10.0, dvdt=None): """Refine threshold detection of previously-found spikes. Parameters @@ -231,9 +230,19 @@ def refine_threshold_indexes(v, t, upstroke_indexes, thresh_frac=0.05, filter=10 return np.array(threshold_indexes) -def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_indexes, end=None, - max_interval=0.005, thresh_frac=0.05, filter=10., dvdt=None, - tol=1.0): +def check_thresholds_and_peaks( + v, + t, + spike_indexes, + peak_indexes, + upstroke_indexes, + end=None, + max_interval=0.005, + thresh_frac=0.05, + filter=10.0, + dvdt=None, + tol=1.0, +): """Validate thresholds and peaks for set of spikes Check that peaks and thresholds for consecutive spikes do not overlap @@ -284,7 +293,11 @@ def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_index too_long_spikes = [] for i, (spk, peak) in enumerate(zip(spike_indexes, peak_indexes)): if t[peak] - t[spk] >= max_interval: - logging.info("Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format(max_interval)) + logging.info( + "Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format( + max_interval + ) + ) too_long_spikes.append(i) if too_long_spikes: @@ -297,7 +310,7 @@ def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_index # First guessing that threshold is wrong and peak is right peak = peak_indexes[i] t_0 = find_time_index(t, t[peak] - max_interval) - below_target = np.flatnonzero(dvdt[upstroke_indexes[i]:t_0:-1] <= target) + below_target = np.flatnonzero(dvdt[upstroke_indexes[i] : t_0 : -1] <= target) if not below_target.size: # Now try to see if threshold was right but peak was wrong @@ -308,18 +321,18 @@ def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_index # If that peak is okay (not outside the allowed window, not past the next spike) # then keep it - if t[new_peak] - t[spike] < max_interval and \ - (i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]]): + if t[new_peak] - t[spike] < max_interval and ( + i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]] + ): peak_indexes[i] = new_peak else: # Otherwise, log and get rid of the spike logging.info("Could not redetermine threshold-peak pair - dropping that pair") drop_spikes.append(i) -# raise FeatureError("Could not redetermine threshold") + # raise FeatureError("Could not redetermine threshold") else: spike_indexes[i] = upstroke_indexes[i] - below_target[0] - if drop_spikes: spike_indexes = np.delete(spike_indexes, drop_spikes) peak_indexes = np.delete(peak_indexes, drop_spikes) @@ -330,8 +343,12 @@ def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_index # voltage - otherwise, drop it clipped = np.zeros_like(spike_indexes, dtype=bool) end_index = find_time_index(t, end) - if len(spike_indexes) > 0 and not np.any(v[peak_indexes[-1]:end_index + 1] <= v[spike_indexes[-1]] + tol): - logging.debug("Failed to return to threshold voltage + tolerance (%.2f) after last spike (min %.2f) - marking last spike as clipped", v[spike_indexes[-1]] + tol, v[peak_indexes[-1]:end_index + 1].min()) + if len(spike_indexes) > 0 and not np.any(v[peak_indexes[-1] : end_index + 1] <= v[spike_indexes[-1]] + tol): + logging.debug( + "Failed to return to threshold voltage + tolerance (%.2f) after last spike (min %.2f) - marking last spike as clipped", + v[spike_indexes[-1]] + tol, + v[peak_indexes[-1] : end_index + 1].min(), + ) clipped[-1] = True return spike_indexes, peak_indexes, upstroke_indexes, clipped @@ -365,24 +382,23 @@ def find_trough_indexes(v, t, spike_indexes, peak_indexes, clipped=None, end=Non end_index = find_time_index(t, end) trough_indexes = np.zeros_like(spike_indexes, dtype=float) - trough_indexes[:-1] = [v[peak:spk].argmin() + peak for peak, spk - in zip(peak_indexes[:-1], spike_indexes[1:])] + trough_indexes[:-1] = [v[peak:spk].argmin() + peak for peak, spk in zip(peak_indexes[:-1], spike_indexes[1:])] if clipped[-1]: # If last spike is cut off by the end of the window, trough is undefined trough_indexes[-1] = np.nan else: - trough_indexes[-1] = v[peak_indexes[-1]:end_index].argmin() + peak_indexes[-1] + trough_indexes[-1] = v[peak_indexes[-1] : end_index].argmin() + peak_indexes[-1] # nwg - trying to remove this next part for now - can't figure out if this will be needed with new "clipped" method # If peak is the same point as the trough, drop that point -# trough_indexes = trough_indexes[np.where(peak_indexes[:len(trough_indexes)] != trough_indexes)] + # trough_indexes = trough_indexes[np.where(peak_indexes[:len(trough_indexes)] != trough_indexes)] return trough_indexes -def find_downstroke_indexes(v, t, peak_indexes, trough_indexes, clipped=None, filter=10., dvdt=None): +def find_downstroke_indexes(v, t, peak_indexes, trough_indexes, clipped=None, filter=10.0, dvdt=None): """Find indexes of minimum voltage (troughs) between spikes. Parameters @@ -411,15 +427,16 @@ def find_downstroke_indexes(v, t, peak_indexes, trough_indexes, clipped=None, fi if len(peak_indexes) < len(trough_indexes): raise FeatureError("Cannot have more troughs than peaks") -# Taking this out...with clipped info, should always have the same number of points -# peak_indexes = peak_indexes[:len(trough_indexes)] + # Taking this out...with clipped info, should always have the same number of points + # peak_indexes = peak_indexes[:len(trough_indexes)] valid_peak_indexes = peak_indexes[~clipped].astype(int) valid_trough_indexes = trough_indexes[~clipped].astype(int) downstroke_indexes = np.zeros_like(peak_indexes) * np.nan - downstroke_index_values = [np.argmin(dvdt[peak:trough]) + peak for peak, trough - in zip(valid_peak_indexes, valid_trough_indexes)] + downstroke_index_values = [ + np.argmin(dvdt[peak:trough]) + peak for peak, trough in zip(valid_peak_indexes, valid_trough_indexes) + ] downstroke_indexes[~clipped] = downstroke_index_values return downstroke_indexes @@ -459,39 +476,61 @@ def find_widths(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None) heights[use_indexes] = v[peak_indexes[use_indexes]] - v[trough_indexes[use_indexes].astype(int)] width_levels = np.zeros_like(trough_indexes) * np.nan - width_levels[use_indexes] = heights[use_indexes] / 2. + v[trough_indexes[use_indexes].astype(int)] + width_levels[use_indexes] = heights[use_indexes] / 2.0 + v[trough_indexes[use_indexes].astype(int)] thresh_to_peak_levels = np.zeros_like(trough_indexes) * np.nan - thresh_to_peak_levels[use_indexes] = (v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes]]) / 2. + v[spike_indexes[use_indexes]] + thresh_to_peak_levels[use_indexes] = (v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes]]) / 2.0 + v[ + spike_indexes[use_indexes] + ] # Some spikes in burst may have deep trough but short height, so can't use same # definition for width - width_levels[width_levels < v[spike_indexes]] = \ - thresh_to_peak_levels[width_levels < v[spike_indexes]] + width_levels[width_levels < v[spike_indexes]] = thresh_to_peak_levels[width_levels < v[spike_indexes]] width_starts = np.zeros_like(trough_indexes) * np.nan - width_starts[use_indexes] = np.array([pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if - np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan for pk, spk, wl - in zip(peak_indexes[use_indexes], spike_indexes[use_indexes], width_levels[use_indexes])]) + width_starts[use_indexes] = np.array( + [ + pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan + for pk, spk, wl in zip(peak_indexes[use_indexes], spike_indexes[use_indexes], width_levels[use_indexes]) + ] + ) width_ends = np.zeros_like(trough_indexes) * np.nan - width_ends[use_indexes] = np.array([pk + np.flatnonzero(v[pk:tr] <= wl)[0] if - np.flatnonzero(v[pk:tr] <= wl).size > 0 else np.nan for pk, tr, wl - in zip(peak_indexes[use_indexes], trough_indexes[use_indexes].astype(int), width_levels[use_indexes])]) + width_ends[use_indexes] = np.array( + [ + pk + np.flatnonzero(v[pk:tr] <= wl)[0] if np.flatnonzero(v[pk:tr] <= wl).size > 0 else np.nan + for pk, tr, wl in zip( + peak_indexes[use_indexes], trough_indexes[use_indexes].astype(int), width_levels[use_indexes] + ) + ] + ) missing_widths = np.isnan(width_starts) | np.isnan(width_ends) widths = np.zeros_like(width_starts, dtype=np.float64) - widths[~missing_widths] = t[width_ends[~missing_widths].astype(int)] - \ - t[width_starts[~missing_widths].astype(int)] + widths[~missing_widths] = t[width_ends[~missing_widths].astype(int)] - t[width_starts[~missing_widths].astype(int)] if any(missing_widths): widths[missing_widths] = np.nan return widths -def analyze_trough_details(v, t, spike_indexes, peak_indexes, clipped=None, end=None, filter=10., - heavy_filter=1., term_frac=0.01, adp_thresh=0.5, tol=0.5, - flat_interval=0.002, adp_max_delta_t=0.005, adp_max_delta_v=10., dvdt=None): +def analyze_trough_details( + v, + t, + spike_indexes, + peak_indexes, + clipped=None, + end=None, + filter=10.0, + heavy_filter=1.0, + term_frac=0.01, + adp_thresh=0.5, + tol=0.5, + flat_interval=0.002, + adp_max_delta_t=0.005, + adp_max_delta_v=10.0, + dvdt=None, +): """Analyze trough to determine if an ADP exists and whether the reset is a 'detour' or 'direct' Parameters @@ -576,9 +615,11 @@ def analyze_trough_details(v, t, spike_indexes, peak_indexes, clipped=None, end= if zero_return_vals.size: putative_adp_index = zero_return_vals[0] + cross min_index = v[putative_adp_index:next_spk].argmin() + putative_adp_index - if (v[putative_adp_index] - v[min_index] >= tol and - v[putative_adp_index] - v[terminated] <= adp_max_delta_v and - t[putative_adp_index] - t[terminated] <= adp_max_delta_t): + if ( + v[putative_adp_index] - v[min_index] >= tol + and v[putative_adp_index] - v[terminated] <= adp_max_delta_v + and t[putative_adp_index] - t[terminated] <= adp_max_delta_t + ): adp_index = putative_adp_index slow_phase_min_index = min_index isi_type = "detour" @@ -619,8 +660,8 @@ def analyze_trough_details(v, t, spike_indexes, peak_indexes, clipped=None, end= # magnitude of the difference will be less than in many of the erroneous # cases seen otherwise - output[2][-1] = np.nan # ADP - output[3][-1] = np.nan # slow trough + output[2][-1] = np.nan # ADP + output[3][-1] = np.nan # slow trough clipped[~clipped] = update_clipped return output, clipped @@ -652,10 +693,14 @@ def calculate_dvdt(v, t, filter=None): if has_fixed_dt(t) and filter: delta_t = t[1] - t[0] - sample_freq = 1. / delta_t - filt_coeff = (filter * 1e3) / (sample_freq / 2.) # filter kHz -> Hz, then get fraction of Nyquist frequency + sample_freq = 1.0 / delta_t + filt_coeff = (filter * 1e3) / (sample_freq / 2.0) # filter kHz -> Hz, then get fraction of Nyquist frequency if filt_coeff < 0 or filt_coeff >= 1: - raise ValueError("bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(filt_coeff, sample_freq / 1e3, filter)) + raise ValueError( + "bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format( + filt_coeff, sample_freq / 1e3, filter + ) + ) b, a = signal.bessel(4, filt_coeff, "low") v_filt = signal.filtfilt(b, a, v, axis=0) dv = np.diff(v_filt) @@ -663,7 +708,7 @@ def calculate_dvdt(v, t, filter=None): dv = np.diff(v) dt = np.diff(t) - dvdt = 1e-3 * dv / dt # in V/s = mV/ms + dvdt = 1e-3 * dv / dt # in V/s = mV/ms # Remove nan values (in case any dt values == 0) dvdt = dvdt[~np.isnan(dvdt)] @@ -760,10 +805,10 @@ def norm_diff(a): return np.nan a = a.astype(float) - if np.allclose((a[1:] + a[:-1]), 0.): - return 0. + if np.allclose((a[1:] + a[:-1]), 0.0): + return 0.0 norm_diffs = (a[1:] - a[:-1]) / (a[1:] + a[:-1]) - norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0. + norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0.0 with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy") avg = np.nanmean(norm_diffs) @@ -807,7 +852,7 @@ def fit_membrane_time_constant(v, t, start, end, min_rsme=1e-4): start_index = find_time_index(t, start) end_index = find_time_index(t, end) - guess = (v[start_index] - v[end_index], 50., v[end_index]) + guess = (v[start_index] - v[end_index], 50.0, v[end_index]) t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64) v_window = v[start_index:end_index].astype(np.float64) try: @@ -853,7 +898,9 @@ def detect_pauses(isis, isi_types, cost_weight=1.0): detour_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "detour"] median_direct = np.median(isis[isi_types == "direct"]) - direct_candidates = [i for i, isi_type in enumerate(isi_types) if isi_type == "direct" and isis[i] > 3 * median_direct] + direct_candidates = [ + i for i, isi_type in enumerate(isi_types) if isi_type == "direct" and isis[i] > 3 * median_direct + ] candidates = detour_candidates + direct_candidates if not candidates: @@ -885,8 +932,7 @@ def detect_pauses(isis, isi_types, cost_weight=1.0): return np.sort(pause_list) -def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, - thr_v, tol=0.5, pause_cost=1.0): +def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, thr_v, tol=0.5, pause_cost=1.0): """Detect bursts in spike train. Parameters @@ -913,7 +959,7 @@ def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, if len(isis) != len(isi_types): raise FeatureError("Wrong number of ISIs") - if len(isis) < 2: # can't determine burstiness for a single ISI + if len(isis) < 2: # can't determine burstiness for a single ISI return np.array([]) fast_tr_v = fast_tr_v[:-1] @@ -921,7 +967,7 @@ def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, slow_tr_v = slow_tr_v[:-1] slow_tr_t = slow_tr_t[:-1] - isi_types = np.array(isi_types) # don't want to change the actual isi types data + isi_types = np.array(isi_types) # don't want to change the actual isi types data # Burst transitions can't be at "pause"-like ISIs pauses = detect_pauses(isis, isi_types, cost_weight=pause_cost).astype(int) @@ -936,17 +982,21 @@ def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, isi_types[(thr_v[:-1] < (slow_tr_v + tol)) & (isi_types == "detour")] = "midburst" # Find transitions from direct -> detour and vice versa for burst boundaries - into_burst = np.array([i + 1 for i, (prev, cur) in - enumerate(zip(isi_types[:-1], isi_types[1:])) if - cur == "direct" and prev == "detour"], - dtype=int) + into_burst = np.array( + [ + i + 1 + for i, (prev, cur) in enumerate(zip(isi_types[:-1], isi_types[1:])) + if cur == "direct" and prev == "detour" + ], + dtype=int, + ) if isi_types[0] == "direct": into_burst = np.append(np.array([0]), into_burst) drop_into = [] out_of_burst = [] for j, (into, next) in enumerate(zip(into_burst, np.append(into_burst[1:], len(isis)))): - for i, isi in enumerate(isi_types[into + 1:next]): + for i, isi in enumerate(isi_types[into + 1 : next]): if isi == "detour": out_of_burst.append(i + into + 1) break @@ -990,17 +1040,17 @@ def detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t, bursts = [] for i, (into, outof) in enumerate(inout_pairs): - if i == len(inout_pairs) - 1: # last burst to evaluate - if outof <= len(isis) - 1: # are there spikes left after the burst? + if i == len(inout_pairs) - 1: # last burst to evaluate + if outof <= len(isis) - 1: # are there spikes left after the burst? metric = _burstiness_index(isis[into:outof], isis[outof:]) - elif i == 0: # was this the first one (and there weren't spikes after)? + elif i == 0: # was this the first one (and there weren't spikes after)? metric = _burstiness_index(isis[into:outof], isis[:into]) else: prev_burst = inout_pairs[i - 1] - metric = _burstiness_index(isis[into:outof], isis[prev_burst[1]:into]) + metric = _burstiness_index(isis[into:outof], isis[prev_burst[1] : into]) else: next_burst = inout_pairs[i + 1] - metric = _burstiness_index(isis[into:outof], isis[outof:next_burst[0]]) + metric = _burstiness_index(isis[into:outof], isis[outof : next_burst[0]]) bursts.append((metric, into, outof)) return bursts @@ -1035,7 +1085,7 @@ def fit_prespike_time_constant(v, t, start, spike_time, dv_limit=-0.001, tau_lim t_slice = t[start_index:end_index] # Solve linear version with single exponential first to guess at the time constant - y0 = v_slice.max() + 5e-6 # set y0 slightly above v_slice maximum + y0 = v_slice.max() + 5e-6 # set y0 slightly above v_slice maximum y = -v_slice + y0 y = np.log(y) @@ -1043,7 +1093,7 @@ def fit_prespike_time_constant(v, t, start, spike_time, dv_limit=-0.001, tau_lim # End the fit interval if the voltage starts dropping new_end_indexes = np.flatnonzero(dy <= dv_limit) - cross_limit = 0.0005 # sec + cross_limit = 0.0005 # sec if not new_end_indexes.size or t_slice[new_end_indexes[0]] - t_slice[0] < cross_limit: # either never crosses or crosses too early new_end_index = len(v_slice) @@ -1070,7 +1120,7 @@ def fit_prespike_time_constant(v, t, start, spike_time, dv_limit=-0.001, tau_lim # These are all empirical values if np.abs(faster_weight) > np.abs(slower_weight): tau = faster_tau - elif (slower_tau - faster_tau) / slower_tau <= 0.1: # close enough; just use slower + elif (slower_tau - faster_tau) / slower_tau <= 0.1: # close enough; just use slower tau = slower_tau elif slower_tau > tau_limit and slower_weight / faster_weight < 2.0: tau = faster_tau @@ -1148,7 +1198,7 @@ def estimate_adjusted_detection_parameters(v_set, t_set, interval_start, interva def _score_burst_set(bursts, isis, delta_t, c_n=0.1, c_tx=0.01): in_burst = np.zeros_like(isis, dtype=bool) for b in bursts: - in_burst[b[0]:b[1]] = True + in_burst[b[0] : b[1]] = True # If all ISIs are part of a burst, give it a bad score if len(isis[~in_burst]) == 0: @@ -1158,20 +1208,20 @@ def _score_burst_set(bursts, isis, delta_t, c_n=0.1, c_tx=0.01): scores = [] for b in bursts: - score = _burstiness_index(isis[b[0]:b[1]], isis[~in_burst]) # base score + score = _burstiness_index(isis[b[0] : b[1]], isis[~in_burst]) # base score if b[1] < len(delta_t): - score -= c_tx * (1. / (delta_frac[b[1]])) # cost for starting a burst + score -= c_tx * (1.0 / (delta_frac[b[1]])) # cost for starting a burst if b[0] > 0: - score -= c_tx * (1. / delta_frac[b[0] - 1]) # cost for ending a burst - score -= c_n * (b[1] - b[0] - 1) # cost for extending a burst + score -= c_tx * (1.0 / delta_frac[b[0] - 1]) # cost for ending a burst + score -= c_n * (b[1] - b[0] - 1) # cost for extending a burst scores.append(score) return scores def _burstiness_index(in_burst_isis, out_burst_isis): - burst_rate = 1. / in_burst_isis.min() - out_rate = 1. / out_burst_isis.min() + burst_rate = 1.0 / in_burst_isis.min() + out_rate = 1.0 / out_burst_isis.min() return (burst_rate - out_rate) / (burst_rate + out_rate) @@ -1188,4 +1238,5 @@ def _dbl_exp_fit(y0, x, A1, tau1, A2, tau2): class FeatureError(Exception): """Generic Python-exception-derived object raised by feature detection functions.""" + pass diff --git a/allensdk/ephys/extract_cell_features.py b/allensdk/ephys/extract_cell_features.py index ef5d5b5b96..a06d2c64b9 100755 --- a/allensdk/ephys/extract_cell_features.py +++ b/allensdk/ephys/extract_cell_features.py @@ -42,17 +42,32 @@ HERO_MIN_AMP_OFFSET = 39.0 HERO_MAX_AMP_OFFSET = 61.0 -SHORT_SQUARE_TYPES = ["Short Square", - "Short Square - Triple", - "Short Square - Hold -60mv", - "Short Square - Hold -70mv", - "Short Square - Hold -80mv"] +SHORT_SQUARE_TYPES = [ + "Short Square", + "Short Square - Triple", + "Short Square - Hold -60mv", + "Short Square - Hold -70mv", + "Short Square - Hold -80mv", +] SHORT_SQUARE_THRESH_FRAC_FLOOR = 0.1 -MEAN_FEATURES = [ "upstroke_downstroke_ratio", "peak_v", "peak_t", "trough_v", "trough_t", - "fast_trough_v", "fast_trough_t", "slow_trough_v", "slow_trough_t", - "threshold_v", "threshold_i", "threshold_t", "peak_v", "peak_t" ] +MEAN_FEATURES = [ + "upstroke_downstroke_ratio", + "peak_v", + "peak_t", + "trough_v", + "trough_t", + "fast_trough_v", + "fast_trough_t", + "slow_trough_v", + "slow_trough_t", + "threshold_v", + "threshold_i", + "threshold_t", + "peak_v", + "peak_t", +] def extract_sweep_features(data_set, sweeps_by_type): @@ -60,7 +75,7 @@ def extract_sweep_features(data_set, sweeps_by_type): sweep_features = {} for stimulus_type, sweep_numbers in sweeps_by_type.items(): - logging.debug("%s:%s" % (stimulus_type, ','.join(map(str, sweep_numbers)))) + logging.debug("%s:%s" % (stimulus_type, ",".join(map(str, sweep_numbers)))) if stimulus_type == "Short Square - Triple": tmp_ext = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers) @@ -71,12 +86,10 @@ def extract_sweep_features(data_set, sweeps_by_type): # triple-sweeps to use different window win_start = efex.SHORT_SQUARE_TRIPLE_WINDOW_START win_end = efex.SHORT_SQUARE_TRIPLE_WINDOW_END - cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters( - v_set, t_set, win_start, win_end) + cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(v_set, t_set, win_start, win_end) thresh_frac = max(SHORT_SQUARE_THRESH_FRAC_FLOOR, thresh_frac) - fex = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers, - dv_cutoff=cutoff, thresh_frac=thresh_frac) + fex = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers, dv_cutoff=cutoff, thresh_frac=thresh_frac) elif stimulus_type in SHORT_SQUARE_TYPES: tmp_ext = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers) t_set = [s.t for s in tmp_ext.sweeps()] @@ -84,47 +97,40 @@ def extract_sweep_features(data_set, sweeps_by_type): win_start = efex.SHORT_SQUARES_WINDOW_START win_end = efex.SHORT_SQUARES_WINDOW_END - cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters( - v_set, t_set, win_start, win_end) + cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(v_set, t_set, win_start, win_end) thresh_frac = max(SHORT_SQUARE_THRESH_FRAC_FLOOR, thresh_frac) - fex = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers, - dv_cutoff=cutoff, thresh_frac=thresh_frac) + fex = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers, dv_cutoff=cutoff, thresh_frac=thresh_frac) else: fex = efex.extractor_for_nwb_sweeps(data_set, sweep_numbers) fex.process_spikes() - sweep_features.update({ f.id:f.as_dict() for f in fex.sweeps() }) + sweep_features.update({f.id: f.as_dict() for f in fex.sweeps()}) return sweep_features + # if subthreshold minimum amplitude is known (e.g., for human cells) then # specify it. otherwise the default value will be used -def extract_cell_features(data_set, - ramp_sweep_numbers, - short_square_sweep_numbers, - long_square_sweep_numbers, - subthresh_min_amp = None): - +def extract_cell_features( + data_set, ramp_sweep_numbers, short_square_sweep_numbers, long_square_sweep_numbers, subthresh_min_amp=None +): if subthresh_min_amp is None: - fex = efex.cell_extractor_for_nwb(data_set, - ramp_sweep_numbers, - short_square_sweep_numbers, - long_square_sweep_numbers) + fex = efex.cell_extractor_for_nwb( + data_set, ramp_sweep_numbers, short_square_sweep_numbers, long_square_sweep_numbers + ) else: - fex = efex.cell_extractor_for_nwb(data_set, - ramp_sweep_numbers, - short_square_sweep_numbers, - long_square_sweep_numbers, - subthresh_min_amp) + fex = efex.cell_extractor_for_nwb( + data_set, ramp_sweep_numbers, short_square_sweep_numbers, long_square_sweep_numbers, subthresh_min_amp + ) fex.process() cell_features = fex.as_dict() # find hero sweep - rheo_amp = cell_features['long_squares']['rheobase_i'] + rheo_amp = cell_features["long_squares"]["rheobase_i"] hero_min, hero_max = rheo_amp + HERO_MIN_AMP_OFFSET, rheo_amp + HERO_MAX_AMP_OFFSET hero_amp = float("inf") hero_sweep = None @@ -148,30 +154,32 @@ def extract_cell_features(data_set, ss_ms0 = mean_features_spike_zero(fex.short_squares_features().sweeps()) # compute baseline from all long square sweeps - v_baseline = np.mean(fex.long_squares_features().sweep_features('v_baseline')) + v_baseline = np.mean(fex.long_squares_features().sweep_features("v_baseline")) - cell_features['long_squares']['v_baseline'] = v_baseline - cell_features['long_squares']['hero_sweep'] = hero_sweep.as_dict() if hero_sweep else None + cell_features["long_squares"]["v_baseline"] = v_baseline + cell_features["long_squares"]["hero_sweep"] = hero_sweep.as_dict() if hero_sweep else None cell_features["ramps"]["mean_spike_0"] = ramps_ms0 cell_features["short_squares"]["mean_spike_0"] = ss_ms0 return cell_features + def mean_features_spike_zero(sweeps): - """ Compute mean feature values for the first spike in list of extractors """ + """Compute mean feature values for the first spike in list of extractors""" output = {} for mf in MEAN_FEATURES: - mfd = [ sweep.spikes()[0][mf] for sweep in sweeps if sweep.sweep_feature("avg_rate") > 0 ] + mfd = [sweep.spikes()[0][mf] for sweep in sweeps if sweep.sweep_feature("avg_rate") > 0] output[mf] = np.mean(mfd) return output + def get_stim_characteristics(i, t, no_test_pulse=False): - ''' + """ Identify the start time, duration, amplitude, start index, and end index of a general stimulus. This assumes that there is a test pulse followed by the stimulus square. - ''' + """ di = np.diff(i) diff_idx = np.flatnonzero(di != 0) @@ -192,22 +200,24 @@ def get_stim_characteristics(i, t, no_test_pulse=False): return (stim_start, stim_dur, stim_amp, start_idx, end_idx) + def get_ramp_stim_characteristics(i, t): - ''' Identify the start time and start index of a ramp sweep. ''' + """Identify the start time and start index of a ramp sweep.""" # Assumes that there is a test pulse followed by the stimulus ramp di = np.diff(i) up_idx = np.flatnonzero(di > 0) - start_idx = up_idx[1] + 1 # shift by one to compensate for diff() + start_idx = up_idx[1] + 1 # shift by one to compensate for diff() return (t[start_idx], start_idx) + def get_square_stim_characteristics(i, t, no_test_pulse=False): - ''' + """ Identify the start time, duration, amplitude, start index, and end index of a square stimulus. This assumes that there is a test pulse followed by the stimulus square. - ''' + """ di = np.diff(i) up_idx = np.flatnonzero(di > 0) @@ -216,10 +226,10 @@ def get_square_stim_characteristics(i, t, no_test_pulse=False): idx = 0 if no_test_pulse else 1 # second square is the stimulus - if up_idx[idx] < down_idx[idx]: # positive square - start_idx = up_idx[idx] + 1 # shift by one to compensate for diff() + if up_idx[idx] < down_idx[idx]: # positive square + start_idx = up_idx[idx] + 1 # shift by one to compensate for diff() end_idx = down_idx[idx] + 1 - else: # negative square + else: # negative square start_idx = down_idx[idx] + 1 end_idx = up_idx[idx] + 1 diff --git a/allensdk/ephys/feature_extractor.py b/allensdk/ephys/feature_extractor.py index 796cad740d..974fb13e1f 100644 --- a/allensdk/ephys/feature_extractor.py +++ b/allensdk/ephys/feature_extractor.py @@ -45,8 +45,9 @@ # must have all features in the file. If one is absent, a penalty # of TODO ??? will be assessed + # set of features -class EphysFeatures( object ): +class EphysFeatures(object): def __init__(self, name): # feature mean and standard deviations self.mean = {} @@ -67,15 +68,15 @@ def __init__(self, name): self.name = name - ################################################################ + ################################################################ # ignore scores - ignore_score = { "hit": "ignore" } + ignore_score = {"hit": "ignore"} self.glossary["n_spikes"] = "Number of spikes" self.scoring["n_spikes"] = ignore_score ################################################################ # ignore misses - ignore_miss = { "hit":"stdev", "miss":"const", "const":0 } + ignore_miss = {"hit": "stdev", "miss": "const", "const": 0} self.glossary["adapt"] = "Adaptation index" self.scoring["adapt"] = ignore_miss self.glossary["latency"] = "Time to first spike (ms)" @@ -83,13 +84,13 @@ def __init__(self, name): ################################################################ # base miss off mean - mean_score = { "hit":"stdev", "miss":"mean_mult", "mean_mult":2 } + mean_score = {"hit": "stdev", "miss": "mean_mult", "mean_mult": 2} self.glossary["ISICV"] = "ISI-CV" self.scoring["ISICV"] = mean_score ################################################################ # normal scoring - normal_score = { "hit":"stdev", "miss":"const", "const":20 } + normal_score = {"hit": "stdev", "miss": "const", "const": 20} self.glossary["isi_avg"] = "Average ISI (ms)" self.scoring["isi_avg"] = ignore_score self.glossary["doublet"] = "Doublet ISI (ms)" @@ -102,13 +103,13 @@ def __init__(self, name): self.scoring["f_slow_ahp_time"] = normal_score self.glossary["base_v"] = "Baseline voltage (mV)" self.scoring["base_v"] = normal_score - #self.glossary["base_v2"] = "Baseline voltage 2 (mV)" - #self.scoring["base_v2"] = normal_score - #self.glossary["base_v3"] = "Baseline voltage 3 (mV)" - #self.scoring["base_v3"] = normal_score + # self.glossary["base_v2"] = "Baseline voltage 2 (mV)" + # self.scoring["base_v2"] = normal_score + # self.glossary["base_v3"] = "Baseline voltage 3 (mV)" + # self.scoring["base_v3"] = normal_score ################################################################ # per spike scoring - perspike_score = { "hit":"perspike", "miss":"const", "const":20, "skip_last_n":0 } + perspike_score = {"hit": "perspike", "miss": "const", "const": 20, "skip_last_n": 0} self.glossary["f_peak"] = "Spike height (mV)" self.scoring["f_peak"] = perspike_score.copy() self.glossary["f_trough"] = "Spike depth (mV)" @@ -132,10 +133,9 @@ def __init__(self, name): self.glossary["thresh_ramp"] = "Change in dv/dt over first 5 mV past threshold (mV/ms)" self.scoring["thresh_ramp"] = perspike_score.copy() - ################################################################ # heavily penalize when there are no spikes - spike_score = { "hit":"stdev", "miss":"const", "const":250 } + spike_score = {"hit": "stdev", "miss": "const", "const": 250} self.glossary["rate"] = "Firing rate (Hz)" self.scoring["rate"] = spike_score @@ -158,7 +158,8 @@ def clone(self, param_dict): self.mean[k] = param_dict[k]["mean"] self.stdev[k] = param_dict[k]["stdev"] -class EphysFeatureExtractor( object ): + +class EphysFeatureExtractor(object): def __init__(self): # list of feature set instances self.feature_list = [] @@ -178,8 +179,8 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): start = onset stop = onset + dur # detect spikes for all of sweep - #start = 0 - #stop = t[-1] + # start = 0 + # stop = t[-1] ################################################################ # pull out spike times @@ -195,21 +196,23 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): dv = np.diff(smooth_v) else: dv = np.diff(v_target) - dvdt = dv / (np.diff(t[start_idx:stop_idx]) * 1e3) # in mV/ms + dvdt = dv / (np.diff(t[start_idx:stop_idx]) * 1e3) # in mV/ms dv_cutoff = 20 thresh_pct = 0.05 spikes = [] - temp_spk_idxs = np.where(np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1)[0] # find positive-going crossings of 100 mV/ms + temp_spk_idxs = np.where(np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1)[ + 0 + ] # find positive-going crossings of 100 mV/ms spk_idxs = [] for i, temp in enumerate(temp_spk_idxs): if i == 0: spk_idxs.append(temp) - elif np.any(dvdt[temp_spk_idxs[i - 1]:temp] < 0): + elif np.any(dvdt[temp_spk_idxs[i - 1] : temp] < 0): # check if the dvdt has gone back down below zero between presumed spike times # sometimes the dvdt bobbles around detection threshold and produces spurious guesses at spike times spk_idxs.append(temp) - spk_idxs += start_idx # set back to the "index space" of the original trace + spk_idxs += start_idx # set back to the "index space" of the original trace # recalculate full dv/dt for feature analysis (vs spike detection) if np.abs(t[1] - t[0] - 5e-6) < 1e-7 and np.var(np.diff(t)) < 1e-6: @@ -218,7 +221,7 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): dv = np.diff(smooth_v) else: dv = np.diff(v) - dvdt = dv / (np.diff(t) * 1e3) # in mV/ms + dvdt = dv / (np.diff(t) * 1e3) # in mV/ms # First time through, accumulate upstrokes to calculate average threshold target for spk_n, spk_idx in enumerate(spk_idxs): @@ -244,18 +247,20 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): spk["f_peak_t"] = t[peak_idx] # Check if end of stimulus interval cuts off spike - if so, don't process spike - if spk_n == len(spk_idxs) - 1 and peak_idx == next_idx-1: + if spk_n == len(spk_idxs) - 1 and peak_idx == next_idx - 1: continue if spk_idx == peak_idx: - continue # this was bugfix, but why? ramp? + continue # this was bugfix, but why? ramp? # Determine maximum upstroke of spike upstroke_idx = np.argmax(dvdt[spk_idx:peak_idx]) + spk_idx spk["upstroke"] = dvdt[upstroke_idx] - if np.isnan(spk["upstroke"]): # sometimes dvdt will be NaN because of multiple cvode points at same time step + if np.isnan( + spk["upstroke"] + ): # sometimes dvdt will be NaN because of multiple cvode points at same time step close_idx = upstroke_idx + 1 - while (np.isnan(dvdt[close_idx])): + while np.isnan(dvdt[close_idx]): close_idx += 1 spk["upstroke_idx"] = close_idx spk["upstroke"] = dvdt[close_idx] @@ -271,7 +276,7 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): # Preliminarily define threshold where dvdt = 5% * max upstroke thresh_pct = 0.05 find_thresh_idxs = np.where(dvdt[prev_idx:upstroke_idx] <= thresh_pct * spk["upstroke"])[0] - if len(find_thresh_idxs) < 1: # Can't find a good threshold value - probably a bad simulation case + if len(find_thresh_idxs) < 1: # Can't find a good threshold value - probably a bad simulation case # Fall back to the upstroke value threshold_idx = upstroke_idx else: @@ -324,7 +329,7 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): # Restore variables from before # peak_idx = spk['peak_idx'] - peak_idx = np.argmax(v[spk['threshold_idx']:next_idx]) + spk['threshold_idx'] + peak_idx = np.argmax(v[spk["threshold_idx"] : next_idx]) + spk["threshold_idx"] spk["peak_idx"] = peak_idx spk["f_peak"] = v[peak_idx] @@ -333,12 +338,14 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): # Determine maximum upstroke of spike # upstroke_idx = spk['upstroke_idx'] - upstroke_idx = np.argmax(dvdt[spk['threshold_idx']:peak_idx]) + spk['threshold_idx'] + upstroke_idx = np.argmax(dvdt[spk["threshold_idx"] : peak_idx]) + spk["threshold_idx"] spk["upstroke"] = dvdt[upstroke_idx] - if np.isnan(spk["upstroke"]): # sometimes dvdt will be NaN because of multiple cvode points at same time step + if np.isnan( + spk["upstroke"] + ): # sometimes dvdt will be NaN because of multiple cvode points at same time step close_idx = upstroke_idx + 1 - while (np.isnan(dvdt[close_idx])): + while np.isnan(dvdt[close_idx]): close_idx += 1 spk["upstroke_idx"] = close_idx spk["upstroke"] = dvdt[close_idx] @@ -353,7 +360,7 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): # Find threshold based on average target find_thresh_idxs = np.where(dvdt[prev_idx:upstroke_idx] <= threshold_target)[0] - if len(find_thresh_idxs) < 1: # Can't find a good threshold value - probably a bad simulation case + if len(find_thresh_idxs) < 1: # Can't find a good threshold value - probably a bad simulation case # Fall back to the upstroke value threshold_idx = upstroke_idx else: @@ -372,7 +379,7 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): overn30_idxs = np.where(v[threshold_idx:peak_idx] >= -30)[0] if len(overn30_idxs) > 0: spk["t_idx_n30"] = overn30_idxs[0] + threshold_idx - else: # fall back to threshold definition if spike doesn't cross -30 mV + else: # fall back to threshold definition if spike doesn't cross -30 mV spk["t_idx_n30"] = threshold_idx spk["t_n30"] = t[spk["t_idx_n30"]] @@ -391,13 +398,17 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): spk["downstroke_i"] = curr[downstroke_idx] spk["downstroke_t"] = t[downstroke_idx] spk["downstroke"] = dvdt[downstroke_idx] - if np.isnan(spk["downstroke"]): # sometimes dvdt will be NaN because of multiple cvode points at same time step + if np.isnan( + spk["downstroke"] + ): # sometimes dvdt will be NaN because of multiple cvode points at same time step close_idx = downstroke_idx + 1 - while (np.isnan(dvdt[close_idx])): + while np.isnan(dvdt[close_idx]): close_idx += 1 spk["downstroke"] = dvdt[close_idx] - feature.mean["base_v"] = v[np.where((t > onset - 0.1) & (t < onset - 0.001))].mean() # baseline voltage, 100ms before stim + feature.mean["base_v"] = v[ + np.where((t > onset - 0.1) & (t < onset - 0.001)) + ].mean() # baseline voltage, 100ms before stim feature.mean["spikes"] = spikes isi_cv = self.isicv(spikes) if isi_cv is not None: @@ -413,18 +424,20 @@ def process_instance(self, name, v, curr, t, onset, dur, stim_name): idx_next = spikes[i + 1]["t_idx"] if i < len(spikes) - 1 else stop_idx self.calculate_trough(spk, v, curr, t, idx_next) half_max_v = (spk["f_peak"] - spk["f_trough"]) / 2.0 + spk["f_trough"] - over_half_max_v_idxs = np.where(v[spk["t_idx"]:spk["trough_idx"]] > half_max_v)[0] + over_half_max_v_idxs = np.where(v[spk["t_idx"] : spk["trough_idx"]] > half_max_v)[0] if len(over_half_max_v_idxs) > 0: - spk["width"] = 1000. * (t[over_half_max_v_idxs[-1] + spk["t_idx"]] - t[over_half_max_v_idxs[0] + spk["t_idx"]]) - feature.mean["latency"] = 1000. * (spikes[0]["t"] - onset) - feature.mean["latency_n30"] = 1000. * (spikes[0]["t_n30"] - onset) + spk["width"] = 1000.0 * ( + t[over_half_max_v_idxs[-1] + spk["t_idx"]] - t[over_half_max_v_idxs[0] + spk["t_idx"]] + ) + feature.mean["latency"] = 1000.0 * (spikes[0]["t"] - onset) + feature.mean["latency_n30"] = 1000.0 * (spikes[0]["t_n30"] - onset) # extract properties for each spike isicnt = 0 isitot = 0 - for i in range(0, len(spikes)-1): + for i in range(0, len(spikes) - 1): spk = spikes[i] - idx_next = spikes[i+1]["t_idx"] - isitot += spikes[i+1]["t"] - spikes[i]["t"] + idx_next = spikes[i + 1]["t_idx"] + isitot += spikes[i + 1]["t"] - spikes[i]["t"] isicnt += 1 if isicnt > 0: feature.mean["isi_avg"] = 1000 * isitot / isicnt @@ -470,19 +483,19 @@ def isicv(self, spikes): isi_mean = 0 lst = [] for i in range(len(spikes) - 1): - isi = spikes[i+1]["t"] - spikes[i]["t"] - #print("\t%g" % isi) + isi = spikes[i + 1]["t"] - spikes[i]["t"] + # print("\t%g" % isi) isi_mean += isi lst.append(isi) isi_mean /= 1.0 * len(lst) - #print(isi_mean) + # print(isi_mean) var = 0 for i in range(len(lst)): dif = isi_mean - lst[i] var += dif * dif var /= len(lst) - #var /= len(lst) - 1 - #print(math.sqrt(var)) + # var /= len(lst) - 1 + # print(math.sqrt(var)) if isi_mean > 0: return math.sqrt(var) / isi_mean return None @@ -493,14 +506,14 @@ def adaptation_index(self, spikes, stim_end): adi = 0 cnt = 0 isi = [] - for i in range(len(spikes)-1): - isi.append(spikes[i+1]["t"] - spikes[i]["t"]) + for i in range(len(spikes) - 1): + isi.append(spikes[i + 1]["t"] - spikes[i]["t"]) # act as though time between last spike and stim end is another ISI per Etay's code # l = stim_end - spikes[-1]["t"] # if l > 0 and l > isi[-1]: # isi.append(l) - for i in range(len(isi)-1): - adi += 1.0 * (isi[i+1] - isi[i]) / (isi[i+1] + isi[i]) + for i in range(len(isi) - 1): + adi += 1.0 * (isi[i + 1] - isi[i]) / (isi[i + 1] + isi[i]) cnt += 1 adi /= cnt return adi @@ -516,7 +529,7 @@ def calculate_trough(self, spike, v, curr, t, next_idx): peak_idx = spike["peak_idx"] if peak_idx >= next_idx: - logging.warning("next index (%d) before peak index (%d) calculating trough" % ( next_idx, peak_idx )) + logging.warning("next index (%d) before peak index (%d) calculating trough" % (next_idx, peak_idx)) trough_idx = next_idx else: trough_idx = np.argmin(v[peak_idx:next_idx]) + peak_idx @@ -582,7 +595,7 @@ def score_feature_set(self, set_num): val = cand.mean[k] inc = abs(mean - val) / stdev scores.append(inc) -# print("Hit %s, %g+/-%g (%g) = %g" % (k, mean, stdev, val, inc)) + # print("Hit %s, %g+/-%g (%g) = %g" % (k, mean, stdev, val, inc)) else: resp = cand.scoring[k]["miss"] if resp == "const": @@ -593,7 +606,7 @@ def score_feature_set(self, set_num): assert False miss = float(miss) scores.append(miss) -# print("Missed %s, penalty = %g" % (k, miss)) + # print("Missed %s, penalty = %g" % (k, miss)) elif response == "perspike": mean = self.summary.mean[k] stdev = self.summary.stdev[k] @@ -602,7 +615,7 @@ def score_feature_set(self, set_num): val = 0 n_spikes = len(cand.mean["spikes"]) skip_last_n = self.summary.scoring[k]["skip_last_n"] - for spike in cand.mean["spikes"][:n_spikes-skip_last_n]: + for spike in cand.mean["spikes"][: n_spikes - skip_last_n]: val += abs(spike[k] - mean) val /= n_spikes - skip_last_n inc = val / stdev @@ -617,7 +630,7 @@ def score_feature_set(self, set_num): assert False miss = float(miss) scores.append(miss) -# print("Missed %s, penalty = %g" % (k, miss)) + # print("Missed %s, penalty = %g" % (k, miss)) else: assert False if abs(sum(scores)) > 1e10: @@ -690,4 +703,3 @@ def summarize(self, summary): val /= 1.0 * len(self.feature_list) self.summary.stdev[k] = math.sqrt(val) return self - diff --git a/allensdk/internal/api/__init__.py b/allensdk/internal/api/__init__.py index 74442a7864..4be3c32abe 100644 --- a/allensdk/internal/api/__init__.py +++ b/allensdk/internal/api/__init__.py @@ -13,11 +13,13 @@ class OneOrMoreResultExpectedError(RuntimeError): def psycopg2_select(query, database, host, port, username, password): - connection = psycopg2.connect( - host=host, port=port, dbname=database, - user=username, password=password, - cursor_factory=psycopg2.extras.RealDictCursor + host=host, + port=port, + dbname=database, + user=username, + password=password, + cursor_factory=psycopg2.extras.RealDictCursor, ) cursor = connection.cursor() @@ -33,7 +35,6 @@ def psycopg2_select(query, database, host, port, username, password): class PostgresQueryMixin(object): def __init__(self, *, dbname, user, host, password, port): - self.dbname = dbname self.user = user self.host = host @@ -44,9 +45,9 @@ def get_cursor(self): return self.get_connection().cursor() def get_connection(self): - return psycopg2.connect(dbname=self.dbname, user=self.user, - host=self.host, password=self.password, - port=self.port) + return psycopg2.connect( + dbname=self.dbname, user=self.user, host=self.host, password=self.password, port=self.port + ) def fetchone(self, query, strict=True): response = one(list(self.select(query).to_dict().values())) @@ -60,24 +61,20 @@ def fetchall(self, query, strict=True): def select(self, query): return psycopg2_select( - query, - database=self.dbname, - host=self.host, - port=self.port, - username=self.user, - password=self.password + query, database=self.dbname, host=self.host, port=self.port, username=self.user, password=self.password ) def select_one(self, query): - data = self.select(query).to_dict('records') + data = self.select(query).to_dict("records") if len(data) == 1: return data[0] return {} -def db_connection_creator(credentials: Optional[DbCredentials] = None, - fallback_credentials: Optional[dict] = None, - ) -> PostgresQueryMixin: +def db_connection_creator( + credentials: Optional[DbCredentials] = None, + fallback_credentials: Optional[dict] = None, +) -> PostgresQueryMixin: """Create a db connection using credentials. If credentials are not provided then use fallback credentials (which attempt to read from shell environment variables). @@ -112,15 +109,17 @@ def db_connection_creator(credentials: Optional[DbCredentials] = None, """ if credentials: db_conn = PostgresQueryMixin( - dbname=credentials.dbname, user=credentials.user, - host=credentials.host, port=credentials.port, - password=credentials.password) + dbname=credentials.dbname, + user=credentials.user, + host=credentials.host, + port=credentials.port, + password=credentials.password, + ) elif fallback_credentials: - db_conn = (credential_injector(fallback_credentials) - (PostgresQueryMixin)()) + db_conn = credential_injector(fallback_credentials)(PostgresQueryMixin)() else: raise RuntimeError( - "Must provide either credentials or fallback credentials in " - "order to create a db connection!") + "Must provide either credentials or fallback credentials in order to create a db connection!" + ) return db_conn diff --git a/allensdk/internal/api/api_prerelease.py b/allensdk/internal/api/api_prerelease.py index 8f971e4934..95d3a3de84 100644 --- a/allensdk/internal/api/api_prerelease.py +++ b/allensdk/internal/api/api_prerelease.py @@ -4,11 +4,10 @@ class ApiPrerelease(Api): - '''Extends allensdk.api.api to copy files 'locally' from shared storage. - ''' + """Extends allensdk.api.api to copy files 'locally' from shared storage.""" def retrieve_file_from_storage(self, storage_path, save_file_path): - '''Copy data from path to file_name. + """Copy data from path to file_name. Parameters ---------- @@ -16,7 +15,7 @@ def retrieve_file_from_storage(self, storage_path, save_file_path): path to file in shared directory (copy source) save_file_name : string path to file destination (copy target) - ''' + """ self._file_download_log.info("Downloading PATH: %s", storage_path) self._file_download_log.debug("To PATH: %s", save_file_path) diff --git a/allensdk/internal/api/lims_api.py b/allensdk/internal/api/lims_api.py index d650d5f84a..f1d9fe6f51 100644 --- a/allensdk/internal/api/lims_api.py +++ b/allensdk/internal/api/lims_api.py @@ -6,39 +6,39 @@ from allensdk.core.authentication import credential_injector, DbCredentials -class LimsApi(): - +class LimsApi: def __init__(self, lims_credentials: Optional[DbCredentials] = None): if lims_credentials: self.lims_db = PostgresQueryMixin( - dbname=lims_credentials.dbname, user=lims_credentials.user, - host=lims_credentials.host, password=lims_credentials.password, - port=lims_credentials.port) + dbname=lims_credentials.dbname, + user=lims_credentials.user, + host=lims_credentials.host, + password=lims_credentials.password, + port=lims_credentials.port, + ) else: # Currying is equivalent to decorator syntactic sugar - self.lims_db = (credential_injector(LIMS_DB_CREDENTIAL_MAP) - (PostgresQueryMixin)()) + self.lims_db = credential_injector(LIMS_DB_CREDENTIAL_MAP)(PostgresQueryMixin)() def get_experiment_id(self): return self.experiment_id def get_behavior_tracking_video_filepath_df(self): - query = ''' + query = """ SELECT wkf.storage_directory || wkf.filename AS raw_behavior_tracking_video_filepath, attachable_type FROM well_known_files wkf WHERE wkf.well_known_file_type_id IN (SELECT id FROM well_known_file_types WHERE name = 'RawBehaviorTrackingVideo') - ''' + """ return pd.read_sql(query, self.lims_db.get_connection()) def get_eye_tracking_video_filepath_df(self): - query = ''' + query = """ SELECT wkf.storage_directory || wkf.filename AS raw_behavior_tracking_video_filepath, attachable_type FROM well_known_files wkf WHERE wkf.well_known_file_type_id IN (SELECT id FROM well_known_file_types WHERE name = 'RawEyeTrackingVideo') - ''' + """ return pd.read_sql(query, self.lims_db.get_connection()) if __name__ == "__main__": - api = LimsApi() for ii in range(5): print(api.get_eye_tracking_video_filepath_df().loc[ii].raw_behavior_tracking_video_filepath) diff --git a/allensdk/internal/api/mtrain_api.py b/allensdk/internal/api/mtrain_api.py index c876490426..deedc8c89b 100644 --- a/allensdk/internal/api/mtrain_api.py +++ b/allensdk/internal/api/mtrain_api.py @@ -8,167 +8,145 @@ from . import PostgresQueryMixin, db_connection_creator from allensdk.brain_observatory.behavior.trials_processing import EDF_COLUMNS -from allensdk.core.auth_config import MTRAIN_DB_CREDENTIAL_MAP, \ - LIMS_DB_CREDENTIAL_MAP +from allensdk.core.auth_config import MTRAIN_DB_CREDENTIAL_MAP, LIMS_DB_CREDENTIAL_MAP from allensdk.core.authentication import credential_injector -from allensdk.brain_observatory.behavior.data_objects \ - import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.metadata.\ - behavior_metadata.behavior_metadata import BehaviorMetadata +from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) class MtrainApi: - - def __init__(self, api_base='http://mtrain:5000'): + def __init__(self, api_base="http://mtrain:5000"): self.api_base = api_base def get_page(self, table_name, get_obj=None, filters=[], **kwargs): - if get_obj is None: get_obj = requests - data = {'total_pages': '--'} + data = {"total_pages": "--"} for ii in itertools.count(1): sys.stdout.flush() - uri = '/'.join([self.api_base, - "api/v1/%s?page=%i&q={\"filters\":%s}" % ( - table_name, ii, json.dumps(filters))]) + uri = "/".join( + [self.api_base, 'api/v1/%s?page=%i&q={"filters":%s}' % (table_name, ii, json.dumps(filters))] + ) tmp = get_obj.get(uri, **kwargs) try: data = tmp.json() except TypeError: data = tmp.json - if 'message' not in data: + if "message" not in data: df = pd.DataFrame(data["objects"]) sys.stdout.flush() yield df - if 'total_pages' not in data or data['total_pages'] == ii: + if "total_pages" not in data or data["total_pages"] == ii: return def get_df(self, table_name, get_obj=None, **kwargs): - return pd.concat([df for df in - self.get_page(table_name, get_obj=get_obj, - **kwargs)], axis=0) + return pd.concat([df for df in self.get_page(table_name, get_obj=get_obj, **kwargs)], axis=0) def get_subjects(self): - return self.get_df('subjects').LabTracks_ID.values - - def get_session(self, behavior_session_uuid=None, - behavior_session_id=None): - assert not all(v is None for v in [ - behavior_session_uuid, - behavior_session_id]), 'must enter either a ' \ - 'behavior_session_uuid or a ' \ - 'behavior_session_id' + return self.get_df("subjects").LabTracks_ID.values + + def get_session(self, behavior_session_uuid=None, behavior_session_id=None): + assert not all(v is None for v in [behavior_session_uuid, behavior_session_id]), ( + "must enter either a behavior_session_uuid or a behavior_session_id" + ) if behavior_session_id is not None: + def _get_behavior_metadata(): - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) - behavior_session_id_ = BehaviorSessionId( - behavior_session_id=behavior_session_id) - bm = BehaviorMetadata.from_lims( - behavior_session_id=behavior_session_id_, lims_db=lims_db) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + behavior_session_id_ = BehaviorSessionId(behavior_session_id=behavior_session_id) + bm = BehaviorMetadata.from_lims(behavior_session_id=behavior_session_id_, lims_db=lims_db) return bm + bm = _get_behavior_metadata() if behavior_session_uuid is not None: # if both a behavior session uuid and a lims id are entered, # ensure that they match - assert behavior_session_uuid == \ - str(bm.behavior_session_uuid), \ - 'behavior_session {} does not match ' \ - 'behavior_session_id {}'.format( - behavior_session_uuid, bm.behavior_session_uuid) + assert behavior_session_uuid == str(bm.behavior_session_uuid), ( + "behavior_session {} does not match behavior_session_id {}".format( + behavior_session_uuid, bm.behavior_session_uuid + ) + ) else: # get a behavior session uuid if a lims ID was entered behavior_session_uuid = str(bm.behavior_session_uuid) filters = [{"name": "id", "op": "eq", "val": behavior_session_uuid}] - behavior_df = self.get_df('behavior_sessions', filters=filters).rename( - columns={'id': 'behavior_session_uuid'}) - state_df = self.get_df('states').rename(columns={'id': 'state_id'}) - regimen_df = self.get_df('regimens').rename( - columns={'id': 'regimen_id', 'name': 'regimen_name'}).drop( - ['states', 'active'], axis=1) - stage_df = self.get_df('stages').rename( - columns={'id': 'stage_id'}).drop(['states'], axis=1) - - behavior_df = pd.merge(behavior_df, state_df, how='left', - on='state_id') - behavior_df = pd.merge(behavior_df, stage_df, how='left', - on='stage_id') - behavior_df = pd.merge(behavior_df, regimen_df, how='left', - on='regimen_id') - behavior_df.drop(['state_id', 'stage_id', 'regimen_id'], inplace=True, - axis=1) + behavior_df = self.get_df("behavior_sessions", filters=filters).rename(columns={"id": "behavior_session_uuid"}) + state_df = self.get_df("states").rename(columns={"id": "state_id"}) + regimen_df = ( + self.get_df("regimens") + .rename(columns={"id": "regimen_id", "name": "regimen_name"}) + .drop(["states", "active"], axis=1) + ) + stage_df = self.get_df("stages").rename(columns={"id": "stage_id"}).drop(["states"], axis=1) + + behavior_df = pd.merge(behavior_df, state_df, how="left", on="state_id") + behavior_df = pd.merge(behavior_df, stage_df, how="left", on="stage_id") + behavior_df = pd.merge(behavior_df, regimen_df, how="left", on="regimen_id") + behavior_df.drop(["state_id", "stage_id", "regimen_id"], inplace=True, axis=1) if len(behavior_df) == 0: raise RuntimeError("Session not found %s:" % behavior_session_uuid) assert len(behavior_df) == 1 session_dict = behavior_df.iloc[0].to_dict() - filters = [{"name": "behavior_session_uuid", "op": "eq", - "val": behavior_session_uuid}] - trials_df = self.get_df('trials', filters=filters).sort_values( - 'index').drop(['id', 'behavior_session'], axis=1).set_index( - 'index', drop=False) - trials_df['behavior_session_uuid'] = trials_df[ - 'behavior_session_uuid'].map(uuid.UUID) + filters = [{"name": "behavior_session_uuid", "op": "eq", "val": behavior_session_uuid}] + trials_df = ( + self.get_df("trials", filters=filters) + .sort_values("index") + .drop(["id", "behavior_session"], axis=1) + .set_index("index", drop=False) + ) + trials_df["behavior_session_uuid"] = trials_df["behavior_session_uuid"].map(uuid.UUID) trials_df.index.name = None - session_dict['trials'] = trials_df[EDF_COLUMNS] + session_dict["trials"] = trials_df[EDF_COLUMNS] return session_dict def get_behavior_training_df(self, LabTracks_ID=None): if LabTracks_ID is not None: - filters = [ - {"name": "LabTracks_ID", "op": "eq", "val": LabTracks_ID}] + filters = [{"name": "LabTracks_ID", "op": "eq", "val": LabTracks_ID}] else: filters = [] - behavior_df = self.get_df('behavior_sessions', filters=filters).rename( - columns={'id': 'behavior_session_uuid'}) - - state_df = self.get_df('states').rename(columns={'id': 'state_id'}) - regimen_df = self.get_df('regimens').rename( - columns={'id': 'regimen_id', 'name': 'regimen_name'}).drop( - ['states', 'active'], axis=1) - stage_df = self.get_df('stages').rename( - columns={'id': 'stage_id', 'name': 'stage_name'}).drop(['states'], - axis=1) - - behavior_df = pd.merge(behavior_df, state_df, how='left', - on='state_id') - behavior_df = pd.merge(behavior_df, stage_df, how='left', - on='stage_id') - behavior_df = pd.merge(behavior_df, regimen_df, how='left', - on='regimen_id') + behavior_df = self.get_df("behavior_sessions", filters=filters).rename(columns={"id": "behavior_session_uuid"}) + + state_df = self.get_df("states").rename(columns={"id": "state_id"}) + regimen_df = ( + self.get_df("regimens") + .rename(columns={"id": "regimen_id", "name": "regimen_name"}) + .drop(["states", "active"], axis=1) + ) + stage_df = ( + self.get_df("stages").rename(columns={"id": "stage_id", "name": "stage_name"}).drop(["states"], axis=1) + ) + + behavior_df = pd.merge(behavior_df, state_df, how="left", on="state_id") + behavior_df = pd.merge(behavior_df, stage_df, how="left", on="stage_id") + behavior_df = pd.merge(behavior_df, regimen_df, how="left", on="regimen_id") return behavior_df def get_current_stage(self, LabTracks_ID): sess = requests.Session() - state_response = sess.get(os.path.join(self.api_base, 'get_script/'), - data=json.dumps({ - 'LabTracks_ID': LabTracks_ID})) # + state_response = sess.get( + os.path.join(self.api_base, "get_script/"), data=json.dumps({"LabTracks_ID": LabTracks_ID}) + ) # # .json()#['objects']).keys() - return state_response.json()['data']['parameters']['stage'] + return state_response.json()["data"]["parameters"]["stage"] class MtrainSqlApi: - - def __init__(self, dbname=None, user=None, host=None, password=None, - port=None): + def __init__(self, dbname=None, user=None, host=None, password=None, port=None): if any(map(lambda x: x is None, [dbname, user, host, password, port])): # Currying is equivalent to decorator syntactic sugar - self.mtrain_db = ( - credential_injector(MTRAIN_DB_CREDENTIAL_MAP) - (PostgresQueryMixin)()) + self.mtrain_db = credential_injector(MTRAIN_DB_CREDENTIAL_MAP)(PostgresQueryMixin)() else: - self.mtrain_db = PostgresQueryMixin( - dbname=dbname, user=user, host=host, password=password, - port=port) + self.mtrain_db = PostgresQueryMixin(dbname=dbname, user=user, host=host, password=password, port=port) def get_subjects(self): query = 'SELECT "LabTracks_ID" FROM subjects' @@ -177,12 +155,14 @@ def get_subjects(self): def get_behavior_training_df(self, LabTracks_ID): connection = self.mtrain_db.get_connection() dataframe = pd.read_sql( - '''SELECT stages.name as stage_name, regimens.name as + """SELECT stages.name as stage_name, regimens.name as regimen_name, bs.date, bs.id as behavior_session_id FROM behavior_sessions bs LEFT JOIN states ON states.id = bs.state_id LEFT JOIN regimens ON regimens.id = states.regimen_id LEFT JOIN stages ON stages.id = states.stage_id WHERE "LabTracks_ID"={} - '''.format(LabTracks_ID), connection) - return dataframe.sort_values(by='date') + """.format(LabTracks_ID), + connection, + ) + return dataframe.sort_values(by="date") diff --git a/allensdk/internal/api/queries/behavior_lims_queries.py b/allensdk/internal/api/queries/behavior_lims_queries.py index 04d583f329..8640f5f481 100644 --- a/allensdk/internal/api/queries/behavior_lims_queries.py +++ b/allensdk/internal/api/queries/behavior_lims_queries.py @@ -3,14 +3,12 @@ from allensdk.internal.api import PostgresQueryMixin import logging -from allensdk.internal.api.queries.utils import ( - build_in_list_selector_query) +from allensdk.internal.api.queries.utils import build_in_list_selector_query def foraging_id_map_from_behavior_session_id( - lims_engine: PostgresQueryMixin, - behavior_session_ids: List[int], - logger: Optional[logging.RootLogger] = None) -> pd.DataFrame: + lims_engine: PostgresQueryMixin, behavior_session_ids: List[int], logger: Optional[logging.RootLogger] = None +) -> pd.DataFrame: """ Returns DataFrame with two columns: foraging_id @@ -26,9 +24,7 @@ def foraging_id_map_from_behavior_session_id( the foraging_id """ - behav_ids = build_in_list_selector_query("id", - behavior_session_ids, - operator="AND") + behav_ids = build_in_list_selector_query("id", behavior_session_ids, operator="AND") forag_ids_query = f""" SELECT foraging_id, id as behavior_session_id FROM behavior_sessions @@ -36,20 +32,21 @@ def foraging_id_map_from_behavior_session_id( {behav_ids}; """ if logger is not None: - logger.debug("get_foraging_ids_from_behavior_session query: \n" - f"{forag_ids_query}") + logger.debug(f"get_foraging_ids_from_behavior_session query: \n{forag_ids_query}") foraging_id_map = lims_engine.select(forag_ids_query) if logger is not None: - logger.debug(f"Retrieved {len(foraging_id_map)} foraging ids for" - " behavior stage query. " - f"Ids = {foraging_id_map.foraging_id}") + logger.debug( + f"Retrieved {len(foraging_id_map)} foraging ids for" + " behavior stage query. " + f"Ids = {foraging_id_map.foraging_id}" + ) return foraging_id_map def stimulus_pickle_paths_from_behavior_session_ids( - lims_connection: PostgresQueryMixin, - behavior_session_id_list: List[int]) -> pd.DataFrame: + lims_connection: PostgresQueryMixin, behavior_session_id_list: List[int] +) -> pd.DataFrame: """ Get a DataFrame mapping behavior_session_id to stimulus_pickle_path @@ -82,10 +79,7 @@ def stimulus_pickle_paths_from_behavior_session_ids( wkft.name = 'StimulusPickle' AND wkf.attachable_type = 'BehaviorSession' - {build_in_list_selector_query( - operator='AND', - col='beh.id', - valid_list=behavior_session_id_list)} + {build_in_list_selector_query(operator="AND", col="beh.id", valid_list=behavior_session_id_list)} """ beh_to_path = lims_connection.select(query) diff --git a/allensdk/internal/api/queries/biophysical_module_api.py b/allensdk/internal/api/queries/biophysical_module_api.py index 80c24ecf33..f37d5571fd 100644 --- a/allensdk/internal/api/queries/biophysical_module_api.py +++ b/allensdk/internal/api/queries/biophysical_module_api.py @@ -15,97 +15,97 @@ from allensdk.api.queries.rma_template import RmaTemplate + class BiophysicalModuleApi(RmaTemplate): - '''''' - - rma_templates = \ - {"biophysical_lims_queries": [ - {'name': 'neuronal_model_runs_by_ids', - 'description': 'see name', - 'model': 'NeuronalModelRun', - 'criteria': '[id$in{{ neuronal_model_run_ids }}]', - 'include': 'well_known_files(well_known_file_type),' - 'neuronal_model(well_known_files(well_known_file_type),' - 'specimen(project,specimen_tags,' - 'ephys_roi_result' - '(ephys_qc_criteria,' - 'well_known_files(well_known_file_type)),' - 'neuron_reconstructions' - '(well_known_files(well_known_file_type)),' - 'ephys_sweeps' - '(ephys_sweep_tags,' - 'ephys_stimulus(ephys_stimulus_type))),' - 'neuronal_model_template' - '(neuronal_model_template_type,' - 'well_known_files(well_known_file_type)))', - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['neuronal_model_run_ids'] + """""" + + rma_templates = { + "biophysical_lims_queries": [ + { + "name": "neuronal_model_runs_by_ids", + "description": "see name", + "model": "NeuronalModelRun", + "criteria": "[id$in{{ neuronal_model_run_ids }}]", + "include": "well_known_files(well_known_file_type)," + "neuronal_model(well_known_files(well_known_file_type)," + "specimen(project,specimen_tags," + "ephys_roi_result" + "(ephys_qc_criteria," + "well_known_files(well_known_file_type))," + "neuron_reconstructions" + "(well_known_files(well_known_file_type))," + "ephys_sweeps" + "(ephys_sweep_tags," + "ephys_stimulus(ephys_stimulus_type)))," + "neuronal_model_template" + "(neuronal_model_template_type," + "well_known_files(well_known_file_type)))", + "num_rows": "all", + "count": False, + "criteria_params": ["neuronal_model_run_ids"], }, - {'name': 'neuronal_models_by_ids', - 'description': 'see name', - 'model': 'NeuronalModel', - 'criteria': '[id$in{{ neuronal_model_ids }}]', - 'include': 'well_known_files(well_known_file_type),' - 'specimen(project,specimen_tags,' - 'ephys_roi_result' - '(ephys_qc_criteria,' - 'well_known_files(well_known_file_type)),' - 'neuron_reconstructions' - '(well_known_files(well_known_file_type)),' - 'ephys_sweeps' - '(ephys_sweep_tags,' - 'ephys_stimulus(ephys_stimulus_type))),' - 'neuronal_model_template' - '(neuronal_model_template_type,' - 'well_known_files(well_known_file_type))', - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['neuronal_model_ids'] - } - ]} - - + { + "name": "neuronal_models_by_ids", + "description": "see name", + "model": "NeuronalModel", + "criteria": "[id$in{{ neuronal_model_ids }}]", + "include": "well_known_files(well_known_file_type)," + "specimen(project,specimen_tags," + "ephys_roi_result" + "(ephys_qc_criteria," + "well_known_files(well_known_file_type))," + "neuron_reconstructions" + "(well_known_files(well_known_file_type))," + "ephys_sweeps" + "(ephys_sweep_tags," + "ephys_stimulus(ephys_stimulus_type)))," + "neuronal_model_template" + "(neuronal_model_template_type," + "well_known_files(well_known_file_type))", + "num_rows": "all", + "count": False, + "criteria_params": ["neuronal_model_ids"], + }, + ] + } + def __init__(self, base_uri=None): - super(BiophysicalModuleApi, self).__init__(base_uri, - query_manifest=BiophysicalModuleApi.rma_templates) - - + super(BiophysicalModuleApi, self).__init__(base_uri, query_manifest=BiophysicalModuleApi.rma_templates) + def get_neuronal_model_runs(self, neuronal_model_run_ids=None): - '''List Neuronal Model Rusn available through LIMS + """List Neuronal Model Rusn available through LIMS with associated info needed to run in NEURON. - + Parameters ---------- neuronal_model_run_ids : integer or list of integers, optional only select specific neuronal_model_runs. - + Returns ------- - dict : neuronal model run metadata - ''' - data = self.template_query('biophysical_lims_queries', - 'neuronal_model_runs_by_ids', - neuronal_model_run_ids=neuronal_model_run_ids) - + dict : neuronal model run metadata + """ + data = self.template_query( + "biophysical_lims_queries", "neuronal_model_runs_by_ids", neuronal_model_run_ids=neuronal_model_run_ids + ) + return data - - + def get_neuronal_models(self, neuronal_model_ids=None): - '''List Neuronal Models available through LIMS + """List Neuronal Models available through LIMS with associated info needed to run in NEURON. - + Parameters ---------- neuronal_model_ids : integer or list of integers, optional only select specific neuronal_models. - + Returns ------- - dict : neuronal model metadata - ''' - data = self.template_query('biophysical_lims_queries', - 'neuronal_models_by_ids', - neuronal_model_ids=neuronal_model_ids) - - return data \ No newline at end of file + dict : neuronal model metadata + """ + data = self.template_query( + "biophysical_lims_queries", "neuronal_models_by_ids", neuronal_model_ids=neuronal_model_ids + ) + + return data diff --git a/allensdk/internal/api/queries/biophysical_module_reader.py b/allensdk/internal/api/queries/biophysical_module_reader.py index 709663e3f1..0cb9699638 100644 --- a/allensdk/internal/api/queries/biophysical_module_reader.py +++ b/allensdk/internal/api/queries/biophysical_module_reader.py @@ -44,7 +44,7 @@ def read_lims_file(self, lims_path): self.lims_update_data = dict(self.lims_data) def read_json(self, path): - with open(path, 'rb') as f: + with open(path, "rb") as f: self.read_json_string(f.read()) def read_json_string(self, json_string): @@ -52,80 +52,79 @@ def read_json_string(self, json_string): self.lims_update_data = dict(self.lims_data) def stimulus_file_entries(self): - ''' read the well known file path from the lims result - corresponding to the stimulus file - :return: well_known_file entries - :rtype: array of dicts - ''' - neuronal_model = self.lims_data['neuronal_model'] - specimen = neuronal_model['specimen'] - roi_result = specimen['ephys_roi_result'] - well_known_files = roi_result['well_known_files'] + """read the well known file path from the lims result + corresponding to the stimulus file + :return: well_known_file entries + :rtype: array of dicts + """ + neuronal_model = self.lims_data["neuronal_model"] + specimen = neuronal_model["specimen"] + roi_result = specimen["ephys_roi_result"] + well_known_files = roi_result["well_known_files"] stimulus_file_entries = [] for well_known_file in well_known_files: try: - file_type_id = well_known_file['well_known_file_type_id'] - + file_type_id = well_known_file["well_known_file_type_id"] + if file_type_id == lims_utilities.NWB_FILE_TYPE_ID: stimulus_file_entries.append(well_known_file) except Exception: - self._log.warn('skipping well known file record with no well known file type.') + self._log.warn("skipping well known file record with no well known file type.") return stimulus_file_entries def stimulus_path(self): - ''' Get the path to the stimulus file from the lims result. - :return: path to stimulus file - :rtype: string - ''' + """Get the path to the stimulus file from the lims result. + :return: path to stimulus file + :rtype: string + """ file_entries = self.stimulus_file_entries() if len(file_entries) > 1: - self._log.warning('More than one stimulus file found.') + self._log.warning("More than one stimulus file found.") file_entry = file_entries[0] - stimulus_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + stimulus_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) return stimulus_path def lims_working_directory(self): - ''' While this is the same directory as the neuronal_model_run - directory, it can be mocked out for testing if the - other directory is read only. - ''' + """While this is the same directory as the neuronal_model_run + directory, it can be mocked out for testing if the + other directory is read only. + """ return self.neuronal_model_run_dir() def neuronal_model_run_dir(self): - ''' read the directory path where - output goes from the lims optimization config json - - Parameters - ---------- - - Returns - ------- - string: - directory path - ''' - return self.lims_data['storage_directory'] + """read the directory path where + output goes from the lims optimization config json + + Parameters + ---------- + + Returns + ------- + string: + directory path + """ + return self.lims_data["storage_directory"] def fit_parameters_file_entries(self): - ''' read the fit_parameter file path from the lims result - corresponding to the stimulus file - :return: well_known_file entries - :rtype: array of dicts - ''' - neuronal_model = self.lims_data['neuronal_model'] - well_known_files = neuronal_model['well_known_files'] + """read the fit_parameter file path from the lims result + corresponding to the stimulus file + :return: well_known_file entries + :rtype: array of dicts + """ + neuronal_model = self.lims_data["neuronal_model"] + well_known_files = neuronal_model["well_known_files"] fit_parameter_file_entries = [] for well_known_file in well_known_files: - file_type_id = well_known_file['well_known_file_type_id'] + file_type_id = well_known_file["well_known_file_type_id"] if file_type_id == lims_utilities.MODEL_PARAMETERS_FILE_TYPE_ID: fit_parameter_file_entries.append(well_known_file) @@ -133,283 +132,266 @@ def fit_parameters_file_entries(self): return fit_parameter_file_entries def fit_parameters_path(self): - ''' Get the path to the fit parameters file from the lims result. - :return: path to file - :rtype: string - ''' + """Get the path to the fit parameters file from the lims result. + :return: path to file + :rtype: string + """ file_entries = self.fit_parameters_file_entries() if len(file_entries) > 1: - self._log.warning('More than one fit parameter file found.') + self._log.warning("More than one fit parameter file found.") file_entry = file_entries[0] - fit_parameter_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + fit_parameter_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) return fit_parameter_path def model_type(self): - ''' TODO: comment - ''' - return self.lims_data['neuronal_model']['neuronal_model_template']['name'] + """TODO: comment""" + return self.lims_data["neuronal_model"]["neuronal_model_template"]["name"] def morphology_file_entries(self): - ''' read the well known file paths - from the lims result corresponding to the morphology - - Returns - ------- - arrary of dicts: - well known file entries - ''' - neuronal_model = self.lims_data['neuronal_model'] - specimen = neuronal_model['specimen'] - reconstructions = specimen['neuron_reconstructions'] + """read the well known file paths + from the lims result corresponding to the morphology + + Returns + ------- + arrary of dicts: + well known file entries + """ + neuronal_model = self.lims_data["neuronal_model"] + specimen = neuronal_model["specimen"] + reconstructions = specimen["neuron_reconstructions"] morphology_file_entries = [] for reconstruction in reconstructions: - superseded = reconstruction['superseded'] - manual = reconstruction['manual'] + superseded = reconstruction["superseded"] + manual = reconstruction["manual"] if manual and not superseded: - well_known_files = reconstruction['well_known_files'] + well_known_files = reconstruction["well_known_files"] for well_known_file in well_known_files: - file_type_id = well_known_file['well_known_file_type_id'] - + file_type_id = well_known_file["well_known_file_type_id"] + if file_type_id == BiophysicalModuleReader.MORPHOLOGY_TYPE_ID: morphology_file_entries.append(well_known_file) return morphology_file_entries def morphology_path(self): - ''' Get the path to the morphology file from the lims result. - :return: path to morphology file - :rtype: string - ''' + """Get the path to the morphology file from the lims result. + :return: path to morphology file + :rtype: string + """ file_entries = self.morphology_file_entries() if len(file_entries) > 1: - self._log.warning('More than one morphology file found.') + self._log.warning("More than one morphology file found.") file_entry = file_entries[0] - morphology_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + morphology_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) return morphology_path def sweep_entries(self): - ''' read the sweep entries - from the lims result corresponding to the stimulus - :return: stimulus sweep entries - :rtype: array of dicts - ''' - neuronal_model = self.lims_data['neuronal_model'] - specimen = neuronal_model['specimen'] - sweeps = specimen['ephys_sweeps'] + """read the sweep entries + from the lims result corresponding to the stimulus + :return: stimulus sweep entries + :rtype: array of dicts + """ + neuronal_model = self.lims_data["neuronal_model"] + specimen = neuronal_model["specimen"] + sweeps = specimen["ephys_sweeps"] return sweeps def sweep_numbers_by_type(self): sweeps = self.sweep_entries() - d = {s['ephys_stimulus']['ephys_stimulus_type']['name']: [] for s in sweeps} + d = {s["ephys_stimulus"]["ephys_stimulus_type"]["name"]: [] for s in sweeps} for n, s in enumerate(sweeps): - t = s['ephys_stimulus']['ephys_stimulus_type']['name'] - d[t].append(s['sweep_number']) + t = s["ephys_stimulus"]["ephys_stimulus_type"]["name"] + d[t].append(s["sweep_number"]) return d def sweep_numbers(self): - ''' Get the stimulus sweep numbers from the lims result - :return: list of sweep numbers - :rtype: array of ints - ''' + """Get the stimulus sweep numbers from the lims result + :return: list of sweep numbers + :rtype: array of ints + """ sweep_entries = self.sweep_entries() if not sweep_entries or len(sweep_entries) < 1: - self._log.warning('No sweeps found.') + self._log.warning("No sweeps found.") - sweeps = [sweep_entry['sweep_number'] \ - for sweep_entry in sweep_entries \ - if sweep_entry['workflow_state'] == 'auto_passed' or \ - sweep_entry['workflow_state'] == 'manual_passed' ] + sweeps = [ + sweep_entry["sweep_number"] + for sweep_entry in sweep_entries + if sweep_entry["workflow_state"] == "auto_passed" or sweep_entry["workflow_state"] == "manual_passed" + ] return list(set(sweeps)) def mod_file_entries(self): - ''' read the NERUON .mod file entries - from the lims result corresponding to the NeuronModel - :return: well known file entries - :rtype: array of dicts - ''' - neuronal_model = self.lims_data['neuronal_model'] - model_template = neuronal_model['neuronal_model_template'] - well_known_files = model_template['well_known_files'] + """read the NERUON .mod file entries + from the lims result corresponding to the NeuronModel + :return: well known file entries + :rtype: array of dicts + """ + neuronal_model = self.lims_data["neuronal_model"] + model_template = neuronal_model["neuronal_model_template"] + well_known_files = model_template["well_known_files"] mod_file_entries = [] for well_known_file in well_known_files: - file_type_id = well_known_file['well_known_file_type_id'] - + file_type_id = well_known_file["well_known_file_type_id"] + if file_type_id == BiophysicalModuleReader.MOD_FILE_TYPE_ID: mod_file_entries.append(well_known_file) - + return mod_file_entries def mod_file_paths(self): - ''' Get the paths to the mod files from the lims result. - :return: paths to mod files - :rtype: array of strings - ''' + """Get the paths to the mod files from the lims result. + :return: paths to mod files + :rtype: array of strings + """ file_entries = self.mod_file_entries() if not file_entries or len(file_entries) < 1: - self._log.warning('No mod files found.') + self._log.warning("No mod files found.") mod_file_paths = [] for file_entry in file_entries: - mod_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + mod_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) self._log.info(mod_path) mod_file_paths.append(mod_path) return mod_file_paths - def update_well_known_file(self, - path, - well_known_file_type_id=None): + def update_well_known_file(self, path, well_known_file_type_id=None): if well_known_file_type_id is None: - well_known_file_type_id = \ - lims_utilities.NWB_UNCOMPRESSED_FILE_TYPE_ID - well_known_files = self.lims_data['well_known_files'] + well_known_file_type_id = lims_utilities.NWB_UNCOMPRESSED_FILE_TYPE_ID + well_known_files = self.lims_data["well_known_files"] (dirname, filename) = os.path.split(os.path.abspath(path)) def get_nwb_file_id(f): - if ('well_known_file_type_id' in f and - f['well_known_file_type_id'] == well_known_file_type_id and - os.path.normpath(f['storage_directory']) == - os.path.normpath(dirname) and - f['filename'] == filename): - return f['id'] + if ( + "well_known_file_type_id" in f + and f["well_known_file_type_id"] == well_known_file_type_id + and os.path.normpath(f["storage_directory"]) == os.path.normpath(dirname) + and f["filename"] == filename + ): + return f["id"] else: return None def not_nwb_file(f): - if ('well_known_file_type_id' in f and - f['well_known_file_type_id'] == well_known_file_type_id): + if "well_known_file_type_id" in f and f["well_known_file_type_id"] == well_known_file_type_id: return False else: return True try: - existing_file_id = \ - next(wkf_id for wkf_id - in (get_nwb_file_id(f2) - for f2 in well_known_files) - if wkf_id) + existing_file_id = next(wkf_id for wkf_id in (get_nwb_file_id(f2) for f2 in well_known_files) if wkf_id) # existing parameter file found - self.lims_update_data['well_known_files'] = \ - [f for f in well_known_files if not_nwb_file(f)] - - self.lims_update_data['well_known_files'] += [{ - 'id': existing_file_id, - 'content_type': None, - 'filename': filename, - 'storage_directory': dirname, - 'well_known_file_type_id': well_known_file_type_id - }] + self.lims_update_data["well_known_files"] = [f for f in well_known_files if not_nwb_file(f)] + + self.lims_update_data["well_known_files"] += [ + { + "id": existing_file_id, + "content_type": None, + "filename": filename, + "storage_directory": dirname, + "well_known_file_type_id": well_known_file_type_id, + } + ] except StopIteration: # no matching nwb files found, remove possible unmatching - self.lims_update_data['well_known_files'] = \ - [f for f in well_known_files if not_nwb_file(f)] - + self.lims_update_data["well_known_files"] = [f for f in well_known_files if not_nwb_file(f)] + (dirname, filename) = os.path.split(os.path.abspath(path)) - self.lims_update_data['well_known_files'] += [{ - 'content_type': None, - 'filename': filename, - 'storage_directory': dirname, - 'well_known_file_type_id': well_known_file_type_id - }] + self.lims_update_data["well_known_files"] += [ + { + "content_type": None, + "filename": filename, + "storage_directory": dirname, + "well_known_file_type_id": well_known_file_type_id, + } + ] def set_workflow_state(self, state): - self.lims_update_data['workflow_state'] = state + self.lims_update_data["workflow_state"] = state def write_file(self, path): - with open(path, 'wb') as f: + with open(path, "wb") as f: f.write(json.dumps(self.lims_update_data, indent=2)) def to_manifest(self, manifest_path=None): b = ManifestBuilder() - b.add_path('BASEDIR', os.path.realpath(os.curdir)) + b.add_path("BASEDIR", os.path.realpath(os.curdir)) - b.add_path('WORKDIR', - os.path.realpath(os.curdir)) + b.add_path("WORKDIR", os.path.realpath(os.curdir)) - b.add_path('MORPHOLOGY', - self.morphology_path(), - typename='file') - b.add_path('CODE_DIR', 'templates') - b.add_path('MODFILE_DIR', 'modfiles') + b.add_path("MORPHOLOGY", self.morphology_path(), typename="file") + b.add_path("CODE_DIR", "templates") + b.add_path("MODFILE_DIR", "modfiles") for modfile in self.mod_file_entries(): - b.add_path('MOD_FILE_%s' % (os.path.splitext(modfile['filename'])[0]), - os.path.join(modfile['storage_directory'], - modfile['filename']), - typename='file', - format='MODFILE') - - b.add_path('neuronal_model_run_data', - self.lims_path, - typename='file') - - b.add_path('stimulus_path', - self.stimulus_path(), - typename='file', - format='NWB') - - b.add_path('manifest', - os.path.join(os.path.realpath(os.curdir), - manifest_path), - typename='file') - - neuronal_model_run_id = self.lims_data['id'] - - nwb_file_name, extension = \ - os.path.splitext(os.path.basename(self.stimulus_path())) - b.add_path('output_path', - '%d_virtual_experiment%s' % (neuronal_model_run_id, - extension), - typename='file', - parent_key='WORKDIR', - format='NWB') - - b.add_path('fit_parameters', - self.fit_parameters_path()) - - b.add_section('biophys', - {"biophys": [{"model_file": [ manifest_path, - self.fit_parameters_path() ], - "model_type": self.model_type()}]}) - - b.add_section('stimulus_conf', - {"runs": [{"neuronal_model_run_id": - neuronal_model_run_id, - "sweeps": self.sweep_numbers(), - "sweeps_by_type": self.sweep_numbers_by_type() - }]}) - - b.add_section('hoc_conf', - {"neuron" : [{"hoc": ["stdgui.hoc", - "import3d.hoc", - "cell.hoc" ] - }]}) + b.add_path( + "MOD_FILE_%s" % (os.path.splitext(modfile["filename"])[0]), + os.path.join(modfile["storage_directory"], modfile["filename"]), + typename="file", + format="MODFILE", + ) + + b.add_path("neuronal_model_run_data", self.lims_path, typename="file") + + b.add_path("stimulus_path", self.stimulus_path(), typename="file", format="NWB") + + b.add_path("manifest", os.path.join(os.path.realpath(os.curdir), manifest_path), typename="file") + + neuronal_model_run_id = self.lims_data["id"] + + nwb_file_name, extension = os.path.splitext(os.path.basename(self.stimulus_path())) + b.add_path( + "output_path", + "%d_virtual_experiment%s" % (neuronal_model_run_id, extension), + typename="file", + parent_key="WORKDIR", + format="NWB", + ) + + b.add_path("fit_parameters", self.fit_parameters_path()) + + b.add_section( + "biophys", + {"biophys": [{"model_file": [manifest_path, self.fit_parameters_path()], "model_type": self.model_type()}]}, + ) + + b.add_section( + "stimulus_conf", + { + "runs": [ + { + "neuronal_model_run_id": neuronal_model_run_id, + "sweeps": self.sweep_numbers(), + "sweeps_by_type": self.sweep_numbers_by_type(), + } + ] + }, + ) + + b.add_section("hoc_conf", {"neuron": [{"hoc": ["stdgui.hoc", "import3d.hoc", "cell.hoc"]}]}) m = Manifest(config=b.path_info) diff --git a/allensdk/internal/api/queries/compound_lims_queries.py b/allensdk/internal/api/queries/compound_lims_queries.py index 6479082484..c3dc3f1d2e 100644 --- a/allensdk/internal/api/queries/compound_lims_queries.py +++ b/allensdk/internal/api/queries/compound_lims_queries.py @@ -1,14 +1,12 @@ from typing import List import pandas as pd from allensdk.internal.api import PostgresQueryMixin -from allensdk.internal.api.queries.ecephys_lims_queries import ( - donor_id_list_from_ecephys_session_ids) +from allensdk.internal.api.queries.ecephys_lims_queries import donor_id_list_from_ecephys_session_ids from allensdk.internal.api.queries.utils import build_in_list_selector_query def behavior_sessions_from_ecephys_session_ids( - lims_connection: PostgresQueryMixin, - ecephys_session_id_list: List[int] + lims_connection: PostgresQueryMixin, ecephys_session_id_list: List[int] ) -> pd.DataFrame: """ Get a DataFrame listing all of the behavior sessions that @@ -38,8 +36,8 @@ def behavior_sessions_from_ecephys_session_ids( listing every behavior session the mice in question went through """ donor_id_list = donor_id_list_from_ecephys_session_ids( - lims_connection=lims_connection, - session_id_list=ecephys_session_id_list) + lims_connection=lims_connection, session_id_list=ecephys_session_id_list + ) query = f""" SELECT @@ -59,10 +57,7 @@ def behavior_sessions_from_ecephys_session_ids( ON genders.id = donors.gender_id JOIN equipment ON equipment.id = behavior.equipment_id - {build_in_list_selector_query( - col='donors.id', - valid_list=donor_id_list - )} + {build_in_list_selector_query(col="donors.id", valid_list=donor_id_list)} """ mouse_to_behavior = lims_connection.select(query) diff --git a/allensdk/internal/api/queries/ecephys_lims_queries.py b/allensdk/internal/api/queries/ecephys_lims_queries.py index 5a1ab6787c..11c2e93712 100644 --- a/allensdk/internal/api/queries/ecephys_lims_queries.py +++ b/allensdk/internal/api/queries/ecephys_lims_queries.py @@ -6,8 +6,8 @@ def donor_id_lookup_from_ecephys_session_ids( - lims_connection: PostgresQueryMixin, - session_id_list: List[int]) -> pd.DataFrame: + lims_connection: PostgresQueryMixin, session_id_list: List[int] +) -> pd.DataFrame: """ Return a dataframe with columns ecephys_session_id @@ -23,24 +23,19 @@ def donor_id_lookup_from_ecephys_session_ids( specimens.donor_id = donors.id JOIN ecephys_sessions ON ecephys_sessions.specimen_id = specimens.id - {build_in_list_selector_query( - col='ecephys_sessions.id', - valid_list=session_id_list - )} + {build_in_list_selector_query(col="ecephys_sessions.id", valid_list=session_id_list)} """ result = lims_connection.select(query) return result def donor_id_list_from_ecephys_session_ids( - lims_connection: PostgresQueryMixin, - session_id_list: List[int]) -> List[int]: + lims_connection: PostgresQueryMixin, session_id_list: List[int] +) -> List[int]: """ Get the list of donor IDs associated with a list of ecephys_session_ids """ - lookup = donor_id_lookup_from_ecephys_session_ids( - lims_connection=lims_connection, - session_id_list=session_id_list) + lookup = donor_id_lookup_from_ecephys_session_ids(lims_connection=lims_connection, session_id_list=session_id_list) return list(np.unique(lookup.donor_id)) diff --git a/allensdk/internal/api/queries/equipment_lims_queries.py b/allensdk/internal/api/queries/equipment_lims_queries.py index e8f1db8841..8cecb5b3ed 100644 --- a/allensdk/internal/api/queries/equipment_lims_queries.py +++ b/allensdk/internal/api/queries/equipment_lims_queries.py @@ -3,9 +3,8 @@ def experiment_configs_from_equipment_id_and_type( - equipment_id: int, - config_type: str, - lims_connection: PostgresQueryMixin) -> pd.DataFrame: + equipment_id: int, config_type: str, lims_connection: PostgresQueryMixin +) -> pd.DataFrame: """ Return the configuration of a piece of experimental equipment as a function of time. @@ -45,15 +44,13 @@ def experiment_configs_from_equipment_id_and_type( """ config_id = lims_connection.fetchone(query) return experiment_configs_from_equipment_id( - equipment_id=equipment_id, - config_type_id=config_id, - lims_connection=lims_connection) + equipment_id=equipment_id, config_type_id=config_id, lims_connection=lims_connection + ) def experiment_configs_from_equipment_id( - equipment_id: int, - config_type_id: int, - lims_connection: PostgresQueryMixin) -> pd.DataFrame: + equipment_id: int, config_type_id: int, lims_connection: PostgresQueryMixin +) -> pd.DataFrame: """ Return the configuration of a piece of experimental equipment as a function of time. diff --git a/allensdk/internal/api/queries/grid_data_api_prerelease.py b/allensdk/internal/api/queries/grid_data_api_prerelease.py index ff74632b21..e8b2cda1c1 100644 --- a/allensdk/internal/api/queries/grid_data_api_prerelease.py +++ b/allensdk/internal/api/queries/grid_data_api_prerelease.py @@ -9,12 +9,12 @@ from ...core import lims_utilities as lu -_STORAGE_DIRECTORY_QUERY = ''' +_STORAGE_DIRECTORY_QUERY = """ select iser.id, iser.storage_directory from image_series as iser where iser.storage_directory is not null -''' +""" @cacheable() @@ -23,19 +23,20 @@ def _get_grid_storage_directories(grid_data_directory): storage_directories = dict() for row in query_result: - path = lu.safe_system_path(row[b'storage_directory']) + path = lu.safe_system_path(row[b"storage_directory"]) # NOTE: hacky, but grid directory contains files without having # injection_density_*.nrrd, projection_density_*.nrrd, ... - grid_example = os.path.join(path, grid_data_directory, 'data_mask_100.nrrd') + grid_example = os.path.join(path, grid_data_directory, "data_mask_100.nrrd") if os.path.exists(grid_example): - storage_directories[str(row[b'id'])] = path + storage_directories[str(row[b"id"])] = path return storage_directories + class GridDataApiPrerelease(GridDataApi): - '''Client for retrieving prereleased mouse connectivity data from lims. + """Client for retrieving prereleased mouse connectivity data from lims. Parameters ---------- @@ -44,12 +45,13 @@ class GridDataApiPrerelease(GridDataApi): file_name : string, optional File name to save/read storage_directories dict. Passed to GridDataApiPrerelease constructor. - ''' - GRID_DATA_DIRECTORY = 'grid' + """ + + GRID_DATA_DIRECTORY = "grid" @classmethod def from_file_name(cls, file_name, cache=True, **kwargs): - '''Alternative constructor using cache path file_name. + """Alternative constructor using cache path file_name. Parameters ---------- @@ -61,7 +63,7 @@ def from_file_name(cls, file_name, cache=True, **kwargs): Returns ------- cls : instance of GridDataApiPrerelease - ''' + """ if os.path.exists(file_name): storage_directories = json_utilities.read(file_name) else: @@ -74,13 +76,12 @@ def from_file_name(cls, file_name, cache=True, **kwargs): return cls(storage_directories, **kwargs) def __init__(self, storage_directories, resolution=None, base_uri=None): - super(GridDataApiPrerelease, self).__init__(resolution=resolution, - base_uri=base_uri) + super(GridDataApiPrerelease, self).__init__(resolution=resolution, base_uri=base_uri) self.storage_directories = storage_directories self.api = ApiPrerelease() def download_projection_grid_data(self, path, experiment_id, file_name): - '''Copy data from path to file_name. + """Copy data from path to file_name. Parameters ---------- @@ -90,11 +91,12 @@ def download_projection_grid_data(self, path, experiment_id, file_name): image series id. file_name : string path to file destination (copy target) - ''' + """ try: storage_path = self.storage_directories[str(experiment_id)] except KeyError as e: - error = ''' + error = ( + """ experiment %s is not in the storage_directories dictionary this can be a result of one or more of: * an invalid experiment id @@ -102,13 +104,14 @@ def download_projection_grid_data(self, path, experiment_id, file_name): try either removing the storage_directories_prerelease.json manifest from you manifest directory, or passing an updated storage_directories dict to the GridDataApiPrerelase constructor. - ''' % experiment_id + """ + % experiment_id + ) self._file_download_log.error(error) self.cleanup_truncated_file(path) raise ValueError(error) from e - storage_path = os.path.join( - storage_path, self.GRID_DATA_DIRECTORY, file_name) + storage_path = os.path.join(storage_path, self.GRID_DATA_DIRECTORY, file_name) self.api.retrieve_file_from_storage(storage_path, path) diff --git a/allensdk/internal/api/queries/mouse_connectivity_api_prerelease.py b/allensdk/internal/api/queries/mouse_connectivity_api_prerelease.py index d26fdcd09f..f20cff0dea 100644 --- a/allensdk/internal/api/queries/mouse_connectivity_api_prerelease.py +++ b/allensdk/internal/api/queries/mouse_connectivity_api_prerelease.py @@ -10,7 +10,7 @@ _STRUCTURE_TREE_ROOT_NAME = "root" _STRUCTURE_TREE_ROOT_ACRONYM = "root" -_EXPERIMENT_QUERY = ''' +_EXPERIMENT_QUERY = """ with specimens_concat_workflows as ( select sp.id, string_agg(w.name, '|') as workflows @@ -62,7 +62,8 @@ left join injections_concat_structures as ics on ics.injection_id = inj.id -- only image series we can pull and ensure mice (should all be mice already) -- where iser.storage_directory is not null and d.organism_id = 2 -''' +""" + def _experiment_dict(row): # use empty strings instead of null @@ -71,50 +72,50 @@ def null_fill(s): exp = dict() - exp['id'] = row[b'id'] + exp["id"] = row[b"id"] - exp['age'] = null_fill(row[b'age']) - exp['gender'] = null_fill(row[b'gender']) - exp['project_code'] = null_fill(row[b'project_code']) - exp['specimen_name'] = null_fill(row[b'specimen_name']) - exp['transgenic_line'] = null_fill(row[b'transgenic_line']) - exp['workflow_state'] = null_fill(row[b'workflow_state']) + exp["age"] = null_fill(row[b"age"]) + exp["gender"] = null_fill(row[b"gender"]) + exp["project_code"] = null_fill(row[b"project_code"]) + exp["specimen_name"] = null_fill(row[b"specimen_name"]) + exp["transgenic_line"] = null_fill(row[b"transgenic_line"]) + exp["workflow_state"] = null_fill(row[b"workflow_state"]) # list : [''] or ['workflow1', 'workflow2', ... ] - exp['workflows'] = null_fill(row[b'workflows']) - exp['workflows'] = exp['workflows'].split('|') + exp["workflows"] = null_fill(row[b"workflows"]) + exp["workflows"] = exp["workflows"].split("|") - if row[b'structure_id'] is not None: - exp['structure_id'] = row[b'structure_id'] - exp['structure_name'] = row[b'structure_name'] - exp['structure_abbrev'] = row[b'structure_acronym'] + if row[b"structure_id"] is not None: + exp["structure_id"] = row[b"structure_id"] + exp["structure_name"] = row[b"structure_name"] + exp["structure_abbrev"] = row[b"structure_acronym"] else: # use root structure for compatibility with structure tree - exp['structure_id'] = _STRUCTURE_TREE_ROOT_ID - exp['structure_name'] = _STRUCTURE_TREE_ROOT_NAME - exp['structure_abbrev'] = _STRUCTURE_TREE_ROOT_ACRONYM - - if row[b'injection_structures_id'] is not None: - ids = row[b'injection_structures_id'].split('|') - names = row[b'injection_structures_name'].split('|') - acronyms = row[b'injection_structures_acronym'].split('|') + exp["structure_id"] = _STRUCTURE_TREE_ROOT_ID + exp["structure_name"] = _STRUCTURE_TREE_ROOT_NAME + exp["structure_abbrev"] = _STRUCTURE_TREE_ROOT_ACRONYM + + if row[b"injection_structures_id"] is not None: + ids = row[b"injection_structures_id"].split("|") + names = row[b"injection_structures_name"].split("|") + acronyms = row[b"injection_structures_acronym"].split("|") else: # have at least prim. inj. struct. in structures - ids = (exp['structure_id'], ) - names = (exp['structure_name'], ) - acronyms = (exp['structure_abbrev'], ) + ids = (exp["structure_id"],) + names = (exp["structure_name"],) + acronyms = (exp["structure_abbrev"],) - keys = 'id', 'name', 'abbreviation' + keys = "id", "name", "abbreviation" values = zip(ids, names, acronyms) structures = map(lambda s: dict(zip(keys, s)), values) - exp['injection_structures'] = list(structures) + exp["injection_structures"] = list(structures) return exp class MouseConnectivityApiPrerelease(MouseConnectivityApi): - '''Client for retrieving prereleased mouse connectivity data from lims. + """Client for retrieving prereleased mouse connectivity data from lims. Parameters ---------- @@ -123,15 +124,13 @@ class MouseConnectivityApiPrerelease(MouseConnectivityApi): file_name : string, optional File name to save/read storage_directories dict. Passed to GridDataApiPrerelease constructor. - ''' + """ - def __init__(self, - storage_directories_file_name, - cache_storage_directories=True, - base_uri=None): + def __init__(self, storage_directories_file_name, cache_storage_directories=True, base_uri=None): super(MouseConnectivityApiPrerelease, self).__init__(base_uri=base_uri) self.grid_data_api = GridDataApiPrerelease.from_file_name( - storage_directories_file_name, cache=cache_storage_directories) + storage_directories_file_name, cache=cache_storage_directories + ) @cacheable() def get_experiments(self): @@ -139,49 +138,36 @@ def get_experiments(self): experiments = [] for row in query_result: - if str(row[b'id']) in self.grid_data_api.storage_directories: - + if str(row[b"id"]) in self.grid_data_api.storage_directories: exp_dict = _experiment_dict(row) experiments.append(exp_dict) return experiments - #@cacheable() + # @cacheable() def get_structure_unionizes(self): raise NotImplementedError() - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_injection_density(self, path, experiment_id, resolution): file_name = "%s_%s.nrrd" % (GridDataApi.INJECTION_DENSITY, resolution) - self.grid_data_api.download_projection_grid_data( - path, experiment_id, file_name) + self.grid_data_api.download_projection_grid_data(path, experiment_id, file_name) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_projection_density(self, path, experiment_id, resolution): file_name = "%s_%s.nrrd" % (GridDataApi.PROJECTION_DENSITY, resolution) - self.grid_data_api.download_projection_grid_data( - path, experiment_id, file_name) + self.grid_data_api.download_projection_grid_data(path, experiment_id, file_name) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_injection_fraction(self, path, experiment_id, resolution): file_name = "%s_%s.nrrd" % (GridDataApi.INJECTION_FRACTION, resolution) - self.grid_data_api.download_projection_grid_data( - path, experiment_id, file_name) + self.grid_data_api.download_projection_grid_data(path, experiment_id, file_name) - @cacheable(strategy='create', - pathfinder=Cache.pathfinder(file_name_position=1, - path_keyword='path')) + @cacheable(strategy="create", pathfinder=Cache.pathfinder(file_name_position=1, path_keyword="path")) def download_data_mask(self, path, experiment_id, resolution): file_name = "%s_%s.nrrd" % (GridDataApi.DATA_MASK, resolution) - self.grid_data_api.download_projection_grid_data( - path, experiment_id, file_name) + self.grid_data_api.download_projection_grid_data(path, experiment_id, file_name) diff --git a/allensdk/internal/api/queries/optimize_config_reader.py b/allensdk/internal/api/queries/optimize_config_reader.py index 4b02502cec..4082a78d0b 100644 --- a/allensdk/internal/api/queries/optimize_config_reader.py +++ b/allensdk/internal/api/queries/optimize_config_reader.py @@ -23,12 +23,12 @@ class OptimizeConfigReader(object): - _log = logging.getLogger('allensdk.internal.api.queries.lims.optimize_config_reader') - + _log = logging.getLogger("allensdk.internal.api.queries.lims.optimize_config_reader") + STIMULUS_CONTENT_TYPE = None MORPHOLOGY_TYPE_ID = 303941301 MOD_FILE_TYPE_ID = 292178729 - NEURONAL_MODEL_PARAMETERS = 329230374 # fit.json file. + NEURONAL_MODEL_PARAMETERS = 329230374 # fit.json file. def __init__(self): self.lims_path = None @@ -47,11 +47,11 @@ def read_lims_file(self, lims_path): def read_json(self, path): self.lims_path = os.path.realpath(path) - + with open(self.lims_path) as f: json_string = f.read() self.read_json_string(json_string) - + return self.lims_data def read_json_string(self, json_string): @@ -59,164 +59,163 @@ def read_json_string(self, json_string): self.lims_update_data = dict(self.lims_data) def write_file(self, path): - with open(path, 'wb') as f: + with open(path, "wb") as f: f.write(json.dumps(self.lims_update_data, indent=2)) def stimulus_file_entries(self): - ''' read the well known file path from the lims result - corresponding to the stimulus file - :return: well_known_file entries - :rtype: array of dicts - ''' + """read the well known file path from the lims result + corresponding to the stimulus file + :return: well_known_file entries + :rtype: array of dicts + """ neuronal_model = self.lims_data - specimen = neuronal_model['specimen'] - roi_result = specimen['ephys_roi_result'] - well_known_files = roi_result['well_known_files'] + specimen = neuronal_model["specimen"] + roi_result = specimen["ephys_roi_result"] + well_known_files = roi_result["well_known_files"] stimulus_file_entries = [] for well_known_file in well_known_files: try: - file_type_id = well_known_file['well_known_file_type_id'] - + file_type_id = well_known_file["well_known_file_type_id"] + if file_type_id == lims_utilities.NWB_FILE_TYPE_ID: stimulus_file_entries.append(well_known_file) except Exception: - OptimizeConfigReader._log.warn('skipping well known file record with no well known file type.') + OptimizeConfigReader._log.warn("skipping well known file record with no well known file type.") return stimulus_file_entries def lims_working_directory(self): - ''' While this is the same directory as the optimize - directory, it can be mocked out for testing if the - optimize directory is write only. - ''' + """While this is the same directory as the optimize + directory, it can be mocked out for testing if the + optimize directory is write only. + """ return self.neuronal_model_optimize_dir() def output_directory(self): - return os.path.join(self.lims_working_directory(), 'work') + return os.path.join(self.lims_working_directory(), "work") def stimulus_path(self): - ''' Get the path to the stimulus file from the lims result. - :return: path to stimulus file - :rtype: string - ''' + """Get the path to the stimulus file from the lims result. + :return: path to stimulus file + :rtype: string + """ file_entries = self.stimulus_file_entries() if len(file_entries) > 1: - OptimizeConfigReader._log.warning('More than one stimulus file found.') + OptimizeConfigReader._log.warning("More than one stimulus file found.") file_entry = file_entries[0] - stimulus_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + stimulus_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) return stimulus_path def neuronal_model_optimize_dir(self): - ''' read the directory path where - output goes from the lims optimization config json - - Parameters - ---------- - - Returns - ------- - string: - directory path - ''' - return self.lims_data['storage_directory'] + """read the directory path where + output goes from the lims optimization config json + + Parameters + ---------- + + Returns + ------- + string: + directory path + """ + return self.lims_data["storage_directory"] def morphology_file_entries(self): - ''' read the well known file paths - from the lims result corresponding to the morphology - - Returns - ------- - arrary of dicts: - well known file entries - ''' + """read the well known file paths + from the lims result corresponding to the morphology + + Returns + ------- + arrary of dicts: + well known file entries + """ neuronal_model = self.lims_data - specimen = neuronal_model['specimen'] - reconstructions = specimen['neuron_reconstructions'] + specimen = neuronal_model["specimen"] + reconstructions = specimen["neuron_reconstructions"] morphology_file_entries = [] for reconstruction in reconstructions: - superseded = reconstruction['superseded'] - manual = reconstruction['manual'] - + superseded = reconstruction["superseded"] + manual = reconstruction["manual"] + if manual and not superseded: - well_known_files = reconstruction['well_known_files'] - + well_known_files = reconstruction["well_known_files"] + for well_known_file in well_known_files: - file_type_id = well_known_file['well_known_file_type_id'] - + file_type_id = well_known_file["well_known_file_type_id"] + if file_type_id == OptimizeConfigReader.MORPHOLOGY_TYPE_ID: morphology_file_entries.append(well_known_file) return morphology_file_entries def morphology_path(self): - ''' Get the path to the morphology file from the lims result. - :return: path to morphology file - :rtype: string - ''' + """Get the path to the morphology file from the lims result. + :return: path to morphology file + :rtype: string + """ file_entries = self.morphology_file_entries() - + if len(file_entries) > 1: - OptimizeConfigReader._log.warning('More than one morphology file found.') - + OptimizeConfigReader._log.warning("More than one morphology file found.") + file_entry = file_entries[0] - - morphology_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) - + + morphology_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) + return morphology_path def sweep_entries(self): - ''' read the sweep entries - from the lims result corresponding to the stimulus - :return: stimulus sweep entries - :rtype: array of dicts - ''' + """read the sweep entries + from the lims result corresponding to the stimulus + :return: stimulus sweep entries + :rtype: array of dicts + """ neuronal_model = self.lims_data - specimen = neuronal_model['specimen'] - sweeps = specimen['ephys_sweeps'] - + specimen = neuronal_model["specimen"] + sweeps = specimen["ephys_sweeps"] + return sweeps def sweep_numbers(self): - ''' Get the stimulus sweep numbers from the lims result - :return: list of sweep numbers - :rtype: array of ints - ''' + """Get the stimulus sweep numbers from the lims result + :return: list of sweep numbers + :rtype: array of ints + """ sweep_entries = self.sweep_entries() if not sweep_entries or len(sweep_entries) < 1: - OptimizeConfigReader._log.warning('No sweeps found.') + OptimizeConfigReader._log.warning("No sweeps found.") - sweeps = [sweep_entry['sweep_number'] \ - for sweep_entry in sweep_entries \ - if sweep_entry['workflow_state'] == 'auto_passed' or \ - sweep_entry['workflow_state'] == 'manual_passed' ] + sweeps = [ + sweep_entry["sweep_number"] + for sweep_entry in sweep_entries + if sweep_entry["workflow_state"] == "auto_passed" or sweep_entry["workflow_state"] == "manual_passed" + ] return list(set(sweeps)) def mod_file_entries(self): - ''' read the NERUON .mod file entries - from the lims result corresponding to the NeuronModel - :return: well known file entries - :rtype: array of dicts - ''' + """read the NERUON .mod file entries + from the lims result corresponding to the NeuronModel + :return: well known file entries + :rtype: array of dicts + """ neuronal_model = self.lims_data - model_template = neuronal_model['neuronal_model_template'] - well_known_files = model_template['well_known_files'] + model_template = neuronal_model["neuronal_model_template"] + well_known_files = model_template["well_known_files"] mod_file_entries = [] for well_known_file in well_known_files: - file_type_id = well_known_file['well_known_file_type_id'] + file_type_id = well_known_file["well_known_file_type_id"] if file_type_id == OptimizeConfigReader.MOD_FILE_TYPE_ID: mod_file_entries.append(well_known_file) @@ -224,176 +223,120 @@ def mod_file_entries(self): return mod_file_entries def mod_file_paths(self): - ''' Get the paths to the mod files from the lims result. - :return: paths to mod files - :rtype: array of strings - ''' + """Get the paths to the mod files from the lims result. + :return: paths to mod files + :rtype: array of strings + """ file_entries = self.mod_file_entries() if not file_entries or len(file_entries) < 1: - OptimizeConfigReader._log.warning('No mod files found.') + OptimizeConfigReader._log.warning("No mod files found.") mod_file_paths = [] for file_entry in file_entries: - mod_path = os.path.join(file_entry['storage_directory'], - file_entry['filename']) + mod_path = os.path.join(file_entry["storage_directory"], file_entry["filename"]) OptimizeConfigReader._log.info(mod_path) mod_file_paths.append(mod_path) return mod_file_paths - def update_well_known_file(self, - path, - well_known_file_type_id=None): + def update_well_known_file(self, path, well_known_file_type_id=None): if well_known_file_type_id is None: well_known_file_type_id = lims_utilities.MODEL_PARAMETERS_FILE_TYPE_ID - well_known_files = self.lims_data['well_known_files'] + well_known_files = self.lims_data["well_known_files"] def get_model_parameter_file_id(f): - if ('well_known_file_type_id' in f and - f['well_known_file_type_id'] == well_known_file_type_id): - return f['id'] + if "well_known_file_type_id" in f and f["well_known_file_type_id"] == well_known_file_type_id: + return f["id"] else: return None def not_fit_param(f): - if ('well_known_file_type_id' in f and - f['well_known_file_type_id'] == well_known_file_type_id): + if "well_known_file_type_id" in f and f["well_known_file_type_id"] == well_known_file_type_id: return False else: return True try: - existing_file_id = \ - next(wkf_id for wkf_id in - (get_model_parameter_file_id(f2) - for f2 in well_known_files) if wkf_id) + existing_file_id = next( + wkf_id for wkf_id in (get_model_parameter_file_id(f2) for f2 in well_known_files) if wkf_id + ) # existing parameter file found - self.lims_update_data['well_known_files'] = [f for f in well_known_files if not_fit_param(f)] + self.lims_update_data["well_known_files"] = [f for f in well_known_files if not_fit_param(f)] (dirname, filename) = os.path.split(os.path.abspath(path)) - self.lims_update_data['well_known_files'] += [{ - 'id': existing_file_id, - 'filename': filename, - 'storage_directory': dirname, - 'well_known_file_type_id': well_known_file_type_id - }] + self.lims_update_data["well_known_files"] += [ + { + "id": existing_file_id, + "filename": filename, + "storage_directory": dirname, + "well_known_file_type_id": well_known_file_type_id, + } + ] except StopIteration: # no parameter files found (dirname, filename) = os.path.split(os.path.abspath(path)) - self.lims_update_data['well_known_files'] += [{ - 'content_type': 'application/json', - 'filename': filename, - 'storage_directory': dirname, - 'well_known_file_type_id': well_known_file_type_id - }] + self.lims_update_data["well_known_files"] += [ + { + "content_type": "application/json", + "filename": filename, + "storage_directory": dirname, + "well_known_file_type_id": well_known_file_type_id, + } + ] def build_manifest(self, manifest_path=None): b = ManifestBuilder() - b.add_path('BASEDIR', os.path.realpath(os.curdir)) + b.add_path("BASEDIR", os.path.realpath(os.curdir)) - b.add_path('WORKDIR', - self.output_directory()) + b.add_path("WORKDIR", self.output_directory()) - b.add_path('MORPHOLOGY', - self.morphology_path(), - typename='file') + b.add_path("MORPHOLOGY", self.morphology_path(), typename="file") - b.add_path('MODFILE_DIR', 'modfiles') + b.add_path("MODFILE_DIR", "modfiles") for modfile in self.mod_file_entries(): - b.add_path('MOD_FILE_%s' % (os.path.splitext(modfile['filename'])[0]), - os.path.join(modfile['storage_directory'], - modfile['filename']), - typename='file', - format='MODFILE') - - b.add_path('stimulus_path', - self.stimulus_path(), - typename='file', - format='NWB') - - b.add_path('manifest', - os.path.join(os.path.realpath(os.curdir), - manifest_path), - typename='file') - - b.add_path('output', - os.path.basename(self.stimulus_path()), - typename='file', - parent_key='WORKDIR', - format='NWB') - - b.add_path('neuronal_model_data', - self.lims_path, - typename='file') - - b.add_path('upfile', - 'upbase.dat', - typename='file', - parent_key='WORKDIR') - b.add_path('downfile', - 'downbase.dat', - typename='file', - parent_key='WORKDIR') - b.add_path('passive_fit_data', - 'passive_fit_data.json', - typename='file', - parent_key='WORKDIR') - b.add_path('stage_1_jobs', - 'stage_1_jobs.json', - typename='file', - parent_key='WORKDIR') - b.add_path('fit_1_file', - 'fit_1_data.json', - typename='file', - parent_key='WORKDIR') - b.add_path('fit_2_file', - 'fit_2_data.json', - typename='file', - parent_key='WORKDIR') - b.add_path('fit_3_file', - 'fit_3_data.json', - typename='file', - parent_key='WORKDIR') - b.add_path('fit_type_path', - typename='file', - spec='%s', - parent_key='WORKDIR') - b.add_path('target_path', - typename='file', - spec='target.json', - parent_key='WORKDIR') - b.add_path('fit_config_json', - typename='file', - spec='%s/config.json', - parent_key='WORKDIR') - b.add_path('final_hof_fit', - typename='file', - spec='%s/s%d/final_hof_fit.txt', - parent_key='WORKDIR') - b.add_path('final_hof', - typename='file', - spec='%s/s%d/final_hof.txt', - parent_key='WORKDIR') - b.add_path('output_fit_file', - typename='file', - spec='fit_%s_%s.json') - - b.add_section('biophys', {"biophys": [ - {"model_file": [ manifest_path ] }]}) - - b.add_section('stimulus_conf', - {"runs": [{"sweeps": self.sweep_numbers(), - "specimen_id": self.lims_data['specimen_id'] - }]}) - - b.add_section('hoc_conf', - {"neuron" : [{"hoc": [ "stdgui.hoc", "import3d.hoc", "cell.hoc" ] - }]}) + b.add_path( + "MOD_FILE_%s" % (os.path.splitext(modfile["filename"])[0]), + os.path.join(modfile["storage_directory"], modfile["filename"]), + typename="file", + format="MODFILE", + ) + + b.add_path("stimulus_path", self.stimulus_path(), typename="file", format="NWB") + + b.add_path("manifest", os.path.join(os.path.realpath(os.curdir), manifest_path), typename="file") + + b.add_path( + "output", os.path.basename(self.stimulus_path()), typename="file", parent_key="WORKDIR", format="NWB" + ) + + b.add_path("neuronal_model_data", self.lims_path, typename="file") + + b.add_path("upfile", "upbase.dat", typename="file", parent_key="WORKDIR") + b.add_path("downfile", "downbase.dat", typename="file", parent_key="WORKDIR") + b.add_path("passive_fit_data", "passive_fit_data.json", typename="file", parent_key="WORKDIR") + b.add_path("stage_1_jobs", "stage_1_jobs.json", typename="file", parent_key="WORKDIR") + b.add_path("fit_1_file", "fit_1_data.json", typename="file", parent_key="WORKDIR") + b.add_path("fit_2_file", "fit_2_data.json", typename="file", parent_key="WORKDIR") + b.add_path("fit_3_file", "fit_3_data.json", typename="file", parent_key="WORKDIR") + b.add_path("fit_type_path", typename="file", spec="%s", parent_key="WORKDIR") + b.add_path("target_path", typename="file", spec="target.json", parent_key="WORKDIR") + b.add_path("fit_config_json", typename="file", spec="%s/config.json", parent_key="WORKDIR") + b.add_path("final_hof_fit", typename="file", spec="%s/s%d/final_hof_fit.txt", parent_key="WORKDIR") + b.add_path("final_hof", typename="file", spec="%s/s%d/final_hof.txt", parent_key="WORKDIR") + b.add_path("output_fit_file", typename="file", spec="fit_%s_%s.json") + + b.add_section("biophys", {"biophys": [{"model_file": [manifest_path]}]}) + + b.add_section( + "stimulus_conf", {"runs": [{"sweeps": self.sweep_numbers(), "specimen_id": self.lims_data["specimen_id"]}]} + ) + + b.add_section("hoc_conf", {"neuron": [{"hoc": ["stdgui.hoc", "import3d.hoc", "cell.hoc"]}]}) return b diff --git a/allensdk/internal/api/queries/pre_release.py b/allensdk/internal/api/queries/pre_release.py index 52058f94b8..02816318e5 100644 --- a/allensdk/internal/api/queries/pre_release.py +++ b/allensdk/internal/api/queries/pre_release.py @@ -4,26 +4,24 @@ import os import collections -sql_query_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pre_release_sql') +sql_query_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "pre_release_sql") -with open(os.path.join(sql_query_dir, 'experiment_pre_release_query.sql'), 'r') as f: +with open(os.path.join(sql_query_dir, "experiment_pre_release_query.sql"), "r") as f: experiment_pre_release_query = f.read() -with open(os.path.join(sql_query_dir, 'container_pre_release_query.sql'), 'r') as f: +with open(os.path.join(sql_query_dir, "container_pre_release_query.sql"), "r") as f: container_pre_release_query = f.read() -with open(os.path.join(sql_query_dir, 'cell_specimens_pre_release_query.sql'), 'r') as f: +with open(os.path.join(sql_query_dir, "cell_specimens_pre_release_query.sql"), "r") as f: cell_specimens_pre_release_query = f.read() -class BrainObservatoryApiPreRelease(BrainObservatoryApi): +class BrainObservatoryApiPreRelease(BrainObservatoryApi): @cacheable() def get_experiment_containers(self): - query_result = lu.query(container_pre_release_query) container_list = [] for q in query_result: - # # For development: print key/val pairs generated from LIMS query: # for key, val in sorted(q.items(), key=lambda x: x[0]): # print(key, val) @@ -31,24 +29,26 @@ def get_experiment_containers(self): c = collections.defaultdict(collections.defaultdict) - c['id'] = q['ec_id'] - c['targeted_structure']['acronym'] = q['acronym'] - c['specimen']['donor'] = collections.defaultdict(collections.defaultdict) - c['specimen']['donor']['external_donor_name'] = q['external_donor_name'] - c['specimen']['donor']['transgenic_lines'] = [collections.defaultdict(collections.defaultdict), collections.defaultdict(collections.defaultdict)] - c['specimen']['donor']['transgenic_lines'][0]['transgenic_line_type_name'] = 'driver' - c['specimen']['donor']['transgenic_lines'][0]['name'] = q['driver'] - c['specimen']['donor']['transgenic_lines'][1]['transgenic_line_type_name'] = 'reporter' - c['specimen']['donor']['transgenic_lines'][1]['name'] = q['reporter'] - c['specimen']['name'] = q['specimen'] - c['imaging_depth'] = q['depth'] - c['failed'] = q['oa_state'] == 'failed' - - - if q['donor_tags'] == u'Epileptiform Events': - c['specimen']['donor']['conditions'] = [collections.defaultdict(collections.defaultdict)] - c['specimen']['donor']['conditions'][0]['name'] = u'Epileptiform Events' - elif q['donor_tags'] == u'': + c["id"] = q["ec_id"] + c["targeted_structure"]["acronym"] = q["acronym"] + c["specimen"]["donor"] = collections.defaultdict(collections.defaultdict) + c["specimen"]["donor"]["external_donor_name"] = q["external_donor_name"] + c["specimen"]["donor"]["transgenic_lines"] = [ + collections.defaultdict(collections.defaultdict), + collections.defaultdict(collections.defaultdict), + ] + c["specimen"]["donor"]["transgenic_lines"][0]["transgenic_line_type_name"] = "driver" + c["specimen"]["donor"]["transgenic_lines"][0]["name"] = q["driver"] + c["specimen"]["donor"]["transgenic_lines"][1]["transgenic_line_type_name"] = "reporter" + c["specimen"]["donor"]["transgenic_lines"][1]["name"] = q["reporter"] + c["specimen"]["name"] = q["specimen"] + c["imaging_depth"] = q["depth"] + c["failed"] = q["oa_state"] == "failed" + + if q["donor_tags"] == "Epileptiform Events": + c["specimen"]["donor"]["conditions"] = [collections.defaultdict(collections.defaultdict)] + c["specimen"]["donor"]["conditions"][0]["name"] = "Epileptiform Events" + elif q["donor_tags"] == "": pass else: raise @@ -56,10 +56,8 @@ def get_experiment_containers(self): container_list.append(c) return container_list - @cacheable() def get_ophys_experiments(self): - query_result = lu.query(experiment_pre_release_query) experiment_list = [] for q in query_result: @@ -70,23 +68,26 @@ def get_ophys_experiments(self): # print(key, val) # raise - c['id'] = q['o_id'] - c['imaging_depth'] = q['depth'] - c['targeted_structure']['acronym'] = q['acronym'] - c['specimen']['donor'] = collections.defaultdict(collections.defaultdict) - c['specimen']['donor']['external_donor_name'] = q['acronym'] - c['specimen']['donor']['transgenic_lines'] = [collections.defaultdict(collections.defaultdict), collections.defaultdict(collections.defaultdict)] - c['specimen']['donor']['transgenic_lines'][0]['transgenic_line_type_name'] = 'driver' - c['specimen']['donor']['transgenic_lines'][0]['name'] = q['driver'] - c['specimen']['donor']['transgenic_lines'][1]['transgenic_line_type_name'] = 'reporter' - c['specimen']['donor']['transgenic_lines'][1]['name'] = q['reporter'] - c['date_of_acquisition'] = q['date_of_acquisition'] - c['specimen']['donor']['date_of_birth'] = q['date_of_birth'] - c['experiment_container_id'] = q['ec_id'] - c['stimulus_name'] = q['stimulus_name'] - c['specimen']['donor']['external_donor_name'] = q['external_donor_name'] - c['specimen']['name'] = q['specimen'] - c['fail_eye_tracking'] = q['fail_eye_tracking'] + c["id"] = q["o_id"] + c["imaging_depth"] = q["depth"] + c["targeted_structure"]["acronym"] = q["acronym"] + c["specimen"]["donor"] = collections.defaultdict(collections.defaultdict) + c["specimen"]["donor"]["external_donor_name"] = q["acronym"] + c["specimen"]["donor"]["transgenic_lines"] = [ + collections.defaultdict(collections.defaultdict), + collections.defaultdict(collections.defaultdict), + ] + c["specimen"]["donor"]["transgenic_lines"][0]["transgenic_line_type_name"] = "driver" + c["specimen"]["donor"]["transgenic_lines"][0]["name"] = q["driver"] + c["specimen"]["donor"]["transgenic_lines"][1]["transgenic_line_type_name"] = "reporter" + c["specimen"]["donor"]["transgenic_lines"][1]["name"] = q["reporter"] + c["date_of_acquisition"] = q["date_of_acquisition"] + c["specimen"]["donor"]["date_of_birth"] = q["date_of_birth"] + c["experiment_container_id"] = q["ec_id"] + c["stimulus_name"] = q["stimulus_name"] + c["specimen"]["donor"]["external_donor_name"] = q["external_donor_name"] + c["specimen"]["name"] = q["specimen"] + c["fail_eye_tracking"] = q["fail_eye_tracking"] experiment_list.append(c) return experiment_list @@ -96,58 +97,114 @@ def get_cell_metrics(self): query_result = lu.query(cell_specimens_pre_release_query) mappings = self.get_stimulus_mappings() - thumbnails = [m['item'] for m in mappings if m['item_type'] == 'T' and m['level'] == 'R'] + thumbnails = [m["item"] for m in mappings if m["item_type"] == "T" and m["level"] == "R"] cell_list = [] for q in query_result: c = collections.defaultdict(collections.defaultdict) - c['all_stim'] = q['a_valid'] and q['b_valid'] and q['c_valid'] - for key in ['cell_specimen_id', 'area', 'donor_full_genotype', 'experiment_container_id', 'imaging_depth', 'specimen_id', 'tld1_id', - 'tld1_name', 'tld2_id', 'tld2_name', 'tlr1_id', 'tlr1_name']: + c["all_stim"] = q["a_valid"] and q["b_valid"] and q["c_valid"] + for key in [ + "cell_specimen_id", + "area", + "donor_full_genotype", + "experiment_container_id", + "imaging_depth", + "specimen_id", + "tld1_id", + "tld1_name", + "tld2_id", + "tld2_name", + "tlr1_id", + "tlr1_name", + ]: c[key] = q[key] - if q['failed_experiment_container'] == 't': - c['failed_experiment_container'] = True - elif q['failed_experiment_container'] == 'f': - c['failed_experiment_container'] = False + if q["failed_experiment_container"] == "t": + c["failed_experiment_container"] = True + elif q["failed_experiment_container"] == "f": + c["failed_experiment_container"] = False else: - raise RuntimeError('Unexpected value: {} not in ("t", "f")'.format(q['failed_experiment_container'])) - + raise RuntimeError('Unexpected value: {} not in ("t", "f")'.format(q["failed_experiment_container"])) # Session A metrics: - for key in ['dsi_dg', 'g_dsi_dg', 'g_osi_dg', 'osi_dg', 'p_dg', 'p_run_mod_dg', 'peak_dff_dg', 'pref_dir_dg', 'pref_tf_dg', 'reliability_dg', - 'reliability_nm3', 'run_mod_dg', 'tfdi_dg', 'tfdi_dg']: - if q['crara_data'] is None: + for key in [ + "dsi_dg", + "g_dsi_dg", + "g_osi_dg", + "osi_dg", + "p_dg", + "p_run_mod_dg", + "peak_dff_dg", + "pref_dir_dg", + "pref_tf_dg", + "reliability_dg", + "reliability_nm3", + "run_mod_dg", + "tfdi_dg", + "tfdi_dg", + ]: + if q["crara_data"] is None: c[key] = None else: - c[key] = q['crara_data']['roi_cell_metrics'].get(key,None) + c[key] = q["crara_data"]["roi_cell_metrics"].get(key, None) # Session B metrics: - for key in ['g_osi_sg', 'image_sel_ns', 'osi_sg', 'p_ns', 'p_run_mod_ns', 'p_run_mod_sg', 'p_sg', 'peak_dff_ns', 'peak_dff_sg', 'pref_image_ns', - 'pref_ori_sg', 'pref_phase_sg', 'pref_sf_sg', 'pref_image_ns', 'pref_ori_sg', 'reliability_ns', 'reliability_sg', - 'run_mod_ns','run_mod_sg','sfdi_sg', 'time_to_peak_ns', 'time_to_peak_sg']: - - if q['crarb_data'] is None: + for key in [ + "g_osi_sg", + "image_sel_ns", + "osi_sg", + "p_ns", + "p_run_mod_ns", + "p_run_mod_sg", + "p_sg", + "peak_dff_ns", + "peak_dff_sg", + "pref_image_ns", + "pref_ori_sg", + "pref_phase_sg", + "pref_sf_sg", + "pref_image_ns", + "pref_ori_sg", + "reliability_ns", + "reliability_sg", + "run_mod_ns", + "run_mod_sg", + "sfdi_sg", + "time_to_peak_ns", + "time_to_peak_sg", + ]: + if q["crarb_data"] is None: c[key] = None else: - c[key] = q['crarb_data']['roi_cell_metrics'].get(key,None) + c[key] = q["crarb_data"]["roi_cell_metrics"].get(key, None) # Session C metrics: - for key in ['reliability_nm2', 'rf_area_off_lsn', 'rf_area_on_lsn', 'rf_center_off_x_lsn', 'rf_center_off_y_lsn', - 'rf_center_on_x_lsn', 'rf_center_on_y_lsn', 'rf_chi2_lsn', 'rf_distance_lsn', 'rf_overlap_index_lsn', - ]: - if q['crarc_data'] is None: + for key in [ + "reliability_nm2", + "rf_area_off_lsn", + "rf_area_on_lsn", + "rf_center_off_x_lsn", + "rf_center_off_y_lsn", + "rf_center_on_x_lsn", + "rf_center_on_y_lsn", + "rf_chi2_lsn", + "rf_distance_lsn", + "rf_overlap_index_lsn", + ]: + if q["crarc_data"] is None: c[key] = None else: - c[key] = q['crarc_data']['roi_cell_metrics'].get(key,None) - - for suffix in ['a', 'b', 'c']: - if q['crar%s_data' % suffix] is not None: - c['reliability_nm1_%s' % suffix] = q['crar%s_data' % suffix]['roi_cell_metrics'].get('reliability_nm1',None) + c[key] = q["crarc_data"]["roi_cell_metrics"].get(key, None) + + for suffix in ["a", "b", "c"]: + if q["crar%s_data" % suffix] is not None: + c["reliability_nm1_%s" % suffix] = q["crar%s_data" % suffix]["roi_cell_metrics"].get( + "reliability_nm1", None + ) else: - c['reliability_nm1_%s' % suffix] = None - + c["reliability_nm1_%s" % suffix] = None + # Fake in thumbnail images: for t in thumbnails: c[t] = None diff --git a/allensdk/internal/api/queries/utils.py b/allensdk/internal/api/queries/utils.py index bb552c648a..9d674ae32f 100644 --- a/allensdk/internal/api/queries/utils.py +++ b/allensdk/internal/api/queries/utils.py @@ -4,10 +4,8 @@ def build_in_list_selector_query( - col: str, - valid_list: Optional[List[SupportsStr]] = None, - operator: str = "WHERE", - valid: bool = True) -> str: + col: str, valid_list: Optional[List[SupportsStr]] = None, operator: str = "WHERE", valid: bool = True +) -> str: """ Filter for rows where the value of a column is contained in a list (or, if valid=False, where the value is not contained in that list). @@ -34,34 +32,30 @@ def build_in_list_selector_query( The clause performing the request filter """ if operator not in ("AND", "OR", "WHERE"): - msg = ("Operator must be 'AND', 'OR', or 'WHERE'; " - f"you gave '{operator}'") + msg = f"Operator must be 'AND', 'OR', or 'WHERE'; you gave '{operator}'" raise ValueError(msg) if not valid_list: return "" if type(valid_list[0]) is str: - valid_list = _convert_list_of_string_to_sql_safe_string( - strings=valid_list) + valid_list = _convert_list_of_string_to_sql_safe_string(strings=valid_list) if valid: relation = "IN" else: relation = "NOT IN" - session_query = ( - f"""{operator} {col} {relation} ({",".join( - sorted(set(map(str, valid_list))))})""") + session_query = f"""{operator} {col} {relation} ({",".join(sorted(set(map(str, valid_list))))})""" return session_query def build_where_clause(clauses: List[str]): if not clauses: - return '' - where_clause = ' AND '.join(clauses) - if not where_clause[:5].lower() == 'where': - where_clause = f'WHERE {where_clause}' + return "" + where_clause = " AND ".join(clauses) + if not where_clause[:5].lower() == "where": + where_clause = f"WHERE {where_clause}" return where_clause @@ -98,9 +92,7 @@ def _sanitize_uuid_list(uuid_list: List[str]) -> List[str]: return sanitized_list -def _convert_list_of_string_to_sql_safe_string( - strings: List[str] -) -> List[str]: +def _convert_list_of_string_to_sql_safe_string(strings: List[str]) -> List[str]: """ Given list of string ["A", "B"] converts to ["'A'", "'B'"] diff --git a/allensdk/internal/api/queries/wkf_lims_queries.py b/allensdk/internal/api/queries/wkf_lims_queries.py index 3eb2395910..98f4d5f813 100644 --- a/allensdk/internal/api/queries/wkf_lims_queries.py +++ b/allensdk/internal/api/queries/wkf_lims_queries.py @@ -1,15 +1,12 @@ from typing import List, Dict from allensdk.internal.api import PostgresQueryMixin -from allensdk.internal.api.queries.utils import ( - build_in_list_selector_query) +from allensdk.internal.api.queries.utils import build_in_list_selector_query from allensdk import OneResultExpectedError def wkf_path_from_attachable( - lims_connection: PostgresQueryMixin, - wkf_type_name: List[str], - attachable_type: str, - attachable_id: int) -> Dict[str, str]: + lims_connection: PostgresQueryMixin, wkf_type_name: List[str], attachable_type: str, attachable_id: int +) -> Dict[str, str]: """ Get the path to well known files, selecting files of a specific type with a specified attachable ID and attachable_type. @@ -50,23 +47,25 @@ def wkf_path_from_attachable( wkf.well_known_file_type_id=wkft.id """ - query += build_in_list_selector_query( - col="wkft.name", - valid_list=wkf_type_name, - operator="WHERE", - valid=True) + query += build_in_list_selector_query(col="wkft.name", valid_list=wkf_type_name, operator="WHERE", valid=True) query += build_in_list_selector_query( - col="wkf.attachable_type", - valid_list=[f"'{attachable_type}'", ], - operator="AND", - valid=True) + col="wkf.attachable_type", + valid_list=[ + f"'{attachable_type}'", + ], + operator="AND", + valid=True, + ) query += build_in_list_selector_query( - col="wkf.attachable_id", - valid_list=[f"'{attachable_id}'", ], - operator="AND", - valid=True) + col="wkf.attachable_id", + valid_list=[ + f"'{attachable_id}'", + ], + operator="AND", + valid=True, + ) query_result = lims_connection.select(query) @@ -75,11 +74,9 @@ def wkf_path_from_attachable( if len(query_result) == 0: return wkf_path_lookup - for type_name, filepath in zip(query_result.type_name, - query_result.filepath): + for type_name, filepath in zip(query_result.type_name, query_result.filepath): if type_name in wkf_path_lookup: - raise OneResultExpectedError( - f"More than one result returned for {type_name}") + raise OneResultExpectedError(f"More than one result returned for {type_name}") wkf_path_lookup[type_name] = filepath return wkf_path_lookup diff --git a/allensdk/internal/brain_observatory/annotated_region_metrics.py b/allensdk/internal/brain_observatory/annotated_region_metrics.py index b0b010d981..6423fe36c9 100644 --- a/allensdk/internal/brain_observatory/annotated_region_metrics.py +++ b/allensdk/internal/brain_observatory/annotated_region_metrics.py @@ -1,4 +1,5 @@ """Module for calculating annotated region metrics from ISI data""" + import numpy as np # These scaling factors are derived from experimental geometry using @@ -10,7 +11,7 @@ def eccentricity(az, alt, az_center, alt_center): """Compute eccentricity - + Parameters ---------- az : numpy.ndarray @@ -29,14 +30,13 @@ def eccentricity(az, alt, az_center, alt_center): """ daz = az - az_center dalt = alt - alt_center - ecc = np.arctan(np.sqrt(np.square(np.tan(dalt)) + - np.square(np.tan(daz))/np.square(np.cos(dalt)))) + ecc = np.arctan(np.sqrt(np.square(np.tan(dalt)) + np.square(np.tan(daz)) / np.square(np.cos(dalt)))) return ecc def retinotopy_metric(mask, isi_map): """Compute retinotopic metrics for a responding area - + Parameters ---------- mask : numpy.ndarray @@ -49,7 +49,7 @@ def retinotopy_metric(mask, isi_map): (float, float, float, float) tuple min, max, range, bias of retinotopic map over masked region """ - ind = np.where( mask > 0 ) + ind = np.where(mask > 0) vals = isi_map[ind] maxv = np.degrees(np.max(vals)) minv = np.degrees(np.min(vals)) @@ -75,56 +75,60 @@ def create_region_mask(image_shape, x, y, width, height, mask): height of region mask mask : list region mask as a list of lists - + Returns ------- numpy.ndarray Region mask """ - bb = np.zeros((height,width), dtype=np.uint8) + bb = np.zeros((height, width), dtype=np.uint8) bb[np.asarray(mask)] = 1 region_mask = np.zeros(image_shape, dtype=np.uint8) - region_mask[y:y + height,x:x + width] = bb + region_mask[y : y + height, x : x + width] = bb return region_mask -def get_metrics(altitude_phase, azimuth_phase, x=None, y=None, width=None, - height=None, mask=None, altitude_scale=ALTITUDE_SCALE, - azimuth_scale=AZIMUTH_SCALE): +def get_metrics( + altitude_phase, + azimuth_phase, + x=None, + y=None, + width=None, + height=None, + mask=None, + altitude_scale=ALTITUDE_SCALE, + azimuth_scale=AZIMUTH_SCALE, +): """Calculate annotated region metrics""" altitude = altitude_phase * altitude_scale azimuth = azimuth_phase * azimuth_scale - eccentricity_ret_zero = np.degrees( - eccentricity(azimuth, altitude, 0.0, 0.0)) + eccentricity_ret_zero = np.degrees(eccentricity(azimuth, altitude, 0.0, 0.0)) result = {} - region_mask = create_region_mask(altitude.shape, x, y, width, height, - mask) + region_mask = create_region_mask(altitude.shape, x, y, width, height, mask) # compute centroid centroid = [np.mean(x_value) for x_value in np.where(region_mask)] - result['y_centroid'] = centroid[0] - result['x_centroid'] = centroid[1] + result["y_centroid"] = centroid[0] + result["x_centroid"] = centroid[1] # compute azimuth/altitude max,min,range and bias - az_min, az_max, az_range, az_bias = retinotopy_metric(region_mask, - azimuth) - alt_min, alt_max, alt_range, alt_bias = retinotopy_metric(region_mask, - altitude) - result['azimuth_min'] = az_min - result['azimuth_max'] = az_max - result['azimuth_range'] = az_range - result['azimuth_bias'] = az_bias - result['altitude_min'] = alt_min - result['altitude_max'] = alt_max - result['altitude_range'] = alt_range - result['altitude_bias'] = alt_bias + az_min, az_max, az_range, az_bias = retinotopy_metric(region_mask, azimuth) + alt_min, alt_max, alt_range, alt_bias = retinotopy_metric(region_mask, altitude) + result["azimuth_min"] = az_min + result["azimuth_max"] = az_max + result["azimuth_range"] = az_range + result["azimuth_bias"] = az_bias + result["altitude_min"] = alt_min + result["altitude_max"] = alt_max + result["altitude_range"] = alt_range + result["altitude_bias"] = alt_bias # eccentricity at centroid - cy = int(round(result['y_centroid'])) - cx = int(round(result['x_centroid'])) - result['eccentricity_at_centroid'] = float(eccentricity_ret_zero[cy,cx]) + cy = int(round(result["y_centroid"])) + cx = int(round(result["x_centroid"])) + result["eccentricity_at_centroid"] = float(eccentricity_ret_zero[cy, cx]) return result diff --git a/allensdk/internal/brain_observatory/demix_report.py b/allensdk/internal/brain_observatory/demix_report.py index 9afc45cd0a..6c9bfcf67f 100644 --- a/allensdk/internal/brain_observatory/demix_report.py +++ b/allensdk/internal/brain_observatory/demix_report.py @@ -2,83 +2,85 @@ import numpy as np import h5py -#import matplotlib -#matplotlib.use('agg') +# import matplotlib +# matplotlib.use('agg') import matplotlib.pyplot as plt import os -def background_trace(trace, save_dir, data_set=None): - fig,ax = plt.subplots(1) +def background_trace(trace, save_dir, data_set=None): + fig, ax = plt.subplots(1) ax.plot(trace) if data_set is not None: _add_stim_epochs(trace, ax, data_set) - save_file = os.path.join(save_dir, 'background_trace.pdf') + save_file = os.path.join(save_dir, "background_trace.pdf") fig.savefig(save_file) logging.info("Background Trace saved to %s", save_file) plt.close(fig) + def correlation_report(dm, save_dir, without_masks=True): - ''' - parameters: - dm: [DeMix object] - without_masks: boolean - ''' + """ + parameters: + dm: [DeMix object] + without_masks: boolean + """ logging.info("Generating Correlation Report") if without_masks: cor, cor_demix = compute_correlations_without_masks(dm) - no_diagonal_mask = (1.0 - np.eye(cor.shape[0])).astype('bool') + no_diagonal_mask = (1.0 - np.eye(cor.shape[0])).astype("bool") fig, ax = plt.subplots(1) - ax.plot(dm.mask_overlap[:-1,:-1][no_diagonal_mask],cor[no_diagonal_mask]-cor_demix[no_diagonal_mask],'o') - ax.set_xlim([-1,np.max(dm.mask_overlap[:-1,:-1][no_diagonal_mask])]) - ax.set_title('Delta Correlation vs. mask overlap') + ax.plot(dm.mask_overlap[:-1, :-1][no_diagonal_mask], cor[no_diagonal_mask] - cor_demix[no_diagonal_mask], "o") + ax.set_xlim([-1, np.max(dm.mask_overlap[:-1, :-1][no_diagonal_mask])]) + ax.set_title("Delta Correlation vs. mask overlap") - save_file = os.path.join(save_dir,'cor_vs_overlap.pdf') + save_file = os.path.join(save_dir, "cor_vs_overlap.pdf") fig.savefig(save_file) logging.info("\tCorrelation overlap saved to %s", save_file) plt.close(fig) - fig, ax = plt.subplots(1,3) - delta_cor = cor- cor_demix - ax[0].hist(delta_cor[dm.mask_overlap[:-1,:-1]==0],bins=100) - ax[0].set_title('Delta Correlation') - ax[1].hist(cor[no_diagonal_mask],bins=100) - ax[1].set_title('Pre-demix Correlation') - ax[2].hist(cor_demix[no_diagonal_mask],bins=100) - ax[2].set_title('Post-demix Correlation') + fig, ax = plt.subplots(1, 3) + delta_cor = cor - cor_demix + ax[0].hist(delta_cor[dm.mask_overlap[:-1, :-1] == 0], bins=100) + ax[0].set_title("Delta Correlation") + ax[1].hist(cor[no_diagonal_mask], bins=100) + ax[1].set_title("Pre-demix Correlation") + ax[2].hist(cor_demix[no_diagonal_mask], bins=100) + ax[2].set_title("Post-demix Correlation") - save_file = os.path.join(save_dir,'cor_hist.pdf') + save_file = os.path.join(save_dir, "cor_hist.pdf") fig.savefig(save_file) logging.info("\tCorrelation histograms saved to %s", save_file) plt.close(fig) - else: - raise Exception('without_masks=False not yet implemented') + raise Exception("without_masks=False not yet implemented") -def plot_masks(dm, save_dir, movie_file, movie_dataset, window=150, add_background=True): +def plot_masks(dm, save_dir, movie_file, movie_dataset, window=150, add_background=True): logging.info("Plotting masks") - overlap_pairs = [(x,y) for (x,y) in zip(*np.where(dm.mask_overlap >0)) if x>y and x!=dm.mask_overlap.shape[0]-1] - movie_data = h5py.File(movie_file,'r') + overlap_pairs = [ + (x, y) for (x, y) in zip(*np.where(dm.mask_overlap > 0)) if x > y and x != dm.mask_overlap.shape[0] - 1 + ] + movie_data = h5py.File(movie_file, "r") bg_traces = dm.get_traces_with_background() - for i,pair in enumerate(overlap_pairs): - fig_pair, ax_pair = plt.subplots(2,2) - + for i, pair in enumerate(overlap_pairs): + fig_pair, ax_pair = plt.subplots(2, 2) + mask = np.zeros(dm.masks[0].shape) - rgb_shape = (dm.masks[0].shape[0],dm.masks[0].shape[1],3) - rgb_mask = np.zeros(rgb_shape,dtype=np.uint8) - #for p in pair: - rgb_mask[:,:,2] = 255*dm.masks[pair[0]] - rgb_mask[:,:,1] = 255*dm.masks[pair[1]] + rgb_shape = (dm.masks[0].shape[0], dm.masks[0].shape[1], 3) + rgb_mask = np.zeros(rgb_shape, dtype=np.uint8) + # for p in pair: + rgb_mask[:, :, 2] = 255 * dm.masks[pair[0]] + rgb_mask[:, :, 1] = 255 * dm.masks[pair[1]] for p in pair: mask += dm.masks[p] non_zeros = np.where(mask) @@ -87,8 +89,8 @@ def plot_masks(dm, save_dir, movie_file, movie_dataset, window=150, add_backgrou xlower = np.min(non_zeros[1]) xupper = np.max(non_zeros[1]) - #ax_pair[0,0].imshow(mask[ylower:yupper,xlower:xupper]) - ax_pair[0,0].imshow(rgb_mask[ylower:yupper,xlower:xupper]) + # ax_pair[0,0].imshow(mask[ylower:yupper,xlower:xupper]) + ax_pair[0, 0].imshow(rgb_mask[ylower:yupper, xlower:xupper]) trace1 = dm.traces[pair[0]] trace2 = dm.traces[pair[1]] @@ -100,86 +102,91 @@ def plot_masks(dm, save_dir, movie_file, movie_dataset, window=150, add_backgrou trace1_demix = dm.traces_demix[pair[0]] trace2_demix = dm.traces_demix[pair[1]] - center = np.where(trace1==np.max(trace1))[0][0] + center = np.where(trace1 == np.max(trace1))[0][0] - ax_pair[1,0].plot(trace1[(center-window):(center+window)],label=str(pair[0])) - ax_pair[1,0].plot(trace2[(center-window):(center+window)],label=str(pair[1])) - ax_pair[1,0].legend() + ax_pair[1, 0].plot(trace1[(center - window) : (center + window)], label=str(pair[0])) + ax_pair[1, 0].plot(trace2[(center - window) : (center + window)], label=str(pair[1])) + ax_pair[1, 0].legend() - ax_pair[1,1].plot(trace1_demix[center-window:center+window],label=str(pair[0])) - ax_pair[1,1].plot(trace2_demix[center-window:center+window],label=str(pair[1])) - ax_pair[1,1].legend() + ax_pair[1, 1].plot(trace1_demix[center - window : center + window], label=str(pair[0])) + ax_pair[1, 1].plot(trace2_demix[center - window : center + window], label=str(pair[1])) + ax_pair[1, 1].legend() - ax_pair[0,1].imshow(movie_data[movie_dataset][center,ylower:yupper,xlower:xupper]) + ax_pair[0, 1].imshow(movie_data[movie_dataset][center, ylower:yupper, xlower:xupper]) - save_file = os.path.join(save_dir,'masks_'+str(pair[0])+'_'+str(pair[1])+'.pdf') + save_file = os.path.join(save_dir, "masks_" + str(pair[0]) + "_" + str(pair[1]) + ".pdf") fig_pair.savefig(save_file) plt.close(fig_pair) logging.info("\tMask saved to %s", save_file) - #print(overlap_pairs) - #print(np.unique(dm.mask_overlap[no_diagonal_mask][dm.mask_overlap[no_diagonal_mask]>0])) + # print(overlap_pairs) + # print(np.unique(dm.mask_overlap[no_diagonal_mask][dm.mask_overlap[no_diagonal_mask]>0])) def _get_epoch_windows(stim_table): - start = np.array(stim_table.start) end = np.array(stim_table.end) - windows = zip(start,end) + windows = zip(start, end) window_list = [[start[0]]] - for i,w in enumerate(windows[1:]): - if start[i+1] - end[i]>1: + for i, w in enumerate(windows[1:]): + if start[i + 1] - end[i] > 1: window_list[-1].append(end[i]) - window_list.append([start[i+1]]) - #window_list += [end[i],start[i+1]] + window_list.append([start[i + 1]]) + # window_list += [end[i],start[i+1]] window_list[-1].append(end[-1]) - #window_list = [start[0]] - #window_list += [ start[x+1] for x in list(np.where(np.abs(start[1:] - end[:-1]) > 1)[0])] - #window_list += [end[-1]] + # window_list = [start[0]] + # window_list += [ start[x+1] for x in list(np.where(np.abs(start[1:] - end[:-1]) > 1)[0])] + # window_list += [end[-1]] - #print(window_list) + # print(window_list) return window_list -def _add_stim_epochs(trace,ax,data_set): - stim_colors_dict = {'locally_sparse_noise':'green','drifting_gratings':'yellow','natural_movie_one':'magenta','natural_movie_two':'magenta','natural_movie_three':'red','natural_scenes':'orange','spontaneous':'grey','static_gratings':'blue'} +def _add_stim_epochs(trace, ax, data_set): + stim_colors_dict = { + "locally_sparse_noise": "green", + "drifting_gratings": "yellow", + "natural_movie_one": "magenta", + "natural_movie_two": "magenta", + "natural_movie_three": "red", + "natural_scenes": "orange", + "spontaneous": "grey", + "static_gratings": "blue", + } from allensdk.brain_observatory.stimulus_info import stimuli_in_session - stim_types = stimuli_in_session(data_set.get_metadata()['session_type']) - + stim_types = stimuli_in_session(data_set.get_metadata()["session_type"]) for stim in stim_types: - #print(stim) + # print(stim) stim_table = data_set.get_stimulus_table(stim) window_list = _get_epoch_windows(stim_table) for w in window_list: - #ax.fill_betweenx(np.arange(trace.shape[0]),w[0],w[1],facecolor=stim_colors_dict[stim],alpha=0.2) - #ax.axvspan(w[0],w[1],np.min(trace),np.max(trace),facecolor=stim_colors_dict[stim],alpha=0.2) - ax.axvspan(w[0],w[1],facecolor=stim_colors_dict[stim],alpha=0.2) - ax.set_ylim([np.min(trace),np.max(trace)]) + # ax.fill_betweenx(np.arange(trace.shape[0]),w[0],w[1],facecolor=stim_colors_dict[stim],alpha=0.2) + # ax.axvspan(w[0],w[1],np.min(trace),np.max(trace),facecolor=stim_colors_dict[stim],alpha=0.2) + ax.axvspan(w[0], w[1], facecolor=stim_colors_dict[stim], alpha=0.2) + ax.set_ylim([np.min(trace), np.max(trace)]) def compute_non_overlap_masks(dm): - no_masks = np.zeros(dm.masks.shape).astype(int) overlap_val = np.zeros(dm.masks.shape[0]) for i, m in enumerate(dm.masks): - - overlap_1 = np.sum(dm.masks[:i],axis=0) - overlap_2 = np.sum(dm.masks[i+1:],axis=0) + overlap_1 = np.sum(dm.masks[:i], axis=0) + overlap_2 = np.sum(dm.masks[i + 1 :], axis=0) overlap = overlap_1 + overlap_2 overlap_val[i] = np.sum(overlap) no1 = overlap == 0 - #no2 = overlap_2 == 0 + # no2 = overlap_2 == 0 no_masks[i] = np.logical_and(no1, m) @@ -188,17 +195,18 @@ def compute_non_overlap_masks(dm): return dm.no_masks + def compute_non_overlap_traces(dm, movie_path, movie_dataset): - no_traces_shape = (dm.traces.shape[0],dm.traces.shape[1]) + no_traces_shape = (dm.traces.shape[0], dm.traces.shape[1]) no_traces = np.zeros(no_traces_shape) N, T = no_traces.shape chunk_size = 1000 - num_chunks = int(np.ceil(T/float(chunk_size))) + num_chunks = int(np.ceil(T / float(chunk_size))) - normalized_flat_masks = dm.no_masks.reshape(N,-1).T # shape (pixels, N) - normalized_flat_masks /= np.sum(normalized_flat_masks,axis=0) # shape(pixels, N) + normalized_flat_masks = dm.no_masks.reshape(N, -1).T # shape (pixels, N) + normalized_flat_masks /= np.sum(normalized_flat_masks, axis=0) # shape(pixels, N) movie_f = h5py.File(movie_path) movie = movie_f[movie_dataset] @@ -206,10 +214,10 @@ def compute_non_overlap_traces(dm, movie_path, movie_dataset): logging.debug("Getting traces for %d chunks", num_chunks) for n in range(num_chunks): print("Chunk = ", n) - data = movie[n*chunk_size:(n+1)*chunk_size] - data = data.reshape(chunk_size,-1) # This line throws an error + data = movie[n * chunk_size : (n + 1) * chunk_size] + data = data.reshape(chunk_size, -1) # This line throws an error - no_traces[:,n*chunk_size:(n+1)*chunk_size] = np.dot(data,normalized_flat_masks).T + no_traces[:, n * chunk_size : (n + 1) * chunk_size] = np.dot(data, normalized_flat_masks).T movie_f.close() @@ -218,8 +226,8 @@ def compute_non_overlap_traces(dm, movie_path, movie_dataset): return dm.no_traces -def compute_correlations(dm, movie_path, movie_dataset): +def compute_correlations(dm, movie_path, movie_dataset): compute_non_overlap_masks(dm) compute_non_overlap_traces(dm, movie_path, movie_dataset) @@ -227,23 +235,23 @@ def compute_correlations(dm, movie_path, movie_dataset): dm_mean = np.mean(dm.traces_demix) t_mean = np.mean(dm.traces) - C_no_dm = np.mean( (dm.no_traces-no_mean)*(dm.traces_demix-dm_mean), axis=1) - C_t_dm = np.mean( (dm.traces-t_mean)*(dm.traces_demix-dm_mean), axis=1) + C_no_dm = np.mean((dm.no_traces - no_mean) * (dm.traces_demix - dm_mean), axis=1) + C_t_dm = np.mean((dm.traces - t_mean) * (dm.traces_demix - dm_mean), axis=1) return C_no_dm, C_t_dm + def compute_correlations_without_masks(dm): N, T = dm.traces.shape - N=N -1 - - traces = (dm.traces.T - np.mean(dm.traces.T,axis=0)) # shape (T,N) - traces_demix = (dm.traces_demix.T - np.mean(dm.traces_demix.T,axis=0)) # shape (T,N) + N = N - 1 - traces /= np.std(traces,axis=0) - traces_demix /= np.std(traces_demix,axis=0) + traces = dm.traces.T - np.mean(dm.traces.T, axis=0) # shape (T,N) + traces_demix = dm.traces_demix.T - np.mean(dm.traces_demix.T, axis=0) # shape (T,N) - cor = np.dot(traces.T,traces)/T - cor_demix = np.dot(traces_demix.T,traces_demix)/T + traces /= np.std(traces, axis=0) + traces_demix /= np.std(traces_demix, axis=0) - return cor[:N,:N], cor_demix[:N,:N] + cor = np.dot(traces.T, traces) / T + cor_demix = np.dot(traces_demix.T, traces_demix) / T + return cor[:N, :N], cor_demix[:N, :N] diff --git a/allensdk/internal/brain_observatory/demixer.py b/allensdk/internal/brain_observatory/demixer.py index 7984d95601..162f202064 100644 --- a/allensdk/internal/brain_observatory/demixer.py +++ b/allensdk/internal/brain_observatory/demixer.py @@ -10,9 +10,11 @@ from allensdk.deprecated import deprecated -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "identify_valid_masks instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "identify_valid_masks instead." +) def identify_valid_masks(mask_array): ms = mask_set.MaskSet(masks=mask_array.astype(bool)) valid_masks = np.ones(mask_array.shape[0]).astype(bool) @@ -21,7 +23,7 @@ def identify_valid_masks(mask_array): duplicates = ms.detect_duplicates(overlap_threshold=0.9) if len(duplicates) > 0: valid_masks[duplicates.keys()] = False - + # detect unions, only for remaining valid masks valid_idxs = np.where(valid_masks) ms = mask_set.MaskSet(masks=mask_array[valid_idxs].astype(bool)) @@ -34,17 +36,19 @@ def identify_valid_masks(mask_array): return valid_masks -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "demix_time_dep_masks instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "demix_time_dep_masks instead." +) def demix_time_dep_masks(raw_traces, stack, masks): - ''' + """ :param raw_traces: extracted traces :param stack: movie (same length as traces) :param masks: binary roi masks :return: demixed traces - ''' + """ N, T = raw_traces.shape _, x, y = masks.shape P = x * y @@ -63,9 +67,8 @@ def demix_time_dep_masks(raw_traces, stack, masks): demix_traces = np.zeros((N, T)) for t in range(T): - weighted_mask_sum = F[:, t] - drop_test = (weighted_mask_sum == 0) + drop_test = weighted_mask_sum == 0 if np.sum(drop_test == 0): norm_mat = sparse.diags(num_pixels_in_mask / weighted_mask_sum, offsets=0) @@ -73,7 +76,9 @@ def demix_time_dep_masks(raw_traces, stack, masks): flat_weighted_masks = norm_mat.dot(flat_masks.dot(stack_t)) - overlap = flat_masks.dot(flat_weighted_masks.T).toarray() # cast to dense numpy array for linear solver because solution is dense + overlap = flat_masks.dot( + flat_weighted_masks.T + ).toarray() # cast to dense numpy array for linear solver because solution is dense try: demix_traces[:, t] = linalg.solve(overlap, F[:, t]) except linalg.LinAlgError: @@ -89,49 +94,55 @@ def demix_time_dep_masks(raw_traces, stack, masks): return demix_traces, drop_frames -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "plot_traces instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "plot_traces instead." +) def plot_traces(raw_trace, demix_trace, roi_id, roi_ind, save_file): fig, ax = plt.subplots() - ax.plot(raw_trace, label='Fluoresence') - ax.plot(demix_trace, label='Demixed') + ax.plot(raw_trace, label="Fluoresence") + ax.plot(demix_trace, label="Demixed") ax.set_title("ROI ID(%d) index (%d)" % (roi_id, roi_ind)) ax.legend() plt.savefig(save_file) plt.close(fig) -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "find_zero_baselines instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "find_zero_baselines instead." +) def find_zero_baselines(traces): means = traces.mean(axis=1) - stds = traces.std(axis=1) - return np.where((means-stds) < 0) + stds = traces.std(axis=1) + return np.where((means - stds) < 0) -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "plot_negative_baselines instead.") -def plot_negative_baselines(raw_traces, demix_traces, mask_array, roi_ids_mask, plot_dir, ext='png'): +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "plot_negative_baselines instead." +) +def plot_negative_baselines(raw_traces, demix_traces, mask_array, roi_ids_mask, plot_dir, ext="png"): N, T = raw_traces.shape _, x, y = mask_array.shape logging.debug("finding negative baselines") neg_inds = find_negative_baselines(demix_traces)[0] - + overlap_inds = set() logging.debug("detected negative baselines: %s", str(neg_inds)) for roi_ind in neg_inds: Manifest.safe_mkdir(plot_dir) - save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_negative.' + ext) + save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_negative." + ext) plot_traces(raw_traces[roi_ind], demix_traces[roi_ind], roi_ids_mask[roi_ind], roi_ind, save_file) - ''' plot overlapping masks ''' - save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_negative_masks.' + ext) + """ plot overlapping masks """ + save_file = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_negative_masks." + ext) roi_overlap_inds = plot_overlap_masks_lengthOne(roi_ind, mask_array, save_file) overlap_inds.update(roi_overlap_inds) @@ -143,11 +154,12 @@ def plot_negative_baselines(raw_traces, demix_traces, mask_array, roi_ids_mask, return list(overlap_inds) -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "plot_negative_transients instead.") -def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, roi_ids_mask, plot_dir, ext='png'): - +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "plot_negative_transients instead." +) +def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, roi_ids_mask, plot_dir, ext="png"): N, T = raw_traces.shape _, x, y = mask_array.shape @@ -159,24 +171,23 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, ro logging.debug("plotting negative transients") - flat_masks = mask_array.reshape(N, x*y) + flat_masks = mask_array.reshape(N, x * y) overlap = flat_masks.dot(flat_masks.T) overlap ^= np.diag(np.diag(overlap)) for roi_ind in rois_with_trans: - - ''' plot biggest negative transient of this roi ''' + """ plot biggest negative transient of this roi """ trans_ind_list = trans_ind_list1[roi_ind] trans_ind_list = trans_ind_list[0] trans_list = [] for i in trans_ind_list: if i > 100 and i < T - 100: - trans_list.append(demix_traces[roi_ind, i - 100:i + 100]) + trans_list.append(demix_traces[roi_ind, i - 100 : i + 100]) elif i > 100 and i >= T - 100: - trans_list.append(demix_traces[roi_ind, i - 100:]) + trans_list.append(demix_traces[roi_ind, i - 100 :]) else: - trans_list.append(demix_traces[roi_ind, :i + 100]) + trans_list.append(demix_traces[roi_ind, : i + 100]) # trans_list = [demix_traces[roi_ind, i-100:i+100] for i in trans_ind_list if i > 100 and i < Nt] Ntrans = len(trans_list) @@ -191,22 +202,20 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, ro # trans_list_min = np.where(demix_traces[roi_ind, trans_ind_list] == min(demix_traces[roi_ind, trans_ind_list]))[0] if np.sum(overlap[roi_ind]) > 0: - if valid_roi[roi_ind]: - - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_transient_valid.' + ext) + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_transient_valid." + ext) plot_transients(roi_ind, trans_ind, mask_array, raw_traces, demix_traces, savefile) - ''' plot overlapping masks ''' - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_masks_valid.' + ext) + """ plot overlapping masks """ + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_masks_valid." + ext) plot_overlap_masks_lengthOne(roi_ind, mask_array, savefile) # plot_overlap_masks(roi_ind, mask_test, savefile) else: - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_transient_invalid.' + ext) + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_transient_invalid." + ext) plot_transients(roi_ind, trans_ind, mask_array, raw_traces, demix_traces, savefile) - ''' plot overlapping masks ''' - savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + '_masks_invalid.' + ext) + """ plot overlapping masks """ + savefile = os.path.join(plot_dir, str(roi_ids_mask[roi_ind]) + "_masks_invalid." + ext) plot_overlap_masks_lengthOne(roi_ind, mask_array, savefile) # plot_overlap_masks(roi_ind, mask_test, savefile) # @@ -216,62 +225,69 @@ def plot_negative_transients(raw_traces, demix_traces, valid_roi, mask_array, ro return rois_with_trans -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "rolling_window instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "rolling_window instead." +) def rolling_window(trace, window=500): - ''' + """ :param trace: :param window: :return: - ''' + """ shape = trace.shape[:-1] + (trace.shape[-1] - window + 1, window) - strides = trace.strides + (trace.strides[-1], ) + strides = trace.strides + (trace.strides[-1],) return np.lib.stride_tricks.as_strided(trace, shape=shape, strides=strides) -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "find_negative_baselines instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "find_negative_baselines instead." +) def find_negative_baselines(trace): means = trace.mean(axis=1) stds = trace.std(axis=1) - return np.where((means+stds) < 0) + return np.where((means + stds) < 0) -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "find_negative_transients_threshold instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "find_negative_transients_threshold instead." +) def find_negative_transients_threshold(trace, window=500, length=10, std_devs=3): - trace = np.pad(trace, pad_width=(window-1, 0), mode='constant', constant_values=[np.mean(trace[:window])]) + trace = np.pad(trace, pad_width=(window - 1, 0), mode="constant", constant_values=[np.mean(trace[:window])]) rolling_mean = np.mean(rolling_window(trace, window), -1) rolling_std = np.std(rolling_window(trace, window), -1) - below_thresh = (trace[window-1:] < rolling_mean - std_devs*rolling_std) - below_thresh = np.pad(below_thresh, pad_width=(window-1, 0), mode='constant') + below_thresh = trace[window - 1 :] < rolling_mean - std_devs * rolling_std + below_thresh = np.pad(below_thresh, pad_width=(window - 1, 0), mode="constant") trans_length = np.sum(rolling_window(below_thresh, length), -1) - trans_length = trans_length[window-length:] + trans_length = trans_length[window - length :] trans_ind = np.where(trans_length == length) return trans_ind -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "plot_overlap_masks_lengthOne instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "plot_overlap_masks_lengthOne instead." +) def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): - masks = np.array(masks).astype(float) N, x, y = masks.shape - if np.sum(masks[-1]) == x*y: + if np.sum(masks[-1]) == x * y: masks = masks[:-1] N -= 1 - flat_masks = masks.reshape(N, x*y) + flat_masks = masks.reshape(N, x * y) masks_overlap = flat_masks.dot(flat_masks.T) ind_plot = np.where(masks_overlap[roi_ind, :] > 0)[0] # rois (k) that roi_ind overlaps with @@ -280,31 +296,37 @@ def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): ind_plot = np.concatenate((ind_plot, ind_k)) ind_plot = np.unique(ind_plot) - ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot!=roi_ind])) + ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot != roi_ind])) plt.figure() - color_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + color_list = ["b", "g", "r", "c", "m", "y", "k"] Ncol = len(color_list) for num, i in enumerate(ind_plot): mask_plot = masks[i] if not weighted: - mask_plot = ((num % Ncol)+1)*np.ma.array(masks[i], mask=(masks[i] == 0)) - plt.imshow(mask_plot, clim=(1., Ncol+1), cmap=colors.ListedColormap(color_list), alpha=0.5, interpolation='nearest') + mask_plot = ((num % Ncol) + 1) * np.ma.array(masks[i], mask=(masks[i] == 0)) + plt.imshow( + mask_plot, + clim=(1.0, Ncol + 1), + cmap=colors.ListedColormap(color_list), + alpha=0.5, + interpolation="nearest", + ) # plt.imshow(mask_plot, clim=(1., len(ind_plot)), alpha=.5) elif weighted: mask_plot = np.ma.array(masks[i], mask=(masks[i] == 0)) - plt.imshow(mask_plot, cmap='gray_r', alpha=.5, interpolation='nearest') + plt.imshow(mask_plot, cmap="gray_r", alpha=0.5, interpolation="nearest") - plt.text(np.mean(np.where(np.sum(mask_plot, axis=0))), np.mean(np.where(np.sum(mask_plot, axis=1))) ,str(i)) + plt.text(np.mean(np.where(np.sum(mask_plot, axis=0))), np.mean(np.where(np.sum(mask_plot, axis=1))), str(i)) mask_tot = np.sum(masks[ind_plot, :, :], axis=0) mask_x = np.sum(mask_tot, axis=0) mask_y = np.sum(mask_tot, axis=1) - plt.xlim((np.amin(np.where(mask_x))-5, np.amax(np.where(mask_x))+5)) - plt.ylim((np.amin(np.where(mask_y))-5, np.amax(np.where(mask_y))+5)) - plt.title('Masks') + plt.xlim((np.amin(np.where(mask_x)) - 5, np.amax(np.where(mask_x)) + 5)) + plt.ylim((np.amin(np.where(mask_y)) - 5, np.amax(np.where(mask_y)) + 5)) + plt.title("Masks") if savefile is not None: plt.savefig(savefile) @@ -313,16 +335,17 @@ def plot_overlap_masks_lengthOne(roi_ind, masks, savefile=None, weighted=False): return ind_plot -@deprecated("The internal demixer module is deprecated and will be removed. " - "Please use allensdk.brain_observatory.demixer." - "plot_transients instead.") +@deprecated( + "The internal demixer module is deprecated and will be removed. " + "Please use allensdk.brain_observatory.demixer." + "plot_transients instead." +) def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): - masks = np.array(masks).astype(float) N, x, y = masks.shape _, Nt = traces.shape - flat_masks = masks.reshape(N, x*y) + flat_masks = masks.reshape(N, x * y) masks_overlap = flat_masks.dot(flat_masks.T) ind_plot = np.where(masks_overlap[roi_ind, :] > 0)[0] # rois (k) that roi_ind overlaps with @@ -331,7 +354,7 @@ def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): ind_plot = np.concatenate((ind_plot, ind_k)) ind_plot = np.unique(ind_plot) - ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot!=roi_ind])) + ind_plot = np.concatenate(([roi_ind], ind_plot[ind_plot != roi_ind])) if t_trans > 150 and t_trans < Nt - 150: plot_t = range(t_trans - 150, t_trans + 150) @@ -341,19 +364,18 @@ def plot_transients(roi_ind, t_trans, masks, traces, demix_traces, savefile): plot_t = range(0, t_trans + 150) fig, ax = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True) - color_list = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + color_list = ["b", "g", "r", "c", "m", "y", "k"] Ncol = len(color_list) for num, i in enumerate(ind_plot): ax[0].plot(plot_t, traces[i, plot_t], label=str(i), color=color_list[(num % Ncol)]) ax[1].plot(plot_t, demix_traces[i, plot_t], label=str(i), color=color_list[(num % Ncol)]) - ax[0].set_title('Raw') - ax[0].set_ylabel('Fluorescence') - ax[1].set_title('Demixed') - ax[1].set_xlabel('Time') + ax[0].set_title("Raw") + ax[0].set_ylabel("Fluorescence") + ax[1].set_title("Demixed") + ax[1].set_xlabel("Time") ax[0].legend(loc=0) plt.savefig(savefile) plt.close(fig) - diff --git a/allensdk/internal/brain_observatory/eye_calibration.py b/allensdk/internal/brain_observatory/eye_calibration.py index df9a09767b..56862e8455 100755 --- a/allensdk/internal/brain_observatory/eye_calibration.py +++ b/allensdk/internal/brain_observatory/eye_calibration.py @@ -7,18 +7,19 @@ CAMERA_POSITION_OLD = np.array([13.0, 0, 0]) CAMERA_POSITION_NEW = np.array([10.28, 7.47, 2.74]) -CAMERA_ROTATIONS_OLD = np.array([0.0, 0.0, 13.1*np.pi/180]) -CAMERA_ROTATIONS_NEW = np.array([0.0, 0.0, 2.8*np.pi/180]) +CAMERA_ROTATIONS_OLD = np.array([0.0, 0.0, 13.1 * np.pi / 180]) +CAMERA_ROTATIONS_NEW = np.array([0.0, 0.0, 2.8 * np.pi / 180]) LED_POSITION_ORIGINAL = np.array([26.51, -3.93, 0.1]) LED_POSITION_OLD = np.array([25.89, -6.12, 3.21]) LED_POSITION_NEW = np.array([24.6, 9.23, 5.26]) EYE_RADIUS = 0.1682 # in cm -CM_PER_PIXEL = 10.2/10000.0 +CM_PER_PIXEL = 10.2 / 10000.0 + class EyeCalibration(object): - '''Class for performing eye-tracking calibration. + """Class for performing eye-tracking calibration. Provides methods for estimating the position of the pupil in 3D space and projecting the gaze onto the monitor in both @@ -40,15 +41,19 @@ class EyeCalibration(object): Radius of the eye in cm. cm_per_pixel : float Pixel size of eye-tracking camera. - ''' - def __init__(self, monitor_position=MONITOR_POSITION_NEW, - monitor_rotations=MONITOR_ROTATIONS, - led_position=LED_POSITION_OLD, - camera_position=CAMERA_POSITION_OLD, - camera_rotations=CAMERA_ROTATIONS_OLD, - eye_radius=EYE_RADIUS, - cm_per_pixel=CM_PER_PIXEL): - '''Constructor.''' + """ + + def __init__( + self, + monitor_position=MONITOR_POSITION_NEW, + monitor_rotations=MONITOR_ROTATIONS, + led_position=LED_POSITION_OLD, + camera_position=CAMERA_POSITION_OLD, + camera_rotations=CAMERA_ROTATIONS_OLD, + eye_radius=EYE_RADIUS, + cm_per_pixel=CM_PER_PIXEL, + ): + """Constructor.""" self.eye_radius = eye_radius self.cm_per_pixel = cm_per_pixel @@ -57,18 +62,15 @@ def __init__(self, monitor_position=MONITOR_POSITION_NEW, self.led_position = led_position self.camera_position = camera_position - self.cr = self.cr_position_in_mouse_eye_coordinates(led_position, - eye_radius) + self.cr = self.cr_position_in_mouse_eye_coordinates(led_position, eye_radius) self.monitor_rotations = monitor_rotations if camera_rotations[0] != 0 or camera_rotations[1] != 0: - logging.warning("Got nonzero x=%s,y=%s rotations for camera", - camera_rotations[0], camera_rotations[1]) + logging.warning("Got nonzero x=%s,y=%s rotations for camera", camera_rotations[0], camera_rotations[1]) self.camera_rotation = camera_rotations[2] - def pupil_position_in_mouse_eye_coordinates(self, pupil_parameters, - cr_parameters): - '''Compute the 3D pupil position in mouse eye coordinates. + def pupil_position_in_mouse_eye_coordinates(self, pupil_parameters, cr_parameters): + """Compute the 3D pupil position in mouse eye coordinates. Parameters ---------- @@ -82,15 +84,12 @@ def pupil_position_in_mouse_eye_coordinates(self, pupil_parameters, ------- numpy.ndarray Pupil position estimates in eye coordinates. - ''' + """ # x, y are in screen coordinates, with y increasing towards the top - delta_px = (pupil_parameters.T[0] - cr_parameters.T[0]) * \ - self.cm_per_pixel - delta_py = (cr_parameters.T[1] - pupil_parameters.T[1]) * \ - self.cm_per_pixel # +y is down on image + delta_px = (pupil_parameters.T[0] - cr_parameters.T[0]) * self.cm_per_pixel + delta_py = (cr_parameters.T[1] - pupil_parameters.T[1]) * self.cm_per_pixel # +y is down on image - R_cam_to_eye = base_object_to_eye_rotation_matrix( - self.camera_position) + R_cam_to_eye = base_object_to_eye_rotation_matrix(self.camera_position) # camera frame is passed to us pointed at the eye, but the image # appears as if the camera were rotated 180 degrees about its y-axis R_cam = object_rotation_matrix(0, np.pi, self.camera_rotation) @@ -108,12 +107,12 @@ def pupil_position_in_mouse_eye_coordinates(self, pupil_parameters, p_cam = np.vstack([px_cam, py_cam, pz_cam]) - # rotate estimates + # rotate estimates return np.dot(R_cam_to_eye, np.dot(R_cam.T, p_cam)).T @staticmethod def cr_position_in_mouse_eye_coordinates(led_position, eye_radius): - '''Determine the 3D position of the corneal reflection. + """Determine the 3D position of the corneal reflection. The eye is modeled as a spherical mirror, so the reflection appears to be half the radius of the eye from the origin along @@ -130,12 +129,11 @@ def cr_position_in_mouse_eye_coordinates(led_position, eye_radius): ------- numpy.ndarray [x,y,z] location of the corneal reflection in eye coordinates. - ''' - return (eye_radius/(2*np.linalg.norm(led_position))) * led_position + """ + return (eye_radius / (2 * np.linalg.norm(led_position))) * led_position - def pupil_position_on_monitor_in_cm(self, pupil_parameters, - cr_parameters): - '''Compute the pupil position on the monitor in cm. + def pupil_position_on_monitor_in_cm(self, pupil_parameters, cr_parameters): + """Compute the pupil position on the monitor in cm. Parameters ---------- @@ -149,33 +147,27 @@ def pupil_position_on_monitor_in_cm(self, pupil_parameters, ------- numpy.ndarray Pupil position estimates in eye coordinates. - ''' - pupil_positions = self.pupil_position_in_mouse_eye_coordinates( - pupil_parameters, cr_parameters) + """ + pupil_positions = self.pupil_position_in_mouse_eye_coordinates(pupil_parameters, cr_parameters) monitor_normal = object_norm_eye_coordinates( - self.monitor_position, self.monitor_rotations[0], - self.monitor_rotations[1], self.monitor_rotations[2]) + self.monitor_position, self.monitor_rotations[0], self.monitor_rotations[1], self.monitor_rotations[2] + ) - projected_positions = project_to_plane(monitor_normal, - self.monitor_position, - pupil_positions) + projected_positions = project_to_plane(monitor_normal, self.monitor_position, pupil_positions) monitor_positions = projected_positions - self.monitor_position - R_monitor_to_eye = base_object_to_eye_rotation_matrix( - self.monitor_position) - R_monitor = object_rotation_matrix(self.monitor_rotations[0], - self.monitor_rotations[1], - self.monitor_rotations[2]) + R_monitor_to_eye = base_object_to_eye_rotation_matrix(self.monitor_position) + R_monitor = object_rotation_matrix( + self.monitor_rotations[0], self.monitor_rotations[1], self.monitor_rotations[2] + ) - result = np.dot(R_monitor.T, - np.dot(R_monitor_to_eye.T, monitor_positions.T)) + result = np.dot(R_monitor.T, np.dot(R_monitor_to_eye.T, monitor_positions.T)) return result[:2].T - def pupil_position_on_monitor_in_degrees(self, pupil_parameters, - cr_parameters): - '''Get pupil position on monitor measured in visual degrees. + def pupil_position_on_monitor_in_degrees(self, pupil_parameters, cr_parameters): + """Get pupil position on monitor measured in visual degrees. Parameters ---------- @@ -189,25 +181,24 @@ def pupil_position_on_monitor_in_degrees(self, pupil_parameters, ------- numpy.ndarray Pupil position estimate in visual degrees. - ''' + """ mag = np.sqrt(np.sum(self.monitor_position**2)) - pupil_pos = self.pupil_position_on_monitor_in_cm(pupil_parameters, - cr_parameters) + pupil_pos = self.pupil_position_on_monitor_in_cm(pupil_parameters, cr_parameters) x = pupil_pos.T[0] y = pupil_pos.T[1] - meridian = np.arctan(x/mag)*180/np.pi - elevation = np.arctan(y/np.sqrt(mag**2 + x**2))*180/np.pi + meridian = np.arctan(x / mag) * 180 / np.pi + elevation = np.arctan(y / np.sqrt(mag**2 + x**2)) * 180 / np.pi angles = np.vstack([meridian, elevation]).T return angles def compute_area(self, pupil_parameters): - '''Compute the area of the pupil. + """Compute the area of the pupil. Assume the pupil is a circle, and that as it moves off-axis with the camera the observed ellipse major axis remains the @@ -222,13 +213,13 @@ def compute_area(self, pupil_parameters): ------- numpy.ndarray [nx1] array of pupil areas in estimated pixels. - ''' + """ r = np.maximum(pupil_parameters.T[3], pupil_parameters.T[4]) - return np.pi*r*r + return np.pi * r * r def project_to_plane(plane_normal, plane_point, points): - '''Project from the origin through points onto a plane. + """Project from the origin through points onto a plane. Parameters ---------- @@ -243,15 +234,13 @@ def project_to_plane(plane_normal, plane_point, points): ------- numpy.ndarray [nx3] points projected on the plane. - ''' - factor = np.sum(plane_normal*plane_point) / \ - np.sum(plane_normal*points, axis=1) - return (factor*points.T).T + """ + factor = np.sum(plane_normal * plane_point) / np.sum(plane_normal * points, axis=1) + return (factor * points.T).T -def object_norm_eye_coordinates(object_position, x_rotation, - y_rotation, z_rotation): - '''Get the normal vector for the object plane in eye coordinates. +def object_norm_eye_coordinates(object_position, x_rotation, y_rotation, z_rotation): + """Get the normal vector for the object plane in eye coordinates. Parameters ---------- @@ -268,15 +257,14 @@ def object_norm_eye_coordinates(object_position, x_rotation, ------- numpy.ndarray Endpoint of the object plane vector in eye coordinates. - ''' + """ R_object_to_eye = base_object_to_eye_rotation_matrix(object_position) - R_object_frame = object_rotation_matrix(x_rotation, y_rotation, - z_rotation) + R_object_frame = object_rotation_matrix(x_rotation, y_rotation, z_rotation) return np.dot(R_object_to_eye, np.dot(R_object_frame, [0, 0, 1])) def base_object_to_eye_rotation_matrix(object_position): - '''Rotation matrix to rotate base object frame to eye coordinates. + """Rotation matrix to rotate base object frame to eye coordinates. By convention, any other object's coordinate frame before rotations is set with positive Z pointing from the object's position back @@ -292,28 +280,24 @@ def base_object_to_eye_rotation_matrix(object_position): ------- numpy.ndarray [3x3] rotation matrix. - ''' - eye_norm = -object_position/np.linalg.norm(object_position) + """ + eye_norm = -object_position / np.linalg.norm(object_position) # rotate about eye-z to align eye-x to object-x - theta_z = -(np.pi/2 + np.arctan2(eye_norm[1], eye_norm[0])) - Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], - [np.sin(theta_z), np.cos(theta_z), 0], - [0, 0, 1]]) + theta_z = -(np.pi / 2 + np.arctan2(eye_norm[1], eye_norm[0])) + Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0], [np.sin(theta_z), np.cos(theta_z), 0], [0, 0, 1]]) eye_norm_about_z = np.dot(Rz, eye_norm) # rotate about x' to align eye-z to object-z - theta_x = np.pi/2 - np.arctan2(eye_norm_about_z[2], eye_norm_about_z[1]) - Rx = np.array([[1, 0, 0], - [0, np.cos(theta_x), -np.sin(theta_x)], - [0, np.sin(theta_x), np.cos(theta_x)]]) + theta_x = np.pi / 2 - np.arctan2(eye_norm_about_z[2], eye_norm_about_z[1]) + Rx = np.array([[1, 0, 0], [0, np.cos(theta_x), -np.sin(theta_x)], [0, np.sin(theta_x), np.cos(theta_x)]]) R = np.dot(Rx, Rz).T return R def object_rotation_matrix(x_rotation, y_rotation, z_rotation): - '''Rotation matrix in object coordinate frame. + """Rotation matrix in object coordinate frame. The rotation matrix for rotating the object coordinate frame from the initial position. This is done by rotating around x, then @@ -332,15 +316,15 @@ def object_rotation_matrix(x_rotation, y_rotation, z_rotation): ------- numpy.ndarray [3x3] rotation matrix. - ''' - Rx = np.array([[1, 0, 0], - [0, np.cos(x_rotation), -np.sin(x_rotation)], - [0, np.sin(x_rotation), np.cos(x_rotation)]]) - Ry = np.array([[np.cos(y_rotation), 0, np.sin(y_rotation)], - [0, 1, 0], - [-np.sin(y_rotation), 0, np.cos(y_rotation)]]) - Rz = np.array([[np.cos(z_rotation), -np.sin(z_rotation), 0], - [np.sin(z_rotation), np.cos(z_rotation), 0], - [0, 0, 1]]) + """ + Rx = np.array( + [[1, 0, 0], [0, np.cos(x_rotation), -np.sin(x_rotation)], [0, np.sin(x_rotation), np.cos(x_rotation)]] + ) + Ry = np.array( + [[np.cos(y_rotation), 0, np.sin(y_rotation)], [0, 1, 0], [-np.sin(y_rotation), 0, np.cos(y_rotation)]] + ) + Rz = np.array( + [[np.cos(z_rotation), -np.sin(z_rotation), 0], [np.sin(z_rotation), np.cos(z_rotation), 0], [0, 0, 1]] + ) result = np.dot(Rz, np.dot(Ry, Rx)) return result diff --git a/allensdk/internal/brain_observatory/fit_ellipse.py b/allensdk/internal/brain_observatory/fit_ellipse.py index 3fa5f75b1d..1aaedc5040 100644 --- a/allensdk/internal/brain_observatory/fit_ellipse.py +++ b/allensdk/internal/brain_observatory/fit_ellipse.py @@ -1,19 +1,18 @@ import numpy as np -class FitEllipse (object): - - def __init__(self,min_points,max_iter,threshold,num_close): +class FitEllipse(object): + def __init__(self, min_points, max_iter, threshold, num_close): # points = np.array(candidate_points) # y,x = points.T - C = np.zeros([6,6]) - C[0,2]= 2.0 - C[2,0]= 2.0 - C[1,1]= -1.0 + C = np.zeros([6, 6]) + C[0, 2] = 2.0 + C[2, 0] = 2.0 + C[1, 1] = -1.0 - #self.x = x - #self.y = y + # self.x = x + # self.y = y self.C = C self.min_points = min_points @@ -25,43 +24,44 @@ def __init__(self,min_points,max_iter,threshold,num_close): self.best_params_set = False self.besterror = np.inf - def ransac_fit(self,candidate_points): - - #points = np.array(candidate_points) + def ransac_fit(self, candidate_points): + # points = np.array(candidate_points) for i in range(self.max_iter): - inlier_points, outlier_points = self.choose_inliers(candidate_points) params, error = self.fit_ellipse(inlier_points) - if len(outlier_points)>0: - cost = self.outlier_cost(outlier_points,params) + if len(outlier_points) > 0: + cost = self.outlier_cost(outlier_points, params) also_in = 0 - for j,c in enumerate(cost): + for j, c in enumerate(cost): point = outlier_points[j] - if cost[j] self.num_close: params, error = self.fit_ellipse(inlier_points) - if (error < self.besterror): + if error < self.besterror: self.best_params = params self.best_params_set = True self.besterror = error if self.best_params_set: - return ellipse_center(self.best_params), ellipse_angle_of_rotation(self.best_params)*180./np.pi, ellipse_axis_length(self.best_params) + return ( + ellipse_center(self.best_params), + ellipse_angle_of_rotation(self.best_params) * 180.0 / np.pi, + ellipse_axis_length(self.best_params), + ) else: return None def choose_inliers(self, candidate_points): - - #cannot take a larger sample than population - if(len(candidate_points) > self.min_points): - inlier_index = np.random.choice(np.arange(len(candidate_points)),self.min_points,replace=False) + # cannot take a larger sample than population + if len(candidate_points) > self.min_points: + inlier_index = np.random.choice(np.arange(len(candidate_points)), self.min_points, replace=False) else: - #TODO check this + # TODO check this inlier_index = np.arange(self.min_points) inlier_points = [] @@ -75,163 +75,164 @@ def choose_inliers(self, candidate_points): return inlier_points, outlier_points - def outlier_cost(self,outlier_points,params): + def outlier_cost(self, outlier_points, params): + y, x = np.array(outlier_points).T - y,x = np.array(outlier_points).T + D = np.vstack([x * x, x * y, y * y, x, y, np.ones(len(y))]) + # S = np.dot(D, D.T) - D = np.vstack([x*x, x*y, y*y, x, y, np.ones(len(y))]) - #S = np.dot(D, D.T) - - cost = (np.dot(params,D))**2 + cost = (np.dot(params, D)) ** 2 return cost - def fit_ellipse(self,inlier_points): + def fit_ellipse(self, inlier_points): try: inlier_points = np.array(inlier_points) points = np.array(inlier_points) - y,x = points.T + y, x = points.T - D = np.vstack([x*x, x*y, y*y, x, y, np.ones(len(y))]) + D = np.vstack([x * x, x * y, y * y, x, y, np.ones(len(y))]) S = np.dot(D, D.T) - M = np.dot(np.linalg.inv(S),self.C) - U,s,V=np.linalg.svd(M) + M = np.dot(np.linalg.inv(S), self.C) + U, s, V = np.linalg.svd(M) params = U.T[0] - error = np.dot(params, np.dot(S,params))/len(inlier_points) + error = np.dot(params, np.dot(S, params)) / len(inlier_points) except Exception: - #TODO - check if this is correct - params = None #WBW error handling - error = 0.00000001 #WBW error handling + # TODO - check if this is correct + params = None # WBW error handling + error = 0.00000001 # WBW error handling return params, error def ellipse_center(a): - b,c,d,f,_g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] # noqa: F841 - num = b*b-a*c - x0=(c*d-b*f)/num - y0=(a*f-b*d)/num - return np.array([x0,y0]) - -def ellipse_angle_of_rotation( a ): - b,c,_d,_f,_g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] # noqa: F841 - return 0.5*np.arctan(2*b/(a-c)) - -def ellipse_angle_of_rotation2( a ): - b,c,_d,_f,_g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] # noqa: F841 + b, c, d, f, _g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] # noqa: F841 + num = b * b - a * c + x0 = (c * d - b * f) / num + y0 = (a * f - b * d) / num + return np.array([x0, y0]) + + +def ellipse_angle_of_rotation(a): + b, c, _d, _f, _g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] # noqa: F841 + return 0.5 * np.arctan(2 * b / (a - c)) + + +def ellipse_angle_of_rotation2(a): + b, c, _d, _f, _g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] # noqa: F841 if b == 0: if a > c: return 0 else: - return np.pi/2 + return np.pi / 2 else: if a > c: - return np.arctan(2*b/(a-c))/2 + return np.arctan(2 * b / (a - c)) / 2 else: - return np.pi/2 + np.arctan(2*b/(a-c))/2 + return np.pi / 2 + np.arctan(2 * b / (a - c)) / 2 -def ellipse_axis_length( a ): - b,c,d,f,g,a = a[1]/2, a[2], a[3]/2, a[4]/2, a[5], a[0] - up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g) - down1=(b*b-a*c)*( (c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) - down2=(b*b-a*c)*( (a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) +def ellipse_axis_length(a): + b, c, d, f, g, a = a[1] / 2, a[2], a[3] / 2, a[4] / 2, a[5], a[0] + up = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g) + down1 = (b * b - a * c) * ((c - a) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)) + down2 = (b * b - a * c) * ((a - c) * np.sqrt(1 + 4 * b * b / ((a - c) * (a - c))) - (c + a)) - #TODO check this - cannot divide by 0 so just use a small number instead - if(down1 == 0): - down1 = .0000000001 + # TODO check this - cannot divide by 0 so just use a small number instead + if down1 == 0: + down1 = 0.0000000001 - if(down2 == 0): - down2 = .0000000001 + if down2 == 0: + down2 = 0.0000000001 - res1=np.sqrt(up/down1) - res2=np.sqrt(up/down2) + res1 = np.sqrt(up / down1) + res2 = np.sqrt(up / down2) return np.array([res1, res2]) -def fit_ellipse(candidate_points): +def fit_ellipse(candidate_points): # method from http://nicky.vanforeest.com/misc/fitEllipse/fitEllipse.html points = np.array(candidate_points) - y,x = points.T - D = np.vstack([x*x, x*y, y*y, x, y, np.ones(len(y))]) + y, x = points.T + D = np.vstack([x * x, x * y, y * y, x, y, np.ones(len(y))]) S = np.dot(D, D.T) - C = np.zeros([6,6]) - C[0,2]= 2.0 - C[2,0]= 2.0 - C[1,1]= -1.0 + C = np.zeros([6, 6]) + C[0, 2] = 2.0 + C[2, 0] = 2.0 + C[1, 1] = -1.0 - M = np.dot(np.linalg.inv(S),C) + M = np.dot(np.linalg.inv(S), C) - U,s,V=np.linalg.svd(M) + U, s, V = np.linalg.svd(M) params = U.T[0] - return ellipse_center(params), ellipse_angle_of_rotation(params)*180./np.pi, ellipse_axis_length(params) + return ellipse_center(params), ellipse_angle_of_rotation(params) * 180.0 / np.pi, ellipse_axis_length(params) -def rotate_vector(y,x,theta): - xp = x*np.cos(theta) - y*np.sin(theta) - yp = x*np.sin(theta) + y*np.cos(theta) +def rotate_vector(y, x, theta): + xp = x * np.cos(theta) - y * np.sin(theta) + yp = x * np.sin(theta) + y * np.cos(theta) - return yp,xp + return yp, xp -def test_fit(): +def test_fit(): import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt - x = np.linspace(-3.0,3.0,1000) - yp = np.sqrt(4.0 - (4.0/9.0)*(x**2)) + x = np.linspace(-3.0, 3.0, 1000) + yp = np.sqrt(4.0 - (4.0 / 9.0) * (x**2)) ym = -yp - yp += 0.1*np.random.normal(size=len(yp)) - ym += 0.1*np.random.normal(size=len(yp)) - - yp, x1 = rotate_vector(yp,x,np.pi/8) - ym, x2 = rotate_vector(ym,x,np.pi/8) + yp += 0.1 * np.random.normal(size=len(yp)) + ym += 0.1 * np.random.normal(size=len(yp)) - y_outlier = np.random.random(size=100)*4.0 - 2.0 - x_outlier = np.random.random(size=100)*6.0 - 3.0 + yp, x1 = rotate_vector(yp, x, np.pi / 8) + ym, x2 = rotate_vector(ym, x, np.pi / 8) + y_outlier = np.random.random(size=100) * 4.0 - 2.0 + x_outlier = np.random.random(size=100) * 6.0 - 3.0 outlier_points = np.vstack([y_outlier, x_outlier]).T - candidate_points = np.vstack([np.hstack([yp,ym]), np.hstack([x,x])]).T + candidate_points = np.vstack([np.hstack([yp, ym]), np.hstack([x, x])]).T candidate_points = np.vstack([outlier_points, candidate_points]) yt, xt = candidate_points.T print(xt) - #center, angle, (axis1,axis2) = fit_ellipse(candidate_points) + # center, angle, (axis1,axis2) = fit_ellipse(candidate_points) - fe=FitEllipse(40,100,0.0001,40) + fe = FitEllipse(40, 100, 0.0001, 40) result = fe.ransac_fit(candidate_points) if result is not None: - center, angle, (axis1,axis2) = fe.ransac_fit(candidate_points) + center, angle, (axis1, axis2) = fe.ransac_fit(candidate_points) print("center = ", center) print("angle = ", angle) print("axis1 = ", axis1) print("axis2 = ", axis2) - fig,ax=plt.subplots(1) - ax.plot(x,yp,'bo') - ax.plot(x,ym,'bo') - ax.plot(x_outlier, y_outlier, 'rx') + fig, ax = plt.subplots(1) + ax.plot(x, yp, "bo") + ax.plot(x, ym, "bo") + ax.plot(x_outlier, y_outlier, "rx") from matplotlib.patches import Ellipse - el = Ellipse(center,width=2.0*axis1,height=2.0*axis2,angle=angle,fill=False,linewidth=3,color='r') - ax.add_artist(el) + el = Ellipse(center, width=2.0 * axis1, height=2.0 * axis2, angle=angle, fill=False, linewidth=3, color="r") + ax.add_artist(el) plt.show() -if __name__=='__main__': +if __name__ == "__main__": test_fit() diff --git a/allensdk/internal/brain_observatory/frame_stream.py b/allensdk/internal/brain_observatory/frame_stream.py index efd2b26acd..e1017b0a19 100644 --- a/allensdk/internal/brain_observatory/frame_stream.py +++ b/allensdk/internal/brain_observatory/frame_stream.py @@ -7,13 +7,14 @@ import traceback import signal -class FrameInputStream( object ): + +class FrameInputStream(object): def __init__(self, movie_path, num_frames=None, block_size=1, cache_frames=False, process_frame_cb=None): self.movie_path = movie_path self.num_frames = num_frames self.block_size = block_size self.cache_frames = cache_frames - self.process_frame_cb = process_frame_cb if process_frame_cb else lambda f: f[:,:,0].copy() + self.process_frame_cb = process_frame_cb if process_frame_cb else lambda f: f[:, :, 0].copy() self.frames_read = 0 self.frame_cache = [] @@ -32,7 +33,7 @@ def _error(self): def _process_frame(self, frame): return self.process_frame_cb(frame) - + def _read_iter(self): pass @@ -53,7 +54,7 @@ def __iter__(self): for frame in self._read_iter(): self.frame_cache.append(self._process_frame(frame)) self.frames_read += 1 - + if (self.frames_read % 100) == 0: logging.debug("Read frames %d", self.frames_read) @@ -62,7 +63,7 @@ def __iter__(self): if self.block_size == 1: yield self.frame_cache[-1] elif (self.frames_read % self.block_size) == 0: - for i in range(-self.block_size,0): + for i in range(-self.block_size, 0): yield self.frame_cache[i] if not self.cache_frames: @@ -76,7 +77,6 @@ def __iter__(self): if not self.cache_frames: self.frame_cache = [] - def __exit__(self, exc_type, exc_value, tb): if exc_value: traceback.print_tb(tb) @@ -85,35 +85,56 @@ def __exit__(self, exc_type, exc_value, tb): def create_images(self, output_directory, image_type): for i, frame in enumerate(self): - file_name = os.path.join(output_directory, "input_frame-%06d." % i + image_type) + file_name = os.path.join(output_directory, "input_frame-%06d." % i + image_type) scipy.misc.imsave(file_name, frame) - -class FfmpegInputStream( FrameInputStream ): - def __init__(self, movie_path, frame_shape, ffmpeg_bin='ffmpeg', num_frames=None, block_size=1, cache_frames=False, process_frame_cb=None): - super(FfmpegInputStream, self).__init__(movie_path=movie_path, num_frames=num_frames, block_size=block_size, cache_frames=cache_frames, process_frame_cb=process_frame_cb) + +class FfmpegInputStream(FrameInputStream): + def __init__( + self, + movie_path, + frame_shape, + ffmpeg_bin="ffmpeg", + num_frames=None, + block_size=1, + cache_frames=False, + process_frame_cb=None, + ): + super(FfmpegInputStream, self).__init__( + movie_path=movie_path, + num_frames=num_frames, + block_size=block_size, + cache_frames=cache_frames, + process_frame_cb=process_frame_cb, + ) self.ffmpeg_bin = ffmpeg_bin self.frame_shape = frame_shape self.pipe = None - + def open(self): super(FfmpegInputStream, self).open() if self.pipe: raise IOError("pipe is open already") - command = [ self.ffmpeg_bin, - '-i', self.movie_path, - '-f', 'image2pipe', - '-pix_fmt', 'rgb24', - '-vcodec', 'rawvideo'] + command = [ + self.ffmpeg_bin, + "-i", + self.movie_path, + "-f", + "image2pipe", + "-pix_fmt", + "rgb24", + "-vcodec", + "rawvideo", + ] if self.num_frames is not None: - command += ['-vframes', str(self.num_frames)] + command += ["-vframes", str(self.num_frames)] - command += ['-'] + command += ["-"] self.pipe = sp.Popen(command, stdout=sp.PIPE, bufsize=0) logging.debug("opened pipe") @@ -128,12 +149,11 @@ def close(self): super(FfmpegInputStream, self).close() - rc = self.pipe.wait() logging.debug("closed input pipe") - + if rc: - raise Exception("input pipe returned with error code %d" % rc) + raise Exception("input pipe returned with error code %d" % rc) self.pipe = None @@ -168,22 +188,23 @@ def _error(self): self.pipe = None def create_images(self, output_directory, image_type): - cmd = self.ffmpeg_bin + ' -i ' + self.movie_path + ' ' + output_directory + '/input_frame-%06d.' + image_type + cmd = self.ffmpeg_bin + " -i " + self.movie_path + " " + output_directory + "/input_frame-%06d." + image_type logging.debug("Calling ffmpeg with the command:") - logging.debug("\t"+cmd) + logging.debug("\t" + cmd) retcode = sp.call(cmd, shell=True) if retcode != 0: logging.debug(retcode) - raise Exception('Something went wrong with image creation') - + raise Exception("Something went wrong with image creation") -class CvInputStream( object): +class CvInputStream(object): def __init__(self, movie_path, num_frames=None, block_size=1, cache_frames=False): - super(FfmpegInputStream, self).__init__(movie_path=movie_path, num_frames=num_frames, block_size=block_size, cache_frames=cache_frames) + super(FfmpegInputStream, self).__init__( + movie_path=movie_path, num_frames=num_frames, block_size=block_size, cache_frames=cache_frames + ) self.cap = None - + def open(self): super(FfmpegInputStream, self).open() @@ -193,6 +214,7 @@ def open(self): self.frames_read = 0 import cv2 + self.cap = cv2.VideoCapture(self.movie_path) logging.debug("opened capture") @@ -220,7 +242,8 @@ def _error(self): self.cap.release() self.cap = None -class FrameOutputStream( object ): + +class FrameOutputStream(object): def __init__(self, block_size=1): self.frames_processed = 0 self.block_frames = [] @@ -258,15 +281,16 @@ def __exit__(self, exc_type, exc_value, tb): raise exc_value self.close() -class ImageOutputStream( FrameOutputStream ): + +class ImageOutputStream(FrameOutputStream): def _write_frames(frames): for i, frame in enumerate(frames): file_name = self.movie_path % i scipy.misc.imsave(file_name, frame) - -class FfmpegOutputStream( FrameOutputStream ): - def __init__(self, frame_shape, ffmpeg_bin='ffmpeg', block_size=1): + +class FfmpegOutputStream(FrameOutputStream): + def __init__(self, frame_shape, ffmpeg_bin="ffmpeg", block_size=1): super(FfmpegOutputStream, self).__init__(block_size) self.ffmpeg_bin = ffmpeg_bin @@ -281,24 +305,32 @@ def open(self, movie_path): logging.warning("pipe is already open!") return - command = [ self.ffmpeg_bin, - '-y', - '-f', 'rawvideo', - '-vcodec', 'rawvideo', - '-s', '%dx%d' % (self.frame_shape[1], self.frame_shape[0]), - '-pix_fmt', 'rgb24', - '-r', '30', - '-i', '-', - '-an', - '-vcodec', 'libx264', - self.movie_path] + command = [ + self.ffmpeg_bin, + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-s", + "%dx%d" % (self.frame_shape[1], self.frame_shape[0]), + "-pix_fmt", + "rgb24", + "-r", + "30", + "-i", + "-", + "-an", + "-vcodec", + "libx264", + self.movie_path, + ] self.pipe = sp.Popen(command, stdin=sp.PIPE) os.kill(self.pipe.pid, signal.SIGSTOP) self.stopped = True logging.debug("opened output pipe") - def _write_frames(self, frames): if self.pipe is None: self.open(self.movie_path) @@ -308,8 +340,8 @@ def _write_frames(self, frames): for frame in frames: sys.stdout.flush() - self.pipe.stdin.write( frame.tostring() ) - + self.pipe.stdin.write(frame.tostring()) + def close(self): super(FfmpegOutputStream, self).close() if self.pipe is None: @@ -329,5 +361,3 @@ def __exit__(self, exc_type, exc_value, tb): self.pipe.kill() raise exc_value self.close() - - diff --git a/allensdk/internal/brain_observatory/itracker.py b/allensdk/internal/brain_observatory/itracker.py index 0c99ee2c51..fac605a4d4 100644 --- a/allensdk/internal/brain_observatory/itracker.py +++ b/allensdk/internal/brain_observatory/itracker.py @@ -12,17 +12,22 @@ # import cv2 -color_list = ['b','g','r','c','m','y','k'] - -class iTracker (object): - def __init__(self, output_folder, - im_shape, num_frames, - input_stream, - threshold_factor=1.3, auto=True, - cutoff_pixels=10, - bbox_pupil=None, - bbox_cr=None): - +color_list = ["b", "g", "r", "c", "m", "y", "k"] + + +class iTracker(object): + def __init__( + self, + output_folder, + im_shape, + num_frames, + input_stream, + threshold_factor=1.3, + auto=True, + cutoff_pixels=10, + bbox_pupil=None, + bbox_cr=None, + ): self.im_shape = im_shape self.num_frames = num_frames self.movie_shape = (num_frames, im_shape[0], im_shape[1]) @@ -40,29 +45,31 @@ def __init__(self, output_folder, if not os.path.exists(self.folder): os.mkdir(self.folder) - self.run_params_file = os.path.join(self.folder, 'run_params.json') - self.run_params = { 'threshold_factor': threshold_factor, - 'auto': auto, - 'cutoff_pixels': cutoff_pixels, - 'movie_shape': self.movie_shape, - 'im_shape': im_shape, - 'bbox_pupil': bbox_pupil, - 'bbox_cr': bbox_cr } - with open(self.run_params_file, 'w') as f: + self.run_params_file = os.path.join(self.folder, "run_params.json") + self.run_params = { + "threshold_factor": threshold_factor, + "auto": auto, + "cutoff_pixels": cutoff_pixels, + "movie_shape": self.movie_shape, + "im_shape": im_shape, + "bbox_pupil": bbox_pupil, + "bbox_cr": bbox_cr, + } + with open(self.run_params_file, "w") as f: f.write(json.dumps(self.run_params)) - self.movie_path_storage_file = os.path.join(self.folder, 'movie_path.txt') + self.movie_path_storage_file = os.path.join(self.folder, "movie_path.txt") if os.path.exists(self.movie_path_storage_file): - with open(self.movie_path_storage_file, 'r') as f: + with open(self.movie_path_storage_file, "r") as f: self.movie_path = f.read() else: self.movie_path = None - self.input_image_folder = os.path.join(self.folder, 'input_images') + self.input_image_folder = os.path.join(self.folder, "input_images") if not os.path.exists(self.input_image_folder): os.mkdir(self.input_image_folder) - self.results_folder = os.path.join(self.folder,'results') + self.results_folder = os.path.join(self.folder, "results") if not os.path.exists(self.results_folder): os.mkdir(self.results_folder) @@ -71,43 +78,42 @@ def __init__(self, output_folder, # if not os.path.exists(self.rays_folder): # os.mkdir(self.rays_folder) - self.qc_folder = os.path.join(self.results_folder,'qc') + self.qc_folder = os.path.join(self.results_folder, "qc") if not os.path.exists(self.qc_folder): os.mkdir(self.qc_folder) - self.frames_folder = os.path.join(self.results_folder,'output_frames') + self.frames_folder = os.path.join(self.results_folder, "output_frames") if not os.path.exists(self.frames_folder): os.mkdir(self.frames_folder) - self.pupil_file = os.path.join(self.results_folder, 'pupil_params.npy') - self.cr_file = os.path.join(self.results_folder, 'cr_params.npy') - self.mean_frame_file = os.path.join(self.results_folder, 'mean_frame.npy') - self.annotated_movie_file = os.path.join(self.results_folder, 'annotated_movie.mp4') + self.pupil_file = os.path.join(self.results_folder, "pupil_params.npy") + self.cr_file = os.path.join(self.results_folder, "cr_params.npy") + self.mean_frame_file = os.path.join(self.results_folder, "mean_frame.npy") + self.annotated_movie_file = os.path.join(self.results_folder, "annotated_movie.mp4") # add variables to determine whether to provide diagnostic, QC and other output # method to regnerate image frames, with or without results? - # fix this so it sets an absolute path def set_movie(self, file_path): logging.debug("Setting movie_path to: %s", file_path) self.movie_path = file_path - with open(self.movie_path_storage_file, 'w') as f: + with open(self.movie_path_storage_file, "w") as f: f.write(self.movie_path) def set_bbox_pupil(self, bbox): self.bbox_pupil = bbox - self.run_params['bbox_pupil']=self.bbox_pupil - with open(self.run_params_file, 'w') as f: + self.run_params["bbox_pupil"] = self.bbox_pupil + with open(self.run_params_file, "w") as f: f.write(json.dumps(self.run_params)) def set_bbox_cr(self, bbox): self.bbox_cr = bbox - self.run_params['bbox_cr']=self.bbox_cr - with open(self.run_params_file, 'w') as f: + self.run_params["bbox_cr"] = self.bbox_cr + with open(self.run_params_file, "w") as f: f.write(json.dumps(self.run_params)) - def create_input_images(self, image_type='png'): + def create_input_images(self, image_type="png"): self.input_stream.create_images(self.input_image_folder, image_type) @property @@ -120,7 +126,7 @@ def mean_frame(self): def mean_frame(self, mean_frame): self._mean_frame = mean_frame - def estimate_bbox_from_mean_frame(self, margin=75, image_type='png'): + def estimate_bbox_from_mean_frame(self, margin=75, image_type="png"): try: import keras # noqa: F401 except ImportError: @@ -131,61 +137,65 @@ def estimate_bbox_from_mean_frame(self, margin=75, image_type='png'): logging.debug("Estimating bbox parameters from 'mean_frame'") # compute the representation for mean_frame - model = InceptionV3(include_top=False, weights='imagenet') + model = InceptionV3(include_top=False, weights="imagenet") # print(self.mean_frame.dtype, self.mean_frame.shape) mp_temp = self.mean_frame.astype(np.float32) mp_temp -= 128 mp_temp /= 128 - rep = model.predict(mp_temp.reshape((1,)+mp_temp.shape)) # shape (1,13,18,2048) - rep[rep<0]=0 - rep = np.mean(rep, axis=(0,1,2)) # shape (2048,) + rep = model.predict(mp_temp.reshape((1,) + mp_temp.shape)) # shape (1,13,18,2048) + rep[rep < 0] = 0 + rep = np.mean(rep, axis=(0, 1, 2)) # shape (2048,) # load regression weights module_folder = os.path.dirname(os.path.abspath(__file__)) - W_pupil = np.load(os.path.join(module_folder,'resources','pupil_weights.npy')) # shape (2048, 5) - W_cr = np.load(os.path.join(module_folder,'resources','cr_weights.npy')) # shape (2048, 5) + W_pupil = np.load(os.path.join(module_folder, "resources", "pupil_weights.npy")) # shape (2048, 5) + W_cr = np.load(os.path.join(module_folder, "resources", "cr_weights.npy")) # shape (2048, 5) estimated_pupil_point = np.dot(rep, W_pupil) # shape (5,) - estimated_cr_point = np.dot(rep, W_cr) # shape (5,) + estimated_cr_point = np.dot(rep, W_cr) # shape (5,) - x_pupil, y_pupil = estimated_pupil_point[:2]*np.array([640,480]) + x_pupil, y_pupil = estimated_pupil_point[:2] * np.array([640, 480]) x_pupil = int(x_pupil) y_pupil = int(y_pupil) - logging.debug("estimated pupil point is ({0},{1})".format(x_pupil,y_pupil)) - print("estimated pupil point is ({0},{1})".format(x_pupil,y_pupil)) + logging.debug("estimated pupil point is ({0},{1})".format(x_pupil, y_pupil)) + print("estimated pupil point is ({0},{1})".format(x_pupil, y_pupil)) # bbox is xmin, xmax, ymin, ymax # x, y = 320, 240 - bbox_pupil = [x_pupil-margin, x_pupil+margin, y_pupil-margin, y_pupil+margin] + bbox_pupil = [x_pupil - margin, x_pupil + margin, y_pupil - margin, y_pupil + margin] x_cr, y_cr = estimated_cr_point[:2] x_cr = int(x_cr) y_cr = int(y_cr) - logging.debug("estimated cr point is ({0},{1})".format(x_cr,y_cr)) - print("estimated cr point is ({0},{1})".format(x_cr,y_cr)) + logging.debug("estimated cr point is ({0},{1})".format(x_cr, y_cr)) + print("estimated cr point is ({0},{1})".format(x_cr, y_cr)) # bbox is xmin, xmax, ymin, ymax - bbox_cr = [x_cr-margin, x_cr+margin, y_cr-margin, y_cr+margin] + bbox_cr = [x_cr - margin, x_cr + margin, y_cr - margin, y_cr + margin] # bbox_cr = None # plot bbox on mean_frame for QC check mean_frame_annotated = np.dstack([self.mean_frame, self.mean_frame, self.mean_frame]) - mean_frame_annotated = self.annotate_frame_with_bbox(mean_frame_annotated,pupil_bbox=bbox_pupil,cr_bbox=bbox_cr) - mean_frame_annotated = self.annotate_frame_with_point(mean_frame_annotated,pupil=(x_pupil, y_pupil),cr=(x_cr, y_cr)) + mean_frame_annotated = self.annotate_frame_with_bbox( + mean_frame_annotated, pupil_bbox=bbox_pupil, cr_bbox=bbox_cr + ) + mean_frame_annotated = self.annotate_frame_with_point( + mean_frame_annotated, pupil=(x_pupil, y_pupil), cr=(x_cr, y_cr) + ) dpi = 100.0 - fig, ax = plt.subplots(figsize=(mean_frame.shape[1]/dpi, mean_frame.shape[0]/dpi)) - fig.subplots_adjust(left=0,right=1,bottom=0,top=1) + fig, ax = plt.subplots(figsize=(mean_frame.shape[1] / dpi, mean_frame.shape[0] / dpi)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - ax.imshow(mean_frame_annotated, aspect='normal') - ax.axis('off') - fig.savefig(os.path.join(self.qc_folder, 'mean_frame_annotated.'+image_type), dpi=dpi) + ax.imshow(mean_frame_annotated, aspect="normal") + ax.axis("off") + fig.savefig(os.path.join(self.qc_folder, "mean_frame_annotated." + image_type), dpi=dpi) self.bbox_pupil = bbox_pupil self.bbox_cr = bbox_cr return bbox_pupil, bbox_cr - def compute_mean_frame(self, image_file_type='png'): + def compute_mean_frame(self, image_file_type="png"): logging.debug("computing mean frame") mean_frame = np.zeros(self.im_shape) @@ -201,30 +211,29 @@ def compute_mean_frame(self, image_file_type='png'): np.save(self.mean_frame_file, mean_frame) dpi = 100.0 - fig, ax = plt.subplots(figsize=(mean_frame.shape[1]/dpi, mean_frame.shape[0]/dpi)) - fig.subplots_adjust(left=0,right=1,bottom=0,top=1) - ax.imshow(mean_frame, aspect='normal', cmap='gray') - ax.axis('off') - fig.savefig(os.path.join(self.qc_folder, 'mean_frame.' + image_file_type), dpi=dpi) + fig, ax = plt.subplots(figsize=(mean_frame.shape[1] / dpi, mean_frame.shape[0] / dpi)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + ax.imshow(mean_frame, aspect="normal", cmap="gray") + ax.axis("off") + fig.savefig(os.path.join(self.qc_folder, "mean_frame." + image_file_type), dpi=dpi) plt.close() if self.bbox_cr and self.bbox_pupil: - mean_frame_annotated = np.dstack([mean_frame,mean_frame,mean_frame]) - mean_frame_annotated = self.annotate_frame_with_bbox(mean_frame_annotated,pupil_bbox=self.bbox_pupil,cr_bbox=self.bbox_cr) - - fig, ax = plt.subplots(figsize=(mean_frame.shape[1]/dpi, mean_frame.shape[0]/dpi)) - fig.subplots_adjust(left=0,right=1,bottom=0,top=1) - ax.imshow(mean_frame_annotated, aspect='normal') - ax.axis('off') - fig.savefig(os.path.join(self.qc_folder, 'mean_frame_bbox.' + image_file_type), dpi=dpi) + mean_frame_annotated = np.dstack([mean_frame, mean_frame, mean_frame]) + mean_frame_annotated = self.annotate_frame_with_bbox( + mean_frame_annotated, pupil_bbox=self.bbox_pupil, cr_bbox=self.bbox_cr + ) + + fig, ax = plt.subplots(figsize=(mean_frame.shape[1] / dpi, mean_frame.shape[0] / dpi)) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + ax.imshow(mean_frame_annotated, aspect="normal") + ax.axis("off") + fig.savefig(os.path.join(self.qc_folder, "mean_frame_bbox." + image_file_type), dpi=dpi) plt.close() return mean_frame - - def detect_eye_closed(self): - try: import keras # noqa: F401 except ImportError: @@ -235,21 +244,22 @@ def detect_eye_closed(self): logging.debug("Detecting eye closed frames") # compute the representation for mean_frame - model = InceptionV3(include_top=False, weights='imagenet') + model = InceptionV3(include_top=False, weights="imagenet") # print(self.mean_frame.dtype, self.mean_frame.shape) # get pre-trained svm from sklearn.externals import joblib + module_folder = os.path.dirname(os.path.abspath(__file__)) - svm = joblib.load(os.path.join(module_folder, 'resources','svm_trained.pkl')) + svm = joblib.load(os.path.join(module_folder, "resources", "svm_trained.pkl")) def compute_rep(frame): mp_temp = frame.astype(np.float32) mp_temp -= 128 mp_temp /= 128 - rep = model.predict(mp_temp.reshape((1,)+mp_temp.shape)) # shape (1,13,18,2048) - rep[rep<0]=0 - rep = np.mean(rep, axis=(0,1,2)) # shape (2048,) + rep = model.predict(mp_temp.reshape((1,) + mp_temp.shape)) # shape (1,13,18,2048) + rep[rep < 0] = 0 + rep = np.mean(rep, axis=(0, 1, 2)) # shape (2048,) return rep @@ -257,10 +267,10 @@ def compute_rep(frame): for input_frame in self.input_stream: rep = compute_rep(input_frame) - is_closed[i] = svm.predict(rep.reshape(-1,len(rep)))[0] + is_closed[i] = svm.predict(rep.reshape(-1, len(rep)))[0] self.is_closed = is_closed - save_path = os.path.join(self.results_folder, 'is_closed.npy') + save_path = os.path.join(self.results_folder, "is_closed.npy") logging.debug("Saving is_closed to:") logging.debug("\t%s", save_path) @@ -269,11 +279,9 @@ def compute_rep(frame): return is_closed - def process_movie(self, movie_output_stream=None, - output_frames=False, - output_annotation_frames=False, - image_file_type = 'jpg' ): - + def process_movie( + self, movie_output_stream=None, output_frames=False, output_annotation_frames=False, image_file_type="jpg" + ): # these aren't really used yet. # self.pupil_loc = (0,0) # self.cr_loc = (0,0) @@ -288,13 +296,15 @@ def process_movie(self, movie_output_stream=None, if output_frames: frame_output_stream = ImageOutputStream() - frame_output_stream.open(os.path.join(self.input_image_folder, 'input_frame-%06d.'+image_file_type)) + frame_output_stream.open(os.path.join(self.input_image_folder, "input_frame-%06d." + image_file_type)) else: frame_output_stream = None if output_annotation_frames: annotation_frame_output_stream = ImageOutputStream() - annotation_frame_output_stream.open(os.path.join(self.frames_folder, 'output_frame-%06d.'+image_file_type)) + annotation_frame_output_stream.open( + os.path.join(self.frames_folder, "output_frame-%06d." + image_file_type) + ) else: annotation_frame_output_stream = None @@ -302,26 +312,26 @@ def process_movie(self, movie_output_stream=None, # get pupil and corneal reflection parameters, this line is the actual eye tracking algorithm pupil, cr = self.process_image(input_frame, bbox_pupil=self.bbox_pupil, bbox_cr=self.bbox_cr) - pupil_params = (pupil[0][0],pupil[0][1],pupil[1],pupil[2][0],pupil[2][1]) - cr_params = (cr[0][0],cr[0][1],cr[1],cr[2][0],cr[2][1]) + pupil_params = (pupil[0][0], pupil[0][1], pupil[1], pupil[2][0], pupil[2][1]) + cr_params = (cr[0][0], cr[0][1], cr[1], cr[2][0], cr[2][1]) if frame_output_stream: frame_output_stream.write(input_frame) if movie_output_stream or annotation_frame_output_stream: - annotated_frame = self.annotate_frame(np.dstack([input_frame,input_frame,input_frame]), - pupil_params, - cr_params) + annotated_frame = self.annotate_frame( + np.dstack([input_frame, input_frame, input_frame]), pupil_params, cr_params + ) if movie_output_stream: - movie_output_stream.write( annotated_frame ) + movie_output_stream.write(annotated_frame) if annotation_frame_output_stream: - annotation_frame_output_stream.write( annotated_frame ) + annotation_frame_output_stream.write(annotated_frame) # save results in arrays - self.pupil_params[i] = (pupil[0][0],pupil[0][1],pupil[1],pupil[2][0],pupil[2][1]) - self.cr_params[i] = (cr[0][0],cr[0][1],cr[1],cr[2][0],cr[2][1]) + self.pupil_params[i] = (pupil[0][0], pupil[0][1], pupil[1], pupil[2][0], pupil[2][1]) + self.cr_params[i] = (cr[0][0], cr[0][1], cr[1], cr[2][0], cr[2][1]) if i % 100 == 0: logging.debug("tracked frame %d", i) @@ -346,7 +356,7 @@ def process_movie(self, movie_output_stream=None, def clear_input_images(self): logging.debug("Deleting input image folder") - shutil.rmtree(os.path.join(self.folder, 'input_images')) + shutil.rmtree(os.path.join(self.folder, "input_images")) def process_image(self, im, bbox_pupil=None, bbox_cr=None): # let's try median filtering the image first @@ -357,22 +367,23 @@ def process_image(self, im, bbox_pupil=None, bbox_cr=None): self.pupil_loc = initial_pupil_point(im, bbox=bbox_pupil) self.cr_loc = initial_cr_point(im, bbox=bbox_cr) - # find rays projecting from seed point - pupil_rays, pupil_ray_values = generate_rays(im,self.pupil_loc) + pupil_rays, pupil_ray_values = generate_rays(im, self.pupil_loc) # save values for analysis self.pupil_rays = pupil_rays self.pupil_ray_values = pupil_ray_values # code for finding pupil ellipse, start with candidate points from rays - pupil_candidate_points = self.get_candidate_points(self.pupil_rays,self.pupil_ray_values,self.threshold_factor,above_threshold=True) + pupil_candidate_points = self.get_candidate_points( + self.pupil_rays, self.pupil_ray_values, self.threshold_factor, above_threshold=True + ) # fit pupil ellipse with all candidate points - #pupil_params = fit_ellipse(pupil_candidate_points) + # pupil_params = fit_ellipse(pupil_candidate_points) # fit pupil ellipse with ransac algorithm - fe=FitEllipse(10,10,0.0001,4) + fe = FitEllipse(10, 10, 0.0001, 4) result = fe.ransac_fit(pupil_candidate_points) # if np.any(np.isnan(result)): #should use np.any(np.isnan(result)) @@ -380,63 +391,59 @@ def process_image(self, im, bbox_pupil=None, bbox_cr=None): # else: # pupil_params = ((np.nan,np.nan),np.nan,(np.nan,np.nan)) # np.nan*np.ones(5) - - if result is not None: #should use np.any(np.isnan(result)) - pupil_params = result #fe.ransac_fit(pupil_candidate_points) + if result is not None: # should use np.any(np.isnan(result)) + pupil_params = result # fe.ransac_fit(pupil_candidate_points) else: logging.debug("No good fit found") - pupil_params = ((np.nan,np.nan),np.nan,(np.nan,np.nan)) # np.nan*np.ones(5) - + pupil_params = ((np.nan, np.nan), np.nan, (np.nan, np.nan)) # np.nan*np.ones(5) # code for finding corneal reflection, start with finding rays from center of cr - cr_rays, cr_ray_values = generate_rays(im,self.cr_loc) + cr_rays, cr_ray_values = generate_rays(im, self.cr_loc) self.cr_rays = cr_rays self.cr_ray_values = cr_ray_values - cr_candidate_points = self.get_candidate_points(self.cr_rays,self.cr_ray_values,0.75,above_threshold=False) + cr_candidate_points = self.get_candidate_points(self.cr_rays, self.cr_ray_values, 0.75, above_threshold=False) try: - #cr_params = fit_ellipse(cr_candidate_points) - fe=FitEllipse(10,10,0.0001,4) + # cr_params = fit_ellipse(cr_candidate_points) + fe = FitEllipse(10, 10, 0.0001, 4) result = fe.ransac_fit(cr_candidate_points) if result is not None: - cr_params = result #fe.ransac_fit(cr_candidate_points) + cr_params = result # fe.ransac_fit(cr_candidate_points) else: logging.debug("No good fit found") - cr_params = ((np.nan,np.nan),np.nan,(np.nan,np.nan)) + cr_params = ((np.nan, np.nan), np.nan, (np.nan, np.nan)) except Exception as e: logging.error("Error during fit: %s", e.message) - cr_params = ((np.nan,np.nan),np.nan,(np.nan,np.nan)) + cr_params = ((np.nan, np.nan), np.nan, (np.nan, np.nan)) # update instance variables - if not np.isnan(pupil_params[0][0]): #should use np.any(np.isnan(result)) - self.pupil_loc = (int(pupil_params[0][1]),int(pupil_params[0][0])) + if not np.isnan(pupil_params[0][0]): # should use np.any(np.isnan(result)) + self.pupil_loc = (int(pupil_params[0][1]), int(pupil_params[0][0])) self.pupil_candidate_points = pupil_candidate_points self.cr_candidate_points = cr_candidate_points return pupil_params, cr_params - def get_candidate_points(self,rays,ray_values,threshold_f,above_threshold=True): - + def get_candidate_points(self, rays, ray_values, threshold_f, above_threshold=True): candidate_points = [] # find candidate points for ellipse from threshold crossing of the image over the rays for i, ray in enumerate(rays): + sample_ray = ray_values[i][: self.cutoff_pixels] + threshold = threshold_f * np.mean(sample_ray) - sample_ray = ray_values[i][:self.cutoff_pixels] - threshold = threshold_f*np.mean(sample_ray) - - for t,g in enumerate(ray_values[i][self.cutoff_pixels:]): + for t, g in enumerate(ray_values[i][self.cutoff_pixels :]): if above_threshold: if g > threshold: - new_point = ray.T[t+self.cutoff_pixels] + new_point = ray.T[t + self.cutoff_pixels] candidate_points += [new_point] break else: if g < threshold: - new_point = ray.T[t+self.cutoff_pixels] + new_point = ray.T[t + self.cutoff_pixels] candidate_points += [new_point] break @@ -449,11 +456,11 @@ def set_seed_points(self, initial_pupil_x, initial_pupil_y, initial_cr_x, initia self.initial_cr_y = initial_cr_y def process_all_images(self): - """ deprecated """ + """deprecated""" # these aren't really used yet. - self.pupil_loc = (0,0) - self.cr_loc = (0,0) + self.pupil_loc = (0, 0) + self.cr_loc = (0, 0) frame_list = os.listdir(self.input_image_folder) num_frames = len(frame_list) @@ -461,24 +468,23 @@ def process_all_images(self): self.pupil_params = np.zeros([num_frames, 5]) self.cr_params = np.zeros([num_frames, 5]) - for i,frame in enumerate(frame_list): + for i, frame in enumerate(frame_list): logging.debug("Processing frame %d", i) - if frame[-4:]!='.jpg' and frame[-4:]!='.png': + if frame[-4:] != ".jpg" and frame[-4:] != ".png": continue # just in case some OS specific files snuck in (like in OS X) - frame_path = os.path.join(self.input_image_folder,frame) + frame_path = os.path.join(self.input_image_folder, frame) # open Image, convert to gray scale and then to numpy array im = Image.open(frame_path) - im = im.convert('L') + im = im.convert("L") im = np.array(im) - # get pupil and corneal reflection parameters, this line is the actual eye tracking algorithm pupil, cr = self.process_image(im) # save results in arrays - self.pupil_params[i] = (pupil[0][0],pupil[0][1],pupil[1],pupil[2][0],pupil[2][1]) - self.cr_params[i] = (cr[0][0],cr[0][1],cr[1],cr[2][0],cr[2][1]) + self.pupil_params[i] = (pupil[0][0], pupil[0][1], pupil[1], pupil[2][0], pupil[2][1]) + self.cr_params[i] = (cr[0][0], cr[0][1], cr[1], cr[2][0], cr[2][1]) logging.debug("Saving pupil and cr parameters to:") logging.debug("\t%s", self.pupil_file) @@ -487,37 +493,33 @@ def process_all_images(self): np.save(self.pupil_file, self.pupil_params) np.save(self.cr_file, self.cr_params) - @staticmethod def rotate(X, Y, center_x, center_y, theta): - - Xp = (X-center_x)*np.cos(theta) - (Y-center_y)*np.sin(theta) + center_x - Yp = (X-center_x)*np.sin(theta) + (Y-center_y)*np.cos(theta) + center_y + Xp = (X - center_x) * np.cos(theta) - (Y - center_y) * np.sin(theta) + center_x + Yp = (X - center_x) * np.sin(theta) + (Y - center_y) * np.cos(theta) + center_y return Xp, Yp @staticmethod def get_ellipse_mask(X, Y, params): - center_x, center_y, theta, axis1, axis2 = params dX = X - center_x dY = Y - center_y - theta = theta*np.pi/180. + theta = theta * np.pi / 180.0 - Xp = dX*np.cos(theta) + dY*np.sin(theta) - Yp = -dX*np.sin(theta) + dY*np.cos(theta) + Xp = dX * np.cos(theta) + dY * np.sin(theta) + Yp = -dX * np.sin(theta) + dY * np.cos(theta) - mask1 = (Xp/axis1)**2 + (Yp/axis2)**2 < 1 + 0.1 - mask2 = (Xp/axis1)**2 + (Yp/axis2)**2 > 1 - 0.1 + mask1 = (Xp / axis1) ** 2 + (Yp / axis2) ** 2 < 1 + 0.1 + mask2 = (Xp / axis1) ** 2 + (Yp / axis2) ** 2 > 1 - 0.1 mask = np.logical_and(mask1, mask2) return mask def annotate_frame_old(self, im, pupil=None, cr=None): - y, x, c = im.shape X, Y = np.meshgrid(np.arange(x), np.arange(y)) @@ -541,17 +543,16 @@ def annotate_frame_old(self, im, pupil=None, cr=None): @classmethod def ellipse_points_from_params(cls, params): center_x, center_y, theta, axis1, axis2 = params - theta = theta*np.pi/180. # convert to radians + theta = theta * np.pi / 180.0 # convert to radians - points_x = np.array([ axis1*np.cos(phi) + center_x for phi in np.linspace(0,2*np.pi, 1000)]) - points_y = np.array([ axis2*np.sin(phi) + center_y for phi in np.linspace(0,2*np.pi, 1000)]) + points_x = np.array([axis1 * np.cos(phi) + center_x for phi in np.linspace(0, 2 * np.pi, 1000)]) + points_y = np.array([axis2 * np.sin(phi) + center_y for phi in np.linspace(0, 2 * np.pi, 1000)]) points_x, points_y = cls.rotate(points_x, points_y, center_x, center_y, theta) return points_x, points_y def annotate_frame(self, im, pupil=None, cr=None): - y, x, c = im.shape im_pil = Image.fromarray(im) @@ -560,13 +561,13 @@ def annotate_frame(self, im, pupil=None, cr=None): if pupil is not None: points_x, points_y = self.ellipse_points_from_params(pupil) - draw.point(zip(points_x, points_y), fill=(255,0,0)) + draw.point(zip(points_x, points_y), fill=(255, 0, 0)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 if cr is not None: points_x, points_y = self.ellipse_points_from_params(cr) - draw.point(zip(points_x, points_y), fill=(0,0,255)) + draw.point(zip(points_x, points_y), fill=(0, 0, 255)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 @@ -574,7 +575,6 @@ def annotate_frame(self, im, pupil=None, cr=None): return np.array(im_pil) def annotate_frame_with_bbox(self, im, pupil_bbox=None, cr_bbox=None): - # y, x, c = im.shape im_pil = Image.fromarray(im) @@ -584,7 +584,7 @@ def annotate_frame_with_bbox(self, im, pupil_bbox=None, cr_bbox=None): if pupil_bbox is not None: xmin, xmax, ymin, ymax = pupil_bbox # print(pupil_bbox) - draw.rectangle([xmin,ymin,xmax,ymax],outline=(255,0,0)) + draw.rectangle([xmin, ymin, xmax, ymax], outline=(255, 0, 0)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 @@ -592,7 +592,7 @@ def annotate_frame_with_bbox(self, im, pupil_bbox=None, cr_bbox=None): if cr_bbox is not None: # print(cr_bbox) xmin, xmax, ymin, ymax = cr_bbox - draw.rectangle([xmin,ymin,xmax,ymax],outline=(0,0,255)) + draw.rectangle([xmin, ymin, xmax, ymax], outline=(0, 0, 255)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 @@ -600,7 +600,6 @@ def annotate_frame_with_bbox(self, im, pupil_bbox=None, cr_bbox=None): return np.array(im_pil) def annotate_frame_with_point(self, im, pupil=None, cr=None): - # y, x, c = im.shape # print(im.shape, im.dtype) @@ -612,14 +611,14 @@ def annotate_frame_with_point(self, im, pupil=None, cr=None): if pupil is not None: # points_x, points_y = ellipse_points_from_params(pupil) # draw.point(pupil, fill=(255,0,0)) - draw.ellipse([pupil[0]-5,pupil[1]-5,pupil[0]+5,pupil[1]+5],fill=(255,0,0)) + draw.ellipse([pupil[0] - 5, pupil[1] - 5, pupil[0] + 5, pupil[1] + 5], fill=(255, 0, 0)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 if cr is not None: # points_x, points_y = ellipse_points_from_params(cr) # draw.point(cr, fill=(0,0,255)) - draw.ellipse([cr[0]-5,cr[1]-5,cr[0]+5,cr[1]+5],fill=(0,0,255)) + draw.ellipse([cr[0] - 5, cr[1] - 5, cr[0] + 5, cr[1] + 5], fill=(0, 0, 255)) # for i, px in enumerate(points_x): # im[int(points_y[i]), int(px), 0] = 255 @@ -628,8 +627,7 @@ def annotate_frame_with_point(self, im, pupil=None, cr=None): @staticmethod def get_frame_index(frame_name): - - return int(frame_name[12:-4])-1 # change 7 to 12 + return int(frame_name[12:-4]) - 1 # change 7 to 12 # def annotate_frame(self, frame, im, pupil, cr): # # this function is not done yet @@ -647,10 +645,10 @@ def get_frame_index(frame_name): def output_annotation(self, frames_to_output=None): """generate a the series of images with eyetracking results superimposed""" - fig, ax = plt.subplots(figsize=(4,3)) #, frameon=False) - fig.subplots_adjust(left=0,right=1,bottom=0,top=1) + fig, ax = plt.subplots(figsize=(4, 3)) # , frameon=False) + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - ax.axis('off') + ax.axis("off") # ax.axis('tight') self.pupil_params = np.load(self.pupil_file) @@ -665,22 +663,21 @@ def output_annotation(self, frames_to_output=None): im = Image.open(os.path.join(self.input_image_folder, first_frame)) im = np.array(im) - new_im = np.dstack([im,im,im]) + new_im = np.dstack([im, im, im]) pupil = self.pupil_params[frame_index] cr = self.cr_params[frame_index] new_im = self.annotate_frame(new_im, pupil, cr) - im_fig = ax.imshow(new_im, aspect='normal') #extent=(0,1,1,0) + im_fig = ax.imshow(new_im, aspect="normal") # extent=(0,1,1,0) fig.savefig(os.path.join(self.frames_folder, first_frame), dpi=100) for input_frame in frames_to_output[1:]: - frame_index = self.get_frame_index(input_frame) im = Image.open(os.path.join(self.input_image_folder, input_frame)) im = np.array(im) - new_im = np.dstack([im,im,im]) + new_im = np.dstack([im, im, im]) pupil = self.pupil_params[frame_index] cr = self.cr_params[frame_index] new_im = self.annotate_frame(new_im, pupil, cr) @@ -689,7 +686,7 @@ def output_annotation(self, frames_to_output=None): fig.savefig(os.path.join(self.frames_folder, input_frame), dpi=100) - def output_QC(self, image_type='png'): + def output_QC(self, image_type="png"): """generate a set of summary statistics and plots for QC purposes""" logging.debug("saving QC images") @@ -698,63 +695,62 @@ def output_QC(self, image_type='png'): logging.debug("saving pupil position") fig, ax = plt.subplots(1) - ax.plot(self.pupil_params.T[0], label='pupil x') - ax.plot(self.pupil_params.T[1], label='pupil y') - ax.set_xlabel('frame index') - ax.set_title('pupil position') + ax.plot(self.pupil_params.T[0], label="pupil x") + ax.plot(self.pupil_params.T[1], label="pupil y") + ax.set_xlabel("frame index") + ax.set_title("pupil position") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'pupil_position.'+image_type)) + fig.savefig(os.path.join(self.qc_folder, "pupil_position." + image_type)) logging.debug("saving cr position") fig, ax = plt.subplots(1) - ax.plot(self.cr_params.T[0], label='cr x') - ax.plot(self.cr_params.T[1], label='cr y') - ax.set_xlabel('frame index') - ax.set_title('CR position') + ax.plot(self.cr_params.T[0], label="cr x") + ax.plot(self.cr_params.T[1], label="cr y") + ax.set_xlabel("frame index") + ax.set_title("CR position") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'cr_position.'+image_type)) + fig.savefig(os.path.join(self.qc_folder, "cr_position." + image_type)) logging.debug("saving pupil axes") fig, ax = plt.subplots(1) - ax.plot(self.pupil_params.T[3], label='pupil axis 1') - ax.plot(self.pupil_params.T[4], label='pupil axis 2') - ax.set_xlabel('frame index') - ax.set_title('Pupil major and minor axis size') + ax.plot(self.pupil_params.T[3], label="pupil axis 1") + ax.plot(self.pupil_params.T[4], label="pupil axis 2") + ax.set_xlabel("frame index") + ax.set_title("Pupil major and minor axis size") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'pupil_axes.'+image_type)) + fig.savefig(os.path.join(self.qc_folder, "pupil_axes." + image_type)) logging.debug("saving cr major/minor axis") fig, ax = plt.subplots(1) - ax.plot(self.cr_params.T[3], label='cr axis 1') - ax.plot(self.cr_params.T[4], label='cr axis 2') - ax.set_xlabel('frame index') - ax.set_title('CR major and minor axis size') + ax.plot(self.cr_params.T[3], label="cr axis 1") + ax.plot(self.cr_params.T[4], label="cr axis 2") + ax.set_xlabel("frame index") + ax.set_title("CR major and minor axis size") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'cr_axes.'+image_type)) + fig.savefig(os.path.join(self.qc_folder, "cr_axes." + image_type)) logging.debug("saving pupil angle") fig, ax = plt.subplots(1) - ax.plot(self.pupil_params.T[2], label='pupil angle') - ax.set_xlabel('frame index') - ax.set_title('pupil major axis angle') + ax.plot(self.pupil_params.T[2], label="pupil angle") + ax.set_xlabel("frame index") + ax.set_title("pupil major axis angle") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'pupil_angle.'+image_type)) + fig.savefig(os.path.join(self.qc_folder, "pupil_angle." + image_type)) logging.debug("saving cr angle") fig, ax = plt.subplots(1) - ax.plot(self.cr_params.T[2], label='cr angle') - ax.set_xlabel('frame index') - ax.set_title('corneal reflection major axis angle') + ax.plot(self.cr_params.T[2], label="cr angle") + ax.set_xlabel("frame index") + ax.set_title("corneal reflection major axis angle") ax.legend() - fig.savefig(os.path.join(self.qc_folder, 'cr_angle.'+image_type)) - + fig.savefig(os.path.join(self.qc_folder, "cr_angle." + image_type)) logging.debug("computing density") # the remainder of these take a *very* long time T = self.pupil_params.shape[0] y, x = self.im_shape - mean_frame = np.dstack([self.mean_frame,self.mean_frame,self.mean_frame]) + mean_frame = np.dstack([self.mean_frame, self.mean_frame, self.mean_frame]) pupil_density = np.zeros((y, x, 3)) # pupil_all = 255*np.ones(pupil_density.shape, np.uint8) # pupil_all = np.stack([mean_frame, mean_frame, mean_frame], axis=2) @@ -764,44 +760,42 @@ def output_QC(self, image_type='png'): # cr_all = 255*np.ones(cr_density.shape, np.uint8) # cr_all = np.stack([mean_frame, mean_frame, mean_frame], axis=2) cr_all = mean_frame.copy() - temp = np.zeros((y,x,3), dtype=np.uint8) + temp = np.zeros((y, x, 3), dtype=np.uint8) for t in range(T): if t % 100 == 0: logging.debug("finished %d frames", t) ptemp = self.annotate_frame(temp.copy(), self.pupil_params[t]) - pupil_density += ptemp #, self.cr_params[t]) + pupil_density += ptemp # , self.cr_params[t]) pupil_all = self.annotate_frame(pupil_all, self.pupil_params[t]) crtemp = self.annotate_frame(temp.copy(), cr=self.cr_params[t]) - cr_density += crtemp #, self.cr_params[t]) + cr_density += crtemp # , self.cr_params[t]) cr_all = self.annotate_frame(cr_all, cr=self.cr_params[t]) logging.debug("plotting pupil density") fig, ax = plt.subplots(1) - ax.imshow(np.log(1+pupil_density[:,:,0]), cmap='Greys', interpolation='nearest') + ax.imshow(np.log(1 + pupil_density[:, :, 0]), cmap="Greys", interpolation="nearest") # ax.axis('off') - ax.set_title('Pupil ellipse density') - fig.savefig(os.path.join(self.qc_folder, 'pupil_density.'+image_type)) - + ax.set_title("Pupil ellipse density") + fig.savefig(os.path.join(self.qc_folder, "pupil_density." + image_type)) logging.debug("plotting pupil all") fig, ax = plt.subplots(1) - ax.imshow(pupil_all, cmap='Greys', interpolation='nearest') + ax.imshow(pupil_all, cmap="Greys", interpolation="nearest") # ax.axis('off') - ax.set_title('All pupil ellipses combined') - fig.savefig(os.path.join(self.qc_folder, 'pupil_all_plot.'+image_type)) + ax.set_title("All pupil ellipses combined") + fig.savefig(os.path.join(self.qc_folder, "pupil_all_plot." + image_type)) logging.debug("plotting cr density") fig, ax = plt.subplots(1) - ax.imshow(np.log(1+cr_density[:,:,2]), cmap='Greys', interpolation='nearest') + ax.imshow(np.log(1 + cr_density[:, :, 2]), cmap="Greys", interpolation="nearest") # ax.axis('off') - ax.set_title('CR ellipse density') - fig.savefig(os.path.join(self.qc_folder, 'cr_density.'+image_type)) + ax.set_title("CR ellipse density") + fig.savefig(os.path.join(self.qc_folder, "cr_density." + image_type)) logging.debug("plotting cr all") fig, ax = plt.subplots(1) - ax.imshow(cr_all, cmap='Greys', interpolation='nearest') - ax.set_title('All CR ellipses combined') + ax.imshow(cr_all, cmap="Greys", interpolation="nearest") + ax.set_title("All CR ellipses combined") # ax.axis('off') - fig.savefig(os.path.join(self.qc_folder, 'cr_all_plot.'+image_type)) - + fig.savefig(os.path.join(self.qc_folder, "cr_all_plot." + image_type)) diff --git a/allensdk/internal/brain_observatory/itracker_utils.py b/allensdk/internal/brain_observatory/itracker_utils.py index fce3b347fa..d2dd26d39e 100644 --- a/allensdk/internal/brain_observatory/itracker_utils.py +++ b/allensdk/internal/brain_observatory/itracker_utils.py @@ -3,33 +3,33 @@ from scipy.ndimage.filters import sobel import logging + def default_ray(n): + y = np.zeros(n, dtype=np.int64) + x = np.arange(n, dtype=np.int64) - y = np.zeros(n,dtype=np.int64) - x = np.arange(n,dtype=np.int64) + return np.vstack([y, x]) - return np.vstack([y,x]) -def rotate_ray(ray,theta): +def rotate_ray(ray, theta): + y, x = ray.astype(np.float64) - y,x = ray.astype(np.float64) + xp = x * np.cos(theta) + y * np.sin(theta) + yp = -x * np.sin(theta) + y * np.cos(theta) - xp = x*np.cos(theta) + y*np.sin(theta) - yp = -x*np.sin(theta) + y*np.cos(theta) + return np.vstack([yp.astype(np.int64), xp.astype(np.int64)]) - return np.vstack([yp.astype(np.int64),xp.astype(np.int64)]) def generate_rays(image_array, seed_pixel): + N = 18 # 200 - N = 18 #200 - - #mag, grad_x, grad_y = sobel_grad(image_array.astype('float')) + # mag, grad_x, grad_y = sobel_grad(image_array.astype('float')) shape = image_array.shape - Y,X = np.mgrid[:shape[1],:shape[0]] + Y, X = np.mgrid[: shape[1], : shape[0]] - n = int(np.sqrt(shape[0]**2 + shape[1]**2)) - angles = np.arange(N)*2.0*np.pi/N + n = int(np.sqrt(shape[0] ** 2 + shape[1] ** 2)) + angles = np.arange(N) * 2.0 * np.pi / N rays = [] tangents = [] @@ -37,34 +37,33 @@ def generate_rays(image_array, seed_pixel): ray_grads = [] def good_coords_mask(y, x): - return np.logical_and(np.logical_and(y>=0,y=0,x= 0, y < shape[0]), np.logical_and(x >= 0, x < shape[1])) for theta in angles: - new_ray = rotate_ray(default_ray(n),theta) + new_ray = rotate_ray(default_ray(n), theta) new_ray = new_ray.T + seed_pixel new_ray = new_ray.T - - mask = good_coords_mask(new_ray[0],new_ray[1]) + mask = good_coords_mask(new_ray[0], new_ray[1]) ym = new_ray[0][mask] xm = new_ray[1][mask] - rays += [np.vstack([ym,xm])] + rays += [np.vstack([ym, xm])] - t = np.array([np.sin(theta),np.cos(theta)]) + t = np.array([np.sin(theta), np.cos(theta)]) tangents += [t] - #rg = t[1]*grad_x[ym,xm] + t[0]*grad_y[ym,xm] - #rg[rg<0] = 0.0 - rg = image_array[ym,xm] - #rg = rg[1:].astype(np.float64) - rg[:-1].astype(np.float64) + # rg = t[1]*grad_x[ym,xm] + t[0]*grad_y[ym,xm] + # rg[rg<0] = 0.0 + rg = image_array[ym, xm] + # rg = rg[1:].astype(np.float64) - rg[:-1].astype(np.float64) # rg[rg<0] = 0.0 ray_grads += [rg] - return rays, ray_grads + def initial_pupil_point(image_array, bbox=None): """bbox is a tuple of (xmin, xmax, ymin, ymax)""" @@ -74,24 +73,25 @@ def initial_pupil_point(image_array, bbox=None): else: shape = image_array.shape crop_distance = 50 - crop_im = image_array[crop_distance:shape[0]-crop_distance, crop_distance:shape[1]-crop_distance] + crop_im = image_array[crop_distance : shape[0] - crop_distance, crop_distance : shape[1] - crop_distance] m = np.max(crop_im) - dark_square = m*np.ones([30,30]) - #c = correlate2d(m-crop_im,dark_square,mode='same') - c = fftconvolve(m-crop_im,dark_square[::-1,::-1],mode='same') - y,x = np.where(c==np.max(c)) + dark_square = m * np.ones([30, 30]) + # c = correlate2d(m-crop_im,dark_square,mode='same') + c = fftconvolve(m - crop_im, dark_square[::-1, ::-1], mode="same") + y, x = np.where(c == np.max(c)) if bbox is not None: - ybar=int(np.mean(y))+ymin - xbar=int(np.mean(x))+xmin + ybar = int(np.mean(y)) + ymin + xbar = int(np.mean(x)) + xmin else: - ybar=int(np.mean(y))+crop_distance - xbar=int(np.mean(x))+crop_distance + ybar = int(np.mean(y)) + crop_distance + xbar = int(np.mean(x)) + crop_distance return ybar, xbar + def initial_cr_point(image_array, bbox=None): """bbox is a tuple of (xmin, xmax, ymin, ymax)""" @@ -101,56 +101,57 @@ def initial_cr_point(image_array, bbox=None): else: shape = image_array.shape crop_distance = 50 - crop_im = image_array[crop_distance:shape[0]-crop_distance, crop_distance:shape[1]-crop_distance] + crop_im = image_array[crop_distance : shape[0] - crop_distance, crop_distance : shape[1] - crop_distance] m = np.max(crop_im) mean = np.mean(crop_im) - Y,X = np.meshgrid(np.arange(-20,20),np.arange(-20,20)) - bright_circle = np.zeros([40,40]) - mask = X**2 + Y**2 < 100. + Y, X = np.meshgrid(np.arange(-20, 20), np.arange(-20, 20)) + bright_circle = np.zeros([40, 40]) + mask = X**2 + Y**2 < 100.0 bright_circle[mask] = m bright_circle -= np.mean(bright_circle) - #c = correlate2d(crop_im-mean,bright_circle,mode='same') - c = fftconvolve(crop_im-mean,bright_circle[::-1,::-1],mode='same') - y,x = np.where(c==np.max(c)) + # c = correlate2d(crop_im-mean,bright_circle,mode='same') + c = fftconvolve(crop_im - mean, bright_circle[::-1, ::-1], mode="same") + y, x = np.where(c == np.max(c)) if bbox is not None: - ybar=int(np.mean(y))+ymin - xbar=int(np.mean(x))+xmin + ybar = int(np.mean(y)) + ymin + xbar = int(np.mean(x)) + xmin else: - ybar=int(np.mean(y))+crop_distance - xbar=int(np.mean(x))+crop_distance + ybar = int(np.mean(y)) + crop_distance + xbar = int(np.mean(x)) + crop_distance return ybar, xbar -def sobel_grad(image_array): - grad_y = sobel(image_array.astype(np.float64),0) - grad_x = sobel(image_array.astype(np.float64),1) +def sobel_grad(image_array): + grad_y = sobel(image_array.astype(np.float64), 0) + grad_x = sobel(image_array.astype(np.float64), 1) - #print "grad_x dtype = ", grad_x.dtype + # print "grad_x dtype = ", grad_x.dtype mag = np.sqrt(grad_y**2 + grad_x**2) + 1e-16 return mag, grad_x, grad_y + def medfilt_custom(x, kernel_size=3): - '''This median filter returns 'nan' whenever any value in the kernal width is 'nan' and the median otherwise''' + """This median filter returns 'nan' whenever any value in the kernal width is 'nan' and the median otherwise""" T = x.shape[0] - delta = kernel_size/2 + delta = kernel_size / 2 x_med = np.zeros(x.shape) - window = x[0:delta+1] + window = x[0 : delta + 1] if np.any(np.isnan(window)): x_med[0] = np.nan else: x_med[0] = np.median(window) # print window - for t in range(1,T): - window = x[t-delta:t+delta+1] + for t in range(1, T): + window = x[t - delta : t + delta + 1] # print window if np.any(np.isnan(window)): x_med[t] = np.nan @@ -159,12 +160,13 @@ def medfilt_custom(x, kernel_size=3): return x_med + def eccentricity(a1, a2): + return np.sqrt(1.0 - (np.minimum(a1, a2) ** 2) / (np.maximum(a1, a2) ** 2)) - return np.sqrt(1.0 - (np.minimum(a1,a2)**2)/(np.maximum(a1,a2)**2)) def median_absolute_deviation(a, consistency_constant=1.4826): - '''Calculate the median absolute deviation of a univariate dataset. + """Calculate the median absolute deviation of a univariate dataset. Parameters ---------- @@ -178,36 +180,36 @@ def median_absolute_deviation(a, consistency_constant=1.4826): ------- float Median absolute deviation of the data. - ''' + """ return consistency_constant * np.nanmedian(np.abs(a - np.nanmedian(a))) + def post_process_cr(cr_params): """This will replace questionable values of the CR x and y position with 'nan' - 1) threshold ellipse area by 99th percentile area distribution - 2) median filter using custom median filter - 3) remove deviations from discontinuous jumps + 1) threshold ellipse area by 99th percentile area distribution + 2) median filter using custom median filter + 3) remove deviations from discontinuous jumps - The 'nan' values likely represent obscured CRs, secondary reflections, merges - with the secondary reflection, or visual distortions due to the whisker or - deformations of the eye""" + The 'nan' values likely represent obscured CRs, secondary reflections, merges + with the secondary reflection, or visual distortions due to the whisker or + deformations of the eye""" - area = np.pi*cr_params.T[3]*cr_params.T[4] + area = np.pi * cr_params.T[3] * cr_params.T[4] # compute a threshold on the area of the cr ellipse dev = median_absolute_deviation(area) if dev == 0: - logging.warning("Median absolute deviation is 0," - "falling back to standard deviation.") + logging.warning("Median absolute deviation is 0,falling back to standard deviation.") dev = np.nanstd(area) - threshold = np.nanmedian(area) + 3*dev + threshold = np.nanmedian(area) + 3 * dev x_center = cr_params.T[0] y_center = cr_params.T[1] # set x,y where area is over threshold to nan - x_center[area>threshold] = np.nan - y_center[area>threshold] = np.nan + x_center[area > threshold] = np.nan + y_center[area > threshold] = np.nan # median filter x_center_med = medfilt_custom(x_center, kernel_size=3) @@ -225,12 +227,12 @@ def post_process_cr(cr_params): std_y = np.std(y_center_med[y_mask_finite]) # set these extreme values to nan - #x_center_med[x_center_med < mean_x - 3*std_x] = np.nan - #y_center_med[y_center_med > mean_y + 3*std_y] = np.nan - x_center_med[np.abs(x_center_med - mean_x) > 3*std_x] = np.nan - y_center_med[np.abs(y_center_med - mean_y) > 3*std_y] = np.nan + # x_center_med[x_center_med < mean_x - 3*std_x] = np.nan + # y_center_med[y_center_med > mean_y + 3*std_y] = np.nan + x_center_med[np.abs(x_center_med - mean_x) > 3 * std_x] = np.nan + y_center_med[np.abs(y_center_med - mean_y) > 3 * std_y] = np.nan - either_nan_mask = np.logical_and(np.isnan(x_center_med),np.isnan(y_center_med)) + either_nan_mask = np.logical_and(np.isnan(x_center_med), np.isnan(y_center_med)) x_center_med[either_nan_mask] = np.nan y_center_med[either_nan_mask] = np.nan @@ -242,7 +244,7 @@ def post_process_cr(cr_params): def post_process_pupil(pupil_params): - '''Filter pupil parameters to replace outliers with nan + """Filter pupil parameters to replace outliers with nan Parameters ---------- @@ -253,8 +255,8 @@ def post_process_pupil(pupil_params): ------- numpy.ndarray Pupil parameters with outliers replaced with nan - ''' - area = np.pi*pupil_params.T[3]*pupil_params.T[4] + """ + area = np.pi * pupil_params.T[3] * pupil_params.T[4] threshold = np.percentile(area[np.isfinite(area)], 99) outlier_index = area > threshold pupil_params[outlier_index, :] = np.nan @@ -262,7 +264,7 @@ def post_process_pupil(pupil_params): def filter_bad_params(params, frame_width, frame_height): - '''Replace positions outside image with nan''' - params[(params[:,0] > frame_width) | (params[:,0] < 0), :] = np.nan - params[(params[:,1] > frame_height) | (params[:,1] < 0), :] = np.nan + """Replace positions outside image with nan""" + params[(params[:, 0] > frame_width) | (params[:, 0] < 0), :] = np.nan + params[(params[:, 1] > frame_height) | (params[:, 1] < 0), :] = np.nan return params diff --git a/allensdk/internal/brain_observatory/mask_set.py b/allensdk/internal/brain_observatory/mask_set.py index 4dc5f89e4c..623dd0a40c 100644 --- a/allensdk/internal/brain_observatory/mask_set.py +++ b/allensdk/internal/brain_observatory/mask_set.py @@ -2,7 +2,8 @@ import numpy as np import logging -class MaskSet( object ): + +class MaskSet(object): def __init__(self, masks): self.masks = masks self.bbs = make_bbs(self.masks) @@ -21,14 +22,14 @@ def count(self): return len(self.bbs) def distance(self, mask_idxs): - return max(self.mask_dist[i,j] for (i,j) in itertools.combinations(mask_idxs, 2)) + return max(self.mask_dist[i, j] for (i, j) in itertools.combinations(mask_idxs, 2)) def close(self, mask_idxs, max_dist): - return not any(self.mask_dist[i,j] > max_dist for (i,j) in itertools.combinations(mask_idxs, 2)) + return not any(self.mask_dist[i, j] > max_dist for (i, j) in itertools.combinations(mask_idxs, 2)) def close_sets(self, set_size, max_dist): mask_sets = itertools.combinations(range(len(self.bbs)), set_size) - return (ms for ms in mask_sets if self.close(ms, max_dist)) + return (ms for ms in mask_sets if self.close(ms, max_dist)) def _idx_key(self, idxs): return tuple(sorted(set(idxs))) @@ -38,7 +39,7 @@ def mask(self, mask_idx): def union(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) - + if mask_idxs in self.cached_unions: return self.cached_unions[mask_idxs] @@ -47,7 +48,7 @@ def union(self, mask_idxs): i0 = mask_idxs[0] union = self.masks[i0].copy() - + if len(mask_idxs) == 1: return union @@ -59,18 +60,18 @@ def union(self, mask_idxs): return union def overlap_fraction(self, idx0, idx1): - union_size = self.union_size([idx0,idx1]) - overlap_size = self.intersection_size([idx0,idx1]) + union_size = self.union_size([idx0, idx1]) + overlap_size = self.intersection_size([idx0, idx1]) return float(overlap_size) / float(union_size) def detect_duplicates(self, overlap_threshold): duplicate_masks = set() - for idx0,idx1 in self.close_sets(set_size=2, max_dist=0): + for idx0, idx1 in self.close_sets(set_size=2, max_dist=0): overlap_frac = self.overlap_fraction(idx0, idx1) if overlap_frac > overlap_threshold: - duplicate_masks.add(tuple(sorted([idx0,idx1]))) + duplicate_masks.add(tuple(sorted([idx0, idx1]))) return duplicate_masks @@ -83,7 +84,6 @@ def mask_is_union_of_set(self, mask_idx, set_idxs, threshold): if overlap_size < threshold * set_mask_size: return False - # does this mask cover more than the union of the individual set elements? set_union = self.union(set_idxs) mask = self.mask(mask_idx) @@ -102,9 +102,9 @@ def detect_unions(self, set_size=2, max_dist=10, threshold=0.7): if mask_idx in set_idxs: continue elif not self.close([mask_idx] + list(set_idxs), max_dist): - continue + continue - if self.mask_is_union_of_set(mask_idx, set_idxs, threshold): + if self.mask_is_union_of_set(mask_idx, set_idxs, threshold): if mask_idx in union_masks: logging.warning("already detected this mask as a union") union_masks[mask_idx] = set_idxs @@ -116,15 +116,15 @@ def union_size(self, mask_idxs): if mask_idxs in self.cached_union_sizes: return self.cached_union_sizes[mask_idxs] - + s = self.union(mask_idxs).sum() self.cached_union_sizes[mask_idxs] = s return s - def intersection(self, mask_idxs): + def intersection(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) - + if mask_idxs in self.cached_intersections: return self.cached_intersections[mask_idxs] @@ -134,7 +134,7 @@ def intersection(self, mask_idxs): # don't cache the empty ones if not self.close(mask_idxs, 0): return np.zeros(self.masks[0].shape) - + i0 = mask_idxs[0] intersection = self.masks[i0].copy() @@ -153,7 +153,7 @@ def intersection_size(self, mask_idxs): if mask_idxs in self.cached_intersection_sizes: return self.cached_intersection_sizes[mask_idxs] - + s = self.intersection(mask_idxs).sum() self.cached_intersection_sizes[mask_idxs] = s @@ -168,15 +168,16 @@ def make_bbs(masks): for i in range(len(masks)): m = np.where(masks[i]) - bbs.append([[m[0].min(), m[0].max()],[m[1].min(), m[1].max()]]) + bbs.append([[m[0].min(), m[0].max()], [m[1].min(), m[1].max()]]) return bbs + def bb_dist(bbs): num_bbs = len(bbs) dist = np.zeros((num_bbs, num_bbs)) - for i,j in itertools.combinations(range(num_bbs), 2): + for i, j in itertools.combinations(range(num_bbs), 2): bbi = bbs[i] bbj = bbs[j] @@ -190,7 +191,7 @@ def bb_dist(bbs): else: disty = bbi[1][0] - bbj[1][1] - dist[i,j] = max(distx,disty) - dist[j,i] = dist[i,j] + dist[i, j] = max(distx, disty) + dist[j, i] = dist[i, j] return dist diff --git a/allensdk/internal/brain_observatory/mouse.py b/allensdk/internal/brain_observatory/mouse.py index b095660dfb..f9efb9eb6a 100644 --- a/allensdk/internal/brain_observatory/mouse.py +++ b/allensdk/internal/brain_observatory/mouse.py @@ -1,20 +1,21 @@ from typing import Optional, List -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.mouse_id import \ - MouseId +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.mouse_id import MouseId from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP from allensdk.internal.api import db_connection_creator -from allensdk.internal.brain_observatory.util.multi_session_utils import \ - get_session_metadata_multiprocessing, get_images_shown, \ - remove_invalid_sessions +from allensdk.internal.brain_observatory.util.multi_session_utils import ( + get_session_metadata_multiprocessing, + get_images_shown, + remove_invalid_sessions, +) class Mouse: """A mouse""" + def __init__(self, mouse_id: str): self._mouse_id = mouse_id @@ -22,10 +23,7 @@ def __init__(self, mouse_id: str): def mouse_id(self) -> str: return self._mouse_id - def get_behavior_sessions( - self, - exclude_invalid_sessions: bool = True - ) -> List[BehaviorMetadata]: + def get_behavior_sessions(self, exclude_invalid_sessions: bool = True) -> List[BehaviorMetadata]: """ Gets all behavior sessions for mouse @@ -38,9 +36,7 @@ def get_behavior_sessions( ------- List[BehaviorMetadata] """ - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) query = f""" SELECT bs.id @@ -50,35 +46,20 @@ def get_behavior_sessions( """ behavior_session_ids = lims_db.fetchall(query=query) behavior_sessions = get_session_metadata_multiprocessing( - behavior_session_ids=behavior_session_ids, - lims_engine=lims_db + behavior_session_ids=behavior_session_ids, lims_engine=lims_db ) if exclude_invalid_sessions: - behavior_sessions = remove_invalid_sessions( - behavior_sessions=behavior_sessions - ) + behavior_sessions = remove_invalid_sessions(behavior_sessions=behavior_sessions) return behavior_sessions @classmethod - def from_behavior_session_id( - cls, - behavior_session_id: int - ) -> "Mouse": + def from_behavior_session_id(cls, behavior_session_id: int) -> "Mouse": """Instantiates `Mouse` from `behavior_session_id`""" - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) - mouse_id = MouseId.from_lims( - behavior_session_id=behavior_session_id, - lims_db=lims_db - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + mouse_id = MouseId.from_lims(behavior_session_id=behavior_session_id, lims_db=lims_db) return Mouse(mouse_id=mouse_id.value) - def get_images_shown( - self, - up_to_behavior_session_id: Optional[int] = None, - n_workers: Optional[int] = None - ): + def get_images_shown(self, up_to_behavior_session_id: Optional[int] = None, n_workers: Optional[int] = None): """Gets all images presented to mouse up to (not including) `up_to_behavior_session_id` if provided @@ -91,27 +72,22 @@ def get_images_shown( Number of processes to spawn for reading image names from stimulus files """ - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) behavior_sessions = self.get_behavior_sessions() if up_to_behavior_session_id is not None: this_date_of_acquisition = [ - x.date_of_acquisition for x in behavior_sessions - if x.behavior_session_id == up_to_behavior_session_id][0] + x.date_of_acquisition for x in behavior_sessions if x.behavior_session_id == up_to_behavior_session_id + ][0] - prior_behavior_session_ids = set([ - x.behavior_session_id for x in behavior_sessions - if x.date_of_acquisition < this_date_of_acquisition]) - behavior_sessions = [ - x for x in behavior_sessions - if x.behavior_session_id in prior_behavior_session_ids] + prior_behavior_session_ids = set( + [x.behavior_session_id for x in behavior_sessions if x.date_of_acquisition < this_date_of_acquisition] + ) + behavior_sessions = [x for x in behavior_sessions if x.behavior_session_id in prior_behavior_session_ids] images_shown = get_images_shown( - behavior_session_ids=( - [x.behavior_session_id for x in behavior_sessions]), + behavior_session_ids=([x.behavior_session_id for x in behavior_sessions]), lims_engine=lims_db, - n_workers=n_workers + n_workers=n_workers, ) return images_shown diff --git a/allensdk/internal/brain_observatory/ophys_session_decomposition.py b/allensdk/internal/brain_observatory/ophys_session_decomposition.py index 66db20caa6..67d7d7ffa9 100644 --- a/allensdk/internal/brain_observatory/ophys_session_decomposition.py +++ b/allensdk/internal/brain_observatory/ophys_session_decomposition.py @@ -2,9 +2,8 @@ import h5py -def open_view_on_binary(file_like, dtype=np.uint8, mode="r", offset=0, - shape=None, order="C", strides=None): - '''Open a view into a memory-mapped binary file. +def open_view_on_binary(file_like, dtype=np.uint8, mode="r", offset=0, shape=None, order="C", strides=None): + """Open a view into a memory-mapped binary file. Parameters ---------- @@ -28,17 +27,16 @@ def open_view_on_binary(file_like, dtype=np.uint8, mode="r", offset=0, ------- numpy.memmap Strided view into memory-mapped array. - ''' + """ mapped = np.memmap(file_like, dtype, mode, offset, order=order) - return np.lib.stride_tricks.as_strided(mapped, shape=shape, - strides=strides) + return np.lib.stride_tricks.as_strided(mapped, shape=shape, strides=strides) def read_strided(filename, dtype, offset, shape, strides): - '''Load a frame without memory-mapping.''' + """Load a frame without memory-mapping.""" frame_size = np.dtype(dtype).itemsize arr = np.empty(shape, dtype=dtype) - frame_size = arr.dtype.itemsize*np.prod(shape[1:]) + frame_size = arr.dtype.itemsize * np.prod(shape[1:]) step = strides[0] - frame_size with open(filename, "rb") as f: f.seek(offset) @@ -50,29 +48,35 @@ def read_strided(filename, dtype, offset, shape, strides): def load_frame(raw_filename, json_meta, use_memmap=False): - '''Load a frame of a multi-frame raw file.''' + """Load a frame of a multi-frame raw file.""" if use_memmap: - arr = open_view_on_binary(raw_filename, dtype=json_meta["dtype"], - offset=json_meta["byte_offset"], - shape=json_meta["shape"], - strides=json_meta["strides"]) + arr = open_view_on_binary( + raw_filename, + dtype=json_meta["dtype"], + offset=json_meta["byte_offset"], + shape=json_meta["shape"], + strides=json_meta["strides"], + ) else: - arr = read_strided(raw_filename, dtype=json_meta["dtype"], - offset=json_meta["byte_offset"], - shape=json_meta["shape"], - strides=json_meta["strides"]) + arr = read_strided( + raw_filename, + dtype=json_meta["dtype"], + offset=json_meta["byte_offset"], + shape=json_meta["shape"], + strides=json_meta["strides"], + ) return arr -def export_frame_to_hdf5(raw_filename, data_hdf5_filename, - auxiliary_hdf5_filename, frame_meta, - compression="gzip", compression_opts=9): - '''Export a frame from raw to hdf5. - +def export_frame_to_hdf5( + raw_filename, data_hdf5_filename, auxiliary_hdf5_filename, frame_meta, compression="gzip", compression_opts=9 +): + """Export a frame from raw to hdf5. + Data with the channel_description `data` is stored in the data_hdf5_filename, while any other data is stored in the - auxiliary_hdf5_filename - ''' + auxiliary_hdf5_filename + """ data_created = False aux_created = False for json_meta in frame_meta: @@ -92,6 +96,10 @@ def export_frame_to_hdf5(raw_filename, data_hdf5_filename, data = load_frame(raw_filename, json_meta) chunks = (1, json_meta["shape"][1], json_meta["shape"][2]) with h5py.File(filename, mode) as f: - f.create_dataset(json_meta["channel_description"], data=data, - chunks=chunks, compression=compression, - compression_opts=compression_opts) + f.create_dataset( + json_meta["channel_description"], + data=data, + chunks=chunks, + compression=compression, + compression_opts=compression_opts, + ) diff --git a/allensdk/internal/brain_observatory/roi_filter.py b/allensdk/internal/brain_observatory/roi_filter.py index 6c75cad43c..68dd81d530 100644 --- a/allensdk/internal/brain_observatory/roi_filter.py +++ b/allensdk/internal/brain_observatory/roi_filter.py @@ -14,7 +14,7 @@ class ROIClassifier(object): - '''Wrapper for machine learning classifier. + """Wrapper for machine learning classifier. Provides an underlying classifier model implementing `fit`, `score`, and `predict`. Tracks additional information for @@ -35,79 +35,75 @@ class ROIClassifier(object): `reporters`: Reporter set used for training. `other_appended_labels`: Labels appended outside model. `cross_validation_scores`: Cross validation if generated. - ''' + """ + def __init__(self, model_data=None): - '''Constructor.''' + """Constructor.""" if model_data is None: model_data = {} self.sklearn_version = sklearn_version model_sklearn = model_data.get("sklearn_version", None) if sklearn_version != model_sklearn: - logging.warning("Using sklearn %s, model trained using %s", - sklearn_version, model_sklearn) + logging.warning("Using sklearn %s, model trained using %s", sklearn_version, model_sklearn) self.model = model_data.get("model", None) - self.training_features = model_data.get("training_features", - pd.DataFrame()) - self.training_labels = model_data.get("training_labels", - pd.DataFrame()) + self.training_features = model_data.get("training_features", pd.DataFrame()) + self.training_labels = model_data.get("training_labels", pd.DataFrame()) self.trimmed_features = model_data.get("trimmed_features", []) self.structure_ids = model_data.get("structure_ids", []) self.drivers = model_data.get("drivers", []) self.reporters = model_data.get("reporters", []) - self.other_appended_labels = model_data.get("other_appended_labels", - []) + self.other_appended_labels = model_data.get("other_appended_labels", []) # this is a harsh score for multilabel because it requires ALL # labels predicted - self.cross_validation_scores = model_data.get( - "cross_validation_scores", None) + self.cross_validation_scores = model_data.get("cross_validation_scores", None) self.unexpected_features = [] @property def model_data(self): - '''The classifier properties as a dictionary.''' - data = {"model": self.model, - "training_features": self.training_features, - "training_labels": self.training_labels, - "trimmed_features": self.trimmed_features, - "structure_ids": self.structure_ids, - "drivers": self.drivers, - "reporters": self.reporters, - "other_appended_labels": self.other_appended_labels, - "sklearn_version": self.sklearn_version, - "cross_validation_scores": self.cross_validation_scores} + """The classifier properties as a dictionary.""" + data = { + "model": self.model, + "training_features": self.training_features, + "training_labels": self.training_labels, + "trimmed_features": self.trimmed_features, + "structure_ids": self.structure_ids, + "drivers": self.drivers, + "reporters": self.reporters, + "other_appended_labels": self.other_appended_labels, + "sklearn_version": self.sklearn_version, + "cross_validation_scores": self.cross_validation_scores, + } return data @property def label_names(self): - '''Return label names for the classifier.''' + """Return label names for the classifier.""" return self.training_labels.columns - def create_feature_array(self, object_data, depth, structure_id, drivers, - reporters): - '''Creates feature array from input data. + def create_feature_array(self, object_data, depth, structure_id, drivers, reporters): + """Creates feature array from input data. See Also -------- create_feature_array : Create a feature array given model and inputs - ''' + """ pass - def get_labels(self, object_data, depth, structure_id, drivers, - reporters): - '''Generate labels from input data. + def get_labels(self, object_data, depth, structure_id, drivers, reporters): + """Generate labels from input data. See Also -------- ROIClassifier.create_feature_array - ''' - features = create_feature_array(self.model_data, object_data, depth, - structure_id, drivers, reporters) + """ + features = create_feature_array(self.model_data, object_data, depth, structure_id, drivers, reporters) self.unexpected_features = get_unexpected_features( - self.model_data, object_data, structure_id, drivers, reporters) + self.model_data, object_data, structure_id, drivers, reporters + ) return self.predict(features) def fit(self, features, labels): - '''Fit model to data. + """Fit model to data. Parameters ---------- @@ -115,21 +111,21 @@ def fit(self, features, labels): Training feature set. labels : pandas.DataFrame Training labels. - ''' + """ self.training_features = features self.training_labels = labels self.model.fit(features, labels) def score(self, features, labels): - '''Calculate classifier score on data.''' + """Calculate classifier score on data.""" return self.model.score(features, labels) def predict(self, features): - '''Generate classification labels given features.''' + """Generate classification labels given features.""" return self.model.predict(features) def cross_validate(self, features, labels, n_folds=5, n_jobs=1): - '''Generate cross-validation scores for the classifier. + """Generate cross-validation scores for the classifier. Parameters ---------- @@ -146,25 +142,24 @@ def cross_validate(self, features, labels, n_folds=5, n_jobs=1): ------- numpy.ndarray `n_folds` cross-validation scores. - ''' - self.cross_validation_scores = cross_val_score( - self.model, features, labels, cv=n_folds, n_jobs=n_jobs) + """ + self.cross_validation_scores = cross_val_score(self.model, features, labels, cv=n_folds, n_jobs=n_jobs) return self.cross_validation_scores def save(self, filename): - '''Save the classifier to file by pickling.''' + """Save the classifier to file by pickling.""" with open(filename, "wb") as f: pickle.dump(self.model_data, f) @staticmethod def from_file(filename): - '''Load an ROIClassifier from file.''' + """Load an ROIClassifier from file.""" with open(filename, "rb") as f: return ROIClassifier(pickle.load(f)) def mean_gray_to_sigma(meanInt0, snpoffsetstdv): - '''Calculate intensity variation used in prior code. + """Calculate intensity variation used in prior code. Parameters ---------- @@ -177,15 +172,14 @@ def mean_gray_to_sigma(meanInt0, snpoffsetstdv): ------- pandas.Series meanInt0/snpoffsetstdv, preventing Inf (returns as 0). - ''' + """ mean_gray_to_sigma = meanInt0 / snpoffsetstdv.astype(float) mean_gray_to_sigma[snpoffsetstdv == 0.0] = 0 return mean_gray_to_sigma -def create_feature_array(model_data, object_data, depth, structure_id, - drivers, reporters): - '''Create feature array from input data. +def create_feature_array(model_data, object_data, depth, structure_id, drivers, reporters): + """Create feature array from input data. This creates the feature array with column ordering matching what the classifier was trained on. @@ -205,12 +199,11 @@ def create_feature_array(model_data, object_data, depth, structure_id, List of drivers for the mouse. reporters : list List of reporters for the mouse. - ''' + """ training_features = model_data["training_features"].columns if np.isnan(depth): depth = 0 - meanGrayToSigma = mean_gray_to_sigma( - object_data["meanInt0"], object_data["snpoffsetstdv"]) + meanGrayToSigma = mean_gray_to_sigma(object_data["meanInt0"], object_data["snpoffsetstdv"]) features = pd.DataFrame() for column in training_features: if column == "depth": @@ -228,14 +221,12 @@ def create_feature_array(model_data, object_data, depth, structure_id, features[column] = object_data[column] else: logging.error("Feature %s missing from input data", column) - raise KeyError( - "Feature {} missing from input data".format(column)) + raise KeyError("Feature {} missing from input data".format(column)) return features -def get_unexpected_features(model_data, object_data, structure_id, drivers, - reporters): - '''Get list of incoming features that weren't in traning data. +def get_unexpected_features(model_data, object_data, structure_id, drivers, reporters): + """Get list of incoming features that weren't in traning data. Parameters ---------- @@ -250,21 +241,19 @@ def get_unexpected_features(model_data, object_data, structure_id, drivers, List of drivers for the mouse. reporters : list List of reporters for the mouse. - ''' + """ training_features = model_data["training_features"].columns trimmed_features = model_data["trimmed_features"] - inputs = list(itertools.chain(object_data.columns, [structure_id], - drivers, reporters)) + inputs = list(itertools.chain(object_data.columns, [structure_id], drivers, reporters)) unexpected_features = [] for feature in inputs: - if (feature not in training_features) and \ - (feature not in trimmed_features): + if (feature not in training_features) and (feature not in trimmed_features): unexpected_features.append(feature) return unexpected_features def label_unions_and_duplicates(rois, overlap_threshold): - '''Detect unions and duplicates and label ROIs.''' + """Detect unions and duplicates and label ROIs.""" masks = create_roi_mask_array(rois) valid_masks = np.ones(masks.shape[0]).astype(bool) ms = mask_set.MaskSet(masks=masks) @@ -292,7 +281,7 @@ def label_unions_and_duplicates(rois, overlap_threshold): def apply_labels(rois, label_array, label_names): - '''Apply labels to rois. + """Apply labels to rois. Parameters ---------- @@ -307,10 +296,9 @@ def apply_labels(rois, label_array, label_names): ------- list List of ROIs with labels appended. - ''' + """ label_df = pd.DataFrame(data=label_array, columns=label_names) - label_lists = label_df.apply(_column_match).apply( - _compress_to_list, args=(label_df.columns,), axis=1) + label_lists = label_df.apply(_column_match).apply(_compress_to_list, args=(label_df.columns,), axis=1) for i, roi in enumerate(rois): roi.labels.extend(label_lists[i]) return rois @@ -321,5 +309,5 @@ def _column_match(column): def _compress_to_list(row, names): - '''Get names that have value 1 in row.''' + """Get names that have value 1 in row.""" return list(names[row.values]) diff --git a/allensdk/internal/brain_observatory/roi_filter_utils.py b/allensdk/internal/brain_observatory/roi_filter_utils.py index 0cb0fb5dec..1392f6d42b 100644 --- a/allensdk/internal/brain_observatory/roi_filter_utils.py +++ b/allensdk/internal/brain_observatory/roi_filter_utils.py @@ -7,9 +7,9 @@ import pandas as pd import numpy as np -CRITERIA_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "resources", - "roi_filter_training_criteria.json") +CRITERIA_FILE = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "resources", "roi_filter_training_criteria.json" +) _CRITERIA = None @@ -22,7 +22,7 @@ def CRITERIA(): class TrainingLabelClassifier(object): - '''Very basic threshold_based classifier. + """Very basic threshold_based classifier. Has a decision function that is just the number of distinct criteria met by the classifier. Criteria are defined as a list @@ -32,16 +32,17 @@ class TrainingLabelClassifier(object): ---------- criteria : list List of evaluation strings. - ''' + """ + def __init__(self, criteria): - '''Constructor.''' + """Constructor.""" if criteria is None: self.criteria = [] else: self.criteria = criteria def decision_function(self, X): - '''Get the distance from the decision boundary. + """Get the distance from the decision boundary. Parameters ---------- @@ -52,7 +53,7 @@ def decision_function(self, X): ------- T : array-like Distance for each sample from the decision boundary. - ''' + """ T = np.zeros((X.shape[0],), dtype=int) for crit in self.criteria: T[X.eval(crit).as_matrix()] += 1 @@ -60,7 +61,7 @@ def decision_function(self, X): class TrainingMultiLabelClassifier(object): - '''Multilabel classifier using groups of TrainingLabelClassifiers. + """Multilabel classifier using groups of TrainingLabelClassifiers. This was used to generate labeling for training the original SVM for classification. @@ -69,9 +70,10 @@ class TrainingMultiLabelClassifier(object): ---------- criteria : dictionary Label names and criteria for each label. - ''' + """ + def __init__(self, criteria=None): - '''Constructor.''' + """Constructor.""" if criteria is None: criteria = CRITERIA() i = 0 @@ -85,7 +87,7 @@ def __init__(self, criteria=None): i += 1 def _labels_as_columns(self, label_codes): - '''Convert label series to boolean columns for each label. + """Convert label series to boolean columns for each label. Parameters ---------- @@ -97,7 +99,7 @@ def _labels_as_columns(self, label_codes): pandas.DataFrame Dataframe where each column is a label, and values are True for labeled or False otherwise. - ''' + """ output = pd.DataFrame() for name in self.labels: number = self._codes[name] @@ -113,7 +115,7 @@ def _map_code_to_list(self, label_code): return output def get_eXcluded(self, X): - '''Get the calculated value of the eXcluded column. + """Get the calculated value of the eXcluded column. This is useful for comparison with the original classifier implementation. @@ -127,7 +129,7 @@ def get_eXcluded(self, X): ------- numpy.ndarray Calculated eXcluded score from the classifier. - ''' + """ eXcluded = np.zeros((X.shape[0],), dtype=X["eXcluded"].dtype) for classifier in self._classifiers.values(): eXcluded += classifier.decision_function(X) @@ -138,7 +140,7 @@ def get_eXcluded(self, X): return eXcluded.as_matrix() def label_data(self, X, as_columns=True): - '''Generate labels for each row in X. + """Generate labels for each row in X. Parameters ---------- @@ -150,7 +152,7 @@ def label_data(self, X, as_columns=True): numpy.ndarray Array of label codes representing the combination of labels found for each row. - ''' + """ labels = np.zeros((X.shape[0],), dtype=int) for label, classifier in self._classifiers.items(): labels[classifier.decision_function(X) > 0] += label @@ -161,7 +163,7 @@ def label_data(self, X, as_columns=True): def calculate_max_border(motion_df, max_shift): - '''Calculate motion boundary from frame offsets. + """Calculate motion boundary from frame offsets. When the motion correction algorithm fails to find sufficient matches, it generates very large frame offsets. The use of @@ -180,16 +182,14 @@ def calculate_max_border(motion_df, max_shift): ------- list [right_shift, left_shift, down_shift, up_shift] - ''' + """ # strip outliers - x_no_outliers = motion_df["x"][(motion_df["x"] >= -max_shift) - & (motion_df["x"] <= max_shift)] - y_no_outliers = motion_df["y"][(motion_df["y"] >= -max_shift) - & (motion_df["y"] <= max_shift)] + x_no_outliers = motion_df["x"][(motion_df["x"] >= -max_shift) & (motion_df["x"] <= max_shift)] + y_no_outliers = motion_df["y"][(motion_df["y"] >= -max_shift) & (motion_df["y"] <= max_shift)] - right_shift = np.max(-1*x_no_outliers.min(), 0) + right_shift = np.max(-1 * x_no_outliers.min(), 0) left_shift = np.max(x_no_outliers.max(), 0) - down_shift = np.max(-1*y_no_outliers.min(), 0) + down_shift = np.max(-1 * y_no_outliers.min(), 0) up_shift = np.max(y_no_outliers.max(), 0) border = [right_shift, left_shift, down_shift, up_shift] @@ -201,7 +201,7 @@ def calculate_max_border(motion_df, max_shift): def order_rois_by_object_list(object_data, rois): - '''Reorder rois by matching bounding boxes to object list. + """Reorder rois by matching bounding boxes to object list. Parameters ---------- @@ -214,30 +214,22 @@ def order_rois_by_object_list(object_data, rois): ------- list The list of rois reordered to index the same as object_data. - ''' - object_points = object_data[["minx", - "miny", - "maxx", - "maxy", - "area"]].copy() + """ + object_points = object_data[["minx", "miny", "maxx", "maxy", "area"]].copy() object_points["maxx"] += 1 object_points["maxy"] += 1 roi_points = [] for roi in rois: - roi_points.append([roi.x, roi.y, roi.x+roi.width, roi.y+roi.height, - roi.mask.sum()]) - reorder_index = get_indices_by_distance(object_points, - np.array(roi_points)) + roi_points.append([roi.x, roi.y, roi.x + roi.width, roi.y + roi.height, roi.mask.sum()]) + reorder_index = get_indices_by_distance(object_points, np.array(roi_points)) multi_mapped = set() if len(set(reorder_index)) != reorder_index.shape[0]: unique, counts = np.unique(reorder_index, return_counts=True) multi_mapped = set(unique[counts > 1]) - not_mapped = set(np.setdiff1d(np.arange(reorder_index.shape[0]), - reorder_index)) + not_mapped = set(np.setdiff1d(np.arange(reorder_index.shape[0]), reorder_index)) logging.warning("ROIs don't uniquely map to object_list") - for idx in (multi_mapped | not_mapped): - logging.warning( - "%s has ambiguous mapping to object list" % rois[idx].label) + for idx in multi_mapped | not_mapped: + logging.warning("%s has ambiguous mapping to object list" % rois[idx].label) out_rois = [] for i in reorder_index: roi = rois[i] @@ -248,7 +240,7 @@ def order_rois_by_object_list(object_data, rois): def get_rois(segmentation_stack, border=None): - '''Extract a list of rois from the segmentation data array. + """Extract a list of rois from the segmentation data array. Parameters ---------- @@ -262,7 +254,7 @@ def get_rois(segmentation_stack, border=None): ------- list List of RoiMask objects. - ''' + """ rois = [] if border is None: border = [0, 0, 0, 0] @@ -270,14 +262,12 @@ def get_rois(segmentation_stack, border=None): width = segmentation_stack.shape[2] for i in range(segmentation_stack.shape[0]): page = segmentation_stack[i, :, :] - label_mask, num_labels = measurements.label( - page, structure=[[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + label_mask, num_labels = measurements.label(page, structure=[[1, 1, 1], [1, 1, 1], [1, 1, 1]]) for label in range(1, num_labels + 1): img_mask = label_mask == label - mask = create_roi_mask(width, height, border, - roi_mask=img_mask, - label="ROI {}:{}".format(i, label), - mask_group=i) + mask = create_roi_mask( + width, height, border, roi_mask=img_mask, label="ROI {}:{}".format(i, label), mask_group=i + ) mask.labels = [] if mask.overlaps_motion_border: mask.labels.append("motion_border") @@ -286,14 +276,13 @@ def get_rois(segmentation_stack, border=None): def get_indices_by_distance(object_list_points, mask_points): - '''Find indices of nearest neighbor matches. + """Find indices of nearest neighbor matches. Require a distance of 0 (perfect match) and a unique match between masks and object_list entries. - ''' + """ if np.array(mask_points).ndim != 2: - raise ValueError("number of dimensions is incorrect. Expected 2 " - f"got {np.array(mask_points).ndim}") + raise ValueError(f"number of dimensions is incorrect. Expected 2 got {np.array(mask_points).ndim}") tree = cKDTree(mask_points) distance, indices = tree.query(object_list_points) if distance.max() > 0: diff --git a/allensdk/internal/brain_observatory/run_itracker.py b/allensdk/internal/brain_observatory/run_itracker.py index 26af30b575..200f415ecd 100644 --- a/allensdk/internal/brain_observatory/run_itracker.py +++ b/allensdk/internal/brain_observatory/run_itracker.py @@ -4,9 +4,7 @@ import logging from allensdk.config.manifest import Manifest from allensdk.internal.brain_observatory.itracker import iTracker -from allensdk.internal.brain_observatory.frame_stream import \ - FfmpegInputStream, \ - FfmpegOutputStream +from allensdk.internal.brain_observatory.frame_stream import FfmpegInputStream, FfmpegOutputStream import h5py import ast import sys @@ -14,9 +12,9 @@ DEFAULT_THRESHOLD_FACTOR = 1.6 -if sys.platform == 'linux2': +if sys.platform == "linux2": FFMPEG_BIN = "/shared/utils.x86_64/ffmpeg/bin/ffmpeg" -elif sys.platform == 'darwin': +elif sys.platform == "darwin": FFMPEG_BIN = "/usr/local/bin/ffmpeg" @@ -24,8 +22,7 @@ def compute_bounding_box(points): if not points: return None points = np.array(points) - return [points[:, 0].min(), points[:, 0].max(), - points[:, 1].min(), points[:, 1].max()] + return [points[:, 0].min(), points[:, 0].max(), points[:, 1].min(), points[:, 1].max()] def get_polygon(experiment_id, group_name): @@ -41,8 +38,7 @@ def get_polygon(experiment_id, group_name): """ % (group_name, experiment_id) try: - path = np.array( - [int(v) for v in lu.query(query)[0]['path'].split(',')]) + path = np.array([int(v) for v in lu.query(query)[0]["path"].split(",")]) except KeyError: return [] except IndexError: @@ -54,27 +50,27 @@ def get_polygon(experiment_id, group_name): def get_experiment_info(experiment_id): - logging.info("Downloading paths/metadata for experiment ID: %d", - experiment_id) - query = "select storage_directory, id from ophys_sessions where id = " + \ - str(experiment_id) + logging.info("Downloading paths/metadata for experiment ID: %d", experiment_id) + query = "select storage_directory, id from ophys_sessions where id = " + str(experiment_id) - storage_directory = lu.query(query)[0]['storage_directory'] + storage_directory = lu.query(query)[0]["storage_directory"] logging.info("\tStorage directory: %s", storage_directory) - movie_file = glob.glob(storage_directory + '*video-1.avi')[0] - metadata_file = glob.glob(storage_directory + '*video-1.h5')[0] + movie_file = glob.glob(storage_directory + "*video-1.avi")[0] + metadata_file = glob.glob(storage_directory + "*video-1.h5")[0] - cr_points = get_polygon(experiment_id, 'Corneal Reflection Bounding Box') - pupil_points = get_polygon(experiment_id, 'Pupil Bounding Box') + cr_points = get_polygon(experiment_id, "Corneal Reflection Bounding Box") + pupil_points = get_polygon(experiment_id, "Pupil Bounding Box") logging.info("\tmovie file: %s", movie_file) logging.info("\tmetadata file: %s", metadata_file) - return dict(movie_file=movie_file, - metadata_file=metadata_file, - corneal_reflection_points=cr_points, - pupil_points=pupil_points) + return dict( + movie_file=movie_file, + metadata_file=metadata_file, + corneal_reflection_points=cr_points, + pupil_points=pupil_points, + ) def get_movie_shape_from_metadata(metadata_file): @@ -87,32 +83,33 @@ def get_movie_shape_from_metadata(metadata_file): # 'width'], 3) # in the metadata file from lims, the 'width' and 'height' variables are # swapped, hopefully this is the same for every single experiment. - movie_shape = ( - metadata['frames'], metadata['width'], metadata['height'], 3) + movie_shape = (metadata["frames"], metadata["width"], metadata["height"], 3) logging.info("movie_shape from metadata_file = %s", str(movie_shape)) return movie_shape -def run_itracker(movie_file, output_directory, - output_frames=False, - output_annotation_frames=False, - output_annotated_movie=True, - output_annotated_movie_block_size=1, - estimate_bbox=False, - num_frames=None, - output_QC=True, - image_type='png', - cache_input_frames=False, - input_block_size=1, - metadata_file=None, - movie_shape=None, - **kwargs): +def run_itracker( + movie_file, + output_directory, + output_frames=False, + output_annotation_frames=False, + output_annotated_movie=True, + output_annotated_movie_block_size=1, + estimate_bbox=False, + num_frames=None, + output_QC=True, + image_type="png", + cache_input_frames=False, + input_block_size=1, + metadata_file=None, + movie_shape=None, + **kwargs, +): if output_directory is not None: Manifest.safe_mkdir(output_directory) - assert metadata_file is not None and movie_shape is not None, \ - "Must provide either metadata_file or movie_shape" + assert metadata_file is not None and movie_shape is not None, "Must provide either metadata_file or movie_shape" if metadata_file: movie_shape = get_movie_shape_from_metadata(metadata_file) @@ -122,22 +119,28 @@ def run_itracker(movie_file, output_directory, if num_frames is None: num_frames = movie_shape[0] - input_stream = FfmpegInputStream(movie_file, frame_shape, - ffmpeg_bin=FFMPEG_BIN, - num_frames=num_frames, - cache_frames=cache_input_frames, - block_size=input_block_size) - - movie_output_stream = FfmpegOutputStream( + input_stream = FfmpegInputStream( + movie_file, frame_shape, - block_size=output_annotated_movie_block_size, - ffmpeg_bin=FFMPEG_BIN) if \ - output_annotated_movie else None + ffmpeg_bin=FFMPEG_BIN, + num_frames=num_frames, + cache_frames=cache_input_frames, + block_size=input_block_size, + ) - itracker = iTracker(output_directory, input_stream=input_stream, - im_shape=(movie_shape[1], movie_shape[2]), - num_frames=num_frames, - **kwargs) + movie_output_stream = ( + FfmpegOutputStream(frame_shape, block_size=output_annotated_movie_block_size, ffmpeg_bin=FFMPEG_BIN) + if output_annotated_movie + else None + ) + + itracker = iTracker( + output_directory, + input_stream=input_stream, + im_shape=(movie_shape[1], movie_shape[2]), + num_frames=num_frames, + **kwargs, + ) # open this early to avoid duplicating massive memory movie_output_stream.open(itracker.annotated_movie_file) @@ -149,9 +152,11 @@ def run_itracker(movie_file, output_directory, if estimate_bbox: bbox_pupil, bbox_cr = itracker.estimate_bbox_from_mean_frame() - itracker.process_movie(movie_output_stream=movie_output_stream, - output_frames=output_frames, - output_annotation_frames=output_annotation_frames) + itracker.process_movie( + movie_output_stream=movie_output_stream, + output_frames=output_frames, + output_annotation_frames=output_annotation_frames, + ) if output_QC: itracker.output_QC(image_type=image_type) @@ -161,14 +166,14 @@ def run_itracker(movie_file, output_directory, def main(): parser = argparse.ArgumentParser() - parser.add_argument('--experiment_id', default=None, type=int) - parser.add_argument('--movie_file', default=None) - parser.add_argument('--metadata_file', default=None) - parser.add_argument('--output_directory', default='.') - parser.add_argument('--estimate_bbox', action='store_true') - parser.add_argument('--num_frames', default=None, type=int) - parser.add_argument('--threshold_factor', default=DEFAULT_THRESHOLD_FACTOR) - parser.add_argument('--log_level', default=logging.DEBUG) + parser.add_argument("--experiment_id", default=None, type=int) + parser.add_argument("--movie_file", default=None) + parser.add_argument("--metadata_file", default=None) + parser.add_argument("--output_directory", default=".") + parser.add_argument("--estimate_bbox", action="store_true") + parser.add_argument("--num_frames", default=None, type=int) + parser.add_argument("--threshold_factor", default=DEFAULT_THRESHOLD_FACTOR) + parser.add_argument("--log_level", default=logging.DEBUG) args = parser.parse_args() logging.getLogger().setLevel(args.log_level) @@ -177,23 +182,22 @@ def main(): threshold_factor=args.threshold_factor, output_directory=args.output_directory, num_frames=args.num_frames, - estimate_bbox=args.estimate_bbox + estimate_bbox=args.estimate_bbox, ) if args.experiment_id: info = get_experiment_info(args.experiment_id) - data['movie_file'] = info['movie_file'] - data['metadata_file'] = info['metadata_file'] + data["movie_file"] = info["movie_file"] + data["metadata_file"] = info["metadata_file"] - if info.get('pupil_points', None): - data['bbox_pupil'] = compute_bounding_box(info['pupil_points']) - if info.get('corneal_reflection_points', None): - data['bbox_cr'] = compute_bounding_box( - info['corneal_reflection_points']) + if info.get("pupil_points", None): + data["bbox_pupil"] = compute_bounding_box(info["pupil_points"]) + if info.get("corneal_reflection_points", None): + data["bbox_cr"] = compute_bounding_box(info["corneal_reflection_points"]) else: - data['movie_file'] = args.movie_file - data['metadata_file'] = args.metdata_file + data["movie_file"] = args.movie_file + data["metadata_file"] = args.metdata_file run_itracker(**data) diff --git a/allensdk/internal/brain_observatory/time_sync.py b/allensdk/internal/brain_observatory/time_sync.py index 5bd79c9b36..e049f4b07e 100644 --- a/allensdk/internal/brain_observatory/time_sync.py +++ b/allensdk/internal/brain_observatory/time_sync.py @@ -6,17 +6,18 @@ from allensdk.brain_observatory.sync_dataset import Dataset import pandas as pd import logging + try: import cv2 except ImportError: cv2 = None TRANSITION_FRAME_INTERVAL = 60 -REG_PHOTODIODE_INTERVAL = 1.0 # seconds -REG_PHOTODIODE_STD = 0.05 # seconds -PHOTODIODE_ANOMALY_THRESHOLD = 0.5 # seconds -LONG_STIM_THRESHOLD = 0.2 # seconds -MAX_MONITOR_DELAY = 0.07 # seconds +REG_PHOTODIODE_INTERVAL = 1.0 # seconds +REG_PHOTODIODE_STD = 0.05 # seconds +PHOTODIODE_ANOMALY_THRESHOLD = 0.5 # seconds +LONG_STIM_THRESHOLD = 0.2 # seconds +MAX_MONITOR_DELAY = 0.07 # seconds def get_keys(sync_dset: Dataset) -> dict: @@ -36,16 +37,14 @@ def get_keys(sync_dset: Dataset) -> dict: # and value is the possible data for each category existing in sync dataset # line labels key_dict = { - "photodiode": ["stim_photodiode", "photodiode"], - "2p": ["2p_vsync"], - "stimulus": ["stim_vsync", "vsync_stim"], - "eye_camera": ["cam2_exposure", "eye_tracking", - "eye_frame_received"], - "behavior_camera": ["cam1_exposure", "behavior_monitoring", - "beh_frame_received"], - "acquiring": ["2p_acquiring", "acq_trigger"], - "lick_sensor": ["lick_1", "lick_sensor"] - } + "photodiode": ["stim_photodiode", "photodiode"], + "2p": ["2p_vsync"], + "stimulus": ["stim_vsync", "vsync_stim"], + "eye_camera": ["cam2_exposure", "eye_tracking", "eye_frame_received"], + "behavior_camera": ["cam1_exposure", "behavior_monitoring", "beh_frame_received"], + "acquiring": ["2p_acquiring", "acq_trigger"], + "lick_sensor": ["lick_1", "lick_sensor"], + } label_set = set(sync_dset.line_labels) remove_keys = [] for key, value in key_dict.items(): @@ -63,37 +62,40 @@ def get_keys(sync_dset: Dataset) -> dict: # the contents of the `remove_keys` list is printed to the console # as a user warning if len(remove_keys) > 0: - logging.warning("Could not find valid lines for the following data " - "sources") + logging.warning("Could not find valid lines for the following data sources") for key in remove_keys: logging.warning(f"{key} (valid line label(s) = {key_dict[key]}") key_dict.pop(key) return key_dict -def calculate_monitor_delay(sync_dset, stim_times, photodiode_key, - transition_frame_interval=TRANSITION_FRAME_INTERVAL, # noqa: E501 - max_monitor_delay=MAX_MONITOR_DELAY): +def calculate_monitor_delay( + sync_dset, + stim_times, + photodiode_key, + transition_frame_interval=TRANSITION_FRAME_INTERVAL, # noqa: E501 + max_monitor_delay=MAX_MONITOR_DELAY, +): """Calculate monitor delay.""" transitions = stim_times[::transition_frame_interval] photodiode_events = get_real_photodiode_events(sync_dset, photodiode_key) - transition_events = photodiode_events[0:len(transitions)] + transition_events = photodiode_events[0 : len(transitions)] delays = transition_events - transitions delay = np.mean(delays) - logging.info(f"Calculated monitor delay: {delay}. \n " - f"Max monitor delay: {np.max(delays)}. \n " - f"Min monitor delay: {np.min(delays)}.\n " - f"Std monitor delay: {np.std(delays)}.") + logging.info( + f"Calculated monitor delay: {delay}. \n " + f"Max monitor delay: {np.max(delays)}. \n " + f"Min monitor delay: {np.min(delays)}.\n " + f"Std monitor delay: {np.std(delays)}." + ) if delay < 0 or delay > max_monitor_delay: - raise ValueError(f"Delay ({delay}s) falls outside expected value " - f"range (0-{MAX_MONITOR_DELAY}s).") + raise ValueError(f"Delay ({delay}s) falls outside expected value range (0-{MAX_MONITOR_DELAY}s).") return delay -def _find_last_n(arr: np.ndarray, n: int, - cond: Callable[[Any], bool]) -> Optional[int]: +def _find_last_n(arr: np.ndarray, n: int, cond: Callable[[Any], bool]) -> Optional[int]: """ Find the final index where the prior `n` values in an array meet the condition `cond` (inclusive). @@ -111,8 +113,7 @@ def _find_last_n(arr: np.ndarray, n: int, return reversed_ix -def _find_n(arr: np.ndarray, n: int, - cond: Callable[[Any], bool]) -> Optional[int]: +def _find_n(arr: np.ndarray, n: int, cond: Callable[[Any], bool]) -> Optional[int]: """ Find the index where the next `n` values in an array meet the condition `cond` (inclusive). @@ -131,7 +132,7 @@ def _find_n(arr: np.ndarray, n: int, while queue.count(True) < n: try: i += 1 - queue.append(cond(arr[i+n-1])) + queue.append(cond(arr[i + n - 1])) except IndexError: return None return i @@ -157,24 +158,18 @@ def get_photodiode_events(sync_dset, photodiode_key): min_interval = REG_PHOTODIODE_INTERVAL - REG_PHOTODIODE_STD max_interval = REG_PHOTODIODE_INTERVAL + REG_PHOTODIODE_STD if not len(all_events): - raise ValueError("No photodiode events found. Please check " - "the input data for errors. ") - first_valid_index = _find_n( - all_events_diff_next, 2, - lambda x: (x >= min_interval) & (x <= max_interval)) - last_valid_index = _find_last_n( - all_events_diff_prev, 2, - lambda x: (x >= min_interval) & (x <= max_interval)) + raise ValueError("No photodiode events found. Please check the input data for errors. ") + first_valid_index = _find_n(all_events_diff_next, 2, lambda x: (x >= min_interval) & (x <= max_interval)) + last_valid_index = _find_last_n(all_events_diff_prev, 2, lambda x: (x >= min_interval) & (x <= max_interval)) if first_valid_index is None: raise ValueError("Can't find valid start event") if last_valid_index is None: raise ValueError("Can't find valid end event") - pd_events = all_events[first_valid_index:last_valid_index+1] + pd_events = all_events[first_valid_index : last_valid_index + 1] return pd_events -def get_real_photodiode_events(sync_dset, photodiode_key, - anomaly_threshold=PHOTODIODE_ANOMALY_THRESHOLD): +def get_real_photodiode_events(sync_dset, photodiode_key, anomaly_threshold=PHOTODIODE_ANOMALY_THRESHOLD): """Gets the photodiode events with the anomalies removed.""" events = get_photodiode_events(sync_dset, photodiode_key) anomalies = np.where(np.diff(events) < anomaly_threshold) @@ -182,9 +177,8 @@ def get_real_photodiode_events(sync_dset, photodiode_key, def get_alignment_array(ref, other, int_method=np.floor): - """Generate an alignment array """ - return int_method(np.interp(other, ref, np.arange(len(ref)), left=np.nan, - right=np.nan)) + """Generate an alignment array""" + return int_method(np.interp(other, ref, np.arange(len(ref)), left=np.nan, right=np.nan)) def get_video_length(filename): @@ -193,8 +187,7 @@ def get_video_length(filename): capture = cv2.VideoCapture(filename) return int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) except AttributeError: - logging.warning("Could not get length for %s, opencv out of date", - filename) + logging.warning("Could not get length for %s, opencv out of date", filename) else: logging.warning("Could not get length for %s", filename) @@ -235,8 +228,7 @@ def corrected_video_timestamps(video_name, timestamps, data_length): if data_length is not None: delta = len(timestamps) - data_length if delta != 0: - logging.info("%s data of length %s has timestamps of length " - "%s", video_name, data_length, len(timestamps)) + logging.info("%s data of length %s has timestamps of length %s", video_name, data_length, len(timestamps)) else: logging.info("No data length provided for %s", video_name) @@ -244,9 +236,16 @@ def corrected_video_timestamps(video_name, timestamps, data_length): class OphysTimeAligner(object): - def __init__(self, sync_file, scanner=None, dff_file=None, - stimulus_pkl=None, eye_video=None, behavior_video=None, - long_stim_threshold=LONG_STIM_THRESHOLD): + def __init__( + self, + sync_file, + scanner=None, + dff_file=None, + stimulus_pkl=None, + eye_video=None, + behavior_video=None, + long_stim_threshold=LONG_STIM_THRESHOLD, + ): self.scanner = scanner if scanner is not None else "SCIVIVO" self._dataset = Dataset(sync_file) self._keys = get_keys(self._dataset) @@ -288,10 +287,8 @@ def ophys_timestamps(self): elif self.scanner == "NIKONA1RMP": # Nikon has a signal that indicates when it started writing to disk acquiring_key = self._keys["acquiring"] - acquisition_start = self._dataset.get_rising_edges( - acquiring_key, units="seconds")[0] - ophys_times = self._dataset.get_falling_edges( - ophys_key, units="seconds") + acquisition_start = self._dataset.get_rising_edges(acquiring_key, units="seconds")[0] + ophys_times = self._dataset.get_falling_edges(ophys_key, units="seconds") times = ophys_times[ophys_times >= acquisition_start] else: raise ValueError("Invalid scanner: {}".format(self.scanner)) @@ -306,12 +303,14 @@ def corrected_ophys_timestamps(self): if self.ophys_data_length is not None: if len(times) < self.ophys_data_length: raise ValueError( - "Got too few timestamps ({}) for ophys data length " - "({})".format(len(times), self.ophys_data_length)) + "Got too few timestamps ({}) for ophys data length ({})".format(len(times), self.ophys_data_length) + ) elif len(times) > self.ophys_data_length: - logging.info("Ophys data of length %s has timestamps of " - "length %s, truncating timestamps", - self.ophys_data_length, len(times)) + logging.info( + "Ophys data of length %s has timestamps of length %s, truncating timestamps", + self.ophys_data_length, + len(times), + ) delta = len(times) - self.ophys_data_length times = times[:-delta] else: @@ -329,23 +328,21 @@ def _get_clipped_stim_timestamps(self): timestamps = self.stim_timestamps delta = 0 - if self.stim_data_length is not None and \ - self.stim_data_length < len(timestamps): + if self.stim_data_length is not None and self.stim_data_length < len(timestamps): stim_key = self._keys["stimulus"] rising = self.dataset.get_rising_edges(stim_key, units="seconds") # Some versions of camstim caused a spike when the DAQ is first # initialized. Remove it. if rising[1] - rising[0] > self.long_stim_threshold: - logging.info("Initial DAQ spike detected from stimulus, " - "removing it") + logging.info("Initial DAQ spike detected from stimulus, removing it") timestamps = timestamps[1:] delta = len(timestamps) - self.stim_data_length if delta != 0: - logging.info("Stim data of length %s has timestamps of " - "length %s", - self.stim_data_length, len(timestamps)) + logging.info( + "Stim data of length %s has timestamps of length %s", self.stim_data_length, len(timestamps) + ) elif self.stim_data_length is None: logging.info("No data length provided for stim stream") @@ -370,18 +367,14 @@ def clipped_stim_timestamps(self): len(timestamps) - len(pkl_file['items']['behavior']['intervalsms'] """ if self._clipped_stim_ts_delta is None: - (self._clipped_stim_timestamp_values, - self._clipped_stim_ts_delta) = self._get_clipped_stim_timestamps() + (self._clipped_stim_timestamp_values, self._clipped_stim_ts_delta) = self._get_clipped_stim_timestamps() - return (self._clipped_stim_timestamp_values, - self._clipped_stim_ts_delta) + return (self._clipped_stim_timestamp_values, self._clipped_stim_ts_delta) def _get_monitor_delay(self): timestamps, delta = self.clipped_stim_timestamps photodiode_key = self._keys["photodiode"] - delay = calculate_monitor_delay(self.dataset, - timestamps, - photodiode_key) + delay = calculate_monitor_delay(self.dataset, timestamps, photodiode_key) return delay @property @@ -426,9 +419,7 @@ def behavior_video_timestamps(self): @property def corrected_behavior_video_timestamps(self): - return corrected_video_timestamps("Behavior video", - self.behavior_video_timestamps, - self.behavior_data_length) + return corrected_video_timestamps("Behavior video", self.behavior_video_timestamps, self.behavior_data_length) @property def eye_video_timestamps(self): @@ -438,6 +429,4 @@ def eye_video_timestamps(self): @property def corrected_eye_video_timestamps(self): - return corrected_video_timestamps("Eye video", - self.eye_video_timestamps, - self.eye_data_length) + return corrected_video_timestamps("Eye video", self.eye_video_timestamps, self.eye_data_length) diff --git a/allensdk/internal/brain_observatory/util/multi_session_utils.py b/allensdk/internal/brain_observatory/util/multi_session_utils.py index 9b9c86ad3d..7c5e86e0a5 100644 --- a/allensdk/internal/brain_observatory/util/multi_session_utils.py +++ b/allensdk/internal/brain_observatory/util/multi_session_utils.py @@ -1,4 +1,5 @@ """Utilities for accessing data across multiple sessions""" + import os from multiprocessing import Pool from typing import List, Optional, Set, Callable @@ -6,21 +7,19 @@ from tqdm import tqdm from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile -from allensdk.brain_observatory.behavior.data_files.stimulus_file import \ - MalformedStimulusFileError +from allensdk.brain_observatory.behavior.data_files.stimulus_file import MalformedStimulusFileError from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.stimulus_processing import \ - get_image_names +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.stimulus_processing import get_image_names from allensdk.internal.api import PostgresQueryMixin def get_session_metadata_multiprocessing( - behavior_session_ids: List[int], - lims_engine: PostgresQueryMixin, - n_workers: Optional[int] = None, + behavior_session_ids: List[int], + lims_engine: PostgresQueryMixin, + n_workers: Optional[int] = None, ) -> List[BehaviorMetadata]: """Gets session metadata for `behavior_session_ids`. Uses multiprocessing to speed up reading @@ -43,8 +42,8 @@ def get_session_metadata_multiprocessing( target=_get_session_metadata, behavior_session_ids=behavior_session_ids, lims_engine=lims_engine, - progress_bar_title='Reading session metadata from pkl file', - n_workers=n_workers + progress_bar_title="Reading session metadata from pkl file", + n_workers=n_workers, ) session_metadata = [x for x in session_metadata if x is not None] @@ -52,9 +51,7 @@ def get_session_metadata_multiprocessing( def get_images_shown( - behavior_session_ids: List[int], - lims_engine: PostgresQueryMixin, - n_workers: Optional[int] = None + behavior_session_ids: List[int], lims_engine: PostgresQueryMixin, n_workers: Optional[int] = None ) -> Set[str]: """ Gets images shown to mouse during `behavior_session_ids` @@ -75,13 +72,14 @@ def get_images_shown( target=_get_image_names, behavior_session_ids=behavior_session_ids, lims_engine=lims_engine, - progress_bar_title='Reading image_names from pkl file', - n_workers=n_workers + progress_bar_title="Reading image_names from pkl file", + n_workers=n_workers, ) else: # single process - image_names = [_get_image_names([behavior_session_id, lims_engine]) - for behavior_session_id in behavior_session_ids] + image_names = [ + _get_image_names([behavior_session_id, lims_engine]) for behavior_session_id in behavior_session_ids + ] res = set() for image_name_set in image_names: for image_name in image_name_set: @@ -90,24 +88,23 @@ def get_images_shown( def multiprocessing_helper( - target: Callable, - progress_bar_title: str, - behavior_session_ids: List[int], - lims_engine: PostgresQueryMixin, - n_workers: Optional[int] = None + target: Callable, + progress_bar_title: str, + behavior_session_ids: List[int], + lims_engine: PostgresQueryMixin, + n_workers: Optional[int] = None, ): if n_workers is None: n_workers = os.cpu_count() with Pool(n_workers) as p: - res = list(tqdm( - p.imap(target, - zip( - behavior_session_ids, - [lims_engine] * len(behavior_session_ids)) - ), - total=len(behavior_session_ids), - desc=progress_bar_title)) + res = list( + tqdm( + p.imap(target, zip(behavior_session_ids, [lims_engine] * len(behavior_session_ids))), + total=len(behavior_session_ids), + desc=progress_bar_title, + ) + ) return res @@ -117,9 +114,7 @@ def _get_session_metadata(*args) -> Optional[BehaviorMetadata]: """ behavior_session_id, db_conn = args[0] try: - meta = BehaviorMetadata.from_lims( - behavior_session_id=BehaviorSessionId(behavior_session_id), - lims_db=db_conn) + meta = BehaviorMetadata.from_lims(behavior_session_id=BehaviorSessionId(behavior_session_id), lims_db=db_conn) except MalformedStimulusFileError: meta = None return meta @@ -130,10 +125,8 @@ def _get_image_names(*args) -> Set[str]: Helper function to get image names from behavior stimulus file """ behavior_session_id, db_conn = args[0] - behavior_stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=behavior_session_id, db=db_conn) - image_names = get_image_names( - behavior_stimulus_file=behavior_stimulus_file) + behavior_stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=behavior_session_id, db=db_conn) + image_names = get_image_names(behavior_stimulus_file=behavior_stimulus_file) return image_names @@ -143,7 +136,7 @@ def remove_invalid_sessions( remove_sessions_after_mouse_death_date: bool = True, remove_aborted_sessions: bool = True, expected_training_duration: int = 15 * 60, - expected_duration: int = 60 * 60 + expected_duration: int = 60 * 60, ) -> List[BehaviorMetadata]: """ Removes any invalid sessions from `behavior_sessions` @@ -182,16 +175,20 @@ def remove_invalid_sessions( if remove_sessions_after_mouse_death_date: behavior_sessions = [ - x for x in behavior_sessions - if (x.subject_metadata.get_death_date() is None or - x.date_of_acquisition <= x.subject_metadata.get_death_date())] + x + for x in behavior_sessions + if ( + x.subject_metadata.get_death_date() is None + or x.date_of_acquisition <= x.subject_metadata.get_death_date() + ) + ] if remove_aborted_sessions: - training_sessions = \ - [x for x in behavior_sessions if x.is_training and - x.get_session_duration() > expected_training_duration] - nontraining_sessions = \ - [x for x in behavior_sessions if not x.is_training and - x.get_session_duration() > expected_duration] + training_sessions = [ + x for x in behavior_sessions if x.is_training and x.get_session_duration() > expected_training_duration + ] + nontraining_sessions = [ + x for x in behavior_sessions if not x.is_training and x.get_session_duration() > expected_duration + ] behavior_sessions = training_sessions + nontraining_sessions return behavior_sessions diff --git a/allensdk/internal/core/__init__.py b/allensdk/internal/core/__init__.py index 49b76dd3d8..a9e368b4f7 100644 --- a/allensdk/internal/core/__init__.py +++ b/allensdk/internal/core/__init__.py @@ -3,4 +3,4 @@ with on-prem resources """ -from ._data_file import DataFile # noqa F401 +from ._data_file import DataFile # noqa F401 diff --git a/allensdk/internal/core/_data_file.py b/allensdk/internal/core/_data_file.py index 501c321128..4bb0336c43 100644 --- a/allensdk/internal/core/_data_file.py +++ b/allensdk/internal/core/_data_file.py @@ -21,9 +21,7 @@ class DataFile(abc.ABC): Path to file. """ - def __init__(self, - filepath: Union[str, Path], - **kwargs): # pragma: no cover + def __init__(self, filepath: Union[str, Path], **kwargs): # pragma: no cover self._filepath: str = safe_system_path(str(filepath)) self._data = self.load_data(filepath=self._filepath, **kwargs) @@ -37,8 +35,7 @@ def filepath(self) -> str: # pragma: no cover @classmethod @abc.abstractmethod - def from_json(cls, - dict_repr: dict) -> "DataFile": # pragma: no cover + def from_json(cls, dict_repr: dict) -> "DataFile": # pragma: no cover """Populates a DataFile from a JSON compatible dict (likely parsed by argschema) @@ -70,8 +67,7 @@ def from_lims(cls) -> "DataFile": # pragma: no cover @staticmethod @abc.abstractmethod - def load_data(filepath: Union[str, Path], - **kwargs) -> Any: # pragma: no cover + def load_data(filepath: Union[str, Path], **kwargs) -> Any: # pragma: no cover """Given a filepath (that is meant to by read by the DataFile type), load the contents of the file into a Python type. (dict, DataFrame, list, etc...) diff --git a/allensdk/internal/core/lims_pipeline_module.py b/allensdk/internal/core/lims_pipeline_module.py index fd0d87947b..f7309e64d8 100644 --- a/allensdk/internal/core/lims_pipeline_module.py +++ b/allensdk/internal/core/lims_pipeline_module.py @@ -10,13 +10,14 @@ SHARED_SDK = "/shared/bioapps/infoapps/lims2_modules/lib/allensdk" RUN_PYTHON = "/shared/bioapps/infoapps/lims2_modules/lib/python/run_python.sh" -class PipelineModule( object ): + +class PipelineModule(object): def __init__(self, description="", parser=None): if parser is None: self.parser = default_argument_parser(description) else: self.parser = parser - + self._args = None @property @@ -33,32 +34,35 @@ def input_data(self): except Exception as e: logging.error("could not read input json: %s", self.args.input_json) raise e - + def write_output_data(self, data): try: ju.write(self.args.output_json, data) except Exception as e: logging.error("could not write output json: %s", self.args.output_json) - raise e + raise e def default_argument_parser(description=""): parser = argparse.ArgumentParser(description) - parser.add_argument('input_json') - parser.add_argument('output_json') - parser.add_argument('--log-level', default=logging.DEBUG) + parser.add_argument("input_json") + parser.add_argument("output_json") + parser.add_argument("--log-level", default=logging.DEBUG) return parser -def run_module(module, input_data, storage_directory, - optional_args=None, - python=SHARED_PYTHON, - sdk_path=SHARED_SDK, - local=False, - pbs=None): - - PBS_TEMPLATE=""" +def run_module( + module, + input_data, + storage_directory, + optional_args=None, + python=SHARED_PYTHON, + sdk_path=SHARED_SDK, + local=False, + pbs=None, +): + PBS_TEMPLATE = """ export PYTHONPATH=%(sdk_path)s:$PYTHONPATH PYTHON=%(python)s SCRIPT="%(module)s" @@ -74,51 +78,43 @@ def run_module(module, input_data, storage_directory, Manifest.safe_mkdir(storage_directory) - pbs_headers = [ ('-j oe'), - ('-o %s' % os.path.join(storage_directory, "run.log")) ] + pbs_headers = [("-j oe"), ("-o %s" % os.path.join(storage_directory, "run.log"))] pbs = pbs if pbs is not None else {} - queue = pbs.get('queue', 'braintv') - pbs_headers.append('-q %s' % queue) + queue = pbs.get("queue", "braintv") + pbs_headers.append("-q %s" % queue) + + walltime = pbs.get("walltime", "3:00:00") + pbs_headers.append("-l walltime=%s" % walltime) - walltime = pbs.get('walltime', '3:00:00') - pbs_headers.append('-l walltime=%s' % walltime) + vmem = pbs.get("vmem", 16) + pbs_headers.append("-l vmem=%dgb" % vmem) - vmem = pbs.get('vmem', 16) - pbs_headers.append('-l vmem=%dgb' % vmem) + if "job_name" in pbs: + pbs_headers.append("-N %s" % pbs["job_name"]) - if 'job_name' in pbs: - pbs_headers.append('-N %s' % pbs['job_name']) - - if 'ncpus' in pbs: - pbs_headers.append('-l ncpus=%d' % pbs['ncpus']) + if "ncpus" in pbs: + pbs_headers.append("-l ncpus=%d" % pbs["ncpus"]) - pbs_headers = [ '#PBS %s' % s for s in pbs_headers ] + pbs_headers = ["#PBS %s" % s for s in pbs_headers] - with open(pbs_file,"w") as f: - f.write('\n'.join(pbs_headers) + PBS_TEMPLATE % { + with open(pbs_file, "w") as f: + f.write( + "\n".join(pbs_headers) + + PBS_TEMPLATE + % { "python": python, "sdk_path": sdk_path, "module": module, "input_json": input_json, "output_json": output_json, - "optional_args": " ".join(optional_args) - }) - - + "optional_args": " ".join(optional_args), + } + ) ju.write(input_json, input_data) if local: - subprocess.call(['sh', pbs_file]) + subprocess.call(["sh", pbs_file]) else: - subprocess.call(['qsub', pbs_file]) - - - - - - - - - + subprocess.call(["qsub", pbs_file]) diff --git a/allensdk/internal/core/lims_utilities.py b/allensdk/internal/core/lims_utilities.py index ef9cc42691..f31bff498a 100644 --- a/allensdk/internal/core/lims_utilities.py +++ b/allensdk/internal/core/lims_utilities.py @@ -15,13 +15,14 @@ def get_well_known_files_by_type(wkfs, wkf_type_id): - out = [os.path.join(wkf['storage_directory'], wkf['filename']) - for wkf in wkfs - if wkf.get('well_known_file_type_id', None) == wkf_type_id] + out = [ + os.path.join(wkf["storage_directory"], wkf["filename"]) + for wkf in wkfs + if wkf.get("well_known_file_type_id", None) == wkf_type_id + ] if len(out) == 0: - raise IOError( - "Could not find well known files with type %d." % wkf_type_id) + raise IOError("Could not find well known files with type %d." % wkf_type_id) return out @@ -31,21 +32,16 @@ def get_well_known_file_by_type(wkfs, wkf_type_id): nout = len(out) if nout != 1: - raise IOError( - "Expected single well known file with type %d. Got %d." % ( - wkf_type_id, nout)) + raise IOError("Expected single well known file with type %d. Got %d." % (wkf_type_id, nout)) return out[0] def get_well_known_files_by_name(wkfs, filename): - out = [os.path.join(wkf['storage_directory'], wkf['filename']) - for wkf in wkfs - if wkf['filename'] == filename] + out = [os.path.join(wkf["storage_directory"], wkf["filename"]) for wkf in wkfs if wkf["filename"] == filename] if len(out) == 0: - raise IOError( - "Could not find well known files with name %s." % filename) + raise IOError("Could not find well known files with name %s." % filename) return out @@ -55,73 +51,56 @@ def get_well_known_file_by_name(wkfs, filename): nout = len(out) if nout != 1: - raise IOError( - "Expected single well known file with name %s. Got %d." % ( - filename, nout)) + raise IOError("Expected single well known file with name %s. Got %d." % (filename, nout)) return out[0] def append_well_known_file(wkfs, path, wkf_type_id=None, content_type=None): - record = { - 'filename': os.path.basename(path), - 'storage_directory': os.path.dirname(path) - } + record = {"filename": os.path.basename(path), "storage_directory": os.path.dirname(path)} if wkf_type_id is not None: - record['well_known_file_type_id'] = wkf_type_id + record["well_known_file_type_id"] = wkf_type_id if content_type is not None: - record['content_type'] = content_type + record["content_type"] = content_type for wkf in wkfs: - if wkf['filename'] == record['filename']: - logging.debug( - "found existing well known file record for %s, updating", path) + if wkf["filename"] == record["filename"]: + logging.debug("found existing well known file record for %s, updating", path) wkf.update(record) return - logging.debug( - "could not find existing well known file record for %s, appending", - path) + logging.debug("could not find existing well known file record for %s, appending", path) wkfs.append(record) -def _connect(user="limsreader", host="limsdb2", database="lims2", - password="limsro", port=5432): +def _connect(user="limsreader", host="limsdb2", database="lims2", password="limsro", port=5432): import pg8000 - conn = pg8000.connect(user=user, host=host, database=database, - password=password, port=port) + conn = pg8000.connect(user=user, host=host, database=database, password=password, port=port) return conn, conn.cursor() def _select(cursor, query): cursor.execute(query) - columns = [d[0].decode("utf-8") if isinstance(d[0], bytes) else d[0] for d - in cursor.description] + columns = [d[0].decode("utf-8") if isinstance(d[0], bytes) else d[0] for d in cursor.description] return [dict(zip(columns, c)) for c in cursor.fetchall()] def select(cursor, query): - raise DeprecationWarning( - "lims_utilities.select is deprecated. Please use " - "lims_utilities.query instead.") + raise DeprecationWarning("lims_utilities.select is deprecated. Please use lims_utilities.query instead.") -def connect(user="limsreader", host="limsdb2", database="lims2", - password="limsro", port=5432): - raise DeprecationWarning( - "lims_utilities.connect is deprecated. Please use " - "lims_utilities.query instead.") +def connect(user="limsreader", host="limsdb2", database="lims2", password="limsro", port=5432): + raise DeprecationWarning("lims_utilities.connect is deprecated. Please use lims_utilities.query instead.") -def query(query, user="limsreader", host="limsdb2", database="lims2", - password="limsro", port=5432): +def query(query, user="limsreader", host="limsdb2", database="lims2", password="limsro", port=5432): conn, cursor = _connect(user, host, database, password, port) # Guard against non-ascii characters in query - query = ''.join([i if ord(i) < 128 else ' ' for i in query]) + query = "".join([i if ord(i) < 128 else " " for i in query]) try: results = _select(cursor, query) @@ -142,28 +121,30 @@ def convert_from_titan_linux(file_name): # Lookup table mapping project to program project_to_program = { "neuralcoding": "braintv", - '0378': "celltypes", - 'conn': "celltypes", - 'ctyconn': "celltypes", - 'humancelltypes': "celltypes", - 'mousecelltypes': "celltypes", - 'shotconn': "celltypes", - 'synapticphys': "celltypes", - 'whbi': "celltypes", - 'wijem': "celltypes" + "0378": "celltypes", + "conn": "celltypes", + "ctyconn": "celltypes", + "humancelltypes": "celltypes", + "mousecelltypes": "celltypes", + "shotconn": "celltypes", + "synapticphys": "celltypes", + "whbi": "celltypes", + "wijem": "celltypes", } # Tough intermediary state where we have old paths # being translated to new paths - m = re.match('/projects/([^/]+)/vol1/(.*)', file_name) + m = re.match("/projects/([^/]+)/vol1/(.*)", file_name) if m: - newpath = os.path.normpath(os.path.join( - '/allen', - 'programs', - project_to_program.get(m.group(1), 'undefined'), - 'production', - m.group(1), - m.group(2) - )) + newpath = os.path.normpath( + os.path.join( + "/allen", + "programs", + project_to_program.get(m.group(1), "undefined"), + "production", + m.group(1), + m.group(2), + ) + ) return newpath return file_name @@ -172,55 +153,53 @@ def linux_to_windows(file_name): # Lookup table mapping project to program project_to_program = { "neuralcoding": "braintv", - '0378': "celltypes", - 'conn': "celltypes", - 'ctyconn': "celltypes", - 'humancelltypes': "celltypes", - 'mousecelltypes': "celltypes", - 'shotconn': "celltypes", - 'synapticphys': "celltypes", - 'whbi': "celltypes", - 'wijem': "celltypes" + "0378": "celltypes", + "conn": "celltypes", + "ctyconn": "celltypes", + "humancelltypes": "celltypes", + "mousecelltypes": "celltypes", + "shotconn": "celltypes", + "synapticphys": "celltypes", + "whbi": "celltypes", + "wijem": "celltypes", } # Simple case for new world order - m = re.match('/allen', file_name) + m = re.match("/allen", file_name) if m: - return "\\" + file_name.replace('/', '\\') + return "\\" + file_name.replace("/", "\\") # /data/ paths are being retained (for now) # this will need to be extended to map directories to # /allen/{programs,aibs}/workgroups/foo - m = re.match('/data/([^/]+)/(.*)', file_name) + m = re.match("/data/([^/]+)/(.*)", file_name) if m: - return os.path.normpath( - os.path.join('\\\\aibsdata', m.group(1), m.group(2))) + return os.path.normpath(os.path.join("\\\\aibsdata", m.group(1), m.group(2))) # Tough intermediary state where we have old paths # being translated to new paths - m = re.match('/projects/([^/]+)/vol1/(.*)', file_name) + m = re.match("/projects/([^/]+)/vol1/(.*)", file_name) if m: - newpath = os.path.normpath(os.path.join( - '\\\\allen', - 'programs', - project_to_program.get(m.group(1), 'undefined'), - 'production', - m.group(1), - m.group(2) - )) + newpath = os.path.normpath( + os.path.join( + "\\\\allen", + "programs", + project_to_program.get(m.group(1), "undefined"), + "production", + m.group(1), + m.group(2), + ) + ) return newpath # No matches found. Clean up and return path given to us return os.path.normpath(file_name) -def get_input_json(object_id, object_class, strategy_class, host="lims2", - **kwargs): - query_string = ("http://{}/InputJsons?strategy_class={}" - "&object_class={}&object_id={}").format(host, - strategy_class, - object_class, - object_id) +def get_input_json(object_id, object_class, strategy_class, host="lims2", **kwargs): + query_string = ("http://{}/InputJsons?strategy_class={}&object_class={}&object_id={}").format( + host, strategy_class, object_class, object_id + ) for key, value in kwargs.items(): query_string += "&{}={}".format(key, value) diff --git a/allensdk/internal/core/mouse_connectivity_cache_prerelease.py b/allensdk/internal/core/mouse_connectivity_cache_prerelease.py index d3c229ceda..b768e00d9e 100644 --- a/allensdk/internal/core/mouse_connectivity_cache_prerelease.py +++ b/allensdk/internal/core/mouse_connectivity_cache_prerelease.py @@ -5,8 +5,7 @@ from allensdk.core import json_utilities from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache -from ..api.queries.mouse_connectivity_api_prerelease \ - import MouseConnectivityApiPrerelease +from ..api.queries.mouse_connectivity_api_prerelease import MouseConnectivityApiPrerelease class MouseConnectivityCachePrerelease(MouseConnectivityCache): @@ -46,40 +45,41 @@ class MouseConnectivityCachePrerelease(MouseConnectivityCache): """ - EXPERIMENTS_PRERELEASE_KEY = 'EXPERIMENTS_PRERELEASE' - STORAGE_DIRECTORIES_PRERELEASE_KEY = 'STORAGE_DIRECTORIES_PRERELEASE' + EXPERIMENTS_PRERELEASE_KEY = "EXPERIMENTS_PRERELEASE" + STORAGE_DIRECTORIES_PRERELEASE_KEY = "STORAGE_DIRECTORIES_PRERELEASE" # allows user to pass 'male', 'female' instead of only 'm', 'f' - _GENDER_DICT=dict(male='m', female='f') - - def __init__(self, - resolution=None, - cache=True, - manifest_file='mouse_connectivity_manifest_prerelease.json', - ccf_version=None, - version=None, - cache_storage_directories=True, - storage_directories_file_name=None): - + _GENDER_DICT = dict(male="m", female="f") + + def __init__( + self, + resolution=None, + cache=True, + manifest_file="mouse_connectivity_manifest_prerelease.json", + ccf_version=None, + version=None, + cache_storage_directories=True, + storage_directories_file_name=None, + ): super(MouseConnectivityCachePrerelease, self).__init__( - resolution=resolution, cache=cache, manifest_file=manifest_file, - ccf_version=ccf_version, version=version) - - file_name = self.get_cache_path(storage_directories_file_name, - self.STORAGE_DIRECTORIES_PRERELEASE_KEY) - self.api = MouseConnectivityApiPrerelease( - file_name, cache_storage_directories=cache_storage_directories) - - def get_experiments(self, - dataframe=False, - file_name=None, - cre=None, - injection_structure_ids=None, - age=None, - gender=None, - workflow_state=None, - workflows=None, - project_code=None): + resolution=resolution, cache=cache, manifest_file=manifest_file, ccf_version=ccf_version, version=version + ) + + file_name = self.get_cache_path(storage_directories_file_name, self.STORAGE_DIRECTORIES_PRERELEASE_KEY) + self.api = MouseConnectivityApiPrerelease(file_name, cache_storage_directories=cache_storage_directories) + + def get_experiments( + self, + dataframe=False, + file_name=None, + cre=None, + injection_structure_ids=None, + age=None, + gender=None, + workflow_state=None, + workflows=None, + project_code=None, + ): """Read a list of experiments. If caching is enabled, this will save the whole (unfiltered) list of @@ -96,8 +96,7 @@ def get_experiments(self, the file_name will be pulled out of the manifest. If caching is disabled, no file will be saved. Default is None. """ - file_name = self.get_cache_path(file_name, - self.EXPERIMENTS_PRERELEASE_KEY) + file_name = self.get_cache_path(file_name, self.EXPERIMENTS_PRERELEASE_KEY) if os.path.exists(file_name): experiments = json_utilities.read(file_name) @@ -109,30 +108,27 @@ def get_experiments(self, json_utilities.write(file_name, experiments) # filter the read/downloaded list of experiments - experiments = self.filter_experiments(experiments, - cre, - injection_structure_ids, - age, - gender, - workflow_state, - workflows, - project_code) + experiments = self.filter_experiments( + experiments, cre, injection_structure_ids, age, gender, workflow_state, workflows, project_code + ) if dataframe: experiments = pd.DataFrame(experiments) - experiments.set_index(['id'], inplace=True, drop=False) + experiments.set_index(["id"], inplace=True, drop=False) return experiments - def filter_experiments(self, - experiments, - cre=None, - injection_structure_ids=None, - age=None, - gender=None, - workflow_state=None, - workflows=None, - project_code=None): + def filter_experiments( + self, + experiments, + cre=None, + injection_structure_ids=None, + age=None, + gender=None, + workflow_state=None, + workflows=None, + project_code=None, + ): """ Take a list of experiments and filter them by cre status and injection structure. @@ -153,33 +149,33 @@ def filter_experiments(self, If None, returna all experiments. Default None. """ experiments = super(MouseConnectivityCachePrerelease, self).filter_experiments( - experiments, cre=cre, injection_structure_ids=injection_structure_ids) + experiments, cre=cre, injection_structure_ids=injection_structure_ids + ) # all kwargs == None base case conditions = [lambda d: True] if age is not None: age = [a.lower() for a in age] - conditions.append(lambda d: d['age'].lower() in age) + conditions.append(lambda d: d["age"].lower() in age) if gender is not None: # TODO: pass a string instead of an iterable? gender = [self._GENDER_DICT.get(g.lower(), g.lower()) for g in gender] - conditions.append(lambda d: d['gender'].lower() in gender) + conditions.append(lambda d: d["gender"].lower() in gender) if workflow_state is not None: - #workflow_state = map(str.lower, workflow_state) + # workflow_state = map(str.lower, workflow_state) workflow_state = [ws.lower() for ws in workflow_state] - conditions.append(lambda d: d['workflow_state'].lower() in workflow_state) + conditions.append(lambda d: d["workflow_state"].lower() in workflow_state) if workflows is not None: workflows = [w.lower() for w in workflows] - conditions.append(lambda d: any([w.lower() in workflows - for w in d['workflows']])) + conditions.append(lambda d: any([w.lower() in workflows for w in d["workflows"]])) if project_code is not None: project_code = [pc.lower() for pc in project_code] - conditions.append(lambda d: d['project_code'].lower() in project_code) + conditions.append(lambda d: d["project_code"].lower() in project_code) return [e for e in experiments if all(f(e) for f in conditions)] @@ -192,17 +188,17 @@ def add_manifest_paths(self, manifest_builder): file_name: string File location to save the manifest. """ - manifest_builder = super(MouseConnectivityCachePrerelease, self)\ - .add_manifest_paths(manifest_builder) - - manifest_builder.add_path(self.EXPERIMENTS_PRERELEASE_KEY, - 'experiments_prerelease.json', - parent_key='BASEDIR', - typename='file') - - manifest_builder.add_path(self.STORAGE_DIRECTORIES_PRERELEASE_KEY, - 'storage_directories_prerelease.json', - parent_key='BASEDIR', - typename='file') + manifest_builder = super(MouseConnectivityCachePrerelease, self).add_manifest_paths(manifest_builder) + + manifest_builder.add_path( + self.EXPERIMENTS_PRERELEASE_KEY, "experiments_prerelease.json", parent_key="BASEDIR", typename="file" + ) + + manifest_builder.add_path( + self.STORAGE_DIRECTORIES_PRERELEASE_KEY, + "storage_directories_prerelease.json", + parent_key="BASEDIR", + typename="file", + ) return manifest_builder diff --git a/allensdk/internal/core/simpletree.py b/allensdk/internal/core/simpletree.py index dad54adf0e..21cae3864c 100644 --- a/allensdk/internal/core/simpletree.py +++ b/allensdk/internal/core/simpletree.py @@ -1,14 +1,10 @@ - -class SimpleTree( object ): - def __init__(self, nodes, - node_id_cb, - parent_id_cb): - +class SimpleTree(object): + def __init__(self, nodes, node_id_cb, parent_id_cb): self.node_list = nodes - self._nodes = { node_id_cb(n):n for n in nodes } - self._parent_ids = { nid:parent_id_cb(n) for nid,n in self._nodes.items() } - self._child_ids = { nid:[] for nid in self._nodes } + self._nodes = {node_id_cb(n): n for n in nodes} + self._parent_ids = {nid: parent_id_cb(n) for nid, n in self._nodes.items()} + self._child_ids = {nid: [] for nid in self._nodes} for nid in self._parent_ids: pid = self._parent_ids[nid] @@ -39,7 +35,7 @@ def ancestor_ids(self, nid): pid = self.parent_id(pid) except Exception: raise KeyError("Could not find ancestors for node %s" % str(nid)) - + def descendant_ids(self, nid): ids = [nid] try: @@ -74,7 +70,3 @@ def descendants(self, nid): def ancestors(self, nid): for node in self.nodes(self.ancestor_ids(nid)): yield node - - - - diff --git a/allensdk/internal/core/swc.py b/allensdk/internal/core/swc.py index 5353bd9162..10ac6b1dd6 100644 --- a/allensdk/internal/core/swc.py +++ b/allensdk/internal/core/swc.py @@ -17,9 +17,10 @@ from allensdk.internal.morphology.morphology import * from allensdk.internal.morphology.node import Node + ######################################################################## def read_swc(file_name): - """ + """ Read in an SWC file and return a Morphology object. Parameters @@ -38,21 +39,21 @@ def read_swc(file_name): with open(file_name, "r") as f: for line in f: # remove comments - if line.lstrip().startswith('#'): + if line.lstrip().startswith("#"): continue # read values. expected SWC format is: # ID, type, x, y, z, rad, parent # x, y, z and rad are floats. the others are ints toks = line.split() vals = Node( - n = int(toks[0]), - t = int(toks[1]), - x = float(toks[2]), - y = float(toks[3]), - z = float(toks[4]), - r = float(toks[5]), - pn = int(toks[6].rstrip()) - ) + n=int(toks[0]), + t=int(toks[1]), + x=float(toks[2]), + y=float(toks[3]), + z=float(toks[4]), + r=float(toks[5]), + pn=int(toks[6].rstrip()), + ) # store this node nodes.append(vals) # increment line number (used for error reporting only) @@ -64,38 +65,36 @@ def read_swc(file_name): err += "Content: '%s'\n" % line raise IOError(err) - return Morphology(node_list=nodes) + return Morphology(node_list=nodes) ######################################################################## -class Marker( dict ): - """ Simple dictionary class for handling reconstruction marker objects. """ +class Marker(dict): + """Simple dictionary class for handling reconstruction marker objects.""" - SPACING = [ .1144, .1144, .28 ] + SPACING = [0.1144, 0.1144, 0.28] - CUT_DENDRITE = 10 + CUT_DENDRITE = 10 NO_RECONSTRUCTION = 20 def __init__(self, *args, **kwargs): super(Marker, self).__init__(*args, **kwargs) # marker file x,y,z coordinates are offset by a single image-space pixel - self['x'] -= self.SPACING[0] - self['y'] -= self.SPACING[1] - self['z'] -= self.SPACING[2] - + self["x"] -= self.SPACING[0] + self["y"] -= self.SPACING[1] + self["z"] -= self.SPACING[2] def read_marker_file(file_name): - """ read in a marker file and return a list of dictionaries """ - - with open(file_name, 'r') as f: - rows = csv.DictReader((r for r in f if not r.startswith('#')), - fieldnames=['x','y','z','radius','shape','name','comment', - 'color_r','color_g','color_b']) + """read in a marker file and return a list of dictionaries""" - return [ Marker({ 'x': float(r['x']), - 'y': float(r['y']), - 'z': float(r['z']), - 'name': int(r['name']) }) for r in rows ] + with open(file_name, "r") as f: + rows = csv.DictReader( + (r for r in f if not r.startswith("#")), + fieldnames=["x", "y", "z", "radius", "shape", "name", "comment", "color_r", "color_g", "color_b"], + ) + return [ + Marker({"x": float(r["x"]), "y": float(r["y"]), "z": float(r["z"]), "name": int(r["name"])}) for r in rows + ] diff --git a/allensdk/internal/ephys/core_feature_extract.py b/allensdk/internal/ephys/core_feature_extract.py index 1eb0a17367..04a49a5b1f 100644 --- a/allensdk/internal/ephys/core_feature_extract.py +++ b/allensdk/internal/ephys/core_feature_extract.py @@ -17,32 +17,34 @@ TEST_PULSE_DURATION_SEC = 0.4 -LONG_SQUARE_COARSE = 'C1LSCOARSE' -LONG_SQUARE_FINE = 'C1LSFINEST' -SHORT_SQUARE = 'C1SSFINEST' -RAMP = 'C1RP25PR1S' -PASSED_SWEEP_STATES = [ 'manual_passed', 'auto_passed' ] -ICLAMP_UNITS = [ 'Amps', 'pA' ] +LONG_SQUARE_COARSE = "C1LSCOARSE" +LONG_SQUARE_FINE = "C1LSFINEST" +SHORT_SQUARE = "C1SSFINEST" +RAMP = "C1RP25PR1S" +PASSED_SWEEP_STATES = ["manual_passed", "auto_passed"] +ICLAMP_UNITS = ["Amps", "pA"] + def filter_sweeps(sweeps, types=None, passed_only=True, iclamp_only=True): if passed_only: - sweeps = [ s for s in sweeps if s.get('workflow_state', None) in PASSED_SWEEP_STATES ] + sweeps = [s for s in sweeps if s.get("workflow_state", None) in PASSED_SWEEP_STATES] if iclamp_only: - sweeps = [ s for s in sweeps if s['stimulus_units'] in ICLAMP_UNITS ] + sweeps = [s for s in sweeps if s["stimulus_units"] in ICLAMP_UNITS] if types: - sweeps = [ s for s in sweeps for t in types - if s['ephys_stimulus']['description'].startswith(t) ] - - return sorted(sweeps, key=lambda x: x['sweep_number']) + sweeps = [s for s in sweeps for t in types if s["ephys_stimulus"]["description"].startswith(t)] + + return sorted(sweeps, key=lambda x: x["sweep_number"]) + def filtered_sweep_numbers(sweeps, types=None, passed_only=True, iclamp_only=True): - return [ s['sweep_number'] for s in filter_sweeps(sweeps, types, passed_only, iclamp_only) ] + return [s["sweep_number"] for s in filter_sweeps(sweeps, types, passed_only, iclamp_only)] + def find_stim_start(stim, idx0=0): - """ - Find the index of the first nonzero positive or negative jump in an array. + """ + Find the index of the first nonzero positive or negative jump in an array. Parameters ---------- @@ -60,56 +62,61 @@ def find_stim_start(stim, idx0=0): di = np.diff(stim) idxs = np.flatnonzero(di) idxs = idxs[idxs >= idx0] - + if len(idxs) == 0: return -1 - - return idxs[0]+1 + + return idxs[0] + 1 + def find_sweep_stim_start(data_set, sweep_number): sweep = data_set.get_sweep(sweep_number) - sr = sweep['sampling_rate'] - stim_start = find_stim_start(sweep['stimulus'], TEST_PULSE_DURATION_SEC * sr) / sr + sr = sweep["sampling_rate"] + stim_start = find_stim_start(sweep["stimulus"], TEST_PULSE_DURATION_SEC * sr) / sr logging.info("Long square stims start at time %f", stim_start) return stim_start - + + def find_coarse_long_square_amp_delta(sweeps, decimals=0): - """ Find the delta between amplitudes of coarse long square sweeps. Includes failed sweeps. """ - sweeps = filter_sweeps(sweeps, types=[ LONG_SQUARE_COARSE ], passed_only = False) + """Find the delta between amplitudes of coarse long square sweeps. Includes failed sweeps.""" + sweeps = filter_sweeps(sweeps, types=[LONG_SQUARE_COARSE], passed_only=False) - amps = sorted([s['stimulus_amplitude'] for s in sweeps]) + amps = sorted([s["stimulus_amplitude"] for s in sweeps]) amps_diff = np.round(np.diff(amps), decimals=decimals) - amps_diff = amps_diff[amps_diff > 0] # repeats are okay - deltas = sorted(np.unique(amps_diff)) # unique nonzero deltas + amps_diff = amps_diff[amps_diff > 0] # repeats are okay + deltas = sorted(np.unique(amps_diff)) # unique nonzero deltas if len(deltas) == 0: return 0 - + delta = deltas[0] - + if len(deltas) != 1: - logging.warning("Found multiple coarse long square amplitude step differences: %s. Using: %f" % (str(deltas), delta)) + logging.warning( + "Found multiple coarse long square amplitude step differences: %s. Using: %f" % (str(deltas), delta) + ) return delta + def update_output_sweep_features(cell_features, sweep_features, sweep_index): # add peak deflection for subthreshold long squares for sweep_number, sweep in sweep_index.items(): - pd = sweep_features.get(sweep_number,{}).get('peak_deflect', None) + pd = sweep_features.get(sweep_number, {}).get("peak_deflect", None) if pd is not None: - sweep['peak_deflection'] = pd[0] - + sweep["peak_deflection"] = pd[0] + # update num_spikes for sweep_num in sweep_features: - num_spikes = len(sweep_features[sweep_num]['spikes']) + num_spikes = len(sweep_features[sweep_num]["spikes"]) if num_spikes == 0: num_spikes = None - sweep_index[sweep_num]['num_spikes'] = num_spikes + sweep_index[sweep_num]["num_spikes"] = num_spikes def nan_get(obj, key): - """ Return a value from a dictionary. If it does not exist, return None. If it is NaN, return None """ + """Return a value from a dictionary. If it does not exist, return None. If it is NaN, return None""" v = obj.get(key, None) if v is None: @@ -117,12 +124,13 @@ def nan_get(obj, key): else: return None if np.isnan(v) else v + def generate_output_cell_features(cell_features, sweep_features, sweep_index): ephys_features = {} # find hero and rheo sweeps in sweep table rheo_sweep_num = cell_features["long_squares"]["rheobase_sweep"]["id"] - rheo_sweep_id = sweep_index.get(rheo_sweep_num, {}).get('id', None) + rheo_sweep_id = sweep_index.get(rheo_sweep_num, {}).get("id", None) if rheo_sweep_id is None: raise Exception("Could not find id of rheobase sweep number %d." % rheo_sweep_num) @@ -130,14 +138,14 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): hero_sweep = cell_features["long_squares"]["hero_sweep"] if hero_sweep is None: raise Exception("Could not find hero sweep") - + hero_sweep_num = hero_sweep["id"] - hero_sweep_id = sweep_index.get(hero_sweep_num, {}).get('id', None) + hero_sweep_id = sweep_index.get(hero_sweep_num, {}).get("id", None) if hero_sweep_id is None: raise Exception("Could not find id of hero sweep number %d." % hero_sweep_num) - - # create a table of values + + # create a table of values # this is a dictionary of ephys_features base = cell_features["long_squares"] ephys_features["rheobase_sweep_id"] = rheo_sweep_id @@ -158,7 +166,7 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): # now grab the rheo spike base = cell_features["long_squares"]["rheobase_sweep"]["spikes"][0] - ephys_features["upstroke_downstroke_ratio_long_square"] = nan_get(base, "upstroke_downstroke_ratio") + ephys_features["upstroke_downstroke_ratio_long_square"] = nan_get(base, "upstroke_downstroke_ratio") ephys_features["peak_v_long_square"] = nan_get(base, "peak_v") ephys_features["peak_t_long_square"] = nan_get(base, "peak_t") ephys_features["trough_v_long_square"] = nan_get(base, "trough_v") @@ -177,15 +185,15 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): ephys_features["sag"] = nan_get(base, "sag") # convert to ms tau = nan_get(base, "tau") - ephys_features["tau"] = (tau * 1e3) if tau is not None else None + ephys_features["tau"] = (tau * 1e3) if tau is not None else None ephys_features["vm_for_sag"] = nan_get(base, "vm_for_sag") - ephys_features["has_burst"] = None#base.get("has_burst", None) - ephys_features["has_pause"] = None#base.get("has_pause", None) - ephys_features["has_delay"] = None#base.get("has_delay", None) + ephys_features["has_burst"] = None # base.get("has_burst", None) + ephys_features["has_pause"] = None # base.get("has_pause", None) + ephys_features["has_delay"] = None # base.get("has_delay", None) ephys_features["f_i_curve_slope"] = nan_get(base, "fi_fit_slope") # change the base to ramp - base = cell_features["ramps"]["mean_spike_0"] # mean feature of first spike for all of these + base = cell_features["ramps"]["mean_spike_0"] # mean feature of first spike for all of these ephys_features["upstroke_downstroke_ratio_ramp"] = nan_get(base, "upstroke_downstroke_ratio") ephys_features["peak_v_ramp"] = nan_get(base, "peak_v") ephys_features["peak_t_ramp"] = nan_get(base, "peak_t") @@ -201,7 +209,7 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): ephys_features["threshold_t_ramp"] = nan_get(base, "threshold_t") # change the base to short_square - base = cell_features["short_squares"]["mean_spike_0"] # mean feature of first spike for all of these + base = cell_features["short_squares"]["mean_spike_0"] # mean feature of first spike for all of these ephys_features["upstroke_downstroke_ratio_short_square"] = nan_get(base, "upstroke_downstroke_ratio") ephys_features["peak_v_short_square"] = nan_get(base, "peak_v") ephys_features["peak_t_short_square"] = nan_get(base, "peak_t") @@ -216,44 +224,50 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): ephys_features["slow_trough_t_short_square"] = nan_get(base, "slow_trough_t") ephys_features["threshold_v_short_square"] = nan_get(base, "threshold_v") - #ephys_features["threshold_i_short_square"] = nan_get(base, "threshold_i") + # ephys_features["threshold_i_short_square"] = nan_get(base, "threshold_i") ephys_features["threshold_t_short_square"] = nan_get(base, "threshold_t") ephys_features["threshold_i_short_square"] = nan_get(cell_features["short_squares"], "stimulus_amplitude") return ephys_features + def extract_data(data, nwb_file): ########################################################## #### alings with ephys_sweep_qc_tool extract_features #### - cell_specimen = data['specimens'][0] - sweep_list = cell_specimen['ephys_sweeps'] - sweep_index = { s['sweep_number']:s for s in sweep_list } + cell_specimen = data["specimens"][0] + sweep_list = cell_specimen["ephys_sweeps"] + sweep_index = {s["sweep_number"]: s for s in sweep_list} data_set = NwbDataSet(nwb_file) - # extract sweep-level features logging.debug("Computing sweep features") iclamp_sweep_list = filter_sweeps(sweep_list, iclamp_only=True, passed_only=False) iclamp_sweeps = defaultdict(list) for s in iclamp_sweep_list: try: - stimulus_type_name = s['ephys_stimulus']['ephys_stimulus_type']['name'] + stimulus_type_name = s["ephys_stimulus"]["ephys_stimulus_type"]["name"] except KeyError: - raise Exception("Sweep %d has no ephys stimulus record in features JSON file: %s" % (s['sweep_number'], json.dumps(s, indent=3, default=json_handler))) + raise Exception( + "Sweep %d has no ephys stimulus record in features JSON file: %s" + % (s["sweep_number"], json.dumps(s, indent=3, default=json_handler)) + ) if stimulus_type_name == "Unknown": - raise Exception(("Sweep %d (%s) has 'Unknown' stimulus type." + - "Please update the EpysStimuli and EphysRawStimulusNames associations in LIMS.") % (s['sweep_number'], s['ephys_stimulus']['description'])) + raise Exception( + ( + "Sweep %d (%s) has 'Unknown' stimulus type." + + "Please update the EpysStimuli and EphysRawStimulusNames associations in LIMS." + ) + % (s["sweep_number"], s["ephys_stimulus"]["description"]) + ) - iclamp_sweeps[stimulus_type_name].append(s['sweep_number']) + iclamp_sweeps[stimulus_type_name].append(s["sweep_number"]) passed_iclamp_sweep_list = filter_sweeps(sweep_list, iclamp_only=True, passed_only=True) num_passed_sweeps = len(passed_iclamp_sweep_list) - logging.info("%d of %d sweeps passed QC", - num_passed_sweeps, - len(iclamp_sweep_list)) + logging.info("%d of %d sweeps passed QC", num_passed_sweeps, len(iclamp_sweep_list)) if num_passed_sweeps == 0: raise FeatureError("There are no QC-passed sweeps available to analyze") @@ -261,58 +275,59 @@ def extract_data(data, nwb_file): # compute sweep features logging.info("Computing sweep features") sweep_features = extract_sweep_features(data_set, iclamp_sweeps) - cell_specimen['sweep_ephys_features'] = sweep_features + cell_specimen["sweep_ephys_features"] = sweep_features # extract cell-level features logging.info("Computing cell features") - long_square_sweep_numbers = filtered_sweep_numbers(sweep_list, [ LONG_SQUARE_COARSE, LONG_SQUARE_FINE ]) - short_square_sweep_numbers = filtered_sweep_numbers(sweep_list, [ SHORT_SQUARE ]) - ramp_sweep_numbers = filtered_sweep_numbers(sweep_list, [ RAMP ]) + long_square_sweep_numbers = filtered_sweep_numbers(sweep_list, [LONG_SQUARE_COARSE, LONG_SQUARE_FINE]) + short_square_sweep_numbers = filtered_sweep_numbers(sweep_list, [SHORT_SQUARE]) + ramp_sweep_numbers = filtered_sweep_numbers(sweep_list, [RAMP]) logging.debug("long square sweeps: %s", str(long_square_sweep_numbers)) logging.debug("short square sweeps: %s", str(short_square_sweep_numbers)) logging.debug("ramp sweeps: %s", str(ramp_sweep_numbers)) # PBS-262 -- have variable subthreshold minimum for human cells - subthresh_min_amp = None # None means default (mouse) behavior + subthresh_min_amp = None # None means default (mouse) behavior long_square_amp_delta = find_coarse_long_square_amp_delta(sweep_list) - + if long_square_amp_delta != 20.0: subthresh_min_amp = -200 - logging.info("Long squares using %fpA step size. Using subthreshold minimum amplitude of %s.", - long_square_amp_delta, - str(subthresh_min_amp) if subthresh_min_amp is not None else "[default]") + logging.info( + "Long squares using %fpA step size. Using subthreshold minimum amplitude of %s.", + long_square_amp_delta, + str(subthresh_min_amp) if subthresh_min_amp is not None else "[default]", + ) stim_start = find_sweep_stim_start(data_set, long_square_sweep_numbers[0]) if stim_start > 0: logging.info("resetting long square start time to: %f", stim_start) reset_long_squares_start(stim_start) - cell_features = extract_cell_features(data_set, - ramp_sweep_numbers, - short_square_sweep_numbers, - long_square_sweep_numbers, - subthresh_min_amp) + cell_features = extract_cell_features( + data_set, ramp_sweep_numbers, short_square_sweep_numbers, long_square_sweep_numbers, subthresh_min_amp + ) # shuffle peak deflection for the subthreshold long squares for s in cell_features["long_squares"]["subthreshold_sweeps"]: - sweep_features[s['id']]['peak_deflect'] = s['peak_deflect'] + sweep_features[s["id"]]["peak_deflect"] = s["peak_deflect"] - cell_specimen['cell_ephys_features'] = cell_features + cell_specimen["cell_ephys_features"] = cell_features update_output_sweep_features(cell_features, sweep_features, sweep_index) ephys_features = generate_output_cell_features(cell_features, sweep_features, sweep_index) try: - out_ephys_features = cell_specimen.get('ephys_features',[])[0] + out_ephys_features = cell_specimen.get("ephys_features", [])[0] out_ephys_features.update(ephys_features) except IndexError: - cell_specimen['ephys_features'] = [ ephys_features ] + cell_specimen["ephys_features"] = [ephys_features] #### breaks with ephys_sweep_qc_tool extract_features #### ########################################################## return sweep_list, sweep_features + def save_qc_figures(qc_fig_dir, nwb_file, output_data, plot_cell_figures): if os.path.exists(qc_fig_dir): logging.warning("removing existing qc figures directory: %s", qc_fig_dir) diff --git a/allensdk/internal/ephys/plot_qc_figures.py b/allensdk/internal/ephys/plot_qc_figures.py index 913df6dd61..6d34f9c72f 100644 --- a/allensdk/internal/ephys/plot_qc_figures.py +++ b/allensdk/internal/ephys/plot_qc_figures.py @@ -1,6 +1,6 @@ import matplotlib -matplotlib.use('agg') +matplotlib.use("agg") import logging @@ -20,36 +20,41 @@ import datetime import matplotlib.pyplot as plt -#import seaborn as sns +# import seaborn as sns + +AXIS_Y_RANGE = [-110, 60] -AXIS_Y_RANGE = [ -110, 60 ] def get_time_string(): return datetime.datetime.now().strftime("%I:%M%p %B %d, %Y") + def get_spikes(sweep_features, sweep_number): return get_features(sweep_features, sweep_number)["spikes"] + def get_features(sweep_features, sweep_number): - try: + try: return sweep_features[int(sweep_number)] except KeyError: return sweep_features[str(sweep_number)] + def load_experiment(file_name, sweep_number): ds = NwbDataSet(file_name) sweep = ds.get_sweep(sweep_number) - - r = sweep['index_range'] - v = sweep['response'] * 1e3 - i = sweep['stimulus'] * 1e12 - dt = 1.0 / sweep['sampling_rate'] + + r = sweep["index_range"] + v = sweep["response"] * 1e3 + i = sweep["stimulus"] * 1e12 + dt = 1.0 / sweep["sampling_rate"] t = np.arange(0, len(v)) * dt return (v, i, t, r, dt) + def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features, cell_features, type_name): - figs = [ plt.figure() for f in range(3+len(sweep_numbers)) ] + figs = [plt.figure() for f in range(3 + len(sweep_numbers))] v, i, t, r, dt = load_experiment(nwb_file, sweep_numbers[0]) if type_name == "short_square" or type_name == "long_square": @@ -64,7 +69,7 @@ def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features for sn in sweep_numbers: spikes = get_spikes(sweep_features, sn) - if (len(spikes) < 1): + if len(spikes) < 1: logging.warning("no spikes in sweep %d" % sn) continue @@ -74,42 +79,55 @@ def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] - voltages = [ rheo_spike[f] for f in voltage_features] - times = [ rheo_spike[f] for f in time_features] + voltages = [rheo_spike[f] for f in voltage_features] + times = [rheo_spike[f] for f in time_features] plt.figure(figs[0].number) - plt.scatter(range(len(voltages)), voltages, color='gray') + plt.scatter(range(len(voltages)), voltages, color="gray") plt.tight_layout() - plt.figure(figs[1].number) - plt.scatter(range(len(times)), times, color='gray') + plt.scatter(range(len(times)), times, color="gray") plt.tight_layout() plt.figure(figs[2].number) - plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])], color='gray') + plt.scatter([0], [spikes[0]["upstroke"] / (-spikes[0]["downstroke"])], color="gray") plt.tight_layout() - plt.figure(figs[0].number) - - yvals = [float(lims_features[k + "_v_" + type_name]) for k in gen_features if lims_features[k + "_v_" + type_name] is not None] + + yvals = [ + float(lims_features[k + "_v_" + type_name]) + for k in gen_features + if lims_features[k + "_v_" + type_name] is not None + ] xvals = range(len(yvals)) - - plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) - plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) + + plt.scatter(xvals, yvals, color="blue", marker="_", s=40, zorder=100) + plt.xticks(xvals, ["thr", "pk", "tr", "ftr", "str"]) plt.title(type_name + ": voltages") plt.figure(figs[1].number) - yvals = [float(lims_features[k + "_t_" + type_name]) for k in gen_features if lims_features[k + "_t_" + type_name] is not None] + yvals = [ + float(lims_features[k + "_t_" + type_name]) + for k in gen_features + if lims_features[k + "_t_" + type_name] is not None + ] xvals = range(len(yvals)) - plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) - plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) + plt.scatter(xvals, yvals, color="blue", marker="_", s=40, zorder=100) + plt.xticks(xvals, ["thr", "pk", "tr", "ftr", "str"]) plt.title(type_name + ": times") - + plt.figure(figs[2].number) if lims_features["upstroke_downstroke_ratio_" + type_name] is not None: - plt.scatter([0], [float(lims_features["upstroke_downstroke_ratio_" + type_name])], color='blue', marker='_', s=40, zorder=100) + plt.scatter( + [0], + [float(lims_features["upstroke_downstroke_ratio_" + type_name])], + color="blue", + marker="_", + s=40, + zorder=100, + ) plt.xticks([]) plt.title(type_name + ": up/down") @@ -117,7 +135,7 @@ def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features plt.figure(figs[3 + index].number) v, i, t, r, dt = load_experiment(nwb_file, sn) - plt.plot(t, v, color='black') + plt.plot(t, v, color="black") plt.title(str(sn)) spikes = get_spikes(sweep_features, sn) @@ -134,22 +152,31 @@ def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] - voltages = [ rheo_spike[f] for f in voltage_features ] - times = [ rheo_spike[f] for f in time_features ] + voltages = [rheo_spike[f] for f in voltage_features] + times = [rheo_spike[f] for f in time_features] - plt.scatter(times, voltages, color='red', zorder=20) - + plt.scatter(times, voltages, color="red", zorder=20) delta_v = 5.0 if nspikes: - plt.plot([spikes[0]['upstroke_t'] - 1e-3 * (delta_v / spikes[0]['upstroke']), - spikes[0]['upstroke_t'] + 1e-3 * (delta_v / spikes[0]['upstroke'])], - [spikes[0]['upstroke_v'] - delta_v, spikes[0]['upstroke_v'] + delta_v], color='red') - - if 'downstroke_t' in spikes[0]: - plt.plot([spikes[0]['downstroke_t'] - 1e-3 * (delta_v / spikes[0]['downstroke']), - spikes[0]['downstroke_t'] + 1e-3 * (delta_v / spikes[0]['downstroke'])], - [spikes[0]['downstroke_v'] - delta_v, spikes[0]['downstroke_v'] + delta_v], color='red') + plt.plot( + [ + spikes[0]["upstroke_t"] - 1e-3 * (delta_v / spikes[0]["upstroke"]), + spikes[0]["upstroke_t"] + 1e-3 * (delta_v / spikes[0]["upstroke"]), + ], + [spikes[0]["upstroke_v"] - delta_v, spikes[0]["upstroke_v"] + delta_v], + color="red", + ) + + if "downstroke_t" in spikes[0]: + plt.plot( + [ + spikes[0]["downstroke_t"] - 1e-3 * (delta_v / spikes[0]["downstroke"]), + spikes[0]["downstroke_t"] + 1e-3 * (delta_v / spikes[0]["downstroke"]), + ], + [spikes[0]["downstroke_v"] - delta_v, spikes[0]["downstroke_v"] + delta_v], + color="red", + ) else: logging.warning("spike has no downstroke time, clipped") @@ -159,72 +186,72 @@ def plot_single_ap_values(nwb_file, sweep_numbers, lims_features, sweep_features elif type_name == "short_square": plt.xlim(stim_start - 0.002, stim_start + stim_dur + 0.01) elif type_name == "long_square": - plt.xlim(times[0]- 0.002, times[-2] + 0.002) + plt.xlim(times[0] - 0.002, times[-2] + 0.002) plt.tight_layout() - return figs + def plot_sweep_figures(nwb_file, ephys_roi_result, image_dir, sizes): sweeps = ephys_roi_result["specimens"][0]["ephys_sweeps"] - vclamp_sweep_numbers = sorted([ s['sweep_number'] for s in sweeps if s['stimulus_units'] == 'Amps' or s['stimulus_units'] == 'pA' ]) + vclamp_sweep_numbers = sorted( + [s["sweep_number"] for s in sweeps if s["stimulus_units"] == "Amps" or s["stimulus_units"] == "pA"] + ) image_file_sets = {} - - tp_len = 0.035 tp_steps = int(tp_len * 200000) b, a = sg.bessel(4, 0.1, "low") for i, sweep_number in enumerate(vclamp_sweep_numbers): - logging.info("plotting sweep %d" % sweep_number) + logging.info("plotting sweep %d" % sweep_number) if i == 0: - v_init, i_init, t_init, r_init, dt_init = load_experiment(nwb_file, sweep_number) + v_init, i_init, t_init, r_init, dt_init = load_experiment(nwb_file, sweep_number) tp_fig = plt.figure() axTP = plt.gca() axTP.set_yticklabels([]) axTP.set_xticklabels([]) axTP.set_xlabel(str(sweep_number)) - axTP.set_ylabel('') + axTP.set_ylabel("") xTP = t_init[0:tp_steps] yTP = v_init[0:tp_steps] axTP.plot(xTP, yTP, linewidth=1) axTP.set_xlim(0, tp_len) -# sns.despine() - + # sns.despine() + exp_fig = plt.figure() axDP = plt.gca() axDP.set_yticklabels([]) axDP.set_xticklabels([]) axDP.set_xlabel(str(sweep_number)) - axDP.set_ylabel('') - v_exp = v_init[r_init[0]:] - t_exp = t_init[r_init[0]:] + axDP.set_ylabel("") + v_exp = v_init[r_init[0] :] + t_exp = t_init[r_init[0] :] yDP = sg.filtfilt(b, a, v_exp, axis=0) xDP = t_exp baseline = yDP[5000:9000] baselineMean = np.mean(baseline) - baselineV = (np.ones(len(xDP))) * baselineMean + baselineV = (np.ones(len(xDP))) * baselineMean axDP.plot(xDP, yDP, linewidth=1) axDP.plot(xDP, baselineV, linewidth=1) axDP.set_xlim(t_exp[0], t_exp[-1]) -# sns.despine() + # sns.despine() v_prev, _i_prev, _t_prev, _r_prev = v_init, i_init, t_init, r_init # noqa: F841 else: - v, i, t, r, dt = load_experiment(nwb_file, sweep_number) + v, i, t, r, dt = load_experiment(nwb_file, sweep_number) tp_fig = plt.figure() axTP = plt.gca() axTP.set_yticklabels([]) axTP.set_xticklabels([]) axTP.set_xlabel(str(sweep_number)) - axTP.set_ylabel('') + axTP.set_ylabel("") yTP = v[:tp_steps] xTP = t[:tp_steps] TPBL = np.mean(yTP[0:100]) @@ -239,42 +266,42 @@ def plot_sweep_figures(nwb_file, ephys_roi_result, image_dir, sizes): axTP.plot(xTP, yTPpN, linewidth=1) axTP.plot(xTP, yTPN, linewidth=1) axTP.set_xlim(0, tp_len) -# sns.despine() + # sns.despine() exp_fig = plt.figure() axDP = plt.gca() axDP.set_yticklabels([]) axDP.set_xticklabels([]) axDP.set_xlabel(str(sweep_number)) - axDP.set_ylabel('') - v_exp = v[r[0]:] - t_exp = t[r[0]:] + axDP.set_ylabel("") + v_exp = v[r[0] :] + t_exp = t[r[0] :] yDP = sg.filtfilt(b, a, v_exp, axis=0) xDP = t_exp baseline = yDP[5000:9000] baselineMean = np.mean(baseline) - baselineV = (np.ones(len(xDP))) * baselineMean + baselineV = (np.ones(len(xDP))) * baselineMean axDP.plot(xDP, yDP, linewidth=1) axDP.plot(xDP, baselineV, linewidth=1) axDP.set_xlim(t_exp[0], t_exp[-1]) -# sns.despine() + # sns.despine() v_prev, _i_prev, _t_prev, _r_prev = v, i, t, r # noqa: F841 - - save_figure(tp_fig, 'test_pulse_%d' % sweep_number, 'test_pulses', image_dir, sizes, image_file_sets) - save_figure(exp_fig, 'experiment_%d' % sweep_number, 'experiments', image_dir, sizes, image_file_sets) + save_figure(tp_fig, "test_pulse_%d" % sweep_number, "test_pulses", image_dir, sizes, image_file_sets) + save_figure(exp_fig, "experiment_%d" % sweep_number, "experiments", image_dir, sizes, image_file_sets) return image_file_sets -def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, scalew=1, scaleh=1, ext='jpg'): + +def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, scalew=1, scaleh=1, ext="jpg"): plt.figure(fig.number) if image_set_name not in image_sets: - image_sets[image_set_name] = { size_name: [] for size_name in sizes } + image_sets[image_set_name] = {size_name: [] for size_name in sizes} for size_name, size in sizes.items(): - fig.set_size_inches(size*scalew, size*scaleh) + fig.set_size_inches(size * scalew, size * scaleh) image_file = os.path.join(image_dir, "%s_%s.%s" % (image_name, size_name, ext)) @@ -286,22 +313,22 @@ def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, s def plot_images(ephys_roi_result, image_dir, sizes, image_sets): - wkfs = [ f for f in ephys_roi_result['well_known_files'] if f['filename'].endswith('tif') ] - - paths = [ os.path.join(f['storage_directory'], f['filename']) for f in wkfs ] + wkfs = [f for f in ephys_roi_result["well_known_files"] if f["filename"].endswith("tif")] - paths = [ lims_utilities.safe_system_path(p) for p in paths ] + paths = [os.path.join(f["storage_directory"], f["filename"]) for f in wkfs] + + paths = [lims_utilities.safe_system_path(p) for p in paths] image_set_name = "images" - image_sets[image_set_name] = { size_name: [] for size_name in sizes } + image_sets[image_set_name] = {size_name: [] for size_name in sizes} for i, path in enumerate(paths): image_data = plt.imread(path) image_data = np.array(image_data, dtype=np.float32) - + vmin = image_data.min() vmax = image_data.max() - + image_data = np.array((image_data - vmin) / (vmax - vmin) * 255.0, dtype=np.uint8) for size_name, size in sizes.items(): @@ -312,24 +339,25 @@ def plot_images(ephys_roi_result, image_dir, sizes, image_sets): else: sdata = image_data - filename = os.path.join(image_dir, "image_%d_%s.jpg" % (i, size_name)) scipy.misc.imsave(filename, sdata) - image_sets['images'][size_name].append(filename) - + image_sets["images"][size_name].append(filename) -def plot_subthreshold_long_square_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): + +def plot_subthreshold_long_square_figures( + nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files +): sub_sweeps = cell_features["long_squares"]["subthreshold_sweeps"] tau_sweeps = cell_features["long_squares"]["subthreshold_membrane_property_sweeps"] # 0a - Plot VI curve and linear fit, along with vrest - x = np.array([ s['stim_amp'] for s in sub_sweeps ]) - y = np.array([ s['peak_deflect'][0] for s in sub_sweeps ]) - i = np.array([ s['stim_amp'] for s in tau_sweeps ]) + x = np.array([s["stim_amp"] for s in sub_sweeps]) + y = np.array([s["peak_deflect"][0] for s in sub_sweeps]) + i = np.array([s["stim_amp"] for s in tau_sweeps]) fig = plt.figure() - plt.scatter(x, y, color='black') + plt.scatter(x, y, color="black") plt.plot([x.min(), x.max()], [lims_features["vrest"], lims_features["vrest"]], color="blue", linewidth=2) plt.plot(i, i * 1e-3 * lims_features["ri"] + lims_features["vrest"], color="red", linewidth=2) plt.xlabel("pA") @@ -337,35 +365,39 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, lims_features plt.title("ri = {:.1f}, vrest = {:.1f}".format(lims_features["ri"], lims_features["vrest"])) plt.tight_layout() - save_figure(fig, 'VI_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) - + save_figure(fig, "VI_curve", "subthreshold_long_squares", image_dir, sizes, cell_image_files) + # 0b - Plot tau curve and average fig = plt.figure() - x = np.array([ s['stim_amp'] for s in tau_sweeps ]) - y = np.array([ s['tau'] for s in tau_sweeps ]) - plt.scatter(x, y, color='black') - i = np.array([ s['stim_amp'] for s in tau_sweeps ]) - plt.plot([i.min(), i.max()], [cell_features["long_squares"]["tau"], cell_features["long_squares"]["tau"]], color="red", linewidth=2) + x = np.array([s["stim_amp"] for s in tau_sweeps]) + y = np.array([s["tau"] for s in tau_sweeps]) + plt.scatter(x, y, color="black") + i = np.array([s["stim_amp"] for s in tau_sweeps]) + plt.plot( + [i.min(), i.max()], + [cell_features["long_squares"]["tau"], cell_features["long_squares"]["tau"]], + color="red", + linewidth=2, + ) plt.xlabel("pA") ylim = plt.ylim() plt.ylim(0, ylim[1]) plt.ylabel("tau (s)") plt.tight_layout() + save_figure(fig, "tau_curve", "subthreshold_long_squares", image_dir, sizes, cell_image_files) - save_figure(fig, 'tau_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) - - subthresh_dict = {s['id']:s for s in tau_sweeps} + subthresh_dict = {s["id"]: s for s in tau_sweeps} # 0c - Plot the subthreshold squares - tau_sweeps = [ s['id'] for s in tau_sweeps ] - tau_figs = [ plt.figure() for i in range(len(tau_sweeps)) ] + tau_sweeps = [s["id"] for s in tau_sweeps] + tau_figs = [plt.figure() for i in range(len(tau_sweeps))] for index, s in enumerate(tau_sweeps): v, i, t, r, dt = load_experiment(nwb_file, s) - + plt.figure(tau_figs[index].number) - + plt.plot(t, v, color="black") if index == 0: @@ -379,13 +411,12 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, lims_features stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) - peak_idx = subthresh_dict[s]['peak_deflect'][1] - peak_t = peak_idx*dt - plt.scatter([peak_t], [subthresh_dict[s]['peak_deflect'][0]], color='red', zorder=10) + peak_idx = subthresh_dict[s]["peak_deflect"][1] + peak_t = peak_idx * dt + plt.scatter([peak_t], [subthresh_dict[s]["peak_deflect"][0]], color="red", zorder=10) popt = ft.fit_membrane_time_constant(v, t, stim_start, peak_t) plt.title(str(s)) - plt.plot(t[start_idx:peak_idx], exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt), color='blue') - + plt.plot(t[start_idx:peak_idx], exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt), color="blue") for index, s in enumerate(tau_sweeps): plt.figure(tau_figs[index].number) @@ -393,32 +424,37 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, lims_features plt.tight_layout() for index, tau_fig in enumerate(tau_figs): - save_figure(tau_figs[index], 'tau_%d' % index, 'subthreshold_long_squares', image_dir, sizes, cell_image_files) + save_figure(tau_figs[index], "tau_%d" % index, "subthreshold_long_squares", image_dir, sizes, cell_image_files) + -def plot_short_square_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): +def plot_short_square_figures( + nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files +): repeat_amp = cell_features["short_squares"].get("stimulus_amplitude", None) if repeat_amp is not None: - short_square_sweep_nums = [ s['id'] for s in cell_features["short_squares"]["common_amp_sweeps"] ] + short_square_sweep_nums = [s["id"] for s in cell_features["short_squares"]["common_amp_sweeps"]] - figs = plot_single_ap_values(nwb_file, short_square_sweep_nums, - lims_features, sweep_features, cell_features, - "short_square") + figs = plot_single_ap_values( + nwb_file, short_square_sweep_nums, lims_features, sweep_features, cell_features, "short_square" + ) for index, fig in enumerate(figs): - save_figure(fig, 'short_squares_%d' % index, 'short_squares', image_dir, sizes, cell_image_files) + save_figure(fig, "short_squares_%d" % index, "short_squares", image_dir, sizes, cell_image_files) - fig = plot_instantaneous_threshold_thumbnail(nwb_file, short_square_sweep_nums, - cell_features, lims_features, sweep_features) + fig = plot_instantaneous_threshold_thumbnail( + nwb_file, short_square_sweep_nums, cell_features, lims_features, sweep_features + ) - save_figure(fig, 'instantaneous_threshold_thumbnail', 'short_squares', image_dir, sizes, cell_image_files) - + save_figure(fig, "instantaneous_threshold_thumbnail", "short_squares", image_dir, sizes, cell_image_files) else: logging.warning("No short square figures to plot.") -def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_features, lims_features, sweep_features, color='red'): +def plot_instantaneous_threshold_thumbnail( + nwb_file, sweep_numbers, cell_features, lims_features, sweep_features, color="red" +): min_sweep_number = None for sn in sorted(sweep_numbers): spikes = get_spikes(sweep_features, sn) @@ -427,13 +463,13 @@ def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_feature min_sweep_number = sn if min_sweep_number is None else min(min_sweep_number, sn) fig = plt.figure(frameon=False) - ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() - fig.add_axes(ax) + fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) - ax.set_xlabel('') - ax.set_ylabel('') + ax.set_xlabel("") + ax.set_ylabel("") v, i, t, r, dt = load_experiment(nwb_file, sn) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) @@ -442,44 +478,52 @@ def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_feature tend = stim_start + stim_dur + 0.005 plt.plot(t, v, linewidth=1, color=color) - + plt.ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1]) plt.xlim(tstart, tend) return fig -def plot_ramp_figures(nwb_file, cell_specimen, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): - sweeps = cell_specimen['ephys_sweeps'] - ramps_sweeps = [ s["sweep_number"] for s in sweeps if s["workflow_state"].endswith("passed") and s["ephys_stimulus"]["description"][:10] == "C1RP25PR1S"] +def plot_ramp_figures( + nwb_file, cell_specimen, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files +): + sweeps = cell_specimen["ephys_sweeps"] + ramps_sweeps = [ + s["sweep_number"] + for s in sweeps + if s["workflow_state"].endswith("passed") and s["ephys_stimulus"]["description"][:10] == "C1RP25PR1S" + ] figs = [] if len(ramps_sweeps) > 0: figs = plot_single_ap_values(nwb_file, ramps_sweeps, lims_features, sweep_features, cell_features, "ramp") for index, fig in enumerate(figs): - save_figure(fig, 'ramps_%d' % index, 'ramps', image_dir, sizes, cell_image_files) + save_figure(fig, "ramps_%d" % index, "ramps", image_dir, sizes, cell_image_files) + def plot_rheo_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): - rheo_sweeps = [ lims_features["rheobase_sweep_num"] ] + rheo_sweeps = [lims_features["rheobase_sweep_num"]] figs = plot_single_ap_values(nwb_file, rheo_sweeps, lims_features, sweep_features, cell_features, "long_square") for index, fig in enumerate(figs): - save_figure(fig, 'rheo_%d' % index, 'rheo', image_dir, sizes, cell_image_files) + save_figure(fig, "rheo_%d" % index, "rheo", image_dir, sizes, cell_image_files) + def plot_hero_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() v, i, t, r, dt = load_experiment(nwb_file, int(lims_features["thumbnail_sweep_num"])) - plt.plot(t, v, color='black') + plt.plot(t, v, color="black") stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) plt.ylim(-110, 50) - spike_times = [spk['threshold_t'] for spk in get_spikes(sweep_features, lims_features["thumbnail_sweep_num"])] + spike_times = [spk["threshold_t"] for spk in get_spikes(sweep_features, lims_features["thumbnail_sweep_num"])] isis = np.diff(np.array(spike_times)) plt.title("thumbnail {:d}, amp = {:.1f}".format(lims_features["thumbnail_sweep_num"], stim_amp)) plt.tight_layout() - - save_figure(fig, 'thumbnail_0', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2) + + save_figure(fig, "thumbnail_0", "thumbnail", image_dir, sizes, cell_image_files, scalew=2) fig = plt.figure() plt.plot(range(len(isis)), isis) @@ -488,14 +532,14 @@ def plot_hero_figures(nwb_file, cell_features, lims_features, sweep_features, im plt.title("adapt = {:.3g}".format(lims_features["adaptation"])) else: plt.title("adapt = not defined") - + for k in ["has_delay", "has_burst", "has_pause"]: if lims_features.get(k, None) is None: lims_features[k] = False plt.tight_layout() - save_figure(fig, 'thumbnail_1', 'thumbnail', image_dir, sizes, cell_image_files) - + save_figure(fig, "thumbnail_1", "thumbnail", image_dir, sizes, cell_image_files) + yvals = [ float(lims_features["has_delay"]), float(lims_features["has_burst"]), @@ -504,22 +548,24 @@ def plot_hero_figures(nwb_file, cell_features, lims_features, sweep_features, im xvals = range(len(yvals)) fig = plt.figure() - plt.scatter(xvals, yvals, color='red') - plt.xticks(xvals, ['Delay', 'Burst', 'Pause']) + plt.scatter(xvals, yvals, color="red") + plt.xticks(xvals, ["Delay", "Burst", "Pause"]) plt.title("flags") plt.tight_layout() - save_figure(fig, 'thumbnail_2', 'thumbnail', image_dir, sizes, cell_image_files) + save_figure(fig, "thumbnail_2", "thumbnail", image_dir, sizes, cell_image_files) summary_fig = plot_long_square_summary(nwb_file, cell_features, lims_features, sweep_features) - save_figure(summary_fig, 'ephys_summary', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2) + save_figure(summary_fig, "ephys_summary", "thumbnail", image_dir, sizes, cell_image_files, scalew=2) def plot_long_square_summary(nwb_file, cell_features, lims_features, sweep_features): - long_square_sweeps = cell_features['long_squares']['sweeps'] - long_square_sweep_numbers = [ int(s['id']) for s in long_square_sweeps ] - - thumbnail_summary_fig = plot_sweep_set_summary(nwb_file, int(lims_features['thumbnail_sweep_num']), long_square_sweep_numbers) + long_square_sweeps = cell_features["long_squares"]["sweeps"] + long_square_sweep_numbers = [int(s["id"]) for s in long_square_sweeps] + + thumbnail_summary_fig = plot_sweep_set_summary( + nwb_file, int(lims_features["thumbnail_sweep_num"]), long_square_sweep_numbers + ) plt.figure(thumbnail_summary_fig.number) return thumbnail_summary_fig @@ -527,12 +573,16 @@ def plot_long_square_summary(nwb_file, cell_features, lims_features, sweep_featu def plot_fi_curve_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() - fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"], key=lambda s:s['stim_amp']) - x = [d['stim_amp'] for d in fi_sorted] - y = [d['avg_rate'] for d in fi_sorted] - last_zero_idx = np.nonzero(y)[0][0] - 1 - plt.scatter(x, y, color='black') - plt.plot(x[last_zero_idx:], cell_features["long_squares"]["fi_fit_slope"] * (np.array(x[last_zero_idx:]) - x[last_zero_idx]), color='red') + fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"], key=lambda s: s["stim_amp"]) + x = [d["stim_amp"] for d in fi_sorted] + y = [d["avg_rate"] for d in fi_sorted] + last_zero_idx = np.nonzero(y)[0][0] - 1 + plt.scatter(x, y, color="black") + plt.plot( + x[last_zero_idx:], + cell_features["long_squares"]["fi_fit_slope"] * (np.array(x[last_zero_idx:]) - x[last_zero_idx]), + color="red", + ) plt.xlabel("pA") plt.ylabel("spikes/sec") plt.title("slope = {:.3g}".format(lims_features["f_i_curve_slope"])) @@ -542,70 +592,77 @@ def plot_fi_curve_figures(nwb_file, cell_features, lims_features, sweep_features v, i, t, r, dt = load_experiment(nwb_file, s) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) rheo_hero_x.append(stim_amp) - rheo_hero_y = [ len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps ] + rheo_hero_y = [len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps] plt.scatter(rheo_hero_x, rheo_hero_y, zorder=20) plt.tight_layout() - save_figure(fig, 'fi_curve', 'fi_curve', image_dir, sizes, cell_image_files, scalew=2) + save_figure(fig, "fi_curve", "fi_curve", image_dir, sizes, cell_image_files, scalew=2) + def plot_sag_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() for d in cell_features["long_squares"]["subthreshold_sweeps"]: - if d['peak_deflect'][0] == lims_features["vm_for_sag"]: - v, i, t, r, dt = load_experiment(nwb_file, int(d['id'])) + if d["peak_deflect"][0] == lims_features["vm_for_sag"]: + v, i, t, r, dt = load_experiment(nwb_file, int(d["id"])) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) - plt.plot(t, v, color='black') - plt.scatter(d['peak_deflect'][1], d['peak_deflect'][0], color='red', zorder=10) - #plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10) + plt.plot(t, v, color="black") + plt.scatter(d["peak_deflect"][1], d["peak_deflect"][0], color="red", zorder=10) + # plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10) plt.xlim(stim_start - 0.25, stim_start + stim_dur + 0.25) - plt.title("sag = {:.3g}".format(lims_features['sag'])) + plt.title("sag = {:.3g}".format(lims_features["sag"])) plt.tight_layout() - save_figure(fig, 'sag', 'sag', image_dir, sizes, cell_image_files, scalew=2) + save_figure(fig, "sag", "sag", image_dir, sizes, cell_image_files, scalew=2) + def mask_nulls(data): - data[0, np.equal(data[0,:], None) | np.equal(data[0,:],0)] = np.nan + data[0, np.equal(data[0, :], None) | np.equal(data[0, :], 0)] = np.nan + def plot_sweep_value_figures(cell_specimen, image_dir, sizes, cell_image_files): - sweeps = sorted(cell_specimen['ephys_sweeps'], key=lambda s: s['sweep_number'] ) - + sweeps = sorted(cell_specimen["ephys_sweeps"], key=lambda s: s["sweep_number"]) + # plot bridge balance - data = np.array([ [ s['bridge_balance_mohm'], s['sweep_number'] ] for s in sweeps ]).T + data = np.array([[s["bridge_balance_mohm"], s["sweep_number"]] for s in sweeps]).T mask_nulls(data) fig = plt.figure() - plt.title('bridge balance') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'bridge_balance', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) + plt.title("bridge balance") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "bridge_balance", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) # plot pre_vm_mv, no blowout sweep - data = np.array([ [ s['pre_vm_mv'], s['sweep_number'] ] - for s in sweeps - if not s['ephys_stimulus']['description'].startswith('EXTPBLWOUT')]).T + data = np.array( + [ + [s["pre_vm_mv"], s["sweep_number"]] + for s in sweeps + if not s["ephys_stimulus"]["description"].startswith("EXTPBLWOUT") + ] + ).T mask_nulls(data) fig = plt.figure() - plt.title('pre vm') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'pre_vm_mv', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) + plt.title("pre vm") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "pre_vm_mv", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) # plot bias current - data = np.array([ [ s['leak_pa'], s['sweep_number'] ] for s in sweeps ]).T + data = np.array([[s["leak_pa"], s["sweep_number"]] for s in sweeps]).T mask_nulls(data) - + fig = plt.figure() - plt.title('leak') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'leak', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) + plt.title("leak") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "leak", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) + def plot_cell_figures(nwb_file, ephys_roi_result, image_dir, sizes): - cell_image_files = {} - plt.style.use('ggplot') + plt.style.use("ggplot") cell_specimen = ephys_roi_result["specimens"][0] cell_features = cell_specimen["cell_ephys_features"] @@ -616,13 +673,19 @@ def plot_cell_figures(nwb_file, ephys_roi_result, image_dir, sizes): plot_sweep_value_figures(cell_specimen, image_dir, sizes, cell_image_files) logging.info("saving tau and vi figs") - plot_subthreshold_long_square_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files) - + plot_subthreshold_long_square_figures( + nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files + ) + logging.info("saving short square figs") - plot_short_square_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files) + plot_short_square_figures( + nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files + ) logging.info("saving ramps") - plot_ramp_figures(nwb_file, cell_specimen, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files) + plot_ramp_figures( + nwb_file, cell_specimen, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files + ) logging.info("saving rheo figs") plot_rheo_figures(nwb_file, cell_features, lims_features, sweep_features, image_dir, sizes, cell_image_files) @@ -638,17 +701,18 @@ def plot_cell_figures(nwb_file, ephys_roi_result, image_dir, sizes): return cell_image_files -def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, - highlight_color='#0779BE', background_color='#dddddd'): +def plot_sweep_set_summary( + nwb_file, highlight_sweep_number, sweep_numbers, highlight_color="#0779BE", background_color="#dddddd" +): fig = plt.figure(frameon=False) - ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() - fig.add_axes(ax) + fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) - ax.set_xlabel('') - ax.set_ylabel('') + ax.set_xlabel("") + ax.set_ylabel("") for sn in sweep_numbers: v, i, t, r, dt = load_experiment(nwb_file, sn) @@ -658,7 +722,7 @@ def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, plt.plot(t, v, linewidth=1, color=highlight_color) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) - + tstart = stim_start - 0.05 tend = stim_start + stim_dur + 0.25 @@ -667,6 +731,7 @@ def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, return fig + def make_sweep_html(sweep_files, file_name): html = "" html += "Cell QC Figures" @@ -674,47 +739,49 @@ def make_sweep_html(sweep_files, file_name): html += "

page created at: %s

" % get_time_string() html += "
" - if 'test_pulses' in sweep_files: - for small_img, large_img in zip(sweep_files['test_pulses']['small'], - sweep_files['test_pulses']['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) + if "test_pulses" in sweep_files: + for small_img, large_img in zip(sweep_files["test_pulses"]["small"], sweep_files["test_pulses"]["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) html += "
" html += "
" - if 'experiments' in sweep_files: - for small_img, large_img in zip(sweep_files['experiments']['small'], - sweep_files['experiments']['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) + if "experiments" in sweep_files: + for small_img, large_img in zip(sweep_files["experiments"]["small"], sweep_files["experiments"]["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) html += "
" - + html += "" - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(html) -def make_cell_html(image_files, ephys_roi_result, file_name, relative_sweep_link): +def make_cell_html(image_files, ephys_roi_result, file_name, relative_sweep_link): html = "" - specimen = ephys_roi_result['specimens'][0] + specimen = ephys_roi_result["specimens"][0] - html += "

Specimen %d: %s

" % ( specimen['id'], specimen['name'] ) + html += "

Specimen %d: %s

" % (specimen["id"], specimen["name"]) html += "

page created at: %s

" % get_time_string() if relative_sweep_link: html += "

Sweep QC Figures

" else: - sweep_qc_link = '/'.join([ephys_roi_result['storage_directory'], 'qc_figures', 'sweep.html']) + sweep_qc_link = "/".join([ephys_roi_result["storage_directory"], "qc_figures", "sweep.html"]) sweep_qc_link = lims_utilities.safe_system_path(sweep_qc_link) html += "

Sweep QC Figures

" % sweep_qc_link - fields_to_show = [ 'electrode_0_pa', 'seal_gohm', 'initial_access_resistance_mohm', 'input_resistance_mohm' ] + fields_to_show = ["electrode_0_pa", "seal_gohm", "initial_access_resistance_mohm", "input_resistance_mohm"] html += "" for field in fields_to_show: - html += "" % (field, ephys_roi_result.get(field,None)) + html += "" % (field, ephys_roi_result.get(field, None)) html += "
%s%s
%s%s
" for image_file_set_name in image_files: @@ -722,57 +789,57 @@ def make_cell_html(image_files, ephys_roi_result, file_name, relative_sweep_link image_set_files = image_files[image_file_set_name] - for small_img, large_img in zip(image_set_files['small'], image_set_files['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) - html += ("") + for small_img, large_img in zip(image_set_files["small"], image_set_files["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) + html += "" - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(html) + def make_sweep_page(nwb_file, ephys_roi_result, working_dir): - sizes = { 'small': 2.0, 'large': 6.0 } + sizes = {"small": 2.0, "large": 6.0} sweep_files = plot_sweep_figures(nwb_file, ephys_roi_result, working_dir, sizes) - make_sweep_html(sweep_files, - os.path.join(working_dir, 'sweep.html')) + make_sweep_html(sweep_files, os.path.join(working_dir, "sweep.html")) -def make_cell_page(nwb_file, ephys_roi_result, working_dir, save_cell_plots=True): +def make_cell_page(nwb_file, ephys_roi_result, working_dir, save_cell_plots=True): if save_cell_plots: - sizes = { 'small': 2.0, 'large': 6.0 } + sizes = {"small": 2.0, "large": 6.0} cell_files = plot_cell_figures(nwb_file, ephys_roi_result, working_dir, sizes) else: cell_files = {} - + logging.info("saving images") - sizes = { 'small': 200, 'large': None } + sizes = {"small": 200, "large": None} plot_images(ephys_roi_result, working_dir, sizes, cell_files) - sweep_page = os.path.join(working_dir, 'sweep.html') + sweep_page = os.path.join(working_dir, "sweep.html") relative_sweep_link = os.path.exists(sweep_page) if not relative_sweep_link: logging.info("sweep page doesn't exist, point to production sweep page") - - make_cell_html(cell_files, ephys_roi_result, - os.path.join(working_dir, 'index.html'), - relative_sweep_link) + + make_cell_html(cell_files, ephys_roi_result, os.path.join(working_dir, "index.html"), relative_sweep_link) + def exp_curve(x, a, inv_tau, y0): - ''' Function used for tau curve fitting ''' + """Function used for tau curve fitting""" return y0 + a * np.exp(-inv_tau * x) def main(): - parser = argparse.ArgumentParser(description='analyze specimens for cell-wide features') - parser.add_argument('nwb_file') - parser.add_argument('feature_json') - parser.add_argument('--output_directory', default='.') - parser.add_argument('--no-sweep-page', action='store_false', dest='sweep_page') - parser.add_argument('--no-cell-page', action='store_false', dest='cell_page') - parser.add_argument('--log_level') - + parser = argparse.ArgumentParser(description="analyze specimens for cell-wide features") + parser.add_argument("nwb_file") + parser.add_argument("feature_json") + parser.add_argument("--output_directory", default=".") + parser.add_argument("--no-sweep-page", action="store_false", dest="sweep_page") + parser.add_argument("--no-cell-page", action="store_false", dest="cell_page") + parser.add_argument("--log_level") args = parser.parse_args() @@ -790,6 +857,5 @@ def main(): make_cell_page(args.nwb_file, ephys_roi_result, args.output_directory, True) - -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/ephys/plot_qc_figures3.py b/allensdk/internal/ephys/plot_qc_figures3.py index 3b7f03e94c..63468a1b33 100644 --- a/allensdk/internal/ephys/plot_qc_figures3.py +++ b/allensdk/internal/ephys/plot_qc_figures3.py @@ -1,6 +1,6 @@ import matplotlib -matplotlib.use('agg') +matplotlib.use("agg") import logging @@ -18,36 +18,41 @@ import datetime import matplotlib.pyplot as plt -#import seaborn as sns +# import seaborn as sns + +AXIS_Y_RANGE = [-110, 60] -AXIS_Y_RANGE = [ -110, 60 ] def get_time_string(): return datetime.datetime.now().strftime("%I:%M%p %B %d, %Y") + def get_spikes(sweep_features, sweep_number): return get_features(sweep_features, sweep_number)["spikes"] + def get_features(sweep_features, sweep_number): - try: + try: return sweep_features[int(sweep_number)] except KeyError: return sweep_features[str(sweep_number)] + def load_experiment(file_name, sweep_number): ds = NwbDataSet(file_name) sweep = ds.get_sweep(sweep_number) - - r = sweep['index_range'] - v = sweep['response'] * 1e3 - i = sweep['stimulus'] * 1e12 - dt = 1.0 / sweep['sampling_rate'] + + r = sweep["index_range"] + v = sweep["response"] * 1e3 + i = sweep["stimulus"] * 1e12 + dt = 1.0 / sweep["sampling_rate"] t = np.arange(0, len(v)) * dt return (v, i, t, r, dt) + def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features, cell_features, type_name): - figs = [ plt.figure() for f in range(3+len(sweep_numbers)) ] + figs = [plt.figure() for f in range(3 + len(sweep_numbers))] v, i, t, r, dt = load_experiment(nwb_file, sweep_numbers[0]) if type_name == "short_square" or type_name == "long_square": @@ -62,7 +67,7 @@ def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features for sn in sweep_numbers: spikes = get_spikes(sweep_features, sn) - if (len(spikes) < 1): + if len(spikes) < 1: logging.warning("no spikes in sweep %d" % sn) continue @@ -72,42 +77,55 @@ def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] - voltages = [ rheo_spike[f] for f in voltage_features] - times = [ rheo_spike[f] for f in time_features] + voltages = [rheo_spike[f] for f in voltage_features] + times = [rheo_spike[f] for f in time_features] plt.figure(figs[0].number) - plt.scatter(range(len(voltages)), voltages, color='gray') + plt.scatter(range(len(voltages)), voltages, color="gray") plt.tight_layout() - plt.figure(figs[1].number) - plt.scatter(range(len(times)), times, color='gray') + plt.scatter(range(len(times)), times, color="gray") plt.tight_layout() plt.figure(figs[2].number) - plt.scatter([0], [spikes[0]['upstroke'] / (-spikes[0]['downstroke'])], color='gray') + plt.scatter([0], [spikes[0]["upstroke"] / (-spikes[0]["downstroke"])], color="gray") plt.tight_layout() - plt.figure(figs[0].number) - - yvals = [float(rheo_features[k + "_v_" + type_name]) for k in gen_features if rheo_features[k + "_v_" + type_name] is not None] + + yvals = [ + float(rheo_features[k + "_v_" + type_name]) + for k in gen_features + if rheo_features[k + "_v_" + type_name] is not None + ] xvals = range(len(yvals)) - - plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) - plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) + + plt.scatter(xvals, yvals, color="blue", marker="_", s=40, zorder=100) + plt.xticks(xvals, ["thr", "pk", "tr", "ftr", "str"]) plt.title(type_name + ": voltages") plt.figure(figs[1].number) - yvals = [float(rheo_features[k + "_t_" + type_name]) for k in gen_features if rheo_features[k + "_t_" + type_name] is not None] + yvals = [ + float(rheo_features[k + "_t_" + type_name]) + for k in gen_features + if rheo_features[k + "_t_" + type_name] is not None + ] xvals = range(len(yvals)) - plt.scatter(xvals, yvals, color='blue', marker='_', s=40, zorder=100) - plt.xticks(xvals, ['thr', 'pk', 'tr', 'ftr', 'str']) + plt.scatter(xvals, yvals, color="blue", marker="_", s=40, zorder=100) + plt.xticks(xvals, ["thr", "pk", "tr", "ftr", "str"]) plt.title(type_name + ": times") - + plt.figure(figs[2].number) if rheo_features["upstroke_downstroke_ratio_" + type_name] is not None: - plt.scatter([0], [float(rheo_features["upstroke_downstroke_ratio_" + type_name])], color='blue', marker='_', s=40, zorder=100) + plt.scatter( + [0], + [float(rheo_features["upstroke_downstroke_ratio_" + type_name])], + color="blue", + marker="_", + s=40, + zorder=100, + ) plt.xticks([]) plt.title(type_name + ": up/down") @@ -115,7 +133,7 @@ def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features plt.figure(figs[3 + index].number) v, i, t, r, dt = load_experiment(nwb_file, sn) - plt.plot(t, v, color='black') + plt.plot(t, v, color="black") plt.title(str(sn)) spikes = get_spikes(sweep_features, sn) @@ -132,21 +150,30 @@ def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features else: rheo_sn = cell_features["long_squares"]["rheobase_sweep"]["id"] rheo_spike = get_spikes(sweep_features, rheo_sn)[0] - voltages = [ rheo_spike[f] for f in voltage_features ] - times = [ rheo_spike[f] for f in time_features ] + voltages = [rheo_spike[f] for f in voltage_features] + times = [rheo_spike[f] for f in time_features] - plt.scatter(times, voltages, color='red', zorder=20) - + plt.scatter(times, voltages, color="red", zorder=20) delta_v = 5.0 if nspikes: - plt.plot([spikes[0]['upstroke_t'] - 1e-3 * (delta_v / spikes[0]['upstroke']), - spikes[0]['upstroke_t'] + 1e-3 * (delta_v / spikes[0]['upstroke'])], - [spikes[0]['upstroke_v'] - delta_v, spikes[0]['upstroke_v'] + delta_v], color='red') - - plt.plot([spikes[0]['downstroke_t'] - 1e-3 * (delta_v / spikes[0]['downstroke']), - spikes[0]['downstroke_t'] + 1e-3 * (delta_v / spikes[0]['downstroke'])], - [spikes[0]['downstroke_v'] - delta_v, spikes[0]['downstroke_v'] + delta_v], color='red') + plt.plot( + [ + spikes[0]["upstroke_t"] - 1e-3 * (delta_v / spikes[0]["upstroke"]), + spikes[0]["upstroke_t"] + 1e-3 * (delta_v / spikes[0]["upstroke"]), + ], + [spikes[0]["upstroke_v"] - delta_v, spikes[0]["upstroke_v"] + delta_v], + color="red", + ) + + plt.plot( + [ + spikes[0]["downstroke_t"] - 1e-3 * (delta_v / spikes[0]["downstroke"]), + spikes[0]["downstroke_t"] + 1e-3 * (delta_v / spikes[0]["downstroke"]), + ], + [spikes[0]["downstroke_v"] - delta_v, spikes[0]["downstroke_v"] + delta_v], + color="red", + ) if type_name == "ramp": if nspikes: @@ -154,76 +181,74 @@ def plot_single_ap_values(nwb_file, sweep_numbers, rheo_features, sweep_features elif type_name == "short_square": plt.xlim(stim_start - 0.002, stim_start + stim_dur + 0.01) elif type_name == "long_square": - plt.xlim(times[0]- 0.002, times[-2] + 0.002) + plt.xlim(times[0] - 0.002, times[-2] + 0.002) plt.tight_layout() - return figs -#def plot_sweep_figures(nwb_file, ephys_roi_result, image_dir, sizes): + +# def plot_sweep_figures(nwb_file, ephys_roi_result, image_dir, sizes): def plot_sweep_figures(nwb_file, sweep_data, image_dir, sizes): -# try: -# sweeps = ephys_roi_result["specimens"][0]["ephys_sweeps"] -# except: -# sweeps = ephys_roi_result["specimens"]["ephys_sweeps"] - vclamp_sweep_numbers = sorted([ s['sweep_number'] for s in sweep_data if s['stimulus_units'] == 'Amps' ]) + # try: + # sweeps = ephys_roi_result["specimens"][0]["ephys_sweeps"] + # except: + # sweeps = ephys_roi_result["specimens"]["ephys_sweeps"] + vclamp_sweep_numbers = sorted([s["sweep_number"] for s in sweep_data if s["stimulus_units"] == "Amps"]) image_file_sets = {} - - tp_len = 0.035 tp_steps = int(tp_len * 200000) b, a = sg.bessel(4, 0.1, "low") for i, sweep_number in enumerate(vclamp_sweep_numbers): - logging.info("plotting sweep %d" % sweep_number) + logging.info("plotting sweep %d" % sweep_number) if i == 0: - v_init, i_init, t_init, r_init, dt_init = load_experiment(nwb_file, sweep_number) + v_init, i_init, t_init, r_init, dt_init = load_experiment(nwb_file, sweep_number) tp_fig = plt.figure() axTP = plt.gca() axTP.set_yticklabels([]) axTP.set_xticklabels([]) axTP.set_xlabel(str(sweep_number)) - axTP.set_ylabel('') + axTP.set_ylabel("") xTP = t_init[0:tp_steps] yTP = v_init[0:tp_steps] axTP.plot(xTP, yTP, linewidth=1) axTP.set_xlim(0, tp_len) -# sns.despine() - + # sns.despine() + exp_fig = plt.figure() axDP = plt.gca() axDP.set_yticklabels([]) axDP.set_xticklabels([]) axDP.set_xlabel(str(sweep_number)) - axDP.set_ylabel('') - v_exp = v_init[r_init[0]:] - t_exp = t_init[r_init[0]:] + axDP.set_ylabel("") + v_exp = v_init[r_init[0] :] + t_exp = t_init[r_init[0] :] yDP = sg.filtfilt(b, a, v_exp, axis=0) xDP = t_exp baseline = yDP[5000:9000] baselineMean = np.mean(baseline) - baselineV = (np.ones(len(xDP))) * baselineMean + baselineV = (np.ones(len(xDP))) * baselineMean axDP.plot(xDP, yDP, linewidth=1) axDP.plot(xDP, baselineV, linewidth=1) axDP.set_xlim(t_exp[0], t_exp[-1]) -# sns.despine() + # sns.despine() v_prev, _i_prev, _t_prev, _r_prev = v_init, i_init, t_init, r_init # noqa: F841 else: - v, i, t, r, dt = load_experiment(nwb_file, sweep_number) + v, i, t, r, dt = load_experiment(nwb_file, sweep_number) tp_fig = plt.figure() axTP = plt.gca() axTP.set_yticklabels([]) axTP.set_xticklabels([]) axTP.set_xlabel(str(sweep_number)) - axTP.set_ylabel('') + axTP.set_ylabel("") yTP = v[:tp_steps] xTP = t[:tp_steps] TPBL = np.mean(yTP[0:100]) @@ -238,42 +263,42 @@ def plot_sweep_figures(nwb_file, sweep_data, image_dir, sizes): axTP.plot(xTP, yTPpN, linewidth=1) axTP.plot(xTP, yTPN, linewidth=1) axTP.set_xlim(0, tp_len) -# sns.despine() + # sns.despine() exp_fig = plt.figure() axDP = plt.gca() axDP.set_yticklabels([]) axDP.set_xticklabels([]) axDP.set_xlabel(str(sweep_number)) - axDP.set_ylabel('') - v_exp = v[r[0]:] - t_exp = t[r[0]:] + axDP.set_ylabel("") + v_exp = v[r[0] :] + t_exp = t[r[0] :] yDP = sg.filtfilt(b, a, v_exp, axis=0) xDP = t_exp baseline = yDP[5000:9000] baselineMean = np.mean(baseline) - baselineV = (np.ones(len(xDP))) * baselineMean + baselineV = (np.ones(len(xDP))) * baselineMean axDP.plot(xDP, yDP, linewidth=1) axDP.plot(xDP, baselineV, linewidth=1) axDP.set_xlim(t_exp[0], t_exp[-1]) -# sns.despine() + # sns.despine() v_prev, _i_prev, _t_prev, _r_prev = v, i, t, r # noqa: F841 - - save_figure(tp_fig, 'test_pulse_%d' % sweep_number, 'test_pulses', image_dir, sizes, image_file_sets) - save_figure(exp_fig, 'experiment_%d' % sweep_number, 'experiments', image_dir, sizes, image_file_sets) + save_figure(tp_fig, "test_pulse_%d" % sweep_number, "test_pulses", image_dir, sizes, image_file_sets) + save_figure(exp_fig, "experiment_%d" % sweep_number, "experiments", image_dir, sizes, image_file_sets) return image_file_sets -def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, scalew=1, scaleh=1, ext='jpg'): + +def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, scalew=1, scaleh=1, ext="jpg"): plt.figure(fig.number) if image_set_name not in image_sets: - image_sets[image_set_name] = { size_name: [] for size_name in sizes } + image_sets[image_set_name] = {size_name: [] for size_name in sizes} for size_name, size in sizes.items(): - fig.set_size_inches(size*scalew, size*scaleh) + fig.set_size_inches(size * scalew, size * scaleh) image_file = os.path.join(image_dir, "%s_%s.%s" % (image_name, size_name, ext)) @@ -285,22 +310,22 @@ def save_figure(fig, image_name, image_set_name, image_dir, sizes, image_sets, s def plot_images(well_known_files, image_dir, sizes, image_sets): - wkfs = [ f for f in well_known_files if f['filename'].endswith('tif') ] - - paths = [ os.path.join(f['storage_directory'], f['filename']) for f in wkfs ] + wkfs = [f for f in well_known_files if f["filename"].endswith("tif")] - paths = [ lims_utilities.safe_system_path(p) for p in paths ] + paths = [os.path.join(f["storage_directory"], f["filename"]) for f in wkfs] + + paths = [lims_utilities.safe_system_path(p) for p in paths] image_set_name = "images" - image_sets[image_set_name] = { size_name: [] for size_name in sizes } + image_sets[image_set_name] = {size_name: [] for size_name in sizes} for i, path in enumerate(paths): image_data = plt.imread(path) image_data = np.array(image_data, dtype=np.float32) - + vmin = image_data.min() vmax = image_data.max() - + image_data = np.array((image_data - vmin) / (vmax - vmin) * 255.0, dtype=np.uint8) for size_name, size in sizes.items(): @@ -311,24 +336,25 @@ def plot_images(well_known_files, image_dir, sizes, image_sets): else: sdata = image_data - filename = os.path.join(image_dir, "image_%d_%s.jpg" % (i, size_name)) scipy.misc.imsave(filename, sdata) - image_sets['images'][size_name].append(filename) - + image_sets["images"][size_name].append(filename) -def plot_subthreshold_long_square_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): + +def plot_subthreshold_long_square_figures( + nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files +): sub_sweeps = cell_features["long_squares"]["subthreshold_sweeps"] tau_sweeps = cell_features["long_squares"]["subthreshold_membrane_property_sweeps"] # 0a - Plot VI curve and linear fit, along with vrest - x = np.array([ s['stim_amp'] for s in sub_sweeps ]) - y = np.array([ s['peak_deflect'][0] for s in sub_sweeps ]) - i = np.array([ s['stim_amp'] for s in tau_sweeps ]) + x = np.array([s["stim_amp"] for s in sub_sweeps]) + y = np.array([s["peak_deflect"][0] for s in sub_sweeps]) + i = np.array([s["stim_amp"] for s in tau_sweeps]) fig = plt.figure() - plt.scatter(x, y, color='black') + plt.scatter(x, y, color="black") plt.plot([x.min(), x.max()], [rheo_features["vrest"], rheo_features["vrest"]], color="blue", linewidth=2) plt.plot(i, i * 1e-3 * rheo_features["ri"] + rheo_features["vrest"], color="red", linewidth=2) plt.xlabel("pA") @@ -336,35 +362,39 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, rheo_features plt.title("ri = {:.1f}, vrest = {:.1f}".format(rheo_features["ri"], rheo_features["vrest"])) plt.tight_layout() - save_figure(fig, 'VI_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) - + save_figure(fig, "VI_curve", "subthreshold_long_squares", image_dir, sizes, cell_image_files) + # 0b - Plot tau curve and average fig = plt.figure() - x = np.array([ s['stim_amp'] for s in tau_sweeps ]) - y = np.array([ s['tau'] for s in tau_sweeps ]) - plt.scatter(x, y, color='black') - i = np.array([ s['stim_amp'] for s in tau_sweeps ]) - plt.plot([i.min(), i.max()], [cell_features["long_squares"]["tau"], cell_features["long_squares"]["tau"]], color="red", linewidth=2) + x = np.array([s["stim_amp"] for s in tau_sweeps]) + y = np.array([s["tau"] for s in tau_sweeps]) + plt.scatter(x, y, color="black") + i = np.array([s["stim_amp"] for s in tau_sweeps]) + plt.plot( + [i.min(), i.max()], + [cell_features["long_squares"]["tau"], cell_features["long_squares"]["tau"]], + color="red", + linewidth=2, + ) plt.xlabel("pA") ylim = plt.ylim() plt.ylim(0, ylim[1]) plt.ylabel("tau (s)") plt.tight_layout() + save_figure(fig, "tau_curve", "subthreshold_long_squares", image_dir, sizes, cell_image_files) - save_figure(fig, 'tau_curve', 'subthreshold_long_squares', image_dir, sizes, cell_image_files) - - subthresh_dict = {s['id']:s for s in tau_sweeps} + subthresh_dict = {s["id"]: s for s in tau_sweeps} # 0c - Plot the subthreshold squares - tau_sweeps = [ s['id'] for s in tau_sweeps ] - tau_figs = [ plt.figure() for i in range(len(tau_sweeps)) ] + tau_sweeps = [s["id"] for s in tau_sweeps] + tau_figs = [plt.figure() for i in range(len(tau_sweeps))] for index, s in enumerate(tau_sweeps): v, i, t, r, dt = load_experiment(nwb_file, s) - + plt.figure(tau_figs[index].number) - + plt.plot(t, v, color="black") if index == 0: @@ -378,13 +408,12 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, rheo_features stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) - peak_idx = subthresh_dict[s]['peak_deflect'][1] - peak_t = peak_idx*dt - plt.scatter([peak_t], [subthresh_dict[s]['peak_deflect'][0]], color='red', zorder=10) + peak_idx = subthresh_dict[s]["peak_deflect"][1] + peak_t = peak_idx * dt + plt.scatter([peak_t], [subthresh_dict[s]["peak_deflect"][0]], color="red", zorder=10) popt = ft.fit_membrane_time_constant(v, t, stim_start, peak_t) plt.title(str(s)) - plt.plot(t[start_idx:peak_idx], exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt), color='blue') - + plt.plot(t[start_idx:peak_idx], exp_curve(t[start_idx:peak_idx] - t[start_idx], *popt), color="blue") for index, s in enumerate(tau_sweeps): plt.figure(tau_figs[index].number) @@ -392,32 +421,37 @@ def plot_subthreshold_long_square_figures(nwb_file, cell_features, rheo_features plt.tight_layout() for index, tau_fig in enumerate(tau_figs): - save_figure(tau_figs[index], 'tau_%d' % index, 'subthreshold_long_squares', image_dir, sizes, cell_image_files) + save_figure(tau_figs[index], "tau_%d" % index, "subthreshold_long_squares", image_dir, sizes, cell_image_files) + -def plot_short_square_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): +def plot_short_square_figures( + nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files +): repeat_amp = cell_features["short_squares"].get("stimulus_amplitude", None) if repeat_amp is not None: - short_square_sweep_nums = [ s['id'] for s in cell_features["short_squares"]["common_amp_sweeps"] ] + short_square_sweep_nums = [s["id"] for s in cell_features["short_squares"]["common_amp_sweeps"]] - figs = plot_single_ap_values(nwb_file, short_square_sweep_nums, - rheo_features, sweep_features, cell_features, - "short_square") + figs = plot_single_ap_values( + nwb_file, short_square_sweep_nums, rheo_features, sweep_features, cell_features, "short_square" + ) for index, fig in enumerate(figs): - save_figure(fig, 'short_squares_%d' % index, 'short_squares', image_dir, sizes, cell_image_files) + save_figure(fig, "short_squares_%d" % index, "short_squares", image_dir, sizes, cell_image_files) - fig = plot_instantaneous_threshold_thumbnail(nwb_file, short_square_sweep_nums, - cell_features, rheo_features, sweep_features) + fig = plot_instantaneous_threshold_thumbnail( + nwb_file, short_square_sweep_nums, cell_features, rheo_features, sweep_features + ) - save_figure(fig, 'instantaneous_threshold_thumbnail', 'short_squares', image_dir, sizes, cell_image_files) - + save_figure(fig, "instantaneous_threshold_thumbnail", "short_squares", image_dir, sizes, cell_image_files) else: logging.warning("No short square figures to plot.") -def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_features, rheo_features, sweep_features, color='red'): +def plot_instantaneous_threshold_thumbnail( + nwb_file, sweep_numbers, cell_features, rheo_features, sweep_features, color="red" +): min_sweep_number = None for sn in sorted(sweep_numbers): spikes = get_spikes(sweep_features, sn) @@ -426,13 +460,13 @@ def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_feature min_sweep_number = sn if min_sweep_number is None else min(min_sweep_number, sn) fig = plt.figure(frameon=False) - ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() - fig.add_axes(ax) + fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) - ax.set_xlabel('') - ax.set_ylabel('') + ax.set_xlabel("") + ax.set_ylabel("") v, i, t, r, dt = load_experiment(nwb_file, sn) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) @@ -441,44 +475,52 @@ def plot_instantaneous_threshold_thumbnail(nwb_file, sweep_numbers, cell_feature tend = stim_start + stim_dur + 0.005 plt.plot(t, v, linewidth=1, color=color) - + plt.ylim(AXIS_Y_RANGE[0], AXIS_Y_RANGE[1]) plt.xlim(tstart, tend) return fig -def plot_ramp_figures(nwb_file, sweep_info, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): +def plot_ramp_figures( + nwb_file, sweep_info, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files +): sweeps = sweep_info - ramps_sweeps = [ s["sweep_number"] for s in sweeps if s["workflow_state"].endswith("passed") and s["ephys_stimulus"]["description"][:10] == "C1RP25PR1S"] + ramps_sweeps = [ + s["sweep_number"] + for s in sweeps + if s["workflow_state"].endswith("passed") and s["ephys_stimulus"]["description"][:10] == "C1RP25PR1S" + ] figs = [] if len(ramps_sweeps) > 0: figs = plot_single_ap_values(nwb_file, ramps_sweeps, rheo_features, sweep_features, cell_features, "ramp") for index, fig in enumerate(figs): - save_figure(fig, 'ramps_%d' % index, 'ramps', image_dir, sizes, cell_image_files) + save_figure(fig, "ramps_%d" % index, "ramps", image_dir, sizes, cell_image_files) + def plot_rheo_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): - rheo_sweeps = [ rheo_features["rheobase_sweep_num"] ] + rheo_sweeps = [rheo_features["rheobase_sweep_num"]] figs = plot_single_ap_values(nwb_file, rheo_sweeps, rheo_features, sweep_features, cell_features, "long_square") for index, fig in enumerate(figs): - save_figure(fig, 'rheo_%d' % index, 'rheo', image_dir, sizes, cell_image_files) + save_figure(fig, "rheo_%d" % index, "rheo", image_dir, sizes, cell_image_files) + def plot_hero_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() v, i, t, r, dt = load_experiment(nwb_file, int(rheo_features["thumbnail_sweep_num"])) - plt.plot(t, v, color='black') + plt.plot(t, v, color="black") stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) plt.xlim(stim_start - 0.05, stim_start + stim_dur + 0.05) plt.ylim(-110, 50) - spike_times = [spk['threshold_t'] for spk in get_spikes(sweep_features, rheo_features["thumbnail_sweep_num"])] + spike_times = [spk["threshold_t"] for spk in get_spikes(sweep_features, rheo_features["thumbnail_sweep_num"])] isis = np.diff(np.array(spike_times)) plt.title("thumbnail {:d}, amp = {:.1f}".format(rheo_features["thumbnail_sweep_num"], stim_amp)) plt.tight_layout() - - save_figure(fig, 'thumbnail_0', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2) + + save_figure(fig, "thumbnail_0", "thumbnail", image_dir, sizes, cell_image_files, scalew=2) fig = plt.figure() plt.plot(range(len(isis)), isis) @@ -487,14 +529,14 @@ def plot_hero_figures(nwb_file, cell_features, rheo_features, sweep_features, im plt.title("adapt = {:.3g}".format(rheo_features["adaptation"])) else: plt.title("adapt = not defined") - + for k in ["has_delay", "has_burst", "has_pause"]: if rheo_features.get(k, None) is None: rheo_features[k] = False plt.tight_layout() - save_figure(fig, 'thumbnail_1', 'thumbnail', image_dir, sizes, cell_image_files) - + save_figure(fig, "thumbnail_1", "thumbnail", image_dir, sizes, cell_image_files) + yvals = [ float(rheo_features["has_delay"]), float(rheo_features["has_burst"]), @@ -503,22 +545,24 @@ def plot_hero_figures(nwb_file, cell_features, rheo_features, sweep_features, im xvals = range(len(yvals)) fig = plt.figure() - plt.scatter(xvals, yvals, color='red') - plt.xticks(xvals, ['Delay', 'Burst', 'Pause']) + plt.scatter(xvals, yvals, color="red") + plt.xticks(xvals, ["Delay", "Burst", "Pause"]) plt.title("flags") plt.tight_layout() - save_figure(fig, 'thumbnail_2', 'thumbnail', image_dir, sizes, cell_image_files) + save_figure(fig, "thumbnail_2", "thumbnail", image_dir, sizes, cell_image_files) summary_fig = plot_long_square_summary(nwb_file, cell_features, rheo_features, sweep_features) - save_figure(summary_fig, 'ephys_summary', 'thumbnail', image_dir, sizes, cell_image_files, scalew=2) + save_figure(summary_fig, "ephys_summary", "thumbnail", image_dir, sizes, cell_image_files, scalew=2) def plot_long_square_summary(nwb_file, cell_features, rheo_features, sweep_features): - long_square_sweeps = cell_features['long_squares']['sweeps'] - long_square_sweep_numbers = [ int(s['id']) for s in long_square_sweeps ] - - thumbnail_summary_fig = plot_sweep_set_summary(nwb_file, int(rheo_features['thumbnail_sweep_num']), long_square_sweep_numbers) + long_square_sweeps = cell_features["long_squares"]["sweeps"] + long_square_sweep_numbers = [int(s["id"]) for s in long_square_sweeps] + + thumbnail_summary_fig = plot_sweep_set_summary( + nwb_file, int(rheo_features["thumbnail_sweep_num"]), long_square_sweep_numbers + ) plt.figure(thumbnail_summary_fig.number) return thumbnail_summary_fig @@ -526,12 +570,16 @@ def plot_long_square_summary(nwb_file, cell_features, rheo_features, sweep_featu def plot_fi_curve_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() - fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"], key=lambda s:s['stim_amp']) - x = [d['stim_amp'] for d in fi_sorted] - y = [d['avg_rate'] for d in fi_sorted] - last_zero_idx = np.nonzero(y)[0][0] - 1 - plt.scatter(x, y, color='black') - plt.plot(x[last_zero_idx:], cell_features["long_squares"]["fi_fit_slope"] * (np.array(x[last_zero_idx:]) - x[last_zero_idx]), color='red') + fi_sorted = sorted(cell_features["long_squares"]["spiking_sweeps"], key=lambda s: s["stim_amp"]) + x = [d["stim_amp"] for d in fi_sorted] + y = [d["avg_rate"] for d in fi_sorted] + last_zero_idx = np.nonzero(y)[0][0] - 1 + plt.scatter(x, y, color="black") + plt.plot( + x[last_zero_idx:], + cell_features["long_squares"]["fi_fit_slope"] * (np.array(x[last_zero_idx:]) - x[last_zero_idx]), + color="red", + ) plt.xlabel("pA") plt.ylabel("spikes/sec") plt.title("slope = {:.3g}".format(rheo_features["f_i_curve_slope"])) @@ -541,89 +589,96 @@ def plot_fi_curve_figures(nwb_file, cell_features, rheo_features, sweep_features v, i, t, r, dt = load_experiment(nwb_file, s) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) rheo_hero_x.append(stim_amp) - rheo_hero_y = [ len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps ] + rheo_hero_y = [len(get_spikes(sweep_features, s)) for s in rheo_hero_sweeps] plt.scatter(rheo_hero_x, rheo_hero_y, zorder=20) plt.tight_layout() - save_figure(fig, 'fi_curve', 'fi_curve', image_dir, sizes, cell_image_files, scalew=2) + save_figure(fig, "fi_curve", "fi_curve", image_dir, sizes, cell_image_files, scalew=2) + def plot_sag_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files): fig = plt.figure() for d in cell_features["long_squares"]["subthreshold_sweeps"]: - if d['peak_deflect'][0] == rheo_features["vm_for_sag"]: - v, i, t, r, dt = load_experiment(nwb_file, int(d['id'])) + if d["peak_deflect"][0] == rheo_features["vm_for_sag"]: + v, i, t, r, dt = load_experiment(nwb_file, int(d["id"])) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) - plt.plot(t, v, color='black') - plt.scatter(d['peak_deflect'][1], d['peak_deflect'][0], color='red', zorder=10) - #plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10) + plt.plot(t, v, color="black") + plt.scatter(d["peak_deflect"][1], d["peak_deflect"][0], color="red", zorder=10) + # plt.plot([stim_start + stim_dur - 0.1, stim_start + stim_dur], [d['steady'], d['steady']], color='red', zorder=10) plt.xlim(stim_start - 0.25, stim_start + stim_dur + 0.25) - plt.title("sag = {:.3g}".format(rheo_features['sag'])) + plt.title("sag = {:.3g}".format(rheo_features["sag"])) plt.tight_layout() - save_figure(fig, 'sag', 'sag', image_dir, sizes, cell_image_files, scalew=2) + save_figure(fig, "sag", "sag", image_dir, sizes, cell_image_files, scalew=2) + def mask_nulls(data): - data[0, np.equal(data[0,:], None) | np.equal(data[0,:],0)] = np.nan + data[0, np.equal(data[0, :], None) | np.equal(data[0, :], 0)] = np.nan + def plot_sweep_value_figures(sweep_info, image_dir, sizes, cell_image_files): - sweeps = sorted(sweep_info, key=lambda s: s['sweep_number'] ) - + sweeps = sorted(sweep_info, key=lambda s: s["sweep_number"]) + # plot bridge balance - data = np.array([ [ s['bridge_balance_mohm'], s['sweep_number'] ] for s in sweeps ]).T + data = np.array([[s["bridge_balance_mohm"], s["sweep_number"]] for s in sweeps]).T mask_nulls(data) fig = plt.figure() - plt.title('bridge balance') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'bridge_balance', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) + plt.title("bridge balance") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "bridge_balance", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) # plot pre_vm_mv, no blowout sweep - data = np.array([ [ s['pre_vm_mv'], s['sweep_number'] ] - for s in sweeps - if not s['ephys_stimulus']['description'].startswith('EXTPBLWOUT')]).T + data = np.array( + [ + [s["pre_vm_mv"], s["sweep_number"]] + for s in sweeps + if not s["ephys_stimulus"]["description"].startswith("EXTPBLWOUT") + ] + ).T mask_nulls(data) fig = plt.figure() - plt.title('pre vm') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'pre_vm_mv', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) + plt.title("pre vm") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "pre_vm_mv", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) # plot bias current - data = np.array([ [ s['leak_pa'], s['sweep_number'] ] for s in sweeps ]).T + data = np.array([[s["leak_pa"], s["sweep_number"]] for s in sweeps]).T mask_nulls(data) - + fig = plt.figure() - plt.title('leak') - plt.plot(data[1,:], data[0,:], marker='.') - - save_figure(fig, 'leak', 'sweep_values', image_dir, sizes, cell_image_files, scalew=2) - -#def plot_cell_figures(nwb_file, ephys_roi_result, image_dir, sizes): -def plot_cell_figures(nwb_file, - cell_features, - sweep_features, - rheo_features, - image_dir, - sweep_info, - sizes): - + plt.title("leak") + plt.plot(data[1, :], data[0, :], marker=".") + + save_figure(fig, "leak", "sweep_values", image_dir, sizes, cell_image_files, scalew=2) + + +# def plot_cell_figures(nwb_file, ephys_roi_result, image_dir, sizes): +def plot_cell_figures(nwb_file, cell_features, sweep_features, rheo_features, image_dir, sweep_info, sizes): cell_image_files = {} - plt.style.use('ggplot') + plt.style.use("ggplot") logging.info("saving sweep feature figures") plot_sweep_value_figures(sweep_info, image_dir, sizes, cell_image_files) logging.info("saving tau and vi figs") - plot_subthreshold_long_square_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files) - + plot_subthreshold_long_square_figures( + nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files + ) + logging.info("saving short square figs") - plot_short_square_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files) + plot_short_square_figures( + nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files + ) logging.info("saving ramps") - plot_ramp_figures(nwb_file, sweep_info, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files) + plot_ramp_figures( + nwb_file, sweep_info, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files + ) logging.info("saving rheo figs") plot_rheo_figures(nwb_file, cell_features, rheo_features, sweep_features, image_dir, sizes, cell_image_files) @@ -639,17 +694,18 @@ def plot_cell_figures(nwb_file, return cell_image_files -def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, - highlight_color='#0779BE', background_color='#dddddd'): +def plot_sweep_set_summary( + nwb_file, highlight_sweep_number, sweep_numbers, highlight_color="#0779BE", background_color="#dddddd" +): fig = plt.figure(frameon=False) - ax = plt.Axes(fig, [0., 0., 1., 1.]) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) ax.set_axis_off() - fig.add_axes(ax) + fig.add_axes(ax) ax.set_yticklabels([]) ax.set_xticklabels([]) - ax.set_xlabel('') - ax.set_ylabel('') + ax.set_xlabel("") + ax.set_ylabel("") for sn in sweep_numbers: v, i, t, r, dt = load_experiment(nwb_file, sn) @@ -659,7 +715,7 @@ def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, plt.plot(t, v, linewidth=1, color=highlight_color) stim_start, stim_dur, stim_amp, start_idx, end_idx = get_square_stim_characteristics(i, t) - + tstart = stim_start - 0.05 tend = stim_start + stim_dur + 0.25 @@ -668,6 +724,7 @@ def plot_sweep_set_summary(nwb_file, highlight_sweep_number, sweep_numbers, return fig + def make_sweep_html(sweep_files, file_name): html = "" html += "Cell QC Figures" @@ -675,43 +732,44 @@ def make_sweep_html(sweep_files, file_name): html += "

page created at: %s

" % get_time_string() html += "
" - if 'test_pulses' in sweep_files: - for small_img, large_img in zip(sweep_files['test_pulses']['small'], - sweep_files['test_pulses']['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) + if "test_pulses" in sweep_files: + for small_img, large_img in zip(sweep_files["test_pulses"]["small"], sweep_files["test_pulses"]["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) html += "
" html += "
" - if 'experiments' in sweep_files: - for small_img, large_img in zip(sweep_files['experiments']['small'], - sweep_files['experiments']['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) + if "experiments" in sweep_files: + for small_img, large_img in zip(sweep_files["experiments"]["small"], sweep_files["experiments"]["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) html += "
" - + html += "" - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(html) -def make_cell_html(image_files, file_name, relative_sweep_link, specimen_info, fields): +def make_cell_html(image_files, file_name, relative_sweep_link, specimen_info, fields): html = "" - html += "

Specimen %d: %s

" % ( specimen_info["id"], specimen_info["name"] ) + html += "

Specimen %d: %s

" % (specimen_info["id"], specimen_info["name"]) html += "

page created at: %s

" % get_time_string() if relative_sweep_link: html += "

Sweep QC Figures

" else: - sweep_qc_link = '/'.join([specimen_info['storage_directory'], 'qc_figures', 'sweep.html']) + sweep_qc_link = "/".join([specimen_info["storage_directory"], "qc_figures", "sweep.html"]) sweep_qc_link = lims_utilities.safe_system_path(sweep_qc_link) html += "

Sweep QC Figures

" % sweep_qc_link - html += "" - for k,v in fields.items(): + for k, v in fields.items(): html += "" % (k, v) html += "
%s%s
" @@ -720,82 +778,92 @@ def make_cell_html(image_files, file_name, relative_sweep_link, specimen_info, f image_set_files = image_files[image_file_set_name] - for small_img, large_img in zip(image_set_files['small'], image_set_files['large']): - html += "" % ( os.path.basename(large_img), - os.path.basename(small_img) ) - html += ("") + for small_img, large_img in zip(image_set_files["small"], image_set_files["large"]): + html += "" % ( + os.path.basename(large_img), + os.path.basename(small_img), + ) + html += "" - with open(file_name, 'w') as f: + with open(file_name, "w") as f: f.write(html) + def make_sweep_page(nwb_file, working_dir, sweep_data): - sizes = { 'small': 2.0, 'large': 6.0 } + sizes = {"small": 2.0, "large": 6.0} + + sweep_files = plot_sweep_figures(nwb_file=nwb_file, sweep_data=sweep_data, image_dir=working_dir, sizes=sizes) - sweep_files = plot_sweep_figures( - nwb_file=nwb_file, - sweep_data=sweep_data, - image_dir=working_dir, - sizes=sizes) + make_sweep_html(sweep_files, os.path.join(working_dir, "sweep.html")) - make_sweep_html(sweep_files, - os.path.join(working_dir, 'sweep.html')) -#def make_cell_page(nwb_file, ephys_roi_result, working_dir, save_cell_plots=True): -def make_cell_page(nwb_file, cell_features, rheo_features, sweep_features, sweep_info, well_known_files, specimen_info, working_dir, fields_to_show, save_cell_plots=True): - """ nwb_file: name of nwb file (string) +# def make_cell_page(nwb_file, ephys_roi_result, working_dir, save_cell_plots=True): +def make_cell_page( + nwb_file, + cell_features, + rheo_features, + sweep_features, + sweep_info, + well_known_files, + specimen_info, + working_dir, + fields_to_show, + save_cell_plots=True, +): + """nwb_file: name of nwb file (string) - cell_features: + cell_features: - rheo_features: dict containing extracted features from rheobase sweep + rheo_features: dict containing extracted features from rheobase sweep - sweep_features: + sweep_features: - sweep_info: + sweep_info: - well_known_files: LIMS-output information containing graphics - file names + well_known_files: LIMS-output information containing graphics + file names - working_dir: + working_dir: - save_cell_plots: + save_cell_plots: """ if save_cell_plots: - sizes = { 'small': 2.0, 'large': 6.0 } + sizes = {"small": 2.0, "large": 6.0} cell_files = plot_cell_figures( - nwb_file = nwb_file, - cell_features = cell_features, - rheo_features = rheo_features, - sweep_features = sweep_features, - sweep_info = sweep_info, - image_dir = working_dir, - sizes = sizes) + nwb_file=nwb_file, + cell_features=cell_features, + rheo_features=rheo_features, + sweep_features=sweep_features, + sweep_info=sweep_info, + image_dir=working_dir, + sizes=sizes, + ) else: cell_files = {} - + logging.info("saving images") - sizes = { 'small': 200, 'large': None } + sizes = {"small": 200, "large": None} plot_images(well_known_files, working_dir, sizes, cell_files) - sweep_page = os.path.join(working_dir, 'sweep.html') + sweep_page = os.path.join(working_dir, "sweep.html") relative_sweep_link = os.path.exists(sweep_page) if not relative_sweep_link: logging.info("sweep page doesn't exist, point to production sweep page") - - make_cell_html(cell_files, - os.path.join(working_dir, 'index.html'), - relative_sweep_link, - specimen_info, - fields_to_show) + + make_cell_html( + cell_files, os.path.join(working_dir, "index.html"), relative_sweep_link, specimen_info, fields_to_show + ) + def exp_curve(x, a, inv_tau, y0): - ''' Function used for tau curve fitting ''' + """Function used for tau curve fitting""" return y0 + a * np.exp(-inv_tau * x) -#def main(): +# def main(): # parser = argparse.ArgumentParser(description='analyze specimens for cell-wide features') # parser.add_argument('nwb_file') # parser.add_argument('feature_json') @@ -803,7 +871,7 @@ def exp_curve(x, a, inv_tau, y0): # parser.add_argument('--no-sweep-page', action='store_false', dest='sweep_page') # parser.add_argument('--no-cell-page', action='store_false', dest='cell_page') # parser.add_argument('--log_level') -# +# # # args = parser.parse_args() # @@ -822,4 +890,4 @@ def exp_curve(x, a, inv_tau, y0): # # # -#if __name__ == '__main__': main() +# if __name__ == '__main__': main() diff --git a/allensdk/internal/model/AIC.py b/allensdk/internal/model/AIC.py index 70698b9a83..c99281e0a0 100644 --- a/allensdk/internal/model/AIC.py +++ b/allensdk/internal/model/AIC.py @@ -5,29 +5,32 @@ TODO: comment style """ + def AIC(RSS, k, n): """ Computes the Akaike Information Criterion. - + RSS-residual sum of squares of the fitting errors. k - number of fitted parameters. n - number of observations. """ - AIC = 2 * k + n * np.log( RSS/n) + AIC = 2 * k + n * np.log(RSS / n) return AIC - + + def AICc(RSS, k, n): """ Corrected AIC. formula from Wikipedia. """ retval = AIC(RSS, k, n) - if n-k-1 != 0: - retval += 2.0 *k* (k+1)/ (n-k-1) + if n - k - 1 != 0: + retval += 2.0 * k * (k + 1) / (n - k - 1) return retval - + + def BIC(RSS, k, n): """ Bayesian information criterion or Schwartz information criterion. Formula from wikipedia. """ - return n * np.log(RSS/n) + k * np.log(n) + return n * np.log(RSS / n) + k * np.log(n) diff --git a/allensdk/internal/model/GLM.py b/allensdk/internal/model/GLM.py index c303d0d3be..5c67f9875d 100644 --- a/allensdk/internal/model/GLM.py +++ b/allensdk/internal/model/GLM.py @@ -5,143 +5,144 @@ # TODO: normalize function call names # TODO: document functions -def create_basis_IPSP(neye,ncos,kpeaks,ks,DTsim,t0,I_stim,nkt,flag_exp,npcut): - + +def create_basis_IPSP(neye, ncos, kpeaks, ks, DTsim, t0, I_stim, nkt, flag_exp, npcut): kbasprs = {} - kbasprs['neye'] = neye #No of 'identity' basis vectors near time of spike - kbasprs['ncos'] = ncos #No of raised-cosines to use - kbasprs['kpeaks'] = kpeaks #Position of first and last bump - kbasprs['b'] = 0.1 #Offset for non-linear scaling - kbasprs['ks'] = ks - - gg0 = makeFitStruct_GLM(DTsim,kbasprs,nkt,flag_exp) - - #Create spike-stim with which to convolve post spike filter + kbasprs["neye"] = neye # No of 'identity' basis vectors near time of spike + kbasprs["ncos"] = ncos # No of raised-cosines to use + kbasprs["kpeaks"] = kpeaks # Position of first and last bump + kbasprs["b"] = 0.1 # Offset for non-linear scaling + kbasprs["ks"] = ks + + gg0 = makeFitStruct_GLM(DTsim, kbasprs, nkt, flag_exp) + + # Create spike-stim with which to convolve post spike filter spike_stim = np.zeros(np.shape(I_stim)) for kk in range(len(t0)): spind = int(t0[kk]) - #print int(t0[kk]), spind-190000 - spike_stim[spind]=1.0 - + # print int(t0[kk]), spind-190000 + spike_stim[spind] = 1.0 + ##Convolve temporal basis functions with spike-stim - c = np.zeros((len(spike_stim),ncos)) + c = np.zeros((len(spike_stim), ncos)) for jj in range(ncos): - basisfilt = gg0['ktbas'][:,jj] - bconv = np.convolve(spike_stim,np.flipud(basisfilt),'full') - c[:,jj] = bconv[range(len(spike_stim))] - + basisfilt = gg0["ktbas"][:, jj] + bconv = np.convolve(spike_stim, np.flipud(basisfilt), "full") + c[:, jj] = bconv[range(len(spike_stim))] + basis_IPSP = c - + return basis_IPSP, gg0 -def makeFitStruct_GLM(dtsim,kbasprs,nkt,flag_exp): - + +def makeFitStruct_GLM(dtsim, kbasprs, nkt, flag_exp): gg = {} - gg['k'] = [] - gg['dc'] = 0 - gg['kt'] = np.zeros((nkt,1)) - gg['ktbas'] = [] - gg['kbasprs'] = kbasprs - gg['dt'] = dtsim - + gg["k"] = [] + gg["dc"] = 0 + gg["kt"] = np.zeros((nkt, 1)) + gg["ktbas"] = [] + gg["kbasprs"] = kbasprs + gg["dt"] = dtsim + nkt = nkt - if flag_exp==0: - ktbas = makeBasis_StimKernel(kbasprs,nkt) + if flag_exp == 0: + ktbas = makeBasis_StimKernel(kbasprs, nkt) else: - ktbas = makeBasis_StimKernel_exp(kbasprs,nkt) - - gg['ktbas'] = ktbas - gg['k'] = gg['ktbas']*gg['kt'] - + ktbas = makeBasis_StimKernel_exp(kbasprs, nkt) + + gg["ktbas"] = ktbas + gg["k"] = gg["ktbas"] * gg["kt"] + return gg -def makeBasis_StimKernel(kbasprs,nkt): - - neye = kbasprs['neye'] - ncos = kbasprs['ncos'] - kpeaks = kbasprs['kpeaks'] + +def makeBasis_StimKernel(kbasprs, nkt): + neye = kbasprs["neye"] + ncos = kbasprs["ncos"] + kpeaks = kbasprs["kpeaks"] kdt = 1 - b = kbasprs['b'] - - yrnge = nlin(kpeaks + b*np.ones(np.shape(kpeaks))) - #db = np.diff(yrnge)/(ncos-1) - db = (yrnge[-1]-yrnge[0])/(ncos-1) + b = kbasprs["b"] + + yrnge = nlin(kpeaks + b * np.ones(np.shape(kpeaks))) + # db = np.diff(yrnge)/(ncos-1) + db = (yrnge[-1] - yrnge[0]) / (ncos - 1) ctrs = yrnge - mxt = invnl(yrnge[ncos-1]+2*db)-b + mxt = invnl(yrnge[ncos - 1] + 2 * db) - b print(mxt) - kt0 = np.arange(0,mxt,kdt) + kt0 = np.arange(0, mxt, kdt) nt = len(kt0) - e1 = np.tile(nlin(kt0+b*np.ones(np.shape(kt0))),(ncos,1)) - e2 = np.transpose(e1) - e3 = np.tile(ctrs,(nt,1)) + e1 = np.tile(nlin(kt0 + b * np.ones(np.shape(kt0))), (ncos, 1)) + e2 = np.transpose(e1) + e3 = np.tile(ctrs, (nt, 1)) kbasis0 = [] - for kk in range(ncos): - kbasis0.append(ff(e2[:,kk],e3[:,kk],db)) - - - #Concatenate identity vectors - nkt0 = np.size(kt0,0) - a1 = np.concatenate((np.eye(neye), np.zeros((nkt0,neye))),axis=0) - a2 = np.concatenate((np.zeros((neye,ncos)),np.array(kbasis0).T),axis=0) - kbasis = np.concatenate((a1,a2),axis=1) + for kk in range(ncos): + kbasis0.append(ff(e2[:, kk], e3[:, kk], db)) + + # Concatenate identity vectors + nkt0 = np.size(kt0, 0) + a1 = np.concatenate((np.eye(neye), np.zeros((nkt0, neye))), axis=0) + a2 = np.concatenate((np.zeros((neye, ncos)), np.array(kbasis0).T), axis=0) + kbasis = np.concatenate((a1, a2), axis=1) kbasis = np.flipud(kbasis) - nkt0 = np.size(kbasis,0) - + nkt0 = np.size(kbasis, 0) + if nkt0 < nkt: - kbasis = np.concatenate((np.zeros((nkt-nkt0,ncos+neye)),kbasis),axis=0) + kbasis = np.concatenate((np.zeros((nkt - nkt0, ncos + neye)), kbasis), axis=0) elif nkt0 > nkt: - kbasis = kbasis[-1-nkt:-1,:] + kbasis = kbasis[-1 - nkt : -1, :] kbasis = normalizecols(kbasis) - return kbasis + return kbasis -def makeBasis_StimKernel_exp(kbasprs,nkt): - ks = kbasprs['ks'] - x0 = np.arange(0,nkt) - kbasis = np.zeros((nkt,len(ks))) +def makeBasis_StimKernel_exp(kbasprs, nkt): + ks = kbasprs["ks"] + x0 = np.arange(0, nkt) + kbasis = np.zeros((nkt, len(ks))) for ii in range(len(ks)): - kbasis[:,ii] = invnl(-ks[ii]*x0) #(1.0/ks[ii])* - - kbasis = np.flipud(kbasis) - return kbasis - + kbasis[:, ii] = invnl(-ks[ii] * x0) # (1.0/ks[ii])* + + kbasis = np.flipud(kbasis) + return kbasis + + def nlin(x): eps = 1e-20 - return np.log(x+eps) - + return np.log(x + eps) + + def invnl(x): eps = 1e-20 - return np.exp(x)-eps - -def ff(x,c,dc): - rowsize = np.size(x,0) + return np.exp(x) - eps + + +def ff(x, c, dc): + rowsize = np.size(x, 0) m = [] - for i in range(rowsize): + for i in range(rowsize): xi = x[i] ci = c[i] - val=(np.cos(np.max([-pi,np.min([pi,(xi-ci)*pi/dc/2])]))+1)/2 + val = (np.cos(np.max([-pi, np.min([pi, (xi - ci) * pi / dc / 2])])) + 1) / 2 m.append(val) - + return np.array(m) - + + def normalizecols(A): - - B = A/np.tile(np.sqrt(sum(A**2,0)),(np.size(A,0),1)) - + B = A / np.tile(np.sqrt(sum(A**2, 0)), (np.size(A, 0), 1)) + return B - -def sameconv(A,B): - + + +def sameconv(A, B): am = np.size(A) bm = np.size(B) - nn = am+bm-1 - - q = npft.fft(A,nn)*npft.fft(np.flipud(B),nn) + nn = am + bm - 1 + + q = npft.fft(A, nn) * npft.fft(np.flipud(B), nn) p = q G = npft.ifft(p) G = G[range(am)] - - return G + return G diff --git a/allensdk/internal/model/biophysical/biophysical_archiver.py b/allensdk/internal/model/biophysical/biophysical_archiver.py index 3649368bb4..fe80cb15b7 100644 --- a/allensdk/internal/model/biophysical/biophysical_archiver.py +++ b/allensdk/internal/model/biophysical/biophysical_archiver.py @@ -1,5 +1,4 @@ -from allensdk.api.queries.biophysical_api import \ - BiophysicalApi +from allensdk.api.queries.biophysical_api import BiophysicalApi from allensdk.api.queries.cell_types_api import CellTypesApi from allensdk.api.queries.rma_api import RmaApi import os @@ -7,8 +6,8 @@ import shutil -#bp = BiophysicalApi('http://api.brain-map.org') -#bp.cache_stimulus = True # change to False to not download the large stimulus NWB file +# bp = BiophysicalApi('http://api.brain-map.org') +# bp.cache_stimulus = True # change to False to not download the large stimulus NWB file # neuronal_model_id = 472451419 # get this from the web site as above # bp.cache_data(neuronal_model_id, working_directory='neuronal_model') @@ -17,75 +16,80 @@ # Note, am I supposed to be only archiving Biophysical models or also GLIFs? + class BiophysicalArchiver(object): def __init__(self, archive_dir=None): - self.bp = BiophysicalApi('http://api.brain-map.org') - self.bp.cache_stimulus = True # change to False to not download the large stimulus NWB file + self.bp = BiophysicalApi("http://api.brain-map.org") + self.bp.cache_stimulus = True # change to False to not download the large stimulus NWB file self.cta = CellTypesApi() self.rma = RmaApi() - self.neuronal_model_download_endpoint = 'http://celltypes.brain-map.org/neuronal_model/download/' + self.neuronal_model_download_endpoint = "http://celltypes.brain-map.org/neuronal_model/download/" self.template_names = {} self.nwb_list = [] - + if archive_dir is None: - archive_dir = '.' + archive_dir = "." self.archive_dir = archive_dir - + def get_template_names(self): - template_response = self.rma.model_query('NeuronalModelTemplate') - self.template_names = { t['id']: str(t['name']).replace(' ', '_') for t in template_response} - + template_response = self.rma.model_query("NeuronalModelTemplate") + self.template_names = {t["id"]: str(t["name"]).replace(" ", "_") for t in template_response} + def get_cells(self): return self.cta.list_cells(True, True) - + def get_neuronal_models(self, specimen_ids): - return self.rma.model_query('NeuronalModel', - criteria='specimen[id$in%s]' % ','.join(str(i) for i in specimen_ids), - include='specimen', - num_rows='all') - + return self.rma.model_query( + "NeuronalModel", + criteria="specimen[id$in%s]" % ",".join(str(i) for i in specimen_ids), + include="specimen", + num_rows="all", + ) + def get_stimulus_file(self, neuronal_model_id): - result = self.rma.model_query('NeuronalModel', - criteria='[id$eq%d]' % (neuronal_model_id), - include="specimen(ephys_result(well_known_files(well_known_file_type[name$il'NWB*'])))", - tabular=['path']) - - stimulus_filename = result[0]['path'] - + result = self.rma.model_query( + "NeuronalModel", + criteria="[id$eq%d]" % (neuronal_model_id), + include="specimen(ephys_result(well_known_files(well_known_file_type[name$il'NWB*'])))", + tabular=["path"], + ) + + stimulus_filename = result[0]["path"] + return stimulus_filename - - - stimulus_filename = os.path.basename(result[0]['path']) - + + stimulus_filename = os.path.basename(result[0]["path"]) + return stimulus_filename - + def archive_cell(self, ephys_result_id, specimen_id, template, neuronal_model_id): url = self.neuronal_model_download_endpoint + "/%d" % (neuronal_model_id) - file_name = os.path.join(self.archive_dir, 'ephys_result_%d_specimen_%d_%s_neuronal_model_%d.zip' % (ephys_result_id, - specimen_id, - template, - neuronal_model_id)) + file_name = os.path.join( + self.archive_dir, + "ephys_result_%d_specimen_%d_%s_neuronal_model_%d.zip" + % (ephys_result_id, specimen_id, template, neuronal_model_id), + ) self.rma.retrieve_file_over_http(url, file_name) nwb_file = self.get_stimulus_file(neuronal_model_id) - shutil.copy(nwb_file, self.archive_dir) - self.nwb_list.append("%s\t%s" % (os.path.basename(nwb_file), - file_name)) - -if __name__ == '__main__': - archive_dir = sys.argv[-1] # /data/informatics/mousecelltypes/model_cache_may_2015 + shutil.copy(nwb_file, self.archive_dir) + self.nwb_list.append("%s\t%s" % (os.path.basename(nwb_file), file_name)) + + +if __name__ == "__main__": + archive_dir = sys.argv[-1] # /data/informatics/mousecelltypes/model_cache_may_2015 ba = BiophysicalArchiver(archive_dir) ba.get_template_names() cells = ba.get_cells() - - specimen_ids = (cell['id'] for cell in cells) + + specimen_ids = (cell["id"] for cell in cells) neuronal_models = ba.get_neuronal_models(specimen_ids) for nm in neuronal_models: - ephys_result_id = nm['specimen']['ephys_result_id'] - template_id = nm['neuronal_model_template_id'] + ephys_result_id = nm["specimen"]["ephys_result_id"] + template_id = nm["neuronal_model_template_id"] if template_id in ba.template_names: template = ba.template_names[template_id] else: - template = 'unknown' - ba.archive_cell(ephys_result_id, nm['specimen_id'], template, nm['id']) - with open(os.path.join(ba.archive_dir, 'STIMULUS.csv'), 'w') as f: - f.write("\n".join(ba.nwb_list)) \ No newline at end of file + template = "unknown" + ba.archive_cell(ephys_result_id, nm["specimen_id"], template, nm["id"]) + with open(os.path.join(ba.archive_dir, "STIMULUS.csv"), "w") as f: + f.write("\n".join(ba.nwb_list)) diff --git a/allensdk/internal/model/biophysical/check_fi_shift.py b/allensdk/internal/model/biophysical/check_fi_shift.py index 3b4b4d5f90..3bcca9f9c6 100755 --- a/allensdk/internal/model/biophysical/check_fi_shift.py +++ b/allensdk/internal/model/biophysical/check_fi_shift.py @@ -3,8 +3,8 @@ from allensdk.ephys.feature_extractor import EphysFeatureExtractor import allensdk.internal.model.biophysical.ephys_utils as ephys_utils -def calculate_fi_curves(data_set, sweeps): +def calculate_fi_curves(data_set, sweeps): sweep_type = "C1LSCOARSE" _, sweep_numbers, statuses = ephys_utils.get_sweeps_of_type(sweep_type, sweeps) features = EphysFeatureExtractor() @@ -13,7 +13,7 @@ def calculate_fi_curves(data_set, sweeps): sweep_status = dict(zip(sweep_numbers, statuses)) for s in sweep_numbers: - if sweep_status[s] in [ 'auto_failed', 'manual_failed' ]: + if sweep_status[s] in ["auto_failed", "manual_failed"]: continue v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s) @@ -31,7 +31,7 @@ def calculate_fi_curves(data_set, sweeps): core2_amps = {} amp_list = [] for s in sweep_numbers: - if sweep_status[s] in [ 'auto_failed', 'manual_failed' ]: + if sweep_status[s] in ["auto_failed", "manual_failed"]: continue v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s) @@ -55,10 +55,13 @@ def calculate_fi_curves(data_set, sweeps): stim_start, stim_dur, stim_amp, start_idx, end_idx = ephys_utils.get_step_stim_characteristics(i, t) features.process_instance(s, v, i, t, stim_start, stim_dur, "") core2_fi_curve.append((amp, features.feature_list[-1].mean["n_spikes"] / stim_dur)) - first_half_spike_count = len([spk for spk in features.feature_list[-1].mean["spikes"] if spk["t"] < stim_start + stim_dur / 2.0]) + first_half_spike_count = len( + [spk for spk in features.feature_list[-1].mean["spikes"] if spk["t"] < stim_start + stim_dur / 2.0] + ) core2_half_fi_curve.append((amp, first_half_spike_count / (stim_dur / 2.0))) - return { "coarse": coarse_fi_curve, "core2": core2_fi_curve, "core2_half": core2_half_fi_curve } + return {"coarse": coarse_fi_curve, "core2": core2_fi_curve, "core2_half": core2_half_fi_curve} + def estimate_fi_shift(data_set, sweeps): curve_data = calculate_fi_curves(data_set, sweeps) @@ -68,7 +71,7 @@ def estimate_fi_shift(data_set, sweeps): x = np.array([d[0] for d in coarse_fi_sorted], dtype=np.float64) y = np.array([d[1] for d in coarse_fi_sorted], dtype=np.float64) - if len(np.flatnonzero(y)) == 0: # original curve is all zero, so can't figure out shift + if len(np.flatnonzero(y)) == 0: # original curve is all zero, so can't figure out shift return np.nan, 0 last_zero_index = np.flatnonzero(y)[0] - 1 diff --git a/allensdk/internal/model/biophysical/deap_utils.py b/allensdk/internal/model/biophysical/deap_utils.py index 55f3e6addb..e9f75d0775 100644 --- a/allensdk/internal/model/biophysical/deap_utils.py +++ b/allensdk/internal/model/biophysical/deap_utils.py @@ -4,6 +4,7 @@ import numpy as np + class Utils(HocUtils): _log = logging.getLogger(__name__) @@ -41,16 +42,16 @@ def generate_morphology(self, morph_filename): def load_cell_parameters(self): cell = self.cell - passive = self.description.data['passive'][0] - conditions = self.description.data['conditions'][0] - channels = self.description.data['channels'] - addl_params = self.description.data['addl_params'] + passive = self.description.data["passive"][0] + conditions = self.description.data["conditions"][0] + channels = self.description.data["channels"] + addl_params = self.description.data["addl_params"] # Set passive properties for sec in cell.all: - sec.Ra = passive['ra'] - sec.cm = passive['cm'][sec.name().split(".")[1][:4]] - sec.insert('pas') + sec.Ra = passive["ra"] + sec.cm = passive["cm"][sec.name().split(".")[1][:4]] + sec.insert("pas") for seg in sec: seg.pas.e = passive["e_pas"] self.h.v_init = passive["e_pas"] @@ -69,14 +70,14 @@ def load_cell_parameters(self): sec.insert(ap["mechanism"]) # Set reversal potentials - for erev in conditions['erev']: + for erev in conditions["erev"]: sections = [s for s in cell.all if s.name().split(".")[1][:4] == erev["section"]] for sec in sections: sec.ena = erev["ena"] sec.ek = erev["ek"] def set_normalized_parameters(self, params): - channels_and_others = self.description.data['channels'] + self.description.data['addl_params'] + channels_and_others = self.description.data["channels"] + self.description.data["addl_params"] for i, p in enumerate(params): c = channels_and_others[i] value = p * (c["max"] - c["min"]) + c["min"] @@ -88,7 +89,7 @@ def set_normalized_parameters(self, params): setattr(sec, param_name, value) def set_actual_parameters(self, params): - channels_and_others = self.description.data['channels'] + self.description.data['addl_params'] + channels_and_others = self.description.data["channels"] + self.description.data["addl_params"] for i, p in enumerate(params): c = channels_and_others[i] sections = [s for s in self.cell.all if s.name().split(".")[1][:4] == c["section"]] @@ -100,7 +101,7 @@ def set_actual_parameters(self, params): def normalize_actual_parameters(self, params): params_array = np.array(params) - channels_and_others = self.description.data['channels'] + self.description.data['addl_params'] + channels_and_others = self.description.data["channels"] + self.description.data["addl_params"] max_vals = np.array([c["max"] for c in channels_and_others]) min_vals = np.array([c["min"] for c in channels_and_others]) @@ -109,7 +110,7 @@ def normalize_actual_parameters(self, params): def actual_parameters_from_normalized(self, params): actual_params = [] - channels_and_others = self.description.data['channels'] + self.description.data['addl_params'] + channels_and_others = self.description.data["channels"] + self.description.data["addl_params"] for i, p in enumerate(params): c = channels_and_others[i] value = p * (c["max"] - c["min"]) + c["min"] @@ -137,7 +138,7 @@ def calculate_feature_errors(self, t_ms, v, i): delay = self.stim.delay * 1e-3 duration = self.stim.dur * 1e-3 t = t_ms * 1e-3 - feature_names = self.description.data['features'] + feature_names = self.description.data["features"] # penalize for failing to return to rest start_index = np.flatnonzero(t >= delay)[0] @@ -149,18 +150,20 @@ def calculate_feature_errors(self, t_ms, v, i): if swp.sweep_feature("avg_rate") > 0: fail_trace = True - target_features = self.description.data['target_features'] + target_features = self.description.data["target_features"] target_features_dict = {f["name"]: {"mean": f["mean"], "stdev": f["stdev"]} for f in target_features} if not fail_trace: swp = EphysSweepFeatureExtractor(t, v, i, start=delay, end=(delay + duration), filter=None) swp.process_spikes() - if len(swp.spikes()) < minimum_num_spikes: # Enough spikes? + if len(swp.spikes()) < minimum_num_spikes: # Enough spikes? fail_trace = True else: - avg_per_spike_peak_error = np.mean([abs(spk["peak_v"] - target_features_dict["peak_v"]["mean"]) for spk in swp.spikes()]) + avg_per_spike_peak_error = np.mean( + [abs(spk["peak_v"] - target_features_dict["peak_v"]["mean"]) for spk in swp.spikes()] + ) avg_overall_error = abs(target_features_dict["peak_v"]["mean"] - swp.spike_feature("peak_v").mean()) - if avg_per_spike_peak_error > 3.0 * avg_overall_error: # Weird bi-modality of spikes; 3.0 is arbitrary + if avg_per_spike_peak_error > 3.0 * avg_overall_error: # Weird bi-modality of spikes; 3.0 is arbitrary fail_trace = True if fail_trace: @@ -180,18 +183,18 @@ def calculate_feature_errors(self, t_ms, v, i): slow_trough_t = swp.spike_feature("slow_trough_t") delta_t = slow_trough_t - fast_trough_t - delta_t[np.isnan(delta_t)] = 0. + delta_t[np.isnan(delta_t)] = 0.0 other_features["slow_trough_delta_time"] = np.mean(delta_t[:-1] / np.diff(threshold_t)) fast_trough_v = swp.spike_feature("fast_trough_v") slow_trough_v = swp.spike_feature("slow_trough_v") delta_v = fast_trough_v - slow_trough_v - delta_v[np.isnan(delta_v)] = 0. + delta_v[np.isnan(delta_v)] = 0.0 other_features["slow_trough_delta_v"] = delta_v.mean() for f in feature_names: - target_mean = target_features_dict[f]['mean'] - target_stdev = target_features_dict[f]['stdev'] + target_mean = target_features_dict[f]["mean"] + target_stdev = target_features_dict[f]["stdev"] if target_stdev == 0: print("Feature with 0 stdev: ", f) diff --git a/allensdk/internal/model/biophysical/ephys_utils.py b/allensdk/internal/model/biophysical/ephys_utils.py index 1f02307d76..0e176f6e50 100644 --- a/allensdk/internal/model/biophysical/ephys_utils.py +++ b/allensdk/internal/model/biophysical/ephys_utils.py @@ -1,22 +1,25 @@ import numpy as np + def get_sweep_v_i_t_from_set(data_set, sweep_number): sweep_data = data_set.get_sweep(sweep_number) - i = sweep_data["stimulus"] # in A - v = sweep_data["response"] # in V - i *= 1e12 # to pA - v *= 1e3 # to mV - sampling_rate = sweep_data["sampling_rate"] # in Hz + i = sweep_data["stimulus"] # in A + v = sweep_data["response"] # in V + i *= 1e12 # to pA + v *= 1e3 # to mV + sampling_rate = sweep_data["sampling_rate"] # in Hz t = np.arange(0, len(v)) * (1.0 / sampling_rate) return v, i, t + def get_sweeps_of_type(sweep_type, sweeps): - sweeps = [ s for s in sweeps if s['ephys_stimulus']['description'].startswith( sweep_type )] - sweep_numbers = [ s['sweep_number'] for s in sweeps ] - statuses = [ s['workflow_state'] for s in sweeps ] - + sweeps = [s for s in sweeps if s["ephys_stimulus"]["description"].startswith(sweep_type)] + sweep_numbers = [s["sweep_number"] for s in sweeps] + statuses = [s["workflow_state"] for s in sweeps] + return sweeps, sweep_numbers, statuses + def get_step_stim_characteristics(i, t): # Assumes that there is a test pulse followed by the stimulus step di = np.diff(i) @@ -27,10 +30,10 @@ def get_step_stim_characteristics(i, t): if len(up_idx) < 2 or len(down_idx) < 2: return (np.nan, np.nan, np.nan, np.nan, np.nan) - if up_idx[1] < down_idx[1]: # positive step - start_idx = up_idx[1] + 1 # shift by one to compensate for diff() + if up_idx[1] < down_idx[1]: # positive step + start_idx = up_idx[1] + 1 # shift by one to compensate for diff() end_idx = down_idx[1] + 1 - else: # negative step + else: # negative step start_idx = down_idx[1] + 1 end_idx = up_idx[1] + 1 stim_start = float(t[start_idx]) diff --git a/allensdk/internal/model/biophysical/fit_stage_1.py b/allensdk/internal/model/biophysical/fit_stage_1.py index 291ee8b224..0479dd0c33 100644 --- a/allensdk/internal/model/biophysical/fit_stage_1.py +++ b/allensdk/internal/model/biophysical/fit_stage_1.py @@ -7,8 +7,7 @@ from collections import Counter import subprocess -from allensdk.ephys.ephys_extractor \ - import EphysSweepFeatureExtractor, EphysSweepSetFeatureExtractor +from allensdk.ephys.ephys_extractor import EphysSweepFeatureExtractor, EphysSweepSetFeatureExtractor import allensdk.core.json_utilities as ju from allensdk.core.nwb_data_set import NwbDataSet import allensdk.internal.model.biophysical.optimize as optimize @@ -17,10 +16,11 @@ SEEDS = [1234, 1001, 4321, 1024, 2048] FIT_BASE_DIR = os.path.join(os.path.dirname(__file__), "fits") APICAL_DENDRITE_TYPE = 4 -MPIEXEC = 'mpiexec' +MPIEXEC = "mpiexec" DEFAULT_NUM_PROCESSES = 240 -_fit_stage_1_log = logging.getLogger('allensdk.model.biophysical.fit_stage_1') +_fit_stage_1_log = logging.getLogger("allensdk.model.biophysical.fit_stage_1") + def find_core1_trace(data_set, all_sweeps): sweep_type = "C1LSCOARSE" @@ -32,28 +32,37 @@ def find_core1_trace(data_set, all_sweeps): if sweep_status[s][-6:] == "failed": continue v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s) - if np.all(v[-100:] == 0): # Check for early termination of sweep + if np.all(v[-100:] == 0): # Check for early termination of sweep continue stim_start, stim_dur, stim_amp, start_idx, end_idx = ephys_utils.get_step_stim_characteristics(i, t) swp = EphysSweepFeatureExtractor(t, v, i, start=stim_start, end=(stim_start + stim_dur)) swp.process_spikes() isi_cv = swp.sweep_feature("isi_cv", allow_missing=True) - sweep_info[s] = {"amp": stim_amp, - "n_spikes": len(swp.spikes()), - "quality": is_trace_good_quality(v, i, t), - "isi_cv": isi_cv} - + sweep_info[s] = { + "amp": stim_amp, + "n_spikes": len(swp.spikes()), + "quality": is_trace_good_quality(v, i, t), + "isi_cv": isi_cv, + } + rheobase_amp = 1e12 for s in sweep_info: if sweep_info[s]["amp"] < rheobase_amp and sweep_info[s]["n_spikes"] > 0: rheobase_amp = sweep_info[s]["amp"] sweep_to_use_amp = 1e12 - sweep_to_use_isi_cv = 1e121 + sweep_to_use_isi_cv = 1e121 sweep_to_use = -1 for s in sweep_info: - if sweep_info[s]["quality"] and sweep_info[s]["amp"] >= 39.0 + rheobase_amp and sweep_info[s]["isi_cv"] < 1.2 * sweep_to_use_isi_cv: + if ( + sweep_info[s]["quality"] + and sweep_info[s]["amp"] >= 39.0 + rheobase_amp + and sweep_info[s]["isi_cv"] < 1.2 * sweep_to_use_isi_cv + ): use_new_sweep = False - if sweep_to_use_isi_cv > 0.3 and ((sweep_to_use_isi_cv - sweep_info[s]["isi_cv"]) / sweep_to_use_isi_cv) >= 0.2: + if ( + sweep_to_use_isi_cv > 0.3 + and ((sweep_to_use_isi_cv - sweep_info[s]["isi_cv"]) / sweep_to_use_isi_cv) >= 0.2 + ): use_new_sweep = True elif sweep_info[s]["amp"] < sweep_to_use_amp: use_new_sweep = True @@ -62,13 +71,14 @@ def find_core1_trace(data_set, all_sweeps): sweep_to_use = s sweep_to_use_amp = sweep_info[s]["amp"] sweep_to_use_isi_cv = sweep_info[s]["isi_cv"] - + if sweep_to_use == -1: _fit_stage_1_log.warn("Could not find appropriate core 1 sweep!") return [] else: return [sweep_to_use] + def find_core2_trace(data_set, all_sweeps): sweep_type = "C2SQRHELNG" _, sweeps, statuses = ephys_utils.get_sweeps_of_type(sweep_type, all_sweeps) @@ -112,6 +122,7 @@ def find_core2_trace(data_set, all_sweeps): sweeps_to_fit.append(k) return sweeps_to_fit + def is_trace_good_quality(v, i, t): stim_start, stim_dur, stim_amp, start_idx, end_idx = ephys_utils.get_step_stim_characteristics(i, t) swp = EphysSweepFeatureExtractor(t, v, i, start=stim_start, end=(stim_start + stim_dur)) @@ -124,8 +135,10 @@ def is_trace_good_quality(v, i, t): return False time_to_end = stim_start + stim_dur - spikes[-1]["threshold_t"] - avg_end_isi = (((spikes[-1]["threshold_t"] - spikes[-2]["threshold_t"]) + - (spikes[-2]["threshold_t"] - spikes[-3]["threshold_t"])) / 2.0) + avg_end_isi = ( + (spikes[-1]["threshold_t"] - spikes[-2]["threshold_t"]) + + (spikes[-2]["threshold_t"] - spikes[-3]["threshold_t"]) + ) / 2.0 if time_to_end > 2 * avg_end_isi: return False @@ -140,7 +153,7 @@ def is_trace_good_quality(v, i, t): def check_for_pause(isis): if len(isis) <= 2: return False - + for i, isi in enumerate(isis[1:-1]): if isi > 3 * isis[i + 1 - 1] and isi > 3 * isis[i + 1 + 1]: return True @@ -149,27 +162,27 @@ def check_for_pause(isis): def collect_target_features(ft): min_std_dict = { - 'avg_rate': 0.5, - 'adapt': 0.001, - 'peak_v': 2.0, - 'trough_v': 2.0, - 'fast_trough_v': 2.0, - 'slow_trough_delta_v': 2.0, - 'slow_trough_delta_time': 0.05, - 'latency': 5.0, - 'isi_cv': 0.1, - 'mean_isi': 0.5, - 'first_isi': 1.0, - 'time_to_end': 50.0, - 'v_baseline': 2.0, - 'width': 0.0001, - 'upstroke': 50.0, - 'downstroke': 50.0, - 'upstroke_v': 2.0, - 'downstroke_v': 2.0, - 'threshold_v': 2.0, - 'peak_to_fast_tr_time': 0.0005, - 'phase_slope': 5.0, + "avg_rate": 0.5, + "adapt": 0.001, + "peak_v": 2.0, + "trough_v": 2.0, + "fast_trough_v": 2.0, + "slow_trough_delta_v": 2.0, + "slow_trough_delta_time": 0.05, + "latency": 5.0, + "isi_cv": 0.1, + "mean_isi": 0.5, + "first_isi": 1.0, + "time_to_end": 50.0, + "v_baseline": 2.0, + "width": 0.0001, + "upstroke": 50.0, + "downstroke": 50.0, + "upstroke_v": 2.0, + "downstroke_v": 2.0, + "threshold_v": 2.0, + "peak_to_fast_tr_time": 0.0005, + "phase_slope": 5.0, } target_features = [] @@ -182,20 +195,20 @@ def collect_target_features(ft): def prepare_stage_1(description, passive_fit_data): - output_directory = description.manifest.get_path('WORKDIR') - neuronal_model_data = ju.read(description.manifest.get_path('neuronal_model_data')) - specimen_data = neuronal_model_data['specimen'] - is_spiny = not any(t['name'] == u'dendrite type - aspiny' for t in specimen_data['specimen_tags']) - all_sweeps = specimen_data['ephys_sweeps'] - data_set = NwbDataSet(description.manifest.get_path('stimulus_path')) - swc_path = description.manifest.get_path('MORPHOLOGY') - + output_directory = description.manifest.get_path("WORKDIR") + neuronal_model_data = ju.read(description.manifest.get_path("neuronal_model_data")) + specimen_data = neuronal_model_data["specimen"] + is_spiny = not any(t["name"] == "dendrite type - aspiny" for t in specimen_data["specimen_tags"]) + all_sweeps = specimen_data["ephys_sweeps"] + data_set = NwbDataSet(description.manifest.get_path("stimulus_path")) + swc_path = description.manifest.get_path("MORPHOLOGY") + if not os.path.exists(output_directory): os.makedirs(output_directory) - ra = passive_fit_data['ra'] - cm1 = passive_fit_data['cm1'] - cm2 = passive_fit_data['cm2'] + ra = passive_fit_data["ra"] + cm1 = passive_fit_data["cm1"] + cm2 = passive_fit_data["cm2"] # Check for fi curve shift to decide to use core1 or core2 fi_shift, n_core2 = check_fi_shift.estimate_fi_shift(data_set, all_sweeps) @@ -247,19 +260,23 @@ def prepare_stage_1(description, passive_fit_data): slow_trough_t = swp.spike_feature("slow_trough_t") delta_t = slow_trough_t - fast_trough_t - delta_t[np.isnan(delta_t)] = 0. + delta_t[np.isnan(delta_t)] = 0.0 sweep_avg_slow_trough_delta_time.append(np.mean(delta_t[:-1] / np.diff(threshold_t))) fast_trough_v = swp.spike_feature("fast_trough_v") slow_trough_v = swp.spike_feature("slow_trough_v") delta_v = fast_trough_v - slow_trough_v - delta_v[np.isnan(delta_v)] = 0. + delta_v[np.isnan(delta_v)] = 0.0 sweep_avg_slow_trough_delta_v.append(delta_v.mean()) - ft["slow_trough_delta_time"] = {"mean": float(np.mean(sweep_avg_slow_trough_delta_time)), - "stdev": float(np.std(sweep_avg_slow_trough_delta_time))} - ft["slow_trough_delta_v"] = {"mean": float(np.mean(sweep_avg_slow_trough_delta_v)), - "stdev": float(np.std(sweep_avg_slow_trough_delta_v))} + ft["slow_trough_delta_time"] = { + "mean": float(np.mean(sweep_avg_slow_trough_delta_time)), + "stdev": float(np.std(sweep_avg_slow_trough_delta_time)), + } + ft["slow_trough_delta_v"] = { + "mean": float(np.mean(sweep_avg_slow_trough_delta_v)), + "stdev": float(np.std(sweep_avg_slow_trough_delta_v)), + } baseline_v = float(ext.sweep_features("v_baseline").mean()) passive_fit_data["e_pas"] = baseline_v @@ -277,13 +294,13 @@ def prepare_stage_1(description, passive_fit_data): max_i = 0 for s in all_sweeps: try: - v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s['sweep_number']) + v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s["sweep_number"]) except Exception: pass if np.max(i) > max_i: max_i = np.max(i) - max_i += 10 # add 10 pA - max_i *= 1e-3 # convert to nA + max_i += 10 # add 10 pA + max_i *= 1e-3 # convert to nA # ----------- Generate output and submit jobs --------------- @@ -307,13 +324,9 @@ def prepare_stage_1(description, passive_fit_data): # Collect and save data for target.json file target_dict = {} - target_dict["passive"] = [{ - "ra": ra, - "cm": { "soma": cm1, "axon": cm1, "dend": cm2 }, - "e_pas": baseline_v - }] + target_dict["passive"] = [{"ra": ra, "cm": {"soma": cm1, "axon": cm1, "dend": cm2}, "e_pas": baseline_v}] - swc_data = pd.read_table(swc_path, sep='\s', comment='#', header=None) + swc_data = pd.read_table(swc_path, sep="\s", comment="#", header=None) has_apic = False if APICAL_DENDRITE_TYPE in pd.unique(swc_data[1]): has_apic = True @@ -324,31 +337,27 @@ def prepare_stage_1(description, passive_fit_data): if has_apic: target_dict["passive"][0]["cm"]["apic"] = cm2 - target_dict["fitting"] = [{ - "junction_potential": jxn, - "sweeps": sweeps_to_fit, - "passive_fit_info": passive_fit_data, - "max_stim_test_na": max_i, - }] + target_dict["fitting"] = [ + { + "junction_potential": jxn, + "sweeps": sweeps_to_fit, + "passive_fit_info": passive_fit_data, + "max_stim_test_na": max_i, + } + ] - target_dict["stimulus"] = [{ - "amplitude": 1e-3 * stim_amp, - "delay": 1000.0, - "duration": 1e3 * stim_dur - }] + target_dict["stimulus"] = [{"amplitude": 1e-3 * stim_amp, "delay": 1000.0, "duration": 1e3 * stim_dur}] target_dict["manifest"] = [] target_dict["manifest"].append({"type": "file", "spec": swc_path, "key": "MORPHOLOGY"}) target_dict["target_features"] = collect_target_features(ft) - target_file = os.path.join(output_directory, 'target.json') + target_file = os.path.join(output_directory, "target.json") ju.write(target_file, target_dict) # Create config.json for each fit type - config_base_data = ju.read(os.path.join(FIT_BASE_DIR, - 'config_base.json')) - + config_base_data = ju.read(os.path.join(FIT_BASE_DIR, "config_base.json")) jobs = [] for fit_type in fit_types: @@ -356,9 +365,9 @@ def prepare_stage_1(description, passive_fit_data): fit_type_dir = os.path.join(output_directory, fit_type) config_path = os.path.join(fit_type_dir, "config.json") - config["biophys"][0]["model_file"] = [ target_file, config_path] + config["biophys"][0]["model_file"] = [target_file, config_path] if has_apic: - fit_style_file = os.path.join(FIT_BASE_DIR, 'fit_styles', '%s_fit_style.json' % (fit_type)) + fit_style_file = os.path.join(FIT_BASE_DIR, "fit_styles", "%s_fit_style.json" % (fit_type)) else: fit_style_file = os.path.join(FIT_BASE_DIR, "fit_styles", "%s_noapic_fit_style.json" % (fit_type)) @@ -367,29 +376,33 @@ def prepare_stage_1(description, passive_fit_data): ju.write(config_path, config) for seed in SEEDS: - logfile = os.path.join(output_directory, fit_type, 's%d' % seed, 'stage_1.log') - jobs.append({ - 'config_path': os.path.abspath(config_path), - 'fit_type': fit_type, - 'log': os.path.abspath(logfile), - 'seed': seed, - 'num_processes': DEFAULT_NUM_PROCESSES - }) + logfile = os.path.join(output_directory, fit_type, "s%d" % seed, "stage_1.log") + jobs.append( + { + "config_path": os.path.abspath(config_path), + "fit_type": fit_type, + "log": os.path.abspath(logfile), + "seed": seed, + "num_processes": DEFAULT_NUM_PROCESSES, + } + ) return jobs def run_stage_1(jobs): for job in jobs: - args = [MPIEXEC, - '-np', - str(job['num_processes']), - sys.executable, - '-m', - optimize.__name__, - str(job['seed']), - job['config_path'], - str(optimize.DEFAULT_NGEN), - str(optimize.DEFAULT_MU)] + args = [ + MPIEXEC, + "-np", + str(job["num_processes"]), + sys.executable, + "-m", + optimize.__name__, + str(job["seed"]), + job["config_path"], + str(optimize.DEFAULT_NGEN), + str(optimize.DEFAULT_MU), + ] _fit_stage_1_log.debug(args) - with open(job['log'], "w") as outfile: - subprocess.call(args, stdout=outfile) \ No newline at end of file + with open(job["log"], "w") as outfile: + subprocess.call(args, stdout=outfile) diff --git a/allensdk/internal/model/biophysical/fit_stage_2.py b/allensdk/internal/model/biophysical/fit_stage_2.py index 380ae9e9f0..5cc2a04680 100755 --- a/allensdk/internal/model/biophysical/fit_stage_2.py +++ b/allensdk/internal/model/biophysical/fit_stage_2.py @@ -11,17 +11,18 @@ FIT_TYPES = {"f6": "f9", "f12": "f13"} DEFAULT_NUM_PROCESSES = 240 -_fit_stage_2_log = logging.getLogger('allensdk.model.biophysical.fit_stage_2') +_fit_stage_2_log = logging.getLogger("allensdk.model.biophysical.fit_stage_2") + def prepare_stage_2(output_directory): - config_base_data = json_utilities.read(os.path.join(FIT_BASE_DIR, 'config_base.json')) + config_base_data = json_utilities.read(os.path.join(FIT_BASE_DIR, "config_base.json")) jobs = [] for fit_type in FIT_TYPES: best_error = 1e12 best_seed = 0 - + fit_type_dir = os.path.join(output_directory, fit_type) if not os.path.exists(fit_type_dir): @@ -56,13 +57,13 @@ def prepare_stage_2(output_directory): config = config_base_data.copy() config_path = os.path.join(new_fit_type_dir, "config.json") - config["biophys"][0]["model_file"] = [ target_file, config_path] + config["biophys"][0]["model_file"] = [target_file, config_path] if has_apic: fit_style_file = os.path.join(FIT_BASE_DIR, "fit_styles", FIT_TYPES[fit_type] + "_fit_style.json") else: fit_style_file = os.path.join(FIT_BASE_DIR, "fit_styles", FIT_TYPES[fit_type] + "_noapic_fit_style.json") - config["biophys"][0]["model_file"].append( fit_style_file ) + config["biophys"][0]["model_file"].append(fit_style_file) config["manifest"].append({"type": "dir", "spec": new_fit_type_dir, "key": "FITDIR"}) config["manifest"].append({"type": "file", "spec": start_pop_file, "key": "STARTPOP"}) @@ -70,46 +71,51 @@ def prepare_stage_2(output_directory): json_utilities.write(config_path, config) for seed in SEEDS: - logfile = os.path.join(new_fit_type_dir, 's%d' % seed, 'stage_2.log') - - jobs.append({ - 'config_path': os.path.abspath(config_path), - 'fit_type': fit_type, - 'log': os.path.abspath(logfile), - 'seed': seed, - 'num_processes': DEFAULT_NUM_PROCESSES - }) + logfile = os.path.join(new_fit_type_dir, "s%d" % seed, "stage_2.log") + + jobs.append( + { + "config_path": os.path.abspath(config_path), + "fit_type": fit_type, + "log": os.path.abspath(logfile), + "seed": seed, + "num_processes": DEFAULT_NUM_PROCESSES, + } + ) return jobs def run_stage_2(jobs): for job in jobs: - args = [MPIEXEC, - '-np', str(job['num_processes']), - sys.executable, - '-m', - optimize.__name__, - str(job['seed']), - job['config_path'], - str(optimize.DEFAULT_NGEN), - str(optimize.DEFAULT_MU)] + args = [ + MPIEXEC, + "-np", + str(job["num_processes"]), + sys.executable, + "-m", + optimize.__name__, + str(job["seed"]), + job["config_path"], + str(optimize.DEFAULT_NGEN), + str(optimize.DEFAULT_MU), + ] _fit_stage_2_log.debug(args) - with open(job['log'], "w") as outfile: + with open(job["log"], "w") as outfile: subprocess.call(args, stdout=outfile) def main(): - parser = argparse.ArgumentParser(description='Set up DEAP-style fit for second stage') - parser.add_argument('--output_dir', required=True) - parser.add_argument('specimen_id', type=int) + parser = argparse.ArgumentParser(description="Set up DEAP-style fit for second stage") + parser.add_argument("--output_dir", required=True) + parser.add_argument("specimen_id", type=int) args = parser.parse_args() - output_directory = os.path.join(args.output_dir, 'specimen_%d' % args.specimen_id) + output_directory = os.path.join(args.output_dir, "specimen_%d" % args.specimen_id) jobs = prepare_stage_2(output_directory) run_stage_2(jobs) -if __name__ == "__main__": - main() +if __name__ == "__main__": + main() diff --git a/allensdk/internal/model/biophysical/make_deap_fit_json.py b/allensdk/internal/model/biophysical/make_deap_fit_json.py index 84e0754d47..18384c16f4 100755 --- a/allensdk/internal/model/biophysical/make_deap_fit_json.py +++ b/allensdk/internal/model/biophysical/make_deap_fit_json.py @@ -8,59 +8,53 @@ from allensdk.internal.model.biophysical import ephys_utils from allensdk.internal.model.biophysical.deap_utils import Utils + class Report: - _log = logging.getLogger('allensdk.model.biophysical.make_deap_fit_json') + _log = logging.getLogger("allensdk.model.biophysical.make_deap_fit_json") - def __init__(self, - top_level_description, - fit_type): + def __init__(self, top_level_description, fit_type): self.utils = None self.top_level_description = top_level_description self.description = None self.manifest = None - self.specimen_id = str(self.top_level_description.data['runs'][0]['specimen_id']) + self.specimen_id = str(self.top_level_description.data["runs"][0]["specimen_id"]) self.fit_type = fit_type - self.target_path = self.top_level_description.manifest.get_path('target_path') + self.target_path = self.top_level_description.manifest.get_path("target_path") self.target = ju.read(self.target_path) - + self.seeds = [1234, 1001, 4321, 1024, 2048] - self.org_selections = [0, 100, 500, 1000] # Picks thek best, 100th best, etc. organisms as examples + self.org_selections = [0, 100, 500, 1000] # Picks thek best, 100th best, etc. organisms as examples self.trace_colors = ["#1b9e77", "#d95f02", "#7570b3", "#e7298a"] - - self.config_path = self.top_level_description.manifest.get_path('fit_config_json', - self.fit_type) + + self.config_path = self.top_level_description.manifest.get_path("fit_config_json", self.fit_type) self.fit_config = Config().load(self.config_path) - + fit_style_path = self.fit_config.data["biophys"][0]["model_file"][-1] - + self.fit_style_info = ju.read(fit_style_path) self.used_features = self.fit_style_info["features"] self.all_params = self.fit_style_info["channels"] + self.fit_style_info["addl_params"] - - nwb_path = self.top_level_description.manifest.get_path('stimulus_path') + + nwb_path = self.top_level_description.manifest.get_path("stimulus_path") self.data_set = NwbDataSet(nwb_path) - - self.neuronal_model_data = ju.read(self.top_level_description.manifest.get_path('neuronal_model_data')) - self.specimen_data = self.neuronal_model_data['specimen'] - self.all_sweeps = self.specimen_data['ephys_sweeps'] - - - + + self.neuronal_model_data = ju.read(self.top_level_description.manifest.get_path("neuronal_model_data")) + self.specimen_data = self.neuronal_model_data["specimen"] + self.all_sweeps = self.specimen_data["ephys_sweeps"] + def best_fit_value(self): return self.all_hof_fit_errors[self.sorted_indexes[self.org_selections[0]]] - - + def generate_fit_file(self): self.gather_from_seeds() self.setup_model() self.check_org_selections_for_noise_block() self.make_fit_json_file() - - + def make_fit_json_file(self): json_data = {} - + # passive json_data["passive"] = [{}] json_data["passive"][0]["ra"] = self.target["passive"][0]["ra"] @@ -68,16 +62,16 @@ def make_fit_json_file(self): json_data["passive"][0]["cm"] = [] for k in self.target["passive"][0]["cm"]: json_data["passive"][0]["cm"].append({"section": k, "cm": self.target["passive"][0]["cm"][k]}) - + # fitting json_data["fitting"] = [{}] json_data["fitting"][0]["sweeps"] = self.target["fitting"][0]["sweeps"] json_data["fitting"][0]["junction_potential"] = self.target["fitting"][0]["junction_potential"] - + # conditions json_data["conditions"] = self.fit_style_info["conditions"] json_data["conditions"][0]["v_init"] = self.target["passive"][0]["e_pas"] - + # genome json_data["genome"] = [] genome_vals = self.all_hof_fits[self.sorted_indexes[self.org_selections[0]], :] @@ -86,51 +80,41 @@ def make_fit_json_file(self): param_name = p["parameter"] + "_" + p["mechanism"] else: param_name = p["parameter"] - json_data["genome"].append({"value": genome_vals[i], - "section": p["section"], - "name": param_name, - "mechanism": p["mechanism"] - }) - + json_data["genome"].append( + {"value": genome_vals[i], "section": p["section"], "name": param_name, "mechanism": p["mechanism"]} + ) + # write out file - with open(self.top_level_description.manifest.get_path('output_fit_file', - self.specimen_id, - self.fit_type), "w") as f: + with open( + self.top_level_description.manifest.get_path("output_fit_file", self.specimen_id, self.fit_type), "w" + ) as f: json.dump(json_data, f, indent=2) - - + def setup_model(self): - morphology_path = os.path.realpath(self.top_level_description.manifest.get_path('MORPHOLOGY')) + morphology_path = os.path.realpath(self.top_level_description.manifest.get_path("MORPHOLOGY")) self.utils = Utils(self.fit_config) h = self.utils.h self.utils.generate_morphology(morphology_path) self.utils.load_cell_parameters() self.utils.insert_iclamp() self.stim_params = self.fit_config.data["stimulus"][0] - self.utils.set_iclamp_params(self.stim_params["amplitude"], - self.stim_params["delay"], - self.stim_params["duration"]) + self.utils.set_iclamp_params( + self.stim_params["amplitude"], self.stim_params["delay"], self.stim_params["duration"] + ) h.tstop = self.stim_params["delay"] * 2.0 + self.stim_params["duration"] h.cvode_active(1) h.cvode.atolscale("cai", 1e-4) h.cvode.maxstep(10) - - + def gather_from_seeds(self): first_created = False for s in self.seeds: - final_hof_fit_path = \ - self.top_level_description.manifest.get_path('final_hof_fit', - self.fit_type, - s) - final_hof_path = \ - self.top_level_description.manifest.get_path('final_hof', - self.fit_type, - s) + final_hof_fit_path = self.top_level_description.manifest.get_path("final_hof_fit", self.fit_type, s) + final_hof_path = self.top_level_description.manifest.get_path("final_hof", self.fit_type, s) if not os.path.exists(final_hof_fit_path): Report._log.warn("Could not find output file %s for seed %d" % (final_hof_fit_path, s)) continue - + hof_fit_errors = np.loadtxt(final_hof_fit_path) hof_fits = np.loadtxt(final_hof_path) if not first_created: @@ -144,13 +128,12 @@ def gather_from_seeds(self): self.all_hof_fit_errors = all_hof_fit_errors self.sorted_indexes = np.argsort(self.all_hof_fit_errors) - def check_org_selections_for_noise_block(self): h = self.utils.h v_vec, i_vec, t_vec = self.utils.record_values() - depol_block_threshold = -50.0 # mV - block_min_duration = 50.0 # ms + depol_block_threshold = -50.0 # mV + block_min_duration = 50.0 # ms h.cvode_active(0) noise_i_stim = [] @@ -197,8 +180,9 @@ def check_org_selections_for_noise_block(self): use_ii = ii break h.cvode_active(1) - self.utils.set_iclamp_params(self.stim_params["amplitude"], self.stim_params["delay"], - self.stim_params["duration"]) + self.utils.set_iclamp_params( + self.stim_params["amplitude"], self.stim_params["delay"], self.stim_params["duration"] + ) self.utils.h.tstop = self.stim_params["delay"] * 2.0 + self.stim_params["duration"] if use_ii == -1: diff --git a/allensdk/internal/model/biophysical/neuron_parallel.py b/allensdk/internal/model/biophysical/neuron_parallel.py index b3fd0c9312..5222545324 100644 --- a/allensdk/internal/model/biophysical/neuron_parallel.py +++ b/allensdk/internal/model/biophysical/neuron_parallel.py @@ -1,14 +1,15 @@ from neuron import h import logging -_neuron_parallel_log = logging.getLogger('allensdk.model.biophysical.neuron_parallel') +_neuron_parallel_log = logging.getLogger("allensdk.model.biophysical.neuron_parallel") _pc = h.ParallelContext() + def map(func, *iterables): start_time = pc_time() userids = [] - userid = 200 # arbitrary, but needs to be a positive integer + userid = 200 # arbitrary, but needs to be a positive integer for args in zip(*iterables): args2 = (list(a) for a in args) _pc.submit(userid, func, *args2) @@ -19,28 +20,31 @@ def map(func, *iterables): _neuron_parallel_log.debug("Map took %s" % (str(end_time - start_time))) return [results[userid] for userid in userids] + def working(): while _pc.working(): userid = int(_pc.userid()) ret = _pc.pyret() yield userid, ret + def runworker(): _pc.runworker() + def done(): _pc.done() + def pc_time(): return _pc.time() + def reset_neuron_library(): - ''' + """ See Also: https://www.neuron.yale.edu/phpBB/viewtopic.php?f=2&t=2367 - ''' + """ _pc.gid_clear() - + for sec in h.allsec(): - h("%s{delete_section()}" % (sec.name()) ) - - \ No newline at end of file + h("%s{delete_section()}" % (sec.name())) diff --git a/allensdk/internal/model/biophysical/optimize.py b/allensdk/internal/model/biophysical/optimize.py index 90830f143f..79775ea1d6 100755 --- a/allensdk/internal/model/biophysical/optimize.py +++ b/allensdk/internal/model/biophysical/optimize.py @@ -1,4 +1,4 @@ -from mpi4py import MPI # needed for NEURON parallel execution +from mpi4py import MPI # needed for NEURON parallel execution import os from allensdk.internal.model.biophysical.deap_utils import Utils from . import neuron_parallel @@ -17,7 +17,7 @@ DEFAULT_MU = 1200 -_optimize_log = logging.getLogger('allensdk.model.biophysical.optimize') +_optimize_log = logging.getLogger("allensdk.model.biophysical.optimize") utils = None @@ -42,15 +42,13 @@ def eval_param_set(params): if check_for_block(): feature_errors = min_fail_penalty * np.ones_like(feature_errors) # Reset the stimulus back - utils.set_iclamp_params(stim_params["amplitude"], stim_params["delay"], - stim_params["duration"]) + utils.set_iclamp_params(stim_params["amplitude"], stim_params["delay"], stim_params["duration"]) return [np.sum(feature_errors)] def check_for_block(): - utils.set_iclamp_params(max_stim_amp, stim_params["delay"], - stim_params["duration"]) + utils.set_iclamp_params(max_stim_amp, stim_params["delay"], stim_params["duration"]) h.finitialize() h.run() @@ -58,9 +56,9 @@ def check_for_block(): t = t_vec.as_numpy() stim_start_idx = np.flatnonzero(t >= utils.stim.delay)[0] stim_end_idx = np.flatnonzero(t >= utils.stim.delay + utils.stim.dur)[0] - depol_block_threshold = -50.0 # mV - block_min_duration = 50.0 # ms - long_hyperpol_threshold = -75.0 # mV + depol_block_threshold = -50.0 # mV + block_min_duration = 50.0 # ms + long_hyperpol_threshold = -75.0 # mV bool_v = np.array(v > depol_block_threshold, dtype=int) up_indexes = np.flatnonzero(np.diff(bool_v) == 1) @@ -81,7 +79,9 @@ def check_for_block(): down_indexes = np.flatnonzero(np.diff(bool_v) == -1) down_indexes = down_indexes[(down_indexes > stim_start_idx) & (down_indexes < stim_end_idx)] if len(down_indexes) != 0: - up_indexes = up_indexes[(up_indexes > stim_start_idx) & (up_indexes < stim_end_idx) & (up_indexes > down_indexes[0])] + up_indexes = up_indexes[ + (up_indexes > stim_start_idx) & (up_indexes < stim_end_idx) & (up_indexes > down_indexes[0]) + ] if len(up_indexes) < len(down_indexes): up_indexes = np.append(up_indexes, [stim_end_idx]) max_hyperpol_duration = np.max([t[up_indexes[k]] - t[down_idx] for k, down_idx in enumerate(down_indexes)]) @@ -108,22 +108,22 @@ def initPopulation(pcls, ind_init, popfile): def main(): global utils, h, v_vec, i_vec, t_vec, do_block_check, max_stim_amp, stim_params, config, seed - parser = argparse.ArgumentParser(description='Start a DEAP testing run.') - parser.add_argument('seed', type=int) - parser.add_argument('config_path') - parser.add_argument('ngen', type=int) - parser.add_argument('mu', type=int) + parser = argparse.ArgumentParser(description="Start a DEAP testing run.") + parser.add_argument("seed", type=int) + parser.add_argument("config_path") + parser.add_argument("ngen", type=int) + parser.add_argument("mu", type=int) args = parser.parse_args() seed = args.seed # Set up NEURON config = Config().load(args.config_path) - if 'LOG_CFG' in os.environ: - log_config = os.environ['LOG_CFG'] + if "LOG_CFG" in os.environ: + log_config = os.environ["LOG_CFG"] else: - log_config = str(files('allensdk.model.biophysical').joinpath('logging.conf')) - os.environ['LOG_CFG'] = log_config + log_config = str(files("allensdk.model.biophysical").joinpath("logging.conf")) + os.environ["LOG_CFG"] = log_config lc.fileConfig(log_config) stim_params = config.data["stimulus"][0] @@ -140,12 +140,11 @@ def main(): h = utils.h manifest = config.manifest - morphology_path = manifest.get_path('MORPHOLOGY') - utils.generate_morphology(morphology_path.encode('ascii', 'ignore')) + morphology_path = manifest.get_path("MORPHOLOGY") + utils.generate_morphology(morphology_path.encode("ascii", "ignore")) utils.load_cell_parameters() utils.insert_iclamp() - utils.set_iclamp_params(stim_params["amplitude"], stim_params["delay"], - stim_params["duration"]) + utils.set_iclamp_params(stim_params["amplitude"], stim_params["delay"], stim_params["duration"]) h.tstop = stim_params["delay"] * 2.0 + stim_params["duration"] h.cvode_active(1) @@ -154,7 +153,7 @@ def main(): v_vec, i_vec, t_vec = utils.record_values() - try: # Wrapping this all to catch exceptions during NEURON parallel execution + try: # Wrapping this all to catch exceptions during NEURON parallel execution neuron_parallel.runworker() # Set up genetic algorithm @@ -170,7 +169,7 @@ def main(): ndim = len(config.data["channels"]) + len(config.data["addl_params"]) - creator.create("FitnessMin", base.Fitness, weights=(-1.0, )) + creator.create("FitnessMin", base.Fitness, weights=(-1.0,)) creator.create("Individual", list, fitness=creator.FitnessMin) toolbox = base.Toolbox() @@ -180,10 +179,8 @@ def main(): toolbox.register("population", tools.initRepeat, list, toolbox.individual) toolbox.register("evaluate", eval_param_set) - toolbox.register("mate", tools.cxSimulatedBinaryBounded, low=BOUND_LOWER, up=BOUND_UPPER, - eta=eta) - toolbox.register("mutate", tools.mutPolynomialBounded, low=BOUND_LOWER, up=BOUND_UPPER, - eta=eta, indpb=mtpb) + toolbox.register("mate", tools.cxSimulatedBinaryBounded, low=BOUND_LOWER, up=BOUND_UPPER, eta=eta) + toolbox.register("mutate", tools.mutPolynomialBounded, low=BOUND_LOWER, up=BOUND_UPPER, eta=eta, indpb=mtpb) toolbox.register("variate", algorithms.varAnd) toolbox.register("select", tools.selBest) toolbox.register("map", neuron_parallel.map) @@ -243,7 +240,7 @@ def main(): h.quit() except RuntimeError: _optimize_log.critical("Exception encountered during parallel NEURON execution") - MPI.COMM_WORLD.Abort() # Shut down all the processes + MPI.COMM_WORLD.Abort() # Shut down all the processes if __name__ == "__main__": diff --git a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit.py b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit.py index 97458b2b24..c5a86ae9bb 100755 --- a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit.py +++ b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit.py @@ -7,7 +7,7 @@ # Load the morphology -BASEDIR = os.path.dirname(__file__)#"/data/mat/nathang/deap_optimize/passive_fitting" +BASEDIR = os.path.dirname(__file__) # "/data/mat/nathang/deap_optimize/passive_fitting" @neuron_utils.read_neuron_fit_stdout @@ -18,7 +18,7 @@ def neuron_passive_fit(up_data, down_data, swc_path, limit): neuron_utils.load_morphology(swc_path) for sec in h.allsec(): - sec.insert('pas') + sec.insert("pas") for seg in sec: seg.pas.e = 0 @@ -79,21 +79,17 @@ def neuron_passive_fit(up_data, down_data, swc_path, limit): h.region_areas() - return { - 'Ri': fit_Ri, - 'Cm': fit_Cm, - 'Rm': fit_Rm, - 'err': minerr - } + return {"Ri": fit_Ri, "Cm": fit_Cm, "Rm": fit_Rm, "err": minerr} + def arg_parser(): - parser = argparse.ArgumentParser(description='analyze cap check sweep') - parser.add_argument('--up_file') - parser.add_argument('--down_file') - parser.add_argument('--swc_path') - parser.add_argument('--specimen_id', type=int, required=True) - parser.add_argument('--limit', type=float, required=True) - parser.add_argument('--output_file', required=True) + parser = argparse.ArgumentParser(description="analyze cap check sweep") + parser.add_argument("--up_file") + parser.add_argument("--down_file") + parser.add_argument("--swc_path") + parser.add_argument("--specimen_id", type=int, required=True) + parser.add_argument("--limit", type=float, required=True) + parser.add_argument("--output_file", required=True) return parser @@ -102,29 +98,30 @@ def process_inputs(parser): swc_path = args.swc_path up_data = np.loadtxt(args.up_file) down_data = np.loadtxt(args.down_file) - + return args, up_data, down_data, swc_path def main(): import sys - + manifest_path = sys.argv[-1] limit = float(sys.argv[-2]) os.chdir(os.path.dirname(manifest_path)) app_config = Config() description = app_config.load(manifest_path) - - upfile = description.manifest.get_path('upfile') - up_data = np.loadtxt(upfile) - downfile = description.manifest.get_path('downfile') + + upfile = description.manifest.get_path("upfile") + up_data = np.loadtxt(upfile) + downfile = description.manifest.get_path("downfile") down_data = np.loadtxt(downfile) - swc_path = description.manifest.get_path('MORPHOLOGY') - + swc_path = description.manifest.get_path("MORPHOLOGY") + data = neuron_passive_fit(up_data, down_data, swc_path, limit) - output_file = description.manifest.get_path('fit_1_file') - + output_file = description.manifest.get_path("fit_1_file") + json_utilities.write(output_file, data) + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit2.py b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit2.py index d3c2a9d69f..0aea592e4b 100755 --- a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit2.py +++ b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit2.py @@ -9,6 +9,7 @@ BASEDIR = os.path.dirname(__file__) + @neuron_utils.read_neuron_fit_stdout def neuron_passive_fit2(up_data, down_data, swc_path, limit): h = neuron_utils.get_h() @@ -17,7 +18,7 @@ def neuron_passive_fit2(up_data, down_data, swc_path, limit): neuron_utils.load_morphology(swc_path) for sec in h.allsec(): - sec.insert('pas') + sec.insert("pas") for seg in sec: seg.pas.e = 0 @@ -74,32 +75,29 @@ def neuron_passive_fit2(up_data, down_data, swc_path, limit): minerr = mrf.opt.minerr h.region_areas() - return { - 'Ri': fit_Ri, - 'Cm': fit_Cm, - 'Rm': fit_Rm, - 'err': minerr - } + return {"Ri": fit_Ri, "Cm": fit_Cm, "Rm": fit_Rm, "err": minerr} + def main(): import sys - + manifest_path = sys.argv[-1] limit = float(sys.argv[-2]) os.chdir(os.path.dirname(manifest_path)) app_config = Config() description = app_config.load(manifest_path) - - upfile = description.manifest.get_path('upfile') - up_data = np.loadtxt(upfile) - downfile = description.manifest.get_path('downfile') + + upfile = description.manifest.get_path("upfile") + up_data = np.loadtxt(upfile) + downfile = description.manifest.get_path("downfile") down_data = np.loadtxt(downfile) - swc_path = description.manifest.get_path('MORPHOLOGY') - output_file = description.manifest.get_path('fit_2_file') - + swc_path = description.manifest.get_path("MORPHOLOGY") + output_file = description.manifest.get_path("fit_2_file") + data = neuron_passive_fit2(up_data, down_data, swc_path, limit) - + json_utilities.write(output_file, data) + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit_elec.py b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit_elec.py index c138e80507..4880829a94 100755 --- a/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit_elec.py +++ b/allensdk/internal/model/biophysical/passive_fitting/neuron_passive_fit_elec.py @@ -10,20 +10,16 @@ BASEDIR = os.path.dirname(__file__) + @neuron_utils.read_neuron_fit_stdout -def neuron_passive_fit_elec(up_data, - down_data, - swc_path, - limit, - bridge, - elec_cap): +def neuron_passive_fit_elec(up_data, down_data, swc_path, limit, bridge, elec_cap): h = neuron_utils.get_h() h.load_file("stdgui.hoc") h.load_file("import3d.hoc") neuron_utils.load_morphology(swc_path) for sec in h.allsec(): - sec.insert('pas') + sec.insert("pas") for seg in sec: seg.pas.e = 0 @@ -88,16 +84,12 @@ def neuron_passive_fit_elec(up_data, minerr = mrf.opt.minerr h.region_areas() - return { - 'Ri': fit_Ri, - 'Cm': fit_Cm, - 'Rm': fit_Rm, - 'err': minerr - } + return {"Ri": fit_Ri, "Cm": fit_Cm, "Rm": fit_Rm, "err": minerr} + def main(): import sys - + manifest_path = sys.argv[-1] elec_cap = float(sys.argv[-2]) bridge = float(sys.argv[-3]) @@ -106,17 +98,17 @@ def main(): app_config = Config() description = app_config.load(manifest_path) - upfile = description.manifest.get_path('upfile') - up_data = np.loadtxt(upfile) - downfile = description.manifest.get_path('downfile') + upfile = description.manifest.get_path("upfile") + up_data = np.loadtxt(upfile) + downfile = description.manifest.get_path("downfile") down_data = np.loadtxt(downfile) - swc_path = description.manifest.get_path('MORPHOLOGY') + swc_path = description.manifest.get_path("MORPHOLOGY") data = neuron_passive_fit_elec(up_data, down_data, swc_path, limit, bridge, elec_cap) - output_file = description.manifest.get_path('fit_3_file') + output_file = description.manifest.get_path("fit_3_file") json_utilities.write(output_file, data) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/model/biophysical/passive_fitting/neuron_utils.py b/allensdk/internal/model/biophysical/passive_fitting/neuron_utils.py index 4ef3fdb543..89234e576e 100644 --- a/allensdk/internal/model/biophysical/passive_fitting/neuron_utils.py +++ b/allensdk/internal/model/biophysical/passive_fitting/neuron_utils.py @@ -1,15 +1,19 @@ # in place of global from neuron import h -def get_h(): + +def get_h(): if get_h.h is None: from neuron import h + get_h.h = h return get_h.h - + + get_h.h = None from .output_grabber import OutputGrabber + def load_morphology(filename): h = get_h() swc = h.Import3d_SWC_read() @@ -22,8 +26,8 @@ def load_morphology(filename): def parse_neuron_output(output_str): printed_fields = {} - for line in output_str.split('\n'): - if line.startswith('nquad'): + for line in output_str.split("\n"): + if line.startswith("nquad"): continue toks = line.split() if len(toks) == 2: @@ -32,7 +36,7 @@ def parse_neuron_output(output_str): v = float(v) except Exception: pass - + printed_fields[toks[0].strip()] = v return printed_fields @@ -40,7 +44,6 @@ def parse_neuron_output(output_str): def read_neuron_fit_stdout(func): def call(*args, **kwargs): - g = OutputGrabber() g.start() data = func(*args, **kwargs) @@ -52,5 +55,3 @@ def call(*args, **kwargs): return data return call - - diff --git a/allensdk/internal/model/biophysical/passive_fitting/output_grabber.py b/allensdk/internal/model/biophysical/passive_fitting/output_grabber.py index 75bc112aea..bd9222599a 100644 --- a/allensdk/internal/model/biophysical/passive_fitting/output_grabber.py +++ b/allensdk/internal/model/biophysical/passive_fitting/output_grabber.py @@ -3,10 +3,12 @@ import threading import time + class OutputGrabber(object): """ Class used to grab standard output or another stream. """ + escape_char = "\b" def __init__(self, stream=None, threaded=False): @@ -19,7 +21,6 @@ def __init__(self, stream=None, threaded=False): # Create a pipe so the stream can be captured: self.pipe_out, self.pipe_in = os.pipe() - def start(self): """ Start capturing the stream data. @@ -36,7 +37,6 @@ def start(self): # Make sure that the thread is running and os.read is executed: time.sleep(0.01) - def stop(self): """ Stop capturing the stream data and save the text in `capturedtext`. @@ -58,7 +58,6 @@ def stop(self): # Restore the original stream: os.dup2(self.streamfd, self.origstreamfd) - def readOutput(self): """ Read the stream data (one byte at a time) diff --git a/allensdk/internal/model/biophysical/passive_fitting/preprocess.py b/allensdk/internal/model/biophysical/passive_fitting/preprocess.py index 2a15153ab4..c4093228bf 100644 --- a/allensdk/internal/model/biophysical/passive_fitting/preprocess.py +++ b/allensdk/internal/model/biophysical/passive_fitting/preprocess.py @@ -3,18 +3,17 @@ import numpy as np import pandas as pd -_passive_fit_log = logging.getLogger( - 'allensdk.model.biophysical.passive_fitting.preprocess') +_passive_fit_log = logging.getLogger("allensdk.model.biophysical.passive_fitting.preprocess") + def get_passive_fit_data(cap_check_sweeps, data_set): - bridge_balances = [s['bridge_balance_mohm'] for s in cap_check_sweeps] + bridge_balances = [s["bridge_balance_mohm"] for s in cap_check_sweeps] bridge_avg = np.array(bridge_balances).mean() _passive_fit_log.debug("bridge avg {:.2f}".format(bridge_avg)) initialized = False for idx, s in enumerate(cap_check_sweeps): - v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, - s['sweep_number']) + v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s["sweep_number"]) if v is None: continue up_idxs, down_idxs = get_cap_check_indices(i) @@ -22,14 +21,14 @@ def get_passive_fit_data(cap_check_sweeps, data_set): skip_count = 0 for j in range(len(up_idxs)): if j == 0: - avg_up = v[(up_idxs[j] - 400):down_idxs[j + 1]] - avg_down = v[(down_idxs[j] - 400):up_idxs[j]] + avg_up = v[(up_idxs[j] - 400) : down_idxs[j + 1]] + avg_down = v[(down_idxs[j] - 400) : up_idxs[j]] elif j == len(up_idxs) - 1: - avg_up = avg_up + v[(up_idxs[j] - 400):-2] - avg_down = avg_down + v[(down_idxs[j] - 400):up_idxs[j]] + avg_up = avg_up + v[(up_idxs[j] - 400) : -2] + avg_down = avg_down + v[(down_idxs[j] - 400) : up_idxs[j]] else: - avg_up = avg_up + v[(up_idxs[j] - 400):down_idxs[j + 1]] - avg_down = avg_down + v[(down_idxs[j] - 400):up_idxs[j]] + avg_up = avg_up + v[(up_idxs[j] - 400) : down_idxs[j + 1]] + avg_down = avg_down + v[(down_idxs[j] - 400) : up_idxs[j]] avg_up /= len(up_idxs) - skip_count avg_down /= len(up_idxs) - skip_count if not initialized: @@ -42,7 +41,7 @@ def get_passive_fit_data(cap_check_sweeps, data_set): grand_up /= len(cap_check_sweeps) grand_down /= len(cap_check_sweeps) - t = 0.005 * np.arange(len(grand_up)) # in ms, assumes 200kHz sampling rate] + t = 0.005 * np.arange(len(grand_up)) # in ms, assumes 200kHz sampling rate] grand_up_data = np.column_stack((t, grand_up)) grand_down_data = np.column_stack((t, grand_down)) @@ -58,12 +57,8 @@ def get_passive_fit_data(cap_check_sweeps, data_set): escape_index = escape_indexes[0] escape_t = t[escape_index] - return { - 'grand_up': grand_up_data, - 'grand_down': grand_down_data, - 'escape_t': escape_t, - 'bridge_avg': bridge_avg - } + return {"grand_up": grand_up_data, "grand_down": grand_down_data, "escape_t": escape_t, "bridge_avg": bridge_avg} + def get_cap_check_indices(i): # Assumes that there is a test pulse followed by the stimulus pulses (downward first) @@ -74,7 +69,9 @@ def get_cap_check_indices(i): return up_idx[2::2], down_idx[1::2] -def main(): +def main(): pass + + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/biophysical/run_optimize.py b/allensdk/internal/model/biophysical/run_optimize.py index e50f54f129..e6797d73b9 100755 --- a/allensdk/internal/model/biophysical/run_optimize.py +++ b/allensdk/internal/model/biophysical/run_optimize.py @@ -14,208 +14,187 @@ class RunOptimize(object): - _log = logging.getLogger('allensdk.internal.model.biophysical.run_optimize') + _log = logging.getLogger("allensdk.internal.model.biophysical.run_optimize") - def __init__(self, - input_json, - output_json): + def __init__(self, input_json, output_json): self.input_json = input_json self.output_json = output_json self.app_config = None self.manifest = None self.data_set = None - def load_manifest(self): self.app_config = Config().load(self.input_json) self.manifest = self.app_config.manifest - self.data_set = NwbDataSet(self.manifest.get_path('stimulus_path')) - + self.data_set = NwbDataSet(self.manifest.get_path("stimulus_path")) def nrnivmodl(self): RunOptimize._log.debug("nrnivmodl") - subprocess.call(['nrnivmodl', './modfiles']) - + subprocess.call(["nrnivmodl", "./modfiles"]) def info(self, lims_json_path): - ''' return a string that a bash script can use + """return a string that a bash script can use to find the working directory, etc. to clean up. - ''' + """ ocr = OptimizeConfigReader() ocr.read_lims_file(lims_json_path) - - print(self.app_config.data['runs'][0]['specimen_id']) - print(self.manifest.get_path('BASEDIR')) + print(self.app_config.data["runs"][0]["specimen_id"]) + print(self.manifest.get_path("BASEDIR")) def copy_local(self): - ''' + """ Note ---- For files that aren't needed for local debugging, use write_manifest instead. - ''' + """ self.load_manifest() - modfile_dir = self.manifest.get_path('MODFILE_DIR') + modfile_dir = self.manifest.get_path("MODFILE_DIR") if not os.path.exists(modfile_dir): os.mkdir(modfile_dir) - output_dir = self.manifest.get_path('WORKDIR') + output_dir = self.manifest.get_path("WORKDIR") if not os.path.exists(output_dir): os.mkdir(output_dir) - modfiles = [self.manifest.get_path(key) for key,info - in self.manifest.path_info.items() - if 'format' in info and info['format'] == 'MODFILE'] + modfiles = [ + self.manifest.get_path(key) + for key, info in self.manifest.path_info.items() + if "format" in info and info["format"] == "MODFILE" + ] for from_file in modfiles: RunOptimize._log.debug("copying %s to %s" % (from_file, modfile_dir)) shutil.copy(from_file, modfile_dir) - shutil.copy(str(files('allensdk.model.biophysical').joinpath('cell.hoc')), - self.manifest.get_path('BASEDIR')) + shutil.copy(str(files("allensdk.model.biophysical").joinpath("cell.hoc")), self.manifest.get_path("BASEDIR")) - - def generate_manifest_rma(self, - neuronal_model_id, - manifest_path, - api_url=None): - ''' + def generate_manifest_rma(self, neuronal_model_id, manifest_path, api_url=None): + """ Note ---- Other necessary files are also written. - ''' + """ import json bma = BiophysicalModuleApi(api_url) data = bma.get_neuronal_models(neuronal_model_id) ocr = OptimizeConfigReader() - ocr.read_lims_message(data, 'lims_message.json') + ocr.read_lims_message(data, "lims_message.json") - with open('lims_message.json', 'w') as f: + with open("lims_message.json", "w") as f: f.write(json.dumps(data[0], sort_keys=True, indent=2)) ocr.to_manifest(manifest_path) - - def generate_manifest_lims(self, - lims_json_path, - manifest_path): - ''' + def generate_manifest_lims(self, lims_json_path, manifest_path): + """ Note ---- Other necessary files are also written. - ''' + """ ocr = OptimizeConfigReader() ocr.read_lims_file(lims_json_path) ocr.to_manifest(manifest_path) - def start_specimen(self): import allensdk.internal.model.biophysical.run_passive_fit as run_passive_fit import allensdk.internal.model.biophysical.fit_stage_1 as fit_stage_1 import allensdk.internal.model.biophysical.fit_stage_2 as fit_stage_2 self.load_manifest() - - self.passive_fit_data = \ - run_passive_fit.run_passive_fit(self.app_config) - ju.write(self.manifest.get_path('passive_fit_data'), - self.passive_fit_data) + self.passive_fit_data = run_passive_fit.run_passive_fit(self.app_config) + + ju.write(self.manifest.get_path("passive_fit_data"), self.passive_fit_data) - self.stage_1_jobs = \ - fit_stage_1.prepare_stage_1(self.app_config, - self.passive_fit_data) + self.stage_1_jobs = fit_stage_1.prepare_stage_1(self.app_config, self.passive_fit_data) - ju.write(self.manifest.get_path('stage_1_jobs'), - self.stage_1_jobs) + ju.write(self.manifest.get_path("stage_1_jobs"), self.stage_1_jobs) fit_stage_1.run_stage_1(self.stage_1_jobs) - output_directory = self.manifest.get_path('WORKDIR') + output_directory = self.manifest.get_path("WORKDIR") stage_2_jobs = fit_stage_2.prepare_stage_2(output_directory) fit_stage_2.run_stage_2(stage_2_jobs) - def make_fit(self): self.load_manifest() - + fit_types = ["f9", "f13"] - best_fit_values = { fit_type: None for fit_type in fit_types } + best_fit_values = {fit_type: None for fit_type in fit_types} - specimen_id = self.app_config.data['runs'][0]['specimen_id'] + specimen_id = self.app_config.data["runs"][0]["specimen_id"] for fit_type in fit_types: - fit_type_dir = self.manifest.get_path('fit_type_path', fit_type) + fit_type_dir = self.manifest.get_path("fit_type_path", fit_type) if os.path.exists(fit_type_dir): - report = Report(self.app_config, - fit_type) + report = Report(self.app_config, fit_type) report.generate_fit_file() if fit_type in best_fit_values.keys(): best_fit_values[fit_type] = report.best_fit_value() - best_fit_type, min_fit_value = reduce(lambda a, b: a if (a[1] < b[1]) else b, - (i for i in best_fit_values.items() if i[1] is not None)) - best_fit_file = self.manifest.get_path('output_fit_file', - specimen_id, - best_fit_type) + best_fit_type, min_fit_value = reduce( + lambda a, b: a if (a[1] < b[1]) else b, (i for i in best_fit_values.items() if i[1] is not None) + ) + best_fit_file = self.manifest.get_path("output_fit_file", specimen_id, best_fit_type) lims_upload_config = OptimizeConfigReader() - lims_upload_config.read_json(self.manifest.get_path('neuronal_model_data')) + lims_upload_config.read_json(self.manifest.get_path("neuronal_model_data")) - lims_upload_config.update_well_known_file(best_fit_file, - well_known_file_type_id=OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) + lims_upload_config.update_well_known_file( + best_fit_file, well_known_file_type_id=OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS + ) lims_upload_config.write_file(output_json) def main(command, input_json, output_json): - ''' Entry point for module. - :param command: select behavior, nrnivmodl or simulate - :type command: string - :param lims_strategy_json: path to json file output from lims. - :type lims_strategy_json: string - :param lims_response_json: path to json file returned to lims. - :type lims_response_json: string - ''' - - o = RunOptimize(input_json, - output_json) - - if 'LOG_CFG' in os.environ: - log_config = os.environ['LOG_CFG'] + """Entry point for module. + :param command: select behavior, nrnivmodl or simulate + :type command: string + :param lims_strategy_json: path to json file output from lims. + :type lims_strategy_json: string + :param lims_response_json: path to json file returned to lims. + :type lims_response_json: string + """ + + o = RunOptimize(input_json, output_json) + + if "LOG_CFG" in os.environ: + log_config = os.environ["LOG_CFG"] else: - log_config = str(files('allensdk.model.biophysical').joinpath('logging.conf')) - os.environ['LOG_CFG'] = log_config + log_config = str(files("allensdk.model.biophysical").joinpath("logging.conf")) + os.environ["LOG_CFG"] = log_config lc.fileConfig(log_config) - if 'nrnivmodl' == command: + if "nrnivmodl" == command: o.nrnivmodl() - elif 'info' == command: + elif "info" == command: o.info(input_json) - elif 'generate_manifest_rma' == command: + elif "generate_manifest_rma" == command: o.generate_manifest_rma(input_json, output_json) - elif 'generate_manifest_lims' == command: + elif "generate_manifest_lims" == command: o.generate_manifest_lims(input_json, output_json) - elif 'write_manifest' == command: + elif "write_manifest" == command: o.write_manifest() - elif 'copy_local' == command: + elif "copy_local" == command: o.copy_local() - elif 'start_specimen' == command: + elif "start_specimen" == command: o.start_specimen() - elif 'make_fit' == command: + elif "make_fit" == command: o.make_fit() else: RunOptimize._log.error("no command") - print('done') + print("done") -if __name__ == '__main__': +if __name__ == "__main__": import sys command, input_json, output_json = sys.argv[-3:] diff --git a/allensdk/internal/model/biophysical/run_optimize_workflow.py b/allensdk/internal/model/biophysical/run_optimize_workflow.py index 5bd24f0974..b9661f1c24 100644 --- a/allensdk/internal/model/biophysical/run_optimize_workflow.py +++ b/allensdk/internal/model/biophysical/run_optimize_workflow.py @@ -2,11 +2,11 @@ from subprocess import call from importlib.resources import files -the_script = str(files(__package__).joinpath('run_optimize.sh')) +the_script = str(files(__package__).joinpath("run_optimize.sh")) -cmd = ['/bin/bash', the_script] +cmd = ["/bin/bash", the_script] cmd.extend(sys.argv[1:]) -print(' '.join(cmd)) +print(" ".join(cmd)) -call(cmd) \ No newline at end of file +call(cmd) diff --git a/allensdk/internal/model/biophysical/run_passive_fit.py b/allensdk/internal/model/biophysical/run_passive_fit.py index 9d1e3c68f4..84c417c635 100644 --- a/allensdk/internal/model/biophysical/run_passive_fit.py +++ b/allensdk/internal/model/biophysical/run_passive_fit.py @@ -15,65 +15,80 @@ import logging.config as lc -_run_passive_fit_log = logging.getLogger('allensdk.internal.model.biophysical.run_passive_fit') +_run_passive_fit_log = logging.getLogger("allensdk.internal.model.biophysical.run_passive_fit") def run_passive_fit(description): - output_directory = description.manifest.get_path('WORKDIR') - neuronal_model = ju.read(description.manifest.get_path('neuronal_model_data')) - specimen_data = neuronal_model['specimen'] - - is_spiny = not any(t['name'] == u'dendrite type - aspiny' for t in specimen_data['specimen_tags']) - - all_sweeps = specimen_data['ephys_sweeps'] + output_directory = description.manifest.get_path("WORKDIR") + neuronal_model = ju.read(description.manifest.get_path("neuronal_model_data")) + specimen_data = neuronal_model["specimen"] + + is_spiny = not any(t["name"] == "dendrite type - aspiny" for t in specimen_data["specimen_tags"]) + + all_sweeps = specimen_data["ephys_sweeps"] if not os.path.exists(output_directory): os.makedirs(output_directory) - cap_check_sweeps, _, _ = \ - ephys_utils.get_sweeps_of_type('C1SQCAPCHK', - all_sweeps) - + cap_check_sweeps, _, _ = ephys_utils.get_sweeps_of_type("C1SQCAPCHK", all_sweeps) + passive_fit_data = {} if len(cap_check_sweeps) > 0: - data_set = NwbDataSet(description.manifest.get_path('stimulus_path')) + data_set = NwbDataSet(description.manifest.get_path("stimulus_path")) d = passive_prep.get_passive_fit_data(cap_check_sweeps, data_set) - grand_up_file = os.path.join(output_directory, 'upbase.dat') - np.savetxt(grand_up_file, d['grand_up']) - - grand_down_file = os.path.join(output_directory, 'downbase.dat') - np.savetxt(grand_down_file, d['grand_down']) - - passive_fit_data["bridge"] = d['bridge_avg'] - passive_fit_data["escape_time"] = d['escape_t'] - - fit_1_file = description.manifest.get_path('fit_1_file') - subprocess.check_output([sys.executable, - '-m', neuron_passive_fit.__name__, - str(d['escape_t']), - os.path.realpath(description.manifest.get_path('manifest')) ]) - passive_fit_data['fit_1'] = ju.read(fit_1_file) - - fit_2_file = description.manifest.get_path('fit_2_file') - - subprocess.check_output([sys.executable, - '-m', neuron_passive_fit2.__name__, - str(d['escape_t']), - os.path.realpath(description.manifest.get_path('manifest')) ]) - passive_fit_data['fit_2'] = ju.read(fit_2_file) - - fit_3_file = description.manifest.get_path('fit_3_file') - subprocess.check_output([sys.executable, - '-m', neuron_passive_fit_elec.__name__, - str(d['escape_t']), - str(d['bridge_avg']), - str(1.0), - os.path.realpath(description.manifest.get_path('manifest')) ]) - passive_fit_data['fit_3'] = ju.read(fit_3_file) - + grand_up_file = os.path.join(output_directory, "upbase.dat") + np.savetxt(grand_up_file, d["grand_up"]) + + grand_down_file = os.path.join(output_directory, "downbase.dat") + np.savetxt(grand_down_file, d["grand_down"]) + + passive_fit_data["bridge"] = d["bridge_avg"] + passive_fit_data["escape_time"] = d["escape_t"] + + fit_1_file = description.manifest.get_path("fit_1_file") + subprocess.check_output( + [ + sys.executable, + "-m", + neuron_passive_fit.__name__, + str(d["escape_t"]), + os.path.realpath(description.manifest.get_path("manifest")), + ] + ) + passive_fit_data["fit_1"] = ju.read(fit_1_file) + + fit_2_file = description.manifest.get_path("fit_2_file") + + subprocess.check_output( + [ + sys.executable, + "-m", + neuron_passive_fit2.__name__, + str(d["escape_t"]), + os.path.realpath(description.manifest.get_path("manifest")), + ] + ) + passive_fit_data["fit_2"] = ju.read(fit_2_file) + + fit_3_file = description.manifest.get_path("fit_3_file") + subprocess.check_output( + [ + sys.executable, + "-m", + neuron_passive_fit_elec.__name__, + str(d["escape_t"]), + str(d["bridge_avg"]), + str(1.0), + os.path.realpath(description.manifest.get_path("manifest")), + ] + ) + passive_fit_data["fit_3"] = ju.read(fit_3_file) + # Check for potentially problematic outcomes - cm_rel_delta = (passive_fit_data["fit_1"]["Cm"] - passive_fit_data["fit_3"]["Cm"]) / passive_fit_data["fit_1"]["Cm"] + cm_rel_delta = (passive_fit_data["fit_1"]["Cm"] - passive_fit_data["fit_3"]["Cm"]) / passive_fit_data["fit_1"][ + "Cm" + ] if passive_fit_data["fit_2"]["err"] < passive_fit_data["fit_1"]["err"]: _run_passive_fit_log.debug("Fixed Ri gave better results than original") if passive_fit_data["fit_2"]["err"] < passive_fit_data["fit_3"]["err"]: @@ -114,10 +129,10 @@ def run_passive_fit(description): else: cm2 = 1.0 - passive_fit_data['ra'] = ra - passive_fit_data['cm1'] = cm1 - passive_fit_data['cm2'] = cm2 - + passive_fit_data["ra"] = ra + passive_fit_data["cm1"] = cm1 + passive_fit_data["cm2"] = cm2 + return passive_fit_data @@ -125,11 +140,11 @@ def main(limit, manifest_path): app_config = Config() description = app_config.load(manifest_path) - if 'LOG_CFG' in os.environ: - log_config = os.environ['LOG_CFG'] + if "LOG_CFG" in os.environ: + log_config = os.environ["LOG_CFG"] else: - log_config = str(files('allensdk.model.biophysical').joinpath('logging.conf')) - os.environ['LOG_CFG'] = log_config + log_config = str(files("allensdk.model.biophysical").joinpath("logging.conf")) + os.environ["LOG_CFG"] = log_config lc.fileConfig(log_config) run_passive_fit(description) @@ -138,6 +153,5 @@ def main(limit, manifest_path): if __name__ == "__main__": limit = sys.argv[-2] manifest_path = sys.argv[-1] - + main(limit, manifest_path) - diff --git a/allensdk/internal/model/biophysical/run_simulate_lims.py b/allensdk/internal/model/biophysical/run_simulate_lims.py index 6d17b5a433..7738acda72 100644 --- a/allensdk/internal/model/biophysical/run_simulate_lims.py +++ b/allensdk/internal/model/biophysical/run_simulate_lims.py @@ -9,22 +9,17 @@ class RunSimulateLims(RunSimulate): - _log = logging.getLogger('allensdk.internal.model.biophysical.run_simulate_lims') + _log = logging.getLogger("allensdk.internal.model.biophysical.run_simulate_lims") - def __init__(self, - input_json, - output_json): + def __init__(self, input_json, output_json): super(RunSimulateLims, self).__init__(input_json, output_json) - def generate_manifest_rma(self, - neuronal_model_run_id, - manifest_path, - api_url=None): - ''' + def generate_manifest_rma(self, neuronal_model_run_id, manifest_path, api_url=None): + """ Note ---- Other necessary files are also written. - ''' + """ import json from allensdk.internal.api.queries.biophysical_module_api import BiophysicalModuleApi from allensdk.internal.api.queries.biophysical_module_reader import BiophysicalModuleReader @@ -33,97 +28,92 @@ def generate_manifest_rma(self, data = bma.get_neuronal_model_runs(neuronal_model_run_id) lr = BiophysicalModuleReader() - lr.read_lims_message(data, 'lims_message.json') + lr.read_lims_message(data, "lims_message.json") - with open('lims_message.json', 'w') as f: + with open("lims_message.json", "w") as f: f.write(json.dumps(data[0], sort_keys=True, indent=2)) - lr.to_manifest(manifest_path) + lr.to_manifest(manifest_path) - def generate_manifest_lims(self, - lims_data_path, - manifest_path): - ''' + def generate_manifest_lims(self, lims_data_path, manifest_path): + """ Note ---- Other necessary files are also written. - ''' + """ from allensdk.internal.api.queries.biophysical_module_reader import BiophysicalModuleReader self.lims_json = lims_data_path lr = BiophysicalModuleReader() lr.read_lims_file(self.lims_json) - + lr.to_manifest(manifest_path) def copy_local(self): - self.load_manifest() - modfile_dir = self.manifest.get_path('MODFILE_DIR') + modfile_dir = self.manifest.get_path("MODFILE_DIR") if not os.path.exists(modfile_dir): os.mkdir(modfile_dir) - workdir = self.manifest.get_path('WORKDIR') + workdir = self.manifest.get_path("WORKDIR") if not os.path.exists(workdir): os.mkdir(workdir) - modfiles = [self.manifest.get_path(key) for key,info - in self.manifest.path_info.items() - if 'format' in info and info['format'] == 'MODFILE'] - - for from_file in modfiles: + modfiles = [ + self.manifest.get_path(key) + for key, info in self.manifest.path_info.items() + if "format" in info and info["format"] == "MODFILE" + ] + + for from_file in modfiles: RunSimulate._log.debug("copying %s to %s" % (from_file, modfile_dir)) shutil.copy(from_file, modfile_dir) - - shutil.copy(self.manifest.get_path('fit_parameters'), - workdir) - - shutil.copyfile(self.manifest.get_path('stimulus_path'), - self.manifest.get_path('output_path')) - - shutil.copy(str(files('allensdk.model.biophysical').joinpath('cell.hoc')), - os.curdir) - - + + shutil.copy(self.manifest.get_path("fit_parameters"), workdir) + + shutil.copyfile(self.manifest.get_path("stimulus_path"), self.manifest.get_path("output_path")) + + shutil.copy(str(files("allensdk.model.biophysical").joinpath("cell.hoc")), os.curdir) + + def main(command, lims_strategy_json, lims_response_json): - ''' Entry point for module. - :param command: select behavior, nrnivmodl or simulate - :type command: string - :param lims_strategy_json: path to json file output from lims. - :type lims_strategy_json: string - :param lims_response_json: path to json file returned to lims. - :type lims_response_json: string - ''' - rs = RunSimulateLims(lims_strategy_json, - lims_response_json) + """Entry point for module. + :param command: select behavior, nrnivmodl or simulate + :type command: string + :param lims_strategy_json: path to json file output from lims. + :type lims_strategy_json: string + :param lims_response_json: path to json file returned to lims. + :type lims_response_json: string + """ + rs = RunSimulateLims(lims_strategy_json, lims_response_json) RunSimulateLims._log.debug("command: %s" % (command)) RunSimulateLims._log.debug("lims strategy json: %s" % (lims_strategy_json)) RunSimulateLims._log.debug("lims upload json: %s" % (lims_response_json)) - - log_config = str(files('allensdk.model.biophysical').joinpath('logging.conf')) + + log_config = str(files("allensdk.model.biophysical").joinpath("logging.conf")) lc.fileConfig(log_config) - os.environ['LOG_CFG'] = log_config - - if 'nrnivmodl' == command: + os.environ["LOG_CFG"] = log_config + + if "nrnivmodl" == command: rs.nrnivmodl() - elif 'copy_local' == command: + elif "copy_local" == command: rs.copy_local() - elif 'generate_manifest_rma' == command: + elif "generate_manifest_rma" == command: rs.generate_manifest_rma(input_json, output_json) - elif 'generate_manifest_lims' == command: + elif "generate_manifest_lims" == command: rs.generate_manifest_lims(input_json, output_json) - elif 'generate_manifest_lims' == command: + elif "generate_manifest_lims" == command: rs.generate_manifest_lims(input_json, output_json) else: rs.simulate() -if __name__ == '__main__': +if __name__ == "__main__": command, input_json, output_json = sys.argv[-3:] try: diff --git a/allensdk/internal/model/biophysical/run_simulate_workflow.py b/allensdk/internal/model/biophysical/run_simulate_workflow.py index 5bd37096c0..c1ce803954 100644 --- a/allensdk/internal/model/biophysical/run_simulate_workflow.py +++ b/allensdk/internal/model/biophysical/run_simulate_workflow.py @@ -2,11 +2,11 @@ from subprocess import call from importlib.resources import files -the_script = str(files(__package__).joinpath('run_simulate.sh')) +the_script = str(files(__package__).joinpath("run_simulate.sh")) -cmd = ['/bin/bash', the_script] +cmd = ["/bin/bash", the_script] cmd.extend(sys.argv[1:]) -print(' '.join(cmd)) +print(" ".join(cmd)) -call(cmd) \ No newline at end of file +call(cmd) diff --git a/allensdk/internal/model/data_access.py b/allensdk/internal/model/data_access.py index c5face9492..2ddce2db18 100644 --- a/allensdk/internal/model/data_access.py +++ b/allensdk/internal/model/data_access.py @@ -2,14 +2,15 @@ from scipy import signal import numpy as np + def load_sweep(file_name, sweep_number, desired_dt=None, cut=0, bessel=False): - '''load a data sweep and do specified data processing. + """load a data sweep and do specified data processing. Inputs: file_name: string name of .nwb data file - sweep_number: + sweep_number: number specifying the sweep to be loaded - desired_dt: + desired_dt: the size of the time step the data should be subsampled to cut: indicie of which to start reporting data (i.e. cut off data before this indicie) @@ -21,7 +22,7 @@ def load_sweep(file_name, sweep_number, desired_dt=None, cut=0, bessel=False): current: array dt: time step of the returned data start_idx: the index at which the first stimulus starts (excluding the test pulse) - ''' + """ ds = NwbDataSet(file_name) data = ds.get_sweep(sweep_number) @@ -29,13 +30,13 @@ def load_sweep(file_name, sweep_number, desired_dt=None, cut=0, bessel=False): if cut > 0: data["response"] = data["response"][cut:] - data["stimulus"] = data["stimulus"][cut:] + data["stimulus"] = data["stimulus"][cut:] if bessel: - sample_freq = 1. / data["dt"] - filt_coeff = (bessel["freq"]) / (sample_freq / 2.) # filter fraction of Nyquist frequency + sample_freq = 1.0 / data["dt"] + filt_coeff = (bessel["freq"]) / (sample_freq / 2.0) # filter fraction of Nyquist frequency b, a = signal.bessel(bessel["N"], filt_coeff, "low") - data['response'] = signal.filtfilt(b, a, data['response'], axis=0) + data["response"] = signal.filtfilt(b, a, data["response"], axis=0) if desired_dt is not None: if data["dt"] != desired_dt: @@ -47,22 +48,17 @@ def load_sweep(file_name, sweep_number, desired_dt=None, cut=0, bessel=False): if "start_idx" not in data: data["start_idx"] = data["index_range"][0] - return { - "voltage": data["response"], - "current": data["stimulus"], - "dt": data["dt"], - "start_idx": data["start_idx"] - } + return {"voltage": data["response"], "current": data["stimulus"], "dt": data["dt"], "start_idx": data["start_idx"]} def load_sweeps(file_name, sweep_numbers, dt=None, cut=0, bessel=False): - '''load sweeps and do specified data processing. + """load sweeps and do specified data processing. Inputs: file_name: string name of .nwb data file - sweep_numbers: + sweep_numbers: sweep numbers to be loaded - desired_dt: + desired_dt: the size of the time step the data should be subsampled to cut: indicie of which to start reporting data (i.e. cut off data before this indicie) @@ -73,17 +69,17 @@ def load_sweeps(file_name, sweep_numbers, dt=None, cut=0, bessel=False): voltage: list of voltage trace arrays current: list of current trace arrays dt: list of time step corresponding to each array of the returned data - start_idx: list of the indicies at which the first stimulus starts (excluding + start_idx: list of the indicies at which the first stimulus starts (excluding the test pulse) in each returned sweep - ''' - data = [ load_sweep(file_name, sweep_number, dt, cut, bessel) for sweep_number in sweep_numbers ] + """ + data = [load_sweep(file_name, sweep_number, dt, cut, bessel) for sweep_number in sweep_numbers] return { - 'voltage': [ d['voltage'] for d in data ], - 'current': [ d['current'] for d in data ], - 'dt': [ d['dt'] for d in data ], - 'start_idx': [ d['start_idx'] for d in data ], - } + "voltage": [d["voltage"] for d in data], + "current": [d["current"] for d in data], + "dt": [d["dt"] for d in data], + "start_idx": [d["start_idx"] for d in data], + } def subsample_data(data, method, present_time_step, desired_time_step): @@ -93,11 +89,10 @@ def subsample_data(data, method, present_time_step, desired_time_step): # number of elements to average over n = int(desired_time_step / present_time_step) - if method == "mean": # if n does not divide evenly into the length of the array, crop off the end end = n * int(len(data) / n) - - return np.mean(data[:end].reshape(-1,n), 1) - raise Exception("unknown subsample method: %s" % (method)) \ No newline at end of file + return np.mean(data[:end].reshape(-1, n), 1) + + raise Exception("unknown subsample method: %s" % (method)) diff --git a/allensdk/internal/model/glif/ASGLM.py b/allensdk/internal/model/glif/ASGLM.py index 3f5017075e..c337869216 100644 --- a/allensdk/internal/model/glif/ASGLM.py +++ b/allensdk/internal/model/glif/ASGLM.py @@ -5,12 +5,26 @@ import statsmodels.api as sm import matplotlib.pyplot as plt -def ASGLM_pairwise(ks_int, I_stim, voltage, spike_ind, cinit, tauinit, SCL, dt, resting_potential, - SHORT_RUN=False, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False): - '''Calculate the resistance and amplitude of the afterspike currents for + +def ASGLM_pairwise( + ks_int, + I_stim, + voltage, + spike_ind, + cinit, + tauinit, + SCL, + dt, + resting_potential, + SHORT_RUN=False, + MAKE_PLOT=False, + SHOW_PLOT=False, + BLOCK=False, +): + """Calculate the resistance and amplitude of the afterspike currents for Parameters ---------- - ks_int: list + ks_int: list initial possible k's (k=1/tau, where tau is the time constant of the exponential decay) I_stim: list of arrays input stimulus traces of sweeps @@ -18,226 +32,240 @@ def ASGLM_pairwise(ks_int, I_stim, voltage, spike_ind, cinit, tauinit, SCL, dt, voltage of cell as a result of I_stim spike_ind: list of arrays each array contains the index of the spikes - cinit: float + cinit: float membrane capacitance tauinit: float time constant of membrane SCL: float number of indicies that should be cut after a spike - dt: float + dt: float size of time step of injected current Returns - ''' - - #Initialize post-spike filter parameters (MOST OF THESE ARE HARD-CODED currently) - nkt = 8000 # arbitrary length of filter WANT FILTER TO COVER A LENGTH OF TIME, WANT FILTER TO BE LONGER THAN LONGEST LENGTH ASC - DTsim = dt #DTsim = dt means filter is nkt*dt = 100 ms long in this case THIS SHOULD INDEED BE THE SAMPLE WIDTH - neye = 0 #no of identity basis vectors DELTA - f = 1e-3/dt #pre-factor for getting time units correct for bases - - taus_int=[1000./kk for kk in ks_int] #converting to ms - taus_filter = [f*j for j in taus_int] #convert ms time to filter time units 1/(dt*ks) - ks_list = [(1.0/(i)) for i in taus_filter] #use peak positions of rcos bumps as time-scales for exponential bases - ncos = 2 #no of bases!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - flag_exp = 1 #flag_exp = 1 means use exponential bases, else use raised-cosine bumps + """ + + # Initialize post-spike filter parameters (MOST OF THESE ARE HARD-CODED currently) + nkt = 8000 # arbitrary length of filter WANT FILTER TO COVER A LENGTH OF TIME, WANT FILTER TO BE LONGER THAN LONGEST LENGTH ASC + DTsim = dt # DTsim = dt means filter is nkt*dt = 100 ms long in this case THIS SHOULD INDEED BE THE SAMPLE WIDTH + neye = 0 # no of identity basis vectors DELTA + f = 1e-3 / dt # pre-factor for getting time units correct for bases + + taus_int = [1000.0 / kk for kk in ks_int] # converting to ms + taus_filter = [f * j for j in taus_int] # convert ms time to filter time units 1/(dt*ks) + ks_list = [ + (1.0 / (i)) for i in taus_filter + ] # use peak positions of rcos bumps as time-scales for exponential bases + ncos = 2 # no of bases!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + flag_exp = 1 # flag_exp = 1 means use exponential bases, else use raised-cosine bumps vL = resting_potential - + # GLM fit with post-spike currents - tst = 0 #190000#355000 #time-step to start + tst = 0 # 190000#355000 #time-step to start npcut = int(SCL) # no of points to cut after each spike-initiation - - # Collect spikes between tst and tend + + # Collect spikes between tst and tend t0_list = [] for mm in range(len(I_stim)): tend = len(I_stim[mm]) - t0=[] - for jj in range(len(spike_ind[mm])): - if spike_ind[mm][jj]>tst and spike_ind[mm][jj] tst and spike_ind[mm][jj] < tend: t0.append(spike_ind[mm][jj]) - + t0_list.append(t0) - - #Create a list of pairs of ks - ks_pairs = list(itertools.combinations(ks_list,ncos)) - ks_pairs_in_SI_units= list(itertools.combinations(ks_int, ncos)) - if len(ks_pairs)!=10: - raise Exception('figure subplots will need to be changed as there is a different number than 10 ks_pairs.') - - #Initialize list to hold charge dump values and amp-vectors - #Iterate over all pairs + + # Create a list of pairs of ks + ks_pairs = list(itertools.combinations(ks_list, ncos)) + ks_pairs_in_SI_units = list(itertools.combinations(ks_int, ncos)) + if len(ks_pairs) != 10: + raise Exception("figure subplots will need to be changed as there is a different number than 10 ks_pairs.") + + # Initialize list to hold charge dump values and amp-vectors + # Iterate over all pairs if SHORT_RUN: logging.warning("You are not doing all the ks pairs in ASGLM_pairwise") - ks_pairs=[ks_pairs[0]] + ks_pairs = [ks_pairs[0]] - R_for_all_ks_pairs=[] - asc_amp_for_all_ks_pairs=[] - llh_for_all_ks_pairs=[] + R_for_all_ks_pairs = [] + asc_amp_for_all_ks_pairs = [] + llh_for_all_ks_pairs = [] for ks_ind, (ks_fit_units, ks_SI_units) in enumerate(zip(ks_pairs, ks_pairs_in_SI_units)): - print('ks_fit_units', ks_fit_units) - #Create basis IPSPs + print("ks_fit_units", ks_fit_units) + # Create basis IPSPs if MAKE_PLOT: - plt.figure(78, figsize=(20,10)) + plt.figure(78, figsize=(20, 10)) basis_IPSP_list = [] - for rr in range(len(I_stim)): #loop over repeats - #find the basis of the entire trace - basis_IPSP, gg0 = GLM.create_basis_IPSP(neye,ncos,taus_filter,ks_fit_units,DTsim,t0_list[rr],I_stim[rr],nkt,flag_exp,npcut) - basis_IPSP_list.append(basis_IPSP) - #--Plot basis IPSPs between si and se - si = t0_list[0][0]-10 #plot start_ind - se = si+nkt+10 #plot end_ind - tvec = dt*np.arange(si-tst,se-tst) #convert time-steps to real time (in sec) + for rr in range(len(I_stim)): # loop over repeats + # find the basis of the entire trace + basis_IPSP, gg0 = GLM.create_basis_IPSP( + neye, ncos, taus_filter, ks_fit_units, DTsim, t0_list[rr], I_stim[rr], nkt, flag_exp, npcut + ) + basis_IPSP_list.append(basis_IPSP) + # --Plot basis IPSPs between si and se + si = t0_list[0][0] - 10 # plot start_ind + se = si + nkt + 10 # plot end_ind + tvec = dt * np.arange(si - tst, se - tst) # convert time-steps to real time (in sec) if MAKE_PLOT: plt.figure(78) - plt.subplot(5,2, ks_ind+1) - plt.plot(1e3*tvec,basis_IPSP[si:se,:], lw=2, label=str(rr)) #1e3 plots time on x-axis in ms - plt.xlabel('time (ms)') - plt.title("k's "+str(ks_SI_units)) + plt.subplot(5, 2, ks_ind + 1) + plt.plot(1e3 * tvec, basis_IPSP[si:se, :], lw=2, label=str(rr)) # 1e3 plots time on x-axis in ms + plt.xlabel("time (ms)") + plt.title("k's " + str(ks_SI_units)) if MAKE_PLOT: - plt.annotate('ASGLM (fit asc and R): AScurrent basis', - xy=(.4, .985), - xycoords='figure fraction', - horizontalalignment='left', verticalalignment='top', - fontsize=20) + plt.annotate( + "ASGLM (fit asc and R): AScurrent basis", + xy=(0.4, 0.985), + xycoords="figure fraction", + horizontalalignment="left", + verticalalignment="top", + fontsize=20, + ) plt.legend() plt.tight_layout() - + if SHOW_PLOT: - plt.show(block=BLOCK) - + plt.show(block=BLOCK) + # cut spikes out of dv, v, i, and b_ipsp and put the different sweeps in lists i_all_swps_list = [] b_ipsp_all_swps_list = [] v_all_swps_list = [] - dv_all_swps_list = [] - for ss in range(len(I_stim)): #loop over repeats + dv_all_swps_list = [] + for ss in range(len(I_stim)): # loop over repeats tend = len(I_stim[ss]) - i = I_stim[ss][tst:tend-1] - b_ipsp = basis_IPSP_list[ss][tst:tend-1] - - v = voltage[ss][tst:tend-1] - vs = voltage[ss][tst+1:tend] - dv = (vs-v)/dt #derivative of voltage - - #delete npcut points after spike from each qty + i = I_stim[ss][tst : tend - 1] + b_ipsp = basis_IPSP_list[ss][tst : tend - 1] + + v = voltage[ss][tst : tend - 1] + vs = voltage[ss][tst + 1 : tend] + dv = (vs - v) / dt # derivative of voltage + + # delete npcut points after spike from each qty delpts = [] for kk in range(len(t0_list[ss])): - delpts.append(range(int(t0_list[ss][kk])-tst,int(t0_list[ss][kk])-tst+npcut)) - - dv = np.delete(dv,delpts,0) - v = np.delete(v,delpts,0) - i = np.delete(i,delpts,0) - b_ipsp = np.delete(b_ipsp,delpts,0) - + delpts.append(range(int(t0_list[ss][kk]) - tst, int(t0_list[ss][kk]) - tst + npcut)) + + dv = np.delete(dv, delpts, 0) + v = np.delete(v, delpts, 0) + i = np.delete(i, delpts, 0) + b_ipsp = np.delete(b_ipsp, delpts, 0) + v_all_swps_list.append(v) dv_all_swps_list.append(dv) i_all_swps_list.append(i) b_ipsp_all_swps_list.append(b_ipsp) - - tvec = dt*np.arange(len(v)) - + + tvec = dt * np.arange(len(v)) + # Compute amplitude each basis AS current using a GLM if MAKE_PLOT: plt.figure(79, figsize=(20, 12)) plt.figure(80, figsize=(20, 12)) - R_for_each_sweep=[] - asc_amp_for_each_sweep=[] - llh_for_each_sweep=[] - for kkk, (v_spike_deleted, dv_spike_deleted, i_spike_deleted, b_ipsp_spikes_deleted) in \ - enumerate(zip(v_all_swps_list, dv_all_swps_list, i_all_swps_list, b_ipsp_all_swps_list)): - - #--fitting afterspike current amplitudes and resistance - inp = np.zeros((len(i_spike_deleted),ncos+1)) - inp[:,range(0,ncos)] = (1/cinit)*b_ipsp_spikes_deleted[:,range(ncos)] - inp[:,ncos] = -(v_spike_deleted-vL)/tauinit - out = dv_spike_deleted -i_spike_deleted/cinit#+ (v-vL)/tauinit - i/cinit + R_for_each_sweep = [] + asc_amp_for_each_sweep = [] + llh_for_each_sweep = [] + for kkk, (v_spike_deleted, dv_spike_deleted, i_spike_deleted, b_ipsp_spikes_deleted) in enumerate( + zip(v_all_swps_list, dv_all_swps_list, i_all_swps_list, b_ipsp_all_swps_list) + ): + # --fitting afterspike current amplitudes and resistance + inp = np.zeros((len(i_spike_deleted), ncos + 1)) + inp[:, range(0, ncos)] = (1 / cinit) * b_ipsp_spikes_deleted[:, range(ncos)] + inp[:, ncos] = -(v_spike_deleted - vL) / tauinit + out = dv_spike_deleted - i_spike_deleted / cinit # + (v-vL)/tauinit - i/cinit try: - glm_fit = sm.GLM(out,inp,family=sm.families.Gaussian(sm.families.links.identity)) + glm_fit = sm.GLM(out, inp, family=sm.families.Gaussian(sm.families.links.identity)) res = glm_fit.fit() - fitprs = res.params #fitprs has [AMP OF ASC, TAU] + fitprs = res.params # fitprs has [AMP OF ASC, TAU] - llh=res.llf - fit_R=tauinit/(fitprs[ncos]*cinit) - fit_asc_amp=fitprs[:ncos] - #Compute and plot post-spike current (essentially multiply basis functions with correct amplitudes from GLM fit) - ipsc = np.sum(b_ipsp_spikes_deleted[:,0:ncos]*fit_asc_amp,1) #THIS IS TOTAL POSTSPIKE CURRENT + llh = res.llf + fit_R = tauinit / (fitprs[ncos] * cinit) + fit_asc_amp = fitprs[:ncos] + # Compute and plot post-spike current (essentially multiply basis functions with correct amplitudes from GLM fit) + ipsc = np.sum(b_ipsp_spikes_deleted[:, 0:ncos] * fit_asc_amp, 1) # THIS IS TOTAL POSTSPIKE CURRENT except Exception as e: logging.warning("fit didn't work: " + str(e)) - llh=np.nan - fit_R=np.nan - fit_asc_amp=np.ones(ncos)*np.nan - ipsc=np.ones(len(b_ipsp_spikes_deleted[:,0]))*np.nan - + llh = np.nan + fit_R = np.nan + fit_asc_amp = np.ones(ncos) * np.nan + ipsc = np.ones(len(b_ipsp_spikes_deleted[:, 0])) * np.nan + R_for_each_sweep.append(fit_R) asc_amp_for_each_sweep.append(fit_asc_amp) llh_for_each_sweep.append(llh) - - #Plot a single instance of AS current as function of time (in ms) + + # Plot a single instance of AS current as function of time (in ms) if MAKE_PLOT: plt.figure(79) - plt.subplot(5,2, ks_ind+1) - plot_inds = np.arange(int(t0_list[0][0])-tst,int(t0_list[0][0])-tst+nkt) #plot just after first spike - tvec = dt*plot_inds - plt.plot(tvec,ipsc[plot_inds], lw=2, label='llh='+str(llh)) - plt.xlabel('time (s)') - plt.ylabel('current (A)') - plt.title("k's "+str(ks_SI_units)) - + plt.subplot(5, 2, ks_ind + 1) + plot_inds = np.arange( + int(t0_list[0][0]) - tst, int(t0_list[0][0]) - tst + nkt + ) # plot just after first spike + tvec = dt * plot_inds + plt.plot(tvec, ipsc[plot_inds], lw=2, label="llh=" + str(llh)) + plt.xlabel("time (s)") + plt.ylabel("current (A)") + plt.title("k's " + str(ks_SI_units)) + plt.figure(80) - plt.subplot(5,2,ks_ind+1) + plt.subplot(5, 2, ks_ind + 1) for ASC, KS in zip(fit_asc_amp, np.array(ks_SI_units)): - #TODO: MAKE SURE I SHOULD BE USING SI UNITS HERE {I think this is correct as it is what I am returning from the function] - #TODO: MAKE SURE THE SIGNS OF K ARE OK - t_plot=np.arange(10000)*dt - single_asc_trace=ASC*np.exp(-KS*t_plot) - plt.plot(t_plot, single_asc_trace, lw=2, label='sweep '+str(kkk)) - plt.xlabel('time (s)') - plt.ylabel('current (A)') - plt.title("k's "+str(ks_SI_units)) + # TODO: MAKE SURE I SHOULD BE USING SI UNITS HERE {I think this is correct as it is what I am returning from the function] + # TODO: MAKE SURE THE SIGNS OF K ARE OK + t_plot = np.arange(10000) * dt + single_asc_trace = ASC * np.exp(-KS * t_plot) + plt.plot(t_plot, single_asc_trace, lw=2, label="sweep " + str(kkk)) + plt.xlabel("time (s)") + plt.ylabel("current (A)") + plt.title("k's " + str(ks_SI_units)) if MAKE_PLOT: plt.figure(79) plt.tight_layout() - plt.annotate('ASGLM (fit asc and R): Sum fit after spike currents', - xy=(.3, .985), - xycoords='figure fraction', - horizontalalignment='left', verticalalignment='top', - fontsize=20) + plt.annotate( + "ASGLM (fit asc and R): Sum fit after spike currents", + xy=(0.3, 0.985), + xycoords="figure fraction", + horizontalalignment="left", + verticalalignment="top", + fontsize=20, + ) plt.legend() - + plt.figure(80) plt.tight_layout() - plt.annotate('ASGLM (fit asc and R): Individual Currents', - xy=(.3, .985), - xycoords='figure fraction', - horizontalalignment='left', verticalalignment='top', - fontsize=20) + plt.annotate( + "ASGLM (fit asc and R): Individual Currents", + xy=(0.3, 0.985), + xycoords="figure fraction", + horizontalalignment="left", + verticalalignment="top", + fontsize=20, + ) plt.legend() if SHOW_PLOT: plt.show(block=False) - - -# #BELOW IS OVER EVERY PAIR + + # #BELOW IS OVER EVERY PAIR R_for_all_ks_pairs.append(R_for_each_sweep) asc_amp_for_all_ks_pairs.append(asc_amp_for_each_sweep) llh_for_all_ks_pairs.append(llh_for_each_sweep) - #!!!!!!!!!!!!!we multiplied ks by dt for SI!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - ave_llh_for_each_pair=np.mean(llh_for_all_ks_pairs, axis=1) - best_ks_pair_ind = np.where(np.max(ave_llh_for_each_pair)==ave_llh_for_each_pair)[0][0] - - best_k_pair=np.array(ks_pairs[best_ks_pair_ind])/dt - best_asc_amp=np.array(asc_amp_for_all_ks_pairs[best_ks_pair_ind]) - best_R=np.array(R_for_all_ks_pairs[best_ks_pair_ind]) - best_llh=np.array(llh_for_all_ks_pairs[best_ks_pair_ind]) - - print('**********from ASGLM_pairwise******************************************') - print('best_ks_pair_ind', best_ks_pair_ind) - print('best_asc_amp', best_asc_amp) - print('best_k_pair', best_k_pair) - print('best_R', best_R) - print('best_llh', best_llh) - - print('**********done with ASGLM_pairwise***********************************') - + #!!!!!!!!!!!!!we multiplied ks by dt for SI!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + ave_llh_for_each_pair = np.mean(llh_for_all_ks_pairs, axis=1) + best_ks_pair_ind = np.where(np.max(ave_llh_for_each_pair) == ave_llh_for_each_pair)[0][0] + + best_k_pair = np.array(ks_pairs[best_ks_pair_ind]) / dt + best_asc_amp = np.array(asc_amp_for_all_ks_pairs[best_ks_pair_ind]) + best_R = np.array(R_for_all_ks_pairs[best_ks_pair_ind]) + best_llh = np.array(llh_for_all_ks_pairs[best_ks_pair_ind]) + + print("**********from ASGLM_pairwise******************************************") + print("best_ks_pair_ind", best_ks_pair_ind) + print("best_asc_amp", best_asc_amp) + print("best_k_pair", best_k_pair) + print("best_R", best_R) + print("best_llh", best_llh) + + print("**********done with ASGLM_pairwise***********************************") + return best_k_pair, best_asc_amp, best_R, best_llh diff --git a/allensdk/internal/model/glif/MLIN.py b/allensdk/internal/model/glif/MLIN.py index d42c657045..d646a07a43 100644 --- a/allensdk/internal/model/glif/MLIN.py +++ b/allensdk/internal/model/glif/MLIN.py @@ -4,140 +4,156 @@ from scipy.optimize import curve_fit import matplotlib.pyplot as plt + def MLIN(voltage, current, res, cap, dt, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, PUBLICATION_PLOT=False): - '''voltage, current + """voltage, current input: - voltage: numpy array of voltage with test pulse cut out - current: numpy array of stimulus with test pulse cut out ''' + voltage: numpy array of voltage with test pulse cut out + current: numpy array of stimulus with test pulse cut out""" t = np.arange(0, len(current)) * dt (_, _, _, start_idx, end_idx) = get_square_stim_characteristics(current, t, no_test_pulse=True) stim_len = end_idx - start_idx - distribution_start_ind=start_idx + int(.5/dt) - distribution_end_ind=start_idx + stim_len - - v_section=voltage[distribution_start_ind:distribution_end_ind] + distribution_start_ind = start_idx + int(0.5 / dt) + distribution_end_ind = start_idx + stim_len + + v_section = voltage[distribution_start_ind:distribution_end_ind] if MAKE_PLOT: - times=np.arange(0, len(voltage))*dt - plt.figure(figsize=(15, 11)) - plt.subplot2grid((7,2), (0,0), colspan=2) + times = np.arange(0, len(voltage)) * dt + plt.figure(figsize=(15, 11)) + plt.subplot2grid((7, 2), (0, 0), colspan=2) plt.plot(times[distribution_start_ind:distribution_end_ind], v_section) - plt.title('voltage for histogram') + plt.title("voltage for histogram") print(v_section) - v_section=v_section-np.mean(v_section) - var_of_section=np.var(v_section) - sv_for_expsymm=np.std(v_section)/np.sqrt(2) - subthreshold_long_square_voltage_distribution=stats.norm(loc=0, scale=np.sqrt(var_of_section)) - - #--autocorrelation - tau_4AC=res*cap - AC=autocorr(v_section-np.mean(v_section)) - ACtime=np.arange(0,len(AC))*dt - - #--fit autocorrelation with decaying exponential - (popt, pcov)= curve_fit(exp_decay, ACtime, AC, p0=[AC[0],tau_4AC]) - tau_from_AC=popt[1] - - if MAKE_PLOT: - plt.subplot2grid((7,2), (1,0), rowspan=3) - plt.hist(v_section, bins=50, normed=True, label='data') - data_grid=np.arange(min(v_section), max(v_section), abs(min(v_section)-max(v_section))/100.) - plt.plot(data_grid, subthreshold_long_square_voltage_distribution.pdf(data_grid), 'r', label='gauss with\nmeasured var') - plt.plot(data_grid, expsymm_pdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm function') - plt.xlabel('voltage (mV)') - plt.title('Mean subtracted voltage hist') + v_section = v_section - np.mean(v_section) + var_of_section = np.var(v_section) + sv_for_expsymm = np.std(v_section) / np.sqrt(2) + subthreshold_long_square_voltage_distribution = stats.norm(loc=0, scale=np.sqrt(var_of_section)) + + # --autocorrelation + tau_4AC = res * cap + AC = autocorr(v_section - np.mean(v_section)) + ACtime = np.arange(0, len(AC)) * dt + + # --fit autocorrelation with decaying exponential + (popt, pcov) = curve_fit(exp_decay, ACtime, AC, p0=[AC[0], tau_4AC]) + tau_from_AC = popt[1] + + if MAKE_PLOT: + plt.subplot2grid((7, 2), (1, 0), rowspan=3) + plt.hist(v_section, bins=50, normed=True, label="data") + data_grid = np.arange(min(v_section), max(v_section), abs(min(v_section) - max(v_section)) / 100.0) + plt.plot( + data_grid, + subthreshold_long_square_voltage_distribution.pdf(data_grid), + "r", + label="gauss with\nmeasured var", + ) + plt.plot(data_grid, expsymm_pdf(data_grid, sv_for_expsymm), "m", lw=3, label="expsymm function") + plt.xlabel("voltage (mV)") + plt.title("Mean subtracted voltage hist") plt.legend() - - #--cumulative density function - (h, edges)=np.histogram(v_section, bins=50) - centers=find_bin_center(edges) - - CDFx=centers - CDFy=np.cumsum(h)/float(len(v_section)) - - plt.subplot2grid((7,2), (4,0), rowspan=3) - plt.plot(CDFx, CDFy, label='data') -# plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit') - plt.plot(data_grid, subthreshold_long_square_voltage_distribution.cdf(data_grid), 'r', label='gauss with\nmeasured var') - plt.plot(data_grid, expsymm_cdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm func') - plt.title('Normalized cumulative sum') - plt.xlabel('v-mean(v)') + + # --cumulative density function + (h, edges) = np.histogram(v_section, bins=50) + centers = find_bin_center(edges) + + CDFx = centers + CDFy = np.cumsum(h) / float(len(v_section)) + + plt.subplot2grid((7, 2), (4, 0), rowspan=3) + plt.plot(CDFx, CDFy, label="data") + # plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit') + plt.plot( + data_grid, + subthreshold_long_square_voltage_distribution.cdf(data_grid), + "r", + label="gauss with\nmeasured var", + ) + plt.plot(data_grid, expsymm_cdf(data_grid, sv_for_expsymm), "m", lw=3, label="expsymm func") + plt.title("Normalized cumulative sum") + plt.xlabel("v-mean(v)") plt.legend() - - plt.subplot2grid((7,2), (1,1), rowspan=3) - plt.plot(ACtime, AC, label='data') - plt.xlabel('shift (s)') - plt.title('Auto correlation') - plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label='RC') - plt.plot(ACtime, exp_decay(ACtime, popt[0], tau_from_AC), label='fit') + + plt.subplot2grid((7, 2), (1, 1), rowspan=3) + plt.plot(ACtime, AC, label="data") + plt.xlabel("shift (s)") + plt.title("Auto correlation") + plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label="RC") + plt.plot(ACtime, exp_decay(ACtime, popt[0], tau_from_AC), label="fit") plt.legend() - + plt.tight_layout() if SHOW_PLOT: - plt.show(block=BLOCK) - - if PUBLICATION_PLOT: - times=np.arange(0, len(voltage))*dt - plt.figure(figsize=(14, 7)) - plt.subplot2grid((3,3), (0,0), colspan=3) - plt.xlabel('time (s)', fontsize=14) - plt.ylabel('(mV)', fontsize=14) - plt.plot(times[distribution_start_ind:distribution_end_ind], v_section*1.e3) - plt.title('Voltage for histogram', fontsize=16) - - plt.subplot2grid((3,3), (1,0), rowspan=2) - plt.hist(v_section*1.e3, bins=50, normed=True, label='data') - data_grid=np.arange(min(v_section), max(v_section), abs(min(v_section)-max(v_section))/100.) -# plt.plot(data_grid, subthreshold_long_square_voltage_distribution.pdf(data_grid), 'r', label='gauss with\nmeasured var') - plt.plot(data_grid*1.e3, 1.e-3*expsymm_pdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm ') - plt.xlabel('voltage (mV)', fontsize=14) - plt.title('Mean subtracted voltage hist', fontsize=16) + plt.show(block=BLOCK) + + if PUBLICATION_PLOT: + times = np.arange(0, len(voltage)) * dt + plt.figure(figsize=(14, 7)) + plt.subplot2grid((3, 3), (0, 0), colspan=3) + plt.xlabel("time (s)", fontsize=14) + plt.ylabel("(mV)", fontsize=14) + plt.plot(times[distribution_start_ind:distribution_end_ind], v_section * 1.0e3) + plt.title("Voltage for histogram", fontsize=16) + + plt.subplot2grid((3, 3), (1, 0), rowspan=2) + plt.hist(v_section * 1.0e3, bins=50, normed=True, label="data") + data_grid = np.arange(min(v_section), max(v_section), abs(min(v_section) - max(v_section)) / 100.0) + # plt.plot(data_grid, subthreshold_long_square_voltage_distribution.pdf(data_grid), 'r', label='gauss with\nmeasured var') + plt.plot(data_grid * 1.0e3, 1.0e-3 * expsymm_pdf(data_grid, sv_for_expsymm), "m", lw=3, label="expsymm ") + plt.xlabel("voltage (mV)", fontsize=14) + plt.title("Mean subtracted voltage hist", fontsize=16) plt.legend(loc=1) - - #--cumulative density function - (h, edges)=np.histogram(v_section, bins=50) - centers=find_bin_center(edges) - - CDFx=centers - CDFy=np.cumsum(h)/float(len(v_section)) - - plt.subplot2grid((3,3), (1,1), rowspan=2) - plt.plot(CDFx*1e3, CDFy, label='data') - # plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit') -# plt.plot(data_grid, subthreshold_long_square_voltage_distribution.cdf(data_grid), 'r', label='gauss with\nmeasured var') - plt.plot(data_grid*1.e3, expsymm_cdf(data_grid, sv_for_expsymm), 'm', lw=3, label='expsymm') - plt.title('Normalized cumulative sum', fontsize=16) - plt.xlabel('V-mean(V) (mV)', fontsize=16) + + # --cumulative density function + (h, edges) = np.histogram(v_section, bins=50) + centers = find_bin_center(edges) + + CDFx = centers + CDFy = np.cumsum(h) / float(len(v_section)) + + plt.subplot2grid((3, 3), (1, 1), rowspan=2) + plt.plot(CDFx * 1e3, CDFy, label="data") + # plt.plot(CDFx, sig(CDFx, popt[0], popt[1]), label='fit') + # plt.plot(data_grid, subthreshold_long_square_voltage_distribution.cdf(data_grid), 'r', label='gauss with\nmeasured var') + plt.plot(data_grid * 1.0e3, expsymm_cdf(data_grid, sv_for_expsymm), "m", lw=3, label="expsymm") + plt.title("Normalized cumulative sum", fontsize=16) + plt.xlabel("V-mean(V) (mV)", fontsize=16) plt.legend(loc=2, fontsize=14) - - plt.subplot2grid((3,3), (1,2), rowspan=2) - plt.plot(ACtime, AC*1.e3, label='data') - plt.xlabel('shift (s)', fontsize=14) - plt.title('Auto correlation', fontsize=16) -# plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label='RC') - plt.plot(ACtime, exp_decay(ACtime, popt[0]*1.e3, tau_from_AC), lw=3, label='fit') + + plt.subplot2grid((3, 3), (1, 2), rowspan=2) + plt.plot(ACtime, AC * 1.0e3, label="data") + plt.xlabel("shift (s)", fontsize=14) + plt.title("Auto correlation", fontsize=16) + # plt.plot(ACtime, exp_decay(ACtime, AC[0], tau_4AC), label='RC') + plt.plot(ACtime, exp_decay(ACtime, popt[0] * 1.0e3, tau_from_AC), lw=3, label="fit") plt.legend(loc=1) plt.tight_layout() - + return var_of_section, sv_for_expsymm, tau_from_AC + def expsymm_pdf(v, dv): - return 1./(2.*dv)*np.exp(-np.absolute(v)/dv) + return 1.0 / (2.0 * dv) * np.exp(-np.absolute(v) / dv) + def expsymm_cdf(v, dv): - return 1./2.+(v*(1-np.exp(-np.absolute(v)/dv)))/(2.*np.absolute(v)) + return 1.0 / 2.0 + (v * (1 - np.exp(-np.absolute(v) / dv))) / (2.0 * np.absolute(v)) + def exp_decay(time, amp, tau): - return amp*np.exp(-time/tau) + return amp * np.exp(-time / tau) + def find_bin_center(edges): - centers=np.zeros(len(edges)-1) - for ii in range(0, len(edges)-1): - centers[ii]=np.mean([edges[ii], edges[ii+1]]) + centers = np.zeros(len(edges) - 1) + for ii in range(0, len(edges) - 1): + centers[ii] = np.mean([edges[ii], edges[ii + 1]]) return centers + def autocorr(x): - result = np.correlate(x, x, mode='full') -# return result - return result[result.size/2:] + result = np.correlate(x, x, mode="full") + # return result + return result[result.size / 2 :] diff --git a/allensdk/internal/model/glif/are_two_lists_of_arrays_the_same.py b/allensdk/internal/model/glif/are_two_lists_of_arrays_the_same.py index dd3b81548a..f13defe9fc 100644 --- a/allensdk/internal/model/glif/are_two_lists_of_arrays_the_same.py +++ b/allensdk/internal/model/glif/are_two_lists_of_arrays_the_same.py @@ -1,16 +1,15 @@ import numpy as np -def are_two_lists_of_arrays_the_same(data1, data2): - '''returns False if to lists of arrays are different. + +def are_two_lists_of_arrays_the_same(data1, data2): + """returns False if to lists of arrays are different. otherwise the function returns True. - ''' - - if len(data1) != len(data2): + """ + + if len(data1) != len(data2): return False - for a,b in zip(data1,data2): + for a, b in zip(data1, data2): if np.any(a != b): return False - + return True - - \ No newline at end of file diff --git a/allensdk/internal/model/glif/configure_model.py b/allensdk/internal/model/glif/configure_model.py index fe3c8da46a..f19e24bab5 100644 --- a/allensdk/internal/model/glif/configure_model.py +++ b/allensdk/internal/model/glif/configure_model.py @@ -1,218 +1,228 @@ -#going to need to take preprocessed dictionaries and model configuration and create -#a preprocessed model configuration and preprocessed_config file +# going to need to take preprocessed dictionaries and model configuration and create +# a preprocessed model configuration and preprocessed_config file # -#something will have to tell it what parameters to take out of the preprocessed dict +# something will have to tell it what parameters to take out of the preprocessed dict import numpy as np import argparse import allensdk.core.json_utilities as ju -class ModelConfigurationException( Exception ): + +class ModelConfigurationException(Exception): pass + DEFAULT_NEURON_PARAMETERS = { - "type": "GLIF", - "dt": 5e-05, + "type": "GLIF", + "dt": 5e-05, "El": 0, - "asc_tau_array": [ 1, 1 ], - "asc_amp_array": [ 0, 0 ], - "init_AScurrents": [ 0.0, 0.0 ], - "init_threshold": 0.02, + "asc_tau_array": [1, 1], + "asc_amp_array": [0, 0], + "init_AScurrents": [0.0, 0.0], + "init_threshold": 0.02, "init_voltage": 0.0, "extrapolation_method_name": "endpoints", - "dt_multiplier": 1 - } + "dt_multiplier": 1, +} -DEFAULT_OPTIMIZER_PARAMETERS = { - "xtol": 1e-05, - "ftol": 1e-05, - "sigma_outer": 0.3, - "sigma_inner": 0.01, - "inner_iterations": 3, - "outer_iterations": 3, +DEFAULT_OPTIMIZER_PARAMETERS = { + "xtol": 1e-05, + "ftol": 1e-05, + "sigma_outer": 0.3, + "sigma_inner": 0.01, + "inner_iterations": 3, + "outer_iterations": 3, "internal_iterations": 10000000, "iteration_info": [], "param_fit_names": [], "cut": 0, - "bessel": { 'N': 4, 'freq': 10000 } - } - + "bessel": {"N": 4, "freq": 10000}, +} + def specify_parameter_groups(dictionary, dict_specifer, neuron_type): - '''Specifies which values from the preprocessor will be used in the model configuration. - This is helpful if the preprocessor calculates many different values. - + """Specifies which values from the preprocessor will be used in the model configuration. + This is helpful if the preprocessor calculates many different values. + Parameters ---------- dictionary: dict dictionary from preprocessor dict_specifier: string - The following are available model levels + The following are available model levels 'LIF' (GLIF1) 'LIF_R' (GLIF2) 'LIF_ASC' (GLIF3) - 'LIF_R_ASC' (GLIF_4) + 'LIF_R_ASC' (GLIF_4) 'LIF_R_ASC_AT' (GLIF_5) neuron_type: string - 'simple_neuron' is the only available option however here would be a good place for + 'simple_neuron' is the only available option however here would be a good place for the user to implement their own configurations. - + Returns ------- output_dict: dict dictionary containing model configuration - ''' - - output_dict={'El_reference':dictionary['El']['El_noise']['measured']['mean'], - 'El':0., - 'dt':dictionary['dt_used_for_preprocessor_calculations'], - 'spike_cut_length':dictionary['spike_cutting']['NOdeltaV']['cut_length'], - 'spike_cutting_intercept':dictionary['spike_cutting']['NOdeltaV']['intercept'], - 'spike_cutting_slope':dictionary['spike_cutting']['NOdeltaV']['slope'], - 'asc_amp_array':dictionary['asc']['amp'], - 'asc_tau_array':(1./np.array(dictionary['asc']['k'])).tolist(), - 'th_inf': dictionary['th_inf']['via_Vmeasure']['from_zero'], - 'deltaV': None, - 'threshold_adaptation': {'a_spike_component_of_threshold': dictionary['threshold_adaptation']['a_spike_component_of_threshold'], - 'b_spike_component_of_threshold':dictionary['threshold_adaptation']['b_spike_component_of_threshold'], - 'a_voltage_component_of_threshold':dictionary['threshold_adaptation']['a_voltage_comp_of_thr_from_fitab'], - 'b_voltage_component_of_threshold': dictionary['threshold_adaptation']['b_voltage_comp_of_thr_from_fitab']}, - 'MLIN': dictionary['MLIN'], - 'spike_inds': { - 'noise1': [ ], - 'noise2': [ ] - } - } - - # specify specific values different for different levels. Although there is only one neuron + """ + + output_dict = { + "El_reference": dictionary["El"]["El_noise"]["measured"]["mean"], + "El": 0.0, + "dt": dictionary["dt_used_for_preprocessor_calculations"], + "spike_cut_length": dictionary["spike_cutting"]["NOdeltaV"]["cut_length"], + "spike_cutting_intercept": dictionary["spike_cutting"]["NOdeltaV"]["intercept"], + "spike_cutting_slope": dictionary["spike_cutting"]["NOdeltaV"]["slope"], + "asc_amp_array": dictionary["asc"]["amp"], + "asc_tau_array": (1.0 / np.array(dictionary["asc"]["k"])).tolist(), + "th_inf": dictionary["th_inf"]["via_Vmeasure"]["from_zero"], + "deltaV": None, + "threshold_adaptation": { + "a_spike_component_of_threshold": dictionary["threshold_adaptation"]["a_spike_component_of_threshold"], + "b_spike_component_of_threshold": dictionary["threshold_adaptation"]["b_spike_component_of_threshold"], + "a_voltage_component_of_threshold": dictionary["threshold_adaptation"]["a_voltage_comp_of_thr_from_fitab"], + "b_voltage_component_of_threshold": dictionary["threshold_adaptation"]["b_voltage_comp_of_thr_from_fitab"], + }, + "MLIN": dictionary["MLIN"], + "spike_inds": {"noise1": [], "noise2": []}, + } + + # specify specific values different for different levels. Although there is only one neuron # type here, this would be a good place to add other user defined neuron types - if neuron_type=='simple_neuron': - output_dict['C']=dictionary['capacitance']['C_test_list']['mean'] - if dict_specifer in ['LIF', 'LIF_R']: - output_dict['R_input']=dictionary['resistance']['R_test_list']['mean'] - elif dict_specifer in ['LIF_ASC', 'LIF_R_ASC', 'LIF_R_ASC_AT']: - output_dict['R_input']=dictionary['resistance']['R_fit_ASC_and_R']['mean'] - - for k,v in dictionary['sweep_properties']['noise1'].items(): - output_dict['spike_inds']['noise1'].append( v['spike_ind'] ) - output_dict['spike_inds']['noise2'].append( v['spike_ind'] ) - + if neuron_type == "simple_neuron": + output_dict["C"] = dictionary["capacitance"]["C_test_list"]["mean"] + if dict_specifer in ["LIF", "LIF_R"]: + output_dict["R_input"] = dictionary["resistance"]["R_test_list"]["mean"] + elif dict_specifer in ["LIF_ASC", "LIF_R_ASC", "LIF_R_ASC_AT"]: + output_dict["R_input"] = dictionary["resistance"]["R_fit_ASC_and_R"]["mean"] + + for k, v in dictionary["sweep_properties"]["noise1"].items(): + output_dict["spike_inds"]["noise1"].append(v["spike_ind"]) + output_dict["spike_inds"]["noise2"].append(v["spike_ind"]) + return output_dict def validate_method_requirements(method_config_name, has_mss): - '''Confirm that the neuron has the specific sweeps required for the specified configuration - + """Confirm that the neuron has the specific sweeps required for the specified configuration + Parameters ---------- method_config_name: string - Specifies the model level. Options are: + Specifies the model level. Options are: 'LIF' (GLIF1) 'LIF_R' (GLIF2) 'LIF_ASC' (GLIF3) - 'LIF_R_ASC' (GLIF_4) + 'LIF_R_ASC' (GLIF_4) 'LIF_R_ASC_AT' (GLIF_5) has_mss: boolean Specifies if the neuron has a multi short square sweep (for fitting spike component of threshold). - ''' + """ if not has_mss: - valid_configs = ['LIF', 'LIF_ASC'] + valid_configs = ["LIF", "LIF_ASC"] else: - valid_configs = ['LIF', 'LIF_ASC','LIF_R', 'LIF_R_ASC', 'LIF_R_ASC_AT'] + valid_configs = ["LIF", "LIF_ASC", "LIF_R", "LIF_R_ASC", "LIF_R_ASC_AT"] if method_config_name not in valid_configs: - raise ModelConfigurationException("Model type %s cannot be configured due to missing data (mss: %s)" % ( method_config_name, str(has_mss))) + raise ModelConfigurationException( + "Model type %s cannot be configured due to missing data (mss: %s)" % (method_config_name, str(has_mss)) + ) + def update_neuron_method(method_type, arg_method_name, neuron_config): - #TODO: documentation - neuron_config[method_type] = { 'name': arg_method_name, 'params': None } + # TODO: documentation + neuron_config[method_type] = {"name": arg_method_name, "params": None} def configure_model(method_config, preprocessor_values): - '''Configures the model from the specified method configuration and preprocessor values. - + """Configures the model from the specified method configuration and preprocessor values. + Parameters ---------- method_config: dictionary contains values needed to configure the methods for the specified level within the dictionary preprocessor_values: dictionary dictionary from preprocessor - ''' - - preprocessor_values = specify_parameter_groups(preprocessor_values, method_config['name'], 'simple_neuron') + """ + + preprocessor_values = specify_parameter_groups(preprocessor_values, method_config["name"], "simple_neuron") neuron_config = {} neuron_config.update(DEFAULT_NEURON_PARAMETERS) optimizer_config = {} optimizer_config.update(DEFAULT_OPTIMIZER_PARAMETERS) - #a) select values want to use out of the preprocessor_values via specifying parameter_gropus - #b) look what levels are available via the levels available in the preprocessor_values. - + # a) select values want to use out of the preprocessor_values via specifying parameter_gropus + # b) look what levels are available via the levels available in the preprocessor_values. + # Skip trace if subthreshold noise has a spike in it. - noise1_ind = [ n1i for n1i in preprocessor_values['spike_inds']['noise1'] if n1i is not None ] + noise1_ind = [n1i for n1i in preprocessor_values["spike_inds"]["noise1"] if n1i is not None] noise1_ind = np.concatenate(noise1_ind) - if np.any(noise1_ind * preprocessor_values['dt'] < 8.0): + if np.any(noise1_ind * preprocessor_values["dt"] < 8.0): raise ModelConfigurationException("Subthreshold region of noise1 stimulus contains spikes.") # check if there is a short square triple - if preprocessor_values['threshold_adaptation']['b_spike_component_of_threshold'] and preprocessor_values['threshold_adaptation']['a_spike_component_of_threshold']: - has_mss=True + if ( + preprocessor_values["threshold_adaptation"]["b_spike_component_of_threshold"] + and preprocessor_values["threshold_adaptation"]["a_spike_component_of_threshold"] + ): + has_mss = True else: - has_mss=False - + has_mss = False + # make sure that the requested method config meets minimum requirements - validate_method_requirements(method_config['name'], has_mss) - - update_neuron_method('AScurrent_dynamics_method', method_config['AScurrent_dynamics_method'], neuron_config) - update_neuron_method('voltage_dynamics_method', method_config['voltage_dynamics_method'], neuron_config) - update_neuron_method('threshold_dynamics_method', method_config['threshold_dynamics_method'], neuron_config) - update_neuron_method('AScurrent_reset_method', method_config['AScurrent_reset_method'], neuron_config) - update_neuron_method('voltage_reset_method', method_config['voltage_reset_method'], neuron_config) - update_neuron_method('threshold_reset_method', method_config['threshold_reset_method'], neuron_config) - - neuron_config['El_reference'] = preprocessor_values['El_reference'] - neuron_config['C'] = preprocessor_values['C'] - neuron_config['El'] = preprocessor_values['El'] - neuron_config['spike_cut_length'] = preprocessor_values['spike_cut_length'] - neuron_config['asc_amp_array'] = preprocessor_values['asc_amp_array'] - neuron_config['asc_tau_array'] = preprocessor_values['asc_tau_array'] - neuron_config['R_input'] = preprocessor_values['R_input'] - neuron_config['th_inf'] = preprocessor_values['th_inf'] - - optimizer_config['error_function'] = method_config['error_function'] - optimizer_config['param_fit_names'] = method_config['param_fit_names'] - - #b) choose the sets want from the preprocessor_values - configure_method_parameters(neuron_config, - optimizer_config, - preprocessor_values['spike_cutting_slope'], - preprocessor_values['spike_cutting_intercept'], - preprocessor_values['threshold_adaptation']['a_spike_component_of_threshold'], - preprocessor_values['threshold_adaptation']['b_spike_component_of_threshold'], - preprocessor_values['threshold_adaptation']['a_voltage_component_of_threshold'], - preprocessor_values['threshold_adaptation']['b_voltage_component_of_threshold'], - preprocessor_values['MLIN']['var_of_section'], - preprocessor_values['MLIN']['sv_for_expsymm'], - preprocessor_values['MLIN']['tau_from_AC']) - - return { - 'neuron': neuron_config, - 'optimizer': optimizer_config - } + validate_method_requirements(method_config["name"], has_mss) + + update_neuron_method("AScurrent_dynamics_method", method_config["AScurrent_dynamics_method"], neuron_config) + update_neuron_method("voltage_dynamics_method", method_config["voltage_dynamics_method"], neuron_config) + update_neuron_method("threshold_dynamics_method", method_config["threshold_dynamics_method"], neuron_config) + update_neuron_method("AScurrent_reset_method", method_config["AScurrent_reset_method"], neuron_config) + update_neuron_method("voltage_reset_method", method_config["voltage_reset_method"], neuron_config) + update_neuron_method("threshold_reset_method", method_config["threshold_reset_method"], neuron_config) + + neuron_config["El_reference"] = preprocessor_values["El_reference"] + neuron_config["C"] = preprocessor_values["C"] + neuron_config["El"] = preprocessor_values["El"] + neuron_config["spike_cut_length"] = preprocessor_values["spike_cut_length"] + neuron_config["asc_amp_array"] = preprocessor_values["asc_amp_array"] + neuron_config["asc_tau_array"] = preprocessor_values["asc_tau_array"] + neuron_config["R_input"] = preprocessor_values["R_input"] + neuron_config["th_inf"] = preprocessor_values["th_inf"] + + optimizer_config["error_function"] = method_config["error_function"] + optimizer_config["param_fit_names"] = method_config["param_fit_names"] + + # b) choose the sets want from the preprocessor_values + configure_method_parameters( + neuron_config, + optimizer_config, + preprocessor_values["spike_cutting_slope"], + preprocessor_values["spike_cutting_intercept"], + preprocessor_values["threshold_adaptation"]["a_spike_component_of_threshold"], + preprocessor_values["threshold_adaptation"]["b_spike_component_of_threshold"], + preprocessor_values["threshold_adaptation"]["a_voltage_component_of_threshold"], + preprocessor_values["threshold_adaptation"]["b_voltage_component_of_threshold"], + preprocessor_values["MLIN"]["var_of_section"], + preprocessor_values["MLIN"]["sv_for_expsymm"], + preprocessor_values["MLIN"]["tau_from_AC"], + ) + + return {"neuron": neuron_config, "optimizer": optimizer_config} + + +def configure_method_parameters( + neuron_config, + optimizer_config, + v_reset_slope, + v_reset_intercept, + a_spike_component_of_threshold, + b_spike_component_of_threshold, + a_voltage_component_of_threshold, + b_voltage_component_of_threshold, + var_of_section, + sv_for_expsymm, + tau_from_AC, +): + """Configures the methods used to run the models -def configure_method_parameters(neuron_config, - optimizer_config, - v_reset_slope, - v_reset_intercept, - a_spike_component_of_threshold, - b_spike_component_of_threshold, - a_voltage_component_of_threshold, - b_voltage_component_of_threshold, - var_of_section, - sv_for_expsymm, - tau_from_AC): - '''Configures the methods used to run the models - Parameters ---------- neuron_config: dict @@ -237,162 +247,155 @@ def configure_method_parameters(neuron_config, parameter in MLIN optimization tau_from_AC: float time course of exponential fit to the autocorrelation - ''' - + """ + # configure voltage reset rules - method_config = neuron_config['voltage_reset_method'] - if method_config.get('params', None) is None: - if method_config['name'] == 'zero': - method_config['params'] = {} - elif method_config['name'] == 'v_before': - method_config['params'] = { - 'a': v_reset_slope, - 'b': v_reset_intercept - } - - elif method_config['name'] == 'i_v_before': - method_config['params'] = { - 'a': 1, - 'b': 2, - 'c': 3 - } - raise ModelConfigurationException('i_v_before of voltage reset method is not yet implemented') - elif method_config['name'] == 'fixed': - raise ModelConfigurationException('cannot use fixed voltage reset method in preprocessor') + method_config = neuron_config["voltage_reset_method"] + if method_config.get("params", None) is None: + if method_config["name"] == "zero": + method_config["params"] = {} + elif method_config["name"] == "v_before": + method_config["params"] = {"a": v_reset_slope, "b": v_reset_intercept} + + elif method_config["name"] == "i_v_before": + method_config["params"] = {"a": 1, "b": 2, "c": 3} + raise ModelConfigurationException("i_v_before of voltage reset method is not yet implemented") + elif method_config["name"] == "fixed": + raise ModelConfigurationException("cannot use fixed voltage reset method in preprocessor") else: - method_config['params'] = {} - - # configure threshold reset rules - method_config = neuron_config['threshold_reset_method'] - if method_config.get('params', None) is None: - coeff_th_inf = neuron_config.get('coeffs', {}).get('th_inf',1.0) - adjusted_th_inf = neuron_config['th_inf'] * coeff_th_inf - - if method_config['name'] == 'max_v_th': - raise ModelConfigurationException('max_v_th threshold reset rule is not currently in use') - - elif method_config['name'] == 'th_before': - raise ModelConfigurationException('th_before is not currently in use') - - elif method_config['name'] == 'inf': - method_config['params'] = {} - neuron_config['init_threshold'] = adjusted_th_inf - - elif method_config['name'] == 'three_components': - method_config['params'] = { 'a_spike': a_spike_component_of_threshold, - 'b_spike': b_spike_component_of_threshold } - neuron_config['init_threshold'] = adjusted_th_inf - - elif method_config['name'] == 'fixed': + method_config["params"] = {} + + # configure threshold reset rules + method_config = neuron_config["threshold_reset_method"] + if method_config.get("params", None) is None: + coeff_th_inf = neuron_config.get("coeffs", {}).get("th_inf", 1.0) + adjusted_th_inf = neuron_config["th_inf"] * coeff_th_inf + + if method_config["name"] == "max_v_th": + raise ModelConfigurationException("max_v_th threshold reset rule is not currently in use") + + elif method_config["name"] == "th_before": + raise ModelConfigurationException("th_before is not currently in use") + + elif method_config["name"] == "inf": + method_config["params"] = {} + neuron_config["init_threshold"] = adjusted_th_inf + + elif method_config["name"] == "three_components": + method_config["params"] = { + "a_spike": a_spike_component_of_threshold, + "b_spike": b_spike_component_of_threshold, + } + neuron_config["init_threshold"] = adjusted_th_inf + + elif method_config["name"] == "fixed": raise ModelConfigurationException("cannot use fixed threshold reset method in preprocessor") - else: - raise ModelConfigurationException("unknown threshold reset method: ", method_config['name']) - + else: + raise ModelConfigurationException("unknown threshold reset method: ", method_config["name"]) + # configure voltage dynamics rules - - method_config = neuron_config['voltage_dynamics_method'] - if method_config.get('params', None) is None: - if method_config['name'] == 'quadratic_i_of_v': - raise ModelConfigurationException('quadraticIofV of voltage_dynamics_method preprocessing is not yet implemented') - elif method_config['name'] == 'linear_forward_euler': - method_config['params'] = {} - elif method_config['name'] == 'linear_exact': - method_config['params'] = {} + + method_config = neuron_config["voltage_dynamics_method"] + if method_config.get("params", None) is None: + if method_config["name"] == "quadratic_i_of_v": + raise ModelConfigurationException( + "quadraticIofV of voltage_dynamics_method preprocessing is not yet implemented" + ) + elif method_config["name"] == "linear_forward_euler": + method_config["params"] = {} + elif method_config["name"] == "linear_exact": + method_config["params"] = {} else: - raise ModelConfigurationException("unknown voltage dynamics method: ", method_config['name']) - - # configure threshold dynamics rules - method_config = neuron_config['threshold_dynamics_method'] - - if method_config.get('params', None) is None: - if method_config['name'] == 'three_components_forward': - method_config['params'] = { - 'a_spike': a_spike_component_of_threshold, - 'b_spike': b_spike_component_of_threshold, - 'a_voltage': a_voltage_component_of_threshold, - 'b_voltage': b_voltage_component_of_threshold - } - - elif method_config['name'] == 'three_components_exact': - method_config['params'] = { - 'a_spike': a_spike_component_of_threshold, - 'b_spike': b_spike_component_of_threshold, - 'a_voltage': a_voltage_component_of_threshold, - 'b_voltage': b_voltage_component_of_threshold - } - - elif method_config['name'] == 'spike_component': - method_config['params'] = { - 'a_spike': a_spike_component_of_threshold, - 'b_spike': b_spike_component_of_threshold, - 'a_voltage': 0, - 'b_voltage': 0 - } - - elif method_config['name'] == 'inf': - method_config['params'] = {} - + raise ModelConfigurationException("unknown voltage dynamics method: ", method_config["name"]) + + # configure threshold dynamics rules + method_config = neuron_config["threshold_dynamics_method"] + + if method_config.get("params", None) is None: + if method_config["name"] == "three_components_forward": + method_config["params"] = { + "a_spike": a_spike_component_of_threshold, + "b_spike": b_spike_component_of_threshold, + "a_voltage": a_voltage_component_of_threshold, + "b_voltage": b_voltage_component_of_threshold, + } + + elif method_config["name"] == "three_components_exact": + method_config["params"] = { + "a_spike": a_spike_component_of_threshold, + "b_spike": b_spike_component_of_threshold, + "a_voltage": a_voltage_component_of_threshold, + "b_voltage": b_voltage_component_of_threshold, + } + + elif method_config["name"] == "spike_component": + method_config["params"] = { + "a_spike": a_spike_component_of_threshold, + "b_spike": b_spike_component_of_threshold, + "a_voltage": 0, + "b_voltage": 0, + } + + elif method_config["name"] == "inf": + method_config["params"] = {} + else: - raise ModelConfigurationException("unknown threshold dynamics method: ", method_config['name']) - + raise ModelConfigurationException("unknown threshold dynamics method: ", method_config["name"]) + # configure ascurrent dynamics rules - method_config = neuron_config['AScurrent_dynamics_method'] - if method_config.get('params', None) is None: + method_config = neuron_config["AScurrent_dynamics_method"] + if method_config.get("params", None) is None: # TODO: rename 'vector' to something more specific - if method_config['name'] == 'vector': - method_config['params'] = { - 'vector': [1, 2, 3] - } - raise ModelConfigurationException('vector of AScurrent_dynamics_method is not yet implemented') - elif method_config['name'] == 'none': - method_config['params'] = {} - elif method_config['name'] == 'exp': - method_config['params'] = {} + if method_config["name"] == "vector": + method_config["params"] = {"vector": [1, 2, 3]} + raise ModelConfigurationException("vector of AScurrent_dynamics_method is not yet implemented") + elif method_config["name"] == "none": + method_config["params"] = {} + elif method_config["name"] == "exp": + method_config["params"] = {} else: - raise ModelConfigurationException("unknown AScurrent dynamics method: ", method_config['name']) - - # configure ascurrent reset rule + raise ModelConfigurationException("unknown AScurrent dynamics method: ", method_config["name"]) + + # configure ascurrent reset rule # this is down here because it depends on numbers computed for the AScurrent_dynamics_method - method_config = neuron_config['AScurrent_reset_method'] - if method_config.get('params', None) is None: - if method_config['name'] == 'sum': - method_config['params'] = { - 'r': np.ones(len(neuron_config['asc_tau_array'])) - } - elif method_config['name'] == 'none': - method_config['params'] = {} + method_config = neuron_config["AScurrent_reset_method"] + if method_config.get("params", None) is None: + if method_config["name"] == "sum": + method_config["params"] = {"r": np.ones(len(neuron_config["asc_tau_array"]))} + elif method_config["name"] == "none": + method_config["params"] = {} else: - raise ModelConfigurationException("unknown AScurrent reset method: ", method_config['name']) - + raise ModelConfigurationException("unknown AScurrent reset method: ", method_config["name"]) + # configure parameters for MLIN optimization - if optimizer_config['error_function']=='MLIN': - optimizer_config['error_function_data'] = { - 'subthreshold_long_square_voltage_variance': var_of_section, - 'sv_for_expsymm': sv_for_expsymm, - 'tau_from_AC': tau_from_AC - } + if optimizer_config["error_function"] == "MLIN": + optimizer_config["error_function_data"] = { + "subthreshold_long_square_voltage_variance": var_of_section, + "sv_for_expsymm": sv_for_expsymm, + "tau_from_AC": tau_from_AC, + } # validation # make sure that the initial ascurrents have the correct size - if len(neuron_config['init_AScurrents']) != len(neuron_config['asc_tau_array']): + if len(neuron_config["init_AScurrents"]) != len(neuron_config["asc_tau_array"]): raise ModelConfigurationException("init_AScurrents have incorrect length.") - - spike_cut_length = neuron_config.get('spike_cut_length', None) - + + spike_cut_length = neuron_config.get("spike_cut_length", None) + if spike_cut_length is None: raise ModelConfigurationException("Spike cut length must be set, but it is not.") - + if spike_cut_length < 0: raise ModelConfigurationException("Spike cut length must be non-negative.") def main(): parser = argparse.ArgumentParser() - parser.add_argument('preprocessor_values_path', help='path to preprocessor values json') - parser.add_argument('method_config_path', help='path to method configuration json') - parser.add_argument('output_path', help='path to store final model configuration') + parser.add_argument("preprocessor_values_path", help="path to preprocessor values json") + parser.add_argument("method_config_path", help="path to method configuration json") + parser.add_argument("output_path", help="path to store final model configuration") args = parser.parse_args() diff --git a/allensdk/internal/model/glif/error_functions.py b/allensdk/internal/model/glif/error_functions.py index 27a23aaa23..117018e319 100644 --- a/allensdk/internal/model/glif/error_functions.py +++ b/allensdk/internal/model/glif/error_functions.py @@ -10,225 +10,241 @@ # TODO: clean up # TODO: license + def MLIN_list_error(param_guess, experiment, input_data): - #TODO: binning is now done in preprocessor so perhaps should take it out of here. -# voltage_distribution=norm(loc=0, scale=np.sqrt(voltage_variance)*10.) - sv=input_data['sv_for_expsymm'] #used in the expsymm function -# tau_4AC=experiment.neuron.R_input*experiment.neuron.C - tau_from_AC=input_data['tau_from_AC'] - spike_length=int(experiment.neuron.spike_cut_length) - noSpike_bin_size_ind=int(tau_from_AC/experiment.neuron.dt) - spike_bin_size_time=.005 #TODO: PLAY WITH THIS VALUE. 1 MS, 2, 4, 8 - spike_bin_size_ind=int(spike_bin_size_time/experiment.neuron.dt) - - logging.info('running parameter guess: %s' % param_guess) - + # TODO: binning is now done in preprocessor so perhaps should take it out of here. + # voltage_distribution=norm(loc=0, scale=np.sqrt(voltage_variance)*10.) + sv = input_data["sv_for_expsymm"] # used in the expsymm function + # tau_4AC=experiment.neuron.R_input*experiment.neuron.C + tau_from_AC = input_data["tau_from_AC"] + spike_length = int(experiment.neuron.spike_cut_length) + noSpike_bin_size_ind = int(tau_from_AC / experiment.neuron.dt) + spike_bin_size_time = 0.005 # TODO: PLAY WITH THIS VALUE. 1 MS, 2, 4, 8 + spike_bin_size_ind = int(spike_bin_size_time / experiment.neuron.dt) + + logging.info("running parameter guess: %s" % param_guess) MLIN_list = [] try: run_data = experiment.run(param_guess) except GlifNeuronException as e: - out=e.data + out = e.data raise e except GlifBadInitializationException as e: - logging.error('voltage STARTS above threshold: setting error to be large.Difference between thresh and voltage is: %f' % e.dv) + logging.error( + "voltage STARTS above threshold: setting error to be large.Difference between thresh and voltage is: %f" + % e.dv + ) raise Exception() except GlifBadResetException as e: - logging.error('THIS REALLY SHOULDNT HAPPEN WITH NEW INITIALIZATION EXCEPTION: voltage is above threshold at reset: setting error to be large. Difference between thresh and voltage is: %f' % e.dv) + logging.error( + "THIS REALLY SHOULDNT HAPPEN WITH NEW INITIALIZATION EXCEPTION: voltage is above threshold at reset: setting error to be large. Difference between thresh and voltage is: %f" + % e.dv + ) raise Exception() - v_model_list=[] - th_model_list=[] - non_spike_bin_ind_edges_list=[] - spike_edges_list_list=[] - spike_bins_list=[] - noSpike_bins_list=[] - spike_prob_list=[] - noSpike_prob_list=[] - - for stim_list_index in range(0,len(experiment.stim_list)): - #TODO: the following line is a hack to take care of the case when there are no spikes in a sweep - if len(experiment.spike_time_steps[stim_list_index])==0: - MLIN=0 - raise Exception('there are no spikes in the sweep') + v_model_list = [] + th_model_list = [] + non_spike_bin_ind_edges_list = [] + spike_edges_list_list = [] + spike_bins_list = [] + noSpike_bins_list = [] + spike_prob_list = [] + noSpike_prob_list = [] + + for stim_list_index in range(0, len(experiment.stim_list)): + # TODO: the following line is a hack to take care of the case when there are no spikes in a sweep + if len(experiment.spike_time_steps[stim_list_index]) == 0: + MLIN = 0 + raise Exception("there are no spikes in the sweep") else: - bio_spike_ind=experiment.spike_time_steps[stim_list_index] - v_model=run_data['voltage'][stim_list_index] - th_model=run_data['threshold'][stim_list_index] - #------------------------------------------------------------------------------------ - #---------------make all the bins---------------------------------------------------- - #------------------------------------------------------------------------------------ - - #--for every spike define a region for making non spiking bins - between_spike_edges_list=[] - between_spike_edges_list.append([0, bio_spike_ind[0]-spike_bin_size_ind]) #edges to first spike - for ii in range(0, len(bio_spike_ind)-1): - between_spike_edges_list.append([bio_spike_ind[ii]+spike_length, bio_spike_ind[ii+1]-spike_bin_size_ind]) - - - #--define no spike bin edges - non_spike_bin_ind_edges_in_ISI=[] + bio_spike_ind = experiment.spike_time_steps[stim_list_index] + v_model = run_data["voltage"][stim_list_index] + th_model = run_data["threshold"][stim_list_index] + # ------------------------------------------------------------------------------------ + # ---------------make all the bins---------------------------------------------------- + # ------------------------------------------------------------------------------------ + + # --for every spike define a region for making non spiking bins + between_spike_edges_list = [] + between_spike_edges_list.append([0, bio_spike_ind[0] - spike_bin_size_ind]) # edges to first spike + for ii in range(0, len(bio_spike_ind) - 1): + between_spike_edges_list.append( + [bio_spike_ind[ii] + spike_length, bio_spike_ind[ii + 1] - spike_bin_size_ind] + ) + + # --define no spike bin edges + non_spike_bin_ind_edges_in_ISI = [] for btw_spike_edges in between_spike_edges_list: - temp=range(btw_spike_edges[0], btw_spike_edges[1], noSpike_bin_size_ind) + temp = range(btw_spike_edges[0], btw_spike_edges[1], noSpike_bin_size_ind) temp.append(btw_spike_edges[1]) non_spike_bin_ind_edges_in_ISI.append(temp) - - #--define spike bin edges - spike_edges_list=[] + + # --define spike bin edges + spike_edges_list = [] for spike in bio_spike_ind: - spike_edges_list.append([spike-spike_bin_size_ind, spike]) - - #--nonspike bin edges need to be arranged correctly because they are just all the edges for a given ISI - non_spike_bin_ind_edges=[] + spike_edges_list.append([spike - spike_bin_size_ind, spike]) + + # --nonspike bin edges need to be arranged correctly because they are just all the edges for a given ISI + non_spike_bin_ind_edges = [] for edges_in_1_spike in non_spike_bin_ind_edges_in_ISI: - for ii in range(0,len(edges_in_1_spike)-1): - non_spike_bin_ind_edges.append([edges_in_1_spike[ii], edges_in_1_spike[ii+1]]) - - #----------------------------------------------------------------------------- - #-------------finding values in bins------------------------------------------ - #----------------------------------------------------------------------------- - - spike_bins={} -# spike_bins['vmax']=[] - spike_bins['th']=[] - spike_bins['v']=[] - spike_bins['v_th_diff']=[] -# spike_bins['ind_of_max_v']=[] - spike_bins['ind_of_max_diff_btw_v_th']=[] - + for ii in range(0, len(edges_in_1_spike) - 1): + non_spike_bin_ind_edges.append([edges_in_1_spike[ii], edges_in_1_spike[ii + 1]]) + + # ----------------------------------------------------------------------------- + # -------------finding values in bins------------------------------------------ + # ----------------------------------------------------------------------------- + + spike_bins = {} + # spike_bins['vmax']=[] + spike_bins["th"] = [] + spike_bins["v"] = [] + spike_bins["v_th_diff"] = [] + # spike_bins['ind_of_max_v']=[] + spike_bins["ind_of_max_diff_btw_v_th"] = [] + for bin_edges in spike_edges_list: - bin_ind=range(bin_edges[0], bin_edges[1]) - #vmax_in_bin=max(v_model[bin_ind]) - diff_vector=th_model[bin_ind]-v_model[bin_ind] - #the_ind=bin_ind[np.where(v_model[bin_ind]==vmax_in_bin)[0]] this is used to use in where v is at a max (as opposed to the difference between v and th) -# print("***********************************************************") -# print('th_model[bin_ind]', th_model[bin_ind]) -# print('v_model[bin_ind]',v_model[bin_ind]) -# print('diff_vector', diff_vector) -# print('np.where(diff_vector==min(diff_vector))[0]', np.where(diff_vector==min(diff_vector))[0]) - the_ind=bin_ind[np.where(diff_vector==min(diff_vector[~np.isnan(diff_vector)]))[0][0]] - th_in_bin=th_model[the_ind] - v_in_bin=v_model[the_ind] -# diffV=th_in_bin-vmax_in_bin - diffV=th_model[the_ind]-v_model[the_ind] - spike_bins['v'].append(v_in_bin) - spike_bins['th'].append(th_in_bin) - spike_bins['v_th_diff'].append(diffV) - spike_bins['ind_of_max_diff_btw_v_th'].append(the_ind) - - - noSpike_bins={} -# noSpike_bins['vmax']=[] - noSpike_bins['th']=[] - noSpike_bins['v']=[] - noSpike_bins['v_th_diff']=[] - noSpike_bins['ind_of_max_diff_btw_v_th']=[] + bin_ind = range(bin_edges[0], bin_edges[1]) + # vmax_in_bin=max(v_model[bin_ind]) + diff_vector = th_model[bin_ind] - v_model[bin_ind] + # the_ind=bin_ind[np.where(v_model[bin_ind]==vmax_in_bin)[0]] this is used to use in where v is at a max (as opposed to the difference between v and th) + # print("***********************************************************") + # print('th_model[bin_ind]', th_model[bin_ind]) + # print('v_model[bin_ind]',v_model[bin_ind]) + # print('diff_vector', diff_vector) + # print('np.where(diff_vector==min(diff_vector))[0]', np.where(diff_vector==min(diff_vector))[0]) + the_ind = bin_ind[np.where(diff_vector == min(diff_vector[~np.isnan(diff_vector)]))[0][0]] + th_in_bin = th_model[the_ind] + v_in_bin = v_model[the_ind] + # diffV=th_in_bin-vmax_in_bin + diffV = th_model[the_ind] - v_model[the_ind] + spike_bins["v"].append(v_in_bin) + spike_bins["th"].append(th_in_bin) + spike_bins["v_th_diff"].append(diffV) + spike_bins["ind_of_max_diff_btw_v_th"].append(the_ind) + + noSpike_bins = {} + # noSpike_bins['vmax']=[] + noSpike_bins["th"] = [] + noSpike_bins["v"] = [] + noSpike_bins["v_th_diff"] = [] + noSpike_bins["ind_of_max_diff_btw_v_th"] = [] for bin_edges in non_spike_bin_ind_edges: - bin_ind=range(bin_edges[0], bin_edges[1]) -# vmax_in_bin=max(v_model[bin_ind]) - diff_vector=th_model[bin_ind]-v_model[bin_ind] -# max_indicies=np.where(v_model[bin_ind]==vmax_in_bin)[0] - min_indicies=np.where(diff_vector==min(diff_vector))[0] -# if len(max_indicies)>1: -# print('there is more than one maximum indicie in a bin at', max_indicies, 'choosing last value for computation') -# print('all voltages in the bin are', v_model[bin_ind]) - the_ind=bin_ind[min_indicies[-1]] #this is here just incase there is more than one value at max voltage in a bin - th_in_bin=th_model[the_ind] - v_in_bin=v_model[the_ind] - diffV=th_model[the_ind]-v_model[the_ind] - noSpike_bins['v'].append(v_in_bin) - noSpike_bins['th'].append(th_in_bin) - noSpike_bins['v_th_diff'].append(diffV) - noSpike_bins['ind_of_max_diff_btw_v_th'].append(the_ind) - - #----------------------------------------------------------------------------- - #-------------calculate MLIN-------------------------------------------------- - #----------------------------------------------------------------------------- - -# this was the version with the normal distribution -# noSpike_prob=np.log(np.spacing(1)+voltage_distribution.cdf(noSpike_bins['v_th_diff'])) -# spike_prob=np.log(1+np.spacing(1)-voltage_distribution.cdf(spike_bins['v_th_diff'])) - #version with the expsymm function - - #---OPTION ONE-------- -# noSpike_prob=np.log(np.spacing(1)+expsymm_cdf(noSpike_bins['v_th_diff'], sv)) -# spike_prob=np.log(1+np.spacing(1)-expsymm_cdf(spike_bins['v_th_diff'], sv)) -# #---OPTION TWO-------- -# N_spike=np.float(len(spike_bins['v_th_diff'])) -# N_noSpike=np.float(len(noSpike_bins['v_th_diff'])) -# noSpike_prob=(N_spike/(N_noSpike+N_spike))*np.log(np.spacing(1)+expsymm_cdf(noSpike_bins['v_th_diff'], sv)) -# spike_prob=(N_noSpike/(N_noSpike+N_spike))*np.log(1+np.spacing(1)-expsymm_cdf(spike_bins['v_th_diff'], sv)) - #---OPTION THREE - noSpike_negDiff=(-np.log(2.)+np.array(noSpike_bins['v_th_diff'])/sv)[np.array(noSpike_bins['v_th_diff'])<=0.0] - noSpike_posDiff=(np.log(1.-0.5*np.exp(-np.array(noSpike_bins['v_th_diff'])/sv)))[np.array(noSpike_bins['v_th_diff'])>0.0] - noSpike_prob=np.append(noSpike_negDiff, noSpike_posDiff) #!!NOTE: this may not line up correctly in outputs of MLIN HACK - - spike_negDiff=(np.log(1.-0.5*np.exp(np.array(spike_bins['v_th_diff'])/sv)))[np.array(spike_bins['v_th_diff'])<=0.0] - spike_posDiff=(-np.log(2.)-np.array(spike_bins['v_th_diff'])/sv)[np.array(spike_bins['v_th_diff'])>0.0] - spike_prob=np.append(spike_negDiff, spike_posDiff) - - MLIN=-(sum(noSpike_prob)+sum(spike_prob)) - logging.info('MLIN: %f', MLIN) - + bin_ind = range(bin_edges[0], bin_edges[1]) + # vmax_in_bin=max(v_model[bin_ind]) + diff_vector = th_model[bin_ind] - v_model[bin_ind] + # max_indicies=np.where(v_model[bin_ind]==vmax_in_bin)[0] + min_indicies = np.where(diff_vector == min(diff_vector))[0] + # if len(max_indicies)>1: + # print('there is more than one maximum indicie in a bin at', max_indicies, 'choosing last value for computation') + # print('all voltages in the bin are', v_model[bin_ind]) + the_ind = bin_ind[ + min_indicies[-1] + ] # this is here just incase there is more than one value at max voltage in a bin + th_in_bin = th_model[the_ind] + v_in_bin = v_model[the_ind] + diffV = th_model[the_ind] - v_model[the_ind] + noSpike_bins["v"].append(v_in_bin) + noSpike_bins["th"].append(th_in_bin) + noSpike_bins["v_th_diff"].append(diffV) + noSpike_bins["ind_of_max_diff_btw_v_th"].append(the_ind) + + # ----------------------------------------------------------------------------- + # -------------calculate MLIN-------------------------------------------------- + # ----------------------------------------------------------------------------- + + # this was the version with the normal distribution + # noSpike_prob=np.log(np.spacing(1)+voltage_distribution.cdf(noSpike_bins['v_th_diff'])) + # spike_prob=np.log(1+np.spacing(1)-voltage_distribution.cdf(spike_bins['v_th_diff'])) + # version with the expsymm function + + # ---OPTION ONE-------- + # noSpike_prob=np.log(np.spacing(1)+expsymm_cdf(noSpike_bins['v_th_diff'], sv)) + # spike_prob=np.log(1+np.spacing(1)-expsymm_cdf(spike_bins['v_th_diff'], sv)) + # #---OPTION TWO-------- + # N_spike=np.float(len(spike_bins['v_th_diff'])) + # N_noSpike=np.float(len(noSpike_bins['v_th_diff'])) + # noSpike_prob=(N_spike/(N_noSpike+N_spike))*np.log(np.spacing(1)+expsymm_cdf(noSpike_bins['v_th_diff'], sv)) + # spike_prob=(N_noSpike/(N_noSpike+N_spike))*np.log(1+np.spacing(1)-expsymm_cdf(spike_bins['v_th_diff'], sv)) + # ---OPTION THREE + noSpike_negDiff = (-np.log(2.0) + np.array(noSpike_bins["v_th_diff"]) / sv)[ + np.array(noSpike_bins["v_th_diff"]) <= 0.0 + ] + noSpike_posDiff = (np.log(1.0 - 0.5 * np.exp(-np.array(noSpike_bins["v_th_diff"]) / sv)))[ + np.array(noSpike_bins["v_th_diff"]) > 0.0 + ] + noSpike_prob = np.append( + noSpike_negDiff, noSpike_posDiff + ) #!!NOTE: this may not line up correctly in outputs of MLIN HACK + + spike_negDiff = (np.log(1.0 - 0.5 * np.exp(np.array(spike_bins["v_th_diff"]) / sv)))[ + np.array(spike_bins["v_th_diff"]) <= 0.0 + ] + spike_posDiff = (-np.log(2.0) - np.array(spike_bins["v_th_diff"]) / sv)[ + np.array(spike_bins["v_th_diff"]) > 0.0 + ] + spike_prob = np.append(spike_negDiff, spike_posDiff) + + MLIN = -(sum(noSpike_prob) + sum(spike_prob)) + logging.info("MLIN: %f", MLIN) + MLIN_list.append([MLIN]) v_model_list.append(v_model) th_model_list.append(th_model) - non_spike_bin_ind_edges_list.append(non_spike_bin_ind_edges) + non_spike_bin_ind_edges_list.append(non_spike_bin_ind_edges) spike_edges_list_list.append(spike_edges_list) spike_bins_list.append(spike_bins) noSpike_bins_list.append(noSpike_bins) spike_prob_list.append(spike_prob) noSpike_prob_list.append(noSpike_prob) - concatenateMLINList=np.concatenate(MLIN_list) + concatenateMLINList = np.concatenate(MLIN_list) experiment.spike_errors.append(concatenateMLINList) - -# print('param Guess', param_guess, 'TRD', np.mean(concatenateTRDList)) - out =np.mean(concatenateMLINList) - logging.info('MLIN: %f', np.mean(concatenateMLINList)) - -#------------------------------------------------------------------- -#--------------------------plotting------------------------------------ -#--------------------------------------------------------------------- -# plt.subplot(2,1,1) -# plt.title('Model', fontsize=16) -# plt.plot(time, v_model, 'b-', label='voltage') -# plt.plot(time, th_model, 'b--', label='threshold') -# plt.plot(np.concatenate(non_spike_bin_ind_edges)*experiment.neuron.dt, v_model[np.concatenate(non_spike_bin_ind_edges)], 'k|', ms=16) -# plt.plot(np.concatenate(spike_edges_list)*experiment.neuron.dt, v_model[np.concatenate(spike_edges_list)], 'r|', ms=16) -# plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, v_model[spike_bins['ind_of_max_diff_btw_v_th']], 'r.', ms=6) -# plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, th_model[spike_bins['ind_of_max_diff_btw_v_th']], 'r.', ms=6) -# plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, v_model[noSpike_bins['ind_of_max_diff_btw_v_th']], 'k.', ms=6) -# plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, th_model[noSpike_bins['ind_of_max_diff_btw_v_th']], '.', ms=6) -# plt.title(' MLIN='+ str(MLIN)+' : sv='+str(sv)+' : ac_tau='+str(tau_from_AC)+' : spike bin size='+str(spike_bin_size_time)+'!!!!!! !!!!!', fontsize=20) -# plt.xlim([0, time[-1]]) -# -# plt.subplot(2,1,2) -# plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, spike_prob, 'r.', ms=16, label='spike probablility') -# plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, noSpike_prob, 'b.', ms=16, label='no spike probablility') -# plt.legend() -# -# print("coming out of function", out) -# -# plt.show() - - - #converted to list 4_11_15 - experiment.MLIN_HACK={ - 'v_model': v_model_list, - 'th_model': th_model_list, - 'non_spike_bin_ind_edges': non_spike_bin_ind_edges_list, - 'spike_edges_list': spike_edges_list_list, #TODO Corinne: why was this originally a list but nothing else a list - 'spike_bins': spike_bins_list, - 'noSpike_bins': noSpike_bins_list, - 'tau_from_AC': tau_from_AC, - 'spike_prob': spike_prob_list, - 'noSpike_prob': noSpike_prob_list, - 'spike_bin_size_time' : spike_bin_size_time, - 'sv': sv - + + # print('param Guess', param_guess, 'TRD', np.mean(concatenateTRDList)) + out = np.mean(concatenateMLINList) + logging.info("MLIN: %f", np.mean(concatenateMLINList)) + + # ------------------------------------------------------------------- + # --------------------------plotting------------------------------------ + # --------------------------------------------------------------------- + # plt.subplot(2,1,1) + # plt.title('Model', fontsize=16) + # plt.plot(time, v_model, 'b-', label='voltage') + # plt.plot(time, th_model, 'b--', label='threshold') + # plt.plot(np.concatenate(non_spike_bin_ind_edges)*experiment.neuron.dt, v_model[np.concatenate(non_spike_bin_ind_edges)], 'k|', ms=16) + # plt.plot(np.concatenate(spike_edges_list)*experiment.neuron.dt, v_model[np.concatenate(spike_edges_list)], 'r|', ms=16) + # plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, v_model[spike_bins['ind_of_max_diff_btw_v_th']], 'r.', ms=6) + # plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, th_model[spike_bins['ind_of_max_diff_btw_v_th']], 'r.', ms=6) + # plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, v_model[noSpike_bins['ind_of_max_diff_btw_v_th']], 'k.', ms=6) + # plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, th_model[noSpike_bins['ind_of_max_diff_btw_v_th']], '.', ms=6) + # plt.title(' MLIN='+ str(MLIN)+' : sv='+str(sv)+' : ac_tau='+str(tau_from_AC)+' : spike bin size='+str(spike_bin_size_time)+'!!!!!! !!!!!', fontsize=20) + # plt.xlim([0, time[-1]]) + # + # plt.subplot(2,1,2) + # plt.plot(np.array(spike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, spike_prob, 'r.', ms=16, label='spike probablility') + # plt.plot(np.array(noSpike_bins['ind_of_max_diff_btw_v_th'])*experiment.neuron.dt, noSpike_prob, 'b.', ms=16, label='no spike probablility') + # plt.legend() + # + # print("coming out of function", out) + # + # plt.show() + + # converted to list 4_11_15 + experiment.MLIN_HACK = { + "v_model": v_model_list, + "th_model": th_model_list, + "non_spike_bin_ind_edges": non_spike_bin_ind_edges_list, + "spike_edges_list": spike_edges_list_list, # TODO Corinne: why was this originally a list but nothing else a list + "spike_bins": spike_bins_list, + "noSpike_bins": noSpike_bins_list, + "tau_from_AC": tau_from_AC, + "spike_prob": spike_prob_list, + "noSpike_prob": noSpike_prob_list, + "spike_bin_size_time": spike_bin_size_time, + "sv": sv, } -# + # return out diff --git a/allensdk/internal/model/glif/find_spikes.py b/allensdk/internal/model/glif/find_spikes.py index 8cebfebc34..f00d6e5669 100644 --- a/allensdk/internal/model/glif/find_spikes.py +++ b/allensdk/internal/model/glif/find_spikes.py @@ -1,12 +1,13 @@ import numpy as np import allensdk.ephys.ephys_extractor as efex -ALIGN_CUT_WINDOW = np.array([ 0.002, 0.015 ]) +ALIGN_CUT_WINDOW = np.array([0.002, 0.015]) + def find_spikes_list_old(voltage_list, dt): out_idx = [] out_v = [] - + for v in voltage_list: idx, v = find_spikes_old(v, dt) out_idx.append(idx) @@ -14,69 +15,70 @@ def find_spikes_list_old(voltage_list, dt): return out_idx, out_v + def find_spikes_list(voltage_list, dt): - v_set = [ v * 1e3 for v in voltage_list ] - t_set = [ np.arange(0, len(v)) * dt for v in voltage_list ] - i_set = [ np.zeros(len(v)) for v in voltage_list ] - + v_set = [v * 1e3 for v in voltage_list] + t_set = [np.arange(0, len(v)) * dt for v in voltage_list] + i_set = [np.zeros(len(v)) for v in voltage_list] + ext = efex.EphysSweepSetFeatureExtractor(t_set, v_set, i_set, filter=None) ext.process_spikes() - sweep_spikes = [ s.spikes() for s in ext.sweeps() ] - - out_idx = [ np.array([ int(s['threshold_index']) for s in spikes ]) for spikes in sweep_spikes ] - out_v = [ np.array([ s['threshold_v'] for s in spikes ]) for spikes in sweep_spikes ] - + sweep_spikes = [s.spikes() for s in ext.sweeps()] + + out_idx = [np.array([int(s["threshold_index"]) for s in spikes]) for spikes in sweep_spikes] + out_v = [np.array([s["threshold_v"] for s in spikes]) for spikes in sweep_spikes] + return out_idx, out_v + SHORT_SQUARE_MAX_THRESH_FRAC = 0.1 + def find_spikes_ssq_list(voltage_list, dt, dv_cutoff, thresh_frac): - v_set = [ v * 1e3 for v in voltage_list ] - t_set = [ np.arange(0, len(v)) * dt for v in voltage_list ] - i_set = [ np.zeros(len(v)) for v in voltage_list ] + v_set = [v * 1e3 for v in voltage_list] + t_set = [np.arange(0, len(v)) * dt for v in voltage_list] + i_set = [np.zeros(len(v)) for v in voltage_list] thresh_frac = max(SHORT_SQUARE_MAX_THRESH_FRAC, thresh_frac) - - ext = efex.EphysSweepSetFeatureExtractor(t_set, v_set, i_set, - dv_cutoff=dv_cutoff, - thresh_frac=thresh_frac, - filter=None) + + ext = efex.EphysSweepSetFeatureExtractor( + t_set, v_set, i_set, dv_cutoff=dv_cutoff, thresh_frac=thresh_frac, filter=None + ) ext.process_spikes() - sweep_spikes = [ e.spikes() for e in ext.sweeps() ] - - out_idx = [ np.array([ int(s['threshold_index']) for s in spikes ]) for spikes in sweep_spikes ] - out_v = [ np.array([ s['threshold_v'] for s in spikes ]) for spikes in sweep_spikes ] - + sweep_spikes = [e.spikes() for e in ext.sweeps()] + + out_idx = [np.array([int(s["threshold_index"]) for s in spikes]) for spikes in sweep_spikes] + out_v = [np.array([s["threshold_v"] for s in spikes]) for spikes in sweep_spikes] + return out_idx, out_v + def find_spikes_old(v, dt): - v = v * 1e3 # convert V => mV - t = np.arange(0, len(v)) * dt + v = v * 1e3 # convert V => mV + t = np.arange(0, len(v)) * dt i = np.zeros(t.shape) - - - + fx = efex.EphysSweepFeatureExtractor(t=t, v=v, i=i) fx.process_spikes() feature_data = fx.spikes() - - #fx = EphysFeatureExtractor() - #fx.process_instance("", v, i ,t, 0, t[-1], "") - #feature_data = fx.feature_list[0].mean - ids = np.array([ s["threshold_idx"] for s in feature_data ]) - vs = np.array([ s["threshold_v"] for s in feature_data ]) + # fx = EphysFeatureExtractor() + # fx.process_instance("", v, i ,t, 0, t[-1], "") + # feature_data = fx.feature_list[0].mean - vs /= 1e3 # mV => V + ids = np.array([s["threshold_idx"] for s in feature_data]) + vs = np.array([s["threshold_v"] for s in feature_data]) + + vs /= 1e3 # mV => V return ids, vs -def align_and_cut_spikes(voltage_list, current_list, dt, spike_window = None): - ''' This function aligns the spikes to some criteria and returns a current and voltage trace of - of the spike over a time window. Also returns zero crossing,and threshold +def align_and_cut_spikes(voltage_list, current_list, dt, spike_window=None): + """This function aligns the spikes to some criteria and returns a current and voltage trace of + of the spike over a time window. Also returns zero crossing,and threshold in reference to the aligned spikes. - ''' + """ if spike_window is None: spike_window = ALIGN_CUT_WINDOW @@ -87,33 +89,32 @@ def align_and_cut_spikes(voltage_list, current_list, dt, spike_window = None): aligned_spike_ind = np.array([]) spike_sweeps = [] spikes_per_trace = np.array([]) - + spike_ind_list, _ = find_spikes_list(voltage_list, dt) for jj, voltage_and_current_and_spike in enumerate(zip(voltage_list, current_list, spike_ind_list)): voltage, current, whole_trace_spike_ind = voltage_and_current_and_spike - + spikes_per_trace = np.append(spikes_per_trace, len(whole_trace_spike_ind)) - - alignment_ind = whole_trace_spike_ind + + alignment_ind = whole_trace_spike_ind aligned_spike_ind = np.append(aligned_spike_ind, np.ones(len(whole_trace_spike_ind)) * index_before_spike) # print('alignment_ind', alignment_ind) spike_delimiters = [(ind - index_before_spike, ind + index_after_spike) for ind in alignment_ind] - for d in spike_delimiters: + for d in spike_delimiters: # this 'if' statement makes sure we don't cause a ValueError if min(d) > 0 and max(d) < len(voltage) - 1: - spike_trace = voltage[d[0]:d[1]] - current_trace = current[d[0]:d[1]] + spike_trace = voltage[d[0] : d[1]] + current_trace = current[d[0] : d[1]] spike_shapes.append(spike_trace) - current_shapes.append(current_trace) + current_shapes.append(current_trace) spike_sweeps.append(jj) - # note: that depending on how things were aligned, all of one of the values will be the same. print("spikes_per_trace", spikes_per_trace) - temp = np.append(0, np.cumsum(spikes_per_trace)) - print('temp', temp) + temp = np.append(0, np.cumsum(spikes_per_trace)) + print("temp", temp) wave_index_of_first_spikes = [int(ii) for ii in list(temp[range(0, len(temp) - 1)])] print("in cut spikes: wave_index_of_first_spikes ", wave_index_of_first_spikes) diff --git a/allensdk/internal/model/glif/find_sweeps.py b/allensdk/internal/model/glif/find_sweeps.py index 014d4a201c..cac860c634 100644 --- a/allensdk/internal/model/glif/find_sweeps.py +++ b/allensdk/internal/model/glif/find_sweeps.py @@ -7,63 +7,65 @@ import allensdk.core.json_utilities as ju -SHORT_SQUARE = 'Short Square' -SHORT_SQUARE_60 = 'Short Square - Hold -60mv' -SHORT_SQUARE_80 = 'Short Square - Hold -80mv' -LONG_SQUARE = 'Long Square' -RAMP = 'Ramp' -NOISE1 = 'Noise 1' -NOISE2 = 'Noise 2' -SHORT_SQUARE_TRIPLE = 'Short Square - Triple' -RAMP_TO_RHEO = 'Ramp to Rheobase' - - -class MissingSweepException( Exception ): +SHORT_SQUARE = "Short Square" +SHORT_SQUARE_60 = "Short Square - Hold -60mv" +SHORT_SQUARE_80 = "Short Square - Hold -80mv" +LONG_SQUARE = "Long Square" +RAMP = "Ramp" +NOISE1 = "Noise 1" +NOISE2 = "Noise 2" +SHORT_SQUARE_TRIPLE = "Short Square - Triple" +RAMP_TO_RHEO = "Ramp to Rheobase" + + +class MissingSweepException(Exception): pass + def get_sweep_numbers(sweep_list): - return [ s['sweep_number'] for s in sweep_list] + return [s["sweep_number"] for s in sweep_list] def get_sweeps_by_name(sweeps, sweep_type): if isinstance(sweeps, dict): - return [ s for sn,s in sweeps.items() if s[u'ephys_stimulus'][u'ephys_stimulus_type'][u'name'] == sweep_type ] + return [s for sn, s in sweeps.items() if s["ephys_stimulus"]["ephys_stimulus_type"]["name"] == sweep_type] else: - return [ s for s in sweeps if s[u'ephys_stimulus'][u'ephys_stimulus_type'][u'name'] == sweep_type ] + return [s for s in sweeps if s["ephys_stimulus"]["ephys_stimulus_type"]["name"] == sweep_type] def find_ranked_sweep(sweep_list, key, reverse=False): if sweep_list: sorted_sweep_list = sorted(sweep_list, key=lambda x: x[key], reverse=reverse) - - out_sweeps = [ sorted_sweep_list[0] ] - - for i in range(1,len(sweep_list)): + + out_sweeps = [sorted_sweep_list[0]] + + for i in range(1, len(sweep_list)): if sorted_sweep_list[i][key] == out_sweeps[0][key]: out_sweeps.append(sorted_sweep_list[i]) else: break - + return get_sweep_numbers(out_sweeps) else: return [] def organize_sweeps_by_name(sweeps, name): - sweep_list = sorted(get_sweeps_by_name(sweeps, name), key=lambda x: x['sweep_number']) + sweep_list = sorted(get_sweeps_by_name(sweeps, name), key=lambda x: x["sweep_number"]) - subthreshold_list = [ s for s in sweep_list if s.get('num_spikes',None) in [0, None] ] - suprathreshold_list = [ s for s in sweep_list if s.get('num_spikes',None) > 0 ] + subthreshold_list = [s for s in sweep_list if s.get("num_spikes", None) in [0, None]] + suprathreshold_list = [s for s in sweep_list if s.get("num_spikes", None) > 0] return { - 'all': get_sweep_numbers(sweep_list), - 'subthreshold': get_sweep_numbers(subthreshold_list), - 'suprathreshold': get_sweep_numbers(suprathreshold_list), - 'maximum_subthreshold': find_ranked_sweep(subthreshold_list, 'stimulus_amplitude', reverse=True), - 'minimum_suprathreshold': find_ranked_sweep(suprathreshold_list, 'stimulus_amplitude') + "all": get_sweep_numbers(sweep_list), + "subthreshold": get_sweep_numbers(subthreshold_list), + "suprathreshold": get_sweep_numbers(suprathreshold_list), + "maximum_subthreshold": find_ranked_sweep(subthreshold_list, "stimulus_amplitude", reverse=True), + "minimum_suprathreshold": find_ranked_sweep(suprathreshold_list, "stimulus_amplitude"), #'maximum_subthreshold': find_ranked_sweep(subthreshold_list, 'stimulus_absolute_amplitude', reverse=True), #'minimum_suprathreshold': find_ranked_sweep(suprathreshold_list, 'stimulus_absolute_amplitude') - } + } + def find_long_square_sweeps(sweeps): out = organize_sweeps_by_name(sweeps, LONG_SQUARE) @@ -76,77 +78,72 @@ def find_ramp_to_rheo_sweeps(sweeps): def find_short_square_sweeps(sweeps): - ''' + """ Find 1) all of the subthreshold short square sweeps 2) all of the superthreshold short square sweeps 3) the subthresholds short square sweep with maximum stimulus amplitude - ''' + """ out = organize_sweeps_by_name(sweeps, SHORT_SQUARE) out60 = organize_sweeps_by_name(sweeps, SHORT_SQUARE_60) out80 = organize_sweeps_by_name(sweeps, SHORT_SQUARE_80) out_triple = organize_sweeps_by_name(sweeps, SHORT_SQUARE_TRIPLE) - out['all_60'] = out60['all'] - out['all_80'] = out80['all'] - out['triple'] = out_triple['all'] - - if len(out['maximum_subthreshold']) == 0: + out["all_60"] = out60["all"] + out["all_80"] = out80["all"] + out["triple"] = out_triple["all"] + + if len(out["maximum_subthreshold"]) == 0: raise MissingSweepException("No maximum subthreshold short square") - if len(out['minimum_suprathreshold']) == 0: + if len(out["minimum_suprathreshold"]) == 0: raise MissingSweepException("No minimum suprathreshold short square") return out def find_ramp_sweeps(sweeps): - ''' + """ Find 1) all ramp sweeps 2) all subthreshold ramps 3) all superthreshold ramps - ''' + """ out = organize_sweeps_by_name(sweeps, RAMP) return out - + def find_noise_sweeps(sweeps): - ''' + """ Find 1) the noise1 sweeps 2) the noise2 sweeps 4) all noise sweeps - ''' + """ noise1 = organize_sweeps_by_name(sweeps, NOISE1) noise2 = organize_sweeps_by_name(sweeps, NOISE2) - - all_noise_sweeps = sorted(noise1['all'] + noise2['all']) - out = { - 'all': all_noise_sweeps, - 'noise1': noise1['all'], - 'noise2': noise2['all'] - } + all_noise_sweeps = sorted(noise1["all"] + noise2["all"]) + + out = {"all": all_noise_sweeps, "noise1": noise1["all"], "noise2": noise2["all"]} + + num_noise1_sweeps = len(out["noise1"]) + num_noise2_sweeps = len(out["noise2"]) - num_noise1_sweeps = len(out['noise1']) - num_noise2_sweeps = len(out['noise2']) - required_noise1_sweeps = 2 required_noise2_sweeps = 2 if num_noise1_sweeps < required_noise1_sweeps: raise MissingSweepException("not enough noise1 sweeps (%d/%d)" % (num_noise1_sweeps, required_noise1_sweeps)) - + if num_noise2_sweeps < required_noise2_sweeps: raise MissingSweepException("not enough noise2 sweeps (%d/%d)" % (num_noise2_sweeps, required_noise2_sweeps)) - + return out def find_sweeps(sweep_list): - - sweep_index = { s['sweep_number']: s for s in sweep_list } + sweep_index = {s["sweep_number"]: s for s in sweep_list} data = {} ssq_data = find_short_square_sweeps(sweep_index) @@ -165,13 +162,13 @@ def find_sweeps(sweep_list): data.update(noise_data) return data, sweep_index - + def parse_arguments(): - parser = argparse.ArgumentParser(description='find relevant sweeps from a sweep catalog') + parser = argparse.ArgumentParser(description="find relevant sweeps from a sweep catalog") - parser.add_argument('sweep_list_file', help='json file containing a list of sweeps for a cell') - parser.add_argument('output_file', help='output json data config file') + parser.add_argument("sweep_list_file", help="json file containing a list of sweeps for a cell") + parser.add_argument("output_file", help="output json data config file") args = parser.parse_args() @@ -200,5 +197,6 @@ def main(): logging.error(err) sys.exit(1) + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/glif/glif_experiment.py b/allensdk/internal/model/glif/glif_experiment.py index c52fcb0953..0056d2830d 100644 --- a/allensdk/internal/model/glif/glif_experiment.py +++ b/allensdk/internal/model/glif/glif_experiment.py @@ -4,12 +4,20 @@ # TODO: license # TODO: document -class GlifExperiment( object ): - def __init__(self, neuron, dt, stim_list, resp_list, - spike_time_steps, grid_spike_times, grid_spike_voltages, - param_fit_names, - **kwargs): +class GlifExperiment(object): + def __init__( + self, + neuron, + dt, + stim_list, + resp_list, + spike_time_steps, + grid_spike_times, + grid_spike_voltages, + param_fit_names, + **kwargs, + ): self.neuron = neuron self.dt = dt self.stim_list = stim_list @@ -21,101 +29,95 @@ def __init__(self, neuron, dt, stim_list, resp_list, self.spike_errors = [] - def run(self, param_guess): - '''This code will run the loaded neuron model in reference to the target neuron spikes. + """This code will run the loaded neuron model in reference to the target neuron spikes. inputs: self: is the instance of the neuron model and parameters alone with the values of the target spikes. NOTE the values in each array of the self.gridSpikeIndexTarge_list and the self.interpolated_spike_times are in reference to the time start of of the stim in each induvidual array (not the universal time) param_guess: array of scalars of the values that will be inserted into the mapping function below. returns: - voltage_list: list of array of voltage values. NOTE: IF THE MODEL NEURON SPIKES BEFORE THE TARGET THE VOLTAGE WILL - NOT BE CALCULATED THEREFORE THE RESULTING VECTOR WILL NOT BE AS LONG AS THE TARGET AND ALSO WILL NOT + voltage_list: list of array of voltage values. NOTE: IF THE MODEL NEURON SPIKES BEFORE THE TARGET THE VOLTAGE WILL + NOT BE CALCULATED THEREFORE THE RESULTING VECTOR WILL NOT BE AS LONG AS THE TARGET AND ALSO WILL NOT MAKE SENSE WITH THE STIMULUS UNLESS YOU CUT IT AND OUTPUT IT TOO. grid_spike_times_list: - interpolated_spike_time_list: an array of the actual times of the spikes. NOTE: THESE TIMES ARE CALCULATED BY ADDING THE + interpolated_spike_time_list: an array of the actual times of the spikes. NOTE: THESE TIMES ARE CALCULATED BY ADDING THE TIME OF THE INDIVIDUAL SPIKE TO THE TIME OF THE LAST SPIKE. - gridISIFromLastTargSpike_list: list of arrays of spike times of the model in reference to the last target (biological) + gridISIFromLastTargSpike_list: list of arrays of spike times of the model in reference to the last target (biological) spike (not in reference to sweep start) - interpolatedISIFromLastTargSpike_list: list of arrays of spike times of the model in reference to the last target (biological) + interpolatedISIFromLastTargSpike_list: list of arrays of spike times of the model in reference to the last target (biological) spike (not in reference to sweep start) - voltageOfModelAtGridBioSpike_list: list of arrays of scalars that contain the voltage of the model neuron when the target or bio neuron spikes. - theshOfModelAtGridBioSpike_list: list of arrays of scalars that contain the threshold of the model neuron when the target or bio neuron spikes.''' + voltageOfModelAtGridBioSpike_list: list of arrays of scalars that contain the voltage of the model neuron when the target or bio neuron spikes. + theshOfModelAtGridBioSpike_list: list of arrays of scalars that contain the threshold of the model neuron when the target or bio neuron spikes.""" - self.set_neuron_parameters(param_guess) + self.set_neuron_parameters(param_guess) self.spike_errors = [] run_data = [] - - for stim_list_index in range(len(self.stim_list)): - run_data.append(self.neuron.run_with_biological_spikes(self.stim_list[stim_list_index], - self.resp_list[stim_list_index], - self.spike_time_steps[stim_list_index])) - - return { - 'voltage': [ rd['voltage'] for rd in run_data ], - 'threshold': [ rd['threshold'] for rd in run_data ], - 'AScurrent_matrix': [ rd['AScurrent_matrix'] for rd in run_data ], - - 'grid_ISI': [ rd['grid_ISI'] for rd in run_data ], - 'interpolated_ISI': [ rd['interpolated_ISI'] for rd in run_data ], - 'grid_model_spike_times': [ rd['grid_model_spike_times'] for rd in run_data ], - 'interpolated_model_spike_times': [ rd['interpolated_model_spike_times'] for rd in run_data ], - - 'grid_model_spike_voltages': [ rd['grid_model_spike_voltages'] for rd in run_data ], - 'interpolated_model_spike_voltages': [ rd['interpolated_model_spike_voltages'] for rd in run_data ], + for stim_list_index in range(len(self.stim_list)): + run_data.append( + self.neuron.run_with_biological_spikes( + self.stim_list[stim_list_index], + self.resp_list[stim_list_index], + self.spike_time_steps[stim_list_index], + ) + ) - 'grid_bio_spike_model_voltage': [ rd['grid_bio_spike_model_voltage'] for rd in run_data ], - 'grid_bio_spike_model_threshold': [ rd['grid_bio_spike_model_threshold'] for rd in run_data ] + return { + "voltage": [rd["voltage"] for rd in run_data], + "threshold": [rd["threshold"] for rd in run_data], + "AScurrent_matrix": [rd["AScurrent_matrix"] for rd in run_data], + "grid_ISI": [rd["grid_ISI"] for rd in run_data], + "interpolated_ISI": [rd["interpolated_ISI"] for rd in run_data], + "grid_model_spike_times": [rd["grid_model_spike_times"] for rd in run_data], + "interpolated_model_spike_times": [rd["interpolated_model_spike_times"] for rd in run_data], + "grid_model_spike_voltages": [rd["grid_model_spike_voltages"] for rd in run_data], + "interpolated_model_spike_voltages": [rd["interpolated_model_spike_voltages"] for rd in run_data], + "grid_bio_spike_model_voltage": [rd["grid_bio_spike_model_voltage"] for rd in run_data], + "grid_bio_spike_model_threshold": [rd["grid_bio_spike_model_threshold"] for rd in run_data], } - def run_base_model(self, param_guess): - '''This code will run the loaded neuron model. + """This code will run the loaded neuron model. inputs: self: is the instance of the neuron model and parameters alone with the values of the target spikes. NOTE the values in each array of the self.gridSpikeIndexTarge_list and the self.interpolated_spike_times are in reference to the time start of of the stim in each induvidual array (not the universal time) param_guess: array of scalars of the values that will be inserted into the mapping function below. returns: - voltage_list: list of array of voltage values. NOTE: IF THE MODEL NEURON SPIKES BEFORE THE TARGET THE VOLTAGE WILL - NOT BE CALCULATED THEREFORE THE RESULTING VECTOR WILL NOT BE AS LONG AS THE TARGET AND ALSO WILL NOT + voltage_list: list of array of voltage values. NOTE: IF THE MODEL NEURON SPIKES BEFORE THE TARGET THE VOLTAGE WILL + NOT BE CALCULATED THEREFORE THE RESULTING VECTOR WILL NOT BE AS LONG AS THE TARGET AND ALSO WILL NOT MAKE SENSE WITH THE STIMULUS UNLESS YOU CUT IT AND OUTPUT IT TOO. gridTime_list: - interpolatedTime_list: an array of the actual times of the spikes. NOTE: THESE TIMES ARE CALCULATED BY ADDING THE + interpolatedTime_list: an array of the actual times of the spikes. NOTE: THESE TIMES ARE CALCULATED BY ADDING THE TIME OF THE INDIVIDUAL SPIKE TO THE TIME OF THE LAST SPIKE. - grid_ISI_list: list of arrays of spike times of the model in reference to the last target (biological) + grid_ISI_list: list of arrays of spike times of the model in reference to the last target (biological) spike (not in reference to sweep start) - interpolated_ISI_list: list of arrays of spike times of the model in reference to the last target (biological) + interpolated_ISI_list: list of arrays of spike times of the model in reference to the last target (biological) spike (not in reference to sweep start) - grid_spike_voltage_list: list of arrays of scalars that contain the voltage of the model neuron when the target or bio neuron spikes. - grid_spike_threshold_list: list of arrays of scalars that contain the threshold of the model neuron when the target or bio neuron spikes.''' + grid_spike_voltage_list: list of arrays of scalars that contain the voltage of the model neuron when the target or bio neuron spikes. + grid_spike_threshold_list: list of arrays of scalars that contain the threshold of the model neuron when the target or bio neuron spikes.""" - stim_list = self.stim_list - - self.set_neuron_parameters(param_guess) + + self.set_neuron_parameters(param_guess) self.spike_errors = [] - + run_data = [] for stim_list_index in range(len(stim_list)): run_data.append(self.neuron.run(stim_list[stim_list_index])) - return { - 'voltage': [ rd['voltage'] for rd in run_data ], - 'threshold': [ rd['threshold'] for rd in run_data ], - 'AScurrents': [ rd['AScurrents'] for rd in run_data ], - - 'spike_time_steps': [ rd['spike_time_steps'] for rd in run_data ], - 'grid_spike_times': [ rd['grid_spike_times'] for rd in run_data ], - 'interpolated_spike_times': [ rd['interpolated_spike_times'] for rd in run_data ], - - 'interpolated_spike_voltage': [ rd['interpolated_spike_voltage'] for rd in run_data ], - 'interpolated_spike_threshold': [ rd['interpolated_spike_threshold'] for rd in run_data ] + "voltage": [rd["voltage"] for rd in run_data], + "threshold": [rd["threshold"] for rd in run_data], + "AScurrents": [rd["AScurrents"] for rd in run_data], + "spike_time_steps": [rd["spike_time_steps"] for rd in run_data], + "grid_spike_times": [rd["grid_spike_times"] for rd in run_data], + "interpolated_spike_times": [rd["interpolated_spike_times"] for rd in run_data], + "interpolated_spike_voltage": [rd["interpolated_spike_voltage"] for rd in run_data], + "interpolated_spike_threshold": [rd["interpolated_spike_threshold"] for rd in run_data], } def neuron_parameter_count(self): @@ -136,10 +138,10 @@ def neuron_parameter_count(self): count += 1 return count - def set_neuron_parameters(self, param_guess): - '''Maps the parameter guesses to the coefficients of the model. + def set_neuron_parameters(self, param_guess): + """Maps the parameter guesses to the coefficients of the model. input: - param_guess is vector of values. It is assumed that the length will be ''' + param_guess is vector of values. It is assumed that the length will be""" index = 0 for fit_name in self.param_fit_names: @@ -153,9 +155,8 @@ def set_neuron_parameters(self, param_guess): try: # this will throw a type error if 'coeff' is a scalar coeff_size = len(coeff) - self.neuron.coeffs[fit_name] = param_guess[index:index+coeff_size] + self.neuron.coeffs[fit_name] = param_guess[index : index + coeff_size] index += coeff_size except TypeError: self.neuron.coeffs[fit_name] = param_guess[index] index += 1 - diff --git a/allensdk/internal/model/glif/glif_optimizer.py b/allensdk/internal/model/glif/glif_optimizer.py index 981a619569..3ad66139e6 100644 --- a/allensdk/internal/model/glif/glif_optimizer.py +++ b/allensdk/internal/model/glif/glif_optimizer.py @@ -15,18 +15,26 @@ # TODO: license # TODO: document -class GlifOptimizer(object): - def __init__(self, experiment, dt, - outer_iterations, inner_iterations, - sigma_outer, sigma_inner, - param_fit_names, stim, - xtol, ftol, - internal_iterations, - bessel, - error_function = None, - error_function_data = None, - init_params = None): +class GlifOptimizer(object): + def __init__( + self, + experiment, + dt, + outer_iterations, + inner_iterations, + sigma_outer, + sigma_inner, + param_fit_names, + stim, + xtol, + ftol, + internal_iterations, + bessel, + error_function=None, + error_function_data=None, + init_params=None, + ): self.start_time = None self.rng = np.random.RandomState() @@ -51,12 +59,12 @@ def __init__(self, experiment, dt, self.ftol = ftol self.internal_iterations = internal_iterations - + self.bessel = bessel - logging.info('internal_iterations: %s' % internal_iterations) - logging.info('outer_iterations: %s' % outer_iterations) - logging.info('inner_iterations: %s' % inner_iterations) + logging.info("internal_iterations: %s" % internal_iterations) + logging.info("outer_iterations: %s" % outer_iterations) + logging.info("inner_iterations: %s" % inner_iterations) self.iteration_info = [] @@ -66,23 +74,26 @@ def __init__(self, experiment, dt, self.init_params = np.ones(expected_param_count) elif len(self.init_params) != expected_param_count: self.init_params = np.ones(expected_param_count) - logging.warning('optimizer init_params has wrong length (given %d, expected %d). settings to all ones' % (len(self.init_params), expected_param_count)) - + logging.warning( + "optimizer init_params has wrong length (given %d, expected %d). settings to all ones" + % (len(self.init_params), expected_param_count) + ) + def to_dict(self): return { - 'outer_iterations': self.outer_iterations, - 'inner_iterations': self.inner_iterations, - 'init_params': self.init_params, - 'sigma_outer': self.sigma_outer, - 'sigma_inner': self.sigma_inner, - 'param_fit_names': self.param_fit_names, - 'xtol': self.xtol, - 'ftol': self.ftol, - 'internal_iterations': self.internal_iterations, - 'iteration_info': self.iteration_info, - 'bessel': self.bessel + "outer_iterations": self.outer_iterations, + "inner_iterations": self.inner_iterations, + "init_params": self.init_params, + "sigma_outer": self.sigma_outer, + "sigma_inner": self.sigma_inner, + "param_fit_names": self.param_fit_names, + "xtol": self.xtol, + "ftol": self.ftol, + "internal_iterations": self.internal_iterations, + "iteration_info": self.iteration_info, + "bessel": self.bessel, } - + def randomize_parameter_values(self, values, sigma): values = np.array(self.rng.normal(values, sigma)) @@ -90,175 +101,197 @@ def randomize_parameter_values(self, values, sigma): if not values.shape: values = np.array([values]) return values - + def initiate_unique_seed(self, seed=None): - if seed is None: - x1=str(int(uuid4())) #get a uuid, turn it into int then turn it into string - x2=[x1[ii:ii+8]for ii in range(0,40,8)] #break it up into chunks - x3=[int(ii) for ii in x2]#turn string chunks back into integers - print('seed', x3) + x1 = str(int(uuid4())) # get a uuid, turn it into int then turn it into string + x2 = [x1[ii : ii + 8] for ii in range(0, 40, 8)] # break it up into chunks + x3 = [int(ii) for ii in x2] # turn string chunks back into integers + print("seed", x3) self.rng.seed(x3) else: self.rng.seed(seed) def evaluate(self, x, dt_multiplier=100): - self.experiment.neuron.dt_multiplier = dt_multiplier return self.error_function([x], self.experiment, self.error_function_data) - + def run_many(self, iteration_finished_callback=None, seed=None): self.initiate_unique_seed(seed=seed) params_start = self.init_params self.start_time = time.time() - params=params_start -# params=self.randomize_parameter_values(params_start, self.sigma_outer) - print('actual starting parameters', params) - - stop_flag=False - + params = params_start + # params=self.randomize_parameter_values(params_start, self.sigma_outer) + print("actual starting parameters", params) + + stop_flag = False + # TODO: unhardcode this dt_multiplier_list = [100, 32, 10] - #Note the following line may be useful when there are more iteration but is hasnt been tested -# dt_multiplier_list = np.ceil(np.logspace(1,2,self.inner_iterations))[::-1].astype(int) + # Note the following line may be useful when there are more iteration but is hasnt been tested + # dt_multiplier_list = np.ceil(np.logspace(1,2,self.inner_iterations))[::-1].astype(int) print(dt_multiplier_list) -# dt_multiplier_list = [10,10,10] - #TODO: figure out the implications of this being an int versus float - #TODO: make this so that dt multiplier actually gets set - for outer in range(0, self.outer_iterations): #outerloop - for inner in range(0, self.inner_iterations): #innerloop + # dt_multiplier_list = [10,10,10] + # TODO: figure out the implications of this being an int versus float + # TODO: make this so that dt multiplier actually gets set + for outer in range(0, self.outer_iterations): # outerloop + for inner in range(0, self.inner_iterations): # innerloop iteration_start_time = time.time() # run the optimizer once. first time is always the passed initial conditions. -# print('dt_multiplier_list[inner]', dt_multiplier_list[inner]) - #--set this equal to 1 if want to do it slow + # print('dt_multiplier_list[inner]', dt_multiplier_list[inner]) + # --set this equal to 1 if want to do it slow self.experiment.neuron.dt_multiplier = dt_multiplier_list[inner] - #self.experiment.neuron.dt_multiplier = 10 - - + # self.experiment.neuron.dt_multiplier = 10 + opt = self.run_once(params) xopt, fopt = opt[0], opt[1] - logging.info('fmin took %f secs, %f mins, %f hours' % (time.time() - iteration_start_time, (time.time() - iteration_start_time)/60, (time.time() - iteration_start_time)/60/60)) - - self.iteration_info.append({ - 'in_params': np.array(params).tolist(), - 'out_params': xopt.tolist(), - 'error': float(fopt), - 'dt_multiplier': self.experiment.neuron.dt_multiplier - }) - -# ER=self.iteration_info['error'] -# ETOL=1.e-4 -# if len(ER) >=3: -# #!!!!!!!!!!!!!!!fix this to use that actual parameters!!!!!!!!!!!!!!!!!!!!!! -# if np.abs(ER[-1]-ER[-2])=3: + # #!!!!!!!!!!!!!!!fix this to use that actual parameters!!!!!!!!!!!!!!!!!!!!!! + # if np.abs(ER[-1]-ER[-2])threshold_t0: - raise GlifBadInitializationException("Voltage STARTS above threshold: voltage_t0 (%f) threshold_t0 (%f)" % ( voltage_t0, threshold_t0)) + + if voltage_t0 > threshold_t0: + raise GlifBadInitializationException( + "Voltage STARTS above threshold: voltage_t0 (%f) threshold_t0 (%f)" % (voltage_t0, threshold_t0) + ) start_index = 0 end_index = 0 - + try: num_spikes = len(bio_spike_time_steps) - - # if there are no target spikes, just run until the model spikes - if num_spikes == 0: + # if there are no target spikes, just run until the model spikes + if num_spikes == 0: start_index = 0 end_index = len(stimulus) # evaluate the model starting from the beginning until the model spikes - run_data = self.run_until_biological_spike(voltage_t0, threshold_t0, AScurrents_t0, - stimulus, response, start_index, end_index, - []) + run_data = self.run_until_biological_spike( + voltage_t0, threshold_t0, AScurrents_t0, stimulus, response, start_index, end_index, [] + ) - voltage = run_data['voltage'] - threshold = run_data['threshold'] - AScurrent_matrix = run_data['AScurrent_matrix'] + voltage = run_data["voltage"] + threshold = run_data["threshold"] + AScurrent_matrix = run_data["AScurrent_matrix"] if len(voltage) != len(stimulus): - logging.warning('Your voltage output is not the same length as your stimulus') + logging.warning("Your voltage output is not the same length as your stimulus") if len(threshold) != len(stimulus): - logging.warning('Your threshold output is not the same length as your stimulus') + logging.warning("Your threshold output is not the same length as your stimulus") if len(AScurrent_matrix) != len(stimulus): - logging.warning('Your AScurrent_matrix output is not the same length as your stimulus') - + logging.warning("Your AScurrent_matrix output is not the same length as your stimulus") + # do not keep track of the spikes in the model that spike if the target doesn't spike. grid_ISI = np.array([]) interpolated_ISI = np.array([]) - grid_model_spike_times = np.array([]) + grid_model_spike_times = np.array([]) interpolated_model_spike_times = np.array([]) - - grid_model_spike_voltages = np.array([]) - interpolated_model_spike_voltages = np.array([]) + + grid_model_spike_voltages = np.array([]) + interpolated_model_spike_voltages = np.array([]) grid_bio_spike_model_voltage = np.array([]) grid_bio_spike_model_threshold = np.array([]) @@ -198,160 +205,176 @@ def run_with_biological_spikes(self, stimulus, response, bio_spike_time_steps): grid_model_spike_times = np.empty(num_spikes) interpolated_model_spike_times = np.empty(num_spikes) - - grid_model_spike_voltages = np.empty(num_spikes) + + grid_model_spike_voltages = np.empty(num_spikes) interpolated_model_spike_voltages = np.empty(num_spikes) grid_bio_spike_model_voltage = np.empty(num_spikes) grid_bio_spike_model_threshold = np.empty(num_spikes) - voltage = np.empty(len(stimulus)) - voltage[:] = np.nan + voltage[:] = np.nan threshold = np.empty(len(stimulus)) threshold[:] = np.nan AScurrent_matrix = np.empty(shape=(len(stimulus), len(AScurrents_t0))) AScurrent_matrix[:] = np.nan - - # run the simulation over the interspike intervals (starting at the beginning of the simulation). + + # run the simulation over the interspike intervals (starting at the beginning of the simulation). start_index = 0 for spike_num in range(num_spikes): - if spike_num % 10 == 0: - logging.debug("spike %d / %d" % (spike_num, num_spikes)) + logging.debug("spike %d / %d" % (spike_num, num_spikes)) end_index = int(bio_spike_time_steps[spike_num]) - - assert start_index < end_index, Exception("start_index > end_index: this is probably because spike_cut_length is longer than the previous inter-spike interval") + + assert start_index < end_index, Exception( + "start_index > end_index: this is probably because spike_cut_length is longer than the previous inter-spike interval" + ) # run the simulation over this interspike interval -# t0 = time.time() - run_data = self.run_until_biological_spike(voltage_t0, threshold_t0, AScurrents_t0, - stimulus, response, start_index, end_index, - bio_spike_time_steps) -# print('fast', time.time() - t0) - -# curr_voltage = run_data_fast['voltage'] -# curr_threshold = run_data_fast['threshold'] -# voltage_scrubbed = curr_voltage[np.logical_not(np.isnan(curr_voltage))] -# threshold_scrubbed = curr_threshold[np.logical_not(np.isnan(curr_threshold))] -# -# tmp_t = np.linspace(0,1,len(voltage_scrubbed)) -# print(voltage_scrubbed) -# plt.plot(tmp_t, voltage_scrubbed) -# plt.plot(tmp_t, threshold_scrubbed) -# -# plt.show() -# sys.exit() - -# for key, val in run_data.items(): -# print(key, val) -# sys.exit() - - + # t0 = time.time() + run_data = self.run_until_biological_spike( + voltage_t0, + threshold_t0, + AScurrents_t0, + stimulus, + response, + start_index, + end_index, + bio_spike_time_steps, + ) + # print('fast', time.time() - t0) + + # curr_voltage = run_data_fast['voltage'] + # curr_threshold = run_data_fast['threshold'] + # voltage_scrubbed = curr_voltage[np.logical_not(np.isnan(curr_voltage))] + # threshold_scrubbed = curr_threshold[np.logical_not(np.isnan(curr_threshold))] + # + # tmp_t = np.linspace(0,1,len(voltage_scrubbed)) + # print(voltage_scrubbed) + # plt.plot(tmp_t, voltage_scrubbed) + # plt.plot(tmp_t, threshold_scrubbed) + # + # plt.show() + # sys.exit() + + # for key, val in run_data.items(): + # print(key, val) + # sys.exit() + # assign the simulated data to the correct locations in the output arrays - voltage[start_index:end_index] = run_data['voltage'] - threshold[start_index:end_index] = run_data['threshold'] - AScurrent_matrix[start_index:end_index,:] = run_data['AScurrent_matrix'] + voltage[start_index:end_index] = run_data["voltage"] + threshold[start_index:end_index] = run_data["threshold"] + AScurrent_matrix[start_index:end_index, :] = run_data["AScurrent_matrix"] + + grid_ISI[spike_num] = run_data["grid_model_spike_time"] + interpolated_ISI[spike_num] = run_data["interpolated_model_spike_time"] - grid_ISI[spike_num] = run_data['grid_model_spike_time'] - interpolated_ISI[spike_num] = run_data['interpolated_model_spike_time'] + grid_model_spike_times[spike_num] = run_data["grid_model_spike_time"] + start_index * self.dt + interpolated_model_spike_times[spike_num] = ( + run_data["interpolated_model_spike_time"] + start_index * self.dt + ) - grid_model_spike_times[spike_num] = run_data['grid_model_spike_time'] + start_index * self.dt - interpolated_model_spike_times[spike_num] = run_data['interpolated_model_spike_time'] + start_index * self.dt - - grid_model_spike_voltages[spike_num] = run_data['grid_model_spike_voltage'] - interpolated_model_spike_voltages[spike_num] = run_data['interpolated_model_spike_voltage'] + grid_model_spike_voltages[spike_num] = run_data["grid_model_spike_voltage"] + interpolated_model_spike_voltages[spike_num] = run_data["interpolated_model_spike_voltage"] - grid_bio_spike_model_voltage[spike_num] = run_data['grid_bio_spike_model_voltage'] - grid_bio_spike_model_threshold[spike_num] = run_data['grid_bio_spike_model_threshold'] + grid_bio_spike_model_voltage[spike_num] = run_data["grid_bio_spike_model_voltage"] + grid_bio_spike_model_threshold[spike_num] = run_data["grid_bio_spike_model_threshold"] # update the voltage, threshold, and afterspike currents for the next interval - voltage_t0 = run_data['voltage_t0'] - threshold_t0 = run_data['threshold_t0'] - AScurrents_t0 = run_data['AScurrents_t0'] - + voltage_t0 = run_data["voltage_t0"] + threshold_t0 = run_data["threshold_t0"] + AScurrents_t0 = run_data["AScurrents_t0"] + start_index = end_index # if cutting spikes, jump forward the appropriate amount of time if self.spike_cut_length > 0: start_index += self.spike_cut_length - + # simulate the portion of the stimulus between the last spike and the end of the array. # no spikes are recorded from this time! - run_data = self.run_until_biological_spike(voltage_t0, threshold_t0, AScurrents_t0, - stimulus, response, start_index, len(stimulus), - bio_spike_time_steps) - - voltage[start_index:] = run_data['voltage'] - threshold[start_index:] = run_data['threshold'] - AScurrent_matrix[start_index:,:] = run_data['AScurrent_matrix'] + run_data = self.run_until_biological_spike( + voltage_t0, + threshold_t0, + AScurrents_t0, + stimulus, + response, + start_index, + len(stimulus), + bio_spike_time_steps, + ) + + voltage[start_index:] = run_data["voltage"] + threshold[start_index:] = run_data["threshold"] + AScurrent_matrix[start_index:, :] = run_data["AScurrent_matrix"] # make sure that the output data has the correct number of spikes in it - if ( len(interpolated_model_spike_times) != num_spikes or - len(grid_model_spike_times) != num_spikes or - len(grid_ISI) != num_spikes or - len(interpolated_ISI) != num_spikes or - len(grid_bio_spike_model_voltage) != num_spikes or - len(grid_bio_spike_model_threshold) != num_spikes): - raise Exception('The number of spikes in your output does not match your target') + if ( + len(interpolated_model_spike_times) != num_spikes + or len(grid_model_spike_times) != num_spikes + or len(grid_ISI) != num_spikes + or len(interpolated_ISI) != num_spikes + or len(grid_bio_spike_model_voltage) != num_spikes + or len(grid_bio_spike_model_threshold) != num_spikes + ): + raise Exception("The number of spikes in your output does not match your target") except GlifNeuronException as e: - # if an exception was raised during run_until_spike, record any simulated data before exiting - voltage[start_index:end_index] = e.data['voltage'] - threshold[start_index:end_index] = e.data['threshold'] - AScurrent_matrix[start_index:end_index,:] = e.data['AScurrent_matrix'] + voltage[start_index:end_index] = e.data["voltage"] + threshold[start_index:end_index] = e.data["threshold"] + AScurrent_matrix[start_index:end_index, :] = e.data["AScurrent_matrix"] out = { - 'voltage': voltage, - 'threshold': threshold, - 'AScurrent_matrix': AScurrent_matrix, - - 'grid_ISI': grid_ISI, - 'interpolated_ISI': interpolated_ISI, - - 'grid_model_spike_times': grid_model_spike_times, - 'interpolated_model_spike_times': interpolated_model_spike_times, - - 'grid_model_spike_voltages': grid_model_spike_voltages, - 'interpolated_model_spike_voltages': interpolated_model_spike_voltages, - - 'grid_bio_spike_model_voltage': grid_bio_spike_model_voltage, - 'grid_bio_spike_model_threshold': grid_bio_spike_model_threshold - } + "voltage": voltage, + "threshold": threshold, + "AScurrent_matrix": AScurrent_matrix, + "grid_ISI": grid_ISI, + "interpolated_ISI": interpolated_ISI, + "grid_model_spike_times": grid_model_spike_times, + "interpolated_model_spike_times": interpolated_model_spike_times, + "grid_model_spike_voltages": grid_model_spike_voltages, + "interpolated_model_spike_voltages": interpolated_model_spike_voltages, + "grid_bio_spike_model_voltage": grid_bio_spike_model_voltage, + "grid_bio_spike_model_threshold": grid_bio_spike_model_threshold, + } raise GlifNeuronException(e.message, out) return { - 'voltage': voltage, - 'threshold': threshold, - 'AScurrent_matrix': AScurrent_matrix, - - 'grid_model_spike_times': grid_model_spike_times, - 'interpolated_model_spike_times': interpolated_model_spike_times, - - 'grid_model_spike_voltages': grid_model_spike_voltages, - 'interpolated_model_spike_voltages': interpolated_model_spike_voltages, - - 'grid_ISI': grid_ISI, - 'interpolated_ISI': interpolated_ISI, - - 'grid_bio_spike_model_voltage': grid_bio_spike_model_voltage, - 'grid_bio_spike_model_threshold': grid_bio_spike_model_threshold + "voltage": voltage, + "threshold": threshold, + "AScurrent_matrix": AScurrent_matrix, + "grid_model_spike_times": grid_model_spike_times, + "interpolated_model_spike_times": interpolated_model_spike_times, + "grid_model_spike_voltages": grid_model_spike_voltages, + "interpolated_model_spike_voltages": interpolated_model_spike_voltages, + "grid_ISI": grid_ISI, + "interpolated_ISI": interpolated_ISI, + "grid_bio_spike_model_voltage": grid_bio_spike_model_voltage, + "grid_bio_spike_model_threshold": grid_bio_spike_model_threshold, } - - def run_until_biological_spike(self, voltage_t0, threshold_t0, AScurrents_t0, - stimulus, response, start_index, after_end_index, - bio_spike_time_steps): - """ Run the neuron simulation over a segment of a stimulus given initial conditions for use in the "forced spike" - optimization paradigm. [Note: the section of stimulus - is meant to be between two biological neuron spikes. Thus the stimulus is during the interspike interval (ISI)]. The - model is simulated until either the model spikes or the end of the segment is reached. If the model does not spike, a - spike time is extrapolated past the end of the simulation segment. - - This function also returns the initial conditions for the subsequent stimulus segment. In the forced spike paradigm - there are several ways + + def run_until_biological_spike( + self, + voltage_t0, + threshold_t0, + AScurrents_t0, + stimulus, + response, + start_index, + after_end_index, + bio_spike_time_steps, + ): + """Run the neuron simulation over a segment of a stimulus given initial conditions for use in the "forced spike" + optimization paradigm. [Note: the section of stimulus + is meant to be between two biological neuron spikes. Thus the stimulus is during the interspike interval (ISI)]. The + model is simulated until either the model spikes or the end of the segment is reached. If the model does not spike, a + spike time is extrapolated past the end of the simulation segment. + + This function also returns the initial conditions for the subsequent stimulus segment. In the forced spike paradigm + there are several ways Parameters ---------- @@ -379,15 +402,15 @@ def run_until_biological_spike(self, voltage_t0, threshold_t0, AScurrents_t0, 'voltage': simulated voltage value 'threshold': simulated threshold values 'AScurrent_matrix': afterspike current values during the simulation - 'grid_model_spike_time': model spike time (in units of dt) + 'grid_model_spike_time': model spike time (in units of dt) 'interpolated_model_spike_time': model spike time (in units of dt) interpolated between time steps - 'voltage_t0': reset voltage value to be used in subsequent simulation interval + 'voltage_t0': reset voltage value to be used in subsequent simulation interval 'threshold_t0': reset threshold value to be used in subsequent simulation interval 'AScurrents_t0': reset afterspike current value to be used in subsequent simulation interval 'grid_bio_spike_model_voltage': model voltage at the time of the input spike 'grid_bio_spike_model_threshold': model threshold at the time of the input spike """ - + grid_model_spike_time = None grid_model_spike_voltage = None interpolated_model_spike_time = None @@ -395,237 +418,346 @@ def run_until_biological_spike(self, voltage_t0, threshold_t0, AScurrents_t0, # preallocate arrays and matricies num_time_steps_fine = after_end_index - start_index - t_fine_grid = np.arange(num_time_steps_fine)*self.dt + t_fine_grid = np.arange(num_time_steps_fine) * self.dt - #-------------------------------------------------------------------------------- - #---Apply refinement factor to integrate over larger time steps (assumes--------- - #---current within the steps can be averaged):---------------------------------- - #-------------------------------------------------------------------------------- + # -------------------------------------------------------------------------------- + # ---Apply refinement factor to integrate over larger time steps (assumes--------- + # ---current within the steps can be averaged):---------------------------------- + # -------------------------------------------------------------------------------- dt_old = self.dt - self.dt = self.dt*self.dt_multiplier - + self.dt = self.dt * self.dt_multiplier + # define the local course grain indicies note the last graining will be shorter and is appended to the end - local_coarse_indicies=np.append(np.arange(num_time_steps_fine)[::self.dt_multiplier], after_end_index - start_index) #the last indicie in this array is still one longer than the last simulated index - # convert the local coarse grained indicies into global incidies - global_coarse_indicies=local_coarse_indicies+start_index + local_coarse_indicies = np.append( + np.arange(num_time_steps_fine)[:: self.dt_multiplier], after_end_index - start_index + ) # the last indicie in this array is still one longer than the last simulated index + # convert the local coarse grained indicies into global incidies + global_coarse_indicies = local_coarse_indicies + start_index # TODO: I dont think this does anything. - if len(local_coarse_indicies)==2: + if len(local_coarse_indicies) == 2: pass - + num_time_steps_coarse = len(local_coarse_indicies) voltage_out_coarse_grid = np.empty(num_time_steps_coarse) voltage_out_coarse_grid[:] = np.nan threshold_out_coarse_grid = np.empty(num_time_steps_coarse) threshold_out_coarse_grid[:] = np.nan - AScurrent_matrix_coarse_grid = np.empty(shape=(num_time_steps_coarse, len(AScurrents_t0))) + AScurrent_matrix_coarse_grid = np.empty(shape=(num_time_steps_coarse, len(AScurrents_t0))) AScurrent_matrix_coarse_grid[:] = np.nan # these grid times are in the local frame of reference - t_coarse_grid = np.arange(num_time_steps_coarse-1)*self.dt #subtracting the one off here because appending the actual last time that is not the same dt. + t_coarse_grid = ( + np.arange(num_time_steps_coarse - 1) * self.dt + ) # subtracting the one off here because appending the actual last time that is not the same dt. t_coarse_grid = np.append(t_coarse_grid, t_fine_grid[-1]) - dt_vector=t_coarse_grid[1:]-t_coarse_grid[:-1] #note that this vector is one index shorter than the t_course_grid + dt_vector = ( + t_coarse_grid[1:] - t_coarse_grid[:-1] + ) # note that this vector is one index shorter than the t_course_grid - # Define the coarse grain stimulus by taking the stimulus average between indicies. Note that since the initial input voltage is recorded in + # Define the coarse grain stimulus by taking the stimulus average between indicies. Note that since the initial input voltage is recorded in # the output vectors the stimulus average is indeed the input to the correct time step (i.e. current being fed in is the average current before the step) - stimulus_coarse=[stimulus[global_coarse_indicies[ii]:global_coarse_indicies[ii+1]].mean() for ii in range(len(global_coarse_indicies)-1)] + stimulus_coarse = [ + stimulus[global_coarse_indicies[ii] : global_coarse_indicies[ii + 1]].mean() + for ii in range(len(global_coarse_indicies) - 1) + ] # step though time steps and calculate voltage values - for time_step in range(len(local_coarse_indicies)-1): #minus 1 is needed to match vector sizes because initial inputs are recorded in output vectors. - # update output values (Note: in general one can update values before or after the first time step. + for time_step in range( + len(local_coarse_indicies) - 1 + ): # minus 1 is needed to match vector sizes because initial inputs are recorded in output vectors. + # update output values (Note: in general one can update values before or after the first time step. # Here the input starting value of voltage is recorded before a time step. This means the last value is not recorded.) - voltage_out_coarse_grid[time_step] = voltage_t0 + voltage_out_coarse_grid[time_step] = voltage_t0 threshold_out_coarse_grid[time_step] = threshold_t0 - AScurrent_matrix_coarse_grid[time_step,:] = np.matrix(AScurrents_t0) - + AScurrent_matrix_coarse_grid[time_step, :] = np.matrix(AScurrents_t0) + # record error in optimization if they are happening - if np.isnan(voltage_t0) or np.isinf(voltage_t0) or np.isnan(threshold_t0) or np.isinf(threshold_t0) or any(np.isnan(AScurrents_t0)) or any(np.isinf(AScurrents_t0)): + if ( + np.isnan(voltage_t0) + or np.isinf(voltage_t0) + or np.isnan(threshold_t0) + or np.isinf(threshold_t0) + or any(np.isnan(AScurrents_t0)) + or any(np.isinf(AScurrents_t0)) + ): logging.error(self) - logging.error('time step: %d / %d' % (time_step, num_time_steps_coarse)) - logging.error(' voltage_t0: %f' % voltage_t0) - logging.error(' voltage started the run at: %f' % voltage_out_coarse_grid[0]) - logging.error(' voltage before: %s' % voltage_out_coarse_grid[time_step-20:time_step]) - logging.error(' threshold_t0: %f' % threshold_t0) - logging.error(' threshold started the run at: %f' % threshold_out_coarse_grid[0]) - logging.error(' threshold before: %s' % threshold_out_coarse_grid[time_step-20:time_step]) - logging.error(' AScurrents_t0: %s' % AScurrents_t0) - if 'a_spike' in self.threshold_dynamics_method.params: - logging.error(' a_spike: %s' % self.threshold_dynamics_method.params['a_spike']) - if 'b_spike' in self.threshold_dynamics_method.params: - logging.error(' b_spike: %s' % self.threshold_dynamics_method.params['b_spike']) - + logging.error("time step: %d / %d" % (time_step, num_time_steps_coarse)) + logging.error(" voltage_t0: %f" % voltage_t0) + logging.error(" voltage started the run at: %f" % voltage_out_coarse_grid[0]) + logging.error(" voltage before: %s" % voltage_out_coarse_grid[time_step - 20 : time_step]) + logging.error(" threshold_t0: %f" % threshold_t0) + logging.error(" threshold started the run at: %f" % threshold_out_coarse_grid[0]) + logging.error(" threshold before: %s" % threshold_out_coarse_grid[time_step - 20 : time_step]) + logging.error(" AScurrents_t0: %s" % AScurrents_t0) + if "a_spike" in self.threshold_dynamics_method.params: + logging.error(" a_spike: %s" % self.threshold_dynamics_method.params["a_spike"]) + if "b_spike" in self.threshold_dynamics_method.params: + logging.error(" b_spike: %s" % self.threshold_dynamics_method.params["b_spike"]) + # plot output in original index space - temp_fine_grid_for_intp=np.arange(0,t_coarse_grid[time_step-1], dt_old) + temp_fine_grid_for_intp = np.arange(0, t_coarse_grid[time_step - 1], dt_old) voltage_out_fine_grid = np.empty(num_time_steps_fine) voltage_out_fine_grid[:] = np.nan threshold_out_fine_grid = np.empty(num_time_steps_fine) threshold_out_fine_grid[:] = np.nan - fv = spi.interp1d(t_coarse_grid[:time_step], voltage_out_coarse_grid[:time_step], assume_sorted=True, bounds_error=False, fill_value=voltage_out_coarse_grid[-1]) - ft = spi.interp1d(t_coarse_grid[:time_step], threshold_out_coarse_grid[:time_step], assume_sorted=True, bounds_error=False, fill_value=threshold_out_coarse_grid[-1]) + fv = spi.interp1d( + t_coarse_grid[:time_step], + voltage_out_coarse_grid[:time_step], + assume_sorted=True, + bounds_error=False, + fill_value=voltage_out_coarse_grid[-1], + ) + ft = spi.interp1d( + t_coarse_grid[:time_step], + threshold_out_coarse_grid[:time_step], + assume_sorted=True, + bounds_error=False, + fill_value=threshold_out_coarse_grid[-1], + ) voltage_with_error = fv(temp_fine_grid_for_intp) threshold_with_error = ft(temp_fine_grid_for_intp) - voltage_out_fine_grid[:len(voltage_with_error)]=voltage_with_error - threshold_out_fine_grid[:len(threshold_with_error)]=threshold_with_error - + voltage_out_fine_grid[: len(voltage_with_error)] = voltage_with_error + threshold_out_fine_grid[: len(threshold_with_error)] = threshold_with_error + AScurrent_matrix = np.empty(shape=(num_time_steps_fine, len(AScurrents_t0))) AScurrent_matrix[:] = np.nan for ii in range(len(AScurrents_t0)): - curr_fASc = spi.interp1d(t_coarse_grid[:time_step], AScurrent_matrix_coarse_grid[:time_step,ii], assume_sorted=True, bounds_error=False, fill_value=AScurrent_matrix_coarse_grid[-1,ii]) - temp_asc=curr_fASc(temp_fine_grid_for_intp) - AScurrent_matrix[:len(temp_asc),ii] = temp_asc - - raise GlifNeuronException('Invalid threshold, voltage, or after-spike current encountered.', { - 'voltage': voltage_out_fine_grid, - 'threshold': threshold_out_fine_grid, - 'AScurrent_matrix': AScurrent_matrix - }) - + curr_fASc = spi.interp1d( + t_coarse_grid[:time_step], + AScurrent_matrix_coarse_grid[:time_step, ii], + assume_sorted=True, + bounds_error=False, + fill_value=AScurrent_matrix_coarse_grid[-1, ii], + ) + temp_asc = curr_fASc(temp_fine_grid_for_intp) + AScurrent_matrix[: len(temp_asc), ii] = temp_asc + + raise GlifNeuronException( + "Invalid threshold, voltage, or after-spike current encountered.", + { + "voltage": voltage_out_fine_grid, + "threshold": threshold_out_fine_grid, + "AScurrent_matrix": AScurrent_matrix, + }, + ) + # changing dt be the dt of the coarse bin (which is variable for the last bin) - self.dt=dt_vector[time_step] - (voltage_t1, threshold_t1, AScurrents_t1) = self.dynamics(voltage_t0, threshold_t0, AScurrents_t0, stimulus_coarse[time_step], time_step+start_index, bio_spike_time_steps) #TODO fix list versus array - - # updating the input values - voltage_t0=voltage_t1 - threshold_t0=threshold_t1 - AScurrents_t0=AScurrents_t1 - + self.dt = dt_vector[time_step] + (voltage_t1, threshold_t1, AScurrents_t1) = self.dynamics( + voltage_t0, + threshold_t0, + AScurrents_t0, + stimulus_coarse[time_step], + time_step + start_index, + bio_spike_time_steps, + ) # TODO fix list versus array + + # updating the input values + voltage_t0 = voltage_t1 + threshold_t0 = threshold_t1 + AScurrents_t0 = AScurrents_t1 + # Inserting the last values into the nan at the end of the matricies so that when do the interpolation the end of the vector will not be nans # Note this should not mess with any of the outputs because is the interploated values that are the output. - voltage_out_coarse_grid[time_step+1] = voltage_t0 - threshold_out_coarse_grid[time_step+1] = threshold_t0 - AScurrent_matrix_coarse_grid[time_step+1,:] = np.matrix(AScurrents_t0) + voltage_out_coarse_grid[time_step + 1] = voltage_t0 + threshold_out_coarse_grid[time_step + 1] = threshold_t0 + AScurrent_matrix_coarse_grid[time_step + 1, :] = np.matrix(AScurrents_t0) # Reset dt to previous value: - self.dt = dt_old - - fv = spi.interp1d(t_coarse_grid, voltage_out_coarse_grid, assume_sorted=True, bounds_error=False, fill_value=voltage_out_coarse_grid[-1]) - ft = spi.interp1d(t_coarse_grid, threshold_out_coarse_grid, assume_sorted=True, bounds_error=False, fill_value=threshold_out_coarse_grid[-1]) + self.dt = dt_old + + fv = spi.interp1d( + t_coarse_grid, + voltage_out_coarse_grid, + assume_sorted=True, + bounds_error=False, + fill_value=voltage_out_coarse_grid[-1], + ) + ft = spi.interp1d( + t_coarse_grid, + threshold_out_coarse_grid, + assume_sorted=True, + bounds_error=False, + fill_value=threshold_out_coarse_grid[-1], + ) voltage_out = fv(t_fine_grid) threshold_out = ft(t_fine_grid) - + # initalize after spike current matrix AScurrent_matrix = np.empty(shape=(num_time_steps_fine, len(AScurrents_t0))) AScurrent_matrix[:] = np.nan for ii in range(len(AScurrents_t0)): - curr_fASc = spi.interp1d(t_coarse_grid, AScurrent_matrix_coarse_grid[:,ii], assume_sorted=True, bounds_error=False, fill_value=AScurrent_matrix_coarse_grid[-1,ii]) - AScurrent_matrix[:,ii] = curr_fASc(t_fine_grid) - - # find where model voltage crosses model threshold - grid_model_spike_time, grid_model_spike_voltage, interpolated_model_spike_time, interpolated_model_spike_voltage = find_first_model_spike(voltage_out, threshold_out, voltage_t1, threshold_t1, self.dt) + curr_fASc = spi.interp1d( + t_coarse_grid, + AScurrent_matrix_coarse_grid[:, ii], + assume_sorted=True, + bounds_error=False, + fill_value=AScurrent_matrix_coarse_grid[-1, ii], + ) + AScurrent_matrix[:, ii] = curr_fASc(t_fine_grid) + + # find where model voltage crosses model threshold + ( + grid_model_spike_time, + grid_model_spike_voltage, + interpolated_model_spike_time, + interpolated_model_spike_voltage, + ) = find_first_model_spike(voltage_out, threshold_out, voltage_t1, threshold_t1, self.dt) # if the model never spiked, extrapolate to guess when it would have spiked - if grid_model_spike_time is None: - grid_model_spike_time, grid_model_spike_voltage, interpolated_model_spike_time, interpolated_model_spike_voltage = self.extrapolation_method(self, voltage_out, threshold_out, voltage_t1, threshold_t1, self.dt) - + if grid_model_spike_time is None: + ( + grid_model_spike_time, + grid_model_spike_voltage, + interpolated_model_spike_time, + interpolated_model_spike_voltage, + ) = self.extrapolation_method(self, voltage_out, threshold_out, voltage_t1, threshold_t1, self.dt) + # when the target spikes, reset so that next round will start at reset but not recording it in the voltage here. # note that at the last section of the stimulus where there is no current injected the model will be reset even if - # the biological neuron doesn't spike. However, this doesnt matter as it won't be recorded. + # the biological neuron doesn't spike. However, this doesnt matter as it won't be recorded. num_spikes = len(bio_spike_time_steps) if num_spikes > 0: - if after_end_index threshold[time_step]: - grid_model_spike_time = dt * (time_step-1) - grid_model_spike_voltage = voltage[time_step-1] - - interpolated_model_spike_time = glif_neuron.interpolate_spike_time(dt, time_step-1, - threshold[time_step-1], threshold[time_step], - voltage[time_step-1], voltage[time_step]) - - interpolated_model_spike_voltage = interpolate_spike_voltage(dt, time_step-1, - threshold[time_step-1], threshold[time_step], - voltage[time_step-1], voltage[time_step]) - - return grid_model_spike_time, grid_model_spike_voltage, interpolated_model_spike_time, interpolated_model_spike_voltage + grid_model_spike_time = dt * (time_step - 1) + grid_model_spike_voltage = voltage[time_step - 1] + + interpolated_model_spike_time = glif_neuron.interpolate_spike_time( + dt, + time_step - 1, + threshold[time_step - 1], + threshold[time_step], + voltage[time_step - 1], + voltage[time_step], + ) + + interpolated_model_spike_voltage = interpolate_spike_voltage( + dt, + time_step - 1, + threshold[time_step - 1], + threshold[time_step], + voltage[time_step - 1], + voltage[time_step], + ) + + return ( + grid_model_spike_time, + grid_model_spike_voltage, + interpolated_model_spike_time, + interpolated_model_spike_voltage, + ) # if the last voltage is above threshold and there hasn't already been a spike - if voltage_t1 > threshold_t1: - grid_model_spike_time = dt * ( num_time_steps - 1 ) + if voltage_t1 > threshold_t1: + grid_model_spike_time = dt * (num_time_steps - 1) grid_model_spike_voltage = voltage_t1 - - interpolated_model_spike_time = glif_neuron.interpolate_spike_time(dt, num_time_steps - 1, threshold[num_time_steps-1], threshold_t1, voltage[num_time_steps-1], voltage_t1) - interpolated_model_spike_voltage = interpolate_spike_voltage(dt, num_time_steps, threshold[-1], threshold_t1, voltage[-1], voltage_t1) - return grid_model_spike_time, grid_model_spike_voltage, interpolated_model_spike_time, interpolated_model_spike_voltage + interpolated_model_spike_time = glif_neuron.interpolate_spike_time( + dt, num_time_steps - 1, threshold[num_time_steps - 1], threshold_t1, voltage[num_time_steps - 1], voltage_t1 + ) + interpolated_model_spike_voltage = interpolate_spike_voltage( + dt, num_time_steps, threshold[-1], threshold_t1, voltage[-1], voltage_t1 + ) + return ( + grid_model_spike_time, + grid_model_spike_voltage, + interpolated_model_spike_time, + interpolated_model_spike_voltage, + ) return None, None, None, None + def extrapolate_model_spike_from_endpoints(neuron, voltage, threshold, voltage_t1, threshold_t1, dt): - - #--extrapolate using first point in ISI and last point in ISI + # --extrapolate using first point in ISI and last point in ISI num_time_steps = len(voltage) - - interpolated_model_spike_time = extrapolate_spike_time(dt, num_time_steps, threshold[0], threshold_t1, voltage[0], voltage_t1) - interpolated_model_spike_voltage = extrapolate_spike_voltage(dt, num_time_steps, threshold[0], threshold_t1, voltage[0], voltage_t1) - - grid_model_spike_time = np.ceil(interpolated_model_spike_time / dt) * dt # grid spike time based off extrapolated spike time + + interpolated_model_spike_time = extrapolate_spike_time( + dt, num_time_steps, threshold[0], threshold_t1, voltage[0], voltage_t1 + ) + interpolated_model_spike_voltage = extrapolate_spike_voltage( + dt, num_time_steps, threshold[0], threshold_t1, voltage[0], voltage_t1 + ) + + grid_model_spike_time = ( + np.ceil(interpolated_model_spike_time / dt) * dt + ) # grid spike time based off extrapolated spike time grid_model_spike_voltage = interpolated_model_spike_voltage - result = grid_model_spike_time, grid_model_spike_voltage, interpolated_model_spike_time, interpolated_model_spike_voltage - + result = ( + grid_model_spike_time, + grid_model_spike_voltage, + interpolated_model_spike_time, + interpolated_model_spike_voltage, + ) return result - def extrapolate_model_spike_from_endpoints_single_tau(neuron, voltage, threshold, voltage_t1, threshold_t1, dt): tau_m = neuron.tau_m num_time_steps = len(voltage) - ii = np.floor(tau_m/dt) - starting_ind = max(0,(num_time_steps - ii)) - result = extrapolate_model_spike_from_endpoints(neuron, voltage[starting_ind:], threshold[starting_ind:], voltage_t1, threshold_t1, dt) - + ii = np.floor(tau_m / dt) + starting_ind = max(0, (num_time_steps - ii)) + result = extrapolate_model_spike_from_endpoints( + neuron, voltage[starting_ind:], threshold[starting_ind:], voltage_t1, threshold_t1, dt + ) + return result - + + def extrapolate_spike_time(dt, num_time_steps, threshold_t0, threshold_t1, voltage_t0, voltage_t1): - """ Given two voltage and threshold values and an interval between them, extrapolate a spike time - by intersecting lines the thresholds and voltages. """ + """Given two voltage and threshold values and an interval between them, extrapolate a spike time + by intersecting lines the thresholds and voltages.""" return glif_neuron.line_crossing_x(dt * num_time_steps, voltage_t0, voltage_t1, threshold_t0, threshold_t1) + def extrapolate_spike_voltage(dt, num_time_steps, threshold_t0, threshold_t1, voltage_t0, voltage_t1): - """ Given two voltage and threshold values and an interval between them, extrapolate a spike time - by intersecting lines the thresholds and voltages. """ + """Given two voltage and threshold values and an interval between them, extrapolate a spike time + by intersecting lines the thresholds and voltages.""" return glif_neuron.line_crossing_y(dt * num_time_steps, voltage_t0, voltage_t1, threshold_t0, threshold_t1) - + + def interpolate_spike_voltage(dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1): - """ Given two voltage and threshold values, the dt between them and the initial time step, interpolate - a spike time within the dt interval by intersecting the two lines. """ - return time_step*dt + glif_neuron.line_crossing_y(dt, voltage_t0, voltage_t1, threshold_t0, threshold_t1) + """Given two voltage and threshold values, the dt between them and the initial time step, interpolate + a spike time within the dt interval by intersecting the two lines.""" + return time_step * dt + glif_neuron.line_crossing_y(dt, voltage_t0, voltage_t1, threshold_t0, threshold_t1) diff --git a/allensdk/internal/model/glif/optimize_neuron.py b/allensdk/internal/model/glif/optimize_neuron.py index ac06aff2ae..cb02552291 100644 --- a/allensdk/internal/model/glif/optimize_neuron.py +++ b/allensdk/internal/model/glif/optimize_neuron.py @@ -12,19 +12,21 @@ from allensdk.internal.model.glif.find_spikes import find_spikes_list import allensdk.internal.model.glif.preprocess_neuron as pn -def get_optimize_sweep_numbers(sweep_index): - #TODO: why is this here--why are sweep indicies being fed to a find_noise_sweeps sweeps and specifying - #noise?--shouldn't the sweeps already be provided? - return fs.find_noise_sweeps(sweep_index)['noise1'] + +def get_optimize_sweep_numbers(sweep_index): + # TODO: why is this here--why are sweep indicies being fed to a find_noise_sweeps sweeps and specifying + # noise?--shouldn't the sweeps already be provided? + return fs.find_noise_sweeps(sweep_index)["noise1"] + def optimize_neuron(model_config, sweep_index, nwb_file, save_callback=None): - '''Optimizes a neuron. + """Optimizes a neuron. 1. Loads optimizer and neuron configuration data. 2. Loads the voltage trace sweeps that will be optimized 3. Configures the experiment and optimizer 4. Runs the optimizer 5. TODO: where is data saved - + Parameters ---------- model_config : dictionary @@ -33,81 +35,87 @@ def optimize_neuron(model_config, sweep_index, nwb_file, save_callback=None): indices (as labeled in the data configuration file) of sweeps that will be optimized save_callback : module saves output - ''' + """ # define the neuron and optimizer dictionaries from the model configuration - neuron_config = model_config['neuron'] - optimizer_config = model_config['optimizer'] + neuron_config = model_config["neuron"] + optimizer_config = model_config["optimizer"] # load the neuron with along with the methods needed for optimization neuron = GlifOptimizerNeuron.from_dict(neuron_config) # TODO: not sure what this is doing optimize_sweeps = get_optimize_sweep_numbers(sweep_index) - + # load the sweeps to be optimized - optimize_data = load_sweeps(nwb_file, optimize_sweeps, neuron.dt, - optimizer_config["cut"], optimizer_config["bessel"]) + optimize_data = load_sweeps( + nwb_file, optimize_sweeps, neuron.dt, optimizer_config["cut"], optimizer_config["bessel"] + ) # needed to offset all voltages by El_reference - El_reference = neuron_config['El_reference'] + El_reference = neuron_config["El_reference"] # get indicies of spikes and voltage at those spikes - spike_ind, spike_v = find_spikes_list(optimize_data['voltage'], neuron_config['dt']) + spike_ind, spike_v = find_spikes_list(optimize_data["voltage"], neuron_config["dt"]) + + # get times of spikes + grid_spike_times = [si * neuron_config["dt"] for si in spike_ind] - # get times of spikes - grid_spike_times = [ si*neuron_config['dt'] for si in spike_ind ] - # convert voltage at spikes into reference frame of El - grid_spike_voltages_in_ref_to_zero = [ sv - El_reference for sv in spike_v ] + grid_spike_voltages_in_ref_to_zero = [sv - El_reference for sv in spike_v] - # convert voltage into reference frame of El - resp_list = [ d - El_reference for d in optimize_data['voltage'] ] + # convert voltage into reference frame of El + resp_list = [d - El_reference for d in optimize_data["voltage"]] # configure experiment - experiment = GlifExperiment(neuron = neuron, - dt = neuron.dt, - stim_list = optimize_data['current'], - resp_list = resp_list, - spike_time_steps = spike_ind, - grid_spike_times = grid_spike_times, - grid_spike_voltages = grid_spike_voltages_in_ref_to_zero, - param_fit_names = optimizer_config['param_fit_names']) + experiment = GlifExperiment( + neuron=neuron, + dt=neuron.dt, + stim_list=optimize_data["current"], + resp_list=resp_list, + spike_time_steps=spike_ind, + grid_spike_times=grid_spike_times, + grid_spike_voltages=grid_spike_voltages_in_ref_to_zero, + param_fit_names=optimizer_config["param_fit_names"], + ) # configure optimizer - optimizer = GlifOptimizer(experiment = experiment, - dt = neuron.dt, - outer_iterations = optimizer_config['outer_iterations'], - inner_iterations = optimizer_config['inner_iterations'], - sigma_inner = optimizer_config['sigma_inner'], - sigma_outer = optimizer_config['sigma_outer'], - param_fit_names = optimizer_config['param_fit_names'], - stim = optimize_data['current'], - error_function_data = optimizer_config['error_function_data'], - xtol = optimizer_config['xtol'], - ftol = optimizer_config['ftol'], - internal_iterations = optimizer_config['internal_iterations'], - init_params = optimizer_config.get('init_params', None), - bessel = optimizer_config['bessel']) + optimizer = GlifOptimizer( + experiment=experiment, + dt=neuron.dt, + outer_iterations=optimizer_config["outer_iterations"], + inner_iterations=optimizer_config["inner_iterations"], + sigma_inner=optimizer_config["sigma_inner"], + sigma_outer=optimizer_config["sigma_outer"], + param_fit_names=optimizer_config["param_fit_names"], + stim=optimize_data["current"], + error_function_data=optimizer_config["error_function_data"], + xtol=optimizer_config["xtol"], + ftol=optimizer_config["ftol"], + internal_iterations=optimizer_config["internal_iterations"], + init_params=optimizer_config.get("init_params", None), + bessel=optimizer_config["bessel"], + ) def save(optimizer, outer, inner): - logging.info('finished outer: %d inner: %d' % (outer, inner)) + logging.info("finished outer: %d inner: %d" % (outer, inner)) if save_callback: save_callback(optimizer, outer, inner) - + # run the optimizer - best_param, begin_param = optimizer.run_many(save) + best_param, begin_param = optimizer.run_many(save) # over write the the initial experiment parameters with the best found parameters # TODO: but why do this since it is not being returned experiment.set_neuron_parameters(best_param) - + return optimizer, best_param, begin_param - + + def main(): parser = argparse.ArgumentParser() - parser.add_argument('model_config_file') - parser.add_argument('sweeps_file') - parser.add_argument('output_file') + parser.add_argument("model_config_file") + parser.add_argument("sweeps_file") + parser.add_argument("output_file") parser.add_argument("--dt", default=pn.DEFAULT_DT) parser.add_argument("--bessel", default=pn.DEFAULT_BESSEL) parser.add_argument("--cut", default=pn.DEFAULT_CUT) @@ -117,7 +125,7 @@ def main(): model_config = ju.read(args.model_config_file) sweep_list = ju.read(args.sweeps_file) - sweep_index = { s['sweep_number']:s for s in sweep_list } + sweep_index = {s["sweep_number"]: s for s in sweep_list} try: neuron, best_param, begin_param = optimize_neuron(model_config, sweep_index, dt, cut, bessel) @@ -125,5 +133,6 @@ def main(): except Exception as e: logging.error(e.message) + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/glif/plotting.py b/allensdk/internal/model/glif/plotting.py index a6e23ff420..9e35c5a3fa 100644 --- a/allensdk/internal/model/glif/plotting.py +++ b/allensdk/internal/model/glif/plotting.py @@ -1,106 +1,141 @@ -'''Written by Corinne Teeter 3-31-14 -''' +"""Written by Corinne Teeter 3-31-14""" import matplotlib.pyplot as plt import numpy as np -def checkPreprocess(originalStim_list, processedStim_list, originalVoltage_list, processedVoltage_list, config, blockME=False): - - timeOriginal=np.arange(len(np.concatenate(originalStim_list)))*config.neuron['dt'] - if 'subSample' in config.dictOfPreprocessMethods.keys(): - timeProcessed=np.arange(len(np.concatenate(processedStim_list)))*config.dictOfPreprocessMethods['subSample']['desired_time_step'] +def checkPreprocess( + originalStim_list, processedStim_list, originalVoltage_list, processedVoltage_list, config, blockME=False +): + timeOriginal = np.arange(len(np.concatenate(originalStim_list))) * config.neuron["dt"] + if "subSample" in config.dictOfPreprocessMethods.keys(): + timeProcessed = ( + np.arange(len(np.concatenate(processedStim_list))) + * config.dictOfPreprocessMethods["subSample"]["desired_time_step"] + ) else: - timeProcessed=timeOriginal - - plt.figure(figsize=(20,10)) - plt.subplot(4,1,1) - plt.plot(timeOriginal, np.concatenate(originalStim_list), 'b') - plt.title('original stimulation') - - plt.subplot(4,1,2) - plt.plot(timeProcessed, np.concatenate(processedStim_list), 'r') - plt.title('processed stimulation') - - plt.subplot(4,1,3) - plt.plot(timeOriginal, np.concatenate(originalVoltage_list), 'b') - plt.title('original voltage') - - plt.subplot(4,1,4) - plt.plot(timeProcessed, np.concatenate(processedVoltage_list), 'r') - plt.title('processed voltage') - - plt.annotate(config.cellName+': View result of preprocessing', xy=(.4, .975), - xycoords='figure fraction', - horizontalalignment='left', verticalalignment='top', - fontsize=20) + timeProcessed = timeOriginal + + plt.figure(figsize=(20, 10)) + plt.subplot(4, 1, 1) + plt.plot(timeOriginal, np.concatenate(originalStim_list), "b") + plt.title("original stimulation") + + plt.subplot(4, 1, 2) + plt.plot(timeProcessed, np.concatenate(processedStim_list), "r") + plt.title("processed stimulation") + + plt.subplot(4, 1, 3) + plt.plot(timeOriginal, np.concatenate(originalVoltage_list), "b") + plt.title("original voltage") + + plt.subplot(4, 1, 4) + plt.plot(timeProcessed, np.concatenate(processedVoltage_list), "r") + plt.title("processed voltage") + + plt.annotate( + config.cellName + ": View result of preprocessing", + xy=(0.4, 0.975), + xycoords="figure fraction", + horizontalalignment="left", + verticalalignment="top", + fontsize=20, + ) + # plt.show(block=blockME) + def plotSpikes(voltage_list, spike_ind_list, dt, blockME=False, method=False): - - converted_spike_ind_list=[] - time=np.arange(len(np.concatenate(voltage_list)))*dt - #--find the length of each vector - thelength=0 + converted_spike_ind_list = [] + time = np.arange(len(np.concatenate(voltage_list))) * dt + # --find the length of each vector + thelength = 0 for ii, voltage in enumerate(voltage_list): - converted_spike_ind_list.append(spike_ind_list[ii]+thelength) - thelength=thelength+len(voltage) - - subsampled_time=[time[ii] for ii in np.concatenate(converted_spike_ind_list)] + converted_spike_ind_list.append(spike_ind_list[ii] + thelength) + thelength = thelength + len(voltage) + + subsampled_time = [time[ii] for ii in np.concatenate(converted_spike_ind_list)] plt.figure(figsize=(20, 5)) - plt.plot(time, np.concatenate(voltage_list), 'b') - plt.plot(subsampled_time, [np.concatenate(voltage_list)[ii] for ii in np.concatenate(converted_spike_ind_list)], 'r.', ms=16) + plt.plot(time, np.concatenate(voltage_list), "b") + plt.plot( + subsampled_time, + [np.concatenate(voltage_list)[ii] for ii in np.concatenate(converted_spike_ind_list)], + "r.", + ms=16, + ) if not method: - plt.title('Spikes') + plt.title("Spikes") else: - plt.title('Spikes. Method used: '+method) - - plt.ylabel('voltage (V)') - plt.xlabel('time (s)') + plt.title("Spikes. Method used: " + method) + + plt.ylabel("voltage (V)") + plt.xlabel("time (s)") + + # plt.show(block=blockME) - -def checkSpikeCutting(originalStim_list, cutStim_list, originalVoltage_list, cutVoltage_list, allindOfNonSpiking_list, config, blockME=False): - - if len(originalStim_list)!=len(cutStim_list) or \ - len(originalStim_list)!=len(originalVoltage_list) or \ - len(originalStim_list)!=len(cutVoltage_list) or \ - len(originalStim_list)!=len(allindOfNonSpiking_list): - raise Exception('lists are not the same length') - - - lengthGoingToAdd=0 - whole_ind=np.array([]) - whole_v=np.array([]) + + +def checkSpikeCutting( + originalStim_list, + cutStim_list, + originalVoltage_list, + cutVoltage_list, + allindOfNonSpiking_list, + config, + blockME=False, +): + if ( + len(originalStim_list) != len(cutStim_list) + or len(originalStim_list) != len(originalVoltage_list) + or len(originalStim_list) != len(cutVoltage_list) + or len(originalStim_list) != len(allindOfNonSpiking_list) + ): + raise Exception("lists are not the same length") + + lengthGoingToAdd = 0 + whole_ind = np.array([]) + whole_v = np.array([]) for trace, ind_array in zip(originalVoltage_list, allindOfNonSpiking_list): - ind=np.arange(0, len(trace))+lengthGoingToAdd - whole_ind=np.append(whole_ind, [ind[ii] for ii in ind_array]) - whole_v=np.append(whole_v, [trace[ii] for ii in ind_array]) - lengthGoingToAdd=lengthGoingToAdd+len(trace) - - - time=np.arange(len(np.concatenate(originalStim_list)))*config.neuron['dt'] - plt.figure(figsize=(20,10)) - - plt.subplot(1,1,1) + ind = np.arange(0, len(trace)) + lengthGoingToAdd + whole_ind = np.append(whole_ind, [ind[ii] for ii in ind_array]) + whole_v = np.append(whole_v, [trace[ii] for ii in ind_array]) + lengthGoingToAdd = lengthGoingToAdd + len(trace) + + time = np.arange(len(np.concatenate(originalStim_list))) * config.neuron["dt"] + plt.figure(figsize=(20, 10)) + + plt.subplot(1, 1, 1) plt.plot(time, np.concatenate(originalVoltage_list)) - plt.title('voltage') - - plt.plot(whole_ind*config.neuron['dt'], whole_v, '--r', lw=2) - - plt.annotate(config.cellName+': check spike cutting', xy=(.4, .975), - xycoords='figure fraction', - horizontalalignment='left', verticalalignment='top', - fontsize=20) + plt.title("voltage") + + plt.plot(whole_ind * config.neuron["dt"], whole_v, "--r", lw=2) + + plt.annotate( + config.cellName + ": check spike cutting", + xy=(0.4, 0.975), + xycoords="figure fraction", + horizontalalignment="left", + verticalalignment="top", + fontsize=20, + ) + + # plt.show(block=blockME) -def plotLineRegress1(slope, intercept, r,xlim): - y=slope*xlim+intercept - print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) - plt.plot(xlim, y, '-k', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r) -def plotLineRegressRed(slope, intercept, r,xlim): - y=slope*xlim+intercept - print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) - plt.plot(xlim, y, '-r', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r) +def plotLineRegress1(slope, intercept, r, xlim): + y = slope * xlim + intercept + print("slope=", slope, "intercept=", intercept, "xlim", xlim) + plt.plot( + xlim, y, "-k", lw=4, label="slope=" + "%.2f" % slope + ", intercept=" + "%.3f" % intercept + ", r=" + "%.2f" % r + ) + + +def plotLineRegressRed(slope, intercept, r, xlim): + y = slope * xlim + intercept + print("slope=", slope, "intercept=", intercept, "xlim", xlim) + plt.plot( + xlim, y, "-r", lw=4, label="slope=" + "%.2f" % slope + ", intercept=" + "%.3f" % intercept + ", r=" + "%.2f" % r + ) diff --git a/allensdk/internal/model/glif/preprocess_neuron.py b/allensdk/internal/model/glif/preprocess_neuron.py index bc407d9cca..ee4a4191b4 100644 --- a/allensdk/internal/model/glif/preprocess_neuron.py +++ b/allensdk/internal/model/glif/preprocess_neuron.py @@ -19,23 +19,27 @@ import matplotlib.pyplot as plt import allensdk.internal.model.glif.plotting as plotting -RESTING_POTENTIAL = 'slow_vm_mv' +RESTING_POTENTIAL = "slow_vm_mv" DEFAULT_DT = 5e-05 DEFAULT_CUT = 0 -DEFAULT_BESSEL = { 'N': 4, 'freq': 10000 } +DEFAULT_BESSEL = {"N": 4, "freq": 10000} MAKE_PLOT = True SHOW_PLOT = False -SAVE_FIG =True +SAVE_FIG = True SHORT_RUN = False + class MissingSpikeException(Exception): pass -RESTING_POTENTIAL = 'slow_vm_mv' -def find_first_spike_voltage(voltage, dt, ssq=False, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, - dv_cutoff=20.0, thresh_frac=0.05): - '''calculate voltage at threshold of first spike +RESTING_POTENTIAL = "slow_vm_mv" + + +def find_first_spike_voltage( + voltage, dt, ssq=False, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, dv_cutoff=20.0, thresh_frac=0.05 +): + """calculate voltage at threshold of first spike Parameters ---------- voltage: numpy array @@ -54,417 +58,537 @@ def find_first_spike_voltage(voltage, dt, ssq=False, MAKE_PLOT=False, SHOW_PLOT= specifies cut off of the derivative of the voltage thresh_frac: float variable that goes into feature extractor - + Returns ------- :float voltage of threshold of first spike - ''' - + """ + if ssq: spike_time_steps, _ = find_spikes_ssq_list([voltage], dt, dv_cutoff=dv_cutoff, thresh_frac=thresh_frac) else: spike_time_steps, _ = find_spikes_list([voltage], dt) if MAKE_PLOT: - plotting.plotSpikes([voltage], spike_time_steps, dt, blockME=False, method='dvdt_v2') + plotting.plotSpikes([voltage], spike_time_steps, dt, blockME=False, method="dvdt_v2") if SHOW_PLOT: plt.show(block=BLOCK) if len(spike_time_steps[0]) == 0: - raise MissingSpikeException('No spike detected.') - + raise MissingSpikeException("No spike detected.") + return voltage[spike_time_steps[0][0]] + def tag_plot(tag, fs=9): - plt.annotate(tag, xy=(0.98, .01), - xycoords='figure fraction', - horizontalalignment='right', - verticalalignment='bottom', - fontsize=fs) - + plt.annotate( + tag, + xy=(0.98, 0.01), + xycoords="figure fraction", + horizontalalignment="right", + verticalalignment="bottom", + fontsize=fs, + ) + + def estimate_dv_cutoff(voltage_list, dt, start_t, end_t): - v_set = [ v * 1e3 for v in voltage_list ] - t_set = [ np.arange(0, len(v)) * dt for v in voltage_list ] - - dv_cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(v_set, t_set, - start_t, end_t, - filter=None) - + v_set = [v * 1e3 for v in voltage_list] + t_set = [np.arange(0, len(v)) * dt for v in voltage_list] + + dv_cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters(v_set, t_set, start_t, end_t, filter=None) + return dv_cutoff, thresh_frac -def preprocess_neuron(nwb_file, sweep_list, cell_properties=None, - dt=None, cut=None, bessel=None, save_figure_path=None): + +def preprocess_neuron( + nwb_file, sweep_list, cell_properties=None, dt=None, cut=None, bessel=None, save_figure_path=None +): if dt is None: - dt = DEFAULT_DT + dt = DEFAULT_DT if cut is None: - cut = DEFAULT_CUT + cut = DEFAULT_CUT if bessel is None: bessel = DEFAULT_BESSEL - sweep_index = { s['sweep_number']: s for s in sweep_list } + sweep_index = {s["sweep_number"]: s for s in sweep_list} noise_sweeps = fs.find_noise_sweeps(sweep_index) - noise1_sweeps = noise_sweeps['noise1'] - noise2_sweeps = noise_sweeps['noise2'] - - ssq_sweeps = fs.find_short_square_sweeps(sweep_index) - all_ssq_data = load_sweeps(nwb_file, ssq_sweeps['all'], dt, cut, bessel) - ssq_dv_cutoff, ssq_thresh_frac = estimate_dv_cutoff(all_ssq_data['voltage'], dt, - efex.SHORT_SQUARES_WINDOW_START, - efex.SHORT_SQUARES_WINDOW_END) - - ssq_triple_sweeps = ssq_sweeps['triple'] - - ramp_sweeps = fs.find_ramp_sweeps(sweep_index)['suprathreshold'] - fs.find_ramp_to_rheo_sweeps(sweep_index)['all'] - + noise1_sweeps = noise_sweeps["noise1"] + noise2_sweeps = noise_sweeps["noise2"] + + ssq_sweeps = fs.find_short_square_sweeps(sweep_index) + all_ssq_data = load_sweeps(nwb_file, ssq_sweeps["all"], dt, cut, bessel) + ssq_dv_cutoff, ssq_thresh_frac = estimate_dv_cutoff( + all_ssq_data["voltage"], dt, efex.SHORT_SQUARES_WINDOW_START, efex.SHORT_SQUARES_WINDOW_END + ) + + ssq_triple_sweeps = ssq_sweeps["triple"] + + ramp_sweeps = fs.find_ramp_sweeps(sweep_index)["suprathreshold"] + fs.find_ramp_to_rheo_sweeps(sweep_index)["all"] + noise1_data = load_sweeps(nwb_file, noise1_sweeps, dt, cut, bessel) noise2_data = load_sweeps(nwb_file, noise2_sweeps, dt, cut, bessel) - maximum_subthreshold_short_square_sweeps = ssq_sweeps['maximum_subthreshold'] + maximum_subthreshold_short_square_sweeps = ssq_sweeps["maximum_subthreshold"] load_sweeps(nwb_file, [maximum_subthreshold_short_square_sweeps[0]], dt, cut, bessel) - minimum_suprathreshold_short_square_sweeps = ssq_sweeps['minimum_suprathreshold'] - minimum_suprathreshold_short_square_data = load_sweeps(nwb_file, [minimum_suprathreshold_short_square_sweeps[0]], dt, cut, bessel) - - dt = noise1_data['dt'][0] #getting subsampled dt returned for ease of use - - subthresh_noise_current_list=[] - subthresh_noise_voltage_list=[] - noise_El_list=[] - for ss in range(0, len(noise1_data['current'])): - #--subthreshold noise has first epoch of noise with a region of no stimulation before and after (note the selection of end point is hard coded) - subthresh_noise_current_list.append(noise1_data['current'][ss][noise1_data['start_idx'][ss]:int(6./dt)]) - subthresh_noise_voltage_list.append(noise1_data['voltage'][ss][noise1_data['start_idx'][ss]:int(6./dt)]) - noise_El_list.append(sweep_index[noise1_sweeps[ss]][RESTING_POTENTIAL]*1e-3) + minimum_suprathreshold_short_square_sweeps = ssq_sweeps["minimum_suprathreshold"] + minimum_suprathreshold_short_square_data = load_sweeps( + nwb_file, [minimum_suprathreshold_short_square_sweeps[0]], dt, cut, bessel + ) + + dt = noise1_data["dt"][0] # getting subsampled dt returned for ease of use + + subthresh_noise_current_list = [] + subthresh_noise_voltage_list = [] + noise_El_list = [] + for ss in range(0, len(noise1_data["current"])): + # --subthreshold noise has first epoch of noise with a region of no stimulation before and after (note the selection of end point is hard coded) + subthresh_noise_current_list.append(noise1_data["current"][ss][noise1_data["start_idx"][ss] : int(6.0 / dt)]) + subthresh_noise_voltage_list.append(noise1_data["voltage"][ss][noise1_data["start_idx"][ss] : int(6.0 / dt)]) + noise_El_list.append(sweep_index[noise1_sweeps[ss]][RESTING_POTENTIAL] * 1e-3) # Els calculated from QC - El_noise=np.mean(noise_El_list) - El_subthreshold_blip=sweep_index[maximum_subthreshold_short_square_sweeps[0]][RESTING_POTENTIAL]*1e-3 - El_suprathreshold_blip=sweep_index[minimum_suprathreshold_short_square_sweeps[0]][RESTING_POTENTIAL]*1e-3 + El_noise = np.mean(noise_El_list) + El_subthreshold_blip = sweep_index[maximum_subthreshold_short_square_sweeps[0]][RESTING_POTENTIAL] * 1e-3 + El_suprathreshold_blip = sweep_index[minimum_suprathreshold_short_square_sweeps[0]][RESTING_POTENTIAL] * 1e-3 if len(ramp_sweeps): - logging.info('has ramp') + logging.info("has ramp") load_sweeps(nwb_file, ramp_sweeps, dt, cut, bessel) - El_ramp=sweep_index[ramp_sweeps[0]][RESTING_POTENTIAL]*1e-3 + El_ramp = sweep_index[ramp_sweeps[0]][RESTING_POTENTIAL] * 1e-3 else: - ramp_sweeps=None - El_ramp=None + ramp_sweeps = None + El_ramp = None logging.info("No ramp") if len(ssq_triple_sweeps): - logging.info('has multi ss') + logging.info("has multi ss") multi_ssq_data = load_sweeps(nwb_file, ssq_triple_sweeps, dt, cut, bessel) - multi_ssq_dv_cutoff, multi_ssq_thresh_frac = estimate_dv_cutoff(multi_ssq_data['voltage'], dt, - efex.SHORT_SQUARE_TRIPLE_WINDOW_START, - efex.SHORT_SQUARE_TRIPLE_WINDOW_END) + multi_ssq_dv_cutoff, multi_ssq_thresh_frac = estimate_dv_cutoff( + multi_ssq_data["voltage"], dt, efex.SHORT_SQUARE_TRIPLE_WINDOW_START, efex.SHORT_SQUARE_TRIPLE_WINDOW_END + ) print("*************************") print("ssq", ssq_dv_cutoff, ssq_thresh_frac) - print("triple",multi_ssq_dv_cutoff, multi_ssq_thresh_frac) + print("triple", multi_ssq_dv_cutoff, multi_ssq_thresh_frac) else: - ssq_triple_sweeps=None + ssq_triple_sweeps = None multi_ssq_data = None logging.info("No multi short square") # Needed for MLIN long_square_config = fs.find_long_square_sweeps(sweep_index) - maximum_subthreshold_long_square_sweeps = long_square_config['maximum_subthreshold'] - #TODO: Here you are loading just one sweep: probably should load all - maximum_subthreshold_long_square_data = load_sweeps(nwb_file, [maximum_subthreshold_long_square_sweeps[0]], dt, cut, bessel) - El_max_subth_long_square=sweep_index[maximum_subthreshold_long_square_sweeps[0]][RESTING_POTENTIAL]*1e-3 + maximum_subthreshold_long_square_sweeps = long_square_config["maximum_subthreshold"] + # TODO: Here you are loading just one sweep: probably should load all + maximum_subthreshold_long_square_data = load_sweeps( + nwb_file, [maximum_subthreshold_long_square_sweeps[0]], dt, cut, bessel + ) + El_max_subth_long_square = sweep_index[maximum_subthreshold_long_square_sweeps[0]][RESTING_POTENTIAL] * 1e-3 - #--------------------------------------------------------------- - #---------find spiking indicies of spikes in noise-------------- - #--------------------------------------------------------------- + # --------------------------------------------------------------- + # ---------find spiking indicies of spikes in noise-------------- + # --------------------------------------------------------------- # note that when using find_spikes_list without removing the testpulse a warning will result from calculating feature_data['base_v'] in the feature extractor (line 375) this not relavent here - noise1_ind_wo_test_pulse_removed, _ = find_spikes_list(noise1_data['voltage'], dt) - noise2_ind_wo_test_pulse_removed, _ = find_spikes_list(noise2_data['voltage'], dt) - #Put all ISI ind in a - ISI_length=np.array([]) + noise1_ind_wo_test_pulse_removed, _ = find_spikes_list(noise1_data["voltage"], dt) + noise2_ind_wo_test_pulse_removed, _ = find_spikes_list(noise2_data["voltage"], dt) + # Put all ISI ind in a + ISI_length = np.array([]) for ii in range(len(noise1_ind_wo_test_pulse_removed)): - ISI_length=np.append(ISI_length,noise1_ind_wo_test_pulse_removed[ii][1:]-noise1_ind_wo_test_pulse_removed[ii][:-1]) + ISI_length = np.append( + ISI_length, noise1_ind_wo_test_pulse_removed[ii][1:] - noise1_ind_wo_test_pulse_removed[ii][:-1] + ) for ii in range(len(noise2_ind_wo_test_pulse_removed)): - ISI_length=np.append(ISI_length,noise2_ind_wo_test_pulse_removed[ii][1:]-noise2_ind_wo_test_pulse_removed[ii][:-1]) - min_ISI_len=np.min(ISI_length) - - - #------------------------------------------------------------------------------------------------------------------- - #---------------------Compute R, C and EL via least squares------------------------------------------------- - #------------------------------------------------------------------------------------------------------------------- - - #--compute R, C, and El via least squares tested in verify_RCEl_GLM_vs_lssq_and_smooth.py - (R_test_list, C_test_list, El_test_list)=least_squares_RCEl_calc_tested(subthresh_noise_voltage_list, subthresh_noise_current_list, dt) - R_test_list_mean=np.mean(R_test_list) - C_test_list_mean=np.mean(C_test_list) - - #----------------------------------------------------------------------------------------- - #------------------------ compute spike cut length---------------------------------------- - #----------------------------------------------------------------------------------------- - - #TODO: I should disentangle this function so I can get rid of the deltaV dependency - (spike_cut_length_NODELTAV, slope_at_min_expVar_list_NODELTAV, intercept_at_min_expVar_list_NODELTAV) \ - = calc_spike_cut_and_v_reset_via_expvar_residuals(noise1_data['current'], noise1_data['voltage'], - dt, El_noise, 0, - max_spike_cut_time=min_ISI_len*dt, - MAKE_PLOT=MAKE_PLOT, - SHOW_PLOT=SHOW_PLOT, - BLOCK=False) + ISI_length = np.append( + ISI_length, noise2_ind_wo_test_pulse_removed[ii][1:] - noise2_ind_wo_test_pulse_removed[ii][:-1] + ) + min_ISI_len = np.min(ISI_length) + + # ------------------------------------------------------------------------------------------------------------------- + # ---------------------Compute R, C and EL via least squares------------------------------------------------- + # ------------------------------------------------------------------------------------------------------------------- + + # --compute R, C, and El via least squares tested in verify_RCEl_GLM_vs_lssq_and_smooth.py + (R_test_list, C_test_list, El_test_list) = least_squares_RCEl_calc_tested( + subthresh_noise_voltage_list, subthresh_noise_current_list, dt + ) + R_test_list_mean = np.mean(R_test_list) + C_test_list_mean = np.mean(C_test_list) + + # ----------------------------------------------------------------------------------------- + # ------------------------ compute spike cut length---------------------------------------- + # ----------------------------------------------------------------------------------------- + + # TODO: I should disentangle this function so I can get rid of the deltaV dependency + (spike_cut_length_NODELTAV, slope_at_min_expVar_list_NODELTAV, intercept_at_min_expVar_list_NODELTAV) = ( + calc_spike_cut_and_v_reset_via_expvar_residuals( + noise1_data["current"], + noise1_data["voltage"], + dt, + El_noise, + 0, + max_spike_cut_time=min_ISI_len * dt, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=SHOW_PLOT, + BLOCK=False, + ) + ) if SAVE_FIG: - tag='spikeCutting_noDeltaV_regression.png' + tag = "spikeCutting_noDeltaV_regression.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() - tag='spikeCutting_noDeltaV_spike_wave_form.png' + tag = "spikeCutting_noDeltaV_spike_wave_form.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() - logging.info('spike cut length: %d', spike_cut_length_NODELTAV) - - #----------------------------------------------------------------------------------------- - #------------------------ compute ASC amplitudes------------------------------------------ - #----------------------------------------------------------------------------------------- - - #***Hack: k's are being hard coded into this function and are not necessarily consistent with what is in the setting of AS currents!!! - k_asc_possible=np.array([3, 10., 30., 100., 300.]) - if SHORT_RUN: #THIS IS JUST FOR DEBUGGING SO THAT YOU DONT HAVE TO WAIT FOR THE ENTIRE MODULE TO RUN - (best_k_pair_fit_ascR, best_asc_amp_fit_ascR, best_R_fit_ascR, best_llh_fit_ascR)=ASGLM_pairwise(k_asc_possible, - noise1_data['current'], noise1_data['voltage'], noise1_ind_wo_test_pulse_removed, - C_test_list_mean, C_test_list_mean*R_test_list_mean, spike_cut_length_NODELTAV, dt, El_noise, - SHORT_RUN=True, MAKE_PLOT=MAKE_PLOT, SHOW_PLOT=SHOW_PLOT, BLOCK=False) - asc_amp_from_ASGLM=np.mean(best_asc_amp_fit_ascR, axis=0) - R_from_ASGLM=np.mean(best_R_fit_ascR) + logging.info("spike cut length: %d", spike_cut_length_NODELTAV) + + # ----------------------------------------------------------------------------------------- + # ------------------------ compute ASC amplitudes------------------------------------------ + # ----------------------------------------------------------------------------------------- + + # ***Hack: k's are being hard coded into this function and are not necessarily consistent with what is in the setting of AS currents!!! + k_asc_possible = np.array([3, 10.0, 30.0, 100.0, 300.0]) + if SHORT_RUN: # THIS IS JUST FOR DEBUGGING SO THAT YOU DONT HAVE TO WAIT FOR THE ENTIRE MODULE TO RUN + (best_k_pair_fit_ascR, best_asc_amp_fit_ascR, best_R_fit_ascR, best_llh_fit_ascR) = ASGLM_pairwise( + k_asc_possible, + noise1_data["current"], + noise1_data["voltage"], + noise1_ind_wo_test_pulse_removed, + C_test_list_mean, + C_test_list_mean * R_test_list_mean, + spike_cut_length_NODELTAV, + dt, + El_noise, + SHORT_RUN=True, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=SHOW_PLOT, + BLOCK=False, + ) + asc_amp_from_ASGLM = np.mean(best_asc_amp_fit_ascR, axis=0) + R_from_ASGLM = np.mean(best_R_fit_ascR) else: - (best_k_pair_fit_ascR, best_asc_amp_fit_ascR, best_R_fit_ascR, best_llh_fit_ascR)=ASGLM_pairwise(k_asc_possible, - noise1_data['current'], noise1_data['voltage'], noise1_ind_wo_test_pulse_removed, - C_test_list_mean, C_test_list_mean*R_test_list_mean, spike_cut_length_NODELTAV, dt, El_noise, - SHORT_RUN=False, MAKE_PLOT=MAKE_PLOT, SHOW_PLOT=SHOW_PLOT, BLOCK=False) - asc_amp_from_ASGLM=np.mean(best_asc_amp_fit_ascR, axis=0) - R_from_ASGLM=np.mean(best_R_fit_ascR) + (best_k_pair_fit_ascR, best_asc_amp_fit_ascR, best_R_fit_ascR, best_llh_fit_ascR) = ASGLM_pairwise( + k_asc_possible, + noise1_data["current"], + noise1_data["voltage"], + noise1_ind_wo_test_pulse_removed, + C_test_list_mean, + C_test_list_mean * R_test_list_mean, + spike_cut_length_NODELTAV, + dt, + El_noise, + SHORT_RUN=False, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=SHOW_PLOT, + BLOCK=False, + ) + asc_amp_from_ASGLM = np.mean(best_asc_amp_fit_ascR, axis=0) + R_from_ASGLM = np.mean(best_R_fit_ascR) if SAVE_FIG: - tag='GLM_fit_ascR_basis.png' + tag = "GLM_fit_ascR_basis.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() - tag='GLM_fit_ascR_sumASC.png' + tag = "GLM_fit_ascR_sumASC.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() - tag='GLM_fit_ascR_individualASC.png' + tag = "GLM_fit_ascR_individualASC.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() - logging.info('Output of ASC fitting GLM') - logging.info('R out_of_GLM_Rfit_Cfixed %f %s', R_from_ASGLM/1e6, "MOhms") - logging.info('ASC amplitudes at the time of cut spike %s', str(asc_amp_from_ASGLM*1e12)) + logging.info("Output of ASC fitting GLM") + logging.info("R out_of_GLM_Rfit_Cfixed %f %s", R_from_ASGLM / 1e6, "MOhms") + logging.info("ASC amplitudes at the time of cut spike %s", str(asc_amp_from_ASGLM * 1e12)) logging.info("ks used %s", str(best_k_pair_fit_ascR)) - #----------------------------------------------------------------------------------------- - #------------------------ calculate thresholds----------------------------------------- - #----------------------------------------------------------------------------------------- + # ----------------------------------------------------------------------------------------- + # ------------------------ calculate thresholds----------------------------------------- + # ----------------------------------------------------------------------------------------- - # ---extract instantaneous threshold from suprathreshold blip + # ---extract instantaneous threshold from suprathreshold blip try: - th_inf_via_Vmeasure = find_first_spike_voltage(minimum_suprathreshold_short_square_data['voltage'][0][minimum_suprathreshold_short_square_data['start_idx'][0]:], - dt, - ssq=True, - MAKE_PLOT=MAKE_PLOT, - SHOW_PLOT=SHOW_PLOT, - BLOCK=False, - dv_cutoff=ssq_dv_cutoff, - thresh_frac=ssq_thresh_frac) - th_inf_via_Vmeasure_from0=th_inf_via_Vmeasure-El_suprathreshold_blip + th_inf_via_Vmeasure = find_first_spike_voltage( + minimum_suprathreshold_short_square_data["voltage"][0][ + minimum_suprathreshold_short_square_data["start_idx"][0] : + ], + dt, + ssq=True, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=SHOW_PLOT, + BLOCK=False, + dv_cutoff=ssq_dv_cutoff, + thresh_frac=ssq_thresh_frac, + ) + th_inf_via_Vmeasure_from0 = th_inf_via_Vmeasure - El_suprathreshold_blip if SAVE_FIG: - tag='th_inf_from_blip.png' + tag = "th_inf_from_blip.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path, tag), format='png') + plt.savefig(os.path.join(save_figure_path, tag), format="png") plt.close() except MissingSpikeException: - raise MissingSpikeException("The suprathreshold short square sweep must have a spike, but no spike was detected. This means that feature extraction and GLIF spike detection are inconsistent.") + raise MissingSpikeException( + "The suprathreshold short square sweep must have a spike, but no spike was detected. This means that feature extraction and GLIF spike detection are inconsistent." + ) - #----------------------------------------------------------------------------------------------------- - #-----------------find spike and voltage component of the threshold--------------------------- - #----------------------------------------------------------------------------------------------------- + # ----------------------------------------------------------------------------------------------------- + # -----------------find spike and voltage component of the threshold--------------------------- + # ----------------------------------------------------------------------------------------------------- # If a multishort square stimulus exists calculate spike component of threshold.. if multi_ssq_data: - (a_spike_component_of_threshold, b_spike_component_of_threshold, - mean_voltage_first_spike_of_blip) = calc_spike_component_of_threshold_from_multiblip(multi_ssq_data, - dt, - multi_ssq_dv_cutoff, - multi_ssq_thresh_frac, - MAKE_PLOT=MAKE_PLOT, - SHOW_PLOT=False, - BLOCK=False, - PUBLICATION_PLOT=False) - #adjust values to be after spike cutting + (a_spike_component_of_threshold, b_spike_component_of_threshold, mean_voltage_first_spike_of_blip) = ( + calc_spike_component_of_threshold_from_multiblip( + multi_ssq_data, + dt, + multi_ssq_dv_cutoff, + multi_ssq_thresh_frac, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=False, + BLOCK=False, + PUBLICATION_PLOT=False, + ) + ) + # adjust values to be after spike cutting if a_spike_component_of_threshold is not None and b_spike_component_of_threshold is not None: - a_spike_component_of_threshold=spike_component_of_threshold_exact(a_spike_component_of_threshold, b_spike_component_of_threshold, spike_cut_length_NODELTAV*dt) + a_spike_component_of_threshold = spike_component_of_threshold_exact( + a_spike_component_of_threshold, b_spike_component_of_threshold, spike_cut_length_NODELTAV * dt + ) if SAVE_FIG: - tag='multiblip_fit.png' + tag = "multiblip_fit.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path, tag), format='png') - plt.close() + plt.savefig(os.path.join(save_figure_path, tag), format="png") + plt.close() - tag='multiblip_data.png' + tag = "multiblip_data.png" tag_plot(tag, fs=9) - plt.savefig(os.path.join(save_figure_path,tag), format='png') - plt.close() + plt.savefig(os.path.join(save_figure_path, tag), format="png") + plt.close() - #---calculate voltage componet of threshold + # ---calculate voltage componet of threshold if a_spike_component_of_threshold is None or b_spike_component_of_threshold is None: logging.warning("spike component of threshold could not be calculated from the multiblip data") - a_voltage_comp_of_thr_from_fitab=None - b_voltage_comp_of_thr_from_fitab=None - a_voltage_comp_of_thr_from_fitabth=None - b_voltage_comp_of_thr_from_fitabth=None - th_inf_fit_w_v_comp_of_th=None - th_inf_fit_w_v_comp_of_th_from0=None + a_voltage_comp_of_thr_from_fitab = None + b_voltage_comp_of_thr_from_fitab = None + a_voltage_comp_of_thr_from_fitabth = None + b_voltage_comp_of_thr_from_fitabth = None + th_inf_fit_w_v_comp_of_th = None + th_inf_fit_w_v_comp_of_th_from0 = None else: - #TODO:this function needs to be changed to use experimental change of reference - b_voltage_guess = 5.0 - a_voltage_guess = 0.1*b_voltage_guess - fit_ab_vcomp_from_noise = fmin(func=fit_avoltage_bvoltage, - args=(noise1_data['voltage'], - noise_El_list, - spike_cut_length_NODELTAV, - noise1_ind_wo_test_pulse_removed, #NOTE THAT IF YOU WANT TO USE THIS TO GET A VOLTAGE WITHIN THE FUNCTION YOU NEED TO SUBTRACT OFF AND INDICIE BECAUSE THIS IS THE VALUE SET TO NAN AS THE SPIKE WAS INITIATED IN THE PREVIOUS TIME STEP. - th_inf_via_Vmeasure, - dt, - a_spike_component_of_threshold, - b_spike_component_of_threshold), - x0=[a_voltage_guess,b_voltage_guess]) - - a_voltage_comp_of_thr_from_fitab=fit_ab_vcomp_from_noise[0] - b_voltage_comp_of_thr_from_fitab=fit_ab_vcomp_from_noise[1] - - logging.info("spike components %s %s", str(a_spike_component_of_threshold), str(b_spike_component_of_threshold)) - logging.info("voltage components %s %s", str(a_voltage_comp_of_thr_from_fitab), str(b_voltage_comp_of_thr_from_fitab)) - - fit_ab_vcomp_th_from_thr_from_noise = fmin(func=fit_avoltage_bvoltage_th, - args=(noise1_data['voltage'], - noise_El_list, - spike_cut_length_NODELTAV, - noise1_ind_wo_test_pulse_removed, #NOTE THAT IF YOU WANT TO USE THIS TO GET A VOLTAGE WITHIN THE FUNCTION YOU NEED TO SUBTRACT OFF AND INDICIE BECAUSE THIS IS THE VALUE SET TO NAN AS THE SPIKE WAS INITIATED IN THE PREVIOUS TIME STEP. - dt, - a_spike_component_of_threshold, - b_spike_component_of_threshold), - x0=[a_voltage_guess,b_voltage_guess, th_inf_via_Vmeasure]) - - a_voltage_comp_of_thr_from_fitabth=fit_ab_vcomp_th_from_thr_from_noise[0] - b_voltage_comp_of_thr_from_fitabth=fit_ab_vcomp_th_from_thr_from_noise[1] - th_inf_fit_w_v_comp_of_th=fit_ab_vcomp_th_from_thr_from_noise[2] - th_inf_fit_w_v_comp_of_th_from0=th_inf_fit_w_v_comp_of_th-El_noise - logging.info("spike components %s %s", str(a_spike_component_of_threshold), str(b_spike_component_of_threshold)) - logging.info("voltage components %s %s %s %s", str(a_voltage_comp_of_thr_from_fitabth), str(b_voltage_comp_of_thr_from_fitabth), 'fit threshold', str(th_inf_fit_w_v_comp_of_th)) - + # TODO:this function needs to be changed to use experimental change of reference + b_voltage_guess = 5.0 + a_voltage_guess = 0.1 * b_voltage_guess + fit_ab_vcomp_from_noise = fmin( + func=fit_avoltage_bvoltage, + args=( + noise1_data["voltage"], + noise_El_list, + spike_cut_length_NODELTAV, + noise1_ind_wo_test_pulse_removed, # NOTE THAT IF YOU WANT TO USE THIS TO GET A VOLTAGE WITHIN THE FUNCTION YOU NEED TO SUBTRACT OFF AND INDICIE BECAUSE THIS IS THE VALUE SET TO NAN AS THE SPIKE WAS INITIATED IN THE PREVIOUS TIME STEP. + th_inf_via_Vmeasure, + dt, + a_spike_component_of_threshold, + b_spike_component_of_threshold, + ), + x0=[a_voltage_guess, b_voltage_guess], + ) + + a_voltage_comp_of_thr_from_fitab = fit_ab_vcomp_from_noise[0] + b_voltage_comp_of_thr_from_fitab = fit_ab_vcomp_from_noise[1] + + logging.info( + "spike components %s %s", str(a_spike_component_of_threshold), str(b_spike_component_of_threshold) + ) + logging.info( + "voltage components %s %s", str(a_voltage_comp_of_thr_from_fitab), str(b_voltage_comp_of_thr_from_fitab) + ) + + fit_ab_vcomp_th_from_thr_from_noise = fmin( + func=fit_avoltage_bvoltage_th, + args=( + noise1_data["voltage"], + noise_El_list, + spike_cut_length_NODELTAV, + noise1_ind_wo_test_pulse_removed, # NOTE THAT IF YOU WANT TO USE THIS TO GET A VOLTAGE WITHIN THE FUNCTION YOU NEED TO SUBTRACT OFF AND INDICIE BECAUSE THIS IS THE VALUE SET TO NAN AS THE SPIKE WAS INITIATED IN THE PREVIOUS TIME STEP. + dt, + a_spike_component_of_threshold, + b_spike_component_of_threshold, + ), + x0=[a_voltage_guess, b_voltage_guess, th_inf_via_Vmeasure], + ) + + a_voltage_comp_of_thr_from_fitabth = fit_ab_vcomp_th_from_thr_from_noise[0] + b_voltage_comp_of_thr_from_fitabth = fit_ab_vcomp_th_from_thr_from_noise[1] + th_inf_fit_w_v_comp_of_th = fit_ab_vcomp_th_from_thr_from_noise[2] + th_inf_fit_w_v_comp_of_th_from0 = th_inf_fit_w_v_comp_of_th - El_noise + logging.info( + "spike components %s %s", str(a_spike_component_of_threshold), str(b_spike_component_of_threshold) + ) + logging.info( + "voltage components %s %s %s %s", + str(a_voltage_comp_of_thr_from_fitabth), + str(b_voltage_comp_of_thr_from_fitabth), + "fit threshold", + str(th_inf_fit_w_v_comp_of_th), + ) else: - a_spike_component_of_threshold=None - b_spike_component_of_threshold=None - a_voltage_comp_of_thr_from_fitab=None - b_voltage_comp_of_thr_from_fitab=None - a_voltage_comp_of_thr_from_fitabth=None - b_voltage_comp_of_thr_from_fitabth=None - th_inf_fit_w_v_comp_of_th_from0=None - th_inf_fit_w_v_comp_of_th=None - - #-------------------------------------------------------------------------- - #------------------------ MLIN calculations-------------------------------- - #-------------------------------------------------------------------------- - - #TODO: probably want to use more than just one square pulse for this distribution - STLS_voltage=maximum_subthreshold_long_square_data['voltage'][0][maximum_subthreshold_long_square_data['start_idx'][0]:] - STLS_current=maximum_subthreshold_long_square_data['current'][0][maximum_subthreshold_long_square_data['start_idx'][0]:] - (var_of_section, sv_for_expsymm, tau_from_AC)=MLIN(STLS_voltage, STLS_current, R_test_list_mean, C_test_list_mean, dt, - MAKE_PLOT=MAKE_PLOT, - SHOW_PLOT=SHOW_PLOT, - BLOCK=False, - PUBLICATION_PLOT=False) + a_spike_component_of_threshold = None + b_spike_component_of_threshold = None + a_voltage_comp_of_thr_from_fitab = None + b_voltage_comp_of_thr_from_fitab = None + a_voltage_comp_of_thr_from_fitabth = None + b_voltage_comp_of_thr_from_fitabth = None + th_inf_fit_w_v_comp_of_th_from0 = None + th_inf_fit_w_v_comp_of_th = None + + # -------------------------------------------------------------------------- + # ------------------------ MLIN calculations-------------------------------- + # -------------------------------------------------------------------------- + + # TODO: probably want to use more than just one square pulse for this distribution + STLS_voltage = maximum_subthreshold_long_square_data["voltage"][0][ + maximum_subthreshold_long_square_data["start_idx"][0] : + ] + STLS_current = maximum_subthreshold_long_square_data["current"][0][ + maximum_subthreshold_long_square_data["start_idx"][0] : + ] + (var_of_section, sv_for_expsymm, tau_from_AC) = MLIN( + STLS_voltage, + STLS_current, + R_test_list_mean, + C_test_list_mean, + dt, + MAKE_PLOT=MAKE_PLOT, + SHOW_PLOT=SHOW_PLOT, + BLOCK=False, + PUBLICATION_PLOT=False, + ) if SAVE_FIG: - tag='MLIN.png' + tag = "MLIN.png" tag_plot(tag) - plt.savefig(os.path.join(save_figure_path,tag), format='png') - plt.close() + plt.savefig(os.path.join(save_figure_path, tag), format="png") + plt.close() - #-------------------------------------------------------------------------- - #------------------------ make output dictionaries------------------------- - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- + # ------------------------ make output dictionaries------------------------- + # -------------------------------------------------------------------------- - #TODO: find out how many are the max number of all_passing_sweeps. - El_noise_1=[None, None, None, None, None] - sweep_noise_1=[None, None, None, None, None] - spike_ind_noise_1=[None, None, None, None, None] + # TODO: find out how many are the max number of all_passing_sweeps. + El_noise_1 = [None, None, None, None, None] + sweep_noise_1 = [None, None, None, None, None] + spike_ind_noise_1 = [None, None, None, None, None] def fill_in_lists(out_list, data_list): - '''note since the input is a list shouldnt need to return anything (pass by reference)''' + """note since the input is a list shouldnt need to return anything (pass by reference)""" for ii in range(len(data_list)): - out_list[ii]=data_list[ii] + out_list[ii] = data_list[ii] + fill_in_lists(El_noise_1, noise_El_list) fill_in_lists(sweep_noise_1, noise1_sweeps) fill_in_lists(spike_ind_noise_1, noise1_ind_wo_test_pulse_removed) - #--initialize output dictionaries - for_reference_dict={} - - for_reference_dict['dt_used_for_preprocessor_calculations']=dt - #for_reference_dict['optional_methods']=self.optional_methods - for_reference_dict['sweep_properties']={'noise1': - {'1':{'El': El_noise_1[0], 'sweep_num':sweep_noise_1[0], 'spike_ind': spike_ind_noise_1[0]}, - '2':{'El': El_noise_1[1], 'sweep_num':sweep_noise_1[1], 'spike_ind': spike_ind_noise_1[1]}, - '3':{'El': El_noise_1[2], 'sweep_num':sweep_noise_1[2], 'spike_ind': spike_ind_noise_1[2]}, - '4':{'El': El_noise_1[3], 'sweep_num':sweep_noise_1[3], 'spike_ind': spike_ind_noise_1[3]}, - '5':{'El': El_noise_1[4], 'sweep_num':sweep_noise_1[4], 'spike_ind': spike_ind_noise_1[4]}}, - 'ramp': {'sweep_num':ramp_sweeps}, - 'subthreshold_short_square': {'sweep_num':maximum_subthreshold_short_square_sweeps}, - 'suprathreshold_short_square': {'R_testpulsesweep_num':minimum_suprathreshold_short_square_sweeps}, - 'max_subthresh_long_square': {'sweep_num':maximum_subthreshold_long_square_sweeps}, - 'multi_short_square': {'sweep_num':ssq_triple_sweeps}} - - for_reference_dict['El']={'El_noise': {'measured': {'mean':El_noise, 'list':noise_El_list, 'dependencies':None}}, - 'El_ramp': {'value':El_ramp, 'dependencies': None}, - 'El_subthreshold_blip': {'value':El_subthreshold_blip, 'dependencies':None}, - 'El_suprathreshold_blip': {'value':El_suprathreshold_blip, 'dependencies':None}, - 'El_max_subth_long_square': {'value':El_max_subth_long_square, 'dependencies':None}} - - for_reference_dict['resistance']={#'R_lssq_Wrest':{'mean': R_lssq_wrest_mean, 'list': R_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, - 'R_from_lims':{'value':cell_properties['ri']*1e6}, - 'R_test_list': {'mean':R_test_list_mean, 'list': R_test_list}, - 'R_fit_ASC_and_R':{'mean':R_from_ASGLM, 'list': best_R_fit_ascR}} - - for_reference_dict['capacitance']={#'C_lssq_Wrest': {'mean':C_lssq_wrest_mean, 'list':C_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, - 'C_from_lims':{'value': (cell_properties['tau']*1e-3)/(cell_properties['ri']*1e6)}, - 'C_test_list': {'mean':C_test_list_mean, 'list': C_test_list}} - - for_reference_dict['spike_cut_length']={'no deltaV shift':{'length':spike_cut_length_NODELTAV, - 'slope':slope_at_min_expVar_list_NODELTAV, - 'intercept':intercept_at_min_expVar_list_NODELTAV, - 'dependencies':None}} - - for_reference_dict['spike_cutting']={'NOdeltaV': {'cut_length':spike_cut_length_NODELTAV, 'slope':slope_at_min_expVar_list_NODELTAV, 'intercept':intercept_at_min_expVar_list_NODELTAV, 'dependencies': None}} - - for_reference_dict['asc']={'k': best_k_pair_fit_ascR, 'amp':asc_amp_from_ASGLM, 'dependencies': 'Cap and res from least squares'} - - for_reference_dict['th_inf']={'via_Vmeasure':{'value':th_inf_via_Vmeasure, 'from_zero':th_inf_via_Vmeasure_from0, 'dependencies':'measured from suprathreshold blip'}, - 'fit_with_v_comp_of_th':{'value':th_inf_fit_w_v_comp_of_th, 'from_zero':th_inf_fit_w_v_comp_of_th_from0}} - - for_reference_dict['threshold_adaptation']={'a_spike_component_of_threshold':a_spike_component_of_threshold, - 'b_spike_component_of_threshold':b_spike_component_of_threshold, - 'a_voltage_comp_of_thr_from_fitab':a_voltage_comp_of_thr_from_fitab, - 'b_voltage_comp_of_thr_from_fitab':b_voltage_comp_of_thr_from_fitab, - 'a_voltage_comp_of_thr_from_fitabth':a_voltage_comp_of_thr_from_fitabth, - 'b_voltage_comp_of_thr_from_fitabth':b_voltage_comp_of_thr_from_fitabth} - for_reference_dict['MLIN']={'var_of_section':var_of_section, - 'sv_for_expsymm':sv_for_expsymm, - 'tau_from_AC':tau_from_AC} + # --initialize output dictionaries + for_reference_dict = {} + + for_reference_dict["dt_used_for_preprocessor_calculations"] = dt + # for_reference_dict['optional_methods']=self.optional_methods + for_reference_dict["sweep_properties"] = { + "noise1": { + "1": {"El": El_noise_1[0], "sweep_num": sweep_noise_1[0], "spike_ind": spike_ind_noise_1[0]}, + "2": {"El": El_noise_1[1], "sweep_num": sweep_noise_1[1], "spike_ind": spike_ind_noise_1[1]}, + "3": {"El": El_noise_1[2], "sweep_num": sweep_noise_1[2], "spike_ind": spike_ind_noise_1[2]}, + "4": {"El": El_noise_1[3], "sweep_num": sweep_noise_1[3], "spike_ind": spike_ind_noise_1[3]}, + "5": {"El": El_noise_1[4], "sweep_num": sweep_noise_1[4], "spike_ind": spike_ind_noise_1[4]}, + }, + "ramp": {"sweep_num": ramp_sweeps}, + "subthreshold_short_square": {"sweep_num": maximum_subthreshold_short_square_sweeps}, + "suprathreshold_short_square": {"R_testpulsesweep_num": minimum_suprathreshold_short_square_sweeps}, + "max_subthresh_long_square": {"sweep_num": maximum_subthreshold_long_square_sweeps}, + "multi_short_square": {"sweep_num": ssq_triple_sweeps}, + } + + for_reference_dict["El"] = { + "El_noise": {"measured": {"mean": El_noise, "list": noise_El_list, "dependencies": None}}, + "El_ramp": {"value": El_ramp, "dependencies": None}, + "El_subthreshold_blip": {"value": El_subthreshold_blip, "dependencies": None}, + "El_suprathreshold_blip": {"value": El_suprathreshold_blip, "dependencies": None}, + "El_max_subth_long_square": {"value": El_max_subth_long_square, "dependencies": None}, + } + + for_reference_dict[ + "resistance" + ] = { #'R_lssq_Wrest':{'mean': R_lssq_wrest_mean, 'list': R_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, + "R_from_lims": {"value": cell_properties["ri"] * 1e6}, + "R_test_list": {"mean": R_test_list_mean, "list": R_test_list}, + "R_fit_ASC_and_R": {"mean": R_from_ASGLM, "list": best_R_fit_ascR}, + } + + for_reference_dict[ + "capacitance" + ] = { #'C_lssq_Wrest': {'mean':C_lssq_wrest_mean, 'list':C_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, + "C_from_lims": {"value": (cell_properties["tau"] * 1e-3) / (cell_properties["ri"] * 1e6)}, + "C_test_list": {"mean": C_test_list_mean, "list": C_test_list}, + } + + for_reference_dict["spike_cut_length"] = { + "no deltaV shift": { + "length": spike_cut_length_NODELTAV, + "slope": slope_at_min_expVar_list_NODELTAV, + "intercept": intercept_at_min_expVar_list_NODELTAV, + "dependencies": None, + } + } + + for_reference_dict["spike_cutting"] = { + "NOdeltaV": { + "cut_length": spike_cut_length_NODELTAV, + "slope": slope_at_min_expVar_list_NODELTAV, + "intercept": intercept_at_min_expVar_list_NODELTAV, + "dependencies": None, + } + } + + for_reference_dict["asc"] = { + "k": best_k_pair_fit_ascR, + "amp": asc_amp_from_ASGLM, + "dependencies": "Cap and res from least squares", + } + + for_reference_dict["th_inf"] = { + "via_Vmeasure": { + "value": th_inf_via_Vmeasure, + "from_zero": th_inf_via_Vmeasure_from0, + "dependencies": "measured from suprathreshold blip", + }, + "fit_with_v_comp_of_th": {"value": th_inf_fit_w_v_comp_of_th, "from_zero": th_inf_fit_w_v_comp_of_th_from0}, + } + + for_reference_dict["threshold_adaptation"] = { + "a_spike_component_of_threshold": a_spike_component_of_threshold, + "b_spike_component_of_threshold": b_spike_component_of_threshold, + "a_voltage_comp_of_thr_from_fitab": a_voltage_comp_of_thr_from_fitab, + "b_voltage_comp_of_thr_from_fitab": b_voltage_comp_of_thr_from_fitab, + "a_voltage_comp_of_thr_from_fitabth": a_voltage_comp_of_thr_from_fitabth, + "b_voltage_comp_of_thr_from_fitabth": b_voltage_comp_of_thr_from_fitabth, + } + for_reference_dict["MLIN"] = { + "var_of_section": var_of_section, + "sv_for_expsymm": sv_for_expsymm, + "tau_from_AC": tau_from_AC, + } logging.info("finished") return for_reference_dict + def main(): parser = argparse.ArgumentParser() parser.add_argument("nwb_file") @@ -477,11 +601,11 @@ def main(): args = parser.parse_args() sweep_list = ju.read(args.sweep_list_file) - + values = preprocess_neuron(args.nwb_file, sweep_list) ju.write(args.output_json, values) - + if __name__ == "__main__": main() diff --git a/allensdk/internal/model/glif/rc.py b/allensdk/internal/model/glif/rc.py index e6d4eacad3..c10d67f950 100644 --- a/allensdk/internal/model/glif/rc.py +++ b/allensdk/internal/model/glif/rc.py @@ -1,47 +1,44 @@ import numpy as np - def least_squares_RCEl_calc_tested(voltage_list, current_list, dt): - '''Calculate resistance, capacitance and resting potential by performing + """Calculate resistance, capacitance and resting potential by performing least squares on current and voltage. - + Parameters ---------- - voltage_list: list of arrays + voltage_list: list of arrays voltage responses for several sweep repeats - current_list: list of arrays + current_list: list of arrays current injections for several sweep repeats dt: float time step size in voltage and current traces - + Returns ------- r_list: list of floats each value corresponds to the resistance of a sweep - c_list: list of floats + c_list: list of floats each value corresponds to the capacitance of a sweep - el_list: list of floats + el_list: list of floats each value corresponds to the resting potential of a sweep - ''' + """ - r_list=[] - c_list=[] - el_list=[] + r_list = [] + c_list = [] + el_list = [] for voltage, current in zip(voltage_list, current_list): - matrix=np.ones((len(voltage)-1, 3)) - matrix[:,0]=voltage[0:len(voltage)-1] - matrix[:,1]=current[0:len(current)-1] - lsq_non_der_raw=np.linalg.lstsq(matrix, voltage[1:])[0] -# (r_lsq_non_der_raw, c_lsq_non_der_raw, El_lsq_non_der_raw)=RCEL_from_standard_space(lsq_non_der_raw) - - c_lsq_non_der_raw=dt/lsq_non_der_raw[1] - r_lsq_non_der_raw=-lsq_non_der_raw[1]/(lsq_non_der_raw[0]-1) - El_lsq_non_der_raw=-lsq_non_der_raw[2]/(lsq_non_der_raw[0]-1) + matrix = np.ones((len(voltage) - 1, 3)) + matrix[:, 0] = voltage[0 : len(voltage) - 1] + matrix[:, 1] = current[0 : len(current) - 1] + lsq_non_der_raw = np.linalg.lstsq(matrix, voltage[1:])[0] + # (r_lsq_non_der_raw, c_lsq_non_der_raw, El_lsq_non_der_raw)=RCEL_from_standard_space(lsq_non_der_raw) + + c_lsq_non_der_raw = dt / lsq_non_der_raw[1] + r_lsq_non_der_raw = -lsq_non_der_raw[1] / (lsq_non_der_raw[0] - 1) + El_lsq_non_der_raw = -lsq_non_der_raw[2] / (lsq_non_der_raw[0] - 1) r_list.append(r_lsq_non_der_raw) c_list.append(c_lsq_non_der_raw) el_list.append(El_lsq_non_der_raw) - - return r_list, c_list, el_list - + return r_list, c_list, el_list diff --git a/allensdk/internal/model/glif/spike_cutting.py b/allensdk/internal/model/glif/spike_cutting.py index edfe0cf2dc..5bb0e2e380 100644 --- a/allensdk/internal/model/glif/spike_cutting.py +++ b/allensdk/internal/model/glif/spike_cutting.py @@ -2,215 +2,261 @@ from scipy import stats from allensdk.internal.model.glif.find_spikes import align_and_cut_spikes, ALIGN_CUT_WINDOW -import matplotlib.pyplot as plt - -def calc_spike_cut_and_v_reset_via_expvar_residuals(all_current_list, - all_voltage_list, dt, El_reference, deltaV, - max_spike_cut_time=False, - MAKE_PLOT=False, SHOW_PLOT=False, PUBLICATION_PLOT=False, BLOCK=False): - '''This function calculates where the spike should be cut based on explained variance. - The goal is to find a model where the voltage after a spike maximally explains the - voltage before a spike. This will also specify the voltage reset rule - inputs: - spike_determination_method: string specifing the method used to find threshold - all_current_list: list of current (list of current traces injected into neuron) - all_voltage_list: list of voltages (list of voltage trace) - The change is that if the slope is greater than one or intercept is greater than zero it forces it. - Regardless of required force the residuals are used. - ''' - - #--find the region of the spike needed for calculation of explained variance - (temp_v_spike_shape_list, all_i_spike_shape_list, all_thresholdInd, waveIndOfFirstSpikes, spikeFromWhichSweep) \ - = align_and_cut_spikes(all_voltage_list, all_current_list, dt) - - #--At this point it is unclear how this calculation should be done. - #--the slope should be fine no matter what, but the intercept dependency - #--will depend on the El, and deltaV - - #--change reference - all_v_spike_shape_list=[shape-El_reference-deltaV for shape in temp_v_spike_shape_list] +import matplotlib.pyplot as plt + + +def calc_spike_cut_and_v_reset_via_expvar_residuals( + all_current_list, + all_voltage_list, + dt, + El_reference, + deltaV, + max_spike_cut_time=False, + MAKE_PLOT=False, + SHOW_PLOT=False, + PUBLICATION_PLOT=False, + BLOCK=False, +): + """This function calculates where the spike should be cut based on explained variance. + The goal is to find a model where the voltage after a spike maximally explains the + voltage before a spike. This will also specify the voltage reset rule + inputs: + spike_determination_method: string specifing the method used to find threshold + all_current_list: list of current (list of current traces injected into neuron) + all_voltage_list: list of voltages (list of voltage trace) + The change is that if the slope is greater than one or intercept is greater than zero it forces it. + Regardless of required force the residuals are used. + """ + + # --find the region of the spike needed for calculation of explained variance + (temp_v_spike_shape_list, all_i_spike_shape_list, all_thresholdInd, waveIndOfFirstSpikes, spikeFromWhichSweep) = ( + align_and_cut_spikes(all_voltage_list, all_current_list, dt) + ) + + # --At this point it is unclear how this calculation should be done. + # --the slope should be fine no matter what, but the intercept dependency + # --will depend on the El, and deltaV + + # --change reference + all_v_spike_shape_list = [shape - El_reference - deltaV for shape in temp_v_spike_shape_list] # --setting limits to find explained variance - if max_spike_cut_time and max_spike_cut_time < .010: - expVarIndRangeAfterSpike = range(int(.001 / dt), int(max_spike_cut_time / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME + if max_spike_cut_time and max_spike_cut_time < 0.010: + expVarIndRangeAfterSpike = range( + int(0.001 / dt), int(max_spike_cut_time / dt) + ) # NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME else: - expVarIndRangeAfterSpike = range(int(.001 / dt), int(.010 / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME + expVarIndRangeAfterSpike = range( + int(0.001 / dt), int(0.010 / dt) + ) # NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME vectorIndex_of_max_explained_var = expVarIndRangeAfterSpike[0] # this is just here for the title of the plot list_of_endPointArrays = [] # this should end up a list of numpy arrays where each numpy array contains the indices of the v_spike_shape_list that are a certain time after the threshold for ii in expVarIndRangeAfterSpike: list_of_endPointArrays.append(np.array(all_thresholdInd) + ii) - - def line_force_slope_to_1(x,c): - return x+c - - def line_force_int_to_0(x, m): #TODO: CHANGE THIS TO REST TOD DISCONNECT EVERYTHING. - return m*x - -# HERE YOU GET THE SLOPE AND INTERCEPT AT EACH POINT + + def line_force_slope_to_1(x, c): + return x + c + + def line_force_int_to_0(x, m): # TODO: CHANGE THIS TO REST TOD DISCONNECT EVERYTHING. + return m * x + + # HERE YOU GET THE SLOPE AND INTERCEPT AT EACH POINT linRegress_error_4_each_time_end = [] - slope_at_each_time_end=[] - intercept_at_each_time_end=[] + slope_at_each_time_end = [] + intercept_at_each_time_end = [] varData_4_each_time_end = [] varModel_4_each_time_end = [] chi2 = [] - sum_residuals_4_each_time_end=[] + sum_residuals_4_each_time_end = [] xdata = np.array([v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)]) - for jj, vectorOfIndAcrossWaves in enumerate(list_of_endPointArrays): # these indices should be in terms of the spike waveforms -# print('jj', jj) + for jj, vectorOfIndAcrossWaves in enumerate( + list_of_endPointArrays + ): # these indices should be in terms of the spike waveforms + # print('jj', jj) # TODO: Teeter get rid of the nonblipness - v_at_specificEndPoint = [all_v_spike_shape_list[ii][index] for ii, index in enumerate(vectorOfIndAcrossWaves)] # this is calculating variance at certain time points - # --currently the model of voltage reset is a linear regression between voltage before the spike and the voltage after the spike but it could be more complicated (for example as a function of current) + v_at_specificEndPoint = [ + all_v_spike_shape_list[ii][index] for ii, index in enumerate(vectorOfIndAcrossWaves) + ] # this is calculating variance at certain time points + # --currently the model of voltage reset is a linear regression between voltage before the spike and the voltage after the spike but it could be more complicated (for example as a function of current) ydata = np.array(v_at_specificEndPoint) # this is the voltage at the specified end point slope, intercept, r_value, p_value, std_err = stats.linregress(xdata, ydata) -# print(slope, intercept, r_value, p_value, std_err) - -# if slope>1.0: -# logging.warning('linear regression slope is bigger than one: forcing slope to 1 and refitting intercept.') -# slope=1.0 -# (intercept, nothing)=curve_fit(line_force_slope_to_1, xdata, ydata) -# #print("NEW INTERCEPT:", intercept) -# if intercept>0.0: -# #warnings.warn('/t ... and intercept is bigger than zero: forcing intercept to 0') -# intercept=0.0 -# -# if intercept>0.0: -# logging.warning('Intercept is bigger than zero: forcing intercept to 0 and refitting slope.') -# intercept=0.0 -# (slope, nothing)=curve_fit(line_force_int_to_0,xdata, ydata) -# #print("NEW SLOPE: ", slope) -# if slope>1.0: -# logging.warning('/t ... and linear regression slope is bigger than one: forcing slope to 1.') -# slope=1.0 - + # print(slope, intercept, r_value, p_value, std_err) + + # if slope>1.0: + # logging.warning('linear regression slope is bigger than one: forcing slope to 1 and refitting intercept.') + # slope=1.0 + # (intercept, nothing)=curve_fit(line_force_slope_to_1, xdata, ydata) + # #print("NEW INTERCEPT:", intercept) + # if intercept>0.0: + # #warnings.warn('/t ... and intercept is bigger than zero: forcing intercept to 0') + # intercept=0.0 + # + # if intercept>0.0: + # logging.warning('Intercept is bigger than zero: forcing intercept to 0 and refitting slope.') + # intercept=0.0 + # (slope, nothing)=curve_fit(line_force_int_to_0,xdata, ydata) + # #print("NEW SLOPE: ", slope) + # if slope>1.0: + # logging.warning('/t ... and linear regression slope is bigger than one: forcing slope to 1.') + # slope=1.0 + slope_at_each_time_end.append(slope) intercept_at_each_time_end.append(intercept) ymodel = slope * xdata + intercept residuals = ydata - ymodel - sum_residuals=sum(abs(residuals)) + sum_residuals = sum(abs(residuals)) sum_residuals_4_each_time_end.append(sum_residuals) chi2.append(np.var(residuals)) # how well the model describes the data linRegress_error_4_each_time_end.append(std_err) varData_4_each_time_end.append(np.var(v_at_specificEndPoint)) varModel_4_each_time_end.append(np.var(ymodel)) - + # --these will line up with how many arrays there are in the list vectorIndex_of_min_sum_residuals = sum_residuals_4_each_time_end.index(min(sum_residuals_4_each_time_end)) - #----NOTE THIS ISNT ACTUALLY CALCULATING EXPLAINED VARIANCE!!!!!!!!!!!!!!!!!! - vectorIndex_of_max_explained_var=vectorIndex_of_min_sum_residuals - - + # ----NOTE THIS ISNT ACTUALLY CALCULATING EXPLAINED VARIANCE!!!!!!!!!!!!!!!!!! + vectorIndex_of_max_explained_var = vectorIndex_of_min_sum_residuals + all_v_spike_init_list = [v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)] -# USE THIS WHEN MUTIPLE VECTORS all_v_at_min_expVar_list=[v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)] - all_v_at_min_expVar_list = [v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)] - time_at_minExpVar=list_of_endPointArrays[vectorIndex_of_max_explained_var]*dt + # USE THIS WHEN MUTIPLE VECTORS all_v_at_min_expVar_list=[v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)] + all_v_at_min_expVar_list = [ + v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list) + ] + time_at_minExpVar = list_of_endPointArrays[vectorIndex_of_max_explained_var] * dt if MAKE_PLOT: truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt plt.figure(figsize=(20, 10)) for ii in range(0, len(all_v_spike_shape_list)): - plt.subplot(2,1,1) + plt.subplot(2, 1, 1) plt.plot(truncatedTime, temp_v_spike_shape_list[ii]) # plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k' - plt.plot(truncatedTime[all_thresholdInd[ii]], temp_v_spike_shape_list[ii][all_thresholdInd[ii]], '*k') - plt.title('Non adusted spikes') - - plt.subplot(2,1,2) + plt.plot(truncatedTime[all_thresholdInd[ii]], temp_v_spike_shape_list[ii][all_thresholdInd[ii]], "*k") + plt.title("Non adusted spikes") + + plt.subplot(2, 1, 2) plt.plot(truncatedTime, all_v_spike_shape_list[ii]) - plt.plot(time_at_minExpVar, all_v_at_min_expVar_list, '*k') - plt.xlabel('time (s)', fontsize=20) - plt.ylabel('voltage (mV)', fontsize=20) - plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV)) - + plt.plot(time_at_minExpVar, all_v_at_min_expVar_list, "*k") + plt.xlabel("time (s)", fontsize=20) + plt.ylabel("voltage (mV)", fontsize=20) + plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference, deltaV)) + if PUBLICATION_PLOT: truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt plt.figure(figsize=(20, 5)) for ii in range(0, len(all_v_spike_shape_list)): - plt.plot(truncatedTime*1000, temp_v_spike_shape_list[ii]*1e3, lw=2) + plt.plot(truncatedTime * 1000, temp_v_spike_shape_list[ii] * 1e3, lw=2) # plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k' - plt.plot(truncatedTime[all_thresholdInd[ii]]*1000, temp_v_spike_shape_list[ii][all_thresholdInd[ii]]*1e3, '.k', ms=10) -# plt.title('Spike Cutting', fontsize=20) -# plt.subplot(2,1,2) -# plt.plot(truncatedTime, all_v_spike_shape_list[ii]) - plt.plot(time_at_minExpVar*1000, (np.array(all_v_at_min_expVar_list)+El_reference+deltaV)*1.e3, '.k', ms=10) - plt.xlabel('Time (ms)', fontsize=16) - plt.ylabel('Voltage (mV)', fontsize=16) - plt.xlim([0,12]) + plt.plot( + truncatedTime[all_thresholdInd[ii]] * 1000, + temp_v_spike_shape_list[ii][all_thresholdInd[ii]] * 1e3, + ".k", + ms=10, + ) + # plt.title('Spike Cutting', fontsize=20) + # plt.subplot(2,1,2) + # plt.plot(truncatedTime, all_v_spike_shape_list[ii]) + plt.plot( + time_at_minExpVar * 1000, + (np.array(all_v_at_min_expVar_list) + El_reference + deltaV) * 1.0e3, + ".k", + ms=10, + ) + plt.xlabel("Time (ms)", fontsize=16) + plt.ylabel("Voltage (mV)", fontsize=16) + plt.xlim([0, 12]) plt.tight_layout() -# plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV)) - - if SHOW_PLOT: - plt.show(block=BLOCK) + # plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV)) + if SHOW_PLOT: + plt.show(block=BLOCK) -# indNotExcluded_In_regress=list(np.setdiff1d(np.array([theInd for theInd in spikeIndDict['nonblip']]), np.array(waveIndOfFirstSpikes))) -# something is wrong with all_v_at_min_expVar_list--look at the difference between starting at .003 and .005 after thresh + # indNotExcluded_In_regress=list(np.setdiff1d(np.array([theInd for theInd in spikeIndDict['nonblip']]), np.array(waveIndOfFirstSpikes))) + # something is wrong with all_v_at_min_expVar_list--look at the difference between starting at .003 and .005 after thresh if MAKE_PLOT: plt.figure(figsize=(20, 10)) - plt.plot(all_v_spike_init_list, all_v_at_min_expVar_list, 'b.', ms=16, label='noise') # list of voltage traces for blip - plt.xlabel('voltage at spike initiation (V)', fontsize=20) - plt.ylabel('voltage after spike (V)', fontsize=20) -# plt.title(cellTitle, fontsize=20) + plt.plot( + all_v_spike_init_list, all_v_at_min_expVar_list, "b.", ms=16, label="noise" + ) # list of voltage traces for blip + plt.xlabel("voltage at spike initiation (V)", fontsize=20) + plt.ylabel("voltage after spike (V)", fontsize=20) + # plt.title(cellTitle, fontsize=20) + + ( + slope_at_min_expVar_list, + intercept_at_min_expVar_list, + r_value_at_min_expVar_list, + p_value_at_min_expVar_list, + std_err_at_min_expVar_list, + ) = stats.linregress(np.array(all_v_spike_init_list), np.array(all_v_at_min_expVar_list)) - slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, p_value_at_min_expVar_list, std_err_at_min_expVar_list = \ - stats.linregress(np.array(all_v_spike_init_list), np.array(all_v_at_min_expVar_list)) + print("mean of voltage before spike", np.mean(all_v_spike_init_list)) + print("mean of voltage after spike", np.mean(all_v_at_min_expVar_list)) - print('mean of voltage before spike', np.mean(all_v_spike_init_list)) - print('mean of voltage after spike', np.mean(all_v_at_min_expVar_list)) - - - spike_cut_length= (list_of_endPointArrays[vectorIndex_of_max_explained_var][0])-int(ALIGN_CUT_WINDOW[0]/dt) #note this is dangerous if they arent' all at the same ind + spike_cut_length = (list_of_endPointArrays[vectorIndex_of_max_explained_var][0]) - int( + ALIGN_CUT_WINDOW[0] / dt + ) # note this is dangerous if they arent' all at the same ind - if MAKE_PLOT: xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) plotLineRegress1(slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, xlim) plt.legend(loc=2, fontsize=20) - if MAKE_PLOT: xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) - plotLineRegressRed(slope_at_each_time_end[vectorIndex_of_max_explained_var], intercept_at_each_time_end[vectorIndex_of_max_explained_var], np.nan, xlim) + plotLineRegressRed( + slope_at_each_time_end[vectorIndex_of_max_explained_var], + intercept_at_each_time_end[vectorIndex_of_max_explained_var], + np.nan, + xlim, + ) plt.legend(loc=2, fontsize=20) if SHOW_PLOT: - plt.show(block=BLOCK) + plt.show(block=BLOCK) if PUBLICATION_PLOT: - plt.figure(figsize=(7, 5)) - plt.plot(np.array(all_v_spike_init_list)*1e3, np.array(all_v_at_min_expVar_list)*1e3, 'b.', ms=16) # list of voltage traces for blip - plt.xlabel('Voltage at spike initiation (mV)', fontsize=16) - plt.ylabel('Voltage after spike (mV)', fontsize=16) -# plt.title('Voltage reset rules', fontsize=20) + plt.plot( + np.array(all_v_spike_init_list) * 1e3, np.array(all_v_at_min_expVar_list) * 1e3, "b.", ms=16 + ) # list of voltage traces for blip + plt.xlabel("Voltage at spike initiation (mV)", fontsize=16) + plt.ylabel("Voltage after spike (mV)", fontsize=16) + # plt.title('Voltage reset rules', fontsize=20) xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) - - def plot_hack(slope, intercept, r,xlim): - y=slope*xlim+intercept - plt.plot(xlim, y, '-k', lw=4)# label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept) - - plot_hack(slope_at_min_expVar_list, intercept_at_min_expVar_list*1e3, r_value_at_min_expVar_list, xlim*1e3) + + def plot_hack(slope, intercept, r, xlim): + y = slope * xlim + intercept + plt.plot(xlim, y, "-k", lw=4) # label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept) + + plot_hack(slope_at_min_expVar_list, intercept_at_min_expVar_list * 1e3, r_value_at_min_expVar_list, xlim * 1e3) plt.legend(loc=2, fontsize=16) plt.tight_layout() plt.show(block=BLOCK) - #TODO: Corinne look to see if these were calculated with zeroed out El if not does is matter? + # TODO: Corinne look to see if these were calculated with zeroed out El if not does is matter? if isinstance(slope_at_min_expVar_list, np.ndarray): - slope_at_min_expVar_list=float(slope_at_min_expVar_list[0]) + slope_at_min_expVar_list = float(slope_at_min_expVar_list[0]) if isinstance(intercept_at_min_expVar_list, np.ndarray): - intercept_at_min_expVar_list=float(intercept_at_min_expVar_list[0]) + intercept_at_min_expVar_list = float(intercept_at_min_expVar_list[0]) - if type(intercept_at_min_expVar_list)==list or type(intercept_at_min_expVar_list)==np.ndarray: - intercept_at_min_expVar_list=intercept_at_min_expVar_list[0] + if type(intercept_at_min_expVar_list) == list or type(intercept_at_min_expVar_list) == np.ndarray: + intercept_at_min_expVar_list = intercept_at_min_expVar_list[0] return spike_cut_length, slope_at_min_expVar_list, intercept_at_min_expVar_list -def plotLineRegress1(slope, intercept, r,xlim): - y=slope*xlim+intercept - print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) - plt.plot(xlim, y, '-k', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r) -def plotLineRegressRed(slope, intercept, r,xlim): - y=slope*xlim+intercept - print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) - plt.plot(xlim, y, '-r', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r) +def plotLineRegress1(slope, intercept, r, xlim): + y = slope * xlim + intercept + print("slope=", slope, "intercept=", intercept, "xlim", xlim) + plt.plot( + xlim, y, "-k", lw=4, label="slope=" + "%.2f" % slope + ", intercept=" + "%.3f" % intercept + ", r=" + "%.2f" % r + ) + + +def plotLineRegressRed(slope, intercept, r, xlim): + y = slope * xlim + intercept + print("slope=", slope, "intercept=", intercept, "xlim", xlim) + plt.plot( + xlim, y, "-r", lw=4, label="slope=" + "%.2f" % slope + ", intercept=" + "%.3f" % intercept + ", r=" + "%.2f" % r + ) diff --git a/allensdk/internal/model/glif/threshold_adaptation.py b/allensdk/internal/model/glif/threshold_adaptation.py index 02c3a3533c..07014fca6e 100644 --- a/allensdk/internal/model/glif/threshold_adaptation.py +++ b/allensdk/internal/model/glif/threshold_adaptation.py @@ -2,219 +2,241 @@ from scipy.optimize import curve_fit import matplotlib.pyplot as plt import logging + THRESH_PCT_MULTIBLIP = 0.05 from allensdk.model.glif.glif_neuron_methods import spike_component_of_threshold_exact from allensdk.internal.model.glif.find_spikes import find_spikes_ssq_list -def calc_spike_component_of_threshold_from_multiblip(multi_SS, dt, dv_cutoff, thresh_frac, - MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, PUBLICATION_PLOT=False): - '''Calculate the spike components of the threshold by fitting a decaying exponential function to data to threshold versus time - since last spike in the multiblip data. The exponential is forced to decay to the local th_inf (calculated as the mean all of the + +def calc_spike_component_of_threshold_from_multiblip( + multi_SS, dt, dv_cutoff, thresh_frac, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False, PUBLICATION_PLOT=False +): + """Calculate the spike components of the threshold by fitting a decaying exponential function to data to threshold versus time + since last spike in the multiblip data. The exponential is forced to decay to the local th_inf (calculated as the mean all of the threshold values of the first spikes in each individual triblip stimulus). For each multiblip stimulus in a stimulus set if there - is more than one spike the difference in voltages from the first and second spike are plotted versus the separation in time. Note that - this algorithm should only be implemented on multiblips sweeps where the neuron spike on the first and second blip. Since there is - no easy way to do this, this erroneous data should not be provided to this algorithm (i.e is should be visually checked and eliminated - the preprocessor should hold back this data manually for now.) - - #TODO: check to see if this is still true. Notes: The standard SDK spike detection algorithm does not work with the multiblip stimulus + is more than one spike the difference in voltages from the first and second spike are plotted versus the separation in time. Note that + this algorithm should only be implemented on multiblips sweeps where the neuron spike on the first and second blip. Since there is + no easy way to do this, this erroneous data should not be provided to this algorithm (i.e is should be visually checked and eliminated + the preprocessor should hold back this data manually for now.) + + #TODO: check to see if this is still true. Notes: The standard SDK spike detection algorithm does not work with the multiblip stimulus due to artifacts when the stimulus turns on and off. Please see the find_multiblip_spikes module for more information. - + Input: - + multi_SS: dictionary contains multiblip information such as current and stimulus dt: float time step in seconds - + Returns: - + const_to_add_to_thresh_for_reset: float amplitude of the exponential fit otherwise known as a_spike. Note that this is without any spike cutting decay_const: float - decay constant of exponential. Note the function fit is a negative exponential which will mean this value will + decay constant of exponential. Note the function fit is a negative exponential which will mean this value will either have to be negated when it is used or the functions used will have to have to include the negative. thresh_inf: float - - ''' - multi_SS_v=multi_SS['voltage'] - multi_SS_i=multi_SS['current'] - # --get indicies of spikes - spike_ind, _=find_spikes_ssq_list(multi_SS_v, dt, dv_cutoff, thresh_frac) -# spike_ind=find_multiblip_spikes(multi_SS_i, multi_SS_v, dt) can depricate find_multiblip_spikes + """ + multi_SS_v = multi_SS["voltage"] + multi_SS_i = multi_SS["current"] + # --get indicies of spikes + spike_ind, _ = find_spikes_ssq_list(multi_SS_v, dt, dv_cutoff, thresh_frac) + # spike_ind=find_multiblip_spikes(multi_SS_i, multi_SS_v, dt) can depricate find_multiblip_spikes # eliminate spurious spikes that may exist - spike_lt=[np.where(SI0: - logging.warning('there is a spike before the stimulus in the multiblip') - spike_ind=[np.delete(SI,ind)for SI, ind in zip(spike_ind, spike_lt)] - spike_gt=[np.where(SI>int(3.0/dt))[0] for SI in spike_ind] - if len(np.concatenate(spike_gt))>0: - logging.warning('there is a spike after the stimulus in the multiblip') - spike_ind=[np.delete(SI,ind)for SI, ind in zip(spike_ind, spike_gt)] - + spike_lt = [np.where(SI < int(2.0 / dt))[0] for SI in spike_ind] + if len(np.concatenate(spike_lt)) > 0: + logging.warning("there is a spike before the stimulus in the multiblip") + spike_ind = [np.delete(SI, ind) for SI, ind in zip(spike_ind, spike_lt)] + spike_gt = [np.where(SI > int(3.0 / dt))[0] for SI in spike_ind] + if len(np.concatenate(spike_gt)) > 0: + logging.warning("there is a spike after the stimulus in the multiblip") + spike_ind = [np.delete(SI, ind) for SI, ind in zip(spike_ind, spike_gt)] + # intialize output lists - time_previous_spike=[] - threshold=[] - thresh_first_spike=[] #will set constant to this + time_previous_spike = [] + threshold = [] + thresh_first_spike = [] # will set constant to this if MAKE_PLOT: - plt.figure(figsize=(20,24)) - + plt.figure(figsize=(20, 24)) + # Loop though each tri blip stimulus in muliblip stimulus for k in range(0, len(multi_SS_v)): - thresh=[multi_SS_v[k][j] for j in spike_ind[k]] # voltage at all spikes in a single tri blip - if thresh!=[] and len(thresh)>1:# there needs to be more than one spike so that we can find the time difference - thresh_first_spike.append(thresh[0]) #Note that this finds the first spike (it might not be at the first stimulus blip) + thresh = [multi_SS_v[k][j] for j in spike_ind[k]] # voltage at all spikes in a single tri blip + if ( + thresh != [] and len(thresh) > 1 + ): # there needs to be more than one spike so that we can find the time difference + thresh_first_spike.append( + thresh[0] + ) # Note that this finds the first spike (it might not be at the first stimulus blip) threshold.append(thresh[1]) - time_previous_spike.append((spike_ind[k][1]-spike_ind[k][0])*dt) -# Old way when looked at all the spikes instead of just the first two (can be depricated; just here for record keeping) -# threshold.append(thresh[1:]) -# time_before_temp=[] -# for j in range(1,len(thresh)): -# time_before_temp.append((spike_ind[k][j]-spike_ind[k][j-1])*dt) -# #for each spike calculate the time from the previous spike -# time_previous_spike.append(time_before_temp) + time_previous_spike.append((spike_ind[k][1] - spike_ind[k][0]) * dt) + # Old way when looked at all the spikes instead of just the first two (can be depricated; just here for record keeping) + # threshold.append(thresh[1:]) + # time_before_temp=[] + # for j in range(1,len(thresh)): + # time_before_temp.append((spike_ind[k][j]-spike_ind[k][j-1])*dt) + # #for each spike calculate the time from the previous spike + # time_previous_spike.append(time_before_temp) if MAKE_PLOT: - plt.subplot(len(multi_SS_v)+1,1,1) - plt.plot(np.arange(0, len(multi_SS_i[k]))*dt, multi_SS_i[k]*1e12, lw=2) - plt.ylabel('current (pA)', fontsize=16) - plt.xlim([2., 2.12]) - plt.title('Triple Short Square', fontsize=20) - plt.subplot(len(multi_SS_v)+1,1,k+2) - plt.plot(np.arange(0, len(multi_SS_v[k]))*dt, multi_SS_v[k], lw=2) - plt.plot(spike_ind[k]*dt, thresh, '.k', ms=16) - plt.xlim([2., 2.12]) - + plt.subplot(len(multi_SS_v) + 1, 1, 1) + plt.plot(np.arange(0, len(multi_SS_i[k])) * dt, multi_SS_i[k] * 1e12, lw=2) + plt.ylabel("current (pA)", fontsize=16) + plt.xlim([2.0, 2.12]) + plt.title("Triple Short Square", fontsize=20) + plt.subplot(len(multi_SS_v) + 1, 1, k + 2) + plt.plot(np.arange(0, len(multi_SS_v[k])) * dt, multi_SS_v[k], lw=2) + plt.plot(spike_ind[k] * dt, thresh, ".k", ms=16) + plt.xlim([2.0, 2.12]) + if MAKE_PLOT: - plt.ylabel('voltage (V)', fontsize=16) - plt.xlabel('time (s)', fontsize=16) - - if SHOW_PLOT: + plt.ylabel("voltage (V)", fontsize=16) + plt.xlabel("time (s)", fontsize=16) + + if SHOW_PLOT: plt.show(block=False) # put numbers into one vector for fitting of exponential function - thresh_inf=np.mean(thresh_first_spike) #note this threshold infinity isnt the one coming from single blip - try: #this try here because sometimes even though have the traces there isnt more than one trace with two spikes -#--these two lines no longer needed because all single values now (depricate with lines up above) instead converst them to arrays -# threshold=np.concatenate(threshold) -# time_previous_spike=np.concatenate(time_previous_spike) #note that this will have nans in it - threshold=np.array(threshold) - time_previous_spike=np.array(time_previous_spike) #note that this will have nans in it - + thresh_inf = np.mean(thresh_first_spike) # note this threshold infinity isnt the one coming from single blip + try: # this try here because sometimes even though have the traces there isnt more than one trace with two spikes + # --these two lines no longer needed because all single values now (depricate with lines up above) instead converst them to arrays + # threshold=np.concatenate(threshold) + # time_previous_spike=np.concatenate(time_previous_spike) #note that this will have nans in it + threshold = np.array(threshold) + time_previous_spike = np.array(time_previous_spike) # note that this will have nans in it + if MAKE_PLOT: plt.figure() - plt.plot(time_previous_spike, threshold, '.k', ms=16) - plt.ylabel('threshold (mV)') - plt.xlabel('time since last spike (s)') - + plt.plot(time_previous_spike, threshold, ".k", ms=16) + plt.ylabel("threshold (mV)") + plt.xlabel("time since last spike (s)") + # calculate values of exponential function both if force function to local threshold infinity and not forcing to a value # (not forcing to a value seems less valid unless a bunch of points are added corresponding to the threshold of the # first spike at time equal infinity (because the first spike is a spike that happens where the spike before it was an # infinite time away)). Therefore, the values that are obtained from forcing are the ones that are used. - p0_force=[.002, -100.] - p0_fit=[.002, -100., thresh_inf] - - #TODO: THIS WOULD BE BETTER IF IT CALLED THE ACTUAL FUNCTION IN THE NEURON METHODS THAT WAY THEY WOULD HAVE TO BE THE SAME - (popt_force, pcov_force)= curve_fit(exp_force_c, (time_previous_spike, thresh_inf), threshold, p0=p0_force, maxfev=100000) - (popt_fit, pcov_fit)= curve_fit(exp_fit_c, time_previous_spike, threshold, p0=p0_fit, maxfev=100000) + p0_force = [0.002, -100.0] + p0_fit = [0.002, -100.0, thresh_inf] + + # TODO: THIS WOULD BE BETTER IF IT CALLED THE ACTUAL FUNCTION IN THE NEURON METHODS THAT WAY THEY WOULD HAVE TO BE THE SAME + (popt_force, pcov_force) = curve_fit( + exp_force_c, (time_previous_spike, thresh_inf), threshold, p0=p0_force, maxfev=100000 + ) + (popt_fit, pcov_fit) = curve_fit(exp_fit_c, time_previous_spike, threshold, p0=p0_fit, maxfev=100000) # viewing fit functions - time_previous_spike.sort() #since time is not in order, making new time vector so that obtained fit curve can be plotted - fit_force=exp_force_c((time_previous_spike, thresh_inf), popt_force[0], popt_force[1]) - fit_fit=exp_fit_c(time_previous_spike, popt_fit[0], popt_fit[1], popt_fit[2]) + time_previous_spike.sort() # since time is not in order, making new time vector so that obtained fit curve can be plotted + fit_force = exp_force_c((time_previous_spike, thresh_inf), popt_force[0], popt_force[1]) + fit_fit = exp_fit_c(time_previous_spike, popt_fit[0], popt_fit[1], popt_fit[2]) if MAKE_PLOT: - plt.plot(time_previous_spike, fit_force, 'r', lw=4, label="exp fit (force const to thesh first spike)\n k=%.3g, amp=%.3g" % (popt_force[1], popt_force[0])) - plt.plot(time_previous_spike, fit_fit, 'b', lw=4, label="exp fit (fit constant)\n k=%.3g, amp=%.3g" % (popt_fit[1], popt_fit[0])) + plt.plot( + time_previous_spike, + fit_force, + "r", + lw=4, + label="exp fit (force const to thesh first spike)\n k=%.3g, amp=%.3g" % (popt_force[1], popt_force[0]), + ) + plt.plot( + time_previous_spike, + fit_fit, + "b", + lw=4, + label="exp fit (fit constant)\n k=%.3g, amp=%.3g" % (popt_fit[1], popt_fit[0]), + ) plt.legend() - if SHOW_PLOT: + if SHOW_PLOT: plt.show(block=False) - - if PUBLICATION_PLOT: + + if PUBLICATION_PLOT: plt.figure(figsize=[14, 5]) - ax1=plt.subplot2grid((2, 2), (0,0)) - ax2=plt.subplot2grid((2, 2), (1,0)) - ax3=plt.subplot2grid((2, 2), (0,1), rowspan=2) + ax1 = plt.subplot2grid((2, 2), (0, 0)) + ax2 = plt.subplot2grid((2, 2), (1, 0)) + ax3 = plt.subplot2grid((2, 2), (0, 1), rowspan=2) for k in range(0, len(multi_SS_v)): - thresh=[multi_SS_v[k][j] for j in spike_ind[k]] - - ax1.plot(np.arange(0, len(multi_SS_i[k]))*dt, multi_SS_i[k]*1.e12, lw=2) - ax1.set_ylabel('Current (pA)', fontsize=16) - ax1.set_xlim([2., 2.12]) + thresh = [multi_SS_v[k][j] for j in spike_ind[k]] + + ax1.plot(np.arange(0, len(multi_SS_i[k])) * dt, multi_SS_i[k] * 1.0e12, lw=2) + ax1.set_ylabel("Current (pA)", fontsize=16) + ax1.set_xlim([2.0, 2.12]) ax1.axes.xaxis.set_ticklabels([]) - #ax1.set_title('Triple Short Square', fontsize=20) - - ax2.plot(np.arange(0, len(multi_SS_v[k]))*dt, multi_SS_v[k]*1.e3, lw=2) - ax2.plot(spike_ind[k]*dt, np.array(thresh)*1.e3, '.k', ms=16) - ax2.set_ylabel('Voltage (mV)', fontsize=16) - ax2.set_xlabel('Time (ms)', fontsize=16) - ax2.set_xlim([2., 2.12]) - - - ax3.plot(time_previous_spike, np.array(threshold)*1.e3, '.k', ms=16) - ax3.set_ylabel('Threshold (mV)', fontsize=16) - ax3.set_xlabel('Time since last spike (s)', fontsize=16) -# ax3.set_title('Spiking component of threshold', fontsize=20) - ax3.plot(time_previous_spike, fit_force*1.e3, 'r', lw=4)#, label="exp fit: k=%.3g, amp=%.3g" % (popt_force[1], popt_force[0])) - ax3.legend() + # ax1.set_title('Triple Short Square', fontsize=20) + + ax2.plot(np.arange(0, len(multi_SS_v[k])) * dt, multi_SS_v[k] * 1.0e3, lw=2) + ax2.plot(spike_ind[k] * dt, np.array(thresh) * 1.0e3, ".k", ms=16) + ax2.set_ylabel("Voltage (mV)", fontsize=16) + ax2.set_xlabel("Time (ms)", fontsize=16) + ax2.set_xlim([2.0, 2.12]) + + ax3.plot(time_previous_spike, np.array(threshold) * 1.0e3, ".k", ms=16) + ax3.set_ylabel("Threshold (mV)", fontsize=16) + ax3.set_xlabel("Time since last spike (s)", fontsize=16) + # ax3.set_title('Spiking component of threshold', fontsize=20) + ax3.plot( + time_previous_spike, fit_force * 1.0e3, "r", lw=4 + ) # , label="exp fit: k=%.3g, amp=%.3g" % (popt_force[1], popt_force[0])) + ax3.legend() plt.tight_layout() plt.show() - - const_to_add_to_thresh_for_reset=popt_force[0] - decay_const=popt_force[1] - - if decay_const >0: - logging.critical('This neuron has an increasing decay value for the spike component of the threshold') - if const_to_add_to_thresh_for_reset<0: - logging.critical('This neuron has a negative amplitude for the spike component of the threshold') - - #if the decay constant is positive, or the amplitute is negative set the amplitude to 0 so that there - #will be no spike component of the threshold - if decay_const >0 or const_to_add_to_thresh_for_reset < 0: - const_to_add_to_thresh_for_reset=0 - + + const_to_add_to_thresh_for_reset = popt_force[0] + decay_const = popt_force[1] + + if decay_const > 0: + logging.critical("This neuron has an increasing decay value for the spike component of the threshold") + if const_to_add_to_thresh_for_reset < 0: + logging.critical("This neuron has a negative amplitude for the spike component of the threshold") + + # if the decay constant is positive, or the amplitute is negative set the amplitude to 0 so that there + # will be no spike component of the threshold + if decay_const > 0 or const_to_add_to_thresh_for_reset < 0: + const_to_add_to_thresh_for_reset = 0 + # This decay constant was originally forced to be positive (i.e. decay_const=abs(popt_force[1])) - # and then it is negated everywhere it is utilized elsewhere in the code. Now things are forced in - # a different way above. However the decay constant still needs to be negated here for use in the - # rest of the code. - decay_const=-decay_const - + # and then it is negated everywhere it is utilized elsewhere in the code. Now things are forced in + # a different way above. However the decay constant still needs to be negated here for use in the + # rest of the code. + decay_const = -decay_const except Exception as e: logging.error(e.message) - const_to_add_to_thresh_for_reset=None - decay_const=None - + const_to_add_to_thresh_for_reset = None + decay_const = None + return const_to_add_to_thresh_for_reset, decay_const, thresh_inf -def fit_avoltage_bvoltage(x, v_trace_list, El_list, spike_cut_length, all_spikeInd_list, th_inf, dt, a_spike, - b_spike, fake=False): - '''This is a version of fit_avoltage_bvoltage_debug that does not require the th_trace, + +def fit_avoltage_bvoltage( + x, v_trace_list, El_list, spike_cut_length, all_spikeInd_list, th_inf, dt, a_spike, b_spike, fake=False +): + """This is a version of fit_avoltage_bvoltage_debug that does not require the th_trace, v_component_of_thresh_trace, and spike_component_of_thresh_trace needed for debugging. A test should be run to make sure the same output comes out from this and the debug function - + This function returns the squared error for the difference between the 'known' voltage - component of the threshold obtained from the biological neuron and the voltage component - of the threshold of the model obtained with the input parameters (so that the minimum can be + component of the threshold obtained from the biological neuron and the voltage component + of the threshold of the model obtained with the input parameters (so that the minimum can be searched for via fmin). The overall threshold is the sum of threshold infinity the spike component - of the threshold and the voltage component of the threshold. Therefore threshold infinity and + of the threshold and the voltage component of the threshold. Therefore threshold infinity and the spike component of the threshold must be subtracted from the threshold of the neuron in order to isolate the voltage component of the threshold. In the evaluation of the model the actual - voltage of the neuron is used so that any errors in the other components of the model will not + voltage of the neuron is used so that any errors in the other components of the model will not influence the fits here (for example, if a afterspike current was estimated incorrectly) - + Notes: - * The spike component of the threshold is subtracted from the + * The spike component of the threshold is subtracted from the voltage which means that the voltage component of the threshold should only be added to rules. - * b_spike was fit using a negative value in the function therefore the negative is placed in the - equation. + * b_spike was fit using a negative value in the function therefore the negative is placed in the + equation. * values in this function are in 'real' voltage as opposed to voltage - relative to resting potential. - * current injection during the spike is not taken into account. This seems reasonable as the + relative to resting potential. + * current injection during the spike is not taken into account. This seems reasonable as the ion channels are open during this time and injected current may not greatly influence the neuron. - + x: numpy array x[0]=a_voltage input, x[1] is b_voltage_input, x[2] is th_inf v_trace_list: list of numpy arrays @@ -224,10 +246,10 @@ def fit_avoltage_bvoltage(x, v_trace_list, El_list, spike_cut_length, all_spikeI spike_cut_length: int number of indicies removed after initiation of a spike all_spikeInd_list: list of numpy arrays - indicies of spike trains + indicies of spike trains th_inf: float threshold infinity (v_trace, El, and th_inf must be in the same frame of reference) - dt: float + dt: float size of time step (SI units) a_spike: float amplitude of spike component of threshold. @@ -236,86 +258,104 @@ def fit_avoltage_bvoltage(x, v_trace_list, El_list, spike_cut_length, all_spikeI fake: Boolean if True makes uses the voltage value of spike step-1 because there is not a voltage value at the spike step because it is set to nan in the simulator. - ''' - a_voltage=x[0] - b_voltage=x[1] + """ + a_voltage = x[0] + b_voltage = x[1] - total_err=0 + total_err = 0 for v_trace, El, all_spikeInd in zip(v_trace_list, El_list, all_spikeInd_list): # Calculate values along the whole trace and then take the values at the spike ind - internal_sp_comp_array=np.zeros(all_spikeInd[0]+spike_cut_length) - left_over=0 - #vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero - #Note that care has to be taken here to get make sure the right amount of decay is left over - for spike_number in range(1,len(all_spikeInd)): #skipping first spike since no residual spike component of threshold - integration_length=all_spikeInd[spike_number]-all_spikeInd[spike_number-1]+1 #this is the amount of time that needs to be integrated over note that it is one longer than the interval because of the last value is added to the aspike for the next ISI - local_spike_comp_of_threshold=spike_component_of_threshold_exact(a_spike+left_over, b_spike, np.arange(integration_length)*dt) - internal_sp_comp_array=np.append(internal_sp_comp_array, local_spike_comp_of_threshold[:-1]) - left_over=local_spike_comp_of_threshold[-1] - + internal_sp_comp_array = np.zeros(all_spikeInd[0] + spike_cut_length) + left_over = 0 + # vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero + # Note that care has to be taken here to get make sure the right amount of decay is left over + for spike_number in range( + 1, len(all_spikeInd) + ): # skipping first spike since no residual spike component of threshold + integration_length = ( + all_spikeInd[spike_number] - all_spikeInd[spike_number - 1] + 1 + ) # this is the amount of time that needs to be integrated over note that it is one longer than the interval because of the last value is added to the aspike for the next ISI + local_spike_comp_of_threshold = spike_component_of_threshold_exact( + a_spike + left_over, b_spike, np.arange(integration_length) * dt + ) + internal_sp_comp_array = np.append(internal_sp_comp_array, local_spike_comp_of_threshold[:-1]) + left_over = local_spike_comp_of_threshold[-1] + # Compute voltage component of threshold at biological spike (subtract th_inf and spike component of threshold # from biological voltage values at spike initiation) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # NOTE THAT THERE IS AN ISSUE HERE USING FAKE DATA. THE -1 IS HERE BECAUSE THE NEURON CROSSES THRESHOLD SOMETIME BETWEEN TWO INDICIES. # FOR THE FAKE DATA THE TIME OF THE SPIKE (THE POINT FOLLOWING WHEN THE VOLTAGE CROSSES THRESHOLD) IS SET TO NAN. - # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE + # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE # TRACE AND POSSIBLY IN THE INTEGRATION. if fake: - v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd-1]-internal_sp_comp_array[all_spikeInd-1]-th_inf #USE THIS FOR FAKE DATA + v_comp_of_th_at_each_spike_via_data = ( + v_trace[all_spikeInd - 1] - internal_sp_comp_array[all_spikeInd - 1] - th_inf + ) # USE THIS FOR FAKE DATA else: - v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd]-internal_sp_comp_array[all_spikeInd]-th_inf #USE THIS FOR REAL DATA (although probably not necessary) + v_comp_of_th_at_each_spike_via_data = ( + v_trace[all_spikeInd] - internal_sp_comp_array[all_spikeInd] - th_inf + ) # USE THIS FOR REAL DATA (although probably not necessary) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - + # For each ISI, calculate the difference between the voltage dependent component of the threshold - # and the value that would be determined via a model that uses the actual voltage of neuron. - sq_err = [] #list to store squared error between the model and biological threshold - for spike_number in range(1, len(all_spikeInd)): #loop over all ISI's in data - v_start_ind = all_spikeInd[spike_number-1]+int(spike_cut_length) #dont want to use voltage during a spike + # and the value that would be determined via a model that uses the actual voltage of neuron. + sq_err = [] # list to store squared error between the model and biological threshold + for spike_number in range(1, len(all_spikeInd)): # loop over all ISI's in data + v_start_ind = all_spikeInd[spike_number - 1] + int( + spike_cut_length + ) # dont want to use voltage during a spike end_ind = all_spikeInd[spike_number] - v_in_ISI=v_trace[v_start_ind:end_ind] - #voltage component of threshold at the beginning and end of the ISI - #With fake data if go from one before fake data to fake data this should be exact - theta0=v_comp_of_th_at_each_spike_via_data[spike_number-1] #this assumes that the voltage component of the threshold does not change over the time period of the spike + v_in_ISI = v_trace[v_start_ind:end_ind] + # voltage component of threshold at the beginning and end of the ISI + # With fake data if go from one before fake data to fake data this should be exact + theta0 = v_comp_of_th_at_each_spike_via_data[ + spike_number - 1 + ] # this assumes that the voltage component of the threshold does not change over the time period of the spike # not sure this makes sense any moretheta0=v_comp_of_th_at_each_spike_via_data[spike_number-1]+sp_comp_of_offset_at_spike_list[spike_number-1] #use this is you want to add the biological component back on - theta1=v_comp_of_th_at_each_spike_via_data[spike_number] - tvec=np.arange(len(v_in_ISI))*dt - - #analytical solution should be exact with fake data--small differences could be because of the differences in the voltage at - #spike indicies are off by one - model=+theta0*np.exp(-b_voltage*dt*(end_ind-v_start_ind))+a_voltage*np.exp(-b_voltage*tvec[-1])*np.sum(dt*(v_in_ISI-El)*np.exp(b_voltage*tvec)) - err = (theta1-model)**2 + theta1 = v_comp_of_th_at_each_spike_via_data[spike_number] + tvec = np.arange(len(v_in_ISI)) * dt + + # analytical solution should be exact with fake data--small differences could be because of the differences in the voltage at + # spike indicies are off by one + model = +theta0 * np.exp(-b_voltage * dt * (end_ind - v_start_ind)) + a_voltage * np.exp( + -b_voltage * tvec[-1] + ) * np.sum(dt * (v_in_ISI - El) * np.exp(b_voltage * tvec)) + err = (theta1 - model) ** 2 if ~np.isnan(err): sq_err.append(err) - - total_err+=np.sum(sq_err) + + total_err += np.sum(sq_err) return total_err -def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spikeInd_list, dt, a_spike, - b_spike, fake=False): - '''This is a version of fit_avoltage_bvoltage_th_debug that does not require the th_trace, + +def fit_avoltage_bvoltage_th( + x, v_trace_list, El_list, spike_cut_length, all_spikeInd_list, dt, a_spike, b_spike, fake=False +): + """This is a version of fit_avoltage_bvoltage_th_debug that does not require the th_trace, v_component_of_thresh_trace, and spike_component_of_thresh_trace needed for debugging. A test should be run to make sure the same output comes out from this and the debug function - + This function returns the squared error for the difference between the 'known' voltage - component of the threshold obtained from the biological neuron and the voltage component - of the threshold of the model obtained with the input parameters (so that the minimum can be + component of the threshold obtained from the biological neuron and the voltage component + of the threshold of the model obtained with the input parameters (so that the minimum can be searched for via fmin). The overall threshold is the sum of threshold infinity the spike component - of the threshold and the voltage component of the threshold. Therefore threshold infinity and + of the threshold and the voltage component of the threshold. Therefore threshold infinity and the spike component of the threshold must be subtracted from the threshold of the neuron in order to isolate the voltage component of the threshold. In the evaluation of the model the actual - voltage of the neuron is used so that any errors in the other components of the model will not + voltage of the neuron is used so that any errors in the other components of the model will not influence the fits here (for example, if a afterspike current was estimated incorrectly) - + Notes: - * The spike component of the threshold is subtracted from the + * The spike component of the threshold is subtracted from the voltage which means that the voltage component of the threshold should only be added to rules. - * b_spike was fit using a negative value in the function therefore the negative is placed in the - equation. + * b_spike was fit using a negative value in the function therefore the negative is placed in the + equation. * values in this function are in 'real' voltage as opposed to voltage - relative to resting potential. - * current injection during the spike is not taken into account. This seems reasonable as the + relative to resting potential. + * current injection during the spike is not taken into account. This seems reasonable as the ion channels are open during this time and injected current may not greatly influence the neuron. - + x: numpy array x[0]=a_voltage input, x[1] is b_voltage_input, x[2] is th_inf v_trace_list: list of numpy arrays @@ -325,8 +365,8 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi spike_cut_length: int number of indicies removed after initiation of a spike all_spikeInd_list: list of numpy arrays - indicies of spike trains - dt: float + indicies of spike trains + dt: float size of time step (SI units) a_spike: float amplitude of spike component of threshold. @@ -335,84 +375,100 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi fake: Boolean if True makes uses the voltage value of spike step-1 because there is not a voltage value at the spike step because it is set to nan in the simulator. - ''' - a_voltage=x[0] - b_voltage=x[1] - th_inf=x[2] + """ + a_voltage = x[0] + b_voltage = x[1] + th_inf = x[2] - total_err=0 + total_err = 0 for v_trace, El, all_spikeInd in zip(v_trace_list, El_list, all_spikeInd_list): # Calculate values along the whole trace and then take the values at the spike ind - internal_sp_comp_array=np.zeros(all_spikeInd[0]+spike_cut_length) - left_over=0 - #vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero - #Note that care has to be taken here to get make sure the right amount of decay is left over - for spike_number in range(1,len(all_spikeInd)): #skipping first spike since no residual spike component of threshold - integration_length=all_spikeInd[spike_number]-all_spikeInd[spike_number-1]+1 #this is the amount of time that needs to be integrated over note that it is one longer than the interval because of the last value is added to the aspike for the next ISI - local_spike_comp_of_threshold=spike_component_of_threshold_exact(a_spike+left_over, b_spike, np.arange(integration_length)*dt) - internal_sp_comp_array=np.append(internal_sp_comp_array, local_spike_comp_of_threshold[:-1]) - left_over=local_spike_comp_of_threshold[-1] - + internal_sp_comp_array = np.zeros(all_spikeInd[0] + spike_cut_length) + left_over = 0 + # vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero + # Note that care has to be taken here to get make sure the right amount of decay is left over + for spike_number in range( + 1, len(all_spikeInd) + ): # skipping first spike since no residual spike component of threshold + integration_length = ( + all_spikeInd[spike_number] - all_spikeInd[spike_number - 1] + 1 + ) # this is the amount of time that needs to be integrated over note that it is one longer than the interval because of the last value is added to the aspike for the next ISI + local_spike_comp_of_threshold = spike_component_of_threshold_exact( + a_spike + left_over, b_spike, np.arange(integration_length) * dt + ) + internal_sp_comp_array = np.append(internal_sp_comp_array, local_spike_comp_of_threshold[:-1]) + left_over = local_spike_comp_of_threshold[-1] + # Compute voltage component of threshold at biological spike (subtract th_inf and spike component of threshold # from biological voltage values at spike initiation) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # NOTE THAT THERE IS AN ISSUE HERE USING FAKE DATA. THE -1 IS HERE BECAUSE THE NEURON CROSSES THRESHOLD SOMETIME BETWEEN TWO INDICIES. # FOR THE FAKE DATA THE TIME OF THE SPIKE (THE POINT FOLLOWING WHEN THE VOLTAGE CROSSES THRESHOLD) IS SET TO NAN. - # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE + # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE # TRACE AND POSSIBLY IN THE INTEGRATION. if fake: - v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd-1]-internal_sp_comp_array[all_spikeInd-1]-th_inf #USE THIS FOR FAKE DATA + v_comp_of_th_at_each_spike_via_data = ( + v_trace[all_spikeInd - 1] - internal_sp_comp_array[all_spikeInd - 1] - th_inf + ) # USE THIS FOR FAKE DATA else: - v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd]-internal_sp_comp_array[all_spikeInd]-th_inf #USE THIS FOR REAL DATA (although probably not necessary) + v_comp_of_th_at_each_spike_via_data = ( + v_trace[all_spikeInd] - internal_sp_comp_array[all_spikeInd] - th_inf + ) # USE THIS FOR REAL DATA (although probably not necessary) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - + # For each ISI, calculate the difference between the voltage dependent component of the threshold - # and the value that would be determined via a model that uses the actual voltage of neuron. - sq_err = [] #list to store squared error between the model and biological threshold - for spike_number in range(1, len(all_spikeInd)): #loop over all ISI's in data - v_start_ind = all_spikeInd[spike_number-1]+int(spike_cut_length) #dont want to use voltage during a spike + # and the value that would be determined via a model that uses the actual voltage of neuron. + sq_err = [] # list to store squared error between the model and biological threshold + for spike_number in range(1, len(all_spikeInd)): # loop over all ISI's in data + v_start_ind = all_spikeInd[spike_number - 1] + int( + spike_cut_length + ) # dont want to use voltage during a spike end_ind = all_spikeInd[spike_number] - v_in_ISI=v_trace[v_start_ind:end_ind] - #voltage component of threshold at the beginning and end of the ISI - #With fake data if go from one before fake data to fake data this should be exact - theta0=v_comp_of_th_at_each_spike_via_data[spike_number-1] #this assumes that the voltage component of the threshold does not change over the time period of the spike + v_in_ISI = v_trace[v_start_ind:end_ind] + # voltage component of threshold at the beginning and end of the ISI + # With fake data if go from one before fake data to fake data this should be exact + theta0 = v_comp_of_th_at_each_spike_via_data[ + spike_number - 1 + ] # this assumes that the voltage component of the threshold does not change over the time period of the spike # not sure this makes sense any moretheta0=v_comp_of_th_at_each_spike_via_data[spike_number-1]+sp_comp_of_offset_at_spike_list[spike_number-1] #use this is you want to add the biological component back on - theta1=v_comp_of_th_at_each_spike_via_data[spike_number] - tvec=np.arange(len(v_in_ISI))*dt - - #analytical solution should be exact with fake data--small differences could be because of the differences in the voltage at - #spike indicies are off by one - model=+theta0*np.exp(-b_voltage*dt*(end_ind-v_start_ind))+a_voltage*np.exp(-b_voltage*tvec[-1])*np.sum(dt*(v_in_ISI-El)*np.exp(b_voltage*tvec)) - err = (theta1-model)**2 + theta1 = v_comp_of_th_at_each_spike_via_data[spike_number] + tvec = np.arange(len(v_in_ISI)) * dt + + # analytical solution should be exact with fake data--small differences could be because of the differences in the voltage at + # spike indicies are off by one + model = +theta0 * np.exp(-b_voltage * dt * (end_ind - v_start_ind)) + a_voltage * np.exp( + -b_voltage * tvec[-1] + ) * np.sum(dt * (v_in_ISI - El) * np.exp(b_voltage * tvec)) + err = (theta1 - model) ** 2 if ~np.isnan(err): sq_err.append(err) - - total_err+=np.sum(sq_err) + + total_err += np.sum(sq_err) return total_err -#TODO: depricate confirmed use of fit_avoltage_bvoltage_th -#def err_fix_th(x, v_trace, El, spike_cut_length, all_spikeInd, th_inf, dt, a_spike, b_spike): -# '''This function returns the squared error for the difference between the 'known' voltage -# component of the threshold obtained from the biological neuron and the voltage component -# of the threshold of the model obtained with the input parameters (so that the minimum can be +# TODO: depricate confirmed use of fit_avoltage_bvoltage_th +# def err_fix_th(x, v_trace, El, spike_cut_length, all_spikeInd, th_inf, dt, a_spike, b_spike): +# '''This function returns the squared error for the difference between the 'known' voltage +# component of the threshold obtained from the biological neuron and the voltage component +# of the threshold of the model obtained with the input parameters (so that the minimum can be # searched for via fmin). The overall threshold is the sum of threshold infinity the spike component -# of the threshold and the voltage component of the threshold. Therefore threshold infinity and +# of the threshold and the voltage component of the threshold. Therefore threshold infinity and # the spike component of the threshold must be subtracted from the threshold of the neuron in order # to isolate the voltage component of the threshold. In the evaluation of the model the actual -# voltage of the neuron is used so that any errors in the other components of the model will not +# voltage of the neuron is used so that any errors in the other components of the model will not # influence the fits here (for example if a afterspike current was estimated incorrectly) -# +# # Notes: -# * The spike component of the threshold is subtracted from the +# * The spike component of the threshold is subtracted from the # voltage which means that the voltage component of the threshold should only be added to rules. -# * b_spike was fit using a negative value in the function therefore the negative is placed in the -# equation. +# * b_spike was fit using a negative value in the function therefore the negative is placed in the +# equation. # * values in this function are in 'real' voltage as opposed to voltage -# relative to resting potential. -# * current injection during the spike is not taken into account. This seems reasonable as the +# relative to resting potential. +# * current injection during the spike is not taken into account. This seems reasonable as the # ion channels are open during this time and injected current may not greatly influence the neuron. -# +# # x: numpy array # x[0]=a_voltage input, x[1] is b_voltage_input # voltage: numpy array @@ -422,10 +478,10 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi # spike_cut_length: int # number of indicies removed after initiation of a spike # all_spikeInd: numpy array -# indicies of spike train +# indicies of spike train # th_inf: float # threshold infinity (voltage, El, and th_inf must be in the same frame of reference) -# dt: float +# dt: float # size of time step (SI units) # a_spike: float # amplitude of spike component of threshold. @@ -435,7 +491,7 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi # a_voltage=x[0] # b_voltage=x[1] # # effect of the spike component of the threshold from each spike and previous spikes -# sp_comp_of_offset_sum_vector=[0] #vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero +# sp_comp_of_offset_sum_vector=[0] #vector of spike component of of the threshold from each spike and previous spikes; note at first spike there is no spike component of threshold so initialized at zero # for spike_number in range(1,len(all_spikeInd)): #skipping first spike since no residual spike component of threshold # t=(all_spikeInd[spike_number]-all_spikeInd[spike_number-1]-spike_cut_length)*dt #spike ISI # sp_comp_of_th_offset_local = spike_component_of_threshold_exact(a_spike, b_spike, t) #spike component of threshold at each ISI for each individual spike @@ -450,21 +506,21 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # # NOTE THAT THERE IS AN ISSUE HERE USING FAKE DATA. THE -1 IS HERE BECAUSE THE NEURON CROSSES THRESHOLD SOMETIME BETWEEN TWO INDICIES. # # FOR THE FAKE DATA THE TIME OF THE SPIKE (THE POINT FOLLOWING WHEN THE VOLTAGE CROSSES THRESHOLD) IS SET TO NAN. -# # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE +# # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE # # TRACE AND POSSIBLY IN THE INTEGRATION. # v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd-1]-np.array(sp_comp_of_offset_sum_vector)-th_inf #USE THIS FOR FAKE DATA ## v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd]-np.array(sp_comp_of_offset_sum_vector)-th_inf #THIS IS PROBABLY APPROPRIATE FOR REAL DATA # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! -# +# # # For each ISI, calculate the difference between the v_trace dependent component of the threshold -# # and the value that would be determined via a model that uses the actual v_trace of neuron. +# # and the value that would be determined via a model that uses the actual v_trace of neuron. # sq_err = [] #list to store squared error between the model and biological threshold # for spike_number in range(1, len(all_spikeInd)): #loop over all ISI's in data # v_start_ind = all_spikeInd[spike_number-1]+int(spike_cut_length) #dont want to use voltage during a spike # end_ind = all_spikeInd[spike_number] # v_in_ISI=v_trace[v_start_ind:end_ind] # #v_trace component of threshold at the beginning and end of the ISI -# +# # theta0=v_comp_of_th_at_each_spike_via_data[spike_number-1] #this assumes that the voltage component of the threshold does not change over the time period of the spike # # not sure this makes sense any moretheta0=v_comp_of_th_at_each_spike_via_data[spike_number-1]+sp_comp_of_offset_sum_vector[spike_number-1] #use this is you want to add the biological component back on # theta1=v_comp_of_th_at_each_spike_via_data[spike_number] @@ -475,22 +531,22 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi # err = (theta1-model)**2 # if ~np.isnan(err): # sq_err.append(err) -# +# # return np.sum(sq_err) # TODO: can depricate, using sdk now -#def find_multiblip_spikes(multi_SS_i, multi_SS_v, dt): -# '''artifacts caused by turning stimulus on and off created artifacts that -# created problems for the standard spike detection algorithm. Several alterations +# def find_multiblip_spikes(multi_SS_i, multi_SS_v, dt): +# '''artifacts caused by turning stimulus on and off created artifacts that +# created problems for the standard spike detection algorithm. Several alterations # were made so that the algorithm would detect spikes appropriately. Please see multiblip -# spike cutting documentation for more information on how the code differs from the SDK +# spike cutting documentation for more information on how the code differs from the SDK # version and what was needed to solve specific issues. # input: # multi_SS_i: list of arrays # each array corresponds to the current stimulation of one triblip stimulus # multi_SS_v: list of arrays # each array corresponds to the voltage trace of triblip simulus -# dt: float +# dt: float # ''' # artifact_ave_window_time_s=0.0003 # # window_indicies_to_ave_len=int(artifact_ave_window_time_s/dt) @@ -502,14 +558,14 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi # potential_artifact_indexes=np.sort(np.append(up_blip_index, down_blip_index)) # artifact_removed_voltage=copy.deepcopy(voltage) # window_boarders_index=[] -# #--remove artifacts +# #--remove artifacts # for index in potential_artifact_indexes: # smooth_window=range(index,index+window_indicies_to_ave_len+1) -# +# # # interpolate in the smoothing window # blah=interp1d([smooth_window[0], smooth_window[-1]], [artifact_removed_voltage[smooth_window[0]], artifact_removed_voltage[smooth_window[-1]]]) # artifact_removed_voltage[smooth_window]=blah(smooth_window) -# +# # #windows boarders are just for plotting ## window_boarders_index.append(smooth_window[0]) ## window_boarders_index.append(smooth_window[-1]) @@ -519,83 +575,83 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi ## plt.plot(window_boarders_index, artifact_removed_voltage[window_boarders_index], '|g', ms=10) ## plt.xlim([40400, 41000]) ## plt.show() -## +## ## t = np.arange(0, len()) * dt # -# # keeping the smooth_v convention of the SDK find spike code. However in the SDK code +# # keeping the smooth_v convention of the SDK find spike code. However in the SDK code # # this is used to name data potentially smoothed by a bessel filter # smooth_v = artifact_removed_voltage -# dv = np.diff(smooth_v) +# dv = np.diff(smooth_v) # dvdt = dv / dt # dvv = np.diff(dvdt) -# +# # v=smooth_v[:-1] #truncating the end of v so it has the same dimensions as dvdt for time plotting # # spikes = [] # out_spk_idxs = [] -# +# # peaks=get_peaks(v) # find potential spikes by finding peak over zero mv -# +# # # Etay defines spike as time of threshold crossing. Threshold is defined as the time at which dvdt is some percent of maximum threshold. # # TODO: figure out how maximum threshold is defined in original code so I can say why I don't use it # for spk_n, peak_idx in enumerate(peaks): # #---------find spike peak---------------------------- # spk = {} -# -# spk["peak_idx"] = peak_idx -# upstroke_idx = np.argmax(dvdt[peak_idx-int(.001/dt):peak_idx]) + peak_idx-int(.001/dt) +# +# spk["peak_idx"] = peak_idx +# upstroke_idx = np.argmax(dvdt[peak_idx-int(.001/dt):peak_idx]) + peak_idx-int(.001/dt) # spk["upstroke"] = dvdt[upstroke_idx] # spk["upstroke_idx"] = upstroke_idx # spk["upstroke_v"] = v[upstroke_idx] -# +# # # Define threshold where dvdt = 5% * max upstroke # dvdt_thr_target = THRESH_PCT_MULTIBLIP * spk["upstroke"] # #print 'spk[upstroke]', spk["upstroke"], 'dvdt_thr_target', dvdt_thr_target # prev_idx = peak_idx-int(.0035/dt) -# #check to make sure prev_idx is not before or in a window where the stimulus blip comes on because +# #check to make sure prev_idx is not before or in a window where the stimulus blip comes on because # #it will errorniously trip the threshold dvdt # for index in up_blip_index: # if prev_idx<=index+int(.0005/dt) and prev_idx>= index-int(.0035/dt): # prev_idx=index+int(.0005/dt) -# +# # mean_dvv= [np.mean(dvv[pv-2:pv+3]) for pv in range(prev_idx,upstroke_idx)] #makes sure dv2/dt2 isnt spuriously going down by averaging 5 points -# find_thresh_idxs = np.where(np.logical_and(dvdt[prev_idx:upstroke_idx] >= dvdt_thr_target, np.greater(mean_dvv,0)))[0] -# +# find_thresh_idxs = np.where(np.logical_and(dvdt[prev_idx:upstroke_idx] >= dvdt_thr_target, np.greater(mean_dvv,0)))[0] +# # if len(find_thresh_idxs) < 1: # Can't find a good threshold value - probably a bad simulation case # # Fall back to the upstroke value # threshold_idx = upstroke_idx # else: # threshold_idx = find_thresh_idxs[0] + prev_idx -# +# # spk["threshold_idx"] = threshold_idx # spk["threshold_v"] = v[threshold_idx] -# +# # # Check for things that are probably not spikes: -# +# # # if the "spike" is less than 2 mV from threshold to peak, don't count it -# if v[peak_idx] - v[threshold_idx] < 0.002: +# if v[peak_idx] - v[threshold_idx] < 0.002: # print("\tnot counting spike is closer to peak than 2 mV") # continue -# +# # #NOTE: because threshold doesnt decay to zero in the multiblip this doesnt get rid of the situation that usually is only in the first spike of a stimulus # # if the spike is less the -30mV, don't count it # if v[peak_idx] < -0.04: # print("\tnot counting spike: peak is too small") # continue -# +# # spikes.append(spk) -# +# # #----figure out if I should still find a global threshold and then do it all again # # # find global threshold which is an average of the individual thresholds # # if len(spikes) > 0: # # dvdt_thr_target = np.array([spk["upstroke"] for spk in spikes]).mean() * THRESH_PCT_MULTIBLIP # # else: # if there weren't any spikes, move along # # return np.array([]) -# +# # out_spk_idxs.append(spk["threshold_idx"]) -# +# # out_spk_idxs_list.append(np.array(out_spk_idxs)) -# +# ## time_vector=np.arange(len(v))*dt ## plt.figure() ## # plt.subplot(3,1,1) @@ -619,28 +675,32 @@ def fit_avoltage_bvoltage_th(x, v_trace_list, El_list, spike_cut_length, all_spi ## plt.plot(time_vector[potential_artifact_indexes], v[potential_artifact_indexes], 'b|', ms=24, lw=4) ## plt.legend() ## plt.show() -# +# # return out_spk_idxs_list + def get_peaks(voltage, aboveValue=0): - '''This function was written by Corinne Teeter and calculates the action potential peaks of a voltage equation" + """This function was written by Corinne Teeter and calculates the action potential peaks of a voltage equation" inputs voltage: numpy array of voltages aboveValue: scalar voltage value over which voltage is considered a spike. - outputs: - peakInd: array of indicies of peaks''' - VshiftR=np.concatenate(([0], voltage[0:voltage.size-2])) - VshiftL=voltage[1:voltage.size] - IndShiftR=np.where(voltage[0:voltage.size-1]>VshiftR) - IndShiftL=np.where(voltage[0:voltage.size-1]>VshiftL) - greatThanThresh=np.where(voltage>aboveValue) #finds indicies greater than the value provided - peakInd=np.intersect1d(np.intersect1d(IndShiftL[0], IndShiftR[0]), greatThanThresh[0]) #find the indicies of the peak + outputs: + peakInd: array of indicies of peaks""" + VshiftR = np.concatenate(([0], voltage[0 : voltage.size - 2])) + VshiftL = voltage[1 : voltage.size] + IndShiftR = np.where(voltage[0 : voltage.size - 1] > VshiftR) + IndShiftL = np.where(voltage[0 : voltage.size - 1] > VshiftL) + greatThanThresh = np.where(voltage > aboveValue) # finds indicies greater than the value provided + peakInd = np.intersect1d( + np.intersect1d(IndShiftL[0], IndShiftR[0]), greatThanThresh[0] + ) # find the indicies of the peak return peakInd + def exp_force_c(t_const, a1, k1): (t, const) = t_const - return a1*(np.exp(k1*t))+const + return a1 * (np.exp(k1 * t)) + const -def exp_fit_c(t, a1, k1, const): - return a1*(np.exp(k1*t))+const +def exp_fit_c(t, a1, k1, const): + return a1 * (np.exp(k1 * t)) + const diff --git a/allensdk/internal/morphology/compartment.py b/allensdk/internal/morphology/compartment.py index 2a88781a80..bf95246856 100644 --- a/allensdk/internal/morphology/compartment.py +++ b/allensdk/internal/morphology/compartment.py @@ -15,6 +15,7 @@ import allensdk.internal.morphology.node as node + class Compartment(object): def __init__(self, node1, node2): if not isinstance(node1, node.Node) or not isinstance(node2, node.Node): @@ -28,4 +29,3 @@ def __str__(self): s = "%s %f" % (str(self.center), self.length) s += "\n\t" + self.node1.short_string() + "\n\t" + self.node2.short_string() return s - diff --git a/allensdk/internal/morphology/morphology.py b/allensdk/internal/morphology/morphology.py index 84491ebdd4..5b5b104149 100644 --- a/allensdk/internal/morphology/morphology.py +++ b/allensdk/internal/morphology/morphology.py @@ -20,9 +20,9 @@ from allensdk.internal.morphology.compartment import Compartment -class Morphology( object ): - """ - Keep track of the list of nodes in a morphology and provide +class Morphology(object): + """ + Keep track of the list of nodes in a morphology and provide a few helper methods (soma, tree information, pruning, etc). """ @@ -31,21 +31,21 @@ class Morphology( object ): BASAL_DENDRITE = 3 APICAL_DENDRITE = 4 - NODE_TYPES = [ SOMA, AXON, BASAL_DENDRITE, APICAL_DENDRITE ] + NODE_TYPES = [SOMA, AXON, BASAL_DENDRITE, APICAL_DENDRITE] def __init__(self, node_list=None): - """ + """ Try to initialize from a list of nodes first, then from a dictionary indexed by node id if that fails, and finally just leave everything empty. - + Parameters ---------- - node_list: list + node_list: list list of Node objects """ - self._node_list = [] # list of morphology node IDs - self._compartment_list = [] # list of morphology compartment IDs + self._node_list = [] # list of morphology node IDs + self._compartment_list = [] # list of morphology compartment IDs ############################################## # define tree list here for clarity, even though it's reset below @@ -58,7 +58,7 @@ def __init__(self, node_list=None): # NOTE: if morphology is manually manipulated, this value can # become incorrect self.dims = None - + ############################################## # construct the node list # first try to do so using the node list, then try using @@ -76,20 +76,19 @@ def __init__(self, node_list=None): # and construct internal associations self._reconstruct() - #################################################################### #################################################################### # class properties, and helper functions for them - @property + @property def node_list(self): - """ Return the node list. This is a property to ensure that the - node list and node index are in sync. """ + """Return the node list. This is a property to ensure that the + node list and node index are in sync.""" return self._node_list @node_list.setter def node_list(self, node_list): - """ Update the node list. """ + """Update the node list.""" self._set_nodes(node_list) @property @@ -98,14 +97,14 @@ def compartment_list(self): @property def num_trees(self): - """ Return the number of trees in the morphology. A tree is - defined as everything following from a single root node. """ + """Return the number of trees in the morphology. A tree is + defined as everything following from a single root node.""" return len(self._tree_list) @property def num_nodes(self): - """ - Return the number of nodes in the morphology. + """ + Return the number of nodes in the morphology. """ return len(self.node_list) @@ -117,7 +116,7 @@ def _set_nodes(self, node_list): file while also being able to initialize from the node list of an existing Morphology object. As nodes in a morphology object contain reference to nodes in that object, make a shallow copy - of input nodes and overwrite known references (ie, the + of input nodes and overwrite known references (ie, the 'children' array) """ self._node_list = [] @@ -133,10 +132,10 @@ def _set_nodes(self, node_list): # a soma can consist of multiple compartments, and there can # be multiple roots # replaced those calls with new soma_root(), which returns the - # same thing as the old soma() when only a single soma node is + # same thing as the old soma() when only a single soma node is # present def soma_root(self): - """ Returns root node of soma, if present""" + """Returns root node of soma, if present""" if len(self._tree_list) > 0 and self._tree_list[0][0].t == 1: return self._tree_list[0][0] return None @@ -146,16 +145,16 @@ def soma_root(self): # tree and node access def tree(self, n): - """ + """ Returns a list of all Morphology nodes within the specified tree. A tree is defined as a fully connected graph of nodes. Each tree has exactly one root. - + Parameters ---------- n: integer ID of desired tree - + Returns ------- A list of all morphology objects in the specified tree, or None @@ -165,36 +164,34 @@ def tree(self, n): return None return self._tree_list[n] - def node(self, n): """ Returns the morphology node having the specified ID. - + Parameters ---------- n: integer ID of desired node - + Returns ------- A morphology node having the specified ID, or None if such a node doesn't exist """ # undocumented feature -- if a node is supplied instead of a - # node ID, the node is returned and no error is + # node ID, the node is returned and no error is # triggered return self._resolve_node_type(n) - def compartment(self, n): - """ + """ Returns the morphology Compartment having the specified ID. - + Parameters ---------- n: integer ID of desired compartment - + Returns ------- A morphology object having the specified ID, or None if such a @@ -204,22 +201,21 @@ def compartment(self, n): return None return self._compartment_list[n] - def parent_of(self, seg): - """ Returns parent of the specified node. - + """Returns parent of the specified node. + Parameters ---------- seg: integer or Morphology Object The ID of the child node, or the child node itself - + Returns ------- A morphology object, or None if no parent exists or if the specified node ID doesn't exist """ # if ID passed in, make sure it's converted to a node - # don't trap for exception here -- if supplied segment is + # don't trap for exception here -- if supplied segment is # incorrect, make sure the user knows about it seg = self._resolve_node_type(seg) # return parent of specified node @@ -227,23 +223,21 @@ def parent_of(self, seg): return self._node_list[seg.parent] return None - def children_of(self, seg): - """ Returns a list of the children of the specified node - + """Returns a list of the children of the specified node + Parameters ---------- seg: integer or Morphology Object The ID of the parent node, or the parent node itself - + Returns ------- A list of the child morphology objects. If the ID of the parent node is invalid, None is returned. """ seg = self._resolve_node_type(seg) - return [ self._node_list[c] for c in seg.children ] - + return [self._node_list[c] for c in seg.children] def to_dict(self): """ @@ -251,8 +245,7 @@ def to_dict(self): of the Morphology. Modifying them will not modify anything in the Morphology itself. """ - return { c.n: c.to_dict() for c in self._node_list } - + return {c.n: c.to_dict() for c in self._node_list} ################################################################### ################################################################### @@ -274,20 +267,19 @@ def _resolve_node_type(self, seg): raise TypeError("Object not recognized as morphology Node or index") return seg - def change_parent(self, child, parent): - """ Change the parent of a node. The child node is adjusted to - point to the new parent, the child is taken off of the previous + """Change the parent of a node. The child node is adjusted to + point to the new parent, the child is taken off of the previous parent's child list, and it is added to the new parent's child list. - + Parameters ---------- child: integer or Morphology Object The ID of the child node, or the child node itself - + parent: integer or Morphology Object The ID of the parent node, or the parent node itself - + Returns ------- Nothing @@ -300,18 +292,18 @@ def change_parent(self, child, parent): old_par.children.remove(child_seg.n) parent_seg.children.append(child_seg.n) child_seg.parent = parent_seg.n - + def get_dimensions(self): - """ Returns tuple of overall width, height and depth of - morphology. - WARNING: if locations of nodes in morphology are manipulated - then this value can become incorrect. It can be reset and - recalculated by programmitcally setting self.dims to None. + """Returns tuple of overall width, height and depth of + morphology. + WARNING: if locations of nodes in morphology are manipulated + then this value can become incorrect. It can be reset and + recalculated by programmitcally setting self.dims to None. - Returns - ------- - 3 real arrays: [width, height, depth], [min_x, min_y, min_z], - [max_x, max_y, max_z] + Returns + ------- + 3 real arrays: [width, height, depth], [min_x, min_y, min_z], + [max_x, max_y, max_z] """ if self.dims is None: min_x = self.node_list[0].x @@ -328,27 +320,31 @@ def get_dimensions(self): min_x = min(node.x, min_x) min_y = min(node.y, min_y) min_z = min(node.z, min_z) - self.dims = [(max_x-min_x), (max_y-min_y), (max_z-min_z)], [min_x, min_y, min_z], [max_x, max_y, max_z] + self.dims = ( + [(max_x - min_x), (max_y - min_y), (max_z - min_z)], + [min_x, min_y, min_z], + [max_x, max_y, max_z], + ) return self.dims # returns a list of node located within dist of x,y,z def find(self, x, y, z, dist, node_type=None): - """ Returns a list of Morphology Objects located within 'dist' - of coordinate (x,y,z). If node_type is specified, the search + """Returns a list of Morphology Objects located within 'dist' + of coordinate (x,y,z). If node_type is specified, the search will be constrained to return only nodes of that type. - + Parameters ---------- x, y, z: float The x,y,z coordinates from which to search around - + dist: float The search radius - + node_type: enum (optional) - One of the following constants: SOMA, AXON, + One of the following constants: SOMA, AXON, BASAL_DENDRITE or APICAL_DENDRITE - + Returns ------- A list of all Morphology Objects matching the search criteria @@ -358,31 +354,29 @@ def find(self, x, y, z, dist, node_type=None): dx = seg.x - x dy = seg.y - y dz = seg.z - z - if math.sqrt(dx*dx + dy*dy + dz*dz) <= dist: + if math.sqrt(dx * dx + dy * dy + dz * dz) <= dist: if node_type is None or seg.t == node_type: found.append(seg) return found - def node_list_by_type(self, node_type): - """ Return an list of all nodes having the specified + """Return an list of all nodes having the specified node type. - + Parameters ---------- node_type: int Desired node type - + Returns ------- A list of of Morphology Objects """ return [x for x in self._node_list if x.t == node_type] - def save(self, file_name): - """ Write this morphology out to an SWC file - + """Write this morphology out to an SWC file + Parameters ---------- file_name: string @@ -399,23 +393,21 @@ def save(self, file_name): f.write("%d\n" % seg.parent) f.close() - # keep for backward compatibility, but don't publish in docs def write(self, file_name): self.save(file_name) - def sparsify(self, modulo): - """ Return a new Morphology object that has a given number of non-leaf, + """Return a new Morphology object that has a given number of non-leaf, non-root nodes removed. - + Parameters ---------- modulo: int keep 1 out of every modulo nodes. - + Returns - ------- + ------- Morphology A new morphology instance """ @@ -424,13 +416,15 @@ def sparsify(self, modulo): nodes = copy.deepcopy(self.node_list) # figure out which nodes to toss keep = {} - ctr = 0 # mod counter -- keep every modulo element (starting w/ 1st) + ctr = 0 # mod counter -- keep every modulo element (starting w/ 1st) for seg in nodes: nid = seg.n - if (seg.parent < 0 or - len(seg.children) != 1 or - nodes[seg.parent].t == Morphology.SOMA or - seg.t == Morphology.SOMA): + if ( + seg.parent < 0 + or len(seg.children) != 1 + or nodes[seg.parent].t == Morphology.SOMA + or seg.t == Morphology.SOMA + ): keep[nid] = True else: if ctr % modulo == 0: @@ -453,19 +447,18 @@ def sparsify(self, modulo): sparse.append(seg) return Morphology(sparse) - #################################################################### #################################################################### - + def _reconstruct(self): """ - Internal function. - Restructures data and establishes appropriate internal linking. - Data is re-order, removing 'holes' in the ID sequence so that - each object ID corresponds to its position in node list. + Internal function. + Restructures data and establishes appropriate internal linking. + Data is re-order, removing 'holes' in the ID sequence so that + each object ID corresponds to its position in node list. Dictionaries mapping IDs to objects are no longer necessary. Trees are (re)calculated - Parent-child indices are recalculated + Parent-child indices are recalculated A new compartment list is created """ remap = {} @@ -520,11 +513,10 @@ def _reconstruct(self): endpoint.compartment_id = len(self._compartment_list) self._compartment_list.append(compartment) - def append(self, nodes): - """ Add additional nodes to this Morphology. Those nodes must + """Add additional nodes to this Morphology. Those nodes must originate from another morphology object. - + Parameters ---------- nodes: list of Morphology nodes @@ -550,31 +542,29 @@ def append(self, nodes): node.parent = remap[node.parent] self._reconstruct() - def convert_type(self, from_type, to_type): - """ Convert all nodes in morphology from one type to another + """Convert all nodes in morphology from one type to another Parameters ---------- from_type: enum The node type that will be eliminated and replaced. - Use one of the following constants: SOMA, AXON, + Use one of the following constants: SOMA, AXON, BASAL_DENDRITE, or APICAL_DENDRITE to_type: enum - The new type that will replace it. - Use one of the following constants: SOMA, AXON, + The new type that will replace it. + Use one of the following constants: SOMA, AXON, BASAL_DENDRITE, or APICAL_DENDRITE """ for node in self.node_list: if node.t == from_type: node.t = to_type - def stumpify_axon(self, count=10): - """ Remove all axon nodes except the first 'count' + """Remove all axon nodes except the first 'count' nodes, as counted from the connected axon root. - + Parameters ---------- count: Integer @@ -599,7 +589,7 @@ def stumpify_axon(self, count=10): for i in range(count): # ignore bifurcations -- go 'count' deep on one line only ax.flag = i - #ax["flag"] = i + # ax["flag"] = i children = ax.children if len(children) > 0: ax = self.node(children[0]) @@ -607,14 +597,13 @@ def stumpify_axon(self, count=10): for i in range(len(self.node_list)): seg = self.node_list[i] if seg.t == Morphology.AXON: - #if "flag" not in seg: + # if "flag" not in seg: if seg.flag is None: self.node_list[i] = None self._reconstruct() - def _strip(self, flagged_for_removal): - """ Internal function with code common between + """Internal function with code common between strip_all_other_types() and strip_type() """ # if parent will be stripped and node will remain, convert @@ -643,23 +632,22 @@ def _strip(self, flagged_for_removal): seg.parent = -1 self._reconstruct() - # strip out everything but the soma and the specified SWC type def strip_all_other_types(self, node_type, keep_soma=True): - """ Strips everything from the morphology except for the + """Strips everything from the morphology except for the specified type. Parent and child relationships are updated accordingly, creating new roots when necessary. - + Parameters ---------- node_type: enum - The node type to keep in the morphology. - Use one of the following constants: SOMA, AXON, + The node type to keep in the morphology. + Use one of the following constants: SOMA, AXON, BASAL_DENDRITE, or APICAL_DENDRITE - + keep_soma: Boolean (optional) - True (default) if soma nodes should remain in the + True (default) if soma nodes should remain in the morpyhology, and False if the soma should also be stripped """ flagged_for_removal = {} @@ -678,19 +666,18 @@ def strip_all_other_types(self, node_type, keep_soma=True): flagged_for_removal[seg.n] = False self._strip(flagged_for_removal) - # strip out the specified SWC type def strip_type(self, node_type): - """ Strips all nodes of the specified type from the - morphology. + """Strips all nodes of the specified type from the + morphology. Parent and child relationships are updated accordingly, creating new roots when necessary. - + Parameters ---------- node_type: enum The node type to strip from the morphology. - Use one of the following constants: SOMA, AXON, + Use one of the following constants: SOMA, AXON, BASAL_DENDRITE, or APICAL_DENDRITE """ flagged_for_removal = {} @@ -704,36 +691,33 @@ def strip_type(self, node_type): else: flagged_for_removal[seg.n] = False self._strip(flagged_for_removal) - def clone(self): - """ Create a clone (deep copy) of this morphology - """ + """Create a clone (deep copy) of this morphology""" return copy.deepcopy(self) - def apply_affine_only_rotation(self, aff): - """ Apply an affine transform to all nodes in this + """Apply an affine transform to all nodes in this morphology. Only the rotation element of the transform is - performed (i.e., although the entire transformation and + performed (i.e., although the entire transformation and translation matrix is supplied, only the rotation element is used). The morphology is translated to the point where the soma root is at 0,0,0. - + Format of the affine matrix is: - + [x0 y0 z0] [tx] [x1 y1 z1] [ty] [x2 y2 z2] [tz] - - where the left 3x3 the matrix defines the affine rotation + + where the left 3x3 the matrix defines the affine rotation and scaling, and the right column is the translation vector. The matrix must be collapsed and stored in a list as follows: - + [x0 y0, z0, x1, y1, z1, x2, y2, z2, tx, ty, tz] - + Parameters ---------- aff: 3x4 array of floats (python 2D list, or numpy 2D array) @@ -758,39 +742,39 @@ def apply_affine_only_rotation(self, aff): affine[8] /= scale_z # apply rotation for seg in self.node_list: - x = seg.x*affine[0] + seg.y*affine[1] + seg.z*affine[2] - y = seg.x*affine[3] + seg.y*affine[4] + seg.z*affine[5] - z = seg.x*affine[6] + seg.y*affine[7] + seg.z*affine[8] + x = seg.x * affine[0] + seg.y * affine[1] + seg.z * affine[2] + y = seg.x * affine[3] + seg.y * affine[4] + seg.z * affine[5] + z = seg.x * affine[6] + seg.y * affine[7] + seg.z * affine[8] seg.x = x seg.y = y seg.z = z -# # relocate back to zero -# soma = self.soma_root() -# if soma is not None: -# for seg in self.node_list: -# seg.x -= soma.x -# seg.y -= soma.y -# seg.z -= soma.z + # # relocate back to zero + # soma = self.soma_root() + # if soma is not None: + # for seg in self.node_list: + # seg.x -= soma.x + # seg.y -= soma.y + # seg.z -= soma.z def apply_affine(self, aff, scale=None): - """ Apply an affine transform to all nodes in this + """Apply an affine transform to all nodes in this morphology. Compartment radius is adjusted as well. - + Format of the affine matrix is: - + [x0 y0 z0] [tx] [x1 y1 z1] [ty] [x2 y2 z2] [tz] - - where the left 3x3 the matrix defines the affine rotation + + where the left 3x3 the matrix defines the affine rotation and scaling, and the right column is the translation vector. - + The matrix must be collapsed and stored in a list as follows: - + [x0 y0, z0, x1, y1, z1, x2, y2, z2, tx, ty, tz] - + Parameters ---------- aff: 3x4 array of floats (python 2D list, or numpy 2D array) @@ -800,7 +784,7 @@ def apply_affine(self, aff, scale=None): # nodes, the radius of each node must be adjusted. # There are 2 ways to measure scale from a transform. Assuming # an isotropic transform, the scale is the cube root of the - # matrix determinant. The other ways is to measure scale + # matrix determinant. The other ways is to measure scale # independently along each axis. # For now, the node radius is only updated based on the average # scale along all 3 axes (eg, isotropic assumption), so calculate @@ -814,24 +798,23 @@ def apply_affine(self, aff, scale=None): # scale factor det_scale = np.power(abs(determinant), 1.0 / 3.0) ## measure scale along each axis - ## keep this code here in case - #scale_x = abs(aff[0] + aff[3] + aff[6]) - #scale_y = abs(aff[1] + aff[4] + aff[7]) - #scale_z = abs(aff[2] + aff[5] + aff[8]) - #avg_scale = (scale_x + scale_y + scale_z) / 3.0; + ## keep this code here in case + # scale_x = abs(aff[0] + aff[3] + aff[6]) + # scale_y = abs(aff[1] + aff[4] + aff[7]) + # scale_z = abs(aff[2] + aff[5] + aff[8]) + # avg_scale = (scale_x + scale_y + scale_z) / 3.0; # # use determinant for scaling for now as it's most simple scale = det_scale for seg in self.node_list: - x = seg.x*aff[0] + seg.y*aff[1] + seg.z*aff[2] + aff[9] - y = seg.x*aff[3] + seg.y*aff[4] + seg.z*aff[5] + aff[10] - z = seg.x*aff[6] + seg.y*aff[7] + seg.z*aff[8] + aff[11] + x = seg.x * aff[0] + seg.y * aff[1] + seg.z * aff[2] + aff[9] + y = seg.x * aff[3] + seg.y * aff[4] + seg.z * aff[5] + aff[10] + z = seg.x * aff[6] + seg.y * aff[7] + seg.z * aff[8] + aff[11] seg.x = x seg.y = y seg.z = z seg.radius *= scale - def _separate_trees(self): """ Construct list of independent trees (each tree has a root of -1). @@ -856,14 +839,14 @@ def _separate_trees(self): # figure out which tree to put node into # if there are muliple possibilities, merge all of them if len(local_trees) == 0: - tree_num = len(trees) # create new tree + tree_num = len(trees) # create new tree elif len(local_trees) == 1: - tree_num = local_trees[0] # use existing tree + tree_num = local_trees[0] # use existing tree elif len(local_trees) > 1: # this node is an intersection of multiple trees # merge all trees into the first one found tree_num = local_trees[0] - for j in range(1,len(local_trees)): + for j in range(1, len(local_trees)): dead_tree = local_trees[j] trees[dead_tree] = [] for node in self.node_list: @@ -896,7 +879,6 @@ def _separate_trees(self): # reset node tree_id to correct tree number self._reset_tree_ids() - def _reset_tree_ids(self): """ reset each node's tree_id value to the correct tree number @@ -905,7 +887,6 @@ def _reset_tree_ids(self): for j in range(len(self._tree_list[i])): self._tree_list[i][j].tree_id = i - def _check_consistency(self): """ internal function -- don't publish in the docs @@ -954,7 +935,6 @@ def _check_consistency(self): print("Failed consistency check: %d errors encountered" % errs) return errs - def _find_type_boundary(self): """ return a list of segments who have parents that are a different type @@ -968,11 +948,10 @@ def _find_type_boundary(self): adoptees.append(node) return adoptees - # remove tree from swc's "forest" def delete_tree(self, n): - """ Delete tree, and all of its nodes, from the morphology. - + """Delete tree, and all of its nodes, from the morphology. + Parameters ---------- n: Integer @@ -991,11 +970,9 @@ def delete_tree(self, n): # reset node tree_id to correct tree number self._reset_tree_ids() - def _print_all_nodes(self): """ debugging function. prints all nodes """ for node in self.node_list: print(node.short_string()) - diff --git a/allensdk/internal/morphology/morphvis.py b/allensdk/internal/morphology/morphvis.py index 2b1780306b..58c572297c 100644 --- a/allensdk/internal/morphology/morphvis.py +++ b/allensdk/internal/morphology/morphvis.py @@ -1,6 +1,7 @@ from PIL import Image, ImageDraw import numpy as np + class MorphologyColors(object): def __init__(self): self.soma = (0, 0, 0) @@ -24,38 +25,38 @@ def set_apical_color(self, r, g, b): # create empty image def create_image(w, h, color=None, alpha=False): if alpha: - mode = 'RGBA' + mode = "RGBA" else: - mode = 'RGB' + mode = "RGB" if color is not None: - return Image.new(mode, (w,h), color) + return Image.new(mode, (w, h), color) else: - return Image.new(mode, (w,h)) + return Image.new(mode, (w, h)) def calculate_scale(morph, pix_width, pix_height): - """ Calculates scaling factor and x,y insets required to auto-scale - and center morphology into box with specified numbers of pixels + """Calculates scaling factor and x,y insets required to auto-scale + and center morphology into box with specified numbers of pixels - Parameters - ---------- + Parameters + ---------- - morph: AISDK Morphology object + morph: AISDK Morphology object - pix_width: int - Number of image pixels on X axis + pix_width: int + Number of image pixels on X axis - pix_height: int - Number of image pixels on Y axis + pix_height: int + Number of image pixels on Y axis - Returns - ------- - real, real, real - First return value is the scaling factor. Second is the - number of pixels needed to adjust x-coordinates so that the - morphology is horizontally centered. Third is the number of - pixels needed to adjust the y-coordinates so that the morphology - is vertically centered. + Returns + ------- + real, real, real + First return value is the scaling factor. Second is the + number of pixels needed to adjust x-coordinates so that the + morphology is horizontally centered. Third is the number of + pixels needed to adjust the y-coordinates so that the morphology + is vertically centered. """ dims, low, high = morph.get_dimensions() # get boundaries of morphology @@ -77,75 +78,83 @@ def calculate_scale(morph, pix_width, pix_height): v_center = (ylow + yhigh) / 2.0 scale_inset_x = -low[0] * scale_factor # invert y coordinates for conversion to pixel space - scale_inset_y = pix_height/2 + scale_factor*v_center + scale_inset_y = pix_height / 2 + scale_factor * v_center else: # image constrained on vertical axis scale_factor = vscale # center image horizontally h_center = (xlow + xhigh) / 2.0 - scale_inset_x = pix_width/2 - scale_factor*h_center + scale_inset_x = pix_width / 2 - scale_factor * h_center scale_inset_y = -low[1] * scale_factor return scale_factor, scale_inset_x, scale_inset_y # draw morphology on image -- takes image and morphology, modifies image # options: scale to fit | linear scaling -def draw_morphology(img, morph, - inset_left=0, inset_right=0, inset_top=0, inset_bottom=0, - scale_to_fit=False, scale_factor=1.0, colors=None): - """ Draws morphology onto image - When no scaling is applied, and no insets are provided, the - coordinates of the morphology are used directly -- i.e., 100 in - morphology coordinates is equal to 100 pixels. - - The scale factor is multiplied to morphology coordinates before - being drawn. If scale_factor=2 then 50 in morphology coordinates - is 100 pixels. Left and top insets shift the coordinate axes - for drawing. E.g., if left=10 and top=5 then 0,0 in morphology - coordinates is 10,5 in pixel space. Bottom and right insets are - ignored. - - If scale_to_fit is set then scale factor is ignored. The - morphology is scaled to be the maximum size that fits in - the image, taking into account insets. In a 100x100 image, if - all insets=10, then the image is scaled to fit into the center - 80x80 pixel area, and nothing is drawn in the inset border areas. - - Axons are drawn before soma and dendrite compartments. - - - Parameters - ---------- - - img: PIL image object - - morph: AISDK Morphology object - - inset_*: real - This is the number of pixels to use as border on top/bottom/ - right/left. If scale_to_fit is false then only the top/left - values are used, as the scale_factor will determine how - large the morphology is (it can be drawn beyond insets and even - beyond image boundaries) - - scale_to_fit: boolean - If true then morphology is scaled to the inset area of the - image and scale_factor is ignored. Morphology is centered - in the image in the sense that the top/bottom and left/right - edges of the morphology are equidistant from image borders. - - scale_factor: real - A scalar amount that is multiplied to morphology coordinates - before drawing - - colors: MorphologyColors object - This is the color scheme used to draw the morphology. If - colors=None then default coloring is used - - Returns - ------- - - 2-dimensional array, the pixel coordinates of the soma root [x,y] +def draw_morphology( + img, + morph, + inset_left=0, + inset_right=0, + inset_top=0, + inset_bottom=0, + scale_to_fit=False, + scale_factor=1.0, + colors=None, +): + """Draws morphology onto image + When no scaling is applied, and no insets are provided, the + coordinates of the morphology are used directly -- i.e., 100 in + morphology coordinates is equal to 100 pixels. + + The scale factor is multiplied to morphology coordinates before + being drawn. If scale_factor=2 then 50 in morphology coordinates + is 100 pixels. Left and top insets shift the coordinate axes + for drawing. E.g., if left=10 and top=5 then 0,0 in morphology + coordinates is 10,5 in pixel space. Bottom and right insets are + ignored. + + If scale_to_fit is set then scale factor is ignored. The + morphology is scaled to be the maximum size that fits in + the image, taking into account insets. In a 100x100 image, if + all insets=10, then the image is scaled to fit into the center + 80x80 pixel area, and nothing is drawn in the inset border areas. + + Axons are drawn before soma and dendrite compartments. + + + Parameters + ---------- + + img: PIL image object + + morph: AISDK Morphology object + + inset_*: real + This is the number of pixels to use as border on top/bottom/ + right/left. If scale_to_fit is false then only the top/left + values are used, as the scale_factor will determine how + large the morphology is (it can be drawn beyond insets and even + beyond image boundaries) + + scale_to_fit: boolean + If true then morphology is scaled to the inset area of the + image and scale_factor is ignored. Morphology is centered + in the image in the sense that the top/bottom and left/right + edges of the morphology are equidistant from image borders. + + scale_factor: real + A scalar amount that is multiplied to morphology coordinates + before drawing + + colors: MorphologyColors object + This is the color scheme used to draw the morphology. If + colors=None then default coloring is used + + Returns + ------- + + 2-dimensional array, the pixel coordinates of the soma root [x,y] """ # determine drawing area, scaling factor, offset to origin # if scaling to fit, find value that scales morphology so height @@ -176,21 +185,21 @@ def draw_morphology(img, morph, for i in range(3): for comp in morph.compartment_list: if comp.node2.t == 1: - if i != 2: # soma drawn last + if i != 2: # soma drawn last # NOTE: there are unlikely to be soma compartments # additional soma-drawing code is below continue color = colors.soma elif comp.node2.t == 2: - if i != 0: # axon drawn first + if i != 0: # axon drawn first continue color = colors.axon elif comp.node2.t == 3: - if i != 1: # dendrite drawn second + if i != 1: # dendrite drawn second continue color = colors.basal elif comp.node2.t == 4: - if i != 1: # dendrite drawn second + if i != 1: # dendrite drawn second continue color = colors.apical x0 = scale_inset_x + inset_left + scale_factor * comp.node1.x @@ -211,72 +220,72 @@ def draw_morphology(img, morph, rad = scale_factor * root.radius x0 = int(x - rad) y0 = int(y - rad) - x1 = int(x0 + 2*rad) - y1 = int(y0 + 2*rad) - canvas.ellipse((x0,y0,x1,y1), fill=colors.soma, outline=colors.soma) + x1 = int(x0 + 2 * rad) + y1 = int(y0 + 2 * rad) + canvas.ellipse((x0, y0, x1, y1), fill=colors.soma, outline=colors.soma) # return soma root coordinate, in unit of pixels return [x, y] -def draw_density_hist(img, morph, vert_scale, - inset_left=0, inset_right=0, inset_top=0, inset_bottom=0, - num_bins=None, colors=None): - """ Draws density histogram onto image - When no scaling is applied, and no insets are provided, the - coordinates of the morphology are used directly -- i.e., 100 in - morphology coordinates is equal to 100 pixels. - - The scale factor is multiplied to morphology coordinates before - being drawn. If scale_factor=2 then 50 in morphology coordinates - is 100 pixels. Left and top insets shift the coordinate axes - for drawing. E.g., if left=10 and top=5 then 0,0 in morphology - coordinates is 10,5 in pixel space. Bottom and right insets are - ignored. - - If scale_to_fit is set then scale factor is ignored. The - morphology is scaled to be the maximum size that fits in - the image, taking into account insets. In a 100x100 image, if - all insets=10, then the image is scaled to fit into the center - 80x80 pixel area, and nothing is drawn in the inset border areas. - - Axons are drawn before soma and dendrite compartments. - - - Parameters - ---------- - - img: PIL image object - - morph: AISDK Morphology object - - vert_scale: real - This is the amout required to multiply to a moprhology - y-coordinate to convert it to relative cortical depth (on [0,1]). - This is the inverse of the cortical thickness. - - inset_*: real - This is the number of pixels to use as border on top/bottom/ - right/left. If scale_to_fit is false then only the top/left - values are used, as the scale_factor will determine how - large the morphology is (it can be drawn beyond insets and even - beyond image boundaries) - - num_bins: int - The number of bins in the histogram - - colors: MorphologyColors object - This is the color scheme used to draw the morphology. If - colors=None then default coloring is used - - Returns - ------- - - Histogram arrays: [hist, hist2, hist3, hist4] - where hist is the histgram of all neurites, and hist[234] are - the histograms of SWC types 2,3,4 +def draw_density_hist( + img, morph, vert_scale, inset_left=0, inset_right=0, inset_top=0, inset_bottom=0, num_bins=None, colors=None +): + """Draws density histogram onto image + When no scaling is applied, and no insets are provided, the + coordinates of the morphology are used directly -- i.e., 100 in + morphology coordinates is equal to 100 pixels. + + The scale factor is multiplied to morphology coordinates before + being drawn. If scale_factor=2 then 50 in morphology coordinates + is 100 pixels. Left and top insets shift the coordinate axes + for drawing. E.g., if left=10 and top=5 then 0,0 in morphology + coordinates is 10,5 in pixel space. Bottom and right insets are + ignored. + + If scale_to_fit is set then scale factor is ignored. The + morphology is scaled to be the maximum size that fits in + the image, taking into account insets. In a 100x100 image, if + all insets=10, then the image is scaled to fit into the center + 80x80 pixel area, and nothing is drawn in the inset border areas. + + Axons are drawn before soma and dendrite compartments. + + + Parameters + ---------- + + img: PIL image object + + morph: AISDK Morphology object + + vert_scale: real + This is the amout required to multiply to a moprhology + y-coordinate to convert it to relative cortical depth (on [0,1]). + This is the inverse of the cortical thickness. + + inset_*: real + This is the number of pixels to use as border on top/bottom/ + right/left. If scale_to_fit is false then only the top/left + values are used, as the scale_factor will determine how + large the morphology is (it can be drawn beyond insets and even + beyond image boundaries) + + num_bins: int + The number of bins in the histogram + + colors: MorphologyColors object + This is the color scheme used to draw the morphology. If + colors=None then default coloring is used + + Returns + ------- + + Histogram arrays: [hist, hist2, hist3, hist4] + where hist is the histgram of all neurites, and hist[234] are + the histograms of SWC types 2,3,4 """ - # if number of bins not specified, default to vertical size of + # if number of bins not specified, default to vertical size of # drawing area in image img_width, img_height = img.size draw_width = img_width - (inset_left + inset_right) @@ -298,14 +307,14 @@ def draw_density_hist(img, morph, vert_scale, canvas = ImageDraw.Draw(img) - # for each compartment, split its length (weight) between bins for + # for each compartment, split its length (weight) between bins for # start and end nodes for seg in morph.compartment_list: wt = seg.length / 2.0 # node 1 bin1 = int(-vert_scale * seg.node1.y * num_bins) if bin1 == num_bins: - bin1 = num_bins-1 + bin1 = num_bins - 1 # only include parts of the histogram that are in the viewable # range (ie, exclude parts that extend to pia and/or wm) if bin1 >= 0 and bin1 < num_bins: @@ -319,7 +328,7 @@ def draw_density_hist(img, morph, vert_scale, # node 2 bin2 = int(-vert_scale * seg.node1.y * num_bins) if bin2 == num_bins: - bin2 = num_bins-1 + bin2 = num_bins - 1 if bin2 >= 0 and bin2 < num_bins: if seg.node2.t == 2: hist_2[bin2] += wt @@ -334,11 +343,11 @@ def draw_density_hist(img, morph, vert_scale, x0 = inset_left x1 = x0 y0 = inset_top - y1 = img_height-inset_bottom + y1 = img_height - inset_bottom canvas.line((x0, y0, x1, y1), col) hist_step = 1.0 * draw_height / num_bins - hist_scale = (draw_width-1) / hist.max() + hist_scale = (draw_width - 1) / hist.max() for seg in morph.compartment_list: ypos = 1.0 * inset_top for i in range(num_bins): @@ -369,11 +378,12 @@ def draw_density_hist(img, morph, vert_scale, return hist, hist_2, hist_3, hist_4 + # TODO # draw path on image -- takes image and path, modifies image -#def draw_path(img, path, color): - # path is in units of pixels - # foreach vertex, draw line in specified color +# def draw_path(img, path, color): +# path is in units of pixels +# foreach vertex, draw line in specified color # refine section boundaries -- takes labeled regions and generates labeled mask # -> need constraints on input. lookup of points is easy. categorizing @@ -381,5 +391,3 @@ def draw_density_hist(img, morph, vert_scale, # label morphology -- takes morpholgoy and labeled mask and adds lables # to each compartment - - diff --git a/allensdk/internal/morphology/node.py b/allensdk/internal/morphology/node.py index 201cb828dc..d7b1779556 100644 --- a/allensdk/internal/morphology/node.py +++ b/allensdk/internal/morphology/node.py @@ -16,11 +16,12 @@ import json import math + def euclidean_distance(node1, node2): dx = node1.x - node2.x dy = node1.y - node2.y dz = node1.z - node2.z - return math.sqrt(dx*dx + dy*dy + dz*dz) + return math.sqrt(dx * dx + dy * dy + dz * dz) def midpoint(node1, node2): @@ -30,7 +31,7 @@ def midpoint(node1, node2): return [px, py, pz] -class Node(object): +class Node(object): """ Represents node in SWC morphology file """ @@ -68,15 +69,15 @@ def __init__(self, n, t, x, y, z, r, pn, **kwargs): self.z = z self.radius = r self.parent = pn - # + # self.children = [] # IDs of child nodes - self.tree_id = -1 # which unconnected graph this node belongs to + self.tree_id = -1 # which unconnected graph this node belongs to # number of compartment that has this node as its endpoint # all nodes except root nodes have a compartment - self.compartment_id = -1 + self.compartment_id = -1 def to_dict(self): - """ Convert the node into a serializable dictionary """ + """Convert the node into a serializable dictionary""" return { "id": self.n, "type": self.t, @@ -87,19 +88,19 @@ def to_dict(self): "parent": self.parent, "children": self.children, "tree_id": self.tree_id, - "compartment_id": self.compartment_id + "compartment_id": self.compartment_id, } @classmethod def from_dict(cls, d): return cls( - n = d["id"], - t = d["type"], - x = d["x"], - y = d["y"], - z = d["z"], - r = d["radius"], - pn = d["parent"], + n=d["id"], + t=d["type"], + x=d["x"], + y=d["y"], + z=d["z"], + r=d["radius"], + pn=d["parent"], ) def __getitem__(self, item): @@ -110,20 +111,29 @@ def __str__(self): return json.dumps(self.to_dict()) def short_string(self): - """ create string with node information in succinct, - single-line form """ - return "%d %d %.4f %.4f %.4f %.4f %d %s %d" % (self.n, self.t, self.x, self.y, self.z, self.radius, self.parent, - str(self.children), self.tree_id) + """create string with node information in succinct, + single-line form""" + return "%d %d %.4f %.4f %.4f %.4f %d %s %d" % ( + self.n, + self.t, + self.x, + self.y, + self.z, + self.radius, + self.parent, + str(self.children), + self.tree_id, + ) + # Morphology nodes have the following fields. These allow dictionary access # to node fields (this is for backward compatibility) -NODE_ID = 'id' -NODE_TYPE = 'type' -NODE_X = 'x' -NODE_Y = 'y' -NODE_Z = 'z' -NODE_R = 'radius' -NODE_PN = 'parent' -NODE_TREE_ID = 'tree_id' -NODE_CHILDREN = 'children' - +NODE_ID = "id" +NODE_TYPE = "type" +NODE_X = "x" +NODE_Y = "y" +NODE_Z = "z" +NODE_R = "radius" +NODE_PN = "parent" +NODE_TREE_ID = "tree_id" +NODE_CHILDREN = "children" diff --git a/allensdk/internal/morphology/validate_swc.py b/allensdk/internal/morphology/validate_swc.py index cba0e79ee1..b6ac083451 100755 --- a/allensdk/internal/morphology/validate_swc.py +++ b/allensdk/internal/morphology/validate_swc.py @@ -5,7 +5,7 @@ def resave_swc(orig_swc, new_file): - """ Reads SWC file into AllenSDK Morphology object and resaves + """Reads SWC file into AllenSDK Morphology object and resaves it. This can fix some problems in an SWC file that may disrupt other software tools reading the file (e.g., NEURON) @@ -41,14 +41,22 @@ def __init__(self, n, t, x, y, z, r, pn): self.children = [] # IDs of child nodes def __str__(self): - """ create string with node information in succinct, - single-line form """ - return "%d %d %.4f %.4f %.4f %.4f %d %s" % (self.n, self.t, self.x, self.y, self.z, self.r, self.pn, str(self.children)) - + """create string with node information in succinct, + single-line form""" + return "%d %d %.4f %.4f %.4f %.4f %d %s" % ( + self.n, + self.t, + self.x, + self.y, + self.z, + self.r, + self.pn, + str(self.children), + ) def validate_swc(swc_file): - """ + """ Tests SWC files for compatibility with AllenSDK To be compatible with NEURON, SWC files must have the following properties: @@ -91,37 +99,37 @@ def validate_swc(swc_file): success = False break - # if we've made it here, file is OK for using Morphology class, and + # if we've made it here, file is OK for using Morphology class, and # should be valid with internal processing. It may also be able # to be convertable for NEURON use by resaving it nodes = [] - node_table = [] # lookup table by node num + node_table = [] # lookup table by node num line_num = 1 try: with open(swc_file, "r") as f: for line in f: # remove comments - if line.lstrip().startswith('#'): + if line.lstrip().startswith("#"): continue # read values. expected SWC format is: # ID, type, x, y, z, rad, parent # x, y, z and rad are floats. the others are ints toks = line.split() vals = TestNode( - n = int(toks[0]), - t = int(toks[1]), - x = float(toks[2]), - y = float(toks[3]), - z = float(toks[4]), - r = float(toks[5]), - pn = int(toks[6].rstrip()), - ) + n=int(toks[0]), + t=int(toks[1]), + x=float(toks[2]), + y=float(toks[3]), + z=float(toks[4]), + r=float(toks[5]), + pn=int(toks[6].rstrip()), + ) # store this node while len(nodes) <= vals.n: nodes.append(None) nodes[vals.n] = vals - #nodes.append(vals) + # nodes.append(vals) # if vals.n < 0: print("Negative node ID not allowed") @@ -154,20 +162,24 @@ def validate_swc(swc_file): success = False # verify presence and number of soma and root nodes - num_soma_nodes = sum([ int(c is not None and c.t == 1) for c in nodes ]) + num_soma_nodes = sum([int(c is not None and c.t == 1) for c in nodes]) if num_soma_nodes == 0: print("SWC must have at least one soma node. Found: %d" % num_soma_nodes) print("----------------------------------") success = False elif num_soma_nodes > 1: - print("Warning: File has multiple soma nodes. This can interfere with feature analysis in some external software (e.g., vaa3d)") + print( + "Warning: File has multiple soma nodes. This can interfere with feature analysis in some external software (e.g., vaa3d)" + ) print("----------------------------------") - num_root_nodes = sum([ int(c is not None and c.pn == -1) for c in nodes ]) + num_root_nodes = sum([int(c is not None and c.pn == -1) for c in nodes]) # case of no root nodes covered by rule below that ID of child must # be greater than that of parent if num_root_nodes > 1: - print("Warning: File has multiple root nodes. This can interfere with feature analysis in some external software (e.g., vaa3d)") + print( + "Warning: File has multiple root nodes. This can interfere with feature analysis in some external software (e.g., vaa3d)" + ) print("----------------------------------") # get a list of all of the ids, make sure they are unique while we're at it @@ -188,13 +200,13 @@ def validate_swc(swc_file): success = False break all_ids.add(iid) - + # make sure that first root node is soma for n in nodes: if n is not None: root = n break - #root = nodes[0] + # root = nodes[0] if root.t != 1: # see if soma has a root if sum([int(c is not None and c.t == 2 and c.pn == -1) for c in nodes]) == 0: @@ -211,22 +223,21 @@ def validate_swc(swc_file): root_child = nodes[root_child_id] num_grand_children = len(root_child.children) if num_grand_children > 1: - print("Child of root (%s) has more than one child (%d)" % ( root_child_id, num_grand_children )) + print("Child of root (%s) has more than one child (%d)" % (root_child_id, num_grand_children)) print("----------------------------------") success = False # sort the ids and make sure there are no gaps sorted_ids = sorted(all_ids) for i in range(1, len(sorted_ids)): - if sorted_ids[i] - sorted_ids[i-1] != 1: + if sorted_ids[i] - sorted_ids[i - 1] != 1: print("Node IDs are not sequential") print("This can be fixed by calling resave_swc() on the file") print("----------------------------------") success = False return success - - - + + def main(): argc = len(sys.argv) if argc < 1: @@ -245,5 +256,7 @@ def main(): print(" FAIL") print(str(e)) exit(1) + + if __name__ == "__main__": main() diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionize.py b/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionize.py index 5b4881764d..6c69c5091d 100644 --- a/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionize.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionize.py @@ -6,29 +6,26 @@ class CavUnionize(Unionize): - - __slots__ = ['sum_pixels', 'sum_cav_pixels'] + __slots__ = ["sum_pixels", "sum_cav_pixels"] def __init__(self, *args, **kwargs): for key in self.__slots__: setattr(self, key, 0) - def calculate(self, low, high, data_arrays): data_arrays = self.slice_arrays(low, high, data_arrays) - self.sum_pixels = data_arrays['sum_pixels'].sum() - self.sum_cav_pixels = np.multiply(data_arrays['sum_pixels'], data_arrays['cav_density']).sum() - + self.sum_pixels = data_arrays["sum_pixels"].sum() + self.sum_cav_pixels = np.multiply(data_arrays["sum_pixels"], data_arrays["cav_density"]).sum() def propagate(self, ancestor): ancestor.sum_pixels += self.sum_pixels ancestor.sum_cav_pixels += self.sum_cav_pixels return ancestor - def output(self, volume_scale, max_pixels): - return {'structure_volume': self.sum_pixels * volume_scale / max_pixels, - 'signal_volume': self.sum_cav_pixels * volume_scale / max_pixels, - 'signal_density': self.sum_cav_pixels / self.sum_pixels if self.sum_pixels > 0 else 0} - + return { + "structure_volume": self.sum_pixels * volume_scale / max_pixels, + "signal_volume": self.sum_cav_pixels * volume_scale / max_pixels, + "signal_density": self.sum_cav_pixels / self.sum_pixels if self.sum_pixels > 0 else 0, + } diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionizer.py b/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionizer.py index b2beaa6462..b859f778e0 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionizer.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/cav_unionizer.py @@ -7,45 +7,37 @@ class CavUnionizer(IntervalUnionizer): - - @classmethod def record_cb(cls): return CavUnionize() - @classmethod def propagate_record(cls, child_record, ancestor_record, copy_all=False): return child_record.propagate(ancestor_record) - def extract_data(self, data_arrays, low, high): - '''As parent - ''' - + """As parent""" + unionize = self.__class__.record_cb() unionize.calculate(low, high, data_arrays) - - return unionize + return unionize def postprocess_unionizes(self, raw_unionizes, image_series_id, volume_scale, max_pixels): - unionizes = [] - - logging.info('getting formatted unionize output') + + logging.info("getting formatted unionize output") for sid, un in raw_unionizes.items(): - if sid < 0: - hemisphere = 'left' + hemisphere = "left" else: - hemisphere = 'right' + hemisphere = "right" sid = abs(sid) - + out = un.output(volume_scale, max_pixels) - out['structure_id'] = sid - out['hemisphere'] = hemisphere - out['image_series_id'] = image_series_id + out["structure_id"] = sid + out["hemisphere"] = hemisphere + out["image_series_id"] = image_series_id unionizes.append(out) diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/data_utilities.py b/allensdk/internal/mouse_connectivity/interval_unionize/data_utilities.py index aed561f3f5..10bffa0a97 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/data_utilities.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/data_utilities.py @@ -49,15 +49,11 @@ def get_sum_pixels(sum_pixels_path): return {"sum_pixels": read(sum_pixels_path)} -def get_sum_pixel_intensities( - sum_pixel_intensities_path, injection_sum_pixel_intensities_path -): +def get_sum_pixel_intensities(sum_pixel_intensities_path, injection_sum_pixel_intensities_path): logging.info("getting sum pixel intensities") return { "sum_pixel_intensities": read(sum_pixel_intensities_path), - "injection_sum_pixel_intensities": read( - injection_sum_pixel_intensities_path - ), + "injection_sum_pixel_intensities": read(injection_sum_pixel_intensities_path), } @@ -66,9 +62,7 @@ def get_cav_density(cav_density_path): return {"cav_density": read(cav_density_path)} -def get_injection_data( - injection_fraction_path, injection_density_path, injection_energy_path -): +def get_injection_data(injection_fraction_path, injection_density_path, injection_energy_path): """Read nrrd files containing injection signal data""" logging.info("getting injection_fraction") @@ -108,9 +102,7 @@ def get_projection_data( except (IOError, OSError, RuntimeError): logging.info("skipping aav exclusion fraction") - aav_exclusion_fraction = np.zeros( - projection_density.shape, dtype=bool, order="C" - ) + aav_exclusion_fraction = np.zeros(projection_density.shape, dtype=bool, order="C") return { "projection_density": projection_density, diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/interval_unionizer.py b/allensdk/internal/mouse_connectivity/interval_unionize/interval_unionizer.py index 3c02150985..dab4c4d12a 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/interval_unionizer.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/interval_unionizer.py @@ -5,145 +5,137 @@ import numpy as np -class IntervalUnionizer(object): - +class IntervalUnionizer(object): @classmethod def record_cb(cls): return defaultdict(lambda *a, **k: 0, {}) - def __init__(self, exclude_structure_ids=None): - '''Builds unionize records from grid data. Unionize records are - summaries of experimental observations occuring in a particular - spatial domain. Domains are generally specified by the intersection of - + """Builds unionize records from grid data. Unionize records are + summaries of experimental observations occuring in a particular + spatial domain. Domains are generally specified by the intersection of + 1. a brain structure 2. an injection polygon (or its inverse) 3. the left or right side of the brain - + Parameters ---------- exclude_structure_ids : list of int, optional - Don't generate records for these structures. Defaults to [0], + Don't generate records for these structures. Defaults to [0], which excludes everything not in the brain. - - ''' - + + """ + if exclude_structure_ids is None: exclude_structure_ids = [0] self.exclude_structure_ids = exclude_structure_ids - - + def setup_interval_map(self, annotation): - '''Build a map from structure ids to intervals in the sorted flattened - reference space. - + """Build a map from structure ids to intervals in the sorted flattened + reference space. + Parameters ---------- annotation : np.ndarray Segmentation label array. - - ''' - - logging.info('getting flat annotation') + + """ + + logging.info("getting flat annotation") flat_annot = annotation.flat - - logging.info('finding sort') + + logging.info("finding sort") self.sort = np.argsort(flat_annot) - - logging.info('sorting flat annotation') + + logging.info("sorting flat annotation") flat_annot = flat_annot[self.sort] - - logging.info('finding bounds') + + logging.info("finding bounds") diff = np.diff(flat_annot) bounds = np.nonzero(diff)[0] - uniques = [ flat_annot[ii] for ii in bounds ] + [flat_annot[-1]] - - logging.info('building map') + uniques = [flat_annot[ii] for ii in bounds] + [flat_annot[-1]] + + logging.info("building map") lower_bounds = [0] + (bounds + 1).tolist() upper_bounds = (bounds + 1).tolist() + [len(flat_annot)] - self.interval_map = {sid: item for sid, item - in zip(uniques, zip(lower_bounds, upper_bounds)) - if sid not in self.exclude_structure_ids} - - + self.interval_map = { + sid: item + for sid, item in zip(uniques, zip(lower_bounds, upper_bounds)) + if sid not in self.exclude_structure_ids + } + def extract_data(self, data_arrays, low, high, **kwargs): - '''Given flattened data arrays and a specified interval, generate + """Given flattened data arrays and a specified interval, generate summary data - + Parameters ---------- data_arrays : dict - Keys identify types of data volume. Values are flattened, sorted + Keys identify types of data volume. Values are flattened, sorted arrays. low : int Index at which interval of interest begins. Inclusive. high : int Index at which interval of interest ends. Exclusive. - - ''' - - raise NotImplementedError('specify in subclass!') - - + + """ + + raise NotImplementedError("specify in subclass!") + @classmethod def propagate_record(cls, child_record, ancestor_record, copy_all=False): - '''Updates one unionize corresponding to a rootward structure with + """Updates one unionize corresponding to a rootward structure with information from a unionize corresponding to a leafward structure - + Parameters ---------- child_record : unionize Data will be drawn from this record ancestor_record : unionize This record will be updated - - ''' - - raise NotImplementedError('specify in subclass!') - - + + """ + + raise NotImplementedError("specify in subclass!") + @classmethod def propagate_unionizes(cls, direct_unionizes, ancestor_id_map): - '''Structures are arranged in a tree, whose leafward-oriented edges - indicate physical containment. This method updates rootward unionize + """Structures are arranged in a tree, whose leafward-oriented edges + indicate physical containment. This method updates rootward unionize records with information from leafward ones. - + Parameters ---------- direct_unionizes : list of unionizes - Each entry is a unionize record produced from a collection of + Each entry is a unionize record produced from a collection of directly labeled voxels in the segmentation volume. ancestor_id_map : dict - Keys are structure ids. Values are ids of all structures rootward in + Keys are structure ids. Values are ids of all structures rootward in the tree, including the key node - + Returns ------- output_unionizes : list of unionizes - Contains completed unionize records at all depths in the structure + Contains completed unionize records at all depths in the structure tree - - ''' - + """ + output_unionizes = defaultdict(cls.record_cb, cp.deepcopy(direct_unionizes)) for k, v in direct_unionizes.items(): for aid in ancestor_id_map[k]: - if k == aid: continue - - logging.debug('propagating data from {0} to {1}'.format(k, aid)) + + logging.debug("propagating data from {0} to {1}".format(k, aid)) output_unionizes[aid] = cls.propagate_record(v, output_unionizes[aid]) - + return output_unionizes - @classmethod def propagate_to_bilateral(cls, lateral_unionizes): - bilateral = defaultdict(cls.record_cb, {}) for sid in list(lateral_unionizes.keys()): unionize = lateral_unionizes[sid] @@ -152,67 +144,63 @@ def propagate_to_bilateral(cls, lateral_unionizes): if (sid in bilateral) or (other_id in bilateral): continue - logging.debug('bilateralizing structure {0}'.format(sid)) + logging.debug("bilateralizing structure {0}".format(sid)) other = lateral_unionizes[other_id] - + bilateral[sid] = cls.propagate_record(unionize, bilateral[sid], True) bilateral[sid] = cls.propagate_record(other, bilateral[sid], True) return bilateral - - def postprocess_unionizes(self, raw_unionizes, **kwargs): - '''Carry out additional calculations/formatting derivative of core + """Carry out additional calculations/formatting derivative of core unionization. - + Parameters ---------- raw_unionizes : list of unionizes Each entry is a unionize record. - - ''' - raise NotImplementedError('specify in subclass!') - - + + """ + raise NotImplementedError("specify in subclass!") + def sort_data_arrays(self, data_arrays): - '''Apply the precomputed sort to flattened data arrays - + """Apply the precomputed sort to flattened data arrays + Parameters ---------- data_arrays : dict - Keys identify types of data volume. Values are flattened, unsorted + Keys identify types of data volume. Values are flattened, unsorted arrays. - + Returns ------- - dict : + dict : As input, but values are sorted - - ''' - - logging.info('sorting data arrays') + + """ + + logging.info("sorting data arrays") return {k: v[self.sort] for k, v in data_arrays.items()} - - + def direct_unionize(self, data_arrays, pre_sorted=False, **kwargs): - '''Obtain unionize records from directly annotated regions. - + """Obtain unionize records from directly annotated regions. + Parameters ---------- data_arrays : dict Keys identify types of data volume. Values are flattened arrays. sorted : bool, optional If False, data arrays will be sorted. - - ''' - + + """ + if not pre_sorted: data_arrays = self.sort_data_arrays(data_arrays) - + unionizes = {} for sid, (low, high) in self.interval_map.items(): - logging.debug( 'unionizing structure {0} :: voxel_count={1}'.format(sid, high - low) ) + logging.debug("unionizing structure {0} :: voxel_count={1}".format(sid, high - low)) unionizes[sid] = self.extract_data(data_arrays, low, high, **kwargs) - + return unionizes diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_cav.py b/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_cav.py index 4ae2b29ef5..16db7d3248 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_cav.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_cav.py @@ -10,59 +10,59 @@ def run(input_data): + logging.info("making ancestor id map") + ancestor_id_map = get_ancestor_id_map(input_data["structures"]) - logging.info('making ancestor id map') - ancestor_id_map = get_ancestor_id_map(input_data['structures']) + logging.info("computing volume scale factor") + volume_scale = (input_data["reference_spacing"] / 10**3) ** 3 # mum3 -> mm3 + logging.info("volume scale factor : {0}".format(volume_scale)) - logging.info('computing volume scale factor') - volume_scale = (input_data['reference_spacing'] / 10 ** 3) ** 3 # mum3 -> mm3 - logging.info('volume scale factor : {0}'.format(volume_scale)) + logging.info("reference shape : {0}".format(input_data["reference_shape"])) + logging.info("reference spacing : {0}".format(input_data["reference_spacing"])) + logging.info("image_series_id : {0}".format(input_data["image_series_id"])) - logging.info('reference shape : {0}'.format(input_data['reference_shape'])) - logging.info('reference spacing : {0}'.format(input_data['reference_spacing'])) - logging.info('image_series_id : {0}'.format(input_data['image_series_id'])) - - annotation = du.load_annotation(input_data['annotation_path'], input_data['grid_paths']['data_mask']) + annotation = du.load_annotation(input_data["annotation_path"], input_data["grid_paths"]["data_mask"]) unionizer = CavUnionizer() unionizer.setup_interval_map(annotation) del annotation - signal_arrays = du.get_cav_density(input_data['grid_paths']['cav_density']) - signal_arrays.update(du.get_sum_pixels(input_data['grid_paths']['sum_pixels'])) + signal_arrays = du.get_cav_density(input_data["grid_paths"]["cav_density"]) + signal_arrays.update(du.get_sum_pixels(input_data["grid_paths"]["sum_pixels"])) - max_pixels = float(np.amax(signal_arrays['sum_pixels'])) - logging.info('max pixels per voxel: {}'.format(max_pixels)) + max_pixels = float(np.amax(signal_arrays["sum_pixels"])) + logging.info("max pixels per voxel: {}".format(max_pixels)) for k, v in signal_arrays.items(): - logging.info('sorting {0} array'.format(k)) + logging.info("sorting {0} array".format(k)) signal_arrays[k] = v.flat[unionizer.sort] - - logging.info('computing unionizes from directly annotated voxels') + + logging.info("computing unionizes from directly annotated voxels") raw_unionizes = unionizer.direct_unionize(signal_arrays, pre_sorted=True) - - logging.info('propagating data to ancestor structures') + + logging.info("propagating data to ancestor structures") raw_unionizes = CavUnionizer.propagate_unionizes(raw_unionizes, ancestor_id_map) - logging.info('propagating data to bilateral unionizes') + logging.info("propagating data to bilateral unionizes") bilateral = CavUnionizer.propagate_to_bilateral(raw_unionizes) - cooked_unionizes = list(unionizer.postprocess_unionizes( - raw_unionizes, - image_series_id=input_data['image_series_id'], - volume_scale=volume_scale, - max_pixels=max_pixels - )) - cooked_bilateral = list(unionizer.postprocess_unionizes( - bilateral, - image_series_id=input_data['image_series_id'], - volume_scale=volume_scale, - max_pixels=max_pixels - )) + cooked_unionizes = list( + unionizer.postprocess_unionizes( + raw_unionizes, + image_series_id=input_data["image_series_id"], + volume_scale=volume_scale, + max_pixels=max_pixels, + ) + ) + cooked_bilateral = list( + unionizer.postprocess_unionizes( + bilateral, image_series_id=input_data["image_series_id"], volume_scale=volume_scale, max_pixels=max_pixels + ) + ) for item in cooked_bilateral: - item['hemisphere'] = '(none)' + item["hemisphere"] = "(none)" cooked_unionizes.append(item) - logging.info('computed {0} unionize records'.format(len(cooked_unionizes))) + logging.info("computed {0} unionize records".format(len(cooked_unionizes))) return cooked_unionizes diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_classic.py b/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_classic.py index fd601f32ff..e84163ef56 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_classic.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/run_tissuecyte_unionize_classic.py @@ -7,86 +7,93 @@ def get_ancestor_id_map(structures): - - - tree = SimpleTree( structures, - lambda st: int(st['id']), - lambda st: st['parent_structure_id']) - ancestor_id_map = tree.value_map( lambda st: st['id'], - lambda st: tree.ancestor_ids([st['id']])[0] ) + tree = SimpleTree(structures, lambda st: int(st["id"]), lambda st: st["parent_structure_id"]) + ancestor_id_map = tree.value_map(lambda st: st["id"], lambda st: tree.ancestor_ids([st["id"]])[0]) for k in list(ancestor_id_map): ancestor_id_map[-k] = map(lambda x: -x, ancestor_id_map[k]) - + return ancestor_id_map def get_volume_scale(image_resolution, voxel_depth): - return image_resolution ** 2 * 10 ** -9 * voxel_depth + return image_resolution**2 * 10**-9 * voxel_depth def run(input_data): + logging.info("making ancestor id map") + ancestor_id_map = get_ancestor_id_map(input_data["structures"]) - logging.info('making ancestor id map') - ancestor_id_map = get_ancestor_id_map(input_data['structures']) - - logging.info('computing volume scale factor') - volume_scale = get_volume_scale(input_data['image_resolution'], input_data['reference_spacing']) - logging.info('volume scale factor : {0}'.format(volume_scale)) + logging.info("computing volume scale factor") + volume_scale = get_volume_scale(input_data["image_resolution"], input_data["reference_spacing"]) + logging.info("volume scale factor : {0}".format(volume_scale)) - logging.info('reference shape : {0}'.format(input_data['reference_shape'])) - logging.info('reference spacing : {0}'.format(input_data['reference_spacing'])) - logging.info('image_series_id : {0}'.format(input_data['image_series_id'])) + logging.info("reference shape : {0}".format(input_data["reference_shape"])) + logging.info("reference spacing : {0}".format(input_data["reference_spacing"])) + logging.info("image_series_id : {0}".format(input_data["image_series_id"])) - annotation = du.load_annotation(input_data['annotation_path'], input_data['grid_paths']['data_mask']) + annotation = du.load_annotation(input_data["annotation_path"], input_data["grid_paths"]["data_mask"]) unionizer = TissuecyteUnionizer() unionizer.setup_interval_map(annotation) del annotation - signal_arrays = du.get_injection_data(input_data['grid_paths']['injection_fraction'], - input_data['grid_paths']['injection_density'], - input_data['grid_paths']['injection_energy']) - signal_arrays.update(du.get_projection_data(input_data['grid_paths']['projection_density'], - input_data['grid_paths']['projection_energy'], - input_data['grid_paths']['aav_exclusion_fraction'])) - signal_arrays.update(du.get_sum_pixels(input_data['grid_paths']['sum_pixels'])) - signal_arrays.update(du.get_sum_pixel_intensities(input_data['grid_paths']['sum_pixel_intensities'], - input_data['grid_paths']['injection_sum_pixel_intensities'])) + signal_arrays = du.get_injection_data( + input_data["grid_paths"]["injection_fraction"], + input_data["grid_paths"]["injection_density"], + input_data["grid_paths"]["injection_energy"], + ) + signal_arrays.update( + du.get_projection_data( + input_data["grid_paths"]["projection_density"], + input_data["grid_paths"]["projection_energy"], + input_data["grid_paths"]["aav_exclusion_fraction"], + ) + ) + signal_arrays.update(du.get_sum_pixels(input_data["grid_paths"]["sum_pixels"])) + signal_arrays.update( + du.get_sum_pixel_intensities( + input_data["grid_paths"]["sum_pixel_intensities"], + input_data["grid_paths"]["injection_sum_pixel_intensities"], + ) + ) for k, v in signal_arrays.items(): - logging.info('sorting {0} array'.format(k)) + logging.info("sorting {0} array".format(k)) signal_arrays[k] = v.flat[unionizer.sort] - - logging.info('computing unionizes from directly annotated voxels') + + logging.info("computing unionizes from directly annotated voxels") raw_unionizes = unionizer.direct_unionize(signal_arrays, pre_sorted=True) - - logging.info('propagating data to ancestor structures') - raw_unionizes = TissuecyteUnionizer.propagate_unionizes(raw_unionizes, - ancestor_id_map) - logging.info('propagating data to bilateral unionizes') + logging.info("propagating data to ancestor structures") + raw_unionizes = TissuecyteUnionizer.propagate_unionizes(raw_unionizes, ancestor_id_map) + + logging.info("propagating data to bilateral unionizes") bilateral = TissuecyteUnionizer.propagate_to_bilateral(raw_unionizes) - cooked_unionizes = list(unionizer.postprocess_unionizes( - raw_unionizes, - image_series_id=input_data['image_series_id'], - output_spacing_iso=input_data['reference_spacing'], - volume_scale=volume_scale, - target_shape=input_data['reference_shape'], - sort=unionizer.sort - )) - - cooked_bilateral = list(unionizer.postprocess_unionizes( - bilateral, - image_series_id=input_data['image_series_id'], - output_spacing_iso=input_data['reference_spacing'], - volume_scale=volume_scale, - target_shape=input_data['reference_shape'], - sort=unionizer.sort - )) + cooked_unionizes = list( + unionizer.postprocess_unionizes( + raw_unionizes, + image_series_id=input_data["image_series_id"], + output_spacing_iso=input_data["reference_spacing"], + volume_scale=volume_scale, + target_shape=input_data["reference_shape"], + sort=unionizer.sort, + ) + ) + + cooked_bilateral = list( + unionizer.postprocess_unionizes( + bilateral, + image_series_id=input_data["image_series_id"], + output_spacing_iso=input_data["reference_spacing"], + volume_scale=volume_scale, + target_shape=input_data["reference_shape"], + sort=unionizer.sort, + ) + ) for item in cooked_bilateral: - item['hemisphere_id'] = 3 + item["hemisphere_id"] = 3 cooked_unionizes.append(item) - - logging.info('computed {0} unionize records'.format(len(cooked_unionizes))) + + logging.info("computed {0} unionize records".format(len(cooked_unionizes))) return cooked_unionizes diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionize_record.py b/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionize_record.py index a5fb565f14..d5216de7c3 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionize_record.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionize_record.py @@ -6,23 +6,29 @@ class TissuecyteBaseUnionize(Unionize): - __slots__ = ['sum_pixels', 'sum_projection_pixels', - 'sum_projection_pixel_intensity', - 'max_voxel_index', 'max_voxel_density', 'projection_density', - 'projection_energy', 'projection_intensity', - 'direct_sum_projection_pixels', - 'sum_pixel_intensity'] + __slots__ = [ + "sum_pixels", + "sum_projection_pixels", + "sum_projection_pixel_intensity", + "max_voxel_index", + "max_voxel_density", + "projection_density", + "projection_energy", + "projection_intensity", + "direct_sum_projection_pixels", + "sum_pixel_intensity", + ] def __init__(self): - '''A unionize record summarizing observations from a tissuecyte + """A unionize record summarizing observations from a tissuecyte projection experiment - ''' + """ for key in self.__slots__: setattr(self, key, 0) def propagate(self, ancestor, copy_all=False): - '''Update a rootward unionize with data from this unionize record + """Update a rootward unionize with data from this unionize record Parameters ---------- @@ -33,12 +39,11 @@ def propagate(self, ancestor, copy_all=False): ------- ancestor : TissuecyteBaseUnionize - ''' + """ ancestor.sum_pixels += self.sum_pixels ancestor.sum_projection_pixels += self.sum_projection_pixels - ancestor.sum_projection_pixel_intensity += \ - self.sum_projection_pixel_intensity + ancestor.sum_projection_pixel_intensity += self.sum_projection_pixel_intensity ancestor.sum_pixel_intensity += self.sum_pixel_intensity if ancestor.max_voxel_density <= self.max_voxel_density: @@ -46,13 +51,12 @@ def propagate(self, ancestor, copy_all=False): ancestor.max_voxel_index = self.max_voxel_index if copy_all: - ancestor.direct_sum_projection_pixels += \ - self.direct_sum_projection_pixels + ancestor.direct_sum_projection_pixels += self.direct_sum_projection_pixels return ancestor def set_max_voxel(self, density_array, low): - '''Find the voxel of greatest density in this unionizes spatial domain + """Find the voxel of greatest density in this unionizes spatial domain Parameters ---------- @@ -61,7 +65,7 @@ def set_max_voxel(self, density_array, low): low : int index in full flattened, sorted array of starting voxel - ''' + """ if self.sum_projection_pixels > 0: self.max_voxel_index = np.argmax(density_array) @@ -70,7 +74,7 @@ def set_max_voxel(self, density_array, low): self.max_voxel_index += low def output(self, output_spacing_iso, volume_scale, target_shape, sort): - '''Generate derived data for this unionize + """Generate derived data for this unionize Parameters ---------- @@ -81,95 +85,76 @@ def output(self, output_spacing_iso, volume_scale, target_shape, sort): target_shape : array-like of numeric Shape of reference space - ''' + """ if self.sum_pixels > 0: - self.projection_density = self.sum_projection_pixels / \ - self.sum_pixels - self.projection_energy = \ - self.sum_projection_pixel_intensity / self.sum_pixels + self.projection_density = self.sum_projection_pixels / self.sum_pixels + self.projection_energy = self.sum_projection_pixel_intensity / self.sum_pixels if self.sum_projection_pixels > 0: - self.projection_intensity = \ - self.sum_projection_pixel_intensity / \ - self.sum_projection_pixels + self.projection_intensity = self.sum_projection_pixel_intensity / self.sum_projection_pixels output = {k: getattr(self, k) for k in self.__slots__} - output['volume'] = self.sum_pixels * volume_scale - output[ - 'direct_projection_volume'] = \ - self.direct_sum_projection_pixels * volume_scale - output['projection_volume'] = self.sum_projection_pixels * volume_scale - output['sum_pixel_intensity'] = self.sum_pixel_intensity + output["volume"] = self.sum_pixels * volume_scale + output["direct_projection_volume"] = self.direct_sum_projection_pixels * volume_scale + output["projection_volume"] = self.sum_projection_pixels * volume_scale + output["sum_pixel_intensity"] = self.sum_pixel_intensity if self.max_voxel_index > 0: self.max_voxel_index = sort[self.max_voxel_index] - mv_pos = np.unravel_index([self.max_voxel_index], - shape=target_shape, order='C') + mv_pos = np.unravel_index([self.max_voxel_index], shape=target_shape, order="C") if len(mv_pos[0]) == 0: mv_pos = [[0], [0], [0]] else: mv_pos = [[0], [0], [0]] - output['max_voxel_x'] = mv_pos[0][0] * output_spacing_iso - output['max_voxel_y'] = mv_pos[1][0] * output_spacing_iso - output['max_voxel_z'] = mv_pos[2][0] * output_spacing_iso - del output['max_voxel_index'] + output["max_voxel_x"] = mv_pos[0][0] * output_spacing_iso + output["max_voxel_y"] = mv_pos[1][0] * output_spacing_iso + output["max_voxel_z"] = mv_pos[2][0] * output_spacing_iso + del output["max_voxel_index"] return output class TissuecyteInjectionUnionize(TissuecyteBaseUnionize): - def calculate(self, low, high, data_arrays): data_arrays = self.slice_arrays(low, high, data_arrays) - self.sum_pixels = np.multiply(data_arrays['sum_pixels'], - data_arrays['injection_fraction']).sum() - self.sum_projection_pixels = \ - np.multiply(data_arrays['sum_pixels'], - data_arrays['injection_density']).sum() + self.sum_pixels = np.multiply(data_arrays["sum_pixels"], data_arrays["injection_fraction"]).sum() + self.sum_projection_pixels = np.multiply(data_arrays["sum_pixels"], data_arrays["injection_density"]).sum() self.direct_sum_projection_pixels = self.sum_projection_pixels self.sum_projection_pixel_intensity = np.multiply( - data_arrays['sum_pixels'], data_arrays['injection_energy']).sum() - self.sum_pixel_intensity = data_arrays[ - 'injection_sum_pixel_intensities'].sum() + data_arrays["sum_pixels"], data_arrays["injection_energy"] + ).sum() + self.sum_pixel_intensity = data_arrays["injection_sum_pixel_intensities"].sum() - self.set_max_voxel(data_arrays['injection_density'], low) + self.set_max_voxel(data_arrays["injection_density"], low) class TissuecyteProjectionUnionize(TissuecyteBaseUnionize): - def calculate(self, low, high, data_arrays, ij_record): data_arrays = self.slice_arrays(low, high, data_arrays) - nex = np.logical_or( - data_arrays['injection_fraction'], - np.logical_not(data_arrays['aav_exclusion_fraction']) - ) + nex = np.logical_or(data_arrays["injection_fraction"], np.logical_not(data_arrays["aav_exclusion_fraction"])) - self.sum_pixels = data_arrays['sum_pixels'][nex].sum() + self.sum_pixels = data_arrays["sum_pixels"][nex].sum() self.sum_pixels -= ij_record.sum_pixels - self.sum_projection_pixels = np.multiply(data_arrays['sum_pixels'], - data_arrays[ - 'projection_density'])[ - nex].sum() + self.sum_projection_pixels = np.multiply(data_arrays["sum_pixels"], data_arrays["projection_density"])[ + nex + ].sum() self.sum_projection_pixels -= ij_record.sum_projection_pixels self.direct_sum_projection_pixels = self.sum_projection_pixels - self.sum_projection_pixel_intensity = \ - np.multiply(data_arrays['sum_pixels'], - data_arrays['projection_energy'])[nex].sum() - self.sum_projection_pixel_intensity -= \ - ij_record.sum_projection_pixel_intensity + self.sum_projection_pixel_intensity = np.multiply(data_arrays["sum_pixels"], data_arrays["projection_energy"])[ + nex + ].sum() + self.sum_projection_pixel_intensity -= ij_record.sum_projection_pixel_intensity - self.sum_pixel_intensity = float( - data_arrays['sum_pixel_intensities'][nex].sum()) + self.sum_pixel_intensity = float(data_arrays["sum_pixel_intensities"][nex].sum()) self.sum_pixel_intensity -= ij_record.sum_pixel_intensity - valid_density = np.multiply(nex, data_arrays['projection_density']) - valid_density = np.multiply(valid_density, - 1 - data_arrays['injection_fraction']) + valid_density = np.multiply(nex, data_arrays["projection_density"]) + valid_density = np.multiply(valid_density, 1 - data_arrays["injection_fraction"]) self.set_max_voxel(valid_density, low) diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionizer.py b/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionizer.py index be8a9b54a5..401ca2df4a 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionizer.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/tissuecyte_unionizer.py @@ -3,49 +3,42 @@ from .interval_unionizer import IntervalUnionizer -from .tissuecyte_unionize_record import TissuecyteInjectionUnionize, \ - TissuecyteProjectionUnionize +from .tissuecyte_unionize_record import TissuecyteInjectionUnionize, TissuecyteProjectionUnionize + - class TissuecyteUnionizer(IntervalUnionizer): - '''A specialization of the IntervalUnionizer set up for unionizing + """A specialization of the IntervalUnionizer set up for unionizing Tissuecyte-derived projection data. - ''' - + """ @classmethod def record_cb(cls): - return {'injection': TissuecyteInjectionUnionize(), - 'projection': TissuecyteProjectionUnionize()} - - + return {"injection": TissuecyteInjectionUnionize(), "projection": TissuecyteProjectionUnionize()} + def extract_data(self, data_arrays, low, high): - '''As parent - ''' - + """As parent""" + unionize = self.__class__.record_cb() - unionize['injection'].calculate(low, high, data_arrays) - unionize['projection'].calculate(low, high, data_arrays, unionize['injection']) - - return unionize - - - @classmethod + unionize["injection"].calculate(low, high, data_arrays) + unionize["projection"].calculate(low, high, data_arrays, unionize["injection"]) + + return unionize + + @classmethod def propagate_record(cls, child_record, ancestor_record, copy_all=False): - '''As parent - ''' + """As parent""" for k, v in child_record.items(): v.propagate(ancestor_record[k], copy_all) - + return ancestor_record - - def postprocess_unionizes(self, raw_unionizes, image_series_id, - output_spacing_iso, volume_scale, target_shape, sort): - '''As parent - + def postprocess_unionizes( + self, raw_unionizes, image_series_id, output_spacing_iso, volume_scale, target_shape, sort + ): + """As parent + New Parameters -------------- output_spacing_iso : numeric @@ -54,53 +47,50 @@ def postprocess_unionizes(self, raw_unionizes, image_series_id, Scale factor mapping pixels to microns^3 target_shape : array-like of numeric Shape of reference space - - ''' + + """ unionizes = [] total_injection_volume = 0 - - logging.info('getting formatted unionize output') + + logging.info("getting formatted unionize output") for sid, un in raw_unionizes.items(): - if sid < 0: hemisphere = 1 else: hemisphere = 2 - + current = [] for ij, item in un.items(): - v = item.output(output_spacing_iso, volume_scale, target_shape, sort) - injection = True if ij == 'injection' else False - + injection = True if ij == "injection" else False + if injection and hemisphere != 3: - total_injection_volume += v['direct_projection_volume'] - - del v['direct_projection_volume'] - del v['direct_sum_projection_pixels'] - - v.update({'is_injection': injection, - 'hemisphere_id': hemisphere, - 'structure_id': abs(sid), - 'image_series_id': image_series_id}) - + total_injection_volume += v["direct_projection_volume"] + + del v["direct_projection_volume"] + del v["direct_sum_projection_pixels"] + + v.update( + { + "is_injection": injection, + "hemisphere_id": hemisphere, + "structure_id": abs(sid), + "image_series_id": image_series_id, + } + ) + current.append(v) - + unionizes.extend(current) - + if total_injection_volume > 0: - logging.info('computing normalized projection volume') + logging.info("computing normalized projection volume") for un in unionizes: - un['normalized_projection_volume'] = un['projection_volume'] / total_injection_volume + un["normalized_projection_volume"] = un["projection_volume"] / total_injection_volume else: - logging.warning('no injection found!') + logging.warning("no injection found!") for un in unionizes: - un['normalized_projection_volume'] = 0 - - return filter(lambda x: x['sum_pixels'] > 0, unionizes) - - - + un["normalized_projection_volume"] = 0 - + return filter(lambda x: x["sum_pixels"] > 0, unionizes) diff --git a/allensdk/internal/mouse_connectivity/interval_unionize/unionize_record.py b/allensdk/internal/mouse_connectivity/interval_unionize/unionize_record.py index 61dfd4d300..0a7111d89b 100755 --- a/allensdk/internal/mouse_connectivity/interval_unionize/unionize_record.py +++ b/allensdk/internal/mouse_connectivity/interval_unionize/unionize_record.py @@ -1,27 +1,21 @@ - class Unionize(object): - '''Abstract base class for unionize records. - ''' + """Abstract base class for unionize records.""" def __init__(self, *args, **kwargs): raise NotImplementedError() - - + def calculate(self, *args, **kwargs): raise NotImplementedError() - - + def propagate(self, ancestor, copy_all, *args, **kwargs): raise NotImplementedError() - - + def output(self, *args, **kwargs): raise NotImplementedError() - def slice_arrays(self, low, high, data_arrays): - '''Extract a slice from several aligned arrays - + """Extract a slice from several aligned arrays + Parameters ---------- low : int @@ -29,9 +23,9 @@ def slice_arrays(self, low, high, data_arrays): high : int end of slice, exclusive data_arrays : dict - keys are varieties of data. values are sorted, flattened + keys are varieties of data. values are sorted, flattened data arrays - - ''' - + + """ + return {k: v[low:high] for k, v in data_arrays.items()} diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/generate_projection_strip.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/generate_projection_strip.py index 75ef37e218..cd0c5a752c 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/generate_projection_strip.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/generate_projection_strip.py @@ -24,54 +24,48 @@ def apply_colormap(image, colormap): def blend_with_background(image, background): - for ii in range(3): - image[:, :, ii] = np.multiply(np.squeeze(image[:, :, ii]), - np.squeeze(image[:, :, -1])) - image[:, :, ii] += np.multiply(np.squeeze(background), - np.squeeze(1.0 - image[:, :, -1])) + image[:, :, ii] = np.multiply(np.squeeze(image[:, :, ii]), np.squeeze(image[:, :, -1])) + image[:, :, ii] += np.multiply(np.squeeze(background), np.squeeze(1.0 - image[:, :, -1])) image[:, :, -1] = 1 return image def do_blur(image, blur): - for ii in range(3): im = sitk.GetImageFromArray(image[:, :, ii]) im = sitk.DiscreteGaussian(im, blur) image[:, :, ii] = sitk.GetArrayFromImage(im) - + return image def handle_output_image(sheet, out_image, colormap, nsteps): - sheet = sheet.copy() whole_sheet = sheet.get_output(-1) - if out_image['blur'] > 0.0: - logging.info('applying a gaussian blur with variance: {0:2.2f}'.format(out_image['blur'])) + if out_image["blur"] > 0.0: + logging.info("applying a gaussian blur with variance: {0:2.2f}".format(out_image["blur"])) whole_sheet = sitk.GetImageFromArray(whole_sheet) - whole_sheet = sitk.DiscreteGaussian(whole_sheet, out_image['blur']) + whole_sheet = sitk.DiscreteGaussian(whole_sheet, out_image["blur"]) whole_sheet = sitk.GetArrayFromImage(whole_sheet) - if out_image['scale'] != 1: - whole_sheet = zoom(whole_sheet, zoom=out_image['scale'], order=1) + if out_image["scale"] != 1: + whole_sheet = zoom(whole_sheet, zoom=out_image["scale"], order=1) whole_sheet = apply_colormap(whole_sheet, colormap) - if out_image['background'] is not None: - whole_sheet = blend_with_background(whole_sheet, out_image['background']) + if out_image["background"] is not None: + whole_sheet = blend_with_background(whole_sheet, out_image["background"]) else: whole_sheet = blend_with_background(whole_sheet, np.zeros_like(whole_sheet)[:, :, -1]) whole_sheet = np.around(whole_sheet * 255).astype(np.uint8) - out_image['write'](whole_sheet[:, :, :-1]) + out_image["write"](whole_sheet[:, :, :-1]) def simple_rotation(from_axis, to_axis, start, end, nsteps): - angles = np.linspace(start * np.pi, end * np.pi, nsteps, endpoint=False) from_axes = [from_axis] * nsteps to_axes = [to_axis] * nsteps @@ -80,7 +74,6 @@ def simple_rotation(from_axis, to_axis, start, end, nsteps): def run(volume, imin, imax, rotations, colormap): - volume = vis.sitk_safe_ln(volume) ln_imin = np.log(imin) if imin != 0 else -np.inf @@ -89,20 +82,19 @@ def run(volume, imin, imax, rotations, colormap): volume = sitk.IntensityWindowing(volume, ln_imin, ln_imax, 0.0, 1.0) for rotation in rotations: - max_sheet = ImageSheet() depth_sheet = ImageSheet() - vp = VolumeProjector.fixed_factory(volume, rotation['window_size']) - callback = functools.partial(max_cb, max_sheet, depth_sheet, **rotation['projection_parameters']) + vp = VolumeProjector.fixed_factory(volume, rotation["window_size"]) + callback = functools.partial(max_cb, max_sheet, depth_sheet, **rotation["projection_parameters"]) - rot = rotation['rotation_parameters'] + rot = rotation["rotation_parameters"] from_axes, to_axes, angles = simple_rotation(**rot) for response in vp.rotate_and_extract(from_axes, to_axes, angles, callback): pass - rotation['write_depth_sheet'](depth_sheet.get_output(-1)) + rotation["write_depth_sheet"](depth_sheet.get_output(-1)) - for out_image in rotation['output_images']: - handle_output_image(max_sheet, out_image, colormap, rot['nsteps']) + for out_image in rotation["output_images"]: + handle_output_image(max_sheet, out_image, colormap, rot["nsteps"]) diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/image_sheet.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/image_sheet.py index 9c489263b5..eba49419c2 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/image_sheet.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/image_sheet.py @@ -6,36 +6,28 @@ class ImageSheet(object): - - def append(self, new_cell): - - if not hasattr(self, 'images'): + if not hasattr(self, "images"): self.images = [new_cell] else: self.images.append(new_cell) - def apply(self, fn, *args, **kwargs): fn = functools.partial(fn, *args, **kwargs) self.images = map(fn, self.images) - def copy(self): new_sheet = ImageSheet() new_sheet.images = cp.deepcopy(self.images) return new_sheet - def get_output(self, axis): output = np.concatenate(self.images, axis=axis) - logging.info('concatenated sheet has size: {0}'.format(output.shape)) + logging.info("concatenated sheet has size: {0}".format(output.shape)) return output - @staticmethod def build_from_image(image, n, axis): - images = np.split(image, n, axis) sheet = ImageSheet() diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/projection_functions.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/projection_functions.py index 19ac349817..bb5f16c77f 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/projection_functions.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/projection_functions.py @@ -31,6 +31,3 @@ def template_projection(volume, axis, gain=2, maxv=1, *a, **k): output += gain * np.multiply(current, current) / maxv return output - - - diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/visualization_utilities.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/visualization_utilities.py index 8d179da6cf..4e828dedf1 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/visualization_utilities.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/visualization_utilities.py @@ -8,24 +8,24 @@ import numpy as np -def convert_discrete_colormap(data, cm_name='custom', color_names=None): - '''Generates a matplotlib continuous colormap on [0, 1] from a discrete +def convert_discrete_colormap(data, cm_name="custom", color_names=None): + """Generates a matplotlib continuous colormap on [0, 1] from a discrete colormap at N evenly spaced points. Parameters ---------- data : list of list - Sublists are [r, g, b]. - + Sublists are [r, g, b]. + Returns ------- matplotlib.colors.LinearSegmentedColormap Gamma is 1. Output space is 3 X [0, 1] - - ''' + + """ if color_names is None: - color_names = ['red', 'green', 'blue'] + color_names = ["red", "green", "blue"] data = np.array(data) npoints = data.shape[0] @@ -35,7 +35,7 @@ def convert_discrete_colormap(data, cm_name='custom', color_names=None): for col, name in enumerate(color_names): color_array = np.zeros((npoints, 3)) - + color_array[:, 0] = domain color_array[:, 1] = minmax_norm(data[:, col]) color_array[:, 2] = color_array[:, 1] @@ -45,41 +45,36 @@ def convert_discrete_colormap(data, cm_name='custom', color_names=None): return mpl.colors.LinearSegmentedColormap(cm_name, color_arrays, npoints, gamma=1.0) - def minmax_norm(data): - rng = np.amax(data) - np.amin(data) if rng == 0: - return data + return data return (data - np.amin(data)) / rng - def sitk_safe_ln(data, minimum=10**-10): - - logging.info('thresholding below at {0}'.format(minimum)) + logging.info("thresholding below at {0}".format(minimum)) minimum = float(minimum) data = sitk.Threshold(data, minimum, np.inf, minimum) - logging.info('taking natural log') + logging.info("taking natural log") return sitk.Log(data) def normalize_intensity(data, in_min, in_max, out_min=0.0, out_max=0.0): - - logging.info('setting input range: [{0:2.3f}, {1:2.3f}]'.format(in_min, in_max)) + logging.info("setting input range: [{0:2.3f}, {1:2.3f}]".format(in_min, in_max)) data = sitk.ShiftScale(data, -in_min, 1.0 / (in_max - in_min)) data = sitk.Threshold(data, 0.0, np.inf, 0.0) data = sitk.Threshold(data, 0.0, 1.0, 1.0) - logging.info('setting output range: [{0:2.3f}, {1:2.3f}]'.format(out_min, out_max)) - data = sitk.ShiftScale(data, 0.0, out_max - out_min) # want to scale first - return sitk.ShiftScale(data, out_min, 1) # then shift + logging.info("setting output range: [{0:2.3f}, {1:2.3f}]".format(out_min, out_max)) + data = sitk.ShiftScale(data, 0.0, out_max - out_min) # want to scale first + return sitk.ShiftScale(data, out_min, 1) # then shift def blend(image_stack, weight_stack): - ''' + """ Parameters ---------- @@ -88,12 +83,14 @@ def blend(image_stack, weight_stack): weight_stack :: list of np.ndarray The weight of each image at each pixel. Will be normalized. - ''' + """ image_stack = np.array(image_stack) weight_stack = np.array(weight_stack) - weight_stack = weight_stack - np.amin(weight_stack, axis=0) / (np.amax(weight_stack, axis=0) - np.amin(weight_stack, axis=0)) + weight_stack = weight_stack - np.amin(weight_stack, axis=0) / ( + np.amax(weight_stack, axis=0) - np.amin(weight_stack, axis=0) + ) weight_stack[np.isnan(weight_stack)] = 0.5 weight_stack[np.isinf(weight_stack)] = 0.5 diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_projector.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_projector.py index 3fcbb8eca0..74a0783610 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_projector.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_projector.py @@ -8,14 +8,12 @@ class VolumeProjector(object): - def __init__(self, view_volume): - logging.info('initializing volume projector') + logging.info("initializing volume projector") self.view_volume = view_volume - def build_rotation_transform(self, from_axis, to_axis, angle): - logging.info('constructing rotation') + logging.info("constructing rotation") transform = sitk.AffineTransform(3) transform.SetCenter((vol.sitk_get_center(self.view_volume)).tolist()) @@ -24,47 +22,36 @@ def build_rotation_transform(self, from_axis, to_axis, angle): logging.info(transform.__str__()) return transform - def rotate(self, from_axis, to_axis, angle): - logging.info('rotating from axis {0} to axis {1} ' - 'by {2:2.2f} radians'.format(from_axis, to_axis, angle)) - + logging.info("rotating from axis {0} to axis {1} by {2:2.2f} radians".format(from_axis, to_axis, angle)) + transform = self.build_rotation_transform(from_axis, to_axis, angle) - rotated = sitk.Resample(self.view_volume, transform, sitk.sitkLinear, - 0.0, self.view_volume.GetPixelID()) + rotated = sitk.Resample(self.view_volume, transform, sitk.sitkLinear, 0.0, self.view_volume.GetPixelID()) return rotated - - + def extract(self, cb, volume=None): - logging.info('extracting projection') - + logging.info("extracting projection") + if volume is None: - volume=self.view_volume + volume = self.view_volume return cb(volume) - def rotate_and_extract(self, from_axes, to_axes, angles, cb): - for fax, tax, angle in zip(from_axes, to_axes, angles): - rotated = self.rotate(fax, tax, angle) yield self.extract(cb, rotated) - @classmethod def fixed_factory(cls, volume, size): - view_volume = sitk.Image(int(size[0]), int(size[1]), int(size[2]), volume.GetPixelID()) view_volume = vol.sitk_paste_into_center(volume, view_volume) return cls(view_volume) - @classmethod def safe_factory(cls, volume): - max_extent = vol.sitk_get_diagonal_length(volume) max_extent = [np.ceil(max_extent).astype(int)] * 3 @@ -74,10 +61,5 @@ def safe_factory(cls, volume): for ax in range(volume.GetDimension()): if vpar[ax] != lpar[ax]: max_extent[ax] += 1 - - return cls.fixed_factory(volume, max_extent) - - - - + return cls.fixed_factory(volume, max_extent) diff --git a/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_utilities.py b/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_utilities.py index b5cc19c545..7e03f78f23 100644 --- a/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_utilities.py +++ b/allensdk/internal/mouse_connectivity/projection_thumbnail/volume_utilities.py @@ -7,9 +7,7 @@ def sitk_get_image_parameters(volume): - return (np.array(volume.GetSpacing()), - np.array(volume.GetSize()), - np.array(volume.GetOrigin())) + return (np.array(volume.GetSpacing()), np.array(volume.GetSize()), np.array(volume.GetOrigin())) def sitk_get_center(volume): @@ -19,7 +17,7 @@ def sitk_get_center(volume): def sitk_get_size_parity(volume): _, size, _ = sitk_get_image_parameters(volume) - return np.mod(size, 2) + return np.mod(size, 2) def sitk_get_diagonal_length(volume): @@ -28,14 +26,25 @@ def sitk_get_diagonal_length(volume): def sitk_paste_into_center(smaller, larger): - smaller_parities = sitk_get_size_parity(smaller) larger_parities = sitk_get_size_parity(larger) if not np.allclose(smaller_parities, larger_parities): - logging.warn('parities differ, result will not be centered : {0}, {1}'.format(smaller_parities, larger_parities)) + logging.warn( + "parities differ, result will not be centered : {0}, {1}".format(smaller_parities, larger_parities) + ) smaller_center = sitk_get_center(smaller) larger_center = sitk_get_center(larger) offset = np.around(larger_center - smaller_center).astype(int).tolist() - return sitk.Paste(larger, smaller, smaller.GetSize(), [0, 0, 0,], offset) + return sitk.Paste( + larger, + smaller, + smaller.GetSize(), + [ + 0, + 0, + 0, + ], + offset, + ) diff --git a/allensdk/internal/mouse_connectivity/tissuecyte_stitching/stitcher.py b/allensdk/internal/mouse_connectivity/tissuecyte_stitching/stitcher.py index 50f0b8cfd6..7426ecdba9 100644 --- a/allensdk/internal/mouse_connectivity/tissuecyte_stitching/stitcher.py +++ b/allensdk/internal/mouse_connectivity/tissuecyte_stitching/stitcher.py @@ -6,62 +6,51 @@ import numpy as np - class Stitcher(object): - - def __init__(self, image_dimensions, tiles, average_tiles, channels): - - logging.info('image_dimensions: {0}'.format(image_dimensions)) - self.image_dimensions = image_dimensions + logging.info("image_dimensions: {0}".format(image_dimensions)) + self.image_dimensions = image_dimensions - self.average_tiles = defaultdict(lambda *a, **k: None, average_tiles) - - self.tiles = tiles - self.channels = channels + self.average_tiles = defaultdict(lambda *a, **k: None, average_tiles) + self.tiles = tiles + self.channels = channels def run(self, cb=np.array): - slice_image, stitched_indicator = initialize_images(self.image_dimensions, len(self.channels)) missing_tiles = {} for tile in self.tiles: - if tile.is_missing: - missing_tiles[tile.index] = tile.get_missing_path() tile.initialize_image() else: - - tile.apply_average_tile_to_self(self.average_tiles[tile.channel]) - tile.trim_self() + tile.apply_average_tile_to_self(self.average_tiles[tile.channel]) + tile.trim_self() self.stitch(slice_image, stitched_indicator, tile, cb) return slice_image, missing_tiles - def stitch(self, slice_image, stitched_indicator, tile, cb=np.array): - region = tile.get_image_region() - + current_region = slice_image[region] indicator_region = stitched_indicator[region] - stup = (tile.size['row'], tile.size['column']) + stup = (tile.size["row"], tile.size["column"]) blend = get_blend(indicator_region, stup, cb) blend = make_blended_tile(blend, tile.image, current_region) - logging.info('obtained blend') + logging.info("obtained blend") slice_image[region] = blend stitched_indicator[region] = 1 - logging.info('updated image region with tile data') + logging.info("updated image region with tile data") -def initialize_image(dimensions, nchannels, dtype, order='C'): - return np.zeros((dimensions['row'], dimensions['column'], nchannels), dtype=dtype, order=order) +def initialize_image(dimensions, nchannels, dtype, order="C"): + return np.zeros((dimensions["row"], dimensions["column"], nchannels), dtype=dtype, order=order) def initialize_images(dimensions, nchannels): @@ -73,9 +62,9 @@ def make_blended_tile(blend, tile, current_region): def get_indicator_bound_point(indicator, lg, axis): - '''Finds the index of first change in a binary mask + """Finds the index of first change in a binary mask along a specified axis in a specified direction - ''' + """ delta = np.diff(indicator, axis=axis) points = np.where(lg(delta, 0)) @@ -84,20 +73,20 @@ def get_indicator_bound_point(indicator, lg, axis): points = np.unique(points[axis]) size = indicator.shape[axis] points = points[lg(points, size / 2.0)] - + if len(points) > 0: return points[-1] return None def blend_component_from_point(point, mesh, lg): - '''Obtains a normalized component of the blend, which describes depth of + """Obtains a normalized component of the blend, which describes depth of overlap along a specified axis in a specified direction - ''' + """ - # this has the effect that the shallowest part of the blend + # this has the effect that the shallowest part of the blend # is always 0 - symmetric with the deepest after normalization. - blend = point - mesh + 1 + blend = point - mesh + 1 blend[lg(blend, 0)] = 0 blend = np.fabs(blend) @@ -108,8 +97,7 @@ def blend_component_from_point(point, mesh, lg): def get_blend_component(indicator, lg, axis, meshes): - ''' - ''' + """ """ point = get_indicator_bound_point(indicator, lg, axis) if point is None: @@ -119,8 +107,7 @@ def get_blend_component(indicator, lg, axis, meshes): def get_overall_blend(indicator, meshes): - ''' - ''' + """ """ blends = [] @@ -134,12 +121,9 @@ def get_overall_blend(indicator, meshes): def get_blend(indicator_region, stup, cb=np.array): - ''' - ''' + """ """ - meshes = np.meshgrid(*map(np.arange, stup), indexing='ij') + meshes = np.meshgrid(*map(np.arange, stup), indexing="ij") blend = get_overall_blend(indicator_region, meshes) return cb(np.multiply(blend, indicator_region)) - - diff --git a/allensdk/internal/mouse_connectivity/tissuecyte_stitching/tile.py b/allensdk/internal/mouse_connectivity/tissuecyte_stitching/tile.py index 64308ac37b..655b3872ba 100644 --- a/allensdk/internal/mouse_connectivity/tissuecyte_stitching/tile.py +++ b/allensdk/internal/mouse_connectivity/tissuecyte_stitching/tile.py @@ -4,10 +4,7 @@ class Tile(object): - - def __init__(self, index, image, is_missing, bounds, channel, size, - margins, *args, **kwargs): - + def __init__(self, index, image, is_missing, bounds, channel, size, margins, *args, **kwargs): # identifier self.index = index @@ -24,73 +21,60 @@ def __init__(self, index, image, is_missing, bounds, channel, size, self.margins = margins logging.info( - 'tile {index} on channel {channel} starts at ({0}, {1})'.format( - self.bounds['row']['start'], - self.bounds['column']['start'], - index=self.index, - channel=self.channel)) + "tile {index} on channel {channel} starts at ({0}, {1})".format( + self.bounds["row"]["start"], self.bounds["column"]["start"], index=self.index, channel=self.channel + ) + ) def trim_self(self): - logging.info('trimming tile') + logging.info("trimming tile") self.image = self.trim(self.image) def trim(self, image): - logging.info( - 'trimming with margins ({row}, {column})'.format(**self.margins)) + logging.info("trimming with margins ({row}, {column})".format(**self.margins)) return image[ - self.margins['row']: self.margins['row'] + self.size['row'], - self.margins['column']: self.margins['column'] + self.size[ - 'column']] + self.margins["row"] : self.margins["row"] + self.size["row"], + self.margins["column"] : self.margins["column"] + self.size["column"], + ] def average_tile_is_untrimmed(self, average_tile): - return average_tile.shape[0] > self.image.shape[0] \ - or average_tile.shape[1] > self.image.shape[1] + return average_tile.shape[0] > self.image.shape[0] or average_tile.shape[1] > self.image.shape[1] def apply_average_tile(self, average_tile): - if average_tile is None: logging.info( - 'no average tile found for tile with index {index} on ' - 'channel {channel}'.format( - **self.__dict__)) + "no average tile found for tile with index {index} on channel {channel}".format(**self.__dict__) + ) return self.image if self.average_tile_is_untrimmed(average_tile): - logging.info('trimming average tile') + logging.info("trimming average tile") average_tile = self.trim(average_tile) logging.info( - 'applying flatfield correction to tile with index {index} on ' - 'channel {channel}'.format( - **self.__dict__)) + "applying flatfield correction to tile with index {index} on channel {channel}".format(**self.__dict__) + ) return np.multiply(self.image, average_tile) def apply_average_tile_to_self(self, average_tile): self.image = self.apply_average_tile(average_tile) def get_image_region(self): + row = self.bounds["row"] + col = self.bounds["column"] - row = self.bounds['row'] - col = self.bounds['column'] - - return slice(row['start'], row['end']), \ - slice(col['start'], col['end']), \ - self.channel + return slice(row["start"], row["end"]), slice(col["start"], col["end"]), self.channel def get_missing_path(self): + row = self.bounds["row"] + col = self.bounds["column"] - row = self.bounds['row'] - col = self.bounds['column'] - - path = [row['start'], col['start'], - row['end'], col['start'], - row['end'], col['end'], - row['start'], col['end']] + path = [row["start"], col["start"], row["end"], col["start"], row["end"], col["end"], row["start"], col["end"]] - logging.info('missing tile starts at: ({0}, {1})'.format(*path)) + logging.info("missing tile starts at: ({0}, {1})".format(*path)) return path def initialize_image(self): - logging.info('initializing tile image to 0') - self.image = np.zeros((self.size['row'], self.size['column'])) + logging.info("initializing tile image to 0") + self.image = np.zeros((self.size["row"], self.size["column"])) diff --git a/allensdk/internal/notebooks/execute_notebooks.py b/allensdk/internal/notebooks/execute_notebooks.py index 43cc28156c..f36854742f 100644 --- a/allensdk/internal/notebooks/execute_notebooks.py +++ b/allensdk/internal/notebooks/execute_notebooks.py @@ -13,43 +13,25 @@ from traitlets.config import Config parser = ArgumentParser() -parser.add_argument( - '--notebooks_dir', - required=True, - help='Path to notebooks to execute' -) -parser.add_argument( - '--skip_notebooks', - nargs='+', - help='List of notebook names to skip', - default=[] -) +parser.add_argument("--notebooks_dir", required=True, help="Path to notebooks to execute") +parser.add_argument("--skip_notebooks", nargs="+", help="List of notebook names to skip", default=[]) args = parser.parse_args() logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(name='Notebook runner') +logger = logging.getLogger(name="Notebook runner") NOTEBOOK_ARGS = { - 'brain_observatory_analysis.ipynb': { - 'RUN_LOCALLY_SPARSE_NOISE': False - }, - 'ecephys_data_access.ipynb': { - 'DOWNLOAD_COMPLETE_DATASET': False - }, - 'visual_behavior_neuropixels_data_access.ipynb': { - 'DOWNLOAD_COMPLETE_DATASET': False - }, - 'visual_behavior_ophys_data_access.ipynb': { - 'DOWNLOAD_COMPLETE_DATASET': False - }, - 'ecephys_session.ipynb': { - 'DOWNLOAD_LFP': False - } + "brain_observatory_analysis.ipynb": {"RUN_LOCALLY_SPARSE_NOISE": False}, + "ecephys_data_access.ipynb": {"DOWNLOAD_COMPLETE_DATASET": False}, + "visual_behavior_neuropixels_data_access.ipynb": {"DOWNLOAD_COMPLETE_DATASET": False}, + "visual_behavior_ophys_data_access.ipynb": {"DOWNLOAD_COMPLETE_DATASET": False}, + "ecephys_session.ipynb": {"DOWNLOAD_LFP": False}, } class NotebookRunner: """Notebook runner""" + def __init__(self, notebooks_dir: str): """ @@ -58,12 +40,8 @@ def __init__(self, notebooks_dir: str): notebooks_dir Path to notebooks """ - notebook_paths = [ - Path(x) for x in glob.glob(os.path.join(notebooks_dir, "*.ipynb")) - ] - self._notebook_paths = [ - x for x in notebook_paths - if x.name not in args.skip_notebooks] + notebook_paths = [Path(x) for x in glob.glob(os.path.join(notebooks_dir, "*.ipynb"))] + self._notebook_paths = [x for x in notebook_paths if x.name not in args.skip_notebooks] def run(self): """Runs each notebook, overwriting it with updated output, @@ -78,7 +56,7 @@ def run(self): for notebook_path in self._notebook_paths: with tempfile.TemporaryDirectory() as tmp_dir: - tmp_nb_path = Path(tmp_dir) / 'scratch_nb.ipynb' + tmp_nb_path = Path(tmp_dir) / "scratch_nb.ipynb" try: papermill.execute_notebook( input_path=notebook_path, @@ -86,27 +64,24 @@ def run(self): # Note: notebook must have a variable with this name # and the cell must have tag 'parameters' parameters={ - **{ - 'output_dir': tmp_dir, - 'resources_dir': str(Path(__file__).parent / - 'resources') - }, - **NOTEBOOK_ARGS.get(notebook_path.name, {}) + **{"output_dir": tmp_dir, "resources_dir": str(Path(__file__).parent / "resources")}, + **NOTEBOOK_ARGS.get(notebook_path.name, {}), }, - kernel_name='python3' + kernel_name="python3", ) - self._remove_injected_parameters_cell( - notebook_path=tmp_nb_path) + self._remove_injected_parameters_cell(notebook_path=tmp_nb_path) - logging.info('Executing notebook succeeded. ' - f'Overwriting with new notebook output. ' - f'Moving {tmp_nb_path} to {notebook_path}') + logging.info( + "Executing notebook succeeded. " + f"Overwriting with new notebook output. " + f"Moving {tmp_nb_path} to {notebook_path}" + ) shutil.move(tmp_nb_path, notebook_path) except PapermillExecutionError as e: logging.error(e) errors.append(notebook_path.name) if len(errors) > 0: - msg = f'{len(errors)} notebooks failed. Errors in: {errors}' + msg = f"{len(errors)} notebooks failed. Errors in: {errors}" logging.error(msg) raise RuntimeError(msg) @@ -117,13 +92,10 @@ def _remove_injected_parameters_cell(notebook_path): """Removes cells with tag "injected-parameters" and outputs notebook""" c = Config() c.TagRemovePreprocessor.remove_cell_tags = ("injected-parameters",) - c.NotebookExporter.preprocessors = [ - "nbconvert.preprocessors.TagRemovePreprocessor"] + c.NotebookExporter.preprocessors = ["nbconvert.preprocessors.TagRemovePreprocessor"] exporter = NotebookExporter(config=c) - exporter.register_preprocessor(TagRemovePreprocessor(config=c), - enabled=True) - output = NotebookExporter(config=c).from_filename( - notebook_path) + exporter.register_preprocessor(TagRemovePreprocessor(config=c), enabled=True) + output = NotebookExporter(config=c).from_filename(notebook_path) with open(notebook_path, "w") as f: f.write(output[0]) @@ -133,5 +105,5 @@ def main(): notebook_runner.run() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/convert_igor_nwb.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/convert_igor_nwb.py index 602eb9d62d..d27afbf157 100755 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/convert_igor_nwb.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/convert_igor_nwb.py @@ -134,8 +134,7 @@ def main(): t1 = t0 + PULSE_LEN ep = nd.create_epoch("TestPulse_" + num, t0, t1) ep.add_timeseries("stimulus", "stimulus/presentation/Sweep_" + num) - ep.add_timeseries("response", - "acquisition/timeseries/Sweep_" + num) + ep.add_timeseries("response", "acquisition/timeseries/Sweep_" + num) ep.finalize() # experiment epoch t0 = ts["starting_time"][()] @@ -143,8 +142,7 @@ def main(): t0 += EXPERIMENT_START_TIME ep = nd.create_epoch("Experiment_" + num, t0, t1) ep.add_timeseries("stimulus", "stimulus/presentation/Sweep_" + num) - ep.add_timeseries("response", - "acquisition/timeseries/Sweep_" + num) + ep.add_timeseries("response", "acquisition/timeseries/Sweep_" + num) ep.finalize() nd.close() @@ -152,9 +150,7 @@ def main(): # execute hdf5-repack to get it back to its original size try: print("Repacking hdf5 file with compression") - process = subprocess.Popen( - ["h5repack", "-f", "GZIP=4", tmpfile, outfile], - stdout=subprocess.PIPE) + process = subprocess.Popen(["h5repack", "-f", "GZIP=4", tmpfile, outfile], stdout=subprocess.PIPE) process.wait() except Exception: print("Unable to run h5repack on temporary nwb file") @@ -172,5 +168,5 @@ def main(): module.write_output_data({}) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/extract_nwb_data.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/extract_nwb_data.py index 826f81030e..63828df9fc 100644 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/extract_nwb_data.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/extract_nwb_data.py @@ -10,18 +10,24 @@ # manual keys are values that can be passed in through input.json. # these values are used if the particular value cannot be computed. # a better name might be 'DEFAULT_VALUE_KEYS' -from allensdk.internal.pipeline_modules.IVSCC.ephys_nwb.lab_notebook_reader \ - import \ - create_lab_notebook_reader -from allensdk.internal.pipeline_modules.IVSCC.ephys_nwb.qc_support import \ - measure_blowout, measure_seal, find_stim_start, measure_vm, \ - get_last_vm_epoch, find_stim_interval, find_stim_amplitude_and_duration, \ - measure_electrode_0, get_last_vm_noise_epoch, get_stability_vm_epoch, \ - get_first_vm_noise_epoch, measure_input_resistance, \ - measure_initial_access_resistance - -MANUAL_KEYS = ['manual_seal_gohm', 'manual_initial_access_resistance_mohm', - 'manual_initial_input_mohm'] +from allensdk.internal.pipeline_modules.IVSCC.ephys_nwb.lab_notebook_reader import create_lab_notebook_reader +from allensdk.internal.pipeline_modules.IVSCC.ephys_nwb.qc_support import ( + measure_blowout, + measure_seal, + find_stim_start, + measure_vm, + get_last_vm_epoch, + find_stim_interval, + find_stim_amplitude_and_duration, + measure_electrode_0, + get_last_vm_noise_epoch, + get_stability_vm_epoch, + get_first_vm_noise_epoch, + measure_input_resistance, + measure_initial_access_resistance, +) + +MANUAL_KEYS = ["manual_seal_gohm", "manual_initial_access_resistance_mohm", "manual_initial_input_mohm"] # names of blocks used in output.json # for sweep-specific data: @@ -55,8 +61,7 @@ def build_sweep_stim_map(): try: nwb_file = h5py.File(nwb_file_name, "r") except Exception: - raise Exception( - "Unable to open input NWB file '%s'" % str(nwb_file_name)) + raise Exception("Unable to open input NWB file '%s'" % str(nwb_file_name)) print("Opened '%s'" % str(nwb_file_name)) sweep_stim_map = {} stim_sweep_map = {} @@ -81,13 +86,13 @@ def build_sweep_stim_map(): # fetches stimulus code for a given sweep name, or None if no stimulus # was found for the specified sweep def get_sweep_name_by_stimulus_code(stim_name): - """ Returns the first sweep name that uses the specified stimulus - type. 'First' does not mean lowest sweep number, only the first - one found using a [random] dictionary search. + """Returns the first sweep name that uses the specified stimulus + type. 'First' does not mean lowest sweep number, only the first + one found using a [random] dictionary search. - Input: stimulus name (string) + Input: stimulus name (string) - Output: sweep name (string), or None if no sweep found for this stim + Output: sweep name (string), or None if no sweep found for this stim """ global sweep_stim_map for k, v in stim_sweep_map.items(): @@ -99,24 +104,24 @@ def get_sweep_name_by_stimulus_code(stim_name): # returns True if stimulus name for specified sweep indicates the sweep # is a ramp and False otherwise def sweep_is_ramp(sweep_name): - """ Input: sweep name (string) + """Input: sweep name (string) - Output: boolean (True if sweep is ramp, False otherwise) + Output: boolean (True if sweep is ramp, False otherwise) """ global sweep_stim_map - return sweep_stim_map[sweep_name].startswith('C1RP') + return sweep_stim_map[sweep_name].startswith("C1RP") # old code based on using NwbDataSet objects. provide a way to # create them in order to leverage old code as much as possible def get_sweep_data(sweep_name): - """ Input: sweep name (string) + """Input: sweep name (string) - Output: NwbDataSet object + Output: NwbDataSet object """ global nwb_file_name try: - num = int(sweep_name.split('_')[-1]) + num = int(sweep_name.split("_")[-1]) except Exception: print("Unable to parse sweep number from '%s'" % str(sweep_name)) raise @@ -126,36 +131,32 @@ def get_sweep_data(sweep_name): # functions to lookup a sweep having the desired stimulus code # NOTE: if multiple instance exist then only one instance is returned def get_blowout_sweep(): - """ Returns NwbDataSet for the blowout sweep, or None if it's absent - """ - sweep_name = get_sweep_name_by_stimulus_code('EXTPBLWOUT') + """Returns NwbDataSet for the blowout sweep, or None if it's absent""" + sweep_name = get_sweep_name_by_stimulus_code("EXTPBLWOUT") if sweep_name is None: return None return get_sweep_data(sweep_name) def get_bath_sweep(): - """ Returns NwbDataSet for the bath sweep, or None if it's absent - """ - sweep_name = get_sweep_name_by_stimulus_code('EXTPINBATH') + """Returns NwbDataSet for the bath sweep, or None if it's absent""" + sweep_name = get_sweep_name_by_stimulus_code("EXTPINBATH") if sweep_name is None: return None return get_sweep_data(sweep_name) def get_seal_sweep(): - """ Returns NwbDataSet for the seal sweep, or None if it's absent - """ - sweep_name = get_sweep_name_by_stimulus_code('EXTPCllATT') + """Returns NwbDataSet for the seal sweep, or None if it's absent""" + sweep_name = get_sweep_name_by_stimulus_code("EXTPCllATT") if sweep_name is None: return None return get_sweep_data(sweep_name) def get_breakin_sweep(): - """ Returns NwbDataSet for the breakin sweep, or None if it's absent - """ - sweep_name = get_sweep_name_by_stimulus_code('EXTPBREAKN') + """Returns NwbDataSet for the breakin sweep, or None if it's absent""" + sweep_name = get_sweep_name_by_stimulus_code("EXTPBREAKN") if sweep_name is None: return None return get_sweep_data(sweep_name) @@ -167,42 +168,38 @@ def get_breakin_sweep(): ######################################################################## # QC-relevant feature extraction code + # cell-level values (for ephys_roi_results) def cell_level_features(jin, jout, sweep_tag_list, manual_values): - """ - """ + """ """ output_data = {} jout[JSON_BLOCK_EXPERIMENT_DATA] = output_data # measure blowout voltage try: blowout_data = get_blowout_sweep() - blowout = measure_blowout(blowout_data['response'], - blowout_data['index_range'][0]) - output_data['blowout_mv'] = blowout + blowout = measure_blowout(blowout_data["response"], blowout_data["index_range"][0]) + output_data["blowout_mv"] = blowout except Exception: msg = "Blowout is not available" sweep_tag_list.append(msg) logging.warning(msg) - output_data['blowout_mv'] = None + output_data["blowout_mv"] = None # measure "electrode 0" try: bath_data = get_bath_sweep() - e0 = measure_electrode_0(bath_data['response'], - bath_data['sampling_rate']) - output_data['electrode_0_pa'] = e0 + e0 = measure_electrode_0(bath_data["response"], bath_data["sampling_rate"]) + output_data["electrode_0_pa"] = e0 except Exception: msg = "Electrode 0 is not available" sweep_tag_list.append(msg) logging.warning(msg) - output_data['electrode_0_pa'] = None + output_data["electrode_0_pa"] = None # measure clamp seal try: seal_data = get_seal_sweep() - seal = measure_seal(seal_data['stimulus'], - seal_data['response'], - seal_data['sampling_rate']) + seal = measure_seal(seal_data["stimulus"], seal_data["response"], seal_data["sampling_rate"]) # error may arise in computing seal, which falls through to # exception handler. if seal computation didn't fail but # computation generated invalid value, trigger same @@ -215,7 +212,7 @@ def cell_level_features(jin, jout, sweep_tag_list, manual_values): sweep_tag_list.append(msg) logging.warning(msg) # look for manual seal value and use it if it's available - seal = manual_values.get('manual_seal_gohm', None) + seal = manual_values.get("manual_seal_gohm", None) if seal is not None: logging.info("using manual seal value: %f" % seal) sweep_tag_list.append("Seal set using manual value") @@ -238,15 +235,15 @@ def cell_level_features(jin, jout, sweep_tag_list, manual_values): ########################### # input resistance try: - ir = measure_input_resistance(breakin_data['stimulus'], - breakin_data['response'], - breakin_data['sampling_rate']) + ir = measure_input_resistance( + breakin_data["stimulus"], breakin_data["response"], breakin_data["sampling_rate"] + ) except Exception: logging.warning("Error reading input resistance.") # apply manual value if it's available if ir is None: sweep_tag_list.append("Input resistance is not available") - ir = manual_values.get('manual_initial_input_mohm', None) + ir = manual_values.get("manual_initial_input_mohm", None) if ir is not None: msg = "Using manual value for input resistance" logging.info(msg) @@ -254,23 +251,21 @@ def cell_level_features(jin, jout, sweep_tag_list, manual_values): ########################### # initial access resistance try: - sr = measure_initial_access_resistance(breakin_data['stimulus'], - breakin_data['response'], - breakin_data[ - 'sampling_rate']) + sr = measure_initial_access_resistance( + breakin_data["stimulus"], breakin_data["response"], breakin_data["sampling_rate"] + ) except Exception: logging.warning("Error reading initial access resistance.") # apply manual value if it's available if sr is None: sweep_tag_list.append("Initial access resistance is not available") - sr = manual_values.get('manual_initial_access_resistance_mohm', - None) + sr = manual_values.get("manual_initial_access_resistance_mohm", None) if sr is not None: msg = "Using manual initial access resistance" logging.info(msg) sweep_tag_list.append(msg) # - output_data['input_resistance_mohm'] = ir + output_data["input_resistance_mohm"] = ir output_data["initial_access_resistance_mohm"] = sr sr_ratio = None # input access resistance ratio @@ -279,19 +274,18 @@ def cell_level_features(jin, jout, sweep_tag_list, manual_values): sr_ratio = sr / ir except Exception: pass # let sr_ratio stay as None - output_data['input_access_resistance_ratio'] = sr_ratio + output_data["input_access_resistance_ratio"] = sr_ratio ############################## def sweep_level_features(jin, jout, sweep_tag_list): - """ - """ + """ """ global sweep_list # pull out features from each sweep (for ephys_sweeps) jout[JSON_BLOCK_SWEEP_DATA] = {} for sweep_name in sweep_list: # pull data streams from file - sweep_num = int(sweep_name.split('_')[-1]) + sweep_num = int(sweep_name.split("_")[-1]) try: sweep_data = NwbDataSet(nwb_file_name).get_sweep(sweep_num) except Exception: @@ -304,10 +298,10 @@ def sweep_level_features(jin, jout, sweep_tag_list): if sweep_data["stimulus_unit"] == "Volts": continue # voltage-clamp - volts = sweep_data['response'] - current = sweep_data['stimulus'] - hz = sweep_data['sampling_rate'] - idx_start, idx_stop = sweep_data['index_range'] + volts = sweep_data["response"] + current = sweep_data["stimulus"] + hz = sweep_data["sampling_rate"] + idx_start, idx_stop = sweep_data["index_range"] # measure Vm and noise before stimulus idx0, idx1 = get_first_vm_noise_epoch(idx_start, current, hz) @@ -320,7 +314,7 @@ def sweep_level_features(jin, jout, sweep_tag_list): # do not check for ramps, because they do not have enough time to # recover mean1 = None - sweep_not_truncated = (idx_stop == len(current) - 1) + sweep_not_truncated = idx_stop == len(current) - 1 if sweep_not_truncated and not sweep_is_ramp(sweep_name): idx0, idx1 = get_last_vm_epoch(idx_stop, current, hz) mean1, _ = measure_vm(1e3 * volts[idx0:idx1]) @@ -331,7 +325,7 @@ def sweep_level_features(jin, jout, sweep_tag_list): # measure Vm and noise over extended interval, to check stability stim_start = find_stim_start(idx_start, current) - sweep['stimulus_start_time'] = stim_start / sweep_data['sampling_rate'] + sweep["stimulus_start_time"] = stim_start / sweep_data["sampling_rate"] idx0, idx1 = get_stability_vm_epoch(idx_start, stim_start, hz) mean2, rms2 = measure_vm(1000 * volts[idx0:idx1]) @@ -350,13 +344,12 @@ def sweep_level_features(jin, jout, sweep_tag_list): sweep["vm_delta_mv"] = None # compute stimulus duration, amplitude, interal - stim_amp, stim_dur = find_stim_amplitude_and_duration(idx_start, - current, hz) + stim_amp, stim_dur = find_stim_amplitude_and_duration(idx_start, current, hz) stim_int = find_stim_interval(idx_start, current, hz) - sweep['stimulus_amplitude'] = stim_amp * 1e12 - sweep['stimulus_duration'] = stim_dur - sweep['stimulus_interval'] = stim_int + sweep["stimulus_amplitude"] = stim_amp * 1e12 + sweep["stimulus_duration"] = stim_dur + sweep["stimulus_interval"] = stim_int tag_list = [] for i in range(len(sweep_tag_list)): @@ -377,7 +370,7 @@ def summarize_sweeps(jin, jout): h5_file_name = jin.get("input_h5", None) notebook = create_lab_notebook_reader(nwb_file_name, h5_file_name) - borg = h5py.File(nwb_file_name, 'r') + borg = h5py.File(nwb_file_name, "r") # two json blocks to store data in exp_data = jout[JSON_BLOCK_EXPERIMENT_DATA] @@ -402,7 +395,7 @@ def summarize_sweeps(jin, jout): session_date = borg["session_start_time"][()] if len(session_date) == 1: session_date = session_date[0] - exp_data['recording_date'] = session_date + exp_data["recording_date"] = session_date # get sampling rate # use same output strategy as h5-nwb converter @@ -417,9 +410,8 @@ def summarize_sweeps(jin, jout): sampling_rate = sweep_ts["starting_time"].attrs["rate"] break if sampling_rate is None: - raise Exception( - "Unable to determine sampling rate from current clamp sweep.") - exp_data['sampling_rate'] = sampling_rate + raise Exception("Unable to determine sampling rate from current clamp sweep.") + exp_data["sampling_rate"] = sampling_rate # sweep_data = [] # output_data["sweep_summary"] = sweep_data @@ -427,22 +419,20 @@ def summarize_sweeps(jin, jout): for sweep_name in borg["acquisition/timeseries"]: # get h5 timeseries object, and the sweep number sweep_ts = borg["acquisition/timeseries"][sweep_name] - sweep_num = int(sweep_name.split('_')[-1]) + sweep_num = int(sweep_name.split("_")[-1]) # sweep_num = int(sweep_name[:-4].split('_')[-1]) # for reading igor # nwb # fetch stim name from lab notebook stim_name = notebook.get_value("Stim Wave Name", sweep_num, "") if len(stim_name) == 0: - raise Exception( - "Could not read stimulus wave name from lab notebook for " - "sweep %d" % sweep_num) + raise Exception("Could not read stimulus wave name from lab notebook for sweep %d" % sweep_num) # stim units are based on timeseries type ancestry = sweep_ts.attrs["ancestry"] if "CurrentClamp" in ancestry[-1]: - stim_units = 'pA' + stim_units = "pA" elif "VoltageClamp" in ancestry[-1]: - stim_units = 'mV' + stim_units = "mV" else: # it's probably OK to skip this sweep and put a 'continue' # here instead of an exception, but wait until there's @@ -454,9 +444,7 @@ def summarize_sweeps(jin, jout): # -> need to strip last 5 chars off to make match for lookup stim_type_name = stim_type_name_map.get(stim_name[:-5], None) if stim_type_name is None: - raise Exception( - "Could not find stimulus raw name (\"%s\") for sweep %d." % - (stim_name, sweep_num)) + raise Exception('Could not find stimulus raw name ("%s") for sweep %d.' % (stim_name, sweep_num)) # voltage-clamp sweeps shouldn't have a record yet -- make one if sweep_name not in swp_data: @@ -466,8 +454,7 @@ def summarize_sweeps(jin, jout): # sweep number info["sweep_number"] = sweep_num # bridge balance - bridge_balance = notebook.get_value("Bridge Bal Value", sweep_num, - None) + bridge_balance = notebook.get_value("Bridge Bal Value", sweep_num, None) # IT-14677 # if bridge_balance is None, that's OK. do NOT change it to NaN @@ -475,8 +462,7 @@ def summarize_sweeps(jin, jout): # stimulus units info["stimulus_units"] = stim_units # leak_pa (bias current) - bias_current = notebook.get_value("I-Clamp Holding Level", sweep_num, - None) + bias_current = notebook.get_value("I-Clamp Holding Level", sweep_num, None) # IT-14677 # if bias_current is None, that's OK. do NOT change it to NaN @@ -488,12 +474,12 @@ def summarize_sweeps(jin, jout): raise Exception("Unable to read scale factor for " + sweep_name) # PBS-229 change stim name by appending set_sweep_count cnt = notebook.get_value("Set Sweep Count", sweep_num, 0) - stim_name_ext = stim_name.split('_')[0] + "[%d]" % int(cnt) + stim_name_ext = stim_name.split("_")[0] + "[%d]" % int(cnt) info["ephys_stimulus"] = { # 'description': stim_name, - 'description': stim_name_ext, - 'amplitude': scale_factor, - 'ephys_stimulus_type': {'name': stim_type_name} + "description": stim_name_ext, + "amplitude": scale_factor, + "ephys_stimulus_type": {"name": stim_type_name}, } # borg.close() diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/feature_extraction_module.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/feature_extraction_module.py index 111cb3e404..549c858ce0 100755 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/feature_extraction_module.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/feature_extraction_module.py @@ -15,11 +15,11 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('input_json') - parser.add_argument('output_json') - parser.add_argument('--log_level') - parser.add_argument('--output_directory') - + parser.add_argument("input_json") + parser.add_argument("output_json") + parser.add_argument("--log_level") + parser.add_argument("--output_directory") + args = parser.parse_args() if args.log_level: @@ -27,6 +27,7 @@ def parse_args(): return args + def main(): args = parse_args() @@ -35,9 +36,9 @@ def main(): output_data = copy.deepcopy(input_err) - err_wkfs = output_data['well_known_files'] + err_wkfs = output_data["well_known_files"] nwb_file = lims_utilities.get_well_known_file_by_type(err_wkfs, lims_utilities.NWB_FILE_TYPE_ID) - storage_directory = args.output_directory or output_data['storage_directory'] + storage_directory = args.output_directory or output_data["storage_directory"] # move code to help make data extraction compatible with ephys qc tool try: sweep_list, sweep_features = extract_data(output_data, nwb_file) @@ -47,22 +48,22 @@ def main(): json_utilities.write(args.output_json, output_data) return # - + # embed spike times in NWB file logging.debug("Embedding spike times") - tmp_nwb_file = os.path.join(storage_directory, os.path.basename(nwb_file) + '.tmp') + tmp_nwb_file = os.path.join(storage_directory, os.path.basename(nwb_file) + ".tmp") out_nwb_file = os.path.join(storage_directory, os.path.basename(nwb_file)) shutil.copy(nwb_file, tmp_nwb_file) for sweep in sweep_list: - sweep_num = sweep['sweep_number'] + sweep_num = sweep["sweep_number"] if sweep_num not in sweep_features: continue try: - spikes = sweep_features[sweep_num]['spikes'] - spike_times = [ s['threshold_t'] for s in spikes ] + spikes = sweep_features[sweep_num]["spikes"] + spike_times = [s["threshold_t"] for s in spikes] NwbDataSet(tmp_nwb_file).set_spike_times(sweep_num, spike_times) except Exception as e: logging.info("sweep %d has no sweep features. %s", sweep_num, e.message) @@ -73,17 +74,17 @@ def main(): logging.error("Problem renaming file: %s -> %s" % (tmp_nwb_file, out_nwb_file)) raise e - qc_fig_dir = os.path.join(storage_directory, 'qc_figures') + qc_fig_dir = os.path.join(storage_directory, "qc_figures") save_qc_figures(qc_fig_dir, nwb_file, output_data, True) # regenerating this file - features_json = os.path.join(storage_directory, "%d_ephys_features.json" % output_data['id']) + features_json = os.path.join(storage_directory, "%d_ephys_features.json" % output_data["id"]) json_utilities.write(features_json, output_data) - lims_utilities.append_well_known_file(output_data['well_known_files'], features_json) + lims_utilities.append_well_known_file(output_data["well_known_files"], features_json) # write output json files json_utilities.write(args.output_json, output_data) - -if __name__ == "__main__": + +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/lab_notebook_reader.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/lab_notebook_reader.py index f82508e0bd..0f8cb30a4d 100644 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/lab_notebook_reader.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/lab_notebook_reader.py @@ -24,8 +24,7 @@ def register_enabled_names(self): # lab notebook has two sections, one for numeric data and the other # for text data. this is an internal function to fetch data from # the numeric part of the notebook - def get_numeric_value(self, name, data_col, sweep_col, enable_col, - sweep_num, default_val): + def get_numeric_value(self, name, data_col, sweep_col, enable_col, sweep_num, default_val): data = self.val_number # val_number has 3 dimensions -- the first has a shape of # (#fields * 9). there are many hundreds of elements in this @@ -51,8 +50,7 @@ def get_numeric_value(self, name, data_col, sweep_col, enable_col, return return_val # internal function for fetching data from the text part of the notebook - def get_text_value(self, name, data_col, sweep_col, enable_col, sweep_num, - default_val): + def get_text_value(self, name, data_col, sweep_col, enable_col, sweep_num, default_val): data = self.val_text # algorithm mirrors get_numeric_value # return value is last non-empty entry in specified column @@ -96,8 +94,7 @@ def get_value(self, name, sweep_num, default_val): enable_col = self.enabled[name] enable_idx = numeric_fields.tolist().index(enable_col) field_idx = numeric_fields.tolist().index(name) - return self.get_numeric_value(name, field_idx, sweep_idx, - enable_idx, sweep_num, default_val) + return self.get_numeric_value(name, field_idx, sweep_idx, enable_idx, sweep_num, default_val) elif name in text_fields: # first check to see if file includes old version of column name if "Sweep #" in text_fields: @@ -109,8 +106,7 @@ def get_value(self, name, sweep_num, default_val): enable_col = self.enabled[name] enable_idx = text_fields.tolist().index(enable_col) field_idx = text_fields.tolist().index(name) - return self.get_text_value(name, field_idx, sweep_idx, enable_idx, - sweep_num, default_val) + return self.get_text_value(name, field_idx, sweep_idx, enable_idx, sweep_num, default_val) else: return default_val diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/nwb_publish.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/nwb_publish.py index 08a4fd2193..4db5319e0d 100755 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/nwb_publish.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/nwb_publish.py @@ -22,8 +22,8 @@ """ local_dir = os.path.dirname(os.path.realpath(__file__)) -if not local_dir.endswith('/'): - local_dir += '/' +if not local_dir.endswith("/"): + local_dir += "/" ELECTRODE_NAME = "Electrode 1" ELECTRODE_PATH = "/general/intracellular_ephys/" + ELECTRODE_NAME @@ -78,8 +78,7 @@ def copy_timeseries(timeseries, old_file, new_file, folder, metadata): # older experiments put this in "units" unit = old_ts["data"].attrs["units"] - new_ts.set_data(data, conversion=conversion, resolution=resolution, - unit=unit) + new_ts.set_data(data, conversion=conversion, resolution=resolution, unit=unit) start_time = old_ts["starting_time"][()] sampling_rate = old_ts["starting_time"].attrs["rate"] @@ -124,7 +123,7 @@ def copy_epochs(timeseries, old_file, new_file, folder): anc = old_file["acquisition/timeseries/" + name].attrs["ancestry"] if anc[-1] == "VoltageClampSeries": continue - num = int(name.split('_')[-1]) + num = int(name.split("_")[-1]) # experiment block epname = "Experiment_%d" % num ep = old_file["epochs/%s" % epname] @@ -133,10 +132,8 @@ def copy_epochs(timeseries, old_file, new_file, folder): desc = ep["description"][()] ep = new_file.create_epoch(epname, start, stop) ep.set_value("description", desc) - ep.add_timeseries("stimulus", - "/stimulus/presentation/Sweep_%d" % num) - ep.add_timeseries("response", - "/acquisition/timeseries/Sweep_%d" % num) + ep.add_timeseries("stimulus", "/stimulus/presentation/Sweep_%d" % num) + ep.add_timeseries("response", "/acquisition/timeseries/Sweep_%d" % num) ep.finalize() # test-pulse block epname = "TestPulse_%d" % num @@ -146,10 +143,8 @@ def copy_epochs(timeseries, old_file, new_file, folder): desc = ep["description"][()] ep = new_file.create_epoch(epname, start, stop) ep.set_value("description", desc) - ep.add_timeseries("stimulus", - "/stimulus/presentation/Sweep_%d" % num) - ep.add_timeseries("response", - "/acquisition/timeseries/Sweep_%d" % num) + ep.add_timeseries("stimulus", "/stimulus/presentation/Sweep_%d" % num) + ep.add_timeseries("response", "/acquisition/timeseries/Sweep_%d" % num) ep.finalize() # sweep block epname = name @@ -159,10 +154,8 @@ def copy_epochs(timeseries, old_file, new_file, folder): desc = ep["description"][()] ep = new_file.create_epoch(epname, start, stop) ep.set_value("description", desc) - ep.add_timeseries("stimulus", - "/stimulus/presentation/Sweep_%d" % num) - ep.add_timeseries("response", - "/acquisition/timeseries/Sweep_%d" % num) + ep.add_timeseries("stimulus", "/stimulus/presentation/Sweep_%d" % num) + ep.add_timeseries("response", "/acquisition/timeseries/Sweep_%d" % num) ep.finalize() except Exception: @@ -173,7 +166,7 @@ def copy_epochs(timeseries, old_file, new_file, folder): def copy_file(infile, outfile, passing_sweeps, rsrc, metadata): print("Opening '%s'" % infile) - old = h5py.File(infile, 'r') + old = h5py.File(infile, "r") # top-level data try: vargs = {} @@ -229,56 +222,53 @@ def copy_file(infile, outfile, passing_sweeps, rsrc, metadata): def organize_metadata(ephys_roi_result): - metadata = {'sweeps': {}} + metadata = {"sweeps": {}} - cell_specimen = ephys_roi_result['specimens'][0] - slice_specimen = ephys_roi_result['specimen'] + cell_specimen = ephys_roi_result["specimens"][0] + slice_specimen = ephys_roi_result["specimen"] - metadata['donor_id'] = cell_specimen['donor_id'] - metadata['specimen_name'] = cell_specimen['name'] - metadata['specimen_id'] = cell_specimen['id'] + metadata["donor_id"] = cell_specimen["donor_id"] + metadata["specimen_name"] = cell_specimen["name"] + metadata["specimen_id"] = cell_specimen["id"] try: - metadata['species'] = slice_specimen['donor']['organism']["name"] + metadata["species"] = slice_specimen["donor"]["organism"]["name"] except Exception: logging.error("Unable to read organism name from input.json file") raise # structure try: - structure = cell_specimen['structure'] + structure = cell_specimen["structure"] except Exception: logging.error("Cell has no structure association.") raise soma_location = {} - db_cell_soma_location = cell_specimen['cell_soma_locations'][0] + db_cell_soma_location = cell_specimen["cell_soma_locations"][0] soma_location = {} try: - soma_location['cell_soma_location_x'] = 1e-9 * db_cell_soma_location[ - 'x'] - soma_location['cell_soma_location_y'] = 1e-9 * db_cell_soma_location[ - 'y'] - soma_location['cell_soma_location_z'] = 1e-9 * db_cell_soma_location[ - 'z'] - nd = db_cell_soma_location['normalized_depth'] + soma_location["cell_soma_location_x"] = 1e-9 * db_cell_soma_location["x"] + soma_location["cell_soma_location_y"] = 1e-9 * db_cell_soma_location["y"] + soma_location["cell_soma_location_z"] = 1e-9 * db_cell_soma_location["z"] + nd = db_cell_soma_location["normalized_depth"] if nd is not None: - soma_location['cell_soma_location_normalized_depth'] = nd + soma_location["cell_soma_location_normalized_depth"] = nd except Exception as e: logging.error(e.message) raise structure_info = {} try: - structure_info['structure_id'] = structure['id'] - structure_info['structure_name'] = structure['name'] - structure_info['structure_acronym'] = structure['acronym'] + structure_info["structure_id"] = structure["id"] + structure_info["structure_name"] = structure["name"] + structure_info["structure_acronym"] = structure["acronym"] except Exception: logging.error("Structure information is missing from input.json") raise structure_info.update(soma_location) - metadata['location'] = structure_info + metadata["location"] = structure_info tags = cell_specimen["specimen_tags"] @@ -300,83 +290,76 @@ def organize_metadata(ephys_roi_result): if dend_type is None: raise Exception("Cell has no dendrite type tag.") - metadata['dendrite_type'] = dend_type - metadata['dendrite_trunc'] = dend_trunc + metadata["dendrite_type"] = dend_type + metadata["dendrite_trunc"] = dend_trunc - metadata['ephys_roi_result_id'] = ephys_roi_result['id'] - metadata['seal_gohm'] = ephys_roi_result['seal_gohm'] - metadata['initial_access_resistance_mohm'] = ephys_roi_result[ - 'initial_access_resistance_mohm'] + metadata["ephys_roi_result_id"] = ephys_roi_result["id"] + metadata["seal_gohm"] = ephys_roi_result["seal_gohm"] + metadata["initial_access_resistance_mohm"] = ephys_roi_result["initial_access_resistance_mohm"] - slice_specimen = ephys_roi_result['specimen'] - donor = slice_specimen['donor'] + slice_specimen = ephys_roi_result["specimen"] + donor = slice_specimen["donor"] # gender try: - metadata['gender'] = donor['gender']['name'] + metadata["gender"] = donor["gender"]["name"] except Exception: logging.error("Donor requires gender association.") raise # age try: - age = donor['age'] + age = donor["age"] except Exception: logging.error("Donor requires age association.") raise - metadata['age'] = { - 'date_of_birth': donor['date_of_birth'], - 'name': age['name'] - } + metadata["age"] = {"date_of_birth": donor["date_of_birth"], "name": age["name"]} # cre line and genotype are mouse-only - if metadata['species'] == 'Mus musculus': - genotypes = donor['genotypes'] + if metadata["species"] == "Mus musculus": + genotypes = donor["genotypes"] try: - reporter_genotype = next( - g for g in genotypes if g['genotype_type_id'] == 177835595) - metadata['cre_line'] = reporter_genotype['name'] + reporter_genotype = next(g for g in genotypes if g["genotype_type_id"] == 177835595) + metadata["cre_line"] = reporter_genotype["name"] except Exception: logging.error("Could not find reporter genotype for mouse cell") raise - metadata['genotype'] = { - 'description': [g['description'] for g in genotypes], - 'type': [g['name'] for g in genotypes] + metadata["genotype"] = { + "description": [g["description"] for g in genotypes], + "type": [g["name"] for g in genotypes], } else: logging.info("non-mouse cells do not have cre line or genotypes") # subject - metadata['subject'] = { - 'subject_id': cell_specimen['donor_id'], - 'comments': 'subject_id value here corresponds to Allen Institute ' - 'cell specimen "donor_id"' + metadata["subject"] = { + "subject_id": cell_specimen["donor_id"], + "comments": 'subject_id value here corresponds to Allen Institute cell specimen "donor_id"', } # sweeps - sweeps = cell_specimen['ephys_sweeps'] + sweeps = cell_specimen["ephys_sweeps"] for sweep in sweeps: if "invalid" in sweep and sweep["invalid"]: - logging.debug("skipping sweep %d, invalid" % sweep['sweep_number']) + logging.debug("skipping sweep %d, invalid" % sweep["sweep_number"]) continue - wfs = sweep['workflow_state'] - if wfs not in ['manual_passed', 'auto_passed']: - logging.debug( - "skipping sweep %d, not passed" % sweep['sweep_number']) + wfs = sweep["workflow_state"] + if wfs not in ["manual_passed", "auto_passed"]: + logging.debug("skipping sweep %d, not passed" % sweep["sweep_number"]) continue - stimulus = sweep['ephys_stimulus'] - stimulus_type = stimulus['ephys_stimulus_type'] + stimulus = sweep["ephys_stimulus"] + stimulus_type = stimulus["ephys_stimulus_type"] - metadata['sweeps'][sweep['sweep_number']] = { - 'stimulus_name': stimulus['description'], - 'stimulus_interval': sweep['stimulus_interval'], - 'stimulus_amplitude': sweep['stimulus_amplitude'], - 'stimulus_type_name': stimulus_type['name'], - 'stimulus_units': sweep["stimulus_units"] + metadata["sweeps"][sweep["sweep_number"]] = { + "stimulus_name": stimulus["description"], + "stimulus_interval": sweep["stimulus_interval"], + "stimulus_amplitude": sweep["stimulus_amplitude"], + "stimulus_type_name": stimulus_type["name"], + "stimulus_units": sweep["stimulus_units"], } # IT-12498 add additional metadata to NWB file @@ -386,64 +369,54 @@ def organize_metadata(ephys_roi_result): metadata["pharmacology"] = "please see " + url metadata["citation_policy"] = "please see " + url metadata["institution"] = "Allen Institute for Brain Science" - metadata["generated_by"] = ["pipeline", PIPELINE_NAME, "version", - PIPELINE_VERSION] + metadata["generated_by"] = ["pipeline", PIPELINE_NAME, "version", PIPELINE_VERSION] return metadata def write_metadata(nwb_file, resources, metadata): - nwb_file.set_metadata(nwbco.SEX, metadata['gender']) - if 'cre_line' in metadata: - nwb_file.set_metadata("aibs_cre_line", metadata['cre_line']) + nwb_file.set_metadata(nwbco.SEX, metadata["gender"]) + if "cre_line" in metadata: + nwb_file.set_metadata("aibs_cre_line", metadata["cre_line"]) - if 'genotype' in metadata: - genotype = metadata['genotype'] - genotype_name = '; '.join(genotype['type']) + if "genotype" in metadata: + genotype = metadata["genotype"] + genotype_name = "; ".join(genotype["type"]) nwb_file.set_metadata(nwbco.GENOTYPE, genotype_name, **genotype) - nwb_file.set_metadata('generated_by', metadata['generated_by']) + nwb_file.set_metadata("generated_by", metadata["generated_by"]) - subject = metadata['subject'] + subject = metadata["subject"] nwb_file.set_metadata(nwbco.SUBJECT, resources.get("subject"), **subject) - age = metadata['age'] - nwb_file.set_metadata(nwbco.AGE, age['name'], **age) + age = metadata["age"] + nwb_file.set_metadata(nwbco.AGE, age["name"], **age) trode = ELECTRODE_NAME - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_DESCRIPTION(trode), - resources.get("electrode_description")) - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_FILTERING(trode), - resources.get("electrode_filtering")) - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_DEVICE(trode), - resources.get("electrode_device")) - - location = metadata['location'] - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_LOCATION(trode), - location['structure_name'], **location) - - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_RESISTANCE(trode), - resources.get("electrode_resistance")) - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_SLICE(trode), - resources.get("electrode_slice")) - - seal_gohm = str(metadata['seal_gohm']) - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_SEAL(trode), - seal_gohm + " GOhm") + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_DESCRIPTION(trode), resources.get("electrode_description")) + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_FILTERING(trode), resources.get("electrode_filtering")) + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_DEVICE(trode), resources.get("electrode_device")) + + location = metadata["location"] + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_LOCATION(trode), location["structure_name"], **location) + + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_RESISTANCE(trode), resources.get("electrode_resistance")) + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_SLICE(trode), resources.get("electrode_slice")) + + seal_gohm = str(metadata["seal_gohm"]) + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_SEAL(trode), seal_gohm + " GOhm") acc = str(metadata["initial_access_resistance_mohm"]) - nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_INIT_ACCESS_RESISTANCE(trode), - acc + " MOhm") + nwb_file.set_metadata(nwbco.INTRA_ELECTRODE_INIT_ACCESS_RESISTANCE(trode), acc + " MOhm") - session = {'comments': 'session_id value corresponds to ephys_result_id'} - nwb_file.set_metadata(nwbco.SESSION_ID, - str(metadata['ephys_roi_result_id']), **session) + session = {"comments": "session_id value corresponds to ephys_result_id"} + nwb_file.set_metadata(nwbco.SESSION_ID, str(metadata["ephys_roi_result_id"]), **session) - nwb_file.set_metadata("aibs_specimen_name", metadata['specimen_name']) - nwb_file.set_metadata("aibs_specimen_id", str(metadata['specimen_id'])) + nwb_file.set_metadata("aibs_specimen_name", metadata["specimen_name"]) + nwb_file.set_metadata("aibs_specimen_id", str(metadata["specimen_id"])) - nwb_file.set_metadata("aibs_dendrite_type", metadata['dendrite_type']) - nwb_file.set_metadata("aibs_dendrite_trunc", metadata['dendrite_trunc']) + nwb_file.set_metadata("aibs_dendrite_type", metadata["dendrite_type"]) + nwb_file.set_metadata("aibs_dendrite_trunc", metadata["dendrite_trunc"]) # IT-12498 add additional metadata to NWB file nwb_file.set_metadata(nwbco.DATA_COLLECTION, metadata["data_collection"]) @@ -451,7 +424,7 @@ def write_metadata(nwb_file, resources, metadata): nwb_file.set_metadata(nwbco.PROTOCOL, metadata["protocol"]) nwb_file.set_metadata(nwbco.PHARMACOLOGY, metadata["pharmacology"]) nwb_file.set_metadata("citation_policy", metadata["citation_policy"]) - nwb_file.set_metadata(nwbco.SPECIES, metadata['species']) + nwb_file.set_metadata(nwbco.SPECIES, metadata["species"]) def main(jin): @@ -469,7 +442,7 @@ def main(jin): # TODO dig deeper here # only fetching metadata for passing sweeps - passing_sweeps = metadata['sweeps'].keys() + passing_sweeps = metadata["sweeps"].keys() copy_file(infile, outfile, passing_sweeps, rsrc, metadata) @@ -538,7 +511,7 @@ def main(jin): sweeps = jin[0]["specimens"][0]["ephys_sweeps"] for grp in acq: try: - num = int(str(grp).split('_')[-1]) + num = int(str(grp).split("_")[-1]) except Exception: continue try: @@ -552,39 +525,33 @@ def main(jin): # stim amplitude amp = sweep["stimulus_amplitude"] if amp is None: - amp = float('nan') + amp = float("nan") else: amp = float(amp) - ds = acq["Sweep_%d" % num].create_dataset( - "aibs_stimulus_amplitude_pa", data=amp) + ds = acq["Sweep_%d" % num].create_dataset("aibs_stimulus_amplitude_pa", data=amp) ds.attrs["neurodata_type"] = "Custom" - ds = stim["Sweep_%d" % num].create_dataset( - "aibs_stimulus_amplitude_pa", data=amp) + ds = stim["Sweep_%d" % num].create_dataset("aibs_stimulus_amplitude_pa", data=amp) ds.attrs["neurodata_type"] = "Custom" # stim interval interval = sweep["stimulus_interval"] if interval is None: - interval = float('nan') + interval = float("nan") else: interval = float(interval) - ds = acq["Sweep_%d" % num].create_dataset("aibs_stimulus_interval", - data=interval) + ds = acq["Sweep_%d" % num].create_dataset("aibs_stimulus_interval", data=interval) ds.attrs["neurodata_type"] = "Custom" - ds = stim["Sweep_%d" % num].create_dataset( - "aibs_stimulus_interval", data=interval) + ds = stim["Sweep_%d" % num].create_dataset("aibs_stimulus_interval", data=interval) ds.attrs["neurodata_type"] = "Custom" # stim name name = sweep["ephys_stimulus"]["ephys_stimulus_type"]["name"] - ds = acq["Sweep_%d" % num].create_dataset("aibs_stimulus_name", - data=name) + ds = acq["Sweep_%d" % num].create_dataset("aibs_stimulus_name", data=name) ds.attrs["neurodata_type"] = "Custom" - ds = stim["Sweep_%d" % num].create_dataset("aibs_stimulus_name", - data=name) + ds = stim["Sweep_%d" % num].create_dataset("aibs_stimulus_name", data=name) ds.attrs["neurodata_type"] = "Custom" # seal seal = jin[0]["seal_gohm"] if seal is None: - seal = float('nan') + seal = float("nan") else: seal = float(seal) ds = acq["Sweep_%d" % num].create_dataset("seal", data=seal) @@ -594,14 +561,12 @@ def main(jin): # initial access resistance res = jin[0]["initial_access_resistance_mohm"] if res is None: - res = float('nan') + res = float("nan") else: res = float(res) - ds = acq["Sweep_%d" % num].create_dataset( - "initial_access_resistance", data=res) + ds = acq["Sweep_%d" % num].create_dataset("initial_access_resistance", data=res) ds.attrs["neurodata_type"] = "Custom" - ds = stim["Sweep_%d" % num].create_dataset( - "initial_access_resistance", data=res) + ds = stim["Sweep_%d" % num].create_dataset("initial_access_resistance", data=res) ds.attrs["neurodata_type"] = "Custom" # # # recycle code from old publish module for custom sweep @@ -636,25 +601,21 @@ def main(jin): # TODO describe what's happening here sweeps_by_type = defaultdict(list) - for sweep_number, sweep_data in metadata['sweeps'].items(): - if sweep_data["stimulus_units"] in ["pA", - "Amps"]: # only compute spikes + for sweep_number, sweep_data in metadata["sweeps"].items(): + if sweep_data["stimulus_units"] in ["pA", "Amps"]: # only compute spikes # for current clamp sweeps - sweeps_by_type[sweep_data['stimulus_type_name']].append( - sweep_number) + sweeps_by_type[sweep_data["stimulus_type_name"]].append(sweep_number) - sweep_features = extract_cell_features.extract_sweep_features( - NwbDataSet(outfile), sweeps_by_type) + sweep_features = extract_cell_features.extract_sweep_features(NwbDataSet(outfile), sweeps_by_type) # TODO describe what's happening here for sweep_num in passing_sweeps: try: - spikes = sweep_features[sweep_num]['spikes'] - spike_times = [s['threshold_t'] for s in spikes] + spikes = sweep_features[sweep_num]["spikes"] + spike_times = [s["threshold_t"] for s in spikes] NwbDataSet(outfile).set_spike_times(sweep_num, spike_times) except Exception as e: - logging.info( - "sweep %d has no sweep features. %s" % (sweep_num, e.message)) + logging.info("sweep %d has no sweep features. %s" % (sweep_num, e.message)) empty = {} return empty diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc.py index 800049e5d0..83395fd158 100755 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc.py @@ -5,13 +5,13 @@ from allensdk.internal.core.lims_pipeline_module import PipelineModule from allensdk.core.nwb_data_set import NwbDataSet - + def main(jin): # load QC criteria and sweep table from input json file try: - qc_criteria = jin['ephys_qc_criteria'] - experiment_data = jin['experiment_data'] - sweep_data = jin['sweep_data'] + qc_criteria = jin["ephys_qc_criteria"] + experiment_data = jin["experiment_data"] + sweep_data = jin["sweep_data"] nwb_file = jin["nwb_file"] except Exception: raise IOError("Input json file is missing requisite data") @@ -53,7 +53,6 @@ def main(jin): exp_fail_tags.append("Error analyzing blowout. " + e.message) experiment_state["failed_blowout"] = True - # "electrode 0" experiment_state["failed_electrode_0"] = False try: @@ -68,7 +67,6 @@ def main(jin): exp_fail_tags.append("Error analyzing blowout. " + e.message) experiment_state["failed_electrode_0"] = True - # measure clamp seal experiment_state["failed_seal"] = False try: @@ -88,18 +86,17 @@ def main(jin): exp_fail_tags.append(msg) experiment_state["failed_seal"] = True - # input and access resistance sr_tags = [] try: - sir_ratio = experiment_data['input_access_resistance_ratio'] - #r = experiment_data['input_resistance_mohm'] + sir_ratio = experiment_data["input_access_resistance_ratio"] + # r = experiment_data['input_resistance_mohm'] except Exception: sr_tags.append("Resistance ratio not available") try: - sr = experiment_data['initial_access_resistance_mohm'] + sr = experiment_data["initial_access_resistance_mohm"] except Exception: sr_tags.append("Initial access resistance not available") @@ -129,10 +126,8 @@ def main(jin): if len(sr_tags) > 0: exp_fail_tags.extend(sr_tags) - experiment_state["fail_tags"] = exp_fail_tags - #################################################################### # check features for each sweep sweep_state = {} @@ -150,7 +145,7 @@ def main(jin): unit = sweep["stimulus_units"] # determine if sweep is current or voltage clamp # name may end in "[#]", so strip out section after open bracket - stim_short = stim.split('[')[0] + stim_short = stim.split("[")[0] if stim_short in jin["voltage_clamp_stimuli"]: if unit != "Volts" and unit != "mV": msg = "%s (%s) in wrong mode -- expected voltage clamp" % (name, stim) @@ -163,33 +158,33 @@ def main(jin): fail_tags.append("%s has unrecognized stimulus (%s)" % (name, stim)) if unit == "Volts" or unit == "mV": - continue # no QC on voltage clamp + continue # no QC on voltage clamp if len(fail_tags) > 0: sweep_state[name] = {} sweep_state[name]["state"] = "Fail" sweep_state[name]["reasons"] = fail_tags continue - + # pull data streams from file (this is for detecting truncated # sweeps) sweep_data = NwbDataSet(nwb_file).get_sweep(sweep_num) - current = sweep_data['stimulus'] - idx_start, idx_stop = sweep_data['index_range'] + current = sweep_data["stimulus"] + idx_start, idx_stop = sweep_data["index_range"] if sweep["pre_noise_rms_mv"] > qc_criteria["pre_noise_rms_mv_max"]: fail_tags.append("pre-noise") # check Vm and noise at end of recording - # only do so if acquisition not truncated - # do not check for ramps, because they do not have + # only do so if acquisition not truncated + # do not check for ramps, because they do not have # enough time to recover - is_ramp = stim.startswith('C1RP') + is_ramp = stim.startswith("C1RP") if is_ramp: logging.info("sweep %d skipping vrest criteria on ramp", sweep_num) else: # measure post-stimulus noise - sweep_not_truncated = ( idx_stop == len(current) - 1 ) + sweep_not_truncated = idx_stop == len(current) - 1 if sweep_not_truncated: post_noise_rms_mv = sweep["post_noise_rms_mv"] if post_noise_rms_mv > qc_criteria["post_noise_rms_mv_max"]: @@ -203,7 +198,6 @@ def main(jin): if sweep["vm_delta_mv"] > qc_criteria["vm_delta_mv_max"]: fail_tags.append("Vm delta") - # fail sweeps if stimulus duration is zero # Uncomment out hte following 3 lines to have sweeps without stimulus # faile QC @@ -212,7 +206,6 @@ def main(jin): if not desc.startswith("EXTP"): fail_tags.append("No stimulus detected") - sweep_state[name] = {} if len(fail_tags) > 0: sweep_state[name]["state"] = "Fail" @@ -233,11 +226,11 @@ def main(jin): return jout -if __name__ == "__main__": - # read module input. PipelineModule object automatically parses the + +if __name__ == "__main__": + # read module input. PipelineModule object automatically parses the # command line to pull out input.json and output.json file names module = PipelineModule() - jin = module.input_data() # loads input.json + jin = module.input_data() # loads input.json jout = main(jin) module.write_output_data(jout) # writes output.json - diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc_support.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc_support.py index 2c6fabfcb3..28c19da601 100644 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc_support.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/qc_support.py @@ -1,5 +1,6 @@ import numpy as np + def measure_vm(seg): vals = np.copy(seg) if len(vals) < 1: @@ -9,26 +10,32 @@ def measure_vm(seg): rms = np.sqrt(np.mean(np.square(vals))) return mean, rms + ######################################################################## # experiment-level metrics + def measure_blowout(v, idx0): return 1e3 * np.mean(v[idx0:]) + def measure_electrode_0(curr, hz, t=0.005): n_time_steps = int(t * hz) # electrode 0 is the average current reading with zero voltage input # (ie, the equivalent of resting potential in current-clamp mode) return 1e12 * np.mean(curr[0:n_time_steps]) + def measure_seal(v, curr, hz): t = np.arange(len(v)) / hz return 1e-9 * get_r_from_stable_pulse_response(v, curr, t) + def measure_input_resistance(v, curr, hz): t = np.arange(len(v)) / hz return 1e-6 * get_r_from_stable_pulse_response(v, curr, t) + def measure_initial_access_resistance(v, curr, hz): t = np.arange(len(v)) / hz return 1e-6 * get_r_from_peak_pulse_response(v, curr, t) @@ -36,13 +43,14 @@ def measure_initial_access_resistance(v, curr, hz): ######################################################################## + def get_r_from_stable_pulse_response(v, i, t): dv = np.diff(v) up_idx = np.flatnonzero(dv > 0) down_idx = np.flatnonzero(dv < 0) -# print up_idx -# print down_idx -# print "-----" + # print up_idx + # print down_idx + # print "-----" dt = t[1] - t[0] one_ms = int(0.001 / dt) r = [] @@ -50,26 +58,27 @@ def get_r_from_stable_pulse_response(v, i, t): # take average v and i one ms before end = up_idx[ii] - 1 start = end - one_ms -# print "\tbase" -# print "base interval: %d -> %d" % (start, end) + # print "\tbase" + # print "base interval: %d -> %d" % (start, end) avg_v_base = np.mean(v[start:end]) avg_i_base = np.mean(i[start:end]) -# print "\tv: %g" % avg_v_base -# print "\ti: %g" % avg_i_base + # print "\tv: %g" % avg_v_base + # print "\ti: %g" % avg_i_base # take average v and i one ms before end - end = down_idx[ii]-1 + end = down_idx[ii] - 1 start = end - one_ms -# print "\tsteady" -# print "steady interval: %d -> %d" % (start, end) + # print "\tsteady" + # print "steady interval: %d -> %d" % (start, end) avg_v_steady = np.mean(v[start:end]) avg_i_steady = np.mean(i[start:end]) -# print "\tv: %g" % avg_v_steady -# print "\ti: %g" % avg_i_steady - r_instance = (avg_v_steady-avg_v_base) / (avg_i_steady-avg_i_base) -# print 1e-6*r_instance + # print "\tv: %g" % avg_v_steady + # print "\ti: %g" % avg_i_steady + r_instance = (avg_v_steady - avg_v_base) / (avg_i_steady - avg_i_base) + # print 1e-6*r_instance r.append(r_instance) return np.mean(r) + def get_r_from_peak_pulse_response(v, i, t): dv = np.diff(v) up_idx = np.flatnonzero(dv > 0) @@ -89,34 +98,35 @@ def get_r_from_peak_pulse_response(v, i, t): idx = start + np.argmax(i[start:end]) avg_v_peak = v[idx] avg_i_peak = i[idx] - r_instance = (avg_v_peak-avg_v_base) / (avg_i_peak-avg_i_base) + r_instance = (avg_v_peak - avg_v_base) / (avg_i_peak - avg_i_base) r.append(r_instance) return np.mean(r) - - - def get_last_vm_epoch(idx1, stim, hz): - return idx1-int(0.500 * hz), idx1 + return idx1 - int(0.500 * hz), idx1 + def get_first_vm_noise_epoch(idx0, stim, hz): t0 = idx0 t1 = t0 + int(0.0015 * hz) return t0, t1 + def get_last_vm_noise_epoch(idx1, stim, hz): - return idx1-int(0.0015 * hz), idx1 + return idx1 - int(0.0015 * hz), idx1 + -#def get_stability_vm_epoch(idx0, stim, hz): +# def get_stability_vm_epoch(idx0, stim, hz): def get_stability_vm_epoch(idx0, stim_start, hz): dur = int(0.500 * hz) - #stim_start = find_stim_start(idx0, stim) - if dur > stim_start-1: - dur = stim_start-1 + # stim_start = find_stim_start(idx0, stim) + if dur > stim_start - 1: + dur = stim_start - 1 elif dur <= 0: return 0, 0 - return stim_start-1-dur, stim_start-1 + return stim_start - 1 - dur, stim_start - 1 + def find_stim_start(idx0, stim): # find stim start, using adaptation of nathan's numpy algorithm @@ -135,8 +145,8 @@ def find_stim_start(idx0, stim): # +1 to be first index of stim, not last index of pre-stim return first + 1 -def find_stim_amplitude_and_duration(idx0, stim, hz): +def find_stim_amplitude_and_duration(idx0, stim, hz): if len(stim) < idx0: idx0 = 0 @@ -147,14 +157,14 @@ def find_stim_amplitude_and_duration(idx0, stim, hz): # measure stimulus length # find index of first non-zero value, and last return to zero - nzero = np.where(stim!=0)[0] + nzero = np.where(stim != 0)[0] if len(nzero) > 0: start = nzero[0] end = nzero[-1] dur = (end - start) / hz else: dur = 0 - + dur = float(dur) if abs(peak_high) > abs(peak_low): @@ -164,13 +174,14 @@ def find_stim_amplitude_and_duration(idx0, stim, hz): return amp, dur + def find_stim_interval(idx0, stim, hz): stim = stim[idx0:] # indices where is the stimulus off zero_idxs = np.where(stim == 0)[0] - # derivative of off indices. when greater than one, indicates on period + # derivative of off indices. when greater than one, indicates on period dzero_idxs = np.diff(zero_idxs) dzero_break_idxs = np.where(dzero_idxs[:] > 1)[0] diff --git a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/resource_file.py b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/resource_file.py index 4ba410cd5f..6cef369550 100644 --- a/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/resource_file.py +++ b/allensdk/internal/pipeline_modules/IVSCC/ephys_nwb/resource_file.py @@ -1,16 +1,17 @@ import sys import string + class ResourceFile(object): def __init__(self): self.key_value = {} self.accessed = {} def load(self, infile): - """ Reads in a json, yaml or toml resource file. If multiple files - are loaded, resources between them are merged. An error - occurs if the same key is loaded multiple times and different - values are defined for it + """Reads in a json, yaml or toml resource file. If multiple files + are loaded, resources between them are merged. An error + occurs if the same key is loaded multiple times and different + values are defined for it """ if infile.endswith("json"): self.load_json(infile) @@ -21,38 +22,37 @@ def load(self, infile): else: print("Unrecognized extension for file '%s'. Please use json, yaml or toml" % infile) sys.exit(1) - + def get(self, key, default=None, replace_table=None): - """ Returns the resource value associated with the provided key. + """Returns the resource value associated with the provided key. - Arguments: - *key* (text) Name of resource + Arguments: + *key* (text) Name of resource - *default* (text) Value to be returned if key isn't found + *default* (text) Value to be returned if key isn't found - *replace_table* (dict) Substrings that are to be replaced - in resource string. E.g., if replace_table = { foo: "bar" } - then all instances of "foo" in the resource string will be - replaced with "bar" + *replace_table* (dict) Substrings that are to be replaced + in resource string. E.g., if replace_table = { foo: "bar" } + then all instances of "foo" in the resource string will be + replaced with "bar" - Returns: - Resource string if found and default value if not, with - applied substitutions from replace_table + Returns: + Resource string if found and default value if not, with + applied substitutions from replace_table """ self.accessed[key] = True val = self.key_value.get(key, default) if replace_table is not None: - for k,v in replace_table.items(): + for k, v in replace_table.items(): val = string.replace(val, k, v) return val def report(self): - """ Reports all resources that were defined but not used - """ + """Reports all resources that were defined but not used""" err = False print("-------------------------------") print("--- Resource report --") - for k,v in self.accessed.items(): + for k, v in self.accessed.items(): if not v: if not err: err = True @@ -64,11 +64,11 @@ def report(self): #################################################################### # internal procedure to load json file def load_json(self, infile): - """ Reads input json, yaml or toml file - """ + """Reads input json, yaml or toml file""" import json + try: - with open(infile, 'r') as f: + with open(infile, "r") as f: resources = json.load(f) f.close() except IOError: @@ -80,8 +80,9 @@ def load_json(self, infile): def load_yaml(self, infile): try: import yaml + try: - with open(infile, 'r') as f: + with open(infile, "r") as f: resources = yaml.load(f) f.close() except IOError: @@ -96,8 +97,9 @@ def load_yaml(self, infile): def load_toml(self, infile): try: import toml + try: - with open(infile, 'r') as f: + with open(infile, "r") as f: resources = toml.load(f) f.close() except IOError: @@ -108,11 +110,10 @@ def load_toml(self, infile): print("*** toml not available -- please pip install toml") sys.exit(1) - # internal procedure to read keys out of dictionary recursively def read_keys(self, resources): err = False - for k,v in resources.items(): + for k, v in resources.items(): if isinstance(v, dict): self.read_keys(v) else: @@ -128,12 +129,13 @@ def read_keys(self, resources): class ResourceFileTest(object): def create_json(self, fname, rsrc): import json + d = {} d["jone"] = "json one" d["jtwo"] = "json one" d["jsub"] = {} d["jsub"]["jthree"] = "json three" - with open(fname, 'w') as f: + with open(fname, "w") as f: json.dump(d, f, indent=2) f.close() rsrc.load(fname) @@ -141,12 +143,13 @@ def create_json(self, fname, rsrc): def create_yaml(self, fname, rsrc): try: import yaml + d = {} d["yone"] = "yaml one" d["ytwo"] = "yaml one" d["ysub"] = {} d["ysub"]["ythree"] = "yaml three" - with open(fname, 'w') as f: + with open(fname, "w") as f: yaml.dump(d, f, indent=2) f.close() rsrc.load(fname) @@ -156,12 +159,13 @@ def create_yaml(self, fname, rsrc): def create_toml(self, fname, rsrc): try: import toml + d = {} d["tone"] = "toml one" d["ttwo"] = "toml one" d["tsub"] = {} d["tsub"]["tthree"] = "toml three" - with open(fname, 'w') as f: + with open(fname, "w") as f: toml.dump(d, f) f.close() rsrc.load(fname) @@ -183,4 +187,3 @@ def run(self): print(rsrc.get("tthree", "** toml error")) # run report rsrc.report() - diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/calculate_features.py b/allensdk/internal/pipeline_modules/cell_types/morphology/calculate_features.py index 0fd637d076..78faaff5a1 100644 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/calculate_features.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/calculate_features.py @@ -1,4 +1,3 @@ - import neuron_morphology.swc as swc import neuron_morphology.features.feature_extractor as feature_extractor from allensdk.internal.core.lims_pipeline_module import PipelineModule @@ -6,6 +5,7 @@ ######################################################################## + def main(jin): try: swc_file = jin["swc_file"] @@ -15,7 +15,6 @@ def main(jin): print("** Unable to find requisite fields in input json") raise - #################################################################### # calculate features @@ -34,11 +33,11 @@ def main(jin): print("** Error applying affine transform") raise - #try: + # try: # # save a copy of affine-corrected file # tmp_swc_file = swc_file[:-4] + "_pia.swc" # nrn.write(tmp_swc_file) - #except: + # except: # # treat this as a soft error and print a warning # print("Note: unable to write copy of affine corrected pia file") @@ -81,14 +80,13 @@ def main(jin): feat["soma_surface"] = data["dendrite"]["soma_surface"] feat["overall_height"] = data["dendrite"]["height"] - md["features"] = feat data["morphology_data"] = md return data -if __name__=='__main__': +if __name__ == "__main__": module = PipelineModule() jin = module.input_data() jout = main(jin) diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/cortical_layers.py b/allensdk/internal/pipeline_modules/cell_types/morphology/cortical_layers.py index 8b65dd6192..d6dbfe4cdf 100755 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/cortical_layers.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/cortical_layers.py @@ -6,7 +6,7 @@ from neuron_morphology import swc from allensdk.internal.core.lims_pipeline_module import PipelineModule -#from surrogate_strategy import prep_json +# from surrogate_strategy import prep_json # TODO update run_python.sh to include this path twok_dir = "/shared/bioapps/itk/itk_shared/jp2/build" @@ -18,20 +18,21 @@ ######################################################################## # helper functions + def calculate_centroid(x, y): - ''' Calculates the center of a polygon, using weighted averages + """Calculates the center of a polygon, using weighted averages of vertex locations - ''' + """ assert len(x) == len(y), "Vertex arrays are of incorrect shape" tot_len = 0.0 tot_x = 0.0 tot_y = 0.0 for i in range(len(x)): - x0 = x[i-1] - y0 = y[i-1] + x0 = x[i - 1] + y0 = y[i - 1] x1 = x[i] y1 = y[i] - seg_len = math.sqrt((x1-x0)*(x1-x0) + (y1-y0)*(y1-y0)) + seg_len = math.sqrt((x1 - x0) * (x1 - x0) + (y1 - y0) * (y1 - y0)) tot_len += seg_len tot_x += seg_len * x0 tot_x += seg_len * x1 @@ -41,49 +42,52 @@ def calculate_centroid(x, y): tot_y /= 2.0 * tot_len return tot_x, tot_y + def convert_coords_str(coord_str): - vals = coord_str.split(',') + vals = coord_str.split(",") x = np.array(vals[0::2], dtype=float) y = np.array(vals[1::2], dtype=float) return x, y color_table = [] -color_table.append((255, 77, 77)) +color_table.append((255, 77, 77)) color_table.append((102, 102, 255)) -color_table.append(( 25, 255, 25)) +color_table.append((25, 255, 25)) color_table.append((177, 166, 255)) -color_table.append(( 46, 230, 230)) -color_table.append((255, 77, 255)) -color_table.append((128, 230, 46)) -color_table.append((255, 166, 77)) +color_table.append((46, 230, 230)) +color_table.append((255, 77, 255)) +color_table.append((128, 230, 46)) +color_table.append((255, 166, 77)) color_table.append((179, 179, 179)) -color_table.append(( 77, 255, 166)) -color_table.append((229, 229, 46)) -color_table.append((255, 51, 153)) -color_table.append((166, 77, 255)) -color_table.append((151, 166, 86)) - -color_table.append((153, 0, 0)) -color_table.append(( 0, 77, 153)) -color_table.append((153, 153, 0)) -color_table.append(( 77, 153, 0)) -color_table.append(( 0, 153, 153)) -color_table.append(( 13, 128, 13)) -color_table.append((153, 77, 0)) -color_table.append(( 0, 0, 153)) -color_table.append((153, 0, 153)) -color_table.append(( 0, 179, 89)) +color_table.append((77, 255, 166)) +color_table.append((229, 229, 46)) +color_table.append((255, 51, 153)) +color_table.append((166, 77, 255)) +color_table.append((151, 166, 86)) + +color_table.append((153, 0, 0)) +color_table.append((0, 77, 153)) +color_table.append((153, 153, 0)) +color_table.append((77, 153, 0)) +color_table.append((0, 153, 153)) +color_table.append((13, 128, 13)) +color_table.append((153, 77, 0)) +color_table.append((0, 0, 153)) +color_table.append((153, 0, 153)) +color_table.append((0, 179, 89)) color_table.append((102, 102, 102)) -color_table.append(( 77, 0, 153)) -color_table.append((153, 0, 77)) -color_table.append(( 78, 89, 30)) +color_table.append((77, 0, 153)) +color_table.append((153, 0, 77)) +color_table.append((78, 89, 30)) + def color_by_index(i): global color_table # return color_table[i % len(color_table)] + def draw_morphology(nrn, img, somax, somay, color_by_layer=False): global LINE_WIDTH, resolution # @@ -106,28 +110,29 @@ def draw_morphology(nrn, img, somax, somay, color_by_layer=False): color = dend_col elif c.node2.t == 4: color = apical_col - cv2.line(img, (x0,y0), (x1,y1), color, LINE_WIDTH) + cv2.line(img, (x0, y0), (x1, y1), color, LINE_WIDTH) + def write_svg(svgname, jin, nrn): resolution = jin["resolution"] dx = jin["soma"]["position"][0] - nrn.soma_root().x / resolution dy = jin["soma"]["position"][1] - nrn.soma_root().y / resolution with open(svgname, "w") as f: - #f.write('\n') + # f.write('\n') f.write('\n') # soma soma = jin["soma"]["path"][0] - coords = soma.split(',') + coords = soma.split(",") f.write(' \n') for layer in jin["layers"]: f.write(' \n') for c in nrn.compartment_list: try: @@ -138,17 +143,23 @@ def write_svg(svgname, jin, nrn): x1 = int(c.node2.x / resolution + dx) y0 = int(c.node1.y / resolution + dy) y1 = int(c.node2.y / resolution + dy) - f.write(' \n' % (x0, y0, x1, y1, color[0], color[1], color[2])) - f.write('\n') + f.write( + ' \n' + % (x0, y0, x1, y1, color[0], color[1], color[2]) + ) + f.write("\n") + ######################################################################## ######################################################################## # # global values, shared between functions -resolution = None # microns-per-pixel, from input.json +resolution = None # microns-per-pixel, from input.json LINE_WIDTH = 1 # default pen width -DOWNSAMPLE_STEPS = 1 # default image pyramid level -# +DOWNSAMPLE_STEPS = 1 # default image pyramid level + + +# def main(jin): global resolution, LINE_WIDTH, DOWNSAMPLE_STEPS jout = {} @@ -182,12 +193,12 @@ def main(jin): # calculate soma position and store in jin structure soma_res = jin["soma"]["path"] - soma_path = soma_res[0].split(',') + soma_path = soma_res[0].split(",") soma_x = np.array(soma_path[0::2], dtype=float) soma_y = np.array(soma_path[1::2], dtype=float) soma_path = [] for i in range(len(soma_x)): - soma_path.append([soma_x[i],soma_y[i]]) + soma_path.append([soma_x[i], soma_y[i]]) soma_path = np.array(soma_path, np.int32) soma_pos = calculate_centroid(soma_x, soma_y) jin["soma"]["position"] = soma_pos @@ -212,7 +223,7 @@ def main(jin): # get soma position in 20x pixel space, at present downsample level # resolution converts soma coords in microns to pixels # (resolution is microns / pixel) - resolution *= DOWNSAMPLE # divide pixels by X means mult res by same + resolution *= DOWNSAMPLE # divide pixels by X means mult res by same print("Image resolution (microns/pixel): %f" % resolution) dx = jin["soma"]["position"][0] / DOWNSAMPLE - morph.soma_root().x / resolution dy = jin["soma"]["position"][1] / DOWNSAMPLE - morph.soma_root().y / resolution @@ -220,9 +231,9 @@ def main(jin): dy = int(dy) ############################## - # no point in processing the entire image, as only a small part + # no point in processing the entire image, as only a small part # is relevant - # select min/max values for x,y of all polygons. restrict analysis + # select min/max values for x,y of all polygons. restrict analysis # to there min_x = 1e10 min_y = 1e10 @@ -230,7 +241,7 @@ def main(jin): max_y = 0 layers = jin["layers"] for layer in layers: - path_array = np.array(layer["path"].split(',')) + path_array = np.array(layer["path"].split(",")) x = np.array(path_array[0::2], dtype=float) y = np.array(path_array[1::2], dtype=float) min_x = min(min_x, x.min()) @@ -254,44 +265,43 @@ def main(jin): # make sure border doesn't extend beyond image limits TOP = max(TOP, 0) LEFT = max(LEFT, 0) - RIGHT = min(RIGHT, abs_width-1) - BOTTOM = min(BOTTOM, abs_height-1) + RIGHT = min(RIGHT, abs_width - 1) + BOTTOM = min(BOTTOM, abs_height - 1) WIDTH = int(RIGHT - LEFT) HEIGHT = int(BOTTOM - TOP) # adjust soma location for top and left of visible image area dx -= LEFT dy -= TOP - #print "inset soma position", dx, dy - + # print "inset soma position", dx, dy # make frame for each polygon and blur. blur radious should be # approx the size of largest gap or overlap between polygons. this # is for estimating which polygon each point is a best fit in layers = jin["layers"] for layer in layers: - path_array = np.array(layer["path"].split(',')) + path_array = np.array(layer["path"].split(",")) x = np.array(path_array[0::2], dtype=float) x /= DOWNSAMPLE x -= LEFT y = np.array(path_array[1::2], dtype=float) y /= DOWNSAMPLE y -= TOP - #print layer["label"] - #print x.min(), y.min() - #print x.max(), y.max() + # print layer["label"] + # print x.min(), y.min() + # print x.max(), y.max() xy = [] for i in range(len(x)): - xy.append([x[i],y[i]]) + xy.append([x[i], y[i]]) raw_frame = np.zeros((HEIGHT, WIDTH)) path = np.array(xy) cv2.fillPoly(raw_frame, np.int32([path]), 255) frame = cv2.blur(raw_frame, (GAUS_RAD, GAUS_RAD)) - layer["frame"] = frame # blurred polygon + layer["frame"] = frame # blurred polygon layer["raw_frame"] = raw_frame # raw polygon - # collapse all polys into single array, with value at each position - # corresponding to the index of the polygon that the pixel falls + # collapse all polys into single array, with value at each position + # corresponding to the index of the polygon that the pixel falls # into, or -1 if there's no match master = np.zeros((HEIGHT, WIDTH, 3)) master_idx = np.zeros((HEIGHT, WIDTH), dtype=int) @@ -320,7 +330,7 @@ def main(jin): ################################################# # draw standard morphology on 20x image - img = image_20x[TOP:BOTTOM,LEFT:RIGHT] + img = image_20x[TOP:BOTTOM, LEFT:RIGHT] print(img.shape) draw_morphology(morph, img, int(dx), int(dy)) outfile = "blockface_%d.png" % spec_id @@ -359,10 +369,10 @@ def main(jin): print("saving " + outfile) jout["outline_svg"] = outfile write_svg(outfile, jin, morph) - + ################################################# # draw layer-colored morphology on 20x image - img = image_20x[TOP:BOTTOM,LEFT:RIGHT] + img = image_20x[TOP:BOTTOM, LEFT:RIGHT] draw_morphology(morph, img, int(dx), int(dy), True) outfile = "layered_blockface_%d.png" % spec_id print("saving " + outfile) @@ -374,16 +384,16 @@ def main(jin): img = np.zeros((HEIGHT, WIDTH, 3)) layers = jin["layers"] for layer in layers: - path_array = np.array(layer["path"].split(',')) + path_array = np.array(layer["path"].split(",")) x = np.array(path_array[0::2], dtype=float) x /= DOWNSAMPLE x -= LEFT y = np.array(path_array[1::2], dtype=float) y /= DOWNSAMPLE y -= TOP - for i in range(1,len(x)): - cv2.line(img, (int(x[i-1]),int(y[i-1])), (int(x[i]),int(y[i])), (255, 255, 255), 1) - cv2.line(img, (int(x[-1]),int(y[-1])), (int(x[0]),int(y[0])), (255, 255, 255), 1) + for i in range(1, len(x)): + cv2.line(img, (int(x[i - 1]), int(y[i - 1])), (int(x[i]), int(y[i])), (255, 255, 255), 1) + cv2.line(img, (int(x[-1]), int(y[-1])), (int(x[0]), int(y[0])), (255, 255, 255), 1) draw_morphology(morph, img, int(dx), int(dy), True) outfile = "outline_%d.png" % spec_id print("saving " + outfile) @@ -394,28 +404,27 @@ def main(jin): # mouse -#ims_id = "489909914" -#spec_id = 488679042 +# ims_id = "489909914" +# spec_id = 488679042 -#ims_id = "491762612" -#spec_id = 490387590 +# ims_id = "491762612" +# spec_id = 490387590 # human -#ims_id = "487992082" -#spec_id = 488386504 +# ims_id = "487992082" +# spec_id = 488386504 -#ims_id = "488759189" -#spec_id = 488418027 +# ims_id = "488759189" +# spec_id = 488418027 -#spec_id = 528015670 +# spec_id = 528015670 if __name__ == "__main__": module = PipelineModule() - jin = module.input_data() # loads input.json + jin = module.input_data() # loads input.json # "get" input json - #jin = prep_json(spec_id) - #json.write("in_%d.json" % spec_id, jin) + # jin = prep_json(spec_id) + # json.write("in_%d.json" % spec_id, jin) jout = main(jin) module.write_output_data(jout) # writes output.json - #json.write("out_%d.json" % spec_id, jout) - + # json.write("out_%d.json" % spec_id, jout) diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py b/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py index ea67b95430..2dc67a3c27 100755 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py @@ -2,6 +2,7 @@ import sys import psycopg2 import psycopg2.extras + sys.path.append("/home/keithg/allen/allensd") import allensdk.core.json_utilities as json @@ -38,7 +39,7 @@ def prep_json(spec_id): order by 1 """ cursor.execute(layer_sql % spec_id) - #print layer_sql % spec_id + # print layer_sql % spec_id poly_info = cursor.fetchall() if len(poly_info) == 0: print("Error -- cannot no polygon data for Specimen %d" % spec_id) @@ -52,13 +53,13 @@ def prep_json(spec_id): block["label"] = label poly.append(block) ## break down string path into two numeric arrays - #path_array = np.array(path.split(',')) - #path_x = np.array(path_array[0::2], dtype=float) - #path_y = np.array(path_array[1::2], dtype=float) - #block["path_array"] = path_array - #block["path_x"] = path_x - #block["path_y"] = path_y - #poly[label] = block + # path_array = np.array(path.split(',')) + # path_x = np.array(path_array[0::2], dtype=float) + # path_y = np.array(path_array[1::2], dtype=float) + # block["path_array"] = path_array + # block["path_x"] = path_x + # block["path_y"] = path_y + # poly[label] = block jin["layers"] = poly jin["storage_directory"] = poly_info[0][2] jin["swc_file"] = poly_info[0][3] @@ -70,7 +71,6 @@ def prep_json(spec_id): jin["reconstruction_id"] = reconstruction_id jin["resolution"] = 0.363 - # it appears that we need to restrict image query to use this ims_id ims_id = poly_info[0][4] @@ -96,7 +96,7 @@ def prep_json(spec_id): and agl.name = 'Soma' and cell.id = %d """ - #print soma_sql % spec_id + # print soma_sql % spec_id cursor.execute(soma_sql % spec_id) soma_res = cursor.fetchall() som = {} @@ -126,25 +126,25 @@ def prep_json(spec_id): """ try: cursor.execute(img_sql % ims_id) - #cursor.execute(img_sql % spec_id) + # cursor.execute(img_sql % spec_id) img_res = cursor.fetchall() img_path = img_res[0][0] + img_res[0][1] - #img_path = "%s-20x.jpeg" % str(spec_id) + # img_path = "%s-20x.jpeg" % str(spec_id) except Exception: print("Error fetching path to 20x image from database") print(img_sql % spec_id) raise img = {} - #img["img_res"] = img_res + # img["img_res"] = img_res img["img_path"] = img_path jin["20x"] = img return jin + if __name__ == "__main__": spec_id = 490387590 jin = prep_json(spec_id) print("Test mode: creating input.json for specimen.id=%d" % spec_id) json.write("input.json", jin) - diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py b/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py index ae73f54c5a..fceb73ed0b 100644 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py @@ -5,18 +5,17 @@ import skimage.draw - def calculate_centroid(x, y): - ''' Calculates the center of a polygon, using weighted averages + """Calculates the center of a polygon, using weighted averages of vertex locations - ''' + """ assert len(x) == len(y), "Vertex arrays are of incorrect shape" tot_len = 0.0 tot_x = 0.0 tot_y = 0.0 for i in range(len(x)): - x0 = x[i-1] - y0 = y[i-1] + x0 = x[i - 1] + y0 = y[i - 1] x1 = x[i] y1 = y[i] seg_len = euclidean((x0, y0), (x1, y1)) @@ -31,15 +30,11 @@ def calculate_centroid(x, y): def construct_affine(theta): - tr_rot = [np.cos(theta), np.sin(theta), 0, - -np.sin(theta), np.cos(theta), 0, - 0, 0, 1, - 0, 0, 0 - ] + tr_rot = [np.cos(theta), np.sin(theta), 0, -np.sin(theta), np.cos(theta), 0, 0, 0, 1, 0, 0, 0] return tr_rot -#def get_pia_wm_rotation_transform(soma_coords, wm_coords, pia_coords, resolution): +# def get_pia_wm_rotation_transform(soma_coords, wm_coords, pia_coords, resolution): # # get soma position using weighted average of vertices # sx, sy = convert_coords_str(soma_coords) # soma_x, soma_y = calculate_centroid(sx, sy) @@ -56,13 +51,14 @@ def construct_affine(theta): def convert_coords_str(coord_str): - vals = coord_str.split(',') + vals = coord_str.split(",") x = np.array(vals[0::2], dtype=float) y = np.array(vals[1::2], dtype=float) return x, y + def calculate_shortest(soma_x, soma_y, pia, wm): - """ Calculates shortest distance through a point on the polygon wm + """Calculates shortest distance through a point on the polygon wm through the soma coordinates (soma_x, soma_y) and through a point on the polygon pia. @@ -78,29 +74,29 @@ def calculate_shortest(soma_x, soma_y, pia, wm): height = max(pia_ys) height = int(max(height, max(wm_ys))) - canvas = np.zeros((height+10, width+10, 3), dtype=np.uint8) - - for i in range(1,len(pia_xs)): - lr, lc = skimage.draw.line(int(pia_ys[i-1]), int(pia_xs[i-1]), int(pia_ys[i]), int(pia_xs[i])) - canvas[lr,lc,2] = 255 + canvas = np.zeros((height + 10, width + 10, 3), dtype=np.uint8) + + for i in range(1, len(pia_xs)): + lr, lc = skimage.draw.line(int(pia_ys[i - 1]), int(pia_xs[i - 1]), int(pia_ys[i]), int(pia_xs[i])) + canvas[lr, lc, 2] = 255 - for i in range(1,len(wm_xs)): - lr, lc = skimage.draw.line(int(wm_ys[i-1]), int(wm_xs[i-1]), int(wm_ys[i]), int(wm_xs[i])) - canvas[lr,lc,0] = 255 + for i in range(1, len(wm_xs)): + lr, lc = skimage.draw.line(int(wm_ys[i - 1]), int(wm_xs[i - 1]), int(wm_ys[i]), int(wm_xs[i])) + canvas[lr, lc, 0] = 255 # get points in white matter trace - wp_y, wp_x = np.nonzero(canvas[:,:,0]) + wp_y, wp_x = np.nonzero(canvas[:, :, 0]) ################## # draw an extended line from each wm pix through the soma - # (there are usually less WM pix than pia pix, so this should be + # (there are usually less WM pix than pia pix, so this should be # faster than iterating through pia pix) # make array of blue (pia) channel only to - pia = canvas[:,:,2] + pia = canvas[:, :, 2] # draw line from each wm pix through the soma and into infinity # (line terminates if/when it intersects with pia trace) min_dist = None - min_coord = None # stores [ pia_x, pia_y, wm_x, wm_y ] + min_coord = None # stores [ pia_x, pia_y, wm_x, wm_y ] for i in range(len(wp_x)): x0 = wp_x[i] y0 = wp_y[i] @@ -116,7 +112,7 @@ def calculate_shortest(soma_x, soma_y, pia, wm): if dx > dy: err = dx / 2.0 while x >= 0 and x < width and y >= 0 and y < height: - if pia[y,x] > 0: + if pia[y, x] > 0: dist = euclidean((x0, y0), (x, y)) if min_dist is None or min_dist > dist: min_dist = dist @@ -130,7 +126,7 @@ def calculate_shortest(soma_x, soma_y, pia, wm): else: err = dy / 2.0 while x >= 0 and x < width and y >= 0 and y < height: - if pia[y,x] > 0: + if pia[y, x] > 0: dist = euclidean((x0, y0), (x, y)) if min_dist is None or min_dist > dist: min_dist = dist @@ -160,7 +156,7 @@ def dist_proj_point_lineseg(p, q1, q2): # based on c code from http://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment l2 = euclidean(q1, q2) ** 2 if l2 == 0: - return euclidean(p, q1) # q1 == q2 case + return euclidean(p, q1) # q1 == q2 case t = max(0, min(1, np.dot(p - q1, q2 - q1) / l2)) proj = q1 + t * (q2 - q1) return euclidean(p, proj), proj @@ -169,8 +165,7 @@ def dist_proj_point_lineseg(p, q1, q2): def project_to_polyline(boundary, soma): x, y = convert_coords_str(boundary) points = zip(x, y) - dists_projs = [dist_proj_point_lineseg(soma, np.array(q1), np.array(q2)) - for q1, q2 in zip(points[:-1], points[1:])] + dists_projs = [dist_proj_point_lineseg(soma, np.array(q1), np.array(q2)) for q1, q2 in zip(points[:-1], points[1:])] min_idx = np.argmin(np.array([d[0] for d in dists_projs])) return dists_projs[min_idx][1] @@ -189,47 +184,47 @@ def main(jin): # per IT-14567, blockface analysis is no longer required ######################################################################### ## analyze blockface image - #try: + # try: # soma = jin["blockface"]["Soma"]["path"] # pia = jin["blockface"]["Pia"]["path"] # wm = jin["blockface"]["White Matter"]["path"] # res = float(jin["blockface"]["Pia"]["resolution"]) - #except: + # except: # print("** Error -- missing requisite blockface field(s) in input json") # raise # ## get soma position using weighted average of vertices - #try: + # try: # sx, sy = convert_coords_str(soma) # soma_x, soma_y = calculate_centroid(sx, sy) - #except: + # except: # print("** Error -- unable to calculate soma information (blockface)") # raise # ## calculate shortest path - #try: + # try: # px, py, wx, wy = calculate_shortest(soma_x, soma_y, pia, wm) ## calculate theta and affine # theta = vector_angle((0, 1), np.asarray([px,py]) - np.asarray([wx,wy])) ## calculate soma depth and cortical thickness # depth = res * euclidean((soma_x, soma_y), (px, py)) # blk_thickness = res * euclidean((wx, wy), (px, py)) - #except: + # except: # print("** Error calculating shortest path (blockface)") # raise # - #blockface = {} - #blockface["pia_intersect"] = [ px, py ] - #blockface["wm_intersect"] = [ wx, wy ] - #blockface["soma_center"] = [ soma_x, soma_y ] - #blockface["soma_depth_um"] = depth - #try: + # blockface = {} + # blockface["pia_intersect"] = [ px, py ] + # blockface["wm_intersect"] = [ wx, wy ] + # blockface["soma_center"] = [ soma_x, soma_y ] + # blockface["soma_depth_um"] = depth + # try: # blockface["soma_depth_relative"] = depth / blk_thickness - #except: + # except: # blockface["soma_depth_relative"] = -1.0 # NaN is not friendly to ruby - #blockface["cort_thickness_um"] = blk_thickness - #blockface["theta"] = theta - + # blockface["cort_thickness_um"] = blk_thickness + # blockface["theta"] = theta + ######################################################################## # analyze primary (20x) image try: @@ -240,7 +235,7 @@ def main(jin): except Exception: print("** Error -- missing requisite primary (20x) field(s) in input json") raise - + # get soma position using weighted average of vertices try: sx, sy = convert_coords_str(soma) @@ -252,7 +247,7 @@ def main(jin): # calculate shortest path px, py, wx, wy = calculate_shortest(soma_x, soma_y, pia, wm) # calculate theta and affine - theta = vector_angle((0, 1), np.asarray([px,py]) - np.asarray([wx,wy])) + theta = vector_angle((0, 1), np.asarray([px, py]) - np.asarray([wx, wy])) tr_rot = construct_affine(theta) inv_tr_rot = construct_affine(-theta) # calculate soma depth and cortical thickness @@ -261,25 +256,24 @@ def main(jin): except Exception: print("** Error calculating shortest path (primary)") raise - - + primary = {} - primary["pia_intersect"] = [ px, py ] - primary["wm_intersect"] = [ wx, wy ] - primary["soma_center"] = [ soma_x, soma_y ] + primary["pia_intersect"] = [px, py] + primary["wm_intersect"] = [wx, wy] + primary["soma_center"] = [soma_x, soma_y] primary["soma_depth_um"] = depth try: primary["soma_depth_relative"] = depth / raw_thickness except Exception: - primary["soma_depth_relative"] = -1.0 # NaN is not friendly to ruby + primary["soma_depth_relative"] = -1.0 # NaN is not friendly to ruby primary["cort_thickness_um"] = raw_thickness primary["theta"] = theta - + try: scale = raw_thickness / blk_thickness except Exception: - scale = -1.0 # NaN is not ruby-friendly - + scale = -1.0 # NaN is not ruby-friendly + soma_coords_avail = False if "swc_file" in jin: # if SWC file available, extract soma position from it @@ -314,14 +308,14 @@ def main(jin): if not soma_coords_avail: raise Exception("** Error: Unable to construct translation component of affine") - + try: # apply affine rotation to soma position - translate_x = soma_x*tr_rot[0] + soma_y*tr_rot[1] + soma_z*tr_rot[2] - translate_y = soma_x*tr_rot[3] + soma_y*tr_rot[4] + soma_z*tr_rot[5] - translate_z = soma_x*tr_rot[6] + soma_y*tr_rot[7] + soma_z*tr_rot[8] + translate_x = soma_x * tr_rot[0] + soma_y * tr_rot[1] + soma_z * tr_rot[2] + translate_y = soma_x * tr_rot[3] + soma_y * tr_rot[4] + soma_z * tr_rot[5] + translate_z = soma_x * tr_rot[6] + soma_y * tr_rot[7] + soma_z * tr_rot[8] # apply translation vector to transform - tr_rot[ 9] = -translate_x + tr_rot[9] = -translate_x tr_rot[10] = -translate_y - depth tr_rot[11] = -translate_z except Exception: @@ -330,16 +324,16 @@ def main(jin): soma_x = -translate_x soma_y = -translate_y - depth soma_z = -translate_z - + # apply affine rotation to soma position - translate_x = soma_x*inv_tr_rot[0] + soma_y*inv_tr_rot[1] + soma_z*inv_tr_rot[2] - translate_y = soma_x*inv_tr_rot[3] + soma_y*inv_tr_rot[4] + soma_z*inv_tr_rot[5] - translate_z = soma_x*inv_tr_rot[6] + soma_y*inv_tr_rot[7] + soma_z*inv_tr_rot[8] - - inv_tr_rot[ 9] = -translate_x + translate_x = soma_x * inv_tr_rot[0] + soma_y * inv_tr_rot[1] + soma_z * inv_tr_rot[2] + translate_y = soma_x * inv_tr_rot[3] + soma_y * inv_tr_rot[4] + soma_z * inv_tr_rot[5] + translate_z = soma_x * inv_tr_rot[6] + soma_y * inv_tr_rot[7] + soma_z * inv_tr_rot[8] + + inv_tr_rot[9] = -translate_x inv_tr_rot[10] = -translate_y inv_tr_rot[11] = -translate_z - + try: # upright transform. based on rotation in 20x image upright = {} @@ -348,7 +342,7 @@ def main(jin): upright["trv_%02d" % i] = inv_tr_rot[i] jout = {} jout["primary"] = primary - #jout["blockface"] = blockface # per IT-14567, disable blockface + # jout["blockface"] = blockface # per IT-14567, disable blockface jout["upright"] = upright alignment = {} alignment["scale"] = scale @@ -365,25 +359,25 @@ def main(jin): except Exception: print("** Internal error **") raise - + return jout # # test transform -- bar.swc should match the source file - #print("source swc: " + jin["swc_file"]) - #print tr_rot - #morph2 = swc.read_swc(jin["swc_file"]) - #morph2.apply_affine(tr_rot) - #morph2.save("foo.swc") - #morph3 = swc.read_swc("foo.swc") - #morph3.apply_affine(inv_tr_rot) - #morph3.save("bar.swc") + # print("source swc: " + jin["swc_file"]) + # print tr_rot + # morph2 = swc.read_swc(jin["swc_file"]) + # morph2.apply_affine(tr_rot) + # morph2.save("foo.swc") + # morph3 = swc.read_swc("foo.swc") + # morph3.apply_affine(inv_tr_rot) + # morph3.save("bar.swc") + if __name__ == "__main__": - # read module input. PipelineModule object automatically parses the + # read module input. PipelineModule object automatically parses the # command line to pull out input.json and output.json file names module = PipelineModule() - jin = module.input_data() # loads input.json + jin = module.input_data() # loads input.json jout = main(jin) module.write_output_data(jout) # writes output.json - diff --git a/allensdk/internal/pipeline_modules/gbm/generate_gbm_analysis_run_records.py b/allensdk/internal/pipeline_modules/gbm/generate_gbm_analysis_run_records.py index 30257be6ea..135aa9bdd1 100644 --- a/allensdk/internal/pipeline_modules/gbm/generate_gbm_analysis_run_records.py +++ b/allensdk/internal/pipeline_modules/gbm/generate_gbm_analysis_run_records.py @@ -9,27 +9,27 @@ def main(analysis_records_json_location, db_host, db_port, db_name, db_user, db_passwd): - conn = psycopg2.connect(host=db_host, port=db_port, dbname=db_name, user=db_user, password=db_passwd) cur = conn.cursor() - cur.execute("select distinct rna.id as rna_well_id, gen.storage_directory || gen.filename, trans.storage_directory " - "|| trans.filename from wells rna join rs_tubes t on t.sample_id = rna.id join rna_seq_experiments e " - "on e.rs_tube_id = t.id join rna_seq_analysis_runs_rna_seq_experiments ar2e on ar2e.rna_seq_experiment_id " - "= e.id join well_known_files gen on gen.attachable_id = ar2e.rna_seq_analysis_run_id and " - "gen.well_known_file_type_id = 267380639 join well_known_files trans on trans.attachable_id " - "= ar2e.rna_seq_analysis_run_id and trans.well_known_file_type_id = 267380638 where gen.published_at is " - "not null and gen.storage_directory ilike '%/gbm/%' order by rna.id;") + cur.execute( + "select distinct rna.id as rna_well_id, gen.storage_directory || gen.filename, trans.storage_directory " + "|| trans.filename from wells rna join rs_tubes t on t.sample_id = rna.id join rna_seq_experiments e " + "on e.rs_tube_id = t.id join rna_seq_analysis_runs_rna_seq_experiments ar2e on ar2e.rna_seq_experiment_id " + "= e.id join well_known_files gen on gen.attachable_id = ar2e.rna_seq_analysis_run_id and " + "gen.well_known_file_type_id = 267380639 join well_known_files trans on trans.attachable_id " + "= ar2e.rna_seq_analysis_run_id and trans.well_known_file_type_id = 267380638 where gen.published_at is " + "not null and gen.storage_directory ilike '%/gbm/%' order by rna.id;" + ) data = cur.fetchall() analysis_run_records = {"analysis_run_records": []} for item in data: record = {"rna_well_id": item[0], "analysis_run_gene_path": item[1], "analysis_run_transcript_path": item[2]} analysis_run_records["analysis_run_records"].append(record) - with open(analysis_records_json_location, 'w') as outfile: + with open(analysis_records_json_location, "w") as outfile: json.dump(analysis_run_records, outfile) -if __name__ == '__main__': - +if __name__ == "__main__": analysis_records_json_location = sys.argv[1] db_host = sys.argv[2] db_port = sys.argv[3] diff --git a/allensdk/internal/pipeline_modules/gbm/generate_gbm_heatmap.py b/allensdk/internal/pipeline_modules/gbm/generate_gbm_heatmap.py index 32a2d5e01d..e40931ecc5 100644 --- a/allensdk/internal/pipeline_modules/gbm/generate_gbm_heatmap.py +++ b/allensdk/internal/pipeline_modules/gbm/generate_gbm_heatmap.py @@ -33,46 +33,45 @@ def create_transcripts_for_genes(analysis_run_gene_file): + """Creates a list that contains the associated transcript for each gene sorted by entrez_id""" - """ Creates a list that contains the associated transcript for each gene sorted by entrez_id """ - - transcripts_for_genes = np.genfromtxt(analysis_run_gene_file["analysis_run_gene_path"], usecols=[0, 1], skip_header=1, - dtype='str').tolist() + transcripts_for_genes = np.genfromtxt( + analysis_run_gene_file["analysis_run_gene_path"], usecols=[0, 1], skip_header=1, dtype="str" + ).tolist() data = sorted(transcripts_for_genes, key=lambda row: int(row[0])) - header = ['gene_id', 'transcript_id(s)'] + header = ["gene_id", "transcript_id(s)"] data.insert(0, header) data = pd.DataFrame(data) return data def create_genes_for_transcripts(analysis_run_transcript_file): + """Creates a list that contains the associated gene for each transcript sorted alphabetically""" - """ Creates a list that contains the associated gene for each transcript sorted alphabetically """ - - genes_for_transcripts = np.genfromtxt(analysis_run_transcript_file["analysis_run_transcript_path"], usecols=[0, 1] - , skip_header=1, dtype='str').tolist() + genes_for_transcripts = np.genfromtxt( + analysis_run_transcript_file["analysis_run_transcript_path"], usecols=[0, 1], skip_header=1, dtype="str" + ).tolist() data = sorted(genes_for_transcripts, key=lambda row: row[0].lower()) - header = ['transcript_id', 'gene_id'] + header = ["transcript_id", "gene_id"] data.insert(0, header) data = pd.DataFrame(data) return data def create_gene_fpkm_table(analysis_run_records): - - """ Creates a a matrix ("rows x columns = genes x samples") of fpkm gene expression values for each particular - (gene, sample) pair. Rows are sorted by entrez_id and columns are by rna_well_id """ + """Creates a a matrix ("rows x columns = genes x samples") of fpkm gene expression values for each particular + (gene, sample) pair. Rows are sorted by entrez_id and columns are by rna_well_id""" gene_fpkm = [] rna_well_ids = [] for record in analysis_run_records: - gene_fpkm.append(np.genfromtxt(record["analysis_run_gene_path"], usecols=[-1] - , skip_header=1, dtype='str')) + gene_fpkm.append(np.genfromtxt(record["analysis_run_gene_path"], usecols=[-1], skip_header=1, dtype="str")) rna_well_ids.append(record["rna_well_id"]) - entrez_ids = np.genfromtxt(analysis_run_records[0]["analysis_run_gene_path"], usecols=[0], skip_header=1 - , dtype='str').tolist() + entrez_ids = np.genfromtxt( + analysis_run_records[0]["analysis_run_gene_path"], usecols=[0], skip_header=1, dtype="str" + ).tolist() entrez_ids_int = list(map(int, entrez_ids)) gene_fpkm_table = np.column_stack(gene_fpkm) df = pd.DataFrame(gene_fpkm_table, columns=rna_well_ids, index=entrez_ids_int) @@ -83,20 +82,21 @@ def create_gene_fpkm_table(analysis_run_records): def create_transcript_fpkm_table(analysis_run_records): - - """ Creates a a matrix ("rows x columns = transcripts x samples") of fpkm gene expression values for each particular - (transcript, sample) pair. Rows are sorted by transcript id and columns are by rna_well_id """ + """Creates a a matrix ("rows x columns = transcripts x samples") of fpkm gene expression values for each particular + (transcript, sample) pair. Rows are sorted by transcript id and columns are by rna_well_id""" transcript_fpkm = [] rna_well_ids = [] for record in analysis_run_records: - transcript_fpkm.append(np.genfromtxt(record["analysis_run_transcript_path"], usecols=[-1] - , skip_header=1, dtype='str')) + transcript_fpkm.append( + np.genfromtxt(record["analysis_run_transcript_path"], usecols=[-1], skip_header=1, dtype="str") + ) rna_well_ids.append(record["rna_well_id"]) - transcript_ids = np.genfromtxt(analysis_run_records[0]["analysis_run_transcript_path"], usecols=[0], skip_header=1 - , dtype='str').tolist() + transcript_ids = np.genfromtxt( + analysis_run_records[0]["analysis_run_transcript_path"], usecols=[0], skip_header=1, dtype="str" + ).tolist() transcript_fpkm_table = np.column_stack(transcript_fpkm) df = pd.DataFrame(transcript_fpkm_table, columns=rna_well_ids, index=transcript_ids) @@ -107,19 +107,17 @@ def create_transcript_fpkm_table(analysis_run_records): def create_sample_metadata(sample_metadata_records): + """Creates a table of sample metadata sorted by rna_well_id""" - """ Creates a table of sample metadata sorted by rna_well_id """ - - df = pd.DataFrame.from_dict(sample_metadata_records, orient='columns') - data = df.sort_values(by=['rna_well_id']).reset_index(drop=True) - rna_well_id = data['rna_well_id'] - data.drop(labels=['rna_well_id'], axis=1, inplace=True) - data.insert(0, 'rna_well_id', rna_well_id) + df = pd.DataFrame.from_dict(sample_metadata_records, orient="columns") + data = df.sort_values(by=["rna_well_id"]).reset_index(drop=True) + rna_well_id = data["rna_well_id"] + data.drop(labels=["rna_well_id"], axis=1, inplace=True) + data.insert(0, "rna_well_id", rna_well_id) return data def main(): - input_file = sys.argv[1] data = json.load(open(input_file)) transcripts_for_genes_output = data["transcripts_for_genes_output"] @@ -130,14 +128,16 @@ def main(): analysis_run_records = json.load(open(data["analysis_run_records"])) sample_metadata_records = json.load(open(data["sample_metadata_records"])) - create_transcripts_for_genes(analysis_run_records["analysis_run_records"][0]).to_csv(transcripts_for_genes_output, - index=False, header=False) - create_genes_for_transcripts(analysis_run_records["analysis_run_records"][0]).to_csv(genes_for_transcripts_output, - index=False, header=False) + create_transcripts_for_genes(analysis_run_records["analysis_run_records"][0]).to_csv( + transcripts_for_genes_output, index=False, header=False + ) + create_genes_for_transcripts(analysis_run_records["analysis_run_records"][0]).to_csv( + genes_for_transcripts_output, index=False, header=False + ) create_gene_fpkm_table(analysis_run_records["analysis_run_records"]).to_csv(gene_fpkm_table_output) create_transcript_fpkm_table(analysis_run_records["analysis_run_records"]).to_csv(transcript_fpkm_table_output) create_sample_metadata(sample_metadata_records).to_csv(columns_samples_output, index=False) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/gbm/generate_gbm_sample_metadata.py b/allensdk/internal/pipeline_modules/gbm/generate_gbm_sample_metadata.py index 0179dda6e3..5c9df47d02 100644 --- a/allensdk/internal/pipeline_modules/gbm/generate_gbm_sample_metadata.py +++ b/allensdk/internal/pipeline_modules/gbm/generate_gbm_sample_metadata.py @@ -10,34 +10,34 @@ def main(sample_metadata_json_location, db_host, db_port, db_name, db_user, db_passwd): - conn = psycopg2.connect(host=db_host, port=db_port, dbname=db_name, user=db_user, password=db_passwd) cur = conn.cursor(cursor_factory=RealDictCursor) - cur.execute("select distinct rna.id as rna_well_id, tumor.id as tumor_id, tumor.external_specimen_name as tumor_name" - ", block.id as block_id, block.external_specimen_name as block_name, sp.id as specimen_id" - ", sp.external_specimen_name as specimen_name, min(poly.id) as polygon_id, st.id as structure_id" - ", st.acronym as structure_abbreviation, to_hex(st.red) || to_hex(st.green) || to_hex(st.blue) as " - "structure_color, st.name as structure_name from wells rna join image_series mims on mims.id = " - "rna.image_series_id join specimens sp on sp.id = mims.specimen_id join specimens block on block.id = " - "sp.parent_id join specimens tumor on tumor.id = block.parent_id join avg_microarray_templates mt on " - "mt.image_series_id = mims.id join avg_graphic_objects poly on poly.id = mt.shape_id join structures st " - "on st.id = poly.structure_id join rs_tubes tube on tube.sample_id = rna.id join rna_seq_experiments exp " - "on exp.rs_tube_id = tube.id join rna_seq_analysis_runs_rna_seq_experiments ar2exp on " - "ar2exp.rna_seq_experiment_id = exp.id join analysis_runs ar on ar.id = ar2exp.rna_seq_analysis_run_id " - "join well_known_files fpkm on fpkm.attachable_id = ar.id where rna.sample_id_string like any (array " - "['366-___', '466-___']) and fpkm.published_at is not null group by tumor.id, " - "tumor.external_specimen_name, block.id, block.external_specimen_name, sp.id, sp.external_specimen_name, " - "rna.id, st.id, st.acronym, st.name, structure_color order by rna.id;") - with open(sample_metadata_json_location, 'w') as outfile: + cur.execute( + "select distinct rna.id as rna_well_id, tumor.id as tumor_id, tumor.external_specimen_name as tumor_name" + ", block.id as block_id, block.external_specimen_name as block_name, sp.id as specimen_id" + ", sp.external_specimen_name as specimen_name, min(poly.id) as polygon_id, st.id as structure_id" + ", st.acronym as structure_abbreviation, to_hex(st.red) || to_hex(st.green) || to_hex(st.blue) as " + "structure_color, st.name as structure_name from wells rna join image_series mims on mims.id = " + "rna.image_series_id join specimens sp on sp.id = mims.specimen_id join specimens block on block.id = " + "sp.parent_id join specimens tumor on tumor.id = block.parent_id join avg_microarray_templates mt on " + "mt.image_series_id = mims.id join avg_graphic_objects poly on poly.id = mt.shape_id join structures st " + "on st.id = poly.structure_id join rs_tubes tube on tube.sample_id = rna.id join rna_seq_experiments exp " + "on exp.rs_tube_id = tube.id join rna_seq_analysis_runs_rna_seq_experiments ar2exp on " + "ar2exp.rna_seq_experiment_id = exp.id join analysis_runs ar on ar.id = ar2exp.rna_seq_analysis_run_id " + "join well_known_files fpkm on fpkm.attachable_id = ar.id where rna.sample_id_string like any (array " + "['366-___', '466-___']) and fpkm.published_at is not null group by tumor.id, " + "tumor.external_specimen_name, block.id, block.external_specimen_name, sp.id, sp.external_specimen_name, " + "rna.id, st.id, st.acronym, st.name, structure_color order by rna.id;" + ) + with open(sample_metadata_json_location, "w") as outfile: json.dump(cur.fetchall(), outfile, indent=2) -if __name__ == '__main__': - +if __name__ == "__main__": sample_metadata_json_location = sys.argv[1] db_host = sys.argv[2] db_port = sys.argv[3] db_name = sys.argv[4] db_user = sys.argv[5] db_passwd = sys.argv[6] - main(sample_metadata_json_location, db_host, db_port, db_name, db_user, db_passwd) \ No newline at end of file + main(sample_metadata_json_location, db_host, db_port, db_name, db_user, db_passwd) diff --git a/allensdk/internal/pipeline_modules/run_annotated_region_metrics.py b/allensdk/internal/pipeline_modules/run_annotated_region_metrics.py index 03b949ba99..b5a6037468 100644 --- a/allensdk/internal/pipeline_modules/run_annotated_region_metrics.py +++ b/allensdk/internal/pipeline_modules/run_annotated_region_metrics.py @@ -1,36 +1,36 @@ """Run annotated region metrics calculations""" + import os import h5py from allensdk.internal.core.lims_utilities import get_input_json from allensdk.internal.brain_observatory.annotated_region_metrics import get_metrics -from allensdk.internal.core.lims_pipeline_module import (PipelineModule, - run_module) +from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module SDK_PATH = "/data/informatics/CAM/isi_metrics/allensdk" -SCRIPT_PATH = ("/data/informatics/CAM/isi_metrics/allensdk/allensdk/internal" - "/pipeline_modules/run_annotated_region_metrics.py") +SCRIPT_PATH = ( + "/data/informatics/CAM/isi_metrics/allensdk/allensdk/internal/pipeline_modules/run_annotated_region_metrics.py" +) + -def debug(region_id, storage_directory="./", local=True, - sdk_path=SDK_PATH, script_path=SCRIPT_PATH, lims_host="lims2"): +def debug(region_id, storage_directory="./", local=True, sdk_path=SDK_PATH, script_path=SCRIPT_PATH, lims_host="lims2"): strategy_class = "AnnotatedRegionMetricsStrategy" object_class = "AnnotatedRegion" - input_json = get_input_json(region_id, object_class, strategy_class, - lims_host) + input_json = get_input_json(region_id, object_class, strategy_class, lims_host) exp_dir = os.path.join(storage_directory, str(region_id)) - run_module(script_path, - input_json, - exp_dir, - sdk_path=sdk_path, - pbs=dict(vmem=4, - job_name="isi_metrics_{}".format(region_id), - walltime="1:00:00"), - local=local) + run_module( + script_path, + input_json, + exp_dir, + sdk_path=sdk_path, + pbs=dict(vmem=4, job_name="isi_metrics_{}".format(region_id), walltime="1:00:00"), + local=local, + ) def load_arrays(h5_file): with h5py.File(h5_file, "r") as f: - altitude_phase = f['retinotopy_altitude'][:] - azimuth_phase = f['retinotopy_azimuth'][:] + altitude_phase = f["retinotopy_altitude"][:] + azimuth_phase = f["retinotopy_azimuth"][:] return altitude_phase, azimuth_phase @@ -46,5 +46,6 @@ def main(): mod.write_output_data(output_data) + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_demixing.py b/allensdk/internal/pipeline_modules/run_demixing.py index 9b147dcdf3..9e82fa452a 100644 --- a/allensdk/internal/pipeline_modules/run_demixing.py +++ b/allensdk/internal/pipeline_modules/run_demixing.py @@ -1,5 +1,6 @@ import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import allensdk.internal.core.lims_utilities as lu from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module @@ -14,48 +15,64 @@ from allensdk.config.manifest import Manifest -EXCLUDE_LABELS = ["union", "duplicate", "motion_border", - "decrosstalk_ghost", - "decrosstalk_invalid_raw", - "decrosstalk_invalid_raw_active", - "decrosstalk_invalid_unmixed", - "decrosstalk_invalid_unmixed_active" ] +EXCLUDE_LABELS = [ + "union", + "duplicate", + "motion_border", + "decrosstalk_ghost", + "decrosstalk_invalid_raw", + "decrosstalk_invalid_raw_active", + "decrosstalk_invalid_unmixed", + "decrosstalk_invalid_unmixed_active", +] + def debug(experiment_id, local=False): OUTPUT_DIRECTORY = "/data/informatics/CAM/demix" SDK_PATH = "/data/informatics/CAM/analysis/allensdk" SCRIPT = "/data/informatics/CAM/analysis/allensdk/allensdk/internal/pipeline_modules/run_demixing.py" - sd = lu.query("select storage_directory from ophys_experiments where id = %d" % experiment_id)[0]['storage_directory'] + sd = lu.query("select storage_directory from ophys_experiments where id = %d" % experiment_id)[0][ + "storage_directory" + ] rois = lu.query("select * from cell_rois where ophys_experiment_id = %d" % experiment_id) - exc_labels = lu.query(""" + exc_labels = lu.query( + """ select cr.id, rel.name as exclusion_label from cell_rois cr join cell_rois_roi_exclusion_labels crrel on crrel.cell_roi_id = cr.id join roi_exclusion_labels rel on crrel.roi_exclusion_label_id = rel.id where cr.ophys_experiment_id = %d -""" % experiment_id) - - nrois = { roi['id']: dict(width=roi['width'], - height=roi['height'], - x=roi['x'], - y=roi['y'], - id=roi['id'], - valid=roi['valid_roi'], - mask=roi['mask_matrix'], - exclusion_labels=[]) - for roi in rois } +""" + % experiment_id + ) + + nrois = { + roi["id"]: dict( + width=roi["width"], + height=roi["height"], + x=roi["x"], + y=roi["y"], + id=roi["id"], + valid=roi["valid_roi"], + mask=roi["mask_matrix"], + exclusion_labels=[], + ) + for roi in rois + } for exc_label in exc_labels: - nrois[exc_label['id']]['exclusion_labels'].append(exc_label['exclusion_label']) + nrois[exc_label["id"]]["exclusion_labels"].append(exc_label["exclusion_label"]) - movie_path_response = lu.query(''' + movie_path_response = lu.query( + """ select wkf.filename, wkf.storage_directory from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where wkf.attachable_id = {} and wkf.attachable_type = 'OphysExperiment' and wkft.name = 'MotionCorrectedImageStack' - '''.format(experiment_id)) - movie_h5_path = os.path.join(movie_path_response[0]['storage_directory'], movie_path_response[0]['filename']) + """.format(experiment_id) + ) + movie_h5_path = os.path.join(movie_path_response[0]["storage_directory"], movie_path_response[0]["filename"]) exp_dir = os.path.join(OUTPUT_DIRECTORY, str(experiment_id)) @@ -63,18 +80,19 @@ def debug(experiment_id, local=False): "movie_h5": movie_h5_path, "traces_h5": os.path.join(sd, "processed", "roi_traces.h5"), "roi_masks": nrois.values(), - "output_file": os.path.join(exp_dir, "demixed_traces.h5") - } - - run_module(SCRIPT, - input_data, - exp_dir, - sdk_path=SDK_PATH, - pbs=dict(vmem=160, - job_name="demix_%d"% experiment_id, - walltime="36:00:00"), - local=local, - optional_args=['--log-level','DEBUG']) + "output_file": os.path.join(exp_dir, "demixed_traces.h5"), + } + + run_module( + SCRIPT, + input_data, + exp_dir, + sdk_path=SDK_PATH, + pbs=dict(vmem=160, job_name="demix_%d" % experiment_id, walltime="36:00:00"), + local=local, + optional_args=["--log-level", "DEBUG"], + ) + def assert_exists(file_name): if not os.path.exists(file_name): @@ -83,7 +101,7 @@ def assert_exists(file_name): def get_path(obj, key, check_exists): try: - path = obj[key] + path = obj[key] except KeyError: raise KeyError("required input field '%s' does not exist" % key) @@ -103,7 +121,7 @@ def parse_input(data, exclude_labels): with h5py.File(traces_h5, "r") as f: traces = f["data"][()] - trace_ids = [ int(rid) for rid in f["roi_names"][()] ] + trace_ids = [int(rid) for rid in f["roi_names"][()]] rois = get_path(data, "roi_masks", False) masks = None @@ -112,8 +130,8 @@ def parse_input(data, exclude_labels): for roi in rois: mask = np.zeros(movie_shape, dtype=bool) mask_matrix = np.array(roi["mask"], dtype=bool) - mask[roi["y"]:roi["y"]+roi["height"],roi["x"]:roi["x"]+roi["width"]] = mask_matrix - + mask[roi["y"] : roi["y"] + roi["height"], roi["x"] : roi["x"] + roi["width"]] = mask_matrix + if masks is None: masks = np.zeros((len(rois), mask.shape[0], mask.shape[1]), dtype=bool) valid = np.zeros(len(rois), dtype=bool) @@ -124,12 +142,13 @@ def parse_input(data, exclude_labels): except ValueError: raise ValueError("Could not find cell roi id %d in roi traces file" % rid) - masks[ridx,:,:] = mask + masks[ridx, :, :] = mask - valid[ridx] = len(set(exclude_labels) & set(roi.get("exclusion_labels",[]))) == 0 + valid[ridx] = len(set(exclude_labels) & set(roi.get("exclusion_labels", []))) == 0 return traces, masks, valid, np.array(trace_ids), movie_h5, output_h5 + def main(): mod = PipelineModule() mod.parser.add_argument("--exclude-labels", nargs="*", default=EXCLUDE_LABELS) @@ -147,8 +166,8 @@ def main(): Manifest.safe_mkdir(plot_dir) logging.debug("reading movie") - with h5py.File(movie_h5, 'r') as f: - movie = f['data'][()] + with h5py.File(movie_h5, "r") as f: + movie = f["data"][()] # only demix non-union, non-duplicate ROIs valid_idxs = np.where(valid) @@ -158,20 +177,15 @@ def main(): logging.debug("demixing") demixed_traces, drop_frames = demixer.demix_time_dep_masks(demix_traces, movie, demix_masks) - nt_inds = demixer.plot_negative_transients(demix_traces, - demixed_traces, - valid[valid_idxs], - demix_masks, - trace_ids[valid_idxs], - plot_dir) + nt_inds = demixer.plot_negative_transients( + demix_traces, demixed_traces, valid[valid_idxs], demix_masks, trace_ids[valid_idxs], plot_dir + ) logging.debug("rois with negative transients: %s", str(trace_ids[valid_idxs][nt_inds])) - nb_inds = demixer.plot_negative_baselines(demix_traces, - demixed_traces, - demix_masks, - trace_ids[valid_idxs], - plot_dir) + nb_inds = demixer.plot_negative_baselines( + demix_traces, demixed_traces, demix_masks, trace_ids[valid_idxs], plot_dir + ) # negative baseline rois (and those that overlap with them) become nans logging.debug("rois with negative baselines (or overlap with them): %s", str(trace_ids[valid_idxs][nb_inds])) @@ -182,15 +196,17 @@ def main(): out_traces[:] = np.nan out_traces[valid_idxs] = demixed_traces - with h5py.File(output_h5, 'w') as f: + with h5py.File(output_h5, "w") as f: f.create_dataset("data", data=out_traces, compression="gzip") roi_names = np.array([str(rn) for rn in trace_ids]).astype(np.bytes_) f.create_dataset("roi_names", data=roi_names) - mod.write_output_data(dict( + mod.write_output_data( + dict( negative_transient_roi_ids=trace_ids[valid_idxs][nt_inds], - negative_baseline_roi_ids=trace_ids[valid_idxs][nb_inds] - )) + negative_baseline_roi_ids=trace_ids[valid_idxs][nb_inds], + ) + ) if __name__ == "__main__": diff --git a/allensdk/internal/pipeline_modules/run_eye_tracking.py b/allensdk/internal/pipeline_modules/run_eye_tracking.py index c8ab3d0826..3e294ec4a6 100755 --- a/allensdk/internal/pipeline_modules/run_eye_tracking.py +++ b/allensdk/internal/pipeline_modules/run_eye_tracking.py @@ -1,39 +1,44 @@ import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import logging import os from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module -from allensdk.internal.brain_observatory.run_itracker import (run_itracker, - compute_bounding_box, - DEFAULT_THRESHOLD_FACTOR, - get_experiment_info) +from allensdk.internal.brain_observatory.run_itracker import ( + run_itracker, + compute_bounding_box, + DEFAULT_THRESHOLD_FACTOR, + get_experiment_info, +) + def debug(experiment_id, num_frames=None, threshold_factor=None, local=False): OUTPUT_DIR = "/data/informatics/CAM/eye_tracking/" SDK_PATH = "/data/informatics/CAM/eye_tracking/allensdk" - SCRIPT_PATH = "/data/informatics/CAM/eye_tracking/allensdk/allensdk/internal/pipeline_modules/run_eye_tracking.py" - + SCRIPT_PATH = "/data/informatics/CAM/eye_tracking/allensdk/allensdk/internal/pipeline_modules/run_eye_tracking.py" + experiment_dir = os.path.join(OUTPUT_DIR, str(experiment_id)) - + info = get_experiment_info(experiment_id) - info['output_directory'] = experiment_dir + info["output_directory"] = experiment_dir - optional_args = [ ] + optional_args = [] if num_frames is not None: - optional_args += ['--num_frames',str(num_frames)] - - run_module(SCRIPT_PATH, - info, - experiment_dir, - sdk_path=SDK_PATH, - pbs=dict(vmem=160, - job_name="itrack_%d"% experiment_id, - walltime="10:00:00"), - local=local, - optional_args=optional_args) - + optional_args += ["--num_frames", str(num_frames)] + + run_module( + SCRIPT_PATH, + info, + experiment_dir, + sdk_path=SDK_PATH, + pbs=dict(vmem=160, job_name="itrack_%d" % experiment_id, walltime="10:00:00"), + local=local, + optional_args=optional_args, + ) + + def main(): mod = PipelineModule() mod.parser.add_argument("--num_frames", type=int, default=None) @@ -41,33 +46,32 @@ def main(): data = mod.input_data() args = dict( - movie_file=data['movie_file'], - metadata_file=data['metadata_file'], - output_directory=data['output_directory'], - threshold_factor=data.get('threshold_factor', mod.args.threshold_factor), + movie_file=data["movie_file"], + metadata_file=data["metadata_file"], + output_directory=data["output_directory"], + threshold_factor=data.get("threshold_factor", mod.args.threshold_factor), num_frames=mod.args.num_frames, auto=True, cache_input_frames=True, input_block_size=None, - output_annotated_movie_block_size=None - ) - - if data.get('pupil_points', None): - args['bbox_pupil'] = compute_bounding_box(data['pupil_points']) - if data.get('corneal_reflection_points', None): - args['bbox_cr'] = compute_bounding_box(data['corneal_reflection_points']) - + output_annotated_movie_block_size=None, + ) + + if data.get("pupil_points", None): + args["bbox_pupil"] = compute_bounding_box(data["pupil_points"]) + if data.get("corneal_reflection_points", None): + args["bbox_cr"] = compute_bounding_box(data["corneal_reflection_points"]) + tracker = run_itracker(**args) logging.debug("finished running itracker") - + output_data = dict( - pupil_file=tracker.pupil_file, - corneal_reflection_file=tracker.cr_file, - mean_frame_file=tracker.mean_frame_file - ) - + pupil_file=tracker.pupil_file, corneal_reflection_file=tracker.cr_file, mean_frame_file=tracker.mean_frame_file + ) + mod.write_output_data(output_data) - -if __name__=='__main__': + + +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_neuropil_correction.py b/allensdk/internal/pipeline_modules/run_neuropil_correction.py index 12fd2a0a28..519a9f07f3 100755 --- a/allensdk/internal/pipeline_modules/run_neuropil_correction.py +++ b/allensdk/internal/pipeline_modules/run_neuropil_correction.py @@ -1,6 +1,7 @@ #!/usr/bin/python import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import matplotlib.pyplot as plt import logging import numpy as np @@ -13,6 +14,7 @@ from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module + def debug(experiment_id, local=False): OUTPUT_DIRECTORY = "/data/informatics/CAM/neuropil" SDK_PATH = "/data/informatics/CAM/neuropil/allensdk" @@ -29,33 +31,36 @@ def debug(experiment_id, local=False): exp_dir = os.path.join(OUTPUT_DIRECTORY, str(experiment_id)) input_data = dict( - roi_trace_file = roi_trace_file, - neuropil_trace_file = os.path.join(sd, "processed", "neuropil_traces.h5"), - storage_directory = exp_dir - ) + roi_trace_file=roi_trace_file, + neuropil_trace_file=os.path.join(sd, "processed", "neuropil_traces.h5"), + storage_directory=exp_dir, + ) + + run_module( + SCRIPT, + input_data, + exp_dir, + sdk_path=SDK_PATH, + pbs=dict(vmem=160, job_name="np_%d" % experiment_id, walltime="36:00:00"), + local=local, + optional_args=["--log-level", "DEBUG"], + ) - run_module(SCRIPT, - input_data, - exp_dir, - sdk_path=SDK_PATH, - pbs=dict(vmem=160, - job_name="np_%d"% experiment_id, - walltime="36:00:00"), - local=local, - optional_args=['--log-level','DEBUG']) - def debug_plot(file_name, roi_trace, neuropil_trace, corrected_trace, r, r_vals=None, err_vals=None): - fig = plt.figure(figsize=(15,10)) + fig = plt.figure(figsize=(15, 10)) ax = fig.add_subplot(211) - ax.plot(roi_trace,'r', label="raw") - ax.plot(corrected_trace,'b', label="fc") - ax.plot(neuropil_trace,'g', label="neuropil") - ax.set_xlim(0,roi_trace.size) - ax.set_title('raw(%.02f, %.02f) fc(%.02f, %.02f) r(%f)' % (roi_trace.min(), roi_trace.max(), corrected_trace.min(), corrected_trace.max(), r)) + ax.plot(roi_trace, "r", label="raw") + ax.plot(corrected_trace, "b", label="fc") + ax.plot(neuropil_trace, "g", label="neuropil") + ax.set_xlim(0, roi_trace.size) + ax.set_title( + "raw(%.02f, %.02f) fc(%.02f, %.02f) r(%f)" + % (roi_trace.min(), roi_trace.max(), corrected_trace.min(), corrected_trace.max(), r) + ) ax.legend() - + if r_vals is not None: ax = fig.add_subplot(212) ax.plot(r_vals, err_vals, "o") @@ -63,6 +68,7 @@ def debug_plot(file_name, roi_trace, neuropil_trace, corrected_trace, r, r_vals= plt.savefig(file_name) plt.close() + def adjust_r_for_negativity(r, F_C, F_M, F_N): # this function is no longer used, but leaving it here just in case # loop through all of the negative spots and pick r to fix them @@ -78,7 +84,6 @@ def adjust_r_for_negativity(r, F_C, F_M, F_N): F_C = F_M - r * F_N logging.debug(" updated r to %f", r) - # if there is still a negative spot, it's off by some tiny epsilon. # step r down by delta_r increments until we find one that works. delta_r = -1e-5 @@ -130,13 +135,13 @@ def main(): logging.error("Error: unable to open neuropil trace file '%s'", neuropil_file) raise - ''' + """ get number of traces, length, etc. - ''' - num_traces, T = roi_traces['data'].shape + """ + num_traces, T = roi_traces["data"].shape T_orig = T - T_cross_val = int(T/2) - if (T - T_cross_val > T_cross_val): + T_cross_val = int(T / 2) + if T - T_cross_val > T_cross_val: T = T - 1 # make sure that ROI and neuropil trace files are organized the same @@ -146,23 +151,23 @@ def main(): assert len(n_id) == len(r_id), "Input trace files are not aligned (ROI count)" for i in range(len(n_id)): assert n_id[i] == r_id[i], "Input trace files are not aligned (ROI IDs)" - ''' + """ initialize storage variables and analysis routine - ''' - r_list = [ None ] * num_traces - RMSE_list = [ -1 ] * num_traces + """ + r_list = [None] * num_traces + RMSE_list = [-1] * num_traces roi_names = n_id corrected = np.zeros((num_traces, T_orig)) - r_vals = [ None ] * num_traces + r_vals = [None] * num_traces for n in range(num_traces): - roi = roi_traces['data'][n] - neuropil = neuropil_traces['data'][n] + roi = roi_traces["data"][n] + neuropil = neuropil_traces["data"][n] if np.any(np.isnan(neuropil)): logging.warning("neuropil trace for roi %d contains NaNs, skipping", n) continue - + if np.any(np.isnan(roi)): logging.warning("roi trace for roi %d contains NaNs, skipping", n) continue @@ -170,53 +175,52 @@ def main(): r = None logging.info("Correcting trace %d (roi %s)", n, str(n_id[n])) - results = estimate_contamination_ratios(roi, neuropil) + results = estimate_contamination_ratios(roi, neuropil) logging.info("r=%f err=%f it=%d", results["r"], results["err"], results["it"]) r = results["r"] fc = roi - r * neuropil RMSE_list[n] = results["err"] r_vals[n] = results["r_vals"] - - debug_plot(os.path.join(plot_dir, "initial_%04d.png" % n), - roi, neuropil, fc, r, results["r_vals"], results["err_vals"]) + + debug_plot( + os.path.join(plot_dir, "initial_%04d.png" % n), roi, neuropil, fc, r, results["r_vals"], results["err_vals"] + ) # mean of the corrected trace must be positive if fc.mean() > 0: r_list[n] = r - corrected[n,:] = fc + corrected[n, :] = fc else: logging.warning("fc has negative baseline, skipping this r value") # fill in empty r values - for n in range(num_traces): - roi = roi_traces['data'][n] - neuropil = neuropil_traces['data'][n] + for n in range(num_traces): + roi = roi_traces["data"][n] + neuropil = neuropil_traces["data"][n] if r_list[n] is None: logging.warning("Error estimated r for trace %d. Setting to zero.", n) r_list[n] = 0 - corrected[n,:] = roi + corrected[n, :] = roi # save a debug plot - debug_plot(os.path.join(plot_dir, "final_%04d.png" % n), - roi, neuropil, corrected[n,:], r_list[n]) + debug_plot(os.path.join(plot_dir, "final_%04d.png" % n), roi, neuropil, corrected[n, :], r_list[n]) # one last sanity check eps = -0.0001 - if np.mean(corrected[n,:]) < eps: + if np.mean(corrected[n, :]) < eps: raise Exception("Trace %d baseline is still negative value after correction" % n) if r_list[n] < 0.0: raise Exception("Trace %d ended with negative r" % n) - ######################################################################## # write out processed data try: savefile = os.path.join(storage_dir, "neuropil_correction.h5") - hf = h5py.File(savefile, 'w') + hf = h5py.File(savefile, "w") hf.create_dataset("r", data=r_list) hf.create_dataset("RMSE", data=RMSE_list) hf.create_dataset("FC", data=corrected, compression="gzip") @@ -229,8 +233,8 @@ def main(): hf.close() except Exception: logging.error("Error creating output h5 file") - raise - + raise + roi_traces.close() neuropil_traces.close() @@ -240,5 +244,6 @@ def main(): logging.info("finished") + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_observatory_analysis.py b/allensdk/internal/pipeline_modules/run_observatory_analysis.py index 36e1f3cd53..e2e16cefe8 100644 --- a/allensdk/internal/pipeline_modules/run_observatory_analysis.py +++ b/allensdk/internal/pipeline_modules/run_observatory_analysis.py @@ -22,46 +22,59 @@ import os + def get_experiment_nwb_file(experiment_id): - res = lu.query(""" + res = lu.query( + """ select * from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where attachable_id = %d and wkft.name = 'NWBOphys' -""" % experiment_id) - return os.path.join(res[0]['storage_directory'], res[0]['filename']) +""" + % experiment_id + ) + return os.path.join(res[0]["storage_directory"], res[0]["filename"]) + def get_experiment_session(experiment_id): - return lu.query(""" + return lu.query( + """ select stimulus_name from ophys_sessions os join ophys_experiments oe on oe.ophys_session_id = os.id where oe.id = %d -""" % experiment_id)[0]['stimulus_name'] - -def debug(experiment_ids, local=False, - OUTPUT_DIR = "/data/informatics/CAM/analysis/", - SDK_PATH = "/data/informatics/CAM/analysis/allensdk/", - walltime="10:00:00", - python=SHARED_PYTHON, - queue='braintv'): - +""" + % experiment_id + )[0]["stimulus_name"] + + +def debug( + experiment_ids, + local=False, + OUTPUT_DIR="/data/informatics/CAM/analysis/", + SDK_PATH="/data/informatics/CAM/analysis/allensdk/", + walltime="10:00:00", + python=SHARED_PYTHON, + queue="braintv", +): input_data = {} for eid in experiment_ids: exp_dir = os.path.join(OUTPUT_DIR, str(eid)) - input_data[eid] = dict(nwb_file=get_experiment_nwb_file(eid), - output_file=os.path.join(exp_dir, "%d_analysis.h5" % eid), - session_name=get_experiment_session(eid)) - - run_module(os.path.abspath(__file__), - input_data, - exp_dir, - python=python, - sdk_path=SDK_PATH, - pbs=dict(vmem=32, - job_name="bobanalysis_%d"% eid, - walltime=walltime, - queue=queue), - local=local) + input_data[eid] = dict( + nwb_file=get_experiment_nwb_file(eid), + output_file=os.path.join(exp_dir, "%d_analysis.h5" % eid), + session_name=get_experiment_session(eid), + ) + + run_module( + os.path.abspath(__file__), + input_data, + exp_dir, + python=python, + sdk_path=SDK_PATH, + pbs=dict(vmem=32, job_name="bobanalysis_%d" % eid, walltime=walltime, queue=queue), + local=local, + ) + def main(): mod = PipelineModule() @@ -70,8 +83,8 @@ def main(): results = {} for ident, experiment in jin.items(): - nwb_file = experiment['nwb_file'] - output_file = experiment['output_file'] + nwb_file = experiment["nwb_file"] + output_file = experiment["output_file"] if experiment["session_name"] not in si.SESSION_STIMULUS_MAP.keys(): raise Exception("Could not run analysis for unknown session: %s" % experiment["session_name"]) @@ -80,8 +93,7 @@ def main(): logging.info("NWB file %s", nwb_file) logging.info("Output file %s", output_file) - results[ident] = run_session_analysis(nwb_file, output_file, - save_flag=True, plot_flag=False) + results[ident] = run_session_analysis(nwb_file, output_file, save_flag=True, plot_flag=False) logging.info("Generating output") @@ -92,14 +104,14 @@ def main(): # metric fields names = {} roi_id = None - for metric, values in data['cell'].items(): + for metric, values in data["cell"].items(): if metric == "roi_id": roi_id = values else: # convert dict to array vals = [] for i in range(len(values)): - vals.append(values[i]) # panda syntax + vals.append(values[i]) # panda syntax names[metric] = vals # make an output record for each roi_id if roi_id is not None: @@ -110,14 +122,12 @@ def main(): roi[field] = values[i] res[name] = roi - jout[session_name] = { - 'cell': res, - 'experiment': data['experiment'] - } + jout[session_name] = {"cell": res, "experiment": data["experiment"]} logging.info("Saving output") mod.write_output_data(jout) + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_observatory_container_thumbnails.py b/allensdk/internal/pipeline_modules/run_observatory_container_thumbnails.py index 85ea97dd46..3b7a22fdc0 100644 --- a/allensdk/internal/pipeline_modules/run_observatory_container_thumbnails.py +++ b/allensdk/internal/pipeline_modules/run_observatory_container_thumbnails.py @@ -6,14 +6,19 @@ import allensdk.core.json_utilities as ju from allensdk.config.manifest import Manifest + def get_container_info(container_id): - res = lu.query(""" + res = lu.query( + """ select * from ophys_experiments oe where experiment_container_id = %d and oe.workflow_state != 'failed' -""" % container_id) +""" + % container_id + ) return res + def debug(container_id, local=False, plots=None): SCRIPT = "/data/informatics/CAM/analysis/allensdk/allensdk/internal/pipeline_modules/run_observatory_container_thumbnails.py" SDK_PATH = "/data/informatics/CAM/analysis/allensdk/" @@ -23,42 +28,42 @@ def debug(container_id, local=False, plots=None): input_data = [] for exp in get_container_info(container_id): - exp_data = robsth.get_input_data(exp['id']) + exp_data = robsth.get_input_data(exp["id"]) exp_input_json = os.path.join(exp_data["output_directory"], "input.json") - input_data.append(dict( - input_json=exp_input_json, - output_json=os.path.join(exp_data["output_directory"], "output.json") - )) + input_data.append( + dict(input_json=exp_input_json, output_json=os.path.join(exp_data["output_directory"], "output.json")) + ) Manifest.safe_make_parent_dirs(exp_input_json) ju.write(exp_input_json, exp_data) - run_module(SCRIPT, - input_data, - container_dir, - sdk_path=SDK_PATH, - pbs=dict(vmem=32, - job_name="cthumbs_%d"% container_id, - walltime="10:00:00"), - local=local, - optional_args=['--types='+','.join(plots)] if plots else None) + run_module( + SCRIPT, + input_data, + container_dir, + sdk_path=SDK_PATH, + pbs=dict(vmem=32, job_name="cthumbs_%d" % container_id, walltime="10:00:00"), + local=local, + optional_args=["--types=" + ",".join(plots)] if plots else None, + ) + def main(): mod = PipelineModule() - mod.parser.add_argument("--types", default=','.join(robsth.PLOT_TYPES)) + mod.parser.add_argument("--types", default=",".join(robsth.PLOT_TYPES)) mod.parser.add_argument("--threads", default=4) - + data = mod.input_data() - types = mod.args.types.split(',') + types = mod.args.types.split(",") for input_file in data: - exp_input_json = input_file['input_json'] + exp_input_json = input_file["input_json"] exp_input_data = ju.read(exp_input_json) nwb_file, analysis_file, output_directory = robsth.parse_input(exp_input_data) - - robsth.build_experiment_thumbnails(nwb_file, analysis_file, output_directory, - types, mod.args.threads) -if __name__=='__main__': + robsth.build_experiment_thumbnails(nwb_file, analysis_file, output_directory, types, mod.args.threads) + + +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py b/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py index 209993eb28..36bd2f28a0 100644 --- a/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py +++ b/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py @@ -1,5 +1,6 @@ import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import os import allensdk.core.json_utilities as ju @@ -18,9 +19,11 @@ from allensdk.brain_observatory.natural_scenes import NaturalScenes from allensdk.brain_observatory.natural_movie import NaturalMovie from allensdk.brain_observatory import observatory_plots as oplots -from allensdk.core.brain_observatory_nwb_data_set import (BrainObservatoryNwbDataSet, - MissingStimulusException, - NoEyeTrackingException) +from allensdk.core.brain_observatory_nwb_data_set import ( + BrainObservatoryNwbDataSet, + MissingStimulusException, + NoEyeTrackingException, +) from allensdk.config.manifest import Manifest from allensdk.internal.core.lims_pipeline_module import run_module import allensdk.internal.core.lims_utilities as lu @@ -30,31 +33,39 @@ SMALL_HEIGHT = 150 SMALL_FONT = 4 LARGE_FONT = 12 -PLOT_CONFIGS = { 'small': dict(height_px=SMALL_HEIGHT, pattern="%s_small.png", font_size=SMALL_FONT), - 'large': dict(height_px=LARGE_HEIGHT, pattern="%s_large.png", font_size=LARGE_FONT), - 'svg': dict(height_px=LARGE_HEIGHT, pattern="%s.svg", font_size=SMALL_FONT) } -PLOT_TYPES = ["dg", "sg", "ns", "lsn_on", - "lsn_off", "rf", - "nm1", "nm2", "nm3", "sp", - "corr", "eye"] +PLOT_CONFIGS = { + "small": dict(height_px=SMALL_HEIGHT, pattern="%s_small.png", font_size=SMALL_FONT), + "large": dict(height_px=LARGE_HEIGHT, pattern="%s_large.png", font_size=LARGE_FONT), + "svg": dict(height_px=LARGE_HEIGHT, pattern="%s.svg", font_size=SMALL_FONT), +} +PLOT_TYPES = ["dg", "sg", "ns", "lsn_on", "lsn_off", "rf", "nm1", "nm2", "nm3", "sp", "corr", "eye"] + def get_experiment_analysis_file(experiment_id): - res = lu.query(""" + res = lu.query( + """ select * from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where attachable_id = %d and wkft.name = 'OphysExperimentCellRoiMetricsFile' -""" % experiment_id) - return os.path.join(res[0]['storage_directory'], res[0]['filename']) +""" + % experiment_id + ) + return os.path.join(res[0]["storage_directory"], res[0]["filename"]) + def get_experiment_nwb_file(experiment_id): - res = lu.query(""" + res = lu.query( + """ select * from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where attachable_id = %d and wkft.name = 'NWBOphys' -""" % experiment_id) - return os.path.join(res[0]['storage_directory'], res[0]['filename']) +""" + % experiment_id + ) + return os.path.join(res[0]["storage_directory"], res[0]["filename"]) + def get_experiment_files(experiment_id): nwb_file = get_experiment_nwb_file(experiment_id) @@ -66,62 +77,66 @@ def get_experiment_files(experiment_id): if not os.path.exists(nwb_file): raise Exception("nwb file does not exist: %s" % nwb_file) - #if not os.path.exists(analysis_file): - # raise Exception("analysis file does not exist: %s" % analysis_file) + # if not os.path.exists(analysis_file): + # raise Exception("analysis file does not exist: %s" % analysis_file) return nwb_file, analysis_file + def get_input_data(experiment_id): OUTPUT_DIR = "/data/informatics/CAM/analysis/" nwb_file, analysis_file = get_experiment_files(experiment_id) output_directory = os.path.join(OUTPUT_DIR, str(experiment_id), "thumbnails") - + my_file = "/data/informatics/CAM/analysis/%d/%d_analysis.h5" % (experiment_id, experiment_id) if os.path.exists(my_file): analysis_file = my_file input_data = { - 'nwb_file': nwb_file, + "nwb_file": nwb_file, #'analysis_file': analysis_file, - 'analysis_file': analysis_file, - 'output_directory': output_directory - } + "analysis_file": analysis_file, + "output_directory": output_directory, + } return input_data + def debug(experiment_id, plots=None, local=False): SDK_PATH = "/data/informatics/CAM/analysis/allensdk/" input_data = get_input_data(experiment_id) - run_module(os.path.abspath(__file__), - input_data, - input_data["output_directory"], - sdk_path=SDK_PATH, - pbs=dict(vmem=32, - job_name="bobthumbs_%d"% experiment_id, - walltime="10:00:00"), - local=local, - optional_args=['--types='+','.join(plots)] if plots else None) + run_module( + os.path.abspath(__file__), + input_data, + input_data["output_directory"], + sdk_path=SDK_PATH, + pbs=dict(vmem=32, job_name="bobthumbs_%d" % experiment_id, walltime="10:00:00"), + local=local, + optional_args=["--types=" + ",".join(plots)] if plots else None, + ) + def build_plots(prefix, aspect, configs, output_dir, axes=None, transparent=False): Manifest.safe_mkdir(output_dir) for config in configs: - h = config['height_px'] + h = config["height_px"] w = int(h * aspect) - + file_name = os.path.join(output_dir, config["pattern"] % prefix) logging.debug("file: %s", file_name) with oplots.figure_in_px(w, h, file_name, transparent=transparent): - matplotlib.rcParams.update({'font.size': config['font_size']}) + matplotlib.rcParams.update({"font.size": config["font_size"]}) yield file_name + def build_cell_plots(cell_specimen_ids, prefix, aspect, configs, output_dir, axes=None, transparent=False): - for i,csid in enumerate(cell_specimen_ids): + for i, csid in enumerate(cell_specimen_ids): if np.isnan(csid): cell_dir = os.path.join(output_dir, str(i)) else: @@ -130,36 +145,37 @@ def build_cell_plots(cell_specimen_ids, prefix, aspect, configs, output_dir, axe for fn in build_plots(prefix, aspect, configs, cell_dir, transparent=transparent): yield fn, csid, i + def build_drifting_gratings(dga, configs, output_dir): - for fn in build_plots("drifting_gratings_axes_pref_dir", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("drifting_gratings_axes_pref_dir", 1.0, [configs["large"], configs["svg"]], output_dir): dga.plot_preferred_direction(include_labels=True) oplots.finalize_no_axes() - for fn in build_plots("drifting_gratings_pref_dir", 1.0, [configs['small']], output_dir): + for fn in build_plots("drifting_gratings_pref_dir", 1.0, [configs["small"]], output_dir): dga.plot_preferred_direction(include_labels=False) oplots.finalize_no_axes() - for fn in build_plots("drifting_gratings_axes_pref_tf", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("drifting_gratings_axes_pref_tf", 1.0, [configs["large"], configs["svg"]], output_dir): dga.plot_preferred_temporal_frequency() oplots.finalize_with_axes() - for fn in build_plots("drifting_gratings_pref_tf", 1.0, [configs['small']], output_dir): + for fn in build_plots("drifting_gratings_pref_tf", 1.0, [configs["small"]], output_dir): dga.plot_preferred_temporal_frequency() oplots.finalize_no_labels() - for fn in build_plots("drifting_gratings_axes_dsi", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("drifting_gratings_axes_dsi", 1.0, [configs["large"], configs["svg"]], output_dir): dga.plot_direction_selectivity() oplots.finalize_with_axes() - for fn in build_plots("drifting_gratings_dsi", 1.0, [configs['small']], output_dir): + for fn in build_plots("drifting_gratings_dsi", 1.0, [configs["small"]], output_dir): dga.plot_direction_selectivity() oplots.finalize_no_labels() - for fn in build_plots("drifting_gratings_axes_osi", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("drifting_gratings_axes_osi", 1.0, [configs["large"], configs["svg"]], output_dir): dga.plot_orientation_selectivity() oplots.finalize_with_axes() - for fn in build_plots("drifting_gratings_osi", 1.0, [configs['small']], output_dir): + for fn in build_plots("drifting_gratings_osi", 1.0, [configs["small"]], output_dir): dga.plot_orientation_selectivity() oplots.finalize_no_labels() @@ -168,36 +184,37 @@ def build_drifting_gratings(dga, configs, output_dir): dga.open_star_plot(csid, include_labels=False, cell_index=i) oplots.finalize_no_axes() + def build_static_gratings(sga, configs, output_dir): - for fn in build_plots("static_gratings_axes_time_to_peak", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("static_gratings_axes_time_to_peak", 1.0, [configs["large"], configs["svg"]], output_dir): sga.plot_time_to_peak() oplots.finalize_with_axes() - for fn in build_plots("static_gratings_time_to_peak", 1.0, [configs['small']], output_dir): + for fn in build_plots("static_gratings_time_to_peak", 1.0, [configs["small"]], output_dir): sga.plot_time_to_peak() oplots.finalize_no_labels() - for fn in build_plots("static_gratings_axes_pref_ori", 1.5, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("static_gratings_axes_pref_ori", 1.5, [configs["large"], configs["svg"]], output_dir): sga.plot_preferred_orientation(include_labels=True) oplots.finalize_no_axes() - for fn in build_plots("static_gratings_pref_ori", 1.5, [configs['small']], output_dir): + for fn in build_plots("static_gratings_pref_ori", 1.5, [configs["small"]], output_dir): sga.plot_preferred_orientation(include_labels=False) oplots.finalize_no_axes() - for fn in build_plots("static_gratings_axes_osi", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("static_gratings_axes_osi", 1.0, [configs["large"], configs["svg"]], output_dir): sga.plot_orientation_selectivity() oplots.finalize_with_axes() - for fn in build_plots("static_gratings_osi", 1.0, [configs['small']], output_dir): + for fn in build_plots("static_gratings_osi", 1.0, [configs["small"]], output_dir): sga.plot_orientation_selectivity() oplots.finalize_no_labels() - for fn in build_plots("static_gratings_axes_pref_sf", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("static_gratings_axes_pref_sf", 1.0, [configs["large"], configs["svg"]], output_dir): sga.plot_preferred_spatial_frequency() oplots.finalize_with_axes() - for fn in build_plots("static_gratings_pref_sf", 1.0, [configs['small']], output_dir): + for fn in build_plots("static_gratings_pref_sf", 1.0, [configs["small"]], output_dir): sga.plot_preferred_spatial_frequency() oplots.finalize_no_labels() @@ -206,18 +223,20 @@ def build_static_gratings(sga, configs, output_dir): sga.open_fan_plot(csid, include_labels=False, cell_index=i) oplots.finalize_no_axes() + def build_natural_movie(nma, configs, output_dir, name): csids = nma.data_set.get_cell_specimen_ids() - for file_name, csid, i in build_cell_plots(csids, name, 1.0, configs.values(), output_dir): + for file_name, csid, i in build_cell_plots(csids, name, 1.0, configs.values(), output_dir): nma.open_track_plot(csid, cell_index=i) oplots.finalize_no_axes() + def build_natural_scenes(nsa, configs, output_dir): - for fn in build_plots("natural_scenes_axes_time_to_peak", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("natural_scenes_axes_time_to_peak", 1.0, [configs["large"], configs["svg"]], output_dir): nsa.plot_time_to_peak() oplots.finalize_with_axes() - for fn in build_plots("natural_scenes_time_to_peak", 1.0, [configs['small']], output_dir): + for fn in build_plots("natural_scenes_time_to_peak", 1.0, [configs["small"]], output_dir): nsa.plot_time_to_peak() oplots.finalize_no_labels() @@ -226,23 +245,24 @@ def build_natural_scenes(nsa, configs, output_dir): nsa.open_corona_plot(csid, cell_index=i) oplots.finalize_no_axes() + def build_locally_sparse_noise(lsna, configs, output_dir, on): prefix = "locally_sparse_noise_" + ("on" if on else "off") csids = lsna.data_set.get_cell_specimen_ids() - for file_name, csid, i in build_cell_plots(csids, prefix, 1.754, [configs['large'], configs['small']], output_dir): + for file_name, csid, i in build_cell_plots(csids, prefix, 1.754, [configs["large"], configs["small"]], output_dir): lsna.open_pincushion_plot(on, cell_specimen_id=csid, cell_index=i) oplots.finalize_no_axes() + def build_receptive_field(lsna, configs, output_dir): - lsn_movie, lsn_mask = lsna.data_set.get_locally_sparse_noise_stimulus_template(lsna.stimulus, - mask_off_screen=False) + lsn_movie, lsn_mask = lsna.data_set.get_locally_sparse_noise_stimulus_template(lsna.stimulus, mask_off_screen=False) if lsna.cell_index_receptive_field_analysis_data is None: logging.warning("receptive field analysis not performed, so no receptive field plots will be made") return - clim = np.nanpercentile(lsna.receptive_field, [1.0,99.0], axis=None) + clim = np.nanpercentile(lsna.receptive_field, [1.0, 99.0], axis=None) for fn in build_plots("population_receptive_field", 1.754, [configs["large"]], output_dir, transparent=True): lsna.plot_population_receptive_field(mask=lsn_mask, scalebar=True) @@ -253,42 +273,59 @@ def build_receptive_field(lsna, configs, output_dir): oplots.finalize_no_axes() csids = lsna.data_set.get_cell_specimen_ids() - for file_name, csid, i in build_cell_plots(csids, "receptive_field_on", 1.754, [configs["large"]], output_dir, transparent=True): - lsna.plot_cell_receptive_field(True, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=True) + for file_name, csid, i in build_cell_plots( + csids, "receptive_field_on", 1.754, [configs["large"]], output_dir, transparent=True + ): + lsna.plot_cell_receptive_field( + True, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=True + ) oplots.finalize_no_axes() - for file_name, csid, i in build_cell_plots(csids, "receptive_field_on", 1.754, [configs["small"]], output_dir, transparent=True): - lsna.plot_cell_receptive_field(True, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=False) + for file_name, csid, i in build_cell_plots( + csids, "receptive_field_on", 1.754, [configs["small"]], output_dir, transparent=True + ): + lsna.plot_cell_receptive_field( + True, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=False + ) oplots.finalize_no_axes() - for file_name, csid, i in build_cell_plots(csids, "receptive_field_off", 1.754, [configs["large"]], output_dir, transparent=True): - lsna.plot_cell_receptive_field(False, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=True) + for file_name, csid, i in build_cell_plots( + csids, "receptive_field_off", 1.754, [configs["large"]], output_dir, transparent=True + ): + lsna.plot_cell_receptive_field( + False, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=True + ) oplots.finalize_no_axes() - for file_name, csid, i in build_cell_plots(csids, "receptive_field_off", 1.754, [configs["small"]], output_dir, transparent=True): - lsna.plot_cell_receptive_field(False, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=False) + for file_name, csid, i in build_cell_plots( + csids, "receptive_field_off", 1.754, [configs["small"]], output_dir, transparent=True + ): + lsna.plot_cell_receptive_field( + False, cell_specimen_id=csid, clim=clim, mask=lsn_mask, cell_index=i, scalebar=False + ) oplots.finalize_no_axes() - + def build_speed_tuning(analysis, configs, output_dir): csids = analysis.data_set.get_cell_specimen_ids() - for fn in build_plots("running_speed", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("running_speed", 1.0, [configs["large"], configs["svg"]], output_dir): analysis.plot_running_speed_histogram() oplots.finalize_with_axes() - for fn in build_plots("running_speed", 1.0, [configs['small']], output_dir): + for fn in build_plots("running_speed", 1.0, [configs["small"]], output_dir): analysis.plot_running_speed_histogram() oplots.finalize_no_labels() - for fn, csid, i in build_cell_plots(csids, "speed_tuning", 1.0, [configs['large'], configs['svg']], output_dir): + for fn, csid, i in build_cell_plots(csids, "speed_tuning", 1.0, [configs["large"], configs["svg"]], output_dir): analysis.plot_speed_tuning(csid, cell_index=i) oplots.finalize_with_axes() - for fn, csid, i in build_cell_plots(csids, "speed_tuning", 1.0, [configs['small']], output_dir): + for fn, csid, i in build_cell_plots(csids, "speed_tuning", 1.0, [configs["small"]], output_dir): analysis.plot_speed_tuning(csid, cell_index=i) oplots.finalize_no_axes() + def build_correlation_plots(data_set, analysis_file, configs, output_dir): sig_corrs = [] noise_corrs = [] @@ -300,43 +337,47 @@ def build_correlation_plots(data_set, analysis_file, configs, output_dir): if si.DRIFTING_GRATINGS in avail_stims: dg = DriftingGratings.from_analysis_file(data_set, analysis_file) - if hasattr(dg, 'representational_similarity'): + if hasattr(dg, "representational_similarity"): ans.append(dg) labels.append(si.DRIFTING_GRATINGS_SHORT) colors.append(si.DRIFTING_GRATINGS_COLOR) - setups = [ ( [configs['large']], True ), ( [configs['small']], False )] + setups = [([configs["large"]], True), ([configs["small"]], False)] for cfgs, show_labels in setups: for fn in build_plots("drifting_gratings_representational_similarity", 1.0, cfgs, output_dir): - oplots.plot_representational_similarity(dg.representational_similarity, - dims=[dg.orivals, dg.tfvals[1:]], - dim_labels=["dir", "tf"], - dim_order=[1,0], - colors=['r','b'], - labels=show_labels) + oplots.plot_representational_similarity( + dg.representational_similarity, + dims=[dg.orivals, dg.tfvals[1:]], + dim_labels=["dir", "tf"], + dim_order=[1, 0], + colors=["r", "b"], + labels=show_labels, + ) if si.STATIC_GRATINGS in avail_stims: sg = StaticGratings.from_analysis_file(data_set, analysis_file) - if hasattr(sg, 'representational_similarity'): + if hasattr(sg, "representational_similarity"): ans.append(sg) labels.append(si.STATIC_GRATINGS_SHORT) colors.append(si.STATIC_GRATINGS_COLOR) - setups = [ ( [configs['large']], True ), ( [configs['small']], False )] + setups = [([configs["large"]], True), ([configs["small"]], False)] for cfgs, show_labels in setups: for fn in build_plots("static_gratings_representational_similarity", 1.0, cfgs, output_dir): - oplots.plot_representational_similarity(sg.representational_similarity, - dims=[sg.orivals, sg.sfvals[1:], sg.phasevals], - dim_labels=["ori", "sf", "ph"], - dim_order=[1,0,2], - colors=['r','g','b'], - labels=show_labels) + oplots.plot_representational_similarity( + sg.representational_similarity, + dims=[sg.orivals, sg.sfvals[1:], sg.phasevals], + dim_labels=["ori", "sf", "ph"], + dim_order=[1, 0, 2], + colors=["r", "g", "b"], + labels=show_labels, + ) if si.NATURAL_SCENES in avail_stims: ns = NaturalScenes.from_analysis_file(data_set, analysis_file) - if hasattr(ns, 'representational_similarity'): + if hasattr(ns, "representational_similarity"): ans.append(ns) labels.append(si.NATURAL_SCENES_SHORT) colors.append(si.NATURAL_SCENES_COLOR) - setups = [ ( [configs['large']], True ), ( [configs['small']], False )] + setups = [([configs["large"]], True), ([configs["small"]], False)] for cfgs, show_labels in setups: for fn in build_plots("natural_scenes_representational_similarity", 1.0, cfgs, output_dir): oplots.plot_representational_similarity(ns.representational_similarity, labels=show_labels) @@ -344,34 +385,32 @@ def build_correlation_plots(data_set, analysis_file, configs, output_dir): if len(ans): for an in ans: sig_corrs.append(an.signal_correlation) - extra_dims = range(2,len(an.noise_correlation.shape)) + extra_dims = range(2, len(an.noise_correlation.shape)) noise_corrs.append(an.noise_correlation.mean(axis=tuple(extra_dims))) - for fn in build_plots("correlation", 1.0, [configs['large'], configs['svg']], output_dir): + for fn in build_plots("correlation", 1.0, [configs["large"], configs["svg"]], output_dir): oplots.population_correlation_scatter(sig_corrs, noise_corrs, labels, colors, scale=16.0) oplots.finalize_with_axes() - for fn in build_plots("correlation", 1.0, [configs['small']], output_dir): + for fn in build_plots("correlation", 1.0, [configs["small"]], output_dir): oplots.population_correlation_scatter(sig_corrs, noise_corrs, labels, colors, scale=4.0) oplots.finalize_no_labels() csids = ans[0].data_set.get_cell_specimen_ids() - for fn, csid, i in build_cell_plots(csids, "signal_correlation", 1.0, [configs['large']], output_dir): + for fn, csid, i in build_cell_plots(csids, "signal_correlation", 1.0, [configs["large"]], output_dir): row = ans[0].row_from_cell_id(csid, i) - oplots.plot_cell_correlation([ np.delete(sig_corr[row],i) for sig_corr in sig_corrs ], - labels, colors) + oplots.plot_cell_correlation([np.delete(sig_corr[row], i) for sig_corr in sig_corrs], labels, colors) oplots.finalize_with_axes() - for fn, csid, i in build_cell_plots(csids, "signal_correlation", 1.0, [configs['small']], output_dir): + for fn, csid, i in build_cell_plots(csids, "signal_correlation", 1.0, [configs["small"]], output_dir): row = ans[0].row_from_cell_id(csid, i) - oplots.plot_cell_correlation([ np.delete(sig_corr[row],i) for sig_corr in sig_corrs ], - labels, colors) + oplots.plot_cell_correlation([np.delete(sig_corr[row], i) for sig_corr in sig_corrs], labels, colors) oplots.finalize_no_labels() - + def lsna_check_hvas(data_set, data_file): - avail_stims = si.stimuli_in_session(data_set.get_session_type()) - targeted_structure = data_set.get_metadata()['targeted_structure'] + avail_stims = si.stimuli_in_session(data_set.get_session_type()) + targeted_structure = data_set.get_metadata()["targeted_structure"] stim = None @@ -387,12 +426,12 @@ def lsna_check_hvas(data_set, data_file): stim = si.LOCALLY_SPARSE_NOISE if stim is None: - raise MissingStimulusException("Could not find appropriate LSN stimulus for session %s", - data_set.get_session_type()) + raise MissingStimulusException( + "Could not find appropriate LSN stimulus for session %s", data_set.get_session_type() + ) else: logging.debug("in structure %s, using %s stimulus for plots", targeted_structure, stim) - - + return LocallySparseNoise.from_analysis_file(data_set, data_file, stim) @@ -401,27 +440,23 @@ def build_eye_tracking_plots(data_set, configs, output_dir): pupil_times, xy_deg = data_set.get_pupil_location() xy_deg = xy_deg[np.isfinite(xy_deg).any(axis=1)] if len(xy_deg) == 0: - logging.debug("Eye tracking had no finite data, should have been " - "failed") + logging.debug("Eye tracking had no finite data, should have been failed") return elif len(xy_deg) < 3: - c = np.ones(len(xy_deg)) # not enough points for KDE, should probably be failed + c = np.ones(len(xy_deg)) # not enough points for KDE, should probably be failed else: c = gaussian_kde(xy_deg.T)(xy_deg.T) - for fn in build_plots("eye_tracking_gaze_axes", 1.0, - [configs['large'], configs['svg']], - output_dir): + for fn in build_plots("eye_tracking_gaze_axes", 1.0, [configs["large"], configs["svg"]], output_dir): oplots.plot_pupil_location(xy_deg, c=c, include_labels=True) oplots.finalize_with_axes() - for fn in build_plots("eye_tracking_gaze", 1.0, [configs['small']], - output_dir): + for fn in build_plots("eye_tracking_gaze", 1.0, [configs["small"]], output_dir): oplots.plot_pupil_location(xy_deg, c=c, include_labels=False) oplots.finalize_no_axes() except NoEyeTrackingException: logging.debug("No eye tracking found.") - + def build_type(nwb_file, data_file, configs, output_dir, type_name): data_set = BrainObservatoryNwbDataSet(nwb_file) @@ -468,6 +503,7 @@ def build_type(nwb_file, data_file, configs, output_dir, type_name): logging.critical("error running stimulus (%s)", type_name) raise e + def parse_input(data): nwb_file = data.get("nwb_file", None) @@ -483,18 +519,17 @@ def parse_input(data): if not os.path.exists(analysis_file): raise IOError("analysis file does not exists: %s" % analysis_file) - output_directory = data.get("output_directory", None) if output_directory is None: raise IOError("input JSON missing required field 'output_directory'") Manifest.safe_mkdir(output_directory) - + return nwb_file, analysis_file, output_directory - -def build_experiment_thumbnails(nwb_file, analysis_file, output_directory, - types=None, threads=4): + + +def build_experiment_thumbnails(nwb_file, analysis_file, output_directory, types=None, threads=4): if types is None: types = PLOT_TYPES @@ -517,18 +552,17 @@ def build_experiment_thumbnails(nwb_file, analysis_file, output_directory, p.close() p.join() - def main(): parser = argparse.ArgumentParser() parser.add_argument("-t", "--threads", type=int, default=4) parser.add_argument("--log-level", default=logging.DEBUG) - parser.add_argument("--types", default=','.join(PLOT_TYPES)) + parser.add_argument("--types", default=",".join(PLOT_TYPES)) parser.add_argument("input_json") parser.add_argument("output_json") args = parser.parse_args() - args.types = args.types.split(',') + args.types = args.types.split(",") logging.getLogger().setLevel(args.log_level) @@ -536,10 +570,10 @@ def main(): nwb_file, analysis_file, output_directory = parse_input(input_data) - build_experiment_thumbnails(nwb_file, analysis_file, output_directory, - args.types, args.threads) + build_experiment_thumbnails(nwb_file, analysis_file, output_directory, args.types, args.threads) ju.write(args.output_json, {}) + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_ophys_eye_calibration.py b/allensdk/internal/pipeline_modules/run_ophys_eye_calibration.py index abca49b93d..a7a6534c4c 100644 --- a/allensdk/internal/pipeline_modules/run_ophys_eye_calibration.py +++ b/allensdk/internal/pipeline_modules/run_ophys_eye_calibration.py @@ -1,29 +1,32 @@ import os -from allensdk.internal.core.lims_pipeline_module import (PipelineModule, - run_module) -from allensdk.internal.brain_observatory import (eye_calibration, - itracker_utils) +from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module +from allensdk.internal.brain_observatory import eye_calibration, itracker_utils import allensdk.internal.core.lims_utilities as lu import numpy as np import h5py EYE_RADIUS = 0.1682 -CM_PER_PIXEL = 10.2/10000 +CM_PER_PIXEL = 10.2 / 10000 + def get_wkf(wkf_type, experiment_id): - wkf = lu.query(""" + wkf = lu.query( + """ select CONCAT(wkf.storage_directory, wkf.filename) as path from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where wkft.name LIKE '{}' and wkf.attachable_id = {} -""".format(wkf_type, experiment_id))[0]["path"] +""".format(wkf_type, experiment_id) + )[0]["path"] return wkf + def debug(experiment_id, local=False): OUTPUT_DIRECTORY = "/data/informatics/CAM/eye_calibration" SDK_PATH = "/data/informatics/CAM/eye_calibration/allensdk" - SCRIPT = ("/data/informatics/CAM/eye_calibration/allensdk/allensdk" - "/internal/pipeline_modules/run_ophys_eye_calibration.py") + SCRIPT = ( + "/data/informatics/CAM/eye_calibration/allensdk/allensdk/internal/pipeline_modules/run_ophys_eye_calibration.py" + ) frame_width = 640 frame_height = 480 @@ -31,11 +34,13 @@ def debug(experiment_id, local=False): cr_file = get_wkf("EyeTracking Corneal Reflection", experiment_id) pupil_file = get_wkf("EyeTracking Pupil", experiment_id) - exp_info = lu.query(""" + exp_info = lu.query( + """ select * from ophys_sessions os where os.id = {} -""".format(experiment_id))[0] +""".format(experiment_id) + )[0] exp_dir = os.path.join(OUTPUT_DIRECTORY, str(experiment_id)) # clear out missing values to let us get defaults @@ -48,55 +53,38 @@ def debug(experiment_id, local=False): "pupil_params_file": pupil_file, "frame_width": frame_width, "frame_height": frame_height, - "output_file": os.path.join(exp_dir, - "eye_tracking_to_screen_mapping.h5"), - "monitor_position_x_mm": exp_info.get( - "screen_center_x_mm", eye_calibration.MONITOR_POSITION_OLD[0]*10), - "monitor_position_y_mm": exp_info.get( - "screen_center_y_mm", eye_calibration.MONITOR_POSITION_OLD[1]*10), - "monitor_position_z_mm": exp_info.get( - "screen_center_z_mm", eye_calibration.MONITOR_POSITION_OLD[2]*10), + "output_file": os.path.join(exp_dir, "eye_tracking_to_screen_mapping.h5"), + "monitor_position_x_mm": exp_info.get("screen_center_x_mm", eye_calibration.MONITOR_POSITION_OLD[0] * 10), + "monitor_position_y_mm": exp_info.get("screen_center_y_mm", eye_calibration.MONITOR_POSITION_OLD[1] * 10), + "monitor_position_z_mm": exp_info.get("screen_center_z_mm", eye_calibration.MONITOR_POSITION_OLD[2] * 10), "monitor_rotation_x_deg": exp_info.get("screen_rotation_x_deg", 0), "monitor_rotation_y_deg": exp_info.get("screen_rotation_y_deg", 0), "monitor_rotation_z_deg": exp_info.get("screen_rotation_z_deg", 0), - "camera_position_x_mm": exp_info.get( - "camera_center_x_mm", eye_calibration.CAMERA_POSITION_OLD[0]*10), - "camera_position_y_mm": exp_info.get( - "camera_center_y_mm", eye_calibration.CAMERA_POSITION_OLD[1]*10), - "camera_position_z_mm": exp_info.get( - "camera_center_z_mm", eye_calibration.CAMERA_POSITION_OLD[2]*10), + "camera_position_x_mm": exp_info.get("camera_center_x_mm", eye_calibration.CAMERA_POSITION_OLD[0] * 10), + "camera_position_y_mm": exp_info.get("camera_center_y_mm", eye_calibration.CAMERA_POSITION_OLD[1] * 10), + "camera_position_z_mm": exp_info.get("camera_center_z_mm", eye_calibration.CAMERA_POSITION_OLD[2] * 10), "camera_rotation_x_deg": exp_info.get( - "camera_rotation_x_deg", - eye_calibration.CAMERA_ROTATIONS_OLD[0]*180/np.pi), + "camera_rotation_x_deg", eye_calibration.CAMERA_ROTATIONS_OLD[0] * 180 / np.pi + ), "camera_rotation_y_deg": exp_info.get( - "camera_rotation_y_deg", - eye_calibration.CAMERA_ROTATIONS_OLD[1]*180/np.pi), + "camera_rotation_y_deg", eye_calibration.CAMERA_ROTATIONS_OLD[1] * 180 / np.pi + ), "camera_rotation_z_deg": exp_info.get( - "camera_rotation_z_deg", - eye_calibration.CAMERA_ROTATIONS_OLD[2]*180/np.pi), - "led_position_x_mm": exp_info.get( - "led_center_x_mm", eye_calibration.LED_POSITION_ORIGINAL[0]*10), - "led_position_y_mm": exp_info.get( - "led_center_y_mm", eye_calibration.LED_POSITION_ORIGINAL[1]*10), - "led_position_z_mm": exp_info.get( - "led_center_z_mm", eye_calibration.LED_POSITION_ORIGINAL[2]*10) + "camera_rotation_z_deg", eye_calibration.CAMERA_ROTATIONS_OLD[2] * 180 / np.pi + ), + "led_position_x_mm": exp_info.get("led_center_x_mm", eye_calibration.LED_POSITION_ORIGINAL[0] * 10), + "led_position_y_mm": exp_info.get("led_center_y_mm", eye_calibration.LED_POSITION_ORIGINAL[1] * 10), + "led_position_z_mm": exp_info.get("led_center_z_mm", eye_calibration.LED_POSITION_ORIGINAL[2] * 10), } # TEMPORARY HACKS TO DEAL WITH BAD DATA IN LIMS # TODO: REMOVE WHEN DATAFIXES DONE if input_data["monitor_position_x_mm"] == -86.2: - input_data["monitor_position_x_mm"] = \ - eye_calibration.MONITOR_POSITION_NEW[0]*10 - input_data["monitor_position_y_mm"] = \ - eye_calibration.MONITOR_POSITION_NEW[1]*10 - input_data["monitor_position_z_mm"] = \ - eye_calibration.MONITOR_POSITION_NEW[2]*10 + input_data["monitor_position_x_mm"] = eye_calibration.MONITOR_POSITION_NEW[0] * 10 + input_data["monitor_position_y_mm"] = eye_calibration.MONITOR_POSITION_NEW[1] * 10 + input_data["monitor_position_z_mm"] = eye_calibration.MONITOR_POSITION_NEW[2] * 10 - run_module(SCRIPT, - input_data, - exp_dir, - sdk_path=SDK_PATH, - local=local) + run_module(SCRIPT, input_data, exp_dir, sdk_path=SDK_PATH, local=local) def parse_input_data(data): @@ -105,31 +93,41 @@ def parse_input_data(data): frame_width = data["frame_width"] frame_height = data["frame_height"] output_file = data["output_file"] - monitor_position = np.array([ - float(data['monitor_position_x_mm'])/10.0, - float(data['monitor_position_y_mm'])/10.0, - float(data['monitor_position_z_mm'])/10.0, - ]) - monitor_rotations = np.array([ - float(data['monitor_rotation_x_deg'])*np.pi/180, - float(data['monitor_rotation_y_deg'])*np.pi/180, - float(data['monitor_rotation_z_deg'])*np.pi/180, - ]) - camera_position = np.array([ - float(data['camera_position_x_mm'])/10.0, - float(data['camera_position_y_mm'])/10.0, - float(data['camera_position_z_mm'])/10.0, - ]) - camera_rotations = np.array([ - float(data['camera_rotation_x_deg'])*np.pi/180, - float(data['camera_rotation_y_deg'])*np.pi/180, - float(data['camera_rotation_z_deg'])*np.pi/180, - ]) - led_position = np.array([ - float(data['led_position_x_mm'])/10.0, - float(data['led_position_y_mm'])/10.0, - float(data['led_position_z_mm'])/10.0, - ]) + monitor_position = np.array( + [ + float(data["monitor_position_x_mm"]) / 10.0, + float(data["monitor_position_y_mm"]) / 10.0, + float(data["monitor_position_z_mm"]) / 10.0, + ] + ) + monitor_rotations = np.array( + [ + float(data["monitor_rotation_x_deg"]) * np.pi / 180, + float(data["monitor_rotation_y_deg"]) * np.pi / 180, + float(data["monitor_rotation_z_deg"]) * np.pi / 180, + ] + ) + camera_position = np.array( + [ + float(data["camera_position_x_mm"]) / 10.0, + float(data["camera_position_y_mm"]) / 10.0, + float(data["camera_position_z_mm"]) / 10.0, + ] + ) + camera_rotations = np.array( + [ + float(data["camera_rotation_x_deg"]) * np.pi / 180, + float(data["camera_rotation_y_deg"]) * np.pi / 180, + float(data["camera_rotation_z_deg"]) * np.pi / 180, + ] + ) + led_position = np.array( + [ + float(data["led_position_x_mm"]) / 10.0, + float(data["led_position_y_mm"]) / 10.0, + float(data["led_position_z_mm"]) / 10.0, + ] + ) calibrator = eye_calibration.EyeCalibration( monitor_position=monitor_position, monitor_rotations=monitor_rotations, @@ -137,22 +135,20 @@ def parse_input_data(data): camera_position=camera_position, camera_rotations=camera_rotations, eye_radius=EYE_RADIUS, - cm_per_pixel=CM_PER_PIXEL) + cm_per_pixel=CM_PER_PIXEL, + ) cr_params, _ = itracker_utils.post_process_cr(cr_params) pupil_params = itracker_utils.post_process_pupil(pupil_params) - cr_params = itracker_utils.filter_bad_params(cr_params, frame_width, - frame_height) - pupil_params = itracker_utils.filter_bad_params(pupil_params, frame_width, - frame_height) + cr_params = itracker_utils.filter_bad_params(cr_params, frame_width, frame_height) + pupil_params = itracker_utils.filter_bad_params(pupil_params, frame_width, frame_height) return calibrator, cr_params, pupil_params, output_file def write_output(filename, position_degrees, position_cm, areas): with h5py.File(filename, "w") as f: f.create_dataset("screen_coordinates", data=position_cm) - f.create_dataset("screen_coordinates_spherical", - data=position_degrees) + f.create_dataset("screen_coordinates_spherical", data=position_degrees) f.create_dataset("pupil_areas", data=areas) @@ -162,18 +158,16 @@ def main(): calibrator, cr_params, pupil_params, outfile = parse_input_data(data) pupil_areas = calibrator.compute_area(pupil_params) - pupil_on_monitor_deg = calibrator.pupil_position_on_monitor_in_degrees( - pupil_params, cr_params) - pupil_on_monitor_cm = calibrator.pupil_position_on_monitor_in_cm( - pupil_params, cr_params) + pupil_on_monitor_deg = calibrator.pupil_position_on_monitor_in_degrees(pupil_params, cr_params) + pupil_on_monitor_cm = calibrator.pupil_position_on_monitor_in_cm(pupil_params, cr_params) missing_index = np.isnan(pupil_areas) | np.isnan(pupil_on_monitor_deg.T[0]) pupil_areas[missing_index] = np.nan - pupil_on_monitor_deg[missing_index,:] = np.nan - pupil_on_monitor_cm[missing_index,:] = np.nan - write_output(outfile, pupil_on_monitor_deg, pupil_on_monitor_cm, - pupil_areas) + pupil_on_monitor_deg[missing_index, :] = np.nan + pupil_on_monitor_cm[missing_index, :] = np.nan + write_output(outfile, pupil_on_monitor_deg, pupil_on_monitor_cm, pupil_areas) mod.write_output_data({"screen_mapping_file": outfile}) + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_ophys_session_decomposition.py b/allensdk/internal/pipeline_modules/run_ophys_session_decomposition.py index 8be31ef325..ed8dec9e89 100644 --- a/allensdk/internal/pipeline_modules/run_ophys_session_decomposition.py +++ b/allensdk/internal/pipeline_modules/run_ophys_session_decomposition.py @@ -1,5 +1,4 @@ -from allensdk.internal.core.lims_pipeline_module import (PipelineModule, - run_module) +from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module from allensdk.internal.brain_observatory import ophys_session_decomposition as osd from multiprocessing import Pool import os @@ -10,62 +9,69 @@ DEBUG_ITEMSIZE = 2 DEBUG_N_PLANES = 6 -def create_fake_metadata(exp_dir, raw_path, channels=None, - width=DEBUG_WIDTH, height=DEBUG_HEIGHT, - itemsize=DEBUG_ITEMSIZE, n_planes=DEBUG_N_PLANES): + +def create_fake_metadata( + exp_dir, + raw_path, + channels=None, + width=DEBUG_WIDTH, + height=DEBUG_HEIGHT, + itemsize=DEBUG_ITEMSIZE, + n_planes=DEBUG_N_PLANES, +): metadata = [] size = os.stat(raw_path).st_size if channels is None: channels = DEBUG_CHANNELS - n_frames = size/(itemsize*width*height) - frames_per_plane = n_frames/n_planes/len(channels) + n_frames = size / (itemsize * width * height) + frames_per_plane = n_frames / n_planes / len(channels) for plane in range(n_planes): experiment_id = plane outfile = os.path.join(exp_dir, "plane_{}.h5".format(plane)) frame_meta = [] for i, channel in enumerate(channels): - byte_offset = width * height * itemsize * \ - (plane * len(channels) + i) - strides = [width*height*itemsize*n_planes*len(channels), - width*itemsize, - itemsize] - frame_meta.append({"byte_offset": byte_offset, - "channel": i+1, - "channel_description": channel, - "frame_description": "plane_{}".format(plane), - "dtype": ">u{}".format(itemsize), - "position_offset": [None, 0, 0], - "shape": [frames_per_plane, height, width], - "strides": strides}) - metadata.append({"output_file": outfile, - "experiment_id": experiment_id, - "frame_metadata": frame_meta}) + byte_offset = width * height * itemsize * (plane * len(channels) + i) + strides = [width * height * itemsize * n_planes * len(channels), width * itemsize, itemsize] + frame_meta.append( + { + "byte_offset": byte_offset, + "channel": i + 1, + "channel_description": channel, + "frame_description": "plane_{}".format(plane), + "dtype": ">u{}".format(itemsize), + "position_offset": [None, 0, 0], + "shape": [frames_per_plane, height, width], + "strides": strides, + } + ) + metadata.append({"output_file": outfile, "experiment_id": experiment_id, "frame_metadata": frame_meta}) return metadata def debug(experiment_id, local=False, raw_path=None): OUTPUT_DIRECTORY = "/data/informatics/CAM/ophys_decomp" SDK_PATH = "/data/informatics/CAM/ophys_decomp/allensdk" - SCRIPT = ("/data/informatics/CAM/ophys_decomp/allensdk/allensdk/" - "internal/pipeline_modules/run_ophys_session_decomposition.py") + SCRIPT = ( + "/data/informatics/CAM/ophys_decomp/allensdk/allensdk/" + "internal/pipeline_modules/run_ophys_session_decomposition.py" + ) exp_dir = os.path.join(OUTPUT_DIRECTORY, str(experiment_id)) if raw_path is not None: conversion_definitions = create_fake_metadata(exp_dir, raw_path) - input_data = {"raw_filename": raw_path, - "frame_metadata": conversion_definitions} + input_data = {"raw_filename": raw_path, "frame_metadata": conversion_definitions} else: raise NotImplementedError("No real examples exist yet") - run_module(SCRIPT, - input_data, - exp_dir, - sdk_path=SDK_PATH, - pbs=dict(vmem=160, - job_name="ophys_decomp_%d"% experiment_id, - walltime="36:00:00"), - local=local) + run_module( + SCRIPT, + input_data, + exp_dir, + sdk_path=SDK_PATH, + pbs=dict(vmem=160, job_name="ophys_decomp_%d" % experiment_id, walltime="36:00:00"), + local=local, + ) def convert_frame(conversion_definition): @@ -74,13 +80,12 @@ def convert_frame(conversion_definition): auxiliary_hdf5_filename = conversion_definition["auxiliary_output_file"] experiment_id = conversion_definition["experiment_id"] frame_metadata = conversion_definition["frame_metadata"] - osd.export_frame_to_hdf5(raw_filename, ophys_hdf5_filename, - auxiliary_hdf5_filename, frame_metadata) + osd.export_frame_to_hdf5(raw_filename, ophys_hdf5_filename, auxiliary_hdf5_filename, frame_metadata) return experiment_id, ophys_hdf5_filename, auxiliary_hdf5_filename def parse_input(data): - '''Load all input data from the input json.''' + """Load all input data from the input json.""" conversion_definitions = data["frame_metadata"] for item in conversion_definitions: item["input_file"] = data["raw_filename"] @@ -98,16 +103,16 @@ def main(): pool = Pool(processes=mod.args.threads) output = pool.map(convert_frame, conversion_definitions) else: - output= [] + output = [] for definition in conversion_definitions: output.append(convert_frame(definition)) output_data = {} for eid, ophys_file, auxiliary_file in output: - output_data[eid] = {"ophys_data": ophys_file, - "auxiliary_data": auxiliary_file} + output_data[eid] = {"ophys_data": ophys_file, "auxiliary_data": auxiliary_file} mod.write_output_data(output_data) + if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_ophys_time_sync.py b/allensdk/internal/pipeline_modules/run_ophys_time_sync.py index 8162580a0f..b8b8fe3aa9 100644 --- a/allensdk/internal/pipeline_modules/run_ophys_time_sync.py +++ b/allensdk/internal/pipeline_modules/run_ophys_time_sync.py @@ -11,13 +11,11 @@ import allensdk from allensdk.internal.core.lims_pipeline_module import PipelineModule from allensdk.internal.brain_observatory import time_sync as ts -from allensdk.brain_observatory.argschema_utilities import \ - check_write_access_overwrite +from allensdk.brain_observatory.argschema_utilities import check_write_access_overwrite class TimeSyncOutputs(NamedTuple): - """ Schema for synchronization outputs - """ + """Schema for synchronization outputs""" # unique identifier for the experiment being aligned experiment_id: int @@ -25,7 +23,7 @@ class TimeSyncOutputs(NamedTuple): # calculated monitor delay (s) stimulus_delay: float - # For each data stream, the count of "extra" timestamps (compared to the + # For each data stream, the count of "extra" timestamps (compared to the # number of samples) ophys_delta: int stimulus_delta: int @@ -33,12 +31,12 @@ class TimeSyncOutputs(NamedTuple): behavior_delta: int # aligned timestamps for each data stream (s) - ophys_times: np.ndarray + ophys_times: np.ndarray stimulus_times: np.ndarray eye_times: np.ndarray behavior_times: np.ndarray - # for non-ophys data streams, a mapping from samples to corresponding ophys + # for non-ophys data streams, a mapping from samples to corresponding ophys # frames stimulus_alignment: np.ndarray eye_alignment: np.ndarray @@ -46,21 +44,16 @@ class TimeSyncOutputs(NamedTuple): class TimeSyncWriter: - - def __init__( - self, - output_h5_path: str, - output_json_path: Optional[str] = None - ): - """ Writes synchronization outputs to h5 and (optionally) json. + def __init__(self, output_h5_path: str, output_json_path: Optional[str] = None): + """Writes synchronization outputs to h5 and (optionally) json. Parameters ---------- - output_h5_path : "heavy" outputs (e.g aligned timestamps and - ophy frame correspondances) will ONLY be stored here. Lightweight + output_h5_path : "heavy" outputs (e.g aligned timestamps and + ophy frame correspondances) will ONLY be stored here. Lightweight outputs (e.g. stimulus delay) will also be written here as scalars. - output_json_path : if provided, lightweight outputs will be written - here, along with provenance information, such as the date and + output_json_path : if provided, lightweight outputs will be written + here, along with provenance information, such as the date and allensdk version. """ @@ -69,8 +62,8 @@ def __init__( self.output_json_path: Optional[str] = output_json_path def validate_paths(self): - """ Determines whether we can actually write to the specified paths, - allowing for creation of intermediate directories. It is a good idea + """Determines whether we can actually write to the specified paths, + allowing for creation of intermediate directories. It is a good idea to run this beore doing any heavy calculations! """ @@ -80,7 +73,7 @@ def validate_paths(self): check_write_access_overwrite(self.output_json_path) def write(self, outputs: TimeSyncOutputs): - """ Convenience for writing both an output h5 and (if applicable) an + """Convenience for writing both an output h5 and (if applicable) an output json. Parameters @@ -95,7 +88,7 @@ def write(self, outputs: TimeSyncOutputs): self.write_output_json(outputs) def write_output_h5(self, outputs): - """ Write (mainly) heaviweight data to an h5 file. + """Write (mainly) heaviweight data to an h5 file. Parameters ---------- @@ -117,7 +110,7 @@ def write_output_h5(self, outputs): output_h5["behavior_delta"] = outputs.behavior_delta def write_output_json(self, outputs): - """ Write lightweight data to a json + """Write lightweight data to a json Parameters ---------- @@ -127,21 +120,25 @@ def write_output_json(self, outputs): os.makedirs(os.path.dirname(self.output_json_path), exist_ok=True) with open(self.output_json_path, "w") as output_json: - json.dump({ - "allensdk_version": allensdk.__version__, - "date": str(datetime.datetime.now()), - "experiment_id": outputs.experiment_id, - "output_h5_path": self.output_h5_path, - "ophys_delta": outputs.ophys_delta, - "stim_delta": outputs.stimulus_delta, - "stim_delay": outputs.stimulus_delay, - "eye_delta": outputs.eye_delta, - "behavior_delta": outputs.behavior_delta - }, output_json, indent=2) + json.dump( + { + "allensdk_version": allensdk.__version__, + "date": str(datetime.datetime.now()), + "experiment_id": outputs.experiment_id, + "output_h5_path": self.output_h5_path, + "ophys_delta": outputs.ophys_delta, + "stim_delta": outputs.stimulus_delta, + "stim_delay": outputs.stimulus_delay, + "eye_delta": outputs.eye_delta, + "behavior_delta": outputs.behavior_delta, + }, + output_json, + indent=2, + ) def check_stimulus_delay(obt_delay: float, min_delay: float, max_delay: float): - """ Raise an exception if the monitor delay is not within specified bounds + """Raise an exception if the monitor delay is not within specified bounds Parameters ---------- @@ -153,34 +150,29 @@ def check_stimulus_delay(obt_delay: float, min_delay: float, max_delay: float): if obt_delay < min_delay or obt_delay > max_delay: raise ValueError( - f"calculated monitor delay was {obt_delay:.3f}s " - f"(acceptable interval: [{min_delay:.3f}s, " - f"{max_delay:.3f}s])" + f"calculated monitor delay was {obt_delay:.3f}s (acceptable interval: [{min_delay:.3f}s, {max_delay:.3f}s])" ) def run_ophys_time_sync( - aligner: ts.OphysTimeAligner, - experiment_id: int, - min_stimulus_delay: float, - max_stimulus_delay: float + aligner: ts.OphysTimeAligner, experiment_id: int, min_stimulus_delay: float, max_stimulus_delay: float ) -> TimeSyncOutputs: - """ Carry out synchronization of timestamps across the data streams of an + """Carry out synchronization of timestamps across the data streams of an ophys experiment. Parameters ---------- - aligner : drives alignment. See OphysTimeAligner for details of the + aligner : drives alignment. See OphysTimeAligner for details of the attributes and properties that must be implemented. experiment_id : unique identifier for the experiment being aligned - min_stimulus_delay : reject alignment run (raise a ValueError) if the + min_stimulus_delay : reject alignment run (raise a ValueError) if the calculated monitor delay is below this value (s). - max_stimulus_delay : reject alignment run (raise a ValueError) if the + max_stimulus_delay : reject alignment run (raise a ValueError) if the calculated monitor delay is above this value (s). Returns ------- - A TimeSyncOutputs (see definintion for more information) of output + A TimeSyncOutputs (see definintion for more information) of output parameters and arrays of aligned timestamps. """ @@ -199,11 +191,9 @@ def run_ophys_time_sync( # camera arrays are index of camera frame for each ophys frame ... # cam_nwb_creator depends on this so keeping it that way even though # it makes little sense... len(video_times) - eye_alignment = ts.get_alignment_array(eye_times, ophys_times, - int_method=np.ceil) + eye_alignment = ts.get_alignment_array(eye_times, ophys_times, int_method=np.ceil) - behavior_alignment = ts.get_alignment_array(beh_times, ophys_times, - int_method=np.ceil) + behavior_alignment = ts.get_alignment_array(beh_times, ophys_times, int_method=np.ceil) return TimeSyncOutputs( experiment_id, @@ -218,24 +208,23 @@ def run_ophys_time_sync( beh_times, stim_alignment, eye_alignment, - behavior_alignment + behavior_alignment, ) def main(): parser = argparse.ArgumentParser("Generate brain observatory alignment.") - parser.add_argument("input_json", type=str, - help="path to input json" - ) - parser.add_argument("output_json", type=str, nargs="?", - help="path to which output json will be written" - ) + parser.add_argument("input_json", type=str, help="path to input json") + parser.add_argument("output_json", type=str, nargs="?", help="path to which output json will be written") parser.add_argument("--log-level", default=logging.DEBUG) - parser.add_argument("--min-stimulus-delay", type=float, default=0.0, - help="reject results if monitor delay less than this value (s)" + parser.add_argument( + "--min-stimulus-delay", type=float, default=0.0, help="reject results if monitor delay less than this value (s)" ) - parser.add_argument("--max-stimulus-delay", type=float, default=0.07, - help="reject results if monitor delay greater than this value (s)" + parser.add_argument( + "--max-stimulus-delay", + type=float, + default=0.07, + help="reject results if monitor delay greater than this value (s)", ) mod = PipelineModule("Generate brain observatory alignment.", parser) @@ -245,25 +234,20 @@ def main(): writer.validate_paths() aligner = ts.OphysTimeAligner( - input_data.get("sync_file"), + input_data.get("sync_file"), scanner=input_data.get("scanner", None), dff_file=input_data.get("dff_file", None), stimulus_pkl=input_data.get("stimulus_pkl", None), eye_video=input_data.get("eye_video", None), behavior_video=input_data.get("behavior_video", None), - long_stim_threshold=input_data.get( - "long_stim_threshold", ts.LONG_STIM_THRESHOLD - ) + long_stim_threshold=input_data.get("long_stim_threshold", ts.LONG_STIM_THRESHOLD), ) outputs = run_ophys_time_sync( - aligner, - input_data.get("ophys_experiment_id"), - mod.args.min_stimulus_delay, - mod.args.max_stimulus_delay + aligner, input_data.get("ophys_experiment_id"), mod.args.min_stimulus_delay, mod.args.max_stimulus_delay ) writer.write(outputs) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/allensdk/internal/pipeline_modules/run_roi_filter.py b/allensdk/internal/pipeline_modules/run_roi_filter.py index e83f2f38af..2b5ded149d 100644 --- a/allensdk/internal/pipeline_modules/run_roi_filter.py +++ b/allensdk/internal/pipeline_modules/run_roi_filter.py @@ -1,10 +1,8 @@ import logging import allensdk.internal.core.lims_utilities as lu -from allensdk.internal.core.lims_pipeline_module import ( - PipelineModule, run_module) +from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module from allensdk.internal.brain_observatory import roi_filter, roi_filter_utils -from allensdk.brain_observatory.roi_masks import (RIGHT_SHIFT, LEFT_SHIFT, - DOWN_SHIFT, UP_SHIFT) +from allensdk.brain_observatory.roi_masks import RIGHT_SHIFT, LEFT_SHIFT, DOWN_SHIFT, UP_SHIFT import pandas as pd import os import h5py @@ -13,34 +11,38 @@ MAX_SHIFT = 30 OVERLAP_THRESHOLD = 0.9 DEBUG_SDK_PATH = "/data/informatics/CAM/roi_filter/allensdk/" -DEBUG_SCRIPT = os.path.join(DEBUG_SDK_PATH, "allensdk", "internal", - "pipeline_modules", "run_roi_filter.py") +DEBUG_SCRIPT = os.path.join(DEBUG_SDK_PATH, "allensdk", "internal", "pipeline_modules", "run_roi_filter.py") DEBUG_OUTPUT_DIRECTORY = "/data/informatics/CAM/roi_filter/" def get_motion_filepath(experiment_id): - return lu.query(""" + return lu.query( + """ select CONCAT(wkf.storage_directory, wkf.filename) as path from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id join ophys_experiments oe on oe.id = wkf.attachable_id where oe.id = {} and -wkft.name like 'OphysMotionXyOffsetData'""".format(experiment_id))[0]["path"] +wkft.name like 'OphysMotionXyOffsetData'""".format(experiment_id) + )[0]["path"] def get_segmentation_filepath(experiment_id, file_type): - return lu.query(""" + return lu.query( + """ select CONCAT(wkf.storage_directory, wkf.filename) as path from well_known_files wkf join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id join ophys_cell_segmentation_runs ocsr on ocsr.id = wkf.attachable_id join ophys_experiments oe on oe.id = ocsr.ophys_experiment_id where oe.id = {} and wkft.name like '{}' -and ocsr.current = 't'""".format(experiment_id, file_type))[0]["path"] +and ocsr.current = 't'""".format(experiment_id, file_type) + )[0]["path"] def get_model_info(experiment_id): - res = lu.query(""" + res = lu.query( + """ select CONCAT(wkf.storage_directory, wkf.filename) as path, wkf.id from ophys_experiments oe join ophys_sessions os on os.id = oe.ophys_session_id @@ -48,12 +50,14 @@ def get_model_info(experiment_id): join well_known_files wkf on wkf.attachable_id = p.id join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id where oe.id = {} and -wkft.name = 'RoiLabelModel'""".format(experiment_id))[0] +wkft.name = 'RoiLabelModel'""".format(experiment_id) + )[0] return res["path"], res["id"] def get_genotype_info(experiment_id, code): - res = lu.query(""" + res = lu.query( + """ select g.name from ophys_experiments oe join ophys_sessions os on os.id = oe.ophys_session_id @@ -62,7 +66,8 @@ def get_genotype_info(experiment_id, code): join donors_genotypes dg on dg.donor_id = d.id join genotypes g on g.id = dg.genotype_id join genotype_types gt on gt.id = g.genotype_type_id -where oe.id = {} and gt.code like '{}'""".format(experiment_id, code)) +where oe.id = {} and gt.code like '{}'""".format(experiment_id, code) + ) output = set() for line in res: output.add(line["name"]) @@ -75,62 +80,60 @@ def create_input_data(experiment_id): model, model_id = get_model_info(experiment_id) data["roi_label_model"] = model data["roi_label_model_id"] = model_id - data["targeted_structure_id"] = lu.query(""" + data["targeted_structure_id"] = lu.query( + """ select targeted_structure_id from ophys_experiments oe -where oe.id = {}""".format(experiment_id))[0]["targeted_structure_id"] - data["imaging_depth"] = lu.query(""" +where oe.id = {}""".format(experiment_id) + )[0]["targeted_structure_id"] + data["imaging_depth"] = lu.query( + """ select calculated_depth from ophys_experiments oe -where oe.id = {}""".format(experiment_id))[0]["calculated_depth"] +where oe.id = {}""".format(experiment_id) + )[0]["calculated_depth"] data["drivers"] = get_genotype_info(experiment_id, "D") data["reporters"] = get_genotype_info(experiment_id, "R") - data["max_int_file"] = get_segmentation_filepath( - experiment_id, "OphysSegmentationMaskData") - data["object_list"] = get_segmentation_filepath( - experiment_id, "OphysSegmentationObjects") + data["max_int_file"] = get_segmentation_filepath(experiment_id, "OphysSegmentationMaskData") + data["object_list"] = get_segmentation_filepath(experiment_id, "OphysSegmentationObjects") return data -def debug(experiment_id, local=False, sdk_path=DEBUG_SDK_PATH, - script=DEBUG_SCRIPT, output_directory=DEBUG_OUTPUT_DIRECTORY): +def debug( + experiment_id, local=False, sdk_path=DEBUG_SDK_PATH, script=DEBUG_SCRIPT, output_directory=DEBUG_OUTPUT_DIRECTORY +): input_data = create_input_data(experiment_id) exp_dir = os.path.join(output_directory, str(experiment_id)) - run_module(script, - input_data, - exp_dir, - sdk_path=sdk_path, - local=local) + run_module(script, input_data, exp_dir, sdk_path=sdk_path, local=local) def load_object_list(filename): - '''Load the object list file.''' + """Load the object list file.""" dataframe = pd.read_csv(filename) dataframe.columns = [column.strip() for column in dataframe.columns] return dataframe def is_deprecated_motion_file(filename): - '''Check if a file is an old style motion correction file. + """Check if a file is an old style motion correction file. By agreement, new-style files will always have a header and that header will always contain at least 1 alpha character. - ''' + """ with open(filename, "r") as f: return not any([c.isalpha() for c in f.readline()]) def load_rigid_motion_transform(filename): - '''Load the rigid motion transform file.''' + """Load the rigid motion transform file.""" if is_deprecated_motion_file(filename): - return pd.read_csv(filename, header=None, - names=DEPRECATED_MOTION_HEADER) + return pd.read_csv(filename, header=None, names=DEPRECATED_MOTION_HEADER) else: return pd.read_csv(filename) def load_all_input(data): - '''Load all input data from the input json.''' + """Load all input data from the input json.""" try: object_list_file = data["object_list"] object_data = load_object_list(object_list_file) @@ -150,8 +153,7 @@ def load_all_input(data): logging.error("Input json missing log_0") raise except IOError: - logging.error("Could not read rigid motion transform file %s", - rigid_motion_transform_file) + logging.error("Could not read rigid motion transform file %s", rigid_motion_transform_file) raise try: @@ -214,30 +216,32 @@ def load_all_input(data): raise ValueError(f"no ROIs were found from {maxint_file}") rois = roi_filter_utils.order_rois_by_object_list(object_data, rois) - result = {"model_id": model_id, - "classifier": classifier, - "object_data": object_data, - "depth": depth, - "structure_id": structure_id, - "drivers": drivers, - "reporters": reporters, - "border": border, - "rois": rois} + result = { + "model_id": model_id, + "classifier": classifier, + "object_data": object_data, + "depth": depth, + "structure_id": structure_id, + "drivers": drivers, + "reporters": reporters, + "border": border, + "rois": rois, + } return result -def create_output_data(rois, model_id, border, excluded, - unexpected_features): +def create_output_data(rois, model_id, border, excluded, unexpected_features): data = {} - data["motion_border"] = {"x0": border[RIGHT_SHIFT], - "y0": border[DOWN_SHIFT], - "x1": border[LEFT_SHIFT], - "y1": border[UP_SHIFT]} + data["motion_border"] = { + "x0": border[RIGHT_SHIFT], + "y0": border[DOWN_SHIFT], + "x1": border[LEFT_SHIFT], + "y1": border[UP_SHIFT], + } data["roi_label_model_id"] = model_id data["unexpected_features"] = unexpected_features if rois: - data["image"] = {"width": rois[0].img_cols, - "height": rois[0].img_rows} + data["image"] = {"width": rois[0].img_cols, "height": rois[0].img_rows} json_rois = {} for i, roi in enumerate(rois): json_roi = {} @@ -274,15 +278,12 @@ def main(): border = data["border"] rois = data["rois"] - label_array = classifier.get_labels(object_data, depth, structure_id, - drivers, reporters) + label_array = classifier.get_labels(object_data, depth, structure_id, drivers, reporters) rois = roi_filter.apply_labels(rois, label_array, classifier.label_names) rois = roi_filter.label_unions_and_duplicates(rois, OVERLAP_THRESHOLD) - output_data = create_output_data(rois, model_id, border, - object_data["eXcluded"], - classifier.unexpected_features) + output_data = create_output_data(rois, model_id, border, object_data["eXcluded"], classifier.unexpected_features) mod.write_output_data(output_data) diff --git a/allensdk/internal/pipeline_modules/run_tissuecyte_projection_thumbnail_from_json.py b/allensdk/internal/pipeline_modules/run_tissuecyte_projection_thumbnail_from_json.py index f15d99443c..a29d2215e8 100644 --- a/allensdk/internal/pipeline_modules/run_tissuecyte_projection_thumbnail_from_json.py +++ b/allensdk/internal/pipeline_modules/run_tissuecyte_projection_thumbnail_from_json.py @@ -21,7 +21,7 @@ def write_depth_image(image, path): image = sitk.GetImageFromArray(image) sitk.WriteImage(image, str(path)) - + def load_background_image(path): background = sitk.ReadImage(str(path)) @@ -45,51 +45,50 @@ def pad(volume): def main(): - module = PipelineModule() input_data = module.input_data() output_dir = os.path.dirname(module.args.output_json) - logging.info('reading data volume from {0}'.format(input_data['volume_path'])) - volume = sitk.ReadImage(str(input_data['volume_path'])) + logging.info("reading data volume from {0}".format(input_data["volume_path"])) + volume = sitk.ReadImage(str(input_data["volume_path"])) volume = sitk.PermuteAxes(volume, PERMUTATION) volume = sitk.Flip(volume, FLIP) - logging.info('reading colormap from {0}'.format(input_data['colormap_path'])) - colormap = pd.read_csv(input_data['colormap_path'], header=None, - names=['red', 'green', 'blue'], delim_whitespace=True) - colormap = convert_discrete_colormap(colormap.values, 'projection') - - output_data = {'output_file_paths': []} - for rot in input_data['rotations']: - - rot['write_depth_sheet'] = functools.partial(write_depth_image, - path=str(os.path.join(output_dir, rot['depth_path']))) - output_data['output_file_paths'].append(os.path.join(output_dir, rot['depth_path'])) - - if isinstance(rot['window_size'], str): - if rot['window_size'] == 'no_pad': - rot['window_size'] = no_pad(volume) - elif rot['window_size'] == 'pad': - rot['window_size'] = pad(volume) + logging.info("reading colormap from {0}".format(input_data["colormap_path"])) + colormap = pd.read_csv( + input_data["colormap_path"], header=None, names=["red", "green", "blue"], delim_whitespace=True + ) + colormap = convert_discrete_colormap(colormap.values, "projection") + + output_data = {"output_file_paths": []} + for rot in input_data["rotations"]: + rot["write_depth_sheet"] = functools.partial( + write_depth_image, path=str(os.path.join(output_dir, rot["depth_path"])) + ) + output_data["output_file_paths"].append(os.path.join(output_dir, rot["depth_path"])) + + if isinstance(rot["window_size"], str): + if rot["window_size"] == "no_pad": + rot["window_size"] = no_pad(volume) + elif rot["window_size"] == "pad": + rot["window_size"] = pad(volume) else: - raise ValueError('did not understand window size option {0}'.format(rot['window_size'])) - logging.info('window_size: {0}'.format(rot['window_size'])) - - for out_image in rot['output_images']: - out_image['write'] = functools.partial(imsave, os.path.join(output_dir, out_image['path'])) - output_data['output_file_paths'].append(os.path.join(output_dir, out_image['path'])) - - if 'background_path' in out_image: - out_image['background'] = load_background_image(out_image['background_path']) + raise ValueError("did not understand window size option {0}".format(rot["window_size"])) + logging.info("window_size: {0}".format(rot["window_size"])) + + for out_image in rot["output_images"]: + out_image["write"] = functools.partial(imsave, os.path.join(output_dir, out_image["path"])) + output_data["output_file_paths"].append(os.path.join(output_dir, out_image["path"])) + + if "background_path" in out_image: + out_image["background"] = load_background_image(out_image["background_path"]) else: - out_image['background'] = None + out_image["background"] = None - run(volume, input_data['min_threshold'], input_data['max_threshold'], - input_data['rotations'], colormap) + run(volume, input_data["min_threshold"], input_data["max_threshold"], input_data["rotations"], colormap) module.write_output_data(output_data) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_tissuecyte_stitching_classic.py b/allensdk/internal/pipeline_modules/run_tissuecyte_stitching_classic.py index 21d885ac4b..6a67279a29 100644 --- a/allensdk/internal/pipeline_modules/run_tissuecyte_stitching_classic.py +++ b/allensdk/internal/pipeline_modules/run_tissuecyte_stitching_classic.py @@ -11,9 +11,9 @@ from allensdk.internal.mouse_connectivity.tissuecyte_stitching.tile import Tile import allensdk.core.json_utilities as ju -# TODO this ought to be installed with the actual python build? +# TODO this ought to be installed with the actual python build? # need to consult with sysadmins/refactor jp2 project build -sys.path.append('/shared/bioapps/itk/itk_shared/jp2/build') +sys.path.append("/shared/bioapps/itk/itk_shared/jp2/build") import jpeg_twok @@ -22,25 +22,23 @@ def get_missing_tile_paths(missing_tiles): - paths = [] for index, path in missing_tiles.items(): - spath = ','.join(map(str, path)) - logging.info('writing missing tile path for tile {0} as {1}'.format(index, spath)) + spath = ",".join(map(str, path)) + logging.info("writing missing tile path for tile {0} as {1}".format(index, spath)) paths.append(spath) return paths def read_image(file_name): - logging.info('reading image from {0}'.format(file_name)) + logging.info("reading image from {0}".format(file_name)) image = sitk.ReadImage(str(file_name)) return np.flipud(sitk.GetArrayFromImage(image)).T def normalize_image_by_median(image): - median = np.median(image) if median != 0: @@ -57,34 +55,32 @@ def load_average_tile(path): def get_average_tiles(average_tile_paths): - - average_tiles = {} + average_tiles = {} for key, path in average_tile_paths.items(): key = int(key) - 1 try: average_tiles[key] = load_average_tile(path) - logging.info('found average tile for channel {0} (zero-indexed)'.format(key)) - except(IOError, OSError, RuntimeError): + logging.info("found average tile for channel {0} (zero-indexed)".format(key)) + except (IOError, OSError, RuntimeError): average_tiles[key] = None - logging.info('did not find average tile for channel {0} (zero-indexed)'.format(key)) - + logging.info("did not find average tile for channel {0} (zero-indexed)".format(key)) + return average_tiles def generate_tiles(tiles): - for tile_params in tiles: tile = tile_params.copy() try: - tile['image'] = read_image(tile['path']) - tile['is_missing'] = False + tile["image"] = read_image(tile["path"]) + tile["is_missing"] = False except (IOError, OSError, RuntimeError): - tile['image'] = None - tile['is_missing'] = True - - tile['channel'] = tile['channel'] - 1 + tile["image"] = None + tile["is_missing"] = True + + tile["channel"] = tile["channel"] - 1 tile_obj = Tile(**tile) del tile @@ -96,34 +92,31 @@ def write_output(arr, spacing, path): def main(): - output_json = args.output_json output_directory = os.path.dirname(output_json) - slice_path = os.path.join(output_directory, data['slice_fname']) - - tiles = generate_tiles(data['tiles']) - average_tiles = get_average_tiles(data['average_tile_paths']) + slice_path = os.path.join(output_directory, data["slice_fname"]) + + tiles = generate_tiles(data["tiles"]) + average_tiles = get_average_tiles(data["average_tile_paths"]) - stitcher = Stitcher(data['image_dimensions'], tiles, average_tiles, data['channels']) + stitcher = Stitcher(data["image_dimensions"], tiles, average_tiles, data["channels"]) image, missing = stitcher.run() del tiles missing_tile_paths = get_missing_tile_paths(missing) - write_output(np.ascontiguousarray(image), data['spacing'], slice_path) + write_output(np.ascontiguousarray(image), data["spacing"], slice_path) - module_outputs = {'slice_fname': slice_path, - 'missing_tile_paths': missing_tile_paths} + module_outputs = {"slice_fname": slice_path, "missing_tile_paths": missing_tile_paths} ju.write(output_json, module_outputs) - -if __name__ == '__main__': - logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +if __name__ == "__main__": + logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") parser = argparse.ArgumentParser() - parser.add_argument('input_json', type=str) - parser.add_argument('output_json', type=str) + parser.add_argument("input_json", type=str) + parser.add_argument("output_json", type=str) args = parser.parse_args() data = ju.read(args.input_json) diff --git a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_cav_from_json.py b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_cav_from_json.py index 54175e54ea..4ed929b6a7 100644 --- a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_cav_from_json.py +++ b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_cav_from_json.py @@ -1,12 +1,8 @@ - from allensdk.internal.core.lims_pipeline_module import PipelineModule -from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_cav import run - - +from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_cav import run def main(): - module = PipelineModule() input_data = module.input_data() @@ -14,5 +10,5 @@ def main(): module.write_output_data(output_data) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_counts_from_json.py b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_counts_from_json.py index 6b56a86b64..94a0e72ca8 100644 --- a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_counts_from_json.py +++ b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_counts_from_json.py @@ -1,10 +1,8 @@ - from allensdk.internal.core.lims_pipeline_module import PipelineModule -from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_classic_counts import run +from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_classic_counts import run def main(): - module = PipelineModule() input_data = module.input_data() @@ -12,5 +10,5 @@ def main(): module.write_output_data(output_data) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_from_json.py b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_from_json.py index 9bc2df2e55..972a0695eb 100644 --- a/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_from_json.py +++ b/allensdk/internal/pipeline_modules/run_tissuecyte_unionize_classic_from_json.py @@ -1,12 +1,8 @@ - from allensdk.internal.core.lims_pipeline_module import PipelineModule -from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_classic import run - - +from allensdk.internal.mouse_connectivity.interval_unionize.run_tissuecyte_unionize_classic import run def main(): - module = PipelineModule() input_data = module.input_data() @@ -14,5 +10,5 @@ def main(): module.write_output_data(output_data) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/model/__init__.py b/allensdk/model/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/model/__init__.py +++ b/allensdk/model/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/model/biophys_sim/__init__.py b/allensdk/model/biophys_sim/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/model/biophys_sim/__init__.py +++ b/allensdk/model/biophys_sim/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/model/biophys_sim/bps_command.py b/allensdk/model/biophys_sim/bps_command.py index 2546495bc7..705e78906e 100644 --- a/allensdk/model/biophys_sim/bps_command.py +++ b/allensdk/model/biophys_sim/bps_command.py @@ -40,24 +40,23 @@ from .config import Config -def choose_bps_command(command='bps_simple', conf_file=None): - log = logging.getLogger('allensdk.model.biophys_sim.bps_command') +def choose_bps_command(command="bps_simple", conf_file=None): + log = logging.getLogger("allensdk.model.biophys_sim.bps_command") log.info("bps command: %s" % (command)) if conf_file: conf_file = os.path.abspath(conf_file) - if command == 'help': - print(Config().argparser.parse_args(['--help'])) - elif command == 'nrnivmodl': - sp.call(['nrnivmodl', 'modfiles']) - elif command == 'run_simple': + if command == "help": + print(Config().argparser.parse_args(["--help"])) + elif command == "nrnivmodl": + sp.call(["nrnivmodl", "modfiles"]) + elif command == "run_simple": app_config = Config() description = app_config.load(conf_file) - sys.path.insert(1, description.manifest.get_path('CODE_DIR')) - (module_name, function_name) = description.data[ - 'runs'][0]['main'].split('#') + sys.path.insert(1, description.manifest.get_path("CODE_DIR")) + (module_name, function_name) = description.data["runs"][0]["main"].split("#") run_module(description, module_name, function_name) else: raise Exception("unknown command %s" % (command)) @@ -79,24 +78,23 @@ def main(): argv = sys.argv if len(argv) > 1: - if argv[0] == 'nrniv': - command = 'run_simple' + if argv[0] == "nrniv": + command = "run_simple" else: command = argv[1] else: - command = 'run_simple' + command = "run_simple" - if len(argv) > 2 and (argv[-1].endswith('.conf') or - argv[-1].endswith('.json')): + if len(argv) > 2 and (argv[-1].endswith(".conf") or argv[-1].endswith(".json")): conf_file = argv[-1] else: try: - conf_file = os.environ['CONF_FILE'] + conf_file = os.environ["CONF_FILE"] except Exception: pass choose_bps_command(command, conf_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/model/biophys_sim/config.py b/allensdk/model/biophys_sim/config.py index 288bb5733f..5d4775a7d7 100644 --- a/allensdk/model/biophys_sim/config.py +++ b/allensdk/model/biophys_sim/config.py @@ -44,44 +44,44 @@ class Config(ApplicationConfig): _log = logging.getLogger(__name__) - _DEFAULT_LOG_CONFIG = str(files(__package__).joinpath('logging.conf')) + _DEFAULT_LOG_CONFIG = str(files(__package__).joinpath("logging.conf")) #: A structure that defines the available configuration parameters. #: The default value and help strings may be seen by viewing the source. _DEFAULTS = { - 'workdir': {'default': 'workdir', - 'help': 'writable directory where intermediate and output files are written.'}, - 'data_dir': {'default': '', - 'help': 'writable directory where intermediate and output files are written.'}, - 'model_file': {'default': 'config.json', - 'help': 'file where the model parameters are set.'}, - 'main': {'default': 'simulation#run', - 'help': 'module#function that runs the actual simulation'} + "workdir": { + "default": "workdir", + "help": "writable directory where intermediate and output files are written.", + }, + "data_dir": {"default": "", "help": "writable directory where intermediate and output files are written."}, + "model_file": {"default": "config.json", "help": "file where the model parameters are set."}, + "main": {"default": "simulation#run", "help": "module#function that runs the actual simulation"}, } def __init__(self): - super(Config, self).__init__(Config._DEFAULTS, - name='biophys', - halp='tools for biophysically detailed modeling at the Allen Institute.', - default_log_config=Config._DEFAULT_LOG_CONFIG) - - def load(self, config_path, - disable_existing_logs=False): - '''Parse the application configuration then immediately load + super(Config, self).__init__( + Config._DEFAULTS, + name="biophys", + halp="tools for biophysically detailed modeling at the Allen Institute.", + default_log_config=Config._DEFAULT_LOG_CONFIG, + ) + + def load(self, config_path, disable_existing_logs=False): + """Parse the application configuration then immediately load the model configuration files. Parameters ---------- disable_existing_logs : boolean, optional If false (default) leave existing logs after configuration. - ''' + """ super(Config, self).load([config_path], disable_existing_logs) description = self.read_model_description() return description def read_model_description(self): - '''parse the model_file field of the application configuration + """parse the model_file field of the application configuration and read the files. The model_file field of the application configuration is @@ -98,27 +98,25 @@ def read_model_description(self): ------- description : Description Configuration object. - ''' + """ reader = DescriptionParser() description = Description() Config._log.info("model file: %s" % self.model_file) # TODO: make space aware w/ regex - for model_file in self.model_file.split(','): + for model_file in self.model_file.split(","): if not model_file.startswith("file:"): - model_file = 'file:' + model_file + model_file = "file:" + model_file file_regex = re.compile(r"^file:([^?]*)(\?(.*)?)?") m = file_regex.match(model_file) model_file = m.group(1) file_url_params = {} if m.group(3): - file_url_params.update(((x[0], x[1]) - for x in (y.split('=') - for y in m.group(3).split('&')))) - if 'section' in file_url_params: - section = file_url_params['section'] + file_url_params.update(((x[0], x[1]) for x in (y.split("=") for y in m.group(3).split("&")))) + if "section" in file_url_params: + section = file_url_params["section"] else: section = None Config._log.info("reading model file %s" % (model_file)) diff --git a/allensdk/model/biophys_sim/neuron/__init__.py b/allensdk/model/biophys_sim/neuron/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/model/biophys_sim/neuron/__init__.py +++ b/allensdk/model/biophys_sim/neuron/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/model/biophys_sim/neuron/hoc_utils.py b/allensdk/model/biophys_sim/neuron/hoc_utils.py index 9cebec3186..beda440d09 100644 --- a/allensdk/model/biophys_sim/neuron/hoc_utils.py +++ b/allensdk/model/biophys_sim/neuron/hoc_utils.py @@ -37,7 +37,7 @@ class HocUtils(object): - '''A helper class for containing references to NEUORN. + """A helper class for containing references to NEUORN. Attributes ---------- @@ -47,7 +47,8 @@ class HocUtils(object): The NEURON python object. neuron : module The NEURON module. - ''' + """ + _log = logging.getLogger(__name__) h = None nrn = None @@ -65,31 +66,31 @@ def __init__(self, description): self.description = description self.manifest = description.manifest - self.hoc_files = description.data['neuron'][0]['hoc'] + self.hoc_files = description.data["neuron"][0]["hoc"] self.initialize_hoc() def initialize_hoc(self): - '''Basic setup for NEURON.''' + """Basic setup for NEURON.""" h = self.h - params = self.description.data['conditions'][0] + params = self.description.data["conditions"][0] for hoc_file in self.hoc_files: HocUtils._log.info("loading hoc file %s" % (hoc_file)) HocUtils.h.load_file(str(hoc_file)) - h('starttime = startsw()') + h("starttime = startsw()") - if 'celsius' in params: - h.celsius = params['celsius'] + if "celsius" in params: + h.celsius = params["celsius"] - if 'v_init' in params: - h.v_init = params['v_init'] + if "v_init" in params: + h.v_init = params["v_init"] - if 'dt' in params: - h.dt = params['dt'] + if "dt" in params: + h.dt = params["dt"] h.steps_per_ms = 1.0 / h.dt - if 'tstop' in params: - h.tstop = params['tstop'] + if "tstop" in params: + h.tstop = params["tstop"] h.runStopAt = h.tstop diff --git a/allensdk/model/biophys_sim/scripts/__init__.py b/allensdk/model/biophys_sim/scripts/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/model/biophys_sim/scripts/__init__.py +++ b/allensdk/model/biophys_sim/scripts/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/model/biophysical/__init__.py b/allensdk/model/biophysical/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/model/biophysical/__init__.py +++ b/allensdk/model/biophysical/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/model/biophysical/run_simulate.py b/allensdk/model/biophysical/run_simulate.py index 2c5522d279..c40cc1d058 100644 --- a/allensdk/model/biophysical/run_simulate.py +++ b/allensdk/model/biophysical/run_simulate.py @@ -45,11 +45,9 @@ class RunSimulate(object): - _log = logging.getLogger('allensdk.model.biophysical.run_simulate') + _log = logging.getLogger("allensdk.model.biophysical.run_simulate") - def __init__(self, - input_json, - output_json): + def __init__(self, input_json, output_json): self.input_json = input_json self.output_json = output_json self.app_config = None @@ -58,77 +56,72 @@ def __init__(self, def load_manifest(self): self.app_config = Config().load(self.input_json) self.manifest = self.app_config.manifest - fix_sections = ['passive', 'axon_morph,', 'conditions', 'fitting'] + fix_sections = ["passive", "axon_morph,", "conditions", "fitting"] self.app_config.fix_unary_sections(fix_sections) def nrnivmodl(self): RunSimulate._log.debug("nrnivmodl") - subprocess.call(['nrnivmodl', './modfiles']) + subprocess.call(["nrnivmodl", "./modfiles"]) def simulate(self): - from allensdk.internal.api.queries.biophysical_module_reader \ - import BiophysicalModuleReader + from allensdk.internal.api.queries.biophysical_module_reader import BiophysicalModuleReader self.load_manifest() try: - stimulus_path = self.manifest.get_path('stimulus_path') + stimulus_path = self.manifest.get_path("stimulus_path") RunSimulate._log.info("stimulus path: %s" % (stimulus_path)) except Exception: - raise Exception( - 'Could not read input stimulus path from input config.') + raise Exception("Could not read input stimulus path from input config.") try: - out_path = self.manifest.get_path('output_path') + out_path = self.manifest.get_path("output_path") RunSimulate._log.info("result NWB file: %s" % (out_path)) except Exception: - raise Exception('Could not read output path from input config.') + raise Exception("Could not read output path from input config.") try: - morphology_path = self.manifest.get_path('MORPHOLOGY') + morphology_path = self.manifest.get_path("MORPHOLOGY") RunSimulate._log.info("morphology path: %s" % (morphology_path)) except Exception: - raise Exception( - 'Could not read morphology path from input config.') + raise Exception("Could not read morphology path from input config.") single_cell.run(self.app_config) lims_upload_config = BiophysicalModuleReader() - lims_upload_config.read_json( - self.manifest.get_path('neuronal_model_run_data')) + lims_upload_config.read_json(self.manifest.get_path("neuronal_model_run_data")) lims_upload_config.update_well_known_file(out_path) - lims_upload_config.set_workflow_state('passed') + lims_upload_config.set_workflow_state("passed") lims_upload_config.write_file(self.output_json) def main(command, lims_strategy_json, lims_response_json): - ''' Entry point for module. - :param command: select behavior, nrnivmodl or simulate - :type command: string - :param lims_strategy_json: path to json file output from lims. - :type lims_strategy_json: string - :param lims_response_json: path to json file returned to lims. - :type lims_response_json: string - ''' - rs = RunSimulate(lims_strategy_json, - lims_response_json) + """Entry point for module. + :param command: select behavior, nrnivmodl or simulate + :type command: string + :param lims_strategy_json: path to json file output from lims. + :type lims_strategy_json: string + :param lims_response_json: path to json file returned to lims. + :type lims_response_json: string + """ + rs = RunSimulate(lims_strategy_json, lims_response_json) RunSimulate._log.debug("command: %s" % (command)) RunSimulate._log.debug("lims strategy json: %s" % (lims_strategy_json)) RunSimulate._log.debug("lims upload json: %s" % (lims_response_json)) - log_config = str(files('allensdk.model.biophysical').joinpath('logging.conf')) + log_config = str(files("allensdk.model.biophysical").joinpath("logging.conf")) lc.fileConfig(log_config) - os.environ['LOG_CFG'] = log_config + os.environ["LOG_CFG"] = log_config - if 'nrnivmodl' == command: + if "nrnivmodl" == command: rs.nrnivmodl() else: rs.simulate() -if __name__ == '__main__': +if __name__ == "__main__": command, input_json, output_json = sys.argv[-3:] try: diff --git a/allensdk/model/biophysical/runner.py b/allensdk/model/biophysical/runner.py index f32c465fc1..3ad0ab3d88 100644 --- a/allensdk/model/biophysical/runner.py +++ b/allensdk/model/biophysical/runner.py @@ -45,16 +45,18 @@ from functools import partial import argparse -_runner_log = logging.getLogger('allensdk.model.biophysical.runner') +_runner_log = logging.getLogger("allensdk.model.biophysical.runner") _lock = None + def _init_lock(lock): global _lock _lock = lock + def run(args, sweeps=None, procs=6): - '''Main function for simulating sweeps in a biophysical experiment. + """Main function for simulating sweeps in a biophysical experiment. Parameters ---------- @@ -64,21 +66,20 @@ def run(args, sweeps=None, procs=6): number of sweeps to simulate simultaneously. sweeps : list list of experiment sweep numbers to simulate. If None, simulate all sweeps. - ''' + """ description = load_description(args) - - prepare_nwb_output(description.manifest.get_path('stimulus_path'), - description.manifest.get_path('output_path')) + + prepare_nwb_output(description.manifest.get_path("stimulus_path"), description.manifest.get_path("output_path")) if procs == 1: run_sync(description, sweeps) return if sweeps is None: - description.manifest.get_path('stimulus_path') - run_params = description.data['runs'][0] - sweeps = run_params['sweeps'] + description.manifest.get_path("stimulus_path") + run_params = description.data["runs"][0] + sweeps = run_params["sweeps"] lock = mp.Lock() pool = mp.Pool(procs, initializer=_init_lock, initargs=(lock,)) @@ -88,7 +89,7 @@ def run(args, sweeps=None, procs=6): def run_sync(description, sweeps=None): - '''Single-process main function for simulating sweeps in a biophysical experiment. + """Single-process main function for simulating sweeps in a biophysical experiment. Parameters ---------- @@ -96,7 +97,7 @@ def run_sync(description, sweeps=None): All information needed to run the experiment. sweeps : list list of experiment sweep numbers to simulate. If None, simulate all sweeps. - ''' + """ # configure NEURON utils = create_utils(description) @@ -104,17 +105,17 @@ def run_sync(description, sweeps=None): # configure model manifest = description.manifest - morphology_path = description.manifest.get_path('MORPHOLOGY').encode('ascii', 'ignore') + morphology_path = description.manifest.get_path("MORPHOLOGY").encode("ascii", "ignore") morphology_path = morphology_path.decode("utf-8") utils.generate_morphology(morphology_path) utils.load_cell_parameters() # configure stimulus and recording - stimulus_path = description.manifest.get_path('stimulus_path') - run_params = description.data['runs'][0] + stimulus_path = description.manifest.get_path("stimulus_path") + run_params = description.data["runs"][0] if sweeps is None: - sweeps = run_params['sweeps'] - sweeps_by_type = run_params['sweeps_by_type'] + sweeps = run_params["sweeps"] + sweeps_by_type = run_params["sweeps_by_type"] output_path = manifest.get_path("output_path") @@ -142,9 +143,8 @@ def run_sync(description, sweeps=None): _lock.release() -def prepare_nwb_output(nwb_stimulus_path, - nwb_result_path): - '''Copy the stimulus file, zero out the recorded voltages and spike times. +def prepare_nwb_output(nwb_stimulus_path, nwb_result_path): + """Copy the stimulus file, zero out the recorded voltages and spike times. Parameters ---------- @@ -152,7 +152,7 @@ def prepare_nwb_output(nwb_stimulus_path, NWB file name nwb_result_path : string NWB file name - ''' + """ output_dir = os.path.dirname(nwb_result_path) if not os.path.exists(output_dir): @@ -166,7 +166,7 @@ def prepare_nwb_output(nwb_stimulus_path, def save_nwb(output_path, v, sweep, sweeps_by_type): - '''Save a single voltage output result into an existing sweep in a NWB file. + """Save a single voltage output result into an existing sweep in a NWB file. This is intended to overwrite a recorded trace with a simulated voltage. Parameters @@ -177,30 +177,28 @@ def save_nwb(output_path, v, sweep, sweeps_by_type): voltage sweep : integer which entry to overwrite in the file. - ''' + """ output = NwbDataSet(output_path) output.set_sweep(sweep, None, v) - sweep_by_type = {t: [sweep] - for t, ss in sweeps_by_type.items() if sweep in ss} - sweep_features = extract_cell_features.extract_sweep_features(output, - sweep_by_type) + sweep_by_type = {t: [sweep] for t, ss in sweeps_by_type.items() if sweep in ss} + sweep_features = extract_cell_features.extract_sweep_features(output, sweep_by_type) try: - spikes = sweep_features[sweep]['spikes'] - spike_times = [s['threshold_t'] for s in spikes] + spikes = sweep_features[sweep]["spikes"] + spike_times = [s["threshold_t"] for s in spikes] output.set_spike_times(sweep, spike_times) except Exception as e: logging.info("sweep %d has no sweep features. %s" % (sweep, e.args)) def load_description(args_dict): - '''Read configurations. + """Read configurations. Parameters ---------- args_dict : dict Parsed arguments dictionary with following keys. - + manifest_file : string .json file with containing the experiment configuration axon_type : string @@ -210,29 +208,34 @@ def load_description(args_dict): ------- Config Object with all information needed to run the experiment. - ''' - manifest_json_path = args_dict['manifest_file'] - + """ + manifest_json_path = args_dict["manifest_file"] + description = Config().load(manifest_json_path) - + # For newest all-active models update the axon replacement - axon_replacement_dict = {'axon_type': args_dict.get('axon_type', 'truncated')} - description.update_data(axon_replacement_dict, 'biophys') + axon_replacement_dict = {"axon_type": args_dict.get("axon_type", "truncated")} + description.update_data(axon_replacement_dict, "biophys") # fix nonstandard description sections - fix_sections = ['passive', 'axon_morph,', 'conditions', 'fitting'] + fix_sections = ["passive", "axon_morph,", "conditions", "fitting"] description.fix_unary_sections(fix_sections) return description # Create the parser -sim_parser = argparse.ArgumentParser(description='Run simulation for biophysical models with the provided configuration') -sim_parser.add_argument('manifest_file', - help='.json configurations for running the simulations') -sim_parser.add_argument('--axon_type', default='truncated', choices=['stub', 'truncated'], - help='axon replacement for all-active models; truncated: truncate reconstructed axon after 60 micron, stub: replace reconstructed axon with a uniform stub 60 micron long and 1 micron in diameter') - -if '__main__' == __name__: +sim_parser = argparse.ArgumentParser( + description="Run simulation for biophysical models with the provided configuration" +) +sim_parser.add_argument("manifest_file", help=".json configurations for running the simulations") +sim_parser.add_argument( + "--axon_type", + default="truncated", + choices=["stub", "truncated"], + help="axon replacement for all-active models; truncated: truncate reconstructed axon after 60 micron, stub: replace reconstructed axon with a uniform stub 60 micron long and 1 micron in diameter", +) + +if "__main__" == __name__: schema = sim_parser.parse_args() run(vars(schema)) diff --git a/allensdk/model/biophysical/utils.py b/allensdk/model/biophysical/utils.py index 2e53c23b2b..bf3bcfc60c 100644 --- a/allensdk/model/biophysical/utils.py +++ b/allensdk/model/biophysical/utils.py @@ -106,9 +106,7 @@ def __init__(self, description): self.stim_vec_list = [] - def update_default_cell_hoc( - self, description, default_cell_hoc="cell.hoc" - ): + def update_default_cell_hoc(self, description, default_cell_hoc="cell.hoc"): """replace the default 'cell.hoc' path in the manifest with 'cell.hoc' packaged within AllenSDK if it does not exist""" @@ -187,20 +185,8 @@ def load_cell_parameters(self): h(p["name"] + " = %g " % p["value"]) else: if p["mechanism"] != "": - h( - 'forsec "' - + p["section"] - + '" { insert ' - + p["mechanism"] - + " }" - ) - h( - 'forsec "' - + p["section"] - + '" { ' - + p["name"] - + " = %g }" % p["value"] - ) + h('forsec "' + p["section"] + '" { insert ' + p["mechanism"] + " }") + h('forsec "' + p["section"] + '" { ' + p["name"] + " = %g }" % p["value"]) # Set reversal potentials for erev in conditions["erev"]: @@ -255,9 +241,7 @@ def read_stimulus(self, stimulus_path, sweep=0): sweep : integer, optional sweep index """ - Utils._log.info( - "reading stimulus path: %s, sweep %s", stimulus_path, sweep - ) + Utils._log.info("reading stimulus path: %s, sweep %s", stimulus_path, sweep) stimulus_data = NwbDataSet(stimulus_path) sweep_data = stimulus_data.get_sweep(sweep) @@ -274,8 +258,7 @@ def read_stimulus(self, stimulus_path, sweep=0): if hz != neuron_hz: Utils._log.debug( - "changing sampling rate from %d to %d to avoid NEURON " - "aliasing", + "changing sampling rate from %d to %d to avoid NEURON aliasing", hz, neuron_hz, ) @@ -306,17 +289,13 @@ def get_recorded_data(self, vec): 't' = numpy.ndarray with timestamps """ - junction_potential = self.description.data["fitting"][0][ - "junction_potential" - ] + junction_potential = self.description.data["fitting"][0]["junction_potential"] v = np.array(vec["v"]) t = np.array(vec["t"]) if self.stimulus_sampling_rate < self.simulation_sampling_rate: - factor = ( - self.simulation_sampling_rate / self.stimulus_sampling_rate - ) + factor = self.simulation_sampling_rate / self.stimulus_sampling_rate Utils._log.debug("subsampling recorded traces by %dX", factor) v = block_reduce(v, (factor,), np.mean)[: len(self.stim_curr)] @@ -360,16 +339,11 @@ def generate_morphology(self, morph_filename): Path to morphology. """ if self.axon_type == "stub": - self._log.info( - "Replacing axon with a stub : length 60 micron, diameter 1 " - "micron" - ) + self._log.info("Replacing axon with a stub : length 60 micron, diameter 1 micron") super(AllActiveUtils, self).generate_morphology(morph_filename) return - self._log.info( - "Legacy model - Truncating reconstructed axon after 60 micron" - ) + self._log.info("Legacy model - Truncating reconstructed axon after 60 micron") morph_basename = os.path.basename(morph_filename) morph_extension = morph_basename.split(".")[-1] if morph_extension.lower() == "swc": @@ -443,21 +417,12 @@ def load_cell_parameters(self): else: if hasattr(h, section_array): if mechanism != "": - print( - "Adding mechanism %s to %s" - % (mechanism, section_array) - ) + print("Adding mechanism %s to %s" % (mechanism, section_array)) for section in getattr(h, section_array): - if ( - self.h.ismembrane(str(mechanism), sec=section) - != 1 - ): + if self.h.ismembrane(str(mechanism), sec=section) != 1: section.insert(mechanism) - print( - "Setting %s to %.6g in %s" - % (param_name, param_value, section_array) - ) + print("Setting %s to %.6g in %s" % (param_name, param_value, section_array)) for section in getattr(h, section_array): setattr(section, param_name, param_value) @@ -467,10 +432,7 @@ def load_cell_parameters(self): ek = float(erev["ek"]) ena = float(erev["ena"]) - print( - "Setting ek to %.6g and ena to %.6g in %s" - % (ek, ena, erev_section_array) - ) + print("Setting ek to %.6g and ena to %.6g in %s" % (ek, ena, erev_section_array)) if hasattr(h, erev_section_array): for section in getattr(h, erev_section_array): @@ -480,10 +442,7 @@ def load_cell_parameters(self): if self.h.ismembrane("na_ion", sec=section) == 1: setattr(section, "ena", ena) else: - print( - "Warning: can't set erev for %s, " - "section array doesn't exist" % erev_section_array - ) + print("Warning: can't set erev for %s, section array doesn't exist" % erev_section_array) self.h.v_init = conditions["v_init"] self.h.celsius = conditions["celsius"] diff --git a/allensdk/model/glif/__init__.py b/allensdk/model/glif/__init__.py index 6ef7424a26..2755506be4 100644 --- a/allensdk/model/glif/__init__.py +++ b/allensdk/model/glif/__init__.py @@ -33,7 +33,7 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -""" A Generalized Linear Integrate and Fire (GLIF) neuron modeling package. +"""A Generalized Linear Integrate and Fire (GLIF) neuron modeling package. Use this code to run the GLIF models available in the Allen Cell Types Atlas. See :doc:`glif_models` for more details. """ diff --git a/allensdk/model/glif/glif_neuron.py b/allensdk/model/glif/glif_neuron.py index 06e6c99e3a..f124e6e5f3 100755 --- a/allensdk/model/glif/glif_neuron.py +++ b/allensdk/model/glif/glif_neuron.py @@ -36,7 +36,7 @@ import logging import numpy as np -import simplejson as json +import simplejson as json import allensdk.core.json_utilities as ju import copy @@ -45,23 +45,26 @@ except Exception: from .glif_neuron_methods import GlifNeuronMethod, METHOD_LIBRARY -class GlifBadResetException( Exception ): - """ Exception raised when voltage is still above threshold after a reset rule is applied. """ + +class GlifBadResetException(Exception): + """Exception raised when voltage is still above threshold after a reset rule is applied.""" + def __init__(self, message, dv): super(Exception, self).__init__(message) self.dv = dv - -class GlifNeuron( object ): - """ Implements the current-based Mihalas Neiber GLIF neuron. Simulations model the voltage, + + +class GlifNeuron(object): + """Implements the current-based Mihalas Neiber GLIF neuron. Simulations model the voltage, threshold, and afterspike currents of a neuron given an input stimulus. A set of modular dynamics - rules are applied until voltage crosses threshold, at which point a set of modular reset rules are + rules are applied until voltage crosses threshold, at which point a set of modular reset rules are applied. See glif_neuron_methods.py for a list of what options there are for voltage, threshold, and afterspike current dynamics and reset rules. Parameters ---------- - El : float - resting potential + El : float + resting potential dt : float duration between time steps asc_tau_array: np.ndarray @@ -90,7 +93,7 @@ class GlifNeuron( object ): dictionary containing the 'name' of the voltage dynamics method to use and a 'params' dictionary parameters to pass to that function. threshold_reset_method : dict dictionary containing the 'name' of the threshold dynamics method to use and a 'params' dictionary parameters to pass to that function. - init_voltage : float + init_voltage : float initial voltage value init_threshold : float initial spike threshold value @@ -100,18 +103,36 @@ class GlifNeuron( object ): TYPE = "GLIF" - def __init__(self, El, dt, asc_tau_array, R_input, C, asc_amp_array, spike_cut_length, th_inf, th_adapt, coeffs, - AScurrent_dynamics_method, voltage_dynamics_method, threshold_dynamics_method, - AScurrent_reset_method, voltage_reset_method, threshold_reset_method, - init_voltage, init_threshold, init_AScurrents, **kwargs): - - """ Initialize the neuron.""" + def __init__( + self, + El, + dt, + asc_tau_array, + R_input, + C, + asc_amp_array, + spike_cut_length, + th_inf, + th_adapt, + coeffs, + AScurrent_dynamics_method, + voltage_dynamics_method, + threshold_dynamics_method, + AScurrent_reset_method, + voltage_reset_method, + threshold_reset_method, + init_voltage, + init_threshold, + init_AScurrents, + **kwargs, + ): + """Initialize the neuron.""" self.type = GlifNeuron.TYPE self.El = El self.dt = dt self.asc_tau_array = np.array(asc_tau_array) - + self.R_input = R_input self.C = C @@ -126,97 +147,101 @@ def __init__(self, El, dt, asc_tau_array, R_input, C, asc_amp_array, spike_cut_l self.init_threshold = init_threshold self.init_AScurrents = init_AScurrents - assert len(asc_tau_array) == len(asc_amp_array), Exception("After-spike current vector must have same length as asc_tau_array (%d vs %d)" % (asc_amp_array, asc_tau_array)) - assert len(self.init_AScurrents) == len(self.asc_tau_array), Exception("init_AScurrents length (%d) must have same length as asc_tau_array (%d)" % (len(self.init_AScurrents), len(self.asc_tau_array))) - + assert len(asc_tau_array) == len(asc_amp_array), Exception( + "After-spike current vector must have same length as asc_tau_array (%d vs %d)" + % (asc_amp_array, asc_tau_array) + ) + assert len(self.init_AScurrents) == len(self.asc_tau_array), Exception( + "init_AScurrents length (%d) must have same length as asc_tau_array (%d)" + % (len(self.init_AScurrents), len(self.asc_tau_array)) + ) # values computed based on inputs self.k = 1.0 / self.asc_tau_array self.G = 1.0 / self.R_input - # Values that can be fit: They scale the input values. + # Values that can be fit: They scale the input values. # These are allowed to have default values because they are going to get optimized. - self.coeffs = { - 'th_inf': 1, - 'C': 1, - 'G': 1, - 'b': 1, - 'a': 1, - 'asc_amp_array': np.ones(len(self.asc_tau_array)) - } + self.coeffs = {"th_inf": 1, "C": 1, "G": 1, "b": 1, "a": 1, "asc_amp_array": np.ones(len(self.asc_tau_array))} self.coeffs.update(coeffs) - - logging.debug('spike cut length: %d' % self.spike_cut_length) + + logging.debug("spike cut length: %d" % self.spike_cut_length) # initialize dynamics methods - self.AScurrent_dynamics_method = self.configure_library_method('AScurrent_dynamics_method', AScurrent_dynamics_method) - self.voltage_dynamics_method = self.configure_library_method('voltage_dynamics_method', voltage_dynamics_method) - self.threshold_dynamics_method = self.configure_library_method('threshold_dynamics_method', threshold_dynamics_method) + self.AScurrent_dynamics_method = self.configure_library_method( + "AScurrent_dynamics_method", AScurrent_dynamics_method + ) + self.voltage_dynamics_method = self.configure_library_method("voltage_dynamics_method", voltage_dynamics_method) + self.threshold_dynamics_method = self.configure_library_method( + "threshold_dynamics_method", threshold_dynamics_method + ) # initialize reset methods - self.AScurrent_reset_method = self.configure_library_method('AScurrent_reset_method', AScurrent_reset_method) - self.voltage_reset_method = self.configure_library_method('voltage_reset_method', voltage_reset_method) - self.threshold_reset_method = self.configure_library_method('threshold_reset_method', threshold_reset_method) + self.AScurrent_reset_method = self.configure_library_method("AScurrent_reset_method", AScurrent_reset_method) + self.voltage_reset_method = self.configure_library_method("voltage_reset_method", voltage_reset_method) + self.threshold_reset_method = self.configure_library_method("threshold_reset_method", threshold_reset_method) def __str__(self): return json.dumps(self.to_dict(), default=ju.json_handler, indent=2) @property def tau_m(self): - return self.R_input*self.C + return self.R_input * self.C @classmethod def from_dict(cls, d): - return cls(El = d['El'], - dt = d['dt'], - asc_tau_array = d['asc_tau_array'], - R_input = d['R_input'], - C = d['C'], - asc_amp_array = d['asc_amp_array'], - spike_cut_length = d['spike_cut_length'], - th_inf = d['th_inf'], - th_adapt = d['th_adapt'], - coeffs = d.get('coeffs', {}), - AScurrent_dynamics_method = d['AScurrent_dynamics_method'], - voltage_dynamics_method = d['voltage_dynamics_method'], - threshold_dynamics_method = d['threshold_dynamics_method'], - voltage_reset_method = d['voltage_reset_method'], - AScurrent_reset_method = d['AScurrent_reset_method'], - threshold_reset_method = d['threshold_reset_method'], - init_voltage = d['init_voltage'], - init_threshold = d['init_threshold'], - init_AScurrents = d['init_AScurrents']) + return cls( + El=d["El"], + dt=d["dt"], + asc_tau_array=d["asc_tau_array"], + R_input=d["R_input"], + C=d["C"], + asc_amp_array=d["asc_amp_array"], + spike_cut_length=d["spike_cut_length"], + th_inf=d["th_inf"], + th_adapt=d["th_adapt"], + coeffs=d.get("coeffs", {}), + AScurrent_dynamics_method=d["AScurrent_dynamics_method"], + voltage_dynamics_method=d["voltage_dynamics_method"], + threshold_dynamics_method=d["threshold_dynamics_method"], + voltage_reset_method=d["voltage_reset_method"], + AScurrent_reset_method=d["AScurrent_reset_method"], + threshold_reset_method=d["threshold_reset_method"], + init_voltage=d["init_voltage"], + init_threshold=d["init_threshold"], + init_AScurrents=d["init_AScurrents"], + ) def to_dict(self): - """ Convert the neuron to a serializable dictionary. """ + """Convert the neuron to a serializable dictionary.""" return { - 'type': self.type, - 'El': self.El, - 'dt': self.dt, - 'asc_tau_array': copy.deepcopy(self.asc_tau_array), - 'R_input': self.R_input, - 'C': self.C, - 'asc_amp_array': copy.deepcopy(self.asc_amp_array), - 'spike_cut_length': self.spike_cut_length, - 'th_inf': self.th_inf, - 'th_adapt': self.th_adapt, - 'coeffs': copy.deepcopy(self.coeffs), - 'AScurrent_dynamics_method': copy.deepcopy(self.AScurrent_dynamics_method), - 'voltage_dynamics_method': copy.deepcopy(self.voltage_dynamics_method), - 'threshold_dynamics_method': copy.deepcopy(self.threshold_dynamics_method), - 'AScurrent_reset_method': copy.deepcopy(self.AScurrent_reset_method), - 'voltage_reset_method': copy.deepcopy(self.voltage_reset_method), - 'threshold_reset_method': copy.deepcopy(self.threshold_reset_method), - 'init_voltage': self.init_voltage, - 'init_threshold': self.init_threshold, - 'init_AScurrents': copy.deepcopy(self.init_AScurrents), - 'El_reference': self.El + "type": self.type, + "El": self.El, + "dt": self.dt, + "asc_tau_array": copy.deepcopy(self.asc_tau_array), + "R_input": self.R_input, + "C": self.C, + "asc_amp_array": copy.deepcopy(self.asc_amp_array), + "spike_cut_length": self.spike_cut_length, + "th_inf": self.th_inf, + "th_adapt": self.th_adapt, + "coeffs": copy.deepcopy(self.coeffs), + "AScurrent_dynamics_method": copy.deepcopy(self.AScurrent_dynamics_method), + "voltage_dynamics_method": copy.deepcopy(self.voltage_dynamics_method), + "threshold_dynamics_method": copy.deepcopy(self.threshold_dynamics_method), + "AScurrent_reset_method": copy.deepcopy(self.AScurrent_reset_method), + "voltage_reset_method": copy.deepcopy(self.voltage_reset_method), + "threshold_reset_method": copy.deepcopy(self.threshold_reset_method), + "init_voltage": self.init_voltage, + "init_threshold": self.init_threshold, + "init_AScurrents": copy.deepcopy(self.init_AScurrents), + "El_reference": self.El, } @staticmethod def configure_method(method_name, method, method_params): - """ Create a GlifNeuronMethod instance given a name, a function, and function parameters. + """Create a GlifNeuronMethod instance given a name, a function, and function parameters. This is just a shortcut to the GlifNeuronMethod constructor. Parameters @@ -224,7 +249,7 @@ def configure_method(method_name, method, method_params): method_name : string name for refering to this method later method : function - a python function + a python function method_parameters : dict function arguments whose values should be fixed @@ -232,14 +257,14 @@ def configure_method(method_name, method, method_params): ------- GlifNeuronMethod a GlifNeuronMethod instance - """ + """ return GlifNeuronMethod(method_name, method, method_params) @staticmethod def configure_library_method(method_type, params): - """ Create a GlifNeuronMethod instance out of a library of functions organized by type name. - This refers to the METHOD_LIBRARY in glif_neuron_methods.py, which lays out the available functions + """Create a GlifNeuronMethod instance out of a library of functions organized by type name. + This refers to the METHOD_LIBRARY in glif_neuron_methods.py, which lays out the available functions that can be used for dynamics and reset rules. Parameters @@ -257,21 +282,21 @@ def configure_library_method(method_type, params): method_options = METHOD_LIBRARY.get(method_type, None) assert method_options is not None, Exception("Unknown method type (%s)" % method_type) - - method_name = params.get('name', None) - method_params = params.get('params', None) - + + method_name = params.get("name", None) + method_params = params.get("params", None) + assert method_name is not None, Exception("Method configuration for %s has no 'name'" % (method_type)) assert method_params is not None, Exception("Method configuration for %s has no 'params'" % (method_params)) - + method = method_options.get(method_name, None) - + assert method is not None, Exception("unknown method name %s of type %s" % (method_name, method_type)) - + return GlifNeuron.configure_method(method_name, method, method_params) - def dynamics(self, voltage_t0, threshold_t0, AScurrents_t0, inj, time_step, spike_time_steps): - """ Update the voltage, threshold, and afterspike currents of the neuron for a single time step. + def dynamics(self, voltage_t0, threshold_t0, AScurrents_t0, inj, time_step, spike_time_steps): + """Update the voltage, threshold, and afterspike currents of the neuron for a single time step. Parameters ---------- @@ -299,9 +324,9 @@ def dynamics(self, voltage_t0, threshold_t0, AScurrents_t0, inj, time_step, spik threshold_t1 = self.threshold_dynamics_method(self, threshold_t0, voltage_t0, AScurrents_t0, inj) return voltage_t1, threshold_t1, AScurrents_t1 - + def reset(self, voltage_t0, threshold_t0, AScurrents_t0): - """ Apply reset rules to the neuron's voltage, threshold, and afterspike currents assuming a spike has occurred (voltage is above threshold). + """Apply reset rules to the neuron's voltage, threshold, and afterspike currents assuming a spike has occurred (voltage is above threshold). Parameters ---------- @@ -317,22 +342,22 @@ def reset(self, voltage_t0, threshold_t0, AScurrents_t0): tuple voltage_t1 (voltage at next time step), threshold_t1 (threshold at next time step), AScurrents_t1 (afterspike currents at next time step) """ - - AScurrents_t1 = self.AScurrent_reset_method(self, AScurrents_t0) + + AScurrents_t1 = self.AScurrent_reset_method(self, AScurrents_t0) voltage_t1 = self.voltage_reset_method(self, voltage_t0) threshold_t1 = self.threshold_reset_method(self, threshold_t0, voltage_t1) - bad_reset_flag=False + bad_reset_flag = False if voltage_t1 > threshold_t1: - bad_reset_flag=True - #TODO put this back in eventually but would rather debug right now -# raise GlifBadResetException("Voltage reset above threshold: voltage_t1 (%f) threshold_t1 (%f), voltage_t0 (%f) threshold_t0 (%f) AScurrents_t0 (%s)" % ( voltage_t1, threshold_t1, voltage_t0, threshold_t0, repr(AScurrents_t0)), voltage_t1 - threshold_t1) + bad_reset_flag = True + # TODO put this back in eventually but would rather debug right now + # raise GlifBadResetException("Voltage reset above threshold: voltage_t1 (%f) threshold_t1 (%f), voltage_t0 (%f) threshold_t0 (%f) AScurrents_t0 (%s)" % ( voltage_t1, threshold_t1, voltage_t0, threshold_t0, repr(AScurrents_t0)), voltage_t1 - threshold_t1) return voltage_t1, threshold_t1, AScurrents_t1, bad_reset_flag - + def run(self, stim): - """ Run neuron simulation over a given stimulus. This steps through the stimulus applying dynamics equations. - After each step it checks if voltage is above threshold. If so, self.spike_cut_length NaNs are inserted - into the output voltages, reset rules are applied to the voltage, threshold, and afterspike currents, and the + """Run neuron simulation over a given stimulus. This steps through the stimulus applying dynamics equations. + After each step it checks if voltage is above threshold. If so, self.spike_cut_length NaNs are inserted + into the output voltages, reset rules are applied to the voltage, threshold, and afterspike currents, and the simulation resumes. Parameters @@ -343,35 +368,35 @@ def run(self, stim): Returns ------- dict - a dictionary containing: - 'voltage': simulated voltage values, + a dictionary containing: + 'voltage': simulated voltage values, 'threshold': threshold values during the simulation, - 'AScurrents': afterspike current values during the simulation, - 'grid_spike_times': spike times (in uits of self.dt) aligned to simulation time steps, - 'interpolated_spike_times': spike times (in units of self.dt) linearly interpolated between time steps, - 'spike_time_steps': the indices of grid spike times, - 'interpolated_spike_voltage': voltage of the simulation at interpolated spike times, + 'AScurrents': afterspike current values during the simulation, + 'grid_spike_times': spike times (in uits of self.dt) aligned to simulation time steps, + 'interpolated_spike_times': spike times (in units of self.dt) linearly interpolated between time steps, + 'spike_time_steps': the indices of grid spike times, + 'interpolated_spike_voltage': voltage of the simulation at interpolated spike times, 'interpolated_spike_threshold': threshold of the simulation at interpolated spike times """ - bad_reset_flag=False - + bad_reset_flag = False + # initialize the voltage, threshold, and afterspike current values voltage_t0 = self.init_voltage threshold_t0 = self.init_threshold AScurrents_t0 = self.init_AScurrents - self.threshold_components = None #get rid of lingering method data + self.threshold_components = None # get rid of lingering method data - num_time_steps = len(stim) + num_time_steps = len(stim) num_AScurrents = len(AScurrents_t0) - + # pre-allocate the output voltages, thresholds, and after-spike currents - voltage_out=np.empty(num_time_steps) - voltage_out[:]=np.nan - threshold_out=np.empty(num_time_steps) - threshold_out[:]=np.nan - AScurrents_out=np.empty(shape=(num_time_steps, num_AScurrents)) - AScurrents_out[:]=np.nan + voltage_out = np.empty(num_time_steps) + voltage_out[:] = np.nan + threshold_out = np.empty(num_time_steps) + threshold_out[:] = np.nan + AScurrents_out = np.empty(shape=(num_time_steps, num_AScurrents)) + AScurrents_out[:] = np.nan # array that will hold spike indices spike_time_steps = [] @@ -383,29 +408,38 @@ def run(self, stim): time_step = 0 while time_step < num_time_steps: if time_step % 10000 == 0: - logging.info("time step %d / %d" % (time_step, num_time_steps)) + logging.info("time step %d / %d" % (time_step, num_time_steps)) # compute voltage, threshold, and ascurrents at current time step - (voltage_t1, threshold_t1, AScurrents_t1) = self.dynamics(voltage_t0, threshold_t0, AScurrents_t0, stim[time_step], time_step, spike_time_steps) - - #if the voltage is bigger than the threshold record the spike and reset the values - if voltage_t1 > threshold_t1: + (voltage_t1, threshold_t1, AScurrents_t1) = self.dynamics( + voltage_t0, threshold_t0, AScurrents_t0, stim[time_step], time_step, spike_time_steps + ) + # if the voltage is bigger than the threshold record the spike and reset the values + if voltage_t1 > threshold_t1: # spike_time_steps are stimulus indices when voltage surpassed threshold spike_time_steps.append(time_step) - grid_spike_times.append(time_step * self.dt) + grid_spike_times.append(time_step * self.dt) - # compute higher fidelity spike time/voltage/threshold by linearly interpolating - interpolated_spike_times.append(interpolate_spike_time(self.dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1)) + # compute higher fidelity spike time/voltage/threshold by linearly interpolating + interpolated_spike_times.append( + interpolate_spike_time(self.dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1) + ) interpolated_spike_time_offset = interpolated_spike_times[-1] - (time_step - 1) * self.dt - interpolated_spike_voltage.append(interpolate_spike_value(self.dt, interpolated_spike_time_offset, voltage_t0, voltage_t1)) - interpolated_spike_threshold.append(interpolate_spike_value(self.dt, interpolated_spike_time_offset, threshold_t0, threshold_t1)) - + interpolated_spike_voltage.append( + interpolate_spike_value(self.dt, interpolated_spike_time_offset, voltage_t0, voltage_t1) + ) + interpolated_spike_threshold.append( + interpolate_spike_value(self.dt, interpolated_spike_time_offset, threshold_t0, threshold_t1) + ) + # reset voltage, threshold, and afterspike currents # Note that these values are not ever recorded unless the spike cut length doesnt happen (this doesnt seem quite right) - (voltage_t0, threshold_t0, AScurrents_t0, bad_reset_flag) = self.reset(voltage_t1, threshold_t1, AScurrents_t1) - + (voltage_t0, threshold_t0, AScurrents_t0, bad_reset_flag) = self.reset( + voltage_t1, threshold_t1, AScurrents_t1 + ) + # if we are not integrating during the spike (which includes right now), insert nans then jump ahead # TODO MAYBE ONE LAST NAN SHOULD BE INSERTED AND THIS VALUE SHOULD BE RECORDED FOR CONSISTANCY if self.spike_cut_length > 0: @@ -414,87 +448,88 @@ def run(self, stim): cut_past_end = (time_step + n) >= len(voltage_out) if cut_past_end: n = len(voltage_out) - time_step - - voltage_out[time_step:time_step+n] = np.nan - threshold_out[time_step:time_step+n] = np.nan - AScurrents_out[time_step:time_step+n,:] = np.nan + + voltage_out[time_step : time_step + n] = np.nan + threshold_out[time_step : time_step + n] = np.nan + AScurrents_out[time_step : time_step + n, :] = np.nan if not cut_past_end: - voltage_out[time_step+n] = voltage_t0 - threshold_out[time_step+n] = threshold_t0 - AScurrents_out[time_step+n,:] = AScurrents_t0 + voltage_out[time_step + n] = voltage_t0 + threshold_out[time_step + n] = threshold_t0 + AScurrents_out[time_step + n, :] = AScurrents_t0 - time_step += self.spike_cut_length+1 - else: - voltage_out[time_step] = voltage_t0 + time_step += self.spike_cut_length + 1 + else: + voltage_out[time_step] = voltage_t0 threshold_out[time_step] = threshold_t0 - AScurrents_out[time_step,:] = AScurrents_t0 + AScurrents_out[time_step, :] = AScurrents_t0 time_step += 1 - + if bad_reset_flag: - voltage_out[time_step:time_step+5] = voltage_t0 - threshold_out[time_step:time_step+5] = threshold_t0 - AScurrents_out[time_step:time_step+5] = AScurrents_t0 + voltage_out[time_step : time_step + 5] = voltage_t0 + threshold_out[time_step : time_step + 5] = threshold_t0 + AScurrents_out[time_step : time_step + 5] = AScurrents_t0 break else: # there was no spike, store the next voltages - voltage_out[time_step] = voltage_t1 + voltage_out[time_step] = voltage_t1 threshold_out[time_step] = threshold_t1 - AScurrents_out[time_step,:] = AScurrents_t1 + AScurrents_out[time_step, :] = AScurrents_t1 voltage_t0 = voltage_t1 threshold_t0 = threshold_t1 AScurrents_t0 = AScurrents_t1 - + time_step += 1 return { - 'voltage': voltage_out, - 'threshold': threshold_out, - 'AScurrents': AScurrents_out, - 'grid_spike_times': np.array(grid_spike_times), - 'interpolated_spike_times': np.array(interpolated_spike_times), - 'spike_time_steps': np.array(spike_time_steps), - 'interpolated_spike_voltage': np.array(interpolated_spike_voltage), - 'interpolated_spike_threshold': np.array(interpolated_spike_threshold) - } + "voltage": voltage_out, + "threshold": threshold_out, + "AScurrents": AScurrents_out, + "grid_spike_times": np.array(grid_spike_times), + "interpolated_spike_times": np.array(interpolated_spike_times), + "spike_time_steps": np.array(spike_time_steps), + "interpolated_spike_voltage": np.array(interpolated_spike_voltage), + "interpolated_spike_threshold": np.array(interpolated_spike_threshold), + } -# TODO: DEPRICATE -# def get_threshold_components(self): -# if self.threshold_components is None: -# self.threshold_components = { 'spike': [0], 'voltage': [0] } -# -# return self.threshold_components + # TODO: DEPRICATE + # def get_threshold_components(self): + # if self.threshold_components is None: + # self.threshold_components = { 'spike': [0], 'voltage': [0] } + # + # return self.threshold_components def append_threshold_components(self, spike, voltage): - self.threshold_components['spike'].append(spike) - self.threshold_components['voltage'].append(voltage) + self.threshold_components["spike"].append(spike) + self.threshold_components["voltage"].append(voltage) + # TODO: DEPRICATE # def reset_threshold_components(self): -# self.threshold_components = None - +# self.threshold_components = None def interpolate_spike_time(dt, time_step, threshold_t0, threshold_t1, voltage_t0, voltage_t1): - """ Given two voltage and threshold values, the dt between them and the initial time step, interpolate - a spike time within the dt interval by intersecting the two lines. """ - return time_step*dt + line_crossing_x(dt, voltage_t0, voltage_t1, threshold_t0, threshold_t1) + """Given two voltage and threshold values, the dt between them and the initial time step, interpolate + a spike time within the dt interval by intersecting the two lines.""" + return time_step * dt + line_crossing_x(dt, voltage_t0, voltage_t1, threshold_t0, threshold_t1) def interpolate_spike_value(dt, interpolated_spike_time_offset, v0, v1): - """ Take a value at two adjacent time steps and linearly interpolate what the value would be - at an offset between the two time steps. """ + """Take a value at two adjacent time steps and linearly interpolate what the value would be + at an offset between the two time steps.""" return v0 + (v1 - v0) * interpolated_spike_time_offset / dt def line_crossing_x(dx, a0, a1, b0, b1): - """ Find the x value of the intersection of two lines. """ - assert type(a0) != int and type(a1) != int and type(b0) != int and type(b1) != int, Exception("Do not pass integers into this function!") - return dx * (b0 - a0) / ( (a1 - a0) - (b1 - b0) ) - + """Find the x value of the intersection of two lines.""" + assert type(a0) != int and type(a1) != int and type(b0) != int and type(b1) != int, Exception( + "Do not pass integers into this function!" + ) + return dx * (b0 - a0) / ((a1 - a0) - (b1 - b0)) + def line_crossing_y(dx, a0, a1, b0, b1): - """ Find the y value of the intersection of two lines. """ + """Find the y value of the intersection of two lines.""" return b0 + (b1 - b0) * (b0 - a0) / ((a1 - a0) - (b1 - b0)) - diff --git a/allensdk/model/glif/glif_neuron_methods.py b/allensdk/model/glif/glif_neuron_methods.py index 73aedf4226..a76acd70c3 100644 --- a/allensdk/model/glif/glif_neuron_methods.py +++ b/allensdk/model/glif/glif_neuron_methods.py @@ -33,21 +33,22 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -""" The methods in this module are used for configuring dynamics and reset rules for the GlifNeuron. +"""The methods in this module are used for configuring dynamics and reset rules for the GlifNeuron. For more details on how to use these methods, see :doc:`glif_models`. """ + import functools import numpy as np -class GlifNeuronMethod( object ): - """ A simple class to keep track of the name and parameters associated with a neuron method. +class GlifNeuronMethod(object): + """A simple class to keep track of the name and parameters associated with a neuron method. This class is initialized with a name, function, and parameters to pass to the function. The - function then has those passed parameters fixed to a partial function using functools.partial. - This class then mimics a function itself using the __call__ convention. Parameters that are not + function then has those passed parameters fixed to a partial function using functools.partial. + This class then mimics a function itself using the __call__ convention. Parameters that are not fixed in this way are assumed to be passed into the method when it is called. If the passed parameters contain an argument that is not part of the function signature, an exception will - be raised. + be raised. Parameters ---------- @@ -65,17 +66,14 @@ def __init__(self, method_name, method, method_params): self.method = functools.partial(method, **method_params) def __call__(self, *args, **kwargs): - """ Defining this method allows an instance to be called like a function """ + """Defining this method allows an instance to be called like a function""" return self.method(*args, **kwargs) def to_dict(self): - return { - 'name': self.name, - 'params': self.params - } + return {"name": self.name, "params": self.params} def modify_parameter(self, param, operator): - """ Modify a function parameter needs to be modified after initialization. + """Modify a function parameter needs to be modified after initialization. Parameters ---------- @@ -94,9 +92,9 @@ def modify_parameter(self, param, operator): return value -def max_of_line_and_const(x,b,c,d): - #TODO: move to other library - """ Find the maximum of a value and a position on a line +def max_of_line_and_const(x, b, c, d): + # TODO: move to other library + """Find the maximum of a value and a position on a line Parameters ---------- @@ -116,13 +114,13 @@ def max_of_line_and_const(x,b,c,d): """ one = b - two = c*x+d - return np.maximum(one,two) + two = c * x + d + return np.maximum(one, two) -def min_of_line_and_zero(x,c,d): - #TODO: move to other library - """ Find the minimum of a value and a position on a line +def min_of_line_and_zero(x, c, d): + # TODO: move to other library + """Find the minimum of a value and a position on a line Parameters ---------- @@ -142,42 +140,46 @@ def min_of_line_and_zero(x,c,d): """ one = 0 - two = c*x+d - return np.minimum(one,two) + two = c * x + d + return np.minimum(one, two) def dynamics_AScurrent_exp(neuron, AScurrents_t0, time_step, spike_time_steps): - """ Exponential afterspike current dynamics method takes a current at t0 and returns the current at + """Exponential afterspike current dynamics method takes a current at t0 and returns the current at a time step later. """ - return AScurrents_t0*np.exp(-neuron.k*neuron.dt) - + return AScurrents_t0 * np.exp(-neuron.k * neuron.dt) + def dynamics_AScurrent_none(neuron, AScurrents_t0, time_step, spike_time_steps): - """ This method always returns zeros for the afterspike currents, regardless of input. """ + """This method always returns zeros for the afterspike currents, regardless of input.""" return np.zeros(len(AScurrents_t0)) def dynamics_voltage_linear_forward_euler(neuron, voltage_t0, AScurrents_t0, inj): - """ (TODO) Linear voltage dynamics. """ - return voltage_t0 + (inj + np.sum(AScurrents_t0) - neuron.G * neuron.coeffs['G'] * (voltage_t0 - neuron.El)) * neuron.dt / (neuron.C * neuron.coeffs['C']) + """(TODO) Linear voltage dynamics.""" + return voltage_t0 + ( + inj + np.sum(AScurrents_t0) - neuron.G * neuron.coeffs["G"] * (voltage_t0 - neuron.El) + ) * neuron.dt / (neuron.C * neuron.coeffs["C"]) + def dynamics_voltage_linear_exact(neuron, voltage_t0, AScurrents_t0, inj): - """ (TODO) Linear voltage dynamics. """ + """(TODO) Linear voltage dynamics.""" - C = (neuron.C * neuron.coeffs['C']) + C = neuron.C * neuron.coeffs["C"] I = inj + np.sum(AScurrents_t0) - g = neuron.G * neuron.coeffs['G'] - tau = g/C - N = (I+ g*neuron.El)/C + g = neuron.G * neuron.coeffs["G"] + tau = g / C + N = (I + g * neuron.El) / C + + return voltage_t0 * np.exp(-neuron.dt * tau) + N * (1 - np.exp(-tau * neuron.dt)) / tau - return voltage_t0*np.exp(-neuron.dt*tau) + N*(1-np.exp(-tau*neuron.dt))/tau def spike_component_of_threshold_forward_euler(th_t0, b_spike, dt): - '''Spike component of threshold modeled as an exponential decay. Implemented - here for forward Euler - + """Spike component of threshold modeled as an exponential decay. Implemented + here for forward Euler + Parameters ---------- th_t0 : float @@ -186,14 +188,15 @@ def spike_component_of_threshold_forward_euler(th_t0, b_spike, dt): decay constant of exponential dt : float time step - ''' - b_spike=-b_spike #TODO: this is here because b_spike is always input as positive although it is negative - return th_t0 + th_t0*b_spike * dt - + """ + b_spike = -b_spike # TODO: this is here because b_spike is always input as positive although it is negative + return th_t0 + th_t0 * b_spike * dt + + def spike_component_of_threshold_exact(th0, b_spike, t): - '''Spike component of threshold modeled as an exponential decay. Implemented - here as exact analytical solution. - + """Spike component of threshold modeled as an exponential decay. Implemented + here as exact analytical solution. + Parameters ---------- th0 : float @@ -203,16 +206,17 @@ def spike_component_of_threshold_exact(th0, b_spike, t): t : float or array time step if used in an Euler setup time if used analytically - ''' - b_spike=-b_spike - return th0*np.exp(b_spike * t) + """ + b_spike = -b_spike + return th0 * np.exp(b_spike * t) + def voltage_component_of_threshold_forward_euler(th_t0, v_t0, dt, a_voltage, b_voltage, El): - '''Equation 2.1 of Mihalas and Nieber, 2009 implemented for use in forward Euler. Note - here all variables are in reference to threshold infinity. Therefore thr_inf is zero - here (replaced threshold_inf with 0 in the equation to be verbose). This is done so that - th_inf can be optimized without affecting this function. - + """Equation 2.1 of Mihalas and Nieber, 2009 implemented for use in forward Euler. Note + here all variables are in reference to threshold infinity. Therefore thr_inf is zero + here (replaced threshold_inf with 0 in the equation to be verbose). This is done so that + th_inf can be optimized without affecting this function. + Parameters ---------- th_t0 : float @@ -227,15 +231,16 @@ def voltage_component_of_threshold_forward_euler(th_t0, v_t0, dt, a_voltage, b_v constant b El : float reversal potential - ''' - return th_t0 + (a_voltage*(v_t0-El)-b_voltage*(th_t0-0))*dt + """ + return th_t0 + (a_voltage * (v_t0 - El) - b_voltage * (th_t0 - 0)) * dt + def voltage_component_of_threshold_exact(th0, v0, I, t, a_voltage, b_voltage, C, g, El): - '''Note this function is the exact formulation; however, dt is used because t0 is the initial time and dt + """Note this function is the exact formulation; however, dt is used because t0 is the initial time and dt is the time the function is exactly evaluated at. Note: that here, this equation is in reference to th_inf. - Therefore th0 is the total threshold-thr_inf (threshold_inf replaced with 0 in the equation to be verbose). + Therefore th0 is the total threshold-thr_inf (threshold_inf replaced with 0 in the equation to be verbose). This is done so that th_inf can be optimized without affecting this function. - + Parameters ---------- th0 : float @@ -255,26 +260,31 @@ def voltage_component_of_threshold_exact(th0, v0, I, t, a_voltage, b_voltage, C, capacitance g : float conductance (1/resistance) - El : float + El : float reversal potential - ''' - beta=(I+g*El)/g - phi=a_voltage/(b_voltage-g/C) - return phi*(v0-beta)*np.exp(-g*t/C)+1/(np.exp(b_voltage*t))*(th0-phi*(v0-beta)- - (a_voltage/b_voltage)*(beta-El)-0) +(a_voltage/b_voltage)*(beta-El) +0 - - -def dynamics_threshold_three_components_exact(neuron, threshold_t0, voltage_t0, AScurrents_t0, inj, - a_spike, b_spike, a_voltage, b_voltage): - """Analytical solution for threshold dynamics. The threshold will adapt via two mechanisms: - 1. a voltage dependent adaptation. - 2. a component initiated by a spike which decays as an exponential. - These two component are in reference to threshold infinity and are recorded + """ + beta = (I + g * El) / g + phi = a_voltage / (b_voltage - g / C) + return ( + phi * (v0 - beta) * np.exp(-g * t / C) + + 1 / (np.exp(b_voltage * t)) * (th0 - phi * (v0 - beta) - (a_voltage / b_voltage) * (beta - El) - 0) + + (a_voltage / b_voltage) * (beta - El) + + 0 + ) + + +def dynamics_threshold_three_components_exact( + neuron, threshold_t0, voltage_t0, AScurrents_t0, inj, a_spike, b_spike, a_voltage, b_voltage +): + """Analytical solution for threshold dynamics. The threshold will adapt via two mechanisms: + 1. a voltage dependent adaptation. + 2. a component initiated by a spike which decays as an exponential. + These two component are in reference to threshold infinity and are recorded in the neuron's threshold components. - The third component refers to th_inf which is added separately as opposed to being + The third component refers to th_inf which is added separately as opposed to being included in the voltage component of the threshold as is done in equation 2.1 of Mihalas and Nieber 2009. Threshold infinity is removed for simple optimization. - + Parameters ---------- neuron : class @@ -283,45 +293,49 @@ def dynamics_threshold_three_components_exact(neuron, threshold_t0, voltage_t0, voltage_t0 : float voltage input to function AScurrents_t0 : vector - values of after spike currents + values of after spike currents inj : float current injected into the neuron """ - #TODO: just having the get_threshold_components added an erroneous zero to the beginning of the list + # TODO: just having the get_threshold_components added an erroneous zero to the beginning of the list if neuron.threshold_components is None: - neuron.threshold_components = { 'spike': [], 'voltage': [] } + neuron.threshold_components = {"spike": [], "voltage": []} th_spike = 0 th_voltage = 0 - else: + else: tcs = neuron.threshold_components - th_spike = tcs['spike'][-1] - th_voltage = tcs['voltage'][-1] + th_spike = tcs["spike"][-1] + th_voltage = tcs["voltage"][-1] - a_voltage = a_voltage * neuron.coeffs['a'] - b_voltage = b_voltage * neuron.coeffs['b'] + a_voltage = a_voltage * neuron.coeffs["a"] + b_voltage = b_voltage * neuron.coeffs["b"] I = inj + np.sum(AScurrents_t0) - C = neuron.C * neuron.coeffs['C'] - g = neuron.G * neuron.coeffs['G'] + C = neuron.C * neuron.coeffs["C"] + g = neuron.G * neuron.coeffs["G"] - voltage_component=voltage_component_of_threshold_exact(th_voltage, voltage_t0, I, neuron.dt, a_voltage, b_voltage, C, g, neuron.El) + voltage_component = voltage_component_of_threshold_exact( + th_voltage, voltage_t0, I, neuron.dt, a_voltage, b_voltage, C, g, neuron.El + ) spike_component = spike_component_of_threshold_exact(th_spike, b_spike, neuron.dt) - - #------update the voltage and spiking values of the the + + # ------update the voltage and spiking values of the the neuron.append_threshold_components(spike_component, voltage_component) - - return voltage_component+spike_component+neuron.th_inf * neuron.coeffs['th_inf'] - -def dynamics_threshold_spike_component(neuron, threshold_t0, voltage_t0, AScurrents_t0, inj, - a_spike, b_spike, a_voltage, b_voltage): - """Analytical solution for spike component of threshold. The threshold will adapt via - a component initiated by a spike which decays as an exponential. The component is in - reference to threshold infinity and are recorded in the neuron's threshold components. The voltage - component of the threshold is set to zero in the threshold components because it is zero here - The third component refers to th_inf which is added separately as opposed to being + + return voltage_component + spike_component + neuron.th_inf * neuron.coeffs["th_inf"] + + +def dynamics_threshold_spike_component( + neuron, threshold_t0, voltage_t0, AScurrents_t0, inj, a_spike, b_spike, a_voltage, b_voltage +): + """Analytical solution for spike component of threshold. The threshold will adapt via + a component initiated by a spike which decays as an exponential. The component is in + reference to threshold infinity and are recorded in the neuron's threshold components. The voltage + component of the threshold is set to zero in the threshold components because it is zero here + The third component refers to th_inf which is added separately as opposed to being included in the voltage component of the threshold as is done in equation 2.1 of Mihalas and Nieber 2009. Threshold infinity is removed for simple optimization. - + Parameters ---------- neuron : class @@ -330,71 +344,75 @@ def dynamics_threshold_spike_component(neuron, threshold_t0, voltage_t0, AScurre voltage_t0 : float voltage input to function AScurrents_t0 : vector - values of after spike currents + values of after spike currents inj : float current injected into the neuron """ - #TODO: just having the get_threshold_components added an erroneous zero to the beginning of the list + # TODO: just having the get_threshold_components added an erroneous zero to the beginning of the list if neuron.threshold_components is None: - neuron.threshold_components = { 'spike': [], 'voltage': [] } + neuron.threshold_components = {"spike": [], "voltage": []} th_spike = 0 - else: + else: tcs = neuron.threshold_components - th_spike = tcs['spike'][-1] + th_spike = tcs["spike"][-1] spike_component = spike_component_of_threshold_exact(th_spike, b_spike, neuron.dt) - - #------update the voltage and spiking values of the the + + # ------update the voltage and spiking values of the the neuron.append_threshold_components(spike_component, 0.0) - - return spike_component+neuron.th_inf * neuron.coeffs['th_inf'] + + return spike_component + neuron.th_inf * neuron.coeffs["th_inf"] def dynamics_threshold_inf(neuron, threshold_t0, voltage_t0, AScurrents_t0, inj): - """ Set threshold to the neuron's instantaneous threshold. + """Set threshold to the neuron's instantaneous threshold. Parameters ---------- neuron : class threshold_t0 : not used here voltage_t0 : not used here - AScurrents_t0 : not used here + AScurrents_t0 : not used here inj : not used here - AScurrents_t0 : not used here + AScurrents_t0 : not used here inj : not used here - """ - return neuron.coeffs['th_inf'] * neuron.th_inf + """ + return neuron.coeffs["th_inf"] * neuron.th_inf def reset_AScurrent_sum(neuron, AScurrents_t0, r): - """ Reset afterspike currents by adding summed exponentials. Left over currents from last spikes as + """Reset afterspike currents by adding summed exponentials. Left over currents from last spikes as well as newly initiated currents from current spike. Currents amplitudes in neuron.asc_amp_array need - to be the amplitudes advanced though the spike cutting. I.e. In the preprocessor if the after spike currents - are calculated via the GLM from spike initiation the amplitude at the time after the spike cutting needs to + to be the amplitudes advanced though the spike cutting. I.e. In the preprocessor if the after spike currents + are calculated via the GLM from spike initiation the amplitude at the time after the spike cutting needs to be calculated and neuron.asc_amp_array needs to be set to this value. - + Parameters ---------- r : np.ndarray a coefficient vector applied to the afterspike currents """ - new_currents=neuron.asc_amp_array * neuron.coeffs['asc_amp_array'] #neuron.asc_amp_array are amplitudes initiating after the spike is cut - left_over_currents=AScurrents_t0 * r * np.exp(-(neuron.k * neuron.dt * neuron.spike_cut_length)) #advancing cut currents though the spike + new_currents = ( + neuron.asc_amp_array * neuron.coeffs["asc_amp_array"] + ) # neuron.asc_amp_array are amplitudes initiating after the spike is cut + left_over_currents = ( + AScurrents_t0 * r * np.exp(-(neuron.k * neuron.dt * neuron.spike_cut_length)) + ) # advancing cut currents though the spike - return new_currents+left_over_currents + return new_currents + left_over_currents def reset_AScurrent_none(neuron, AScurrents_t0): - """ Reset afterspike currents to zero. """ - - if np.sum(AScurrents_t0)!=0: - raise Exception('You are running a LIF but the AScurrents are not zero!') + """Reset afterspike currents to zero.""" + + if np.sum(AScurrents_t0) != 0: + raise Exception("You are running a LIF but the AScurrents are not zero!") return np.zeros(len(AScurrents_t0)) def reset_voltage_v_before(neuron, voltage_t0, a, b): - """ Reset voltage to the previous value with a scale and offset applied. + """Reset voltage to the previous value with a scale and offset applied. Parameters ---------- @@ -404,108 +422,103 @@ def reset_voltage_v_before(neuron, voltage_t0, a, b): voltage offset constant """ - return a*(voltage_t0)+b + return a * (voltage_t0) + b + def reset_voltage_zero(neuron, voltage_t0): - """ Reset voltage to zero. """ + """Reset voltage to zero.""" return 0.0 + def reset_threshold_inf(neuron, threshold_t0, voltage_v1): - """ Reset the threshold to instantaneous threshold. """ - return neuron.coeffs['th_inf'] * neuron.th_inf + """Reset the threshold to instantaneous threshold.""" + return neuron.coeffs["th_inf"] * neuron.th_inf + def reset_threshold_three_components(neuron, threshold_t0, voltage_v1, a_spike, b_spike): - '''This method calculates the two components of the threshold: a spike (fast) - component and a voltage (slow) component. The threshold_components vectors are then + """This method calculates the two components of the threshold: a spike (fast) + component and a voltage (slow) component. The threshold_components vectors are then updated so that the traces match the voltage, current, and total threshold traces. The spike component of the threshold decays via an exponential fit specified by the amplitude - a_spike and the time constant b_spike fit via the multiblip data. The voltage component - does not change during the duration of the spike. The - spike component are threshold component are summed along with threshold infinity to - return the total threshold. Note that in the current implementation a_spike is added to - the last value of the threshold_components which means that a_spike is the amplitude after - spike cutting (if there is any). - - Inputs: + a_spike and the time constant b_spike fit via the multiblip data. The voltage component + does not change during the duration of the spike. The + spike component are threshold component are summed along with threshold infinity to + return the total threshold. Note that in the current implementation a_spike is added to + the last value of the threshold_components which means that a_spike is the amplitude after + spike cutting (if there is any). + + Inputs: neuron: class contains attributes of the neuron threshold_t0, voltage_t0: float are not used but are here for consistency with other methods - a_spike: float + a_spike: float amplitude of the exponential decay of spike component of threshold after spike cutting has been implemented. b_spike: float amplitude of the exponential decay of spike component of threshold - + Outputs: Returns: float - the total threshold which is the sum of the spike component of threshold, the voltage + the total threshold which is the sum of the spike component of threshold, the voltage component of threshold and threshold infinity (with it's corresponding coefficient) neuron.threshold_components: dictionary containing a spike: list - vector of spiking component of threshold that corresponds to the voltage, current, + vector of spiking component of threshold that corresponds to the voltage, current, and total threshold traces b_spike: list - vector of voltage component of threshold that corresponds to the voltage, current, + vector of voltage component of threshold that corresponds to the voltage, current, and total threshold traces. - - Note that this function can be changed to use a_spike at the time of the spike and then have the + + Note that this function can be changed to use a_spike at the time of the spike and then have the the spike component plus the residual decay thought the spike. There are benefits and drawbacks to - this. This potential change would be beneficial as it perhaps makes more biological sense for the + this. This potential change would be beneficial as it perhaps makes more biological sense for the threshold to go up at the time of spike if the traces are ever used. Also this would mean that a_spike - would not have to be adjusted thought the spike cutting after the multiblip fit. However the current + would not have to be adjusted thought the spike cutting after the multiblip fit. However the current implementation makes sense in that it is similar to how afterspike currents are implemented. - ''' + """ if neuron.threshold_components is None: - raise Exception('reset should never happen at the beginning of a trace') - - tcs = neuron.threshold_components #for ease of updating - - # note that these values are at the indicie of the time of the spike which is the index right after the voltage crosses - # threshold since the neuron.threshold_components are updated by the dynamics method which is called before the reset. - th_spike=tcs['spike'][-1] #this needs to decay through the spike must be very particular about how many indicies to decay - th_voltage= tcs['voltage'][-1] - + raise Exception("reset should never happen at the beginning of a trace") + + tcs = neuron.threshold_components # for ease of updating + + # note that these values are at the indicie of the time of the spike which is the index right after the voltage crosses + # threshold since the neuron.threshold_components are updated by the dynamics method which is called before the reset. + th_spike = tcs["spike"][ + -1 + ] # this needs to decay through the spike must be very particular about how many indicies to decay + th_voltage = tcs["voltage"][-1] + # calculate spike component decay though spike from time =1 (not zero because zero is already in neuron.threshold_components # via the dynamics method) though the end of the spike cutting - spike_comp_decay=spike_component_of_threshold_exact(th_spike, b_spike, np.arange(1,neuron.spike_cut_length+1)*neuron.dt) #Note that the plus one is that one needs to know the decay and the inital condition for next starting point - - #update neuron.threshold_components via pass by reference. - [tcs['voltage'].append(value) for value in np.ones(neuron.spike_cut_length)*th_voltage] #note that here I don't need the plus one because I am starting from zero - [tcs['spike'].append(value) for value in spike_comp_decay] - + spike_comp_decay = spike_component_of_threshold_exact( + th_spike, b_spike, np.arange(1, neuron.spike_cut_length + 1) * neuron.dt + ) # Note that the plus one is that one needs to know the decay and the inital condition for next starting point + + # update neuron.threshold_components via pass by reference. + [ + tcs["voltage"].append(value) for value in np.ones(neuron.spike_cut_length) * th_voltage + ] # note that here I don't need the plus one because I am starting from zero + [tcs["spike"].append(value) for value in spike_comp_decay] + # add the amplitude of the spike component decay to last value of vector (reseting) - tcs['spike'][-1]=tcs['spike'][-1]+a_spike - - return tcs['spike'][-1] + tcs['voltage'][-1] + neuron.th_inf * neuron.coeffs['th_inf'] + tcs["spike"][-1] = tcs["spike"][-1] + a_spike + + return tcs["spike"][-1] + tcs["voltage"][-1] + neuron.th_inf * neuron.coeffs["th_inf"] -#: The METHOD_LIBRARY constant groups dynamics and reset methods by group name (e.g. 'voltage_dynamics_method'). -#Those groups assign each method in this file a string name. This is used by the GlifNeuron when initializing -#its dynamics and reset methods. +#: The METHOD_LIBRARY constant groups dynamics and reset methods by group name (e.g. 'voltage_dynamics_method'). +# Those groups assign each method in this file a string name. This is used by the GlifNeuron when initializing +# its dynamics and reset methods. METHOD_LIBRARY = { - 'AScurrent_dynamics_method': { - 'exp': dynamics_AScurrent_exp, - 'none': dynamics_AScurrent_none - }, - 'voltage_dynamics_method': { - 'linear_forward_euler': dynamics_voltage_linear_forward_euler - }, - 'threshold_dynamics_method': { - 'spike_component': dynamics_threshold_spike_component, - 'inf': dynamics_threshold_inf, - 'three_components_exact': dynamics_threshold_three_components_exact - }, - 'AScurrent_reset_method': { - 'sum': reset_AScurrent_sum, - 'none': reset_AScurrent_none - }, - 'voltage_reset_method': { - 'v_before': reset_voltage_v_before, - 'zero': reset_voltage_zero - }, - 'threshold_reset_method': { - 'inf': reset_threshold_inf, - 'three_components': reset_threshold_three_components - } + "AScurrent_dynamics_method": {"exp": dynamics_AScurrent_exp, "none": dynamics_AScurrent_none}, + "voltage_dynamics_method": {"linear_forward_euler": dynamics_voltage_linear_forward_euler}, + "threshold_dynamics_method": { + "spike_component": dynamics_threshold_spike_component, + "inf": dynamics_threshold_inf, + "three_components_exact": dynamics_threshold_three_components_exact, + }, + "AScurrent_reset_method": {"sum": reset_AScurrent_sum, "none": reset_AScurrent_none}, + "voltage_reset_method": {"v_before": reset_voltage_v_before, "zero": reset_voltage_zero}, + "threshold_reset_method": {"inf": reset_threshold_inf, "three_components": reset_threshold_three_components}, } diff --git a/allensdk/model/glif/simulate_neuron.py b/allensdk/model/glif/simulate_neuron.py index 5812e0ec45..01f0450824 100644 --- a/allensdk/model/glif/simulate_neuron.py +++ b/allensdk/model/glif/simulate_neuron.py @@ -43,25 +43,30 @@ from allensdk.api.queries.glif_api import GlifApi from allensdk.model.glif.glif_neuron import GlifNeuron -DEFAULT_SPIKE_CUT_VALUE = 0.05 # 50mV +DEFAULT_SPIKE_CUT_VALUE = 0.05 # 50mV -def parse_arguments(): - ''' Use argparse to get required arguments from the command line ''' - parser = argparse.ArgumentParser(description='fit a neuron') - parser.add_argument('--ephys_file', help='ephys file name') - parser.add_argument('--sweeps_file', help='JSON file listing sweep properties') - parser.add_argument('--neuron_config_file', help='neuron configuration JSON file ') - parser.add_argument('--neuronal_model_id', help='id of the neuronal model. Used when downloading sweep properties.', type=int) - parser.add_argument('--output_ephys_file', help='output file name') - parser.add_argument('--log_level', help='log level', default=logging.INFO) - parser.add_argument('--spike_cut_value', help='value to fill in for spike duration', default=DEFAULT_SPIKE_CUT_VALUE, type=float) +def parse_arguments(): + """Use argparse to get required arguments from the command line""" + parser = argparse.ArgumentParser(description="fit a neuron") + + parser.add_argument("--ephys_file", help="ephys file name") + parser.add_argument("--sweeps_file", help="JSON file listing sweep properties") + parser.add_argument("--neuron_config_file", help="neuron configuration JSON file ") + parser.add_argument( + "--neuronal_model_id", help="id of the neuronal model. Used when downloading sweep properties.", type=int + ) + parser.add_argument("--output_ephys_file", help="output file name") + parser.add_argument("--log_level", help="log level", default=logging.INFO) + parser.add_argument( + "--spike_cut_value", help="value to fill in for spike duration", default=DEFAULT_SPIKE_CUT_VALUE, type=float + ) return parser.parse_args() def simulate_sweep(neuron, stimulus, spike_cut_value): - ''' Simulate a neuron given a stimulus and initial conditions. ''' + """Simulate a neuron given a stimulus and initial conditions.""" start_time = time.time() @@ -69,7 +74,7 @@ def simulate_sweep(neuron, stimulus, spike_cut_value): data = neuron.run(stimulus) - voltage = data['voltage'] + voltage = data["voltage"] voltage[np.isnan(voltage)] = spike_cut_value logging.debug("simulation time %f" % (time.time() - start_time)) @@ -78,7 +83,7 @@ def simulate_sweep(neuron, stimulus, spike_cut_value): def load_sweep(file_name, sweep_number): - ''' Load the stimulus for a sweep from file. ''' + """Load the stimulus for a sweep from file.""" logging.debug("loading sweep %d" % sweep_number) load_start_time = time.time() @@ -90,7 +95,7 @@ def load_sweep(file_name, sweep_number): def write_sweep_response(file_name, sweep_number, response, spike_times): - ''' Overwrite the response in a file. ''' + """Overwrite the response in a file.""" logging.debug("writing sweep") @@ -104,7 +109,7 @@ def write_sweep_response(file_name, sweep_number, response, spike_times): def simulate_sweep_from_file(neuron, sweep_number, input_file_name, output_file_name, spike_cut_value): - ''' Load a sweep stimulus, simulate the response, and write it out. ''' + """Load a sweep stimulus, simulate the response, and write it out.""" sweep_start_time = time.time() @@ -115,16 +120,16 @@ def simulate_sweep_from_file(neuron, sweep_number, input_file_name, output_file_ raise # tell the neuron what dt should be for this sweep - neuron.dt = 1.0 / data['sampling_rate'] + neuron.dt = 1.0 / data["sampling_rate"] - sim_data = simulate_sweep(neuron, data['stimulus'], spike_cut_value) + sim_data = simulate_sweep(neuron, data["stimulus"], spike_cut_value) - write_sweep_response(output_file_name, sweep_number, sim_data['voltage'], sim_data['interpolated_spike_times']) + write_sweep_response(output_file_name, sweep_number, sim_data["voltage"], sim_data["interpolated_spike_times"]) - logging.debug("total sweep time %f" % ( time.time() - sweep_start_time )) + logging.debug("total sweep time %f" % (time.time() - sweep_start_time)) -def simulate_neuron(neuron, sweep_numbers, input_file_name, output_file_name, spike_cut_value): +def simulate_neuron(neuron, sweep_numbers, input_file_name, output_file_name, spike_cut_value): start_time = time.time() for sweep_number in sweep_numbers: @@ -132,17 +137,17 @@ def simulate_neuron(neuron, sweep_numbers, input_file_name, output_file_name, sp logging.debug("total elapsed time %f" % (time.time() - start_time)) + def main(): args = parse_arguments() logging.getLogger().setLevel(args.log_level) glif_api = None - if (args.neuron_config_file is None or - args.sweeps_file is None or - args.ephys_file is None): - - assert args.neuronal_model_id is not None, Exception("A neuronal model id is required if no neuron config file, sweeps file, or ephys data file is provided.") + if args.neuron_config_file is None or args.sweeps_file is None or args.ephys_file is None: + assert args.neuronal_model_id is not None, Exception( + "A neuronal model id is required if no neuron config file, sweeps file, or ephys data file is provided." + ) glif_api = GlifApi() glif_api.get_neuronal_model(args.neuronal_model_id) @@ -160,7 +165,7 @@ def main(): if args.ephys_file: ephys_file = args.ephys_file else: - ephys_file = 'stimulus_%d.nwb' % args.neuronal_model_id + ephys_file = "stimulus_%d.nwb" % args.neuronal_model_id if not os.path.exists(ephys_file): logging.info("Downloading stimulus to %s." % ephys_file) @@ -174,15 +179,13 @@ def main(): logging.warning("Overwriting input file data with simulated data in place.") output_ephys_file = ephys_file - neuron = GlifNeuron.from_dict(neuron_config) # filter out test sweeps - sweep_numbers = [ s['sweep_number'] for s in sweeps if s['stimulus_name'] != 'Test' ] + sweep_numbers = [s["sweep_number"] for s in sweeps if s["stimulus_name"] != "Test"] simulate_neuron(neuron, sweep_numbers, ephys_file, output_ephys_file, args.spike_cut_value) - if __name__ == "__main__": main() diff --git a/allensdk/morphology/__init__.py b/allensdk/morphology/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/morphology/__init__.py +++ b/allensdk/morphology/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/morphology/validate_swc.py b/allensdk/morphology/validate_swc.py index 561863fe53..1df99290ba 100644 --- a/allensdk/morphology/validate_swc.py +++ b/allensdk/morphology/validate_swc.py @@ -36,6 +36,7 @@ import argparse import allensdk.core.swc as swc + def validate_swc(swc_file): """ To be compatible with NEURON, SWC files must have the following properties: @@ -46,33 +47,28 @@ def validate_swc(swc_file): soma_id = swc.Morphology.SOMA morphology = swc.read_swc(swc_file) # verify that there is a single root node - num_soma_nodes = sum([(int(c['type']) == soma_id) - for c in morphology.compartment_list]) + num_soma_nodes = sum([(int(c["type"]) == soma_id) for c in morphology.compartment_list]) if num_soma_nodes != 1: - raise Exception( - "SWC must have single soma compartment. Found: %d" % num_soma_nodes) + raise Exception("SWC must have single soma compartment. Found: %d" % num_soma_nodes) # sanity check root = morphology.root if root is None: raise Exception("Morphology has no root node") # verify that children of the root have max one child - for root_child_id in root['children']: + for root_child_id in root["children"]: root_child = morphology.compartment_index[root_child_id] - num_grand_children = len(root_child['children']) + num_grand_children = len(root_child["children"]) if num_grand_children > 1: - raise Exception("Child of root (%s) has more than one child (%d)" % ( - root_child_id, num_grand_children)) + raise Exception("Child of root (%s) has more than one child (%d)" % (root_child_id, num_grand_children)) # get a list of all of the ids, make sure they are unique while we're at it all_ids = set() for compartment in morphology.compartment_list: iid = int(compartment["id"]) if iid in all_ids: - raise Exception("Compartment ID %s is not unique." % - compartment["id"]) + raise Exception("Compartment ID %s is not unique." % compartment["id"]) pid = int(compartment["parent"]) if iid < pid: - raise Exception( - "Compartment (%d) has a smaller ID that its parent (%d)" % (iid, pid)) + raise Exception("Compartment (%d) has a smaller ID that its parent (%d)" % (iid, pid)) all_ids.add(iid) # sort the ids and make sure there are no gaps @@ -85,13 +81,14 @@ def validate_swc(swc_file): def main(): try: - parser = argparse.ArgumentParser( - "validate an SWC file for use with NEURON") - parser.add_argument('swc_file') + parser = argparse.ArgumentParser("validate an SWC file for use with NEURON") + parser.add_argument("swc_file") args = parser.parse_args() validate_swc(args.swc_file) except Exception as e: print(str(e)) exit(1) + + if __name__ == "__main__": main() diff --git a/allensdk/mouse_connectivity/grid/__init__.py b/allensdk/mouse_connectivity/grid/__init__.py index a968d7ab97..eea2d86961 100755 --- a/allensdk/mouse_connectivity/grid/__init__.py +++ b/allensdk/mouse_connectivity/grid/__init__.py @@ -1,19 +1,9 @@ - from .writers import classic_writer, count_writer, cav_writer from .subimage import CavSubImage, CountSubImage, ClassicSubImage cases = { - 'classic': { - 'writer': classic_writer, - 'subimage': ClassicSubImage - }, - 'count': { - 'writer': count_writer, - 'subimage': CountSubImage - }, - 'cav': { - 'writer': cav_writer, - 'subimage': CavSubImage - } -} \ No newline at end of file + "classic": {"writer": classic_writer, "subimage": ClassicSubImage}, + "count": {"writer": count_writer, "subimage": CountSubImage}, + "cav": {"writer": cav_writer, "subimage": CavSubImage}, +} diff --git a/allensdk/mouse_connectivity/grid/__main__.py b/allensdk/mouse_connectivity/grid/__main__.py index a2c14804ca..6be42d6ade 100755 --- a/allensdk/mouse_connectivity/grid/__main__.py +++ b/allensdk/mouse_connectivity/grid/__main__.py @@ -6,119 +6,120 @@ import argschema import requests -from allensdk.brain_observatory.argschema_utilities import \ - write_or_print_outputs +from allensdk.brain_observatory.argschema_utilities import write_or_print_outputs from . import cases from ._schemas import InputParameters, OutputParameters from .image_series_gridder import ImageSeriesGridder -def get_inputs_from_lims(host, image_series_id, output_root, job_queue, - strategy): - uri = ''.join(''' +def get_inputs_from_lims(host, image_series_id, output_root, job_queue, strategy): + uri = "".join( + """ {}/input_jsons? object_id={}& object_class=ImageSeries& strategy_class={}& job_queue_name={} - '''.format(host, image_series_id, strategy, job_queue).split()) + """.format(host, image_series_id, strategy, job_queue).split() + ) response = requests.get(uri) data = response.json() - if len(data) == 1 and 'error' in data: - raise ValueError('bad request uri: {} ({})'.format(uri, data['error'])) + if len(data) == 1 and "error" in data: + raise ValueError("bad request uri: {} ({})".format(uri, data["error"])) - data['storage_directory'] = os.path.join(output_root, os.path.split( - data['storage_directory'])[-1]) - data['grid_prefix'] = os.path.join(output_root, - os.path.split(data['grid_prefix'])[-1]) - data['accumulator_prefix'] = os.path.join(output_root, os.path.split( - data['accumulator_prefix'])[-1]) + data["storage_directory"] = os.path.join(output_root, os.path.split(data["storage_directory"])[-1]) + data["grid_prefix"] = os.path.join(output_root, os.path.split(data["grid_prefix"])[-1]) + data["accumulator_prefix"] = os.path.join(output_root, os.path.split(data["accumulator_prefix"])[-1]) return data def run_grid(args): try: - case = cases[args['case']] + case = cases[args["case"]] except KeyError: - logging.error('unrecognized case: {}'.format(args['case'])) + logging.error("unrecognized case: {}".format(args["case"])) raise - sub_images = args['sub_images'] + sub_images = args["sub_images"] - input_dimensions = [sub_images[0]['dimensions']['column'], - sub_images[0]['dimensions']['row'], - args['sub_image_count']] + input_dimensions = [ + sub_images[0]["dimensions"]["column"], + sub_images[0]["dimensions"]["row"], + args["sub_image_count"], + ] - input_spacing = [sub_images[0]['spacing']['column'], - sub_images[0]['spacing']['row'], - args['image_series_slice_spacing']] + input_spacing = [ + sub_images[0]["spacing"]["column"], + sub_images[0]["spacing"]["row"], + args["image_series_slice_spacing"], + ] for ii, si in enumerate(sub_images): - del si['dimensions'] - del si['spacing'] - si['polygon_info'] = si['polygons'] - del si['polygons'] - sub_images = sorted(sub_images, key=lambda si: si['specimen_tissue_index']) - logging.info('{} sub images with indices: {}'.format( - len(sub_images), [si['specimen_tissue_index'] for si in sub_images]) + del si["dimensions"] + del si["spacing"] + si["polygon_info"] = si["polygons"] + del si["polygons"] + sub_images = sorted(sub_images, key=lambda si: si["specimen_tissue_index"]) + logging.info( + "{} sub images with indices: {}".format(len(sub_images), [si["specimen_tissue_index"] for si in sub_images]) ) - output_dimensions = [args['reference_dimensions']['slice'], - args['reference_dimensions']['row'], - args['reference_dimensions']['column']] + output_dimensions = [ + args["reference_dimensions"]["slice"], + args["reference_dimensions"]["row"], + args["reference_dimensions"]["column"], + ] - output_spacing = [args['reference_spacing']['slice'], - args['reference_spacing']['row'], - args['reference_spacing']['column']] + output_spacing = [ + args["reference_spacing"]["slice"], + args["reference_spacing"]["row"], + args["reference_spacing"]["column"], + ] - subimage_kwargs = {'cls': case['subimage']} - if args['filter_bit'] is not None: - subimage_kwargs['filter_bit'] = args['filter_bit'] + subimage_kwargs = {"cls": case["subimage"]} + if args["filter_bit"] is not None: + subimage_kwargs["filter_bit"] = args["filter_bit"] gridder = ImageSeriesGridder( in_dims=input_dimensions, in_spacing=input_spacing, out_dims=output_dimensions, out_spacing=output_spacing, - reduce_level=args['reduce_level'], + reduce_level=args["reduce_level"], subimages=sub_images, subimage_kwargs=subimage_kwargs, - nprocesses=args['nprocesses'], - affine_params=args['affine_params'], - dfmfld_path=args['deformation_field_path'] + nprocesses=args["nprocesses"], + affine_params=args["affine_params"], + dfmfld_path=args["deformation_field_path"], ) gridder.setup_subimages() gridder.build_coarse_grids() - writer = case['writer'] - paths = writer(gridder, args['grid_prefix'], args['accumulator_prefix'], - target_spacings=args['target_spacings']) + writer = case["writer"] + paths = writer(gridder, args["grid_prefix"], args["accumulator_prefix"], target_spacings=args["target_spacings"]) - return {'output_file_paths': paths} + return {"output_file_paths": paths} def main(): - logging.basicConfig( - format='%(asctime)s - %(process)s - %(levelname)s - %(message)s') + logging.basicConfig(format="%(asctime)s - %(process)s - %(levelname)s - %(message)s") # TODO replace with argschema implementation of multisource parser remaining_args = sys.argv[1:] input_data = {} - if '--get_inputs_from_lims' in sys.argv: + if "--get_inputs_from_lims" in sys.argv: lims_parser = argparse.ArgumentParser(add_help=False) - lims_parser.add_argument('--host', type=str, default='http://lims2') - lims_parser.add_argument('--job_queue', type=str, default=None) - lims_parser.add_argument('--strategy', type=str, default=None) - lims_parser.add_argument('--image_series_id', type=int, default=None) - lims_parser.add_argument('--output_root', type=str, default=None) - - lims_args, remaining_args = lims_parser.parse_known_args( - remaining_args) - remaining_args = [item for item in remaining_args if - item != '--get_inputs_from_lims'] + lims_parser.add_argument("--host", type=str, default="http://lims2") + lims_parser.add_argument("--job_queue", type=str, default=None) + lims_parser.add_argument("--strategy", type=str, default=None) + lims_parser.add_argument("--image_series_id", type=int, default=None) + lims_parser.add_argument("--output_root", type=str, default=None) + + lims_args, remaining_args = lims_parser.parse_known_args(remaining_args) + remaining_args = [item for item in remaining_args if item != "--get_inputs_from_lims"] input_data = get_inputs_from_lims(**lims_args.__dict__) parser = argschema.ArgSchemaParser( @@ -132,5 +133,5 @@ def main(): write_or_print_outputs(output, parser) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/allensdk/mouse_connectivity/grid/_schemas.py b/allensdk/mouse_connectivity/grid/_schemas.py index 02228dd936..6f29a88ac1 100755 --- a/allensdk/mouse_connectivity/grid/_schemas.py +++ b/allensdk/mouse_connectivity/grid/_schemas.py @@ -3,11 +3,7 @@ from argschema.schemas import DefaultSchema from marshmallow import RAISE -VALID_CASES = ( - 'classic', - 'cav', - 'count' -) +VALID_CASES = ("classic", "cav", "count") class RaisingSchema(DefaultSchema): @@ -50,55 +46,40 @@ class InputParameters(ArgSchema): class Meta: unknown = RAISE - log_level = LogLevel(default='INFO', - description="set the logging level of the module") - case = String(required=True, validate=lambda s: s in VALID_CASES, - help='select a use case to run') - sub_images = Nested(SubImage, required=True, many=True, - help='Sub images composing this image series') - affine_params = List(Float, - help='Parameters of affine image stack to reference ' - 'space transform.') - deformation_field_path = String(required=True, - help='Path to parameters of the ' - 'deformable local transform from ' - 'affine-transformed image stack to ' - 'reference space transform.' - ) - image_series_slice_spacing = Float(required=True, - help='Distance (microns) between ' - 'successive images in this ' - 'series.') - target_spacings = List(Float, required=True, - help='For each volume produced, downsample to ' - 'this isometric resolution') - reference_spacing = Nested(ReferenceSpacing, required=True, - help='Native spacing of reference space (' - 'microns).') - reference_dimensions = Nested(ReferenceDimensions, required=True, - help='Native dimensions of reference space.') - sub_image_count = Int(required=True, help='Expected number of sub images') - grid_prefix = String(required=True, help='Write output grid files here') - accumulator_prefix = String(required=True, - help='If this run produces accumulators, ' - 'write them here.') - storage_directory = String(required=False, - help='Storage directory for this image ' - 'series. Not used') - filter_bit = Int(default=None, allow_none=True, - help='if provided, signals that pixels with this bit ' - 'high have passed the optional post-filter stage') - nprocesses = Int(default=8, help='spawn this many worker subprocesses') - reduce_level = Int(default=0, - help='power of two by which to downsample each input ' - 'axis') + log_level = LogLevel(default="INFO", description="set the logging level of the module") + case = String(required=True, validate=lambda s: s in VALID_CASES, help="select a use case to run") + sub_images = Nested(SubImage, required=True, many=True, help="Sub images composing this image series") + affine_params = List(Float, help="Parameters of affine image stack to reference space transform.") + deformation_field_path = String( + required=True, + help="Path to parameters of the " + "deformable local transform from " + "affine-transformed image stack to " + "reference space transform.", + ) + image_series_slice_spacing = Float( + required=True, help="Distance (microns) between successive images in this series." + ) + target_spacings = List( + Float, required=True, help="For each volume produced, downsample to this isometric resolution" + ) + reference_spacing = Nested(ReferenceSpacing, required=True, help="Native spacing of reference space (microns).") + reference_dimensions = Nested(ReferenceDimensions, required=True, help="Native dimensions of reference space.") + sub_image_count = Int(required=True, help="Expected number of sub images") + grid_prefix = String(required=True, help="Write output grid files here") + accumulator_prefix = String(required=True, help="If this run produces accumulators, write them here.") + storage_directory = String(required=False, help="Storage directory for this image series. Not used") + filter_bit = Int( + default=None, + allow_none=True, + help="if provided, signals that pixels with this bit high have passed the optional post-filter stage", + ) + nprocesses = Int(default=8, help="spawn this many worker subprocesses") + reduce_level = Int(default=0, help="power of two by which to downsample each input axis") class OutputSchema(RaisingSchema): - input_parameters = Nested(InputParameters, - description=("Input parameters the module " - "was run with"), - required=True) + input_parameters = Nested(InputParameters, description=("Input parameters the module was run with"), required=True) class OutputParameters(OutputSchema): diff --git a/allensdk/mouse_connectivity/grid/image_series_gridder.py b/allensdk/mouse_connectivity/grid/image_series_gridder.py index 54cec7209f..aef21b6bcb 100755 --- a/allensdk/mouse_connectivity/grid/image_series_gridder.py +++ b/allensdk/mouse_connectivity/grid/image_series_gridder.py @@ -9,148 +9,130 @@ from .utilities import image_utilities as iu -#============================================================================== +# ============================================================================== class ImageSeriesGridder(object): - @property def transform(self): - - if not hasattr(self, '_transform'): + if not hasattr(self, "_transform"): dfmfld = sitk.ReadImage(str(self.dfmfld_path)) self._transform = iu.build_composite_transform(dfmfld, self.affine_params) del dfmfld - - return self._transform + return self._transform - def __init__(self, in_dims, in_spacing, - out_dims, out_spacing, - reduce_level, - subimages, - subimage_kwargs, - nprocesses, - affine_params, - dfmfld_path): - + def __init__( + self, + in_dims, + in_spacing, + out_dims, + out_spacing, + reduce_level, + subimages, + subimage_kwargs, + nprocesses, + affine_params, + dfmfld_path, + ): self.in_dims = np.array(in_dims) self.in_spacing = np.array(in_spacing) - + self.out_dims = np.array(out_dims) self.out_spacing = np.array(out_spacing) - + self.reduce_level = reduce_level - + self.nprocesses = nprocesses - + self.affine_params = affine_params self.dfmfld_path = dfmfld_path - + self.volumes = {} - + self.subimages = subimages self.subimage_kwargs = subimage_kwargs - - + def set_coarse_grid_parameters(self): - - self.coarse_dims, self.coarse_spacing, self.coarse_grid_radius = \ - iu.compute_coarse_parameters(self.in_dims, self.in_spacing, - self.out_spacing[::-1], - self.reduce_level) - + self.coarse_dims, self.coarse_spacing, self.coarse_grid_radius = iu.compute_coarse_parameters( + self.in_dims, self.in_spacing, self.out_spacing[::-1], self.reduce_level + ) + self.coarse_dims[-1] = self.in_dims[-1] self.coarse_spacing[-1] = self.in_spacing[-1] self.coarse_grid_radius = self.coarse_grid_radius[0] - - + def setup_subimages(self): - - if not hasattr(self, 'coarse_grid_radius'): + if not hasattr(self, "coarse_grid_radius"): self.set_coarse_grid_parameters() - - dc = {'in_dims': self.in_dims[:2], - 'in_spacing': self.in_spacing[:2], - 'coarse_dims': self.coarse_dims[:2], - 'coarse_spacing': self.coarse_spacing[:2], - 'reduce_level': self.reduce_level} + + dc = { + "in_dims": self.in_dims[:2], + "in_spacing": self.in_spacing[:2], + "coarse_dims": self.coarse_dims[:2], + "coarse_spacing": self.coarse_spacing[:2], + "reduce_level": self.reduce_level, + } dc.update(self.subimage_kwargs) - + for si in self.subimages: si.update(dc) - - + def initialize_coarse_volume(self, key, dtype): - logging.info('initializing {0} coarse grid volume'.format(key)) + logging.info("initializing {0} coarse grid volume".format(key)) self.volumes[key] = iu.new_image(self.coarse_dims, self.coarse_spacing, dtype, True) origin = list(self.volumes[key].GetOrigin()) origin[2] = 0 self.volumes[key].SetOrigin(origin) - - + def paste_slice(self, key, index, slice_array): - ''' - ''' - + """ """ + if key not in self.volumes: sitk_type = iu.np_sitk_convert(slice_array.dtype) self.initialize_coarse_volume(key, sitk_type) - - logging.info('resampling data from index {0} into {1} coarse grid volume'.format(index, key)) + + logging.info("resampling data from index {0} into {1} coarse grid volume".format(index, key)) slice_image = iu.image_from_array(slice_array.T, self.coarse_spacing[:2], True) self.volumes[key] = iu.resample_into_volume(slice_image, None, index, self.volumes[key]) - - + def paste_subimage(self, index, output): - '''Inserts planar accumulators into coarse grid volumes - ''' - + """Inserts planar accumulators into coarse grid volumes""" + for key, array in output.items(): self.paste_slice(key, index, array) output[key] = None - + del output - - + def build_coarse_grids(self): - pool = mp.Pool(processes=self.nprocesses) mapper = pool.imap_unordered(run_subimage, self.subimages) - - logging.info('building coarse grids ({} processes)'.format(self.nprocesses)) + + logging.info("building coarse grids ({} processes)".format(self.nprocesses)) for index, output in mapper: - - logging.info('received coarse planar data from subimage at index {0}'.format(index)) + logging.info("received coarse planar data from subimage at index {0}".format(index)) self.paste_subimage(index, output) - def resample_volume(self, key): - logging.info('resampling {0} volume'.format(key)) - self.volumes[key] = iu.resample_volume(self.volumes[key], self.out_dims, - self.out_spacing, None, - self.transform) - + logging.info("resampling {0} volume".format(key)) + self.volumes[key] = iu.resample_volume(self.volumes[key], self.out_dims, self.out_spacing, None, self.transform) def consume_volume(self, key, cb): - logging.info('consuming {0} volume'.format(key)) + logging.info("consuming {0} volume".format(key)) self.resample_volume(key) cb(self.volumes[key]) del self.volumes[key] - def accumulator_to_numpy(self, key, cb): self.resample_volume(key) cb(self.volumes[key]) - logging.info('converting {0} volume to ndarray'.format(key)) + logging.info("converting {0} volume to ndarray".format(key)) self.volumes[key] = sitk.GetArrayFromImage(self.volumes[key]) - def make_ratio_volume(self, num_key, den_key, ratio_key): - '''assume parents numpified - ''' + """assume parents numpified""" self.volumes[ratio_key] = np.divide(self.volumes[num_key], self.volumes[den_key]) self.volumes[ratio_key][np.isnan(self.volumes[ratio_key])] = 0 - diff --git a/allensdk/mouse_connectivity/grid/subimage/__init__.py b/allensdk/mouse_connectivity/grid/subimage/__init__.py index 2a905eba94..601e62b2a5 100755 --- a/allensdk/mouse_connectivity/grid/subimage/__init__.py +++ b/allensdk/mouse_connectivity/grid/subimage/__init__.py @@ -8,22 +8,21 @@ def run_subimage(input_data): - # TODO: remove or fix - logging.basicConfig(format='%(asctime)s - %(process)s - %(levelname)s - %(message)s') - logging.getLogger('').setLevel(logging.INFO) - - index = input_data.pop('specimen_tissue_index') - cls = input_data.pop('cls') - logging.info('handling {0} at index {1}'.format(cls.__name__, index)) + logging.basicConfig(format="%(asctime)s - %(process)s - %(levelname)s - %(message)s") + logging.getLogger("").setLevel(logging.INFO) + + index = input_data.pop("specimen_tissue_index") + cls = input_data.pop("cls") + logging.info("handling {0} at index {1}".format(cls.__name__, index)) si = cls(**input_data) - + try: si.setup_images() si.compute_coarse_planes() except Exception as err: logging.exception(err) raise err - + return index, si.accumulators diff --git a/allensdk/mouse_connectivity/grid/subimage/base_subimage.py b/allensdk/mouse_connectivity/grid/subimage/base_subimage.py index f30371403f..5d1ffe937d 100755 --- a/allensdk/mouse_connectivity/grid/subimage/base_subimage.py +++ b/allensdk/mouse_connectivity/grid/subimage/base_subimage.py @@ -8,49 +8,40 @@ from allensdk.mouse_connectivity.grid.utilities import image_utilities as iu -#============================================================================== +# ============================================================================== class SubImage(object): - @property def pixel_counter(self): - if not hasattr(self, '_pixel_counter'): + if not hasattr(self, "_pixel_counter"): self._pixel_counter = self.make_pixel_counter() return self._pixel_counter - - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, - *args, **kwargs): - + def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs): self.reduce_level = reduce_level self.in_dims = np.around(in_dims / 2**reduce_level).astype(int) self.in_spacing = in_spacing * 2**reduce_level self.coarse_spacing = coarse_spacing - - self.blocks, self.coarse_dims = iu.grid_image_blocks( - self.in_dims, self.in_spacing, self.coarse_spacing) - + + self.blocks, self.coarse_dims = iu.grid_image_blocks(self.in_dims, self.in_spacing, self.coarse_spacing) + self.images = {} self.accumulators = {} - - + def setup_images(self): pass - - + def compute_coarse_planes(self): raise NotImplementedError() - def binarize(self, image_name): - logging.info('binarizing {0}'.format(image_name)) + logging.info("binarizing {0}".format(image_name)) self.images[image_name][np.nonzero(self.images[image_name])] = 1 - def apply_mask(self, image_name, mask_name, positive=True): - logging.info('applying {0} mask to {1}'.format(mask_name, image_name)) - + logging.info("applying {0} mask to {1}".format(mask_name, image_name)) + mask = self.images[mask_name] if not positive: mask = np.logical_not(mask) @@ -58,211 +49,176 @@ def apply_mask(self, image_name, mask_name, positive=True): self.images[image_name] = np.multiply(self.images[image_name], mask) - def make_pixel_counter(self): def fn(x): - return np.sum(x) * 2 ** ( self.reduce_level + 1) # additional x2 <- is an area - return functools.partial(iu.block_apply, out_shape=self.coarse_dims, - dtype=np.float32, blocks=self.blocks, - fn=fn) + return np.sum(x) * 2 ** (self.reduce_level + 1) # additional x2 <- is an area + return functools.partial( + iu.block_apply, out_shape=self.coarse_dims, dtype=np.float32, blocks=self.blocks, fn=fn + ) def apply_pixel_counter(self, accumulator_name, image): self.accumulators[accumulator_name] = self.pixel_counter(image) - - -#============================================================================== - -class SegmentationSubImage(SubImage): - required_segmentations = [] +# ============================================================================== - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, - segmentation_paths, *args, **kwargs): +class SegmentationSubImage(SubImage): + required_segmentations = [] - super(SegmentationSubImage, self).__init__( - reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs) + def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, segmentation_paths, *args, **kwargs): + super(SegmentationSubImage, self).__init__(reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs) self.segmentation_paths = segmentation_paths - - if 'filter_bit' in kwargs and kwargs['filter_bit'] is not None: - self.filter = 2 ** kwargs['filter_bit'] + if "filter_bit" in kwargs and kwargs["filter_bit"] is not None: + self.filter = 2 ** kwargs["filter_bit"] def setup_images(self): super(SegmentationSubImage, self).setup_images() self.get_segmentation() - def get_segmentation(self): - for name in self.__class__.required_segmentations: self.read_segmentation_image(name) self.process_segmentation() - def process_segmentation(self): pass - - def extract_signal_from_segmentation(self, segmentation_name='segmentation', - signal_name='signal'): - ''' + def extract_signal_from_segmentation(self, segmentation_name="segmentation", signal_name="signal"): + """ Notes ----- - Currently, the segmentation uses a series of codes to map 8-bit values - onto meaningful classifications. The code for signal pixels is a 1 in + Currently, the segmentation uses a series of codes to map 8-bit values + onto meaningful classifications. The code for signal pixels is a 1 in the leftmost bit. - In some cases, bit 5 indicates that the pixel was not removed in a + In some cases, bit 5 indicates that the pixel was not removed in a posfiltering process. Optionally, this postfilter can be applied in gridding. - ''' + """ - logging.info('extracting {0} mask'.format(signal_name)) + logging.info("extracting {0} mask".format(signal_name)) self.images[signal_name] = np.right_shift(self.images[segmentation_name], 7) - + signal_count = np.count_nonzero(self.images[signal_name]) - logging.info('{0} signal pixels were detected'.format(signal_count)) + logging.info("{0} signal pixels were detected".format(signal_count)) - if hasattr(self, 'filter'): + if hasattr(self, "filter"): filter_mask = np.bitwise_and(self.images[segmentation_name], self.filter) self.images[signal_name][filter_mask == 0] = 0 filter_count = np.count_nonzero(self.images[signal_name]) - logging.info('{0} / {1} pixels passed the signal filter'.format(filter_count, signal_count)) - - - + logging.info("{0} / {1} pixels passed the signal filter".format(filter_count, signal_count)) - def extract_injection_from_segmentation(self, segmentation_name='segmentation', - injection_name='injection'): - ''' + def extract_injection_from_segmentation(self, segmentation_name="segmentation", injection_name="injection"): + """ Notes ----- - Currently, the segmentation uses a series of codes to map 8-bit values - onto meaningful classifications. The code for signal pixels is a 1 in + Currently, the segmentation uses a series of codes to map 8-bit values + onto meaningful classifications. The code for signal pixels is a 1 in at least one of of the 5 rightmost bits. - ''' + """ - logging.info('extracting {0} mask'.format(injection_name)) + logging.info("extracting {0} mask".format(injection_name)) self.images[injection_name] = np.bitwise_and(self.images[segmentation_name], 31) self.images[injection_name][self.images[injection_name] > 0] = 1 - - def read_segmentation_image(self, segmentation_name='segmentation'): - ''' + def read_segmentation_image(self, segmentation_name="segmentation"): + """ Notes ----- - We downsample in memory rather than using the jp2 pyramid because the + We downsample in memory rather than using the jp2 pyramid because the segmentation is a label image. - ''' - + """ + path = self.segmentation_paths[segmentation_name] - logging.info('loading {} from {}'.format(segmentation_name, path)) + logging.info("loading {} from {}".format(segmentation_name, path)) segmentation = iu.read_segmentation_image(path) - logging.info('{} shape: {}'.format(segmentation_name, segmentation.shape)) + logging.info("{} shape: {}".format(segmentation_name, segmentation.shape)) if self.reduce_level > 0: - logging.info('downsampling {0}'.format(segmentation_name)) + logging.info("downsampling {0}".format(segmentation_name)) segmentation = zoom(segmentation, 1.0 / 2**self.reduce_level, order=0) - + self.images[segmentation_name] = segmentation -#============================================================================== +# ============================================================================== class IntensitySubImage(SubImage): - required_intensities = [] - - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, intensity_paths, - *args, **kwargs): - - super(IntensitySubImage, self).__init__(reduce_level, in_dims, - in_spacing, coarse_spacing, *args, **kwargs) + def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, intensity_paths, *args, **kwargs): + super(IntensitySubImage, self).__init__(reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs) self.intensity_paths = intensity_paths - def get_intensity(self): - for name in self.__class__.required_intensities: info = self.intensity_paths[name] - logging.info('loading {} intensities from {}'.format(name, info['path'])) - - self.images[name] = iu.read_intensity_image(info['path'], self.reduce_level, info['channel']) - logging.info('loaded {} intensities to image of shape: {}'.format(name, self.images[name].shape)) - + logging.info("loading {} intensities from {}".format(name, info["path"])) + self.images[name] = iu.read_intensity_image(info["path"], self.reduce_level, info["channel"]) + logging.info("loaded {} intensities to image of shape: {}".format(name, self.images[name].shape)) def setup_images(self): super(IntensitySubImage, self).setup_images() self.get_intensity() -#============================================================================== +# ============================================================================== - -class PolygonSubImage(SubImage): +class PolygonSubImage(SubImage): required_polys = [] optional_polys = [] - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, - polygon_info, *args, **kwargs): - - super(PolygonSubImage, self).__init__( - reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs) - + def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, polygon_info, *args, **kwargs): + super(PolygonSubImage, self).__init__(reduce_level, in_dims, in_spacing, coarse_spacing, *args, **kwargs) + self.polygon_info = polygon_info - - + def setup_images(self): super(PolygonSubImage, self).setup_images() self.get_polygons() - def get_polygons(self): - polygon_keys = [] polygon_keys.extend(self.__class__.optional_polys) polygon_keys.extend(self.__class__.required_polys) for key in polygon_keys: - logging.info('rasterizing {0} polygon'.format(key)) - + logging.info("rasterizing {0} polygon".format(key)) + points = self.polygon_info[key] - self.images[key] = iu.rasterize_polygons(self.in_dims.astype(int)[::-1], - [1.0 / 2**self.reduce_level, - 1.0 / 2**self.reduce_level], - points).T - - -#============================================================================== - - + self.images[key] = iu.rasterize_polygons( + self.in_dims.astype(int)[::-1], [1.0 / 2**self.reduce_level, 1.0 / 2**self.reduce_level], points + ).T + + +# ============================================================================== + + def run_subimage(input_data): - # TODO: not propagating the log level from the calling thread - logging.getLogger('').setLevel(logging.INFO) - - index = input_data.pop('specimen_tissue_index') - cls = input_data.pop('cls') - logging.info('handling {0} at index {1}'.format(cls.__name__, index)) + logging.getLogger("").setLevel(logging.INFO) + + index = input_data.pop("specimen_tissue_index") + cls = input_data.pop("cls") + logging.info("handling {0} at index {1}".format(cls.__name__, index)) si = cls(**input_data) - + si.setup_images() si.compute_coarse_planes() diff --git a/allensdk/mouse_connectivity/grid/subimage/cav_subimage.py b/allensdk/mouse_connectivity/grid/subimage/cav_subimage.py index af9bd4889b..628d3c3b8d 100755 --- a/allensdk/mouse_connectivity/grid/subimage/cav_subimage.py +++ b/allensdk/mouse_connectivity/grid/subimage/cav_subimage.py @@ -5,27 +5,24 @@ from .base_subimage import PolygonSubImage -#============================================================================== +# ============================================================================== class CavSubImage(PolygonSubImage): - - required_polys = ['missing_tile', 'cav_tracer'] - + required_polys = ["missing_tile", "cav_tracer"] def compute_coarse_planes(self): + nonmissing = np.logical_not(self.images["missing_tile"]) + del self.images["missing_tile"] - nonmissing = np.logical_not(self.images['missing_tile']) - del self.images['missing_tile'] - - self.apply_pixel_counter('sum_pixels', nonmissing) + self.apply_pixel_counter("sum_pixels", nonmissing) - cav_nonmissing = np.multiply(self.images['cav_tracer'], nonmissing) + cav_nonmissing = np.multiply(self.images["cav_tracer"], nonmissing) del nonmissing - self.apply_pixel_counter('cav_tracer', cav_nonmissing) + self.apply_pixel_counter("cav_tracer", cav_nonmissing) del cav_nonmissing del self.images -#============================================================================== +# ============================================================================== diff --git a/allensdk/mouse_connectivity/grid/subimage/classic_subimage.py b/allensdk/mouse_connectivity/grid/subimage/classic_subimage.py index 53dc5152d2..aef8023846 100755 --- a/allensdk/mouse_connectivity/grid/subimage/classic_subimage.py +++ b/allensdk/mouse_connectivity/grid/subimage/classic_subimage.py @@ -6,108 +6,111 @@ from .base_subimage import PolygonSubImage, SegmentationSubImage, IntensitySubImage -#============================================================================== +# ============================================================================== class ClassicSubImage(IntensitySubImage, SegmentationSubImage, PolygonSubImage): - - required_polys = ['missing_tile', 'no_signal', 'aav_exclusion'] - optional_polys = ['aav_tracer'] - required_segmentations = ['segmentation'] - required_intensities = ['green'] - - - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, - polygon_info, segmentation_paths, intensity_paths, - injection_polygon_key='aav_tracer', - *args, **kwargs): - + required_polys = ["missing_tile", "no_signal", "aav_exclusion"] + optional_polys = ["aav_tracer"] + required_segmentations = ["segmentation"] + required_intensities = ["green"] + + def __init__( + self, + reduce_level, + in_dims, + in_spacing, + coarse_spacing, + polygon_info, + segmentation_paths, + intensity_paths, + injection_polygon_key="aav_tracer", + *args, + **kwargs, + ): super(ClassicSubImage, self).__init__( - reduce_level, in_dims, in_spacing, coarse_spacing, - polygon_info=polygon_info, - segmentation_paths=segmentation_paths, - intensity_paths=intensity_paths, - *args, **kwargs) + reduce_level, + in_dims, + in_spacing, + coarse_spacing, + polygon_info=polygon_info, + segmentation_paths=segmentation_paths, + intensity_paths=intensity_paths, + *args, + **kwargs, + ) self.injection_polygon_key = injection_polygon_key - def process_segmentation(self): + self.apply_mask("segmentation", "missing_tile", False) + + self.extract_signal_from_segmentation(signal_name="projection") + + self.apply_mask("projection", "no_signal", False) + del self.images["no_signal"] - self.apply_mask('segmentation', 'missing_tile', False) - - self.extract_signal_from_segmentation(signal_name='projection') - - self.apply_mask('projection', 'no_signal', False) - del self.images['no_signal'] - if self.injection_polygon_key in self.images: - logging.info('reading injection from rasterized {} polygon'.format(self.injection_polygon_key)) - self.images['injection'] = self.images[self.injection_polygon_key] + logging.info("reading injection from rasterized {} polygon".format(self.injection_polygon_key)) + self.images["injection"] = self.images[self.injection_polygon_key] del self.images[self.injection_polygon_key] else: self.extract_injection_from_segmentation() - del self.images['segmentation'] - logging.info('injection pixel count: {}'.format(np.count_nonzero(self.images['injection']))) - - self.binarize('projection') - self.binarize('injection') - - - def compute_coarse_planes(self): + del self.images["segmentation"] + logging.info("injection pixel count: {}".format(np.count_nonzero(self.images["injection"]))) + self.binarize("projection") + self.binarize("injection") + + def compute_coarse_planes(self): # do these in batches to minimize peak memory usage self.compute_intensity() self.compute_injection() self.compute_projection() self.compute_sum_pixels() - + del self.images - def compute_intensity(self): - logging.info('computing green accumulators') - - self.apply_pixel_counter('sum_pixel_intensities', self.images['green']) + logging.info("computing green accumulators") - injection_intensity = np.multiply(self.images['green'], self.images['injection']) - self.apply_pixel_counter('injection_sum_pixel_intensities', injection_intensity) + self.apply_pixel_counter("sum_pixel_intensities", self.images["green"]) + + injection_intensity = np.multiply(self.images["green"], self.images["injection"]) + self.apply_pixel_counter("injection_sum_pixel_intensities", injection_intensity) del injection_intensity - - self.images['green'][self.images['projection'] == 0] = 0 - self.apply_pixel_counter('sum_projecting_pixel_intensities', self.images['green']) - - self.images['green'][self.images['injection'] == 0] = 0 - self.apply_pixel_counter('injectionsum_projecting_pixel_intensities', self.images['green']) - del self.images['green'] + self.images["green"][self.images["projection"] == 0] = 0 + self.apply_pixel_counter("sum_projecting_pixel_intensities", self.images["green"]) + + self.images["green"][self.images["injection"] == 0] = 0 + self.apply_pixel_counter("injectionsum_projecting_pixel_intensities", self.images["green"]) + del self.images["green"] def compute_injection(self): - logging.info('computing injection accumulators') - - self.apply_pixel_counter('injection_sum_pixels', self.images['injection']) + logging.info("computing injection accumulators") - injection_projecting_pixels = np.logical_and(self.images['injection'], self.images['projection']) - self.apply_pixel_counter('injection_sum_projecting_pixels', injection_projecting_pixels) - del injection_projecting_pixels + self.apply_pixel_counter("injection_sum_pixels", self.images["injection"]) - del self.images['injection'] + injection_projecting_pixels = np.logical_and(self.images["injection"], self.images["projection"]) + self.apply_pixel_counter("injection_sum_projecting_pixels", injection_projecting_pixels) + del injection_projecting_pixels + del self.images["injection"] def compute_projection(self): - logging.info('computing projection accumulators') - - self.apply_pixel_counter('sum_projecting_pixels', self.images['projection']) - del self.images['projection'] - - + logging.info("computing projection accumulators") + + self.apply_pixel_counter("sum_projecting_pixels", self.images["projection"]) + del self.images["projection"] + def compute_sum_pixels(self): - logging.info('computing sum pixel accumulators') + logging.info("computing sum pixel accumulators") + + self.apply_pixel_counter("sum_pixels", np.logical_not(self.images["missing_tile"])) - self.apply_pixel_counter('sum_pixels', np.logical_not(self.images['missing_tile'])) - - if 'aav_exclusion' in self.images: - self.apply_pixel_counter('aav_exclusion_sum_pixels', self.images['aav_exclusion']) + if "aav_exclusion" in self.images: + self.apply_pixel_counter("aav_exclusion_sum_pixels", self.images["aav_exclusion"]) -#============================================================================== +# ============================================================================== diff --git a/allensdk/mouse_connectivity/grid/subimage/count_subimage.py b/allensdk/mouse_connectivity/grid/subimage/count_subimage.py index d93601d0a2..4c5e9d9e83 100755 --- a/allensdk/mouse_connectivity/grid/subimage/count_subimage.py +++ b/allensdk/mouse_connectivity/grid/subimage/count_subimage.py @@ -7,71 +7,77 @@ class CountSubImage(SegmentationSubImage, PolygonSubImage): - - required_polys = ['missing_tile', 'no_signal', 'aav_exclusion'] - required_segmentations = ['segmentation'] - - def __init__(self, reduce_level, in_dims, in_spacing, coarse_spacing, - polygon_info, segmentation_paths, injection_polygon_key='aav_tracer', *args, **kwargs): - - super(CountSubImage, self).__init__(reduce_level, in_dims, in_spacing, - coarse_spacing, - polygon_info=polygon_info, - segmentation_paths=segmentation_paths, - *args, **kwargs) + required_polys = ["missing_tile", "no_signal", "aav_exclusion"] + required_segmentations = ["segmentation"] + + def __init__( + self, + reduce_level, + in_dims, + in_spacing, + coarse_spacing, + polygon_info, + segmentation_paths, + injection_polygon_key="aav_tracer", + *args, + **kwargs, + ): + super(CountSubImage, self).__init__( + reduce_level, + in_dims, + in_spacing, + coarse_spacing, + polygon_info=polygon_info, + segmentation_paths=segmentation_paths, + *args, + **kwargs, + ) self.injection_polygon_key = injection_polygon_key - def process_segmentation(self): + self.apply_mask("segmentation", "missing_tile", False) - self.apply_mask('segmentation', 'missing_tile', False) - - self.extract_signal_from_segmentation(signal_name='projection') + self.extract_signal_from_segmentation(signal_name="projection") - self.apply_mask('projection', 'no_signal', False) - del self.images['no_signal'] + self.apply_mask("projection", "no_signal", False) + del self.images["no_signal"] if self.injection_polygon_key in self.images: - self.images['injection'] = self.images[self.injection_polygon_key] + self.images["injection"] = self.images[self.injection_polygon_key] del self.images[injection_polygon_key] else: - self.extract_injection_from_segmentation() - del self.images['segmentation'] - - self.binarize('projection') - self.binarize('injection') + self.extract_injection_from_segmentation() + del self.images["segmentation"] + self.binarize("projection") + self.binarize("injection") def compute_injection(self): - logging.info('computing injection accumulators') + logging.info("computing injection accumulators") - self.apply_pixel_counter('injection_sum_pixels', self.images['injection']) + self.apply_pixel_counter("injection_sum_pixels", self.images["injection"]) - injection_projecting_pixels = np.logical_and(self.images['injection'], self.images['projection']) - self.apply_pixel_counter('injection_sum_projecting_pixels', injection_projecting_pixels) + injection_projecting_pixels = np.logical_and(self.images["injection"], self.images["projection"]) + self.apply_pixel_counter("injection_sum_projecting_pixels", injection_projecting_pixels) del injection_projecting_pixels - del self.images['injection'] - + del self.images["injection"] def compute_projection(self): - logging.info('computing projection accumulators') - - self.apply_pixel_counter('sum_projecting_pixels', self.images['projection']) - del self.images['projection'] + logging.info("computing projection accumulators") + self.apply_pixel_counter("sum_projecting_pixels", self.images["projection"]) + del self.images["projection"] def compute_sum_pixels(self): - logging.info('computing sum pixel accumulators') + logging.info("computing sum pixel accumulators") - self.apply_pixel_counter('sum_pixels', np.logical_not(self.images['missing_tile'])) - - if 'aav_exclusion' in self.images: - self.apply_pixel_counter('aav_exclusion_sum_pixels', self.images['aav_exclusion']) + self.apply_pixel_counter("sum_pixels", np.logical_not(self.images["missing_tile"])) + if "aav_exclusion" in self.images: + self.apply_pixel_counter("aav_exclusion_sum_pixels", self.images["aav_exclusion"]) def compute_coarse_planes(self): - self.compute_injection() self.compute_projection() self.compute_sum_pixels() diff --git a/allensdk/mouse_connectivity/grid/utilities/downsampling_utilities.py b/allensdk/mouse_connectivity/grid/utilities/downsampling_utilities.py index 1e63eacb01..aae6f2cbe5 100755 --- a/allensdk/mouse_connectivity/grid/utilities/downsampling_utilities.py +++ b/allensdk/mouse_connectivity/grid/utilities/downsampling_utilities.py @@ -9,7 +9,6 @@ def downsample_average(volume, current_spacing, target_spacing): - factor = target_spacing / current_spacing if factor == 1: @@ -20,33 +19,28 @@ def downsample_average(volume, current_spacing, target_spacing): elif factor - np.floor(factor) == 0.5: volume = window_average(volume, factor) else: - raise ValueError('voxels cannot be unevenly split!') + raise ValueError("voxels cannot be unevenly split!") return volume def block_average(volume, factor): - logging.info('downsampling by block averaging with a factor of {0}'.format(factor)) + logging.info("downsampling by block averaging with a factor of {0}".format(factor)) factor = np.around(factor).astype(int) return block_reduce(volume, tuple([factor, factor, factor]), np.mean, 0) def apply_divisions(image, window_size): - for axis in range(image.ndim): - - slc = tuple([ - slice(window_size-1, None, window_size) - if ii == axis - else slice(0, None) - for ii in range(image.ndim) - ]) + slc = tuple( + [slice(window_size - 1, None, window_size) if ii == axis else slice(0, None) for ii in range(image.ndim)] + ) image[slc] = image[slc] / 2 def window_average(volume, factor): - logging.info('downsampling by window averaging with a factor of {0}'.format(factor)) + logging.info("downsampling by window averaging with a factor of {0}".format(factor)) volume = volume.copy() window_size = np.ceil(factor).astype(int) @@ -60,21 +54,22 @@ def window_average(volume, factor): def conv(image, factor, window_size): - kernel = np.ones([window_size for ii in image.shape]) - return convolve(image, kernel, mode='constant', cval=0.0) / factor ** image.ndim + kernel = np.ones([window_size for ii in image.shape]) + return convolve(image, kernel, mode="constant", cval=0.0) / factor**image.ndim def extract(image, factor, window_size, window_step, output_shape): - - output = np.zeros( output_shape ) - - for case in it.product(*([[0, 1]] * image.ndim)): - - inp = tuple([slice(window_size - 2, None, window_step) - if not ii else slice(window_size, None, window_step) for ii in case]) + output = np.zeros(output_shape) + + for case in it.product(*([[0, 1]] * image.ndim)): + inp = tuple( + [ + slice(window_size - 2, None, window_step) if not ii else slice(window_size, None, window_step) + for ii in case + ] + ) out = tuple([slice(0, None, 2) if not ii else slice(1, None, 2) for ii in case]) output[out] = image[inp] return output - diff --git a/allensdk/mouse_connectivity/grid/utilities/image_utilities.py b/allensdk/mouse_connectivity/grid/utilities/image_utilities.py index c78a25578d..5f01a54709 100755 --- a/allensdk/mouse_connectivity/grid/utilities/image_utilities.py +++ b/allensdk/mouse_connectivity/grid/utilities/image_utilities.py @@ -28,8 +28,7 @@ def set_image_spacing(image, spacing, origin=True): - ''' - ''' + """ """ spacing = np.array(spacing) @@ -40,8 +39,7 @@ def set_image_spacing(image, spacing, origin=True): def new_image(dims, spacing, dtype, origin=True): - ''' - ''' + """ """ if len(dims) == 2: image = sitk.Image(dims[0], dims[1], dtype) @@ -53,8 +51,7 @@ def new_image(dims, spacing, dtype, origin=True): def image_from_array(array, spacing, origin=True): - ''' - ''' + """ """ image = sitk.GetImageFromArray(array) set_image_spacing(image, spacing, origin) @@ -63,15 +60,13 @@ def image_from_array(array, spacing, origin=True): def np_sitk_convert(np_type): - ''' - ''' + """ """ return NUMPY_SITK_TYPE_LOOKUP[np_type] def sitk_np_convert(sitk_type): - ''' - ''' + """ """ return SITK_NUMPY_TYPE_LOOKUP[sitk_type] @@ -80,8 +75,7 @@ def sitk_np_convert(sitk_type): def compute_coarse_parameters(in_dims, in_spacing, out_spacing, reduce_level): - ''' - ''' + """ """ reduce_factor = pow(2, reduce_level) fradius = np.divide(out_spacing, in_spacing) / 2.0 / reduce_factor @@ -90,51 +84,39 @@ def compute_coarse_parameters(in_dims, in_spacing, out_spacing, reduce_level): coarse_grid_size = (coarse_grid_radius * 2 + 1) * reduce_factor coarse_grid_spacing = np.multiply(in_spacing, coarse_grid_size) - coarse_grid_dims = np.ceil( - np.divide(in_dims, coarse_grid_size) - ).astype(int) + coarse_grid_dims = np.ceil(np.divide(in_dims, coarse_grid_size)).astype(int) return coarse_grid_dims, coarse_grid_spacing, coarse_grid_radius def block_apply(in_image, out_shape, dtype, blocks, fn): - ''' - ''' + """ """ out_image = np.zeros(out_shape, dtype=dtype) for ii, row_block in enumerate(blocks[0]): for jj, col_block in enumerate(blocks[1]): - - out_image[ii, jj] = fn(in_image[row_block[0]:row_block[1], - col_block[0]:col_block[1]]) + out_image[ii, jj] = fn(in_image[row_block[0] : row_block[1], col_block[0] : col_block[1]]) return out_image def grid_image_blocks(in_shape, in_spacing, out_spacing): - ''' - ''' + """ """ blocks = [] out_shape = [] for dim in range(len(in_shape)): - in_px_centers = np.arange(in_spacing[dim]*0.5, - in_shape[dim]*in_spacing[dim], - in_spacing[dim]) + in_px_centers = np.arange(in_spacing[dim] * 0.5, in_shape[dim] * in_spacing[dim], in_spacing[dim]) - out_px_edges = np.arange(out_spacing[dim], - (in_shape[dim]-0.5)*in_spacing[dim], - out_spacing[dim]) + out_px_edges = np.arange(out_spacing[dim], (in_shape[dim] - 0.5) * in_spacing[dim], out_spacing[dim]) dig = np.digitize(in_px_centers, out_px_edges) inds = np.where(np.diff(dig) > 0)[0] + 1 inds = [0] + inds.tolist() + [in_shape[dim]] - dim_blocks = [ - (int(inds[i]), int(inds[i+1])) for i in range(len(inds)-1) - ] + dim_blocks = [(int(inds[i]), int(inds[i + 1])) for i in range(len(inds) - 1)] out_shape.append(len(dim_blocks)) blocks.append(dim_blocks) @@ -146,16 +128,10 @@ def grid_image_blocks(in_shape, in_spacing, out_spacing): def rasterize_polygons(shape, scale, polys): - canvas = np.zeros(shape, dtype=np.uint8) for points in polys: - - rpts = np.array([ - int(np.around(item[1] * scale[1])) for item in points - ]) - cpts = np.array([ - int(np.around(item[0] * scale[0])) for item in points - ]) + rpts = np.array([int(np.around(item[1] * scale[1])) for item in points]) + cpts = np.array([int(np.around(item[0] * scale[0])) for item in points]) poly = polygon(rpts, cpts) canvas[poly] = 1 @@ -167,8 +143,7 @@ def rasterize_polygons(shape, scale, polys): def resample_into_volume(image, transform, z, vol, dtype=sitk.sitkFloat32): - ''' - ''' + """ """ if transform is None: transform = sitk.Transform() @@ -179,8 +154,7 @@ def resample_into_volume(image, transform, z, vol, dtype=sitk.sitkFloat32): def build_affine_transform(aff_params): - ''' - ''' + """ """ xfm = sitk.AffineTransform(3) xfm.SetParameters(aff_params) @@ -189,11 +163,9 @@ def build_affine_transform(aff_params): def build_composite_transform(dfmfield=None, aff_params=None): - ''' - ''' + """ """ - if dfmfield is not None and \ - dfmfield.GetPixelIDValue() != sitk.sitkVectorFloat64: + if dfmfield is not None and dfmfield.GetPixelIDValue() != sitk.sitkVectorFloat64: dfmfield = sitk.Cast(dfmfield, sitk.sitkVectorFloat64) if dfmfield is None and aff_params is None: @@ -212,8 +184,7 @@ def build_composite_transform(dfmfield=None, aff_params=None): def resample_volume(volume, dims, spacing, interpolator=None, transform=None): - ''' - ''' + """ """ if transform is None: transform = sitk.Transform() @@ -224,27 +195,20 @@ def resample_volume(volume, dims, spacing, interpolator=None, transform=None): return sitk.Resample(volume, ref, transform, interpolator) -def write_volume(volume, - name, - prefix=None, - specify_resolution=None, - extension='.nrrd', - paths=None): - +def write_volume(volume, name, prefix=None, specify_resolution=None, extension=".nrrd", paths=None): if prefix is None: path = name else: path = os.path.join(prefix, name) if specify_resolution is not None: - if isinstance(specify_resolution, (float, np.floating)) and \ - specify_resolution % 1.0 == 0: + if isinstance(specify_resolution, (float, np.floating)) and specify_resolution % 1.0 == 0: specify_resolution = int(specify_resolution) - path = path + '_{0}'.format(specify_resolution) + path = path + "_{0}".format(specify_resolution) path = path + extension - logging.info('writing {0} volume to {1}'.format(name, path)) + logging.info("writing {0} volume to {1}".format(name, path)) Manifest.safe_make_parent_dirs(path) volume.SetOrigin([0, 0, 0]) sitk.WriteImage(volume, str(path), True) @@ -255,13 +219,13 @@ def write_volume(volume, def __read_segmentation_image_with_kakadu(path): if not os.path.exists(path): - raise OSError('file not found at {}'.format(path)) + raise OSError("file not found at {}".format(path)) return jpeg_twok.read(path).T def __read_intensity_image_with_kakadu(path, reduce_level, channel): if not os.path.exists(path): - raise OSError('file not found at {}'.format(path)) + raise OSError("file not found at {}".format(path)) return jpeg_twok.read(path, reduce_level, channel).T @@ -281,11 +245,13 @@ def __read_intensity_image_with_glymur(path): # however, since it is proprietary, we can't share it # alongside the allensdk, # so we default to glymur (a python openjpeg) for external users. - sys.path.append('/shared/bioapps/itk/itk_shared/jp2/build') + sys.path.append("/shared/bioapps/itk/itk_shared/jp2/build") import jpeg_twok + read_segmentation_image = __read_segmentation_image_with_kakadu read_intensity_image = __read_intensity_image_with_kakadu except failed_import: import glymur + read_segmentation_image = __read_segmentation_image_with_glymur read_intensity_image = __read_intensity_image_with_glymur diff --git a/allensdk/mouse_connectivity/grid/writers/__init__.py b/allensdk/mouse_connectivity/grid/writers/__init__.py index aad37882e2..c87c603b9a 100755 --- a/allensdk/mouse_connectivity/grid/writers/__init__.py +++ b/allensdk/mouse_connectivity/grid/writers/__init__.py @@ -9,110 +9,186 @@ def count_writer(gridder, grid_prefix, accumulator_prefix, target_spacings, **kwargs): paths = [] - cb = functools.partial(write_volume, name='sum_pixels', prefix=accumulator_prefix, paths=paths) - gridder.accumulator_to_numpy('sum_pixels', cb) - - ratio_and_pyramid(gridder, 'sum_projecting_pixels', 'sum_pixels', 'projection_density', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'injection_sum_projecting_pixels', 'sum_pixels', 'injection_density', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'injection_sum_pixels', 'sum_pixels', 'injection_fraction', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'aav_exclusion_sum_pixels', 'sum_pixels', 'aav_exclusion_fraction', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - gridder.volumes['data_mask'] = gridder.volumes['sum_pixels'] / np.amax(gridder.volumes['sum_pixels']) - del gridder.volumes['sum_pixels'] - gridder.volumes['data_mask'] = gridder.volumes['data_mask'].round() - handle_pyramid(gridder, 'data_mask', target_spacings, grid_prefix, paths=paths) + cb = functools.partial(write_volume, name="sum_pixels", prefix=accumulator_prefix, paths=paths) + gridder.accumulator_to_numpy("sum_pixels", cb) + + ratio_and_pyramid( + gridder, + "sum_projecting_pixels", + "sum_pixels", + "projection_density", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "injection_sum_projecting_pixels", + "sum_pixels", + "injection_density", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "injection_sum_pixels", + "sum_pixels", + "injection_fraction", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "aav_exclusion_sum_pixels", + "sum_pixels", + "aav_exclusion_fraction", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + gridder.volumes["data_mask"] = gridder.volumes["sum_pixels"] / np.amax(gridder.volumes["sum_pixels"]) + del gridder.volumes["sum_pixels"] + gridder.volumes["data_mask"] = gridder.volumes["data_mask"].round() + handle_pyramid(gridder, "data_mask", target_spacings, grid_prefix, paths=paths) return paths - -def cav_writer(gridder, grid_prefix, accumulator_prefix, **kwargs): +def cav_writer(gridder, grid_prefix, accumulator_prefix, **kwargs): paths = [] - cb = functools.partial(write_volume, name='cav_tracer_10', prefix=accumulator_prefix, paths=paths) - gridder.accumulator_to_numpy('cav_tracer', cb) + cb = functools.partial(write_volume, name="cav_tracer_10", prefix=accumulator_prefix, paths=paths) + gridder.accumulator_to_numpy("cav_tracer", cb) - cb = functools.partial(write_volume, name='sum_pixels_10', prefix=accumulator_prefix, paths=paths) - gridder.accumulator_to_numpy('sum_pixels', cb) + cb = functools.partial(write_volume, name="sum_pixels_10", prefix=accumulator_prefix, paths=paths) + gridder.accumulator_to_numpy("sum_pixels", cb) - gridder.make_ratio_volume('cav_tracer', 'sum_pixels', 'cav_density') - gridder.volumes['cav_density'] = image_from_array(gridder.volumes['cav_density'], gridder.out_spacing) - write_volume(gridder.volumes['cav_density'], name='cav_density_10', prefix=grid_prefix, paths=paths) - del gridder.volumes['cav_density'] + gridder.make_ratio_volume("cav_tracer", "sum_pixels", "cav_density") + gridder.volumes["cav_density"] = image_from_array(gridder.volumes["cav_density"], gridder.out_spacing) + write_volume(gridder.volumes["cav_density"], name="cav_density_10", prefix=grid_prefix, paths=paths) + del gridder.volumes["cav_density"] - gridder.volumes['data_mask'] = gridder.volumes['sum_pixels'] / np.amax(gridder.volumes['sum_pixels']) - del gridder.volumes['sum_pixels'] - gridder.volumes['data_mask'] = gridder.volumes['data_mask'].round() - gridder.volumes['data_mask'] = image_from_array(gridder.volumes['data_mask'], gridder.out_spacing) - write_volume(gridder.volumes['data_mask'], name='data_mask_10', prefix=accumulator_prefix, paths=paths) + gridder.volumes["data_mask"] = gridder.volumes["sum_pixels"] / np.amax(gridder.volumes["sum_pixels"]) + del gridder.volumes["sum_pixels"] + gridder.volumes["data_mask"] = gridder.volumes["data_mask"].round() + gridder.volumes["data_mask"] = image_from_array(gridder.volumes["data_mask"], gridder.out_spacing) + write_volume(gridder.volumes["data_mask"], name="data_mask_10", prefix=accumulator_prefix, paths=paths) del gridder.volumes return paths def classic_writer(gridder, grid_prefix, accumulator_prefix, target_spacings, **kwargs): - paths = [] - cb = functools.partial(write_volume, name='sum_pixel_intensities', prefix=accumulator_prefix, paths=paths) - gridder.consume_volume('sum_pixel_intensities', cb) - - cb = functools.partial(write_volume, name='injection_sum_pixel_intensities', prefix=accumulator_prefix, paths=paths) - gridder.consume_volume('injection_sum_pixel_intensities', cb) - - cb = functools.partial(write_volume, name='sum_pixels', prefix=accumulator_prefix, paths=paths) - gridder.accumulator_to_numpy('sum_pixels', cb) - - ratio_and_pyramid(gridder, 'sum_projecting_pixels', 'sum_pixels', 'projection_density', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'injection_sum_projecting_pixels', 'sum_pixels', 'injection_density', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'sum_projecting_pixel_intensities', 'sum_pixels', 'projection_energy', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'injectionsum_projecting_pixel_intensities', 'sum_pixels', 'injection_energy', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'injection_sum_pixels', 'sum_pixels', 'injection_fraction', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - ratio_and_pyramid(gridder, 'aav_exclusion_sum_pixels', 'sum_pixels', 'aav_exclusion_fraction', - accumulator_prefix, grid_prefix, target_spacings, paths=paths) - - gridder.volumes['data_mask'] = gridder.volumes['sum_pixels'] / np.amax(gridder.volumes['sum_pixels']) - del gridder.volumes['sum_pixels'] - gridder.volumes['data_mask'] = gridder.volumes['data_mask'].round() - handle_pyramid(gridder, 'data_mask', target_spacings, grid_prefix, paths=paths) + cb = functools.partial(write_volume, name="sum_pixel_intensities", prefix=accumulator_prefix, paths=paths) + gridder.consume_volume("sum_pixel_intensities", cb) + + cb = functools.partial(write_volume, name="injection_sum_pixel_intensities", prefix=accumulator_prefix, paths=paths) + gridder.consume_volume("injection_sum_pixel_intensities", cb) + + cb = functools.partial(write_volume, name="sum_pixels", prefix=accumulator_prefix, paths=paths) + gridder.accumulator_to_numpy("sum_pixels", cb) + + ratio_and_pyramid( + gridder, + "sum_projecting_pixels", + "sum_pixels", + "projection_density", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "injection_sum_projecting_pixels", + "sum_pixels", + "injection_density", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "sum_projecting_pixel_intensities", + "sum_pixels", + "projection_energy", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "injectionsum_projecting_pixel_intensities", + "sum_pixels", + "injection_energy", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "injection_sum_pixels", + "sum_pixels", + "injection_fraction", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + ratio_and_pyramid( + gridder, + "aav_exclusion_sum_pixels", + "sum_pixels", + "aav_exclusion_fraction", + accumulator_prefix, + grid_prefix, + target_spacings, + paths=paths, + ) + + gridder.volumes["data_mask"] = gridder.volumes["sum_pixels"] / np.amax(gridder.volumes["sum_pixels"]) + del gridder.volumes["sum_pixels"] + gridder.volumes["data_mask"] = gridder.volumes["data_mask"].round() + handle_pyramid(gridder, "data_mask", target_spacings, grid_prefix, paths=paths) return paths def handle_pyramid(isg, key, target_spacings, prefix, paths): - cspacing = isg.out_spacing[0] for tspacing in target_spacings: - downsampled = downsample_average(isg.volumes[key], cspacing, tspacing) - write_volume(image_from_array(downsampled, [tspacing] * 3), - key, prefix=prefix, specify_resolution=tspacing, paths=paths) + write_volume( + image_from_array(downsampled, [tspacing] * 3), key, prefix=prefix, specify_resolution=tspacing, paths=paths + ) del isg.volumes[key] def ratio_and_pyramid(isg, num, den, out, accumulator_prefix, grid_prefix, target_spacings, paths): - cb = functools.partial(write_volume, name=num, prefix=accumulator_prefix, paths=paths) isg.accumulator_to_numpy(num, cb) - + isg.make_ratio_volume(num, den, out) del isg.volumes[num] - handle_pyramid(isg, out, target_spacings, grid_prefix, paths) \ No newline at end of file + handle_pyramid(isg, out, target_spacings, grid_prefix, paths) diff --git a/allensdk/test/api/__init__.py b/allensdk/test/api/__init__.py index 93369a117d..31145a27f1 100644 --- a/allensdk/test/api/__init__.py +++ b/allensdk/test/api/__init__.py @@ -1,8 +1,8 @@ class SafeJsonMsg: - ''' Apes a paged query response from api.brain-map.org. - Safe to use with Pythons >= 3.7 (which implement pep 479, such that StopIteration errors in - generators are converted to RunTimeErrors). - ''' + """Apes a paged query response from api.brain-map.org. + Safe to use with Pythons >= 3.7 (which implement pep 479, such that StopIteration errors in + generators are converted to RunTimeErrors). + """ def __init__(self, data): self.data = iter(data) @@ -11,4 +11,4 @@ def __call__(self, *a, **k): try: return next(self.data) except StopIteration: - return {'msg': []} + return {"msg": []} diff --git a/allensdk/test/api/cloud_cache/conftest.py b/allensdk/test/api/cloud_cache/conftest.py index dbdd900e7c..be07c870ac 100644 --- a/allensdk/test/api/cloud_cache/conftest.py +++ b/allensdk/test/api/cloud_cache/conftest.py @@ -15,33 +15,24 @@ def example_datasets(): """ datasets = {} data = {} - data['f1.txt'] = {'data': b'1234567', - 'file_id': '1'} - data['f2.txt'] = {'data': b'4567890', - 'file_id': '2'} - data['f3.txt'] = {'data': b'11121314', - 'file_id': '3'} - datasets['1.0.0'] = data + data["f1.txt"] = {"data": b"1234567", "file_id": "1"} + data["f2.txt"] = {"data": b"4567890", "file_id": "2"} + data["f3.txt"] = {"data": b"11121314", "file_id": "3"} + datasets["1.0.0"] = data data = {} - data['f1.txt'] = {'data': b'abcdefg', - 'file_id': '1'} - data['f2.txt'] = {'data': b'4567890', - 'file_id': '2'} - data['f3.txt'] = {'data': b'11121314', - 'file_id': '3'} + data["f1.txt"] = {"data": b"abcdefg", "file_id": "1"} + data["f2.txt"] = {"data": b"4567890", "file_id": "2"} + data["f3.txt"] = {"data": b"11121314", "file_id": "3"} - datasets['2.0.0'] = data + datasets["2.0.0"] = data data = {} - data['f1.txt'] = {'data': b'1234567', - 'file_id': '1'} - data['f2.txt'] = {'data': b'xyzabcde', - 'file_id': '2'} - data['f3.txt'] = {'data': b'hijklmnop', - 'file_id': '3'} - - datasets['3.0.0'] = data + data["f1.txt"] = {"data": b"1234567", "file_id": "1"} + data["f2.txt"] = {"data": b"xyzabcde", "file_id": "2"} + data["f3.txt"] = {"data": b"hijklmnop", "file_id": "3"} + + datasets["3.0.0"] = data return datasets @@ -52,15 +43,15 @@ def baseline_data_with_metadata(): CloudCache API """ data = {} - data['f1.txt'] = {'file_id': '1', 'data': b'1234'} - data['f2.txt'] = {'file_id': '2', 'data': b'2345'} - data['f3.txt'] = {'file_id': '3', 'data': b'6789'} + data["f1.txt"] = {"file_id": "1", "data": b"1234"} + data["f2.txt"] = {"file_id": "2", "data": b"2345"} + data["f3.txt"] = {"file_id": "3", "data": b"6789"} metadata = {} - metadata['metadata_1.csv'] = b'abcdef' - metadata['metadata_2.csv'] = b'ghijklm' - metadata['metadata_3.csv'] = b'nopqrst' - return {'data': data, 'metadata': metadata} + metadata["metadata_1.csv"] = b"abcdef" + metadata["metadata_2.csv"] = b"ghijklm" + metadata["metadata_3.csv"] = b"nopqrst" + return {"data": data, "metadata": metadata} @pytest.fixture @@ -72,114 +63,114 @@ def example_datasets_with_metadata(baseline_data_with_metadata): """ example = {} - example['data'] = {} - example['metadata'] = {} + example["data"] = {} + example["metadata"] = {} data = copy.deepcopy(baseline_data_with_metadata) - example['data']['1.0.0'] = data['data'] - example['metadata']['1.0.0'] = data['metadata'] + example["data"]["1.0.0"] = data["data"] + example["metadata"]["1.0.0"] = data["metadata"] # delete one data file data = copy.deepcopy(baseline_data_with_metadata) - data['data'].pop('f2.txt') - example['data']['2.0.0'] = data['data'] - example['metadata']['2.0.0'] = data['metadata'] + data["data"].pop("f2.txt") + example["data"]["2.0.0"] = data["data"] + example["metadata"]["2.0.0"] = data["metadata"] # rename one data file data = copy.deepcopy(baseline_data_with_metadata) - old = data['data'].pop('f2.txt') - data['data']['f4.txt'] = {'file_id': '4', 'data': old['data']} - example['data']['3.0.0'] = data['data'] - example['metadata']['3.0.0'] = data['metadata'] + old = data["data"].pop("f2.txt") + data["data"]["f4.txt"] = {"file_id": "4", "data": old["data"]} + example["data"]["3.0.0"] = data["data"] + example["metadata"]["3.0.0"] = data["metadata"] # change one data file data = copy.deepcopy(baseline_data_with_metadata) - data['data']['f3.txt'] = {'file_id': '3', 'data': b'44556677'} - example['data']['4.0.0'] = data['data'] - example['metadata']['4.0.0'] = data['metadata'] + data["data"]["f3.txt"] = {"file_id": "3", "data": b"44556677"} + example["data"]["4.0.0"] = data["data"] + example["metadata"]["4.0.0"] = data["metadata"] # add a data file data = copy.deepcopy(baseline_data_with_metadata) - data['data']['f4.txt'] = {'file_id': '4', 'data': b'44556677'} - example['data']['5.0.0'] = data['data'] - example['metadata']['5.0.0'] = data['metadata'] + data["data"]["f4.txt"] = {"file_id": "4", "data": b"44556677"} + example["data"]["5.0.0"] = data["data"] + example["metadata"]["5.0.0"] = data["metadata"] # delete a data file and change another data = copy.deepcopy(baseline_data_with_metadata) - data['data'].pop('f2.txt') - data['data']['f1.txt'] = {'file_id': '1', 'data': b'xxxxxx'} - example['data']['6.0.0'] = data['data'] - example['metadata']['6.0.0'] = data['metadata'] + data["data"].pop("f2.txt") + data["data"]["f1.txt"] = {"file_id": "1", "data": b"xxxxxx"} + example["data"]["6.0.0"] = data["data"] + example["metadata"]["6.0.0"] = data["metadata"] # delete a data file and rename another data = copy.deepcopy(baseline_data_with_metadata) - data['data'].pop('f2.txt') - old = data['data'].pop('f3.txt') - data['data']['f5.txt'] = {'file_id': '5', 'data': old['data']} - example['data']['7.0.0'] = data['data'] - example['metadata']['7.0.0'] = data['metadata'] + data["data"].pop("f2.txt") + old = data["data"].pop("f3.txt") + data["data"]["f5.txt"] = {"file_id": "5", "data": old["data"]} + example["data"]["7.0.0"] = data["data"] + example["metadata"]["7.0.0"] = data["metadata"] # delete a data file and add another data = copy.deepcopy(baseline_data_with_metadata) - data['data'].pop('f2.txt') - data['data']['f5.txt'] = {'file_id': '5', 'data': b'yyyyy'} - example['data']['8.0.0'] = data['data'] - example['metadata']['8.0.0'] = data['metadata'] + data["data"].pop("f2.txt") + data["data"]["f5.txt"] = {"file_id": "5", "data": b"yyyyy"} + example["data"]["8.0.0"] = data["data"] + example["metadata"]["8.0.0"] = data["metadata"] # rename a data file and add another data = copy.deepcopy(baseline_data_with_metadata) - old = data['data'].pop('f3.txt') - data['data']['f4.txt'] = {'file_id': '4', 'data': old['data']} - data['data']['f5.txt'] = {'file_id': '5', 'data': b'wwwwww'} - example['data']['9.0.0'] = data['data'] - example['metadata']['9.0.0'] = data['metadata'] + old = data["data"].pop("f3.txt") + data["data"]["f4.txt"] = {"file_id": "4", "data": old["data"]} + data["data"]["f5.txt"] = {"file_id": "5", "data": b"wwwwww"} + example["data"]["9.0.0"] = data["data"] + example["metadata"]["9.0.0"] = data["metadata"] # delete a metadata file data = copy.deepcopy(baseline_data_with_metadata) - data['metadata'].pop('metadata_2.csv') - example['data']['10.0.0'] = data['data'] - example['metadata']['10.0.0'] = data['metadata'] + data["metadata"].pop("metadata_2.csv") + example["data"]["10.0.0"] = data["data"] + example["metadata"]["10.0.0"] = data["metadata"] # rename a metadata file data = copy.deepcopy(baseline_data_with_metadata) - old = data['metadata'].pop('metadata_2.csv') - data['metadata']['metadata_4.csv'] = old - example['data']['11.0.0'] = data['data'] - example['metadata']['11.0.0'] = data['metadata'] + old = data["metadata"].pop("metadata_2.csv") + data["metadata"]["metadata_4.csv"] = old + example["data"]["11.0.0"] = data["data"] + example["metadata"]["11.0.0"] = data["metadata"] # change a metadata file data = copy.deepcopy(baseline_data_with_metadata) - data['metadata']['metadata_3.csv'] = b'12345' - example['data']['12.0.0'] = data['data'] - example['metadata']['12.0.0'] = data['metadata'] + data["metadata"]["metadata_3.csv"] = b"12345" + example["data"]["12.0.0"] = data["data"] + example["metadata"]["12.0.0"] = data["metadata"] # add a metadata file data = copy.deepcopy(baseline_data_with_metadata) - data['metadata']['metadata_4.csv'] = b'12345' - example['data']['13.0.0'] = data['data'] - example['metadata']['13.0.0'] = data['metadata'] + data["metadata"]["metadata_4.csv"] = b"12345" + example["data"]["13.0.0"] = data["data"] + example["metadata"]["13.0.0"] = data["metadata"] # delete a data file and change a metadata file data = copy.deepcopy(baseline_data_with_metadata) - data['data'].pop('f2.txt') - old = data['metadata'].pop('metadata_3.csv') - data['metadata']['metadata_4.csv'] = old - example['data']['14.0.0'] = data['data'] - example['metadata']['14.0.0'] = data['metadata'] + data["data"].pop("f2.txt") + old = data["metadata"].pop("metadata_3.csv") + data["metadata"]["metadata_4.csv"] = old + example["data"]["14.0.0"] = data["data"] + example["metadata"]["14.0.0"] = data["metadata"] # rename a data file, add two data files # rename a metadata file and delete two metadata files data = copy.deepcopy(baseline_data_with_metadata) - old = data['data'].pop('f1.txt') - data['data']['f4.txt'] = old - data['data']['f5.txt'] = {'file_id': '5', 'data': b'babababa'} - data['data']['f6.txt'] = {'file_id': '6', 'data': b'neighneigh'} - old = data['metadata'].pop('metadata_2.csv') - data['metadata']['metadata_4.csv'] = old - data['metadata'].pop('metadata_1.csv') - data['metadata'].pop('metadata_3.csv') - - example['data']['15.0.0'] = data['data'] - example['metadata']['15.0.0'] = data['metadata'] + old = data["data"].pop("f1.txt") + data["data"]["f4.txt"] = old + data["data"]["f5.txt"] = {"file_id": "5", "data": b"babababa"} + data["data"]["f6.txt"] = {"file_id": "6", "data": b"neighneigh"} + old = data["metadata"].pop("metadata_2.csv") + data["metadata"]["metadata_4.csv"] = old + data["metadata"].pop("metadata_1.csv") + data["metadata"].pop("metadata_3.csv") + + example["data"]["15.0.0"] = data["data"] + example["metadata"]["15.0.0"] = data["metadata"] return example diff --git a/allensdk/test/api/cloud_cache/test_cache.py b/allensdk/test/api/cloud_cache/test_cache.py index d10220eb8e..8a8424b3a9 100644 --- a/allensdk/test/api/cloud_cache/test_cache.py +++ b/allensdk/test/api/cloud_cache/test_cache.py @@ -19,26 +19,19 @@ def test_list_all_manifests(tmpdir): Test that S3CloudCache.list_al_manifests() returns the correct result """ - test_bucket_name = 'list_manifest_bucket' + test_bucket_name = "list_manifest_bucket" - conn = boto3.resource('s3', region_name='us-east-1') + conn = boto3.resource("s3", region_name="us-east-1") conn.create_bucket(Bucket=test_bucket_name) - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=b'123456') - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v2.0.0.json', - Body=b'123456') - client.put_object(Bucket=test_bucket_name, - Key='junk.txt', - Body=b'123456') + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=b"123456") + client.put_object(Bucket=test_bucket_name, Key="proj/manifests/manifest_v2.0.0.json", Body=b"123456") + client.put_object(Bucket=test_bucket_name, Key="junk.txt", Body=b"123456") - cache = S3CloudCache(tmpdir, test_bucket_name, 'proj') + cache = S3CloudCache(tmpdir, test_bucket_name, "proj") - assert cache.manifest_file_names == ['manifest_v1.0.0.json', - 'manifest_v2.0.0.json'] + assert cache.manifest_file_names == ["manifest_v1.0.0.json", "manifest_v2.0.0.json"] @mock_s3 @@ -48,24 +41,20 @@ def test_list_all_manifests_many(tmpdir): can return at a time """ - test_bucket_name = 'list_manifest_bucket' + test_bucket_name = "list_manifest_bucket" - conn = boto3.resource('s3', region_name='us-east-1') + conn = boto3.resource("s3", region_name="us-east-1") conn.create_bucket(Bucket=test_bucket_name) - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") for ii in range(2000): - client.put_object(Bucket=test_bucket_name, - Key=f'proj/manifests/manifest_{ii}.json', - Body=b'123456') + client.put_object(Bucket=test_bucket_name, Key=f"proj/manifests/manifest_{ii}.json", Body=b"123456") - client.put_object(Bucket=test_bucket_name, - Key='junk.txt', - Body=b'123456') + client.put_object(Bucket=test_bucket_name, Key="junk.txt", Body=b"123456") - cache = S3CloudCache(tmpdir, test_bucket_name, 'proj') + cache = S3CloudCache(tmpdir, test_bucket_name, "proj") - expected = list([f'manifest_{ii}.json' for ii in range(2000)]) + expected = list([f"manifest_{ii}.json" for ii in range(2000)]) expected.sort() assert cache.manifest_file_names == expected @@ -76,63 +65,63 @@ def test_loading_manifest(tmpdir): Test loading manifests with S3CloudCache """ - test_bucket_name = 'list_manifest_bucket' - - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') - - client = boto3.client('s3', region_name='us-east-1') - - manifest_1 = {'manifest_version': '1', - 'metadata_file_id_column_name': 'file_id', - 'data_pipeline': 'placeholder', - 'project_name': 'sam-beckett', - 'data_files': {}, - 'metadata_files': {'a.csv': {'url': 'http://www.junk.com', - 'version_id': '1111', - 'file_hash': 'abcde'}, - 'b.csv': {'url': 'http://silly.com', - 'version_id': '2222', - 'file_hash': 'fghijk'}}} - - manifest_2 = {'manifest_version': '2', - 'metadata_file_id_column_name': 'file_id', - 'data_pipeline': 'placeholder', - 'project_name': 'al', - 'data_files': {}, - 'metadata_files': {'c.csv': {'url': 'http://www.absurd.com', - 'version_id': '3333', - 'file_hash': 'lmnop'}, - 'd.csv': {'url': 'http://nonsense.com', - 'version_id': '4444', - 'file_hash': 'qrstuv'}}} - - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=bytes(json.dumps(manifest_1), 'utf-8')) - - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v2.0.0.json', - Body=bytes(json.dumps(manifest_2), 'utf-8')) - - cache = S3CloudCache(pathlib.Path(tmpdir), test_bucket_name, 'proj') + test_bucket_name = "list_manifest_bucket" + + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") + + client = boto3.client("s3", region_name="us-east-1") + + manifest_1 = { + "manifest_version": "1", + "metadata_file_id_column_name": "file_id", + "data_pipeline": "placeholder", + "project_name": "sam-beckett", + "data_files": {}, + "metadata_files": { + "a.csv": {"url": "http://www.junk.com", "version_id": "1111", "file_hash": "abcde"}, + "b.csv": {"url": "http://silly.com", "version_id": "2222", "file_hash": "fghijk"}, + }, + } + + manifest_2 = { + "manifest_version": "2", + "metadata_file_id_column_name": "file_id", + "data_pipeline": "placeholder", + "project_name": "al", + "data_files": {}, + "metadata_files": { + "c.csv": {"url": "http://www.absurd.com", "version_id": "3333", "file_hash": "lmnop"}, + "d.csv": {"url": "http://nonsense.com", "version_id": "4444", "file_hash": "qrstuv"}, + }, + } + + client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=bytes(json.dumps(manifest_1), "utf-8") + ) + + client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v2.0.0.json", Body=bytes(json.dumps(manifest_2), "utf-8") + ) + + cache = S3CloudCache(pathlib.Path(tmpdir), test_bucket_name, "proj") assert cache.current_manifest is None - cache.load_manifest('manifest_v1.0.0.json') + cache.load_manifest("manifest_v1.0.0.json") assert cache._manifest._data == manifest_1 - assert cache.version == '1' - assert cache.file_id_column == 'file_id' - assert cache.metadata_file_names == ['a.csv', 'b.csv'] - assert cache.current_manifest == 'manifest_v1.0.0.json' + assert cache.version == "1" + assert cache.file_id_column == "file_id" + assert cache.metadata_file_names == ["a.csv", "b.csv"] + assert cache.current_manifest == "manifest_v1.0.0.json" - cache.load_manifest('manifest_v2.0.0.json') + cache.load_manifest("manifest_v2.0.0.json") assert cache._manifest._data == manifest_2 - assert cache.version == '2' - assert cache.file_id_column == 'file_id' - assert cache.metadata_file_names == ['c.csv', 'd.csv'] + assert cache.version == "2" + assert cache.file_id_column == "file_id" + assert cache.metadata_file_names == ["c.csv", "d.csv"] with pytest.raises(ValueError) as context: - cache.load_manifest('manifest_v3.0.0.json') - msg = 'is not one of the valid manifest names' + cache.load_manifest("manifest_v3.0.0.json") + msg = "is not one of the valid manifest names" assert msg in context.value.args[0] @@ -142,46 +131,37 @@ def test_file_exists(tmpdir): Test that cache._file_exists behaves correctly """ - data = b'aakderasjklsafetss77123523asf' + data = b"aakderasjklsafetss77123523asf" hasher = hashlib.blake2b() hasher.update(data) true_checksum = hasher.hexdigest() - test_file_path = pathlib.Path(tmpdir)/'junk.txt' - with open(test_file_path, 'wb') as out_file: + test_file_path = pathlib.Path(tmpdir) / "junk.txt" + with open(test_file_path, "wb") as out_file: out_file.write(data) # need to populate a bucket in order for # S3CloudCache to be instantiated - test_bucket_name = 'silly_bucket' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "silly_bucket" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") - cache = S3CloudCache(tmpdir, test_bucket_name, 'proj') + cache = S3CloudCache(tmpdir, test_bucket_name, "proj") # should be true - good_attribute = CacheFileAttributes('http://silly.url.com', - '12345', - true_checksum, - test_file_path) + good_attribute = CacheFileAttributes("http://silly.url.com", "12345", true_checksum, test_file_path) assert cache._file_exists(good_attribute) # test when file path is wrong - bad_path = pathlib.Path('definitely/not/a/file.txt') - bad_attribute = CacheFileAttributes('http://silly.url.com', - '12345', - true_checksum, - bad_path) + bad_path = pathlib.Path("definitely/not/a/file.txt") + bad_attribute = CacheFileAttributes("http://silly.url.com", "12345", true_checksum, bad_path) assert not cache._file_exists(bad_attribute) # test when path exists but is not a file - bad_attribute = CacheFileAttributes('http://silly.url.com', - '12345', - true_checksum, - pathlib.Path(tmpdir)) + bad_attribute = CacheFileAttributes("http://silly.url.com", "12345", true_checksum, pathlib.Path(tmpdir)) with pytest.raises(RuntimeError) as context: cache._file_exists(bad_attribute) - assert 'but is not a file' in context.value.args[0] + assert "but is not a file" in context.value.args[0] @mock_s3 @@ -191,43 +171,38 @@ def test_download_file(tmpdir): """ hasher = hashlib.blake2b() - data = b'11235813kjlssergwesvsdd' + data = b"11235813kjlssergwesvsdd" hasher.update(data) true_checksum = hasher.hexdigest() - test_bucket_name = 'bucket_for_download' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_download" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='data/data_file.txt', - Body=data) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="data/data_file.txt", Body=data) response = client.list_object_versions(Bucket=test_bucket_name) - version_id = response['Versions'][0]['VersionId'] + version_id = response["Versions"][0]["VersionId"] - cache_dir = pathlib.Path(tmpdir) / 'download/test/cache' - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache_dir = pathlib.Path(tmpdir) / "download/test/cache" + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") - expected_path = cache_dir / true_checksum / 'data/data_file.txt' + expected_path = cache_dir / true_checksum / "data/data_file.txt" - url = f'http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt' - good_attributes = CacheFileAttributes(url, - version_id, - true_checksum, - expected_path) + url = f"http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt" + good_attributes = CacheFileAttributes(url, version_id, true_checksum, expected_path) assert not expected_path.exists() cache._download_file(good_attributes) assert expected_path.exists() hasher = hashlib.blake2b() - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_checksum @@ -244,81 +219,71 @@ def test_download_file_multiple_versions(tmpdir): """ hasher = hashlib.blake2b() - data_1 = b'11235813kjlssergwesvsdd' + data_1 = b"11235813kjlssergwesvsdd" hasher.update(data_1) true_checksum_1 = hasher.hexdigest() hasher = hashlib.blake2b() - data_2 = b'zzzzxxxxyyyywwwwjjjj' + data_2 = b"zzzzxxxxyyyywwwwjjjj" hasher.update(data_2) true_checksum_2 = hasher.hexdigest() assert true_checksum_2 != true_checksum_1 - test_bucket_name = 'bucket_for_download_versions' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_download_versions" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='data/data_file.txt', - Body=data_1) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="data/data_file.txt", Body=data_1) response = client.list_object_versions(Bucket=test_bucket_name) - version_id_1 = response['Versions'][0]['VersionId'] + version_id_1 = response["Versions"][0]["VersionId"] - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='data/data_file.txt', - Body=data_2) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="data/data_file.txt", Body=data_2) response = client.list_object_versions(Bucket=test_bucket_name) version_id_2 = None - for v in response['Versions']: - if v['IsLatest']: - version_id_2 = v['VersionId'] + for v in response["Versions"]: + if v["IsLatest"]: + version_id_2 = v["VersionId"] assert version_id_2 is not None assert version_id_2 != version_id_1 - cache_dir = pathlib.Path(tmpdir) / 'download/test/cache' - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache_dir = pathlib.Path(tmpdir) / "download/test/cache" + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") - url = f'http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt' + url = f"http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt" # download first version of file - expected_path = cache_dir / true_checksum_1 / 'data/data_file.txt' + expected_path = cache_dir / true_checksum_1 / "data/data_file.txt" - good_attributes = CacheFileAttributes(url, - version_id_1, - true_checksum_1, - expected_path) + good_attributes = CacheFileAttributes(url, version_id_1, true_checksum_1, expected_path) assert not expected_path.exists() cache._download_file(good_attributes) assert expected_path.exists() hasher = hashlib.blake2b() - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_checksum_1 # download second version of file - expected_path = cache_dir / true_checksum_2 / 'data/data_file.txt' + expected_path = cache_dir / true_checksum_2 / "data/data_file.txt" - good_attributes = CacheFileAttributes(url, - version_id_2, - true_checksum_2, - expected_path) + good_attributes = CacheFileAttributes(url, version_id_2, true_checksum_2, expected_path) assert not expected_path.exists() cache._download_file(good_attributes) assert expected_path.exists() hasher = hashlib.blake2b() - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_checksum_2 @@ -331,43 +296,38 @@ def test_re_download_file(tmpdir): """ hasher = hashlib.blake2b() - data = b'11235813kjlssergwesvsdd' + data = b"11235813kjlssergwesvsdd" hasher.update(data) true_checksum = hasher.hexdigest() - test_bucket_name = 'bucket_for_re_download' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_re_download" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='data/data_file.txt', - Body=data) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="data/data_file.txt", Body=data) response = client.list_object_versions(Bucket=test_bucket_name) - version_id = response['Versions'][0]['VersionId'] + version_id = response["Versions"][0]["VersionId"] - cache_dir = pathlib.Path(tmpdir) / 'download/test/cache' - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache_dir = pathlib.Path(tmpdir) / "download/test/cache" + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") - expected_path = cache_dir / true_checksum / 'data/data_file.txt' + expected_path = cache_dir / true_checksum / "data/data_file.txt" - url = f'http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt' - good_attributes = CacheFileAttributes(url, - version_id, - true_checksum, - expected_path) + url = f"http://{test_bucket_name}.s3.amazonaws.com/data/data_file.txt" + good_attributes = CacheFileAttributes(url, version_id, true_checksum, expected_path) assert not expected_path.exists() cache._download_file(good_attributes) assert expected_path.exists() hasher = hashlib.blake2b() - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_checksum @@ -378,7 +338,7 @@ def test_re_download_file(tmpdir): cache._download_file(good_attributes) assert expected_path.exists() hasher = hashlib.blake2b() - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_checksum @@ -390,56 +350,52 @@ def test_download_data(tmpdir): """ hasher = hashlib.blake2b() - data = b'11235813kjlssergwesvsdd' + data = b"11235813kjlssergwesvsdd" hasher.update(data) true_checksum = hasher.hexdigest() - test_bucket_name = 'bucket_for_download_data' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_download_data" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='data/data_file.txt', - Body=data) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="data/data_file.txt", Body=data) response = client.list_object_versions(Bucket=test_bucket_name) - version_id = response['Versions'][0]['VersionId'] + version_id = response["Versions"][0]["VersionId"] manifest = {} - manifest['manifest_version'] = '1' - manifest['project_name'] = "project-z" - manifest['metadata_file_id_column_name'] = 'file_id' - manifest['metadata_files'] = {} - url = f'http://{test_bucket_name}.s3.amazonaws.com/project-z/data/data_file.txt' # noqa: E501 - data_file = {'url': url, - 'version_id': version_id, - 'file_hash': true_checksum} - - manifest['data_files'] = {'only_data_file': data_file} - manifest['data_pipeline'] = 'placeholder' - - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=bytes(json.dumps(manifest), 'utf-8')) + manifest["manifest_version"] = "1" + manifest["project_name"] = "project-z" + manifest["metadata_file_id_column_name"] = "file_id" + manifest["metadata_files"] = {} + url = f"http://{test_bucket_name}.s3.amazonaws.com/project-z/data/data_file.txt" # noqa: E501 + data_file = {"url": url, "version_id": version_id, "file_hash": true_checksum} + + manifest["data_files"] = {"only_data_file": data_file} + manifest["data_pipeline"] = "placeholder" + + client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=bytes(json.dumps(manifest), "utf-8") + ) cache_dir = pathlib.Path(tmpdir) / "data/path/cache" - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") - cache.load_manifest('manifest_v1.0.0.json') + cache.load_manifest("manifest_v1.0.0.json") - expected_path = cache_dir / 'project-z-1' / 'data/data_file.txt' + expected_path = cache_dir / "project-z-1" / "data/data_file.txt" assert not expected_path.exists() # test data_path - attr = cache.data_path('only_data_file') - assert attr['local_path'] == expected_path - assert not attr['exists'] + attr = cache.data_path("only_data_file") + assert attr["local_path"] == expected_path + assert not attr["exists"] # NOTE: commenting out because moto does not support # list_object_versions and this is becoming difficult @@ -466,63 +422,57 @@ def test_download_metadata(tmpdir): """ hasher = hashlib.blake2b() - data = b'11235813kjlssergwesvsdd' + data = b"11235813kjlssergwesvsdd" hasher.update(data) true_checksum = hasher.hexdigest() - test_bucket_name = 'bucket_for_download_metadata' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_download_metadata" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - meta_version = client.put_object(Bucket=test_bucket_name, - Key='metadata_file.csv', - Body=data)["VersionId"] + client = boto3.client("s3", region_name="us-east-1") + meta_version = client.put_object(Bucket=test_bucket_name, Key="metadata_file.csv", Body=data)["VersionId"] response = client.list_object_versions(Bucket=test_bucket_name) - version_id = response['Versions'][0]['VersionId'] + version_id = response["Versions"][0]["VersionId"] manifest = {} - manifest['manifest_version'] = '1' - manifest['project_name'] = "project4" - manifest['metadata_file_id_column_name'] = 'file_id' - url = f'http://{test_bucket_name}.s3.amazonaws.com/project4/metadata_file.csv' # noqa: E501 - metadata_file = {'url': url, - 'version_id': version_id, - 'file_hash': true_checksum} - - manifest['metadata_files'] = {'metadata_file.csv': metadata_file} - manifest['data_files'] = {} - manifest['data_pipeline'] = 'placeholder' - - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=bytes(json.dumps(manifest), 'utf-8')) + manifest["manifest_version"] = "1" + manifest["project_name"] = "project4" + manifest["metadata_file_id_column_name"] = "file_id" + url = f"http://{test_bucket_name}.s3.amazonaws.com/project4/metadata_file.csv" # noqa: E501 + metadata_file = {"url": url, "version_id": version_id, "file_hash": true_checksum} + + manifest["metadata_files"] = {"metadata_file.csv": metadata_file} + manifest["data_files"] = {} + manifest["data_pipeline"] = "placeholder" + + client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=bytes(json.dumps(manifest), "utf-8") + ) cache_dir = pathlib.Path(tmpdir) / "metadata/path/cache" - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") - cache.load_manifest('manifest_v1.0.0.json') + cache.load_manifest("manifest_v1.0.0.json") - expected_path = cache_dir / "project4-1" / 'metadata_file.csv' + expected_path = cache_dir / "project4-1" / "metadata_file.csv" assert not expected_path.exists() # test that metadata_path also works - attr = cache.metadata_path('metadata_file.csv') - assert attr['local_path'] == expected_path - assert not attr['exists'] + attr = cache.metadata_path("metadata_file.csv") + assert attr["local_path"] == expected_path + assert not attr["exists"] def response_fun(Bucket, Prefix): # moto doesn't cover list_object_versions - return {"Versions": [{ - "VersionId": meta_version, - "Key": "metadata_file.csv", - "Size": 12}]} + return {"Versions": [{"VersionId": meta_version, "Key": "metadata_file.csv", "Size": 12}]} + # cache.s3_client.list_object_versions = response_fun # NOTE: commenting out because moto does not support @@ -548,59 +498,55 @@ def test_metadata(tmpdir): Test that S3CloudCache.metadata() returns the expected pandas DataFrame """ data = {} - data['mouse_id'] = [1, 4, 6, 8] - data['sex'] = ['F', 'F', 'M', 'M'] - data['age'] = ['P50', 'P46', 'P23', 'P40'] + data["mouse_id"] = [1, 4, 6, 8] + data["sex"] = ["F", "F", "M", "M"] + data["age"] = ["P50", "P46", "P23", "P40"] true_df = pd.DataFrame(data) with io.StringIO() as stream: true_df.to_csv(stream, index=False) stream.seek(0) - data = bytes(stream.read(), 'utf-8') + data = bytes(stream.read(), "utf-8") hasher = hashlib.blake2b() hasher.update(data) true_checksum = hasher.hexdigest() - test_bucket_name = 'bucket_for_metadata' - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + test_bucket_name = "bucket_for_metadata" + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=test_bucket_name, - Key='metadata_file.csv', - Body=data) + client = boto3.client("s3", region_name="us-east-1") + client.put_object(Bucket=test_bucket_name, Key="metadata_file.csv", Body=data) response = client.list_object_versions(Bucket=test_bucket_name) - version_id = response['Versions'][0]['VersionId'] + version_id = response["Versions"][0]["VersionId"] manifest = {} - manifest['manifest_version'] = '1' - manifest['project_name'] = "project-X" - manifest['metadata_file_id_column_name'] = 'file_id' - url = f'http://{test_bucket_name}.s3.amazonaws.com/metadata_file.csv' - metadata_file = {'url': url, - 'version_id': version_id, - 'file_hash': true_checksum} - - manifest['metadata_files'] = {'metadata_file.csv': metadata_file} - manifest['data_files'] = {} - manifest['data_pipeline'] = 'placeholder' - - client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=bytes(json.dumps(manifest), 'utf-8')) + manifest["manifest_version"] = "1" + manifest["project_name"] = "project-X" + manifest["metadata_file_id_column_name"] = "file_id" + url = f"http://{test_bucket_name}.s3.amazonaws.com/metadata_file.csv" + metadata_file = {"url": url, "version_id": version_id, "file_hash": true_checksum} + + manifest["metadata_files"] = {"metadata_file.csv": metadata_file} + manifest["data_files"] = {} + manifest["data_pipeline"] = "placeholder" + + client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=bytes(json.dumps(manifest), "utf-8") + ) cache_dir = pathlib.Path(tmpdir) / "metadata/cache" - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') - cache.load_manifest('manifest_v1.0.0.json') + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") + cache.load_manifest("manifest_v1.0.0.json") - metadata_df = cache.get_metadata('metadata_file.csv') + metadata_df = cache.get_metadata("metadata_file.csv") assert true_df.equals(metadata_df) @@ -610,23 +556,23 @@ def test_latest_manifest(tmpdir, example_datasets_with_metadata): Test that the methods which return the latest and latest downloaded manifest file names work correctly """ - bucket_name = 'latest_manifest_bucket' - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=example_datasets_with_metadata['metadata']) + bucket_name = "latest_manifest_bucket" + create_bucket( + bucket_name, example_datasets_with_metadata["data"], metadatasets=example_datasets_with_metadata["metadata"] + ) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - assert cache.latest_downloaded_manifest_file == '' + assert cache.latest_downloaded_manifest_file == "" - cache.load_manifest('project-x_manifest_v7.0.0.json') - cache.load_manifest('project-x_manifest_v3.0.0.json') - cache.load_manifest('project-x_manifest_v2.0.0.json') + cache.load_manifest("project-x_manifest_v7.0.0.json") + cache.load_manifest("project-x_manifest_v3.0.0.json") + cache.load_manifest("project-x_manifest_v2.0.0.json") - assert cache.latest_manifest_file == 'project-x_manifest_v15.0.0.json' + assert cache.latest_manifest_file == "project-x_manifest_v15.0.0.json" - expected = 'project-x_manifest_v7.0.0.json' + expected = "project-x_manifest_v7.0.0.json" assert cache.latest_downloaded_manifest_file == expected @@ -637,26 +583,24 @@ def test_outdated_manifest_warning(tmpdir, example_datasets_with_metadata): manifest """ - bucket_name = 'outdated_manifest_bucket' - metadatasets = example_datasets_with_metadata['metadata'] - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=metadatasets) + bucket_name = "outdated_manifest_bucket" + metadatasets = example_datasets_with_metadata["metadata"] + create_bucket(bucket_name, example_datasets_with_metadata["data"], metadatasets=metadatasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - m_warn_type = 'OutdatedManifestWarning' + m_warn_type = "OutdatedManifestWarning" with pytest.warns(OutdatedManifestWarning) as warnings: - cache.load_manifest('project-x_manifest_v7.0.0.json') + cache.load_manifest("project-x_manifest_v7.0.0.json") ct = 0 for w in warnings.list: if w._category_name == m_warn_type: msg = str(w.message) - assert 'is not the most up to date' in msg - assert 'S3CloudCache.compare_manifests' in msg - assert 'load_latest_manifest' in msg + assert "is not the most up to date" in msg + assert "S3CloudCache.compare_manifests" in msg + assert "load_latest_manifest" in msg ct += 1 assert ct > 0 @@ -665,9 +609,9 @@ def test_outdated_manifest_warning(tmpdir, example_datasets_with_metadata): # not OutdatedManifestWarnings with warnings_mod.catch_warnings(record=True) as w: warnings_mod.simplefilter("always") - cache.load_manifest('project-x_manifest_v11.0.0.json') + cache.load_manifest("project-x_manifest_v11.0.0.json") for wi in w: - assert wi.category.__name__ != 'OutdatedManifestWarning' + assert wi.category.__name__ != "OutdatedManifestWarning" @mock_s3 @@ -676,27 +620,23 @@ def test_list_all_downloaded(tmpdir, example_datasets_with_metadata): Test that list_all_downloaded_manifests works """ - bucket_name = 'outdated_manifest_bucket' - metadatasets = example_datasets_with_metadata['metadata'] - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=metadatasets) + bucket_name = "outdated_manifest_bucket" + metadatasets = example_datasets_with_metadata["metadata"] + create_bucket(bucket_name, example_datasets_with_metadata["data"], metadatasets=metadatasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") assert cache.list_all_downloaded_manifests() == [] - cache.load_manifest('project-x_manifest_v5.0.0.json') - assert cache.current_manifest == 'project-x_manifest_v5.0.0.json' - cache.load_manifest('project-x_manifest_v2.0.0.json') - assert cache.current_manifest == 'project-x_manifest_v2.0.0.json' - cache.load_manifest('project-x_manifest_v3.0.0.json') - assert cache.current_manifest == 'project-x_manifest_v3.0.0.json' + cache.load_manifest("project-x_manifest_v5.0.0.json") + assert cache.current_manifest == "project-x_manifest_v5.0.0.json" + cache.load_manifest("project-x_manifest_v2.0.0.json") + assert cache.current_manifest == "project-x_manifest_v2.0.0.json" + cache.load_manifest("project-x_manifest_v3.0.0.json") + assert cache.current_manifest == "project-x_manifest_v3.0.0.json" - expected = {'project-x_manifest_v5.0.0.json', - 'project-x_manifest_v2.0.0.json', - 'project-x_manifest_v3.0.0.json'} + expected = {"project-x_manifest_v5.0.0.json", "project-x_manifest_v2.0.0.json", "project-x_manifest_v3.0.0.json"} downloaded = set(cache.list_all_downloaded_manifests()) assert downloaded == expected @@ -708,24 +648,22 @@ def test_latest_manifest_warning(tmpdir, example_datasets_with_metadata): to load_latest_manifest but that has not been downloaded yet """ - bucket_name = 'outdated_manifest_bucket' - metadatasets = example_datasets_with_metadata['metadata'] - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=metadatasets) + bucket_name = "outdated_manifest_bucket" + metadatasets = example_datasets_with_metadata["metadata"] + create_bucket(bucket_name, example_datasets_with_metadata["data"], metadatasets=metadatasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - cache.load_manifest('project-x_manifest_v4.0.0.json') + cache.load_manifest("project-x_manifest_v4.0.0.json") with pytest.warns(OutdatedManifestWarning) as warnings: cache.load_latest_manifest() assert len(warnings) == 1 msg = str(warnings[0].message) - assert 'project-x_manifest_v4.0.0.json' in msg - assert 'project-x_manifest_v15.0.0.json' in msg - assert 'It is possible that some data files' in msg + assert "project-x_manifest_v4.0.0.json" in msg + assert "project-x_manifest_v15.0.0.json" in msg + assert "It is possible that some data files" in msg cmd = "S3CloudCache.load_manifest('project-x_manifest_v4.0.0.json')" assert cmd in msg @@ -735,14 +673,12 @@ def test_load_last_manifest(tmpdir, example_datasets_with_metadata): """ Test that load_last_manifest works """ - bucket_name = 'load_lst_manifest_bucket' - metadatasets = example_datasets_with_metadata['metadata'] - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=metadatasets) + bucket_name = "load_lst_manifest_bucket" + metadatasets = example_datasets_with_metadata["metadata"] + create_bucket(bucket_name, example_datasets_with_metadata["data"], metadatasets=metadatasets) - cache_dir = pathlib.Path(tmpdir) / 'load_last_cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "load_last_cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") # check that load_last_manifest in a new cache loads the # latest manifest without emitting a warning @@ -751,64 +687,59 @@ def test_load_last_manifest(tmpdir, example_datasets_with_metadata): cache.load_last_manifest() ct = 0 for wi in w: - if wi.category.__name__ == 'OutdatedManifestWarning': + if wi.category.__name__ == "OutdatedManifestWarning": ct += 1 assert ct == 0 - assert cache.current_manifest == 'project-x_manifest_v15.0.0.json' + assert cache.current_manifest == "project-x_manifest_v15.0.0.json" - cache.load_manifest('project-x_manifest_v7.0.0.json') + cache.load_manifest("project-x_manifest_v7.0.0.json") del cache # check that load_last_manifest on an old cache emits the # expected warning and loads the correct manifest - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - expected = 'A more up to date version of the ' - expected += 'dataset -- project-x_manifest_v15.0.0.json ' - expected += '-- exists online' - with pytest.warns(OutdatedManifestWarning, - match=expected): + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + expected = "A more up to date version of the " + expected += "dataset -- project-x_manifest_v15.0.0.json " + expected += "-- exists online" + with pytest.warns(OutdatedManifestWarning, match=expected): cache.load_last_manifest() - assert cache.current_manifest == 'project-x_manifest_v7.0.0.json' - cache.load_manifest('project-x_manifest_v4.0.0.json') + assert cache.current_manifest == "project-x_manifest_v7.0.0.json" + cache.load_manifest("project-x_manifest_v4.0.0.json") del cache # repeat the above test, making sure the correct manifest is # loaded again - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - expected = 'A more up to date version of the ' - expected += 'dataset -- project-x_manifest_v15.0.0.json ' - expected += '-- exists online' - with pytest.warns(OutdatedManifestWarning, - match=expected): + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + expected = "A more up to date version of the " + expected += "dataset -- project-x_manifest_v15.0.0.json " + expected += "-- exists online" + with pytest.warns(OutdatedManifestWarning, match=expected): cache.load_last_manifest() - assert cache.current_manifest == 'project-x_manifest_v4.0.0.json' + assert cache.current_manifest == "project-x_manifest_v4.0.0.json" @mock_s3 -def test_corrupted_load_last_manifest(tmpdir, - example_datasets_with_metadata): +def test_corrupted_load_last_manifest(tmpdir, example_datasets_with_metadata): """ Test that load_last_manifest works when the record of the last manifest has been corrupted """ - bucket_name = 'load_lst_manifest_bucket' - metadatasets = example_datasets_with_metadata['metadata'] - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=metadatasets) + bucket_name = "load_lst_manifest_bucket" + metadatasets = example_datasets_with_metadata["metadata"] + create_bucket(bucket_name, example_datasets_with_metadata["data"], metadatasets=metadatasets) - cache_dir = pathlib.Path(tmpdir) / 'load_last_cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - cache.load_manifest('project-x_manifest_v9.0.0.json') + cache_dir = pathlib.Path(tmpdir) / "load_last_cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + cache.load_manifest("project-x_manifest_v9.0.0.json") fname = cache._manifest_last_used.resolve() del cache - with open(fname, 'w') as out_file: - out_file.write('babababa') - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - expected = 'Loading latest version -- project-x_manifest_v15.0.0.json' + with open(fname, "w") as out_file: + out_file.write("babababa") + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + expected = "Loading latest version -- project-x_manifest_v15.0.0.json" with pytest.warns(UserWarning, match=expected): cache.load_last_manifest() - assert cache.current_manifest == 'project-x_manifest_v15.0.0.json' + assert cache.current_manifest == "project-x_manifest_v15.0.0.json" diff --git a/allensdk/test/api/cloud_cache/test_change_log.py b/allensdk/test/api/cloud_cache/test_change_log.py index 91adf1ce45..face208519 100644 --- a/allensdk/test/api/cloud_cache/test_change_log.py +++ b/allensdk/test/api/cloud_cache/test_change_log.py @@ -10,142 +10,120 @@ def test_summarize_comparison(tmpdir, example_datasets_with_metadata): Test that CloudCacheBase.summarize_comparison reports the correct changes when comparing two manifests """ - bucket_name = 'summarizing_bucket' - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=example_datasets_with_metadata['metadata']) - - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v2.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 1 - assert ('data/f2.txt', 'data/f2.txt deleted') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v3.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 1 - assert ('data/f2.txt', - 'data/f2.txt renamed data/f4.txt') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v4.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 1 - assert ('data/f3.txt', 'data/f3.txt changed') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v5.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 1 - assert ('data/f4.txt', 'data/f4.txt created') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v6.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 2 - assert ('data/f2.txt', 'data/f2.txt deleted') in log['data_changes'] - assert ('data/f1.txt', 'data/f1.txt changed') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v7.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 2 - assert ('data/f2.txt', 'data/f2.txt deleted') in log['data_changes'] - assert ('data/f3.txt', 'data/f3.txt ' - 'renamed data/f5.txt') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v8.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 2 - assert ('data/f2.txt', 'data/f2.txt deleted') in log['data_changes'] - assert ('data/f5.txt', 'data/f5.txt created') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v9.0.0.json') - - assert len(log['metadata_changes']) == 0 - assert len(log['data_changes']) == 2 - assert ('data/f3.txt', 'data/f3.txt renamed ' - 'data/f4.txt') in log['data_changes'] - assert ('data/f5.txt', 'data/f5.txt created') in log['data_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v10.0.0.json') - - assert len(log['data_changes']) == 0 - assert len(log['metadata_changes']) == 1 - assert ('project_metadata/metadata_2.csv', - 'project_metadata/metadata_2.csv ' - 'deleted') in log['metadata_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v11.0.0.json') - - assert len(log['data_changes']) == 0 - assert len(log['metadata_changes']) == 1 - assert ('project_metadata/metadata_2.csv', - 'project_metadata/metadata_2.csv renamed ' - 'project_metadata/metadata_4.csv') in log['metadata_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v12.0.0.json') - - assert len(log['data_changes']) == 0 - assert len(log['metadata_changes']) == 1 - assert ('project_metadata/metadata_3.csv', - 'project_metadata/metadata_3.csv ' - 'changed') in log['metadata_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v13.0.0.json') - - assert len(log['data_changes']) == 0 - assert len(log['metadata_changes']) == 1 - assert ('project_metadata/metadata_4.csv', - 'project_metadata/metadata_4.csv ' - 'created') in log['metadata_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v14.0.0.json') - assert len(log['data_changes']) == 1 - assert len(log['metadata_changes']) == 1 - assert ('data/f2.txt', 'data/f2.txt deleted') in log['data_changes'] - assert ('project_metadata/metadata_3.csv', - 'project_metadata/metadata_3.csv renamed ' - 'project_metadata/metadata_4.csv') in log['metadata_changes'] - - log = cache.summarize_comparison('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v15.0.0.json') - assert len(log['data_changes']) == 3 - assert len(log['metadata_changes']) == 3 - - ans1 = ('data/f1.txt', 'data/f1.txt renamed data/f4.txt') - ans2 = ('data/f5.txt', 'data/f5.txt created') - ans3 = ('data/f6.txt', 'data/f6.txt created') - - assert set(log['data_changes']) == {ans1, ans2, ans3} - - ans1 = ('project_metadata/metadata_2.csv', - 'project_metadata/metadata_2.csv renamed ' - 'project_metadata/metadata_4.csv') - ans2 = ('project_metadata/metadata_1.csv', - 'project_metadata/metadata_1.csv deleted') - ans3 = ('project_metadata/metadata_3.csv', - 'project_metadata/metadata_3.csv deleted') - - assert set(log['metadata_changes']) == {ans1, ans2, ans3} + bucket_name = "summarizing_bucket" + create_bucket( + bucket_name, example_datasets_with_metadata["data"], metadatasets=example_datasets_with_metadata["metadata"] + ) + + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v2.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 1 + assert ("data/f2.txt", "data/f2.txt deleted") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v3.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 1 + assert ("data/f2.txt", "data/f2.txt renamed data/f4.txt") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v4.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 1 + assert ("data/f3.txt", "data/f3.txt changed") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v5.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 1 + assert ("data/f4.txt", "data/f4.txt created") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v6.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 2 + assert ("data/f2.txt", "data/f2.txt deleted") in log["data_changes"] + assert ("data/f1.txt", "data/f1.txt changed") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v7.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 2 + assert ("data/f2.txt", "data/f2.txt deleted") in log["data_changes"] + assert ("data/f3.txt", "data/f3.txt renamed data/f5.txt") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v8.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 2 + assert ("data/f2.txt", "data/f2.txt deleted") in log["data_changes"] + assert ("data/f5.txt", "data/f5.txt created") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v9.0.0.json") + + assert len(log["metadata_changes"]) == 0 + assert len(log["data_changes"]) == 2 + assert ("data/f3.txt", "data/f3.txt renamed data/f4.txt") in log["data_changes"] + assert ("data/f5.txt", "data/f5.txt created") in log["data_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v10.0.0.json") + + assert len(log["data_changes"]) == 0 + assert len(log["metadata_changes"]) == 1 + assert ("project_metadata/metadata_2.csv", "project_metadata/metadata_2.csv deleted") in log["metadata_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v11.0.0.json") + + assert len(log["data_changes"]) == 0 + assert len(log["metadata_changes"]) == 1 + assert ( + "project_metadata/metadata_2.csv", + "project_metadata/metadata_2.csv renamed project_metadata/metadata_4.csv", + ) in log["metadata_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v12.0.0.json") + + assert len(log["data_changes"]) == 0 + assert len(log["metadata_changes"]) == 1 + assert ("project_metadata/metadata_3.csv", "project_metadata/metadata_3.csv changed") in log["metadata_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v13.0.0.json") + + assert len(log["data_changes"]) == 0 + assert len(log["metadata_changes"]) == 1 + assert ("project_metadata/metadata_4.csv", "project_metadata/metadata_4.csv created") in log["metadata_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v14.0.0.json") + assert len(log["data_changes"]) == 1 + assert len(log["metadata_changes"]) == 1 + assert ("data/f2.txt", "data/f2.txt deleted") in log["data_changes"] + assert ( + "project_metadata/metadata_3.csv", + "project_metadata/metadata_3.csv renamed project_metadata/metadata_4.csv", + ) in log["metadata_changes"] + + log = cache.summarize_comparison("project-x_manifest_v1.0.0.json", "project-x_manifest_v15.0.0.json") + assert len(log["data_changes"]) == 3 + assert len(log["metadata_changes"]) == 3 + + ans1 = ("data/f1.txt", "data/f1.txt renamed data/f4.txt") + ans2 = ("data/f5.txt", "data/f5.txt created") + ans3 = ("data/f6.txt", "data/f6.txt created") + + assert set(log["data_changes"]) == {ans1, ans2, ans3} + + ans1 = ( + "project_metadata/metadata_2.csv", + "project_metadata/metadata_2.csv renamed project_metadata/metadata_4.csv", + ) + ans2 = ("project_metadata/metadata_1.csv", "project_metadata/metadata_1.csv deleted") + ans3 = ("project_metadata/metadata_3.csv", "project_metadata/metadata_3.csv deleted") + + assert set(log["metadata_changes"]) == {ans1, ans2, ans3} @mock_s3 @@ -155,27 +133,26 @@ def test_compare_manifesst_string(tmpdir, example_datasets_with_metadata): Test that CloudCacheBase.compare_manifests reports the correct changes when comparing two manifests """ - bucket_name = 'compare_manifest_bucket' - create_bucket(bucket_name, - example_datasets_with_metadata['data'], - metadatasets=example_datasets_with_metadata['metadata']) - - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') - - msg = cache.compare_manifests('project-x_manifest_v1.0.0.json', - 'project-x_manifest_v15.0.0.json') - - expected = 'Changes going from\n' - expected += 'project-x_manifest_v1.0.0.json\n' - expected += 'to\n' - expected += 'project-x_manifest_v15.0.0.json\n\n' - expected += 'project_metadata/metadata_1.csv deleted\n' - expected += 'project_metadata/metadata_2.csv renamed ' - expected += 'project_metadata/metadata_4.csv\n' - expected += 'project_metadata/metadata_3.csv deleted\n' - expected += 'data/f1.txt renamed data/f4.txt\n' - expected += 'data/f5.txt created\n' - expected += 'data/f6.txt created\n' + bucket_name = "compare_manifest_bucket" + create_bucket( + bucket_name, example_datasets_with_metadata["data"], metadatasets=example_datasets_with_metadata["metadata"] + ) + + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") + + msg = cache.compare_manifests("project-x_manifest_v1.0.0.json", "project-x_manifest_v15.0.0.json") + + expected = "Changes going from\n" + expected += "project-x_manifest_v1.0.0.json\n" + expected += "to\n" + expected += "project-x_manifest_v15.0.0.json\n\n" + expected += "project_metadata/metadata_1.csv deleted\n" + expected += "project_metadata/metadata_2.csv renamed " + expected += "project_metadata/metadata_4.csv\n" + expected += "project_metadata/metadata_3.csv deleted\n" + expected += "data/f1.txt renamed data/f4.txt\n" + expected += "data/f5.txt created\n" + expected += "data/f6.txt created\n" assert msg == expected diff --git a/allensdk/test/api/cloud_cache/test_file_attributes.py b/allensdk/test/api/cloud_cache/test_file_attributes.py index f3af08221f..8a9e630561 100644 --- a/allensdk/test/api/cloud_cache/test_file_attributes.py +++ b/allensdk/test/api/cloud_cache/test_file_attributes.py @@ -5,51 +5,46 @@ def test_cache_file_attributes(): - attr = CacheFileAttributes(url='http://my/url', - version_id='aaabbb', - file_hash='12345', - local_path=pathlib.Path('/my/local/path')) + attr = CacheFileAttributes( + url="http://my/url", version_id="aaabbb", file_hash="12345", local_path=pathlib.Path("/my/local/path") + ) - assert attr.url == 'http://my/url' - assert attr.version_id == 'aaabbb' - assert attr.file_hash == '12345' - assert attr.local_path == pathlib.Path('/my/local/path') + assert attr.url == "http://my/url" + assert attr.version_id == "aaabbb" + assert attr.file_hash == "12345" + assert attr.local_path == pathlib.Path("/my/local/path") # test that the correct ValueErrors are raised # when you pass invalid arguments with pytest.raises(ValueError) as context: - attr = CacheFileAttributes(url=5.0, - version_id='aaabbb', - file_hash='12345', - local_path=pathlib.Path('/my/local/path')) + attr = CacheFileAttributes( + url=5.0, version_id="aaabbb", file_hash="12345", local_path=pathlib.Path("/my/local/path") + ) msg = "url must be str; got " assert context.value.args[0] == msg with pytest.raises(ValueError) as context: - attr = CacheFileAttributes(url='http://my/url/', - version_id=5.0, - file_hash='12345', - local_path=pathlib.Path('/my/local/path')) + attr = CacheFileAttributes( + url="http://my/url/", version_id=5.0, file_hash="12345", local_path=pathlib.Path("/my/local/path") + ) msg = "version_id must be str; got " assert context.value.args[0] == msg with pytest.raises(ValueError) as context: - attr = CacheFileAttributes(url='http://my/url/', - version_id='aaabbb', - file_hash=5.0, - local_path=pathlib.Path('/my/local/path')) + attr = CacheFileAttributes( + url="http://my/url/", version_id="aaabbb", file_hash=5.0, local_path=pathlib.Path("/my/local/path") + ) msg = "file_hash must be str; got " assert context.value.args[0] == msg with pytest.raises(ValueError) as context: - attr = CacheFileAttributes(url='http://my/url/', - version_id='aaabbb', - file_hash='12345', - local_path='/my/local/path') + attr = CacheFileAttributes( + url="http://my/url/", version_id="aaabbb", file_hash="12345", local_path="/my/local/path" + ) msg = "local_path must be pathlib.Path; got " assert context.value.args[0] == msg @@ -59,15 +54,14 @@ def test_str(): """ Test the string representation of CacheFileParameters """ - attr = CacheFileAttributes(url='http://my/url', - version_id='aaabbb', - file_hash='12345', - local_path=pathlib.Path('/my/local/path')) + attr = CacheFileAttributes( + url="http://my/url", version_id="aaabbb", file_hash="12345", local_path=pathlib.Path("/my/local/path") + ) - s = f'{attr}' + s = f"{attr}" assert "CacheFileParameters{" in s assert '"file_hash": "12345"' in s assert '"url": "http://my/url"' in s assert '"version_id": "aaabbb"' in s - if platform.system().lower() != 'windows': + if platform.system().lower() != "windows": assert '"local_path": "/my/local/path"' in s diff --git a/allensdk/test/api/cloud_cache/test_full_process.py b/allensdk/test/api/cloud_cache/test_full_process.py index 8f8be62f2d..5e50a22d0a 100644 --- a/allensdk/test/api/cloud_cache/test_full_process.py +++ b/allensdk/test/api/cloud_cache/test_full_process.py @@ -16,246 +16,226 @@ def test_full_cache_system(tmpdir): each of which involve different versions of files """ - test_bucket_name = 'full_cache_bucket' + test_bucket_name = "full_cache_bucket" - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # turn on bucket versioning # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - s3_client = boto3.client('s3', region_name='us-east-1') + s3_client = boto3.client("s3", region_name="us-east-1") # generate data and expected hashes true_hashes = {} version_id_lookup = {} - data1_v1 = b'12345678' - data1_v2 = b'45678901' - data2_v1 = b'abcdefghijk' - data2_v2 = b'lmnopqrstuv' - data3_v1 = b'jklmnopqrst' + data1_v1 = b"12345678" + data1_v2 = b"45678901" + data2_v1 = b"abcdefghijk" + data2_v2 = b"lmnopqrstuv" + data3_v1 = b"jklmnopqrst" - metadata1_v1 = pd.DataFrame({'mouse': [1, 2, 3], - 'sex': ['F', 'F', 'M']}) + metadata1_v1 = pd.DataFrame({"mouse": [1, 2, 3], "sex": ["F", "F", "M"]}) - metadata2_v1 = pd.DataFrame({'experiment': [5, 6, 7], - 'file_id': ['data1', 'data2', 'data3']}) + metadata2_v1 = pd.DataFrame({"experiment": [5, 6, 7], "file_id": ["data1", "data2", "data3"]}) - metadata1_v2 = pd.DataFrame({'mouse': [8, 9, 0], - 'sex': ['M', 'F', 'M']}) + metadata1_v2 = pd.DataFrame({"mouse": [8, 9, 0], "sex": ["M", "F", "M"]}) v1_hashes = {} - for data, key in zip((data1_v1, data2_v1, data3_v1), - ('data1', 'data2', 'data3')): - + for data, key in zip((data1_v1, data2_v1, data3_v1), ("data1", "data2", "data3")): hasher = hashlib.blake2b() hasher.update(data) v1_hashes[key] = hasher.hexdigest() - s3_client.put_object(Bucket=test_bucket_name, - Key=f'proj/data/{key}', - Body=data) - - for df, key in zip((metadata1_v1, metadata2_v1), - ('proj/metadata1.csv', 'proj/metadata2.csv')): + s3_client.put_object(Bucket=test_bucket_name, Key=f"proj/data/{key}", Body=data) + for df, key in zip((metadata1_v1, metadata2_v1), ("proj/metadata1.csv", "proj/metadata2.csv")): with io.StringIO() as stream: df.to_csv(stream, index=False) stream.seek(0) - data = bytes(stream.read(), 'utf-8') + data = bytes(stream.read(), "utf-8") hasher = hashlib.blake2b() hasher.update(data) - v1_hashes[key.replace('proj/', '')] = hasher.hexdigest() - s3_client.put_object(Bucket=test_bucket_name, - Key=key, - Body=data) + v1_hashes[key.replace("proj/", "")] = hasher.hexdigest() + s3_client.put_object(Bucket=test_bucket_name, Key=key, Body=data) - true_hashes['v1'] = v1_hashes + true_hashes["v1"] = v1_hashes v1_version_id = {} response = s3_client.list_object_versions(Bucket=test_bucket_name) - for v in response['Versions']: - vkey = v['Key'].replace('proj/', '').replace('data/', '') - v1_version_id[vkey] = v['VersionId'] + for v in response["Versions"]: + vkey = v["Key"].replace("proj/", "").replace("data/", "") + v1_version_id[vkey] = v["VersionId"] - version_id_lookup['v1'] = v1_version_id + version_id_lookup["v1"] = v1_version_id v2_hashes = {} v2_version_id = {} - for data, key in zip((data1_v2, data2_v2), - ('data1', 'data2')): - + for data, key in zip((data1_v2, data2_v2), ("data1", "data2")): hasher = hashlib.blake2b() hasher.update(data) v2_hashes[key] = hasher.hexdigest() - s3_client.put_object(Bucket=test_bucket_name, - Key=f'proj/data/{key}', - Body=data) + s3_client.put_object(Bucket=test_bucket_name, Key=f"proj/data/{key}", Body=data) - s3_client.delete_object(Bucket=test_bucket_name, - Key='proj/data/data3') + s3_client.delete_object(Bucket=test_bucket_name, Key="proj/data/data3") with io.StringIO() as stream: metadata1_v2.to_csv(stream, index=False) stream.seek(0) - data = bytes(stream.read(), 'utf-8') + data = bytes(stream.read(), "utf-8") hasher = hashlib.blake2b() hasher.update(data) - v2_hashes['metadata1.csv'] = hasher.hexdigest() - s3_client.put_object(Bucket=test_bucket_name, - Key='proj/metadata1.csv', - Body=data) + v2_hashes["metadata1.csv"] = hasher.hexdigest() + s3_client.put_object(Bucket=test_bucket_name, Key="proj/metadata1.csv", Body=data) - s3_client.delete_object(Bucket=test_bucket_name, - Key='proj/metadata2.csv') + s3_client.delete_object(Bucket=test_bucket_name, Key="proj/metadata2.csv") - true_hashes['v2'] = v2_hashes + true_hashes["v2"] = v2_hashes v2_version_id = {} response = s3_client.list_object_versions(Bucket=test_bucket_name) - for v in response['Versions']: - if not v['IsLatest']: + for v in response["Versions"]: + if not v["IsLatest"]: continue - vkey = v['Key'].replace('proj/', '').replace('data/', '') - v2_version_id[vkey] = v['VersionId'] - version_id_lookup['v2'] = v2_version_id + vkey = v["Key"].replace("proj/", "").replace("data/", "") + v2_version_id[vkey] = v["VersionId"] + version_id_lookup["v2"] = v2_version_id # check thata data3 and metadata2.csv do not occur in v2 of # the dataset, but other data/metadata files do - assert 'data3' in version_id_lookup['v1'] - assert 'data3' not in version_id_lookup['v2'] - assert 'data1' in version_id_lookup['v1'] - assert 'data2' in version_id_lookup['v1'] - assert 'data1' in version_id_lookup['v2'] - assert 'data2' in version_id_lookup['v2'] - assert 'metadata1.csv' in version_id_lookup['v1'] - assert 'metadata2.csv' in version_id_lookup['v1'] - assert 'metadata1.csv' in version_id_lookup['v2'] - assert 'metadata2.csv' not in version_id_lookup['v2'] + assert "data3" in version_id_lookup["v1"] + assert "data3" not in version_id_lookup["v2"] + assert "data1" in version_id_lookup["v1"] + assert "data2" in version_id_lookup["v1"] + assert "data1" in version_id_lookup["v2"] + assert "data2" in version_id_lookup["v2"] + assert "metadata1.csv" in version_id_lookup["v1"] + assert "metadata2.csv" in version_id_lookup["v1"] + assert "metadata1.csv" in version_id_lookup["v2"] + assert "metadata2.csv" not in version_id_lookup["v2"] # build manifests manifest_1 = {} - manifest_1['manifest_version'] = 'A' - manifest_1['project_name'] = "project-A1" - manifest_1['metadata_file_id_column_name'] = 'file_id' - manifest_1['data_pipeline'] = 'placeholder' + manifest_1["manifest_version"] = "A" + manifest_1["project_name"] = "project-A1" + manifest_1["metadata_file_id_column_name"] = "file_id" + manifest_1["data_pipeline"] = "placeholder" data_files_1 = {} - for k in ('data1', 'data2', 'data3'): + for k in ("data1", "data2", "data3"): obj = {} - obj['url'] = f'http://{test_bucket_name}.s3.amazonaws.com/proj/data/{k}' # noqa: E501 - obj['file_hash'] = true_hashes['v1'][k] - obj['version_id'] = version_id_lookup['v1'][k] + obj["url"] = f"http://{test_bucket_name}.s3.amazonaws.com/proj/data/{k}" # noqa: E501 + obj["file_hash"] = true_hashes["v1"][k] + obj["version_id"] = version_id_lookup["v1"][k] data_files_1[k] = obj - manifest_1['data_files'] = data_files_1 + manifest_1["data_files"] = data_files_1 metadata_files_1 = {} - for k in ('metadata1.csv', 'metadata2.csv'): + for k in ("metadata1.csv", "metadata2.csv"): obj = {} - obj['url'] = f'http://{test_bucket_name}.s3.amazonaws.com/proj/{k}' - obj['file_hash'] = true_hashes['v1'][k] - obj['version_id'] = version_id_lookup['v1'][k] + obj["url"] = f"http://{test_bucket_name}.s3.amazonaws.com/proj/{k}" + obj["file_hash"] = true_hashes["v1"][k] + obj["version_id"] = version_id_lookup["v1"][k] metadata_files_1[k] = obj - manifest_1['metadata_files'] = metadata_files_1 + manifest_1["metadata_files"] = metadata_files_1 manifest_2 = {} - manifest_2['manifest_version'] = 'B' - manifest_2['project_name'] = "project-B2" - manifest_2['metadata_file_id_column_name'] = 'file_id' - manifest_2['data_pipeline'] = 'placeholder' + manifest_2["manifest_version"] = "B" + manifest_2["project_name"] = "project-B2" + manifest_2["metadata_file_id_column_name"] = "file_id" + manifest_2["data_pipeline"] = "placeholder" data_files_2 = {} - for k in ('data1', 'data2'): + for k in ("data1", "data2"): obj = {} - obj['url'] = f'http://{test_bucket_name}.s3.amazonaws.com/proj/data/{k}' # noqa: E501 - obj['file_hash'] = true_hashes['v2'][k] - obj['version_id'] = version_id_lookup['v2'][k] + obj["url"] = f"http://{test_bucket_name}.s3.amazonaws.com/proj/data/{k}" # noqa: E501 + obj["file_hash"] = true_hashes["v2"][k] + obj["version_id"] = version_id_lookup["v2"][k] data_files_2[k] = obj - manifest_2['data_files'] = data_files_2 + manifest_2["data_files"] = data_files_2 metadata_files_2 = {} - for k in ['metadata1.csv']: + for k in ["metadata1.csv"]: obj = {} - obj['url'] = f'http://{test_bucket_name}.s3.amazonaws.com/proj/{k}' - obj['file_hash'] = true_hashes['v2'][k] - obj['version_id'] = version_id_lookup['v2'][k] + obj["url"] = f"http://{test_bucket_name}.s3.amazonaws.com/proj/{k}" + obj["file_hash"] = true_hashes["v2"][k] + obj["version_id"] = version_id_lookup["v2"][k] metadata_files_2[k] = obj - manifest_2['metadata_files'] = metadata_files_2 + manifest_2["metadata_files"] = metadata_files_2 - s3_client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v1.0.0.json', - Body=bytes(json.dumps(manifest_1), 'utf-8')) + s3_client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v1.0.0.json", Body=bytes(json.dumps(manifest_1), "utf-8") + ) - s3_client.put_object(Bucket=test_bucket_name, - Key='proj/manifests/manifest_v2.0.0.json', - Body=bytes(json.dumps(manifest_2), 'utf-8')) + s3_client.put_object( + Bucket=test_bucket_name, Key="proj/manifests/manifest_v2.0.0.json", Body=bytes(json.dumps(manifest_2), "utf-8") + ) # Use S3CloudCache to interact with dataset - cache_dir = pathlib.Path(tmpdir) / 'my/test/cache' - cache = S3CloudCache(cache_dir, test_bucket_name, 'proj') + cache_dir = pathlib.Path(tmpdir) / "my/test/cache" + cache = S3CloudCache(cache_dir, test_bucket_name, "proj") # load the first version of the dataset - cache.load_manifest('manifest_v1.0.0.json') - assert cache.version == 'A' + cache.load_manifest("manifest_v1.0.0.json") + assert cache.version == "A" # check that metadata dataframes have expected contents - m1 = cache.get_metadata('metadata1.csv') + m1 = cache.get_metadata("metadata1.csv") assert metadata1_v1.equals(m1) - m2 = cache.get_metadata('metadata2.csv') + m2 = cache.get_metadata("metadata2.csv") assert metadata2_v1.equals(m2) # check that data files have expected hashes - for k in ('data1', 'data2', 'data3'): - + for k in ("data1", "data2", "data3"): attr = cache.data_path(k) - assert not attr['exists'] + assert not attr["exists"] local_path = cache.download_data(k) assert local_path.exists() hasher = hashlib.blake2b() - with open(local_path, 'rb') as in_file: + with open(local_path, "rb") as in_file: hasher.update(in_file.read()) - assert hasher.hexdigest() == true_hashes['v1'][k] + assert hasher.hexdigest() == true_hashes["v1"][k] attr = cache.data_path(k) - assert attr['exists'] + assert attr["exists"] # now load the second version of the dataset - cache.load_manifest('manifest_v2.0.0.json') - assert cache.version == 'B' + cache.load_manifest("manifest_v2.0.0.json") + assert cache.version == "B" # metadata2.csv should not exist in this version of the dataset with pytest.raises(ValueError) as context: - cache.get_metadata('metadata2.csv') - assert 'is not in self.metadata_file_names' in context.value.args[0] + cache.get_metadata("metadata2.csv") + assert "is not in self.metadata_file_names" in context.value.args[0] # check that metadata1 has expected contents - m1 = cache.get_metadata('metadata1.csv') + m1 = cache.get_metadata("metadata1.csv") assert metadata1_v2.equals(m1) # data3 should not exist in this version of the dataset with pytest.raises(ValueError) as context: - _ = cache.download_data('data3') - assert 'not a data file listed' in context.value.args[0] + _ = cache.download_data("data3") + assert "not a data file listed" in context.value.args[0] with pytest.raises(ValueError) as context: - _ = cache.data_path('data3') - assert 'not a data file listed' in context.value.args[0] + _ = cache.data_path("data3") + assert "not a data file listed" in context.value.args[0] # check that data1, data2 have expected hashes - for k in ('data1', 'data2'): + for k in ("data1", "data2"): attr = cache.data_path(k) - assert not attr['exists'] + assert not attr["exists"] local_path = cache.download_data(k) assert local_path.exists() hasher = hashlib.blake2b() - with open(local_path, 'rb') as in_file: + with open(local_path, "rb") as in_file: hasher.update(in_file.read()) - assert hasher.hexdigest() == true_hashes['v2'][k] + assert hasher.hexdigest() == true_hashes["v2"][k] attr = cache.data_path(k) - assert attr['exists'] + assert attr["exists"] diff --git a/allensdk/test/api/cloud_cache/test_local_cache.py b/allensdk/test/api/cloud_cache/test_local_cache.py index 211c5c4641..b892323924 100644 --- a/allensdk/test/api/cloud_cache/test_local_cache.py +++ b/allensdk/test/api/cloud_cache/test_local_cache.py @@ -13,38 +13,37 @@ def test_local_cache_file_access(tmpdir, example_datasets): with LocalCache """ - bucket_name = 'local_cache_bucket' + bucket_name = "local_cache_bucket" create_bucket(bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cloud_cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cloud_cache = S3CloudCache(cache_dir, bucket_name, "project-x") - cloud_cache.load_manifest('project-x_manifest_v1.0.0.json') - cloud_cache.download_data('1') - cloud_cache.download_data('3') + cloud_cache.load_manifest("project-x_manifest_v1.0.0.json") + cloud_cache.download_data("1") + cloud_cache.download_data("3") - cloud_cache.load_manifest('project-x_manifest_v3.0.0.json') - cloud_cache.download_data('2') + cloud_cache.load_manifest("project-x_manifest_v3.0.0.json") + cloud_cache.download_data("2") del cloud_cache - local_cache = LocalCache(cache_dir, 'project-x') + local_cache = LocalCache(cache_dir, "project-x") manifest_set = set(local_cache.manifest_file_names) - assert manifest_set == {'project-x_manifest_v1.0.0.json', - 'project-x_manifest_v3.0.0.json'} - - local_cache.load_manifest('project-x_manifest_v1.0.0.json') - attr = local_cache.data_path('1') - assert attr['exists'] - attr = local_cache.data_path('2') - assert not attr['exists'] - attr = local_cache.data_path('3') - assert attr['exists'] - - local_cache.load_manifest('project-x_manifest_v3.0.0.json') - attr = local_cache.data_path('1') - assert attr['exists'] # because file 1 is the same in v1.0 and v3.0 - attr = local_cache.data_path('2') - assert attr['exists'] - attr = local_cache.data_path('3') - assert not attr['exists'] + assert manifest_set == {"project-x_manifest_v1.0.0.json", "project-x_manifest_v3.0.0.json"} + + local_cache.load_manifest("project-x_manifest_v1.0.0.json") + attr = local_cache.data_path("1") + assert attr["exists"] + attr = local_cache.data_path("2") + assert not attr["exists"] + attr = local_cache.data_path("3") + assert attr["exists"] + + local_cache.load_manifest("project-x_manifest_v3.0.0.json") + attr = local_cache.data_path("1") + assert attr["exists"] # because file 1 is the same in v1.0 and v3.0 + attr = local_cache.data_path("2") + assert attr["exists"] + attr = local_cache.data_path("3") + assert not attr["exists"] diff --git a/allensdk/test/api/cloud_cache/test_manifest.py b/allensdk/test/api/cloud_cache/test_manifest.py index b7f44a97d5..bc8a3faf73 100644 --- a/allensdk/test/api/cloud_cache/test_manifest.py +++ b/allensdk/test/api/cloud_cache/test_manifest.py @@ -10,12 +10,13 @@ def meta_json_path(tmpdir): jpath = tmpdir / "somejson.json" d = { - "project_name": "X", - "manifest_version": "Y", - "metadata_file_id_column_name": "Z", - "data_pipeline": "ZA", - "metadata_files": ["ZB", "ZC", "ZD"], - "data_files": {"AB": "ab", "BC": "bc", "CD": "cd"}} + "project_name": "X", + "manifest_version": "Y", + "metadata_file_id_column_name": "Z", + "data_pipeline": "ZA", + "metadata_files": ["ZB", "ZC", "ZD"], + "data_files": {"AB": "ab", "BC": "bc", "CD": "cd"}, + } with open(jpath, "w") as f: json.dump(d, f) yield jpath @@ -26,8 +27,8 @@ def test_constructor(meta_json_path): Make sure that the Manifest class __init__ runs and raises an error if you give it an unexpected cache_dir """ - Manifest('my/cache/dir', meta_json_path) - Manifest(pathlib.Path('my/other/cache/dir'), meta_json_path) + Manifest("my/cache/dir", meta_json_path) + Manifest(pathlib.Path("my/other/cache/dir"), meta_json_path) with pytest.raises(ValueError, match=r"cache_dir must be either a str.*"): Manifest(1234.2, meta_json_path) @@ -38,16 +39,14 @@ def test_create_file_attributes(meta_json_path): handles input parameters (this is mostly a test of local_path generation) """ - mfest = Manifest('/my/cache/dir', meta_json_path) - attr = mfest._create_file_attributes('http://my.url.com/path/to/file.txt', - '12345', - 'aaabbbcccddd') + mfest = Manifest("/my/cache/dir", meta_json_path) + attr = mfest._create_file_attributes("http://my.url.com/path/to/file.txt", "12345", "aaabbbcccddd") assert isinstance(attr, CacheFileAttributes) - assert attr.url == 'http://my.url.com/path/to/file.txt' - assert attr.version_id == '12345' - assert attr.file_hash == 'aaabbbcccddd' - expected_path = '/my/cache/dir/X-Y/to/file.txt' + assert attr.url == "http://my.url.com/path/to/file.txt" + assert attr.version_id == "12345" + assert attr.file_hash == "aaabbbcccddd" + expected_path = "/my/cache/dir/X-Y/to/file.txt" assert attr.local_path == pathlib.Path(expected_path).resolve() @@ -56,19 +55,19 @@ def manifest_for_metadata(tmpdir): jpath = tmpdir / "a_manifest.json" manifest = {} metadata_files = {} - metadata_files['a.txt'] = {'url': 'http://my.url.com/path/to/a.txt', - 'version_id': '12345', - 'file_hash': 'abcde'} - metadata_files['b.txt'] = {'url': 'http://my.other.url.com/different/path/to/b.txt', # noqa: E501 - 'version_id': '67890', - 'file_hash': 'fghijk'} - - manifest['metadata_files'] = metadata_files - manifest['data_files'] = {} - manifest['project_name'] = "some-project" - manifest['manifest_version'] = '000' - manifest['metadata_file_id_column_name'] = 'file_id' - manifest['data_pipeline'] = 'placeholder' + metadata_files["a.txt"] = {"url": "http://my.url.com/path/to/a.txt", "version_id": "12345", "file_hash": "abcde"} + metadata_files["b.txt"] = { + "url": "http://my.other.url.com/different/path/to/b.txt", # noqa: E501 + "version_id": "67890", + "file_hash": "fghijk", + } + + manifest["metadata_files"] = metadata_files + manifest["data_files"] = {} + manifest["project_name"] = "some-project" + manifest["manifest_version"] = "000" + manifest["metadata_file_id_column_name"] = "file_id" + manifest["data_pipeline"] = "placeholder" with open(jpath, "w") as f: json.dump(manifest, f) yield jpath @@ -81,21 +80,21 @@ def test_metadata_file_attributes(manifest_for_metadata): error when you ask for a metadata file that does not exist """ - mfest = Manifest('/my/cache/dir/', manifest_for_metadata) + mfest = Manifest("/my/cache/dir/", manifest_for_metadata) - a_obj = mfest.metadata_file_attributes('a.txt') - assert a_obj.url == 'http://my.url.com/path/to/a.txt' - assert a_obj.version_id == '12345' - assert a_obj.file_hash == 'abcde' - expected = safe_system_path('/my/cache/dir/some-project-000/to/a.txt') + a_obj = mfest.metadata_file_attributes("a.txt") + assert a_obj.url == "http://my.url.com/path/to/a.txt" + assert a_obj.version_id == "12345" + assert a_obj.file_hash == "abcde" + expected = safe_system_path("/my/cache/dir/some-project-000/to/a.txt") expected = pathlib.Path(expected).resolve() assert a_obj.local_path == expected - b_obj = mfest.metadata_file_attributes('b.txt') - assert b_obj.url == 'http://my.other.url.com/different/path/to/b.txt' - assert b_obj.version_id == '67890' - assert b_obj.file_hash == 'fghijk' - expected = safe_system_path('/my/cache/dir/some-project-000/path/to/b.txt') + b_obj = mfest.metadata_file_attributes("b.txt") + assert b_obj.url == "http://my.other.url.com/different/path/to/b.txt" + assert b_obj.version_id == "67890" + assert b_obj.file_hash == "fghijk" + expected = safe_system_path("/my/cache/dir/some-project-000/path/to/b.txt") expected = pathlib.Path(expected).resolve() assert b_obj.local_path == expected @@ -103,7 +102,7 @@ def test_metadata_file_attributes(manifest_for_metadata): # for a metadata file that does not exist with pytest.raises(ValueError) as context: - _ = mfest.metadata_file_attributes('c.txt') + _ = mfest.metadata_file_attributes("c.txt") msg = "c.txt\nis not in self.metadata_file_names" assert msg in context.value.args[0] @@ -112,19 +111,19 @@ def test_metadata_file_attributes(manifest_for_metadata): def manifest_with_data(tmpdir): jpath = tmpdir / "manifest_with files.json" manifest = {} - manifest['metadata_files'] = {} - manifest['manifest_version'] = '0' - manifest['project_name'] = "myproject" - manifest['metadata_file_id_column_name'] = 'file_id' - manifest['data_pipeline'] = 'placeholder' + manifest["metadata_files"] = {} + manifest["manifest_version"] = "0" + manifest["project_name"] = "myproject" + manifest["metadata_file_id_column_name"] = "file_id" + manifest["data_pipeline"] = "placeholder" data_files = {} - data_files['a'] = {'url': 'http://my.url.com/myproject/path/to/a.nwb', - 'version_id': '12345', - 'file_hash': 'abcde'} - data_files['b'] = {'url': 'http://my.other.url.com/different/path/b.nwb', - 'version_id': '67890', - 'file_hash': 'fghijk'} - manifest['data_files'] = data_files + data_files["a"] = {"url": "http://my.url.com/myproject/path/to/a.nwb", "version_id": "12345", "file_hash": "abcde"} + data_files["b"] = { + "url": "http://my.other.url.com/different/path/b.nwb", + "version_id": "67890", + "file_hash": "fghijk", + } + manifest["data_files"] = data_files with open(jpath, "w") as f: json.dump(manifest, f) yield jpath @@ -136,24 +135,24 @@ def test_data_file_attributes(manifest_with_data): CacheFileAttributes object and raises the correct error when you ask for a data file that does not exist """ - mfest = Manifest('/my/cache/dir', manifest_with_data) + mfest = Manifest("/my/cache/dir", manifest_with_data) - a_obj = mfest.data_file_attributes('a') - assert a_obj.url == 'http://my.url.com/myproject/path/to/a.nwb' - assert a_obj.version_id == '12345' - assert a_obj.file_hash == 'abcde' - expected = safe_system_path('/my/cache/dir/myproject-0/path/to/a.nwb') + a_obj = mfest.data_file_attributes("a") + assert a_obj.url == "http://my.url.com/myproject/path/to/a.nwb" + assert a_obj.version_id == "12345" + assert a_obj.file_hash == "abcde" + expected = safe_system_path("/my/cache/dir/myproject-0/path/to/a.nwb") assert a_obj.local_path == pathlib.Path(expected).resolve() - b_obj = mfest.data_file_attributes('b') - assert b_obj.url == 'http://my.other.url.com/different/path/b.nwb' - assert b_obj.version_id == '67890' - assert b_obj.file_hash == 'fghijk' - expected = safe_system_path('/my/cache/dir/myproject-0/path/b.nwb') + b_obj = mfest.data_file_attributes("b") + assert b_obj.url == "http://my.other.url.com/different/path/b.nwb" + assert b_obj.version_id == "67890" + assert b_obj.file_hash == "fghijk" + expected = safe_system_path("/my/cache/dir/myproject-0/path/b.nwb") assert b_obj.local_path == pathlib.Path(expected).resolve() with pytest.raises(ValueError) as context: - _ = mfest.data_file_attributes('c') + _ = mfest.data_file_attributes("c") msg = "file_id: c\nIs not a data file listed in manifest:" assert msg in context.value.args[0] @@ -164,10 +163,8 @@ def test_file_attribute_errors(meta_json_path): attributes before loading a manifest.json """ mfest = Manifest("/my/cache/dir", meta_json_path) - with pytest.raises(ValueError, - match=r".* not in self.metadata_file_names"): - mfest.metadata_file_attributes('some_file.txt') + with pytest.raises(ValueError, match=r".* not in self.metadata_file_names"): + mfest.metadata_file_attributes("some_file.txt") - with pytest.raises(ValueError, - match=r".* not a data file listed in manifest"): - mfest.data_file_attributes('other_file.txt') + with pytest.raises(ValueError, match=r".* not a data file listed in manifest"): + mfest.data_file_attributes("other_file.txt") diff --git a/allensdk/test/api/cloud_cache/test_smart_download.py b/allensdk/test/api/cloud_cache/test_smart_download.py index e9d17e94a8..2fee07d725 100644 --- a/allensdk/test/api/cloud_cache/test_smart_download.py +++ b/allensdk/test/api/cloud_cache/test_smart_download.py @@ -16,83 +16,76 @@ def test_smart_file_downloading(tmpdir, example_datasets): Test that the CloudCache is smart enough to build symlinks where possible """ - test_bucket_name = 'smart_download_bucket' - create_bucket(test_bucket_name, - example_datasets) + test_bucket_name = "smart_download_bucket" + create_bucket(test_bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, test_bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, test_bucket_name, "project-x") # download all data files from all versions, keeping track # of the paths to the downloaded data files downloaded = {} - for version in ('1.0.0', '2.0.0', '3.0.0'): + for version in ("1.0.0", "2.0.0", "3.0.0"): downloaded[version] = {} - cache.load_manifest(f'project-x_manifest_v{version}.json') - for file_id in ('1', '2', '3'): + cache.load_manifest(f"project-x_manifest_v{version}.json") + for file_id in ("1", "2", "3"): downloaded[version][file_id] = cache.download_data(file_id) # check that the version 1.0.0 of all files are actual files - for file_id in ('1', '2', '3'): - assert downloaded['1.0.0'][file_id].is_file() - assert not downloaded['1.0.0'][file_id].is_symlink() + for file_id in ("1", "2", "3"): + assert downloaded["1.0.0"][file_id].is_file() + assert not downloaded["1.0.0"][file_id].is_symlink() # check that v2.0.0 f1.txt is a new file - assert downloaded['2.0.0']['1'].is_file() - assert not downloaded['2.0.0']['1'].is_symlink() + assert downloaded["2.0.0"]["1"].is_file() + assert not downloaded["2.0.0"]["1"].is_symlink() # check that v2.0.0 f2.txt and f3.txt are symlinks to # the correct v1.0.0 files - for file_id in ('2', '3'): - assert downloaded['2.0.0'][file_id].is_file() - assert downloaded['2.0.0'][file_id].is_symlink() + for file_id in ("2", "3"): + assert downloaded["2.0.0"][file_id].is_file() + assert downloaded["2.0.0"][file_id].is_symlink() # check that symlink points to the correct file - test = downloaded['2.0.0'][file_id].resolve() - control = downloaded['1.0.0'][file_id].resolve() + test = downloaded["2.0.0"][file_id].resolve() + control = downloaded["1.0.0"][file_id].resolve() if test != control: - test = downloaded['2.0.0'][file_id].resolve() - control = downloaded['1.0.0'][file_id].resolve() - raise RuntimeError(f'{test} != {control}\n' - 'even though the first is a symlink') + test = downloaded["2.0.0"][file_id].resolve() + control = downloaded["1.0.0"][file_id].resolve() + raise RuntimeError(f"{test} != {control}\neven though the first is a symlink") # check that the absolute paths of the files are different, # even though one is a symlink - test = downloaded['2.0.0'][file_id].absolute() - control = downloaded['1.0.0'][file_id].absolute() + test = downloaded["2.0.0"][file_id].absolute() + control = downloaded["1.0.0"][file_id].absolute() if test == control: - test = downloaded['2.0.0'][file_id].absolute() - control = downloaded['1.0.0'][file_id].absolute() - raise RuntimeError(f'{test} == {control}\n' - 'even though they should be ' - 'different absolute paths') + test = downloaded["2.0.0"][file_id].absolute() + control = downloaded["1.0.0"][file_id].absolute() + raise RuntimeError(f"{test} == {control}\neven though they should be different absolute paths") # repeat the above tests for v3.0.0, f1.txt - assert downloaded['3.0.0']['1'].is_file() - assert downloaded['3.0.0']['1'].is_symlink() + assert downloaded["3.0.0"]["1"].is_file() + assert downloaded["3.0.0"]["1"].is_symlink() - res3 = downloaded['3.0.0']['1'].resolve() - res1 = downloaded['1.0.0']['1'].resolve() + res3 = downloaded["3.0.0"]["1"].resolve() + res1 = downloaded["1.0.0"]["1"].resolve() if res3 != res1: - test = downloaded['3.0.0']['1'].resolve() - control = downloaded['1.0.0']['1'].resolve() - raise RuntimeError(f'{test} != {control}\n' - 'even though the first is a symlink') + test = downloaded["3.0.0"]["1"].resolve() + control = downloaded["1.0.0"]["1"].resolve() + raise RuntimeError(f"{test} != {control}\neven though the first is a symlink") - abs3 = downloaded['3.0.0']['1'].absolute() - abs1 = downloaded['1.0.0']['1'].absolute() + abs3 = downloaded["3.0.0"]["1"].absolute() + abs1 = downloaded["1.0.0"]["1"].absolute() if abs3 == abs1: - test = downloaded['3.0.0']['1'].absolute() - control = downloaded['1.0.0']['1'].absolute() - raise RuntimeError(f'{test} == {control}\n' - 'even though they should be ' - 'different absolute paths') + test = downloaded["3.0.0"]["1"].absolute() + control = downloaded["1.0.0"]["1"].absolute() + raise RuntimeError(f"{test} == {control}\neven though they should be different absolute paths") # check that v3 v2.txt and f3.txt are not symlinks - assert downloaded['3.0.0']['2'].is_file() - assert not downloaded['3.0.0']['2'].is_symlink() - assert downloaded['3.0.0']['3'].is_file() - assert not downloaded['3.0.0']['3'].is_symlink() + assert downloaded["3.0.0"]["2"].is_file() + assert not downloaded["3.0.0"]["2"].is_symlink() + assert downloaded["3.0.0"]["3"].is_file() + assert not downloaded["3.0.0"]["3"].is_symlink() @mock_s3 @@ -101,63 +94,62 @@ def test_on_corrupted_files(tmpdir, example_datasets): Test that the CloudCache re-downloads files when they have been corrupted """ - bucket_name = 'corruption_bucket' - create_bucket(bucket_name, - example_datasets) + bucket_name = "corruption_bucket" + create_bucket(bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - version_list = ('1.0.0', '2.0.0', '3.0.0') - file_id_list = ('1', '2', '3') + version_list = ("1.0.0", "2.0.0", "3.0.0") + file_id_list = ("1", "2", "3") for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: cache.download_data(file_id) # make sure that all files exist for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: attr = cache.data_path(file_id) - assert attr['exists'] + assert attr["exists"] hasher = hashlib.blake2b() - hasher.update(b'4567890') + hasher.update(b"4567890") true_hash = hasher.hexdigest() # Check that, when a file on disk gets removed, # all of the symlinks that point back to that file # get marked as `not exists` - cache.load_manifest('project-x_manifest_v1.0.0.json') - attr = cache.data_path('2') - attr['local_path'].unlink() + cache.load_manifest("project-x_manifest_v1.0.0.json") + attr = cache.data_path("2") + attr["local_path"].unlink() - attr = cache.data_path('2') - assert not attr['exists'] + attr = cache.data_path("2") + assert not attr["exists"] # note that v0.2.0/f2.txt is identical to v0.1.0/f2.txt # in the example data set - cache.load_manifest('project-x_manifest_v2.0.0.json') - attr = cache.data_path('2') - assert not attr['exists'] + cache.load_manifest("project-x_manifest_v2.0.0.json") + attr = cache.data_path("2") + assert not attr["exists"] # re-download one of the identical files, and verify # that both datasets are restored - cache.download_data('2') - attr = cache.data_path('2') - assert attr['exists'] - redownloaded_path = attr['local_path'] + cache.download_data("2") + attr = cache.data_path("2") + assert attr["exists"] + redownloaded_path = attr["local_path"] - cache.load_manifest('project-x_manifest_v1.0.0.json') - attr = cache.data_path('2') - assert attr['exists'] - other_path = attr['local_path'] + cache.load_manifest("project-x_manifest_v1.0.0.json") + attr = cache.data_path("2") + assert attr["exists"] + other_path = attr["local_path"] hasher = hashlib.blake2b() - with open(other_path, 'rb') as in_file: + with open(other_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_hash @@ -175,34 +167,33 @@ def test_on_removed_files(tmpdir, example_datasets): Test that the CloudCache re-downloads files when the the files at the root of the symlinks have been removed """ - bucket_name = 'corruption_bucket' - create_bucket(bucket_name, - example_datasets) + bucket_name = "corruption_bucket" + create_bucket(bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - version_list = ('1.0.0', '2.0.0', '3.0.0') - file_id_list = ('1', '2', '3') + version_list = ("1.0.0", "2.0.0", "3.0.0") + file_id_list = ("1", "2", "3") for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: cache.download_data(file_id) # make sure that all files exist for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: attr = cache.data_path(file_id) - assert attr['exists'] + assert attr["exists"] hasher = hashlib.blake2b() - hasher.update(b'4567890') + hasher.update(b"4567890") true_hash = hasher.hexdigest() - p1 = cache_dir / 'project-x-1.0.0' / 'data' / 'f2.txt' - p2 = cache_dir / 'project-x-2.0.0' / 'data' / 'f2.txt' + p1 = cache_dir / "project-x-1.0.0" / "data" / "f2.txt" + p2 = cache_dir / "project-x-2.0.0" / "data" / "f2.txt" # note that f2.txt is identical between v 1.0.0 and 2.0.0 assert p1.is_file() @@ -219,23 +210,23 @@ def test_on_removed_files(tmpdir, example_datasets): # make sure that the file which has been moved is now # marked as not existing - cache.load_manifest('project-x_manifest_v1.0.0.json') - test_path = cache.data_path('2') - assert not test_path['exists'] + cache.load_manifest("project-x_manifest_v1.0.0.json") + test_path = cache.data_path("2") + assert not test_path["exists"] - cache.load_manifest('project-x_manifest_v2.0.0.json') - test_path = cache.data_path('2') - assert not test_path['exists'] + cache.load_manifest("project-x_manifest_v2.0.0.json") + test_path = cache.data_path("2") + assert not test_path["exists"] # now, re-download the data by way of manifest 2 # and verify that the symlink relationship is # re-established - p2 = cache.download_data('2') + p2 = cache.download_data("2") assert p2.is_file() assert p2.is_symlink() # because the symlink was not removed - cache.load_manifest('project-x_manifest_v1.0.0.json') - p1 = cache.download_data('2') + cache.load_manifest("project-x_manifest_v1.0.0.json") + p1 = cache.download_data("2") assert p1.is_file() assert not p1.is_symlink() @@ -243,7 +234,7 @@ def test_on_removed_files(tmpdir, example_datasets): assert p1.absolute() != p2.absolute() hasher = hashlib.blake2b() - with open(p2, 'rb') as in_file: + with open(p2, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_hash @@ -254,34 +245,33 @@ def test_on_removed_symlinks(tmpdir, example_datasets): Test that the CloudCache re-downloads files when the the symlinks have been removed """ - bucket_name = 'corruption_bucket' - create_bucket(bucket_name, - example_datasets) + bucket_name = "corruption_bucket" + create_bucket(bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - version_list = ('1.0.0', '2.0.0', '3.0.0') - file_id_list = ('1', '2', '3') + version_list = ("1.0.0", "2.0.0", "3.0.0") + file_id_list = ("1", "2", "3") for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: cache.download_data(file_id) # make sure that all files exist for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: attr = cache.data_path(file_id) - assert attr['exists'] + assert attr["exists"] hasher = hashlib.blake2b() - hasher.update(b'4567890') + hasher.update(b"4567890") true_hash = hasher.hexdigest() - p1 = cache_dir / 'project-x-1.0.0' / 'data' / 'f2.txt' - p2 = cache_dir / 'project-x-2.0.0' / 'data' / 'f2.txt' + p1 = cache_dir / "project-x-1.0.0" / "data" / "f2.txt" + p2 = cache_dir / "project-x-2.0.0" / "data" / "f2.txt" # note that f2.txt is identical between v 1.0.0 and 2.0.0 assert p1.is_file() @@ -297,17 +287,17 @@ def test_on_removed_symlinks(tmpdir, example_datasets): assert not p2.is_symlink() assert p1.is_file() - cache.load_manifest('project-x_manifest_v2.0.0.json') - test_path = cache.data_path('2') - assert test_path['exists'] - p2 = pathlib.Path(test_path['local_path']) + cache.load_manifest("project-x_manifest_v2.0.0.json") + test_path = cache.data_path("2") + assert test_path["exists"] + p2 = pathlib.Path(test_path["local_path"]) assert p2.is_symlink() assert p2.exists() assert p1.absolute() != p2.absolute() assert p1.resolve() == p2.resolve() hasher = hashlib.blake2b() - with open(p2, 'rb') as in_file: + with open(p2, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == true_hash @@ -318,61 +308,60 @@ def test_corrupted_download_manifest(tmpdir, example_datasets): Test that CloudCache can handle the case where the _downloaded_data_path dict gets corrupted """ - bucket_name = 'manifest_corruption_bucket' - create_bucket(bucket_name, - example_datasets) + bucket_name = "manifest_corruption_bucket" + create_bucket(bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' - cache = S3CloudCache(cache_dir, bucket_name, 'project-x') + cache_dir = pathlib.Path(tmpdir) / "cache" + cache = S3CloudCache(cache_dir, bucket_name, "project-x") - version_list = ('1.0.0', '2.0.0', '3.0.0') - file_id_list = ('1', '2', '3') + version_list = ("1.0.0", "2.0.0", "3.0.0") + file_id_list = ("1", "2", "3") for version in version_list: - cache.load_manifest(f'project-x_manifest_v{version}.json') + cache.load_manifest(f"project-x_manifest_v{version}.json") for file_id in file_id_list: cache.download_data(file_id) - with open(cache._downloaded_data_path, 'rb') as in_file: + with open(cache._downloaded_data_path, "rb") as in_file: src_data = json.load(in_file) # write a corrupted downloaded_data_path for k in src_data: - src_data[k] = '' - with open(cache._downloaded_data_path, 'w') as out_file: + src_data[k] = "" + with open(cache._downloaded_data_path, "w") as out_file: out_file.write(json.dumps(src_data, indent=2)) hasher = hashlib.blake2b() - hasher.update(b'4567890') + hasher.update(b"4567890") true_hash = hasher.hexdigest() - cache.load_manifest('project-x_manifest_v1.0.0.json') - attr = cache.data_path('2') + cache.load_manifest("project-x_manifest_v1.0.0.json") + attr = cache.data_path("2") # assert below will pass; because file exists and is not yet corrupted, # CloudCache won't consult _downloaded_data_path - assert attr['exists'] + assert attr["exists"] # now remove one of the data files - attr['local_path'].unlink() + attr["local_path"].unlink() # now that the file is corrupted, 'exists' is False - attr = cache.data_path('2') - assert not attr['exists'] + attr = cache.data_path("2") + assert not attr["exists"] # note that v0.2.0/f2.txt is identical to v0.1.0/f2.txt - cache.load_manifest('project-x_manifest_v2.0.0.json') - attr = cache.data_path('2') - assert not attr['exists'] + cache.load_manifest("project-x_manifest_v2.0.0.json") + attr = cache.data_path("2") + assert not attr["exists"] # re download the file - cache.download_data('2') - attr = cache.data_path('2') - downloaded_path = attr['local_path'] + cache.download_data("2") + attr = cache.data_path("2") + downloaded_path = attr["local_path"] - assert attr['exists'] + assert attr["exists"] hasher = hashlib.blake2b() - with open(attr['local_path'], 'rb') as in_file: + with open(attr["local_path"], "rb") as in_file: hasher.update(in_file.read()) test_hash = hasher.hexdigest() assert test_hash == true_hash @@ -380,11 +369,11 @@ def test_corrupted_download_manifest(tmpdir, example_datasets): # check that the v0.1.0 version of the file, which should be # identical to the v0.2.0 version of the file, is also # fixed - cache.load_manifest('project-x_manifest_v1.0.0.json') - attr = cache.data_path('2') - assert attr['exists'] - assert attr['local_path'].resolve() == downloaded_path.resolve() - assert attr['local_path'].absolute() != downloaded_path.absolute() + cache.load_manifest("project-x_manifest_v1.0.0.json") + attr = cache.data_path("2") + assert attr["exists"] + assert attr["local_path"].resolve() == downloaded_path.resolve() + assert attr["local_path"].absolute() != downloaded_path.absolute() @mock_s3 @@ -404,49 +393,48 @@ def _download_file(self, file_attributes: CacheFileAttributes): # first two versions of dataset are identical; # third differs example_data = {} - example_data['1.0.0'] = {} - example_data['1.0.0']['f1.txt'] = {'file_id': '1', 'data': b'abc'} - example_data['1.0.0']['f2.txt'] = {'file_id': '2', 'data': b'def'} + example_data["1.0.0"] = {} + example_data["1.0.0"]["f1.txt"] = {"file_id": "1", "data": b"abc"} + example_data["1.0.0"]["f2.txt"] = {"file_id": "2", "data": b"def"} - example_data['2.0.0'] = {} - example_data['2.0.0']['f1.txt'] = {'file_id': '1', 'data': b'abc'} - example_data['2.0.0']['f2.txt'] = {'file_id': '2', 'data': b'def'} + example_data["2.0.0"] = {} + example_data["2.0.0"]["f1.txt"] = {"file_id": "1", "data": b"abc"} + example_data["2.0.0"]["f2.txt"] = {"file_id": "2", "data": b"def"} - example_data['3.0.0'] = {} - example_data['3.0.0']['f1.txt'] = {'file_id': '1', 'data': b'tuv'} - example_data['3.0.0']['f2.txt'] = {'file_id': '2', 'data': b'wxy'} + example_data["3.0.0"] = {} + example_data["3.0.0"]["f1.txt"] = {"file_id": "1", "data": b"tuv"} + example_data["3.0.0"]["f2.txt"] = {"file_id": "2", "data": b"wxy"} - test_bucket_name = 'cache_from_scratch_bucket' - create_bucket(test_bucket_name, - example_data) + test_bucket_name = "cache_from_scratch_bucket" + create_bucket(test_bucket_name, example_data) - cache_dir = pathlib.Path(tmpdir) / 'cache' + cache_dir = pathlib.Path(tmpdir) / "cache" # read in v1.0.0 data files using normal S3 cache class with warnings_mod.catch_warnings(record=True) as w: warnings_mod.simplefilter("always") - cache = S3CloudCache(cache_dir, test_bucket_name, 'project-x') + cache = S3CloudCache(cache_dir, test_bucket_name, "project-x") # make sure no MissingLocalManifestWarnings were raised - w_type = 'MissingLocalManifestWarning' + w_type = "MissingLocalManifestWarning" for wi in w: if wi.category.__name__ == w_type: - msg = 'Raised MissingLocalManifestWarning on empty ' - msg += 'cache dir' + msg = "Raised MissingLocalManifestWarning on empty " + msg += "cache dir" assert False, msg expected_hash = {} - cache.load_manifest('project-x_manifest_v1.0.0.json') - for file_id in ('1', '2'): + cache.load_manifest("project-x_manifest_v1.0.0.json") + for file_id in ("1", "2"): local_path = cache.download_data(file_id) hasher = hashlib.blake2b() - with open(local_path, 'rb') as in_file: + with open(local_path, "rb") as in_file: hasher.update(in_file.read()) expected_hash[file_id] = hasher.hexdigest() # load the other manifests, so DummyCache can get it - cache.load_manifest('project-x_manifest_v2.0.0.json') - cache.load_manifest('project-x_manifest_v3.0.0.json') + cache.load_manifest("project-x_manifest_v2.0.0.json") + cache.load_manifest("project-x_manifest_v3.0.0.json") # delete the JSON file that maps local path to file hash lookup_path = cache._downloaded_data_path @@ -461,23 +449,23 @@ def _download_file(self, file_attributes: CacheFileAttributes): # are returned. This will mean that the local manifest mapping # filename to file hash was correctly reconstructed. with pytest.warns(MissingLocalManifestWarning): - dummy = DummyCache(cache_dir, test_bucket_name, 'project-x') + dummy = DummyCache(cache_dir, test_bucket_name, "project-x") dummy.construct_local_manifest() - dummy.load_manifest('project-x_manifest_v2.0.0.json') - for file_id in ('1', '2'): + dummy.load_manifest("project-x_manifest_v2.0.0.json") + for file_id in ("1", "2"): local_path = dummy.download_data(file_id) hasher = hashlib.blake2b() - with open(local_path, 'rb') as in_file: + with open(local_path, "rb") as in_file: hasher.update(in_file.read()) assert hasher.hexdigest() == expected_hash[file_id] # make sure that dummy really is unable to download by trying # (and failing) to get data from v3.0.0 - dummy.load_manifest('project-x_manifest_v3.0.0.json') + dummy.load_manifest("project-x_manifest_v3.0.0.json") with pytest.raises(RuntimeError): - dummy.download_data('1') + dummy.download_data("1") @mock_s3 @@ -486,22 +474,21 @@ def test_local_cache_symlink(tmpdir, example_datasets): Test that a LocalCache is smart enough to construct a symlink where appropriate """ - test_bucket_name = 'local_cache_test_bucket' - create_bucket(test_bucket_name, - example_datasets) + test_bucket_name = "local_cache_test_bucket" + create_bucket(test_bucket_name, example_datasets) - cache_dir = pathlib.Path(tmpdir) / 'cache' + cache_dir = pathlib.Path(tmpdir) / "cache" # create an online cache and download some data - online_cache = S3CloudCache(cache_dir, test_bucket_name, 'project-x') - online_cache.load_manifest('project-x_manifest_v1.0.0.json') - p0 = online_cache.download_data('1') - online_cache.load_manifest('project-x_manifest_v3.0.0.json') + online_cache = S3CloudCache(cache_dir, test_bucket_name, "project-x") + online_cache.load_manifest("project-x_manifest_v1.0.0.json") + p0 = online_cache.download_data("1") + online_cache.load_manifest("project-x_manifest_v3.0.0.json") # path to file we intend to download # (just making sure it wasn't accidentally created early # by the online cache) - shld_be = cache_dir / 'project-x-3.0.0/data/f1.txt' + shld_be = cache_dir / "project-x-3.0.0/data/f1.txt" assert not shld_be.exists() del online_cache @@ -509,17 +496,17 @@ def test_local_cache_symlink(tmpdir, example_datasets): # create a local cache pointing to the same cache directory # an try to access a data file that, while not downloaded, # is identical to a file that has been downloaded - local_cache = LocalCache(cache_dir, test_bucket_name, 'project-x') - local_cache.load_manifest('project-x_manifest_v3.0.0.json') - attr = local_cache.data_path('1') - assert attr['exists'] - assert attr['local_path'].absolute() == shld_be.absolute() - assert attr['local_path'].is_symlink() - assert attr['local_path'].resolve() == p0.resolve() + local_cache = LocalCache(cache_dir, test_bucket_name, "project-x") + local_cache.load_manifest("project-x_manifest_v3.0.0.json") + attr = local_cache.data_path("1") + assert attr["exists"] + assert attr["local_path"].absolute() == shld_be.absolute() + assert attr["local_path"].is_symlink() + assert attr["local_path"].resolve() == p0.resolve() # test that LocalCache does not have access to data that # has not been downloaded - attr = local_cache.data_path('2') - assert not attr['exists'] + attr = local_cache.data_path("2") + assert not attr["exists"] with pytest.raises(NotImplementedError): - local_cache.download_data('2') + local_cache.download_data("2") diff --git a/allensdk/test/api/cloud_cache/test_static_local_cache.py b/allensdk/test/api/cloud_cache/test_static_local_cache.py index 484828f018..ee183d9459 100644 --- a/allensdk/test/api/cloud_cache/test_static_local_cache.py +++ b/allensdk/test/api/cloud_cache/test_static_local_cache.py @@ -18,33 +18,21 @@ def mounted_s3_dataset_fixture(tmp_path, request) -> Tuple[Path, str, dict]: # Get fixture parameters project_name = request.param.get("project_name", "test_project_name_1") dataset_version = request.param.get("dataset_version", "0.3.0") - metadata_file_id_column_name = request.param.get( - "metadata_file_id_column_name", "file_id" - ) + metadata_file_id_column_name = request.param.get("metadata_file_id_column_name", "file_id") metadata_files_contents = request.param.get( "metadata_files_contents", # Each item in list is a tuple of: # (metadata_filename, metadata_contents) [ ("metadata_1.csv", {"mouse": [1, 2, 3], "sex": ["F", "F", "M"]}), - ( - "metadata_2.csv", - { - "experiment": [4, 5, 6], - metadata_file_id_column_name: ["data1", "data2", "data3"] - } - ) - ] + ("metadata_2.csv", {"experiment": [4, 5, 6], metadata_file_id_column_name: ["data1", "data2", "data3"]}), + ], ) data_files_contents = request.param.get( "data_files_contents", # Each item in list is a tuple of: # (data_filename, data_contents) - [ - ("data_1.nwb", "123456"), - ("data_2.nwb", "abcdef"), - ("data_3.nwb", "ghijkl") - ] + [("data_1.nwb", "123456"), ("data_2.nwb", "abcdef"), ("data_3.nwb", "ghijkl")], ) # Create mock mounted s3 directory structure @@ -64,12 +52,9 @@ def mounted_s3_dataset_fixture(tmp_path, request) -> Tuple[Path, str, dict]: df_to_save.to_csv(str(meta_save_path), index=False) manifest_meta_entries[meta_fname.rstrip(".csv")] = { - "url": ( - f"http://{project_name}.s3.amazonaws.com/{project_name}" - f"/project_metadata/{meta_fname}" - ), + "url": (f"http://{project_name}.s3.amazonaws.com/{project_name}/project_metadata/{meta_fname}"), "version_id": "test_placeholder", - "file_hash": file_hash_from_path(meta_save_path) + "file_hash": file_hash_from_path(meta_save_path), } # Create data files and manifest entries @@ -79,16 +64,13 @@ def mounted_s3_dataset_fixture(tmp_path, request) -> Tuple[Path, str, dict]: manifest_data_entries = dict() for file_fname, file_contents in data_files_contents: file_save_path = mock_data_dir / file_fname - with file_save_path.open('w') as f: + with file_save_path.open("w") as f: f.write(file_contents) manifest_data_entries[file_fname.rstrip(".nwb")] = { - "url": ( - f"http://{project_name}.s3.amazonaws.com/{project_name}" - f"/project_data/{file_fname}" - ), + "url": (f"http://{project_name}.s3.amazonaws.com/{project_name}/project_data/{file_fname}"), "version_id": "test_placeholder", - "file_hash": file_hash_from_path(file_save_path) + "file_hash": file_hash_from_path(file_save_path), } # Create manifest dir and manifest @@ -100,35 +82,22 @@ def mounted_s3_dataset_fixture(tmp_path, request) -> Tuple[Path, str, dict]: manifest_contents = { "project_name": project_name, "manifest_version": dataset_version, - "data_pipeline": [ - { - "name": "AllenSDK", - "version": "2.11.0", - "comment": "This is a test entry. NOT REAL." - } - ], + "data_pipeline": [{"name": "AllenSDK", "version": "2.11.0", "comment": "This is a test entry. NOT REAL."}], "metadata_file_id_column_name": metadata_file_id_column_name, "metadata_files": manifest_meta_entries, - "data_files": manifest_data_entries + "data_files": manifest_data_entries, } - with manifest_path.open('w') as f: + with manifest_path.open("w") as f: json.dump(manifest_contents, f, indent=4) - expected = { - "expected_metadata": metadata_files_contents, - "expected_data": data_files_contents - } + expected = {"expected_metadata": metadata_files_contents, "expected_data": data_files_contents} return mock_mounted_base_dir, project_name, expected @pytest.mark.parametrize( - "mounted_s3_dataset_fixture", - [ - {"project_name": "visual-behavior-ophys"} - ], - indirect=["mounted_s3_dataset_fixture"] + "mounted_s3_dataset_fixture", [{"project_name": "visual-behavior-ophys"}], indirect=["mounted_s3_dataset_fixture"] ) def test_static_local_cache_access(mounted_s3_dataset_fixture): local_static_cache_dir, proj_name, expected = mounted_s3_dataset_fixture @@ -152,30 +121,18 @@ def test_static_local_cache_access(mounted_s3_dataset_fixture): @pytest.mark.parametrize( "num_manifests, project_name, create_project_folders, expected", [ - ( - 2, - "test_project", - True, - ['test_project_manifest_v0.1.0.json'] - ), - ( - 4, - "test_project_2", - True, - ['test_project_2_manifest_v0.3.0.json'] - ), + (2, "test_project", True, ["test_project_manifest_v0.1.0.json"]), + (4, "test_project_2", True, ["test_project_2_manifest_v0.3.0.json"]), # This test case is expected to raise a RuntimeError ( None, # Not applicable "test_project_2", False, - None # Not applicable - ) - ] + None, # Not applicable + ), + ], ) -def test_static_local_cache_list_all_manifests( - tmp_path, num_manifests, project_name, create_project_folders, expected -): +def test_static_local_cache_list_all_manifests(tmp_path, num_manifests, project_name, create_project_folders, expected): cache_dir = tmp_path / "cache_dir" cache_dir.mkdir() @@ -187,9 +144,7 @@ def test_static_local_cache_list_all_manifests( manifests_dir.mkdir() for n in range(num_manifests): - manifest_path = ( - manifests_dir / f"{project_name}_manifest_v0.{n}.0.json" - ) + manifest_path = manifests_dir / f"{project_name}_manifest_v0.{n}.0.json" manifest_path.touch() cache = StaticLocalCache(cache_dir, project_name) @@ -197,7 +152,5 @@ def test_static_local_cache_list_all_manifests( assert cache._manifest_file_names == expected else: - with pytest.raises( - RuntimeError, match="Expected the provided cache_dir" - ): + with pytest.raises(RuntimeError, match="Expected the provided cache_dir"): _ = StaticLocalCache(cache_dir, project_name) diff --git a/allensdk/test/api/cloud_cache/test_utils.py b/allensdk/test/api/cloud_cache/test_utils.py index e0a6f86086..ee54ea2e34 100644 --- a/allensdk/test/api/cloud_cache/test_utils.py +++ b/allensdk/test/api/cloud_cache/test_utils.py @@ -5,7 +5,6 @@ def test_bucket_name_from_url(): - url = 'https://dummy_bucket.s3.amazonaws.com/txt_file.txt?versionId="jklaafdaerew"' # noqa: E501 bucket_name = utils.bucket_name_from_url(url) assert bucket_name == "dummy_bucket" @@ -29,21 +28,20 @@ def test_bucket_name_from_url(): def test_relative_path_from_url(): url = 'https://dummy_bucket.s3.amazonaws.com/my/dir/txt_file.txt?versionId="jklaafdaerew"' # noqa: E501 relative_path = utils.relative_path_from_url(url) - assert relative_path == 'my/dir/txt_file.txt' + assert relative_path == "my/dir/txt_file.txt" def test_file_hash_from_path(tmpdir): - rng = np.random.RandomState(881) - alphabet = list('abcdefghijklmnopqrstuvwxyz') - fname = tmpdir / 'hash_dummy.txt' - with open(fname, 'w') as out_file: + alphabet = list("abcdefghijklmnopqrstuvwxyz") + fname = tmpdir / "hash_dummy.txt" + with open(fname, "w") as out_file: for ii in range(10): - out_file.write(''.join(rng.choice(alphabet, size=10))) - out_file.write('\n') + out_file.write("".join(rng.choice(alphabet, size=10))) + out_file.write("\n") hasher = hashlib.blake2b() - with open(fname, 'rb') as in_file: + with open(fname, "rb") as in_file: chunk = in_file.read(7) while len(chunk) > 0: hasher.update(chunk) diff --git a/allensdk/test/api/cloud_cache/test_windows_isilon_paths.py b/allensdk/test/api/cloud_cache/test_windows_isilon_paths.py index ac656a0520..0381276b02 100644 --- a/allensdk/test/api/cloud_cache/test_windows_isilon_paths.py +++ b/allensdk/test/api/cloud_cache/test_windows_isilon_paths.py @@ -16,20 +16,31 @@ def test_windows_path_to_isilon(monkeypatch, tmpdir): cache_dir = Path(tmpdir) - manifest_1 = {'manifest_version': '1', - 'metadata_file_id_column_name': 'file_id', - 'data_pipeline': 'placeholder', - 'project_name': 'my-project', - 'metadata_files': {'a.csv': {'url': 'http://www.junk.com/path/to/a.csv', # noqa: E501 - 'version_id': '1111', - 'file_hash': 'abcde'}, - 'b.csv': {'url': 'http://silly.com/path/to/b.csv', # noqa: E501 - 'version_id': '2222', - 'file_hash': 'fghijk'}}, - 'data_files': {'data_1': {'url': 'http://www.junk.com/data/path/data.csv', # noqa: E501 - 'version_id': '1111', - 'file_hash': 'lmnopqrst'}} - } + manifest_1 = { + "manifest_version": "1", + "metadata_file_id_column_name": "file_id", + "data_pipeline": "placeholder", + "project_name": "my-project", + "metadata_files": { + "a.csv": { + "url": "http://www.junk.com/path/to/a.csv", # noqa: E501 + "version_id": "1111", + "file_hash": "abcde", + }, + "b.csv": { + "url": "http://silly.com/path/to/b.csv", # noqa: E501 + "version_id": "2222", + "file_hash": "fghijk", + }, + }, + "data_files": { + "data_1": { + "url": "http://www.junk.com/data/path/data.csv", # noqa: E501 + "version_id": "1111", + "file_hash": "lmnopqrst", + } + }, + } manifest_path = tmpdir / "manifest.json" with open(manifest_path, "w") as f: json.dump(manifest_1, f) @@ -39,15 +50,15 @@ def dummy_file_exists(self, m): # we do not want paths to `/allen` to be resolved to # a local drive on the user's machine - bad_windows_pattern = re.compile('^[A-Z]\:') # noqa: W605 + bad_windows_pattern = re.compile("^[A-Z]\:") # noqa: W605 # make sure pattern is correctly formulated - m = bad_windows_pattern.search('C:\\a\windows\path') # noqa: W605 + m = bad_windows_pattern.search("C:\\a\windows\path") # noqa: W605 assert m is not None with monkeypatch.context() as ctx: - class TestCloudCache(CloudCacheBase): + class TestCloudCache(CloudCacheBase): def _download_file(self, m, o): pass @@ -57,14 +68,12 @@ def _download_manifest(self, m, o): def _list_all_manifests(self): pass - ctx.setattr(TestCloudCache, - '_file_exists', - dummy_file_exists) + ctx.setattr(TestCloudCache, "_file_exists", dummy_file_exists) - cache = TestCloudCache(cache_dir, 'proj') + cache = TestCloudCache(cache_dir, "proj") cache._manifest = Manifest(cache_dir, json_input=manifest_path) - m_path = cache.metadata_path('a.csv') + m_path = cache.metadata_path("a.csv") assert bad_windows_pattern.match(str(m_path)) is None - d_path = cache.data_path('data_1') + d_path = cache.data_path("data_1") assert bad_windows_pattern.match(str(d_path)) is None diff --git a/allensdk/test/api/cloud_cache/utils.py b/allensdk/test/api/cloud_cache/utils.py index f58524bbab..5518b2048c 100644 --- a/allensdk/test/api/cloud_cache/utils.py +++ b/allensdk/test/api/cloud_cache/utils.py @@ -4,11 +4,9 @@ import hashlib -def load_dataset(data_blobs: dict, - metadata_blobs: Union[dict, None], - manifest_version: str, - bucket_name: str, - client: boto3.client) -> None: +def load_dataset( + data_blobs: dict, metadata_blobs: Union[dict, None], manifest_version: str, bucket_name: str, client: boto3.client +) -> None: """ Load a test dataset into moto's mocked S3 @@ -37,76 +35,64 @@ def load_dataset(data_blobs: dict, and uploads the manifest to moto3 """ - project_name = 'project-x' + project_name = "project-x" for fname in data_blobs: - client.put_object(Bucket=bucket_name, - Key=f'project-x/data/{fname}', - Body=data_blobs[fname]['data']) + client.put_object(Bucket=bucket_name, Key=f"project-x/data/{fname}", Body=data_blobs[fname]["data"]) if metadata_blobs is not None: for fname in metadata_blobs: - client.put_object(Bucket=bucket_name, - Key=f'project-x/project_metadata/{fname}', - Body=metadata_blobs[fname]) + client.put_object(Bucket=bucket_name, Key=f"project-x/project_metadata/{fname}", Body=metadata_blobs[fname]) response = client.list_object_versions(Bucket=bucket_name) fname_to_version = {} - for obj in response['Versions']: - if obj['IsLatest']: - fname = obj['Key'].split('/')[-1] - fname_to_version[fname] = obj['VersionId'] + for obj in response["Versions"]: + if obj["IsLatest"]: + fname = obj["Key"].split("/")[-1] + fname_to_version[fname] = obj["VersionId"] manifest = {} - manifest['manifest_version'] = manifest_version - manifest['project_name'] = project_name - manifest['metadata_file_id_column_name'] = 'file_id' - manifest['metadata_files'] = {} - manifest['data_pipeline'] = 'placeholder' + manifest["manifest_version"] = manifest_version + manifest["project_name"] = project_name + manifest["metadata_file_id_column_name"] = "file_id" + manifest["metadata_files"] = {} + manifest["data_pipeline"] = "placeholder" data_file_dict = {} - url_root = f'http://{bucket_name}.s3.amazonaws.com/{project_name}/data' + url_root = f"http://{bucket_name}.s3.amazonaws.com/{project_name}/data" for fname in data_blobs: - url = f'{url_root}/{fname}' + url = f"{url_root}/{fname}" hasher = hashlib.blake2b() - hasher.update(data_blobs[fname]['data']) + hasher.update(data_blobs[fname]["data"]) checksum = hasher.hexdigest() - data_file = {'url': url, - 'version_id': fname_to_version[fname], - 'file_hash': checksum} + data_file = {"url": url, "version_id": fname_to_version[fname], "file_hash": checksum} - data_file_dict[data_blobs[fname]['file_id']] = data_file + data_file_dict[data_blobs[fname]["file_id"]] = data_file - manifest['data_files'] = data_file_dict + manifest["data_files"] = data_file_dict if metadata_blobs is not None: - url_root = f'http://{bucket_name}.s3.amazonaws.com/{project_name}/' - url_root += 'project_metadata' + url_root = f"http://{bucket_name}.s3.amazonaws.com/{project_name}/" + url_root += "project_metadata" metadata_dict = {} for fname in metadata_blobs: - url = f'{url_root}/{fname}' + url = f"{url_root}/{fname}" hasher = hashlib.blake2b() hasher.update(metadata_blobs[fname]) - metadata_dict[fname] = {'url': url, - 'file_hash': hasher.hexdigest(), - 'version_id': fname_to_version[fname]} + metadata_dict[fname] = {"url": url, "file_hash": hasher.hexdigest(), "version_id": fname_to_version[fname]} - manifest['metadata_files'] = metadata_dict + manifest["metadata_files"] = metadata_dict - manifest_k = f'{project_name}/manifests/' - manifest_k += f'{project_name}_manifest_v{manifest_version}.json' - client.put_object(Bucket=bucket_name, - Key=manifest_k, - Body=bytes(json.dumps(manifest), 'utf-8')) + manifest_k = f"{project_name}/manifests/" + manifest_k += f"{project_name}_manifest_v{manifest_version}.json" + client.put_object(Bucket=bucket_name, Key=manifest_k, Body=bytes(json.dumps(manifest), "utf-8")) return None -def create_bucket(test_bucket_name: str, - datasets: dict, - metadatasets: Optional[dict] = None) -> None: +def create_bucket(test_bucket_name: str, datasets: dict, metadatasets: Optional[dict] = None) -> None: """ Create a bucket and populate it with example datasets @@ -124,14 +110,14 @@ def create_bucket(test_bucket_name: str, metadata files to be loaded to the bucket (default: None) """ - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") # upload first dataset for v in datasets.keys(): @@ -139,10 +125,6 @@ def create_bucket(test_bucket_name: str, m = metadatasets[v] else: m = None - load_dataset(datasets[v], - m, - v, - test_bucket_name, - client) + load_dataset(datasets[v], m, v, test_bucket_name, client) return None diff --git a/allensdk/test/api/queries/test_utils.py b/allensdk/test/api/queries/test_utils.py index 18a492e9f9..4738f09938 100644 --- a/allensdk/test/api/queries/test_utils.py +++ b/allensdk/test/api/queries/test_utils.py @@ -1,18 +1,10 @@ import pytest -from allensdk.internal.api.queries.utils import \ - _convert_list_of_string_to_sql_safe_string +from allensdk.internal.api.queries.utils import _convert_list_of_string_to_sql_safe_string @pytest.mark.parametrize( - 'strings, expected', - ( - (['A', 'B'], ["'A'", "'B'"]), - (["'A'", "'B'"], ["'A'", "'B'"]), - (['A'], ["'A'"]), - ([], []) - ) + "strings, expected", ((["A", "B"], ["'A'", "'B'"]), (["'A'", "'B'"], ["'A'", "'B'"]), (["A"], ["'A'"]), ([], [])) ) def test_convert_list_of_string_to_sql_safe_string(strings, expected): - assert _convert_list_of_string_to_sql_safe_string(strings=strings) == \ - expected + assert _convert_list_of_string_to_sql_safe_string(strings=strings) == expected diff --git a/allensdk/test/api/test_annotated_section_data_set_api.py b/allensdk/test/api/test_annotated_section_data_set_api.py index 8c22dafb15..fac1634b0a 100644 --- a/allensdk/test/api/test_annotated_section_data_set_api.py +++ b/allensdk/test/api/test_annotated_section_data_set_api.py @@ -33,8 +33,7 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -from allensdk.api.queries.annotated_section_data_sets_api import \ - AnnotatedSectionDataSetsApi +from allensdk.api.queries.annotated_section_data_sets_api import AnnotatedSectionDataSetsApi import pytest from unittest.mock import MagicMock @@ -43,7 +42,7 @@ def annotated(): asdsa = AnnotatedSectionDataSetsApi() - asdsa.json_msg_query = MagicMock(name='json_msg_query') + asdsa.json_msg_query = MagicMock(name="json_msg_query") return asdsa @@ -54,13 +53,15 @@ def test_get_annotated_section_data_set(annotated): intensity_values=["High", "Low", "Medium"], density_values=["High", "Low"], pattern_values=["Full"], - age_names=["E11.5", "13.5"]) + age_names=["E11.5", "13.5"], + ) annotated.json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/annotated_section_data_sets.json" "?structures=112763676&intensity_values='High','Low','Medium'" "&density_values='High','Low'" - "&pattern_values='Full'&age_names='E11.5','13.5'") + "&pattern_values='Full'&age_names='E11.5','13.5'" + ) def test_get_compound_annotated_section_data_set(annotated): @@ -69,29 +70,31 @@ def test_get_compound_annotated_section_data_set(annotated): intensity_values=["High", "Low", "Medium"], density_values=["High", "Low"], pattern_values=["Full"], - age_names=["E11.5", "13.5"]) + age_names=["E11.5", "13.5"], + ) annotated.json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/annotated_section_data_sets.json?" "structures=112763676" "&intensity_values='High','Low','Medium'&density_values='High','Low'" "&pattern_values='Full'" - "&age_names='E11.5','13.5'") + "&age_names='E11.5','13.5'" + ) def test_get_annotated_section_data_set_via_rma(annotated): - annotated.json_msg_query = \ - MagicMock(name='json_msg_query') + annotated.json_msg_query = MagicMock(name="json_msg_query") annotated.get_compound_annotated_section_data_sets( - [{'structures': [112763676], - 'intensity_values': ['High', 'Low'], - 'link': 'or'}, - {'structures': [112763686], - 'intensity_values': ['Low']}]) + [ + {"structures": [112763676], "intensity_values": ["High", "Low"], "link": "or"}, + {"structures": [112763686], "intensity_values": ["Low"]}, + ] + ) annotated.json_msg_query.assert_called_once_with( "http://api.brain-map.org" "/api/v2/compound_annotated_section_data_sets.json" "?query=[structures $in 112763676 : intensity_values $in 'High','Low']" - " or [structures $in 112763686 : intensity_values $in 'Low']") + " or [structures $in 112763686 : intensity_values $in 'Low']" + ) diff --git a/allensdk/test/api/test_api.py b/allensdk/test/api/test_api.py index c00b5a87a0..124b829e87 100644 --- a/allensdk/test/api/test_api.py +++ b/allensdk/test/api/test_api.py @@ -47,7 +47,8 @@ from allensdk.api.api import Api, stream_file_over_http, stream_zip_directory_over_http -_msg = {'whatever': True} +_msg = {"whatever": True} + @pytest.fixture def api(): @@ -56,21 +57,19 @@ def api(): @pytest.fixture def response(): - resp = MagicMock() - resp.iter_content = lambda *a, **k: iter([b'1', b'2', b'3']) + resp.iter_content = lambda *a, **k: iter([b"1", b"2", b"3"]) return resp @pytest.fixture def zip_response(): - flike = io.BytesIO() - data = '122333444455555' + data = "122333444455555" - zipper = zipfile.ZipFile(flike, mode='w') - zipper.writestr('test.txt', data) + zipper = zipfile.ZipFile(flike, mode="w") + zipper.writestr("test.txt", data) zipper.close() return flike.getvalue() @@ -78,96 +77,79 @@ def zip_response(): def test_failed_download(api): with pytest.raises(HTTPError) as e_info: - api.retrieve_file_over_http('http://example.com/yo.jpg', - '/tmp/testfile') + api.retrieve_file_over_http("http://example.com/yo.jpg", "/tmp/testfile") - assert e_info.typename == 'HTTPError' + assert e_info.typename == "HTTPError" def test_request_timeout(api): def raise_read_timeout(response, path=None): raise requests.exceptions.ReadTimeout - with patch('requests.get', return_value=MagicMock()) as get_mock: + with patch("requests.get", return_value=MagicMock()) as get_mock: response_mock = get_mock.return_value response_mock.raise_for_status = MagicMock() - - with patch( - 'requests_toolbelt.downloadutils.stream.stream_response_to_file', - MagicMock(name='stream_response_to_file', - side_effect=raise_read_timeout)) as stream_mock: - - with patch(builtins.__name__ + '.open', - mock_open(), - create=True) as open_mock: - with patch('os.remove', MagicMock()) as os_remove: + with patch( + "requests_toolbelt.downloadutils.stream.stream_response_to_file", + MagicMock(name="stream_response_to_file", side_effect=raise_read_timeout), + ) as stream_mock: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as open_mock: + with patch("os.remove", MagicMock()) as os_remove: with pytest.raises(requests.exceptions.ReadTimeout) as e_info: - api.retrieve_file_over_http('http://example.com/yo.jpg', - '/tmp/testfile') + api.retrieve_file_over_http("http://example.com/yo.jpg", "/tmp/testfile") - assert e_info.typename == 'ReadTimeout' + assert e_info.typename == "ReadTimeout" stream_mock.assert_called_with(response_mock, path=open_mock.return_value) - get_mock.assert_called_once_with('http://example.com/yo.jpg', - stream=True, - timeout=(9.05, 31.1)) - open_mock.assert_called_once_with('/tmp/testfile', 'wb') - os_remove.assert_called_once_with('/tmp/testfile') + get_mock.assert_called_once_with("http://example.com/yo.jpg", stream=True, timeout=(9.05, 31.1)) + open_mock.assert_called_once_with("/tmp/testfile", "wb") + os_remove.assert_called_once_with("/tmp/testfile") @patch("allensdk.core.json_utilities.read_url_post", return_value=_msg) def test_do_query_post(ju_read_url_post, api): - api.do_query(lambda *a, **k: 'http://localhost/%s' % (a[0]), - lambda d: d, - "wow", - post=True) + api.do_query(lambda *a, **k: "http://localhost/%s" % (a[0]), lambda d: d, "wow", post=True) - ju_read_url_post.assert_called_once_with('http://localhost/wow') + ju_read_url_post.assert_called_once_with("http://localhost/wow") @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_do_query_get(ju_read_url_get, api): - api.do_query(lambda *a, **k: 'http://localhost/%s' % (a[0]), - lambda d: d, - "wow", - post=False) + api.do_query(lambda *a, **k: "http://localhost/%s" % (a[0]), lambda d: d, "wow", post=False) - ju_read_url_get.assert_called_once_with('http://localhost/wow') + ju_read_url_get.assert_called_once_with("http://localhost/wow") @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_load_api_schema(ju_read_url_get, api): api.load_api_schema() - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/enumerate.json') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/enumerate.json") def test_stream_file_over_http(response, tmpdir_factory): + path = tmpdir_factory.mktemp("file_stream_test").join("test.txt") - path = tmpdir_factory.mktemp('file_stream_test').join('test.txt') - - with patch('requests.get', return_value=response): - stream_file_over_http('https://fish.gov', str(path)) + with patch("requests.get", return_value=response): + stream_file_over_http("https://fish.gov", str(path)) - with open(str(path), 'r') as fil: + with open(str(path), "r") as fil: data = fil.read() - assert( data == '123' ) + assert data == "123" def test_stream_zip_directory_over_http(zip_response, tmpdir_factory): + path = tmpdir_factory.mktemp("zip_stream_test").join("test.txt") - path = tmpdir_factory.mktemp('zip_stream_test').join('test.txt') - - with patch('requests.get'): - with patch('requests_toolbelt.downloadutils.stream.stream_response_to_file', - side_effect=lambda r, b: b.write(zip_response)): - - stream_zip_directory_over_http('https://fish.gov', os.path.dirname(str(path))) + with patch("requests.get"): + with patch( + "requests_toolbelt.downloadutils.stream.stream_response_to_file", + side_effect=lambda r, b: b.write(zip_response), + ): + stream_zip_directory_over_http("https://fish.gov", os.path.dirname(str(path))) - with open(str(path), 'r') as fil: + with open(str(path), "r") as fil: data = fil.read() - assert(data == '122333444455555') - \ No newline at end of file + assert data == "122333444455555" diff --git a/allensdk/test/api/test_biophysical_api.py b/allensdk/test/api/test_biophysical_api.py index abfeaf1f1e..aa114130ce 100644 --- a/allensdk/test/api/test_biophysical_api.py +++ b/allensdk/test/api/test_biophysical_api.py @@ -10,9 +10,9 @@ @pytest.fixture def neuronal_model_response(): dirname = os.path.dirname(__file__) - path = os.path.join(dirname, 'response_test_data', '472451419_response.json') + path = os.path.join(dirname, "response_test_data", "472451419_response.json") - with open(path, 'r') as jf: + with open(path, "r") as jf: data = json.load(jf) return data @@ -20,60 +20,62 @@ def neuronal_model_response(): @pytest.fixture def biophys_api(): - endpoint = 'http://twarehouse-backup' + endpoint = "http://twarehouse-backup" return BiophysicalApi(endpoint) -@pytest.mark.parametrize('model_id', [3]) -@pytest.mark.parametrize('fmt', [None, 'json', 'xml']) +@pytest.mark.parametrize("model_id", [3]) +@pytest.mark.parametrize("fmt", [None, "json", "xml"]) def test_build_rma(model_id, fmt, biophys_api): if fmt is None: - fmt_exp = 'json' + fmt_exp = "json" obt = biophys_api.build_rma(model_id) else: fmt_exp = fmt obt = biophys_api.build_rma(model_id, fmt_exp) - exp = 'http://twarehouse-backup/api/v2/data/query.{}?'\ - 'q=model::NeuronalModel,'\ - 'rma::criteria,[id$eq{}],'\ - 'neuronal_model_template(well_known_files(well_known_file_type)),'\ - 'specimen(ephys_result(well_known_files(well_known_file_type)),'\ - 'neuron_reconstructions(well_known_files(well_known_file_type)),ephys_sweeps),'\ - 'well_known_files(well_known_file_type),'\ - 'rma::include,neuronal_model_template(well_known_files(well_known_file_type)),'\ - 'specimen(ephys_result(well_known_files(well_known_file_type)),'\ - 'neuron_reconstructions(well_known_files(well_known_file_type)),ephys_sweeps),'\ - 'well_known_files(well_known_file_type)' + exp = ( + "http://twarehouse-backup/api/v2/data/query.{}?" + "q=model::NeuronalModel," + "rma::criteria,[id$eq{}]," + "neuronal_model_template(well_known_files(well_known_file_type))," + "specimen(ephys_result(well_known_files(well_known_file_type))," + "neuron_reconstructions(well_known_files(well_known_file_type)),ephys_sweeps)," + "well_known_files(well_known_file_type)," + "rma::include,neuronal_model_template(well_known_files(well_known_file_type))," + "specimen(ephys_result(well_known_files(well_known_file_type))," + "neuron_reconstructions(well_known_files(well_known_file_type)),ephys_sweeps)," + "well_known_files(well_known_file_type)" + ) exp = exp.format(fmt_exp, model_id) assert obt == exp def test_is_well_known_file_type(biophys_api): - wkf = {'well_known_file_type': {'name': 'fish'}} + wkf = {"well_known_file_type": {"name": "fish"}} - assert(biophys_api.is_well_known_file_type(wkf, 'fish')) - assert(not biophys_api.is_well_known_file_type(wkf, 'fowl')) + assert biophys_api.is_well_known_file_type(wkf, "fish") + assert not biophys_api.is_well_known_file_type(wkf, "fowl") @patch.object(BiophysicalApi, "json_msg_query") def test_get_neuronal_models(mock_json_msg_query, biophys_api): - biophys_api.get_neuronal_models([386049446,469753383]) + biophys_api.get_neuronal_models([386049446, 469753383]) mock_json_msg_query.assert_called_once_with( "http://twarehouse-backup/api/v2/data/query.json?" "q=model::NeuronalModel,rma::criteria,[neuronal_model_template_id$in491455321,329230710]," - "[specimen_id$in386049446,469753383],rma::options[num_rows$eq'all'][count$eqfalse]") + "[specimen_id$in386049446,469753383],rma::options[num_rows$eq'all'][count$eqfalse]" + ) def test_read_json(biophys_api, neuronal_model_response): - obt = biophys_api.read_json(neuronal_model_response) - assert(obt['stimulus']['491198851'] == "386049444.nwb") - assert(obt['morphology']['491459173'] == "Nr5a1-Cre_Ai14-177334.05.01.01_491459171_m.swc") - assert(obt['fit']['497235805'] == '386049446_fit.json') - assert(obt['marker']['496607103'] == 'Nr5a1-Cre_Ai14-177334.05.01.01_491459171_marker_m.swc') - assert(obt['modfiles']['395337293'] == os.path.join('modfiles', 'SK.mod')) - assert(np.allclose(biophys_api.sweeps, [42])) + assert obt["stimulus"]["491198851"] == "386049444.nwb" + assert obt["morphology"]["491459173"] == "Nr5a1-Cre_Ai14-177334.05.01.01_491459171_m.swc" + assert obt["fit"]["497235805"] == "386049446_fit.json" + assert obt["marker"]["496607103"] == "Nr5a1-Cre_Ai14-177334.05.01.01_491459171_marker_m.swc" + assert obt["modfiles"]["395337293"] == os.path.join("modfiles", "SK.mod") + assert np.allclose(biophys_api.sweeps, [42]) diff --git a/allensdk/test/api/test_brain_observatory_api.py b/allensdk/test/api/test_brain_observatory_api.py index 325de061a9..bec42d192d 100644 --- a/allensdk/test/api/test_brain_observatory_api.py +++ b/allensdk/test/api/test_brain_observatory_api.py @@ -66,11 +66,7 @@ def safe_msg5(): @pytest.fixture() def bo_api(): - endpoint = ( - os.environ["TEST_API_ENDPOINT"] - if "TEST_API_ENDPOINT" in os.environ - else "http://twarehouse-backup" - ) + endpoint = os.environ["TEST_API_ENDPOINT"] if "TEST_API_ENDPOINT" in os.environ else "http://twarehouse-backup" return BrainObservatoryApi(endpoint) @@ -80,41 +76,17 @@ def mock_containers(): { "targeted_structure": {"acronym": "CBS"}, "imaging_depth": 100, - "specimen": { - "donor": { - "transgenic_lines": [ - { - "name": "Shiny", - "transgenic_line_type_name": "driver"} - ] - } - }, + "specimen": {"donor": {"transgenic_lines": [{"name": "Shiny", "transgenic_line_type_name": "driver"}]}}, }, { "targeted_structure": {"acronym": "ABC"}, "imaging_depth": 150, - "specimen": { - "donor": { - "transgenic_lines": [ - { - "name": "ShinyCre", - "transgenic_line_type_name": "driver"} - ] - } - }, + "specimen": {"donor": {"transgenic_lines": [{"name": "ShinyCre", "transgenic_line_type_name": "driver"}]}}, }, { "targeted_structure": {"acronym": "NBC"}, "imaging_depth": 200, - "specimen": { - "donor": { - "transgenic_lines": [ - { - "name": "Don", - "transgenic_line_type_name": "reporter"} - ] - } - }, + "specimen": {"donor": {"transgenic_lines": [{"name": "Don", "transgenic_line_type_name": "reporter"}]}}, }, ] @@ -176,8 +148,7 @@ def mock_specimens(): def test_list_isi_experiments(mock_json_msg_query, bo_api): bo_api.list_isi_experiments() mock_json_msg_query.assert_called_once_with( - bo_api.api_url + "/api/v2/data/query.json?q=" - "model::IsiExperiment,rma::options[num_rows$eq'all'][count$eqfalse]" + bo_api.api_url + "/api/v2/data/query.json?q=model::IsiExperiment,rma::options[num_rows$eq'all'][count$eqfalse]" ) @@ -296,7 +267,7 @@ def test_get_cell_metrics_no_ids(mock_json_msg_query, bo_api): mock_json_msg_query.assert_called_once_with( bo_api.api_url + "/api/v2/data/query.json?q=" "model::ApiCamCellMetric," - "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 + "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 ) @@ -308,7 +279,7 @@ def test_get_cell_metrics_one_ids(mock_json_msg_query, bo_api): bo_api.api_url + "/api/v2/data/query.json?q=" "model::ApiCamCellMetric," "rma::criteria,[cell_specimen_id$in517394843]," - "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 + "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 ) @@ -320,14 +291,12 @@ def test_get_cell_metrics_two_ids(mock_json_msg_query, bo_api): bo_api.api_url + "/api/v2/data/query.json?q=" "model::ApiCamCellMetric," "rma::criteria,[cell_specimen_id$in517394843,517394850]," - "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 + "rma::options[num_rows$eq2000][start_row$eq0][order$eq'cell_specimen_id'][count$eqfalse]" # noqa e501 ) def test_get_cell_metrics_five_messages(bo_api, safe_msg5): - with patch( - "allensdk.core.json_utilities.read_url_get", side_effect=safe_msg5 - ) as ju_read_url_get: + with patch("allensdk.core.json_utilities.read_url_get", side_effect=safe_msg5) as ju_read_url_get: ids = [517394843, 517394850] list(bo_api.get_cell_metrics(cell_specimen_ids=ids)) @@ -337,10 +306,7 @@ def test_get_cell_metrics_five_messages(bo_api, safe_msg5): "rma::criteria,%5Bcell_specimen_id$in517394843,517394850%5D," "rma::options%5Bnum_rows$eq2000%5D%5Bstart_row$eq{}%5D%5Border$eq%27cell_specimen_id%27%5D%5Bcount$eqfalse%5D" # noqa e501 ) - expected_calls = map( - lambda c: call(base_query.format(c)), - [0, 2000, 4000, 6000, 8000, 10000] - ) + expected_calls = map(lambda c: call(base_query.format(c)), [0, 2000, 4000, 6000, 8000, 10000]) assert ju_read_url_get.call_args_list == list(expected_calls) @@ -351,22 +317,16 @@ def test_filter_experiment_containers_no_filters(bo_api, mock_containers): def test_filter_experiment_containers_depth_filter(bo_api, mock_containers): - containers = bo_api.filter_experiment_containers( - mock_containers, imaging_depths=[100] - ) + containers = bo_api.filter_experiment_containers(mock_containers, imaging_depths=[100]) assert len(containers) == 1 -def test_filter_experiment_containers_structures_filter( - bo_api, mock_containers): - containers = bo_api.filter_experiment_containers( - mock_containers, targeted_structures=["CBS"] - ) +def test_filter_experiment_containers_structures_filter(bo_api, mock_containers): + containers = bo_api.filter_experiment_containers(mock_containers, targeted_structures=["CBS"]) assert len(containers) == 1 -def test_filter_experiment_containers_lines_all_filters( - bo_api, mock_containers): +def test_filter_experiment_containers_lines_all_filters(bo_api, mock_containers): containers = bo_api.filter_experiment_containers( mock_containers, imaging_depths=[200], @@ -386,23 +346,16 @@ def test_filter_experiment_containers_lines_all_filters( assert len(containers) == 1 -def test_filter_experiment_containers_transgenic_lines( - bo_api, mock_containers): - containers = bo_api.filter_experiment_containers( - mock_containers, cre_lines=["Shiny"] - ) +def test_filter_experiment_containers_transgenic_lines(bo_api, mock_containers): + containers = bo_api.filter_experiment_containers(mock_containers, cre_lines=["Shiny"]) assert len(containers) == 0 - containers = bo_api.filter_experiment_containers( - mock_containers, cre_lines=["ShinyCre"] - ) + containers = bo_api.filter_experiment_containers(mock_containers, cre_lines=["ShinyCre"]) assert len(containers) == 1 - containers = bo_api.filter_experiment_containers( - mock_containers, transgenic_lines=["DON"] - ) + containers = bo_api.filter_experiment_containers(mock_containers, transgenic_lines=["DON"]) assert len(containers) == 1 @@ -413,37 +366,28 @@ def test_filter_ophys_experiments_no_filters(bo_api, mock_ophys_experiments): def test_filter_ophys_experiments_container_id(bo_api, mock_ophys_experiments): - experiments = bo_api.filter_ophys_experiments( - mock_ophys_experiments, experiment_container_ids=[1] - ) + experiments = bo_api.filter_ophys_experiments(mock_ophys_experiments, experiment_container_ids=[1]) assert len(experiments) == 1 def test_filter_ophys_experiments_stimuli(bo_api, mock_ophys_experiments): - experiments = bo_api.filter_ophys_experiments( - mock_ophys_experiments, stimuli=["static_gratings"] - ) + experiments = bo_api.filter_ophys_experiments(mock_ophys_experiments, stimuli=["static_gratings"]) assert len(experiments) == 1 def test_filter_ophys_experiments_eye_tracking(bo_api, mock_ophys_experiments): - experiments = bo_api.filter_ophys_experiments( - mock_ophys_experiments, require_eye_tracking=True - ) + experiments = bo_api.filter_ophys_experiments(mock_ophys_experiments, require_eye_tracking=True) assert len(experiments) == 1 def test_filter_cell_specimens(bo_api, mock_specimens): - specimens = bo_api.filter_cell_specimens(mock_specimens, - include_failed=True) + specimens = bo_api.filter_cell_specimens(mock_specimens, include_failed=True) assert specimens == mock_specimens specimens = bo_api.filter_cell_specimens(mock_specimens) assert len(specimens) == 2 - specimens = bo_api.filter_cell_specimens( - mock_specimens, ids=[mock_specimens[0]["cell_specimen_id"]] - ) + specimens = bo_api.filter_cell_specimens(mock_specimens, ids=[mock_specimens[0]["cell_specimen_id"]]) assert len(specimens) == 1 assert specimens[0] == mock_specimens[0] @@ -452,9 +396,7 @@ def test_filter_cell_specimens(bo_api, mock_specimens): cnt[sp["experiment_container_id"]] += 1 ecid = mock_specimens[0]["experiment_container_id"] - specimens = bo_api.filter_cell_specimens( - mock_specimens, experiment_container_ids=[ecid] - ) + specimens = bo_api.filter_cell_specimens(mock_specimens, experiment_container_ids=[ecid]) assert len(specimens) == cnt[ecid] assert specimens[0] == mock_specimens[0] @@ -465,9 +407,7 @@ def test_filter_cell_specimens(bo_api, mock_specimens): "json_msg_query", return_value=[{"download_link": "/url/path/to/file"}], ) -def test_save_ophys_experiment_data( - mock_json_msg_query, mock_retrieve_file_over_http, bo_api -): +def test_save_ophys_experiment_data(mock_json_msg_query, mock_retrieve_file_over_http, bo_api): with patch("allensdk.config.manifest.Manifest.safe_mkdir") as mkdir: bo_api.save_ophys_experiment_data(1, "/path/to/filename") @@ -480,9 +420,7 @@ def test_save_ophys_experiment_data( "[attachable_id$eq1],well_known_file_type[name$eqNWBOphys]," "rma::options[num_rows$eq'all'][count$eqfalse]" ) - mock_retrieve_file_over_http.assert_called_with( - bo_api.api_url + "/url/path/to/file", "/path/to/filename" - ) + mock_retrieve_file_over_http.assert_called_with(bo_api.api_url + "/url/path/to/file", "/path/to/filename") @patch.object(BrainObservatoryApi, "retrieve_file_over_http") @@ -491,9 +429,7 @@ def test_save_ophys_experiment_data( "json_msg_query", return_value=[{"download_link": "/url/path/to/file"}], ) -def test_save_ophys_experiment_event_data( - mock_json_msg_query, mock_retrieve_file_over_http, bo_api -): +def test_save_ophys_experiment_event_data(mock_json_msg_query, mock_retrieve_file_over_http, bo_api): with patch("allensdk.config.manifest.Manifest.safe_mkdir") as mkdir: bo_api.save_ophys_experiment_event_data(1, "/path/to/filename") @@ -503,12 +439,10 @@ def test_save_ophys_experiment_event_data( bo_api.api_url + "/api/v2/data/query.json?q=" "model::WellKnownFile," "rma::criteria," - "[attachable_id$eq1],well_known_file_type[name$eqObservatoryEventsFile]," # noqa e501 + "[attachable_id$eq1],well_known_file_type[name$eqObservatoryEventsFile]," # noqa e501 "rma::options[num_rows$eq'all'][count$eqfalse]" ) - mock_retrieve_file_over_http.assert_called_with( - bo_api.api_url + "/url/path/to/file", "/path/to/filename" - ) + mock_retrieve_file_over_http.assert_called_with(bo_api.api_url + "/url/path/to/file", "/path/to/filename") @pytest.mark.parametrize("ophys_experiment_id", [1, 2]) @@ -529,19 +463,11 @@ def get_dummy_data(file_id): def dummy_cloud_cache_init(cache_dir, bucket_name, project_name): pass - with patch.object(S3CloudCache, "__init__", - wraps=dummy_cloud_cache_init): - with patch.object(S3CloudCache, "load_latest_manifest", - wraps=lambda: None): - with patch.object( - S3CloudCache, "get_metadata", wraps=get_dummy_metadata - ): - with patch.object( - S3CloudCache, "download_data", wraps=get_dummy_data - ): - cloud_cache = S3CloudCache( - cache_dir="", bucket_name="", project_name="" - ) + with patch.object(S3CloudCache, "__init__", wraps=dummy_cloud_cache_init): + with patch.object(S3CloudCache, "load_latest_manifest", wraps=lambda: None): + with patch.object(S3CloudCache, "get_metadata", wraps=get_dummy_metadata): + with patch.object(S3CloudCache, "download_data", wraps=get_dummy_data): + cloud_cache = S3CloudCache(cache_dir="", bucket_name="", project_name="") if ophys_experiment_id == 2: with pytest.raises(ValueError): bo_api.save_ophys_experiment_eye_tracking_data( # noqa E501 @@ -549,11 +475,10 @@ def dummy_cloud_cache_init(cache_dir, bucket_name, project_name): cloud_cache=cloud_cache, ) else: - file_path = bo_api.\ - save_ophys_experiment_eye_tracking_data( - ophys_experiment_id=ophys_experiment_id, - cloud_cache=cloud_cache, - ) + file_path = bo_api.save_ophys_experiment_eye_tracking_data( + ophys_experiment_id=ophys_experiment_id, + cloud_cache=cloud_cache, + ) actual = np.load(str(file_path)) np.testing.assert_array_equal(actual, expected) @@ -564,9 +489,7 @@ def dummy_cloud_cache_init(cache_dir, bucket_name, project_name): "json_msg_query", return_value=[{"download_link": "/url/path/to/file"}], ) -def test_get_cell_specimen_id_mapping( - mock_json_msg_query, mock_retrieve_file_over_http, bo_api -): +def test_get_cell_specimen_id_mapping(mock_json_msg_query, mock_retrieve_file_over_http, bo_api): with patch("pandas.read_csv") as readcsv: bo_api.get_cell_specimen_id_mapping("/path/to/filename", 1) @@ -579,9 +502,7 @@ def test_get_cell_specimen_id_mapping( "[id$eq1],well_known_file_type[name$eqOphysCellSpecimenIdMapping]," "rma::options[num_rows$eq'all'][count$eqfalse]" ) - mock_retrieve_file_over_http.assert_called_with( - bo_api.api_url + "/url/path/to/file", "/path/to/filename" - ) + mock_retrieve_file_over_http.assert_called_with(bo_api.api_url + "/url/path/to/file", "/path/to/filename") def test_find_container_tags(): @@ -613,35 +534,17 @@ def test_find_specimen_cre_line(): assert cre is None # None if no 'Cre' - s = { - "donor": { - "transgenic_lines": [ - {"transgenic_line_type_name": "driver", "name": "banana"} - ] - } - } + s = {"donor": {"transgenic_lines": [{"transgenic_line_type_name": "driver", "name": "banana"}]}} cre = find_specimen_cre_line(s) assert cre is None # None if no 'Cre' - s = { - "donor": { - "transgenic_lines": [ - {"transgenic_line_type_name": "driver", "name": "bananaCre"} - ] - } - } + s = {"donor": {"transgenic_lines": [{"transgenic_line_type_name": "driver", "name": "bananaCre"}]}} cre = find_specimen_cre_line(s) assert cre == "bananaCre" # None if no 'driver' - s = { - "donor": { - "transgenic_lines": [ - {"transgenic_line_type_name": "reporter", "name": "bananaCre"} - ] - } - } + s = {"donor": {"transgenic_lines": [{"transgenic_line_type_name": "reporter", "name": "bananaCre"}]}} cre = find_specimen_cre_line(s) assert cre is None @@ -652,24 +555,12 @@ def test_find_specimen_reporter_line(): cre = find_specimen_reporter_line(s) assert cre is None - s = { - "donor": { - "transgenic_lines": [ - {"transgenic_line_type_name": "reporter", "name": "banana"} - ] - } - } + s = {"donor": {"transgenic_lines": [{"transgenic_line_type_name": "reporter", "name": "banana"}]}} cre = find_specimen_reporter_line(s) assert cre == "banana" # None if no "reporter" - s = { - "donor": { - "transgenic_lines": [ - {"transgenic_line_type_name": "driver", "name": "bananaCre"} - ] - } - } + s = {"donor": {"transgenic_lines": [{"transgenic_line_type_name": "driver", "name": "bananaCre"}]}} cre = find_specimen_reporter_line(s) assert cre is None diff --git a/allensdk/test/api/test_cache.py b/allensdk/test/api/test_cache.py index 631167c0a8..831e226c91 100755 --- a/allensdk/test/api/test_cache.py +++ b/allensdk/test/api/test_cache.py @@ -46,7 +46,7 @@ from allensdk.config.manifest import ManifestVersionError from allensdk.config.manifest_builder import ManifestBuilder -_msg = [{'whatever': True}] +_msg = [{"whatever": True}] _pd_msg = pd.DataFrame(_msg) @@ -62,7 +62,7 @@ def rma(): @pytest.fixture def wavefront_obj(): - return ''' + return """ v 8578 5484.96 5227.57 v 8509.2 5487.54 5237.07 @@ -87,13 +87,12 @@ def wavefront_obj(): f 3//3 2//2 5//5 f 6//6 3//3 5//5 - ''' + """ @pytest.fixture def dummy_cache(): class DummyCache(Cache): - VERSION = None def build_manifest(self, file_name): @@ -105,8 +104,7 @@ def build_manifest(self, file_name): def test_version_update(fn_temp_dir, dummy_cache): - - mpath = os.path.join(fn_temp_dir, 'manifest.json') + mpath = os.path.join(fn_temp_dir, "manifest.json") dummy_cache(manifest=mpath) dummy_cache(manifest=mpath) @@ -116,47 +114,37 @@ def test_version_update(fn_temp_dir, dummy_cache): def test_load_manifest(tmpdir_factory, dummy_cache): - - manifest = tmpdir_factory.mktemp('data').join('test_manifest.json') + manifest = tmpdir_factory.mktemp("data").join("test_manifest.json") cache = dummy_cache(manifest=str(manifest)) - assert(cache.manifest_path == str(manifest)) - assert(os.path.exists(cache.manifest_path)) + assert cache.manifest_path == str(manifest) + assert os.path.exists(cache.manifest_path) @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=pd.DataFrame(_msg)) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) def test_wrap_json(ju_read_url_get, ju_read, ju_write, rma, cache): - df = cache.wrap(rma.model_query, - 'example.txt', - cache=True, - model='Hemisphere') + df = cache.wrap(rma.model_query, "example.txt", cache=True, model="Hemisphere") - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - ju_write.assert_called_once_with('example.txt', _msg) - ju_read.assert_called_once_with('example.txt') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + ju_write.assert_called_once_with("example.txt", _msg) + ju_read.assert_called_once_with("example.txt") @patch("pandas.io.json.read_json", return_value=_msg) @patch("allensdk.core.json_utilities.write") -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) def test_wrap_dataframe(ju_read_url_get, ju_write, mock_read_json, rma, cache): - json_data = cache.wrap(rma.model_query, - 'example.txt', - cache=True, - return_dataframe=True, - model='Hemisphere') + json_data = cache.wrap(rma.model_query, "example.txt", cache=True, return_dataframe=True, model="Hemisphere") - assert json_data[0]['whatever'] + assert json_data[0]["whatever"] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - ju_write.assert_called_once_with('example.txt', _msg) - mock_read_json.assert_called_once_with('example.txt', orient='records') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + ju_write.assert_called_once_with("example.txt", _msg) + mock_read_json.assert_called_once_with("example.txt", orient="records") def test_memoize_with_function(): @@ -188,7 +176,7 @@ def test_memoize_with_kwarg_function(): @memoize def f(x, *, y, z=1): time.sleep(0.1) - return (x * y * z) + return x * y * z # Build cache f(2, y=1, z=2) @@ -225,6 +213,6 @@ def f(self, x): def test_get_default_manifest_file(): - assert get_default_manifest_file('brain_observatory') == 'brain_observatory/manifest.json' - assert get_default_manifest_file('cell_types') == 'cell_types/manifest.json' - assert get_default_manifest_file('mouse_connectivity') == 'mouse_connectivity/manifest.json' + assert get_default_manifest_file("brain_observatory") == "brain_observatory/manifest.json" + assert get_default_manifest_file("cell_types") == "cell_types/manifest.json" + assert get_default_manifest_file("mouse_connectivity") == "mouse_connectivity/manifest.json" diff --git a/allensdk/test/api/test_cacheable.py b/allensdk/test/api/test_cacheable.py index 73fe9284f1..f4a13d03c4 100644 --- a/allensdk/test/api/test_cacheable.py +++ b/allensdk/test/api/test_cacheable.py @@ -46,270 +46,235 @@ import io as StringIO -_msg = [{'whatever': True}] +_msg = [{"whatever": True}] _pd_msg = pd.DataFrame(_msg) -_csv_msg = pd.read_csv(StringIO.StringIO(""",whatever +_csv_msg = pd.read_csv( + StringIO.StringIO(""",whatever 0,True -"""), index_col=0) +"""), + index_col=0, +) @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch('csv.DictWriter') -@patch('pandas.read_csv', return_value=_csv_msg) -def test_cacheable_csv_dataframe(read_csv, dictwriter, ju_read_url_get, - ju_read, ju_write): +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch("csv.DictWriter") +@patch("pandas.read_csv", return_value=_csv_msg) +def test_cacheable_csv_dataframe(read_csv, dictwriter, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - with patch('allensdk.config.manifest.Manifest.safe_mkdir') as mkdir: - with patch(builtins.__name__ + '.open', - mock_open(), - create=True) as open_mock: + with patch("allensdk.config.manifest.Manifest.safe_mkdir") as mkdir: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as open_mock: open_mock.return_value.write = MagicMock() - df = get_hemispheres(path='/xyz/abc/example.txt', - strategy='create', - **Cache.cache_csv_dataframe()) + df = get_hemispheres(path="/xyz/abc/example.txt", strategy="create", **Cache.cache_csv_dataframe()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - read_csv.assert_called_once_with('/xyz/abc/example.txt', parse_dates=True) - assert not ju_write.called, 'write should not have been called' - assert not ju_read.called, 'read should not have been called' - mkdir.assert_called_once_with('/xyz/abc') - open_mock.assert_called_once_with('/xyz/abc/example.txt', 'w') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + read_csv.assert_called_once_with("/xyz/abc/example.txt", parse_dates=True) + assert not ju_write.called, "write should not have been called" + assert not ju_read.called, "read should not have been called" + mkdir.assert_called_once_with("/xyz/abc") + open_mock.assert_called_once_with("/xyz/abc/example.txt", "w") @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch.object(Manifest, 'safe_mkdir') -@patch('pandas.read_csv', return_value=_csv_msg) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch.object(Manifest, "safe_mkdir") +@patch("pandas.read_csv", return_value=_csv_msg) def test_cacheable_json(read_csv, mkdir, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - df = get_hemispheres(path='/xyz/abc/example.json', - strategy='create', - **Cache.cache_json()) + df = get_hemispheres(path="/xyz/abc/example.json", strategy="create", **Cache.cache_json()) - assert 'whatever' in df[0] + assert "whatever" in df[0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - assert not read_csv.called, 'read_csv should not have been called' - ju_write.assert_called_once_with('/xyz/abc/example.json', - _msg) - ju_read.assert_called_once_with('/xyz/abc/example.json') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + assert not read_csv.called, "read_csv should not have been called" + ju_write.assert_called_once_with("/xyz/abc/example.json", _msg) + ju_read.assert_called_once_with("/xyz/abc/example.json") @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch.object(Manifest, 'safe_mkdir') +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch.object(Manifest, "safe_mkdir") def test_excpt(mkdir, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres_excpt(): - return RmaApi().model_query(model='Hemisphere', - excpt=['symbol']) + return RmaApi().model_query(model="Hemisphere", excpt=["symbol"]) - df = get_hemispheres_excpt(path='/xyz/abc/example.json', - strategy='create', - **Cache.cache_json()) + df = get_hemispheres_excpt(path="/xyz/abc/example.json", strategy="create", **Cache.cache_json()) - assert 'whatever' in df[0] + assert "whatever" in df[0] ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere,rma::options%5Bexcept$eqsymbol%5D') - ju_write.assert_called_once_with('/xyz/abc/example.json', _msg) - ju_read.assert_called_once_with('/xyz/abc/example.json') - mkdir.assert_called_once_with('/xyz/abc') + "http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere,rma::options%5Bexcept$eqsymbol%5D" + ) + ju_write.assert_called_once_with("/xyz/abc/example.json", _msg) + ju_read.assert_called_once_with("/xyz/abc/example.json") + mkdir.assert_called_once_with("/xyz/abc") @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch('pandas.read_csv', return_value=_csv_msg) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch("pandas.read_csv", return_value=_csv_msg) def test_cacheable_no_cache_csv(read_csv, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - df = get_hemispheres(path='/xyz/abc/example.csv', - strategy='file', - **Cache.cache_csv()) + df = get_hemispheres(path="/xyz/abc/example.csv", strategy="file", **Cache.cache_csv()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] assert not ju_read_url_get.called - read_csv.assert_called_once_with('/xyz/abc/example.csv', parse_dates=True) - assert not ju_write.called, 'json write should not have been called' - assert not ju_read.called, 'json read should not have been called' + read_csv.assert_called_once_with("/xyz/abc/example.csv", parse_dates=True) + assert not ju_write.called, "json write should not have been called" + assert not ju_read.called, "json read should not have been called" @patch("pandas.io.json.read_json", return_value=_pd_msg) @patch("pandas.read_csv", return_value=_csv_msg) @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch.object(Manifest, 'safe_mkdir') -def test_cacheable_json_dataframe(mkdir, ju_read_url_get, ju_read, ju_write, - read_csv, mock_read_json): +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch.object(Manifest, "safe_mkdir") +def test_cacheable_json_dataframe(mkdir, ju_read_url_get, ju_read, ju_write, read_csv, mock_read_json): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - df = get_hemispheres(path='/xyz/abc/example.json', - strategy='create', - **Cache.cache_json_dataframe()) + df = get_hemispheres(path="/xyz/abc/example.json", strategy="create", **Cache.cache_json_dataframe()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - assert not read_csv.called, 'read_csv should not have been called' - mock_read_json.assert_called_once_with('/xyz/abc/example.json', - orient='records') - ju_write.assert_called_once_with('/xyz/abc/example.json', _msg) - assert not ju_read.called, 'json read should not have been called' - mkdir.assert_called_once_with('/xyz/abc') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + assert not read_csv.called, "read_csv should not have been called" + mock_read_json.assert_called_once_with("/xyz/abc/example.json", orient="records") + ju_write.assert_called_once_with("/xyz/abc/example.json", _msg) + assert not ju_read.called, "json read should not have been called" + mkdir.assert_called_once_with("/xyz/abc") @patch("pandas.io.json.read_json", return_value=_pd_msg) @patch("pandas.read_csv", return_value=_csv_msg) @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch('csv.DictWriter') -@patch.object(Manifest, 'safe_mkdir') -def test_cacheable_csv_json(mkdir, dictwriter, ju_read_url_get, ju_read, - ju_write, read_csv, mock_read_json): +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch("csv.DictWriter") +@patch.object(Manifest, "safe_mkdir") +def test_cacheable_csv_json(mkdir, dictwriter, ju_read_url_get, ju_read, ju_write, read_csv, mock_read_json): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - with patch(builtins.__name__ + '.open', - mock_open(), - create=True) as open_mock: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as open_mock: open_mock.return_value.write = MagicMock() - df = get_hemispheres(path='/xyz/example.csv', - strategy='create', - **Cache.cache_csv_json()) + df = get_hemispheres(path="/xyz/example.csv", strategy="create", **Cache.cache_csv_json()) - assert 'whatever' in df[0] + assert "whatever" in df[0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - read_csv.assert_called_once_with('/xyz/example.csv', parse_dates=True) + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + read_csv.assert_called_once_with("/xyz/example.csv", parse_dates=True) dictwriter.return_value.writerow.assert_called() - assert not mock_read_json.called, 'pj.read_json should not have been called' - assert not ju_write.called, 'ju.write should not have been called' - assert not ju_read.called, 'json read should not have been called' - mkdir.assert_called_once_with('/xyz') - open_mock.assert_called_once_with('/xyz/example.csv', 'w') + assert not mock_read_json.called, "pj.read_json should not have been called" + assert not ju_write.called, "ju.write should not have been called" + assert not ju_read.called, "json read should not have been called" + mkdir.assert_called_once_with("/xyz") + open_mock.assert_called_once_with("/xyz/example.csv", "w") @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) @patch("pandas.read_csv") @patch.object(pd.DataFrame, "to_csv") -def test_cacheable_no_save(to_csv, read_csv, ju_read_url_get, ju_read, - ju_write): +def test_cacheable_no_save(to_csv, read_csv, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") data = get_hemispheres() - assert 'whatever' in data[0] + assert "whatever" in data[0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - assert not to_csv.called, 'to_csv should not have been called' - assert not read_csv.called, 'read_csv should not have been called' - assert not ju_write.called, 'json write should not have been called' - assert not ju_read.called, 'json read should not have been called' + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + assert not to_csv.called, "to_csv should not have been called" + assert not read_csv.called, "read_csv should not have been called" + assert not ju_write.called, "json write should not have been called" + assert not ju_read.called, "json read should not have been called" @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) @patch("pandas.read_csv", return_value=_csv_msg) @patch.object(pd.DataFrame, "to_csv") -def test_cacheable_no_save_dataframe(to_csv, read_csv, ju_read_url_get, - ju_read, ju_write): +def test_cacheable_no_save_dataframe(to_csv, read_csv, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") df = get_hemispheres(**Cache.nocache_dataframe()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - assert not to_csv.called, 'to_csv should not have been called' - assert not read_csv.called, 'read_csv should not have been called' - assert not ju_write.called, 'json write should not have been called' - assert not ju_read.called, 'json read should not have been called' + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + assert not to_csv.called, "to_csv should not have been called" + assert not read_csv.called, "read_csv should not have been called" + assert not ju_write.called, "json write should not have been called" + assert not ju_read.called, "json read should not have been called" @patch("pandas.read_csv", return_value=_csv_msg) @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) -@patch('csv.DictWriter') -@patch.object(Manifest, 'safe_mkdir') -def test_cacheable_lazy_csv_no_file(mkdir, dictwriter, ju_read_url_get, - ju_read, ju_write, read_csv): +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) +@patch("csv.DictWriter") +@patch.object(Manifest, "safe_mkdir") +def test_cacheable_lazy_csv_no_file(mkdir, dictwriter, ju_read_url_get, ju_read, ju_write, read_csv): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - with patch('os.path.exists', MagicMock(return_value=False)): - with patch(builtins.__name__ + '.open', - mock_open(), - create=True) as open_mock: + with patch("os.path.exists", MagicMock(return_value=False)): + with patch(builtins.__name__ + ".open", mock_open(), create=True) as open_mock: open_mock.return_value.write = MagicMock() - df = get_hemispheres(path='/xyz/abc/example.csv', - strategy='lazy', - **Cache.cache_csv()) + df = get_hemispheres(path="/xyz/abc/example.csv", strategy="lazy", **Cache.cache_csv()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] - ju_read_url_get.assert_called_once_with( - 'http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere') - open_mock.assert_called_once_with('/xyz/abc/example.csv', 'w') + ju_read_url_get.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=model::Hemisphere") + open_mock.assert_called_once_with("/xyz/abc/example.csv", "w") dictwriter.return_value.writerow.assert_called() - read_csv.assert_called_once_with('/xyz/abc/example.csv', parse_dates=True) - assert not ju_write.called, 'json write should not have been called' - assert not ju_read.called, 'json read should not have been called' + read_csv.assert_called_once_with("/xyz/abc/example.csv", parse_dates=True) + assert not ju_write.called, "json write should not have been called" + assert not ju_read.called, "json read should not have been called" @patch("allensdk.core.json_utilities.write") @patch("allensdk.core.json_utilities.read", return_value=_msg) -@patch("allensdk.core.json_utilities.read_url_get", return_value={'msg': _msg}) +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) @patch("pandas.read_csv", return_value=_csv_msg) -def test_cacheable_lazy_csv_file_exists(read_csv, ju_read_url_get, ju_read, - ju_write): +def test_cacheable_lazy_csv_file_exists(read_csv, ju_read_url_get, ju_read, ju_write): @cacheable() def get_hemispheres(): - return RmaApi().model_query(model='Hemisphere') + return RmaApi().model_query(model="Hemisphere") - with patch('os.path.exists', MagicMock(return_value=True)): - df = get_hemispheres(path='/xyz/abc/example.csv', - strategy='lazy', - **Cache.cache_csv()) + with patch("os.path.exists", MagicMock(return_value=True)): + df = get_hemispheres(path="/xyz/abc/example.csv", strategy="lazy", **Cache.cache_csv()) - assert df.loc[:, 'whatever'][0] + assert df.loc[:, "whatever"][0] assert not ju_read_url_get.called - read_csv.assert_called_once_with('/xyz/abc/example.csv', parse_dates=True) - assert not ju_write.called, 'json write should not have been called' - assert not ju_read.called, 'json read should not have been called' + read_csv.assert_called_once_with("/xyz/abc/example.csv", parse_dates=True) + assert not ju_write.called, "json write should not have been called" + assert not ju_read.called, "json read should not have been called" diff --git a/allensdk/test/api/test_caching_utilities.py b/allensdk/test/api/test_caching_utilities.py index 4a3578168a..70b8f8f042 100644 --- a/allensdk/test/api/test_caching_utilities.py +++ b/allensdk/test/api/test_caching_utilities.py @@ -9,15 +9,11 @@ def get_data(): - return pd.DataFrame( - {"a": [1, 2, 3, 4], "b": ["duck", "kangaroo", "walrus", "ibex"]} - ) + return pd.DataFrame({"a": [1, 2, 3, 4], "b": ["duck", "kangaroo", "walrus", "ibex"]}) def swapped_data(): - return pd.DataFrame( - {"b": [1, 2, 3, 4], "a": ["duck", "kangaroo", "walrus", "ibex"]} - ) + return pd.DataFrame({"b": [1, 2, 3, 4], "a": ["duck", "kangaroo", "walrus", "ibex"]}) def write_to_dict(dc, data): @@ -94,7 +90,7 @@ def swap(data): swapped_data(), 1, 0, - id="simple case" + id="simple case", ), pytest.param( False, @@ -109,7 +105,7 @@ def swap(data): swapped_data(), 1, 0, - id="eager success case" + id="eager success case", ), pytest.param( False, @@ -124,7 +120,7 @@ def swap(data): "raise", 1, 1, - id="lazy failure case" + id="lazy failure case", ), pytest.param( False, @@ -139,7 +135,7 @@ def swap(data): swapped_data(), 10, 9, - id="repeated failure case" + id="repeated failure case", ), pytest.param( False, @@ -154,7 +150,7 @@ def swap(data): "warn", 10, 9, - id="warning case" + id="warning case", ), pytest.param( False, @@ -169,7 +165,7 @@ def swap(data): "raise", 1, 1, - id="eager failure case" + id="eager failure case", ), pytest.param( True, @@ -184,7 +180,7 @@ def swap(data): get_data(), 0, 0, - id="existing data case" + id="existing data case", ), ], ) @@ -202,7 +198,6 @@ def test_call_caching( expected_fetches, expected_cleanups, ): - dc = {} if existing: write(dc, fetch()) @@ -213,17 +208,7 @@ def test_call_caching( read_fn = partial(read, dc) cleanup_fn = partial(cleanup, dc) - fn = partial( - cu.call_caching, - fetch, - write_fn, - read_fn, - pre_write, - cleanup_fn, - lazy, - num_tries, - failure_message - ) + fn = partial(cu.call_caching, fetch, write_fn, read_fn, pre_write, cleanup_fn, lazy, num_tries, failure_message) if isinstance(expected, str) and expected == "raise": with pytest.raises(ValueError): @@ -250,15 +235,12 @@ def test_one_file_call_caching(tmpdir_factory, existing): if existing: data.to_csv(path, index=False) + def getter(): return "foo" obtained = cu.one_file_call_caching( - path, - getter, - lambda path, df: df.to_csv(path, index=False), - lambda path: pd.read_csv(path), - num_tries=2 + path, getter, lambda path, df: df.to_csv(path, index=False), lambda path: pd.read_csv(path), num_tries=2 ) pd.testing.assert_frame_equal(get_data(), obtained, check_like=True, check_dtype=False) diff --git a/allensdk/test/api/test_cell_types_api.py b/allensdk/test/api/test_cell_types_api.py index 9b18ca6855..af29592457 100644 --- a/allensdk/test/api/test_cell_types_api.py +++ b/allensdk/test/api/test_cell_types_api.py @@ -3,93 +3,95 @@ from unittest.mock import patch from allensdk.api.queries.cell_types_api import CellTypesApi + @pytest.fixture def mock_cells_api(): return [ - { - 'cell_reporter_status': "fish", - 'csl__x': 1, - 'csl__y': 2, - 'csl__z': 3, - 'donor__species': 'taco', - 'specimen__id': 10, - 'specimen__name': 'joe', - 'structure__layer': 'fifteen', - 'structure_parent__id': 2, - 'structure_parent__acronym': 'ASAP', - 'line_name': 'bezier', - 'tag__dendrite_type': 'spikey', - 'tag__apical': 'stumpy', - 'nr__reconstruction_type': 'fancy', - 'donor__disease_state': 'influenza', - 'donor__id': 1, - 'specimen__hemisphere': 'hi', - 'csl__normalized_depth': 1 - },{ - - 'cell_reporter_status': "nofish", - 'csl__x': 1, - 'csl__y': 2, - 'csl__z': 3, - 'donor__species': 'taco', - 'specimen__id': 10, - 'specimen__name': 'joe', - 'structure__layer': 'fifteen', - 'structure_parent__id': 2, - 'structure_parent__acronym': 'ASAP', - 'line_name': 'bezier', - 'tag__dendrite_type': 'spikey', - 'tag__apical': 'stumpy', - 'nr__reconstruction_type': None, - 'donor__disease_state': None, - 'donor__id': 1, - 'specimen__hemisphere': 'hi', - 'csl__normalized_depth': 1 - } + { + "cell_reporter_status": "fish", + "csl__x": 1, + "csl__y": 2, + "csl__z": 3, + "donor__species": "taco", + "specimen__id": 10, + "specimen__name": "joe", + "structure__layer": "fifteen", + "structure_parent__id": 2, + "structure_parent__acronym": "ASAP", + "line_name": "bezier", + "tag__dendrite_type": "spikey", + "tag__apical": "stumpy", + "nr__reconstruction_type": "fancy", + "donor__disease_state": "influenza", + "donor__id": 1, + "specimen__hemisphere": "hi", + "csl__normalized_depth": 1, + }, + { + "cell_reporter_status": "nofish", + "csl__x": 1, + "csl__y": 2, + "csl__z": 3, + "donor__species": "taco", + "specimen__id": 10, + "specimen__name": "joe", + "structure__layer": "fifteen", + "structure_parent__id": 2, + "structure_parent__acronym": "ASAP", + "line_name": "bezier", + "tag__dendrite_type": "spikey", + "tag__apical": "stumpy", + "nr__reconstruction_type": None, + "donor__disease_state": None, + "donor__id": 1, + "specimen__hemisphere": "hi", + "csl__normalized_depth": 1, + }, ] - - + + @pytest.fixture def mock_cells(): return [ - { - 'specimen_tags': [], - 'neuron_reconstructions': [], - 'data_sets': [], - 'donor': { - 'transgenic_lines': [], - 'organism': { 'name': CellTypesApi.MOUSE }, - 'conditions': [ { 'name': 'disease categories - influenza' } ] - } + { + "specimen_tags": [], + "neuron_reconstructions": [], + "data_sets": [], + "donor": { + "transgenic_lines": [], + "organism": {"name": CellTypesApi.MOUSE}, + "conditions": [{"name": "disease categories - influenza"}], }, - { - 'specimen_tags': [], - 'neuron_reconstructions': [], - 'data_sets': [ {} ], - 'donor': { - 'transgenic_lines': [ { 'transgenic_line_type_name': 'driver', 'name': 'fish' } ], - 'organism': { 'name': 'fish' } - } + }, + { + "specimen_tags": [], + "neuron_reconstructions": [], + "data_sets": [{}], + "donor": { + "transgenic_lines": [{"transgenic_line_type_name": "driver", "name": "fish"}], + "organism": {"name": "fish"}, }, - { - 'specimen_tags': [], - 'neuron_reconstructions': [ {} ], - 'data_sets': [], - 'cell_reporter': { 'name': 'bob' }, - 'donor': { - 'transgenic_lines': [], - 'organism': { 'name': CellTypesApi.HUMAN }, - 'conditions': [ { 'name': 'disease categories - cheese' } ] - } + }, + { + "specimen_tags": [], + "neuron_reconstructions": [{}], + "data_sets": [], + "cell_reporter": {"name": "bob"}, + "donor": { + "transgenic_lines": [], + "organism": {"name": CellTypesApi.HUMAN}, + "conditions": [{"name": "disease categories - cheese"}], }, - ] + }, + ] + @pytest.fixture def cell_types_api(): endpoint = None - - if 'TEST_API_ENDPOINT' in os.environ: - endpoint = os.environ['TEST_API_ENDPOINT'] + + if "TEST_API_ENDPOINT" in os.environ: + endpoint = os.environ["TEST_API_ENDPOINT"] return CellTypesApi(endpoint) else: return None @@ -98,6 +100,7 @@ def cell_types_api(): @pytest.mark.requires_api_endpoint def test_list_cells_unmocked(cell_types_api): from allensdk.config import enable_console_log + enable_console_log() # this test will always require the latest warehouse @@ -111,24 +114,25 @@ def test_list_cells_mocked(mock_cells): cells = ctapi.list_cells() assert len(cells) == 3 - flu_cells = [ cell for cell in cells if cell['disease_categories'] == [('influenza')] ] + flu_cells = [cell for cell in cells if cell["disease_categories"] == [("influenza")]] assert len(flu_cells) == 1 - + cells = ctapi.list_cells(require_reconstruction=True) assert len(cells) == 1 - + cells = ctapi.list_cells(require_morphology=True) assert len(cells) == 1 - cells = ctapi.list_cells(reporter_status=['bob']) + cells = ctapi.list_cells(reporter_status=["bob"]) assert len(cells) == 1 - cells = ctapi.list_cells(species=['HOMO SAPIENS']) + cells = ctapi.list_cells(species=["HOMO SAPIENS"]) assert len(cells) == 1 - cells = ctapi.list_cells(species=['mus musculus']) + cells = ctapi.list_cells(species=["mus musculus"]) assert len(cells) == 1 + def test_list_cells_api_mocked(mock_cells_api): with patch.object(CellTypesApi, "model_query", return_value=mock_cells_api): ctapi = CellTypesApi() @@ -142,11 +146,8 @@ def test_list_cells_api_mocked(mock_cells_api): fcells = ctapi.filter_cells_api(cells, require_morphology=True) assert len(fcells) == 1 - fcells = ctapi.filter_cells_api(cells, species=['taco']) + fcells = ctapi.filter_cells_api(cells, species=["taco"]) assert len(fcells) == 2 - fcells = ctapi.filter_cells_api(cells, reporter_status=['fish']) + fcells = ctapi.filter_cells_api(cells, reporter_status=["fish"]) assert len(fcells) == 1 - - - diff --git a/allensdk/test/api/test_file_download.py b/allensdk/test/api/test_file_download.py index ace2ef9d60..38ba878b34 100644 --- a/allensdk/test/api/test_file_download.py +++ b/allensdk/test/api/test_file_download.py @@ -57,169 +57,150 @@ def cache(): @pytest.mark.parametrize("file_exists", (True, False)) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch.object(Manifest, 'safe_mkdir') +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch.object(Manifest, "safe_mkdir") def test_file_download_lazy(nrrd_read, safe_mkdir, mca, cache, file_exists): with patch.object(mca, "retrieve_file_over_http") as mock_retrieve: - @cacheable(strategy='lazy', - reader=nrrd_read, - pathfinder=Cache.pathfinder(file_name_position=3, - secondary_file_name_position=1)) - def download_volumetric_data(data_path, - file_name, - voxel_resolution=None, - save_file_path=None, - release=None, - coordinate_framework=None): - url = mca.build_volumetric_data_download_url(data_path, - file_name, - voxel_resolution, - release, - coordinate_framework) + + @cacheable( + strategy="lazy", + reader=nrrd_read, + pathfinder=Cache.pathfinder(file_name_position=3, secondary_file_name_position=1), + ) + def download_volumetric_data( + data_path, file_name, voxel_resolution=None, save_file_path=None, release=None, coordinate_framework=None + ): + url = mca.build_volumetric_data_download_url( + data_path, file_name, voxel_resolution, release, coordinate_framework + ) mca.retrieve_file_over_http(url, save_file_path) - with patch('os.path.exists', - Mock(name="os.path.exists", - return_value=file_exists)): + with patch("os.path.exists", Mock(name="os.path.exists", return_value=file_exists)): nrrd_read.reset_mock() - download_volumetric_data(MCA.AVERAGE_TEMPLATE, - 'annotation_10.nrrd', - MCA.VOXEL_RESOLUTION_10_MICRONS, - 'volumetric.nrrd', - MCA.CCF_2016, - strategy='lazy') + download_volumetric_data( + MCA.AVERAGE_TEMPLATE, + "annotation_10.nrrd", + MCA.VOXEL_RESOLUTION_10_MICRONS, + "volumetric.nrrd", + MCA.CCF_2016, + strategy="lazy", + ) if file_exists: - assert not mock_retrieve.called, 'server call not needed when file exists' + assert not mock_retrieve.called, "server call not needed when file exists" else: mock_retrieve.assert_called_once_with( - 'http://download.alleninstitute.org/informatics-archive/annotation/ccf_2016/mouse_ccf/average_template/annotation_10.nrrd', - 'volumetric.nrrd') - assert not safe_mkdir.called, 'safe_mkdir should not have been called.' - nrrd_read.assert_called_once_with('volumetric.nrrd') + "http://download.alleninstitute.org/informatics-archive/annotation/ccf_2016/mouse_ccf/average_template/annotation_10.nrrd", + "volumetric.nrrd", + ) + assert not safe_mkdir.called, "safe_mkdir should not have been called." + nrrd_read.assert_called_once_with("volumetric.nrrd") @pytest.mark.parametrize("file_exists", (True, False)) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch.object(Manifest, 'safe_mkdir') +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch.object(Manifest, "safe_mkdir") def test_file_download_server(nrrd_read, safe_mkdir, mca, cache, file_exists): with patch.object(mca, "retrieve_file_over_http") as mock_retrieve: - @cacheable(reader=nrrd_read, - pathfinder=Cache.pathfinder(file_name_position=3, - secondary_file_name_position=1)) - def download_volumetric_data(data_path, - file_name, - voxel_resolution=None, - save_file_path=None, - release=None, - coordinate_framework=None): - url = mca.build_volumetric_data_download_url(data_path, - file_name, - voxel_resolution, - release, - coordinate_framework) + + @cacheable(reader=nrrd_read, pathfinder=Cache.pathfinder(file_name_position=3, secondary_file_name_position=1)) + def download_volumetric_data( + data_path, file_name, voxel_resolution=None, save_file_path=None, release=None, coordinate_framework=None + ): + url = mca.build_volumetric_data_download_url( + data_path, file_name, voxel_resolution, release, coordinate_framework + ) mca.retrieve_file_over_http(url, save_file_path) - with patch('os.path.exists', - Mock(name="os.path.exists", - return_value=file_exists)): + with patch("os.path.exists", Mock(name="os.path.exists", return_value=file_exists)): nrrd_read.reset_mock() - - download_volumetric_data(MCA.AVERAGE_TEMPLATE, - 'annotation_10.nrrd', - MCA.VOXEL_RESOLUTION_10_MICRONS, - 'volumetric.nrrd', - MCA.CCF_2016, - strategy='create') + + download_volumetric_data( + MCA.AVERAGE_TEMPLATE, + "annotation_10.nrrd", + MCA.VOXEL_RESOLUTION_10_MICRONS, + "volumetric.nrrd", + MCA.CCF_2016, + strategy="create", + ) mock_retrieve.assert_called_once_with( - 'http://download.alleninstitute.org/informatics-archive/annotation/ccf_2016/mouse_ccf/average_template/annotation_10.nrrd', - 'volumetric.nrrd') - assert not safe_mkdir.called, 'safe_mkdir should not have been called.' - nrrd_read.assert_called_once_with('volumetric.nrrd') + "http://download.alleninstitute.org/informatics-archive/annotation/ccf_2016/mouse_ccf/average_template/annotation_10.nrrd", + "volumetric.nrrd", + ) + assert not safe_mkdir.called, "safe_mkdir should not have been called." + nrrd_read.assert_called_once_with("volumetric.nrrd") @pytest.mark.parametrize("file_exists", (True, False)) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch.object(Manifest, 'safe_mkdir') +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch.object(Manifest, "safe_mkdir") def test_file_download_cached_file(nrrd_read, safe_mkdir, mca, cache, file_exists): with patch.object(mca, "retrieve_file_over_http") as mock_retrieve: - @cacheable(reader=nrrd_read, - pathfinder=Cache.pathfinder(file_name_position=3, - secondary_file_name_position=1)) - def download_volumetric_data(data_path, - file_name, - voxel_resolution=None, - save_file_path=None, - release=None, - coordinate_framework=None): - url = mca.build_volumetric_data_download_url(data_path, - file_name, - voxel_resolution, - release, - coordinate_framework) + + @cacheable(reader=nrrd_read, pathfinder=Cache.pathfinder(file_name_position=3, secondary_file_name_position=1)) + def download_volumetric_data( + data_path, file_name, voxel_resolution=None, save_file_path=None, release=None, coordinate_framework=None + ): + url = mca.build_volumetric_data_download_url( + data_path, file_name, voxel_resolution, release, coordinate_framework + ) mca.retrieve_file_over_http(url, save_file_path) - with patch('os.path.exists', - Mock(name="os.path.exists", - return_value=file_exists)): + with patch("os.path.exists", Mock(name="os.path.exists", return_value=file_exists)): nrrd_read.reset_mock() - download_volumetric_data(MCA.AVERAGE_TEMPLATE, - 'annotation_10.nrrd', - MCA.VOXEL_RESOLUTION_10_MICRONS, - 'volumetric.nrrd', - MCA.CCF_2016, - strategy='file') + download_volumetric_data( + MCA.AVERAGE_TEMPLATE, + "annotation_10.nrrd", + MCA.VOXEL_RESOLUTION_10_MICRONS, + "volumetric.nrrd", + MCA.CCF_2016, + strategy="file", + ) - assert not mock_retrieve.called, 'server should not have been called' - assert not safe_mkdir.called, 'safe_mkdir should not have been called.' - nrrd_read.assert_called_once_with('volumetric.nrrd') + assert not mock_retrieve.called, "server should not have been called" + assert not safe_mkdir.called, "safe_mkdir should not have been called." + nrrd_read.assert_called_once_with("volumetric.nrrd") @pytest.mark.parametrize("file_exists", (True, False)) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch.object(Manifest, 'safe_mkdir') +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch.object(Manifest, "safe_mkdir") def test_file_kwarg(nrrd_read, safe_mkdir, mca, cache, file_exists): with patch.object(mca, "retrieve_file_over_http") as mock_retrieve: - @cacheable(reader=nrrd_read, - pathfinder=Cache.pathfinder(file_name_position=3, - secondary_file_name_position=1, - path_keyword='save_file_path')) - def download_volumetric_data(data_path, - file_name, - voxel_resolution=None, - save_file_path=None, - release=None, - coordinate_framework=None): - url = mca.build_volumetric_data_download_url(data_path, - file_name, - voxel_resolution, - release, - coordinate_framework) + + @cacheable( + reader=nrrd_read, + pathfinder=Cache.pathfinder( + file_name_position=3, secondary_file_name_position=1, path_keyword="save_file_path" + ), + ) + def download_volumetric_data( + data_path, file_name, voxel_resolution=None, save_file_path=None, release=None, coordinate_framework=None + ): + url = mca.build_volumetric_data_download_url( + data_path, file_name, voxel_resolution, release, coordinate_framework + ) mca.retrieve_file_over_http(url, save_file_path) - with patch('os.path.exists', - Mock(name="os.path.exists", - return_value=file_exists)): + with patch("os.path.exists", Mock(name="os.path.exists", return_value=file_exists)): nrrd_read.reset_mock() - download_volumetric_data(MCA.AVERAGE_TEMPLATE, - 'annotation_10.nrrd', - MCA.VOXEL_RESOLUTION_10_MICRONS, - 'volumetric.nrrd', - MCA.CCF_2016, - strategy='file', - save_file_path='file.nrrd' ) - - assert not mock_retrieve.called, 'server should not have been called' - assert not safe_mkdir.called, 'safe_mkdir should not have been called.' - nrrd_read.assert_called_once_with('file.nrrd') + download_volumetric_data( + MCA.AVERAGE_TEMPLATE, + "annotation_10.nrrd", + MCA.VOXEL_RESOLUTION_10_MICRONS, + "volumetric.nrrd", + MCA.CCF_2016, + strategy="file", + save_file_path="file.nrrd", + ) + + assert not mock_retrieve.called, "server should not have been called" + assert not safe_mkdir.called, "safe_mkdir should not have been called." + nrrd_read.assert_called_once_with("file.nrrd") diff --git a/allensdk/test/api/test_glif_api.py b/allensdk/test/api/test_glif_api.py index 47a1290576..8895eec068 100644 --- a/allensdk/test/api/test_glif_api.py +++ b/allensdk/test/api/test_glif_api.py @@ -53,8 +53,8 @@ def specimen_id(): def glif_api(): endpoint = None - if 'TEST_API_ENDPOINT' in os.environ: - endpoint = os.environ['TEST_API_ENDPOINT'] + if "TEST_API_ENDPOINT" in os.environ: + endpoint = os.environ["TEST_API_ENDPOINT"] return GlifApi(endpoint) else: return None @@ -66,24 +66,22 @@ def test_get_neuronal_model_templates(glif_api): assert len(glif_api.get_neuronal_model_templates()) == 7 for template in glif_api.get_neuronal_model_templates(): - - if template['id'] == 329230710: - assert 'perisomatic' in template['name'] - elif template['id'] == 395310498: - assert '(LIF-R-ASC-A)' in template['name'] - elif template['id'] == 395310469: - assert '(LIF)' in template['name'] - elif template['id'] == 395310475: - assert '(LIF-ASC)' in template['name'] - elif template['id'] == 395310479: - assert '(LIF-R)' in template['name'] - elif template['id'] == 471355161: - assert '(LIF-R-ASC)' in template['name'] - elif template['id'] == 491455321: - assert 'Biophysical - all active' in template['name'] + if template["id"] == 329230710: + assert "perisomatic" in template["name"] + elif template["id"] == 395310498: + assert "(LIF-R-ASC-A)" in template["name"] + elif template["id"] == 395310469: + assert "(LIF)" in template["name"] + elif template["id"] == 395310475: + assert "(LIF-ASC)" in template["name"] + elif template["id"] == 395310479: + assert "(LIF-R)" in template["name"] + elif template["id"] == 471355161: + assert "(LIF-R-ASC)" in template["name"] + elif template["id"] == 491455321: + assert "Biophysical - all active" in template["name"] else: - raise Exception('Unrecognized template: %s (%s)' % ( - template['id'], template['name'])) + raise Exception("Unrecognized template: %s (%s)" % (template["id"], template["name"])) @pytest.mark.requires_api_endpoint @@ -92,7 +90,7 @@ def test_get_neuronal_models(glif_api, specimen_id): cells = glif_api.get_neuronal_models([specimen_id]) assert len(cells) == 1 - assert len(cells[0]['neuronal_models']) == 2 + assert len(cells[0]["neuronal_models"]) == 2 @pytest.mark.requires_api_endpoint @@ -107,14 +105,12 @@ def test_get_neuronal_models_no_ids(glif_api): def test_get_neuron_configs(glif_api, specimen_id): model = glif_api.get_neuronal_models([specimen_id]) - neuronal_model_ids = [nm['id'] for nm in model[0]['neuronal_models']] + neuronal_model_ids = [nm["id"] for nm in model[0]["neuronal_models"]] assert set(neuronal_model_ids) == set((566283950, 566283946)) test_id = 566283950 - np.testing.assert_almost_equal( - glif_api.get_neuron_configs([test_id])[test_id]['th_inf'], - 0.024561992461740227) + np.testing.assert_almost_equal(glif_api.get_neuron_configs([test_id])[test_id]["th_inf"], 0.024561992461740227) @pytest.mark.requires_api_endpoint @@ -131,6 +127,6 @@ def test_deprecated(fn_temp_dir, glif_api, neuronal_model_id): glif_api.get_neuronal_model(neuronal_model_id) glif_api.get_neuron_config() - nwb_path = os.path.join(fn_temp_dir, 'tmp.nwb') + nwb_path = os.path.join(fn_temp_dir, "tmp.nwb") glif_api.get_neuronal_model(neuronal_model_id) glif_api.cache_stimulus_file(nwb_path) diff --git a/allensdk/test/api/test_grid_data_api.py b/allensdk/test/api/test_grid_data_api.py index 458d07a43f..e0c06f848e 100644 --- a/allensdk/test/api/test_grid_data_api.py +++ b/allensdk/test/api/test_grid_data_api.py @@ -41,41 +41,39 @@ @pytest.fixture def grid_data(): gda = GridDataApi() - gda.retrieve_file_over_http = \ - MagicMock(name='retrieve_file_over_http') + gda.retrieve_file_over_http = MagicMock(name="retrieve_file_over_http") return gda def test_download_gene_expression_grid_data(grid_data): - - path = '69816930/density.mhd' + path = "69816930/density.mhd" section_data_set_id = 69816930 - volume_type = 'density' + volume_type = "density" grid_data.download_gene_expression_grid_data(section_data_set_id, volume_type, path) - expected = 'http://api.brain-map.org/grid_data/download/69816930?include=density' + expected = "http://api.brain-map.org/grid_data/download/69816930?include=density" grid_data.retrieve_file_over_http.assert_called_once_with(expected, path, zipped=True) def test_api_doc_url_download_expression_grid(grid_data): - '''Url to download the 200um density volume + """Url to download the 200um density volume for the Mouse Brain Atlas SectionDataSet 69816930. Notes ----- See `Downloading 3-D Expression Grid Data `_ , example 'Download the 200um density volume for the Mouse Brain Atlas SectionDataSet 69816930'. - ''' - path = '69816930.zip' + """ + path = "69816930.zip" section_data_set_id = 69816930 grid_data.download_expression_grid_data(section_data_set_id) - expected = 'http://api.brain-map.org/grid_data/download/69816930' + expected = "http://api.brain-map.org/grid_data/download/69816930" grid_data.retrieve_file_over_http.assert_called_once_with(expected, path) def test_api_doc_url_download_expression_grid_energy_intensity(grid_data): - '''Url to download the 200um energy and intensity volumes for Mouse Brain Atlas SectionDataSet 69816930. + """Url to download the 200um energy and intensity volumes for Mouse Brain Atlas SectionDataSet 69816930. Notes ----- @@ -83,81 +81,80 @@ def test_api_doc_url_download_expression_grid_energy_intensity(grid_data): , example 'Download the 200um energy and intensity volumes for Mouse Brain Atlas SectionDataSet 69816930'. The id in the example url doesn't match the caption. - ''' - path = '183282970.zip' + """ + path = "183282970.zip" section_data_set_id = 183282970 - include = ['energy', 'intensity'] - grid_data.download_expression_grid_data(section_data_set_id, - include=include) + include = ["energy", "intensity"] + grid_data.download_expression_grid_data(section_data_set_id, include=include) grid_data.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/grid_data/download/183282970" - "?include=energy,intensity", - path) + "http://api.brain-map.org/grid_data/download/183282970?include=energy,intensity", path + ) def test_api_doc_url_projection_grid(grid_data): - '''Url to download the 100um density volume for the Mouse Connectivity Atlas SectionDataSet 181777177. + """Url to download the 100um density volume for the Mouse Connectivity Atlas SectionDataSet 181777177. Notes ----- See `Downloading 3-D Projection Grid Data `_ , example 'Download the 100um density volume for the Mouse Connectivity Atlas SectionDataSet 181777177'. - ''' - path = '181777177.nrrd' + """ + path = "181777177.nrrd" section_data_set_id = 181777177 grid_data.download_projection_grid_data(section_data_set_id) - expected = 'http://api.brain-map.org/grid_data/download_file/181777177' + expected = "http://api.brain-map.org/grid_data/download_file/181777177" grid_data.retrieve_file_over_http.assert_called_once_with(expected, path) def test_api_doc_url_projection_grid_injection_fraction_resolution(grid_data): - '''Url to download the 25um injection_fraction volume for Mouse Connectivity Atlas SectionDataSet 181777177. + """Url to download the 25um injection_fraction volume for Mouse Connectivity Atlas SectionDataSet 181777177. Notes ----- See `Downloading 3-D Projection Grid Data `_ , example 'Download the 25um injection_fraction volume for Mouse Connectivity Atlas SectionDataSet 181777177'. - ''' + """ section_data_set_id = 181777177 - path = 'id.nrrd' - grid_data.download_projection_grid_data(section_data_set_id, - [grid_data.INJECTION_FRACTION], - resolution=25, - save_file_path=path) + path = "id.nrrd" + grid_data.download_projection_grid_data( + section_data_set_id, [grid_data.INJECTION_FRACTION], resolution=25, save_file_path=path + ) grid_data.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/grid_data/download_file/181777177" - "?image=injection_fraction&resolution=25", - path) + "http://api.brain-map.org/grid_data/download_file/181777177?image=injection_fraction&resolution=25", path + ) def test_download_deformation_field(grid_data): grid_data.model_query = MagicMock( - name='model_query', + name="model_query", return_value=[ - {'well_known_file_type': {'name': 'DeformationFieldHeader'}, 'id': 123}, - {'well_known_file_type': {'name': 'DeformationFieldVoxels'}, 'id': 456} - ] + {"well_known_file_type": {"name": "DeformationFieldHeader"}, "id": 123}, + {"well_known_file_type": {"name": "DeformationFieldVoxels"}, "id": 456}, + ], ) grid_data.download_deformation_field(789) - grid_data.retrieve_file_over_http.assert_any_call('http://api.brain-map.org/api/v2/well_known_file_download/123', '789_dfmfld.mhd') - grid_data.retrieve_file_over_http.assert_any_call('http://api.brain-map.org/api/v2/well_known_file_download/456', '789_dfmfld.raw') + grid_data.retrieve_file_over_http.assert_any_call( + "http://api.brain-map.org/api/v2/well_known_file_download/123", "789_dfmfld.mhd" + ) + grid_data.retrieve_file_over_http.assert_any_call( + "http://api.brain-map.org/api/v2/well_known_file_download/456", "789_dfmfld.raw" + ) def test_download_alignment3d(grid_data): - grid_data.json_msg_query = MagicMock( - name='json_msg_query', - return_value=[{'alignment3d': 'foo'}] - ) + grid_data.json_msg_query = MagicMock(name="json_msg_query", return_value=[{"alignment3d": "foo"}]) obtained = grid_data.download_alignment3d(123) - assert 'foo' == obtained - grid_data.json_msg_query.assert_called_once_with(( - 'http://api.brain-map.org/api/v2/data/query.json?q=' - 'model::SectionDataSet[id$eq123],' - 'rma::include,alignment3d,' - 'rma::options[num_rows$eq\'all\'][count$eqfalse]' - )) + assert "foo" == obtained + grid_data.json_msg_query.assert_called_once_with( + ( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::SectionDataSet[id$eq123]," + "rma::include,alignment3d," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) + ) diff --git a/allensdk/test/api/test_image_download_api.py b/allensdk/test/api/test_image_download_api.py index 0e57b05809..9002520e7a 100644 --- a/allensdk/test/api/test_image_download_api.py +++ b/allensdk/test/api/test_image_download_api.py @@ -43,539 +43,463 @@ def image_api(): image_api = ImageDownloadApi() - image_api.retrieve_file_over_http = \ - MagicMock(name='retrieve_file_over_http') - image_api.json_msg_query = MagicMock(name='json_msg_query') + image_api.retrieve_file_over_http = MagicMock(name="retrieve_file_over_http") + image_api.json_msg_query = MagicMock(name="json_msg_query") return image_api -def test_get_section_image_ranges(image_api): +def test_get_section_image_ranges(image_api): section_image_ids = [126862575, 297225768] image_api.get_section_image_ranges(section_image_ids) - image_api.json_msg_query.assert_called_once_with('http://api.brain-map.org/api/v2/data/query.json?q=model::Equalization,' - 'rma::criteria,section_data_set(section_images[id$in126862575,297225768]),' - 'rma::options[only$eq\'blue_lower,blue_upper,red_lower,red_upper,green_lower,green_upper\']' - '[num_rows$eq\'all\'][count$eqfalse]') + image_api.json_msg_query.assert_called_once_with( + "http://api.brain-map.org/api/v2/data/query.json?q=model::Equalization," + "rma::criteria,section_data_set(section_images[id$in126862575,297225768])," + "rma::options[only$eq'blue_lower,blue_upper,red_lower,red_upper,green_lower,green_upper']" + "[num_rows$eq'all'][count$eqfalse]" + ) -def test_get_section_data_sets_by_product(image_api): +def test_get_section_data_sets_by_product(image_api): product_ids = [10, 22] image_api.get_section_data_sets_by_product(product_ids) - image_api.json_msg_query.assert_called_once_with('http://api.brain-map.org/api/v2/data/query.json?'\ - 'q=model::SectionDataSet,'\ - 'rma::criteria,[failed$in\'false\'],products[id$in10,22],'\ - 'rma::options[num_rows$eq\'all\'][count$eqfalse]') + image_api.json_msg_query.assert_called_once_with( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::SectionDataSet," + "rma::criteria,[failed$in'false'],products[id$in10,22]," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) -def test_get_section_data_sets_by_product_failedok(image_api): +def test_get_section_data_sets_by_product_failedok(image_api): product_ids = [10, 22] image_api.get_section_data_sets_by_product(product_ids, include_failed=True) - image_api.json_msg_query.assert_called_once_with('http://api.brain-map.org/api/v2/data/query.json?'\ - 'q=model::SectionDataSet,'\ - 'rma::criteria,[failed$in\'false\',\'true\'],products[id$in10,22],'\ - 'rma::options[num_rows$eq\'all\'][count$eqfalse]') + image_api.json_msg_query.assert_called_once_with( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::SectionDataSet," + "rma::criteria,[failed$in'false','true'],products[id$in10,22]," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) -def test_get_section_image_ranges_as_list(image_api): - image_api.template_query = MagicMock(return_value=[{'blue_lower': 0, 'blue_upper': 1, 'green_lower': 2, 'green_upper': 3, 'red_lower': 4, 'red_upper': 5}]) +def test_get_section_image_ranges_as_list(image_api): + image_api.template_query = MagicMock( + return_value=[ + {"blue_lower": 0, "blue_upper": 1, "green_lower": 2, "green_upper": 3, "red_lower": 4, "red_upper": 5} + ] + ) obt = image_api.get_section_image_ranges([1]) - assert(np.allclose( [4, 5, 2, 3, 0, 1], obt[0] )) + assert np.allclose([4, 5, 2, 3, 0, 1], obt[0]) + def test_api_doc_url_download_section_image_downsampled(image_api): - ''' + """ Notes ----- See: `Experimental Overview and Metadata `_ , link labeled 'Download image downsampled by factor of 6 using default thresholds'. - ''' - path = '126862575.jpg' + """ + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - range=[0, 932, 0, 1279, 0, 4095]) + image_api.download_section_image(section_image_id, downsample=6, range=[0, 932, 0, 1279, 0, 4095]) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&range=0,932,0,1279,0,4095", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&range=0,932,0,1279,0,4095", path + ) def test_api_doc_url_download_section_image_downsample_dimensions(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - downsample_dimensions=True) + image_api.download_section_image(section_image_id, downsample=6, downsample_dimensions=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&downsample_dimensions=true", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&downsample_dimensions=true", path + ) def test_api_doc_url_download_section_image_full_res(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 image_api.download_section_image(section_image_id) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575", path + ) def test_api_doc_url_download_section_image_downsample_dimensions_false(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - downsample_dimensions=False) + image_api.download_section_image(section_image_id, downsample=6, downsample_dimensions=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&downsample_dimensions=false", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&downsample_dimensions=false", + path, + ) def test_api_doc_url_download_section_image_downsampled_low_quality(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=3, - quality=50) + image_api.download_section_image(section_image_id, downsample=3, quality=50) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=3&quality=50", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=3&quality=50", path + ) def test_api_doc_url_download_section_image_tumor_feature_annotation(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - tumor_feature_annotation=True) + image_api.download_section_image(section_image_id, downsample=6, tumor_feature_annotation=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&tumor_feature_annotation=true", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&tumor_feature_annotation=true", + path, + ) def test_api_doc_url_download_section_image_tumor_feature_annotation_false(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - tumor_feature_annotation=False) + image_api.download_section_image(section_image_id, downsample=6, tumor_feature_annotation=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&tumor_feature_annotation=false", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&tumor_feature_annotation=false", + path, + ) def test_api_doc_url_download_section_image_tumor_feature_boundary(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - tumor_feature_boundary=True) + image_api.download_section_image(section_image_id, downsample=6, tumor_feature_boundary=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&tumor_feature_boundary=true", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&tumor_feature_boundary=true", + path, + ) def test_api_doc_url_download_section_image_tumor_feature_boundary_false(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - tumor_feature_boundary=False) + image_api.download_section_image(section_image_id, downsample=6, tumor_feature_boundary=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&tumor_feature_boundary=false", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&tumor_feature_boundary=false", + path, + ) def test_api_doc_url_download_section_image_expression(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - expression=True) + image_api.download_section_image(section_image_id, downsample=6, expression=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&expression=true", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&expression=true", path + ) def test_api_doc_url_download_section_image_expression_false(image_api): - path = '126862575.jpg' + path = "126862575.jpg" section_image_id = 126862575 - image_api.download_section_image(section_image_id, - downsample=6, - expression=False) + image_api.download_section_image(section_image_id, downsample=6, expression=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/126862575" - "?downsample=6&expression=false", - path) + "http://api.brain-map.org/api/v2/section_image_download/126862575?downsample=6&expression=false", path + ) def test_api_doc_url_download_atlas_image_downsampled(image_api): - path = '100883869.jpg' + path = "100883869.jpg" section_image_id = 100883869 - image_api.download_atlas_image(section_image_id, - downsample=4) + image_api.download_atlas_image(section_image_id, downsample=4) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?downsample=4", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?downsample=4", path + ) def test_api_doc_url_download_atlas_image_downsampled_low_quality(image_api): - path = '100883869.jpg' + path = "100883869.jpg" section_image_id = 100883869 - image_api.download_atlas_image(section_image_id, - downsample=4, - quality=50) + image_api.download_atlas_image(section_image_id, downsample=4, quality=50) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?downsample=4&quality=50", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?downsample=4&quality=50", path + ) def test_api_doc_url_download_atlas_image_annotation(image_api): - path = '100883869.jpg' + path = "100883869.jpg" section_image_id = 100883869 - image_api.download_atlas_image(section_image_id, - downsample=4, - annotation=True) + image_api.download_atlas_image(section_image_id, downsample=4, annotation=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?downsample=4&annotation=true", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?downsample=4&annotation=true", path + ) def test_api_doc_url_download_atlas_image_annotation_false(image_api): - path = '100883869.jpg' + path = "100883869.jpg" section_image_id = 100883869 - image_api.download_atlas_image(section_image_id, - downsample=4, - annotation=False) + image_api.download_atlas_image(section_image_id, downsample=4, annotation=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?downsample=4&annotation=false", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?downsample=4&annotation=false", path + ) def test_api_doc_url_download_atlas_image_atlas(image_api): - path = '100883869.jpg' + path = "100883869.jpg" section_image_id = 100883869 - image_api.download_atlas_image(section_image_id, - downsample=4, - annotation=True, - atlas=2) + image_api.download_atlas_image(section_image_id, downsample=4, annotation=True, atlas=2) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?downsample=4&annotation=true&atlas=2", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?downsample=4&annotation=true&atlas=2", path + ) def test_api_doc_url_download_atlas_full_resolution_region_of_interest(image_api): - path = '100883869.jpg' + path = "100883869.jpg" subimage_id = 100883869 - image_api.download_atlas_image(subimage_id, - left=6174, - top=2282, - width=1000, - height=1000) + image_api.download_atlas_image(subimage_id, left=6174, top=2282, width=1000, height=1000) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/atlas_image_download/100883869" - "?left=6174&top=2282&width=1000&height=1000", - path) + "http://api.brain-map.org/api/v2/atlas_image_download/100883869?left=6174&top=2282&width=1000&height=1000", path + ) def test_api_doc_url_download_projection_image_downsampled(image_api): - path = '126862583.jpg' + path = "126862583.jpg" section_image_id = 126862583 - image_api.download_projection_image(section_image_id, - downsample=4) + image_api.download_projection_image(section_image_id, downsample=4) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/projection_image_download/126862583" - "?downsample=4", - path) + "http://api.brain-map.org/api/v2/projection_image_download/126862583?downsample=4", path + ) def test_api_doc_url_download_projection_image_projection(image_api): - path = '126862583.jpg' + path = "126862583.jpg" section_image_id = 126862583 - image_api.download_projection_image(section_image_id, - downsample=4, - projection=True) + image_api.download_projection_image(section_image_id, downsample=4, projection=True) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/projection_image_download/126862583" - "?downsample=4&projection=true", - path) + "http://api.brain-map.org/api/v2/projection_image_download/126862583?downsample=4&projection=true", path + ) def test_api_doc_url_download_projection_image_projection_false(image_api): - path = '126862583.jpg' + path = "126862583.jpg" section_image_id = 126862583 - image_api.download_projection_image(section_image_id, - downsample=4, - projection=False) + image_api.download_projection_image(section_image_id, downsample=4, projection=False) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/projection_image_download/126862583" - "?downsample=4&projection=false", - path) + "http://api.brain-map.org/api/v2/projection_image_download/126862583?downsample=4&projection=false", path + ) def test_api_doc_url_download_projection_image_view(image_api): - path = '126862583.jpg' + path = "126862583.jpg" section_image_id = 126862583 - image_api.download_projection_image(section_image_id, - downsample=4, - view='projection') + image_api.download_projection_image(section_image_id, downsample=4, view="projection") image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/projection_image_download/126862583" - "?downsample=4&view=projection", - path) + "http://api.brain-map.org/api/v2/projection_image_download/126862583?downsample=4&view=projection", path + ) def test_api_doc_url_download_projection_image_view_exception(image_api): - section_image_id = 126862583 - + with pytest.raises(ValueError) as excinfo: - image_api.download_projection_image(section_image_id, - downsample=4, - view='typo') + image_api.download_projection_image(section_image_id, downsample=4, view="typo") - assert excinfo.value.args[0] == "view argument should be 'expression', 'projection', 'tumor_feature_annotation' or 'tumor_feature_boundary'" + assert ( + excinfo.value.args[0] + == "view argument should be 'expression', 'projection', 'tumor_feature_annotation' or 'tumor_feature_boundary'" + ) def test_api_doc_url_download_image_downsampled(image_api): - ''' + """ Notes ----- See: `Image Download Service `_ - ''' - path = '69750516.jpg' + """ + path = "69750516.jpg" subimage_id = 69750516 - image_api.download_image(subimage_id, - downsample=4) + image_api.download_image(subimage_id, downsample=4) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/image_download/69750516" - "?downsample=4", - path) + "http://api.brain-map.org/api/v2/image_download/69750516?downsample=4", path + ) def test_api_doc_url_download_image_downsampled_low_quality(image_api): - ''' + """ Notes ----- See: `Image Download Service `_ - ''' - path = '69750516.jpg' + """ + path = "69750516.jpg" subimage_id = 69750516 - image_api.download_image(subimage_id, - downsample=3, - quality=50) + image_api.download_image(subimage_id, downsample=3, quality=50) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/image_download/69750516" - "?downsample=3&quality=50", - path) + "http://api.brain-map.org/api/v2/image_download/69750516?downsample=3&quality=50", path + ) def test_api_doc_url_download_full_resolution_region_of_interest(image_api): - ''' + """ Notes ----- See: `Image Download Service `_ - ''' - path = '69750516.jpg' + """ + path = "69750516.jpg" subimage_id = 69750516 - image_api.download_image(subimage_id, - left=6174, - top=2282, - width=1000, - height=1000) + image_api.download_image(subimage_id, left=6174, top=2282, width=1000, height=1000) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/image_download/69750516" - "?left=6174&top=2282&width=1000&height=1000", - path) + "http://api.brain-map.org/api/v2/image_download/69750516?left=6174&top=2282&width=1000&height=1000", path + ) def test_api_doc_url_download_image_expression_mask(image_api): - ''' + """ Notes ----- See: `Image Download Service `_ - ''' - path = '69750516.jpg' + """ + path = "69750516.jpg" subimage_id = 69750516 - image_api.download_image(subimage_id, - downsample=4, - view='expression') + image_api.download_image(subimage_id, downsample=4, view="expression") image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/image_download/69750516" - "?downsample=4&view=expression", - path) + "http://api.brain-map.org/api/v2/image_download/69750516?downsample=4&view=expression", path + ) def test_api_doc_url_download_image_full_resolution(image_api): - ''' + """ Notes ----- See: `Experimental Overview and Metadata `_ , link labeled 'Download a region of interest at full resolution using default thresholds'. - ''' - expected = 'http://api.brain-map.org/api/v2/section_image_download/126862575?range=0,932,0,1279,0,4095&left=19045&top=11684&width=1000&height=1000' - path = '126862575.jpg' + """ + expected = "http://api.brain-map.org/api/v2/section_image_download/126862575?range=0,932,0,1279,0,4095&left=19045&top=11684&width=1000&height=1000" + path = "126862575.jpg" - image_api.retrieve_file_over_http = \ - MagicMock(name='retrieve_file_over_http') + image_api.retrieve_file_over_http = MagicMock(name="retrieve_file_over_http") section_image_id = 126862575 - image_api.download_section_image(section_image_id, - left=19045, - top=11684, - width=1000, - height=1000, - range=[0, 932, 0, 1279, 0, 4095]) + image_api.download_section_image( + section_image_id, left=19045, top=11684, width=1000, height=1000, range=[0, 932, 0, 1279, 0, 4095] + ) image_api.retrieve_file_over_http.assert_called_once_with(expected, path) def test_colormap_filter(image_api): - ''' - ''' - path = '70636013.jpg' + """ """ + path = "70636013.jpg" section_image_id = 70636013 - image_api.download_section_image(section_image_id, - downsample=4, - view='expression', - colormap=(0.9,"expression")) + image_api.download_section_image(section_image_id, downsample=4, view="expression", colormap=(0.9, "expression")) image_api.retrieve_file_over_http.assert_called_once_with( "http://api.brain-map.org/api/v2/section_image_download/70636013" "?downsample=4&colormap=0.5,0.9,0,256,4&view=expression", - path) + path, + ) def test_colormap_filter_string(image_api): - ''' - ''' - path = '70636013.jpg' + """ """ + path = "70636013.jpg" section_image_id = 70636013 - image_api.download_section_image(section_image_id, - downsample=4, - view='expression', - colormap="expression") + image_api.download_section_image(section_image_id, downsample=4, view="expression", colormap="expression") image_api.retrieve_file_over_http.assert_called_once_with( "http://api.brain-map.org/api/v2/section_image_download/70636013" "?downsample=4&colormap=expression&view=expression", - path) - + path, + ) def test_rgb_filter(image_api): - ''' - ''' - path = '70636013.jpg' + """ """ + path = "70636013.jpg" section_image_id = 70636013 - image_api.download_section_image(section_image_id, - downsample=4, - view='expression', - rgb=[0.25,0.5,1]) + image_api.download_section_image(section_image_id, downsample=4, view="expression", rgb=[0.25, 0.5, 1]) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/70636013" - "?downsample=4&rgb=0.25,0.5,1&view=expression", - path) + "http://api.brain-map.org/api/v2/section_image_download/70636013?downsample=4&rgb=0.25,0.5,1&view=expression", + path, + ) def test_contrast_filter(image_api): - ''' - ''' - path = '70636013.jpg' + """ """ + path = "70636013.jpg" section_image_id = 70636013 - image_api.download_section_image(section_image_id, - downsample=4, - view='expression', - contrast=[0.5,1]) + image_api.download_section_image(section_image_id, downsample=4, view="expression", contrast=[0.5, 1]) image_api.retrieve_file_over_http.assert_called_once_with( - "http://api.brain-map.org/api/v2/section_image_download/70636013" - "?downsample=4&contrast=0.5,1&view=expression", - path) + "http://api.brain-map.org/api/v2/section_image_download/70636013?downsample=4&contrast=0.5,1&view=expression", + path, + ) def test_atlas_image_query(image_api): - expected = "http://api.brain-map.org/api/v2/data/query.json?q=" + \ - "model::Atlas,rma::criteria,[id$eq1]," + \ - "rma::options[only$eqimage_type]," + \ - "pipe::list[type_name$is'image_type']," + \ - "model::AtlasImage,rma::criteria,[annotated$eqtrue]," + \ - "atlas_data_set(atlases[id$eq1])," + \ - "alternate_images[image_type$eq$type_name]," + \ - "rma::options[num_rows$eq'all']" + \ - "[order$eqsub_images.section_number]" + expected = ( + "http://api.brain-map.org/api/v2/data/query.json?q=" + + "model::Atlas,rma::criteria,[id$eq1]," + + "rma::options[only$eqimage_type]," + + "pipe::list[type_name$is'image_type']," + + "model::AtlasImage,rma::criteria,[annotated$eqtrue]," + + "atlas_data_set(atlases[id$eq1])," + + "alternate_images[image_type$eq$type_name]," + + "rma::options[num_rows$eq'all']" + + "[order$eqsub_images.section_number]" + ) adult_mouse_atlas_id = 1 image_api.atlas_image_query(adult_mouse_atlas_id) @@ -584,27 +508,29 @@ def test_atlas_image_query(image_api): def test_atlas_image_query_image_type_name(image_api): - expected = "http://api.brain-map.org/api/v2/data/query.json?q=" + \ - "model::AtlasImage,rma::criteria,[annotated$eqtrue]," + \ - "atlas_data_set(atlases[id$eq1])," + \ - "alternate_images[image_type$eq'Atlas - Adult Mouse']," + \ - "rma::options[num_rows$eq'all']" + \ - "[order$eqsub_images.section_number]" + expected = ( + "http://api.brain-map.org/api/v2/data/query.json?q=" + + "model::AtlasImage,rma::criteria,[annotated$eqtrue]," + + "atlas_data_set(atlases[id$eq1])," + + "alternate_images[image_type$eq'Atlas - Adult Mouse']," + + "rma::options[num_rows$eq'all']" + + "[order$eqsub_images.section_number]" + ) adult_mouse_atlas_id = 1 - adult_mouse_image_type_name = 'Atlas - Adult Mouse' - image_api.atlas_image_query(adult_mouse_atlas_id, - image_type_name=adult_mouse_image_type_name) + adult_mouse_image_type_name = "Atlas - Adult Mouse" + image_api.atlas_image_query(adult_mouse_atlas_id, image_type_name=adult_mouse_image_type_name) image_api.json_msg_query.assert_called_once_with(expected) def test_section_image_query(image_api): - - exp = 'http://api.brain-map.org/api/v2/data/query.json?'\ - 'q=model::SectionImage,'\ - 'rma::criteria,[data_set_id$eq70813257],'\ - 'rma::options[num_rows$eq\'all\'][count$eqfalse]' + exp = ( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::SectionImage," + "rma::criteria,[data_set_id$eq70813257]," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) image_api.section_image_query(70813257) image_api.json_msg_query.assert_called_once_with(exp) diff --git a/allensdk/test/api/test_mouse_atlas_api.py b/allensdk/test/api/test_mouse_atlas_api.py index 258d7734f6..1e6e86afaf 100644 --- a/allensdk/test/api/test_mouse_atlas_api.py +++ b/allensdk/test/api/test_mouse_atlas_api.py @@ -40,7 +40,6 @@ from allensdk.api.queries.mouse_atlas_api import MouseAtlasApi as MAA - @pytest.fixture def atlas(): maa = MAA() @@ -49,55 +48,51 @@ def atlas(): @patch.object(MAA, "json_msg_query") def test_get_genes(mock_query, atlas): - - expected = 'http://api.brain-map.org/api/v2/data/query.json?'\ - 'q=model::Gene,rma::criteria,[organism_id$in2],rma::include,chromosome,'\ - 'rma::options[num_rows$eq2000][start_row$eq0][order$eq\'id\'][count$eqfalse]' + expected = ( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::Gene,rma::criteria,[organism_id$in2],rma::include,chromosome," + "rma::options[num_rows$eq2000][start_row$eq0][order$eq'id'][count$eqfalse]" + ) for result in atlas.get_genes(): pass - + mock_query.assert_called_once_with(expected) @patch.object(MAA, "json_msg_query") def test_get_section_data_sets(mock_query, atlas): - - expected = 'http://api.brain-map.org/api/v2/data/query.json?'\ - 'q=model::SectionDataSet,rma::criteria,products[id$in1],rma::include,genes,'\ - 'rma::options[num_rows$eq2000][start_row$eq0][order$eq\'id\'][count$eqfalse]' + expected = ( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::SectionDataSet,rma::criteria,products[id$in1],rma::include,genes," + "rma::options[num_rows$eq2000][start_row$eq0][order$eq'id'][count$eqfalse]" + ) for result in atlas.get_section_data_sets(): pass - + mock_query.assert_called_once_with(expected) def test_download_expression_density(atlas): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: with pytest.raises(RuntimeError): - atlas.download_expression_density('file.name', 12345) + atlas.download_expression_density("file.name", 12345) + + gda.assert_called_once_with("http://api.brain-map.org/grid_data/download/12345?include=density") - gda.assert_called_once_with( - 'http://api.brain-map.org/grid_data/download/'\ - '12345?include=density') - def test_download_expression_intensity(atlas): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: with pytest.raises(RuntimeError): - atlas.download_expression_intensity('file.name', 12345) + atlas.download_expression_intensity("file.name", 12345) - gda.assert_called_once_with( - 'http://api.brain-map.org/grid_data/download/'\ - '12345?include=intensity') + gda.assert_called_once_with("http://api.brain-map.org/grid_data/download/12345?include=intensity") def test_download_expression_energy(atlas): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: with pytest.raises(RuntimeError): - atlas.download_expression_energy('file.name', 12345) + atlas.download_expression_energy("file.name", 12345) - gda.assert_called_once_with( - 'http://api.brain-map.org/grid_data/download/'\ - '12345?include=energy') + gda.assert_called_once_with("http://api.brain-map.org/grid_data/download/12345?include=energy") diff --git a/allensdk/test/api/test_mouse_connectivity_api.py b/allensdk/test/api/test_mouse_connectivity_api.py index 3979ac5b52..b8709695df 100644 --- a/allensdk/test/api/test_mouse_connectivity_api.py +++ b/allensdk/test/api/test_mouse_connectivity_api.py @@ -39,218 +39,166 @@ import numpy as np from allensdk.api.queries.mouse_connectivity_api import MouseConnectivityApi as MCA -MOCK_ANNOTATION_DATA = 'mock_annotation_data' -MOCK_ANNOTATION_IMAGE = 'mock_annotation_image' -DOWNLOAD_LINK = '/path/to/link' +MOCK_ANNOTATION_DATA = "mock_annotation_data" +MOCK_ANNOTATION_IMAGE = "mock_annotation_image" +DOWNLOAD_LINK = "/path/to/link" @pytest.fixture def connectivity(): mca = MCA() - + return mca def CCF_VERSIONS(): - return [MCA.CCF_2015, - MCA.CCF_2016] + return [MCA.CCF_2015, MCA.CCF_2016] -def DATA_PATHS(): - return [MCA.AVERAGE_TEMPLATE, - MCA.ARA_NISSL, - MCA.MOUSE_2011, - MCA.DEVMOUSE_2012, - MCA.CCF_2015, - MCA.CCF_2016] +def DATA_PATHS(): + return [MCA.AVERAGE_TEMPLATE, MCA.ARA_NISSL, MCA.MOUSE_2011, MCA.DEVMOUSE_2012, MCA.CCF_2015, MCA.CCF_2016] def RESOLUTIONS(): - return [MCA.VOXEL_RESOLUTION_10_MICRONS, - MCA.VOXEL_RESOLUTION_25_MICRONS, - MCA.VOXEL_RESOLUTION_50_MICRONS, - MCA.VOXEL_RESOLUTION_100_MICRONS] + return [ + MCA.VOXEL_RESOLUTION_10_MICRONS, + MCA.VOXEL_RESOLUTION_25_MICRONS, + MCA.VOXEL_RESOLUTION_50_MICRONS, + MCA.VOXEL_RESOLUTION_100_MICRONS, + ] -@pytest.mark.parametrize("data_path,resolution", - it.product(DATA_PATHS(), - RESOLUTIONS())) +@pytest.mark.parametrize("data_path,resolution", it.product(DATA_PATHS(), RESOLUTIONS())) @patch.object(MCA, "retrieve_file_over_http") -def test_download_volumetric_data(mock_retrieve, - connectivity, - data_path, - resolution): +def test_download_volumetric_data(mock_retrieve, connectivity, data_path, resolution): cache_filename = "annotation_%d.nrrd" % (resolution) - connectivity.download_volumetric_data(data_path, - cache_filename, - resolution) + connectivity.download_volumetric_data(data_path, cache_filename, resolution) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (data_path, - resolution), - cache_filename) + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (data_path, resolution), + cache_filename, + ) -@pytest.mark.parametrize("ccf_version,resolution", - it.product(CCF_VERSIONS(), - RESOLUTIONS())) +@pytest.mark.parametrize("ccf_version,resolution", it.product(CCF_VERSIONS(), RESOLUTIONS())) @patch.object(MCA, "retrieve_file_over_http") -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_annotation_volume(os_makedirs, - nrrd_read, - mock_retrieve, - connectivity, - ccf_version, - resolution): - cache_file = '/path/to/annotation_%d.nrrd' % (resolution) - - connectivity.download_annotation_volume( - ccf_version, - resolution, - cache_file, - reader=nrrd_read) +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_annotation_volume(os_makedirs, nrrd_read, mock_retrieve, connectivity, ccf_version, resolution): + cache_file = "/path/to/annotation_%d.nrrd" % (resolution) + + connectivity.download_annotation_volume(ccf_version, resolution, cache_file, reader=nrrd_read) nrrd_read.assert_called_once_with(cache_file) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (ccf_version, - resolution), - "/path/to/annotation_%d.nrrd" % (resolution)) + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (ccf_version, resolution), + "/path/to/annotation_%d.nrrd" % (resolution), + ) - os_makedirs.assert_any_call('/path/to') + os_makedirs.assert_any_call("/path/to") -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) +@pytest.mark.parametrize("resolution", RESOLUTIONS()) @patch.object(MCA, "retrieve_file_over_http") -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_annotation_volume_default(os_makedirs, - nrrd_read, - mock_retrieve, - connectivity, - resolution): +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_annotation_volume_default(os_makedirs, nrrd_read, mock_retrieve, connectivity, resolution): a, b = connectivity.download_annotation_volume( - None, - resolution, - '/path/to/annotation_%d.nrrd' % (resolution), - reader=nrrd_read) - + None, resolution, "/path/to/annotation_%d.nrrd" % (resolution), reader=nrrd_read + ) + assert a assert b mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (MCA.CCF_VERSION_DEFAULT, - resolution), - "/path/to/annotation_%d.nrrd" % (resolution)) + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (MCA.CCF_VERSION_DEFAULT, resolution), + "/path/to/annotation_%d.nrrd" % (resolution), + ) - os_makedirs.assert_any_call('/path/to') + os_makedirs.assert_any_call("/path/to") -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) +@pytest.mark.parametrize("resolution", RESOLUTIONS()) @patch.object(MCA, "retrieve_file_over_http") -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_structure_mask(os_makedirs, - nrrd_read, - mock_retrieve, - connectivity, - resolution): - +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_structure_mask(os_makedirs, nrrd_read, mock_retrieve, connectivity, resolution): structure_id = 12 - a, b = connectivity.download_structure_mask(structure_id, - None, - resolution,'/path/to/foo.nrrd', - reader=nrrd_read) + a, b = connectivity.download_structure_mask(structure_id, None, resolution, "/path/to/foo.nrrd", reader=nrrd_read) assert a assert b - expected = 'http://download.alleninstitute.org/informatics-archive/'\ - 'current-release/mouse_ccf/{0}/structure_masks/'\ - 'structure_masks_{1}/structure_{2}.nrrd'.format(MCA.CCF_VERSION_DEFAULT, - resolution, - structure_id) - mock_retrieve.assert_called_once_with(expected, '/path/to/foo.nrrd') - os_makedirs.assert_any_call('/path/to') + expected = ( + "http://download.alleninstitute.org/informatics-archive/" + "current-release/mouse_ccf/{0}/structure_masks/" + "structure_masks_{1}/structure_{2}.nrrd".format(MCA.CCF_VERSION_DEFAULT, resolution, structure_id) + ) + mock_retrieve.assert_called_once_with(expected, "/path/to/foo.nrrd") + os_makedirs.assert_any_call("/path/to") -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) +@pytest.mark.parametrize("resolution", RESOLUTIONS()) @patch.object(MCA, "retrieve_file_over_http") -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_template_volume(os_makedirs, - nrrd_read, - mock_retrieve, - connectivity, - resolution): +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_template_volume(os_makedirs, nrrd_read, mock_retrieve, connectivity, resolution): connectivity.download_template_volume( - resolution, - '/path/to/average_template_%d.nrrd' % (resolution), - reader=nrrd_read) + resolution, "/path/to/average_template_%d.nrrd" % (resolution), reader=nrrd_read + ) - nrrd_read.assert_called_once_with('/path/to/average_template_%d.nrrd' % (resolution)) + nrrd_read.assert_called_once_with("/path/to/average_template_%d.nrrd" % (resolution)) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/average_template/average_template_%d.nrrd" % - (resolution), - "/path/to/average_template_%d.nrrd" % (resolution)) + "current-release/mouse_ccf/average_template/average_template_%d.nrrd" % (resolution), + "/path/to/average_template_%d.nrrd" % (resolution), + ) - os_makedirs.assert_any_call('/path/to') + os_makedirs.assert_any_call("/path/to") @patch.object(MCA, "json_msg_query") -def test_get_experiments_no_ids(mock_query, - connectivity): +def test_get_experiments_no_ids(mock_query, connectivity): connectivity.get_experiments(None) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::SectionDataSet,rma::criteria,[failed$eqfalse]," - "products[id$in5,31]") + "products[id$in5,31]" + ) @patch.object(MCA, "json_msg_query") -def test_get_experiments_one_id(mock_query, - connectivity): +def test_get_experiments_one_id(mock_query, connectivity): connectivity.get_experiments(987) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::SectionDataSet,rma::criteria,[failed$eqfalse]," - "products[id$in5,31],[id$in987]") + "products[id$in5,31],[id$in987]" + ) @patch.object(MCA, "json_msg_query") -def test_get_experiments_ids(mock_query, - connectivity): - connectivity.get_experiments([9,8,7]) +def test_get_experiments_ids(mock_query, connectivity): + connectivity.get_experiments([9, 8, 7]) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::SectionDataSet,rma::criteria,[failed$eqfalse]," - "products[id$in5,31],[id$in9,8,7]") + "products[id$in5,31],[id$in9,8,7]" + ) @patch.object(MCA, "json_msg_query") -def test_get_manual_injection_summary(mock_query, - connectivity): +def test_get_manual_injection_summary(mock_query, connectivity): connectivity.get_manual_injection_summary(123) mock_query.assert_called_once_with( @@ -265,12 +213,12 @@ def test_get_manual_injection_summary(mock_query, "coordinates_dv,coordinates_ml,angle,sex,strain,injection_materials," "acronym,structures.name,days,transgenic_mice.name," "transgenic_lines.name,transgenic_lines.description," - "transgenic_lines.id,donors.id]") + "transgenic_lines.id,donors.id]" + ) @patch.object(MCA, "json_msg_query") -def test_get_experiment_detail(mock_query, - connectivity): +def test_get_experiment_detail(mock_query, connectivity): connectivity.get_experiment_detail(123) mock_query.assert_called_once_with( @@ -279,182 +227,153 @@ def test_get_experiment_detail(mock_query, "rma::include,specimen(stereotaxic_injections" "(primary_injection_structure,structures," "stereotaxic_injection_coordinates)),equalization,sub_images," - "rma::options[order$eq'sub_images.section_number$asc']") + "rma::options[order$eq'sub_images.section_number$asc']" + ) @patch.object(MCA, "json_msg_query") -def test_get_projection_image_info(mock_query, - connectivity): +def test_get_projection_image_info(mock_query, connectivity): connectivity.get_projection_image_info(123, 456) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::SectionDataSet,rma::criteria,[id$eq123],rma::include," - "equalization,sub_images[section_number$eq456]") + "equalization,sub_images[section_number$eq456]" + ) def test_build_reference_aligned_channel_volumes_url(connectivity): - url = \ - connectivity.build_reference_aligned_image_channel_volumes_url(123456) + url = connectivity.build_reference_aligned_image_channel_volumes_url(123456) - assert url == ("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::WellKnownFile,rma::criteria," - "well_known_file_type[name$eq'ImagesResampledTo25MicronARA']" - "[attachable_id$eq123456]") + assert url == ( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::WellKnownFile,rma::criteria," + "well_known_file_type[name$eq'ImagesResampledTo25MicronARA']" + "[attachable_id$eq123456]" + ) @patch.object(MCA, "retrieve_file_over_http") @patch.object(MCA, "do_query", return_value=DOWNLOAD_LINK) -def test_reference_aligned_channel_volumes(mock_query, - mock_retrieve, - connectivity): +def test_reference_aligned_channel_volumes(mock_query, mock_retrieve, connectivity): connectivity.download_reference_aligned_image_channel_volumes(123456) - mock_retrieve.assert_called_once_with( - "http://api.brain-map.org/path/to/link", - "123456.zip") + mock_retrieve.assert_called_once_with("http://api.brain-map.org/path/to/link", "123456.zip") @patch.object(MCA, "json_msg_query") -def test_experiment_source_search(mock_query, - connectivity): - connectivity.experiment_source_search( - injection_structures='Isocortex', - primary_structure_only=True) +def test_experiment_source_search(mock_query, connectivity): + connectivity.experiment_source_search(injection_structures="Isocortex", primary_structure_only=True) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "service::mouse_connectivity_injection_structure" - "[injection_structures$eqIsocortex][primary_structure_only$eqtrue]") + "[injection_structures$eqIsocortex][primary_structure_only$eqtrue]" + ) @patch.object(MCA, "json_msg_query") -def test_experiment_spatial_search(mock_query, - connectivity): - connectivity.experiment_spatial_search( - seed_point=[6900,5050,6450]) +def test_experiment_spatial_search(mock_query, connectivity): + connectivity.experiment_spatial_search(seed_point=[6900, 5050, 6450]) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "service::mouse_connectivity_target_spatial" - "[seed_point$eq6900,5050,6450]") + "[seed_point$eq6900,5050,6450]" + ) @patch.object(MCA, "json_msg_query") -def test_injection_coordinate_search(mock_query, - connectivity): - connectivity.experiment_injection_coordinate_search( - seed_point=[6900,5050,6450]) +def test_injection_coordinate_search(mock_query, connectivity): + connectivity.experiment_injection_coordinate_search(seed_point=[6900, 5050, 6450]) mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "service::mouse_connectivity_injection_coordinate" - "[seed_point$eq6900,5050,6450]") + "[seed_point$eq6900,5050,6450]" + ) @patch.object(MCA, "json_msg_query") -def test_experiment_correlation_search(mock_query, - connectivity): - connectivity.experiment_correlation_search( - row=112670853, structure='TH') +def test_experiment_correlation_search(mock_query, connectivity): + connectivity.experiment_correlation_search(row=112670853, structure="TH") mock_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "service::mouse_connectivity_correlation" - "[row$eq112670853][structure$eqTH]") + "[row$eq112670853][structure$eqTH]" + ) -@pytest.mark.parametrize("injection,hemisphere", - it.product([True, False,None], - [['left'],['right'],None])) +@pytest.mark.parametrize("injection,hemisphere", it.product([True, False, None], [["left"], ["right"], None])) @patch.object(MCA, "json_msg_query") -def test_get_structure_unionizes(mock_query, - connectivity, - injection, - hemisphere): +def test_get_structure_unionizes(mock_query, connectivity, injection, hemisphere): connectivity.get_structure_unionizes( - experiment_ids=[126862385], - is_injection=injection, - hemisphere_ids=hemisphere, - include='structure') + experiment_ids=[126862385], is_injection=injection, hemisphere_ids=hemisphere, include="structure" + ) - i = '' + i = "" if injection is not None: i = "[is_injection$eq%s]" % (str(injection).lower()) - h = '' + h = "" if hemisphere is not None: h = "[hemisphere_id$in%s]" % (hemisphere[0]) mock_query.assert_called_once_with( - ("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::ProjectionStructureUnionize,rma::criteria," - "[section_data_set_id$in126862385]%s%s," - "rma::include,structure,rma::options[num_rows$eq'all']" - "[count$eqfalse]") % (i, h)) + ( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::ProjectionStructureUnionize,rma::criteria," + "[section_data_set_id$in126862385]%s%s," + "rma::include,structure,rma::options[num_rows$eq'all']" + "[count$eqfalse]" + ) + % (i, h) + ) def test_download_injection_density(connectivity): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: - connectivity.download_injection_density( - 'file.name', 12345, 10) + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: + connectivity.download_injection_density("file.name", 12345, 10) gda.assert_called_once_with( - "http://api.brain-map.org/grid_data/download_file/" - "12345" - "?image=injection_density&resolution=10", - "file.name") + "http://api.brain-map.org/grid_data/download_file/12345?image=injection_density&resolution=10", "file.name" + ) def test_download_projection_density(connectivity): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: - connectivity.download_projection_density( - 'file.name', 12345, 10) + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: + connectivity.download_projection_density("file.name", 12345, 10) gda.assert_called_once_with( - "http://api.brain-map.org/grid_data/download_file/" - "12345" - "?image=projection_density&resolution=10", - "file.name") + "http://api.brain-map.org/grid_data/download_file/12345?image=projection_density&resolution=10", "file.name" + ) def test_download_data_mask_density(connectivity): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: - connectivity.download_data_mask( - 'file.name', 12345, 10) + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: + connectivity.download_data_mask("file.name", 12345, 10) gda.assert_called_once_with( - "http://api.brain-map.org/grid_data/download_file/" - "12345" - "?image=data_mask&resolution=10", - "file.name") + "http://api.brain-map.org/grid_data/download_file/12345?image=data_mask&resolution=10", "file.name" + ) def test_download_injection_fraction(connectivity): - with patch('allensdk.api.api.Api.retrieve_file_over_http') as gda: - connectivity.download_injection_fraction( - 'file.name', 12345, 10) + with patch("allensdk.api.api.Api.retrieve_file_over_http") as gda: + connectivity.download_injection_fraction("file.name", 12345, 10) gda.assert_called_once_with( - "http://api.brain-map.org/grid_data/download_file/" - "12345" - "?image=injection_fraction&resolution=10", - "file.name") + "http://api.brain-map.org/grid_data/download_file/12345?image=injection_fraction&resolution=10", "file.name" + ) def test_calculate_injection_centroid(connectivity): - density = np.array(([1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0])) - fraction = np.array(([1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0], - [1.0,1.0,1.0,1.0])) - - centroid = connectivity.calculate_injection_centroid( - density, fraction, resolution=25) - + density = np.array(([1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0])) + fraction = np.array(([1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0])) + + centroid = connectivity.calculate_injection_centroid(density, fraction, resolution=25) + assert np.array_equal(centroid, [37.5, 37.5]) diff --git a/allensdk/test/api/test_ontologies_api.py b/allensdk/test/api/test_ontologies_api.py index 15fbe8c548..20c6d3b836 100644 --- a/allensdk/test/api/test_ontologies_api.py +++ b/allensdk/test/api/test_ontologies_api.py @@ -53,7 +53,8 @@ def test_get_structure_graph(mock_json_msg_query, ontologies): "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria,[graph_id$in1]," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -62,7 +63,8 @@ def test_list_structure_graphs(mock_json_msg_query, ontologies): mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::StructureGraph," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -70,7 +72,8 @@ def test_list_structure_sets_noarg(mock_json_msg_query, ontologies): ontologies.get_structure_sets() mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" - "model::StructureSet,rma::options[num_rows$eq'all'][count$eqfalse]") + "model::StructureSet,rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -79,15 +82,16 @@ def test_list_structure_sets_args(mock_json_msg_query, ontologies): mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::StructureSet,rma::criteria,[id$in2,3]," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") def test_list_atlases(mock_json_msg_query, ontologies): ontologies.get_atlases() mock_json_msg_query.assert_called_once_with( - "http://api.brain-map.org/api/v2/data/query.json?q=" - "model::Atlas,rma::options[num_rows$eq'all'][count$eqfalse]") + "http://api.brain-map.org/api/v2/data/query.json?q=model::Atlas,rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -98,20 +102,21 @@ def test_structure_graph_by_name(mock_json_msg_query, ontologies): "model::Structure,rma::criteria," "graph[structure_graphs.name$in'Mouse Brain Atlas']," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") def test_structure_graphs_by_names(mock_json_msg_query, ontologies): - ontologies.get_structures(structure_graph_names=["'Mouse Brain Atlas'", - "'Human Brain Atlas'"]) + ontologies.get_structures(structure_graph_names=["'Mouse Brain Atlas'", "'Human Brain Atlas'"]) mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," "graph[structure_graphs.name$in'Mouse Brain Atlas'," "'Human Brain Atlas']," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -121,7 +126,8 @@ def test_structure_set_by_id(mock_json_msg_query, ontologies): "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria,[structure_set_id$in8]," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -131,20 +137,20 @@ def test_structure_sets_by_ids(mock_json_msg_query, ontologies): "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria,[structure_set_id$in7,8]," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") def test_structure_set_by_name(mock_json_msg_query, ontologies): - ontologies.get_structures( - structure_set_names=ontologies.quote_string( - "Mouse Connectivity - Summary")) + ontologies.get_structures(structure_set_names=ontologies.quote_string("Mouse Connectivity - Summary")) mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," "structure_sets[name$in'Mouse Connectivity - Summary']," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -152,13 +158,16 @@ def test_structure_set_by_names(mock_json_msg_query, ontologies): ontologies.get_structures( structure_set_names=[ ontologies.quote_string("NHP - Coarse"), - ontologies.quote_string("Mouse Connectivity - Summary")]) + ontologies.quote_string("Mouse Connectivity - Summary"), + ] + ) mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," "structure_sets[name$in'NHP - Coarse','Mouse Connectivity - Summary']," "rma::options" - "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]") + "[num_rows$eq'all'][order$eqstructures.graph_order][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -167,7 +176,8 @@ def test_structure_set_no_order(mock_json_msg_query, ontologies): mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," - "[graph_id$in1],rma::options[num_rows$eq'all'][count$eqfalse]") + "[graph_id$in1],rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -183,7 +193,8 @@ def test_atlas_1(mock_json_msg_query, ontologies): "ontologies.id,ontologies.name," "structure_graphs.id,structure_graphs.name," "graphic_group_labels.id,graphic_group_labels.name']" - "[num_rows$eq'all'][count$eqfalse]") + "[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -194,7 +205,8 @@ def test_atlas_verbose(mock_json_msg_query, ontologies): "model::Atlas,rma::criteria," "structure_graph(ontology),graphic_group_labels," "rma::include,structure_graph(ontology),graphic_group_labels," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(OntologiesApi, "json_msg_query") @@ -205,13 +217,13 @@ def test_get_structures_with_sets(mock_json_msg_query, ontologies): "model::Structure,rma::criteria,[graph_id$in1]," "rma::include,structure_sets," "rma::options[num_rows$eq'all'][order$eqstructures.graph_order]" - "[count$eqfalse]") + "[count$eqfalse]" + ) def test_unpack_structure_set_ancestors(ontologies): - - sdf = pd.DataFrame([{'structure_id_path': '/1/2/3/'}]) + sdf = pd.DataFrame([{"structure_id_path": "/1/2/3/"}]) ontologies.unpack_structure_set_ancestors(sdf) - - assert( 'structure_set_ancestor' in sdf.columns.values ) - assert( allclose(sdf['structure_set_ancestor'].values[0], [1, 2, 3]) ) + + assert "structure_set_ancestor" in sdf.columns.values + assert allclose(sdf["structure_set_ancestor"].values[0], [1, 2, 3]) diff --git a/allensdk/test/api/test_pager.py b/allensdk/test/api/test_pager.py index 3e6ef004df..162ea581b9 100644 --- a/allensdk/test/api/test_pager.py +++ b/allensdk/test/api/test_pager.py @@ -41,6 +41,7 @@ import builtins from allensdk.api.queries.rma_template import RmaTemplate from allensdk.api.warehouse_cache.cache import cacheable, Cache + try: import StringIO except Exception: @@ -52,65 +53,55 @@ def pager(): return RmaPager() -_msg = [{'whatever': True}] + +_msg = [{"whatever": True}] _pd_msg = pd.DataFrame(_msg) -_csv_msg = pd.read_csv(StringIO.StringIO(""",whatever +_csv_msg = pd.read_csv( + StringIO.StringIO(""",whatever 0,True -"""), index_col=0) - -_read_url_get_msg5 = [{'msg': _msg}, - {'msg': _msg}, - {'msg': _msg}, - {'msg': _msg}, - {'msg': _msg}] -_pj_msg5 = pd.DataFrame([{'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}]) -_read_msg5 = [{'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}] +"""), + index_col=0, +) + +_read_url_get_msg5 = [{"msg": _msg}, {"msg": _msg}, {"msg": _msg}, {"msg": _msg}, {"msg": _msg}] +_pj_msg5 = pd.DataFrame( + [{"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}] +) +_read_msg5 = [{"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}] @pytest.fixture def safe_read_url_get_msg5(): return SafeJsonMsg(_read_url_get_msg5) + @pytest.fixture def rma(): return RmaApi() -@patch("allensdk.core.json_utilities.read_url_get", - return_value={'msg': _msg}) + +@patch("allensdk.core.json_utilities.read_url_get", return_value={"msg": _msg}) def test_pageable_json(ju_read_url_get, rma): - @pageable() def get_genes(**kwargs): - return rma.model_query(model='Gene', **kwargs) + return rma.model_query(model="Gene", **kwargs) nr = 5 pp = 1 - tr = nr*pp + tr = nr * pp df = list(get_genes(num_rows=nr, total_rows=tr)) - assert df == [{'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}] + assert df == [{"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}, {"whatever": True}] + + base_query = ( + "http://api.brain-map.org/api/v2/data/query.json?q=model::Gene" + ",rma::options%5Bnum_rows$eq5%5D%5Bstart_row$eq{}%5D" + "%5Bcount$eqfalse%5D" + ) - base_query = \ - ('http://api.brain-map.org/api/v2/data/query.json?q=model::Gene' - ',rma::options%5Bnum_rows$eq5%5D%5Bstart_row$eq{}%5D' - '%5Bcount$eqfalse%5D') + expected_calls = map(lambda c: call(base_query.format(c)), [0, 1, 2, 3, 4]) - expected_calls = map(lambda c: call(base_query.format(c)), - [0, 1, 2, 3, 4]) - assert ju_read_url_get.call_args_list == list(expected_calls) @@ -119,149 +110,144 @@ def test_all(safe_read_url_get_msg5, rma): @pageable() def get_genes(**kwargs): - return rma.model_query(model='Gene', **kwargs) + return rma.model_query(model="Gene", **kwargs) nr = 1 - df = list(get_genes(num_rows=nr, total_rows='all')) + df = list(get_genes(num_rows=nr, total_rows="all")) - assert df == [{'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}] + assert df == [ + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + ] - base_query = \ - ('http://api.brain-map.org/api/v2/data/query.json?q=model::Gene' - ',rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D' - '%5Bcount$eqfalse%5D') + base_query = ( + "http://api.brain-map.org/api/v2/data/query.json?q=model::Gene" + ",rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D" + "%5Bcount$eqfalse%5D" + ) # we get one extra call if total_rows % num_rows == 0 with current implementation - expected_calls = map(lambda c: call(base_query.format(c)), - [0, 1, 2, 3, 4, 5]) - + expected_calls = map(lambda c: call(base_query.format(c)), [0, 1, 2, 3, 4, 5]) + assert ju_read_url_get.call_args_list == list(expected_calls) -@pytest.mark.parametrize("cache_style", - (Cache.cache_csv, - Cache.cache_csv_json, - Cache.cache_csv_dataframe)) +@pytest.mark.parametrize("cache_style", (Cache.cache_csv, Cache.cache_csv_json, Cache.cache_csv_dataframe)) @patch("pandas.read_csv", return_value=_csv_msg) @patch("os.makedirs") -def test_cacheable_pageable_csv(os_makedirs, read_csv, - cache_style, safe_read_url_get_msg5): +def test_cacheable_pageable_csv(os_makedirs, read_csv, cache_style, safe_read_url_get_msg5): with patch("allensdk.core.json_utilities.read_url_get", side_effect=safe_read_url_get_msg5) as ju_read_url_get: - archive_templates = \ - {"cam_cell_queries": [ - {'name': 'cam_cell_metric', - 'description': 'see name', - 'model': 'ApiCamCellMetric', - 'num_rows': 1000, - 'count': False - } ] } + archive_templates = { + "cam_cell_queries": [ + { + "name": "cam_cell_metric", + "description": "see name", + "model": "ApiCamCellMetric", + "num_rows": 1000, + "count": False, + } + ] + } rmat = RmaTemplate(query_manifest=archive_templates) @cacheable() @pageable(num_rows=2000) - def get_cam_cell_metrics(*args, - **kwargs): - return rmat.template_query("cam_cell_queries", - 'cam_cell_metric', - *args, - **kwargs) - - with patch(builtins.__name__ + '.open', - mock_open(), - create=True): - with patch('csv.DictWriter.writerow') as csv_writerow: - get_cam_cell_metrics(strategy='create', - path='/path/to/cam_cell_metrics.csv', - num_rows=1, - total_rows='all', - **cache_style()) - - os_makedirs.assert_called_with('/path/to') - - base_query = ('http://api.brain-map.org/api/v2/data/query.json?' - 'q=model::ApiCamCellMetric,' - 'rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D' - '%5Bcount$eqfalse%5D') - - expected_calls = map(lambda c: call(base_query.format(c)), - [0, 1, 2, 3, 4, 5]) + def get_cam_cell_metrics(*args, **kwargs): + return rmat.template_query("cam_cell_queries", "cam_cell_metric", *args, **kwargs) + + with patch(builtins.__name__ + ".open", mock_open(), create=True): + with patch("csv.DictWriter.writerow") as csv_writerow: + get_cam_cell_metrics( + strategy="create", + path="/path/to/cam_cell_metrics.csv", + num_rows=1, + total_rows="all", + **cache_style(), + ) + + os_makedirs.assert_called_with("/path/to") + + base_query = ( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::ApiCamCellMetric," + "rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D" + "%5Bcount$eqfalse%5D" + ) + + expected_calls = map(lambda c: call(base_query.format(c)), [0, 1, 2, 3, 4, 5]) assert ju_read_url_get.call_args_list == list(expected_calls) - read_csv.assert_called_once_with('/path/to/cam_cell_metrics.csv', parse_dates=True) + read_csv.assert_called_once_with("/path/to/cam_cell_metrics.csv", parse_dates=True) - assert csv_writerow.call_args_list == [call({'whatever': 'whatever'}), - call({'whatever': True}), - call({'whatever': True}), - call({'whatever': True}), - call({'whatever': True}), - call({'whatever': True})] + assert csv_writerow.call_args_list == [ + call({"whatever": "whatever"}), + call({"whatever": True}), + call({"whatever": True}), + call({"whatever": True}), + call({"whatever": True}), + call({"whatever": True}), + ] -@pytest.mark.parametrize("cache_style", - (Cache.cache_json, - Cache.cache_json_dataframe)) +@pytest.mark.parametrize("cache_style", (Cache.cache_json, Cache.cache_json_dataframe)) @patch("allensdk.core.json_utilities.read", return_value=_read_msg5) @patch("pandas.io.json.read_json", return_value=_pj_msg5) @patch("os.makedirs") -def test_cacheable_pageable_json(os_makedirs, pj_read_json, - ju_read, cache_style, safe_read_url_get_msg5): +def test_cacheable_pageable_json(os_makedirs, pj_read_json, ju_read, cache_style, safe_read_url_get_msg5): with patch("allensdk.core.json_utilities.read_url_get", side_effect=safe_read_url_get_msg5) as ju_read_url_get: - - archive_templates = \ - {"cam_cell_queries": [ - {'name': 'cam_cell_metric', - 'description': 'see name', - 'model': 'ApiCamCellMetric', - 'num_rows': 1000, - 'count': False - } ] } + archive_templates = { + "cam_cell_queries": [ + { + "name": "cam_cell_metric", + "description": "see name", + "model": "ApiCamCellMetric", + "num_rows": 1000, + "count": False, + } + ] + } rmat = RmaTemplate(query_manifest=archive_templates) @cacheable() @pageable(num_rows=2000) - def get_cam_cell_metrics(*args, - **kwargs): - return rmat.template_query("cam_cell_queries", - 'cam_cell_metric', - *args, - **kwargs) - - with patch(builtins.__name__ + '.open', - mock_open(), - create=True) as open_mock: - open_mock.return_value.read = \ - MagicMock(name='read', - return_value=[{'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}, - {'whatever': True}]) - cam_cell_metrics = \ - get_cam_cell_metrics(strategy='create', - path='/path/to/cam_cell_metrics.json', - num_rows=1, - total_rows='all', - **cache_style()) - - os_makedirs.assert_called_with('/path/to') - - base_query = \ - ('http://api.brain-map.org/api/v2/data/query.json?' - 'q=model::ApiCamCellMetric,' - 'rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D' - '%5Bcount$eqfalse%5D') - - expected_calls = map(lambda c: call(base_query.format(c)), - [0, 1, 2, 3, 4, 5]) - - open_mock.assert_called_once_with('/path/to/cam_cell_metrics.json', 'wb') - open_mock.return_value.write.assert_called_once_with('[\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n }\n]') + def get_cam_cell_metrics(*args, **kwargs): + return rmat.template_query("cam_cell_queries", "cam_cell_metric", *args, **kwargs) + + with patch(builtins.__name__ + ".open", mock_open(), create=True) as open_mock: + open_mock.return_value.read = MagicMock( + name="read", + return_value=[ + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + {"whatever": True}, + ], + ) + cam_cell_metrics = get_cam_cell_metrics( + strategy="create", path="/path/to/cam_cell_metrics.json", num_rows=1, total_rows="all", **cache_style() + ) + + os_makedirs.assert_called_with("/path/to") + + base_query = ( + "http://api.brain-map.org/api/v2/data/query.json?" + "q=model::ApiCamCellMetric," + "rma::options%5Bnum_rows$eq1%5D%5Bstart_row$eq{}%5D" + "%5Bcount$eqfalse%5D" + ) + + expected_calls = map(lambda c: call(base_query.format(c)), [0, 1, 2, 3, 4, 5]) + + open_mock.assert_called_once_with("/path/to/cam_cell_metrics.json", "wb") + open_mock.return_value.write.assert_called_once_with( + '[\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n },\n {\n "whatever": true\n }\n]' + ) assert ju_read_url_get.call_args_list == list(expected_calls) assert len(cam_cell_metrics) == 5 diff --git a/allensdk/test/api/test_reference_space_api.py b/allensdk/test/api/test_reference_space_api.py index 3f17948113..61634079c1 100644 --- a/allensdk/test/api/test_reference_space_api.py +++ b/allensdk/test/api/test_reference_space_api.py @@ -42,179 +42,143 @@ @pytest.fixture def ref_space(): rsa = RSA() - + return rsa @pytest.fixture def mock_nrrd(): mocked_nrrd = MagicMock() - mocked_nrrd.read = MagicMock(return_value=('mock_annotation_data', - 'mock_annotation_image')) + mocked_nrrd.read = MagicMock(return_value=("mock_annotation_data", "mock_annotation_image")) return mocked_nrrd def CCF_VERSIONS(): - return [RSA.CCF_2015, - RSA.CCF_2016, - RSA.CCF_2017] + return [RSA.CCF_2015, RSA.CCF_2016, RSA.CCF_2017] -def DATA_PATHS(): - return [RSA.AVERAGE_TEMPLATE, - RSA.ARA_NISSL, - RSA.MOUSE_2011, - RSA.DEVMOUSE_2012, - RSA.CCF_2015, - RSA.CCF_2016, - RSA.CCF_2017] +def DATA_PATHS(): + return [ + RSA.AVERAGE_TEMPLATE, + RSA.ARA_NISSL, + RSA.MOUSE_2011, + RSA.DEVMOUSE_2012, + RSA.CCF_2015, + RSA.CCF_2016, + RSA.CCF_2017, + ] def RESOLUTIONS(): - return [RSA.VOXEL_RESOLUTION_10_MICRONS, - RSA.VOXEL_RESOLUTION_25_MICRONS, - RSA.VOXEL_RESOLUTION_50_MICRONS, - RSA.VOXEL_RESOLUTION_100_MICRONS] + return [ + RSA.VOXEL_RESOLUTION_10_MICRONS, + RSA.VOXEL_RESOLUTION_25_MICRONS, + RSA.VOXEL_RESOLUTION_50_MICRONS, + RSA.VOXEL_RESOLUTION_100_MICRONS, + ] -MOCK_ANNOTATION_DATA = 'mock_annotation_data' -MOCK_ANNOTATION_IMAGE = 'mock_annotation_image' +MOCK_ANNOTATION_DATA = "mock_annotation_data" +MOCK_ANNOTATION_IMAGE = "mock_annotation_image" def test_download_mouse_atlas_volume(ref_space): - - with patch.object(ref_space, 'retrieve_file_over_http') as mock_retrieve: + with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: with pytest.raises(RuntimeError): - ref_space.download_mouse_atlas_volume('P56', 'Mouse_gridAnnotation', 'P56/gridAnnotation.mhd') + ref_space.download_mouse_atlas_volume("P56", "Mouse_gridAnnotation", "P56/gridAnnotation.mhd") mock_retrieve.assert_called_once_with( - 'http://download.alleninstitute.org/informatics-archive/'\ - 'current-release/mouse_annotation/'\ - 'P56_Mouse_gridAnnotation.zip', - 'P56/gridAnnotation.mhd', - zipped=True) - - -@pytest.mark.parametrize("data_path,resolution", - it.product(DATA_PATHS(), - RESOLUTIONS())) -def test_download_volumetric_data(ref_space, - data_path, - resolution): + "http://download.alleninstitute.org/informatics-archive/" + "current-release/mouse_annotation/" + "P56_Mouse_gridAnnotation.zip", + "P56/gridAnnotation.mhd", + zipped=True, + ) + + +@pytest.mark.parametrize("data_path,resolution", it.product(DATA_PATHS(), RESOLUTIONS())) +def test_download_volumetric_data(ref_space, data_path, resolution): cache_filename = "annotation_%d.nrrd" % (resolution) with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: - ref_space.download_volumetric_data(data_path, - cache_filename, - resolution) + ref_space.download_volumetric_data(data_path, cache_filename, resolution) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (data_path, - resolution), - cache_filename) - - -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_structure_mask(os_makedirs, - nrrd_read, - ref_space, - resolution): + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (data_path, resolution), + cache_filename, + ) + +@pytest.mark.parametrize("resolution", RESOLUTIONS()) +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_structure_mask(os_makedirs, nrrd_read, ref_space, resolution): structure_id = 12 with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: - a, b = ref_space.download_structure_mask(structure_id, - None, resolution, - '/path/to/foo.nrrd', - reader=nrrd_read) + a, b = ref_space.download_structure_mask(structure_id, None, resolution, "/path/to/foo.nrrd", reader=nrrd_read) assert a assert b - expected = 'http://download.alleninstitute.org/informatics-archive/'\ - 'current-release/mouse_ccf/{0}/structure_masks/'\ - 'structure_masks_{1}/structure_{2}.nrrd'.format(RSA.CCF_VERSION_DEFAULT, - resolution, - structure_id) - mock_retrieve.assert_called_once_with(expected, '/path/to/foo.nrrd') - os_makedirs.assert_any_call('/path/to') - + expected = ( + "http://download.alleninstitute.org/informatics-archive/" + "current-release/mouse_ccf/{0}/structure_masks/" + "structure_masks_{1}/structure_{2}.nrrd".format(RSA.CCF_VERSION_DEFAULT, resolution, structure_id) + ) + mock_retrieve.assert_called_once_with(expected, "/path/to/foo.nrrd") + os_makedirs.assert_any_call("/path/to") -@patch('allensdk.core.obj_utilities.read_obj', return_value=('mock_obj')) -@patch('os.makedirs') -def test_download_structure_mesh(os_makedirs, - read_obj, - ref_space): +@patch("allensdk.core.obj_utilities.read_obj", return_value=("mock_obj")) +@patch("os.makedirs") +def test_download_structure_mesh(os_makedirs, read_obj, ref_space): structure_id = 12 with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: - a = ref_space.download_structure_mesh(structure_id, - None, '/path/to/foo.obj', - reader=read_obj) - - assert a == 'mock_obj' - - expected = 'http://download.alleninstitute.org/informatics-archive/'\ - 'current-release/mouse_ccf/{0}/structure_meshes/'\ - '{1}.obj'.format(RSA.CCF_VERSION_DEFAULT, structure_id) - - mock_retrieve.assert_called_once_with(expected, '/path/to/foo.obj') - os_makedirs.assert_any_call('/path/to') - - -@pytest.mark.parametrize("ccf_version,resolution", - it.product(CCF_VERSIONS(), - RESOLUTIONS())) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_annotation_volume(os_makedirs, - nrrd_read, - ref_space, - ccf_version, - resolution): - cache_file = '/path/to/annotation_%d.nrrd' % (resolution) + a = ref_space.download_structure_mesh(structure_id, None, "/path/to/foo.obj", reader=read_obj) + + assert a == "mock_obj" + + expected = ( + "http://download.alleninstitute.org/informatics-archive/" + "current-release/mouse_ccf/{0}/structure_meshes/" + "{1}.obj".format(RSA.CCF_VERSION_DEFAULT, structure_id) + ) + + mock_retrieve.assert_called_once_with(expected, "/path/to/foo.obj") + os_makedirs.assert_any_call("/path/to") + + +@pytest.mark.parametrize("ccf_version,resolution", it.product(CCF_VERSIONS(), RESOLUTIONS())) +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_annotation_volume(os_makedirs, nrrd_read, ref_space, ccf_version, resolution): + cache_file = "/path/to/annotation_%d.nrrd" % (resolution) with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: - ref_space.download_annotation_volume( - ccf_version, - resolution, - cache_file, - reader=nrrd_read) + ref_space.download_annotation_volume(ccf_version, resolution, cache_file, reader=nrrd_read) nrrd_read.assert_called_once_with(cache_file) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (ccf_version, resolution), - "/path/to/annotation_%d.nrrd" % (resolution)) - - os_makedirs.assert_any_call('/path/to') - - -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_annotation_volume_default(os_makedirs, - nrrd_read, - ref_space, - resolution): + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (ccf_version, resolution), + "/path/to/annotation_%d.nrrd" % (resolution), + ) + + os_makedirs.assert_any_call("/path/to") + + +@pytest.mark.parametrize("resolution", RESOLUTIONS()) +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_annotation_volume_default(os_makedirs, nrrd_read, ref_space, resolution): with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: a, b = ref_space.download_annotation_volume( - None, - resolution, - '/path/to/annotation_%d.nrrd' % (resolution), - reader=nrrd_read) + None, resolution, "/path/to/annotation_%d.nrrd" % (resolution), reader=nrrd_read + ) assert a assert b @@ -223,32 +187,26 @@ def test_download_annotation_volume_default(os_makedirs, mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/%s/annotation_%d.nrrd" % - (RSA.CCF_VERSION_DEFAULT, resolution), - "/path/to/annotation_%d.nrrd" % (resolution)) - - os_makedirs.assert_any_call('/path/to') - - -@pytest.mark.parametrize("resolution", - RESOLUTIONS()) -@patch("nrrd.read", return_value=('mock_annotation_data', - 'mock_annotation_image')) -@patch('os.makedirs') -def test_download_template_volume(os_makedirs, - nrrd_read, - ref_space, - resolution): + "current-release/mouse_ccf/%s/annotation_%d.nrrd" % (RSA.CCF_VERSION_DEFAULT, resolution), + "/path/to/annotation_%d.nrrd" % (resolution), + ) + + os_makedirs.assert_any_call("/path/to") + + +@pytest.mark.parametrize("resolution", RESOLUTIONS()) +@patch("nrrd.read", return_value=("mock_annotation_data", "mock_annotation_image")) +@patch("os.makedirs") +def test_download_template_volume(os_makedirs, nrrd_read, ref_space, resolution): with patch.object(ref_space, "retrieve_file_over_http") as mock_retrieve: ref_space.download_template_volume( - resolution, - '/path/to/average_template_%d.nrrd' % (resolution), - reader=nrrd_read) + resolution, "/path/to/average_template_%d.nrrd" % (resolution), reader=nrrd_read + ) mock_retrieve.assert_called_once_with( "http://download.alleninstitute.org/informatics-archive/" - "current-release/mouse_ccf/average_template/average_template_%d.nrrd" % - (resolution), - "/path/to/average_template_%d.nrrd" % (resolution)) + "current-release/mouse_ccf/average_template/average_template_%d.nrrd" % (resolution), + "/path/to/average_template_%d.nrrd" % (resolution), + ) - os_makedirs.assert_any_call('/path/to') + os_makedirs.assert_any_call("/path/to") diff --git a/allensdk/test/api/test_rma_template.py b/allensdk/test/api/test_rma_template.py index 6c05e47644..8a236ccc98 100644 --- a/allensdk/test/api/test_rma_template.py +++ b/allensdk/test/api/test_rma_template.py @@ -38,94 +38,100 @@ from allensdk.api.queries.rma_template import RmaTemplate -_msg = {'msg': [{'whatever': True}]} +_msg = {"msg": [{"whatever": True}]} @pytest.fixture def rma(): - templates = \ - {"ontology_queries": [ - {'name': 'structures_by_graph_ids', - 'description': 'see name', - 'model': 'Structure', - 'criteria': '[graph_id$in{{ graph_ids }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_ids'] - }, - {'name': 'structures_by_graph_names', - 'description': 'see name', - 'model': 'Structure', - 'criteria': 'graph[structure_graphs.name$in{{ graph_names }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_names'] - }, - {'name': 'structures_by_set_ids', - 'description': 'see name', - 'model': 'Structure', - 'criteria': '[structure_set_id$in{{ set_ids }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['set_ids'] - }, - {'name': 'structures_by_set_names', - 'description': 'see name', - 'model': 'Structure', - 'criteria': 'structure_sets[name$in{{ set_names }}]', - 'order': ['structures.graph_order'], - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['set_names'] - }, - {'name': 'structure_graphs_list', - 'description': 'see name', - 'model': 'StructureGraph', - 'num_rows': 'all', - 'count': False - }, - {'name': 'structure_sets_list', - 'description': 'see name', - 'model': 'StructureSet', - 'num_rows': 'all', - 'count': False - }, - {'name': 'atlases_list', - 'description': 'see name', - 'model': 'Atlas', - 'num_rows': 'all', - 'count': False - }, - {'name': 'atlases_table', - 'description': 'see name', - 'model': 'Atlas', - 'criteria': '{% if graph_ids is defined %}[graph_id$in{{ graph_ids }}],{% endif %}structure_graph(ontology),graphic_group_labels', - 'include': '[structure_graph(ontology),graphic_group_labels', - 'num_rows': 'all', - 'count': False, - 'criteria_params': ['graph_ids'] - }, - {'name': 'atlases_table_brief', - 'description': 'see name', - 'model': 'Atlas', - 'criteria': 'structure_graph(ontology),graphic_group_labels', - 'include': 'structure_graph(ontology),graphic_group_labels', - 'only': ['atlases.id', - 'atlases.name', - 'atlases.image_type', - 'ontologies.id', - 'ontologies.name', - 'structure_graphs.id', - 'structure_graphs.name', - 'graphic_group_labels.id', - 'graphic_group_labels.name'], - 'num_rows': 'all', - 'count': False - } - ]} + templates = { + "ontology_queries": [ + { + "name": "structures_by_graph_ids", + "description": "see name", + "model": "Structure", + "criteria": "[graph_id$in{{ graph_ids }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["graph_ids"], + }, + { + "name": "structures_by_graph_names", + "description": "see name", + "model": "Structure", + "criteria": "graph[structure_graphs.name$in{{ graph_names }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["graph_names"], + }, + { + "name": "structures_by_set_ids", + "description": "see name", + "model": "Structure", + "criteria": "[structure_set_id$in{{ set_ids }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["set_ids"], + }, + { + "name": "structures_by_set_names", + "description": "see name", + "model": "Structure", + "criteria": "structure_sets[name$in{{ set_names }}]", + "order": ["structures.graph_order"], + "num_rows": "all", + "count": False, + "criteria_params": ["set_names"], + }, + { + "name": "structure_graphs_list", + "description": "see name", + "model": "StructureGraph", + "num_rows": "all", + "count": False, + }, + { + "name": "structure_sets_list", + "description": "see name", + "model": "StructureSet", + "num_rows": "all", + "count": False, + }, + {"name": "atlases_list", "description": "see name", "model": "Atlas", "num_rows": "all", "count": False}, + { + "name": "atlases_table", + "description": "see name", + "model": "Atlas", + "criteria": "{% if graph_ids is defined %}[graph_id$in{{ graph_ids }}],{% endif %}structure_graph(ontology),graphic_group_labels", + "include": "[structure_graph(ontology),graphic_group_labels", + "num_rows": "all", + "count": False, + "criteria_params": ["graph_ids"], + }, + { + "name": "atlases_table_brief", + "description": "see name", + "model": "Atlas", + "criteria": "structure_graph(ontology),graphic_group_labels", + "include": "structure_graph(ontology),graphic_group_labels", + "only": [ + "atlases.id", + "atlases.name", + "atlases.image_type", + "ontologies.id", + "ontologies.name", + "structure_graphs.id", + "structure_graphs.name", + "graphic_group_labels.id", + "graphic_group_labels.name", + ], + "num_rows": "all", + "count": False, + }, + ] + } rma = RmaTemplate(query_manifest=templates) return rma @@ -133,56 +139,53 @@ def rma(): @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_atlases_list(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'atlases_list') + rma.template_query("ontology_queries", "atlases_list") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Atlas,rma::options" - "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structure_graphs_list(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structure_graphs_list') + rma.template_query("ontology_queries", "structure_graphs_list") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::StructureGraph,rma::options" - "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structure_sets_list(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structure_sets_list') + rma.template_query("ontology_queries", "structure_sets_list") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::StructureSet,rma::options" - "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structures_by_graph_ids(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structures_by_graph_ids', - graph_ids='1') + rma.template_query("ontology_queries", "structures_by_graph_ids", graph_ids="1") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," "%5Bgraph_id$in1%5D,rma::options" "%5Bnum_rows$eq%27all%27%5D%5Border$eqstructures.graph_order%5D" - "%5Bcount$eqfalse%5D") + "%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structures_by_two_graph_ids(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structures_by_graph_ids', - graph_ids=[1, 2]) + rma.template_query("ontology_queries", "structures_by_graph_ids", graph_ids=[1, 2]) ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" @@ -190,14 +193,15 @@ def test_structures_by_two_graph_ids(ju_read_url_get, rma): "%5Bgraph_id$in1,2%5D," "rma::options" "%5Bnum_rows$eq%27all%27%5D" - "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D") + "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structures_by_graph_names(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structures_by_graph_names', - graph_names=rma.quote_string('Human+Brain+Atlas')) + rma.template_query( + "ontology_queries", "structures_by_graph_names", graph_names=rma.quote_string("Human+Brain+Atlas") + ) ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" @@ -205,53 +209,51 @@ def test_structures_by_graph_names(ju_read_url_get, rma): "graph%5Bstructure_graphs.name$in%27Human+Brain+Atlas%27%5D," "rma::options" "%5Bnum_rows$eq%27all%27%5D" - "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D") + "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_structures_by_set_ids(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'structures_by_graph_ids', - graph_ids='1') + rma.template_query("ontology_queries", "structures_by_graph_ids", graph_ids="1") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Structure,rma::criteria," "%5Bgraph_id$in1%5D,rma::options%5Bnum_rows$eq%27all%27%5D" - "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D") + "%5Border$eqstructures.graph_order%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_atlases_table(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'atlases_table') + rma.template_query("ontology_queries", "atlases_table") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Atlas,rma::criteria," "structure_graph%28ontology%29,graphic_group_labels," "rma::include,%5Bstructure_graph%28ontology%29,graphic_group_labels," - "rma::options%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "rma::options%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_atlases_table_one_graph(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'atlases_table', - graph_ids=1) + rma.template_query("ontology_queries", "atlases_table", graph_ids=1) ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::Atlas,rma::criteria," "%5Bgraph_id$in1%5D,structure_graph%28ontology%29,graphic_group_labels," "rma::include,%5Bstructure_graph%28ontology%29,graphic_group_labels," - "rma::options%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "rma::options%5Bnum_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) @patch("allensdk.core.json_utilities.read_url_get", return_value=_msg) def test_atlases_table_brief(ju_read_url_get, rma): - rma.template_query('ontology_queries', - 'atlases_table_brief') + rma.template_query("ontology_queries", "atlases_table_brief") ju_read_url_get.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" @@ -261,4 +263,5 @@ def test_atlases_table_brief(ju_read_url_get, rma): "rma::options%5Bonly$eq%27atlases.id,atlases.name,atlases.image_type," "ontologies.id,ontologies.name,structure_graphs.id,structure_graphs.name," "graphic_group_labels.id,graphic_group_labels.name%27%5D%5B" - "num_rows$eq%27all%27%5D%5Bcount$eqfalse%5D") + "num_rows$eq%27all%27%5D%5Bcount$eqfalse%5D" + ) diff --git a/allensdk/test/api/test_svg_api.py b/allensdk/test/api/test_svg_api.py index 1fee1b37e8..e8bb1ffa92 100644 --- a/allensdk/test/api/test_svg_api.py +++ b/allensdk/test/api/test_svg_api.py @@ -33,64 +33,67 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -#test AllenSDK svg api for download and show +# test AllenSDK svg api for download and show from allensdk.api.queries.svg_api import SvgApi import pytest from unittest.mock import MagicMock + @pytest.fixture def svg(): sa = SvgApi() return sa - + + def test_build_query(svg): ####download true url download = True groups = None section_image_id = 21889 - returned_url = svg.build_query(section_image_id, groups,download) - assert (returned_url == "http://api.brain-map.org/api/v2/svg_download/21889") - + returned_url = svg.build_query(section_image_id, groups, download) + assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889" + ####download true with one group url download = True groups = [1] section_image_id = 21889 - returned_url = svg.build_query(section_image_id, groups,download) - assert (returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1") - + returned_url = svg.build_query(section_image_id, groups, download) + assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1" + ####download true with groups url download = True - groups = [1,2] + groups = [1, 2] section_image_id = 21889 - returned_url = svg.build_query(section_image_id, groups,download) - assert (returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1,2") - + returned_url = svg.build_query(section_image_id, groups, download) + assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1,2" + ####download false url download = False groups = None section_image_id = 21889 - returned_url = svg.build_query(section_image_id, groups,download) - assert (returned_url == "http://api.brain-map.org/api/v2/svg/21889") - + returned_url = svg.build_query(section_image_id, groups, download) + assert returned_url == "http://api.brain-map.org/api/v2/svg/21889" + ####download false groups exist url download = False groups = [28] section_image_id = 21889 - returned_url = svg.build_query(section_image_id, groups,download) - assert (returned_url == "http://api.brain-map.org/api/v2/svg/21889?groups=28") + returned_url = svg.build_query(section_image_id, groups, download) + assert returned_url == "http://api.brain-map.org/api/v2/svg/21889?groups=28" + def test_download_svg(svg): - svg.retrieve_file_over_http = MagicMock(name='retrieve_file_over_http') + svg.retrieve_file_over_http = MagicMock(name="retrieve_file_over_http") section_image_id = 21889 groups = None file_path = None - - svg.download_svg(section_image_id,groups,file_path) - svg.retrieve_file_over_http.assert_called_with('http://api.brain-map.org/api/v2/svg_download/21889', '21889.svg') - - + + svg.download_svg(section_image_id, groups, file_path) + svg.retrieve_file_over_http.assert_called_with("http://api.brain-map.org/api/v2/svg_download/21889", "21889.svg") + + def test_get_svg(svg): - svg.retrieve_xml_over_http = MagicMock(name='retrieve_xml_over_http') + svg.retrieve_xml_over_http = MagicMock(name="retrieve_xml_over_http") ####groups None section_image_id = 100960033 @@ -98,7 +101,7 @@ def test_get_svg(svg): svg.get_svg(section_image_id, groups) svg.retrieve_xml_over_http.assert_called_with("http://api.brain-map.org/api/v2/svg/100960033") - + ####groups in 28 section_image_id = 100960033 groups = [28] diff --git a/allensdk/test/api/test_synchronization_api.py b/allensdk/test/api/test_synchronization_api.py index c9bf3da2e1..11d33bd6cd 100644 --- a/allensdk/test/api/test_synchronization_api.py +++ b/allensdk/test/api/test_synchronization_api.py @@ -41,28 +41,26 @@ @pytest.fixture def synch(): sa = SynchronizationApi() - sa.json_msg_query = MagicMock(name='json_msg_query') + sa.json_msg_query = MagicMock(name="json_msg_query") return sa def test_image_to_image(synch): - ''' + """ Notes ----- Expected link is slightly modified for json and float serialization of zeros. See: `Image Alignment `_ , link labeled 'Sync a VISp and VISal experiment to a location in a SCs SectionDataSet'. - ''' + """ section_image_id = 114754496 (x, y) = (18232, 10704) section_data_set_ids = [113887162, 116903968] - _ = synch.get_image_to_image(section_image_id, - x, y, - section_data_set_ids) - expected = 'http://api.brain-map.org/api/v2/image_to_image/114754496.json?x=18232.000000&y=10704.000000§ion_data_set_ids=113887162,116903968' + _ = synch.get_image_to_image(section_image_id, x, y, section_data_set_ids) + expected = "http://api.brain-map.org/api/v2/image_to_image/114754496.json?x=18232.000000&y=10704.000000§ion_data_set_ids=113887162,116903968" synch.json_msg_query.assert_called_once_with(expected) @@ -71,10 +69,8 @@ def test_image_to_image_2d(synch): (x, y) = (6208, 2368) section_image_ids = [68173103, 68173105, 68173107] - _ = synch.get_image_to_image_2d(section_image_id, - x, y, - section_image_ids) - expected = 'http://api.brain-map.org/api/v2/image_to_image_2d/68173101.json?x=6208.000000&y=2368.000000§ion_image_ids=68173103,68173105,68173107' + _ = synch.get_image_to_image_2d(section_image_id, x, y, section_image_ids) + expected = "http://api.brain-map.org/api/v2/image_to_image_2d/68173101.json?x=6208.000000&y=2368.000000§ion_image_ids=68173103,68173105,68173107" synch.json_msg_query.assert_called_once_with(expected) @@ -83,10 +79,8 @@ def test_reference_to_image(synch): (x, y, z) = (6085, 3670, 4883) section_data_set_ids = [68545324, 67810540] - _ = synch.get_reference_to_image(reference_space_id, - x, y, z, - section_data_set_ids) - expected = 'http://api.brain-map.org/api/v2/reference_to_image/10.json?x=6085.000000&y=3670.000000&z=4883.000000§ion_data_set_ids=68545324,67810540' + _ = synch.get_reference_to_image(reference_space_id, x, y, z, section_data_set_ids) + expected = "http://api.brain-map.org/api/v2/reference_to_image/10.json?x=6085.000000&y=3670.000000&z=4883.000000§ion_data_set_ids=68545324,67810540" synch.json_msg_query.assert_called_once_with(expected) @@ -94,37 +88,32 @@ def test_image_to_reference(synch): section_image_id = 68173101 (x, y) = (6208, 2368) - _ = synch.get_image_to_reference(section_image_id, - x, y) - expected = 'http://api.brain-map.org/api/v2/image_to_reference/68173101.json?x=6208.000000&y=2368.000000' + _ = synch.get_image_to_reference(section_image_id, x, y) + expected = "http://api.brain-map.org/api/v2/image_to_reference/68173101.json?x=6208.000000&y=2368.000000" synch.json_msg_query.assert_called_once_with(expected) def test_structure_to_image(synch): section_data_set_id = 68545324 - structure_ids = [315, 698, 1089, 703, 477, - 803, 512, 549, 1097, 313, 771, 354] + structure_ids = [315, 698, 1089, 703, 477, 803, 512, 549, 1097, 313, 771, 354] - _ = synch.get_structure_to_image(section_data_set_id, - structure_ids) - expected = 'http://api.brain-map.org/api/v2/structure_to_image/68545324.json?structure_ids=315,698,1089,703,477,803,512,549,1097,313,771,354' + _ = synch.get_structure_to_image(section_data_set_id, structure_ids) + expected = "http://api.brain-map.org/api/v2/structure_to_image/68545324.json?structure_ids=315,698,1089,703,477,803,512,549,1097,313,771,354" synch.json_msg_query.assert_called_once_with(expected) def test_image_to_atlas(synch): - ''' + """ Notes ----- Expected link is slightly modified for json and float serialization of zeros. See: `Image Alignment `_ , link labeled 'Sync the P56 coronal reference atlas to a location in the SCs SectionDataSet'. - ''' + """ section_image_id = 114754496 (x, y) = (18232, 10704) atlas_id = 1 - _ = synch.get_image_to_atlas(section_image_id, - x, y, - atlas_id) - expected = 'http://api.brain-map.org/api/v2/image_to_atlas/114754496.json?x=18232.000000&y=10704.000000&atlas_id=1' + _ = synch.get_image_to_atlas(section_image_id, x, y, atlas_id) + expected = "http://api.brain-map.org/api/v2/image_to_atlas/114754496.json?x=18232.000000&y=10704.000000&atlas_id=1" synch.json_msg_query.assert_called_once_with(expected) diff --git a/allensdk/test/api/test_tree_search_api.py b/allensdk/test/api/test_tree_search_api.py index ebb37c7b57..79a474f05f 100644 --- a/allensdk/test/api/test_tree_search_api.py +++ b/allensdk/test/api/test_tree_search_api.py @@ -38,57 +38,70 @@ import pytest from unittest.mock import MagicMock + @pytest.fixture def tree_search(): tsa = TreeSearchApi() - tsa.json_msg_query = MagicMock(name='json_msg_query') + tsa.json_msg_query = MagicMock(name="json_msg_query") return tsa + def test_get_specimen_tree(tree_search): ####ancestor true for Specimen - kind = 'Specimen' + kind = "Specimen" db_id = 113817886 ancestors = True descendants = None tree_search.get_tree(kind, db_id, ancestors, descendants) - tree_search.json_msg_query.assert_called_with("http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true") - + tree_search.json_msg_query.assert_called_with( + "http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true" + ) + ####ancestor true for Specimen - kind = 'Specimen' + kind = "Specimen" db_id = 113817886 ancestors = True descendants = False tree_search.get_tree(kind, db_id, ancestors, descendants) - tree_search.json_msg_query.assert_called_with("http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true&descendants=false") - + tree_search.json_msg_query.assert_called_with( + "http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true&descendants=false" + ) + ####ancestor false for Specimen - kind = 'Specimen' + kind = "Specimen" db_id = 113817886 ancestors = False descendants = True tree_search.get_tree(kind, db_id, ancestors, descendants) - tree_search.json_msg_query.assert_called_with("http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=false&descendants=true") + tree_search.json_msg_query.assert_called_with( + "http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=false&descendants=true" + ) + def test_get_structure_tree(tree_search): ####ancestor True for Structure - kind = 'Structure' + kind = "Structure" db_id = 12547 ancestors = True descendants = True tree_search.get_tree(kind, db_id, ancestors, descendants) - tree_search.json_msg_query.assert_called_with("http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=true&descendants=true") - + tree_search.json_msg_query.assert_called_with( + "http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=true&descendants=true" + ) + ####ancestor False for Structure - kind = 'Structure' + kind = "Structure" db_id = 12547 ancestors = False descendants = True tree_search.get_tree(kind, db_id, ancestors, descendants) - tree_search.json_msg_query.assert_called_with("http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=false&descendants=true") - + tree_search.json_msg_query.assert_called_with( + "http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=false&descendants=true" + ) + ####ancestor None for Structure - kind = 'Structure' + kind = "Structure" db_id = 12547 ancestors = None descendants = None diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/conftest.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/conftest.py index 31c2e5b81f..2ebd1f5ded 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/conftest.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/conftest.py @@ -3,17 +3,16 @@ import pandas as pd import pytest import semver -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( # noqa +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( # noqa VisualBehaviorNeuropixelsProjectCloudApi, ) -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import ( # noqa +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import ( # noqa BehaviorProjectCloudApi, ) @pytest.fixture def vbo_s3_cloud_cache_data(): - all_versions = {} all_versions["data"] = {} all_versions["metadata"] = {} @@ -226,7 +225,6 @@ def vbo_s3_cloud_cache_data(): @pytest.fixture def vbn_s3_cloud_cache_data(): - all_versions = {} all_versions["data"] = {} all_versions["metadata"] = {} diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/tables/test_ophys_mixin.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/tables/test_ophys_mixin.py index 13bb06300e..35bc07db0b 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/tables/test_ophys_mixin.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/tables/test_ophys_mixin.py @@ -3,20 +3,15 @@ import numpy as np import pandas as pd -from allensdk.brain_observatory.behavior.behavior_project_cache.tables\ - .ophys_mixin import \ - OphysMixin +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.ophys_mixin import OphysMixin class TestOhysMixin: def test__merge_column_values_nothing_to_merge(self): """when nothing to merge, df stays same""" - df = pd.DataFrame({ - 'date_of_acquisition': [1, 2], - 'foo': [3, 4] - }) + df = pd.DataFrame({"date_of_acquisition": [1, 2], "foo": [3, 4]}) - with patch.object(OphysMixin, '__init__') as ophys_mixin: + with patch.object(OphysMixin, "__init__") as ophys_mixin: ophys_mixin.return_value = None mixin = OphysMixin() mixin._df = df @@ -25,47 +20,43 @@ def test__merge_column_values_nothing_to_merge(self): def test__merge_column_values(self): """ophys value chosen""" - df = pd.DataFrame({ - 'date_of_acquisition_behavior': [1, 2], - 'date_of_acquisition_ophys': [3, 4], - 'session_type_behavior': ['foo', 'bar'], - 'session_type_ophys': ['foo', 'baz'], - 'foo': [1, 2] - }) + df = pd.DataFrame( + { + "date_of_acquisition_behavior": [1, 2], + "date_of_acquisition_ophys": [3, 4], + "session_type_behavior": ["foo", "bar"], + "session_type_ophys": ["foo", "baz"], + "foo": [1, 2], + } + ) - with patch.object(OphysMixin, '__init__') as ophys_mixin: + with patch.object(OphysMixin, "__init__") as ophys_mixin: ophys_mixin.return_value = None mixin = OphysMixin() mixin._df = df mixin._merge_columns() - expected = pd.DataFrame({ - 'date_of_acquisition': [3, 4], - 'session_type': ['foo', 'baz'], - 'foo': [1, 2] - }) - pd.testing.assert_frame_equal( - expected.sort_index(axis=1), - mixin._df.sort_index(axis=1) - ) + expected = pd.DataFrame({"date_of_acquisition": [3, 4], "session_type": ["foo", "baz"], "foo": [1, 2]}) + pd.testing.assert_frame_equal(expected.sort_index(axis=1), mixin._df.sort_index(axis=1)) def test__merge_column_values_missing(self): """ophys value chosen and merged with non-null behavior""" - df = pd.DataFrame({ - 'date_of_acquisition_behavior': ['foo', 'bar'], - 'date_of_acquisition_ophys': ['baz', np.nan], - }) + df = pd.DataFrame( + { + "date_of_acquisition_behavior": ["foo", "bar"], + "date_of_acquisition_ophys": ["baz", np.nan], + } + ) - with patch.object(OphysMixin, '__init__') as ophys_mixin: + with patch.object(OphysMixin, "__init__") as ophys_mixin: ophys_mixin.return_value = None mixin = OphysMixin() mixin._df = df mixin._merge_columns() - expected = pd.DataFrame({ - 'date_of_acquisition': ['baz', 'bar'], - }) - pd.testing.assert_frame_equal( - expected.sort_index(axis=1), - mixin._df.sort_index(axis=1) + expected = pd.DataFrame( + { + "date_of_acquisition": ["baz", "bar"], + } ) + pd.testing.assert_frame_equal(expected.sort_index(axis=1), mixin._df.sort_index(axis=1)) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_neuropixels_project_cloud_api.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_neuropixels_project_cloud_api.py index 1dbb447947..8e7d5d2107 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_neuropixels_project_cloud_api.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_neuropixels_project_cloud_api.py @@ -40,9 +40,7 @@ def __init__( channel_table.to_csv(self.channel_table_path, index=False) probe_table.to_csv(self.probe_table_path, index=False) unit_table.to_csv(self.unit_table_path, index=False) - behavior_session_table.to_csv( - self.behavior_session_table_path, index=False - ) + behavior_session_table.to_csv(self.behavior_session_table_path, index=False) self._manifest = MagicMock() self._manifest.metadata_file_names = [ @@ -83,17 +81,13 @@ def mock_cache(tmpdir): "behavior_sessions": pd.DataFrame( { "behavior_session_id": [1, 2, 3, 4], - "ecephys_session_id": pd.Series( - [10, 11, 12, 13], dtype="Int64" - ), + "ecephys_session_id": pd.Series([10, 11, 12, 13], dtype="Int64"), "mouse_id": [4, 4, 2, 1], } ), "ecephys_sessions": pd.DataFrame( { - "ecephys_session_id": pd.Series( - [10, 11, 12, 13], dtype="Int64" - ), + "ecephys_session_id": pd.Series([10, 11, 12, 13], dtype="Int64"), "behavior_session_id": [1, 2, 3, 4], "file_id": [10, 11, 12, 13], } @@ -147,18 +141,12 @@ def mock_cache(tmpdir): @pytest.mark.parametrize("local", [True, False]) -def test_VisualBehaviorNeuropixelsProjectCloudApi( - mock_cache, monkeypatch, local -): +def test_VisualBehaviorNeuropixelsProjectCloudApi(mock_cache, monkeypatch, local): mocked_cache, expected = mock_cache - api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi( - mocked_cache, skip_version_check=True, local=False - ) + api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi(mocked_cache, skip_version_check=True, local=False) if local: - api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi( - mocked_cache, skip_version_check=True, local=True - ) + api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi(mocked_cache, skip_version_check=True, local=True) # behavior session table as expected bst = api.get_behavior_session_table() @@ -197,9 +185,7 @@ def test_VisualBehaviorNeuropixelsProjectCloudApi( def mock_nwb(nwb_path, probe_meta): return nwb_path - monkeypatch.setattr( - cloudapi.BehaviorEcephysSession, "from_nwb_path", mock_nwb - ) + monkeypatch.setattr(cloudapi.BehaviorEcephysSession, "from_nwb_path", mock_nwb) assert api.get_ecephys_session(12) == "12" @@ -212,31 +198,23 @@ def test_probe_meta(mock_cache, monkeypatch, has_lfp_data): probe_meta_table["has_lfp_data"] = has_lfp_data probe_meta_table.to_csv(mocked_cache.probe_table_path, index=False) - api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi( - mocked_cache, skip_version_check=True, local=False - ) + api = cloudapi.VisualBehaviorNeuropixelsProjectCloudApi(mocked_cache, skip_version_check=True, local=False) def mock_from_nwb_path(nwb_path, probe_meta): return probe_meta ecephys_session_id = 10 - with patch.object( - BehaviorEcephysSession, "from_nwb_path", wraps=mock_from_nwb_path - ): + with patch.object(BehaviorEcephysSession, "from_nwb_path", wraps=mock_from_nwb_path): probe_meta = api.get_ecephys_session(ecephys_session_id) probe_meta_table = api.get_probe_table() - probes = probe_meta_table.loc[ - probe_meta_table["ecephys_session_id"] == ecephys_session_id - ] + probes = probe_meta_table.loc[probe_meta_table["ecephys_session_id"] == ecephys_session_id] if has_lfp_data: for probe_name in probes["name"].unique(): obtained_probe_nwb_path = probe_meta[probe_name].lfp_csd_filepath() expected_probe_nwb_path = str( - probe_meta_table.loc[ - (probe_meta_table["name"] == probe_name) - ].iloc[0]["file_id"] + probe_meta_table.loc[(probe_meta_table["name"] == probe_name)].iloc[0]["file_id"] ) assert obtained_probe_nwb_path == expected_probe_nwb_path else: @@ -250,9 +228,7 @@ def mock_from_nwb_path(nwb_path, probe_meta): ("1.0.1", "2.9.0", "0.0.0", "1.0.0", True), ], ) -def test_version_check( - manifest_version, data_pipeline_version, cmin, cmax, exception -): +def test_version_check(manifest_version, data_pipeline_version, cmin, cmax, exception): if exception: with pytest.raises( BehaviorCloudCacheVersionException, @@ -300,9 +276,7 @@ def test_from_local_cache(monkeypatch): except (TypeError, FileNotFoundError): pass - mock_local_cache.assert_called_once_with( - "first_cache_dir", "project_1", "ui_1" - ) + mock_local_cache.assert_called_once_with("first_cache_dir", "project_1", "ui_1") # Test from_local_cache with use_static_cache=True try: @@ -312,6 +286,4 @@ def test_from_local_cache(monkeypatch): except (TypeError, FileNotFoundError): pass - mock_static_local_cache.assert_called_once_with( - "second_cache_dir", "project_2", "ui_2" - ) + mock_static_local_cache.assert_called_once_with("second_cache_dir", "project_2", "ui_2") diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_cloud_api.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_cloud_api.py index 04bf5ee56b..ccb3f9b313 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_cloud_api.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_cloud_api.py @@ -35,12 +35,8 @@ def __init__( ophys_session_table.to_csv(self.session_table_path, index=False) ophys_cells_table.to_csv(self.ophys_cells_table_path, index=False) - behavior_session_table.to_csv( - self.behavior_session_table_path, index=False - ) - ophys_experiment_table.to_csv( - self.ophys_experiment_table_path, index=False - ) + behavior_session_table.to_csv(self.behavior_session_table_path, index=False) + ophys_experiment_table.to_csv(self.ophys_experiment_table_path, index=False) self._manifest = MagicMock() self._manifest.metadata_file_names = [ @@ -109,8 +105,7 @@ def mock_cache(request, tmpdir): ), "ophys_session_table": pd.DataFrame( { - "ophys_session_id": pd.Series([10, 11, 12, 13], - dtype='Int64'), + "ophys_session_id": pd.Series([10, 11, 12, 13], dtype="Int64"), "mouse_id": ["1"] * 4, "date_of_acquisition": pd.to_datetime(["2021-01-01"] * 4), "ophys_experiment_id": [4, 5, 6, [7, 8, 9]], @@ -120,9 +115,7 @@ def mock_cache(request, tmpdir): { "ophys_experiment_id": [4, 5, 6, 7, 8, 9], "mouse_id": ["1"] * 6, - "date_of_acquisition": pd.to_datetime( - ["2021-01-01"] * 6, utc=True - ), + "date_of_acquisition": pd.to_datetime(["2021-01-01"] * 6, utc=True), "file_id": [4, 5, 6, 7, 8, 9], } ), @@ -140,13 +133,9 @@ def mock_cache(request, tmpdir): @pytest.mark.parametrize("local", [True, False]) def test_BehaviorProjectCloudApi(mock_cache, monkeypatch, local): mocked_cache, expected = mock_cache - api = cloudapi.BehaviorProjectCloudApi( - mocked_cache, skip_version_check=True, local=False - ) + api = cloudapi.BehaviorProjectCloudApi(mocked_cache, skip_version_check=True, local=False) if local: - api = cloudapi.BehaviorProjectCloudApi( - mocked_cache, skip_version_check=True, local=True - ) + api = cloudapi.BehaviorProjectCloudApi(mocked_cache, skip_version_check=True, local=True) # behavior session table as expected bost = api.get_behavior_session_table() @@ -184,9 +173,7 @@ def mock_nwb(nwb_path): assert api.get_behavior_session(4) == "7" # direct check only for ophys experiment - monkeypatch.setattr( - cloudapi.BehaviorOphysExperiment, "from_nwb_path", mock_nwb - ) + monkeypatch.setattr(cloudapi.BehaviorOphysExperiment, "from_nwb_path", mock_nwb) assert api.get_behavior_ophys_experiment(8) == "8" @@ -197,9 +184,7 @@ def mock_nwb(nwb_path): ("1.0.1", "2.9.0", "0.0.0", "1.0.0", True), ], ) -def test_version_check( - manifest_version, data_pipeline_version, cmin, cmax, exception -): +def test_version_check(manifest_version, data_pipeline_version, cmin, cmax, exception): if exception: with pytest.raises( BehaviorCloudCacheVersionException, @@ -225,9 +210,7 @@ def test_from_local_cache(monkeypatch): "comment": "This is a test entry. NOT REAL.", } ] - mock_manifest.version = ( - cloudapi.BehaviorProjectCloudApi.MANIFEST_COMPATIBILITY[0] - ) + mock_manifest.version = cloudapi.BehaviorProjectCloudApi.MANIFEST_COMPATIBILITY[0] mock_local_cache = create_autospec(cloudapibase.LocalCache) type(mock_local_cache.return_value)._manifest = mock_manifest @@ -252,9 +235,7 @@ def test_from_local_cache(monkeypatch): except (TypeError, FileNotFoundError): pass - mock_local_cache.assert_called_once_with( - "first_cache_dir", "project_1", "ui_1" - ) + mock_local_cache.assert_called_once_with("first_cache_dir", "project_1", "ui_1") # Test from_local_cache with use_static_cache=True try: @@ -269,6 +250,4 @@ def test_from_local_cache(monkeypatch): except (TypeError, FileNotFoundError): pass - mock_static_local_cache.assert_called_once_with( - "second_cache_dir", "project_2", "ui_2" - ) + mock_static_local_cache.assert_called_once_with("second_cache_dir", "project_2", "ui_2") diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_lims_api.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_lims_api.py index c2d7ea30bc..291ee6d87c 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_lims_api.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_lims_api.py @@ -42,9 +42,7 @@ def stream(self, endpoint): @pytest.fixture def MockBehaviorProjectLimsApi(): - return BehaviorProjectLimsApi( - MockQueryEngine(), MockQueryEngine(), MockQueryEngine() - ) + return BehaviorProjectLimsApi(MockQueryEngine(), MockQueryEngine(), MockQueryEngine()) @pytest.mark.parametrize( @@ -80,9 +78,7 @@ def MockBehaviorProjectLimsApi(): ), ], ) -def test_build_line_from_donor_query( - line, expected, MockBehaviorProjectLimsApi -): +def test_build_line_from_donor_query(line, expected, MockBehaviorProjectLimsApi): mbp_api = MockBehaviorProjectLimsApi assert expected == mbp_api._build_line_from_donor_query(line=line) @@ -96,18 +92,10 @@ def setup_class(cls): # Note: these tables will need to be updated if the expected table # changes - cls.release_behavior_sessions_table = pd.read_csv( - test_dir / "behavior_session_table.csv" - ) - cls.release_ophys_sessions_table = pd.read_csv( - test_dir / "ophys_session_table.csv" - ) - cls.release_ophys_experiments_table = pd.read_csv( - test_dir / "ophys_experiment_table.csv" - ) - cls.release_ophys_cells_table = pd.read_csv( - test_dir / "ophys_cells_table.csv" - ) + cls.release_behavior_sessions_table = pd.read_csv(test_dir / "behavior_session_table.csv") + cls.release_ophys_sessions_table = pd.read_csv(test_dir / "ophys_session_table.csv") + cls.release_ophys_experiments_table = pd.read_csv(test_dir / "ophys_experiment_table.csv") + cls.release_ophys_cells_table = pd.read_csv(test_dir / "ophys_cells_table.csv") cls.lims_api = BehaviorProjectLimsApi.default(passed_only=False) @@ -128,79 +116,43 @@ def _get_behavior_session(self, behavior_session_id, lims_db): def test_all_behavior_sessions(self): """Tests that when passed_only=False, that more sessions are returned than in the release table""" - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): obtained = self.lims_api.get_behavior_session_table() # Make sure ids returned are superset of release ids assert ( - len( - set(obtained.index).intersection( - self.release_behavior_sessions_table[ - "behavior_session_id" - ] - ) - ) - == self.release_behavior_sessions_table[ - "behavior_session_id" - ].nunique() - ) - assert ( - obtained.shape[0] - > self.release_behavior_sessions_table.shape[0] + len(set(obtained.index).intersection(self.release_behavior_sessions_table["behavior_session_id"])) + == self.release_behavior_sessions_table["behavior_session_id"].nunique() ) + assert obtained.shape[0] > self.release_behavior_sessions_table.shape[0] @pytest.mark.requires_bamboo def test_all_ophys_sessions(self): """Tests that when passed_only=False, that more sessions are returned than in the release table""" - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): obtained = self.lims_api.get_ophys_session_table() # Make sure ids returned are superset of release ids assert ( - len( - set(obtained.index).intersection( - self.release_ophys_sessions_table["ophys_session_id"] - ) - ) - == self.release_ophys_sessions_table[ - "ophys_session_id" - ].nunique() - ) - assert ( - obtained.shape[0] > self.release_ophys_sessions_table.shape[0] + len(set(obtained.index).intersection(self.release_ophys_sessions_table["ophys_session_id"])) + == self.release_ophys_sessions_table["ophys_session_id"].nunique() ) + assert obtained.shape[0] > self.release_ophys_sessions_table.shape[0] @pytest.mark.requires_bamboo def test_all_ophys_experiments(self): """Tests that when passed_only=False, that more experiments are returned than in the release table""" - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): obtained = self.lims_api.get_ophys_experiment_table() # Make sure ids returned are superset of release ids assert ( - len( - set(obtained.index).intersection( - self.release_ophys_experiments_table[ - "ophys_experiment_id" - ] - ) - ) - == self.release_ophys_experiments_table[ - "ophys_experiment_id" - ].nunique() - ) - assert ( - obtained.shape[0] - > self.release_ophys_experiments_table.shape[0] + len(set(obtained.index).intersection(self.release_ophys_experiments_table["ophys_experiment_id"])) + == self.release_ophys_experiments_table["ophys_experiment_id"].nunique() ) + assert obtained.shape[0] > self.release_ophys_experiments_table.shape[0] @pytest.mark.requires_bamboo def test_all_cells(self): @@ -211,9 +163,7 @@ def test_all_cells(self): # Make sure ids returned are superset of release ids assert ( len( - set(obtained["ophys_experiment_id"]).intersection( - self.release_ophys_cells_table["ophys_experiment_id"] - ) + set(obtained["ophys_experiment_id"]).intersection(self.release_ophys_cells_table["ophys_experiment_id"]) ) == self.release_ophys_cells_table["ophys_experiment_id"].nunique() ) @@ -230,16 +180,12 @@ def setup_class(cls): cls.test_dir = Path(__file__).parent / "test_data" / "vbo" - behavior_sessions_table = pd.read_csv( - cls.test_dir / "behavior_session_table.csv" - ) - cls.session_type_map = behavior_sessions_table.set_index( - "behavior_session_id" - )[["session_type"]].to_dict()["session_type"] + behavior_sessions_table = pd.read_csv(cls.test_dir / "behavior_session_table.csv") + cls.session_type_map = behavior_sessions_table.set_index("behavior_session_id")[["session_type"]].to_dict()[ + "session_type" + ] - cls.lims_cache = VisualBehaviorOphysProjectCache.from_lims( - data_release_date=["2021-03-25", "2021-08-12"] - ) + cls.lims_cache = VisualBehaviorOphysProjectCache.from_lims(data_release_date=["2021-03-25", "2021-08-12"]) cls.tempdir = tempdir with patch.object( @@ -247,9 +193,7 @@ def setup_class(cls): "_get_metadata_path", wraps=cls._get_project_table_path, ): - cls.cloud_cache = VisualBehaviorOphysProjectCache.from_s3_cache( - cache_dir=tempdir.name - ) + cls.cloud_cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir=tempdir.name) def teardown_class(self): self.tempdir.cleanup() @@ -261,9 +205,7 @@ def _get_behavior_session(self, behavior_session_id, lims_db): behavior_session_id=BehaviorSessionId(behavior_session_id), equipment=None, stimulus_frame_rate=None, - session_type=SessionType( - self.session_type_map[behavior_session_id] - ), + session_type=SessionType(self.session_type_map[behavior_session_id]), behavior_session_uuid=None, ) @@ -280,9 +222,7 @@ def _get_project_table_path(cls, fname): @pytest.mark.requires_bamboo def test_behavior_session_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): from_lims = self.lims_cache.get_behavior_session_table() from_lims = from_lims.drop(columns=list(SESSION_SUPPRESS)) @@ -297,9 +237,7 @@ def test_behavior_session_table(self): @pytest.mark.requires_bamboo def test_ophys_session_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): from_lims = self.lims_cache.get_ophys_session_table() from_lims = from_lims.drop(columns=list(SESSION_SUPPRESS)) @@ -313,14 +251,11 @@ def test_ophys_session_table(self): @pytest.mark.requires_bamboo def test_ophys_experiments_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): from_lims = self.lims_cache.get_ophys_experiment_table() from_lims = from_lims.drop( - columns=list(OPHYS_EXPERIMENTS_SUPPRESS) - + list(OPHYS_EXPERIMENTS_SUPPRESS_FINAL), + columns=list(OPHYS_EXPERIMENTS_SUPPRESS) + list(OPHYS_EXPERIMENTS_SUPPRESS_FINAL), errors="ignore", ) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_metadata_writer.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_metadata_writer.py index 7576ccb524..3488706e3e 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_metadata_writer.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_behavior_project_metadata_writer.py @@ -25,32 +25,26 @@ def setup_class(cls): # Note: these tables will need to be updated if the expected table # changes - cls.expected_behavior_sessions_table = pd.read_csv( - test_dir / "behavior_session_table.csv" - ) - cls.expected_ophys_sessions_table = pd.read_csv( - test_dir / "ophys_session_table.csv" - ) - cls.expected_ophys_experiments_table = pd.read_csv( - test_dir / "ophys_experiment_table.csv" - ) - cls.expected_ophys_cells_table = pd.read_csv( - test_dir / "ophys_cells_table.csv" - ) + cls.expected_behavior_sessions_table = pd.read_csv(test_dir / "behavior_session_table.csv") + cls.expected_ophys_sessions_table = pd.read_csv(test_dir / "ophys_session_table.csv") + cls.expected_ophys_experiments_table = pd.read_csv(test_dir / "ophys_experiment_table.csv") + cls.expected_ophys_cells_table = pd.read_csv(test_dir / "ophys_cells_table.csv") - cls.session_type_map = cls.expected_behavior_sessions_table.set_index( - "behavior_session_id" - )[["session_type"]].to_dict()["session_type"] + cls.session_type_map = cls.expected_behavior_sessions_table.set_index("behavior_session_id")[ + ["session_type"] + ].to_dict()["session_type"] cls.test_dir = tempfile.TemporaryDirectory() - input_data = {"output_dir": cls.test_dir.name, - "data_release_date": ['2021-03-25', '2021-08-12'], - "clobber": True, - "log_level": "INFO", - "behavior_nwb_dir": str(test_dir), - "ophys_nwb_dir": str(test_dir), - "on_missing_file": "warn"} + input_data = { + "output_dir": cls.test_dir.name, + "data_release_date": ["2021-03-25", "2021-08-12"], + "clobber": True, + "log_level": "INFO", + "behavior_nwb_dir": str(test_dir), + "ophys_nwb_dir": str(test_dir), + "on_missing_file": "warn", + } cls.project_table_writer = BehaviorProjectMetadataWriter( input_data=input_data, args=[], @@ -69,156 +63,101 @@ def _get_behavior_session(self, behavior_session_id, lims_db): behavior_session_id=BehaviorSessionId(behavior_session_id), equipment=None, stimulus_frame_rate=None, - session_type=SessionType( - self.session_type_map[behavior_session_id] - ), + session_type=SessionType(self.session_type_map[behavior_session_id]), behavior_session_uuid=None, ) @pytest.mark.requires_bamboo def test_get_behavior_sessions_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): - self.project_table_writer._write_behavior_sessions( - include_trial_metrics=False - ) - obtained = pd.read_csv( - Path(self.test_dir.name) / "behavior_session_table.csv" - ) - obtained = obtained.sort_values("behavior_session_id").reset_index( - drop=True - ) - expected = self.expected_behavior_sessions_table.sort_values( - "behavior_session_id" - ).reset_index(drop=True) + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): + self.project_table_writer._write_behavior_sessions(include_trial_metrics=False) + obtained = pd.read_csv(Path(self.test_dir.name) / "behavior_session_table.csv") + obtained = obtained.sort_values("behavior_session_id").reset_index(drop=True) + expected = self.expected_behavior_sessions_table.sort_values("behavior_session_id").reset_index(drop=True) # File paths are not created by # AllenInstitute/informatics_release_tool hence we ignore them # here. pd.testing.assert_frame_equal( - obtained.drop('file_path', axis=1).sort_index(axis=1), - expected.drop('file_path', axis=1).sort_index(axis=1) + obtained.drop("file_path", axis=1).sort_index(axis=1), + expected.drop("file_path", axis=1).sort_index(axis=1), ) @pytest.mark.requires_bamboo def test_get_ophys_sessions_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): self.project_table_writer._write_ophys_sessions() - obtained = pd.read_csv( - Path(self.test_dir.name) / "ophys_session_table.csv" - ) - obtained = obtained.sort_values("ophys_session_id").reset_index( - drop=True - ) - obtained['date_of_acquisition'] = pd.to_datetime( - obtained['date_of_acquisition'], utc=True) - expected = self.expected_ophys_sessions_table.sort_values( - "ophys_session_id" - ).reset_index(drop=True) - expected['date_of_acquisition'] = pd.to_datetime( - expected['date_of_acquisition'], utc=True) - pd.testing.assert_frame_equal( - obtained.sort_index(axis=1), - expected.sort_index(axis=1) - ) + obtained = pd.read_csv(Path(self.test_dir.name) / "ophys_session_table.csv") + obtained = obtained.sort_values("ophys_session_id").reset_index(drop=True) + obtained["date_of_acquisition"] = pd.to_datetime(obtained["date_of_acquisition"], utc=True) + expected = self.expected_ophys_sessions_table.sort_values("ophys_session_id").reset_index(drop=True) + expected["date_of_acquisition"] = pd.to_datetime(expected["date_of_acquisition"], utc=True) + pd.testing.assert_frame_equal(obtained.sort_index(axis=1), expected.sort_index(axis=1)) @pytest.mark.requires_bamboo def test_get_ophys_experiments_table(self): - with patch.object( - BehaviorMetadata, "from_lims", wraps=self._get_behavior_session - ): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): self.project_table_writer._write_ophys_experiments() - obtained = pd.read_csv( - Path(self.test_dir.name) / "ophys_experiment_table.csv" - ) - obtained = obtained.sort_values("ophys_experiment_id").reset_index( - drop=True - ) - expected = self.expected_ophys_experiments_table.sort_values( - "ophys_experiment_id" - ).reset_index(drop=True) - - obtained['date_of_acquisition'] = pd.to_datetime( - obtained['date_of_acquisition'], utc=True) - expected['date_of_acquisition'] = pd.to_datetime( - expected['date_of_acquisition'], utc=True) + obtained = pd.read_csv(Path(self.test_dir.name) / "ophys_experiment_table.csv") + obtained = obtained.sort_values("ophys_experiment_id").reset_index(drop=True) + expected = self.expected_ophys_experiments_table.sort_values("ophys_experiment_id").reset_index(drop=True) + + obtained["date_of_acquisition"] = pd.to_datetime(obtained["date_of_acquisition"], utc=True) + expected["date_of_acquisition"] = pd.to_datetime(expected["date_of_acquisition"], utc=True) # File paths are not created by # AllenInstitute/informatics_release_tool hence we ignore them # here. pd.testing.assert_frame_equal( - obtained.drop('file_path', axis=1).sort_index(axis=1), - expected.drop('file_path', axis=1).sort_index(axis=1) + obtained.drop("file_path", axis=1).sort_index(axis=1), + expected.drop("file_path", axis=1).sort_index(axis=1), ) @pytest.mark.requires_bamboo def test_get_ophys_cells_table(self): self.project_table_writer._write_ophys_cells() - obtained = pd.read_csv( - Path(self.test_dir.name) / "ophys_cells_table.csv" - ) - pd.testing.assert_frame_equal( - obtained, self.expected_ophys_cells_table - ) + obtained = pd.read_csv(Path(self.test_dir.name) / "ophys_cells_table.csv") + pd.testing.assert_frame_equal(obtained, self.expected_ophys_cells_table) @pytest.mark.requires_bamboo def test_imaging_plane_group_only_mesoscope(self): """Tests that imaging plane group only applies to mesoscope""" - with patch.object( - BehaviorMetadata, 'from_lims', - wraps=self._get_behavior_session): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): self.project_table_writer._write_ophys_sessions() self.project_table_writer._write_ophys_experiments() - ophys_session_tbl = pd.read_csv(Path(self.test_dir.name) / - 'ophys_session_table.csv') - ophys_experiment_tbl = pd.read_csv(Path(self.test_dir.name) / - 'ophys_experiment_table.csv') - df = ophys_session_tbl.merge(ophys_experiment_tbl, - on='ophys_session_id') - - assert (df[~df['equipment_name_x'].str.startswith('MESO')] - ['imaging_plane_group_count'].isna().all()) - assert (df[~df['equipment_name_x'].str.startswith('MESO')] - ['imaging_plane_group'].isna().all()) + ophys_session_tbl = pd.read_csv(Path(self.test_dir.name) / "ophys_session_table.csv") + ophys_experiment_tbl = pd.read_csv(Path(self.test_dir.name) / "ophys_experiment_table.csv") + df = ophys_session_tbl.merge(ophys_experiment_tbl, on="ophys_session_id") + + assert df[~df["equipment_name_x"].str.startswith("MESO")]["imaging_plane_group_count"].isna().all() + assert df[~df["equipment_name_x"].str.startswith("MESO")]["imaging_plane_group"].isna().all() @pytest.mark.requires_bamboo def test_imaging_plane_group_count_consistent(self): """Tests that imaging plane group count in ophys sessions table is consistent with the number of imaging plane groups in experiment table""" - with patch.object( - BehaviorMetadata, 'from_lims', - wraps=self._get_behavior_session): + with patch.object(BehaviorMetadata, "from_lims", wraps=self._get_behavior_session): self.project_table_writer._write_ophys_sessions() self.project_table_writer._write_ophys_experiments() - ophys_session_tbl = pd.read_csv(Path(self.test_dir.name) / - 'ophys_session_table.csv') - ophys_experiment_tbl = pd.read_csv(Path(self.test_dir.name) / - 'ophys_experiment_table.csv') - df = ophys_session_tbl.merge(ophys_experiment_tbl, - on='ophys_session_id') + ophys_session_tbl = pd.read_csv(Path(self.test_dir.name) / "ophys_session_table.csv") + ophys_experiment_tbl = pd.read_csv(Path(self.test_dir.name) / "ophys_experiment_table.csv") + df = ophys_session_tbl.merge(ophys_experiment_tbl, on="ophys_session_id") imaging_plane_group_count = ( - df[~df['imaging_plane_group'].isna()] - .groupby('ophys_session_id')['imaging_plane_group'].nunique() + df[~df["imaging_plane_group"].isna()] + .groupby("ophys_session_id")["imaging_plane_group"] + .nunique() .reset_index() - .rename( - columns={'imaging_plane_group': 'imaging_plane_group_count'}) + .rename(columns={"imaging_plane_group": "imaging_plane_group_count"}) ) ophys_session_tbl = ophys_session_tbl.merge( - imaging_plane_group_count, - on='ophys_session_id', - suffixes=('_from_lims', '_recalculated'), - how='left' + imaging_plane_group_count, on="ophys_session_id", suffixes=("_from_lims", "_recalculated"), how="left" ) assert ( - (ophys_session_tbl['imaging_plane_group_count_from_lims'] - [~ophys_session_tbl['imaging_plane_group_count_from_lims'].isna()] - == - ophys_session_tbl['imaging_plane_group_count_recalculated'] - [~ophys_session_tbl['imaging_plane_group_count_recalculated'] - .isna()]) - .all() - ) + ophys_session_tbl["imaging_plane_group_count_from_lims"][ + ~ophys_session_tbl["imaging_plane_group_count_from_lims"].isna() + ] + == ophys_session_tbl["imaging_plane_group_count_recalculated"][ + ~ophys_session_tbl["imaging_plane_group_count_recalculated"].isna() + ] + ).all() diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_experiments_table_utils.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_experiments_table_utils.py index c60bd0cf05..f912376998 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_experiments_table_utils.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_experiments_table_utils.py @@ -1,70 +1,67 @@ import copy import pandas as pd -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - tables.util.experiments_table_utils import ( - add_passive_flag_to_ophys_experiment_table, - add_image_set_to_experiment_table) +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.experiments_table_utils import ( + add_passive_flag_to_ophys_experiment_table, + add_image_set_to_experiment_table, +) def test_add_passive_flag(): - input_data = [] expected_data = [] - datum = {'id': 0, 'session_number': 2} + datum = {"id": 0, "session_number": 2} input_data.append(copy.deepcopy(datum)) - datum['passive'] = True + datum["passive"] = True expected_data.append(copy.deepcopy(datum)) - datum = {'id': 1, 'session_number': 5} + datum = {"id": 1, "session_number": 5} input_data.append(copy.deepcopy(datum)) - datum['passive'] = True + datum["passive"] = True expected_data.append(copy.deepcopy(datum)) - datum = {'id': 2, 'session_number': 1} + datum = {"id": 2, "session_number": 1} input_data.append(copy.deepcopy(datum)) - datum['passive'] = False + datum["passive"] = False expected_data.append(copy.deepcopy(datum)) - datum = {'id': 3, 'session_number': 3} + datum = {"id": 3, "session_number": 3} input_data.append(copy.deepcopy(datum)) - datum['passive'] = False + datum["passive"] = False expected_data.append(copy.deepcopy(datum)) - datum = {'id': 4, 'session_number': 2} + datum = {"id": 4, "session_number": 2} input_data.append(copy.deepcopy(datum)) - datum['passive'] = True + datum["passive"] = True expected_data.append(copy.deepcopy(datum)) - datum = {'id': 5, 'session_number': 5} + datum = {"id": 5, "session_number": 5} input_data.append(copy.deepcopy(datum)) - datum['passive'] = True + datum["passive"] = True expected_data.append(copy.deepcopy(datum)) input_df = pd.DataFrame(input_data) expected_df = pd.DataFrame(expected_data) assert not input_df.equals(expected_df) - output_df = add_passive_flag_to_ophys_experiment_table( - input_df) + output_df = add_passive_flag_to_ophys_experiment_table(input_df) assert not input_df.equals(output_df) assert len(input_df.columns) != len(output_df.columns) assert output_df.equals(expected_df) def test_add_image_set_to_experiment_table(): - input_data = [] expected_data = [] - datum = {'id': 0, 'session_type': 'ophys_5_images_x_passive'} + datum = {"id": 0, "session_type": "ophys_5_images_x_passive"} input_data.append(copy.deepcopy(datum)) - datum['image_set'] = 'x' + datum["image_set"] = "x" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 1, 'session_type': 'ophys_5'} + datum = {"id": 1, "session_type": "ophys_5"} input_data.append(copy.deepcopy(datum)) - datum['image_set'] = 'N/A' + datum["image_set"] = "N/A" expected_data.append(copy.deepcopy(datum)) input_df = pd.DataFrame(input_data) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_metadata_parsers.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_metadata_parsers.py index 64fc3a525b..28012eab92 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_metadata_parsers.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_metadata_parsers.py @@ -88,12 +88,7 @@ def test_parse_num_cortical_structures(): } ) expected = pd.Series([1, 1, 2, 4, None], dtype="Int64") - obtained = ( - df["project_code"] - .apply(parse_num_cortical_structures) - .astype("Int64") - .rename(None) - ) + obtained = df["project_code"].apply(parse_num_cortical_structures).astype("Int64").rename(None) pd.testing.assert_series_equal(expected, obtained) @@ -112,7 +107,5 @@ def test_parse_num_depths(): } ) expected = pd.Series([1, 1, 4, 2, None], dtype="Int64") - obtained = ( - df["project_code"].apply(parse_num_depths).astype("Int64").rename(None) - ) + obtained = df["project_code"].apply(parse_num_depths).astype("Int64").rename(None) pd.testing.assert_series_equal(expected, obtained) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_natural_movie_cache.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_natural_movie_cache.py index 3497bd720a..2ef61b88d4 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_natural_movie_cache.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_natural_movie_cache.py @@ -16,13 +16,9 @@ def test_natural_movie_cache(): "behavior_project_cache.project_apis.data_io." "natural_movie_one_cache.NaturalMovieOneCache." "get_raw_movie", - return_value=rng.integers( - low=0, high=256, size=(1, 304, 608), dtype=np.uint8 - ), + return_value=rng.integers(low=0, high=256, size=(1, 304, 608), dtype=np.uint8), ): - cache = NaturalMovieOneCache( - cache_dir="fake_dir", bucket_name="fake_bucket" - ) + cache = NaturalMovieOneCache(cache_dir="fake_dir", bucket_name="fake_bucket") movie = cache.get_processed_template_movie(n_workers=1) assert movie.index.name == "movie_frame_index" assert movie.columns.to_list() == ["unwarped", "warped"] diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_pandas_compat.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_pandas_compat.py index a6a178a860..50025dbd49 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_pandas_compat.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_pandas_compat.py @@ -54,18 +54,18 @@ def test_schema_parses_timezone_aware_date(self): def test_cloud_api_mixed_precision_dates(self): """Simulate the cloud API's to_datetime call on a column with mixed microsecond precision, as found in real session CSVs.""" - df = pd.DataFrame({ - "date_of_acquisition": [ - "2019-09-25 13:31:46", - "2019-10-01 14:22:33.123456", - "2020-01-15 09:00:00", - "2020-06-30 16:45:12.500000", - ] - }) - # This is the exact call pattern used in behavior_project_cloud_api.py - df["date_of_acquisition"] = pd.to_datetime( - df["date_of_acquisition"], format="ISO8601", utc=True + df = pd.DataFrame( + { + "date_of_acquisition": [ + "2019-09-25 13:31:46", + "2019-10-01 14:22:33.123456", + "2020-01-15 09:00:00", + "2020-06-30 16:45:12.500000", + ] + } ) + # This is the exact call pattern used in behavior_project_cloud_api.py + df["date_of_acquisition"] = pd.to_datetime(df["date_of_acquisition"], format="ISO8601", utc=True) assert df["date_of_acquisition"].dt.tz is not None assert len(df) == 4 @@ -79,9 +79,11 @@ def test_mixed_precision_fails_without_explicit_format(self): with pytest.raises(ValueError, match="doesn't match format"): pd.to_datetime( - pd.Series([ - "2019-09-25 13:31:46", - "2019-10-01 14:22:33.123456", - ]), + pd.Series( + [ + "2019-09-25 13:31:46", + "2019-10-01 14:22:33.123456", + ] + ), format="%Y-%m-%d %H:%M:%S.%f", ) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbn_from_s3.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbn_from_s3.py index eb41e609d3..85600316d3 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbn_from_s3.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbn_from_s3.py @@ -2,10 +2,8 @@ import pytest -from allensdk.brain_observatory.ecephys.behavior_ecephys_session import \ - BehaviorEcephysSession -from allensdk.brain_observatory.behavior.behavior_session import \ - BehaviorSession +from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession +from allensdk.brain_observatory.behavior.behavior_session import BehaviorSession from .utils import create_bucket, load_dataset import boto3 from moto import mock_s3 @@ -15,9 +13,9 @@ from allensdk.api.cloud_cache.cloud_cache import MissingLocalManifestWarning from allensdk.api.cloud_cache.cloud_cache import OutdatedManifestWarning -from allensdk.brain_observatory.\ - behavior.behavior_project_cache.behavior_neuropixels_project_cache \ - import VisualBehaviorNeuropixelsProjectCache +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import ( + VisualBehaviorNeuropixelsProjectCache, +) @mock_s3 @@ -31,10 +29,7 @@ def test_vbn_metadata_tables(tmpdir, vbn_s3_cloud_cache_data): cache_dir = pathlib.Path(tmpdir) / "test_metadata" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) cache.get_probe_table() @@ -48,8 +43,7 @@ def test_vbn_metadata_tables(tmpdir, vbn_s3_cloud_cache_data): assert len(ecephys) == 1 assert ecephys.index[0] == 222 - abnormal = cache.get_ecephys_session_table( - filter_abnormalities=False) + abnormal = cache.get_ecephys_session_table(filter_abnormalities=False) assert len(abnormal) == 3 @@ -60,44 +54,33 @@ def test_probe_nwb_file(monkeypatch, tmpdir, vbn_s3_cloud_cache_data): cache_dir = pathlib.Path(tmpdir) / "test_metadata" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) probe_meta_table = cache.get_probe_table() for probe in probe_meta_table.itertuples(): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorEcephysSession, 'from_nwb_path', - lambda path, probe_meta: probe_meta) - probe_meta = \ - cache.get_ecephys_session( - ecephys_session_id=probe.ecephys_session_id) + ctx.setattr(BehaviorEcephysSession, "from_nwb_path", lambda path, probe_meta: probe_meta) + probe_meta = cache.get_ecephys_session(ecephys_session_id=probe.ecephys_session_id) probe_id = probe.Index probe_nwb = probe_meta[probe.name].lfp_csd_filepath() - expected_path = (cache_dir / f'{project_name}-0.{len(data)}.0' / - 'data' / f'probe_{probe_id}_lfp.nwb') + expected_path = cache_dir / f"{project_name}-0.{len(data)}.0" / "data" / f"probe_{probe_id}_lfp.nwb" assert probe_nwb == expected_path assert expected_path.is_file() @mock_s3 def test_manifest_methods(tmpdir, vbn_s3_cloud_cache_data): - data, versions = vbn_s3_cloud_cache_data cache_dir = pathlib.Path(tmpdir) / "test_manifest_list" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] m_list = cache.list_manifest_file_names() assert len(m_list) == 2 @@ -112,41 +95,33 @@ def test_manifest_methods(tmpdir, vbn_s3_cloud_cache_data): change_msg = cache.compare_manifests(v_names[0], v_names[-1]) - for mname in ('behavior_sessions', - 'ecephys_sessions', - 'probes'): + for mname in ("behavior_sessions", "ecephys_sessions", "probes"): print(change_msg) - assert f'project_metadata/{mname} changed' in change_msg + assert f"project_metadata/{mname} changed" in change_msg - assert 'ecephys_file_1.nwb changed' in change_msg - assert 'ecephys_file_3.nwb created' in change_msg + assert "ecephys_file_1.nwb changed" in change_msg + assert "ecephys_file_3.nwb created" in change_msg @mock_s3 -def test_local_cache_construction( - tmpdir, - vbn_s3_cloud_cache_data, - monkeypatch -): - +def test_local_cache_construction(tmpdir, vbn_s3_cloud_cache_data, monkeypatch): data, versions = vbn_s3_cloud_cache_data cache_dir = pathlib.Path(tmpdir) / "test_construction" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] cache.load_manifest(v_names[0]) with monkeypatch.context() as ctx: - ctx.setattr(BehaviorEcephysSession, 'from_nwb_path', - lambda path, probe_meta: create_autospec( - BehaviorEcephysSession, instance=True)) + ctx.setattr( + BehaviorEcephysSession, + "from_nwb_path", + lambda path, probe_meta: create_autospec(BehaviorEcephysSession, instance=True), + ) cache.get_ecephys_session(ecephys_session_id=5111) assert cache.fetch_api.cache._downloaded_data_path.is_file() cache.fetch_api.cache._downloaded_data_path.unlink() @@ -156,10 +131,12 @@ def test_local_cache_construction( with pytest.warns(MissingLocalManifestWarning) as warnings: cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - cmd = 'VisualBehaviorNeuropixelsProjectCache.construct_local_manifest()' + cmd = "VisualBehaviorNeuropixelsProjectCache.construct_local_manifest()" warning_msgs = [ - f'{warnings[i].message}' for i in range(len(warnings)) - if type(warnings[i].message) is MissingLocalManifestWarning] + f"{warnings[i].message}" + for i in range(len(warnings)) + if type(warnings[i].message) is MissingLocalManifestWarning + ] assert any([cmd in msg for msg in warning_msgs]) # Because, at the point where the cache was reconstitute, @@ -171,19 +148,15 @@ def test_local_cache_construction( cache.construct_local_manifest() assert cache.fetch_api.cache._downloaded_data_path.is_file() - with open(manifest_path, 'rb') as in_file: + with open(manifest_path, "rb") as in_file: local_manifest = json.load(in_file) fnames = set([pathlib.Path(k).name for k in local_manifest]) - assert 'ecephys_file_1.nwb' in fnames + assert "ecephys_file_1.nwb" in fnames assert len(local_manifest) == 9 # 8 metadata files and 1 data file @mock_s3 -def test_load_out_of_date_manifest( - tmpdir, - vbn_s3_cloud_cache_data, - monkeypatch -): +def test_load_out_of_date_manifest(tmpdir, vbn_s3_cloud_cache_data, monkeypatch): """ Test that VisualBehaviorNeuropixelsProjectCache can load a manifest other than the latest and download files @@ -194,58 +167,50 @@ def test_load_out_of_date_manifest( cache_dir = pathlib.Path(tmpdir) / "test_linkage" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] cache.load_manifest(v_names[0]) for sess_id in (333, 444): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda path: create_autospec( - BehaviorSession, instance=True)) + ctx.setattr(BehaviorSession, "from_nwb_path", lambda path: create_autospec(BehaviorSession, instance=True)) cache.get_behavior_session(behavior_session_id=sess_id) for ses_id in (5111, 5112): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorEcephysSession, 'from_nwb_path', - lambda path, probe_meta: create_autospec( - BehaviorEcephysSession, instance=True)) + ctx.setattr( + BehaviorEcephysSession, + "from_nwb_path", + lambda path, probe_meta: create_autospec(BehaviorEcephysSession, instance=True), + ) cache.get_ecephys_session(ecephys_session_id=ses_id) - v1_dir = cache_dir / f'{project_name}-{versions[0]}/data' + v1_dir = cache_dir / f"{project_name}-{versions[0]}/data" # Check that all expected file were downloaded - dir_glob = v1_dir.glob('*') + dir_glob = v1_dir.glob("*") file_names = set() file_contents = {} for p in dir_glob: file_names.add(p.name) - with open(p, 'rb') as in_file: + with open(p, "rb") as in_file: data = in_file.read() file_contents[p.name] = data - expected = {'ecephys_file_1.nwb', 'ecephys_file_2.nwb'} + expected = {"ecephys_file_1.nwb", "ecephys_file_2.nwb"} assert file_names == expected expected = {} - expected['ecephys_file_1.nwb'] = b'abcde' - expected['ecephys_file_2.nwb'] = b'fghijk' + expected["ecephys_file_1.nwb"] = b"abcde" + expected["ecephys_file_2.nwb"] = b"fghijk" assert file_contents == expected @mock_s3 @pytest.mark.parametrize("delete_cache", [True, False]) -def test_file_linkage( - tmpdir, - vbn_s3_cloud_cache_data, - delete_cache, - monkeypatch -): +def test_file_linkage(tmpdir, vbn_s3_cloud_cache_data, delete_cache, monkeypatch): """ Test that symlinks are used where appropriate @@ -259,15 +224,12 @@ def test_file_linkage( bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] - v_dirs = [cache_dir / f'{project_name}-{i}/data' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] + v_dirs = [cache_dir / f"{project_name}-{i}/data" for i in versions] assert cache.current_manifest() == v_names[-1] assert cache.list_all_downloaded_manifests() == [v_names[-1]] @@ -278,18 +240,18 @@ def test_file_linkage( for sess_id in (333, 444): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda path: create_autospec( - BehaviorSession, instance=True)) + ctx.setattr(BehaviorSession, "from_nwb_path", lambda path: create_autospec(BehaviorSession, instance=True)) cache.get_behavior_session(behavior_session_id=sess_id) for sess_id in (5111, 5112): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorEcephysSession, 'from_nwb_path', - lambda path, probe_meta: create_autospec( - BehaviorEcephysSession, instance=True)) + ctx.setattr( + BehaviorEcephysSession, + "from_nwb_path", + lambda path, probe_meta: create_autospec(BehaviorEcephysSession, instance=True), + ) cache.get_ecephys_session(ecephys_session_id=sess_id) - v1_glob = v_dirs[0].glob('*') + v1_glob = v_dirs[0].glob("*") v1_paths = {} for p in v1_glob: v1_paths[p.name] = p @@ -308,30 +270,29 @@ def test_file_linkage( assert cache.current_manifest() == v_names[-1] for sess_id in (777, 888): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda path: create_autospec( - BehaviorSession, instance=True)) + ctx.setattr(BehaviorSession, "from_nwb_path", lambda path: create_autospec(BehaviorSession, instance=True)) cache.get_behavior_session(behavior_session_id=sess_id) for sess_id in (222, 333): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorEcephysSession, 'from_nwb_path', - lambda path, probe_meta: create_autospec( - BehaviorEcephysSession, instance=True)) + ctx.setattr( + BehaviorEcephysSession, + "from_nwb_path", + lambda path, probe_meta: create_autospec(BehaviorEcephysSession, instance=True), + ) cache.get_ecephys_session(ecephys_session_id=sess_id) - v2_glob = v_dirs[-1].glob('*') + v2_glob = v_dirs[-1].glob("*") v2_paths = {} for p in v2_glob: v2_paths[p.name] = p # check symlinks - for name in ('ecephys_file_2.nwb',): - + for name in ("ecephys_file_2.nwb",): assert v2_paths[name].is_symlink() assert v2_paths[name].resolve() == v1_paths[name].resolve() assert v2_paths[name].absolute() != v1_paths[name].absolute() - name = 'ecephys_file_1.nwb' + name = "ecephys_file_1.nwb" assert not v2_paths[name].is_symlink() assert not v2_paths[name].absolute() == v1_paths[name].absolute() @@ -346,34 +307,26 @@ def test_when_data_updated(tmpdir, vbn_s3_cloud_cache_data, data_update): cache_dir = pathlib.Path(tmpdir) / "test_update" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) del cache - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") later_version = str(semver.parse_version_info(versions[-1]).bump_minor()) - load_dataset(data_update['data'], - data_update['metadata'], - later_version, - bucket_name, - project_name, - client) + load_dataset(data_update["data"], data_update["metadata"], later_version, bucket_name, project_name, client) - name3 = f'{project_name}_manifest_v{later_version}' - name2 = f'{project_name}_manifest_v{versions[-1]}' + name3 = f"{project_name}_manifest_v{later_version}" + name2 = f"{project_name}_manifest_v{versions[-1]}" - cmd = 'VisualBehaviorNeuropixelsProjectCache.load_manifest' + cmd = "VisualBehaviorNeuropixelsProjectCache.load_manifest" with pytest.warns(OutdatedManifestWarning, match=name3) as warnings: VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) checked_msg = False for w in warnings.list: - if w._category_name == 'OutdatedManifestWarning': + if w._category_name == "OutdatedManifestWarning": msg = str(w.message) assert name3 in msg assert name2 in msg @@ -393,20 +346,17 @@ def test_load_last(tmpdir, vbn_s3_cloud_cache_data, data_update): cache_dir = pathlib.Path(tmpdir) / "test_update" bucket_name = VisualBehaviorNeuropixelsProjectCache.BUCKET_NAME project_name = VisualBehaviorNeuropixelsProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] assert cache.current_manifest() == v_names[-1] cache.load_manifest(v_names[0]) assert cache.current_manifest() == v_names[0] del cache - msg = 'VisualBehaviorNeuropixelsProjectCache.compare_manifests' + msg = "VisualBehaviorNeuropixelsProjectCache.compare_manifests" with pytest.warns(OutdatedManifestWarning, match=msg): cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbo_from_s3.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbo_from_s3.py index a959a2c833..30bc1356df 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbo_from_s3.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/test_vbo_from_s3.py @@ -2,10 +2,8 @@ import pytest -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import \ - BehaviorOphysExperiment -from allensdk.brain_observatory.behavior.behavior_session import \ - BehaviorSession +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment +from allensdk.brain_observatory.behavior.behavior_session import BehaviorSession from .utils import create_bucket, load_dataset import boto3 from moto import mock_s3 @@ -15,27 +13,23 @@ from allensdk.api.cloud_cache.cloud_cache import MissingLocalManifestWarning from allensdk.api.cloud_cache.cloud_cache import OutdatedManifestWarning -from allensdk.brain_observatory.\ - behavior.behavior_project_cache.behavior_project_cache \ - import VisualBehaviorOphysProjectCache +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_project_cache import ( + VisualBehaviorOphysProjectCache, +) @mock_s3 def test_manifest_methods(tmpdir, vbo_s3_cloud_cache_data): - data, versions = vbo_s3_cloud_cache_data cache_dir = pathlib.Path(tmpdir) / "test_manifest_list" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] m_list = cache.list_manifest_file_names() assert len(m_list) == 2 @@ -50,43 +44,35 @@ def test_manifest_methods(tmpdir, vbo_s3_cloud_cache_data): change_msg = cache.compare_manifests(v_names[0], v_names[-1]) - for mname in ('behavior_session_table', - 'ophys_session_table', - 'ophys_experiment_table'): - assert f'project_metadata/{mname} changed' in change_msg + for mname in ("behavior_session_table", "ophys_session_table", "ophys_experiment_table"): + assert f"project_metadata/{mname} changed" in change_msg - assert 'ophys_file_1.nwb changed' in change_msg - assert 'ophys_file_5.nwb created' in change_msg - assert 'ophys_file_2.nwb' not in change_msg - assert 'behavior_file_3.nwb' not in change_msg - assert 'behavior_file_4.nwb' not in change_msg + assert "ophys_file_1.nwb changed" in change_msg + assert "ophys_file_5.nwb created" in change_msg + assert "ophys_file_2.nwb" not in change_msg + assert "behavior_file_3.nwb" not in change_msg + assert "behavior_file_4.nwb" not in change_msg @mock_s3 -def test_local_cache_construction( - tmpdir, - vbo_s3_cloud_cache_data, - monkeypatch -): - +def test_local_cache_construction(tmpdir, vbo_s3_cloud_cache_data, monkeypatch): data, versions = vbo_s3_cloud_cache_data cache_dir = pathlib.Path(tmpdir) / "test_construction" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] cache.load_manifest(v_names[0]) with monkeypatch.context() as ctx: - ctx.setattr(BehaviorOphysExperiment, 'from_nwb_path', - lambda nwb_path: create_autospec( - BehaviorOphysExperiment, instance=True)) + ctx.setattr( + BehaviorOphysExperiment, + "from_nwb_path", + lambda nwb_path: create_autospec(BehaviorOphysExperiment, instance=True), + ) cache.get_behavior_ophys_experiment(ophys_experiment_id=5111) assert cache.fetch_api.cache._downloaded_data_path.is_file() cache.fetch_api.cache._downloaded_data_path.unlink() @@ -96,10 +82,12 @@ def test_local_cache_construction( with pytest.warns(MissingLocalManifestWarning) as warnings: cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - cmd = 'VisualBehaviorOphysProjectCache.construct_local_manifest()' + cmd = "VisualBehaviorOphysProjectCache.construct_local_manifest()" warning_msgs = [ - f'{warnings[i].message}' for i in range(len(warnings)) - if type(warnings[i].message) is MissingLocalManifestWarning] + f"{warnings[i].message}" + for i in range(len(warnings)) + if type(warnings[i].message) is MissingLocalManifestWarning + ] assert any([cmd in msg for msg in warning_msgs]) # Because, at the point where the cache was reconstitute, @@ -111,19 +99,15 @@ def test_local_cache_construction( cache.construct_local_manifest() assert cache.fetch_api.cache._downloaded_data_path.is_file() - with open(manifest_path, 'rb') as in_file: + with open(manifest_path, "rb") as in_file: local_manifest = json.load(in_file) fnames = set([pathlib.Path(k).name for k in local_manifest]) - assert 'ophys_file_1.nwb' in fnames + assert "ophys_file_1.nwb" in fnames assert len(local_manifest) == 9 # 8 metadata files and 1 data file @mock_s3 -def test_load_out_of_date_manifest( - tmpdir, - vbo_s3_cloud_cache_data, - monkeypatch -): +def test_load_out_of_date_manifest(tmpdir, vbo_s3_cloud_cache_data, monkeypatch): """ Test that VisualBehaviorOphysProjectCache can load a manifest other than the latest and download files @@ -134,62 +118,54 @@ def test_load_out_of_date_manifest( cache_dir = pathlib.Path(tmpdir) / "test_linkage" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] cache.load_manifest(v_names[0]) for sess_id in (333, 444): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda nwb_path: - create_autospec( - BehaviorSession, instance=True)) + ctx.setattr( + BehaviorSession, "from_nwb_path", lambda nwb_path: create_autospec(BehaviorSession, instance=True) + ) cache.get_behavior_session(behavior_session_id=sess_id) for exp_id in (5111, 5222): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorOphysExperiment, 'from_nwb_path', - lambda nwb_path: create_autospec( - BehaviorOphysExperiment, instance=True)) + ctx.setattr( + BehaviorOphysExperiment, + "from_nwb_path", + lambda nwb_path: create_autospec(BehaviorOphysExperiment, instance=True), + ) cache.get_behavior_ophys_experiment(ophys_experiment_id=exp_id) - v1_dir = cache_dir / f'{project_name}-{versions[0]}/data' + v1_dir = cache_dir / f"{project_name}-{versions[0]}/data" # Check that all expected file were downloaded - dir_glob = v1_dir.glob('*') + dir_glob = v1_dir.glob("*") file_names = set() file_contents = {} for p in dir_glob: file_names.add(p.name) - with open(p, 'rb') as in_file: + with open(p, "rb") as in_file: data = in_file.read() file_contents[p.name] = data - expected = {'ophys_file_1.nwb', 'ophys_file_2.nwb', - 'behavior_file_3.nwb', 'behavior_file_4.nwb'} + expected = {"ophys_file_1.nwb", "ophys_file_2.nwb", "behavior_file_3.nwb", "behavior_file_4.nwb"} assert file_names == expected expected = {} - expected['ophys_file_1.nwb'] = b'abcde' - expected['ophys_file_2.nwb'] = b'fghijk' - expected['behavior_file_3.nwb'] = b'12345' - expected['behavior_file_4.nwb'] = b'67890' + expected["ophys_file_1.nwb"] = b"abcde" + expected["ophys_file_2.nwb"] = b"fghijk" + expected["behavior_file_3.nwb"] = b"12345" + expected["behavior_file_4.nwb"] = b"67890" assert file_contents == expected @mock_s3 @pytest.mark.parametrize("delete_cache", [True, False]) -def test_file_linkage( - tmpdir, - vbo_s3_cloud_cache_data, - delete_cache, - monkeypatch -): +def test_file_linkage(tmpdir, vbo_s3_cloud_cache_data, delete_cache, monkeypatch): """ Test that symlinks are used where appropriate @@ -202,15 +178,12 @@ def test_file_linkage( cache_dir = pathlib.Path(tmpdir) / "test_linkage" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] - v_dirs = [cache_dir / f'{project_name}-{i}/data' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] + v_dirs = [cache_dir / f"{project_name}-{i}/data" for i in versions] assert cache.current_manifest() == v_names[-1] assert cache.list_all_downloaded_manifests() == [v_names[-1]] @@ -221,19 +194,20 @@ def test_file_linkage( for sess_id in (333, 444): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda nwb_path: - create_autospec( - BehaviorSession, instance=True)) + ctx.setattr( + BehaviorSession, "from_nwb_path", lambda nwb_path: create_autospec(BehaviorSession, instance=True) + ) cache.get_behavior_session(behavior_session_id=sess_id) for exp_id in (5111, 5222): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorOphysExperiment, 'from_nwb_path', - lambda nwb_path: create_autospec( - BehaviorOphysExperiment, instance=True)) + ctx.setattr( + BehaviorOphysExperiment, + "from_nwb_path", + lambda nwb_path: create_autospec(BehaviorOphysExperiment, instance=True), + ) cache.get_behavior_ophys_experiment(ophys_experiment_id=exp_id) - v1_glob = v_dirs[0].glob('*') + v1_glob = v_dirs[0].glob("*") v1_paths = {} for p in v1_glob: v1_paths[p.name] = p @@ -252,37 +226,35 @@ def test_file_linkage( assert cache.current_manifest() == v_names[-1] for sess_id in (777, 888): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, 'from_nwb_path', - lambda nwb_path: - create_autospec( - BehaviorSession, instance=True)) + ctx.setattr( + BehaviorSession, "from_nwb_path", lambda nwb_path: create_autospec(BehaviorSession, instance=True) + ) cache.get_behavior_session(behavior_session_id=sess_id) for exp_id in (5444, 5666, 5777): with monkeypatch.context() as ctx: - ctx.setattr(BehaviorOphysExperiment, 'from_nwb_path', - lambda nwb_path: create_autospec( - BehaviorOphysExperiment, instance=True)) + ctx.setattr( + BehaviorOphysExperiment, + "from_nwb_path", + lambda nwb_path: create_autospec(BehaviorOphysExperiment, instance=True), + ) cache.get_behavior_ophys_experiment(ophys_experiment_id=exp_id) - v2_glob = v_dirs[-1].glob('*') + v2_glob = v_dirs[-1].glob("*") v2_paths = {} for p in v2_glob: v2_paths[p.name] = p # check symlinks - for name in ('ophys_file_2.nwb', - 'behavior_file_3.nwb', - 'behavior_file_4.nwb'): - + for name in ("ophys_file_2.nwb", "behavior_file_3.nwb", "behavior_file_4.nwb"): assert v2_paths[name].is_symlink() assert v2_paths[name].resolve() == v1_paths[name].resolve() assert v2_paths[name].absolute() != v1_paths[name].absolute() - name = 'ophys_file_1.nwb' + name = "ophys_file_1.nwb" assert not v2_paths[name].is_symlink() assert not v2_paths[name].absolute() == v1_paths[name].absolute() - assert 'ophys_file_5.nwb' in v2_paths + assert "ophys_file_5.nwb" in v2_paths @mock_s3 @@ -295,34 +267,26 @@ def test_when_data_updated(tmpdir, vbo_s3_cloud_cache_data, data_update): cache_dir = pathlib.Path(tmpdir) / "test_update" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) del cache - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") later_version = str(semver.parse_version_info(versions[-1]).bump_minor()) - load_dataset(data_update['data'], - data_update['metadata'], - later_version, - bucket_name, - project_name, - client) + load_dataset(data_update["data"], data_update["metadata"], later_version, bucket_name, project_name, client) - name3 = f'{project_name}_manifest_v{later_version}' - name2 = f'{project_name}_manifest_v{versions[-1]}' + name3 = f"{project_name}_manifest_v{later_version}" + name2 = f"{project_name}_manifest_v{versions[-1]}" - cmd = 'VisualBehaviorOphysProjectCache.load_manifest' + cmd = "VisualBehaviorOphysProjectCache.load_manifest" with pytest.warns(OutdatedManifestWarning, match=name3) as warnings: VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) checked_msg = False for w in warnings.list: - if w._category_name == 'OutdatedManifestWarning': + if w._category_name == "OutdatedManifestWarning": msg = str(w.message) assert name3 in msg assert name2 in msg @@ -342,20 +306,17 @@ def test_load_last(tmpdir, vbo_s3_cloud_cache_data, data_update): cache_dir = pathlib.Path(tmpdir) / "test_update" bucket_name = VisualBehaviorOphysProjectCache.BUCKET_NAME project_name = VisualBehaviorOphysProjectCache.PROJECT_NAME - create_bucket(bucket_name, - project_name, - data['data'], - data['metadata']) + create_bucket(bucket_name, project_name, data["data"], data["metadata"]) cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) - v_names = [f'{project_name}_manifest_v{i}.json' for i in versions] + v_names = [f"{project_name}_manifest_v{i}.json" for i in versions] assert cache.current_manifest() == v_names[-1] cache.load_manifest(v_names[0]) assert cache.current_manifest() == v_names[0] del cache - msg = 'VisualBehaviorOphysProjectCache.compare_manifests' + msg = "VisualBehaviorOphysProjectCache.compare_manifests" with pytest.warns(OutdatedManifestWarning, match=msg): cache = VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir) diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache/utils.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache/utils.py index 0a51ff20ad..f47ea2cc43 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache/utils.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache/utils.py @@ -4,12 +4,14 @@ import hashlib -def load_dataset(data_blobs: dict, - metadata_blobs: Union[dict, None], - manifest_version: str, - bucket_name: str, - project_name: str, - client: boto3.client) -> None: +def load_dataset( + data_blobs: dict, + metadata_blobs: Union[dict, None], + manifest_version: str, + bucket_name: str, + project_name: str, + client: boto3.client, +) -> None: """ Load a test dataset into moto's mocked S3 @@ -41,74 +43,63 @@ def load_dataset(data_blobs: dict, """ for fname in data_blobs: - client.put_object(Bucket=bucket_name, - Key=f'{project_name}/data/{fname}', - Body=data_blobs[fname]['data']) + client.put_object(Bucket=bucket_name, Key=f"{project_name}/data/{fname}", Body=data_blobs[fname]["data"]) if metadata_blobs is not None: for fname in metadata_blobs: - client.put_object(Bucket=bucket_name, - Key=f'{project_name}/project_metadata/{fname}', - Body=metadata_blobs[fname]) + client.put_object( + Bucket=bucket_name, Key=f"{project_name}/project_metadata/{fname}", Body=metadata_blobs[fname] + ) response = client.list_object_versions(Bucket=bucket_name) fname_to_version = {} - for obj in response['Versions']: - if obj['IsLatest']: - fname = obj['Key'].split('/')[-1] - fname_to_version[fname] = obj['VersionId'] + for obj in response["Versions"]: + if obj["IsLatest"]: + fname = obj["Key"].split("/")[-1] + fname_to_version[fname] = obj["VersionId"] manifest = {} - manifest['manifest_version'] = manifest_version - manifest['project_name'] = project_name - manifest['metadata_file_id_column_name'] = 'file_id' - manifest['metadata_files'] = {} - manifest['data_pipeline'] = [{'name': 'AllenSDK', 'version': '1.1.1'}] + manifest["manifest_version"] = manifest_version + manifest["project_name"] = project_name + manifest["metadata_file_id_column_name"] = "file_id" + manifest["metadata_files"] = {} + manifest["data_pipeline"] = [{"name": "AllenSDK", "version": "1.1.1"}] data_file_dict = {} - url_root = f'http://{bucket_name}.s3.amazonaws.com/{project_name}/data' + url_root = f"http://{bucket_name}.s3.amazonaws.com/{project_name}/data" for fname in data_blobs: - url = f'{url_root}/{fname}' + url = f"{url_root}/{fname}" hasher = hashlib.blake2b() - hasher.update(data_blobs[fname]['data']) + hasher.update(data_blobs[fname]["data"]) checksum = hasher.hexdigest() - data_file = {'url': url, - 'version_id': fname_to_version[fname], - 'file_hash': checksum} + data_file = {"url": url, "version_id": fname_to_version[fname], "file_hash": checksum} - data_file_dict[data_blobs[fname]['file_id']] = data_file + data_file_dict[data_blobs[fname]["file_id"]] = data_file - manifest['data_files'] = data_file_dict + manifest["data_files"] = data_file_dict if metadata_blobs is not None: - url_root = f'http://{bucket_name}.s3.amazonaws.com/{project_name}/' - url_root += 'project_metadata' + url_root = f"http://{bucket_name}.s3.amazonaws.com/{project_name}/" + url_root += "project_metadata" metadata_dict = {} for fname in metadata_blobs: - url = f'{url_root}/{fname}' + url = f"{url_root}/{fname}" hasher = hashlib.blake2b() hasher.update(metadata_blobs[fname]) - metadata_dict[fname] = {'url': url, - 'file_hash': hasher.hexdigest(), - 'version_id': fname_to_version[fname]} + metadata_dict[fname] = {"url": url, "file_hash": hasher.hexdigest(), "version_id": fname_to_version[fname]} - manifest['metadata_files'] = metadata_dict + manifest["metadata_files"] = metadata_dict - manifest_k = f'{project_name}/manifests/' - manifest_k += f'{project_name}_manifest_v{manifest_version}.json' - client.put_object(Bucket=bucket_name, - Key=manifest_k, - Body=bytes(json.dumps(manifest), 'utf-8')) + manifest_k = f"{project_name}/manifests/" + manifest_k += f"{project_name}_manifest_v{manifest_version}.json" + client.put_object(Bucket=bucket_name, Key=manifest_k, Body=bytes(json.dumps(manifest), "utf-8")) return None -def create_bucket(test_bucket_name: str, - project_name: str, - datasets: dict, - metadatasets: dict) -> None: +def create_bucket(test_bucket_name: str, project_name: str, datasets: dict, metadatasets: dict) -> None: """ Create a bucket and populate it with example datasets @@ -129,14 +120,14 @@ def create_bucket(test_bucket_name: str, metadata files to be loaded to the bucket (default: None) """ - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket=test_bucket_name, ACL='public-read') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket=test_bucket_name, ACL="public-read") # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#bucketversioning bucket_versioning = conn.BucketVersioning(test_bucket_name) bucket_versioning.enable() - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") # upload first dataset for v in datasets.keys(): @@ -144,11 +135,6 @@ def create_bucket(test_bucket_name: str, m = metadatasets[v] else: m = None - load_dataset(datasets[v], - m, - v, - test_bucket_name, - project_name, - client) + load_dataset(datasets[v], m, v, test_bucket_name, project_name, client) return None diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/conftest.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/conftest.py index 6624eaa156..ba782e280f 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/conftest.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/conftest.py @@ -5,30 +5,25 @@ import pandas as pd import tempfile -from allensdk.brain_observatory.behavior.behavior_project_cache \ - import VisualBehaviorOphysProjectCache +from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorOphysProjectCache -from allensdk.brain_observatory.behavior.behavior_project_cache.\ - tables.util.experiments_table_utils import ( - add_passive_flag_to_ophys_experiment_table, - add_image_set_to_experiment_table) +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.experiments_table_utils import ( + add_passive_flag_to_ophys_experiment_table, + add_image_set_to_experiment_table, +) from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.prior_exposure_processing import ( # noqa: E501 add_experience_level_ophys, ) -from allensdk.brain_observatory.behavior.behavior_project_cache.tables \ - .util.prior_exposure_processing import \ - get_prior_exposures_to_session_type, \ - get_prior_exposures_to_image_set, \ - get_prior_exposures_to_omissions -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.full_genotype import \ - FullGenotype -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.reporter_line import \ - ReporterLine - - -@pytest.fixture(scope='session') +from allensdk.brain_observatory.behavior.behavior_project_cache.tables.util.prior_exposure_processing import ( + get_prior_exposures_to_session_type, + get_prior_exposures_to_image_set, + get_prior_exposures_to_omissions, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.full_genotype import FullGenotype +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.reporter_line import ReporterLine + + +@pytest.fixture(scope="session") def behavior_session_id_list(): """ List of behavior_session_id; the most fundamental fixture @@ -36,115 +31,94 @@ def behavior_session_id_list(): return list(range(1, 9)) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def session_name_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to session_name """ - return {ii: f'session_{ii}' - for ii in behavior_session_id_list} + return {ii: f"session_{ii}" for ii in behavior_session_id_list} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def date_of_acquisition_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to date of acquisition """ - return {ii: np.datetime64(f'2020-02-{ii:02d}') - for ii in behavior_session_id_list} + return {ii: np.datetime64(f"2020-02-{ii:02d}") for ii in behavior_session_id_list} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def session_type_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to session_type """ rng = np.random.default_rng(871231) - possible = ('TRAINING_1_gratings', - 'OPHYS_1_images_A', - 'OPHYS_1_images_B') + possible = ("TRAINING_1_gratings", "OPHYS_1_images_A", "OPHYS_1_images_B") - vals = rng.choice(possible, - size=len(behavior_session_id_list), - replace=True) + vals = rng.choice(possible, size=len(behavior_session_id_list), replace=True) - return {ii: vv - for ii, vv in zip(behavior_session_id_list, - vals)} + return {ii: vv for ii, vv in zip(behavior_session_id_list, vals)} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def project_code_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to project_code """ - return {ii: 'code{ii}' - for ii in behavior_session_id_list} + return {ii: "code{ii}" for ii in behavior_session_id_list} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def specimen_id_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to specimen_id """ - return {ii: 1111*ii - for ii in behavior_session_id_list} + return {ii: 1111 * ii for ii in behavior_session_id_list} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def genotype_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to full_genotype """ rng = np.random.default_rng(981232) - possible = ('foo-SlcCre', - 'Vip-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt', - 'bar', - 'foobar') - chosen = rng.choice(possible, - size=len(behavior_session_id_list), - replace=True) - return {ii: val - for ii, val in zip(behavior_session_id_list, chosen)} - - -@pytest.fixture(scope='session') + possible = ("foo-SlcCre", "Vip-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt", "bar", "foobar") + chosen = rng.choice(possible, size=len(behavior_session_id_list), replace=True) + return {ii: val for ii, val in zip(behavior_session_id_list, chosen)} + + +@pytest.fixture(scope="session") def reporter_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to reporter_line """ - return {ii: f"Ai{90+ii}(TITL-GCaMP6f)" - for ii in behavior_session_id_list} + return {ii: f"Ai{90 + ii}(TITL-GCaMP6f)" for ii in behavior_session_id_list} -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def driver_lookup(behavior_session_id_list): """ Dict mapping behavior_session_id to driver_line. Note: driver_line is a list of strings """ rng = np.random.default_rng(1723213) - possible = [["aa"], - ["aa", "bb"], - ["cc"], - ["cc", "dd"]] - indices = rng.integers(0, len(possible), - size=len(behavior_session_id_list)) + possible = [["aa"], ["aa", "bb"], ["cc"], ["cc", "dd"]] + indices = rng.integers(0, len(possible), size=len(behavior_session_id_list)) chosen = [possible[i] for i in indices] - return {ii: val - for ii, val in zip(behavior_session_id_list, - chosen)} + return {ii: val for ii, val in zip(behavior_session_id_list, chosen)} -@pytest.fixture(scope='session') -def behavior_session_data_fixture(behavior_session_id_list, - session_name_lookup, - date_of_acquisition_lookup, - session_type_lookup, - specimen_id_lookup, - genotype_lookup, - reporter_lookup, - driver_lookup): +@pytest.fixture(scope="session") +def behavior_session_data_fixture( + behavior_session_id_list, + session_name_lookup, + date_of_acquisition_lookup, + session_type_lookup, + specimen_id_lookup, + genotype_lookup, + reporter_lookup, + driver_lookup, +): """ List of dicts. Each dict is an entry in the raw behavior_session_table as would be returned by the @@ -153,7 +127,6 @@ def behavior_session_data_fixture(behavior_session_id_list, behavior_session_list = [] for s_id in behavior_session_id_list: - genotype = genotype_lookup[s_id] driver = driver_lookup[s_id] reporter = reporter_lookup[s_id] @@ -161,20 +134,22 @@ def behavior_session_data_fixture(behavior_session_id_list, specimen_id = specimen_id_lookup[s_id] s_name = session_name_lookup[s_id] s_type = session_type_lookup[s_id] - datum = {'behavior_session_id': s_id, - 'session_name': s_name, - 'date_of_acquisition': date, - 'specimen_id': specimen_id, - 'session_type': s_type, - 'equipment_name': 'MESO2.0', - 'donor_id': 20+s_id, - 'full_genotype': genotype, - 'sex': ['m', 'f'][s_id % 2], - 'age_in_days': s_id*7, - 'foraging_id': s_id+30, - 'mouse_id': s_id+40, - 'reporter_line': reporter, - 'driver_line': driver} + datum = { + "behavior_session_id": s_id, + "session_name": s_name, + "date_of_acquisition": date, + "specimen_id": specimen_id, + "session_type": s_type, + "equipment_name": "MESO2.0", + "donor_id": 20 + s_id, + "full_genotype": genotype, + "sex": ["m", "f"][s_id % 2], + "age_in_days": s_id * 7, + "foraging_id": s_id + 30, + "mouse_id": s_id + 40, + "reporter_line": reporter, + "driver_line": driver, + } behavior_session_list.append(datum) @@ -191,16 +166,14 @@ def behavior_session_table(behavior_session_data_fixture): index = [] for datum in behavior_session_data_fixture: datum = copy.deepcopy(datum) - index.append(datum.pop('behavior_session_id')) + index.append(datum.pop("behavior_session_id")) data.append(datum) - df = pd.DataFrame( - data, - index=pd.Index(index, name='behavior_session_id')) + df = pd.DataFrame(data, index=pd.Index(index, name="behavior_session_id")) return df -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def behavior_session_to_ophys_session_map(behavior_session_id_list): """ Dict mapping behavior_session_id to ophys_session_id. @@ -217,7 +190,7 @@ def behavior_session_to_ophys_session_map(behavior_session_id_list): return lookup -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def ophys_session_to_experiment_map(behavior_session_to_ophys_session_map): """ Dict mapping ophys_session_id to a list of ophys_experiment_ids @@ -229,13 +202,13 @@ def ophys_session_to_experiment_map(behavior_session_to_ophys_session_map): ophys_vals = list(behavior_session_to_ophys_session_map.values()) ophys_vals.sort() for ii in ophys_vals: - lookup[ii] = list(range(i0, i0+dd)) + lookup[ii] = list(range(i0, i0 + dd)) i0 += dd return lookup -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def ophys_experiment_to_container_map(ophys_session_to_experiment_map): """ Dict mapping ophys_experiment_id to a list of ophys_container_ids @@ -256,7 +229,7 @@ def ophys_experiment_to_container_map(ophys_session_to_experiment_map): return lookup -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def container_state_lookup(ophys_experiment_to_container_map): """ Dict mapping ophys_container_id to container_workflow_state @@ -271,13 +244,13 @@ def container_state_lookup(ophys_experiment_to_container_map): local_container_list = ophys_experiment_to_container_map[exp_id] for container_id in local_container_list: assert container_id not in lookup - lookup[container_id] = 'junk' + lookup[container_id] = "junk" good_container = rng.choice(local_container_list) - lookup[good_container] = 'published' + lookup[good_container] = "published" return lookup -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def experiment_state_lookup(ophys_session_data_fixture): """ Dict mapping ophys_experiment_id to the experiment_workflow_state @@ -285,24 +258,26 @@ def experiment_state_lookup(ophys_session_data_fixture): rng = np.random.default_rng(772312) exp_id_list = [] for datum in ophys_session_data_fixture: - for exp_id in datum['ophys_experiment_id']: + for exp_id in datum["ophys_experiment_id"]: if exp_id not in exp_id_list: exp_id_list.append(exp_id) lookup = dict() for exp_id in exp_id_list: - lookup[exp_id] = ['passed', 'failed'][rng.integers(0, 2)] + lookup[exp_id] = ["passed", "failed"][rng.integers(0, 2)] return lookup -@pytest.fixture(scope='session') -def ophys_session_data_fixture(project_code_lookup, - session_name_lookup, - date_of_acquisition_lookup, - specimen_id_lookup, - session_type_lookup, - ophys_session_to_experiment_map, - ophys_experiment_to_container_map, - behavior_session_to_ophys_session_map): +@pytest.fixture(scope="session") +def ophys_session_data_fixture( + project_code_lookup, + session_name_lookup, + date_of_acquisition_lookup, + specimen_id_lookup, + session_type_lookup, + ophys_session_to_experiment_map, + ophys_experiment_to_container_map, + behavior_session_to_ophys_session_map, +): """ List of dicts. Each dict is one entry in the ophys_session_table as returned @@ -318,16 +293,17 @@ def ophys_session_data_fixture(project_code_lookup, for exp_id in ophys_session_to_experiment_map[o_session]: container_list += ophys_experiment_to_container_map[exp_id] - datum = {'behavior_session_id': beh, - 'project_code': project_code_lookup[beh], - 'date_of_acquisition': date_of_acquisition_lookup[beh], - 'session_name': session_name_lookup[beh], - 'session_type': session_type_lookup[beh], - 'ophys_experiment_id': - ophys_session_to_experiment_map[o_session], - 'ophys_container_id': container_list, - 'specimen_id': 9*beh, - 'ophys_session_id': o_session} + datum = { + "behavior_session_id": beh, + "project_code": project_code_lookup[beh], + "date_of_acquisition": date_of_acquisition_lookup[beh], + "session_name": session_name_lookup[beh], + "session_type": session_type_lookup[beh], + "ophys_experiment_id": ophys_session_to_experiment_map[o_session], + "ophys_container_id": container_list, + "specimen_id": 9 * beh, + "ophys_session_id": o_session, + } ophys_session_list.append(datum) return ophys_session_list @@ -342,20 +318,17 @@ def ophys_session_table(ophys_session_data_fixture): index = [] for datum in ophys_session_data_fixture: datum = copy.deepcopy(datum) - index.append(datum.pop('ophys_session_id')) + index.append(datum.pop("ophys_session_id")) data.append(datum) - df = pd.DataFrame( - data, - index=pd.Index(index, name='ophys_session_id')) + df = pd.DataFrame(data, index=pd.Index(index, name="ophys_session_id")) return df -@pytest.fixture(scope='session') -def ophys_experiment_data_fixture(ophys_session_data_fixture, - experiment_state_lookup, - container_state_lookup, - ophys_experiment_to_container_map): +@pytest.fixture(scope="session") +def ophys_experiment_data_fixture( + ophys_session_data_fixture, experiment_state_lookup, container_state_lookup, ophys_experiment_to_container_map +): """ List of dicts. Each dict is an entry in the ophys_experiment_table as returned @@ -366,27 +339,24 @@ def ophys_experiment_data_fixture(ophys_session_data_fixture, isi_id = 4000 ophys_experiment_list = [] for ophys_session in ophys_session_data_fixture: - for i_experiment in ophys_session['ophys_experiment_id']: + for i_experiment in ophys_session["ophys_experiment_id"]: cntr_id_list = ophys_experiment_to_container_map[i_experiment] for container_id in cntr_id_list: datum = { - 'ophys_session_id': ophys_session['ophys_session_id'], - 'session_type': ophys_session['session_type'], - 'behavior_session_id': - ophys_session['behavior_session_id'], - 'ophys_container_id': container_id, - 'container_workflow_state': - container_state_lookup[container_id], - 'experiment_workflow_state': - experiment_state_lookup[i_experiment], - 'session_name': ophys_session['session_name'], - 'date_of_acquisition': - ophys_session['date_of_acquisition'], - 'isi_experiment_id': isi_id, - 'imaging_depth': rng.integers(50, 200), - 'targeted_tructure': 'VISp', - 'published_at': ophys_session['date_of_acquisition'], - 'ophys_experiment_id': i_experiment} + "ophys_session_id": ophys_session["ophys_session_id"], + "session_type": ophys_session["session_type"], + "behavior_session_id": ophys_session["behavior_session_id"], + "ophys_container_id": container_id, + "container_workflow_state": container_state_lookup[container_id], + "experiment_workflow_state": experiment_state_lookup[i_experiment], + "session_name": ophys_session["session_name"], + "date_of_acquisition": ophys_session["date_of_acquisition"], + "isi_experiment_id": isi_id, + "imaging_depth": rng.integers(50, 200), + "targeted_tructure": "VISp", + "published_at": ophys_session["date_of_acquisition"], + "ophys_experiment_id": i_experiment, + } ophys_experiment_list.append(datum) return ophys_experiment_list @@ -401,18 +371,15 @@ def ophys_experiments_table(ophys_experiment_data_fixture): index = [] for datum in ophys_experiment_data_fixture: datum = copy.deepcopy(datum) - index.append(datum.pop('ophys_experiment_id')) + index.append(datum.pop("ophys_experiment_id")) data.append(datum) - df = pd.DataFrame( - data, - index=pd.Index(index, name='ophys_experiment_id')) + df = pd.DataFrame(data, index=pd.Index(index, name="ophys_experiment_id")) return df @pytest.fixture() -def intermediate_behavior_table(behavior_session_table, - mock_api): +def intermediate_behavior_table(behavior_session_table, mock_api): """ A dataframe created by adding/transfrming columns in behavior_session_table. This table is used to produce the @@ -420,32 +387,26 @@ def intermediate_behavior_table(behavior_session_table, """ df = behavior_session_table.copy(deep=True) - df['reporter_line'] = df['reporter_line'].apply( - ReporterLine.parse) - df['cre_line'] = df['full_genotype'].apply( - lambda x: FullGenotype(full_genotype=x).parse_cre_line()) - df['indicator'] = df['reporter_line'].apply( - lambda x: ReporterLine(reporter_line=x).parse_indicator()) - - df['prior_exposures_to_session_type'] = \ - get_prior_exposures_to_session_type(df=df) - df['prior_exposures_to_image_set'] = \ - get_prior_exposures_to_image_set(df=df) - df['prior_exposures_to_omissions'] = \ - get_prior_exposures_to_omissions( - df=df, - fetch_api=mock_api) + df["reporter_line"] = df["reporter_line"].apply(ReporterLine.parse) + df["cre_line"] = df["full_genotype"].apply(lambda x: FullGenotype(full_genotype=x).parse_cre_line()) + df["indicator"] = df["reporter_line"].apply(lambda x: ReporterLine(reporter_line=x).parse_indicator()) + + df["prior_exposures_to_session_type"] = get_prior_exposures_to_session_type(df=df) + df["prior_exposures_to_image_set"] = get_prior_exposures_to_image_set(df=df) + df["prior_exposures_to_omissions"] = get_prior_exposures_to_omissions(df=df, fetch_api=mock_api) return df @pytest.fixture() -def expected_behavior_session_table(intermediate_behavior_table, - ophys_session_data_fixture, - mock_api, - container_state_lookup, - experiment_state_lookup, - ophys_experiment_to_container_map, - request): +def expected_behavior_session_table( + intermediate_behavior_table, + ophys_session_data_fixture, + mock_api, + container_state_lookup, + experiment_state_lookup, + ophys_experiment_to_container_map, + request, +): """ The behavior_session_table as returned by the user-facing methods in behavior_project_cache. @@ -456,49 +417,48 @@ def expected_behavior_session_table(intermediate_behavior_table, 'passed_only' points to the value of passed_only used to generate the dataframe. """ - if hasattr(request, 'param'): + if hasattr(request, "param"): passed_only = request.param else: passed_only = True df = intermediate_behavior_table.copy(deep=True) - df['session_name_behavior'] = df['session_name'] - df = df.drop(['session_name'], axis=1) - df['specimen_id_behavior'] = df['specimen_id'] - df = df.drop(['specimen_id'], axis=1) + df["session_name_behavior"] = df["session_name"] + df = df.drop(["session_name"], axis=1) + df["specimen_id_behavior"] = df["specimen_id"] + df = df.drop(["specimen_id"], axis=1) - df['project_code'] = None - df['ophys_session_id'] = None - df['session_name_ophys'] = None - df['ophys_experiment_id'] = None - df['ophys_container_id'] = None - df['specimen_id_ophys'] = None + df["project_code"] = None + df["ophys_session_id"] = None + df["session_name_ophys"] = None + df["ophys_experiment_id"] = None + df["ophys_container_id"] = None + df["specimen_id_ophys"] = None session_number = [] - for v in df['session_type'].values: - if 'OPHYS' in v: + for v in df["session_type"].values: + if "OPHYS" in v: session_number.append(1) else: session_number.append(None) - df['session_number'] = session_number + df["session_number"] = session_number for ophys_session in ophys_session_data_fixture: - index = ophys_session['behavior_session_id'] - df.at[index, 'project_code'] = ophys_session['project_code'] - df.at[index, 'ophys_session_id'] = ophys_session['ophys_session_id'] - df.at[index, 'session_name_ophys'] = ophys_session['session_name'] + index = ophys_session["behavior_session_id"] + df.at[index, "project_code"] = ophys_session["project_code"] + df.at[index, "ophys_session_id"] = ophys_session["ophys_session_id"] + df.at[index, "session_name_ophys"] = ophys_session["session_name"] container_id_list = set() exp_id_list = set() - for exp_id in ophys_session['ophys_experiment_id']: + for exp_id in ophys_session["ophys_experiment_id"]: # because SessionsTable does not filter on experiment state exp_id_list.add(exp_id) - if experiment_state_lookup[exp_id] != 'passed' and passed_only: + if experiment_state_lookup[exp_id] != "passed" and passed_only: continue for container_id in ophys_experiment_to_container_map[exp_id]: - is_published = (container_state_lookup[container_id] - == 'published') + is_published = container_state_lookup[container_id] == "published" if is_published or not passed_only: container_id_list.add(container_id) @@ -507,21 +467,19 @@ def expected_behavior_session_table(intermediate_behavior_table, container_id_list = list(container_id_list) container_id_list.sort() - df.at[index, 'ophys_container_id'] = container_id_list - df.at[index, 'ophys_experiment_id'] = exp_id_list - df.at[index, 'specimen_id_ophys'] = ophys_session['specimen_id'] + df.at[index, "ophys_container_id"] = container_id_list + df.at[index, "ophys_experiment_id"] = exp_id_list + df.at[index, "specimen_id_ophys"] = ophys_session["specimen_id"] - df['ophys_session_id'] = df['ophys_session_id'].astype(float) + df["ophys_session_id"] = df["ophys_session_id"].astype(float) - return {'df': df, 'passed_only': passed_only} + return {"df": df, "passed_only": passed_only} @pytest.fixture() -def expected_experiments_table(ophys_experiments_table, - container_state_lookup, - experiment_state_lookup, - intermediate_behavior_table, - request): +def expected_experiments_table( + ophys_experiments_table, container_state_lookup, experiment_state_lookup, intermediate_behavior_table, request +): """ The experiments_table as returned by the user-facing methods in the behavior_project_cache @@ -533,7 +491,7 @@ def expected_experiments_table(ophys_experiments_table, generate the dataframe. """ - if hasattr(request, 'param'): + if hasattr(request, "param"): passed_only = request.param else: passed_only = True @@ -545,54 +503,58 @@ def expected_experiments_table(ophys_experiments_table, expected = expected.query("experiment_workflow_state=='passed'") expected = expected.query("container_workflow_state=='published'") - expected = expected.join(behavior_table[ - ['equipment_name', - 'donor_id', - 'full_genotype', - 'mouse_id', - 'driver_line', - 'sex', - 'age_in_days', - 'foraging_id', - 'reporter_line', - 'specimen_id', - 'prior_exposures_to_session_type', - 'prior_exposures_to_image_set', - 'prior_exposures_to_omissions', - 'indicator', - 'cre_line']], - on='behavior_session_id') - - expected = expected.join(behavior_table[ - ['session_name']], - on='behavior_session_id', - rsuffix='_behavior') + expected = expected.join( + behavior_table[ + [ + "equipment_name", + "donor_id", + "full_genotype", + "mouse_id", + "driver_line", + "sex", + "age_in_days", + "foraging_id", + "reporter_line", + "specimen_id", + "prior_exposures_to_session_type", + "prior_exposures_to_image_set", + "prior_exposures_to_omissions", + "indicator", + "cre_line", + ] + ], + on="behavior_session_id", + ) + + expected = expected.join(behavior_table[["session_name"]], on="behavior_session_id", rsuffix="_behavior") session_number = [] - for v in expected['session_type'].values: - if 'OPHYS' in v: + for v in expected["session_type"].values: + if "OPHYS" in v: session_number.append(1) else: session_number.append(None) - expected['session_number'] = session_number + expected["session_number"] = session_number expected = add_experience_level_ophys(expected) expected = add_passive_flag_to_ophys_experiment_table(expected) expected = add_image_set_to_experiment_table(expected) - expected['session_name_ophys'] = expected['session_name'] - expected = expected.drop(['session_name'], axis=1) + expected["session_name_ophys"] = expected["session_name"] + expected = expected.drop(["session_name"], axis=1) - return {'df': expected, 'passed_only': passed_only} + return {"df": expected, "passed_only": passed_only} @pytest.fixture() -def expected_ophys_session_table(ophys_session_table, - intermediate_behavior_table, - container_state_lookup, - experiment_state_lookup, - ophys_experiment_to_container_map, - request): +def expected_ophys_session_table( + ophys_session_table, + intermediate_behavior_table, + container_state_lookup, + experiment_state_lookup, + ophys_experiment_to_container_map, + request, +): """ The ophys_session_table as returned by the user-facing methods in the behavior_project_cache. @@ -603,7 +565,7 @@ def expected_ophys_session_table(ophys_session_table, 'passed_only' points to the value of passed_only used to generate the dataframe. """ - if hasattr(request, 'param'): + if hasattr(request, "param"): passed_only = request.param else: passed_only = True @@ -613,10 +575,10 @@ def expected_ophys_session_table(ophys_session_table, valid_containers = set() valid_experiments = set() for exp_id in ophys_experiment_to_container_map: - if experiment_state_lookup[exp_id] != 'passed': + if experiment_state_lookup[exp_id] != "passed": continue for container_id in ophys_experiment_to_container_map[exp_id]: - if container_state_lookup[container_id] == 'published': + if container_state_lookup[container_id] == "published": valid_containers.add(container_id) valid_experiments.add(exp_id) @@ -625,53 +587,52 @@ def expected_ophys_session_table(ophys_session_table, # that is probably supposed to happen at the level # of the LIMS query (?) for index_val in expected.index.values: - raw_containers = expected.loc[index_val]['ophys_container_id'] + raw_containers = expected.loc[index_val]["ophys_container_id"] container_id = [c for c in raw_containers if c in valid_containers] - expected.at[index_val, 'ophys_container_id'] = container_id + expected.at[index_val, "ophys_container_id"] = container_id behavior_table = intermediate_behavior_table.copy(deep=True) - expected = expected.join(behavior_table[ - ['equipment_name', - 'donor_id', - 'full_genotype', - 'mouse_id', - 'driver_line', - 'sex', - 'age_in_days', - 'foraging_id', - 'reporter_line', - 'prior_exposures_to_session_type', - 'prior_exposures_to_image_set', - 'prior_exposures_to_omissions', - 'indicator', - 'cre_line']], - on='behavior_session_id') + expected = expected.join( + behavior_table[ + [ + "equipment_name", + "donor_id", + "full_genotype", + "mouse_id", + "driver_line", + "sex", + "age_in_days", + "foraging_id", + "reporter_line", + "prior_exposures_to_session_type", + "prior_exposures_to_image_set", + "prior_exposures_to_omissions", + "indicator", + "cre_line", + ] + ], + on="behavior_session_id", + ) expected = expected.join( - behavior_table[['specimen_id', 'session_name']], - on='behavior_session_id', - rsuffix='_behavior', - lsuffix='_ophys') + behavior_table[["specimen_id", "session_name"]], on="behavior_session_id", rsuffix="_behavior", lsuffix="_ophys" + ) session_number = [] - for v in expected['session_type'].values: - if 'OPHYS' in v: + for v in expected["session_type"].values: + if "OPHYS" in v: session_number.append(1) else: session_number.append(None) - expected['session_number'] = session_number + expected["session_number"] = session_number - return {'df': expected, 'passed_only': passed_only} + return {"df": expected, "passed_only": passed_only} @pytest.fixture -def mock_api(ophys_session_table, - behavior_session_table, - ophys_experiments_table): - +def mock_api(ophys_session_table, behavior_session_table, ophys_experiments_table): class MockApi: - def get_ophys_session_table(self, n_workers=1): return ophys_session_table @@ -694,7 +655,5 @@ def get_behavior_stage_parameters(self, foraging_ids): def TempdirBehaviorCache(mock_api, request): temp_dir = tempfile.TemporaryDirectory() manifest = os.path.join(temp_dir.name, "manifest.json") - yield VisualBehaviorOphysProjectCache(fetch_api=mock_api(), - cache=request.param, - manifest=manifest) + yield VisualBehaviorOphysProjectCache(fetch_api=mock_api(), cache=request.param, manifest=manifest) temp_dir.cleanup() diff --git a/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/test_behavior_project_cache.py b/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/test_behavior_project_cache.py index 8cbca56f6d..537f07077a 100644 --- a/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/test_behavior_project_cache.py +++ b/allensdk/test/brain_observatory/behavior/behavior_project_cache_data_model/test_behavior_project_cache.py @@ -4,14 +4,13 @@ @pytest.mark.parametrize("TempdirBehaviorCache", [True], indirect=True) -def test_session_table_reads_from_cache(TempdirBehaviorCache, - caplog): +def test_session_table_reads_from_cache(TempdirBehaviorCache, caplog): caplog.set_level(logging.INFO, logger="call_caching") cache = TempdirBehaviorCache cache.get_ophys_session_table() - reading_tuple = ('call_caching', logging.INFO, 'Reading data from cache') - no_file_tuple = ('call_caching', logging.INFO, 'No cache file found.') - writing_tuple = ('call_caching', logging.INFO, 'Writing data to cache') + reading_tuple = ("call_caching", logging.INFO, "Reading data from cache") + no_file_tuple = ("call_caching", logging.INFO, "No cache file found.") + writing_tuple = ("call_caching", logging.INFO, "Writing data to cache") assert reading_tuple in caplog.record_tuples assert no_file_tuple in caplog.record_tuples assert writing_tuple in caplog.record_tuples @@ -24,14 +23,13 @@ def test_session_table_reads_from_cache(TempdirBehaviorCache, @pytest.mark.parametrize("TempdirBehaviorCache", [True], indirect=True) -def test_behavior_table_reads_from_cache(TempdirBehaviorCache, - caplog): +def test_behavior_table_reads_from_cache(TempdirBehaviorCache, caplog): caplog.set_level(logging.INFO, logger="call_caching") cache = TempdirBehaviorCache cache.get_behavior_session_table() - reading_tuple = ('call_caching', logging.INFO, 'Reading data from cache') - no_file_tuple = ('call_caching', logging.INFO, 'No cache file found.') - writing_tuple = ('call_caching', logging.INFO, 'Writing data to cache') + reading_tuple = ("call_caching", logging.INFO, "Reading data from cache") + no_file_tuple = ("call_caching", logging.INFO, "No cache file found.") + writing_tuple = ("call_caching", logging.INFO, "Writing data to cache") assert reading_tuple in caplog.record_tuples assert no_file_tuple in caplog.record_tuples assert writing_tuple in caplog.record_tuples @@ -44,22 +42,16 @@ def test_behavior_table_reads_from_cache(TempdirBehaviorCache, @pytest.mark.parametrize("TempdirBehaviorCache", [True, False], indirect=True) -def test_get_ophys_session_table_by_experiment(TempdirBehaviorCache, - expected_ophys_session_table): - - raw = expected_ophys_session_table['df'][['ophys_experiment_id']] +def test_get_ophys_session_table_by_experiment(TempdirBehaviorCache, expected_ophys_session_table): + raw = expected_ophys_session_table["df"][["ophys_experiment_id"]] data = [] - for session_id, exp_id_list in zip(raw.index.values, - raw.ophys_experiment_id.values): + for session_id, exp_id_list in zip(raw.index.values, raw.ophys_experiment_id.values): for exp_id in exp_id_list: - data.append({'ophys_session_id': session_id, - 'ophys_experiment_id': exp_id}) + data.append({"ophys_session_id": session_id, "ophys_experiment_id": exp_id}) - expected = pd.DataFrame(data).set_index('ophys_experiment_id') + expected = pd.DataFrame(data).set_index("ophys_experiment_id") - actual = TempdirBehaviorCache.get_ophys_session_table( - index_column="ophys_experiment_id")[ - ["ophys_session_id"]] + actual = TempdirBehaviorCache.get_ophys_session_table(index_column="ophys_experiment_id")[["ophys_session_id"]] expected.index = expected.index.astype(actual.index.dtype) pd.testing.assert_frame_equal(expected, actual) @@ -71,43 +63,34 @@ def test_cloud_manifest_errors(TempdirBehaviorCache): Test that methods which should not exist for BehaviorProjectCaches that are not backed by CloudCaches raise NotImplementedError """ - msg = 'Method {mname} does not exist for this ' - msg += 'VisualBehaviorOphysProjectCache, which is based on MockApi' - with pytest.raises(NotImplementedError, - match=msg.format(mname='construct_local_manifest')): + msg = "Method {mname} does not exist for this " + msg += "VisualBehaviorOphysProjectCache, which is based on MockApi" + with pytest.raises(NotImplementedError, match=msg.format(mname="construct_local_manifest")): TempdirBehaviorCache.construct_local_manifest() - with pytest.raises(NotImplementedError, - match=msg.format(mname='compare_manifests')): - TempdirBehaviorCache.compare_manifests('a', 'b') + with pytest.raises(NotImplementedError, match=msg.format(mname="compare_manifests")): + TempdirBehaviorCache.compare_manifests("a", "b") - with pytest.raises(NotImplementedError, - match=msg.format(mname='load_latest_manifest')): + with pytest.raises(NotImplementedError, match=msg.format(mname="load_latest_manifest")): TempdirBehaviorCache.load_latest_manifest() - this_msg = msg.format(mname='latest_downloaded_manifest_file') - with pytest.raises(NotImplementedError, - match=this_msg): + this_msg = msg.format(mname="latest_downloaded_manifest_file") + with pytest.raises(NotImplementedError, match=this_msg): TempdirBehaviorCache.latest_downloaded_manifest_file() - with pytest.raises(NotImplementedError, - match=msg.format(mname='latest_manifest_file')): + with pytest.raises(NotImplementedError, match=msg.format(mname="latest_manifest_file")): TempdirBehaviorCache.latest_manifest_file() - with pytest.raises(NotImplementedError, - match=msg.format(mname='load_manifest')): - TempdirBehaviorCache.load_manifest('a') + with pytest.raises(NotImplementedError, match=msg.format(mname="load_manifest")): + TempdirBehaviorCache.load_manifest("a") - with pytest.raises(NotImplementedError, - match=msg.format(mname='current_manifest')): + with pytest.raises(NotImplementedError, match=msg.format(mname="current_manifest")): TempdirBehaviorCache.current_manifest() - this_msg = msg.format(mname='list_all_downloaded_manifests') - with pytest.raises(NotImplementedError, - match=this_msg): + this_msg = msg.format(mname="list_all_downloaded_manifests") + with pytest.raises(NotImplementedError, match=this_msg): TempdirBehaviorCache.list_all_downloaded_manifests() - this_msg = msg.format(mname='list_manifest_file_names') - with pytest.raises(NotImplementedError, - match=this_msg): + this_msg = msg.format(mname="list_manifest_file_names") + with pytest.raises(NotImplementedError, match=this_msg): TempdirBehaviorCache.list_manifest_file_names() diff --git a/allensdk/test/brain_observatory/behavior/conftest.py b/allensdk/test/brain_observatory/behavior/conftest.py index 5941b1e4fe..9b2c3afc0d 100644 --- a/allensdk/test/brain_observatory/behavior/conftest.py +++ b/allensdk/test/brain_observatory/behavior/conftest.py @@ -8,13 +8,12 @@ from pynwb import NWBFile from allensdk.test_utilities.custom_comparators import WhitespaceStrippedString -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import ( - BehaviorOphysExperiment) +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment def get_resources_dir(): behavior_dir = os.path.dirname(__file__) - return os.path.join(behavior_dir, 'resources') + return os.path.join(behavior_dir, "resources") def pytest_assertrepr_compare(config, op, left, right): @@ -23,14 +22,13 @@ def pytest_assertrepr_compare(config, op, left, right): right_compare = right.orig else: right_compare = right - return ["Comparing strings with whitespace stripped. ", - f"{left.orig} != {right_compare}.", "Diff:"] + left.diff + return ["Comparing strings with whitespace stripped. ", f"{left.orig} != {right_compare}.", "Diff:"] + left.diff def pytest_ignore_collect(path, config): - ''' The brain_observatory.ecephys submodule uses + """The brain_observatory.ecephys submodule uses python 3.6 features that may not be backwards compatible! - ''' + """ if sys.version_info < (3, 6): return True @@ -43,23 +41,13 @@ def behavior_stimuli_data_fixture(request): This fixture mimicks the behavior experiment stimuli data logs and allows parameterization for testing """ - images_set_log = request.param.get("images_set_log", [ - ('Image', 'im065', 5.809, 0)]) - images_draw_log = request.param.get("images_draw_log", [ - ([0] + [1] * 3 + [0] * 3) - ]) - grating_set_log = request.param.get("grating_set_log", [ - ('Ori', 90, 3.585, 0) - ]) - grating_draw_log = request.param.get("grating_draw_log", [ - ([0] + [1] * 3 + [0] * 3) - ]) - omitted_flash_frame_log = request.param.get("omitted_flash_frame_log", { - "grating_0": [] - }) + images_set_log = request.param.get("images_set_log", [("Image", "im065", 5.809, 0)]) + images_draw_log = request.param.get("images_draw_log", [([0] + [1] * 3 + [0] * 3)]) + grating_set_log = request.param.get("grating_set_log", [("Ori", 90, 3.585, 0)]) + grating_draw_log = request.param.get("grating_draw_log", [([0] + [1] * 3 + [0] * 3)]) + omitted_flash_frame_log = request.param.get("omitted_flash_frame_log", {"grating_0": []}) grating_phase = request.param.get("grating_phase", None) - grating_spatial_frequency = request.param.get("grating_spatial_frequency", - None) + grating_spatial_frequency = request.param.get("grating_spatial_frequency", None) has_images = request.param.get("has_images", True) has_grating = request.param.get("has_grating", True) @@ -69,27 +57,17 @@ def behavior_stimuli_data_fixture(request): image_data = { "set_log": images_set_log, "draw_log": images_draw_log, - "image_path": os.path.join(resources_dir, - 'stimulus_template', - 'input', - 'test_image_set.pkl') + "image_path": os.path.join(resources_dir, "stimulus_template", "input", "test_image_set.pkl"), } grating_data = { "set_log": grating_set_log, "draw_log": grating_draw_log, "phase": grating_phase, - "sf": grating_spatial_frequency + "sf": grating_spatial_frequency, } - data = { - "items": { - "behavior": { - "stimuli": {}, - "omitted_flash_frame_log": omitted_flash_frame_log - } - } - } + data = {"items": {"behavior": {"stimuli": {}, "omitted_flash_frame_log": omitted_flash_frame_log}}} if has_images: data["items"]["behavior"]["stimuli"]["images"] = image_data @@ -106,26 +84,25 @@ def skeletal_nwb_fixture(): Instantiate an NWB file that has no real data in it """ - timezone = pytz.timezone('UTC') + timezone = pytz.timezone("UTC") date = timezone.localize(datetime.datetime.now()) nwbfile = NWBFile( - session_description="dummy", - identifier="00001", - session_start_time=date, - file_create_date=date, - institution="Allen Institute for Brain Science" - ) + session_description="dummy", + identifier="00001", + session_start_time=date, + file_create_date=date, + institution="Allen Institute for Brain Science", + ) return nwbfile -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def behavior_ophys_experiment_fixture(): """ A valid BehaviorOphysExperiment instantiated from_lims """ experiment_id = 953443028 - experiment = BehaviorOphysExperiment.from_lims( - experiment_id) + experiment = BehaviorOphysExperiment.from_lims(experiment_id) return experiment diff --git a/allensdk/test/brain_observatory/behavior/data_files/conftest.py b/allensdk/test/brain_observatory/behavior/data_files/conftest.py index 35c7cfe7bc..7bfd5a6c2f 100644 --- a/allensdk/test/brain_observatory/behavior/data_files/conftest.py +++ b/allensdk/test/brain_observatory/behavior/data_files/conftest.py @@ -6,83 +6,67 @@ from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( BehaviorStimulusFile, MappingStimulusFile, - ReplayStimulusFile) + ReplayStimulusFile, +) @pytest.fixture -def behavior_pkl_fixture( - tmp_path_factory, - helper_functions): +def behavior_pkl_fixture(tmp_path_factory, helper_functions): """ Write a behavior pkl file to disk. Return a dict containing the path to that file, as well as the expected number of frames associated with the pickle file. """ - tmpdir = tmp_path_factory.mktemp('behavior_pkl') - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + tmpdir = tmp_path_factory.mktemp("behavior_pkl") + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) nframes = 17 - result = {'items': - {'behavior': - {'intervalsms': list(range(nframes-1))}}} + result = {"items": {"behavior": {"intervalsms": list(range(nframes - 1))}}} pd.to_pickle(result, pkl_path) - yield {'path': pkl_path, 'expected_frames': nframes} + yield {"path": pkl_path, "expected_frames": nframes} helper_functions.windows_safe_cleanup(file_path=pkl_path) @pytest.fixture -def general_pkl_fixture( - tmp_path_factory, - helper_functions): +def general_pkl_fixture(tmp_path_factory, helper_functions): """ Write a non-behavior stimulus pkl file to disk. Return a dict containing the path to that file, as well as the expected number of frames associated with the pickle file. """ - tmpdir = tmp_path_factory.mktemp('general_pkl') - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + tmpdir = tmp_path_factory.mktemp("general_pkl") + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) nframes = 19 - result = {'intervalsms': list(range(nframes-1))} + result = {"intervalsms": list(range(nframes - 1))} pd.to_pickle(result, pkl_path) - yield {'path': pkl_path, 'expected_frames': nframes} + yield {"path": pkl_path, "expected_frames": nframes} helper_functions.windows_safe_cleanup(file_path=pkl_path) @pytest.fixture -def behavior_stim_fixture( - behavior_pkl_fixture): +def behavior_stim_fixture(behavior_pkl_fixture): """ A BehaviorStimulusFile """ - return BehaviorStimulusFile.from_json( - dict_repr={"behavior_stimulus_file": - behavior_pkl_fixture["path"]}) + return BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": behavior_pkl_fixture["path"]}) @pytest.fixture -def replay_stim_fixture( - general_pkl_fixture): +def replay_stim_fixture(general_pkl_fixture): """ A ReplayStimulusFile """ - return ReplayStimulusFile.from_json( - dict_repr={"replay_stimulus_file": - general_pkl_fixture["path"]}) + return ReplayStimulusFile.from_json(dict_repr={"replay_stimulus_file": general_pkl_fixture["path"]}) @pytest.fixture -def mapping_stim_fixture( - general_pkl_fixture): +def mapping_stim_fixture(general_pkl_fixture): """ A MappingStimulusFile """ - return MappingStimulusFile.from_json( - dict_repr={"mapping_stimulus_file": - general_pkl_fixture["path"]}) + return MappingStimulusFile.from_json(dict_repr={"mapping_stimulus_file": general_pkl_fixture["path"]}) diff --git a/allensdk/test/brain_observatory/behavior/data_files/test_eye_tracking_metadata_file.py b/allensdk/test/brain_observatory/behavior/data_files/test_eye_tracking_metadata_file.py index 83c4054113..e23bbe0523 100644 --- a/allensdk/test/brain_observatory/behavior/data_files/test_eye_tracking_metadata_file.py +++ b/allensdk/test/brain_observatory/behavior/data_files/test_eye_tracking_metadata_file.py @@ -2,31 +2,24 @@ import tempfile import pathlib -from allensdk.brain_observatory.behavior.\ - data_files.eye_tracking_metadata_file import ( - EyeTrackingMetadataFile) +from allensdk.brain_observatory.behavior.data_files.eye_tracking_metadata_file import EyeTrackingMetadataFile -def test_eye_tracking_metadata_file( - tmp_path_factory, - helper_functions): +def test_eye_tracking_metadata_file(tmp_path_factory, helper_functions): """ Just a smoke test for EyeTrackingMetadataFile.from_json """ - json_path = pathlib.Path(tempfile.mkstemp(suffix='.json')[1]) + json_path = pathlib.Path(tempfile.mkstemp(suffix=".json")[1]) - data = {'a': [1, {'b': 3}], - 'c': 'x'} + data = {"a": [1, {"b": 3}], "c": "x"} - with open(json_path, 'w') as out_file: + with open(json_path, "w") as out_file: out_file.write(json.dumps(data)) - dict_repr = {'raw_eye_tracking_video_meta_data': - str(json_path.resolve().absolute())} + dict_repr = {"raw_eye_tracking_video_meta_data": str(json_path.resolve().absolute())} - data_file = EyeTrackingMetadataFile.from_json( - dict_repr=dict_repr) + data_file = EyeTrackingMetadataFile.from_json(dict_repr=dict_repr) assert isinstance(data_file, EyeTrackingMetadataFile) assert data_file.data == data diff --git a/allensdk/test/brain_observatory/behavior/data_files/test_stimulus_file.py b/allensdk/test/brain_observatory/behavior/data_files/test_stimulus_file.py index 1d26681f21..865bac1b24 100644 --- a/allensdk/test/brain_observatory/behavior/data_files/test_stimulus_file.py +++ b/allensdk/test/brain_observatory/behavior/data_files/test_stimulus_file.py @@ -69,22 +69,16 @@ def test_stimulus_file_from_lims(stimulus_file_fixture, behavior_session_id): # Basic test case mock_db_conn.fetchone.return_value = str(stim_pkl_path) - stimulus_file = BehaviorStimulusFile.from_lims( - mock_db_conn, behavior_session_id - ) + stimulus_file = BehaviorStimulusFile.from_lims(mock_db_conn, behavior_session_id) assert stimulus_file.data == stim_pkl_data # Now test caching by deleting stimulus_file and also asserting db # `fetchone` called only once stim_pkl_path.unlink() - stimfile_cached = BehaviorStimulusFile.from_lims( - mock_db_conn, behavior_session_id - ) + stimfile_cached = BehaviorStimulusFile.from_lims(mock_db_conn, behavior_session_id) assert stimfile_cached.data == stim_pkl_data - query = BEHAVIOR_STIMULUS_FILE_QUERY_TEMPLATE.format( - behavior_session_id=behavior_session_id - ) + query = BEHAVIOR_STIMULUS_FILE_QUERY_TEMPLATE.format(behavior_session_id=behavior_session_id) mock_db_conn.fetchone.assert_called_once_with(query, strict=True) @@ -134,9 +128,7 @@ def test_malformed_behavior_pkl(general_pkl_fixture): _ = stim.num_frames -def test_stimulus_file_lookup( - behavior_stim_fixture, mapping_stim_fixture, replay_stim_fixture -): +def test_stimulus_file_lookup(behavior_stim_fixture, mapping_stim_fixture, replay_stim_fixture): """ Smoke test of StimulusFileLookup """ diff --git a/allensdk/test/brain_observatory/behavior/data_files/test_sync_file.py b/allensdk/test/brain_observatory/behavior/data_files/test_sync_file.py index 5d5eefe6a7..99fbdacd25 100644 --- a/allensdk/test/brain_observatory/behavior/data_files/test_sync_file.py +++ b/allensdk/test/brain_observatory/behavior/data_files/test_sync_file.py @@ -8,9 +8,7 @@ from allensdk.internal.api import PostgresQueryMixin from allensdk.brain_observatory.behavior.data_files import SyncFile -from allensdk.brain_observatory.behavior.data_files.sync_file import ( - _get_sync_file_query_template -) +from allensdk.brain_observatory.behavior.data_files.sync_file import _get_sync_file_query_template @pytest.fixture @@ -32,18 +30,18 @@ def mock_get_sync_data(sync_path, permissive): return data -@pytest.mark.parametrize("sync_file_fixture", [ - ({"sync_data": [2, 3, 4, 5]}), -], indirect=["sync_file_fixture"]) +@pytest.mark.parametrize( + "sync_file_fixture", + [ + ({"sync_data": [2, 3, 4, 5]}), + ], + indirect=["sync_file_fixture"], +) def test_sync_file_from_json(monkeypatch, sync_file_fixture): sync_path, sync_data = sync_file_fixture with monkeypatch.context() as m: - m.setattr( - "allensdk.brain_observatory.behavior.data_files" - ".sync_file.get_sync_data", - mock_get_sync_data - ) + m.setattr("allensdk.brain_observatory.behavior.data_files.sync_file.get_sync_data", mock_get_sync_data) # Basic test case input_json_dict = {"sync_file": str(sync_path)} @@ -56,25 +54,18 @@ def test_sync_file_from_json(monkeypatch, sync_file_fixture): assert np.allclose(sync_file_cached.data, sync_data) -@pytest.mark.parametrize("sync_file_fixture, ophys_experiment_id", [ - ({"sync_data": [2, 3, 4, 5]}, 12), - ({"sync_data": [2, 3, 4, 5]}, 8) -], indirect=["sync_file_fixture"]) -def test_sync_file_from_lims( - monkeypatch, - sync_file_fixture, - ophys_experiment_id -): +@pytest.mark.parametrize( + "sync_file_fixture, ophys_experiment_id", + [({"sync_data": [2, 3, 4, 5]}, 12), ({"sync_data": [2, 3, 4, 5]}, 8)], + indirect=["sync_file_fixture"], +) +def test_sync_file_from_lims(monkeypatch, sync_file_fixture, ophys_experiment_id): sync_path, sync_data = sync_file_fixture mock_db_conn = create_autospec(PostgresQueryMixin, instance=True) with monkeypatch.context() as m: - m.setattr( - "allensdk.brain_observatory.behavior.data_files" - ".sync_file.get_sync_data", - mock_get_sync_data - ) + m.setattr("allensdk.brain_observatory.behavior.data_files.sync_file.get_sync_data", mock_get_sync_data) # Basic test case mock_db_conn.fetchone.return_value = str(sync_path) @@ -87,8 +78,6 @@ def test_sync_file_from_lims( stimfile_cached = SyncFile.from_lims(mock_db_conn, ophys_experiment_id) np.allclose(stimfile_cached.data, sync_data) - query = _get_sync_file_query_template( - behavior_session_id=ophys_experiment_id - ) + query = _get_sync_file_query_template(behavior_session_id=ophys_experiment_id) mock_db_conn.fetchone.assert_called_once_with(query, strict=True) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/base/test_data_object.py b/allensdk/test/brain_observatory/behavior/data_objects/base/test_data_object.py index 6ddd08148c..b73212e35d 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/base/test_data_object.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/base/test_data_object.py @@ -7,18 +7,19 @@ class TestDataObject: def test_to_dict_simple(self): class Simple(DataObject): def __init__(self): - super().__init__(name='simple', value=1) + super().__init__(name="simple", value=1) + s = Simple() - assert s.to_dict() == {'simple': 1} + assert s.to_dict() == {"simple": 1} def test_to_dict_nested(self): class B(DataObject): def __init__(self): - super().__init__(name='b', value='!') + super().__init__(name="b", value="!") class A(DataObject): def __init__(self, b: B): - super().__init__(name='a', value=None, is_value_self=True) + super().__init__(name="a", value=None, is_value_self=True) self._b = b @property @@ -27,18 +28,19 @@ def prop1(self): @property def prop2(self): - return '@' + return "@" + a = A(b=B()) - assert a.to_dict() == {'a': {'b': '!', 'prop2': '@'}} + assert a.to_dict() == {"a": {"b": "!", "prop2": "@"}} def test_to_dict_double_nested(self): class C(DataObject): def __init__(self): - super().__init__(name='c', value='!!!') + super().__init__(name="c", value="!!!") class B(DataObject): def __init__(self, c: C): - super().__init__(name='b', value=None, is_value_self=True) + super().__init__(name="b", value=None, is_value_self=True) self._c = c @property @@ -47,11 +49,11 @@ def prop1(self): @property def prop2(self): - return '!!' + return "!!" class A(DataObject): def __init__(self, b: B): - super().__init__(name='a', value=None, is_value_self=True) + super().__init__(name="a", value=None, is_value_self=True) self._b = b @property @@ -60,22 +62,21 @@ def prop1(self): @property def prop2(self): - return '@' + return "@" a = A(b=B(c=C())) - assert a.to_dict() == {'a': {'b': {'c': '!!!', 'prop2': '!!'}, - 'prop2': '@'}} + assert a.to_dict() == {"a": {"b": {"c": "!!!", "prop2": "!!"}, "prop2": "@"}} def test_not_equals(self): - s1 = DataObject(name='s1', value=1) - s2 = DataObject(name='s1', value='1') + s1 = DataObject(name="s1", value=1) + s2 = DataObject(name="s1", value="1") assert s1 != s2 def test_exclude_equals(self): - s1 = DataObject(name='s1', value=1, exclude_from_equals={'s1'}) - s2 = DataObject(name='s1', value='1') + s1 = DataObject(name="s1", value=1, exclude_from_equals={"s1"}) + s2 = DataObject(name="s1", value="1") assert s1 == s2 def test_cannot_compare(self): with pytest.raises(NotImplementedError): - assert DataObject(name='foo', value=1) == 1 + assert DataObject(name="foo", value=1) == 1 diff --git a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_table.py b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_table.py index e9124ce207..7cb6f814fd 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_table.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_table.py @@ -6,27 +6,18 @@ import pandas as pd import pynwb import pytest -from allensdk.brain_observatory.behavior.data_files\ - .eye_tracking_metadata_file import \ - EyeTrackingMetadataFile +from allensdk.brain_observatory.behavior.data_files.eye_tracking_metadata_file import EyeTrackingMetadataFile from allensdk.brain_observatory.sync_dataset import Dataset as SyncDataset from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory import sync_utilities -from allensdk.brain_observatory.behavior.data_files import \ - SyncFile -from allensdk.brain_observatory.behavior.data_files.eye_tracking_file import \ - EyeTrackingFile +from allensdk.brain_observatory.behavior.data_files import SyncFile +from allensdk.brain_observatory.behavior.data_files.eye_tracking_file import EyeTrackingFile from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.eye_tracking \ - .eye_tracking_table import \ - EyeTrackingTable -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest -from allensdk.test.brain_observatory.behavior.test_eye_tracking_processing \ - import \ - create_refined_eye_tracking_df +from allensdk.brain_observatory.behavior.data_objects.eye_tracking.eye_tracking_table import EyeTrackingTable +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest +from allensdk.test.brain_observatory.behavior.test_eye_tracking_processing import create_refined_eye_tracking_df class TestFromDataFile(LimsTest): @@ -35,19 +26,16 @@ def setup_class(cls): cls.ophys_experiment_id = 994278291 dir = Path(__file__).parent.parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - df = pd.read_pickle(str(test_data_dir / 'eye_tracking_table.pkl')) + df = pd.read_pickle(str(test_data_dir / "eye_tracking_table.pkl")) cls.expected = EyeTrackingTable(eye_tracking=df) @pytest.mark.requires_bamboo def test_from_data_file(self): - behavior_session_id = BehaviorSessionId.from_lims( - db=self.dbconn, ophys_experiment_id=self.ophys_experiment_id) - etf = EyeTrackingFile.from_lims( - behavior_session_id=behavior_session_id.value, db=self.dbconn) - sync_file = SyncFile.from_lims( - behavior_session_id=behavior_session_id.value, db=self.dbconn) + behavior_session_id = BehaviorSessionId.from_lims(db=self.dbconn, ophys_experiment_id=self.ophys_experiment_id) + etf = EyeTrackingFile.from_lims(behavior_session_id=behavior_session_id.value, db=self.dbconn) + sync_file = SyncFile.from_lims(behavior_session_id=behavior_session_id.value, db=self.dbconn) sync_path = Path(sync_file.filepath) @@ -55,15 +43,12 @@ def test_from_data_file(self): session_sync_file=sync_path, sync_line_label_keys=SyncDataset.EYE_TRACKING_KEYS, drop_frames=None, - trim_after_spike=False) + trim_after_spike=False, + ) - stimulus_timestamps = StimulusTimestamps( - timestamps=frame_times, - monitor_delay=0.0) + stimulus_timestamps = StimulusTimestamps(timestamps=frame_times, monitor_delay=0.0) - ett = EyeTrackingTable.from_data_file( - data_file=etf, - stimulus_timestamps=stimulus_timestamps) + ett = EyeTrackingTable.from_data_file(data_file=etf, stimulus_timestamps=stimulus_timestamps) # filter to first 100 values for testing ett = EyeTrackingTable(eye_tracking=ett.value.iloc[:100]) @@ -74,36 +59,77 @@ class TestNWB: @classmethod def setup_class(cls): dir = Path(__file__).parent.parent.resolve() - cls.test_data_dir = dir / 'test_data' + cls.test_data_dir = dir / "test_data" df = create_refined_eye_tracking_df( - np.array([[0.1, 12 * np.pi, 72 * np.pi, 196 * np.pi, False, - 196 * np.pi, 12 * np.pi, 72 * np.pi, - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., - 13., 14., 15.], - [0.2, 20 * np.pi, 90 * np.pi, 225 * np.pi, False, - 225 * np.pi, 20 * np.pi, 90 * np.pi, - 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., - 14., 15., 16.]]) + np.array( + [ + [ + 0.1, + 12 * np.pi, + 72 * np.pi, + 196 * np.pi, + False, + 196 * np.pi, + 12 * np.pi, + 72 * np.pi, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0, + ], + [ + 0.2, + 20 * np.pi, + 90 * np.pi, + 225 * np.pi, + False, + 225 * np.pi, + 20 * np.pi, + 90 * np.pi, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0, + 16.0, + ], + ] + ) ) cls.eye_tracking_table = EyeTrackingTable(eye_tracking=df) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.eye_tracking_table.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=EyeTrackingTable) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=EyeTrackingTable) else: obt = EyeTrackingTable.from_nwb(nwbfile=self.nwbfile) @@ -113,24 +139,20 @@ def test_read_write_nwb(self, roundtrip, class TestTimeFrameAlignment: @classmethod def setup_class(cls): - with open('/allen/aibs/informatics/module_test_data/ecephys/' - 'BEHAVIOR_ECEPHYS_WRITE_NWB_QUEUE_1044594870_input.json') \ - as f: + with open( + "/allen/aibs/informatics/module_test_data/ecephys/BEHAVIOR_ECEPHYS_WRITE_NWB_QUEUE_1044594870_input.json" + ) as f: input_data = json.load(f) - cls.input_data = input_data['session_data'] + cls.input_data = input_data["session_data"] - cls.eye_tracking_file = EyeTrackingFile.from_json( - dict_repr=cls.input_data) + cls.eye_tracking_file = EyeTrackingFile.from_json(dict_repr=cls.input_data) - cls.metadata_file = EyeTrackingMetadataFile.from_json( - dict_repr=cls.input_data) + cls.metadata_file = EyeTrackingMetadataFile.from_json(dict_repr=cls.input_data) # Making up timestamps cls.stimulus_timestamps = StimulusTimestamps( - timestamps=( - np.linspace(1.33955, 9.72219322e+03, - cls.eye_tracking_file.data.shape[0])), - monitor_delay=0) + timestamps=(np.linspace(1.33955, 9.72219322e03, cls.eye_tracking_file.data.shape[0])), monitor_delay=0 + ) @pytest.mark.requires_bamboo def test_metadata_frame_is_dropped(self): @@ -138,17 +160,17 @@ def test_metadata_frame_is_dropped(self): (adds extra metadata frame at front), that this extra frame is dropped""" # Make it 1 shorter than # frames - timestamps = self.stimulus_timestamps.update_timestamps( - timestamps=self.stimulus_timestamps.value[:-1]) + timestamps = self.stimulus_timestamps.update_timestamps(timestamps=self.stimulus_timestamps.value[:-1]) ett = EyeTrackingTable.from_data_file( - data_file=self.eye_tracking_file, - stimulus_timestamps=timestamps, - metadata_file=self.metadata_file + data_file=self.eye_tracking_file, stimulus_timestamps=timestamps, metadata_file=self.metadata_file + ) + assert ( + ett.value.shape[0] + == + # Subtract 1 for the metadata frame + self.eye_tracking_file.data.shape[0] - 1 ) - assert (ett.value.shape[0] == - # Subtract 1 for the metadata frame - self.eye_tracking_file.data.shape[0] - 1) @pytest.mark.requires_bamboo def test_timestamps_are_truncated(self): @@ -156,15 +178,16 @@ def test_timestamps_are_truncated(self): that the timestamps are truncated""" # Make it 2 longer than # frames timestamps = self.stimulus_timestamps.update_timestamps( - timestamps=np.concatenate([self.stimulus_timestamps.value, - self.stimulus_timestamps.value[-2:]])) + timestamps=np.concatenate([self.stimulus_timestamps.value, self.stimulus_timestamps.value[-2:]]) + ) ett = EyeTrackingTable.from_data_file( - data_file=self.eye_tracking_file, - stimulus_timestamps=timestamps, - metadata_file=self.metadata_file + data_file=self.eye_tracking_file, stimulus_timestamps=timestamps, metadata_file=self.metadata_file ) - assert (ett.value.shape[0] == - # subtract off metadata frame - self.eye_tracking_file.data.shape[0] - 1) + assert ( + ett.value.shape[0] + == + # subtract off metadata frame + self.eye_tracking_file.data.shape[0] - 1 + ) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_utils.py b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_utils.py index f6ec31dc5f..3203bd1341 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_utils.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_eye_tracking_utils.py @@ -4,49 +4,38 @@ import pathlib import numpy as np -from allensdk.brain_observatory.behavior.data_files.\ - eye_tracking_metadata_file import EyeTrackingMetadataFile +from allensdk.brain_observatory.behavior.data_files.eye_tracking_metadata_file import EyeTrackingMetadataFile -from allensdk.brain_observatory.behavior.\ - data_objects.eye_tracking.eye_tracking_table import ( - get_lost_frames) +from allensdk.brain_observatory.behavior.data_objects.eye_tracking.eye_tracking_table import get_lost_frames @pytest.mark.parametrize( - "input_str, lost_count, expected_array", - [('', 0, []), - ('13-19', 1, [12, 13, 14, 15, 16, 17, 18]), - ('5-7,100-103', 1, [4, 5, 6, 99, 100, 101, 102]), - ('77', 1, [76]), - ('3-5,21-25,201-204', 1, - [2, 3, 4, 20, 21, 22, 23, 24, 200, 201, 202, 203])]) -def test_get_lost_frames( - input_str, - lost_count, - expected_array, - tmp_path_factory, - helper_functions): + "input_str, lost_count, expected_array", + [ + ("", 0, []), + ("13-19", 1, [12, 13, 14, 15, 16, 17, 18]), + ("5-7,100-103", 1, [4, 5, 6, 99, 100, 101, 102]), + ("77", 1, [76]), + ("3-5,21-25,201-204", 1, [2, 3, 4, 20, 21, 22, 23, 24, 200, 201, 202, 203]), + ], +) +def test_get_lost_frames(input_str, lost_count, expected_array, tmp_path_factory, helper_functions): """ Test performance of get_lost_frames by constructing an example camera metadata json file with records of lost frames and running it through the method. """ - metadata = {'RecordingReport': - {'FramesLostCount': lost_count, - 'LostFrames': [input_str]}} + metadata = {"RecordingReport": {"FramesLostCount": lost_count, "LostFrames": [input_str]}} - tmpdir = pathlib.Path(tmp_path_factory.mktemp('get_lost_frames')) - json_path = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.json')[1]) - with open(json_path, 'w') as output_file: + tmpdir = pathlib.Path(tmp_path_factory.mktemp("get_lost_frames")) + json_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".json")[1]) + with open(json_path, "w") as output_file: output_file.write(json.dumps(metadata)) - dict_repr = {'raw_eye_tracking_video_meta_data': - str(json_path.resolve().absolute())} + dict_repr = {"raw_eye_tracking_video_meta_data": str(json_path.resolve().absolute())} - metadata = EyeTrackingMetadataFile.from_json( - dict_repr=dict_repr) + metadata = EyeTrackingMetadataFile.from_json(dict_repr=dict_repr) actual = get_lost_frames(eye_tracking_metadata=metadata) np.testing.assert_array_equal(actual, np.array(expected_array)) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_rig_geometry.py b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_rig_geometry.py index ae86f0d568..9b642ca557 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_rig_geometry.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/eye_tracking/test_rig_geometry.py @@ -8,11 +8,8 @@ import pytest from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.eye_tracking \ - .rig_geometry import \ - RigGeometry, Coordinates -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.brain_observatory.behavior.data_objects.eye_tracking.rig_geometry import RigGeometry, Coordinates +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestFromLims(LimsTest): @@ -21,20 +18,18 @@ def setup_class(cls): cls.ophys_experiment_id = 994278291 dir = Path(__file__).parent.parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - with open(test_data_dir / 'eye_tracking_rig_geometry.json') as f: + with open(test_data_dir / "eye_tracking_rig_geometry.json") as f: x = json.load(f) - x = x['rig_geometry'] - x = {'eye_tracking_rig_geometry': x} + x = x["rig_geometry"] + x = {"eye_tracking_rig_geometry": x} cls.expected = RigGeometry.from_json(dict_repr=x) @pytest.mark.requires_bamboo def test_from_lims(self): - behavior_session_id = BehaviorSessionId.from_lims( - db=self.dbconn, ophys_experiment_id=self.ophys_experiment_id) - rg = RigGeometry.from_lims( - behavior_session_id=behavior_session_id.value, lims_db=self.dbconn) + behavior_session_id = BehaviorSessionId.from_lims(db=self.dbconn, ophys_experiment_id=self.ophys_experiment_id) + rg = RigGeometry.from_lims(behavior_session_id=behavior_session_id.value, lims_db=self.dbconn) assert rg == self.expected @pytest.mark.requires_bamboo @@ -48,15 +43,14 @@ def test_rig_geometry_newer_than_experiment(self): # experiment date_of_acquisition ophys_experiment_id = 521405260 - rg = RigGeometry.from_lims( - ophys_experiment_id=ophys_experiment_id, lims_db=self.dbconn) + rg = RigGeometry.from_lims(ophys_experiment_id=ophys_experiment_id, lims_db=self.dbconn) expected = RigGeometry( camera_position_mm=Coordinates(x=130.0, y=0.0, z=0.0), led_position=Coordinates(x=265.1, y=-39.3, z=1.0), monitor_position_mm=Coordinates(x=170.0, y=0.0, z=0.0), camera_rotation_deg=Coordinates(x=0.0, y=0.0, z=13.1), monitor_rotation_deg=Coordinates(x=0.0, y=0.0, z=0.0), - equipment='CAM2P.1' + equipment="CAM2P.1", ) assert rg == expected @@ -64,11 +58,10 @@ def test_only_single_geometry_returned(self): """Tests that when a rig contains multiple geometries, that only 1 is returned""" dir = Path(__file__).parent.parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" # This example contains multiple geometries per config - df = pd.read_pickle( - str(test_data_dir / 'raw_eye_tracking_rig_geometry.pkl')) + df = pd.read_pickle(str(test_data_dir / "raw_eye_tracking_rig_geometry.pkl")) obtained = RigGeometry._select_most_recent_geometry(rig_geometry=df) assert (obtained.groupby(obtained.index).size() == 1).all() @@ -80,18 +73,17 @@ def setup_class(cls): cls.ophys_experiment_id = 994278291 dir = Path(__file__).parent.parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - with open(test_data_dir / 'eye_tracking_rig_geometry.json') as f: + with open(test_data_dir / "eye_tracking_rig_geometry.json") as f: x = json.load(f) - x = x['rig_geometry'] - x = {'eye_tracking_rig_geometry': x} + x = x["rig_geometry"] + x = {"eye_tracking_rig_geometry": x} cls.expected = RigGeometry.from_json(dict_repr=x) @pytest.mark.requires_bamboo def test_from_json(self): - dict_repr = {'eye_tracking_rig_geometry': - self.expected.to_dict()['rig_geometry']} + dict_repr = {"eye_tracking_rig_geometry": self.expected.to_dict()["rig_geometry"]} rg = RigGeometry.from_json(dict_repr=dict_repr) assert rg == self.expected @@ -100,30 +92,25 @@ class TestNWB: @classmethod def setup_class(cls): dir = Path(__file__).parent.parent.resolve() - cls.test_data_dir = dir / 'test_data' + cls.test_data_dir = dir / "test_data" - with open(cls.test_data_dir / 'eye_tracking_rig_geometry.json') as f: + with open(cls.test_data_dir / "eye_tracking_rig_geometry.json") as f: x = json.load(f) - x = x['rig_geometry'] - x = {'eye_tracking_rig_geometry': x} + x = x["rig_geometry"] + x = {"eye_tracking_rig_geometry": x} cls.rig_geometry = RigGeometry.from_json(dict_repr=x) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.rig_geometry.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=RigGeometry) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=RigGeometry) else: obt = RigGeometry.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/lims_util.py b/allensdk/test/brain_observatory/behavior/data_objects/lims_util.py index 93226b09cd..539c499cc2 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/lims_util.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/lims_util.py @@ -1,20 +1,17 @@ -from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP, \ - MTRAIN_DB_CREDENTIAL_MAP +from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP, MTRAIN_DB_CREDENTIAL_MAP from allensdk.internal.api import db_connection_creator class LimsTest: """Helper class for testing LIMS. For each test, checks whether bamboo is required and if so sets up a connection""" + def setup_method(self, method): - marks = getattr(method, 'pytestmark', None) + marks = getattr(method, "pytestmark", None) if marks: marks = [m.name for m in marks] # Will only create a dbconn if the test requires_bamboo - if 'requires_bamboo' in marks: - self.dbconn = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP) - self.mtrainconn = db_connection_creator( - fallback_credentials=MTRAIN_DB_CREDENTIAL_MAP - ) + if "requires_bamboo" in marks: + self.dbconn = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + self.mtrainconn = db_connection_creator(fallback_credentials=MTRAIN_DB_CREDENTIAL_MAP) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/metadata/behavior_metadata/test_behavior_metadata.py b/allensdk/test/brain_observatory/behavior/data_objects/metadata/behavior_metadata/test_behavior_metadata.py index 0eb4833b34..1e4bf44274 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/metadata/behavior_metadata/test_behavior_metadata.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/metadata/behavior_metadata/test_behavior_metadata.py @@ -67,13 +67,8 @@ def _get_meta(): sex=Sex(sex="M"), age=Age(age=139), reporter_line=ReporterLine(reporter_line="Ai93(TITL-GCaMP6f)"), - full_genotype=FullGenotype( - full_genotype="Slc17a7-IRES2-Cre/wt;Camk2a-tTA/wt;" - "Ai93(TITL-GCaMP6f)/wt" - ), - driver_line=DriverLine( - driver_line=["Camk2a-tTA", "Slc17a7-IRES2-Cre"] - ), + full_genotype=FullGenotype(full_genotype="Slc17a7-IRES2-Cre/wt;Camk2a-tTA/wt;Ai93(TITL-GCaMP6f)/wt"), + driver_line=DriverLine(driver_line=["Camk2a-tTA", "Slc17a7-IRES2-Cre"]), mouse_id=MouseId(mouse_id="416369"), ) behavior_meta = BehaviorMetadata( @@ -82,12 +77,8 @@ def _get_meta(): equipment=Equipment(equipment_name="my_device"), stimulus_frame_rate=StimulusFrameRate(stimulus_frame_rate=60.0), session_type=SessionType(session_type="Unknown"), - behavior_session_uuid=BehaviorSessionUUID( - behavior_session_uuid=uuid.uuid4() - ), - date_of_acquisition=DateOfAcquisition( - datetime.datetime(2022, 8, 24, 12, 35) - ), + behavior_session_uuid=BehaviorSessionUUID(behavior_session_uuid=uuid.uuid4()), + date_of_acquisition=DateOfAcquisition(datetime.datetime(2022, 8, 24, 12, 35)), project_code=ProjectCode("1234"), ) return behavior_meta @@ -98,22 +89,16 @@ class TestLims(LimsTest): def test_behavior_session_uuid(self): behavior_session_id = 823847007 meta = BehaviorMetadata.from_lims( - behavior_session_id=BehaviorSessionId( - behavior_session_id=behavior_session_id - ), + behavior_session_id=BehaviorSessionId(behavior_session_id=behavior_session_id), lims_db=self.dbconn, ) - assert meta.behavior_session_uuid == uuid.UUID( - "394a910e-94c7-4472-9838-5345aff59ed8" - ) + assert meta.behavior_session_uuid == uuid.UUID("394a910e-94c7-4472-9838-5345aff59ed8") class TestBehaviorMetadata(BehaviorMetaTestCase): def test_cre_line(self): """Tests that cre_line properly parsed from driver_line""" - fg = FullGenotype( - full_genotype="Sst-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt" - ) + fg = FullGenotype(full_genotype="Sst-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt") assert fg.parse_cre_line() == "Sst-IRES-Cre" def test_cre_line_bad_full_genotype(self): @@ -123,10 +108,7 @@ def test_cre_line_bad_full_genotype(self): with pytest.warns(UserWarning) as record: cre_line = fg.parse_cre_line(warn=True) assert cre_line is None - assert ( - str(record[0].message) == "Unable to parse cre_line from " - "full_genotype" - ) + assert str(record[0].message) == "Unable to parse cre_line from full_genotype" def test_cre_line_full_genotype_is_none(self): """Test that cre_line is None and no error raised""" @@ -135,10 +117,7 @@ def test_cre_line_full_genotype_is_none(self): with pytest.warns(UserWarning) as record: cre_line = fg.parse_cre_line(warn=True) assert cre_line is None - assert ( - str(record[0].message) == "Unable to parse cre_line from " - "full_genotype" - ) + assert str(record[0].message) == "Unable to parse cre_line from full_genotype" def test_reporter_line(self): """Test that reporter line properly parsed from list""" @@ -155,21 +134,17 @@ def test_reporter_line_str(self): ( ( ("foo", "bar"), - "More than 1 reporter line. " "Returning the first one", + "More than 1 reporter line. Returning the first one", "foo", ), (None, "Error parsing reporter line. It is null.", None), ([], "Error parsing reporter line. The array is empty", None), ), ) - def test_reporter_edge_cases( - self, input_reporter_line, warning_msg, expected - ): + def test_reporter_edge_cases(self, input_reporter_line, warning_msg, expected): """Test reporter line edge cases""" with pytest.warns(UserWarning) as record: - reporter_line = ReporterLine.parse( - reporter_line=input_reporter_line, warn=True - ) + reporter_line = ReporterLine.parse(reporter_line=input_reporter_line, warn=True) assert reporter_line == expected assert str(record[0].message) == warning_msg @@ -183,21 +158,17 @@ def test_age_in_days(self): ( ( "unkown", - "Could not parse numeric age from age code " - '(age code does not start with "P")', + 'Could not parse numeric age from age code (age code does not start with "P")', None, ), ( "P", - "Could not parse numeric age from age code " - "(no numeric values found in age code)", + "Could not parse numeric age from age code (no numeric values found in age code)", None, ), ), ) - def test_age_in_days_edge_cases( - self, monkeypatch, input_age, warning_msg, expected - ): + def test_age_in_days_edge_cases(self, monkeypatch, input_age, warning_msg, expected): """Test age in days edge cases""" with pytest.warns(UserWarning) as record: age_in_days = Age._age_code_to_days(age=input_age, warn=True) @@ -211,12 +182,8 @@ def test_age_in_days_edge_cases( # Vanilla test case ( { - "extractor_expt_date": datetime.datetime.strptime( - "2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S" - ), - "pkl_expt_date": datetime.datetime.strptime( - "2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S" - ), + "extractor_expt_date": datetime.datetime.strptime("2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S"), + "pkl_expt_date": datetime.datetime.strptime("2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S"), "behavior_session_id": 1, }, None, @@ -224,9 +191,7 @@ def test_age_in_days_edge_cases( # pkl expt date stored in unix format ( { - "extractor_expt_date": datetime.datetime.strptime( - "2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S" - ), + "extractor_expt_date": datetime.datetime.strptime("2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S"), "pkl_expt_date": 1615716855.0, "behavior_session_id": 2, }, @@ -235,12 +200,8 @@ def test_age_in_days_edge_cases( # Extractor and pkl dates differ significantly ( { - "extractor_expt_date": datetime.datetime.strptime( - "2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S" - ), - "pkl_expt_date": datetime.datetime.strptime( - "2021-03-14 20:14:15", "%Y-%m-%d %H:%M:%S" - ), + "extractor_expt_date": datetime.datetime.strptime("2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S"), + "pkl_expt_date": datetime.datetime.strptime("2021-03-14 20:14:15", "%Y-%m-%d %H:%M:%S"), "behavior_session_id": 3, }, "The `date_of_acquisition` field in LIMS *", @@ -248,9 +209,7 @@ def test_age_in_days_edge_cases( # pkl file contains an unparseable datetime ( { - "extractor_expt_date": datetime.datetime.strptime( - "2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S" - ), + "extractor_expt_date": datetime.datetime.strptime("2021-03-14 03:14:15", "%Y-%m-%d %H:%M:%S"), "pkl_expt_date": None, "behavior_session_id": 4, }, @@ -258,9 +217,7 @@ def test_age_in_days_edge_cases( ), ], ) - def test_get_date_of_acquisition( - self, tmp_path, test_params, expected_warn_msg - ): + def test_get_date_of_acquisition(self, tmp_path, test_params, expected_warn_msg): mock_session_id = test_params["behavior_session_id"] pkl_save_path = tmp_path / f"mock_pkl_{mock_session_id}.pkl" @@ -268,9 +225,7 @@ def test_get_date_of_acquisition( pickle.dump({"start_time": test_params["pkl_expt_date"]}, handle) tz = pytz.timezone("America/Los_Angeles") - extractor_expt_date = tz.localize( - test_params["extractor_expt_date"] - ).astimezone(pytz.utc) + extractor_expt_date = tz.localize(test_params["extractor_expt_date"]).astimezone(pytz.utc) stimulus_file = BehaviorStimulusFile(filepath=pkl_save_path) obt_date = DateOfAcquisition(date_of_acquisition=extractor_expt_date) @@ -286,9 +241,7 @@ def test_get_date_of_acquisition( def test_indicator(self): """Test that indicator is parsed from full_genotype""" - reporter_line = ReporterLine( - reporter_line="Ai148(TIT2L-GC6f-ICL-tTA2)" - ) + reporter_line = ReporterLine(reporter_line="Ai148(TIT2L-GC6f-ICL-tTA2)") assert reporter_line.parse_indicator() == "GCaMP6f" @pytest.mark.parametrize( @@ -296,8 +249,7 @@ def test_indicator(self): ( ( None, - "Could not parse indicator from reporter because there is no " - "reporter", + "Could not parse indicator from reporter because there is no reporter", None, ), ( @@ -308,9 +260,7 @@ def test_indicator(self): ), ), ) - def test_indicator_edge_cases( - self, input_reporter_line, warning_msg, expected - ): + def test_indicator_edge_cases(self, input_reporter_line, warning_msg, expected): """Test indicator parsing edge cases""" with pytest.warns(UserWarning) as record: reporter_line = ReporterLine(reporter_line=input_reporter_line) @@ -326,33 +276,23 @@ def setup_class(cls): dir = Path(__file__).parent.parent.parent.resolve() test_data_dir = dir / "test_data" sf_path = test_data_dir / "stimulus_file.pkl" - cls.stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={"behavior_stimulus_file": str(sf_path)} - ) + cls.stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(sf_path)}) def test_session_uuid(self): - uuid = BehaviorSessionUUID.from_stimulus_file( - stimulus_file=self.stimulus_file - ) + uuid = BehaviorSessionUUID.from_stimulus_file(stimulus_file=self.stimulus_file) expected = UUID("138531ab-fe59-4523-9154-07c8d97bbe03") assert expected == uuid.value def test_get_stimulus_frame_rate(self): - rate = StimulusFrameRate.from_stimulus_file( - stimulus_file=self.stimulus_file - ) + rate = StimulusFrameRate.from_stimulus_file(stimulus_file=self.stimulus_file) assert 62.0 == rate.value def test_date_of_acquisition_utc(): """Tests that when read from json (in Pacific time), that date of acquisition is converted to utc""" - expected = DateOfAcquisition( - date_of_acquisition=datetime.datetime(2019, 9, 26, 16, tzinfo=pytz.UTC) - ) - actual = DateOfAcquisition.from_json( - dict_repr={"date_of_acquisition": "2019-09-26 09:00:00"} - ) + expected = DateOfAcquisition(date_of_acquisition=datetime.datetime(2019, 9, 26, 16, tzinfo=pytz.UTC)) + actual = DateOfAcquisition.from_json(dict_repr={"date_of_acquisition": "2019-09-26 09:00:00"}) assert expected == actual @@ -361,21 +301,15 @@ def setup_method(self, method): self.nwbfile = pynwb.NWBFile( session_description="asession", identifier="afile", - session_start_time=datetime.datetime( - 2022, 8, 24, 12, 35, tzinfo=pytz.UTC - ), + session_start_time=datetime.datetime(2022, 8, 24, 12, 35, tzinfo=pytz.UTC), ) @pytest.mark.parametrize("roundtrip", [True, False]) - def test_add_behavior_only_metadata( - self, roundtrip, data_object_roundtrip_fixture - ): + def test_add_behavior_only_metadata(self, roundtrip, data_object_roundtrip_fixture): self.meta.to_nwb(nwbfile=self.nwbfile) if roundtrip: - meta_obt = data_object_roundtrip_fixture( - self.nwbfile, BehaviorMetadata - ) + meta_obt = data_object_roundtrip_fixture(self.nwbfile, BehaviorMetadata) else: meta_obt = BehaviorMetadata.from_nwb(nwbfile=self.nwbfile) assert self.meta == meta_obt diff --git a/allensdk/test/brain_observatory/behavior/data_objects/metadata/test_behavior_ophys_metadata.py b/allensdk/test/brain_observatory/behavior/data_objects/metadata/test_behavior_ophys_metadata.py index e5d6c58762..cb62b9590b 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/metadata/test_behavior_ophys_metadata.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/metadata/test_behavior_ophys_metadata.py @@ -61,28 +61,20 @@ def _get_meta(): ophys_container_id=OphysContainerId(ophys_container_id=5678), field_of_view_shape=FieldOfViewShape(width=4, height=4), imaging_depth=ImagingDepth(imaging_depth=375), - targeted_imaging_depth=TargetedImagingDepth( - targeted_imaging_depth=375 - ), + targeted_imaging_depth=TargetedImagingDepth(targeted_imaging_depth=375), project_code=OphysProjectCode("1234"), ) behavior_metadata = TestBehaviorMetadata() behavior_metadata.setup_class() - return BehaviorOphysMetadata( - behavior_metadata=behavior_metadata.meta, ophys_metadata=ophys_meta - ) + return BehaviorOphysMetadata(behavior_metadata=behavior_metadata.meta, ophys_metadata=ophys_meta) def _get_multiplane_meta(self): bo_meta = self.meta - bo_meta.behavior_metadata._equipment = Equipment( - equipment_name="MESO.1" - ) + bo_meta.behavior_metadata._equipment = Equipment(equipment_name="MESO.1") ophys_experiment_metadata = bo_meta.ophys_metadata - imaging_plane_group = ImagingPlaneGroup( - plane_group_count=5, plane_group=0 - ) + imaging_plane_group = ImagingPlaneGroup(plane_group_count=5, plane_group=0) multiplane_meta = MultiplaneMetadata( ophys_experiment_id=ophys_experiment_metadata.ophys_experiment_id, ophys_session_id=ophys_experiment_metadata._ophys_session_id, @@ -108,9 +100,7 @@ def setup_method(self, method): # Will only create a dbconn if the test requires_bamboo if "requires_bamboo" in marks: - self.dbconn = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + self.dbconn = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) @pytest.mark.requires_bamboo @pytest.mark.parametrize("meso", [True, False]) @@ -131,17 +121,11 @@ def test_from_lims(self, meso): assert bom.ophys_metadata.targeted_imaging_depth == 150 assert bom.behavior_metadata.session_type == "OPHYS_1_images_A" - assert ( - bom.behavior_metadata.subject_metadata.reporter_line - == "Ai148(TIT2L-GC6f-ICL-tTA2)" - ) - assert bom.behavior_metadata.subject_metadata.driver_line == [ - "Sst-IRES-Cre" - ] + assert bom.behavior_metadata.subject_metadata.reporter_line == "Ai148(TIT2L-GC6f-ICL-tTA2)" + assert bom.behavior_metadata.subject_metadata.driver_line == ["Sst-IRES-Cre"] assert bom.behavior_metadata.subject_metadata.mouse_id == "457841" assert ( - bom.behavior_metadata.subject_metadata.full_genotype - == "Sst-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt" + bom.behavior_metadata.subject_metadata.full_genotype == "Sst-IRES-Cre/wt;Ai148(TIT2L-GC6f-ICL-tTA2)/wt" ) assert bom.behavior_metadata.subject_metadata.age_in_days == 206 assert bom.behavior_metadata.subject_metadata.sex == "F" @@ -150,10 +134,7 @@ def test_from_lims(self, meso): assert bom.ophys_metadata.imaging_depth == 175 assert bom.ophys_metadata.targeted_imaging_depth == 175 assert bom.behavior_metadata.session_type == "OPHYS_4_images_A" - assert ( - bom.behavior_metadata.subject_metadata.reporter_line - == "Ai93(TITL-GCaMP6f)" - ) + assert bom.behavior_metadata.subject_metadata.reporter_line == "Ai93(TITL-GCaMP6f)" assert bom.behavior_metadata.subject_metadata.driver_line == [ "Camk2a-tTA", "Slc17a7-IRES2-Cre", @@ -176,9 +157,7 @@ def setup_method(self, method): dict_repr = json.load(f) dict_repr = dict_repr["session_data"] dict_repr["sync_file"] = str(test_data_dir / "sync.h5") - dict_repr["behavior_stimulus_file"] = str( - test_data_dir / "behavior_stimulus_file.pkl" - ) + dict_repr["behavior_stimulus_file"] = str(test_data_dir / "behavior_stimulus_file.pkl") dict_repr["dff_file"] = str(test_data_dir / "demix_file.h5") self.dict_repr = dict_repr @@ -187,9 +166,7 @@ def setup_method(self, method): def test_from_json(self, meso): if meso: self.dict_repr["rig_name"] = "MESO.1" - bom = BehaviorOphysMetadata.from_json( - dict_repr=self.dict_repr, is_multiplane=meso - ) + bom = BehaviorOphysMetadata.from_json(dict_repr=self.dict_repr, is_multiplane=meso) if meso: assert isinstance(bom.ophys_metadata, MultiplaneMetadata) @@ -203,16 +180,12 @@ def setup_method(self, method): self.nwbfile = pynwb.NWBFile( session_description="asession", identifier=str(self.meta.ophys_metadata.ophys_experiment_id), - session_start_time=datetime.datetime( - 2022, 8, 24, 12, 35, tzinfo=pytz.UTC - ), + session_start_time=datetime.datetime(2022, 8, 24, 12, 35, tzinfo=pytz.UTC), ) @pytest.mark.parametrize("meso", [True, False]) @pytest.mark.parametrize("roundtrip", [True, False]) - def test_read_write_nwb( - self, roundtrip, data_object_roundtrip_fixture, meso - ): + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture, meso): if meso: self.meta = self._get_multiplane_meta() diff --git a/allensdk/test/brain_observatory/behavior/data_objects/nwb_input_json.py b/allensdk/test/brain_observatory/behavior/data_objects/nwb_input_json.py index 785671dbed..224192fae5 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/nwb_input_json.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/nwb_input_json.py @@ -10,13 +10,9 @@ def __init__(self): dict_repr = json.load(f) dict_repr = dict_repr["session_data"] dict_repr["sync_file"] = str(test_data_dir / "sync.h5") - dict_repr["behavior_stimulus_file"] = str( - test_data_dir / "behavior_stimulus_file.pkl" - ) + dict_repr["behavior_stimulus_file"] = str(test_data_dir / "behavior_stimulus_file.pkl") dict_repr["dff_file"] = str(test_data_dir / "demix_file.h5") - dict_repr["neuropil_corrected_file"] = str( - test_data_dir / "neuropil_corrected_file.h5" - ) + dict_repr["neuropil_corrected_file"] = str(test_data_dir / "neuropil_corrected_file.h5") dict_repr["demix_file"] = str(test_data_dir / "demix_file.h5") dict_repr["neuropil_file"] = str(test_data_dir / "demix_file.h5") dict_repr["events_file"] = str(test_data_dir / "events.h5") diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/conftest.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/conftest.py index 23cac166f2..c839c957b8 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/conftest.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/conftest.py @@ -6,13 +6,13 @@ import h5py import json -from allensdk.brain_observatory.behavior.data_files.sync_file import ( - SyncFile) +from allensdk.brain_observatory.behavior.data_files.sync_file import SyncFile from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( BehaviorStimulusFile, ReplayStimulusFile, - MappingStimulusFile) + MappingStimulusFile, +) @pytest.fixture @@ -26,13 +26,15 @@ def basic_running_stim_file_fixture(): "dx": rng.random((100,)), "vsig": rng.uniform(low=0.0, high=5.1, size=(100,)), "vin": rng.uniform(low=4.9, high=5.0, size=(100,)), - }]}}} + } + ] + } + } + } -@pytest.fixture(scope='session') -def stimulus_file_frame_fixture( - tmp_path_factory, - helper_functions): +@pytest.fixture(scope="session") +def stimulus_file_frame_fixture(tmp_path_factory, helper_functions): """ Writes some skeletal stimulus files (really only good for getting frame counts) to disk. Yields a tuple of dicts @@ -45,27 +47,21 @@ def stimulus_file_frame_fixture( temporary pickle file """ - tmpdir = tmp_path_factory.mktemp('all_frame_count_test') + tmpdir = tmp_path_factory.mktemp("all_frame_count_test") pkl_path_lookup = dict() - pkl_path_lookup['behavior'] = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + pkl_path_lookup["behavior"] = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) - pkl_path_lookup['mapping'] = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + pkl_path_lookup["mapping"] = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) - pkl_path_lookup['replay'] = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + pkl_path_lookup["replay"] = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) - frame_count_lookup = {'behavior': 13, 'mapping': 44, 'replay': 76} + frame_count_lookup = {"behavior": 13, "mapping": 44, "replay": 76} - data = {'items': - {'behavior': - {'intervalsms': - list(range(frame_count_lookup['behavior']-1))}}} - pd.to_pickle(data, pkl_path_lookup['behavior']) + data = {"items": {"behavior": {"intervalsms": list(range(frame_count_lookup["behavior"] - 1))}}} + pd.to_pickle(data, pkl_path_lookup["behavior"]) - for key in ('mapping', 'replay'): - data = {'intervalsms': list(range(frame_count_lookup[key]-1))} + for key in ("mapping", "replay"): + data = {"intervalsms": list(range(frame_count_lookup[key] - 1))} pd.to_pickle(data, pkl_path_lookup[key]) output_pkl_path = dict() @@ -77,11 +73,10 @@ def stimulus_file_frame_fixture( yield (frame_count_lookup, output_pkl_path) for key in pkl_path_lookup: - helper_functions.windows_safe_cleanup( - file_path=pkl_path_lookup[key]) + helper_functions.windows_safe_cleanup(file_path=pkl_path_lookup[key]) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def merge_data_fixture(): """ Return a dict keyed on @@ -98,58 +93,50 @@ def merge_data_fixture(): output = dict() rng = np.random.default_rng(229988) n_total = 0 - for key in ('behavior', 'mapping', 'replay'): + for key in ("behavior", "mapping", "replay"): nt = rng.integers(20, 40, 1)[0] n_total += nt this_entry = dict() speed = rng.random(nt) - dx = rng.random(nt)+0.1 + dx = rng.random(nt) + 0.1 v_in = rng.random(nt) v_sig = rng.random(nt) - n_skip = rng.integers(nt//4, nt//3, 1)[0] - skipped = np.sort(rng.choice(np.arange(nt, dtype=int), - n_skip, - replace=False)) + n_skip = rng.integers(nt // 4, nt // 3, 1)[0] + skipped = np.sort(rng.choice(np.arange(nt, dtype=int), n_skip, replace=False)) dx[skipped] = 0.0 - df = pd.DataFrame(data={'dx': dx, - 'v_in': v_in, - 'v_sig': v_sig, - 'speed': speed}) + df = pd.DataFrame(data={"dx": dx, "v_in": v_in, "v_sig": v_sig, "speed": speed}) kept_mask = np.ones(nt, dtype=bool) kept_mask[skipped] = False - this_entry['speed'] = speed - this_entry['dx'] = dx - this_entry['v_in'] = v_in - this_entry['v_sig'] = v_sig - this_entry['dataframe'] = df - this_entry['kept_mask'] = kept_mask + this_entry["speed"] = speed + this_entry["dx"] = dx + this_entry["v_in"] = v_in + this_entry["v_sig"] = v_sig + this_entry["dataframe"] = df + this_entry["kept_mask"] = kept_mask output[key] = this_entry - output['n_timesteps'] = n_total + output["n_timesteps"] = n_total return output # Below here we define a self-consistent set of stimulus datasets for testing -@pytest.fixture(scope='session') -def pkl_tmp_dir_fixture( - tmp_path_factory): + +@pytest.fixture(scope="session") +def pkl_tmp_dir_fixture(tmp_path_factory): """ A directory where the temporary data files can be written """ - tmpdir = pathlib.Path(tmp_path_factory.mktemp( - 'self_consistent_multi_stim')) + tmpdir = pathlib.Path(tmp_path_factory.mktemp("self_consistent_multi_stim")) return tmpdir -@pytest.fixture(scope='session') -def behavior_pkl_fixture( - pkl_tmp_dir_fixture, - helper_functions): +@pytest.fixture(scope="session") +def behavior_pkl_fixture(pkl_tmp_dir_fixture, helper_functions): """ Write a pkl file for behavior data. @@ -163,17 +150,11 @@ def behavior_pkl_fixture( rng = np.random.default_rng(77123412) - pkl_path = pathlib.Path( - tempfile.mkstemp( - dir=pkl_tmp_dir_fixture, - prefix='behavior_', - suffix='.pkl')[1]) + pkl_path = pathlib.Path(tempfile.mkstemp(dir=pkl_tmp_dir_fixture, prefix="behavior_", suffix=".pkl")[1]) n_frames = 77 - n_skip = rng.integers(10, n_frames//2, 1)[0] - to_skip = rng.choice(np.arange(n_frames, dtype=int), - n_skip, - replace=False) + n_skip = rng.integers(10, n_frames // 2, 1)[0] + to_skip = rng.choice(np.arange(n_frames, dtype=int), n_skip, replace=False) kept_mask = np.ones(n_frames, dtype=bool) kept_mask[to_skip] = False @@ -190,35 +171,36 @@ def behavior_pkl_fixture( "dx": dx, "vsig": vsig, "vin": vin, - }]}}} + } + ] + } + } + } - data['items']['behavior']['intervalsms'] = list(range(n_frames-1)) + data["items"]["behavior"]["intervalsms"] = list(range(n_frames - 1)) pd.to_pickle(data, pkl_path) output = { - 'path_to_pkl': str(pkl_path.resolve().absolute()), - 'dx': dx, - 'vsig': vsig, - 'vin': vin, - 'n_frames': n_frames, - 'kept_mask': kept_mask} + "path_to_pkl": str(pkl_path.resolve().absolute()), + "dx": dx, + "vsig": vsig, + "vin": vin, + "n_frames": n_frames, + "kept_mask": kept_mask, + } yield output helper_functions.windows_safe_cleanup(file_path=pkl_path) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def behavior_stim_file_fixture(behavior_pkl_fixture): - yield BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': - behavior_pkl_fixture['path_to_pkl']}) + yield BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": behavior_pkl_fixture["path_to_pkl"]}) -@pytest.fixture(scope='session') -def mapping_pkl_fixture( - pkl_tmp_dir_fixture, - helper_functions): +@pytest.fixture(scope="session") +def mapping_pkl_fixture(pkl_tmp_dir_fixture, helper_functions): """ Write a pkl file for mapping data. @@ -232,17 +214,11 @@ def mapping_pkl_fixture( rng = np.random.default_rng(442138) - pkl_path = pathlib.Path( - tempfile.mkstemp( - dir=pkl_tmp_dir_fixture, - prefix='mapping_', - suffix='.pkl')[1]) + pkl_path = pathlib.Path(tempfile.mkstemp(dir=pkl_tmp_dir_fixture, prefix="mapping_", suffix=".pkl")[1]) n_frames = 53 - n_skip = rng.integers(10, n_frames//2, 1)[0] - to_skip = rng.choice(np.arange(n_frames, dtype=int), - n_skip, - replace=False) + n_skip = rng.integers(10, n_frames // 2, 1)[0] + to_skip = rng.choice(np.arange(n_frames, dtype=int), n_skip, replace=False) kept_mask = np.ones(n_frames, dtype=bool) kept_mask[to_skip] = False @@ -259,35 +235,36 @@ def mapping_pkl_fixture( "dx": dx, "vsig": vsig, "vin": vin, - }]}}} + } + ] + } + } + } - data['intervalsms'] = list(range(n_frames-1)) + data["intervalsms"] = list(range(n_frames - 1)) pd.to_pickle(data, pkl_path) output = { - 'path_to_pkl': str(pkl_path.resolve().absolute()), - 'dx': dx, - 'vsig': vsig, - 'vin': vin, - 'n_frames': n_frames, - 'kept_mask': kept_mask} + "path_to_pkl": str(pkl_path.resolve().absolute()), + "dx": dx, + "vsig": vsig, + "vin": vin, + "n_frames": n_frames, + "kept_mask": kept_mask, + } yield output helper_functions.windows_safe_cleanup(file_path=pkl_path) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def mapping_stim_file_fixture(mapping_pkl_fixture): - yield MappingStimulusFile.from_json( - dict_repr={'mapping_stimulus_file': - mapping_pkl_fixture['path_to_pkl']}) + yield MappingStimulusFile.from_json(dict_repr={"mapping_stimulus_file": mapping_pkl_fixture["path_to_pkl"]}) -@pytest.fixture(scope='session') -def replay_pkl_fixture( - pkl_tmp_dir_fixture, - helper_functions): +@pytest.fixture(scope="session") +def replay_pkl_fixture(pkl_tmp_dir_fixture, helper_functions): """ Write a pkl file for replay data. @@ -301,17 +278,11 @@ def replay_pkl_fixture( rng = np.random.default_rng(55332211) - pkl_path = pathlib.Path( - tempfile.mkstemp( - dir=pkl_tmp_dir_fixture, - prefix='replay_', - suffix='.pkl')[1]) + pkl_path = pathlib.Path(tempfile.mkstemp(dir=pkl_tmp_dir_fixture, prefix="replay_", suffix=".pkl")[1]) n_frames = 47 - n_skip = rng.integers(10, n_frames//2, 1)[0] - to_skip = rng.choice(np.arange(n_frames, dtype=int), - n_skip, - replace=False) + n_skip = rng.integers(10, n_frames // 2, 1)[0] + to_skip = rng.choice(np.arange(n_frames, dtype=int), n_skip, replace=False) kept_mask = np.ones(n_frames, dtype=bool) kept_mask[to_skip] = False @@ -328,105 +299,101 @@ def replay_pkl_fixture( "dx": dx, "vsig": vsig, "vin": vin, - }]}}} + } + ] + } + } + } - data['intervalsms'] = list(range(n_frames-1)) + data["intervalsms"] = list(range(n_frames - 1)) pd.to_pickle(data, pkl_path) output = { - 'path_to_pkl': str(pkl_path.resolve().absolute()), - 'dx': dx, - 'vsig': vsig, - 'vin': vin, - 'n_frames': n_frames, - 'kept_mask': kept_mask} + "path_to_pkl": str(pkl_path.resolve().absolute()), + "dx": dx, + "vsig": vsig, + "vin": vin, + "n_frames": n_frames, + "kept_mask": kept_mask, + } yield output helper_functions.windows_safe_cleanup(file_path=pkl_path) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def replay_stim_file_fixture(replay_pkl_fixture): - yield ReplayStimulusFile.from_json( - dict_repr={'replay_stimulus_file': - replay_pkl_fixture['path_to_pkl']}) + yield ReplayStimulusFile.from_json(dict_repr={"replay_stimulus_file": replay_pkl_fixture["path_to_pkl"]}) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sync_path_fixture( - behavior_pkl_fixture, - replay_pkl_fixture, - mapping_pkl_fixture, - pkl_tmp_dir_fixture, - helper_functions): + behavior_pkl_fixture, replay_pkl_fixture, mapping_pkl_fixture, pkl_tmp_dir_fixture, helper_functions +): """ path to test sync file that can go with these pickle files """ - metadata = {'ni_daq': - {'device': 'Dev1', - 'counter_output_freq': 100.0, - 'sample_rate': 100.0, - 'counter_bits': 32, - 'event_bits': 32}, - 'start_time': '2020-10-07 14:01:17.336502', - 'stop_time': '2020-10-07 16:42:24.177205', - 'line_labels': ['vsync_stim', - 'stim_running', - 'vsync_2p', - 'lick_sensor', - 'eye_tracking', - 'behavior_monitoring', - 'stim_photodiode'], - 'timeouts': [], - 'version': '2.2.1+g1bc7438.b42257', - 'sampling_type': 'frequency', - 'file_version': '1.0.0', - 'line_label_revision': 3, - 'total_samples': 10000} - - sync_path = pathlib.Path( - tempfile.mkstemp( - dir=pkl_tmp_dir_fixture, - prefix='example_', - suffix='.sync')[1]) - - nb = behavior_pkl_fixture['n_frames'] - nm = mapping_pkl_fixture['n_frames'] - nr = replay_pkl_fixture['n_frames'] + metadata = { + "ni_daq": { + "device": "Dev1", + "counter_output_freq": 100.0, + "sample_rate": 100.0, + "counter_bits": 32, + "event_bits": 32, + }, + "start_time": "2020-10-07 14:01:17.336502", + "stop_time": "2020-10-07 16:42:24.177205", + "line_labels": [ + "vsync_stim", + "stim_running", + "vsync_2p", + "lick_sensor", + "eye_tracking", + "behavior_monitoring", + "stim_photodiode", + ], + "timeouts": [], + "version": "2.2.1+g1bc7438.b42257", + "sampling_type": "frequency", + "file_version": "1.0.0", + "line_label_revision": 3, + "total_samples": 10000, + } + + sync_path = pathlib.Path(tempfile.mkstemp(dir=pkl_tmp_dir_fixture, prefix="example_", suffix=".sync")[1]) + + nb = behavior_pkl_fixture["n_frames"] + nm = mapping_pkl_fixture["n_frames"] + nr = replay_pkl_fixture["n_frames"] nframes = nb + nm + nr - n_lines = 5*nframes + n_lines = 5 * nframes data = np.zeros((n_lines, 2), dtype=np.uint32) - data[:, 0] = np.arange(n_lines, dtype=np.uint32)+1 + data[:, 0] = np.arange(n_lines, dtype=np.uint32) + 1 - for i_beh in range(2, 2+3*nb, 3): - data[i_beh:i_beh+1, 1] += 1 - data[1:i_beh+1, 1] += 2 + for i_beh in range(2, 2 + 3 * nb, 3): + data[i_beh : i_beh + 1, 1] += 1 + data[1 : i_beh + 1, 1] += 2 - for i_map in range(i_beh+3, i_beh+1+3*nm, 3): - data[i_map:i_map+1, 1] += 1 - data[i_beh+2:i_map+1] += 2 + for i_map in range(i_beh + 3, i_beh + 1 + 3 * nm, 3): + data[i_map : i_map + 1, 1] += 1 + data[i_beh + 2 : i_map + 1] += 2 - for i_replay in range(i_map+4, i_map+1+3*nb, 3): - data[i_replay: i_replay+1, 1] += 1 - data[i_map+4: i_replay+1] += 2 + for i_replay in range(i_map + 4, i_map + 1 + 3 * nb, 3): + data[i_replay : i_replay + 1, 1] += 1 + data[i_map + 4 : i_replay + 1] += 2 - with h5py.File(sync_path, 'w') as out_file: - out_file.create_dataset( - 'data', data=data) - out_file.create_dataset( - 'meta', - data=json.dumps(metadata).encode('utf-8')) + with h5py.File(sync_path, "w") as out_file: + out_file.create_dataset("data", data=data) + out_file.create_dataset("meta", data=json.dumps(metadata).encode("utf-8")) yield sync_path helper_functions.windows_safe_cleanup(file_path=sync_path) -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def sync_file_fixture(sync_path_fixture): - return SyncFile.from_json( - dict_repr={'sync_file': - str(sync_path_fixture.resolve().absolute())}) + return SyncFile.from_json(dict_repr={"sync_file": str(sync_path_fixture.resolve().absolute())}) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_multi_stim_running_processing.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_multi_stim_running_processing.py index dfd6ddd14a..5332909205 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_multi_stim_running_processing.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_multi_stim_running_processing.py @@ -6,56 +6,42 @@ import pathlib import pandas as pd -from allensdk.brain_observatory.behavior.\ - data_objects.running_speed.running_processing import ( - get_running_df) +from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import get_running_df from allensdk.brain_observatory.behavior.data_files.stimulus_file import ( BehaviorStimulusFile, ReplayStimulusFile, - MappingStimulusFile) + MappingStimulusFile, +) -from allensdk.brain_observatory.behavior.\ - data_objects.running_speed.multi_stim_running_processing import ( - _extract_dx_info, - _merge_dx_data, - multi_stim_running_df_from_raw_data) +from allensdk.brain_observatory.behavior.data_objects.running_speed.multi_stim_running_processing import ( + _extract_dx_info, + _merge_dx_data, + multi_stim_running_df_from_raw_data, +) -@pytest.mark.parametrize( - "use_lowpass, zscore", - product((True, False), - (5.0, 10.0))) -def test_extract_dx_basic( - basic_running_stim_file_fixture, - use_lowpass, - zscore, - tmp_path_factory, - helper_functions): +@pytest.mark.parametrize("use_lowpass, zscore", product((True, False), (5.0, 10.0))) +def test_extract_dx_basic(basic_running_stim_file_fixture, use_lowpass, zscore, tmp_path_factory, helper_functions): """ Test that _extract_dx_info behaves like get_running_df with start_index and end_index handled correctly """ - tmpdir = tmp_path_factory.mktemp('extract_dx_test') + tmpdir = tmp_path_factory.mktemp("extract_dx_test") stim_file = basic_running_stim_file_fixture - n_time = len(stim_file['items']['behavior']['encoders'][0]['dx']) - time_array = np.linspace(0., 10., n_time) + n_time = len(stim_file["items"]["behavior"]["encoders"][0]["dx"]) + time_array = np.linspace(0.0, 10.0, n_time) expected_stim = copy.deepcopy(stim_file) - for key in ('dx', 'vsig', 'vin'): - raw = expected_stim['items']['behavior']['encoders'][0].pop(key) - expected_stim['items']['behavior']['encoders'][0][key] = raw + for key in ("dx", "vsig", "vin"): + raw = expected_stim["items"]["behavior"]["encoders"][0].pop(key) + expected_stim["items"]["behavior"]["encoders"][0][key] = raw - expected = get_running_df( - data=expected_stim, - time=time_array, - lowpass=use_lowpass, - zscore_threshold=zscore) + expected = get_running_df(data=expected_stim, time=time_array, lowpass=use_lowpass, zscore_threshold=zscore) - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) pd.to_pickle(expected_stim, pkl_path) @@ -64,34 +50,31 @@ def __init__(self, pkl_path): self.data = pd.read_pickle(pkl_path) actual = _extract_dx_info( - frame_times=time_array, - stimulus_file=DummyStimFile(pkl_path.resolve().absolute()), - zscore_threshold=zscore, - use_lowpass_filter=use_lowpass) + frame_times=time_array, + stimulus_file=DummyStimFile(pkl_path.resolve().absolute()), + zscore_threshold=zscore, + use_lowpass_filter=use_lowpass, + ) pd.testing.assert_frame_equal(actual, expected) helper_functions.windows_safe_cleanup(file_path=pkl_path) -def test_extract_dx_time_mismatch( - basic_running_stim_file_fixture, - tmp_path_factory, - helper_functions): +def test_extract_dx_time_mismatch(basic_running_stim_file_fixture, tmp_path_factory, helper_functions): """ Test that an exception gets thrown if frame_times is not of the correct length """ - tmpdir = tmp_path_factory.mktemp('extract_dx_test') + tmpdir = tmp_path_factory.mktemp("extract_dx_test") stim_file = basic_running_stim_file_fixture - n_time = len(stim_file['items']['behavior']['encoders'][0]['dx']) - time_array = np.linspace(0., 10., n_time-5) + n_time = len(stim_file["items"]["behavior"]["encoders"][0]["dx"]) + time_array = np.linspace(0.0, 10.0, n_time - 5) - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmpdir, suffix='.pkl')[1]) + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".pkl")[1]) pd.to_pickle(stim_file, pkl_path) @@ -101,88 +84,83 @@ def __init__(self, pkl_path): with pytest.raises(ValueError, match="length of v_in"): _extract_dx_info( - frame_times=time_array, - stimulus_file=DummyStimFile(pkl_path.resolve().absolute()), - zscore_threshold=10.0, - use_lowpass_filter=True) + frame_times=time_array, + stimulus_file=DummyStimFile(pkl_path.resolve().absolute()), + zscore_threshold=10.0, + use_lowpass_filter=True, + ) helper_functions.windows_safe_cleanup(file_path=pkl_path) -@pytest.mark.parametrize('start_frame', [0, 15]) -def test_merge_dx_data(merge_data_fixture, - start_frame): +@pytest.mark.parametrize("start_frame", [0, 15]) +def test_merge_dx_data(merge_data_fixture, start_frame): """ Test that _merge_dx_data correctly merges dataframes """ - frame_times = np.linspace(0., - 10., - merge_data_fixture['n_timesteps']) + frame_times = np.linspace(0.0, 10.0, merge_data_fixture["n_timesteps"]) - (velocity_df, - raw_df) = _merge_dx_data( - mapping_velocities=merge_data_fixture['mapping']['dataframe'], - behavior_velocities=merge_data_fixture['behavior']['dataframe'], - replay_velocities=merge_data_fixture['replay']['dataframe'], + (velocity_df, raw_df) = _merge_dx_data( + mapping_velocities=merge_data_fixture["mapping"]["dataframe"], + behavior_velocities=merge_data_fixture["behavior"]["dataframe"], + replay_velocities=merge_data_fixture["replay"]["dataframe"], frame_times=frame_times, - behavior_start_frame=start_frame) + behavior_start_frame=start_frame, + ) - assert set(velocity_df.columns) == set(['velocity', - 'net_rotation', - 'frame_indexes', - 'frame_time']) + assert set(velocity_df.columns) == set(["velocity", "net_rotation", "frame_indexes", "frame_time"]) - assert set(raw_df.columns) == set(['vsig', 'vin', 'frame_time', 'dx']) + assert set(raw_df.columns) == set(["vsig", "vin", "frame_time", "dx"]) # make sure that velocity_df skipped some timesteps assert len(velocity_df.velocity.values) < len(raw_df.dx.values) # check that velocity_df has the expected values, having # skipped the right timesteps - for in_key, out_key in zip(('dx', 'speed'), - ('net_rotation', 'velocity')): + for in_key, out_key in zip(("dx", "speed"), ("net_rotation", "velocity")): expected = [] - for pkl_key in ('behavior', 'mapping', 'replay'): - kept = merge_data_fixture[pkl_key]['kept_mask'] + for pkl_key in ("behavior", "mapping", "replay"): + kept = merge_data_fixture[pkl_key]["kept_mask"] expected.append(merge_data_fixture[pkl_key][in_key][kept]) expected = np.concatenate(expected) np.testing.assert_array_equal(velocity_df[out_key].values, expected) # check contents of frame_time and frame_indexes - global_kept = np.concatenate([merge_data_fixture['behavior']['kept_mask'], - merge_data_fixture['mapping']['kept_mask'], - merge_data_fixture['replay']['kept_mask']]) - np.testing.assert_array_equal( - frame_times[global_kept], - velocity_df.frame_time.values) + global_kept = np.concatenate( + [ + merge_data_fixture["behavior"]["kept_mask"], + merge_data_fixture["mapping"]["kept_mask"], + merge_data_fixture["replay"]["kept_mask"], + ] + ) + np.testing.assert_array_equal(frame_times[global_kept], velocity_df.frame_time.values) np.testing.assert_array_equal( - np.arange(start_frame, - merge_data_fixture['n_timesteps']+start_frame, - dtype=int)[global_kept], - velocity_df.frame_indexes.values) + np.arange(start_frame, merge_data_fixture["n_timesteps"] + start_frame, dtype=int)[global_kept], + velocity_df.frame_indexes.values, + ) # check contents of raw_df - for in_key, out_key in zip(('dx', 'v_in', 'v_sig'), - ('dx', 'vin', 'vsig')): + for in_key, out_key in zip(("dx", "v_in", "v_sig"), ("dx", "vin", "vsig")): expected = [] - for pkl_key in ('behavior', 'mapping', 'replay'): + for pkl_key in ("behavior", "mapping", "replay"): expected.append(merge_data_fixture[pkl_key][in_key]) expected = np.concatenate(expected) np.testing.assert_array_equal(expected, raw_df[out_key].values) - np.testing.assert_array_equal(raw_df.frame_time.values, - frame_times) + np.testing.assert_array_equal(raw_df.frame_time.values, frame_times) -@pytest.mark.parametrize("start_frame", [0, ]) +@pytest.mark.parametrize( + "start_frame", + [ + 0, + ], +) def test_multi_stim_running_df_from_raw_data( - start_frame, - behavior_pkl_fixture, - replay_pkl_fixture, - mapping_pkl_fixture, - sync_path_fixture): + start_frame, behavior_pkl_fixture, replay_pkl_fixture, mapping_pkl_fixture, sync_path_fixture +): """ test that multi_stim_running_df_from_raw_data can properly process stimulus pickle files @@ -191,24 +169,18 @@ def test_multi_stim_running_df_from_raw_data( use_lowpass = True zscore = 10.0 - b_stim = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': - behavior_pkl_fixture['path_to_pkl']}) - - r_stim = ReplayStimulusFile.from_json( - dict_repr={'replay_stimulus_file': - replay_pkl_fixture['path_to_pkl']}) - - m_stim = MappingStimulusFile.from_json( - dict_repr={'mapping_stimulus_file': - mapping_pkl_fixture['path_to_pkl']}) - - (velocities_df, - raw_df) = multi_stim_running_df_from_raw_data( - sync_path=sync_path_fixture, - behavior_stimulus_file=b_stim, - mapping_stimulus_file=m_stim, - replay_stimulus_file=r_stim, - use_lowpass_filter=use_lowpass, - zscore_threshold=zscore, - behavior_start_frame=start_frame) + b_stim = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": behavior_pkl_fixture["path_to_pkl"]}) + + r_stim = ReplayStimulusFile.from_json(dict_repr={"replay_stimulus_file": replay_pkl_fixture["path_to_pkl"]}) + + m_stim = MappingStimulusFile.from_json(dict_repr={"mapping_stimulus_file": mapping_pkl_fixture["path_to_pkl"]}) + + (velocities_df, raw_df) = multi_stim_running_df_from_raw_data( + sync_path=sync_path_fixture, + behavior_stimulus_file=b_stim, + mapping_stimulus_file=m_stim, + replay_stimulus_file=r_stim, + use_lowpass_filter=use_lowpass, + zscore_threshold=zscore, + behavior_start_frame=start_frame, + ) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_acquisition.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_acquisition.py index 01d6aa8852..4921af43c0 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_acquisition.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_acquisition.py @@ -5,11 +5,9 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import ( # noqa: E501 - get_running_df -) -from allensdk.brain_observatory.behavior.data_objects import ( - RunningAcquisition, StimulusTimestamps + get_running_df, ) +from allensdk.brain_observatory.behavior.data_objects import RunningAcquisition, StimulusTimestamps def test_nonzero_monitor_delay_acq(): @@ -17,17 +15,13 @@ def test_nonzero_monitor_delay_acq(): Test that RunningAcquisition throws an exception if instantiated with a timestamps object that has non-zero monitor_delay """ + class OtherTimestamps(object): monitor_delay = 0.01 value = 0.0 - with pytest.raises(RuntimeError, - match="should be no monitor delay"): - - RunningAcquisition( - running_acquisition=4.0, - stimulus_file=None, - stimulus_timestamps=OtherTimestamps()) + with pytest.raises(RuntimeError, match="should be no monitor delay"): + RunningAcquisition(running_acquisition=4.0, stimulus_file=None, stimulus_timestamps=OtherTimestamps()) @pytest.mark.parametrize( @@ -35,41 +29,26 @@ class OtherTimestamps(object): [ ( # dict_repr - { - "behavior_stimulus_file": "mock_stimulus_file.pkl", - "monitor_delay": 0.0 - }, + {"behavior_stimulus_file": "mock_stimulus_file.pkl", "monitor_delay": 0.0}, # returned_running_acq_df pd.DataFrame( - { - "timestamps": [1, 2], - "speed": [3, 4], - "dx": [5, 6], - "v_sig": [7, 8], - "v_in": [9, 10] - } + {"timestamps": [1, 2], "speed": [3, 4], "dx": [5, 6], "v_sig": [7, 8], "v_in": [9, 10]} ).set_index("timestamps"), # expected_running_acq_df - pd.DataFrame( - { - "timestamps": [1, 2], - "dx": [5, 6], - "v_sig": [7, 8], - "v_in": [9, 10] - } - ).set_index("timestamps") + pd.DataFrame({"timestamps": [1, 2], "dx": [5, 6], "v_sig": [7, 8], "v_in": [9, 10]}).set_index( + "timestamps" + ), ), - ] + ], ) -def test_running_acquisition_from_json( - monkeypatch, dict_repr, returned_running_acq_df, expected_running_acq_df -): +def test_running_acquisition_from_json(monkeypatch, dict_repr, returned_running_acq_df, expected_running_acq_df): mock_stimulus_file = create_autospec(BehaviorStimulusFile) mock_stimulus_timestamps = create_autospec(StimulusTimestamps) class DummyTimestamps(object): monitor_delay = 0.0 value = 0.0 + dummy_ts = DummyTimestamps() mock_stimulus_timestamps.from_stimulus_file.return_value = dummy_ts mock_stimulus_timestamps.from_json.return_value = dummy_ts @@ -80,30 +59,24 @@ class DummyTimestamps(object): with monkeypatch.context() as m: m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_acquisition.BehaviorStimulusFile", - mock_stimulus_file + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_acquisition.BehaviorStimulusFile", + mock_stimulus_file, ) m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_acquisition.StimulusTimestamps", - mock_stimulus_timestamps + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_acquisition.StimulusTimestamps", + mock_stimulus_timestamps, ) m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_acquisition.get_running_df", - mock_get_running_df + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_acquisition.get_running_df", + mock_get_running_df, ) - obt = RunningAcquisition.from_stimulus_file( - behavior_stimulus_file=mock_stimulus_file) + obt = RunningAcquisition.from_stimulus_file(behavior_stimulus_file=mock_stimulus_file) mock_stimulus_file_instance = mock_stimulus_file.from_json(dict_repr) - mock_stimulus_timestamps_instance = \ - mock_stimulus_timestamps.from_stimulus_file( - stimulus_file=mock_stimulus_file_instance, - monitor_delay=0.0 - ) + mock_stimulus_timestamps_instance = mock_stimulus_timestamps.from_stimulus_file( + stimulus_file=mock_stimulus_file_instance, monitor_delay=0.0 + ) assert obt._stimulus_timestamps == mock_stimulus_timestamps_instance pd.testing.assert_frame_equal(obt.value, expected_running_acq_df) @@ -115,22 +88,18 @@ class DummyTimestamps(object): # data_object_roundtrip_fixture: # test/brain_observatory/behavior/data_objects/conftest.py @pytest.mark.parametrize("roundtrip", [True, False]) -@pytest.mark.parametrize("running_acq_data", [ - ( - # expected_running_acq_df - pd.DataFrame( - { - "timestamps": [2.0, 4.0], - "dx": [10.0, 12.0], - "v_sig": [14.0, 16.0], - "v_in": [18.0, 20.0] - } - ).set_index("timestamps") - ), -]) -def test_running_acquisition_nwb_roundtrip( - nwbfile, data_object_roundtrip_fixture, roundtrip, running_acq_data -): +@pytest.mark.parametrize( + "running_acq_data", + [ + ( + # expected_running_acq_df + pd.DataFrame( + {"timestamps": [2.0, 4.0], "dx": [10.0, 12.0], "v_sig": [14.0, 16.0], "v_in": [18.0, 20.0]} + ).set_index("timestamps") + ), + ], +) +def test_running_acquisition_nwb_roundtrip(nwbfile, data_object_roundtrip_fixture, roundtrip, running_acq_data): running_acq = RunningAcquisition(running_acquisition=running_acq_data) nwbfile = running_acq.to_nwb(nwbfile) @@ -139,6 +108,4 @@ def test_running_acquisition_nwb_roundtrip( else: obt = RunningAcquisition.from_nwb(nwbfile) - pd.testing.assert_frame_equal( - obt.value, running_acq_data, check_like=True - ) + pd.testing.assert_frame_equal(obt.value, running_acq_data, check_like=True) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_processing.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_processing.py index b06b263d7b..021aed7dfe 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_processing.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_processing.py @@ -2,16 +2,23 @@ import pytest from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import ( # noqa: E501 - get_running_df, calc_deriv, deg_to_dist, _shift, _identify_wraps, - _unwrap_voltage_signal, _angular_change, _zscore_threshold_1d, - _clip_speed_wraps) + get_running_df, + calc_deriv, + deg_to_dist, + _shift, + _identify_wraps, + _unwrap_voltage_signal, + _angular_change, + _zscore_threshold_1d, + _clip_speed_wraps, +) import allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing as rp # noqa: E501 @pytest.fixture def timestamps(): - return np.arange(0., 10., 0.1) + return np.arange(0.0, 10.0, 0.1) @pytest.fixture @@ -25,15 +32,20 @@ def running_data(): "dx": rng.random((100,)), "vsig": rng.uniform(low=0.0, high=5.1, size=(100,)), "vin": rng.uniform(low=4.9, high=5.0, size=(100,)), - }]}}} + } + ] + } + } + } @pytest.mark.parametrize( - "x,time,expected", [ + "x,time,expected", + [ ([1.0, 1.0], [1.0, 2.0], [np.nan, 0.0]), ([1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [np.nan, 1.0, 1.0]), - ([1.0, 2.0, 3.0], [1.0, 4.0, 6.0], [np.nan, 1.0/3.0, 1.0/2.0]) - ] + ([1.0, 2.0, 3.0], [1.0, 4.0, 6.0], [np.nan, 1.0 / 3.0, 1.0 / 2.0]), + ], ) def test_calc_deriv(x, time, expected): obtained = calc_deriv(x, time) @@ -45,98 +57,58 @@ def test_calc_deriv(x, time, expected): assert np.all(obtained == expected) -@pytest.mark.parametrize( - "speed,expected", [ - (np.array([1.0]), [5.5033]), - (np.array([0., 2.0]), [0., 11.0066]) - ] -) +@pytest.mark.parametrize("speed,expected", [(np.array([1.0]), [5.5033]), (np.array([0.0, 2.0]), [0.0, 11.0066])]) def test_deg_to_dist(speed, expected): np.testing.assert_allclose(deg_to_dist(speed), expected, atol=0.0001) -@pytest.mark.parametrize( - "lowpass", [True, False] -) +@pytest.mark.parametrize("lowpass", [True, False]) def test_get_running_df(running_data, timestamps, lowpass): actual = get_running_df(running_data, timestamps, lowpass=lowpass) np.testing.assert_array_equal(actual.index, timestamps) assert sorted(list(actual)) == ["dx", "speed", "v_in", "v_sig"] # Should bring raw data through - np.testing.assert_array_equal( - actual["v_sig"].values, - running_data["items"]["behavior"]["encoders"][0]["vsig"]) - np.testing.assert_array_equal( - actual["v_in"].values, - running_data["items"]["behavior"]["encoders"][0]["vin"]) - np.testing.assert_array_equal( - actual["dx"].values, - running_data["items"]["behavior"]["encoders"][0]["dx"]) + np.testing.assert_array_equal(actual["v_sig"].values, running_data["items"]["behavior"]["encoders"][0]["vsig"]) + np.testing.assert_array_equal(actual["v_in"].values, running_data["items"]["behavior"]["encoders"][0]["vin"]) + np.testing.assert_array_equal(actual["dx"].values, running_data["items"]["behavior"]["encoders"][0]["dx"]) if lowpass: assert np.count_nonzero(np.isnan(actual["speed"])) == 0 -@pytest.mark.parametrize( - "lowpass", [True, False] -) -def test_get_running_df_one_fewer_timestamp_check_warning(running_data, - timestamps, - lowpass): - with pytest.warns( - UserWarning, - match="Time array is 1 value shorter than encoder array.*" - ): +@pytest.mark.parametrize("lowpass", [True, False]) +def test_get_running_df_one_fewer_timestamp_check_warning(running_data, timestamps, lowpass): + with pytest.warns(UserWarning, match="Time array is 1 value shorter than encoder array.*"): # Call with one fewer timestamp, check for a warning - _ = get_running_df( - data=running_data, - time=timestamps[:-1], - lowpass=lowpass - ) + _ = get_running_df(data=running_data, time=timestamps[:-1], lowpass=lowpass) -@pytest.mark.parametrize( - "lowpass", [True, False] -) -def test_get_running_df_one_fewer_timestamp_check_truncation(running_data, - timestamps, - lowpass): +@pytest.mark.parametrize("lowpass", [True, False]) +def test_get_running_df_one_fewer_timestamp_check_truncation(running_data, timestamps, lowpass): # Call with one fewer timestamp - output = get_running_df( - data=running_data, - time=timestamps[:-1], - lowpass=lowpass - ) + output = get_running_df(data=running_data, time=timestamps[:-1], lowpass=lowpass) # Check that the output is actually trimmed, and the values are the same assert len(output) == len(timestamps) - 1 - np.testing.assert_equal( - output["v_sig"], - running_data["items"]["behavior"]["encoders"][0]["vsig"][:-1] - ) - np.testing.assert_equal( - output["v_in"], - running_data["items"]["behavior"]["encoders"][0]["vin"][:-1] - ) + np.testing.assert_equal(output["v_sig"], running_data["items"]["behavior"]["encoders"][0]["vsig"][:-1]) + np.testing.assert_equal(output["v_in"], running_data["items"]["behavior"]["encoders"][0]["vin"][:-1]) @pytest.mark.parametrize( "arr, periods, fill, expected", [ - ([1, 2, 3], 1, None, np.array([np.nan, 1., 2.])), - ([1, 2, 3], 2, 99., np.array([99., 99., 1.])), - ([1, 2, 3], 3, 99., np.array([99., 99., 99.])), - ([1, 2, 3], 4, 99., np.array([99., 99., 99.])), - ([], 2, 30, np.array([])) - ] + ([1, 2, 3], 1, None, np.array([np.nan, 1.0, 2.0])), + ([1, 2, 3], 2, 99.0, np.array([99.0, 99.0, 1.0])), + ([1, 2, 3], 3, 99.0, np.array([99.0, 99.0, 99.0])), + ([1, 2, 3], 4, 99.0, np.array([99.0, 99.0, 99.0])), + ([], 2, 30, np.array([])), + ], ) def test_shift(arr, periods, fill, expected): actual = _shift(arr, periods, fill) np.testing.assert_array_equal(actual, expected) -@pytest.mark.parametrize( - "periods", [0, -2] -) +@pytest.mark.parametrize("periods", [0, -2]) def test_shift_raises_error_periods_zero(periods): with pytest.raises(ValueError, match="Can only shift"): _shift(np.ones((5,)), periods) @@ -145,74 +117,96 @@ def test_shift_raises_error_periods_zero(periods): @pytest.mark.parametrize( "arr, min_threshold, max_threshold, expected", [ - (np.array( - [0, 2, 5, 0, # pos wrap 5-0 - 2, 0, 5 # neg wrap 0-5 - ]), 1.5, 3.5, (np.array([3]), np.array([6]))), + ( + np.array( + [ + 0, + 2, + 5, + 0, # pos wrap 5-0 + 2, + 0, + 5, # neg wrap 0-5 + ] + ), + 1.5, + 3.5, + (np.array([3]), np.array([6])), + ), (np.array([0, 2, 5, 0, 2, 5]), 0, 5, (np.array([]), np.array([]))), - ] + ], ) def test_identify_wraps(arr, min_threshold, max_threshold, expected): - actual = _identify_wraps( - arr, min_threshold=min_threshold, max_threshold=max_threshold) + actual = _identify_wraps(arr, min_threshold=min_threshold, max_threshold=max_threshold) np.testing.assert_array_equal( - actual[0], expected[0], - f"error identifying positive wraps, got {actual[0]}, " - f"expected {expected[0]}") + actual[0], expected[0], f"error identifying positive wraps, got {actual[0]}, expected {expected[0]}" + ) np.testing.assert_array_equal( - actual[1], expected[1], - f"error identifying negative wraps, got {actual[1]}, " - f"expected {expected[1]}") + actual[1], expected[1], f"error identifying negative wraps, got {actual[1]}, expected {expected[1]}" + ) @pytest.mark.parametrize( "vsig, pos_wrap_ix, neg_wrap_ix, vmax, max_threshold, max_diff, expected", [ - ( # No artifacts or baseline + ( # No artifacts or baseline np.array([0, 1, 3, 5, 0.5, 1, 2.5, 5, 0, 1, 4, 3]), - np.array([4, 8]), np.array([10]), 5.0, 5.1, 3.0, - np.array([np.nan, 1, 3, 5, 5.5, 6, 7.5, 10, 10, 11, 9, 8]) + np.array([4, 8]), + np.array([10]), + 5.0, + 5.1, + 3.0, + np.array([np.nan, 1, 3, 5, 5.5, 6, 7.5, 10, 10, 11, 9, 8]), ), - ( # Some diff artifacts, baseline + ( # Some diff artifacts, baseline np.array([1, 1, 3, 5, 0.5, 1, 2.5, 5, 0, 1, 4, 1.5]), - np.array([4, 8]), np.array([10]), 5.0, 5.1, 2.0, - np.array([np.nan, 1, 3, 5, 5.5, 6, 7.5, np.nan, 7.5, 8.5, 6.5, - np.nan]) + np.array([4, 8]), + np.array([10]), + 5.0, + 5.1, + 2.0, + np.array([np.nan, 1, 3, 5, 5.5, 6, 7.5, np.nan, 7.5, 8.5, 6.5, np.nan]), ), - ( # Max artifact -- use threshold instead + ( # Max artifact -- use threshold instead np.array([0, 7, 3, 5, 0.5, 1]), - np.array([2, 4]), np.array([]).astype(int), None, 5.1, 6.0, - np.array([np.nan, np.nan, 1, 3, 3.5, 4]) + np.array([2, 4]), + np.array([]).astype(int), + None, + 5.1, + 6.0, + np.array([np.nan, np.nan, 1, 3, 3.5, 4]), ), ( # No wraps - np.ones(5,), np.array([]), np.array([]), 5.0, 5.1, 3.0, - np.array([np.nan, 1., 1., 1., 1.]) - ) - ] + np.ones( + 5, + ), + np.array([]), + np.array([]), + 5.0, + 5.1, + 3.0, + np.array([np.nan, 1.0, 1.0, 1.0, 1.0]), + ), + ], ) -def test_unwrap_voltage_signal( - vsig, pos_wrap_ix, neg_wrap_ix, vmax, max_threshold, max_diff, - expected): +def test_unwrap_voltage_signal(vsig, pos_wrap_ix, neg_wrap_ix, vmax, max_threshold, max_diff, expected): actual = _unwrap_voltage_signal( - vsig, pos_wrap_ix, neg_wrap_ix, vmax=vmax, - max_threshold=max_threshold, max_diff=max_diff) + vsig, pos_wrap_ix, neg_wrap_ix, vmax=vmax, max_threshold=max_threshold, max_diff=max_diff + ) np.testing.assert_array_equal(actual, expected) @pytest.mark.parametrize( "vsig, vmax, expected", [ - ( - np.array([1, 2, 3, 4, 5]), 2.0, - np.array([np.nan, np.pi, np.pi, np.pi, np.pi]) - ), + (np.array([1, 2, 3, 4, 5]), 2.0, np.array([np.nan, np.pi, np.pi, np.pi, np.pi])), ( np.array([np.nan, 1, 3, np.nan, 4]), np.array([2.0, 2.0, 2.0, 2.0, 2.0]), - np.array([np.nan, np.nan, 2*np.pi, np.nan, np.nan]), - ) - ] + np.array([np.nan, np.nan, 2 * np.pi, np.nan, np.nan]), + ), + ], ) def test_angular_change(vsig, vmax, expected): actual = _angular_change(vsig, vmax) @@ -222,11 +216,26 @@ def test_angular_change(vsig, vmax, expected): @pytest.mark.parametrize( "arr, threshold, expected", [ - (np.ones(5,), 2.0, np.ones(5,)), - (np.array([99, 1, np.nan, 1, 1, 1]), 1.5, - np.array([np.nan, 1, np.nan, 1, 1, 1])), - (np.ones(1,), 2.0, np.ones(1,)) - ] + ( + np.ones( + 5, + ), + 2.0, + np.ones( + 5, + ), + ), + (np.array([99, 1, np.nan, 1, 1, 1]), 1.5, np.array([np.nan, 1, np.nan, 1, 1, 1])), + ( + np.ones( + 1, + ), + 2.0, + np.ones( + 1, + ), + ), + ], ) def test_zscore_threshold_1d(arr, threshold, expected): actual = _zscore_threshold_1d(arr, threshold=threshold) @@ -236,18 +245,17 @@ def test_zscore_threshold_1d(arr, threshold, expected): @pytest.mark.parametrize( "speed, time, wrap_indices, span, expected", [ - ( # Clip bottom, then clip top, then no clip required + ( # Clip bottom, then clip top, then no clip required np.array([0, 0, -1, 5, 0, 99, 6, 1, 2, 3]), np.array(range(10)).astype(float), [2, 5, 8], 1.0, - np.array([0, 0, 0, 5, 0, 6, 6, 1, 2, 3]) + np.array([0, 0, 0, 5, 0, 6, 6, 1, 2, 3]), ), - ] + ], ) -def test_clip_speed_wraps( - speed, time, wrap_indices, span, expected, monkeypatch): - monkeypatch.setattr(rp, "_local_boundaries", lambda x, y, z: (y-1, y+1)) +def test_clip_speed_wraps(speed, time, wrap_indices, span, expected, monkeypatch): + monkeypatch.setattr(rp, "_local_boundaries", lambda x, y, z: (y - 1, y + 1)) actual = _clip_speed_wraps(speed, time, wrap_indices, span) np.testing.assert_array_equal(actual, expected) @@ -255,10 +263,10 @@ def test_clip_speed_wraps( @pytest.mark.parametrize( "time, index, span", [ - (np.arange(10.), 0, 2.0), # no neighborhood before first point - (np.arange(10.), 9, 2.0), # no neighborhood after last point - (np.arange(10.), 4, 0.25), # data not sampled with enough frequency - ] + (np.arange(10.0), 0, 2.0), # no neighborhood before first point + (np.arange(10.0), 9, 2.0), # no neighborhood after last point + (np.arange(10.0), 4, 0.25), # data not sampled with enough frequency + ], ) def test_local_boundaries_raises_warning(time, index, span): with pytest.warns(UserWarning, match="Unable to find"): @@ -268,10 +276,10 @@ def test_local_boundaries_raises_warning(time, index, span): @pytest.mark.parametrize( "time, index", [ - (np.array([5., 4., 3., 4., 5.]), 2), - (np.array([1., 2., 3., 2., 1.]), 2), - (np.array([3., 3., 3., 2., 1.]), 2), - ] + (np.array([5.0, 4.0, 3.0, 4.0, 5.0]), 2), + (np.array([1.0, 2.0, 3.0, 2.0, 1.0]), 2), + (np.array([3.0, 3.0, 3.0, 2.0, 1.0]), 2), + ], ) def test_local_boundaries_raises_error_non_monotonic(time, index): with pytest.raises(ValueError, match="Data do not monotonically"): @@ -281,9 +289,9 @@ def test_local_boundaries_raises_error_non_monotonic(time, index): @pytest.mark.parametrize( "time, index, span, expected", [ - (np.arange(10.), 4, 2.0, (2.0, 6.0)), # Spans > 1 element +/- - (np.arange(10.), 2, 1.0, (1.0, 3.0)) # Spans = 1 element +/- - ] + (np.arange(10.0), 4, 2.0, (2.0, 6.0)), # Spans > 1 element +/- + (np.arange(10.0), 2, 1.0, (1.0, 3.0)), # Spans = 1 element +/- + ], ) def test_local_boundaries(time, index, span, expected): actual = rp._local_boundaries(time, index, span) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed.py index 8e52241b6f..c6056d3e88 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed.py @@ -6,11 +6,9 @@ from allensdk.core.exceptions import DataFrameIndexError from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects.running_speed.running_processing import ( # noqa: E501 - get_running_df -) -from allensdk.brain_observatory.behavior.data_objects import ( - RunningSpeed, StimulusTimestamps + get_running_df, ) +from allensdk.brain_observatory.behavior.data_objects import RunningSpeed, StimulusTimestamps class DummyTimestamps(object): @@ -18,6 +16,7 @@ class DummyTimestamps(object): A class meant to mock the StimulusTimestamps API by providing monitor_delay=0.0, and value=0.0 """ + monitor_delay = 0.0 value = 0.0 @@ -27,84 +26,67 @@ def test_nonzero_monitor_delay_speed(): Test that RunningSpeed throws an exception if instantiated with a timestamps object that has non-zero monitor_delay """ + class OtherTimestamps(object): monitor_delay = 0.01 value = 0.0 - with pytest.raises(RuntimeError, - match="should be no monitor delay"): - + with pytest.raises(RuntimeError, match="should be no monitor delay"): RunningSpeed( - running_speed=pd.DataFrame({'speed': [4.0]}), + running_speed=pd.DataFrame({"speed": [4.0]}), stimulus_file=None, sync_file=None, - stimulus_timestamps=OtherTimestamps()) + stimulus_timestamps=OtherTimestamps(), + ) @pytest.mark.parametrize("filtered", [True, False]) @pytest.mark.parametrize("zscore_threshold", [1.0, 4.2]) -@pytest.mark.parametrize("returned_running_df, expected_running_df, raises", [ - # Test basic case - ( - # returned_running_df - pd.DataFrame({ - "timestamps": [2, 4, 6, 8], - "speed": [1, 2, 3, 4] - }).set_index("timestamps"), - # expected_running_df - pd.DataFrame({ - "timestamps": [2, 4, 6, 8], - "speed": [1, 2, 3, 4] - }), - # raises - False - ), - # Test when returned dataframe lacks "timestamps" as index - ( - # returned_running_df - pd.DataFrame({ - "timestamps": [2, 4, 6, 8], - "speed": [1, 2, 3, 4] - }).set_index("speed"), - # expected_running_df - None, - # raises - "Expected running_data_df index to be named 'timestamps'" - ), -]) +@pytest.mark.parametrize( + "returned_running_df, expected_running_df, raises", + [ + # Test basic case + ( + # returned_running_df + pd.DataFrame({"timestamps": [2, 4, 6, 8], "speed": [1, 2, 3, 4]}).set_index("timestamps"), + # expected_running_df + pd.DataFrame({"timestamps": [2, 4, 6, 8], "speed": [1, 2, 3, 4]}), + # raises + False, + ), + # Test when returned dataframe lacks "timestamps" as index + ( + # returned_running_df + pd.DataFrame({"timestamps": [2, 4, 6, 8], "speed": [1, 2, 3, 4]}).set_index("speed"), + # expected_running_df + None, + # raises + "Expected running_data_df index to be named 'timestamps'", + ), + ], +) def test_get_running_speed_df( - monkeypatch, returned_running_df, filtered, zscore_threshold, - expected_running_df, raises + monkeypatch, returned_running_df, filtered, zscore_threshold, expected_running_df, raises ): - - mock_stimulus_file_instance = create_autospec( - BehaviorStimulusFile, - instance=True) - mock_stimulus_timestamps_instance = create_autospec( - StimulusTimestamps, instance=True - ) + mock_stimulus_file_instance = create_autospec(BehaviorStimulusFile, instance=True) + mock_stimulus_timestamps_instance = create_autospec(StimulusTimestamps, instance=True) mock_get_running_speed_df = create_autospec(get_running_df) mock_get_running_speed_df.return_value = returned_running_df with monkeypatch.context() as m: m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_speed.get_running_df", - mock_get_running_speed_df + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed.get_running_df", + mock_get_running_speed_df, ) if raises: with pytest.raises(DataFrameIndexError, match=raises): _ = RunningSpeed._get_running_speed_df( - mock_stimulus_file_instance, - mock_stimulus_timestamps_instance, - filtered, zscore_threshold + mock_stimulus_file_instance, mock_stimulus_timestamps_instance, filtered, zscore_threshold ) else: obt = RunningSpeed._get_running_speed_df( - mock_stimulus_file_instance, - mock_stimulus_timestamps_instance, - filtered, zscore_threshold + mock_stimulus_file_instance, mock_stimulus_timestamps_instance, filtered, zscore_threshold ) pd.testing.assert_frame_equal(obt, expected_running_df) @@ -113,7 +95,7 @@ def test_get_running_speed_df( data=mock_stimulus_file_instance.data, time=mock_stimulus_timestamps_instance.value, lowpass=filtered, - zscore_threshold=zscore_threshold + zscore_threshold=zscore_threshold, ) @@ -124,23 +106,16 @@ def test_get_running_speed_df( [ ( # dict_repr - { - "behavior_stimulus_file": "mock_stimulus_file.pkl" - }, + {"behavior_stimulus_file": "mock_stimulus_file.pkl"}, # returned_running_df - pd.DataFrame( - {"timestamps": [1, 2], "speed": [3, 4]} - ).set_index("timestamps"), + pd.DataFrame({"timestamps": [1, 2], "speed": [3, 4]}).set_index("timestamps"), # expected_running_df - pd.DataFrame( - {"timestamps": [1, 2], "speed": [3, 4]} - ), + pd.DataFrame({"timestamps": [1, 2], "speed": [3, 4]}), ), - ] + ], ) def test_running_speed_from_json( - monkeypatch, dict_repr, returned_running_df, expected_running_df, - filtered, zscore_threshold + monkeypatch, dict_repr, returned_running_df, expected_running_df, filtered, zscore_threshold ): mock_stimulus_file = create_autospec(BehaviorStimulusFile) mock_stimulus_timestamps = create_autospec(StimulusTimestamps) @@ -155,40 +130,34 @@ def test_running_speed_from_json( with monkeypatch.context() as m: m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_speed.BehaviorStimulusFile", - mock_stimulus_file + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed.BehaviorStimulusFile", + mock_stimulus_file, ) m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_speed.StimulusTimestamps", - mock_stimulus_timestamps + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed.StimulusTimestamps", + mock_stimulus_timestamps, ) m.setattr( - "allensdk.brain_observatory.behavior.data_objects" - ".running_speed.running_speed.get_running_df", - mock_get_running_speed_df + "allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed.get_running_df", + mock_get_running_speed_df, ) obt = RunningSpeed.from_stimulus_file( - behavior_stimulus_file=mock_stimulus_file, - filtered=filtered, - zscore_threshold=zscore_threshold) + behavior_stimulus_file=mock_stimulus_file, filtered=filtered, zscore_threshold=zscore_threshold + ) mock_stimulus_file_instance = mock_stimulus_file.from_json(dict_repr) - mock_stimulus_timestamps_instance = \ - mock_stimulus_timestamps.from_stimulus_file( - stimulus_file=mock_stimulus_file_instance, - monitor_delay=0.0 - ) + mock_stimulus_timestamps_instance = mock_stimulus_timestamps.from_stimulus_file( + stimulus_file=mock_stimulus_file_instance, monitor_delay=0.0 + ) assert obt._stimulus_timestamps == mock_stimulus_timestamps_instance mock_get_running_speed_df.assert_called_once_with( data=mock_stimulus_file.data, time=mock_stimulus_timestamps_instance.value, lowpass=filtered, - zscore_threshold=zscore_threshold + zscore_threshold=zscore_threshold, ) assert obt._filtered == filtered @@ -202,22 +171,18 @@ def test_running_speed_from_json( # test/brain_observatory/behavior/data_objects/conftest.py @pytest.mark.parametrize("roundtrip", [True, False]) @pytest.mark.parametrize("filtered", [True, False]) -@pytest.mark.parametrize("running_speed_data", [ - (pd.DataFrame({"timestamps": [3.0, 4.0], "speed": [5.0, 6.0]})), -]) -def test_running_speed_nwb_roundtrip( - nwbfile, data_object_roundtrip_fixture, roundtrip, running_speed_data, - filtered -): - running_speed = RunningSpeed( - running_speed=running_speed_data, filtered=filtered - ) +@pytest.mark.parametrize( + "running_speed_data", + [ + (pd.DataFrame({"timestamps": [3.0, 4.0], "speed": [5.0, 6.0]})), + ], +) +def test_running_speed_nwb_roundtrip(nwbfile, data_object_roundtrip_fixture, roundtrip, running_speed_data, filtered): + running_speed = RunningSpeed(running_speed=running_speed_data, filtered=filtered) nwbfile = running_speed.to_nwb(nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile, RunningSpeed, filtered=filtered - ) + obt = data_object_roundtrip_fixture(nwbfile, RunningSpeed, filtered=filtered) else: obt = RunningSpeed.from_nwb(nwbfile, filtered=filtered) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed_from_multi_stim.py b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed_from_multi_stim.py index c942f844a8..1db2580568 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed_from_multi_stim.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/running_speed/test_running_speed_from_multi_stim.py @@ -6,97 +6,82 @@ import pytz import datetime -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.running_speed import ( - RunningSpeed) +from allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed import RunningSpeed -from allensdk.brain_observatory.behavior.data_objects.\ - running_speed.running_acquisition import ( - RunningAcquisition) +from allensdk.brain_observatory.behavior.data_objects.running_speed.running_acquisition import RunningAcquisition -@pytest.mark.parametrize('filtered', [True, False]) +@pytest.mark.parametrize("filtered", [True, False]) def test_vbn_running_speed_round_trip( - behavior_stim_file_fixture, - replay_stim_file_fixture, - mapping_stim_file_fixture, - sync_file_fixture, - tmp_path_factory, - helper_functions, - filtered): + behavior_stim_file_fixture, + replay_stim_file_fixture, + mapping_stim_file_fixture, + sync_file_fixture, + tmp_path_factory, + helper_functions, + filtered, +): """ Test that we can round trip VBNRunningSpeed between from_json, to_json, from_nwb, and to_nwb """ - tmpdir = tmp_path_factory.mktemp('vbn_running_roundtrip') - nwb_path = pathlib.Path( - tempfile.mkstemp( - dir=tmpdir, - suffix='.nwb')[1]) + tmpdir = tmp_path_factory.mktemp("vbn_running_roundtrip") + nwb_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".nwb")[1]) running_obj = RunningSpeed.from_multiple_stimulus_files( - behavior_stimulus_file=behavior_stim_file_fixture, - mapping_stimulus_file=mapping_stim_file_fixture, - replay_stimulus_file=replay_stim_file_fixture, - sync_file=sync_file_fixture, - filtered=filtered, - zscore_threshold=10.0) + behavior_stimulus_file=behavior_stim_file_fixture, + mapping_stimulus_file=mapping_stim_file_fixture, + replay_stimulus_file=replay_stim_file_fixture, + sync_file=sync_file_fixture, + filtered=filtered, + zscore_threshold=10.0, + ) start_time = pytz.utc.localize(datetime.datetime(2020, 7, 11)) - nwbfile = NWBFile( - session_description='running', - identifier='00001', - session_start_time=start_time) + nwbfile = NWBFile(session_description="running", identifier="00001", session_start_time=start_time) running_obj.to_nwb(nwbfile=nwbfile) - with NWBHDF5IO(nwb_path, 'w') as out_file: + with NWBHDF5IO(nwb_path, "w") as out_file: out_file.write(nwbfile) - with NWBHDF5IO(nwb_path, 'r') as in_file: - new_obj = RunningSpeed.from_nwb( - nwbfile=in_file.read(), - filtered=filtered) + with NWBHDF5IO(nwb_path, "r") as in_file: + new_obj = RunningSpeed.from_nwb(nwbfile=in_file.read(), filtered=filtered) pd.testing.assert_frame_equal(new_obj.value, running_obj.value) helper_functions.windows_safe_cleanup(file_path=nwb_path) def test_vbn_running_acq_round_trip( - behavior_stim_file_fixture, - replay_stim_file_fixture, - mapping_stim_file_fixture, - sync_file_fixture, - helper_functions, - tmp_path_factory): + behavior_stim_file_fixture, + replay_stim_file_fixture, + mapping_stim_file_fixture, + sync_file_fixture, + helper_functions, + tmp_path_factory, +): """ Test that we can round trip VBNRunningAcquistion between from_json, to_json, from_nwb, and to_nwb """ - tmpdir = tmp_path_factory.mktemp('vbn_running_acq_roundtrip') - nwb_path = pathlib.Path( - tempfile.mkstemp( - dir=tmpdir, - suffix='.nwb')[1]) + tmpdir = tmp_path_factory.mktemp("vbn_running_acq_roundtrip") + nwb_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix=".nwb")[1]) running_obj = RunningAcquisition.from_multiple_stimulus_files( - behavior_stimulus_file=behavior_stim_file_fixture, - mapping_stimulus_file=mapping_stim_file_fixture, - replay_stimulus_file=replay_stim_file_fixture, - sync_file=sync_file_fixture) + behavior_stimulus_file=behavior_stim_file_fixture, + mapping_stimulus_file=mapping_stim_file_fixture, + replay_stimulus_file=replay_stim_file_fixture, + sync_file=sync_file_fixture, + ) start_time = pytz.utc.localize(datetime.datetime(202, 7, 11)) - nwbfile = NWBFile( - session_description='running', - identifier='00001', - session_start_time=start_time) + nwbfile = NWBFile(session_description="running", identifier="00001", session_start_time=start_time) running_obj.to_nwb(nwbfile=nwbfile) - with NWBHDF5IO(nwb_path, 'w') as out_file: + with NWBHDF5IO(nwb_path, "w") as out_file: out_file.write(nwbfile) - with NWBHDF5IO(nwb_path, 'r') as in_file: - new_obj = RunningAcquisition.from_nwb( - nwbfile=in_file.read()) + with NWBHDF5IO(nwb_path, "r") as in_file: + new_obj = RunningAcquisition.from_nwb(nwbfile=in_file.read()) pd.testing.assert_frame_equal(new_obj.value, running_obj.value) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_stimulus_timestamps.py b/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_stimulus_timestamps.py index 414b15f439..a0badcf135 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_stimulus_timestamps.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_stimulus_timestamps.py @@ -9,115 +9,101 @@ from allensdk.internal.api import PostgresQueryMixin from allensdk.brain_observatory.behavior.data_files import ( - BehaviorStimulusFile, SyncFile, MappingStimulusFile, ReplayStimulusFile + BehaviorStimulusFile, + SyncFile, + MappingStimulusFile, + ReplayStimulusFile, +) +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.timestamps_processing import ( + get_behavior_stimulus_timestamps, + get_ophys_stimulus_timestamps, ) -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .stimulus_timestamps.timestamps_processing import ( - get_behavior_stimulus_timestamps, get_ophys_stimulus_timestamps) from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps -@pytest.mark.parametrize("dict_repr, has_pkl, has_sync", [ - # Test where input json only has "behavior_stimulus_file" - ( - # dict_repr - { - "behavior_stimulus_file": "mock_stimulus_file.pkl" - }, - # has_pkl - True, - # has_sync - False - ), - # Test where input json has both "behavior_stimulus_file" and "sync_file" - ( - # dict_repr - { - "behavior_stimulus_file": "mock_stimulus_file.pkl", - "sync_file": "mock_sync_file.h5" - }, - # has_pkl - True, - # has_sync - True - ), -]) -def test_stimulus_timestamps_from_json( - monkeypatch, dict_repr, has_pkl, has_sync -): +@pytest.mark.parametrize( + "dict_repr, has_pkl, has_sync", + [ + # Test where input json only has "behavior_stimulus_file" + ( + # dict_repr + {"behavior_stimulus_file": "mock_stimulus_file.pkl"}, + # has_pkl + True, + # has_sync + False, + ), + # Test where input json has both "behavior_stimulus_file" and "sync_file" + ( + # dict_repr + {"behavior_stimulus_file": "mock_stimulus_file.pkl", "sync_file": "mock_sync_file.h5"}, + # has_pkl + True, + # has_sync + True, + ), + ], +) +def test_stimulus_timestamps_from_json(monkeypatch, dict_repr, has_pkl, has_sync): mock_stimulus_file = create_autospec(BehaviorStimulusFile) mock_sync_file = create_autospec(SyncFile) - mock_get_behavior_stimulus_timestamps = create_autospec( - get_behavior_stimulus_timestamps - ) - mock_get_ophys_stimulus_timestamps = create_autospec( - get_ophys_stimulus_timestamps - ) + mock_get_behavior_stimulus_timestamps = create_autospec(get_behavior_stimulus_timestamps) + mock_get_ophys_stimulus_timestamps = create_autospec(get_ophys_stimulus_timestamps) with monkeypatch.context() as m: m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps" ".stimulus_timestamps.BehaviorStimulusFile", - mock_stimulus_file + mock_stimulus_file, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps.SyncFile", - mock_sync_file + mock_sync_file, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps" ".get_behavior_stimulus_timestamps", - mock_get_behavior_stimulus_timestamps + mock_get_behavior_stimulus_timestamps, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps" ".get_ophys_stimulus_timestamps", - mock_get_ophys_stimulus_timestamps + mock_get_ophys_stimulus_timestamps, ) mock_stimulus_file_instance = mock_stimulus_file.from_json(dict_repr) ts_from_stim = StimulusTimestamps.from_stimulus_file( - stimulus_file=mock_stimulus_file_instance, - monitor_delay=0.0) + stimulus_file=mock_stimulus_file_instance, monitor_delay=0.0 + ) if has_pkl and has_sync: mock_sync_file_instance = mock_sync_file.from_json(dict_repr) - ts_from_sync = StimulusTimestamps.from_sync_file( - sync_file=mock_sync_file_instance, - monitor_delay=0.0) + ts_from_sync = StimulusTimestamps.from_sync_file(sync_file=mock_sync_file_instance, monitor_delay=0.0) if has_pkl and has_sync: - mock_get_ophys_stimulus_timestamps.assert_called_once_with( - sync_path=mock_sync_file_instance.filepath - ) + mock_get_ophys_stimulus_timestamps.assert_called_once_with(sync_path=mock_sync_file_instance.filepath) assert ts_from_sync._sync_file == mock_sync_file_instance else: assert ts_from_stim._stimulus_file == mock_stimulus_file_instance - mock_get_behavior_stimulus_timestamps.assert_called_once_with( - stimulus_pkl=mock_stimulus_file_instance.data - ) + mock_get_behavior_stimulus_timestamps.assert_called_once_with(stimulus_pkl=mock_stimulus_file_instance.data) @pytest.fixture def stimulus_file_fixture(): dir = Path(__file__).parent.parent.resolve() - test_data_dir = dir / 'test_data' - sf_path = test_data_dir / 'stimulus_file.pkl' + test_data_dir = dir / "test_data" + sf_path = test_data_dir / "stimulus_file.pkl" - return BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(sf_path)}) + return BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(sf_path)}) def test_stimulus_timestamps_from_json2(stimulus_file_fixture): - sf = stimulus_file_fixture - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=sf, - monitor_delay=0.0) + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=sf, monitor_delay=0.0) expected = np.array([0.016 * i for i in range(11)]) assert np.allclose(expected, stimulus_timestamps.value) @@ -131,95 +117,64 @@ def test_stimulus_timestamps_from_json3(stimulus_file_fixture): """ sf = stimulus_file_fixture - sf._data['items']['behavior']['intervalsms'] = [0.1, 0.2, 0.3, 0.4] - - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=sf, - monitor_delay=0.0) - - expected = np.array([0., 0.0001, 0.0003, 0.0006, 0.001]) - np.testing.assert_array_almost_equal(stimulus_timestamps.value, - expected, - decimal=10) - - -@pytest.mark.parametrize("behavior_session_id, ophys_experiment_id", [ - ( - 12345, - None - ), - ( - 1234, - 5678 - ) -]) -def test_stimulus_timestamps_from_lims( - monkeypatch, behavior_session_id, ophys_experiment_id -): + sf._data["items"]["behavior"]["intervalsms"] = [0.1, 0.2, 0.3, 0.4] + + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=sf, monitor_delay=0.0) + + expected = np.array([0.0, 0.0001, 0.0003, 0.0006, 0.001]) + np.testing.assert_array_almost_equal(stimulus_timestamps.value, expected, decimal=10) + + +@pytest.mark.parametrize("behavior_session_id, ophys_experiment_id", [(12345, None), (1234, 5678)]) +def test_stimulus_timestamps_from_lims(monkeypatch, behavior_session_id, ophys_experiment_id): mock_db_conn = create_autospec(PostgresQueryMixin, instance=True) mock_stimulus_file = create_autospec(BehaviorStimulusFile) mock_sync_file = create_autospec(SyncFile) - mock_get_behavior_stimulus_timestamps = create_autospec( - get_behavior_stimulus_timestamps - ) - mock_get_ophys_stimulus_timestamps = create_autospec( - get_ophys_stimulus_timestamps - ) + mock_get_behavior_stimulus_timestamps = create_autospec(get_behavior_stimulus_timestamps) + mock_get_ophys_stimulus_timestamps = create_autospec(get_ophys_stimulus_timestamps) with monkeypatch.context() as m: m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps" ".stimulus_timestamps.BehaviorStimulusFile", - mock_stimulus_file + mock_stimulus_file, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps.SyncFile", - mock_sync_file + mock_sync_file, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps" ".get_behavior_stimulus_timestamps", - mock_get_behavior_stimulus_timestamps + mock_get_behavior_stimulus_timestamps, ) m.setattr( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.stimulus_timestamps" ".get_ophys_stimulus_timestamps", - mock_get_ophys_stimulus_timestamps - ) - mock_stimulus_file_instance = mock_stimulus_file.from_lims( - mock_db_conn, behavior_session_id + mock_get_ophys_stimulus_timestamps, ) + mock_stimulus_file_instance = mock_stimulus_file.from_lims(mock_db_conn, behavior_session_id) ts_from_stim = StimulusTimestamps.from_stimulus_file( - stimulus_file=mock_stimulus_file_instance, - monitor_delay=0.0) + stimulus_file=mock_stimulus_file_instance, monitor_delay=0.0 + ) assert ts_from_stim._stimulus_file == mock_stimulus_file_instance if behavior_session_id is not None and ophys_experiment_id is not None: - mock_sync_file_instance = mock_sync_file.from_lims( - mock_db_conn, ophys_experiment_id - ) - ts_from_sync = StimulusTimestamps.from_sync_file( - sync_file=mock_sync_file_instance, - monitor_delay=0.0) + mock_sync_file_instance = mock_sync_file.from_lims(mock_db_conn, ophys_experiment_id) + ts_from_sync = StimulusTimestamps.from_sync_file(sync_file=mock_sync_file_instance, monitor_delay=0.0) if behavior_session_id is not None and ophys_experiment_id is not None: - mock_get_ophys_stimulus_timestamps.assert_called_once_with( - sync_path=mock_sync_file_instance.filepath - ) + mock_get_ophys_stimulus_timestamps.assert_called_once_with(sync_path=mock_sync_file_instance.filepath) assert ts_from_sync._sync_file == mock_sync_file_instance else: - mock_stimulus_file.from_lims.assert_called_with( - mock_db_conn, behavior_session_id - ) - mock_get_behavior_stimulus_timestamps.assert_called_once_with( - stimulus_pkl=mock_stimulus_file_instance.data - ) + mock_stimulus_file.from_lims.assert_called_with(mock_db_conn, behavior_session_id) + mock_get_behavior_stimulus_timestamps.assert_called_once_with(stimulus_pkl=mock_stimulus_file_instance.data) # Fixtures: @@ -228,21 +183,13 @@ def test_stimulus_timestamps_from_lims( # data_object_roundtrip_fixture: # test/brain_observatory/behavior/data_objects/conftest.py @pytest.mark.parametrize( - 'roundtrip, raw_stimulus_timestamps_data, monitor_delay', - product((True, False), - (np.arange(1, 6, 1), np.arange(6, 11, 1)), - (0.1, 0.2))) + "roundtrip, raw_stimulus_timestamps_data, monitor_delay", + product((True, False), (np.arange(1, 6, 1), np.arange(6, 11, 1)), (0.1, 0.2)), +) def test_stimulus_timestamps_nwb_roundtrip( - nwbfile, - data_object_roundtrip_fixture, - roundtrip, - raw_stimulus_timestamps_data, - monitor_delay + nwbfile, data_object_roundtrip_fixture, roundtrip, raw_stimulus_timestamps_data, monitor_delay ): - stimulus_timestamps = StimulusTimestamps( - timestamps=raw_stimulus_timestamps_data, - monitor_delay=monitor_delay - ) + stimulus_timestamps = StimulusTimestamps(timestamps=raw_stimulus_timestamps_data, monitor_delay=monitor_delay) nwbfile = stimulus_timestamps.to_nwb(nwbfile) if roundtrip: @@ -250,13 +197,11 @@ def test_stimulus_timestamps_nwb_roundtrip( else: obt = StimulusTimestamps.from_nwb(nwbfile) - assert np.allclose(obt.value, - raw_stimulus_timestamps_data+monitor_delay) + assert np.allclose(obt.value, raw_stimulus_timestamps_data + monitor_delay) -@pytest.fixture(scope='module') -def stimulus_timestamps_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def stimulus_timestamps_fixture(behavior_ecephys_session_config_fixture): """ Return a StimulusTimestamps object constituted from multiple stimulus files @@ -266,30 +211,19 @@ def stimulus_timestamps_fixture( bsf = BehaviorStimulusFile.from_json(dict_repr=input_data) msf = MappingStimulusFile.from_json(dict_repr=input_data) rsf = ReplayStimulusFile.from_json(dict_repr=input_data) - obj = \ - StimulusTimestamps.from_multiple_stimulus_blocks( - sync_file=sync_file, - list_of_stims=[bsf, msf, rsf] - ) + obj = StimulusTimestamps.from_multiple_stimulus_blocks(sync_file=sync_file, list_of_stims=[bsf, msf, rsf]) return obj @pytest.mark.requires_bamboo -@pytest.mark.parametrize('roundtrip', [True, False]) -def test_read_write_nwb( - roundtrip, - data_object_roundtrip_fixture, - stimulus_timestamps_fixture, - helper_functions): - +@pytest.mark.parametrize("roundtrip", [True, False]) +def test_read_write_nwb(roundtrip, data_object_roundtrip_fixture, stimulus_timestamps_fixture, helper_functions): nwbfile = helper_functions.create_blank_nwb_file() stimulus_timestamps_fixture.to_nwb(nwbfile=nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=StimulusTimestamps) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=StimulusTimestamps) else: obt = StimulusTimestamps.from_nwb(nwbfile=nwbfile) @@ -304,13 +238,9 @@ def test_substract_monitor_delay(): rng = np.random.default_rng(22) timestamps = np.sort(rng.random(100)) monitor_delay = 0.57 - original_ts = StimulusTimestamps( - timestamps=timestamps, - monitor_delay=monitor_delay) + original_ts = StimulusTimestamps(timestamps=timestamps, monitor_delay=monitor_delay) - np.testing.assert_array_equal( - original_ts.value, - timestamps+monitor_delay) + np.testing.assert_array_equal(original_ts.value, timestamps + monitor_delay) assert np.isclose(original_ts.monitor_delay, monitor_delay) new_ts = original_ts.subtract_monitor_delay() diff --git a/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_timestamps_processing.py b/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_timestamps_processing.py index 179d3e3edb..c3b03e3c76 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_timestamps_processing.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/stimulus_timestamps/test_timestamps_processing.py @@ -3,70 +3,82 @@ import numpy as np import pytest -from allensdk.brain_observatory.behavior.data_objects.timestamps \ - .stimulus_timestamps.timestamps_processing import ( - get_behavior_stimulus_timestamps, - get_ophys_stimulus_timestamps, - get_frame_indices) +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.timestamps_processing import ( + get_behavior_stimulus_timestamps, + get_ophys_stimulus_timestamps, + get_frame_indices, +) from allensdk.internal.brain_observatory.time_sync import OphysTimeAligner -@pytest.mark.parametrize("pkl_data, expected", [ - # Extremely basic test case - ( - # pkl_data - { - "items": { - "behavior": { - "intervalsms": np.array([ - 1000, 1001, 1002, 1003, 1004, 1005 - ]) - } - } - }, - # expected - np.array([ - 0.0, 1.0, 2.001, 3.003, 4.006, 5.01, 6.015 - ]) - ), - # More realistic test case - ( - # pkl_data - { - "items": { - "behavior": { - "intervalsms": np.array([ - 16.5429, 16.6685, 16.66580001, 16.70569999, - 16.6668, - 16.69619999, 16.655, 16.6805, 16.75940001, 16.6831 - ]) +@pytest.mark.parametrize( + "pkl_data, expected", + [ + # Extremely basic test case + ( + # pkl_data + {"items": {"behavior": {"intervalsms": np.array([1000, 1001, 1002, 1003, 1004, 1005])}}}, + # expected + np.array([0.0, 1.0, 2.001, 3.003, 4.006, 5.01, 6.015]), + ), + # More realistic test case + ( + # pkl_data + { + "items": { + "behavior": { + "intervalsms": np.array( + [ + 16.5429, + 16.6685, + 16.66580001, + 16.70569999, + 16.6668, + 16.69619999, + 16.655, + 16.6805, + 16.75940001, + 16.6831, + ] + ) + } } - } - }, - # expected - np.array([ - 0.0, 0.0165429, 0.0332114, 0.0498772, 0.0665829, 0.0832497, - 0.0999459, 0.1166009, 0.1332814, 0.1500408, 0.1667239 - ]) - ) -]) + }, + # expected + np.array( + [ + 0.0, + 0.0165429, + 0.0332114, + 0.0498772, + 0.0665829, + 0.0832497, + 0.0999459, + 0.1166009, + 0.1332814, + 0.1500408, + 0.1667239, + ] + ), + ), + ], +) def test_get_behavior_stimulus_timestamps(pkl_data, expected): obt = get_behavior_stimulus_timestamps(pkl_data) assert np.allclose(obt, expected) -@pytest.mark.parametrize("sync_path, expected_timestamps", [ - ("/tmp/mock_sync_file.h5", [1, 2, 3]), -]) -def test_get_ophys_stimulus_timestamps( - monkeypatch, sync_path, expected_timestamps -): +@pytest.mark.parametrize( + "sync_path, expected_timestamps", + [ + ("/tmp/mock_sync_file.h5", [1, 2, 3]), + ], +) +def test_get_ophys_stimulus_timestamps(monkeypatch, sync_path, expected_timestamps): mock_ophys_time_aligner = create_autospec(OphysTimeAligner) mock_aligner_instance = mock_ophys_time_aligner.return_value - property_mock = PropertyMock( - return_value=(expected_timestamps, "ignored_return_val") - ) + property_mock = PropertyMock(return_value=(expected_timestamps, "ignored_return_val")) type(mock_aligner_instance).clipped_stim_timestamps = property_mock with monkeypatch.context() as m: @@ -74,7 +86,7 @@ def test_get_ophys_stimulus_timestamps( "allensdk.brain_observatory.behavior.data_objects" ".timestamps.stimulus_timestamps.timestamps_processing" ".OphysTimeAligner", - mock_ophys_time_aligner + mock_ophys_time_aligner, ) obt = get_ophys_stimulus_timestamps(sync_path) @@ -85,26 +97,18 @@ def test_get_ophys_stimulus_timestamps( @pytest.mark.parametrize( "frame_timestamps, event_timestamps, expected_indices", - [(np.arange(0, 1, 0.11), - np.array([0.22, 0.13, 0.32, 0.77]), - np.array([2, 1, 2, 7])), - (np.arange(0, 1, 0.11), - np.array([-0.1, 2.1, 0.55, 0.42]), - np.array([0, 9, 5, 3])), - (np.array([0.11, 0.11, 0.33, 0.33, 0.44]), - np.array([0.05, 0.12, 0.22, 0.33, 0.77]), - np.array([0, 1, 1, 2, 4]))]) -def test_get_frame_indices( - frame_timestamps, - event_timestamps, - expected_indices): + [ + (np.arange(0, 1, 0.11), np.array([0.22, 0.13, 0.32, 0.77]), np.array([2, 1, 2, 7])), + (np.arange(0, 1, 0.11), np.array([-0.1, 2.1, 0.55, 0.42]), np.array([0, 9, 5, 3])), + (np.array([0.11, 0.11, 0.33, 0.33, 0.44]), np.array([0.05, 0.12, 0.22, 0.33, 0.77]), np.array([0, 1, 1, 2, 4])), + ], +) +def test_get_frame_indices(frame_timestamps, event_timestamps, expected_indices): """ Test that get_frame_indices correctly associates events with frames """ - actual = get_frame_indices( - frame_timestamps=frame_timestamps, - event_timestamps=event_timestamps) + actual = get_frame_indices(frame_timestamps=frame_timestamps, event_timestamps=event_timestamps) np.testing.assert_array_equal(actual, expected_indices) @@ -117,8 +121,5 @@ def test_get_frame_indices_error(): frame_timestamps = np.array([0.1, 0.4, 0.3]) event_timestamps = np.array([0.11, 0.22]) - with pytest.raises(ValueError, - match="frame_timestamps are not in ascending order"): - get_frame_indices( - frame_timestamps=frame_timestamps, - event_timestamps=event_timestamps) + with pytest.raises(ValueError, match="frame_timestamps are not in ascending order"): + get_frame_indices(frame_timestamps=frame_timestamps, event_timestamps=event_timestamps) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_cell_specimens.py b/allensdk/test/brain_observatory/behavior/data_objects/test_cell_specimens.py index 0228adcb5c..6b86fd41f7 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_cell_specimens.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_cell_specimens.py @@ -8,24 +8,21 @@ import pytest from allensdk.core import DataObject -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .cell_specimens import ( - CellSpecimens, - CellSpecimenMeta, - EventsParams, - ) -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .rois_mixin import RoisMixin -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .ophys_experiment_metadata.imaging_plane import ImagingPlane -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .ophys_timestamps import OphysTimestamps +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.cell_specimens import ( + CellSpecimens, + CellSpecimenMeta, + EventsParams, +) +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.rois_mixin import RoisMixin +from allensdk.brain_observatory.behavior.data_objects.metadata.ophys_experiment_metadata.imaging_plane import ( + ImagingPlane, +) +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import OphysTimestamps from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP from allensdk.internal.api import db_connection_creator -from allensdk.test.brain_observatory.behavior.data_objects.metadata\ - .test_behavior_ophys_metadata import ( - TestBOM, - ) +from allensdk.test.brain_observatory.behavior.data_objects.metadata.test_behavior_ophys_metadata import ( + TestBOM, +) class TestLims: @@ -49,26 +46,18 @@ def setup_method(self, method): # Will only create a dbconn if the test requires_bamboo if "requires_bamboo" in marks: - self.dbconn = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + self.dbconn = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) @pytest.mark.requires_bamboo def test_from_lims(self): number_of_frames = 140296 - ots = OphysTimestamps( - timestamps=np.linspace( - start=0.1, stop=0.1 * number_of_frames, num=number_of_frames - ) - ) + ots = OphysTimestamps(timestamps=np.linspace(start=0.1, stop=0.1 * number_of_frames, num=number_of_frames)) csp = CellSpecimens.from_lims( ophys_experiment_id=self.ophys_experiment_id, lims_db=self.dbconn, ophys_timestamps=ots, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) assert not csp.table.empty assert not csp.events.empty @@ -86,15 +75,11 @@ def setup_class(cls): dict_repr = json.load(f) dict_repr = dict_repr["session_data"] dict_repr["sync_file"] = str(test_data_dir / "sync.h5") - dict_repr["behavior_stimulus_file"] = str( - test_data_dir / "behavior_stimulus_file.pkl" - ) + dict_repr["behavior_stimulus_file"] = str(test_data_dir / "behavior_stimulus_file.pkl") dict_repr["dff_file"] = str(test_data_dir / "demix_file.h5") dict_repr["demix_file"] = str(test_data_dir / "demix_file.h5") dict_repr["neuropil_file"] = str(test_data_dir / "demix_file.h5") - dict_repr["neuropil_corrected_file"] = str( - test_data_dir / "neuropil_corrected_file.h5" - ) + dict_repr["neuropil_corrected_file"] = str(test_data_dir / "neuropil_corrected_file.h5") dict_repr["events_file"] = str(test_data_dir / "events.h5") cls.dict_repr = dict_repr @@ -107,18 +92,14 @@ def setup_class(cls): targeted_structure="VISp", ), ) - cls.ophys_timestamps = OphysTimestamps( - timestamps=np.array([0.1, 0.2, 0.3]) - ) + cls.ophys_timestamps = OphysTimestamps(timestamps=np.array([0.1, 0.2, 0.3])) def test_from_json(self): csp = CellSpecimens.from_json( dict_repr=self.dict_repr, ophys_timestamps=self.ophys_timestamps, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) assert not csp.table.empty assert not csp.events.empty @@ -142,21 +123,15 @@ def test_roi_data_same_order_as_cell_specimen_table(self, data): dict_repr=self.dict_repr, ophys_timestamps=self.ophys_timestamps, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) private_attr = getattr(csp, f"_{data}") public_attr = getattr(csp, data) # Events stores cell_roi_id as column whereas traces is index - data_cell_roi_ids = getattr( - private_attr.value, "cell_roi_id" if data == "events" else "index" - ).values + data_cell_roi_ids = getattr(private_attr.value, "cell_roi_id" if data == "events" else "index").values - current_order = np.where( - data_cell_roi_ids == csp._cell_specimen_table["cell_roi_id"] - )[0] + current_order = np.where(data_cell_roi_ids == csp._cell_specimen_table["cell_roi_id"])[0] # make sure same order private_attr._value = private_attr.value.iloc[current_order] @@ -177,18 +152,14 @@ def test_roi_data_same_order_as_cell_specimen_table(self, data): "corrected_fluorescence_traces", ), ) - def test_trace_rois_different_than_cell_specimen_table( - self, trace_type, extra_in_trace - ): + def test_trace_rois_different_than_cell_specimen_table(self, trace_type, extra_in_trace): """check that an exception is raised if there is a mismatch in rois between cell specimen table and traces""" csp = CellSpecimens.from_json( dict_repr=self.dict_repr, ophys_timestamps=self.ophys_timestamps, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) private_trace_attr = getattr(csp, f"_{trace_type}") @@ -201,33 +172,28 @@ def test_trace_rois_different_than_cell_specimen_table( else: # Drop an roi from trace that is in cell specimen table csp_rois = csp._cell_specimen_table["cell_roi_id"] - private_trace_attr._value = private_trace_attr._value[ - private_trace_attr._value.index != csp_rois.iloc[0] - ] + private_trace_attr._value = private_trace_attr._value[private_trace_attr._value.index != csp_rois.iloc[0]] if trace_type == "dff_traces": trace_args = { "dff_traces": private_trace_attr, "demixed_traces": csp._demixed_traces, "neuropil_traces": csp._neuropil_traces, - "corrected_fluorescence_traces": - csp._corrected_fluorescence_traces, + "corrected_fluorescence_traces": csp._corrected_fluorescence_traces, } elif trace_type == "demixed_traces": trace_args = { "dff_traces": csp._dff_traces, "demixed_traces": private_trace_attr, "neuropil_traces": csp._neuropil_traces, - "corrected_fluorescence_traces": - csp._corrected_fluorescence_traces, + "corrected_fluorescence_traces": csp._corrected_fluorescence_traces, } elif trace_type == "neuropil_traces": trace_args = { "dff_traces": csp._dff_traces, "demixed_traces": csp._demixed_traces, "neuropil_traces": private_trace_attr, - "corrected_fluorescence_traces": - csp._corrected_fluorescence_traces, + "corrected_fluorescence_traces": csp._corrected_fluorescence_traces, } else: trace_args = { @@ -252,9 +218,7 @@ def test_trace_rois_different_than_cell_specimen_table( class TestNWB: @classmethod def setup_class(cls): - cls.ophys_timestamps = OphysTimestamps( - timestamps=np.array([0.1, 0.2, 0.3]) - ) + cls.ophys_timestamps = OphysTimestamps(timestamps=np.array([0.1, 0.2, 0.3])) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( @@ -275,26 +239,20 @@ def setup_method(self, method): @pytest.mark.parametrize("exclude_invalid_rois", [True, False]) @pytest.mark.parametrize("roundtrip", [True, False]) - def test_read_write_nwb( - self, roundtrip, data_object_roundtrip_fixture, exclude_invalid_rois - ): + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture, exclude_invalid_rois): cell_specimens = CellSpecimens.from_json( dict_repr=self.dict_repr, ophys_timestamps=self.ophys_timestamps, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), exclude_invalid_rois=exclude_invalid_rois, - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) csp = cell_specimens._cell_specimen_table valid_roi_id = csp[csp["valid_roi"]]["cell_roi_id"] - cell_specimens.to_nwb( - nwbfile=self.nwbfile, ophys_timestamps=self.ophys_timestamps - ) + cell_specimens.to_nwb(nwbfile=self.nwbfile, ophys_timestamps=self.ophys_timestamps) if roundtrip: obt = data_object_roundtrip_fixture( @@ -302,28 +260,20 @@ def test_read_write_nwb( data_object_cls=CellSpecimens, exclude_invalid_rois=exclude_invalid_rois, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) else: obt = cell_specimens.from_nwb( nwbfile=self.nwbfile, exclude_invalid_rois=exclude_invalid_rois, segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20 - ), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), ) if exclude_invalid_rois: - cell_specimens._cell_specimen_table = ( - cell_specimens._cell_specimen_table[ - cell_specimens._cell_specimen_table["cell_roi_id"].isin( - valid_roi_id - ) - ] - ) + cell_specimens._cell_specimen_table = cell_specimens._cell_specimen_table[ + cell_specimens._cell_specimen_table["cell_roi_id"].isin(valid_roi_id) + ] assert obt == cell_specimens @@ -347,8 +297,6 @@ def __init__(self): raise_if_rois_missing=raise_if_rois_missing, ) else: - rois.filter_and_reorder( - roi_ids=roi_ids, raise_if_rois_missing=raise_if_rois_missing - ) + rois.filter_and_reorder(roi_ids=roi_ids, raise_if_rois_missing=raise_if_rois_missing) expected = pd.DataFrame({"cell_roi_id": [1], "foo": [2]}) pd.testing.assert_frame_equal(rois._value, expected) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_licks.py b/allensdk/test/brain_observatory/behavior/data_objects/test_licks.py index ec6fd17426..d609a863e6 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_licks.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_licks.py @@ -6,8 +6,7 @@ import pynwb import pytest -from allensdk.brain_observatory.behavior.data_files import ( - BehaviorStimulusFile, SyncFile) +from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile, SyncFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory.behavior.data_objects.licks import Licks @@ -16,13 +15,11 @@ class TestFromBehaviorStimulusFile: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - cls.stimulus_file = BehaviorStimulusFile( - filepath=test_data_dir / 'behavior_stimulus_file.pkl') - cls.sync_file = SyncFile( - filepath=test_data_dir / 'sync.h5') - expected = pd.read_pickle(str(test_data_dir / 'licks.pkl')) + cls.stimulus_file = BehaviorStimulusFile(filepath=test_data_dir / "behavior_stimulus_file.pkl") + cls.sync_file = SyncFile(filepath=test_data_dir / "sync.h5") + expected = pd.read_pickle(str(test_data_dir / "licks.pkl")) cls.expected = Licks(licks=expected) def test_monitor_delay_error(self): @@ -30,21 +27,13 @@ def test_monitor_delay_error(self): Test that an error is raised if Licks are instantiated with non-zero monitor delay """ - timestamps = StimulusTimestamps( - np.arange(10), - 0.1) - with pytest.raises(RuntimeError, - match="monitor_delay should be zero"): - Licks.from_stimulus_file( - stimulus_file=self.stimulus_file, - stimulus_timestamps=timestamps) + timestamps = StimulusTimestamps(np.arange(10), 0.1) + with pytest.raises(RuntimeError, match="monitor_delay should be zero"): + Licks.from_stimulus_file(stimulus_file=self.stimulus_file, stimulus_timestamps=timestamps) def test_from_stimulus_file(self): - st = StimulusTimestamps.from_stimulus_file( - stimulus_file=self.stimulus_file, - monitor_delay=0.0) - licks = Licks.from_stimulus_file(stimulus_file=self.stimulus_file, - stimulus_timestamps=st) + st = StimulusTimestamps.from_stimulus_file(stimulus_file=self.stimulus_file, monitor_delay=0.0) + licks = Licks.from_stimulus_file(stimulus_file=self.stimulus_file, stimulus_timestamps=st) assert licks == self.expected def test_from_stimulus_and_sync_file(self): @@ -52,11 +41,10 @@ def test_from_stimulus_and_sync_file(self): Test is slightly different from other tests as the sync file data is not matched up to the stim data. """ - lick_times = self.sync_file.data['lick_times'] - licks = Licks.from_stimulus_file(stimulus_file=self.stimulus_file, - stimulus_timestamps=lick_times) - assert licks.value['timestamps'][0] == lick_times[0] - assert licks.value['frame'][0] == self.expected.value['frame'][0] + lick_times = self.sync_file.data["lick_times"] + licks = Licks.from_stimulus_file(stimulus_file=self.stimulus_file, stimulus_timestamps=lick_times) + assert licks.value["timestamps"][0] == lick_times[0] + assert licks.value["frame"][0] == self.expected.value["frame"][0] def test_from_stimulus_file2(self, tmpdir): """ @@ -64,26 +52,18 @@ def test_from_stimulus_file2(self, tmpdir): of licks whose timestamps are based on their frame number with respect to the stimulus_timestamps """ - stimulus_filepath = self._create_test_stimulus_file( - lick_events=[12, 15, 90, 136], tmpdir=tmpdir) - stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(stimulus_filepath)}) - timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), - monitor_delay=0.0) - licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) - - expected_dict = {'timestamps': [0.12, 0.15, 0.90, 1.36], - 'frame': [12, 15, 90, 136]} + stimulus_filepath = self._create_test_stimulus_file(lick_events=[12, 15, 90, 136], tmpdir=tmpdir) + stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(stimulus_filepath)}) + timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), monitor_delay=0.0) + licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) + + expected_dict = {"timestamps": [0.12, 0.15, 0.90, 1.36], "frame": [12, 15, 90, 136]} expected_df = pd.DataFrame(expected_dict) assert expected_df.columns.equals(licks.value.columns) np.testing.assert_array_almost_equal( - expected_df.timestamps.to_numpy(), - licks.value['timestamps'].to_numpy(), - decimal=10) - np.testing.assert_array_almost_equal(expected_df.frame.to_numpy(), - licks.value['frame'].to_numpy(), - decimal=10) + expected_df.timestamps.to_numpy(), licks.value["timestamps"].to_numpy(), decimal=10 + ) + np.testing.assert_array_almost_equal(expected_df.frame.to_numpy(), licks.value["frame"].to_numpy(), decimal=10) def test_empty_licks(self, tmpdir): """ @@ -91,23 +71,16 @@ def test_empty_licks(self, tmpdir): there are no licks """ - stimulus_filepath = self._create_test_stimulus_file( - lick_events=[], tmpdir=tmpdir) - stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(stimulus_filepath)}) - timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), - monitor_delay=0.0) - licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) - - expected_dict = {'timestamps': [], - 'frame': []} + stimulus_filepath = self._create_test_stimulus_file(lick_events=[], tmpdir=tmpdir) + stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(stimulus_filepath)}) + timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), monitor_delay=0.0) + licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) + + expected_dict = {"timestamps": [], "frame": []} expected_df = pd.DataFrame(expected_dict) assert expected_df.columns.equals(licks.value.columns) - np.testing.assert_array_equal(expected_df.timestamps.to_numpy(), - licks.value['timestamps'].to_numpy()) - np.testing.assert_array_equal(expected_df.frame.to_numpy(), - licks.value['frame'].to_numpy()) + np.testing.assert_array_equal(expected_df.timestamps.to_numpy(), licks.value["timestamps"].to_numpy()) + np.testing.assert_array_equal(expected_df.frame.to_numpy(), licks.value["frame"].to_numpy()) def test_get_licks_excess(self, tmpdir): """ @@ -121,59 +94,46 @@ def test_get_licks_excess(self, tmpdir): """ stimulus_filepath = self._create_test_stimulus_file( lick_events=[12, 15, 90, 136, 200], # len(timestamps) == 200, - tmpdir=tmpdir) - stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(stimulus_filepath)}) - timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), - monitor_delay=0.0) - licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) - - expected_dict = {'timestamps': [0.12, 0.15, 0.90, 1.36], - 'frame': [12, 15, 90, 136]} + tmpdir=tmpdir, + ) + stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(stimulus_filepath)}) + timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), monitor_delay=0.0) + licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) + + expected_dict = {"timestamps": [0.12, 0.15, 0.90, 1.36], "frame": [12, 15, 90, 136]} expected_df = pd.DataFrame(expected_dict) assert expected_df.columns.equals(licks.value.columns) np.testing.assert_array_almost_equal( - expected_df.timestamps.to_numpy(), - licks.value['timestamps'].to_numpy(), - decimal=10) - np.testing.assert_array_almost_equal(expected_df.frame.to_numpy(), - licks.value['frame'].to_numpy(), - decimal=10) + expected_df.timestamps.to_numpy(), licks.value["timestamps"].to_numpy(), decimal=10 + ) + np.testing.assert_array_almost_equal(expected_df.frame.to_numpy(), licks.value["frame"].to_numpy(), decimal=10) def test_get_licks_failure(self, tmpdir): stimulus_filepath = self._create_test_stimulus_file( lick_events=[12, 15, 90, 136, 201], # len(timestamps) == 200, - tmpdir=tmpdir) - stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(stimulus_filepath)}) - timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), - monitor_delay=0.0) + tmpdir=tmpdir, + ) + stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(stimulus_filepath)}) + timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), monitor_delay=0.0) with pytest.raises(IndexError): - Licks.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) + Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) @staticmethod def _create_test_stimulus_file(lick_events, tmpdir): trial_log = [ - {'licks': [(-1.0, 100), (-1.0, 200)]}, - {'licks': [(-1.0, 300), (-1.0, 400)]}, - {'licks': [(-1.0, 500), (-1.0, 600)]} + {"licks": [(-1.0, 100), (-1.0, 200)]}, + {"licks": [(-1.0, 300), (-1.0, 400)]}, + {"licks": [(-1.0, 500), (-1.0, 600)]}, ] - lick_events = [{'lick_events': lick_events}] + lick_events = [{"lick_events": lick_events}] data = { - 'items': { - 'behavior': { - 'trial_log': trial_log, - 'lick_sensors': lick_events - } - }, + "items": {"behavior": {"trial_log": trial_log, "lick_sensors": lick_events}}, } - tmp_path = tmpdir / 'stimulus_file.pkl' - with open(tmp_path, 'wb') as f: + tmp_path = tmpdir / "stimulus_file.pkl" + with open(tmp_path, "wb") as f: pickle.dump(data, f) f.seek(0) @@ -184,32 +144,23 @@ class TestNWB: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - stimulus_file = BehaviorStimulusFile( - filepath=test_data_dir / 'behavior_stimulus_file.pkl') - ts = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, - monitor_delay=0.0) - cls.licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=ts) + stimulus_file = BehaviorStimulusFile(filepath=test_data_dir / "behavior_stimulus_file.pkl") + ts = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.0) + cls.licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=ts) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.licks.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=Licks) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=Licks) else: obt = self.licks.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_motion_correction.py b/allensdk/test/brain_observatory/behavior/data_objects/test_motion_correction.py index 9590c48d50..dd4c0a1f19 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_motion_correction.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_motion_correction.py @@ -7,27 +7,13 @@ import pynwb import pytest -from allensdk.brain_observatory.behavior.data_files\ - .rigid_motion_transform_file import \ - RigidMotionTransformFile -from allensdk.brain_observatory.behavior.data_objects.cell_specimens\ - .cell_specimens import ( - CellSpecimens, - EventsParams) -from allensdk.brain_observatory.behavior.data_objects.motion_correction \ - import \ - MotionCorrection -from allensdk.brain_observatory.behavior.data_objects.timestamps\ - .ophys_timestamps import \ - OphysTimestamps -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest -from allensdk.test.brain_observatory.behavior.data_objects.metadata\ - .test_behavior_ophys_metadata import \ - TestBOM -from allensdk.test.brain_observatory.behavior.data_objects.nwb_input_json \ - import \ - NwbInputJson +from allensdk.brain_observatory.behavior.data_files.rigid_motion_transform_file import RigidMotionTransformFile +from allensdk.brain_observatory.behavior.data_objects.cell_specimens.cell_specimens import CellSpecimens, EventsParams +from allensdk.brain_observatory.behavior.data_objects.motion_correction import MotionCorrection +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import OphysTimestamps +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest +from allensdk.test.brain_observatory.behavior.data_objects.metadata.test_behavior_ophys_metadata import TestBOM +from allensdk.test.brain_observatory.behavior.data_objects.nwb_input_json import NwbInputJson class TestFromDataFile(LimsTest): @@ -38,11 +24,11 @@ def setup_class(cls): @pytest.mark.requires_bamboo def test_from_data_file(self): motion_correction_file = RigidMotionTransformFile.from_lims( - ophys_experiment_id=self.ophys_experiment_id, db=self.dbconn) - mc = MotionCorrection.from_data_file( - rigid_motion_transform_file=motion_correction_file) + ophys_experiment_id=self.ophys_experiment_id, db=self.dbconn + ) + mc = MotionCorrection.from_data_file(rigid_motion_transform_file=motion_correction_file) assert not mc.value.empty - expected_cols = ['x', 'y'] + expected_cols = ["x", "y"] assert len(mc.value.columns) == 2 for c in expected_cols: assert c in mc.value.columns @@ -52,35 +38,30 @@ class TestJson: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' - with open(test_data_dir / 'test_input.json') as f: + test_data_dir = dir / "test_data" + with open(test_data_dir / "test_input.json") as f: dict_repr = json.load(f) - dict_repr = dict_repr['session_data'] - dict_repr['rigid_motion_transform_file'] = \ - str(test_data_dir / 'rigid_motion_transform_file.csv') + dict_repr = dict_repr["session_data"] + dict_repr["rigid_motion_transform_file"] = str(test_data_dir / "rigid_motion_transform_file.csv") cls.dict_repr = dict_repr - cls.motion_correction_file = \ - RigidMotionTransformFile.from_json(dict_repr=dict_repr) - expected = pd.DataFrame({'x': [2, 3, 2], 'y': [-3, -4, -4]}) + cls.motion_correction_file = RigidMotionTransformFile.from_json(dict_repr=dict_repr) + expected = pd.DataFrame({"x": [2, 3, 2], "y": [-3, -4, -4]}) cls.expected = MotionCorrection(motion_correction=expected) def test_from_json(self): - mc = MotionCorrection.from_data_file( - rigid_motion_transform_file=self.motion_correction_file) + mc = MotionCorrection.from_data_file(rigid_motion_transform_file=self.motion_correction_file) assert mc == self.expected class TestNWB: @classmethod def setup_class(cls): - df = pd.DataFrame({'x': [2, 3, 2], 'y': [-3, -4, -4]}) + df = pd.DataFrame({"x": [2, 3, 2], "y": [-3, -4, -4]}) cls.motion_correction = MotionCorrection(motion_correction=df) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) def _write_cell_specimen(): @@ -92,28 +73,24 @@ def _write_cell_specimen(): # write cell specimen ij = NwbInputJson() - ophys_timestamps = OphysTimestamps( - timestamps=np.array([.1, .2, .3])) + ophys_timestamps = OphysTimestamps(timestamps=np.array([0.1, 0.2, 0.3])) csp = CellSpecimens.from_json( - dict_repr=ij.dict_repr, ophys_timestamps=ophys_timestamps, - segmentation_mask_image_spacing=(.78125e-3, .78125e-3), - events_params=EventsParams( - filter_scale_seconds=2.0/31.0, - filter_n_time_steps=20)) + dict_repr=ij.dict_repr, + ophys_timestamps=ophys_timestamps, + segmentation_mask_image_spacing=(0.78125e-3, 0.78125e-3), + events_params=EventsParams(filter_scale_seconds=2.0 / 31.0, filter_n_time_steps=20), + ) csp.to_nwb(nwbfile=self.nwbfile, ophys_timestamps=ophys_timestamps) # need to write cell specimen, since it is a dependency _write_cell_specimen() - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.motion_correction.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=MotionCorrection) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=MotionCorrection) else: obt = self.motion_correction.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_ophys_timestamps.py b/allensdk/test/brain_observatory/behavior/data_objects/test_ophys_timestamps.py index bd1281e9bc..487e6cb6a9 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_ophys_timestamps.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_ophys_timestamps.py @@ -4,60 +4,55 @@ import pytest from allensdk.brain_observatory.behavior.data_files import SyncFile -from allensdk.brain_observatory.behavior.data_objects.timestamps \ - .ophys_timestamps import \ - OphysTimestamps, OphysTimestampsMultiplane -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.brain_observatory.behavior.data_objects.timestamps.ophys_timestamps import ( + OphysTimestamps, + OphysTimestampsMultiplane, +) +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestFromSyncFile(LimsTest): def setup_method(self, method): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - self.sync_file = SyncFile(filepath=str(test_data_dir / 'sync.h5')) + self.sync_file = SyncFile(filepath=str(test_data_dir / "sync.h5")) def test_from_sync_file(self): - self.sync_file._data = {'ophys_frames': np.array([.1, .2, .3])} - ts = OphysTimestamps.from_sync_file(sync_file=self.sync_file)\ - .validate(number_of_frames=3) - expected = np.array([.1, .2, .3]) + self.sync_file._data = {"ophys_frames": np.array([0.1, 0.2, 0.3])} + ts = OphysTimestamps.from_sync_file(sync_file=self.sync_file).validate(number_of_frames=3) + expected = np.array([0.1, 0.2, 0.3]) np.testing.assert_equal(ts.value, expected) def test_too_long_single_plane(self): """test that timestamps are truncated for single plane data""" - self.sync_file._data = {'ophys_frames': np.array([.1, .2, .3])} - ts = OphysTimestamps.from_sync_file(sync_file=self.sync_file)\ - .validate(number_of_frames=2) - expected = np.array([.1, .2]) + self.sync_file._data = {"ophys_frames": np.array([0.1, 0.2, 0.3])} + ts = OphysTimestamps.from_sync_file(sync_file=self.sync_file).validate(number_of_frames=2) + expected = np.array([0.1, 0.2]) np.testing.assert_equal(ts.value, expected) def test_too_long_multi_plane(self): """test that exception raised when timestamps longer than # frames for multiplane data""" - self.sync_file._data = {'ophys_frames': np.array([.1, .2, .3])} + self.sync_file._data = {"ophys_frames": np.array([0.1, 0.2, 0.3])} with pytest.raises(RuntimeError): - OphysTimestampsMultiplane.from_sync_file(sync_file=self.sync_file, - group_count=2, - plane_group=0)\ - .validate(number_of_frames=1) + OphysTimestampsMultiplane.from_sync_file(sync_file=self.sync_file, group_count=2, plane_group=0).validate( + number_of_frames=1 + ) def test_too_short(self): """test when timestamps shorter than # frames""" - self.sync_file._data = {'ophys_frames': np.array([.1, .2, .3])} + self.sync_file._data = {"ophys_frames": np.array([0.1, 0.2, 0.3])} with pytest.raises(RuntimeError): - OphysTimestamps.from_sync_file(sync_file=self.sync_file)\ - .validate(number_of_frames=4) + OphysTimestamps.from_sync_file(sync_file=self.sync_file).validate(number_of_frames=4) def test_multiplane(self): """test timestamps properly extracted when multiplane""" - self.sync_file._data = {'ophys_frames': np.array([.1, .2, .3, .4])} - ts = OphysTimestampsMultiplane.from_sync_file(sync_file=self.sync_file, - group_count=2, - plane_group=0)\ - .validate(number_of_frames=2) - expected = np.array([.1, .3]) + self.sync_file._data = {"ophys_frames": np.array([0.1, 0.2, 0.3, 0.4])} + ts = OphysTimestampsMultiplane.from_sync_file(sync_file=self.sync_file, group_count=2, plane_group=0).validate( + number_of_frames=2 + ) + expected = np.array([0.1, 0.3]) np.testing.assert_equal(ts.value, expected) @pytest.mark.parametrize( @@ -72,20 +67,18 @@ def test_multiplane(self): # last (np.array([0, 1, 0, 1, 0, 1, 0, 1]), 1, 2, np.ones(4)), # only one group - (np.ones(10), 0, 1, np.ones(10)) - ] + (np.ones(10), 0, 1, np.ones(10)), + ], ) - def test_process_ophys_plane_timestamps( - self, timestamps, plane_group, group_count, expected): + def test_process_ophys_plane_timestamps(self, timestamps, plane_group, group_count, expected): """Various test cases""" - self.sync_file._data = {'ophys_frames': timestamps} - number_of_frames = len(timestamps) if group_count == 0 else \ - len(timestamps) / group_count + self.sync_file._data = {"ophys_frames": timestamps} + number_of_frames = len(timestamps) if group_count == 0 else len(timestamps) / group_count if group_count == 0: ts = OphysTimestamps.from_sync_file(sync_file=self.sync_file) else: ts = OphysTimestampsMultiplane.from_sync_file( - sync_file=self.sync_file, group_count=group_count, - plane_group=plane_group) + sync_file=self.sync_file, group_count=group_count, plane_group=plane_group + ) ts = ts.validate(number_of_frames=number_of_frames) np.testing.assert_array_equal(expected, ts.value) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_projections.py b/allensdk/test/brain_observatory/behavior/data_objects/test_projections.py index e9a2cf6aee..43deeb3856 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_projections.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_projections.py @@ -5,8 +5,7 @@ import pynwb import pytest -from allensdk.brain_observatory.behavior.data_objects.projections import \ - Projections +from allensdk.brain_observatory.behavior.data_objects.projections import Projections from allensdk.core.auth_config import LIMS_DB_CREDENTIAL_MAP from allensdk.internal.api import db_connection_creator @@ -17,30 +16,28 @@ def setup_class(cls): cls.ophys_experiment_id = 994278291 dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" cls.expected_max = Projections._from_filepath( - filepath=str(test_data_dir / 'max_projection.png'), - pixel_size=.78125) + filepath=str(test_data_dir / "max_projection.png"), pixel_size=0.78125 + ) cls.expected_avg = Projections._from_filepath( - filepath=str(test_data_dir / 'avg_projection.png'), - pixel_size=.78125) + filepath=str(test_data_dir / "avg_projection.png"), pixel_size=0.78125 + ) def setup_method(self, method): - marks = getattr(method, 'pytestmark', None) + marks = getattr(method, "pytestmark", None) if marks: marks = [m.name for m in marks] # Will only create a dbconn if the test requires_bamboo - if 'requires_bamboo' in marks: - self.dbconn = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP) + if "requires_bamboo" in marks: + self.dbconn = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) @pytest.mark.requires_bamboo def test_from_lims(self): - projections = Projections.from_lims( - ophys_experiment_id=self.ophys_experiment_id, lims_db=self.dbconn) + projections = Projections.from_lims(ophys_experiment_id=self.ophys_experiment_id, lims_db=self.dbconn) assert projections.max_projection == self.expected_max assert projections.avg_projection == self.expected_avg @@ -50,23 +47,22 @@ class TestJson: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' - with open(test_data_dir / 'test_input.json') as f: + test_data_dir = dir / "test_data" + with open(test_data_dir / "test_input.json") as f: dict_repr = json.load(f) - dict_repr = dict_repr['session_data'] - dict_repr['max_projection_file'] = test_data_dir / \ - dict_repr['max_projection_file'] - dict_repr['average_intensity_projection_image_file'] = \ - test_data_dir / \ - dict_repr['average_intensity_projection_image_file'] + dict_repr = dict_repr["session_data"] + dict_repr["max_projection_file"] = test_data_dir / dict_repr["max_projection_file"] + dict_repr["average_intensity_projection_image_file"] = ( + test_data_dir / dict_repr["average_intensity_projection_image_file"] + ) cls.expected_max = Projections._from_filepath( - filepath=str(test_data_dir / 'max_projection.png'), - pixel_size=.78125) + filepath=str(test_data_dir / "max_projection.png"), pixel_size=0.78125 + ) cls.expected_avg = Projections._from_filepath( - filepath=str(test_data_dir / 'avg_projection.png'), - pixel_size=.78125) + filepath=str(test_data_dir / "avg_projection.png"), pixel_size=0.78125 + ) cls.dict_repr = dict_repr @@ -82,25 +78,19 @@ class TestNWB: def setup_class(cls): tj = TestJson() tj.setup_class() - cls.projections = Projections.from_json( - dict_repr=tj.dict_repr) + cls.projections = Projections.from_json(dict_repr=tj.dict_repr) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.projections.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=Projections) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=Projections) else: obt = self.projections.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_rewards.py b/allensdk/test/brain_observatory/behavior/data_objects/test_rewards.py index faabe2f25e..e3ea9a5734 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_rewards.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_rewards.py @@ -9,8 +9,7 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory.behavior.data_objects.rewards import Rewards -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestFromBehaviorStimulusFile(LimsTest): @@ -19,20 +18,16 @@ def setup_class(cls): cls.behavior_session_id = 994174745 dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - expected = pd.read_pickle(str(test_data_dir / 'rewards.pkl')) + expected = pd.read_pickle(str(test_data_dir / "rewards.pkl")) cls.expected = Rewards(rewards=expected) @pytest.mark.requires_bamboo def test_from_stimulus_file(self): - stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=self.behavior_session_id, db=self.dbconn) - timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, - monitor_delay=0.0) - rewards = Rewards.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) + stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=self.behavior_session_id, db=self.dbconn) + timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.0) + rewards = Rewards.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) assert rewards == self.expected def test_monitor_delay_error(self): @@ -40,14 +35,9 @@ def test_monitor_delay_error(self): Test that an error is raised if Rewards are instantiated with non-zero monitor delay """ - timestamps = StimulusTimestamps( - np.arange(10), - 0.1) - with pytest.raises(RuntimeError, - match="monitor_delay should be zero"): - Rewards.from_stimulus_file( - stimulus_file=None, - stimulus_timestamps=timestamps) + timestamps = StimulusTimestamps(np.arange(10), 0.1) + with pytest.raises(RuntimeError, match="monitor_delay should be zero"): + Rewards.from_stimulus_file(stimulus_file=None, stimulus_timestamps=timestamps) def test_from_stimulus_file2(self, tmpdir): """ @@ -59,37 +49,26 @@ def test_from_stimulus_file2(self, tmpdir): def _create_dummy_stimulus_file(): trial_log = [ - {'rewards': [(0.001, -1.0, 4)], - 'trial_params': {'auto_reward': True}}, - {'rewards': []}, - {'rewards': [(0.002, -1.0, 10)], - 'trial_params': {'auto_reward': False}} + {"rewards": [(0.001, -1.0, 4)], "trial_params": {"auto_reward": True}}, + {"rewards": []}, + {"rewards": [(0.002, -1.0, 10)], "trial_params": {"auto_reward": False}}, ] data = { - 'items': { - 'behavior': { - 'trial_log': trial_log - } - }, + "items": {"behavior": {"trial_log": trial_log}}, } - tmp_path = tmpdir / 'stimulus_file.pkl' - with open(tmp_path, 'wb') as f: + tmp_path = tmpdir / "stimulus_file.pkl" + with open(tmp_path, "wb") as f: pickle.dump(data, f) f.seek(0) return tmp_path stimulus_filepath = _create_dummy_stimulus_file() - stimulus_file = BehaviorStimulusFile.from_json( - dict_repr={'behavior_stimulus_file': str(stimulus_filepath)}) - timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), - monitor_delay=0.0) - rewards = Rewards.from_stimulus_file(stimulus_file=stimulus_file, - stimulus_timestamps=timestamps) - - expected_dict = {'volume': [0.001, 0.002], - 'timestamps': [0.04, 0.1], - 'auto_rewarded': [True, False]} + stimulus_file = BehaviorStimulusFile.from_json(dict_repr={"behavior_stimulus_file": str(stimulus_filepath)}) + timestamps = StimulusTimestamps(timestamps=np.arange(0, 2.0, 0.01), monitor_delay=0.0) + rewards = Rewards.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=timestamps) + + expected_dict = {"volume": [0.001, 0.002], "timestamps": [0.04, 0.1], "auto_rewarded": [True, False]} expected_df = pd.DataFrame(expected_dict) expected_df = expected_df assert expected_df.equals(rewards.value) @@ -99,27 +78,22 @@ class TestNWB: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - rewards = pd.read_pickle(str(test_data_dir / 'rewards.pkl')) + rewards = pd.read_pickle(str(test_data_dir / "rewards.pkl")) cls.rewards = Rewards(rewards=rewards) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.rewards.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=Rewards) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=Rewards) else: obt = self.rewards.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_stimuli.py b/allensdk/test/brain_observatory/behavior/data_objects/test_stimuli.py index 0873f53264..daf2a437cf 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_stimuli.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_stimuli.py @@ -31,16 +31,10 @@ def setup_class(cls): dir = Path(__file__).parent.resolve() test_data_dir = dir / "test_data" - presentations = pd.read_pickle( - str(test_data_dir / "presentations.pkl") - ) + presentations = pd.read_pickle(str(test_data_dir / "presentations.pkl")) templates = pd.read_pickle(str(test_data_dir / "templates.pkl")) - cls.expected_presentations = StimulusPresentations( - presentations=presentations - ) - cls.expected_templates = Templates( - templates={templates.image_set_name: templates} - ) + cls.expected_presentations = StimulusPresentations(presentations=presentations) + cls.expected_templates = Templates(templates={templates.image_set_name: templates}) @pytest.mark.requires_bamboo def test_from_stimulus_file(self): @@ -58,12 +52,8 @@ def data(self): } ) - stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=self.behavior_session_id, db=self.dbconn - ) - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, monitor_delay=0.0 - ) + stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=self.behavior_session_id, db=self.dbconn) + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.0) stimuli = Stimuli.from_stimulus_file( stimulus_file=stimulus_file, stimulus_timestamps=stimulus_timestamps, @@ -82,17 +72,13 @@ def presentations_fixture(behavior_ecephys_session_config_fixture): """ obj = Presentations.from_path( path=behavior_ecephys_session_config_fixture["stim_table_file"], - behavior_session_id=( - behavior_ecephys_session_config_fixture["behavior_session_id"] - ), + behavior_session_id=(behavior_ecephys_session_config_fixture["behavior_session_id"]), ) return obj @pytest.mark.requires_bamboo -@pytest.mark.parametrize( - "roundtrip, add_is_change", ([True, False], [True, False]) -) +@pytest.mark.parametrize("roundtrip, add_is_change", ([True, False], [True, False])) def test_read_write_nwb( roundtrip, add_is_change, @@ -104,12 +90,8 @@ def test_read_write_nwb( nwbfile = helper_functions.create_blank_nwb_file() # Need to write stimulus timestamps first - bsf = BehaviorStimulusFile.from_json( - dict_repr=behavior_ecephys_session_config_fixture - ) - ts = StimulusTimestamps.from_stimulus_file( - stimulus_file=bsf, monitor_delay=0.0 - ) + bsf = BehaviorStimulusFile.from_json(dict_repr=behavior_ecephys_session_config_fixture) + ts = StimulusTimestamps.from_stimulus_file(stimulus_file=bsf, monitor_delay=0.0) ts.to_nwb(nwbfile=nwbfile) presentations_fixture.to_nwb(nwbfile=nwbfile) @@ -121,9 +103,7 @@ def test_read_write_nwb( add_is_change=add_is_change, ) else: - obt = Presentations.from_nwb( - nwbfile=nwbfile, add_is_change=add_is_change - ) + obt = Presentations.from_nwb(nwbfile=nwbfile, add_is_change=add_is_change) assert obt == presentations_fixture @@ -134,9 +114,7 @@ def setup_class(cls): dir = Path(__file__).parent.resolve() cls.test_data_dir = dir / "test_data" - presentations = pd.read_pickle( - str(cls.test_data_dir / "presentations.pkl") - ) + presentations = pd.read_pickle(str(cls.test_data_dir / "presentations.pkl")) templates = pd.read_pickle(str(cls.test_data_dir / "templates.pkl")) presentations = presentations.drop("is_change", axis=1) presentations = presentations.drop("flashes_since_change", axis=1) @@ -152,12 +130,8 @@ def setup_method(self, method): ) # Need to write stimulus timestamps first - bsf = BehaviorStimulusFile( - filepath=self.test_data_dir / "behavior_stimulus_file.pkl" - ) - ts = StimulusTimestamps.from_stimulus_file( - stimulus_file=bsf, monitor_delay=0.0 - ) + bsf = BehaviorStimulusFile(filepath=self.test_data_dir / "behavior_stimulus_file.pkl") + ts = StimulusTimestamps.from_stimulus_file(stimulus_file=bsf, monitor_delay=0.0) ts.to_nwb(nwbfile=self.nwbfile) @pytest.mark.parametrize("roundtrip", [True, False]) @@ -165,18 +139,14 @@ def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.stimuli.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, data_object_cls=Stimuli - ) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=Stimuli) else: obt = Stimuli.from_nwb(nwbfile=self.nwbfile) # is_change different due to limit_to_images. flashes_since_change # also relies on this column so we ommit that. obt.presentations.value.drop("is_change", axis=1, inplace=True) - obt.presentations.value.drop( - "flashes_since_change", axis=1, inplace=True - ) + obt.presentations.value.drop("flashes_since_change", axis=1, inplace=True) assert obt == self.stimuli @@ -213,11 +183,7 @@ def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): def test_set_omitted_stop_time(stimulus_table, expected_table_data): stimulus_table = pd.DataFrame.from_dict(data=stimulus_table) expected_table = pd.DataFrame.from_dict(data=expected_table_data) - stimulus_table = ( - StimulusPresentations._fill_missing_values_for_omitted_flashes( - df=stimulus_table - ) - ) + stimulus_table = StimulusPresentations._fill_missing_values_for_omitted_flashes(df=stimulus_table) assert stimulus_table.equals(expected_table) @@ -227,9 +193,7 @@ def stimulus_templates_fixture(behavior_ecephys_session_config_fixture): Return a Templates object """ - sf = BehaviorStimulusFile.from_json( - dict_repr=behavior_ecephys_session_config_fixture - ) + sf = BehaviorStimulusFile.from_json(dict_repr=behavior_ecephys_session_config_fixture) obj = Templates.from_stimulus_file(stimulus_file=sf) return obj @@ -248,14 +212,10 @@ def test_read_write_nwb_no_image_index( nwbfile = helper_functions.create_blank_nwb_file() - stimulus_templates_fixture.to_nwb( - nwbfile=nwbfile, stimulus_presentations=presentations_fixture - ) + stimulus_templates_fixture.to_nwb(nwbfile=nwbfile, stimulus_presentations=presentations_fixture) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, data_object_cls=Templates - ) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=Templates) else: obt = Templates.from_nwb(nwbfile=nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_stimulus_presentations.py b/allensdk/test/brain_observatory/behavior/data_objects/test_stimulus_presentations.py index c3deace54e..85d0063940 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_stimulus_presentations.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_stimulus_presentations.py @@ -9,24 +9,19 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps -from allensdk.brain_observatory.behavior.data_objects.stimuli\ - .fingerprint_stimulus import \ - FingerprintStimulus -from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations \ - import get_spontaneous_block_indices, Presentations +from allensdk.brain_observatory.behavior.data_objects.stimuli.fingerprint_stimulus import FingerprintStimulus +from allensdk.brain_observatory.behavior.data_objects.stimuli.presentations import ( + get_spontaneous_block_indices, + Presentations, +) from allensdk.internal.brain_observatory.mouse import Mouse -@pytest.mark.parametrize('stimulus_blocks, expected', [ - ([0, 2, 3], [1]), - ([0, 2, 4], [1, 3]), - ([0, 1, 2], []) -]) +@pytest.mark.parametrize("stimulus_blocks, expected", [([0, 2, 3], [1]), ([0, 2, 4], [1, 3]), ([0, 1, 2], [])]) def test_get_spontaneous_block_indices(stimulus_blocks, expected): - stimulus_blocks = np.array(stimulus_blocks, dtype='int') - expected = np.array(expected, dtype='int') - obtained = get_spontaneous_block_indices( - stimulus_blocks=stimulus_blocks) + stimulus_blocks = np.array(stimulus_blocks, dtype="int") + expected = np.array(expected, dtype="int") + obtained = get_spontaneous_block_indices(stimulus_blocks=stimulus_blocks) assert np.array_equal(obtained, expected) @@ -34,24 +29,21 @@ class TestFingerprintStimulus: @classmethod def setup_class(cls): stim_file = { - 'items': { - 'behavior': { - 'items': { - 'fingerprint': { - 'static_stimulus': { - 'runs': 2, + "items": { + "behavior": { + "items": { + "fingerprint": { + "static_stimulus": { + "runs": 2, # Movie is 2 frames long, repeats 2 times - 'sweep_frames': np.array([(0, 1), (2, 3), - (4, 5), (6, 7)]), + "sweep_frames": np.array([(0, 1), (2, 3), (4, 5), (6, 7)]), # 2 frames of gray screen followed by 2 movie # frames - 'frame_list': np.array([-1, -1, 0, 1]) + "frame_list": np.array([-1, -1, 0, 1]), }, # 2 gray screen on frame 5, 6 followed by 4 frames # of movie that last 2 monitor frames each - 'frame_indices': np.array([5, 6] + - list(range( - 7, 7 + 4 * 2))) + "frame_indices": np.array([5, 6] + list(range(7, 7 + 4 * 2))), } } } @@ -59,20 +51,13 @@ def setup_class(cls): } tmpdir = tempfile.TemporaryDirectory() cls.tmpdir = tmpdir - with open(Path(tmpdir.name) / 'behavior_stimulus.pkl', 'wb') as f: + with open(Path(tmpdir.name) / "behavior_stimulus.pkl", "wb") as f: pickle.dump(stim_file, f) - cls.stimulus_file = BehaviorStimulusFile( - filepath=Path(tmpdir.name) / 'behavior_stimulus.pkl' - ) - cls.stimulus_presentations_table = pd.DataFrame({ - 'stimulus_block': [0] - }) - cls.stimulus_timestamps = StimulusTimestamps( - timestamps=np.arange(0, 20), - monitor_delay=0.0 - ) + cls.stimulus_file = BehaviorStimulusFile(filepath=Path(tmpdir.name) / "behavior_stimulus.pkl") + cls.stimulus_presentations_table = pd.DataFrame({"stimulus_block": [0]}) + cls.stimulus_timestamps = StimulusTimestamps(timestamps=np.arange(0, 20), monitor_delay=0.0) dir = Path(__file__).parent.resolve() - cls.test_data_dir = dir / 'test_data' + cls.test_data_dir = dir / "test_data" def teardown(self): self.tmpdir.cleanup() @@ -82,10 +67,9 @@ def test_fingerprint_stimulus(self): obt = FingerprintStimulus.from_stimulus_file( stimulus_presentations=self.stimulus_presentations_table, stimulus_file=self.stimulus_file, - stimulus_timestamps=self.stimulus_timestamps - + stimulus_timestamps=self.stimulus_timestamps, ) - with open(self.test_data_dir / 'fingerprint_stimulus.pkl', 'rb') as f: + with open(self.test_data_dir / "fingerprint_stimulus.pkl", "rb") as f: expected = pickle.load(f) obt = obt.table[sorted([c for c in obt.table])] @@ -95,61 +79,56 @@ def test_fingerprint_stimulus(self): def test_add_fingerprint_stimulus(self): """Checks that fingerprint block and spontaneous block are correctly added to table""" - with open(self.test_data_dir / 'fingerprint_stimulus.pkl', 'rb') as f: + with open(self.test_data_dir / "fingerprint_stimulus.pkl", "rb") as f: fingerprint_stim = pickle.load(f) obt = Presentations._add_fingerprint_stimulus( stimulus_presentations=self.stimulus_presentations_table, stimulus_file=self.stimulus_file, - stimulus_timestamps=self.stimulus_timestamps + stimulus_timestamps=self.stimulus_timestamps, ) # there should a block for 0, 1 (spontaneous) and 2 (fingerprint) - assert sorted(obt['stimulus_block'].unique().tolist()) == [0, 1, 2] + assert sorted(obt["stimulus_block"].unique().tolist()) == [0, 1, 2] expected_num_spontaneous_rows = 1 expected_num_fingerprint_rows = fingerprint_stim.shape[0] - assert obt[obt['stimulus_name'] == 'spontaneous'].shape[0] == \ - expected_num_spontaneous_rows + assert obt[obt["stimulus_name"] == "spontaneous"].shape[0] == expected_num_spontaneous_rows - assert obt.shape[0] == \ - self.stimulus_presentations_table.shape[0] + \ - expected_num_spontaneous_rows + \ - expected_num_fingerprint_rows + assert ( + obt.shape[0] + == self.stimulus_presentations_table.shape[0] + + expected_num_spontaneous_rows + + expected_num_fingerprint_rows + ) class TestStimulusPresentations: - @pytest.mark.parametrize('image_names, expected', [ - (['A', 'B'], {'A': False, 'B': False}), - (['A', 'C'], {'A': False, 'C': True}), - (['C', 'omitted', np.nan, None], {'C': True}) - ]) + @pytest.mark.parametrize( + "image_names, expected", + [ + (["A", "B"], {"A": False, "B": False}), + (["A", "C"], {"A": False, "C": True}), + (["C", "omitted", np.nan, None], {"C": True}), + ], + ) def test_get_is_image_novel(self, image_names, expected): - with patch.object(Mouse, attribute='from_behavior_session_id', - wraps=lambda behavior_session_id: Mouse('1')): - with patch.object( - Mouse, attribute='get_images_shown', - wraps=lambda up_to_behavior_session_id: {'A', 'B'}): - obt = Presentations._get_is_image_novel( - image_names=image_names, behavior_session_id=1) + with patch.object(Mouse, attribute="from_behavior_session_id", wraps=lambda behavior_session_id: Mouse("1")): + with patch.object(Mouse, attribute="get_images_shown", wraps=lambda up_to_behavior_session_id: {"A", "B"}): + obt = Presentations._get_is_image_novel(image_names=image_names, behavior_session_id=1) assert obt == expected def test_add_is_image_novel(self): - stimulus_presentations = pd.DataFrame({ - 'image_name': ['A', 'B', 'C', np.nan]}) - is_image_novel = { - 'A': False, - 'B': False, - 'C': True - } - with patch.object(Presentations, attribute='_get_is_image_novel', - wraps=lambda image_names, behavior_session_id: - is_image_novel): - Presentations._add_is_image_novel( - stimulus_presentations=stimulus_presentations, - behavior_session_id=1 + stimulus_presentations = pd.DataFrame({"image_name": ["A", "B", "C", np.nan]}) + is_image_novel = {"A": False, "B": False, "C": True} + with patch.object( + Presentations, + attribute="_get_is_image_novel", + wraps=lambda image_names, behavior_session_id: is_image_novel, + ): + Presentations._add_is_image_novel(stimulus_presentations=stimulus_presentations, behavior_session_id=1) + assert ( + stimulus_presentations["is_image_novel"].tolist() + == list(is_image_novel.values()) + # due to last stimulus which is not an image + + [np.nan] ) - assert (stimulus_presentations['is_image_novel'].tolist() == - list(is_image_novel.values()) - # due to last stimulus which is not an image - + [np.nan] - ) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_task_parameters.py b/allensdk/test/brain_observatory/behavior/data_objects/test_task_parameters.py index fa0924c44d..5dfe2c96f7 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_task_parameters.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_task_parameters.py @@ -6,10 +6,8 @@ import pytest from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile -from allensdk.brain_observatory.behavior.data_objects.task_parameters import \ - TaskParameters -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.brain_observatory.behavior.data_objects.task_parameters import TaskParameters +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestFromBehaviorStimulusFile(LimsTest): @@ -18,16 +16,15 @@ def setup_class(cls): cls.behavior_session_id = 994174745 dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - with open(test_data_dir / 'task_parameters.json') as f: + with open(test_data_dir / "task_parameters.json") as f: tp = json.load(f) cls.expected = TaskParameters(**tp) @pytest.mark.requires_bamboo def test_from_stimulus_file(self): - stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=self.behavior_session_id, db=self.dbconn) + stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=self.behavior_session_id, db=self.dbconn) tp = TaskParameters.from_stimulus_file(stimulus_file=stimulus_file) assert tp == self.expected @@ -35,32 +32,26 @@ def test_from_stimulus_file(self): class TestNWB: def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) dir = Path(__file__).parent.resolve() - self.test_data_dir = dir / 'test_data' + self.test_data_dir = dir / "test_data" - with open(self.test_data_dir / 'task_parameters.json') as f: + with open(self.test_data_dir / "task_parameters.json") as f: tp = json.load(f) self.task_parameters = TaskParameters(**tp) - @pytest.mark.parametrize('is_stimulus_duration_sec_nan', [True, False]) - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture, - is_stimulus_duration_sec_nan): + @pytest.mark.parametrize("is_stimulus_duration_sec_nan", [True, False]) + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture, is_stimulus_duration_sec_nan): if is_stimulus_duration_sec_nan: self.task_parameters._stimulus_duration_sec = np.nan self.task_parameters.to_nwb(nwbfile=self.nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=TaskParameters) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=TaskParameters) else: obt = TaskParameters.from_nwb(nwbfile=self.nwbfile) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_templates.py b/allensdk/test/brain_observatory/behavior/data_objects/test_templates.py index fe2eea4e7e..8d083bddc8 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_templates.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_templates.py @@ -7,13 +7,14 @@ @pytest.mark.parametrize( "input_templates, message", - ([{"images1": None, "Other Input": None}, - "Found multiple image StimulusTemplates"], - [{"movie1": None, "movie2": None}, - "Found multiple fingerprint movie StimulusTemplate"], - [{"images1": None, "Other Input": None, - "movie1": None, "movie2": None}, - "Found multiple image StimulusTemplates"]) + ( + [{"images1": None, "Other Input": None}, "Found multiple image StimulusTemplates"], + [{"movie1": None, "movie2": None}, "Found multiple fingerprint movie StimulusTemplate"], + [ + {"images1": None, "Other Input": None, "movie1": None, "movie2": None}, + "Found multiple image StimulusTemplates", + ], + ), ) def test_template_to_many_inputs_exception(input_templates, message): """Test that we catch exceptions for too many input StimulusTemplates.""" diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_trial_obj.py b/allensdk/test/brain_observatory/behavior/data_objects/test_trial_obj.py index 1a95c2b019..b0ebbaf5e4 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_trial_obj.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_trial_obj.py @@ -1,136 +1,117 @@ import pytest import numpy as np import pandas as pd -from allensdk.brain_observatory.behavior.data_objects import ( - StimulusTimestamps) -from allensdk.brain_observatory.behavior.data_objects.trials.trial import ( - Trial) - - -@pytest.mark.parametrize("behavior_stimuli_data_fixture, trial, expected", - [({}, - {'events': [(None, None, None, 0)], - 'stimulus_changes': []}, - {'initial_image_name': 'gratings_90', - 'change_image_name': 'gratings_90'}), - ({}, - {'events': [(None, None, None, 0)], - 'stimulus_changes':[ - (('horizontal', 90), - ('vertical', 180), - None, None)]}, - {'initial_image_name': 'gratings_90', - 'change_image_name': 'gratings_180'}), - ({"images_set_log": [('Image', 'im065', 5, 0)], - "grating_set_log": [("Ori", 270, 15, 6)]}, - {'events': [(None, None, None, 5)], - 'stimulus_changes': - [(('im065', 'im065'), - ('im057', 'im057'), - None, None)]}, - {'initial_image_name': 'im065', - 'change_image_name': 'im057'})], - indirect=['behavior_stimuli_data_fixture']) -def test_get_trial_image_names(behavior_stimuli_data_fixture, trial, - expected): +from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps +from allensdk.brain_observatory.behavior.data_objects.trials.trial import Trial + +@pytest.mark.parametrize( + "behavior_stimuli_data_fixture, trial, expected", + [ + ( + {}, + {"events": [(None, None, None, 0)], "stimulus_changes": []}, + {"initial_image_name": "gratings_90", "change_image_name": "gratings_90"}, + ), + ( + {}, + { + "events": [(None, None, None, 0)], + "stimulus_changes": [(("horizontal", 90), ("vertical", 180), None, None)], + }, + {"initial_image_name": "gratings_90", "change_image_name": "gratings_180"}, + ), + ( + {"images_set_log": [("Image", "im065", 5, 0)], "grating_set_log": [("Ori", 270, 15, 6)]}, + { + "events": [(None, None, None, 5)], + "stimulus_changes": [(("im065", "im065"), ("im057", "im057"), None, None)], + }, + {"initial_image_name": "im065", "change_image_name": "im057"}, + ), + ], + indirect=["behavior_stimuli_data_fixture"], +) +def test_get_trial_image_names(behavior_stimuli_data_fixture, trial, expected): class DummyTrial(Trial): @staticmethod - def _calculate_trial_end(trial_end, - behavior_stimulus_file): + def _calculate_trial_end(trial_end, behavior_stimulus_file): return -999 - def _match_to_sync_timestamps( - self, - raw_stimulus_timestamps, - licks, - rewards, - stimuli): + def _match_to_sync_timestamps(self, raw_stimulus_timestamps, licks, rewards, stimuli): return dict() - stimuli = behavior_stimuli_data_fixture['items']['behavior']['stimuli'] - - trial_obj = DummyTrial(trial=trial, - start=None, - end=None, - behavior_stimulus_file=None, - index=None, - stimulus_timestamps=None, - licks=None, - rewards=None, - stimuli=stimuli) + stimuli = behavior_stimuli_data_fixture["items"]["behavior"]["stimuli"] + + trial_obj = DummyTrial( + trial=trial, + start=None, + end=None, + behavior_stimulus_file=None, + index=None, + stimulus_timestamps=None, + licks=None, + rewards=None, + stimuli=stimuli, + ) trial_image_names = trial_obj._get_trial_image_names(stimuli) assert trial_image_names == expected -@pytest.mark.parametrize("behavior_stimuli_data_fixture, start_frame," - "expected", - [({}, 0, ('grating', 90, 'gratings_90')), - ({ - "images_set_log": [ - ('Image', 'im065', 5, 0)], - "grating_set_log": [ - ("Ori", 270, 15, 6)]}, 0, - ('images', 'im065', 'im065')), - ({ - "images_set_log": [], - "grating_set_log": [] - }, 0, ('', '', ''))], - indirect=['behavior_stimuli_data_fixture']) -def test_resolve_initial_image(behavior_stimuli_data_fixture, start_frame, - expected): - +@pytest.mark.parametrize( + "behavior_stimuli_data_fixture, start_frame,expected", + [ + ({}, 0, ("grating", 90, "gratings_90")), + ( + {"images_set_log": [("Image", "im065", 5, 0)], "grating_set_log": [("Ori", 270, 15, 6)]}, + 0, + ("images", "im065", "im065"), + ), + ({"images_set_log": [], "grating_set_log": []}, 0, ("", "", "")), + ], + indirect=["behavior_stimuli_data_fixture"], +) +def test_resolve_initial_image(behavior_stimuli_data_fixture, start_frame, expected): class DummyTrial(Trial): @staticmethod - def _calculate_trial_end(trial_end, - behavior_stimulus_file): + def _calculate_trial_end(trial_end, behavior_stimulus_file): return -999 - def _match_to_sync_timestamps( - self, - raw_stimulus_timestamps, - licks, - rewards, - stimuli): + def _match_to_sync_timestamps(self, raw_stimulus_timestamps, licks, rewards, stimuli): return dict() - stimuli = behavior_stimuli_data_fixture['items']['behavior']['stimuli'] - - trial_obj = DummyTrial(trial=None, - start=None, - end=None, - behavior_stimulus_file=None, - index=None, - stimulus_timestamps=None, - licks=None, - rewards=None, - stimuli=stimuli) + stimuli = behavior_stimuli_data_fixture["items"]["behavior"]["stimuli"] + + trial_obj = DummyTrial( + trial=None, + start=None, + end=None, + behavior_stimulus_file=None, + index=None, + stimulus_timestamps=None, + licks=None, + rewards=None, + stimuli=stimuli, + ) resolved = trial_obj._resolve_initial_image(stimuli, start_frame) assert resolved == expected @pytest.mark.parametrize( - "go,catch,auto_rewarded,hit,false_alarm,aborted,errortext", [ - (False, False, False, True, False, True, - "'aborted' trials cannot be"), # aborted and hit - (False, False, False, False, True, True, - "'aborted' trials cannot be"), # aborted and false alarm - (False, False, True, False, False, True, - "'aborted' trials cannot be"), # aborted and auto_rewarded - (False, False, False, True, True, False, - "both `hit` and `false_alarm` cannot be True"), # hit and false alarm - (True, True, False, False, False, False, - "both `go` and `catch` cannot be True"), # go and catch + "go,catch,auto_rewarded,hit,false_alarm,aborted,errortext", + [ + (False, False, False, True, False, True, "'aborted' trials cannot be"), # aborted and hit + (False, False, False, False, True, True, "'aborted' trials cannot be"), # aborted and false alarm + (False, False, True, False, False, True, "'aborted' trials cannot be"), # aborted and auto_rewarded + (False, False, False, True, True, False, "both `hit` and `false_alarm` cannot be True"), # hit and false alarm + (True, True, False, False, False, False, "both `go` and `catch` cannot be True"), # go and catch # go and auto_rewarded - (True, False, True, False, False, False, - "both `go` and `auto_rewarded` cannot be True") - ] + (True, False, True, False, False, False, "both `go` and `auto_rewarded` cannot be True"), + ], ) -def test_get_trial_timing_exclusivity_assertions( - go, catch, auto_rewarded, hit, false_alarm, aborted, errortext): - +def test_get_trial_timing_exclusivity_assertions(go, catch, auto_rewarded, hit, false_alarm, aborted, errortext): # we just want to test a method of Trial, specifically test # errors that will be raised before any processing happens, # so we can define a child class with an empty __init__ @@ -139,37 +120,26 @@ def __init__(self): pass with pytest.raises(AssertionError) as e: - DummyTrial()._get_trial_timing( - None, None, go, catch, auto_rewarded, hit, false_alarm, - aborted) + DummyTrial()._get_trial_timing(None, None, go, catch, auto_rewarded, hit, false_alarm, aborted) assert errortext in str(e.value) def test_get_trial_timing(): event_dict = { - ('trial_start', ''): {'timestamp': 306.4785879253758, 'frame': 18075}, - ('initial_blank', 'enter'): {'timestamp': 306.47868008512637, - 'frame': 18075}, - ('initial_blank', 'exit'): {'timestamp': 306.4787637603285, - 'frame': 18075}, - ('pre_change', 'enter'): {'timestamp': 306.47883573270514, - 'frame': 18075}, - ('pre_change', 'exit'): {'timestamp': 306.4789062422286, - 'frame': 18075}, - ('stimulus_window', 'enter'): {'timestamp': 306.478977629464, - 'frame': 18075}, - ('stimulus_changed', ''): {'timestamp': 310.9827406729944, - 'frame': 18345}, - ('auto_reward', ''): {'timestamp': 310.98279450599154, 'frame': 18345}, - ('response_window', 'enter'): {'timestamp': 311.13223900212347, - 'frame': 18354}, - ('response_window', 'exit'): {'timestamp': 311.73284526699706, - 'frame': 18390}, - ('miss', ''): {'timestamp': 311.7330193465259, 'frame': 18390}, - ('stimulus_window', 'exit'): {'timestamp': 315.2356723770604, - 'frame': 18600}, - ('no_lick', 'exit'): {'timestamp': 315.23582480636213, 'frame': 18600}, - ('trial_end', ''): {'timestamp': 315.23590438557534, 'frame': 18600} + ("trial_start", ""): {"timestamp": 306.4785879253758, "frame": 18075}, + ("initial_blank", "enter"): {"timestamp": 306.47868008512637, "frame": 18075}, + ("initial_blank", "exit"): {"timestamp": 306.4787637603285, "frame": 18075}, + ("pre_change", "enter"): {"timestamp": 306.47883573270514, "frame": 18075}, + ("pre_change", "exit"): {"timestamp": 306.4789062422286, "frame": 18075}, + ("stimulus_window", "enter"): {"timestamp": 306.478977629464, "frame": 18075}, + ("stimulus_changed", ""): {"timestamp": 310.9827406729944, "frame": 18345}, + ("auto_reward", ""): {"timestamp": 310.98279450599154, "frame": 18345}, + ("response_window", "enter"): {"timestamp": 311.13223900212347, "frame": 18354}, + ("response_window", "exit"): {"timestamp": 311.73284526699706, "frame": 18390}, + ("miss", ""): {"timestamp": 311.7330193465259, "frame": 18390}, + ("stimulus_window", "exit"): {"timestamp": 315.2356723770604, "frame": 18600}, + ("no_lick", "exit"): {"timestamp": 315.23582480636213, "frame": 18600}, + ("trial_end", ""): {"timestamp": 315.23590438557534, "frame": 18600}, } licks = [ @@ -197,9 +167,7 @@ def test_get_trial_timing(): # need a mock Trial class that just populates # self._stimulus_timestamps - stimulus_timestamps = StimulusTimestamps( - timestamps=timestamps, - monitor_delay=monitor_delay) + stimulus_timestamps = StimulusTimestamps(timestamps=timestamps, monitor_delay=monitor_delay) class DummyTrial(Trial): def __init__(self, timestamps): @@ -208,28 +176,21 @@ def __init__(self, timestamps): this_trial = DummyTrial(timestamps=stimulus_timestamps) result = this_trial._get_trial_timing( - event_dict, - licks, - go=False, - catch=False, - auto_rewarded=True, - hit=False, - false_alarm=False, - aborted=False + event_dict, licks, go=False, catch=False, auto_rewarded=True, hit=False, false_alarm=False, aborted=False ) expected_result = { - 'start_time': 306.4785879253758, - 'stop_time': 315.23590438557534, - 'trial_length': 8.757316460199547, - 'response_time': 312.24876, - 'change_frame': 18345, - 'change_time': 311.78086, - 'response_latency': 0.4678999999999769 + "start_time": 306.4785879253758, + "stop_time": 315.23590438557534, + "trial_length": 8.757316460199547, + "response_time": 312.24876, + "change_frame": 18345, + "change_time": 311.78086, + "response_latency": 0.4678999999999769, } # use assert_frame_equal to take advantage of the # nice way it deals with NaNs - pd.testing.assert_frame_equal(pd.DataFrame(result, index=[0]), - pd.DataFrame(expected_result, index=[0]), - check_names=False) + pd.testing.assert_frame_equal( + pd.DataFrame(result, index=[0]), pd.DataFrame(expected_result, index=[0]), check_names=False + ) diff --git a/allensdk/test/brain_observatory/behavior/data_objects/test_trial_table.py b/allensdk/test/brain_observatory/behavior/data_objects/test_trial_table.py index 8a25ca67a1..bcd0930601 100644 --- a/allensdk/test/brain_observatory/behavior/data_objects/test_trial_table.py +++ b/allensdk/test/brain_observatory/behavior/data_objects/test_trial_table.py @@ -7,24 +7,16 @@ import pynwb import pytest -from allensdk.brain_observatory.behavior.data_files import ( - BehaviorStimulusFile, - SyncFile) +from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile, SyncFile from allensdk.brain_observatory.behavior.data_objects import StimulusTimestamps from allensdk.brain_observatory.behavior.data_objects.licks import Licks -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.equipment import \ - Equipment +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.equipment import Equipment from allensdk.brain_observatory.behavior.data_objects.rewards import Rewards -from allensdk.brain_observatory.behavior.data_objects.stimuli.util import \ - calculate_monitor_delay -from allensdk.brain_observatory.behavior.data_objects.task_parameters import \ - TaskParameters -from allensdk.brain_observatory.behavior.data_objects.trials.trials import \ - Trials +from allensdk.brain_observatory.behavior.data_objects.stimuli.util import calculate_monitor_delay +from allensdk.brain_observatory.behavior.data_objects.task_parameters import TaskParameters +from allensdk.brain_observatory.behavior.data_objects.trials.trials import Trials from allensdk.internal.brain_observatory.time_sync import OphysTimeAligner -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestFromBehaviorStimulusFile(LimsTest): @@ -34,7 +26,7 @@ def setup_class(cls): cls.ophys_experiment_id = 994278291 dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" # Note: trials.pkl must be created from a BehaviorSession, # not a BehaviorOphysExperiment. If it is created from @@ -42,20 +34,15 @@ def setup_class(cls): # instantiated from a sync file, rather than a stimulus file. # The tests expect stimulus_timestamps to be instantiated # from a stimulus file. - expected = pd.read_pickle(str(test_data_dir / 'trials.pkl')) + expected = pd.read_pickle(str(test_data_dir / "trials.pkl")) cls.expected = Trials(trials=expected, response_window_start=0) @pytest.mark.requires_bamboo def test_from_stimulus_file(self): - stimulus_file, stimulus_timestamps, licks, rewards, \ - response_window_start = \ - self._get_trial_table_data() + stimulus_file, stimulus_timestamps, licks, rewards, response_window_start = self._get_trial_table_data() trials = Trials.from_stimulus_file( - stimulus_file=stimulus_file, - stimulus_timestamps=stimulus_timestamps, - licks=licks, - rewards=rewards + stimulus_file=stimulus_file, stimulus_timestamps=stimulus_timestamps, licks=licks, rewards=rewards ) trials._response_window_start = response_window_start self.expected._response_window_start = response_window_start @@ -64,69 +51,56 @@ def test_from_stimulus_file(self): def test_from_stimulus_file2(self): dir = Path(__file__).parent.parent.resolve() - stimulus_filepath = dir / 'resources' / 'example_stimulus.pkl.gz' + stimulus_filepath = dir / "resources" / "example_stimulus.pkl.gz" stimulus_file = BehaviorStimulusFile(filepath=stimulus_filepath) - stimulus_file, stimulus_timestamps, licks, rewards, \ - response_window_start = \ - self._get_trial_table_data(stimulus_file=stimulus_file) + stimulus_file, stimulus_timestamps, licks, rewards, response_window_start = self._get_trial_table_data( + stimulus_file=stimulus_file + ) Trials.from_stimulus_file( - stimulus_file=stimulus_file, - stimulus_timestamps=stimulus_timestamps, - licks=licks, - rewards=rewards + stimulus_file=stimulus_file, stimulus_timestamps=stimulus_timestamps, licks=licks, rewards=rewards ) - def _get_trial_table_data( - self, - stimulus_file: Optional[BehaviorStimulusFile] = None): + def _get_trial_table_data(self, stimulus_file: Optional[BehaviorStimulusFile] = None): """returns data required to instantiate a TrialTable""" if stimulus_file is None: - stimulus_file = BehaviorStimulusFile.from_lims( - behavior_session_id=self.behavior_session_id, db=self.dbconn) - stimulus_timestamps = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, - monitor_delay=0.02115) + stimulus_file = BehaviorStimulusFile.from_lims(behavior_session_id=self.behavior_session_id, db=self.dbconn) + stimulus_timestamps = StimulusTimestamps.from_stimulus_file(stimulus_file=stimulus_file, monitor_delay=0.02115) stimulus_timestamps_no_delay = StimulusTimestamps.from_stimulus_file( - stimulus_file=stimulus_file, - monitor_delay=0.0) + stimulus_file=stimulus_file, monitor_delay=0.0 + ) - licks = Licks.from_stimulus_file( - stimulus_file=stimulus_file, - stimulus_timestamps=stimulus_timestamps_no_delay) + licks = Licks.from_stimulus_file(stimulus_file=stimulus_file, stimulus_timestamps=stimulus_timestamps_no_delay) rewards = Rewards.from_stimulus_file( - stimulus_file=stimulus_file, - stimulus_timestamps=stimulus_timestamps_no_delay) + stimulus_file=stimulus_file, stimulus_timestamps=stimulus_timestamps_no_delay + ) - response_window_start = TaskParameters.from_stimulus_file( - stimulus_file=stimulus_file - ).response_window_sec[0] + response_window_start = TaskParameters.from_stimulus_file(stimulus_file=stimulus_file).response_window_sec[0] - return stimulus_file, stimulus_timestamps, licks, rewards, \ - response_window_start + return stimulus_file, stimulus_timestamps, licks, rewards, response_window_start class TestMonitorDelay: @classmethod def setup_class(cls): cls.lookup_table_expected_values = { - 'CAM2P.1': 0.020842, - 'CAM2P.2': 0.037566, - 'CAM2P.3': 0.021390, - 'CAM2P.4': 0.021102, - 'CAM2P.5': 0.021192, - 'MESO.1': 0.03613 + "CAM2P.1": 0.020842, + "CAM2P.2": 0.037566, + "CAM2P.3": 0.021390, + "CAM2P.4": 0.021102, + "CAM2P.5": 0.021192, + "MESO.1": 0.03613, } def setup_method(self, method): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - trials = pd.read_pickle(str(test_data_dir / 'trials.pkl')) - self.sync_file = SyncFile(filepath=str(test_data_dir / 'sync.h5')) + trials = pd.read_pickle(str(test_data_dir / "trials.pkl")) + self.sync_file = SyncFile(filepath=str(test_data_dir / "sync.h5")) self.trials = Trials(trials=trials, response_window_start=0) - @pytest.mark.parametrize('equipment_name', ('CAMP2.1', 'MESO.2')) + @pytest.mark.parametrize("equipment_name", ("CAMP2.1", "MESO.2")) def test_monitor_delay(self, monkeypatch, equipment_name): equipment = Equipment(equipment_name=equipment_name) @@ -134,11 +108,8 @@ def dummy_delay(self): return 1.12 with monkeypatch.context() as ctx: - ctx.setattr(OphysTimeAligner, - '_get_monitor_delay', - dummy_delay) - md = calculate_monitor_delay(sync_file=self.sync_file, - equipment=equipment) + ctx.setattr(OphysTimeAligner, "_get_monitor_delay", dummy_delay) + md = calculate_monitor_delay(sync_file=self.sync_file, equipment=equipment) assert abs(md - 1.12) < 1.0e-6 def test_monitor_delay_lookup(self, monkeypatch): @@ -147,14 +118,10 @@ def dummy_delay(self): raise ValueError("that did not work") with monkeypatch.context() as ctx: - ctx.setattr(OphysTimeAligner, - '_get_monitor_delay', - dummy_delay) - for equipment, expected in \ - self.lookup_table_expected_values.items(): + ctx.setattr(OphysTimeAligner, "_get_monitor_delay", dummy_delay) + for equipment, expected in self.lookup_table_expected_values.items(): equipment = Equipment(equipment_name=equipment) - md = calculate_monitor_delay( - sync_file=self.sync_file, equipment=equipment) + md = calculate_monitor_delay(sync_file=self.sync_file, equipment=equipment) assert abs(md - expected) < 1e-6 def test_unkown_rig_name(self, monkeypatch): @@ -163,53 +130,41 @@ def dummy_delay(self): raise ValueError("that did not work") with monkeypatch.context() as ctx: - ctx.setattr(OphysTimeAligner, - '_get_monitor_delay', - dummy_delay) - equipment = Equipment(equipment_name='spam') + ctx.setattr(OphysTimeAligner, "_get_monitor_delay", dummy_delay) + equipment = Equipment(equipment_name="spam") with pytest.raises(RuntimeError): - calculate_monitor_delay(sync_file=self.sync_file, - equipment=equipment) + calculate_monitor_delay(sync_file=self.sync_file, equipment=equipment) class TestNWB: @classmethod def setup_class(cls): dir = Path(__file__).parent.resolve() - test_data_dir = dir / 'test_data' + test_data_dir = dir / "test_data" - trials = pd.read_pickle(str(test_data_dir / 'trials.pkl')) + trials = pd.read_pickle(str(test_data_dir / "trials.pkl")) cls.trials = Trials(trials=trials, response_window_start=0) def setup_method(self, method): self.nwbfile = pynwb.NWBFile( - session_description='asession', - identifier='1234', - session_start_time=datetime.now() + session_description="asession", identifier="1234", session_start_time=datetime.now() ) @pytest.mark.requires_bamboo - @pytest.mark.parametrize('roundtrip', [True, False]) - def test_read_write_nwb(self, roundtrip, - data_object_roundtrip_fixture): + @pytest.mark.parametrize("roundtrip", [True, False]) + def test_read_write_nwb(self, roundtrip, data_object_roundtrip_fixture): self.trials.to_nwb(nwbfile=self.nwbfile) - with patch.object(TaskParameters, 'from_nwb', - lambda nwbfile: create_autospec( - TaskParameters, instance=True)): + with patch.object(TaskParameters, "from_nwb", lambda nwbfile: create_autospec(TaskParameters, instance=True)): if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=self.nwbfile, - data_object_cls=Trials) + obt = data_object_roundtrip_fixture(nwbfile=self.nwbfile, data_object_cls=Trials) else: obt = self.trials.from_nwb(nwbfile=self.nwbfile) test = TestFromBehaviorStimulusFile() test.setup_class() test.setup_method(self.test_read_write_nwb) - _, _, _, _, \ - response_window_start = \ - test._get_trial_table_data() + _, _, _, _, response_window_start = test._get_trial_table_data() obt._response_window_start = response_window_start self.trials._response_window_start = response_window_start diff --git a/allensdk/test/brain_observatory/behavior/test_behavior_metadata_legacy.py b/allensdk/test/brain_observatory/behavior/test_behavior_metadata_legacy.py index 2f0c8b43b2..aa6ae92f86 100644 --- a/allensdk/test/brain_observatory/behavior/test_behavior_metadata_legacy.py +++ b/allensdk/test/brain_observatory/behavior/test_behavior_metadata_legacy.py @@ -1,192 +1,170 @@ import pytest import numpy as np -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import ( - description_dict, get_task_parameters, get_expt_description) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + description_dict, + get_task_parameters, + get_expt_description, +) -@pytest.mark.parametrize("data, expected", - [pytest.param({ # noqa: E128 - "items": { - "behavior": { - "config": { - "DoC": { - "blank_duration_range": ( - 0.5, 0.6), - "response_window": [0.15, 0.75], - "change_time_dist": "geometric", - "auto_reward_volume": 0.002, - }, - "reward": { - "reward_volume": 0.007, - }, - "behavior": { - "task_id": "DoC_untranslated", - }, - }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "images": {"draw_log": [1] * 10, - "flash_interval_sec": [ - 0.32, -1.0]} - }, - } - } - }, - { - "blank_duration_sec": [0.5, 0.6], - "stimulus_duration_sec": 0.32, - "omitted_flash_fraction": 0.05, - "response_window_sec": [0.15, 0.75], - "reward_volume": 0.007, - "session_type": "TRAINING_3_images_A", - "stimulus": "images", - "stimulus_distribution": "geometric", - "task": "change detection", - "n_stimulus_frames": 10, - "auto_reward_volume": 0.002 - }, id='basic'), - pytest.param({ - "items": { - "behavior": { - "config": { - "DoC": { - "blank_duration_range": ( - 0.5, 0.5), - "response_window": [0.15, - 0.75], - "change_time_dist": - "geometric", - "auto_reward_volume": 0.002 - }, - "reward": { - "reward_volume": 0.007, - }, - "behavior": { - "task_id": "DoC_untranslated", - }, - }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "images": {"draw_log": [1] * 10, - "flash_interval_sec": [ - 0.32, -1.0]} - }, - } - } - }, - { - "blank_duration_sec": [0.5, 0.5], - "stimulus_duration_sec": 0.32, - "omitted_flash_fraction": 0.05, - "response_window_sec": [0.15, 0.75], - "reward_volume": 0.007, - "session_type": "TRAINING_3_images_A", - "stimulus": "images", - "stimulus_distribution": "geometric", - "task": "change detection", - "n_stimulus_frames": 10, - "auto_reward_volume": 0.002 - }, id='single_value_blank_duration'), - pytest.param({ - "items": { - "behavior": { - "config": { - "DoC": { - "blank_duration_range": ( - 0.5, 0.5), - "response_window": [0.15, - 0.75], - "change_time_dist": - "geometric", - "auto_reward_volume": 0.002 - }, - "reward": { - "reward_volume": 0.007, - }, - "behavior": { - "task_id": "DoC_untranslated", - }, - }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "grating": {"draw_log": [1] * 10, - "flash_interval_sec": - [0.34, -1.0]} - }, - } - } - }, - { - "blank_duration_sec": [0.5, 0.5], - "stimulus_duration_sec": 0.34, - "omitted_flash_fraction": 0.05, - "response_window_sec": [0.15, 0.75], - "reward_volume": 0.007, - "session_type": "TRAINING_3_images_A", - "stimulus": "grating", - "stimulus_distribution": "geometric", - "task": "change detection", - "n_stimulus_frames": 10, - "auto_reward_volume": 0.002 - }, id='stimulus_duration_from_grating'), - pytest.param({ - "items": { - "behavior": { - "config": { - "DoC": { - "blank_duration_range": ( - 0.5, 0.5), - "response_window": [0.15, - 0.75], - "change_time_dist": - "geometric", - "auto_reward_volume": 0.002 - }, - "reward": { - "reward_volume": 0.007, - }, - "behavior": { - "task_id": "DoC_untranslated", - }, - }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "grating": { - "draw_log": [1] * 10, - "flash_interval_sec": None} - }, - } - } - }, - { - "blank_duration_sec": [0.5, 0.5], - "stimulus_duration_sec": np.nan, - "omitted_flash_fraction": 0.05, - "response_window_sec": [0.15, 0.75], - "reward_volume": 0.007, - "session_type": "TRAINING_3_images_A", - "stimulus": "grating", - "stimulus_distribution": "geometric", - "task": "change detection", - "n_stimulus_frames": 10, - "auto_reward_volume": 0.002 - }, id='stimulus_duration_none') - ] - ) +@pytest.mark.parametrize( + "data, expected", + [ + pytest.param( + { # noqa: E128 + "items": { + "behavior": { + "config": { + "DoC": { + "blank_duration_range": (0.5, 0.6), + "response_window": [0.15, 0.75], + "change_time_dist": "geometric", + "auto_reward_volume": 0.002, + }, + "reward": { + "reward_volume": 0.007, + }, + "behavior": { + "task_id": "DoC_untranslated", + }, + }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"images": {"draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0]}}, + } + } + }, + { + "blank_duration_sec": [0.5, 0.6], + "stimulus_duration_sec": 0.32, + "omitted_flash_fraction": 0.05, + "response_window_sec": [0.15, 0.75], + "reward_volume": 0.007, + "session_type": "TRAINING_3_images_A", + "stimulus": "images", + "stimulus_distribution": "geometric", + "task": "change detection", + "n_stimulus_frames": 10, + "auto_reward_volume": 0.002, + }, + id="basic", + ), + pytest.param( + { + "items": { + "behavior": { + "config": { + "DoC": { + "blank_duration_range": (0.5, 0.5), + "response_window": [0.15, 0.75], + "change_time_dist": "geometric", + "auto_reward_volume": 0.002, + }, + "reward": { + "reward_volume": 0.007, + }, + "behavior": { + "task_id": "DoC_untranslated", + }, + }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"images": {"draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0]}}, + } + } + }, + { + "blank_duration_sec": [0.5, 0.5], + "stimulus_duration_sec": 0.32, + "omitted_flash_fraction": 0.05, + "response_window_sec": [0.15, 0.75], + "reward_volume": 0.007, + "session_type": "TRAINING_3_images_A", + "stimulus": "images", + "stimulus_distribution": "geometric", + "task": "change detection", + "n_stimulus_frames": 10, + "auto_reward_volume": 0.002, + }, + id="single_value_blank_duration", + ), + pytest.param( + { + "items": { + "behavior": { + "config": { + "DoC": { + "blank_duration_range": (0.5, 0.5), + "response_window": [0.15, 0.75], + "change_time_dist": "geometric", + "auto_reward_volume": 0.002, + }, + "reward": { + "reward_volume": 0.007, + }, + "behavior": { + "task_id": "DoC_untranslated", + }, + }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"grating": {"draw_log": [1] * 10, "flash_interval_sec": [0.34, -1.0]}}, + } + } + }, + { + "blank_duration_sec": [0.5, 0.5], + "stimulus_duration_sec": 0.34, + "omitted_flash_fraction": 0.05, + "response_window_sec": [0.15, 0.75], + "reward_volume": 0.007, + "session_type": "TRAINING_3_images_A", + "stimulus": "grating", + "stimulus_distribution": "geometric", + "task": "change detection", + "n_stimulus_frames": 10, + "auto_reward_volume": 0.002, + }, + id="stimulus_duration_from_grating", + ), + pytest.param( + { + "items": { + "behavior": { + "config": { + "DoC": { + "blank_duration_range": (0.5, 0.5), + "response_window": [0.15, 0.75], + "change_time_dist": "geometric", + "auto_reward_volume": 0.002, + }, + "reward": { + "reward_volume": 0.007, + }, + "behavior": { + "task_id": "DoC_untranslated", + }, + }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"grating": {"draw_log": [1] * 10, "flash_interval_sec": None}}, + } + } + }, + { + "blank_duration_sec": [0.5, 0.5], + "stimulus_duration_sec": np.nan, + "omitted_flash_fraction": 0.05, + "response_window_sec": [0.15, 0.75], + "reward_volume": 0.007, + "session_type": "TRAINING_3_images_A", + "stimulus": "grating", + "stimulus_distribution": "geometric", + "task": "change detection", + "n_stimulus_frames": 10, + "auto_reward_volume": 0.002, + }, + id="stimulus_duration_none", + ), + ], +) def test_get_task_parameters(data, expected): actual = get_task_parameters(data) for k, v in actual.items(): @@ -219,7 +197,7 @@ def test_get_task_parameters_task_id_exception(): "blank_duration_range": (0.5, 0.6), "response_window": [0.15, 0.75], "change_time_dist": "geometric", - "auto_reward_volume": 0.002 + "auto_reward_volume": 0.002, }, "reward": { "reward_volume": 0.007, @@ -228,14 +206,8 @@ def test_get_task_parameters_task_id_exception(): "task_id": "junk", }, }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "images": {"draw_log": [1] * 10, - "flash_interval_sec": [0.32, -1.0]} - }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"images": {"draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0]}}, } } } @@ -258,7 +230,7 @@ def test_get_task_parameters_flash_duration_exception(): "blank_duration_range": (0.5, 0.6), "response_window": [0.15, 0.75], "change_time_dist": "geometric", - "auto_reward_volume": 0.002 + "auto_reward_volume": 0.002, }, "reward": { "reward_volume": 0.007, @@ -267,14 +239,8 @@ def test_get_task_parameters_flash_duration_exception(): "task_id": "DoC", }, }, - "params": { - "stage": "TRAINING_3_images_A", - "flash_omit_probability": 0.05 - }, - "stimuli": { - "junk": {"draw_log": [1] * 10, - "flash_interval_sec": [0.32, -1.0]} - }, + "params": {"stage": "TRAINING_3_images_A", "flash_omit_probability": 0.05}, + "stimuli": {"junk": {"draw_log": [1] * 10, "flash_interval_sec": [0.32, -1.0]}}, } } } @@ -285,35 +251,31 @@ def test_get_task_parameters_flash_duration_exception(): assert shld_be in error.value.args[0] -@pytest.mark.parametrize("session_type, expected_description", [ - ("OPHYS_0_images_Z", description_dict[r"\AOPHYS_0_images"]), - ("OPHYS_1_images_A", description_dict[r"\AOPHYS_[1|3]_images"]), - ("OPHYS_2_images_B", description_dict[r"\AOPHYS_2_images"]), - ("OPHYS_3_images_C", description_dict[r"\AOPHYS_[1|3]_images"]), - ("OPHYS_4_images_D", description_dict[r"\AOPHYS_[4|6]_images"]), - ("OPHYS_5_images_E", description_dict[r"\AOPHYS_5_images"]), - ("OPHYS_6_images_F", description_dict[r"\AOPHYS_[4|6]_images"]), - ("TRAINING_0_gratings_A", description_dict[r"\ATRAINING_0_gratings"]), - ("TRAINING_1_gratings_B", description_dict[r"\ATRAINING_1_gratings"]), - ("TRAINING_2_gratings_C", description_dict[r"\ATRAINING_2_gratings"]), - ("TRAINING_3_images_D", description_dict[r"\ATRAINING_3_images"]), - ("TRAINING_4_images_E", description_dict[r"\ATRAINING_4_images"]), - ('TRAINING_3_images_A_10uL_reward', - description_dict[r"\ATRAINING_3_images"]), - ('TRAINING_5_images_A_handoff_lapsed', - description_dict[r"\ATRAINING_5_images"]) -]) -def test_get_expt_description_with_valid_session_type(session_type, - expected_description): +@pytest.mark.parametrize( + "session_type, expected_description", + [ + ("OPHYS_0_images_Z", description_dict[r"\AOPHYS_0_images"]), + ("OPHYS_1_images_A", description_dict[r"\AOPHYS_[1|3]_images"]), + ("OPHYS_2_images_B", description_dict[r"\AOPHYS_2_images"]), + ("OPHYS_3_images_C", description_dict[r"\AOPHYS_[1|3]_images"]), + ("OPHYS_4_images_D", description_dict[r"\AOPHYS_[4|6]_images"]), + ("OPHYS_5_images_E", description_dict[r"\AOPHYS_5_images"]), + ("OPHYS_6_images_F", description_dict[r"\AOPHYS_[4|6]_images"]), + ("TRAINING_0_gratings_A", description_dict[r"\ATRAINING_0_gratings"]), + ("TRAINING_1_gratings_B", description_dict[r"\ATRAINING_1_gratings"]), + ("TRAINING_2_gratings_C", description_dict[r"\ATRAINING_2_gratings"]), + ("TRAINING_3_images_D", description_dict[r"\ATRAINING_3_images"]), + ("TRAINING_4_images_E", description_dict[r"\ATRAINING_4_images"]), + ("TRAINING_3_images_A_10uL_reward", description_dict[r"\ATRAINING_3_images"]), + ("TRAINING_5_images_A_handoff_lapsed", description_dict[r"\ATRAINING_5_images"]), + ], +) +def test_get_expt_description_with_valid_session_type(session_type, expected_description): obt = get_expt_description(session_type) assert obt == expected_description -@pytest.mark.parametrize("session_type", [ - ("bogus_session_type"), - ("stuff"), - ("OPHYS_7") -]) +@pytest.mark.parametrize("session_type", [("bogus_session_type"), ("stuff"), ("OPHYS_7")]) def test_get_expt_description_raises_with_invalid_session_type(session_type): with pytest.raises(RuntimeError, match="session type should match.*"): get_expt_description(session_type) diff --git a/allensdk/test/brain_observatory/behavior/test_behavior_ophys_experiment.py b/allensdk/test/brain_observatory/behavior/test_behavior_ophys_experiment.py index 779a6413ab..1fe80e879a 100644 --- a/allensdk/test/brain_observatory/behavior/test_behavior_ophys_experiment.py +++ b/allensdk/test/brain_observatory/behavior/test_behavior_ophys_experiment.py @@ -27,9 +27,7 @@ def test_nwb_end_to_end(tmpdir_factory): oeid = 795073741 tmpdir = "test_nwb_end_to_end" - nwb_filepath = os.path.join( - str(tmpdir_factory.mktemp(tmpdir)), "nwbfile.nwb" - ) + nwb_filepath = os.path.join(str(tmpdir_factory.mktemp(tmpdir)), "nwbfile.nwb") d1 = BehaviorOphysExperiment.from_lims( oeid, @@ -40,9 +38,7 @@ def test_nwb_end_to_end(tmpdir_factory): d2 = BehaviorOphysExperiment.from_nwb(nwbfile=nwbfile) - assert sessions_are_equal( - d1, d2, reraise=True, ignore_keys={"metadata": {"project_code"}} - ) + assert sessions_are_equal(d1, d2, reraise=True, ignore_keys={"metadata": {"project_code"}}) @pytest.mark.nightly @@ -60,9 +56,7 @@ def test_visbeh_ophys_data_set(): # for _, row in data_set.roi_masks.iterrows(): # print(np.array(row.to_dict()['mask']).sum()) - lims_db = db_connection_creator( - fallback_credentials=LIMS_DB_CREDENTIAL_MAP - ) + lims_db = db_connection_creator(fallback_credentials=LIMS_DB_CREDENTIAL_MAP) behavior_session_id = BehaviorSessionId.from_lims( db=lims_db, ophys_experiment_id=ophys_experiment_id, @@ -82,12 +76,10 @@ def test_visbeh_ophys_data_set(): assert stimulus_templates.loc["im000"].warped.shape == MONITOR_DIMENSIONS assert stimulus_templates.loc["im000"].unwarped.shape == MONITOR_DIMENSIONS - assert len(data_set.licks) == 2421 and set(data_set.licks.columns) == set( - ["timestamps", "frame"] + assert len(data_set.licks) == 2421 and set(data_set.licks.columns) == set(["timestamps", "frame"]) + assert len(data_set.rewards) == 85 and set(data_set.rewards.columns) == set( + ["timestamps", "volume", "auto_rewarded"] ) - assert len(data_set.rewards) == 85 and set( - data_set.rewards.columns - ) == set(["timestamps", "volume", "auto_rewarded"]) assert len(data_set.corrected_fluorescence_traces) == 258 and set( data_set.corrected_fluorescence_traces.columns ) == set(["cell_roi_id", "corrected_fluorescence", "RMSE", "r"]) @@ -100,17 +92,13 @@ def test_visbeh_ophys_data_set(): ) assert len(data_set.cell_specimen_table) == len(data_set.dff_traces) - assert ( - data_set.average_projection.data.shape - == data_set.max_projection.data.shape - ) + assert data_set.average_projection.data.shape == data_set.max_projection.data.shape assert set(data_set.motion_correction.columns) == set(["x", "y"]) assert len(data_set.trials) == 602 expected_metadata = { "stimulus_frame_rate": 60.0, - "full_genotype": "Slc17a7-IRES2-Cre/wt;Camk2a-tTA/wt;Ai93(" - "TITL-GCaMP6f)/wt", + "full_genotype": "Slc17a7-IRES2-Cre/wt;Camk2a-tTA/wt;Ai93(TITL-GCaMP6f)/wt", "ophys_experiment_id": 789359614, "behavior_session_id": 789295700, "imaging_plane_group_count": 0, @@ -118,12 +106,8 @@ def test_visbeh_ophys_data_set(): "session_type": "OPHYS_6_images_B", "driver_line": ["Camk2a-tTA", "Slc17a7-IRES2-Cre"], "cre_line": "Slc17a7-IRES2-Cre", - "behavior_session_uuid": uuid.UUID( - "69cdbe09-e62b-4b42-aab1-54b5773dfe78" - ), - "date_of_acquisition": pytz.utc.localize( - datetime.datetime(2018, 11, 30, 15, 58, 50, 325000) - ), + "behavior_session_uuid": uuid.UUID("69cdbe09-e62b-4b42-aab1-54b5773dfe78"), + "date_of_acquisition": pytz.utc.localize(datetime.datetime(2018, 11, 30, 15, 58, 50, 325000)), "ophys_frame_rate": 31.0, "imaging_depth": 375, "targeted_imaging_depth": 375, @@ -206,9 +190,7 @@ def test_event_detection(): ] assert len(events.columns) == len(expected_columns) # Assert they contain the same columns - assert len(set(expected_columns).intersection(events.columns)) == len( - expected_columns - ) + assert len(set(expected_columns).intersection(events.columns)) == len(expected_columns) assert events.index.name == "cell_specimen_id" @@ -298,15 +280,9 @@ def test_stim_v_trials_time(behavior_ophys_experiment_fixture): """ exp = behavior_ophys_experiment_fixture - stim = exp.stimulus_presentations[ - exp.stimulus_presentations["is_change"] - ].start_time.reset_index(drop=True) + stim = exp.stimulus_presentations[exp.stimulus_presentations["is_change"]].start_time.reset_index(drop=True) - trials = ( - exp.trials.query("not aborted") - .query("go or auto_rewarded")["change_time"] - .reset_index(drop=True) - ) + trials = exp.trials.query("not aborted").query("go or auto_rewarded")["change_time"].reset_index(drop=True) delta = np.abs(stim - trials) assert delta.max() < 1.0e-6 diff --git a/allensdk/test/brain_observatory/behavior/test_behavior_session.py b/allensdk/test/brain_observatory/behavior/test_behavior_session.py index 66b50172cc..9ee3d8d66f 100644 --- a/allensdk/test/brain_observatory/behavior/test_behavior_session.py +++ b/allensdk/test/brain_observatory/behavior/test_behavior_session.py @@ -1,5 +1,4 @@ -from allensdk.brain_observatory.behavior.behavior_session import ( - BehaviorSession) +from allensdk.brain_observatory.behavior.behavior_session import BehaviorSession from allensdk.brain_observatory.session_api_utils import sessions_are_equal @@ -7,59 +6,45 @@ import pathlib from pynwb import NWBHDF5IO -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest @pytest.mark.requires_bamboo -def test_nwb_end_to_end_session( - tmpdir_factory, - helper_functions): +def test_nwb_end_to_end_session(tmpdir_factory, helper_functions): session_id = 870987812 - tmpdir = tmpdir_factory.mktemp('session_nwb_end_to_end') + tmpdir = tmpdir_factory.mktemp("session_nwb_end_to_end") tmpdir = pathlib.Path(tmpdir) - nwb_path = tmpdir / f'session_{session_id}.nwb' - session = BehaviorSession.from_lims( - behavior_session_id=session_id) + nwb_path = tmpdir / f"session_{session_id}.nwb" + session = BehaviorSession.from_lims(behavior_session_id=session_id) nwb_file = session.to_nwb() - with NWBHDF5IO(nwb_path, 'w') as nwb_file_writer: + with NWBHDF5IO(nwb_path, "w") as nwb_file_writer: nwb_file_writer.write(nwb_file) - roundtrip = BehaviorSession.from_nwb_path( - nwb_path=str(nwb_path.resolve().absolute())) + roundtrip = BehaviorSession.from_nwb_path(nwb_path=str(nwb_path.resolve().absolute())) - assert sessions_are_equal( - session, - roundtrip, - reraise=True) + assert sessions_are_equal(session, roundtrip, reraise=True) helper_functions.windows_safe_cleanup_dir(tmpdir) @pytest.fixture def session_data_fixture(): - return { "behavior_session_id": 1010991549, "foraging_id": "bfcc4803-8892-4cb4-88e0-9437b98936db", - "driver_line": [ - "Vip-IRES-Cre" - ], - "reporter_line": [ - "Ai32(RCL-ChR2(H134R)_EYFP)" - ], + "driver_line": ["Vip-IRES-Cre"], + "reporter_line": ["Ai32(RCL-ChR2(H134R)_EYFP)"], "full_genotype": "Vip-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt", "rig_name": "BEH.G-Box1", "date_of_acquisition": "2020-02-28 03:11:17", "external_specimen_name": 506940, - "behavior_stimulus_file": - "/allen/programs/braintv/production/visualbehavior/prod0/" - "specimen_1000324129/behavior_session_1010991549/" - "200228111053_506940_bfcc4803-8892-4cb4-88e0-9437b98936db.pkl", + "behavior_stimulus_file": "/allen/programs/braintv/production/visualbehavior/prod0/" + "specimen_1000324129/behavior_session_1010991549/" + "200228111053_506940_bfcc4803-8892-4cb4-88e0-9437b98936db.pkl", "date_of_birth": "2019-11-24 16:00:00", "sex": "M", "age": "unknown", - "stimulus_name": None + "stimulus_name": None, } @@ -71,27 +56,27 @@ def dummy_init(self): pass with monkeypatch.context() as ctx: - ctx.setattr(BehaviorSession, '__init__', dummy_init) + ctx.setattr(BehaviorSession, "__init__", dummy_init) bs = BehaviorSession() obt = bs.list_data_attributes_and_methods() expected = { - 'behavior_session_id', - 'get_performance_metrics', - 'get_reward_rate', - 'get_rolling_performance_df', - 'licks', - 'metadata', - 'raw_running_speed', - 'rewards', - 'running_speed', - 'stimulus_presentations', - 'stimulus_templates', - 'stimulus_timestamps', - 'task_parameters', - 'trials', - 'eye_tracking', - 'eye_tracking_rig_geometry' + "behavior_session_id", + "get_performance_metrics", + "get_reward_rate", + "get_rolling_performance_df", + "licks", + "metadata", + "raw_running_speed", + "rewards", + "running_speed", + "stimulus_presentations", + "stimulus_templates", + "stimulus_timestamps", + "task_parameters", + "trials", + "eye_tracking", + "eye_tracking_rig_geometry", } assert any(expected ^ set(obt)) is False @@ -99,13 +84,10 @@ def dummy_init(self): @pytest.mark.nightly def test_behavior_session_equivalent_json_lims(session_data_fixture): - json_session = BehaviorSession.from_json(session_data_fixture) - behavior_session_id = session_data_fixture['behavior_session_id'] - lims_session = BehaviorSession.from_lims( - behavior_session_id - ) + behavior_session_id = session_data_fixture["behavior_session_id"] + lims_session = BehaviorSession.from_lims(behavior_session_id) assert sessions_are_equal(json_session, lims_session, reraise=True) @@ -116,6 +98,5 @@ def test_eye_tracking_loaded_with_metadata_frame(self): # This session uses MVR to record the eye tracking video sess_id = 1154034257 - sess = BehaviorSession.from_lims(behavior_session_id=sess_id, - lims_db=self.dbconn) + sess = BehaviorSession.from_lims(behavior_session_id=sess_id, lims_db=self.dbconn) assert not sess.eye_tracking.empty diff --git a/allensdk/test/brain_observatory/behavior/test_criteria.py b/allensdk/test/brain_observatory/behavior/test_criteria.py index 2bb3cb5bdb..131dfc1d75 100644 --- a/allensdk/test/brain_observatory/behavior/test_criteria.py +++ b/allensdk/test/brain_observatory/behavior/test_criteria.py @@ -8,31 +8,71 @@ "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 0, 2: 0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 0, + 2: 0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 2.0, 1: 2.0, 2: 2.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 2.0, + 1: 2.0, + 2: 2.0, + }, + } + ), False, ), # should need to be greater than 2.0 ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 0, 2: 2.1, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 0, + 2: 2.1, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 2.1, 2: 2.1, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 2.1, + 2: 2.1, + }, + } + ), True, ), ], @@ -45,17 +85,35 @@ def test_two_out_of_three_aint_bad(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, }, - "dprime_peak": {0: 0, 1: 2.1, } - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + }, + "dprime_peak": { + 0: 0, + 1: 2.1, + }, + } + ), pytest.raises(DataFrameIndexError), ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "other_col": {0: 0, 1: 2.1, 2: 2.1, } - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "other_col": { + 0: 0, + 1: 2.1, + 2: 2.1, + }, + } + ), pytest.raises(DataFrameKeyError), ), ], @@ -69,34 +127,73 @@ def test_two_out_of_three_aint_bad_exception(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 0, 2: 0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 0, + 2: 0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 2.0, 1: 2.0, 2: 2.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 2.0, + 1: 2.0, + 2: 2.0, + }, + } + ), False, ), # should need to be greater than 2.0 ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 2.1, 2: 0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 2.1, + 2: 0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "dprime_peak": {0: 0, 1: 0, 2: 2.1, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "dprime_peak": { + 0: 0, + 1: 0, + 2: 2.1, + }, + } + ), True, ), - ], ) def test_yesterday_was_good(session_summary, expected): @@ -107,17 +204,29 @@ def test_yesterday_was_good(session_summary, expected): "session_summary,expected", [ ( - pd.DataFrame({ - "training_day": {}, - "dprime_peak": {}, - }), + pd.DataFrame( + { + "training_day": {}, + "dprime_peak": {}, + } + ), pytest.raises(DataFrameIndexError), ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "other_col": {0: 0, 1: 0, 2: 2.1, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "other_col": { + 0: 0, + 1: 0, + 2: 2.1, + }, + } + ), pytest.raises(DataFrameKeyError), ), ], @@ -131,52 +240,122 @@ def test_yesterday_was_good_exception(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.1, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.1, + }, + } + ), False, ), # non-inclusive ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.9, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.9, + }, + } + ), False, ), # non-inclusive ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.11, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.11, + }, + } + ), True, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.89, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.89, + }, + } + ), True, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.1011, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.1011, + }, + } + ), True, ), # doesn't round ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "response_bias": {0: 0.0, 1: 0.0, 2: 0.8999, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "response_bias": { + 0: 0.0, + 1: 0.0, + 2: 0.8999, + }, + } + ), True, ), # doesn't round ], @@ -189,17 +368,29 @@ def test_no_response_bias(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {}, - "response_bias": {}, - }), + pd.DataFrame( + { + "training_day": {}, + "response_bias": {}, + } + ), pytest.raises(DataFrameIndexError), ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "other_col": {0: 0.0, 1: 0.0, 2: 0.11, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "other_col": { + 0: 0.0, + 1: 0.0, + 2: 0.11, + }, + } + ), pytest.raises(DataFrameKeyError), ), ], @@ -213,31 +404,71 @@ def test_no_response_bias_exception(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "num_contingent_trials": {0: 0.0, 1: 0.0, 2: 0.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "num_contingent_trials": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "num_contingent_trials": {0: 0.0, 1: 0.0, 2: 100.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "num_contingent_trials": { + 0: 0.0, + 1: 0.0, + 2: 100.0, + }, + } + ), False, ), # non-inclusive ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "num_contingent_trials": {0: 0.0, 1: 0.0, 2: 300.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "num_contingent_trials": { + 0: 0.0, + 1: 0.0, + 2: 300.0, + }, + } + ), False, ), # non-inclusive ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "num_contingent_trials": {0: 0.0, 1: 0.0, 2: 301.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "num_contingent_trials": { + 0: 0.0, + 1: 0.0, + 2: 301.0, + }, + } + ), True, ), ], @@ -250,20 +481,32 @@ def test_whole_lotta_trials(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "training_day": {}, - "num_contingent_trials": {}, - }), + pd.DataFrame( + { + "training_day": {}, + "num_contingent_trials": {}, + } + ), pytest.raises(DataFrameIndexError), ), ( - pd.DataFrame({ - "training_day": {0: 0, 1: 1, 2: 3, }, - "other_col": {0: 0.0, 1: 0.0, 2: 301.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + "other_col": { + 0: 0.0, + 1: 0.0, + 2: 301.0, + }, + } + ), pytest.raises(DataFrameKeyError), ), - ] + ], ) def test_whole_lotta_trials_exception(session_summary, expected): with expected: @@ -274,35 +517,88 @@ def test_whole_lotta_trials_exception(session_summary, expected): "trials, expected", [ ( - pd.DataFrame({ - "training_day": {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, }, # associate all with same training day - "trial_type": {0: "aborted", 1: "go", 2: "catch", 3: "go", }, - "trial_length": {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + }, # associate all with same training day + "trial_type": { + 0: "aborted", + 1: "go", + 2: "catch", + 3: "go", + }, + "trial_length": { + 0: 1.0, + 1: 1.0, + 2: 1.0, + 3: 1.0, + }, + } + ), True, ), ( - pd.DataFrame({ - "training_day": {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, }, # associate all with same training day - "trial_type": {0: "aborted", 1: "go", 2: "catch", 3: "aborted", }, - "trial_length": {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + }, # associate all with same training day + "trial_type": { + 0: "aborted", + 1: "go", + 2: "catch", + 3: "aborted", + }, + "trial_length": { + 0: 1.0, + 1: 1.0, + 2: 1.0, + 3: 1.0, + }, + } + ), False, ), # non-inclusive ( - pd.DataFrame({ - "training_day": {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, }, # associate all with same training day - "trial_type": {0: "aborted", 1: "go", 2: "aborted", 3: "aborted", }, - "trial_length": {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, }, - }), + pd.DataFrame( + { + "training_day": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + }, # associate all with same training day + "trial_type": { + 0: "aborted", + 1: "go", + 2: "aborted", + 3: "aborted", + }, + "trial_length": { + 0: 1.0, + 1: 1.0, + 2: 1.0, + 3: 1.0, + }, + } + ), False, ), ( - pd.DataFrame({ - "training_day": {}, # associate all with same training day - "trial_type": {}, - "trial_length": {}, - }), + pd.DataFrame( + { + "training_day": {}, # associate all with same training day + "trial_type": {}, + "trial_length": {}, + } + ), False, ), ], @@ -315,27 +611,76 @@ def test_mostly_useful(trials, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "task": {0: "Images", 1: "", 2: "", 3: "Images", }, - "dprime_peak": {0: 1.0, 1: 0.0, 2: 0.0, 3: 0.0, }, - "num_engaged_trials": {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, }, - }), + pd.DataFrame( + { + "task": { + 0: "Images", + 1: "", + 2: "", + 3: "Images", + }, + "dprime_peak": { + 0: 1.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + }, + "num_engaged_trials": { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 0.0, + }, + } + ), False, ), ( - pd.DataFrame({ - "task": {0: "", 1: "", 2: "", 3: "Images", 4: "Images", 5: "Images"}, - "dprime_peak": {0: 1.2, 1: 2.0, 2: 1.1, 3: 1.1, 4: 2.0, 5: 1.2, }, - "num_engaged_trials": {0: 101, 1: 102, 2: 200, 3: 101, 4: 102, 5: 99, }, - }), + pd.DataFrame( + { + "task": {0: "", 1: "", 2: "", 3: "Images", 4: "Images", 5: "Images"}, + "dprime_peak": { + 0: 1.2, + 1: 2.0, + 2: 1.1, + 3: 1.1, + 4: 2.0, + 5: 1.2, + }, + "num_engaged_trials": { + 0: 101, + 1: 102, + 2: 200, + 3: 101, + 4: 102, + 5: 99, + }, + } + ), False, ), ( - pd.DataFrame({ - "task": {0: "Images", 1: "Images", 2: "Images", 3: "Images", 4: "Images", 5: "Images"}, - "dprime_peak": {0: 1.2, 1: 2.0, 2: 1.1, 3: 1.1, 4: 2.0, 5: 1.2, }, - "num_engaged_trials": {0: 101, 1: 102, 2: 200, 3: 101, 4: 102, 5: 200, }, - }), + pd.DataFrame( + { + "task": {0: "Images", 1: "Images", 2: "Images", 3: "Images", 4: "Images", 5: "Images"}, + "dprime_peak": { + 0: 1.2, + 1: 2.0, + 2: 1.1, + 3: 1.1, + 4: 2.0, + 5: 1.2, + }, + "num_engaged_trials": { + 0: 101, + 1: 102, + 2: 200, + 3: 101, + 4: 102, + 5: 200, + }, + } + ), True, ), ], @@ -348,19 +693,43 @@ def test_meets_engagement_criteria(session_summary, expected): "session_summary, expected", [ ( - pd.DataFrame({ - "task": {0: "Images", 1: "Images", 2: "Images", 3: "Images", 4: "Images", 5: "Images"}, - "other_metric": {0: 1.2, 1: 2.0, 2: 1.1, 3: 1.1, 4: 2.0, 5: 1.2, }, - "num_engaged_trials": {0: 101, 1: 102, 2: 200, 3: 101, 4: 102, 5: 200, }, - }), + pd.DataFrame( + { + "task": {0: "Images", 1: "Images", 2: "Images", 3: "Images", 4: "Images", 5: "Images"}, + "other_metric": { + 0: 1.2, + 1: 2.0, + 2: 1.1, + 3: 1.1, + 4: 2.0, + 5: 1.2, + }, + "num_engaged_trials": { + 0: 101, + 1: 102, + 2: 200, + 3: 101, + 4: 102, + 5: 200, + }, + } + ), pytest.raises(DataFrameKeyError), ), ( - pd.DataFrame({ - "task": {0: "Images", 1: "Images", }, - "dprime_peak": {0: 1.2, 1: 2.0,}, - "num_engaged_trials": {0: 101, 1: 102}, - }), + pd.DataFrame( + { + "task": { + 0: "Images", + 1: "Images", + }, + "dprime_peak": { + 0: 1.2, + 1: 2.0, + }, + "num_engaged_trials": {0: 101, 1: 102}, + } + ), pytest.raises(DataFrameIndexError), ), ], @@ -373,9 +742,42 @@ def test_meets_engagement_criteria_exception(session_summary, expected): @pytest.mark.parametrize( "trials, expected", [ - (pd.DataFrame({'training_day': {0: 0, 1: 1, 2: 3, }, }), False, ), - (pd.DataFrame({'training_day': {0: 0, 1: 1, 2: 40, }, }), True, ), # inclusive - (pd.DataFrame({'training_day': {0: 0, 1: 1, 2: 41, }, }), True, ), + ( + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 3, + }, + } + ), + False, + ), + ( + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 40, + }, + } + ), + True, + ), # inclusive + ( + pd.DataFrame( + { + "training_day": { + 0: 0, + 1: 1, + 2: 41, + }, + } + ), + True, + ), ], ) def test_summer_over(trials, expected): diff --git a/allensdk/test/brain_observatory/behavior/test_dprime.py b/allensdk/test/brain_observatory/behavior/test_dprime.py index e67037baff..1d3d7ab70a 100644 --- a/allensdk/test/brain_observatory/behavior/test_dprime.py +++ b/allensdk/test/brain_observatory/behavior/test_dprime.py @@ -6,21 +6,21 @@ from collections import defaultdict -from allensdk.brain_observatory.behavior.dprime import \ - get_hit_rate, \ - get_false_alarm_rate, \ - get_rolling_dprime, \ - get_trial_count_corrected_false_alarm_rate, \ - get_trial_count_corrected_hit_rate, \ - get_dprime +from allensdk.brain_observatory.behavior.dprime import ( + get_hit_rate, + get_false_alarm_rate, + get_rolling_dprime, + get_trial_count_corrected_false_alarm_rate, + get_trial_count_corrected_hit_rate, + get_dprime, +) -NaN = float('nan') +NaN = float("nan") @pytest.fixture def mock_trials_fixture(): - n_tr = 500 np.random.seed(42) change = np.random.random(n_tr) > 0.8 @@ -28,39 +28,39 @@ def mock_trials_fixture(): detect = change.copy() detect[incorrect] = ~detect[incorrect] - trials = pd.DataFrame({ - 'change': change, - 'detect': detect, - },) - trials['trial_type'] = trials['change'].map(lambda x: ['catch', 'go'][x]) - trials['response'] = trials['detect'] - trials['change_time'] = np.sort(np.random.rand(n_tr)) * 3600 - trials['reward_lick_latency'] = 0.1 - trials['reward_lick_count'] = 10 - trials['auto_rewarded'] = False - trials['lick_frames'] = [[] for row in trials.iterrows()] - trials['trial_length'] = 8.5 - trials['reward_times'] = trials.apply( - lambda r: [r['change_time']+0.2] if r['change']*r['detect'] else [], - axis=1) - trials['reward_volume'] = 0.005 * trials['reward_times'].map(len) - trials['response_latency'] = trials.apply( - lambda r: 0.2 if r['detect'] else np.nan, axis=1) - trials['blank_duration_range'] = [[0.5, 0.5] for row in trials.iterrows()] + trials = pd.DataFrame( + { + "change": change, + "detect": detect, + }, + ) + trials["trial_type"] = trials["change"].map(lambda x: ["catch", "go"][x]) + trials["response"] = trials["detect"] + trials["change_time"] = np.sort(np.random.rand(n_tr)) * 3600 + trials["reward_lick_latency"] = 0.1 + trials["reward_lick_count"] = 10 + trials["auto_rewarded"] = False + trials["lick_frames"] = [[] for row in trials.iterrows()] + trials["trial_length"] = 8.5 + trials["reward_times"] = trials.apply( + lambda r: [r["change_time"] + 0.2] if r["change"] * r["detect"] else [], axis=1 + ) + trials["reward_volume"] = 0.005 * trials["reward_times"].map(len) + trials["response_latency"] = trials.apply(lambda r: 0.2 if r["detect"] else np.nan, axis=1) + trials["blank_duration_range"] = [[0.5, 0.5] for row in trials.iterrows()] metadata = {} - metadata['mouse_id'] = 'M999999' - metadata['user_id'] = 'johnd' + metadata["mouse_id"] = "M999999" + metadata["user_id"] = "johnd" - metadata['startdatetime'] = datetime.datetime( - 2017, 7, 19, 10, 35, 8, 369000, tzinfo=pytz.utc) - metadata['dayofweek'] = metadata['startdatetime'].weekday() - metadata['startdatetime'] = metadata['startdatetime'] + metadata["startdatetime"] = datetime.datetime(2017, 7, 19, 10, 35, 8, 369000, tzinfo=pytz.utc) + metadata["dayofweek"] = metadata["startdatetime"].weekday() + metadata["startdatetime"] = metadata["startdatetime"] - metadata['behavior_session_uuid'] = 12345 - metadata['stage'] = 'test' - metadata['stimulus'] = 'natural_scenes' - metadata['stimulus_distribution'] = 'exponential' + metadata["behavior_session_uuid"] = 12345 + metadata["stage"] = "test" + metadata["stimulus"] = "natural_scenes" + metadata["stimulus_distribution"] = "exponential" for k, v in metadata.items(): trials[k] = v @@ -69,78 +69,63 @@ def mock_trials_fixture(): @pytest.fixture def mock_rolling_dprime_fixture(mock_trials_fixture): - data_dict = defaultdict(list) - for ri, row in mock_trials_fixture[['trial_type', 'response', - 'change_time']].iterrows(): - assert not pd.isnull(row['change_time']) - if row['trial_type'] == 'go' and row['response']: + for ri, row in mock_trials_fixture[["trial_type", "response", "change_time"]].iterrows(): + assert not pd.isnull(row["change_time"]) + if row["trial_type"] == "go" and row["response"]: hit = True miss = false_alarm = correct_reject = False - elif row['trial_type'] == 'go' and not row['response']: + elif row["trial_type"] == "go" and not row["response"]: miss = True hit = false_alarm = correct_reject = False - elif row['trial_type'] == 'catch' and row['response']: + elif row["trial_type"] == "catch" and row["response"]: false_alarm = True miss = hit = correct_reject = False - elif row['trial_type'] == 'catch' and not row['response']: + elif row["trial_type"] == "catch" and not row["response"]: correct_reject = True hit = false_alarm = miss = False else: raise RuntimeError - data_dict['hit'].append(hit) - data_dict['miss'].append(miss) - data_dict['false_alarm'].append(false_alarm) - data_dict['correct_reject'].append(correct_reject) - data_dict['aborted'].append(False) + data_dict["hit"].append(hit) + data_dict["miss"].append(miss) + data_dict["false_alarm"].append(false_alarm) + data_dict["correct_reject"].append(correct_reject) + data_dict["aborted"].append(False) return pd.DataFrame(data_dict) def test_get_hit_rate(): + hit, miss, aborted = ([0, 1, 0, 0, 0, 1], [1, 0, 0, 0, 1, 0], [0, 0, 1, 1, 0, 0]) - hit, miss, aborted = ( - [0, 1, 0, 0, 0, 1], - [1, 0, 0, 0, 1, 0], - [0, 0, 1, 1, 0, 0]) - - result = get_hit_rate(hit=hit, miss=miss, aborted=aborted, - sliding_window=3) - np.testing.assert_allclose(result, [0, .5, 1/3, 2/3]) + result = get_hit_rate(hit=hit, miss=miss, aborted=aborted, sliding_window=3) + np.testing.assert_allclose(result, [0, 0.5, 1 / 3, 2 / 3]) def test_get_false_alarm_rate(mock_trials_fixture): + false_alarm, correct_reject, aborted = ([0, 1, 0, 0, 0, 1], [1, 0, 0, 0, 1, 0], [0, 0, 1, 1, 0, 0]) - false_alarm, correct_reject, aborted = ( - [0, 1, 0, 0, 0, 1], - [1, 0, 0, 0, 1, 0], - [0, 0, 1, 1, 0, 0]) - - result = get_false_alarm_rate(false_alarm=false_alarm, - correct_reject=correct_reject, - aborted=aborted, - sliding_window=3) - np.testing.assert_allclose(result, [0, .5, 1/3, 2/3]) + result = get_false_alarm_rate( + false_alarm=false_alarm, correct_reject=correct_reject, aborted=aborted, sliding_window=3 + ) + np.testing.assert_allclose(result, [0, 0.5, 1 / 3, 2 / 3]) def test_rolling_dprime_unit(): - hit, miss, false_alarm, correct_reject, aborted = ( [0, 0, 1, 0, 0, 1], [1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0], - [0, 0, 0, 0, 0, 0]) + [0, 0, 0, 0, 0, 0], + ) hr = get_hit_rate(hit=hit, miss=miss, aborted=aborted, sliding_window=3) far = get_false_alarm_rate( - false_alarm=false_alarm, - correct_reject=correct_reject, - aborted=aborted, - sliding_window=3) + false_alarm=false_alarm, correct_reject=correct_reject, aborted=aborted, sliding_window=3 + ) result = get_rolling_dprime(hr, far) - np.testing.assert_allclose(result, - [NaN, NaN, NaN, 2.326348, 4.652696, 4.652696]) + np.testing.assert_allclose(result, [NaN, NaN, NaN, 2.326348, 4.652696, 4.652696]) def test_rolling_dprime_integration_legacy(mock_rolling_dprime_fixture): @@ -152,12 +137,10 @@ def test_rolling_dprime_integration_legacy(mock_rolling_dprime_fixture): correct_reject = mock_rolling_dprime_fixture.correct_reject aborted = mock_rolling_dprime_fixture.aborted - hr = get_hit_rate(hit=hit, miss=miss, aborted=aborted, - sliding_window=sliding_window) - cr = get_false_alarm_rate(false_alarm=false_alarm, - correct_reject=correct_reject, - aborted=aborted, - sliding_window=sliding_window) + hr = get_hit_rate(hit=hit, miss=miss, aborted=aborted, sliding_window=sliding_window) + cr = get_false_alarm_rate( + false_alarm=false_alarm, correct_reject=correct_reject, aborted=aborted, sliding_window=sliding_window + ) dprime = get_rolling_dprime(hr, cr) assert dprime[2] == 4.6526957480816815 @@ -172,25 +155,24 @@ def test_rolling_dprime_integration(mock_rolling_dprime_fixture): correct_reject = mock_rolling_dprime_fixture.correct_reject aborted = mock_rolling_dprime_fixture.aborted - hr = get_trial_count_corrected_hit_rate(hit=hit, miss=miss, - aborted=aborted, - sliding_window=sliding_window) + hr = get_trial_count_corrected_hit_rate(hit=hit, miss=miss, aborted=aborted, sliding_window=sliding_window) cr = get_trial_count_corrected_false_alarm_rate( - false_alarm=false_alarm, - correct_reject=correct_reject, - aborted=aborted, - sliding_window=sliding_window) + false_alarm=false_alarm, correct_reject=correct_reject, aborted=aborted, sliding_window=sliding_window + ) dprime = get_rolling_dprime(hr, cr) assert dprime[2] == 0.6744897501960817 -@pytest.mark.parametrize('hr, far, dprime', [ - pytest.param(1., 1., 0.), - pytest.param(.5, .5, 0.), - pytest.param(.25, .5, -0.6744897501960817), - pytest.param(.5, .25, 0.6744897501960817), -]) +@pytest.mark.parametrize( + "hr, far, dprime", + [ + pytest.param(1.0, 1.0, 0.0), + pytest.param(0.5, 0.5, 0.0), + pytest.param(0.25, 0.5, -0.6744897501960817), + pytest.param(0.5, 0.25, 0.6744897501960817), + ], +) def test_dprime(hr, far, dprime): val = get_dprime(hr, far) assert val == dprime diff --git a/allensdk/test/brain_observatory/behavior/test_event_detection.py b/allensdk/test/brain_observatory/behavior/test_event_detection.py index 9d1b4b2fae..4443c1ef06 100644 --- a/allensdk/test/brain_observatory/behavior/test_event_detection.py +++ b/allensdk/test/brain_observatory/behavior/test_event_detection.py @@ -1,8 +1,7 @@ import numpy as np import pytest -from allensdk.brain_observatory.behavior.event_detection import \ - filter_events_array +from allensdk.brain_observatory.behavior.event_detection import filter_events_array def test_filter_events_array(): diff --git a/allensdk/test/brain_observatory/behavior/test_eye_tracking_processing.py b/allensdk/test/brain_observatory/behavior/test_eye_tracking_processing.py index d7fd01ca6f..d8b0463101 100644 --- a/allensdk/test/brain_observatory/behavior/test_eye_tracking_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_eye_tracking_processing.py @@ -6,9 +6,14 @@ import pandas as pd from allensdk.brain_observatory.behavior.eye_tracking_processing import ( - load_eye_tracking_hdf, determine_outliers, compute_circular_area, - compute_elliptical_area, determine_likely_blinks, - process_eye_tracking_data, EyeTrackingError) + load_eye_tracking_hdf, + determine_outliers, + compute_circular_area, + compute_elliptical_area, + determine_likely_blinks, + process_eye_tracking_data, + EyeTrackingError, +) def create_preload_eye_tracking_df(data: np.ndarray) -> pd.DataFrame: @@ -17,12 +22,25 @@ def create_preload_eye_tracking_df(data: np.ndarray) -> pd.DataFrame: def create_loaded_eye_tracking_df(data: np.ndarray) -> pd.DataFrame: - columns = ["cr_center_x", "cr_center_y", "cr_width", "cr_height", "cr_phi", - "eye_center_x", "eye_center_y", "eye_width", "eye_height", - "eye_phi", "pupil_center_x", "pupil_center_y", "pupil_width", - "pupil_height", "pupil_phi"] + columns = [ + "cr_center_x", + "cr_center_y", + "cr_width", + "cr_height", + "cr_phi", + "eye_center_x", + "eye_center_y", + "eye_width", + "eye_height", + "eye_phi", + "pupil_center_x", + "pupil_center_y", + "pupil_width", + "pupil_height", + "pupil_phi", + ] df = pd.DataFrame(data, columns=columns) - df.index.name = 'frame' + df.index.name = "frame" return df @@ -32,17 +50,36 @@ def create_area_df(data: np.ndarray) -> pd.DataFrame: def create_refined_eye_tracking_df(data: np.ndarray) -> pd.DataFrame: - columns = ["timestamps", "cr_area", "eye_area", "pupil_area", - "likely_blink", "pupil_area_raw", "cr_area_raw", "eye_area_raw", - "cr_center_x", "cr_center_y", "cr_width", "cr_height", "cr_phi", - "eye_center_x", "eye_center_y", "eye_width", "eye_height", - "eye_phi", "pupil_center_x", "pupil_center_y", "pupil_width", - "pupil_height", "pupil_phi"] + columns = [ + "timestamps", + "cr_area", + "eye_area", + "pupil_area", + "likely_blink", + "pupil_area_raw", + "cr_area_raw", + "eye_area_raw", + "cr_center_x", + "cr_center_y", + "cr_width", + "cr_height", + "cr_phi", + "eye_center_x", + "eye_center_y", + "eye_width", + "eye_height", + "eye_phi", + "pupil_center_x", + "pupil_center_y", + "pupil_width", + "pupil_height", + "pupil_phi", + ] df = pd.DataFrame(data, columns=columns) - df.index.name = 'frame' + df.index.name = "frame" # Initializing a df coerces all data to one dtype # restoring the bool dtype for the 'likely_blink' column. - df['likely_blink'] = df['likely_blink'].apply(bool) + df["likely_blink"] = df["likely_blink"].apply(bool) return df @@ -63,75 +100,80 @@ def hdf_fixture(request, tmp_path) -> Path: return tmp_hdf_path -@pytest.mark.parametrize("hdf_fixture, expected", [ - ({"cr": np.array([[1., 2., 3., 4., 5.]]), - "eye": np.array([[6., 7., 8., 9., 10.]]), - "pupil": np.array([[11., 12., 13., 14., 15.]])}, - - create_loaded_eye_tracking_df( - np.array([[1., 2., 3., 4., 5., 6., 7., 8., - 9., 10., 11., 12., 13., 14., 15.]])) - ), - - ({"cr": np.array([[5 + 2j, 4 + 1j, 3 + 1j, 2 + 8j, 1 + 1j]]), - "eye": np.array([[6, 7, 8, 9, 10]]), - "pupil": np.array([[15 + 1j, 14 + 3j, 13 + 2j, 12 + 1j, 11 + 1j]])}, - - create_loaded_eye_tracking_df( - np.array([[5., 4., 3., 2., 1., 6., 7., 8., - 9., 10., 15., 14., 13., 12., 11.]])) - ), - -], indirect=["hdf_fixture"]) +@pytest.mark.parametrize( + "hdf_fixture, expected", + [ + ( + { + "cr": np.array([[1.0, 2.0, 3.0, 4.0, 5.0]]), + "eye": np.array([[6.0, 7.0, 8.0, 9.0, 10.0]]), + "pupil": np.array([[11.0, 12.0, 13.0, 14.0, 15.0]]), + }, + create_loaded_eye_tracking_df( + np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0]]) + ), + ), + ( + { + "cr": np.array([[5 + 2j, 4 + 1j, 3 + 1j, 2 + 8j, 1 + 1j]]), + "eye": np.array([[6, 7, 8, 9, 10]]), + "pupil": np.array([[15 + 1j, 14 + 3j, 13 + 2j, 12 + 1j, 11 + 1j]]), + }, + create_loaded_eye_tracking_df( + np.array([[5.0, 4.0, 3.0, 2.0, 1.0, 6.0, 7.0, 8.0, 9.0, 10.0, 15.0, 14.0, 13.0, 12.0, 11.0]]) + ), + ), + ], + indirect=["hdf_fixture"], +) def test_load_eye_tracking_hdf(hdf_fixture: Path, expected: pd.DataFrame): obtained = load_eye_tracking_hdf(hdf_fixture) assert expected.equals(obtained) -@pytest.mark.parametrize("data_df, z_threshold, expected", [ - (create_area_df( - np.array([[1, 1, 2], - [2, 2, 1], - [1, 7, 3], - [1, 1, 1], - [1, 3, 2], - [1, 1, 1], - [1, 2, 1], - [2, 1, 1000]])), - 2.5, - pd.Series([False, False, False, False, False, False, False, True])), - - (create_area_df( - np.array([[1, 1, 2], - [2, 2, 1], - [1, 7, 3], - [1, 1, 1], - [1, 3, 2], - [1, 1, 1], - [1, 2, 1], - [2, 1, 1000]])), - 2.0, - pd.Series([False, False, True, False, False, False, False, True])), - -]) +@pytest.mark.parametrize( + "data_df, z_threshold, expected", + [ + ( + create_area_df( + np.array([[1, 1, 2], [2, 2, 1], [1, 7, 3], [1, 1, 1], [1, 3, 2], [1, 1, 1], [1, 2, 1], [2, 1, 1000]]) + ), + 2.5, + pd.Series([False, False, False, False, False, False, False, True]), + ), + ( + create_area_df( + np.array([[1, 1, 2], [2, 2, 1], [1, 7, 3], [1, 1, 1], [1, 3, 2], [1, 1, 1], [1, 2, 1], [2, 1, 1000]]) + ), + 2.0, + pd.Series([False, False, True, False, False, False, False, True]), + ), + ], +) def test_determine_outliers(data_df, z_threshold, expected): obtained = determine_outliers(data_df, z_threshold) assert np.allclose(obtained, expected) -@pytest.mark.parametrize("df_row, expected", [ - (pd.Series([3, 2], index=["width", "height"]), 9 * np.pi), - (pd.Series([2, 3], index=["width", "height"]), 9 * np.pi), -]) +@pytest.mark.parametrize( + "df_row, expected", + [ + (pd.Series([3, 2], index=["width", "height"]), 9 * np.pi), + (pd.Series([2, 3], index=["width", "height"]), 9 * np.pi), + ], +) def test_compute_circular_area(df_row: pd.Series, expected: float): obtained_area = compute_circular_area(df_row) assert obtained_area == expected -@pytest.mark.parametrize("df_row, expected", [ - (pd.Series([3, 2], index=["width", "height"]), 6 * np.pi), - (pd.Series([2, 3], index=["width", "height"]), 6 * np.pi), -]) +@pytest.mark.parametrize( + "df_row, expected", + [ + (pd.Series([3, 2], index=["width", "height"]), 6 * np.pi), + (pd.Series([2, 3], index=["width", "height"]), 6 * np.pi), + ], +) def test_compute_elliptical_area(df_row: pd.Series, expected: float): obtained_area = compute_elliptical_area(df_row) assert obtained_area == expected @@ -140,90 +182,147 @@ def test_compute_elliptical_area(df_row: pd.Series, expected: float): @pytest.mark.parametrize( "eye_areas, pupil_areas, outliers, dilation_frames, expected", [ - (pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), - pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), - pd.Series([False, False, False, False, False, False, - False, False, True]), - 2, - pd.Series([True, True, True, True, True, True, True, True, True])), - - (pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), - pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), - pd.Series([False, False, False, False, False, False, - False, False, True]), - 1, - pd.Series([True, True, False, True, True, True, False, True, True])), - - - (pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), - pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), - pd.Series([False, False, False, False, False, False, - False, False, True]), - 0, - pd.Series([True, False, False, False, True, False, False, - False, True])), - ]) -def test_determine_likely_blinks(eye_areas, pupil_areas, outliers, - dilation_frames, expected): - obtained = determine_likely_blinks(eye_areas, pupil_areas, outliers, - dilation_frames) + ( + pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), + pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), + pd.Series([False, False, False, False, False, False, False, False, True]), + 2, + pd.Series([True, True, True, True, True, True, True, True, True]), + ), + ( + pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), + pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), + pd.Series([False, False, False, False, False, False, False, False, True]), + 1, + pd.Series([True, True, False, True, True, True, False, True, True]), + ), + ( + pd.Series([4, 8, 3, 20, np.nan, 10, 21, 19, 42]), + pd.Series([np.nan, 10, 2, 30, 99, 80, 93, 18, 777]), + pd.Series([False, False, False, False, False, False, False, False, True]), + 0, + pd.Series([True, False, False, False, True, False, False, False, True]), + ), + ], +) +def test_determine_likely_blinks(eye_areas, pupil_areas, outliers, dilation_frames, expected): + obtained = determine_likely_blinks(eye_areas, pupil_areas, outliers, dilation_frames) assert expected.equals(obtained) -@pytest.mark.parametrize("eye_tracking_df, frame_times", [ - (create_loaded_eye_tracking_df( - np.array([[1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1], - [2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2]])), - pd.Series(np.arange(0, 1.8, 0.1))), -]) -def test_process_eye_tracking_data_raises_on_sync_error(eye_tracking_df, - frame_times): +@pytest.mark.parametrize( + "eye_tracking_df, frame_times", + [ + ( + create_loaded_eye_tracking_df( + np.array([[1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1], [2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2]]) + ), + pd.Series(np.arange(0, 1.8, 0.1)), + ), + ], +) +def test_process_eye_tracking_data_raises_on_sync_error(eye_tracking_df, frame_times): """ Test that an error is raised when the number of sync timestamps exceeds the number of eye tracking frames by more than 15 """ - with pytest.raises(EyeTrackingError, - match='Error! The number of sync file'): + with pytest.raises(EyeTrackingError, match="Error! The number of sync file"): process_eye_tracking_data(eye_tracking_df, frame_times) -@pytest.mark.parametrize("eye_tracking_df, frame_times", [ - (create_loaded_eye_tracking_df( - np.array([[1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1], - [2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2]])), - pd.Series(np.arange(0, 1.7, 0.1))), -]) -def test_process_eye_tracking_data_truncation(eye_tracking_df, - frame_times): +@pytest.mark.parametrize( + "eye_tracking_df, frame_times", + [ + ( + create_loaded_eye_tracking_df( + np.array([[1, 1, 2, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1], [2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 1, 2]]) + ), + pd.Series(np.arange(0, 1.7, 0.1)), + ), + ], +) +def test_process_eye_tracking_data_truncation(eye_tracking_df, frame_times): """ Test that the array of sync times is truncated when the number of raw sync timestamps exceeds the numer of eye tracking frames by <= 15 """ df = process_eye_tracking_data(eye_tracking_df, frame_times) - np.testing.assert_array_almost_equal(df.timestamps.to_numpy(), - np.array([0.0, 0.1]), - decimal=10) - - -@pytest.mark.parametrize("eye_tracking_df, frame_times, expected", [ - (create_loaded_eye_tracking_df( - np.array([[1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., - 12., 13., 14., 15.], - [2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., - 13., 14., 15., 16.]])), - pd.Series([0.1, 0.2]), - create_refined_eye_tracking_df( - np.array([[0.1, 12 * np.pi, 72 * np.pi, 196 * np.pi, False, - 196 * np.pi, 12 * np.pi, 72 * np.pi, - 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., - 13., 14., 15.], - [0.2, 20 * np.pi, 90 * np.pi, 225 * np.pi, False, - 225 * np.pi, 20 * np.pi, 90 * np.pi, - 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., - 14., 15., 16.]])) - ), -]) + np.testing.assert_array_almost_equal(df.timestamps.to_numpy(), np.array([0.0, 0.1]), decimal=10) + + +@pytest.mark.parametrize( + "eye_tracking_df, frame_times, expected", + [ + ( + create_loaded_eye_tracking_df( + np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0], + [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0], + ] + ) + ), + pd.Series([0.1, 0.2]), + create_refined_eye_tracking_df( + np.array( + [ + [ + 0.1, + 12 * np.pi, + 72 * np.pi, + 196 * np.pi, + False, + 196 * np.pi, + 12 * np.pi, + 72 * np.pi, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0, + ], + [ + 0.2, + 20 * np.pi, + 90 * np.pi, + 225 * np.pi, + False, + 225 * np.pi, + 20 * np.pi, + 90 * np.pi, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0, + 16.0, + ], + ] + ) + ), + ), + ], +) def test_process_eye_tracking_data(eye_tracking_df, frame_times, expected): obtained = process_eye_tracking_data(eye_tracking_df, frame_times) pd.testing.assert_frame_equal(obtained, expected) diff --git a/allensdk/test/brain_observatory/behavior/test_incomplete_data_objects.py b/allensdk/test/brain_observatory/behavior/test_incomplete_data_objects.py index c32a115a52..5f604cb42c 100644 --- a/allensdk/test/brain_observatory/behavior/test_incomplete_data_objects.py +++ b/allensdk/test/brain_observatory/behavior/test_incomplete_data_objects.py @@ -1,21 +1,14 @@ import pytest import copy -from allensdk.brain_observatory.behavior.\ - data_objects.eye_tracking.eye_tracking_table import EyeTrackingTable -from allensdk.brain_observatory.behavior.\ - data_objects.licks import Licks -from allensdk.brain_observatory.behavior.\ - data_objects.rewards import Rewards -from allensdk.brain_observatory.behavior.behavior_ophys_experiment import ( - BehaviorOphysExperiment) +from allensdk.brain_observatory.behavior.data_objects.eye_tracking.eye_tracking_table import EyeTrackingTable +from allensdk.brain_observatory.behavior.data_objects.licks import Licks +from allensdk.brain_observatory.behavior.data_objects.rewards import Rewards +from allensdk.brain_observatory.behavior.behavior_ophys_experiment import BehaviorOphysExperiment @pytest.mark.requires_bamboo -def test_incomplete_eye_tracking( - behavior_ophys_experiment_fixture, - skeletal_nwb_fixture): - +def test_incomplete_eye_tracking(behavior_ophys_experiment_fixture, skeletal_nwb_fixture): populated_eye_tracking = behavior_ophys_experiment_fixture.eye_tracking empty_eye_tracking = EyeTrackingTable.from_nwb(skeletal_nwb_fixture) empty_eye_tracking_df = empty_eye_tracking.value @@ -27,25 +20,23 @@ def test_incomplete_eye_tracking( empty_columns = set(empty_eye_tracking_df.columns) assert populated_columns == empty_columns - assert (populated_eye_tracking.index.name - == empty_eye_tracking_df.index.name) + assert populated_eye_tracking.index.name == empty_eye_tracking_df.index.name # make sure, when round-tripping the experiments, that the # populated experiment still writes out a populated data frame nwb1 = behavior_ophys_experiment_fixture.to_nwb() - assert 'EyeTracking' in nwb1.acquisition.keys() + assert "EyeTracking" in nwb1.acquisition.keys() roundtrip1 = BehaviorOphysExperiment.from_nwb(nwb1) assert len(roundtrip1.eye_tracking) > 0 nwb2 = empty_eye_tracking.to_nwb(skeletal_nwb_fixture) - assert 'EyeTracking' not in nwb2.acquisition.keys() + assert "EyeTracking" not in nwb2.acquisition.keys() roundtrip2 = EyeTrackingTable.from_nwb(nwb2) assert len(roundtrip2.value) == 0 @pytest.mark.requires_bamboo -def test_incomplete_eye_tracking_from_lims( - behavior_ophys_experiment_fixture): +def test_incomplete_eye_tracking_from_lims(behavior_ophys_experiment_fixture): """ Compare a BehaviorOphysExperiment without eye tracking data from_lims with a BehaviorOphysExperiment with eye tracking data. @@ -53,8 +44,7 @@ def test_incomplete_eye_tracking_from_lims( """ incomplete_exp_id = 806456687 - incomplete_experiment = BehaviorOphysExperiment.from_lims( - incomplete_exp_id) + incomplete_experiment = BehaviorOphysExperiment.from_lims(incomplete_exp_id) complete = behavior_ophys_experiment_fixture.eye_tracking incomplete = incomplete_experiment.eye_tracking @@ -66,10 +56,7 @@ def test_incomplete_eye_tracking_from_lims( @pytest.mark.requires_bamboo -def test_incomplete_licks( - behavior_ophys_experiment_fixture, - skeletal_nwb_fixture): - +def test_incomplete_licks(behavior_ophys_experiment_fixture, skeletal_nwb_fixture): populated_licks = behavior_ophys_experiment_fixture.licks empty_licks = Licks.from_nwb(skeletal_nwb_fixture) empty_licks_df = empty_licks.value @@ -86,21 +73,18 @@ def test_incomplete_licks( # make sure, when round-tripping the experiments, that the # populated experiment still writes out a populated data frame nwb1 = behavior_ophys_experiment_fixture.to_nwb() - assert 'licking' in nwb1.processing.keys() + assert "licking" in nwb1.processing.keys() roundtrip1 = BehaviorOphysExperiment.from_nwb(nwb1) assert len(roundtrip1.licks) > 0 nwb2 = empty_licks.to_nwb(skeletal_nwb_fixture) - assert 'licking' not in nwb2.processing.keys() + assert "licking" not in nwb2.processing.keys() roundtrip2 = Licks.from_nwb(nwb2) assert len(roundtrip2.value) == 0 @pytest.mark.requires_bamboo -def test_incomplete_rewards( - behavior_ophys_experiment_fixture, - skeletal_nwb_fixture): - +def test_incomplete_rewards(behavior_ophys_experiment_fixture, skeletal_nwb_fixture): populated_rewards = behavior_ophys_experiment_fixture.rewards empty_rewards = Rewards.from_nwb(skeletal_nwb_fixture) empty_rewards_df = empty_rewards.value @@ -117,23 +101,19 @@ def test_incomplete_rewards( # make sure, when round-tripping the experiments, that the # populated experiment still writes out a populated data frame nwb1 = behavior_ophys_experiment_fixture.to_nwb() - assert 'rewards' in nwb1.processing.keys() + assert "rewards" in nwb1.processing.keys() roundtrip1 = BehaviorOphysExperiment.from_nwb(nwb1) assert len(roundtrip1.rewards) > 0 nwb2 = empty_rewards.to_nwb(skeletal_nwb_fixture) - assert 'rewards' not in nwb2.processing.keys() + assert "rewards" not in nwb2.processing.keys() roundtrip2 = Rewards.from_nwb(nwb2) assert len(roundtrip2.value) == 0 @pytest.mark.requires_bamboo -def test_incomplete_rig_geometry( - behavior_ophys_experiment_fixture, - skeletal_nwb_fixture): - - populated_rig_geom = behavior_ophys_experiment_fixture.\ - eye_tracking_rig_geometry +def test_incomplete_rig_geometry(behavior_ophys_experiment_fixture, skeletal_nwb_fixture): + populated_rig_geom = behavior_ophys_experiment_fixture.eye_tracking_rig_geometry assert len(populated_rig_geom) > 0 @@ -147,22 +127,22 @@ def test_incomplete_rig_geometry( no_rig_geom._eye_tracking_rig_geometry = None # make sure we didn't alter the fixture - assert behavior_ophys_experiment_fixture.\ - _eye_tracking_rig_geometry is not None + assert behavior_ophys_experiment_fixture._eye_tracking_rig_geometry is not None nwb = no_rig_geom.to_nwb() test_ophys_experiment = BehaviorOphysExperiment.from_nwb(nwb) - assert 'eye_tracking_rig_metadata' not in nwb.processing.keys() + assert "eye_tracking_rig_metadata" not in nwb.processing.keys() assert len(test_ophys_experiment.eye_tracking_rig_geometry) == 0 assert isinstance( - test_ophys_experiment.eye_tracking_rig_geometry, - type(behavior_ophys_experiment_fixture.eye_tracking_rig_geometry)) + test_ophys_experiment.eye_tracking_rig_geometry, + type(behavior_ophys_experiment_fixture.eye_tracking_rig_geometry), + ) # make sure that the populated BehaviorOphysExperiment actually # writes out RigGeometry to NWB files nwb = behavior_ophys_experiment_fixture.to_nwb() - assert 'eye_tracking_rig_metadata' in nwb.processing.keys() + assert "eye_tracking_rig_metadata" in nwb.processing.keys() roundtrip = BehaviorOphysExperiment.from_nwb(nwb) assert len(roundtrip.eye_tracking_rig_geometry) > 0 diff --git a/allensdk/test/brain_observatory/behavior/test_mtrain_annotate.py b/allensdk/test/brain_observatory/behavior/test_mtrain_annotate.py index 983a419aef..c34ed18566 100644 --- a/allensdk/test/brain_observatory/behavior/test_mtrain_annotate.py +++ b/allensdk/test/brain_observatory/behavior/test_mtrain_annotate.py @@ -6,13 +6,10 @@ @pytest.fixture def trials(): - return pd.DataFrame({ - 'trial_type': ['go', 'catch', 'go', 'catch'], - 'response': [1.0, 1.0, 0.0, 0.0]}) + return pd.DataFrame({"trial_type": ["go", "catch", "go", "catch"], "response": [1.0, 1.0, 0.0, 0.0]}) def test_annotate_change_detect(trials): - annotate_change_detect(trials) - pd.testing.assert_series_equal(trials['change'], pd.Series([True, False, True, False], name='change')) - pd.testing.assert_series_equal(trials['detect'], pd.Series([True, True, False, False], name='detect')) + pd.testing.assert_series_equal(trials["change"], pd.Series([True, False, True, False], name="change")) + pd.testing.assert_series_equal(trials["detect"], pd.Series([True, True, False, False], name="detect")) diff --git a/allensdk/test/brain_observatory/behavior/test_prior_exposure_count_processing.py b/allensdk/test/brain_observatory/behavior/test_prior_exposure_count_processing.py index 82830eb8ba..0c891dfa6b 100644 --- a/allensdk/test/brain_observatory/behavior/test_prior_exposure_count_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_prior_exposure_count_processing.py @@ -86,88 +86,57 @@ def get_behavior_stage_parameters(self, foraging_ids): def test_add_experience_level(): - input_data = [] expected_data = [] - datum = {'id': 0, - 'session_number': 1, - 'prior_exposures_to_image_set': 4, - 'session_type': 'OPHYS_1'} + datum = {"id": 0, "session_number": 1, "prior_exposures_to_image_set": 4, "session_type": "OPHYS_1"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Familiar' + datum["experience_level"] = "Familiar" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 1, - 'session_number': 2, - 'prior_exposures_to_image_set': 5, - 'session_type': 'OPHYS_2'} + datum = {"id": 1, "session_number": 2, "prior_exposures_to_image_set": 5, "session_type": "OPHYS_2"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Familiar' + datum["experience_level"] = "Familiar" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 2, - 'session_number': 3, - 'prior_exposures_to_image_set': 1772, - 'session_type': 'OPHYS_3'} + datum = {"id": 2, "session_number": 3, "prior_exposures_to_image_set": 1772, "session_type": "OPHYS_3"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Familiar' + datum["experience_level"] = "Familiar" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 3, - 'session_number': 4, - 'prior_exposures_to_image_set': 0, - 'session_type': 'OPHYS_4'} + datum = {"id": 3, "session_number": 4, "prior_exposures_to_image_set": 0, "session_type": "OPHYS_4"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel 1' + datum["experience_level"] = "Novel 1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 4, - 'session_number': 5, - 'prior_exposures_to_image_set': 0, - 'session_type': 'OPHYS_5'} + datum = {"id": 4, "session_number": 5, "prior_exposures_to_image_set": 0, "session_type": "OPHYS_5"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel 1' + datum["experience_level"] = "Novel 1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 5, - 'session_number': 6, - 'prior_exposures_to_image_set': 0, - 'session_type': 'OPHYS_6'} + datum = {"id": 5, "session_number": 6, "prior_exposures_to_image_set": 0, "session_type": "OPHYS_6"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel 1' + datum["experience_level"] = "Novel 1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 7, - 'session_number': 4, - 'prior_exposures_to_image_set': 2, - 'session_type': 'OPHYS_4'} + datum = {"id": 7, "session_number": 4, "prior_exposures_to_image_set": 2, "session_type": "OPHYS_4"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel >1' + datum["experience_level"] = "Novel >1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 8, - 'session_number': 5, - 'prior_exposures_to_image_set': 1, - 'session_type': 'OPHYS_5'} + datum = {"id": 8, "session_number": 5, "prior_exposures_to_image_set": 1, "session_type": "OPHYS_5"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel >1' + datum["experience_level"] = "Novel >1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 9, - 'session_number': 6, - 'prior_exposures_to_image_set': 3, - 'session_type': 'OPHYS_6'} + datum = {"id": 9, "session_number": 6, "prior_exposures_to_image_set": 3, "session_type": "OPHYS_6"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'Novel >1' + datum["experience_level"] = "Novel >1" expected_data.append(copy.deepcopy(datum)) - datum = {'id': 10, - 'session_number': 7, - 'prior_exposures_to_image_set': 3, - 'session_type': 'OPHYS_7'} + datum = {"id": 10, "session_number": 7, "prior_exposures_to_image_set": 3, "session_type": "OPHYS_7"} input_data.append(copy.deepcopy(datum)) - datum['experience_level'] = 'None' + datum["experience_level"] = "None" expected_data.append(copy.deepcopy(datum)) input_df = pd.DataFrame(input_data) diff --git a/allensdk/test/brain_observatory/behavior/test_rewards_processing.py b/allensdk/test/brain_observatory/behavior/test_rewards_processing.py index d344bce6ee..c37c3b3af8 100644 --- a/allensdk/test/brain_observatory/behavior/test_rewards_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_rewards_processing.py @@ -10,29 +10,24 @@ def test_get_rewards(): "behavior": { "trial_log": [ { - 'rewards': [(0.007, 1085.96, 55)], - 'trial_params': { - 'catch': False, 'auto_reward': False, - 'change_time': 5}}, + "rewards": [(0.007, 1085.96, 55)], + "trial_params": {"catch": False, "auto_reward": False, "change_time": 5}, + }, { - 'rewards': [(0.008, 1090.01, 66)], - 'trial_params': { - 'catch': False, 'auto_reward': True, - 'change_time': 6}}, + "rewards": [(0.008, 1090.01, 66)], + "trial_params": {"catch": False, "auto_reward": True, "change_time": 6}, + }, { - 'rewards': [], - 'trial_params': { - 'catch': False, 'auto_reward': False, - 'change_time': 4}, + "rewards": [], + "trial_params": {"catch": False, "auto_reward": False, "change_time": 4}, }, - ] - }}} - expected = pd.DataFrame( - {"volume": [0.007, 0.008], - "timestamps": [14.0, 15.0], - "autorewarded": [False, True]}) + ] + } + } + } + expected = pd.DataFrame({"volume": [0.007, 0.008], "timestamps": [14.0, 15.0], "autorewarded": [False, True]}) - timesteps = -1*np.ones(100, dtype=float) + timesteps = -1 * np.ones(100, dtype=float) timesteps[55] = 14.0 timesteps[66] = 15.0 pd.testing.assert_frame_equal(expected, get_rewards(data, timesteps)) diff --git a/allensdk/test/brain_observatory/behavior/test_session_metrics.py b/allensdk/test/brain_observatory/behavior/test_session_metrics.py index 8fc6ad96ef..ceb887a512 100644 --- a/allensdk/test/brain_observatory/behavior/test_session_metrics.py +++ b/allensdk/test/brain_observatory/behavior/test_session_metrics.py @@ -8,29 +8,33 @@ "trials, detect_col, trial_types, expected", [ ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True]}), + pd.DataFrame( + {"trial_type": ["go", "go", "catch", "catch", "aborted"], "detect": [True, False, True, True, True]} + ), "detect", ["go", "catch"], 0.75, ), ( - pd.DataFrame({"trial_type":[ "go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True]}), + pd.DataFrame( + {"trial_type": ["go", "go", "catch", "catch", "aborted"], "detect": [True, False, True, True, True]} + ), "detect", ["go"], 0.5, ), ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True]}), + pd.DataFrame( + {"trial_type": ["go", "go", "catch", "catch", "aborted"], "detect": [True, False, True, True, True]} + ), "detect", [], 0.8, ), ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True]}), + pd.DataFrame( + {"trial_type": ["go", "go", "catch", "catch", "aborted"], "detect": [True, False, True, True, True]} + ), "detect", ["early"], np.nan, @@ -38,33 +42,48 @@ ], ) def test_response_bias(trials, detect_col, trial_types, expected): - assert metrics.response_bias(trials, detect_col, trial_types) == \ - pytest.approx(expected, nan_ok=True) + assert metrics.response_bias(trials, detect_col, trial_types) == pytest.approx(expected, nan_ok=True) @pytest.mark.parametrize( "trials, expected", [ ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + } + ), 4, ), ( - pd.DataFrame({"trial_type":[ "go", "go", "go"],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "go"], + } + ), 3, ), ( - pd.DataFrame({"trial_type": ["catch"],}), + pd.DataFrame( + { + "trial_type": ["catch"], + } + ), 1, ), ( - pd.DataFrame({"trial_type": [],}), + pd.DataFrame( + { + "trial_type": [], + } + ), 0, ), ( pd.DataFrame({"trial_type": ["aborted", "nogo"]}), 0, - ) + ), ], ) def test_num_contingent_trials(trials, expected): diff --git a/allensdk/test/brain_observatory/behavior/test_stimulus_processing.py b/allensdk/test/brain_observatory/behavior/test_stimulus_processing.py index 435c9ce9cf..b2a69a4e8a 100644 --- a/allensdk/test/brain_observatory/behavior/test_stimulus_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_stimulus_processing.py @@ -32,16 +32,13 @@ def behavior_stimuli_time_fixture(request): timestamp_count = request.param["timestamp_count"] time_step = request.param["time_step"] - timestamps = np.array( - [time_step * i for i in range(timestamp_count)] - ).astype("int64") + timestamps = np.array([time_step * i for i in range(timestamp_count)]).astype("int64") return timestamps @pytest.mark.parametrize( - "behavior_stimuli_data_fixture,current_set_ix,start_frame," - "n_frames,expected", + "behavior_stimuli_data_fixture,current_set_ix,start_frame,n_frames,expected", [ ( { @@ -88,8 +85,7 @@ def test_get_stimulus_epoch( @pytest.mark.parametrize( - "behavior_stimuli_data_fixture,start_frame,stop_frame,expected," - "stimuli_type", + "behavior_stimuli_data_fixture,start_frame,stop_frame,expected,stimuli_type", [ ( { @@ -171,11 +167,7 @@ def test_get_draw_epochs( expected, stimuli_type, ): - draw_log = behavior_stimuli_data_fixture["items"]["behavior"]["stimuli"][ - stimuli_type - ][ - "draw_log" - ] # noqa: E128 + draw_log = behavior_stimuli_data_fixture["items"]["behavior"]["stimuli"][stimuli_type]["draw_log"] # noqa: E128 actual = _get_draw_epochs(draw_log, start_frame, stop_frame) assert actual == expected @@ -186,9 +178,7 @@ def test_get_draw_epochs( indirect=["behavior_stimuli_data_fixture"], ) def test_get_stimulus_templates(behavior_stimuli_data_fixture): - templates = get_stimulus_templates( - behavior_stimuli_data_fixture, grating_images_dict={} - ) + templates = get_stimulus_templates(behavior_stimuli_data_fixture, grating_images_dict={}) assert templates.image_set_name == "test_image_set" assert len(templates) == 1 @@ -197,9 +187,7 @@ def test_get_stimulus_templates(behavior_stimuli_data_fixture): for img in templates.values(): assert isinstance(img, StimulusImage) - expected_path = os.path.join( - get_resources_dir(), "stimulus_template", "expected" - ) + expected_path = os.path.join(get_resources_dir(), "stimulus_template", "expected") expected_unwarped_path = os.path.join(expected_path, "im065_unwarped.pkl") expected_unwarped = pd.read_pickle(expected_unwarped_path) @@ -219,7 +207,7 @@ def test_get_stimulus_templates(behavior_stimuli_data_fixture): @pytest.mark.parametrize( - ("behavior_stimuli_data_fixture, " "grating_images_dict, expected"), + ("behavior_stimuli_data_fixture, grating_images_dict, expected"), [ ( {"has_images": False}, @@ -234,21 +222,13 @@ def test_get_stimulus_templates(behavior_stimuli_data_fixture): ], indirect=["behavior_stimuli_data_fixture"], ) -def test_get_stimulus_templates_for_gratings( - behavior_stimuli_data_fixture, grating_images_dict, expected -): - templates = get_stimulus_templates( - behavior_stimuli_data_fixture, grating_images_dict=grating_images_dict - ) +def test_get_stimulus_templates_for_gratings(behavior_stimuli_data_fixture, grating_images_dict, expected): + templates = get_stimulus_templates(behavior_stimuli_data_fixture, grating_images_dict=grating_images_dict) assert templates.image_set_name == "grating" assert list(templates.keys()) == ["gratings_90.0"] - assert np.allclose( - templates["gratings_90.0"].warped, np.array([[1, 1], [1, 1]]) - ) - assert np.allclose( - templates["gratings_90.0"].unwarped, np.array([[2, 2], [2, 2]]) - ) + assert np.allclose(templates["gratings_90.0"].warped, np.array([[1, 1], [1, 1]])) + assert np.allclose(templates["gratings_90.0"].unwarped, np.array([[2, 2], [2, 2]])) # def test_get_images_dict(): @@ -259,8 +239,7 @@ def test_get_stimulus_templates_for_gratings( @pytest.mark.parametrize( - "behavior_stimuli_data_fixture, remove_stimuli, " - "starting_index, expected_metadata", + "behavior_stimuli_data_fixture, remove_stimuli, starting_index, expected_metadata", [ ( {"grating_set_log": []}, @@ -375,7 +354,7 @@ def test_get_gratings_metadata( @pytest.mark.parametrize( - "behavior_stimuli_data_fixture, remove_stimuli, " "expected_metadata", + "behavior_stimuli_data_fixture, remove_stimuli, expected_metadata", [ ( { @@ -419,9 +398,7 @@ def test_get_gratings_metadata( ], indirect=["behavior_stimuli_data_fixture"], ) -def test_get_stimulus_metadata( - behavior_stimuli_data_fixture, remove_stimuli, expected_metadata -): +def test_get_stimulus_metadata(behavior_stimuli_data_fixture, remove_stimuli, expected_metadata): for key in remove_stimuli: # do this because at current images are not tested and there's a # hard coded path that prevents testing when this is fixed this can @@ -436,9 +413,7 @@ def test_get_stimulus_metadata( @pytest.mark.parametrize( - "behavior_stimuli_time_fixture," - "behavior_stimuli_data_fixture, " - "expected", + "behavior_stimuli_time_fixture,behavior_stimuli_data_fixture, expected", [ ( {"timestamp_count": 15, "time_step": 1}, @@ -469,12 +444,8 @@ def test_get_stimulus_metadata( "behavior_stimuli_data_fixture", ], ) -def test_get_stimulus_presentations( - behavior_stimuli_time_fixture, behavior_stimuli_data_fixture, expected -): - presentations_df = get_stimulus_presentations( - behavior_stimuli_data_fixture, behavior_stimuli_time_fixture - ) +def test_get_stimulus_presentations(behavior_stimuli_time_fixture, behavior_stimuli_data_fixture, expected): + presentations_df = get_stimulus_presentations(behavior_stimuli_data_fixture, behavior_stimuli_time_fixture) expected_df = pd.DataFrame.from_dict(expected) expected_df.index.name = "stimulus_presentations_id" @@ -483,9 +454,7 @@ def test_get_stimulus_presentations( @pytest.mark.parametrize( - "behavior_stimuli_time_fixture," - "behavior_stimuli_data_fixture," - "expected_data", + "behavior_stimuli_time_fixture,behavior_stimuli_data_fixture,expected_data", [ ( {"timestamp_count": 15, "time_step": 1}, @@ -516,9 +485,7 @@ def test_get_stimulus_presentations( ("Image", "im065", 5, 0), ("Image", "im064", 25, 6), ], - "images_draw_log": ( - ([0] * 2 + [1] * 2 + [0] * 3) * 2 + [0] * 16 - ), + "images_draw_log": (([0] * 2 + [1] * 2 + [0] * 3) * 2 + [0] * 16), "grating_set_log": [ ("Ori", 90, -1, 12), # -1 because that element is not used @@ -542,12 +509,8 @@ def test_get_stimulus_presentations( "behavior_stimuli_data_fixture", ], ) -def test_get_visual_stimuli_df( - behavior_stimuli_time_fixture, behavior_stimuli_data_fixture, expected_data -): - stimuli_df = get_visual_stimuli_df( - behavior_stimuli_data_fixture, behavior_stimuli_time_fixture - ) +def test_get_visual_stimuli_df(behavior_stimuli_time_fixture, behavior_stimuli_data_fixture, expected_data): + stimuli_df = get_visual_stimuli_df(behavior_stimuli_data_fixture, behavior_stimuli_time_fixture) stimuli_df = stimuli_df.drop("index", axis=1) expected_df = pd.DataFrame.from_dict(expected_data) @@ -556,9 +519,7 @@ def test_get_visual_stimuli_df( def test_is_change_event_no_change(): """Test case for no change""" - stimulus_presentations = pd.DataFrame( - {"image_name": ["A", "A", "A"], "omitted": [False, False, False]} - ) + stimulus_presentations = pd.DataFrame({"image_name": ["A", "A", "A"], "omitted": [False, False, False]}) obtained = is_change_event(stimulus_presentations=stimulus_presentations) expected = pd.Series([False, False, False], name="is_change") @@ -567,9 +528,7 @@ def test_is_change_event_no_change(): def test_is_change_event_all_change(): """Test case for all change""" - stimulus_presentations = pd.DataFrame( - {"image_name": ["A", "B", "C"], "omitted": [False, False, False]} - ) + stimulus_presentations = pd.DataFrame({"image_name": ["A", "B", "C"], "omitted": [False, False, False]}) obtained = is_change_event(stimulus_presentations=stimulus_presentations) expected = pd.Series([False, True, True], name="is_change") @@ -578,9 +537,7 @@ def test_is_change_event_all_change(): def test_is_change_omission(): """Test case for single omission""" - stimulus_presentations = pd.DataFrame( - {"image_name": ["A", "B", "C"], "omitted": [False, True, False]} - ) + stimulus_presentations = pd.DataFrame({"image_name": ["A", "B", "C"], "omitted": [False, True, False]}) obtained = is_change_event(stimulus_presentations=stimulus_presentations) expected = pd.Series([False, False, True], name="is_change") @@ -645,18 +602,12 @@ def test_compute_trials_id_for_stimulus(): index=stimulus_presentations.index, dtype="int", ) - output_trials_ids = compute_trials_id_for_stimulus( - stimulus_presentations, trials - ) + output_trials_ids = compute_trials_id_for_stimulus(stimulus_presentations, trials) pd.testing.assert_series_equal(output_trials_ids, expected_trials_id) # Test with explicit active block. - stimulus_presentations["active"] = np.array( - [True, True, False, False, False, False, False, False] - ) - output_trials_ids = compute_trials_id_for_stimulus( - stimulus_presentations, trials - ) + stimulus_presentations["active"] = np.array([True, True, False, False, False, False, False, False]) + output_trials_ids = compute_trials_id_for_stimulus(stimulus_presentations, trials) pd.testing.assert_series_equal(output_trials_ids, expected_trials_id) @@ -688,12 +639,8 @@ def test_compute_is_shame_change(): index=stimulus_presentations.index, dtype="bool", ) - stimulus_presentations = compute_is_sham_change( - stimulus_presentations, trials - ) - pd.testing.assert_series_equal( - stimulus_presentations["is_sham_change"], expected - ) + stimulus_presentations = compute_is_sham_change(stimulus_presentations, trials) + pd.testing.assert_series_equal(stimulus_presentations["is_sham_change"], expected) def test_produce_stimulus_block_names(): @@ -711,9 +658,7 @@ def test_produce_stimulus_block_names(): ], index=stimulus_presentations.index, ) - output_stim_names = produce_stimulus_block_names( - stimulus_presentations, "OPHYS_1", "VisualBehaviorTask1B" - ) + output_stim_names = produce_stimulus_block_names(stimulus_presentations, "OPHYS_1", "VisualBehaviorTask1B") assert np.array_equal( output_stim_names["stimulus_block_name"].values, expected_stimulus_block_names.values, @@ -730,9 +675,7 @@ def test_produce_stimulus_block_names(): ], index=stimulus_presentations.index, ) - output_stim_names = produce_stimulus_block_names( - stimulus_presentations, "OPHYS_1_passive", "VisualBehaviorTask1B" - ) + output_stim_names = produce_stimulus_block_names(stimulus_presentations, "OPHYS_1_passive", "VisualBehaviorTask1B") assert np.array_equal( output_stim_names["stimulus_block_name"].values, expected_stimulus_block_names.values, @@ -742,7 +685,5 @@ def test_produce_stimulus_block_names(): stimulus_presentations = pd.DataFrame( data={"stimulus_block": [0, 1, 2, 3]}, ) - output_stim_names = produce_stimulus_block_names( - stimulus_presentations, "OPHYS_1_passive", "NotAProject" - ) + output_stim_names = produce_stimulus_block_names(stimulus_presentations, "OPHYS_1_passive", "NotAProject") assert "stimulus_block_name" not in output_stim_names.columns diff --git a/allensdk/test/brain_observatory/behavior/test_sync_processing.py b/allensdk/test/brain_observatory/behavior/test_sync_processing.py index 49aacc60fa..c6563d8f41 100644 --- a/allensdk/test/brain_observatory/behavior/test_sync_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_sync_processing.py @@ -15,58 +15,53 @@ "specimen_789992909", "ophys_session_819949602", ) -sync_path = os.path.join( - base_dir, - "819949602_sync.h5" -) +sync_path = os.path.join(base_dir, "819949602_sync.h5") @pytest.mark.requires_bamboo -@pytest.mark.parametrize("sync_path, sync_key, count_exp, last_exp", [ - [sync_path, "ophys_frames", 140082, 4530.11659], - [sync_path, "lick_times", 2099, 3860.94482], - [sync_path, "ophys_trigger", 1, 6.8612], - [sync_path, "eye_tracking", 135908, 4531.00479], - [sync_path, "behavior_monitoring", 135887, 4530.19092], - [sync_path, "stim_photodiode", 4512, 4510.80997], - [sync_path, "stimulus_times_no_delay", 269977, 4510.25654], -]) +@pytest.mark.parametrize( + "sync_path, sync_key, count_exp, last_exp", + [ + [sync_path, "ophys_frames", 140082, 4530.11659], + [sync_path, "lick_times", 2099, 3860.94482], + [sync_path, "ophys_trigger", 1, 6.8612], + [sync_path, "eye_tracking", 135908, 4531.00479], + [sync_path, "behavior_monitoring", 135887, 4530.19092], + [sync_path, "stim_photodiode", 4512, 4510.80997], + [sync_path, "stimulus_times_no_delay", 269977, 4510.25654], + ], +) def test_get_time_sync_integration(sync_path, sync_key, count_exp, last_exp): obt = sync.get_sync_data(sync_path)[sync_key] assert count_exp == len(obt) assert last_exp == obt[-1] -@pytest.mark.parametrize("fn, key, rise, fall, expect", [ - [sync.get_trigger, "foo", None, None, None], - [sync.get_trigger, "2p_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_trigger, "acq_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_trigger, "2p_acq_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_trigger, "2p_acquiring", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_trigger, "stim_running", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_eye_tracking, "cam2_exposure", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_eye_tracking, "eye_tracking", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_eye_tracking, "eye_frame_received", [1, 2, 3], [4, 5, 6], - [1, 2, 3]], - [sync.get_behavior_monitoring, "cam1_exposure", [1, 2, 3], [4, 5, 6], - [1, 2, 3]], - [sync.get_behavior_monitoring, "behavior_monitoring", [1, 2, 3], [4, 5, 6], - [1, 2, 3]], - [sync.get_behavior_monitoring, "beh_frame_received", [1, 2, 3], [4, 5, 6], - [1, 2, 3]], - [sync.get_stim_photodiode, "stim_photodiode", [1, 2, 3], [4, 5, 6], - [1, 2, 3, 4, 5, 6]], - [sync.get_stim_photodiode, "photodiode", [1, 2, 3], [4, 5, 6], - [1, 2, 3, 4, 5, 6]], - [sync.get_lick_times, "lick_times", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_lick_times, "lick_sensor", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_ophys_frames, "2p_vsync", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_ophys_frames, "vsync_2p", [1, 2, 3], [4, 5, 6], [1, 2, 3]], - [sync.get_raw_stimulus_frames, "stim_vsync", [1, 2, 3], [4, 5, 6], - [4, 5, 6]], -]) +@pytest.mark.parametrize( + "fn, key, rise, fall, expect", + [ + [sync.get_trigger, "foo", None, None, None], + [sync.get_trigger, "2p_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_trigger, "acq_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_trigger, "2p_acq_trigger", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_trigger, "2p_acquiring", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_trigger, "stim_running", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_eye_tracking, "cam2_exposure", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_eye_tracking, "eye_tracking", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_eye_tracking, "eye_frame_received", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_behavior_monitoring, "cam1_exposure", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_behavior_monitoring, "behavior_monitoring", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_behavior_monitoring, "beh_frame_received", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_stim_photodiode, "stim_photodiode", [1, 2, 3], [4, 5, 6], [1, 2, 3, 4, 5, 6]], + [sync.get_stim_photodiode, "photodiode", [1, 2, 3], [4, 5, 6], [1, 2, 3, 4, 5, 6]], + [sync.get_lick_times, "lick_times", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_lick_times, "lick_sensor", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_ophys_frames, "2p_vsync", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_ophys_frames, "vsync_2p", [1, 2, 3], [4, 5, 6], [1, 2, 3]], + [sync.get_raw_stimulus_frames, "stim_vsync", [1, 2, 3], [4, 5, 6], [4, 5, 6]], + ], +) def test_timestamp_extractors(fn, key, rise, fall, expect): - class Ds(Dataset): def __init__(self): self.line_labels = [key, "1", "2"] diff --git a/allensdk/test/brain_observatory/behavior/test_trial_masks.py b/allensdk/test/brain_observatory/behavior/test_trial_masks.py index 65fa6b49a0..165c667f5f 100644 --- a/allensdk/test/brain_observatory/behavior/test_trial_masks.py +++ b/allensdk/test/brain_observatory/behavior/test_trial_masks.py @@ -7,79 +7,107 @@ "trials, trial_types, expected", [ ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "detect": [True, False, True, True, True], + } + ), ["go", "catch"], pd.Series([True, True, True, True, False], name="trial_type"), ), ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "detect": [True, False, True, True, True], + } + ), ["aborted"], - pd.Series([False, False, False, False, True], name="trial_type") + pd.Series([False, False, False, False, True], name="trial_type"), ), ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "detect": [True, False, True, True, True], + } + ), [], pd.Series([True, True, True, True, True], name="trial_type"), ), ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "detect": [True, False, True, True, True], + } + ), ["early"], pd.Series([False, False, False, False, False], name="trial_type"), ), ( - pd.DataFrame({"trial_type": [], - "detect": [],}), + pd.DataFrame( + { + "trial_type": [], + "detect": [], + } + ), ["go", "catch"], pd.Series([], name="trial_type"), ), ], ) def test_trial_types(trials, trial_types, expected): - pd.testing.assert_series_equal( - masks.trial_types(trials, trial_types), expected, check_dtype=False) - + pd.testing.assert_series_equal(masks.trial_types(trials, trial_types), expected, check_dtype=False) @pytest.mark.parametrize( "trials, trial_types, expected", [ ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "include": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "include": [True, False, True, True, True], + } + ), ["go", "catch"], - pd.Series([True, True, True, False], name="trial_type", - index=[0, 2, 3, 4]), + pd.Series([True, True, True, False], name="trial_type", index=[0, 2, 3, 4]), ), ], ) def test_trial_types_works_with_subselection(trials, trial_types, expected): pd.testing.assert_series_equal( - masks.trial_types(trials[trials["include"]], trial_types), expected, - check_dtype=False) + masks.trial_types(trials[trials["include"]], trial_types), expected, check_dtype=False + ) @pytest.mark.parametrize( "trials, expected", [ ( - pd.DataFrame({"trial_type": ["go", "go", "catch", "catch", "aborted"], - "detect": [True, False, True, True, True],}), + pd.DataFrame( + { + "trial_type": ["go", "go", "catch", "catch", "aborted"], + "detect": [True, False, True, True, True], + } + ), pd.Series([True, True, True, True, False], name="trial_type"), ), ( - pd.DataFrame({"trial_type": [], - "detect": [],}), + pd.DataFrame( + { + "trial_type": [], + "detect": [], + } + ), pd.Series([], name="trial_type"), ), - ] + ], ) def test_contingent_trials(trials, expected): - pd.testing.assert_series_equal( - masks.contingent_trials(trials), expected, check_dtype=False) + pd.testing.assert_series_equal(masks.contingent_trials(trials), expected, check_dtype=False) @pytest.mark.parametrize( @@ -100,7 +128,7 @@ def test_contingent_trials(trials, expected): 3.0, pd.Series([False, False, False, False], name="reward_rate"), ), - ] + ], ) def test_reward_rate(trials, thresh, expected): - pd.testing.assert_series_equal(masks.reward_rate(trials, thresh), expected) \ No newline at end of file + pd.testing.assert_series_equal(masks.reward_rate(trials, thresh), expected) diff --git a/allensdk/test/brain_observatory/behavior/test_trials_processing.py b/allensdk/test/brain_observatory/behavior/test_trials_processing.py index 8489cd5b44..0798b60e78 100644 --- a/allensdk/test/brain_observatory/behavior/test_trials_processing.py +++ b/allensdk/test/brain_observatory/behavior/test_trials_processing.py @@ -4,226 +4,917 @@ import pandas as pd import numpy as np -from allensdk.brain_observatory.behavior.data_objects.trials.trials \ - import Trials +from allensdk.brain_observatory.behavior.data_objects.trials.trials import Trials _test_response_latency_0 = np.array( - [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, 0.3669842, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 0.41701037, - np.nan, np.nan, 0.31692564, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, 0.28356898, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, 0.33363652, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, 0.21683128, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, 0.38365788, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]) + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.3669842, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.41701037, + np.nan, + np.nan, + 0.31692564, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.28356898, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.33363652, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.21683128, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.38365788, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + ] +) _test_starttime_0 = np.array( - [19.99986754, 22.9857173, 25.25430697, 27.50627203, 30.50876019, - 33.54466563, 37.314543, 38.06517225, 41.0677066, 42.58570486, - 44.83770039, 52.3774269, 57.63187756, 62.91968583, 66.67286547, - 68.17414034, 69.67541489, 74.19590836, 77.19846323, 78.69971244, - 80.9516403, 83.95418202, 86.22276367, 87.72406035, 89.97592663, - 92.97849208, 95.23039781, 97.49898909, 100.50154941, 102.75344423, - 104.25473193, 107.25726563, 109.50917136, 111.7610819, 114.01299179, - 116.26490714, 119.26744951, 121.51938378, 124.53862772, 127.55782612, - 129.80972512, 132.81226619, 137.31609721, 138.81734546, 141.08594764, - 143.33787839, 146.35708545, 149.35962973, 156.86599991, 159.88526245, - 163.65508305, 173.44672267, 175.69862167, 178.70117974, 180.95312267, - 183.22166712, 186.22421269, 189.24343194, 191.49533542, 195.28190828, - 198.28443076, 201.30364616, 203.55556985, 205.82415567, 209.577343, - 213.3472146, 216.34976979, 218.61835433, 222.38820892, 226.14139112, - 233.64779273, 234.39842519, 235.16570215, 237.4176191, 239.66953991, - 241.93811899, 244.94066135, 247.94320564, 249.44448821, 252.44703217, - 253.96497334, 254.71559618, 256.2168919, 263.73995532, 266.74247811, - 268.99441912, 271.2629764, 275.0161833, 276.51745143, 278.03538907, - 280.2873025, 282.53921593, 284.80780047, 287.81034956, 290.06226106, - 291.56355774, 294.56608823, 297.58530492, 299.83719974, 302.08911637, - 304.34102787, 308.1109572, 315.6172597, 317.11854964, 319.37043805, - 321.65571199, 323.90761451, 325.40892787, 327.66080216, 329.91271591, - 332.9319358, 336.68513692, 339.70434687, 341.20561115, 341.95625997, - 345.72611233, 347.22738976, 349.4626279, 357.01902553, 359.27093992, - 362.29016334, 365.29272109, 368.31197049, 372.06513441, 373.58308296, - 375.83498259, 377.3362504, 379.58815677, 382.60747224, 387.11118876, - 390.11375903, 392.3823596, 394.63425956, 396.88616336, 399.92207297, - 402.92464964, 405.94383939, 408.19575538, 410.46437038, 413.46690377, - 416.4694612, 418.72134384, 420.95662335, 423.95913236, 425.46040562, - 427.71231745, 429.98091481, 432.98349662, 434.48471376, 435.98598863, - 438.98854735, 442.0077961, 444.25968163, 446.54494947, 448.02954296, - 451.79940525, 454.81863412, 457.07053664, 459.33912952, 461.59103621, - 463.85962107, 465.360903, 467.61282348, 469.8814, 471.3826691, - 474.38523327, 477.42112684, 479.67303705, 481.94162769, 484.94417229, - 487.96340758, 490.21530914, 492.46722064, 494.73581031, 498.48898834, - 501.50822074, 503.00949272, 505.2780933, 507.53000255, 510.54922212, - 514.31907961, 516.57098918, 518.83957885, 520.34085308, 523.3600781, - 525.61201879, 529.39853328, 531.6504467, 533.91904696, 536.17094563, - 539.94081562, 542.19272616, 545.19526724, 546.69653986, 548.19781249, - 550.44972206, 553.46895029, 556.4714981, 559.47403694, 562.47658507]) + [ + 19.99986754, + 22.9857173, + 25.25430697, + 27.50627203, + 30.50876019, + 33.54466563, + 37.314543, + 38.06517225, + 41.0677066, + 42.58570486, + 44.83770039, + 52.3774269, + 57.63187756, + 62.91968583, + 66.67286547, + 68.17414034, + 69.67541489, + 74.19590836, + 77.19846323, + 78.69971244, + 80.9516403, + 83.95418202, + 86.22276367, + 87.72406035, + 89.97592663, + 92.97849208, + 95.23039781, + 97.49898909, + 100.50154941, + 102.75344423, + 104.25473193, + 107.25726563, + 109.50917136, + 111.7610819, + 114.01299179, + 116.26490714, + 119.26744951, + 121.51938378, + 124.53862772, + 127.55782612, + 129.80972512, + 132.81226619, + 137.31609721, + 138.81734546, + 141.08594764, + 143.33787839, + 146.35708545, + 149.35962973, + 156.86599991, + 159.88526245, + 163.65508305, + 173.44672267, + 175.69862167, + 178.70117974, + 180.95312267, + 183.22166712, + 186.22421269, + 189.24343194, + 191.49533542, + 195.28190828, + 198.28443076, + 201.30364616, + 203.55556985, + 205.82415567, + 209.577343, + 213.3472146, + 216.34976979, + 218.61835433, + 222.38820892, + 226.14139112, + 233.64779273, + 234.39842519, + 235.16570215, + 237.4176191, + 239.66953991, + 241.93811899, + 244.94066135, + 247.94320564, + 249.44448821, + 252.44703217, + 253.96497334, + 254.71559618, + 256.2168919, + 263.73995532, + 266.74247811, + 268.99441912, + 271.2629764, + 275.0161833, + 276.51745143, + 278.03538907, + 280.2873025, + 282.53921593, + 284.80780047, + 287.81034956, + 290.06226106, + 291.56355774, + 294.56608823, + 297.58530492, + 299.83719974, + 302.08911637, + 304.34102787, + 308.1109572, + 315.6172597, + 317.11854964, + 319.37043805, + 321.65571199, + 323.90761451, + 325.40892787, + 327.66080216, + 329.91271591, + 332.9319358, + 336.68513692, + 339.70434687, + 341.20561115, + 341.95625997, + 345.72611233, + 347.22738976, + 349.4626279, + 357.01902553, + 359.27093992, + 362.29016334, + 365.29272109, + 368.31197049, + 372.06513441, + 373.58308296, + 375.83498259, + 377.3362504, + 379.58815677, + 382.60747224, + 387.11118876, + 390.11375903, + 392.3823596, + 394.63425956, + 396.88616336, + 399.92207297, + 402.92464964, + 405.94383939, + 408.19575538, + 410.46437038, + 413.46690377, + 416.4694612, + 418.72134384, + 420.95662335, + 423.95913236, + 425.46040562, + 427.71231745, + 429.98091481, + 432.98349662, + 434.48471376, + 435.98598863, + 438.98854735, + 442.0077961, + 444.25968163, + 446.54494947, + 448.02954296, + 451.79940525, + 454.81863412, + 457.07053664, + 459.33912952, + 461.59103621, + 463.85962107, + 465.360903, + 467.61282348, + 469.8814, + 471.3826691, + 474.38523327, + 477.42112684, + 479.67303705, + 481.94162769, + 484.94417229, + 487.96340758, + 490.21530914, + 492.46722064, + 494.73581031, + 498.48898834, + 501.50822074, + 503.00949272, + 505.2780933, + 507.53000255, + 510.54922212, + 514.31907961, + 516.57098918, + 518.83957885, + 520.34085308, + 523.3600781, + 525.61201879, + 529.39853328, + 531.6504467, + 533.91904696, + 536.17094563, + 539.94081562, + 542.19272616, + 545.19526724, + 546.69653986, + 548.19781249, + 550.44972206, + 553.46895029, + 556.4714981, + 559.47403694, + 562.47658507, + ] +) -expected_result_0 = np.array([ - np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, 0.85743611, 0.82215855, 0.79754855, 0.77420232, - 0.74532604, 0.72504419, 0.73828876, 0.73168092, 0.73168145, - 0.73844044, 0.745635, 0.75997116, 0.73889553, 0.7457893, 0.73212764, - 0.72533739, 0., 0., 0., 0., 0., 0., 0., 0.83147215, 0.76759434, - 0.76013235, 1.50562915, 1.37576878, 1.36403067, 1.3524898, - 1.3640296, 1.36377167, 1.35249025, 1.35223636, 1.35223623, - 1.31828762, 1.31828779, 1.30726822, 1.30726804, 1.30703059, 1.28600222, - 1.27551339, 1.26541718, 1.27551391, 1.26541723, 1.86854445, 1.78508514, - 1.85409645, 1.86822076, 1.86854435, 1.86854454, 1.88321881, 1.88321885, - 1.31756348, 1.33989546, 1.35147388, 0.74517267, 0.75933052, 1.54807324, - 1.44950587, 1.43676766, 1.44979704, 1.46306592, 1.43676702, 1.47718591, - 1.50468411, 1.51930166, 1.51930185, 1.51930188, 1.53387944, 1.5642303, - 1.59545215, 1.58003398, 1.59580631, 1.62831513, 0.87666335, 0.85784626, - 1.6450693, 1.53453391, 1.54940651, 1.54974049, 1.56423021, 1.57968714, - 1.5796865, 1.59545254, 1.5800338, 1.53420629, 1.49127149, 0.78984375, - 0.80576787, 0.82234767, 0.80576784, 0.83089596, 1.64507108, 1.51930204, - 1.51930202, 1.50468432, 1.49096252, 1.49065321, 1.46336336, 1.46306626, - 1.4765797, 1.50468436, 1.50468414, 1.4903434, 1.44979783, 1.46336463, - 0.78160518, 0.77403664, 0.77403649, 0.76661287, 0.75932993, 0.74501851, - 0.74501813, 0.74486366, 0.74501799, 0.75202743, 0.7593303, 0.75234155, - 0.73168169, 0.7524993, 0.74548119, 0.74517234, 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., -]) +expected_result_0 = np.array( + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.85743611, + 0.82215855, + 0.79754855, + 0.77420232, + 0.74532604, + 0.72504419, + 0.73828876, + 0.73168092, + 0.73168145, + 0.73844044, + 0.745635, + 0.75997116, + 0.73889553, + 0.7457893, + 0.73212764, + 0.72533739, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.83147215, + 0.76759434, + 0.76013235, + 1.50562915, + 1.37576878, + 1.36403067, + 1.3524898, + 1.3640296, + 1.36377167, + 1.35249025, + 1.35223636, + 1.35223623, + 1.31828762, + 1.31828779, + 1.30726822, + 1.30726804, + 1.30703059, + 1.28600222, + 1.27551339, + 1.26541718, + 1.27551391, + 1.26541723, + 1.86854445, + 1.78508514, + 1.85409645, + 1.86822076, + 1.86854435, + 1.86854454, + 1.88321881, + 1.88321885, + 1.31756348, + 1.33989546, + 1.35147388, + 0.74517267, + 0.75933052, + 1.54807324, + 1.44950587, + 1.43676766, + 1.44979704, + 1.46306592, + 1.43676702, + 1.47718591, + 1.50468411, + 1.51930166, + 1.51930185, + 1.51930188, + 1.53387944, + 1.5642303, + 1.59545215, + 1.58003398, + 1.59580631, + 1.62831513, + 0.87666335, + 0.85784626, + 1.6450693, + 1.53453391, + 1.54940651, + 1.54974049, + 1.56423021, + 1.57968714, + 1.5796865, + 1.59545254, + 1.5800338, + 1.53420629, + 1.49127149, + 0.78984375, + 0.80576787, + 0.82234767, + 0.80576784, + 0.83089596, + 1.64507108, + 1.51930204, + 1.51930202, + 1.50468432, + 1.49096252, + 1.49065321, + 1.46336336, + 1.46306626, + 1.4765797, + 1.50468436, + 1.50468414, + 1.4903434, + 1.44979783, + 1.46336463, + 0.78160518, + 0.77403664, + 0.77403649, + 0.76661287, + 0.75932993, + 0.74501851, + 0.74501813, + 0.74486366, + 0.74501799, + 0.75202743, + 0.7593303, + 0.75234155, + 0.73168169, + 0.7524993, + 0.74548119, + 0.74517234, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] +) expected_result_1 = np.array( - [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, - np.nan, np.nan, 1.811146, 1.944290, 1.898119, 1.811146, 1.733465, - 1.771897, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 2.417308, 2.047210, 1.994977, 3.890695, 3.321282, - 3.253684, 3.190197, 3.190196, 3.255157, 3.255157, 1.853143, 1.898129, - 1.897124, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 2.153860, 1.855050, 1.945345, 2.044882, - 2.155151, 2.279434, 2.344817, 2.279435, 2.347877, 2.574765, 0.000000, - 0.000000, 0.000000, 3.191613, 2.492687, 2.418930, 2.494413, 2.572924, - 2.346344, 2.492686, 2.492686, 2.346343, 2.279434, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 2.574758, 2.157737, 2.217599, 2.157739, 2.214870, 2.279435, 2.346341, - 2.346345, 2.346345, 2.417310, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 2.752063, 2.213507, 2.277990, 2.343290, 2.344815, - 2.213503, 1.992768, 2.153859, 2.097345, 2.152573, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, - 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, ]) + [ + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 1.811146, + 1.944290, + 1.898119, + 1.811146, + 1.733465, + 1.771897, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 2.417308, + 2.047210, + 1.994977, + 3.890695, + 3.321282, + 3.253684, + 3.190197, + 3.190196, + 3.255157, + 3.255157, + 1.853143, + 1.898129, + 1.897124, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 2.153860, + 1.855050, + 1.945345, + 2.044882, + 2.155151, + 2.279434, + 2.344817, + 2.279435, + 2.347877, + 2.574765, + 0.000000, + 0.000000, + 0.000000, + 3.191613, + 2.492687, + 2.418930, + 2.494413, + 2.572924, + 2.346344, + 2.492686, + 2.492686, + 2.346343, + 2.279434, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 2.574758, + 2.157737, + 2.217599, + 2.157739, + 2.214870, + 2.279435, + 2.346341, + 2.346345, + 2.346345, + 2.417310, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 2.752063, + 2.213507, + 2.277990, + 2.343290, + 2.344815, + 2.213503, + 1.992768, + 2.153859, + 2.097345, + 2.152573, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + 0.000000, + ] +) -@pytest.mark.parametrize('kwargs, expected', [ - ( +@pytest.mark.parametrize( + "kwargs, expected", + [ + ( { - 'response_latency': _test_response_latency_0, - 'starttime': _test_starttime_0, - 'trial_window': 15, - 'initial_trials': 10, + "response_latency": _test_response_latency_0, + "starttime": _test_starttime_0, + "trial_window": 15, + "initial_trials": 10, }, expected_result_0, - ), - ( + ), + ( { - 'response_latency': _test_response_latency_0, - 'starttime': _test_starttime_0, - 'trial_window': 5, - 'initial_trials': 10, + "response_latency": _test_response_latency_0, + "starttime": _test_starttime_0, + "trial_window": 5, + "initial_trials": 10, }, expected_result_1, - ), -]) + ), + ], +) def test_calculate_reward_rate(kwargs, expected): - with patch.object(Trials, '_calculate_response_latency_list', - wraps=lambda: kwargs['response_latency']): - with patch.object(Trials, 'start_time', - new_callable=PropertyMock) as mock_start_time: - mock_start_time.return_value = pd.Series(kwargs['starttime']) + with patch.object(Trials, "_calculate_response_latency_list", wraps=lambda: kwargs["response_latency"]): + with patch.object(Trials, "start_time", new_callable=PropertyMock) as mock_start_time: + mock_start_time.return_value = pd.Series(kwargs["starttime"]) trials = Trials(trials=pd.DataFrame(), response_window_start=0) reward_rate = trials.calculate_reward_rate( - trial_window=kwargs['trial_window'], - initial_trials=kwargs['initial_trials'] + trial_window=kwargs["trial_window"], initial_trials=kwargs["initial_trials"] ) - assert np.allclose(reward_rate, expected, equal_nan=True), \ + assert np.allclose(reward_rate, expected, equal_nan=True), ( "calculated reward rate should match expected reward rate :(" + ) def trial_data_and_expectation_0(): test_trial = { - 'index': 3, - 'cumulative_rewards': 1, - 'licks': [(318.2737866026219, 18736), - (318.4235244484611, 18745), - (318.55351991075554, 18753), - (318.6735239364698, 18760), - (318.8235420609609, 18769), - (318.9733899117824, 18778), - (319.153503175955, 18789), - (319.35351008052305, 18801), - (321.24372627834714, 18914), - (321.3438153063156, 18920), - (321.49348118080985, 18929), - (321.6237259097134, 18937)], - 'stimulus_changes': [(('im065', 'im065'), - ('im062', 'im062'), - 317.76644976765834, - 18706)], - 'success': False, - 'cumulative_volume': 0.005, - 'trial_params': {'catch': False, - 'auto_reward': True, - 'change_time': 5}, - 'rewards': [(0.005, 317.92325660388286, 18715)], - 'events': [['trial_start', '', 314.0120642698258, 18481], - ['initial_blank', 'enter', 314.01216666808074, 18481], - ['initial_blank', 'exit', 314.0122573636779, 18481], - ['pre_change', 'enter', 314.01233489378524, 18481], - ['pre_change', 'exit', 314.0124103759274, 18481], - ['stimulus_window', 'enter', 314.01248819860115, 18481], - ['stimulus_changed', '', 317.7666744586863, 18706], - ['auto_reward', '', 317.76681547571155, 18706], - ['response_window', 'enter', 317.9231027139341, 18715], - ['response_window', 'exit', 318.532233361527, 18752], - ['miss', '', 318.5324346472395, 18752], - ['stimulus_window', 'exit', 322.0351179203675, 18962], - ['no_lick', 'exit', 322.0352864386384, 18962], - ['trial_end', '', 322.0353750862705, 18962]] + "index": 3, + "cumulative_rewards": 1, + "licks": [ + (318.2737866026219, 18736), + (318.4235244484611, 18745), + (318.55351991075554, 18753), + (318.6735239364698, 18760), + (318.8235420609609, 18769), + (318.9733899117824, 18778), + (319.153503175955, 18789), + (319.35351008052305, 18801), + (321.24372627834714, 18914), + (321.3438153063156, 18920), + (321.49348118080985, 18929), + (321.6237259097134, 18937), + ], + "stimulus_changes": [(("im065", "im065"), ("im062", "im062"), 317.76644976765834, 18706)], + "success": False, + "cumulative_volume": 0.005, + "trial_params": {"catch": False, "auto_reward": True, "change_time": 5}, + "rewards": [(0.005, 317.92325660388286, 18715)], + "events": [ + ["trial_start", "", 314.0120642698258, 18481], + ["initial_blank", "enter", 314.01216666808074, 18481], + ["initial_blank", "exit", 314.0122573636779, 18481], + ["pre_change", "enter", 314.01233489378524, 18481], + ["pre_change", "exit", 314.0124103759274, 18481], + ["stimulus_window", "enter", 314.01248819860115, 18481], + ["stimulus_changed", "", 317.7666744586863, 18706], + ["auto_reward", "", 317.76681547571155, 18706], + ["response_window", "enter", 317.9231027139341, 18715], + ["response_window", "exit", 318.532233361527, 18752], + ["miss", "", 318.5324346472395, 18752], + ["stimulus_window", "exit", 322.0351179203675, 18962], + ["no_lick", "exit", 322.0352864386384, 18962], + ["trial_end", "", 322.0353750862705, 18962], + ], } expected_result = { - 'reward_volume': 0.005, - 'hit': False, - 'false_alarm': False, - 'miss': False, - 'sham_change': False, - 'stimulus_change': True, - 'aborted': False, - 'go': False, - 'catch': False, - 'auto_rewarded': True, - 'correct_reject': False + "reward_volume": 0.005, + "hit": False, + "false_alarm": False, + "miss": False, + "sham_change": False, + "stimulus_change": True, + "aborted": False, + "go": False, + "catch": False, + "auto_rewarded": True, + "correct_reject": False, } return test_trial, expected_result @@ -231,43 +922,41 @@ def trial_data_and_expectation_0(): def trial_data_and_expectation_1(): test_trial = { - 'index': 4, - 'cumulative_rewards': 1, - 'licks': [(324.1935569847751, 19091), - (324.34329131981696, 19100), - (324.49368158882305, 19109)], - 'stimulus_changes': [], - 'success': False, - 'cumulative_volume': 0.005, - 'trial_params': {'catch': False, - 'auto_reward': True, - 'change_time': 6}, - 'rewards': [], - 'events': [['trial_start', '', 322.2688823113451, 18976], - ['initial_blank', 'enter', 322.2689858798658, 18976], - ['initial_blank', 'exit', 322.26907599033007, 18976], - ['pre_change', 'enter', 322.2691523501716, 18976], - ['pre_change', 'exit', 322.26922900257955, 18976], - ['stimulus_window', 'enter', 322.26930536242105, 18976], - ['early_response', '', 324.1937059010944, 19091], - ['abort', '', 324.1937848940339, 19091], - ['timeout', 'enter', 324.19388963282034, 19091], - ['timeout', 'exit', 324.8042502297378, 19128], - ['trial_end', '', 324.80448691598986, 19128]] + "index": 4, + "cumulative_rewards": 1, + "licks": [(324.1935569847751, 19091), (324.34329131981696, 19100), (324.49368158882305, 19109)], + "stimulus_changes": [], + "success": False, + "cumulative_volume": 0.005, + "trial_params": {"catch": False, "auto_reward": True, "change_time": 6}, + "rewards": [], + "events": [ + ["trial_start", "", 322.2688823113451, 18976], + ["initial_blank", "enter", 322.2689858798658, 18976], + ["initial_blank", "exit", 322.26907599033007, 18976], + ["pre_change", "enter", 322.2691523501716, 18976], + ["pre_change", "exit", 322.26922900257955, 18976], + ["stimulus_window", "enter", 322.26930536242105, 18976], + ["early_response", "", 324.1937059010944, 19091], + ["abort", "", 324.1937848940339, 19091], + ["timeout", "enter", 324.19388963282034, 19091], + ["timeout", "exit", 324.8042502297378, 19128], + ["trial_end", "", 324.80448691598986, 19128], + ], } expected_result = { - 'reward_volume': 0, - 'hit': False, - 'false_alarm': False, - 'miss': False, - 'sham_change': False, - 'stimulus_change': False, - 'aborted': True, - 'go': False, - 'catch': False, - 'auto_rewarded': False, - 'correct_reject': False + "reward_volume": 0, + "hit": False, + "false_alarm": False, + "miss": False, + "sham_change": False, + "stimulus_change": False, + "aborted": True, + "go": False, + "catch": False, + "auto_rewarded": False, + "correct_reject": False, } return test_trial, expected_result @@ -275,91 +964,87 @@ def trial_data_and_expectation_1(): def trial_data_and_expectation_2(): test_trial = { - 'index': 51, - 'cumulative_rewards': 11, - 'licks': [(542.6200214334176, 32186), - (542.7097825733969, 32191), - (542.8597161461861, 32200), - (542.9599280520605, 32206), - (543.059708422432, 32212), - (543.15998088956, 32218), - (543.2899491431752, 32226), - (543.4098750536493, 32233), - (543.5197477960238, 32240), - (543.6596846660369, 32248), - (543.7699336488565, 32255), - (543.8897463361172, 32262), - (544.0196821148575, 32270), - (544.13974055793, 32277), - (544.2596729048659, 32284), - (544.3896745110557, 32292), - (544.5397306691843, 32301)], - 'stimulus_changes': [(('im069', 'im069'), - ('im085', 'im085'), - 542.2007438794369, - 32161)], - 'success': True, - 'cumulative_volume': 0.067, - 'trial_params': {'catch': False, - 'auto_reward': False, - 'change_time': 4}, - 'rewards': [(0.007, 542.620156599114, 32186)], - 'events': [['trial_start', '', 539.1971251251088, 31981], - ['initial_blank', 'enter', 539.197228401063, 31981], - ['initial_blank', 'exit', 539.1973220223246, 31981], - ['pre_change', 'enter', 539.1974007226976, 31981], - ['pre_change', 'exit', 539.197477667672, 31981], - ['stimulus_window', 'enter', 539.1975575383109, 31981], - ['stimulus_changed', '', 542.2009428246179, 32161], - ['response_window', 'enter', 542.3661398812824, 32171], - ['hit', '', 542.6201402153932, 32186], - ['response_window', 'exit', 542.9666720011281, 32207], - ['stimulus_window', 'exit', 546.4695340323526, 32417], - ['no_lick', 'exit', 546.4696966992947, 32417], - ['trial_end', '', 546.4697827138287, 32417]] + "index": 51, + "cumulative_rewards": 11, + "licks": [ + (542.6200214334176, 32186), + (542.7097825733969, 32191), + (542.8597161461861, 32200), + (542.9599280520605, 32206), + (543.059708422432, 32212), + (543.15998088956, 32218), + (543.2899491431752, 32226), + (543.4098750536493, 32233), + (543.5197477960238, 32240), + (543.6596846660369, 32248), + (543.7699336488565, 32255), + (543.8897463361172, 32262), + (544.0196821148575, 32270), + (544.13974055793, 32277), + (544.2596729048659, 32284), + (544.3896745110557, 32292), + (544.5397306691843, 32301), + ], + "stimulus_changes": [(("im069", "im069"), ("im085", "im085"), 542.2007438794369, 32161)], + "success": True, + "cumulative_volume": 0.067, + "trial_params": {"catch": False, "auto_reward": False, "change_time": 4}, + "rewards": [(0.007, 542.620156599114, 32186)], + "events": [ + ["trial_start", "", 539.1971251251088, 31981], + ["initial_blank", "enter", 539.197228401063, 31981], + ["initial_blank", "exit", 539.1973220223246, 31981], + ["pre_change", "enter", 539.1974007226976, 31981], + ["pre_change", "exit", 539.197477667672, 31981], + ["stimulus_window", "enter", 539.1975575383109, 31981], + ["stimulus_changed", "", 542.2009428246179, 32161], + ["response_window", "enter", 542.3661398812824, 32171], + ["hit", "", 542.6201402153932, 32186], + ["response_window", "exit", 542.9666720011281, 32207], + ["stimulus_window", "exit", 546.4695340323526, 32417], + ["no_lick", "exit", 546.4696966992947, 32417], + ["trial_end", "", 546.4697827138287, 32417], + ], } expected_result = { - 'reward_volume': 0.007, - 'hit': True, - 'false_alarm': False, - 'miss': False, - 'sham_change': False, - 'stimulus_change': True, - 'aborted': False, - 'go': True, - 'catch': False, - 'auto_rewarded': False, - 'correct_reject': False + "reward_volume": 0.007, + "hit": True, + "false_alarm": False, + "miss": False, + "sham_change": False, + "stimulus_change": True, + "aborted": False, + "go": True, + "catch": False, + "auto_rewarded": False, + "correct_reject": False, } return test_trial, expected_result @pytest.mark.parametrize( - "trials, response_window_start, expected", - [ - ( - pd.DataFrame({ - "change_time": [1, 2, 3, 4], - "lick_times": [[1.1], [2.1, 2.2], [3.3, 3.4], [4.4]]}), - 0.0, - [0.1, 0.1, 0.3, 0.4]), - ( - pd.DataFrame({ - "change_time": [1, 2, 3, 4], - "lick_times": [[1.1], [], [3.3, 3.4], [4.4]]}), - 0.0, - [0.1, float("inf"), 0.3, 0.4]), - ( - pd.DataFrame({ - "change_time": [1, 2, 3, 4], - "lick_times": [[1.1], [], [3.3, 3.4], [4.4]]}), - 0.15, - [float("inf"), float("inf"), 0.3, 0.4]), - ]) -def test_calculate_response_latency_list( - trials, response_window_start, expected): + "trials, response_window_start, expected", + [ + ( + pd.DataFrame({"change_time": [1, 2, 3, 4], "lick_times": [[1.1], [2.1, 2.2], [3.3, 3.4], [4.4]]}), + 0.0, + [0.1, 0.1, 0.3, 0.4], + ), + ( + pd.DataFrame({"change_time": [1, 2, 3, 4], "lick_times": [[1.1], [], [3.3, 3.4], [4.4]]}), + 0.0, + [0.1, float("inf"), 0.3, 0.4], + ), + ( + pd.DataFrame({"change_time": [1, 2, 3, 4], "lick_times": [[1.1], [], [3.3, 3.4], [4.4]]}), + 0.15, + [float("inf"), float("inf"), 0.3, 0.4], + ), + ], +) +def test_calculate_response_latency_list(trials, response_window_start, expected): trials = Trials(trials=trials, response_window_start=response_window_start) latencies = trials._calculate_response_latency_list() np.testing.assert_allclose(latencies, expected) @@ -367,40 +1052,27 @@ def test_calculate_response_latency_list( @pytest.fixture def trials_example(): - """minimal example for test_construct_rolling_performance_df - """ + """minimal example for test_construct_rolling_performance_df""" trials_dict = { - 'start_time': { - 8: 368.305066913832, - 9: 378.0631642451044, - 10: 386.31999971927144, - 11: 394.57686825376004}, - 'change_time': { - 8: 1, - 9: 2, - 10: 3, - 11: 4 - }, - 'lick_times': { - 8: np.array([]), - 9: np.array([]), - 10: np.array([]), - 11: np.array([])}, - 'hit': {8: False, 9: False, 10: False, 11: False}, - 'false_alarm': {8: False, 9: False, 10: False, 11: False}, - 'miss': {8: True, 9: False, 10: True, 11: True}, - 'aborted': {8: False, 9: False, 10: False, 11: False}, - 'correct_reject': {8: False, 9: True, 10: False, 11: False}} + "start_time": {8: 368.305066913832, 9: 378.0631642451044, 10: 386.31999971927144, 11: 394.57686825376004}, + "change_time": {8: 1, 9: 2, 10: 3, 11: 4}, + "lick_times": {8: np.array([]), 9: np.array([]), 10: np.array([]), 11: np.array([])}, + "hit": {8: False, 9: False, 10: False, 11: False}, + "false_alarm": {8: False, 9: False, 10: False, 11: False}, + "miss": {8: True, 9: False, 10: True, 11: True}, + "aborted": {8: False, 9: False, 10: False, 11: False}, + "correct_reject": {8: False, 9: True, 10: False, 11: False}, + } return pd.DataFrame(trials_dict) -@pytest.mark.parametrize('is_passive', [True, False]) +@pytest.mark.parametrize("is_passive", [True, False]) def test_construct_rolling_performance_df(trials_example, is_passive): """tests that ending a session_type with "passive" replaces rolling_dprime values with all zeros """ # Mouse not rewarded if passive - trials_example['reward_volume'] = 0 if is_passive else [0, .25, 0, 0] + trials_example["reward_volume"] = 0 if is_passive else [0, 0.25, 0, 0] trials = Trials(trials=trials_example, response_window_start=0.15) df = trials.rolling_performance diff --git a/allensdk/test/brain_observatory/conftest.py b/allensdk/test/brain_observatory/conftest.py index acd2ab54dd..04cf76d661 100644 --- a/allensdk/test/brain_observatory/conftest.py +++ b/allensdk/test/brain_observatory/conftest.py @@ -4,44 +4,38 @@ import pynwb import numpy as np -from allensdk.brain_observatory.vbn_2022.input_json_writer.utils import ( - vbn_nwb_config_from_ecephys_session_id_list) +from allensdk.brain_observatory.vbn_2022.input_json_writer.utils import vbn_nwb_config_from_ecephys_session_id_list @pytest.fixture def running_speed(): from allensdk.brain_observatory.running_speed import RunningSpeed - return RunningSpeed( - timestamps=[1., 2., 3.], - values=[4, 5, 6] - ) + + return RunningSpeed(timestamps=[1.0, 2.0, 3.0], values=[4, 5, 6]) @pytest.fixture def nwbfile(): - return pynwb.NWBFile( - session_description='asession', - identifier='afile', - session_start_time=datetime.now() - ) + return pynwb.NWBFile(session_description="asession", identifier="afile", session_start_time=datetime.now()) @pytest.fixture def roundtripper(tmpdir_factory): def f(nwbfile, api_cls, **api_kwargs): - tmpdir = str(tmpdir_factory.mktemp('nwb_roundtrip_tests')) - nwb_path = os.path.join(tmpdir, 'nwbfile.nwb') + tmpdir = str(tmpdir_factory.mktemp("nwb_roundtrip_tests")) + nwb_path = os.path.join(tmpdir, "nwbfile.nwb") - with pynwb.NWBHDF5IO(nwb_path, 'w') as write_io: + with pynwb.NWBHDF5IO(nwb_path, "w") as write_io: write_io.write(nwbfile) return api_cls(nwb_path, **api_kwargs) + return f @pytest.fixture def stimulus_timestamps(): - return np.array([1., 2., 3.]) + return np.array([1.0, 2.0, 3.0]) @pytest.fixture @@ -51,29 +45,27 @@ def f(nwbfile, data_object_cls, **data_object_cls_kwargs): tmp_dir.mkdir() nwb_path = tmp_dir / "data_object_roundtrip_nwbfile.nwb" - with pynwb.NWBHDF5IO(str(nwb_path), 'w') as write_io: + with pynwb.NWBHDF5IO(str(nwb_path), "w") as write_io: write_io.write(nwbfile) - with pynwb.NWBHDF5IO(str(nwb_path), 'r') as read_io: + with pynwb.NWBHDF5IO(str(nwb_path), "r") as read_io: roundtripped_nwbfile = read_io.read() - data_object_instance = data_object_cls.from_nwb( - roundtripped_nwbfile, **data_object_cls_kwargs - ) + data_object_instance = data_object_cls.from_nwb(roundtripped_nwbfile, **data_object_cls_kwargs) return data_object_instance return f -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def behavior_ecephys_session_config_fixture(): """ Return a dict representing the session_data needed to create a BehaviorEcephysSession """ session_data_list = vbn_nwb_config_from_ecephys_session_id_list( - ecephys_session_id_list=[1111216934], - probes_to_skip=None) + ecephys_session_id_list=[1111216934], probes_to_skip=None + ) - return session_data_list['sessions'][0] + return session_data_list["sessions"][0] diff --git a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/conftest.py b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/conftest.py index 9c6faa507b..3fd7b208f1 100644 --- a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/conftest.py +++ b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/conftest.py @@ -6,21 +6,18 @@ import pandas as pd -@pytest.fixture(scope='session') -def some_files_fixture( - tmp_path_factory, - helper_functions): +@pytest.fixture(scope="session") +def some_files_fixture(tmp_path_factory, helper_functions): """ Create some temporary files; return a list of paths to them """ - tmpdir = pathlib.Path( - tmp_path_factory.mktemp('id_generator')) + tmpdir = pathlib.Path(tmp_path_factory.mktemp("id_generator")) path_list = [] for idx in range(4): - this_path = tmpdir / f'{idx}' / f'silly_file_{idx}.nwb' + this_path = tmpdir / f"{idx}" / f"silly_file_{idx}.nwb" os.makedirs(this_path.parent, exist_ok=True) - with open(this_path, 'w') as out_file: - out_file.write(f'this is file {idx}') + with open(this_path, "w") as out_file: + out_file.write(f"this is file {idx}") path_list.append(this_path) yield path_list @@ -28,8 +25,7 @@ def some_files_fixture( @pytest.fixture -def metadata_table_fixture( - some_files_fixture): +def metadata_table_fixture(some_files_fixture): """ A metadata_table suitable to be run through add_file_path_to_metadata_table with some_files_fixture @@ -41,18 +37,12 @@ def metadata_table_fixture( input_data = [] for i, file_path in enumerate(some_files_fixture): file_name = file_path.name - file_idx = file_name.split('_')[-1].split('.')[0] + file_idx = file_name.split("_")[-1].split(".")[0] file_idx = int(file_idx) - element = {'file_index': file_idx, - 'some_data': float(rng.random()), - 'session_id': i} + element = {"file_index": file_idx, "some_data": float(rng.random()), "session_id": i} input_data.append(element) # Not real files, never written - input_data.append({'file_index': 11232, - 'some_data': float(rng.random()), - 'session_id': 11232}) - input_data.append({'file_index': 4455, - 'some_data': float(rng.random()), - 'session_id': 4455}) + input_data.append({"file_index": 11232, "some_data": float(rng.random()), "session_id": 11232}) + input_data.append({"file_index": 4455, "some_data": float(rng.random()), "session_id": 4455}) return pd.DataFrame(data=input_data) diff --git a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_id_generator.py b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_id_generator.py index 9e408c70ed..baede3a675 100644 --- a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_id_generator.py +++ b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_id_generator.py @@ -1,8 +1,6 @@ import pytest import pathlib -from allensdk.brain_observatory.data_release_utils \ - .metadata_utils.id_generator import ( - FileIDGenerator) +from allensdk.brain_observatory.data_release_utils.metadata_utils.id_generator import FileIDGenerator def test_not_a_file_error(): @@ -11,13 +9,12 @@ def test_not_a_file_error(): something that is not a file. """ generator = FileIDGenerator() - dummy = pathlib.Path('something.nwb') + dummy = pathlib.Path("something.nwb") with pytest.raises(ValueError, match="is not a file"): generator.id_from_path(file_path=dummy) -def test_not_a_path_error( - some_files_fixture): +def test_not_a_path_error(some_files_fixture): """ Test that an error is raised when you try to get an ID from a file_path that is not a pathlib.Path @@ -28,26 +25,21 @@ def test_not_a_path_error( generator.id_from_path(file_path=str_path) -def test_symlink_error( - some_files_fixture, - tmp_path_factory, - helper_functions): +def test_symlink_error(some_files_fixture, tmp_path_factory, helper_functions): """ Test that an error is raised if you try to get an ID for a symlink """ - tmp_dir = pathlib.Path(tmp_path_factory.mktemp('symlink_test')) - symlink_path = tmp_dir / 'a_silly_symlink.nwb' + tmp_dir = pathlib.Path(tmp_path_factory.mktemp("symlink_test")) + symlink_path = tmp_dir / "a_silly_symlink.nwb" symlink_path.symlink_to(some_files_fixture[2]) assert symlink_path.is_symlink() generator = FileIDGenerator() - with pytest.raises(ValueError, - match="is a symlink; must be an actual path"): + with pytest.raises(ValueError, match="is a symlink; must be an actual path"): generator.id_from_path(file_path=symlink_path) helper_functions.windows_safe_cleanup(file_path=symlink_path) -def test_file_id_generator( - some_files_fixture): +def test_file_id_generator(some_files_fixture): """ Test that FileIDGenerator assigns unique IDs to files """ diff --git a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_utils.py b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_utils.py index c9b50355da..9a7929a1fe 100644 --- a/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_utils.py +++ b/allensdk/test/brain_observatory/data_release_utils/metadata_utils/test_utils.py @@ -1,61 +1,47 @@ import pytest import pandas as pd -from allensdk.brain_observatory.data_release_utils \ - .metadata_utils.id_generator import ( - FileIDGenerator) +from allensdk.brain_observatory.data_release_utils.metadata_utils.id_generator import FileIDGenerator -from allensdk.brain_observatory.data_release_utils \ - .metadata_utils.utils import ( - add_file_paths_to_metadata_table) +from allensdk.brain_observatory.data_release_utils.metadata_utils.utils import add_file_paths_to_metadata_table -def test_add_file_paths_to_metadata_table_on_missing_error( - some_files_fixture, - metadata_table_fixture): +def test_add_file_paths_to_metadata_table_on_missing_error(some_files_fixture, metadata_table_fixture): """ Test that an error is raised by add_file_paths_to_metadata_table when on_missing_file is a nonsense value """ - with pytest.raises(ValueError, - match="on_missing_file must be one of"): + with pytest.raises(ValueError, match="on_missing_file must be one of"): add_file_paths_to_metadata_table( metadata_table=metadata_table_fixture, id_generator=FileIDGenerator(), file_dir=some_files_fixture[0].parent, - file_prefix='silly_file', - index_col='file_index', - on_missing_file='whatever', - data_dir_col='session_id' + file_prefix="silly_file", + index_col="file_index", + on_missing_file="whatever", + data_dir_col="session_id", ) -def test_add_file_paths_to_metadata_table_no_file_error( - some_files_fixture, - metadata_table_fixture): +def test_add_file_paths_to_metadata_table_no_file_error(some_files_fixture, metadata_table_fixture): """ Test that an error is raised by add_file_paths_to_metadata_table when files are missing (if requested) """ - with pytest.raises(RuntimeError, - match="The following files do not exist"): + with pytest.raises(RuntimeError, match="The following files do not exist"): add_file_paths_to_metadata_table( metadata_table=metadata_table_fixture, id_generator=FileIDGenerator(), file_dir=some_files_fixture[0].parent, - file_prefix='silly_file', - index_col='file_index', - on_missing_file='error', - data_dir_col='session_id' + file_prefix="silly_file", + index_col="file_index", + on_missing_file="error", + data_dir_col="session_id", ) -@pytest.mark.parametrize( - 'on_missing_file', ['skip', 'warn']) -def test_add_file_paths_to_metadata_table( - some_files_fixture, - metadata_table_fixture, - on_missing_file): +@pytest.mark.parametrize("on_missing_file", ["skip", "warn"]) +def test_add_file_paths_to_metadata_table(some_files_fixture, metadata_table_fixture, on_missing_file): """ Test that add_file_paths_to_metadata_table behaves as expected when not raising an error @@ -65,39 +51,37 @@ def test_add_file_paths_to_metadata_table( id_generator = FileIDGenerator() - with pytest.warns(UserWarning, - match='The following files do not exist'): + with pytest.warns(UserWarning, match="The following files do not exist"): result = add_file_paths_to_metadata_table( metadata_table=metadata_table_fixture, id_generator=id_generator, file_dir=file_dir, - file_prefix='silly_file', - index_col='file_index', + file_prefix="silly_file", + index_col="file_index", on_missing_file=on_missing_file, - data_dir_col='session_id' + data_dir_col="session_id", ) # because we have not yet added file_id and file_path # to expected assert not expected.equals(result) - if on_missing_file == 'skip': + if on_missing_file == "skip": assert len(result) == len(some_files_fixture) else: assert len(result) == len(some_files_fixture) + 2 - expected['file_id'] = id_generator.dummy_value - expected['file_path'] = 'nothing' + expected["file_id"] = id_generator.dummy_value + expected["file_path"] = "nothing" for file_idx in expected.file_index: - file_path = file_dir / f'{file_idx}' / f'silly_file_{file_idx}.nwb' + file_path = file_dir / f"{file_idx}" / f"silly_file_{file_idx}.nwb" str_path = str(file_path.resolve().absolute()) - expected.loc[expected.file_index == file_idx, 'file_path'] = str_path + expected.loc[expected.file_index == file_idx, "file_path"] = str_path if file_path.exists(): file_id = id_generator.id_from_path(file_path=file_path) - expected.loc[expected.file_index == file_idx, 'file_id'] = file_id - elif on_missing_file == 'skip': - expected = expected.drop( - expected.loc[expected.file_index == file_idx].index) + expected.loc[expected.file_index == file_idx, "file_id"] = file_id + elif on_missing_file == "skip": + expected = expected.drop(expected.loc[expected.file_index == file_idx].index) pd.testing.assert_frame_equal(result, expected) diff --git a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_align_timestamps_module.py b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_align_timestamps_module.py index 273bb6f506..1406a5965d 100644 --- a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_align_timestamps_module.py +++ b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_align_timestamps_module.py @@ -12,10 +12,8 @@ ) -def apply_input_json_template( - template_path, input_json_path, temp_dir, data_dir=DATA_DIR -): - """ A utility for adjusting the input json so that: +def apply_input_json_template(template_path, input_json_path, temp_dir, data_dir=DATA_DIR): + """A utility for adjusting the input json so that: 1. input paths find cached data in the data dir 2. output paths write to a specified temp_dir The adjusted input json will be written to temp_dir. @@ -25,26 +23,15 @@ def apply_input_json_template( with open(template_path, "r") as input_json_file: input_json_data = json.load(input_json_file) - input_json_data["sync_h5_path"] = os.path.join( - data_dir, input_json_data["sync_h5_path"] - ) + input_json_data["sync_h5_path"] = os.path.join(data_dir, input_json_data["sync_h5_path"]) for probe in input_json_data["probes"]: - - probe["barcode_channel_states_path"] = os.path.join( - data_dir, probe["barcode_channel_states_path"] - ) - probe["barcode_timestamps_path"] = os.path.join( - data_dir, probe["barcode_timestamps_path"] - ) + probe["barcode_channel_states_path"] = os.path.join(data_dir, probe["barcode_channel_states_path"]) + probe["barcode_timestamps_path"] = os.path.join(data_dir, probe["barcode_timestamps_path"]) for timestamps_file in probe["mappable_timestamp_files"]: - timestamps_file["input_path"] = os.path.join( - data_dir, timestamps_file["input_path"] - ) - timestamps_file["output_path"] = os.path.join( - temp_dir, timestamps_file["output_path"] - ) + timestamps_file["input_path"] = os.path.join(data_dir, timestamps_file["input_path"]) + timestamps_file["output_path"] = os.path.join(temp_dir, timestamps_file["output_path"]) with open(input_json_path, "w") as input_json_file: json.dump(input_json_data, input_json_file) @@ -70,20 +57,12 @@ def align_timestamps_706875901_expected_params(): def align_timestamps_706875901_expected_files(): return lambda data_dir: { "probeA": { - "spikes_timestamps": os.path.join( - data_dir, "706875901_probeA_aligned_spike_timestamps.npy" - ), - "lfp_timestamps": os.path.join( - data_dir, "706875901_probeA_aligned_lfp_timestamps.npy" - ), + "spikes_timestamps": os.path.join(data_dir, "706875901_probeA_aligned_spike_timestamps.npy"), + "lfp_timestamps": os.path.join(data_dir, "706875901_probeA_aligned_lfp_timestamps.npy"), }, "probeB": { - "spikes_timestamps": os.path.join( - data_dir, "706875901_probeB_aligned_spike_timestamps.npy" - ), - "lfp_timestamps": os.path.join( - data_dir, "706875901_probeB_aligned_lfp_timestamps.npy" - ), + "spikes_timestamps": os.path.join(data_dir, "706875901_probeB_aligned_spike_timestamps.npy"), + "lfp_timestamps": os.path.join(data_dir, "706875901_probeB_aligned_lfp_timestamps.npy"), }, } @@ -98,9 +77,7 @@ def run_align_timestamps_706875901(tmpdir_factory): executable.extend(["--input_json", input_json_path]) executable.extend(["--output_json", output_json_path]) - input_json_template_path = os.path.join( - DATA_DIR, "706875901_align_timestamps_input.json" - ) + input_json_template_path = os.path.join(DATA_DIR, "706875901_align_timestamps_input.json") apply_input_json_template(input_json_template_path, input_json_path, base_path) sp.check_call(executable) @@ -112,7 +89,6 @@ def run_align_timestamps_706875901(tmpdir_factory): def test_align_timestamps_parameters_706875901( run_align_timestamps_706875901, align_timestamps_706875901_expected_params ): - with open(run_align_timestamps_706875901, "r") as output_json_file: output_json_data = json.load(output_json_file) @@ -120,27 +96,17 @@ def test_align_timestamps_parameters_706875901( expected = align_timestamps_706875901_expected_params[probe["name"]] assert expected["total_time_shift"] == probe["total_time_shift"] - assert ( - expected["global_probe_sampling_rate"] - == probe["global_probe_sampling_rate"] - ) - assert ( - expected["global_probe_lfp_sampling_rate"] - == probe["global_probe_lfp_sampling_rate"] - ) + assert expected["global_probe_sampling_rate"] == probe["global_probe_sampling_rate"] + assert expected["global_probe_lfp_sampling_rate"] == probe["global_probe_lfp_sampling_rate"] @pytest.mark.requires_bamboo -def test_align_timestamps_files_706875901( - run_align_timestamps_706875901, align_timestamps_706875901_expected_files -): - +def test_align_timestamps_files_706875901(run_align_timestamps_706875901, align_timestamps_706875901_expected_files): with open(run_align_timestamps_706875901, "r") as output_json_file: output_json_data = json.load(output_json_file) expected_files = align_timestamps_706875901_expected_files(DATA_DIR) for probe in output_json_data["probe_outputs"]: - for output_file_key, output_file_path in probe["output_paths"].items(): expected_file_path = expected_files[probe["name"]][output_file_key] expected_data = np.load(expected_file_path, allow_pickle=False) @@ -152,7 +118,6 @@ def test_align_timestamps_files_706875901( @pytest.mark.requires_bamboo def test_align_timestamps_barcode_agreement_706875901(run_align_timestamps_706875901): - with open(run_align_timestamps_706875901, "r") as output_json_file: output_json_data = json.load(output_json_file) @@ -167,13 +132,9 @@ def test_align_timestamps_barcode_agreement_706875901(run_align_timestamps_70687 barcode_data = np.load(probe["barcode_timestamps_path"], allow_pickle=False) total_time_shift = probe_parameters[name]["total_time_shift"] - global_probe_sampling_rate = probe_parameters[name][ - "global_probe_sampling_rate" - ] + global_probe_sampling_rate = probe_parameters[name]["global_probe_sampling_rate"] - aligned_barcode_data.append( - barcode_data / global_probe_sampling_rate - total_time_shift - ) + aligned_barcode_data.append(barcode_data / global_probe_sampling_rate - total_time_shift) barcode_timestamp_lengths.append(len(aligned_barcode_data)) min_length = np.amin(barcode_timestamp_lengths) diff --git a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode.py b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode.py index bb260c217a..c52824d310 100644 --- a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode.py +++ b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode.py @@ -6,7 +6,6 @@ @pytest.fixture def two_barcodes(): - on_times = np.array([11, 14, 30, 32, 34]) off_times = np.array([13, 15, 31, 33, 35]) @@ -20,7 +19,6 @@ def two_barcodes(): @pytest.fixture def master_barcodes_sequence(): - master_times = np.array([10, 25, 30, 37, 44, 45]) master_barcodes = np.array([1, 2, 3, 4, 5, 6]) @@ -28,7 +26,6 @@ def master_barcodes_sequence(): def test_extract_barcodes_from_times(two_barcodes): - starts_obt, codes_obt = barcode.extract_barcodes_from_times(*two_barcodes) starts_exp = [30] @@ -44,14 +41,11 @@ def test_extract_barcodes_from_times(two_barcodes): @pytest.mark.parametrize("prate", [10]) # , 1, -7, 0.1]) @pytest.mark.parametrize("npcodes", [-1]) # , 3]) def test_get_time_offset(sc, tr, sind, prate, npcodes, master_barcodes_sequence): - master_times, master_barcodes = master_barcodes_sequence probe_times = (master_times[:npcodes] + tr) * sc probe_barcodes = master_barcodes[:npcodes] - obt = barcode.get_probe_time_offset( - master_times, master_barcodes, probe_times, probe_barcodes, sind, prate - ) + obt = barcode.get_probe_time_offset(master_times, master_barcodes, probe_times, probe_barcodes, sind, prate) obt = [obt[0][0], obt[1][0], (obt[2][0][0], obt[2][1][0])] # total_time_shift, probe_rate, master_endpoints @@ -68,7 +62,6 @@ def test_get_time_offset(sc, tr, sind, prate, npcodes, master_barcodes_sequence) @pytest.mark.parametrize("sc", [-10, -2, -1, -0.5, 0.5, 1, 2, 10]) @pytest.mark.parametrize("tr", [-10, -2, -1, -0.5, 0, 0.5, 1, 2, 10]) def test_linear_transform_from_intervals(sc, tr, master_barcodes_sequence): - master = np.array([1, 2]) probe = (master + tr) * sc @@ -80,17 +73,12 @@ def test_linear_transform_from_intervals(sc, tr, master_barcodes_sequence): @pytest.mark.parametrize("sc", [-10, -2, -1, -0.5, 0.5, 1, 2, 10]) @pytest.mark.parametrize("tr", [-10, -2, -1, -0.5, 0, 0.5, 1, 2, 10]) -@pytest.mark.parametrize( - "npcodes", [-1, 3] -) # fails in region [0, 2] due to insufficient samples +@pytest.mark.parametrize("npcodes", [-1, 3]) # fails in region [0, 2] due to insufficient samples def test_match_barcodes(sc, tr, npcodes, master_barcodes_sequence): - master_times, master_barcodes = master_barcodes_sequence probe_times = (master_times + tr) * sc probe_times_cut = probe_times[:npcodes] probe_barcodes = master_barcodes[:npcodes] - pint, mint = barcode.match_barcodes( - master_times, master_barcodes, probe_times_cut, probe_barcodes - ) + pint, mint = barcode.match_barcodes(master_times, master_barcodes, probe_times_cut, probe_barcodes) assert pint[1] - pint[0] == sc * (mint[1] - mint[0]) diff --git a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode_sync_dataset.py b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode_sync_dataset.py index 69e5ff963c..3329b49727 100644 --- a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode_sync_dataset.py +++ b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_barcode_sync_dataset.py @@ -8,11 +8,8 @@ ) -@pytest.mark.parametrize( - "line_labels,expected", [[["barcode"], 0], [["barcodes"], 0], [[], None]] -) +@pytest.mark.parametrize("line_labels,expected", [[["barcode"], 0], [["barcodes"], 0], [[], None]]) def test_barcode_line(line_labels, expected): - dataset = BarcodeSyncDataset() dataset.line_labels = line_labels @@ -32,10 +29,7 @@ def test_barcode_line(line_labels, expected): [1, np.array([30, 50, 50.08]), np.array([31, 50.04, 50.12]), [50], [3], True], ], ) -def test_extract_barcodes( - sample_frequency, rising_edges, falling_edges, times_exp, codes_exp, table -): - +def test_extract_barcodes(sample_frequency, rising_edges, falling_edges, times_exp, codes_exp, table): dataset = BarcodeSyncDataset() dataset.sample_frequency = sample_frequency dataset.line_labels = ["barcode"] @@ -48,7 +42,6 @@ def test_extract_barcodes( "allensdk.brain_observatory.sync_dataset.Dataset.get_falling_edges", return_value=falling_edges, ): - if table: table = dataset.get_barcode_table() times = table["times"] diff --git a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_channel_states.py b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_channel_states.py index 5f0d9eeda5..e1f9164e99 100644 --- a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_channel_states.py +++ b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_channel_states.py @@ -1,4 +1,3 @@ - import pytest import numpy as np @@ -17,10 +16,7 @@ ] ], ) -def test_extract_barcodes_from_states( - sample_frequency, events, times, times_exp, codes_exp -): - +def test_extract_barcodes_from_states(sample_frequency, events, times, times_exp, codes_exp): times, codes = cs.extract_barcodes_from_states(events, times, sample_frequency) assert np.allclose(times, times_exp) diff --git a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_probe_synchronizer.py b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_probe_synchronizer.py index 12e3f9a222..8d9cdc1c86 100644 --- a/allensdk/test/brain_observatory/ecephys/align_timestamps/test_probe_synchronizer.py +++ b/allensdk/test/brain_observatory/ecephys/align_timestamps/test_probe_synchronizer.py @@ -1,4 +1,3 @@ - import pytest import numpy as np @@ -9,7 +8,6 @@ def get_test_barcodes(): - master_barcode_times = np.linspace(0, 30, 10) master_barcodes = np.arange(0, 11) @@ -31,15 +29,12 @@ def get_test_barcodes(): @pytest.fixture def synchronizer(): - local_sampling_rate = 4.0 probe_start_index = 0 mbt, mb, pbt, pb, min_time, max_time = get_test_barcodes() - result = ProbeSynchronizer.compute( - mbt, mb, pbt, pb, min_time, max_time, probe_start_index, local_sampling_rate - ) + result = ProbeSynchronizer.compute(mbt, mb, pbt, pb, min_time, max_time, probe_start_index, local_sampling_rate) return result @@ -61,7 +56,6 @@ def synchronizer(): ], ) def test_call(synchronizer, samples, sync_condition, expected): - if sync_condition in ("master", "probe"): obtained = synchronizer(samples, sync_condition=sync_condition) # print(obtained) @@ -73,5 +67,4 @@ def test_call(synchronizer, samples, sync_condition, expected): def test_sampling_rate_scale(synchronizer): - assert synchronizer.sampling_rate_scale == 0.5 diff --git a/allensdk/test/brain_observatory/ecephys/conftest.py b/allensdk/test/brain_observatory/ecephys/conftest.py index e2f0743a00..42912bb732 100644 --- a/allensdk/test/brain_observatory/ecephys/conftest.py +++ b/allensdk/test/brain_observatory/ecephys/conftest.py @@ -2,10 +2,10 @@ def pytest_ignore_collect(path, config): - ''' + """ The brain_observatory.ecephys submodule uses python 3.6 features that may not be backwards compatible! - ''' + """ if sys.version_info < (3, 6): return True diff --git a/allensdk/test/brain_observatory/ecephys/ecephys_session_api/test_ecephys_nwb1_session_api.py b/allensdk/test/brain_observatory/ecephys/ecephys_session_api/test_ecephys_nwb1_session_api.py index c0e2ad3b12..e05d720ec0 100644 --- a/allensdk/test/brain_observatory/ecephys/ecephys_session_api/test_ecephys_nwb1_session_api.py +++ b/allensdk/test/brain_observatory/ecephys/ecephys_session_api/test_ecephys_nwb1_session_api.py @@ -8,9 +8,9 @@ @pytest.mark.requires_bamboo -@pytest.mark.parametrize("nwb_path", [ - Path("/", "allen", "aibs", "mat", "Kael", "ecephys_data", "mouse412792.spikes.nwb") -]) +@pytest.mark.parametrize( + "nwb_path", [Path("/", "allen", "aibs", "mat", "Kael", "ecephys_data", "mouse412792.spikes.nwb")] +) def test_spikes_nwb1(nwb_path): """ This test was based on the file /allen/aibs/mat/ecephys_data/mouse412792.spikes.nwb. To run this test please copy @@ -23,39 +23,37 @@ def test_spikes_nwb1(nwb_path): # TODO: Convert this NWB 1 file into a NWB 2 file that way we can check that NWB Adaptors return the same data # and computations (minus a few exceptions for missing NWB 1 data). session = EcephysSession.from_nwb_path(path=str(nwb_path), nwb_version=1) - assert(isinstance(session.units, pd.DataFrame)) - assert(len(session.units) == 1363) + assert isinstance(session.units, pd.DataFrame) + assert len(session.units) == 1363 print(session.stimulus_names) - assert(isinstance(session.stimulus_presentations, pd.DataFrame)) - assert(len(session.stimulus_presentations) == 70390) - assert(len(session.get_stimulus_table(['Natural Images_5'])) == 5950) - assert(len(session.get_stimulus_table(['drifting_gratings_2'])) == 630) - assert(len(session.get_stimulus_table(['flash_250ms_1'])) == 150) - assert(len(session.get_stimulus_table(['gabor_20_deg_250ms_0'])) == 3645) - assert(len(session.get_stimulus_table(['natural_movie_one_three'])) == 18000) - assert(len(session.get_stimulus_table(['natural_movie_three_four'])) == 36000) - assert(len(session.get_stimulus_table(['spontaneous'])) == 15) - assert(len(session.get_stimulus_table(['static_gratings_6'])) == 6000) + assert isinstance(session.stimulus_presentations, pd.DataFrame) + assert len(session.stimulus_presentations) == 70390 + assert len(session.get_stimulus_table(["Natural Images_5"])) == 5950 + assert len(session.get_stimulus_table(["drifting_gratings_2"])) == 630 + assert len(session.get_stimulus_table(["flash_250ms_1"])) == 150 + assert len(session.get_stimulus_table(["gabor_20_deg_250ms_0"])) == 3645 + assert len(session.get_stimulus_table(["natural_movie_one_three"])) == 18000 + assert len(session.get_stimulus_table(["natural_movie_three_four"])) == 36000 + assert len(session.get_stimulus_table(["spontaneous"])) == 15 + assert len(session.get_stimulus_table(["static_gratings_6"])) == 6000 - assert(session.running_speed.shape[0] == 365700) + assert session.running_speed.shape[0] == 365700 - assert(len(session.spike_times.keys()) == 1363) + assert len(session.spike_times.keys()) == 1363 - assert(len(session.mean_waveforms.keys()) == 1363) + assert len(session.mean_waveforms.keys()) == 1363 one_waveform = next(iter(session.mean_waveforms.values())) - assert(isinstance(one_waveform, xr.DataArray)) + assert isinstance(one_waveform, xr.DataArray) - assert(len(session.probes) == 6) - assert(len(session.channels) == 737) + assert len(session.probes) == 6 + assert len(session.channels) == 737 pst = session.presentationwise_spike_times() - assert(isinstance(pst, pd.DataFrame) and len(pst) > 0) + assert isinstance(pst, pd.DataFrame) and len(pst) > 0 cpc = session.conditionwise_spike_statistics( stimulus_presentation_ids=session.stimulus_presentations.index.values[:40] ) - assert(isinstance(cpc, pd.DataFrame) and len(cpc) > 0) - - + assert isinstance(cpc, pd.DataFrame) and len(cpc) > 0 diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/conftest.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/conftest.py index 7daf9984c8..5971e9eaba 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/conftest.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/conftest.py @@ -6,6 +6,7 @@ class MockSessionApi(EcephysSessionApi): """Mock Data to create an EcephysSession object and pass it into stimulus analysis""" + def get_spike_times(self): return { 0: np.array([1, 2, 3, 4]), @@ -13,53 +14,62 @@ def get_spike_times(self): 2: np.array([1.01, 1.03, 1.02]), 3: np.array([]), 4: np.array([0.01, 1.7, 2.13, 3.19, 4.25]), - 5: np.array([1.5, 3.0, 4.5]) + 5: np.array([1.5, 3.0, 4.5]), } def get_channels(self): - return pd.DataFrame({ - 'local_index': [0, 1, 2], - 'probe_horizontal_position': [5, 10, 15], - 'probe_id': [0, 0, 1], - 'probe_vertical_position': [10, 22, 33], - 'valid_data': [False, True, True] - }, index=pd.Index(name='channel_id', data=[0, 1, 2])) + return pd.DataFrame( + { + "local_index": [0, 1, 2], + "probe_horizontal_position": [5, 10, 15], + "probe_id": [0, 0, 1], + "probe_vertical_position": [10, 22, 33], + "valid_data": [False, True, True], + }, + index=pd.Index(name="channel_id", data=[0, 1, 2]), + ) def get_units(self): - udf = pd.DataFrame({ - 'firing_rate': np.linspace(1, 3, 6), - 'isi_violations': [40, 0.5, 0.1, 0.2, 0.0, 0.1], - 'local_index': [0, 0, 1, 1, 2, 2], - 'peak_channel_id': [0, 2, 1, 1, 2, 0], - 'quality': ['good', 'good', 'good', 'bad', 'good', 'good'], - }, index=pd.Index(name='unit_id', data=np.arange(6)[::-1])) + udf = pd.DataFrame( + { + "firing_rate": np.linspace(1, 3, 6), + "isi_violations": [40, 0.5, 0.1, 0.2, 0.0, 0.1], + "local_index": [0, 0, 1, 1, 2, 2], + "peak_channel_id": [0, 2, 1, 1, 2, 0], + "quality": ["good", "good", "good", "bad", "good", "good"], + }, + index=pd.Index(name="unit_id", data=np.arange(6)[::-1]), + ) return udf def get_probes(self): - return pd.DataFrame({ - 'description': ['probeA', 'probeB'], - 'location': ['VISp', 'VISam'], - 'sampling_rate': [30000.0, 30000.0] - }, index=pd.Index(name='id', data=[0, 1])) + return pd.DataFrame( + {"description": ["probeA", "probeB"], "location": ["VISp", "VISam"], "sampling_rate": [30000.0, 30000.0]}, + index=pd.Index(name="id", data=[0, 1]), + ) def get_stimulus_presentations(self): - return pd.DataFrame({ - 'start_time': np.linspace(0.0, 4.5, 10, endpoint=True), - 'stop_time': np.linspace(0.5, 5.0, 10, endpoint=True), - 'stimulus_name': ['spontaneous'] + ['s0'] * 6 + ['spontaneous'] + ['s1'] * 2, - 'stimulus_block': [0] + [1] * 6 + [0] + [2] * 2, - 'duration': 0.5, - 'stimulus_index': [0] + [1] * 6 + [0] + [2] * 2, - 'conditions': [0, 0, 0, 0, 1, 1, 1, 0, 2, 3] # generic stimulus condition - }, index=pd.Index(name='id', data=np.arange(10))) + return pd.DataFrame( + { + "start_time": np.linspace(0.0, 4.5, 10, endpoint=True), + "stop_time": np.linspace(0.5, 5.0, 10, endpoint=True), + "stimulus_name": ["spontaneous"] + ["s0"] * 6 + ["spontaneous"] + ["s1"] * 2, + "stimulus_block": [0] + [1] * 6 + [0] + [2] * 2, + "duration": 0.5, + "stimulus_index": [0] + [1] * 6 + [0] + [2] * 2, + "conditions": [0, 0, 0, 0, 1, 1, 1, 0, 2, 3], # generic stimulus condition + }, + index=pd.Index(name="id", data=np.arange(10)), + ) def get_invalid_times(self): return pd.DataFrame() - def get_running_speed(self): - return pd.DataFrame({ - "start_time": np.linspace(0.0, 9.9, 100), - "end_time": np.linspace(0.1, 10.0, 100), - "velocity": np.linspace(-0.1, 11.0, 100) - }) \ No newline at end of file + return pd.DataFrame( + { + "start_time": np.linspace(0.0, 9.9, 100), + "end_time": np.linspace(0.1, 10.0, 100), + "velocity": np.linspace(-0.1, 11.0, 100), + } + ) diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py index 1f3970c9a1..06f5e55243 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_dot_motion.py @@ -3,36 +3,34 @@ import pandas as pd from .conftest import MockSessionApi -from allensdk.brain_observatory.ecephys.stimulus_analysis.dot_motion \ - import DotMotion +from allensdk.brain_observatory.ecephys.stimulus_analysis.dot_motion import DotMotion from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession class MockDMSessionApi(MockSessionApi): def get_stimulus_presentations(self): - features = np.array(np.meshgrid([0.0, 45.0, 90.0, 135.0, 180.0, - 225.0, 270.0, 315.0], # Dir - [0.001, 0.005, 0.01, 0.02]) # Speed - ).reshape(2, 32) - - features = np.concatenate((features, - np.array([np.nan, np.nan]).reshape((2, 1))), - axis=1) # null case - - return pd.DataFrame({ - 'start_time': np.concatenate(([0.0], np.linspace(0.5, 32.5, 33, - endpoint=True), [33.5])), - 'stop_time': np.concatenate(([0.5], np.linspace(1.5, 33.5, 33, - endpoint=True), [34.0])), - 'stimulus_name': ['spontaneous'] + - ['dot_motion']*33 + - ['spontaneous'], - 'stimulus_block': [0] + [1]*33 + [0], - 'duration': [0.5] + [1.0]*33 + [0.5], - 'stimulus_index': [0] + [1]*33 + [0], - 'Dir': np.concatenate(([np.nan], features[0, :], [np.nan])), - 'Speed': np.concatenate(([np.nan], features[1, :], [np.nan])) - }, index=pd.Index(name='id', data=np.arange(35))) + features = np.array( + np.meshgrid( + [0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0], # Dir + [0.001, 0.005, 0.01, 0.02], + ) # Speed + ).reshape(2, 32) + + features = np.concatenate((features, np.array([np.nan, np.nan]).reshape((2, 1))), axis=1) # null case + + return pd.DataFrame( + { + "start_time": np.concatenate(([0.0], np.linspace(0.5, 32.5, 33, endpoint=True), [33.5])), + "stop_time": np.concatenate(([0.5], np.linspace(1.5, 33.5, 33, endpoint=True), [34.0])), + "stimulus_name": ["spontaneous"] + ["dot_motion"] * 33 + ["spontaneous"], + "stimulus_block": [0] + [1] * 33 + [0], + "duration": [0.5] + [1.0] * 33 + [0.5], + "stimulus_index": [0] + [1] * 33 + [0], + "Dir": np.concatenate(([np.nan], features[0, :], [np.nan])), + "Speed": np.concatenate(([np.nan], features[1, :], [np.nan])), + }, + index=pd.Index(name="id", data=np.arange(35)), + ) def get_invalid_times(self): return pd.DataFrame() @@ -46,61 +44,57 @@ def ecephys_api(): def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) dm = DotMotion(ecephys_session=session) - assert(dm.name == 'Dot Motion') - assert(set(dm.unit_ids) == set(range(6))) - assert(len(dm.conditionwise_statistics) == 33*6) - assert(dm.conditionwise_psth.shape == (33, 1.0/0.001-1, 6)) - assert(not dm.presentationwise_spike_times.empty) - assert(len(dm.presentationwise_statistics) == 33*6) - assert(len(dm.stimulus_conditions) == 33) + assert dm.name == "Dot Motion" + assert set(dm.unit_ids) == set(range(6)) + assert len(dm.conditionwise_statistics) == 33 * 6 + assert dm.conditionwise_psth.shape == (33, 1.0 / 0.001 - 1, 6) + assert not dm.presentationwise_spike_times.empty + assert len(dm.presentationwise_statistics) == 33 * 6 + assert len(dm.stimulus_conditions) == 33 def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) dm = DotMotion(ecephys_session=session) - assert(isinstance(dm.stim_table, pd.DataFrame)) - assert(len(dm.stim_table) == 33) - assert(set(dm.stim_table.columns).issuperset({'Dir', - 'Speed', - 'start_time', - 'stop_time'})) + assert isinstance(dm.stim_table, pd.DataFrame) + assert len(dm.stim_table) == 33 + assert set(dm.stim_table.columns).issuperset({"Dir", "Speed", "start_time", "stop_time"}) - assert(set(dm.directions) == {0.0, 45.0, 90.0, 135.0, - 180.0, 225.0, 270.0, 315.0}) - assert(dm.number_directions == 8) + assert set(dm.directions) == {0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0} + assert dm.number_directions == 8 - assert(set(dm.speeds) == {0.001, 0.005, 0.01, 0.02}) - assert(dm.number_speeds == 4) + assert set(dm.speeds) == {0.001, 0.005, 0.01, 0.02} + assert dm.number_speeds == 4 def test_metrics(ecephys_api): session = EcephysSession(api=ecephys_api) rfm = DotMotion(ecephys_session=session) - assert(isinstance(rfm.metrics, pd.DataFrame)) - assert(len(rfm.metrics) == 6) - assert(rfm.metrics.index.names == ['unit_id']) + assert isinstance(rfm.metrics, pd.DataFrame) + assert len(rfm.metrics) == 6 + assert rfm.metrics.index.names == ["unit_id"] - assert('pref_speed_dm' in rfm.metrics.columns) - assert(rfm.metrics['pref_speed_dm'].loc[0] == 0.001) - assert(rfm.metrics['pref_speed_dm'].loc[5] == 0.001) + assert "pref_speed_dm" in rfm.metrics.columns + assert rfm.metrics["pref_speed_dm"].loc[0] == 0.001 + assert rfm.metrics["pref_speed_dm"].loc[5] == 0.001 - assert('pref_dir_dm' in rfm.metrics.columns) - assert(rfm.metrics['pref_dir_dm'].loc[0] == 0.0) - assert(rfm.metrics['pref_dir_dm'].loc[4] == 45.0) + assert "pref_dir_dm" in rfm.metrics.columns + assert rfm.metrics["pref_dir_dm"].loc[0] == 0.0 + assert rfm.metrics["pref_dir_dm"].loc[4] == 45.0 - assert('firing_rate_dm' in rfm.metrics.columns) - assert('fano_dm' in rfm.metrics.columns) - assert('lifetime_sparseness_dm' in rfm.metrics.columns) - assert('run_pval_dm' in rfm.metrics.columns) - assert('run_mod_dm' in rfm.metrics.columns) + assert "firing_rate_dm" in rfm.metrics.columns + assert "fano_dm" in rfm.metrics.columns + assert "lifetime_sparseness_dm" in rfm.metrics.columns + assert "run_pval_dm" in rfm.metrics.columns + assert "run_mod_dm" in rfm.metrics.columns -@pytest.mark.skip(reason='metric not yet implemented') +@pytest.mark.skip(reason="metric not yet implemented") def test_speed_tuning_idx(): pass -if __name__ == '__main__': +if __name__ == "__main__": # test_load() # test_stimulus() test_metrics() diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_drifting_gratings.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_drifting_gratings.py index f8ab53fd22..2184100ce8 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_drifting_gratings.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_drifting_gratings.py @@ -4,10 +4,14 @@ from .conftest import MockSessionApi from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession -from allensdk.brain_observatory.ecephys.stimulus_analysis.drifting_gratings \ - import DriftingGratings, modulation_index, c50, f1_f0 +from allensdk.brain_observatory.ecephys.stimulus_analysis.drifting_gratings import ( + DriftingGratings, + modulation_index, + c50, + f1_f0, +) -pd.set_option('display.max_columns', None) +pd.set_option("display.max_columns", None) class MockDGSessionApi(MockSessionApi): @@ -23,60 +27,57 @@ def get_spike_times(self): 1: np.array([2.5]), 2: np.array([1.01, 1.03, 1.02]), 3: np.array([]), - 4: np.array( - [0.01, 1.7, 2.13, 3.19, 4.25, 46.4, 48.7, 54.2, 80.3, 85.40, - 85.44, 85.47]), - + 4: np.array([0.01, 1.7, 2.13, 3.19, 4.25, 46.4, 48.7, 54.2, 80.3, 85.40, 85.44, 85.47]), # 5: np.array([1.5, 3.0, 4.5, 90.1]) # make sure there is a # spike for the contrast stimulus - 5: np.concatenate(([1.5, 3.0, 4.5], np.linspace(85.0, 89.0, 20))) + 5: np.concatenate(([1.5, 3.0, 4.5], np.linspace(85.0, 89.0, 20))), } def get_stimulus_presentations(self): - features = np.array(np.meshgrid([1.0, 2.0, 4.0, 8.0, 15.0], # TF - [0.0, 45.0, 90.0, 135.0, 180.0, 225.0, - 270.0, 315.0]) # ORI - ).reshape(2, 40) - - stim_table = pd.DataFrame({ - 'start_time': np.concatenate( - ([0.0], np.linspace(0.5, 78.5, 40, endpoint=True), [80.0])), - 'stop_time': np.concatenate( - ([0.0], np.linspace(2.5, 80.5, 40, endpoint=True), [81.0])), - 'stimulus_name': ['spontaneous'] + ['drifting_gratings'] * 40 + [ - 'spontaneous'], - 'stimulus_block': [0] + [1] * 40 + [0], - 'duration': [0.5] + [2.0] * 40 + [0.5], - 'stimulus_index': [0] + [1] * 40 + [0], - 'temporal_frequency': np.concatenate( - ([np.nan], features[0, :], [np.nan])), - 'orientation': np.concatenate( - ([np.nan], features[1, :], [np.nan])), - 'contrast': 0.8 - }, index=pd.Index(name='id', data=np.arange(42))) + features = np.array( + np.meshgrid( + [1.0, 2.0, 4.0, 8.0, 15.0], # TF + [0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0], + ) # ORI + ).reshape(2, 40) + + stim_table = pd.DataFrame( + { + "start_time": np.concatenate(([0.0], np.linspace(0.5, 78.5, 40, endpoint=True), [80.0])), + "stop_time": np.concatenate(([0.0], np.linspace(2.5, 80.5, 40, endpoint=True), [81.0])), + "stimulus_name": ["spontaneous"] + ["drifting_gratings"] * 40 + ["spontaneous"], + "stimulus_block": [0] + [1] * 40 + [0], + "duration": [0.5] + [2.0] * 40 + [0.5], + "stimulus_index": [0] + [1] * 40 + [0], + "temporal_frequency": np.concatenate(([np.nan], features[0, :], [np.nan])), + "orientation": np.concatenate(([np.nan], features[1, :], [np.nan])), + "contrast": 0.8, + }, + index=pd.Index(name="id", data=np.arange(42)), + ) if self._with_dg_contrast: - features = np.array(np.meshgrid([0.0, 45.0, 90.0, 135.0], # ORI - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, - 0.35, 0.6, 1.0]) # contrast - ).reshape(2, 36) - - dg_constrast = pd.DataFrame({ - 'start_time': np.concatenate( - (80.0 + np.linspace(0.0, 17.5, 36, endpoint=True), - [97.5])), - 'stop_time': np.concatenate((81.5 + np.linspace(0.5, 18.0, 36, - endpoint=True), - [98.0])), - 'stimulus_name': ['drifting_gratings_contrast'] * 36 + [ - 'spontaneous'], - 'stimulus_block': [2] * 36 + [0], - 'duration': [0.5] * 36 + [0.5], - 'stimulus_index': [2] * 36 + [0], - 'temporal_frequency': 2.0, - 'orientation': np.concatenate((features[0, :], [np.nan])), - 'contrast': np.concatenate((features[1, :], [np.nan])) - }, index=pd.Index(name='id', data=np.arange(42, 42 + 37))) + features = np.array( + np.meshgrid( + [0.0, 45.0, 90.0, 135.0], # ORI + [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0], + ) # contrast + ).reshape(2, 36) + + dg_constrast = pd.DataFrame( + { + "start_time": np.concatenate((80.0 + np.linspace(0.0, 17.5, 36, endpoint=True), [97.5])), + "stop_time": np.concatenate((81.5 + np.linspace(0.5, 18.0, 36, endpoint=True), [98.0])), + "stimulus_name": ["drifting_gratings_contrast"] * 36 + ["spontaneous"], + "stimulus_block": [2] * 36 + [0], + "duration": [0.5] * 36 + [0.5], + "stimulus_index": [2] * 36 + [0], + "temporal_frequency": 2.0, + "orientation": np.concatenate((features[0, :], [np.nan])), + "contrast": np.concatenate((features[1, :], [np.nan])), + }, + index=pd.Index(name="id", data=np.arange(42, 42 + 37)), + ) stim_table = pd.concat((stim_table, dg_constrast)) return stim_table @@ -98,190 +99,179 @@ def ecephys_api_w_contrast(): def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) dg = DriftingGratings(ecephys_session=session) - assert (dg.name == 'Drifting Gratings') - assert (set(dg.unit_ids) == set(range(6))) - assert (len(dg.conditionwise_statistics) == 40 * 6) - assert (dg.conditionwise_psth.shape == (40, 2.0 / 0.001 - 1, 6)) - assert (not dg.presentationwise_spike_times.empty) - assert (len(dg.presentationwise_statistics) == 40 * 6) - assert (len(dg.stimulus_conditions) == 40) + assert dg.name == "Drifting Gratings" + assert set(dg.unit_ids) == set(range(6)) + assert len(dg.conditionwise_statistics) == 40 * 6 + assert dg.conditionwise_psth.shape == (40, 2.0 / 0.001 - 1, 6) + assert not dg.presentationwise_spike_times.empty + assert len(dg.presentationwise_statistics) == 40 * 6 + assert len(dg.stimulus_conditions) == 40 def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) dg = DriftingGratings(ecephys_session=session) - assert (isinstance(dg.stim_table, pd.DataFrame)) - assert (len(dg.stim_table) == 40) - assert (len(dg.stim_table_contrast) == 0) + assert isinstance(dg.stim_table, pd.DataFrame) + assert len(dg.stim_table) == 40 + assert len(dg.stim_table_contrast) == 0 - assert (set(dg.stim_table.columns).issuperset( - {'temporal_frequency', 'orientation', 'contrast', 'start_time', - 'stop_time'})) + assert set(dg.stim_table.columns).issuperset( + {"temporal_frequency", "orientation", "contrast", "start_time", "stop_time"} + ) - assert (set(dg.tfvals) == {1.0, 2.0, 4.0, 8.0, 15.0}) - assert (dg.number_tf == 5) + assert set(dg.tfvals) == {1.0, 2.0, 4.0, 8.0, 15.0} + assert dg.number_tf == 5 - assert (set(dg.orivals) == {0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, - 315.0}) - assert (dg.number_ori == 8) + assert set(dg.orivals) == {0.0, 45.0, 90.0, 135.0, 180.0, 225.0, 270.0, 315.0} + assert dg.number_ori == 8 - assert (set(dg.contrastvals) == {0.8}) - assert (dg.number_contrast == 1) + assert set(dg.contrastvals) == {0.8} + assert dg.number_contrast == 1 def test_metrics(ecephys_api): # Run metrics with no drifting_gratings_contrast stimuli session = EcephysSession(api=ecephys_api) dg = DriftingGratings(ecephys_session=session) - assert (isinstance(dg.metrics, pd.DataFrame)) - assert (len(dg.metrics) == 6) - assert (dg.metrics.index.names == ['unit_id']) + assert isinstance(dg.metrics, pd.DataFrame) + assert len(dg.metrics) == 6 + assert dg.metrics.index.names == ["unit_id"] - assert ('pref_ori_dg' in dg.metrics.columns) - assert (np.all( - dg.metrics['pref_ori_dg'].loc[[0, 1, 2, 3, 4, 5]] == np.full(6, 0.0))) + assert "pref_ori_dg" in dg.metrics.columns + assert np.all(dg.metrics["pref_ori_dg"].loc[[0, 1, 2, 3, 4, 5]] == np.full(6, 0.0)) - assert ('pref_tf_dg' in dg.metrics.columns) - assert (np.all(dg.metrics['pref_tf_dg'].loc[[0, 5]] == [1.0, 2.0])) + assert "pref_tf_dg" in dg.metrics.columns + assert np.all(dg.metrics["pref_tf_dg"].loc[[0, 5]] == [1.0, 2.0]) # with no contrast stimuli the c50 metric should be null - assert ('c50_dg' in dg.metrics.columns) - assert ( - np.allclose(dg.metrics['c50_dg'].values, [np.nan] * 6, equal_nan=True)) + assert "c50_dg" in dg.metrics.columns + assert np.allclose(dg.metrics["c50_dg"].values, [np.nan] * 6, equal_nan=True) - assert ('f1_f0_dg' in dg.metrics.columns) - assert (np.allclose(dg.metrics['f1_f0_dg'].loc[[0, 1, 2, 3, 4, 5]], - [0.001572, np.nan, 1.999778, np.nan, 1.560436, - 1.999978], equal_nan=True, atol=1.0e-06)) + assert "f1_f0_dg" in dg.metrics.columns + assert np.allclose( + dg.metrics["f1_f0_dg"].loc[[0, 1, 2, 3, 4, 5]], + [0.001572, np.nan, 1.999778, np.nan, 1.560436, 1.999978], + equal_nan=True, + atol=1.0e-06, + ) - assert ('mod_idx_dg' in dg.metrics.columns) - assert ('g_osi_dg' in dg.metrics.columns) - assert (np.allclose(dg.metrics['g_osi_dg'].loc[[0, 3, 4, 5]], - [1.0, np.nan, 0.745356, 1.0], equal_nan=True)) + assert "mod_idx_dg" in dg.metrics.columns + assert "g_osi_dg" in dg.metrics.columns + assert np.allclose(dg.metrics["g_osi_dg"].loc[[0, 3, 4, 5]], [1.0, np.nan, 0.745356, 1.0], equal_nan=True) - assert ('g_dsi_dg' in dg.metrics.columns) - assert (np.allclose(dg.metrics['g_dsi_dg'].loc[[0, 3, 4, 5]], - [1.0, np.nan, 0.491209, 1.0], equal_nan=True)) + assert "g_dsi_dg" in dg.metrics.columns + assert np.allclose(dg.metrics["g_dsi_dg"].loc[[0, 3, 4, 5]], [1.0, np.nan, 0.491209, 1.0], equal_nan=True) - assert ('firing_rate_dg' in dg.metrics.columns) - assert ('fano_dg' in dg.metrics.columns) - assert ('lifetime_sparseness_dg' in dg.metrics.columns) - assert ('run_pval_dg' in dg.metrics.columns) - assert ('run_mod_dg' in dg.metrics.columns) + assert "firing_rate_dg" in dg.metrics.columns + assert "fano_dg" in dg.metrics.columns + assert "lifetime_sparseness_dg" in dg.metrics.columns + assert "run_pval_dg" in dg.metrics.columns + assert "run_mod_dg" in dg.metrics.columns def test_contrast_stimulus(ecephys_api_w_contrast): session = EcephysSession(api=ecephys_api_w_contrast) dg = DriftingGratings(ecephys_session=session) - assert (len(dg.stim_table) == 40) + assert len(dg.stim_table) == 40 - assert (len(dg.stim_table_contrast) == 36) - assert (len(dg.stimulus_conditions_contrast) == 36) - assert (len(dg.conditionwise_statistics_contrast) == 36 * 6) + assert len(dg.stim_table_contrast) == 36 + assert len(dg.stimulus_conditions_contrast) == 36 + assert len(dg.conditionwise_statistics_contrast) == 36 * 6 def test_metric_with_contrast(ecephys_api_w_contrast): session = EcephysSession(api=ecephys_api_w_contrast) dg = DriftingGratings(ecephys_session=session) - assert (isinstance(dg.metrics, pd.DataFrame)) - assert (len(dg.metrics) == 6) - assert (dg.metrics.index.names == ['unit_id']) + assert isinstance(dg.metrics, pd.DataFrame) + assert len(dg.metrics) == 6 + assert dg.metrics.index.names == ["unit_id"] # make sure normal prefered conditions remain the same - assert ('pref_ori_dg' in dg.metrics.columns) - assert (np.all( - dg.metrics['pref_ori_dg'].loc[[0, 1, 2, 3, 4, 5]] == np.full(6, 0.0))) - assert ('pref_tf_dg' in dg.metrics.columns) - assert (np.all(dg.metrics['pref_tf_dg'].loc[[0, 5]] == [1.0, 2.0])) + assert "pref_ori_dg" in dg.metrics.columns + assert np.all(dg.metrics["pref_ori_dg"].loc[[0, 1, 2, 3, 4, 5]] == np.full(6, 0.0)) + assert "pref_tf_dg" in dg.metrics.columns + assert np.all(dg.metrics["pref_tf_dg"].loc[[0, 5]] == [1.0, 2.0]) # Make sure class can see drifting_gratings_contrasts stimuli - assert ('c50_dg' in dg.metrics.columns) - assert np.isfinite(dg.metrics['c50_dg'].loc[0]) - assert np.isnan(dg.metrics['c50_dg'].loc[4]) + assert "c50_dg" in dg.metrics.columns + assert np.isfinite(dg.metrics["c50_dg"].loc[0]) + assert np.isnan(dg.metrics["c50_dg"].loc[4]) # NOTE beginning with a change that updated pandas, pyNWB and numpy # version dependencies, the underlying 'c50' calculation # (drifting_gratings.py) very occasionally is off by one index # in estimating the halfway point in the contrast curve. # accommodating that possibility here: - assert np.allclose(dg.metrics['c50_dg'].loc[[5]], 0.17585882, atol=0.5, - rtol=1.0) - - -@pytest.mark.parametrize('response,tf,sampling_rate,expected', - [ - (np.array([]), 2.0, 1000.0, np.nan), - # invalid input - (np.zeros(2000), 2.0, 1000.0, 0.0), - # no responses, MI ~ 0 - (np.ones(2000), 4.0, 1000.0, 0.0), - # no derivation, MI ~ 0 - (np.linspace(0.5, 12.1), 8.0, 1.0, np.nan), - # tf is outside niquist freq. - (np.array([0.1, 0.2, 0.2, 1.1]), 2.0, 4.0, - 0.1389328986), # low mi - (np.linspace(0.5, 12.1, 50), 8.0, 1000.0, - 4.993941), # high mi - ]) + assert np.allclose(dg.metrics["c50_dg"].loc[[5]], 0.17585882, atol=0.5, rtol=1.0) + + +@pytest.mark.parametrize( + "response,tf,sampling_rate,expected", + [ + (np.array([]), 2.0, 1000.0, np.nan), + # invalid input + (np.zeros(2000), 2.0, 1000.0, 0.0), + # no responses, MI ~ 0 + (np.ones(2000), 4.0, 1000.0, 0.0), + # no derivation, MI ~ 0 + (np.linspace(0.5, 12.1), 8.0, 1.0, np.nan), + # tf is outside niquist freq. + (np.array([0.1, 0.2, 0.2, 1.1]), 2.0, 4.0, 0.1389328986), # low mi + (np.linspace(0.5, 12.1, 50), 8.0, 1000.0, 4.993941), # high mi + ], +) def test_modulation_index(response, tf, sampling_rate, expected): mi = modulation_index(response, tf, sampling_rate) # return nan, invalid - assert (np.isclose(mi, expected, equal_nan=True)) - - -@pytest.mark.parametrize('contrast_vals,responses,expected', - [ - (np.array( - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, - 1.0]), np.array([]), np.nan), - # invalid input - (np.array( - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, - 1.0]), np.full(9, 12.0), 0.0090), - # flat non-zero curve - (np.array( - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, - 1.0]), np.zeros(9), None), - # no responses — degenerate fit, check finite only - (np.array( - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, - 1.0]), np.linspace(0.0, 12.0, 9), - 0.1330745098039216), - (np.array( - [0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, - 1.0]), np.array( - [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0]), np.nan), - # nan, special case where curve can't be fitted - ]) + assert np.isclose(mi, expected, equal_nan=True) + + +@pytest.mark.parametrize( + "contrast_vals,responses,expected", + [ + (np.array([0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0]), np.array([]), np.nan), + # invalid input + (np.array([0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0]), np.full(9, 12.0), 0.0090), + # flat non-zero curve + (np.array([0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0]), np.zeros(9), None), + # no responses — degenerate fit, check finite only + (np.array([0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0]), np.linspace(0.0, 12.0, 9), 0.1330745098039216), + ( + np.array([0.01, 0.02, 0.04, 0.08, 0.13, 0.2, 0.35, 0.6, 1.0]), + np.array([10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + np.nan, + ), + # nan, special case where curve can't be fitted + ], +) def test_c50(contrast_vals, responses, expected): c50_metric = c50(contrast_vals, responses) if expected is None: assert np.isfinite(c50_metric) else: - assert (np.isclose(c50_metric, expected, equal_nan=True)) - - -@pytest.mark.parametrize('data_arr,tf,trial_duration,expected', - [ - (np.array([]), 2.0, 1.0, np.nan), # invalid input - (np.zeros((5, 256)), 4.0, 2.0, np.nan), - # no spikes - (np.ones((5, 256)), 18.0, 16.0, np.nan), - # tf*trial_duration is too high, returns nan - (np.full((5, 256), 5.0), 4.0, 2.0, 0.0), - # has constant spiking - (np.array([0, 0, 1, 1, 2, 0, 5, 1]), 2.0, 1.0, - 0.894427190999916), # can handle arrays - (np.array([[0, 0, 1, 1, 2, 0, 5, 1]]), 2.0, 1.0, - 0.894427190999916) - # same as above but int matrix form - ]) + assert np.isclose(c50_metric, expected, equal_nan=True) + + +@pytest.mark.parametrize( + "data_arr,tf,trial_duration,expected", + [ + (np.array([]), 2.0, 1.0, np.nan), # invalid input + (np.zeros((5, 256)), 4.0, 2.0, np.nan), + # no spikes + (np.ones((5, 256)), 18.0, 16.0, np.nan), + # tf*trial_duration is too high, returns nan + (np.full((5, 256), 5.0), 4.0, 2.0, 0.0), + # has constant spiking + (np.array([0, 0, 1, 1, 2, 0, 5, 1]), 2.0, 1.0, 0.894427190999916), # can handle arrays + (np.array([[0, 0, 1, 1, 2, 0, 5, 1]]), 2.0, 1.0, 0.894427190999916), + # same as above but int matrix form + ], +) def test_f1_f0(data_arr, tf, trial_duration, expected): f1_f0_val = f1_f0(data_arr, tf, trial_duration) - assert (np.isclose(f1_f0_val, expected, equal_nan=True)) + assert np.isclose(f1_f0_val, expected, equal_nan=True) -if __name__ == '__main__': +if __name__ == "__main__": # test_stimulus() test_metrics() # test_stim_table_contrast() diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_flashes.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_flashes.py index 5d245999d5..c5f848562c 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_flashes.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_flashes.py @@ -9,21 +9,23 @@ class MockFlSessionApi(MockSessionApi): def get_stimulus_presentations(self): - return pd.DataFrame({ - 'start_time': np.concatenate(([0.0], np.linspace(0.5, 4.25, 16, endpoint=True), [4.5])), - 'stop_time': np.concatenate(([0.5], np.linspace(0.75, 4.5, 16, endpoint=True), [5.0])), - 'stimulus_name': ['spontaneous'] + ['flashes']*16 + ['spontaneous'], - 'stimulus_block': [0] + [1]*16 + [0], - 'duration': [0.5] + [0.25]*16 + [0.5], - 'stimulus_index': [0] + [1]*16 + [0], - 'color': [np.nan, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, np.nan] - }, index=pd.Index(name='id', data=np.arange(18))) + return pd.DataFrame( + { + "start_time": np.concatenate(([0.0], np.linspace(0.5, 4.25, 16, endpoint=True), [4.5])), + "stop_time": np.concatenate(([0.5], np.linspace(0.75, 4.5, 16, endpoint=True), [5.0])), + "stimulus_name": ["spontaneous"] + ["flashes"] * 16 + ["spontaneous"], + "stimulus_block": [0] + [1] * 16 + [0], + "duration": [0.5] + [0.25] * 16 + [0.5], + "stimulus_index": [0] + [1] * 16 + [0], + "color": [np.nan, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, np.nan], + }, + index=pd.Index(name="id", data=np.arange(18)), + ) def get_invalid_times(self): return pd.DataFrame() - @pytest.fixture def ecephys_api(): return MockFlSessionApi() @@ -32,50 +34,54 @@ def ecephys_api(): def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) fl = Flashes(ecephys_session=session) - assert(fl.name == 'Flashes') - assert(set(fl.unit_ids) == set(range(6))) - assert(len(fl.conditionwise_statistics) == 2*6) - assert(fl.conditionwise_psth.shape == (2, 249, 6)) - assert(not fl.presentationwise_spike_times.empty) - assert(len(fl.presentationwise_statistics) == 16*6) - assert(len(fl.stimulus_conditions) == 2) + assert fl.name == "Flashes" + assert set(fl.unit_ids) == set(range(6)) + assert len(fl.conditionwise_statistics) == 2 * 6 + assert fl.conditionwise_psth.shape == (2, 249, 6) + assert not fl.presentationwise_spike_times.empty + assert len(fl.presentationwise_statistics) == 16 * 6 + assert len(fl.stimulus_conditions) == 2 def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) fl = Flashes(ecephys_session=session) - assert(isinstance(fl.stim_table, pd.DataFrame)) - assert(len(fl.stim_table) == 16) - assert(set(fl.stim_table.columns).issuperset({'color', 'start_time', 'stop_time'})) + assert isinstance(fl.stim_table, pd.DataFrame) + assert len(fl.stim_table) == 16 + assert set(fl.stim_table.columns).issuperset({"color", "start_time", "stop_time"}) - assert(all(fl.colors == [-1.0, 1.0])) - assert(fl.number_colors == 2) + assert all(fl.colors == [-1.0, 1.0]) + assert fl.number_colors == 2 def test_metrics(ecephys_api): session = EcephysSession(api=ecephys_api) fl = Flashes(ecephys_session=session) - assert(isinstance(fl.metrics, pd.DataFrame)) - assert(len(fl.metrics) == 6) - assert(fl.metrics.index.names == ['unit_id']) - - assert('on_off_ratio_fl' in fl.metrics.columns) - assert(np.allclose(fl.metrics['on_off_ratio_fl'].loc[[0, 1, 2, 3, 4, 5]], - [0.0, np.nan, 0.0, np.nan, 3.0, 2.0], equal_nan=True)) # Check _get_on_off_ratio() method - - assert('sustained_idx_fl' in fl.metrics.columns) - assert(np.allclose(fl.metrics['sustained_idx_fl'].loc[[0, 1, 2, 3, 4, 5]].values, - [0.00401606, np.nan, 0.01204819, np.nan, 0.02811245, 0.00401606], equal_nan=True)) - - assert('firing_rate_fl' in fl.metrics.columns) - assert('time_to_peak_fl' in fl.metrics.columns) - assert('fano_fl' in fl.metrics.columns) - assert('lifetime_sparseness_fl' in fl.metrics.columns) - assert('run_pval_fl' in fl.metrics.columns) - assert('run_mod_fl' in fl.metrics.columns) - - -if __name__ == '__main__': + assert isinstance(fl.metrics, pd.DataFrame) + assert len(fl.metrics) == 6 + assert fl.metrics.index.names == ["unit_id"] + + assert "on_off_ratio_fl" in fl.metrics.columns + assert np.allclose( + fl.metrics["on_off_ratio_fl"].loc[[0, 1, 2, 3, 4, 5]], [0.0, np.nan, 0.0, np.nan, 3.0, 2.0], equal_nan=True + ) # Check _get_on_off_ratio() method + + assert "sustained_idx_fl" in fl.metrics.columns + assert np.allclose( + fl.metrics["sustained_idx_fl"].loc[[0, 1, 2, 3, 4, 5]].values, + [0.00401606, np.nan, 0.01204819, np.nan, 0.02811245, 0.00401606], + equal_nan=True, + ) + + assert "firing_rate_fl" in fl.metrics.columns + assert "time_to_peak_fl" in fl.metrics.columns + assert "fano_fl" in fl.metrics.columns + assert "lifetime_sparseness_fl" in fl.metrics.columns + assert "run_pval_fl" in fl.metrics.columns + assert "run_mod_fl" in fl.metrics.columns + + +if __name__ == "__main__": # test_load() # test_stimulus() test_metrics() diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_movies.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_movies.py index 68b5fd942c..60192fc644 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_movies.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_movies.py @@ -16,12 +16,12 @@ def ecephys_api(): return MockNMSessionApi() -@pytest.mark.skip(reason='NaturalMovies not fully implemented.') +@pytest.mark.skip(reason="NaturalMovies not fully implemented.") def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) nm = NaturalMovies(ecephys_session=session) - assert(nm.name == 'Natural Movies') - assert(set(nm.unit_ids) == set(range(6))) + assert nm.name == "Natural Movies" + assert set(nm.unit_ids) == set(range(6)) # assert(len(nm.conditionwise_statistics) == 119*6) # assert(nm.conditionwise_psth.shape == (119, 249, 6)) # assert(not nm.presentationwise_spike_times.empty) @@ -29,11 +29,11 @@ def test_load(ecephys_api): # assert(len(nm.stimulus_conditions) == 119) -@pytest.mark.skip(reason='NaturalMovies not fully implemented.') +@pytest.mark.skip(reason="NaturalMovies not fully implemented.") def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) nm = NaturalMovies(ecephys_session=session) - assert(isinstance(nm.stim_table, pd.DataFrame)) + assert isinstance(nm.stim_table, pd.DataFrame) # assert(len(nm.stim_table) == 119) # assert(set(nm.stim_table.columns).issuperset({'frame', 'start_time', 'stop_time'})) # assert(np.all(nm.images == np.arange(-1.0, 118))) @@ -41,16 +41,16 @@ def test_stimulus(ecephys_api): # assert(nm.number_nonblank == 118) -@pytest.mark.skip(reason='NaturalMovies not fully implemented.') +@pytest.mark.skip(reason="NaturalMovies not fully implemented.") def test_metrics(ecephys_api): session = EcephysSession(api=ecephys_api) nm = NaturalMovies(ecephys_session=session) - assert(isinstance(nm.metrics, pd.DataFrame)) - assert(len(nm.metrics) == 6) - assert(nm.metrics.index.names == ['unit_id']) - - assert('fano_nm' in nm.metrics.columns) - assert('firing_rate_nm' in nm.metrics.columns) - assert('lifetime_sparseness_nm' in nm.metrics.columns) - assert('run_pval_nm' in nm.metrics.columns) - assert('run_mod_nm' in nm.metrics.columns) + assert isinstance(nm.metrics, pd.DataFrame) + assert len(nm.metrics) == 6 + assert nm.metrics.index.names == ["unit_id"] + + assert "fano_nm" in nm.metrics.columns + assert "firing_rate_nm" in nm.metrics.columns + assert "lifetime_sparseness_nm" in nm.metrics.columns + assert "run_pval_nm" in nm.metrics.columns + assert "run_mod_nm" in nm.metrics.columns diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_scenes.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_scenes.py index c5b6e52a63..aaecf283ac 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_scenes.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_natural_scenes.py @@ -9,19 +9,23 @@ class MockNSSessionApi(MockSessionApi): def get_stimulus_presentations(self): - return pd.DataFrame({ - 'start_time': np.concatenate(([0.0], np.linspace(0.5, 29.50, 119, endpoint=True), [39.75])), - 'stop_time': np.concatenate(([0.5], np.linspace(0.75, 39.75, 119, endpoint=True), [40.25])), - 'stimulus_name': ['spontaneous'] + ['natural_scenes']*119 + ['spontaneous'], - 'stimulus_block': [0] + [1]*119 + [0], - 'duration': [0.5] + [0.25]*119 + [0.5], - 'stimulus_index': [0] + [1]*119 + [0], - 'frame': np.concatenate(([np.nan], np.arange(-1.0, 118.0), [np.nan])) - }, index=pd.Index(name='id', data=np.arange(121))) + return pd.DataFrame( + { + "start_time": np.concatenate(([0.0], np.linspace(0.5, 29.50, 119, endpoint=True), [39.75])), + "stop_time": np.concatenate(([0.5], np.linspace(0.75, 39.75, 119, endpoint=True), [40.25])), + "stimulus_name": ["spontaneous"] + ["natural_scenes"] * 119 + ["spontaneous"], + "stimulus_block": [0] + [1] * 119 + [0], + "duration": [0.5] + [0.25] * 119 + [0.5], + "stimulus_index": [0] + [1] * 119 + [0], + "frame": np.concatenate(([np.nan], np.arange(-1.0, 118.0), [np.nan])), + }, + index=pd.Index(name="id", data=np.arange(121)), + ) def get_invalid_times(self): return pd.DataFrame() + @pytest.fixture def ecephys_api(): return MockNSSessionApi() @@ -30,62 +34,64 @@ def ecephys_api(): def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) ns = NaturalScenes(ecephys_session=session) - assert(ns.name == 'Natural Scenes') - assert(set(ns.unit_ids) == set(range(6))) - assert(len(ns.conditionwise_statistics) == 119*6) - assert(ns.conditionwise_psth.shape == (119, 249, 6)) - assert(not ns.presentationwise_spike_times.empty) - assert(len(ns.presentationwise_statistics) == 119*6) - assert(len(ns.stimulus_conditions) == 119) + assert ns.name == "Natural Scenes" + assert set(ns.unit_ids) == set(range(6)) + assert len(ns.conditionwise_statistics) == 119 * 6 + assert ns.conditionwise_psth.shape == (119, 249, 6) + assert not ns.presentationwise_spike_times.empty + assert len(ns.presentationwise_statistics) == 119 * 6 + assert len(ns.stimulus_conditions) == 119 def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) ns = NaturalScenes(ecephys_session=session) - assert(isinstance(ns.stim_table, pd.DataFrame)) - assert(len(ns.stim_table) == 119) - assert(set(ns.stim_table.columns).issuperset({'frame', 'start_time', 'stop_time'})) + assert isinstance(ns.stim_table, pd.DataFrame) + assert len(ns.stim_table) == 119 + assert set(ns.stim_table.columns).issuperset({"frame", "start_time", "stop_time"}) - assert(np.all(ns.images == np.arange(-1.0, 118))) - assert(ns.number_images == 119) - assert(ns.number_nonblank == 118) + assert np.all(ns.images == np.arange(-1.0, 118)) + assert ns.number_images == 119 + assert ns.number_nonblank == 118 def test_metrics(ecephys_api): session = EcephysSession(api=ecephys_api) ns = NaturalScenes(ecephys_session=session) - assert(isinstance(ns.metrics, pd.DataFrame)) - assert(len(ns.metrics) == 6) - assert(ns.metrics.index.names == ['unit_id']) - - assert('pref_image_ns' in ns.metrics.columns) - assert(np.all(ns.metrics['pref_image_ns'].loc[[0, 1, 3, 4]] == [2, 9, 2, 4])) - - assert('image_selectivity_ns' in ns.metrics.columns) - assert('firing_rate_ns' in ns.metrics.columns) - assert('fano_ns' in ns.metrics.columns) - assert('time_to_peak_ns' in ns.metrics.columns) - assert('lifetime_sparseness_ns' in ns.metrics.columns) - assert('run_pval_ns' in ns.metrics.columns) - assert('run_mod_ns' in ns.metrics.columns) - - -@pytest.mark.parametrize('responses,expected', - [ - (np.array([]), np.nan), # invalid input - (np.array([1.0]), np.nan), # selectivity of one image is undefined - (np.array([0.0]), np.nan), - (np.zeros(118), 0.0), # responds uniformly - (np.ones(118), 0.0), # responds uniformly - (np.array([0.0]*200 + [1.0]), 0.99004975), # reponse to 1 image ~ 1.0 - (np.array([5.5, 0.0, 15.0, 10.0, 2.3, 4.9]), 0.16166666666666674) - ]) + assert isinstance(ns.metrics, pd.DataFrame) + assert len(ns.metrics) == 6 + assert ns.metrics.index.names == ["unit_id"] + + assert "pref_image_ns" in ns.metrics.columns + assert np.all(ns.metrics["pref_image_ns"].loc[[0, 1, 3, 4]] == [2, 9, 2, 4]) + + assert "image_selectivity_ns" in ns.metrics.columns + assert "firing_rate_ns" in ns.metrics.columns + assert "fano_ns" in ns.metrics.columns + assert "time_to_peak_ns" in ns.metrics.columns + assert "lifetime_sparseness_ns" in ns.metrics.columns + assert "run_pval_ns" in ns.metrics.columns + assert "run_mod_ns" in ns.metrics.columns + + +@pytest.mark.parametrize( + "responses,expected", + [ + (np.array([]), np.nan), # invalid input + (np.array([1.0]), np.nan), # selectivity of one image is undefined + (np.array([0.0]), np.nan), + (np.zeros(118), 0.0), # responds uniformly + (np.ones(118), 0.0), # responds uniformly + (np.array([0.0] * 200 + [1.0]), 0.99004975), # reponse to 1 image ~ 1.0 + (np.array([5.5, 0.0, 15.0, 10.0, 2.3, 4.9]), 0.16166666666666674), + ], +) def test_image_selectivity(responses, expected): img_sel = image_selectivity(responses) - assert(np.isclose(img_sel, expected, equal_nan=True)) + assert np.isclose(img_sel, expected, equal_nan=True) -if __name__ == '__main__': +if __name__ == "__main__": test_load() # test_stimulus() # test_metrics() diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_receptive_field_mapping.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_receptive_field_mapping.py index b32e2b4a93..03b4b7d534 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_receptive_field_mapping.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_receptive_field_mapping.py @@ -3,7 +3,7 @@ import pytest from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession -from allensdk.brain_observatory.ecephys.stimulus_analysis.receptive_field_mapping import ( # noqa +from allensdk.brain_observatory.ecephys.stimulus_analysis.receptive_field_mapping import ( # noqa ReceptiveFieldMapping, fit_2d_gaussian, threshold_rf, @@ -47,18 +47,12 @@ def get_stimulus_presentations(self): [21.25], ) ), - "stimulus_name": ["spontaneous"] - + ["gabors"] * 81 - + ["spontaneous"], + "stimulus_name": ["spontaneous"] + ["gabors"] * 81 + ["spontaneous"], "stimulus_block": [0] + [1] * 81 + [0], "duration": [0.5] + [0.25] * 81 + [0.5], "stimulus_index": [0] + [1] * 81 + [0], - "x_position": np.concatenate( - ([np.nan], features[0, :], [np.nan]) - ), - "y_position": np.concatenate( - ([np.nan], features[1, :], [np.nan]) - ), + "x_position": np.concatenate(([np.nan], features[0, :], [np.nan])), + "y_position": np.concatenate(([np.nan], features[1, :], [np.nan])), }, index=pd.Index(name="id", data=np.arange(83)), ) @@ -86,9 +80,7 @@ def test_stimulus(ecephys_api): rfm = ReceptiveFieldMapping(ecephys_session=session) assert isinstance(rfm.stim_table, pd.DataFrame) assert len(rfm.stim_table) == 81 - assert set(rfm.stim_table.columns).issuperset( - {"x_position", "y_position", "start_time", "stop_time"} - ) + assert set(rfm.stim_table.columns).issuperset({"x_position", "y_position", "start_time", "stop_time"}) assert set(rfm.azimuths) == { 30.0, @@ -182,24 +174,10 @@ def test_receptive_fields(ecephys_api): ) # Some randomly sampled testing to make sure everything works as expected - assert ( - rfm.receptive_fields["spike_counts"][{"unit_id": 0}].values.sum() == 4 - ) - assert ( - rfm.receptive_fields["spike_counts"][{"unit_id": 3}].values.sum() == 0 - ) - assert ( - rfm.receptive_fields["spike_counts"][ - {"unit_id": 2, "x_position": 8, "y_position": 3} - ] - == 3 - ) - assert np.all( - rfm.receptive_fields["spike_counts"][ - {"x_position": 2, "y_position": 5} - ] - == [1, 0, 0, 0, 1, 1] - ) + assert rfm.receptive_fields["spike_counts"][{"unit_id": 0}].values.sum() == 4 + assert rfm.receptive_fields["spike_counts"][{"unit_id": 3}].values.sum() == 0 + assert rfm.receptive_fields["spike_counts"][{"unit_id": 2, "x_position": 8, "y_position": 3}] == 3 + assert np.all(rfm.receptive_fields["spike_counts"][{"x_position": 2, "y_position": 5}] == [1, 0, 0, 0, 1, 1]) # Some special receptive fields for testing @@ -222,9 +200,7 @@ def test_receptive_fields(ecephys_api): # RF as a typical gaussian x, y = np.meshgrid(np.linspace(-1, 1, 9), np.linspace(-1, 1, 9)) -rf_field_gaussian = np.exp( - -((np.sqrt(x * x + y * y) - 0.0) ** 2 / (2.0 * 1.0**2)) -) +rf_field_gaussian = np.exp(-((np.sqrt(x * x + y * y) - 0.0) ** 2 / (2.0 * 1.0**2))) # Only activity at one of the corners of the field rf_field_edge = np.zeros((9, 9)) @@ -255,9 +231,7 @@ def test_receptive_fields(ecephys_api): (rf_field_edge, 0.05, None, 8.0, 8.0, 1.0), ], ) -def test_threshold_rf( - rf, threshold, expected_mask, expected_x, expected_y, expected_area -): +def test_threshold_rf(rf, threshold, expected_mask, expected_x, expected_y, expected_area): mask_rf, x, y, area = threshold_rf(rf, threshold) assert np.isclose(x, expected_x, equal_nan=True) assert np.isclose(y, expected_y, equal_nan=True) @@ -299,7 +273,5 @@ def test_fit_2d_gaussian(matrix, expected): if __name__ == "__main__": test_metrics() - test_fit_2d_gaussian( - rf_field_edge, (np.array([5.0, 8.0, 8.0, 0.0, 0.0]), True) - ) + test_fit_2d_gaussian(rf_field_edge, (np.array([5.0, 8.0, 8.0, 0.0, 0.0]), True)) pass diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_static_gratings.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_static_gratings.py index 0873a399d8..5efb1a7efb 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_static_gratings.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_static_gratings.py @@ -9,21 +9,28 @@ class MockSGSessionApi(MockSessionApi): def get_stimulus_presentations(self): - features = np.array(np.meshgrid([0.02, 0.04, 0.08, 0.16, 0.32], # SF - [0.0, 30.0, 60.0, 90.0, 120.0, 150.0], # ORI - [0.0, 0.25, 0.50, 0.75])).reshape(3, 120) # Phase - - return pd.DataFrame({ - 'start_time': np.concatenate(([0.0], np.linspace(0.5, 30.25, 120, endpoint=True), [31.5])), - 'stop_time': np.concatenate(([0.5], np.linspace(0.75, 30.50, 120, endpoint=True), [32.0])), - 'stimulus_name': ['spontaneous'] + ['static_gratings']*120 + ['spontaneous'], - 'stimulus_block': [0] + [1]*120 + [0], - 'duration': [0.5] + [0.25]*120 + [0.5], - 'stimulus_index': [0] + [1]*120 + [0], - 'spatial_frequency': np.concatenate(([np.nan], features[0, :], [np.nan])), - 'orientation': np.concatenate(([np.nan], features[1, :], [np.nan])), - 'phase': np.concatenate(([np.nan], features[2, :], [np.nan])) - }, index=pd.Index(name='id', data=np.arange(122))) + features = np.array( + np.meshgrid( + [0.02, 0.04, 0.08, 0.16, 0.32], # SF + [0.0, 30.0, 60.0, 90.0, 120.0, 150.0], # ORI + [0.0, 0.25, 0.50, 0.75], + ) + ).reshape(3, 120) # Phase + + return pd.DataFrame( + { + "start_time": np.concatenate(([0.0], np.linspace(0.5, 30.25, 120, endpoint=True), [31.5])), + "stop_time": np.concatenate(([0.5], np.linspace(0.75, 30.50, 120, endpoint=True), [32.0])), + "stimulus_name": ["spontaneous"] + ["static_gratings"] * 120 + ["spontaneous"], + "stimulus_block": [0] + [1] * 120 + [0], + "duration": [0.5] + [0.25] * 120 + [0.5], + "stimulus_index": [0] + [1] * 120 + [0], + "spatial_frequency": np.concatenate(([np.nan], features[0, :], [np.nan])), + "orientation": np.concatenate(([np.nan], features[1, :], [np.nan])), + "phase": np.concatenate(([np.nan], features[2, :], [np.nan])), + }, + index=pd.Index(name="id", data=np.arange(122)), + ) @pytest.fixture @@ -34,102 +41,156 @@ def ecephys_api(): def test_load(ecephys_api): session = EcephysSession(api=ecephys_api) sg = StaticGratings(ecephys_session=session) - assert(sg.name == 'Static Gratings') - assert(set(sg.unit_ids) == set(range(6))) - assert(len(sg.conditionwise_statistics) == 120*6) - assert(sg.conditionwise_psth.shape == (120, 249, 6)) - assert(not sg.presentationwise_spike_times.empty) - assert(len(sg.presentationwise_statistics) == 120*6) - assert(len(sg.stimulus_conditions) == 120) + assert sg.name == "Static Gratings" + assert set(sg.unit_ids) == set(range(6)) + assert len(sg.conditionwise_statistics) == 120 * 6 + assert sg.conditionwise_psth.shape == (120, 249, 6) + assert not sg.presentationwise_spike_times.empty + assert len(sg.presentationwise_statistics) == 120 * 6 + assert len(sg.stimulus_conditions) == 120 def test_stimulus(ecephys_api): session = EcephysSession(api=ecephys_api) sg = StaticGratings(ecephys_session=session) - assert(isinstance(sg.stim_table, pd.DataFrame)) - assert(len(sg.stim_table) == 120) - assert(set(sg.stim_table.columns).issuperset({'spatial_frequency', 'orientation', 'phase', 'start_time', 'stop_time'})) + assert isinstance(sg.stim_table, pd.DataFrame) + assert len(sg.stim_table) == 120 + assert set(sg.stim_table.columns).issuperset( + {"spatial_frequency", "orientation", "phase", "start_time", "stop_time"} + ) - assert(set(sg.sfvals) == {0.02, 0.04, 0.08, 0.16, 0.32}) - assert(sg.number_sf == 5) + assert set(sg.sfvals) == {0.02, 0.04, 0.08, 0.16, 0.32} + assert sg.number_sf == 5 - assert(set(sg.orivals) == {0.0, 30.0, 60.0, 90.0, 120.0, 150.0}) - assert(sg.number_ori == 6) + assert set(sg.orivals) == {0.0, 30.0, 60.0, 90.0, 120.0, 150.0} + assert sg.number_ori == 6 - assert(set(sg.phasevals) == {0.0, 0.25, 0.50, 0.75}) - assert(sg.number_phase == 4) + assert set(sg.phasevals) == {0.0, 0.25, 0.50, 0.75} + assert sg.number_phase == 4 def test_bad_stimulus_key(ecephys_api): with pytest.raises(Exception): session = EcephysSession(api=ecephys_api) - sg = StaticGratings(ecephys_session=session, stimulus_key='gratings static') + sg = StaticGratings(ecephys_session=session, stimulus_key="gratings static") sg.stim_table def test_bad_col_key(ecephys_api): with pytest.raises(KeyError): session = EcephysSession(api=ecephys_api) - sg = StaticGratings(ecephys_session=session, col_sf='spatial_frequency', col_phase='esahp') + sg = StaticGratings(ecephys_session=session, col_sf="spatial_frequency", col_phase="esahp") sg.phasevals def test_metrics(ecephys_api): session = EcephysSession(api=ecephys_api) sg = StaticGratings(ecephys_session=session) - assert(isinstance(sg.metrics, pd.DataFrame)) - assert(len(sg.metrics) == 6) - assert(sg.metrics.index.names == ['unit_id']) - - assert('pref_sf_sg' in sg.metrics.columns) - assert(np.all(sg.metrics['pref_sf_sg'].loc[[0, 2, 4]] == [0.02, 0.02, 0.04])) - - assert('pref_ori_sg' in sg.metrics.columns) - assert(np.all(sg.metrics['pref_ori_sg'].loc[[0, 2, 4]] == [0.0, 0.0, 0.0])) - - assert('pref_phase_sg' in sg.metrics.columns) - assert(np.all(sg.metrics['pref_phase_sg'].loc[[0, 1, 2, 3]] == [0.25, 0.75, 0.5, 0.0])) - - assert('g_osi_sg' in sg.metrics.columns) - assert('time_to_peak_sg' in sg.metrics.columns) - assert('firing_rate_sg' in sg.metrics.columns) - assert('fano_sg' in sg.metrics.columns) - assert('lifetime_sparseness_sg' in sg.metrics.columns) - assert('run_pval_sg' in sg.metrics.columns) - assert('run_mod_sg' in sg.metrics.columns) - - -@pytest.mark.parametrize('sf_tuning_responses,mean_sweeps_trials,expected', - [ - (np.array([18.08333, 19.8333, 28.333, 14.80, 9.6170]), - np.array([12.0, 4.0, 8.0, 32.0, 4.0, 0.0, 4.0, 8.0, 24.0, 40.0, 32.0, 8.0, 20.0, 28.0, - 24.0, 28.0, 0.0, 4.0, 4.0, 24.0, 16.0, 8.0, 16.0, 4.0, 0.0, 4.0, 24.0, 4.0, - 12.0, 20.0, 0.0, 12.0, 0.0, 16.0]), 0.4402349784724991) - ]) + assert isinstance(sg.metrics, pd.DataFrame) + assert len(sg.metrics) == 6 + assert sg.metrics.index.names == ["unit_id"] + + assert "pref_sf_sg" in sg.metrics.columns + assert np.all(sg.metrics["pref_sf_sg"].loc[[0, 2, 4]] == [0.02, 0.02, 0.04]) + + assert "pref_ori_sg" in sg.metrics.columns + assert np.all(sg.metrics["pref_ori_sg"].loc[[0, 2, 4]] == [0.0, 0.0, 0.0]) + + assert "pref_phase_sg" in sg.metrics.columns + assert np.all(sg.metrics["pref_phase_sg"].loc[[0, 1, 2, 3]] == [0.25, 0.75, 0.5, 0.0]) + + assert "g_osi_sg" in sg.metrics.columns + assert "time_to_peak_sg" in sg.metrics.columns + assert "firing_rate_sg" in sg.metrics.columns + assert "fano_sg" in sg.metrics.columns + assert "lifetime_sparseness_sg" in sg.metrics.columns + assert "run_pval_sg" in sg.metrics.columns + assert "run_mod_sg" in sg.metrics.columns + + +@pytest.mark.parametrize( + "sf_tuning_responses,mean_sweeps_trials,expected", + [ + ( + np.array([18.08333, 19.8333, 28.333, 14.80, 9.6170]), + np.array( + [ + 12.0, + 4.0, + 8.0, + 32.0, + 4.0, + 0.0, + 4.0, + 8.0, + 24.0, + 40.0, + 32.0, + 8.0, + 20.0, + 28.0, + 24.0, + 28.0, + 0.0, + 4.0, + 4.0, + 24.0, + 16.0, + 8.0, + 16.0, + 4.0, + 0.0, + 4.0, + 24.0, + 4.0, + 12.0, + 20.0, + 0.0, + 12.0, + 0.0, + 16.0, + ] + ), + 0.4402349784724991, + ) + ], +) def test_get_sfdi(sf_tuning_responses, mean_sweeps_trials, expected): - assert(get_sfdi(sf_tuning_responses, mean_sweeps_trials, len(sf_tuning_responses)) == expected) - - -@pytest.mark.parametrize('sf_tuning_response,sf_vals,pref_sf_index,expected', - [ - (np.array([2.69565217, 3.91836735, 2.36734694, 1.52, 2.21276596]), - [0.02, 0.04, 0.08, 0.16, 0.32], 1, (0.22704947240176027, 0.0234087755414, np.nan, np.nan) - ), - (np.array([1.14285714, 0.73469388, 7.44, 13.6, 11.6]), [0.02, 0.04, 0.08, 0.16, 0.32], 3, - (3.290141840632274, 0.1956416782774323, 0.08, np.nan)), - (np.array([2.24, 1.83333333, 1.68, 1.87755102, 1.87755102]), - [0.02, 0.04, 0.08, 0.16, 0.32], 0, (0.0, 0.019999999552965164, np.nan, 0.32)) - ]) + assert get_sfdi(sf_tuning_responses, mean_sweeps_trials, len(sf_tuning_responses)) == expected + + +@pytest.mark.parametrize( + "sf_tuning_response,sf_vals,pref_sf_index,expected", + [ + ( + np.array([2.69565217, 3.91836735, 2.36734694, 1.52, 2.21276596]), + [0.02, 0.04, 0.08, 0.16, 0.32], + 1, + (0.22704947240176027, 0.0234087755414, np.nan, np.nan), + ), + ( + np.array([1.14285714, 0.73469388, 7.44, 13.6, 11.6]), + [0.02, 0.04, 0.08, 0.16, 0.32], + 3, + (3.290141840632274, 0.1956416782774323, 0.08, np.nan), + ), + ( + np.array([2.24, 1.83333333, 1.68, 1.87755102, 1.87755102]), + [0.02, 0.04, 0.08, 0.16, 0.32], + 0, + (0.0, 0.019999999552965164, np.nan, 0.32), + ), + ], +) def test_fit_sf_tuning(sf_tuning_response, sf_vals, pref_sf_index, expected): result = fit_sf_tuning(sf_tuning_response, sf_vals, pref_sf_index) assert np.allclose(result, expected, equal_nan=True, rtol=1e-3, atol=1e-4) -if __name__ == '__main__': +if __name__ == "__main__": # test_stimulus() # test_load() # test_bad_stimulus_key() # test_bad_col_key() test_metrics() pass - diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_stimulus_analysis.py b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_stimulus_analysis.py index 2cf28cd791..dfb82fa608 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_stimulus_analysis.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_analysis/test_stimulus_analysis.py @@ -77,10 +77,7 @@ def get_stimulus_presentations(self): { "start_time": np.linspace(0.0, 4.5, 10, endpoint=True), "stop_time": np.linspace(0.5, 5.0, 10, endpoint=True), - "stimulus_name": ["spontaneous"] - + ["s0"] * 6 - + ["spontaneous"] - + ["s1"] * 2, + "stimulus_name": ["spontaneous"] + ["s0"] * 6 + ["spontaneous"] + ["s1"] * 2, "stimulus_block": [0] + [1] * 6 + [0] + [2] * 2, "duration": 0.5, "stimulus_index": [0] + [1] * 6 + [0] + [2] * 2, @@ -131,48 +128,36 @@ def test_unit_ids_filter_by_id(ecephys_api): assert set(stim_analysis.unit_ids) == {1, 2, 3} assert stim_analysis.unit_count == 3 - stim_analysis = StimulusAnalysis( - ecephys_session=session, filter={"unit_id": [3, 0]} - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, filter={"unit_id": [3, 0]}) assert set(stim_analysis.unit_ids) == {0, 3} assert stim_analysis.unit_count == 2 with pytest.raises(KeyError): # If unit ids don't exists should raise an error - stim_analysis = StimulusAnalysis( - ecephys_session=session, filter=[100, 200] - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, filter=[100, 200]) stim_analysis.unit_ids def test_unit_ids_filtered(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, filter={"location": "VISp"} - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, filter={"location": "VISp"}) assert set(stim_analysis.unit_ids) == {0, 2, 3, 5} assert stim_analysis.unit_count == 4 - stim_analysis = StimulusAnalysis( - ecephys_session=session, filter={"location": "VISp", "quality": "good"} - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, filter={"location": "VISp", "quality": "good"}) assert set(stim_analysis.unit_ids) == {0, 3, 5} assert stim_analysis.unit_count == 3 with pytest.raises(Exception): # No units found should raise exception - stim_analysis = StimulusAnalysis( - ecephys_session=session, filter={"location": "pSIV"} - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, filter={"location": "pSIV"}) stim_analysis.unit_ids stim_analysis.unit_count def test_stim_table(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") assert isinstance(stim_analysis.stim_table, pd.DataFrame) assert len(stim_analysis.stim_table) == 6 assert stim_analysis.total_presentations == 6 @@ -185,9 +170,7 @@ def test_stim_table(ecephys_api): assert "duration" in stim_analysis.stim_table with pytest.raises(Exception): - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="0s" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="0s") stim_analysis.stim_table @@ -195,16 +178,12 @@ def test_stim_table_spontaneous(ecephys_api): # By default table should be empty because non of the stimulus are above # the duration threshold session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, spontaneous_threshold=0.49 - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, spontaneous_threshold=0.49) assert isinstance(stim_analysis.stim_table_spontaneous, pd.DataFrame) assert len(stim_analysis.stim_table_spontaneous) == 2 # Check that threshold is working - stim_analysis = StimulusAnalysis( - ecephys_session=session, spontaneous_threshold=0.51 - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, spontaneous_threshold=0.51) assert len(stim_analysis.stim_table_spontaneous) == 0 @@ -218,21 +197,11 @@ def test_conditionwise_psth(ecephys_api): ) assert isinstance(stim_analysis.conditionwise_psth, xr.DataArray) # assert(stim_analysis.conditionwise_psth.shape == (2, 4, 6)) - assert ( - stim_analysis.conditionwise_psth.coords[ - "time_relative_to_stimulus_onset" - ].size - == 4 - ) # 0.5/0.1 - 1 + assert stim_analysis.conditionwise_psth.coords["time_relative_to_stimulus_onset"].size == 4 # 0.5/0.1 - 1 assert stim_analysis.conditionwise_psth.coords["unit_id"].size == 6 - assert ( - stim_analysis.conditionwise_psth.coords["stimulus_condition_id"].size - == 2 - ) + assert stim_analysis.conditionwise_psth.coords["stimulus_condition_id"].size == 2 assert np.allclose( - stim_analysis.conditionwise_psth[ - {"unit_id": 0, "stimulus_condition_id": 1} - ].values, + stim_analysis.conditionwise_psth[{"unit_id": 0, "stimulus_condition_id": 1}].values, np.array([1.0 / 3.0, 0.0, 0.0, 0.0]), ) @@ -243,27 +212,15 @@ def test_conditionwise_psth(ecephys_api): trial_duration=0.5, psth_resolution=0.1, ) - assert ( - stim_analysis.conditionwise_psth.coords[ - "time_relative_to_stimulus_onset" - ].size - == 4 - ) + assert stim_analysis.conditionwise_psth.coords["time_relative_to_stimulus_onset"].size == 4 assert stim_analysis.conditionwise_psth.coords["unit_id"].size == 6 - assert ( - stim_analysis.conditionwise_psth.coords["stimulus_condition_id"].size - == 2 - ) + assert stim_analysis.conditionwise_psth.coords["stimulus_condition_id"].size == 2 def test_conditionwise_statistics(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) - assert ( - len(stim_analysis.conditionwise_statistics) == 2 * 6 - ) # units x condition_ids + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") + assert len(stim_analysis.conditionwise_statistics) == 2 * 6 # units x condition_ids assert set(stim_analysis.conditionwise_statistics.index.names) == { "unit_id", "stimulus_condition_id", @@ -299,36 +256,23 @@ def test_conditionwise_statistics(ecephys_api): def test_presentationwise_spike_times(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") assert len(stim_analysis.presentationwise_spike_times) == 12 - assert list(stim_analysis.presentationwise_spike_times.index.names) == [ - "spike_time" - ] + assert list(stim_analysis.presentationwise_spike_times.index.names) == ["spike_time"] assert set(stim_analysis.presentationwise_spike_times.columns) == { "stimulus_presentation_id", "unit_id", "time_since_stimulus_presentation_onset", } assert stim_analysis.presentationwise_spike_times.loc[1.01]["unit_id"] == 2 - assert ( - stim_analysis.presentationwise_spike_times.loc[1.01][ - "stimulus_presentation_id" - ] - == 2 - ) + assert stim_analysis.presentationwise_spike_times.loc[1.01]["stimulus_presentation_id"] == 2 assert len(stim_analysis.presentationwise_spike_times.loc[3.0]) == 2 def test_presentationwise_statistics(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0", trial_duration=0.5 - ) - assert ( - len(stim_analysis.presentationwise_statistics) == 6 * 6 - ) # units x presentation_ids + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0", trial_duration=0.5) + assert len(stim_analysis.presentationwise_statistics) == 6 * 6 # units x presentation_ids assert set(stim_analysis.presentationwise_statistics.index.names) == { "stimulus_presentation_id", "unit_id", @@ -338,16 +282,8 @@ def test_presentationwise_statistics(ecephys_api): "stimulus_condition_id", "running_speed", } - assert ( - stim_analysis.presentationwise_statistics.loc[1, 0]["spike_counts"] - == 1.0 - ) - assert ( - stim_analysis.presentationwise_statistics.loc[1, 0][ - "stimulus_condition_id" - ] - == 1.0 - ) + assert stim_analysis.presentationwise_statistics.loc[1, 0]["spike_counts"] == 1.0 + assert stim_analysis.presentationwise_statistics.loc[1, 0]["stimulus_condition_id"] == 1.0 assert np.isclose( stim_analysis.presentationwise_statistics.loc[1, 0]["running_speed"], 0.684848, @@ -356,13 +292,9 @@ def test_presentationwise_statistics(ecephys_api): def test_stimulus_conditions(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0", trial_duration=0.5 - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0", trial_duration=0.5) assert len(stim_analysis.stimulus_conditions) == 2 - assert np.all( - stim_analysis.stimulus_conditions["stimulus_name"].unique() == ["s0"] - ) + assert np.all(stim_analysis.stimulus_conditions["stimulus_name"].unique() == ["s0"]) assert set(stim_analysis.stimulus_conditions["conditions"].unique()) == { 0, 1, @@ -371,26 +303,16 @@ def test_stimulus_conditions(ecephys_api): def test_running_speed(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") assert set(stim_analysis.running_speed.index.values) == set(range(1, 7)) - assert np.isclose( - stim_analysis.running_speed.loc[1]["running_speed"], 0.684848 - ) - assert np.isclose( - stim_analysis.running_speed.loc[3]["running_speed"], 1.806061 - ) - assert np.isclose( - stim_analysis.running_speed.loc[6]["running_speed"], 3.487879 - ) + assert np.isclose(stim_analysis.running_speed.loc[1]["running_speed"], 0.684848) + assert np.isclose(stim_analysis.running_speed.loc[3]["running_speed"], 1.806061) + assert np.isclose(stim_analysis.running_speed.loc[6]["running_speed"], 3.487879) def test_spikes(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") assert isinstance(stim_analysis.spikes, dict) assert stim_analysis.spikes.keys() == set(range(6)) assert np.allclose(stim_analysis.spikes[0], [1, 2, 3, 4]) @@ -399,17 +321,13 @@ def test_spikes(ecephys_api): # Check that spikes dict is filtering units session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0", filter=[0, 2] - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0", filter=[0, 2]) assert stim_analysis.spikes.keys() == {0, 2} def test_get_preferred_condition(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") assert stim_analysis._get_preferred_condition(3) == 1 with pytest.raises(KeyError): @@ -418,31 +336,16 @@ def test_get_preferred_condition(ecephys_api): def test_check_multiple_preferred_conditions(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0" - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0") - assert ( - stim_analysis._check_multiple_pref_conditions(0, "conditions", [0, 1]) - is False - ) - assert ( - stim_analysis._check_multiple_pref_conditions(3, "conditions", [0, 1]) - is True - ) + assert stim_analysis._check_multiple_pref_conditions(0, "conditions", [0, 1]) is False + assert stim_analysis._check_multiple_pref_conditions(3, "conditions", [0, 1]) is True def test_get_time_to_peak(ecephys_api): session = EcephysSession(api=ecephys_api) - stim_analysis = StimulusAnalysis( - ecephys_session=session, stimulus_key="s0", trial_duration=0.5 - ) - assert ( - stim_analysis._get_time_to_peak( - 1, stim_analysis._get_preferred_condition(1) - ) - == 0.0005 - ) + stim_analysis = StimulusAnalysis(ecephys_session=session, stimulus_key="s0", trial_duration=0.5) + assert stim_analysis._get_time_to_peak(1, stim_analysis._get_preferred_condition(1)) == 0.0005 @pytest.mark.parametrize( @@ -504,9 +407,7 @@ def test_get_time_to_peak(ecephys_api): ), ], ) -def test_running_modulation( - spike_counts, running_speeds, speed_threshold, expected -): +def test_running_modulation(spike_counts, running_speeds, speed_threshold, expected): rm = running_modulation(spike_counts, running_speeds, speed_threshold) assert np.allclose(rm, expected, equal_nan=True) @@ -619,9 +520,7 @@ def test_overall_firing_rate(start_times, stop_times, spike_times, expected): ], ) def test_get_fr(spikes, sampling_freq, sweep_length, expected): - frs = get_fr( - spikes, num_timestep_second=sampling_freq, sweep_length=sweep_length - ) + frs = get_fr(spikes, num_timestep_second=sampling_freq, sweep_length=sweep_length) assert len(frs) == int(sampling_freq * sweep_length) assert np.allclose(frs, expected) diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_ephys_pre_spikes.py b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_ephys_pre_spikes.py index f0d2dd4376..c3fe5652e2 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_ephys_pre_spikes.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_ephys_pre_spikes.py @@ -28,7 +28,6 @@ def stimulus_psuedofixture_1(): def test_assign_sweep_values(): - stim_table = pd.DataFrame( [ { @@ -103,7 +102,6 @@ def test_assign_sweep_values(): ], ) def test_split_column(table, column, new_columns, drop, expected): - obtained = ephys_pre_spikes.split_column(table, column, new_columns, drop) pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_column_type=False, check_dtype=False) @@ -128,7 +126,6 @@ def test_split_column(table, column, new_columns, drop, expected): ], ) def test_apply_display_sequence(sweeps, disp_seq, expected): - table = pd.DataFrame(sweeps) disp_seq = np.array(disp_seq) obt_table = ephys_pre_spikes.apply_display_sequence(table, disp_seq) @@ -153,11 +150,12 @@ def test_apply_display_sequence(sweeps, disp_seq, expected): ], ) def test_make_spontaneous_activity_tables(stimulus_tables, expected): - obtained = ephys_pre_spikes.make_spontaneous_activity_tables(stimulus_tables) if len(obtained) == 1: - pd.testing.assert_frame_equal(obtained[0], expected[0], check_like=True, check_column_type=False, check_dtype=False) + pd.testing.assert_frame_equal( + obtained[0], expected[0], check_like=True, check_column_type=False, check_dtype=False + ) else: assert len(obtained) == len(expected) @@ -182,22 +180,15 @@ def test_make_spontaneous_activity_tables(stimulus_tables, expected): ], ) def test_create_stim_table(stimuli, stim_tabler, spon_tabler, sort_key, expected): - - obtained = ephys_pre_spikes.create_stim_table( - stimuli, stim_tabler, spon_tabler, sort_key - ) - pd.testing.assert_frame_equal( - obtained, expected, check_like=True, check_dtype=False, check_column_type=False - ) + obtained = ephys_pre_spikes.create_stim_table(stimuli, stim_tabler, spon_tabler, sort_key) + pd.testing.assert_frame_equal(obtained, expected, check_like=True, check_dtype=False, check_column_type=False) @pytest.mark.parametrize( "stim_table,frame_times,fps,eft,map_cols,expected", [ [ - pd.DataFrame( - {"Start": [1, 2, 3, 4], "End": [2, 3, 4, 5], "data": [-1, -2, -3, -4]} - ), + pd.DataFrame({"Start": [1, 2, 3, 4], "End": [2, 3, 4, 5], "data": [-1, -2, -3, -4]}), np.array([100, 50, 25, 12.5, 6.25]), 10, True, @@ -213,10 +204,7 @@ def test_create_stim_table(stimuli, stim_tabler, spon_tabler, sort_key, expected ], ) def test_apply_frame_times(stim_table, frame_times, fps, eft, map_cols, expected): - - obtained = ephys_pre_spikes.apply_frame_times( - stim_table, frame_times, fps, eft, map_cols - ) + obtained = ephys_pre_spikes.apply_frame_times(stim_table, frame_times, fps, eft, map_cols) pd.testing.assert_frame_equal(obtained, expected, check_like=True, check_column_type=False, check_dtype=False) @@ -262,9 +250,8 @@ def test_apply_frame_times(stim_table, frame_times, fps, eft, map_cols, expected ], ) def test_build_stimuluswise_table(stimulus, stf, start_key, end_key, expected): - - obtained = ephys_pre_spikes.build_stimuluswise_table( - stimulus, stf, start_key, end_key - ) + obtained = ephys_pre_spikes.build_stimuluswise_table(stimulus, stf, start_key, end_key) for obtained_table, expected_table in zip(obtained, expected): - pd.testing.assert_frame_equal(obtained_table, expected_table, check_like=True, check_column_type=False, check_dtype=False) + pd.testing.assert_frame_equal( + obtained_table, expected_table, check_like=True, check_column_type=False, check_dtype=False + ) diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_naming_utilities.py b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_naming_utilities.py index 1e750d41d5..c289b13efe 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_naming_utilities.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_naming_utilities.py @@ -77,9 +77,7 @@ def test_add_number_to_shuffled_movie(table, expected): "table,expected", [ [ - pd.DataFrame( - {"stimulus_name": ["natural_movie_4", "natural_movie_5_more_repeats"]} - ), + pd.DataFrame({"stimulus_name": ["natural_movie_4", "natural_movie_5_more_repeats"]}), pd.DataFrame( { "stimulus_name": [ @@ -105,14 +103,10 @@ def test_standardize_movie_numbers(table, expected): "Natural Images": "natural_scenes", "contrast_response": "drifting_gratings_contrast", }, - pd.DataFrame( - {"stimulus_name": ["natural_scenes", "drifting_gratings_contrast"]} - ), + pd.DataFrame({"stimulus_name": ["natural_scenes", "drifting_gratings_contrast"]}), ], [ - pd.DataFrame( - {"stimulus_name": ["Natural Images", "contrast_response", np.nan]} - ), + pd.DataFrame({"stimulus_name": ["Natural Images", "contrast_response", np.nan]}), { "Natural Images": "natural_scenes", "contrast_response": "drifting_gratings_contrast", @@ -186,6 +180,4 @@ def test_drop_empty_columns(table, expected): ) def test_collapse_colimns(table, expected): obtained = nu.collapse_columns(table) - pd.testing.assert_frame_equal( - expected, obtained, check_like=True, check_dtype=False - ) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_parameter_extraction.py b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_parameter_extraction.py index 36e3049637..b02bf1d437 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_parameter_extraction.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_parameter_extraction.py @@ -21,7 +21,6 @@ def test_extract_const_params_from_stim_repr_duplicates(dup_stim_repr): def test_extract_const_params_from_stim_repr(stim_repr): - expected = { "autoDraw": False, "autoLog": True, @@ -40,7 +39,6 @@ def test_extract_const_params_from_stim_repr(stim_repr): def test_extract_stim_class_from_repr(stim_repr): - expected = "GratingStim" obtained = spe.extract_stim_class_from_repr(stim_repr) diff --git a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_table_module.py b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_table_module.py index 875f9875b0..0dcd1577ce 100644 --- a/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_table_module.py +++ b/allensdk/test/brain_observatory/ecephys/stimulus_table/test_stimulus_table_module.py @@ -53,9 +53,7 @@ def stim_file(*a, **k): }, { "stim_path": "C:\\ecephys_stimulus_scripts\\static_gratings.stim", - "display_sequence": np.array( - [[45, 46], [100, 110]], dtype=np.int32 - ), + "display_sequence": np.array([[45, 46], [100, 110]], dtype=np.int32), "dimnames": ["Ori", "Phase"], "sweep_frames": [(0, 8), (9, 17), (18, 26), (27, 35)], "sweep_order": [0, 2, 3, 1], @@ -283,7 +281,6 @@ def expected_table(): new=sync_file(), ) def test_build_stimulus_table(tmpdir_factory, expected_table): - tmpdir = str(tmpdir_factory.mktemp("ecephys_stimulus_table_integration")) table_path = os.path.join(tmpdir, "stimulus_table.csv") frame_times_path = os.path.join(tmpdir, "frame_times.npy") @@ -306,7 +303,7 @@ def test_build_stimulus_table(tmpdir_factory, expected_table): column_name_map={}, output_stimulus_table_path=table_path, output_frame_times_path=frame_times_path, - fail_on_negative_duration=True + fail_on_negative_duration=True, ) obtained_table = pd.read_csv(table_path) diff --git a/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_metadata.py b/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_metadata.py index 3cd9901468..035cd01587 100644 --- a/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_metadata.py +++ b/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_metadata.py @@ -4,61 +4,50 @@ import pytz from pynwb import NWBFile -from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import \ - BehaviorEcephysMetadata +from allensdk.brain_observatory.ecephys._behavior_ecephys_metadata import BehaviorEcephysMetadata -@pytest.fixture(scope='module') -def behavior_ecephys_metadata_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def behavior_ecephys_metadata_fixture(behavior_ecephys_session_config_fixture): """ Return a BehaviorEcephysMetadata object """ - obj = BehaviorEcephysMetadata.from_json( - dict_repr=behavior_ecephys_session_config_fixture) + obj = BehaviorEcephysMetadata.from_json(dict_repr=behavior_ecephys_session_config_fixture) return obj -@pytest.fixture(scope='module') -def ecephys_session_id_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def ecephys_session_id_fixture(behavior_ecephys_session_config_fixture): """ Return an ecephys_session_id object """ - return behavior_ecephys_session_config_fixture['ecephys_session_id'] + return behavior_ecephys_session_config_fixture["ecephys_session_id"] -def create_nwb_file( - ecephys_session_id): +def create_nwb_file(ecephys_session_id): """ Return an NWB file with a specified ID """ nwbfile = NWBFile( - session_description='foo', + session_description="foo", identifier=str(ecephys_session_id), - session_id='foo', - session_start_time=datetime.datetime(2021, 6, 24, 13, 59, 17, 563000, - tzinfo=pytz.UTC), - institution="Allen Institute" + session_id="foo", + session_start_time=datetime.datetime(2021, 6, 24, 13, 59, 17, 563000, tzinfo=pytz.UTC), + institution="Allen Institute", ) return nwbfile @pytest.mark.requires_bamboo -@pytest.mark.parametrize('roundtrip', [True, False]) +@pytest.mark.parametrize("roundtrip", [True, False]) def test_read_write_nwb( - roundtrip, - data_object_roundtrip_fixture, - behavior_ecephys_metadata_fixture, - ecephys_session_id_fixture): - + roundtrip, data_object_roundtrip_fixture, behavior_ecephys_metadata_fixture, ecephys_session_id_fixture +): nwbfile = create_nwb_file(ecephys_session_id_fixture) behavior_ecephys_metadata_fixture.to_nwb(nwbfile=nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=BehaviorEcephysMetadata) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=BehaviorEcephysMetadata) else: obt = BehaviorEcephysMetadata.from_nwb(nwbfile=nwbfile) diff --git a/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_session.py b/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_session.py index b85251b1d1..0b78fbe29f 100644 --- a/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_session.py +++ b/allensdk/test/brain_observatory/ecephys/test_behavior_ecephys_session.py @@ -5,55 +5,45 @@ import numpy as np import copy -from allensdk.brain_observatory.ecephys.behavior_ecephys_session import \ - BehaviorEcephysSession +from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession from allensdk.brain_observatory.ecephys._probe import ProbeWithLFPMeta -@pytest.fixture(scope='module') -def behavior_ecephys_session_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def behavior_ecephys_session_fixture(behavior_ecephys_session_config_fixture): """ Return a BehaviorEcephysSession for testing """ config = copy.deepcopy(behavior_ecephys_session_config_fixture) # Don't load LFP here to speed up the tests - for probe in config['probes']: - probe['lfp'] = None + for probe in config["probes"]: + probe["lfp"] = None return BehaviorEcephysSession.from_json( - session_data=config, - skip_probes=['probeB', 'probeC', 'probeD', 'probeE', 'probeF'] + session_data=config, skip_probes=["probeB", "probeC", "probeD", "probeE", "probeF"] ) -@pytest.fixture(scope='module') -def behavior_ecephys_session_with_lfp_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def behavior_ecephys_session_with_lfp_fixture(behavior_ecephys_session_config_fixture): """ Return a BehaviorEcephysSession for testing """ config = copy.deepcopy(behavior_ecephys_session_config_fixture) return BehaviorEcephysSession.from_json( - session_data=config, - skip_probes=['probeB', 'probeC', 'probeD', 'probeE', 'probeF'] + session_data=config, skip_probes=["probeB", "probeC", "probeD", "probeE", "probeF"] ) @pytest.mark.requires_bamboo -@pytest.mark.parametrize('roundtrip', [True, False]) -def test_read_write_session_nwb( - roundtrip, - data_object_roundtrip_fixture, - behavior_ecephys_session_fixture): +@pytest.mark.parametrize("roundtrip", [True, False]) +def test_read_write_session_nwb(roundtrip, data_object_roundtrip_fixture, behavior_ecephys_session_fixture): """Tests roundtrip of the session data""" nwbfile, _ = behavior_ecephys_session_fixture.to_nwb() if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=BehaviorEcephysSession) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=BehaviorEcephysSession) else: obt = BehaviorEcephysSession.from_nwb(nwbfile=nwbfile) @@ -62,34 +52,24 @@ def test_read_write_session_nwb( @pytest.mark.requires_bamboo def test_read_write_session_with_probe_nwb( - data_object_roundtrip_fixture, - behavior_ecephys_session_with_lfp_fixture, - tmpdir + data_object_roundtrip_fixture, behavior_ecephys_session_with_lfp_fixture, tmpdir ): """Tests roundtrip of a session with separate probe nwb files that store LFP and CSD data""" - nwbfile, probe_nwbfile_map = \ - behavior_ecephys_session_with_lfp_fixture.to_nwb() + nwbfile, probe_nwbfile_map = behavior_ecephys_session_with_lfp_fixture.to_nwb() probe_meta = dict() for probe_name, probe_nwbfile in probe_nwbfile_map.items(): - path = Path(tmpdir) / f'probe_{probe_name}_lfp.nwb' - with pynwb.NWBHDF5IO(path, 'w') as write_io: + path = Path(tmpdir) / f"probe_{probe_name}_lfp.nwb" + with pynwb.NWBHDF5IO(path, "w") as write_io: write_io.write(probe_nwbfile) - probe_meta[probe_name] = ProbeWithLFPMeta( - lfp_csd_filepath=path, - lfp_sampling_rate=1 - ) - - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=BehaviorEcephysSession, - probe_meta=probe_meta - ) + probe_meta[probe_name] = ProbeWithLFPMeta(lfp_csd_filepath=path, lfp_sampling_rate=1) + + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=BehaviorEcephysSession, probe_meta=probe_meta) # check that the lfp metadata fields are set before loading lfp - assert obt.probes['has_lfp_data'].all() - assert not obt.probes['lfp_sampling_rate'].isnull().any() + assert obt.probes["has_lfp_data"].all() + assert not obt.probes["lfp_sampling_rate"].isnull().any() # Load the LFP data into memory for probe in obt._probes: @@ -100,8 +80,7 @@ def test_read_write_session_with_probe_nwb( @pytest.mark.requires_bamboo -def test_session_consistency( - behavior_ecephys_session_fixture): +def test_session_consistency(behavior_ecephys_session_fixture): """ This method will test the self-consistency of the BehaviorEcephysSession @@ -113,13 +92,11 @@ def test_session_consistency( trials = behavior_ecephys_session_fixture.trials stim_frames = stim[stim.is_change & stim.active].start_frame trials_frames = trials[trials.is_change].change_frame - delta = stim_frames.values-trials_frames.values - np.testing.assert_array_equal( - delta, - np.zeros(len(delta), dtype=int)) + delta = stim_frames.values - trials_frames.values + np.testing.assert_array_equal(delta, np.zeros(len(delta), dtype=int)) # make sure that response_latency is not in the trials table - assert 'response_latency' not in trials.columns + assert "response_latency" not in trials.columns @pytest.mark.requires_bamboo @@ -133,8 +110,7 @@ def test_getters_sanity(behavior_ecephys_session_fixture): @pytest.mark.requires_bamboo -def test_getters_sanity_from_nwb( - behavior_ecephys_session_fixture): +def test_getters_sanity_from_nwb(behavior_ecephys_session_fixture): """Sanity check to make sure that the BehaviorEcephysSession can use the BehaviorSession base class getter methods when read from nwb """ diff --git a/allensdk/test/brain_observatory/ecephys/test_copy_utility.py b/allensdk/test/brain_observatory/ecephys/test_copy_utility.py index da74d4c15f..1dfde5c6a1 100644 --- a/allensdk/test/brain_observatory/ecephys/test_copy_utility.py +++ b/allensdk/test/brain_observatory/ecephys/test_copy_utility.py @@ -9,8 +9,7 @@ import argschema import allensdk.brain_observatory.ecephys.copy_utility.__main__ as cu -from allensdk.brain_observatory.ecephys.copy_utility._schemas import ( - SessionUploadInputSchema, SessionUploadOutputSchema) +from allensdk.brain_observatory.ecephys.copy_utility._schemas import SessionUploadInputSchema, SessionUploadOutputSchema @pytest.mark.parametrize("already_exists", [True, False]) @@ -25,18 +24,12 @@ def test_dst_dir_exists(already_exists, tmp_path, monkeypatch): outj_path = tmp_path / "output.json" args = { - "files": [ - { - "source": str(src_file), - "destination": str(dst_file), - "key": "something"}], - "output_json": str(outj_path)} + "files": [{"source": str(src_file), "destination": str(dst_file), "key": "something"}], + "output_json": str(outj_path), + } parser = argschema.ArgSchemaParser( - args, - schema_type=SessionUploadInputSchema, - output_schema_type=SessionUploadOutputSchema, - args=[] + args, schema_type=SessionUploadInputSchema, output_schema_type=SessionUploadOutputSchema, args=[] ) def mock_copy_file(source, dest, use_rsync, make_parent_dirs, chmod=None): @@ -53,11 +46,11 @@ def mock_copy_file(source, dest, use_rsync, make_parent_dirs, chmod=None): def test_hash_file(tmpdir_factory): - tempdir = str(tmpdir_factory.mktemp('ecephys_copy_utility_test_hash_file')) - path = os.path.join(tempdir, 'afile.txt') + tempdir = str(tmpdir_factory.mktemp("ecephys_copy_utility_test_hash_file")) + path = os.path.join(tempdir, "afile.txt") - st = 'hello world' - with open(path, 'wb') as f: + st = "hello world" + with open(path, "wb") as f: f.write(st.encode()) hasher_cls = hashlib.sha256 @@ -69,133 +62,108 @@ def test_hash_file(tmpdir_factory): assert expected == obtained -@pytest.mark.parametrize('use_rsync', [True, False]) -@pytest.mark.parametrize('make_parent_dirs', [True, False]) +@pytest.mark.parametrize("use_rsync", [True, False]) +@pytest.mark.parametrize("make_parent_dirs", [True, False]) @pytest.mark.parametrize("chmod", [777, 775, 755, None]) def test_copy_file_entry(tmpdir_factory, use_rsync, make_parent_dirs, chmod): - - mac_or_linux = ( - sys.platform.startswith('darwin') or sys.platform.startswith('linux') - ) + mac_or_linux = sys.platform.startswith("darwin") or sys.platform.startswith("linux") if use_rsync and not mac_or_linux: pytest.skip() - tempdir = str( - tmpdir_factory.mktemp('ecephys_copy_utility_test_copy_file_entry') - ) - spath = os.path.join(tempdir, 'afile.txt') - dpath = os.path.join(tempdir, 'bfile.txt') + tempdir = str(tmpdir_factory.mktemp("ecephys_copy_utility_test_copy_file_entry")) + spath = os.path.join(tempdir, "afile.txt") + dpath = os.path.join(tempdir, "bfile.txt") - with open(spath, 'w') as sf: - sf.write('foo') + with open(spath, "w") as sf: + sf.write("foo") cu.copy_file_entry(spath, dpath, use_rsync, make_parent_dirs, chmod) - with open(dpath, 'r') as df: - assert df.read() == 'foo' + with open(dpath, "r") as df: + assert df.read() == "foo" def get_human_mode(path): return int(oct(os.stat(path).st_mode & 0o777)[2:]) + expected_mode = chmod if chmod is not None else get_human_mode(spath) if mac_or_linux: assert get_human_mode(dpath) == expected_mode -@pytest.mark.parametrize('different', [True, False]) -@pytest.mark.parametrize('raise_if_comparison_fails', [True, False]) -def test_compare_directories(tmpdir_factory, - different, - raise_if_comparison_fails): +@pytest.mark.parametrize("different", [True, False]) +@pytest.mark.parametrize("raise_if_comparison_fails", [True, False]) +def test_compare_directories(tmpdir_factory, different, raise_if_comparison_fails): hasher_cls = hashlib.sha256 - base_dir = str( - tmpdir_factory.mktemp('ecephys_copy_utility_test_compare_directories') - ) - sdir = os.path.join(base_dir, 'src') + base_dir = str(tmpdir_factory.mktemp("ecephys_copy_utility_test_compare_directories")) + sdir = os.path.join(base_dir, "src") os.makedirs(sdir) - ddir = os.path.join(base_dir, 'dest') + ddir = os.path.join(base_dir, "dest") os.makedirs(ddir) if different: - - with open(os.path.join(sdir, 'foo.txt'), 'w') as f: - f.write('baz') + with open(os.path.join(sdir, "foo.txt"), "w") as f: + f.write("baz") if raise_if_comparison_fails: with pytest.raises(ValueError): - cu.compare_directories( - sdir, ddir, hasher_cls, raise_if_comparison_fails) + cu.compare_directories(sdir, ddir, hasher_cls, raise_if_comparison_fails) else: with pytest.warns(UserWarning): - cu.compare_directories( - sdir, ddir, hasher_cls, raise_if_comparison_fails) + cu.compare_directories(sdir, ddir, hasher_cls, raise_if_comparison_fails) else: - cu.compare_directories( - sdir, ddir, hasher_cls, raise_if_comparison_fails) + cu.compare_directories(sdir, ddir, hasher_cls, raise_if_comparison_fails) -@pytest.mark.parametrize('different', [True, False]) -@pytest.mark.parametrize('raise_if_comparison_fails', [True, False]) +@pytest.mark.parametrize("different", [True, False]) +@pytest.mark.parametrize("raise_if_comparison_fails", [True, False]) def test_compare_files(tmpdir_factory, different, raise_if_comparison_fails): hasher_cls = hashlib.sha256 - base_dir = str( - tmpdir_factory.mktemp('ecephys_copy_utility_test_compare_files') - ) - spath = os.path.join(base_dir, 'source.txt') - dpath = os.path.join(base_dir, 'dest.txt') + base_dir = str(tmpdir_factory.mktemp("ecephys_copy_utility_test_compare_files")) + spath = os.path.join(base_dir, "source.txt") + dpath = os.path.join(base_dir, "dest.txt") - with open(spath, 'w') as f: - f.write('baz') + with open(spath, "w") as f: + f.write("baz") if different: - - with open(dpath, 'w') as f: - f.write('fish') + with open(dpath, "w") as f: + f.write("fish") if raise_if_comparison_fails: with pytest.raises(ValueError): - cu.compare_files( - spath, dpath, hasher_cls, raise_if_comparison_fails) + cu.compare_files(spath, dpath, hasher_cls, raise_if_comparison_fails) else: with pytest.warns(UserWarning): - cu.compare_files( - spath, dpath, hasher_cls, raise_if_comparison_fails) + cu.compare_files(spath, dpath, hasher_cls, raise_if_comparison_fails) else: - - with open(dpath, 'w') as f: - f.write('baz') + with open(dpath, "w") as f: + f.write("baz") cu.compare_files(spath, dpath, hasher_cls, raise_if_comparison_fails) def test_SessionUploadSchema(tmpdir): - src_file = Path(tmpdir) / 'src.csv' + src_file = Path(tmpdir) / "src.csv" src_file.touch() - dst_file = Path(tmpdir) / 'dst.csv' + dst_file = Path(tmpdir) / "dst.csv" - output_json = Path(tmpdir) / 'output.json' + output_json = Path(tmpdir) / "output.json" test_data = { - 'files': [{ - 'source': str(src_file), - 'destination': str(dst_file), - 'key': '' - }], - 'output_json': str(output_json) + "files": [{"source": str(src_file), "destination": str(dst_file), "key": ""}], + "output_json": str(output_json), } parser = argschema.ArgSchemaParser( - test_data, - schema_type=SessionUploadInputSchema, - output_schema_type=SessionUploadOutputSchema, - args=[] + test_data, schema_type=SessionUploadInputSchema, output_schema_type=SessionUploadOutputSchema, args=[] ) # Mocking the functionality of the main method shutil.copy(src_file, dst_file) - parser.output({'files': test_data['files']}) + parser.output({"files": test_data["files"]}) diff --git a/allensdk/test/brain_observatory/ecephys/test_current_source_density.py b/allensdk/test/brain_observatory/ecephys/test_current_source_density.py index 160b263884..9fc9dc17b9 100644 --- a/allensdk/test/brain_observatory/ecephys/test_current_source_density.py +++ b/allensdk/test/brain_observatory/ecephys/test_current_source_density.py @@ -9,88 +9,90 @@ @pytest.fixture def stim_table(): - return pd.DataFrame({ - 'Start': [0, 1, 2, 3, 4, 5, 6], - 'End': [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], - 'alpha': [None, -1, -2, -3, -4, -5, -6], - 'stimulus_name': [None, 'a', 'a', 'a', 'b', 'b', 'a'], - 'stimulus_index': [None, 0, 0, 0, 1, 1, 2] - }) + return pd.DataFrame( + { + "Start": [0, 1, 2, 3, 4, 5, 6], + "End": [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5], + "alpha": [None, -1, -2, -3, -4, -5, -6], + "stimulus_name": [None, "a", "a", "a", "b", "b", "a"], + "stimulus_index": [None, 0, 0, 0, 1, 1, 2], + } + ) # ------------ _current_source_density.py ------------ -@pytest.mark.parametrize('stim_index', [0, None]) +@pytest.mark.parametrize("stim_index", [0, None]) def test_extract_trial_windows(stim_table, stim_index): - - stim_name = 'a' + stim_name = "a" time_step = 0.1 pre_stim_time = 0.2 post_stim_time = 0.3 num_trials = 2 - expected = [ - [0.8, 0.9, 1.0, 1.1, 1.2], - [1.8, 1.9, 2.0, 2.1, 2.2] - ] + expected = [[0.8, 0.9, 1.0, 1.1, 1.2], [1.8, 1.9, 2.0, 2.1, 2.2]] exp_rel = [-0.2, -0.1, 0.0, 0.1, 0.2] obtained, obt_rel = csd.extract_trial_windows( - stim_table, stim_name, time_step, pre_stim_time, - post_stim_time, num_trials, stim_index + stim_table, stim_name, time_step, pre_stim_time, post_stim_time, num_trials, stim_index ) assert np.allclose(obtained, expected) assert np.allclose(obt_rel, exp_rel) -@pytest.mark.parametrize('times,raw,channels,windows,volts_per_bit,expected', [ +@pytest.mark.parametrize( + "times,raw,channels,windows,volts_per_bit,expected", [ - np.arange(10), - np.arange(50).reshape([10, 5]), - [1, 3], - [[5.5, 6], [7, 8]], - 1.0, [ - # data are rounded to int - [[28, 31], [30, 33]], - [[36, 41], [38, 43]] - ] - ], - [ - np.arange(10), - np.arange(50).reshape([10, 5]), - [1, 3], - [[5.5, 6], [7, 8]], - 0.5, + np.arange(10), + np.arange(50).reshape([10, 5]), + [1, 3], + [[5.5, 6], [7, 8]], + 1.0, + [ + # data are rounded to int + [[28, 31], [30, 33]], + [[36, 41], [38, 43]], + ], + ], [ - # volts_per_bit scaling may result in floats - [[14, 15.5], [15, 16.5]], - [[18, 20.5], [19, 21.5]] - ] - ] -]) -def test_accumulate_lfp_data(times, raw, channels, windows, volts_per_bit, - expected): - obtained = csd.accumulate_lfp_data(times, raw, channels, - windows, volts_per_bit) + np.arange(10), + np.arange(50).reshape([10, 5]), + [1, 3], + [[5.5, 6], [7, 8]], + 0.5, + [ + # volts_per_bit scaling may result in floats + [[14, 15.5], [15, 16.5]], + [[18, 20.5], [19, 21.5]], + ], + ], + ], +) +def test_accumulate_lfp_data(times, raw, channels, windows, volts_per_bit, expected): + obtained = csd.accumulate_lfp_data(times, raw, channels, windows, volts_per_bit) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('trial_mean_accumulated,spacing,expected,expected_channels', [ +@pytest.mark.parametrize( + "trial_mean_accumulated,spacing,expected,expected_channels", [ - np.nanmean(np.arange(36).reshape([2, 6, 3]) ** 3, axis=0), - 1.0, - [[1728., 1926., 2142.], - [648., 702., 756.], - [810., 864., 918.], - [972., 1026., 1080.], - [1134., 1188., 1242.], - [-5292., -5706., -6138.]], - np.arange(6) - ] -]) + [ + np.nanmean(np.arange(36).reshape([2, 6, 3]) ** 3, axis=0), + 1.0, + [ + [1728.0, 1926.0, 2142.0], + [648.0, 702.0, 756.0], + [810.0, 864.0, 918.0], + [972.0, 1026.0, 1080.0], + [1134.0, 1188.0, 1242.0], + [-5292.0, -5706.0, -6138.0], + ], + np.arange(6), + ] + ], +) def test_compute_csd(trial_mean_accumulated, spacing, expected, expected_channels): - obtained, obtained_channels = csd.compute_csd(trial_mean_accumulated, spacing=spacing) assert np.allclose(obtained, expected) @@ -98,109 +100,106 @@ def test_compute_csd(trial_mean_accumulated, spacing, expected, expected_channel # ------------ _interpolation_utils.py ------------ -@pytest.mark.parametrize('min_chan, max_chan, expected', [ +@pytest.mark.parametrize( + "min_chan, max_chan, expected", [ - # min_chan - 0, - # max_chan - 4, - # expected actual channel locations - [[16, 0], [48, 0], [0, 20], [32, 20]] - ], - [ - 2, - 6, - [[0, 20], [32, 20], [16, 40], [48, 40]] - ], - [ - 0, - 8, - [[16, 0], [48, 0], [0, 20], [32, 20], - [16, 40], [48, 40], [0, 60], [32, 60]] - ], - [ - 4, - 8, - [[16, 40], [48, 40], [0, 60], [32, 60]] + [ + # min_chan + 0, + # max_chan + 4, + # expected actual channel locations + [[16, 0], [48, 0], [0, 20], [32, 20]], + ], + [2, 6, [[0, 20], [32, 20], [16, 40], [48, 40]]], + [0, 8, [[16, 0], [48, 0], [0, 20], [32, 20], [16, 40], [48, 40], [0, 60], [32, 60]]], + [4, 8, [[16, 40], [48, 40], [0, 60], [32, 60]]], + [5, 6, [[48, 40]]], ], - [ - 5, - 6, - [[48, 40]] - ] - -]) +) def test_make_actual_channel_locations(min_chan, max_chan, expected): - obtained = interp_utils.make_actual_channel_locations(min_chan=min_chan, - max_chan=max_chan) + obtained = interp_utils.make_actual_channel_locations(min_chan=min_chan, max_chan=max_chan) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('min_chan, max_chan, expected', [ +@pytest.mark.parametrize( + "min_chan, max_chan, expected", [ - # min_chan - 0, - # max_chan - 7, - # expected interpolated channel locations - [[24, 0], [24, 10], [24, 20], [24, 30], [24, 40], [24, 50], [24, 60]] - ], - [ - 0, - 14, - [[24, 0], [24, 10], [24, 20], [24, 30], [24, 40], [24, 50], [24, 60], - [24, 70], [24, 80], [24, 90], [24, 100], [24, 110], [24, 120], [24, 130]] - ], - [ - 2, - 6, - [[24, 20], [24, 30], [24, 40], [24, 50]] - ], - [ - 7, - 14, - [[24, 70], [24, 80], [24, 90], [24, 100], [24, 110], [24, 120], [24, 130]] + [ + # min_chan + 0, + # max_chan + 7, + # expected interpolated channel locations + [[24, 0], [24, 10], [24, 20], [24, 30], [24, 40], [24, 50], [24, 60]], + ], + [ + 0, + 14, + [ + [24, 0], + [24, 10], + [24, 20], + [24, 30], + [24, 40], + [24, 50], + [24, 60], + [24, 70], + [24, 80], + [24, 90], + [24, 100], + [24, 110], + [24, 120], + [24, 130], + ], + ], + [2, 6, [[24, 20], [24, 30], [24, 40], [24, 50]]], + [7, 14, [[24, 70], [24, 80], [24, 90], [24, 100], [24, 110], [24, 120], [24, 130]]], + [8, 9, [[24, 80]]], ], - [ - 8, - 9, - [[24, 80]] - ] -]) +) def test_make_interp_channel_locations(min_chan, max_chan, expected): - obtained = interp_utils.make_interp_channel_locations(min_chan=min_chan, - max_chan=max_chan) + obtained = interp_utils.make_interp_channel_locations(min_chan=min_chan, max_chan=max_chan) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('lfp, actual_locs, interp_locs, expected', [ +@pytest.mark.parametrize( + "lfp, actual_locs, interp_locs, expected", [ - # lfp - np.arange(36).reshape([2, 6, 3]) ** 3, - # actual_locs - interp_utils.make_actual_channel_locations(0, 6), - # interp_locs - interp_utils.make_interp_channel_locations(0, 6), - # expected (interp_lfp, spacing) - ([[[-1.48688877e+01, -1.65987508e+00, 2.25788198e+01], - [1.84977651e+02, 3.01621496e+02, 4.52968522e+02], - [5.82039685e+02, 8.18476063e+02, 1.11005335e+03], - [1.23712914e+03, 1.61285986e+03, 2.05929375e+03], - [2.03821497e+03, 2.56276850e+03, 3.17035171e+03], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], - - [[6.80643377e+03, 7.93617702e+03, 9.18494994e+03], - [1.24901530e+04, 1.41494541e+04, 1.59514583e+04], - [1.81704531e+04, 2.03174258e+04, 2.26275394e+04], - [2.37138688e+04, 2.62802568e+04, 2.90253479e+04], - [2.90797202e+04, 3.20168080e+04, 3.51449255e+04], - [0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]], 0.01) - ] -]) + [ + # lfp + np.arange(36).reshape([2, 6, 3]) ** 3, + # actual_locs + interp_utils.make_actual_channel_locations(0, 6), + # interp_locs + interp_utils.make_interp_channel_locations(0, 6), + # expected (interp_lfp, spacing) + ( + [ + [ + [-1.48688877e01, -1.65987508e00, 2.25788198e01], + [1.84977651e02, 3.01621496e02, 4.52968522e02], + [5.82039685e02, 8.18476063e02, 1.11005335e03], + [1.23712914e03, 1.61285986e03, 2.05929375e03], + [2.03821497e03, 2.56276850e03, 3.17035171e03], + [0.00000000e00, 0.00000000e00, 0.00000000e00], + ], + [ + [6.80643377e03, 7.93617702e03, 9.18494994e03], + [1.24901530e04, 1.41494541e04, 1.59514583e04], + [1.81704531e04, 2.03174258e04, 2.26275394e04], + [2.37138688e04, 2.62802568e04, 2.90253479e04], + [2.90797202e04, 3.20168080e04, 3.51449255e04], + [0.00000000e00, 0.00000000e00, 0.00000000e00], + ], + ], + 0.01, + ), + ] + ], +) def test_interp_channel_locs(lfp, actual_locs, interp_locs, expected): - obtained = interp_utils.interp_channel_locs(lfp=lfp, - actual_locs=actual_locs, - interp_locs=interp_locs) + obtained = interp_utils.interp_channel_locs(lfp=lfp, actual_locs=actual_locs, interp_locs=interp_locs) obtained_interp_lfp, obtained_spacing = obtained expected_interp_lfp, expected_spacing = expected @@ -210,58 +209,101 @@ def test_interp_channel_locs(lfp, actual_locs, interp_locs, expected): # ------------ _filter_utils.py ------------ -@pytest.mark.parametrize('lfp, ref_channels, noisy_thresh, expected', [ +@pytest.mark.parametrize( + "lfp, ref_channels, noisy_thresh, expected", [ - # lfp arrays in the form of: trials x channel x time samples - # channel 1 should be marked as 'noisy' and 2 should be removed - # for being a reference - np.array([[[0.1, 0.1, 0.1, 0.1], [0, 50, 500, 5000], [0, 0, 0, 0], [0.3, 0.3, 0.3, 0.3]], - [[0.15, 0.15, 0.15, 0.15], [0, 10, 100, 1000], [0, 0, 0, 0], [0.25, 0.25, 0.25, 0.25]], - [[0.2, 0.2, 0.2, 0.2], [0, 0, 0, 0], [0, 0, 0, 0], [0.2, 0.2, 0.2, 0.2]]]), - # reference channels - [2], - # noisy_channel_threshold - 2.0, - # expected output (cleaned_lfp, good_indices) - (np.array([[[0.1, 0.1, 0.1, 0.1], [0.3, 0.3, 0.3, 0.3]], - [[0.15, 0.15, 0.15, 0.15], [0.25, 0.25, 0.25, 0.25]], - [[0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]]), - np.array([0, 3])) - ] -]) + [ + # lfp arrays in the form of: trials x channel x time samples + # channel 1 should be marked as 'noisy' and 2 should be removed + # for being a reference + np.array( + [ + [[0.1, 0.1, 0.1, 0.1], [0, 50, 500, 5000], [0, 0, 0, 0], [0.3, 0.3, 0.3, 0.3]], + [[0.15, 0.15, 0.15, 0.15], [0, 10, 100, 1000], [0, 0, 0, 0], [0.25, 0.25, 0.25, 0.25]], + [[0.2, 0.2, 0.2, 0.2], [0, 0, 0, 0], [0, 0, 0, 0], [0.2, 0.2, 0.2, 0.2]], + ] + ), + # reference channels + [2], + # noisy_channel_threshold + 2.0, + # expected output (cleaned_lfp, good_indices) + ( + np.array( + [ + [[0.1, 0.1, 0.1, 0.1], [0.3, 0.3, 0.3, 0.3]], + [[0.15, 0.15, 0.15, 0.15], [0.25, 0.25, 0.25, 0.25]], + [[0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]], + ] + ), + np.array([0, 3]), + ), + ] + ], +) def test_select_good_channels(lfp, ref_channels, noisy_thresh, expected): - obtained = filt_utils.select_good_channels(lfp, - ref_channels, - noisy_thresh) + obtained = filt_utils.select_good_channels(lfp, ref_channels, noisy_thresh) obtained_cleaned, obtained_good_inds = obtained assert np.allclose(obtained_cleaned, expected[0]) assert np.allclose(obtained_good_inds, expected[1]) -@pytest.mark.parametrize('lfp, sampling_rate, filter_cuts, filter_order, expected', [ +@pytest.mark.parametrize( + "lfp, sampling_rate, filter_cuts, filter_order, expected", [ - # lfp - np.arange(30).reshape([1, 3, 10]), - # sampling_rate - 1000, - # filter_cuts - [5.0, 150.0], - # filter_order - 1, - # expected output - [[[-8.03033681, -7.51102701, -7.00015769, -6.49716665, -6.00149202, - -5.51257727, -5.02987273, -4.55283625, -4.08093445, -3.61364687], - [-8.03033681, -7.51102701, -7.00015769, -6.49716665, -6.00149202, - -5.51257727, -5.02987273, -4.55283625, -4.08093445, -3.61364687], - [-8.03033681, -7.51102701, -7.00015769, -6.49716665, -6.00149202, - -5.51257727, -5.02987273, -4.55283625, -4.08093445, -3.61364687]]] - - ] - -]) + [ + # lfp + np.arange(30).reshape([1, 3, 10]), + # sampling_rate + 1000, + # filter_cuts + [5.0, 150.0], + # filter_order + 1, + # expected output + [ + [ + [ + -8.03033681, + -7.51102701, + -7.00015769, + -6.49716665, + -6.00149202, + -5.51257727, + -5.02987273, + -4.55283625, + -4.08093445, + -3.61364687, + ], + [ + -8.03033681, + -7.51102701, + -7.00015769, + -6.49716665, + -6.00149202, + -5.51257727, + -5.02987273, + -4.55283625, + -4.08093445, + -3.61364687, + ], + [ + -8.03033681, + -7.51102701, + -7.00015769, + -6.49716665, + -6.00149202, + -5.51257727, + -5.02987273, + -4.55283625, + -4.08093445, + -3.61364687, + ], + ] + ], + ] + ], +) def test_filter_lfp_channels(lfp, sampling_rate, filter_cuts, filter_order, expected): - obtained = filt_utils.filter_lfp_channels(lfp, - sampling_rate, - filter_cuts, - filter_order) + obtained = filt_utils.filter_lfp_channels(lfp, sampling_rate, filter_cuts, filter_order) assert np.allclose(obtained, expected) diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_cache.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_cache.py index b5bc38ebb6..e138bc6af0 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_cache.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_cache.py @@ -11,114 +11,130 @@ import allensdk.brain_observatory.ecephys.ecephys_project_cache as epc import allensdk.brain_observatory.ecephys.nwb_util from allensdk.core.authentication import DbCredentials -from allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine \ - import ( - write_from_stream, write_bytes_from_coroutine, AsyncHttpEngine, - HttpEngine, DEFAULT_TIMEOUT as HTTP_ENGINE_DEFAULT_TIMEOUT - ) +from allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine import ( + write_from_stream, + write_bytes_from_coroutine, + AsyncHttpEngine, + HttpEngine, + DEFAULT_TIMEOUT as HTTP_ENGINE_DEFAULT_TIMEOUT, +) -mock_lims_credentials = DbCredentials(dbname='mock_lims', user='mock_user', - host='mock_host', port='mock_port', - password='mock') +mock_lims_credentials = DbCredentials( + dbname="mock_lims", user="mock_user", host="mock_host", port="mock_port", password="mock" +) @pytest.fixture def raw_sessions(): - return pd.DataFrame({ - 'session_type': ['stimulus_set_one', 'stimulus_set_two', - 'stimulus_set_two'], - "unit_count": [500, 1000, 1500], - "channel_count": [40, 90, 140], - "probe_count": [3, 4, 5], - "structure_acronyms": [["a", "v"], ["a", "c"], ["b"]] - }, index=pd.Series(name='id', data=[1, 2, 3])) + return pd.DataFrame( + { + "session_type": ["stimulus_set_one", "stimulus_set_two", "stimulus_set_two"], + "unit_count": [500, 1000, 1500], + "channel_count": [40, 90, 140], + "probe_count": [3, 4, 5], + "structure_acronyms": [["a", "v"], ["a", "c"], ["b"]], + }, + index=pd.Series(name="id", data=[1, 2, 3]), + ) @pytest.fixture def sessions(): - return pd.DataFrame({ - 'session_type': ['stimulus_set_one', 'stimulus_set_two', - 'stimulus_set_two'], - "unit_count": [500, 1000, 1500], - "channel_count": [40, 90, 140], - "probe_count": [3, 4, 5], - "ecephys_structure_acronyms": [["a", "v"], ["a", "c"], ["b"]] - }, index=pd.Series(name='id', data=[1, 2, 3])) + return pd.DataFrame( + { + "session_type": ["stimulus_set_one", "stimulus_set_two", "stimulus_set_two"], + "unit_count": [500, 1000, 1500], + "channel_count": [40, 90, 140], + "probe_count": [3, 4, 5], + "ecephys_structure_acronyms": [["a", "v"], ["a", "c"], ["b"]], + }, + index=pd.Series(name="id", data=[1, 2, 3]), + ) @pytest.fixture def units(): - return pd.DataFrame({ - 'ecephys_channel_id': [2, 1], - 'snr': [1.5, 4.9], - "amplitude_cutoff": [0.05, 0.2], - "presence_ratio": [10, 20], - "isi_violations": [0.3, 0.4], - "quality": ["good", "noise"] - }, index=pd.Series(name='id', data=[1, 2])) + return pd.DataFrame( + { + "ecephys_channel_id": [2, 1], + "snr": [1.5, 4.9], + "amplitude_cutoff": [0.05, 0.2], + "presence_ratio": [10, 20], + "isi_violations": [0.3, 0.4], + "quality": ["good", "noise"], + }, + index=pd.Series(name="id", data=[1, 2]), + ) @pytest.fixture def analysis_metrics(): - return pd.DataFrame({ - "a": [0, 1, 2], - "b": [3, 4, 5] - }, index=pd.Index(name="ecephys_unit_id", data=[1, 2, 3])) + return pd.DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]}, index=pd.Index(name="ecephys_unit_id", data=[1, 2, 3])) @pytest.fixture def channels(): - return pd.DataFrame({ - 'ecephys_probe_id': [11, 11], - 'ap': [1000, 2000], - "unit_count": [5, 10], - "ecephys_structure_acronym": ["a", "b"] - }, index=pd.Series(name='id', data=[1, 2])) + return pd.DataFrame( + { + "ecephys_probe_id": [11, 11], + "ap": [1000, 2000], + "unit_count": [5, 10], + "ecephys_structure_acronym": ["a", "b"], + }, + index=pd.Series(name="id", data=[1, 2]), + ) @pytest.fixture def raw_probes(): - return pd.DataFrame({ - 'ecephys_session_id': [3], - "unit_count": [50], - "channel_count": [10], - "lfp_temporal_subsampling_factor": [2.0], - "lfp_sampling_rate": [1000.0], - }, index=pd.Series(name='id', data=[11])) + return pd.DataFrame( + { + "ecephys_session_id": [3], + "unit_count": [50], + "channel_count": [10], + "lfp_temporal_subsampling_factor": [2.0], + "lfp_sampling_rate": [1000.0], + }, + index=pd.Series(name="id", data=[11]), + ) @pytest.fixture def probes(): - return pd.DataFrame({ - 'ecephys_session_id': [3], - "unit_count": [50], - "channel_count": [10], - "lfp_temporal_subsampling_factor": [2.0], - "lfp_sampling_rate": [500.0], - }, index=pd.Series(name='id', data=[11])) + return pd.DataFrame( + { + "ecephys_session_id": [3], + "unit_count": [50], + "channel_count": [10], + "lfp_temporal_subsampling_factor": [2.0], + "lfp_sampling_rate": [500.0], + }, + index=pd.Series(name="id", data=[11]), + ) @pytest.fixture def annotated_probes(probes, sessions): - return pd.merge(probes, sessions, left_on="ecephys_session_id", - right_index=True, suffixes=["_probe", "_session"]) + return pd.merge(probes, sessions, left_on="ecephys_session_id", right_index=True, suffixes=["_probe", "_session"]) @pytest.fixture def annotated_channels(channels, annotated_probes): - return pd.merge(channels, annotated_probes, left_on="ecephys_probe_id", - right_index=True, suffixes=["_channel", "_probe"]) + return pd.merge( + channels, annotated_probes, left_on="ecephys_probe_id", right_index=True, suffixes=["_channel", "_probe"] + ) @pytest.fixture def annotated_units(units, annotated_channels): - return pd.merge(units, annotated_channels, left_on="ecephys_channel_id", - right_index=True, suffixes=["_unit", "_channel"]) + return pd.merge( + units, annotated_channels, left_on="ecephys_channel_id", right_index=True, suffixes=["_unit", "_channel"] + ) @pytest.fixture def shared_tmpdir(tmpdir_factory): - return str(tmpdir_factory.mktemp('test_ecephys_project_cache')) + return str(tmpdir_factory.mktemp("test_ecephys_project_cache")) class MockEngine: @@ -127,10 +143,8 @@ def __init__(self): @pytest.fixture -def mock_api(shared_tmpdir, raw_sessions, units, channels, raw_probes, - analysis_metrics): +def mock_api(shared_tmpdir, raw_sessions, units, channels, raw_probes, analysis_metrics): class MockApi: - def __init__(self, **kwargs): self.accesses = collections.defaultdict(lambda: 1) self.rma_engine = MockEngine() @@ -151,38 +165,32 @@ def get_probes(self, **kwargs): return raw_probes def get_session_data(self, session_id, **kwargs): - path = os.path.join(shared_tmpdir, 'tmp.nwb') + path = os.path.join(shared_tmpdir, "tmp.nwb") nwbfile = pynwb.NWBFile( - session_description='EcephysSession', - identifier=f"{session_id}", - session_start_time=datetime.now() + session_description="EcephysSession", identifier=f"{session_id}", session_start_time=datetime.now() ) allensdk.brain_observatory.ecephys.nwb_util.add_probe_to_nwbfile( - nwbfile, 11, sampling_rate=1.0, - lfp_sampling_rate=2.0, - has_lfp_data=True, - name="Test Probe") + nwbfile, 11, sampling_rate=1.0, lfp_sampling_rate=2.0, has_lfp_data=True, name="Test Probe" + ) with pynwb.NWBHDF5IO(path, "w") as io: io.write(nwbfile) - return open(path, 'rb') + return open(path, "rb") def get_probe_lfp_data(self, probe_id): path = os.path.join(shared_tmpdir, f"probe_{probe_id}.nwb") nwbfile = pynwb.NWBFile( - session_description='EcephysProbe', - identifier=f"{probe_id}", - session_start_time=datetime.now() + session_description="EcephysProbe", identifier=f"{probe_id}", session_start_time=datetime.now() ) with pynwb.NWBHDF5IO(path, "w") as io: io.write(nwbfile) - return open(path, 'rb') + return open(path, "rb") def get_natural_scene_template(self, number): path = os.path.join(shared_tmpdir, "tmp.tiff") @@ -203,12 +211,9 @@ def get_unit_analysis_metrics(self, *a, **k): @pytest.fixture def tmpdir_cache(shared_tmpdir, mock_api): - man_path = os.path.join(shared_tmpdir, 'manifest.json') + man_path = os.path.join(shared_tmpdir, "manifest.json") - return epc.EcephysProjectCache( - fetch_api=mock_api(), - manifest=man_path - ) + return epc.EcephysProjectCache(fetch_api=mock_api(), manifest=man_path) def lazy_cache_test(cache, cache_name, api_name, expected, *args, **kwargs): @@ -222,45 +227,39 @@ def lazy_cache_test(cache, cache_name, api_name, expected, *args, **kwargs): def test_get_sessions(tmpdir_cache, sessions): - lazy_cache_test(tmpdir_cache, '_get_sessions', "get_sessions", sessions) + lazy_cache_test(tmpdir_cache, "_get_sessions", "get_sessions", sessions) @pytest.mark.parametrize("filter_by_validity", [False, True]) def test_get_units(tmpdir_cache, units, filter_by_validity): if filter_by_validity: units = units[units["quality"] == "good"].drop(columns="quality") - lazy_cache_test(tmpdir_cache, '_get_units', "get_units", units, - filter_by_validity=filter_by_validity) + lazy_cache_test(tmpdir_cache, "_get_units", "get_units", units, filter_by_validity=filter_by_validity) else: units = units[units["amplitude_cutoff"] <= 0.1] - lazy_cache_test(tmpdir_cache, '_get_units', "get_units", units, - filter_by_validity=filter_by_validity) + lazy_cache_test(tmpdir_cache, "_get_units", "get_units", units, filter_by_validity=filter_by_validity) def test_get_probes(tmpdir_cache, probes): - lazy_cache_test(tmpdir_cache, '_get_probes', "get_probes", probes) + lazy_cache_test(tmpdir_cache, "_get_probes", "get_probes", probes) def test_get_channels(tmpdir_cache, channels): - lazy_cache_test(tmpdir_cache, '_get_channels', "get_channels", channels) + lazy_cache_test(tmpdir_cache, "_get_channels", "get_channels", channels) def test_get_annotated_probes(tmpdir_cache, probes, annotated_probes): - lazy_cache_test(tmpdir_cache, "_get_annotated_probes", "get_probes", - annotated_probes) + lazy_cache_test(tmpdir_cache, "_get_annotated_probes", "get_probes", annotated_probes) def test_get_annotated_channels(tmpdir_cache, channels, annotated_channels): - lazy_cache_test(tmpdir_cache, "_get_annotated_channels", "get_channels", - annotated_channels) + lazy_cache_test(tmpdir_cache, "_get_annotated_channels", "get_channels", annotated_channels) def test_get_annotated_units(tmpdir_cache, units, annotated_units): - annotated_units = annotated_units[ - annotated_units["amplitude_cutoff"] < 0.1] + annotated_units = annotated_units[annotated_units["amplitude_cutoff"] < 0.1] - lazy_cache_test(tmpdir_cache, "_get_annotated_units", "get_units", - annotated_units, filter_by_validity=False) + lazy_cache_test(tmpdir_cache, "_get_annotated_units", "get_units", annotated_units, filter_by_validity=False) def test_get_session_data(shared_tmpdir, tmpdir_cache): @@ -268,9 +267,8 @@ def test_get_session_data(shared_tmpdir, tmpdir_cache): data_one = tmpdir_cache.get_session_data(sid) - assert 1 == tmpdir_cache.fetch_api.accesses['get_session_data'] - assert os.path.join(shared_tmpdir, f"session_{sid}", - f"session_{sid}.nwb") == data_one.api.path + assert 1 == tmpdir_cache.fetch_api.accesses["get_session_data"] + assert os.path.join(shared_tmpdir, f"session_{sid}", f"session_{sid}.nwb") == data_one.api.path def test_get_natural_scene_template(shared_tmpdir, tmpdir_cache): @@ -294,38 +292,33 @@ def test_get_natural_movie_template(shared_tmpdir, tmpdir_cache): def test_get_unit_analysis_metrics_for_session(tmpdir_cache, analysis_metrics): lazy_cache_test( tmpdir_cache, - 'get_unit_analysis_metrics_for_session', + "get_unit_analysis_metrics_for_session", "get_unit_analysis_metrics", analysis_metrics, session_id=3, - annotate=False + annotate=False, ) -def test_get_unit_analysis_metrics_by_session_type(tmpdir_cache, - analysis_metrics): +def test_get_unit_analysis_metrics_by_session_type(tmpdir_cache, analysis_metrics): lazy_cache_test( tmpdir_cache, - 'get_unit_analysis_metrics_by_session_type', + "get_unit_analysis_metrics_by_session_type", "get_unit_analysis_metrics", analysis_metrics, session_type="stimulus_set_two", - annotate=False + annotate=False, ) def test_get_session_data_eventual_success(tmpdir_factory, mock_api): - man_path = os.path.join( - tmpdir_factory.mktemp("get_session_data"), - "manifest.json" - ) + man_path = os.path.join(tmpdir_factory.mktemp("get_session_data"), "manifest.json") class InitiallyFailingApi(mock_api): def get_session_data(self, session_id, **kwargs): if self.accesses["get_session_data"] < 1: raise ValueError("bad news!") - return super(InitiallyFailingApi, self).get_session_data( - session_id, **kwargs) + return super(InitiallyFailingApi, self).get_session_data(session_id, **kwargs) api = InitiallyFailingApi() cache = epc.EcephysProjectCache(manifest=man_path, fetch_api=api) @@ -336,10 +329,7 @@ def get_session_data(self, session_id, **kwargs): def test_get_session_data_continual_failure(tmpdir_factory, mock_api): - man_path = os.path.join( - tmpdir_factory.mktemp("get_session_data"), - "manifest.json" - ) + man_path = os.path.join(tmpdir_factory.mktemp("get_session_data"), "manifest.json") class ContinuallyFailingApi(mock_api): def get_session_data(self, session_id, **kwargs): @@ -354,17 +344,13 @@ def get_session_data(self, session_id, **kwargs): def test_get_probe_lfp_data(tmpdir_factory, mock_api): - man_path = os.path.join( - tmpdir_factory.mktemp("get_lfp_data"), - "manifest.json" - ) + man_path = os.path.join(tmpdir_factory.mktemp("get_lfp_data"), "manifest.json") class InitiallyFailingApi(mock_api): def get_probe_lfp_data(self, probe_id, **kwargs): if self.accesses["get_probe_data"] < 1: raise ValueError("bad news!") - return super(InitiallyFailingApi, self).get_probe_lfp_data( - probe_id, **kwargs) + return super(InitiallyFailingApi, self).get_probe_lfp_data(probe_id, **kwargs) api = InitiallyFailingApi() cache = epc.EcephysProjectCache(manifest=man_path, fetch_api=api) @@ -379,10 +365,7 @@ def get_probe_lfp_data(self, probe_id, **kwargs): def test_get_probe_lfp_data_continually_failing(tmpdir_factory, mock_api): - man_path = os.path.join( - tmpdir_factory.mktemp("get_lfp_data"), - "manifest.json" - ) + man_path = os.path.join(tmpdir_factory.mktemp("get_lfp_data"), "manifest.json") class ContinuallyFailingApi(mock_api): def get_probe_lfp_data(self, probe_id, **kwargs): @@ -404,8 +387,7 @@ def test_from_lims_default(tmpdir_factory): tmpdir = str(tmpdir_factory.mktemp("test_from_lims_default")) cache = epc.EcephysProjectCache.from_lims( - manifest=os.path.join(tmpdir, "manifest.json"), - lims_credentials=mock_lims_credentials + manifest=os.path.join(tmpdir, "manifest.json"), lims_credentials=mock_lims_credentials ) assert isinstance(cache.fetch_api.app_engine, HttpEngine) assert cache.stream_writer is write_from_stream @@ -416,9 +398,7 @@ def test_from_lims_default(tmpdir_factory): def test_from_warehouse_default(tmpdir_factory): tmpdir = str(tmpdir_factory.mktemp("test_from_warehouse_default")) - cache = epc.EcephysProjectCache.from_warehouse( - manifest=os.path.join(tmpdir, "manifest.json") - ) + cache = epc.EcephysProjectCache.from_warehouse(manifest=os.path.join(tmpdir, "manifest.json")) assert isinstance(cache.fetch_api.rma_engine, HttpEngine) assert cache.stream_writer is write_from_stream assert cache.fetch_api.rma_engine.scheme == "http" @@ -427,9 +407,7 @@ def test_from_warehouse_default(tmpdir_factory): def test_init_default(tmpdir_factory): tmpdir = str(tmpdir_factory.mktemp("test_init_default")) - cache = epc.EcephysProjectCache( - manifest=os.path.join(tmpdir, "manifest.json") - ) + cache = epc.EcephysProjectCache(manifest=os.path.join(tmpdir, "manifest.json")) assert isinstance(cache.fetch_api.rma_engine, HttpEngine) assert cache.stream_writer is cache.fetch_api.rma_engine.write_bytes assert cache.fetch_api.rma_engine.scheme == "http" @@ -437,31 +415,42 @@ def test_init_default(tmpdir_factory): @pytest.mark.parametrize( - ("cache_constructor, asynchronous, engine_attr, expected_engine," - "expected_scheme, expected_host, expected_stream_writer"), [ + ( + "cache_constructor, asynchronous, engine_attr, expected_engine," + "expected_scheme, expected_host, expected_stream_writer" + ), + [ ( - epc.EcephysProjectCache.from_lims, True, - "app_engine", AsyncHttpEngine, "http", "lims2", - write_bytes_from_coroutine + epc.EcephysProjectCache.from_lims, + True, + "app_engine", + AsyncHttpEngine, + "http", + "lims2", + write_bytes_from_coroutine, ), - ( - epc.EcephysProjectCache.from_lims, False, - "app_engine", HttpEngine, "http", "lims2", - write_from_stream - ) - ]) + (epc.EcephysProjectCache.from_lims, False, "app_engine", HttpEngine, "http", "lims2", write_from_stream), + ], +) def test_stream_asynchronous_arg_from_lims( - cache_constructor, asynchronous, engine_attr, expected_engine, - expected_scheme, expected_host, expected_stream_writer, - tmpdir_factory): - """ Ensure the proper stream engine is chosen from the `asynchronous` + cache_constructor, + asynchronous, + engine_attr, + expected_engine, + expected_scheme, + expected_host, + expected_stream_writer, + tmpdir_factory, +): + """Ensure the proper stream engine is chosen from the `asynchronous` argument in the EcephysProjectCache constructors (using other default values).""" tmpdir = str(tmpdir_factory.mktemp("test_stream_async_args")) cache = cache_constructor( asynchronous=asynchronous, manifest=os.path.join(tmpdir, "manifest.json"), - lims_credentials=mock_lims_credentials) + lims_credentials=mock_lims_credentials, + ) engine = getattr(cache.fetch_api, engine_attr) assert isinstance(engine, expected_engine) assert cache.stream_writer is expected_stream_writer @@ -470,31 +459,46 @@ def test_stream_asynchronous_arg_from_lims( @pytest.mark.parametrize( - ("cache_constructor, asynchronous, engine_attr, expected_engine," - "expected_scheme, expected_host, expected_stream_writer"), [ + ( + "cache_constructor, asynchronous, engine_attr, expected_engine," + "expected_scheme, expected_host, expected_stream_writer" + ), + [ ( - epc.EcephysProjectCache.from_warehouse, True, - "rma_engine", AsyncHttpEngine, "http", "api.brain-map.org", - write_bytes_from_coroutine + epc.EcephysProjectCache.from_warehouse, + True, + "rma_engine", + AsyncHttpEngine, + "http", + "api.brain-map.org", + write_bytes_from_coroutine, ), ( - epc.EcephysProjectCache.from_warehouse, False, - "rma_engine", HttpEngine, "http", "api.brain-map.org", - write_from_stream - ) - ]) + epc.EcephysProjectCache.from_warehouse, + False, + "rma_engine", + HttpEngine, + "http", + "api.brain-map.org", + write_from_stream, + ), + ], +) def test_stream_asynchronous_arg_from_warehouse( - cache_constructor, asynchronous, engine_attr, expected_engine, - expected_scheme, expected_host, expected_stream_writer, - tmpdir_factory): - """ Ensure the proper stream engine is chosen from the `asynchronous` + cache_constructor, + asynchronous, + engine_attr, + expected_engine, + expected_scheme, + expected_host, + expected_stream_writer, + tmpdir_factory, +): + """Ensure the proper stream engine is chosen from the `asynchronous` argument in the EcephysProjectCache constructors (using other default values).""" tmpdir = str(tmpdir_factory.mktemp("test_stream_async_args")) - cache = cache_constructor( - asynchronous=asynchronous, - manifest=os.path.join(tmpdir, "manifest.json") - ) + cache = cache_constructor(asynchronous=asynchronous, manifest=os.path.join(tmpdir, "manifest.json")) engine = getattr(cache.fetch_api, engine_attr) assert isinstance(engine, expected_engine) assert cache.stream_writer is expected_stream_writer @@ -514,9 +518,7 @@ def test_stream_writer_method_default_correct(tmpdir_factory): def test_default_timeout_from_warehouse(tmpdir_factory): tmpdir = str(tmpdir_factory.mktemp("test_from_warehouse_default")) - cache = epc.EcephysProjectCache.from_warehouse( - manifest=os.path.join(tmpdir, "manifest.json") - ) + cache = epc.EcephysProjectCache.from_warehouse(manifest=os.path.join(tmpdir, "manifest.json")) assert cache.fetch_api.rma_engine.timeout == HTTP_ENGINE_DEFAULT_TIMEOUT @@ -524,7 +526,6 @@ def test_user_provided_timeout_from_warehouse(tmpdir_factory): user_provided_timeout = 3 tmpdir = str(tmpdir_factory.mktemp("test_from_warehouse_default")) cache = epc.EcephysProjectCache.from_warehouse( - manifest=os.path.join(tmpdir, "manifest.json"), - timeout=user_provided_timeout + manifest=os.path.join(tmpdir, "manifest.json"), timeout=user_provided_timeout ) assert cache.fetch_api.rma_engine.timeout == user_provided_timeout diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_fixed_api.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_fixed_api.py index b0c59dcc51..f660402a0c 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_fixed_api.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_fixed_api.py @@ -13,4 +13,4 @@ def test_get_session_data(): api = EcephysProjectFixedApi() with pytest.raises(MissingDataError) as err: api.get_session_data(12345) - assert re.compile("12345").search(err.message) is not None \ No newline at end of file + assert re.compile("12345").search(err.message) is not None diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_lims_api.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_lims_api.py index 05bb5b0635..330f56fdab 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_lims_api.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_lims_api.py @@ -10,13 +10,12 @@ ecephys_project_lims_api as epla, ) -mock_lims_credentials = DbCredentials(dbname='mock_lims', user='mock_user', - host='mock_host', port='mock_port', - password='mock') +mock_lims_credentials = DbCredentials( + dbname="mock_lims", user="mock_user", host="mock_host", port="mock_port", password="mock" +) class MockSelector: - def __init__(self, checks, response): self.checks = checks self.response = response @@ -29,162 +28,113 @@ def __call__(self, query, *args, **kwargs): return self.response -@pytest.mark.parametrize("method_name,kwargs,response,checks,expected", [ - [ - "get_units", - {}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "no_pa_check": lambda st: "published_at" not in st - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_units", - {"session_ids": [1, 2, 3]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "filters_sessions": lambda st: re.compile( - r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_units", - {"unit_ids": [1, 2, 3]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "filters_units": lambda st: re.compile( - r".+and eu.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_units", - {"channel_ids": [1, 2, 3], "probe_ids": [4, 5, 6]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "filters_channels": lambda st: re.compile( - r".+and ec.id in \(1,2,3\).*", re.DOTALL).match( - st) is not None, - "filters_probes": lambda st: re.compile( - r".+and ep.id in \(4,5,6\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_units", - {"published_at": "2019-10-22"}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "checks_pa_not_null": lambda st: re.compile( - r".+and es.published_at is not null.*", re.DOTALL).match( - st) is not None, - "checks_pa": lambda st: re.compile( - r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match( - st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_channels", - {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "checks_pa_not_null": lambda st: re.compile( - r".+and es.published_at is not null.*", re.DOTALL).match( - st) is not None, - "checks_pa": lambda st: re.compile( - r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match( - st) is not None, - "filters_sessions": lambda st: re.compile( - r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], - [ - "get_probes", - {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14]}), - { - "checks_pa_not_null": lambda st: re.compile( - r".+and es.published_at is not null.*", re.DOTALL).match( - st) is not None, - "checks_pa": lambda st: re.compile( - r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match( - st) is not None, - "filters_sessions": lambda st: re.compile( - r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14]}, - index=pd.Index(name="id", data=[5, 6]) - ) - ], +@pytest.mark.parametrize( + "method_name,kwargs,response,checks,expected", [ - "get_sessions", - {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, - pd.DataFrame({"id": [5, 6], "something": [12, 14], - "genotype": ["foo", np.nan]}), - { - "checks_pa_not_null": lambda st: re.compile( - r".+and es.published_at is not null.*", re.DOTALL).match( - st) is not None, - "checks_pa": lambda st: re.compile( - r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match( - st) is not None, - "filters_sessions": lambda st: re.compile( - r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"something": [12, 14], "genotype": ["foo", "wt"]}, - index=pd.Index(name="id", data=[5, 6]) - ) + [ + "get_units", + {}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + {"no_pa_check": lambda st: "published_at" not in st}, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_units", + {"session_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + {"filters_sessions": lambda st: re.compile(r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None}, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_units", + {"unit_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + {"filters_units": lambda st: re.compile(r".+and eu.id in \(1,2,3\).*", re.DOTALL).match(st) is not None}, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_units", + {"channel_ids": [1, 2, 3], "probe_ids": [4, 5, 6]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + { + "filters_channels": lambda st: re.compile(r".+and ec.id in \(1,2,3\).*", re.DOTALL).match(st) + is not None, + "filters_probes": lambda st: re.compile(r".+and ep.id in \(4,5,6\).*", re.DOTALL).match(st) is not None, + }, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_units", + {"published_at": "2019-10-22"}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + { + "checks_pa_not_null": lambda st: re.compile(r".+and es.published_at is not null.*", re.DOTALL).match(st) + is not None, + "checks_pa": lambda st: re.compile(r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match(st) + is not None, + }, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_channels", + {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + { + "checks_pa_not_null": lambda st: re.compile(r".+and es.published_at is not null.*", re.DOTALL).match(st) + is not None, + "checks_pa": lambda st: re.compile(r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match(st) + is not None, + "filters_sessions": lambda st: re.compile(r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) + is not None, + }, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_probes", + {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14]}), + { + "checks_pa_not_null": lambda st: re.compile(r".+and es.published_at is not null.*", re.DOTALL).match(st) + is not None, + "checks_pa": lambda st: re.compile(r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match(st) + is not None, + "filters_sessions": lambda st: re.compile(r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) + is not None, + }, + pd.DataFrame({"something": [12, 14]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_sessions", + {"published_at": "2019-10-22", "session_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "something": [12, 14], "genotype": ["foo", np.nan]}), + { + "checks_pa_not_null": lambda st: re.compile(r".+and es.published_at is not null.*", re.DOTALL).match(st) + is not None, + "checks_pa": lambda st: re.compile(r".+and es.published_at <= '2019-10-22'.*", re.DOTALL).match(st) + is not None, + "filters_sessions": lambda st: re.compile(r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) + is not None, + }, + pd.DataFrame({"something": [12, 14], "genotype": ["foo", "wt"]}, index=pd.Index(name="id", data=[5, 6])), + ], + [ + "get_unit_analysis_metrics", + {"ecephys_session_ids": [1, 2, 3]}, + pd.DataFrame({"id": [5, 6], "data": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], "ecephys_unit_id": [10, 11]}), + {"filters_sessions": lambda st: re.compile(r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None}, + pd.DataFrame( + {"id": [5, 6], "a": [1, 3], "b": [2, 4]}, index=pd.Index(name="ecephys_unit_id", data=[10, 11]) + ), + ], ], - [ - "get_unit_analysis_metrics", - {"ecephys_session_ids": [1, 2, 3]}, - pd.DataFrame( - {"id": [5, 6], "data": [{"a": 1, "b": 2}, {"a": 3, "b": 4}], - "ecephys_unit_id": [10, 11]}), - { - "filters_sessions": lambda st: re.compile( - r".+and es.id in \(1,2,3\).*", re.DOTALL).match(st) is not None - }, - pd.DataFrame( - {"id": [5, 6], "a": [1, 3], "b": [2, 4]}, - index=pd.Index(name="ecephys_unit_id", data=[10, 11]) - ) - ] -]) +) def test_pg_query(method_name, kwargs, response, checks, expected): selector = MockSelector(checks, response) - with mock.patch("allensdk.internal.api.psycopg2_select", - new=selector) as ptc: - api = epla.EcephysProjectLimsApi.default( - lims_credentials=mock_lims_credentials) + with mock.patch("allensdk.internal.api.psycopg2_select", new=selector) as ptc: + api = epla.EcephysProjectLimsApi.default(lims_credentials=mock_lims_credentials) obtained = getattr(api, method_name)(**kwargs) - pd.testing.assert_frame_equal(expected, obtained, check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) any_checks_failed = False for name, result in ptc.passed.items(): @@ -201,13 +151,11 @@ def test_pg_query(method_name, kwargs, response, checks, expected): class MockPgEngine: - def __init__(self, query_pattern): self.query_pattern = query_pattern class MockTemplatePgEngine(MockPgEngine): - def select_one(self, rendered): assert self.query_pattern.match(rendered) is not None return {"well_known_file_id": WKF_ID} @@ -224,35 +172,35 @@ def stream(self, url): assert url == f"well_known_files/download/{WKF_ID}?wkf_id={WKF_ID}" -@pytest.mark.parametrize("method,kwargs,query_pattern,pg_engine_cls", [ +@pytest.mark.parametrize( + "method,kwargs,query_pattern,pg_engine_cls", [ - "get_natural_movie_template", - {"number": 12}, - re.compile(".+st.name = 'natural_movie_12'.+", re.DOTALL), - MockTemplatePgEngine + [ + "get_natural_movie_template", + {"number": 12}, + re.compile(".+st.name = 'natural_movie_12'.+", re.DOTALL), + MockTemplatePgEngine, + ], + [ + "get_natural_scene_template", + {"number": 12}, + re.compile(".+st.name = 'natural_scene_12'.+", re.DOTALL), + MockTemplatePgEngine, + ], + [ + "get_probe_lfp_data", + {"probe_id": 53}, + re.compile(r".+and earp.ecephys_probe_id = 53.+", re.DOTALL), + MockDataPgEngine, + ], + [ + "get_session_data", + {"session_id": 53}, + re.compile(r".+and ear.ecephys_session_id = 53.+", re.DOTALL), + MockDataPgEngine, + ], ], - [ - "get_natural_scene_template", - {"number": 12}, - re.compile(".+st.name = 'natural_scene_12'.+", re.DOTALL), - MockTemplatePgEngine - ], - [ - "get_probe_lfp_data", - {"probe_id": 53}, - re.compile(r".+and earp.ecephys_probe_id = 53.+", re.DOTALL), - MockDataPgEngine - ], - [ - "get_session_data", - {"session_id": 53}, - re.compile(r".+and ear.ecephys_session_id = 53.+", re.DOTALL), - MockDataPgEngine - ] -]) +) def test_file_getter(method, kwargs, query_pattern, pg_engine_cls): - api = epla.EcephysProjectLimsApi( - postgres_engine=pg_engine_cls(query_pattern), - app_engine=MockHttpEngine() - ) + api = epla.EcephysProjectLimsApi(postgres_engine=pg_engine_cls(query_pattern), app_engine=MockHttpEngine()) getattr(api, method)(**kwargs) diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_warehouse_api.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_warehouse_api.py index 61b56bcc31..f0e2059725 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_project_warehouse_api.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_project_warehouse_api.py @@ -2,6 +2,7 @@ from allensdk.brain_observatory.ecephys.ecephys_project_api import ecephys_project_warehouse_api as epwa + @pytest.mark.skipif(True, reason="broken test") @pytest.mark.parametrize( "method,conditions,expected_query", @@ -9,23 +10,17 @@ [ "get_sessions", {}, - ( - "criteria=model::EcephysSession" - ), + ("criteria=model::EcephysSession"), ], [ "get_sessions", {"session_ids": [779839471, 759228117]}, - ( - "criteria=model::EcephysSession,rma::criteria[id$in779839471,759228117]" - ), + ("criteria=model::EcephysSession,rma::criteria[id$in779839471,759228117]"), ], [ "get_sessions", {"session_ids": [779839471, 759228117], "has_eye_tracking": True}, - ( - "criteria=model::EcephysSession,rma::criteria[id$in779839471,759228117][fail_eye_tracking$eqfalse]" - ), + ("criteria=model::EcephysSession,rma::criteria[id$in779839471,759228117][fail_eye_tracking$eqfalse]"), ], [ "get_sessions", diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py index 67949c42ef..58068bd823 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_session.py @@ -4,107 +4,105 @@ import xarray as xr import types -from allensdk.brain_observatory.ecephys.ecephys_session_api import \ - EcephysSessionApi -from allensdk.brain_observatory.ecephys.ecephys_session import \ - EcephysSession, nan_intervals, build_spike_histogram +from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysSessionApi +from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession, nan_intervals, build_spike_histogram @pytest.fixture def raw_stimulus_table(): - return pd.DataFrame({ - 'start_time': np.arange(4)/2, - 'stop_time': np.arange(1, 5)/2, - 'stimulus_name': ['a', 'a', 'a', 'a_movie'], - 'stimulus_block': [0, 0, 0, 1], - 'TF': np.empty(4) * np.nan, - 'SF': np.empty(4) * np.nan, - 'Ori': np.empty(4) * np.nan, - 'Contrast': np.empty(4) * np.nan, - 'Pos_x': np.empty(4) * np.nan, - 'Pos_y': np.empty(4) * np.nan, - 'stimulus_index': [0, 0, 1, 1], - 'Color': np.arange(4)*5.5, - 'Image': np.empty(4) * np.nan, - 'Phase': np.linspace(0, 180, 4), - "texRes": np.ones([4]) - }, index=pd.Index(name='id', data=np.arange(4))) + return pd.DataFrame( + { + "start_time": np.arange(4) / 2, + "stop_time": np.arange(1, 5) / 2, + "stimulus_name": ["a", "a", "a", "a_movie"], + "stimulus_block": [0, 0, 0, 1], + "TF": np.empty(4) * np.nan, + "SF": np.empty(4) * np.nan, + "Ori": np.empty(4) * np.nan, + "Contrast": np.empty(4) * np.nan, + "Pos_x": np.empty(4) * np.nan, + "Pos_y": np.empty(4) * np.nan, + "stimulus_index": [0, 0, 1, 1], + "Color": np.arange(4) * 5.5, + "Image": np.empty(4) * np.nan, + "Phase": np.linspace(0, 180, 4), + "texRes": np.ones([4]), + }, + index=pd.Index(name="id", data=np.arange(4)), + ) @pytest.fixture def raw_invalid_times_table(): - return pd.DataFrame({ - "start_time": [0.3, 1.1, 1.6], - "stop_time": [0.6, 1.54, 2.3], - "tags": - [ + return pd.DataFrame( + { + "start_time": [0.3, 1.1, 1.6], + "stop_time": [0.6, 1.54, 2.3], + "tags": [ ["EcephysSession", "739448407", "stimulus"], ["EcephysProbe", "123448407", "probeA"], ["EcephysProbe", "123448407", "all_probes"], - ] - }) + ], + } + ) @pytest.fixture def raw_spike_times(): - return { - 0: np.array([5, 6, 7, 8]), - 1: np.array([2.5]), - 2: np.array([1.01, 1.03, 1.02]) - } + return {0: np.array([5, 6, 7, 8]), 1: np.array([2.5]), 2: np.array([1.01, 1.03, 1.02])} @pytest.fixture def raw_mean_waveforms(): - return { - 0: np.zeros((3, 20)), - 1: np.zeros((3, 20)) + 1, - 2: np.zeros((3, 20)) + 2 - } + return {0: np.zeros((3, 20)), 1: np.zeros((3, 20)) + 1, 2: np.zeros((3, 20)) + 2} @pytest.fixture def raw_channels(): - return pd.DataFrame({ - 'probe_channel_number': [0, 1, 2], - 'probe_horizontal_position': [5, 10, 15], - 'probe_id': [0, 0, 0], - 'probe_vertical_position': [10, 22, 33], - 'valid_data': [False, True, True] - }, index=pd.Index(name='channel_id', data=[0, 1, 2])) + return pd.DataFrame( + { + "probe_channel_number": [0, 1, 2], + "probe_horizontal_position": [5, 10, 15], + "probe_id": [0, 0, 0], + "probe_vertical_position": [10, 22, 33], + "valid_data": [False, True, True], + }, + index=pd.Index(name="channel_id", data=[0, 1, 2]), + ) @pytest.fixture def raw_units(): - return pd.DataFrame({ - 'firing_rate': np.linspace(1, 3, 3), - 'isi_violations': [40, 0.5, 0.1], - 'probe_channel_number': [0, 0, 1], - 'peak_channel_id': [2, 1, 0], - 'quality': ['good', 'good', 'noise'], - 'snr': [0.1, 1.4, 10.0], - 'on_screen_rf': [True, False, True], - 'p_value_rf': [0.001, 0.01, 0.05] - }, index=pd.Index(name='unit_id', data=np.arange(3)[::-1])) + return pd.DataFrame( + { + "firing_rate": np.linspace(1, 3, 3), + "isi_violations": [40, 0.5, 0.1], + "probe_channel_number": [0, 0, 1], + "peak_channel_id": [2, 1, 0], + "quality": ["good", "good", "noise"], + "snr": [0.1, 1.4, 10.0], + "on_screen_rf": [True, False, True], + "p_value_rf": [0.001, 0.01, 0.05], + }, + index=pd.Index(name="unit_id", data=np.arange(3)[::-1]), + ) @pytest.fixture def raw_probes(): - return pd.DataFrame({ - 'description': ['probeA', 'probeB'], - 'location': ['VISp', 'VISam'], - 'sampling_rate': [30000.0, 30000.0] - }, index=pd.Index(name='id', data=[0, 1])) + return pd.DataFrame( + {"description": ["probeA", "probeB"], "location": ["VISp", "VISam"], "sampling_rate": [30000.0, 30000.0]}, + index=pd.Index(name="id", data=[0, 1]), + ) @pytest.fixture def raw_lfp(): return { 0: xr.DataArray( - data=np.array([[1, 2, 3, 4, 5], - [6, 7, 8, 9, 10]]), - dims=['channel', 'time'], - coords=[[2, 1], np.linspace(0, 2, 5)] + data=np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + dims=["channel", "time"], + coords=[[2, 1], np.linspace(0, 2, 5)], ) } @@ -117,6 +115,7 @@ def get_stimulus_presentations(self): def get_invalid_times(self): return pd.DataFrame() + return EcephysJustStimulusTableApi() @@ -142,11 +141,7 @@ def get_invalid_times(self): @pytest.fixture -def lfp_masking_api(raw_channels, - raw_probes, - raw_lfp, - raw_stimulus_table, - raw_invalid_times_table): +def lfp_masking_api(raw_channels, raw_probes, raw_lfp, raw_stimulus_table, raw_invalid_times_table): class EcephysMaskInvalidLFPApi(EcephysSessionApi): def get_channels(self): return raw_channels @@ -162,6 +157,7 @@ def get_stimulus_presentations(self): def get_invalid_times(self): return raw_invalid_times_table + return EcephysMaskInvalidLFPApi() @@ -176,6 +172,7 @@ def get_units(self): def get_probes(self): return raw_probes + return EcephysUnitsTableApi() @@ -187,14 +184,12 @@ def get_invalid_times(self): def get_stimulus_presentations(self): return raw_stimulus_table + return EcephysValidStimulusTableApi() @pytest.fixture -def mean_waveforms_api(raw_mean_waveforms, - raw_channels, - raw_units, - raw_probes): +def mean_waveforms_api(raw_mean_waveforms, raw_channels, raw_units, raw_probes): class EcephysMeanWaveformsApi(EcephysSessionApi): def get_mean_waveforms(self): return raw_mean_waveforms @@ -207,15 +202,12 @@ def get_units(self): def get_probes(self): return raw_probes + return EcephysMeanWaveformsApi() @pytest.fixture -def spike_times_api(raw_units, - raw_channels, - raw_probes, - raw_stimulus_table, - raw_spike_times): +def spike_times_api(raw_units, raw_channels, raw_probes, raw_stimulus_table, raw_spike_times): class EcephysSpikeTimesApi(EcephysSessionApi): def get_spike_times(self): return raw_spike_times @@ -241,11 +233,7 @@ def get_invalid_times(self): def get_no_spikes_times(self): # A special method used for testing cases when there are no spikes for a # given session, will be swapped out for get_spike_times() - return { - 0: np.array([]), - 1: np.array([]), - 2: np.array([]) - } + return {0: np.array([]), 1: np.array([]), 2: np.array([])} @pytest.fixture @@ -253,18 +241,20 @@ def session_metadata_api(): class EcephysSessionMetadataApi(EcephysSessionApi): def get_ecephys_session_id(self): return 12345 + return EcephysSessionMetadataApi() def test_get_stimulus_epochs(just_stim_table_api): - - expected = pd.DataFrame({ - "start_time": [0, 3/2], - "stop_time": [3/2, 2], - "duration": [3/2, 1/2], - "stimulus_name": ["a", "a_movie"], - "stimulus_block": [0, 1] - }) + expected = pd.DataFrame( + { + "start_time": [0, 3 / 2], + "stop_time": [3 / 2, 2], + "duration": [3 / 2, 1 / 2], + "stimulus_name": ["a", "a_movie"], + "stimulus_block": [0, 1], + } + ) session = EcephysSession(api=just_stim_table_api) obtained = session.get_stimulus_epochs() @@ -272,71 +262,55 @@ def test_get_stimulus_epochs(just_stim_table_api): print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_get_invalid_times(valid_stimulus_table_api, raw_invalid_times_table): - expected = raw_invalid_times_table session = EcephysSession(api=valid_stimulus_table_api) obtained = session.get_invalid_times() - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_get_stimulus_presentations(valid_stimulus_table_api): - - expected = pd.DataFrame({ - "start_time": [0, 1/2, 1, 3/2], - "stop_time": [1/2, 1, 3/2, 2], - "stimulus_name": ['invalid_presentation', - 'invalid_presentation', 'a', 'a_movie'], - "phase": [np.nan, np.nan, 120.0, 180.0] - }, index=pd.Index(name='stimulus_presentation_id', data=[0, 1, 2, 3])) + expected = pd.DataFrame( + { + "start_time": [0, 1 / 2, 1, 3 / 2], + "stop_time": [1 / 2, 1, 3 / 2, 2], + "stimulus_name": ["invalid_presentation", "invalid_presentation", "a", "a_movie"], + "phase": [np.nan, np.nan, 120.0, 180.0], + }, + index=pd.Index(name="stimulus_presentation_id", data=[0, 1, 2, 3]), + ) session = EcephysSession(api=valid_stimulus_table_api) - obtained = session.stimulus_presentations[["start_time", - "stop_time", - "stimulus_name", - "phase"]] + obtained = session.stimulus_presentations[["start_time", "stop_time", "stimulus_name", "phase"]] print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_get_stimulus_presentations_no_invalid_times(just_stim_table_api): - - expected = pd.DataFrame({ - "start_time": [0, 1/2, 1, 3/2], - "stop_time": [1/2, 1, 3/2, 2], - 'stimulus_name': ['a', 'a', 'a', 'a_movie'], - - }, index=pd.Index(name='stimulus_presentation_id', data=[0, 1, 2, 3])) + expected = pd.DataFrame( + { + "start_time": [0, 1 / 2, 1, 3 / 2], + "stop_time": [1 / 2, 1, 3 / 2, 2], + "stimulus_name": ["a", "a", "a", "a_movie"], + }, + index=pd.Index(name="stimulus_presentation_id", data=[0, 1, 2, 3]), + ) session = EcephysSession(api=just_stim_table_api) - obtained = session.stimulus_presentations[["start_time", - "stop_time", - "stimulus_name"]] + obtained = session.stimulus_presentations[["start_time", "stop_time", "stimulus_name"]] print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_session_metadata(session_metadata_api): @@ -347,10 +321,21 @@ def test_session_metadata(session_metadata_api): def test_build_stimulus_presentations(just_stim_table_api): expected_columns = [ - 'start_time', 'stop_time', 'stimulus_name', 'stimulus_block', - 'temporal_frequency', 'spatial_frequency', 'orientation', 'contrast', - 'x_position', 'y_position', 'color', 'frame', 'phase', - 'duration', "stimulus_condition_id" + "start_time", + "stop_time", + "stimulus_name", + "stimulus_block", + "temporal_frequency", + "spatial_frequency", + "orientation", + "contrast", + "x_position", + "y_position", + "color", + "frame", + "phase", + "duration", + "stimulus_condition_id", ] session = EcephysSession(api=just_stim_table_api) @@ -360,7 +345,7 @@ def test_build_stimulus_presentations(just_stim_table_api): print(obtained.columns) assert set(expected_columns) == set(obtained.columns) - assert 'stimulus_presentation_id' == obtained.index.name + assert "stimulus_presentation_id" == obtained.index.name assert 4 == obtained.shape[0] @@ -377,58 +362,47 @@ def test_build_units_table(units_table_api): obtained = session.units assert 3 == session.num_units - assert np.allclose([10, 22, 33], obtained['probe_vertical_position']) + assert np.allclose([10, 22, 33], obtained["probe_vertical_position"]) assert np.allclose([0, 1, 2], obtained.index.values) - assert np.allclose([0.05, 0.01, 0.001], obtained['p_value_rf'].values) + assert np.allclose([0.05, 0.01, 0.001], obtained["p_value_rf"].values) def test_presentationwise_spike_counts(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = \ - session.presentationwise_spike_counts( - np.linspace(-.1, .1, 3), - session.stimulus_presentations.index.values, - session.units.index.values) + obtained = session.presentationwise_spike_counts( + np.linspace(-0.1, 0.1, 3), session.stimulus_presentations.index.values, session.units.index.values + ) - first = obtained.loc[{'unit_id': 2, 'stimulus_presentation_id': 2}] + first = obtained.loc[{"unit_id": 2, "stimulus_presentation_id": 2}] assert np.allclose([0, 3], first) - second = obtained.loc[{'unit_id': 1, 'stimulus_presentation_id': 3}] + second = obtained.loc[{"unit_id": 1, "stimulus_presentation_id": 3}] assert np.allclose([0, 0], second) assert np.allclose([4, 2, 3], obtained.shape) -@pytest.mark.parametrize("spike_times,time_domain,expected", [ +@pytest.mark.parametrize( + "spike_times,time_domain,expected", [ - {1: [1.5, 2.5]}, - [[1, 2, 3, 4], [1.1, 2.1, 3.1, 4.1]], - np.array([[1, 1, 0], [1, 1, 0]])[:, :, None] + [{1: [1.5, 2.5]}, [[1, 2, 3, 4], [1.1, 2.1, 3.1, 4.1]], np.array([[1, 1, 0], [1, 1, 0]])[:, :, None]], + [{1: [1.5, 2.5]}, [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], np.array([[1, 1, 0], [0, 1, 0]])[:, :, None]], + [ + {1: [1.5, 2.5], 2: [1.5, 2.5]}, + [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], + np.stack(([[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]), axis=2), + ], + [ + {1: [1.5, 2.5], 2: [1.5, 1.55]}, + [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], + np.stack(([[1, 1, 0], [0, 1, 0]], [[2, 0, 0], [0, 0, 0]]), axis=2), + ], ], - [ - {1: [1.5, 2.5]}, - [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], - np.array([[1, 1, 0], [0, 1, 0]])[:, :, None] - ], - [ - {1: [1.5, 2.5], 2: [1.5, 2.5]}, - [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], - np.stack(([[1, 1, 0], [0, 1, 0]], [[1, 1, 0], [0, 1, 0]]), axis=2) - ], - [ - {1: [1.5, 2.5], 2: [1.5, 1.55]}, - [[1, 2, 3, 4], [1.6, 2.0, 4.0, 4.1]], - np.stack(([[1, 1, 0], [0, 1, 0]], [[2, 0, 0], [0, 0, 0]]), axis=2) - ] -]) +) @pytest.mark.parametrize("binarize", [True, False]) def test_build_spike_histogram(spike_times, time_domain, expected, binarize): - unit_ids = [k for k in spike_times.keys()] - obtained = build_spike_histogram(time_domain, - spike_times, - unit_ids, - binarize=binarize) + obtained = build_spike_histogram(time_domain, spike_times, unit_ids, binarize=binarize) expected = np.array(expected) if binarize: @@ -440,33 +414,34 @@ def test_build_spike_histogram(spike_times, time_domain, expected, binarize): def test_presentationwise_spike_times(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = \ - session.presentationwise_spike_times( - session.stimulus_presentations.index.values, - session.units.index.values) + obtained = session.presentationwise_spike_times( + session.stimulus_presentations.index.values, session.units.index.values + ) - expected = pd.DataFrame({ - 'unit_id': [2, 2, 2], - 'stimulus_presentation_id': [2, 2, 2, ], - 'time_since_stimulus_presentation_onset': [0.01, 0.02, 0.03] - }, index=pd.Index(name='spike_time', data=[1.01, 1.02, 1.03])) + expected = pd.DataFrame( + { + "unit_id": [2, 2, 2], + "stimulus_presentation_id": [ + 2, + 2, + 2, + ], + "time_since_stimulus_presentation_onset": [0.01, 0.02, 0.03], + }, + index=pd.Index(name="spike_time", data=[1.01, 1.02, 1.03]), + ) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_empty_presentationwise_spike_times(spike_times_api): # Test that when there are no spikes presentationwise_spike_times # doesn't fail and instead returns a empty dataframe - spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, - spike_times_api) + spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, spike_times_api) session = EcephysSession(api=spike_times_api) - obtained = \ - session.presentationwise_spike_times( - session.stimulus_presentations.index.values, - session.units.index.values) + obtained = session.presentationwise_spike_times( + session.stimulus_presentations.index.values, session.units.index.values + ) assert isinstance(obtained, pd.DataFrame) assert obtained.empty @@ -474,11 +449,9 @@ def test_empty_presentationwise_spike_times(spike_times_api): def test_conditionwise_spike_statistics(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = \ - session.conditionwise_spike_statistics( - stimulus_presentation_ids=[0, 1, 2]) + obtained = session.conditionwise_spike_statistics(stimulus_presentation_ids=[0, 1, 2]) - pd.set_option('display.max_columns', None) + pd.set_option("display.max_columns", None) assert obtained.loc[(2, 2), "spike_count"] == 3 assert obtained.loc[(2, 2), "stimulus_presentation_count"] == 1 @@ -486,38 +459,31 @@ def test_conditionwise_spike_statistics(spike_times_api): def test_conditionwise_spike_statistics_using_rates(spike_times_api): session = EcephysSession(api=spike_times_api) - obtained = \ - session.conditionwise_spike_statistics( - stimulus_presentation_ids=[0, 1, 2], use_rates=True) + obtained = session.conditionwise_spike_statistics(stimulus_presentation_ids=[0, 1, 2], use_rates=True) - pd.set_option('display.max_columns', None) + pd.set_option("display.max_columns", None) assert np.allclose([0, 0, 6], obtained["spike_mean"].values) def test_empty_conditionwise_spike_statistics(spike_times_api): # special case when there are no spikes - spike_times_api.get_spike_times = \ - types.MethodType(get_no_spikes_times, spike_times_api) + spike_times_api.get_spike_times = types.MethodType(get_no_spikes_times, spike_times_api) session = EcephysSession(api=spike_times_api) obtained = session.conditionwise_spike_statistics( - stimulus_presentation_ids=session.stimulus_presentations.index.values, - unit_ids=session.units.index.values + stimulus_presentation_ids=session.stimulus_presentations.index.values, unit_ids=session.units.index.values ) assert len(obtained) == 12 - assert not np.any(obtained['spike_count']) # check all spike_counts are 0 - assert not np.any(obtained['spike_mean']) # spike_means are 0 - assert np.all(np.isnan(obtained['spike_std'])) # std/sem is undefined - assert np.all(np.isnan(obtained['spike_sem'])) + assert not np.any(obtained["spike_count"]) # check all spike_counts are 0 + assert not np.any(obtained["spike_mean"]) # spike_means are 0 + assert np.all(np.isnan(obtained["spike_std"])) # std/sem is undefined + assert np.all(np.isnan(obtained["spike_sem"])) def test_get_stimulus_parameter_values(just_stim_table_api): session = EcephysSession(api=just_stim_table_api) obtained = session.get_stimulus_parameter_values() - expected = { - 'color': [0, 5.5, 11, 16.5], - 'phase': [0, 60, 120, 180] - } + expected = {"color": [0, 5.5, 11, 16.5], "phase": [0, 60, 120, 180]} for k, v in expected.items(): assert np.allclose(v, obtained[k]) @@ -525,89 +491,69 @@ def test_get_stimulus_parameter_values(just_stim_table_api): @pytest.mark.parametrize("detailed", [True, False]) -def test_get_stimulus_table(detailed, - just_stim_table_api, - raw_stimulus_table): +def test_get_stimulus_table(detailed, just_stim_table_api, raw_stimulus_table): session = EcephysSession(api=just_stim_table_api) - obtained = session.get_stimulus_table( - ['a'], - include_detailed_parameters=detailed) - - expected_columns = ['start_time', - 'stop_time', - 'stimulus_name', - 'stimulus_block', - 'Color', - 'Phase'] + obtained = session.get_stimulus_table(["a"], include_detailed_parameters=detailed) + + expected_columns = ["start_time", "stop_time", "stimulus_name", "stimulus_block", "Color", "Phase"] if detailed: expected_columns.append("texRes") expected = raw_stimulus_table.loc[:2, expected_columns] - expected['duration'] = expected['stop_time'] - expected['start_time'] + expected["duration"] = expected["stop_time"] - expected["start_time"] expected["stimulus_condition_id"] = [0, 1, 2] expected.rename(columns={"Color": "color", "Phase": "phase"}, inplace=True) print(expected) print(obtained) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_filter_owned_df(just_stim_table_api): session = EcephysSession(api=just_stim_table_api) ids = [0, 2] - obtained = session._filter_owned_df('stimulus_presentations', ids) + obtained = session._filter_owned_df("stimulus_presentations", ids) - assert np.allclose([0, 120], obtained['phase'].values) + assert np.allclose([0, 120], obtained["phase"].values) def test_filter_owned_df_scalar(just_stim_table_api): session = EcephysSession(api=just_stim_table_api) ids = 3 - obtained = session._filter_owned_df('stimulus_presentations', ids) - assert obtained['phase'].values[0] == 180 + obtained = session._filter_owned_df("stimulus_presentations", ids) + assert obtained["phase"].values[0] == 180 def test_build_inter_presentation_intervals(just_stim_table_api): session = EcephysSession(api=just_stim_table_api) obtained = session.inter_presentation_intervals - expected = pd.DataFrame({ - 'interval': [0, 0, 0] - }, index=pd.MultiIndex( + expected = pd.DataFrame( + {"interval": [0, 0, 0]}, + index=pd.MultiIndex( levels=[[0, 1, 2], [1, 2, 3]], codes=[[0, 1, 2], [0, 1, 2]], - names=['from_presentation_id', 'to_presentation_id'] - ) + names=["from_presentation_id", "to_presentation_id"], + ), ) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_get_inter_presentation_intervals_for_stimulus(just_stim_table_api): session = EcephysSession(api=just_stim_table_api) - obtained = session.get_inter_presentation_intervals_for_stimulus('a') - - expected = pd.DataFrame({ - 'interval': [0, 0] - }, index=pd.MultiIndex( - levels=[[0, 1], [1, 2]], - codes=[[0, 1], [0, 1]], - names=['from_presentation_id', 'to_presentation_id'] - ) + obtained = session.get_inter_presentation_intervals_for_stimulus("a") + + expected = pd.DataFrame( + {"interval": [0, 0]}, + index=pd.MultiIndex( + levels=[[0, 1], [1, 2]], codes=[[0, 1], [0, 1]], names=["from_presentation_id", "to_presentation_id"] + ), ) - pd.testing.assert_frame_equal(expected, - obtained, - check_like=True, - check_dtype=False) + pd.testing.assert_frame_equal(expected, obtained, check_like=True, check_dtype=False) def test_get_lfp(channels_table_api): @@ -615,10 +561,9 @@ def test_get_lfp(channels_table_api): obtained = session.get_lfp(0) expected = xr.DataArray( - data=np.array([[1, 2, 3, 4, 5], - [6, 7, 8, 9, 10]]), - dims=['channel', 'time'], - coords=[[2, 1], np.linspace(0, 2, 5)] + data=np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + dims=["channel", "time"], + coords=[[2, 1], np.linspace(0, 2, 5)], ) xr.testing.assert_equal(expected, obtained) @@ -629,10 +574,9 @@ def test_get_lfp_mask_invalid(lfp_masking_api): obtained = session.get_lfp(0) expected = xr.DataArray( - data=np.array([[1, 2, 3, np.nan, np.nan], - [6, 7, 8, np.nan, np.nan]]), - dims=['channel', 'time'], - coords=[[2, 1], np.linspace(0, 2, 5)] + data=np.array([[1, 2, 3, np.nan, np.nan], [6, 7, 8, np.nan, np.nan]]), + dims=["channel", "time"], + coords=[[2, 1], np.linspace(0, 2, 5)], ) print(expected) print(obtained) @@ -640,10 +584,6 @@ def test_get_lfp_mask_invalid(lfp_masking_api): xr.testing.assert_equal(expected, obtained) -@pytest.mark.parametrize("inp,expected", [ - [[np.nan, np.nan, 4, 4, 4, 5, 5], [0, 2, 5, 7]] -]) +@pytest.mark.parametrize("inp,expected", [[[np.nan, np.nan, 4, 4, 4, 5, 5], [0, 2, 5, 7]]]) def test_nan_intervals(inp, expected): - assert np.allclose( - expected, nan_intervals(inp) - ) + assert np.allclose(expected, nan_intervals(inp)) diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_session_nwb_api.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_session_nwb_api.py index 42dcc0f7dc..157ee7fa94 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_session_nwb_api.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_session_nwb_api.py @@ -6,26 +6,25 @@ import allensdk.brain_observatory.ecephys.utils -@pytest.mark.parametrize("left,right,expected,left_on,right_on", [ +@pytest.mark.parametrize( + "left,right,expected,left_on,right_on", [ - pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), - pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), - pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), - "a", - "a" + [ + pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), + pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}), + "a", + "a", + ], + [ + pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), + pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [7, 8, 9]}), + pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [7, 8, 9]}), + ["a", "b"], + ["a", "b"], + ], ], - [ - pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}), - pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [7, 8, 9]}), - pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [7, 8, 9]}), - ["a", "b"], - ["a", "b"] - ] -]) +) def test_clobbering_merge(left, right, expected, left_on, right_on): - obtained = allensdk.brain_observatory.ecephys.utils.clobbering_merge( - left, - right, - left_on=left_on, - right_on=left_on) + obtained = allensdk.brain_observatory.ecephys.utils.clobbering_merge(left, right, left_on=left_on, right_on=left_on) pd.testing.assert_frame_equal(expected, obtained, check_like=True) diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_sync_dataset.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_sync_dataset.py index bc693c4010..a6761d1bc8 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_sync_dataset.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_sync_dataset.py @@ -6,30 +6,34 @@ from allensdk.brain_observatory.ecephys.file_io.ecephys_sync_dataset import EcephysSyncDataset -@pytest.mark.parametrize('expected', [1, None]) +@pytest.mark.parametrize("expected", [1, None]) def test_sample_frequency(expected): dataset = EcephysSyncDataset() - dataset.meta_data = {'ni_daq': {}} + dataset.meta_data = {"ni_daq": {}} dataset.sample_frequency = expected assert dataset.sample_frequency == expected - assert dataset.sample_frequency == dataset.meta_data['ni_daq']['counter_output_freq'] + assert dataset.sample_frequency == dataset.meta_data["ni_daq"]["counter_output_freq"] -@pytest.mark.parametrize('key,line_labels,led_vals', [ - [ 'foo', ('LED_sync',), np.array([1, 2, 3]) ], - [ 'LED_sync', ('LED_sync',), np.array([1, 2, 3]) ], -]) +@pytest.mark.parametrize( + "key,line_labels,led_vals", + [ + ["foo", ("LED_sync",), np.array([1, 2, 3])], + ["LED_sync", ("LED_sync",), np.array([1, 2, 3])], + ], +) def test_extract_led_times(key, line_labels, led_vals): - dataset = EcephysSyncDataset() dataset.line_labels = line_labels dataset.sample_frequency = 1000 - with mock.patch('allensdk.brain_observatory.sync_dataset.Dataset.get_all_times', return_value=led_vals): - with mock.patch("allensdk.brain_observatory.sync_dataset.Dataset.get_bit_changes", return_value=np.ones_like(led_vals)) as q: + with mock.patch("allensdk.brain_observatory.sync_dataset.Dataset.get_all_times", return_value=led_vals): + with mock.patch( + "allensdk.brain_observatory.sync_dataset.Dataset.get_bit_changes", return_value=np.ones_like(led_vals) + ) as q: obtained = dataset.extract_led_times(key) - + if key in line_labels: q.assert_called_once_with(0) else: @@ -38,32 +42,36 @@ def test_extract_led_times(key, line_labels, led_vals): assert np.allclose(obtained, led_vals) -@pytest.mark.parametrize('photodiode_times,vsyncs,cycle,expected', [ - [ # expected timing, using vsyncs - np.arange(5.0, 5 + (100 * 0.75), 0.75), - np.arange(5.0, 5 + (298 * 0.25), 0.25) - 0.0625 * np.random.rand(298), # num frames is (num_vsyncs - 1) * cycle + 1 - 3, - np.arange(5.0, 5 + (298 * 0.25), 0.25) - ] -]) +@pytest.mark.parametrize( + "photodiode_times,vsyncs,cycle,expected", + [ + [ # expected timing, using vsyncs + np.arange(5.0, 5 + (100 * 0.75), 0.75), + np.arange(5.0, 5 + (298 * 0.25), 0.25) + - 0.0625 * np.random.rand(298), # num frames is (num_vsyncs - 1) * cycle + 1 + 3, + np.arange(5.0, 5 + (298 * 0.25), 0.25), + ] + ], +) def test_extract_frame_times_from_photodiode(photodiode_times, vsyncs, cycle, expected): - class TimesWrapper: def __call__(self, ignore, keys): - if 'photodiode' in keys: + if "photodiode" in keys: return photodiode_times - elif 'frames' in keys: + elif "frames" in keys: return vsyncs dataset = EcephysSyncDataset() - with mock.patch('allensdk.brain_observatory.ecephys.file_io.ecephys_sync_dataset.EcephysSyncDataset.get_edges', new_callable=TimesWrapper): + with mock.patch( + "allensdk.brain_observatory.ecephys.file_io.ecephys_sync_dataset.EcephysSyncDataset.get_edges", + new_callable=TimesWrapper, + ): obtained = dataset.extract_frame_times_from_photodiode(photodiode_cycle=cycle) assert np.allclose(obtained, expected) - def test_factory(): - - with mock.patch('allensdk.brain_observatory.sync_dataset.Dataset.load') as p: - EcephysSyncDataset.factory('foo') - p.assert_called_with('foo') \ No newline at end of file + with mock.patch("allensdk.brain_observatory.sync_dataset.Dataset.load") as p: + EcephysSyncDataset.factory("foo") + p.assert_called_with("foo") diff --git a/allensdk/test/brain_observatory/ecephys/test_ecephys_utils.py b/allensdk/test/brain_observatory/ecephys/test_ecephys_utils.py index 983892157e..efdf3b68c5 100644 --- a/allensdk/test/brain_observatory/ecephys/test_ecephys_utils.py +++ b/allensdk/test/brain_observatory/ecephys/test_ecephys_utils.py @@ -1,7 +1,6 @@ import pytest import numpy as np -from allensdk.brain_observatory.ecephys.utils import ( - strip_substructure_acronym) +from allensdk.brain_observatory.ecephys.utils import strip_substructure_acronym def test_strip_substructure_acronym(): @@ -9,17 +8,15 @@ def test_strip_substructure_acronym(): Test that strip_substructure_acronym behaves properly """ - assert strip_substructure_acronym('abcde-fg-hi') == 'abcde' + assert strip_substructure_acronym("abcde-fg-hi") == "abcde" assert strip_substructure_acronym(None) is None - data = ['DG-mo', 'DG-pd', 'LS-ab', 'LT-x', 'AB-cd', - 'WX-yz', 'AB-ef'] - expected = ['AB', 'DG', 'LS', 'LT', 'WX'] + data = ["DG-mo", "DG-pd", "LS-ab", "LT-x", "AB-cd", "WX-yz", "AB-ef"] + expected = ["AB", "DG", "LS", "LT", "WX"] assert strip_substructure_acronym(data) == expected - data = [None, 'DG-mo', 'DG-pd', 'LS-ab', 'LT-x', 'AB-cd', - 'WX-yz', None, 'AB-ef', np.nan] - expected = ['AB', 'DG', 'LS', 'LT', 'WX'] + data = [None, "DG-mo", "DG-pd", "LS-ab", "LT-x", "AB-cd", "WX-yz", None, "AB-ef", np.nan] + expected = ["AB", "DG", "LS", "LT", "WX"] assert strip_substructure_acronym(data) == expected assert strip_substructure_acronym([None]) == [] @@ -29,7 +26,7 @@ def test_strip_substructure_acronym(): # pass in a tuple; check that it fails since that is not # a str or a list with pytest.raises(RuntimeError, match="list or a str"): - strip_substructure_acronym(('a', 'b', 'c')) + strip_substructure_acronym(("a", "b", "c")) with pytest.raises(RuntimeError, match="list or a str"): - strip_substructure_acronym(['abc', 2.3]) + strip_substructure_acronym(["abc", 2.3]) diff --git a/allensdk/test/brain_observatory/ecephys/test_http_engine.py b/allensdk/test/brain_observatory/ecephys/test_http_engine.py index 8b2bb34337..5397656152 100644 --- a/allensdk/test/brain_observatory/ecephys/test_http_engine.py +++ b/allensdk/test/brain_observatory/ecephys/test_http_engine.py @@ -4,44 +4,32 @@ import requests import pytest -from allensdk.brain_observatory.ecephys.ecephys_project_api import ( - http_engine -) +from allensdk.brain_observatory.ecephys.ecephys_project_api import http_engine class MockResponse: - @property def headers(self): - return {"Content-length": 10 * 1024 ** 2} - + return {"Content-length": 10 * 1024**2} + def iter_content(self, chunksize): for ii in range(5): yield f"{ii}_{chunksize}_".encode() def test_stream(): - engine = http_engine.HttpEngine( - scheme="http", - host="api.brain-map.org/api/v2" - ) + engine = http_engine.HttpEngine(scheme="http", host="api.brain-map.org/api/v2") with mock.patch("requests.get", return_value=MockResponse()) as p: - results = [item for item in engine.stream("fish")] - p.assert_called_once_with( - "http://api.brain-map.org/api/v2/fish", stream=True - ) + p.assert_called_once_with("http://api.brain-map.org/api/v2/fish", stream=True) assert f"3_{engine.chunksize}_" == results[3].decode() + def test_stream_timeout(): - engine = http_engine.HttpEngine( - scheme="http", - host="api.brain-map.org/api/v2", - timeout=0 - ) + engine = http_engine.HttpEngine(scheme="http", host="api.brain-map.org/api/v2", timeout=0) with mock.patch("requests.get", return_value=MockResponse()): with pytest.raises(requests.Timeout): @@ -50,21 +38,15 @@ def test_stream_timeout(): def test_stream_to_file(tmpdir_factory): - tmpdir = str(tmpdir_factory.mktemp("stream_test")) path = os.path.join(tmpdir, "look_at_this_file") - engine = http_engine.HttpEngine( - scheme="http", - host="api.brain-map.org/api/v2", - chunksize="hi" - ) + engine = http_engine.HttpEngine(scheme="http", host="api.brain-map.org/api/v2", chunksize="hi") with mock.patch("requests.get", return_value=MockResponse()): - stream = engine.stream("fish") http_engine.write_from_stream(path, stream) - + with open(path, "r") as fil: assert "0_hi_1_hi_2_hi_3_hi_4_hi_" == fil.read() @@ -75,7 +57,6 @@ def get(self, url): class MockAsyncResponse: - async def __aenter__(self): return self @@ -91,18 +72,13 @@ def content(self): class MockAsyncContent: - async def iter_chunked(self, chunksize): for ii in range(10): yield (f"{ii}".encode()) def test_async_stream_to_file(tmpdir_factory): - engine = http_engine.AsyncHttpEngine( - scheme="http", - host="api.brain.map.org/api/v2", - session=MockAsyncSession() - ) + engine = http_engine.AsyncHttpEngine(scheme="http", host="api.brain.map.org/api/v2", session=MockAsyncSession()) tmpdir = str(tmpdir_factory.mktemp("async_stream_test")) path = os.path.join(tmpdir, "one_two.three") @@ -111,5 +87,4 @@ def test_async_stream_to_file(tmpdir_factory): http_engine.write_bytes_from_coroutine(path, stream) with open(path, "r") as fil: - assert "0123456789" == fil.read() - + assert "0123456789" == fil.read() diff --git a/allensdk/test/brain_observatory/ecephys/test_lfp_subsampling.py b/allensdk/test/brain_observatory/ecephys/test_lfp_subsampling.py index b66939f8c2..2be2606da7 100644 --- a/allensdk/test/brain_observatory/ecephys/test_lfp_subsampling.py +++ b/allensdk/test/brain_observatory/ecephys/test_lfp_subsampling.py @@ -6,63 +6,67 @@ import allensdk.brain_observatory.ecephys.lfp_subsampling.subsampling as subsampling -@pytest.mark.parametrize('total_channels', [100, 384]) -@pytest.mark.parametrize('surface_offset', [-20, -50]) -@pytest.mark.parametrize('surface_padding', [10, 20]) -@pytest.mark.parametrize('start_channel_offset', [0, 1, 2]) -@pytest.mark.parametrize('channel_stride', [1, 2, 4, 10]) +@pytest.mark.parametrize("total_channels", [100, 384]) +@pytest.mark.parametrize("surface_offset", [-20, -50]) +@pytest.mark.parametrize("surface_padding", [10, 20]) +@pytest.mark.parametrize("start_channel_offset", [0, 1, 2]) +@pytest.mark.parametrize("channel_stride", [1, 2, 4, 10]) def test_select_channels(total_channels, surface_offset, surface_padding, start_channel_offset, channel_stride): input_channels = np.arange(start_channel_offset, total_channels + surface_offset + surface_padding) - selected, actual = subsampling.select_channels(total_channels=total_channels, - surface_channel=total_channels + surface_offset, - surface_padding=surface_padding, - start_channel_offset=start_channel_offset, - channel_stride=channel_stride, - channel_order=np.arange(total_channels)) + selected, actual = subsampling.select_channels( + total_channels=total_channels, + surface_channel=total_channels + surface_offset, + surface_padding=surface_padding, + start_channel_offset=start_channel_offset, + channel_stride=channel_stride, + channel_order=np.arange(total_channels), + ) assert np.allclose(selected, actual) assert len(selected) == len(input_channels[::channel_stride]) -@pytest.mark.parametrize('remove_references', [True, False]) -@pytest.mark.parametrize('reference_channels', [np.array([0, 1, 2]), np.array([9, 10, 11]), np.array([10, 11, 12])]) -@pytest.mark.parametrize('remove_noisy_channels', [True, False]) -@pytest.mark.parametrize('noisy_channels', [np.array([10, 11, 12])]) +@pytest.mark.parametrize("remove_references", [True, False]) +@pytest.mark.parametrize("reference_channels", [np.array([0, 1, 2]), np.array([9, 10, 11]), np.array([10, 11, 12])]) +@pytest.mark.parametrize("remove_noisy_channels", [True, False]) +@pytest.mark.parametrize("noisy_channels", [np.array([10, 11, 12])]) def test_select_channels_filtered(remove_references, reference_channels, remove_noisy_channels, noisy_channels): - """Similar to test above but focused on ability to remove reference """ + """Similar to test above but focused on ability to remove reference""" total_channels = 100 surface_offset = -20 start_channel_offset = 0 channel_stride = 1 surface_padding = 10 - selected, actual = subsampling.select_channels(total_channels=total_channels, - surface_channel=total_channels + surface_offset, - surface_padding=surface_padding, - start_channel_offset=start_channel_offset, - channel_stride=channel_stride, - channel_order=np.arange(total_channels), - noisy_channels=noisy_channels, - remove_noisy_channels=remove_noisy_channels, - reference_channels=reference_channels, - remove_references=remove_references) + selected, actual = subsampling.select_channels( + total_channels=total_channels, + surface_channel=total_channels + surface_offset, + surface_padding=surface_padding, + start_channel_offset=start_channel_offset, + channel_stride=channel_stride, + channel_order=np.arange(total_channels), + noisy_channels=noisy_channels, + remove_noisy_channels=remove_noisy_channels, + reference_channels=reference_channels, + remove_references=remove_references, + ) assert np.allclose(selected, actual) removed_channels = set() if remove_noisy_channels: - assert(not np.any(np.isin(noisy_channels, selected))) + assert not np.any(np.isin(noisy_channels, selected)) removed_channels |= set(noisy_channels) if remove_references: - assert(not np.any(np.isin(reference_channels, selected))) + assert not np.any(np.isin(reference_channels, selected)) removed_channels |= set(reference_channels) input_channels = np.arange(start_channel_offset, total_channels + surface_offset + surface_padding) - assert(len(selected) == len(input_channels) - len(removed_channels)) + assert len(selected) == len(input_channels) - len(removed_channels) -@pytest.mark.parametrize('array_length', [50]) # , 150, 2001]) -@pytest.mark.parametrize('subsampling_factor', [1]) # , 2, 4, 10]) +@pytest.mark.parametrize("array_length", [50]) # , 150, 2001]) +@pytest.mark.parametrize("subsampling_factor", [1]) # , 2, 4, 10]) def test_subsample_timestamps(subsampling_factor, array_length): timestamps = np.linspace(0, 50, array_length) ts_subsampled = subsampling.subsample_timestamps(timestamps, subsampling_factor) @@ -96,12 +100,13 @@ def test_remove_lfp_noise(): assert np.array_equal(np.unique(lfp_noise_removed), np.array([-1, 0])) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() - logging.getLogger('ecephys_pipeline.modules.lfp_subsampling').setLevel(logging.INFO) + logging.getLogger("ecephys_pipeline.modules.lfp_subsampling").setLevel(logging.INFO) for tc, so, sp, sco, cs in itertools.product([100, 384], [-20, -50], [10, 20], [0, 1, 2], [1, 2, 4, 10]): - test_select_channels(total_channels=tc, surface_offset=so, surface_padding=sp, start_channel_offset=sco, - channel_stride=cs) + test_select_channels( + total_channels=tc, surface_offset=so, surface_padding=sp, start_channel_offset=sco, channel_stride=cs + ) test_subsample_timestamps(subsampling_factor=1, array_length=50) test_subsample_lfp() test_remove_lfp_offset() diff --git a/allensdk/test/brain_observatory/ecephys/test_optotagging_table.py b/allensdk/test/brain_observatory/ecephys/test_optotagging_table.py index 0e8cec7647..a6dd22715e 100644 --- a/allensdk/test/brain_observatory/ecephys/test_optotagging_table.py +++ b/allensdk/test/brain_observatory/ecephys/test_optotagging_table.py @@ -3,30 +3,23 @@ from allensdk.brain_observatory.ecephys.optotagging import OptotaggingTable -@pytest.fixture(scope='module') -def optotagging_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def optotagging_fixture(behavior_ecephys_session_config_fixture): """ Return an OptotaggingTable """ - obj = OptotaggingTable.from_json( - dict_repr=behavior_ecephys_session_config_fixture) + obj = OptotaggingTable.from_json(dict_repr=behavior_ecephys_session_config_fixture) return obj @pytest.mark.requires_bamboo -@pytest.mark.parametrize('roundtrip', [True, False]) -def test_read_write_nwb(roundtrip, - data_object_roundtrip_fixture, - optotagging_fixture, - helper_functions): +@pytest.mark.parametrize("roundtrip", [True, False]) +def test_read_write_nwb(roundtrip, data_object_roundtrip_fixture, optotagging_fixture, helper_functions): nwbfile = helper_functions.create_blank_nwb_file() optotagging_fixture.to_nwb(nwbfile=nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=OptotaggingTable) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=OptotaggingTable) else: obt = OptotaggingTable.from_nwb(nwbfile=nwbfile) diff --git a/allensdk/test/brain_observatory/ecephys/test_probes.py b/allensdk/test/brain_observatory/ecephys/test_probes.py index b4bbba8de2..ec72eb022e 100644 --- a/allensdk/test/brain_observatory/ecephys/test_probes.py +++ b/allensdk/test/brain_observatory/ecephys/test_probes.py @@ -6,27 +6,23 @@ from allensdk.brain_observatory.ecephys.write_nwb.schemas import Probe -@pytest.fixture(scope='module') -def probes_config_fixture( - behavior_ecephys_session_config_fixture): +@pytest.fixture(scope="module") +def probes_config_fixture(behavior_ecephys_session_config_fixture): """ Return config data for a Probes object """ - input_data = copy.deepcopy( - behavior_ecephys_session_config_fixture) + input_data = copy.deepcopy(behavior_ecephys_session_config_fixture) # trim down the number of probes to reduce memory footprint of test - input_data['probes'] = ( - input_data['probes'][:3]) + input_data["probes"] = input_data["probes"][:3] - input_data = input_data['probes'] + input_data = input_data["probes"] return input_data -@pytest.fixture(scope='module') -def probes_fixture( - probes_config_fixture): +@pytest.fixture(scope="module") +def probes_fixture(probes_config_fixture): """ Return a Probes object """ @@ -34,24 +30,19 @@ def probes_fixture( # Don't test lfp here for probe in probes: - probe['lfp'] = None + probe["lfp"] = None obj = Probes.from_json(probes=probes) return obj @pytest.mark.requires_bamboo -@pytest.mark.parametrize('roundtrip', [True, False]) -def test_read_write_nwb(roundtrip, - data_object_roundtrip_fixture, - probes_fixture, - helper_functions): +@pytest.mark.parametrize("roundtrip", [True, False]) +def test_read_write_nwb(roundtrip, data_object_roundtrip_fixture, probes_fixture, helper_functions): nwbfile = helper_functions.create_blank_nwb_file() probes_fixture.to_nwb(nwbfile=nwbfile) if roundtrip: - obt = data_object_roundtrip_fixture( - nwbfile=nwbfile, - data_object_cls=Probes) + obt = data_object_roundtrip_fixture(nwbfile=nwbfile, data_object_cls=Probes) else: obt = Probes.from_nwb(nwbfile=nwbfile) @@ -59,20 +50,16 @@ def test_read_write_nwb(roundtrip, @pytest.mark.requires_bamboo -def test_skip_probes( - probes_config_fixture): +def test_skip_probes(probes_config_fixture): """tests that when skip_probes is passed that the probe is skipped""" - names = [p['name'] for p in probes_config_fixture] + names = [p["name"] for p in probes_config_fixture] skip_probes = [names[0]] - probes = Probes.from_json( - probes=probes_config_fixture, skip_probes=skip_probes) - assert sorted([p.name for p in probes]) == \ - sorted([p for p in names if p not in skip_probes]) + probes = Probes.from_json(probes=probes_config_fixture, skip_probes=skip_probes) + assert sorted([p.name for p in probes]) == sorted([p for p in names if p not in skip_probes]) @pytest.mark.requires_bamboo -def test_units_from_structure_with_acronym( - probes_fixture): +def test_units_from_structure_with_acronym(probes_fixture): """Checks that if there are channels with subregion in manual structure id, that units detected from this region are still included in units table""" @@ -81,17 +68,16 @@ def test_units_from_structure_with_acronym( # Set the _structure_acronym to something with a hyphen for probe in probes_fixture.probes: for channel in probe.channels.value: - if channel._structure_acronym == 'MGd': - channel._structure_acronym = 'MGd-foo' + if channel._structure_acronym == "MGd": + channel._structure_acronym = "MGd-foo" obtained_n_units = probes_fixture.get_units_table().shape[0] assert expected_n_units == obtained_n_units -@pytest.mark.parametrize('structure_acronym', ('LGd-sh', 'LGd', None)) -@pytest.mark.parametrize('strip_structure_subregion', (True, False)) -def test_probe_channels_strip_subregion( - structure_acronym, strip_structure_subregion): +@pytest.mark.parametrize("structure_acronym", ("LGd-sh", "LGd", None)) +@pytest.mark.parametrize("strip_structure_subregion", (True, False)) +def test_probe_channels_strip_subregion(structure_acronym, strip_structure_subregion): """Tests that subregion is stripped from manual structure acronym""" c = Channel( id=1, @@ -101,14 +87,13 @@ def test_probe_channels_strip_subregion( probe_id=1, valid_data=True, structure_acronym=structure_acronym, - strip_structure_subregion=strip_structure_subregion + strip_structure_subregion=strip_structure_subregion, ) if type(structure_acronym) is str: if strip_structure_subregion: - expected = 'LGd' + expected = "LGd" else: - expected = 'LGd-sh' if structure_acronym == 'LGd-sh' \ - else 'LGd' + expected = "LGd-sh" if structure_acronym == "LGd-sh" else "LGd" assert c.structure_acronym == expected else: c.structure_acronym is None diff --git a/allensdk/test/brain_observatory/ecephys/test_rma_engine.py b/allensdk/test/brain_observatory/ecephys/test_rma_engine.py index 1f29a2b9ec..4b12173724 100644 --- a/allensdk/test/brain_observatory/ecephys/test_rma_engine.py +++ b/allensdk/test/brain_observatory/ecephys/test_rma_engine.py @@ -5,28 +5,18 @@ import allensdk.brain_observatory.ecephys.ecephys_project_api.rma_engine as rma_engine -@pytest.mark.parametrize("dataframe,expected_types", [ +@pytest.mark.parametrize( + "dataframe,expected_types", [ - pd.DataFrame({ - "a": ["1", "2", "3"], - "b": ["a", "1", "2"] - }), - {"a": np.dtype("int64"), "b": np.dtype("O")} + [pd.DataFrame({"a": ["1", "2", "3"], "b": ["a", "1", "2"]}), {"a": np.dtype("int64"), "b": np.dtype("O")}], + [pd.DataFrame({"a": ["1", "2.4", "3"], "b": ["a", "1", "2"]}), {"a": float, "b": np.dtype("O")}], ], - [ - pd.DataFrame({ - "a": ["1", "2.4", "3"], - "b": ["a", "1", "2"] - }), - {"a": float, "b": np.dtype("O")} - ] -]) +) def test_infer_column_types(dataframe, expected_types): - obtained = rma_engine.infer_column_types(dataframe) obtained_types = {colname: obtained[colname].dtype for colname in obtained.columns} - - assert(set(expected_types.keys()) == set(obtained_types.keys())) + + assert set(expected_types.keys()) == set(obtained_types.keys()) for key, value in expected_types.items(): assert np.dtype(value) == np.dtype(obtained_types[key]) diff --git a/allensdk/test/brain_observatory/ecephys/test_stim_file.py b/allensdk/test/brain_observatory/ecephys/test_stim_file.py index 8510b6fe22..22f6752d0e 100644 --- a/allensdk/test/brain_observatory/ecephys/test_stim_file.py +++ b/allensdk/test/brain_observatory/ecephys/test_stim_file.py @@ -11,41 +11,29 @@ # ideally these would be fixtures, but I want to parametrize over them def stim_pkl_data(): return { - 'fps': 1000, - 'pre_blank_sec': 20, - 'stimuli': [{'a': 1}, {'a': 1}], - 'items': { - 'foraging': { - 'encoders': [ - { - 'dx': [1, 2, 3] - } - ] - } - } + "fps": 1000, + "pre_blank_sec": 20, + "stimuli": [{"a": 1}, {"a": 1}], + "items": {"foraging": {"encoders": [{"dx": [1, 2, 3]}]}}, } def stim_pkl_data_toplevel_dx(): return { - 'fps': 1000, - 'pre_blank_sec': 20, - 'dx': [1, 2, 3], - 'stimuli': [{'a': 1}, {'a': 1}], - 'items': { - 'foraging': { - 'encoders': [] - } - } + "fps": 1000, + "pre_blank_sec": 20, + "dx": [1, 2, 3], + "stimuli": [{"a": 1}, {"a": 1}], + "items": {"foraging": {"encoders": []}}, } @pytest.fixture(params=[stim_pkl_data, stim_pkl_data_toplevel_dx]) def stim_pkl_on_disk(tmpdir_factory, request): - tmpdir = str(tmpdir_factory.mktemp('stim_files')) - file_path = os.path.join(tmpdir, 'stim.pkl') + tmpdir = str(tmpdir_factory.mktemp("stim_files")) + file_path = os.path.join(tmpdir, "stim.pkl") - with open(file_path, 'wb') as pkl_file: + with open(file_path, "wb") as pkl_file: pickle.dump(request.param(), pkl_file) return file_path @@ -56,12 +44,15 @@ def camstimone_pickle_stim_file(stim_pkl_on_disk): return stim_file.CamStimOnePickleStimFile.factory(stim_pkl_on_disk) -@pytest.mark.parametrize('prop_name,expected,comp', [ - ['frames_per_second', 1000, op.eq], - ['pre_blank_sec', 20, op.eq], - ['angular_wheel_rotation', [1, 2, 3], np.allclose], - ['angular_wheel_velocity', [1000, 2000, 3000], np.allclose] -]) +@pytest.mark.parametrize( + "prop_name,expected,comp", + [ + ["frames_per_second", 1000, op.eq], + ["pre_blank_sec", 20, op.eq], + ["angular_wheel_rotation", [1, 2, 3], np.allclose], + ["angular_wheel_velocity", [1000, 2000, 3000], np.allclose], + ], +) def test_properties(camstimone_pickle_stim_file, prop_name, expected, comp): obtained = getattr(camstimone_pickle_stim_file, prop_name) - assert comp(obtained, expected) \ No newline at end of file + assert comp(obtained, expected) diff --git a/allensdk/test/brain_observatory/ecephys/test_stimulus_sync.py b/allensdk/test/brain_observatory/ecephys/test_stimulus_sync.py index c285d248db..a534c5c8f2 100644 --- a/allensdk/test/brain_observatory/ecephys/test_stimulus_sync.py +++ b/allensdk/test/brain_observatory/ecephys/test_stimulus_sync.py @@ -7,172 +7,202 @@ # manual test cases for compute_frame_times, allocate_by_vsync, assign_to_last -@pytest.mark.parametrize('photodiode_times,frame_duration,num_frames,cycle,vsyncs,expected', [ - [ # super basic, no vsyncs, no bad frames - np.linspace(5, 30.0, 11), 0.25, 100, 10, None, - [ - np.arange(5, 30, 0.25), - np.arange(5.25, 30.25, 0.25) - ] - ], - [ # also no bad frames - np.array([5, 5.75, 6.5, 7.25]), 0.25, 9, 3, None, - [ - np.array([5, 5.25, 5.5, 5.75, 6.0, 6.25, 6.5, 6.75, 7.0]), - np.array([5.25, 5.5, 5.75, 6.0, 6.25, 6.5, 6.75, 7.0, 7.25 ]), - ] - ], - [ # now add in a long-short, using the append_to_last rule - np.array([5, 5.75, 6.75, 7.25, 8.0]), 0.25, 12, 3, None, - [ - np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.75, 7.00, 7.25, 7.25, 7.50, 7.75]), - np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.75, 7.00, 7.25, 7.25, 7.50, 7.75, 8.00]), - ] - ], - [ # expected timing, using vsyncs - np.array([5, 5.75, 6.5, 7.25, 8.0]), 0.25, 12, 3, - np.array([4.9 , 5.15, 5.4 , 5.65, 5.9 , 6.15, 6.4 , 6.65, 6.9 , 7.15, 7.4 , 7.65, 7.9]), - [ - np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.0, 7.25, 7.50, 7.75]), - np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.0, 7.25, 7.50, 7.75, 8.00]) - ] - ], - [ # classic extra frame case - np.array([5, 5.75, 6.5, 7.5, 8.25]), 0.25, 12, 3, - np.array([4.9 , 5.15, 5.4 , 5.65, 5.9 , 6.15, 6.4 , 6.65, 7.15, 7.4 , 7.65, 7.9 , 8.15]), - [ - np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.25, 7.50, 7.75, 8.00]), - np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.25, 7.50, 7.75, 8., 8.25]) - ] - ], - [ # long-short, using vsyncs - np.array([5, 5.75, 6.5, 7.50, 8.0]), 0.25, 12, 3, - np.array([4.9 , 5.15, 5.4 , 5.65, 5.9 , 6.15, 6.4 , 6.9 , 7.15, 7.4, 7.4, 7.65, 7.9]), - [ - np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 7.0, 7.25, 7.50, 7.50, 7.75]), - np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 7.0, 7.25, 7.50, 7.50, 7.75, 8.00]) - ] +@pytest.mark.parametrize( + "photodiode_times,frame_duration,num_frames,cycle,vsyncs,expected", + [ + [ # super basic, no vsyncs, no bad frames + np.linspace(5, 30.0, 11), + 0.25, + 100, + 10, + None, + [np.arange(5, 30, 0.25), np.arange(5.25, 30.25, 0.25)], + ], + [ # also no bad frames + np.array([5, 5.75, 6.5, 7.25]), + 0.25, + 9, + 3, + None, + [ + np.array([5, 5.25, 5.5, 5.75, 6.0, 6.25, 6.5, 6.75, 7.0]), + np.array([5.25, 5.5, 5.75, 6.0, 6.25, 6.5, 6.75, 7.0, 7.25]), + ], + ], + [ # now add in a long-short, using the append_to_last rule + np.array([5, 5.75, 6.75, 7.25, 8.0]), + 0.25, + 12, + 3, + None, + [ + np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.75, 7.00, 7.25, 7.25, 7.50, 7.75]), + np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.75, 7.00, 7.25, 7.25, 7.50, 7.75, 8.00]), + ], + ], + [ # expected timing, using vsyncs + np.array([5, 5.75, 6.5, 7.25, 8.0]), + 0.25, + 12, + 3, + np.array([4.9, 5.15, 5.4, 5.65, 5.9, 6.15, 6.4, 6.65, 6.9, 7.15, 7.4, 7.65, 7.9]), + [ + np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.0, 7.25, 7.50, 7.75]), + np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.0, 7.25, 7.50, 7.75, 8.00]), + ], + ], + [ # classic extra frame case + np.array([5, 5.75, 6.5, 7.5, 8.25]), + 0.25, + 12, + 3, + np.array([4.9, 5.15, 5.4, 5.65, 5.9, 6.15, 6.4, 6.65, 7.15, 7.4, 7.65, 7.9, 8.15]), + [ + np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.25, 7.50, 7.75, 8.00]), + np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.25, 7.50, 7.75, 8.0, 8.25]), + ], + ], + [ # long-short, using vsyncs + np.array([5, 5.75, 6.5, 7.50, 8.0]), + 0.25, + 12, + 3, + np.array([4.9, 5.15, 5.4, 5.65, 5.9, 6.15, 6.4, 6.9, 7.15, 7.4, 7.4, 7.65, 7.9]), + [ + np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 7.0, 7.25, 7.50, 7.50, 7.75]), + np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 7.0, 7.25, 7.50, 7.50, 7.75, 8.00]), + ], + ], + [ # only short, using vsyncs + np.array([5, 5.75, 6.5, 7.0, 7.75]), + 0.25, + 12, + 3, + np.array([4.9, 5.15, 5.4, 5.65, 5.9, 6.15, 6.4, 6.65, 6.65, 6.9, 7.15, 7.4, 7.65, 7.9]), + [ + np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 6.75, 7.0, 7.25, 7.50]), + np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 6.75, 7.0, 7.25, 7.50, 7.75]), + ], + ], ], - [ # only short, using vsyncs - np.array([5, 5.75, 6.5, 7.0, 7.75]), 0.25, 12, 3, - np.array([4.9 , 5.15, 5.4 , 5.65, 5.9 , 6.15, 6.4, 6.65, 6.65, 6.9, 7.15, 7.4 , 7.65, 7.9]), - [ - np.array([5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 6.75, 7.0, 7.25, 7.50]), - np.array([5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 6.75, 7.0, 7.25, 7.50, 7.75]) - ] - ], -]) +) def test_compute_frame_times(photodiode_times, frame_duration, num_frames, cycle, vsyncs, expected): - if vsyncs is not None: cb = partial(stimulus_sync.allocate_by_vsync, np.diff(vsyncs)) else: cb = stimulus_sync.assign_to_last - obt_indices, obt_starts, obt_ends = stimulus_sync.compute_frame_times(photodiode_times, frame_duration, num_frames, cycle, cb) - assert(np.allclose(obt_indices, np.arange(num_frames))) + obt_indices, obt_starts, obt_ends = stimulus_sync.compute_frame_times( + photodiode_times, frame_duration, num_frames, cycle, cb + ) + assert np.allclose(obt_indices, np.arange(num_frames)) assert np.allclose(obt_starts, expected[0]) assert np.allclose(obt_ends, expected[1]) -@pytest.mark.parametrize('process,pctiles', [ - [partial(np.random.rand, 1000), (5, 95)], - [partial(np.random.rand, 1000), (45, 55)] -]) +@pytest.mark.parametrize( + "process,pctiles", [[partial(np.random.rand, 1000), (5, 95)], [partial(np.random.rand, 1000), (45, 55)]] +) def test_trimmed_stats(process, pctiles): - data = np.sort(process()) true_mean = np.mean(data) true_std = np.std(data) - lower_missing = pctiles[0] + lower_missing = pctiles[0] upper_missing = 100 - pctiles[1] total_missing = lower_missing + upper_missing fraction_lower = lower_missing / total_missing - num_missing = (data.size * total_missing / 100) / ( 1 - total_missing / 100 ) - num_missing_lower = int(np.around( num_missing * fraction_lower )) - num_missing_upper = int(np.around( num_missing * ( 1 - fraction_lower ) )) + num_missing = (data.size * total_missing / 100) / (1 - total_missing / 100) + num_missing_lower = int(np.around(num_missing * fraction_lower)) + num_missing_upper = int(np.around(num_missing * (1 - fraction_lower))) - data = np.concatenate([ - data, - np.zeros(num_missing_lower) - 1000, - np.zeros(num_missing_upper) + 1000 - ]) + data = np.concatenate([data, np.zeros(num_missing_lower) - 1000, np.zeros(num_missing_upper) + 1000]) obt_mean, obt_std = stimulus_sync.trimmed_stats(data, pctiles=pctiles) assert obt_mean == true_mean assert obt_std == true_std -@pytest.mark.parametrize('pd_times,vs_times, expected', [ - [ [1, 2, 3, 4, 5], [1.8, 3, 4], [2, 3, 4] ] -]) +@pytest.mark.parametrize("pd_times,vs_times, expected", [[[1, 2, 3, 4, 5], [1.8, 3, 4], [2, 3, 4]]]) def test_trim_border_pulses(pd_times, vs_times, expected): obtained = stimulus_sync.trim_border_pulses(pd_times, vs_times) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('base,effect', [ - [ np.arange(20, dtype=float), [0.25, -0.25] ], - [ np.arange(20, dtype=float), [0.25, -0.25] ], - # [ np.arange(20, dtype=float), [0.25, -0.4] ], misses for assymmetric cases - # [ np.arange(20, dtype=float), [0.4, -0.25] ] -]) +@pytest.mark.parametrize( + "base,effect", + [ + [np.arange(20, dtype=float), [0.25, -0.25]], + [np.arange(20, dtype=float), [0.25, -0.25]], + # [ np.arange(20, dtype=float), [0.25, -0.4] ], misses for assymmetric cases + # [ np.arange(20, dtype=float), [0.4, -0.25] ] + ], +) def test_correct_on_off_effects(base, effect): impacted = base.copy() impacted[::2] += effect[0] impacted[1::2] += effect[1] - obtained = stimulus_sync.correct_on_off_effects(impacted) assert np.allclose(base, obtained) -@pytest.mark.parametrize('pd_times,ndevs,expected_mask', [ - [ [1, 2, 3, 9, 10, 11, 12], 4, [1, 1, 0, 0, 1, 1, 1] ], - [ [1.03, 2.10, 2.99, 8.9, 10.0, 11.1, 11.98], 10, [1, 1, 0, 0, 1, 1, 1] ] -]) +@pytest.mark.parametrize( + "pd_times,ndevs,expected_mask", + [ + [[1, 2, 3, 9, 10, 11, 12], 4, [1, 1, 0, 0, 1, 1, 1]], + [[1.03, 2.10, 2.99, 8.9, 10.0, 11.1, 11.98], 10, [1, 1, 0, 0, 1, 1, 1]], + ], +) def test_flag_unexpected_edges(pd_times, ndevs, expected_mask): - obtained_mask = stimulus_sync.flag_unexpected_edges(pd_times, ndevs) assert np.allclose(obtained_mask, expected_mask) -@pytest.mark.parametrize('pd_times,ndevs,cycle,max_offset,expected', [ - [ [0, 1, 2, 3, 4, 9, 10, 11], 10, 60, 5, np.arange(12) ], - [ - np.concatenate([[0, 1, 2, 3, 3.95, 4, 4.1, 4.7, 9, 10, 11], np.arange(12, 1000)]), - 0.1, 60, 5, - np.arange(1000) - ] -]) +@pytest.mark.parametrize( + "pd_times,ndevs,cycle,max_offset,expected", + [ + [[0, 1, 2, 3, 4, 9, 10, 11], 10, 60, 5, np.arange(12)], + [ + np.concatenate([[0, 1, 2, 3, 3.95, 4, 4.1, 4.7, 9, 10, 11], np.arange(12, 1000)]), + 0.1, + 60, + 5, + np.arange(1000), + ], + ], +) def test_fix_unexpected_edges(pd_times, ndevs, cycle, max_offset, expected): obtained = stimulus_sync.fix_unexpected_edges(pd_times, ndevs, cycle, max_offset) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('pd_times,cycle,expected', [ - [ [0, 1, 2, 3, 4, 5.1, 6, 7, 8], 1, 1] -]) +@pytest.mark.parametrize("pd_times,cycle,expected", [[[0, 1, 2, 3, 4, 5.1, 6, 7, 8], 1, 1]]) def test_estimate_frame_duration(pd_times, cycle, expected): obtained = stimulus_sync.estimate_frame_duration(pd_times, cycle) assert obtained == expected -@pytest.mark.parametrize('ends,frame_duration,irregularity,expected', [ - [ np.arange(20, dtype=float), 0.5, 1, np.concatenate([np.arange(19), [19.5]]) ] -]) +@pytest.mark.parametrize( + "ends,frame_duration,irregularity,expected", + [[np.arange(20, dtype=float), 0.5, 1, np.concatenate([np.arange(19), [19.5]])]], +) def test_assign_to_last(ends, frame_duration, irregularity, expected): _, obt_ends = stimulus_sync.assign_to_last(None, None, ends, frame_duration, irregularity, None) assert np.allclose(obt_ends, expected) -@pytest.mark.parametrize('vs_diff,index,starts,ends,frame_duration,irregularity,cycle,expected', [ - [ [1, 1, 1, 1, 2, 1, 1, 1, 1], 1, [5, 6, 7], [6, 7, 8], 1, 1, 3, [[5, 6, 8], [6, 8, 9]] ], - [ [1, 1, 1, 1, 0.5, 1, 1, 1, 1], 1, [5, 6, 7], [6, 7, 8], 1, -1, 3, [[5, 6, 6], [6, 6, 7]] ], -]) +@pytest.mark.parametrize( + "vs_diff,index,starts,ends,frame_duration,irregularity,cycle,expected", + [ + [[1, 1, 1, 1, 2, 1, 1, 1, 1], 1, [5, 6, 7], [6, 7, 8], 1, 1, 3, [[5, 6, 8], [6, 8, 9]]], + [[1, 1, 1, 1, 0.5, 1, 1, 1, 1], 1, [5, 6, 7], [6, 7, 8], 1, -1, 3, [[5, 6, 6], [6, 6, 7]]], + ], +) def test_allocate_by_vsync(vs_diff, index, starts, ends, frame_duration, irregularity, cycle, expected): - obt_starts, obt_ends = stimulus_sync.allocate_by_vsync(vs_diff, index, starts, ends, frame_duration, irregularity, cycle) + obt_starts, obt_ends = stimulus_sync.allocate_by_vsync( + vs_diff, index, starts, ends, frame_duration, irregularity, cycle + ) assert np.allclose(obt_starts, expected[0]) - assert np.allclose(obt_ends, expected[1]) \ No newline at end of file + assert np.allclose(obt_ends, expected[1]) diff --git a/allensdk/test/brain_observatory/ecephys/test_visualization.py b/allensdk/test/brain_observatory/ecephys/test_visualization.py index 099aa18f63..173276b6a0 100644 --- a/allensdk/test/brain_observatory/ecephys/test_visualization.py +++ b/allensdk/test/brain_observatory/ecephys/test_visualization.py @@ -1,14 +1,22 @@ import allensdk.brain_observatory.ecephys.visualization.__init__ as vis import pandas as pd + def test_raster_plot(): - spike_times = pd.DataFrame({ - 'unit_id': [2, 1, 2], - 'stimulus_presentation_id': [2, 2, 2, ], - 'time_since_stimulus_presentation_onset': [0.01, 0.02, 0.03] - }, index=pd.Index(name='spike_time', data=[1.01, 1.02, 1.03])) - + spike_times = pd.DataFrame( + { + "unit_id": [2, 1, 2], + "stimulus_presentation_id": [ + 2, + 2, + 2, + ], + "time_since_stimulus_presentation_onset": [0.01, 0.02, 0.03], + }, + index=pd.Index(name="spike_time", data=[1.01, 1.02, 1.03]), + ) + fig = vis.raster_plot(spike_times) ax = fig.get_axes()[0] - assert len(spike_times['unit_id'].unique()) == len(ax.collections) + assert len(spike_times["unit_id"].unique()) == len(ax.collections) diff --git a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py index 3d2ba6f9c1..5d22f1c4cf 100644 --- a/allensdk/test/brain_observatory/ecephys/test_write_nwb.py +++ b/allensdk/test/brain_observatory/ecephys/test_write_nwb.py @@ -194,9 +194,7 @@ def test_add_metadata(nwbfile, roundtripper, metadata, expected_metadata): if obtained[key] != value: misses[key] = {"expected": value, "obtained": obtained[key]} - assert ( - len(misses) == 0 - ), f"the following metadata items were mismatched: {misses}" + assert len(misses) == 0, f"the following metadata items were mismatched: {misses}" @pytest.mark.parametrize( @@ -216,9 +214,7 @@ def test_add_metadata(nwbfile, roundtripper, metadata, expected_metadata): ], "stop_time": [2.0, 4.0, 5.0, 6.0, 8.0], }, - index=pd.Index( - name="stimulus_presentations_id", data=[0, 1, 2, 3, 4] - ), + index=pd.Index(name="stimulus_presentations_id", data=[0, 1, 2, 3, 4]), ) ), ( @@ -244,9 +240,7 @@ def test_add_metadata(nwbfile, roundtripper, metadata, expected_metadata): "stop_time": [2.0, 4.0, 5.0, 6.0, 8.0], "color": [np.nan] + ["[1.0, 1.0, 1.0]"] * 4, }, - index=pd.Index( - name="stimulus_presentations_id", data=[0, 1, 2, 3, 4] - ), + index=pd.Index(name="stimulus_presentations_id", data=[0, 1, 2, 3, 4]), ) ), ], @@ -263,17 +257,14 @@ def test_add_stimulus_presentations(nwbfile, presentations, roundtripper): presentations.value["color_triplet"] = [""] + ["[1.0, 1.0, 1.0]"] * 4 presentations.value["color"] = "" expected_df = presentations.value[sorted(presentations.value.columns)] - obtained_df = obtained_stimulus_table[ - sorted(obtained_stimulus_table.columns)].copy() + obtained_df = obtained_stimulus_table[sorted(obtained_stimulus_table.columns)].copy() for col in obtained_df.columns: if col in expected_df.columns: obtained_df[col] = obtained_df[col].astype(expected_df[col].dtype) pd.testing.assert_frame_equal(expected_df, obtained_df, check_dtype=False) -def test_add_stimulus_presentations_color( - nwbfile, stimulus_presentations_color, roundtripper -): +def test_add_stimulus_presentations_color(nwbfile, stimulus_presentations_color, roundtripper): write_nwb.add_stimulus_timestamps(nwbfile, [0, 1]) presentations = Presentations(presentations=stimulus_presentations_color) presentations.to_nwb(nwbfile=nwbfile, stimulus_name_column="stimulus_name") @@ -289,9 +280,7 @@ def test_add_stimulus_presentations_color( if expected != obtained: mismatched = True - assert ( - not mismatched - ), f"expected: {expected_color}, obtained: {obtained_color}" + assert not mismatched, f"expected: {expected_color}, obtained: {obtained_color}" @pytest.mark.parametrize( @@ -345,9 +334,7 @@ def test_add_stimulus_presentations_color( ), ], ) -def test_add_optotagging_table_to_nwbfile( - nwbfile, roundtripper, opto_table, expected -): +def test_add_optotagging_table_to_nwbfile(nwbfile, roundtripper, opto_table, expected): opto_table["duration"] = opto_table["stop_time"] - opto_table["start_time"] opto_table = OptotaggingTable(table=opto_table) @@ -428,9 +415,7 @@ def test_add_probe_to_nwbfile( ], ) def test_add_ecephys_electrode_columns(nwbfile, columns_to_add): - allensdk.brain_observatory.ecephys.nwb_util._add_ecephys_electrode_columns( - nwbfile, columns_to_add - ) + allensdk.brain_observatory.ecephys.nwb_util._add_ecephys_electrode_columns(nwbfile, columns_to_add) if columns_to_add is None: expected_columns = [ @@ -445,7 +430,7 @@ def test_add_ecephys_electrode_columns(nwbfile, columns_to_add): @pytest.mark.parametrize( - ("channels, channel_number_whitelist, " "expected_electrode_table"), + ("channels, channel_number_whitelist, expected_electrode_table"), [ ( [ @@ -461,8 +446,7 @@ def test_add_ecephys_electrode_columns(nwbfile, columns_to_add): "left_right_ccf_coordinate": 25.0, "structure_acronym": "CA1", "impedance": np.nan, - "filtering": "AP band: 500 Hz high-pass; LFP " - "band: 1000 Hz low-pass", + "filtering": "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", }, { "id": 2, @@ -495,17 +479,14 @@ def test_add_ecephys_electrode_columns(nwbfile, columns_to_add): "impedance": [42.0, np.nan], "filtering": [ "custom", - "AP band: 500 Hz high-pass; " - "LFP band: 1000 Hz low-pass", + "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", ], } ).set_index("id"), ) ], ) -def test_add_ecephys_electrodes( - nwbfile, channels, channel_number_whitelist, expected_electrode_table -): +def test_add_ecephys_electrodes(nwbfile, channels, channel_number_whitelist, expected_electrode_table): mock_device = pynwb.device.Device(name="mock_device") mock_electrode_group = pynwb.ecephys.ElectrodeGroup( name="mock_group", description="", location="", device=mock_device @@ -515,14 +496,10 @@ def test_add_ecephys_electrodes( nwbfile, channels, mock_electrode_group, channel_number_whitelist ) - obt_electrode_table = nwbfile.electrodes.to_dataframe().drop( - columns=["group", "group_name"] - ) + obt_electrode_table = nwbfile.electrodes.to_dataframe().drop(columns=["group", "group_name"]) expected_electrode_table.rename(columns={"impedance": "imp"}, inplace=True) - pd.testing.assert_frame_equal( - obt_electrode_table, expected_electrode_table, check_like=True - ) + pd.testing.assert_frame_equal(obt_electrode_table, expected_electrode_table, check_like=True) @pytest.mark.parametrize( @@ -552,12 +529,8 @@ def test_add_ragged_data_to_dynamic_table(units_table, spike_times): assert np.allclose([13, 4, 12], units_table["spike_times"][2]) -@pytest.mark.parametrize( - "roundtrip,include_rotation", [[True, True], [True, False]] -) -def test_add_running_speed_to_nwbfile( - nwbfile, running_speed, roundtripper, roundtrip, include_rotation -): +@pytest.mark.parametrize("roundtrip,include_rotation", [[True, True], [True, False]]) +def test_add_running_speed_to_nwbfile(nwbfile, running_speed, roundtripper, roundtrip, include_rotation): nwbfile = write_nwb.add_running_speed_to_nwbfile(nwbfile, running_speed) if roundtrip: api_obt = roundtripper(nwbfile, EcephysNwbSessionApi) @@ -573,12 +546,8 @@ def test_add_running_speed_to_nwbfile( @pytest.mark.parametrize("roundtrip", [[True]]) -def test_add_raw_running_data_to_nwbfile( - nwbfile, raw_running_data, roundtripper, roundtrip -): - nwbfile = write_nwb.add_raw_running_data_to_nwbfile( - nwbfile, raw_running_data - ) +def test_add_raw_running_data_to_nwbfile(nwbfile, raw_running_data, roundtripper, roundtrip): + nwbfile = write_nwb.add_raw_running_data_to_nwbfile(nwbfile, raw_running_data) if roundtrip: api_obt = roundtripper(nwbfile, EcephysNwbSessionApi) else: @@ -703,9 +672,7 @@ def test_read_stimulus_table( columns_to_drop, expected, ): - expected = expected.set_index( - pd.Index(range(expected.shape[0]), name="stimulus_presentations_id") - ) + expected = expected.set_index(pd.Index(range(expected.shape[0]), name="stimulus_presentations_id")) dirname = str(tmpdir_factory.mktemp("ecephys_nwb_test")) stim_table_path = os.path.join(dirname, "stim_table.csv") @@ -717,9 +684,7 @@ def add_is_image_novel(stimulus_presentations, behavior_session_id): # not testing this for vcn return None - with patch.object( - Presentations, "_add_is_image_novel", wraps=add_is_image_novel - ): + with patch.object(Presentations, "_add_is_image_novel", wraps=add_is_image_novel): obt = Presentations.from_path( path=stim_table_path, behavior_session_id=1, @@ -750,9 +715,7 @@ def test_read_spike_times_to_dictionary(tmpdir_factory): spike_times_path, spike_units_path, local_to_global_unit_map ) for ii in range(15): - assert np.allclose( - obtained[-ii], sorted([spike_times[ii], spike_times[15 + ii]]) - ) + assert np.allclose(obtained[-ii], sorted([spike_times[ii], spike_times[15 + ii]])) def test_read_waveforms_to_dictionary(tmpdir_factory): @@ -781,9 +744,9 @@ def lfp_data(): subsample_channels = np.array([3, 2]) return { - "data": np.arange( - total_timestamps * len(subsample_channels), dtype=np.int16 - ).reshape((total_timestamps, len(subsample_channels))), + "data": np.arange(total_timestamps * len(subsample_channels), dtype=np.int16).reshape( + (total_timestamps, len(subsample_channels)) + ), "timestamps": np.linspace(0, 1, total_timestamps), "subsample_channels": subsample_channels, } @@ -810,8 +773,7 @@ def probe_data(): "left_right_ccf_coordinate": 15.0, "structure_acronym": "CA1", "impedence": np.nan, - "filtering": "AP band: 500 Hz high-pass; " - "LFP band: 1000 Hz low-pass", + "filtering": "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", }, { "id": 1, @@ -825,8 +787,7 @@ def probe_data(): "left_right_ccf_coordinate": 20.0, "structure_acronym": "CA2", "impedence": np.nan, - "filtering": "AP band: 500 Hz high-pass; " - "LFP band: 1000 Hz low-pass", + "filtering": "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", }, { "id": 2, @@ -840,8 +801,7 @@ def probe_data(): "left_right_ccf_coordinate": 25.0, "structure_acronym": "CA3", "impedence": np.nan, - "filtering": "AP band: 500 Hz high-pass; " - "LFP band: 1000 Hz low-pass", + "filtering": "AP band: 500 Hz high-pass; LFP band: 1000 Hz low-pass", }, ], "lfp": { @@ -903,9 +863,7 @@ def dummy_meta_from_json(dict_repr): ecephys_session_id=1, behavior_session_id=BehaviorSessionId(1), behavior_session_uuid=BehaviorSessionUUID(None), - date_of_acquisition=DateOfAcquisition( - date_of_acquisition=datetime.now() - ), + date_of_acquisition=DateOfAcquisition(date_of_acquisition=datetime.now()), equipment=Equipment("foo"), session_type=SessionType("foo"), stimulus_frame_rate=StimulusFrameRate(1.0), @@ -926,16 +884,12 @@ def dummy_meta_from_json(dict_repr): write_csd_to_h5(path=input_csd_path, **csd_data) np.save(input_timestamps_path, lfp_data["timestamps"], allow_pickle=False) - np.save( - input_channels_path, lfp_data["subsample_channels"], allow_pickle=False - ) + np.save(input_channels_path, lfp_data["subsample_channels"], allow_pickle=False) with open(input_data_path, "wb") as input_data_file: input_data_file.write(lfp_data["data"].tobytes()) with patch.object(Units, "from_json", wraps=lambda probe: None): - with patch.object( - BehaviorEcephysMetadata, "from_json", wraps=dummy_meta_from_json - ): + with patch.object(BehaviorEcephysMetadata, "from_json", wraps=dummy_meta_from_json): write_nwb.write_probe_lfp_file( 4242, test_session_metadata, @@ -944,9 +898,7 @@ def dummy_meta_from_json(dict_repr): probe_data, ) - exp_electrodes = ( - pd.DataFrame(probe_data["channels"]).set_index("id").loc[[2, 1], :] - ) + exp_electrodes = pd.DataFrame(probe_data["channels"]).set_index("id").loc[[2, 1], :] exp_electrodes = exp_electrodes.rename(columns={"impedance": "imp"}) exp_electrodes.rename( columns={ @@ -999,20 +951,14 @@ def dummy_meta_from_json(dict_repr): check_dtype=False, ) else: - pd.testing.assert_frame_equal( - obt_electrodes, exp_electrodes, check_like=True - ) + pd.testing.assert_frame_equal(obt_electrodes, exp_electrodes, check_like=True) - processing_module = obt_f.get_processing_module( - "current_source_density" - ) + processing_module = obt_f.get_processing_module("current_source_density") csd_series = processing_module["ecephys_csd"] assert np.allclose(csd_data["csd"], csd_series.time_series.data[:].T) - assert np.allclose( - csd_data["relative_window"], csd_series.time_series.timestamps[:] - ) + assert np.allclose(csd_data["relative_window"], csd_series.time_series.timestamps[:]) obt_channel_locations = np.stack( ( csd_series.virtual_electrode_x_positions, @@ -1026,9 +972,7 @@ def dummy_meta_from_json(dict_repr): @pytest.mark.parametrize("roundtrip", [True, False]) -def test_write_probe_lfp_file_roundtrip( - tmpdir_factory, roundtrip, lfp_data, probe_data, csd_data -): +def test_write_probe_lfp_file_roundtrip(tmpdir_factory, roundtrip, lfp_data, probe_data, csd_data): expected_csd = xr.DataArray( name="CSD", data=csd_data["csd"], @@ -1074,9 +1018,7 @@ def test_write_probe_lfp_file_roundtrip( write_csd_to_h5(path=input_csd_path, **csd_data) np.save(input_timestamps_path, lfp_data["timestamps"], allow_pickle=False) - np.save( - input_channels_path, lfp_data["subsample_channels"], allow_pickle=False - ) + np.save(input_channels_path, lfp_data["subsample_channels"], allow_pickle=False) with open(input_data_path, "wb") as input_data_file: input_data_file.write(lfp_data["data"].tobytes()) @@ -1084,17 +1026,11 @@ def test_write_probe_lfp_file_roundtrip( with patch.object( BehaviorEcephysMetadata, "from_json", - wraps=lambda dict_repr: create_autospec( - BehaviorEcephysMetadata, instance=True - ), + wraps=lambda dict_repr: create_autospec(BehaviorEcephysMetadata, instance=True), ): - write_nwb.write_probe_lfp_file( - 4242, None, datetime.now(), logging.INFO, probe_data - ) + write_nwb.write_probe_lfp_file(4242, None, datetime.now(), logging.INFO, probe_data) - obt = EcephysNwbSessionApi( - path=None, probe_lfp_paths={12345: NWBHDF5IO(output_path, "r").read} - ) + obt = EcephysNwbSessionApi(path=None, probe_lfp_paths={12345: NWBHDF5IO(output_path, "r").read}) obtained_lfp = obt.get_lfp(12345) obtained_csd = obt.get_current_source_density(12345) @@ -1133,9 +1069,7 @@ def invalid_epochs(): def test_add_invalid_times(invalid_epochs, tmpdir_factory): - nwbfile_name = str( - tmpdir_factory.mktemp("test").join("test_invalid_times.nwb") - ) + nwbfile_name = str(tmpdir_factory.mktemp("test").join("test_invalid_times.nwb")) nwbfile = NWBFile( session_description="EcephysSession", @@ -1152,9 +1086,7 @@ def test_add_invalid_times(invalid_epochs, tmpdir_factory): df = nwbfile.invalid_times.to_dataframe() df_in = nwbfile_in.invalid_times.to_dataframe() - pd.testing.assert_frame_equal( - df, df_in, check_like=True, check_dtype=False - ) + pd.testing.assert_frame_equal(df, df_in, check_like=True, check_dtype=False) def test_roundtrip_add_invalid_times(nwbfile, invalid_epochs, roundtripper): @@ -1213,9 +1145,7 @@ def expected_amplitudes(): return np.array([0, 15, 60, 45, 120]) -def test_scale_amplitudes( - spike_amplitudes, templates, spike_templates, expected_amplitudes -): +def test_scale_amplitudes(spike_amplitudes, templates, spike_templates, expected_amplitudes): scale_factor = 0.195 expected = expected_amplitudes * scale_factor @@ -1239,9 +1169,7 @@ def test_read_spike_amplitudes_to_dictionary( spike_units_path = os.path.join(tmpdir, "spike_units.npy") templates_path = os.path.join(tmpdir, "templates.npy") spike_templates_path = os.path.join(tmpdir, "spike_templates.npy") - inverse_whitening_matrix_path = os.path.join( - tmpdir, "inverse_whitening_matrix_path.npy" - ) + inverse_whitening_matrix_path = os.path.join(tmpdir, "inverse_whitening_matrix_path.npy") whitening_matrix = np.diag(np.arange(3) + 1) inverse_whitening_matrix = np.linalg.inv(whitening_matrix) @@ -1292,9 +1220,7 @@ def test_read_spike_amplitudes_to_dictionary( 54321: np.array([5, 4, 3, -1, 6]), }, { - 12345: np.array( - [0, 1, 2, 3, 4, 5] - ), # spike_amplitudes_mapping + 12345: np.array([0, 1, 2, 3, 4, 5]), # spike_amplitudes_mapping 54321: np.array([0, 1, 2, 3, 4]), }, ( @@ -1310,25 +1236,17 @@ def test_read_spike_amplitudes_to_dictionary( ), ], ) -def test_filter_and_sort_spikes( - spike_times_mapping, spike_amplitudes_mapping, expected -): +def test_filter_and_sort_spikes(spike_times_mapping, spike_amplitudes_mapping, expected): for unit in spike_times_mapping: expected_spike_times, expected_spike_amplitudes = expected ( obtained_spike_times, obtained_spike_amplitudes, - ) = _get_filtered_and_sorted_spikes( - spike_times_mapping[unit], spike_amplitudes_mapping[unit] - ) + ) = _get_filtered_and_sorted_spikes(spike_times_mapping[unit], spike_amplitudes_mapping[unit]) - np.testing.assert_equal( - obtained_spike_times, expected_spike_times[unit] - ) - np.testing.assert_equal( - obtained_spike_amplitudes, expected_spike_amplitudes[unit] - ) + np.testing.assert_equal(obtained_spike_times, expected_spike_times[unit]) + np.testing.assert_equal(obtained_spike_amplitudes, expected_spike_amplitudes[unit]) @pytest.mark.parametrize("roundtrip", [True, False]) @@ -1405,35 +1323,23 @@ def test_filter_and_sort_spikes( } ).set_index(keys="id", drop=True), { - 777: np.array( - [0.0, 1.0, 2.0, -1.0, 5.0, 4.0] - ), # spike_times + 777: np.array([0.0, 1.0, 2.0, -1.0, 5.0, 4.0]), # spike_times 778: np.array([5.0, 4.0, 3.0, -1.0, 6.0]), }, { - 777: np.array( - [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] - ), # spike_amplitudes + 777: np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), # spike_amplitudes 778: np.array([0.0, 1.0, 2.0, 3.0, 4.0]), }, { - 777: np.array( - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] - ), # mean_waveforms + 777: np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), # mean_waveforms 778: np.array([1.0, 2.0, 3.0, 4.0, 5.0]), }, ), ) ], ) -def test_add_probewise_data_to_nwbfile( - monkeypatch, nwbfile, roundtripper, roundtrip, probes, parsed_probe_data -): - expected_units_table = pd.read_pickle( - Path(__file__).absolute().parent - / "resources" - / "expected_units_table.pkl" - ) +def test_add_probewise_data_to_nwbfile(monkeypatch, nwbfile, roundtripper, roundtrip, probes, parsed_probe_data): + expected_units_table = pd.read_pickle(Path(__file__).absolute().parent / "resources" / "expected_units_table.pkl") units = Units( [ @@ -1463,9 +1369,7 @@ def test_add_probewise_data_to_nwbfile( else: obt = EcephysNwbSessionApi.from_nwbfile(nwbfile) - pd.testing.assert_frame_equal( - obt.nwbfile.units.to_dataframe(), expected_units_table - ) + pd.testing.assert_frame_equal(obt.nwbfile.units.to_dataframe(), expected_units_table) @pytest.mark.parametrize("roundtrip", [True, False]) @@ -1501,43 +1405,29 @@ def test_add_probewise_data_to_nwbfile( def test_add_eye_tracking_rig_geometry_data_to_nwbfile( nwbfile, roundtripper, roundtrip, eye_tracking_rig_geom, expected ): - nwbfile = write_nwb.add_eye_tracking_rig_geometry_data_to_nwbfile( - nwbfile, eye_tracking_rig_geom - ) + nwbfile = write_nwb.add_eye_tracking_rig_geometry_data_to_nwbfile(nwbfile, eye_tracking_rig_geom) if roundtrip: obt = roundtripper(nwbfile, EcephysNwbSessionApi) else: obt = EcephysNwbSessionApi.from_nwbfile(nwbfile) obtained_metadata = obt.get_rig_metadata() - pd.testing.assert_frame_equal( - obtained_metadata["geometry"], expected["geometry"], check_like=True - ) + pd.testing.assert_frame_equal(obtained_metadata["geometry"], expected["geometry"], check_like=True) assert obtained_metadata["equipment"] == expected["equipment"] @pytest.mark.parametrize("roundtrip", [True, False]) @pytest.mark.parametrize( - ( - "eye_tracking_frame_times, eye_dlc_tracking_data, " - "eye_gaze_data, expected_pupil_data, " - "expected_gaze_data" - ), + ("eye_tracking_frame_times, eye_dlc_tracking_data, eye_gaze_data, expected_pupil_data, expected_gaze_data"), [ ( # eye_tracking_frame_times pd.Series([3.0, 4.0, 5.0, 6.0, 7.0]), # eye_dlc_tracking_data { - "pupil_params": create_preload_eye_tracking_df( - np.full((5, 5), 1.0) - ), - "cr_params": create_preload_eye_tracking_df( - np.full((5, 5), 2.0) - ), - "eye_params": create_preload_eye_tracking_df( - np.full((5, 5), 3.0) - ), + "pupil_params": create_preload_eye_tracking_df(np.full((5, 5), 1.0)), + "cr_params": create_preload_eye_tracking_df(np.full((5, 5), 2.0)), + "eye_params": create_preload_eye_tracking_df(np.full((5, 5), 3.0)), }, # eye_gaze_data { @@ -1569,9 +1459,7 @@ def test_add_eye_tracking_rig_geometry_data_to_nwbfile( "x": [3.0, 5.0, np.nan, 9.0, 11.0], } ), - "synced_frame_timestamps": pd.Series( - [3.0, 4.0, 5.0, 6.0, 7.0] - ), + "synced_frame_timestamps": pd.Series([3.0, 4.0, 5.0, 6.0, 7.0]), }, # expected_pupil_data pd.DataFrame( @@ -1592,9 +1480,7 @@ def test_add_eye_tracking_rig_geometry_data_to_nwbfile( "eye_width": [6.0] * 5, "eye_phi": [3.0] * 5, }, - index=pd.Index( - name="Time (s)", data=[3.0, 4.0, 5.0, 6.0, 7.0] - ), + index=pd.Index(name="Time (s)", data=[3.0, 4.0, 5.0, 6.0, 7.0]), ), # expected_gaze_data pd.DataFrame( @@ -1648,9 +1534,7 @@ def test_add_eye_tracking_rig_geometry_data_to_nwbfile( 10.0, ], }, - index=pd.Index( - name="Time (s)", data=[3.0, 4.0, 5.0, 6.0, 7.0] - ), + index=pd.Index(name="Time (s)", data=[3.0, 4.0, 5.0, 6.0, 7.0]), ), ), ], @@ -1674,13 +1558,7 @@ def test_add_eye_tracking_data_to_nwbfile( else: obt = EcephysNwbSessionApi.from_nwbfile(nwbfile) obtained_pupil_data = obt.get_pupil_data() - obtained_screen_gaze_data = obt.get_screen_gaze_data( - include_filtered_data=True - ) + obtained_screen_gaze_data = obt.get_screen_gaze_data(include_filtered_data=True) - pd.testing.assert_frame_equal( - obtained_pupil_data, expected_pupil_data, check_like=True - ) - pd.testing.assert_frame_equal( - obtained_screen_gaze_data, expected_gaze_data, check_like=True - ) + pd.testing.assert_frame_equal(obtained_pupil_data, expected_pupil_data, check_like=True) + pd.testing.assert_frame_equal(obtained_screen_gaze_data, expected_gaze_data, check_like=True) diff --git a/allensdk/test/brain_observatory/extract_running_speed/test_extract_running_speed_module.py b/allensdk/test/brain_observatory/extract_running_speed/test_extract_running_speed_module.py index 94d783a4c2..9e98343732 100644 --- a/allensdk/test/brain_observatory/extract_running_speed/test_extract_running_speed_module.py +++ b/allensdk/test/brain_observatory/extract_running_speed/test_extract_running_speed_module.py @@ -5,32 +5,28 @@ import pytest + @pytest.fixture def use_temp_dir(tmpdir_factory): def fn(data_dir, tempdir_name, input_json_fname, output_json_fname, module, renamer_cb): - temp_dir = str(tmpdir_factory.mktemp(tempdir_name)) - + input_json_path = os.path.join(data_dir, input_json_fname) new_input_json_path = os.path.join(temp_dir, input_json_fname) output_json_path = os.path.join(temp_dir, output_json_fname) - with open(input_json_path, 'r') as input_json: + with open(input_json_path, "r") as input_json: input_json_data = json.load(input_json) input_json_data = renamer_cb(input_json_data, data_dir, temp_dir) - with open(new_input_json_path, 'w') as new_input_json: + with open(new_input_json_path, "w") as new_input_json: json.dump(input_json_data, new_input_json) - sp.check_call([ - 'python', '-m', module, - '--input_json', new_input_json_path, - '--output_json', output_json_path - ]) + sp.check_call(["python", "-m", module, "--input_json", new_input_json_path, "--output_json", output_json_path]) - with open(output_json_path, 'r') as output_json: + with open(output_json_path, "r") as output_json: output_json_data = json.load(output_json) return output_json_data @@ -49,28 +45,32 @@ def reparent(path, new_parent): @pytest.mark.requires_bamboo -@pytest.mark.parametrize('input_json_fname,output_json_fname,exp_fname', [ +@pytest.mark.parametrize( + "input_json_fname,output_json_fname,exp_fname", [ - "ECEPHYS_EXTRACT_RUNNING_SPEED_QUEUE_744228101_input.json", - 'ECEPHYS_EXTRACT_RUNNING_SPEED_QUEUE_744228101ls_output.json', - '744228101_running_speeds.h5', - ] -]) -def test_extract_running_speed_module( - use_temp_dir, input_json_fname, output_json_fname, exp_fname -): - + [ + "ECEPHYS_EXTRACT_RUNNING_SPEED_QUEUE_744228101_input.json", + "ECEPHYS_EXTRACT_RUNNING_SPEED_QUEUE_744228101ls_output.json", + "744228101_running_speeds.h5", + ] + ], +) +def test_extract_running_speed_module(use_temp_dir, input_json_fname, output_json_fname, exp_fname): def renamer(input_json_data, data_dir, temp_dir): - input_json_data['sync_h5_path'] = reparent(input_json_data['sync_h5_path'], data_dir) - input_json_data['stimulus_pkl_path'] = reparent(input_json_data['stimulus_pkl_path'], data_dir) + input_json_data["sync_h5_path"] = reparent(input_json_data["sync_h5_path"], data_dir) + input_json_data["stimulus_pkl_path"] = reparent(input_json_data["stimulus_pkl_path"], data_dir) - input_json_data['output_path'] = reparent(input_json_data['output_path'], temp_dir) + input_json_data["output_path"] = reparent(input_json_data["output_path"], temp_dir) return input_json_data use_temp_dir( - DATA_DIR, 'test_extract_running_speed', input_json_fname, output_json_fname, - 'allensdk.brain_observatory.extract_running_speed', renamer + DATA_DIR, + "test_extract_running_speed", + input_json_fname, + output_json_fname, + "allensdk.brain_observatory.extract_running_speed", + renamer, ) expected_path = os.path.join(DATA_DIR, exp_fname) assert os.path.exists(expected_path) diff --git a/allensdk/test/brain_observatory/gaze_mapping/test_gaze_mapping.py b/allensdk/test/brain_observatory/gaze_mapping/test_gaze_mapping.py index e269921362..1dd4a23b13 100644 --- a/allensdk/test/brain_observatory/gaze_mapping/test_gaze_mapping.py +++ b/allensdk/test/brain_observatory/gaze_mapping/test_gaze_mapping.py @@ -15,7 +15,7 @@ def gaze_mapper_fixture(request): "camera_position": np.array([0, 0, 0]), "camera_rotations": np.array([0, 0, 0]), "eye_radius": 0.1682, - "cm_per_pixel": (10.2 / 10000.0) + "cm_per_pixel": (10.2 / 10000.0), } default_params.update(request.param) return gm.GazeMapper(**default_params) @@ -25,206 +25,248 @@ def gaze_mapper_fixture(request): def rig_component_fixture(request): default_params = { "position_in_eye_coord_frame": np.array([0, 0, 0]), - "rotations_in_self_coord_frame": np.array([0, 0, 0]) + "rotations_in_self_coord_frame": np.array([0, 0, 0]), } default_params.update(request.param) return gm.EyeTrackingRigObject(**default_params) # ======== EyeTrackingRigObject tests ======== -@pytest.mark.parametrize('rig_component_fixture,expected', [ - ({"position_in_eye_coord_frame": [1, 0, 0]}, - [[0, 0, -1], - [-1, 0, 0], - [0, 1, 0]]), - - ({"position_in_eye_coord_frame": [0, 1, 0]}, - [[1, 0, 0], - [0, 0, -1], - [0, 1, 0]]), - - ({"position_in_eye_coord_frame": [0, 0, 1]}, - [[0, -1, 0], - [-1, 0, 0], - [0, 0, -1]]), -], indirect=['rig_component_fixture']) +@pytest.mark.parametrize( + "rig_component_fixture,expected", + [ + ({"position_in_eye_coord_frame": [1, 0, 0]}, [[0, 0, -1], [-1, 0, 0], [0, 1, 0]]), + ({"position_in_eye_coord_frame": [0, 1, 0]}, [[1, 0, 0], [0, 0, -1], [0, 1, 0]]), + ({"position_in_eye_coord_frame": [0, 0, 1]}, [[0, -1, 0], [-1, 0, 0], [0, 0, -1]]), + ], + indirect=["rig_component_fixture"], +) def test_generate_self_to_eye_frame_xform(rig_component_fixture, expected): obtained = rig_component_fixture.generate_self_to_eye_frame_xform() assert np.allclose(obtained.as_matrix(), expected) # ======== GazeMapper tests ======== -@pytest.mark.parametrize('gaze_mapper_fixture,expected', [ - # Simple 2D scenarios - ({"led_position": np.array([100, 0, 50])}, np.array([0.08417078, 0, 0.04208539])), - ({"led_position": np.array([50, 0, 20])}, np.array([0.08424169, 0, 0.03369668])), - # 3D scenarios - ({"led_position": np.array([246, 92.3, 52.6])}, np.array([0.07876523, 0.02955297, 0.01684167])), - ({"led_position": np.array([258.9, -61.2, 32.1])}, np.array([0.08187032, -0.01935289, 0.01015078])) - -], indirect=["gaze_mapper_fixture"]) +@pytest.mark.parametrize( + "gaze_mapper_fixture,expected", + [ + # Simple 2D scenarios + ({"led_position": np.array([100, 0, 50])}, np.array([0.08417078, 0, 0.04208539])), + ({"led_position": np.array([50, 0, 20])}, np.array([0.08424169, 0, 0.03369668])), + # 3D scenarios + ({"led_position": np.array([246, 92.3, 52.6])}, np.array([0.07876523, 0.02955297, 0.01684167])), + ({"led_position": np.array([258.9, -61.2, 32.1])}, np.array([0.08187032, -0.01935289, 0.01015078])), + ], + indirect=["gaze_mapper_fixture"], +) def test_compute_cr_coordinate(gaze_mapper_fixture, expected): obtained = gaze_mapper_fixture.compute_cr_coordinate() assert np.allclose(obtained, expected) -@pytest.mark.parametrize("gaze_mapper_fixture, method_inputs, expected", [ - ({"monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([150, 0, 0]), - "led_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[300, 300], [350, 350], [325, 325], [290, 290]]), - "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300]])}, - [[-0.16820, 0.0000, 0.0000], - [-0.15195, -0.0510, 0.0510], - [-0.16428, -0.0255, 0.0255], - [-0.16758, 0.0102, -0.0102]]), - - # Test when params result in estimated pupil location outside of eye - ({"monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([150, 0, 0]), - "led_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[900, 900], [350, 350], [325, 325], [250, 250], [100, 100]]), - "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300], [800, 800]])}, - [[np.nan, np.nan, np.nan], - [-0.15195, -0.0510, 0.0510], - [-0.16428, -0.0255, 0.0255], - [-0.15195, 0.051, -0.051], - [np.nan, np.nan, np.nan]]), - -], indirect=['gaze_mapper_fixture']) -def test_pupil_pos_in_eye_coords(gaze_mapper_fixture, - method_inputs, - expected): +@pytest.mark.parametrize( + "gaze_mapper_fixture, method_inputs, expected", + [ + ( + { + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([150, 0, 0]), + "led_position": np.array([130, 0, 0]), + }, + { + "cam_pupil_params": np.array([[300, 300], [350, 350], [325, 325], [290, 290]]), + "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300]]), + }, + [ + [-0.16820, 0.0000, 0.0000], + [-0.15195, -0.0510, 0.0510], + [-0.16428, -0.0255, 0.0255], + [-0.16758, 0.0102, -0.0102], + ], + ), + # Test when params result in estimated pupil location outside of eye + ( + { + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([150, 0, 0]), + "led_position": np.array([130, 0, 0]), + }, + { + "cam_pupil_params": np.array([[900, 900], [350, 350], [325, 325], [250, 250], [100, 100]]), + "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300], [800, 800]]), + }, + [ + [np.nan, np.nan, np.nan], + [-0.15195, -0.0510, 0.0510], + [-0.16428, -0.0255, 0.0255], + [-0.15195, 0.051, -0.051], + [np.nan, np.nan, np.nan], + ], + ), + ], + indirect=["gaze_mapper_fixture"], +) +def test_pupil_pos_in_eye_coords(gaze_mapper_fixture, method_inputs, expected): obtained = gaze_mapper_fixture.pupil_pos_in_eye_coords(**method_inputs) assert np.allclose(obtained, expected, rtol=1e-4, equal_nan=True) -@pytest.mark.parametrize("gaze_mapper_fixture, method_inputs, expected", [ - ({"monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([150, 0, 0]), - "led_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[300, 300], [350, 350], [325, 325], [290, 290]]), - "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300]])}, - [[0, 0], - [-57.057, -57.057], - [-26.386, -26.386], - [10.3472, 10.347]]), - - # Test when params result in estimated pupil location outside of eye - ({"monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([150, 0, 0]), - "led_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[300, 300], [900, 900], [325, 325], [200, 200]]), - "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [800, 800]])}, - [[0, 0], - [np.nan, np.nan], - [-26.386, -26.386], - [np.nan, np.nan]]), - -], indirect=['gaze_mapper_fixture']) -def test_pupil_position_on_monitor_in_cm(gaze_mapper_fixture, - method_inputs, - expected): +@pytest.mark.parametrize( + "gaze_mapper_fixture, method_inputs, expected", + [ + ( + { + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([150, 0, 0]), + "led_position": np.array([130, 0, 0]), + }, + { + "cam_pupil_params": np.array([[300, 300], [350, 350], [325, 325], [290, 290]]), + "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [300, 300]]), + }, + [[0, 0], [-57.057, -57.057], [-26.386, -26.386], [10.3472, 10.347]], + ), + # Test when params result in estimated pupil location outside of eye + ( + { + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([150, 0, 0]), + "led_position": np.array([130, 0, 0]), + }, + { + "cam_pupil_params": np.array([[300, 300], [900, 900], [325, 325], [200, 200]]), + "cam_cr_params": np.array([[300, 300], [300, 300], [300, 300], [800, 800]]), + }, + [[0, 0], [np.nan, np.nan], [-26.386, -26.386], [np.nan, np.nan]], + ), + ], + indirect=["gaze_mapper_fixture"], +) +def test_pupil_position_on_monitor_in_cm(gaze_mapper_fixture, method_inputs, expected): obtained = gaze_mapper_fixture.pupil_position_on_monitor_in_cm(**method_inputs) assert np.allclose(obtained, expected, rtol=1e-4, equal_nan=True) -@pytest.mark.parametrize("gaze_mapper_fixture, method_inputs, expected", [ - ({"monitor_position": np.array([170, 0, 0])}, # rig geometry parameters - {"pupil_pos_on_monitor_in_cm": np.array([[2, 5]])}, - np.array([[0.6740368979845053, 1.6845678100189891]])), # expected - - ({"monitor_position": np.array([100, 0, 0])}, - {"pupil_pos_on_monitor_in_cm": np.array([[5, 6], [8, 9]])}, - np.array([[2.862405226111748, 3.429356585864454], - [4.573921259900861, 5.126473695179203]])) -], indirect=['gaze_mapper_fixture']) -def test_pupil_position_on_monitor_in_degrees(gaze_mapper_fixture, - method_inputs, - expected): - obtained = gaze_mapper_fixture.pupil_position_on_monitor_in_degrees( - **method_inputs - ) +@pytest.mark.parametrize( + "gaze_mapper_fixture, method_inputs, expected", + [ + ( + {"monitor_position": np.array([170, 0, 0])}, # rig geometry parameters + {"pupil_pos_on_monitor_in_cm": np.array([[2, 5]])}, + np.array([[0.6740368979845053, 1.6845678100189891]]), + ), # expected + ( + {"monitor_position": np.array([100, 0, 0])}, + {"pupil_pos_on_monitor_in_cm": np.array([[5, 6], [8, 9]])}, + np.array([[2.862405226111748, 3.429356585864454], [4.573921259900861, 5.126473695179203]]), + ), + ], + indirect=["gaze_mapper_fixture"], +) +def test_pupil_position_on_monitor_in_degrees(gaze_mapper_fixture, method_inputs, expected): + obtained = gaze_mapper_fixture.pupil_position_on_monitor_in_degrees(**method_inputs) assert np.allclose(obtained, expected) -@pytest.mark.parametrize('gaze_mapper_fixture,ellipse_fits,expected,deg_diff_tolerance', [ - # General sanity check using extreme pupil values to see if output - # screen mapped coordinates are generally in the right quadrant/hemisphere. - - # As if looking at top half of screen - ({"led_position": np.array([135, 0, 0]), # rig geometry parameters - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[300, 200]]), # Pupil center (x, y) coords - "cam_cr_params": np.array([[300, 300]])}, # Corneal reflect (x, y) coords - np.array([[0, 1]]), # Expected general direction of outputs in unit vector form - 0), # Allowed angle tolerance (in degrees) between `expected - obtained` - - # As if looking at bottom half of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[300, 400]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[0, -1]]), - 0), - - # As if looking at right side of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[200, 300]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[1, 0]]), - 0), - - # As if looking at left side of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[400, 300]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[-1, 0]]), - 0), - - # As if looking at upper right quadrant of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[200, 200]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[1, 1]]), - 0), - - # As if looking at lower right quadrant of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[200, 400]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[1, -1]]), - 0), - - # As if looking to upper right quadrant of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[400, 200]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[-1, 1]]), - 0), - - # As if looking at lower left quadrant of screen - ({"led_position": np.array([135, 0, 0]), - "monitor_position": np.array([170, 0, 0]), - "camera_position": np.array([130, 0, 0])}, - {"cam_pupil_params": np.array([[400, 400]]), - "cam_cr_params": np.array([[300, 300]])}, - np.array([[-1, -1]]), - 0), - -], indirect=["gaze_mapper_fixture"]) +@pytest.mark.parametrize( + "gaze_mapper_fixture,ellipse_fits,expected,deg_diff_tolerance", + [ + # General sanity check using extreme pupil values to see if output + # screen mapped coordinates are generally in the right quadrant/hemisphere. + # As if looking at top half of screen + ( + { + "led_position": np.array([135, 0, 0]), # rig geometry parameters + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + { + "cam_pupil_params": np.array([[300, 200]]), # Pupil center (x, y) coords + "cam_cr_params": np.array([[300, 300]]), + }, # Corneal reflect (x, y) coords + np.array([[0, 1]]), # Expected general direction of outputs in unit vector form + 0, + ), # Allowed angle tolerance (in degrees) between `expected - obtained` + # As if looking at bottom half of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[300, 400]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[0, -1]]), + 0, + ), + # As if looking at right side of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[200, 300]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[1, 0]]), + 0, + ), + # As if looking at left side of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[400, 300]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[-1, 0]]), + 0, + ), + # As if looking at upper right quadrant of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[200, 200]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[1, 1]]), + 0, + ), + # As if looking at lower right quadrant of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[200, 400]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[1, -1]]), + 0, + ), + # As if looking to upper right quadrant of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[400, 200]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[-1, 1]]), + 0, + ), + # As if looking at lower left quadrant of screen + ( + { + "led_position": np.array([135, 0, 0]), + "monitor_position": np.array([170, 0, 0]), + "camera_position": np.array([130, 0, 0]), + }, + {"cam_pupil_params": np.array([[400, 400]]), "cam_cr_params": np.array([[300, 300]])}, + np.array([[-1, -1]]), + 0, + ), + ], + indirect=["gaze_mapper_fixture"], +) def test_mapping_gives_sane_outputs(gaze_mapper_fixture, ellipse_fits, expected, deg_diff_tolerance): obtained = gaze_mapper_fixture.pupil_position_on_monitor_in_cm(**ellipse_fits) for obt, exp in zip(obtained, expected): @@ -236,109 +278,131 @@ def test_mapping_gives_sane_outputs(gaze_mapper_fixture, ellipse_fits, expected, # ======== Standalone function tests ======== -@pytest.mark.parametrize('ellipse_params,expected', [ - (pd.DataFrame({"height": [1, 1, 1, 1], "width": [2, 2, 2, 2]}), - pd.Series([4 * np.pi] * 4)), - - (pd.DataFrame({"height": [2, 2, 2, 2], "width": [1, 1, 1, 1]}), - pd.Series([4 * np.pi] * 4)), - - (pd.DataFrame({"height": [2, 4, 8, 16], "width": [1, 3, 9, 27]}), - pd.Series([4 * np.pi, 16 * np.pi, 81 * np.pi, 729 * np.pi])), - - (pd.DataFrame({"height": [1, 3, 9, 27], "width": [2, 4, 8, 16]}), - pd.Series([4 * np.pi, 16 * np.pi, 81 * np.pi, 729 * np.pi])), - - (pd.DataFrame({"height": [np.nan, 3, np.nan, 27], - "width": [2, 4, np.nan, np.nan]}), - pd.Series([4 * np.pi, 16 * np.pi, np.nan, 729 * np.pi])), -]) +@pytest.mark.parametrize( + "ellipse_params,expected", + [ + (pd.DataFrame({"height": [1, 1, 1, 1], "width": [2, 2, 2, 2]}), pd.Series([4 * np.pi] * 4)), + (pd.DataFrame({"height": [2, 2, 2, 2], "width": [1, 1, 1, 1]}), pd.Series([4 * np.pi] * 4)), + ( + pd.DataFrame({"height": [2, 4, 8, 16], "width": [1, 3, 9, 27]}), + pd.Series([4 * np.pi, 16 * np.pi, 81 * np.pi, 729 * np.pi]), + ), + ( + pd.DataFrame({"height": [1, 3, 9, 27], "width": [2, 4, 8, 16]}), + pd.Series([4 * np.pi, 16 * np.pi, 81 * np.pi, 729 * np.pi]), + ), + ( + pd.DataFrame({"height": [np.nan, 3, np.nan, 27], "width": [2, 4, np.nan, np.nan]}), + pd.Series([4 * np.pi, 16 * np.pi, np.nan, 729 * np.pi]), + ), + ], +) def test_compute_circular_areas(ellipse_params, expected): obtained = gm.compute_circular_areas(ellipse_params) assert np.allclose(obtained, expected, equal_nan=True) -@pytest.mark.parametrize('ellipse_params, expected', [ - (pd.DataFrame({"height": [1, 2, 3, 4], "width": [4, 3, 2, 1]}), - pd.Series([4 * np.pi, 6 * np.pi, 6 * np.pi, 4 * np.pi])), - - (pd.DataFrame({"height": [np.nan, 7, 11, 12, np.nan], - "width": [5, 3, 11, np.nan, np.nan]}), - pd.Series([np.nan, np.pi * 21, np.pi * 121, np.nan, np.nan])) -]) +@pytest.mark.parametrize( + "ellipse_params, expected", + [ + ( + pd.DataFrame({"height": [1, 2, 3, 4], "width": [4, 3, 2, 1]}), + pd.Series([4 * np.pi, 6 * np.pi, 6 * np.pi, 4 * np.pi]), + ), + ( + pd.DataFrame({"height": [np.nan, 7, 11, 12, np.nan], "width": [5, 3, 11, np.nan, np.nan]}), + pd.Series([np.nan, np.pi * 21, np.pi * 121, np.nan, np.nan]), + ), + ], +) def test_compute_elliptical_areas(ellipse_params, expected): obtained = gm.compute_elliptical_areas(ellipse_params) assert np.allclose(obtained, expected, equal_nan=True) -@pytest.mark.parametrize("function_inputs,expected", [ - ({"plane_normal": np.array([1, 1, 1]), - "plane_point": np.array([1, 1, -5]), - "line_vectors": np.array([[6, 1, 4]]), - "line_points": np.array([[-5, 1, -1]])}, - np.array([-3.90909091, 1.18181818, -0.27272727])), - - ({"plane_normal": np.array([1, 0, 0]), - "plane_point": np.array([10, 0, 0]), - "line_vectors": np.array([[1, 0, 0], [1, 1, 1], [1, 1, 0]]), - "line_points": np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])}, - np.array([[10, 0, 0], [10, 10, 10], [10, 10, 0]])), - - ({"plane_normal": np.array([1, 0, 0]), - "plane_point": np.array([10, 0, 0]), - "line_vectors": np.array([[1, 0, 0], [1, 1, 1], [1, 1, 0], [1, 0, 0], [1, 0, 0]]), - "line_points": np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]])}, - np.array([[10, 0, 0], [10, 10, 10], [10, 10, 0], [10, 0, 0], [10, 0, 0]])), - - ({"plane_normal": np.array([1, 1, 1]), - "plane_point": np.array([10, 10, 10]), - "line_vectors": np.array([[1, 1, 1], [1, 1, 1]]), - "line_points": np.array([[1, 2, 3], [0, 0, 0]])}, - np.array([[9, 10, 11], [10, 10, 10]])), - - ({"plane_normal": np.array([2, 1, -4]), - "plane_point": np.array([1, 1, -0.25]), - "line_vectors": np.array([[1, 3, 1]]), - "line_points": np.array([[0, 2, 0]])}, - np.array([[2, 8, 2]])), -]) +@pytest.mark.parametrize( + "function_inputs,expected", + [ + ( + { + "plane_normal": np.array([1, 1, 1]), + "plane_point": np.array([1, 1, -5]), + "line_vectors": np.array([[6, 1, 4]]), + "line_points": np.array([[-5, 1, -1]]), + }, + np.array([-3.90909091, 1.18181818, -0.27272727]), + ), + ( + { + "plane_normal": np.array([1, 0, 0]), + "plane_point": np.array([10, 0, 0]), + "line_vectors": np.array([[1, 0, 0], [1, 1, 1], [1, 1, 0]]), + "line_points": np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]), + }, + np.array([[10, 0, 0], [10, 10, 10], [10, 10, 0]]), + ), + ( + { + "plane_normal": np.array([1, 0, 0]), + "plane_point": np.array([10, 0, 0]), + "line_vectors": np.array([[1, 0, 0], [1, 1, 1], [1, 1, 0], [1, 0, 0], [1, 0, 0]]), + "line_points": np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]), + }, + np.array([[10, 0, 0], [10, 10, 10], [10, 10, 0], [10, 0, 0], [10, 0, 0]]), + ), + ( + { + "plane_normal": np.array([1, 1, 1]), + "plane_point": np.array([10, 10, 10]), + "line_vectors": np.array([[1, 1, 1], [1, 1, 1]]), + "line_points": np.array([[1, 2, 3], [0, 0, 0]]), + }, + np.array([[9, 10, 11], [10, 10, 10]]), + ), + ( + { + "plane_normal": np.array([2, 1, -4]), + "plane_point": np.array([1, 1, -0.25]), + "line_vectors": np.array([[1, 3, 1]]), + "line_points": np.array([[0, 2, 0]]), + }, + np.array([[2, 8, 2]]), + ), + ], +) def test_project_to_plane(function_inputs, expected): obtained = gm.project_to_plane(**function_inputs) print(obtained) assert np.allclose(obtained, expected) -@pytest.mark.parametrize("function_inputs, expected", [ - ({'x_rotation': 0.5, - 'y_rotation': 0.5, - 'z_rotation': 0.5}, - [[0.77015115, -0.21902415, 0.59907898], - [0.42073549, 0.88034656, -0.21902415], - [-0.47942554, 0.42073549, 0.77015115]]), - - ({'x_rotation': 0.5, - 'y_rotation': 0, - 'z_rotation': 0}, - [[1.0, 0.0, 0.0], - [0.0, 0.87758256, -0.47942554], - [0.0, 0.47942554, 0.87758256]]), - - ({'x_rotation': 0, - 'y_rotation': 0.5, - 'z_rotation': 0}, - [[0.87758256, 0.0, 0.47942554], - [0.0, 1.0, 0.0], - [-0.47942554, 0.0, 0.87758256]]), - - ({'x_rotation': 0, - 'y_rotation': 0, - 'z_rotation': 0.5}, - [[0.87758256, -0.47942554, 0.0], - [0.47942554, 0.87758256, 0.0], - [0.0, 0.0, 1.0]]), -]) +@pytest.mark.parametrize( + "function_inputs, expected", + [ + ( + {"x_rotation": 0.5, "y_rotation": 0.5, "z_rotation": 0.5}, + [ + [0.77015115, -0.21902415, 0.59907898], + [0.42073549, 0.88034656, -0.21902415], + [-0.47942554, 0.42073549, 0.77015115], + ], + ), + ( + {"x_rotation": 0.5, "y_rotation": 0, "z_rotation": 0}, + [[1.0, 0.0, 0.0], [0.0, 0.87758256, -0.47942554], [0.0, 0.47942554, 0.87758256]], + ), + ( + {"x_rotation": 0, "y_rotation": 0.5, "z_rotation": 0}, + [[0.87758256, 0.0, 0.47942554], [0.0, 1.0, 0.0], [-0.47942554, 0.0, 0.87758256]], + ), + ( + {"x_rotation": 0, "y_rotation": 0, "z_rotation": 0.5}, + [[0.87758256, -0.47942554, 0.0], [0.47942554, 0.87758256, 0.0], [0.0, 0.0, 1.0]], + ), + ], +) def test_generate_object_rotation_xform(function_inputs, expected): obtained = gm.generate_object_rotation_xform(**function_inputs) assert np.allclose(obtained.as_matrix(), expected) diff --git a/allensdk/test/brain_observatory/gaze_mapping/test_main.py b/allensdk/test/brain_observatory/gaze_mapping/test_main.py index c814b85174..1bbc1f49fe 100644 --- a/allensdk/test/brain_observatory/gaze_mapping/test_main.py +++ b/allensdk/test/brain_observatory/gaze_mapping/test_main.py @@ -10,34 +10,39 @@ from allensdk.brain_observatory.sync_dataset import Dataset -def create_sample_ellipse_hdf(output_file: Path, - cr_data: pd.DataFrame, - eye_data: pd.DataFrame, - pupil_data: pd.DataFrame): - cr_data.to_hdf(output_file, key='cr', mode='w') - eye_data.to_hdf(output_file, key='eye', mode='a') - pupil_data.to_hdf(output_file, key='pupil', mode='a') +def create_sample_ellipse_hdf( + output_file: Path, cr_data: pd.DataFrame, eye_data: pd.DataFrame, pupil_data: pd.DataFrame +): + cr_data.to_hdf(output_file, key="cr", mode="w") + eye_data.to_hdf(output_file, key="eye", mode="a") + pupil_data.to_hdf(output_file, key="pupil", mode="a") @pytest.fixture def ellipse_fits_fixture(tmp_path, request) -> dict: - cr = {"center_x": [300, 305, 295, 310, 280], - "center_y": [300, 305, 295, 310, 280], - "width": [7, 8, 6, 7, 10], - "height": [6, 9, 5, 6, 8], - "phi": [0, 0.1, 0.15, 0.1, 0]} - - eye = {"center_x": [300, 305, 295, 310, 280], - "center_y": [300, 305, 295, 310, 280], - "width": [150, 155, 160, 150, 155], - "height": [120, 115, 120, 110, 100], - "phi": [0, 0.1, 0.15, 0.1, 0]} - - pupil = {"center_x": [300, 305, 295, 310, 280], - "center_y": [300, 305, 295, 310, 280], - "width": [30, 35, 40, 25, 50], - "height": [25, 27, 30, 20, 45], - "phi": [0, 0.1, 0.15, 0.1, 0]} + cr = { + "center_x": [300, 305, 295, 310, 280], + "center_y": [300, 305, 295, 310, 280], + "width": [7, 8, 6, 7, 10], + "height": [6, 9, 5, 6, 8], + "phi": [0, 0.1, 0.15, 0.1, 0], + } + + eye = { + "center_x": [300, 305, 295, 310, 280], + "center_y": [300, 305, 295, 310, 280], + "width": [150, 155, 160, 150, 155], + "height": [120, 115, 120, 110, 100], + "phi": [0, 0.1, 0.15, 0.1, 0], + } + + pupil = { + "center_x": [300, 305, 295, 310, 280], + "center_y": [300, 305, 295, 310, 280], + "width": [30, 35, 40, 25, 50], + "height": [25, 27, 30, 20, 45], + "phi": [0, 0.1, 0.15, 0.1, 0], + } test_dir = tmp_path / "test_load_ellipse_fit_params" test_dir.mkdir() @@ -46,8 +51,7 @@ def ellipse_fits_fixture(tmp_path, request) -> dict: test_path = test_dir / "good_ellipse_fits.h5" else: test_path = test_dir / "bad_ellipse_fits.h5" - pupil = {"center_x": [300], "center_y": [300], "width": [30], - "height": [25], "phi": [0]} + pupil = {"center_x": [300], "center_y": [300], "width": [30], "height": [25], "phi": [0]} cr = pd.DataFrame(cr) eye = pd.DataFrame(eye) @@ -55,20 +59,20 @@ def ellipse_fits_fixture(tmp_path, request) -> dict: create_sample_ellipse_hdf(test_path, cr, eye, pupil) - return {"cr": pd.DataFrame(cr), - "eye": pd.DataFrame(eye), - "pupil": pd.DataFrame(pupil), - "file_path": test_path} + return {"cr": pd.DataFrame(cr), "eye": pd.DataFrame(eye), "pupil": pd.DataFrame(pupil), "file_path": test_path} -@pytest.mark.parametrize("ellipse_fits_fixture, expect_good_file", [ - ({"create_good_fits_file": True}, True), - ({"create_good_fits_file": False}, False) -], indirect=["ellipse_fits_fixture"]) +@pytest.mark.parametrize( + "ellipse_fits_fixture, expect_good_file", + [({"create_good_fits_file": True}, True), ({"create_good_fits_file": False}, False)], + indirect=["ellipse_fits_fixture"], +) def test_load_ellipse_fit_params(ellipse_fits_fixture: dict, expect_good_file: bool): - expected = {"cr_params": pd.DataFrame(ellipse_fits_fixture["cr"]).astype(float), - "pupil_params": pd.DataFrame(ellipse_fits_fixture["pupil"]).astype(float), - "eye_params": pd.DataFrame(ellipse_fits_fixture["eye"]).astype(float)} + expected = { + "cr_params": pd.DataFrame(ellipse_fits_fixture["cr"]).astype(float), + "pupil_params": pd.DataFrame(ellipse_fits_fixture["pupil"]).astype(float), + "eye_params": pd.DataFrame(ellipse_fits_fixture["eye"]).astype(float), + } if expect_good_file: obtained = main.load_ellipse_fit_params(ellipse_fits_fixture["file_path"]) @@ -79,57 +83,64 @@ def test_load_ellipse_fit_params(ellipse_fits_fixture: dict, expect_good_file: b obtained = main.load_ellipse_fit_params(ellipse_fits_fixture["file_path"]) -@pytest.mark.parametrize("input_args, expected", [ - ({"input_file": Path("input_file.h5"), - "session_sync_file": Path("sync_file.h5"), - "output_file": Path("output_file.h5"), - "monitor_position_x_mm": 100.0, - "monitor_position_y_mm": 500.0, - "monitor_position_z_mm": 300.0, - "monitor_rotation_x_deg": 30, - "monitor_rotation_y_deg": 60, - "monitor_rotation_z_deg": 90, - "camera_position_x_mm": 200.0, - "camera_position_y_mm": 600.0, - "camera_position_z_mm": 700.0, - "camera_rotation_x_deg": 20, - "camera_rotation_y_deg": 180, - "camera_rotation_z_deg": 5, - "led_position_x_mm": 800.0, - "led_position_y_mm": 900.0, - "led_position_z_mm": 1000.0, - "eye_radius_cm": 0.1682, - "cm_per_pixel": 0.0001, - "equipment": "Rig A", - "date_of_acquisition": "Some Date", - "eye_video_file": Path("eye_video.avi")}, - - {"pupil_params": "pupil_params_placeholder", - "cr_params": "cr_params_placeholder", - "eye_params": "eye_params_placeholder", - "session_sync_file": Path("sync_file.h5"), - "output_file": Path("output_file.h5"), - "monitor_position": np.array([10.0, 50.0, 30.0]), - "monitor_rotations": np.array([np.pi / 6, np.pi / 3, np.pi / 2]), - "camera_position": np.array([20.0, 60.0, 70.0]), - "camera_rotations": np.array([np.pi / 9, np.pi, np.pi / 36]), - "led_position": np.array([80.0, 90.0, 100.0]), - "eye_radius_cm": 0.1682, - "cm_per_pixel": 0.0001, - "equipment": "Rig A", - "date_of_acquisition": "Some Date", - "eye_video_file": Path("eye_video.avi")} - ), - -]) +@pytest.mark.parametrize( + "input_args, expected", + [ + ( + { + "input_file": Path("input_file.h5"), + "session_sync_file": Path("sync_file.h5"), + "output_file": Path("output_file.h5"), + "monitor_position_x_mm": 100.0, + "monitor_position_y_mm": 500.0, + "monitor_position_z_mm": 300.0, + "monitor_rotation_x_deg": 30, + "monitor_rotation_y_deg": 60, + "monitor_rotation_z_deg": 90, + "camera_position_x_mm": 200.0, + "camera_position_y_mm": 600.0, + "camera_position_z_mm": 700.0, + "camera_rotation_x_deg": 20, + "camera_rotation_y_deg": 180, + "camera_rotation_z_deg": 5, + "led_position_x_mm": 800.0, + "led_position_y_mm": 900.0, + "led_position_z_mm": 1000.0, + "eye_radius_cm": 0.1682, + "cm_per_pixel": 0.0001, + "equipment": "Rig A", + "date_of_acquisition": "Some Date", + "eye_video_file": Path("eye_video.avi"), + }, + { + "pupil_params": "pupil_params_placeholder", + "cr_params": "cr_params_placeholder", + "eye_params": "eye_params_placeholder", + "session_sync_file": Path("sync_file.h5"), + "output_file": Path("output_file.h5"), + "monitor_position": np.array([10.0, 50.0, 30.0]), + "monitor_rotations": np.array([np.pi / 6, np.pi / 3, np.pi / 2]), + "camera_position": np.array([20.0, 60.0, 70.0]), + "camera_rotations": np.array([np.pi / 9, np.pi, np.pi / 36]), + "led_position": np.array([80.0, 90.0, 100.0]), + "eye_radius_cm": 0.1682, + "cm_per_pixel": 0.0001, + "equipment": "Rig A", + "date_of_acquisition": "Some Date", + "eye_video_file": Path("eye_video.avi"), + }, + ), + ], +) def test_preprocess_input_args(monkeypatch, input_args: dict, expected: dict): def mock_load_ellipse_fit_params(*args, **kwargs): - return {"pupil_params": "pupil_params_placeholder", - "cr_params": "cr_params_placeholder", - "eye_params": "eye_params_placeholder"} + return { + "pupil_params": "pupil_params_placeholder", + "cr_params": "cr_params_placeholder", + "eye_params": "eye_params_placeholder", + } - monkeypatch.setattr(main, "load_ellipse_fit_params", - mock_load_ellipse_fit_params) + monkeypatch.setattr(main, "load_ellipse_fit_params", mock_load_ellipse_fit_params) obtained = main.preprocess_input_args(input_args) @@ -140,16 +151,14 @@ def mock_load_ellipse_fit_params(*args, **kwargs): assert obtained[key] == expected[key] -@pytest.mark.parametrize("pupil_params_rows, expected, expect_fail", [ - (5, pd.Series([1, 2, 3, 4, 5]), False), - (4, None, True) -]) +@pytest.mark.parametrize( + "pupil_params_rows, expected, expect_fail", [(5, pd.Series([1, 2, 3, 4, 5]), False), (4, None, True)] +) def test_load_sync_file_timings(monkeypatch, pupil_params_rows, expected, expect_fail): def mock_get_synchronized_frame_times(*args, **kwargs): return pd.Series([1, 2, 3, 4, 5]) - monkeypatch.setattr(main.su, "get_synchronized_frame_times", - mock_get_synchronized_frame_times) + monkeypatch.setattr(main.su, "get_synchronized_frame_times", mock_get_synchronized_frame_times) if expect_fail: with pytest.raises(RuntimeError, match="number of camera sync pulses"): @@ -170,10 +179,9 @@ class MockDataset(Dataset): def __init__(self, path): pass - def get_edges(self, kind, keys, units='seconds'): + def get_edges(self, kind, keys, units="seconds"): return pd.Series([1, 2, 3, 4, 500, 501, 502, 503], dtype=np.int64) - with monkeypatch.context() as ctx: ctx.setattr(su, "Dataset", MockDataset) timestamps = main.load_sync_file_timings("", 8, False) diff --git a/allensdk/test/brain_observatory/multi_stimulus_running_speed/test_multi_stimulus_running_speed.py b/allensdk/test/brain_observatory/multi_stimulus_running_speed/test_multi_stimulus_running_speed.py index 793b4c07cd..ee43e693f4 100644 --- a/allensdk/test/brain_observatory/multi_stimulus_running_speed/test_multi_stimulus_running_speed.py +++ b/allensdk/test/brain_observatory/multi_stimulus_running_speed/test_multi_stimulus_running_speed.py @@ -4,23 +4,14 @@ import pathlib import pandas as pd -from allensdk.brain_observatory.\ - multi_stimulus_running_speed.multi_stimulus_running_speed import ( - MultiStimulusRunningSpeed - ) +from allensdk.brain_observatory.multi_stimulus_running_speed.multi_stimulus_running_speed import ( + MultiStimulusRunningSpeed, +) DATA_DIR = os.environ.get( "ECEPHYS_PIPELINE_DATA", - os.path.join( - "/", - "allen", - "aibs", - "informatics", - "module_test_data", - "ecephys", - "filtered_running_speed" - ), + os.path.join("/", "allen", "aibs", "informatics", "module_test_data", "ecephys", "filtered_running_speed"), ) NUMBER_OF_VSYNCS = 523198 @@ -33,103 +24,66 @@ @pytest.mark.requires_bamboo @pytest.fixture(scope="session") def sync_h5_path_fixture(): - sync_h5_path = os.path.join( - DATA_DIR, - '1090803859_553960_20210317.sync' - ) + sync_h5_path = os.path.join(DATA_DIR, "1090803859_553960_20210317.sync") return sync_h5_path @pytest.mark.requires_bamboo @pytest.fixture(scope="session") def pkl_path_fixture(): - mapping_pkl_path = os.path.join( - DATA_DIR, - '1090803859_553960_20210317_mapping.pkl' - ) + mapping_pkl_path = os.path.join(DATA_DIR, "1090803859_553960_20210317_mapping.pkl") - behavior_pkl_path = os.path.join( - DATA_DIR, - '1090803859_553960_20210317_behavior.pkl' - ) + behavior_pkl_path = os.path.join(DATA_DIR, "1090803859_553960_20210317_behavior.pkl") - replay_pkl_path = os.path.join( - DATA_DIR, - '1090803859_553960_20210317_replay.pkl' - ) + replay_pkl_path = os.path.join(DATA_DIR, "1090803859_553960_20210317_replay.pkl") - return {'behavior': behavior_pkl_path, - 'mapping': mapping_pkl_path, - 'replay': replay_pkl_path} + return {"behavior": behavior_pkl_path, "mapping": mapping_pkl_path, "replay": replay_pkl_path} @pytest.mark.requires_bamboo @pytest.fixture(scope="session") -def multi_stimulus_fixture(tmpdir_factory, - sync_h5_path_fixture, - pkl_path_fixture): - - temp_output_dir = pathlib.Path( - tmpdir_factory.mktemp('MultiStimulusRunningSpeedOutput') - ) +def multi_stimulus_fixture(tmpdir_factory, sync_h5_path_fixture, pkl_path_fixture): + temp_output_dir = pathlib.Path(tmpdir_factory.mktemp("MultiStimulusRunningSpeedOutput")) - output_path = tempfile.mkstemp( - dir=temp_output_dir, - prefix='output_', - suffix='.h5')[1] + output_path = tempfile.mkstemp(dir=temp_output_dir, prefix="output_", suffix=".h5")[1] - output_json = tempfile.mkstemp( - dir=temp_output_dir, - prefix='output_', - suffix='.json')[1] + output_json = tempfile.mkstemp(dir=temp_output_dir, prefix="output_", suffix=".json")[1] args = { - 'mapping_pkl_path': pkl_path_fixture['mapping'], - 'behavior_pkl_path': pkl_path_fixture['behavior'], - 'replay_pkl_path': pkl_path_fixture['replay'], - 'sync_h5_path': sync_h5_path_fixture, - 'output_json': output_json, - 'output_path': output_path, - 'use_lowpass_filter': True, - 'zscore_threshold': 10.0 + "mapping_pkl_path": pkl_path_fixture["mapping"], + "behavior_pkl_path": pkl_path_fixture["behavior"], + "replay_pkl_path": pkl_path_fixture["replay"], + "sync_h5_path": sync_h5_path_fixture, + "output_json": output_json, + "output_path": output_path, + "use_lowpass_filter": True, + "zscore_threshold": 10.0, } - return MultiStimulusRunningSpeed( - args=[], - input_data=args - ) + return MultiStimulusRunningSpeed(args=[], input_data=args) # smoke test @pytest.mark.requires_bamboo def test_proccessing(multi_stimulus_fixture): - multi_stimulus_fixture.process() - output_path = multi_stimulus_fixture.args['output_path'] + output_path = multi_stimulus_fixture.args["output_path"] - obtained_velocity = pd.read_hdf( - output_path, - key="running_speed" - ) + obtained_velocity = pd.read_hdf(output_path, key="running_speed") - obtained_raw = pd.read_hdf( - output_path, - key="raw_data" - ) + obtained_raw = pd.read_hdf(output_path, key="raw_data") # check that keys exist - assert('net_rotation' in obtained_velocity) - assert('velocity' in obtained_velocity) - assert('frame_indexes' in obtained_velocity) - assert('frame_time' in obtained_velocity) - assert('vsig' in obtained_raw) - assert('vin' in obtained_raw) - assert('frame_time' in obtained_raw) - assert('dx' in obtained_raw) + assert "net_rotation" in obtained_velocity + assert "velocity" in obtained_velocity + assert "frame_indexes" in obtained_velocity + assert "frame_time" in obtained_velocity + assert "vsig" in obtained_raw + assert "vin" in obtained_raw + assert "frame_time" in obtained_raw + assert "dx" in obtained_raw # check that the data is the correct length - assert(len(obtained_raw['frame_time']) == NUMBER_OF_VSYNCS) - assert( - len(obtained_velocity['net_rotation']) == NUMBER_OF_NET_ROTATION_ITEMS - ) + assert len(obtained_raw["frame_time"]) == NUMBER_OF_VSYNCS + assert len(obtained_velocity["net_rotation"]) == NUMBER_OF_NET_ROTATION_ITEMS diff --git a/allensdk/test/brain_observatory/nwb/conftest.py b/allensdk/test/brain_observatory/nwb/conftest.py index 6f2fa9d8b3..fa7c2e942b 100644 --- a/allensdk/test/brain_observatory/nwb/conftest.py +++ b/allensdk/test/brain_observatory/nwb/conftest.py @@ -1,10 +1,8 @@ import sys - def pytest_ignore_collect(path, config): - ''' The brain_observatory.ecephys submodule uses python 3.6 features that may not be backwards compatible! - ''' + """The brain_observatory.ecephys submodule uses python 3.6 features that may not be backwards compatible!""" if sys.version_info < (3, 6): return True diff --git a/allensdk/test/brain_observatory/nwb/test_nwb.py b/allensdk/test/brain_observatory/nwb/test_nwb.py index 48ed3f65be..cd9e53bea8 100644 --- a/allensdk/test/brain_observatory/nwb/test_nwb.py +++ b/allensdk/test/brain_observatory/nwb/test_nwb.py @@ -7,7 +7,6 @@ @pytest.fixture def version_only_nwbfile_fixture(tmp_path, request): - nwb_version = request.param.get("nwb_version", "2.2.2") nwbfile_path = tmp_path / "version_only_nwbfile.nwb" @@ -26,31 +25,29 @@ def version_only_nwbfile_fixture(tmp_path, request): return str(nwbfile_path) -@pytest.mark.parametrize("version_only_nwbfile_fixture, min_desired_version" - ", warns, warn_msg, invalid_nwb", [ - ({"nwb_version": None}, "2.2.2" , True, "Warn msg A", True), - ({"nwb_version": "0.9.0c"}, "2.2.2" , True, "Warn msg B", False), - ({"nwb_version": "2"}, "2.2.2", True, "Warn msg C", False), - ({"nwb_version": "2.0b"}, "2.2.2", True, "Warn msg D", False), - ({"nwb_version": "2.2.2"}, "2.2.2", False, None, False), - ({"nwb_version": "2.2.8"}, "2.2.2", False, None, False), - ({"nwb_version": "3.0"}, "2.2.2", False, None, False) -], indirect=["version_only_nwbfile_fixture"]) -def test_check_nwbfile_version(version_only_nwbfile_fixture, - min_desired_version, warns, - warn_msg, invalid_nwb): - +@pytest.mark.parametrize( + "version_only_nwbfile_fixture, min_desired_version, warns, warn_msg, invalid_nwb", + [ + ({"nwb_version": None}, "2.2.2", True, "Warn msg A", True), + ({"nwb_version": "0.9.0c"}, "2.2.2", True, "Warn msg B", False), + ({"nwb_version": "2"}, "2.2.2", True, "Warn msg C", False), + ({"nwb_version": "2.0b"}, "2.2.2", True, "Warn msg D", False), + ({"nwb_version": "2.2.2"}, "2.2.2", False, None, False), + ({"nwb_version": "2.2.8"}, "2.2.2", False, None, False), + ({"nwb_version": "3.0"}, "2.2.2", False, None, False), + ], + indirect=["version_only_nwbfile_fixture"], +) +def test_check_nwbfile_version(version_only_nwbfile_fixture, min_desired_version, warns, warn_msg, invalid_nwb): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - check_nwbfile_version(nwbfile_path=version_only_nwbfile_fixture, - desired_minimum_version=min_desired_version, - warning_msg=warn_msg) + check_nwbfile_version( + nwbfile_path=version_only_nwbfile_fixture, desired_minimum_version=min_desired_version, warning_msg=warn_msg + ) if warns: if invalid_nwb: - assert ("neither a 'nwb_version' field " - "nor dataset could be found" - in str(w[-1].message)) + assert "neither a 'nwb_version' field nor dataset could be found" in str(w[-1].message) else: assert warn_msg in str(w[-1].message) else: diff --git a/allensdk/test/brain_observatory/nwb/test_nwb_api.py b/allensdk/test/brain_observatory/nwb/test_nwb_api.py index 059474f149..cff59ede7f 100644 --- a/allensdk/test/brain_observatory/nwb/test_nwb_api.py +++ b/allensdk/test/brain_observatory/nwb/test_nwb_api.py @@ -6,6 +6,6 @@ def test_missing_file(tmpdir_factory): - path = os.path.join(str(tmpdir_factory.mktemp('nwb_api_missing_file_test')), 'foo.nwb') + path = os.path.join(str(tmpdir_factory.mktemp("nwb_api_missing_file_test")), "foo.nwb") with pytest.raises(OSError): - NwbApi.from_path(path) \ No newline at end of file + NwbApi.from_path(path) diff --git a/allensdk/test/brain_observatory/nwb/test_nwb_utils.py b/allensdk/test/brain_observatory/nwb/test_nwb_utils.py index 45224ef52d..fc6a89cdbc 100644 --- a/allensdk/test/brain_observatory/nwb/test_nwb_utils.py +++ b/allensdk/test/brain_observatory/nwb/test_nwb_utils.py @@ -25,9 +25,7 @@ ), ], ) -def test_get_stimulus_name_column( - input_cols, possible_names, expected_intersection -): +def test_get_stimulus_name_column(input_cols, possible_names, expected_intersection): column_name = nwb_utils.get_column_name(input_cols, possible_names) assert column_name == expected_intersection @@ -53,9 +51,7 @@ def test_get_stimulus_name_column( ), ], ) -def test_get_stimulus_name_column_exceptions( - input_cols, possible_names, expected_excep_cols -): +def test_get_stimulus_name_column_exceptions(input_cols, possible_names, expected_excep_cols): with pytest.raises(KeyError) as error: nwb_utils.get_column_name(input_cols, possible_names) for expected_value in expected_excep_cols: diff --git a/allensdk/test/brain_observatory/receptive_field_analysis/test_chisquarerf.py b/allensdk/test/brain_observatory/receptive_field_analysis/test_chisquarerf.py index 57cae07df4..c912eb3ecc 100644 --- a/allensdk/test/brain_observatory/receptive_field_analysis/test_chisquarerf.py +++ b/allensdk/test/brain_observatory/receptive_field_analysis/test_chisquarerf.py @@ -41,8 +41,7 @@ import scipy.stats as stats import numpy as np -from allensdk.brain_observatory.receptive_field_analysis import \ - chisquarerf as chi +from allensdk.brain_observatory.receptive_field_analysis import chisquarerf as chi @pytest.fixture @@ -51,8 +50,7 @@ def rf_events(): def make(receptive_field_mask, lsn): activity = np.logical_or(lsn == 255, lsn == 0) - return np.logical_and(activity, receptive_field_mask).sum(axis=(1, 2))[ - :, None] + return np.logical_and(activity, receptive_field_mask).sum(axis=(1, 2))[:, None] return make @@ -69,7 +67,7 @@ def make(ntr, nr, nc): def rf_mask(): def make(nr, nc, slices): mask = np.zeros((nr, nc)) - mask[slices[0][0]:slices[0][-1], slices[1][0]:slices[1][-1]] = 1 + mask[slices[0][0] : slices[0][-1], slices[1][0] : slices[1][-1]] = 1 return mask return make @@ -103,8 +101,7 @@ def trials_per_pixel(): # not testing d < 1 here -@pytest.mark.parametrize('r,c,d', - [[2, 3, 4], [28, 16, 3], [28, 16, 2], [10, 20, 12]]) +@pytest.mark.parametrize("r,c,d", [[2, 3, 4], [28, 16, 3], [28, 16, 2], [10, 20, 12]]) def test_interpolate_rf(r, c, d): image = np.arange(r * c).reshape([r, c]) @@ -114,20 +111,20 @@ def test_interpolate_rf(r, c, d): obtained = chi.interpolate_RF(image, d) grad = np.gradient(obtained) - assert (np.allclose(grad[0], np.zeros_like(grad[0]) + delta_row)) - assert (np.allclose(grad[1], np.zeros_like(grad[1]) + delta_col)) + assert np.allclose(grad[0], np.zeros_like(grad[0]) + delta_row) + assert np.allclose(grad[1], np.zeros_like(grad[1]) + delta_col) # tests integration with interpolate # not testing case where r, c are small -@pytest.mark.parametrize('r,c,d', [[28, 16, 3], [28, 16, 2], [10, 20, 12]]) +@pytest.mark.parametrize("r,c,d", [[28, 16, 3], [28, 16, 2], [10, 20, 12]]) def test_deinterpolate_rf(r, c, d): image = np.arange(r * c).reshape([r, c]) interp = chi.interpolate_RF(image, d) obt = chi.deinterpolate_RF(interp, c, r, d) - assert (np.allclose(image, obt)) + assert np.allclose(image, obt) def test_smooth_sta(): @@ -139,9 +136,9 @@ def test_smooth_sta(): thresholded[thresholded < 0.5] = 0 thresholded[thresholded > 0.5] = 1 - assert (np.allclose(smoothed.T, smoothed)) - assert (np.allclose(image, thresholded)) - assert (np.count_nonzero(smoothed) > np.count_nonzero(image)) + assert np.allclose(smoothed.T, smoothed) + assert np.allclose(image, thresholded) + assert np.count_nonzero(smoothed) > np.count_nonzero(image) def test_build_trial_matrix(): @@ -156,37 +153,31 @@ def test_build_trial_matrix(): exp[0, 0, 1, 1] = 1 obt = chi.build_trial_matrix(lsn_template, 2) - assert (np.allclose(exp, obt)) + assert np.allclose(exp, obt) -def test_get_expected_events_by_pixel(exclusion_mask, events_per_pixel, - trials_per_pixel): - obt = chi.get_expected_events_by_pixel(exclusion_mask, events_per_pixel, - trials_per_pixel) +def test_get_expected_events_by_pixel(exclusion_mask, events_per_pixel, trials_per_pixel): + obt = chi.get_expected_events_by_pixel(exclusion_mask, events_per_pixel, trials_per_pixel) - assert (obt[ - 0, 0, 0, 0] == 0.625) # 5 events, 8 trials (events counted + assert obt[0, 0, 0, 0] == 0.625 # 5 events, 8 trials (events counted # even if 0 trials) - assert (obt[1, 1, 0, 0] == 0.5) # 4 events, 8 trials - assert (obt[0, 0, 0, 1] == 0.0) # no trials - assert (obt[1, 3, 3, 0] == 0.0) # out of mask + assert obt[1, 1, 0, 0] == 0.5 # 4 events, 8 trials + assert obt[0, 0, 0, 1] == 0.0 # no trials + assert obt[1, 3, 3, 0] == 0.0 # out of mask -def test_chi_square_within_mask(exclusion_mask, events_per_pixel, - trials_per_pixel): - obt_p, obt_ch = chi.chi_square_within_mask(exclusion_mask, - events_per_pixel, - trials_per_pixel) +def test_chi_square_within_mask(exclusion_mask, events_per_pixel, trials_per_pixel): + obt_p, obt_ch = chi.chi_square_within_mask(exclusion_mask, events_per_pixel, trials_per_pixel) resps = np.array([4, 0, 0, 0, 0, 0, 0, 0]) resids = resps - 0.5 - chi_sum = (resids ** 2 / 0.5).sum() + chi_sum = (resids**2 / 0.5).sum() exp_p = 1.0 - stats.chi2.cdf(chi_sum, 15) # the zeroth test cell has a response without a trial. # this is infinitely surprising, so the pval is 0 - assert (np.allclose(obt_p, [0, exp_p])) + assert np.allclose(obt_p, [0, exp_p]) def test_get_disc_masks(): @@ -204,8 +195,8 @@ def test_get_disc_masks(): obt = chi.get_disc_masks(lsn_template, radius=1) - assert (np.allclose(exp1, obt[1, 1, :, :])) - assert (np.allclose(exp0, obt[0, 0, :, :])) + assert np.allclose(exp1, obt[1, 1, :, :]) + assert np.allclose(exp0, obt[0, 0, :, :]) def test_get_events_per_pixel(): @@ -234,34 +225,31 @@ def test_get_events_per_pixel(): exp[1, 1, 1, 1] = 2 obt = chi.get_events_per_pixel(events, trials) - assert (np.allclose(obt, exp)) + assert np.allclose(obt, exp) -@pytest.mark.parametrize('base,ex', [[5., 10], [0.1, 12], - [np.arange(20), np.linspace(0, 1, 20)]]) +@pytest.mark.parametrize("base,ex", [[5.0, 10], [0.1, 12], [np.arange(20), np.linspace(0, 1, 20)]]) def test_nll_to_pvalue(base, ex): obt = chi.NLL_to_pvalue(ex, base) exp = np.power(base, -ex) - assert (np.allclose(exp, obt)) + assert np.allclose(exp, obt) # test by reversing nll_to_pvalue -@pytest.mark.parametrize('base,ex', [[10., 2], [10., 4], - [np.array([10, 10, 10]), - np.linspace(0, 1, 3)]]) +@pytest.mark.parametrize("base,ex", [[10.0, 2], [10.0, 4], [np.array([10, 10, 10]), np.linspace(0, 1, 3)]]) def test_pvalue_to_nll(base, ex): pv = chi.NLL_to_pvalue(ex, base) max_nll = np.amax(ex) obt = chi.pvalue_to_NLL(pv, max_nll) - assert (np.allclose(ex, obt)) + assert np.allclose(ex, obt) -@pytest.mark.skipif(os.getenv('NO_TEST_RANDOM') == 'true', - reason="random seed may not produce the same results on " - "all machines") +@pytest.mark.skipif( + os.getenv("NO_TEST_RANDOM") == "true", reason="random seed may not produce the same results on all machines" +) def test_chi_square_binary(locally_sparse_noise, rf_events, rf_mask): ntr = 2000 nr = 20 @@ -273,14 +261,13 @@ def test_chi_square_binary(locally_sparse_noise, rf_events, rf_mask): events = rf_events(mask, lsn) obt = chi.chi_square_binary(events, lsn) - assert (obt[0][slices[0][0]:slices[0][-1], - slices[1][0]:slices[1][-1]].sum() == 0) - assert (obt.sum() > 0) + assert obt[0][slices[0][0] : slices[0][-1], slices[1][0] : slices[1][-1]].sum() == 0 + assert obt.sum() > 0 -@pytest.mark.skipif(os.getenv('NO_TEST_RANDOM') == 'true', - reason="random seed may not produce the same results on " - "all machines") +@pytest.mark.skipif( + os.getenv("NO_TEST_RANDOM") == "true", reason="random seed may not produce the same results on all machines" +) def test_get_peak_significance(locally_sparse_noise, rf_events, rf_mask): ntr = 2000 nr = 20 @@ -296,8 +283,8 @@ def test_get_peak_significance(locally_sparse_noise, rf_events, rf_mask): significant_cells, best_p, _, _ = chi.get_peak_significance(chi_nll, lsn) - assert (np.allclose(best_p, 0)) - assert (np.allclose(significant_cells, [True])) + assert np.allclose(best_p, 0) + assert np.allclose(significant_cells, [True]) def test_locate_median(): @@ -305,4 +292,4 @@ def test_locate_median(): where = np.where(mask) obt = chi.locate_median(*where) - assert (np.allclose(obt, [4, 4])) + assert np.allclose(obt, [4, 4]) diff --git a/allensdk/test/brain_observatory/receptive_field_analysis/test_fitgaussian2D.py b/allensdk/test/brain_observatory/receptive_field_analysis/test_fitgaussian2D.py index b3c4487c5b..380b605fc9 100644 --- a/allensdk/test/brain_observatory/receptive_field_analysis/test_fitgaussian2D.py +++ b/allensdk/test/brain_observatory/receptive_field_analysis/test_fitgaussian2D.py @@ -47,16 +47,14 @@ import allensdk.brain_observatory.receptive_field_analysis.fitgaussian2D as gauss -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def gaussian_pdf(): - def gpdf(mean, cov, axes, scale=1): + rv = multivariate_normal(mean, cov) + mesh = np.meshgrid(*axes, indexing="ij") + pos = np.rollaxis(np.array(mesh), 0, len(axes) + 1) - rv = multivariate_normal( mean, cov ) - mesh = np.meshgrid( *axes, indexing='ij' ) - pos = np.rollaxis( np.array(mesh), 0, len( axes ) + 1 ) - - out = rv.pdf( pos ) + out = rv.pdf(pos) out = out / np.amax(out) * scale return out, mesh @@ -64,24 +62,21 @@ def gpdf(mean, cov, axes, scale=1): return gpdf -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def domain_axes(): - start = 0 stop = 201 step = 1 naxes = 2 - axes = [ np.arange(start, stop, step) for ii in range(naxes) ] + axes = [np.arange(start, stop, step) for ii in range(naxes)] return axes -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def simple_fill(): - def do_fill(domain_axes, fn): - - arr = np.zeros([ len(da) for da in domain_axes ]) + arr = np.zeros([len(da) for da in domain_axes]) for pt in it.product(*domain_axes): arr[pt] = fn(*pt) @@ -90,101 +85,114 @@ def do_fill(domain_axes, fn): return do_fill -@pytest.mark.parametrize('mean,cov,scale', [ [ [ 100, 100 ], [ 25, 25 ], 1 ], - [ [ 100, 100 ], [ 10, 25 ], 1 ], - [ [ 100, 110 ], [ 25, 25 ], 1 ], - [ [ 100, 110 ], [ 10, 25 ], 1 ], - [ [ 110, 100 ], [ 10, 25 ], 1 ] ]) +@pytest.mark.parametrize( + "mean,cov,scale", + [ + [[100, 100], [25, 25], 1], + [[100, 100], [10, 25], 1], + [[100, 110], [25, 25], 1], + [[100, 110], [10, 25], 1], + [[110, 100], [10, 25], 1], + ], +) def test_gaussian2D_norot(mean, cov, scale, gaussian_pdf, domain_axes, simple_fill): + full_cov = [[cov[0], 0], [0, cov[1]]] + exp, mesh = gaussian_pdf(mean, full_cov, domain_axes, scale) - full_cov = [ [ cov[0], 0 ], [ 0, cov[1] ] ] - exp, mesh = gaussian_pdf( mean, full_cov, domain_axes, scale ) - - obt_fn = gauss.gaussian2D( scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), 0 ) - obt = simple_fill( domain_axes, obt_fn ) + obt_fn = gauss.gaussian2D(scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), 0) + obt = simple_fill(domain_axes, obt_fn) - assert( np.allclose( obt, exp ) ) + assert np.allclose(obt, exp) # only providing independent cov - using rotation after the fact -@pytest.mark.skipif(skimage.__version__ < '0.11.1', reason='cannot rotate about non-center point before .11.1') -@pytest.mark.parametrize('mean,cov,scale,rot', [ [ [ 100, 100 ], [ 25, 25 ], 1, 0 ], - [ [ 100, 100 ], [ 10, 25 ], 1, 0 ], - [ [ 100, 110 ], [ 25, 25 ], 1, 0 ], - [ [ 100, 110 ], [ 10, 25 ], 1, 0 ], - [ [ 110, 100 ], [ 10, 25 ], 1, 0 ], - [ [ 100, 100 ], [ 25, 25 ], 1, 90 ], - [ [ 100, 100 ], [ 25, 20 ], 1, 180 ], - [ [ 100, 100 ], [ 30, 25 ], 1, -90 ], - [ [ 100, 110 ], [ 20, 15 ], 1, -45 ], - [ [ 100, 110 ], [ 20, 15 ], 1, 30 ], - [ [ 100, 100 ], [ 15, 20 ], 1, 10 ], - [ [ 100, 100 ], [ 10, 25 ], 10, 0 ] ]) +@pytest.mark.skipif(skimage.__version__ < "0.11.1", reason="cannot rotate about non-center point before .11.1") +@pytest.mark.parametrize( + "mean,cov,scale,rot", + [ + [[100, 100], [25, 25], 1, 0], + [[100, 100], [10, 25], 1, 0], + [[100, 110], [25, 25], 1, 0], + [[100, 110], [10, 25], 1, 0], + [[110, 100], [10, 25], 1, 0], + [[100, 100], [25, 25], 1, 90], + [[100, 100], [25, 20], 1, 180], + [[100, 100], [30, 25], 1, -90], + [[100, 110], [20, 15], 1, -45], + [[100, 110], [20, 15], 1, 30], + [[100, 100], [15, 20], 1, 10], + [[100, 100], [10, 25], 10, 0], + ], +) def test_gaussian2D(mean, cov, scale, rot, gaussian_pdf, domain_axes, simple_fill): - - full_cov = [ [ cov[0], 0 ], [ 0, cov[1] ] ] - exp, mesh = gaussian_pdf( mean, full_cov, domain_axes, scale ) + full_cov = [[cov[0], 0], [0, cov[1]]] + exp, mesh = gaussian_pdf(mean, full_cov, domain_axes, scale) if rot != 0: - exp = rotate( exp, -rot, False, center=mean[::-1] ) # negative rotation - - obt_fn = gauss.gaussian2D( scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), rot ) - obt = simple_fill( domain_axes, obt_fn ) + exp = rotate(exp, -rot, False, center=mean[::-1]) # negative rotation + + obt_fn = gauss.gaussian2D(scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), rot) + obt = simple_fill(domain_axes, obt_fn) if rot == 0: - assert( np.allclose( obt, exp ) ) + assert np.allclose(obt, exp) else: - assert( np.linalg.norm( obt - exp ) / np.linalg.norm(exp) < 10 ** -2 ) - - -@pytest.mark.parametrize('mean,cov,scale', [ [ [ 100, 100 ], [ [1, 0 ], [0, 1] ], 1 ], - [ [ 100, 150 ], [ [1, 0 ], [0, 1] ], 1 ], - [ [ 125, 125 ], [ [1, 0 ], [0, 1] ], 1 ], - [ [ 110, 100 ], [ [1, 0 ], [0, 1] ], 1 ], - [ [ 90, 100 ], [ [1, 0 ], [0, 1] ], 1 ], - [ [ 100, 100 ], [ [1, 0 ], [0, 1] ], 2 ], - [ [ 100, 100 ], [ [5, 0 ], [0, 1] ], 1 ] ]) + assert np.linalg.norm(obt - exp) / np.linalg.norm(exp) < 10**-2 + + +@pytest.mark.parametrize( + "mean,cov,scale", + [ + [[100, 100], [[1, 0], [0, 1]], 1], + [[100, 150], [[1, 0], [0, 1]], 1], + [[125, 125], [[1, 0], [0, 1]], 1], + [[110, 100], [[1, 0], [0, 1]], 1], + [[90, 100], [[1, 0], [0, 1]], 1], + [[100, 100], [[1, 0], [0, 1]], 2], + [[100, 100], [[5, 0], [0, 1]], 1], + ], +) def test_moments2(mean, cov, scale, gaussian_pdf, domain_axes): + pdf, mesh = gaussian_pdf(mean, cov, domain_axes, scale) + mom_exp = np.array([scale, mean[0], mean[1], np.sqrt(cov[1][1]), np.sqrt(cov[0][0])]) - pdf, mesh = gaussian_pdf( mean, cov, domain_axes, scale ) - mom_exp = np.array([ scale, - mean[0], mean[1], - np.sqrt(cov[1][1]), np.sqrt(cov[0][0]) ]) + mom_obt = gauss.moments2(pdf) - mom_obt = gauss.moments2( pdf ) + assert np.allclose(mom_obt[:-1], mom_exp) + assert mom_obt[-1] is None # TODO: why? - assert( np.allclose( mom_obt[:-1], mom_exp ) ) - assert( mom_obt[-1] is None ) # TODO: why? - # we probably want to test rotation here at some point, but there is no way that it could work now, given # that moments2 assumes independence ... -@pytest.mark.parametrize('mean,cov,scale', [ [ [ 100, 100 ], [ 25, 25 ], 1 ], - [ [ 100, 100 ], [ 10, 25 ], 1 ], - [ [ 100, 110 ], [ 25, 25 ], 1 ], - [ [ 100, 110 ], [ 10, 25 ], 1 ], - [ [ 110, 100 ], [ 10, 25 ], 1 ] ]) +@pytest.mark.parametrize( + "mean,cov,scale", + [ + [[100, 100], [25, 25], 1], + [[100, 100], [10, 25], 1], + [[100, 110], [25, 25], 1], + [[100, 110], [10, 25], 1], + [[110, 100], [10, 25], 1], + ], +) def test_fitgaussian2D(mean, cov, scale, gaussian_pdf, domain_axes): + full_cov = [[cov[0], 0], [0, cov[1]]] + img, mesh = gaussian_pdf(mean, full_cov, domain_axes, scale) - full_cov = [ [ cov[0], 0 ], [ 0, cov[1] ] ] - img, mesh = gaussian_pdf( mean, full_cov, domain_axes, scale ) + obt = gauss.fitgaussian2D(img) + exp = [scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), 0] - obt = gauss.fitgaussian2D( img ) - exp = [ scale, mean[0], mean[1], np.sqrt(cov[0]), np.sqrt(cov[1]), 0 ] - - assert( np.allclose( exp, obt, atol=10**-3 ) ) + assert np.allclose(exp, obt, atol=10**-3) def test_fitgaussian2D_failure(): - data = np.eye(10) res = mock.MagicMock() res.success = False res.status = 3 - res.message = 'foo' + res.message = "foo" res.x = np.array([1.0, 1.0, 1.0, 0.0]) - with mock.patch('scipy.optimize.minimize', return_value=res): - with pytest.raises( gauss.GaussianFitError ): + with mock.patch("scipy.optimize.minimize", return_value=res): + with pytest.raises(gauss.GaussianFitError): gauss.fitgaussian2D(data) diff --git a/allensdk/test/brain_observatory/sync_utilities/conftest.py b/allensdk/test/brain_observatory/sync_utilities/conftest.py index aae110cf0c..4d32edff48 100644 --- a/allensdk/test/brain_observatory/sync_utilities/conftest.py +++ b/allensdk/test/brain_observatory/sync_utilities/conftest.py @@ -32,13 +32,11 @@ def line_name_fixture(): """ List of line names to use in test sync file """ - return ['lineA', '', 'lineB', 'lineC', 'lineD'] + return ["lineA", "", "lineB", "lineC", "lineD"] @pytest.fixture -def line_to_edges_fixture( - line_name_fixture, - sync_sample_fixture): +def line_to_edges_fixture(line_name_fixture, sync_sample_fixture): """ A dict mapping line name to lists of rising and falling edge indexes in the test sync file (note that the lists @@ -66,73 +64,68 @@ def line_to_edges_fixture( # every line starts out with some random initial # samles set to 1 - this_rising = [0, ] + this_rising = [ + 0, + ] this_falling = [rng.integers(2, 18)] for idx in range(0, len(changes), 2): this_rising.append(changes[idx]) - this_falling.append(changes[idx+1]) - result[line_name] = {'rising_idx': np.array(this_rising), - 'falling_idx': np.array(this_falling)} + this_falling.append(changes[idx + 1]) + result[line_name] = {"rising_idx": np.array(this_rising), "falling_idx": np.array(this_falling)} return result @pytest.fixture -def sync_metadata_fixture( - sync_freq_fixture, - line_name_fixture): +def sync_metadata_fixture(sync_freq_fixture, line_name_fixture): """ Dict representing 'meta' dataset in test sync file """ - metadata = {'ni_daq': - {'device': 'Dev1', - 'counter_output_freq': sync_freq_fixture, - 'sample_rate': sync_freq_fixture, - 'counter_bits': 32, - 'event_bits': 32}, - 'start_time': '2020-10-07 14:01:17.336502', - 'stop_time': '2020-10-07 16:42:24.177205', - 'line_labels': line_name_fixture, - 'timeouts': [], - 'version': '2.2.1+g1bc7438.b42257', - 'sampling_type': 'frequency', - 'file_version': '1.0.0', - 'line_label_revision': 3, - 'total_samples': 10000} + metadata = { + "ni_daq": { + "device": "Dev1", + "counter_output_freq": sync_freq_fixture, + "sample_rate": sync_freq_fixture, + "counter_bits": 32, + "event_bits": 32, + }, + "start_time": "2020-10-07 14:01:17.336502", + "stop_time": "2020-10-07 16:42:24.177205", + "line_labels": line_name_fixture, + "timeouts": [], + "version": "2.2.1+g1bc7438.b42257", + "sampling_type": "frequency", + "file_version": "1.0.0", + "line_label_revision": 3, + "total_samples": 10000, + } return metadata @pytest.fixture def sync_file_fixture( - sync_metadata_fixture, - line_name_fixture, - line_to_edges_fixture, - sync_sample_fixture, - tmp_path_factory): + sync_metadata_fixture, line_name_fixture, line_to_edges_fixture, sync_sample_fixture, tmp_path_factory +): """ Yields the path to a sync file for testing """ - tmpdir = pathlib.Path(tmp_path_factory.mktemp('external_sync_test')) - sync_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix='sync')[1]) + tmpdir = pathlib.Path(tmp_path_factory.mktemp("external_sync_test")) + sync_path = pathlib.Path(tempfile.mkstemp(dir=tmpdir, suffix="sync")[1]) n_samples = len(sync_sample_fixture) data = np.zeros((n_samples, 2), dtype=np.uint32) data[:, 0] = sync_sample_fixture for pwr_of_2, line_name in enumerate(line_name_fixture): - this_rising = line_to_edges_fixture[line_name]['rising_idx'] - this_falling = line_to_edges_fixture[line_name]['falling_idx'] + this_rising = line_to_edges_fixture[line_name]["rising_idx"] + this_falling = line_to_edges_fixture[line_name]["falling_idx"] for rising, falling in zip(this_rising, this_falling): data[rising:falling, 1] += 2**pwr_of_2 - with h5py.File(sync_path, 'w') as out_file: - out_file.create_dataset( - 'data', - data=data) - out_file.create_dataset( - 'meta', - data=json.dumps(sync_metadata_fixture).encode('utf-8')) + with h5py.File(sync_path, "w") as out_file: + out_file.create_dataset("data", data=data) + out_file.create_dataset("meta", data=json.dumps(sync_metadata_fixture).encode("utf-8")) yield sync_path diff --git a/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_alignment.py b/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_alignment.py index 2fca707727..279ec51c03 100644 --- a/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_alignment.py +++ b/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_alignment.py @@ -11,58 +11,51 @@ _choose_line, _get_rising_times, _get_falling_times, - _get_line_starts_and_ends) + _get_line_starts_and_ends, +) -def test_choose_line( - sync_file_fixture): +def test_choose_line(sync_file_fixture): """ Test that _choose_line chooses the expected line """ with SyncDataset(sync_file_fixture) as data: - assert _choose_line(data, 'lineC') == 'lineC' - assert _choose_line(data, ('lineB', 'lineA')) == 'lineB' - assert _choose_line(data, ('lineA', 'lineB')) == 'lineA' - assert _choose_line(data, ('xxxx', 'lineD')) == 'lineD' - with pytest.raises(RuntimeError, match='Could not find one of'): - _choose_line(data, ('xxxx', 'yyyy')) - with pytest.raises(RuntimeError, match='Could not find one of'): - _choose_line(data, 'zzzz') + assert _choose_line(data, "lineC") == "lineC" + assert _choose_line(data, ("lineB", "lineA")) == "lineB" + assert _choose_line(data, ("lineA", "lineB")) == "lineA" + assert _choose_line(data, ("xxxx", "lineD")) == "lineD" + with pytest.raises(RuntimeError, match="Could not find one of"): + _choose_line(data, ("xxxx", "yyyy")) + with pytest.raises(RuntimeError, match="Could not find one of"): + _choose_line(data, "zzzz") @pytest.mark.parametrize( - "specified_lines, expected_line", - [('lineD', 'lineD'), - (('nonsense', 'lineC'), 'lineC'), - (('lineC', 'lineB'), 'lineC'), - (('lineB', 'lineC'), 'lineB')]) + "specified_lines, expected_line", + [ + ("lineD", "lineD"), + (("nonsense", "lineC"), "lineC"), + (("lineC", "lineB"), "lineC"), + (("lineB", "lineC"), "lineB"), + ], +) def test_get_rising_times( - sync_file_fixture, - sync_freq_fixture, - sync_sample_fixture, - line_to_edges_fixture, - specified_lines, - expected_line): + sync_file_fixture, sync_freq_fixture, sync_sample_fixture, line_to_edges_fixture, specified_lines, expected_line +): """ Test that _get_rising_times returns the expected timestamp arrays """ with SyncDataset(sync_file_fixture) as data: - actual = _get_rising_times( - data=data, - sync_lines=specified_lines) + actual = _get_rising_times(data=data, sync_lines=specified_lines) - expected_idx = line_to_edges_fixture[expected_line]['rising_idx'][1:] - expected_time = sync_sample_fixture[expected_idx]/sync_freq_fixture + expected_idx = line_to_edges_fixture[expected_line]["rising_idx"][1:] + expected_time = sync_sample_fixture[expected_idx] / sync_freq_fixture np.testing.assert_allclose(expected_time, actual) -@pytest.mark.parametrize( - "specified_lines", - ['lineZ', ('lineU', 'lineW')]) -def test_get_rising_times_exception( - sync_file_fixture, - specified_lines): +@pytest.mark.parametrize("specified_lines", ["lineZ", ("lineU", "lineW")]) +def test_get_rising_times_exception(sync_file_fixture, specified_lines): """ Test that _get_rising_times raises the expected exception when you specify non-existent lines @@ -70,72 +63,60 @@ def test_get_rising_times_exception( with SyncDataset(sync_file_fixture) as data: with pytest.raises(RuntimeError, match="Could not find one of"): - _get_rising_times( - data=data, - sync_lines=specified_lines) + _get_rising_times(data=data, sync_lines=specified_lines) @pytest.mark.parametrize( - "specified_lines, expected_line", - [('lineD', 'lineD'), - (('nonsense', 'lineC'), 'lineC'), - (('lineC', 'lineB'), 'lineC'), - (('lineB', 'lineC'), 'lineB')]) + "specified_lines, expected_line", + [ + ("lineD", "lineD"), + (("nonsense", "lineC"), "lineC"), + (("lineC", "lineB"), "lineC"), + (("lineB", "lineC"), "lineB"), + ], +) def test_get_falling_times( - sync_file_fixture, - sync_freq_fixture, - sync_sample_fixture, - line_to_edges_fixture, - specified_lines, - expected_line): + sync_file_fixture, sync_freq_fixture, sync_sample_fixture, line_to_edges_fixture, specified_lines, expected_line +): """ Test that _get_falling_times returns the expected timestamp arrays """ with SyncDataset(sync_file_fixture) as data: - actual = _get_falling_times( - data=data, - sync_lines=specified_lines) + actual = _get_falling_times(data=data, sync_lines=specified_lines) - expected_idx = line_to_edges_fixture[expected_line]['falling_idx'][1:] - expected_time = sync_sample_fixture[expected_idx]/sync_freq_fixture + expected_idx = line_to_edges_fixture[expected_line]["falling_idx"][1:] + expected_time = sync_sample_fixture[expected_idx] / sync_freq_fixture np.testing.assert_allclose(expected_time, actual) @pytest.mark.parametrize( - "specified_lines, expected_line", - [('lineD', 'lineD'), - (('nonsense', 'lineC'), 'lineC'), - (('lineC', 'lineB'), 'lineC'), - (('lineB', 'lineC'), 'lineB')]) + "specified_lines, expected_line", + [ + ("lineD", "lineD"), + (("nonsense", "lineC"), "lineC"), + (("lineC", "lineB"), "lineC"), + (("lineB", "lineC"), "lineB"), + ], +) def test_get_line_starts_and_ends( - sync_file_fixture, - sync_freq_fixture, - sync_sample_fixture, - line_to_edges_fixture, - specified_lines, - expected_line): + sync_file_fixture, sync_freq_fixture, sync_sample_fixture, line_to_edges_fixture, specified_lines, expected_line +): """ Test that _get_line_starts_and_ends works as expected """ with SyncDataset(sync_file_fixture) as data: - actual = _get_line_starts_and_ends( - data=data, - sync_lines=specified_lines) - start_idx = line_to_edges_fixture[expected_line]['rising_idx'][1:] - end_idx = line_to_edges_fixture[expected_line]['falling_idx'][1:] - start_times = sync_sample_fixture[start_idx]/sync_freq_fixture - end_times = sync_sample_fixture[end_idx]/sync_freq_fixture + actual = _get_line_starts_and_ends(data=data, sync_lines=specified_lines) + start_idx = line_to_edges_fixture[expected_line]["rising_idx"][1:] + end_idx = line_to_edges_fixture[expected_line]["falling_idx"][1:] + start_times = sync_sample_fixture[start_idx] / sync_freq_fixture + end_times = sync_sample_fixture[end_idx] / sync_freq_fixture np.testing.assert_allclose(actual[0], start_times) np.testing.assert_allclose(actual[1], end_times) -@pytest.mark.parametrize( - "specified_lines", - ['lineZ', ('lineU', 'lineW')]) -def test_get_falling_times_exception( - sync_file_fixture, - specified_lines): +@pytest.mark.parametrize("specified_lines", ["lineZ", ("lineU", "lineW")]) +def test_get_falling_times_exception(sync_file_fixture, specified_lines): """ Test that _get_falling_times raises the expected exception when you specify non-existent lines @@ -143,6 +124,4 @@ def test_get_falling_times_exception( with SyncDataset(sync_file_fixture) as data: with pytest.raises(RuntimeError, match="Could not find one of"): - _get_falling_times( - data=data, - sync_lines=specified_lines) + _get_falling_times(data=data, sync_lines=specified_lines) diff --git a/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_get_start_frames.py b/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_get_start_frames.py index c3a74339d6..ec8b56b5da 100644 --- a/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_get_start_frames.py +++ b/allensdk/test/brain_observatory/sync_utilities/test_sync_stim_get_start_frames.py @@ -1,9 +1,7 @@ import pytest import numpy as np from allensdk.brain_observatory.sync_dataset import Dataset as SyncDataset -from allensdk.brain_observatory.sync_stim_aligner import ( - _get_start_frames, - get_stim_timestamps_from_stimulus_blocks) +from allensdk.brain_observatory.sync_stim_aligner import _get_start_frames, get_stim_timestamps_from_stimulus_blocks class DummyStim(object): @@ -23,12 +21,11 @@ def num_frames(self): @pytest.fixture def line_name_fixture(): - return ['lineA', 'stim_running', 'lineB', 'vsync_stim'] + return ["lineA", "stim_running", "lineB", "vsync_stim"] @pytest.fixture -def line_to_edges_fixture( - sync_sample_fixture): +def line_to_edges_fixture(sync_sample_fixture): n_samples = len(sync_sample_fixture) result = dict() @@ -36,23 +33,21 @@ def line_to_edges_fixture( # fill lineA and lineB with random bits rng = np.random.default_rng(66123) indexes = np.arange(0, n_samples, 3) - for line_name in ('lineA', 'lineB'): + for line_name in ("lineA", "lineB"): changes = rng.choice(indexes, 50, replace=False) changes = np.sort(changes) this_rising = [] this_falling = [] for idx in range(0, len(changes), 2): this_rising.append(changes[idx]) - this_falling.append(changes[idx+1]) - result[line_name] = {'rising_idx': np.array(this_rising), - 'falling_idx': np.array(this_falling)} + this_falling.append(changes[idx + 1]) + result[line_name] = {"rising_idx": np.array(this_rising), "falling_idx": np.array(this_falling)} # create four intentional blocks of length # 33, 66, 99, 132 in the stim_running line this_rising = [12, 100, 204, 500] this_falling = [45, 166, 303, 632] - result['stim_running'] = {'rising_idx': np.array(this_rising), - 'falling_idx': np.array(this_falling)} + result["stim_running"] = {"rising_idx": np.array(this_rising), "falling_idx": np.array(this_falling)} # set vsync_stim lines; # because we are placing an edge every three frames, @@ -60,23 +55,21 @@ def line_to_edges_fixture( # be 1/3 that specified in the above block v_rising = np.arange(4, n_samples, 3, dtype=int) v_falling = v_rising + 1 - result['vsync_stim'] = {'rising_idx': v_rising, - 'falling_idx': v_falling} + result["vsync_stim"] = {"rising_idx": v_rising, "falling_idx": v_falling} return result @pytest.fixture -def expected_start_frames_fixture( - line_to_edges_fixture): +def expected_start_frames_fixture(line_to_edges_fixture): """ Return dict that maps 'rising', 'falling' to the expected start frames for all of the stimulus blocks in our test sync file """ result = dict() - for edge_type in ('rising', 'falling'): - frame_edges = line_to_edges_fixture['vsync_stim'][f'{edge_type}_idx'] - stim_edge_list = line_to_edges_fixture['stim_running']['rising_idx'] + for edge_type in ("rising", "falling"): + frame_edges = line_to_edges_fixture["vsync_stim"][f"{edge_type}_idx"] + stim_edge_list = line_to_edges_fixture["stim_running"]["rising_idx"] expected_idx = [] for stim_edge in stim_edge_list: this_idx = np.where(frame_edges >= stim_edge)[0].min() @@ -85,11 +78,7 @@ def expected_start_frames_fixture( return result -def test_get_start_frames_exact( - sync_file_fixture, - line_to_edges_fixture, - sync_sample_fixture, - sync_freq_fixture): +def test_get_start_frames_exact(sync_file_fixture, line_to_edges_fixture, sync_sample_fixture, sync_freq_fixture): """ Test case where _get_frame_offsets is expected to return exact matches @@ -97,77 +86,56 @@ def test_get_start_frames_exact( with SyncDataset(sync_file_fixture) as sync_data: start_frames = _get_start_frames( - data=sync_data, - raw_frame_times=sync_sample_fixture/sync_freq_fixture, - stimulus_frame_counts=[44, 55, 66, 77], - tolerance=0.0) - np.testing.assert_array_equal( - start_frames, - line_to_edges_fixture['stim_running']['rising_idx']) + data=sync_data, + raw_frame_times=sync_sample_fixture / sync_freq_fixture, + stimulus_frame_counts=[44, 55, 66, 77], + tolerance=0.0, + ) + np.testing.assert_array_equal(start_frames, line_to_edges_fixture["stim_running"]["rising_idx"]) -@pytest.mark.parametrize( - "stimulus_frame_counts, expected", - [([33, 66, 132], [12, 100, 500]), - ([66, 132], [100, 500])]) +@pytest.mark.parametrize("stimulus_frame_counts, expected", [([33, 66, 132], [12, 100, 500]), ([66, 132], [100, 500])]) def test_get_start_frames_skip_one( - sync_file_fixture, - sync_sample_fixture, - sync_freq_fixture, - stimulus_frame_counts, - expected): + sync_file_fixture, sync_sample_fixture, sync_freq_fixture, stimulus_frame_counts, expected +): """ Test the case where one of the blocks in stim_running is erroneous """ with SyncDataset(sync_file_fixture) as sync_data: start_frames = _get_start_frames( - data=sync_data, - raw_frame_times=sync_sample_fixture/sync_freq_fixture, - stimulus_frame_counts=stimulus_frame_counts, - tolerance=0.0) - np.testing.assert_array_equal( - start_frames, - expected) + data=sync_data, + raw_frame_times=sync_sample_fixture / sync_freq_fixture, + stimulus_frame_counts=stimulus_frame_counts, + tolerance=0.0, + ) + np.testing.assert_array_equal(start_frames, expected) @pytest.mark.parametrize( - "tolerance, stimulus_frame_counts, expected", - [(0.1, [35, 61, 127], [12, 100, 500]), - (0.05, [32, 68, 136], [12, 100, 500]), - (0.05, [67, 96], [100, 204])]) + "tolerance, stimulus_frame_counts, expected", + [(0.1, [35, 61, 127], [12, 100, 500]), (0.05, [32, 68, 136], [12, 100, 500]), (0.05, [67, 96], [100, 204])], +) def test_get_start_frames_tolerance( - sync_file_fixture, - sync_sample_fixture, - sync_freq_fixture, - tolerance, - stimulus_frame_counts, - expected): + sync_file_fixture, sync_sample_fixture, sync_freq_fixture, tolerance, stimulus_frame_counts, expected +): """ Test that _get_start_frames correctly infers starting frames within tolerance """ with SyncDataset(sync_file_fixture) as sync_data: start_frames = _get_start_frames( - data=sync_data, - raw_frame_times=sync_sample_fixture/sync_freq_fixture, - stimulus_frame_counts=stimulus_frame_counts, - tolerance=tolerance) - np.testing.assert_array_equal( - start_frames, - expected) + data=sync_data, + raw_frame_times=sync_sample_fixture / sync_freq_fixture, + stimulus_frame_counts=stimulus_frame_counts, + tolerance=tolerance, + ) + np.testing.assert_array_equal(start_frames, expected) -@pytest.mark.parametrize( - "tolerance, stimulus_frame_counts", - [(0.1, [32, 40, 99]), - (0.05, [33, 66, 300])]) +@pytest.mark.parametrize("tolerance, stimulus_frame_counts", [(0.1, [32, 40, 99]), (0.05, [33, 66, 300])]) def test_get_start_frames_tolerance_failures( - sync_file_fixture, - line_to_edges_fixture, - sync_sample_fixture, - sync_freq_fixture, - tolerance, - stimulus_frame_counts): + sync_file_fixture, line_to_edges_fixture, sync_sample_fixture, sync_freq_fixture, tolerance, stimulus_frame_counts +): """ Test that _get_start_frames correctly fails when the best guess is outside of the specified tolerance @@ -175,54 +143,51 @@ def test_get_start_frames_tolerance_failures( with SyncDataset(sync_file_fixture) as sync_data: with pytest.raises(RuntimeError, match="Could not find matching sync"): _get_start_frames( - data=sync_data, - raw_frame_times=sync_sample_fixture/sync_freq_fixture, - stimulus_frame_counts=stimulus_frame_counts, - tolerance=tolerance) + data=sync_data, + raw_frame_times=sync_sample_fixture / sync_freq_fixture, + stimulus_frame_counts=stimulus_frame_counts, + tolerance=tolerance, + ) -def test_get_start_frames_too_many_pkl( - sync_file_fixture, - sync_sample_fixture, - sync_freq_fixture): +def test_get_start_frames_too_many_pkl(sync_file_fixture, sync_sample_fixture, sync_freq_fixture): """ Test the case where you specify too many stimulus_frame_counts """ with SyncDataset(sync_file_fixture) as sync_data: with pytest.raises(RuntimeError, match="more pkl frame count entries"): _get_start_frames( - data=sync_data, - raw_frame_times=sync_sample_fixture/sync_freq_fixture, - stimulus_frame_counts=[44, 55, 77, 100, 300, 55], - tolerance=0.0) + data=sync_data, + raw_frame_times=sync_sample_fixture / sync_freq_fixture, + stimulus_frame_counts=[44, 55, 77, 100, 300, 55], + tolerance=0.0, + ) -def test_user_facing_get_stims_error( - sync_file_fixture): +def test_user_facing_get_stims_error(sync_file_fixture): """ Make sure that get_stim_timestamps_from_stimulus_blocks raises the expected error if you give it a bad raw_frame_time_direction """ - with pytest.raises(ValueError, - match="Cannot parse raw_frame_time_direction"): + with pytest.raises(ValueError, match="Cannot parse raw_frame_time_direction"): get_stim_timestamps_from_stimulus_blocks( - stimulus_files=[DummyStim(n_frames=2), - DummyStim(n_frames=4)], + stimulus_files=[DummyStim(n_frames=2), DummyStim(n_frames=4)], sync_file=sync_file_fixture, - raw_frame_time_lines='vsync_stim', - raw_frame_time_direction='nonsense', - frame_count_tolerance=0.0) + raw_frame_time_lines="vsync_stim", + raw_frame_time_direction="nonsense", + frame_count_tolerance=0.0, + ) -@pytest.mark.parametrize( - "edge_type", ["rising", "falling"]) +@pytest.mark.parametrize("edge_type", ["rising", "falling"]) def test_user_facing_get_stim_timestamps_smoke( - sync_file_fixture, - edge_type, - expected_start_frames_fixture, - line_to_edges_fixture, - sync_sample_fixture, - sync_freq_fixture): + sync_file_fixture, + edge_type, + expected_start_frames_fixture, + line_to_edges_fixture, + sync_sample_fixture, + sync_freq_fixture, +): """ Test user-facing get_stim_timestaps_from_stimulus_blocks in case where the number of blocks in stim_running matches @@ -234,73 +199,71 @@ def test_user_facing_get_stim_timestamps_smoke( # found by analyzing stim_running matches the expected # number of blocks exactly, the start_frames of those # blocks will be returned - stim_list = [DummyStim(n_frames=33), - DummyStim(n_frames=66), - DummyStim(n_frames=99), - DummyStim(n_frames=132)] + stim_list = [DummyStim(n_frames=33), DummyStim(n_frames=66), DummyStim(n_frames=99), DummyStim(n_frames=132)] result = get_stim_timestamps_from_stimulus_blocks( - stimulus_files=stim_list, - sync_file=sync_file_fixture, - raw_frame_time_lines='vsync_stim', - raw_frame_time_direction=edge_type, - frame_count_tolerance=0.0) + stimulus_files=stim_list, + sync_file=sync_file_fixture, + raw_frame_time_lines="vsync_stim", + raw_frame_time_direction=edge_type, + frame_count_tolerance=0.0, + ) assert len(result["timestamps"]) == 4 - for ii, (this_array, - this_start_frame, - this_stim) in enumerate(zip(result["timestamps"], - result["start_frames"], - stim_list)): - raw_idx = line_to_edges_fixture['vsync_stim'][f'{edge_type}_idx'] - raw_times = sync_sample_fixture[raw_idx]/sync_freq_fixture + for ii, (this_array, this_start_frame, this_stim) in enumerate( + zip(result["timestamps"], result["start_frames"], stim_list) + ): + raw_idx = line_to_edges_fixture["vsync_stim"][f"{edge_type}_idx"] + raw_times = sync_sample_fixture[raw_idx] / sync_freq_fixture idx0 = expected_start_frames_fixture[edge_type][ii] - expected = raw_times[idx0: idx0+this_stim.num_frames] + expected = raw_times[idx0 : idx0 + this_stim.num_frames] np.testing.assert_array_equal(this_array, expected) assert this_start_frame == expected_start_frames_fixture[edge_type][ii] @pytest.mark.parametrize( - "edge_type, stim_frame_inputs, tolerance, expected_idx", - [('rising', (11, 33, 44), 0.0, [0, 2, 3]), - ('falling', (11, 33, 44), 0.0, [0, 2, 3]), - ('rising', (21, 34, 45), 0.05, [1, 2, 3]), - ('falling', (31, 47), 0.1, [2, 3])]) + "edge_type, stim_frame_inputs, tolerance, expected_idx", + [ + ("rising", (11, 33, 44), 0.0, [0, 2, 3]), + ("falling", (11, 33, 44), 0.0, [0, 2, 3]), + ("rising", (21, 34, 45), 0.05, [1, 2, 3]), + ("falling", (31, 47), 0.1, [2, 3]), + ], +) def test_user_facing_get_stim_timestamps( - sync_file_fixture, - edge_type, - stim_frame_inputs, - tolerance, - expected_idx, - expected_start_frames_fixture, - line_to_edges_fixture, - sync_sample_fixture, - sync_freq_fixture): + sync_file_fixture, + edge_type, + stim_frame_inputs, + tolerance, + expected_idx, + expected_start_frames_fixture, + line_to_edges_fixture, + sync_sample_fixture, + sync_freq_fixture, +): """ Test the user-facing get_start_frames_from_stimulus_blocks in cases of differing stimulus specifications and tolerances """ stim_list = [DummyStim(n_frames=n) for n in stim_frame_inputs] - expected_start = [expected_start_frames_fixture[edge_type][idx] - for idx in expected_idx] + expected_start = [expected_start_frames_fixture[edge_type][idx] for idx in expected_idx] result = get_stim_timestamps_from_stimulus_blocks( - stimulus_files=stim_list, - sync_file=sync_file_fixture, - raw_frame_time_lines='vsync_stim', - raw_frame_time_direction=edge_type, - frame_count_tolerance=tolerance) + stimulus_files=stim_list, + sync_file=sync_file_fixture, + raw_frame_time_lines="vsync_stim", + raw_frame_time_direction=edge_type, + frame_count_tolerance=tolerance, + ) assert len(result["timestamps"]) == len(expected_idx) - for ii, (this_array, - this_start_frame, - this_stim) in enumerate(zip(result["timestamps"], - result["start_frames"], - stim_list)): - raw_idx = line_to_edges_fixture['vsync_stim'][f'{edge_type}_idx'] - raw_times = sync_sample_fixture[raw_idx]/sync_freq_fixture + for ii, (this_array, this_start_frame, this_stim) in enumerate( + zip(result["timestamps"], result["start_frames"], stim_list) + ): + raw_idx = line_to_edges_fixture["vsync_stim"][f"{edge_type}_idx"] + raw_times = sync_sample_fixture[raw_idx] / sync_freq_fixture idx0 = expected_start[ii] - expected = raw_times[idx0: idx0+this_stim.num_frames] + expected = raw_times[idx0 : idx0 + this_stim.num_frames] np.testing.assert_array_equal(this_array, expected) assert this_start_frame == expected_start[ii] diff --git a/allensdk/test/brain_observatory/sync_utilities/test_sync_utilities.py b/allensdk/test/brain_observatory/sync_utilities/test_sync_utilities.py index 3aaf2c8576..8ba52662a7 100644 --- a/allensdk/test/brain_observatory/sync_utilities/test_sync_utilities.py +++ b/allensdk/test/brain_observatory/sync_utilities/test_sync_utilities.py @@ -8,15 +8,14 @@ class MockDataset(Dataset): - def __init__(self, path: str, - eye_tracking_timings, behavior_tracking_timings): + def __init__(self, path: str, eye_tracking_timings, behavior_tracking_timings): # Note: eye_tracking_timings and behavior_tracking_timings are test # inputs that can be parametrized and do not exist in the real # `Dataset` class. self.eye_tracking_timings = eye_tracking_timings self.behavior_tracking_timings = behavior_tracking_timings - def get_edges(self, kind, keys, units='seconds'): + def get_edges(self, kind, keys, units="seconds"): if keys == self.EYE_TRACKING_KEYS: return self.eye_tracking_timings elif keys == self.BEHAVIOR_TRACKING_KEYS: @@ -25,17 +24,12 @@ def get_edges(self, kind, keys, units='seconds'): @pytest.fixture def mock_dataset_fixture(request): - test_params = { - "eye_tracking_timings": [], - "behavior_tracking_timings": [] - } + test_params = {"eye_tracking_timings": [], "behavior_tracking_timings": []} test_params.update(request.param) return partial(MockDataset, **test_params) -@pytest.mark.parametrize('vs_times, expected', [ - [[0.016, 0.033, 0.051, 0.067, 3.0], [0.016, 0.033, 0.051, 0.067]] -]) +@pytest.mark.parametrize("vs_times, expected", [[[0.016, 0.033, 0.051, 0.067, 3.0], [0.016, 0.033, 0.051, 0.067]]]) def test_trim_discontiguous_vsyncs(vs_times, expected): obtained = su.trim_discontiguous_times(vs_times) assert np.allclose(obtained, expected) @@ -44,136 +38,124 @@ def test_trim_discontiguous_vsyncs(vs_times, expected): @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, - Dataset.EYE_TRACKING_KEYS, [0.020, 0.030, 0.040, 0.050]), - - ({"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, - Dataset.BEHAVIOR_TRACKING_KEYS, [0.08, 0.090, 0.100, 0.110]) + ( + {"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, + Dataset.EYE_TRACKING_KEYS, + [0.020, 0.030, 0.040, 0.050], + ), + ( + {"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, + Dataset.BEHAVIOR_TRACKING_KEYS, + [0.08, 0.090, 0.100, 0.110], + ), ], - indirect=["mock_dataset_fixture"]) -def test_get_synchronized_frame_times(monkeypatch, mock_dataset_fixture, - sync_line_label_keys, expected): + indirect=["mock_dataset_fixture"], +) +def test_get_synchronized_frame_times(monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", - sync_line_label_keys - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys) assert np.allclose(obtained, expected) @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, - Dataset.EYE_TRACKING_KEYS, [0.020, 0.030, 0.040, 0.050, 3.0]), - - ({"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, - Dataset.BEHAVIOR_TRACKING_KEYS, [0.08, 0.090, 0.100, 0.110, 8.0]) - ], indirect=["mock_dataset_fixture"]) -def test_get_synchronized_frame_times_no_trim( - monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected -): + ( + {"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, + Dataset.EYE_TRACKING_KEYS, + [0.020, 0.030, 0.040, 0.050, 3.0], + ), + ( + {"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, + Dataset.BEHAVIOR_TRACKING_KEYS, + [0.08, 0.090, 0.100, 0.110, 8.0], + ), + ], + indirect=["mock_dataset_fixture"], +) +def test_get_synchronized_frame_times_no_trim(monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", - sync_line_label_keys, - trim_after_spike=False - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys, trim_after_spike=False) assert np.allclose(obtained, expected) @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": - [0.020, 0.030, 3.0, 0.040, 0.050, 0.040]}, - Dataset.EYE_TRACKING_KEYS, [0.020, 0.030]), - - ({"behavior_tracking_timings": - [0.080, 8.0, 0.090, 0.100, 0.110, 0.150, 0.085, 0.110, 0.13]}, - Dataset.BEHAVIOR_TRACKING_KEYS, [0.08]) - ], indirect=["mock_dataset_fixture"]) + ({"eye_tracking_timings": [0.020, 0.030, 3.0, 0.040, 0.050, 0.040]}, Dataset.EYE_TRACKING_KEYS, [0.020, 0.030]), + ( + {"behavior_tracking_timings": [0.080, 8.0, 0.090, 0.100, 0.110, 0.150, 0.085, 0.110, 0.13]}, + Dataset.BEHAVIOR_TRACKING_KEYS, + [0.08], + ), + ], + indirect=["mock_dataset_fixture"], +) def test_get_synchronized_frame_times_trim_with_spike( monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected ): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", sync_line_label_keys - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys) assert np.allclose(obtained, expected) @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": [3.0, 0.030, 0.040, 0.050]}, - Dataset.EYE_TRACKING_KEYS, []), - - ({"behavior_tracking_timings": [8.0, 0.080, 0.090, 0.100, 0.110]}, - Dataset.BEHAVIOR_TRACKING_KEYS, []) - ], indirect=["mock_dataset_fixture"]) -def test_get_synchronized_frame_times_trim_all( - monkeypatch, - mock_dataset_fixture, - sync_line_label_keys, - expected -): - + ({"eye_tracking_timings": [3.0, 0.030, 0.040, 0.050]}, Dataset.EYE_TRACKING_KEYS, []), + ({"behavior_tracking_timings": [8.0, 0.080, 0.090, 0.100, 0.110]}, Dataset.BEHAVIOR_TRACKING_KEYS, []), + ], + indirect=["mock_dataset_fixture"], +) +def test_get_synchronized_frame_times_trim_all(monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", - sync_line_label_keys - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys) assert np.allclose(obtained, expected) @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": [0.020, 0.030, 3.0, 0.050, 0.040]}, - Dataset.EYE_TRACKING_KEYS, [0.020, 0.030, 3.0, 0.050, 0.040]), - - ({"behavior_tracking_timings": [0.080, 8.0, 0.090, 0.100, 0.110]}, - Dataset.BEHAVIOR_TRACKING_KEYS, [0.080, 8.0, 0.090, 0.100, 0.110]) + ( + {"eye_tracking_timings": [0.020, 0.030, 3.0, 0.050, 0.040]}, + Dataset.EYE_TRACKING_KEYS, + [0.020, 0.030, 3.0, 0.050, 0.040], + ), + ( + {"behavior_tracking_timings": [0.080, 8.0, 0.090, 0.100, 0.110]}, + Dataset.BEHAVIOR_TRACKING_KEYS, + [0.080, 8.0, 0.090, 0.100, 0.110], + ), ], - indirect=["mock_dataset_fixture"]) + indirect=["mock_dataset_fixture"], +) def test_get_synchronized_frame_times_no_trim_with_spike( - monkeypatch, - mock_dataset_fixture, - sync_line_label_keys, - expected + monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected ): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", - sync_line_label_keys, - trim_after_spike=False - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys, trim_after_spike=False) assert np.allclose(obtained, expected) @pytest.mark.parametrize( "mock_dataset_fixture,sync_line_label_keys,expected", [ - ({"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, - Dataset.EYE_TRACKING_KEYS, [0.020, 0.030, 0.050]), - - ({"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, - Dataset.BEHAVIOR_TRACKING_KEYS, [0.080, 0.090, 0.110]) - ], indirect=["mock_dataset_fixture"]) -def test_get_synchronized_frame_times_drop_frame( - monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected -): + ({"eye_tracking_timings": [0.020, 0.030, 0.040, 0.050, 3.0]}, Dataset.EYE_TRACKING_KEYS, [0.020, 0.030, 0.050]), + ( + {"behavior_tracking_timings": [0.080, 0.090, 0.100, 0.110, 8.0]}, + Dataset.BEHAVIOR_TRACKING_KEYS, + [0.080, 0.090, 0.110], + ), + ], + indirect=["mock_dataset_fixture"], +) +def test_get_synchronized_frame_times_drop_frame(monkeypatch, mock_dataset_fixture, sync_line_label_keys, expected): monkeypatch.setattr(su, "Dataset", mock_dataset_fixture) - obtained = su.get_synchronized_frame_times( - "dummy_path", - sync_line_label_keys, - drop_frames=[2] - ) + obtained = su.get_synchronized_frame_times("dummy_path", sync_line_label_keys, drop_frames=[2]) assert np.allclose(obtained, expected) diff --git a/allensdk/test/brain_observatory/test_circle_plots.py b/allensdk/test/brain_observatory/test_circle_plots.py index fd4a5cb753..5a50a71856 100644 --- a/allensdk/test/brain_observatory/test_circle_plots.py +++ b/allensdk/test/brain_observatory/test_circle_plots.py @@ -36,6 +36,7 @@ import allensdk.brain_observatory.circle_plots as cplots import numpy as np + def test_polar_to_xy(): d = cplots.polar_to_xy([0], 0.0) assert d.shape[0] == 1 @@ -43,61 +44,69 @@ def test_polar_to_xy(): d = cplots.polar_to_xy([0, np.pi], 1.0) - assert np.allclose(d, [[ 1, 0 ], [-1, 0]]) + assert np.allclose(d, [[1, 0], [-1, 0]]) + def test_polar_linspace(): d = cplots.polar_linspace(1, 0, 180, 2, endpoint=True, degrees=True) - assert np.allclose(d, [[1,0],[-1,0]]) + assert np.allclose(d, [[1, 0], [-1, 0]]) d = cplots.polar_linspace(1, 0, np.pi, 2, endpoint=False, degrees=False) - assert np.allclose(d, [[1,0],[0,1.0]]) + assert np.allclose(d, [[1, 0], [0, 1.0]]) - d = cplots.polar_linspace(2, 0, 2*np.pi, 4, endpoint=False, degrees=False) - assert np.allclose(d, [[2,0],[0,2],[-2,0],[0,-2]]) + d = cplots.polar_linspace(2, 0, 2 * np.pi, 4, endpoint=False, degrees=False) + assert np.allclose(d, [[2, 0], [0, 2], [-2, 0], [0, -2]]) d = cplots.polar_linspace(3, 0, 360, 5, endpoint=True, degrees=True) - assert np.allclose(d, [[3,0],[0,3],[-3,0],[0,-3],[3,0]]) + assert np.allclose(d, [[3, 0], [0, 3], [-3, 0], [0, -3], [3, 0]]) + def test_spiral_trials(): - coll = cplots.spiral_trials([0,2]) + coll = cplots.spiral_trials([0, 2]) assert len(coll.get_paths()) == 2 + def test_spiral_trials_polar(): coll = cplots.spiral_trials_polar(1.0, 0.0, [1.0]) assert len(coll.get_paths()) == 1 - coll = cplots.spiral_trials_polar(1.0, 0.0, [1.0], offset=[1,0]) + coll = cplots.spiral_trials_polar(1.0, 0.0, [1.0], offset=[1, 0]) assert len(coll.get_paths()) == 1 + def test_angle_lines(): lines = cplots.angle_lines([0], 0, 1) assert len(lines.get_paths()) == 1 - lines = cplots.angle_lines([0,1], 0, 1) + lines = cplots.angle_lines([0, 1], 0, 1) assert len(lines.get_paths()) == 2 + def test_radial_arcs(): arcs = cplots.radial_arcs([1], 0, 1) assert len(arcs.get_paths()) == 1 - arcs = cplots.radial_arcs([1,2], 0, 1) + arcs = cplots.radial_arcs([1, 2], 0, 1) assert len(arcs.get_paths()) == 2 + def test_radial_circles(): d = cplots.radial_circles([1]) assert len(d.get_paths()) == 1 - d = cplots.radial_circles([1,2]) + d = cplots.radial_circles([1, 2]) assert len(d.get_paths()) == 2 + def test_polar_line_circles(): - d = cplots.polar_line_circles([1],0) + d = cplots.polar_line_circles([1], 0) assert len(d.get_paths()) == 1 - d = cplots.polar_line_circles([1],0,0) + d = cplots.polar_line_circles([1], 0, 0) assert len(d.get_paths()) == 1 + def test_wedge_ring(): d = cplots.wedge_ring(1, 0, 1, 0, 180) assert len(d.get_paths()) == 1 @@ -105,26 +114,24 @@ def test_wedge_ring(): d = cplots.wedge_ring(2, 0, 1) assert len(d.get_paths()) == 2 + def test_reset_hex_pack(): cplots.hex_pack(1.0, 1) cplots.reset_hex_pack() assert len(cplots.HEX_POSITIONS) == 0 - + + def test_hex_pack(): cplots.reset_hex_pack() pos = cplots.hex_pack(1.0, 1) assert pos.shape[0] == 1 - assert np.allclose(pos, [[0,0]]) - assert np.allclose(cplots.HEX_POSITIONS.shape, [1,2]) + assert np.allclose(pos, [[0, 0]]) + assert np.allclose(cplots.HEX_POSITIONS.shape, [1, 2]) pos = cplots.hex_pack(2.0, 2) - assert np.allclose(pos, [[0,0],[4,0]]) - assert np.allclose(cplots.HEX_POSITIONS.shape, [7,2]) + assert np.allclose(pos, [[0, 0], [4, 0]]) + assert np.allclose(cplots.HEX_POSITIONS.shape, [7, 2]) pos = cplots.hex_pack(2.0, 8) - assert np.allclose(cplots.HEX_POSITIONS.shape, [19,2]) - - - - + assert np.allclose(cplots.HEX_POSITIONS.shape, [19, 2]) diff --git a/allensdk/test/brain_observatory/test_demixer.py b/allensdk/test/brain_observatory/test_demixer.py index 66d7bca345..8c754f7fe1 100644 --- a/allensdk/test/brain_observatory/test_demixer.py +++ b/allensdk/test/brain_observatory/test_demixer.py @@ -8,82 +8,78 @@ "source_frame,mask_traces,flat_masks,pixels_per_mask,expected", [ ( - np.array([2., 2., 2., 1.]), + np.array([2.0, 2.0, 2.0, 1.0]), np.array([2.0, 2.0]), sparse.csr_matrix(np.array([[1, 0, 0, 0], [1, 1, 0, 0]])), np.array([1, 2]), np.array([0, 2]), ), ( - np.array([2., 0., 2., 1.]), - np.array([2.0, 0.]), # zero in mask trace + np.array([2.0, 0.0, 2.0, 1.0]), + np.array([2.0, 0.0]), # zero in mask trace sparse.csr_matrix(np.array([[1, 0, 0, 0], [1, 1, 0, 0]])), np.array([1, 2]), None, ), ( - np.array([2., 0., 2., 1.]), - np.array([2.0, 0.]), + np.array([2.0, 0.0, 2.0, 1.0]), + np.array([2.0, 0.0]), sparse.csr_matrix(np.array([[1, 0, 0, 0], [0, 0, 0, 0]])), - np.array([1, 0]), # invalid mask (zero pixels) + np.array([1, 0]), # invalid mask (zero pixels) None, ), ( - np.zeros(4), # singular overlap matrix + np.zeros(4), # singular overlap matrix np.ones(2), sparse.csr_matrix(np.array([[1, 0, 0, 0], [1, 1, 0, 0]])), np.array([1, 2]), np.zeros(2), - ) - ] + ), + ], ) -def test_demix_point( - source_frame, mask_traces, flat_masks, pixels_per_mask, expected): - result = dmx._demix_point(source_frame, mask_traces, flat_masks, - pixels_per_mask) +def test_demix_point(source_frame, mask_traces, flat_masks, pixels_per_mask, expected): + result = dmx._demix_point(source_frame, mask_traces, flat_masks, pixels_per_mask) np.testing.assert_equal(result, expected) - @pytest.mark.parametrize( "raw_traces,stack,masks,max_block_size,expected", [ ( np.array([[2.0, 0.0], [2.0, 2.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), - 1, # max_block_size < stack length - (np.array([[0, 0], [2, 0]]), [False, True]) + 1, # max_block_size < stack length + (np.array([[0, 0], [2, 0]]), [False, True]), ), ( np.array([[2.0, 0.0], [2.0, 2.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), - 2, # max_block_size = stack length - (np.array([[0, 0], [2, 0]]), [False, True]) + 2, # max_block_size = stack length + (np.array([[0, 0], [2, 0]]), [False, True]), ), ( np.array([[2.0, 0.0], [2.0, 2.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), 1000, # max_block_size > stack length - (np.array([[0, 0], [2, 0]]), [False, True]) + (np.array([[0, 0], [2, 0]]), [False, True]), ), ( np.array([[2.0, 0.0], [2.0, 2.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), -1, # stack processed in one block - (np.array([[0, 0], [2, 0]]), [False, True]) + (np.array([[0, 0], [2, 0]]), [False, True]), ), ( np.array([[2.0, 0.0, 1.0], [2.0, 2.0, 0.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]], [[1., 2.], [1., 2.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]], [[1.0, 2.0], [1.0, 2.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), 2, # stack length not divisible by max_block_size - (np.array([[0, 0, 0], [2, 0, 0]]), [False, True, True]) + (np.array([[0, 0, 0], [2, 0, 0]]), [False, True, True]), ), - ], ) def test_demix_time_dep_masks(raw_traces, stack, masks, max_block_size, expected): @@ -97,13 +93,12 @@ def test_demix_time_dep_masks(raw_traces, stack, masks, max_block_size, expected [ ( np.array([[2.0, 0.0], [2.0, 2.0]]), - np.array([[[2., 2.], [2., 1.]], [[2., 2.], [2., 1.]]]), + np.array([[[2.0, 2.0], [2.0, 1.0]], [[2.0, 2.0], [2.0, 1.0]]]), np.array([[[1, 0], [0, 0]], [[1, 1], [0, 0]]]), - -2, # invalid max_block_size) + -2, # invalid max_block_size) ), ], ) def test_demix_invalid_max_block_size(raw_traces, stack, masks, max_block_size): with pytest.raises(ValueError, match="Invalid maximum block size*"): dmx.demix_time_dep_masks(raw_traces, stack, masks, max_block_size) - diff --git a/allensdk/test/brain_observatory/test_dff.py b/allensdk/test/brain_observatory/test_dff.py index 504c51df45..5f4e328edd 100644 --- a/allensdk/test/brain_observatory/test_dff.py +++ b/allensdk/test/brain_observatory/test_dff.py @@ -89,7 +89,7 @@ def test_compute_dff_windowed_mode(): y = dff.compute_dff_windowed_mode(x) - assert(y.shape == x.shape) + assert y.shape == x.shape def test_compute_dff_windowed_median(): @@ -104,21 +104,22 @@ def test_compute_dff_windowed_median(): with pytest.raises(ValueError): dff.compute_dff_windowed_median(x) - x = np.sin(np.arange(0, 200)).reshape(1,200) + x = np.sin(np.arange(0, 200)).reshape(1, 200) - y = dff.compute_dff_windowed_median(x, median_kernel_long=101, - median_kernel_short=11, - noise_kernel_length=5) + y = dff.compute_dff_windowed_median(x, median_kernel_long=101, median_kernel_short=11, noise_kernel_length=5) - assert(y.shape == x.shape) + assert y.shape == x.shape noise_stds = [] small_frames = [] - y = dff.compute_dff_windowed_median(x, median_kernel_long=101, - median_kernel_short=11, - noise_stds=noise_stds, - n_small_baseline_frames=small_frames, - noise_kernel_length=5) + y = dff.compute_dff_windowed_median( + x, + median_kernel_long=101, + median_kernel_short=11, + noise_stds=noise_stds, + n_small_baseline_frames=small_frames, + noise_kernel_length=5, + ) assert len(noise_stds) == 1 assert len(small_frames) == 1 @@ -129,8 +130,7 @@ def test_calculate_dff(): with patch("os.makedirs") as mock_makedirs: with patch.object(Figure, "savefig") as mock_save: - with patch.object(dff, "compute_dff_windowed_median", - return_value=x) as mock_computation: + with patch.object(dff, "compute_dff_windowed_median", return_value=x) as mock_computation: dff.calculate_dff(x) assert mock_makedirs.call_count == 0 assert mock_save.call_count == 0 @@ -139,22 +139,23 @@ def test_calculate_dff(): with patch("os.makedirs") as mock_makedirs: with patch.object(Figure, "savefig") as mock_save: mock_computation = MagicMock(return_value=x) - dff.calculate_dff(x, dff_computation_cb=mock_computation, - save_plot_dir="./test") + dff.calculate_dff(x, dff_computation_cb=mock_computation, save_plot_dir="./test") mock_makedirs.assert_called_once_with("./test") mock_save.assert_called_once() mock_computation.assert_called_once_with(x) - x = np.sin(np.arange(0, 200)).reshape(1,200) + x = np.sin(np.arange(0, 200)).reshape(1, 200) noise_stds = [] small_frames = [] - computation_cb = partial(dff.compute_dff_windowed_median, - median_kernel_long=101, - median_kernel_short=11, - noise_stds=noise_stds, - n_small_baseline_frames=small_frames, - noise_kernel_length=5) + computation_cb = partial( + dff.compute_dff_windowed_median, + median_kernel_long=101, + median_kernel_short=11, + noise_stds=noise_stds, + n_small_baseline_frames=small_frames, + noise_kernel_length=5, + ) dff.calculate_dff(x, dff_computation_cb=computation_cb) assert len(noise_stds) == 1 assert len(small_frames) == 1 diff --git a/allensdk/test/brain_observatory/test_drifting_gratings.py b/allensdk/test/brain_observatory/test_drifting_gratings.py index 46a3d50037..4e0ff65134 100644 --- a/allensdk/test/brain_observatory/test_drifting_gratings.py +++ b/allensdk/test/brain_observatory/test_drifting_gratings.py @@ -42,58 +42,48 @@ @pytest.fixture def dataset(): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) - dataset.get_stimulus_table=MagicMock(name='get_stimulus_table', - return_value=MagicMock()) - + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) + dataset.get_stimulus_table = MagicMock(name="get_stimulus_table", return_value=MagicMock()) + return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + def mock_sweep_response(): - sweep_response = MagicMock(name='sweep_response') - mean_sweep_response = MagicMock(name='mean_sweep_response') - pval = MagicMock(name='pval') - - return MagicMock(name='get_sweep_response', - return_value=(sweep_response, - mean_sweep_response, - pval)) - -@patch.object(StimulusAnalysis, - 'get_speed_tuning', - mock_speed_tuning()) -@patch.object(StimulusAnalysis, - 'get_sweep_response', - mock_sweep_response()) -@pytest.mark.parametrize('trigger', (1, 2, 3, 4, 5)) + sweep_response = MagicMock(name="sweep_response") + mean_sweep_response = MagicMock(name="mean_sweep_response") + pval = MagicMock(name="pval") + + return MagicMock(name="get_sweep_response", return_value=(sweep_response, mean_sweep_response, pval)) + + +@patch.object(StimulusAnalysis, "get_speed_tuning", mock_speed_tuning()) +@patch.object(StimulusAnalysis, "get_sweep_response", mock_sweep_response()) +@pytest.mark.parametrize("trigger", (1, 2, 3, 4, 5)) def test_harness(dataset, trigger): dg = DriftingGratings(dataset) diff --git a/allensdk/test/brain_observatory/test_locally_sparse_noise.py b/allensdk/test/brain_observatory/test_locally_sparse_noise.py index 9e2a461645..2f5a8fef82 100644 --- a/allensdk/test/brain_observatory/test_locally_sparse_noise.py +++ b/allensdk/test/brain_observatory/test_locally_sparse_noise.py @@ -42,71 +42,58 @@ @pytest.fixture def dataset(): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) - - LSN = MagicMock(name='LSN') - LSN_mask = MagicMock(name='LSN_mask') - dataset.get_locally_sparse_noise_stimulus_template = \ - MagicMock(name='get_locally_sparse_noise_stimulus_template', - return_value=(LSN, LSN_mask)) + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) + + LSN = MagicMock(name="LSN") + LSN_mask = MagicMock(name="LSN_mask") + dataset.get_locally_sparse_noise_stimulus_template = MagicMock( + name="get_locally_sparse_noise_stimulus_template", return_value=(LSN, LSN_mask) + ) return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + def mock_sweep_response(): - sweep_response = MagicMock(name='sweep_response') - mean_sweep_response = MagicMock(name='mean_sweep_response') - pval = MagicMock(name='pval') - - return MagicMock(name='get_sweep_response', - return_value=(sweep_response, - mean_sweep_response, - pval)) - -@patch.object(StimulusAnalysis, - 'get_sweep_response', - mock_sweep_response()) -@patch.object(LocallySparseNoise, - 'get_receptive_field', - MagicMock(name='get_receptive_field')) -@pytest.mark.parametrize('stimulus,trigger', - it.product(('locally_sparse_noise', - 'locally_sparse_noise_4deg', - 'locally_sparse_noise_8deg'), - (1,2,3,4,5,6))) -def test_harness(dataset, - stimulus, - trigger): - with patch('allensdk.brain_observatory.stimulus_analysis.StimulusAnalysis.get_speed_tuning', - mock_speed_tuning()): + sweep_response = MagicMock(name="sweep_response") + mean_sweep_response = MagicMock(name="mean_sweep_response") + pval = MagicMock(name="pval") + + return MagicMock(name="get_sweep_response", return_value=(sweep_response, mean_sweep_response, pval)) + + +@patch.object(StimulusAnalysis, "get_sweep_response", mock_sweep_response()) +@patch.object(LocallySparseNoise, "get_receptive_field", MagicMock(name="get_receptive_field")) +@pytest.mark.parametrize( + "stimulus,trigger", + it.product(("locally_sparse_noise", "locally_sparse_noise_4deg", "locally_sparse_noise_8deg"), (1, 2, 3, 4, 5, 6)), +) +def test_harness(dataset, stimulus, trigger): + with patch("allensdk.brain_observatory.stimulus_analysis.StimulusAnalysis.get_speed_tuning", mock_speed_tuning()): lsn = LocallySparseNoise(dataset, stimulus) assert lsn._stim_table is StimulusAnalysis._PRELOAD @@ -119,7 +106,7 @@ def test_harness(dataset, assert lsn._mean_sweep_response is StimulusAnalysis._PRELOAD assert lsn._pval is StimulusAnalysis._PRELOAD assert lsn._receptive_field is StimulusAnalysis._PRELOAD - + if trigger == 1: print(lsn.stim_table) print(lsn.sweep_response) diff --git a/allensdk/test/brain_observatory/test_mouse.py b/allensdk/test/brain_observatory/test_mouse.py index e0d01cfe79..cdda5f7b8f 100644 --- a/allensdk/test/brain_observatory/test_mouse.py +++ b/allensdk/test/brain_observatory/test_mouse.py @@ -8,20 +8,20 @@ from allensdk.brain_observatory.behavior.data_files import BehaviorStimulusFile from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.date_of_acquisition import \ - DateOfAcquisition +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.date_of_acquisition import ( + DateOfAcquisition, +) from allensdk.internal.brain_observatory.mouse import Mouse class TestMouse: @classmethod def setup_class(cls): - cls.mouse = Mouse(mouse_id='1') - cls.image_names = ('A', 'B') + cls.mouse = Mouse(mouse_id="1") + cls.image_names = ("A", "B") cls.tmpdir = tempfile.TemporaryDirectory() def teardown_class(self): @@ -35,64 +35,53 @@ def get_behavior_sessions(): return [ BehaviorMetadata( date_of_acquisition=DateOfAcquisition( - date_of_acquisition=datetime.datetime( - year=2022, month=12, day=day)), + date_of_acquisition=datetime.datetime(year=2022, month=12, day=day) + ), behavior_session_id=BehaviorSessionId(behavior_session_id=day), behavior_session_uuid=None, equipment=None, session_type=None, stimulus_frame_rate=None, - subject_metadata=None + subject_metadata=None, ) - for day in range(1, 11)] + for day in range(1, 11) + ] def get_behavior_stimulus_file(self, behavior_session_id, db): # need to create 10 dummy stimulus dictionaries for each of the 10 # behavior sessions for this mouse behavior_session_id = int(behavior_session_id) if behavior_session_id <= 5: - stimulus_category = 'image' + stimulus_category = "image" if behavior_session_id % 2 == 0: image_name = self.image_names[0] else: image_name = self.image_names[1] else: - stimulus_category = 'grating' + stimulus_category = "grating" image_name = None d = { - 'items': { - 'behavior': { - 'stimuli': { + "items": { + "behavior": { + "stimuli": { # unused - '': { - 'set_log': [ - (stimulus_category, - image_name, '', '') for _ in range(10)] - } + "": {"set_log": [(stimulus_category, image_name, "", "") for _ in range(10)]} } } } } - with open(Path(self.tmpdir.name) / f'stim_{behavior_session_id}.pkl', - 'wb') as f: + with open(Path(self.tmpdir.name) / f"stim_{behavior_session_id}.pkl", "wb") as f: pickle.dump(d, f) - return BehaviorStimulusFile( - filepath=Path(self.tmpdir.name) / f'stim_{behavior_session_id}' - f'.pkl') + return BehaviorStimulusFile(filepath=Path(self.tmpdir.name) / f"stim_{behavior_session_id}.pkl") - @pytest.mark.parametrize('upto_behavior_session_id', (None, 1, 2, 6, 9)) + @pytest.mark.parametrize("upto_behavior_session_id", (None, 1, 2, 6, 9)) def test_images_shown(self, upto_behavior_session_id): with patch( - 'allensdk.internal.brain_observatory.mouse.' - 'db_connection_creator', - wraps=lambda fallback_credentials: None): - with patch.object(Mouse, attribute='get_behavior_sessions', - wraps=self.get_behavior_sessions): - with patch.object(BehaviorStimulusFile, attribute='from_lims', - wraps=self.get_behavior_stimulus_file): - obt = self.mouse.get_images_shown( - up_to_behavior_session_id=upto_behavior_session_id, - n_workers=1) + "allensdk.internal.brain_observatory.mouse.db_connection_creator", wraps=lambda fallback_credentials: None + ): + with patch.object(Mouse, attribute="get_behavior_sessions", wraps=self.get_behavior_sessions): + with patch.object(BehaviorStimulusFile, attribute="from_lims", wraps=self.get_behavior_stimulus_file): + obt = self.mouse.get_images_shown(up_to_behavior_session_id=upto_behavior_session_id, n_workers=1) if upto_behavior_session_id is None: assert obt == set(self.image_names) elif upto_behavior_session_id == 1: @@ -100,7 +89,7 @@ def test_images_shown(self, upto_behavior_session_id): assert obt == set() elif upto_behavior_session_id == 2: # only been shown B - assert obt == {'B'} + assert obt == {"B"} elif upto_behavior_session_id == 6: # shown all images assert obt == set(self.image_names) diff --git a/allensdk/test/brain_observatory/test_natural_movie.py b/allensdk/test/brain_observatory/test_natural_movie.py index 20b9a08ee8..422d6c50ca 100644 --- a/allensdk/test/brain_observatory/test_natural_movie.py +++ b/allensdk/test/brain_observatory/test_natural_movie.py @@ -42,71 +42,60 @@ @pytest.fixture def stimulus_table(): - return pd.DataFrame([ - {'frame': 0, 'start': 0, 'stop': 1}, - {'frame': 0, 'start': 1, 'stop': 2}, - {'frame': 1, 'start': 2, 'stop': 3}, - ]) + return pd.DataFrame( + [ + {"frame": 0, "start": 0, "stop": 1}, + {"frame": 0, "start": 1, "stop": 2}, + {"frame": 1, "start": 2, "stop": 3}, + ] + ) @pytest.fixture def dataset(stimulus_table): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) - dataset.get_stimulus_table=MagicMock(name='get_stimulus_table', - return_value=stimulus_table) - + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) + dataset.get_stimulus_table = MagicMock(name="get_stimulus_table", return_value=stimulus_table) + return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + def mock_sweep_response(): - sweep_response = MagicMock(name='sweep_response') - mean_sweep_response = MagicMock(name='mean_sweep_response') - pval = MagicMock(name='pval') - - return MagicMock(name='get_sweep_response', - return_value=(sweep_response, - mean_sweep_response, - pval)) - -@patch.object(StimulusAnalysis, - 'get_speed_tuning', - mock_speed_tuning()) -@patch.object(StimulusAnalysis, - 'get_sweep_response', - mock_sweep_response()) + sweep_response = MagicMock(name="sweep_response") + mean_sweep_response = MagicMock(name="mean_sweep_response") + pval = MagicMock(name="pval") + + return MagicMock(name="get_sweep_response", return_value=(sweep_response, mean_sweep_response, pval)) + + +@patch.object(StimulusAnalysis, "get_speed_tuning", mock_speed_tuning()) +@patch.object(StimulusAnalysis, "get_sweep_response", mock_sweep_response()) @pytest.mark.parametrize( - 'trigger', [ - ('stim_table', 'sweep_response', 'peak'), - ('sweeplength', 'sweep_response', 'peak') - ] + "trigger", [("stim_table", "sweep_response", "peak"), ("sweeplength", "sweep_response", "peak")] ) def test_harness(dataset, trigger): movie_name = "Mock Movie Name" diff --git a/allensdk/test/brain_observatory/test_natural_scenes.py b/allensdk/test/brain_observatory/test_natural_scenes.py index d4ce68e709..699aee23a1 100644 --- a/allensdk/test/brain_observatory/test_natural_scenes.py +++ b/allensdk/test/brain_observatory/test_natural_scenes.py @@ -41,58 +41,48 @@ @pytest.fixture def dataset(): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) - dataset.get_stimulus_table=MagicMock(name='get_stimulus_table', - return_value=MagicMock()) - + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) + dataset.get_stimulus_table = MagicMock(name="get_stimulus_table", return_value=MagicMock()) + return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + def mock_sweep_response(): - sweep_response = MagicMock(name='sweep_response') - mean_sweep_response = MagicMock(name='mean_sweep_response') - pval = MagicMock(name='pval') - - return MagicMock(name='get_sweep_response', - return_value=(sweep_response, - mean_sweep_response, - pval)) - -@patch.object(StimulusAnalysis, - 'get_speed_tuning', - mock_speed_tuning()) -@patch.object(StimulusAnalysis, - 'get_sweep_response', - mock_sweep_response()) -@pytest.mark.parametrize('trigger', (1, 2, 3, 4, 5)) + sweep_response = MagicMock(name="sweep_response") + mean_sweep_response = MagicMock(name="mean_sweep_response") + pval = MagicMock(name="pval") + + return MagicMock(name="get_sweep_response", return_value=(sweep_response, mean_sweep_response, pval)) + + +@patch.object(StimulusAnalysis, "get_speed_tuning", mock_speed_tuning()) +@patch.object(StimulusAnalysis, "get_sweep_response", mock_sweep_response()) +@pytest.mark.parametrize("trigger", (1, 2, 3, 4, 5)) def test_harness(dataset, trigger): ns = NaturalScenes(dataset) diff --git a/allensdk/test/brain_observatory/test_notebook.py b/allensdk/test/brain_observatory/test_notebook.py index 7d250787b1..cb082e802a 100644 --- a/allensdk/test/brain_observatory/test_notebook.py +++ b/allensdk/test/brain_observatory/test_notebook.py @@ -48,8 +48,8 @@ @pytest.fixture def boc(tmpdir_factory): - manifest_file = tmpdir_factory.mktemp('data').join(os.path.join('boc','manifest.json')) - endpoint = os.environ['TEST_API_ENDPOINT'] if 'TEST_API_ENDPOINT' in os.environ else 'http://api.brain-map.org' + manifest_file = tmpdir_factory.mktemp("data").join(os.path.join("boc", "manifest.json")) + endpoint = os.environ["TEST_API_ENDPOINT"] if "TEST_API_ENDPOINT" in os.environ else "http://api.brain-map.org" return BrainObservatoryCache(manifest_file=str(manifest_file), base_uri=endpoint) @@ -60,13 +60,13 @@ def test_brain_observatory_trace_analysis_notebook(boc): dg = DriftingGratings(data_set) specimen_id = 517425074 specimen_ids = data_set.get_cell_specimen_ids() - - cell_loc = np.argwhere(specimen_ids==specimen_id)[0][0] - assert cell_loc == 97 - + cell_loc = np.argwhere(specimen_ids == specimen_id)[0][0] + + assert cell_loc == 97 + # temporal frequency plot - dg.response[:,1:,cell_loc,0] + dg.response[:, 1:, cell_loc, 0] dg.tfvals[1:] # peak @@ -75,20 +75,22 @@ def test_brain_observatory_trace_analysis_notebook(boc): # trials for cell's preferred condition pref_ori = dg.orivals[dg.peak.ori_dg[cell_loc]] pref_tf = dg.tfvals[dg.peak.tf_dg[cell_loc]] - assert pref_ori == 180 + assert pref_ori == 180 assert pref_tf == 2 - pref_trials = dg.stim_table[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)] - assert pref_trials['start'][1] == 837 - assert pref_trials['end'][1] == 897 + pref_trials = dg.stim_table[(dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf)] + assert pref_trials["start"][1] == 837 + assert pref_trials["end"][1] == 897 # mean sweep response - dg.sweep_response[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)] - subset_mean = dg.mean_sweep_response[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)] - assert np.isclose(subset_mean['dx'][1], 0.920868) + dg.sweep_response[(dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf)] + subset_mean = dg.mean_sweep_response[ + (dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf) + ] + assert np.isclose(subset_mean["dx"][1], 0.920868) # response to each trial - np.arange(-1*dg.interlength, dg.interlength+dg.sweeplength, 1.)/dg.acquisition_rate + np.arange(-1 * dg.interlength, dg.interlength + dg.sweeplength, 1.0) / dg.acquisition_rate @pytest.mark.nightly @@ -97,8 +99,8 @@ def test_brain_observatory_static_gratings_notebook(boc): sg = StaticGratings(data_set) peak_head = sg.peak.head() - assert peak_head['cell_specimen_id'][0] == 517399188 - assert np.isclose(peak_head['reliability_sg'][0], -0.010099250163301616) + assert peak_head["cell_specimen_id"][0] == 517399188 + assert np.isclose(peak_head["reliability_sg"][0], -0.010099250163301616) @pytest.mark.nightly @@ -106,73 +108,102 @@ def test_brain_observatory_natural_scenes_notebook(boc): data_set = boc.get_ophys_experiment_data(510938357) ns = NaturalScenes(data_set) ns_head = ns.peak.head() - - assert np.isclose(ns_head['peak_dff_ns'][0], 4.9614532738) - assert ns_head['cell_specimen_id'][0] == 517399188 + + assert np.isclose(ns_head["peak_dff_ns"][0], 4.9614532738) + assert ns_head["cell_specimen_id"][0] == 517399188 @pytest.mark.nightly def test_brain_observatory_locally_sparse_noise_notebook(boc): specimen_id = 517410165 cell = boc.get_cell_specimens(ids=[specimen_id])[0] - - exp = boc.get_ophys_experiments(experiment_container_ids=[cell['experiment_container_id']], - stimuli=[stim_info.LOCALLY_SPARSE_NOISE])[0] - - data_set = boc.get_ophys_experiment_data(exp['id']) + + exp = boc.get_ophys_experiments( + experiment_container_ids=[cell["experiment_container_id"]], stimuli=[stim_info.LOCALLY_SPARSE_NOISE] + )[0] + + data_set = boc.get_ophys_experiment_data(exp["id"]) lsn = LocallySparseNoise(data_set) specimen_ids = data_set.get_cell_specimen_ids() - cell_loc = np.argwhere(specimen_ids==specimen_id)[0][0] - lsn.receptive_field[:,:,cell_loc,0] + cell_loc = np.argwhere(specimen_ids == specimen_id)[0][0] + lsn.receptive_field[:, :, cell_loc, 0] assert True - #assert cell_loc - #assert receptive_field + # assert cell_loc + # assert receptive_field + @pytest.mark.nightly def test_brain_observatory_experiment_containers_notebook(boc): boc.get_all_targeted_structures() - visp_ecs = boc.get_experiment_containers(targeted_structures=['VISp']) + visp_ecs = boc.get_experiment_containers(targeted_structures=["VISp"]) depths = boc.get_all_imaging_depths() stims = boc.get_all_stimuli() cre_lines = boc.get_all_cre_lines() - cux2_ecs = boc.get_experiment_containers(cre_lines=['Cux2-CreERT2']) - cux2_ec_id = cux2_ecs[-1]['id'] + cux2_ecs = boc.get_experiment_containers(cre_lines=["Cux2-CreERT2"]) + cux2_ec_id = cux2_ecs[-1]["id"] boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id]) - exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], - stimuli=[stim_info.STATIC_GRATINGS])[0] - exp = boc.get_ophys_experiment_data(exp['id']) - - assert set(depths) == set([175, 185, 195, 200, 205, 225, 250, 265, 275, 276, 285, - 300, 320, 325, 335, 350, 365, 375, 390, 400, 550, 570, - 625]) - - expected_stimuli = ['drifting_gratings', - 'locally_sparse_noise', - 'locally_sparse_noise_4deg', - 'locally_sparse_noise_8deg', - 'natural_movie_one', - 'natural_movie_three', - 'natural_movie_two', - 'natural_scenes', - 'spontaneous', - 'static_gratings'] + exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], stimuli=[stim_info.STATIC_GRATINGS])[0] + exp = boc.get_ophys_experiment_data(exp["id"]) + + assert set(depths) == set( + [ + 175, + 185, + 195, + 200, + 205, + 225, + 250, + 265, + 275, + 276, + 285, + 300, + 320, + 325, + 335, + 350, + 365, + 375, + 390, + 400, + 550, + 570, + 625, + ] + ) + + expected_stimuli = [ + "drifting_gratings", + "locally_sparse_noise", + "locally_sparse_noise_4deg", + "locally_sparse_noise_8deg", + "natural_movie_one", + "natural_movie_three", + "natural_movie_two", + "natural_scenes", + "spontaneous", + "static_gratings", + ] assert set(stims) == set(expected_stimuli) - expected_cre_lines = [ u'Cux2-CreERT2', - u'Emx1-IRES-Cre', - u'Fezf2-CreER', - u'Nr5a1-Cre', - u'Ntsr1-Cre_GN220', - u'Pvalb-IRES-Cre', - u'Rbp4-Cre_KL100', - u'Rorb-IRES2-Cre', - u'Scnn1a-Tg3-Cre', - u'Slc17a7-IRES2-Cre', - u'Sst-IRES-Cre', - u'Tlx3-Cre_PL56', - u'Vip-IRES-Cre' ] + expected_cre_lines = [ + "Cux2-CreERT2", + "Emx1-IRES-Cre", + "Fezf2-CreER", + "Nr5a1-Cre", + "Ntsr1-Cre_GN220", + "Pvalb-IRES-Cre", + "Rbp4-Cre_KL100", + "Rorb-IRES2-Cre", + "Scnn1a-Tg3-Cre", + "Slc17a7-IRES2-Cre", + "Sst-IRES-Cre", + "Tlx3-Cre_PL56", + "Vip-IRES-Cre", + ] assert set(cre_lines) == set(expected_cre_lines) @@ -181,25 +212,25 @@ def test_brain_observatory_experiment_containers_notebook(boc): cells = pd.DataFrame.from_records(cells) # find direction selective cells in VISp - visp_ec_ids = [ ec['id'] for ec in visp_ecs ] - visp_cells = cells[cells['experiment_container_id'].isin(visp_ec_ids)] + visp_ec_ids = [ec["id"] for ec in visp_ecs] + visp_cells = cells[cells["experiment_container_id"].isin(visp_ec_ids)] # significant response to drifting gratings stimulus - sig_cells = visp_cells[visp_cells['p_dg'] < 0.05] + sig_cells = visp_cells[visp_cells["p_dg"] < 0.05] # direction selective cells - dsi_cells = sig_cells[(sig_cells['dsi_dg'] > 0.5) & (sig_cells['dsi_dg'] < 1.5)] - #assert len(cells) == 27124 + dsi_cells = sig_cells[(sig_cells["dsi_dg"] > 0.5) & (sig_cells["dsi_dg"] < 1.5)] + # assert len(cells) == 27124 assert len(cells) > 0 - #assert len(visp_cells) == 16031 + # assert len(visp_cells) == 16031 assert len(visp_cells) > 0 - #assert len(sig_cells) == 8669 + # assert len(sig_cells) == 8669 assert len(sig_cells) > 0 - #assert len(dsi_cells) == 4943 + # assert len(dsi_cells) == 4943 assert len(dsi_cells) > 0 # find experiment containers for those cells - dsi_ec_ids = dsi_cells['experiment_container_id'].unique() + dsi_ec_ids = dsi_cells["experiment_container_id"].unique() # Download the ophys experiments containing the drifting gratings stimulus for VISp experiment containers boc.get_ophys_experiments(experiment_container_ids=dsi_ec_ids, stimuli=[stim_info.DRIFTING_GRATINGS]) @@ -208,13 +239,14 @@ def test_brain_observatory_experiment_containers_notebook(boc): dsi_cell = dsi_cells.iloc[0] # figure out which ophys experiment has the drifting gratings stimulus for the cell's experiment container - cell_exp = boc.get_ophys_experiments(experiment_container_ids=[dsi_cell['experiment_container_id']], - stimuli=[stim_info.DRIFTING_GRATINGS])[0] - - data_set = boc.get_ophys_experiment_data(cell_exp['id']) + cell_exp = boc.get_ophys_experiments( + experiment_container_ids=[dsi_cell["experiment_container_id"]], stimuli=[stim_info.DRIFTING_GRATINGS] + )[0] + + data_set = boc.get_ophys_experiment_data(cell_exp["id"]) # Fluorescence - dsi_cell_id = dsi_cell['cell_specimen_id'] + dsi_cell_id = dsi_cell["cell_specimen_id"] time, raw_traces = data_set.get_fluorescence_traces(cell_specimen_ids=[dsi_cell_id]) _, demixed_traces = data_set.get_demixed_traces(cell_specimen_ids=[dsi_cell_id]) _, neuropil_traces = data_set.get_neuropil_traces(cell_specimen_ids=[dsi_cell_id]) @@ -223,14 +255,14 @@ def test_brain_observatory_experiment_containers_notebook(boc): # ROI Masks data_set = boc.get_ophys_experiment_data(510221121) - + # get the specimen IDs for a few cells cids = data_set.get_cell_specimen_ids()[:15:5] - + # get masks for specific cells data_set.get_roi_mask(cell_specimen_ids=cids) - - # make a mask of all ROIs in the experiment + + # make a mask of all ROIs in the experiment all_roi_masks = data_set.get_roi_mask_array() all_roi_masks.max(axis=0) @@ -242,36 +274,34 @@ def test_brain_observatory_experiment_containers_notebook(boc): dg = DriftingGratings(data_set) # filter for visually responding, selective cells - vis_cells = (dg.peak.ptest_dg < 0.05) & (dg.peak.peak_dff_dg > 3) + vis_cells = (dg.peak.ptest_dg < 0.05) & (dg.peak.peak_dff_dg > 3) osi_cells = vis_cells & (dg.peak.osi_dg > 0.5) & (dg.peak.osi_dg <= 1.5) dsi_cells = vis_cells & (dg.peak.dsi_dg > 0.5) & (dg.peak.dsi_dg <= 1.5) # 2-d tf vs. ori histogram # tfval = 0 is used for the blank sweep, so we are ignoring it here - os = np.zeros((len(dg.orivals), len(dg.tfvals)-1)) - ds = np.zeros((len(dg.orivals), len(dg.tfvals)-1)) - - for i,trial in dg.peak[osi_cells].iterrows(): - os[trial.ori_dg, trial.tf_dg-1] += 1 - - for i,trial in dg.peak[dsi_cells].iterrows(): - ds[trial.ori_dg, trial.tf_dg-1] += 1 - + os = np.zeros((len(dg.orivals), len(dg.tfvals) - 1)) + ds = np.zeros((len(dg.orivals), len(dg.tfvals) - 1)) + + for i, trial in dg.peak[osi_cells].iterrows(): + os[trial.ori_dg, trial.tf_dg - 1] += 1 + + for i, trial in dg.peak[dsi_cells].iterrows(): + ds[trial.ori_dg, trial.tf_dg - 1] += 1 + max(os.max(), ds.max()) # Neuropil correction data_set = boc.get_ophys_experiment_data(569407590) csid = data_set.get_cell_specimen_ids()[0] - time, demixed_traces = data_set.get_demixed_traces( - cell_specimen_ids=[csid]) + time, demixed_traces = data_set.get_demixed_traces(cell_specimen_ids=[csid]) _, neuropil_traces = data_set.get_neuropil_traces(cell_specimen_ids=[csid]) results = estimate_contamination_ratios(demixed_traces[0], neuropil_traces[0]) - demixed_traces[0] - results['r'] * neuropil_traces[0] - _, corrected_traces = data_set.get_corrected_fluorescence_traces( - cell_specimen_ids=[csid]) - + demixed_traces[0] - results["r"] * neuropil_traces[0] + _, corrected_traces = data_set.get_corrected_fluorescence_traces(cell_specimen_ids=[csid]) + # Running Speed and Motion Correction data_set = boc.get_ophys_experiment_data(512326618) dxcm, dxtime = data_set.get_running_speed() diff --git a/allensdk/test/brain_observatory/test_observatory_plots.py b/allensdk/test/brain_observatory/test_observatory_plots.py index 9bec285e8e..29691c8a94 100644 --- a/allensdk/test/brain_observatory/test_observatory_plots.py +++ b/allensdk/test/brain_observatory/test_observatory_plots.py @@ -49,16 +49,17 @@ from importlib.resources import files -data_file = os.environ.get('TEST_OBSERVATORY_EXPERIMENT_PLOTS_DATA', 'skip') -if data_file == 'default': - data_file = str(files('allensdk.test.brain_observatory').joinpath('test_observatory_plots_data.json')) +data_file = os.environ.get("TEST_OBSERVATORY_EXPERIMENT_PLOTS_DATA", "skip") +if data_file == "default": + data_file = str(files("allensdk.test.brain_observatory").joinpath("test_observatory_plots_data.json")) -if data_file == 'skip': - EXPERIMENT_CONTAINER=None - TEST_DATA_DIR=None +if data_file == "skip": + EXPERIMENT_CONTAINER = None + TEST_DATA_DIR = None else: EXPERIMENT_CONTAINER = ju.read(data_file) - TEST_DATA_DIR = EXPERIMENT_CONTAINER['image_directory'] + TEST_DATA_DIR = EXPERIMENT_CONTAINER["image_directory"] + class AnalysisSingleton(object): def __init__(self, klass, session, *args): @@ -70,16 +71,17 @@ def __init__(self, klass, session, *args): @staticmethod def experiment_for_session(session): - return next(exp for exp in EXPERIMENT_CONTAINER['experiments'] if exp['session'] == session) + return next(exp for exp in EXPERIMENT_CONTAINER["experiments"] if exp["session"] == session) def __call__(self): if self.obj is None: exp = self.experiment_for_session(self.session) - data_set = BrainObservatoryNwbDataSet(exp['nwb_file']) - self.obj = self.klass.from_analysis_file(data_set, exp['analysis_file'], *self.args) + data_set = BrainObservatoryNwbDataSet(exp["nwb_file"]) + self.obj = self.klass.from_analysis_file(data_set, exp["analysis_file"], *self.args) return self.obj + STATIC_GRATINGS = AnalysisSingleton(StaticGratings, stiminfo.THREE_SESSION_B) DRIFTING_GRATINGS = AnalysisSingleton(DriftingGratings, stiminfo.THREE_SESSION_A) NATURAL_SCENES = AnalysisSingleton(NaturalScenes, stiminfo.THREE_SESSION_B) @@ -91,7 +93,7 @@ def __call__(self): LOCALLY_SPARSE_NOISE = AnalysisSingleton(LocallySparseNoise, stiminfo.THREE_SESSION_C, stiminfo.LOCALLY_SPARSE_NOISE) if EXPERIMENT_CONTAINER: - CELL_SPECIMEN_ID = EXPERIMENT_CONTAINER['cells'][0] + CELL_SPECIMEN_ID = EXPERIMENT_CONTAINER["cells"][0] else: CELL_SPECIMEN_ID = None @@ -100,168 +102,175 @@ def assert_images_match(new_file, test_file, shape): assert os.path.exists(new_file) new_img = mpimg.imread(new_file) assert np.allclose(new_img.shape[:2], shape) - + assert os.path.exists(test_file) test_img = mpimg.imread(test_file) assert np.allclose(new_img.shape, test_img.shape) assert np.allclose(test_img.shape[:2], shape) assert (new_img - test_img).mean() < 0.1 - + os.remove(new_file) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,static_gratings", - [ [ 'static_gratings_ttp.png', STATIC_GRATINGS ] ]) -def test_ttp_static_gratings(new_file, static_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,static_gratings", [["static_gratings_ttp.png", STATIC_GRATINGS]]) +def test_ttp_static_gratings(new_file, static_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): static_gratings().plot_time_to_peak() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,static_gratings", - [ [ 'static_gratings_pref_ori.png', STATIC_GRATINGS ] ]) -def test_pref_ori_static_gratings(new_file, static_gratings, shape=[250,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,static_gratings", [["static_gratings_pref_ori.png", STATIC_GRATINGS]]) +def test_pref_ori_static_gratings(new_file, static_gratings, shape=[250, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): static_gratings().plot_preferred_orientation() oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,static_gratings", - [ [ 'static_gratings_ori.png', STATIC_GRATINGS ] ]) -def test_osi_static_gratings(new_file, static_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,static_gratings", [["static_gratings_ori.png", STATIC_GRATINGS]]) +def test_osi_static_gratings(new_file, static_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): static_gratings().plot_orientation_selectivity() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,static_gratings", - [ [ 'static_gratings_pref_sf.png', STATIC_GRATINGS ] ]) -def test_pref_sf(new_file, static_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,static_gratings", [["static_gratings_pref_sf.png", STATIC_GRATINGS]]) +def test_pref_sf(new_file, static_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): static_gratings().plot_preferred_spatial_frequency() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,drifting_gratings", - [ [ 'drifting_gratings_pref_dir.png', DRIFTING_GRATINGS ] ]) -def test_pref_dir_drifting_gratings(new_file, drifting_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,drifting_gratings", [["drifting_gratings_pref_dir.png", DRIFTING_GRATINGS]]) +def test_pref_dir_drifting_gratings(new_file, drifting_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): drifting_gratings().plot_preferred_direction() oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,drifting_gratings", - [ [ 'drifting_gratings_pref_tf.png', DRIFTING_GRATINGS ] ]) -def test_pref_tf_drifting_gratings(new_file, drifting_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,drifting_gratings", [["drifting_gratings_pref_tf.png", DRIFTING_GRATINGS]]) +def test_pref_tf_drifting_gratings(new_file, drifting_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): drifting_gratings().plot_preferred_temporal_frequency() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,drifting_gratings", - [ [ 'drifting_gratings_dsi.png', DRIFTING_GRATINGS ] ]) -def test_dsi_drifting_gratings(new_file, drifting_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,drifting_gratings", [["drifting_gratings_dsi.png", DRIFTING_GRATINGS]]) +def test_dsi_drifting_gratings(new_file, drifting_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): drifting_gratings().plot_direction_selectivity() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,drifting_gratings", - [ [ 'drifting_gratings_osi.png', DRIFTING_GRATINGS ] ]) -def test_osi_drifting_gratings(new_file, drifting_gratings, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,drifting_gratings", [["drifting_gratings_osi.png", DRIFTING_GRATINGS]]) +def test_osi_drifting_gratings(new_file, drifting_gratings, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): drifting_gratings().plot_orientation_selectivity() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,natural_scenes", - [ [ 'natural_scenes_ttp.png', NATURAL_SCENES ] ]) -def test_ttp_natural_scenes(new_file, natural_scenes, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize("new_file,natural_scenes", [["natural_scenes_ttp.png", NATURAL_SCENES]]) +def test_ttp_natural_scenes(new_file, natural_scenes, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): natural_scenes().plot_time_to_peak() oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,static_gratings,cell_specimen_id", - [ [ 'static_gratings_fan_plot.png', STATIC_GRATINGS, CELL_SPECIMEN_ID ] ]) -def test_fan_plot(new_file, static_gratings, cell_specimen_id, shape=[250,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,static_gratings,cell_specimen_id", [["static_gratings_fan_plot.png", STATIC_GRATINGS, CELL_SPECIMEN_ID]] +) +def test_fan_plot(new_file, static_gratings, cell_specimen_id, shape=[250, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): static_gratings().open_fan_plot(cell_specimen_id) oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,natural_scenes,cell_specimen_id", - [ [ 'natural_scenes_fan_plot.png', NATURAL_SCENES, CELL_SPECIMEN_ID ] ]) -def test_corona_plot(new_file, natural_scenes, cell_specimen_id, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,natural_scenes,cell_specimen_id", [["natural_scenes_fan_plot.png", NATURAL_SCENES, CELL_SPECIMEN_ID]] +) +def test_corona_plot(new_file, natural_scenes, cell_specimen_id, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): natural_scenes().open_corona_plot(cell_specimen_id) oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,natural_movie,cell_specimen_id", - [ ("natural_movie_one_a_track_plot.png", NATURAL_MOVIE_ONE_A, CELL_SPECIMEN_ID), - ("natural_movie_one_b_track_plot.png", NATURAL_MOVIE_ONE_B, CELL_SPECIMEN_ID), - ("natural_movie_one_c_track_plot.png", NATURAL_MOVIE_ONE_C, CELL_SPECIMEN_ID), - ("natural_movie_two_track_plot.png", NATURAL_MOVIE_TWO, CELL_SPECIMEN_ID), - ("natural_movie_three_track_plot.png", NATURAL_MOVIE_THREE, CELL_SPECIMEN_ID) ]) -def test_track_plot(new_file, natural_movie, cell_specimen_id, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,natural_movie,cell_specimen_id", + [ + ("natural_movie_one_a_track_plot.png", NATURAL_MOVIE_ONE_A, CELL_SPECIMEN_ID), + ("natural_movie_one_b_track_plot.png", NATURAL_MOVIE_ONE_B, CELL_SPECIMEN_ID), + ("natural_movie_one_c_track_plot.png", NATURAL_MOVIE_ONE_C, CELL_SPECIMEN_ID), + ("natural_movie_two_track_plot.png", NATURAL_MOVIE_TWO, CELL_SPECIMEN_ID), + ("natural_movie_three_track_plot.png", NATURAL_MOVIE_THREE, CELL_SPECIMEN_ID), + ], +) +def test_track_plot(new_file, natural_movie, cell_specimen_id, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): natural_movie().open_track_plot(cell_specimen_id) oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,drifting_gratings,cell_specimen_id", - [ [ 'drifting_gratings_star_plot.png', DRIFTING_GRATINGS, CELL_SPECIMEN_ID ] ]) -def test_star_plot(new_file, drifting_gratings, cell_specimen_id, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,drifting_gratings,cell_specimen_id", + [["drifting_gratings_star_plot.png", DRIFTING_GRATINGS, CELL_SPECIMEN_ID]], +) +def test_star_plot(new_file, drifting_gratings, cell_specimen_id, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): drifting_gratings().open_star_plot(cell_specimen_id) oplots.finalize_no_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,analysis,cell_specimen_id", - [ ("3sa_speed_tuning_plot.png", DRIFTING_GRATINGS, CELL_SPECIMEN_ID), - ("3sb_speed_tuning_plot.png", STATIC_GRATINGS, CELL_SPECIMEN_ID), - ("3sc_speed_tuning_plot.png", NATURAL_MOVIE_TWO, CELL_SPECIMEN_ID) ]) -def test_speed_tuning_plot(new_file, analysis, cell_specimen_id, shape=[500,500]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,analysis,cell_specimen_id", + [ + ("3sa_speed_tuning_plot.png", DRIFTING_GRATINGS, CELL_SPECIMEN_ID), + ("3sb_speed_tuning_plot.png", STATIC_GRATINGS, CELL_SPECIMEN_ID), + ("3sc_speed_tuning_plot.png", NATURAL_MOVIE_TWO, CELL_SPECIMEN_ID), + ], +) +def test_speed_tuning_plot(new_file, analysis, cell_specimen_id, shape=[500, 500]): with oplots.figure_in_px(shape[1], shape[0], new_file): analysis().plot_speed_tuning(cell_specimen_id) oplots.finalize_with_axes() assert_images_match(new_file, os.path.join(TEST_DATA_DIR, new_file), shape) -@pytest.mark.skipif(data_file == 'skip', reason='NWB Data files not configured') -@pytest.mark.parametrize("new_file,locally_sparse_noise,on,cell_specimen_id", - [ ('locally_sparse_noise_on.png', LOCALLY_SPARSE_NOISE, True, CELL_SPECIMEN_ID), - ('locally_sparse_noise_off.png', LOCALLY_SPARSE_NOISE, False, CELL_SPECIMEN_ID) ]) -def test_pincushion_plot(new_file, locally_sparse_noise, on, cell_specimen_id, shape=[500,877]): +@pytest.mark.skipif(data_file == "skip", reason="NWB Data files not configured") +@pytest.mark.parametrize( + "new_file,locally_sparse_noise,on,cell_specimen_id", + [ + ("locally_sparse_noise_on.png", LOCALLY_SPARSE_NOISE, True, CELL_SPECIMEN_ID), + ("locally_sparse_noise_off.png", LOCALLY_SPARSE_NOISE, False, CELL_SPECIMEN_ID), + ], +) +def test_pincushion_plot(new_file, locally_sparse_noise, on, cell_specimen_id, shape=[500, 877]): with oplots.figure_in_px(shape[1], shape[0], new_file): locally_sparse_noise().open_pincushion_plot(on, cell_specimen_id) oplots.finalize_no_axes() diff --git a/allensdk/test/brain_observatory/test_roi_masks.py b/allensdk/test/brain_observatory/test_roi_masks.py index e0266038d4..39b580cbb6 100644 --- a/allensdk/test/brain_observatory/test_roi_masks.py +++ b/allensdk/test/brain_observatory/test_roi_masks.py @@ -73,15 +73,13 @@ def test_init_by_pixels_large(): a = np.random.random((512, 512)) a[a > 0.5] = 1 - m = roi_masks.create_roi_mask( - 512, 512, [0, 0, 0, 0], pix_list=np.argwhere(a)) + m = roi_masks.create_roi_mask(512, 512, [0, 0, 0, 0], pix_list=np.argwhere(a)) npx = len(np.where(a)[0]) assert npx == len(np.where(m.get_mask_plane())[0]) def test_create_neuropil_mask(): - image_width = 100 image_height = 80 @@ -97,7 +95,7 @@ def test_create_neuropil_mask(): roi = roi_masks.create_roi_mask(image_w=image_width, image_h=image_height, border=border, roi_mask=roi_mask) obtained = roi_masks.create_neuropil_mask(roi, border, combined_binary_mask) - expected_mask = np.zeros((58-27, 45-17), dtype=np.uint8) + expected_mask = np.zeros((58 - 27, 45 - 17), dtype=np.uint8) expected_mask[:, :] = 1 assert np.allclose(expected_mask, obtained.mask) @@ -124,40 +122,35 @@ def test_create_empty_neuropil_mask(): obtained = roi_masks.create_neuropil_mask(roi, border, combined_binary_mask) assert obtained.mask is None - assert 'zero_pixels' in obtained.flags + assert "zero_pixels" in obtained.flags @pytest.fixture def image_dims(): - return { - 'width': 100, - 'height': 100 - } + return {"width": 100, "height": 100} @pytest.fixture def motion_border(): return [5.0, 5.0, 5.0, 5.0] + @pytest.fixture def roi_mask_list(image_dims, motion_border): - base_pixels = np.argwhere(np.ones((10, 10))) masks = [] for ii in range(10): pixels = base_pixels + ii * 10 - masks.append(roi_masks.create_roi_mask( - image_dims['width'], - image_dims['height'], - motion_border, - pix_list=pixels, - label=str(ii), - mask_group=-1 - )) + masks.append( + roi_masks.create_roi_mask( + image_dims["width"], image_dims["height"], motion_border, pix_list=pixels, label=str(ii), mask_group=-1 + ) + ) return masks + @pytest.fixture def neuropil_masks(roi_mask_list, motion_border): neuropil_masks = [] @@ -166,18 +159,14 @@ def neuropil_masks(roi_mask_list, motion_border): combined_mask = mask_array.max(axis=0) for roi_mask in roi_mask_list: - neuropil_masks.append(roi_masks.create_neuropil_mask( - roi_mask, - motion_border, - combined_mask, - roi_mask.label - )) + neuropil_masks.append(roi_masks.create_neuropil_mask(roi_mask, motion_border, combined_mask, roi_mask.label)) return neuropil_masks + @pytest.fixture def video(image_dims): num_frames = 20 - data = np.ones((num_frames, image_dims['height'], image_dims['width'])) + data = np.ones((num_frames, image_dims["height"], image_dims["width"])) data[:, 50:, 50:] = 2 return data @@ -185,16 +174,15 @@ def video(image_dims): def test_calculate_traces(video, roi_mask_list): roi_traces, exclusions = roi_masks.calculate_traces(video, roi_mask_list) - expected_exclusions = pd.DataFrame({ - 'roi_id': ['0', '9'], - 'exclusion_label_name': ['motion_border', 'motion_border'] - }) + expected_exclusions = pd.DataFrame( + {"roi_id": ["0", "9"], "exclusion_label_name": ["motion_border", "motion_border"]} + ) assert np.all(np.isnan(roi_traces[0, :])) assert np.all(roi_traces[4, :] == 1) assert np.all(roi_traces[6, :] == 2) assert np.all(np.isnan(roi_traces[9, :])) - + pd.testing.assert_frame_equal(expected_exclusions, pd.DataFrame(exclusions), check_like=True) @@ -207,9 +195,10 @@ def test_validate_masks(roi_mask_list, neuropil_masks): for mask in roi_mask_list: obtained.extend(roi_masks.validate_mask(mask)) - expected_exclusions = pd.DataFrame({ - 'roi_id': ['0', '3', '9', '7'], - 'exclusion_label_name': ['motion_border', 'empty_roi_mask', 'motion_border', 'empty_neuropil_mask'] - }) + expected_exclusions = pd.DataFrame( + { + "roi_id": ["0", "3", "9", "7"], + "exclusion_label_name": ["motion_border", "empty_roi_mask", "motion_border", "empty_neuropil_mask"], + } + ) pd.testing.assert_frame_equal(expected_exclusions, pd.DataFrame(obtained), check_like=True) - diff --git a/allensdk/test/brain_observatory/test_session_analysis.py b/allensdk/test/brain_observatory/test_session_analysis.py index 75caecbd70..d367fa3ffc 100644 --- a/allensdk/test/brain_observatory/test_session_analysis.py +++ b/allensdk/test/brain_observatory/test_session_analysis.py @@ -35,8 +35,7 @@ # import pytest from unittest.mock import patch -from allensdk.core.brain_observatory_nwb_data_set import \ - BrainObservatoryNwbDataSet +from allensdk.core.brain_observatory_nwb_data_set import BrainObservatoryNwbDataSet from allensdk.brain_observatory.session_analysis import SessionAnalysis import os @@ -46,18 +45,19 @@ def mock_stimulus_table(dset, name): t = _orig_get_stimulus_table(dset, name) - t.at[0, 'end'] = t.loc[0, 'start'] + 10 + t.at[0, "end"] = t.loc[0, "start"] + 10 return t @pytest.fixture def session_a(): - filename = os.path.abspath(os.path.join( - "/", "allen", "aibs", "informatics", "module_test_data", - "observatory", "test_nwb", "out_510390912.nwb" - )) - save_path = 'xyza' + filename = os.path.abspath( + os.path.join( + "/", "allen", "aibs", "informatics", "module_test_data", "observatory", "test_nwb", "out_510390912.nwb" + ) + ) + save_path = "xyza" sa = SessionAnalysis(filename, save_path) @@ -66,11 +66,12 @@ def session_a(): @pytest.fixture def session_b(): - filename = os.path.abspath(os.path.join( - "/", "allen", "aibs", "informatics", "module_test_data", - "observatory", "test_nwb", "506278598.nwb" - )) - save_path = 'xyzb' + filename = os.path.abspath( + os.path.join( + "/", "allen", "aibs", "informatics", "module_test_data", "observatory", "test_nwb", "506278598.nwb" + ) + ) + save_path = "xyzb" sa = SessionAnalysis(filename, save_path) @@ -79,11 +80,12 @@ def session_b(): @pytest.fixture def session_c(): - filename = os.path.abspath(os.path.join( - "/", "allen", "aibs", "informatics", "module_test_data", - "observatory", "test_nwb", "out_510221121.nwb" - )) - save_path = 'xyzc' + filename = os.path.abspath( + os.path.join( + "/", "allen", "aibs", "informatics", "module_test_data", "observatory", "test_nwb", "out_510221121.nwb" + ) + ) + save_path = "xyzc" sa = SessionAnalysis(filename, save_path) @@ -91,33 +93,36 @@ def session_c(): @pytest.mark.nightly -@pytest.mark.parametrize('plot_flag', [False]) +@pytest.mark.parametrize("plot_flag", [False]) def test_session_a(session_a, plot_flag): - with patch('allensdk.core.brain_observatory_nwb_data_set.' - 'BrainObservatoryNwbDataSet.get_stimulus_table', - mock_stimulus_table): + with patch( + "allensdk.core.brain_observatory_nwb_data_set.BrainObservatoryNwbDataSet.get_stimulus_table", + mock_stimulus_table, + ): session_a.session_a(plot_flag=plot_flag) assert True @pytest.mark.nightly -@pytest.mark.parametrize('plot_flag', [False]) +@pytest.mark.parametrize("plot_flag", [False]) def test_session_b(session_b, plot_flag): - with patch('allensdk.core.brain_observatory_nwb_data_set.' - 'BrainObservatoryNwbDataSet.get_stimulus_table', - mock_stimulus_table): + with patch( + "allensdk.core.brain_observatory_nwb_data_set.BrainObservatoryNwbDataSet.get_stimulus_table", + mock_stimulus_table, + ): session_b.session_b(plot_flag=plot_flag) assert True @pytest.mark.nightly -@pytest.mark.parametrize('plot_flag', [False]) +@pytest.mark.parametrize("plot_flag", [False]) def test_session_c(session_c, plot_flag): - with patch('allensdk.core.brain_observatory_nwb_data_set.' - 'BrainObservatoryNwbDataSet.get_stimulus_table', - mock_stimulus_table): + with patch( + "allensdk.core.brain_observatory_nwb_data_set.BrainObservatoryNwbDataSet.get_stimulus_table", + mock_stimulus_table, + ): session_c.session_c(plot_flag=plot_flag) assert True @@ -127,18 +132,18 @@ def test_session_c(session_c, plot_flag): def test_session_get_session_type(session_a): session_type = session_a.nwb.get_session_type() - assert session_type == 'three_session_A' + assert session_type == "three_session_A" @pytest.mark.nightly def test_session_get_session_type_b(session_b): session_type = session_b.nwb.get_session_type() - assert session_type == 'three_session_B' + assert session_type == "three_session_B" @pytest.mark.nightly def test_session_get_session_type_c(session_c): session_type = session_c.nwb.get_session_type() - assert session_type == 'three_session_C' + assert session_type == "three_session_C" diff --git a/allensdk/test/brain_observatory/test_session_analysis_regression.py b/allensdk/test/brain_observatory/test_session_analysis_regression.py index c040585e28..dc445a6bdd 100644 --- a/allensdk/test/brain_observatory/test_session_analysis_regression.py +++ b/allensdk/test/brain_observatory/test_session_analysis_regression.py @@ -14,59 +14,59 @@ from allensdk.brain_observatory.natural_scenes import NaturalScenes from allensdk.brain_observatory.locally_sparse_noise import LocallySparseNoise from allensdk.brain_observatory.session_analysis import SessionAnalysis -from allensdk.core.brain_observatory_nwb_data_set import \ - BrainObservatoryNwbDataSet as BODS +from allensdk.core.brain_observatory_nwb_data_set import BrainObservatoryNwbDataSet as BODS import allensdk.brain_observatory.stimulus_info as si logging.basicConfig(level=logging.DEBUG) -if 'TEST_SESSION_ANALYSIS_REGRESSION_DATA' in os.environ: - data_file = os.environ['TEST_SESSION_ANALYSIS_REGRESSION_DATA'] +if "TEST_SESSION_ANALYSIS_REGRESSION_DATA" in os.environ: + data_file = os.environ["TEST_SESSION_ANALYSIS_REGRESSION_DATA"] else: - data_file = str(files('allensdk.test.brain_observatory').joinpath('test_session_analysis_regression_data.json')) + data_file = str(files("allensdk.test.brain_observatory").joinpath("test_session_analysis_regression_data.json")) @pytest.fixture(scope="module") def paths(): pyversion = sys.version_info[0] logging.debug("loading " + data_file) - with open(data_file, 'r') as f: + with open(data_file, "r") as f: data = json.load(f) return data[str(pyversion)] @pytest.fixture(scope="module") def nwb_a(paths): - return paths['nwb_a'] + return paths["nwb_a"] @pytest.fixture(scope="module") def nwb_b(paths): - return paths['nwb_b'] + return paths["nwb_b"] @pytest.fixture(scope="module") def nwb_c(paths): - return paths['nwb_c'] + return paths["nwb_c"] @pytest.fixture(scope="module") def analysis_a(paths): - return paths['analysis_a'] + return paths["analysis_a"] @pytest.fixture(scope="module") def analysis_b(paths): - return paths['analysis_b'] + return paths["analysis_b"] @pytest.fixture(scope="module") def analysis_c(paths): - return paths['analysis_c'] + return paths["analysis_c"] # session a + @pytest.fixture(scope="module") def dg(nwb_a, analysis_a): return DriftingGratings.from_analysis_file(BODS(nwb_a), analysis_a) @@ -74,18 +74,17 @@ def dg(nwb_a, analysis_a): @pytest.fixture(scope="module") def nm1a(nwb_a, analysis_a): - return NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a, - si.NATURAL_MOVIE_ONE) + return NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a, si.NATURAL_MOVIE_ONE) @pytest.fixture(scope="module") def nm3(nwb_a, analysis_a): - return NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a, - si.NATURAL_MOVIE_THREE) + return NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a, si.NATURAL_MOVIE_THREE) # session b + @pytest.fixture(scope="module") def sg(nwb_b, analysis_b): return StaticGratings.from_analysis_file(BODS(nwb_b), analysis_b) @@ -93,8 +92,7 @@ def sg(nwb_b, analysis_b): @pytest.fixture(scope="module") def nm1b(nwb_b, analysis_b): - return NaturalMovie.from_analysis_file(BODS(nwb_b), analysis_b, - si.NATURAL_MOVIE_ONE) + return NaturalMovie.from_analysis_file(BODS(nwb_b), analysis_b, si.NATURAL_MOVIE_ONE) @pytest.fixture(scope="module") @@ -108,22 +106,19 @@ def lsn(nwb_c, analysis_c): # in order to work around 2/3 unicode compatibility, separate files are # specified for python 2 and 3 # we need to look up a different key depending on python version - key = si.LOCALLY_SPARSE_NOISE_4DEG if sys.version_info < (3,) else \ - si.LOCALLY_SPARSE_NOISE + key = si.LOCALLY_SPARSE_NOISE_4DEG if sys.version_info < (3,) else si.LOCALLY_SPARSE_NOISE return LocallySparseNoise.from_analysis_file(BODS(nwb_c), analysis_c, key) @pytest.fixture(scope="module") def nm1c(nwb_c, analysis_c): - return NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c, - si.NATURAL_MOVIE_ONE) + return NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c, si.NATURAL_MOVIE_ONE) @pytest.fixture(scope="module") def nm2(nwb_c, analysis_c): - return NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c, - si.NATURAL_MOVIE_TWO) + return NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c, si.NATURAL_MOVIE_TWO) @pytest.fixture(scope="module") @@ -159,7 +154,7 @@ def analysis_c_new(nwb_c, tmpdir_factory): logging.debug("running analysis c") session_analysis = SessionAnalysis(nwb_c, save_path) - session_type = BODS(nwb_c).get_metadata()['session_type'] + session_type = BODS(nwb_c).get_metadata()["session_type"] if session_type == si.THREE_SESSION_C2: session_analysis.session_c2(plot_flag=False, save_flag=True) elif session_type == si.THREE_SESSION_C: @@ -177,15 +172,17 @@ def compare_peak(p1, p2): p1 = p1.infer_objects() p2 = p2.infer_objects() - peak_blacklist = ["rf_center_on_x_lsn", - "rf_center_on_y_lsn", - "rf_center_off_x_lsn", - "rf_center_off_y_lsn", - "rf_area_on_lsn", - "rf_area_off_lsn", - "rf_distance_lsn", - "rf_overlap_index_lsn", - "rf_chi2_lsn"] + peak_blacklist = [ + "rf_center_on_x_lsn", + "rf_center_on_y_lsn", + "rf_center_off_x_lsn", + "rf_center_off_y_lsn", + "rf_area_on_lsn", + "rf_area_off_lsn", + "rf_distance_lsn", + "rf_overlap_index_lsn", + "rf_chi2_lsn", + ] for col in p1.select_dtypes(include=[np.number]): if col in peak_blacklist: @@ -195,7 +192,7 @@ def compare_peak(p1, p2): logging.debug("checking " + col) assert np.allclose(p1[col], p2[col], equal_nan=True, atol=1e-4) - for col in p1.select_dtypes(include=['O']): + for col in p1.select_dtypes(include=["O"]): logging.debug("checking " + col) assert all(p1[col] == p2[col]) @@ -212,33 +209,23 @@ def test_drifting_gratings(dg, nwb_a, analysis_a_new): logging.debug("reading outputs") dg_new = DriftingGratings.from_analysis_file(BODS(nwb_a), analysis_a_new) # assert np.allclose(dg.sweep_response, dg_new.sweep_response) - assert np.allclose(dg.mean_sweep_response, dg_new.mean_sweep_response, - equal_nan=True) + assert np.allclose(dg.mean_sweep_response, dg_new.mean_sweep_response, equal_nan=True) - assert np.allclose(dg.response, dg_new.response, equal_nan=True, - atol=1e-4, rtol=1e-4) - assert np.allclose(dg.noise_correlation, dg_new.noise_correlation, - equal_nan=True) - assert np.allclose(dg.signal_correlation, dg_new.signal_correlation, - equal_nan=True) - assert np.allclose(dg.representational_similarity, - dg_new.representational_similarity, equal_nan=True) + assert np.allclose(dg.response, dg_new.response, equal_nan=True, atol=1e-4, rtol=1e-4) + assert np.allclose(dg.noise_correlation, dg_new.noise_correlation, equal_nan=True) + assert np.allclose(dg.signal_correlation, dg_new.signal_correlation, equal_nan=True) + assert np.allclose(dg.representational_similarity, dg_new.representational_similarity, equal_nan=True) @pytest.mark.nightly def test_natural_movie_one_a(nm1a, nwb_a, analysis_a_new): - nm1a_new = NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a_new, - si.NATURAL_MOVIE_ONE) + nm1a_new = NaturalMovie.from_analysis_file(BODS(nwb_a), analysis_a_new, si.NATURAL_MOVIE_ONE) # assert np.allclose(nm1a.sweep_response, nm1a_new.sweep_response) - assert np.allclose(nm1a.binned_cells_sp, nm1a_new.binned_cells_sp, - equal_nan=True) - assert np.allclose(nm1a.binned_cells_vis, nm1a_new.binned_cells_vis, - equal_nan=True) - assert np.allclose(nm1a.binned_dx_sp, nm1a_new.binned_dx_sp, - equal_nan=True) - assert np.allclose(nm1a.binned_dx_vis, nm1a_new.binned_dx_vis, - equal_nan=True) + assert np.allclose(nm1a.binned_cells_sp, nm1a_new.binned_cells_sp, equal_nan=True) + assert np.allclose(nm1a.binned_cells_vis, nm1a_new.binned_cells_vis, equal_nan=True) + assert np.allclose(nm1a.binned_dx_sp, nm1a_new.binned_dx_sp, equal_nan=True) + assert np.allclose(nm1a.binned_dx_vis, nm1a_new.binned_dx_vis, equal_nan=True) @pytest.mark.nightly @@ -260,48 +247,34 @@ def test_session_b(analysis_b, analysis_b_new): def test_static_gratings(sg, nwb_b, analysis_b_new): sg_new = StaticGratings.from_analysis_file(BODS(nwb_b), analysis_b_new) # assert np.allclose(sg.sweep_response, sg_new.sweep_response) - assert np.allclose(sg.mean_sweep_response, sg_new.mean_sweep_response, - equal_nan=True) + assert np.allclose(sg.mean_sweep_response, sg_new.mean_sweep_response, equal_nan=True) - assert np.allclose(sg.response, sg_new.response, equal_nan=True, - atol=1e-4, rtol=1e-4) - assert np.allclose(sg.noise_correlation, sg_new.noise_correlation, - equal_nan=True) - assert np.allclose(sg.signal_correlation, sg_new.signal_correlation, - equal_nan=True) - assert np.allclose(sg.representational_similarity, - sg_new.representational_similarity, equal_nan=True) + assert np.allclose(sg.response, sg_new.response, equal_nan=True, atol=1e-4, rtol=1e-4) + assert np.allclose(sg.noise_correlation, sg_new.noise_correlation, equal_nan=True) + assert np.allclose(sg.signal_correlation, sg_new.signal_correlation, equal_nan=True) + assert np.allclose(sg.representational_similarity, sg_new.representational_similarity, equal_nan=True) @pytest.mark.nightly def test_natural_movie_one_b(nm1b, nwb_b, analysis_b_new): - nm1b_new = NaturalMovie.from_analysis_file(BODS(nwb_b), analysis_b_new, - si.NATURAL_MOVIE_ONE) + nm1b_new = NaturalMovie.from_analysis_file(BODS(nwb_b), analysis_b_new, si.NATURAL_MOVIE_ONE) # assert np.allclose(nm1b.sweep_response, nm1b_new.sweep_response) - assert np.allclose(nm1b.binned_cells_sp, nm1b_new.binned_cells_sp, - equal_nan=True) - assert np.allclose(nm1b.binned_cells_vis, nm1b_new.binned_cells_vis, - equal_nan=True) - assert np.allclose(nm1b.binned_dx_sp, nm1b_new.binned_dx_sp, - equal_nan=True) - assert np.allclose(nm1b.binned_dx_vis, nm1b_new.binned_dx_vis, - equal_nan=True) + assert np.allclose(nm1b.binned_cells_sp, nm1b_new.binned_cells_sp, equal_nan=True) + assert np.allclose(nm1b.binned_cells_vis, nm1b_new.binned_cells_vis, equal_nan=True) + assert np.allclose(nm1b.binned_dx_sp, nm1b_new.binned_dx_sp, equal_nan=True) + assert np.allclose(nm1b.binned_dx_vis, nm1b_new.binned_dx_vis, equal_nan=True) @pytest.mark.nightly def test_natural_scenes(ns, nwb_b, analysis_b_new): ns_new = NaturalScenes.from_analysis_file(BODS(nwb_b), analysis_b_new) # assert np.allclose(ns.sweep_response, ns_new.sweep_response) - assert np.allclose(ns.mean_sweep_response, ns_new.mean_sweep_response, - equal_nan=True) + assert np.allclose(ns.mean_sweep_response, ns_new.mean_sweep_response, equal_nan=True) - assert np.allclose(ns.noise_correlation, ns_new.noise_correlation, - equal_nan=True) - assert np.allclose(ns.signal_correlation, ns_new.signal_correlation, - equal_nan=True) - assert np.allclose(ns.representational_similarity, - ns_new.representational_similarity, equal_nan=True) + assert np.allclose(ns.noise_correlation, ns_new.noise_correlation, equal_nan=True) + assert np.allclose(ns.signal_correlation, ns_new.signal_correlation, equal_nan=True) + assert np.allclose(ns.representational_similarity, ns_new.representational_similarity, equal_nan=True) @pytest.mark.nightly @@ -314,34 +287,24 @@ def test_session_c(analysis_c, analysis_c_new): @pytest.mark.nightly def test_locally_sparse_noise(lsn, nwb_c, analysis_c_new): ds = BODS(nwb_c) - session_type = ds.get_metadata()['session_type'] + session_type = ds.get_metadata()["session_type"] logging.debug(session_type) if session_type == si.THREE_SESSION_C: - lsn_new = LocallySparseNoise.from_analysis_file( - ds, analysis_c_new, - si.LOCALLY_SPARSE_NOISE) + lsn_new = LocallySparseNoise.from_analysis_file(ds, analysis_c_new, si.LOCALLY_SPARSE_NOISE) elif session_type == si.THREE_SESSION_C2: - lsn_new = LocallySparseNoise.from_analysis_file( - ds, analysis_c_new, - si.LOCALLY_SPARSE_NOISE_4DEG) + lsn_new = LocallySparseNoise.from_analysis_file(ds, analysis_c_new, si.LOCALLY_SPARSE_NOISE_4DEG) # assert np.allclose(lsn.sweep_response, lsn_new.sweep_response) - assert np.allclose(lsn.mean_sweep_response, lsn_new.mean_sweep_response, - equal_nan=True) + assert np.allclose(lsn.mean_sweep_response, lsn_new.mean_sweep_response, equal_nan=True) @pytest.mark.nightly def test_natural_movie_one_c(nm1c, nwb_c, analysis_c_new): - nm1c_new = NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c_new, - si.NATURAL_MOVIE_ONE) + nm1c_new = NaturalMovie.from_analysis_file(BODS(nwb_c), analysis_c_new, si.NATURAL_MOVIE_ONE) # assert np.allclose(nm1c.sweep_response, nm1c_new.sweep_response) - assert np.allclose(nm1c.binned_dx_sp, nm1c_new.binned_dx_sp, - equal_nan=True) - assert np.allclose(nm1c.binned_dx_vis, nm1c_new.binned_dx_vis, - equal_nan=True) - assert np.allclose(nm1c.binned_cells_sp, nm1c_new.binned_cells_sp, - equal_nan=True) - assert np.allclose(nm1c.binned_cells_vis, nm1c_new.binned_cells_vis, - equal_nan=True) + assert np.allclose(nm1c.binned_dx_sp, nm1c_new.binned_dx_sp, equal_nan=True) + assert np.allclose(nm1c.binned_dx_vis, nm1c_new.binned_dx_vis, equal_nan=True) + assert np.allclose(nm1c.binned_cells_sp, nm1c_new.binned_cells_sp, equal_nan=True) + assert np.allclose(nm1c.binned_cells_vis, nm1c_new.binned_cells_vis, equal_nan=True) diff --git a/allensdk/test/brain_observatory/test_session_api_utils.py b/allensdk/test/brain_observatory/test_session_api_utils.py index 00c61d9714..e6631f8723 100644 --- a/allensdk/test/brain_observatory/test_session_api_utils.py +++ b/allensdk/test/brain_observatory/test_session_api_utils.py @@ -9,12 +9,18 @@ class ParamsMixinTestHarness(ParamsMixin): - - def __init__(self, param_to_ignore, a_param_1: int, a_param_2: float, - b_param_1: list, c_param_1: bool, d_param_1: np.ndarray, - e_param_1: pd.Series, f_param_1: pd.DataFrame): - - super().__init__(ignore={'param_to_ignore'}) + def __init__( + self, + param_to_ignore, + a_param_1: int, + a_param_2: float, + b_param_1: list, + c_param_1: bool, + d_param_1: np.ndarray, + e_param_1: pd.Series, + f_param_1: pd.DataFrame, + ): + super().__init__(ignore={"param_to_ignore"}) self._a_param_1 = a_param_1 self._a_param_2 = a_param_2 @@ -27,287 +33,402 @@ def __init__(self, param_to_ignore, a_param_1: int, a_param_2: float, @pytest.fixture def mixin_harness_fixture(request) -> ParamsMixinTestHarness: - param_to_ignore = request.param.get('param_to_ignore', 'x') - a_param_1 = request.param.get('a_param_1', 8) - a_param_2 = request.param.get('a_param_2', 42.0) - b_param_1 = request.param.get('b_param_1', [1, 2, 3]) - c_param_1 = request.param.get('c_param_1', True) - d_param_1 = request.param.get('d_param_1', np.array([5, 5])) - e_param_1 = request.param.get('e_param_1', pd.Series([4.0, 5.0])) - f_param_1 = request.param.get('f_param_1', pd.DataFrame([1, 2, 3])) - - mixed_in = ParamsMixinTestHarness(param_to_ignore, a_param_1, - a_param_2, b_param_1, c_param_1, - d_param_1, e_param_1, f_param_1) - mixed_in._updated_params = request.param.get('updated_params', set()) + param_to_ignore = request.param.get("param_to_ignore", "x") + a_param_1 = request.param.get("a_param_1", 8) + a_param_2 = request.param.get("a_param_2", 42.0) + b_param_1 = request.param.get("b_param_1", [1, 2, 3]) + c_param_1 = request.param.get("c_param_1", True) + d_param_1 = request.param.get("d_param_1", np.array([5, 5])) + e_param_1 = request.param.get("e_param_1", pd.Series([4.0, 5.0])) + f_param_1 = request.param.get("f_param_1", pd.DataFrame([1, 2, 3])) + + mixed_in = ParamsMixinTestHarness( + param_to_ignore, a_param_1, a_param_2, b_param_1, c_param_1, d_param_1, e_param_1, f_param_1 + ) + mixed_in._updated_params = request.param.get("updated_params", set()) return mixed_in -@pytest.mark.parametrize("a, b, expected", [ - (2, 2, True), - ('1', '1', True), - (1.5, 1.5, True), - ([1, 2, 3], [1, 2, 3], True), - ({1, 2, 3}, {1, 2, 3}, True), - ({'a', 'b', 'c'}, {'c', 'a', 'b'}, True), - ({'a': 0, 'z': 42}, {'a': 0, 'z': 42}, True), - (np.array([1, 2, 3]), np.array([1, 2, 3]), True), - ({'c': np.array([5, 5])}, {'c': np.array([5, 5])}, True), - (pd.Series([5, 5, 5]), pd.Series([5, 5, 5]), True), - (pd.DataFrame([10, 10]), pd.DataFrame([10, 10]), True), - ([pd.DataFrame(['a', 'b', 'c'])], [pd.DataFrame(['a', 'b', 'c'])], True), - ({'a': np.array([1, 2, 3])}, {'a': np.array([1, 2, 3])}, True), - ({'a': {'x': pd.Series([5.0, 6.0])}}, {'a': {'x': pd.Series([5.0, 6.0])}}, True), - ({'a': 20, 'b': 30}, {'b': 30, 'a': 20}, True), - - (1, 2.0, False), - ('1', 2, False), - ([1, 2, 3], 5, False), - ([1, 2, 3], [1, 2], False), - ([1, 2, 3], [3, 2, 1], False), - (['a', 'b'], {'a', 'b'}, False), - ({'a'}, {'a', 'b'}, False), - ({'a'}, {'b'}, False), - ({'a', 'b'}, np.array(['a', 'b']), False), - (np.array([3, 4, 5]), np.array([3, 4]), False), - ({'c': np.array([5, 5])}, {'c': np.array([5, 6])}, False), - (pd.Series([5, 5, 5]), pd.Series([5, 6, 5]), False), - (pd.Series([1, 2, 3]), pd.Series([1, 2]), False), - (pd.DataFrame([10, 10]), pd.DataFrame([10, 7]), False), - (pd.DataFrame([10, 20, 30]), pd.DataFrame([10, 20]), False), - ([pd.DataFrame(['a', 'b', 'c'])], [pd.DataFrame(['a', 'b', 'd'])], False), - ({'a': np.array([1, 2, 3])}, {'a': np.array([1, 2, 5])}, False), - ({'a': {'x': pd.Series([5.0, 6.0])}}, {'a': {'x': pd.Series([5.0, 7.0])}}, False), - - (pd.Series([5, 5, 5]), np.array([5, 5, 5]), False), - (np.array([8, 8, 8]), pd.DataFrame([8, 8, 8]), False), - (pd.Series([3, 3, 3]), pd.DataFrame([3, 3, 3]), False), -]) +@pytest.mark.parametrize( + "a, b, expected", + [ + (2, 2, True), + ("1", "1", True), + (1.5, 1.5, True), + ([1, 2, 3], [1, 2, 3], True), + ({1, 2, 3}, {1, 2, 3}, True), + ({"a", "b", "c"}, {"c", "a", "b"}, True), + ({"a": 0, "z": 42}, {"a": 0, "z": 42}, True), + (np.array([1, 2, 3]), np.array([1, 2, 3]), True), + ({"c": np.array([5, 5])}, {"c": np.array([5, 5])}, True), + (pd.Series([5, 5, 5]), pd.Series([5, 5, 5]), True), + (pd.DataFrame([10, 10]), pd.DataFrame([10, 10]), True), + ([pd.DataFrame(["a", "b", "c"])], [pd.DataFrame(["a", "b", "c"])], True), + ({"a": np.array([1, 2, 3])}, {"a": np.array([1, 2, 3])}, True), + ({"a": {"x": pd.Series([5.0, 6.0])}}, {"a": {"x": pd.Series([5.0, 6.0])}}, True), + ({"a": 20, "b": 30}, {"b": 30, "a": 20}, True), + (1, 2.0, False), + ("1", 2, False), + ([1, 2, 3], 5, False), + ([1, 2, 3], [1, 2], False), + ([1, 2, 3], [3, 2, 1], False), + (["a", "b"], {"a", "b"}, False), + ({"a"}, {"a", "b"}, False), + ({"a"}, {"b"}, False), + ({"a", "b"}, np.array(["a", "b"]), False), + (np.array([3, 4, 5]), np.array([3, 4]), False), + ({"c": np.array([5, 5])}, {"c": np.array([5, 6])}, False), + (pd.Series([5, 5, 5]), pd.Series([5, 6, 5]), False), + (pd.Series([1, 2, 3]), pd.Series([1, 2]), False), + (pd.DataFrame([10, 10]), pd.DataFrame([10, 7]), False), + (pd.DataFrame([10, 20, 30]), pd.DataFrame([10, 20]), False), + ([pd.DataFrame(["a", "b", "c"])], [pd.DataFrame(["a", "b", "d"])], False), + ({"a": np.array([1, 2, 3])}, {"a": np.array([1, 2, 5])}, False), + ({"a": {"x": pd.Series([5.0, 6.0])}}, {"a": {"x": pd.Series([5.0, 7.0])}}, False), + (pd.Series([5, 5, 5]), np.array([5, 5, 5]), False), + (np.array([8, 8, 8]), pd.DataFrame([8, 8, 8]), False), + (pd.Series([3, 3, 3]), pd.DataFrame([3, 3, 3]), False), + ], +) def test_is_equal(a, b, expected): assert is_equal(a, b) == expected -@pytest.mark.parametrize("mixin_harness_fixture, expected", [ - ({}, - [Parameter('param_to_ignore', Parameter.POSITIONAL_OR_KEYWORD), - Parameter('a_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=int), - Parameter('a_param_2', Parameter.POSITIONAL_OR_KEYWORD, annotation=float), - Parameter('b_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=list), - Parameter('c_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=bool), - Parameter('d_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=np.ndarray), - Parameter('e_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=pd.Series), - Parameter('f_param_1', Parameter.POSITIONAL_OR_KEYWORD, annotation=pd.DataFrame)]), -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, expected", + [ + ( + {}, + [ + Parameter("param_to_ignore", Parameter.POSITIONAL_OR_KEYWORD), + Parameter("a_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=int), + Parameter("a_param_2", Parameter.POSITIONAL_OR_KEYWORD, annotation=float), + Parameter("b_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=list), + Parameter("c_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=bool), + Parameter("d_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=np.ndarray), + Parameter("e_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=pd.Series), + Parameter("f_param_1", Parameter.POSITIONAL_OR_KEYWORD, annotation=pd.DataFrame), + ], + ), + ], + indirect=["mixin_harness_fixture"], +) def test_get_param_signatures(mixin_harness_fixture, expected): obtained = mixin_harness_fixture._get_param_signatures() assert obtained == expected -@pytest.mark.parametrize("mixin_harness_fixture, expected", [ - ({}, - {'param_to_ignore': Parameter.empty, 'a_param_1': int, - 'a_param_2': float, 'b_param_1': list, 'c_param_1': bool, - 'd_param_1': np.ndarray, 'e_param_1': pd.Series, - 'f_param_1': pd.DataFrame}), -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, expected", + [ + ( + {}, + { + "param_to_ignore": Parameter.empty, + "a_param_1": int, + "a_param_2": float, + "b_param_1": list, + "c_param_1": bool, + "d_param_1": np.ndarray, + "e_param_1": pd.Series, + "f_param_1": pd.DataFrame, + }, + ), + ], + indirect=["mixin_harness_fixture"], +) def test_get_param_type_annotations(mixin_harness_fixture, expected): obtained = mixin_harness_fixture._get_param_type_annotations() assert obtained == expected -@pytest.mark.parametrize("mixin_harness_fixture, expected", [ - ({}, - ['a_param_1', 'a_param_2', 'b_param_1', 'c_param_1', 'd_param_1', - 'e_param_1', 'f_param_1', 'param_to_ignore']), -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, expected", + [ + ( + {}, + [ + "a_param_1", + "a_param_2", + "b_param_1", + "c_param_1", + "d_param_1", + "e_param_1", + "f_param_1", + "param_to_ignore", + ], + ), + ], + indirect=["mixin_harness_fixture"], +) def test_get_param_names(mixin_harness_fixture, expected): obtained = mixin_harness_fixture._get_param_names() assert obtained == expected -@pytest.mark.parametrize("mixin_harness_fixture, expected", [ - ({}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({'a_param_1': 2, 'a_param_2': 10.0, 'b_param_1': [1], 'c_param_1': False}, - {'a_param_1': 2, 'a_param_2': 10.0, 'b_param_1': [1], - 'c_param_1': False, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}) -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, expected", + [ + ( + {}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {"a_param_1": 2, "a_param_2": 10.0, "b_param_1": [1], "c_param_1": False}, + { + "a_param_1": 2, + "a_param_2": 10.0, + "b_param_1": [1], + "c_param_1": False, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ], + indirect=["mixin_harness_fixture"], +) def test_get_params(mixin_harness_fixture, expected): obtained = mixin_harness_fixture.get_params() is_equal(obtained, expected) -@pytest.mark.parametrize("mixin_harness_fixture, params_to_set, expected", [ - ({}, - {'a_param_1': 5}, - {'a_param_1': 5, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'a_param_2': 10.0}, - {'a_param_1': 8, 'a_param_2': 10.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'b_param_1': [3, 4, 5]}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [3, 4, 5], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'a_param_1': 20, 'a_param_2': 3.14, 'b_param_1': [9, 10]}, - {'a_param_1': 20, 'a_param_2': 3.14, 'b_param_1': [9, 10], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'d_param_1': np.array([20, 20]), 'e_param_1': pd.Series([1, 2, 3]), 'b_param_1': [9, 10]}, - {'a_param_1': 20, 'a_param_2': 3.14, 'b_param_1': [9, 10], - 'c_param_1': True, 'd_param_1': np.array([20, 20]), - 'e_param_1': pd.Series([1, 2, 3]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, params_to_set, expected", + [ + ( + {}, + {"a_param_1": 5}, + { + "a_param_1": 5, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"a_param_2": 10.0}, + { + "a_param_1": 8, + "a_param_2": 10.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"b_param_1": [3, 4, 5]}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [3, 4, 5], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"a_param_1": 20, "a_param_2": 3.14, "b_param_1": [9, 10]}, + { + "a_param_1": 20, + "a_param_2": 3.14, + "b_param_1": [9, 10], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"d_param_1": np.array([20, 20]), "e_param_1": pd.Series([1, 2, 3]), "b_param_1": [9, 10]}, + { + "a_param_1": 20, + "a_param_2": 3.14, + "b_param_1": [9, 10], + "c_param_1": True, + "d_param_1": np.array([20, 20]), + "e_param_1": pd.Series([1, 2, 3]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ], + indirect=["mixin_harness_fixture"], +) def test_set_params_basic(mixin_harness_fixture, params_to_set, expected): mixin_harness_fixture.set_params(**params_to_set) obtained = mixin_harness_fixture.get_params() is_equal(obtained, expected) -@pytest.mark.parametrize("mixin_harness_fixture, params_to_set, expected", [ - ({}, - {'a_param': 5}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'something_random': 10.0}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - -], indirect=["mixin_harness_fixture"]) -def test_set_params_with_invalid_params(mixin_harness_fixture, - params_to_set, expected): +@pytest.mark.parametrize( + "mixin_harness_fixture, params_to_set, expected", + [ + ( + {}, + {"a_param": 5}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"something_random": 10.0}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ], + indirect=["mixin_harness_fixture"], +) +def test_set_params_with_invalid_params(mixin_harness_fixture, params_to_set, expected): with warnings.catch_warnings(record=True) as w: mixin_harness_fixture.set_params(**params_to_set) - assert 'not valid and is being ignored' in str(w[-1].message) + assert "not valid and is being ignored" in str(w[-1].message) obtained = mixin_harness_fixture.get_params() is_equal(obtained, expected) -@pytest.mark.parametrize("mixin_harness_fixture, params_to_set, expected", [ - ({}, - {'a_param_1': 'hello'}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'a_param_2': [5]}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'b_param_1': 1}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'d_param_1': [1, 2, 3]}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - - ({}, - {'e_param_1': {1, 2, 3}}, - {'a_param_1': 8, 'a_param_2': 42.0, 'b_param_1': [1, 2, 3], - 'c_param_1': True, 'd_param_1': np.array([5, 5]), - 'e_param_1': pd.Series([4.0, 5.0]), 'f_param_1': pd.DataFrame([1, 2, 3])}), - -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, params_to_set, expected", + [ + ( + {}, + {"a_param_1": "hello"}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"a_param_2": [5]}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"b_param_1": 1}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"d_param_1": [1, 2, 3]}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ( + {}, + {"e_param_1": {1, 2, 3}}, + { + "a_param_1": 8, + "a_param_2": 42.0, + "b_param_1": [1, 2, 3], + "c_param_1": True, + "d_param_1": np.array([5, 5]), + "e_param_1": pd.Series([4.0, 5.0]), + "f_param_1": pd.DataFrame([1, 2, 3]), + }, + ), + ], + indirect=["mixin_harness_fixture"], +) def test_set_params_with_invalid_type(mixin_harness_fixture, params_to_set, expected): with warnings.catch_warnings(record=True) as w: mixin_harness_fixture.set_params(**params_to_set) - assert 'should be of type' in str(w[-1].message) + assert "should be of type" in str(w[-1].message) obtained = mixin_harness_fixture.get_params() is_equal(obtained, expected) -@pytest.mark.parametrize("mixin_harness_fixture, params_to_set, data_params, expected", [ - ({}, - {'a_param_1': 42}, - {'a_param_1'}, - True), - - ({}, - {'a_param_1': 8}, - {'a_param_1'}, - False), - - ({}, - {'a_param_1': 8.0}, - {'a_param_1'}, - False), - - ({}, - {'a_param_2': 3.0}, - {'a_param_1'}, - False), - - ({}, - {'a_param_2': 2.5, 'b_param_1': ['a', 'b', 'c']}, - {'b_param_1'}, - True), - - ({}, - {'a_param_1': 10, 'a_param_2': 9.0}, - {'b_param_1'}, - False), - - ({}, - {'d_param_1': np.array([98, 99, 100]), 'a_param_2': 9.0}, - {'d_param_1'}, - True), - - ({}, - {'d_param_1': np.array([98, 99, 100]), 'e_param_1': pd.Series([1, 3, 5])}, - {'e_param_1'}, - True), - - ({}, - {'d_param_1': np.array([98, 99, 100]), 'e_param_1': pd.Series([4.0, 5.0])}, - {'e_param_1'}, - False), - - -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, params_to_set, data_params, expected", + [ + ({}, {"a_param_1": 42}, {"a_param_1"}, True), + ({}, {"a_param_1": 8}, {"a_param_1"}, False), + ({}, {"a_param_1": 8.0}, {"a_param_1"}, False), + ({}, {"a_param_2": 3.0}, {"a_param_1"}, False), + ({}, {"a_param_2": 2.5, "b_param_1": ["a", "b", "c"]}, {"b_param_1"}, True), + ({}, {"a_param_1": 10, "a_param_2": 9.0}, {"b_param_1"}, False), + ({}, {"d_param_1": np.array([98, 99, 100]), "a_param_2": 9.0}, {"d_param_1"}, True), + ({}, {"d_param_1": np.array([98, 99, 100]), "e_param_1": pd.Series([1, 3, 5])}, {"e_param_1"}, True), + ({}, {"d_param_1": np.array([98, 99, 100]), "e_param_1": pd.Series([4.0, 5.0])}, {"e_param_1"}, False), + ], + indirect=["mixin_harness_fixture"], +) def test_needs_data_refresh(mixin_harness_fixture, params_to_set, data_params, expected): mixin_harness_fixture.set_params(**params_to_set) obtained = mixin_harness_fixture.needs_data_refresh(data_params) assert obtained == expected -@pytest.mark.parametrize("mixin_harness_fixture, data_params, expected", [ - ({'updated_params': {'a_param_1', 'b_param_1'}}, - {'a_param_1'}, - {'b_param_1'}), - - ({'updated_params': {'a_param_1', 'a_param_2', 'b_param_1'}}, - {'a_param_1', 'a_param_2'}, - {'b_param_1'}), -], indirect=["mixin_harness_fixture"]) +@pytest.mark.parametrize( + "mixin_harness_fixture, data_params, expected", + [ + ({"updated_params": {"a_param_1", "b_param_1"}}, {"a_param_1"}, {"b_param_1"}), + ({"updated_params": {"a_param_1", "a_param_2", "b_param_1"}}, {"a_param_1", "a_param_2"}, {"b_param_1"}), + ], + indirect=["mixin_harness_fixture"], +) def test_clear_updated_params(mixin_harness_fixture, data_params, expected): mixin_harness_fixture.clear_updated_params(data_params) assert mixin_harness_fixture._updated_params == expected diff --git a/allensdk/test/brain_observatory/test_static_gratings.py b/allensdk/test/brain_observatory/test_static_gratings.py index 48e3d9dc2a..fb5f0df931 100644 --- a/allensdk/test/brain_observatory/test_static_gratings.py +++ b/allensdk/test/brain_observatory/test_static_gratings.py @@ -41,58 +41,48 @@ @pytest.fixture def dataset(): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) - dataset.get_stimulus_table=MagicMock(name='get_stimulus_table', - return_value=MagicMock()) - + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) + dataset.get_stimulus_table = MagicMock(name="get_stimulus_table", return_value=MagicMock()) + return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + def mock_sweep_response(): - sweep_response = MagicMock(name='sweep_response') - mean_sweep_response = MagicMock(name='mean_sweep_response') - pval = MagicMock(name='pval') - - return MagicMock(name='get_sweep_response', - return_value=(sweep_response, - mean_sweep_response, - pval)) - -@patch.object(StimulusAnalysis, - 'get_speed_tuning', - mock_speed_tuning()) -@patch.object(StimulusAnalysis, - 'get_sweep_response', - mock_sweep_response()) -@pytest.mark.parametrize('trigger', (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) + sweep_response = MagicMock(name="sweep_response") + mean_sweep_response = MagicMock(name="mean_sweep_response") + pval = MagicMock(name="pval") + + return MagicMock(name="get_sweep_response", return_value=(sweep_response, mean_sweep_response, pval)) + + +@patch.object(StimulusAnalysis, "get_speed_tuning", mock_speed_tuning()) +@patch.object(StimulusAnalysis, "get_sweep_response", mock_sweep_response()) +@pytest.mark.parametrize("trigger", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) def test_harness(dataset, trigger): sg = StaticGratings(dataset) diff --git a/allensdk/test/brain_observatory/test_stimulus_analysis.py b/allensdk/test/brain_observatory/test_stimulus_analysis.py index 5d88a7288b..53043d21c5 100644 --- a/allensdk/test/brain_observatory/test_stimulus_analysis.py +++ b/allensdk/test/brain_observatory/test_stimulus_analysis.py @@ -40,45 +40,40 @@ @pytest.fixture def dataset(): - dataset = MagicMock(name='dataset') - - timestamps = MagicMock(name='timestamps') - celltraces = MagicMock(name='celltraces') - dataset.get_corrected_fluorescence_traces = \ - MagicMock(name='get_corrected_fluorescence_traces', - return_value=(timestamps, celltraces)) - dataset.get_roi_ids = MagicMock(name='get_roi_ids') - dataset.get_cell_specimen_ids = MagicMock(name='get_cell_specimen_ids') + dataset = MagicMock(name="dataset") + + timestamps = MagicMock(name="timestamps") + celltraces = MagicMock(name="celltraces") + dataset.get_corrected_fluorescence_traces = MagicMock( + name="get_corrected_fluorescence_traces", return_value=(timestamps, celltraces) + ) + dataset.get_roi_ids = MagicMock(name="get_roi_ids") + dataset.get_cell_specimen_ids = MagicMock(name="get_cell_specimen_ids") dff_traces = MagicMock(name="dfftraces") - dataset.get_dff_traces = MagicMock(name='get_dff_traces', - return_value=(None, dff_traces)) - dxcm = MagicMock(name='dxcm') - dxtime = MagicMock(name='dxtime') - dataset.get_running_speed=MagicMock(name='get_running_speed', - return_value=(dxcm, dxtime)) + dataset.get_dff_traces = MagicMock(name="get_dff_traces", return_value=(None, dff_traces)) + dxcm = MagicMock(name="dxcm") + dxtime = MagicMock(name="dxtime") + dataset.get_running_speed = MagicMock(name="get_running_speed", return_value=(dxcm, dxtime)) return dataset + def mock_speed_tuning(): - binned_dx_sp = MagicMock(name='binned_dx_sp') - binned_cells_sp = MagicMock(name='binned_cells_sp') - binned_dx_vis = MagicMock(name='binned_dx_vis') - binned_cells_vis = MagicMock(name='binned_cells_vis') - peak_run = MagicMock(name='peak_run') - - return MagicMock(name='get_speed_tuning', - return_value=(binned_dx_sp, - binned_cells_sp, - binned_dx_vis, - binned_cells_vis, - peak_run)) - - -@pytest.mark.parametrize('trigger', - (1,2,3,4,5)) -def test_harness(dataset, - trigger): - with patch('allensdk.brain_observatory.stimulus_analysis.StimulusAnalysis.get_speed_tuning', - mock_speed_tuning()) as get_speed_tuning: + binned_dx_sp = MagicMock(name="binned_dx_sp") + binned_cells_sp = MagicMock(name="binned_cells_sp") + binned_dx_vis = MagicMock(name="binned_dx_vis") + binned_cells_vis = MagicMock(name="binned_cells_vis") + peak_run = MagicMock(name="peak_run") + + return MagicMock( + name="get_speed_tuning", return_value=(binned_dx_sp, binned_cells_sp, binned_dx_vis, binned_cells_vis, peak_run) + ) + + +@pytest.mark.parametrize("trigger", (1, 2, 3, 4, 5)) +def test_harness(dataset, trigger): + with patch( + "allensdk.brain_observatory.stimulus_analysis.StimulusAnalysis.get_speed_tuning", mock_speed_tuning() + ) as get_speed_tuning: sa = StimulusAnalysis(dataset) assert sa._timestamps == StimulusAnalysis._PRELOAD diff --git a/allensdk/test/brain_observatory/test_stimulus_info.py b/allensdk/test/brain_observatory/test_stimulus_info.py index 62fbdcdb11..3404dac9d9 100755 --- a/allensdk/test/brain_observatory/test_stimulus_info.py +++ b/allensdk/test/brain_observatory/test_stimulus_info.py @@ -3,252 +3,274 @@ import os from allensdk.core.brain_observatory_nwb_data_set import BrainObservatoryNwbDataSet, si from pathlib import Path + NWB_FLAVORS = [] -if 'TEST_NWB_FILES' in os.environ: - nwb_list_file = os.environ['TEST_NWB_FILES'] +if "TEST_NWB_FILES" in os.environ: + nwb_list_file = os.environ["TEST_NWB_FILES"] else: - nwb_list_file = str(Path(__file__).parent / '..' / 'core' / 'nwb_files.txt') + nwb_list_file = str(Path(__file__).parent / ".." / "core" / "nwb_files.txt") -if nwb_list_file == 'skip': +if nwb_list_file == "skip": NWB_FLAVORS = [] else: - with open(nwb_list_file, 'r') as f: + with open(nwb_list_file, "r") as f: NWB_FLAVORS = [l.strip() for l in f] + @pytest.fixture(params=NWB_FLAVORS) def data_set(request): data_set = BrainObservatoryNwbDataSet(request.param) return data_set + def test_BinaryIntervalSearchTree(): + bist = si.BinaryIntervalSearchTree([(0, 0.9, "A"), (1, 1.9, "B"), (3, 3.9, "D"), (2, 2.9, "C")]) + assert bist.search(1.5)[2] == "B" + assert bist.search(0)[2] == "A" + assert bist.search(2.5)[2] == "C" + assert bist.search(3.5)[2] == "D" - bist = si.BinaryIntervalSearchTree([(0, .9, 'A'), (1, 1.9, 'B'), (3, 3.9, 'D'), (2, 2.9, 'C')]) - assert bist.search(1.5)[2] == 'B' - assert bist.search(0)[2] == 'A' - assert bist.search(2.5)[2] == 'C' - assert bist.search(3.5)[2] == 'D' def test_BinaryIntervalSearchTree_shared_endpoint(): - - bist = si.BinaryIntervalSearchTree([(0, 1, 'A'), (1, 2, 'B')]) - assert bist.search(0)[2] == 'A' - assert bist.search(1)[2] == 'A' - assert bist.search(1.5)[2] == 'B' + bist = si.BinaryIntervalSearchTree([(0, 1, "A"), (1, 2, "B")]) + assert bist.search(0)[2] == "A" + assert bist.search(1)[2] == "A" + assert bist.search(1.5)[2] == "B" + def test_pixels_to_visual_degrees(): m = si.BrainObservatoryMonitor() - np.testing.assert_almost_equal(m.pixels_to_visual_degrees(1), 0.103270443661,10) + np.testing.assert_almost_equal(m.pixels_to_visual_degrees(1), 0.103270443661, 10) -@pytest.mark.skipif(not os.path.exists('/projects/neuralcoding'), - reason="test NWB file not available") -def test_StimulusSearch(data_set): +@pytest.mark.skipif(not os.path.exists("/projects/neuralcoding"), reason="test NWB file not available") +def test_StimulusSearch(data_set): epoch_df = data_set.get_stimulus_epoch_table() s = si.StimulusSearch(data_set) - assert len(s.search(epoch_df.iloc[2]['end'])) == 3 - assert s.search(epoch_df.iloc[2]['end'] + 1) is None + assert len(s.search(epoch_df.iloc[2]["end"])) == 3 + assert s.search(epoch_df.iloc[2]["end"] + 1) is None assert len(s.search(752)) == 3 def test_sessions_with_stimulus(): - for session_type, stimulus_type_list in si.SESSION_STIMULUS_MAP.items(): for stimulus_type in stimulus_type_list: assert session_type in si.sessions_with_stimulus(stimulus_type) -def test_stimuli_in_session(): +def test_stimuli_in_session(): for session_type, stimulus_type_list in si.SESSION_STIMULUS_MAP.items(): for stimulus_type in stimulus_type_list: assert session_type in si.sessions_with_stimulus(stimulus_type) -def test_stimuli_in_session(): - test_dict = {si.THREE_SESSION_A:4, - si.THREE_SESSION_B:4, - si.THREE_SESSION_C:4, - si.THREE_SESSION_C2:5} +def test_stimuli_in_session(): + test_dict = {si.THREE_SESSION_A: 4, si.THREE_SESSION_B: 4, si.THREE_SESSION_C: 4, si.THREE_SESSION_C2: 5} for session in si.SESSION_LIST: assert session in si.SESSION_STIMULUS_MAP assert len(si.stimuli_in_session(session)) == test_dict[session] assert len(si.SESSION_STIMULUS_MAP) == len(si.SESSION_LIST) == 4 + def test_all_stimuli(): assert len(si.all_stimuli()) == 10 + def test_rotate(): - np.testing.assert_array_almost_equal(np.array(si.rotate(1,1,np.pi)), np.array([-1,-1])) + np.testing.assert_array_almost_equal(np.array(si.rotate(1, 1, np.pi)), np.array([-1, -1])) -def test_get_spatial_grating(): +def test_get_spatial_grating(): data = si.get_spatial_grating(height=100, aspect_ratio=2, ori=45, pix_per_cycle=10, phase=0, p2p_amp=2, baseline=1) - assert data.shape == (100,200) - np.testing.assert_almost_equal(data[0,0], data[-1,-1]) + assert data.shape == (100, 200) + np.testing.assert_almost_equal(data[0, 0], data[-1, -1]) np.testing.assert_almost_equal(data.max(), 2, 3) np.testing.assert_almost_equal(data.min(), 0, 3) - np.testing.assert_almost_equal(data[50,100], 2) - -def test_get_spatio_temporal_grating(): + np.testing.assert_almost_equal(data[50, 100], 2) - for t, test_val in zip([0,.5,1], [2,0,2]): - data = si.get_spatio_temporal_grating(t, height=100, aspect_ratio=2, ori=45, pix_per_cycle=10, phase=0, p2p_amp=2, baseline=1, temporal_frequency=1) - np.testing.assert_almost_equal(data[50,100], test_val) - data = si.get_spatio_temporal_grating(0, height=100, - aspect_ratio=2, - ori=45, - pix_per_cycle=20, - phase=0, - p2p_amp=2, - baseline=1, - temporal_frequency=1) +def test_get_spatio_temporal_grating(): + for t, test_val in zip([0, 0.5, 1], [2, 0, 2]): + data = si.get_spatio_temporal_grating( + t, + height=100, + aspect_ratio=2, + ori=45, + pix_per_cycle=10, + phase=0, + p2p_amp=2, + baseline=1, + temporal_frequency=1, + ) + np.testing.assert_almost_equal(data[50, 100], test_val) + + data = si.get_spatio_temporal_grating( + 0, height=100, aspect_ratio=2, ori=45, pix_per_cycle=20, phase=0, p2p_amp=2, baseline=1, temporal_frequency=1 + ) x1 = data[50, 100] - data = si.get_spatio_temporal_grating(.5, height=100, - aspect_ratio=2, - ori=45, - pix_per_cycle=20, - phase=.5, - p2p_amp=2, - baseline=1, - temporal_frequency=1) + data = si.get_spatio_temporal_grating( + 0.5, + height=100, + aspect_ratio=2, + ori=45, + pix_per_cycle=20, + phase=0.5, + p2p_amp=2, + baseline=1, + temporal_frequency=1, + ) x2 = data[50, 100] np.testing.assert_almost_equal(x1, x2) -def test_map_template_monitor(): - - - np.testing.assert_almost_equal(np.array((500, 250)), - si.map_template_coordinate_to_monitor_coordinate((20, 20), (1000, 500), (40, 40))) +def test_map_template_monitor(): + np.testing.assert_almost_equal( + np.array((500, 250)), si.map_template_coordinate_to_monitor_coordinate((20, 20), (1000, 500), (40, 40)) + ) + np.testing.assert_almost_equal( + np.array((20, 20)), si.map_monitor_coordinate_to_template_coordinate((500, 250), (1000, 500), (40, 40)) + ) - np.testing.assert_almost_equal(np.array((20,20)), - si.map_monitor_coordinate_to_template_coordinate((500, 250), (1000, 500), (40,40))) def test_lsn_monitor(): - lsn4_template_coordinate = (8, 14) - lsn4_monitor_coordinate = np.array(si.MONITOR_DIMENSIONS)/2 #(600,960) - np.testing.assert_almost_equal(np.array(lsn4_monitor_coordinate), - si.map_stimulus_coordinate_to_monitor_coordinate(lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_4DEG)) + lsn4_monitor_coordinate = np.array(si.MONITOR_DIMENSIONS) / 2 # (600,960) + np.testing.assert_almost_equal( + np.array(lsn4_monitor_coordinate), + si.map_stimulus_coordinate_to_monitor_coordinate( + lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_4DEG + ), + ) lsn4_template_coordinate = (4, 7) - lsn4_monitor_coordinate = np.array(si.MONITOR_DIMENSIONS)/2#(600,960) - np.testing.assert_almost_equal(np.array(lsn4_monitor_coordinate), - si.map_stimulus_coordinate_to_monitor_coordinate(lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_8DEG)) - - lsn4_template_coordinate = (0,0) - lsn4_monitor_coordinate = (240,330) - np.testing.assert_almost_equal(np.array(lsn4_monitor_coordinate), - si.map_stimulus_coordinate_to_monitor_coordinate(lsn4_template_coordinate, - si.MONITOR_DIMENSIONS, - si.LOCALLY_SPARSE_NOISE_4DEG)) - - lsn4_template_coordinate = (0,0) - lsn4_monitor_coordinate = (240,330) - np.testing.assert_almost_equal(np.array(lsn4_template_coordinate), - si.monitor_coordinate_to_lsn_coordinate(lsn4_monitor_coordinate, - si.MONITOR_DIMENSIONS, - si.LOCALLY_SPARSE_NOISE_4DEG)) - - - lsn4_template_coordinate = (0,0) - lsn4_monitor_coordinate = (240,330) - np.testing.assert_almost_equal(np.array(lsn4_monitor_coordinate), - si.map_stimulus_coordinate_to_monitor_coordinate(lsn4_template_coordinate, - si.MONITOR_DIMENSIONS, - si.LOCALLY_SPARSE_NOISE_8DEG)) - - lsn4_template_coordinate = (0,0) - lsn4_monitor_coordinate = (240,330) - np.testing.assert_almost_equal(np.array(lsn4_template_coordinate), - si.monitor_coordinate_to_lsn_coordinate(lsn4_monitor_coordinate, - si.MONITOR_DIMENSIONS, - si.LOCALLY_SPARSE_NOISE_8DEG)) + lsn4_monitor_coordinate = np.array(si.MONITOR_DIMENSIONS) / 2 # (600,960) + np.testing.assert_almost_equal( + np.array(lsn4_monitor_coordinate), + si.map_stimulus_coordinate_to_monitor_coordinate( + lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_8DEG + ), + ) + + lsn4_template_coordinate = (0, 0) + lsn4_monitor_coordinate = (240, 330) + np.testing.assert_almost_equal( + np.array(lsn4_monitor_coordinate), + si.map_stimulus_coordinate_to_monitor_coordinate( + lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_4DEG + ), + ) + + lsn4_template_coordinate = (0, 0) + lsn4_monitor_coordinate = (240, 330) + np.testing.assert_almost_equal( + np.array(lsn4_template_coordinate), + si.monitor_coordinate_to_lsn_coordinate( + lsn4_monitor_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_4DEG + ), + ) + + lsn4_template_coordinate = (0, 0) + lsn4_monitor_coordinate = (240, 330) + np.testing.assert_almost_equal( + np.array(lsn4_monitor_coordinate), + si.map_stimulus_coordinate_to_monitor_coordinate( + lsn4_template_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_8DEG + ), + ) + + lsn4_template_coordinate = (0, 0) + lsn4_monitor_coordinate = (240, 330) + np.testing.assert_almost_equal( + np.array(lsn4_template_coordinate), + si.monitor_coordinate_to_lsn_coordinate( + lsn4_monitor_coordinate, si.MONITOR_DIMENSIONS, si.LOCALLY_SPARSE_NOISE_8DEG + ), + ) -def test_natural_scene_monitor(): - template_coordinate = (0,0) +def test_natural_scene_monitor(): + template_coordinate = (0, 0) monitor_coordinate = (141, 373) - np.testing.assert_almost_equal(np.array(monitor_coordinate), - si.natural_scene_coordinate_to_monitor_coordinate(template_coordinate, - si.MONITOR_DIMENSIONS)) + np.testing.assert_almost_equal( + np.array(monitor_coordinate), + si.natural_scene_coordinate_to_monitor_coordinate(template_coordinate, si.MONITOR_DIMENSIONS), + ) - template_coordinate = (0,0) + template_coordinate = (0, 0) monitor_coordinate = (141, 373) - np.testing.assert_almost_equal(np.array(template_coordinate), - si.map_monitor_coordinate_to_stimulus_coordinate(monitor_coordinate, - si.MONITOR_DIMENSIONS, - si.NATURAL_SCENES)) + np.testing.assert_almost_equal( + np.array(template_coordinate), + si.map_monitor_coordinate_to_stimulus_coordinate(monitor_coordinate, si.MONITOR_DIMENSIONS, si.NATURAL_SCENES), + ) -def test_natural_movie_monitor(): - template_coordinate = (0,0) +def test_natural_movie_monitor(): + template_coordinate = (0, 0) monitor_coordinate = (60, 0) - np.testing.assert_almost_equal(np.array(monitor_coordinate), - si.natural_movie_coordinate_to_monitor_coordinate(template_coordinate, - si.MONITOR_DIMENSIONS)) + np.testing.assert_almost_equal( + np.array(monitor_coordinate), + si.natural_movie_coordinate_to_monitor_coordinate(template_coordinate, si.MONITOR_DIMENSIONS), + ) - template_coordinate = (0,0) + template_coordinate = (0, 0) monitor_coordinate = (60, 0) - np.testing.assert_almost_equal(np.array(template_coordinate), - si.map_monitor_coordinate_to_stimulus_coordinate(monitor_coordinate, - si.MONITOR_DIMENSIONS, - si.NATURAL_MOVIE_ONE)) + np.testing.assert_almost_equal( + np.array(template_coordinate), + si.map_monitor_coordinate_to_stimulus_coordinate( + monitor_coordinate, si.MONITOR_DIMENSIONS, si.NATURAL_MOVIE_ONE + ), + ) -def test_bijective_all_stimuli(): +def test_bijective_all_stimuli(): for stimulus in si.all_stimuli(): + template_coordinate = (10, 10) + monitor_coordinate = si.map_stimulus_coordinate_to_monitor_coordinate( + template_coordinate, si.MONITOR_DIMENSIONS, stimulus + ) - template_coordinate = (10,10) - monitor_coordinate = si.map_stimulus_coordinate_to_monitor_coordinate(template_coordinate, - si.MONITOR_DIMENSIONS, - stimulus) - - new_template_coordinate = si.map_monitor_coordinate_to_stimulus_coordinate(monitor_coordinate, - si.MONITOR_DIMENSIONS, - stimulus) + new_template_coordinate = si.map_monitor_coordinate_to_stimulus_coordinate( + monitor_coordinate, si.MONITOR_DIMENSIONS, stimulus + ) np.testing.assert_array_almost_equal(template_coordinate, new_template_coordinate) - for original_loc in [(0,0), (10,10)]: - + for original_loc in [(0, 0), (10, 10)]: for target_stimulus in si.all_stimuli(): - new_loc = si.map_stimulus(original_loc, stimulus, target_stimulus, si.MONITOR_DIMENSIONS) new_original_loc = si.map_stimulus(new_loc, target_stimulus, stimulus, si.MONITOR_DIMENSIONS) np.testing.assert_array_almost_equal(new_original_loc, original_loc) -def test_monitor_basic_spatial_unit(): - m = si.Monitor(300,400, 5, 'cm') - m.set_spatial_unit('cm') +def test_monitor_basic_spatial_unit(): + m = si.Monitor(300, 400, 5, "cm") + m.set_spatial_unit("cm") - m.set_spatial_unit('inch') + m.set_spatial_unit("inch") np.testing.assert_almost_equal(m.panel_size, 1.968505, 5) - np.testing.assert_almost_equal(1./m.aspect_ratio, 3./4) + np.testing.assert_almost_equal(1.0 / m.aspect_ratio, 3.0 / 4) np.testing.assert_almost_equal(m.height, 0.46500143220300011) np.testing.assert_almost_equal(m.width, 0.62000190960400015) np.testing.assert_almost_equal(m.pixel_size, 0.0015500047740100004) - m.set_spatial_unit('cm') + m.set_spatial_unit("cm") np.testing.assert_almost_equal(m.panel_size, 5) - np.testing.assert_almost_equal(1./m.aspect_ratio, 3./4) + np.testing.assert_almost_equal(1.0 / m.aspect_ratio, 3.0 / 4) np.testing.assert_almost_equal(m.height, 3) np.testing.assert_almost_equal(m.width, 4) - np.testing.assert_almost_equal(m.pixel_size, .01) + np.testing.assert_almost_equal(m.pixel_size, 0.01) def test_pixels_to_visual_degrees(): - m = si.BrainObservatoryMonitor() np.testing.assert_almost_equal(m.pixels_to_visual_degrees(45), 4.64716996476) @@ -257,46 +279,41 @@ def test_pixels_to_visual_degrees(): np.testing.assert_almost_equal(m.pixels_to_visual_degrees(1), 0.103270443661) np.testing.assert_almost_equal(m.pixels_to_visual_degrees(1, small_angle_approximation=False), 0.103270415704) -@pytest.mark.skipif(not os.path.exists('/projects/neuralcoding'), - reason="test NWB file not available") -def test_lsn_image_to_screen(data_set): +@pytest.mark.skipif(not os.path.exists("/projects/neuralcoding"), reason="test NWB file not available") +def test_lsn_image_to_screen(data_set): compare_set = set(data_set.list_stimuli()).intersection(si.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES) if len(compare_set) > 0: for stimulus_type in compare_set: - template = data_set.get_stimulus_template(stimulus_type) m = si.BrainObservatoryMonitor() - m.lsn_image_to_screen(template[0,:,:]).shape == si.MONITOR_DIMENSIONS + m.lsn_image_to_screen(template[0, :, :]).shape == si.MONITOR_DIMENSIONS -@pytest.mark.skipif(not os.path.exists('/projects/neuralcoding'), - reason="test NWB file not available") -def test_natural_movie_image_to_screen(data_set): +@pytest.mark.skipif(not os.path.exists("/projects/neuralcoding"), reason="test NWB file not available") +def test_natural_movie_image_to_screen(data_set): compare_set = set(data_set.list_stimuli()).intersection(si.NATURAL_MOVIE_STIMULUS_TYPES) if len(compare_set) > 0: for stimulus_type in compare_set: - template = data_set.get_stimulus_template(stimulus_type) m = si.BrainObservatoryMonitor() m.natural_movie_image_to_screen(template[0, :, :]).shape == si.MONITOR_DIMENSIONS -@pytest.mark.skipif(not os.path.exists('/projects/neuralcoding'), - reason="test NWB file not available") -def test_grating_to_screen(data_set): +@pytest.mark.skipif(not os.path.exists("/projects/neuralcoding"), reason="test NWB file not available") +def test_grating_to_screen(data_set): compare_set = set(data_set.list_stimuli()).intersection([si.STATIC_GRATINGS, si.DRIFTING_GRATINGS]) if len(compare_set) > 0: - for stimulus_type in compare_set: m = si.BrainObservatoryMonitor() curr_row = data_set.get_stimulus_table(stimulus_type).iloc[10] phase = 0 - spatial_frequency = .04 + spatial_frequency = 0.04 orientation = curr_row.orientation - template = m.grating_to_screen(phase, spatial_frequency, orientation) + template = m.grating_to_screen(phase, spatial_frequency, orientation) assert m.natural_movie_image_to_screen(template).shape == si.MONITOR_DIMENSIONS + def test_get_mask(): m = si.BrainObservatoryMonitor() mask = m.get_mask() @@ -308,15 +325,15 @@ def test_get_mask(): def test_mask(): m = si.BrainObservatoryMonitor() - assert(m._mask is None) + assert m._mask is None - assert(m.mask.sum() == 931286) - assert(m.mask.shape == si.MONITOR_DIMENSIONS) - assert(m._mask is not None) + assert m.mask.sum() == 931286 + assert m.mask.shape == si.MONITOR_DIMENSIONS + assert m._mask is not None def test_translate_image_and_fill(): - ''' + """ [[1 2 3] [4 5 6] [7 8 9]] @@ -324,52 +341,155 @@ def test_translate_image_and_fill(): [[127 4 5] [127 7 8] [127 127 127]] - ''' - + """ - X = np.array([[1,2,3],[4,5,6],[7,8,9]]) + X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) X_test = np.array([[127, 4, 5], [127, 7, 8], [127, 127, 127]]) - X_result = si.translate_image_and_fill(X, translation=(1,1)) + X_result = si.translate_image_and_fill(X, translation=(1, 1)) np.testing.assert_array_almost_equal(X_result, X_test) + def test_visual_degrees_to_pixels(): - m = si.BrainObservatoryMonitor() np.testing.assert_approx_equal(m.visual_degrees_to_pixels(4.5), 43.5749072092) -def test_spatial_frequency_to_pix_per_cycle(): +def test_spatial_frequency_to_pix_per_cycle(): m = si.BrainObservatoryMonitor() - x1 = m.spatial_frequency_to_pix_per_cycle(.1, 15.0) - x2 = m.spatial_frequency_to_pix_per_cycle(.05, 15.0) + x1 = m.spatial_frequency_to_pix_per_cycle(0.1, 15.0) + x2 = m.spatial_frequency_to_pix_per_cycle(0.05, 15.0) np.testing.assert_almost_equal(x1, 97.7072500845) - np.testing.assert_almost_equal(x2/x1, 2) + np.testing.assert_almost_equal(x2 / x1, 2) -def test_show_image(): +def test_show_image(): m = si.BrainObservatoryMonitor() img = np.zeros(si.MONITOR_DIMENSIONS) m.show_image(img, show=False, warp=True, mask=False) m.show_image(img, show=False, warp=False, mask=True) -def test_map_stimulus(): +def test_map_stimulus(): m = si.BrainObservatoryMonitor() - test_list = [(0, 0), (-5.333333333333333, -7.333333333333333), (-5.333333333333333, -7.333333333333333), (-2.6666666666666665, -3.6666666666666665), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-141.0, -373.0), (0, 0), (0, 0), (240.0, 330.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (99.0, -43.0), (240.0, 330.0), (240.0, 330.0), (240.0, 330.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (99.0, -43.0), (240.0, 330.0), (240.0, 330.0), (240.0, 330.0), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (50.666666666666664, 104.5), (99.0, -43.0), (240.0, 330.0), (240.0, 330.0), (60.0, 0.0), (-4.0, -7.333333333333333), (-4.0, -7.333333333333333), (-2.0, -3.6666666666666665), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (-81.0, -373.0), (60.0, 0.0), (60.0, 0.0), (60.0, 0.0), (-4.0, -7.333333333333333), (-4.0, -7.333333333333333), (-2.0, -3.6666666666666665), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (-81.0, -373.0), (60.0, 0.0), (60.0, 0.0), (60.0, 0.0), (-4.0, -7.333333333333333), (-4.0, -7.333333333333333), (-2.0, -3.6666666666666665), (0.0, 0.0), (0.0, 0.0), (0.0, 0.0), (-81.0, -373.0), (60.0, 0.0), (60.0, 0.0), (141.0, 373.0), (-2.2, 0.9555555555555556), (-2.2, 0.9555555555555556), (-1.1, 0.4777777777777778), (22.8, 118.11666666666666), (22.8, 118.11666666666666), (22.8, 118.11666666666666), (0.0, 0.0), (141.0, 373.0), (141.0, 373.0), (0, 0), (-5.333333333333333, -7.333333333333333), (-5.333333333333333, -7.333333333333333), (-2.6666666666666665, -3.6666666666666665), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-141.0, -373.0), (0, 0), (0, 0), (0, 0), (-5.333333333333333, -7.333333333333333), (-5.333333333333333, -7.333333333333333), (-2.6666666666666665, -3.6666666666666665), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-16.88888888888889, 0.0), (-141.0, -373.0), (0, 0), (0, 0)] + test_list = [ + (0, 0), + (-5.333333333333333, -7.333333333333333), + (-5.333333333333333, -7.333333333333333), + (-2.6666666666666665, -3.6666666666666665), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-141.0, -373.0), + (0, 0), + (0, 0), + (240.0, 330.0), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (99.0, -43.0), + (240.0, 330.0), + (240.0, 330.0), + (240.0, 330.0), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (99.0, -43.0), + (240.0, 330.0), + (240.0, 330.0), + (240.0, 330.0), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (50.666666666666664, 104.5), + (99.0, -43.0), + (240.0, 330.0), + (240.0, 330.0), + (60.0, 0.0), + (-4.0, -7.333333333333333), + (-4.0, -7.333333333333333), + (-2.0, -3.6666666666666665), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (-81.0, -373.0), + (60.0, 0.0), + (60.0, 0.0), + (60.0, 0.0), + (-4.0, -7.333333333333333), + (-4.0, -7.333333333333333), + (-2.0, -3.6666666666666665), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (-81.0, -373.0), + (60.0, 0.0), + (60.0, 0.0), + (60.0, 0.0), + (-4.0, -7.333333333333333), + (-4.0, -7.333333333333333), + (-2.0, -3.6666666666666665), + (0.0, 0.0), + (0.0, 0.0), + (0.0, 0.0), + (-81.0, -373.0), + (60.0, 0.0), + (60.0, 0.0), + (141.0, 373.0), + (-2.2, 0.9555555555555556), + (-2.2, 0.9555555555555556), + (-1.1, 0.4777777777777778), + (22.8, 118.11666666666666), + (22.8, 118.11666666666666), + (22.8, 118.11666666666666), + (0.0, 0.0), + (141.0, 373.0), + (141.0, 373.0), + (0, 0), + (-5.333333333333333, -7.333333333333333), + (-5.333333333333333, -7.333333333333333), + (-2.6666666666666665, -3.6666666666666665), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-141.0, -373.0), + (0, 0), + (0, 0), + (0, 0), + (-5.333333333333333, -7.333333333333333), + (-5.333333333333333, -7.333333333333333), + (-2.6666666666666665, -3.6666666666666665), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-16.88888888888889, 0.0), + (-141.0, -373.0), + (0, 0), + (0, 0), + ] counter = 0 for source_stimulus in sorted(si.all_stimuli()): for target_stimulus in sorted(si.all_stimuli()): - tmp = m.map_stimulus((0,0), source_stimulus, target_stimulus) + tmp = m.map_stimulus((0, 0), source_stimulus, target_stimulus) np.testing.assert_array_almost_equal(tmp, test_list[counter]) counter += 1 - np.testing.assert_array_almost_equal(m.map_stimulus(tmp, target_stimulus, source_stimulus), np.array([0,0])) + np.testing.assert_array_almost_equal( + m.map_stimulus(tmp, target_stimulus, source_stimulus), np.array([0, 0]) + ) + if __name__ == "__main__": -# + # # with open(nwb_list_file, 'r') as f: # NWB_FLAVORS = [l.strip() for l in f] # @@ -395,4 +515,4 @@ def test_map_stimulus(): # test_spatial_frequency_to_pix_per_cycle() # test_get_mask() # test_show_image() - test_map_stimulus() \ No newline at end of file + test_map_stimulus() diff --git a/allensdk/test/brain_observatory/vbn_2022/input_json_writer/test_json_writer_cli.py b/allensdk/test/brain_observatory/vbn_2022/input_json_writer/test_json_writer_cli.py index 5af5a3daf7..9e4d1d81c1 100644 --- a/allensdk/test/brain_observatory/vbn_2022/input_json_writer/test_json_writer_cli.py +++ b/allensdk/test/brain_observatory/vbn_2022/input_json_writer/test_json_writer_cli.py @@ -2,28 +2,24 @@ import json import pathlib -from allensdk.brain_observatory.vbn_2022.input_json_writer \ - .input_json_writer import VBN2022InputJsonWriter +from allensdk.brain_observatory.vbn_2022.input_json_writer.input_json_writer import VBN2022InputJsonWriter -from allensdk.brain_observatory.ecephys.write_nwb.vbn._schemas import ( - VBNInputSchema) +from allensdk.brain_observatory.ecephys.write_nwb.vbn._schemas import VBNInputSchema @pytest.mark.requires_bamboo -def test_writer_cli( - tmp_path_factory, - helper_functions): +def test_writer_cli(tmp_path_factory, helper_functions): """ This will just be a smoke test to write out the input json and validate it against the NWB writer schema """ - base_dir = pathlib.Path(tmp_path_factory.mktemp('input_json_cli')) + base_dir = pathlib.Path(tmp_path_factory.mktemp("input_json_cli")) - json_dir = base_dir / 'input_jsons' + json_dir = base_dir / "input_jsons" json_dir.mkdir() - nwb_dir = base_dir / 'nwb_files' + nwb_dir = base_dir / "nwb_files" nwb_dir.mkdir() json_prefix = "test_input_jsons" @@ -37,11 +33,8 @@ def test_writer_cli( probes_to_skip = [] for session_id in session_list: - for suffix in 'BDF': - probes_to_skip.append({ - "session": session_id, - "probe": f"probe{suffix}" - }) + for suffix in "BDF": + probes_to_skip.append({"session": session_id, "probe": f"probe{suffix}"}) json_generation_data = { "log_level": "INFO", @@ -51,10 +44,10 @@ def test_writer_cli( "nwb_output_dir": str(nwb_dir.resolve().absolute()), "json_prefix": json_prefix, "nwb_prefix": nwb_prefix, - "probes_to_skip": probes_to_skip} + "probes_to_skip": probes_to_skip, + } - writer = VBN2022InputJsonWriter(args=[], - input_data=json_generation_data) + writer = VBN2022InputJsonWriter(args=[], input_data=json_generation_data) writer.run() ct_valid = 0 @@ -63,18 +56,17 @@ def test_writer_cli( if session_id == 9: assert not expected_path.exists() else: - ct_valid += 1 assert expected_path.is_file() # read in the written file and verify that it # passes schema validation - with open(expected_path, 'rb') as in_file: + with open(expected_path, "rb") as in_file: json_data = json.load(in_file) - session_data = json_data['session_data'] - assert len(session_data['probes']) > 0 - assert len(session_data['probes'][0]['channels']) > 0 - assert len(session_data['probes'][0]['units']) > 0 + session_data = json_data["session_data"] + assert len(session_data["probes"]) > 0 + assert len(session_data["probes"][0]["channels"]) > 0 + assert len(session_data["probes"][0]["units"]) > 0 schema = VBNInputSchema() assert len(schema.validate(data=json_data)) == 0 @@ -82,5 +74,4 @@ def test_writer_cli( # make sure we actually tested some valid files assert ct_valid > 0 - helper_functions.windows_safe_cleanup_dir( - dir_path=base_dir) + helper_functions.windows_safe_cleanup_dir(dir_path=base_dir) diff --git a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/conftest.py b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/conftest.py index 9edd7af1fa..0b26ab4be2 100644 --- a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/conftest.py +++ b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/conftest.py @@ -4,9 +4,9 @@ import pathlib import tempfile -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.date_of_acquisition import \ - DateOfAcquisition +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.date_of_acquisition import ( + DateOfAcquisition, +) @pytest.fixture @@ -15,8 +15,8 @@ def smoketest_config_fixture(): config parameters for on-prem metadata writer smoketest """ config = { - "ecephys_session_id_list": [1115077618, 1081429294], - "probes_to_skip": [{"session": 1115077618, "probe": "probeC"}] + "ecephys_session_id_list": [1115077618, 1081429294], + "probes_to_skip": [{"session": 1115077618, "probe": "probeC"}], } return config @@ -27,17 +27,15 @@ def smoketest_with_failed_sessions_config_fixture(): config parameters for on-prem metadata writer smoketest """ config = { - "ecephys_session_id_list": [1051155866], - "failed_ecephys_session_id_list": [1050962145], - "probes_to_skip": [{"session": 1115077618, "probe": "probeC"}] + "ecephys_session_id_list": [1051155866], + "failed_ecephys_session_id_list": [1050962145], + "probes_to_skip": [{"session": 1115077618, "probe": "probeC"}], } return config @pytest.fixture -def patching_pickle_file_fixture( - helper_functions, - tmp_path_factory): +def patching_pickle_file_fixture(helper_functions, tmp_path_factory): """ Write mock data to some stimulus pickles. Return a dict mapping behavior_session_id to @@ -45,52 +43,30 @@ def patching_pickle_file_fixture( The session_type stored in the pickle The date_of_acquisition stored in the pickle """ - tmp_dir = tmp_path_factory.mktemp('patching_pickles') + tmp_dir = tmp_path_factory.mktemp("patching_pickles") output = dict() - this_date = DateOfAcquisition( - datetime.datetime(year=2020, month=6, day=7)).value - this_stage = 'first_stage' - pkl_data = {'start_time': this_date, - 'items': - {'behavior': - {'params': {'stage': this_stage}}}} - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmp_dir, suffix='.pkl')[1]) + this_date = DateOfAcquisition(datetime.datetime(year=2020, month=6, day=7)).value + this_stage = "first_stage" + pkl_data = {"start_time": this_date, "items": {"behavior": {"params": {"stage": this_stage}}}} + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmp_dir, suffix=".pkl")[1]) pd.to_pickle(pkl_data, pkl_path) - output[1123] = {'pkl_path': pkl_path, - 'date_of_acquisition': this_date, - 'session_type': this_stage} + output[1123] = {"pkl_path": pkl_path, "date_of_acquisition": this_date, "session_type": this_stage} - this_date = DateOfAcquisition( - datetime.datetime(year=1998, month=3, day=14)).value - this_stage = 'second_stage' - pkl_data = {'start_time': this_date, - 'items': - {'behavior': - {'params': {'stage': this_stage}}}} - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmp_dir, suffix='.pkl')[1]) + this_date = DateOfAcquisition(datetime.datetime(year=1998, month=3, day=14)).value + this_stage = "second_stage" + pkl_data = {"start_time": this_date, "items": {"behavior": {"params": {"stage": this_stage}}}} + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmp_dir, suffix=".pkl")[1]) pd.to_pickle(pkl_data, pkl_path) - output[5813] = {'pkl_path': pkl_path, - 'date_of_acquisition': this_date, - 'session_type': this_stage} + output[5813] = {"pkl_path": pkl_path, "date_of_acquisition": this_date, "session_type": this_stage} - this_date = DateOfAcquisition( - datetime.datetime(year=1974, month=7, day=22)).value - this_stage = 'third_stage' - pkl_data = {'start_time': this_date, - 'items': - {'behavior': - {'params': {'stage': this_stage}}}} - pkl_path = pathlib.Path( - tempfile.mkstemp(dir=tmp_dir, suffix='.pkl')[1]) + this_date = DateOfAcquisition(datetime.datetime(year=1974, month=7, day=22)).value + this_stage = "third_stage" + pkl_data = {"start_time": this_date, "items": {"behavior": {"params": {"stage": this_stage}}}} + pkl_path = pathlib.Path(tempfile.mkstemp(dir=tmp_dir, suffix=".pkl")[1]) pd.to_pickle(pkl_data, pkl_path) - output[2134] = {'pkl_path': pkl_path, - 'date_of_acquisition': this_date, - 'session_type': this_stage} + output[2134] = {"pkl_path": pkl_path, "date_of_acquisition": this_date, "session_type": this_stage} yield output - helper_functions.windows_safe_cleanup_dir( - dir_path=pathlib.Path(tmp_dir)) + helper_functions.windows_safe_cleanup_dir(dir_path=pathlib.Path(tmp_dir)) diff --git a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_cli.py b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_cli.py index 362271e99c..a76b249d8c 100644 --- a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_cli.py +++ b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_cli.py @@ -5,20 +5,14 @@ import pandas as pd import tempfile -from allensdk.brain_observatory.vbn_2022.metadata_writer \ - .metadata_writer import VBN2022MetadataWriterClass +from allensdk.brain_observatory.vbn_2022.metadata_writer.metadata_writer import VBN2022MetadataWriterClass @pytest.mark.requires_bamboo -@pytest.mark.parametrize( - 'on_missing_file, with_supplement', - [('skip', False), ('warn', False), ('warn', True)]) +@pytest.mark.parametrize("on_missing_file, with_supplement", [("skip", False), ("warn", False), ("warn", True)]) def test_metadata_writer_smoketest( - smoketest_config_fixture, - tmp_path_factory, - helper_functions, - on_missing_file, - with_supplement): + smoketest_config_fixture, tmp_path_factory, helper_functions, on_missing_file, with_supplement +): """ smoke test for VBN 2022 metadata writer. Requires LIMS and mtrain connections. @@ -27,31 +21,24 @@ def test_metadata_writer_smoketest( to the ecephys_sessions.csv file and test for their existence """ - output_names = ('units.csv', 'probes.csv', 'channels.csv', - 'ecephys_sessions.csv', 'behavior_sessions.csv') + output_names = ("units.csv", "probes.csv", "channels.csv", "ecephys_sessions.csv", "behavior_sessions.csv") config = copy.deepcopy(smoketest_config_fixture) - output_dir = tmp_path_factory.mktemp('vbn_metadata_smoketest') + output_dir = tmp_path_factory.mktemp("vbn_metadata_smoketest") output_dir = pathlib.Path(output_dir) - output_json_path = pathlib.Path( - tempfile.mkstemp(dir=output_dir, - suffix='.json')[1]) - config['output_dir'] = str(output_dir.resolve().absolute()) + output_json_path = pathlib.Path(tempfile.mkstemp(dir=output_dir, suffix=".json")[1]) + config["output_dir"] = str(output_dir.resolve().absolute()) if with_supplement: # add suplemental columns to configuration - supplement = [{'ecephys_session_id': 1115077618, - 'supplementA': 'cat', - 'supplementB': 5}, - {'ecephys_session_id': 1081429294, - 'supplementA': 'frog', - 'supplementB': 6}, - {'ecephys_session_id': 11111, - 'supplementA': None, - 'supplementB': None}] - - config['supplemental_data'] = supplement + supplement = [ + {"ecephys_session_id": 1115077618, "supplementA": "cat", "supplementB": 5}, + {"ecephys_session_id": 1081429294, "supplementA": "frog", "supplementB": 6}, + {"ecephys_session_id": 11111, "supplementA": None, "supplementB": None}, + ] + + config["supplemental_data"] = supplement expected_paths = [] for name in output_names: @@ -59,14 +46,14 @@ def test_metadata_writer_smoketest( assert not file_path.exists() expected_paths.append(file_path) - this_dir = pathlib.Path('.') + this_dir = pathlib.Path(".") this_dir = str(this_dir.resolve().absolute()) - config['ecephys_nwb_prefix'] = 'not_there' - config['ecephys_nwb_dir'] = this_dir - config['clobber'] = False - config['on_missing_file'] = on_missing_file - config['output_json'] = str(output_json_path.resolve().absolute()) + config["ecephys_nwb_prefix"] = "not_there" + config["ecephys_nwb_dir"] = this_dir + config["clobber"] = False + config["on_missing_file"] = on_missing_file + config["output_json"] = str(output_json_path.resolve().absolute()) writer = VBN2022MetadataWriterClass(args=[], input_data=config) writer.run() @@ -74,83 +61,71 @@ def test_metadata_writer_smoketest( # load a dict mapping the name of a metadata.csv # to the list of columns it is supposed to contain this_dir = pathlib.Path(__file__).parent - resource_dir = this_dir / 'resources' - with open(resource_dir / 'column_lookup.json', 'rb') as in_file: + resource_dir = this_dir / "resources" + with open(resource_dir / "column_lookup.json", "rb") as in_file: column_lookup = json.load(in_file) for file_path in expected_paths: assert file_path.exists() df = pd.read_csv(file_path) expected_columns = set(column_lookup[file_path.name]) - if with_supplement and file_path.name == 'ecephys_sessions.csv': - expected_columns.add('supplementA') - expected_columns.add('supplementB') + if with_supplement and file_path.name == "ecephys_sessions.csv": + expected_columns.add("supplementA") + expected_columns.add("supplementB") actual_columns = set(df.columns) assert expected_columns == actual_columns if with_supplement: - df = pd.read_csv(output_dir / 'ecephys_sessions.csv') + df = pd.read_csv(output_dir / "ecephys_sessions.csv") # make sure that no extra rows were added when adding # the supplemental data assert len(df) == 2 for expected in supplement[:2]: - this_row = df.loc[ - df.ecephys_session_id == expected['ecephys_session_id']] + this_row = df.loc[df.ecephys_session_id == expected["ecephys_session_id"]] assert len(this_row) == 1 - assert this_row.supplementA.values[0] == expected['supplementA'] - assert this_row.supplementB.values[0] == expected['supplementB'] + assert this_row.supplementA.values[0] == expected["supplementA"] + assert this_row.supplementB.values[0] == expected["supplementB"] - helper_functions.windows_safe_cleanup_dir( - dir_path=output_dir) + helper_functions.windows_safe_cleanup_dir(dir_path=output_dir) @pytest.mark.requires_bamboo -def test_with_failed_sessions( - smoketest_with_failed_sessions_config_fixture, - tmp_path_factory, - helper_functions): +def test_with_failed_sessions(smoketest_with_failed_sessions_config_fixture, tmp_path_factory, helper_functions): """ Test that metadata writer CLI can handle failed_ecephys_session_id_list """ - output_dir = pathlib.Path(tmp_path_factory.mktemp('failed_session_test')) - output_json_path = output_dir / 'output.json' + output_dir = pathlib.Path(tmp_path_factory.mktemp("failed_session_test")) + output_json_path = output_dir / "output.json" config = copy.deepcopy(smoketest_with_failed_sessions_config_fixture) - config['clobber'] = False - config['output_dir'] = str(output_dir.resolve().absolute()) - config['on_missing_file'] = 'warn' - config['ecephys_nwb_prefix'] = 'not_here' - config['ecephys_nwb_dir'] = str(output_dir.resolve().absolute()) - config['output_json'] = str(output_json_path.resolve().absolute()) + config["clobber"] = False + config["output_dir"] = str(output_dir.resolve().absolute()) + config["on_missing_file"] = "warn" + config["ecephys_nwb_prefix"] = "not_here" + config["ecephys_nwb_dir"] = str(output_dir.resolve().absolute()) + config["output_json"] = str(output_json_path.resolve().absolute()) writer = VBN2022MetadataWriterClass(args=[], input_data=config) writer.run() - for fname in ('behavior_sessions.csv', - 'ecephys_sessions.csv', - 'channels.csv', - 'units.csv', - 'probes.csv'): + for fname in ("behavior_sessions.csv", "ecephys_sessions.csv", "channels.csv", "units.csv", "probes.csv"): file_path = output_dir / fname assert file_path.is_file() - ecephys_sessions_df = pd.read_csv(output_dir / 'ecephys_sessions.csv') + ecephys_sessions_df = pd.read_csv(output_dir / "ecephys_sessions.csv") assert len(ecephys_sessions_df) == 1 - for bad_session_id in config['failed_ecephys_session_id_list']: - assert (bad_session_id not in - ecephys_sessions_df.ecephys_session_id.values) + for bad_session_id in config["failed_ecephys_session_id_list"]: + assert bad_session_id not in ecephys_sessions_df.ecephys_session_id.values - for good_session_id in config['ecephys_session_id_list']: - assert (good_session_id in - ecephys_sessions_df.ecephys_session_id.values) + for good_session_id in config["ecephys_session_id_list"]: + assert good_session_id in ecephys_sessions_df.ecephys_session_id.values # make sure this session was recorded with session_number = 2 assert ecephys_sessions_df.session_number.values[0] == 2 - helper_functions.windows_safe_cleanup_dir( - dir_path=output_dir) + helper_functions.windows_safe_cleanup_dir(dir_path=output_dir) diff --git a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_dataframe_manipulations.py b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_dataframe_manipulations.py index 0d59bfffbc..3b0cac6989 100644 --- a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_dataframe_manipulations.py +++ b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_dataframe_manipulations.py @@ -76,9 +76,7 @@ def test_add_images_from_behavior(): ecephys_data.append({"ecephys_session_id": 3, "figure": 8}) ecephys_table = pd.DataFrame(data=ecephys_data) - ecephys_table = _add_images_from_behavior( - ecephys_table=ecephys_table, behavior_table=beh_table - ) + ecephys_table = _add_images_from_behavior(ecephys_table=ecephys_table, behavior_table=beh_table) expected_data = [] expected_data.append( @@ -184,17 +182,11 @@ def test_add_experience_level(): """ input_data = [] - input_data.append({"session": 1, - "prior_exposures_to_image_set": 2, - 'session_type': 'ECEPHYS_1'}) + input_data.append({"session": 1, "prior_exposures_to_image_set": 2, "session_type": "ECEPHYS_1"}) - input_data.append({"session": 2, - "prior_exposures_to_image_set": None, - 'session_type': 'ECEPHYS_2'}) + input_data.append({"session": 2, "prior_exposures_to_image_set": None, "session_type": "ECEPHYS_2"}) - input_data.append({"session": 1, - "prior_exposures_to_image_set": 0, - 'session_type': 'ECEPHYS_1'}) + input_data.append({"session": 1, "prior_exposures_to_image_set": 0, "session_type": "ECEPHYS_1"}) input_df = pd.DataFrame(data=input_data) actual = add_experience_level_simple(input_df=input_df) @@ -287,9 +279,7 @@ def test_add_experience_level(): ), ], ) -def test_patch_date_and_stage_from_pickle_file( - patching_pickle_file_fixture, flag_columns, ids_to_fix, cols_to_fix -): +def test_patch_date_and_stage_from_pickle_file(patching_pickle_file_fixture, flag_columns, ids_to_fix, cols_to_fix): """ Test that _patch_date_and_stage_from_pickle_file correctly patches sessions that are missing @@ -578,11 +568,7 @@ def test_remove_aborted_sessions(self): behavior_session_id=BehaviorSessionId(x), equipment=None, stimulus_frame_rate=None, - session_type=( - SessionType( - self.behavior_sessions_df.loc[x]["session_type"] - ) - ), + session_type=(SessionType(self.behavior_sessions_df.loc[x]["session_type"])), behavior_session_uuid=None, session_duration=self.session_durations.loc[x], ) @@ -595,6 +581,4 @@ def test_remove_aborted_sessions(self): remove_sessions_after_mouse_death_date=False, ) expected = [behavior_sessions[1], behavior_sessions[3]] - assert sorted([x.behavior_session_id for x in actual]) == sorted( - [x.behavior_session_id for x in expected] - ) + assert sorted([x.behavior_session_id for x in actual]) == sorted([x.behavior_session_id for x in expected]) diff --git a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_2022_metadata_writer_lims_queries.py b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_2022_metadata_writer_lims_queries.py index 786ee50245..05bccc8654 100644 --- a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_2022_metadata_writer_lims_queries.py +++ b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_2022_metadata_writer_lims_queries.py @@ -3,22 +3,17 @@ import datetime from allensdk.brain_observatory.behavior.data_objects import BehaviorSessionId -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.behavior_metadata import \ - BehaviorMetadata -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .behavior_metadata.date_of_acquisition import \ - DateOfAcquisition -from allensdk.brain_observatory.behavior.data_objects.metadata\ - .subject_metadata.subject_metadata import \ - SubjectMetadata -from allensdk.brain_observatory.vbn_2022.metadata_writer.lims_queries import ( - _merge_ecephys_id_and_failed) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_metadata import ( + BehaviorMetadata, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.date_of_acquisition import ( + DateOfAcquisition, +) +from allensdk.brain_observatory.behavior.data_objects.metadata.subject_metadata.subject_metadata import SubjectMetadata +from allensdk.brain_observatory.vbn_2022.metadata_writer.lims_queries import _merge_ecephys_id_and_failed from allensdk.internal.brain_observatory.mouse import Mouse -from allensdk.internal.brain_observatory.util.multi_session_utils import \ - remove_invalid_sessions -from allensdk.test.brain_observatory.behavior.data_objects.lims_util import \ - LimsTest +from allensdk.internal.brain_observatory.util.multi_session_utils import remove_invalid_sessions +from allensdk.test.brain_observatory.behavior.data_objects.lims_util import LimsTest class TestLimsQueries(LimsTest): @@ -31,19 +26,14 @@ def test_exclude_deceased_mice(self): ecephys_session_id = 1071300149 behavior_session_id = BehaviorSessionId.from_ecephys_session_id( - db=self.dbconn, - ecephys_session_id=ecephys_session_id + db=self.dbconn, ecephys_session_id=ecephys_session_id ) - mouse = Mouse.from_behavior_session_id( - behavior_session_id=behavior_session_id.value) + mouse = Mouse.from_behavior_session_id(behavior_session_id=behavior_session_id.value) obtained = mouse.get_behavior_sessions(exclude_invalid_sessions=True) - obtained_all = mouse.get_behavior_sessions( - exclude_invalid_sessions=False) + obtained_all = mouse.get_behavior_sessions(exclude_invalid_sessions=False) - assert all([x.date_of_acquisition <= - x.subject_metadata.get_death_date() for x in obtained]) - assert len([x.subject_metadata.get_death_date() > - x.date_of_acquisition for x in obtained_all]) >= 1 + assert all([x.date_of_acquisition <= x.subject_metadata.get_death_date() for x in obtained]) + assert len([x.subject_metadata.get_death_date() > x.date_of_acquisition for x in obtained_all]) >= 1 def test_filter_on_death_date(): @@ -54,9 +44,7 @@ def test_filter_on_death_date(): for is_valid in (True, False): sessions += [ BehaviorMetadata( - date_of_acquisition=DateOfAcquisition( - date_of_acquisition=datetime.datetime(2020, 6, 7) - ), + date_of_acquisition=DateOfAcquisition(date_of_acquisition=datetime.datetime(2020, 6, 7)), behavior_session_id=i, behavior_session_uuid=None, equipment=None, @@ -69,23 +57,19 @@ def test_filter_on_death_date(): mouse_id=i, reporter_line=None, sex=None, - death_on=( - datetime.datetime(2020, 6, 8) if is_valid else - datetime.datetime(2020, 6, 6) - ) - ) - ) for i in range(2)] + death_on=(datetime.datetime(2020, 6, 8) if is_valid else datetime.datetime(2020, 6, 6)), + ), + ) + for i in range(2) + ] actual = remove_invalid_sessions( behavior_sessions=sessions, remove_pretest_sessions=False, remove_aborted_sessions=False, - remove_sessions_after_mouse_death_date=True + remove_sessions_after_mouse_death_date=True, ) - expected = [ - sessions[0], - sessions[1] - ] + expected = [sessions[0], sessions[1]] assert actual == expected @@ -98,37 +82,31 @@ def test_merge_ecephys_id_and_failed(): """ ecephys_data = [ - {'ecephys_session_id': 4, - 'donor_id': '100'}, - {'ecephys_session_id': 1, - 'donor_id': '200'}, - {'ecephys_session_id': 3, - 'donor_id': '300'}, - {'ecephys_session_id': 2, - 'donor_id': '400'}] + {"ecephys_session_id": 4, "donor_id": "100"}, + {"ecephys_session_id": 1, "donor_id": "200"}, + {"ecephys_session_id": 3, "donor_id": "300"}, + {"ecephys_session_id": 2, "donor_id": "400"}, + ] failed_data = [ - {'ecephys_session_id': 7, - 'donor_id': '300'}, - {'ecphys_session_id': 6, - 'donor_id': '900'}, - {'ecephys_session_id': 5, - 'donor_id': '200'}] + {"ecephys_session_id": 7, "donor_id": "300"}, + {"ecphys_session_id": 6, "donor_id": "900"}, + {"ecephys_session_id": 5, "donor_id": "200"}, + ] class DummyConnection(object): - def select(self, query=None): - if '5' in query: + if "5" in query: return pd.DataFrame(data=failed_data) - elif '3' in query: + elif "3" in query: return pd.DataFrame(data=ecephys_data) else: - raise RuntimeError( - f"cannot mock query={query}") + raise RuntimeError(f"cannot mock query={query}") expected = [1, 2, 3, 4, 5, 7] actual = _merge_ecephys_id_and_failed( - lims_connection=DummyConnection(), - ecephys_session_id_list=[1, 2, 3, 4], - failed_ecephys_session_id_list=[5, 6, 7]) + lims_connection=DummyConnection(), + ecephys_session_id_list=[1, 2, 3, 4], + failed_ecephys_session_id_list=[5, 6, 7], + ) assert expected == actual diff --git a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_prior_omissions.py b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_prior_omissions.py index c01c29eaad..8104d007cd 100644 --- a/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_prior_omissions.py +++ b/allensdk/test/brain_observatory/vbn_2022/metadata_writer/test_vbn_prior_omissions.py @@ -11,9 +11,7 @@ import datetime import copy -from allensdk.brain_observatory.vbn_2022 \ - .metadata_writer.dataframe_manipulations import ( - _add_prior_omissions) +from allensdk.brain_observatory.vbn_2022.metadata_writer.dataframe_manipulations import _add_prior_omissions @pytest.fixture @@ -33,12 +31,14 @@ def session_list_fixture(): # session_type followed by boolean indicating # if it should increment prior_exposures_to_omissions - session_types = [('junk_ephys', True), - ('silly', False), - ('ephys_junk', True), - ('nonsense', False), - ('some_ecephys_data', True), - (None, False)] + session_types = [ + ("junk_ephys", True), + ("silly", False), + ("ephys_junk", True), + ("nonsense", False), + ("some_ecephys_data", True), + (None, False), + ] tt = 12345678.1 beh = 0 @@ -62,19 +62,20 @@ def session_list_fixture(): session_type = rng.choice(session_types) - tt += rng.random()*10.0 + tt += rng.random() * 10.0 if session_type[0] is None: prior = None else: prior = ct_omissions this_session = { - 'mouse_id': mouse_id, - 'behavior_session_id': behavior_id, - 'ecephys_session_id': ecephys_id, - 'date_of_acquisition': datetime.datetime.fromtimestamp(tt), - 'session_type': session_type[0], - 'prior_exposures_to_omissions': prior} + "mouse_id": mouse_id, + "behavior_session_id": behavior_id, + "ecephys_session_id": ecephys_id, + "date_of_acquisition": datetime.datetime.fromtimestamp(tt), + "session_type": session_type[0], + "prior_exposures_to_omissions": prior, + } output.append(this_session) if session_type[1]: ct_omissions += 1 @@ -83,8 +84,7 @@ def session_list_fixture(): @pytest.fixture -def session_dfs_fixture( - session_list_fixture): +def session_dfs_fixture(session_list_fixture): """ Return behavior_sessions_df and ecephys_sessions_df based on session_list_fixture. Include both input dataframes @@ -97,10 +97,10 @@ def session_dfs_fixture( behavior_data = [] ecephys_data = [] for session in session_list: - if session['ecephys_session_id'] is None: + if session["ecephys_session_id"] is None: is_beh = True is_ece = False - elif session['behavior_session_id'] is None: + elif session["behavior_session_id"] is None: is_beh = False is_ece = True else: @@ -122,74 +122,61 @@ def session_dfs_fixture( expected_behavior_df = pd.DataFrame(data=behavior_data) expected_ecephys_df = pd.DataFrame(data=ecephys_data) - input_behavior_df = expected_behavior_df.copy().drop( - axis='columns', columns=['prior_exposures_to_omissions']) - input_ecephys_df = expected_ecephys_df.copy().drop( - axis='columns', columns=['prior_exposures_to_omissions']) + input_behavior_df = expected_behavior_df.copy().drop(axis="columns", columns=["prior_exposures_to_omissions"]) + input_ecephys_df = expected_ecephys_df.copy().drop(axis="columns", columns=["prior_exposures_to_omissions"]) - return {'behavior': input_behavior_df, - 'ecephys': input_ecephys_df, - 'expected_behavior': expected_behavior_df, - 'expected_ecephys': expected_ecephys_df} + return { + "behavior": input_behavior_df, + "ecephys": input_ecephys_df, + "expected_behavior": expected_behavior_df, + "expected_ecephys": expected_ecephys_df, + } -def test_vbn_add_prior_omissions( - session_dfs_fixture): +def test_vbn_add_prior_omissions(session_dfs_fixture): """ test _add_prior_omissions on actual data """ - behavior_df = session_dfs_fixture['behavior'].copy(deep=True) - ecephys_df = session_dfs_fixture['ecephys'].copy(deep=True) + behavior_df = session_dfs_fixture["behavior"].copy(deep=True) + ecephys_df = session_dfs_fixture["ecephys"].copy(deep=True) for df in (behavior_df, ecephys_df): - assert 'prior_exposures_to_omissions' not in df.columns + assert "prior_exposures_to_omissions" not in df.columns - result = _add_prior_omissions( - behavior_sessions_df=behavior_df, - ecephys_sessions_df=ecephys_df) + result = _add_prior_omissions(behavior_sessions_df=behavior_df, ecephys_sessions_df=ecephys_df) - behavior_df = result['behavior'] - ecephys_df = result['ecephys'] + behavior_df = result["behavior"] + ecephys_df = result["ecephys"] for df in (behavior_df, ecephys_df): - assert 'prior_exposures_to_omissions' in df.columns + assert "prior_exposures_to_omissions" in df.columns - pd.testing.assert_frame_equal( - behavior_df, - session_dfs_fixture['expected_behavior'], - check_dtype=False) + pd.testing.assert_frame_equal(behavior_df, session_dfs_fixture["expected_behavior"], check_dtype=False) - pd.testing.assert_frame_equal( - ecephys_df, - session_dfs_fixture['expected_ecephys'], - check_dtype=False) + pd.testing.assert_frame_equal(ecephys_df, session_dfs_fixture["expected_ecephys"], check_dtype=False) -def test_vbn_date_off_warning( - session_dfs_fixture): +def test_vbn_date_off_warning(session_dfs_fixture): """ Test that a warning is raised when behavior_sessions_df and ecephys_sessions_df disagree on the date of a session """ - behavior_sessions = session_dfs_fixture['behavior'].copy(deep=True) - ecephys_sessions = session_dfs_fixture['ecephys'].copy(deep=True) + behavior_sessions = session_dfs_fixture["behavior"].copy(deep=True) + ecephys_sessions = session_dfs_fixture["ecephys"].copy(deep=True) - in_beh = set([beh for beh in - behavior_sessions.loc[ - behavior_sessions.behavior_session_id.notnull() - ].behavior_session_id]) + in_beh = set( + [beh for beh in behavior_sessions.loc[behavior_sessions.behavior_session_id.notnull()].behavior_session_id] + ) in_ece = set(ecephys_sessions.behavior_session_id) in_both = in_ece.intersection(in_beh) assert len(in_both) > 0 in_both = list(in_both) in_both.sort() - chosen_id = in_both[len(in_both)//2] + chosen_id = in_both[len(in_both) // 2] - beh_date = behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id] - ece_date = ecephys_sessions.loc[ - ecephys_sessions.behavior_session_id == chosen_id] + beh_date = behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id] + ece_date = ecephys_sessions.loc[ecephys_sessions.behavior_session_id == chosen_id] # make sure the two dataframes agree on the date initially assert len(beh_date) == 1 @@ -200,14 +187,12 @@ def test_vbn_date_off_warning( # change the date in one of the frames and make sure the # disagreement shows up - behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id, - 'date_of_acquisition'] = datetime.datetime.fromtimestamp(99.0) + behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id, "date_of_acquisition"] = ( + datetime.datetime.fromtimestamp(99.0) + ) - beh_date = behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id] - ece_date = ecephys_sessions.loc[ - ecephys_sessions.behavior_session_id == chosen_id] + beh_date = behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id] + ece_date = ecephys_sessions.loc[ecephys_sessions.behavior_session_id == chosen_id] assert len(beh_date) == 1 assert len(ece_date) == 1 @@ -215,28 +200,27 @@ def test_vbn_date_off_warning( ece_date = ece_date.date_of_acquisition.values[0] assert beh_date != ece_date - with pytest.warns(UserWarning, - match="disagree on the date of behavior session"): - _add_prior_omissions( - behavior_sessions_df=behavior_sessions, - ecephys_sessions_df=ecephys_sessions) + with pytest.warns(UserWarning, match="disagree on the date of behavior session"): + _add_prior_omissions(behavior_sessions_df=behavior_sessions, ecephys_sessions_df=ecephys_sessions) -def test_vbn_session_type_off_error( - session_dfs_fixture): +def test_vbn_session_type_off_error(session_dfs_fixture): """ Test that an exception is thrown when behavior_sessions_df and ecephys_sessions_df disagree on the session_type """ - behavior_sessions = session_dfs_fixture['behavior'].copy(deep=True) - ecephys_sessions = session_dfs_fixture['ecephys'].copy(deep=True) + behavior_sessions = session_dfs_fixture["behavior"].copy(deep=True) + ecephys_sessions = session_dfs_fixture["ecephys"].copy(deep=True) - in_beh = set([beh for beh in - behavior_sessions.loc[ - behavior_sessions.session_type.notnull() - & behavior_sessions.behavior_session_id.notnull() - ].behavior_session_id]) + in_beh = set( + [ + beh + for beh in behavior_sessions.loc[ + behavior_sessions.session_type.notnull() & behavior_sessions.behavior_session_id.notnull() + ].behavior_session_id + ] + ) in_ece = set(ecephys_sessions.behavior_session_id) @@ -244,12 +228,10 @@ def test_vbn_session_type_off_error( assert len(in_both) > 0 in_both = list(in_both) in_both.sort() - chosen_id = in_both[len(in_both)//2] + chosen_id = in_both[len(in_both) // 2] - beh_type = behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id] - ece_type = ecephys_sessions.loc[ - ecephys_sessions.behavior_session_id == chosen_id] + beh_type = behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id] + ece_type = ecephys_sessions.loc[ecephys_sessions.behavior_session_id == chosen_id] # make sure the two dataframes agree on the session_type initially assert len(beh_type) == 1 @@ -260,14 +242,12 @@ def test_vbn_session_type_off_error( # change the session_type in one of the frames and make sure the # disagreement shows up - behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id, - 'session_type'] = 'something_new_and_bizarre' + behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id, "session_type"] = ( + "something_new_and_bizarre" + ) - beh_type = behavior_sessions.loc[ - behavior_sessions.behavior_session_id == chosen_id] - ece_type = ecephys_sessions.loc[ - ecephys_sessions.behavior_session_id == chosen_id] + beh_type = behavior_sessions.loc[behavior_sessions.behavior_session_id == chosen_id] + ece_type = ecephys_sessions.loc[ecephys_sessions.behavior_session_id == chosen_id] assert len(beh_type) == 1 assert len(ece_type) == 1 @@ -275,8 +255,5 @@ def test_vbn_session_type_off_error( ece_type = ece_type.session_type.values[0] assert beh_type != ece_type - with pytest.raises(RuntimeError, - match="disagree on the session type"): - _add_prior_omissions( - behavior_sessions_df=behavior_sessions, - ecephys_sessions_df=ecephys_sessions) + with pytest.raises(RuntimeError, match="disagree on the session type"): + _add_prior_omissions(behavior_sessions_df=behavior_sessions, ecephys_sessions_df=ecephys_sessions) diff --git a/allensdk/test/config/test_config_single_file_json.py b/allensdk/test/config/test_config_single_file_json.py index 390e0839e0..d430ee3221 100644 --- a/allensdk/test/config/test_config_single_file_json.py +++ b/allensdk/test/config/test_config_single_file_json.py @@ -36,6 +36,7 @@ import pytest from unittest.mock import patch, mock_open from allensdk.model.biophys_sim.config import Config + try: import __builtin__ as builtins # @UnresolvedImport except Exception: @@ -44,7 +45,7 @@ @pytest.fixture def simple_config(): - manifest = '''{ + manifest = """{ "manifest": [ { "type": "dir", "spec": "MOCK_DOT", @@ -52,22 +53,21 @@ def simple_config(): }], "biophys": [{ "hoc": [ "stdgui.hoc"] }] - }''' + }""" - with patch(builtins.__name__ + ".open", - mock_open(read_data=manifest)): - config = Config().load('config.json', False) + with patch(builtins.__name__ + ".open", mock_open(read_data=manifest)): + config = Config().load("config.json", False) return config def testAccessHocFilesInData(simple_config): - assert simple_config.data['biophys'][0]['hoc'][0] == 'stdgui.hoc' + assert simple_config.data["biophys"][0]["hoc"][0] == "stdgui.hoc" def testManifestIsNotInData(simple_config): - assert 'manifest' not in simple_config.data + assert "manifest" not in simple_config.data def testManifestInReservedData(simple_config): - assert 'manifest' in simple_config.reserved_data[0] + assert "manifest" in simple_config.reserved_data[0] diff --git a/allensdk/test/config/test_json_comments.py b/allensdk/test/config/test_json_comments.py index e42aa71a71..49c04a5300 100644 --- a/allensdk/test/config/test_json_comments.py +++ b/allensdk/test/config/test_json_comments.py @@ -39,6 +39,7 @@ import allensdk.core.json_utilities as ju from allensdk.core.json_utilities import JsonComments import logging + try: import __builtin__ as builtins # @UnresolvedImport except Exception: @@ -47,163 +48,125 @@ @pytest.fixture def commented_json(): - return ("{\n" - " // comment\n" - " \"color\": \"blue\"\n" - "}") + return '{\n // comment\n "color": "blue"\n}' @pytest.fixture def blank_line_json(): - return ("{\n" - "\n" - "\n" - "\n" - " \"color\": \"blue\"\n" - "}") + return '{\n\n\n\n "color": "blue"\n}' @pytest.fixture def multi_line_json(): - return ("{\n" - "/* \n" - " * multiline comment\n" - " */\n" - " \"color\": \"blue\"\n" - "}") + return '{\n/* \n * multiline comment\n */\n "color": "blue"\n}' @pytest.fixture def two_multi_line_json(): - return ("{\n" - " \"colors\": [\"blue\",\n" - " /* comment these out\n" - " \"red\",\n" - " \"yellow\",\n" - " ... but not these */\n" - " \"orange\",\n" - " \"purple\",\n" - " /* also comment this out\n" - " \"indigo\",\n" - " .... end comment */\n" - " \"violet\"\n" - " ]\n" - "}") + return ( + "{\n" + ' "colors": ["blue",\n' + " /* comment these out\n" + ' "red",\n' + ' "yellow",\n' + " ... but not these */\n" + ' "orange",\n' + ' "purple",\n' + " /* also comment this out\n" + ' "indigo",\n' + " .... end comment */\n" + ' "violet"\n' + " ]\n" + "}" + ) @pytest.fixture def corrupted_json(): - return ("{\n" - " \"colors\": \"blue\",\n" - " /* comment these out\n" - " \"red\",\n" - " \"yel") + return '{\n "colors": "blue",\n /* comment these out\n "red",\n "yel' @pytest.fixture def ju_logger(): - log = logging.getLogger('allensdk.core.json_utilities') + log = logging.getLogger("allensdk.core.json_utilities") log.error = Mock() return log -def testSingleLineCommentJSONDecodeError(corrupted_json, - ju_logger): +def testSingleLineCommentJSONDecodeError(corrupted_json, ju_logger): with pytest.raises(JSONDecodeError) as e_info: - with patch(builtins.__name__ + ".open", - mock_open(read_data=corrupted_json)): + with patch(builtins.__name__ + ".open", mock_open(read_data=corrupted_json)): JsonComments.read_file("corrupted.json") - ju_logger.error.assert_called_once_with( - 'Could not load json object from file: corrupted.json') - assert e_info.typename == 'JSONDecodeError' + ju_logger.error.assert_called_once_with("Could not load json object from file: corrupted.json") + assert e_info.typename == "JSONDecodeError" def testSingleLineComment(commented_json): - parsed_json = JsonComments.read_string( - commented_json) + parsed_json = JsonComments.read_string(commented_json) - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testBlankLines(blank_line_json): - parsed_json = JsonComments.read_string( - blank_line_json) + parsed_json = JsonComments.read_string(blank_line_json) - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testMultiLineComment(multi_line_json): - parsed_json = JsonComments.read_string( - multi_line_json) + parsed_json = JsonComments.read_string(multi_line_json) - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testTwoMultiLineComments(two_multi_line_json): - parsed_json = JsonComments.read_string( - two_multi_line_json) + parsed_json = JsonComments.read_string(two_multi_line_json) - assert('colors' in parsed_json) - assert(len(parsed_json['colors']) == 4) - assert('blue' in parsed_json['colors']) - assert('orange' in parsed_json['colors']) - assert('purple' in parsed_json['colors']) - assert('violet' in parsed_json['colors']) + assert "colors" in parsed_json + assert len(parsed_json["colors"]) == 4 + assert "blue" in parsed_json["colors"] + assert "orange" in parsed_json["colors"] + assert "purple" in parsed_json["colors"] + assert "violet" in parsed_json["colors"] def testSingleLineCommentFile(commented_json): - with patch(builtins.__name__ + ".open", - mock_open( - read_data=commented_json)): - parsed_json = JsonComments.read_file('mock.json') + with patch(builtins.__name__ + ".open", mock_open(read_data=commented_json)): + parsed_json = JsonComments.read_file("mock.json") - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testBlankLinesFile(blank_line_json): - with patch(builtins.__name__ + ".open", - mock_open( - read_data=blank_line_json)): - parsed_json = JsonComments.read_file('mock.json') + with patch(builtins.__name__ + ".open", mock_open(read_data=blank_line_json)): + parsed_json = JsonComments.read_file("mock.json") - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testMultiLineFile(multi_line_json): - with patch(builtins.__name__ + ".open", - mock_open( - read_data=multi_line_json)): - parsed_json = JsonComments.read_file('mock.json') + with patch(builtins.__name__ + ".open", mock_open(read_data=multi_line_json)): + parsed_json = JsonComments.read_file("mock.json") - assert('color' in parsed_json and - parsed_json['color'] == 'blue') + assert "color" in parsed_json and parsed_json["color"] == "blue" def testTwoMultiLineFile(two_multi_line_json): - with patch(builtins.__name__ + ".open", - mock_open( - read_data=two_multi_line_json)): - parsed_json = JsonComments.read_file('mock.json') + with patch(builtins.__name__ + ".open", mock_open(read_data=two_multi_line_json)): + parsed_json = JsonComments.read_file("mock.json") - assert('colors' in parsed_json) - assert(len(parsed_json['colors']) == 4) - assert('blue' in parsed_json['colors']) - assert('orange' in parsed_json['colors']) - assert('purple' in parsed_json['colors']) - assert('violet' in parsed_json['colors']) + assert "colors" in parsed_json + assert len(parsed_json["colors"]) == 4 + assert "blue" in parsed_json["colors"] + assert "orange" in parsed_json["colors"] + assert "purple" in parsed_json["colors"] + assert "violet" in parsed_json["colors"] def test_write_nan(): - with patch(builtins.__name__ + ".open", - mock_open(), - create=True) as mo: - ju.write('/some/file/test.json', { "thing": float('nan')}) - - assert 'null' in str(mo().write.call_args_list[0]) + with patch(builtins.__name__ + ".open", mock_open(), create=True) as mo: + ju.write("/some/file/test.json", {"thing": float("nan")}) + + assert "null" in str(mo().write.call_args_list[0]) diff --git a/allensdk/test/config/test_manifest.py b/allensdk/test/config/test_manifest.py index cddbb90c1f..6a745bc391 100644 --- a/allensdk/test/config/test_manifest.py +++ b/allensdk/test/config/test_manifest.py @@ -42,55 +42,49 @@ @pytest.fixture def builder(): b = ManifestBuilder() - b.add_path('BASEDIR', '/home/username/example') + b.add_path("BASEDIR", "/home/username/example") return b def testManifestConstructor(builder): manifest = builder.get_manifest() - expected = os.path.abspath('/home/username/example') - actual = manifest.get_path('BASEDIR') - assert(expected == actual) + expected = os.path.abspath("/home/username/example") + actual = manifest.get_path("BASEDIR") + assert expected == actual def testManifestParent(builder): - builder.add_path('WORKDIR', - 'work', - parent_key='BASEDIR') + builder.add_path("WORKDIR", "work", parent_key="BASEDIR") manifest = builder.get_manifest() - expected = os.path.abspath('/home/username/example/work') - actual = manifest.get_path('WORKDIR') - assert(expected == actual) + expected = os.path.abspath("/home/username/example/work") + actual = manifest.get_path("WORKDIR") + assert expected == actual def testManifestBuilderDataFrame(builder): - builder.add_path('WORKDIR', - 'work', - parent_key='BASEDIR') + builder.add_path("WORKDIR", "work", parent_key="BASEDIR") builder_df = builder.as_dataframe() - assert('key' in builder_df.keys()) - assert('type' in builder_df.keys()) - assert('spec' in builder_df.keys()) - assert('parent_key' in builder_df.keys()) - assert('format' in builder_df.keys()) - assert(5 == len(builder_df.keys())) + assert "key" in builder_df.keys() + assert "type" in builder_df.keys() + assert "spec" in builder_df.keys() + assert "parent_key" in builder_df.keys() + assert "format" in builder_df.keys() + assert 5 == len(builder_df.keys()) def testManifestDataFrame(builder): - builder.add_path('WORKDIR', - 'work', - parent_key='BASEDIR') + builder.add_path("WORKDIR", "work", parent_key="BASEDIR") manifest = builder.get_manifest() df = manifest.as_dataframe() - assert('type' in df.keys()) - assert('spec' in df.keys()) - assert(2 == len(df.keys())) + assert "type" in df.keys() + assert "spec" in df.keys() + assert 2 == len(df.keys()) def safe_mkdir_root_dir(): directory = os.path.abspath(os.sep) - Manifest.safe_mkdir(directory) # should not error \ No newline at end of file + Manifest.safe_mkdir(directory) # should not error diff --git a/allensdk/test/config/test_multi_file_config.py b/allensdk/test/config/test_multi_file_config.py index 6472b4d178..af7bfee945 100644 --- a/allensdk/test/config/test_multi_file_config.py +++ b/allensdk/test/config/test_multi_file_config.py @@ -36,6 +36,7 @@ import pytest from unittest.mock import patch, mock_open from allensdk.config.model.description_parser import DescriptionParser + try: import __builtin__ as builtins except Exception: @@ -44,91 +45,76 @@ @pytest.fixture def multiconfig(): - file_1 = ("{\n" - " \"section_A\": [\n" - " {\n" - " \"prop_a\": \"val_a\",\n" - " \"prop_b\": \"val_b\"\n" - " },\n" - " {\n" - " \"prop_c\": \"val_c\",\n" - " \"prop_d\": \"val_d\"\n" - " }\n" - " ],\n" - - " \"section_B\": [\n" - " {\n" - " \"prop_e\": \"val_e\",\n" - " \"prop_f\": \"val_f\"\n" - " },\n" - " {\n" - " \"prop_g\": \"val_g\",\n" - " \"prop_h\": \"val_h\"\n" - " }\n" - " ]\n" - "}\n" - ) - file_2 = ("{\n" - " \"section_B\": [\n" - " {\n" - " \"prop_i\": \"val_i\",\n" - " \"prop_j\": \"val_j\"\n" - " }\n" - " ],\n" - " \"section_C\": [\n" - " {\n" - " \"prop_k\": \"val_k\",\n" - " \"prop_l\": \"val_l\"\n" - " }\n" - " ]\n" - "}\n" - ) + file_1 = ( + "{\n" + ' "section_A": [\n' + " {\n" + ' "prop_a": "val_a",\n' + ' "prop_b": "val_b"\n' + " },\n" + " {\n" + ' "prop_c": "val_c",\n' + ' "prop_d": "val_d"\n' + " }\n" + " ],\n" + ' "section_B": [\n' + " {\n" + ' "prop_e": "val_e",\n' + ' "prop_f": "val_f"\n' + " },\n" + " {\n" + ' "prop_g": "val_g",\n' + ' "prop_h": "val_h"\n' + " }\n" + " ]\n" + "}\n" + ) + file_2 = ( + "{\n" + ' "section_B": [\n' + " {\n" + ' "prop_i": "val_i",\n' + ' "prop_j": "val_j"\n' + " }\n" + " ],\n" + ' "section_C": [\n' + " {\n" + ' "prop_k": "val_k",\n' + ' "prop_l": "val_l"\n' + " }\n" + " ]\n" + "}\n" + ) parser = DescriptionParser() - with patch(builtins.__name__ + ".open", - mock_open(read_data=file_1)): + with patch(builtins.__name__ + ".open", mock_open(read_data=file_1)): description = parser.read("mock_1.json") - with patch(builtins.__name__ + ".open", - mock_open(read_data=file_2)): + with patch(builtins.__name__ + ".open", mock_open(read_data=file_2)): parser.read("mock_2.json", description) return description def testAllSectionsPresent(multiconfig): - assert ('section_A' in multiconfig.data and - 'section_B' in multiconfig.data and - 'section_C' in multiconfig.data) + assert "section_A" in multiconfig.data and "section_B" in multiconfig.data and "section_C" in multiconfig.data assert len(multiconfig.data.keys()) == 3 def testSectionA(multiconfig): - assert len(multiconfig.data['section_A']) == 2 - assert multiconfig.data['section_A'][0] == { - 'prop_a': 'val_a', - 'prop_b': 'val_b'} - assert multiconfig.data['section_A'][1] == { - 'prop_c': 'val_c', - 'prop_d': 'val_d'} + assert len(multiconfig.data["section_A"]) == 2 + assert multiconfig.data["section_A"][0] == {"prop_a": "val_a", "prop_b": "val_b"} + assert multiconfig.data["section_A"][1] == {"prop_c": "val_c", "prop_d": "val_d"} def testSectionB(multiconfig): - assert len(multiconfig.data['section_B']) == 3 - assert multiconfig.data['section_B'][0] == { - 'prop_e': 'val_e', - 'prop_f': 'val_f'} - assert multiconfig.data['section_B'][1] == { - 'prop_g': 'val_g', - 'prop_h': 'val_h'} - assert multiconfig.data['section_B'][2] == { - 'prop_i': 'val_i', - 'prop_j': 'val_j'} + assert len(multiconfig.data["section_B"]) == 3 + assert multiconfig.data["section_B"][0] == {"prop_e": "val_e", "prop_f": "val_f"} + assert multiconfig.data["section_B"][1] == {"prop_g": "val_g", "prop_h": "val_h"} + assert multiconfig.data["section_B"][2] == {"prop_i": "val_i", "prop_j": "val_j"} def testSectionC(multiconfig): - assert len(multiconfig.data['section_C']) == 1 - assert multiconfig.data['section_C'][0] == { - 'prop_k': 'val_k', - 'prop_l': 'val_l'} + assert len(multiconfig.data["section_C"]) == 1 + assert multiconfig.data["section_C"][0] == {"prop_k": "val_k", "prop_l": "val_l"} diff --git a/allensdk/test/config/test_pyconfig_parser.py b/allensdk/test/config/test_pyconfig_parser.py index 77b0668819..e18f39a0f1 100644 --- a/allensdk/test/config/test_pyconfig_parser.py +++ b/allensdk/test/config/test_pyconfig_parser.py @@ -36,6 +36,7 @@ import pytest from unittest.mock import patch, mock_open from allensdk.config.model.description_parser import DescriptionParser + try: import __builtin__ as builtins except Exception: @@ -44,93 +45,76 @@ @pytest.fixture def pyconfig(): - file_1 = ("{\n" - " \"section_A\": [\n" - " {\n" - " \"prop_a\": \"val_a\",\n" - " \"prop_b\": \"val_b\"\n" - " },\n" - " {\n" - " \"prop_c\": \"val_c\",\n" - " \"prop_d\": \"val_d\"\n" - " }\n" - " ],\n" - " \"section_B\": [\n" - " {\n" - " \"prop_e\": \"val_e\",\n" - " \"prop_f\": \"val_f\"\n" - " },\n" - " {\n" - " \"prop_g\": \"val_g\",\n" - " \"prop_h\": \"val_h\"\n" - " }\n" - " ]\n" - "}\n" - ) - file_2 = ("{\n" - " \"section_B\": [\n" - " {\n" - " \"prop_i\": \"val_i\",\n" - " \"prop_j\": \"val_j\"\n" - " }\n" - " ],\n" - " \"section_C\": [\n" - " {\n" - " \"prop_k\": \"val_k\",\n" - " \"prop_l\": \"val_l\"\n" - " }\n" - " ]\n" - "}\n" - ) + file_1 = ( + "{\n" + ' "section_A": [\n' + " {\n" + ' "prop_a": "val_a",\n' + ' "prop_b": "val_b"\n' + " },\n" + " {\n" + ' "prop_c": "val_c",\n' + ' "prop_d": "val_d"\n' + " }\n" + " ],\n" + ' "section_B": [\n' + " {\n" + ' "prop_e": "val_e",\n' + ' "prop_f": "val_f"\n' + " },\n" + " {\n" + ' "prop_g": "val_g",\n' + ' "prop_h": "val_h"\n' + " }\n" + " ]\n" + "}\n" + ) + file_2 = ( + "{\n" + ' "section_B": [\n' + " {\n" + ' "prop_i": "val_i",\n' + ' "prop_j": "val_j"\n' + " }\n" + " ],\n" + ' "section_C": [\n' + " {\n" + ' "prop_k": "val_k",\n' + ' "prop_l": "val_l"\n' + " }\n" + " ]\n" + "}\n" + ) - with patch(builtins.__name__ + ".open", - mock_open( - read_data=file_1)): + with patch(builtins.__name__ + ".open", mock_open(read_data=file_1)): parser = DescriptionParser() description = parser.read("mock_1.pycfg") - with patch(builtins.__name__ + ".open", - mock_open( - read_data=file_2)): + with patch(builtins.__name__ + ".open", mock_open(read_data=file_2)): parser = DescriptionParser() - parser.read("mock_2.pycfg", - description) + parser.read("mock_2.pycfg", description) return description def testAllSectionsPresent(pyconfig): - assert('section_A' in pyconfig.data and - 'section_B' in pyconfig.data and - 'section_C' in pyconfig.data) - assert(len(pyconfig.data.keys()) == 3) + assert "section_A" in pyconfig.data and "section_B" in pyconfig.data and "section_C" in pyconfig.data + assert len(pyconfig.data.keys()) == 3 def testSectionA(pyconfig): - assert len(pyconfig.data['section_A']) == 2 - assert pyconfig.data['section_A'][0] == { - 'prop_a': 'val_a', - 'prop_b': 'val_b'} - assert pyconfig.data['section_A'][1] == { - 'prop_c': 'val_c', - 'prop_d': 'val_d'} + assert len(pyconfig.data["section_A"]) == 2 + assert pyconfig.data["section_A"][0] == {"prop_a": "val_a", "prop_b": "val_b"} + assert pyconfig.data["section_A"][1] == {"prop_c": "val_c", "prop_d": "val_d"} def testSectionB(pyconfig): - assert len(pyconfig.data['section_B']) == 3 - assert pyconfig.data['section_B'][0] == { - 'prop_e': 'val_e', - 'prop_f': 'val_f'} - assert pyconfig.data['section_B'][1] == { - 'prop_g': 'val_g', - 'prop_h': 'val_h'} - assert pyconfig.data['section_B'][2] == { - 'prop_i': 'val_i', - 'prop_j': 'val_j'} + assert len(pyconfig.data["section_B"]) == 3 + assert pyconfig.data["section_B"][0] == {"prop_e": "val_e", "prop_f": "val_f"} + assert pyconfig.data["section_B"][1] == {"prop_g": "val_g", "prop_h": "val_h"} + assert pyconfig.data["section_B"][2] == {"prop_i": "val_i", "prop_j": "val_j"} def testSectionC(pyconfig): - assert len(pyconfig.data['section_C']) == 1 - assert pyconfig.data['section_C'][0] == { - 'prop_k': 'val_k', - 'prop_l': 'val_l'} + assert len(pyconfig.data["section_C"]) == 1 + assert pyconfig.data["section_C"][0] == {"prop_k": "val_k", "prop_l": "val_l"} diff --git a/allensdk/test/conftest.py b/allensdk/test/conftest.py index e41fd3b9c2..2a26a92212 100644 --- a/allensdk/test/conftest.py +++ b/allensdk/test/conftest.py @@ -6,18 +6,17 @@ class HelperFunctions(object): - @staticmethod def create_blank_nwb_file(): """ Create and return an empty NWB file """ nwbfile = NWBFile( - session_description='foo', - identifier='1', - session_id='foo', + session_description="foo", + identifier="1", + session_id="foo", session_start_time=datetime.datetime.now(), - institution="Allen Institute" + institution="Allen Institute", ) return nwbfile @@ -49,8 +48,7 @@ def _first_pass_safe_cleanup_dir(dir_path: pathlib.Path): if this_path.is_file(): HelperFunctions.windows_safe_cleanup(file_path=this_path) elif this_path.is_dir(): - HelperFunctions.windows_safe_cleanup_dir( - dir_path=this_path) + HelperFunctions.windows_safe_cleanup_dir(dir_path=this_path) try: this_path.rmdir() except Exception: @@ -63,8 +61,7 @@ def windows_safe_cleanup_dir(dir_path: pathlib.Path): If a PermissionError is raised, ignore if the system is Windows (this has been observed on our CI systems) """ - HelperFunctions._first_pass_safe_cleanup_dir( - dir_path=dir_path) + HelperFunctions._first_pass_safe_cleanup_dir(dir_path=dir_path) contents_list = [n for n in dir_path.iterdir()] for this_path in contents_list: @@ -75,7 +72,7 @@ def windows_safe_cleanup_dir(dir_path: pathlib.Path): raise -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def helper_functions(): """ See solution to making helper functions available across diff --git a/allensdk/test/core/test_authentication.py b/allensdk/test/core/test_authentication.py index f8b2531aab..38f23ed8ea 100644 --- a/allensdk/test/core/test_authentication.py +++ b/allensdk/test/core/test_authentication.py @@ -1,51 +1,51 @@ import pytest -from allensdk.core.authentication import ( - EnvCredentialProvider, credential_injector, set_credential_provider) +from allensdk.core.authentication import EnvCredentialProvider, credential_injector, set_credential_provider @pytest.mark.parametrize( "provider,credential_map,expected", [ - (EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), - {"user": "LIMS_USER", "password": "LIMS_PASSWORD"}, - ("user", "1234")), - ] + ( + EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), + {"user": "LIMS_USER", "password": "LIMS_PASSWORD"}, + ("user", "1234"), + ), + ], ) def test_credential_injector(provider, credential_map, expected): def mock_func(*, user, password): return (user, password) - assert ( - credential_injector(credential_map, provider)(mock_func)() == expected) + + assert credential_injector(credential_map, provider)(mock_func)() == expected @pytest.mark.parametrize( "provider,credential_map,expected", [ - (EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), - {"user": "LIMS_USER"}, - ("user")), - ] + (EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), {"user": "LIMS_USER"}, ("user")), + ], ) -def test_credential_injector_only_injects_existing_kwargs( - provider, credential_map, expected): +def test_credential_injector_only_injects_existing_kwargs(provider, credential_map, expected): def mock_func(*, user): return user - assert ( - credential_injector(credential_map, provider)(mock_func)() == expected) + + assert credential_injector(credential_map, provider)(mock_func)() == expected @pytest.mark.parametrize( "provider,credential_map", [ - (EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), - {"user": "LIMS_USER", "password": "LIMS_PASSWORD"},), - ] + ( + EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), + {"user": "LIMS_USER", "password": "LIMS_PASSWORD"}, + ), + ], ) -def test_credential_injector_only_injects_mapped_credentials( - provider, credential_map): +def test_credential_injector_only_injects_mapped_credentials(provider, credential_map): def mock_func(*, user, db): pass + with pytest.raises(TypeError): credential_injector(credential_map, provider)(mock_func)() @@ -53,22 +53,28 @@ def mock_func(*, user, db): @pytest.mark.parametrize( "provider,credential_map", [ - (EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), - {"user": "LIMS_USER", "password": "LIMS_PASSWORD"},), - ] + ( + EnvCredentialProvider({"LIMS_USER": "user", "LIMS_PASSWORD": "1234"}), + {"user": "LIMS_USER", "password": "LIMS_PASSWORD"}, + ), + ], ) def test_credential_injector_preserves_function_args(provider, credential_map): def mock_func(arg1, kwarg1=None, *, user, password): return (arg1, kwarg1, user, password) - assert ( - credential_injector(credential_map, provider) - (mock_func)("arg1", kwarg1="kwarg1") - == ("arg1", "kwarg1", "user", "1234")) + + assert credential_injector(credential_map, provider)(mock_func)("arg1", kwarg1="kwarg1") == ( + "arg1", + "kwarg1", + "user", + "1234", + ) def test_credential_injector_with_provider_update(): def mock_func(*, user): return user + provider = EnvCredentialProvider({"LIMS_USER": "user"}) credential_map = {"user": "LIMS_USER"} set_credential_provider(provider) diff --git a/allensdk/test/core/test_brain_observatory_cache.py b/allensdk/test/core/test_brain_observatory_cache.py index c99957fe05..b9367d07b1 100644 --- a/allensdk/test/core/test_brain_observatory_cache.py +++ b/allensdk/test/core/test_brain_observatory_cache.py @@ -115,8 +115,10 @@ @pytest.fixture() def events_test_data(): - return {"pattern": "/allen/aibs/informatics/module_test_data/observatory/events/%d_events.npz", - "experiment_id": 715923832} + return { + "pattern": "/allen/aibs/informatics/module_test_data/observatory/events/%d_events.npz", + "experiment_id": 715923832, + } @pytest.fixture(scope="function") @@ -124,31 +126,25 @@ def brain_observatory_cache(): boc = None try: - manifest_data = bytes(CACHE_MANIFEST, 'UTF-8') # Python 3 + manifest_data = bytes(CACHE_MANIFEST, "UTF-8") # Python 3 except Exception: manifest_data = bytes(CACHE_MANIFEST) # Python 2.7 - with patch('os.path.exists', - return_value=True): - with patch(builtins.__name__ + ".open", - mock_open(read_data=manifest_data)): + with patch("os.path.exists", return_value=True): + with patch(builtins.__name__ + ".open", mock_open(read_data=manifest_data)): # Download a list of all targeted areas - boc = BrainObservatoryCache(manifest_file="some_path/manifest.json", - base_uri='http://api.brain-map.org') + boc = BrainObservatoryCache(manifest_file="some_path/manifest.json", base_uri="http://api.brain-map.org") return boc @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_all_targeted_structures(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_all_targeted_structures(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): brain_observatory_cache.get_all_targeted_structures() mock_json_msg_query.assert_called_once_with( @@ -157,41 +153,36 @@ def test_get_all_targeted_structures(mock_json_msg_query, "ophys_experiments,isi_experiment," "specimen(donor(conditions,age,transgenic_lines))," "targeted_structure," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_experiment_containers(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_experiment_containers(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): # Download experiment containers for VISp experiments - brain_observatory_cache.get_experiment_containers( - targeted_structures=['VISp']) + brain_observatory_cache.get_experiment_containers(targeted_structures=["VISp"]) mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::ExperimentContainer,rma::include," "ophys_experiments,isi_experiment," "specimen(donor(conditions,age,transgenic_lines)),targeted_structure," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_all_cre_lines(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_all_cre_lines(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): # Download a list of all cre lines brain_observatory_cache.get_all_cre_lines() @@ -200,103 +191,103 @@ def test_get_all_cre_lines(mock_json_msg_query, "model::ExperimentContainer,rma::include," "ophys_experiments,isi_experiment," "specimen(donor(conditions,age,transgenic_lines)),targeted_structure," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_ophys_experiments(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_ophys_experiments(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): # Download a list of all transgenic driver lines brain_observatory_cache.get_ophys_experiments() - calls = [call("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::OphysExperiment,rma::include,experiment_container," - "well_known_files(well_known_file_type),targeted_structure," - "specimen(donor(age,transgenic_lines))," - "rma::options[num_rows$eq'all'][count$eqfalse]"), - - call("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::WellKnownFile,rma::criteria,well_known_file_type[name$eqEyeDlcScreenMapping]," - "rma::options[num_rows$eq'all'][count$eqfalse]")] + calls = [ + call( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::OphysExperiment,rma::include,experiment_container," + "well_known_files(well_known_file_type),targeted_structure," + "specimen(donor(age,transgenic_lines))," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ), + call( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::WellKnownFile,rma::criteria,well_known_file_type[name$eqEyeDlcScreenMapping]," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ), + ] mock_json_msg_query.assert_has_calls(calls) @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_all_session_types(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_all_session_types(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): # Download a list of all transgenic driver lines brain_observatory_cache.get_all_session_types() - calls = [call("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::OphysExperiment,rma::include,experiment_container," - "well_known_files(well_known_file_type),targeted_structure," - "specimen(donor(age,transgenic_lines))," - "rma::options[num_rows$eq'all'][count$eqfalse]"), - - call("http://api.brain-map.org/api/v2/data/query.json?q=" - "model::WellKnownFile,rma::criteria,well_known_file_type[name$eqEyeDlcScreenMapping]," - "rma::options[num_rows$eq'all'][count$eqfalse]")] + calls = [ + call( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::OphysExperiment,rma::include,experiment_container," + "well_known_files(well_known_file_type),targeted_structure," + "specimen(donor(age,transgenic_lines))," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ), + call( + "http://api.brain-map.org/api/v2/data/query.json?q=" + "model::WellKnownFile,rma::criteria,well_known_file_type[name$eqEyeDlcScreenMapping]," + "rma::options[num_rows$eq'all'][count$eqfalse]" + ), + ] mock_json_msg_query.assert_has_calls(calls) @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_stimulus_mappings(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_stimulus_mappings(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): - with patch('allensdk.core.json_utilities.read', - MagicMock(name='read_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): + with patch("allensdk.core.json_utilities.read", MagicMock(name="read_json")): # Download a list of all transgenic driver lines brain_observatory_cache._get_stimulus_mappings() mock_json_msg_query.assert_called_once_with( "http://api.brain-map.org/api/v2/data/query.json?q=" "model::ApiCamStimulusMapping," - "rma::options[num_rows$eq'all'][count$eqfalse]") + "rma::options[num_rows$eq'all'][count$eqfalse]" + ) @pytest.mark.skipif(True, reason="need to develop mocks") @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_cell_specimens(mock_json_msg_query, - brain_observatory_cache): - with patch('os.path.exists') as m: +def test_get_cell_specimens(mock_json_msg_query, brain_observatory_cache): + with patch("os.path.exists") as m: m.return_value = False - with patch('allensdk.core.json_utilities.write', - MagicMock(name='write_json')): + with patch("allensdk.core.json_utilities.write", MagicMock(name="write_json")): # Download a list of all transgenic driver lines brain_observatory_cache.get_cell_specimens() - mock_json_msg_query.assert_called_once_with( - "http://api.brain-map.org/api/v2/data/query.json?q=") + mock_json_msg_query.assert_called_once_with("http://api.brain-map.org/api/v2/data/query.json?q=") # NOTE: This test should be updated when ugly hack for associating # ophys experiment id with ophys session id is resolved. @patch.object(BrainObservatoryApi, "json_msg_query") -def test_get_ophys_pupil_data(mock_json_msg_query, - brain_observatory_cache): - - with patch.dict('allensdk.core.ophys_experiment_session_id_mapping.ophys_experiment_session_id_map', {111: 777}, clear=True): +def test_get_ophys_pupil_data(mock_json_msg_query, brain_observatory_cache): + with patch.dict( + "allensdk.core.ophys_experiment_session_id_mapping.ophys_experiment_session_id_map", {111: 777}, clear=True + ): # We are only testing that rma query is correct try: brain_observatory_cache.get_ophys_pupil_data(111, suppress_pupil_data=False) @@ -313,16 +304,16 @@ def test_get_ophys_pupil_data(mock_json_msg_query, def test_build_manifest(tmpdir_factory): try: - manifest_data = bytes(CACHE_MANIFEST, 'UTF-8') # Python 3 + manifest_data = bytes(CACHE_MANIFEST, "UTF-8") # Python 3 except Exception: manifest_data = bytes(CACHE_MANIFEST) # Python 2.7 manifest_file = str(tmpdir_factory.mktemp("boc").join("manifest.json")) - with patch('allensdk.config.manifest_builder.ManifestBuilder.write_json_string') as mock_write_json_string: + with patch("allensdk.config.manifest_builder.ManifestBuilder.write_json_string") as mock_write_json_string: mock_write_json_string.return_value = manifest_data BrainObservatoryCache(manifest_file=manifest_file) - with open(manifest_file, 'rb') as f: + with open(manifest_file, "rb") as f: read_manifest_data = f.read() assert manifest_data == read_manifest_data @@ -332,41 +323,41 @@ def test_string_argument_errors(brain_observatory_cache): boc = brain_observatory_cache with pytest.raises(TypeError): - boc.get_experiment_containers(targeted_structures='str') + boc.get_experiment_containers(targeted_structures="str") with pytest.raises(TypeError): - boc.get_experiment_containers(cre_lines='str') + boc.get_experiment_containers(cre_lines="str") with pytest.raises(TypeError): - boc.get_ophys_experiments(targeted_structures='str') + boc.get_ophys_experiments(targeted_structures="str") with pytest.raises(TypeError): - boc.get_ophys_experiments(cre_lines='str') + boc.get_ophys_experiments(cre_lines="str") with pytest.raises(TypeError): - boc.get_ophys_experiments(stimuli='str') + boc.get_ophys_experiments(stimuli="str") with pytest.raises(TypeError): - boc.get_ophys_experiments(session_types='str') + boc.get_ophys_experiments(session_types="str") -@pytest.mark.skipif(not os.path.exists('/allen/aibs/informatics/module_test_data'), reason='AIBS path not available') -@pytest.mark.parametrize("path_dict", get_list_of_path_dict()) -def test_brain_observatory_cache_get_analysis_file(brain_observatory_cache, path_dict): - nwb_path_pattern = os.path.join(os.path.dirname(path_dict['nwb_file']), '%d.nwb') +@pytest.mark.skipif(not os.path.exists("/allen/aibs/informatics/module_test_data"), reason="AIBS path not available") +@pytest.mark.parametrize("path_dict", get_list_of_path_dict()) +def test_brain_observatory_cache_get_analysis_file(brain_observatory_cache, path_dict): + nwb_path_pattern = os.path.join(os.path.dirname(path_dict["nwb_file"]), "%d.nwb") brain_observatory_cache.manifest.add_path(brain_observatory_cache.EXPERIMENT_DATA_KEY, nwb_path_pattern) - analysis_path_pattern = os.path.join(os.path.dirname(path_dict['analysis_file']), '%d_%s_analysis.h5') + analysis_path_pattern = os.path.join(os.path.dirname(path_dict["analysis_file"]), "%d_%s_analysis.h5") brain_observatory_cache.manifest.add_path(brain_observatory_cache.ANALYSIS_DATA_KEY, analysis_path_pattern) - oeid = path_dict['ophys_experiment_id'] + oeid = path_dict["ophys_experiment_id"] data_set = brain_observatory_cache.get_ophys_experiment_data(oeid) for stimulus in data_set.list_stimuli(): if stimulus != si.SPONTANEOUS_ACTIVITY: brain_observatory_cache.get_ophys_experiment_analysis(oeid, stimulus) -@pytest.mark.skipif(not os.path.exists('/allen/aibs/informatics/module_test_data'), reason='AIBS path not available') +@pytest.mark.skipif(not os.path.exists("/allen/aibs/informatics/module_test_data"), reason="AIBS path not available") def test_brain_observatory_cache_get_events_data(brain_observatory_cache, events_test_data): eid = events_test_data["experiment_id"] data_file = events_test_data["pattern"] % eid @@ -375,4 +366,4 @@ def test_brain_observatory_cache_get_events_data(brain_observatory_cache, events events = brain_observatory_cache.get_ophys_experiment_events(eid) true_events = np.load(data_file, allow_pickle=False)["ev"] - assert(np.all(events == true_events)) + assert np.all(events == true_events) diff --git a/allensdk/test/core/test_brain_observatory_nwb_data_set.py b/allensdk/test/core/test_brain_observatory_nwb_data_set.py index e616a8084b..2b5bf0f5ad 100755 --- a/allensdk/test/core/test_brain_observatory_nwb_data_set.py +++ b/allensdk/test/core/test_brain_observatory_nwb_data_set.py @@ -44,16 +44,15 @@ from test_h5_utilities import mem_h5 # noqa: F401 -- pytest fixture - NWB_FLAVORS = [] -if 'TEST_NWB_FILES' in os.environ: - nwb_list_file = os.environ['TEST_NWB_FILES'] +if "TEST_NWB_FILES" in os.environ: + nwb_list_file = os.environ["TEST_NWB_FILES"] else: - nwb_list_file = str(files('allensdk.test.core').joinpath('nwb_files.txt')) + nwb_list_file = str(files("allensdk.test.core").joinpath("nwb_files.txt")) -if os.environ.get('TEST_COMPLETE', None) == 'true': - with open(nwb_list_file, 'r') as f: +if os.environ.get("TEST_COMPLETE", None) == "true": + with open(nwb_list_file, "r") as f: NWB_FLAVORS = [l.strip() for l in f] @@ -68,26 +67,27 @@ def data_set(request): @pytest.fixture def stim_pres_h5(mem_h5): def make_stim_pres_h5(stimulus_name): - mem_h5.create_group('stimulus/presentation/{}'.format(stimulus_name)) - mem_h5.create_group('stimulus/not_presentation/{}'.format(stimulus_name)) + mem_h5.create_group("stimulus/presentation/{}".format(stimulus_name)) + mem_h5.create_group("stimulus/not_presentation/{}".format(stimulus_name)) return mem_h5 + return make_stim_pres_h5 @pytest.fixture def abstract_feature_series_h5(mem_h5): def make_abstract_feature_series_h5(stimulus_name, stim_data, features, frame_dur): - - stimulus_path = 'stimulus/presentation/{}'.format(stimulus_name) - frame_dur_path = '{}/frame_duration'.format(stimulus_path) - features_path = '{}/features'.format(stimulus_path) - stim_data_path = '{}/data'.format(stimulus_path) + stimulus_path = "stimulus/presentation/{}".format(stimulus_name) + frame_dur_path = "{}/frame_duration".format(stimulus_path) + features_path = "{}/features".format(stimulus_path) + stim_data_path = "{}/data".format(stimulus_path) mem_h5[frame_dur_path] = frame_dur mem_h5[stim_data_path] = stim_data mem_h5[features_path] = features return mem_h5 + return make_abstract_feature_series_h5 @@ -103,21 +103,37 @@ def test_get_roi_ids(data_set): ids = data_set.get_roi_ids() assert len(ids) == len(data_set.get_cell_specimen_ids()) + def test_get_metadata(data_set): md = data_set.get_metadata() - valid_fields = [ 'genotype', 'cre_line', 'imaging_depth_um', 'ophys_experiment_id', 'experiment_container_id', - 'session_start_time', 'age_days', 'device', 'device_name', 'pipeline_version', 'sex', - 'targeted_structure', 'excitation_lambda', 'indicator', 'fov', 'session_type', 'specimen_name' ] - - invalid_fields = [ 'imaging_depth', 'age', 'device_string', 'generated_by' ] + valid_fields = [ + "genotype", + "cre_line", + "imaging_depth_um", + "ophys_experiment_id", + "experiment_container_id", + "session_start_time", + "age_days", + "device", + "device_name", + "pipeline_version", + "sex", + "targeted_structure", + "excitation_lambda", + "indicator", + "fov", + "session_type", + "specimen_name", + ] + + invalid_fields = ["imaging_depth", "age", "device_string", "generated_by"] for field in valid_fields: assert md[field] is not None for field in invalid_fields: assert field not in md - def test_get_cell_specimen_indices(data_set): @@ -174,8 +190,8 @@ def test_get_dff_traces(data_set): timestamps, traces = data_set.get_dff_traces([ids[0]]) assert traces.shape[0] == 1 -def test_get_neuropil_r(data_set): +def test_get_neuropil_r(data_set): ids = data_set.get_cell_specimen_ids() r = data_set.get_neuropil_r() assert len(ids) == len(r) @@ -187,6 +203,7 @@ def test_get_neuropil_r(data_set): r = data_set.get_neuropil_r(short_list) assert len(short_list) == len(r) + def test_get_corrected_fluorescence_traces(data_set): ids = data_set.get_cell_specimen_ids() @@ -231,7 +248,6 @@ def test_get_roi_mask_array(data_set): def test_get_stimulus_epoch_table(data_set): - summary_df = data_set.get_stimulus_epoch_table() session_type = data_set.get_session_type() @@ -242,11 +258,11 @@ def test_get_stimulus_epoch_table(data_set): elif session_type == si.THREE_SESSION_C2: assert len(summary_df) == 10 else: - raise NotImplementedError('Code not tested for session of type: %s' % session_type) + raise NotImplementedError("Code not tested for session of type: %s" % session_type) -def test_get_stimulus_table_master(data_set): - master_df = data_set.get_stimulus_table('master') +def test_get_stimulus_table_master(data_set): + master_df = data_set.get_stimulus_table("master") session_type = data_set.get_session_type() if session_type == si.THREE_SESSION_A: @@ -258,97 +274,89 @@ def test_get_stimulus_table_master(data_set): elif session_type == si.THREE_SESSION_C2: assert len(master_df) == 29398 else: - raise NotImplementedError('Code not tested for session of type: %s' % session_type) + raise NotImplementedError("Code not tested for session of type: %s" % session_type) def test_make_indexed_time_series_stimulus_table(): - frame_dur_exp = np.arange(20).reshape((10, 2)) inds_exp = np.arange(10) obt = bonds._make_indexed_time_series_stimulus_table(inds_exp, frame_dur_exp) - frame_dur_obt = np.array([ obt['start'].values, obt['end'].values ]).T - assert(np.allclose( frame_dur_obt, frame_dur_exp )) + frame_dur_obt = np.array([obt["start"].values, obt["end"].values]).T + assert np.allclose(frame_dur_obt, frame_dur_exp) def test_make_indexed_time_series_stimulus_table_out_of_order(): - frame_dur_exp = np.arange(20).reshape((10, 2)) frame_dur_file = frame_dur_exp.copy()[::-1, :] inds_exp = np.arange(10) obt = bonds._make_indexed_time_series_stimulus_table(inds_exp, frame_dur_file) - frame_dur_obt = np.array([ obt['start'].values, obt['end'].values ]).T - assert(np.allclose( frame_dur_obt, frame_dur_exp )) + frame_dur_obt = np.array([obt["start"].values, obt["end"].values]).T + assert np.allclose(frame_dur_obt, frame_dur_exp) def test_make_abstract_feature_series_stimulus_table_out_of_order(): - frame_dur_exp = np.arange(20).reshape((10, 2)) frame_dur_file = frame_dur_exp.copy()[::-1, :] - features_exp = ['orientation', 'spatial_frequency', 'phase'] + features_exp = ["orientation", "spatial_frequency", "phase"] data_exp = np.arange(30).reshape((10, 3)) data_file = data_exp.copy()[::-1, :] obt = bonds._make_abstract_feature_series_stimulus_table(data_file, features_exp, frame_dur_file) - frame_dur_obt = np.array([ obt['start'].values, obt['end'].values ]).T - assert(np.allclose( frame_dur_obt, frame_dur_exp )) + frame_dur_obt = np.array([obt["start"].values, obt["end"].values]).T + assert np.allclose(frame_dur_obt, frame_dur_exp) - data_obt = np.array([ obt['orientation'].values, obt['spatial_frequency'].values, obt['phase'].values ]).T - assert(np.allclose( data_obt, data_exp )) + data_obt = np.array([obt["orientation"].values, obt["spatial_frequency"].values, obt["phase"].values]).T + assert np.allclose(data_obt, data_exp) def test_make_spontanous_activity_stimulus_table(): - table_values_exp = [[0, 2], [4, 6]] frame_dur = np.arange(8).reshape((4, 2)) - events = np.array([ 1, -1, 1, -1 ]) + events = np.array([1, -1, 1, -1]) obt = bonds._make_spontaneous_activity_stimulus_table(events, frame_dur) - assert(np.allclose( obt.values, table_values_exp )) + assert np.allclose(obt.values, table_values_exp) def test_make_repeated_indexed_time_series_stimulus_table(): - frame_dur_exp = np.arange(20).reshape((10, 2)) inds_exp = np.array([0, 1, 2, 3, 4, 0, 1, 2, 3, 4]) repeats_exp = np.array([0] * 5 + [1] * 5) - + obt = bonds._make_repeated_indexed_time_series_stimulus_table(inds_exp, frame_dur_exp) - frame_dur_obt = np.array([ obt['start'].values, obt['end'].values ]).T - assert(np.allclose( frame_dur_obt, frame_dur_exp )) - assert(np.allclose( repeats_exp, obt['repeat'] )) + frame_dur_obt = np.array([obt["start"].values, obt["end"].values]).T + assert np.allclose(frame_dur_obt, frame_dur_exp) + assert np.allclose(repeats_exp, obt["repeat"]) def test_find_stimulus_presentation_group(stim_pres_h5): - - stimulus_name = 'fish' + stimulus_name = "fish" stim_pres_h5 = stim_pres_h5(stimulus_name) obt = bonds._find_stimulus_presentation_group(stim_pres_h5, stimulus_name) - assert( obt.name == '/stimulus/presentation/fish' ) + assert obt.name == "/stimulus/presentation/fish" def test_find_stimulus_presentation_group_missing(stim_pres_h5): - - stimulus_name = 'fish' - stim_pres_h5 = stim_pres_h5('fowl') + stimulus_name = "fish" + stim_pres_h5 = stim_pres_h5("fowl") with pytest.raises(MissingStimulusException): bonds._find_stimulus_presentation_group(stim_pres_h5, stimulus_name) def test_find_stimulus_presentation_group_duplicate(stim_pres_h5): - - stimulus_name = 'fish' - stim_pres_h5 = stim_pres_h5('fish') - stim_pres_h5.create_group('/stimulus/presentation/fish_stimulus') + stimulus_name = "fish" + stim_pres_h5 = stim_pres_h5("fish") + stim_pres_h5.create_group("/stimulus/presentation/fish_stimulus") with pytest.raises(MissingStimulusException): bonds._find_stimulus_presentation_group(stim_pres_h5, stimulus_name) diff --git a/allensdk/test/core/test_cell_filters.py b/allensdk/test/core/test_cell_filters.py index 6385db40eb..ae1008b634 100644 --- a/allensdk/test/core/test_cell_filters.py +++ b/allensdk/test/core/test_cell_filters.py @@ -40,10 +40,8 @@ from zipfile import ZipFile from unittest.mock import patch, mock_open, MagicMock from test_brain_observatory_cache import CACHE_MANIFEST -from allensdk.core.brain_observatory_cache \ - import BrainObservatoryCache -from allensdk.api.queries.brain_observatory_api \ - import BrainObservatoryApi +from allensdk.core.brain_observatory_cache import BrainObservatoryCache +from allensdk.api.queries.brain_observatory_api import BrainObservatoryApi try: @@ -51,94 +49,99 @@ except ModuleNotFoundError: import builtins # @UnresolvedImport -CELL_SPECIMEN_ZIP_URL = ("http://observatory.brain-map.org/visualcoding/" - "data/cell_metrics.csv.zip") +CELL_SPECIMEN_ZIP_URL = "http://observatory.brain-map.org/visualcoding/data/cell_metrics.csv.zip" @pytest.fixture def cells(): - return [{u'tld1_id': 177839004, - u'natural_movie_two_small': None, - u'natural_movie_one_a_small': None, - u'speed_tuning_c_large': None, - u'speed_tuning_c_small': None, - u'drifting_grating_small': None, - u'tld1_name': u'Cux2-CreERT2', - u'imaging_depth': 275, - u'tlr1_id': 265943423, - u'pref_dir_dg': None, - u'osi_sg': 0.728589701688166, - u'osi_dg': None, - u'tlr1_name': u'Ai93(TITL-GCaMP6f)', - u'area': u'VISpm', - u'pref_image_ns': 89.0, - u'natural_movie_one_c_small': None, - u'locally_sparse_noise_on_small': None, - u'drifting_grating_large': None, - u'experiment_container_id': 511498500, - u'natural_movie_one_a_large': None, - u'natural_movie_one_c_large': None, - u'tld2_name': u'Camk2a-tTA', - u'p_ns': 2.64407299505246e-05, - u'natural_movie_three_large': None, - u'pref_ori_sg': 30.0, - u'speed_tuning_a_large': None, - u'p_dg': None, - u'time_to_peak_sg': 0.199499999999999, - u'p_sg': 7.60972815250796e-05, - u'time_to_peak_ns': 0.299249999999998, - u'locally_sparse_noise_on_large': None, - u'dsi_dg': None, - u'pref_tf_dg': None, - u'natural_movie_three_small': None, - u'pref_sf_sg': 0.32, - u'tld2_id': 177837320, - u'locally_sparse_noise_off_large': None, - u'locally_sparse_noise_off_small': None, - u'cell_specimen_id': 517394843, - u'pref_phase_sg': 0.5}, - {u'tld1_id': 177839004, - u'natural_movie_two_small': None, - u'natural_movie_one_a_small': None, - u'speed_tuning_c_large': None, - u'speed_tuning_c_small': None, - u'drifting_grating_small': None, - u'tld1_name': u'Cux2-CreERT2', - u'imaging_depth': 275, - u'tlr1_id': 265943423, - u'natural_movie_two_large': None, - u'speed_tuning_a_small': None, - u'pref_dir_dg': None, - u'osi_sg': 0.899272239777491, - u'osi_dg': None, - u'tlr1_name': u'Ai93(TITL-GCaMP6f)', - u'area': u'VISpm', - u'pref_image_ns': 15.0, - u'natural_movie_one_c_small': None, - u'locally_sparse_noise_on_small': None, - u'drifting_grating_large': None, - u'experiment_container_id': 511498500, - u'natural_movie_one_a_large': None, - u'natural_movie_one_c_large': None, - u'tld2_name': u'Camk2a-tTA', - u'p_ns': 0.000356823517642681, - u'natural_movie_three_large': None, - u'pref_ori_sg': 0.0, - u'speed_tuning_a_large': None, - u'p_dg': None, - u'time_to_peak_sg': 0.565249999999996, - u'p_sg': 0.0565790644804479, - u'time_to_peak_ns': 0.432249999999997, - u'locally_sparse_noise_on_large': None, - u'dsi_dg': None, - u'pref_tf_dg': None, - u'natural_movie_three_small': None, - u'pref_sf_sg': 0.32, - u'tld2_id': 177837320, - u'locally_sparse_noise_off_large': None, - u'locally_sparse_noise_off_small': None, - u'cell_specimen_id': 517394850, - u'pref_phase_sg': 0.5}] + return [ + { + "tld1_id": 177839004, + "natural_movie_two_small": None, + "natural_movie_one_a_small": None, + "speed_tuning_c_large": None, + "speed_tuning_c_small": None, + "drifting_grating_small": None, + "tld1_name": "Cux2-CreERT2", + "imaging_depth": 275, + "tlr1_id": 265943423, + "pref_dir_dg": None, + "osi_sg": 0.728589701688166, + "osi_dg": None, + "tlr1_name": "Ai93(TITL-GCaMP6f)", + "area": "VISpm", + "pref_image_ns": 89.0, + "natural_movie_one_c_small": None, + "locally_sparse_noise_on_small": None, + "drifting_grating_large": None, + "experiment_container_id": 511498500, + "natural_movie_one_a_large": None, + "natural_movie_one_c_large": None, + "tld2_name": "Camk2a-tTA", + "p_ns": 2.64407299505246e-05, + "natural_movie_three_large": None, + "pref_ori_sg": 30.0, + "speed_tuning_a_large": None, + "p_dg": None, + "time_to_peak_sg": 0.199499999999999, + "p_sg": 7.60972815250796e-05, + "time_to_peak_ns": 0.299249999999998, + "locally_sparse_noise_on_large": None, + "dsi_dg": None, + "pref_tf_dg": None, + "natural_movie_three_small": None, + "pref_sf_sg": 0.32, + "tld2_id": 177837320, + "locally_sparse_noise_off_large": None, + "locally_sparse_noise_off_small": None, + "cell_specimen_id": 517394843, + "pref_phase_sg": 0.5, + }, + { + "tld1_id": 177839004, + "natural_movie_two_small": None, + "natural_movie_one_a_small": None, + "speed_tuning_c_large": None, + "speed_tuning_c_small": None, + "drifting_grating_small": None, + "tld1_name": "Cux2-CreERT2", + "imaging_depth": 275, + "tlr1_id": 265943423, + "natural_movie_two_large": None, + "speed_tuning_a_small": None, + "pref_dir_dg": None, + "osi_sg": 0.899272239777491, + "osi_dg": None, + "tlr1_name": "Ai93(TITL-GCaMP6f)", + "area": "VISpm", + "pref_image_ns": 15.0, + "natural_movie_one_c_small": None, + "locally_sparse_noise_on_small": None, + "drifting_grating_large": None, + "experiment_container_id": 511498500, + "natural_movie_one_a_large": None, + "natural_movie_one_c_large": None, + "tld2_name": "Camk2a-tTA", + "p_ns": 0.000356823517642681, + "natural_movie_three_large": None, + "pref_ori_sg": 0.0, + "speed_tuning_a_large": None, + "p_dg": None, + "time_to_peak_sg": 0.565249999999996, + "p_sg": 0.0565790644804479, + "time_to_peak_ns": 0.432249999999997, + "locally_sparse_noise_on_large": None, + "dsi_dg": None, + "pref_tf_dg": None, + "natural_movie_three_small": None, + "pref_sf_sg": 0.32, + "tld2_id": 177837320, + "locally_sparse_noise_off_large": None, + "locally_sparse_noise_off_small": None, + "cell_specimen_id": 517394850, + "pref_phase_sg": 0.5, + }, + ] @pytest.fixture @@ -160,15 +163,12 @@ def unmocked_boc(fn_temp_dir): def brain_observatory_cache(fn_temp_dir): boc = None - manifest_data = bytes(CACHE_MANIFEST, 'UTF-8') + manifest_data = bytes(CACHE_MANIFEST, "UTF-8") - with patch('os.path.exists', - return_value=True): - with patch(builtins.__name__ + ".open", - mock_open(read_data=manifest_data)): + with patch("os.path.exists", return_value=True): + with patch(builtins.__name__ + ".open", mock_open(read_data=manifest_data)): manifest_file = os.path.join(fn_temp_dir, "boc", "manifest.json") - boc = BrainObservatoryCache(manifest_file=manifest_file, - base_uri='http://api.brain-map.org') + boc = BrainObservatoryCache(manifest_file=manifest_file, base_uri="http://api.brain-map.org") return boc @@ -181,8 +181,7 @@ def cell_specimen_table(tmpdir_factory): data_dir = str(tmpdir_factory.mktemp("data")) zipped = os.path.join("cell_specimens.zip") api.retrieve_file_over_http(CELL_SPECIMEN_ZIP_URL, zipped) - df = pd.read_csv(ZipFile(zipped).open("cell_metrics.csv"), - true_values=["t"], false_values=["f"]) + df = pd.read_csv(ZipFile(zipped).open("cell_metrics.csv"), true_values=["t"], false_values=["f"]) js = json.loads(df.to_json(orient="records")) table_file = os.path.join(data_dir, "cell_specimens.json") with open(table_file, "w") as f: @@ -192,68 +191,52 @@ def cell_specimen_table(tmpdir_factory): @pytest.fixture def example_filters(): - f = [{"field": "p_dg", - "op": "<=", - "value": 0.001}, - {"field": "pref_dir_dg", - "op": "=", "value": 45}, - {"field": "area", "op": "in", "value": ["VISpm"]}, - {"field": "tld1_name", - "op": "in", - "value": ["Rbp4-Cre", "Cux2-CreERT2", "Rorb-IRES2-Cre"]}] + f = [ + {"field": "p_dg", "op": "<=", "value": 0.001}, + {"field": "pref_dir_dg", "op": "=", "value": 45}, + {"field": "area", "op": "in", "value": ["VISpm"]}, + {"field": "tld1_name", "op": "in", "value": ["Rbp4-Cre", "Cux2-CreERT2", "Rorb-IRES2-Cre"]}, + ] return f @pytest.fixture def between_filter(): - f = [{"field": "p_ns", - "op": "between", - "value": [0.00034, 0.00035]}] + f = [{"field": "p_ns", "op": "between", "value": [0.00034, 0.00035]}] return f FILTER_OPERATORS = ["=", "<", ">", "<=", ">=", "between", "in", "is"] QUERY_TEMPLATES = { - "=": '({0} == {1})', - "<": '({0} < {1})', - ">": '({0} > {1})', - "<=": '({0} <= {1})', - ">=": '({0} >= {1})', - "between": '({0} >= {1}) and ({0} <= {1})', - "in": '({0} == {1})', - "is": '({0} == {1})' + "=": "({0} == {1})", + "<": "({0} < {1})", + ">": "({0} > {1})", + "<=": "({0} <= {1})", + ">=": "({0} >= {1})", + "between": "({0} >= {1}) and ({0} <= {1})", + "in": "({0} == {1})", + "is": "({0} == {1})", } @pytest.mark.skipif(True, reason="not done") @patch.object(BrainObservatoryApi, "json_msg_query") -def test_dataframe_query(mock_json_msg_query, - brain_observatory_cache, - between_filter, - cells): +def test_dataframe_query(mock_json_msg_query, brain_observatory_cache, between_filter, cells): brain_observatory_cache = unmocked_boc - with patch('os.path.exists', - MagicMock(return_value=True)): - with patch('allensdk.core.json_utilities.read', - MagicMock(return_value=cells)): - cells = brain_observatory_cache.get_cell_specimens( - filters=between_filter) + with patch("os.path.exists", MagicMock(return_value=True)): + with patch("allensdk.core.json_utilities.read", MagicMock(return_value=cells)): + cells = brain_observatory_cache.get_cell_specimens(filters=between_filter) assert len(cells) > 0 @pytest.mark.todo_flaky -def test_dataframe_query_unmocked(unmocked_boc, - example_filters, - cells, - cell_specimen_table): +def test_dataframe_query_unmocked(unmocked_boc, example_filters, cells, cell_specimen_table): brain_observatory_cache = unmocked_boc - cells = brain_observatory_cache.get_cell_specimens( - filters=example_filters, - file_name=cell_specimen_table) + cells = brain_observatory_cache.get_cell_specimens(filters=example_filters, file_name=cell_specimen_table) # total lines = 18260, can make fail by passing no filters # expected = 105 @@ -261,15 +244,10 @@ def test_dataframe_query_unmocked(unmocked_boc, @pytest.mark.todo_flaky -def test_dataframe_query_between_unmocked(unmocked_boc, - between_filter, - cells, - cell_specimen_table): +def test_dataframe_query_between_unmocked(unmocked_boc, between_filter, cells, cell_specimen_table): brain_observatory_cache = unmocked_boc - cells = brain_observatory_cache.get_cell_specimens( - filters=between_filter, - file_name=cell_specimen_table) + cells = brain_observatory_cache.get_cell_specimens(filters=between_filter, file_name=cell_specimen_table) # total lines = 18260, can make fail by passing no filters # expected = 15 @@ -277,39 +255,26 @@ def test_dataframe_query_between_unmocked(unmocked_boc, @pytest.mark.todo_flaky -def test_dataframe_query_is_unmocked(unmocked_boc, - cells, - cell_specimen_table): +def test_dataframe_query_is_unmocked(unmocked_boc, cells, cell_specimen_table): brain_observatory_cache = unmocked_boc - is_filter = [ - {"field": "all_stim", - "op": "is", - "value": True}] + is_filter = [{"field": "all_stim", "op": "is", "value": True}] - cells = brain_observatory_cache.get_cell_specimens( - filters=is_filter, - file_name=cell_specimen_table) + cells = brain_observatory_cache.get_cell_specimens(filters=is_filter, file_name=cell_specimen_table) assert len(cells) > 0 def test_dataframe_query_string_between(api): - filters = [ - {"field": "p_ns", - "op": "between", - "value": [0.00034, 0.00035]}] + filters = [{"field": "p_ns", "op": "between", "value": [0.00034, 0.00035]}] query_string = api.dataframe_query_string(filters) - assert query_string == '(p_ns >= 0.00034) and (p_ns <= 0.00035)' + assert query_string == "(p_ns >= 0.00034) and (p_ns <= 0.00035)" def test_dataframe_query_string_in(api): - filters = [ - {"field": "name", - "op": "in", - "value": ['Abc', 'Def', 'Ghi']}] + filters = [{"field": "name", "op": "in", "value": ["Abc", "Def", "Ghi"]}] query_string = api.dataframe_query_string(filters) @@ -317,10 +282,7 @@ def test_dataframe_query_string_in(api): def test_dataframe_query_string_in_floats(api): - filters = [ - {"field": "rating", - "op": "in", - "value": [9.9, 8.7, 0.1]}] + filters = [{"field": "rating", "op": "in", "value": [9.9, 8.7, 0.1]}] query_string = api.dataframe_query_string(filters) @@ -328,21 +290,19 @@ def test_dataframe_query_string_in_floats(api): def test_dataframe_query_string_is_boolean(api): - filters = [ - {"field": "fact_check", - "op": "is", - "value": False}] + filters = [{"field": "fact_check", "op": "is", "value": False}] query_string = api.dataframe_query_string(filters) assert query_string == "(fact_check == False)" -def test_dataframe_query_string_multi_filters(api, - example_filters): +def test_dataframe_query_string_multi_filters(api, example_filters): query_string = api.dataframe_query_string(example_filters) - assert query_string == ("(p_dg <= 0.001) & (pref_dir_dg == 45) & " - "(area == ['VISpm']) & " - "(tld1_name == " - "['Rbp4-Cre', 'Cux2-CreERT2', 'Rorb-IRES2-Cre'])") + assert query_string == ( + "(p_dg <= 0.001) & (pref_dir_dg == 45) & " + "(area == ['VISpm']) & " + "(tld1_name == " + "['Rbp4-Cre', 'Cux2-CreERT2', 'Rorb-IRES2-Cre'])" + ) diff --git a/allensdk/test/core/test_cell_types_cache_unit.py b/allensdk/test/core/test_cell_types_cache_unit.py index 056f1bc7a0..1dbf291d92 100644 --- a/allensdk/test/core/test_cell_types_cache_unit.py +++ b/allensdk/test/core/test_cell_types_cache_unit.py @@ -43,7 +43,7 @@ import itertools as it import pandas as pd -_MOCK_PATH = '/path/to/xyz.txt' +_MOCK_PATH = "/path/to/xyz.txt" @pytest.fixture(scope="session", autouse=True) @@ -76,28 +76,24 @@ def cache_fixture(tmpdir_factory): return ctc -@pytest.mark.parametrize('path_exists', - (False, True)) -@patch('allensdk.core.cell_types_cache.NwbDataSet') -def test_sweep_data_with_api(mock_nwb, - cache_fixture, - path_exists): +@pytest.mark.parametrize("path_exists", (False, True)) +@patch("allensdk.core.cell_types_cache.NwbDataSet") +def test_sweep_data_with_api(mock_nwb, cache_fixture, path_exists): ctc = cache_fixture specimen_id = 464212183 - ephys_result = [{'ephys_result': - {'well_known_files': [ - {'download_link': '/path/to/data.nwb' }]}}] + ephys_result = [{"ephys_result": {"well_known_files": [{"download_link": "/path/to/data.nwb"}]}}] # this saves the NWB file to 'cell_types/specimen_464212183/ephys.nwb' with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http') as mock_http: - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=ephys_result)) as query_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)) as ope: - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs') as mkd: + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http") as mock_http: + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=ephys_result), + ) as query_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)) as ope: + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs") as mkd: mock_nwb.reset_mock() _ = ctc.get_ephys_data(specimen_id, _MOCK_PATH) @@ -110,8 +106,7 @@ def test_sweep_data_with_api(mock_nwb, # both levels of cacheable methods check if the directory exists. assert mkd.call_args_list == [call(_MOCK_PATH)] assert query_mock.called - mock_http.assert_called_once_with('http://api.brain-map.org/path/to/data.nwb', - _MOCK_PATH) + mock_http.assert_called_once_with("http://api.brain-map.org/path/to/data.nwb", _MOCK_PATH) def test_sweep_data_exception(cache_fixture): @@ -119,159 +114,151 @@ def test_sweep_data_exception(cache_fixture): specimen_id = 464212183 - ephys_result = [{'ephys_result': - {'well_known_files': [] }}] + ephys_result = [{"ephys_result": {"well_known_files": []}}] with pytest.raises(Exception) as exc: with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http'): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=ephys_result)): - with patch('os.path.exists', MagicMock(return_value=False)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): - with patch('allensdk.core.cell_types_cache.NwbDataSet'): + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http"): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=ephys_result), + ): + with patch("os.path.exists", MagicMock(return_value=False)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): + with patch("allensdk.core.cell_types_cache.NwbDataSet"): _ = ctc.get_ephys_data(specimen_id) - - assert 'has no ephys data' in str(exc.value) - - -@pytest.mark.parametrize('path_exists,morph_flag,recon_flag,statuses,species,simple', - it.product((False, True), - (False, True), - (False, True), - (RS.POSITIVE, ['list', 'of', 'statuses']), - (None, ['mouse'], ['human']), - (False,))) - -def test_get_cells(cache_fixture, - path_exists, - morph_flag, - recon_flag, - statuses, - species, - simple): + + assert "has no ephys data" in str(exc.value) + + +@pytest.mark.parametrize( + "path_exists,morph_flag,recon_flag,statuses,species,simple", + it.product( + (False, True), + (False, True), + (False, True), + (RS.POSITIVE, ["list", "of", "statuses"]), + (None, ["mouse"], ["human"]), + (False,), + ), +) +def test_get_cells(cache_fixture, path_exists, morph_flag, recon_flag, statuses, species, simple): ctc = cache_fixture # this downloads metadata for all cells with morphology images with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.core.json_utilities.read', - return_value=['mock_cells_from_server']): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.list_cells_api', - MagicMock(return_value=['mock_cells_from_server'])): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.filter_cells_api', - MagicMock(return_value=['mock_cells'])) as filter_cells_mock: - with patch('allensdk.core.json_utilities.write'): - cells = ctc.get_cells(require_morphology=morph_flag, - require_reconstruction=recon_flag, - reporter_status=statuses, - species=species, - simple=simple) - - assert cells == ['mock_cells'] - - if (statuses == RS.POSITIVE): + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.core.json_utilities.read", return_value=["mock_cells_from_server"]): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.list_cells_api", + MagicMock(return_value=["mock_cells_from_server"]), + ): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.filter_cells_api", + MagicMock(return_value=["mock_cells"]), + ) as filter_cells_mock: + with patch("allensdk.core.json_utilities.write"): + cells = ctc.get_cells( + require_morphology=morph_flag, + require_reconstruction=recon_flag, + reporter_status=statuses, + species=species, + simple=simple, + ) + + assert cells == ["mock_cells"] + + if statuses == RS.POSITIVE: expected_status = [statuses] else: expected_status = statuses - filter_cells_mock.assert_called_once_with(['mock_cells_from_server'], - morph_flag, - recon_flag, - expected_status, - species, - simple) - - -@pytest.mark.parametrize('path_exists,morph_flag,recon_flag,statuses', - it.product((False, True), - (False, True), - (False, True), - (RS.POSITIVE, ['list', 'of', 'statuses']))) -def test_get_cells_with_api(cache_fixture, - path_exists, - morph_flag, - recon_flag, - statuses): + filter_cells_mock.assert_called_once_with( + ["mock_cells_from_server"], morph_flag, recon_flag, expected_status, species, simple + ) + + +@pytest.mark.parametrize( + "path_exists,morph_flag,recon_flag,statuses", + it.product((False, True), (False, True), (False, True), (RS.POSITIVE, ["list", "of", "statuses"])), +) +def test_get_cells_with_api(cache_fixture, path_exists, morph_flag, recon_flag, statuses): ctc = cache_fixture # note, this is only a mock for coverage, # and has not a lot of relation to the actual data form sweeps = [1, 2, 3] - return_dicts = [{'sweep_number': x, - 'tags': ['what - ever'], - 'neuron_reconstructions' : [], - 'data_sets': [], - 'reporter_status': 'whatever', - 'has_morphology': False, - 'has_reconstruction': False, - 'donor': { 'transgenic_lines': [{'transgenic_line_type_name': 'driver', - 'name': 'harold'}]}, - 'cell_reporter': {'name': 'tired'}, - 'specimen_tags': [{'name': 'a - b', - 'value': 123}]} for x \ - in sweeps] + return_dicts = [ + { + "sweep_number": x, + "tags": ["what - ever"], + "neuron_reconstructions": [], + "data_sets": [], + "reporter_status": "whatever", + "has_morphology": False, + "has_reconstruction": False, + "donor": {"transgenic_lines": [{"transgenic_line_type_name": "driver", "name": "harold"}]}, + "cell_reporter": {"name": "tired"}, + "specimen_tags": [{"name": "a - b", "value": 123}], + } + for x in sweeps + ] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', return_value=return_dicts)): - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.core.json_utilities.read', - return_value=return_dicts) as ju_read: - with patch('allensdk.core.json_utilities.write') as ju_write: - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): - ctc.get_cells(require_morphology=morph_flag, - require_reconstruction=recon_flag, - reporter_status=statuses, - simple=True) + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=return_dicts), + ): + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.core.json_utilities.read", return_value=return_dicts) as ju_read: + with patch("allensdk.core.json_utilities.write") as ju_write: + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): + ctc.get_cells( + require_morphology=morph_flag, + require_reconstruction=recon_flag, + reporter_status=statuses, + simple=True, + ) if path_exists: ju_read.assert_called_once_with(_MOCK_PATH) else: assert ju_write.called -@pytest.mark.parametrize('path_exists', - (False, True)) -def test_get_reconstruction(cache_fixture, - cell_id, - path_exists): + +@pytest.mark.parametrize("path_exists", (False, True)) +def test_get_reconstruction(cache_fixture, cell_id, path_exists): ctc = cache_fixture - save_recon = \ - 'allensdk.api.queries.cell_types_api.CellTypesApi.save_reconstruction' + save_recon = "allensdk.api.queries.cell_types_api.CellTypesApi.save_reconstruction" with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): with patch(save_recon) as save_recon_mock: - with patch('allensdk.core.swc.read_swc') as read_swc_mock: + with patch("allensdk.core.swc.read_swc") as read_swc_mock: # download and open an SWC file _ = ctc.get_reconstruction(cell_id) if path_exists is False: - save_recon_mock.assert_called_once_with(cell_id, - _MOCK_PATH) + save_recon_mock.assert_called_once_with(cell_id, _MOCK_PATH) read_swc_mock.assert_called_once_with(_MOCK_PATH) -@pytest.mark.parametrize('path_exists', - (False, True)) +@pytest.mark.parametrize("path_exists", (False, True)) @patch.object(DataFrame, "to_csv") -def test_get_reconstruction_with_api(to_csv, - cache_fixture, - cell_id, - path_exists): +def test_get_reconstruction_with_api(to_csv, cache_fixture, cell_id, path_exists): ctc = cache_fixture - reconstruction_data = [{'neuron_reconstructions': [ - {'well_known_files': [ - {'download_link': 'http://example.org'}]}]}] + reconstruction_data = [ + {"neuron_reconstructions": [{"well_known_files": [{"download_link": "http://example.org"}]}]} + ] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http'): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=reconstruction_data)) as query_mock: - with patch('allensdk.core.swc.read_swc') as read_swc_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http"): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=reconstruction_data), + ) as query_mock: + with patch("allensdk.core.swc.read_swc") as read_swc_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): _ = ctc.get_reconstruction(cell_id) if path_exists: @@ -281,57 +268,49 @@ def test_get_reconstruction_with_api(to_csv, @patch.object(DataFrame, "to_csv") -def test_get_reconstruction_exception(to_csv, - cache_fixture, - cell_id): +def test_get_reconstruction_exception(to_csv, cache_fixture, cell_id): ctc = cache_fixture - reconstruction_data = [{'neuron_reconstructions': [ - {'well_known_files': None}]}] + reconstruction_data = [{"neuron_reconstructions": [{"well_known_files": None}]}] with pytest.raises(Exception) as exc: with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http'): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=reconstruction_data)): - with patch('allensdk.core.swc.read_swc'): - with patch('os.path.exists', MagicMock(return_value=False)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http"): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=reconstruction_data), + ): + with patch("allensdk.core.swc.read_swc"): + with patch("os.path.exists", MagicMock(return_value=False)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): _ = ctc.get_reconstruction(cell_id) - assert 'has no reconstruction' in str(exc.value) + assert "has no reconstruction" in str(exc.value) -@pytest.mark.parametrize('path_exists,lookup_error', - it.product((False, True), - (False, True))) -def test_get_reconstruction_markers(cache_fixture, - cell_id, - path_exists, - lookup_error): +@pytest.mark.parametrize("path_exists,lookup_error", it.product((False, True), (False, True))) +def test_get_reconstruction_markers(cache_fixture, cell_id, path_exists, lookup_error): ctc = cache_fixture if lookup_error: + def lookup(i, n): - raise(LookupError('mock lookup error')) + raise (LookupError("mock lookup error")) else: + def lookup(i, n): return - save_recon_marker = \ - 'allensdk.api.queries.cell_types_api.CellTypesApi.save_reconstruction_markers' + save_recon_marker = "allensdk.api.queries.cell_types_api.CellTypesApi.save_reconstruction_markers" # download and open a marker file with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch(save_recon_marker, - MagicMock(side_effect=lookup)) as save_recon_markers_mock: - with patch('allensdk.core.swc.read_marker_file') as read_marker_mock: + with patch(save_recon_marker, MagicMock(side_effect=lookup)) as save_recon_markers_mock: + with patch("allensdk.core.swc.read_marker_file") as read_marker_mock: _ = ctc.get_reconstruction_markers(cell_id) if path_exists is False: - save_recon_markers_mock.assert_called_once_with(cell_id, - _MOCK_PATH) + save_recon_markers_mock.assert_called_once_with(cell_id, _MOCK_PATH) if lookup_error: assert not read_marker_mock.called @@ -339,64 +318,55 @@ def lookup(i, n): read_marker_mock.assert_called_once_with(_MOCK_PATH) -@pytest.mark.parametrize('path_exists,lookup_error', - it.product((False, True), - (False, True))) -def test_get_reconstruction_markers_with_api(cache_fixture, - cell_id, - path_exists, - lookup_error): +@pytest.mark.parametrize("path_exists,lookup_error", it.product((False, True), (False, True))) +def test_get_reconstruction_markers_with_api(cache_fixture, cell_id, path_exists, lookup_error): ctc = cache_fixture - reconstruction_data = [{'neuron_reconstructions': [ - {'well_known_files': [ - {'download_link': '/mock/path_to_file'}]}]}] + reconstruction_data = [ + {"neuron_reconstructions": [{"well_known_files": [{"download_link": "/mock/path_to_file"}]}]} + ] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http') as mock_http: - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=reconstruction_data)): - with patch('allensdk.core.swc.read_marker_file') as marker_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http") as mock_http: + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=reconstruction_data), + ): + with patch("allensdk.core.swc.read_marker_file") as marker_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): _ = ctc.get_reconstruction_markers(cell_id) if path_exists: assert marker_mock.called else: - mock_http.assert_called_once_with('http://api.brain-map.org/mock/path_to_file', - _MOCK_PATH) + mock_http.assert_called_once_with("http://api.brain-map.org/mock/path_to_file", _MOCK_PATH) -def test_get_reconstruction_markers_exception(cache_fixture, - cell_id): +def test_get_reconstruction_markers_exception(cache_fixture, cell_id): ctc = cache_fixture - reconstruction_data = [{'neuron_reconstructions': [ - {'well_known_files': []}]}] + reconstruction_data = [{"neuron_reconstructions": [{"well_known_files": []}]}] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http'): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=reconstruction_data)): - with patch('allensdk.core.swc.read_marker_file'): - with patch('os.path.exists', MagicMock(return_value=False)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): + with patch("allensdk.api.queries.cell_types_api.CellTypesApi.retrieve_file_over_http"): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=reconstruction_data), + ): + with patch("allensdk.core.swc.read_marker_file"): + with patch("os.path.exists", MagicMock(return_value=False)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): markers = ctc.get_reconstruction_markers(cell_id) assert len(markers) == 0 -@pytest.mark.parametrize('dataframe', - (False, True)) -def test_get_ephys_features(cache_fixture, - dataframe): +@pytest.mark.parametrize("dataframe", (False, True)) +def test_get_ephys_features(cache_fixture, dataframe): ctc = cache_fixture - api_get_ephys_features = \ - 'allensdk.api.queries.cell_types_api.CellTypesApi.get_ephys_features' + api_get_ephys_features = "allensdk.api.queries.cell_types_api.CellTypesApi.get_ephys_features" with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): with patch(api_get_ephys_features) as api_get_ephys_features_mock: @@ -406,32 +376,22 @@ def test_get_ephys_features(cache_fixture, assert api_get_ephys_features_mock.called -@pytest.mark.parametrize('df,path_exists', - it.product((False,True), - (False,True))) +@pytest.mark.parametrize("df,path_exists", it.product((False, True), (False, True))) @patch.object(DataFrame, "to_csv") @patch("pandas.read_csv") -def test_get_ephys_features_with_api(read_csv, - to_csv, - cache_fixture, - df, - path_exists): +def test_get_ephys_features_with_api(read_csv, to_csv, cache_fixture, df, path_exists): ctc = cache_fixture - mock_data = [{'lorem': 1, - 'ipsum': 2 }, - {'lorem': 3, - 'ipsum': 4 }] + mock_data = [{"lorem": 1, "ipsum": 2}, {"lorem": 3, "ipsum": 4}] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=mock_data)) as query_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch(builtins.__name__ + '.open', - mock_open(), - create=True): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs') as mkd: + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=mock_data), + ) as query_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch(builtins.__name__ + ".open", mock_open(), create=True): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs") as mkd: _ = ctc.get_ephys_features(dataframe=df) if path_exists: @@ -441,21 +401,17 @@ def test_get_ephys_features_with_api(read_csv, assert query_mock.called -@pytest.mark.parametrize('df', (False, True)) -def test_get_ephys_features_cache_roundtrip(cached_csv, - cache_fixture, - df): +@pytest.mark.parametrize("df", (False, True)) +def test_get_ephys_features_cache_roundtrip(cached_csv, cache_fixture, df): ctc = cache_fixture - mock_data = [{'lorem': 1, - 'ipsum': 2 }, - {'lorem': 3, - 'ipsum': 4 }] + mock_data = [{"lorem": 1, "ipsum": 2}, {"lorem": 3, "ipsum": 4}] with patch.object(ctc, "get_cache_path", return_value=cached_csv): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=mock_data)): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=mock_data), + ): data = ctc.get_ephys_features() pandas_data = pd.read_csv(cached_csv, parse_dates=True) @@ -463,40 +419,29 @@ def test_get_ephys_features_cache_roundtrip(cached_csv, assert sorted(data[0].keys()) == sorted(pandas_data.columns) -@pytest.mark.parametrize('path_exists,df', - it.product((False, True), - (False, True))) +@pytest.mark.parametrize("path_exists,df", it.product((False, True), (False, True))) @patch.object(DataFrame, "to_csv") -@patch("pandas.read_csv", - return_value=DataFrame([{ 'stuff': 'whatever'}, - { 'stuff': 'nonsense'}])) -def test_get_morphology_features(read_csv, - to_csv, - cache_fixture, - path_exists, - df): +@patch("pandas.read_csv", return_value=DataFrame([{"stuff": "whatever"}, {"stuff": "nonsense"}])) +def test_get_morphology_features(read_csv, to_csv, cache_fixture, path_exists, df): ctc = cache_fixture - json_data = [{ 'stuff': 'whatever'}, - { 'stuff': 'nonsense'}] - + json_data = [{"stuff": "whatever"}, {"stuff": "nonsense"}] + with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs') as mkd: - with patch(builtins.__name__ + '.open', - mock_open(), - create=True): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=json_data)) as query_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs") as mkd: + with patch(builtins.__name__ + ".open", mock_open(), create=True): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=json_data), + ) as query_mock: data = ctc.get_morphology_features(df, _MOCK_PATH) if df: - assert ('stuff' in data) + assert "stuff" in data else: - assert all(['stuff' in f for f in data]) + assert all(["stuff" in f for f in data]) - if path_exists: if df: read_csv.assert_called_once_with(_MOCK_PATH, parse_dates=True) @@ -508,47 +453,41 @@ def test_get_morphology_features(read_csv, assert mkd.called -@pytest.mark.parametrize('path_exists', - (False, True)) -def test_get_ephys_sweeps(cache_fixture, - path_exists): +@pytest.mark.parametrize("path_exists", (False, True)) +def test_get_ephys_sweeps(cache_fixture, path_exists): ctc = cache_fixture cell_id = 464212183 - get_ephys_sweeps = \ - 'allensdk.api.queries.cell_types_api.CellTypesApi.get_ephys_sweeps' + get_ephys_sweeps = "allensdk.api.queries.cell_types_api.CellTypesApi.get_ephys_sweeps" with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): with patch(get_ephys_sweeps) as get_ephys_sweeps_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.core.json_utilities.read', - return_value=['mock_data']): - with patch('allensdk.core.json_utilities.write'): + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.core.json_utilities.read", return_value=["mock_data"]): + with patch("allensdk.core.json_utilities.write"): _ = ctc.get_ephys_sweeps(cell_id) if not path_exists: get_ephys_sweeps_mock.assert_called_once() -@pytest.mark.parametrize('path_exists', - (False, True)) -def test_get_ephys_sweeps_with_api(cache_fixture, - path_exists): +@pytest.mark.parametrize("path_exists", (False, True)) +def test_get_ephys_sweeps_with_api(cache_fixture, path_exists): ctc = cache_fixture cell_id = 464212183 sweeps = [1, 2, 3] - return_dicts = [{'sweep_number': x} for x in sweeps] + return_dicts = [{"sweep_number": x} for x in sweeps] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=return_dicts)) as query_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): - with patch('allensdk.core.json_utilities.read', - return_value=['mock_data']) as ju_read: - with patch('allensdk.core.json_utilities.write'): + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=return_dicts), + ) as query_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): + with patch("allensdk.core.json_utilities.read", return_value=["mock_data"]) as ju_read: + with patch("allensdk.core.json_utilities.write"): _ = ctc.get_ephys_sweeps(cell_id) # read will be called regardless @@ -560,58 +499,44 @@ def test_get_ephys_sweeps_with_api(cache_fixture, assert query_mock.called -@pytest.mark.parametrize('path_exists,require_reconstruction', - it.product((False, True), - (False, True))) -@patch('pandas.DataFrame.merge') +@pytest.mark.parametrize("path_exists,require_reconstruction", it.product((False, True), (False, True))) +@patch("pandas.DataFrame.merge") @patch.object(DataFrame, "to_csv") -@patch("pandas.read_csv", - return_value=DataFrame([{ 'stuff': 'whatever'}, - { 'stuff': 'nonsense'}])) -def test_get_all_features(read_csv, - to_csv, - mock_merge, - cache_fixture, - path_exists, - require_reconstruction): +@patch("pandas.read_csv", return_value=DataFrame([{"stuff": "whatever"}, {"stuff": "nonsense"}])) +def test_get_all_features(read_csv, to_csv, mock_merge, cache_fixture, path_exists, require_reconstruction): ctc = cache_fixture sweeps = [1, 2, 3] - return_dicts = [{'sweep_number': x, - 'tags': 'whatever'} for x in sweeps] + return_dicts = [{"sweep_number": x, "tags": "whatever"} for x in sweeps] with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.api.queries.cell_types_api.CellTypesApi.model_query', - MagicMock(name='model query', - return_value=return_dicts)) as query_mock: - with patch('os.path.exists', MagicMock(return_value=path_exists)): - with patch('allensdk.config.manifest.Manifest.safe_make_parent_dirs'): - with patch('allensdk.core.json_utilities.read', - return_value=return_dicts): - with patch(builtins.__name__ + '.open', - mock_open(), - create=True): - with patch('allensdk.core.json_utilities.write'): - _ = ctc.get_all_features( - require_reconstruction=require_reconstruction) + with patch( + "allensdk.api.queries.cell_types_api.CellTypesApi.model_query", + MagicMock(name="model query", return_value=return_dicts), + ) as query_mock: + with patch("os.path.exists", MagicMock(return_value=path_exists)): + with patch("allensdk.config.manifest.Manifest.safe_make_parent_dirs"): + with patch("allensdk.core.json_utilities.read", return_value=return_dicts): + with patch(builtins.__name__ + ".open", mock_open(), create=True): + with patch("allensdk.core.json_utilities.write"): + _ = ctc.get_all_features(require_reconstruction=require_reconstruction) if path_exists: assert read_csv.called else: assert query_mock.called - + assert mock_merge.called def test_build_manifest(cache_fixture): ctc = cache_fixture - - mb_mock = MagicMock(name='manifest builder') + + mb_mock = MagicMock(name="manifest builder") with patch.object(ctc, "get_cache_path", return_value=_MOCK_PATH): - with patch('allensdk.core.cell_types_cache.ManifestBuilder', - return_value=mb_mock): - ctc.build_manifest('test_manifest.json') + with patch("allensdk.core.cell_types_cache.ManifestBuilder", return_value=mb_mock): + ctc.build_manifest("test_manifest.json") assert mb_mock.add_path.call_count == 8 - mb_mock.write_json_file.assert_called_once_with('test_manifest.json') + mb_mock.write_json_file.assert_called_once_with("test_manifest.json") diff --git a/allensdk/test/core/test_datafame_utils.py b/allensdk/test/core/test_datafame_utils.py index d330ac33db..7a205a0130 100644 --- a/allensdk/test/core/test_datafame_utils.py +++ b/allensdk/test/core/test_datafame_utils.py @@ -92,9 +92,7 @@ def test_error_on_not_unique_index(target_df_fixture): @pytest.mark.parametrize("original_index", [None, "a", "c", "b"]) -def test_patch_no_duplicates( - source_df_fixture, target_df_fixture, original_index -): +def test_patch_no_duplicates(source_df_fixture, target_df_fixture, original_index): """ Test that we get the expected dataframe back in the case where there are no duplicate values of index_column @@ -131,9 +129,7 @@ def test_patch_no_duplicates( @pytest.mark.parametrize("original_index", [None, "c", "b"]) -def test_patch_with_duplicates( - source_df_fixture, target_df_fixture, original_index -): +def test_patch_with_duplicates(source_df_fixture, target_df_fixture, original_index): """ Test that we get the expected dataframe back in the case where there are duplicate values of index_column @@ -231,9 +227,7 @@ def test_multiple_indexes_in_dataframe(): mock_behavior_sessions = pd.DataFrame( { "behavior_session_id": [1, 2, 2, 3, 4], - "ecephys_session_id": pd.Series( - [10, 11, 0, 12, 13], dtype="Int64" - ), + "ecephys_session_id": pd.Series([10, 11, 0, 12, 13], dtype="Int64"), "mouse_id": [4, 4, 4, 2, 1], } ).set_index("behavior_session_id") diff --git a/allensdk/test/core/test_h5_utilities.py b/allensdk/test/core/test_h5_utilities.py index 80c5a6974b..8529eefa2c 100644 --- a/allensdk/test/core/test_h5_utilities.py +++ b/allensdk/test/core/test_h5_utilities.py @@ -9,8 +9,7 @@ @pytest.fixture def mem_h5(request): - my_file = h5py.File('my_file.h5', driver='core', backing_store=False, - mode='w') + my_file = h5py.File("my_file.h5", driver="core", backing_store=False, mode="w") def fin(): my_file.close() @@ -22,29 +21,29 @@ def fin(): @pytest.fixture def simple_h5(mem_h5): - mem_h5.create_group('a') - mem_h5.create_group('a/b') - mem_h5.create_group('a/b/c') - mem_h5.create_group('d') - mem_h5.create_group('a/e') + mem_h5.create_group("a") + mem_h5.create_group("a/b") + mem_h5.create_group("a/b/c") + mem_h5.create_group("d") + mem_h5.create_group("a/e") return mem_h5 @pytest.fixture def simple_h5_with_datsets(simple_h5): - simple_h5.create_dataset(name='/a/b/c/fish', data=np.eye(10)) - simple_h5.create_dataset(name='a/fowl', data=np.eye(15)) - simple_h5.create_dataset(name='a/b/mammal', data=np.eye(20)) + simple_h5.create_dataset(name="/a/b/c/fish", data=np.eye(10)) + simple_h5.create_dataset(name="a/fowl", data=np.eye(15)) + simple_h5.create_dataset(name="a/b/mammal", data=np.eye(20)) return simple_h5 def test_decode_bytes(): - inp = np.array([b'a', b'b', b'c']) + inp = np.array([b"a", b"b", b"c"]) obt = h5_utilities.decode_bytes(inp) - assert (np.array_equal(obt, ['a', 'b', 'c'])) + assert np.array_equal(obt, ["a", "b", "c"]) def test_traverse_h5_file(simple_h5): @@ -55,37 +54,32 @@ def cb(name, node): h5_utilities.traverse_h5_file(cb, simple_h5) - assert (set(names) == set(['a', 'a/b', 'a/b/c', 'd', 'a/e'])) + assert set(names) == set(["a", "a/b", "a/b/c", "d", "a/e"]) def test_locate_h5_objects(simple_h5): - matcher_cb = functools.partial(h5_utilities.h5_object_matcher_relname_in, - ['c', 'e']) + matcher_cb = functools.partial(h5_utilities.h5_object_matcher_relname_in, ["c", "e"]) matches = h5_utilities.locate_h5_objects(matcher_cb, simple_h5) match_names = [match.name for match in matches] - assert (set(match_names) == set(['/a/e', '/a/b/c'])) + assert set(match_names) == set(["/a/e", "/a/b/c"]) def test_keyed_locate_h5_objects(simple_h5): matcher_cbs = { - 'e': functools.partial(h5_utilities.h5_object_matcher_relname_in, - ['e']), - 'c': functools.partial(h5_utilities.h5_object_matcher_relname_in, - ['c']), + "e": functools.partial(h5_utilities.h5_object_matcher_relname_in, ["e"]), + "c": functools.partial(h5_utilities.h5_object_matcher_relname_in, ["c"]), } matches = h5_utilities.keyed_locate_h5_objects(matcher_cbs, simple_h5) - assert (matches['e'].name == '/a/e') - assert (matches['c'].name == '/a/b/c') + assert matches["e"].name == "/a/e" + assert matches["c"].name == "/a/b/c" def test_load_datasets_by_relnames(simple_h5_with_datsets): - relnames = ['fish', 'fowl', 'mammal'] - obt = h5_utilities.load_datasets_by_relnames(relnames, - simple_h5_with_datsets, - simple_h5_with_datsets['a/b']) - - assert (len(obt) == 2) - assert (np.allclose(obt['fish'], np.eye(10))) - assert (np.allclose(obt['mammal'], np.eye(20))) + relnames = ["fish", "fowl", "mammal"] + obt = h5_utilities.load_datasets_by_relnames(relnames, simple_h5_with_datsets, simple_h5_with_datsets["a/b"]) + + assert len(obt) == 2 + assert np.allclose(obt["fish"], np.eye(10)) + assert np.allclose(obt["mammal"], np.eye(20)) diff --git a/allensdk/test/core/test_json_utilities.py b/allensdk/test/core/test_json_utilities.py index 1779b7133b..d959635617 100644 --- a/allensdk/test/core/test_json_utilities.py +++ b/allensdk/test/core/test_json_utilities.py @@ -41,19 +41,21 @@ @pytest.fixture def dict_obj(): int_array, y = np.meshgrid(np.arange(2), np.arange(2)) - float_array = y.astype(float)/4.2 - bool_array = int_array > 0 - - object = {"string": "test string", - "float_array": float_array, - "int_array": int_array, - "bool_array": bool_array, - "list": ["this", "is", 1, "list"]} + float_array = y.astype(float) / 4.2 + bool_array = int_array > 0 + + object = { + "string": "test string", + "float_array": float_array, + "int_array": int_array, + "bool_array": bool_array, + "list": ["this", "is", 1, "list"], + } return object def test_write_integer_array(dict_obj): - s_in = ju.write_string({ "int_array": dict_obj["int_array"] }) + s_in = ju.write_string({"int_array": dict_obj["int_array"]}) s_out = """{ "int_array": [ [ @@ -69,9 +71,10 @@ def test_write_integer_array(dict_obj): assert s_in == s_out + def test_write_float_array(dict_obj): - s_in = ju.write_string({ "float_array": dict_obj["float_array"] }) - s_out ="""{ + s_in = ju.write_string({"float_array": dict_obj["float_array"]}) + s_out = """{ "float_array": [ [ 0.0, @@ -86,16 +89,18 @@ def test_write_float_array(dict_obj): assert s_in == s_out + def test_write_string(dict_obj): - s_in = ju.write_string({ "string": dict_obj["string"] }) + s_in = ju.write_string({"string": dict_obj["string"]}) s_out = """{ "string": "test string" }""" assert s_in == s_out + def test_write_bool_array(dict_obj): - s_in = ju.write_string({ "bool_array": dict_obj["bool_array"] }) + s_in = ju.write_string({"bool_array": dict_obj["bool_array"]}) s_out = """{ "bool_array": [ [ @@ -110,8 +115,9 @@ def test_write_bool_array(dict_obj): }""" assert s_in == s_out + def test_write_list(dict_obj): - s_in = ju.write_string({ "list": dict_obj["list"] }) + s_in = ju.write_string({"list": dict_obj["list"]}) s_out = """{ "list": [ "this", diff --git a/allensdk/test/core/test_lazy_property.py b/allensdk/test/core/test_lazy_property.py index a587209b3f..34b1d362e7 100644 --- a/allensdk/test/core/test_lazy_property.py +++ b/allensdk/test/core/test_lazy_property.py @@ -10,7 +10,6 @@ def get_data(self, original_data): class DataClass(LazyPropertyMixin): - def __init__(self, original_data, api=None): self.api = CopyApi() if api is None else api self.original_data = original_data @@ -18,14 +17,14 @@ def __init__(self, original_data, api=None): self.data = self.LazyProperty(self.api.get_data, original_data=self.original_data) -@pytest.mark.parametrize('original_data', [{'a': 'b'}, [None]]) +@pytest.mark.parametrize("original_data", [{"a": "b"}, [None]]) def test_first_compute(original_data): data_obj = DataClass(original_data) assert data_obj.data == original_data assert data_obj.data is not original_data - -@pytest.mark.parametrize('original_data', [1, '1', [None]]) + +@pytest.mark.parametrize("original_data", [1, "1", [None]]) def test_is_lazy(original_data): data_obj = DataClass(original_data) @@ -34,9 +33,9 @@ def test_is_lazy(original_data): assert first is second -@pytest.mark.parametrize('original_data', [1, '1', [None]]) +@pytest.mark.parametrize("original_data", [1, "1", [None]]) def test_not_settable(original_data): data_obj = DataClass(original_data) with pytest.raises(AttributeError) as err: - data_obj.data = '12345' - assert "Can't set LazyLoadable attribute" in err \ No newline at end of file + data_obj.data = "12345" + assert "Can't set LazyLoadable attribute" in err diff --git a/allensdk/test/core/test_mouse_connectivity_cache.py b/allensdk/test/core/test_mouse_connectivity_cache.py index 100cbab6b8..a4014f88a6 100755 --- a/allensdk/test/core/test_mouse_connectivity_cache.py +++ b/allensdk/test/core/test_mouse_connectivity_cache.py @@ -52,225 +52,277 @@ def cached_csv(tmpdir_factory): return csv -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def mcc(tmpdir_factory): - manifest_file = tmpdir_factory.mktemp("mcc").join('manifest.json') + manifest_file = tmpdir_factory.mktemp("mcc").join("manifest.json") return MouseConnectivityCache(manifest_file=str(manifest_file)) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def new_nodes(): - - return [{'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_sets':[{'id': 1}, {'id': 4}, {'id': 167587189}] }] - - -@pytest.fixture(scope='function') + return [ + { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_sets": [{"id": 1}, {"id": 4}, {"id": 167587189}], + } + ] + + +@pytest.fixture(scope="function") def old_nodes(): - - return [{'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'parent_structure_id': 12}] - -@pytest.fixture(scope='function') + return [ + { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "parent_structure_id": 12, + } + ] + + +@pytest.fixture(scope="function") def experiments(): - return [{'data_set_id': 1, 'name': 'foo', 'storage_directory': 'meep', 'transgenic_line': { 'name': 'most_creish' }, - 'injection_structures': '234/324', 'structure_id': 97}, - {'data_set_id': 2, 'name': 'bar', 'storage_directory': 'meep', 'transgenic_line': None, - 'injection_structures': '234/324/234', 'structure_id': 21}] - - -@pytest.fixture(scope='function') + return [ + { + "data_set_id": 1, + "name": "foo", + "storage_directory": "meep", + "transgenic_line": {"name": "most_creish"}, + "injection_structures": "234/324", + "structure_id": 97, + }, + { + "data_set_id": 2, + "name": "bar", + "storage_directory": "meep", + "transgenic_line": None, + "injection_structures": "234/324/234", + "structure_id": 21, + }, + ] + + +@pytest.fixture(scope="function") def unionizes(): - # note that I've mucked around with these values a bit - return [{"hemisphere_id": 1, "id": 169991412, "is_injection": False, - "max_voxel_density": 0.284863, "max_voxel_x": 7700, - "max_voxel_y": 6500, "max_voxel_z": 5000, - "normalized_projection_volume": 0.0, - "projection_density": 0.116754, "projection_energy": 30.7332, - "projection_intensity": 263.231, "projection_volume": 0.0018718, - "section_data_set_id": 166218353, "structure_id": 1, - "sum_pixel_intensity": 99234900.0, "sum_pixels": 1308740.0, - "sum_projection_pixel_intensity": 40221700.0, - "sum_projection_pixels": 152800.0, - "volume": 0.016032}, - {"hemisphere_id": 2, "id": 169991601, "is_injection": False, - "max_voxel_density": 0.0614783, "max_voxel_x": 7500, - "max_voxel_y": 4900, "max_voxel_z": 1700, - "normalized_projection_volume": 0.0, - "projection_density": 0.0168009, - "projection_energy": 1.96084, "projection_intensity": 116.71, - "projection_volume": 0.00148144, - "section_data_set_id": 166218353, "structure_id": 60, - "sum_pixel_intensity": 261941000.0, "sum_pixels": 7198050.0, - "sum_projection_pixel_intensity": 14114200.0, - "sum_projection_pixels": 120934.0, "volume": 0.0881761}] - - -@pytest.fixture(scope='function') + return [ + { + "hemisphere_id": 1, + "id": 169991412, + "is_injection": False, + "max_voxel_density": 0.284863, + "max_voxel_x": 7700, + "max_voxel_y": 6500, + "max_voxel_z": 5000, + "normalized_projection_volume": 0.0, + "projection_density": 0.116754, + "projection_energy": 30.7332, + "projection_intensity": 263.231, + "projection_volume": 0.0018718, + "section_data_set_id": 166218353, + "structure_id": 1, + "sum_pixel_intensity": 99234900.0, + "sum_pixels": 1308740.0, + "sum_projection_pixel_intensity": 40221700.0, + "sum_projection_pixels": 152800.0, + "volume": 0.016032, + }, + { + "hemisphere_id": 2, + "id": 169991601, + "is_injection": False, + "max_voxel_density": 0.0614783, + "max_voxel_x": 7500, + "max_voxel_y": 4900, + "max_voxel_z": 1700, + "normalized_projection_volume": 0.0, + "projection_density": 0.0168009, + "projection_energy": 1.96084, + "projection_intensity": 116.71, + "projection_volume": 0.00148144, + "section_data_set_id": 166218353, + "structure_id": 60, + "sum_pixel_intensity": 261941000.0, + "sum_pixels": 7198050.0, + "sum_projection_pixel_intensity": 14114200.0, + "sum_projection_pixels": 120934.0, + "volume": 0.0881761, + }, + ] + + +@pytest.fixture(scope="function") def top_injection_unionizes(): - return pd.DataFrame([{'experiment_id': 1, 'is_injection': True, 'hemisphere_id': 1, 'structure_id': 10, 'normalized_projection_volume': 0.75}, - {'experiment_id': 1, 'is_injection': True, 'hemisphere_id': 2, 'structure_id': 15, 'normalized_projection_volume': 0.25}, - {'experiment_id': 1, 'is_injection': False, 'hemisphere_id': 1, 'structure_id': 10, 'normalized_projection_volume': 2.0}, - {'experiment_id': 1, 'is_injection': False, 'hemisphere_id': 2, 'structure_id': 11, 'normalized_projection_volume': 0.001}]) + return pd.DataFrame( + [ + { + "experiment_id": 1, + "is_injection": True, + "hemisphere_id": 1, + "structure_id": 10, + "normalized_projection_volume": 0.75, + }, + { + "experiment_id": 1, + "is_injection": True, + "hemisphere_id": 2, + "structure_id": 15, + "normalized_projection_volume": 0.25, + }, + { + "experiment_id": 1, + "is_injection": False, + "hemisphere_id": 1, + "structure_id": 10, + "normalized_projection_volume": 2.0, + }, + { + "experiment_id": 1, + "is_injection": False, + "hemisphere_id": 2, + "structure_id": 11, + "normalized_projection_volume": 0.001, + }, + ] + ) def test_init(mcc): - assert( os.path.exists(mcc.manifest_path) ) + assert os.path.exists(mcc.manifest_path) def test_get_annotation_volume(mcc): - eye = np.eye(100) - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'annotation', 'ccf_2017', - 'annotation_25.nrrd') - - with mock.patch.object(mcc.api, "retrieve_file_over_http", - new=lambda a, b: nrrd.write(b, eye)): + path = os.path.join(os.path.dirname(mcc.manifest_path), "annotation", "ccf_2017", "annotation_25.nrrd") + + with mock.patch.object(mcc.api, "retrieve_file_over_http", new=lambda a, b: nrrd.write(b, eye)): obtained, _ = mcc.get_annotation_volume() with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_annotation_volume() mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_template_volume(mcc): eye = np.eye(100) - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'average_template_25.nrrd') + path = os.path.join(os.path.dirname(mcc.manifest_path), "average_template_25.nrrd") - with mock.patch.object(mcc.api, "retrieve_file_over_http", - new=lambda a, b: nrrd.write(b, eye)): + with mock.patch.object(mcc.api, "retrieve_file_over_http", new=lambda a, b: nrrd.write(b, eye)): obtained, _ = mcc.get_template_volume() with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_template_volume() mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_projection_density(mcc): - eye = np.eye(100) eid = 123456789 - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(eid), - 'projection_density_25.nrrd') + path = os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(eid), "projection_density_25.nrrd") - with mock.patch('allensdk.api.queries.grid_data_api.GridDataApi.' - 'retrieve_file_over_http', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.api.queries.grid_data_api.GridDataApi.retrieve_file_over_http", new=lambda a, b, c: nrrd.write(c, eye) + ): obtained, _ = mcc.get_projection_density(eid) with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_projection_density(eid) mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_injection_density(mcc): - eye = np.eye(100) eid = 123456789 - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(eid), - 'injection_density_25.nrrd') + path = os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(eid), "injection_density_25.nrrd") - with mock.patch('allensdk.api.queries.grid_data_api.GridDataApi.' - 'retrieve_file_over_http', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.api.queries.grid_data_api.GridDataApi.retrieve_file_over_http", new=lambda a, b, c: nrrd.write(c, eye) + ): obtained, _ = mcc.get_injection_density(eid) with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_injection_density(eid) mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_injection_fraction(mcc): - eye = np.eye(100) eid = 123456789 - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(eid), - 'injection_fraction_25.nrrd') + path = os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(eid), "injection_fraction_25.nrrd") - with mock.patch('allensdk.api.queries.grid_data_api.GridDataApi.' - 'retrieve_file_over_http', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.api.queries.grid_data_api.GridDataApi.retrieve_file_over_http", new=lambda a, b, c: nrrd.write(c, eye) + ): obtained, _ = mcc.get_injection_fraction(eid) with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_injection_fraction(eid) mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_data_mask(mcc): - eye = np.eye(100) eid = 123456789 - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(eid), - 'data_mask_25.nrrd') + path = os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(eid), "data_mask_25.nrrd") - with mock.patch('allensdk.api.queries.grid_data_api.GridDataApi.' - 'retrieve_file_over_http', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.api.queries.grid_data_api.GridDataApi.retrieve_file_over_http", new=lambda a, b, c: nrrd.write(c, eye) + ): obtained, _ = mcc.get_data_mask(eid) with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_data_mask(eid) mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_structure_tree(mcc, new_nodes): + path = os.path.join(os.path.dirname(mcc.manifest_path), "structures.json") - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'structures.json') - - with mock.patch('allensdk.api.queries.ontologies_api.' - 'OntologiesApi.model_query', - return_value=new_nodes) as p: - + with mock.patch("allensdk.api.queries.ontologies_api.OntologiesApi.model_query", return_value=new_nodes) as p: obtained = mcc.get_structure_tree() mcc.get_structure_tree() p.assert_called_once() - assert( obtained.node_ids()[0] == 0 ) - + assert obtained.node_ids()[0] == 0 + cm_obt = obtained.get_colormap() - assert(len(cm_obt[0]) == 3) + assert len(cm_obt[0]) == 3 - assert( os.path.exists(path) ) + assert os.path.exists(path) def test_get_experiments(mcc, experiments): + file_path = os.path.join(os.path.dirname(mcc.manifest_path), "experiments.json") - file_path = os.path.join(os.path.dirname(mcc.manifest_path), 'experiments.json') + def new_fn(*args, **kwargs): + return experiments - def new_fn(*args, **kwargs): return experiments - - with mock.patch.object(mcc.api, "model_query", - new=new_fn): + with mock.patch.object(mcc.api, "model_query", new=new_fn): obtained = mcc.get_experiments() with mock.patch.object(mcc.api, "model_query") as mock_squery: @@ -278,15 +330,14 @@ def new_fn(*args, **kwargs): return experiments mock_squery.assert_not_called() assert os.path.exists(file_path) - assert 'storage_directory' not in obtained[0] - assert obtained[0]['transgenic_line'] == 'most_creish' + assert "storage_directory" not in obtained[0] + assert obtained[0]["transgenic_line"] == "most_creish" - obtained = mcc.get_experiments(cre=['MOST_CREISH']) + obtained = mcc.get_experiments(cre=["MOST_CREISH"]) assert len(obtained) == 1 def test_filter_experiments(mcc, experiments): - pass_line = mcc.filter_experiments(experiments, cre=True) fail_line = mcc.filter_experiments(experiments, cre=False) @@ -297,96 +348,76 @@ def fake_tree(*a, **k): class FakeTree(object): def descendant_ids(*a, **k): return [[97, 98], []] + return FakeTree() - with mock.patch.object(mcc, 'get_structure_tree', new=fake_tree): + with mock.patch.object(mcc, "get_structure_tree", new=fake_tree): sid_line = mcc.filter_experiments(experiments, cre=True, injection_structure_ids=[97, 98]) assert len(sid_line) == 1 -def test_rank_structures(mcc, top_injection_unionizes): - os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(1), - 'structure_unionizes.csv') +def test_rank_structures(mcc, top_injection_unionizes): + os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(1), "structure_unionizes.csv") - with mock.patch.object(mcc.api, "model_query", - lambda *args, **kwargs: top_injection_unionizes): + with mock.patch.object(mcc.api, "model_query", lambda *args, **kwargs: top_injection_unionizes): obt = mcc.rank_structures([1], True, [15], [1, 2]) - assert(len(obt) == 1) + assert len(obt) == 1 exp = obt[0] - assert(len(exp) == 1) + assert len(exp) == 1 st = exp[0] - assert(st['structure_id'] == 15) - assert(st['normalized_projection_volume'] == 0.25) + assert st["structure_id"] == 15 + assert st["normalized_projection_volume"] == 0.25 def test_default_structure_ids(mcc, new_nodes): + os.path.join(os.path.dirname(mcc.manifest_path), "structures.json") - os.path.join(os.path.dirname(mcc.manifest_path), - 'structures.json') - - with mock.patch('allensdk.api.queries.ontologies_api.' - 'OntologiesApi.model_query', - return_value=new_nodes): - + with mock.patch("allensdk.api.queries.ontologies_api.OntologiesApi.model_query", return_value=new_nodes): default_structure_ids = mcc.default_structure_ids - assert(len(default_structure_ids) == 1) - assert(default_structure_ids[0] == 0) + assert len(default_structure_ids) == 1 + assert default_structure_ids[0] == 0 def test_get_experiment_structure_unionizes(mcc, unionizes): - eid = 166218353 - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'experiment_{0}'.format(eid), - 'structure_unionizes.csv') + path = os.path.join(os.path.dirname(mcc.manifest_path), "experiment_{0}".format(eid), "structure_unionizes.csv") - with mock.patch.object(mcc.api, "model_query", - new=lambda *args, **kwargs: unionizes): + with mock.patch.object(mcc.api, "model_query", new=lambda *args, **kwargs: unionizes): obtained = mcc.get_experiment_structure_unionizes(eid) with mock.patch.object(mcc.api, "model_query") as mock_query: mcc.get_experiment_structure_unionizes(eid) mock_query.assert_not_called() - assert obtained.loc[0, 'projection_intensity'] == 263.231 + assert obtained.loc[0, "projection_intensity"] == 263.231 assert os.path.exists(path) -def test_get_experiment_structure_unionizes_cache_roundtrip(mcc, unionizes, - cached_csv): - +def test_get_experiment_structure_unionizes_cache_roundtrip(mcc, unionizes, cached_csv): eid = 166218353 - with mock.patch.object(mcc.api, "model_query", - new=lambda *args, **kwargs: unionizes): - obtained = mcc.get_experiment_structure_unionizes( - eid, file_name=cached_csv) + with mock.patch.object(mcc.api, "model_query", new=lambda *args, **kwargs: unionizes): + obtained = mcc.get_experiment_structure_unionizes(eid, file_name=cached_csv) pandas_data = pd.read_csv(cached_csv, index_col=0, parse_dates=True) - assert obtained.loc[0, 'projection_intensity'] == 263.231 - assert(sorted(obtained.keys()) == sorted(pandas_data.columns)) + assert obtained.loc[0, "projection_intensity"] == 263.231 + assert sorted(obtained.keys()) == sorted(pandas_data.columns) def test_filter_structure_unionizes(mcc, unionizes): + obtained = mcc.filter_structure_unionizes(pd.DataFrame(unionizes), hemisphere_ids=[1]) - obtained = mcc.filter_structure_unionizes(pd.DataFrame(unionizes), - hemisphere_ids=[1]) + assert obtained.loc[0, "volume"] == 0.016032 - assert obtained.loc[0, 'volume'] == 0.016032 + mcc.filter_structure_unionizes(pd.DataFrame(unionizes), hemisphere_ids=[1], structure_ids=[1, 60, 90]) - mcc.filter_structure_unionizes(pd.DataFrame(unionizes), - hemisphere_ids=[1], - structure_ids=[1,60,90]) + assert obtained.loc[0, "volume"] == 0.016032 - assert obtained.loc[0, 'volume'] == 0.016032 def test_get_structure_unionizes(mcc, unionizes): - - with mock.patch.object(mcc, "get_experiment_structure_unionizes", - new=lambda *a, **k: pd.DataFrame(unionizes)): + with mock.patch.object(mcc, "get_experiment_structure_unionizes", new=lambda *a, **k: pd.DataFrame(unionizes)): obtained = mcc.get_structure_unionizes([1, 2, 3]) assert obtained.shape[0] == 6 @@ -395,130 +426,118 @@ def test_get_structure_unionizes(mcc, unionizes): def test_get_projection_matrix(mcc): # yup - unionizes = [{'experiment_id': 1, - 'structure_id': 2, - 'hemisphere_id': 1, - 'value': 30}, - {'experiment_id': 1, - 'structure_id': 2, - 'hemisphere_id': 2, - 'value': 40},] - - with mock.patch.object(mcc, "get_structure_unionizes", - new=lambda *a, **k: pd.DataFrame(unionizes)): + unionizes = [ + {"experiment_id": 1, "structure_id": 2, "hemisphere_id": 1, "value": 30}, + {"experiment_id": 1, "structure_id": 2, "hemisphere_id": 2, "value": 40}, + ] + + with mock.patch.object(mcc, "get_structure_unionizes", new=lambda *a, **k: pd.DataFrame(unionizes)): + class FakeTree(object): def value_map(*a, **k): - return {1: 'one', 2: 'two'} - with mock.patch.object(mcc, "get_structure_tree", - new=lambda *a, **k: FakeTree()): - obtained = mcc.get_projection_matrix([1], [2], [1, 2], ['value']) + return {1: "one", 2: "two"} - assert np.allclose(obtained['matrix'], np.array([[30, 40]])) - assert np.array_equal([ii['label'] for ii in obtained['columns']], - ['two-L', 'two-R']) + with mock.patch.object(mcc, "get_structure_tree", new=lambda *a, **k: FakeTree()): + obtained = mcc.get_projection_matrix([1], [2], [1, 2], ["value"]) + assert np.allclose(obtained["matrix"], np.array([[30, 40]])) + assert np.array_equal([ii["label"] for ii in obtained["columns"]], ["two-L", "two-R"]) -def test_get_reference_space(mcc, new_nodes): +def test_get_reference_space(mcc, new_nodes): tree = StructureTree(StructureTree.clean_structures(new_nodes)) - with mock.patch.object(mcc, "get_structure_tree", - new=lambda *a, **k: tree): + with mock.patch.object(mcc, "get_structure_tree", new=lambda *a, **k: tree): annot = np.arange(125).reshape((5, 5, 5)) - with mock.patch.object(mcc, "get_annotation_volume", - new=lambda *a, **k: (annot, 'foo')): + with mock.patch.object(mcc, "get_annotation_volume", new=lambda *a, **k: (annot, "foo")): rsp_obt = mcc.get_reference_space() - assert( np.allclose(rsp_obt.resolution, [25, 25, 25]) ) - assert( np.allclose( rsp_obt.annotation, annot ) ) + assert np.allclose(rsp_obt.resolution, [25, 25, 25]) + assert np.allclose(rsp_obt.annotation, annot) def test_get_structure_mask(mcc): - sid = 12 eye = np.eye(100) - path = os.path.join(os.path.dirname(mcc.manifest_path), - 'annotation', 'ccf_2017', 'structure_masks', - 'resolution_25', 'structure_{0}.nrrd'.format(sid)) - - with mock.patch.object(mcc.api, "retrieve_file_over_http", - new=lambda a, b: nrrd.write(b, eye)): + path = os.path.join( + os.path.dirname(mcc.manifest_path), + "annotation", + "ccf_2017", + "structure_masks", + "resolution_25", + "structure_{0}.nrrd".format(sid), + ) + + with mock.patch.object(mcc.api, "retrieve_file_over_http", new=lambda a, b: nrrd.write(b, eye)): obtained, _ = mcc.get_structure_mask(sid) with mock.patch.object(mcc.api, "retrieve_file_over_http") as mock_rtrv: mcc.get_structure_mask(sid) mock_rtrv.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) -@pytest.mark.parametrize('inp,fails', [(1, False), - (pd.Series([2]), False), - ('qwerty', True)]) +@pytest.mark.parametrize("inp,fails", [(1, False), (pd.Series([2]), False), ("qwerty", True)]) def test_validate_structure_id(inp, fails): - if fails: with pytest.raises(ValueError): MouseConnectivityCache.validate_structure_id(inp) else: out = MouseConnectivityCache.validate_structure_id(inp) - assert( out == int(inp) ) + assert out == int(inp) -@pytest.mark.parametrize('inp,fails', [([1, 2, 3], False), - ([pd.Series([2]), pd.Series([3])], False), - (['qwerty', 1], True)]) +@pytest.mark.parametrize( + "inp,fails", [([1, 2, 3], False), ([pd.Series([2]), pd.Series([3])], False), (["qwerty", 1], True)] +) def test_validate_structure_ids(inp, fails): - if fails: with pytest.raises(ValueError): MouseConnectivityCache.validate_structure_ids(inp) else: out = MouseConnectivityCache.validate_structure_ids(inp) - assert( out == [ int(i) for i in inp ] ) + assert out == [int(i) for i in inp] def test_get_deformation_field(mcc): - arr = np.random.rand(2, 4, 5, 3) def write_dfmfld(*a, **k): img = sitk.GetImageFromArray(arr) - sitk.WriteImage(img, str(k['header_path']), True) # TODO the str call here is only necessary in 2.7 + sitk.WriteImage(img, str(k["header_path"]), True) # TODO the str call here is only necessary in 2.7 - with mock.patch.object(mcc.api, 'download_deformation_field', new=write_dfmfld): + with mock.patch.object(mcc.api, "download_deformation_field", new=write_dfmfld): obtained = mcc.get_deformation_field(123) assert np.allclose(arr, obtained) def test_get_affine_parameters(mcc): - def new_fn(*args, **kwargs): - return [{'alignment3d': { - 'trv_00': 1, - 'trv_01': 2, - 'trv_02': 3, - 'trv_03': 4, - 'trv_04': 5, - 'trv_05': 6, - 'trv_06': 7, - 'trv_07': 8, - 'trv_08': 9, - 'trv_09': 10, - 'trv_10': 11, - 'trv_11': 12, - }}] - - expected = np.array([ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - [10, 11, 12] - ]) + return [ + { + "alignment3d": { + "trv_00": 1, + "trv_01": 2, + "trv_02": 3, + "trv_03": 4, + "trv_04": 5, + "trv_05": 6, + "trv_06": 7, + "trv_07": 8, + "trv_08": 9, + "trv_09": 10, + "trv_10": 11, + "trv_11": 12, + } + } + ] + + expected = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) with mock.patch.object(mcc.api, "model_query", new=new_fn): obtained = mcc.get_affine_parameters(1245) - assert np.allclose(expected, obtained) \ No newline at end of file + assert np.allclose(expected, obtained) diff --git a/allensdk/test/core/test_mouse_connectivity_notebook.py b/allensdk/test/core/test_mouse_connectivity_notebook.py index 3886d12e8b..7e91495fed 100644 --- a/allensdk/test/core/test_mouse_connectivity_notebook.py +++ b/allensdk/test/core/test_mouse_connectivity_notebook.py @@ -36,18 +36,16 @@ import pytest - @pytest.mark.nightly def test_notebook(tmpdir_factory): - # coding: utf-8 # ## Mouse Connectivity - # + # # This notebook demonstrates how to access and manipulate data in the Allen Mouse Brain Connectivity Atlas. The `MouseConnectivityCache` AllenSDK class provides methods for downloading metadata about experiments, including their viral injection site and the mouse's transgenic line. You can request information either as a Pandas DataFrame or a simple list of dictionaries. - # + # # An important feature of the `MouseConnectivityCache` is how it stores and retrieves data for you. By default, it will create (or read) a manifest file that keeps track of where various connectivity atlas data are stored. If you request something that has not already been downloaded, it will download it and store it in a well known location. - # + # # Download this notebook in .ipynb format here. # In[1]: @@ -58,7 +56,7 @@ def test_notebook(tmpdir_factory): # the data that has already been downloaded onto the hard drives. # If you supply a relative path, it is assumed to be relative to your # current working directory. - manifest_file = tmpdir_factory.mktemp('mcc').join('manifest.json') + manifest_file = tmpdir_factory.mktemp("mcc").join("manifest.json") mcc = MouseConnectivityCache(manifest_file=str(manifest_file)) # open up a list of all of the experiments @@ -68,7 +66,6 @@ def test_notebook(tmpdir_factory): # take a look at what we know about an experiment with a primary motor injection all_experiments.loc[122642490] - # `MouseConnectivityCache` has a method for retrieving the adult mouse structure tree as an `StructureTree` class instance. This is a wrapper around a list of dictionaries, where each dictionary describes a structure. It is principally useful for looking up structures by their properties. # In[2]: @@ -80,16 +77,15 @@ def test_notebook(tmpdir_factory): structure_tree = mcc.get_structure_tree() # get info on some structures - structures = structure_tree.get_structures_by_name(['Primary visual area', 'Hypothalamus']) + structures = structure_tree.get_structures_by_name(["Primary visual area", "Hypothalamus"]) pd.DataFrame(structures) - # As a convenience, structures are grouped in to named collections called "structure sets". These sets can be used to quickly gather a useful subset of structures from the tree. The criteria used to define structure sets are eclectic; a structure set might list: - # + # # * structures that were used in a particular project. # * structures that coarsely partition the brain. # * structures that bear functional similarity. - # + # # or something else entirely. To view all of the available structure sets along with their descriptions, follow this [link](http://api.brain-map.org/api/v2/data/StructureSet/query.json). To see only structure sets relevant to the adult mouse brain, use the StructureTree: # In[3]: @@ -104,7 +100,6 @@ def test_notebook(tmpdir_factory): # query the API for information on those structure sets pd.DataFrame(oapi.get_structure_sets(structure_set_ids)) - # On the connectivity atlas web site, you'll see that we show most of our data at a fairly coarse structure level. We did this by creating a structure set of ~300 structures, which we call the "summary structures". We can use the structure tree to get all of the structures in this set: # In[4]: @@ -113,98 +108,97 @@ def test_notebook(tmpdir_factory): summary_structures = structure_tree.get_structures_by_set_id([167587189]) pd.DataFrame(summary_structures) - # This is how you can filter experiments by transgenic line: # In[5]: # fetch the experiments that have injections in the isocortex of cre-positive mice - isocortex = structure_tree.get_structures_by_name(['Isocortex'])[0] - cre_cortical_experiments = mcc.get_experiments(cre=True, - injection_structure_ids=[isocortex['id']]) + isocortex = structure_tree.get_structures_by_name(["Isocortex"])[0] + cre_cortical_experiments = mcc.get_experiments(cre=True, injection_structure_ids=[isocortex["id"]]) print("%d cre cortical experiments" % len(cre_cortical_experiments)) # same as before, but restrict the cre line - rbp4_cortical_experiments = mcc.get_experiments(cre=[ 'Rbp4-Cre_KL100' ], - injection_structure_ids=[isocortex['id']]) - + rbp4_cortical_experiments = mcc.get_experiments(cre=["Rbp4-Cre_KL100"], injection_structure_ids=[isocortex["id"]]) print("%d Rbp4 cortical experiments" % len(rbp4_cortical_experiments)) - # ## Structure Signal Unionization - # + # # The ProjectionStructureUnionizes API data tells you how much signal there was in a given structure and experiment. It contains the density of projecting signal, volume of projecting signal, and other information. `MouseConnectivityCache` provides methods for querying and storing this data. # In[6]: # find wild-type injections into primary visual area - visp = structure_tree.get_structures_by_acronym(['VISp'])[0] - visp_experiments = mcc.get_experiments(cre=False, - injection_structure_ids=[visp['id']]) + visp = structure_tree.get_structures_by_acronym(["VISp"])[0] + visp_experiments = mcc.get_experiments(cre=False, injection_structure_ids=[visp["id"]]) print("%d VISp experiments" % len(visp_experiments)) - structure_unionizes = mcc.get_structure_unionizes([ e['id'] for e in visp_experiments ], - is_injection=False, - structure_ids=[isocortex['id']], - include_descendants=True) + structure_unionizes = mcc.get_structure_unionizes( + [e["id"] for e in visp_experiments], + is_injection=False, + structure_ids=[isocortex["id"]], + include_descendants=True, + ) print("%d VISp non-injection, cortical structure unionizes" % len(structure_unionizes)) - # In[7]: structure_unionizes.head() - # This is a rather large table, even for a relatively small number of experiments. You can filter it down to a smaller list of structures like this. # In[8]: - dense_unionizes = structure_unionizes[ structure_unionizes.projection_density > .5 ] - large_unionizes = dense_unionizes[ dense_unionizes.volume > .5 ] + dense_unionizes = structure_unionizes[structure_unionizes.projection_density > 0.5] + large_unionizes = dense_unionizes[dense_unionizes.volume > 0.5] large_structures = pd.DataFrame(structure_tree.nodes(large_unionizes.structure_id)) - print("%d large, dense, cortical, non-injection unionizes, %d structures" % ( len(large_unionizes), len(large_structures) )) + print( + "%d large, dense, cortical, non-injection unionizes, %d structures" + % (len(large_unionizes), len(large_structures)) + ) print(large_structures.name) large_unionizes - # ## Generating a Projection Matrix - # The `MouseConnectivityCache` class provides a helper method for converting ProjectionStructureUnionize records for a set of experiments and structures into a matrix. This code snippet demonstrates how to make a matrix of projection density values in auditory sub-structures for cre-negative VISp experiments. + # The `MouseConnectivityCache` class provides a helper method for converting ProjectionStructureUnionize records for a set of experiments and structures into a matrix. This code snippet demonstrates how to make a matrix of projection density values in auditory sub-structures for cre-negative VISp experiments. # In[9]: import numpy as np import matplotlib.pyplot as plt import warnings - warnings.filterwarnings('ignore') - visp_experiment_ids = [ e['id'] for e in visp_experiments ] - ctx_children = structure_tree.child_ids( [isocortex['id']] )[0] + warnings.filterwarnings("ignore") + + visp_experiment_ids = [e["id"] for e in visp_experiments] + ctx_children = structure_tree.child_ids([isocortex["id"]])[0] - pm = mcc.get_projection_matrix(experiment_ids = visp_experiment_ids, - projection_structure_ids = ctx_children, - hemisphere_ids= [2], # right hemisphere, ipsilateral - parameter = 'projection_density') + pm = mcc.get_projection_matrix( + experiment_ids=visp_experiment_ids, + projection_structure_ids=ctx_children, + hemisphere_ids=[2], # right hemisphere, ipsilateral + parameter="projection_density", + ) - row_labels = pm['rows'] # these are just experiment ids - column_labels = [ c['label'] for c in pm['columns'] ] - matrix = pm['matrix'] + row_labels = pm["rows"] # these are just experiment ids + column_labels = [c["label"] for c in pm["columns"]] + matrix = pm["matrix"] - fig, ax = plt.subplots(figsize=(15,15)) + fig, ax = plt.subplots(figsize=(15, 15)) ax.pcolor(matrix, cmap=plt.cm.afmhot) # put the major ticks at the middle of each cell - ax.set_xticks(np.arange(matrix.shape[1])+0.5, minor=False) - ax.set_yticks(np.arange(matrix.shape[0])+0.5, minor=False) + ax.set_xticks(np.arange(matrix.shape[1]) + 0.5, minor=False) + ax.set_yticks(np.arange(matrix.shape[0]) + 0.5, minor=False) ax.set_xlim([0, matrix.shape[1]]) - ax.set_ylim([0, matrix.shape[0]]) + ax.set_ylim([0, matrix.shape[0]]) # want a more natural, table-like display ax.invert_yaxis() @@ -214,9 +208,9 @@ def test_notebook(tmpdir_factory): ax.set_yticklabels(row_labels, minor=False) # ## Manipulating Grid Data - # + # # The `MouseConnectivityCache` class also helps you download and open every experiment's projection grid data volume. By default it will download 25um volumes, but you could also download data at other resolutions if you prefer (10um, 50um, 100um). - # + # # This demonstrates how you can load the projection density for a particular experiment. It also shows how to download the template volume to which all grid data is registered. Voxels in that template have been structurally annotated by neuroanatomists and the resulting labels stored in a separate annotation volume image. # In[10]: @@ -224,7 +218,6 @@ def test_notebook(tmpdir_factory): # we'll take this experiment - an injection into the primary somatosensory - as an example experiment_id = 181599674 - # In[11]: # projection density: number of projecting pixels / voxel volume @@ -250,7 +243,6 @@ def test_notebook(tmpdir_factory): print(pd_info) print(pd.shape, template.shape, annot.shape) - # Once you have these loaded, you can use matplotlib see what they look like. # In[12]: @@ -263,16 +255,15 @@ def test_notebook(tmpdir_factory): # show that slice of all volumes side-by-side f, pr_axes = plt.subplots(1, 3, figsize=(15, 6)) - pr_axes[0].imshow(pd_mip, cmap='hot', aspect='equal') + pr_axes[0].imshow(pd_mip, cmap="hot", aspect="equal") pr_axes[0].set_title("projection density MaxIP") - pr_axes[1].imshow(ind_mip, cmap='hot', aspect='equal') + pr_axes[1].imshow(ind_mip, cmap="hot", aspect="equal") pr_axes[1].set_title("injection density MaxIP") - pr_axes[2].imshow(inf_mip, cmap='hot', aspect='equal') + pr_axes[2].imshow(inf_mip, cmap="hot", aspect="equal") pr_axes[2].set_title("injection fraction MaxIP") - # In[13]: # Look at a slice from the average template and annotation volumes @@ -282,25 +273,22 @@ def test_notebook(tmpdir_factory): f, ccf_axes = plt.subplots(1, 3, figsize=(15, 6)) - ccf_axes[0].imshow(template[slice_idx,:,:], cmap='gray', aspect='equal', vmin=template.min(), vmax=template.max()) + ccf_axes[0].imshow(template[slice_idx, :, :], cmap="gray", aspect="equal", vmin=template.min(), vmax=template.max()) ccf_axes[0].set_title("registration template") - ccf_axes[1].imshow(annot[slice_idx,:,:], cmap='gray', aspect='equal', vmin=0, vmax=2000) + ccf_axes[1].imshow(annot[slice_idx, :, :], cmap="gray", aspect="equal", vmin=0, vmax=2000) ccf_axes[1].set_title("annotation volume") - ccf_axes[2].imshow(cortex_mask[slice_idx,:,:], cmap='gray', aspect='equal', vmin=0, vmax=1) + ccf_axes[2].imshow(cortex_mask[slice_idx, :, :], cmap="gray", aspect="equal", vmin=0, vmax=1) ccf_axes[2].set_title("isocortex mask") - # On occasion the TissueCyte microscope fails to acquire a tile. In this case the data from that tile should not be used for analysis. The data mask associated with each experiment can be used to determine which portions of the grid data came from correctly acquired tiles. - # + # # In this experiment, a missed tile can be seen in the data mask as a dark warped square. The values in the mask exist within [0, 1], describing the fraction of each voxel that was correctly acquired # In[14]: f, data_mask_axis = plt.subplots(figsize=(5, 6)) - data_mask_axis.imshow(dm[81, :, :], cmap='hot', aspect='equal', vmin=0, vmax=1) - data_mask_axis.set_title('data mask') - - + data_mask_axis.imshow(dm[81, :, :], cmap="hot", aspect="equal", vmin=0, vmax=1) + data_mask_axis.set_title("data mask") diff --git a/allensdk/test/core/test_nwb_data_set.py b/allensdk/test/core/test_nwb_data_set.py index c4b48b5f04..68c1ff80d9 100644 --- a/allensdk/test/core/test_nwb_data_set.py +++ b/allensdk/test/core/test_nwb_data_set.py @@ -42,11 +42,11 @@ NWB_FLAVORS = [] -if 'TEST_EPHYS_NWB_FILES' in os.environ: - nwb_list_file = os.environ['TEST_EPHYS_NWB_FILES'] +if "TEST_EPHYS_NWB_FILES" in os.environ: + nwb_list_file = os.environ["TEST_EPHYS_NWB_FILES"] else: - nwb_list_file = str(files('allensdk.test.core').joinpath('nwb_ephys_files.txt')) -with open(nwb_list_file, 'r') as f: + nwb_list_file = str(files("allensdk.test.core").joinpath("nwb_ephys_files.txt")) +with open(nwb_list_file, "r") as f: NWB_FLAVORS = [x.strip() for x in f] @@ -103,7 +103,7 @@ def mock_h5py_file(m=None, data=None): @pytest.fixture def mock_data_set(): - nwb_file = 'fixture.nwb' + nwb_file = "fixture.nwb" data_set = NwbDataSet(nwb_file) return data_set @@ -124,33 +124,24 @@ def __getitem__(self, item): return self.value h5 = { - 'epochs': { - 'Sweep_1': { - 'response': { - 'timeseries': { - 'data': np.ones(DATA_LENGTH) - } + "epochs": { + "Sweep_1": {"response": {"timeseries": {"data": np.ones(DATA_LENGTH)}}}, + "Experiment_1": { + "stimulus": { + "idx_start": H5Scalar(1), + "count": H5Scalar(3), # truncation is here + "timeseries": {"data": np.ones(DATA_LENGTH)}, } }, - 'Experiment_1': { - 'stimulus': { - 'idx_start': H5Scalar(1), - 'count': H5Scalar(3), # truncation is here - 'timeseries': { - 'data': np.ones(DATA_LENGTH) - } - } - } } } - with patch('h5py.File', mock_h5py_file(data=h5)): + with patch("h5py.File", mock_h5py_file(data=h5)): data_set.fill_sweep_responses(0.0, [1], extend_experiment=True) - assert h5['epochs']['Experiment_1']['stimulus']['count'] == 4 - assert h5['epochs']['Experiment_1']['stimulus']['idx_start'] == 1 - assert np.all( - h5['epochs']['Sweep_1']['response']['timeseries']['data'] == 0.0) + assert h5["epochs"]["Experiment_1"]["stimulus"]["count"] == 4 + assert h5["epochs"]["Experiment_1"]["stimulus"]["idx_start"] == 1 + assert np.all(h5["epochs"]["Sweep_1"]["response"]["timeseries"]["data"] == 0.0) def test_fill_sweep_responses(mock_data_set): @@ -158,39 +149,31 @@ def test_fill_sweep_responses(mock_data_set): DATA_LENGTH = 5 h5 = { - 'stimulus': { - 'presentation': { - 'Sweep_1': { - 'aibs_stimulus_amplitude_pa': 15.0, - 'aibs_stimulus_name': 'Joe', - 'gain': 1.0, - 'initial_access_resistance': 0.05, - 'seal': True + "stimulus": { + "presentation": { + "Sweep_1": { + "aibs_stimulus_amplitude_pa": 15.0, + "aibs_stimulus_name": "Joe", + "gain": 1.0, + "initial_access_resistance": 0.05, + "seal": True, } } }, - 'epochs': { - 'Sweep_1': { - 'description': 'sweep 1 description', - 'stimulus': {}, - 'response': { - 'count': DATA_LENGTH, - 'idx_start': 0, - 'timeseries': { - 'data': np.ones(DATA_LENGTH) * 1.0 - } - } + "epochs": { + "Sweep_1": { + "description": "sweep 1 description", + "stimulus": {}, + "response": {"count": DATA_LENGTH, "idx_start": 0, "timeseries": {"data": np.ones(DATA_LENGTH) * 1.0}}, } - } + }, } - with patch('h5py.File', mock_h5py_file(data=h5)): + with patch("h5py.File", mock_h5py_file(data=h5)): data_set.fill_sweep_responses(0.0, [1]) - assert not np.any( - h5['epochs']['Sweep_1']['response']['timeseries']['data']) - assert len(h5['epochs']['Sweep_1']['response']['timeseries']['data']) == \ - DATA_LENGTH + assert not np.any(h5["epochs"]["Sweep_1"]["response"]["timeseries"]["data"]) + assert len(h5["epochs"]["Sweep_1"]["response"]["timeseries"]["data"]) == DATA_LENGTH @pytest.mark.xfail @@ -199,38 +182,28 @@ def test_set_spike_times(mock_data_set): DATA_LENGTH = 5 h5 = { - 'analysis': { - 'spike_times': { - 'Sweep_1': {} - } - }, - 'stimulus': { - 'presentation': { - 'Sweep_1': { - 'aibs_stimulus_amplitude_pa': 15.0, - 'aibs_stimulus_name': 'Joe', - 'gain': 1.0, - 'initial_access_resistance': 0.05, - 'seal': True + "analysis": {"spike_times": {"Sweep_1": {}}}, + "stimulus": { + "presentation": { + "Sweep_1": { + "aibs_stimulus_amplitude_pa": 15.0, + "aibs_stimulus_name": "Joe", + "gain": 1.0, + "initial_access_resistance": 0.05, + "seal": True, } } }, - 'epochs': { - 'Sweep_1': { - 'description': 'sweep 1 description', - 'stimulus': {}, - 'response': { - 'count': DATA_LENGTH, - 'idx_start': 0, - 'timeseries': { - 'data': np.ones(DATA_LENGTH) * 1.0 - } - } + "epochs": { + "Sweep_1": { + "description": "sweep 1 description", + "stimulus": {}, + "response": {"count": DATA_LENGTH, "idx_start": 0, "timeseries": {"data": np.ones(DATA_LENGTH) * 1.0}}, } - } + }, } - with patch('h5py.File', mock_h5py_file(data=h5)): + with patch("h5py.File", mock_h5py_file(data=h5)): data_set.set_spike_times(1, [0.1, 0.2, 0.3, 0.4, 0.5]) assert False diff --git a/allensdk/test/core/test_obj_utilities.py b/allensdk/test/core/test_obj_utilities.py index c8cafdd605..73b3d034b9 100644 --- a/allensdk/test/core/test_obj_utilities.py +++ b/allensdk/test/core/test_obj_utilities.py @@ -44,7 +44,7 @@ @pytest.fixture def wavefront_obj(): - return ''' + return """ v 8578 5484.96 5227.57 v 8509.2 5487.54 5237.07 @@ -69,26 +69,24 @@ def wavefront_obj(): f 3//3 2//2 5//5 f 6//6 3//3 5//5 - ''' + """ def test_read_obj(wavefront_obj): - - path = 'path!' + path = "path!" # need to patch the version in allensdk.api.cache because of import x from y syntax above - with patch( 'allensdk.core.obj_utilities.open', mock_open(read_data=wavefront_obj), create=True ) as p: + with patch("allensdk.core.obj_utilities.open", mock_open(read_data=wavefront_obj), create=True) as p: obt = read_obj(path) - p.assert_called_with(path, 'r') - assert( obt is not None ) + p.assert_called_with(path, "r") + assert obt is not None def test_parse_obj(wavefront_obj): - - lines = wavefront_obj.split('\n') + lines = wavefront_obj.split("\n") vertices, vertex_normals, face_vertices, face_normals = parse_obj(lines) - - assert(np.allclose( face_vertices, face_normals )) - assert(np.allclose( face_vertices[2, :], [2, 1, 4] )) - assert(np.allclose( vertices[1, :], [8509.2, 5487.54, 5237.07] )) - assert(np.allclose( vertex_normals[2, :], [-0.0880336, -0.0323767, -0.995591] )) + + assert np.allclose(face_vertices, face_normals) + assert np.allclose(face_vertices[2, :], [2, 1, 4]) + assert np.allclose(vertices[1, :], [8509.2, 5487.54, 5237.07]) + assert np.allclose(vertex_normals[2, :], [-0.0880336, -0.0323767, -0.995591]) diff --git a/allensdk/test/core/test_pickle_utils.py b/allensdk/test/core/test_pickle_utils.py index 3300a6f3ef..262d1f3f69 100644 --- a/allensdk/test/core/test_pickle_utils.py +++ b/allensdk/test/core/test_pickle_utils.py @@ -3,8 +3,7 @@ import pathlib import pandas as pd -from allensdk.test_utilities.custom_comparators import ( - stimulus_pickle_equivalence) +from allensdk.test_utilities.custom_comparators import stimulus_pickle_equivalence from allensdk.core.pickle_utils import ( _sanitize_list, @@ -12,7 +11,8 @@ _sanitize_tuple, _sanitize_list_or_tuple, _sanitize_pickle_data, - load_and_sanitize_pickle) + load_and_sanitize_pickle, +) @pytest.fixture @@ -24,9 +24,9 @@ def list_data_fixture(): 'output' maps to the desired list with bytes cast as strings. """ - input_data = [b'happy', 2.3, 5, 'pig', b'banana'] - output_data = ['happy', 2.3, 5, 'pig', 'banana'] - return {'input': input_data, 'output': output_data} + input_data = [b"happy", 2.3, 5, "pig", b"banana"] + output_data = ["happy", 2.3, 5, "pig", "banana"] + return {"input": input_data, "output": output_data} @pytest.fixture @@ -38,17 +38,11 @@ def dict_data_fixture(): 'output' maps to the desired dicts with bytes cast as strings. """ - input_data = {b'the_first': 2.1, - 'the_second': b'funny', - b'the_third': 'two', - b'the_fourth': b'three'} + input_data = {b"the_first": 2.1, "the_second": b"funny", b"the_third": "two", b"the_fourth": b"three"} - output_data = {'the_first': 2.1, - 'the_second': 'funny', - 'the_third': 'two', - 'the_fourth': 'three'} + output_data = {"the_first": 2.1, "the_second": "funny", "the_third": "two", "the_fourth": "three"} - return {'input': input_data, 'output': output_data} + return {"input": input_data, "output": output_data} @pytest.fixture @@ -60,52 +54,39 @@ def nested_dict_data_fixture(): 'output' maps to the desired dicts with bytes cast as strings. """ - input_data = { - b'a': [1, b'b', 2, 'c'], - 'c': {b'd': 4, b'e': b'f'}, - b'g': b'h'} + input_data = {b"a": [1, b"b", 2, "c"], "c": {b"d": 4, b"e": b"f"}, b"g": b"h"} - output_data = { - 'a': [1, 'b', 2, 'c'], - 'c': {'d': 4, 'e': 'f'}, - 'g': 'h'} + output_data = {"a": [1, "b", 2, "c"], "c": {"d": 4, "e": "f"}, "g": "h"} - return {'input': input_data, 'output': output_data} + return {"input": input_data, "output": output_data} @pytest.fixture -def nested_list_data_fixture( - list_data_fixture, - nested_dict_data_fixture): +def nested_list_data_fixture(list_data_fixture, nested_dict_data_fixture): """ Return a dict. 'input' maps to a list that has bytes where we want strings. 'output' maps to the desired list with bytes cast as strings. """ - input_data = [copy.deepcopy(list_data_fixture['input']), - copy.deepcopy(nested_dict_data_fixture['input'])] + input_data = [copy.deepcopy(list_data_fixture["input"]), copy.deepcopy(nested_dict_data_fixture["input"])] - output_data = [copy.deepcopy(list_data_fixture['output']), - copy.deepcopy(nested_dict_data_fixture['output'])] + output_data = [copy.deepcopy(list_data_fixture["output"]), copy.deepcopy(nested_dict_data_fixture["output"])] - return {'input': input_data, 'output': output_data} + return {"input": input_data, "output": output_data} -@pytest.mark.parametrize('nested', [True, False]) -def test_sanitize_list( - list_data_fixture, - nested_list_data_fixture, - nested): +@pytest.mark.parametrize("nested", [True, False]) +def test_sanitize_list(list_data_fixture, nested_list_data_fixture, nested): """ Test that _sanitize_list behaves well on an un-nested list """ if nested: - input_data = copy.deepcopy(nested_list_data_fixture['input']) - output_data = nested_list_data_fixture['output'] + input_data = copy.deepcopy(nested_list_data_fixture["input"]) + output_data = nested_list_data_fixture["output"] else: - input_data = copy.deepcopy(list_data_fixture['input']) - output_data = list_data_fixture['output'] + input_data = copy.deepcopy(list_data_fixture["input"]) + output_data = list_data_fixture["output"] for_later = copy.deepcopy(input_data) actual = _sanitize_list(input_data) @@ -113,20 +94,17 @@ def test_sanitize_list( assert not actual == for_later -@pytest.mark.parametrize('nested', [True, False]) -def test_sanitize_dict( - dict_data_fixture, - nested_dict_data_fixture, - nested): +@pytest.mark.parametrize("nested", [True, False]) +def test_sanitize_dict(dict_data_fixture, nested_dict_data_fixture, nested): """ Test that _sanitize_dict behaves well on un-nested dict """ if nested: - input_data = copy.deepcopy(nested_dict_data_fixture['input']) - output_data = nested_dict_data_fixture['output'] + input_data = copy.deepcopy(nested_dict_data_fixture["input"]) + output_data = nested_dict_data_fixture["output"] else: - input_data = copy.deepcopy(dict_data_fixture['input']) - output_data = dict_data_fixture['output'] + input_data = copy.deepcopy(dict_data_fixture["input"]) + output_data = dict_data_fixture["output"] for_later = copy.deepcopy(input_data) actual = _sanitize_dict(input_data) assert actual == output_data @@ -137,15 +115,9 @@ def test_sanitize_tuple(): """ Test that _sanitize_tuple works as expected """ - input_data = (['a', b'b', 'c', 2], - 'cat', - b'dog', - {b'd': 2, 'e': 3, b'f': b'g'}) + input_data = (["a", b"b", "c", 2], "cat", b"dog", {b"d": 2, "e": 3, b"f": b"g"}) - expected_data = (['a', 'b', 'c', 2], - 'cat', - 'dog', - {'d': 2, 'e': 3, 'f': 'g'}) + expected_data = (["a", "b", "c", 2], "cat", "dog", {"d": 2, "e": 3, "f": "g"}) actual = _sanitize_tuple(input_data) assert actual == expected_data @@ -155,45 +127,29 @@ def test_sanitize_list_or_tuple(): """ Test that _sanitize_list_or_tuple works as expected """ - input_data = (['a', b'b', 'c', 2], - 'cat', - b'dog', - {b'd': 2, 'e': 3, b'f': b'g'}) + input_data = (["a", b"b", "c", 2], "cat", b"dog", {b"d": 2, "e": 3, b"f": b"g"}) - expected_data = (['a', 'b', 'c', 2], - 'cat', - 'dog', - {'d': 2, 'e': 3, 'f': 'g'}) + expected_data = (["a", "b", "c", 2], "cat", "dog", {"d": 2, "e": 3, "f": "g"}) actual = _sanitize_list_or_tuple(input_data) assert actual == expected_data - input_data = [['h', b'i', 'j', 2], - 'frog', - b'fly', - {b'k': 2, 'l': 3, b'm': b'n'}] + input_data = [["h", b"i", "j", 2], "frog", b"fly", {b"k": 2, "l": 3, b"m": b"n"}] - expected_data = [['h', 'i', 'j', 2], - 'frog', - 'fly', - {'k': 2, 'l': 3, 'm': 'n'}] + expected_data = [["h", "i", "j", 2], "frog", "fly", {"k": 2, "l": 3, "m": "n"}] actual = _sanitize_list_or_tuple(input_data) assert actual == expected_data -def test_sanitize_pickle_data( - nested_list_data_fixture, - nested_dict_data_fixture): +def test_sanitize_pickle_data(nested_list_data_fixture, nested_dict_data_fixture): """ Test user-facing sanitization method """ - actual = _sanitize_pickle_data( - copy.deepcopy(nested_list_data_fixture['input'])) - assert actual == nested_list_data_fixture['output'] + actual = _sanitize_pickle_data(copy.deepcopy(nested_list_data_fixture["input"])) + assert actual == nested_list_data_fixture["output"] - actual = _sanitize_pickle_data( - copy.deepcopy(nested_dict_data_fixture['input'])) - assert actual == nested_dict_data_fixture['output'] + actual = _sanitize_pickle_data(copy.deepcopy(nested_dict_data_fixture["input"])) + assert actual == nested_dict_data_fixture["output"] def test_load_and_sanitize_error(): @@ -201,11 +157,11 @@ def test_load_and_sanitize_error(): Make sure load_and_sanitize_pickle raises an error if given something that is neither .pkl or .gz """ - with pytest.raises(ValueError, match='Can open .pkl and .gz'): - load_and_sanitize_pickle(pickle_path='junk.txt') + with pytest.raises(ValueError, match="Can open .pkl and .gz"): + load_and_sanitize_pickle(pickle_path="junk.txt") - with pytest.raises(ValueError, match='Can open .pkl and .gz'): - load_and_sanitize_pickle(pickle_path=pathlib.Path('junk.txt')) + with pytest.raises(ValueError, match="Can open .pkl and .gz"): + load_and_sanitize_pickle(pickle_path=pathlib.Path("junk.txt")) def test_local_pickle_equivalence(): @@ -215,12 +171,11 @@ def test_local_pickle_equivalence(): """ this_dir = pathlib.Path(__file__).parent.parent - pkl_path = this_dir / 'brain_observatory/behavior/resources' - pkl_path = pkl_path / 'example_stimulus.pkl.gz' + pkl_path = this_dir / "brain_observatory/behavior/resources" + pkl_path = pkl_path / "example_stimulus.pkl.gz" pd_data = pd.read_pickle(pkl_path) - sanitized_data = load_and_sanitize_pickle( - pickle_path=pkl_path) + sanitized_data = load_and_sanitize_pickle(pickle_path=pkl_path) assert stimulus_pickle_equivalence(sanitized_data, pd_data) @@ -233,9 +188,9 @@ def test_pickle_equivalence(): file_path = pathlib.Path( "/allen/programs/braintv/production/visualbehavior/" - "prod2/specimen_850862430/behavior_session_951520319/951410079.pkl") + "prod2/specimen_850862430/behavior_session_951520319/951410079.pkl" + ) pd_data = pd.read_pickle(file_path) - sanitized_data = load_and_sanitize_pickle( - pickle_path=file_path) + sanitized_data = load_and_sanitize_pickle(pickle_path=file_path) assert stimulus_pickle_equivalence(sanitized_data, pd_data) diff --git a/allensdk/test/core/test_reference_space.py b/allensdk/test/core/test_reference_space.py index c2fd502b5c..2faf4534b0 100644 --- a/allensdk/test/core/test_reference_space.py +++ b/allensdk/test/core/test_reference_space.py @@ -47,14 +47,15 @@ @pytest.fixture def rsp(): - - tree = [{'id': 1, 'structure_id_path': [1]}, - {'id': 2, 'structure_id_path': [1, 2]}, - {'id': 3, 'structure_id_path': [1, 3]}, - {'id': 4, 'structure_id_path': [1, 2, 4]}, - {'id': 5, 'structure_id_path': [1, 2, 5]}, - {'id': 6, 'structure_id_path': [1, 2, 5, 6]}, - {'id': 7, 'structure_id_path': [1, 7]}] + tree = [ + {"id": 1, "structure_id_path": [1]}, + {"id": 2, "structure_id_path": [1, 2]}, + {"id": 3, "structure_id_path": [1, 3]}, + {"id": 4, "structure_id_path": [1, 2, 4]}, + {"id": 5, "structure_id_path": [1, 2, 5]}, + {"id": 6, "structure_id_path": [1, 2, 5, 6]}, + {"id": 7, "structure_id_path": [1, 7]}, + ] # leaves are 6, 4, 3 # additionally annotate 2, 5 for realism :) @@ -71,8 +72,8 @@ def rsp(): @pytest.fixture def itksnap_rsp(): tree = [ - {'id': 1, 'rgb_triplet': [1, 2, 3], 'acronym': 'b', 'structure_id_path': [1]}, - {'id': 5000, 'rgb_triplet': [4, 5, 6], 'acronym': 'a', 'structure_id_path': [1, 5000]}, + {"id": 1, "rgb_triplet": [1, 2, 3], "acronym": "b", "structure_id_path": [1]}, + {"id": 5000, "rgb_triplet": [4, 5, 6], "acronym": "a", "structure_id_path": [1, 5000]}, ] annotation = np.zeros((10, 10, 10)) @@ -85,116 +86,111 @@ def itksnap_rsp(): def test_direct_voxel_counts(rsp): obt_one = rsp.direct_voxel_map obt_two = rsp.direct_voxel_map - - assert( obt_one[3] == 8 ) - assert( obt_one[2] == 4**3 - 2**3 - 1 ) - assert( obt_two[1] == 0 ) - assert( obt_two[2] == 4**3 - 2**3 - 1 ) - -def test_total_voxel_counts(rsp): + assert obt_one[3] == 8 + assert obt_one[2] == 4**3 - 2**3 - 1 + assert obt_two[1] == 0 + assert obt_two[2] == 4**3 - 2**3 - 1 + +def test_total_voxel_counts(rsp): obt = rsp.total_voxel_map - - assert( obt[2] == 4**3 ) - assert( obt[6] == 4 ) - - -def test_remove_unassigned(rsp): + assert obt[2] == 4**3 + assert obt[6] == 4 + + +def test_remove_unassigned(rsp): rsp.remove_unassigned() node_ids = rsp.structure_tree.node_ids() - - assert( 1 in node_ids ) - assert( 7 not in node_ids ) - - -def test_make_structure_mask(rsp): + assert 1 in node_ids + assert 7 not in node_ids + + +def test_make_structure_mask(rsp): exp = np.zeros((10, 10, 10)) exp[4:8, 4:8, 4:8] = 1 exp[8:10, 8:10, 8:10] = 1 obt = rsp.make_structure_mask([2, 3, 7]) - assert( np.allclose(obt, exp) ) - - -def test_make_structure_mask_direct(rsp): + assert np.allclose(obt, exp) + +def test_make_structure_mask_direct(rsp): exp = np.zeros((10, 10, 10)) exp[5:7, 5:7, 6:7] = 1 obt = rsp.make_structure_mask([5], True) - assert( np.allclose(obt, exp) ) - - -def test_many_structure_masks(rsp): + assert np.allclose(obt, exp) + +def test_many_structure_masks(rsp): cb = mock.MagicMock() [ii for ii in rsp.many_structure_masks([2, 3], output_cb=cb)] - assert( cb.call_count == 2 ) + assert cb.call_count == 2 + - def test_many_structure_masks_default_cb(rsp): - rsp.make_structure_mask = mock.MagicMock(return_value=2) for item in rsp.many_structure_masks([1]): - assert( np.allclose(item, [1, 2]) ) - - + assert np.allclose(item, [1, 2]) + + def test_check_coverage(rsp): - mask = np.zeros((10, 10, 10)) mask[7:10, 7:10, 7:10] = 1 - + obt = rsp.check_coverage([3], mask) - assert( np.count_nonzero(obt) == 27 - 8 ) - - -def test_validate_structures(rsp): + assert np.count_nonzero(obt) == 27 - 8 + +def test_validate_structures(rsp): rsp.structure_tree.has_overlaps = mock.MagicMock() rsp.check_coverage = mock.MagicMock() - + rsp.validate_structures(1, 2) - + rsp.structure_tree.has_overlaps.assert_called_with(1) rsp.check_coverage.assert_called_with(1, 2) - -def test_downsample(rsp): +def test_downsample(rsp): target = rsp.downsample((10, 20, 20)) - assert( np.allclose(target.annotation.shape, [10, 5, 5]) ) + assert np.allclose(target.annotation.shape, [10, 5, 5]) def test_get_slice_image(rsp): - - cmap = {0: [0, 0, 0], 1: [0, 0, 0], 2: [0, 0, 0], 3: [1, 2, 3], - 4: [0, 0, 0], 5: [0, 0, 0], 6: [0, 0, 0], 7: [0, 0, 0], } + cmap = { + 0: [0, 0, 0], + 1: [0, 0, 0], + 2: [0, 0, 0], + 3: [1, 2, 3], + 4: [0, 0, 0], + 5: [0, 0, 0], + 6: [0, 0, 0], + 7: [0, 0, 0], + } image = rsp.get_slice_image(0, 90, cmap=cmap) - assert( image[:, :, 0].sum() == 4 ) + assert image[:, :, 0].sum() == 4 def test_direct_voxel_map_setter(rsp): - rsp.direct_voxel_map = 4 - assert( rsp.direct_voxel_map == 4 ) + assert rsp.direct_voxel_map == 4 def test_total_voxel_map_setter(rsp): - rsp.total_voxel_map = 3 - assert( rsp.total_voxel_map == 3 ) + assert rsp.total_voxel_map == 3 def test_export_itksnap_labels(itksnap_rsp): - annot, labels = itksnap_rsp.export_itksnap_labels(id_type=np.uint8) exp = np.zeros((10, 10, 10)) @@ -202,16 +198,15 @@ def test_export_itksnap_labels(itksnap_rsp): exp[:, :, 7:] = 1 assert set(np.unique(annot)) == set([0, 1, 2]) - assert np.array_equal(labels['LABEL'][:], ['a', 'b']) - assert set(labels['IDX'].values) == set([1, 2]) + assert np.array_equal(labels["LABEL"][:], ["a", "b"]) + assert set(labels["IDX"].values) == set([1, 2]) assert np.allclose(exp, annot) def test_write_itksnap_labels(itksnap_rsp, tmpdir_factory): - - tmpdir = str(tmpdir_factory.mktemp('test_write_itksnap_labels')) - annot_path = os.path.join(tmpdir, 'annot.nrrd') - labels_path = os.path.join(tmpdir, 'labels.csv') + tmpdir = str(tmpdir_factory.mktemp("test_write_itksnap_labels")) + annot_path = os.path.join(tmpdir, "annot.nrrd") + labels_path = os.path.join(tmpdir, "labels.csv") itksnap_rsp.write_itksnap_labels(annot_path, labels_path, id_type=np.uint8) exp_annot, exp_labels = itksnap_rsp.export_itksnap_labels(id_type=np.uint8) @@ -220,13 +215,12 @@ def test_write_itksnap_labels(itksnap_rsp, tmpdir_factory): assert np.allclose(obt_annot, exp_annot) obt_labels = pd.read_csv( - labels_path, - delim_whitespace=True, - names=['IDX', '-R-', '-G-', '-B-', '-A-', 'VIS', 'MSH', 'LABEL'], - index_col=False + labels_path, + delim_whitespace=True, + names=["IDX", "-R-", "-G-", "-B-", "-A-", "VIS", "MSH", "LABEL"], + index_col=False, ) pd.testing.assert_frame_equal(obt_labels, exp_labels, check_index_type=False) assert os.path.exists(labels_path) assert os.path.exists(annot_path) - diff --git a/allensdk/test/core/test_reference_space_cache.py b/allensdk/test/core/test_reference_space_cache.py index db18a2f360..16b7c6d564 100644 --- a/allensdk/test/core/test_reference_space_cache.py +++ b/allensdk/test/core/test_reference_space_cache.py @@ -47,7 +47,7 @@ @pytest.fixture() def rsp_version(): - return 'annotation/look_a_version' + return "annotation/look_a_version" @pytest.fixture() @@ -55,42 +55,48 @@ def resolution(): return 25 -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def old_nodes(): - - return [{'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'parent_structure_id': 12}] - - -@pytest.fixture(scope='function') + return [ + { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "parent_structure_id": 12, + } + ] + + +@pytest.fixture(scope="function") def new_nodes(): - - return [{'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_sets':[{'id': 1}, {'id': 4}, {'id': 167587189}] }] - - -@pytest.fixture(scope='function') + return [ + { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_sets": [{"id": 1}, {"id": 4}, {"id": 167587189}], + } + ] + + +@pytest.fixture(scope="function") def rsp(fn_temp_dir, rsp_version, resolution): - - manifest_path = os.path.join(fn_temp_dir, 'manifest.json') - return ReferenceSpaceCache(reference_space_key=rsp_version, - resolution=resolution, - manifest=manifest_path) - + manifest_path = os.path.join(fn_temp_dir, "manifest.json") + return ReferenceSpaceCache(reference_space_key=rsp_version, resolution=resolution, manifest=manifest_path) def test_init(rsp, fn_temp_dir): - - manifest_path = os.path.join(fn_temp_dir, 'manifest.json') - assert( os.path.exists(manifest_path) ) + manifest_path = os.path.join(fn_temp_dir, "manifest.json") + assert os.path.exists(manifest_path) def test_get_annotation_volume(rsp, fn_temp_dir, rsp_version, resolution): - eye = np.eye(100) - path = os.path.join(fn_temp_dir, rsp_version, 'annotation_{0}.nrrd'.format(resolution)) + path = os.path.join(fn_temp_dir, rsp_version, "annotation_{0}.nrrd".format(resolution)) rsp.api.retrieve_file_over_http = lambda a, b: nrrd.write(b, eye) obtained, _ = rsp.get_annotation_volume() @@ -99,14 +105,13 @@ def test_get_annotation_volume(rsp, fn_temp_dir, rsp_version, resolution): rsp.get_annotation_volume() rsp.api.retrieve_file_over_http.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_template_volume(rsp, fn_temp_dir, resolution): - eye = np.eye(100) - path = os.path.join(fn_temp_dir, 'average_template_{0}.nrrd'.format(resolution)) + path = os.path.join(fn_temp_dir, "average_template_{0}.nrrd".format(resolution)) rsp.api.retrieve_file_over_http = lambda a, b: nrrd.write(b, eye) obtained, _ = rsp.get_template_volume() @@ -115,52 +120,45 @@ def test_get_template_volume(rsp, fn_temp_dir, resolution): rsp.get_template_volume() rsp.api.retrieve_file_over_http.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_structure_tree(rsp, fn_temp_dir, new_nodes): + path = os.path.join(fn_temp_dir, "structures.json") - path = os.path.join(fn_temp_dir, 'structures.json') - - with mock.patch('allensdk.api.queries.ontologies_api.' - 'OntologiesApi.model_query', - return_value=new_nodes) as p: - + with mock.patch("allensdk.api.queries.ontologies_api.OntologiesApi.model_query", return_value=new_nodes) as p: obtained = rsp.get_structure_tree() rsp.get_structure_tree() p.assert_called_once() - assert(obtained.node_ids()[0] == 0) - + assert obtained.node_ids()[0] == 0 + cm_obt = obtained.get_colormap() - assert(len(cm_obt[0]) == 3) + assert len(cm_obt[0]) == 3 - assert( os.path.exists(path) ) + assert os.path.exists(path) def test_get_reference_space(rsp, new_nodes): - tree = StructureTree(StructureTree.clean_structures(new_nodes)) rsp.get_structure_tree = lambda *a, **k: tree annot = np.arange(125).reshape((5, 5, 5)) - rsp.get_annotation_volume = lambda *a, **k: (annot, 'foo') + rsp.get_annotation_volume = lambda *a, **k: (annot, "foo") rsp_obt = rsp.get_reference_space() - assert( np.allclose(rsp_obt.resolution, [25, 25, 25]) ) - assert( np.allclose( rsp_obt.annotation, annot ) ) + assert np.allclose(rsp_obt.resolution, [25, 25, 25]) + assert np.allclose(rsp_obt.annotation, annot) def test_get_structure_mask(rsp, fn_temp_dir, rsp_version): - sid = 12 eye = np.eye(100) - path = os.path.join(fn_temp_dir, rsp_version, 'structure_masks', - 'resolution_25', 'structure_{0}.nrrd'.format(sid)) + path = os.path.join(fn_temp_dir, rsp_version, "structure_masks", "resolution_25", "structure_{0}.nrrd".format(sid)) rsp.api.retrieve_file_over_http = lambda a, b: nrrd.write(b, eye) obtained, _ = rsp.get_structure_mask(sid) @@ -169,19 +167,18 @@ def test_get_structure_mask(rsp, fn_temp_dir, rsp_version): rsp.get_structure_mask(sid) rsp.api.retrieve_file_over_http.assert_not_called() - assert( np.allclose(obtained, eye) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained, eye) + assert os.path.exists(path) def test_get_structure_mesh(rsp, fn_temp_dir, rsp_version): - sid = 12 - path = os.path.join(fn_temp_dir, rsp_version, 'structure_meshes','structure_{0}.obj'.format(sid)) + path = os.path.join(fn_temp_dir, rsp_version, "structure_meshes", "structure_{0}.obj".format(sid)) def write_obj(path): - with open(path, 'w') as fil: - fil.write('vn 1 2 4') + with open(path, "w") as fil: + fil.write("vn 1 2 4") expected = [1, 2, 4] @@ -192,31 +189,27 @@ def write_obj(path): rsp.get_structure_mesh(sid) rsp.api.retrieve_file_over_http.assert_not_called() - assert( np.allclose(obtained[1], expected) ) - assert( os.path.exists(path) ) + assert np.allclose(obtained[1], expected) + assert os.path.exists(path) -@pytest.mark.parametrize('inp,fails', [(1, False), - (pd.Series([2]), False), - ('qwerty', True)]) +@pytest.mark.parametrize("inp,fails", [(1, False), (pd.Series([2]), False), ("qwerty", True)]) def test_validate_structure_id(inp, fails): - if fails: with pytest.raises(ValueError): ReferenceSpaceCache.validate_structure_id(inp) else: out = ReferenceSpaceCache.validate_structure_id(inp) - assert( out == int(inp) ) + assert out == int(inp) -@pytest.mark.parametrize('inp,fails', [([1, 2, 3], False), - ([pd.Series([2]), pd.Series([3])], False), - (['qwerty', 1], True)]) +@pytest.mark.parametrize( + "inp,fails", [([1, 2, 3], False), ([pd.Series([2]), pd.Series([3])], False), (["qwerty", 1], True)] +) def test_validate_structure_ids(inp, fails): - if fails: with pytest.raises(ValueError): ReferenceSpaceCache.validate_structure_ids(inp) else: out = ReferenceSpaceCache.validate_structure_ids(inp) - assert( out == list(map(int, inp)) ) + assert out == list(map(int, inp)) diff --git a/allensdk/test/core/test_reference_space_notebook.py b/allensdk/test/core/test_reference_space_notebook.py index 701892c589..cb33013dc5 100644 --- a/allensdk/test/core/test_reference_space_notebook.py +++ b/allensdk/test/core/test_reference_space_notebook.py @@ -40,7 +40,6 @@ @pytest.mark.nightly def test_notebook(tmpdir_factory): - # coding: utf-8 # # Reference Space @@ -66,16 +65,14 @@ def test_notebook(tmpdir_factory): structure_graph = oapi.get_structures_with_sets([1]) # 1 is the id of the adult mouse structure graph # This removes some unused fields returned by the query - structure_graph = StructureTree.clean_structures(structure_graph) + structure_graph = StructureTree.clean_structures(structure_graph) tree = StructureTree(structure_graph) - # In[2]: # now let's take a look at a structure - tree.get_structures_by_name(['Dorsal auditory area']) - + tree.get_structures_by_name(["Dorsal auditory area"]) # The fields are: # * acronym: a shortened name for the structure @@ -94,7 +91,6 @@ def test_notebook(tmpdir_factory): # get a structure's parent tree.parent([1011]) - # In[4]: # get a dictionary mapping structure ids to names @@ -102,7 +98,6 @@ def test_notebook(tmpdir_factory): name_map = tree.get_name_map() name_map[247] - # In[5]: # ask whether one structure is contained within another @@ -110,20 +105,18 @@ def test_notebook(tmpdir_factory): strida = 385 stridb = 247 - is_desc = '' if tree.structure_descends_from(385, 247) else ' not' - - print( '{0} is{1} in {2}'.format(name_map[strida], is_desc, name_map[stridb]) ) + is_desc = "" if tree.structure_descends_from(385, 247) else " not" + print("{0} is{1} in {2}".format(name_map[strida], is_desc, name_map[stridb])) # In[6]: # build a custom map that looks up acronyms by ids - # the syntax here is just a pair of node-wise functions. + # the syntax here is just a pair of node-wise functions. # The first one returns keys while the second one returns values - acronym_map = tree.value_map(lambda x: x['id'], lambda y: y['acronym']) - print( acronym_map[385] ) - + acronym_map = tree.value_map(lambda x: x["id"], lambda y: y["acronym"]) + print(acronym_map[385]) # ## Downloading an annotation volume # @@ -139,27 +132,25 @@ def test_notebook(tmpdir_factory): from allensdk.api.queries.mouse_connectivity_api import MouseConnectivityApi # the annotation download writes a file, so we will need somwhere to put it - annotation_dir = str(tmpdir_factory.mktemp('annotation')) + annotation_dir = str(tmpdir_factory.mktemp("annotation")) - annotation_path = os.path.join(annotation_dir, 'annotation.nrrd') + annotation_path = os.path.join(annotation_dir, "annotation.nrrd") mcapi = MouseConnectivityApi() - mcapi.download_annotation_volume('annotation/ccf_2016', 25, annotation_path) + mcapi.download_annotation_volume("annotation/ccf_2016", 25, annotation_path) annotation, meta = nrrd.read(annotation_path) - # ## Constructing a ReferenceSpace # In[8]: from allensdk.core.reference_space import ReferenceSpace - # build a reference space from a StructureTree and annotation volume, the third argument is + # build a reference space from a StructureTree and annotation volume, the third argument is # the resolution of the space in microns rsp = ReferenceSpace(tree, annotation, [25, 25, 25]) - # ## Using a ReferenceSpace # #### making structure masks @@ -173,14 +164,13 @@ def test_notebook(tmpdir_factory): # view in coronal section - # What if you want a mask for a whole collection of ontologically disparate structures? Just pass more structure ids to make_structure_masks: # In[10]: # This gets all of the structures targeted by the Allen Brain Observatory project brain_observatory_structures = rsp.structure_tree.get_structures_by_set_id([514166994]) - brain_observatory_ids = [st['id'] for st in brain_observatory_structures] + brain_observatory_ids = [st["id"] for st in brain_observatory_structures] rsp.make_structure_mask(brain_observatory_ids) @@ -192,21 +182,20 @@ def test_notebook(tmpdir_factory): import functools - # Define a wrapper function that will control the mask generation. - # This one checks for a nrrd file in the specified base directory + # Define a wrapper function that will control the mask generation. + # This one checks for a nrrd file in the specified base directory # and builds/writes the mask only if one does not exist mask_writer = functools.partial(ReferenceSpace.check_and_write, annotation_dir) - + # many_structure_masks is a generator - nothing has actrually been run yet mask_generator = rsp.many_structure_masks([385, 1097], mask_writer) # consume the resulting iterator to make and write the masks for structure_id in mask_generator: - print( 'made mask for structure {0}.'.format(structure_id) ) + print("made mask for structure {0}.".format(structure_id)) os.listdir(annotation_dir) - # #### Removing unassigned structures # A structure graph may contain structures that are not used in a particular reference space. Having these around can complicate use of the reference space, so we generally want to remove them. @@ -216,8 +205,8 @@ def test_notebook(tmpdir_factory): # In[12]: # Double-check the voxel counts - no_voxel_id = rsp.structure_tree.get_structures_by_name(['Somatosensory areas, layer 6a'])[0]['id'] - print( 'voxel count for structure {0}: {1}'.format(no_voxel_id, rsp.total_voxel_map[no_voxel_id]) ) + no_voxel_id = rsp.structure_tree.get_structures_by_name(["Somatosensory areas, layer 6a"])[0]["id"] + print("voxel count for structure {0}: {1}".format(no_voxel_id, rsp.total_voxel_map[no_voxel_id])) # remove unassigned structures from the ReferenceSpace's StructureTree rsp.remove_unassigned() @@ -225,13 +214,10 @@ def test_notebook(tmpdir_factory): # check the structure tree no_voxel_id in rsp.structure_tree.node_ids() - # #### View a slice from the annotation # In[13]: - - # #### Downsample the space # # If you want an annotation at a resolution we don't provide, you can make one with the downsample method. @@ -242,20 +228,18 @@ def test_notebook(tmpdir_factory): target_resolution = [75, 75, 75] - # in some versions of scipy, scipy.ndimage.zoom raises a helpful but distracting - # warning about the method used to truncate integers. - warnings.simplefilter('ignore') + # in some versions of scipy, scipy.ndimage.zoom raises a helpful but distracting + # warning about the method used to truncate integers. + warnings.simplefilter("ignore") sf_rsp = rsp.downsample(target_resolution) # re-enable warnings - warnings.simplefilter('default') - - print( rsp.annotation.shape ) - print( sf_rsp.annotation.shape ) + warnings.simplefilter("default") + print(rsp.annotation.shape) + print(sf_rsp.annotation.shape) # Now view the downsampled space: # In[15]: - diff --git a/allensdk/test/core/test_simple_tree.py b/allensdk/test/core/test_simple_tree.py index cc008e63f6..4ce48d2e27 100644 --- a/allensdk/test/core/test_simple_tree.py +++ b/allensdk/test/core/test_simple_tree.py @@ -41,155 +41,148 @@ @pytest.fixture def tree(): - s = frozenset([1, 2, 3]) - nodes = [{'id': 0, 'parent': None, 1: 2, s: 'a'}, {'id': 1, 'parent': 0, 1: 7, s: 'd'}, - {'id': 2, 'parent': 0, 1: 3, s: 'b'}, {'id': 3, 'parent': 1, 1: 6, s: 'e'}, - {'id': 4, 'parent': 1, 1: 4, s: 'c'}, {'id': 5, 'parent': 2, 1: 5, s: 'f'}] - + nodes = [ + {"id": 0, "parent": None, 1: 2, s: "a"}, + {"id": 1, "parent": 0, 1: 7, s: "d"}, + {"id": 2, "parent": 0, 1: 3, s: "b"}, + {"id": 3, "parent": 1, 1: 6, s: "e"}, + {"id": 4, "parent": 1, 1: 4, s: "c"}, + {"id": 5, "parent": 2, 1: 5, s: "f"}, + ] + def parent_fn(node): - return node['parent'] + return node["parent"] + def id_fn(node): - return node['id'] - + return node["id"] + return SimpleTree(nodes, id_fn, parent_fn) - - -def test_initialization(tree): - assert( None in tree._parent_ids.values() ) - assert( len(tree._child_ids) == 6 ) - - -def test_filter_nodes(tree): - - two_par = tree.filter_nodes(lambda node: node['parent'] == 2) - assert( two_par[0]['id'] == 5 ) - assert( len(two_par) == 1 ) + +def test_initialization(tree): + assert None in tree._parent_ids.values() + assert len(tree._child_ids) == 6 -@pytest.mark.parametrize('key,val,to,exp', [ ['id', [2, 1, 3], lambda x: x['id'],[2, 1, 3]], - [lambda x: x['id'], [2, 1, 3], lambda x: x['id'],[2, 1, 3]], - [1, [3, 7, 6], lambda x: x['id'],[2, 1, 3]], - [frozenset([1, 2, 3]), ['b'], lambda x: x[1], [3]] ]) +def test_filter_nodes(tree): + two_par = tree.filter_nodes(lambda node: node["parent"] == 2) + assert two_par[0]["id"] == 5 + assert len(two_par) == 1 + + +@pytest.mark.parametrize( + "key,val,to,exp", + [ + ["id", [2, 1, 3], lambda x: x["id"], [2, 1, 3]], + [lambda x: x["id"], [2, 1, 3], lambda x: x["id"], [2, 1, 3]], + [1, [3, 7, 6], lambda x: x["id"], [2, 1, 3]], + [frozenset([1, 2, 3]), ["b"], lambda x: x[1], [3]], + ], +) def test_nodes_by_property(tree, key, val, to, exp): + obt = tree.nodes_by_property(key, val, to_fn=to) + assert allclose(obt, exp) - obt = tree.nodes_by_property( key, val, to_fn=to ) - assert( allclose( obt, exp) ) - def test_value_map(tree): - - parent_map = tree.value_map(lambda node: node['id'], - lambda node: node['parent']) - - assert( len(parent_map) == 6 ) - assert( parent_map[2] == 0 ) - assert( parent_map[3] == 1 ) - + parent_map = tree.value_map(lambda node: node["id"], lambda node: node["parent"]) + + assert len(parent_map) == 6 + assert parent_map[2] == 0 + assert parent_map[3] == 1 + def test_value_map_nonunique(tree): - - with pytest.raises( RuntimeError ): - tree.value_map(lambda node: node['parent'], - lambda node: node['id']) + with pytest.raises(RuntimeError): + tree.value_map(lambda node: node["parent"], lambda node: node["id"]) - -def test_node_ids(tree): +def test_node_ids(tree): obtained = tree.node_ids() expected = range(6) - - assert( set(obtained) == set(expected) ) - - -def test_parent_ids(tree): + assert set(obtained) == set(expected) + + +def test_parent_ids(tree): nodes = [5, 4, 2] obtained = tree.parent_ids(nodes) - - assert( allclose([2, 1, 0], obtained) ) - - -def test_child_ids(tree): + assert allclose([2, 1, 0], obtained) + + +def test_child_ids(tree): obtained = tree.child_ids([1]) - assert( set(obtained[0]) == set([4, 3]) ) - assert( len(obtained) == 1 ) - - -def test_ancestor_ids(tree): + assert set(obtained[0]) == set([4, 3]) + assert len(obtained) == 1 + +def test_ancestor_ids(tree): obtained = tree.ancestor_ids([5, 1]) - - assert( len(obtained) == 2 ) - assert( set(obtained[0]) == set([5, 2, 0]) ) - assert( set(obtained[1]) == set([1, 0]) ) - - -def test_descendant_ids(tree): + assert len(obtained) == 2 + assert set(obtained[0]) == set([5, 2, 0]) + assert set(obtained[1]) == set([1, 0]) + + +def test_descendant_ids(tree): obtained = tree.descendant_ids([0, 3]) - assert( len(obtained) == 2 ) - assert( set(obtained[0]) == set(range(6)) ) - assert( set(obtained[1]) == set([3]) ) - - + assert len(obtained) == 2 + assert set(obtained[0]) == set(range(6)) + assert set(obtained[1]) == set([3]) + + def test_nodes(tree): - obtained = tree.nodes([0, 1]) - - assert( len(obtained) == 2 ) - assert( obtained[0]['parent'] is None ) - assert( obtained[1]['id'] == 1 ) - - -def test_nodes_default(tree): + assert len(obtained) == 2 + assert obtained[0]["parent"] is None + assert obtained[1]["id"] == 1 + + +def test_nodes_default(tree): obtained = tree.nodes() - assert( len(obtained) == 6 ) - + assert len(obtained) == 6 -def test_parents(tree): +def test_parents(tree): obtained = tree.parents([0, 1]) - assert( len(obtained) == 2 ) - assert( obtained[0] is None ) + assert len(obtained) == 2 + assert obtained[0] is None + def test_children(tree): - obtained = tree.children([0, 5]) - assert( len(obtained) == 2 ) - assert( set(obtained[1]) == set([]) ) - assert( len(obtained[0]) == 2 ) - assert( isinstance(obtained[0][0], dict) ) - -def test_descendants(tree): + assert len(obtained) == 2 + assert set(obtained[1]) == set([]) + assert len(obtained[0]) == 2 + assert isinstance(obtained[0][0], dict) + +def test_descendants(tree): obtained = tree.descendants([0, 3]) - assert( len(obtained) == 2 ) - assert( len(obtained[0]) == 6 ) - assert( obtained[1][0]['id'] == 3 ) - assert( isinstance(obtained[0][0], dict) ) + assert len(obtained) == 2 + assert len(obtained[0]) == 6 + assert obtained[1][0]["id"] == 3 + assert isinstance(obtained[0][0], dict) def test_ancestors(tree): - obtained = tree.ancestors([5, 1]) - - assert( len(obtained) == 2 ) - assert( len(obtained[0]) == 3 ) - assert( isinstance(obtained[0][0], dict) ) - assert( len(obtained[1]) == 2 ) + assert len(obtained) == 2 + assert len(obtained[0]) == 3 + assert isinstance(obtained[0][0], dict) + assert len(obtained[1]) == 2 -def test_cbs(tree): +def test_cbs(tree): nodes = tree.nodes() for node in nodes: - assert( node['id'] == tree.node_id_cb(node) ) - assert( node['parent'] == tree.parent_id_cb(node) ) + assert node["id"] == tree.node_id_cb(node) + assert node["parent"] == tree.parent_id_cb(node) diff --git a/allensdk/test/core/test_sitk_utilities.py b/allensdk/test/core/test_sitk_utilities.py index 4c8a839b34..0622aef3c8 100644 --- a/allensdk/test/core/test_sitk_utilities.py +++ b/allensdk/test/core/test_sitk_utilities.py @@ -42,35 +42,33 @@ from allensdk.core import sitk_utilities as su - @pytest.fixture(params=[1, 2, 3, 4]) def ncomponents(request): return request.param -@pytest.fixture(params=[ [10, 20], [10, 20, 30], [10, 20, 30], [10, 10, 10], [20, 20] ]) +@pytest.fixture(params=[[10, 20], [10, 20, 30], [10, 20, 30], [10, 10, 10], [20, 20]]) def size(request): return request.param -@pytest.fixture(params=[ lambda x: list(range(x+1))[1:], lambda x: [10] * x ]) +@pytest.fixture(params=[lambda x: list(range(x + 1))[1:], lambda x: [10] * x]) def spacing(request): return request.param -@pytest.fixture(params=[ lambda x: list(range(x+1))[1:], lambda x: [5] * x ]) +@pytest.fixture(params=[lambda x: list(range(x + 1))[1:], lambda x: [5] * x]) def origin(request): return request.param -@pytest.fixture(params=[ lambda x: np.eye(x).flatten() ]) +@pytest.fixture(params=[lambda x: np.eye(x).flatten()]) def direction(request): return request.param -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def image(size, ncomponents, spacing, origin, direction): - if ncomponents > 1: img = sitk.Image(size, sitk.sitkVectorUInt8, ncomponents) else: @@ -87,96 +85,95 @@ def image(size, ncomponents, spacing, origin, direction): dir_val = direction(ndim) img.SetDirection(dir_val) - return img, {'ncomponents': ncomponents, - 'size': size, - 'spacing': spacing_val, - 'origin': origin_val, - 'direction': dir_val} + return img, { + "ncomponents": ncomponents, + "size": size, + "spacing": spacing_val, + "origin": origin_val, + "direction": dir_val, + } def test_get_sitk_image_information(image): - obtained = su.get_sitk_image_information(image[0]) for key, value in image[1].items(): - assert(np.allclose( obtained[key], value )) + assert np.allclose(obtained[key], value) def test_set_sitk_image_information_roundtrip(image): - info = su.get_sitk_image_information(image[0]) arr = sitk.GetArrayFromImage(image[0]) - new_image = sitk.GetImageFromArray(arr, info['ncomponents'] > 1) + new_image = sitk.GetImageFromArray(arr, info["ncomponents"] > 1) su.set_sitk_image_information(new_image, info) obtained = su.get_sitk_image_information(new_image) for key, value in info.items(): - assert(np.allclose( obtained[key], value )) - - -@pytest.mark.parametrize('act,dec,nc', [ ([10, 20], [20, 10], 1), - ([10, 20, 30], [30, 20, 10], 1), - ([10, 20, 30, 3], [30, 20, 10], 3), - ([10, 10, 10, 3], [10, 10, 10], 3 ) ]) + assert np.allclose(obtained[key], value) + + +@pytest.mark.parametrize( + "act,dec,nc", + [ + ([10, 20], [20, 10], 1), + ([10, 20, 30], [30, 20, 10], 1), + ([10, 20, 30, 3], [30, 20, 10], 3), + ([10, 10, 10, 3], [10, 10, 10], 3), + ], +) def test_fix_array_dimensions(act, dec, nc): - arr = np.zeros(act) obt = su.fix_array_dimensions(arr, nc) if nc == 1: - assert(np.array_equal( obt.shape, dec )) + assert np.array_equal(obt.shape, dec) else: - assert(np.array_equal( obt.shape[:-1], dec )) + assert np.array_equal(obt.shape[:-1], dec) - assert( not np.isfortran(obt) ) + assert not np.isfortran(obt) def test_sitk_metaimage_roundtrip(tmpdir_factory, size): + path = tmpdir_factory.mktemp("metaimage_io_test").join("dummy.mhd") - path = tmpdir_factory.mktemp('metaimage_io_test').join('dummy.mhd') - array = np.random.rand(*size) su.write_ndarray_with_sitk(array, path) obt_image, obt_info = su.read_ndarray_with_sitk(path) - assert(np.allclose( obt_image, array )) + assert np.allclose(obt_image, array) def test_sitk_metaimage_vector_roundtrip(tmpdir_factory, size): + path = tmpdir_factory.mktemp("metaimage_io_test").join("dummy.mhd") - path = tmpdir_factory.mktemp('metaimage_io_test').join('dummy.mhd') - size = list(size) + [3] array = np.random.rand(*size) su.write_ndarray_with_sitk(array, path, ncomponents=3) obt_image, obt_info = su.read_ndarray_with_sitk(path) - assert(np.allclose( obt_image, array )) - assert( obt_info['ncomponents'] == 3 ) + assert np.allclose(obt_image, array) + assert obt_info["ncomponents"] == 3 def test_sitk_nrrd_read(tmpdir_factory, size): - - path = tmpdir_factory.mktemp('nrrd_io_test').join('dummy.nrrd') + path = tmpdir_factory.mktemp("nrrd_io_test").join("dummy.nrrd") array = np.random.rand(*size) - + nrrd.write(str(path), array) obt_image, obt_info = su.read_ndarray_with_sitk(path) - assert(np.allclose( obt_image, array )) - + assert np.allclose(obt_image, array) def test_sitk_nrrd_write(tmpdir_factory, size): + path = tmpdir_factory.mktemp("nrrd_io_test").join("dummy_again.nrrd") - path = tmpdir_factory.mktemp('nrrd_io_test').join('dummy_again.nrrd') - array = np.random.rand(*size) su.write_ndarray_with_sitk(array, path) obt_image, obt_info = nrrd.read(str(path)) - assert(np.allclose( obt_image, array )) + assert np.allclose(obt_image, array) diff --git a/allensdk/test/core/test_structure_tree.py b/allensdk/test/core/test_structure_tree.py index 2cab817e60..095c069b1d 100644 --- a/allensdk/test/core/test_structure_tree.py +++ b/allensdk/test/core/test_structure_tree.py @@ -45,206 +45,242 @@ if sys.version_info > (3,): long = int + @pytest.fixture def nodes(): - - return [{'id': 0, 'structure_id_path': [0], 'rgb_triplet': [0, 0, 0], 'acronym': 'rt', 'name': 'root', 'structure_set_ids':[1, 4]}, - {'id': 1, 'structure_id_path': [0, 1], 'rgb_triplet': [0, 15, 255], 'acronym': 'a', 'name': 'alpha', 'structure_set_ids': [1, 3]}, - {'id': 2, 'structure_id_path': [0, 2], 'rgb_triplet': [255, 255, 255], 'acronym': 'b', 'name': 'beta', 'structure_set_ids': [1, 2]}] + return [ + { + "id": 0, + "structure_id_path": [0], + "rgb_triplet": [0, 0, 0], + "acronym": "rt", + "name": "root", + "structure_set_ids": [1, 4], + }, + { + "id": 1, + "structure_id_path": [0, 1], + "rgb_triplet": [0, 15, 255], + "acronym": "a", + "name": "alpha", + "structure_set_ids": [1, 3], + }, + { + "id": 2, + "structure_id_path": [0, 2], + "rgb_triplet": [255, 255, 255], + "acronym": "b", + "name": "beta", + "structure_set_ids": [1, 2], + }, + ] @pytest.fixture def tree(nodes): return StructureTree(nodes) - + @pytest.fixture def oapi(): oa = OntologiesApi() - - oa.get_structures = mock.MagicMock(return_value=[{'id': 1, 'structure_id_path': '1'}]) + + oa.get_structures = mock.MagicMock(return_value=[{"id": 1, "structure_id_path": "1"}]) oa.get_structure_set_map = mock.MagicMock(return_value={1: [2, 3]}) - + return oa - - + + def test_get_structures_by_id(tree): - obtained = tree.get_structures_by_id([1, 2]) - assert( len(obtained) == 2 ) - - + assert len(obtained) == 2 + + def test_get_structures_by_name(tree): - - obtained = tree.get_structures_by_name(['root']) - assert( len(obtained) == 1 ) - - + obtained = tree.get_structures_by_name(["root"]) + assert len(obtained) == 1 + + def test_get_structures_by_acronym(tree): + obtained = tree.get_structures_by_acronym(["rt", "a", "b"]) + assert len(obtained) == 3 - obtained = tree.get_structures_by_acronym(['rt', 'a', 'b']) - assert( len(obtained) == 3) - def test_get_structures_by_set_id(tree): - obtained = tree.get_structures_by_set_id([2, 3]) - assert( len(obtained) == 2 ) - - + assert len(obtained) == 2 + + def test_get_colormap(tree): - obtained = tree.get_colormap() - assert( allclose(obtained[0], [0, 0, 0]) ) - assert( allclose(obtained[2], [255, 255, 255]) ) - - + assert allclose(obtained[0], [0, 0, 0]) + assert allclose(obtained[2], [255, 255, 255]) + + def test_get_name_map(tree): - obtained = tree.get_name_map() - assert( obtained[0] == 'root' ) - assert( obtained[2] == 'beta' ) - - + assert obtained[0] == "root" + assert obtained[2] == "beta" + + def test_get_id_acronym_map(tree): - obtained = tree.get_id_acronym_map() - assert( obtained['rt'] == 0 ) - + assert obtained["rt"] == 0 -def test_get_ancestor_id_map(tree): +def test_get_ancestor_id_map(tree): obtained = tree.get_ancestor_id_map() - assert( set(obtained[2]) == set([2, 0]) ) - + assert set(obtained[2]) == set([2, 0]) + def test_structure_descends_from(tree): - - assert( tree.structure_descends_from(2, 0) ) - assert( not tree.structure_descends_from(0, 1) ) - - + assert tree.structure_descends_from(2, 0) + assert not tree.structure_descends_from(0, 1) + + def test_has_overlaps(tree): - obtained = tree.has_overlaps([0, 1, 2]) - assert( obtained == set([0]) ) - + assert obtained == set([0]) + obag = tree.has_overlaps([1, 2]) - assert( not obag ) + assert not obag def test_clean_structures(nodes): + dirty_node = { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_sets": [{"id": 1}, {"id": 4}], + } - dirty_node = {'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_sets':[{'id': 1}, {'id': 4}]} - clean_node = StructureTree.clean_structures([dirty_node])[0] - assert( isinstance(clean_node['rgb_triplet'], list) ) - assert( isinstance(clean_node['structure_id_path'], list) ) - - + assert isinstance(clean_node["rgb_triplet"], list) + assert isinstance(clean_node["structure_id_path"], list) + + def test_clean_structures_no_sets(): - - dirty_node = {'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root'} - + dirty_node = {"id": 0, "structure_id_path": "/0/", "color_hex_triplet": "000000", "acronym": "rt", "name": "root"} + clean_node = StructureTree.clean_structures([dirty_node]) StructureTree(clean_node) - - assert( len(clean_node[0]['structure_set_ids']) == 0 ) - - + + assert len(clean_node[0]["structure_set_ids"]) == 0 + + def test_clean_structures_only_ids(): - - dirty_node = {'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_set_ids': [1, 2, 3] } - + dirty_node = { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_set_ids": [1, 2, 3], + } + clean_node = StructureTree.clean_structures([dirty_node]) StructureTree(clean_node) - - assert( len(clean_node[0]['structure_set_ids']) == 3 ) - - + + assert len(clean_node[0]["structure_set_ids"]) == 3 + + def test_clean_structures_ids_sets(): - - dirty_node = {'id': 0, 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_set_ids': [1, 2, 3], - 'structure_sets': [{'id': 1}, {'id': 4}] } - + dirty_node = { + "id": 0, + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_set_ids": [1, 2, 3], + "structure_sets": [{"id": 1}, {"id": 4}], + } + clean_node = StructureTree.clean_structures([dirty_node]) StructureTree(clean_node) - - assert( len(clean_node[0]['structure_set_ids']) == 4 ) - - + + assert len(clean_node[0]["structure_set_ids"]) == 4 + + def test_clean_structures_str_id(): + dirty_node = { + "id": "0", + "structure_id_path": "/0/", + "color_hex_triplet": "000000", + "acronym": "rt", + "name": "root", + "structure_set_ids": [1, 2, 3], + "structure_sets": [{"id": 1}, {"id": 4}], + } - dirty_node = {'id': '0', 'structure_id_path': '/0/', - 'color_hex_triplet': '000000', 'acronym': 'rt', - 'name': 'root', 'structure_set_ids': [1, 2, 3], - 'structure_sets': [{'id': 1}, {'id': 4}] } - clean_node = StructureTree.clean_structures([dirty_node]) st = StructureTree(clean_node) - - assert( set(st.node_ids()) == set([0]) ) - - -def test_get_structure_sets(tree): + assert set(st.node_ids()) == set([0]) + + +def test_get_structure_sets(tree): expected = set([1, 2, 3, 4]) obtained = tree.get_structure_sets() - assert( expected == obtained ) + assert expected == obtained def test_clean_structures_weird_keys(): - - dirty_node = {'id': 5, 'dummy_key': 'dummy_val'} + dirty_node = {"id": 5, "dummy_key": "dummy_val"} clean_node = StructureTree.clean_structures([dirty_node])[0] - assert( len(clean_node) == 2 ) - assert( clean_node['id'] == 5 ) + assert len(clean_node) == 2 + assert clean_node["id"] == 5 -@pytest.mark.parametrize('inp,out', [('990099', [153, 0, 153]), - ('#990099', [153, 0, 153]), - ([153, 0, 153], [153, 0, 153]), - ((153., 0., 153.), [153, 0, 153]), - ([long(153), long(0), long(153)], [153, 0, 153])]) +@pytest.mark.parametrize( + "inp,out", + [ + ("990099", [153, 0, 153]), + ("#990099", [153, 0, 153]), + ([153, 0, 153], [153, 0, 153]), + ((153.0, 0.0, 153.0), [153, 0, 153]), + ([long(153), long(0), long(153)], [153, 0, 153]), + ], +) def test_hex_to_rgb(inp, out): obt = StructureTree.hex_to_rgb(inp) - assert(allclose(obt, out)) + assert allclose(obt, out) -@pytest.mark.parametrize('inp,out', [('/1/2/3/', [1, 2, 3]), - ('1/2/3/', [1, 2, 3]), - ('/1/2/3', [1, 2, 3]), - ('1/2/3', [1, 2, 3]), - ([1, 2, 3], [1, 2, 3]), - ([1.0, long(2), 3], [1, 2, 3]), - ((1, 2, 3), [1, 2, 3]), - ('', [])]) +@pytest.mark.parametrize( + "inp,out", + [ + ("/1/2/3/", [1, 2, 3]), + ("1/2/3/", [1, 2, 3]), + ("/1/2/3", [1, 2, 3]), + ("1/2/3", [1, 2, 3]), + ([1, 2, 3], [1, 2, 3]), + ([1.0, long(2), 3], [1, 2, 3]), + ((1, 2, 3), [1, 2, 3]), + ("", []), + ], +) def test_path_to_list(inp, out): obt = StructureTree.path_to_list(inp) - assert(allclose(obt, out)) + assert allclose(obt, out) def test_export_label_description(tree): - exp = pd.DataFrame({ - 'IDX': [0, 1, 2], - '-R-': [0, 0, 255], - '-G-': [0, 15, 255], - '-B-': [0, 255, 255], - '-A-': [1.0, 1.0, 1.0], - 'VIS': [1, 1, 1], - 'MSH': [1, 1, 1], - 'LABEL': ['rt', 'a', 'b'] - }).loc[:, ('IDX', '-R-', '-G-', '-B-', '-A-', 'VIS', 'MSH', 'LABEL')] + exp = pd.DataFrame( + { + "IDX": [0, 1, 2], + "-R-": [0, 0, 255], + "-G-": [0, 15, 255], + "-B-": [0, 255, 255], + "-A-": [1.0, 1.0, 1.0], + "VIS": [1, 1, 1], + "MSH": [1, 1, 1], + "LABEL": ["rt", "a", "b"], + } + ).loc[:, ("IDX", "-R-", "-G-", "-B-", "-A-", "VIS", "MSH", "LABEL")] obt = tree.export_label_description() - pd.testing.assert_frame_equal(obt, exp) \ No newline at end of file + pd.testing.assert_frame_equal(obt, exp) diff --git a/allensdk/test/ephys/test_extractor.py b/allensdk/test/ephys/test_extractor.py index c8322f2a3e..cc6ef528cf 100644 --- a/allensdk/test/ephys/test_extractor.py +++ b/allensdk/test/ephys/test_extractor.py @@ -41,6 +41,7 @@ from allensdk.ephys.ephys_extractor import EphysSweepSetFeatureExtractor, input_resistance import allensdk.ephys.ephys_extractor as ephys_extractor import os + path = os.path.dirname(__file__) @@ -139,38 +140,36 @@ def test_extractor_with_high_init_dvdt(): def test_extractor_input_resistance(): t = np.arange(0, 1.0, 5e-6) - v1 = np.ones_like(t) * -5. - v2 = np.ones_like(t) * -10. - i1 = np.ones_like(t) * -50. - i2 = np.ones_like(t) * -100. + v1 = np.ones_like(t) * -5.0 + v2 = np.ones_like(t) * -10.0 + i1 = np.ones_like(t) * -50.0 + i2 = np.ones_like(t) * -100.0 ext = EphysSweepSetFeatureExtractor([t, t], [v1, v2], [i1, i2]) ri = input_resistance(ext) - assert np.allclose(ri, 100.) + assert np.allclose(ri, 100.0) def test_fit_fi_slope(): - nsweeps = 5 - weights = np.array([ 2, 1 ]) + weights = np.array([2, 1]) amps = np.random.rand(nsweeps) iteramps = iter(amps) design = np.array([amps, np.ones_like(amps)]).T rates = np.dot(design, weights) + def build_stim_amps(): return lambda sweep: next(iteramps) class Ext(object): def sweeps(self): return np.zeros([nsweeps]) + def sweep_features(self, key): return rates - with mock.patch( - 'allensdk.ephys.ephys_extractor._step_stim_amp', - new_callable=build_stim_amps): - + with mock.patch("allensdk.ephys.ephys_extractor._step_stim_amp", new_callable=build_stim_amps): slope_obt = ephys_extractor.fit_fi_slope(Ext()) - assert(np.allclose(weights[0], slope_obt)) \ No newline at end of file + assert np.allclose(weights[0], slope_obt) diff --git a/allensdk/test/ephys/test_features.py b/allensdk/test/ephys/test_features.py index d5e2ecabdb..de417e24fd 100644 --- a/allensdk/test/ephys/test_features.py +++ b/allensdk/test/ephys/test_features.py @@ -37,6 +37,7 @@ import allensdk.ephys.ephys_features as ft import numpy as np import os + path = os.path.dirname(__file__) @@ -77,7 +78,7 @@ def test_fixed_dt(): assert ft.has_fixed_dt(t) # Change the first time point to make time steps inconsistent - t[0] -= 3. + t[0] -= 3.0 assert not ft.has_fixed_dt(t) @@ -211,8 +212,7 @@ def test_troughs_with_peak_at_end(): peaks = np.array([812, 3478]) clipped = np.array([False, True]) - troughs = ft.find_trough_indexes(v[:peaks[-1]], t[:peaks[-1]], - spikes, peaks, clipped=clipped) + troughs = ft.find_trough_indexes(v[: peaks[-1]], t[: peaks[-1]], spikes, peaks, clipped=clipped) assert np.isnan(troughs[-1]) diff --git a/allensdk/test/glif_tests.py b/allensdk/test/glif_tests.py index e5d940103f..7390b9ce1f 100644 --- a/allensdk/test/glif_tests.py +++ b/allensdk/test/glif_tests.py @@ -34,21 +34,22 @@ # POSSIBILITY OF SUCH DAMAGE. # import matplotlib -matplotlib.use('agg') -import matplotlib.pyplot as plt # noqa: #402 + +matplotlib.use("agg") +import matplotlib.pyplot as plt # noqa: #402 from allensdk.api.queries.glif_api import GlifApi # noqa: #402 import allensdk.core.json_utilities as json_utilities # noqa: #402 from allensdk.model.glif.glif_neuron import GlifNeuron # noqa: #402 from allensdk.model.glif.simulate_neuron import simulate_neuron # noqa: #402 -import os # noqa: #402 -import shutil # noqa: #402 -import logging # noqa: #402 +import os # noqa: #402 +import shutil # noqa: #402 +import logging # noqa: #402 # NEURONAL_MODEL_ID = 491547163 # level 1 LIF NEURONAL_MODEL_ID = 491547171 # level 5 GLIF -OUTPUT_DIR = 'tmp' +OUTPUT_DIR = "tmp" def test_download(): @@ -59,22 +60,18 @@ def test_download(): glif_api = GlifApi() glif_api.get_neuronal_model(NEURONAL_MODEL_ID) - glif_api.cache_stimulus_file(os.path.join( - OUTPUT_DIR, '%d.nwb' % NEURONAL_MODEL_ID)) + glif_api.cache_stimulus_file(os.path.join(OUTPUT_DIR, "%d.nwb" % NEURONAL_MODEL_ID)) neuron_config = glif_api.get_neuron_config() - json_utilities.write(os.path.join( - OUTPUT_DIR, '%d_neuron_config.json' % NEURONAL_MODEL_ID), neuron_config) + json_utilities.write(os.path.join(OUTPUT_DIR, "%d_neuron_config.json" % NEURONAL_MODEL_ID), neuron_config) ephys_sweeps = glif_api.get_ephys_sweeps() - json_utilities.write(os.path.join( - OUTPUT_DIR, 'ephys_sweeps.json'), ephys_sweeps) + json_utilities.write(os.path.join(OUTPUT_DIR, "ephys_sweeps.json"), ephys_sweeps) def test_run(): # initialize the neuron - neuron_config = json_utilities.read(os.path.join( - OUTPUT_DIR, '%d_neuron_config.json' % NEURONAL_MODEL_ID)) + neuron_config = json_utilities.read(os.path.join(OUTPUT_DIR, "%d_neuron_config.json" % NEURONAL_MODEL_ID)) neuron = GlifNeuron.from_dict(neuron_config) # make a short square pulse. stimulus units should be in Amps. @@ -86,27 +83,25 @@ def test_run(): # simulate the neuron output = neuron.run(stimulus) - voltage = output['voltage'] - threshold = output['threshold'] + voltage = output["voltage"] + threshold = output["threshold"] plt.plot(voltage) plt.plot(threshold) - plt.savefig(os.path.join(OUTPUT_DIR, 'plot.png')) + plt.savefig(os.path.join(OUTPUT_DIR, "plot.png")) def test_simulate(): logging.getLogger().setLevel(logging.DEBUG) - neuron_config = json_utilities.read(os.path.join( - OUTPUT_DIR, '%d_neuron_config.json' % NEURONAL_MODEL_ID)) - ephys_sweeps = json_utilities.read( - os.path.join(OUTPUT_DIR, 'ephys_sweeps.json')) - ephys_file_name = os.path.join(OUTPUT_DIR, '%d.nwb' % NEURONAL_MODEL_ID) + neuron_config = json_utilities.read(os.path.join(OUTPUT_DIR, "%d_neuron_config.json" % NEURONAL_MODEL_ID)) + ephys_sweeps = json_utilities.read(os.path.join(OUTPUT_DIR, "ephys_sweeps.json")) + ephys_file_name = os.path.join(OUTPUT_DIR, "%d.nwb" % NEURONAL_MODEL_ID) neuron = GlifNeuron.from_dict(neuron_config) - sweep_numbers = [s['sweep_number'] for s in ephys_sweeps] - simulate_neuron(neuron, sweep_numbers, - ephys_file_name, ephys_file_name, 0.05) + sweep_numbers = [s["sweep_number"] for s in ephys_sweeps] + simulate_neuron(neuron, sweep_numbers, ephys_file_name, ephys_file_name, 0.05) + if __name__ == "__main__": # test_download() diff --git a/allensdk/test/internal/api/test_api_prerelease.py b/allensdk/test/internal/api/test_api_prerelease.py index 4f3c06e204..3160753ff0 100644 --- a/allensdk/test/internal/api/test_api_prerelease.py +++ b/allensdk/test/internal/api/test_api_prerelease.py @@ -16,8 +16,8 @@ def api(): def test_retrieve_file_from_storage(api, fn_temp_dir): eye = np.eye(100) - target = os.path.join(fn_temp_dir, 'target') - store = os.path.join(fn_temp_dir, 'store') + target = os.path.join(fn_temp_dir, "target") + store = os.path.join(fn_temp_dir, "store") nrrd.write(store, eye) api.retrieve_file_from_storage(store, target) diff --git a/allensdk/test/internal/api/test_grid_data_api_prerelease.py b/allensdk/test/internal/api/test_grid_data_api_prerelease.py index 5bdb766698..78eb972be8 100644 --- a/allensdk/test/internal/api/test_grid_data_api_prerelease.py +++ b/allensdk/test/internal/api/test_grid_data_api_prerelease.py @@ -8,18 +8,21 @@ from allensdk.config.manifest import Manifest -from allensdk.internal.api.queries.grid_data_api_prerelease \ - import GridDataApiPrerelease, _get_grid_storage_directories +from allensdk.internal.api.queries.grid_data_api_prerelease import GridDataApiPrerelease, _get_grid_storage_directories + @pytest.fixture def storage_dirs(fn_temp_dir): - return {"111" : os.path.join(fn_temp_dir, "111"), - "222" : os.path.join(fn_temp_dir, "222")} + return {"111": os.path.join(fn_temp_dir, "111"), "222": os.path.join(fn_temp_dir, "222")} + @pytest.fixture def query_result(fn_temp_dir): - return [{b'id' : 111, b'storage_directory' : os.path.join(fn_temp_dir, "111")}, - {b'id' : 222, b'storage_directory' : os.path.join(fn_temp_dir, "222")}] + return [ + {b"id": 111, b"storage_directory": os.path.join(fn_temp_dir, "111")}, + {b"id": 222, b"storage_directory": os.path.join(fn_temp_dir, "222")}, + ] + @pytest.fixture def grid_data(storage_dirs, fn_temp_dir): @@ -34,8 +37,7 @@ def grid_data(storage_dirs, fn_temp_dir): def test_get_grid_storage_directories(storage_dirs, query_result, fn_temp_dir): # ------------------------------------------------------------------------ # test dirs only have grid/ subdirectory - with mock.patch('allensdk.internal.core.lims_utilities.query', - new=lambda a: query_result): + with mock.patch("allensdk.internal.core.lims_utilities.query", new=lambda a: query_result): obtained = _get_grid_storage_directories(GridDataApiPrerelease.GRID_DATA_DIRECTORY) assert not obtained @@ -43,10 +45,9 @@ def test_get_grid_storage_directories(storage_dirs, query_result, fn_temp_dir): # ------------------------------------------------------------------------ # test returns storage_dirs for path in storage_dirs.values(): - Manifest.safe_make_parent_dirs(os.path.join(path, 'grid')) + Manifest.safe_make_parent_dirs(os.path.join(path, "grid")) - with mock.patch('allensdk.internal.core.lims_utilities.query', - new=lambda a: query_result): + with mock.patch("allensdk.internal.core.lims_utilities.query", new=lambda a: query_result): obtained = _get_grid_storage_directories(GridDataApiPrerelease.GRID_DATA_DIRECTORY) for key, value in obtained: @@ -58,15 +59,15 @@ def test_get_grid_storage_directories(storage_dirs, query_result, fn_temp_dir): # ---------------------------------------------------------------------------- @pytest.mark.prerelease() def test_from_file_name(storage_dirs, fn_temp_dir): - file_name = os.path.join(fn_temp_dir, 'storage_dirs.json') + file_name = os.path.join(fn_temp_dir, "storage_dirs.json") - with mock.patch('allensdk.internal.api.queries.grid_data_api_prerelease.' - '_get_grid_storage_directories', - new=lambda a: storage_dirs): + with mock.patch( + "allensdk.internal.api.queries.grid_data_api_prerelease._get_grid_storage_directories", + new=lambda a: storage_dirs, + ): GridDataApiPrerelease.from_file_name(file_name) - with mock.patch('allensdk.internal.api.queries.grid_data_api_prerelease.' - '_get_grid_storage_directories') as ggsd: + with mock.patch("allensdk.internal.api.queries.grid_data_api_prerelease._get_grid_storage_directories") as ggsd: GridDataApiPrerelease.from_file_name(file_name) ggsd.assert_not_called() @@ -75,20 +76,18 @@ def test_from_file_name(storage_dirs, fn_temp_dir): @pytest.mark.prerelease() def test_download_projection_grid_data(grid_data, fn_temp_dir): - eye = np.eye(100) - target = os.path.join(fn_temp_dir, 'target') + target = os.path.join(fn_temp_dir, "target") # test invalid experiment id/no grid - assert_raises(ValueError, grid_data.download_projection_grid_data, target, - 0, 'projection_density_100.nrrd') + assert_raises(ValueError, grid_data.download_projection_grid_data, target, 0, "projection_density_100.nrrd") assert not os.path.exists(target) # test valid - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage', - new=lambda a, b, c: nrrd.write(c, eye)): - - grid_data.download_projection_grid_data(target, 111, 'projection_density_100.nrrd') + with mock.patch( + "allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage", + new=lambda a, b, c: nrrd.write(c, eye), + ): + grid_data.download_projection_grid_data(target, 111, "projection_density_100.nrrd") assert os.path.exists(target) diff --git a/allensdk/test/internal/api/test_mouse_connectivity_api_prerelease.py b/allensdk/test/internal/api/test_mouse_connectivity_api_prerelease.py index 7257c8f5ee..d0c77ad638 100644 --- a/allensdk/test/internal/api/test_mouse_connectivity_api_prerelease.py +++ b/allensdk/test/internal/api/test_mouse_connectivity_api_prerelease.py @@ -6,25 +6,31 @@ from allensdk.core import json_utilities -from allensdk.internal.api.queries.mouse_connectivity_api_prerelease \ - import MouseConnectivityApiPrerelease, _experiment_dict +from allensdk.internal.api.queries.mouse_connectivity_api_prerelease import ( + MouseConnectivityApiPrerelease, + _experiment_dict, +) + @pytest.fixture def storage_dirs(fn_temp_dir): - return {"111" : os.path.join(fn_temp_dir, "111")} + return {"111": os.path.join(fn_temp_dir, "111")} + @pytest.fixture def connectivity(storage_dirs, fn_temp_dir): - file_name = os.path.join(fn_temp_dir, 'storage_directories.json') + file_name = os.path.join(fn_temp_dir, "storage_directories.json") json_utilities.write(file_name, storage_dirs) mca = MouseConnectivityApiPrerelease(file_name) return mca + _STRUCTURE_TREE_ROOT_ID = 997 _STRUCTURE_TREE_ROOT_NAME = "root" _STRUCTURE_TREE_ROOT_ACRONYM = "root" + # ---------------------------------------------------------------------------- # module level functions # ---------------------------------------------------------------------------- @@ -32,37 +38,39 @@ def connectivity(storage_dirs, fn_temp_dir): def tests_experiment_dict(): # ------------------------------------------------------------------------ # null row - row = {b'id':1, - b'age' : None, - b'gender' : None, - b'project_code' : None, - b'specimen_name' : None, - b'transgenic_line' : None, - b'workflow_state' : None, - b'workflows' : None, - b'structure_id' : None, - b'structure_name' : None, - b'structure_acronym' : None, - b'injection_structures_id' : None, - b'injection_structures_name' : None, - b'injection_structures_acronym' : None} + row = { + b"id": 1, + b"age": None, + b"gender": None, + b"project_code": None, + b"specimen_name": None, + b"transgenic_line": None, + b"workflow_state": None, + b"workflows": None, + b"structure_id": None, + b"structure_name": None, + b"structure_acronym": None, + b"injection_structures_id": None, + b"injection_structures_name": None, + b"injection_structures_acronym": None, + } exp = _experiment_dict(row) - assert exp.pop('id') == 1 + assert exp.pop("id") == 1 - assert exp.get('structure_id') == exp.get('injection_structures')[0].get('id') - assert exp.get('structure_name') == exp.get('injection_structures')[0].get('name') - assert exp.get('structure_abbrev') == exp.get('injection_structures')[0].get('abbreviation') + assert exp.get("structure_id") == exp.get("injection_structures")[0].get("id") + assert exp.get("structure_name") == exp.get("injection_structures")[0].get("name") + assert exp.get("structure_abbrev") == exp.get("injection_structures")[0].get("abbreviation") - assert exp.pop('structure_id') == _STRUCTURE_TREE_ROOT_ID - assert exp.pop('structure_name') == _STRUCTURE_TREE_ROOT_NAME - assert exp.pop('structure_abbrev') == _STRUCTURE_TREE_ROOT_ACRONYM + assert exp.pop("structure_id") == _STRUCTURE_TREE_ROOT_ID + assert exp.pop("structure_name") == _STRUCTURE_TREE_ROOT_NAME + assert exp.pop("structure_abbrev") == _STRUCTURE_TREE_ROOT_ACRONYM - assert len(exp.pop('injection_structures')) == 1 + assert len(exp.pop("injection_structures")) == 1 - assert exp.get('workflows')[0] == "" - assert len(exp.pop('workflows')) == 1 + assert exp.get("workflows")[0] == "" + assert len(exp.pop("workflows")) == 1 for value in exp.values(): assert value == "" @@ -80,12 +88,10 @@ def test_get_structure_unionizes(connectivity): def test_download_injection_density(connectivity, storage_dirs, fn_temp_dir): eid = 111 store = storage_dirs[str(eid)] - source = os.path.join(store, 'grid', 'injection_density_25.nrrd') - target = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'injection_density_25.nrrd') + source = os.path.join(store, "grid", "injection_density_25.nrrd") + target = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "injection_density_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage') as gda: + with mock.patch("allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage") as gda: connectivity.download_injection_density(target, eid, 25) gda.assert_called_once_with(source, target) @@ -95,12 +101,10 @@ def test_download_injection_density(connectivity, storage_dirs, fn_temp_dir): def test_download_projection_density(connectivity, storage_dirs, fn_temp_dir): eid = 111 store = storage_dirs[str(eid)] - source = os.path.join(store, 'grid', 'projection_density_25.nrrd') - target = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'projection_density_25.nrrd') + source = os.path.join(store, "grid", "projection_density_25.nrrd") + target = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "projection_density_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage') as gda: + with mock.patch("allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage") as gda: connectivity.download_projection_density(target, eid, 25) gda.assert_called_once_with(source, target) @@ -110,12 +114,10 @@ def test_download_projection_density(connectivity, storage_dirs, fn_temp_dir): def test_download_injection_fraction(connectivity, storage_dirs, fn_temp_dir): eid = 111 store = storage_dirs[str(eid)] - source = os.path.join(store, 'grid', 'injection_fraction_25.nrrd') - target = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'injection_fraction_25.nrrd') + source = os.path.join(store, "grid", "injection_fraction_25.nrrd") + target = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "injection_fraction_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage') as gda: + with mock.patch("allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage") as gda: connectivity.download_injection_fraction(target, eid, 25) gda.assert_called_once_with(source, target) @@ -125,12 +127,10 @@ def test_download_injection_fraction(connectivity, storage_dirs, fn_temp_dir): def test_download_data_mask(connectivity, storage_dirs, fn_temp_dir): eid = 111 store = storage_dirs[str(eid)] - source = os.path.join(store, 'grid', 'data_mask_25.nrrd') - target = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'data_mask_25.nrrd') + source = os.path.join(store, "grid", "data_mask_25.nrrd") + target = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "data_mask_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage') as gda: + with mock.patch("allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage") as gda: connectivity.download_data_mask(target, eid, 25) gda.assert_called_once_with(source, target) diff --git a/allensdk/test/internal/api/test_pre_release.py b/allensdk/test/internal/api/test_pre_release.py index b0addd9d85..7eadb1395d 100644 --- a/allensdk/test/internal/api/test_pre_release.py +++ b/allensdk/test/internal/api/test_pre_release.py @@ -4,18 +4,18 @@ import os import numpy as np -@pytest.fixture(scope='function') + +@pytest.fixture(scope="function") def tmpdir(tmpdir_factory): - fn = tmpdir_factory.mktemp('tmpdir') + fn = tmpdir_factory.mktemp("tmpdir") return fn @pytest.mark.prerelease def test_pre_release_get_containers(tmpdir): - # Values from original boc/api: - temp_dir_base = os.path.join(str(tmpdir), 'base-api') - outfile_base = os.path.join(temp_dir_base, 'manifest.json') + temp_dir_base = os.path.join(str(tmpdir), "base-api") + outfile_base = os.path.join(temp_dir_base, "manifest.json") boc_base = BrainObservatoryCache(manifest_file=outfile_base) containers_base = boc_base.get_experiment_containers() @@ -25,22 +25,31 @@ def test_pre_release_get_containers(tmpdir): # raise try: - temp_dir_extended = os.path.join(str(tmpdir), 'extended-api') - outfile_extended = os.path.join(temp_dir_extended, 'manifest.json') + temp_dir_extended = os.path.join(str(tmpdir), "extended-api") + outfile_extended = os.path.join(temp_dir_extended, "manifest.json") boc_extended = BrainObservatoryCache(manifest_file=outfile_extended, api=BrainObservatoryApiPreRelease()) except TypeError: raise RuntimeError('Allensdk out-of-date; upgrade with "pip install --force --upgrade allensdk"') containers_extended = boc_extended.get_experiment_containers() - + # # For development: print key/val pairs actually populated by adapter: # for key, val in sorted(containers_extended[0].items(), key=lambda x: x[0]): # print key, val # raise - assert len(containers_extended) > 0 - check_key_list = ['failed', 'tags', 'specimen_name', 'imaging_depth', 'donor_name', 'reporter_line', 'targeted_structure', 'cre_line', 'id'] + check_key_list = [ + "failed", + "tags", + "specimen_name", + "imaging_depth", + "donor_name", + "reporter_line", + "targeted_structure", + "cre_line", + "id", + ] for key in check_key_list: assert key in containers_extended[0] assert len(containers_extended[0]) == len(containers_base[0]) @@ -48,11 +57,11 @@ def test_pre_release_get_containers(tmpdir): id_container_dict = {} for c_e in containers_extended: - curr_id = c_e['id'] + curr_id = c_e["id"] id_container_dict[curr_id] = c_e for c_b in containers_base: - c_e = id_container_dict[c_b['id']] + c_e = id_container_dict[c_b["id"]] for key in c_e: if not c_e[key] == c_b[key]: print(key, c_e[key], c_b[key]) @@ -61,10 +70,9 @@ def test_pre_release_get_containers(tmpdir): @pytest.mark.prerelease def test_pre_release_get_experiments(tmpdir): - # Values from original boc/api: - temp_dir_base = os.path.join(str(tmpdir), 'base-api') - outfile_base = os.path.join(temp_dir_base, 'manifest.json') + temp_dir_base = os.path.join(str(tmpdir), "base-api") + outfile_base = os.path.join(temp_dir_base, "manifest.json") boc_base = BrainObservatoryCache(manifest_file=outfile_base) experiments_base = boc_base.get_ophys_experiments() @@ -73,10 +81,9 @@ def test_pre_release_get_experiments(tmpdir): # print key, val # raise - try: - temp_dir_extended = os.path.join(str(tmpdir), 'extended-api') - outfile_extended = os.path.join(temp_dir_extended, 'manifest.json') + temp_dir_extended = os.path.join(str(tmpdir), "extended-api") + outfile_extended = os.path.join(temp_dir_extended, "manifest.json") boc_extended = BrainObservatoryCache(manifest_file=outfile_extended, api=BrainObservatoryApiPreRelease()) except TypeError: raise RuntimeError('Allensdk out-of-date; upgrade with "pip install --force --upgrade allensdk"') @@ -88,7 +95,19 @@ def test_pre_release_get_experiments(tmpdir): # print key, val # raise - check_key_list = ['acquisition_age_days', 'cre_line', 'donor_name', 'experiment_container_id', 'fail_eye_tracking', 'id', 'imaging_depth', 'reporter_line', 'session_type', 'specimen_name', 'targeted_structure'] + check_key_list = [ + "acquisition_age_days", + "cre_line", + "donor_name", + "experiment_container_id", + "fail_eye_tracking", + "id", + "imaging_depth", + "reporter_line", + "session_type", + "specimen_name", + "targeted_structure", + ] for key in check_key_list: assert key in experiments_extended[0] @@ -97,11 +116,11 @@ def test_pre_release_get_experiments(tmpdir): id_experiment_dict = {} for c_e in experiments_extended: - curr_id = c_e['id'] + curr_id = c_e["id"] id_experiment_dict[curr_id] = c_e for c_b in experiments_base: - c_e = id_experiment_dict[c_b['id']] + c_e = id_experiment_dict[c_b["id"]] for key in c_e: # assert c_e[key] == c_b[key] if not c_e[key] == c_b[key]: @@ -111,13 +130,12 @@ def test_pre_release_get_experiments(tmpdir): @pytest.mark.prerelease def test_pre_release_get_cell_specimens(tmpdir): - # Values from original boc/api: Useful debugging code below, commented out # import warnings # warnings.warn('hard coding tmpdir while I dev, because query takes a long time') # temp_dir_base = '/home/nicholasc/tmp/base-api' - temp_dir_base = os.path.join(str(tmpdir), 'base-api') - outfile_base = os.path.join(temp_dir_base, 'manifest.json') + temp_dir_base = os.path.join(str(tmpdir), "base-api") + outfile_base = os.path.join(temp_dir_base, "manifest.json") boc_base = BrainObservatoryCache(manifest_file=outfile_base) cell_specimens_base = boc_base.get_cell_specimens(include_failed=True) @@ -130,32 +148,30 @@ def test_pre_release_get_cell_specimens(tmpdir): # raise try: - temp_dir_extended = os.path.join(str(tmpdir), 'extended-api') - outfile_extended = os.path.join(temp_dir_extended, 'manifest.json') + temp_dir_extended = os.path.join(str(tmpdir), "extended-api") + outfile_extended = os.path.join(temp_dir_extended, "manifest.json") boc_extended = BrainObservatoryCache(manifest_file=outfile_extended, api=BrainObservatoryApiPreRelease()) except TypeError: raise RuntimeError('Allensdk out-of-date; upgrade with "pip install --force --upgrade allensdk"') cell_specimens_extended = boc_extended.get_cell_specimens(include_failed=True) - assert len(cell_specimens_extended) > 0 assert set(cell_specimens_base[0].keys()) == set(cell_specimens_extended[0].keys()) id_experiment_dict = {} for c_b in cell_specimens_base: - curr_id = c_b['cell_specimen_id'] + curr_id = c_b["cell_specimen_id"] id_experiment_dict[curr_id] = c_b for c_e in cell_specimens_extended: - if c_e['cell_specimen_id'] in id_experiment_dict: - c_b = id_experiment_dict[c_e['cell_specimen_id']] + if c_e["cell_specimen_id"] in id_experiment_dict: + c_b = id_experiment_dict[c_e["cell_specimen_id"]] for key in sorted([key2 for key2 in c_b]): assert key in c_e - if not c_e[key] == c_b[key] and not key == 'specimen_id': # Failure mode 1: specimen_id changed - + if not c_e[key] == c_b[key] and not key == "specimen_id": # Failure mode 1: specimen_id changed if isinstance(c_b[key], (float, complex, int)) and isinstance(c_e[key], (float, complex, int)): - assert np.isclose(c_e[key], c_b[key], 1e-12) # Failure mode 2: floating-point precision + assert np.isclose(c_e[key], c_b[key], 1e-12) # Failure mode 2: floating-point precision elif c_b[key] is None and isinstance(c_e[key], (float, complex, int)): pass else: diff --git a/allensdk/test/internal/biophysical/conftest.py b/allensdk/test/internal/biophysical/conftest.py index c467d2d328..84cb319dbb 100644 --- a/allensdk/test/internal/biophysical/conftest.py +++ b/allensdk/test/internal/biophysical/conftest.py @@ -2,5 +2,5 @@ # ignore test_optimize_run.py if don't have neuron installed collect_ignore = [] -if os.getenv("TEST_NEURON") != 'true': +if os.getenv("TEST_NEURON") != "true": collect_ignore.append("test_optimize_run.py") diff --git a/allensdk/test/internal/biophysical/test_ephys_utils.py b/allensdk/test/internal/biophysical/test_ephys_utils.py index cc366b7099..0c3150df29 100644 --- a/allensdk/test/internal/biophysical/test_ephys_utils.py +++ b/allensdk/test/internal/biophysical/test_ephys_utils.py @@ -6,23 +6,18 @@ @pytest.fixture def data_set(): - data = { 'stimulus': 1.0 * np.arange(10), - 'response': 1.0 * np.arange(10), - 'sampling_rate': 0.1 - } - + data = {"stimulus": 1.0 * np.arange(10), "response": 1.0 * np.arange(10), "sampling_rate": 0.1} + data_set = Mock() - data_set.get_sweep = Mock(name='sweep_data', - return_value=data) - + data_set.get_sweep = Mock(name="sweep_data", return_value=data) + return data_set def test_passive_preprocess(data_set): - s = { 'sweep_number': 5 } + s = {"sweep_number": 5} - v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, - s['sweep_number']) + v, i, t = ephys_utils.get_sweep_v_i_t_from_set(data_set, s["sweep_number"]) assert np.array_equal(v, np.arange(10) * 1000.0) assert np.array_equal(i, np.arange(10) * 1.0e12) assert np.array_equal(t, np.arange(10) * 10.0) diff --git a/allensdk/test/internal/biophysical/test_optimize_run.py b/allensdk/test/internal/biophysical/test_optimize_run.py index 517aa787cc..657b362ca4 100644 --- a/allensdk/test/internal/biophysical/test_optimize_run.py +++ b/allensdk/test/internal/biophysical/test_optimize_run.py @@ -4,22 +4,21 @@ from allensdk.model.biophys_sim.config import Config import os from unittest import mock + try: import __builtin__ as builtins except Exception: import builtins -from allensdk.internal.model.biophysical.run_optimize \ - import RunOptimize -from allensdk.internal.api.queries.optimize_config_reader \ - import OptimizeConfigReader +from allensdk.internal.model.biophysical.run_optimize import RunOptimize +from allensdk.internal.api.queries.optimize_config_reader import OptimizeConfigReader from allensdk.model.biophys_sim.neuron.hoc_utils import HocUtils real_import = __import__ -#import allensdk.eclipse_debug +# import allensdk.eclipse_debug -MANIFEST_JSON = ''' +MANIFEST_JSON = """ { "biophys": [ { @@ -551,9 +550,9 @@ } ] } -''' +""" -LIMS_MESSAGE = ''' +LIMS_MESSAGE = """ { "created_at": "2015-02-13T16:52:57-08:00", "id": 329322394, @@ -6784,7 +6783,7 @@ ], "workflow_state": "has_been_fit" } -''' +""" def mock_import(mod, *args): @@ -6801,30 +6800,29 @@ def mock_read_lims_file(self, lims_path): @pytest.fixture def run_optimize(): - rs = RunOptimize('manifest_sdk.json', 'out.json') - - mock.patch.object(OptimizeConfigReader, - 'read_lims_file', - mock_read_lims_file) - + rs = RunOptimize("manifest_sdk.json", "out.json") + + mock.patch.object(OptimizeConfigReader, "read_lims_file", mock_read_lims_file) + return rs def xtest_init(run_optimize): - assert run_optimize.input_json == 'manifest_sdk.json' - assert run_optimize.output_json == 'out.json' + assert run_optimize.input_json == "manifest_sdk.json" + assert run_optimize.output_json == "out.json" assert run_optimize.app_config is None assert run_optimize.manifest is None orig_open = open + def open_configs(n, *args): (_, fn) = os.path.split(n) - if fn == 'manifest_sdk.json': + if fn == "manifest_sdk.json": data_string = MANIFEST_JSON - elif fn == 'lims_message_optimize.json': + elif fn == "lims_message_optimize.json": data_string = LIMS_MESSAGE else: return orig_open(n, *args) @@ -6837,7 +6835,7 @@ def open_configs(n, *args): @patch("shutil.copy") @patch("allensdk.model.biophysical.runner.save_nwb") @patch.object(HocUtils, "__init__") -@patch(builtins.__name__+".__import__", side_effect=mock_import) +@patch(builtins.__name__ + ".__import__", side_effect=mock_import) @patch("allensdk.core.json_utilities.write") @patch("allensdk.internal.model.biophysical.fit_stage_2.run_stage_2") @patch("allensdk.internal.model.biophysical.fit_stage_2.prepare_stage_2") @@ -6845,22 +6843,24 @@ def open_configs(n, *args): @patch("allensdk.internal.model.biophysical.fit_stage_1.prepare_stage_1") @patch("allensdk.internal.model.biophysical.run_passive_fit.run_passive_fit") @patch("allensdk.core.nwb_data_set.NwbDataSet") -def test_start_specimen(nwb_data_set, - passive_fit, - prepare_stage_1, - run_stage_1, - prepare_stage_2, - run_stage_2, - json_utilities_write, - import_mock, - hoc_init, - save_nwb, - shutil_copy, - path_exists, - run_optimize): - with patch(builtins.__name__+".open", open_configs): - fit_description = Config().load('manifest_sdk.json') +def test_start_specimen( + nwb_data_set, + passive_fit, + prepare_stage_1, + run_stage_1, + prepare_stage_2, + run_stage_2, + json_utilities_write, + import_mock, + hoc_init, + save_nwb, + shutil_copy, + path_exists, + run_optimize, +): + with patch(builtins.__name__ + ".open", open_configs): + fit_description = Config().load("manifest_sdk.json") Utils.description = fit_description run_optimize.start_specimen() - + assert True diff --git a/allensdk/test/internal/biophysical/test_simulate_run.py b/allensdk/test/internal/biophysical/test_simulate_run.py index 8584a75457..35f31329c2 100644 --- a/allensdk/test/internal/biophysical/test_simulate_run.py +++ b/allensdk/test/internal/biophysical/test_simulate_run.py @@ -1,16 +1,16 @@ import pytest from unittest.mock import patch, mock_open, Mock, MagicMock + try: import __builtin__ as builtins except Exception: import builtins from allensdk.model.biophysical.utils import Utils from allensdk.model.biophys_sim.config import Config -from allensdk.internal.model.biophysical.run_simulate_lims \ - import RunSimulateLims +from allensdk.internal.model.biophysical.run_simulate_lims import RunSimulateLims from allensdk.model.biophys_sim.neuron.hoc_utils import HocUtils -MANIFEST_JSON = ''' +MANIFEST_JSON = """ { "biophys": [ { @@ -710,18 +710,19 @@ } ] } -''' +""" + @pytest.fixture def run_simulate(): - rs = RunSimulateLims('manifest.json', 'out.json') - + rs = RunSimulateLims("manifest.json", "out.json") + return rs def test_init(run_simulate): - assert run_simulate.input_json == 'manifest.json' - assert run_simulate.output_json == 'out.json' + assert run_simulate.input_json == "manifest.json" + assert run_simulate.output_json == "out.json" assert run_simulate.app_config is None assert run_simulate.manifest is None @@ -731,23 +732,18 @@ def test_init(run_simulate): @patch.object(HocUtils, "__init__") def test_simulate(hoc_init, mock_h, run_simulate): # import allensdk.eclipse_debug - - mock_utils = Mock(name='mock_utils', - h=mock_h) - with patch('allensdk.internal.api.queries.biophysical_module_reader.BiophysicalModuleReader', - MagicMock(name="bio_mod_reader")): - with patch('allensdk.model.biophysical.runner.save_nwb', - MagicMock(name="save_nwb")): - with patch('allensdk.model.biophysical.runner.NwbDataSet', - MagicMock(name='nwb_data_set')): - with patch('allensdk.model.biophysical.runner.copy', - MagicMock(name='shutil_copy')): - with patch('allensdk.model.biophysical.utils.create_utils', - return_value=mock_utils): - with patch(builtins.__name__ + ".open", - mock_open( - read_data=MANIFEST_JSON)): - fit_description = Config().load('manifest.json') + mock_utils = Mock(name="mock_utils", h=mock_h) + + with patch( + "allensdk.internal.api.queries.biophysical_module_reader.BiophysicalModuleReader", + MagicMock(name="bio_mod_reader"), + ): + with patch("allensdk.model.biophysical.runner.save_nwb", MagicMock(name="save_nwb")): + with patch("allensdk.model.biophysical.runner.NwbDataSet", MagicMock(name="nwb_data_set")): + with patch("allensdk.model.biophysical.runner.copy", MagicMock(name="shutil_copy")): + with patch("allensdk.model.biophysical.utils.create_utils", return_value=mock_utils): + with patch(builtins.__name__ + ".open", mock_open(read_data=MANIFEST_JSON)): + fit_description = Config().load("manifest.json") Utils.description = fit_description run_simulate.simulate() diff --git a/allensdk/test/internal/brain_observatory/test_roi_filter_utils.py b/allensdk/test/internal/brain_observatory/test_roi_filter_utils.py index 835dc0f9fc..0ce8236ecf 100644 --- a/allensdk/test/internal/brain_observatory/test_roi_filter_utils.py +++ b/allensdk/test/internal/brain_observatory/test_roi_filter_utils.py @@ -1,36 +1,22 @@ import pytest -from allensdk.internal.brain_observatory.roi_filter_utils import ( - get_indices_by_distance) +from allensdk.internal.brain_observatory.roi_filter_utils import get_indices_by_distance @pytest.mark.parametrize( - "tree_points, query_points, expected, exception", - [ - ( - [[0, 0], [0, 1], [0, 2], [1, 2], [2, 2]], - [[0, 0], [2, 2]], - [0, 4], - None - ), - ( - [[0, 0], [0, 1], [0, 2], [1, 2], [2, 2]], - [[0, 0.4], [0.1, 0.6]], - [0, 1], - pytest.raises(AssertionError, - match="Max match distance greater than 0") - ), - ( - [], - [], - [], - pytest.raises(ValueError, - match=("number of dimensions is incorrect. " - "Expected 2 got 1")) - ) - ]) -def test_get_indices_by_distance(tree_points, query_points, - expected, exception): + "tree_points, query_points, expected, exception", + [ + ([[0, 0], [0, 1], [0, 2], [1, 2], [2, 2]], [[0, 0], [2, 2]], [0, 4], None), + ( + [[0, 0], [0, 1], [0, 2], [1, 2], [2, 2]], + [[0, 0.4], [0.1, 0.6]], + [0, 1], + pytest.raises(AssertionError, match="Max match distance greater than 0"), + ), + ([], [], [], pytest.raises(ValueError, match=("number of dimensions is incorrect. Expected 2 got 1"))), + ], +) +def test_get_indices_by_distance(tree_points, query_points, expected, exception): """tests exceptions with simple 2D vectors. Actual code has 5D vectors for a basic cell-matching to [minx, miny, maxx, maxy, area] """ diff --git a/allensdk/test/internal/brain_observatory/test_run_ophys_time_sync.py b/allensdk/test/internal/brain_observatory/test_run_ophys_time_sync.py index d1de2dda09..d8ffc6e1a8 100644 --- a/allensdk/test/internal/brain_observatory/test_run_ophys_time_sync.py +++ b/allensdk/test/internal/brain_observatory/test_run_ophys_time_sync.py @@ -1,4 +1,4 @@ -""" Tests for the executable that synchronizes distinct data streams within an +"""Tests for the executable that synchronizes distinct data streams within an ophys experiment. For tests of the logic used by this executable, see test_time_sync """ @@ -13,7 +13,10 @@ import allensdk from allensdk.internal.pipeline_modules.run_ophys_time_sync import ( - TimeSyncOutputs, TimeSyncWriter, check_stimulus_delay, run_ophys_time_sync + TimeSyncOutputs, + TimeSyncWriter, + check_stimulus_delay, + run_ophys_time_sync, ) @@ -32,38 +35,37 @@ def outputs(): np.linspace(3, 4, 10), np.arange(10), np.arange(10, 20), - np.arange(20, 30) + np.arange(20, 30), ) @pytest.fixture def writer(tmpdir_factory): tmpdir_path = str(tmpdir_factory.mktemp("run_ophys_time_sync_tests")) - return TimeSyncWriter( - os.path.join(tmpdir_path, "data.h5"), - os.path.join(tmpdir_path, "output.json") - ) + return TimeSyncWriter(os.path.join(tmpdir_path, "data.h5"), os.path.join(tmpdir_path, "output.json")) def test_validate_paths_writable(writer): try: writer.validate_paths() except Exception as err: - pytest.fail( - f"expected no error. Got: {err.__class__.__name__}(\"{err}\")") - - -@pytest.mark.parametrize("h5_key,expected", [ - ["stimulus_alignment", np.arange(10)], - ["eye_tracking_alignment", np.arange(10, 20)], - ["body_camera_alignment", np.arange(20, 30)], - ["twop_vsync_fall", np.linspace(0, 1, 10)], - ["ophys_delta", 0], - ["stim_delta", 1], - ["stim_delay", 0.35], - ["eye_delta", 2], - ["behavior_delta", 3] -]) + pytest.fail(f'expected no error. Got: {err.__class__.__name__}("{err}")') + + +@pytest.mark.parametrize( + "h5_key,expected", + [ + ["stimulus_alignment", np.arange(10)], + ["eye_tracking_alignment", np.arange(10, 20)], + ["body_camera_alignment", np.arange(20, 30)], + ["twop_vsync_fall", np.linspace(0, 1, 10)], + ["ophys_delta", 0], + ["stim_delta", 1], + ["stim_delay", 0.35], + ["eye_delta", 2], + ["behavior_delta", 3], + ], +) def test_write_output_h5(writer, outputs, h5_key, expected): writer.write_output_h5(outputs) @@ -76,15 +78,18 @@ def test_write_output_h5(writer, outputs, h5_key, expected): assert obtained[()] == expected -@pytest.mark.parametrize("json_key,expected", [ - ["allensdk_version", allensdk.__version__], - ["experiment_id", 100], - ["ophys_delta", 0], - ["stim_delta", 1], - ["stim_delay", 0.35], - ["eye_delta", 2], - ["behavior_delta", 3] -]) +@pytest.mark.parametrize( + "json_key,expected", + [ + ["allensdk_version", allensdk.__version__], + ["experiment_id", 100], + ["ophys_delta", 0], + ["stim_delta", 1], + ["stim_delay", 0.35], + ["eye_delta", 2], + ["behavior_delta", 3], + ], +) def test_write_output_json(writer, outputs, json_key, expected): writer.write_output_json(outputs) @@ -113,12 +118,7 @@ class Aligner(NamedTuple): corrected_eye_video_timestamps: np.ndarray corrected_behavior_video_timestamps: np.ndarray - aligner = Aligner( - (np.arange(10), 0, 0.5), - (np.arange(10), 1), - (np.arange(10), 2), - (np.arange(10), 3) - ) + aligner = Aligner((np.arange(10), 0, 0.5), (np.arange(10), 1), (np.arange(10), 2), (np.arange(10), 3)) obtained = run_ophys_time_sync(aligner, 100, 0.0, 2.0) @@ -137,9 +137,8 @@ class Aligner(NamedTuple): ["behavior_times", np.arange(10)], ["stimulus_alignment", np.arange(10)], ["eye_alignment", np.arange(10)], - ["behavior_alignment", np.arange(10)] + ["behavior_alignment", np.arange(10)], ]: - current_obt = getattr(obtained, name) if isinstance(expected, np.ndarray): @@ -148,9 +147,6 @@ class Aligner(NamedTuple): match = expected == current_obt if not match: - mismatches.append( - f"{name} mismatched: expected {expected}, " - f"obtained {current_obt}" - ) + mismatches.append(f"{name} mismatched: expected {expected}, obtained {current_obt}") assert len(mismatches) == 0, "\n" + "\n".join(mismatches) diff --git a/allensdk/test/internal/brain_observatory/test_time_sync.py b/allensdk/test/internal/brain_observatory/test_time_sync.py index ba7692049a..79a2d7b130 100644 --- a/allensdk/test/internal/brain_observatory/test_time_sync.py +++ b/allensdk/test/internal/brain_observatory/test_time_sync.py @@ -13,7 +13,7 @@ ASSUMED_DELAY = 0.0351 -data_file = str(files('allensdk.test.internal.brain_observatory').joinpath("time_sync_test_data.json")) +data_file = str(files("allensdk.test.internal.brain_observatory").joinpath("time_sync_test_data.json")) test_data = json.load(open(data_file, "r")) data_skip = False @@ -22,8 +22,8 @@ # Functions from lims2_modules ophys_time_sync.py for regression testing -MIN_BOUND = .03 -MAX_BOUND = .04 +MIN_BOUND = 0.03 +MAX_BOUND = 0.04 mock_keys = { @@ -33,7 +33,7 @@ "eye_camera": "cam2_exposure", "behavior_camera": "cam1_exposure", "acquiring": "2p_acquiring", - "lick_sensor": "lick_1" + "lick_sensor": "lick_1", } @@ -42,6 +42,7 @@ class MockSyncDataset(Dataset): Mock the Dataset class so it doesn't load an h5 file upon initialization. """ + def __init__(self, data, line_labels=None): self.dfile = data self.line_labels = line_labels @@ -59,9 +60,7 @@ def calculate_stimulus_alignment(stim_time, valid_twop_vsync_fall): stimulus_alignment = np.empty(len(stim_time)) for index in range(len(stim_time)): - crossings = np.nonzero( - np.ediff1d( - np.sign(valid_twop_vsync_fall - stim_time[index])) > 0) + crossings = np.nonzero(np.ediff1d(np.sign(valid_twop_vsync_fall - stim_time[index])) > 0) try: stimulus_alignment[index] = int(crossings[0][0]) except: # noqa: E722 @@ -71,23 +70,19 @@ def calculate_stimulus_alignment(stim_time, valid_twop_vsync_fall): def calculate_valid_twop_vsync_fall(sync_data, sample_frequency): - twop_vsync_fall = sync_data.get_falling_edges('2p_vsync') /\ - sample_frequency + twop_vsync_fall = sync_data.get_falling_edges("2p_vsync") / sample_frequency if len(twop_vsync_fall) == 0: - raise ValueError('Error: twop_vsync_fall length is 0, possible ' - 'invalid, missing, and/or bad data') + raise ValueError("Error: twop_vsync_fall length is 0, possible invalid, missing, and/or bad data") ophys_start = twop_vsync_fall[0] - valid_twop_vsync_fall = twop_vsync_fall[np.where( - twop_vsync_fall > ophys_start)[0]] + valid_twop_vsync_fall = twop_vsync_fall[np.where(twop_vsync_fall > ophys_start)[0]] return valid_twop_vsync_fall def calculate_stim_vsync_fall(sync_data, sample_frequency): - stim_vsync_fall = sync_data.get_falling_edges('stim_vsync')[0:] /\ - sample_frequency + stim_vsync_fall = sync_data.get_falling_edges("stim_vsync")[0:] / sample_frequency return stim_vsync_fall @@ -116,40 +111,33 @@ def find_start(twop_vsync_fall): return start_index -def sync_camera_stimulus(sync_data, sample_frequency, camera, - ophys_experiment_id): - twop_vsync_fall = sync_data.get_falling_edges('2p_vsync') /\ - sample_frequency +def sync_camera_stimulus(sync_data, sample_frequency, camera, ophys_experiment_id): + twop_vsync_fall = sync_data.get_falling_edges("2p_vsync") / sample_frequency if len(twop_vsync_fall) == 0: - raise ValueError('Error: twop_vsync_fall length is 0, ' - 'possible invalid, missing, and/or bad data') + raise ValueError("Error: twop_vsync_fall length is 0, possible invalid, missing, and/or bad data") try: - twop_acquiring = sync_data.get_rising_edges('2p_acquiring') + twop_acquiring = sync_data.get_rising_edges("2p_acquiring") ophys_start = twop_acquiring / sample_frequency except: # noqa: E722 ophys_start = [find_start(twop_vsync_fall)] - twop_vsync_fall = twop_vsync_fall[np.where( - twop_vsync_fall > ophys_start)[0]] + twop_vsync_fall = twop_vsync_fall[np.where(twop_vsync_fall > ophys_start)[0]] cam_fall = None if camera == 1: - cam_fall = sync_data.get_falling_edges('cam1_exposure') /\ - sample_frequency + cam_fall = sync_data.get_falling_edges("cam1_exposure") / sample_frequency elif camera == 2: - cam_fall = sync_data.get_falling_edges('cam2_exposure') /\ - sample_frequency + cam_fall = sync_data.get_falling_edges("cam2_exposure") / sample_frequency else: - raise ValueError(f'Error: camera value {camera} is invalid') + raise ValueError(f"Error: camera value {camera} is invalid") frames = np.zeros((len(twop_vsync_fall), 1)) for i in range(len(frames)): - crossings = np.nonzero( - np.ediff1d(np.sign(cam_fall - twop_vsync_fall[i])) > 0) + crossings = np.nonzero(np.ediff1d(np.sign(cam_fall - twop_vsync_fall[i])) > 0) try: frames[i] = crossings[0][0] except: # noqa: E722 @@ -157,6 +145,7 @@ def sync_camera_stimulus(sync_data, sample_frequency, camera, return frames + # End of regression functions @@ -178,7 +167,7 @@ def scientifica_input(): def input_json(tmpdir_factory): output_file = str(tmpdir_factory.mktemp("test").join("output.h5")) input_data = test_data["nikon"].copy() - input_data['output_file'] = output_file + input_data["output_file"] = output_file json_file = str(tmpdir_factory.mktemp("test").join("input.json")) with open(json_file, "w") as f: json.dump(input_data, f) @@ -206,7 +195,7 @@ def test_get_alignment_array(): def test_regression_valid_2p_timestamps(nikon_input, scientifica_input): sync_file = nikon_input.pop("sync_file") aligner = ts.OphysTimeAligner(sync_file, **nikon_input) - freq = aligner.dataset.meta_data['ni_daq']['counter_output_freq'] + freq = aligner.dataset.meta_data["ni_daq"]["counter_output_freq"] old_times = calculate_valid_twop_vsync_fall(aligner.dataset, freq) new_times = aligner.ophys_timestamps assert np.allclose(new_times[1:], old_times) @@ -214,7 +203,7 @@ def test_regression_valid_2p_timestamps(nikon_input, scientifica_input): # old scientifica used falling edges as timestamps incorrectly sync_file = scientifica_input.pop("sync_file") aligner = ts.OphysTimeAligner(sync_file, **scientifica_input) - freq = aligner.dataset.meta_data['ni_daq']['counter_output_freq'] + freq = aligner.dataset.meta_data["ni_daq"]["counter_output_freq"] old_times = calculate_valid_twop_vsync_fall(aligner.dataset, freq) new_times = aligner.ophys_timestamps assert len(new_times) - len(old_times) == 1 @@ -227,21 +216,18 @@ def test_regression_stim_timestamps(nikon_input, scientifica_input): for input_data in [nikon_input, scientifica_input]: sync_file = input_data.pop("sync_file") aligner = ts.OphysTimeAligner(sync_file, **input_data) - freq = aligner.dataset.meta_data['ni_daq']['counter_output_freq'] + freq = aligner.dataset.meta_data["ni_daq"]["counter_output_freq"] old_times = calculate_stim_vsync_fall(aligner.dataset, freq) assert np.allclose(aligner.stim_timestamps, old_times) @pytest.mark.skipif(data_skip, reason="No sync or data") -def test_regression_calculate_stimulus_alignment(nikon_input, - scientifica_input): +def test_regression_calculate_stimulus_alignment(nikon_input, scientifica_input): for input_data in [nikon_input, scientifica_input]: sync_file = input_data.pop("sync_file") aligner = ts.OphysTimeAligner(sync_file, **input_data) - old_align = calculate_stimulus_alignment(aligner.stim_timestamps, - aligner.ophys_timestamps) - new_align = ts.get_alignment_array(aligner.ophys_timestamps, - aligner.stim_timestamps) + old_align = calculate_stimulus_alignment(aligner.stim_timestamps, aligner.ophys_timestamps) + new_align = ts.get_alignment_array(aligner.ophys_timestamps, aligner.stim_timestamps) # Old alignment assigned simultaneous stim frames to the previous ophys # frame. Methods should only differ when ophys and stim are identical. @@ -250,32 +236,30 @@ def test_regression_calculate_stimulus_alignment(nikon_input, mis_s = aligner.stim_timestamps[mismatch] assert np.all(mis_o == mis_s) # Occurence of mismatch should be rare - assert len(mis_o) < 0.005*len(aligner.ophys_timestamps) + assert len(mis_o) < 0.005 * len(aligner.ophys_timestamps) @pytest.mark.skipif(data_skip, reason="No sync or data") -def test_regression_calculate_camera_alignment(nikon_input, - scientifica_input): +def test_regression_calculate_camera_alignment(nikon_input, scientifica_input): for input_data in [nikon_input, scientifica_input]: sync_file = input_data.pop("sync_file") aligner = ts.OphysTimeAligner(sync_file, **input_data) - freq = aligner.dataset.meta_data['ni_daq']['counter_output_freq'] + freq = aligner.dataset.meta_data["ni_daq"]["counter_output_freq"] old_eye_align = sync_camera_stimulus(aligner.dataset, freq, 2, 1) # old alignment throws out the first ophys timestamp - new_eye_align = ts.get_alignment_array(aligner.eye_video_timestamps, - aligner.ophys_timestamps[1:], - int_method=np.ceil) + new_eye_align = ts.get_alignment_array( + aligner.eye_video_timestamps, aligner.ophys_timestamps[1:], int_method=np.ceil + ) mismatch = np.where(old_eye_align[:, 0] != new_eye_align) - mis_e = \ - aligner.eye_video_timestamps[new_eye_align[mismatch].astype(int)] + mis_e = aligner.eye_video_timestamps[new_eye_align[mismatch].astype(int)] mis_o = aligner.ophys_timestamps[1:][mismatch] - mis_o_plus = aligner.ophys_timestamps[1:][(mismatch[0]+1,)] + mis_o_plus = aligner.ophys_timestamps[1:][(mismatch[0] + 1,)] # New method should only disagree when old method was wrong (old method # set an eye tracking frame to an earlier ophys frame). assert np.all(mis_o < mis_e) assert np.all(mis_o_plus >= mis_e) # Occurence of mismatch should be rare - assert len(mis_o) < 0.005*len(aligner.ophys_timestamps[1:]) + assert len(mis_o) < 0.005 * len(aligner.ophys_timestamps[1:]) @pytest.mark.parametrize("eye_data_length", (None, 5000, 6000)) @@ -287,8 +271,7 @@ def test_get_corrected_eye_times(eye_data_length): aligner = ts.OphysTimeAligner("test") aligner.eye_data_length = eye_data_length - with patch.object(ts.Dataset, "get_falling_edges", - return_value=true_times) as mock_falling: + with patch.object(ts.Dataset, "get_falling_edges", return_value=true_times) as mock_falling: with patch("logging.info") as mock_log: times, delta = aligner.corrected_eye_video_timestamps @@ -315,8 +298,7 @@ def test_get_corrected_behavior_times(behavior_data_length): aligner = ts.OphysTimeAligner("test") aligner.behavior_data_length = behavior_data_length - with patch.object(ts.Dataset, "get_falling_edges", - return_value=true_times) as mock_falling: + with patch.object(ts.Dataset, "get_falling_edges", return_value=true_times) as mock_falling: with patch("logging.info") as mock_log: times, delta = aligner.corrected_behavior_video_timestamps @@ -334,14 +316,10 @@ def test_get_corrected_behavior_times(behavior_data_length): assert delta == (len(true_times) - behavior_data_length) -@pytest.mark.parametrize("stim_data_length,start_delay", [ - (None, False), - (None, True), - (5000, False), - (5000, True), - (6000, False), - (6000, True) - ]) +@pytest.mark.parametrize( + "stim_data_length,start_delay", + [(None, False), (None, True), (5000, False), (5000, True), (6000, False), (6000, True)], +) def test_get_corrected_stim_times(stim_data_length, start_delay): true_falling = np.arange(0, 60, 0.01) true_rising = true_falling + 0.005 @@ -354,15 +332,11 @@ def test_get_corrected_stim_times(stim_data_length, start_delay): aligner = ts.OphysTimeAligner("test") aligner.stim_data_length = stim_data_length - with patch.object(ts, "calculate_monitor_delay", - return_value=ASSUMED_DELAY): - with patch.object(ts.Dataset, "get_falling_edges", - return_value=true_falling): - with patch.object(ts.Dataset, "get_rising_edges", - return_value=true_rising) as mock_rising: + with patch.object(ts, "calculate_monitor_delay", return_value=ASSUMED_DELAY): + with patch.object(ts.Dataset, "get_falling_edges", return_value=true_falling): + with patch.object(ts.Dataset, "get_rising_edges", return_value=true_rising) as mock_rising: with patch("logging.info") as mock_log: - times, delta, stim_delay = \ - aligner.corrected_stim_timestamps + times, delta, stim_delay = aligner.corrected_stim_timestamps if stim_data_length is None: mock_log.assert_called_once() @@ -395,13 +369,10 @@ def test_get_corrected_ophys_times_nikon(ophys_data_length): aligner = ts.OphysTimeAligner("test", "NIKONA1RMP") aligner.ophys_data_length = ophys_data_length - with patch.object(ts.Dataset, "get_falling_edges", - return_value=true_times): - with patch.object(ts.Dataset, "get_rising_edges", - return_value=[0]): + with patch.object(ts.Dataset, "get_falling_edges", return_value=true_times): + with patch.object(ts.Dataset, "get_rising_edges", return_value=[0]): with patch("logging.info") as mock_log: - if ophys_data_length is not None and \ - ophys_data_length > len(true_times): + if ophys_data_length is not None and ophys_data_length > len(true_times): with pytest.raises(ValueError): times, delta = aligner.corrected_ophys_timestamps else: @@ -439,60 +410,55 @@ def test_module(input_json): aligner = ts.OphysTimeAligner(sync_file, **input_data) with h5py.File(output_file) as f: t, d = aligner.corrected_ophys_timestamps - assert np.all(t == f['twop_vsync_fall'][()]) - assert np.all(d == f['ophys_delta'][()]) + assert np.all(t == f["twop_vsync_fall"][()]) + assert np.all(d == f["ophys_delta"][()]) st, sd, stim_delay = aligner.corrected_stim_timestamps align = ts.get_alignment_array(t, st) - assert np.allclose(align, f['stimulus_alignment'][()], - equal_nan=True) - assert np.all(sd == f['stim_delta'][()]) + assert np.allclose(align, f["stimulus_alignment"][()], equal_nan=True) + assert np.all(sd == f["stim_delta"][()]) et, ed = aligner.corrected_eye_video_timestamps align = ts.get_alignment_array(et, t, int_method=np.ceil) - assert np.allclose(align, f['eye_tracking_alignment'][()], - equal_nan=True) - assert np.all(ed == f['eye_delta'][()]) + assert np.allclose(align, f["eye_tracking_alignment"][()], equal_nan=True) + assert np.all(ed == f["eye_delta"][()]) bt, bd = aligner.corrected_behavior_video_timestamps align = ts.get_alignment_array(bt, t, int_method=np.ceil) - assert np.allclose(align, f['body_camera_alignment'][()], - equal_nan=True) + assert np.allclose(align, f["body_camera_alignment"][()], equal_nan=True) @pytest.mark.parametrize( "sync_dset,stim_times,transition_interval,expected", [ - (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), - np.array([0.99, 1.99, 2.99, 3.99, 4.99]), 1, 0.01), - (np.array([1.0, 2.0, 3.0, 4.0]), - np.array([0.95, 2.0, 2.95, 4.0]), 1, 0.025), - (np.array([1.0]), np.array([1.0]), 1, 0.0) + (np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.array([0.99, 1.99, 2.99, 3.99, 4.99]), 1, 0.01), + (np.array([1.0, 2.0, 3.0, 4.0]), np.array([0.95, 2.0, 2.95, 4.0]), 1, 0.025), + (np.array([1.0]), np.array([1.0]), 1, 0.0), ], ) -def test_monitor_delay(sync_dset, stim_times, transition_interval, expected, - monkeypatch): - monkeypatch.setattr(ts, "get_real_photodiode_events", - mock_get_real_photodiode_events) - pytest.approx(expected, - ts.calculate_monitor_delay(sync_dset, stim_times, "key", - transition_interval)) +def test_monitor_delay(sync_dset, stim_times, transition_interval, expected, monkeypatch): + monkeypatch.setattr(ts, "get_real_photodiode_events", mock_get_real_photodiode_events) + pytest.approx(expected, ts.calculate_monitor_delay(sync_dset, stim_times, "key", transition_interval)) @pytest.mark.parametrize( "sync_dset,stim_times,transition_interval", [ # Negative - (np.array([1.0, 2.0, 3.0]), np.array([0.9, 1.9, 2.9]), 1,), + ( + np.array([1.0, 2.0, 3.0]), + np.array([0.9, 1.9, 2.9]), + 1, + ), # Too big - (np.array([1.0, 2.0, 3.0, 4.0]), np.array([1.1, 2.1, 3.1, 4.1]), 1,), + ( + np.array([1.0, 2.0, 3.0, 4.0]), + np.array([1.1, 2.1, 3.1, 4.1]), + 1, + ), ], ) -def test_monitor_delay_raises_error( - sync_dset, stim_times, transition_interval, - monkeypatch): - monkeypatch.setattr(ts, "get_real_photodiode_events", - mock_get_real_photodiode_events) +def test_monitor_delay_raises_error(sync_dset, stim_times, transition_interval, monkeypatch): + monkeypatch.setattr(ts, "get_real_photodiode_events", mock_get_real_photodiode_events) with pytest.raises(ValueError): - ts.calculate_monitor_delay(sync_dset, stim_times, - "key", transition_interval) + ts.calculate_monitor_delay(sync_dset, stim_times, "key", transition_interval) @pytest.mark.parametrize( @@ -504,7 +470,7 @@ def test_monitor_delay_raises_error( (np.array([]), lambda x: x < 1, 1, None), (np.array([1, 2, 3]), lambda x: x < 3, 4, None), (np.array([1, 2, 2, 3, 2]), lambda x: x == 2, 3, None), - ] + ], ) def test_find_n(arr, cond, n, expected): assert expected == ts._find_n(arr, n, cond) @@ -519,7 +485,7 @@ def test_find_n(arr, cond, n, expected): (np.array([]), lambda x: x < 1, 1, None), (np.array([1, 2, 3]), lambda x: x < 3, 4, None), (np.array([1, 2, 2, 3, 2]), lambda x: x == 2, 3, None), - ] + ], ) def test_find_last_n(arr, cond, n, expected): assert expected == ts._find_last_n(arr, n, cond) @@ -528,19 +494,18 @@ def test_find_last_n(arr, cond, n, expected): @pytest.mark.parametrize( "sync_dset,expected", [ - ([0.25, 0.5, 0.75, 1., 2., 3., 5., 5.75], [1., 2., 3.]), - ([1., 2., 3., 4.], [1., 2., 3., 4.]), + ([0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 5.0, 5.75], [1.0, 2.0, 3.0]), + ([1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]), # false alarm start - ([0.25, 1., 2., 2.1, 2.2, 3., 4., 5.], [3., 4., 5.]), + ([0.25, 1.0, 2.0, 2.1, 2.2, 3.0, 4.0, 5.0], [3.0, 4.0, 5.0]), # false alarm end - ([0.25, 1., 2., 3., 4., 4.5, 5.1, 6.1], [1., 2., 3., 4.]), + ([0.25, 1.0, 2.0, 3.0, 4.0, 4.5, 5.1, 6.1], [1.0, 2.0, 3.0, 4.0]), ], ) def test_get_photodiode_events(sync_dset, expected, monkeypatch): ds = MockSyncDataset(sync_dset) monkeypatch.setattr(ds, "get_events_by_line", mock_get_events_by_line) - np.testing.assert_array_equal( - expected, ts.get_photodiode_events(ds, sync_dset)) + np.testing.assert_array_equal(expected, ts.get_photodiode_events(ds, sync_dset)) @pytest.mark.parametrize( @@ -548,8 +513,8 @@ def test_get_photodiode_events(sync_dset, expected, monkeypatch): [ ([]), ([0.25, 0.25]), - ([1., 2.]), - ] + ([1.0, 2.0]), + ], ) def test_photodiode_events_error_if_none_found(sync_dset, monkeypatch): ds = MockSyncDataset(sync_dset) @@ -558,11 +523,14 @@ def test_photodiode_events_error_if_none_found(sync_dset, monkeypatch): ts.get_photodiode_events(ds, sync_dset) -@pytest.mark.parametrize("deserialized_pkl,expected", [ - ({"vsynccount": 100}, 100), - ({"items": {"behavior": {"intervalsms": [2, 2, 2, 2, 2]}}}, 6), - ({"vsynccount": 20, "items": {"behavior": {"intervalsms": [3, 3]}}}, 20) -]) +@pytest.mark.parametrize( + "deserialized_pkl,expected", + [ + ({"vsynccount": 100}, 100), + ({"items": {"behavior": {"intervalsms": [2, 2, 2, 2, 2]}}}, 6), + ({"vsynccount": 20, "items": {"behavior": {"intervalsms": [3, 3]}}}, 20), + ], +) def test_get_stim_data_length(monkeypatch, deserialized_pkl, expected): def mock_read_pickle(*args, **kwargs): return deserialized_pkl @@ -574,109 +542,143 @@ def mock_read_pickle(*args, **kwargs): @pytest.mark.parametrize( - "sync_dset, line_labels, expected_line_labels, expected_log", - [ - (None, ['2p_vsync', 'stim_vsync', 'stim_photodiode', - 'acq_trigger', '', 'cam1_exposure', - 'cam2_exposure', 'lick_sensor'], - { - "photodiode": "stim_photodiode", - "2p": "2p_vsync", - "stimulus": "stim_vsync", - "eye_camera": "cam2_exposure", - "behavior_camera": "cam1_exposure", - "lick_sensor": "lick_sensor", - "acquiring": "acq_trigger"}, - []), - (None, ['2p_vsync', 'stim_vsync', 'photodiode', - 'acq_trigger', 'behavior_monitoring', - 'eye_tracking', 'lick_1'], - { - "photodiode": "photodiode", - "2p": "2p_vsync", - "stimulus": "stim_vsync", - "eye_camera": "eye_tracking", - "behavior_camera": "behavior_monitoring", - "lick_sensor": "lick_1", - "acquiring": "acq_trigger"}, - []), - (None, ['2p_vsync', 'stim_vsync', 'photodiode', - 'acq_trigger', '', 'behavior_monitoring', - 'lick_1'], - { - "photodiode": "photodiode", - "2p": "2p_vsync", - "stimulus": "stim_vsync", - "behavior_camera": "behavior_monitoring", - "lick_sensor": "lick_1", - "acquiring": "acq_trigger"}, - [('root', 30, 'Could not find valid lines for the ' - 'following data sources'), - ('root', 30, "eye_camera (valid line label(s) = " - "['cam2_exposure', 'eye_tracking', " - "'eye_frame_received']")]), - (None, [], - {}, - [('root', 30, - 'Could not find valid lines for the ' - 'following data sources'), - ('root', 30, - "photodiode (valid line label(s) = " - "['stim_photodiode', 'photodiode']"), - ('root', 30, - "2p (valid line label(s) = ['2p_vsync']"), - ('root', 30, - "stimulus (valid line label(s) = " - "['stim_vsync', 'vsync_stim']"), - ('root', 30, - "eye_camera (valid line label(s) = " - "['cam2_exposure', 'eye_tracking', 'eye_frame_received']"), - ('root', 30, "behavior_camera (valid line label(s) " - "= ['cam1_exposure', " - "'behavior_monitoring', " - "'beh_frame_received']"), - ('root', 30, "acquiring (valid line label(s) = " - "['2p_acquiring', 'acq_trigger']"), - ('root', 30, "lick_sensor (valid line label(s) = " - "['lick_1', 'lick_sensor']")]), - (None, ['', 'stim_vsync', 'photodiode', 'acq_trigger', - 'eye_tracking', 'lick_1', 'acq_trigger', - 'cam1_exposure'], - { - "photodiode": "photodiode", - "stimulus": "stim_vsync", - "eye_camera": "eye_tracking", - "behavior_camera": "cam1_exposure", - "lick_sensor": "lick_1", - "acquiring": "acq_trigger"}, - [('root', 30, 'Could not find valid lines for the ' - 'following data sources'), - ('root', 30, "2p (valid line label(s) = " - "['2p_vsync']")]), - (None, ['barcode_ephys', 'vsync_stim', - 'stim_photodiode', 'stim_running', - 'beh_frame_received', 'eye_frame_received', - 'face_frame_received', 'stim_running_opto', - 'stim_trial_opto', 'face_came_frame_readout', - 'eye_cam_frame_readout', - 'beh_cam_frame_readout', 'face_cam_exposing', - 'eye_cam_exposing', 'beh_cam_exposing', - 'lick_sensor'], - { - "photodiode": "stim_photodiode", - "stimulus": "vsync_stim", - "eye_camera": "eye_frame_received", - "behavior_camera": "beh_frame_received", - "lick_sensor": "lick_sensor"}, - [('root', 30, 'Could not find valid lines for the ' - 'following data sources'), - ('root', 30, "2p (valid line label(s) = " - "['2p_vsync']"), - ('root', 30, "acquiring (valid line label(s) = " - "['2p_acquiring', 'acq_trigger']")]) - ]) -def test_get_keys(sync_dset, line_labels, expected_line_labels, expected_log, - caplog): + "sync_dset, line_labels, expected_line_labels, expected_log", + [ + ( + None, + [ + "2p_vsync", + "stim_vsync", + "stim_photodiode", + "acq_trigger", + "", + "cam1_exposure", + "cam2_exposure", + "lick_sensor", + ], + { + "photodiode": "stim_photodiode", + "2p": "2p_vsync", + "stimulus": "stim_vsync", + "eye_camera": "cam2_exposure", + "behavior_camera": "cam1_exposure", + "lick_sensor": "lick_sensor", + "acquiring": "acq_trigger", + }, + [], + ), + ( + None, + ["2p_vsync", "stim_vsync", "photodiode", "acq_trigger", "behavior_monitoring", "eye_tracking", "lick_1"], + { + "photodiode": "photodiode", + "2p": "2p_vsync", + "stimulus": "stim_vsync", + "eye_camera": "eye_tracking", + "behavior_camera": "behavior_monitoring", + "lick_sensor": "lick_1", + "acquiring": "acq_trigger", + }, + [], + ), + ( + None, + ["2p_vsync", "stim_vsync", "photodiode", "acq_trigger", "", "behavior_monitoring", "lick_1"], + { + "photodiode": "photodiode", + "2p": "2p_vsync", + "stimulus": "stim_vsync", + "behavior_camera": "behavior_monitoring", + "lick_sensor": "lick_1", + "acquiring": "acq_trigger", + }, + [ + ("root", 30, "Could not find valid lines for the following data sources"), + ( + "root", + 30, + "eye_camera (valid line label(s) = ['cam2_exposure', 'eye_tracking', 'eye_frame_received']", + ), + ], + ), + ( + None, + [], + {}, + [ + ("root", 30, "Could not find valid lines for the following data sources"), + ("root", 30, "photodiode (valid line label(s) = ['stim_photodiode', 'photodiode']"), + ("root", 30, "2p (valid line label(s) = ['2p_vsync']"), + ("root", 30, "stimulus (valid line label(s) = ['stim_vsync', 'vsync_stim']"), + ( + "root", + 30, + "eye_camera (valid line label(s) = ['cam2_exposure', 'eye_tracking', 'eye_frame_received']", + ), + ( + "root", + 30, + "behavior_camera (valid line label(s) " + "= ['cam1_exposure', " + "'behavior_monitoring', " + "'beh_frame_received']", + ), + ("root", 30, "acquiring (valid line label(s) = ['2p_acquiring', 'acq_trigger']"), + ("root", 30, "lick_sensor (valid line label(s) = ['lick_1', 'lick_sensor']"), + ], + ), + ( + None, + ["", "stim_vsync", "photodiode", "acq_trigger", "eye_tracking", "lick_1", "acq_trigger", "cam1_exposure"], + { + "photodiode": "photodiode", + "stimulus": "stim_vsync", + "eye_camera": "eye_tracking", + "behavior_camera": "cam1_exposure", + "lick_sensor": "lick_1", + "acquiring": "acq_trigger", + }, + [ + ("root", 30, "Could not find valid lines for the following data sources"), + ("root", 30, "2p (valid line label(s) = ['2p_vsync']"), + ], + ), + ( + None, + [ + "barcode_ephys", + "vsync_stim", + "stim_photodiode", + "stim_running", + "beh_frame_received", + "eye_frame_received", + "face_frame_received", + "stim_running_opto", + "stim_trial_opto", + "face_came_frame_readout", + "eye_cam_frame_readout", + "beh_cam_frame_readout", + "face_cam_exposing", + "eye_cam_exposing", + "beh_cam_exposing", + "lick_sensor", + ], + { + "photodiode": "stim_photodiode", + "stimulus": "vsync_stim", + "eye_camera": "eye_frame_received", + "behavior_camera": "beh_frame_received", + "lick_sensor": "lick_sensor", + }, + [ + ("root", 30, "Could not find valid lines for the following data sources"), + ("root", 30, "2p (valid line label(s) = ['2p_vsync']"), + ("root", 30, "acquiring (valid line label(s) = ['2p_acquiring', 'acq_trigger']"), + ], + ), + ], +) +def test_get_keys(sync_dset, line_labels, expected_line_labels, expected_log, caplog): """ Test Cases: 1) Test Case with V2 keys diff --git a/allensdk/test/internal/conftest.py b/allensdk/test/internal/conftest.py index 557d1b85d9..26bc96e45d 100644 --- a/allensdk/test/internal/conftest.py +++ b/allensdk/test/internal/conftest.py @@ -1,8 +1,6 @@ import os - def pytest_ignore_collect(path, config): - ''' These tests (or the code they test) can only run on the local network at the Allen Institute for Brain Science. - ''' - return(os.getenv('TEST_COMPLETE') != 'true') and (os.getenv('TEST_INTERNAL') != 'true') + """These tests (or the code they test) can only run on the local network at the Allen Institute for Brain Science.""" + return (os.getenv("TEST_COMPLETE") != "true") and (os.getenv("TEST_INTERNAL") != "true") diff --git a/allensdk/test/internal/core/test_mouse_connectivity_cache_prerelease.py b/allensdk/test/internal/core/test_mouse_connectivity_cache_prerelease.py index 66d1c08a55..750d899a07 100644 --- a/allensdk/test/internal/core/test_mouse_connectivity_cache_prerelease.py +++ b/allensdk/test/internal/core/test_mouse_connectivity_cache_prerelease.py @@ -7,46 +7,46 @@ from allensdk.core import json_utilities -from allensdk.internal.core.mouse_connectivity_cache_prerelease \ - import MouseConnectivityCachePrerelease +from allensdk.internal.core.mouse_connectivity_cache_prerelease import MouseConnectivityCachePrerelease -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def mcc(fn_temp_dir): - storage_dirs = {"111" : os.path.join(fn_temp_dir, "111"), - "222" : os.path.join(fn_temp_dir, "222")} + storage_dirs = {"111": os.path.join(fn_temp_dir, "111"), "222": os.path.join(fn_temp_dir, "222")} - file_name = os.path.join(fn_temp_dir, 'storage_directories.json') + file_name = os.path.join(fn_temp_dir, "storage_directories.json") json_utilities.write(file_name, storage_dirs) - manifest_path = os.path.join(fn_temp_dir, 'manifest.json') - return MouseConnectivityCachePrerelease( - manifest_file=manifest_path, storage_directories_file_name=file_name) + manifest_path = os.path.join(fn_temp_dir, "manifest.json") + return MouseConnectivityCachePrerelease(manifest_file=manifest_path, storage_directories_file_name=file_name) + @pytest.fixture def experiments(): - return [{'id':111, - 'age' : "10 wks", - 'gender' : "M", - 'project_code' : "Connectional Atlas", - 'specimen_name' : "", - 'transgenic_line' : "", - 'workflow_state' : "passed", - 'workflows' : ["2P Serial Imaging"], - 'structure_id' : 184, - 'structure_name' : "Frontal pole, cerebral cortex", - 'structure_abbrev' : "FRP", - 'injection_structures' : [ - {'id' : 184, - 'name' : "Frontal pole, cerebral cortex", - 'abbreviation' : 'FRP'}, - {'id' : 993, - 'name' : "Secondary motor area", - 'abbreviation' : 'MOs'}]}] + return [ + { + "id": 111, + "age": "10 wks", + "gender": "M", + "project_code": "Connectional Atlas", + "specimen_name": "", + "transgenic_line": "", + "workflow_state": "passed", + "workflows": ["2P Serial Imaging"], + "structure_id": 184, + "structure_name": "Frontal pole, cerebral cortex", + "structure_abbrev": "FRP", + "injection_structures": [ + {"id": 184, "name": "Frontal pole, cerebral cortex", "abbreviation": "FRP"}, + {"id": 993, "name": "Secondary motor area", "abbreviation": "MOs"}, + ], + } + ] + @pytest.mark.prerelease def test_init(mcc, fn_temp_dir): - manifest_path = os.path.join(fn_temp_dir, 'manifest.json') + manifest_path = os.path.join(fn_temp_dir, "manifest.json") assert os.path.exists(manifest_path) @@ -54,16 +54,15 @@ def test_init(mcc, fn_temp_dir): def test_get_projection_density(mcc, fn_temp_dir): eye = np.eye(100) eid = 111 - path = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'projection_density_25.nrrd') + path = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "projection_density_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage", + new=lambda a, b, c: nrrd.write(c, eye), + ): obtained, _ = mcc.get_projection_density(eid) - with mock.patch.object(mcc.api.grid_data_api.api, - "retrieve_file_from_storage") as mock_rtrv: + with mock.patch.object(mcc.api.grid_data_api.api, "retrieve_file_from_storage") as mock_rtrv: mcc.get_projection_density(eid) mock_rtrv.assert_not_called() @@ -75,16 +74,15 @@ def test_get_projection_density(mcc, fn_temp_dir): def test_get_injection_density(mcc, fn_temp_dir): eye = np.eye(100) eid = 111 - path = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'injection_density_25.nrrd') + path = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "injection_density_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage", + new=lambda a, b, c: nrrd.write(c, eye), + ): obtained, _ = mcc.get_injection_density(eid) - with mock.patch.object(mcc.api.grid_data_api.api, - "retrieve_file_from_storage") as mock_rtrv: + with mock.patch.object(mcc.api.grid_data_api.api, "retrieve_file_from_storage") as mock_rtrv: mcc.get_injection_density(eid) mock_rtrv.assert_not_called() @@ -96,16 +94,15 @@ def test_get_injection_density(mcc, fn_temp_dir): def test_get_injection_fraction(mcc, fn_temp_dir): eye = np.eye(100) eid = 111 - path = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'injection_fraction_25.nrrd') + path = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "injection_fraction_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage", + new=lambda a, b, c: nrrd.write(c, eye), + ): obtained, _ = mcc.get_injection_fraction(eid) - with mock.patch.object(mcc.api.grid_data_api.api, - "retrieve_file_from_storage") as mock_rtrv: + with mock.patch.object(mcc.api.grid_data_api.api, "retrieve_file_from_storage") as mock_rtrv: mcc.get_injection_fraction(eid) mock_rtrv.assert_not_called() @@ -117,16 +114,15 @@ def test_get_injection_fraction(mcc, fn_temp_dir): def test_get_data_mask(mcc, fn_temp_dir): eye = np.eye(100) eid = 111 - path = os.path.join(fn_temp_dir, 'experiment_{0}'.format(eid), - 'data_mask_25.nrrd') + path = os.path.join(fn_temp_dir, "experiment_{0}".format(eid), "data_mask_25.nrrd") - with mock.patch('allensdk.internal.api.api_prerelease.ApiPrerelease.' - 'retrieve_file_from_storage', - new=lambda a, b, c: nrrd.write(c, eye)): + with mock.patch( + "allensdk.internal.api.api_prerelease.ApiPrerelease.retrieve_file_from_storage", + new=lambda a, b, c: nrrd.write(c, eye), + ): obtained, _ = mcc.get_data_mask(eid) - with mock.patch.object(mcc.api.grid_data_api.api, - "retrieve_file_from_storage") as mock_rtrv: + with mock.patch.object(mcc.api.grid_data_api.api, "retrieve_file_from_storage") as mock_rtrv: mcc.get_data_mask(eid) mock_rtrv.assert_not_called() @@ -136,7 +132,6 @@ def test_get_data_mask(mcc, fn_temp_dir): @pytest.mark.prerelease def test_filter_experiments(mcc, experiments): - # ------------------------------------------------------------------------ # test cre cre = mcc.filter_experiments(experiments, cre=True) @@ -153,51 +148,53 @@ def test_filter_experiments(mcc, experiments): # ------------------------------------------------------------------------ # test age - pass_age = mcc.filter_experiments(experiments, age=['10 wks', '12 wks']) - fail_age = mcc.filter_experiments(experiments, age=['12 wks']) + pass_age = mcc.filter_experiments(experiments, age=["10 wks", "12 wks"]) + fail_age = mcc.filter_experiments(experiments, age=["12 wks"]) assert len(pass_age) == 1 assert not fail_age # ------------------------------------------------------------------------ # test gender - pass_gender = mcc.filter_experiments(experiments, gender=['MALE']) - fail_gender = mcc.filter_experiments(experiments, gender=['f']) + pass_gender = mcc.filter_experiments(experiments, gender=["MALE"]) + fail_gender = mcc.filter_experiments(experiments, gender=["f"]) assert len(pass_gender) == 1 assert not fail_gender # ------------------------------------------------------------------------ # test workflow-sate - pass_ws = mcc.filter_experiments(experiments, workflow_state=['qc', 'passed']) - fail_ws = mcc.filter_experiments(experiments, workflow_state=['failed']) + pass_ws = mcc.filter_experiments(experiments, workflow_state=["qc", "passed"]) + fail_ws = mcc.filter_experiments(experiments, workflow_state=["failed"]) assert len(pass_ws) == 1 assert not fail_ws # ------------------------------------------------------------------------ # test workflows - pass_w = mcc.filter_experiments(experiments, workflows=['2P SERial ImaGing']) - fail_w = mcc.filter_experiments(experiments, workflows=['trans-synaptic']) + pass_w = mcc.filter_experiments(experiments, workflows=["2P SERial ImaGing"]) + fail_w = mcc.filter_experiments(experiments, workflows=["trans-synaptic"]) assert len(pass_w) == 1 assert not fail_w # ------------------------------------------------------------------------ # test project_code - pass_pc = mcc.filter_experiments(experiments, project_code=['ConNECTIOnal Atlas']) - fail_pc = mcc.filter_experiments(experiments, project_code=['not a code']) + pass_pc = mcc.filter_experiments(experiments, project_code=["ConNECTIOnal Atlas"]) + fail_pc = mcc.filter_experiments(experiments, project_code=["not a code"]) assert len(pass_pc) == 1 assert not fail_pc # ------------------------------------------------------------------------ # test a bunch - conditions = dict(injection_structure_ids=[184, 98], - age=['10 wKS', '12 wks'], - gender=['maLE'], - workflow_state=['qC', 'pASsed'], - workflows=['2p serial imaging']) + conditions = dict( + injection_structure_ids=[184, 98], + age=["10 wKS", "12 wks"], + gender=["maLE"], + workflow_state=["qC", "pASsed"], + workflows=["2p serial imaging"], + ) passed = mcc.filter_experiments(experiments, **conditions) assert len(passed) == 1 diff --git a/allensdk/test/internal/gbm/test_generate_gbm_heatmap.py b/allensdk/test/internal/gbm/test_generate_gbm_heatmap.py index 62752f828e..c5d04f8bf6 100644 --- a/allensdk/test/internal/gbm/test_generate_gbm_heatmap.py +++ b/allensdk/test/internal/gbm/test_generate_gbm_heatmap.py @@ -12,100 +12,156 @@ def test_create_transcripts_for_genes(): - - analysis_run_gene_file = {"analysis_run_gene_path": TEST_GENE_FILE, "analysis_run_transcript_path": - TEST_TRANSCRIPT_FILE, "rna_well_id": 300173630} + analysis_run_gene_file = { + "analysis_run_gene_path": TEST_GENE_FILE, + "analysis_run_transcript_path": TEST_TRANSCRIPT_FILE, + "rna_well_id": 300173630, + } data = heatmap.create_transcripts_for_genes(analysis_run_gene_file) - d = [["gene_id", "transcript_id(s)"], - ["1000", "NM_001792_3"], - ["124989", "NM_001195192_1,NM_152347_4"], - ["100008586", "NM_001098405_1"]] + d = [ + ["gene_id", "transcript_id(s)"], + ["1000", "NM_001792_3"], + ["124989", "NM_001195192_1,NM_152347_4"], + ["100008586", "NM_001098405_1"], + ] expected_data = pd.DataFrame(data=d) - assert(expected_data.equals(data)) + assert expected_data.equals(data) def test_create_genes_for_transcripts(): - - analysis_run_transcript_file = {"analysis_run_gene_path": TEST_GENE_FILE, "analysis_run_transcript_path": - TEST_TRANSCRIPT_FILE, "rna_well_id": - 300173630} + analysis_run_transcript_file = { + "analysis_run_gene_path": TEST_GENE_FILE, + "analysis_run_transcript_path": TEST_TRANSCRIPT_FILE, + "rna_well_id": 300173630, + } data = heatmap.create_genes_for_transcripts(analysis_run_transcript_file) - d = [["transcript_id", "gene_id"], - ["NM_000015_2", "10"], - ["NM_130786_3", "1"], - ["tRNA-Tyr.100009601.chr14", "100"]] + d = [["transcript_id", "gene_id"], ["NM_000015_2", "10"], ["NM_130786_3", "1"], ["tRNA-Tyr.100009601.chr14", "100"]] expected_data = pd.DataFrame(data=d) - assert(expected_data.equals(data)) + assert expected_data.equals(data) def test_create_gene_fpkm_table(): - - analysis_run_records = [{"analysis_run_gene_path": TEST_GENE_FILE, "analysis_run_transcript_path": - TEST_TRANSCRIPT_FILE, "rna_well_id": 300173630}, - {"analysis_run_gene_path": TEST2_GENE_FILE, "analysis_run_transcript_path": - TEST2_TRANSCRIPT_FILE, "rna_well_id": 300173634}] + analysis_run_records = [ + { + "analysis_run_gene_path": TEST_GENE_FILE, + "analysis_run_transcript_path": TEST_TRANSCRIPT_FILE, + "rna_well_id": 300173630, + }, + { + "analysis_run_gene_path": TEST2_GENE_FILE, + "analysis_run_transcript_path": TEST2_TRANSCRIPT_FILE, + "rna_well_id": 300173634, + }, + ] data = heatmap.create_gene_fpkm_table(analysis_run_records) d = np.column_stack([["108.14", "5.10", "0.00"], ["11.05", "21.41", "11.57"]]) expected_data = pd.DataFrame(data=d, columns=[300173630, 300173634], index=[1000, 124989, 100008586]) - assert (expected_data.equals(data)) + assert expected_data.equals(data) def test_create_transcript_fpkm_table(): - - analysis_run_records = [{"analysis_run_gene_path": TEST_GENE_FILE, "analysis_run_transcript_path": - TEST_TRANSCRIPT_FILE, "rna_well_id": 300173630}, - {"analysis_run_gene_path": TEST2_GENE_FILE, "analysis_run_transcript_path": - TEST2_TRANSCRIPT_FILE, "rna_well_id": 300173634}] + analysis_run_records = [ + { + "analysis_run_gene_path": TEST_GENE_FILE, + "analysis_run_transcript_path": TEST_TRANSCRIPT_FILE, + "rna_well_id": 300173630, + }, + { + "analysis_run_gene_path": TEST2_GENE_FILE, + "analysis_run_transcript_path": TEST2_TRANSCRIPT_FILE, + "rna_well_id": 300173634, + }, + ] data = heatmap.create_transcript_fpkm_table(analysis_run_records) d = np.column_stack([["10.00", "100.00", "100.00"], ["0.00", "100.00", "0.00"]]) - expected_data = pd.DataFrame(data=d, columns=[300173630, 300173634], index=["NM_000015_2", "NM_130786_3", - "tRNA-Tyr.100009601.chr14"]) - assert(expected_data.equals(data)) + expected_data = pd.DataFrame( + data=d, columns=[300173630, 300173634], index=["NM_000015_2", "NM_130786_3", "tRNA-Tyr.100009601.chr14"] + ) + assert expected_data.equals(data) def test_create_sample_metadata(): - sample_metadata_records = [ - { - "structure_name": "Microvascular proliferation sampled by reference histology", - "structure_id": 309780906, - "structure_color": "ff330", - "block_id": 703397, - "polygon_id": 298726077, - "specimen_id": 297710593, - "tumor_name": "W1-1-2", - "rna_well_id": 300173634, - "structure_abbreviation": "CTmvp-reference-histology", - "block_name": "W1-1-2-D.2", - "tumor_id": 703393, - "specimen_name": "W1-1-2-D.2.01" - }, - { - "structure_name": "Cellular Tumor sampled by reference histology", - "structure_id": 309780592, - "structure_color": "5d04", - "block_id": 703397, - "polygon_id": 298727153, - "specimen_id": 297710593, - "tumor_name": "W1-1-2", - "rna_well_id": 300173630, - "structure_abbreviation": "CT-reference-histology", - "block_name": "W1-1-2-D.2", - "tumor_id": 703393, - "specimen_name": "W1-1-2-D.2.01" - } + { + "structure_name": "Microvascular proliferation sampled by reference histology", + "structure_id": 309780906, + "structure_color": "ff330", + "block_id": 703397, + "polygon_id": 298726077, + "specimen_id": 297710593, + "tumor_name": "W1-1-2", + "rna_well_id": 300173634, + "structure_abbreviation": "CTmvp-reference-histology", + "block_name": "W1-1-2-D.2", + "tumor_id": 703393, + "specimen_name": "W1-1-2-D.2.01", + }, + { + "structure_name": "Cellular Tumor sampled by reference histology", + "structure_id": 309780592, + "structure_color": "5d04", + "block_id": 703397, + "polygon_id": 298727153, + "specimen_id": 297710593, + "tumor_name": "W1-1-2", + "rna_well_id": 300173630, + "structure_abbreviation": "CT-reference-histology", + "block_name": "W1-1-2-D.2", + "tumor_id": 703393, + "specimen_name": "W1-1-2-D.2.01", + }, ] data = heatmap.create_sample_metadata(sample_metadata_records) - d = [[300173630, 703397, "W1-1-2-D.2", 298727153, 297710593, "W1-1-2-D.2.01", "CT-reference-histology", "5d04", - 309780592, "Cellular Tumor sampled by reference histology", 703393, "W1-1-2"], - [300173634, 703397, "W1-1-2-D.2", 298726077, 297710593, "W1-1-2-D.2.01", "CTmvp-reference-histology", "ff330", - 309780906, "Microvascular proliferation sampled by reference histology", 703393, "W1-1-2"]] - - expected_data = pd.DataFrame(data=d, columns=["rna_well_id", "block_id", "block_name", "polygon_id", "specimen_id", - "specimen_name", "structure_abbreviation", "structure_color", - "structure_id", "structure_name", "tumor_id", "tumor_name"]) + d = [ + [ + 300173630, + 703397, + "W1-1-2-D.2", + 298727153, + 297710593, + "W1-1-2-D.2.01", + "CT-reference-histology", + "5d04", + 309780592, + "Cellular Tumor sampled by reference histology", + 703393, + "W1-1-2", + ], + [ + 300173634, + 703397, + "W1-1-2-D.2", + 298726077, + 297710593, + "W1-1-2-D.2.01", + "CTmvp-reference-histology", + "ff330", + 309780906, + "Microvascular proliferation sampled by reference histology", + 703393, + "W1-1-2", + ], + ] + + expected_data = pd.DataFrame( + data=d, + columns=[ + "rna_well_id", + "block_id", + "block_name", + "polygon_id", + "specimen_id", + "specimen_name", + "structure_abbreviation", + "structure_color", + "structure_id", + "structure_name", + "tumor_id", + "tumor_name", + ], + ) pd.testing.assert_frame_equal(expected_data, data, check_like=True) diff --git a/allensdk/test/internal/morphology/test_apply_affine.py b/allensdk/test/internal/morphology/test_apply_affine.py index 666c9aa6cd..cb04a2b70f 100644 --- a/allensdk/test/internal/morphology/test_apply_affine.py +++ b/allensdk/test/internal/morphology/test_apply_affine.py @@ -8,24 +8,24 @@ def test_apply_affine(): node_list = [Node(0, 1, 0, 0, 0, 3, -1), Node(1, 2, 0, 0, 1, 1, 0)] morph = Morphology(node_list) - scale = [2, 0, 0, - 0, 2, 0, - 0, 0, 2] + scale = [2, 0, 0, 0, 2, 0, 0, 0, 2] translate = [1, 0, 0] affine = scale + translate morph.apply_affine(affine) # was at (0, 0, 1) with r = 1 - expected_node1 = {'id': 1, - 'type': 2, - 'x': 1, - 'y': 0, - 'z': 2, - 'radius': 2, - 'parent': 0, - 'children': [], - 'tree_id': 0, - 'compartment_id': 0} + expected_node1 = { + "id": 1, + "type": 2, + "x": 1, + "y": 0, + "z": 2, + "radius": 2, + "parent": 0, + "children": [], + "tree_id": 0, + "compartment_id": 0, + } obtained_node1 = morph.node_list[1] for key, value in expected_node1.items(): diff --git a/allensdk/test/internal/mouse_connectivity/test_interval_unionizer.py b/allensdk/test/internal/mouse_connectivity/test_interval_unionizer.py index b0774f6670..5ae2b11389 100644 --- a/allensdk/test/internal/mouse_connectivity/test_interval_unionizer.py +++ b/allensdk/test/internal/mouse_connectivity/test_interval_unionizer.py @@ -4,60 +4,51 @@ import pytest from unittest import mock -from allensdk.internal.mouse_connectivity.interval_unionize.interval_unionizer \ - import IntervalUnionizer +from allensdk.internal.mouse_connectivity.interval_unionize.interval_unionizer import IntervalUnionizer -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def annotation(): - annot = np.zeros((10, 10, 10)) annot[:, :, 4:] = 1 # 4 off, 6 on, ... annot[5:8, :, :] = 2 # solid block from 500:800 - - return annot - - + + return annot + + def test_init(): - iu = IntervalUnionizer([1, 2, 3]) - assert( np.allclose(iu.exclude_structure_ids, [1, 2, 3]) ) - + assert np.allclose(iu.exclude_structure_ids, [1, 2, 3]) + iu = IntervalUnionizer() - assert( sum(iu.exclude_structure_ids) == 0 ) - + assert sum(iu.exclude_structure_ids) == 0 -def test_setup_interval_map(annotation): +def test_setup_interval_map(annotation): bounds_exp = {1: (280, 700), 2: (700, 1000)} - + iu = IntervalUnionizer() iu.setup_interval_map(annotation) for k, v in iu.interval_map.items(): - assert( np.allclose(v, bounds_exp[k]) ) - - + assert np.allclose(v, bounds_exp[k]) + + def test_extract_data(): - iu = IntervalUnionizer() - + with pytest.raises(NotImplementedError): - iu.extract_data('olive', 'reticulated', 'western_woma') - - + iu.extract_data("olive", "reticulated", "western_woma") + + def test_propagate_record(): - with pytest.raises(NotImplementedError): - IntervalUnionizer.propagate_record('olive', 'reticulated') - - + IntervalUnionizer.propagate_record("olive", "reticulated") + + def test_propagate_unionizes(): + uns = {1: {"a": 1, "b": 2}, 2: {"a": 2, "b": 3}, 3: {"a": 3, "b": 4}} - uns = {1: {'a': 1, 'b': 2}, - 2: {'a': 2, 'b': 3}, - 3: {'a': 3, 'b': 4}} - amap = {1: [1, 2], 2: [2], 3: [3]} def dummy_prop(cls, c, a): @@ -66,54 +57,48 @@ def dummy_prop(cls, c, a): IntervalUnionizer.propagate_record = classmethod(dummy_prop) ou = IntervalUnionizer.propagate_unionizes(uns, amap) - assert( ou[3]['a'] == 3 ) - assert( ou[2]['b'] == 5 ) - assert( ou[1]['a'] == 1 ) - - + assert ou[3]["a"] == 3 + assert ou[2]["b"] == 5 + assert ou[1]["a"] == 1 + + def test_postprocess_unionizes(): - iu = IntervalUnionizer() with pytest.raises(NotImplementedError): - iu.postprocess_unionizes('foo') - - -def test_sort_data_arrays(): + iu.postprocess_unionizes("foo") + +def test_sort_data_arrays(): data_arrays = {1: np.arange(10), 2: np.arange(10, 20)} sort = np.array([3, 1, 5, 2, 6, 4, 7, 8, 9, 0]) - + iu = IntervalUnionizer() iu.sort = sort - + obt = iu.sort_data_arrays(data_arrays) - - assert( np.allclose(obt[1], sort) ) - assert( np.allclose(obt[2], 10 + sort ) ) - - -def test_direct_unionize(): - data = {'savu': np.arange(1000)} - im = {1: (280, 700), 2: (700, 1000)} + assert np.allclose(obt[1], sort) + assert np.allclose(obt[2], 10 + sort) - with mock.patch('allensdk.internal.mouse_connectivity.interval_unionize.' - 'interval_unionizer.IntervalUnionizer.sort_data_arrays', - new=lambda s, x: x): +def test_direct_unionize(): + data = {"savu": np.arange(1000)} + im = {1: (280, 700), 2: (700, 1000)} + + with mock.patch( + "allensdk.internal.mouse_connectivity.interval_unionize.interval_unionizer.IntervalUnionizer.sort_data_arrays", + new=lambda s, x: x, + ): class IU(IntervalUnionizer): def extract_data(self, d, l, h, **k): - return d['savu'][l:h].sum() + return d["savu"][l:h].sum() iu = IU() - + iu.interval_map = im - + obt = iu.direct_unionize(data) - - assert( obt[1] == np.arange(280, 700).sum() ) - assert( obt[2] == np.arange(700, 1000).sum() ) - - + assert obt[1] == np.arange(280, 700).sum() + assert obt[2] == np.arange(700, 1000).sum() diff --git a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_projection_functions.py b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_projection_functions.py index aba85b6cfb..37d4255a77 100644 --- a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_projection_functions.py +++ b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_projection_functions.py @@ -12,25 +12,21 @@ def example_volume(): return sitk.GetImageFromArray(array) -def test_convert_axis(): # :) - +def test_convert_axis(): # :) obt = prf.convert_axis(2) - assert(obt == 0) + assert obt == 0 def test_max_projection(example_volume): - max_obt, depth_obt = prf.max_projection(example_volume, 2) depth_exp = np.zeros([2, 2]) + 1 max_exp = np.array([[4, 5], [6, 7]]) - assert(np.allclose(depth_exp, depth_obt)) - assert(np.allclose(max_exp, max_obt)) + assert np.allclose(depth_exp, depth_obt) + assert np.allclose(max_exp, max_obt) def test_template_projection(example_volume): - prf.template_projection(example_volume, 2, 1, 1) np.array([[16, 25], [20, 5]]) - diff --git a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_visualization_utilities.py b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_visualization_utilities.py index 8c8126ee2a..19729e5b24 100644 --- a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_visualization_utilities.py +++ b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_visualization_utilities.py @@ -9,7 +9,7 @@ @pytest.fixture def example_volume(): - arr = np.arange(5*6*7, dtype=np.float64).reshape([5, 6, 7]) + arr = np.arange(5 * 6 * 7, dtype=np.float64).reshape([5, 6, 7]) return sitk.GetImageFromArray(arr) @@ -29,32 +29,29 @@ def test_convert_discrete_colormap(discrete_cmap): def test_sitk_safe_ln(example_volume): - obt = sitk.GetArrayFromImage(vis.sitk_safe_ln(example_volume)) arr = sitk.GetArrayFromImage(example_volume) - + arr = np.log(arr) arr[0, 0, 0] = np.log(10**-10) - + print(obt) print(arr) - assert(np.allclose(arr, obt)) + assert np.allclose(arr, obt) def test_normalize_intensity(example_volume): - obt = vis.normalize_intensity(example_volume, 2, 4, 50, 100) obt = sitk.GetArrayFromImage(obt) - assert(75 == obt[0, 0, 3]) + assert 75 == obt[0, 0, 3] def test_blend(): - images = [np.eye(2), np.array([[1, 2], [3, 4]])] weights = [np.fliplr(np.eye(2)), [[0, 0], [0, 1]]] exp = [[1, 0], [0, 4]] obt = vis.blend(images, weights) - assert(np.allclose(exp, obt)) + assert np.allclose(exp, obt) diff --git a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_projector.py b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_projector.py index dd5c85eeac..f90f1198c2 100644 --- a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_projector.py +++ b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_projector.py @@ -8,7 +8,7 @@ @pytest.fixture def simple_volume(): arr = np.arange(8 * 9 * 10, dtype=np.float64).reshape([8, 9, 10]) - return sitk.GetImageFromArray(arr) # swaps 0 <=> 2 axes + return sitk.GetImageFromArray(arr) # swaps 0 <=> 2 axes @pytest.fixture @@ -19,62 +19,57 @@ def cube_volume(): def test_init(simple_volume): vp = VolumeProjector(simple_volume) - assert(vp.view_volume.GetPixel(3, 5, 3) == simple_volume.GetPixel(3, 5, 3)) + assert vp.view_volume.GetPixel(3, 5, 3) == simple_volume.GetPixel(3, 5, 3) def test_build_rotation_transform(simple_volume): - vp = VolumeProjector(simple_volume) trans_obt = vp.build_rotation_transform(0, 2, np.pi / 2.0) - exp = [0, 0, -1, 0, 1, 0, 1, 0, 0] # just hstacked rows - assert(np.allclose(trans_obt.GetMatrix(), exp)) + exp = [0, 0, -1, 0, 1, 0, 1, 0, 0] # just hstacked rows + assert np.allclose(trans_obt.GetMatrix(), exp) -@pytest.mark.parametrize('angle,check', [(2*np.pi, [2, 3, 4]), (np.pi, [7, 3, 3])]) +@pytest.mark.parametrize("angle,check", [(2 * np.pi, [2, 3, 4]), (np.pi, [7, 3, 3])]) def test_rotate(simple_volume, angle, check): - vp = VolumeProjector(simple_volume) obt = vp.rotate(0, 2, angle) - assert(np.allclose(simple_volume.GetPixel(2, 3, 4), obt.GetPixel(*check))) - + assert np.allclose(simple_volume.GetPixel(2, 3, 4), obt.GetPixel(*check)) -def test_extract(): +def test_extract(): arr = np.eye(20) vp = VolumeProjector(arr) obt = vp.extract(np.sum) exp = 20 - assert(obt == exp) + assert obt == exp -@pytest.mark.parametrize('angle,exp', [(0.0, 0.0), (2 * np.pi, 0.0)]) +@pytest.mark.parametrize("angle,exp", [(0.0, 0.0), (2 * np.pi, 0.0)]) def test_rotate_and_extract(angle, exp, cube_volume): - def cb(x): return x.GetPixel(0, 0, 0) + vp = VolumeProjector(cube_volume) - + for obt in vp.rotate_and_extract([0], [2], [angle], cb): - assert(np.allclose(obt, exp)) + assert np.allclose(obt, exp) def test_fixed_factory(simple_volume): - shape = [5, 6, 7] vp = VolumeProjector.fixed_factory(simple_volume, shape) - + shape_obt = vp.view_volume.GetSize() - assert(np.allclose(shape, shape_obt)) + assert np.allclose(shape, shape_obt) def test_safe_factory(simple_volume): - vp = VolumeProjector.safe_factory(simple_volume) shape_exp = [16, 17, 16] - assert(np.allclose(shape_exp, vp.view_volume.GetSize())) - assert(vp.view_volume.GetPixel(8, 9, 8) == simple_volume.GetPixel(5, 5, 4)) + assert np.allclose(shape_exp, vp.view_volume.GetSize()) + assert vp.view_volume.GetPixel(8, 9, 8) == simple_volume.GetPixel(5, 5, 4) diff --git a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_utilities.py b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_utilities.py index 97b141b514..f3e086bf51 100644 --- a/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_utilities.py +++ b/allensdk/test/internal/mouse_connectivity/test_projection_thumbnail/test_volume_utilities.py @@ -14,53 +14,47 @@ def empty_image(): @pytest.fixture def even_image(): arr = np.zeros([12, 14, 16]) - arr[6, 7, 8] = 1 - arr[6, 7, 8] = 1 + arr[6, 7, 8] = 1 + arr[6, 7, 8] = 1 img = sitk.GetImageFromArray(arr) return img def test_sitk_get_image_parameters(empty_image): - sp, sz, og = vol.sitk_get_image_parameters(empty_image) - - assert(np.allclose(sp, [1, 1, 1])) - assert(np.allclose(sz, [2, 3, 4])) - assert(np.allclose(og, [0, 0, 0])) + assert np.allclose(sp, [1, 1, 1]) + assert np.allclose(sz, [2, 3, 4]) + assert np.allclose(og, [0, 0, 0]) -def test_sitk_get_center_even(even_image): +def test_sitk_get_center_even(even_image): center = vol.sitk_get_center(even_image) - assert(np.allclose(center, [7.5, 6.5, 5.5])) + assert np.allclose(center, [7.5, 6.5, 5.5]) def test_sitk_get_center_odd(empty_image): obt = vol.sitk_get_center(empty_image) - assert(np.allclose(obt, [0.5, 1, 1.5])) + assert np.allclose(obt, [0.5, 1, 1.5]) -@pytest.mark.parametrize('size,exp', [([2, 3], [0, 1]), ([3, 3], [1, 1])]) +@pytest.mark.parametrize("size,exp", [([2, 3], [0, 1]), ([3, 3], [1, 1])]) def test_sitk_get_size_parity(size, exp): - image = sitk.Image(size[0], size[1], sitk.sitkUInt8) obt = vol.sitk_get_size_parity(image) - - assert(np.allclose(exp, obt)) + assert np.allclose(exp, obt) -@pytest.mark.parametrize('shape,exp', [([1, 1, 1], np.sqrt(3)), - ([2, 4], np.sqrt(20))]) -def test_sitk_get_diagonal_length(shape, exp): +@pytest.mark.parametrize("shape,exp", [([1, 1, 1], np.sqrt(3)), ([2, 4], np.sqrt(20))]) +def test_sitk_get_diagonal_length(shape, exp): img = sitk.GetImageFromArray(np.zeros(shape)) obt = vol.sitk_get_diagonal_length(img) - assert(exp == obt) + assert exp == obt def test_sitk_paste_into_center_even(): - smaller = sitk.GetImageFromArray(np.eye(2)) larger = sitk.GetImageFromArray(np.zeros((4, 4))) @@ -71,11 +65,10 @@ def test_sitk_paste_into_center_even(): exp[1, 1] = 1 exp[2, 2] = 1 - assert(np.allclose(obt, exp)) + assert np.allclose(obt, exp) def test_sitk_paste_into_center_odd(): - smaller = sitk.GetImageFromArray(np.eye(3, 3)) larger = sitk.GetImageFromArray(np.zeros((5, 5))) @@ -89,4 +82,4 @@ def test_sitk_paste_into_center_odd(): exp[2, 2] = 1 exp[3, 3] = 1 - assert(np.allclose(exp, obt)) + assert np.allclose(exp, obt) diff --git a/allensdk/test/internal/mouse_connectivity/test_tissuecyte_unionize_record.py b/allensdk/test/internal/mouse_connectivity/test_tissuecyte_unionize_record.py index a4fd240678..fcc9a7efe6 100644 --- a/allensdk/test/internal/mouse_connectivity/test_tissuecyte_unionize_record.py +++ b/allensdk/test/internal/mouse_connectivity/test_tissuecyte_unionize_record.py @@ -4,86 +4,84 @@ import pytest from unittest import mock -from allensdk.internal.mouse_connectivity.interval_unionize\ - .tissuecyte_unionize_record import TissuecyteBaseUnionize, \ - TissuecyteInjectionUnionize,TissuecyteProjectionUnionize - - -@pytest.fixture(scope='function') -def data_arrays(): +from allensdk.internal.mouse_connectivity.interval_unionize.tissuecyte_unionize_record import ( + TissuecyteBaseUnionize, + TissuecyteInjectionUnionize, + TissuecyteProjectionUnionize, +) + +@pytest.fixture(scope="function") +def data_arrays(): fq = np.ones(100) fq[25:] = 0 - + lq = np.ones(100) lq[:75] = 0 top = np.ones(100) top[:70] = 0 - return {'injection_fraction': fq, - 'aav_exclusion_fraction': top, - 'projection_density': np.arange(100), - 'projection_energy': np.arange(100) * 2, - 'injection_density': np.multiply(np.arange(100), fq), - 'injection_energy': np.multiply(np.arange(100) * 2, fq), - 'sum_pixels': np.ones(100) * 900, - 'sum_pixel_intensities': np.ones(100), - 'injection_sum_pixel_intensities': np.ones(100)[:50] - } - - -def test_base_init(): + return { + "injection_fraction": fq, + "aav_exclusion_fraction": top, + "projection_density": np.arange(100), + "projection_energy": np.arange(100) * 2, + "injection_density": np.multiply(np.arange(100), fq), + "injection_energy": np.multiply(np.arange(100) * 2, fq), + "sum_pixels": np.ones(100) * 900, + "sum_pixel_intensities": np.ones(100), + "injection_sum_pixel_intensities": np.ones(100)[:50], + } + +def test_base_init(): tbu = TissuecyteBaseUnionize() for item in TissuecyteBaseUnionize.__slots__: - assert( getattr(tbu, item) == 0 ) - + assert getattr(tbu, item) == 0 + -@pytest.mark.parametrize('anc_mvd', [0, 1]) +@pytest.mark.parametrize("anc_mvd", [0, 1]) def test_base_propagate(anc_mvd): - an = TissuecyteBaseUnionize() an.sum_pixels = 12 an.max_voxel_index = 100 an.max_voxel_density = anc_mvd - + ch = TissuecyteBaseUnionize() ch.sum_pixels = 5 ch.max_voxel_index = 50 ch.max_voxel_density = 0.5 - + ch.propagate(an) - - assert( an.sum_pixels == 17 ) - + + assert an.sum_pixels == 17 + if an.max_voxel_density == 1: - assert( an.max_voxel_index == 100 ) + assert an.max_voxel_index == 100 else: - assert( an.max_voxel_index == 50 ) - - -@pytest.mark.parametrize('spp', [0, 1]) -def test_base_set_max_voxel(spp): + assert an.max_voxel_index == 50 + - darr = np.arange(25) / 24 +@pytest.mark.parametrize("spp", [0, 1]) +def test_base_set_max_voxel(spp): + darr = np.arange(25) / 24 darr[15:] = 0 low = 12 tbu = TissuecyteBaseUnionize() tbu.sum_projection_pixels = spp tbu.set_max_voxel(darr, low) - + if spp == 1: - assert( tbu.max_voxel_index == 26 ) - assert( tbu.max_voxel_density == 14 / 24 ) + assert tbu.max_voxel_index == 26 + assert tbu.max_voxel_density == 14 / 24 else: - assert( tbu.max_voxel_index == 0 ) - assert( tbu.max_voxel_density == 0 ) - - -def test_base_slice_arrays(): + assert tbu.max_voxel_index == 0 + assert tbu.max_voxel_density == 0 + +def test_base_slice_arrays(): arrays = {ii: np.arange(10) + ii for ii in range(20)} low = 5 high = 8 @@ -92,77 +90,73 @@ def test_base_slice_arrays(): sl = tbu.slice_arrays(low, high, arrays) for k, v in sl.items(): - assert( len(v) == 3 ) - assert( v.sum() == k * 3 + 18 ) - - -@pytest.mark.parametrize('sum_pixels,sum_projection_pixels', [(0, 2), (0, 2)]) -def test_base_output(sum_pixels, sum_projection_pixels): + assert len(v) == 3 + assert v.sum() == k * 3 + 18 + +@pytest.mark.parametrize("sum_pixels,sum_projection_pixels", [(0, 2), (0, 2)]) +def test_base_output(sum_pixels, sum_projection_pixels): tbu = TissuecyteBaseUnionize() - + tbu.sum_pixels = sum_pixels tbu.direct_sum_projection_pixels = sum_projection_pixels / 2 tbu.sum_projection_pixels = sum_projection_pixels tbu.sum_projection_pixel_intensity = 100 - + tbu.max_voxel_index = 999 tbu.max_voxel_density = 1 - + out = tbu.output(10, 900, (10, 10, 10), np.arange(1000)) - - assert( out['volume'] == sum_pixels * 900 ) - assert( out['direct_projection_volume'] == sum_projection_pixels * 450 ) - assert( out['projection_volume'] == sum_projection_pixels * 900 ) - + + assert out["volume"] == sum_pixels * 900 + assert out["direct_projection_volume"] == sum_projection_pixels * 450 + assert out["projection_volume"] == sum_projection_pixels * 900 + if sum_pixels > 0: - assert( out['projection_density'] == sum_projection_pixels / sum_pixels ) + assert out["projection_density"] == sum_projection_pixels / sum_pixels else: - assert( out['projection_density'] == 0 ) - + assert out["projection_density"] == 0 + if sum_pixels > 0: - assert( out['projection_energy'] == 100 / sum_pixels ) + assert out["projection_energy"] == 100 / sum_pixels else: - assert( out['projection_energy'] == 0 ) - + assert out["projection_energy"] == 0 + if sum_projection_pixels > 0: - assert( out['projection_intensity'] == 100 / sum_projection_pixels ) + assert out["projection_intensity"] == 100 / sum_projection_pixels else: - assert( out['projection_intensity'] == 0 ) - - assert( out['max_voxel_x'] == 90 ) - assert( out['max_voxel_y'] == 90 ) - assert( out['max_voxel_z'] == 90 ) - - -def test_injection_calculate(data_arrays): + assert out["projection_intensity"] == 0 + + assert out["max_voxel_x"] == 90 + assert out["max_voxel_y"] == 90 + assert out["max_voxel_z"] == 90 + +def test_injection_calculate(data_arrays): tiu = TissuecyteInjectionUnionize() - + tiu.calculate(20, 80, data_arrays) - - assert( tiu.sum_pixels == 4500 ) - assert( tiu.sum_projection_pixels == 900 * np.arange(20, 25).sum() ) - assert( tiu.sum_projection_pixel_intensity == 1800 * np.arange(20, 25).sum() ) - - assert( tiu.max_voxel_index == 24 ) - assert( tiu.max_voxel_density == 24 ) - - -def test_projection_calculate(data_arrays): + assert tiu.sum_pixels == 4500 + assert tiu.sum_projection_pixels == 900 * np.arange(20, 25).sum() + assert tiu.sum_projection_pixel_intensity == 1800 * np.arange(20, 25).sum() + + assert tiu.max_voxel_index == 24 + assert tiu.max_voxel_density == 24 + + +def test_projection_calculate(data_arrays): tiu = mock.MagicMock() tiu.sum_pixels = 1 tiu.sum_projection_pixels = 2 tiu.sum_projection_pixel_intensity = 3 - + tpu = TissuecyteProjectionUnionize() tpu.calculate(20, 80, data_arrays, tiu) - - assert( tpu.sum_pixels == 900 * 50 - 1 ) - assert( tpu.sum_projection_pixels == 900 * np.arange(20, 70).sum() - 2 ) - assert( tpu.sum_projection_pixel_intensity == 1800 * np.arange(20, 70).sum() - 3 ) - - assert( tpu.max_voxel_index == 69 ) - assert( tpu.max_voxel_density == 69 ) + assert tpu.sum_pixels == 900 * 50 - 1 + assert tpu.sum_projection_pixels == 900 * np.arange(20, 70).sum() - 2 + assert tpu.sum_projection_pixel_intensity == 1800 * np.arange(20, 70).sum() - 3 + + assert tpu.max_voxel_index == 69 + assert tpu.max_voxel_density == 69 diff --git a/allensdk/test/internal/mouse_connectivity/test_unionize_record.py b/allensdk/test/internal/mouse_connectivity/test_unionize_record.py index 5c4fe941a7..673d7dc5d7 100644 --- a/allensdk/test/internal/mouse_connectivity/test_unionize_record.py +++ b/allensdk/test/internal/mouse_connectivity/test_unionize_record.py @@ -1,14 +1,11 @@ - - import pytest from allensdk.internal.mouse_connectivity.interval_unionize.unionize_record import Unionize -@pytest.mark.parametrize('method', ['__init__', 'calculate', 'propagate', 'output']) +@pytest.mark.parametrize("method", ["__init__", "calculate", "propagate", "output"]) def test_unionize(method): - un = object.__new__(Unionize) - + with pytest.raises(NotImplementedError): - getattr(un, method)('foo', 'fish') + getattr(un, method)("foo", "fish") diff --git a/allensdk/test/internal/test_annotated_region_metrics.py b/allensdk/test/internal/test_annotated_region_metrics.py index a23fbeee1f..c79240b6b9 100644 --- a/allensdk/test/internal/test_annotated_region_metrics.py +++ b/allensdk/test/internal/test_annotated_region_metrics.py @@ -2,9 +2,10 @@ import numpy as np from allensdk.internal.brain_observatory import annotated_region_metrics + @pytest.fixture def mask(): - return np.ones((10,10), dtype=bool) + return np.ones((10, 10), dtype=bool) def retinotopic_map(return_x=True): @@ -27,41 +28,34 @@ def altitude_map(): def test_eccentricity(azimuth_map, altitude_map): - ecc = annotated_region_metrics.eccentricity(azimuth_map, altitude_map, - 0.0, 0.0) - assert(ecc.shape == azimuth_map.shape) + ecc = annotated_region_metrics.eccentricity(azimuth_map, altitude_map, 0.0, 0.0) + assert ecc.shape == azimuth_map.shape def test_create_region_mask(mask): height, width = mask.shape x = y = 30 - region_mask = annotated_region_metrics.create_region_mask((100,100), - x, y, width, - height, - mask.tolist()) - assert(region_mask.shape == (100,100)) - assert(region_mask.sum() == mask.sum()) - assert(np.all(region_mask[y:y+height,x:x+width] == mask)) + region_mask = annotated_region_metrics.create_region_mask((100, 100), x, y, width, height, mask.tolist()) + assert region_mask.shape == (100, 100) + assert region_mask.sum() == mask.sum() + assert np.all(region_mask[y : y + height, x : x + width] == mask) def test_retinotopy_metric(azimuth_map, mask): height, width = mask.shape x = y = 30 - region_mask = annotated_region_metrics.create_region_mask( - azimuth_map.shape, x, y, width, height, mask.tolist()) - rmin, rmax, rrange, rbias = annotated_region_metrics.retinotopy_metric( - region_mask, azimuth_map) + region_mask = annotated_region_metrics.create_region_mask(azimuth_map.shape, x, y, width, height, mask.tolist()) + rmin, rmax, rrange, rbias = annotated_region_metrics.retinotopy_metric(region_mask, azimuth_map) rmap = np.degrees(azimuth_map[np.where(region_mask > 0)]) - assert(rmin == rmap.min()) - assert(rmax == rmap.max()) + assert rmin == rmap.min() + assert rmax == rmap.max() def test_get_metrics(altitude_map, azimuth_map, mask): height, width = mask.shape x = y = 30 - result = annotated_region_metrics.get_metrics(altitude_map, azimuth_map, - x=x, y=y, width=width, - height=height, - mask=mask.tolist()) - assert(isinstance(result, dict)) - assert('azimuth_min' in result) + result = annotated_region_metrics.get_metrics( + altitude_map, azimuth_map, x=x, y=y, width=width, height=height, mask=mask.tolist() + ) + assert isinstance(result, dict) + assert "azimuth_min" in result diff --git a/allensdk/test/internal/test_biophysical_modules.py b/allensdk/test/internal/test_biophysical_modules.py index ee0cbeb901..e42c4dd65a 100644 --- a/allensdk/test/internal/test_biophysical_modules.py +++ b/allensdk/test/internal/test_biophysical_modules.py @@ -1,12 +1,11 @@ -from allensdk.internal.api.queries.biophysical_module_api \ - import BiophysicalModuleApi +from allensdk.internal.api.queries.biophysical_module_api import BiophysicalModuleApi import pytest from unittest.mock import patch @pytest.fixture def biophysical_api(): - bma = BiophysicalModuleApi('http://axon:3000') + bma = BiophysicalModuleApi("http://axon:3000") return bma @@ -15,16 +14,18 @@ def test_get_neuronal_model_runs(biophysical_api): neuronal_model_run_id = 464137111 with patch.object(biophysical_api, "json_msg_query") as mock_query: biophysical_api.get_neuronal_model_runs(neuronal_model_run_id) - expected = ("http://axon:3000/api/v2/data/query.json?q=model::Neuronal" - "ModelRun,rma::criteria,[id$in464137111],rma::include,well_" - "known_files(well_known_file_type),neuronal_model(well_known_" - "files(well_known_file_type),specimen(project,specimen_tags," - "ephys_roi_result(ephys_qc_criteria,well_known_files(well_" - "known_file_type)),neuron_reconstructions(well_known_files" - "(well_known_file_type)),ephys_sweeps(ephys_sweep_tags,ephys_" - "stimulus(ephys_stimulus_type))),neuronal_model_template" - "(neuronal_model_template_type,well_known_files(well_known_" - "file_type))),rma::options[num_rows$eq'all'][count$eqfalse]") + expected = ( + "http://axon:3000/api/v2/data/query.json?q=model::Neuronal" + "ModelRun,rma::criteria,[id$in464137111],rma::include,well_" + "known_files(well_known_file_type),neuronal_model(well_known_" + "files(well_known_file_type),specimen(project,specimen_tags," + "ephys_roi_result(ephys_qc_criteria,well_known_files(well_" + "known_file_type)),neuron_reconstructions(well_known_files" + "(well_known_file_type)),ephys_sweeps(ephys_sweep_tags,ephys_" + "stimulus(ephys_stimulus_type))),neuronal_model_template" + "(neuronal_model_template_type,well_known_files(well_known_" + "file_type))),rma::options[num_rows$eq'all'][count$eqfalse]" + ) mock_query.assert_called_once_with(expected) @@ -32,14 +33,16 @@ def test_get_neuronal_models(biophysical_api): neuronal_model_id = 329322394 with patch.object(biophysical_api, "json_msg_query") as mock_query: biophysical_api.get_neuronal_models(neuronal_model_id) - expected = ("http://axon:3000/api/v2/data/query.json?q=model::Neuronal" - "Model,rma::criteria,[id$in329322394],rma::include," - "well_known_files(well_known_file_type),specimen(project," - "specimen_tags,ephys_roi_result(ephys_qc_criteria,well_known_" - "files(well_known_file_type)),neuron_reconstructions(well_" - "known_files(well_known_file_type)),ephys_sweeps(ephys_sweep_" - "tags,ephys_stimulus(ephys_stimulus_type))),neuronal_model_" - "template(neuronal_model_template_type,well_known_files(well_" - "known_file_type)),rma::options[num_rows$eq'all'][count$" - "eqfalse]") + expected = ( + "http://axon:3000/api/v2/data/query.json?q=model::Neuronal" + "Model,rma::criteria,[id$in329322394],rma::include," + "well_known_files(well_known_file_type),specimen(project," + "specimen_tags,ephys_roi_result(ephys_qc_criteria,well_known_" + "files(well_known_file_type)),neuron_reconstructions(well_" + "known_files(well_known_file_type)),ephys_sweeps(ephys_sweep_" + "tags,ephys_stimulus(ephys_stimulus_type))),neuronal_model_" + "template(neuronal_model_template_type,well_known_files(well_" + "known_file_type)),rma::options[num_rows$eq'all'][count$" + "eqfalse]" + ) mock_query.assert_called_once_with(expected) diff --git a/allensdk/test/internal/test_core_feature_extract.py b/allensdk/test/internal/test_core_feature_extract.py index 48686b60dd..0913373f47 100644 --- a/allensdk/test/internal/test_core_feature_extract.py +++ b/allensdk/test/internal/test_core_feature_extract.py @@ -1,17 +1,20 @@ -from allensdk.internal.ephys.core_feature_extract import ( find_stim_start, - filter_sweeps, - find_coarse_long_square_amp_delta, - nan_get ) +from allensdk.internal.ephys.core_feature_extract import ( + find_stim_start, + filter_sweeps, + find_coarse_long_square_amp_delta, + nan_get, +) + def test_find_stim_start(): - a = [0,0,0,1,1,1,0,0,0] + a = [0, 0, 0, 1, 1, 1, 0, 0, 0] idx = find_stim_start(a) assert idx == 3 idx = find_stim_start(a, 1) assert idx == 3 - a = [0,0,0,-1,-1,-1,0,0,0] + a = [0, 0, 0, -1, -1, -1, 0, 0, 0] idx = find_stim_start(a) assert idx == 3 @@ -19,7 +22,7 @@ def test_find_stim_start(): idx = find_stim_start(a) assert idx == -1 - a = [0,0,0] + a = [0, 0, 0] idx = find_stim_start(a) assert idx == -1 @@ -27,79 +30,190 @@ def test_find_stim_start(): idx = find_stim_start(a) assert idx == -1 + def test_filter_sweeps(): - a = [ { 'sweep_number': 1 }, { 'sweep_number': 0 } ] + a = [{"sweep_number": 1}, {"sweep_number": 0}] sweeps = filter_sweeps(a, passed_only=False, iclamp_only=False) assert len(sweeps) == 2 - assert [ s['sweep_number'] for s in sweeps ] == [ 0,1 ] - - a = [ { 'sweep_number': 1, 'workflow_state': 'auto_passed', 'stimulus_units': 'fish' }, - { 'sweep_number': 0, 'workflow_state': 'auto_failed', 'stimulus_units': 'pA' }, - { 'sweep_number': 2, 'workflow_state': 'manual_passed', 'stimulus_units': 'Amps' }, - { 'sweep_number': 3, 'workflow_state': 'manual_failed', 'stimulus_units': 'taco' } ] + assert [s["sweep_number"] for s in sweeps] == [0, 1] + + a = [ + {"sweep_number": 1, "workflow_state": "auto_passed", "stimulus_units": "fish"}, + {"sweep_number": 0, "workflow_state": "auto_failed", "stimulus_units": "pA"}, + {"sweep_number": 2, "workflow_state": "manual_passed", "stimulus_units": "Amps"}, + {"sweep_number": 3, "workflow_state": "manual_failed", "stimulus_units": "taco"}, + ] sweeps = filter_sweeps(a, passed_only=True, iclamp_only=False) assert len(sweeps) == 2 sweeps = filter_sweeps(a, passed_only=True, iclamp_only=True) assert len(sweeps) == 1 - a = [ { 'sweep_number': 1, 'ephys_stimulus': { 'description': 'T1x' } }, - { 'sweep_number': 0, 'ephys_stimulus': { 'description': 'T2x' } }, - { 'sweep_number': 2, 'ephys_stimulus': { 'description': 'T3x' } }, - { 'sweep_number': 3, 'ephys_stimulus': { 'description': 'T1x' } } ] + a = [ + {"sweep_number": 1, "ephys_stimulus": {"description": "T1x"}}, + {"sweep_number": 0, "ephys_stimulus": {"description": "T2x"}}, + {"sweep_number": 2, "ephys_stimulus": {"description": "T3x"}}, + {"sweep_number": 3, "ephys_stimulus": {"description": "T1x"}}, + ] - sweeps = filter_sweeps(a, passed_only=False, iclamp_only=False, types=['T1', 'T2']) + sweeps = filter_sweeps(a, passed_only=False, iclamp_only=False, types=["T1", "T2"]) assert len(sweeps) == 3 + def test_find_coarse_long_square_amp_delta(): - a = [ { 'stimulus_amplitude': 10, 'sweep_number': 1, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 10, 'sweep_number': 0, 'ephys_stimulus': { 'description': 'C1LSFINE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 10, 'sweep_number': 2, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_failed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 10, 'sweep_number': 3, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' } ] + a = [ + { + "stimulus_amplitude": 10, + "sweep_number": 1, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 10, + "sweep_number": 0, + "ephys_stimulus": {"description": "C1LSFINE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 10, + "sweep_number": 2, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_failed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 10, + "sweep_number": 3, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + ] delta = find_coarse_long_square_amp_delta(a) assert delta == 0 - a = [ { 'stimulus_amplitude': 10, 'sweep_number': 1, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 20, 'sweep_number': 0, 'ephys_stimulus': { 'description': 'C1LSFINE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 30, 'sweep_number': 2, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_failed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 40, 'sweep_number': 3, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' } ] + a = [ + { + "stimulus_amplitude": 10, + "sweep_number": 1, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 20, + "sweep_number": 0, + "ephys_stimulus": {"description": "C1LSFINE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 30, + "sweep_number": 2, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_failed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 40, + "sweep_number": 3, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + ] delta = find_coarse_long_square_amp_delta(a) assert delta == 10 - a = [ { 'stimulus_amplitude': 10, 'sweep_number': 1, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 20, 'sweep_number': 0, 'ephys_stimulus': { 'description': 'C1LSFINE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 20, 'sweep_number': 2, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_failed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 30, 'sweep_number': 3, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' } ] + a = [ + { + "stimulus_amplitude": 10, + "sweep_number": 1, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 20, + "sweep_number": 0, + "ephys_stimulus": {"description": "C1LSFINE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 20, + "sweep_number": 2, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_failed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 30, + "sweep_number": 3, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + ] delta = find_coarse_long_square_amp_delta(a) assert delta == 10 - a = [ { 'stimulus_amplitude': 10, 'sweep_number': 0, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 20, 'sweep_number': 1, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 20, 'sweep_number': 2, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_failed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 30, 'sweep_number': 3, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' }, - { 'stimulus_amplitude': 50, 'sweep_number': 5, 'ephys_stimulus': { 'description': 'C1LSCOARSE' }, 'workflow_state': 'auto_passed', 'stimulus_units': 'pA' } ] + a = [ + { + "stimulus_amplitude": 10, + "sweep_number": 0, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 20, + "sweep_number": 1, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 20, + "sweep_number": 2, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_failed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 30, + "sweep_number": 3, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + { + "stimulus_amplitude": 50, + "sweep_number": 5, + "ephys_stimulus": {"description": "C1LSCOARSE"}, + "workflow_state": "auto_passed", + "stimulus_units": "pA", + }, + ] delta = find_coarse_long_square_amp_delta(a) assert delta == 10 + def test_nan_get(): a = {} - v = nan_get(a, 'fish') + v = nan_get(a, "fish") assert v is None - a = { 'fish': 1 } - v = nan_get(a, 'fish') + a = {"fish": 1} + v = nan_get(a, "fish") assert v == 1 - a = { 'fish': float("nan") } - v = nan_get(a, 'fish') + a = {"fish": float("nan")} + v = nan_get(a, "fish") assert v is None - - - - - - diff --git a/allensdk/test/internal/test_eye_calibration.py b/allensdk/test/internal/test_eye_calibration.py index 2df6d2210f..88b21f7530 100644 --- a/allensdk/test/internal/test_eye_calibration.py +++ b/allensdk/test/internal/test_eye_calibration.py @@ -2,79 +2,88 @@ import numpy as np from allensdk.internal.brain_observatory import eye_calibration + def cr_params(): - x, y = np.meshgrid(np.array([300, 320, 340]), - np.array([220, 240, 260])) + x, y = np.meshgrid(np.array([300, 320, 340]), np.array([220, 240, 260])) return np.vstack((x.flatten(), y.flatten())).T + def pupil_params(): - x, y = np.meshgrid(np.array([280, 320, 380]), - np.array([200, 240, 280])) + x, y = np.meshgrid(np.array([280, 320, 380]), np.array([200, 240, 280])) return np.vstack((x.flatten(), y.flatten())).T -@pytest.mark.parametrize("led_position,eye_radius", [ - (np.array([25.89, -6.12, 3.21]), 0.1682), - (np.array([24.6, 9.23, 5.26]), 0.1682), - (np.array([20.0, 20.0, 20.0]), 500), -]) +@pytest.mark.parametrize( + "led_position,eye_radius", + [ + (np.array([25.89, -6.12, 3.21]), 0.1682), + (np.array([24.6, 9.23, 5.26]), 0.1682), + (np.array([20.0, 20.0, 20.0]), 500), + ], +) def test_cr_position_in_mouse_eye_coordinates(led_position, eye_radius): - cr = eye_calibration.EyeCalibration.cr_position_in_mouse_eye_coordinates( - led_position, eye_radius) + cr = eye_calibration.EyeCalibration.cr_position_in_mouse_eye_coordinates(led_position, eye_radius) tol = 0.000000001 - assert(np.abs(np.linalg.norm(cr) - 0.5*eye_radius) < tol) - err = np.abs(cr/np.linalg.norm(cr) - \ - led_position/np.linalg.norm(led_position)) - assert(np.all(err < tol)) + assert np.abs(np.linalg.norm(cr) - 0.5 * eye_radius) < tol + err = np.abs(cr / np.linalg.norm(cr) - led_position / np.linalg.norm(led_position)) + assert np.all(err < tol) -@pytest.mark.parametrize("led_position,camera_rotations", [ - (np.array([10.0, 0.0, 0.0]),np.array([0.0, 0.0, 0.0])), - (np.array([10.0, 0.0, 0.0]),np.array([0.0, 0.0, np.pi/4])) -]) -def test_pupil_position_in_mouse_eye_coordinates_right( - led_position, camera_rotations): - CM_PER_PIXEL = 10.2/10000 +@pytest.mark.parametrize( + "led_position,camera_rotations", + [ + (np.array([10.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0])), + (np.array([10.0, 0.0, 0.0]), np.array([0.0, 0.0, np.pi / 4])), + ], +) +def test_pupil_position_in_mouse_eye_coordinates_right(led_position, camera_rotations): + CM_PER_PIXEL = 10.2 / 10000 TOL = 0.0000000001 c = eye_calibration.EyeCalibration( - led_position=led_position, cm_per_pixel=CM_PER_PIXEL, - eye_radius=0.1682, camera_rotations=camera_rotations, - camera_position=np.array([13.0, 0.0, 0.0])) + led_position=led_position, + cm_per_pixel=CM_PER_PIXEL, + eye_radius=0.1682, + camera_rotations=camera_rotations, + camera_position=np.array([13.0, 0.0, 0.0]), + ) pupil = pupil_params() cr = cr_params() - bad = c.pupil_position_in_mouse_eye_coordinates(np.array([[1000, 0], [0, 1000]]), - np.array([[0, 0], [0, 0]])) - assert(np.all(np.isnan(bad))) + bad = c.pupil_position_in_mouse_eye_coordinates(np.array([[1000, 0], [0, 1000]]), np.array([[0, 0], [0, 0]])) + assert np.all(np.isnan(bad)) pos = c.pupil_position_in_mouse_eye_coordinates(pupil, cr) - x = (pupil.T[0] - cr.T[0])*CM_PER_PIXEL - y = (cr.T[1] - pupil.T[1])*CM_PER_PIXEL - xr = x*np.cos(-camera_rotations[2]) - y*np.sin(-camera_rotations[2]) - yr = x*np.sin(-camera_rotations[2]) + y*np.cos(-camera_rotations[2]) - assert(np.all(np.abs(pos.T[2] - (yr + c.cr[2])) < TOL)) - assert(np.all(np.abs(pos.T[1] - (xr + c.cr[1])) < TOL)) + x = (pupil.T[0] - cr.T[0]) * CM_PER_PIXEL + y = (cr.T[1] - pupil.T[1]) * CM_PER_PIXEL + xr = x * np.cos(-camera_rotations[2]) - y * np.sin(-camera_rotations[2]) + yr = x * np.sin(-camera_rotations[2]) + y * np.cos(-camera_rotations[2]) + assert np.all(np.abs(pos.T[2] - (yr + c.cr[2])) < TOL) + assert np.all(np.abs(pos.T[1] - (xr + c.cr[1])) < TOL) -@pytest.mark.parametrize("led_position,camera_rotations", [ - (np.array([10.0, 0.0, 0.0]),np.array([0.0, 0.0, 0.0])), - (np.array([10.0, 0.0, 0.0]),np.array([0.0, 0.0, np.pi/4])) -]) -def test_pupil_position_in_mouse_eye_coordinates_front( - led_position, camera_rotations): - CM_PER_PIXEL = 10.2/10000 +@pytest.mark.parametrize( + "led_position,camera_rotations", + [ + (np.array([10.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0])), + (np.array([10.0, 0.0, 0.0]), np.array([0.0, 0.0, np.pi / 4])), + ], +) +def test_pupil_position_in_mouse_eye_coordinates_front(led_position, camera_rotations): + CM_PER_PIXEL = 10.2 / 10000 TOL = 0.0000000001 c = eye_calibration.EyeCalibration( - led_position=led_position, cm_per_pixel=CM_PER_PIXEL, - eye_radius=0.1682, camera_rotations=camera_rotations, - camera_position=np.array([0.0, 13.0, 0.0])) + led_position=led_position, + cm_per_pixel=CM_PER_PIXEL, + eye_radius=0.1682, + camera_rotations=camera_rotations, + camera_position=np.array([0.0, 13.0, 0.0]), + ) pupil = pupil_params() cr = cr_params() - bad = c.pupil_position_in_mouse_eye_coordinates(np.array([[1000, 0], [0, 1000]]), - np.array([[0, 0], [0, 0]])) - assert(np.all(np.isnan(bad))) + bad = c.pupil_position_in_mouse_eye_coordinates(np.array([[1000, 0], [0, 1000]]), np.array([[0, 0], [0, 0]])) + assert np.all(np.isnan(bad)) pos = c.pupil_position_in_mouse_eye_coordinates(pupil, cr) - x = (pupil.T[0] - cr.T[0])*CM_PER_PIXEL - y = (cr.T[1] - pupil.T[1])*CM_PER_PIXEL - xr = x*np.cos(-camera_rotations[2]) - y*np.sin(-camera_rotations[2]) - yr = x*np.sin(-camera_rotations[2]) + y*np.cos(-camera_rotations[2]) - assert(np.all(np.abs(pos.T[2] - (yr + c.cr[2])) < TOL)) - assert(np.all(np.abs(pos.T[0] + (xr - c.cr[0])) < TOL)) + x = (pupil.T[0] - cr.T[0]) * CM_PER_PIXEL + y = (cr.T[1] - pupil.T[1]) * CM_PER_PIXEL + xr = x * np.cos(-camera_rotations[2]) - y * np.sin(-camera_rotations[2]) + yr = x * np.sin(-camera_rotations[2]) + y * np.cos(-camera_rotations[2]) + assert np.all(np.abs(pos.T[2] - (yr + c.cr[2])) < TOL) + assert np.all(np.abs(pos.T[0] + (xr - c.cr[0])) < TOL) diff --git a/allensdk/test/internal/test_internal.py b/allensdk/test/internal/test_internal.py index 7008a6e14f..d3b2d6b3b2 100644 --- a/allensdk/test/internal/test_internal.py +++ b/allensdk/test/internal/test_internal.py @@ -7,6 +7,7 @@ Tests for `internal` module. """ + import pytest @@ -16,10 +17,6 @@ def decorated_example(): See more at: http://doc.pytest.org/en/latest/fixture.html """ -def test_example(decorated_example): - """Sample pytest test function with the pytest fixture as an argument. - """ - - - +def test_example(decorated_example): + """Sample pytest test function with the pytest fixture as an argument.""" diff --git a/allensdk/test/internal/test_mtrain_api.py b/allensdk/test/internal/test_mtrain_api.py index 9e10283522..a83f0e827e 100644 --- a/allensdk/test/internal/test_mtrain_api.py +++ b/allensdk/test/internal/test_mtrain_api.py @@ -4,20 +4,26 @@ @pytest.mark.nightly -@pytest.mark.parametrize('api', [ - pytest.param(MtrainApi()), - pytest.param(MtrainSqlApi()), -]) +@pytest.mark.parametrize( + "api", + [ + pytest.param(MtrainApi()), + pytest.param(MtrainSqlApi()), + ], +) def test_get_subjects(api): subject_list = api.get_subjects() assert len(subject_list) > 190 and 423746 in subject_list @pytest.mark.nightly -@pytest.mark.parametrize('api', [ - pytest.param(MtrainApi()), - pytest.param(MtrainSqlApi()), -]) +@pytest.mark.parametrize( + "api", + [ + pytest.param(MtrainApi()), + pytest.param(MtrainSqlApi()), + ], +) def test_get_behavior_training_df(api): LabTracks_ID = 423986 df = api.get_behavior_training_df(LabTracks_ID) @@ -27,91 +33,83 @@ def test_get_behavior_training_df(api): @pytest.mark.nightly -@pytest.mark.parametrize('LabTracks_ID', [ - pytest.param(423986), -]) +@pytest.mark.parametrize( + "LabTracks_ID", + [ + pytest.param(423986), + ], +) def test_get_current_stage(LabTracks_ID): api = MtrainApi() stage = api.get_current_stage(LabTracks_ID) - assert stage == 'OPHYS_6_images_B' + assert stage == "OPHYS_6_images_B" @pytest.mark.nightly -@pytest.mark.parametrize('behavior_session_uuid, behavior_session_id', - [pytest.param('394a910e-94c7-4472-9838-5345aff59ed8', - None), - pytest.param(None, 823847007), - pytest.param('394a910e-94c7-4472-9838-5345aff59ed8', - 823847007), - ]) +@pytest.mark.parametrize( + "behavior_session_uuid, behavior_session_id", + [ + pytest.param("394a910e-94c7-4472-9838-5345aff59ed8", None), + pytest.param(None, 823847007), + pytest.param("394a910e-94c7-4472-9838-5345aff59ed8", 823847007), + ], +) def test_get_session(behavior_session_uuid, behavior_session_id): api = MtrainApi() - kwargs = {key: val for key, val in - [('behavior_session_uuid', behavior_session_uuid), - ('behavior_session_id', behavior_session_id)] if - val is not None} + kwargs = { + key: val + for key, val in [("behavior_session_uuid", behavior_session_uuid), ("behavior_session_id", behavior_session_id)] + if val is not None + } session_dict = api.get_session(**kwargs) - trials_df = session_dict.pop('trials') + trials_df = session_dict.pop("trials") assert len(trials_df) == 576 assert "stages" in session_dict.keys() # Remove stages because it's # very long del session_dict["stages"] - assert session_dict == {u'name': u'TRAINING_1_gratings', - u'parameters': {u'auto_reward_delay': 0.15, - u'change_time_scale': 2.0, - u'end_after_response': True, - u'change_flashes_max': None, - u'change_time_dist': - u'exponential', - u'stimulus_window': 6.0, - u'response_window': [0.15, 1.0], - u'change_flashes_min': None, - u'catch_frequency': 0.25, - u'min_no_lick_time': 0.0, - u'timeout_duration': 0.3, - u'free_reward_trials': 10, - u'volume_limit': 5.0, - u'max_task_duration_min': 60.0, - u'reward_volume': 0.01, - u'end_after_response_sec': 3.5, - u'start_stop_padding': 20.0, - u'periodic_flash': None, - u'stage': u'TRAINING_1_gratings', - u'auto_reward_vol': 0.005, - u'task_id': u'DoC', - u'stimulus': { - u'params': {u'phase': 0.25, - u'tex': u'sqr', - u'units': u'deg', - u'sf': 0.04, - u'size': [200, - 150]}, - u'class': u'grating', - u'groups': {u'horizontal': { - u'Ori': [90, 270]}, - u'vertical': { - u'Ori': [0, - 180]}}}, - u'failure_repeats': 5, - u'warm_up_trials': 5, - u'pre_change_time': 2.25}, - u'script': - u'http://stash.corp.alleninstitute.org/' - u'projects/VB/repos/visual_behavior_scripts/' - u'raw/change_detection_with_fingerprint.py?at=' - u'021ec55fbbdbb05aad1681c016e83066fe5aa1dd', - 'behavior_session_uuid': - u'394a910e-94c7-4472-9838-5345aff59ed8', - u'script_md5': u'e0535f3b6f03ccc8eeccaed2118f3c1d', - u'LabTracks_ID': 431151, - u'date': - u'2019-02-15T13:01:23.672000', - 'regimen_name': u'VisualBehavior_Task1A_v1.0.1', - u'default_x': False, - u'regimens': [{u'active': False, - u'default': False, - u'id': 14, - u'name': - u'VisualBehavior_Task1A_v1.0.1' - }], - u'default_y': False} + assert session_dict == { + "name": "TRAINING_1_gratings", + "parameters": { + "auto_reward_delay": 0.15, + "change_time_scale": 2.0, + "end_after_response": True, + "change_flashes_max": None, + "change_time_dist": "exponential", + "stimulus_window": 6.0, + "response_window": [0.15, 1.0], + "change_flashes_min": None, + "catch_frequency": 0.25, + "min_no_lick_time": 0.0, + "timeout_duration": 0.3, + "free_reward_trials": 10, + "volume_limit": 5.0, + "max_task_duration_min": 60.0, + "reward_volume": 0.01, + "end_after_response_sec": 3.5, + "start_stop_padding": 20.0, + "periodic_flash": None, + "stage": "TRAINING_1_gratings", + "auto_reward_vol": 0.005, + "task_id": "DoC", + "stimulus": { + "params": {"phase": 0.25, "tex": "sqr", "units": "deg", "sf": 0.04, "size": [200, 150]}, + "class": "grating", + "groups": {"horizontal": {"Ori": [90, 270]}, "vertical": {"Ori": [0, 180]}}, + }, + "failure_repeats": 5, + "warm_up_trials": 5, + "pre_change_time": 2.25, + }, + "script": "http://stash.corp.alleninstitute.org/" + "projects/VB/repos/visual_behavior_scripts/" + "raw/change_detection_with_fingerprint.py?at=" + "021ec55fbbdbb05aad1681c016e83066fe5aa1dd", + "behavior_session_uuid": "394a910e-94c7-4472-9838-5345aff59ed8", + "script_md5": "e0535f3b6f03ccc8eeccaed2118f3c1d", + "LabTracks_ID": 431151, + "date": "2019-02-15T13:01:23.672000", + "regimen_name": "VisualBehavior_Task1A_v1.0.1", + "default_x": False, + "regimens": [{"active": False, "default": False, "id": 14, "name": "VisualBehavior_Task1A_v1.0.1"}], + "default_y": False, + } diff --git a/allensdk/test/internal/test_optimize_config_reader.py b/allensdk/test/internal/test_optimize_config_reader.py index 957e94f29c..853922405c 100644 --- a/allensdk/test/internal/test_optimize_config_reader.py +++ b/allensdk/test/internal/test_optimize_config_reader.py @@ -1,7 +1,7 @@ -from allensdk.internal.api.queries.optimize_config_reader import \ - OptimizeConfigReader +from allensdk.internal.api.queries.optimize_config_reader import OptimizeConfigReader import pytest from unittest.mock import patch, mock_open + try: import __builtin__ as builtins except Exception: @@ -87,12 +87,11 @@ def no_param_config(): ocr = OptimizeConfigReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_NO_PARAM_FILES)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_NO_PARAM_FILES)): ocr.read_lims_file(lims_json_path) - + return ocr @@ -100,12 +99,11 @@ def no_param_config(): def one_param_config(): ocr = OptimizeConfigReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): - ocr.read_lims_file(lims_json_path) - + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): + ocr.read_lims_file(lims_json_path) + return ocr @@ -113,34 +111,30 @@ def one_param_config(): def two_param_config(): ocr = OptimizeConfigReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_TWO_PARAM_FILES)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_TWO_PARAM_FILES)): ocr.read_lims_file(lims_json_path) - + return ocr def test_no_params(no_param_config): - assert no_param_config.lims_data['well_known_files'][0]['well_known_file_type']['id'] != 329230374 - no_param_config.update_well_known_file('/path/to/params_fit.json', - OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) - assert no_param_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 329230374 - + assert no_param_config.lims_data["well_known_files"][0]["well_known_file_type"]["id"] != 329230374 + no_param_config.update_well_known_file("/path/to/params_fit.json", OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) + assert no_param_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 329230374 + def test_one_param(one_param_config): - assert one_param_config.lims_data['well_known_files'][1]['well_known_file_type']['id'] == 329230374 - one_param_config.update_well_known_file('/path/to/params_fit.json', - OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) - assert one_param_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 329230374 - assert one_param_config.lims_update_data['well_known_files'][1]['id'] == 22222 + assert one_param_config.lims_data["well_known_files"][1]["well_known_file_type"]["id"] == 329230374 + one_param_config.update_well_known_file("/path/to/params_fit.json", OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) + assert one_param_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 329230374 + assert one_param_config.lims_update_data["well_known_files"][1]["id"] == 22222 def test_two_params(two_param_config): - assert two_param_config.lims_data['well_known_files'][1]['well_known_file_type']['id'] == 329230374 - two_param_config.update_well_known_file('/path/to/params_fit.json', - OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) - assert two_param_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 329230374 - assert two_param_config.lims_update_data['well_known_files'][1]['id'] == 22222 - assert len(two_param_config.lims_update_data['well_known_files']) == 2 + assert two_param_config.lims_data["well_known_files"][1]["well_known_file_type"]["id"] == 329230374 + two_param_config.update_well_known_file("/path/to/params_fit.json", OptimizeConfigReader.NEURONAL_MODEL_PARAMETERS) + assert two_param_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 329230374 + assert two_param_config.lims_update_data["well_known_files"][1]["id"] == 22222 + assert len(two_param_config.lims_update_data["well_known_files"]) == 2 diff --git a/allensdk/test/internal/test_optimize_manifest.py b/allensdk/test/internal/test_optimize_manifest.py index 92c557083e..1ad25ef5c5 100644 --- a/allensdk/test/internal/test_optimize_manifest.py +++ b/allensdk/test/internal/test_optimize_manifest.py @@ -1,5 +1,4 @@ -from allensdk.internal.api.queries.optimize_config_reader import \ - OptimizeConfigReader +from allensdk.internal.api.queries.optimize_config_reader import OptimizeConfigReader import pytest from unittest.mock import patch, mock_open, MagicMock from io import StringIO, IOBase @@ -86,31 +85,28 @@ def manifest_as_string(reader): output = StringIO() - with patch(builtins.__name__ + ".open", - mock_open(), - create=True) as manifest_f: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as manifest_f: manifest_f.return_value = MagicMock(spec=IOBase) file_handle = manifest_f.return_value.__enter__.return_value file_handle.write.side_effect = output.write - + reader.to_manifest("test_manifest.json") - + return output.getvalue() @pytest.fixture def no_param_config(): json_data = json.loads(LIMS_MESSAGE_ONE_PARAM_FILE) - json_data['well_known_files'] = [] + json_data["well_known_files"] = [] lims_message_no_param_file = json.dumps(json_data) ocr = OptimizeConfigReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=lims_message_no_param_file)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=lims_message_no_param_file)): ocr.read_lims_file(lims_json_path) - + return ocr @@ -118,12 +114,11 @@ def no_param_config(): def one_param_config(): ocr = OptimizeConfigReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): ocr.read_lims_file(lims_json_path) - + return ocr @@ -132,67 +127,84 @@ def one_param_manifest_dict(one_param_config): reader = one_param_config json_string = manifest_as_string(reader) the_dict = json.loads(json_string) - + return the_dict def test_to_manifest(one_param_config): - with patch(builtins.__name__ + ".open", - mock_open(), - create=True) as manifest_f: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as manifest_f: manifest_f.return_value = MagicMock(spec=IOBase) one_param_config.to_manifest("test_manifest.json") manifest_f.assert_called_once_with("test_manifest.json", "wb+") def test_top_level_keys(one_param_manifest_dict): - assert set(one_param_manifest_dict.keys()) == set(['biophys', 'runs', 'neuron', 'manifest']) + assert set(one_param_manifest_dict.keys()) == set(["biophys", "runs", "neuron", "manifest"]) def test_manifest_hoc(one_param_manifest_dict): - assert set(one_param_manifest_dict['neuron'][0].keys()) == set(['hoc']) + assert set(one_param_manifest_dict["neuron"][0].keys()) == set(["hoc"]) def test_specimen_id(one_param_manifest_dict): - assert one_param_manifest_dict['runs'][0]['specimen_id'] == 98765 + assert one_param_manifest_dict["runs"][0]["specimen_id"] == 98765 def test_sweeps(one_param_manifest_dict): - assert set(one_param_manifest_dict['runs'][0]['sweeps']) == set([1, 2]) + assert set(one_param_manifest_dict["runs"][0]["sweeps"]) == set([1, 2]) def test_mod_file_paths(one_param_config): - assert set(one_param_config.mod_file_paths()) == set(['/path/to/mod_files/mod_file_1.mod']) + assert set(one_param_config.mod_file_paths()) == set(["/path/to/mod_files/mod_file_1.mod"]) def test_update_well_known_file_not_existing(no_param_config): fit_file_type = 329230374 - no_param_config.update_well_known_file('/path/to/new_fit.json') - wkf = no_param_config.lims_update_data['well_known_files'][0] - assert 'id' not in wkf - assert wkf['storage_directory'] == '/path/to' - assert wkf['filename'] == 'new_fit.json' - assert wkf['well_known_file_type_id'] == fit_file_type + no_param_config.update_well_known_file("/path/to/new_fit.json") + wkf = no_param_config.lims_update_data["well_known_files"][0] + assert "id" not in wkf + assert wkf["storage_directory"] == "/path/to" + assert wkf["filename"] == "new_fit.json" + assert wkf["well_known_file_type_id"] == fit_file_type def test_update_well_known_file_existing(one_param_config): fit_file_type = 329230374 - one_param_config.update_well_known_file('/path/to/new_fit.json') - wkf = one_param_config.lims_update_data['well_known_files'][0] - assert wkf['id'] == 22222 - assert wkf['storage_directory'] == '/path/to' - assert wkf['filename'] == 'new_fit.json' - assert wkf['well_known_file_type_id'] == fit_file_type - - + one_param_config.update_well_known_file("/path/to/new_fit.json") + wkf = one_param_config.lims_update_data["well_known_files"][0] + assert wkf["id"] == 22222 + assert wkf["storage_directory"] == "/path/to" + assert wkf["filename"] == "new_fit.json" + assert wkf["well_known_file_type_id"] == fit_file_type + + def test_manifest_keys(one_param_manifest_dict): - expected_keys = set(['BASEDIR', 'WORKDIR', 'MORPHOLOGY', 'MODFILE_DIR', - 'MOD_FILE_mod_file_1', 'stimulus_path', 'manifest', - 'output', 'neuronal_model_data', 'upfile', 'downfile', - 'passive_fit_data', 'stage_1_jobs', 'fit_1_file', - 'fit_2_file', 'fit_3_file', 'fit_type_path', - 'target_path', 'fit_config_json', 'final_hof_fit', - 'final_hof', 'output_fit_file']) - - actual_keys = set([e['key'] for e in one_param_manifest_dict['manifest']]) + expected_keys = set( + [ + "BASEDIR", + "WORKDIR", + "MORPHOLOGY", + "MODFILE_DIR", + "MOD_FILE_mod_file_1", + "stimulus_path", + "manifest", + "output", + "neuronal_model_data", + "upfile", + "downfile", + "passive_fit_data", + "stage_1_jobs", + "fit_1_file", + "fit_2_file", + "fit_3_file", + "fit_type_path", + "target_path", + "fit_config_json", + "final_hof_fit", + "final_hof", + "output_fit_file", + ] + ) + + actual_keys = set([e["key"] for e in one_param_manifest_dict["manifest"]]) assert actual_keys == expected_keys diff --git a/allensdk/test/internal/test_roi_filter.py b/allensdk/test/internal/test_roi_filter.py index 1ffac4440f..b2887e09f4 100644 --- a/allensdk/test/internal/test_roi_filter.py +++ b/allensdk/test/internal/test_roi_filter.py @@ -20,7 +20,7 @@ def __init__(self, stack, n_rois, has_duplicates, has_unions): def create_mask_plane(img_shape, dot_positions, radius=15): img = np.zeros(img_shape, dtype=np.uint8) for r, c in dot_positions: - if getattr(draw, 'circle', None) is not None: + if getattr(draw, "circle", None) is not None: # older version of skimage img[draw.circle(r, c, radius, shape=img_shape)] = 1 else: @@ -29,8 +29,7 @@ def create_mask_plane(img_shape, dot_positions, radius=15): return img -@pytest.fixture(params=[(False, False), (True, True), - (False, True), (True, False)]) +@pytest.fixture(params=[(False, False), (True, True), (False, True), (True, False)]) def segmentation(request): has_unions, has_duplicates = request.param plane1 = create_mask_plane((200, 200), [(20, 20), (130, 60), (170, 110)]) @@ -47,28 +46,144 @@ def segmentation(request): masks.append(dplane) n_rois += 1 - return TestSegmentation(np.array(masks), n_rois, has_duplicates, - has_unions) + return TestSegmentation(np.array(masks), n_rois, has_duplicates, has_unions) @pytest.fixture def object_list(): - columns = ["index", "traceindex", "tempIndex", "cx", "cy", "mask2Frame", - "frame", "object", "minx", "miny", "maxx", "maxy", "area", - "shape0", "shape1", "eXcluded", "meanInt0", "maxInt0", - "meanInt1", "maxInt1", "maxMeanRatio", "snpoffsetmean", - "snpoffsetstdv", "act2", "act3", "OvlpCount", "OvlpAreaPer", - "OvlpObj0", "corcoef0", "OvlpObj1", "corcoef1"] + columns = [ + "index", + "traceindex", + "tempIndex", + "cx", + "cy", + "mask2Frame", + "frame", + "object", + "minx", + "miny", + "maxx", + "maxy", + "area", + "shape0", + "shape1", + "eXcluded", + "meanInt0", + "maxInt0", + "meanInt1", + "maxInt1", + "maxMeanRatio", + "snpoffsetmean", + "snpoffsetstdv", + "act2", + "act3", + "OvlpCount", + "OvlpAreaPer", + "OvlpObj0", + "corcoef0", + "OvlpObj1", + "corcoef1", + ] data = [ - [0, 0, 112, 363, 12, 0, 81, 1, 354, 5, 371, 18, 170, 0.679, 9, 0, 49, - 73, 32, 54, 0.6875, -18.598810, 12.540119, 2778, 995, 1, 82, 85, - -1.000, 0, 0.000], - [1, 1, 12, 224, 13, 0, 2, 1, 218, 8, 230, 18, 106, 0.653, 10, 11, 30, - 62, 12, 23, 0.9167, -34.688274, 16.209919, 2818, 390, 0, 0, 0, - 0.000, 0, 0.000], - [2, 999, 109, 323, 9, 0, 206, 2, 315, 2, 331, 22, 193, 0.454, 16, 2, - 123, 255, 92, 225, 1.4457, 0.000000, 0.000000, 0, 0, 0, 0, 0, 0.000, - 0, 0.000] + [ + 0, + 0, + 112, + 363, + 12, + 0, + 81, + 1, + 354, + 5, + 371, + 18, + 170, + 0.679, + 9, + 0, + 49, + 73, + 32, + 54, + 0.6875, + -18.598810, + 12.540119, + 2778, + 995, + 1, + 82, + 85, + -1.000, + 0, + 0.000, + ], + [ + 1, + 1, + 12, + 224, + 13, + 0, + 2, + 1, + 218, + 8, + 230, + 18, + 106, + 0.653, + 10, + 11, + 30, + 62, + 12, + 23, + 0.9167, + -34.688274, + 16.209919, + 2818, + 390, + 0, + 0, + 0, + 0.000, + 0, + 0.000, + ], + [ + 2, + 999, + 109, + 323, + 9, + 0, + 206, + 2, + 315, + 2, + 331, + 22, + 193, + 0.454, + 16, + 2, + 123, + 255, + 92, + 225, + 1.4457, + 0.000000, + 0.000000, + 0, + 0, + 0, + 0, + 0, + 0.000, + 0, + 0.000, + ], ] return pd.DataFrame(data=data, columns=columns) @@ -81,8 +196,10 @@ def xy_data(): @pytest.fixture(scope="module") def old_csv(tmpdir_factory, xy_data): - data = ["0,{},154.086,-14.9831,0,0,1,0.0254625".format(xy_data[0]), - "1,{},-0.78758,-2.39286,0,0,0,0.348251".format(xy_data[1])] + data = [ + "0,{},154.086,-14.9831,0,0,1,0.0254625".format(xy_data[0]), + "1,{},-0.78758,-2.39286,0,0,0,0.348251".format(xy_data[1]), + ] filename = str(tmpdir_factory.mktemp("test").join("old.csv")) with open(filename, "w") as f: f.write("\n".join(data)) @@ -91,9 +208,11 @@ def old_csv(tmpdir_factory, xy_data): @pytest.fixture(scope="module") def new_csv(tmpdir_factory, xy_data): - data = ["framenumber,x,y,correlation,input_x,input_y,estimate", - "0,{},0.65566,3.00745,-1.75258,PhaseCorrelated".format(xy_data[0]), - "1,{},0.65727,3.15259,-2.8105,PhaseCorrelated".format(xy_data[1])] + data = [ + "framenumber,x,y,correlation,input_x,input_y,estimate", + "0,{},0.65566,3.00745,-1.75258,PhaseCorrelated".format(xy_data[0]), + "1,{},0.65727,3.15259,-2.8105,PhaseCorrelated".format(xy_data[1]), + ] filename = str(tmpdir_factory.mktemp("test").join("new.csv")) with open(filename, "w") as f: f.write("\n".join(data)) @@ -103,29 +222,27 @@ def new_csv(tmpdir_factory, xy_data): def model_data(ol, is_valid=True): training_columns = list(ol.columns) if is_valid: - training_columns.extend( - [1, "depth", "driver1", "driver2", "reporter1"]) + training_columns.extend([1, "depth", "driver1", "driver2", "reporter1"]) else: - training_columns.extend( - [2, "depth", "driver1", "driver2", "reporter1"]) - data = {"structure_ids": [1], - "drivers": ["driver1", "driver2"], - "reporters": ["reporter1"], - "training_features": pd.DataFrame(columns=training_columns)} + training_columns.extend([2, "depth", "driver1", "driver2", "reporter1"]) + data = { + "structure_ids": [1], + "drivers": ["driver1", "driver2"], + "reporters": ["reporter1"], + "training_features": pd.DataFrame(columns=training_columns), + } return data def test_calculate_max_border_all_outliers(): - df = pd.DataFrame( - np.ones((100, 9)), - columns=["index", "x", "y", "a", "b", "c", "d", "e", "f"]) + df = pd.DataFrame(np.ones((100, 9)), columns=["index", "x", "y", "a", "b", "c", "d", "e", "f"]) with pytest.raises(ValueError): roi_filter_utils.calculate_max_border(df, 0) def test_get_rois(segmentation): rois = roi_filter_utils.get_rois(segmentation.stack) - assert (len(rois) == segmentation.n_rois) + assert len(rois) == segmentation.n_rois def test_label_unions_and_duplicates(segmentation): @@ -138,8 +255,8 @@ def test_label_unions_and_duplicates(segmentation): duplicates |= 1 if "union" in roi.labels: unions |= 1 - assert (duplicates == segmentation.has_duplicates) - assert (unions == segmentation.has_unions) + assert duplicates == segmentation.has_duplicates + assert unions == segmentation.has_unions def test_create_feature_array(object_list): @@ -150,24 +267,20 @@ def test_create_feature_array(object_list): passing_data = model_data(object_list, True) failing_data = model_data(object_list, False) with pytest.raises(KeyError): - roi_filter.create_feature_array(failing_data, object_list, depth, - structure_id, drivers, reporters) - feature_array = roi_filter.create_feature_array(passing_data, object_list, - depth, structure_id, - drivers, reporters) - assert (np.all(feature_array.columns == - passing_data["training_features"].columns)) + roi_filter.create_feature_array(failing_data, object_list, depth, structure_id, drivers, reporters) + feature_array = roi_filter.create_feature_array(passing_data, object_list, depth, structure_id, drivers, reporters) + assert np.all(feature_array.columns == passing_data["training_features"].columns) def test_training_label_classifier(object_list): classifier = roi_filter_utils.TrainingMultiLabelClassifier() - assert (classifier.labels == sorted(roi_filter_utils.CRITERIA().keys())) + assert classifier.labels == sorted(roi_filter_utils.CRITERIA().keys()) def test_read_csv(old_csv, new_csv): - assert (not run_roi_filter.is_deprecated_motion_file(new_csv)) - assert (run_roi_filter.is_deprecated_motion_file(old_csv)) + assert not run_roi_filter.is_deprecated_motion_file(new_csv) + assert run_roi_filter.is_deprecated_motion_file(old_csv) old_data = run_roi_filter.load_rigid_motion_transform(old_csv) new_data = run_roi_filter.load_rigid_motion_transform(new_csv) - assert (np.all(np.isclose(old_data["x"], new_data["x"]))) - assert (np.all(np.isclose(old_data["y"], new_data["y"]))) + assert np.all(np.isclose(old_data["x"], new_data["x"])) + assert np.all(np.isclose(old_data["y"], new_data["y"])) diff --git a/allensdk/test/internal/test_simulate_manifest.py b/allensdk/test/internal/test_simulate_manifest.py index ceaed8b84f..413cd79bf0 100644 --- a/allensdk/test/internal/test_simulate_manifest.py +++ b/allensdk/test/internal/test_simulate_manifest.py @@ -1,5 +1,4 @@ -from allensdk.internal.api.queries.biophysical_module_reader import \ - BiophysicalModuleReader +from allensdk.internal.api.queries.biophysical_module_reader import BiophysicalModuleReader import pytest from unittest.mock import patch, mock_open, MagicMock from io import StringIO, IOBase @@ -128,34 +127,31 @@ def manifest_as_string(reader): output = StringIO() - with patch(builtins.__name__ + ".open", - mock_open(), - create=True) as manifest_f: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as manifest_f: manifest_f.return_value = MagicMock(spec=IOBase) file_handle = manifest_f.return_value.__enter__.return_value file_handle.write.side_effect = output.write - + reader.to_manifest("test_manifest.json") - + manifest_string = output.getvalue() print(manifest_string) - + return manifest_string @pytest.fixture def no_param_config(): json_data = json.loads(LIMS_MESSAGE_ONE_PARAM_FILE) - json_data['well_known_files'] = [] + json_data["well_known_files"] = [] lims_message_no_param_file = json.dumps(json_data) scr = BiophysicalModuleReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=lims_message_no_param_file)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=lims_message_no_param_file)): scr.read_lims_file(lims_json_path) - + return scr @@ -163,12 +159,11 @@ def no_param_config(): def one_param_config(): scr = BiophysicalModuleReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_ONE_PARAM_FILE)): scr.read_lims_file(lims_json_path) - + return scr @@ -177,85 +172,92 @@ def one_param_manifest_dict(one_param_config): reader = one_param_config json_string = manifest_as_string(reader) the_dict = json.loads(json_string) - + return the_dict def test_to_manifest(one_param_config): - with patch(builtins.__name__ + ".open", - mock_open(), - create=True) as manifest_f: + with patch(builtins.__name__ + ".open", mock_open(), create=True) as manifest_f: manifest_f.return_value = MagicMock(spec=IOBase) one_param_config.to_manifest("test_manifest.json") manifest_f.assert_called_once_with("test_manifest.json", "wb+") def test_top_level_keys(one_param_manifest_dict): - assert set(one_param_manifest_dict.keys()) == \ - set(['biophys', 'runs', 'neuron', 'manifest']) + assert set(one_param_manifest_dict.keys()) == set(["biophys", "runs", "neuron", "manifest"]) def test_manifest_hoc(one_param_manifest_dict): - assert set(one_param_manifest_dict['neuron'][0].keys()) == set(['hoc']) + assert set(one_param_manifest_dict["neuron"][0].keys()) == set(["hoc"]) def test_neuronal_model_run_id(one_param_manifest_dict): - assert one_param_manifest_dict['runs'][0]['neuronal_model_run_id'] == 8888 + assert one_param_manifest_dict["runs"][0]["neuronal_model_run_id"] == 8888 def test_sweeps(one_param_manifest_dict): - assert set(one_param_manifest_dict['runs'][0]['sweeps']) == set([1, 2]) + assert set(one_param_manifest_dict["runs"][0]["sweeps"]) == set([1, 2]) def test_sweeps_by_type(one_param_manifest_dict): - sweeps_by_type = one_param_manifest_dict['runs'][0]['sweeps_by_type'] - assert set(sweeps_by_type['Test']) == set([1]) - assert set(sweeps_by_type['Unknown']) == set([2]) - assert set(sweeps_by_type['Long Square']) == set([3]) + sweeps_by_type = one_param_manifest_dict["runs"][0]["sweeps_by_type"] + assert set(sweeps_by_type["Test"]) == set([1]) + assert set(sweeps_by_type["Unknown"]) == set([2]) + assert set(sweeps_by_type["Long Square"]) == set([3]) assert len(sweeps_by_type.keys()) == 3 def test_mod_file_paths(one_param_config): - assert set(one_param_config.mod_file_paths()) == \ - set(['/path/to/mod_files/mod_file_1.mod']) + assert set(one_param_config.mod_file_paths()) == set(["/path/to/mod_files/mod_file_1.mod"]) def test_update_well_known_file_not_existing(no_param_config): nwb_uncompressed_file_type = 478840678 - no_param_config.update_well_known_file('/path/to/pre_existing_output.nwb') - wkf = no_param_config.lims_update_data['well_known_files'][0] - assert 'id' not in wkf - assert wkf['storage_directory'] == '/path/to' - assert wkf['filename'] == 'pre_existing_output.nwb' - assert wkf['well_known_file_type_id'] == nwb_uncompressed_file_type + no_param_config.update_well_known_file("/path/to/pre_existing_output.nwb") + wkf = no_param_config.lims_update_data["well_known_files"][0] + assert "id" not in wkf + assert wkf["storage_directory"] == "/path/to" + assert wkf["filename"] == "pre_existing_output.nwb" + assert wkf["well_known_file_type_id"] == nwb_uncompressed_file_type def test_update_well_known_file_existing(one_param_config): nwb_uncompressed_file_type = 478840678 - one_param_config.update_well_known_file('/neuronal/model/run/dir/pre_existing_output.nwb') - wkf = one_param_config.lims_update_data['well_known_files'][0] - assert wkf['id'] == 343434 - assert wkf['storage_directory'] == '/neuronal/model/run/dir' - assert wkf['filename'] == 'pre_existing_output.nwb' - assert wkf['well_known_file_type_id'] == nwb_uncompressed_file_type + one_param_config.update_well_known_file("/neuronal/model/run/dir/pre_existing_output.nwb") + wkf = one_param_config.lims_update_data["well_known_files"][0] + assert wkf["id"] == 343434 + assert wkf["storage_directory"] == "/neuronal/model/run/dir" + assert wkf["filename"] == "pre_existing_output.nwb" + assert wkf["well_known_file_type_id"] == nwb_uncompressed_file_type def test_update_well_known_file_existing_name_mismatch(one_param_config): nwb_uncompressed_file_type = 478840678 - one_param_config.update_well_known_file( - '/neuronal/model/run/dir/8888_virtual_experiment.nwb') - wkf = one_param_config.lims_update_data['well_known_files'][0] - assert 'id' not in wkf - assert wkf['storage_directory'] == '/neuronal/model/run/dir' - assert wkf['filename'] == '8888_virtual_experiment.nwb' - assert wkf['well_known_file_type_id'] == nwb_uncompressed_file_type + one_param_config.update_well_known_file("/neuronal/model/run/dir/8888_virtual_experiment.nwb") + wkf = one_param_config.lims_update_data["well_known_files"][0] + assert "id" not in wkf + assert wkf["storage_directory"] == "/neuronal/model/run/dir" + assert wkf["filename"] == "8888_virtual_experiment.nwb" + assert wkf["well_known_file_type_id"] == nwb_uncompressed_file_type def test_manifest_keys(one_param_manifest_dict): - expected_keys = set(['BASEDIR', 'WORKDIR', 'MORPHOLOGY', 'CODE_DIR', - 'MODFILE_DIR', 'MOD_FILE_mod_file_1', 'stimulus_path', - 'manifest', 'output_path', 'fit_parameters', - 'neuronal_model_run_data', 'fit_parameters']) - - actual_keys = set([e['key'] for e in one_param_manifest_dict['manifest']]) + expected_keys = set( + [ + "BASEDIR", + "WORKDIR", + "MORPHOLOGY", + "CODE_DIR", + "MODFILE_DIR", + "MOD_FILE_mod_file_1", + "stimulus_path", + "manifest", + "output_path", + "fit_parameters", + "neuronal_model_run_data", + "fit_parameters", + ] + ) + + actual_keys = set([e["key"] for e in one_param_manifest_dict["manifest"]]) assert actual_keys == expected_keys diff --git a/allensdk/test/internal/test_simulate_update_output.py b/allensdk/test/internal/test_simulate_update_output.py index 727e93daed..08c9b9431f 100644 --- a/allensdk/test/internal/test_simulate_update_output.py +++ b/allensdk/test/internal/test_simulate_update_output.py @@ -1,7 +1,7 @@ -from allensdk.internal.api.queries.biophysical_module_reader import \ - BiophysicalModuleReader +from allensdk.internal.api.queries.biophysical_module_reader import BiophysicalModuleReader import pytest from unittest.mock import patch, mock_open + try: import __builtin__ as builtins except Exception: @@ -105,12 +105,11 @@ def no_nwb_config(): bmr = BiophysicalModuleReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_NO_NWB_FILES)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_NO_NWB_FILES)): bmr.read_lims_file(lims_json_path) - + return bmr @@ -118,12 +117,11 @@ def no_nwb_config(): def one_nwb_config(): bmr = BiophysicalModuleReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_ONE_NWB_FILE)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_ONE_NWB_FILE)): bmr.read_lims_file(lims_json_path) - + return bmr @@ -131,32 +129,34 @@ def one_nwb_config(): def two_nwb_config(): bmr = BiophysicalModuleReader() - lims_json_path = 'lims_message.json' - - with patch(builtins.__name__ + ".open", - mock_open(read_data=LIMS_MESSAGE_TWO_NWB_FILES)): + lims_json_path = "lims_message.json" + + with patch(builtins.__name__ + ".open", mock_open(read_data=LIMS_MESSAGE_TWO_NWB_FILES)): bmr.read_lims_file(lims_json_path) - + return bmr def test_no_nwb(no_nwb_config): - assert no_nwb_config.lims_data['well_known_files'][0]['well_known_file_type']['id'] != 478840678 - no_nwb_config.update_well_known_file('/path/to/example.nwb') - assert no_nwb_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 478840678 - + assert no_nwb_config.lims_data["well_known_files"][0]["well_known_file_type"]["id"] != 478840678 + no_nwb_config.update_well_known_file("/path/to/example.nwb") + assert no_nwb_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 478840678 def test_one_nwb(one_nwb_config): - assert one_nwb_config.lims_data['well_known_files'][1]['well_known_file_type']['id'] == 478840678 - one_nwb_config.update_well_known_file('/projects/mousecelltypes/vol1/prod572/neuronal_model_run_496537307/496537307_virtual_experiment.nwb') - assert one_nwb_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 478840678 - assert one_nwb_config.lims_update_data['well_known_files'][1]['id'] == 22222 + assert one_nwb_config.lims_data["well_known_files"][1]["well_known_file_type"]["id"] == 478840678 + one_nwb_config.update_well_known_file( + "/projects/mousecelltypes/vol1/prod572/neuronal_model_run_496537307/496537307_virtual_experiment.nwb" + ) + assert one_nwb_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 478840678 + assert one_nwb_config.lims_update_data["well_known_files"][1]["id"] == 22222 def test_two_nwb(two_nwb_config): - assert two_nwb_config.lims_data['well_known_files'][1]['well_known_file_type']['id'] == 478840678 - two_nwb_config.update_well_known_file('/projects/mousecelltypes/vol1/prod572/neuronal_model_run_496537307/496537307_virtual_experiment.nwb') - assert two_nwb_config.lims_update_data['well_known_files'][1]['well_known_file_type_id'] == 478840678 - assert two_nwb_config.lims_update_data['well_known_files'][1]['id'] == 22222 - assert len(two_nwb_config.lims_update_data['well_known_files']) == 2 + assert two_nwb_config.lims_data["well_known_files"][1]["well_known_file_type"]["id"] == 478840678 + two_nwb_config.update_well_known_file( + "/projects/mousecelltypes/vol1/prod572/neuronal_model_run_496537307/496537307_virtual_experiment.nwb" + ) + assert two_nwb_config.lims_update_data["well_known_files"][1]["well_known_file_type_id"] == 478840678 + assert two_nwb_config.lims_update_data["well_known_files"][1]["id"] == 22222 + assert len(two_nwb_config.lims_update_data["well_known_files"]) == 2 diff --git a/allensdk/test/internal/tissuecyte_stitching/test_stitcher.py b/allensdk/test/internal/tissuecyte_stitching/test_stitcher.py index 7f8764ef43..089f5328d5 100644 --- a/allensdk/test/internal/tissuecyte_stitching/test_stitcher.py +++ b/allensdk/test/internal/tissuecyte_stitching/test_stitcher.py @@ -8,26 +8,26 @@ def test_initialize_image(): + image = stitcher.initialize_image({"row": 40, "column": 12}, 2, np.float32, "F") - image = stitcher.initialize_image({'row': 40, 'column': 12}, 2, np.float32, 'F') - - assert( np.allclose( image.shape, [40, 12, 2] ) ) - assert( np.sum(image) == 0 ) - assert( image.dtype == np.float32 ) - assert( image.flags.f_contiguous ) + assert np.allclose(image.shape, [40, 12, 2]) + assert np.sum(image) == 0 + assert image.dtype == np.float32 + assert image.flags.f_contiguous def test_initialize_images(): + ( + a, + b, + ) = stitcher.initialize_images({"row": 40, "column": 12}, 1) - a, b, = stitcher.initialize_images({'row': 40, 'column': 12}, 1) - - assert(a.dtype == np.uint16) - assert(b.dtype == np.int8) - assert(len(a.shape) == 3) + assert a.dtype == np.uint16 + assert b.dtype == np.int8 + assert len(a.shape) == 3 def test_make_blended_tile(): - tile = np.arange(25, dtype=float).reshape(5, 5) current_region = np.ones((5, 5)) * 30 blend = np.zeros((5, 5)) @@ -37,14 +37,13 @@ def test_make_blended_tile(): exp = tile.copy() exp[-1, :] = 30 exp[-2, :] = (np.arange(15, 20, dtype=float) + 30) / 2 - + obt = stitcher.make_blended_tile(blend, tile, current_region) - assert(np.allclose(exp, obt)) + assert np.allclose(exp, obt) -@pytest.mark.parametrize('lg,axis,point', [(op.lt, 0, 0), (op.gt, 0, 8), (op.lt, 1, 1), (op.gt, 1, 7)]) +@pytest.mark.parametrize("lg,axis,point", [(op.lt, 0, 0), (op.gt, 0, 8), (op.lt, 1, 1), (op.gt, 1, 7)]) def test_get_indicator_bound_point(lg, axis, point): - indicator = np.zeros((10, 10)) indicator[9:, :] = 1 indicator[:, 8:] = 1 @@ -53,14 +52,13 @@ def test_get_indicator_bound_point(lg, axis, point): indicator[7:, 7:] = 0 obt = stitcher.get_indicator_bound_point(indicator, lg, axis) - - assert( obt == point ) + assert obt == point -def test_blend_component_from_point(): +def test_blend_component_from_point(): mesh = np.tile(np.arange(20), (10, 1)) - point = 15 # indexed in the diff: 17 -> only last r/c + point = 15 # indexed in the diff: 17 -> only last r/c lg = op.gt exp = np.zeros((10, 20)) @@ -70,49 +68,46 @@ def test_blend_component_from_point(): exp[:, -1] = 1 obt = stitcher.blend_component_from_point(point, mesh, lg) - assert( np.allclose( obt, exp ) ) + assert np.allclose(obt, exp) def test_blend_component_from_point_divzero(): - mesh = np.zeros((20, 20)) point = 0 lg = op.gt obt = stitcher.blend_component_from_point(point, mesh, lg) - - assert( np.allclose(obt, mesh) ) + assert np.allclose(obt, mesh) -def test_get_blend_component_nopoint(): - with mock.patch('allensdk.internal.mouse_connectivity.tissuecyte_stitching.stitcher.get_indicator_bound_point', - new=lambda *a, **k: None): - - assert( len(stitcher.get_blend_component(1, 2, 3, 4)) == 0 ) +def test_get_blend_component_nopoint(): + with mock.patch( + "allensdk.internal.mouse_connectivity.tissuecyte_stitching.stitcher.get_indicator_bound_point", + new=lambda *a, **k: None, + ): + assert len(stitcher.get_blend_component(1, 2, 3, 4)) == 0 def test_get_blend_component_actual(): - indicator = np.zeros((20, 10)) indicator[16:, :] = 1 lg = op.gt axis = 0 - meshes = np.meshgrid(np.arange(20), np.arange(10), indexing='ij') + meshes = np.meshgrid(np.arange(20), np.arange(10), indexing="ij") exp = np.zeros_like(indicator) - exp[17, :] = 1.0 / 3.0 + exp[17, :] = 1.0 / 3.0 exp[18, :] = 2.0 / 3.0 exp[19, :] = 1.0 - + obt = stitcher.get_blend_component(indicator, lg, axis, meshes) - assert( np.allclose( obt, exp ) ) + assert np.allclose(obt, exp) def test_get_overall_blend(): - - meshes = np.meshgrid(np.arange(20), np.arange(20), indexing='ij') + meshes = np.meshgrid(np.arange(20), np.arange(20), indexing="ij") indicator = np.zeros((20, 20)) indicator[16:, :] = 1 @@ -124,13 +119,12 @@ def test_get_overall_blend(): exp[18, :] = 2.0 / 3.0 exp[:, -1] = 1 exp[-1, :] = 1 - + obt = stitcher.get_overall_blend(indicator, meshes) - assert( np.allclose(obt, exp) ) + assert np.allclose(obt, exp) def test_get_blend(): - indicator = np.zeros((20, 10)) indicator[16:, :] = 1 indicator[:, 7:] = 1 @@ -147,4 +141,4 @@ def test_get_blend(): exp = np.sqrt(exp) obt = stitcher.get_blend(indicator, stup, cb) - assert( np.allclose( obt, exp ) ) + assert np.allclose(obt, exp) diff --git a/allensdk/test/internal/tissuecyte_stitching/test_tile.py b/allensdk/test/internal/tissuecyte_stitching/test_tile.py index 15bb4dfc14..46881bc1dc 100644 --- a/allensdk/test/internal/tissuecyte_stitching/test_tile.py +++ b/allensdk/test/internal/tissuecyte_stitching/test_tile.py @@ -4,78 +4,73 @@ from allensdk.internal.mouse_connectivity.tissuecyte_stitching.tile import Tile - -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def small_tile(): - index = 20 - image = np.arange(200).reshape((10, 20)) # columns fast + image = np.arange(200).reshape((10, 20)) # columns fast is_missing = False - bounds = {'row': {'start': 40, 'end': 48}, 'column': {'start': 500, 'end': 516}} + bounds = {"row": {"start": 40, "end": 48}, "column": {"start": 500, "end": 516}} channel = 2 - size = {'row': 8, 'column': 16} - margins = {'row': 1, 'column': 2} + size = {"row": 8, "column": 16} + margins = {"row": 1, "column": 2} return Tile(index, image, is_missing, bounds, channel, size, margins) def test_trim_self(small_tile): - small_tile.trim_self() - assert( np.allclose( small_tile.image.shape, [8, 16] ) ) - assert( np.amin(small_tile.image) == 22 ) - + assert np.allclose(small_tile.image.shape, [8, 16]) + assert np.amin(small_tile.image) == 22 def test_trim(small_tile): - image = np.diag(np.arange(20)) out = small_tile.trim(image) - assert( np.allclose( out.shape, [8, 16] ) ) - assert( np.amax(out) == 8 ) - + assert np.allclose(out.shape, [8, 16]) + assert np.amax(out) == 8 -@pytest.mark.parametrize('rs,cs,yn', [(8, 16, False), (11, 21, True)]) -def test_average_tile_is_untrimmed(small_tile, rs, cs, yn): +@pytest.mark.parametrize("rs,cs,yn", [(8, 16, False), (11, 21, True)]) +def test_average_tile_is_untrimmed(small_tile, rs, cs, yn): image = np.zeros((rs, cs)) res = small_tile.average_tile_is_untrimmed(image) - assert( res == yn ) + assert res == yn -@pytest.mark.parametrize('avt,do_trim', [(np.ones((8, 16)) * 2, True), - (np.ones((8, 16)) * 2, True), - (np.ones((10, 20)) * 2, True), - (np.ones((10, 20)) * 2, False)]) +@pytest.mark.parametrize( + "avt,do_trim", + [ + (np.ones((8, 16)) * 2, True), + (np.ones((8, 16)) * 2, True), + (np.ones((10, 20)) * 2, True), + (np.ones((10, 20)) * 2, False), + ], +) def test_apply_average_tile(small_tile, avt, do_trim): - if do_trim: small_tile.trim_self() res = small_tile.apply_average_tile(avt) - assert( np.allclose(res, small_tile.image * 2) ) + assert np.allclose(res, small_tile.image * 2) def test_no_average_tile(small_tile): - res = small_tile.apply_average_tile(None) - assert( np.allclose(res, small_tile.image) ) + assert np.allclose(res, small_tile.image) def test_apply_average_tile_to_self(small_tile): - av = np.ones((10, 20)) * 3 prev = small_tile.image.copy() small_tile.apply_average_tile_to_self(av) - assert( np.allclose( small_tile.image, prev * 3 ) ) + assert np.allclose(small_tile.image, prev * 3) def test_get_image_region(small_tile): - image = np.zeros((1000, 1000, 3)) image[:, :, 0] += 1 image[:, :, 1] += 2 @@ -84,22 +79,20 @@ def test_get_image_region(small_tile): slc = small_tile.get_image_region() image = image[slc] - - assert( np.allclose( image.shape, [8, 16] ) ) - assert( image.sum() == 8 * 16 * 3 + 1 ) + assert np.allclose(image.shape, [8, 16]) + assert image.sum() == 8 * 16 * 3 + 1 -def test_get_missing_path(small_tile): +def test_get_missing_path(small_tile): obt = small_tile.get_missing_path() exp = [40, 500, 48, 500, 48, 516, 40, 516] - assert( np.allclose(obt, exp) ) + assert np.allclose(obt, exp) def test_initialize_image(small_tile): - exp = np.zeros((8, 16)) small_tile.initialize_image() - assert( np.allclose( small_tile.image, exp ) ) + assert np.allclose(small_tile.image, exp) diff --git a/allensdk/test/model/aa_model/test_biophysical_all_active.py b/allensdk/test/model/aa_model/test_biophysical_all_active.py index 92c85af85a..5d5da2de61 100644 --- a/allensdk/test/model/aa_model/test_biophysical_all_active.py +++ b/allensdk/test/model/aa_model/test_biophysical_all_active.py @@ -41,25 +41,26 @@ from allensdk.ephys import ephys_features import subprocess + @pytest.mark.requires_neuron def test_biophysical_aa(): """ Test for backward compatibility of the legacy all-active models """ - - subprocess.check_call(['nrnivmodl', 'modfiles/']) - description = Config().load('manifest.json') + subprocess.check_call(["nrnivmodl", "modfiles/"]) + + description = Config().load("manifest.json") utils = AllActiveUtils(description) h = utils.h manifest = description.manifest - morphology_path = manifest.get_path('MORPHOLOGY') - utils.generate_morphology(morphology_path.encode('ascii', 'ignore').decode("utf-8")) + morphology_path = manifest.get_path("MORPHOLOGY") + utils.generate_morphology(morphology_path.encode("ascii", "ignore").decode("utf-8")) utils.load_cell_parameters() stim = h.IClamp(h.soma[0](0.5)) - stim.amp = 0.33 # Sweep 46 + stim.amp = 0.33 # Sweep 46 stim.delay = 1000.0 stim.dur = 1000.0 @@ -70,12 +71,12 @@ def test_biophysical_aa(): h.finitialize() h.run() - junction_potential = description.data['fitting'][0]['junction_potential'] + junction_potential = description.data["fitting"][0]["junction_potential"] ms = 1.0e-3 - output_data = (numpy.array(vec['v']) - junction_potential) # in mV - output_times = numpy.array(vec['t']) * ms # in s - output_path = 'output_voltage.dat' + output_data = numpy.array(vec["v"]) - junction_potential # in mV + output_times = numpy.array(vec["t"]) * ms # in s + output_path = "output_voltage.dat" DatUtilities.save_voltage(output_path, output_data, output_times) diff --git a/allensdk/test/model/check_parser.py b/allensdk/test/model/check_parser.py index 5843fc6c2d..a9b23c00b8 100644 --- a/allensdk/test/model/check_parser.py +++ b/allensdk/test/model/check_parser.py @@ -1,8 +1,10 @@ from allensdk.model.biophysical.runner import sim_parser + def get_parsed_args(schema): print(vars(schema)) -if __name__ == '__main__': + +if __name__ == "__main__": schema = sim_parser.parse_args() get_parsed_args(schema) diff --git a/allensdk/test/model/peri_model/test_biophysical_peri.py b/allensdk/test/model/peri_model/test_biophysical_peri.py index 3fd8f10ab7..80341fe1ab 100644 --- a/allensdk/test/model/peri_model/test_biophysical_peri.py +++ b/allensdk/test/model/peri_model/test_biophysical_peri.py @@ -41,25 +41,26 @@ from allensdk.ephys import ephys_features import subprocess + @pytest.mark.requires_neuron def test_biophysical_peri(): """ Test for backward compatibility of the perisomatic models """ - - subprocess.check_call(['nrnivmodl', 'modfiles/']) - description = Config().load('manifest.json') + subprocess.check_call(["nrnivmodl", "modfiles/"]) + + description = Config().load("manifest.json") utils = Utils(description) h = utils.h manifest = description.manifest - morphology_path = manifest.get_path('MORPHOLOGY') - utils.generate_morphology(morphology_path.encode('ascii', 'ignore').decode("utf-8")) + morphology_path = manifest.get_path("MORPHOLOGY") + utils.generate_morphology(morphology_path.encode("ascii", "ignore").decode("utf-8")) utils.load_cell_parameters() stim = h.IClamp(h.soma[0](0.5)) - stim.amp = 0.35 # Sweep 47 + stim.amp = 0.35 # Sweep 47 stim.delay = 1000.0 stim.dur = 1000.0 @@ -70,12 +71,12 @@ def test_biophysical_peri(): h.finitialize() h.run() - junction_potential = description.data['fitting'][0]['junction_potential'] + junction_potential = description.data["fitting"][0]["junction_potential"] ms = 1.0e-3 - output_data = (numpy.array(vec['v']) - junction_potential) # in mV - output_times = numpy.array(vec['t']) * ms # in s - output_path = 'output_voltage.dat' + output_data = numpy.array(vec["v"]) - junction_potential # in mV + output_times = numpy.array(vec["t"]) * ms # in s + output_path = "output_voltage.dat" DatUtilities.save_voltage(output_path, output_data, output_times) diff --git a/allensdk/test/model/test_biophysical_perisomatic.py b/allensdk/test/model/test_biophysical_perisomatic.py index c84fd669ff..86600ab605 100644 --- a/allensdk/test/model/test_biophysical_perisomatic.py +++ b/allensdk/test/model/test_biophysical_perisomatic.py @@ -42,26 +42,25 @@ from allensdk.api.queries.biophysical_api import BiophysicalApi -@pytest.mark.skipif(True, - reason="partial testing") +@pytest.mark.skipif(True, reason="partial testing") @pytest.mark.xfail def test_biophysical(): - neuronal_model_id = 472451419 # get this from the web site + neuronal_model_id = 472451419 # get this from the web site - model_directory = '.' + model_directory = "." - bp = BiophysicalApi('http://api.brain-map.org') + bp = BiophysicalApi("http://api.brain-map.org") bp.cache_stimulus = False # don't want to download the large stimulus NWB file bp.cache_data(neuronal_model_id, working_directory=model_directory) - os.system('nrnivmodl modfiles') + os.system("nrnivmodl modfiles") - description = Config().load('manifest.json') + description = Config().load("manifest.json") utils = Utils(description) h = utils.h manifest = description.manifest - morphology_path = manifest.get_path('MORPHOLOGY') - utils.generate_morphology(morphology_path.encode('ascii', 'ignore')) + morphology_path = manifest.get_path("MORPHOLOGY") + utils.generate_morphology(morphology_path.encode("ascii", "ignore")) utils.load_cell_parameters() stim = h.IClamp(h.soma[0](0.5)) @@ -76,15 +75,15 @@ def test_biophysical(): h.finitialize() h.run() - output_path = 'output_voltage.dat' + output_path = "output_voltage.dat" - junction_potential = description.data['fitting'][0]['junction_potential'] + junction_potential = description.data["fitting"][0]["junction_potential"] mV = 1.0e-3 ms = 1.0e-3 - output_data = (numpy.array(vec['v']) - junction_potential) * mV - output_times = numpy.array(vec['t']) * ms + output_data = (numpy.array(vec["v"]) - junction_potential) * mV + output_times = numpy.array(vec["t"]) * ms DatUtilities.save_voltage(output_path, output_data, output_times) - + assert numpy.count_nonzero(output_data) > 0 diff --git a/allensdk/test/model/test_glif.py b/allensdk/test/model/test_glif.py index 6dfd08f60a..9376b1faf7 100644 --- a/allensdk/test/model/test_glif.py +++ b/allensdk/test/model/test_glif.py @@ -55,8 +55,8 @@ def ephys_sweeps_file(fn_temp_dir): def glif_api(): endpoint = None - if 'TEST_API_ENDPOINT' in os.environ: - endpoint = os.environ['TEST_API_ENDPOINT'] + if "TEST_API_ENDPOINT" in os.environ: + endpoint = os.environ["TEST_API_ENDPOINT"] return GlifApi(endpoint) else: return GlifApi() @@ -70,8 +70,7 @@ def neuronal_model_id(): @pytest.fixture -def configured_glif_api(glif_api, neuronal_model_id, neuron_config_file, - ephys_sweeps_file): +def configured_glif_api(glif_api, neuronal_model_id, neuron_config_file, ephys_sweeps_file): glif_api.get_neuronal_model(neuronal_model_id) neuron_config = glif_api.get_neuron_config() @@ -87,18 +86,18 @@ def configured_glif_api(glif_api, neuronal_model_id, neuron_config_file, def output(neuron_config_file, ephys_sweeps_file): neuron_config = json_utilities.read(neuron_config_file) ephys_sweeps = json_utilities.read(ephys_sweeps_file) - ephys_file_name = 'stimulus.nwb' + ephys_file_name = "stimulus.nwb" # pull out the stimulus for the first sweep ephys_sweep = ephys_sweeps[0] ds = NwbDataSet(ephys_file_name) - data = ds.get_sweep(ephys_sweep['sweep_number']) - stimulus = data['stimulus'] + data = ds.get_sweep(ephys_sweep["sweep_number"]) + stimulus = data["stimulus"] # initialize the neuron # important! update the neuron's dt for your stimulus neuron = GlifNeuron.from_dict(neuron_config) - neuron.dt = 1.0 / data['sampling_rate'] + neuron.dt = 1.0 / data["sampling_rate"] # simulate the neuron truncate = 56041 @@ -110,13 +109,13 @@ def output(neuron_config_file, ephys_sweeps_file): @pytest.fixture def stimulus(neuron_config_file, ephys_sweeps_file): ephys_sweeps = json_utilities.read(ephys_sweeps_file) - ephys_file_name = 'stimulus.nwb' + ephys_file_name = "stimulus.nwb" # pull out the stimulus for the first sweep ephys_sweep = ephys_sweeps[0] ds = NwbDataSet(ephys_file_name) - data = ds.get_sweep(ephys_sweep['sweep_number']) - stimulus = data['stimulus'] + data = ds.get_sweep(ephys_sweep["sweep_number"]) + stimulus = data["stimulus"] return stimulus @@ -136,10 +135,15 @@ def test_run_glifneuron(configured_glif_api, neuron_config_file): # simulate the neuron output = neuron.run(stimulus) - expected_fields = {"AScurrents", "grid_spike_times", - "interpolated_spike_threshold", - "interpolated_spike_times", - "interpolated_spike_voltage", - "spike_time_steps", "threshold", "voltage"} + expected_fields = { + "AScurrents", + "grid_spike_times", + "interpolated_spike_threshold", + "interpolated_spike_times", + "interpolated_spike_voltage", + "spike_time_steps", + "threshold", + "voltage", + } assert expected_fields.difference(output.keys()) == set() diff --git a/allensdk/test/model/test_runner.py b/allensdk/test/model/test_runner.py index 4c93f1db94..2950e70541 100644 --- a/allensdk/test/model/test_runner.py +++ b/allensdk/test/model/test_runner.py @@ -1,14 +1,17 @@ import sys import subprocess + def test_args(): """ Test for legacy and newest biophysical model simulation calls """ # Legacy all-active simulation call pattern - args_legacy = subprocess.check_output([sys.executable, '-m', 'allensdk.test.model.check_parser', 'manifest.json']) - assert 'stub' not in args_legacy.decode('utf-8') + args_legacy = subprocess.check_output([sys.executable, "-m", "allensdk.test.model.check_parser", "manifest.json"]) + assert "stub" not in args_legacy.decode("utf-8") # Current all-active simulation call pattern - args_new = subprocess.check_output([sys.executable, '-m', 'allensdk.test.model.check_parser', 'manifest.json', '--axon_type', 'stub']) - assert 'stub' in args_new.decode('utf-8') + args_new = subprocess.check_output( + [sys.executable, "-m", "allensdk.test.model.check_parser", "manifest.json", "--axon_type", "stub"] + ) + assert "stub" in args_new.decode("utf-8") diff --git a/allensdk/test/mouse_connectivity/grid/test_base_subimage.py b/allensdk/test/mouse_connectivity/grid/test_base_subimage.py index 23c7898c96..7407cfab54 100644 --- a/allensdk/test/mouse_connectivity/grid/test_base_subimage.py +++ b/allensdk/test/mouse_connectivity/grid/test_base_subimage.py @@ -55,11 +55,7 @@ def polygon_params(): "in_dims": np.array([30000, 40000]), "in_spacing": np.array([0.35, 0.35]), "coarse_spacing": np.array([16.8, 16.8]), - "polygon_info": { - "hello_am_square": [ - [(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)] - ] - }, + "polygon_info": {"hello_am_square": [[(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)]]}, } @@ -68,14 +64,12 @@ def polygon_params(): def test_init_base(base_params): - si = SubImage(**base_params) assert np.allclose(si.coarse_dims, [625, 834]) def test_binarize(base_params): - si = SubImage(**base_params) si.images["fish"] = np.arange(25).reshape([5, 5]) @@ -89,7 +83,6 @@ def test_binarize(base_params): @pytest.mark.parametrize("positive", [(True), (False)]) def test_apply_mask(base_params, positive): - si = SubImage(**base_params) si.images["submarine"] = np.arange(25).reshape([5, 5]) @@ -106,7 +99,6 @@ def test_apply_mask(base_params, positive): def test_make_pixel_counter(base_params): - si = SubImage(**base_params) reducer = si.make_pixel_counter() @@ -131,7 +123,6 @@ def test_init_segmentation(segmentation_params): def test_extract_signal_from_segmentation(segmentation_params): - si = SegmentationSubImage(**segmentation_params) segmentation_name = "fish" @@ -145,7 +136,6 @@ def test_extract_signal_from_segmentation(segmentation_params): def test_extract_injection_from_segmentation(segmentation_params): - si = SegmentationSubImage(**segmentation_params) segmentation_name = "fish" @@ -162,7 +152,6 @@ def test_extract_injection_from_segmentation(segmentation_params): def test_read_segmentation_image(segmentation_params): - si = SegmentationSubImage(**segmentation_params) arr = np.zeros((32, 32)) arr[:16, :] = 1 @@ -170,8 +159,7 @@ def test_read_segmentation_image(segmentation_params): exp = np.array([[1, 1], [0, 0]]) with mock.patch( - "allensdk.mouse_connectivity.grid.utilities.image_utilities" - ".read_segmentation_image", + "allensdk.mouse_connectivity.grid.utilities.image_utilities.read_segmentation_image", return_value=arr, ) as p: si.read_segmentation_image("name_one") @@ -185,7 +173,6 @@ def test_read_segmentation_image(segmentation_params): def test_init_intensity(intensity_params): - si = IntensitySubImage(**intensity_params) assert si.intensity_paths["name_one"]["path"] == "path_one" @@ -193,22 +180,19 @@ def test_init_intensity(intensity_params): def test_get_intensity(intensity_params): - arr = np.eye(1000) arr[999, 0] = 1 si = IntensitySubImage(**intensity_params) with mock.patch( - "allensdk.mouse_connectivity.grid.subimage.base_subimage" - ".IntensitySubImage.required_intensities", + "allensdk.mouse_connectivity.grid.subimage.base_subimage.IntensitySubImage.required_intensities", new_callable=mock.PropertyMock, ) as a: a.return_value = ["name_one"] with mock.patch( - "allensdk.mouse_connectivity.grid.utilities.image_utilities" - ".read_intensity_image", + "allensdk.mouse_connectivity.grid.utilities.image_utilities.read_intensity_image", return_value=arr, ) as p: si.get_intensity() @@ -222,27 +206,22 @@ def test_get_intensity(intensity_params): def test_init_polygon(polygon_params): - si = PolygonSubImage(**polygon_params) assert np.allclose( si.polygon_info["hello_am_square"], - np.array( - [[(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)]] - ), + np.array([[(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)]]), ) def test_get_polygons(polygon_params): - si = PolygonSubImage(**polygon_params) arr = np.zeros([1875, 2500]) arr[500:1001, 500:1001] = 1 with mock.patch( - "allensdk.mouse_connectivity.grid.subimage.base_subimage" - ".PolygonSubImage.required_polys", + "allensdk.mouse_connectivity.grid.subimage.base_subimage.PolygonSubImage.required_polys", new_callable=mock.PropertyMock, ) as a: a.return_value = ["hello_am_square"] diff --git a/allensdk/test/mouse_connectivity/grid/test_cav_subimage.py b/allensdk/test/mouse_connectivity/grid/test_cav_subimage.py index ad972b2e60..58fa4ba060 100644 --- a/allensdk/test/mouse_connectivity/grid/test_cav_subimage.py +++ b/allensdk/test/mouse_connectivity/grid/test_cav_subimage.py @@ -5,24 +5,25 @@ from allensdk.mouse_connectivity.grid.subimage import CavSubImage -#============================================================================== -#============================================================================== +# ============================================================================== +# ============================================================================== -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def cav_params(): - return {'reduce_level': 4, - 'in_dims': np.array([30000, 40000]), - 'in_spacing': np.array([0.35, 0.35]), - 'coarse_spacing': np.array([16.8, 16.8]), - 'polygon_info': {'missing_tile': [[(8000, 8000), (16000, 8000), - (16000, 16000), (8000, 16000)]], - 'cav_tracer': [(4000, 4000), (8000, 4000), - (8000, 8000), (4000, 8000)]}} + return { + "reduce_level": 4, + "in_dims": np.array([30000, 40000]), + "in_spacing": np.array([0.35, 0.35]), + "coarse_spacing": np.array([16.8, 16.8]), + "polygon_info": { + "missing_tile": [[(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)]], + "cav_tracer": [(4000, 4000), (8000, 4000), (8000, 8000), (4000, 8000)], + }, + } def test_compute_coarse_planes(cav_params): - mt = np.ones([1875, 2500]) mt[:300, :] = 0 @@ -30,8 +31,8 @@ def test_compute_coarse_planes(cav_params): ct[:, :300] = 1 si = CavSubImage(**cav_params) - si.images['cav_tracer'] = ct - si.images['missing_tile'] = mt + si.images["cav_tracer"] = ct + si.images["missing_tile"] = mt si.compute_coarse_planes() @@ -42,5 +43,5 @@ def test_compute_coarse_planes(cav_params): sum_pixels_expected[:100, :] = 144 sum_pixels_expected[:100, -1] = 48 - assert(np.allclose( sum_pixels_expected * 2, si.accumulators['sum_pixels'] )) - assert(np.allclose( cav_tracer_expected * 2, si.accumulators['cav_tracer'] )) + assert np.allclose(sum_pixels_expected * 2, si.accumulators["sum_pixels"]) + assert np.allclose(cav_tracer_expected * 2, si.accumulators["cav_tracer"]) diff --git a/allensdk/test/mouse_connectivity/grid/test_classic_subimage.py b/allensdk/test/mouse_connectivity/grid/test_classic_subimage.py index a1722f8358..559029ed37 100644 --- a/allensdk/test/mouse_connectivity/grid/test_classic_subimage.py +++ b/allensdk/test/mouse_connectivity/grid/test_classic_subimage.py @@ -5,41 +5,42 @@ from allensdk.mouse_connectivity.grid.subimage import ClassicSubImage -#============================================================================== -#============================================================================== +# ============================================================================== +# ============================================================================== -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def classic_params(): - return {'reduce_level': 4, - 'in_dims': np.array([30000, 40000]), - 'in_spacing': np.array([0.35, 0.35]), - 'coarse_spacing': np.array([16.8, 16.8]), - 'segmentation_paths': {'segmentation': '/path/to_segmentation'}, - 'intensity_paths': {'green': {'path': '/path/to/intensity', 'channel': 1}}, - 'polygon_info': {'missing_tile': [[(8000, 8000), (16000, 8000), - (16000, 16000), (8000, 16000)]], - 'no_signal': [(4000, 4000), (8000, 4000), - (8000, 8000), (4000, 8000)]}} - - -@pytest.fixture(scope='function') + return { + "reduce_level": 4, + "in_dims": np.array([30000, 40000]), + "in_spacing": np.array([0.35, 0.35]), + "coarse_spacing": np.array([16.8, 16.8]), + "segmentation_paths": {"segmentation": "/path/to_segmentation"}, + "intensity_paths": {"green": {"path": "/path/to/intensity", "channel": 1}}, + "polygon_info": { + "missing_tile": [[(8000, 8000), (16000, 8000), (16000, 16000), (8000, 16000)]], + "no_signal": [(4000, 4000), (8000, 4000), (8000, 8000), (4000, 8000)], + }, + } + + +@pytest.fixture(scope="function") def missing_tile(): image = np.zeros([1875, 2500], dtype=np.uint8) image[500:1000, 500:1000] = 1 return image -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def no_signal(): image = np.zeros([1875, 2500], dtype=np.uint8) image[250:500, 250:500] = 1 return image -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def segmentation_image(): - image = np.zeros([1875, 2500], dtype=np.uint8) image[:1000, :] += 1 image[:, :1000] += 128 @@ -48,9 +49,8 @@ def segmentation_image(): return image -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def projection(): - image = np.zeros([1875, 2500], dtype=np.uint8) image[:, :1000] = 1 @@ -61,9 +61,8 @@ def projection(): return image -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def injection(): - image = np.zeros([1875, 2500], dtype=np.uint8) image[:1000, :] = 1 @@ -73,36 +72,32 @@ def injection(): return image -#============================================================================== -#============================================================================== +# ============================================================================== +# ============================================================================== def test_init_classic(classic_params): - si = ClassicSubImage(**classic_params) - assert( hasattr(si, 'intensity_paths') ) - assert( hasattr(si, 'segmentation_paths') ) - assert( hasattr(si, 'polygon_info') ) + assert hasattr(si, "intensity_paths") + assert hasattr(si, "segmentation_paths") + assert hasattr(si, "polygon_info") -def test_process_segmentation(classic_params, segmentation_image, - missing_tile, no_signal, projection, injection): - +def test_process_segmentation(classic_params, segmentation_image, missing_tile, no_signal, projection, injection): si = ClassicSubImage(**classic_params) - si.images['segmentation'] = segmentation_image - si.images['missing_tile'] = missing_tile - si.images['no_signal'] = no_signal + si.images["segmentation"] = segmentation_image + si.images["missing_tile"] = missing_tile + si.images["no_signal"] = no_signal si.process_segmentation() - assert(np.allclose( projection, si.images['projection'] )) - assert(np.allclose( injection, si.images['injection'] )) - assert( 'segmentation' not in si.images ) - + assert np.allclose(projection, si.images["projection"]) + assert np.allclose(injection, si.images["injection"]) + assert "segmentation" not in si.images -def test_compute_intensity(classic_params): +def test_compute_intensity(classic_params): pr = np.zeros([1875, 2500]) pr[:600, :] = 1 @@ -110,9 +105,9 @@ def test_compute_intensity(classic_params): ij[:, :600] = 1 si = ClassicSubImage(**classic_params) - si.images['green'] = np.ones([1875, 2500]) * 2 - si.images['projection'] = pr - si.images['injection'] = ij + si.images["green"] = np.ones([1875, 2500]) * 2 + si.images["projection"] = pr + si.images["injection"] = ij si.compute_intensity() @@ -127,16 +122,15 @@ def test_compute_intensity(classic_params): isppi_expected = np.zeros_like(spi_expected) isppi_expected[:200, :200] = spi_expected[:200, :200] - - assert(np.allclose( spi_expected * 2, si.accumulators['sum_pixel_intensities'] )) - assert(np.allclose( ispi_expected * 2, si.accumulators['injection_sum_pixel_intensities'] )) - assert(np.allclose( sppi_expected * 2, si.accumulators['sum_projecting_pixel_intensities'] )) - assert(np.allclose( isppi_expected * 2, si.accumulators['injectionsum_projecting_pixel_intensities'] )) - assert( 'intensity' not in si.images ) + assert np.allclose(spi_expected * 2, si.accumulators["sum_pixel_intensities"]) + assert np.allclose(ispi_expected * 2, si.accumulators["injection_sum_pixel_intensities"]) + assert np.allclose(sppi_expected * 2, si.accumulators["sum_projecting_pixel_intensities"]) + assert np.allclose(isppi_expected * 2, si.accumulators["injectionsum_projecting_pixel_intensities"]) + assert "intensity" not in si.images -def test_compute_injection(classic_params): +def test_compute_injection(classic_params): pr = np.zeros([1875, 2500]) pr[:600, :] = 1 @@ -144,8 +138,8 @@ def test_compute_injection(classic_params): ij[:, :600] = 1 si = ClassicSubImage(**classic_params) - si.images['projection'] = pr - si.images['injection'] = ij + si.images["projection"] = pr + si.images["injection"] = ij si.compute_injection() @@ -155,53 +149,50 @@ def test_compute_injection(classic_params): ispp_expected = np.zeros([625, 834]) ispp_expected[:200, :200] = 144 - assert(np.allclose( isp_expected * 2, si.accumulators['injection_sum_pixels'] )) - assert(np.allclose( ispp_expected * 2, si.accumulators['injection_sum_projecting_pixels'] )) - assert( 'injection' not in si.images ) + assert np.allclose(isp_expected * 2, si.accumulators["injection_sum_pixels"]) + assert np.allclose(ispp_expected * 2, si.accumulators["injection_sum_projecting_pixels"]) + assert "injection" not in si.images def test_compute_projection(classic_params): - pr = np.zeros([1875, 2500]) pr[:600, :] = 1 si = ClassicSubImage(**classic_params) - si.images['projection'] = pr + si.images["projection"] = pr si.compute_projection() spp_expected = np.zeros([625, 834]) spp_expected[:200, :] = 144 spp_expected[:200, -1] = 48 - - assert(np.allclose( spp_expected * 2, si.accumulators['sum_projecting_pixels'] )) - assert( 'projection' not in si.images ) + assert np.allclose(spp_expected * 2, si.accumulators["sum_projecting_pixels"]) + assert "projection" not in si.images -def test_compute_sum_pixels_aav(classic_params): +def test_compute_sum_pixels_aav(classic_params): mt = np.zeros([1875, 2500]) mt[:300, :300] = 1 - + aav = np.zeros([1875, 2500]) aav[300:600, :] = 1 si = ClassicSubImage(**classic_params) - si.images['missing_tile'] = mt - si.images['aav_exclusion'] = aav + si.images["missing_tile"] = mt + si.images["aav_exclusion"] = aav si.compute_sum_pixels() sp_expected = np.zeros([625, 834]) + 144 sp_expected[:100, :100] = 0 -# sp_expected[100, :101] = 48 -# sp_expected[:101, 100] = 48 + # sp_expected[100, :101] = 48 + # sp_expected[:101, 100] = 48 sp_expected[:, -1] = 48 - + aesp_expected = np.zeros([625, 834]) aesp_expected[100:200, :] = 144 aesp_expected[100:200, -1] = 48 - assert(np.allclose( aesp_expected * 2, si.accumulators['aav_exclusion_sum_pixels'] )) - assert(np.allclose( sp_expected * 2, si.accumulators['sum_pixels'] )) - + assert np.allclose(aesp_expected * 2, si.accumulators["aav_exclusion_sum_pixels"]) + assert np.allclose(sp_expected * 2, si.accumulators["sum_pixels"]) diff --git a/allensdk/test/mouse_connectivity/grid/test_image_series_gridder.py b/allensdk/test/mouse_connectivity/grid/test_image_series_gridder.py index 8a7afac80f..7e7d559995 100644 --- a/allensdk/test/mouse_connectivity/grid/test_image_series_gridder.py +++ b/allensdk/test/mouse_connectivity/grid/test_image_series_gridder.py @@ -8,183 +8,195 @@ def small_gridder(): - in_dims = [12353, 16471, 140] in_spacing = [0.85, 0.85, 100.0] out_dims = [1320, 800, 1140] out_spacing = [10.0, 10.0, 10.0] reduce_level = 0 - - subimages = [{'index': 12, - 'segmentation_path': '/path/to/projection_12.jp2', - 'intensity_path': '/path/to/image_12.jp2', - 'polygon_info': {'missing_tile': [], - 'no_signal': [], - 'aav_exclusion': []}}] - - subimage_kwargs = {'cls': dict, 'channel': 1} - + + subimages = [ + { + "index": 12, + "segmentation_path": "/path/to/projection_12.jp2", + "intensity_path": "/path/to/image_12.jp2", + "polygon_info": {"missing_tile": [], "no_signal": [], "aav_exclusion": []}, + } + ] + + subimage_kwargs = {"cls": dict, "channel": 1} + nprocesses = 8 - - affine_params = [2, 0 , 0, 0, 2, 0, 0, 0, 2, 1, 1, 1] - dfmfld_path = '/path/to/deformation_field_header.mhd' + affine_params = [2, 0, 0, 0, 2, 0, 0, 0, 2, 1, 1, 1] + dfmfld_path = "/path/to/deformation_field_header.mhd" + + return ImageSeriesGridder( + in_dims, + in_spacing, + out_dims, + out_spacing, + reduce_level, + subimages, + subimage_kwargs, + nprocesses, + affine_params, + dfmfld_path, + ) - return ImageSeriesGridder(in_dims, in_spacing, out_dims, out_spacing, - reduce_level, subimages, subimage_kwargs, - nprocesses, affine_params, dfmfld_path) - def large_gridder(): - in_dims = [30000, 40000, 140] in_spacing = [0.35, 0.35, 100.0] out_dims = [1320, 800, 1140] out_spacing = [10.0, 10.0, 10.0] reduce_level = 1 - - subimages = [{'index': 12, - 'segmentation_path': '/path/to/projection_12.jp2', - 'intensity_path': '/path/to/image_12.jp2', - 'polygon_info': {'missing_tile': [], - 'no_signal': [], - 'aav_exclusion': []}}] - - subimage_kwargs = {'cls': dict, 'channel': 1} - + + subimages = [ + { + "index": 12, + "segmentation_path": "/path/to/projection_12.jp2", + "intensity_path": "/path/to/image_12.jp2", + "polygon_info": {"missing_tile": [], "no_signal": [], "aav_exclusion": []}, + } + ] + + subimage_kwargs = {"cls": dict, "channel": 1} + nprocesses = 8 - - affine_params = [2, 0 , 0, 0, 2, 0, 0, 0, 2, 1, 1, 1] - dfmfld_path = '/path/to/deformation_field_header.mhd' - - - return ImageSeriesGridder(in_dims, in_spacing, out_dims, out_spacing, - reduce_level, subimages, subimage_kwargs, - nprocesses, affine_params, dfmfld_path) - - -@pytest.mark.parametrize('gridder_fn,cgd,cgs,cgr', [(small_gridder, [951, 1267, 140], [11.05, 11.05, 100], 6), - (large_gridder, [1000, 1334, 140], [10.5, 10.5, 100], 7)]) -def test_set_coarse_grid_parameters(gridder_fn, cgd, cgs, cgr): + affine_params = [2, 0, 0, 0, 2, 0, 0, 0, 2, 1, 1, 1] + dfmfld_path = "/path/to/deformation_field_header.mhd" + + return ImageSeriesGridder( + in_dims, + in_spacing, + out_dims, + out_spacing, + reduce_level, + subimages, + subimage_kwargs, + nprocesses, + affine_params, + dfmfld_path, + ) + + +@pytest.mark.parametrize( + "gridder_fn,cgd,cgs,cgr", + [ + (small_gridder, [951, 1267, 140], [11.05, 11.05, 100], 6), + (large_gridder, [1000, 1334, 140], [10.5, 10.5, 100], 7), + ], +) +def test_set_coarse_grid_parameters(gridder_fn, cgd, cgs, cgr): gridder = gridder_fn() gridder.set_coarse_grid_parameters() - - assert(np.allclose( gridder.coarse_dims, cgd )) - assert(np.allclose( gridder.coarse_spacing, cgs )) - assert(np.allclose( gridder.coarse_grid_radius, cgr )) + assert np.allclose(gridder.coarse_dims, cgd) + assert np.allclose(gridder.coarse_spacing, cgs) + assert np.allclose(gridder.coarse_grid_radius, cgr) -@pytest.mark.parametrize('gridder_fn,reduce_level', [(small_gridder, 0), (large_gridder, 1)]) -def test_setup_subimages(gridder_fn, reduce_level): +@pytest.mark.parametrize("gridder_fn,reduce_level", [(small_gridder, 0), (large_gridder, 1)]) +def test_setup_subimages(gridder_fn, reduce_level): gridder = gridder_fn() gridder.setup_subimages() - + for s in gridder.subimages: - assert( s['reduce_level'] == reduce_level ) + assert s["reduce_level"] == reduce_level gridder = gridder_fn() - + def test_initialize_coarse_volume(): - - key = 'amethystine' + key = "amethystine" size = [1, 2, 3] spacing = [4, 5, 6] - + gridder = small_gridder() gridder.coarse_dims = size gridder.coarse_spacing = spacing - + gridder.initialize_coarse_volume(key, sitk.sitkFloat32) - - assert(np.allclose( gridder.volumes[key].GetSize(), size )) - assert(np.allclose( gridder.volumes[key].GetSpacing(), spacing )) - - + + assert np.allclose(gridder.volumes[key].GetSize(), size) + assert np.allclose(gridder.volumes[key].GetSpacing(), spacing) + + def test_paste_slice(): - - key = 'halmahera' + key = "halmahera" slice_array = np.eye(1000) index = 12 - + volume = sitk.Image(1000, 1000, 140, sitk.sitkFloat32) volume.SetSpacing([1, 1, 100]) - + gridder = small_gridder() gridder.coarse_spacing = [1, 1, 100] gridder.volumes[key] = volume - + gridder.paste_slice(key, index, slice_array) - + obt = sitk.GetArrayFromImage(gridder.volumes[key]) - assert(np.allclose( obt[12, :, :], slice_array )) - - + assert np.allclose(obt[12, :, :], slice_array) + + def test_paste_subimage(): - index = 1 - output = {'a': 1, 'b': 2} - + output = {"a": 1, "b": 2} + gridder = small_gridder() gridder.paste_slice = mock.MagicMock() - + gridder.paste_subimage(index, output) - - assert( len(gridder.paste_slice.mock_calls) == 2 ) - assert( output['a'] is None and output['b'] is None ) - - + + assert len(gridder.paste_slice.mock_calls) == 2 + assert output["a"] is None and output["b"] is None + + def test_build_coarse_grids(): # mp is hard to test class Dummy(object): - def __init__(self, *a, **k): pass - + def imap_unordered(*a, **k): for ii in range(20): yield ii, ii - with mock.patch('multiprocessing.Pool', new=Dummy): - + with mock.patch("multiprocessing.Pool", new=Dummy): gridder = small_gridder() gridder.paste_subimage = mock.MagicMock() - + gridder.build_coarse_grids() - + for ii in range(20): - assert( mock.call(ii, ii) in gridder.paste_subimage.mock_calls ) - - + assert mock.call(ii, ii) in gridder.paste_subimage.mock_calls + + def test_resample_volume(): - - def make_dfield(*a , **k): + def make_dfield(*a, **k): return - + def make_transform(*a, **k): return sitk.TranslationTransform(3, [2, 2, 2]) - - key = 'green_tree' - + + key = "green_tree" + volume = sitk.Image(10, 10, 10, sitk.sitkFloat32) volume.SetSpacing([1, 1, 1]) volume += 1 - - with mock.patch('SimpleITK.ReadImage', new=make_dfield): + + with mock.patch("SimpleITK.ReadImage", new=make_dfield): with mock.patch( - 'allensdk.mouse_connectivity.grid.utilities.image_utilities.build_composite_transform', - new=make_transform + "allensdk.mouse_connectivity.grid.utilities.image_utilities.build_composite_transform", new=make_transform ): - gridder = small_gridder() gridder.out_dims = [10, 10, 10] gridder.out_spacing = [1, 1, 1] - - gridder.volumes[key] = volume + + gridder.volumes[key] = volume gridder.resample_volume(key) - + arr = sitk.GetArrayFromImage(gridder.volumes[key]) - assert( arr.sum() == 8**3 ) - + assert arr.sum() == 8**3 diff --git a/allensdk/test/mouse_connectivity/grid/test_image_utilities.py b/allensdk/test/mouse_connectivity/grid/test_image_utilities.py index 0ab94d033c..0389d053d2 100644 --- a/allensdk/test/mouse_connectivity/grid/test_image_utilities.py +++ b/allensdk/test/mouse_connectivity/grid/test_image_utilities.py @@ -6,7 +6,6 @@ @pytest.fixture(scope="function") def dfmfld(): - disp = sitk.Image(10, 10, 10, sitk.sitkVectorFloat64) disp.SetSpacing([1, 1, 1]) @@ -21,7 +20,6 @@ def aff_params(): def test_set_image_spacing(): - im = sitk.Image(5, 5, 5, sitk.sitkFloat32) iu.set_image_spacing(im, [1, 2, 3]) @@ -31,7 +29,6 @@ def test_set_image_spacing(): def test_new_image_3d(): - im = iu.new_image([300, 200, 100], [1, 2, 3], sitk.sitkFloat32) assert np.allclose(im.GetSize(), [300, 200, 100]) @@ -40,7 +37,6 @@ def test_new_image_3d(): def test_new_image_2d(): - im = iu.new_image([300, 200], [1, 2], sitk.sitkFloat32) assert np.allclose(im.GetSize(), [300, 200]) @@ -50,7 +46,6 @@ def test_new_image_2d(): @pytest.mark.parametrize("np_type,sitk_type", [(np.float32, sitk.sitkFloat32)]) def test_np_sitk_convert(np_type, sitk_type): - arr = np.zeros((100, 100), dtype=np_type) sitk_obt = iu.np_sitk_convert(arr.dtype) @@ -61,7 +56,6 @@ def test_np_sitk_convert(np_type, sitk_type): def test_compute_coarse_parameters(): - in_dims = [1000, 2000] in_spacing = [5, 5] out_spacing = [100, 100] @@ -71,9 +65,7 @@ def test_compute_coarse_parameters(): cgs_exp = [100, 100] cgr_exp = [2, 2] - cgd_obt, cgs_obt, cgr_obt = iu.compute_coarse_parameters( - in_dims, in_spacing, out_spacing, reduce_level - ) + cgd_obt, cgs_obt, cgr_obt = iu.compute_coarse_parameters(in_dims, in_spacing, out_spacing, reduce_level) assert np.allclose(cgd_obt, cgd_exp) assert np.allclose(cgs_obt, cgs_exp) @@ -81,7 +73,6 @@ def test_compute_coarse_parameters(): def test_block_apply(): - row_blocks = [(ii, jj) for ii, jj in zip(range(0, 10, 2), range(2, 12, 2))] col_blocks = [(ii, jj) for ii, jj in zip(range(0, 10, 5), range(5, 15, 5))] blocks = [row_blocks, col_blocks] @@ -96,7 +87,6 @@ def test_block_apply(): def test_grid_image_blocks(): - in_shape = [10, 10] in_spacing = [5, 5] out_spacing = [20, 20] @@ -111,7 +101,6 @@ def test_grid_image_blocks(): def test_rasterize_polygons(): - shape = [10, 10] scale = [1, 1] points_list = [[(4, 4), (6, 4), (6, 6), (4, 6)]] @@ -124,7 +113,6 @@ def test_rasterize_polygons(): def test_resample_into_volume(): - vol = sitk.Image(20, 20, 10, sitk.sitkFloat32) vol.SetSpacing([10, 10, 10]) @@ -141,7 +129,6 @@ def test_resample_into_volume(): def test_build_affine_transform(aff_params): - point = (1, 2, 3) tf = iu.build_affine_transform(aff_params) @@ -152,7 +139,6 @@ def test_build_affine_transform(aff_params): def test_build_composite_transform(dfmfld, aff_params): - point = (5, 5, 5) exp = (15, 15, 15) @@ -163,7 +149,6 @@ def test_build_composite_transform(dfmfld, aff_params): def test_resample_volume(): - volume = np.ones((10, 10, 10)).astype(np.float32) dims = [10, 10, 10] spacing = [1, 1, 1] diff --git a/allensdk/test/test_argschema_utilities.py b/allensdk/test/test_argschema_utilities.py index b2455df7cb..b3d8934b12 100644 --- a/allensdk/test/test_argschema_utilities.py +++ b/allensdk/test/test_argschema_utilities.py @@ -9,7 +9,8 @@ OutputFile, RaisingSchema, check_write_access, - check_write_access_overwrite) + check_write_access_overwrite, +) from marshmallow import Schema, ValidationError READ_ONLY = stat.S_IREAD | stat.S_IRGRP | stat.S_IROTH @@ -102,7 +103,6 @@ def teardown(self): ], ) def test_check_write_access(tmpdir_factory, harness_cls, fn, raises): - base_dir = str(tmpdir_factory.mktemp("HW")) harness = harness_cls(base_dir) @@ -126,13 +126,15 @@ class GenericOutputSchema(RaisingSchema): class TestInputFile(object): - def setup_method(self): self.parser = GenericInputSchema() - @pytest.mark.parametrize("input_data", [ - ({"input_file": "/some/invalid_filepath/input.h5"}), - ]) + @pytest.mark.parametrize( + "input_data", + [ + ({"input_file": "/some/invalid_filepath/input.h5"}), + ], + ) def test_invalid_input_file(self, input_data): with pytest.raises(ValidationError, match=r"No such file or directory"): self.parser.load(input_data) @@ -147,19 +149,21 @@ def test_valid_input_file(self, tmpdir): class TestOutputFile(object): - def setup_method(self): self.parser = GenericOutputSchema() - @pytest.mark.parametrize("output_data", [ - ({"output_file": "////invalid_filepath/output.json"}), - ]) + @pytest.mark.parametrize( + "output_data", + [ + ({"output_file": "////invalid_filepath/output.json"}), + ], + ) def test_invalid_output_file(self, output_data): # Apparently allensdk.brain_observatory.argschema_utilities tests are # skipped on Windows systems and the `check_write_access_overwrite` # function itself does not work correctly on Windows systems. # TODO: This is a stopgap for now - if os.name == 'nt': + if os.name == "nt": pytest.skip() # This test was failing on Bamboo because it was run in a container # as root (which means pretty much anything is writable). If this test diff --git a/allensdk/test/test_deprecated.py b/allensdk/test/test_deprecated.py index e226ae88f6..8a2b08c494 100644 --- a/allensdk/test/test_deprecated.py +++ b/allensdk/test/test_deprecated.py @@ -40,22 +40,20 @@ @pytest.fixture def deprecated_method(): - @deprecated() def i_am_deprecated(): pass return i_am_deprecated - - + + @pytest.fixture def deprecated_class(): - - @class_deprecated('msg') + @class_deprecated("msg") class dep_cls(object): def __init__(self, a): self.a = a - + return dep_cls @@ -63,23 +61,23 @@ def test_deprecated(deprecated_method): expected = "Function i_am_deprecated is deprecated. " with warnings.catch_warnings(record=True) as c: - warnings.simplefilter('always') + warnings.simplefilter("always") deprecated_method() print(expected) print(str(c[-1].message)) assert expected == str(c[-1].message) - - + + def test_deprecated_class(deprecated_class, deprecated_method): - expected = 'Class dep_cls is deprecated. msg' - + expected = "Class dep_cls is deprecated. msg" + with warnings.catch_warnings(record=True) as c: - warnings.simplefilter('always') + warnings.simplefilter("always") deprecated_method() obj = deprecated_class(1) - assert( expected == str(c[-1].message) ) - assert( obj.a == 1 ) + assert expected == str(c[-1].message) + assert obj.a == 1 diff --git a/allensdk/test/test_inline_examples.py b/allensdk/test/test_inline_examples.py index d20a2b11bb..421bb3391a 100644 --- a/allensdk/test/test_inline_examples.py +++ b/allensdk/test/test_inline_examples.py @@ -4,20 +4,12 @@ import pytest -EXAMPLE_DIR = os.path.join( - os.path.dirname(__file__), - '..', - '..', - 'doc_template', - 'examples_root', - 'examples' -) -EXAMPLES = [filename for filename in os.listdir(EXAMPLE_DIR) if filename.split('.')[-1] == 'py'] +EXAMPLE_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "doc_template", "examples_root", "examples") +EXAMPLES = [filename for filename in os.listdir(EXAMPLE_DIR) if filename.split(".")[-1] == "py"] @pytest.mark.nightly -@pytest.mark.parametrize('script_name', EXAMPLES) +@pytest.mark.parametrize("script_name", EXAMPLES) def test_inline_examples(script_name, tmpdir_factory): - - data_dir = tmpdir_factory.mktemp('inline_examples_data') - sp.check_call(['python', os.path.join(EXAMPLE_DIR, script_name)], cwd=str(data_dir)) \ No newline at end of file + data_dir = tmpdir_factory.mktemp("inline_examples_data") + sp.check_call(["python", os.path.join(EXAMPLE_DIR, script_name)], cwd=str(data_dir)) diff --git a/allensdk/test/test_lims_queries.py b/allensdk/test/test_lims_queries.py index 65f6359646..dc81641972 100644 --- a/allensdk/test/test_lims_queries.py +++ b/allensdk/test/test_lims_queries.py @@ -1,22 +1,18 @@ import pytest import pandas as pd -from allensdk.internal.api.queries.utils import ( - build_in_list_selector_query, - _sanitize_uuid_list, build_where_clause) +from allensdk.internal.api.queries.utils import build_in_list_selector_query, _sanitize_uuid_list, build_where_clause -from allensdk.internal.api.queries.behavior_lims_queries import ( - foraging_id_map_from_behavior_session_id) +from allensdk.internal.api.queries.behavior_lims_queries import foraging_id_map_from_behavior_session_id -from allensdk.internal.api.queries.ecephys_lims_queries import ( - donor_id_list_from_ecephys_session_ids) +from allensdk.internal.api.queries.ecephys_lims_queries import donor_id_list_from_ecephys_session_ids -from allensdk.test_utilities.custom_comparators import ( - WhitespaceStrippedString) +from allensdk.test_utilities.custom_comparators import WhitespaceStrippedString @pytest.mark.parametrize( - "col,valid_list,operator,valid,expected", [ + "col,valid_list,operator,valid,expected", + [ ("os.id", [1, 2, 3], "WHERE", True, "WHERE os.id IN (1,2,3)"), ("id2", ["'a'", "'b'"], "AND", True, "AND id2 IN ('a','b')"), ("id3", [1.0], "OR", True, "OR id3 IN (1.0)"), @@ -24,16 +20,11 @@ ("os.id", [1, 2, 3], "WHERE", False, "WHERE os.id NOT IN (1,2,3)"), ("id2", ["'a'", "'b'"], "AND", False, "AND id2 NOT IN ('a','b')"), ("id3", [1.0], "OR", False, "OR id3 NOT IN (1.0)"), - ("id4", None, "WHERE", False, "")] + ("id4", None, "WHERE", False, ""), + ], ) -def test_build_in_list_selector_query( - col, valid_list, operator, valid, expected): - assert (expected - == build_in_list_selector_query( - col=col, - valid_list=valid_list, - operator=operator, - valid=valid)) +def test_build_in_list_selector_query(col, valid_list, operator, valid, expected): + assert expected == build_in_list_selector_query(col=col, valid_list=valid_list, operator=operator, valid=valid) def test_build_in_selector_error(): @@ -41,23 +32,22 @@ def test_build_in_selector_error(): Test that build_in_list_selector_query raises the expected error for an invalid operator """ - with pytest.raises(ValueError, match='Operator must be'): - build_in_list_selector_query( - col='silly', - valid_list=[1, 2, 3], - operator='above', - valid=True) - - -@pytest.mark.parametrize('clauses, expected', [ - (['foo=b', 'baz=a'], 'WHERE foo=b AND baz=a'), - (['WHERE foo=b and baz=c'], 'WHERE foo=b and baz=c'), - (['foo=b', 'baz=a', 'bar=c'], 'WHERE foo=b AND baz=a AND bar=c'), - (['WHERE foo=b', 'baz=a'], 'WHERE foo=b AND baz=a'), - (['where foo=b', 'baz=a'], 'where foo=b AND baz=a'), - ([], ''), - (['foo=b'], 'WHERE foo=b') -]) + with pytest.raises(ValueError, match="Operator must be"): + build_in_list_selector_query(col="silly", valid_list=[1, 2, 3], operator="above", valid=True) + + +@pytest.mark.parametrize( + "clauses, expected", + [ + (["foo=b", "baz=a"], "WHERE foo=b AND baz=a"), + (["WHERE foo=b and baz=c"], "WHERE foo=b and baz=c"), + (["foo=b", "baz=a", "bar=c"], "WHERE foo=b AND baz=a AND bar=c"), + (["WHERE foo=b", "baz=a"], "WHERE foo=b AND baz=a"), + (["where foo=b", "baz=a"], "where foo=b AND baz=a"), + ([], ""), + (["foo=b"], "WHERE foo=b"), + ], +) def test_build_where_clause(clauses, expected): assert build_where_clause(clauses=clauses) == expected @@ -80,29 +70,32 @@ def query(self, _query): @pytest.mark.parametrize( - "behavior_session_ids,expected", [ - (None, + "behavior_session_ids,expected", + [ + ( + None, WhitespaceStrippedString(""" SELECT foraging_id, id as behavior_session_id FROM behavior_sessions WHERE foraging_id IS NOT NULL ; - """)), - (["'id1'", "'id2'"], + """), + ), + ( + ["'id1'", "'id2'"], WhitespaceStrippedString(""" SELECT foraging_id, id as behavior_session_id FROM behavior_sessions WHERE foraging_id IS NOT NULL AND id IN ('id1','id2'); - """)) - ] + """), + ), + ], ) -def test_foraging_id_map( - behavior_session_ids, expected): +def test_foraging_id_map(behavior_session_ids, expected): assert expected == foraging_id_map_from_behavior_session_id( - lims_engine=MockQueryEngine(), - logger=None, - behavior_session_ids=behavior_session_ids) + lims_engine=MockQueryEngine(), logger=None, behavior_session_ids=behavior_session_ids + ) def test_sanitize_uuid_list(): @@ -111,16 +104,15 @@ def test_sanitize_uuid_list(): """ input_list = [ - '12345678123456781234567812345678', - 'aaa', - '1234567812345678123456781234567812345678', - 'abcdefab-1234-abcd-0123-0123456789ab'] + "12345678123456781234567812345678", + "aaa", + "1234567812345678123456781234567812345678", + "abcdefab-1234-abcd-0123-0123456789ab", + ] - sanitized = _sanitize_uuid_list( - uuid_list=input_list) + sanitized = _sanitize_uuid_list(uuid_list=input_list) - assert sanitized == ['12345678123456781234567812345678', - 'abcdefab-1234-abcd-0123-0123456789ab'] + assert sanitized == ["12345678123456781234567812345678", "abcdefab-1234-abcd-0123-0123456789ab"] def test_donor_id_list_from_ecephys_session_ids(): @@ -133,20 +125,19 @@ def test_donor_id_list_from_ecephys_session_ids(): """ class DummyConnection(object): - def select(self, query=None): - data = [{'ecephys_session_id': 1, 'donor_id': 2}, - {'ecephys_session_id': 3, 'donor_id': 2}, - {'ecephys_session_id': 4, 'donor_id': 0}, - {'ecephys_session_id': 5, 'donor_id': 7}, - {'ecephys_session_id': 6, 'donor_id': 2}] + data = [ + {"ecephys_session_id": 1, "donor_id": 2}, + {"ecephys_session_id": 3, "donor_id": 2}, + {"ecephys_session_id": 4, "donor_id": 0}, + {"ecephys_session_id": 5, "donor_id": 7}, + {"ecephys_session_id": 6, "donor_id": 2}, + ] return pd.DataFrame(data=data) expected = [0, 2, 7] - actual = donor_id_list_from_ecephys_session_ids( - lims_connection=DummyConnection(), - session_id_list=[9, 9, 9]) + actual = donor_id_list_from_ecephys_session_ids(lims_connection=DummyConnection(), session_id_list=[9, 9, 9]) assert actual == expected diff --git a/allensdk/test/test_temp_dir.py b/allensdk/test/test_temp_dir.py index 595638f33c..37be3da1e9 100644 --- a/allensdk/test/test_temp_dir.py +++ b/allensdk/test/test_temp_dir.py @@ -9,16 +9,14 @@ def mock_request(): return MagicMock() -@pytest.mark.parametrize("ismount,base_path",[ - (True, os.path.normpath(os.path.join('/', 'dev', 'shm'))), - (False, os.path.dirname(temp_dir.__file__)) -]) -@patch("numpy.random.randint", side_effect=([1, 2, 3, 4, 5, 6], - [1, 2, 3, 4, 5, 7])) +@pytest.mark.parametrize( + "ismount,base_path", + [(True, os.path.normpath(os.path.join("/", "dev", "shm"))), (False, os.path.dirname(temp_dir.__file__))], +) +@patch("numpy.random.randint", side_effect=([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 7])) @patch("os.listdir", return_value=["allensdk_test_123456"]) @patch("os.makedirs") -def test_tmp_dir(os_makedirs, os_listdir, randint, - mock_request, ismount, base_path): +def test_tmp_dir(os_makedirs, os_listdir, randint, mock_request, ismount, base_path): with patch("os.path.exists", return_value=True): with patch("os.path.ismount", return_value=ismount): path = temp_dir.temp_dir(mock_request) diff --git a/allensdk/test_utilities/__init__.py b/allensdk/test_utilities/__init__.py index 92ceaf67c3..8e51ec55db 100644 --- a/allensdk/test_utilities/__init__.py +++ b/allensdk/test_utilities/__init__.py @@ -32,4 +32,4 @@ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -# \ No newline at end of file +# diff --git a/allensdk/test_utilities/custom_comparators.py b/allensdk/test_utilities/custom_comparators.py index 5eb3a22f15..503b2b5a8c 100644 --- a/allensdk/test_utilities/custom_comparators.py +++ b/allensdk/test_utilities/custom_comparators.py @@ -12,8 +12,8 @@ class WhitespaceStrippedString(object): matches the regex \\s, (which includes [ \\t\\n\\r\\f\\v], and other unicode whitespace characters). """ - def __init__(self, string: str, whitespace_chars: str = r"\s", - ASCII: bool = False): + + def __init__(self, string: str, whitespace_chars: str = r"\s", ASCII: bool = False): self.orig = string self.whitespace_chars = whitespace_chars self.flags = re.ASCII if ASCII else 0 @@ -22,15 +22,12 @@ def __init__(self, string: str, whitespace_chars: str = r"\s", def __eq__(self, other: Union[str, "WhitespaceStrippedString"]): if isinstance(other, str): - other = WhitespaceStrippedString( - other, self.whitespace_chars, self.flags) + other = WhitespaceStrippedString(other, self.whitespace_chars, self.flags) self.diff = list(self.differ.compare(self.value, other.value)) return self.value == other.value -def safe_df_comparison(expected: pd.DataFrame, - obtained: pd.DataFrame, - expect_identical_column_order: bool = False): +def safe_df_comparison(expected: pd.DataFrame, obtained: pd.DataFrame, expect_identical_column_order: bool = False): """ Compare two dataframes in a way that is agnostic to column order and datatype of NULL values @@ -60,7 +57,7 @@ def safe_df_comparison(expected: pd.DataFrame, - loops over non-null values, casts arrays into lists, and compares with == """ - msg = '' + msg = "" columns_match = True if not expect_identical_column_order: obtained_column_set = set(obtained.columns) @@ -72,11 +69,11 @@ def safe_df_comparison(expected: pd.DataFrame, columns_match = False if not columns_match: - msg += 'column mis-match\n' - msg += 'obtained columns\n' - msg += f'{obtained.columns}\n' - msg += 'expected columns\n' - msg += f'{expected.columns}\n' + msg += "column mis-match\n" + msg += "obtained columns\n" + msg += f"{obtained.columns}\n" + msg += "expected columns\n" + msg += f"{expected.columns}\n" missing_from_obtained = [] for c in expected.columns: @@ -86,30 +83,30 @@ def safe_df_comparison(expected: pd.DataFrame, for c in obtained.columns: if c not in expected.columns: missing_from_expected.append(c) - msg += f'missing from obtained\n{missing_from_obtained}\n' - msg += f'missing from expected\n{missing_from_expected}\n' + msg += f"missing from obtained\n{missing_from_obtained}\n" + msg += f"missing from expected\n{missing_from_expected}\n" raise RuntimeError(msg) if not expected.index.equals(obtained.index): - msg += 'index mis-match\n' - msg += 'expected index\n' - msg += f'{expected.index}\n' - msg += 'obtained index\n' - msg += f'{obtained.index}\n' + msg += "index mis-match\n" + msg += "expected index\n" + msg += f"{expected.index}\n" + msg += "obtained index\n" + msg += f"{obtained.index}\n" raise RuntimeError(msg) for col in expected.columns: expected_null = expected[col].isnull() obtained_null = obtained[col].isnull() if not expected_null.equals(obtained_null): - msg += f'\n{col} not null at same point in ' - msg += 'obtained and expected\n' + msg += f"\n{col} not null at same point in " + msg += "obtained and expected\n" continue expected_valid = expected[~expected_null] obtained_valid = obtained[~obtained_null] if not expected_valid.index.equals(obtained_valid.index): - msg += '\nindex mismatch in non-null when checking ' - msg += f'{col}\n' + msg += "\nindex mismatch in non-null when checking " + msg += f"{col}\n" for index_val in expected_valid.index.values: e = expected_valid.loc[index_val, col] o = obtained_valid.loc[index_val, col] @@ -118,16 +115,14 @@ def safe_df_comparison(expected: pd.DataFrame, if isinstance(o, pd.Series): o = list(o) if not e == o: - msg += f'\n{col}\n' - msg += f'expected: {e}\n' - msg += f'obtained: {o}\n' - if msg != '': + msg += f"\n{col}\n" + msg += f"expected: {e}\n" + msg += f"obtained: {o}\n" + if msg != "": raise RuntimeError(msg) -def stimulus_pickle_equivalence( - data0: dict, - data1: dict) -> bool: +def stimulus_pickle_equivalence(data0: dict, data1: dict) -> bool: """ Compare two sets of data loaded from a stimulus pickle file. Return True if they are identical. @@ -136,9 +131,7 @@ def stimulus_pickle_equivalence( return _nested_dict_equivalence(data0, data1) -def _nested_scalar_equivalence( - val0: Any, - val1: Any) -> bool: +def _nested_scalar_equivalence(val0: Any, val1: Any) -> bool: """ Compare two scalars. Return True if the scalars are identical. @@ -163,9 +156,7 @@ def _nested_scalar_equivalence( return True -def _nested_iterable_equivalence( - list0: Iterable, - list1: Iterable) -> bool: +def _nested_iterable_equivalence(list0: Iterable, list1: Iterable) -> bool: """ Compare the contents of two iterables. Return True if they are identical. @@ -187,7 +178,7 @@ def _nested_iterable_equivalence( if isinstance(v0, dict): if not _nested_dict_equivalence(v0, v1): return False - elif hasattr(v0, '__len__'): + elif hasattr(v0, "__len__"): if not _nested_iterable_equivalence(v0, v1): return False else: @@ -197,9 +188,7 @@ def _nested_iterable_equivalence( return True -def _nested_dict_equivalence( - dict0: dict, - dict1: dict) -> bool: +def _nested_dict_equivalence(dict0: dict, dict1: dict) -> bool: """ Compare the contents of two dicts. Return True if the dicts are identical. @@ -223,7 +212,7 @@ def _nested_dict_equivalence( if isinstance(val0, dict): if not _nested_dict_equivalence(val0, val1): return False - elif hasattr(val0, '__len__'): + elif hasattr(val0, "__len__"): if not _nested_iterable_equivalence(val0, val1): return False else: diff --git a/allensdk/test_utilities/regression_fixture.py b/allensdk/test_utilities/regression_fixture.py index f9ec2ad57c..bc8856ebdc 100644 --- a/allensdk/test_utilities/regression_fixture.py +++ b/allensdk/test_utilities/regression_fixture.py @@ -4,13 +4,16 @@ import json from pathlib import Path -if 'TEST_SESSION_ANALYSIS_REGRESSION_DATA' in os.environ: - data_file = os.environ['TEST_SESSION_ANALYSIS_REGRESSION_DATA'] +if "TEST_SESSION_ANALYSIS_REGRESSION_DATA" in os.environ: + data_file = os.environ["TEST_SESSION_ANALYSIS_REGRESSION_DATA"] else: - data_file = str(Path(__file__).parent / '..' / 'test' / 'brain_observatory' / 'test_session_analysis_regression_data_list.json') + data_file = str( + Path(__file__).parent / ".." / "test" / "brain_observatory" / "test_session_analysis_regression_data_list.json" + ) + def get_list_of_path_dict(): pyversion = sys.version_info[0] logging.debug("loading " + data_file) - with open(data_file,'r') as f: - return [curr_fixture for curr_fixture in json.load(f) if curr_fixture['version'] == pyversion] + with open(data_file, "r") as f: + return [curr_fixture for curr_fixture in json.load(f) if curr_fixture["version"] == pyversion] diff --git a/allensdk/test_utilities/temp_dir.py b/allensdk/test_utilities/temp_dir.py index d738a7ff22..b59747fd8e 100644 --- a/allensdk/test_utilities/temp_dir.py +++ b/allensdk/test_utilities/temp_dir.py @@ -42,30 +42,27 @@ # Intended as a fixture. Used in conftest.py. def temp_dir(request): - - tmpfs = os.path.normpath(os.path.join('/', 'dev', 'shm')) + tmpfs = os.path.normpath(os.path.join("/", "dev", "shm")) # would like to check mount type, but that requires system calls if os.path.exists(tmpfs) and os.path.ismount(tmpfs): - base_path = tmpfs else: - base_path = os.path.dirname(__file__) fls = os.listdir(base_path) while True: - dname = ''.join(map(str, np.random.randint(0, 10, 6))) + dname = "".join(map(str, np.random.randint(0, 10, 6))) if "allensdk_test_{}".format(dname) not in fls: break - specific_path = os.path.join(base_path, 'allensdk_test_' + dname) + specific_path = os.path.join(base_path, "allensdk_test_" + dname) os.makedirs(specific_path) def fin(): shutil.rmtree(specific_path) if os.path.exists(specific_path): - warnings.warn('test dir {0} still exists!', UserWarning) + warnings.warn("test dir {0} still exists!", UserWarning) request.addfinalizer(fin) diff --git a/conftest.py b/conftest.py index 64a7908d5a..2ea97ee443 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,8 @@ import os import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import pytest # noqa: E402 from allensdk.test_utilities.temp_dir import temp_dir # noqa: E402 @@ -17,59 +18,59 @@ def md_temp_dir(request): def pytest_collection_modifyitems(config, items): - ''' A pytest magic function. This function is called post-collection and gives us a hook for modifying the + """A pytest magic function. This function is called post-collection and gives us a hook for modifying the collected items. - ''' + """ skip_api_endpoint_test = pytest.mark.skipif( - 'TEST_API_ENDPOINT' not in os.environ, - reason='this test requires that an API endpoint be specified (set the TEST_API_ENDPOINT environment variable).' + "TEST_API_ENDPOINT" not in os.environ, + reason="this test requires that an API endpoint be specified (set the TEST_API_ENDPOINT environment variable).", ) skip_nightly_test = pytest.mark.skipif( - os.getenv('TEST_COMPLETE') != 'true', - reason='this test is either time/memory/compute expensive or it depends on resources internal to the Allen Institute. '\ - 'Either way, it does\'nt run by default and must be opted into (it does run in our nightly builds).' + os.getenv("TEST_COMPLETE") != "true", + reason="this test is either time/memory/compute expensive or it depends on resources internal to the Allen Institute. " + "Either way, it does'nt run by default and must be opted into (it does run in our nightly builds).", ) skip_flaky_test = pytest.mark.skipif( - (os.getenv('TEST_COMPLETE') != 'true') and (os.getenv('TEST_FLAKY') != 'true'), - reason='this test does not consistently pass (for instance, because it makes requests that sometimes time out).'\ - 'All such tests should be fixed, but in the mean time we\'ve restricted it to run in our nightly build only '\ - 'in order to reduce the prevalence of bogus test results.' + (os.getenv("TEST_COMPLETE") != "true") and (os.getenv("TEST_FLAKY") != "true"), + reason="this test does not consistently pass (for instance, because it makes requests that sometimes time out)." + "All such tests should be fixed, but in the mean time we've restricted it to run in our nightly build only " + "in order to reduce the prevalence of bogus test results.", ) skip_prerelease_test = pytest.mark.skipif( - os.environ.get('TEST_PRERELEASE') != 'true', - reason='prerelease tests are only valid if external and internal data expected to align' + os.environ.get("TEST_PRERELEASE") != "true", + reason="prerelease tests are only valid if external and internal data expected to align", ) skip_neuron_test = pytest.mark.skipif( - os.getenv('TEST_NEURON') != 'true', - reason='this test depends on the NEURON simulation library. This dependency is not straghtforward to build '\ - 'and install, so you must opt in to running this test' + os.getenv("TEST_NEURON") != "true", + reason="this test depends on the NEURON simulation library. This dependency is not straghtforward to build " + "and install, so you must opt in to running this test", ) skip_outside_bamboo_test = pytest.mark.skipif( - os.environ.get('TEST_BAMBOO') != 'true', - reason='this test depends on the resources only available to Bamboo agents, but are still fast. If they are slow, mark with nightly' + os.environ.get("TEST_BAMBOO") != "true", + reason="this test depends on the resources only available to Bamboo agents, but are still fast. If they are slow, mark with nightly", ) for item in items: - if 'requires_api_endpoint' in item.keywords: + if "requires_api_endpoint" in item.keywords: item.add_marker(skip_api_endpoint_test) - if 'nightly' in item.keywords: + if "nightly" in item.keywords: item.add_marker(skip_nightly_test) - if 'todo_flaky' in item.keywords: + if "todo_flaky" in item.keywords: item.add_marker(skip_flaky_test) - if 'prerelease' in item.keywords: + if "prerelease" in item.keywords: item.add_marker(skip_prerelease_test) - if 'requires_neuron' in item.keywords: + if "requires_neuron" in item.keywords: item.add_marker(skip_neuron_test) - if 'requires_bamboo' in item.keywords: + if "requires_bamboo" in item.keywords: item.add_marker(skip_outside_bamboo_test) diff --git a/doc_template/conf.py b/doc_template/conf.py index 481f5764b3..7d858b9729 100644 --- a/doc_template/conf.py +++ b/doc_template/conf.py @@ -13,7 +13,8 @@ # this lets the docs build in a headless environment import matplotlib -matplotlib.use('agg') + +matplotlib.use("agg") import sys import os @@ -22,144 +23,142 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath("..")) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', 'numpydoc'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode", "sphinx.ext.autosummary", "numpydoc"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['aibs_sphinx/templates'] +templates_path = ["aibs_sphinx/templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'Allen SDK' -copyright = u'2015, Allen Institute for Brain Science' +project = "Allen SDK" +copyright = "2015, Allen Institute for Brain Science" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = 'master' +version = "master" # The full version, including alpha/beta/rc tags. -release = 'dev' +release = "dev" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -modindex_common_prefix = ['allensdk.'] +modindex_common_prefix = ["allensdk."] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'aibs_sphinx' -html_theme_options = { - "sidebarwidth": "300" -} +html_theme = "aibs_sphinx" +html_theme_options = {"sidebarwidth": "300"} # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -html_theme_path = ['../doc_template'] +html_theme_path = ["../doc_template"] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = 'logo.jpg' +# html_logo = 'logo.jpg' # The name of an image file (within the te path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['aibs_sphinx/static', 'examples_root'] +html_static_path = ["aibs_sphinx/static", "examples_root"] -html_extra_path = ['../doc_template/.nojekyll'] +html_extra_path = ["../doc_template/.nojekyll"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = { '**': [ 'globaltoc.html', 'localtoc.html', 'sourcelink.html', 'searchbox.html' ]} -html_sidebars = { '**': [ 'globaltoc.html', 'searchbox.html' ]} +# html_sidebars = { '**': [ 'globaltoc.html', 'localtoc.html', 'sourcelink.html', 'searchbox.html' ]} +html_sidebars = {"**": ["globaltoc.html", "searchbox.html"]} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. html_show_sphinx = False @@ -170,51 +169,48 @@ # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'Allen SDKdoc' +htmlhelp_basename = "Allen SDKdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'Allen SDK.tex', u'Allen SDK Documentation', - u'Allen Institute for Brain Science', 'manual'), + ("index", "Allen SDK.tex", "Allen SDK Documentation", "Allen Institute for Brain Science", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. # latex_domain_indices = True @@ -223,10 +219,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'Allen SDK', u'Allen SDK Documentation', - [u'Allen Institute for Brain Science'], 1) -] +man_pages = [("index", "Allen SDK", "Allen SDK Documentation", ["Allen Institute for Brain Science"], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -237,64 +230,70 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Allen SDK', u'Allen SDK Documentation', - u'Allen Institute for Brain Science', 'Allen SDK', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Allen SDK", + "Allen SDK Documentation", + "Allen Institute for Brain Science", + "Allen SDK", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # -- Options for Epub output --------------------------------------------------- # Bibliographic Dublin Core info. -epub_title = u'Allen SDK' -epub_author = u'Allen Institute for Brain Science' -epub_publisher = u'Allen Institute for Brain Science' -epub_copyright = u'2015, Allen Institute for Brain Science' +epub_title = "Allen SDK" +epub_author = "Allen Institute for Brain Science" +epub_publisher = "Allen Institute for Brain Science" +epub_copyright = "2015, Allen Institute for Brain Science" # The language of the text. It defaults to the language option # or en if the language is not set. -#epub_language = '' +# epub_language = '' # The scheme of the identifier. Typical schemes are ISBN or URL. -#epub_scheme = '' +# epub_scheme = '' # The unique identifier of the text. This can be a ISBN number # or the project homepage. -#epub_identifier = '' +# epub_identifier = '' # A unique identification for the text. -#epub_uid = '' +# epub_uid = '' # A tuple containing the cover image and cover page html template filenames. -#epub_cover = () +# epub_cover = () # HTML files that should be inserted before the pages created by sphinx. # The format is a list of tuples containing the path and title. -#epub_pre_files = [] +# epub_pre_files = [] # HTML files shat should be inserted after the pages created by sphinx. # The format is a list of tuples containing the path and title. -#epub_post_files = [] +# epub_post_files = [] # A list of files that should not be packed into the epub file. -#epub_exclude_files = [] +# epub_exclude_files = [] # The depth of the table of contents in toc.ncx. -#epub_tocdepth = 3 +# epub_tocdepth = 3 # Allow duplicate toc entries. -#epub_tocdup = True +# epub_tocdup = True -skip_autodoc_names = [ 'DEFAULTS' ] +skip_autodoc_names = ["DEFAULTS"] numpydoc_show_class_members = False @@ -312,22 +311,22 @@ def skip_autodoc(app, what, name, obj, skip, options): def render_notebooks(_): - nb_root = os.path.abspath('examples_root/examples/nb') + nb_root = os.path.abspath("examples_root/examples/nb") for filename in os.listdir(nb_root): - if filename.endswith('.ipynb'): + if filename.endswith(".ipynb"): nb = os.path.join(nb_root, filename) - nb_html = nb.replace("ipynb","html") - os.system('jupyter-nbconvert --to html %s --output %s' % (nb, nb_html)) + nb_html = nb.replace("ipynb", "html") + os.system("jupyter-nbconvert --to html %s --output %s" % (nb, nb_html)) def run_apidoc(_): - sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) cur_dir = os.path.abspath(os.path.dirname(__file__)) - module = os.path.join(cur_dir,"..","allensdk") - os.system('sphinx-apidoc -e -o %s %s --force' % (cur_dir, module)) + module = os.path.join(cur_dir, "..", "allensdk") + os.system("sphinx-apidoc -e -o %s %s --force" % (cur_dir, module)) def setup(app): - app.connect('autodoc-skip-member', skip_autodoc) - app.connect('builder-inited', render_notebooks) - app.connect('builder-inited', run_apidoc) + app.connect("autodoc-skip-member", skip_autodoc) + app.connect("builder-inited", render_notebooks) + app.connect("builder-inited", run_apidoc) diff --git a/doc_template/examples_root/examples/biophysical_ex1.py b/doc_template/examples_root/examples/biophysical_ex1.py index 7877abea43..24054beab2 100644 --- a/doc_template/examples_root/examples/biophysical_ex1.py +++ b/doc_template/examples_root/examples/biophysical_ex1.py @@ -1,8 +1,6 @@ from allensdk.api.queries.biophysical_api import BiophysicalApi bp = BiophysicalApi() -bp.cache_stimulus = True # change to False to not download the large stimulus NWB file -neuronal_model_id = 472451419 # get this from the web site as above -bp.cache_data(neuronal_model_id, working_directory='neuronal_model') - - +bp.cache_stimulus = True # change to False to not download the large stimulus NWB file +neuronal_model_id = 472451419 # get this from the web site as above +bp.cache_data(neuronal_model_id, working_directory="neuronal_model") diff --git a/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.ipynb b/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.ipynb index d85269b4b3..634500a52c 100644 --- a/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.ipynb +++ b/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.ipynb @@ -23,12 +23,13 @@ "import seaborn as sns\n", "import shutil\n", "import warnings\n", - "warnings.filterwarnings('ignore')\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", "\n", "%matplotlib inline\n", "\n", - "sns.set(style='whitegrid',font_scale=1.5)\n", - "plt.rcParams.update({'axes.grid':False})" + "sns.set(style=\"whitegrid\", font_scale=1.5)\n", + "plt.rcParams.update({\"axes.grid\": False})" ] }, { @@ -44,24 +45,24 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "def get_manifest_args(args):\n", " return runner.load_description(args)\n", - " \n", "\n", - "def get_sweep_data(nwb_file,sweep_number):\n", + "\n", + "def get_sweep_data(nwb_file, sweep_number):\n", " nwb = NwbDataSet(nwb_file)\n", " sweep = nwb.get_sweep(sweep_number)\n", - " stim_diff = np.diff(sweep['stimulus']*1e12)\n", + " stim_diff = np.diff(sweep[\"stimulus\"] * 1e12)\n", " stim_start = np.where(stim_diff != 0)[0][-2]\n", " stim_end = np.where(stim_diff != 0)[0][-1]\n", - " \n", + "\n", " # read v and t as numpy arrays\n", - " v = sweep['response']*1e3\n", - " dt = 1.0e3 / sweep['sampling_rate']\n", + " v = sweep[\"response\"] * 1e3\n", + " dt = 1.0e3 / sweep[\"sampling_rate\"]\n", " num_samples = len(v)\n", " t = np.arange(num_samples) * dt\n", - " return t,v,t[stim_start],t[stim_end]\n", + " return t, v, t[stim_start], t[stim_end]\n", + "\n", "\n", "def copytree(src, dst, symlinks=False, ignore=None):\n", " for item in os.listdir(src):\n", @@ -70,8 +71,7 @@ " if os.path.isdir(s):\n", " shutil.copytree(s, d, symlinks, ignore)\n", " else:\n", - " shutil.copy2(s, d)\n", - " " + " shutil.copy2(s, d)" ] }, { @@ -87,9 +87,8 @@ "metadata": {}, "outputs": [], "source": [ - " \n", - "cell_id = 468193142 # get this from the web site: http://celltypes.brain-map.org\n", - "sweep_num = 46 # Select a Long Square sweep : 1s DC\n" + "cell_id = 468193142 # get this from the web site: http://celltypes.brain-map.org\n", + "sweep_num = 46 # Select a Long Square sweep : 1s DC" ] }, { @@ -111,17 +110,17 @@ } ], "source": [ - "sdk_model_templates = {'all_active':491455321,'perisomatic':329230710}\n", + "sdk_model_templates = {\"all_active\": 491455321, \"perisomatic\": 329230710}\n", "bp = BiophysicalApi()\n", - "bp.cache_stimulus = True \n", - "model_list = bp.get_neuronal_models(cell_id,model_type_ids=[sdk_model_templates['all_active']])\n", + "bp.cache_stimulus = True\n", + "model_list = bp.get_neuronal_models(cell_id, model_type_ids=[sdk_model_templates[\"all_active\"]])\n", "model_dict = model_list[0]\n", "\n", - "model_dir = 'all_active_models'\n", - "bp.cache_data(model_dict['id'], working_directory=model_dir) \n", - "new_model_file = 'fit_parameters_new.json'\n", - "shutil.copyfile(new_model_file,os.path.join(model_dir,new_model_file))\n", - "copytree('modfiles',os.path.join(model_dir,'modfiles'))\n" + "model_dir = \"all_active_models\"\n", + "bp.cache_data(model_dict[\"id\"], working_directory=model_dir)\n", + "new_model_file = \"fit_parameters_new.json\"\n", + "shutil.copyfile(new_model_file, os.path.join(model_dir, new_model_file))\n", + "copytree(\"modfiles\", os.path.join(model_dir, \"modfiles\"))" ] }, { @@ -144,18 +143,18 @@ ], "source": [ "os.chdir(model_dir)\n", - "os.system('nrnivmodl modfiles/')\n", - "manifest_file = 'manifest.json'\n", + "os.system(\"nrnivmodl modfiles/\")\n", + "manifest_file = \"manifest.json\"\n", "manifest_dict = json.load(open(manifest_file))\n", "\n", "# sweeps by type is not populated when the model is downloaded using the api\n", - "# in that case add the sweep type to the manifest \n", - "if 'sweeps_by_type' not in manifest_dict['runs'][0]:\n", - " manifest_dict['runs'][0]['sweeps_by_type'] = {\"Long Square\":[sweep_num]}\n", - "json.dump(manifest_dict,open(manifest_file,'w'),indent=2) \n", + "# in that case add the sweep type to the manifest\n", + "if \"sweeps_by_type\" not in manifest_dict[\"runs\"][0]:\n", + " manifest_dict[\"runs\"][0][\"sweeps_by_type\"] = {\"Long Square\": [sweep_num]}\n", + "json.dump(manifest_dict, open(manifest_file, \"w\"), indent=2)\n", "\n", "schema_legacy = dict(manifest_file=manifest_file)\n", - "runner.run(schema_legacy,procs=1,sweeps=[sweep_num])\n" + "runner.run(schema_legacy, procs=1, sweeps=[sweep_num])" ] }, { @@ -177,21 +176,20 @@ } ], "source": [ - "\n", "manifest_dict = json.load(open(manifest_file))\n", "\n", "# Change the simulation output directory to avoid overwriting for the new models\n", - "for manifest_config in manifest_dict['manifest']:\n", - " if manifest_config['key'] == 'WORKDIR':\n", - " manifest_config['spec'] = 'work_new'\n", + "for manifest_config in manifest_dict[\"manifest\"]:\n", + " if manifest_config[\"key\"] == \"WORKDIR\":\n", + " manifest_config[\"spec\"] = \"work_new\"\n", "\n", - "new_manifest_file = 'manifest_new.json' \n", - " \n", - "manifest_dict['biophys'][0]['model_file'] = [new_manifest_file,new_model_file]\n", - "json.dump(manifest_dict,open(new_manifest_file,'w'),indent=2) \n", + "new_manifest_file = \"manifest_new.json\"\n", "\n", - "schema_new = dict(manifest_file=new_manifest_file, axon_type = 'stub')\n", - "runner.run(schema_new,procs=1,sweeps=[sweep_num])\n" + "manifest_dict[\"biophys\"][0][\"model_file\"] = [new_manifest_file, new_model_file]\n", + "json.dump(manifest_dict, open(new_manifest_file, \"w\"), indent=2)\n", + "\n", + "schema_new = dict(manifest_file=new_manifest_file, axon_type=\"stub\")\n", + "runner.run(schema_new, procs=1, sweeps=[sweep_num])" ] }, { @@ -219,30 +217,29 @@ } ], "source": [ - "\n", "legacy_config = get_manifest_args(schema_legacy)\n", - "output_nwb_legacy = legacy_config.manifest.get_path('output_path')\n", + "output_nwb_legacy = legacy_config.manifest.get_path(\"output_path\")\n", "\n", - "new_config = get_manifest_args(schema_new)\n", - "output_nwb_new = new_config.manifest.get_path('output_path')\n", + "new_config = get_manifest_args(schema_new)\n", + "output_nwb_new = new_config.manifest.get_path(\"output_path\")\n", "\n", - "exp_nwb = new_config.manifest.get_path('stimulus_path')\n", + "exp_nwb = new_config.manifest.get_path(\"stimulus_path\")\n", "\n", - "t_exp,v_exp,stim_start,stim_end = get_sweep_data(exp_nwb,sweep_num)\n", - "t_legacy_aa,v_legacy_aa,_,_ = get_sweep_data(output_nwb_legacy,sweep_num)\n", - "t_new_aa,v_new_aa,_,_ = get_sweep_data(output_nwb_new,sweep_num)\n", + "t_exp, v_exp, stim_start, stim_end = get_sweep_data(exp_nwb, sweep_num)\n", + "t_legacy_aa, v_legacy_aa, _, _ = get_sweep_data(output_nwb_legacy, sweep_num)\n", + "t_new_aa, v_new_aa, _, _ = get_sweep_data(output_nwb_new, sweep_num)\n", "\n", - "fig,ax = plt.subplots(figsize=(10,6))\n", - "ax.plot(t_exp,v_exp, color = 'k', label='Experiment')\n", - "ax.plot(t_legacy_aa,v_legacy_aa,color = 'r', label= 'Legacy Model')\n", - "ax.plot(t_new_aa,v_new_aa,color='b', label = 'New Model')\n", - "ax.set_xlim([stim_start-200,stim_end+200])\n", + "fig, ax = plt.subplots(figsize=(10, 6))\n", + "ax.plot(t_exp, v_exp, color=\"k\", label=\"Experiment\")\n", + "ax.plot(t_legacy_aa, v_legacy_aa, color=\"r\", label=\"Legacy Model\")\n", + "ax.plot(t_new_aa, v_new_aa, color=\"b\", label=\"New Model\")\n", + "ax.set_xlim([stim_start - 200, stim_end + 200])\n", "sns.despine(ax=ax)\n", - "h,l = ax.get_legend_handles_labels()\n", - "fig.legend(h,l,ncol=3,frameon=False,loc='lower center')\n", - "fig.tight_layout(rect=[0.0,0.05,1,0.95])\n", + "h, l = ax.get_legend_handles_labels()\n", + "fig.legend(h, l, ncol=3, frameon=False, loc=\"lower center\")\n", + "fig.tight_layout(rect=[0.0, 0.05, 1, 0.95])\n", "\n", - "plt.show()\n" + "plt.show()" ] } ], diff --git a/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.py b/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.py index c65ee91205..6e3d93a517 100644 --- a/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.py +++ b/doc_template/examples_root/examples/biophysical_sim/biophysical_sim.py @@ -9,14 +9,15 @@ import shutil import subprocess -sns.set(style='whitegrid', font_scale=1.5) -plt.rcParams.update({'axes.grid': False}) +sns.set(style="whitegrid", font_scale=1.5) +plt.rcParams.update({"axes.grid": False}) + +# %% Define utility functions -#%% Define utility functions def get_manifest_args(args): return runner.load_description(args) - + def get_sweep_data(nwb_file, sweep_number, time_scale=1e3, voltage_scale=1e3, stim_scale=1e12): """ @@ -26,7 +27,7 @@ def get_sweep_data(nwb_file, sweep_number, time_scale=1e3, voltage_scale=1e3, st nwb_file : string File name of a pre-existing NWB file. sweep_number : integer - + time_scale : float Convert to ms scale voltage_scale : float @@ -47,20 +48,21 @@ def get_sweep_data(nwb_file, sweep_number, time_scale=1e3, voltage_scale=1e3, st """ nwb = NwbDataSet(nwb_file) sweep = nwb.get_sweep(sweep_number) - stim = sweep['stimulus'] * stim_scale # in pA + stim = sweep["stimulus"] * stim_scale # in pA stim_diff = np.diff(stim) stim_start = np.where(stim_diff != 0)[0][-2] stim_end = np.where(stim_diff != 0)[0][-1] - + # read v and t as numpy arrays - v = sweep['response'] * voltage_scale # in mV - dt = time_scale / sweep['sampling_rate'] # in ms + v = sweep["response"] * voltage_scale # in mV + dt = time_scale / sweep["sampling_rate"] # in ms num_samples = len(v) t = np.arange(num_samples) * dt stim_start_time = t[stim_start] stim_end_time = t[stim_end] return t, v, stim_start_time, stim_end_time + def copytree(src, dst, symlinks=False, ignore=None): for item in os.listdir(src): s = os.path.join(src, item) @@ -69,80 +71,83 @@ def copytree(src, dst, symlinks=False, ignore=None): shutil.copytree(s, d, symlinks, ignore) else: shutil.copy2(s, d) - -#%% Pick a cell and a sweep to run from the available set of protocols - + + +# %% Pick a cell and a sweep to run from the available set of protocols + cell_id = 468193142 # get this from the web site: http://celltypes.brain-map.org sweep_num = 46 # Select a Long Square sweep : 1s DC -#%% Download all-acive model +# %% Download all-acive model -sdk_model_templates = {'all_active': 491455321, 'perisomatic': 329230710} +sdk_model_templates = {"all_active": 491455321, "perisomatic": 329230710} bp = BiophysicalApi() bp.cache_stimulus = True -model_list = bp.get_neuronal_models(cell_id, model_type_ids=[sdk_model_templates['all_active']]) # Only get the all-active model for the cell +model_list = bp.get_neuronal_models( + cell_id, model_type_ids=[sdk_model_templates["all_active"]] +) # Only get the all-active model for the cell model_dict = model_list[0] -model_dir = 'all_active_models' -bp.cache_data(model_dict['id'], working_directory=model_dir) -new_model_file = 'fit_parameters_new.json' +model_dir = "all_active_models" +bp.cache_data(model_dict["id"], working_directory=model_dir) +new_model_file = "fit_parameters_new.json" shutil.copyfile(new_model_file, os.path.join(model_dir, new_model_file)) -copytree('modfiles', os.path.join(model_dir, 'modfiles')) +copytree("modfiles", os.path.join(model_dir, "modfiles")) -#%% Running the legacy all-active models +# %% Running the legacy all-active models os.chdir(model_dir) -subprocess.check_call(['nrnivmodl', 'modfiles/']) -manifest_file = 'manifest.json' +subprocess.check_call(["nrnivmodl", "modfiles/"]) +manifest_file = "manifest.json" manifest_dict = json.load(open(manifest_file)) # sweeps by type is not populated when the model is downloaded using the api # in that case add the sweep type to the manifest -if 'sweeps_by_type' not in manifest_dict['runs'][0]: - manifest_dict['runs'][0]['sweeps_by_type'] = {"Long Square": [sweep_num]} -json.dump(manifest_dict, open(manifest_file, 'w'), indent=2) +if "sweeps_by_type" not in manifest_dict["runs"][0]: + manifest_dict["runs"][0]["sweeps_by_type"] = {"Long Square": [sweep_num]} +json.dump(manifest_dict, open(manifest_file, "w"), indent=2) schema_legacy = dict(manifest_file=manifest_file) runner.run(schema_legacy, procs=1, sweeps=[sweep_num]) -#%% Running the new all-active models +# %% Running the new all-active models manifest_dict = json.load(open(manifest_file)) # Change the simulation output directory to avoid overwriting for the new models -for manifest_config in manifest_dict['manifest']: - if manifest_config['key'] == 'WORKDIR': - manifest_config['spec'] = 'work_new' +for manifest_config in manifest_dict["manifest"]: + if manifest_config["key"] == "WORKDIR": + manifest_config["spec"] = "work_new" + +new_manifest_file = "manifest_new.json" -new_manifest_file = 'manifest_new.json' - -manifest_dict['biophys'][0]['model_file'] = [new_manifest_file, new_model_file] -json.dump(manifest_dict, open(new_manifest_file, 'w'), indent=2) +manifest_dict["biophys"][0]["model_file"] = [new_manifest_file, new_model_file] +json.dump(manifest_dict, open(new_manifest_file, "w"), indent=2) -schema_new = dict(manifest_file=new_manifest_file, axon_type='stub') +schema_new = dict(manifest_file=new_manifest_file, axon_type="stub") runner.run(schema_new, procs=1, sweeps=[sweep_num]) -#%% Comparing the responses +# %% Comparing the responses legacy_config = get_manifest_args(schema_legacy) -output_nwb_legacy = legacy_config.manifest.get_path('output_path') +output_nwb_legacy = legacy_config.manifest.get_path("output_path") new_config = get_manifest_args(schema_new) -output_nwb_new = new_config.manifest.get_path('output_path') +output_nwb_new = new_config.manifest.get_path("output_path") -exp_nwb = new_config.manifest.get_path('stimulus_path') +exp_nwb = new_config.manifest.get_path("stimulus_path") t_exp, v_exp, stim_start, stim_end = get_sweep_data(exp_nwb, sweep_num) t_legacy_aa, v_legacy_aa, _, _ = get_sweep_data(output_nwb_legacy, sweep_num) t_new_aa, v_new_aa, _, _ = get_sweep_data(output_nwb_new, sweep_num) fig, ax = plt.subplots(figsize=(10, 6)) -ax.plot(t_exp, v_exp, color='k', label='Experiment') -ax.plot(t_legacy_aa, v_legacy_aa, color='r', label='Legacy Model') -ax.plot(t_new_aa, v_new_aa, color='b', label='New Model') +ax.plot(t_exp, v_exp, color="k", label="Experiment") +ax.plot(t_legacy_aa, v_legacy_aa, color="r", label="Legacy Model") +ax.plot(t_new_aa, v_new_aa, color="b", label="New Model") ax.set_xlim([stim_start - 200, stim_end + 200]) sns.despine(ax=ax) handles, legends = ax.get_legend_handles_labels() -fig.legend(handles, legends, ncol=3, frameon=False, loc='lower center') +fig.legend(handles, legends, ncol=3, frameon=False, loc="lower center") fig.tight_layout(rect=[0.0, 0.05, 1, 0.95]) plt.show() diff --git a/doc_template/examples_root/examples/cell_types_ex.py b/doc_template/examples_root/examples/cell_types_ex.py index cec9e9e5c8..89c7e90c1b 100644 --- a/doc_template/examples_root/examples/cell_types_ex.py +++ b/doc_template/examples_root/examples/cell_types_ex.py @@ -1,30 +1,30 @@ -#=============================================================================== +# =============================================================================== # example 1 -#=============================================================================== +# =============================================================================== from allensdk.core.cell_types_cache import CellTypesCache -ctc = CellTypesCache(manifest_file='cell_types/manifest.json') +ctc = CellTypesCache(manifest_file="cell_types/manifest.json") # a list of cell metadata for cells with reconstructions, download if necessary cells = ctc.get_cells(require_reconstruction=True) # open the electrophysiology data of one cell, download if necessary -data_set = ctc.get_ephys_data(cells[0]['id']) +data_set = ctc.get_ephys_data(cells[0]["id"]) # read the reconstruction, download if necessary -reconstruction = ctc.get_reconstruction(cells[0]['id']) +reconstruction = ctc.get_reconstruction(cells[0]["id"]) -#=============================================================================== +# =============================================================================== # example 2 -#=============================================================================== +# =============================================================================== from allensdk.core.cell_types_cache import CellTypesCache from allensdk.ephys.extract_cell_features import extract_cell_features from collections import defaultdict # initialize the cache -ctc = CellTypesCache(manifest_file='cell_types/manifest.json') +ctc = CellTypesCache(manifest_file="cell_types/manifest.json") # pick a cell to analyze specimen_id = 324257146 @@ -33,25 +33,24 @@ data_set = ctc.get_ephys_data(specimen_id) sweeps = ctc.get_ephys_sweeps(specimen_id) -# group the sweeps by stimulus +# group the sweeps by stimulus sweep_numbers = defaultdict(list) for sweep in sweeps: - sweep_numbers[sweep['stimulus_name']].append(sweep['sweep_number']) + sweep_numbers[sweep["stimulus_name"]].append(sweep["sweep_number"]) # calculate features -cell_features = extract_cell_features(data_set, - sweep_numbers['Ramp'], - sweep_numbers['Short Square'], - sweep_numbers['Long Square']) +cell_features = extract_cell_features( + data_set, sweep_numbers["Ramp"], sweep_numbers["Short Square"], sweep_numbers["Long Square"] +) -#=============================================================================== +# =============================================================================== # example 3 -#=============================================================================== +# =============================================================================== import allensdk.core.swc as swc # if you ran the examples above, you will have a reconstruction here -file_name = 'cell_types/specimen_485909730/reconstruction.swc' +file_name = "cell_types/specimen_485909730/reconstruction.swc" morphology = swc.read_swc(file_name) # subsample the morphology 3x. root, soma, junctions, and the first child of the root are preserved. @@ -63,39 +62,39 @@ # a dictionary of compartments indexed by compartment id compartments_by_id = sparse_morphology.compartment_index -# the root soma compartment +# the root soma compartment soma = morphology.soma # all compartments are dictionaries of compartment properties # compartments also keep track of ids of their children for child in morphology.children_of(soma): - print(child['x'], child['y'], child['z'], child['radius']) + print(child["x"], child["y"], child["z"], child["radius"]) -#=============================================================================== +# =============================================================================== # example 4 -#=============================================================================== +# =============================================================================== from allensdk.core.nwb_data_set import NwbDataSet # if you ran the examples above, you will have a NWB file here -file_name = 'cell_types/specimen_485909730/ephys.nwb' +file_name = "cell_types/specimen_485909730/ephys.nwb" data_set = NwbDataSet(file_name) sweep_numbers = data_set.get_sweep_numbers() -sweep_number = sweep_numbers[0] +sweep_number = sweep_numbers[0] sweep_data = data_set.get_sweep(sweep_number) # spike times are in seconds relative to the start of the sweep spike_times = data_set.get_spike_times(sweep_number) # stimulus is a numpy array in amps -stimulus = sweep_data['stimulus'] +stimulus = sweep_data["stimulus"] # response is a numpy array in volts -reponse = sweep_data['response'] +reponse = sweep_data["response"] # sampling rate is in Hz -sampling_rate = sweep_data['sampling_rate'] +sampling_rate = sweep_data["sampling_rate"] # start/stop indices that exclude the experimental test pulse (if applicable) -index_range = sweep_data['index_range'] +index_range = sweep_data["index_range"] diff --git a/doc_template/examples_root/examples/connectivity_ex.py b/doc_template/examples_root/examples/connectivity_ex.py index 49364f9675..be0b140323 100644 --- a/doc_template/examples_root/examples/connectivity_ex.py +++ b/doc_template/examples_root/examples/connectivity_ex.py @@ -1,6 +1,6 @@ -#=============================================================================== +# =============================================================================== # example 1 -#=============================================================================== +# =============================================================================== from allensdk.core.mouse_connectivity_cache import MouseConnectivityCache @@ -9,20 +9,19 @@ # use the structure tree class to get information about the isocortex structure structure_tree = mcc.get_structure_tree() -isocortex_id = structure_tree.get_structures_by_name(['Isocortex'])[0]['id'] +isocortex_id = structure_tree.get_structures_by_name(["Isocortex"])[0]["id"] # a list of dictionaries containing metadata for non-Cre experiments -experiments = mcc.get_experiments(file_name='non_cre.json', - injection_structure_ids=[isocortex_id]) +experiments = mcc.get_experiments(file_name="non_cre.json", injection_structure_ids=[isocortex_id]) # download the projection density volume for one of the experiments -pd = mcc.get_projection_density(experiments[0]['id']) +pd = mcc.get_projection_density(experiments[0]["id"]) -#=============================================================================== +# =============================================================================== # example 2 -#=============================================================================== +# =============================================================================== import nrrd -file_name = 'mouse_connectivity/experiment_644250774/projection_density_25.nrrd' +file_name = "mouse_connectivity/experiment_644250774/projection_density_25.nrrd" data_array, metadata = nrrd.read(file_name) diff --git a/doc_template/examples_root/examples/data_api_client_ex.py b/doc_template/examples_root/examples/data_api_client_ex.py index f33eb21514..486e7ad356 100644 --- a/doc_template/examples_root/examples/data_api_client_ex.py +++ b/doc_template/examples_root/examples/data_api_client_ex.py @@ -1,132 +1,119 @@ -#=============================================================================== +# =============================================================================== # example 1 -#=============================================================================== +# =============================================================================== from allensdk.api.queries.rma_api import RmaApi rma = RmaApi() -data = rma.model_query('Atlas', - criteria="[name$il'*Mouse*']") +data = rma.model_query("Atlas", criteria="[name$il'*Mouse*']") -#=============================================================================== +# =============================================================================== # example 2 -#=============================================================================== - -associations = ''.join(['[id$eq1]', - 'structure_graph(ontology),', - 'graphic_group_labels']) - -atlas_data = rma.model_query('Atlas', - include=associations, - criteria=associations, - only=['atlases.id', - 'atlases.name', - 'atlases.image_type', - 'ontologies.id', - 'ontologies.name', - 'structure_graphs.id', - 'structure_graphs.name', - 'graphic_group_labels.id', - 'graphic_group_labels.name']) - -#=============================================================================== +# =============================================================================== + +associations = "".join(["[id$eq1]", "structure_graph(ontology),", "graphic_group_labels"]) + +atlas_data = rma.model_query( + "Atlas", + include=associations, + criteria=associations, + only=[ + "atlases.id", + "atlases.name", + "atlases.image_type", + "ontologies.id", + "ontologies.name", + "structure_graphs.id", + "structure_graphs.name", + "graphic_group_labels.id", + "graphic_group_labels.name", + ], +) + +# =============================================================================== # example 3 -#=============================================================================== +# =============================================================================== # http://api.brain-map.org/api/v2/data.json schema = rma.get_schema() for entry in schema: - data_description = entry['DataDescription'] + data_description = entry["DataDescription"] clz = list(data_description.keys())[0] info = list(data_description.values())[0] - fields = info['fields'] - associations = info['associations'] - table = info['table'] + fields = info["fields"] + associations = info["associations"] + table = info["table"] print("class: %s" % (clz)) - print("fields: %s" % (','.join(f['name'] for f in fields))) - print("associations: %s" % (','.join(a['name'] for a in associations))) + print("fields: %s" % (",".join(f["name"] for f in fields))) + print("associations: %s" % (",".join(a["name"] for a in associations))) print("table: %s\n" % (table)) -#=============================================================================== +# =============================================================================== # example 4 -#=============================================================================== +# =============================================================================== import pandas as pd -structures = pd.DataFrame( - rma.model_query('Structure', - criteria='[graph_id$eq1]', - num_rows='all')) +structures = pd.DataFrame(rma.model_query("Structure", criteria="[graph_id$eq1]", num_rows="all")) -#=============================================================================== +# =============================================================================== # example 5 -#=============================================================================== +# =============================================================================== -names_and_acronyms = structures.loc[:,['name', 'acronym']] +names_and_acronyms = structures.loc[:, ["name", "acronym"]] -#=============================================================================== +# =============================================================================== # example 6 -#=============================================================================== +# =============================================================================== -mea = structures[structures.acronym == 'MEA'] -mea_id = mea.iloc[0,:].id +mea = structures[structures.acronym == "MEA"] +mea_id = mea.iloc[0, :].id mea_children = structures[structures.parent_structure_id == mea_id] -print(mea_children['name']) +print(mea_children["name"]) -#=============================================================================== +# =============================================================================== # example 7 -#=============================================================================== +# =============================================================================== criteria_string = "structure_sets[name$eq'Mouse Connectivity - Summary']" include_string = "ontology" -summary_structures = \ - pd.DataFrame( - rma.model_query('Structure', - criteria=criteria_string, - include=include_string, - num_rows='all')) -ontologies = \ - pd.DataFrame( - list(summary_structures.ontology)).drop_duplicates() -flat_structures_dataframe = summary_structures.drop(['ontology'], axis=1) - -#=============================================================================== +summary_structures = pd.DataFrame( + rma.model_query("Structure", criteria=criteria_string, include=include_string, num_rows="all") +) +ontologies = pd.DataFrame(list(summary_structures.ontology)).drop_duplicates() +flat_structures_dataframe = summary_structures.drop(["ontology"], axis=1) + +# =============================================================================== # example 8 -#=============================================================================== +# =============================================================================== -print(summary_structures.ontology[0]['name']) +print(summary_structures.ontology[0]["name"]) -#=============================================================================== +# =============================================================================== # example 9 -#=============================================================================== +# =============================================================================== -summary_structures[['id', - 'parent_structure_id', - 'acronym']].to_csv('summary_structures.csv', - index_label='structure_id') -reread = pd.read_csv('summary_structures.csv') +summary_structures[["id", "parent_structure_id", "acronym"]].to_csv( + "summary_structures.csv", index_label="structure_id" +) +reread = pd.read_csv("summary_structures.csv") -#=============================================================================== +# =============================================================================== # example 10 -#=============================================================================== +# =============================================================================== -for id, name, parent_structure_id in summary_structures[['name', - 'parent_structure_id']].itertuples(): +for id, name, parent_structure_id in summary_structures[["name", "parent_structure_id"]].itertuples(): print("%d %s %d" % (id, name, parent_structure_id)) -#=============================================================================== +# =============================================================================== # example 11 -#=============================================================================== +# =============================================================================== from allensdk.api.warehouse_cache.cache import Cache cache_writer = Cache() -do_cache=True -structures_from_api = \ - cache_writer.wrap(rma.model_query, - path='summary.csv', - cache=do_cache, - model='Structure', - criteria='[graph_id$eq1]', - num_rows='all') +do_cache = True +structures_from_api = cache_writer.wrap( + rma.model_query, path="summary.csv", cache=do_cache, model="Structure", criteria="[graph_id$eq1]", num_rows="all" +) diff --git a/doc_template/examples_root/examples/glif_ex.py b/doc_template/examples_root/examples/glif_ex.py index d8658cb41d..2db6faa643 100644 --- a/doc_template/examples_root/examples/glif_ex.py +++ b/doc_template/examples_root/examples/glif_ex.py @@ -1,10 +1,11 @@ # set matplotlib headless - this turns off the production of visible plots! import matplotlib + matplotlib.use("Agg") -#=============================================================================== +# =============================================================================== # example 1 -#=============================================================================== +# =============================================================================== from allensdk.api.queries.glif_api import GlifApi from allensdk.core.cell_types_cache import CellTypesCache @@ -19,26 +20,26 @@ # download the model configuration file nc = glif_api.get_neuron_configs([neuronal_model_id])[neuronal_model_id] neuron_config = glif_api.get_neuron_configs([neuronal_model_id]) -json_utilities.write('neuron_config.json', neuron_config) +json_utilities.write("neuron_config.json", neuron_config) # download information about the cell ctc = CellTypesCache() -ctc.get_ephys_data(nm['specimen_id'], file_name='stimulus.nwb') -ctc.get_ephys_sweeps(nm['specimen_id'], file_name='ephys_sweeps.json') +ctc.get_ephys_data(nm["specimen_id"], file_name="stimulus.nwb") +ctc.get_ephys_sweeps(nm["specimen_id"], file_name="ephys_sweeps.json") -#=============================================================================== +# =============================================================================== # example 2 -#=============================================================================== +# =============================================================================== import allensdk.core.json_utilities as json_utilities from allensdk.model.glif.glif_neuron import GlifNeuron # initialize the neuron -neuron_config = json_utilities.read('neuron_config.json')['566302806'] +neuron_config = json_utilities.read("neuron_config.json")["566302806"] neuron = GlifNeuron.from_dict(neuron_config) # make a short square pulse. stimulus units should be in Amps. -stimulus = [ 0.0 ] * 100 + [ 10e-9 ] * 100 + [ 0.0 ] * 100 +stimulus = [0.0] * 100 + [10e-9] * 100 + [0.0] * 100 # important! set the neuron's dt value for your stimulus in seconds neuron.dt = 5e-6 @@ -46,133 +47,131 @@ # simulate the neuron output = neuron.run(stimulus) -voltage = output['voltage'] -threshold = output['threshold'] -spike_times = output['interpolated_spike_times'] +voltage = output["voltage"] +threshold = output["threshold"] +spike_times = output["interpolated_spike_times"] -#=============================================================================== +# =============================================================================== # example 3 -#=============================================================================== +# =============================================================================== import allensdk.core.json_utilities as json_utilities from allensdk.model.glif.glif_neuron import GlifNeuron from allensdk.model.glif.simulate_neuron import simulate_neuron -neuron_config = json_utilities.read('neuron_config.json')['566302806'] -ephys_sweeps = json_utilities.read('ephys_sweeps.json') -ephys_file_name = 'stimulus.nwb' +neuron_config = json_utilities.read("neuron_config.json")["566302806"] +ephys_sweeps = json_utilities.read("ephys_sweeps.json") +ephys_file_name = "stimulus.nwb" neuron = GlifNeuron.from_dict(neuron_config) -sweep_numbers = [ s['sweep_number'] for s in ephys_sweeps if s['stimulus_units'] == 'Amps' ] -sweep_numbers = sweep_numbers[:1] # for the sake of a speedy example, just run the first one +sweep_numbers = [s["sweep_number"] for s in ephys_sweeps if s["stimulus_units"] == "Amps"] +sweep_numbers = sweep_numbers[:1] # for the sake of a speedy example, just run the first one simulate_neuron(neuron, sweep_numbers, ephys_file_name, ephys_file_name, 0.05) -#=============================================================================== +# =============================================================================== # example 4 -#=============================================================================== +# =============================================================================== import allensdk.core.json_utilities as json_utilities from allensdk.model.glif.glif_neuron import GlifNeuron from allensdk.core.nwb_data_set import NwbDataSet -neuron_config = json_utilities.read('neuron_config.json')['566302806'] -ephys_sweeps = json_utilities.read('ephys_sweeps.json') -ephys_file_name = 'stimulus.nwb' +neuron_config = json_utilities.read("neuron_config.json")["566302806"] +ephys_sweeps = json_utilities.read("ephys_sweeps.json") +ephys_file_name = "stimulus.nwb" # pull out the stimulus for the current-clamp first sweep -ephys_sweep = next( s for s in ephys_sweeps - if s['stimulus_units'] == 'Amps' ) +ephys_sweep = next(s for s in ephys_sweeps if s["stimulus_units"] == "Amps") ds = NwbDataSet(ephys_file_name) -data = ds.get_sweep(ephys_sweep['sweep_number']) -stimulus = data['stimulus'] +data = ds.get_sweep(ephys_sweep["sweep_number"]) +stimulus = data["stimulus"] # initialize the neuron # important! update the neuron's dt for your stimulus neuron = GlifNeuron.from_dict(neuron_config) -neuron.dt = 1.0 / data['sampling_rate'] +neuron.dt = 1.0 / data["sampling_rate"] # simulate the neuron output = neuron.run(stimulus) -voltage = output['voltage'] -threshold = output['threshold'] -spike_times = output['interpolated_spike_times'] +voltage = output["voltage"] +threshold = output["threshold"] +spike_times = output["interpolated_spike_times"] -#=============================================================================== +# =============================================================================== # example 5 -#=============================================================================== +# =============================================================================== import numpy as np import matplotlib.pyplot as plt -voltage = output['voltage'] -threshold = output['threshold'] -interpolated_spike_times = output['interpolated_spike_times'] -spike_times = output['interpolated_spike_times'] -interpolated_spike_voltages = output['interpolated_spike_voltage'] -interpolated_spike_thresholds = output['interpolated_spike_threshold'] -grid_spike_indices = output['spike_time_steps'] -grid_spike_times = output['grid_spike_times'] -after_spike_currents = output['AScurrents'] +voltage = output["voltage"] +threshold = output["threshold"] +interpolated_spike_times = output["interpolated_spike_times"] +spike_times = output["interpolated_spike_times"] +interpolated_spike_voltages = output["interpolated_spike_voltage"] +interpolated_spike_thresholds = output["interpolated_spike_threshold"] +grid_spike_indices = output["spike_time_steps"] +grid_spike_times = output["grid_spike_times"] +after_spike_currents = output["AScurrents"] # create a time array for plotting -time = np.arange(len(stimulus))*neuron.dt +time = np.arange(len(stimulus)) * neuron.dt plt.figure(figsize=(10, 10)) # plot stimulus -plt.subplot(3,1,1) +plt.subplot(3, 1, 1) plt.plot(time, stimulus) -plt.xlabel('time (s)') -plt.ylabel('current (A)') -plt.title('Stimulus') +plt.xlabel("time (s)") +plt.ylabel("current (A)") +plt.title("Stimulus") # plot model output -plt.subplot(3,1,2) -plt.plot(time, voltage, label='voltage') -plt.plot(time, threshold, label='threshold') +plt.subplot(3, 1, 2) +plt.plot(time, voltage, label="voltage") +plt.plot(time, threshold, label="threshold") if grid_spike_indices is not None: - plt.plot(interpolated_spike_times, interpolated_spike_voltages, 'x', - label='interpolated spike') + plt.plot(interpolated_spike_times, interpolated_spike_voltages, "x", label="interpolated spike") - plt.plot((grid_spike_indices-1)*neuron.dt, voltage[grid_spike_indices-1], '.', - label='last step before spike') + plt.plot((grid_spike_indices - 1) * neuron.dt, voltage[grid_spike_indices - 1], ".", label="last step before spike") -plt.xlabel('time (s)') -plt.ylabel('voltage (V)') +plt.xlabel("time (s)") +plt.ylabel("voltage (V)") plt.legend(loc=3) -plt.title('Model Response') +plt.title("Model Response") # plot after spike currents -plt.subplot(3,1,3) +plt.subplot(3, 1, 3) for ii in range(np.shape(after_spike_currents)[1]): - plt.plot(time, after_spike_currents[:,ii]) -plt.xlabel('time (s)') -plt.ylabel('current (A)') -plt.title('After Spike Currents') + plt.plot(time, after_spike_currents[:, ii]) +plt.xlabel("time (s)") +plt.ylabel("current (A)") +plt.title("After Spike Currents") plt.tight_layout() plt.show() -#=============================================================================== +# =============================================================================== # example 6 -#=============================================================================== +# =============================================================================== + -# define your own custom voltage reset rule +# define your own custom voltage reset rule # this one linearly scales the input voltage def custom_voltage_reset_rule(neuron, voltage_t0, custom_param_a, custom_param_b): return custom_param_a * voltage_t0 + custom_param_b + # initialize a neuron from a neuron config file -neuron_config = json_utilities.read('neuron_config.json')['566302806'] +neuron_config = json_utilities.read("neuron_config.json")["566302806"] neuron = GlifNeuron.from_dict(neuron_config) # configure a new method and overwrite the neuron's old method -method = neuron.configure_method('custom', custom_voltage_reset_rule, - { 'custom_param_a': 0.1, 'custom_param_b': 0.0 }) +method = neuron.configure_method("custom", custom_voltage_reset_rule, {"custom_param_a": 0.1, "custom_param_b": 0.0}) neuron.voltage_reset_method = method output = neuron.run(stimulus) diff --git a/doc_template/examples_root/examples/internal/Lims Behavior Project Cache.ipynb b/doc_template/examples_root/examples/internal/Lims Behavior Project Cache.ipynb index d8e48a134c..9844a27786 100644 --- a/doc_template/examples_root/examples/internal/Lims Behavior Project Cache.ipynb +++ b/doc_template/examples_root/examples/internal/Lims Behavior Project Cache.ipynb @@ -381,13 +381,17 @@ } ], "source": [ - "latest = ophys_experiments.query(\"project_code == 'VisualBehavior'\"\n", - " \"& experiment_workflow_state == 'passed'\")\\\n", - " .sort_values(\"date_of_acquisition\", ascending=False).iloc[0]\n", + "latest = (\n", + " ophys_experiments.query(\"project_code == 'VisualBehavior'& experiment_workflow_state == 'passed'\")\n", + " .sort_values(\"date_of_acquisition\", ascending=False)\n", + " .iloc[0]\n", + ")\n", + "\n", + "print(\n", + " f\"Latest experiment id: {latest.name}. Acquired on {latest['date_of_acquisition']}. \"\n", + " f\"Session: {latest['session_type']}\"\n", + ")\n", "\n", - "print(f\"Latest experiment id: {latest.name}. Acquired on {latest['date_of_acquisition']}. \"\n", - " f\"Session: {latest['session_type']}\")\n", - " \n", "session = cache.get_session_data(latest.name)" ] }, @@ -754,8 +758,8 @@ "\n", "# Plot the running speed\n", "plt.plot(behav_sess.running_speed.timestamps, behav_sess.running_speed.values)\n", - "_ = plt.xlabel('Time (Second)')\n", - "_ = plt.ylabel('Speed (cm/Second)')" + "_ = plt.xlabel(\"Time (Second)\")\n", + "_ = plt.ylabel(\"Speed (cm/Second)\")" ] }, { diff --git a/doc_template/examples_root/examples/internal/Using updated behavior project cache to access eye tracking data.ipynb b/doc_template/examples_root/examples/internal/Using updated behavior project cache to access eye tracking data.ipynb index 42e302910e..49cd0a393c 100644 --- a/doc_template/examples_root/examples/internal/Using updated behavior project cache to access eye tracking data.ipynb +++ b/doc_template/examples_root/examples/internal/Using updated behavior project cache to access eye tracking data.ipynb @@ -1040,7 +1040,7 @@ "outputs": [], "source": [ "# Select only frames that do *not* have likely blinks\n", - "blink_filtered = eye_data.loc[~eye_data['likely_blink']]" + "blink_filtered = eye_data.loc[~eye_data[\"likely_blink\"]]" ] }, { @@ -1072,7 +1072,7 @@ } ], "source": [ - "blink_filtered_normalized = blink_filtered['pupil_area'] / np.percentile(blink_filtered['pupil_area'], 99)\n", + "blink_filtered_normalized = blink_filtered[\"pupil_area\"] / np.percentile(blink_filtered[\"pupil_area\"], 99)\n", "blink_filtered_normalized" ] }, @@ -1489,8 +1489,8 @@ ], "source": [ "# Set eye_area and pupil_area for likely_blink frames to NAN\n", - "eye_data2.loc[eye_data2['likely_blink'], ('eye_area', 'pupil_area')] = np.nan\n", - "eye_data2[eye_data2['likely_blink']]" + "eye_data2.loc[eye_data2[\"likely_blink\"], (\"eye_area\", \"pupil_area\")] = np.nan\n", + "eye_data2[eye_data2[\"likely_blink\"]]" ] }, { @@ -1522,9 +1522,9 @@ } ], "source": [ - "interpolated_pupil_areas = eye_data2.loc[:, ('pupil_area')].interpolate()\n", + "interpolated_pupil_areas = eye_data2.loc[:, (\"pupil_area\")].interpolate()\n", "normalized_interpolated_pupil_areas = interpolated_pupil_areas / np.percentile(interpolated_pupil_areas, 99)\n", - "normalized_interpolated_pupil_areas.rename(columns={'pupil_area': 'normalized_pupil_area'})" + "normalized_interpolated_pupil_areas.rename(columns={\"pupil_area\": \"normalized_pupil_area\"})" ] }, { @@ -2041,29 +2041,18 @@ "\n", "import matplotlib.pyplot as plt\n", "\n", - "fig,ax=plt.subplots()\n", + "fig, ax = plt.subplots()\n", "session.set_params(eye_tracking_z_threshold=1.0, eye_tracking_dilation_frames=1)\n", "eye_data1 = session.eye_tracking.copy(deep=True)\n", - "ax.plot(\n", - " eye_data1['time'],\n", - " eye_data1['likely_blink'],\n", - " color = 'black'\n", - ")\n", + "ax.plot(eye_data1[\"time\"], eye_data1[\"likely_blink\"], color=\"black\")\n", "\n", "\n", "session.set_params(eye_tracking_z_threshold=1.0, eye_tracking_dilation_frames=4)\n", "eye_data2 = session.eye_tracking.copy(deep=True)\n", - "ax.plot(\n", - " eye_data2['time'],\n", - " eye_data2['likely_blink'],\n", - " color = 'red',\n", - " linewidth=4,\n", - " alpha=0.5,\n", - " linestyle=':'\n", - ")\n", + "ax.plot(eye_data2[\"time\"], eye_data2[\"likely_blink\"], color=\"red\", linewidth=4, alpha=0.5, linestyle=\":\")\n", "\n", - "ax.legend(['dilation frames = 1', 'dilation frames = 4'])\n", - "ax.set_xlim(2580,2595)" + "ax.legend([\"dilation frames = 1\", \"dilation frames = 4\"])\n", + "ax.set_xlim(2580, 2595)" ] }, { @@ -2473,7 +2462,7 @@ ], "source": [ "# Trying to set parameters with incorrect types will result in them being *IGNORED*\n", - "session.set_params(eyey_tracking_z_threshold=2, eye_tracking_dilation_frames='hello')\n", + "session.set_params(eyey_tracking_z_threshold=2, eye_tracking_dilation_frames=\"hello\")\n", "session.eye_tracking" ] }, @@ -2904,7 +2893,7 @@ ], "source": [ "# Some more subtle examples where using incorrect types will result in user set params being ignored\n", - "session.set_params(eye_tracking_z_threshold='1', eye_tracking_dilation_frames=4.0)\n", + "session.set_params(eye_tracking_z_threshold=\"1\", eye_tracking_dilation_frames=4.0)\n", "session.eye_tracking" ] } diff --git a/doc_template/examples_root/examples/multicell/multi.py b/doc_template/examples_root/examples/multicell/multi.py index 8458b71840..a505eb426b 100644 --- a/doc_template/examples_root/examples/multicell/multi.py +++ b/doc_template/examples_root/examples/multicell/multi.py @@ -2,11 +2,12 @@ from utils import Utils import matplotlib + matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np -config = Config().load('config.json') +config = Config().load("config.json") # configure NEURON utils = Utils(config) @@ -31,11 +32,8 @@ h.run() # save output voltage to text file -data = np.transpose(np.vstack((vec["t"], - vec["v"][0], - vec["v"][1], - vec["v"][2]))) -np.savetxt('multicell.dat', data) +data = np.transpose(np.vstack((vec["t"], vec["v"][0], vec["v"][1], vec["v"][2]))) +np.savetxt("multicell.dat", data) # use matplotlib to plot to png image fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True) @@ -43,4 +41,4 @@ axes[i].plot(vec["t"], vec["v"][i]) axes[i].set_title(utils.cells_data[i]["type"]) plt.tight_layout() -plt.savefig('multicell.png') +plt.savefig("multicell.png") diff --git a/doc_template/examples_root/examples/multicell/multicell_diff.py b/doc_template/examples_root/examples/multicell/multicell_diff.py index 990142e93b..c921dd00cd 100644 --- a/doc_template/examples_root/examples/multicell/multicell_diff.py +++ b/doc_template/examples_root/examples/multicell/multicell_diff.py @@ -1,14 +1,11 @@ -'''Simple script to compare the result of "python multi.py" with expected values.''' +"""Simple script to compare the result of "python multi.py" with expected values.""" + import numpy as np -result = np.loadtxt('multicell.dat', - dtype={'names': ('t', 'v0', 'v1', 'v2'), - 'formats': ('f4', 'f4', 'f4', 'f4')}) -expected = np.loadtxt('multicell_expected.dat', - dtype={'names': ('t', 'v0', 'v1', 'v2'), - 'formats': ('f4', 'f4', 'f4', 'f4')}) +result = np.loadtxt("multicell.dat", dtype={"names": ("t", "v0", "v1", "v2"), "formats": ("f4", "f4", "f4", "f4")}) +expected = np.loadtxt( + "multicell_expected.dat", dtype={"names": ("t", "v0", "v1", "v2"), "formats": ("f4", "f4", "f4", "f4")} +) -for trace in ['v0', 'v1', 'v2']: - print("%s matches expected values: %s" % (trace, - np.allclose(result[trace], - expected[trace]))) +for trace in ["v0", "v1", "v2"]: + print("%s matches expected values: %s" % (trace, np.allclose(result[trace], expected[trace]))) diff --git a/doc_template/examples_root/examples/multicell/utils.py b/doc_template/examples_root/examples/multicell/utils.py index 7e9ac8adb6..b7a7b20b5d 100644 --- a/doc_template/examples_root/examples/multicell/utils.py +++ b/doc_template/examples_root/examples/multicell/utils.py @@ -4,39 +4,39 @@ class Utils(HocUtils): _log = logging.getLogger(__name__) - + def __init__(self, description): super(Utils, self).__init__(description) self.stim = None self.stim_curr = None self.sampling_rate = None - + def generate_cells(self): - fit_ids = self.description.data['fit_ids'][0] - self.cells_data = self.description.data['biophys'][0]['cells'] + fit_ids = self.description.data["fit_ids"][0] + self.cells_data = self.description.data["biophys"][0]["cells"] self.cells = [] for cell_data in self.cells_data: cell = self.h.cell() self.cells.append(cell) - morphology_path = self.description.manifest.get_path('MORPHOLOGY_%s' % (cell_data['type'])) + morphology_path = self.description.manifest.get_path("MORPHOLOGY_%s" % (cell_data["type"])) self.generate_morphology(cell, morphology_path) - self.load_cell_parameters(cell, fit_ids[cell_data['type']]) + self.load_cell_parameters(cell, fit_ids[cell_data["type"]]) def generate_morphology(self, cell, morph_filename): h = self.h - + swc = self.h.Import3d_SWC_read() swc.input(morph_filename) imprt = self.h.Import3d_GUI(swc, 0) imprt.instantiate(cell) - + for seg in cell.soma[0]: seg.area() for sec in cell.all: sec.nseg = 1 + 2 * int(sec.L / 40) - + cell.simplify_axon() for sec in cell.axonal: sec.L = 30 @@ -45,18 +45,18 @@ def generate_morphology(self, cell, morph_filename): cell.axon[0].connect(cell.soma[0], 0.5, 0) cell.axon[1].connect(cell.axon[0], 1, 0) h.define_shape() - + def load_cell_parameters(self, cell, type_index): - passive = self.description.data['fit'][type_index]['passive'][0] - conditions = self.description.data['fit'][type_index]['conditions'][0] - genome = self.description.data['fit'][type_index]['genome'] + passive = self.description.data["fit"][type_index]["passive"][0] + conditions = self.description.data["fit"][type_index]["conditions"][0] + genome = self.description.data["fit"][type_index]["genome"] # Set passive properties - cm_dict = dict([(c['section'], c['cm']) for c in passive['cm']]) + cm_dict = dict([(c["section"], c["cm"]) for c in passive["cm"]]) for sec in cell.all: - sec.Ra = passive['ra'] + sec.Ra = passive["ra"] sec.cm = cm_dict[sec.name().split(".")[1][:4]] - sec.insert('pas') + sec.insert("pas") for seg in sec: seg.pas.e = passive["e_pas"] @@ -67,9 +67,9 @@ def load_cell_parameters(self, cell, type_index): if p["mechanism"] != "": sec.insert(p["mechanism"]) setattr(sec, p["name"], p["value"]) - + # Set reversal potentials - for erev in conditions['erev']: + for erev in conditions["erev"]: sections = [s for s in cell.all if s.name().split(".")[1][:4] == erev["section"]] for sec in sections: sec.ena = erev["ena"] @@ -103,13 +103,11 @@ def setup_iclamp_step(self, target_cell, amp, delay, dur): self.stim.dur = dur def record_values(self): - vec = { "v": [], - "t": self.h.Vector() } - + vec = {"v": [], "t": self.h.Vector()} + for i, cell in enumerate(self.cells): vec["v"].append(self.h.Vector()) vec["v"][i].record(cell.soma[0](0.5)._ref_v) vec["t"].record(self.h._ref_t) - + return vec - \ No newline at end of file diff --git a/doc_template/examples_root/examples/nb/aligning_behavioral_data_to_task_events_with_the_stimulus_and_trials_tables.ipynb b/doc_template/examples_root/examples/nb/aligning_behavioral_data_to_task_events_with_the_stimulus_and_trials_tables.ipynb index 37ef6d9aab..5b284ceb61 100644 --- a/doc_template/examples_root/examples/nb/aligning_behavioral_data_to_task_events_with_the_stimulus_and_trials_tables.ipynb +++ b/doc_template/examples_root/examples/nb/aligning_behavioral_data_to_task_events_with_the_stimulus_and_trials_tables.ipynb @@ -77,9 +77,9 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "from allensdk.brain_observatory.behavior.behavior_project_cache.\\\n", - " behavior_neuropixels_project_cache \\\n", - " import VisualBehaviorNeuropixelsProjectCache" + "from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import (\n", + " VisualBehaviorNeuropixelsProjectCache,\n", + ")" ] }, { @@ -164,17 +164,16 @@ "\n", "is not deleted between instantiations of this cache\n", " warnings.warn(msg, MissingLocalManifestWarning)\n", - "ecephys_sessions.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 64.7k/64.7k [00:00<00:00, 510kMB/s] \n", - "behavior_sessions.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 562k/562k [00:00<00:00, 3.02MMB/s] \n", - "units.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 132M/132M [00:05<00:00, 24.1MMB/s]\n", - "probes.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 130k/130k [00:00<00:00, 798kMB/s] \n", - "channels.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 27.9M/27.9M [00:01<00:00, 18.9MMB/s]\n" + "ecephys_sessions.csv: 100%|██████████| 64.7k/64.7k [00:00<00:00, 510kMB/s] \n", + "behavior_sessions.csv: 100%|██████████| 562k/562k [00:00<00:00, 3.02MMB/s] \n", + "units.csv: 100%|██████████| 132M/132M [00:05<00:00, 24.1MMB/s]\n", + "probes.csv: 100%|██████████| 130k/130k [00:00<00:00, 798kMB/s] \n", + "channels.csv: 100%|██████████| 27.9M/27.9M [00:01<00:00, 18.9MMB/s]\n" ] } ], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=Path(output_dir))\n", + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=Path(output_dir))\n", "\n", "ecephys_sessions_table = cache.get_ecephys_session_table()" ] @@ -277,7 +276,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "ecephys_session_1065437523.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 3.20G/3.20G [02:26<00:00, 21.8MMB/s]\n", + "ecephys_session_1065437523.nwb: 100%|██████████| 3.20G/3.20G [02:26<00:00, 21.8MMB/s]\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.\n", " warn(\"Ignoring cached namespace '%s' version %s because version %s is already loaded.\"\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.6.0-alpha is already loaded.\n", @@ -288,8 +287,7 @@ } ], "source": [ - "session = cache.get_ecephys_session(\n", - " ecephys_session_id=1065437523)" + "session = cache.get_ecephys_session(ecephys_session_id=1065437523)" ] }, { @@ -497,11 +495,9 @@ } ], "source": [ - "stimulus_presentations.groupby('stimulus_block')[['stimulus_block', \n", - " 'stimulus_name', \n", - " 'active', \n", - " 'duration', \n", - " 'start_time']].head(1)" + "stimulus_presentations.groupby(\"stimulus_block\")[\n", + " [\"stimulus_block\", \"stimulus_name\", \"active\", \"duration\", \"start_time\"]\n", + "].head(1)" ] }, { @@ -648,9 +644,9 @@ } ], "source": [ - "active_image_presentations = stimulus_presentations[stimulus_presentations['stimulus_block']==0]\n", - "passive_image_presentations = stimulus_presentations[stimulus_presentations['stimulus_block']==5]\n", - "np.all(active_image_presentations['image_name'].values == passive_image_presentations['image_name'].values )" + "active_image_presentations = stimulus_presentations[stimulus_presentations[\"stimulus_block\"] == 0]\n", + "passive_image_presentations = stimulus_presentations[stimulus_presentations[\"stimulus_block\"] == 5]\n", + "np.all(active_image_presentations[\"image_name\"].values == passive_image_presentations[\"image_name\"].values)" ] }, { @@ -752,22 +748,22 @@ } ], "source": [ - "#get the active behavior part of the stim table\n", - "active_behavior = stimulus_presentations[stimulus_presentations['active']]\n", + "# get the active behavior part of the stim table\n", + "active_behavior = stimulus_presentations[stimulus_presentations[\"active\"]]\n", "\n", - "#for now, let's leave out the omitted stimuli\n", - "active_behavior_no_omissions = active_behavior[~active_behavior['omitted']]\n", + "# for now, let's leave out the omitted stimuli\n", + "active_behavior_no_omissions = active_behavior[~active_behavior[\"omitted\"]]\n", "\n", - "#plot histogram of the stimulus durations\n", + "# plot histogram of the stimulus durations\n", "fig, axes = plt.subplots(1, 2)\n", - "fig.set_size_inches([10,4])\n", - "_ = axes[0].hist(active_behavior_no_omissions.duration, 50)\n", - "axes[0].set_xlabel('Flash Duration (s)')\n", - "axes[0].set_ylabel('Count')\n", + "fig.set_size_inches([10, 4])\n", + "_ = axes[0].hist(active_behavior_no_omissions.duration, 50)\n", + "axes[0].set_xlabel(\"Flash Duration (s)\")\n", + "axes[0].set_ylabel(\"Count\")\n", "\n", - "inter_flash = active_behavior_no_omissions['start_time'].diff()\n", + "inter_flash = active_behavior_no_omissions[\"start_time\"].diff()\n", "_ = axes[1].hist(inter_flash, np.arange(0.7, 1.6, 0.05))\n", - "axes[1].set_xlabel('Inter-flash interval (s)')\n", + "axes[1].set_xlabel(\"Inter-flash interval (s)\")\n", "axes[1].set_xticks(np.arange(0.75, 1.6, 0.25))" ] }, @@ -826,8 +822,8 @@ } ], "source": [ - "#look at the percentage of flashes that were omissions\n", - "np.sum(active_behavior.omitted)/len(active_behavior)" + "# look at the percentage of flashes that were omissions\n", + "np.sum(active_behavior.omitted) / len(active_behavior)" ] }, { @@ -1238,23 +1234,23 @@ "outputs": [], "source": [ "def get_change_time_from_stim_table(row):\n", - " '''\n", + " \"\"\"\n", " Given a particular row in the trials table,\n", - " find the corresponding change time in the \n", + " find the corresponding change time in the\n", " stimulus presentations table\n", - " '''\n", + " \"\"\"\n", " table = stimulus_presentations\n", - " change_frame = row['change_frame']\n", + " change_frame = row[\"change_frame\"]\n", " if np.isnan(change_frame) or change_frame < 0:\n", " return np.nan\n", - " \n", - " change_time = table[table.start_frame==change_frame]\\\n", - " ['start_time'].values[0]\n", - " \n", + "\n", + " change_time = table[table.start_frame == change_frame][\"start_time\"].values[0]\n", + "\n", " return change_time\n", "\n", + "\n", "change_times = trials.apply(get_change_time_from_stim_table, axis=1)\n", - "trials['change_time_with_display_delay'] = change_times" + "trials[\"change_time_with_display_delay\"] = change_times" ] }, { @@ -1362,19 +1358,19 @@ ], "source": [ "# filter for the hit trials\n", - "hit_trials = trials[trials['hit']]\n", + "hit_trials = trials[trials[\"hit\"]]\n", "\n", "# find the time of the first lick after each change\n", - "lick_indices = np.searchsorted(licks.timestamps, hit_trials['change_time_with_display_delay'])\n", + "lick_indices = np.searchsorted(licks.timestamps, hit_trials[\"change_time_with_display_delay\"])\n", "first_lick_times = licks.timestamps.values[lick_indices]\n", - "response_latencies = first_lick_times - hit_trials['change_time_with_display_delay']\n", + "response_latencies = first_lick_times - hit_trials[\"change_time_with_display_delay\"]\n", "\n", "# plot the latencies\n", "fig, ax = plt.subplots()\n", - "fig.suptitle('Response Latency Histogram for Hit trials')\n", + "fig.suptitle(\"Response Latency Histogram for Hit trials\")\n", "ax.hist(response_latencies, bins=np.linspace(-0.1, 0.8, 50))\n", - "ax.set_xlabel('Time from change (s)')\n", - "ax.set_ylabel('Trial count')" + "ax.set_xlabel(\"Time from change (s)\")\n", + "ax.set_ylabel(\"Trial count\")" ] }, { @@ -1667,7 +1663,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 23 columns

\n", + "

5 rows × 23 columns

\n", "" ], "text/plain": [ @@ -1944,7 +1940,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 23 columns

\n", + "

5 rows × 23 columns

\n", "" ], "text/plain": [ @@ -1989,7 +1985,7 @@ } ], "source": [ - "eye_tracking_noblinks = eye_tracking[~eye_tracking['likely_blink']]\n", + "eye_tracking_noblinks = eye_tracking[~eye_tracking[\"likely_blink\"]]\n", "eye_tracking_noblinks.head()" ] }, @@ -2265,52 +2261,56 @@ } ], "source": [ - "time_before = 3.0 #how much time to plot before the reward\n", - "time_after = 3.0 #how much time to plot after the reward\n", - "reward_time = session.rewards.iloc[15]['timestamps'] #get a random reward time\n", - "\n", - "#Get running data aligned to this reward\n", - "trial_running = running_speed.query('timestamps >= {} and timestamps <= {} '.\n", - " format(reward_time-time_before, reward_time+time_after))\n", - "\n", - "#Get pupil data aligned to this reward\n", - "trial_pupil_area = eye_tracking_noblinks.query('timestamps >= {} and timestamps <= {} '.\n", - " format(reward_time-time_before, reward_time+time_after))\n", - "\n", - "#Get stimulus presentations around this reward\n", - "behavior_presentations = stimulus_presentations[stimulus_presentations['active']]\n", - "behavior_presentations = behavior_presentations[~behavior_presentations['omitted']]\n", - "trial_stimuli = behavior_presentations.query('end_time >= {} and start_time <= {}'.\n", - " format(reward_time-time_before, reward_time+time_after))\n", - "\n", - "#Get licking aligned to this reward\n", - "trial_licking = licks.query('timestamps >= {} and timestamps <= {} '.\n", - " format(reward_time-time_before, reward_time+time_after))\n", - "\n", - "\n", - "#Plot running, pupil area and licks\n", + "time_before = 3.0 # how much time to plot before the reward\n", + "time_after = 3.0 # how much time to plot after the reward\n", + "reward_time = session.rewards.iloc[15][\"timestamps\"] # get a random reward time\n", + "\n", + "# Get running data aligned to this reward\n", + "trial_running = running_speed.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(reward_time - time_before, reward_time + time_after)\n", + ")\n", + "\n", + "# Get pupil data aligned to this reward\n", + "trial_pupil_area = eye_tracking_noblinks.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(reward_time - time_before, reward_time + time_after)\n", + ")\n", + "\n", + "# Get stimulus presentations around this reward\n", + "behavior_presentations = stimulus_presentations[stimulus_presentations[\"active\"]]\n", + "behavior_presentations = behavior_presentations[~behavior_presentations[\"omitted\"]]\n", + "trial_stimuli = behavior_presentations.query(\n", + " \"end_time >= {} and start_time <= {}\".format(reward_time - time_before, reward_time + time_after)\n", + ")\n", + "\n", + "# Get licking aligned to this reward\n", + "trial_licking = licks.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(reward_time - time_before, reward_time + time_after)\n", + ")\n", + "\n", + "\n", + "# Plot running, pupil area and licks\n", "fig, axr = plt.subplots()\n", - "fig.set_size_inches(14,6)\n", - "axr.plot(trial_running['timestamps'], trial_running['speed'], 'k')\n", + "fig.set_size_inches(14, 6)\n", + "axr.plot(trial_running[\"timestamps\"], trial_running[\"speed\"], \"k\")\n", "axp = axr.twinx()\n", - "axp.plot(trial_pupil_area['timestamps'], trial_pupil_area['pupil_area'], 'g')\n", - "rew_handle, = axr.plot(reward_time, 0, 'db', markersize=10)\n", - "lick_handle, = axr.plot(trial_licking['timestamps'], np.zeros(len(trial_licking['timestamps'])), 'mo')\n", - "axr.legend([rew_handle, lick_handle], ['reward', 'licks'])\n", - "\n", - "axr.set_ylabel('running speed (cm/s)')\n", - "axp.set_ylabel('pupil area\\n$(pixels^2)$')\n", - "axr.set_xlabel('Experiment time (s)')\n", - "\n", - "axp.yaxis.label.set_color('g')\n", - "axp.spines['right'].set_color('g')\n", - "axp.tick_params(axis='y', colors='g')\n", - "\n", - "#Plot the image flashes as grey bars. \n", - "colors = ['0.3', '0.8']\n", - "stimulus_colors = {stim: c for stim,c in zip(trial_stimuli['image_name'].unique(), colors)}\n", + "axp.plot(trial_pupil_area[\"timestamps\"], trial_pupil_area[\"pupil_area\"], \"g\")\n", + "(rew_handle,) = axr.plot(reward_time, 0, \"db\", markersize=10)\n", + "(lick_handle,) = axr.plot(trial_licking[\"timestamps\"], np.zeros(len(trial_licking[\"timestamps\"])), \"mo\")\n", + "axr.legend([rew_handle, lick_handle], [\"reward\", \"licks\"])\n", + "\n", + "axr.set_ylabel(\"running speed (cm/s)\")\n", + "axp.set_ylabel(\"pupil area\\n$(pixels^2)$\")\n", + "axr.set_xlabel(\"Experiment time (s)\")\n", + "\n", + "axp.yaxis.label.set_color(\"g\")\n", + "axp.spines[\"right\"].set_color(\"g\")\n", + "axp.tick_params(axis=\"y\", colors=\"g\")\n", + "\n", + "# Plot the image flashes as grey bars.\n", + "colors = [\"0.3\", \"0.8\"]\n", + "stimulus_colors = {stim: c for stim, c in zip(trial_stimuli[\"image_name\"].unique(), colors)}\n", "for idx, stimulus in trial_stimuli.iterrows():\n", - " axr.axvspan(stimulus['start_time'], stimulus['end_time'], color=stimulus_colors[stimulus['image_name']], alpha=0.5)" + " axr.axvspan(stimulus[\"start_time\"], stimulus[\"end_time\"], color=stimulus_colors[stimulus[\"image_name\"]], alpha=0.5)" ] }, { diff --git a/doc_template/examples_root/examples/nb/api_modernization/updated_vs_legacy_api.ipynb b/doc_template/examples_root/examples/nb/api_modernization/updated_vs_legacy_api.ipynb index ae4748cfda..85ef4da286 100644 --- a/doc_template/examples_root/examples/nb/api_modernization/updated_vs_legacy_api.ipynb +++ b/doc_template/examples_root/examples/nb/api_modernization/updated_vs_legacy_api.ipynb @@ -63,7 +63,7 @@ } ], "source": [ - "dff_trace = session.dff_traces.loc[cell_specimen_id]['dff']\n", + "dff_trace = session.dff_traces.loc[cell_specimen_id][\"dff\"]\n", "timestamps = session.ophys_timestamps\n", "\n", "_ = plt.plot(timestamps, dff_trace)" diff --git a/doc_template/examples_root/examples/nb/behavior_ophys_session.ipynb b/doc_template/examples_root/examples/nb/behavior_ophys_session.ipynb index b4f8d1253d..2168170f0b 100644 --- a/doc_template/examples_root/examples/nb/behavior_ophys_session.ipynb +++ b/doc_template/examples_root/examples/nb/behavior_ophys_session.ipynb @@ -67,7 +67,7 @@ " \"specimen_830940327\",\n", " \"ophys_session_878436988\",\n", " \"ophys_experiment_879332693\",\n", - " \"behavior_ophys_session_879332693.nwb\"\n", + " \"behavior_ophys_session_879332693.nwb\",\n", ")\n", "session = BehaviorOphysSession.from_nwb_path(full_filepath)" ] @@ -195,14 +195,13 @@ } ], "source": [ - "\n", "print(session.ophys_experiment_id)\n", "print()\n", "for key, val in session.metadata.items():\n", - " print('{}: {}'.format(key, val))\n", + " print(\"{}: {}\".format(key, val))\n", "print()\n", "for key, val in session.task_parameters.items():\n", - " print('{}: {}'.format(key, val))" + " print(\"{}: {}\".format(key, val))" ] }, { @@ -1017,7 +1016,7 @@ "fig, ax = plt.subplots(1, 3)\n", "ax[0].imshow(sitk.GetArrayFromImage(session.max_projection))\n", "ax[1].imshow(sitk.GetArrayFromImage(session.average_projection))\n", - "ax[2].imshow(sitk.GetArrayFromImage(session.segmentation_mask_image))\n" + "ax[2].imshow(sitk.GetArrayFromImage(session.segmentation_mask_image))" ] }, { @@ -1055,8 +1054,8 @@ ], "source": [ "plt.plot(session.running_speed.timestamps, session.running_speed.values)\n", - "_ = plt.xlabel('Time (Second)')\n", - "_ = plt.ylabel('Speed (cm/Second)')" + "_ = plt.xlabel(\"Time (Second)\")\n", + "_ = plt.ylabel(\"Speed (cm/Second)\")" ] }, { @@ -1384,7 +1383,7 @@ "outputs": [], "source": [ "def cut_trace(start_time, stop_time, timestamps, values):\n", - " inds = np.nonzero(np.logical_and(timestamps>start_time, timestamps start_time, timestamps < stop_time))\n", " return values[inds]" ] }, @@ -1416,7 +1415,7 @@ "source": [ "def get_dff_by_epoch(epoch, session=None, cell_specimen_id=None):\n", " timestamps = session.ophys_timestamps[:]\n", - " values = session.dff_traces['dff'].loc[cell_specimen_id]\n", + " values = session.dff_traces[\"dff\"].loc[cell_specimen_id]\n", " return cut_trace_by_epoch(epoch, timestamps, values)" ] }, @@ -1456,12 +1455,13 @@ "def cut_trace_by_epoch(epoch, timestamps, values):\n", " return cut_trace(epoch.start_time, epoch.stop_time, timestamps, values)\n", "\n", + "\n", "class Epoch:\n", - " \n", " def __init__(self, start_time, stop_time):\n", " self.start_time = start_time\n", " self.stop_time = stop_time\n", - " \n", + "\n", + "\n", "e = Epoch(100, 100.25)\n", "cell_specimen_id = 879381164 # for example\n", "\n", @@ -1514,7 +1514,7 @@ "source": [ "for trial_index, row in session.trials.head().iterrows():\n", " cut_running_speed = cut_trace_by_epoch(row, session.running_speed.timestamps, session.running_speed.values)\n", - " print('trial {}: mean_speed = {:5.2f}'.format(trial_index, cut_running_speed.mean()))" + " print(\"trial {}: mean_speed = {:5.2f}\".format(trial_index, cut_running_speed.mean()))" ] }, { @@ -1541,7 +1541,6 @@ "outputs": [], "source": [ "def ragged_average(series):\n", - " \n", " nr = len(series)\n", " nc = max([len(arr) for arr in series.values])\n", "\n", @@ -1549,8 +1548,8 @@ " x.fill(np.nan)\n", " for ri in range(nr):\n", " data = series.iloc[ri]\n", - " x[ri,:len(data)] = data\n", - " \n", + " x[ri, : len(data)] = data\n", + "\n", " return np.nanmean(x, axis=0)" ] }, diff --git a/doc_template/examples_root/examples/nb/brain_observatory.ipynb b/doc_template/examples_root/examples/nb/brain_observatory.ipynb index 67bded38b3..9dcc81c2cc 100644 --- a/doc_template/examples_root/examples/nb/brain_observatory.ipynb +++ b/doc_template/examples_root/examples/nb/brain_observatory.ipynb @@ -99,7 +99,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -138,12 +138,11 @@ "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", "import pprint\n", "\n", - "# This class uses a 'manifest' to keep track of downloaded data and metadata. \n", + "# This class uses a 'manifest' to keep track of downloaded data and metadata.\n", "# All downloaded files will be stored relative to the directory holding the manifest\n", - "# file. If 'manifest_file' is a relative path (as it is below), it will be \n", + "# file. If 'manifest_file' is a relative path (as it is below), it will be\n", "# saved relative to your working directory. It can also be an absolute path.\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "\n", "# Download a list of all targeted areas\n", "targeted_structures = boc.get_all_targeted_structures()\n", @@ -185,12 +184,11 @@ "source": [ "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", "\n", - "# This class uses a 'manifest' to keep track of downloaded data and metadata. \n", + "# This class uses a 'manifest' to keep track of downloaded data and metadata.\n", "# All downloaded files will be stored relative to the directory holding the manifest\n", - "# file. If 'manifest_file' is a relative path (as it is below), it will be \n", + "# file. If 'manifest_file' is a relative path (as it is below), it will be\n", "# saved relative to your working directory. It can also be an absolute path.\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "\n", "# Download a list of all targeted areas\n", "targeted_structures = boc.get_all_targeted_structures()\n", @@ -232,12 +230,11 @@ "source": [ "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", "\n", - "# This class uses a 'manifest' to keep track of downloaded data and metadata. \n", + "# This class uses a 'manifest' to keep track of downloaded data and metadata.\n", "# All downloaded files will be stored relative to the directory holding the manifest\n", - "# file. If 'manifest_file' is a relative path (as it is below), it will be \n", + "# file. If 'manifest_file' is a relative path (as it is below), it will be\n", "# saved relative to your working directory. It can also be an absolute path.\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "\n", "# Download a list of all targeted areas\n", "targeted_structures = boc.get_all_targeted_structures()\n", @@ -278,7 +275,7 @@ ], "source": [ "# Download experiment containers for VISp experiments\n", - "visp_ecs = boc.get_experiment_containers(targeted_structures=['VISp'])\n", + "visp_ecs = boc.get_experiment_containers(targeted_structures=[\"VISp\"])\n", "print(\"all VISp experiment containers: %d\" % len(visp_ecs))" ] }, @@ -417,7 +414,7 @@ } ], "source": [ - "# Download a list of all cre driver lines \n", + "# Download a list of all cre driver lines\n", "cre_lines = boc.get_all_cre_lines()\n", "print(\"all cre lines:\\n\")\n", "pprint.pprint(cre_lines)" @@ -468,7 +465,7 @@ ], "source": [ "# Download experiment containers for Cux2 experiments\n", - "cux2_ecs = boc.get_experiment_containers(cre_lines=['Cux2-CreERT2'])\n", + "cux2_ecs = boc.get_experiment_containers(cre_lines=[\"Cux2-CreERT2\"])\n", "print(\"Cux2 experiments: %d\\n\" % len(cux2_ecs))\n", "\n", "print(\"Example experiment container record:\")\n", @@ -564,7 +561,7 @@ ], "source": [ "# Find all of the experiments for an experiment container\n", - "cux2_ec_id = cux2_ecs[0]['id']\n", + "cux2_ec_id = cux2_ecs[0][\"id\"]\n", "exps = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id])\n", "print(\"Experiments for experiment_container_id %d: %d\\n\" % (cux2_ec_id, len(exps)))\n", "pprint.pprint(exps)" @@ -637,11 +634,10 @@ "import allensdk.brain_observatory.stimulus_info as stim_info\n", "\n", "# pick one of the cux2 experiment containers\n", - "cux2_ec_id = cux2_ecs[-1]['id']\n", + "cux2_ec_id = cux2_ecs[-1][\"id\"]\n", "\n", "# Find the experiment with the static static gratings stimulus\n", - "exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], \n", - " stimuli=[stim_info.STATIC_GRATINGS])[0]\n", + "exp = boc.get_ophys_experiments(experiment_container_ids=[cux2_ec_id], stimuli=[stim_info.STATIC_GRATINGS])[0]\n", "print(\"Experiment with static gratings:\")\n", "pprint.pprint(exp)" ] @@ -722,7 +718,7 @@ } ], "source": [ - "data_set = boc.get_ophys_experiment_data(exp['id'])\n", + "data_set = boc.get_ophys_experiment_data(exp[\"id\"])\n", "\n", "# print out the metadata available in the NWB file\n", "pprint.pprint(data_set.get_metadata())" @@ -795,16 +791,16 @@ "print(\"total cells: %d\" % len(cells))\n", "\n", "# find direction selective cells in VISp\n", - "visp_ec_ids = [ ec['id'] for ec in visp_ecs ]\n", - "visp_cells = cells[cells['experiment_container_id'].isin(visp_ec_ids)]\n", + "visp_ec_ids = [ec[\"id\"] for ec in visp_ecs]\n", + "visp_cells = cells[cells[\"experiment_container_id\"].isin(visp_ec_ids)]\n", "print(\"VISp cells: %d\" % len(visp_cells))\n", "\n", "# significant response to drifting gratings stimulus\n", - "sig_cells = visp_cells[visp_cells['p_dg'] < 0.05]\n", + "sig_cells = visp_cells[visp_cells[\"p_dg\"] < 0.05]\n", "print(\"cells with sig. response to drifting gratings: %d\" % len(sig_cells))\n", "\n", "# direction selective cells\n", - "dsi_cells = sig_cells[(sig_cells['g_dsi_dg'] > 0.9)]\n", + "dsi_cells = sig_cells[(sig_cells[\"g_dsi_dg\"] > 0.9)]\n", "print(\"direction-selective cells: %d\" % len(dsi_cells))" ] }, @@ -878,7 +874,7 @@ "import allensdk.brain_observatory.stimulus_info as stim_info\n", "\n", "# find experiment containers for those cells\n", - "dsi_ec_ids = dsi_cells['experiment_container_id'].unique()\n", + "dsi_ec_ids = dsi_cells[\"experiment_container_id\"].unique()\n", "print(\"total dsi experiment containers: %d\" % len(dsi_ec_ids))\n", "\n", "# Download the ophys experiments containing the drifting gratings stimulus for VISp experiment containers\n", @@ -973,10 +969,11 @@ "dsi_cell = dsi_cells.iloc[0]\n", "\n", "# figure out which ophys experiment has the drifting gratings stimulus for that cell\n", - "cell_exp = boc.get_ophys_experiments(cell_specimen_ids=[dsi_cell['cell_specimen_id']],\n", - " stimuli=[stim_info.DRIFTING_GRATINGS])[0]\n", + "cell_exp = boc.get_ophys_experiments(\n", + " cell_specimen_ids=[dsi_cell[\"cell_specimen_id\"]], stimuli=[stim_info.DRIFTING_GRATINGS]\n", + ")[0]\n", "\n", - "data_set = boc.get_ophys_experiment_data(cell_exp['id'])\n", + "data_set = boc.get_ophys_experiment_data(cell_exp[\"id\"])\n", "\n", "print(\"Metadata from NWB file:\")\n", "pprint.pprint(data_set.get_metadata())\n", @@ -1031,7 +1028,7 @@ }, "outputs": [], "source": [ - "dsi_cell_id = dsi_cell['cell_specimen_id']\n", + "dsi_cell_id = dsi_cell[\"cell_specimen_id\"]\n", "time, raw_traces = data_set.get_fluorescence_traces(cell_specimen_ids=[dsi_cell_id])\n", "_, demixed_traces = data_set.get_demixed_traces(cell_specimen_ids=[dsi_cell_id])\n", "_, neuropil_traces = data_set.get_neuropil_traces(cell_specimen_ids=[dsi_cell_id])\n", @@ -1106,29 +1103,30 @@ ], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "# plot raw and corrected ROI trace\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"Raw Fluorescence Trace\")\n", "plt.plot(time, raw_traces[0])\n", "plt.show()\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"Demixed Fluorescence Trace\")\n", "plt.plot(time, demixed_traces[0])\n", "plt.show()\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"Neuropil-corrected Fluorescence Trace\")\n", "plt.plot(time, corrected_traces[0])\n", "plt.show()\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"dF/F Trace\")\n", - "# warning: dF/F can occasionally be one element longer or shorter \n", + "# warning: dF/F can occasionally be one element longer or shorter\n", "# than the time stamps for the original traces.\n", - "plt.plot(time[:len(dff_traces[0])], dff_traces[0])\n", + "plt.plot(time[: len(dff_traces[0])], dff_traces[0])\n", "plt.show()" ] }, @@ -1198,6 +1196,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", + "\n", "%matplotlib inline\n", "\n", "data_set = boc.get_ophys_experiment_data(510221121)\n", @@ -1209,22 +1208,22 @@ "roi_mask_list = data_set.get_roi_mask(cell_specimen_ids=cids)\n", "\n", "# plot each mask\n", - "f, axes = plt.subplots(1, len(cids)+2, figsize=(15, 3))\n", + "f, axes = plt.subplots(1, len(cids) + 2, figsize=(15, 3))\n", "for ax, roi_mask, cid in zip(axes[:-2], roi_mask_list, cids):\n", - " ax.imshow(roi_mask.get_mask_plane(), cmap='gray')\n", - " ax.set_title('cell %d' % cid)\n", + " ax.imshow(roi_mask.get_mask_plane(), cmap=\"gray\")\n", + " ax.set_title(\"cell %d\" % cid)\n", "\n", - "# make a mask of all ROIs in the experiment \n", + "# make a mask of all ROIs in the experiment\n", "all_roi_masks = data_set.get_roi_mask_array()\n", "combined_mask = all_roi_masks.max(axis=0)\n", "\n", - "axes[-2].imshow(combined_mask, cmap='gray')\n", - "axes[-2].set_title('all ROIs')\n", + "axes[-2].imshow(combined_mask, cmap=\"gray\")\n", + "axes[-2].set_title(\"all ROIs\")\n", "\n", "# show the movie max projection\n", "max_projection = data_set.get_max_projection()\n", - "axes[-1].imshow(max_projection, cmap='gray')\n", - "axes[-1].set_title('max projection')\n", + "axes[-1].imshow(max_projection, cmap=\"gray\")\n", + "axes[-1].set_title(\"max projection\")\n", "\n", "plt.show()" ] @@ -1336,53 +1335,54 @@ ], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "# filter for visually responding, selective cells\n", - "vis_cells = (dg_peak.ptest_dg < 0.05) & (dg_peak.peak_dff_dg > 3)\n", + "vis_cells = (dg_peak.ptest_dg < 0.05) & (dg_peak.peak_dff_dg > 3)\n", "osi_cells = vis_cells & (dg_peak.osi_dg > 0.5) & (dg_peak.osi_dg <= 1.5)\n", "dsi_cells = vis_cells & (dg_peak.dsi_dg > 0.5) & (dg_peak.dsi_dg <= 1.5)\n", "\n", "# 2-d tf vs. ori histogram\n", "# tfval = 0 is used for the blank sweep, so we are ignoring it here\n", - "os = np.zeros((len(dg.orivals), len(dg.tfvals)-1))\n", - "ds = np.zeros((len(dg.orivals), len(dg.tfvals)-1))\n", + "os = np.zeros((len(dg.orivals), len(dg.tfvals) - 1))\n", + "ds = np.zeros((len(dg.orivals), len(dg.tfvals) - 1))\n", + "\n", + "for i, trial in dg_peak[osi_cells].iterrows():\n", + " os[trial.ori_dg, trial.tf_dg - 1] += 1\n", "\n", - "for i,trial in dg_peak[osi_cells].iterrows():\n", - " os[trial.ori_dg, trial.tf_dg-1] += 1\n", - " \n", - "for i,trial in dg_peak[dsi_cells].iterrows():\n", - " ds[trial.ori_dg, trial.tf_dg-1] += 1\n", + "for i, trial in dg_peak[dsi_cells].iterrows():\n", + " ds[trial.ori_dg, trial.tf_dg - 1] += 1\n", "\n", "max_count = max(os.max(), ds.max())\n", "\n", - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "\n", "# plot direction selectivity\n", - "im = ax1.imshow(ds, clim=[0,max_count], cmap='hot', interpolation='nearest')\n", - "ax1.set_xlabel('temporal frequency')\n", - "ax1.set_ylabel('direction')\n", - "ax1.set_xticks(np.arange(len(dg.tfvals)-1))\n", + "im = ax1.imshow(ds, clim=[0, max_count], cmap=\"hot\", interpolation=\"nearest\")\n", + "ax1.set_xlabel(\"temporal frequency\")\n", + "ax1.set_ylabel(\"direction\")\n", + "ax1.set_xticks(np.arange(len(dg.tfvals) - 1))\n", "ax1.set_xticklabels(dg.tfvals[1:])\n", "ax1.set_yticks(np.arange(len(dg.orivals)))\n", "ax1.set_yticklabels(dg.orivals)\n", - "ax1.set_title('direction selective cells')\n", + "ax1.set_title(\"direction selective cells\")\n", "\n", "# plot orientation selectivity\n", - "im = ax2.imshow(os, clim=[0,max_count], cmap='hot', interpolation='nearest')\n", - "ax2.set_xlabel('temporal frequency')\n", - "ax2.set_ylabel('orientation')\n", - "ax2.set_xticks(np.arange(len(dg.tfvals)-1))\n", + "im = ax2.imshow(os, clim=[0, max_count], cmap=\"hot\", interpolation=\"nearest\")\n", + "ax2.set_xlabel(\"temporal frequency\")\n", + "ax2.set_ylabel(\"orientation\")\n", + "ax2.set_xticks(np.arange(len(dg.tfvals) - 1))\n", "ax2.set_xticklabels(dg.tfvals[1:])\n", "ax2.set_yticks(np.arange(len(dg.orivals)))\n", "ax2.set_yticklabels(dg.orivals)\n", - "ax2.set_title('orientation selective cells')\n", + "ax2.set_title(\"orientation selective cells\")\n", "\n", "# plot a colorbar\n", "fig.subplots_adjust(right=0.9)\n", "cbar_ax = fig.add_axes([0.95, 0.05, 0.05, 0.85])\n", "cbar = fig.colorbar(im, cax=cbar_ax)\n", - "cbar.set_ticks(np.arange(0, max_count, 2)+0.5)\n", + "cbar.set_ticks(np.arange(0, max_count, 2) + 0.5)\n", "cbar.set_ticklabels(np.arange(0, max_count, 2, dtype=int))\n", "\n", "plt.show()" @@ -1459,9 +1459,9 @@ "_, neuropil_traces = data_set.get_neuropil_traces(cell_specimen_ids=[csid])\n", "\n", "results = estimate_contamination_ratios(demixed_traces[0], neuropil_traces[0])\n", - "correction = demixed_traces[0] - results['r'] * neuropil_traces[0]\n", - "print(\"r = %f\" % results['r'])\n", - "print(\"max error = %f\" % results['err'])" + "correction = demixed_traces[0] - results[\"r\"] * neuropil_traces[0]\n", + "print(\"r = %f\" % results[\"r\"])\n", + "print(\"max error = %f\" % results[\"err\"])" ] }, { @@ -1522,7 +1522,7 @@ "source": [ "_, corrected_traces = data_set.get_corrected_fluorescence_traces(cell_specimen_ids=[csid])\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"Neuropil-corrected Fluorescence Trace\")\n", "plt.plot(time, corrected_traces[0])\n", "plt.show()" @@ -1587,10 +1587,10 @@ "source": [ "from allensdk.brain_observatory.dff import compute_dff_windowed_mode\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.title(\"dF/F Trace\")\n", "dff = compute_dff_windowed_mode(np.array(corrected_traces))\n", - "plt.plot(time, dff[0,:])\n", + "plt.plot(time, dff[0, :])\n", "plt.show()" ] }, @@ -1661,12 +1661,13 @@ ], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "data_set = boc.get_ophys_experiment_data(501940850)\n", "\n", "dxcm, dxtime = data_set.get_running_speed()\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.plot(dxtime, dxcm)\n", "plt.show()" ] @@ -1729,16 +1730,17 @@ ], "source": [ "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "data_set = boc.get_ophys_experiment_data(501940850)\n", "\n", "mc = data_set.get_motion_correction()\n", "\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.plot(mc.timestamp, mc.x_motion)\n", "plt.plot(mc.timestamp, mc.y_motion)\n", - "plt.legend(['x motion','y motion'])\n", + "plt.legend([\"x motion\", \"y motion\"])\n", "plt.show()" ] }, @@ -1828,6 +1830,7 @@ "source": [ "from allensdk.brain_observatory.brain_observatory_exceptions import NoEyeTrackingException\n", "from matplotlib import pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "# example with no eye tracking data\n", @@ -1838,23 +1841,23 @@ " print(\"No eye tracking for experiment %s.\" % data_set.get_metadata()[\"ophys_experiment_id\"])\n", "\n", "data_set = boc.get_ophys_experiment_data(569407590)\n", - " \n", + "\n", "# looking at azimuth and altitude over time\n", "# by default locations returned are (azimuth, altitude)\n", "# passing as_spherical=False to get_pupil_location will return (x,y) in cm\n", "timestamps, locations = data_set.get_pupil_location()\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.plot(timestamps, locations.T[0])\n", "plt.plot(timestamps, locations.T[1])\n", "plt.title(\"Eye position over time\")\n", "plt.xlabel(\"time (s)\")\n", "plt.ylabel(\"angle (deg)\")\n", - "plt.legend(['azimuth', 'altitude'])\n", + "plt.legend([\"azimuth\", \"altitude\"])\n", "plt.show()\n", "\n", - "#pupil size over time\n", + "# pupil size over time\n", "timestamps, area = data_set.get_pupil_size()\n", - "plt.figure(figsize=(14,4))\n", + "plt.figure(figsize=(14, 4))\n", "plt.plot(timestamps, area)\n", "plt.title(\"Pupil size over time\")\n", "plt.xlabel(\"time (s)\")\n", @@ -1864,7 +1867,7 @@ "\n", "# scatter of gaze positions over approximate screen area\n", "plt.figure()\n", - "plt.scatter(locations.T[0], locations.T[1], s=2, c=\"m\", edgecolor=['none'])\n", + "plt.scatter(locations.T[0], locations.T[1], s=2, c=\"m\", edgecolor=[\"none\"])\n", "plt.title(\"Eye position scatter plot\")\n", "plt.xlim(-70, 70)\n", "plt.ylim(-60, 60)\n", diff --git a/doc_template/examples_root/examples/nb/brain_observatory_analysis.ipynb b/doc_template/examples_root/examples/nb/brain_observatory_analysis.ipynb index e6ce0dc919..e493d4e3f2 100644 --- a/doc_template/examples_root/examples/nb/brain_observatory_analysis.ipynb +++ b/doc_template/examples_root/examples/nb/brain_observatory_analysis.ipynb @@ -78,8 +78,8 @@ }, "outputs": [], "source": [ - "output_dir = '.'\n", - "resources_dir = Path.cwd().parent / 'resources'\n", + "output_dir = \".\"\n", + "resources_dir = Path.cwd().parent / \"resources\"\n", "RUN_LOCALLY_SPARSE_NOISE = True" ] }, @@ -111,8 +111,7 @@ "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", "from pathlib import Path\n", "\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))" + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))" ] }, { @@ -170,6 +169,7 @@ ], "source": [ "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.drifting_gratings import DriftingGratings\n", @@ -233,6 +233,7 @@ ], "source": [ "import numpy as np\n", + "\n", "specimen_id = 517425074\n", "cell_loc = data_set.get_cell_specimen_indices([specimen_id])[0]\n", "\n", @@ -304,13 +305,13 @@ ], "source": [ "# skip the blank sweep column of the temporal frequency dimension\n", - "plt.imshow(dg.response[:,1:,cell_loc,0], cmap='hot', interpolation='none')\n", + "plt.imshow(dg.response[:, 1:, cell_loc, 0], cmap=\"hot\", interpolation=\"none\")\n", "plt.xticks(range(5), dg.tfvals[1:])\n", "plt.yticks(range(8), dg.orivals)\n", "plt.xlabel(\"Temporal frequency (Hz)\", fontsize=20)\n", "plt.ylabel(\"Direction (deg)\", fontsize=20)\n", "plt.tick_params(labelsize=14)\n", - "cbar= plt.colorbar()\n", + "cbar = plt.colorbar()\n", "cbar.set_label(\"DF/F (%)\")" ] }, @@ -668,7 +669,7 @@ } ], "source": [ - "pref_trials = dg.stim_table[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)]\n", + "pref_trials = dg.stim_table[(dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf)]\n", "pref_trials" ] }, @@ -717,7 +718,7 @@ }, "outputs": [], "source": [ - "subset = dg.sweep_response[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)]" + "subset = dg.sweep_response[(dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf)]" ] }, { @@ -791,8 +792,10 @@ } ], "source": [ - "subset_mean = dg.mean_sweep_response[(dg.stim_table.orientation==pref_ori)&(dg.stim_table.temporal_frequency==pref_tf)]\n", - "subset_mean['dx']" + "subset_mean = dg.mean_sweep_response[\n", + " (dg.stim_table.orientation == pref_ori) & (dg.stim_table.temporal_frequency == pref_tf)\n", + "]\n", + "subset_mean[\"dx\"]" ] }, { @@ -851,20 +854,20 @@ } ], "source": [ - "trial_timestamps = np.arange(-1*dg.interlength, dg.interlength+dg.sweeplength, 1.)/dg.acquisition_rate\n", - "plt.figure(figsize=(8,20))\n", + "trial_timestamps = np.arange(-1 * dg.interlength, dg.interlength + dg.sweeplength, 1.0) / dg.acquisition_rate\n", + "plt.figure(figsize=(8, 20))\n", "for i in range(len(subset)):\n", - " plt.subplot(len(pref_trials),1,i+1)\n", - " plt.plot(trial_timestamps, subset[str(cell_loc)].iloc[i], color='k', lw=2)\n", - " plt.axvspan(0,2,color='red', alpha=0.3)\n", + " plt.subplot(len(pref_trials), 1, i + 1)\n", + " plt.plot(trial_timestamps, subset[str(cell_loc)].iloc[i], color=\"k\", lw=2)\n", + " plt.axvspan(0, 2, color=\"red\", alpha=0.3)\n", " plt.ylabel(\"DF/F (%)\")\n", - " plt.ylim(-10,600)\n", - " plt.yticks(range(0,700,200))\n", - " plt.text(2.5, 300, str(round(subset_mean['dx'].iloc[i],2))+\" cm/s\")\n", - " if i<(len(subset)-1):\n", + " plt.ylim(-10, 600)\n", + " plt.yticks(range(0, 700, 200))\n", + " plt.text(2.5, 300, str(round(subset_mean[\"dx\"].iloc[i], 2)) + \" cm/s\")\n", + " if i < (len(subset) - 1):\n", " plt.xticks([])\n", " else:\n", - " plt.xticks([-1,0,1,2,3])\n", + " plt.xticks([-1, 0, 1, 2, 3])\n", " plt.xlabel(\"Time (s)\")" ] }, @@ -1375,10 +1378,11 @@ "specimen_id = 587179530\n", "cell = boc.get_cell_specimens(ids=[specimen_id])[0]\n", "\n", - "exp = boc.get_ophys_experiments(experiment_container_ids=[cell['experiment_container_id']],\n", - " stimuli=[stim_info.LOCALLY_SPARSE_NOISE])[0]\n", - " \n", - "data_set = boc.get_ophys_experiment_data(exp['id'])" + "exp = boc.get_ophys_experiments(\n", + " experiment_container_ids=[cell[\"experiment_container_id\"]], stimuli=[stim_info.LOCALLY_SPARSE_NOISE]\n", + ")[0]\n", + "\n", + "data_set = boc.get_ophys_experiment_data(exp[\"id\"])" ] }, { @@ -1429,6 +1433,7 @@ "import matplotlib.pyplot as plt\n", "from IPython.display import display\n", "from PIL import Image\n", + "\n", "%matplotlib inline\n", "\n", "if RUN_LOCALLY_SPARSE_NOISE:\n", @@ -1437,22 +1442,22 @@ "\n", " cell_idx = data_set.get_cell_specimen_indices([specimen_id])[0]\n", "\n", - " plt.imshow(lsn.receptive_field[:,:,cell_idx,0], interpolation='nearest', cmap='PuRd', origin='lower')\n", + " plt.imshow(lsn.receptive_field[:, :, cell_idx, 0], interpolation=\"nearest\", cmap=\"PuRd\", origin=\"lower\")\n", " plt.title(\"on receptive field\")\n", " plt.show()\n", - " plt.imshow(lsn.receptive_field[:,:,cell_idx,1], interpolation='nearest', cmap='Blues', origin='lower')\n", + " plt.imshow(lsn.receptive_field[:, :, cell_idx, 1], interpolation=\"nearest\", cmap=\"Blues\", origin=\"lower\")\n", " plt.title(\"off receptive field\")\n", " plt.show()\n", "else:\n", " # read in LocallySparseNoise output and display it\n", - " off_receptive_field = Image.open(Path(resources_dir) /\n", - " 'brain_observatory_analysis' /\n", - " 'off_receptive_field.png').convert('RGB')\n", + " off_receptive_field = Image.open(\n", + " Path(resources_dir) / \"brain_observatory_analysis\" / \"off_receptive_field.png\"\n", + " ).convert(\"RGB\")\n", " display(off_receptive_field)\n", "\n", - " on_receptive_field = Image.open(Path(resources_dir) /\n", - " 'brain_observatory_analysis' /\n", - " 'on_receptive_field.png').convert('RGB')\n", + " on_receptive_field = Image.open(\n", + " Path(resources_dir) / \"brain_observatory_analysis\" / \"on_receptive_field.png\"\n", + " ).convert(\"RGB\")\n", " display(on_receptive_field)" ] } diff --git a/doc_template/examples_root/examples/nb/brain_observatory_monitor.ipynb b/doc_template/examples_root/examples/nb/brain_observatory_monitor.ipynb index 32103be1c7..f3bf7d3796 100644 --- a/doc_template/examples_root/examples/nb/brain_observatory_monitor.ipynb +++ b/doc_template/examples_root/examples/nb/brain_observatory_monitor.ipynb @@ -67,7 +67,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -100,8 +100,8 @@ "import numpy as np\n", "import allensdk.brain_observatory.stimulus_info as si\n", "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))" + "\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))" ] }, { @@ -352,7 +352,7 @@ } ], "source": [ - "img = plt.imshow(template, cmap=plt.cm.gray, interpolation='none')" + "img = plt.imshow(template, cmap=plt.cm.gray, interpolation=\"none\")" ] }, { @@ -515,51 +515,45 @@ } ], "source": [ - "y0, x0 = (5,12)\n", + "y0, x0 = (5, 12)\n", "y1, x1 = m.map_stimulus((y0, x0), si.LOCALLY_SPARSE_NOISE, si.NATURAL_MOVIE_ONE)\n", - "ym, xm = si.map_stimulus_coordinate_to_monitor_coordinate((y0, x0), \n", - " (m.n_pixels_r, m.n_pixels_c), \n", - " si.LOCALLY_SPARSE_NOISE)\n", + "ym, xm = si.map_stimulus_coordinate_to_monitor_coordinate(\n", + " (y0, x0), (m.n_pixels_r, m.n_pixels_c), si.LOCALLY_SPARSE_NOISE\n", + ")\n", "\n", - "img_lsn = nwb_dataset.get_stimulus_template('locally_sparse_noise')[0,:,:]\n", - "img_movie = nwb_dataset.get_stimulus_template('natural_movie_one')[0,:,:]\n", + "img_lsn = nwb_dataset.get_stimulus_template(\"locally_sparse_noise\")[0, :, :]\n", + "img_movie = nwb_dataset.get_stimulus_template(\"natural_movie_one\")[0, :, :]\n", "\n", "\n", - "fig, ax = plt.subplots(2,2, figsize=(10,5))\n", - "natural_movie_image = m.natural_movie_image_to_screen(img_movie, origin='upper')\n", - "m.show_image(natural_movie_image, \n", - " ax=ax[0,1], \n", - " show=False, \n", - " origin='upper', \n", - " mask=True)\n", - "ax[0,1].plot([xm], [ym], 'r.')\n", - "ax[0,1].set_title('Natural Movie One, Monitor\\n1200x1920')\n", - "lsn_image = m.lsn_image_to_screen(img_lsn, origin='upper')\n", - "m.show_image(lsn_image, \n", - " ax=ax[0,0], \n", - " show=False, \n", - " origin='upper', \n", - " mask=True)\n", - "ax[0,0].plot([xm], [ym], 'r.')\n", - "ax[0,0].set_title('Locally Sparse Noise, Monitor\\n1200x1920')\n", + "fig, ax = plt.subplots(2, 2, figsize=(10, 5))\n", + "natural_movie_image = m.natural_movie_image_to_screen(img_movie, origin=\"upper\")\n", + "m.show_image(natural_movie_image, ax=ax[0, 1], show=False, origin=\"upper\", mask=True)\n", + "ax[0, 1].plot([xm], [ym], \"r.\")\n", + "ax[0, 1].set_title(\"Natural Movie One, Monitor\\n1200x1920\")\n", + "lsn_image = m.lsn_image_to_screen(img_lsn, origin=\"upper\")\n", + "m.show_image(lsn_image, ax=ax[0, 0], show=False, origin=\"upper\", mask=True)\n", + "ax[0, 0].plot([xm], [ym], \"r.\")\n", + "ax[0, 0].set_title(\"Locally Sparse Noise, Monitor\\n1200x1920\")\n", "\n", - "ax[1,0].imshow(np.flipud(img_lsn), interpolation='none', cmap=plt.cm.gray, \n", - " extent=[0,img_lsn.shape[1],img_lsn.shape[0],0])\n", - "ax[1,0].plot([x0], [y0], 'r.')\n", - "ax[1,0].axes.get_xaxis().set_visible(False)\n", - "ax[1,0].axes.get_yaxis().set_visible(False)\n", - "ax[1,0].set_xlim((0, img_lsn.shape[1]-1))\n", - "ax[1,0].set_ylim((img_lsn.shape[0]-1, 0))\n", - "ax[1,0].set_title('Locally Sparse Noise, Template\\n16x28')\n", + "ax[1, 0].imshow(\n", + " np.flipud(img_lsn), interpolation=\"none\", cmap=plt.cm.gray, extent=[0, img_lsn.shape[1], img_lsn.shape[0], 0]\n", + ")\n", + "ax[1, 0].plot([x0], [y0], \"r.\")\n", + "ax[1, 0].axes.get_xaxis().set_visible(False)\n", + "ax[1, 0].axes.get_yaxis().set_visible(False)\n", + "ax[1, 0].set_xlim((0, img_lsn.shape[1] - 1))\n", + "ax[1, 0].set_ylim((img_lsn.shape[0] - 1, 0))\n", + "ax[1, 0].set_title(\"Locally Sparse Noise, Template\\n16x28\")\n", "\n", - "ax[1,1].imshow(img_movie, interpolation='none', cmap=plt.cm.gray,\n", - " extent=[0,img_movie.shape[1],img_movie.shape[0],0])\n", - "ax[1,1].plot([x1], [y1], 'r.')\n", - "ax[1,1].axes.get_xaxis().set_visible(False)\n", - "ax[1,1].axes.get_yaxis().set_visible(False)\n", - "ax[1,1].set_xlim((0, img_movie.shape[1]-1))\n", - "ax[1,1].set_ylim((img_movie.shape[0]-1, 0))\n", - "ax[1,1].set_title('Natural Movie One, Template\\n304x608')\n", + "ax[1, 1].imshow(\n", + " img_movie, interpolation=\"none\", cmap=plt.cm.gray, extent=[0, img_movie.shape[1], img_movie.shape[0], 0]\n", + ")\n", + "ax[1, 1].plot([x1], [y1], \"r.\")\n", + "ax[1, 1].axes.get_xaxis().set_visible(False)\n", + "ax[1, 1].axes.get_yaxis().set_visible(False)\n", + "ax[1, 1].set_xlim((0, img_movie.shape[1] - 1))\n", + "ax[1, 1].set_ylim((img_movie.shape[0] - 1, 0))\n", + "ax[1, 1].set_title(\"Natural Movie One, Template\\n304x608\")\n", "\n", "\n", "fig.tight_layout()" @@ -632,30 +626,47 @@ } ], "source": [ - "origin = 'upper'\n", - "img_movie = nwb_dataset.get_stimulus_template(si.NATURAL_MOVIE_ONE)[0,:,:]\n", + "origin = \"upper\"\n", + "img_movie = nwb_dataset.get_stimulus_template(si.NATURAL_MOVIE_ONE)[0, :, :]\n", "natural_movie_image = m.natural_movie_image_to_screen(img_movie, origin=origin)\n", - "natural_movie_image_translated = m.natural_movie_image_to_screen(img_movie, origin=origin, translation=(100,-300))\n", + "natural_movie_image_translated = m.natural_movie_image_to_screen(img_movie, origin=origin, translation=(100, -300))\n", "\n", - "img_lsn = nwb_dataset.get_stimulus_template(si.LOCALLY_SPARSE_NOISE)[0,:,:]\n", + "img_lsn = nwb_dataset.get_stimulus_template(si.LOCALLY_SPARSE_NOISE)[0, :, :]\n", "lsn_image = m.lsn_image_to_screen(img_lsn, origin=origin)\n", - "img_lsn_translated = m.lsn_image_to_screen(img_lsn, origin=origin, translation=(-500,0))\n", + "img_lsn_translated = m.lsn_image_to_screen(img_lsn, origin=origin, translation=(-500, 0))\n", "\n", - "img_grating = m.grating_to_screen(.5,.02,0)\n", - "img_grating_translated = m.grating_to_screen(.5,.02,0, translation=(250,250))\n", + "img_grating = m.grating_to_screen(0.5, 0.02, 0)\n", + "img_grating_translated = m.grating_to_screen(0.5, 0.02, 0, translation=(250, 250))\n", "\n", "\n", - "fig, ax = plt.subplots(3,2, figsize=(20,10))\n", - "_ = ax[0,0].imshow(natural_movie_image, interpolation='none', cmap=plt.cm.gray,\n", - " extent=[0,natural_movie_image.shape[1],natural_movie_image.shape[0],0])\n", - "_ = ax[0,1].imshow(natural_movie_image_translated, interpolation='none', cmap=plt.cm.gray,\n", - " extent=[0,natural_movie_image_translated.shape[1],natural_movie_image_translated.shape[0],0])\n", - "_ = ax[1,0].imshow(lsn_image, interpolation='none', cmap=plt.cm.gray,\n", - " extent=[0,lsn_image.shape[1],lsn_image.shape[0],0])\n", - "_ = ax[1,1].imshow(img_lsn_translated, interpolation='none', cmap=plt.cm.gray,\n", - " extent=[0,img_lsn_translated.shape[1],img_lsn_translated.shape[0],0])\n", - "_ = ax[2,0].imshow(img_grating, interpolation='none', cmap=plt.cm.gray,)\n", - "_ = ax[2,1].imshow(img_grating_translated, interpolation='none', cmap=plt.cm.gray)" + "fig, ax = plt.subplots(3, 2, figsize=(20, 10))\n", + "_ = ax[0, 0].imshow(\n", + " natural_movie_image,\n", + " interpolation=\"none\",\n", + " cmap=plt.cm.gray,\n", + " extent=[0, natural_movie_image.shape[1], natural_movie_image.shape[0], 0],\n", + ")\n", + "_ = ax[0, 1].imshow(\n", + " natural_movie_image_translated,\n", + " interpolation=\"none\",\n", + " cmap=plt.cm.gray,\n", + " extent=[0, natural_movie_image_translated.shape[1], natural_movie_image_translated.shape[0], 0],\n", + ")\n", + "_ = ax[1, 0].imshow(\n", + " lsn_image, interpolation=\"none\", cmap=plt.cm.gray, extent=[0, lsn_image.shape[1], lsn_image.shape[0], 0]\n", + ")\n", + "_ = ax[1, 1].imshow(\n", + " img_lsn_translated,\n", + " interpolation=\"none\",\n", + " cmap=plt.cm.gray,\n", + " extent=[0, img_lsn_translated.shape[1], img_lsn_translated.shape[0], 0],\n", + ")\n", + "_ = ax[2, 0].imshow(\n", + " img_grating,\n", + " interpolation=\"none\",\n", + " cmap=plt.cm.gray,\n", + ")\n", + "_ = ax[2, 1].imshow(img_grating_translated, interpolation=\"none\", cmap=plt.cm.gray)" ] } ], diff --git a/doc_template/examples_root/examples/nb/brain_observatory_stimuli.ipynb b/doc_template/examples_root/examples/nb/brain_observatory_stimuli.ipynb index 89edab1490..d79d27255c 100644 --- a/doc_template/examples_root/examples/nb/brain_observatory_stimuli.ipynb +++ b/doc_template/examples_root/examples/nb/brain_observatory_stimuli.ipynb @@ -69,7 +69,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -96,19 +96,21 @@ "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", + "\n", "%matplotlib inline\n", "\n", + "\n", "def plot_stimulus_table(stim_table, title):\n", " fstart = stim_table.start.min()\n", " fend = stim_table.end.max()\n", - " \n", - " fig = plt.figure(figsize=(15,1))\n", + "\n", + " fig = plt.figure(figsize=(15, 1))\n", " ax = fig.gca()\n", - " for i, trial in stim_table.iterrows(): \n", + " for i, trial in stim_table.iterrows():\n", " x1 = float(trial.start - fstart) / (fend - fstart)\n", - " x2 = float(trial.end - fstart) / (fend - fstart) \n", - " ax.add_patch(patches.Rectangle((x1, 0.0), x2 - x1, 1.0, color='r'))\n", - " ax.set_xticks((0,1))\n", + " x2 = float(trial.end - fstart) / (fend - fstart)\n", + " ax.add_patch(patches.Rectangle((x1, 0.0), x2 - x1, 1.0, color=\"r\"))\n", + " ax.set_xticks((0, 1))\n", " ax.set_xticklabels((fstart, fend))\n", " ax.set_yticks(())\n", " ax.set_title(title)\n", @@ -174,14 +176,14 @@ ], "source": [ "from allensdk.core.brain_observatory_cache import BrainObservatoryCache\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "data_set = boc.get_ophys_experiment_data(501940850)\n", "\n", "# this is a pandas DataFrame. find trials with a given stimulus condition.\n", "temporal_frequency = 4\n", "orientation = 225\n", - "stim_table = data_set.get_stimulus_table('drifting_gratings')\n", + "stim_table = data_set.get_stimulus_table(\"drifting_gratings\")\n", "stim_table = stim_table[(stim_table.temporal_frequency == temporal_frequency) & (stim_table.orientation == orientation)]\n", "\n", "# plot the trials\n", @@ -252,10 +254,12 @@ "spatial_frequency = 0.02\n", "orientation = 30\n", "phase = 0.0\n", - "stim_table = data_set.get_stimulus_table('static_gratings')\n", - "stim_table = stim_table[(stim_table.spatial_frequency == spatial_frequency) & \\\n", - " (stim_table.orientation == orientation) & \\\n", - " (stim_table.phase == phase) ]\n", + "stim_table = data_set.get_stimulus_table(\"static_gratings\")\n", + "stim_table = stim_table[\n", + " (stim_table.spatial_frequency == spatial_frequency)\n", + " & (stim_table.orientation == orientation)\n", + " & (stim_table.phase == phase)\n", + "]\n", "\n", "# plot the trials\n", "plot_stimulus_table(stim_table, \"SF %.02f ORI %d Phase %.02f\" % (spatial_frequency, orientation, phase))" @@ -317,14 +321,14 @@ "scene_nums = [4, 83]\n", "\n", "# read in the array of images\n", - "scenes = data_set.get_stimulus_template('natural_scenes')\n", + "scenes = data_set.get_stimulus_template(\"natural_scenes\")\n", "\n", "# display a couple of the scenes\n", - "fig, axes = plt.subplots(1,len(scene_nums))\n", - "for ax,scene in zip(axes, scene_nums):\n", - " ax.imshow(scenes[scene,:,:], cmap='gray')\n", + "fig, axes = plt.subplots(1, len(scene_nums))\n", + "for ax, scene in zip(axes, scene_nums):\n", + " ax.imshow(scenes[scene, :, :], cmap=\"gray\")\n", " ax.set_axis_off()\n", - " ax.set_title('scene %d' % scene)" + " ax.set_title(\"scene %d\" % scene)" ] }, { @@ -380,12 +384,12 @@ "data_set = boc.get_ophys_experiment_data(501498760)\n", "\n", "# the natural scenes stimulus table describes when each scene is on the screen\n", - "stim_table = data_set.get_stimulus_table('natural_scenes')\n", + "stim_table = data_set.get_stimulus_table(\"natural_scenes\")\n", "\n", "# build up a mask of trials for which one of a list of scenes is visible\n", "trial_mask = stim_table.frame == -2\n", "for scene in scene_nums:\n", - " trial_mask |= (stim_table.frame == scene)\n", + " trial_mask |= stim_table.frame == scene\n", "stim_table = stim_table[trial_mask]\n", "\n", "# plot the trials\n", @@ -446,13 +450,13 @@ "data_set = boc.get_ophys_experiment_data(501498760)\n", "\n", "# read in the natural movie one clip\n", - "movie = data_set.get_stimulus_template('natural_movie_one')\n", + "movie = data_set.get_stimulus_template(\"natural_movie_one\")\n", "\n", "# display a random frame for reference\n", "frame = 200\n", - "plt.imshow(movie[frame,:,:], cmap='gray')\n", - "plt.axis('off')\n", - "plt.title('frame %d' % frame)\n", + "plt.imshow(movie[frame, :, :], cmap=\"gray\")\n", + "plt.axis(\"off\")\n", + "plt.title(\"frame %d\" % frame)\n", "plt.show()" ] }, @@ -509,10 +513,10 @@ "data_set = boc.get_ophys_experiment_data(501498760)\n", "\n", "# read in the stimulus table, which describes when a given frame is displayed\n", - "stim_table = data_set.get_stimulus_table('natural_movie_one')\n", + "stim_table = data_set.get_stimulus_table(\"natural_movie_one\")\n", "\n", "# find out when a particular frame range is displayed\n", - "frame_range = [ 100, 120 ]\n", + "frame_range = [100, 120]\n", "stim_table = stim_table[(stim_table.frame >= frame_range[0]) & (stim_table.frame <= frame_range[1])]\n", "\n", "plot_stimulus_table(stim_table, \"frames %d -> %d \" % (frame_range[0], frame_range[1]))" @@ -578,16 +582,17 @@ "source": [ "data_set = boc.get_ophys_experiment_data(505693621)\n", "\n", - "# read in the locally sparse noise stimulus movie. \n", + "# read in the locally sparse noise stimulus movie.\n", "# the 'mask_offscreen' argument will set off-screen grid locations to LocallySparseNoise.LSN_OFF_SCREEN\n", - "lsn_movie, offscreen_mask = data_set.get_locally_sparse_noise_stimulus_template('locally_sparse_noise',\n", - " mask_off_screen=True)\n", + "lsn_movie, offscreen_mask = data_set.get_locally_sparse_noise_stimulus_template(\n", + " \"locally_sparse_noise\", mask_off_screen=True\n", + ")\n", "\n", "# show a single frame of the stimulus for reference\n", "frame = 200\n", - "plt.imshow(lsn_movie[frame,:,:], cmap='gray', interpolation='nearest')\n", - "plt.axis('off')\n", - "plt.title('frame %d' % frame)\n", + "plt.imshow(lsn_movie[frame, :, :], cmap=\"gray\", interpolation=\"nearest\")\n", + "plt.axis(\"off\")\n", + "plt.title(\"frame %d\" % frame)\n", "plt.show()" ] }, @@ -645,14 +650,14 @@ "from allensdk.brain_observatory.locally_sparse_noise import LocallySparseNoise\n", "\n", "# find frames at a given grid location that are 'on'\n", - "loc = (10,15)\n", - "on_frames = np.where(lsn_movie[:,loc[0],loc[1]] == LocallySparseNoise.LSN_ON)[0]\n", + "loc = (10, 15)\n", + "on_frames = np.where(lsn_movie[:, loc[0], loc[1]] == LocallySparseNoise.LSN_ON)[0]\n", "\n", "# For some reason, stim table only goes to frame 8879, while locally_sparse_noise is 9000 frames\n", "on_frames = on_frames[on_frames < stim_table.index.max()]\n", "\n", "# pull these trials out of the stimulus table\n", - "stim_table = data_set.get_stimulus_table('locally_sparse_noise')\n", + "stim_table = data_set.get_stimulus_table(\"locally_sparse_noise\")\n", "stim_table = stim_table.loc[on_frames]\n", "\n", "plot_stimulus_table(stim_table, \"loc (%d,%d) \" % loc)" diff --git a/doc_template/examples_root/examples/nb/cell_specimen_mapping.ipynb b/doc_template/examples_root/examples/nb/cell_specimen_mapping.ipynb index bffbe4edf2..3e5655d8cf 100644 --- a/doc_template/examples_root/examples/nb/cell_specimen_mapping.ipynb +++ b/doc_template/examples_root/examples/nb/cell_specimen_mapping.ipynb @@ -87,7 +87,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -296,10 +296,10 @@ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from matplotlib.colors import ListedColormap\n", + "\n", "%matplotlib inline\n", "\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "\n", "cell_ids = {}\n", "cell_ids[\"three_session_B\"] = int(table[table.old_cell_id == old_cell_id].session_B_new_cell_id)\n", @@ -308,26 +308,25 @@ "datasets = {}\n", "for session, cell_id in iteritems(cell_ids):\n", " # Find and download the session experiment\n", - " exp = boc.get_ophys_experiments(cell_specimen_ids=[cell_id],\n", - " session_types=[session])[0]\n", - " datasets[session] = boc.get_ophys_experiment_data(exp['id'])\n", + " exp = boc.get_ophys_experiments(cell_specimen_ids=[cell_id], session_types=[session])[0]\n", + " datasets[session] = boc.get_ophys_experiment_data(exp[\"id\"])\n", "\n", "# set up a color map for overlay\n", "overlay_map = ListedColormap([\"g\"])\n", - "overlay_map.set_bad(color='k', alpha=0)\n", + "overlay_map.set_bad(color=\"k\", alpha=0)\n", "\n", "# overlay the cell ROIs on the max projection\n", "plt.figure(figsize=(11, 5))\n", "for i, key in enumerate(sorted(datasets.keys())):\n", - " plt.subplot(1, 2, i+1)\n", + " plt.subplot(1, 2, i + 1)\n", " dataset = datasets[key]\n", " cell_id = cell_ids[key]\n", " overlay_mask = dataset.get_roi_mask(cell_specimen_ids=[cell_id])[0].get_mask_plane().astype(float)\n", " overlay_mask[overlay_mask == 0] = np.nan\n", - " plt.imshow(dataset.get_max_projection(), cmap='gray')\n", + " plt.imshow(dataset.get_max_projection(), cmap=\"gray\")\n", " plt.imshow(overlay_mask, cmap=overlay_map, alpha=0.3)\n", " plt.title(key)\n", - " plt.axis('off')\n", + " plt.axis(\"off\")\n", "plt.show()" ] }, @@ -409,14 +408,20 @@ "# want to see if it's a cell I previously did some analysis on\n", "new_cell_id = 517396395\n", "\n", - "old_cell_id = int(table[(table.session_A_new_cell_id==new_cell_id) | \n", - " (table.session_B_new_cell_id==new_cell_id) |\n", - " (table.session_C_new_cell_id==new_cell_id)].old_cell_id)\n", + "old_cell_id = int(\n", + " table[\n", + " (table.session_A_new_cell_id == new_cell_id)\n", + " | (table.session_B_new_cell_id == new_cell_id)\n", + " | (table.session_C_new_cell_id == new_cell_id)\n", + " ].old_cell_id\n", + ")\n", "\n", "# I can see that this cell is found in all sessions as well.\n", - "table[(table.session_A_new_cell_id==new_cell_id) | \n", - " (table.session_B_new_cell_id==new_cell_id) |\n", - " (table.session_C_new_cell_id==new_cell_id)]" + "table[\n", + " (table.session_A_new_cell_id == new_cell_id)\n", + " | (table.session_B_new_cell_id == new_cell_id)\n", + " | (table.session_C_new_cell_id == new_cell_id)\n", + "]" ] } ], diff --git a/doc_template/examples_root/examples/nb/cell_types.ipynb b/doc_template/examples_root/examples/nb/cell_types.ipynb index 337344b4b0..74a901e50a 100644 --- a/doc_template/examples_root/examples/nb/cell_types.ipynb +++ b/doc_template/examples_root/examples/nb/cell_types.ipynb @@ -81,7 +81,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -108,7 +108,24 @@ "tags": [] }, "outputs": [], - "source": "from allensdk.core.cell_types_cache import CellTypesCache\nimport os\n\n# Instantiate the CellTypesCache instance. The manifest_file argument\n# tells it where to store the manifest, which is a JSON file that tracks\n# file paths. If you supply a relative path it will go into your\n# current working directory\ntry:\n ctc = CellTypesCache(manifest_file=Path(output_dir) / 'manifest.json')\nexcept Exception:\n os.remove(Path(output_dir) / 'manifest.json')\n ctc = CellTypesCache(manifest_file=Path(output_dir) / 'manifest.json')\n \n# this saves the NWB file to 'cell_types/specimen_464212183/ephys.nwb'\ncell_specimen_id = 464212183\ndata_set = ctc.get_ephys_data(cell_specimen_id)" + "source": [ + "from allensdk.core.cell_types_cache import CellTypesCache\n", + "import os\n", + "\n", + "# Instantiate the CellTypesCache instance. The manifest_file argument\n", + "# tells it where to store the manifest, which is a JSON file that tracks\n", + "# file paths. If you supply a relative path it will go into your\n", + "# current working directory\n", + "try:\n", + " ctc = CellTypesCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", + "except Exception:\n", + " os.remove(Path(output_dir) / \"manifest.json\")\n", + " ctc = CellTypesCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", + "\n", + "# this saves the NWB file to 'cell_types/specimen_464212183/ephys.nwb'\n", + "cell_specimen_id = 464212183\n", + "data_set = ctc.get_ephys_data(cell_specimen_id)" + ] }, { "cell_type": "code", @@ -141,7 +158,7 @@ "# tells it where to store the manifest, which is a JSON file that tracks\n", "# file paths. If you supply a relative path it will go into your\n", "# current working directory\n", - "ctc = CellTypesCache(manifest_file=Path(output_dir) / 'manifest.json')\n", + "ctc = CellTypesCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", "\n", "# this saves the NWB file to 'cell_types/specimen_464212183/ephys.nwb'\n", "cell_specimen_id = 464212183\n", @@ -179,7 +196,7 @@ "# tells it where to store the manifest, which is a JSON file that tracks\n", "# file paths. If you supply a relative path it will go into your\n", "# current working directory\n", - "ctc = CellTypesCache(manifest_file=Path(output_dir) / 'manifest.json')\n", + "ctc = CellTypesCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", "\n", "# this saves the NWB file to 'cell_types/specimen_464212183/ephys.nwb'\n", "cell_specimen_id = 464212183\n", @@ -251,18 +268,18 @@ "sweep_data = data_set.get_sweep(sweep_number)\n", "\n", "index_range = sweep_data[\"index_range\"]\n", - "i = sweep_data[\"stimulus\"][0:index_range[1]+1] # in A\n", - "v = sweep_data[\"response\"][0:index_range[1]+1] # in V\n", - "i *= 1e12 # to pA\n", - "v *= 1e3 # to mV\n", + "i = sweep_data[\"stimulus\"][0 : index_range[1] + 1] # in A\n", + "v = sweep_data[\"response\"][0 : index_range[1] + 1] # in V\n", + "i *= 1e12 # to pA\n", + "v *= 1e3 # to mV\n", "\n", - "sampling_rate = sweep_data[\"sampling_rate\"] # in Hz\n", + "sampling_rate = sweep_data[\"sampling_rate\"] # in Hz\n", "t = np.arange(0, len(v)) * (1.0 / sampling_rate)\n", "\n", - "plt.style.use('ggplot')\n", + "plt.style.use(\"ggplot\")\n", "fig, axes = plt.subplots(2, 1, sharex=True)\n", - "axes[0].plot(t, v, color='black')\n", - "axes[1].plot(t, i, color='gray')\n", + "axes[0].plot(t, v, color=\"black\")\n", + "axes[1].plot(t, i, color=\"gray\")\n", "axes[0].set_ylabel(\"mV\")\n", "axes[1].set_ylabel(\"pA\")\n", "axes[1].set_xlabel(\"seconds\")\n", @@ -347,16 +364,15 @@ "print(\"Human cells: %d\" % len(cells))\n", "\n", "# cells with reconstructions\n", - "cells = ctc.get_cells(require_reconstruction = True)\n", + "cells = ctc.get_cells(require_reconstruction=True)\n", "print(\"Cells with reconstructions: %d\" % len(cells))\n", "\n", "# all cre positive cells\n", - "cells = ctc.get_cells(reporter_status = RS.POSITIVE)\n", + "cells = ctc.get_cells(reporter_status=RS.POSITIVE)\n", "print(\"Cre-positive cells: %d\" % len(cells))\n", "\n", "# cre negative cells with reconstructions\n", - "cells = ctc.get_cells(require_reconstruction = True, \n", - " reporter_status = RS.NEGATIVE)\n", + "cells = ctc.get_cells(require_reconstruction=True, reporter_status=RS.NEGATIVE)\n", "print(\"Cre-negative cells with reconstructions: %d\" % len(cells))" ] }, @@ -447,7 +463,7 @@ "\n", "# download and open an SWC file\n", "cell_id = 480114344\n", - "morphology = ctc.get_reconstruction(cell_id) \n", + "morphology = ctc.get_reconstruction(cell_id)\n", "\n", "# the compartment list has all of the nodes in the file\n", "pprint.pprint(morphology.compartment_list[0])" @@ -516,7 +532,7 @@ ], "source": [ "# download and open a marker file\n", - "markers = ctc.get_reconstruction_markers(cell_id) \n", + "markers = ctc.get_reconstruction_markers(cell_id)\n", "pprint.pprint(markers[0])" ] }, @@ -577,31 +593,32 @@ ], "source": [ "from allensdk.core.swc import Marker\n", + "\n", "fig, axes = plt.subplots(1, 2, sharey=True, sharex=True)\n", - "axes[0].set_aspect('equal', 'box')\n", - "axes[1].set_aspect('equal', 'box')\n", + "axes[0].set_aspect(\"equal\", \"box\")\n", + "axes[1].set_aspect(\"equal\", \"box\")\n", "\n", "# Make a line drawing of x-y and y-z views\n", "for n in morphology.compartment_list:\n", " for c in morphology.children_of(n):\n", - " axes[0].plot([n['x'], c['x']], [n['y'], c['y']], color='black')\n", - " axes[1].plot([n['z'], c['z']], [n['y'], c['y']], color='black')\n", + " axes[0].plot([n[\"x\"], c[\"x\"]], [n[\"y\"], c[\"y\"]], color=\"black\")\n", + " axes[1].plot([n[\"z\"], c[\"z\"]], [n[\"y\"], c[\"y\"]], color=\"black\")\n", "\n", "# cut dendrite markers\n", - "dm = [ m for m in markers if m['name'] == Marker.CUT_DENDRITE ]\n", + "dm = [m for m in markers if m[\"name\"] == Marker.CUT_DENDRITE]\n", "\n", - "axes[0].scatter([m['x'] for m in dm], [m['y'] for m in dm], color='#3333ff')\n", - "axes[1].scatter([m['z'] for m in dm], [m['y'] for m in dm], color='#3333ff')\n", + "axes[0].scatter([m[\"x\"] for m in dm], [m[\"y\"] for m in dm], color=\"#3333ff\")\n", + "axes[1].scatter([m[\"z\"] for m in dm], [m[\"y\"] for m in dm], color=\"#3333ff\")\n", "\n", "# no reconstruction markers\n", - "nm = [ m for m in markers if m['name'] == Marker.NO_RECONSTRUCTION ]\n", + "nm = [m for m in markers if m[\"name\"] == Marker.NO_RECONSTRUCTION]\n", "\n", - "axes[0].scatter([m['x'] for m in nm], [m['y'] for m in nm], color='#333333')\n", - "axes[1].scatter([m['z'] for m in nm], [m['y'] for m in nm], color='#333333')\n", + "axes[0].scatter([m[\"x\"] for m in nm], [m[\"y\"] for m in nm], color=\"#333333\")\n", + "axes[1].scatter([m[\"z\"] for m in nm], [m[\"y\"] for m in nm], color=\"#333333\")\n", "\n", - "axes[0].set_ylabel('y')\n", - "axes[0].set_xlabel('x')\n", - "axes[1].set_xlabel('z')\n", + "axes[0].set_ylabel(\"y\")\n", + "axes[0].set_xlabel(\"x\")\n", + "axes[1].set_xlabel(\"z\")\n", "plt.show()" ] }, @@ -776,7 +793,7 @@ "\n", "# filter down to a specific cell\n", "specimen_id = 464212183\n", - "cell_ephys_features = ef_df[ef_df['specimen_id']== specimen_id]\n", + "cell_ephys_features = ef_df[ef_df[\"specimen_id\"] == specimen_id]\n", "cell_ephys_features" ] }, @@ -837,8 +854,7 @@ ], "source": [ "plt.figure()\n", - "plt.scatter(ef_df['fast_trough_v_long_square'], \n", - " ef_df['upstroke_downstroke_ratio_long_square'], color='#2ca25f')\n", + "plt.scatter(ef_df[\"fast_trough_v_long_square\"], ef_df[\"upstroke_downstroke_ratio_long_square\"], color=\"#2ca25f\")\n", "plt.ylabel(\"upstroke-downstroke ratio\")\n", "plt.xlabel(\"fast trough depth (mV)\")\n", "plt.show()" @@ -913,21 +929,17 @@ } ], "source": [ - "A = np.vstack([ef_df['fast_trough_v_long_square'], \n", - " np.ones_like(ef_df['upstroke_downstroke_ratio_long_square'])]).T\n", + "A = np.vstack([ef_df[\"fast_trough_v_long_square\"], np.ones_like(ef_df[\"upstroke_downstroke_ratio_long_square\"])]).T\n", "\n", "print(\"First 5 rows of A:\")\n", "print(A[:5, :])\n", "\n", - "m, c = np.linalg.lstsq(A, ef_df['upstroke_downstroke_ratio_long_square'], rcond=None)[0]\n", + "m, c = np.linalg.lstsq(A, ef_df[\"upstroke_downstroke_ratio_long_square\"], rcond=None)[0]\n", "print(\"m\", m, \"c\", c)\n", "\n", "plt.figure()\n", - "plt.scatter(ef_df['fast_trough_v_long_square'], \n", - " ef_df['upstroke_downstroke_ratio_long_square'], \n", - " color='#2ca25f')\n", - "plt.plot(ef_df['fast_trough_v_long_square'],\n", - " m * ef_df['fast_trough_v_long_square'] + c, c='gray')\n", + "plt.scatter(ef_df[\"fast_trough_v_long_square\"], ef_df[\"upstroke_downstroke_ratio_long_square\"], color=\"#2ca25f\")\n", + "plt.plot(ef_df[\"fast_trough_v_long_square\"], m * ef_df[\"fast_trough_v_long_square\"] + c, c=\"gray\")\n", "plt.ylabel(\"upstroke-downstroke ratio\")\n", "plt.xlabel(\"fast trough depth (mV)\")\n", "plt.show()" @@ -993,24 +1005,22 @@ "\n", "# we want to add dendrite type as a column to the ephys. features dataframe\n", "# first build an index on cell specimen ID, then create array of dendrite types\n", - "cell_index = { c['id']: c for c in cells }\n", - "dendrite_types = [ cell_index[cid]['dendrite_type'] for cid in ef_df['specimen_id'] ]\n", + "cell_index = {c[\"id\"]: c for c in cells}\n", + "dendrite_types = [cell_index[cid][\"dendrite_type\"] for cid in ef_df[\"specimen_id\"]]\n", "\n", "# now add the new column\n", - "ef_df['dendrite_type'] = pd.Series(dendrite_types, index=ef_df.index)\n", + "ef_df[\"dendrite_type\"] = pd.Series(dendrite_types, index=ef_df.index)\n", "\n", "fig = plt.figure()\n", "\n", - "for d_type, color in [ [\"spiny\", \"#d95f02\"], [\"aspiny\", \"#7570b3\"] ]:\n", - " df = ef_df[ef_df['dendrite_type'] == d_type]\n", - " plt.scatter(df['fast_trough_v_long_square'], \n", - " df['upstroke_downstroke_ratio_long_square'], \n", - " color=color, label=d_type)\n", + "for d_type, color in [[\"spiny\", \"#d95f02\"], [\"aspiny\", \"#7570b3\"]]:\n", + " df = ef_df[ef_df[\"dendrite_type\"] == d_type]\n", + " plt.scatter(df[\"fast_trough_v_long_square\"], df[\"upstroke_downstroke_ratio_long_square\"], color=color, label=d_type)\n", "\n", " plt.ylabel(\"upstroke-downstroke ratio\")\n", " plt.xlabel(\"fast trough depth (mV)\")\n", - " plt.legend(loc='best')\n", - " \n", + " plt.legend(loc=\"best\")\n", + "\n", "plt.show()" ] }, @@ -1366,19 +1376,19 @@ "sweep_data = data_set.get_sweep(sweep_number)\n", "\n", "index_range = sweep_data[\"index_range\"]\n", - "i = sweep_data[\"stimulus\"][0:index_range[1]+1] # in A\n", - "v = sweep_data[\"response\"][0:index_range[1]+1] # in V\n", - "i *= 1e12 # to pA\n", - "v *= 1e3 # to mV\n", + "i = sweep_data[\"stimulus\"][0 : index_range[1] + 1] # in A\n", + "v = sweep_data[\"response\"][0 : index_range[1] + 1] # in V\n", + "i *= 1e12 # to pA\n", + "v *= 1e3 # to mV\n", "\n", - "sampling_rate = sweep_data[\"sampling_rate\"] # in Hz\n", + "sampling_rate = sweep_data[\"sampling_rate\"] # in Hz\n", "t = np.arange(0, len(v)) * (1.0 / sampling_rate)\n", "\n", "sweep_ext = EphysSweepFeatureExtractor(t=t, v=v, i=i, start=1.02, end=2.02)\n", "sweep_ext.process_spikes()\n", "\n", "print(\"Avg spike threshold: %.01f mV\" % sweep_ext.spike_feature(\"threshold_v\").mean())\n", - "print(\"Avg spike width: %.02f ms\" % (1e3 * np.nanmean(sweep_ext.spike_feature(\"width\"))))" + "print(\"Avg spike width: %.02f ms\" % (1e3 * np.nanmean(sweep_ext.spike_feature(\"width\"))))" ] }, { @@ -1600,13 +1610,13 @@ ], "source": [ "fig = plt.figure()\n", - "p = plt.plot(t, v, color='black')\n", + "p = plt.plot(t, v, color=\"black\")\n", "\n", "min_v = v.min()\n", "\n", "v_level = min_v - 5\n", "\n", - "plt.scatter(spike_times, np.ones(len(spike_times)) * min_v, c='firebrick')\n", + "plt.scatter(spike_times, np.ones(len(spike_times)) * min_v, c=\"firebrick\")\n", "plt.xlim(0.9, 1.2)" ] }, @@ -1677,12 +1687,12 @@ ], "source": [ "fig = plt.figure()\n", - "plt.plot(t, v, color='black')\n", + "plt.plot(t, v, color=\"black\")\n", "\n", "threshold_v = sweep_ext.spike_feature(\"threshold_v\")\n", "\n", "# setting zorder puts the dots on top of the trace\n", - "plt.scatter(spike_times, threshold_v, s=50, c='firebrick', zorder=20)\n", + "plt.scatter(spike_times, threshold_v, s=50, c=\"firebrick\", zorder=20)\n", "plt.xlim(1.015, 1.08)" ] } diff --git a/doc_template/examples_root/examples/nb/download_data_via_api.ipynb b/doc_template/examples_root/examples/nb/download_data_via_api.ipynb index 4f0930461e..8619c48943 100644 --- a/doc_template/examples_root/examples/nb/download_data_via_api.ipynb +++ b/doc_template/examples_root/examples/nb/download_data_via_api.ipynb @@ -41,7 +41,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -74,8 +74,10 @@ "\n", "rma = RmaApi()\n", "\n", + "\n", "def read_data(parsed_json):\n", - " return parsed_json['msg']\n", + " return parsed_json[\"msg\"]\n", + "\n", "\n", "def pretty(result):\n", " print(json.dumps(result, indent=2))" @@ -168,12 +170,13 @@ "source": [ "# Get the atlas id\n", "def query_atlases(search_pattern):\n", - " return rma.build_query_url(rma.model_stage('Atlas',\n", - " criteria=\"[name$il'%s']\" % (search_pattern),\n", - " only=['id', 'name']))\n", + " return rma.build_query_url(\n", + " rma.model_stage(\"Atlas\", criteria=\"[name$il'%s']\" % (search_pattern), only=[\"id\", \"name\"])\n", + " )\n", + "\n", "\n", - "atlases = rma.do_query(query_atlases, read_data, 'mouse*')\n", - "pretty(atlases)\n" + "atlases = rma.do_query(query_atlases, read_data, \"mouse*\")\n", + "pretty(atlases)" ] }, { @@ -211,12 +214,15 @@ "source": [ "# get the structure\n", "def query_structure(acronym, ontology_id):\n", - " return rma.build_query_url(rma.model_stage('Structure',\n", - " criteria=\"[acronym$eq'%s'][ontology_id$eq%d]\" % (acronym, ontology_id),\n", - " only=['id','name']))\n", + " return rma.build_query_url(\n", + " rma.model_stage(\n", + " \"Structure\", criteria=\"[acronym$eq'%s'][ontology_id$eq%d]\" % (acronym, ontology_id), only=[\"id\", \"name\"]\n", + " )\n", + " )\n", "\n", - "structure = rma.do_query(query_structure, read_data, 'VISp', 1)[0]\n", - "pretty(structure)\n" + "\n", + "structure = rma.do_query(query_structure, read_data, \"VISp\", 1)[0]\n", + "pretty(structure)" ] }, { @@ -253,9 +259,10 @@ ], "source": [ "def query_hemisphere(name):\n", - " return rma.build_query_url(rma.model_stage('Hemisphere', criteria=\"[name$il'%s']\" % (name)))\n", - " \n", - "left_hemisphere_id = rma.do_query(query_hemisphere, read_data, 'left')[0]['id']\n", + " return rma.build_query_url(rma.model_stage(\"Hemisphere\", criteria=\"[name$il'%s']\" % (name)))\n", + "\n", + "\n", + "left_hemisphere_id = rma.do_query(query_hemisphere, read_data, \"left\")[0][\"id\"]\n", "\n", "left_hemisphere_id" ] @@ -293,18 +300,18 @@ } ], "source": [ - "mca = MouseConnectivityApi('http://api.brain-map.org')\n", - "experiments = mca.get_experiments(structure['id'])\n", + "mca = MouseConnectivityApi(\"http://api.brain-map.org\")\n", + "experiments = mca.get_experiments(structure[\"id\"])\n", "\n", "# TODO: figure out why this didn't work w/ left hemisphere\n", - "#other_hemisphere_id = 3\n", + "# other_hemisphere_id = 3\n", "\n", "# get experiments doesn't take hemisphere into account, so filter the results with a list comprehension\n", - "#left_hemisphere_experiments = [e for e in experiments\n", + "# left_hemisphere_experiments = [e for e in experiments\n", "# if any([(injection['primary_injection_structure']['hemisphere_id'] == other_hemisphere_id)\n", "# for injection in e['specimen']['stereotaxic_injections']])]\n", - " \n", - "#pretty(left_hemisphere_experiments)\n", + "\n", + "# pretty(left_hemisphere_experiments)\n", "len(experiments)" ] }, @@ -418,21 +425,16 @@ "# TODO: show search to get this\n", "section_data_set_id = 183282970\n", "\n", - "image_list = ['projection_density',\n", - " 'projection_energy',\n", - " 'injection_fraction',\n", - " 'injection_density',\n", - " 'injection_energy']\n", + "image_list = [\"projection_density\", \"projection_energy\", \"injection_fraction\", \"injection_density\", \"injection_energy\"]\n", "\n", - "resolution = 100 # or 10, 25, 50\n", + "resolution = 100 # or 10, 25, 50\n", "\n", "# Hmm, this didn't work for an image list of length > 1\n", "for image in image_list:\n", - " nrrd_file = str(Path(output_dir) / f'grid_{section_data_set_id}_{image}.nrrd')\n", - " gda.download_projection_grid_data(section_data_set_id,\n", - " image=[image],\n", - " resolution=resolution,\n", - " save_file_path=nrrd_file)\n", + " nrrd_file = str(Path(output_dir) / f\"grid_{section_data_set_id}_{image}.nrrd\")\n", + " gda.download_projection_grid_data(\n", + " section_data_set_id, image=[image], resolution=resolution, save_file_path=nrrd_file\n", + " )\n", "\n", "\n", "# TODO: data mask" @@ -488,9 +490,7 @@ ], "source": [ "gda.download_expression_grid_data(\n", - " section_data_set_id, \n", - " include=image_list,\n", - " path=str(Path(output_dir) / f'{section_data_set_id}.zip')\n", + " section_data_set_id, include=image_list, path=str(Path(output_dir) / f\"{section_data_set_id}.zip\")\n", ")" ] }, @@ -559,59 +559,66 @@ } ], "source": [ - "#http://api.brain-map.org/api/v2/data/query.csv?criteria=\n", - "#model::ProjectionStructureUnionize,\n", - "#rma::criteria,[is_injection$eq'f'],hemisphere,structure,section_data_set[id$eq183282970](specimen(stereotaxic_injections(primary_injection_structure,stereotaxic_injection_coordinates))),rma::include,section_data_set(specimen(stereotaxic_injections(primary_injection_structure))),\n", - "#rma::options[tabular$eq'distinct+specimens.name+as+specimen_name,stereotaxic_injection_coordinates.coordinates_ap,stereotaxic_injection_coordinates.coordinates_dv,stereotaxic_injection_coordinates.coordinates_ml,data_sets.id+as+data_set_id,stereotaxic_injections.primary_injection_structure_id,structures.acronym+as+target_structure,hemispheres.symbol+as+hemisphere,projection_structure_unionizes.is_injection,projection_structure_unionizes.sum_pixels,projection_structure_unionizes.sum_projection_pixels,projection_structure_unionizes.sum_pixel_intensity,projection_structure_unionizes.sum_projection_pixel_intensity,projection_structure_unionizes.projection_density,projection_structure_unionizes.projection_intensity,projection_structure_unionizes.projection_energy,projection_structure_unionizes.volume,projection_structure_unionizes.projection_volume,projection_structure_unionizes.normalized_projection_volume,projection_structure_unionizes.max_voxel_density,projection_structure_unionizes.max_voxel_x,projection_structure_unionizes.max_voxel_y,projection_structure_unionizes.max_voxel_z'][start_row$eq0][num_rows$eq3000]\n", + "# http://api.brain-map.org/api/v2/data/query.csv?criteria=\n", + "# model::ProjectionStructureUnionize,\n", + "# rma::criteria,[is_injection$eq'f'],hemisphere,structure,section_data_set[id$eq183282970](specimen(stereotaxic_injections(primary_injection_structure,stereotaxic_injection_coordinates))),rma::include,section_data_set(specimen(stereotaxic_injections(primary_injection_structure))),\n", + "# rma::options[tabular$eq'distinct+specimens.name+as+specimen_name,stereotaxic_injection_coordinates.coordinates_ap,stereotaxic_injection_coordinates.coordinates_dv,stereotaxic_injection_coordinates.coordinates_ml,data_sets.id+as+data_set_id,stereotaxic_injections.primary_injection_structure_id,structures.acronym+as+target_structure,hemispheres.symbol+as+hemisphere,projection_structure_unionizes.is_injection,projection_structure_unionizes.sum_pixels,projection_structure_unionizes.sum_projection_pixels,projection_structure_unionizes.sum_pixel_intensity,projection_structure_unionizes.sum_projection_pixel_intensity,projection_structure_unionizes.projection_density,projection_structure_unionizes.projection_intensity,projection_structure_unionizes.projection_energy,projection_structure_unionizes.volume,projection_structure_unionizes.projection_volume,projection_structure_unionizes.normalized_projection_volume,projection_structure_unionizes.max_voxel_density,projection_structure_unionizes.max_voxel_x,projection_structure_unionizes.max_voxel_y,projection_structure_unionizes.max_voxel_z'][start_row$eq0][num_rows$eq3000]\n", + "\n", "\n", "def build_query(section_data_set_id):\n", - " criteria_string = ''.join([\"[is_injection$eq'f'],\",\n", - " \"hemisphere,\",\n", - " \"structure,\",\n", - " \"section_data_set[id$eq%d]\" % (section_data_set_id),\n", - " \"(specimen\",\n", - " \"(stereotaxic_injections\",\n", - " \"(primary_injection_structure,stereotaxic_injection_coordinates)\",\n", - " \"))\"])\n", - " include_string = ''.join([\"section_data_set\",\n", - " \"(specimen\",\n", - " \"(stereotaxic_injections\",\n", - " \"(primary_injection_structure)\",\n", - " \"))\"])\n", - " tabular_list = ['distinct+specimens.name+as+specimen_name',\n", - " 'stereotaxic_injection_coordinates.coordinates_ap',\n", - " 'stereotaxic_injection_coordinates.coordinates_dv',\n", - " 'stereotaxic_injection_coordinates.coordinates_ml',\n", - " 'data_sets.id+as+data_set_id',\n", - " 'stereotaxic_injections.primary_injection_structure_id',\n", - " 'structures.acronym+as+target_structure',\n", - " 'hemispheres.symbol+as+hemisphere',\n", - " 'projection_structure_unionizes.is_injection',\n", - " 'projection_structure_unionizes.sum_pixels',\n", - " 'projection_structure_unionizes.sum_projection_pixels',\n", - " 'projection_structure_unionizes.sum_pixel_intensity',\n", - " 'projection_structure_unionizes.sum_projection_pixel_intensity',\n", - " 'projection_structure_unionizes.projection_density',\n", - " 'projection_structure_unionizes.projection_intensity',\n", - " 'projection_structure_unionizes.projection_energy',\n", - " 'projection_structure_unionizes.volume',\n", - " 'projection_structure_unionizes.projection_volume',\n", - " 'projection_structure_unionizes.normalized_projection_volume',\n", - " 'projection_structure_unionizes.max_voxel_density',\n", - " 'projection_structure_unionizes.max_voxel_x',\n", - " 'projection_structure_unionizes.max_voxel_y',\n", - " 'projection_structure_unionizes.max_voxel_z']\n", - " model_stage = rma.model_stage('ProjectionStructureUnionize',\n", - " criteria=criteria_string,\n", - " include=include_string,\n", - " tabular=[\"'%s'\" % ','.join(tabular_list)], # TODO: better handling of tabular quotes\n", - " num_rows='all')\n", - " url = rma.build_query_url(model_stage, fmt='csv')\n", - " \n", + " criteria_string = \"\".join(\n", + " [\n", + " \"[is_injection$eq'f'],\",\n", + " \"hemisphere,\",\n", + " \"structure,\",\n", + " \"section_data_set[id$eq%d]\" % (section_data_set_id),\n", + " \"(specimen\",\n", + " \"(stereotaxic_injections\",\n", + " \"(primary_injection_structure,stereotaxic_injection_coordinates)\",\n", + " \"))\",\n", + " ]\n", + " )\n", + " include_string = \"\".join(\n", + " [\"section_data_set\", \"(specimen\", \"(stereotaxic_injections\", \"(primary_injection_structure)\", \"))\"]\n", + " )\n", + " tabular_list = [\n", + " \"distinct+specimens.name+as+specimen_name\",\n", + " \"stereotaxic_injection_coordinates.coordinates_ap\",\n", + " \"stereotaxic_injection_coordinates.coordinates_dv\",\n", + " \"stereotaxic_injection_coordinates.coordinates_ml\",\n", + " \"data_sets.id+as+data_set_id\",\n", + " \"stereotaxic_injections.primary_injection_structure_id\",\n", + " \"structures.acronym+as+target_structure\",\n", + " \"hemispheres.symbol+as+hemisphere\",\n", + " \"projection_structure_unionizes.is_injection\",\n", + " \"projection_structure_unionizes.sum_pixels\",\n", + " \"projection_structure_unionizes.sum_projection_pixels\",\n", + " \"projection_structure_unionizes.sum_pixel_intensity\",\n", + " \"projection_structure_unionizes.sum_projection_pixel_intensity\",\n", + " \"projection_structure_unionizes.projection_density\",\n", + " \"projection_structure_unionizes.projection_intensity\",\n", + " \"projection_structure_unionizes.projection_energy\",\n", + " \"projection_structure_unionizes.volume\",\n", + " \"projection_structure_unionizes.projection_volume\",\n", + " \"projection_structure_unionizes.normalized_projection_volume\",\n", + " \"projection_structure_unionizes.max_voxel_density\",\n", + " \"projection_structure_unionizes.max_voxel_x\",\n", + " \"projection_structure_unionizes.max_voxel_y\",\n", + " \"projection_structure_unionizes.max_voxel_z\",\n", + " ]\n", + " model_stage = rma.model_stage(\n", + " \"ProjectionStructureUnionize\",\n", + " criteria=criteria_string,\n", + " include=include_string,\n", + " tabular=[\"'%s'\" % \",\".join(tabular_list)], # TODO: better handling of tabular quotes\n", + " num_rows=\"all\",\n", + " )\n", + " url = rma.build_query_url(model_stage, fmt=\"csv\")\n", + "\n", " return url\n", "\n", - "print(build_query(183282970))\n", - " \n" + "\n", + "print(build_query(183282970))" ] }, { @@ -782,12 +789,12 @@ } ], "source": [ - "resolution=25\n", + "resolution = 25\n", "mca.download_volumetric_data(\n", - " 'annotation/ccf_2015', \n", - " 'annotation_%d.nrrd' % (resolution),\n", - " save_file_path=str(Path(output_dir) / 'volumetric_data.nrrd')\n", - ")\n" + " \"annotation/ccf_2015\",\n", + " \"annotation_%d.nrrd\" % (resolution),\n", + " save_file_path=str(Path(output_dir) / \"volumetric_data.nrrd\"),\n", + ")" ] }, { diff --git a/doc_template/examples_root/examples/nb/ecephys_data_access.ipynb b/doc_template/examples_root/examples/nb/ecephys_data_access.ipynb index 7e95bc138b..f47091386b 100644 --- a/doc_template/examples_root/examples/nb/ecephys_data_access.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_data_access.ipynb @@ -179,7 +179,7 @@ }, "outputs": [], "source": [ - "output_dir = '/local1/ecephys_cache_dir' # must be updated to a valid directory in your filesystem\n", + "output_dir = \"/local1/ecephys_cache_dir\" # must be updated to a valid directory in your filesystem\n", "DOWNLOAD_COMPLETE_DATASET = True" ] }, @@ -479,7 +479,7 @@ "source": [ "sessions = cache.get_session_table()\n", "\n", - "print('Total number of sessions: ' + str(len(sessions)))\n", + "print(\"Total number of sessions: \" + str(len(sessions)))\n", "\n", "sessions.head()" ] @@ -657,11 +657,12 @@ } ], "source": [ - "filtered_sessions = sessions[(sessions.sex == 'M') & \\\n", - " (sessions.full_genotype.str.find('Sst') > -1) & \\\n", - " (sessions.session_type == 'brain_observatory_1.1') & \\\n", - " (['VISl' in acronyms for acronyms in \n", - " sessions.ecephys_structure_acronyms])]\n", + "filtered_sessions = sessions[\n", + " (sessions.sex == \"M\")\n", + " & (sessions.full_genotype.str.find(\"Sst\") > -1)\n", + " & (sessions.session_type == \"brain_observatory_1.1\")\n", + " & ([\"VISl\" in acronyms for acronyms in sessions.ecephys_structure_acronyms])\n", + "]\n", "\n", "filtered_sessions.head()" ] @@ -864,7 +865,7 @@ "source": [ "probes = cache.get_probes()\n", "\n", - "print('Total number of probes: ' + str(len(probes)))\n", + "print(\"Total number of probes: \" + str(len(probes)))\n", "\n", "probes.head()" ] @@ -1130,7 +1131,7 @@ "source": [ "channels = cache.get_channels()\n", "\n", - "print('Total number of channels: ' + str(len(channels)))\n", + "print(\"Total number of channels: \" + str(len(channels)))\n", "\n", "channels.head()" ] @@ -1192,7 +1193,7 @@ "source": [ "units = cache.get_units()\n", "\n", - "print('Total number of units: ' + str(len(units)))" + "print(\"Total number of units: \" + str(len(units)))" ] }, { @@ -1258,11 +1259,9 @@ } ], "source": [ - "units = cache.get_units(amplitude_cutoff_maximum = np.inf,\n", - " presence_ratio_minimum = -np.inf,\n", - " isi_violations_maximum = np.inf)\n", + "units = cache.get_units(amplitude_cutoff_maximum=np.inf, presence_ratio_minimum=-np.inf, isi_violations_maximum=np.inf)\n", "\n", - "print('Total number of units: ' + str(len(units)))" + "print(\"Total number of units: \" + str(len(units)))" ] }, { @@ -1321,12 +1320,12 @@ } ], "source": [ - "analysis_metrics1 = cache.get_unit_analysis_metrics_by_session_type('brain_observatory_1.1')\n", + "analysis_metrics1 = cache.get_unit_analysis_metrics_by_session_type(\"brain_observatory_1.1\")\n", "\n", - "analysis_metrics2 = cache.get_unit_analysis_metrics_by_session_type('functional_connectivity')\n", + "analysis_metrics2 = cache.get_unit_analysis_metrics_by_session_type(\"functional_connectivity\")\n", "\n", - "print(str(len(analysis_metrics1)) + ' units in table 1')\n", - "print(str(len(analysis_metrics2)) + ' units in table 2')" + "print(str(len(analysis_metrics1)) + \" units in table 1\")\n", + "print(str(len(analysis_metrics2)) + \" units in table 2\")" ] }, { @@ -1384,17 +1383,23 @@ } ], "source": [ - "analysis_metrics1 = cache.get_unit_analysis_metrics_by_session_type('brain_observatory_1.1', amplitude_cutoff_maximum = np.inf,\n", - " presence_ratio_minimum = -np.inf,\n", - " isi_violations_maximum = np.inf)\n", - "\n", - "analysis_metrics2 = cache.get_unit_analysis_metrics_by_session_type('functional_connectivity', amplitude_cutoff_maximum = np.inf,\n", - " presence_ratio_minimum = -np.inf,\n", - " isi_violations_maximum = np.inf)\n", + "analysis_metrics1 = cache.get_unit_analysis_metrics_by_session_type(\n", + " \"brain_observatory_1.1\",\n", + " amplitude_cutoff_maximum=np.inf,\n", + " presence_ratio_minimum=-np.inf,\n", + " isi_violations_maximum=np.inf,\n", + ")\n", + "\n", + "analysis_metrics2 = cache.get_unit_analysis_metrics_by_session_type(\n", + " \"functional_connectivity\",\n", + " amplitude_cutoff_maximum=np.inf,\n", + " presence_ratio_minimum=-np.inf,\n", + " isi_violations_maximum=np.inf,\n", + ")\n", "\n", "all_metrics = pd.concat([analysis_metrics1, analysis_metrics2], sort=False)\n", "\n", - "print(str(len(all_metrics)) + ' units overall')" + "print(str(len(all_metrics)) + \" units overall\")" ] }, { @@ -1491,13 +1496,14 @@ } ], "source": [ - "session = cache.get_session_data(filtered_sessions.index.values[0],\n", - " isi_violations_maximum = np.inf,\n", - " amplitude_cutoff_maximum = np.inf,\n", - " presence_ratio_minimum = -np.inf\n", - " )\n", - "\n", - "print([attr_or_method for attr_or_method in dir(session) if attr_or_method[0] != '_'])" + "session = cache.get_session_data(\n", + " filtered_sessions.index.values[0],\n", + " isi_violations_maximum=np.inf,\n", + " amplitude_cutoff_maximum=np.inf,\n", + " presence_ratio_minimum=-np.inf,\n", + ")\n", + "\n", + "print([attr_or_method for attr_or_method in dir(session) if attr_or_method[0] != \"_\"])" ] }, { @@ -1638,9 +1644,8 @@ "source": [ "if DOWNLOAD_COMPLETE_DATASET:\n", " for session_id, row in sessions.iterrows():\n", - "\n", " truncated_file = True\n", - " directory = os.path.join(output_dir + '/session_' + str(session_id))\n", + " directory = os.path.join(output_dir + \"/session_\" + str(session_id))\n", "\n", " while truncated_file:\n", " session = cache.get_session_data(session_id)\n", @@ -1652,8 +1657,7 @@ " print(\" Truncated spikes file, re-downloading\")\n", "\n", " for probe_id, probe in session.probes.iterrows():\n", - "\n", - " print(' ' + probe.description)\n", + " print(\" \" + probe.description)\n", " truncated_lfp = True\n", "\n", " while truncated_lfp:\n", @@ -1661,7 +1665,7 @@ " lfp = session.get_lfp(probe_id)\n", " truncated_lfp = False\n", " except OSError:\n", - " fname = directory + '/probe_' + str(probe_id) + '_lfp.nwb'\n", + " fname = directory + \"/probe_\" + str(probe_id) + \"_lfp.nwb\"\n", " os.remove(fname)\n", " print(\" Truncated LFP file, re-downloading\")\n", " except ValueError:\n", @@ -1956,19 +1960,19 @@ ], "source": [ "def retrieve_link(session_id):\n", - " \n", " well_known_files = build_and_execute(\n", " (\n", - " \"criteria=model::WellKnownFile\"\n", - " \",rma::criteria,well_known_file_type[name$eq'EcephysNwb']\"\n", - " \"[attachable_type$eq'EcephysSession']\"\n", - " r\"[attachable_id$eq{{session_id}}]\"\n", + " \"criteria=model::WellKnownFile\"\n", + " \",rma::criteria,well_known_file_type[name$eq'EcephysNwb']\"\n", + " \"[attachable_type$eq'EcephysSession']\"\n", + " r\"[attachable_id$eq{{session_id}}]\"\n", " ),\n", - " engine=rma_engine.get_rma_tabular, \n", - " session_id=session_id\n", + " engine=rma_engine.get_rma_tabular,\n", + " session_id=session_id,\n", " )\n", - " \n", - " return 'http://api.brain-map.org/' + well_known_files['download_link'].iloc[0]\n", + "\n", + " return \"http://api.brain-map.org/\" + well_known_files[\"download_link\"].iloc[0]\n", + "\n", "\n", "download_links = [retrieve_link(session_id) for session_id in sessions.index.values]\n", "\n", @@ -2036,7 +2040,6 @@ }, "outputs": [], "source": [ - "\n", "# nwb_path = '/mnt/nvme0/ecephys_cache_dir_10_31/session_721123822/session_721123822.nwb'\n", "\n", "# session = EcephysSession.from_nwb_path(nwb_path, api_kwargs={\n", @@ -2433,7 +2436,6 @@ ], "source": [ "def retrieve_lfp_link(probe_id):\n", - "\n", " well_known_files = build_and_execute(\n", " (\n", " \"criteria=model::WellKnownFile\"\n", @@ -2441,14 +2443,15 @@ " \"[attachable_type$eq'EcephysProbe']\"\n", " r\"[attachable_id$eq{{probe_id}}]\"\n", " ),\n", - " engine=rma_engine.get_rma_tabular, \n", - " probe_id=probe_id\n", + " engine=rma_engine.get_rma_tabular,\n", + " probe_id=probe_id,\n", " )\n", "\n", " if well_known_files.shape[0] != 1:\n", - " return 'file for probe ' + str(probe_id) + ' not found'\n", - " \n", - " return 'http://api.brain-map.org/' + well_known_files.loc[0, \"download_link\"]\n", + " return \"file for probe \" + str(probe_id) + \" not found\"\n", + "\n", + " return \"http://api.brain-map.org/\" + well_known_files.loc[0, \"download_link\"]\n", + "\n", "\n", "probes = cache.get_probes()\n", "\n", diff --git a/doc_template/examples_root/examples/nb/ecephys_lfp_analysis.ipynb b/doc_template/examples_root/examples/nb/ecephys_lfp_analysis.ipynb index 6ff62e5580..ad67dd5351 100644 --- a/doc_template/examples_root/examples/nb/ecephys_lfp_analysis.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_lfp_analysis.ipynb @@ -87,6 +87,7 @@ "import pandas as pd\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache" @@ -117,7 +118,7 @@ "outputs": [], "source": [ "# Example cache directory path, it determines where downloaded data will be stored\n", - "output_dir = '/local1/ecephys_cache_dir/'" + "output_dir = \"/local1/ecephys_cache_dir/\"" ] }, { @@ -207,7 +208,7 @@ "output_type": "stream", "text": [ "WARNING:root:downloading a 2723.916MiB file from http://api.brain-map.org//api/v2/well_known_file_download/1026124469\n", - "Downloading: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 2.86G/2.86G [03:33<00:00, 13.4MB/s]\n", + "Downloading: 100%|██████████| 2.86G/2.86G [03:33<00:00, 13.4MB/s]\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.1.3 because version 1.8.0 is already loaded.\n", " warn(\"Ignoring cached namespace '%s' version %s because version %s is already loaded.\"\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.2.2 because version 2.6.0-alpha is already loaded.\n", @@ -445,9 +446,9 @@ "source": [ "probes = cache.get_probes()\n", "\n", - "print('Fraction of probes with LFP: ' + str(np.around( np.sum(probes.has_lfp_data) / len(probes), 3) ) )\n", - "print(' ')\n", - "print('Sessions with missing LFP files: ' + str(list(probes[~probes.has_lfp_data].ecephys_session_id.unique())))" + "print(\"Fraction of probes with LFP: \" + str(np.around(np.sum(probes.has_lfp_data) / len(probes), 3)))\n", + "print(\" \")\n", + "print(\"Sessions with missing LFP files: \" + str(list(probes[~probes.has_lfp_data].ecephys_session_id.unique())))" ] }, { @@ -572,16 +573,16 @@ } ], "source": [ - "plt.rcParams.update({'font.size': 14})\n", + "plt.rcParams.update({\"font.size\": 14})\n", "\n", "x_coords = session.channels.left_right_ccf_coordinate\n", "y_coords = session.channels.anterior_posterior_ccf_coordinate\n", "color = session.channels.probe_vertical_position\n", "\n", - "plt.figure(figsize=(8,8))\n", - "_ = plt.scatter(-x_coords[x_coords > 0], -y_coords[x_coords > 0], c=color[x_coords > 0], cmap='inferno')\n", - "_ = plt.xlabel('<< lateral --- medial >>')\n", - "_ = plt.ylabel('<< posterior --- anterior >>')" + "plt.figure(figsize=(8, 8))\n", + "_ = plt.scatter(-x_coords[x_coords > 0], -y_coords[x_coords > 0], c=color[x_coords > 0], cmap=\"inferno\")\n", + "_ = plt.xlabel(\"<< lateral --- medial >>\")\n", + "_ = plt.ylabel(\"<< posterior --- anterior >>\")" ] }, { @@ -700,9 +701,12 @@ } ], "source": [ - "{session.probes.loc[probe_id].description : \n", - " list(session.channels[session.channels.probe_id == probe_id].ecephys_structure_acronym.unique())\n", - " for probe_id in session.probes.index.values}" + "{\n", + " session.probes.loc[probe_id].description: list(\n", + " session.channels[session.channels.probe_id == probe_id].ecephys_structure_acronym.unique()\n", + " )\n", + " for probe_id in session.probes.index.values\n", + "}" ] }, { @@ -786,12 +790,12 @@ "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.2.2 because version 2.6.0-alpha is already loaded.\n", " warn(\"Ignoring cached namespace '%s' version %s because version %s is already loaded.\"\n", "WARNING:root:downloading a 2345.194MiB file from http://api.brain-map.org//api/v2/well_known_file_download/1026124475\n", - "Downloading: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 2.46G/2.46G [03:11<00:00, 12.8MB/s]\n" + "Downloading: 100%|██████████| 2.46G/2.46G [03:11<00:00, 12.8MB/s]\n" ] } ], "source": [ - "probe_id = session.probes[session.probes.description == 'probeE'].index.values[0]\n", + "probe_id = session.probes[session.probes.description == \"probeE\"].index.values[0]\n", "\n", "lfp = session.get_lfp(probe_id)" ] @@ -961,7 +965,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -972,7 +976,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -1469,7 +1473,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -1480,7 +1484,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -1807,7 +1811,7 @@ } ], "source": [ - "lfp_slice = lfp.sel(time=slice(100,101))\n", + "lfp_slice = lfp.sel(time=slice(100, 101))\n", "\n", "lfp_slice" ] @@ -1874,10 +1878,10 @@ } ], "source": [ - "plt.figure(figsize=(10,2))\n", + "plt.figure(figsize=(10, 2))\n", "_ = plt.plot(lfp_slice.time, lfp_slice.sel(channel=lfp_slice.channel[10]))\n", - "plt.xlabel('Time (s)')\n", - "plt.ylabel('LFP (V)')" + "plt.xlabel(\"Time (s)\")\n", + "plt.ylabel(\"LFP (V)\")" ] }, { @@ -1930,11 +1934,11 @@ } ], "source": [ - "plt.figure(figsize=(8,8))\n", - "im = plt.imshow(lfp_slice.T,aspect='auto',origin='lower',vmin=-1e-3, vmax=1e-3)\n", + "plt.figure(figsize=(8, 8))\n", + "im = plt.imshow(lfp_slice.T, aspect=\"auto\", origin=\"lower\", vmin=-1e-3, vmax=1e-3)\n", "_ = plt.colorbar(im, fraction=0.036, pad=0.04)\n", - "_ = plt.xlabel('Sample number')\n", - "_ = plt.ylabel('Channel index')" + "_ = plt.xlabel(\"Sample number\")\n", + "_ = plt.ylabel(\"Channel index\")" ] }, { @@ -1996,16 +2000,17 @@ } ], "source": [ - "channel_ids = session.channels[(session.channels.probe_id == probe_id) & \\\n", - " (session.channels.ecephys_structure_acronym.isin(['CA1','CA3','DG']))].index.values\n", + "channel_ids = session.channels[\n", + " (session.channels.probe_id == probe_id) & (session.channels.ecephys_structure_acronym.isin([\"CA1\", \"CA3\", \"DG\"]))\n", + "].index.values\n", "\n", "lfp_slice2 = lfp_slice.sel(channel=slice(np.min(channel_ids), np.max(channel_ids)))\n", "\n", - "plt.figure(figsize=(8,4))\n", - "im = plt.imshow(lfp_slice2.T,aspect='auto',origin='lower',vmin=-1e-3, vmax=1e-3)\n", + "plt.figure(figsize=(8, 4))\n", + "im = plt.imshow(lfp_slice2.T, aspect=\"auto\", origin=\"lower\", vmin=-1e-3, vmax=1e-3)\n", "_ = plt.colorbar(im, fraction=0.036, pad=0.04)\n", - "_ = plt.xlabel('Sample number')\n", - "_ = plt.ylabel('Channel index')" + "_ = plt.xlabel(\"Sample number\")\n", + "_ = plt.ylabel(\"Channel index\")" ] }, { @@ -2103,7 +2108,7 @@ } ], "source": [ - "presentation_table = session.stimulus_presentations[session.stimulus_presentations.stimulus_name == 'flashes']\n", + "presentation_table = session.stimulus_presentations[session.stimulus_presentations.stimulus_name == \"flashes\"]\n", "\n", "presentation_times = presentation_table.start_time.values\n", "presentation_ids = presentation_table.index.values" @@ -2148,16 +2153,17 @@ }, "outputs": [], "source": [ - "trial_window = np.arange(-0.5, 0.5, 1/500)\n", + "trial_window = np.arange(-0.5, 0.5, 1 / 500)\n", "time_selection = np.concatenate([trial_window + t for t in presentation_times])\n", "\n", - "inds = pd.MultiIndex.from_product((presentation_ids, trial_window), \n", - " names=('presentation_id', 'time_from_presentation_onset'))\n", + "inds = pd.MultiIndex.from_product(\n", + " (presentation_ids, trial_window), names=(\"presentation_id\", \"time_from_presentation_onset\")\n", + ")\n", "\n", - "ds = lfp.sel(time = time_selection, method='nearest').to_dataset(name = 'aligned_lfp')\n", - "ds = ds.assign(time=inds).unstack('time')\n", + "ds = lfp.sel(time=time_selection, method=\"nearest\").to_dataset(name=\"aligned_lfp\")\n", + "ds = ds.assign(time=inds).unstack(\"time\")\n", "\n", - "aligned_lfp = ds['aligned_lfp']" + "aligned_lfp = ds[\"aligned_lfp\"]" ] }, { @@ -2214,11 +2220,11 @@ } ], "source": [ - "plt.figure(figsize=(8,6))\n", - "im = plt.imshow(aligned_lfp.mean(dim='presentation_id'), aspect='auto', origin='lower', vmin=-1e-4, vmax=1e-4)\n", + "plt.figure(figsize=(8, 6))\n", + "im = plt.imshow(aligned_lfp.mean(dim=\"presentation_id\"), aspect=\"auto\", origin=\"lower\", vmin=-1e-4, vmax=1e-4)\n", "_ = plt.colorbar(im, fraction=0.036, pad=0.04)\n", - "_ = plt.xlabel('Sample number')\n", - "_ = plt.ylabel('Channel index')" + "_ = plt.xlabel(\"Sample number\")\n", + "_ = plt.ylabel(\"Channel index\")" ] }, { @@ -2338,11 +2344,13 @@ } ], "source": [ - "units_of_interest = session.units[(session.units.probe_id == probe_id) &\n", - " (session.units.ecephys_structure_acronym.str.find('VIS') > -1) &\n", - " (session.units.firing_rate > 10) & \n", - " (session.units.nn_hit_rate > 0.95)]\n", - " \n", + "units_of_interest = session.units[\n", + " (session.units.probe_id == probe_id)\n", + " & (session.units.ecephys_structure_acronym.str.find(\"VIS\") > -1)\n", + " & (session.units.firing_rate > 10)\n", + " & (session.units.nn_hit_rate > 0.95)\n", + "]\n", + "\n", "len(units_of_interest)" ] }, @@ -2455,8 +2463,9 @@ } ], "source": [ - "channel_id = session.channels[(session.channels.probe_channel_number == channel_index) & \n", - " (session.channels.probe_id == probe_id)].index.values[0]\n", + "channel_id = session.channels[\n", + " (session.channels.probe_channel_number == channel_index) & (session.channels.probe_id == probe_id)\n", + "].index.values[0]\n", "\n", "channel_id" ] @@ -2507,8 +2516,8 @@ "\n", "times_in_range = spike_times[(spike_times > start_time) & (spike_times < end_time)]\n", "\n", - "lfp_data = lfp.sel(time = slice(start_time, end_time))\n", - "lfp_data = lfp_data.sel(channel = channel_id, method='nearest')" + "lfp_data = lfp.sel(time=slice(start_time, end_time))\n", + "lfp_data = lfp_data.sel(channel=channel_id, method=\"nearest\")" ] }, { @@ -2562,9 +2571,9 @@ ], "source": [ "_ = plt.plot(lfp_data.time, lfp_data)\n", - "_ = plt.plot(times_in_range, np.ones(times_in_range.shape)*3e-4, '.r')\n", - "_ = plt.xlabel('Time (s)')\n", - "_ = plt.ylabel('LFP (V)')" + "_ = plt.plot(times_in_range, np.ones(times_in_range.shape) * 3e-4, \".r\")\n", + "_ = plt.xlabel(\"Time (s)\")\n", + "_ = plt.ylabel(\"LFP (V)\")" ] }, { @@ -2774,7 +2783,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -2785,7 +2794,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -3229,9 +3238,9 @@ "source": [ "from scipy.ndimage.filters import gaussian_filter\n", "\n", - "_ = plt.figure(figsize=(10,10))\n", + "_ = plt.figure(figsize=(10, 10))\n", "\n", - "filtered_csd = gaussian_filter(csd.data, sigma=(5,1))\n", + "filtered_csd = gaussian_filter(csd.data, sigma=(5, 1))\n", "\n", "fig, ax = plt.subplots(figsize=(6, 6))\n", "\n", @@ -3293,9 +3302,13 @@ } ], "source": [ - "list(session.channels[(session.channels.probe_id == probe_id) &\n", - " (session.channels.probe_vertical_position > 700) &\n", - " (session.channels.probe_vertical_position < 1200)].ecephys_structure_acronym.unique())" + "list(\n", + " session.channels[\n", + " (session.channels.probe_id == probe_id)\n", + " & (session.channels.probe_vertical_position > 700)\n", + " & (session.channels.probe_vertical_position < 1200)\n", + " ].ecephys_structure_acronym.unique()\n", + ")" ] }, { diff --git a/doc_template/examples_root/examples/nb/ecephys_optotagging.ipynb b/doc_template/examples_root/examples/nb/ecephys_optotagging.ipynb index d9d6a518cb..0e2705f9d2 100644 --- a/doc_template/examples_root/examples/nb/ecephys_optotagging.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_optotagging.ipynb @@ -98,6 +98,7 @@ "import xarray as xr\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache" @@ -153,7 +154,7 @@ "outputs": [], "source": [ "# Example cache directory path, it determines where downloaded data will be stored\n", - "output_dir = '/local1/ecephys_cache_dir/'" + "output_dir = \"/local1/ecephys_cache_dir/\"" ] }, { @@ -554,7 +555,7 @@ } ], "source": [ - "pvalb_sessions = sessions[sessions.full_genotype.str.match('Pvalb')]\n", + "pvalb_sessions = sessions[sessions.full_genotype.str.match(\"Pvalb\")]\n", "\n", "pvalb_sessions" ] @@ -1078,9 +1079,11 @@ } ], "source": [ - "columns = ['stimulus_name', 'duration','level']\n", + "columns = [\"stimulus_name\", \"duration\", \"level\"]\n", "\n", - "session.optogenetic_stimulation_epochs.drop_duplicates(columns).sort_values(by=columns).drop(columns=['start_time','stop_time'])" + "session.optogenetic_stimulation_epochs.drop_duplicates(columns).sort_values(by=columns).drop(\n", + " columns=[\"start_time\", \"stop_time\"]\n", + ")" ] }, { @@ -1264,44 +1267,43 @@ } ], "source": [ - "trials = session.optogenetic_stimulation_epochs[(session.optogenetic_stimulation_epochs.duration > 0.009) & \\\n", - " (session.optogenetic_stimulation_epochs.duration < 0.02)]\n", + "trials = session.optogenetic_stimulation_epochs[\n", + " (session.optogenetic_stimulation_epochs.duration > 0.009) & (session.optogenetic_stimulation_epochs.duration < 0.02)\n", + "]\n", "\n", - "units = session.units[session.units.ecephys_structure_acronym.str.match('VIS')]\n", + "units = session.units[session.units.ecephys_structure_acronym.str.match(\"VIS\")]\n", "\n", - "time_resolution = 0.0005 # 0.5 ms bins\n", + "time_resolution = 0.0005 # 0.5 ms bins\n", "\n", "bin_edges = np.arange(-0.01, 0.025, time_resolution)\n", "\n", + "\n", "def optotagging_spike_counts(bin_edges, trials, units):\n", - " \n", " time_resolution = np.mean(np.diff(bin_edges))\n", "\n", - " spike_matrix = np.zeros( (len(trials), len(bin_edges), len(units)) )\n", + " spike_matrix = np.zeros((len(trials), len(bin_edges), len(units)))\n", "\n", " for unit_idx, unit_id in enumerate(units.index.values):\n", - "\n", " spike_times = session.spike_times[unit_id]\n", "\n", " for trial_idx, trial_start in enumerate(trials.start_time.values):\n", + " in_range = (spike_times > (trial_start + bin_edges[0])) * (spike_times < (trial_start + bin_edges[-1]))\n", "\n", - " in_range = (spike_times > (trial_start + bin_edges[0])) * \\\n", - " (spike_times < (trial_start + bin_edges[-1]))\n", - "\n", - " binned_times = ((spike_times[in_range] - (trial_start + bin_edges[0])) / time_resolution).astype('int')\n", + " binned_times = ((spike_times[in_range] - (trial_start + bin_edges[0])) / time_resolution).astype(\"int\")\n", " spike_matrix[trial_idx, binned_times, unit_idx] = 1\n", "\n", " return xr.DataArray(\n", - " name='spike_counts',\n", + " name=\"spike_counts\",\n", " data=spike_matrix,\n", " coords={\n", - " 'trial_id': trials.index.values,\n", - " 'time_relative_to_stimulus_onset': bin_edges,\n", - " 'unit_id': units.index.values\n", + " \"trial_id\": trials.index.values,\n", + " \"time_relative_to_stimulus_onset\": bin_edges,\n", + " \"unit_id\": units.index.values,\n", " },\n", - " dims=['trial_id', 'time_relative_to_stimulus_onset', 'unit_id']\n", + " dims=[\"trial_id\", \"time_relative_to_stimulus_onset\", \"unit_id\"],\n", " )\n", "\n", + "\n", "da = optotagging_spike_counts(bin_edges, trials, units)" ] }, @@ -1362,23 +1364,26 @@ ], "source": [ "def plot_optotagging_response(da):\n", - "\n", - " plt.figure(figsize=(5,10))\n", - "\n", - " plt.imshow(da.mean(dim='trial_id').T / time_resolution, \n", - " extent=[np.min(bin_edges), np.max(bin_edges),\n", - " 0, len(units)],\n", - " aspect='auto', vmin=0, vmax=200) \n", + " plt.figure(figsize=(5, 10))\n", + "\n", + " plt.imshow(\n", + " da.mean(dim=\"trial_id\").T / time_resolution,\n", + " extent=[np.min(bin_edges), np.max(bin_edges), 0, len(units)],\n", + " aspect=\"auto\",\n", + " vmin=0,\n", + " vmax=200,\n", + " )\n", "\n", " for bound in [0.0005, 0.0095]:\n", - " plt.plot([bound, bound],[0, len(units)], ':', color='white', linewidth=1.0)\n", + " plt.plot([bound, bound], [0, len(units)], \":\", color=\"white\", linewidth=1.0)\n", "\n", - " plt.xlabel('Time (s)')\n", - " plt.ylabel('Unit #')\n", + " plt.xlabel(\"Time (s)\")\n", + " plt.ylabel(\"Unit #\")\n", "\n", " cb = plt.colorbar(fraction=0.046, pad=0.04)\n", - " cb.set_label('Mean firing rate (Hz)')\n", - " \n", + " cb.set_label(\"Mean firing rate (Hz)\")\n", + "\n", + "\n", "plot_optotagging_response(da)" ] }, @@ -1478,13 +1483,13 @@ }, "outputs": [], "source": [ - "baseline = da.sel(time_relative_to_stimulus_onset=slice(-0.01,-0.002))\n", + "baseline = da.sel(time_relative_to_stimulus_onset=slice(-0.01, -0.002))\n", "\n", - "baseline_rate = baseline.sum(dim='time_relative_to_stimulus_onset').mean(dim='trial_id') / 0.008\n", + "baseline_rate = baseline.sum(dim=\"time_relative_to_stimulus_onset\").mean(dim=\"trial_id\") / 0.008\n", "\n", - "evoked = da.sel(time_relative_to_stimulus_onset=slice(0.001,0.009))\n", + "evoked = da.sel(time_relative_to_stimulus_onset=slice(0.001, 0.009))\n", "\n", - "evoked_rate = evoked.sum(dim='time_relative_to_stimulus_onset').mean(dim='trial_id') / 0.008" + "evoked_rate = evoked.sum(dim=\"time_relative_to_stimulus_onset\").mean(dim=\"trial_id\") / 0.008" ] }, { @@ -1543,18 +1548,18 @@ } ], "source": [ - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "\n", "plt.scatter(baseline_rate, evoked_rate, s=3)\n", "\n", "axis_limit = 250\n", - "plt.plot([0,axis_limit],[0,axis_limit], ':k')\n", - "plt.plot([0,axis_limit],[0,axis_limit*2], ':r')\n", - "plt.xlim([0,axis_limit])\n", - "plt.ylim([0,axis_limit])\n", + "plt.plot([0, axis_limit], [0, axis_limit], \":k\")\n", + "plt.plot([0, axis_limit], [0, axis_limit * 2], \":r\")\n", + "plt.xlim([0, axis_limit])\n", + "plt.ylim([0, axis_limit])\n", "\n", - "plt.xlabel('Baseline rate (Hz)')\n", - "_ = plt.ylabel('Evoked rate (Hz)')" + "plt.xlabel(\"Baseline rate (Hz)\")\n", + "_ = plt.ylabel(\"Evoked rate (Hz)\")" ] }, { @@ -1622,7 +1627,7 @@ } ], "source": [ - "cre_pos_units = da.unit_id[(evoked_rate / (baseline_rate + 1)) > 2].values # add 1 to prevent divide-by-zero errors\n", + "cre_pos_units = da.unit_id[(evoked_rate / (baseline_rate + 1)) > 2].values # add 1 to prevent divide-by-zero errors\n", "\n", "cre_pos_units" ] @@ -1693,18 +1698,17 @@ } ], "source": [ - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "\n", "for unit_id in cre_pos_units:\n", - " \n", " peak_channel = session.units.loc[unit_id].peak_channel_id\n", - " wv = session.mean_waveforms[unit_id].sel(channel_id = peak_channel)\n", - " \n", - " plt.plot(wv.time * 1000, wv, 'k', alpha=0.3)\n", + " wv = session.mean_waveforms[unit_id].sel(channel_id=peak_channel)\n", + "\n", + " plt.plot(wv.time * 1000, wv, \"k\", alpha=0.3)\n", "\n", - "plt.xlabel('Time (ms)')\n", - "plt.ylabel('Amplitude (microvolts)')\n", - "_ =plt.plot([1.0, 1.0],[-160, 100],':c')" + "plt.xlabel(\"Time (ms)\")\n", + "plt.ylabel(\"Amplitude (microvolts)\")\n", + "_ = plt.plot([1.0, 1.0], [-160, 100], \":c\")" ] }, { @@ -1809,7 +1813,7 @@ } ], "source": [ - "sst_sessions = sessions[sessions.full_genotype.str.match('Sst')]\n", + "sst_sessions = sessions[sessions.full_genotype.str.match(\"Sst\")]\n", "\n", "session = cache.get_session_data(sst_sessions.index.values[-1])" ] @@ -1866,10 +1870,11 @@ } ], "source": [ - "trials = session.optogenetic_stimulation_epochs[(session.optogenetic_stimulation_epochs.duration > 0.009) & \\\n", - " (session.optogenetic_stimulation_epochs.duration < 0.02)]\n", + "trials = session.optogenetic_stimulation_epochs[\n", + " (session.optogenetic_stimulation_epochs.duration > 0.009) & (session.optogenetic_stimulation_epochs.duration < 0.02)\n", + "]\n", "\n", - "units = session.units[session.units.ecephys_structure_acronym.str.match('VIS')]\n", + "units = session.units[session.units.ecephys_structure_acronym.str.match(\"VIS\")]\n", "\n", "bin_edges = np.arange(-0.01, 0.025, 0.0005)\n", "\n", @@ -1962,13 +1967,13 @@ }, "outputs": [], "source": [ - "baseline = da.sel(time_relative_to_stimulus_onset=slice(-0.01,-0.002))\n", + "baseline = da.sel(time_relative_to_stimulus_onset=slice(-0.01, -0.002))\n", "\n", - "baseline_rate = baseline.sum(dim='time_relative_to_stimulus_onset').mean(dim='trial_id') / 0.008\n", + "baseline_rate = baseline.sum(dim=\"time_relative_to_stimulus_onset\").mean(dim=\"trial_id\") / 0.008\n", "\n", - "evoked = da.sel(time_relative_to_stimulus_onset=slice(0.001,0.009))\n", + "evoked = da.sel(time_relative_to_stimulus_onset=slice(0.001, 0.009))\n", "\n", - "evoked_rate = evoked.sum(dim='time_relative_to_stimulus_onset').mean(dim='trial_id') / 0.008" + "evoked_rate = evoked.sum(dim=\"time_relative_to_stimulus_onset\").mean(dim=\"trial_id\") / 0.008" ] }, { @@ -2007,18 +2012,18 @@ } ], "source": [ - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "\n", "plt.scatter(baseline_rate, evoked_rate, s=3)\n", "\n", "axis_limit = 175\n", - "plt.plot([0,axis_limit],[0,axis_limit], ':k')\n", - "plt.plot([0,axis_limit],[0,axis_limit*2], ':r')\n", - "plt.xlim([0,axis_limit])\n", - "plt.ylim([0,axis_limit])\n", + "plt.plot([0, axis_limit], [0, axis_limit], \":k\")\n", + "plt.plot([0, axis_limit], [0, axis_limit * 2], \":r\")\n", + "plt.xlim([0, axis_limit])\n", + "plt.ylim([0, axis_limit])\n", "\n", - "plt.xlabel('Baseline rate (Hz)')\n", - "_ = plt.ylabel('Evoked rate (Hz)')" + "plt.xlabel(\"Baseline rate (Hz)\")\n", + "_ = plt.ylabel(\"Evoked rate (Hz)\")" ] }, { @@ -2089,18 +2094,17 @@ "source": [ "cre_pos_units = da.unit_id[(evoked_rate / (baseline_rate + 1)) > 2].values\n", "\n", - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "\n", "for unit_id in cre_pos_units:\n", - " \n", " peak_channel = session.units.loc[unit_id].peak_channel_id\n", - " wv = session.mean_waveforms[unit_id].sel(channel_id = peak_channel)\n", - " \n", - " plt.plot(wv.time * 1000, wv, 'k', alpha=0.3)\n", + " wv = session.mean_waveforms[unit_id].sel(channel_id=peak_channel)\n", + "\n", + " plt.plot(wv.time * 1000, wv, \"k\", alpha=0.3)\n", "\n", - "plt.xlabel('Time (ms)')\n", - "plt.ylabel('Amplitude (microvolts)')\n", - "_ =plt.plot([1.0, 1.0],[-160, 100],':c')" + "plt.xlabel(\"Time (ms)\")\n", + "plt.ylabel(\"Amplitude (microvolts)\")\n", + "_ = plt.plot([1.0, 1.0], [-160, 100], \":c\")" ] }, { @@ -2181,7 +2185,7 @@ } ], "source": [ - "vip_sessions = sessions[sessions.full_genotype.str.match('Vip')]\n", + "vip_sessions = sessions[sessions.full_genotype.str.match(\"Vip\")]\n", "\n", "session = cache.get_session_data(vip_sessions.index.values[-1])" ] @@ -2238,10 +2242,11 @@ } ], "source": [ - "trials = session.optogenetic_stimulation_epochs[(session.optogenetic_stimulation_epochs.duration > 0.009) & \\\n", - " (session.optogenetic_stimulation_epochs.duration < 0.02)]\n", + "trials = session.optogenetic_stimulation_epochs[\n", + " (session.optogenetic_stimulation_epochs.duration > 0.009) & (session.optogenetic_stimulation_epochs.duration < 0.02)\n", + "]\n", "\n", - "units = session.units[session.units.ecephys_structure_acronym.str.match('VIS')]\n", + "units = session.units[session.units.ecephys_structure_acronym.str.match(\"VIS\")]\n", "\n", "bin_edges = np.arange(-0.01, 0.025, 0.0005)\n", "\n", diff --git a/doc_template/examples_root/examples/nb/ecephys_quality_metrics.ipynb b/doc_template/examples_root/examples/nb/ecephys_quality_metrics.ipynb index 6841496507..b5aef6a852 100644 --- a/doc_template/examples_root/examples/nb/ecephys_quality_metrics.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_quality_metrics.ipynb @@ -149,6 +149,7 @@ "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache" @@ -179,7 +180,7 @@ "outputs": [], "source": [ "# Example cache directory path, it determines where downloaded data will be stored\n", - "output_dir = '/local1/ecephys_cache_dir/'" + "output_dir = \"/local1/ecephys_cache_dir/\"" ] }, { @@ -305,9 +306,7 @@ } ], "source": [ - "units = cache.get_units(amplitude_cutoff_maximum = np.inf,\n", - " presence_ratio_minimum = -np.inf,\n", - " isi_violations_maximum = np.inf)\n", + "units = cache.get_units(amplitude_cutoff_maximum=np.inf, presence_ratio_minimum=-np.inf, isi_violations_maximum=np.inf)\n", "\n", "len(units)" ] @@ -363,10 +362,11 @@ ], "source": [ "from scipy.ndimage.filters import gaussian_filter1d\n", - "plt.rcParams.update({'font.size': 14})\n", + "\n", + "plt.rcParams.update({\"font.size\": 14})\n", + "\n", "\n", "def plot_metric(data, bins, x_axis_label, color, max_value=-1):\n", - " \n", " h, b = np.histogram(data, bins=bins, density=True)\n", "\n", " x = b[:-1]\n", @@ -375,11 +375,11 @@ " plt.plot(x, y, color=color)\n", " plt.xlabel(x_axis_label)\n", " plt.gca().get_yaxis().set_visible(False)\n", - " [plt.gca().spines[loc].set_visible(False) for loc in ['right', 'top', 'left']]\n", + " [plt.gca().spines[loc].set_visible(False) for loc in [\"right\", \"top\", \"left\"]]\n", " if max_value < np.max(y) * 1.1:\n", " max_value = np.max(y) * 1.1\n", " plt.ylim([0, max_value])\n", - " \n", + "\n", " return max_value" ] }, @@ -450,10 +450,10 @@ } ], "source": [ - "data = units['firing_rate']\n", - "bins = np.linspace(0,50,100)\n", + "data = units[\"firing_rate\"]\n", + "bins = np.linspace(0, 50, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'Firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"Firing rate (Hz)\", \"red\")" ] }, { @@ -506,10 +506,10 @@ } ], "source": [ - "data = np.log10(units['firing_rate'])\n", - "bins = np.linspace(-3,2,100)\n", + "data = np.log10(units[\"firing_rate\"])\n", + "bins = np.linspace(-3, 2, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", \"red\")" ] }, { @@ -562,10 +562,10 @@ } ], "source": [ - "data = np.log10(units[units.nn_hit_rate > 0.9]['firing_rate'])\n", - "bins = np.linspace(-3,2,100)\n", + "data = np.log10(units[units.nn_hit_rate > 0.9][\"firing_rate\"])\n", + "bins = np.linspace(-3, 2, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", \"red\")" ] }, { @@ -618,26 +618,47 @@ } ], "source": [ - "region_dict = {'cortex' : ['VISp', 'VISl', 'VISrl', 'VISam', 'VISpm', 'VIS', 'VISal','VISmma','VISmmp','VISli'],\n", - " 'thalamus' : ['LGd','LD', 'LP', 'VPM', 'TH', 'MGm','MGv','MGd','PO','LGv','VL',\n", - " 'VPL','POL','Eth','PoT','PP','PIL','IntG','IGL','SGN','VPL','PF','RT'],\n", - " 'hippocampus' : ['CA1', 'CA2','CA3', 'DG', 'SUB', 'POST','PRE','ProS','HPF'],\n", - " 'midbrain': ['MB','SCig','SCiw','SCsg','SCzo','PPT','APN','NOT','MRN','OP','LT','RPF','CP']}\n", - "\n", - "color_dict = {'cortex' : '#08858C',\n", - " 'thalamus' : '#FC6B6F',\n", - " 'hippocampus' : '#7ED04B',\n", - " 'midbrain' : '#FC9DFE'}\n", - "\n", - "bins = np.linspace(-3,2,100)\n", + "region_dict = {\n", + " \"cortex\": [\"VISp\", \"VISl\", \"VISrl\", \"VISam\", \"VISpm\", \"VIS\", \"VISal\", \"VISmma\", \"VISmmp\", \"VISli\"],\n", + " \"thalamus\": [\n", + " \"LGd\",\n", + " \"LD\",\n", + " \"LP\",\n", + " \"VPM\",\n", + " \"TH\",\n", + " \"MGm\",\n", + " \"MGv\",\n", + " \"MGd\",\n", + " \"PO\",\n", + " \"LGv\",\n", + " \"VL\",\n", + " \"VPL\",\n", + " \"POL\",\n", + " \"Eth\",\n", + " \"PoT\",\n", + " \"PP\",\n", + " \"PIL\",\n", + " \"IntG\",\n", + " \"IGL\",\n", + " \"SGN\",\n", + " \"VPL\",\n", + " \"PF\",\n", + " \"RT\",\n", + " ],\n", + " \"hippocampus\": [\"CA1\", \"CA2\", \"CA3\", \"DG\", \"SUB\", \"POST\", \"PRE\", \"ProS\", \"HPF\"],\n", + " \"midbrain\": [\"MB\", \"SCig\", \"SCiw\", \"SCsg\", \"SCzo\", \"PPT\", \"APN\", \"NOT\", \"MRN\", \"OP\", \"LT\", \"RPF\", \"CP\"],\n", + "}\n", + "\n", + "color_dict = {\"cortex\": \"#08858C\", \"thalamus\": \"#FC6B6F\", \"hippocampus\": \"#7ED04B\", \"midbrain\": \"#FC9DFE\"}\n", + "\n", + "bins = np.linspace(-3, 2, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = np.log10(units[units.ecephys_structure_acronym.isin(region_dict[region])]['firing_rate'])\n", - " \n", - " max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', color_dict[region], max_value)\n", - " \n", + " data = np.log10(units[units.ecephys_structure_acronym.isin(region_dict[region])][\"firing_rate\"])\n", + "\n", + " max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -749,18 +770,17 @@ } ], "source": [ - "bins = np.linspace(0,1,100)\n", + "bins = np.linspace(0, 1, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['presence_ratio']\n", - " \n", - " max_value = plot_metric(data, bins, 'Presence ratio', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"presence_ratio\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Presence ratio\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.9, 0.9],[0,max_value], ':')" + "plt.plot([0.9, 0.9], [0, max_value], \":\")" ] }, { @@ -923,18 +943,17 @@ } ], "source": [ - "bins = np.linspace(0,0.5,200)\n", + "bins = np.linspace(0, 0.5, 200)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['amplitude_cutoff']\n", - " \n", - " max_value = plot_metric(data, bins, 'Amplitude cutoff', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"amplitude_cutoff\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Amplitude cutoff\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.1, 0.1],[0,max_value], ':')" + "plt.plot([0.1, 0.1], [0, max_value], \":\")" ] }, { @@ -1099,18 +1118,17 @@ } ], "source": [ - "bins = np.linspace(0,10,200)\n", + "bins = np.linspace(0, 10, 200)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['isi_violations']\n", - " \n", - " max_value = plot_metric(data, bins, 'ISI violations', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"isi_violations\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"ISI violations\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.5, 0.5],[0,max_value], ':')" + "plt.plot([0.5, 0.5], [0, max_value], \":\")" ] }, { @@ -1173,18 +1191,17 @@ } ], "source": [ - "bins = np.linspace(-6,2.5,100)\n", + "bins = np.linspace(-6, 2.5, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = np.log10(units[units.ecephys_structure_acronym.isin(region_dict[region])]['isi_violations'] + 1e-5) \n", - " \n", - " max_value = plot_metric(data, bins, '$log_{10}$ ISI violations', color_dict[region], max_value)\n", - " \n", + " data = np.log10(units[units.ecephys_structure_acronym.isin(region_dict[region])][\"isi_violations\"] + 1e-5)\n", + "\n", + " max_value = plot_metric(data, bins, \"$log_{10}$ ISI violations\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([np.log10(0.5), np.log10(0.5)],[0,max_value], ':')" + "plt.plot([np.log10(0.5), np.log10(0.5)], [0, max_value], \":\")" ] }, { @@ -1342,15 +1359,14 @@ } ], "source": [ - "bins = np.linspace(0,10,100)\n", + "bins = np.linspace(0, 10, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['snr']\n", - " \n", - " max_value = plot_metric(data, bins, 'SNR', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"snr\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"SNR\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1469,15 +1485,14 @@ } ], "source": [ - "bins = np.linspace(0,170,50)\n", + "bins = np.linspace(0, 170, 50)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['isolation_distance']\n", - " \n", - " max_value = plot_metric(data, bins, 'Isolation distance', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"isolation_distance\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Isolation distance\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1572,15 +1587,14 @@ } ], "source": [ - "bins = np.linspace(0,15,50)\n", + "bins = np.linspace(0, 15, 50)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['d_prime']\n", - " \n", - " max_value = plot_metric(data, bins, 'd-prime', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"d_prime\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"d-prime\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1674,15 +1688,14 @@ } ], "source": [ - "bins = np.linspace(0,1,100)\n", + "bins = np.linspace(0, 1, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.ecephys_structure_acronym.isin(region_dict[region])]['nn_hit_rate']\n", - " \n", - " max_value = plot_metric(data, bins, 'Nearest-neighbors hit rate', color_dict[region], max_value)\n", - " \n", + " data = units[units.ecephys_structure_acronym.isin(region_dict[region])][\"nn_hit_rate\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Nearest-neighbors hit rate\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1776,39 +1789,33 @@ } ], "source": [ - "metrics = ['firing_rate', \n", - " 'presence_ratio', \n", - " 'amplitude_cutoff', \n", - " 'isi_violations', \n", - " 'snr', \n", - " 'isolation_distance', \n", - " 'd_prime', \n", - " 'nn_hit_rate']\n", - "\n", - "ranges = [[0,20],\n", - " [0.9,0.995],\n", - " [0,0.5],\n", - " [0,2],\n", - " [0,8],\n", - " [0,125],\n", - " [0,10],\n", - " [0,1]]\n", - "\n", - "_ = plt.figure(figsize=(5,10))\n", + "metrics = [\n", + " \"firing_rate\",\n", + " \"presence_ratio\",\n", + " \"amplitude_cutoff\",\n", + " \"isi_violations\",\n", + " \"snr\",\n", + " \"isolation_distance\",\n", + " \"d_prime\",\n", + " \"nn_hit_rate\",\n", + "]\n", + "\n", + "ranges = [[0, 20], [0.9, 0.995], [0, 0.5], [0, 2], [0, 8], [0, 125], [0, 10], [0, 1]]\n", + "\n", + "_ = plt.figure(figsize=(5, 10))\n", "\n", "for idx, metric in enumerate(metrics):\n", - " \n", " data = units[metric].values\n", " data = data[np.invert(np.isnan(data))]\n", "\n", - " _ = plt.subplot(len(metrics),1,idx+1)\n", + " _ = plt.subplot(len(metrics), 1, idx + 1)\n", " _ = plt.boxplot(data, showfliers=False, showcaps=False, vert=False)\n", - " _ = plt.ylim([0.8,1.2])\n", + " _ = plt.ylim([0.8, 1.2])\n", " _ = plt.xlim(ranges[idx])\n", " _ = plt.yticks([])\n", - " \n", + "\n", " plt.title(metric)\n", - " \n", + "\n", "plt.tight_layout()" ] } diff --git a/doc_template/examples_root/examples/nb/ecephys_quickstart.ipynb b/doc_template/examples_root/examples/nb/ecephys_quickstart.ipynb index 21647e63b5..4d4b363523 100644 --- a/doc_template/examples_root/examples/nb/ecephys_quickstart.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_quickstart.ipynb @@ -106,7 +106,7 @@ "outputs": [], "source": [ "# Example cache directory path, it determines where downloaded data will be stored\n", - "output_dir = '/local1/ecephys_cache_dir/'" + "output_dir = \"/local1/ecephys_cache_dir/\"" ] }, { @@ -662,15 +662,13 @@ ], "source": [ "presentations = session.get_stimulus_table(\"flashes\")\n", - "units = session.units[session.units[\"ecephys_structure_acronym\"] == 'VISp']\n", + "units = session.units[session.units[\"ecephys_structure_acronym\"] == \"VISp\"]\n", "\n", "time_step = 0.01\n", "time_bins = np.arange(-0.1, 0.5 + time_step, time_step)\n", "\n", "histograms = session.presentationwise_spike_counts(\n", - " stimulus_presentation_ids=presentations.index.values, \n", - " bin_edges=time_bins,\n", - " unit_ids=units.index.values\n", + " stimulus_presentation_ids=presentations.index.values, bin_edges=time_bins, unit_ids=units.index.values\n", ")\n", "\n", "histograms.coords" @@ -713,11 +711,11 @@ "\n", "fig, ax = plt.subplots(figsize=(8, 8))\n", "ax.pcolormesh(\n", - " mean_histograms[\"time_relative_to_stimulus_onset\"], \n", + " mean_histograms[\"time_relative_to_stimulus_onset\"],\n", " np.arange(mean_histograms[\"unit_id\"].size),\n", - " mean_histograms.T, \n", + " mean_histograms.T,\n", " vmin=0,\n", - " vmax=1\n", + " vmax=1,\n", ")\n", "\n", "ax.set_ylabel(\"unit\", fontsize=24)\n", @@ -928,8 +926,7 @@ "visp_units = session.units[session.units[\"ecephys_structure_acronym\"] == \"VISp\"]\n", "\n", "spikes = session.presentationwise_spike_times(\n", - " stimulus_presentation_ids=scene_presentations.index.values,\n", - " unit_ids=visp_units.index.values[:]\n", + " stimulus_presentation_ids=scene_presentations.index.values, unit_ids=visp_units.index.values[:]\n", ")\n", "\n", "spikes" @@ -1395,12 +1392,7 @@ "spikes = spikes.groupby([\"stimulus_presentation_id\", \"unit_id\"]).count()\n", "\n", "design = pd.pivot_table(\n", - " spikes, \n", - " values=\"count\", \n", - " index=\"stimulus_presentation_id\", \n", - " columns=\"unit_id\", \n", - " fill_value=0.0,\n", - " aggfunc=np.sum\n", + " spikes, values=\"count\", index=\"stimulus_presentation_id\", columns=\"unit_id\", fill_value=0.0, aggfunc=np.sum\n", ")\n", "\n", "design" @@ -1579,16 +1571,15 @@ "confusions = []\n", "\n", "for train_indices, test_indices in KFold(n_splits=5).split(design_arr):\n", - " \n", " clf = svm.SVC(gamma=\"scale\", kernel=\"rbf\")\n", " clf.fit(design_arr[train_indices], targets_arr[train_indices])\n", - " \n", + "\n", " test_targets = targets_arr[test_indices]\n", " test_predictions = clf.predict(design_arr[test_indices])\n", - " \n", + "\n", " accuracy = 1 - (np.count_nonzero(test_predictions - test_targets) / test_predictions.size)\n", " print(accuracy)\n", - " \n", + "\n", " accuracies.append(accuracy)\n", " confusions.append(confusion_matrix(y_true=test_targets, y_pred=test_predictions, labels=labels))" ] @@ -1625,7 +1616,7 @@ ], "source": [ "print(f\"mean accuracy: {np.mean(accuracy)}\")\n", - "print(f\"chance: {1/labels.size}\")" + "print(f\"chance: {1 / labels.size}\")" ] }, { diff --git a/doc_template/examples_root/examples/nb/ecephys_receptive_fields.ipynb b/doc_template/examples_root/examples/nb/ecephys_receptive_fields.ipynb index 59fa7383ea..ffb491d3d6 100644 --- a/doc_template/examples_root/examples/nb/ecephys_receptive_fields.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_receptive_fields.ipynb @@ -69,6 +69,7 @@ "import numpy as np\n", "\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache" @@ -116,7 +117,7 @@ "outputs": [], "source": [ "# Example cache directory path, it determines where downloaded data will be stored\n", - "output_dir = '/local1/ecephys_cache_dir/'" + "output_dir = \"/local1/ecephys_cache_dir/\"" ] }, { @@ -274,7 +275,7 @@ } ], "source": [ - "rf_stim_table = session.stimulus_presentations[session.stimulus_presentations.stimulus_name == 'gabors']\n", + "rf_stim_table = session.stimulus_presentations[session.stimulus_presentations.stimulus_name == \"gabors\"]\n", "\n", "len(rf_stim_table)" ] @@ -388,9 +389,9 @@ } ], "source": [ - "print('Unique orientations : ' + str(list(np.sort(rf_stim_table.orientation.unique()))))\n", - "print('Unique x positions : ' + str(list(np.sort(rf_stim_table.x_position.unique()))))\n", - "print('Unique y positions : ' + str(list(np.sort(rf_stim_table.y_position.unique()))))" + "print(\"Unique orientations : \" + str(list(np.sort(rf_stim_table.orientation.unique()))))\n", + "print(\"Unique x positions : \" + str(list(np.sort(rf_stim_table.x_position.unique()))))\n", + "print(\"Unique y positions : \" + str(list(np.sort(rf_stim_table.y_position.unique()))))" ] }, { @@ -549,10 +550,10 @@ } ], "source": [ - "print('Spatial frequency: ' + str(rf_stim_table.spatial_frequency.unique()[0]))\n", - "print('Temporal frequency: ' + str(rf_stim_table.temporal_frequency.unique()[0]))\n", - "print('Size: ' + str(rf_stim_table['size'].unique()[0]))\n", - "print('Contrast: ' + str(rf_stim_table['contrast'].unique()[0]))" + "print(\"Spatial frequency: \" + str(rf_stim_table.spatial_frequency.unique()[0]))\n", + "print(\"Temporal frequency: \" + str(rf_stim_table.temporal_frequency.unique()[0]))\n", + "print(\"Size: \" + str(rf_stim_table[\"size\"].unique()[0]))\n", + "print(\"Contrast: \" + str(rf_stim_table[\"contrast\"].unique()[0]))" ] }, { @@ -1082,7 +1083,7 @@ }, "outputs": [], "source": [ - "v1_units = session.units[session.units.ecephys_structure_acronym == 'VISp']" + "v1_units = session.units[session.units.ecephys_structure_acronym == \"VISp\"]" ] }, { @@ -1179,9 +1180,9 @@ } ], "source": [ - "plt.figure(figsize=(5,5))\n", + "plt.figure(figsize=(5, 5))\n", "_ = plt.imshow(RF)\n", - "_ = plt.axis('off')" + "_ = plt.axis(\"off\")" ] }, { @@ -1238,11 +1239,12 @@ "source": [ "def plot_rf(unit_id, index):\n", " RF = rf_mapping.get_receptive_field(unit_id)\n", - " _ = plt.subplot(6,10,index+1)\n", + " _ = plt.subplot(6, 10, index + 1)\n", " _ = plt.imshow(RF)\n", - " _ = plt.axis('off')\n", - " \n", - "_ = plt.figure(figsize=(10,6))\n", + " _ = plt.axis(\"off\")\n", + "\n", + "\n", + "_ = plt.figure(figsize=(10, 6))\n", "_ = [plot_rf(RF, index) for index, RF in enumerate(v1_units.index.values)]" ] } diff --git a/doc_template/examples_root/examples/nb/ecephys_session.ipynb b/doc_template/examples_root/examples/nb/ecephys_session.ipynb index ce53831c0f..481d218dd1 100644 --- a/doc_template/examples_root/examples/nb/ecephys_session.ipynb +++ b/doc_template/examples_root/examples/nb/ecephys_session.ipynb @@ -152,7 +152,7 @@ "outputs": [], "source": [ "output_dir = \"/local1/ecephys_cache_dir/\"\n", - "resources_dir = Path.cwd().parent / 'resources'\n", + "resources_dir = Path.cwd().parent / \"resources\"\n", "DOWNLOAD_LFP = False" ] }, @@ -1499,7 +1499,7 @@ } ], "source": [ - "session_id = 756029989 # for example\n", + "session_id = 756029989 # for example\n", "session = cache.get_session_data(session_id)" ] }, @@ -2133,9 +2133,9 @@ ], "source": [ "# how many units have signal to noise ratios that are greater than 4?\n", - "print(f'{session.units.shape[0]} units total')\n", - "units_with_very_high_snr = session.units[session.units['snr'] > 4]\n", - "print(f'{units_with_very_high_snr.shape[0]} units have snr > 4')" + "print(f\"{session.units.shape[0]} units total\")\n", + "units_with_very_high_snr = session.units[session.units[\"snr\"] > 4]\n", + "print(f\"{units_with_very_high_snr.shape[0]} units have snr > 4\")" ] }, { @@ -2488,7 +2488,7 @@ } ], "source": [ - "session.stimulus_names # just the unique values from the 'stimulus_name' column" + "session.stimulus_names # just the unique values from the 'stimulus_name' column" ] }, { @@ -3061,7 +3061,7 @@ } ], "source": [ - "session.get_stimulus_table(['drifting_gratings']).head()" + "session.get_stimulus_table([\"drifting_gratings\"]).head()" ] }, { @@ -3123,7 +3123,7 @@ ], "source": [ "for key, values in session.get_stimulus_parameter_values().items():\n", - " print(f'{key}: {values}')" + " print(f\"{key}: {values}\")" ] }, { @@ -3421,7 +3421,7 @@ } ], "source": [ - " # grab an arbitrary (though high-snr!) unit (we made units_with_high_snr above)\n", + "# grab an arbitrary (though high-snr!) unit (we made units_with_high_snr above)\n", "high_snr_unit_ids = units_with_very_high_snr.index.values\n", "unit_id = high_snr_unit_ids[0]\n", "\n", @@ -3558,14 +3558,13 @@ } ], "source": [ - "# get spike times from the first block of drifting gratings presentations \n", + "# get spike times from the first block of drifting gratings presentations\n", "drifting_gratings_presentation_ids = session.stimulus_presentations.loc[\n", - " (session.stimulus_presentations['stimulus_name'] == 'drifting_gratings')\n", + " (session.stimulus_presentations[\"stimulus_name\"] == \"drifting_gratings\")\n", "].index.values\n", "\n", "times = session.presentationwise_spike_times(\n", - " stimulus_presentation_ids=drifting_gratings_presentation_ids,\n", - " unit_ids=high_snr_unit_ids\n", + " stimulus_presentation_ids=drifting_gratings_presentation_ids, unit_ids=high_snr_unit_ids\n", ")\n", "\n", "times.head()" @@ -3648,10 +3647,10 @@ } ], "source": [ - "first_drifting_grating_presentation_id = times['stimulus_presentation_id'].values[0]\n", - "plot_times = times[times['stimulus_presentation_id'] == first_drifting_grating_presentation_id]\n", + "first_drifting_grating_presentation_id = times[\"stimulus_presentation_id\"].values[0]\n", + "plot_times = times[times[\"stimulus_presentation_id\"] == first_drifting_grating_presentation_id]\n", "\n", - "fig = raster_plot(plot_times, title=f'spike raster for stimulus presentation {first_drifting_grating_presentation_id}')\n", + "fig = raster_plot(plot_times, title=f\"spike raster for stimulus presentation {first_drifting_grating_presentation_id}\")\n", "plt.show()\n", "\n", "# also print out this presentation\n", @@ -3962,8 +3961,7 @@ ], "source": [ "stats = session.conditionwise_spike_statistics(\n", - " stimulus_presentation_ids=drifting_gratings_presentation_ids,\n", - " unit_ids=high_snr_unit_ids\n", + " stimulus_presentation_ids=drifting_gratings_presentation_ids, unit_ids=high_snr_unit_ids\n", ")\n", "\n", "# display the parameters associated with each condition\n", @@ -4246,9 +4244,12 @@ "source": [ "with_repeats = stats[stats[\"stimulus_presentation_count\"] >= 5]\n", "\n", + "\n", "def highest_mean_rate(df):\n", - " return df.loc[df['spike_mean'].idxmax()]\n", - "max_rate_conditions = with_repeats.groupby('unit_id').apply(highest_mean_rate)\n", + " return df.loc[df[\"spike_mean\"].idxmax()]\n", + "\n", + "\n", + "max_rate_conditions = with_repeats.groupby(\"unit_id\").apply(highest_mean_rate)\n", "max_rate_conditions.head()" ] }, @@ -4874,18 +4875,16 @@ "\n", "# look at responses to the flash stimulus\n", "flash_250_ms_stimulus_presentation_ids = session.stimulus_presentations[\n", - " session.stimulus_presentations['stimulus_name'] == 'flashes'\n", + " session.stimulus_presentations[\"stimulus_name\"] == \"flashes\"\n", "].index.values\n", "\n", "# and get a set of units with only decent snr\n", - "decent_snr_unit_ids = session.units[\n", - " session.units['snr'] >= 1.5\n", - "].index.values\n", + "decent_snr_unit_ids = session.units[session.units[\"snr\"] >= 1.5].index.values\n", "\n", "spike_counts_da = session.presentationwise_spike_counts(\n", " bin_edges=time_bin_edges,\n", " stimulus_presentation_ids=flash_250_ms_stimulus_presentation_ids,\n", - " unit_ids=decent_snr_unit_ids\n", + " unit_ids=decent_snr_unit_ids,\n", ")\n", "spike_counts_da" ] @@ -4949,12 +4948,12 @@ } ], "source": [ - "presentation_id = 3796 # chosen arbitrarily\n", + "presentation_id = 3796 # chosen arbitrarily\n", "plot_spike_counts(\n", - " spike_counts_da.loc[{'stimulus_presentation_id': presentation_id}], \n", - " spike_counts_da['time_relative_to_stimulus_onset'],\n", - " 'spike count', \n", - " f'unitwise spike counts on presentation {presentation_id}'\n", + " spike_counts_da.loc[{\"stimulus_presentation_id\": presentation_id}],\n", + " spike_counts_da[\"time_relative_to_stimulus_onset\"],\n", + " \"spike count\",\n", + " f\"unitwise spike counts on presentation {presentation_id}\",\n", ")\n", "plt.show()" ] @@ -5470,7 +5469,7 @@ } ], "source": [ - "mean_spike_counts = spike_counts_da.mean(dim='stimulus_presentation_id')\n", + "mean_spike_counts = spike_counts_da.mean(dim=\"stimulus_presentation_id\")\n", "mean_spike_counts" ] }, @@ -5526,10 +5525,10 @@ ], "source": [ "plot_spike_counts(\n", - " mean_spike_counts, \n", - " mean_spike_counts['time_relative_to_stimulus_onset'],\n", - " 'mean spike count', \n", - " 'mean spike counts on flash_250_ms presentations'\n", + " mean_spike_counts,\n", + " mean_spike_counts[\"time_relative_to_stimulus_onset\"],\n", + " \"mean spike count\",\n", + " \"mean spike counts on flash_250_ms presentations\",\n", ")\n", "plt.show()" ] @@ -5589,7 +5588,7 @@ "units_of_interest = high_snr_unit_ids[:35]\n", "\n", "waveforms = {uid: session.mean_waveforms[uid] for uid in units_of_interest}\n", - "peak_channels = {uid: session.units.loc[uid, 'peak_channel_id'] for uid in units_of_interest}\n", + "peak_channels = {uid: session.units.loc[uid, \"peak_channel_id\"] for uid in units_of_interest}\n", "\n", "# plot the mean waveform on each unit's peak channel/\n", "plot_mean_waveforms(waveforms, units_of_interest, peak_channels)\n", @@ -5661,8 +5660,8 @@ "ax.yaxis.set_major_locator(plt.NullLocator())\n", "ax.set_ylabel(\"channel\", fontsize=16)\n", "\n", - "ax.set_xticks(np.arange(0, len(unit_waveforms['time']), 20))\n", - "ax.set_xticklabels([f'{float(ii):1.4f}' for ii in unit_waveforms['time'][::20]], rotation=45)\n", + "ax.set_xticks(np.arange(0, len(unit_waveforms[\"time\"]), 20))\n", + "ax.set_xticklabels([f\"{float(ii):1.4f}\" for ii in unit_waveforms[\"time\"][::20]], rotation=45)\n", "ax.set_xlabel(\"time (s)\", fontsize=16)\n", "\n", "plt.show()" @@ -5722,14 +5721,10 @@ } ], "source": [ - "running_speed_midpoints = session.running_speed[\"start_time\"] + \\\n", - " (session.running_speed[\"end_time\"] - session.running_speed[\"start_time\"]) / 2\n", - "plot_running_speed(\n", - " running_speed_midpoints, \n", - " session.running_speed[\"velocity\"], \n", - " start_index=5000,\n", - " stop_index=5100\n", + "running_speed_midpoints = (\n", + " session.running_speed[\"start_time\"] + (session.running_speed[\"end_time\"] - session.running_speed[\"start_time\"]) / 2\n", ")\n", + "plot_running_speed(running_speed_midpoints, session.running_speed[\"velocity\"], start_index=5000, stop_index=5100)\n", "plt.show()" ] }, @@ -6397,20 +6392,20 @@ "from matplotlib import animation\n", "from matplotlib.patches import Ellipse\n", "\n", - "def plot_animated_ellipse_fits(pupil_data: pd.DataFrame, start_frame: int, end_frame: int):\n", "\n", + "def plot_animated_ellipse_fits(pupil_data: pd.DataFrame, start_frame: int, end_frame: int):\n", " start_frame = 0 if (start_frame < 0) else start_frame\n", " end_frame = len(pupil_data) if (end_frame > len(pupil_data)) else end_frame\n", - " \n", + "\n", " frame_times = pupil_data.index.values[start_frame:end_frame]\n", " interval = np.average(np.diff(frame_times)) * 1000\n", "\n", " fig = plt.figure()\n", " ax = plt.axes(xlim=(0, 480), ylim=(0, 480))\n", "\n", - " cr_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color='white')\n", - " pupil_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color='black')\n", - " eye_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color='grey')\n", + " cr_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color=\"white\")\n", + " pupil_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color=\"black\")\n", + " eye_ellipse = Ellipse((0, 0), width=0.0, height=0.0, angle=0, color=\"grey\")\n", "\n", " ax.add_patch(eye_ellipse)\n", " ax.add_patch(pupil_ellipse)\n", @@ -6421,20 +6416,23 @@ " ellipse_patch.width = ellipse_frame_vals[f\"{prefix}_width\"]\n", " ellipse_patch.height = ellipse_frame_vals[f\"{prefix}_height\"]\n", " ellipse_patch.angle = np.degrees(ellipse_frame_vals[f\"{prefix}_phi\"])\n", - " \n", + "\n", " def init():\n", " return [cr_ellipse, pupil_ellipse, eye_ellipse]\n", "\n", " def animate(i):\n", " ellipse_frame_vals = pupil_data.iloc[i]\n", - " \n", + "\n", " update_ellipse(cr_ellipse, ellipse_frame_vals, prefix=\"corneal_reflection\")\n", " update_ellipse(pupil_ellipse, ellipse_frame_vals, prefix=\"pupil\")\n", " update_ellipse(eye_ellipse, ellipse_frame_vals, prefix=\"eye\")\n", - " \n", + "\n", " return [cr_ellipse, pupil_ellipse, eye_ellipse]\n", - " \n", - " return animation.FuncAnimation(fig, animate, init_func=init, interval=interval, frames=range(start_frame, end_frame), blit=True)\n", + "\n", + " return animation.FuncAnimation(\n", + " fig, animate, init_func=init, interval=interval, frames=range(start_frame, end_frame), blit=True\n", + " )\n", + "\n", "\n", "anim = plot_animated_ellipse_fits(pupil_data, 100, 600)" ] @@ -150635,11 +150633,11 @@ " structure_acronyms, intervals = session.channel_structure_intervals(lfp[\"channel\"])\n", " interval_midpoints = [aa + (bb - aa) / 2 for aa, bb in zip(intervals[:-1], intervals[1:])]\n", "else:\n", - " with open(Path(resources_dir) / 'ecephys_session' / 'structure_acronyms.json') as f:\n", + " with open(Path(resources_dir) / \"ecephys_session\" / \"structure_acronyms.json\") as f:\n", " structure_acronyms = json.load(f)\n", " structure_acronyms = np.array(structure_acronyms)\n", - " \n", - " with open(Path(resources_dir) / 'ecephys_session' / 'intervals.json') as f:\n", + "\n", + " with open(Path(resources_dir) / \"ecephys_session\" / \"intervals.json\") as f:\n", " intervals = json.load(f)\n", " intervals = np.array(structure_acronyms)\n", "print(structure_acronyms)\n", @@ -150693,16 +150691,14 @@ "\n", " num_time_labels = 8\n", " time_label_indices = np.around(np.linspace(1, len(window), num_time_labels)).astype(int) - 1\n", - " time_labels = [ f\"{val:1.3}\" for val in lfp[\"time\"].values[window][time_label_indices]]\n", + " time_labels = [f\"{val:1.3}\" for val in lfp[\"time\"].values[window][time_label_indices]]\n", " ax.set_xticks(time_label_indices + 0.5)\n", " ax.set_xticklabels(time_labels)\n", " ax.set_xlabel(\"time (s)\", fontsize=20)\n", "\n", " plt.show()\n", "else:\n", - " lfp_plot = Image.open(Path(resources_dir) /\n", - " 'ecephys_session' /\n", - " 'lfp_plot.png').convert('RGB')\n", + " lfp_plot = Image.open(Path(resources_dir) / \"ecephys_session\" / \"lfp_plot.png\").convert(\"RGB\")\n", " display(lfp_plot)" ] }, @@ -150802,9 +150798,7 @@ "\n", " plt.show()\n", "else:\n", - " filtered_csd_plot = Image.open(Path(resources_dir) /\n", - " 'ecephys_session' /\n", - " 'filtered_csd_plot.png').convert('RGB')\n", + " filtered_csd_plot = Image.open(Path(resources_dir) / \"ecephys_session\" / \"filtered_csd_plot.png\").convert(\"RGB\")\n", " display(filtered_csd_plot)" ] }, diff --git a/doc_template/examples_root/examples/nb/friday_harbor/experiment_detail_example.ipynb b/doc_template/examples_root/examples/nb/friday_harbor/experiment_detail_example.ipynb index e81f4ec8a6..b6e0db5d3f 100644 --- a/doc_template/examples_root/examples/nb/friday_harbor/experiment_detail_example.ipynb +++ b/doc_template/examples_root/examples/nb/friday_harbor/experiment_detail_example.ipynb @@ -23,15 +23,17 @@ "rma = RmaApi()\n", "mca = MouseConnectivityApi()\n", "\n", + "\n", "def read_data(parsed_json):\n", - " return parsed_json['msg']\n", + " return parsed_json[\"msg\"]\n", + "\n", "\n", "def pretty(result):\n", " print(json.dumps(result, indent=2))\n", - " \n", + "\n", + "\n", "def tabular_dataframe(parsed_tabular_json):\n", - " df = pd.DataFrame.from_records(parsed_tabular_json,\n", - " columns=parsed_tabular_json[0].keys())\n", + " df = pd.DataFrame.from_records(parsed_tabular_json, columns=parsed_tabular_json[0].keys())\n", " return df" ] }, @@ -72,21 +74,25 @@ "# rma version from detail page\n", "results = mca.get_manual_injection_summary(section_data_set_id)\n", "\n", - "print('Manual Injection Summary')\n", - "print('Experiment: %d' % (results[0]['id']))\n", - "injections = results[0]['specimen']['injections']\n", - "print('Primary Structure: %s' % (injections[0]['structure']['name']))\n", + "print(\"Manual Injection Summary\")\n", + "print(\"Experiment: %d\" % (results[0][\"id\"]))\n", + "injections = results[0][\"specimen\"][\"injections\"]\n", + "print(\"Primary Structure: %s\" % (injections[0][\"structure\"][\"name\"]))\n", "\n", - "print('Coords. (AP, ML, DV, L):')\n", + "print(\"Coords. (AP, ML, DV, L):\")\n", "for injection in injections:\n", - " print ('\\t%s(%f, %f, %f, %f)' % (injection['registration_point'],\n", - " injection['coordinates_ap'],\n", - " injection['coordinates_ml'],\n", - " injection['coordinates_dv'],\n", - " injection['angle']))\n", - "print('Transgenic Line: %s' % (results[0]['specimen']['donor']['transgenic_mouse']['transgenic_lines'][0]['name']))\n", - "print(results[0]['specimen']['donor']['transgenic_mouse']['transgenic_lines'][0]['description'])\n", - "\n" + " print(\n", + " \"\\t%s(%f, %f, %f, %f)\"\n", + " % (\n", + " injection[\"registration_point\"],\n", + " injection[\"coordinates_ap\"],\n", + " injection[\"coordinates_ml\"],\n", + " injection[\"coordinates_dv\"],\n", + " injection[\"angle\"],\n", + " )\n", + " )\n", + "print(\"Transgenic Line: %s\" % (results[0][\"specimen\"][\"donor\"][\"transgenic_mouse\"][\"transgenic_lines\"][0][\"name\"]))\n", + "print(results[0][\"specimen\"][\"donor\"][\"transgenic_mouse\"][\"transgenic_lines\"][0][\"description\"])" ] }, { @@ -111,7 +117,7 @@ "source": [ "projection_results = mca.get_structure_projection_signal_statistics(section_data_set_id)\n", "projection_df = tabular_dataframe(projection_results)\n", - "#pretty(projection_results)" + "# pretty(projection_results)" ] }, { @@ -211,17 +217,18 @@ ], "source": [ "from allensdk.api.queries.ontologies_api import OntologiesApi\n", + "\n", "oa = OntologiesApi()\n", "\n", "atlases = tabular_dataframe(oa.get_atlases_table(brief=True))\n", - "mouse_p56_coronal_id = atlases[atlases.name == 'Mouse, P56, Coronal']['structure_graph'].iloc[0]['id'] # 1\n", + "mouse_p56_coronal_id = atlases[atlases.name == \"Mouse, P56, Coronal\"][\"structure_graph\"].iloc[0][\"id\"] # 1\n", "\n", "# calculate a list of only the ancestors that are in the limited structure set\n", "structure_set_df = tabular_dataframe(oa.get_structures(mouse_p56_coronal_id))\n", "\n", "oa.unpack_structure_set_ancestors(structure_set_df)\n", "\n", - "structure_set_df[structure_set_df.acronym == 'PIR']" + "structure_set_df[structure_set_df.acronym == \"PIR\"]" ] }, { @@ -249,21 +256,29 @@ ], "source": [ "cutoff = 0.0050\n", - "data_field = 'projection_volume'\n", + "data_field = \"projection_volume\"\n", "\n", "# projection joined to structures\n", - "table = pd.concat((\n", - " pd.merge(projection_df[projection_df.hemisphere_id == h]\n", - " [projection_df.projection_volume > cutoff].loc[:,['hemisphere_id',\n", - " 'structure_id',\n", - " data_field]],\n", - " structure_set_df[['id','acronym','graph_order']],\n", - " how='inner',\n", - " left_on='structure_id', right_on='id')\n", - " for h in [1, 2])).sort('graph_order').\\\n", - " reset_index('structure_id', 'hemisphere_id')\n", - " \n", - "table[['acronym', 'hemisphere_id', 'projection_volume']]\n", + "table = (\n", + " pd.concat(\n", + " (\n", + " pd.merge(\n", + " projection_df[projection_df.hemisphere_id == h][projection_df.projection_volume > cutoff].loc[\n", + " :, [\"hemisphere_id\", \"structure_id\", data_field]\n", + " ],\n", + " structure_set_df[[\"id\", \"acronym\", \"graph_order\"]],\n", + " how=\"inner\",\n", + " left_on=\"structure_id\",\n", + " right_on=\"id\",\n", + " )\n", + " for h in [1, 2]\n", + " )\n", + " )\n", + " .sort(\"graph_order\")\n", + " .reset_index(\"structure_id\", \"hemisphere_id\")\n", + ")\n", + "\n", + "table[[\"acronym\", \"hemisphere_id\", \"projection_volume\"]]\n", "\n", "table" ] @@ -394,10 +409,13 @@ } ], "source": [ - "display_table = table.loc[:,['acronym', 'hemisphere_id', 'projection_volume', 'graph_order']].pivot(index='graph_order',\n", - " columns='hemisphere_id').iloc[:,1:]\n", + "display_table = (\n", + " table.loc[:, [\"acronym\", \"hemisphere_id\", \"projection_volume\", \"graph_order\"]]\n", + " .pivot(index=\"graph_order\", columns=\"hemisphere_id\")\n", + " .iloc[:, 1:]\n", + ")\n", "\n", - "display_table.columns = ['structure','L', 'R']\n", + "display_table.columns = [\"structure\", \"L\", \"R\"]\n", "display_table" ] }, diff --git a/doc_template/examples_root/examples/nb/image_download.ipynb b/doc_template/examples_root/examples/nb/image_download.ipynb index cdae26665a..c64a028fc9 100644 --- a/doc_template/examples_root/examples/nb/image_download.ipynb +++ b/doc_template/examples_root/examples/nb/image_download.ipynb @@ -85,6 +85,7 @@ "from base64 import b64encode\n", "\n", "from IPython.display import HTML, display\n", + "\n", "%matplotlib inline" ] }, @@ -132,18 +133,20 @@ "\n", " fig, ax = plt.subplots(figsize=figsize)\n", " ax.imshow(image)\n", - " \n", - " \n", + "\n", + "\n", "def verify_svg(file_path, width_scale, height_scale):\n", " # we're using this function to display scaled svg in the rendered notebook.\n", " # we suggest that in your own work you use a tool such as inkscape or illustrator to view svg\n", - " \n", - " with open(file_path, 'rb') as svg_file:\n", + "\n", + " with open(file_path, \"rb\") as svg_file:\n", " svg = svg_file.read()\n", " encoded_svg = b64encode(svg)\n", - " decoded_svg = encoded_svg.decode('ascii')\n", - " \n", - " st = r''.format(decoded_svg, width_scale, height_scale)\n", + " decoded_svg = encoded_svg.decode(\"ascii\")\n", + "\n", + " st = r''.format(\n", + " decoded_svg, width_scale, height_scale\n", + " )\n", " display(HTML(st))" ] }, @@ -235,7 +238,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -261,7 +264,7 @@ "outputs": [], "source": [ "section_image_id = 70945123\n", - "file_path = Path(output_dir) / '70945123.jpg'" + "file_path = Path(output_dir) / \"70945123.jpg\"" ] }, { @@ -410,7 +413,7 @@ "outputs": [], "source": [ "section_image_id = 297225716\n", - "file_path = Path(output_dir) / '297225716_connectivity.jpg'\n", + "file_path = Path(output_dir) / \"297225716_connectivity.jpg\"\n", "downsample = 3" ] }, @@ -576,12 +579,11 @@ ], "source": [ "section_image_id = 297225716\n", - "file_path = Path(output_dir) / '297225716_projection.jpg'\n", + "file_path = Path(output_dir) / \"297225716_projection.jpg\"\n", "downsample = 3\n", - "projection=True\n", + "projection = True\n", "\n", - "image_api.download_projection_image(section_image_id, file_path, downsample=downsample, \n", - " projection=projection)\n", + "image_api.download_projection_image(section_image_id, file_path, downsample=downsample, projection=projection)\n", "\n", "verify_image(file_path)" ] @@ -629,7 +631,7 @@ "source": [ "atlas_image_id = 112282603\n", "downsample = 6\n", - "file_path = Path(output_dir) / '112282603_nissl.jpg'" + "file_path = Path(output_dir) / \"112282603_nissl.jpg\"" ] }, { @@ -781,7 +783,7 @@ "atlas_image_id = 112282603\n", "annotation = True\n", "downsample = 6\n", - "file_path = Path(output_dir) / '112282603_annotation.jpg'\n", + "file_path = Path(output_dir) / \"112282603_annotation.jpg\"\n", "\n", "image_api.download_atlas_image(atlas_image_id, file_path, annotation=annotation, downsample=downsample)\n", "verify_image(file_path)" @@ -839,7 +841,7 @@ "from allensdk.api.queries.svg_api import SvgApi\n", "\n", "svg_api = SvgApi()\n", - "svg_api.download_svg(atlas_image_id, file_path=Path(output_dir) / '112282603.svg')" + "svg_api.download_svg(atlas_image_id, file_path=Path(output_dir) / \"112282603.svg\")" ] }, { @@ -878,7 +880,7 @@ } ], "source": [ - "verify_svg(Path(output_dir) / '112282603.svg', 35, 35)" + "verify_svg(Path(output_dir) / \"112282603.svg\", 35, 35)" ] }, { @@ -1203,7 +1205,7 @@ } ], "source": [ - "atlas_image_dataframe['id'].head()" + "atlas_image_dataframe[\"id\"].head()" ] }, { @@ -1250,10 +1252,10 @@ "outputs": [], "source": [ "section_data_set_id = 100147602\n", - "downsample=6\n", + "downsample = 6\n", "\n", - "section_image_directory = Path(output_dir) / '75214738_section_images'\n", - "format_str = '.jpg'" + "section_image_directory = Path(output_dir) / \"75214738_section_images\"\n", + "format_str = \".jpg\"" ] }, { @@ -1304,7 +1306,7 @@ ], "source": [ "section_images = image_api.section_image_query(section_data_set_id)\n", - "section_image_ids = [si['id'] for si in section_images]\n", + "section_image_ids = [si[\"id\"] for si in section_images]\n", "\n", "print(len(section_image_ids))" ] @@ -1349,21 +1351,20 @@ }, "outputs": [], "source": [ - "# You have probably noticed that the AllenSDK has a logger which notifies you of file downloads. \n", + "# You have probably noticed that the AllenSDK has a logger which notifies you of file downloads.\n", "# Since we are downloading ~300 images, we don't want to see messages for each one.\n", "# The following line will temporarily disable the download logger.\n", - "logging.getLogger('allensdk.api.api.retrieve_file_over_http').disabled = True\n", + "logging.getLogger(\"allensdk.api.api.retrieve_file_over_http\").disabled = True\n", "\n", "for section_image_id in section_image_ids:\n", - " \n", " file_name = str(section_image_id) + format_str\n", " file_path = os.path.join(section_image_directory, file_name)\n", - " \n", + "\n", " Manifest.safe_make_parent_dirs(file_path)\n", " image_api.download_section_image(section_image_id, file_path=file_path, downsample=downsample)\n", - " \n", + "\n", "# re-enable the logger\n", - "logging.getLogger('allensdk.api.api.retrieve_file_over_http').disabled = False" + "logging.getLogger(\"allensdk.api.api.retrieve_file_over_http\").disabled = False" ] }, { @@ -1536,7 +1537,7 @@ } ], "source": [ - "print('there are {} section data sets in the Human Brain ISH SubCortex Study'.format(len(human_subcortex_datasets)))" + "print(\"there are {} section data sets in the Human Brain ISH SubCortex Study\".format(len(human_subcortex_datasets)))" ] }, { diff --git a/doc_template/examples_root/examples/nb/mouse_connectivity.ipynb b/doc_template/examples_root/examples/nb/mouse_connectivity.ipynb index ec1ab735fe..3af575cf81 100644 --- a/doc_template/examples_root/examples/nb/mouse_connectivity.ipynb +++ b/doc_template/examples_root/examples/nb/mouse_connectivity.ipynb @@ -80,7 +80,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -149,7 +149,7 @@ "# the data that has already been downloaded onto the hard drives.\n", "# If you supply a relative path, it is assumed to be relative to your\n", "# current working directory.\n", - "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / 'manifest.json')\n", + "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", "\n", "# open up a list of all of the experiments\n", "all_experiments = mcc.get_experiments(dataframe=True)\n", @@ -225,7 +225,7 @@ "# the data that has already been downloaded onto the hard drives.\n", "# If you supply a relative path, it is assumed to be relative to your\n", "# current working directory.\n", - "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / 'manifest.json')\n", + "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", "\n", "# open up a list of all of the experiments\n", "all_experiments = mcc.get_experiments(dataframe=True)\n", @@ -300,7 +300,7 @@ "# the data that has already been downloaded onto the hard drives.\n", "# If you supply a relative path, it is assumed to be relative to your\n", "# current working directory.\n", - "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / 'manifest.json')\n", + "mcc = MouseConnectivityCache(manifest_file=Path(output_dir) / \"manifest.json\")\n", "\n", "# open up a list of all of the experiments\n", "all_experiments = mcc.get_experiments(dataframe=True)\n", @@ -439,7 +439,7 @@ "structure_tree = mcc.get_structure_tree()\n", "\n", "# get info on some structures\n", - "structures = structure_tree.get_structures_by_name(['Primary visual area', 'Hypothalamus'])\n", + "structures = structure_tree.get_structures_by_name([\"Primary visual area\", \"Hypothalamus\"])\n", "pd.DataFrame(structures)" ] }, @@ -1135,15 +1135,13 @@ ], "source": [ "# fetch the experiments that have injections in the isocortex of cre-positive mice\n", - "isocortex = structure_tree.get_structures_by_name(['Isocortex'])[0]\n", - "cre_cortical_experiments = mcc.get_experiments(cre=True, \n", - " injection_structure_ids=[isocortex['id']])\n", + "isocortex = structure_tree.get_structures_by_name([\"Isocortex\"])[0]\n", + "cre_cortical_experiments = mcc.get_experiments(cre=True, injection_structure_ids=[isocortex[\"id\"]])\n", "\n", "print(\"%d cre cortical experiments\" % len(cre_cortical_experiments))\n", "\n", "# same as before, but restrict the cre line\n", - "rbp4_cortical_experiments = mcc.get_experiments(cre=[ 'Rbp4-Cre_KL100' ], \n", - " injection_structure_ids=[isocortex['id']])\n", + "rbp4_cortical_experiments = mcc.get_experiments(cre=[\"Rbp4-Cre_KL100\"], injection_structure_ids=[isocortex[\"id\"]])\n", "\n", "\n", "print(\"%d Rbp4 cortical experiments\" % len(rbp4_cortical_experiments))" @@ -1206,16 +1204,14 @@ ], "source": [ "# find wild-type injections into primary visual area\n", - "visp = structure_tree.get_structures_by_acronym(['VISp'])[0]\n", - "visp_experiments = mcc.get_experiments(cre=False, \n", - " injection_structure_ids=[visp['id']])\n", + "visp = structure_tree.get_structures_by_acronym([\"VISp\"])[0]\n", + "visp_experiments = mcc.get_experiments(cre=False, injection_structure_ids=[visp[\"id\"]])\n", "\n", "print(\"%d VISp experiments\" % len(visp_experiments))\n", "\n", - "structure_unionizes = mcc.get_structure_unionizes([ e['id'] for e in visp_experiments ], \n", - " is_injection=False,\n", - " structure_ids=[isocortex['id']],\n", - " include_descendants=True)\n", + "structure_unionizes = mcc.get_structure_unionizes(\n", + " [e[\"id\"] for e in visp_experiments], is_injection=False, structure_ids=[isocortex[\"id\"]], include_descendants=True\n", + ")\n", "\n", "print(\"%d VISp non-injection, cortical structure unionizes\" % len(structure_unionizes))" ] @@ -2088,11 +2084,13 @@ } ], "source": [ - "dense_unionizes = structure_unionizes[ structure_unionizes.projection_density > .5 ]\n", - "large_unionizes = dense_unionizes[ dense_unionizes.volume > .5 ]\n", + "dense_unionizes = structure_unionizes[structure_unionizes.projection_density > 0.5]\n", + "large_unionizes = dense_unionizes[dense_unionizes.volume > 0.5]\n", "large_structures = pd.DataFrame(structure_tree.nodes(large_unionizes.structure_id))\n", "\n", - "print(\"%d large, dense, cortical, non-injection unionizes, %d structures\" % ( len(large_unionizes), len(large_structures) ))\n", + "print(\n", + " \"%d large, dense, cortical, non-injection unionizes, %d structures\" % (len(large_unionizes), len(large_structures))\n", + ")\n", "\n", "print(large_structures.name)\n", "\n", @@ -2159,30 +2157,33 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import warnings\n", - "warnings.filterwarnings('ignore')\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", "%matplotlib inline\n", "\n", - "visp_experiment_ids = [ e['id'] for e in visp_experiments ]\n", - "ctx_children = structure_tree.child_ids( [isocortex['id']] )[0]\n", + "visp_experiment_ids = [e[\"id\"] for e in visp_experiments]\n", + "ctx_children = structure_tree.child_ids([isocortex[\"id\"]])[0]\n", "\n", - "pm = mcc.get_projection_matrix(experiment_ids = visp_experiment_ids, \n", - " projection_structure_ids = ctx_children,\n", - " hemisphere_ids= [2], # right hemisphere, ipsilateral\n", - " parameter = 'projection_density')\n", + "pm = mcc.get_projection_matrix(\n", + " experiment_ids=visp_experiment_ids,\n", + " projection_structure_ids=ctx_children,\n", + " hemisphere_ids=[2], # right hemisphere, ipsilateral\n", + " parameter=\"projection_density\",\n", + ")\n", "\n", - "row_labels = pm['rows'] # these are just experiment ids\n", - "column_labels = [ c['label'] for c in pm['columns'] ] \n", - "matrix = pm['matrix']\n", + "row_labels = pm[\"rows\"] # these are just experiment ids\n", + "column_labels = [c[\"label\"] for c in pm[\"columns\"]]\n", + "matrix = pm[\"matrix\"]\n", "\n", - "fig, ax = plt.subplots(figsize=(15,15))\n", + "fig, ax = plt.subplots(figsize=(15, 15))\n", "heatmap = ax.pcolor(matrix, cmap=plt.cm.afmhot)\n", "\n", "# put the major ticks at the middle of each cell\n", - "ax.set_xticks(np.arange(matrix.shape[1])+0.5, minor=False)\n", - "ax.set_yticks(np.arange(matrix.shape[0])+0.5, minor=False)\n", + "ax.set_xticks(np.arange(matrix.shape[1]) + 0.5, minor=False)\n", + "ax.set_yticks(np.arange(matrix.shape[0]) + 0.5, minor=False)\n", "\n", "ax.set_xlim([0, matrix.shape[1]])\n", - "ax.set_ylim([0, matrix.shape[0]]) \n", + "ax.set_ylim([0, matrix.shape[0]])\n", "\n", "# want a more natural, table-like display\n", "ax.invert_yaxis()\n", @@ -2384,13 +2385,13 @@ "# show that slice of all volumes side-by-side\n", "f, pr_axes = plt.subplots(1, 3, figsize=(15, 6))\n", "\n", - "pr_axes[0].imshow(pd_mip, cmap='hot', aspect='equal')\n", + "pr_axes[0].imshow(pd_mip, cmap=\"hot\", aspect=\"equal\")\n", "pr_axes[0].set_title(\"projection density MaxIP\")\n", "\n", - "pr_axes[1].imshow(ind_mip, cmap='hot', aspect='equal')\n", + "pr_axes[1].imshow(ind_mip, cmap=\"hot\", aspect=\"equal\")\n", "pr_axes[1].set_title(\"injection density MaxIP\")\n", "\n", - "pr_axes[2].imshow(inf_mip, cmap='hot', aspect='equal')\n", + "pr_axes[2].imshow(inf_mip, cmap=\"hot\", aspect=\"equal\")\n", "pr_axes[2].set_title(\"injection fraction MaxIP\")\n", "\n", "plt.show()" @@ -2439,13 +2440,13 @@ "\n", "f, ccf_axes = plt.subplots(1, 3, figsize=(15, 6))\n", "\n", - "ccf_axes[0].imshow(template[slice_idx,:,:], cmap='gray', aspect='equal', vmin=template.min(), vmax=template.max())\n", + "ccf_axes[0].imshow(template[slice_idx, :, :], cmap=\"gray\", aspect=\"equal\", vmin=template.min(), vmax=template.max())\n", "ccf_axes[0].set_title(\"registration template\")\n", "\n", - "ccf_axes[1].imshow(annot[slice_idx,:,:], cmap='gray', aspect='equal', vmin=0, vmax=2000)\n", + "ccf_axes[1].imshow(annot[slice_idx, :, :], cmap=\"gray\", aspect=\"equal\", vmin=0, vmax=2000)\n", "ccf_axes[1].set_title(\"annotation volume\")\n", "\n", - "ccf_axes[2].imshow(cortex_mask[slice_idx,:,:], cmap='gray', aspect='equal', vmin=0, vmax=1)\n", + "ccf_axes[2].imshow(cortex_mask[slice_idx, :, :], cmap=\"gray\", aspect=\"equal\", vmin=0, vmax=1)\n", "ccf_axes[2].set_title(\"isocortex mask\")\n", "\n", "plt.show()" @@ -2511,8 +2512,8 @@ "source": [ "f, data_mask_axis = plt.subplots(figsize=(5, 6))\n", "\n", - "data_mask_axis.imshow(dm[81, :, :], cmap='hot', aspect='equal', vmin=0, vmax=1)\n", - "data_mask_axis.set_title('data mask')\n", + "data_mask_axis.imshow(dm[81, :, :], cmap=\"hot\", aspect=\"equal\", vmin=0, vmax=1)\n", + "data_mask_axis.set_title(\"data mask\")\n", "\n", "plt.show()" ] diff --git a/doc_template/examples_root/examples/nb/neuron/pulse_stimulus.ipynb b/doc_template/examples_root/examples/nb/neuron/pulse_stimulus.ipynb index 3edab113d3..cb84ce0aa5 100644 --- a/doc_template/examples_root/examples/nb/neuron/pulse_stimulus.ipynb +++ b/doc_template/examples_root/examples/nb/neuron/pulse_stimulus.ipynb @@ -78,12 +78,12 @@ "source": [ "from allensdk.api.queries.biophysical_api import BiophysicalApi\n", "\n", - "neuronal_model_id = 472451419 # get this from the web site\n", - "#neuronal_model_id = 480361288\n", - "model_directory = '.'\n", + "neuronal_model_id = 472451419 # get this from the web site\n", + "# neuronal_model_id = 480361288\n", + "model_directory = \".\"\n", "\n", - "bp = BiophysicalApi('http://api.brain-map.org')\n", - "bp.cache_stimulus = False # don't want to download the large stimulus NWB file\n", + "bp = BiophysicalApi(\"http://api.brain-map.org\")\n", + "bp.cache_stimulus = False # don't want to download the large stimulus NWB file\n", "bp.cache_data(neuronal_model_id, working_directory=model_directory)" ] }, @@ -108,7 +108,7 @@ "source": [ "import os\n", "\n", - "os.system('nrnivmodl modfiles')" + "os.system(\"nrnivmodl modfiles\")" ] }, { @@ -145,7 +145,7 @@ } ], "source": [ - "description = Config().load('manifest.json')\n", + "description = Config().load(\"manifest.json\")\n", "utils = Utils(description)\n", "h = utils.h" ] @@ -170,8 +170,8 @@ "source": [ "# configure model\n", "manifest = description.manifest\n", - "morphology_path = description.manifest.get_path('MORPHOLOGY')\n", - "utils.generate_morphology(morphology_path.encode('ascii', 'ignore'))\n", + "morphology_path = description.manifest.get_path(\"MORPHOLOGY\")\n", + "utils.generate_morphology(morphology_path.encode(\"ascii\", \"ignore\"))\n", "utils.load_cell_parameters()\n", "\n", "# At this point the cell model has been fully set up in NEURON" @@ -245,17 +245,17 @@ "# save the result to a simple time and voltage space-separated text file\n", "import numpy\n", "\n", - "output_path = 'output_voltage.dat'\n", + "output_path = \"output_voltage.dat\"\n", "\n", - "junction_potential = description.data['fitting'][0]['junction_potential']\n", + "junction_potential = description.data[\"fitting\"][0][\"junction_potential\"]\n", "mV = 1.0e-3\n", "ms = 1.0e-3\n", "\n", - "output_data = (numpy.array(vec['v']) - junction_potential) * mV\n", - "output_times = numpy.array(vec['t']) * ms\n", + "output_data = (numpy.array(vec[\"v\"]) - junction_potential) * mV\n", + "output_times = numpy.array(vec[\"t\"]) * ms\n", "\n", "data = numpy.transpose(numpy.vstack((output_times, output_data)))\n", - "with open (output_path, \"w\") as f:\n", + "with open(output_path, \"w\") as f:\n", " numpy.savetxt(f, data)" ] }, @@ -278,10 +278,11 @@ "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", - "plt.plot(vec['t'], numpy.array(vec['v']) - junction_potential)\n", - "plt.xlabel('time (ms)')\n", - "plt.ylabel('membrane potential (mV)')\n", - "plt.show()\n" + "\n", + "plt.plot(vec[\"t\"], numpy.array(vec[\"v\"]) - junction_potential)\n", + "plt.xlabel(\"time (ms)\")\n", + "plt.ylabel(\"membrane potential (mV)\")\n", + "plt.show()" ] }, { diff --git a/doc_template/examples_root/examples/nb/receptive_fields.ipynb b/doc_template/examples_root/examples/nb/receptive_fields.ipynb index 82d59d975a..0a2b5da5ee 100644 --- a/doc_template/examples_root/examples/nb/receptive_fields.ipynb +++ b/doc_template/examples_root/examples/nb/receptive_fields.ipynb @@ -49,6 +49,7 @@ "import allensdk.brain_observatory.receptive_field_analysis.receptive_field as rf\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", + "\n", "%matplotlib inline" ] }, @@ -93,7 +94,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -135,13 +136,11 @@ "source": [ "cell_specimen_id = 587377366\n", "\n", - "boc = BrainObservatoryCache(\n", - " manifest_file=str(Path(output_dir) / 'brain_observatory_manifest.json'))\n", + "boc = BrainObservatoryCache(manifest_file=str(Path(output_dir) / \"brain_observatory_manifest.json\"))\n", "\n", - "exps = boc.get_ophys_experiments(cell_specimen_ids=[cell_specimen_id],\n", - " stimuli=['locally_sparse_noise'])\n", + "exps = boc.get_ophys_experiments(cell_specimen_ids=[cell_specimen_id], stimuli=[\"locally_sparse_noise\"])\n", "\n", - "data_set = boc.get_ophys_experiment_data(exps[0]['id'])\n", + "data_set = boc.get_ophys_experiment_data(exps[0][\"id\"])\n", "\n", "cell_index = data_set.get_cell_specimen_indices([cell_specimen_id])[0]\n", "\n", @@ -201,11 +200,9 @@ } ], "source": [ - "rf_data = rf.compute_receptive_field_with_postprocessing(data_set, \n", - " cell_index, \n", - " 'locally_sparse_noise', \n", - " alpha=0.5, \n", - " number_of_shuffles=10000)" + "rf_data = rf.compute_receptive_field_with_postprocessing(\n", + " data_set, cell_index, \"locally_sparse_noise\", alpha=0.5, number_of_shuffles=10000\n", + ")" ] }, { @@ -314,7 +311,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_rts_summary(rf_data, ax1, ax2)" ] }, @@ -369,7 +366,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_rts_blur_summary(rf_data, ax1, ax2)" ] }, @@ -425,7 +422,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_p_values(rf_data, ax1, ax2)" ] }, @@ -481,7 +478,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_mask(rf_data, ax1, ax2)" ] }, @@ -536,7 +533,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_gaussian_fit(rf_data, ax1, ax2)" ] }, @@ -593,7 +590,7 @@ "cell_specimen_id = 662279767\n", "exps = boc.get_ophys_experiments(cell_specimen_ids=[cell_specimen_id])\n", "for exp in exps:\n", - " print(boc.get_ophys_experiment_stimuli(exp['id']))" + " print(boc.get_ophys_experiment_stimuli(exp[\"id\"]))" ] }, { @@ -652,9 +649,8 @@ } ], "source": [ - "exps = boc.get_ophys_experiments(cell_specimen_ids=[cell_specimen_id],\n", - " stimuli=['locally_sparse_noise_4deg'])\n", - "data_set = boc.get_ophys_experiment_data(exps[0]['id'])\n", + "exps = boc.get_ophys_experiments(cell_specimen_ids=[cell_specimen_id], stimuli=[\"locally_sparse_noise_4deg\"])\n", + "data_set = boc.get_ophys_experiment_data(exps[0][\"id\"])\n", "cell_index = data_set.get_cell_specimen_indices([cell_specimen_id])[0]\n", "print(\"cell %d has index %d\" % (cell_specimen_id, cell_index))" ] @@ -711,11 +707,9 @@ } ], "source": [ - "rf_data = rf.compute_receptive_field_with_postprocessing(data_set, \n", - " cell_index, \n", - " 'locally_sparse_noise_4deg', \n", - " alpha=0.5, \n", - " number_of_shuffles=10000)" + "rf_data = rf.compute_receptive_field_with_postprocessing(\n", + " data_set, cell_index, \"locally_sparse_noise_4deg\", alpha=0.5, number_of_shuffles=10000\n", + ")" ] }, { @@ -788,7 +782,7 @@ } ], "source": [ - "fig, (ax1, ax2) = plt.subplots(1,2)\n", + "fig, (ax1, ax2) = plt.subplots(1, 2)\n", "rfvis.plot_rts_summary(rf_data, ax1, ax2)" ] } diff --git a/doc_template/examples_root/examples/nb/reference_space.ipynb b/doc_template/examples_root/examples/nb/reference_space.ipynb index 49ac6442fa..f30e9ee471 100644 --- a/doc_template/examples_root/examples/nb/reference_space.ipynb +++ b/doc_template/examples_root/examples/nb/reference_space.ipynb @@ -101,7 +101,7 @@ }, "outputs": [], "source": [ - "output_dir = '.'" + "output_dir = \".\"" ] }, { @@ -126,11 +126,11 @@ }, "outputs": [], "source": [ - "reference_space_key = os.path.join('annotation', 'ccf_2017')\n", + "reference_space_key = os.path.join(\"annotation\", \"ccf_2017\")\n", "resolution = 25\n", - "rspc = ReferenceSpaceCache(resolution, reference_space_key, manifest=Path(output_dir) / 'manifest.json')\n", + "rspc = ReferenceSpaceCache(resolution, reference_space_key, manifest=Path(output_dir) / \"manifest.json\")\n", "# ID 1 is the adult mouse structure graph\n", - "tree = rspc.get_structure_tree(structure_graph_id=1) " + "tree = rspc.get_structure_tree(structure_graph_id=1)" ] }, { @@ -181,7 +181,7 @@ ], "source": [ "# now let's take a look at a structure\n", - "tree.get_structures_by_name(['Dorsal auditory area'])" + "tree.get_structures_by_name([\"Dorsal auditory area\"])" ] }, { @@ -344,8 +344,8 @@ "structure_id_a = 385\n", "structure_id_b = 247\n", "\n", - "is_desc = '' if tree.structure_descends_from(structure_id_a, structure_id_b) else ' not'\n", - "print( '{0} is{1} in {2}'.format(name_map[structure_id_a], is_desc, name_map[structure_id_b]) )" + "is_desc = \"\" if tree.structure_descends_from(structure_id_a, structure_id_b) else \" not\"\n", + "print(\"{0} is{1} in {2}\".format(name_map[structure_id_a], is_desc, name_map[structure_id_b]))" ] }, { @@ -379,11 +379,11 @@ ], "source": [ "# build a custom map that looks up acronyms by ids\n", - "# the syntax here is just a pair of node-wise functions. \n", + "# the syntax here is just a pair of node-wise functions.\n", "# The first one returns keys while the second one returns values\n", "\n", - "acronym_map = tree.value_map(lambda x: x['id'], lambda y: y['acronym'])\n", - "print( acronym_map[structure_id_a] )" + "acronym_map = tree.value_map(lambda x: x[\"id\"], lambda y: y[\"acronym\"])\n", + "print(acronym_map[structure_id_a])" ] }, { @@ -579,6 +579,7 @@ ], "source": [ "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "# A complete mask for one structure\n", @@ -586,7 +587,7 @@ "\n", "# view in coronal section\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", - "plt.imshow(whole_cortex_mask[150, :], interpolation='none', cmap=plt.cm.afmhot)" + "plt.imshow(whole_cortex_mask[150, :], interpolation=\"none\", cmap=plt.cm.afmhot)" ] }, { @@ -651,13 +652,13 @@ "source": [ "# This gets all of the structures targeted by the Allen Brain Observatory project\n", "brain_observatory_structures = rsp.structure_tree.get_structures_by_set_id([514166994])\n", - "brain_observatory_ids = [st['id'] for st in brain_observatory_structures]\n", + "brain_observatory_ids = [st[\"id\"] for st in brain_observatory_structures]\n", "\n", "brain_observatory_mask = rsp.make_structure_mask(brain_observatory_ids)\n", "\n", "# view in horizontal section\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", - "plt.imshow(brain_observatory_mask[:, 40, :], interpolation='none', cmap=plt.cm.afmhot)" + "plt.imshow(brain_observatory_mask[:, 40, :], interpolation=\"none\", cmap=plt.cm.afmhot)" ] }, { @@ -722,18 +723,18 @@ "import functools\n", "from allensdk.core.reference_space import ReferenceSpace\n", "\n", - "# Define a wrapper function that will control the mask generation. \n", - "# This one checks for a nrrd file in the specified base directory \n", + "# Define a wrapper function that will control the mask generation.\n", + "# This one checks for a nrrd file in the specified base directory\n", "# and builds/writes the mask only if one does not exist\n", - "annotation_dir = Path(output_dir) / 'annotation'\n", + "annotation_dir = Path(output_dir) / \"annotation\"\n", "mask_writer = functools.partial(ReferenceSpace.check_and_write, annotation_dir)\n", - " \n", + "\n", "# many_structure_masks is a generator - nothing has actrually been run yet\n", "mask_generator = rsp.many_structure_masks([385, 1097], mask_writer)\n", "\n", "# consume the resulting iterator to make and write the masks\n", "for structure_id in mask_generator:\n", - " print( 'made mask for structure {0}.'.format(structure_id) ) \n", + " print(\"made mask for structure {0}.\".format(structure_id))\n", "\n", "os.listdir(annotation_dir)" ] @@ -815,8 +816,8 @@ ], "source": [ "# Double-check the voxel counts\n", - "no_voxel_id = rsp.structure_tree.get_structures_by_name(['Somatosensory areas, layer 6a'])[0]['id']\n", - "print( 'voxel count for structure {0}: {1}'.format(no_voxel_id, rsp.total_voxel_map[no_voxel_id]) )\n", + "no_voxel_id = rsp.structure_tree.get_structures_by_name([\"Somatosensory areas, layer 6a\"])[0][\"id\"]\n", + "print(\"voxel count for structure {0}: {1}\".format(no_voxel_id, rsp.total_voxel_map[no_voxel_id]))\n", "\n", "# remove unassigned structures from the ReferenceSpace's StructureTree\n", "rsp.remove_unassigned()\n", @@ -885,9 +886,8 @@ } ], "source": [ - "\n", "fig, ax = plt.subplots(figsize=(10, 10))\n", - "plt.imshow(rsp.get_slice_image(1, 5000), interpolation='none')" + "plt.imshow(rsp.get_slice_image(1, 5000), interpolation=\"none\")" ] }, { @@ -944,17 +944,17 @@ "\n", "target_resolution = [75, 75, 75]\n", "\n", - "# in some versions of scipy, scipy.ndimage.zoom raises a helpful but distracting \n", - "# warning about the method used to truncate integers. \n", - "warnings.simplefilter('ignore')\n", + "# in some versions of scipy, scipy.ndimage.zoom raises a helpful but distracting\n", + "# warning about the method used to truncate integers.\n", + "warnings.simplefilter(\"ignore\")\n", "\n", "sf_rsp = rsp.downsample(target_resolution)\n", "\n", "# re-enable warnings\n", - "warnings.simplefilter('default')\n", + "warnings.simplefilter(\"default\")\n", "\n", - "print( rsp.annotation.shape )\n", - "print( sf_rsp.annotation.shape )" + "print(rsp.annotation.shape)\n", + "print(sf_rsp.annotation.shape)" ] }, { @@ -1018,7 +1018,7 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(10, 10))\n", - "plt.imshow(sf_rsp.get_slice_image(1, 5000), interpolation='none')" + "plt.imshow(sf_rsp.get_slice_image(1, 5000), interpolation=\"none\")" ] }, { @@ -1085,7 +1085,7 @@ "source": [ "# using the downsampled annotations\n", "hm_rsp = rsp.downsample([100, 100, 100])\n", - "hm_rsp.write_itksnap_labels('ccf_2017_itksnap.nrrd', 'ccf_2017_itksnap_labels.txt')" + "hm_rsp.write_itksnap_labels(\"ccf_2017_itksnap.nrrd\", \"ccf_2017_itksnap_labels.txt\")" ] }, { diff --git a/doc_template/examples_root/examples/nb/summer_workshop_2015/experiment_detail_example.ipynb b/doc_template/examples_root/examples/nb/summer_workshop_2015/experiment_detail_example.ipynb index 1491c071c2..a405a0735f 100644 --- a/doc_template/examples_root/examples/nb/summer_workshop_2015/experiment_detail_example.ipynb +++ b/doc_template/examples_root/examples/nb/summer_workshop_2015/experiment_detail_example.ipynb @@ -23,15 +23,17 @@ "rma = RmaApi()\n", "mca = MouseConnectivityApi()\n", "\n", + "\n", "def read_data(parsed_json):\n", - " return parsed_json['msg']\n", + " return parsed_json[\"msg\"]\n", + "\n", "\n", "def pretty(result):\n", " print(json.dumps(result, indent=2))\n", - " \n", + "\n", + "\n", "def tabular_dataframe(parsed_tabular_json):\n", - " df = pd.DataFrame.from_records(parsed_tabular_json,\n", - " columns=parsed_tabular_json[0].keys())\n", + " df = pd.DataFrame.from_records(parsed_tabular_json, columns=parsed_tabular_json[0].keys())\n", " return df" ] }, @@ -72,21 +74,25 @@ "# rma version from detail page\n", "results = mca.get_manual_injection_summary(section_data_set_id)\n", "\n", - "print('Manual Injection Summary')\n", - "print('Experiment: %d' % (results[0]['id']))\n", - "injections = results[0]['specimen']['injections']\n", - "print('Primary Structure: %s' % (injections[0]['structure']['name']))\n", + "print(\"Manual Injection Summary\")\n", + "print(\"Experiment: %d\" % (results[0][\"id\"]))\n", + "injections = results[0][\"specimen\"][\"injections\"]\n", + "print(\"Primary Structure: %s\" % (injections[0][\"structure\"][\"name\"]))\n", "\n", - "print('Coords. (AP, ML, DV, L):')\n", + "print(\"Coords. (AP, ML, DV, L):\")\n", "for injection in injections:\n", - " print ('\\t%s(%f, %f, %f, %f)' % (injection['registration_point'],\n", - " injection['coordinates_ap'],\n", - " injection['coordinates_ml'],\n", - " injection['coordinates_dv'],\n", - " injection['angle']))\n", - "print('Transgenic Line: %s' % (results[0]['specimen']['donor']['transgenic_mouse']['transgenic_lines'][0]['name']))\n", - "print(results[0]['specimen']['donor']['transgenic_mouse']['transgenic_lines'][0]['description'])\n", - "\n" + " print(\n", + " \"\\t%s(%f, %f, %f, %f)\"\n", + " % (\n", + " injection[\"registration_point\"],\n", + " injection[\"coordinates_ap\"],\n", + " injection[\"coordinates_ml\"],\n", + " injection[\"coordinates_dv\"],\n", + " injection[\"angle\"],\n", + " )\n", + " )\n", + "print(\"Transgenic Line: %s\" % (results[0][\"specimen\"][\"donor\"][\"transgenic_mouse\"][\"transgenic_lines\"][0][\"name\"]))\n", + "print(results[0][\"specimen\"][\"donor\"][\"transgenic_mouse\"][\"transgenic_lines\"][0][\"description\"])" ] }, { @@ -1985,17 +1991,18 @@ ], "source": [ "from allensdk.api.queries.ontologies_api import OntologiesApi\n", + "\n", "oa = OntologiesApi()\n", "\n", "atlases = tabular_dataframe(oa.get_atlases_table(brief=True))\n", - "mouse_p56_coronal_id = atlases[atlases.name == 'Mouse, P56, Coronal']['structure_graph'].iloc[0]['id'] # 1\n", + "mouse_p56_coronal_id = atlases[atlases.name == \"Mouse, P56, Coronal\"][\"structure_graph\"].iloc[0][\"id\"] # 1\n", "\n", "# calculate a list of only the ancestors that are in the limited structure set\n", "structure_set_df = tabular_dataframe(oa.get_structures(mouse_p56_coronal_id))\n", "\n", "oa.unpack_structure_set_ancestors(structure_set_df)\n", "\n", - "structure_set_df[structure_set_df.acronym == 'PIR']" + "structure_set_df[structure_set_df.acronym == \"PIR\"]" ] }, { @@ -2259,19 +2266,28 @@ ], "source": [ "cutoff = 0.0050\n", - "data_field = 'projection_volume'\n", + "data_field = \"projection_volume\"\n", "\n", "# projection joined to structures\n", - "table = pd.concat((\n", - " pd.merge(projection_df[projection_df.hemisphere_id == h]\n", - " [projection_df.projection_volume > cutoff].loc[:,['hemisphere_id',\n", - " 'structure_id',\n", - " data_field]],\n", - " structure_set_df[['id','acronym','graph_order']],\n", - " how='inner',\n", - " left_on='structure_id', right_on='id')\n", - " for h in [1, 2])).sort_values('graph_order').set_index('structure_id')\n", - " \n", + "table = (\n", + " pd.concat(\n", + " (\n", + " pd.merge(\n", + " projection_df[projection_df.hemisphere_id == h][projection_df.projection_volume > cutoff].loc[\n", + " :, [\"hemisphere_id\", \"structure_id\", data_field]\n", + " ],\n", + " structure_set_df[[\"id\", \"acronym\", \"graph_order\"]],\n", + " how=\"inner\",\n", + " left_on=\"structure_id\",\n", + " right_on=\"id\",\n", + " )\n", + " for h in [1, 2]\n", + " )\n", + " )\n", + " .sort_values(\"graph_order\")\n", + " .set_index(\"structure_id\")\n", + ")\n", + "\n", "# table[['acronym', 'hemisphere_id', 'projection_volume']]\n", "\n", "table" @@ -2444,10 +2460,12 @@ } ], "source": [ - "display_table = table[['hemisphere_id', 'acronym', 'projection_volume', 'graph_order']].pivot(index='graph_order', columns='hemisphere_id')\n", + "display_table = table[[\"hemisphere_id\", \"acronym\", \"projection_volume\", \"graph_order\"]].pivot(\n", + " index=\"graph_order\", columns=\"hemisphere_id\"\n", + ")\n", "# display_table.columns = ['structure','L', 'R']\n", "display_table.columns = display_table.columns.droplevel(0)\n", - "display_table.rename(columns={1:'L', 2:'R'})" + "display_table.rename(columns={1: \"L\", 2: \"R\"})" ] } ], diff --git a/doc_template/examples_root/examples/nb/visual_behavior_compare_across_trial_types.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_compare_across_trial_types.ipynb index 5b879d207e..2004b3a2a4 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_compare_across_trial_types.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_compare_across_trial_types.ipynb @@ -433,12 +433,14 @@ "import seaborn as sns\n", "import numpy as np\n", "import pandas as pd\n", - "pd.set_option('display.max_columns', 500)\n", + "\n", + "pd.set_option(\"display.max_columns\", 500)\n", "\n", "import allensdk.brain_observatory.behavior.behavior_project_cache as bpc\n", "\n", "from importlib.metadata import version\n", - "print('allensdk version 2.10.2 or higher is required, you have {} installed'.format(version(\"allensdk\")))" + "\n", + "print(\"allensdk version 2.10.2 or higher is required, you have {} installed\".format(version(\"allensdk\")))" ] }, { @@ -515,6 +517,7 @@ "outputs": [], "source": [ "from IPython.display import display, HTML\n", + "\n", "display(HTML(\"\"))" ] }, @@ -545,7 +548,7 @@ }, "outputs": [], "source": [ - "output_dir = '/path/to/vbo'" + "output_dir = \"/path/to/vbo\"" ] }, { @@ -603,8 +606,8 @@ ], "source": [ "bc = bpc.VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir=output_dir)\n", - " \n", - "experiment_table = bc.get_ophys_experiment_table() " + "\n", + "experiment_table = bc.get_ophys_experiment_table()" ] }, { @@ -1093,7 +1096,7 @@ } ], "source": [ - "np.sort(experiment_table['session_type'].unique())" + "np.sort(experiment_table[\"session_type\"].unique())" ] }, { @@ -1158,7 +1161,7 @@ ], "source": [ "experiment_id = experiment_table.query('session_type == \"OPHYS_1_images_A\"').sample(random_state=10).index[0]\n", - "print('getting experiment data for experiment_id {}'.format(experiment_id))\n", + "print(\"getting experiment data for experiment_id {}\".format(experiment_id))\n", "experiment_dataset = bc.get_behavior_ophys_experiment(experiment_id)" ] }, @@ -1291,9 +1294,11 @@ "outputs": [], "source": [ "trials_df = experiment_dataset.trials.merge(\n", - " experiment_dataset.get_rolling_performance_df().fillna(method='ffill'), # performance data is NaN on aborted trials. Fill forward to populate.\n", - " left_index = True,\n", - " right_index = True\n", + " experiment_dataset.get_rolling_performance_df().fillna(\n", + " method=\"ffill\"\n", + " ), # performance data is NaN on aborted trials. Fill forward to populate.\n", + " left_index=True,\n", + " right_index=True,\n", ")" ] }, @@ -2671,30 +2676,18 @@ } ], "source": [ - "fig, ax = plt.subplots(2, 1, figsize = (15,5), sharex=True)\n", - "ax[0].plot(\n", - " trials_df['start_time']/60.,\n", - " trials_df['hit_rate'],\n", - " color='darkgreen'\n", - ")\n", + "fig, ax = plt.subplots(2, 1, figsize=(15, 5), sharex=True)\n", + "ax[0].plot(trials_df[\"start_time\"] / 60.0, trials_df[\"hit_rate\"], color=\"darkgreen\")\n", "\n", - "ax[0].plot(\n", - " trials_df['start_time']/60.,\n", - " trials_df['false_alarm_rate'],\n", - " color='darkred'\n", - ")\n", + "ax[0].plot(trials_df[\"start_time\"] / 60.0, trials_df[\"false_alarm_rate\"], color=\"darkred\")\n", "\n", - "ax[0].legend(['rolling hit rate', 'rolling false alarm rate'])\n", + "ax[0].legend([\"rolling hit rate\", \"rolling false alarm rate\"])\n", "\n", - "ax[1].plot(\n", - " trials_df['start_time']/60.,\n", - " trials_df['rolling_dprime'],\n", - " color='black'\n", - ")\n", + "ax[1].plot(trials_df[\"start_time\"] / 60.0, trials_df[\"rolling_dprime\"], color=\"black\")\n", "\n", - "ax[1].set_xlabel('trial start time (minutes)')\n", - "ax[0].set_ylabel('response rate')\n", - "ax[0].set_title('hit and false alarm rates')\n", + "ax[1].set_xlabel(\"trial start time (minutes)\")\n", + "ax[0].set_ylabel(\"response rate\")\n", + "ax[0].set_title(\"hit and false alarm rates\")\n", "ax[1].set_title(\"d'\")\n", "\n", "fig.tight_layout()" @@ -2989,7 +2982,8 @@ "source": [ "# Grab the image stimulus only\n", "stimulus_presentations = experiment_dataset.stimulus_presentations[\n", - " experiment_dataset.stimulus_presentations.stimulus_block_name.str.contains('change_detection')]\n", + " experiment_dataset.stimulus_presentations.stimulus_block_name.str.contains(\"change_detection\")\n", + "]\n", "stimulus_presentations.head()" ] }, @@ -3347,9 +3341,11 @@ } ], "source": [ - "unique_stimuli = [stimulus for stimulus in stimulus_presentations['image_name'].unique() if stimulus != 'omitted']\n", - "colormap = {image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))}\n", - "colormap['omitted'] = np.nan # assign gray to omitted\n", + "unique_stimuli = [stimulus for stimulus in stimulus_presentations[\"image_name\"].unique() if stimulus != \"omitted\"]\n", + "colormap = {\n", + " image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))\n", + "}\n", + "colormap[\"omitted\"] = np.nan # assign gray to omitted\n", "colormap" ] }, @@ -3378,7 +3374,7 @@ }, "outputs": [], "source": [ - "stimulus_presentations['color'] = stimulus_presentations['image_name'].map(lambda image_name: colormap[image_name])" + "stimulus_presentations[\"color\"] = stimulus_presentations[\"image_name\"].map(lambda image_name: colormap[image_name])" ] }, { @@ -4200,7 +4196,7 @@ ], "source": [ "def get_cell_timeseries_dict(dataset, cell_specimen_id):\n", - " '''\n", + " \"\"\"\n", " for a given cell_specimen ID, this function creates a dictionary with the following keys\n", " * timestamps: ophys timestamps\n", " * cell_roi_id\n", @@ -4212,18 +4208,21 @@ " cell_specimen_id\n", " returns\n", " dict\n", - " '''\n", + " \"\"\"\n", " cell_dict = {\n", - " 'timestamps': dataset.ophys_timestamps,\n", - " 'cell_roi_id': [dataset.dff_traces.loc[cell_specimen_id]['cell_roi_id']] * len(dataset.ophys_timestamps),\n", - " 'cell_specimen_id': [cell_specimen_id] * len(dataset.ophys_timestamps),\n", - " 'dff': dataset.dff_traces.loc[cell_specimen_id]['dff'],\n", - "\n", + " \"timestamps\": dataset.ophys_timestamps,\n", + " \"cell_roi_id\": [dataset.dff_traces.loc[cell_specimen_id][\"cell_roi_id\"]] * len(dataset.ophys_timestamps),\n", + " \"cell_specimen_id\": [cell_specimen_id] * len(dataset.ophys_timestamps),\n", + " \"dff\": dataset.dff_traces.loc[cell_specimen_id][\"dff\"],\n", " }\n", " return cell_dict\n", "\n", + "\n", "experiment_dataset.tidy_dff_traces = pd.concat(\n", - " [pd.DataFrame(get_cell_timeseries_dict(experiment_dataset, cell_specimen_id)) for cell_specimen_id in experiment_dataset.dff_traces.reset_index()['cell_specimen_id']]\n", + " [\n", + " pd.DataFrame(get_cell_timeseries_dict(experiment_dataset, cell_specimen_id))\n", + " for cell_specimen_id in experiment_dataset.dff_traces.reset_index()[\"cell_specimen_id\"]\n", + " ]\n", ").reset_index(drop=True)\n", "\n", "experiment_dataset.tidy_dff_traces.sample(5, random_state=42)" @@ -4276,109 +4275,113 @@ "outputs": [], "source": [ "def plot_stimuli(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot stimuli as colored bars on specified axis\n", - " '''\n", + " \"\"\"\n", " # Fixup type for use in query.\n", - " stimulus_presentations['omitted'] = stimulus_presentations['omitted'].astype('bool')\n", - " stimuli = stimulus_presentations.query('end_time >= {} and start_time <= {} and not omitted'.format(float(trial['start_time']), float(trial['stop_time'])))\n", + " stimulus_presentations[\"omitted\"] = stimulus_presentations[\"omitted\"].astype(\"bool\")\n", + " stimuli = stimulus_presentations.query(\n", + " \"end_time >= {} and start_time <= {} and not omitted\".format(\n", + " float(trial[\"start_time\"]), float(trial[\"stop_time\"])\n", + " )\n", + " )\n", " for idx, stimulus in stimuli.iterrows():\n", - " ax.axvspan(stimulus['start_time'], stimulus['end_time'], color=stimulus['color'], alpha=0.5)\n", + " ax.axvspan(stimulus[\"start_time\"], stimulus[\"end_time\"], color=stimulus[\"color\"], alpha=0.5)\n", + "\n", "\n", - " \n", "def plot_running(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot running speed for trial on specified axes\n", - " '''\n", - " trial_running_speed = experiment_dataset.running_speed.query('timestamps >= {} and timestamps <= {} '.format(float(trial['start_time']), float(trial['stop_time'])))\n", - " ax.plot(\n", - " trial_running_speed['timestamps'],\n", - " trial_running_speed['speed'],\n", - " color='black'\n", + " \"\"\"\n", + " trial_running_speed = experiment_dataset.running_speed.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(float(trial[\"start_time\"]), float(trial[\"stop_time\"]))\n", " )\n", - " ax.set_title('running speed')\n", - " ax.set_ylabel('speed (cm/s)')\n", - " \n", + " ax.plot(trial_running_speed[\"timestamps\"], trial_running_speed[\"speed\"], color=\"black\")\n", + " ax.set_title(\"running speed\")\n", + " ax.set_ylabel(\"speed (cm/s)\")\n", + "\n", "\n", "def plot_licks(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot licks as black dots on specified axis\n", - " '''\n", - " trial_licks = experiment_dataset.licks.query('timestamps >= {} and timestamps <= {} '.format(float(trial['start_time']), float(trial['stop_time'])))\n", + " \"\"\"\n", + " trial_licks = experiment_dataset.licks.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(float(trial[\"start_time\"]), float(trial[\"stop_time\"]))\n", + " )\n", " ax.plot(\n", - " trial_licks['timestamps'],\n", - " np.zeros_like(trial_licks['timestamps']),\n", - " marker = 'o',\n", - " linestyle = 'none',\n", - " color='black'\n", + " trial_licks[\"timestamps\"], np.zeros_like(trial_licks[\"timestamps\"]), marker=\"o\", linestyle=\"none\", color=\"black\"\n", " )\n", - " \n", + "\n", "\n", "def plot_rewards(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot rewards as blue diamonds on specified axis\n", - " '''\n", - " trial_rewards = experiment_dataset.rewards.query('timestamps >= {} and timestamps <= {} '.format(float(trial['start_time']), float(trial['stop_time'])))\n", + " \"\"\"\n", + " trial_rewards = experiment_dataset.rewards.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(float(trial[\"start_time\"]), float(trial[\"stop_time\"]))\n", + " )\n", " ax.plot(\n", - " trial_rewards['timestamps'],\n", - " np.zeros_like(trial_rewards['timestamps']),\n", - " marker = 'd',\n", - " linestyle = 'none',\n", - " color='blue',\n", - " markersize = 10,\n", - " alpha = 0.25\n", + " trial_rewards[\"timestamps\"],\n", + " np.zeros_like(trial_rewards[\"timestamps\"]),\n", + " marker=\"d\",\n", + " linestyle=\"none\",\n", + " color=\"blue\",\n", + " markersize=10,\n", + " alpha=0.25,\n", " )\n", - " \n", + "\n", + "\n", "def plot_pupil(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot pupil area on specified axis\n", - " '''\n", - " trial_eye_tracking = experiment_dataset.eye_tracking.query('timestamps >= {} and timestamps <= {} '.format(float(trial['start_time']), float(trial['stop_time'])))\n", - " ax.plot(\n", - " trial_eye_tracking['timestamps'],\n", - " trial_eye_tracking['pupil_area'],\n", - " color='black'\n", + " \"\"\"\n", + " trial_eye_tracking = experiment_dataset.eye_tracking.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(float(trial[\"start_time\"]), float(trial[\"stop_time\"]))\n", " )\n", - " ax.set_title('pupil area')\n", - " ax.set_ylabel('pupil area\\n')\n", - " \n", + " ax.plot(trial_eye_tracking[\"timestamps\"], trial_eye_tracking[\"pupil_area\"], color=\"black\")\n", + " ax.set_title(\"pupil area\")\n", + " ax.set_ylabel(\"pupil area\\n\")\n", + "\n", "\n", "def plot_dff(trial, ax):\n", - " '''\n", + " \"\"\"\n", " plot each cell's dff response for a given trial\n", - " '''\n", - " trial_dff_traces = experiment_dataset.tidy_dff_traces.query('timestamps >= {} and timestamps <= {} '.format(float(trial['start_time']), float(trial['stop_time'])))\n", - " for cell_specimen_id in experiment_dataset.tidy_dff_traces['cell_specimen_id'].unique():\n", + " \"\"\"\n", + " trial_dff_traces = experiment_dataset.tidy_dff_traces.query(\n", + " \"timestamps >= {} and timestamps <= {} \".format(float(trial[\"start_time\"]), float(trial[\"stop_time\"]))\n", + " )\n", + " for cell_specimen_id in experiment_dataset.tidy_dff_traces[\"cell_specimen_id\"].unique():\n", " ax.plot(\n", - " trial_dff_traces.query('cell_specimen_id == @cell_specimen_id')['timestamps'],\n", - " trial_dff_traces.query('cell_specimen_id == @cell_specimen_id')['dff']\n", + " trial_dff_traces.query(\"cell_specimen_id == @cell_specimen_id\")[\"timestamps\"],\n", + " trial_dff_traces.query(\"cell_specimen_id == @cell_specimen_id\")[\"dff\"],\n", " )\n", - " ax.set_title('deltaF/F responses')\n", - " ax.set_ylabel('dF/F')\n", - " \n", + " ax.set_title(\"deltaF/F responses\")\n", + " ax.set_ylabel(\"dF/F\")\n", + "\n", + "\n", "def make_trial_plot(trial):\n", - " '''\n", + " \"\"\"\n", " combine all plots for a given trial\n", - " '''\n", - " fig, axes = plt.subplots(4, 1, figsize = (15, 8), sharex=True)\n", + " \"\"\"\n", + " fig, axes = plt.subplots(4, 1, figsize=(15, 8), sharex=True)\n", "\n", " for ax in axes:\n", " plot_stimuli(trial, ax)\n", - " \n", + "\n", " plot_running(trial, axes[0])\n", "\n", " plot_licks(trial, axes[1])\n", " plot_rewards(trial, axes[1])\n", - " \n", - " axes[1].set_title('licks and rewards')\n", + "\n", + " axes[1].set_title(\"licks and rewards\")\n", " axes[1].set_yticks([])\n", - " axes[1].legend(['licks','rewards'])\n", + " axes[1].legend([\"licks\", \"rewards\"])\n", "\n", " plot_pupil(trial, axes[2])\n", "\n", " plot_dff(trial, axes[3])\n", - " \n", - " axes[3].set_xlabel('time in session (seconds)')\n", + "\n", + " axes[3].set_xlabel(\"time in session (seconds)\")\n", " fig.tight_layout()\n", " return fig, axes" ] @@ -5474,7 +5477,7 @@ } ], "source": [ - "trial = experiment_dataset.trials.query('hit').sample(random_state = 1)\n", + "trial = experiment_dataset.trials.query(\"hit\").sample(random_state=1)\n", "fig, axes = make_trial_plot(trial)" ] }, @@ -6527,7 +6530,7 @@ } ], "source": [ - "trial = experiment_dataset.trials.query('miss').sample(random_state = 2)\n", + "trial = experiment_dataset.trials.query(\"miss\").sample(random_state=2)\n", "fig, axes = make_trial_plot(trial)" ] }, @@ -7580,7 +7583,7 @@ } ], "source": [ - "trial = experiment_dataset.trials.query('false_alarm').sample(random_state = 2)\n", + "trial = experiment_dataset.trials.query(\"false_alarm\").sample(random_state=2)\n", "fig, axes = make_trial_plot(trial)" ] }, @@ -8631,7 +8634,7 @@ } ], "source": [ - "trial = experiment_dataset.trials.query('correct_reject').sample(random_state = 10)\n", + "trial = experiment_dataset.trials.query(\"correct_reject\").sample(random_state=10)\n", "fig, axes = make_trial_plot(trial)" ] } diff --git a/doc_template/examples_root/examples/nb/visual_behavior_mouse_history.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_mouse_history.ipynb index a069dc43ca..2c3f2cb121 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_mouse_history.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_mouse_history.ipynb @@ -388,12 +388,14 @@ "import seaborn as sns\n", "import numpy as np\n", "import pandas as pd\n", - "pd.set_option('display.max_columns', 500)\n", + "\n", + "pd.set_option(\"display.max_columns\", 500)\n", "\n", "import allensdk.brain_observatory.behavior.behavior_project_cache as bpc\n", "\n", "from importlib.metadata import version\n", - "print('allensdk version 2.10.2 or higher is required, you have {} installed'.format(version(\"allensdk\")))" + "\n", + "print(\"allensdk version 2.10.2 or higher is required, you have {} installed\".format(version(\"allensdk\")))" ] }, { @@ -444,6 +446,7 @@ "outputs": [], "source": [ "from IPython.display import display, HTML\n", + "\n", "display(HTML(\"\"))" ] }, @@ -491,7 +494,7 @@ "outputs": [], "source": [ "# choose a location on your file system to cache NWB files as they are loaded:\n", - "output_dir = '/tmp/cache'" + "output_dir = \"/tmp/cache\"" ] }, { @@ -546,8 +549,8 @@ ], "source": [ "bc = bpc.VisualBehaviorOphysProjectCache.from_s3_cache(cache_dir=output_dir)\n", - " \n", - "behavior_session_table = bc.get_behavior_session_table() " + "\n", + "behavior_session_table = bc.get_behavior_session_table()" ] }, { @@ -2840,9 +2843,9 @@ ], "source": [ "# Select a mouse id\n", - "mouse_id = '445002'\n", - "this_mouse_table = behavior_session_table.query('mouse_id == @mouse_id').sort_values(by = 'date_of_acquisition')\n", - "# note that the following is functionally equivalent if you find the syntax easier to read: \n", + "mouse_id = \"445002\"\n", + "this_mouse_table = behavior_session_table.query(\"mouse_id == @mouse_id\").sort_values(by=\"date_of_acquisition\")\n", + "# note that the following is functionally equivalent if you find the syntax easier to read:\n", "# this_mouse_table = behavior_session_table[behavior_session_table['mouse_id'] == mouse_id]\n", "this_mouse_table" ] @@ -3149,7 +3152,7 @@ "outputs": [], "source": [ "behavior_session_id = this_mouse_table.query('session_type == \"TRAINING_5_images_A_handoff_ready\"').index[-1]\n", - "# note that the following is functionally equivalent if you find the syntax easier to read: \n", + "# note that the following is functionally equivalent if you find the syntax easier to read:\n", "# behavior_session_id = this_mouse_table[this_mouse_table['session_type'] == \"TRAINING_5_images_A_handoff_ready\"].index[-1]\n", "dataset = behavior_session_dict[behavior_session_id]" ] @@ -3839,10 +3842,15 @@ "source": [ "# Get the image stimulus block\n", "image_stimulus_presentations = dataset.stimulus_presentations[\n", - " dataset.stimulus_presentations.stimulus_block_name.str.contains('change_detection')]\n", - "unique_stimuli = [stimulus for stimulus in image_stimulus_presentations['image_name'].unique()]\n", - "colormap = {image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))}\n", - "image_stimulus_presentations['color'] = image_stimulus_presentations['image_name'].map(lambda image_name: colormap[image_name])" + " dataset.stimulus_presentations.stimulus_block_name.str.contains(\"change_detection\")\n", + "]\n", + "unique_stimuli = [stimulus for stimulus in image_stimulus_presentations[\"image_name\"].unique()]\n", + "colormap = {\n", + " image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))\n", + "}\n", + "image_stimulus_presentations[\"color\"] = image_stimulus_presentations[\"image_name\"].map(\n", + " lambda image_name: colormap[image_name]\n", + ")" ] }, { @@ -3885,67 +3893,67 @@ "outputs": [], "source": [ "def plot_running(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot running speed between two specified times on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " running_sample = dataset.running_speed.query('timestamps >= @initial_time and timestamps <= @final_time')\n", - " ax.plot(\n", - " running_sample['timestamps'],\n", - " running_sample['speed']\n", - " )\n", + " \"\"\"\n", + " running_sample = dataset.running_speed.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", + " ax.plot(running_sample[\"timestamps\"], running_sample[\"speed\"])\n", + "\n", "\n", "def plot_licks(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot licks as dots between two specified times on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " licking_sample = dataset.licks.query('timestamps >= @initial_time and timestamps <= @final_time')\n", + " \"\"\"\n", + " licking_sample = dataset.licks.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", " ax.plot(\n", - " licking_sample['timestamps'],\n", - " np.zeros_like(licking_sample['timestamps']),\n", - " marker = 'o',\n", - " color = 'black',\n", - " linestyle = 'none'\n", + " licking_sample[\"timestamps\"],\n", + " np.zeros_like(licking_sample[\"timestamps\"]),\n", + " marker=\"o\",\n", + " color=\"black\",\n", + " linestyle=\"none\",\n", " )\n", - " \n", + "\n", + "\n", "def plot_rewards(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot rewards between two specified times as blue diamonds on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " rewards_sample = dataset.rewards.query('timestamps >= @initial_time and timestamps <= @final_time')\n", + " \"\"\"\n", + " rewards_sample = dataset.rewards.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", " ax.plot(\n", - " rewards_sample['timestamps'],\n", - " np.zeros_like(rewards_sample['timestamps']),\n", - " marker = 'd',\n", - " color = 'blue',\n", - " linestyle = 'none',\n", - " markersize = 12,\n", - " alpha = 0.5\n", + " rewards_sample[\"timestamps\"],\n", + " np.zeros_like(rewards_sample[\"timestamps\"]),\n", + " marker=\"d\",\n", + " color=\"blue\",\n", + " linestyle=\"none\",\n", + " markersize=12,\n", + " alpha=0.5,\n", " )\n", - " \n", + "\n", + "\n", "def plot_stimuli(ax, ti, tf, image_stim_table):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot stimuli as colored vertical spans on a s\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", " image_stim_table: Set of image stimuli to plot.\n", - " '''\n", - " stimulus_presentations_sample = image_stim_table.query('end_time >= @initial_time and start_time <= @final_time')\n", + " \"\"\"\n", + " stimulus_presentations_sample = image_stim_table.query(\"end_time >= @initial_time and start_time <= @final_time\")\n", " for idx, stimulus in stimulus_presentations_sample.iterrows():\n", - " ax.axvspan(stimulus['start_time'], stimulus['end_time'], color=stimulus['color'], alpha=0.25)" + " ax.axvspan(stimulus[\"start_time\"], stimulus[\"end_time\"], color=stimulus[\"color\"], alpha=0.25)" ] }, { @@ -5983,22 +5991,22 @@ } ], "source": [ - "initial_time = 775 # initial time for plot, in seconds\n", - "final_time = 800 # final time for plot, in seconds\n", + "initial_time = 775 # initial time for plot, in seconds\n", + "final_time = 800 # final time for plot, in seconds\n", "\n", "plt.clf()\n", - "fig, ax = plt.subplots(figsize = (15,5))\n", + "fig, ax = plt.subplots(figsize=(15, 5))\n", "plot_running(ax, initial_time, final_time)\n", "plot_licks(ax, initial_time, final_time)\n", "plot_rewards(ax, initial_time, final_time)\n", "plot_stimuli(ax, initial_time, final_time, image_stimulus_presentations)\n", "\n", - "ax.legend(['running speed', 'licks', 'rewards'])\n", + "ax.legend([\"running speed\", \"licks\", \"rewards\"])\n", "\n", - "ax.set_ylabel('running speed (cm/s)')\n", - "ax.set_xlabel('time in session (s)')\n", + "ax.set_ylabel(\"running speed (cm/s)\")\n", + "ax.set_xlabel(\"time in session (s)\")\n", "ax.set_xlim(initial_time, final_time)\n", - "ax.set_title('a short section of the session');" + "ax.set_title(\"a short section of the session\");" ] }, { @@ -6388,7 +6396,7 @@ } ], "source": [ - "dataset.trials.query('hit').sample(random_state=0).to_dict('records')" + "dataset.trials.query(\"hit\").sample(random_state=0).to_dict(\"records\")" ] }, { @@ -6507,7 +6515,10 @@ "outputs": [], "source": [ "behavior_performance_table = pd.DataFrame(\n", - " [behavior_session_dict[behavior_session_id].get_performance_metrics() for behavior_session_id in behavior_session_ids]\n", + " [\n", + " behavior_session_dict[behavior_session_id].get_performance_metrics()\n", + " for behavior_session_id in behavior_session_ids\n", + " ]\n", ").set_index(behavior_session_ids)" ] }, @@ -7489,8 +7500,8 @@ "source": [ "this_mouse_table = this_mouse_table.merge(\n", " behavior_performance_table,\n", - " left_index = True,\n", - " right_index = True,\n", + " left_index=True,\n", + " right_index=True,\n", ")\n", "this_mouse_table.head()" ] @@ -8533,24 +8544,20 @@ } ], "source": [ - "fig, ax = plt.subplots(figsize = (15,5))\n", + "fig, ax = plt.subplots(figsize=(15, 5))\n", "\n", - "ax.plot(\n", - " np.arange(len(this_mouse_table)),\n", - " this_mouse_table['max_dprime'],\n", - " marker = 'o'\n", - ")\n", + "ax.plot(np.arange(len(this_mouse_table)), this_mouse_table[\"max_dprime\"], marker=\"o\")\n", "ax.set_xticks(range(len(this_mouse_table)))\n", - "ax.set_xticklabels(list(this_mouse_table['session_type'].values),rotation = 30, ha='right')\n", + "ax.set_xticklabels(list(this_mouse_table[\"session_type\"].values), rotation=30, ha=\"right\")\n", "\n", "# make alternating black/gray vspans for visual clarity\n", - "colors = ['black', 'gray']\n", + "colors = [\"black\", \"gray\"]\n", "for ii in range(len(this_mouse_table)):\n", - " ax.axvspan(ii - 0.5, ii + 0.5, color = colors[ii%2], alpha=0.25)\n", + " ax.axvspan(ii - 0.5, ii + 0.5, color=colors[ii % 2], alpha=0.25)\n", "\n", "ax.set_xlim(-0.5, len(this_mouse_table) - 0.5)\n", - "ax.set_ylabel('dprime')\n", - "ax.set_xlabel('session type')\n", + "ax.set_ylabel(\"dprime\")\n", + "ax.set_xlabel(\"session type\")\n", "ax.set_title(\"Max of rolling d' for every session for mouse {}\".format(mouse_id))\n", "fig.tight_layout()" ] diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_LFP_analysis.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_LFP_analysis.ipynb index efabea6325..f49c4a5590 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_LFP_analysis.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_LFP_analysis.ipynb @@ -98,9 +98,9 @@ } ], "source": [ - "from allensdk.brain_observatory.behavior.behavior_project_cache.\\\n", - " behavior_neuropixels_project_cache \\\n", - " import VisualBehaviorNeuropixelsProjectCache\n", + "from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import (\n", + " VisualBehaviorNeuropixelsProjectCache,\n", + ")\n", "\n", "import numpy as np\n", "import pandas as pd\n", @@ -188,17 +188,16 @@ "\n", "is not deleted between instantiations of this cache\n", " warnings.warn(msg, MissingLocalManifestWarning)\n", - "ecephys_sessions.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 64.7k/64.7k [00:00<00:00, 507kMB/s] \n", - "behavior_sessions.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 562k/562k [00:00<00:00, 3.10MMB/s] \n", - "units.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 132M/132M [00:05<00:00, 24.4MMB/s]\n", - "probes.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 130k/130k [00:00<00:00, 700kMB/s] \n", - "channels.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 27.9M/27.9M [00:01<00:00, 18.1MMB/s]\n" + "ecephys_sessions.csv: 100%|██████████| 64.7k/64.7k [00:00<00:00, 507kMB/s] \n", + "behavior_sessions.csv: 100%|██████████| 562k/562k [00:00<00:00, 3.10MMB/s] \n", + "units.csv: 100%|██████████| 132M/132M [00:05<00:00, 24.4MMB/s]\n", + "probes.csv: 100%|██████████| 130k/130k [00:00<00:00, 700kMB/s] \n", + "channels.csv: 100%|██████████| 27.9M/27.9M [00:01<00:00, 18.1MMB/s]\n" ] } ], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=output_dir)" + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=output_dir)" ] }, { @@ -251,9 +250,9 @@ ], "source": [ "probes = cache.get_probe_table()\n", - "valid_lfp = probes[probes['has_lfp_data']]\n", + "valid_lfp = probes[probes[\"has_lfp_data\"]]\n", "\n", - "print('Fraction of insertions with valid LFP: ', len(valid_lfp)/len(probes))" + "print(\"Fraction of insertions with valid LFP: \", len(valid_lfp) / len(probes))" ] }, { @@ -450,7 +449,7 @@ } ], "source": [ - "valid_lfp[valid_lfp['structure_acronyms'].str.contains(\"'VISp',\")].head()" + "valid_lfp[valid_lfp[\"structure_acronyms\"].str.contains(\"'VISp',\")].head()" ] }, { @@ -495,7 +494,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "ecephys_session_1064644573.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 2.99G/2.99G [02:09<00:00, 23.1MMB/s]\n", + "ecephys_session_1064644573.nwb: 100%|██████████| 2.99G/2.99G [02:09<00:00, 23.1MMB/s]\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.\n", " warn(\"Ignoring cached namespace '%s' version %s because version %s is already loaded.\"\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.6.0-alpha is already loaded.\n", @@ -506,9 +505,8 @@ } ], "source": [ - "session_id = probes.loc[1064735073]['ecephys_session_id']\n", - "session = cache.get_ecephys_session(\n", - " ecephys_session_id=session_id)" + "session_id = probes.loc[1064735073][\"ecephys_session_id\"]\n", + "session = cache.get_ecephys_session(ecephys_session_id=session_id)" ] }, { @@ -572,7 +570,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "probe_probeC_lfp.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 4.68G/4.68G [03:16<00:00, 23.7MMB/s]\n", + "probe_probeC_lfp.nwb: 100%|██████████| 4.68G/4.68G [03:16<00:00, 23.7MMB/s]\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'hdmf-common' version 1.5.1 because version 1.8.0 is already loaded.\n", " warn(\"Ignoring cached namespace '%s' version %s because version %s is already loaded.\"\n", "/opt/hostedtoolcache/Python/3.8.18/x64/lib/python3.8/site-packages/hdmf/spec/namespace.py:531: UserWarning: Ignoring cached namespace 'core' version 2.5.0 because version 2.6.0-alpha is already loaded.\n", @@ -733,7 +731,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -744,7 +742,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -1245,7 +1243,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -1256,7 +1254,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -1587,7 +1585,7 @@ } ], "source": [ - "lfp_slice = lfp.sel(time=slice(100,101))\n", + "lfp_slice = lfp.sel(time=slice(100, 101))\n", "\n", "lfp_slice" ] @@ -1654,10 +1652,10 @@ } ], "source": [ - "plt.figure(figsize=(10,2))\n", + "plt.figure(figsize=(10, 2))\n", "_ = plt.plot(lfp_slice.time, lfp_slice.sel(channel=lfp_slice.channel[10]))\n", - "plt.xlabel('Time (s)')\n", - "plt.ylabel('LFP (V)')" + "plt.xlabel(\"Time (s)\")\n", + "plt.ylabel(\"LFP (V)\")" ] }, { @@ -1710,11 +1708,11 @@ } ], "source": [ - "plt.figure(figsize=(8,8))\n", - "im = plt.imshow(lfp_slice.T,aspect='auto',origin='lower',vmin=-1e-3, vmax=1e-3)\n", + "plt.figure(figsize=(8, 8))\n", + "im = plt.imshow(lfp_slice.T, aspect=\"auto\", origin=\"lower\", vmin=-1e-3, vmax=1e-3)\n", "_ = plt.colorbar(im, fraction=0.036, pad=0.04)\n", - "_ = plt.xlabel('Sample number')\n", - "_ = plt.ylabel('Channel index')" + "_ = plt.xlabel(\"Sample number\")\n", + "_ = plt.ylabel(\"Channel index\")" ] }, { @@ -1799,7 +1797,7 @@ "outputs": [], "source": [ "stim_presentations = session.stimulus_presentations\n", - "flashes = stim_presentations[stim_presentations['stimulus_name'].str.contains('flash')]\n", + "flashes = stim_presentations[stim_presentations[\"stimulus_name\"].str.contains(\"flash\")]\n", "presentation_times = flashes.start_time.values\n", "presentation_ids = flashes.index.values" ] @@ -1843,8 +1841,8 @@ }, "outputs": [], "source": [ - "def align_lfp(lfp, trial_window, alignment_times, trial_ids = None):\n", - " '''\n", + "def align_lfp(lfp, trial_window, alignment_times, trial_ids=None):\n", + " \"\"\"\n", " Aligns the LFP data array to experiment times of interest\n", " INPUTS:\n", " lfp: data array containing LFP data for one probe insertion\n", @@ -1852,23 +1850,24 @@ " alignment_times: experiment times around which to excise data\n", " trial_ids: indices in the session stim table specifying which stimuli to use for alignment.\n", " None if aligning to non-stimulus times\n", - " \n", + "\n", " OUTPUT:\n", " aligned data array with dimensions channels x trials x time\n", - " '''\n", - " \n", + " \"\"\"\n", + "\n", " time_selection = np.concatenate([trial_window + t for t in alignment_times])\n", - " \n", + "\n", " if trial_ids is None:\n", " trial_ids = np.arange(len(alignment_times))\n", - " \n", - " inds = pd.MultiIndex.from_product((trial_ids, trial_window), \n", - " names=('presentation_id', 'time_from_presentation_onset'))\n", "\n", - " ds = lfp.sel(time = time_selection, method='nearest').to_dataset(name = 'aligned_lfp')\n", - " ds = ds.assign(time=inds).unstack('time')\n", + " inds = pd.MultiIndex.from_product(\n", + " (trial_ids, trial_window), names=(\"presentation_id\", \"time_from_presentation_onset\")\n", + " )\n", + "\n", + " ds = lfp.sel(time=time_selection, method=\"nearest\").to_dataset(name=\"aligned_lfp\")\n", + " ds = ds.assign(time=inds).unstack(\"time\")\n", "\n", - " return ds['aligned_lfp']" + " return ds[\"aligned_lfp\"]" ] }, { @@ -1893,7 +1892,7 @@ }, "outputs": [], "source": [ - "aligned_lfp = align_lfp(lfp, np.arange(-0.5, 0.5, 1/500), presentation_times, presentation_ids)" + "aligned_lfp = align_lfp(lfp, np.arange(-0.5, 0.5, 1 / 500), presentation_times, presentation_ids)" ] }, { @@ -1938,10 +1937,10 @@ "outputs": [], "source": [ "chans = session.get_channels()\n", - "lfp_chan_depths = [chans.loc[c]['probe_vertical_position'] for c in lfp.channel.values]\n", + "lfp_chan_depths = [chans.loc[c][\"probe_vertical_position\"] for c in lfp.channel.values]\n", "\n", - "chans_in_brain = chans[(chans['probe_id']==1064735073)&(~chans['structure_acronym'].str.contains('root'))]\n", - "first_channel_in_brain_position = chans_in_brain['probe_vertical_position'].max()" + "chans_in_brain = chans[(chans[\"probe_id\"] == 1064735073) & (~chans[\"structure_acronym\"].str.contains(\"root\"))]\n", + "first_channel_in_brain_position = chans_in_brain[\"probe_vertical_position\"].max()" ] }, { @@ -1988,16 +1987,18 @@ ], "source": [ "fig, ax = plt.subplots()\n", - "fig.suptitle('Flash aligned mean LFP')\n", - "im = ax.pcolor(aligned_lfp.time_from_presentation_onset.values, lfp_chan_depths, aligned_lfp.mean(dim='presentation_id').data)\n", + "fig.suptitle(\"Flash aligned mean LFP\")\n", + "im = ax.pcolor(\n", + " aligned_lfp.time_from_presentation_onset.values, lfp_chan_depths, aligned_lfp.mean(dim=\"presentation_id\").data\n", + ")\n", "_ = plt.colorbar(im, fraction=0.036, pad=0.04)\n", - "_ = plt.xlabel('Time from flash onset (s)')\n", - "_ = plt.ylabel('Channel Position from Tip (um)')\n", + "_ = plt.xlabel(\"Time from flash onset (s)\")\n", + "_ = plt.ylabel(\"Channel Position from Tip (um)\")\n", "\n", - "ax.axvline(0, c='w', ls='dotted')\n", - "ax.axvline(0.25, c='w', ls='dotted')\n", - "ax.axhline(first_channel_in_brain_position, c='w')\n", - "ax.text(-0.4, first_channel_in_brain_position+50, 'brain surface', c='w')" + "ax.axvline(0, c=\"w\", ls=\"dotted\")\n", + "ax.axvline(0.25, c=\"w\", ls=\"dotted\")\n", + "ax.axhline(first_channel_in_brain_position, c=\"w\")\n", + "ax.text(-0.4, first_channel_in_brain_position + 50, \"brain surface\", c=\"w\")" ] }, { @@ -2079,23 +2080,25 @@ }, "outputs": [], "source": [ - "sess_units = session.get_units()\n", + "sess_units = session.get_units()\n", "\n", - "#Grab units whose peak channels are in the LFP data, have relatively low isi violations and high amplitude spikes\n", - "units_on_lfp_chans = sess_units[(sess_units.peak_channel_id.isin(lfp.channel.values)) &\n", - " (sess_units.isi_violations < 0.5) &\n", - " (sess_units.amplitude > 200)]\n", + "# Grab units whose peak channels are in the LFP data, have relatively low isi violations and high amplitude spikes\n", + "units_on_lfp_chans = sess_units[\n", + " (sess_units.peak_channel_id.isin(lfp.channel.values))\n", + " & (sess_units.isi_violations < 0.5)\n", + " & (sess_units.amplitude > 200)\n", + "]\n", "\n", - "#Merge this curated unit table with the channel table to get CCF locations for these units\n", - "units_on_lfp_chans = units_on_lfp_chans.merge(chans, left_on='peak_channel_id', right_index=True)\n", + "# Merge this curated unit table with the channel table to get CCF locations for these units\n", + "units_on_lfp_chans = units_on_lfp_chans.merge(chans, left_on=\"peak_channel_id\", right_index=True)\n", "\n", - "#Select a unit in V1\n", - "v1_units = units_on_lfp_chans[units_on_lfp_chans.structure_acronym.str.contains('VISp')]\n", + "# Select a unit in V1\n", + "v1_units = units_on_lfp_chans[units_on_lfp_chans.structure_acronym.str.contains(\"VISp\")]\n", "unit_id = v1_units.index.values[0]\n", "\n", - "#Get the peak channel ID for this unit (the channel on which it had the greatest spike amplitude)\n", - "peak_chan_id = units_on_lfp_chans.loc[unit_id]['peak_channel_id']\n", - "peak_probe_position = units_on_lfp_chans.loc[unit_id]['probe_vertical_position']" + "# Get the peak channel ID for this unit (the channel on which it had the greatest spike amplitude)\n", + "peak_chan_id = units_on_lfp_chans.loc[unit_id][\"peak_channel_id\"]\n", + "peak_probe_position = units_on_lfp_chans.loc[unit_id][\"probe_vertical_position\"]" ] }, { @@ -2144,8 +2147,8 @@ "\n", "times_in_range = spike_times[(spike_times > start_time) & (spike_times < end_time)]\n", "\n", - "lfp_data = lfp.sel(time = slice(start_time, end_time))\n", - "lfp_data = lfp_data.sel(channel = peak_chan_id, method='nearest')" + "lfp_data = lfp.sel(time=slice(start_time, end_time))\n", + "lfp_data = lfp_data.sel(channel=peak_chan_id, method=\"nearest\")" ] }, { @@ -2187,8 +2190,11 @@ }, "outputs": [], "source": [ - "stims_in_window = stim_presentations[(stim_presentations.start_time>start_time)&(stim_presentations.start_time start_time)\n", + " & (stim_presentations.start_time < end_time)\n", + " & (~stim_presentations.omitted)\n", + "]\n", "stim_times_in_window = stims_in_window.start_time.values" ] }, @@ -2254,14 +2260,13 @@ ], "source": [ "_ = plt.plot(lfp_data.time, lfp_data)\n", - "_ = plt.plot(times_in_range, np.ones(times_in_range.shape)*3e-4, '.r')\n", - "_ = plt.xlabel('Time (s)')\n", - "_ = plt.ylabel('LFP (V)')\n", + "_ = plt.plot(times_in_range, np.ones(times_in_range.shape) * 3e-4, \".r\")\n", + "_ = plt.xlabel(\"Time (s)\")\n", + "_ = plt.ylabel(\"LFP (V)\")\n", "\n", - "_ = plt.plot(stim_times_in_window, np.ones(stim_times_in_window.size)*4e-4, 'vg')\n", + "_ = plt.plot(stim_times_in_window, np.ones(stim_times_in_window.size) * 4e-4, \"vg\")\n", "\n", - "plt.legend(['LFP', 'spikes', 'stim times'])\n", - " " + "plt.legend([\"LFP\", \"spikes\", \"stim times\"])" ] }, { @@ -2320,9 +2325,9 @@ }, "outputs": [], "source": [ - "rng = np.random.default_rng(seed=42) #set seed for deterministic results\n", + "rng = np.random.default_rng(seed=42) # set seed for deterministic results\n", "spikes_to_use = rng.choice(spike_times, min((spike_times.size, 1000)), replace=False)\n", - "spike_triggered_lfp = align_lfp(lfp, np.arange(-0.1, 0.1, 1/1250), spikes_to_use)" + "spike_triggered_lfp = align_lfp(lfp, np.arange(-0.1, 0.1, 1 / 1250), spikes_to_use)" ] }, { @@ -2386,14 +2391,18 @@ ], "source": [ "fig, ax = plt.subplots()\n", - "im = ax.pcolor(spike_triggered_lfp.time_from_presentation_onset.values, lfp_chan_depths, \n", - " spike_triggered_lfp.mean(dim='presentation_id').data, shading='auto')\n", + "im = ax.pcolor(\n", + " spike_triggered_lfp.time_from_presentation_onset.values,\n", + " lfp_chan_depths,\n", + " spike_triggered_lfp.mean(dim=\"presentation_id\").data,\n", + " shading=\"auto\",\n", + ")\n", "\n", - "ax.plot(-0.01, peak_probe_position, '>w')\n", - "ax.text(-0.015, peak_probe_position, 'peak channel', c='w', va='center', ha='right')\n", - "ax.set_ylim([peak_probe_position-300, peak_probe_position+300])\n", - "ax.set_xlabel('Time from spike (s)')\n", - "ax.set_ylabel('Channel depth')" + "ax.plot(-0.01, peak_probe_position, \">w\")\n", + "ax.text(-0.015, peak_probe_position, \"peak channel\", c=\"w\", va=\"center\", ha=\"right\")\n", + "ax.set_ylim([peak_probe_position - 300, peak_probe_position + 300])\n", + "ax.set_xlabel(\"Time from spike (s)\")\n", + "ax.set_ylabel(\"Channel depth\")" ] }, { @@ -2608,7 +2617,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: '\u25ba';\n", + " content: '►';\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -2619,7 +2628,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: '\u25bc';\n", + " content: '▼';\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", @@ -3060,9 +3069,9 @@ "source": [ "from scipy.ndimage.filters import gaussian_filter\n", "\n", - "_ = plt.figure(figsize=(10,10))\n", + "_ = plt.figure(figsize=(10, 10))\n", "\n", - "filtered_csd = gaussian_filter(csd.data, sigma=(5,1))\n", + "filtered_csd = gaussian_filter(csd.data, sigma=(5, 1))\n", "\n", "fig, ax = plt.subplots(figsize=(6, 6))\n", "\n", @@ -3072,13 +3081,13 @@ "_ = ax.set_ylabel(\"vertical position (um)\")\n", "\n", "\n", - "chans_in_v1 = chans[(chans['probe_id']==1064735073)&(chans['structure_acronym'].str.contains('VISp'))]\n", - "last_cortex_channel_position = chans_in_v1['probe_vertical_position'].min()\n", + "chans_in_v1 = chans[(chans[\"probe_id\"] == 1064735073) & (chans[\"structure_acronym\"].str.contains(\"VISp\"))]\n", + "last_cortex_channel_position = chans_in_v1[\"probe_vertical_position\"].min()\n", "\n", - "ax.axhline(first_channel_in_brain_position, c='w')\n", - "ax.text(-0.075, first_channel_in_brain_position+50, 'brain surface', c='w')\n", - "ax.axhline(last_cortex_channel_position, c='w')\n", - "ax.text(-0.075, last_cortex_channel_position+50, 'end of cortex', c='w')" + "ax.axhline(first_channel_in_brain_position, c=\"w\")\n", + "ax.text(-0.075, first_channel_in_brain_position + 50, \"brain surface\", c=\"w\")\n", + "ax.axhline(last_cortex_channel_position, c=\"w\")\n", + "ax.text(-0.075, last_cortex_channel_position + 50, \"end of cortex\", c=\"w\")" ] }, { @@ -3131,7 +3140,7 @@ } ], "source": [ - "probes.loc[1064735073]['structure_acronyms']" + "probes.loc[1064735073][\"structure_acronyms\"]" ] } ], diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_analyzing_behavior_only_data.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_analyzing_behavior_only_data.ipynb index 75db05cfd5..e0fab6b716 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_analyzing_behavior_only_data.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_analyzing_behavior_only_data.ipynb @@ -80,9 +80,9 @@ } ], "source": [ - "from allensdk.brain_observatory.behavior.behavior_project_cache.\\\n", - " behavior_neuropixels_project_cache \\\n", - " import VisualBehaviorNeuropixelsProjectCache\n", + "from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import (\n", + " VisualBehaviorNeuropixelsProjectCache,\n", + ")\n", "\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", @@ -179,8 +179,7 @@ } ], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=output_dir)" + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=output_dir)" ] }, { @@ -1584,7 +1583,9 @@ ], "source": [ "mouse_id = 533537\n", - "mouse_session_table = behavior_session_table[behavior_session_table['mouse_id']==mouse_id].sort_values(by='date_of_acquisition')\n", + "mouse_session_table = behavior_session_table[behavior_session_table[\"mouse_id\"] == mouse_id].sort_values(\n", + " by=\"date_of_acquisition\"\n", + ")\n", "mouse_session_table" ] }, @@ -2587,11 +2588,13 @@ }, "outputs": [], "source": [ - "unique_stimuli = [stimulus for stimulus in behavior_session.stimulus_presentations['image_name'].unique()]\n", - "colormap = {image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))}\n", + "unique_stimuli = [stimulus for stimulus in behavior_session.stimulus_presentations[\"image_name\"].unique()]\n", + "colormap = {\n", + " image_name: sns.color_palette()[image_number] for image_number, image_name in enumerate(np.sort(unique_stimuli))\n", + "}\n", "\n", - "stim_colors = stimulus_presentations.apply(lambda row: colormap[row['image_name']], axis=1)\n", - "stimulus_presentations['color'] = stim_colors" + "stim_colors = stimulus_presentations.apply(lambda row: colormap[row[\"image_name\"]], axis=1)\n", + "stimulus_presentations[\"color\"] = stim_colors" ] }, { @@ -2634,66 +2637,68 @@ "outputs": [], "source": [ "def plot_running(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot running speed between two specified times on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " running_sample = behavior_session.running_speed.query('timestamps >= @initial_time and timestamps <= @final_time')\n", - " ax.plot(\n", - " running_sample['timestamps'],\n", - " running_sample['speed']\n", - " )\n", + " \"\"\"\n", + " running_sample = behavior_session.running_speed.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", + " ax.plot(running_sample[\"timestamps\"], running_sample[\"speed\"])\n", + "\n", "\n", "def plot_licks(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot licks as dots between two specified times on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " licking_sample = behavior_session.licks.query('timestamps >= @initial_time and timestamps <= @final_time')\n", + " \"\"\"\n", + " licking_sample = behavior_session.licks.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", " ax.plot(\n", - " licking_sample['timestamps'],\n", - " np.zeros_like(licking_sample['timestamps']),\n", - " marker = 'o',\n", - " color = 'black',\n", - " linestyle = 'none'\n", + " licking_sample[\"timestamps\"],\n", + " np.zeros_like(licking_sample[\"timestamps\"]),\n", + " marker=\"o\",\n", + " color=\"black\",\n", + " linestyle=\"none\",\n", " )\n", - " \n", + "\n", + "\n", "def plot_rewards(ax, initial_time, final_time):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot rewards between two specified times as blue diamonds on a specified axis\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " rewards_sample = behavior_session.rewards.query('timestamps >= @initial_time and timestamps <= @final_time')\n", + " \"\"\"\n", + " rewards_sample = behavior_session.rewards.query(\"timestamps >= @initial_time and timestamps <= @final_time\")\n", " ax.plot(\n", - " rewards_sample['timestamps'],\n", - " np.zeros_like(rewards_sample['timestamps']),\n", - " marker = 'd',\n", - " color = 'blue',\n", - " linestyle = 'none',\n", - " markersize = 12,\n", - " alpha = 0.5\n", + " rewards_sample[\"timestamps\"],\n", + " np.zeros_like(rewards_sample[\"timestamps\"]),\n", + " marker=\"d\",\n", + " color=\"blue\",\n", + " linestyle=\"none\",\n", + " markersize=12,\n", + " alpha=0.5,\n", " )\n", - " \n", + "\n", + "\n", "def plot_stimuli(ax, ti, tf):\n", - " '''\n", + " \"\"\"\n", " a simple function to plot stimuli as colored vertical spans on a s\n", " inputs:\n", " ax: axis on which to plot\n", " intial_time: initial time to plot from\n", " final_time: final time to plot to\n", - " '''\n", - " stimulus_presentations_sample = stimulus_presentations.query('end_time >= @initial_time and start_time <= @final_time')\n", + " \"\"\"\n", + " stimulus_presentations_sample = stimulus_presentations.query(\n", + " \"end_time >= @initial_time and start_time <= @final_time\"\n", + " )\n", " for idx, stimulus in stimulus_presentations_sample.iterrows():\n", - " ax.axvspan(stimulus['start_time'], stimulus['end_time'], color=stimulus['color'], alpha=0.25)" + " ax.axvspan(stimulus[\"start_time\"], stimulus[\"end_time\"], color=stimulus[\"color\"], alpha=0.25)" ] }, { @@ -2756,22 +2761,22 @@ } ], "source": [ - "initial_time = 775 # initial time for plot, in seconds\n", - "final_time = 800 # final time for plot, in seconds\n", + "initial_time = 775 # initial time for plot, in seconds\n", + "final_time = 800 # final time for plot, in seconds\n", "\n", "plt.clf()\n", - "fig, ax = plt.subplots(figsize = (15,5))\n", + "fig, ax = plt.subplots(figsize=(15, 5))\n", "plot_running(ax, initial_time, final_time)\n", "plot_licks(ax, initial_time, final_time)\n", "plot_rewards(ax, initial_time, final_time)\n", "plot_stimuli(ax, initial_time, final_time)\n", "\n", - "ax.legend(['running speed', 'licks', 'rewards'])\n", + "ax.legend([\"running speed\", \"licks\", \"rewards\"])\n", "\n", - "ax.set_ylabel('running speed (cm/s)')\n", - "ax.set_xlabel('time in session (s)')\n", + "ax.set_ylabel(\"running speed (cm/s)\")\n", + "ax.set_xlabel(\"time in session (s)\")\n", "ax.set_xlim(initial_time, final_time)\n", - "ax.set_title('a short section of the session');" + "ax.set_title(\"a short section of the session\");" ] }, { @@ -3354,8 +3359,8 @@ } ], "source": [ - "hit_miss_table = trials.pivot_table(index='change_image_name', values=['hit', 'miss'], aggfunc=sum)\n", - "hit_miss_table['hit_rate'] = hit_miss_table['hit']/(hit_miss_table['hit'] + hit_miss_table['miss'])\n", + "hit_miss_table = trials.pivot_table(index=\"change_image_name\", values=[\"hit\", \"miss\"], aggfunc=sum)\n", + "hit_miss_table[\"hit_rate\"] = hit_miss_table[\"hit\"] / (hit_miss_table[\"hit\"] + hit_miss_table[\"miss\"])\n", "hit_miss_table" ] }, @@ -3424,9 +3429,9 @@ "\n", "rolling_performance = behavior_session.get_rolling_performance_df()\n", "\n", - "ax.plot(rolling_performance['reward_rate'])\n", - "ax.set_ylabel('Rewards per minute')\n", - "ax.set_xlabel('Trials')" + "ax.plot(rolling_performance[\"reward_rate\"])\n", + "ax.set_ylabel(\"Rewards per minute\")\n", + "ax.set_xlabel(\"Trials\")" ] }, { @@ -3571,10 +3576,12 @@ } ], "source": [ - "engaged_trials = trials[rolling_performance['reward_rate']>2]\n", + "engaged_trials = trials[rolling_performance[\"reward_rate\"] > 2]\n", "\n", - "hit_miss_table_engaged = engaged_trials.pivot_table(index='change_image_name', values=['hit', 'miss'], aggfunc=sum)\n", - "hit_miss_table_engaged['hit_rate'] = hit_miss_table_engaged['hit']/(hit_miss_table_engaged['hit'] + hit_miss_table_engaged['miss'])\n", + "hit_miss_table_engaged = engaged_trials.pivot_table(index=\"change_image_name\", values=[\"hit\", \"miss\"], aggfunc=sum)\n", + "hit_miss_table_engaged[\"hit_rate\"] = hit_miss_table_engaged[\"hit\"] / (\n", + " hit_miss_table_engaged[\"hit\"] + hit_miss_table_engaged[\"miss\"]\n", + ")\n", "hit_miss_table_engaged" ] }, diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_data_access.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_data_access.ipynb index 8d6127f0fa..0e03c1c885 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_data_access.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_data_access.ipynb @@ -1045,7 +1045,7 @@ } ], "source": [ - "cache.load_manifest('visual-behavior-neuropixels_project_manifest_v0.2.0.json')" + "cache.load_manifest(\"visual-behavior-neuropixels_project_manifest_v0.2.0.json\")" ] }, { @@ -1148,8 +1148,10 @@ "source": [ "# This cell will not be useful until an updated version of the data release is issued\n", "\n", - "msg = cache.compare_manifests('visual-behavior-neuropixels_project_manifest_v0.1.0.json',\n", - " 'visual-behavior-neuropixels_project_manifest_v0.2.0.json')\n", + "msg = cache.compare_manifests(\n", + " \"visual-behavior-neuropixels_project_manifest_v0.1.0.json\",\n", + " \"visual-behavior-neuropixels_project_manifest_v0.2.0.json\",\n", + ")\n", "print(msg)" ] }, @@ -3027,7 +3029,7 @@ ], "source": [ "# Visualizing a particular stimulus\n", - "plt.imshow(ecephys_session.stimulus_templates['warped']['im104_r'], cmap='gray')" + "plt.imshow(ecephys_session.stimulus_templates[\"warped\"][\"im104_r\"], cmap=\"gray\")" ] }, { @@ -3106,7 +3108,7 @@ "\n", "if DOWNLOAD_COMPLETE_DATASET:\n", " for ecephys_session_id, _ in filtered_ecephys_sessions.iterrows():\n", - " cache.get_ecephys_session(ecephys_session_id=ecephys_session_id)\n" + " cache.get_ecephys_session(ecephys_session_id=ecephys_session_id)" ] }, { @@ -3204,11 +3206,15 @@ "source": [ "from urllib.parse import urljoin\n", "\n", + "\n", "def get_manifest_url(manifest_version: str) -> str:\n", " hostname = \"https://visual-behavior-neuropixels-data.s3.us-west-2.amazonaws.com\"\n", - " object_key = f\"visual-behavior-neuropixels/manifests/visual-behavior-neuropixels_project_manifest_v{manifest_version}.json\"\n", + " object_key = (\n", + " f\"visual-behavior-neuropixels/manifests/visual-behavior-neuropixels_project_manifest_v{manifest_version}.json\"\n", + " )\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_manifest_url(\"0.1.0\"))" ] @@ -3251,6 +3257,7 @@ " object_key = f\"visual-behavior-neuropixels/project_metadata/{metadata_table_name}.csv\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_metadata_url(\"behavior_sessions\"))" ] @@ -3293,6 +3300,7 @@ " object_key = f\"visual-behavior-neuropixels/ecephys_sessions/ecephys_session_{ecephys_session_id}.nwb\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_behavior_session_url(1052533639))" ] @@ -7695,12 +7703,13 @@ "# The location will differ based on where you downloaded the manifest.json!\n", "my_manifest_location = output_dir / cache.current_manifest()\n", "\n", + "\n", "def generate_all_download_urls_from_manifest(manifest_path: Path) -> List[str]:\n", - " with manifest_path.open('r') as fp:\n", + " with manifest_path.open(\"r\") as fp:\n", " manifest = json.load(fp)\n", - " \n", + "\n", " download_links = []\n", - " \n", + "\n", " # Get download links for specific version of metadata files\n", " for metadata_file_entry in manifest[\"metadata_files\"].values():\n", " base_download_url = metadata_file_entry[\"url\"]\n", @@ -7713,12 +7722,13 @@ " base_download_url = data_file_entry[\"url\"]\n", " version_query = f\"?versionId={data_file_entry['version_id']}\"\n", " full_download_url = urljoin(base_download_url, version_query)\n", - " download_links.append(full_download_url) \n", + " download_links.append(full_download_url)\n", "\n", " return download_links\n", "\n", + "\n", "# Example:\n", - "print('\\n'.join(generate_all_download_urls_from_manifest(my_manifest_location)))" + "print(\"\\n\".join(generate_all_download_urls_from_manifest(my_manifest_location)))" ] } ], diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_dataset_manifest.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_dataset_manifest.ipynb index 5125241f7a..44d8e72ef6 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_dataset_manifest.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_dataset_manifest.ipynb @@ -434,9 +434,9 @@ "import numpy as np\n", "from pathlib import Path\n", "\n", - "from allensdk.brain_observatory.behavior.behavior_project_cache.\\\n", - " behavior_neuropixels_project_cache \\\n", - " import VisualBehaviorNeuropixelsProjectCache" + "from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import (\n", + " VisualBehaviorNeuropixelsProjectCache,\n", + ")" ] }, { @@ -539,8 +539,7 @@ } ], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=Path(output_dir))" + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=Path(output_dir))" ] }, { @@ -568,8 +567,7 @@ }, "outputs": [], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=Path(output_dir))" + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=Path(output_dir))" ] }, { @@ -967,9 +965,9 @@ } ], "source": [ - "sessions_by_imageset_experience_day = ecephys_sessions_table.pivot_table(index=['session_number', 'experience_level'], \n", - " columns=['image_set'], \n", - " values='behavior_session_id', aggfunc=len)\n", + "sessions_by_imageset_experience_day = ecephys_sessions_table.pivot_table(\n", + " index=[\"session_number\", \"experience_level\"], columns=[\"image_set\"], values=\"behavior_session_id\", aggfunc=len\n", + ")\n", "display(sessions_by_imageset_experience_day)" ] }, @@ -1063,7 +1061,7 @@ } ], "source": [ - "print('the different transgenic lines included in this dataset are:\\n')\n", + "print(\"the different transgenic lines included in this dataset are:\\n\")\n", "print(np.sort(ecephys_sessions_table.genotype.unique()))" ] }, @@ -1154,10 +1152,11 @@ } ], "source": [ - "#Number of sessions per genotype/sex\n", - "sessions_by_genotype_sex = ecephys_sessions_table.pivot_table(values='session_number', index='genotype', \n", - " columns='sex', aggfunc=len)\n", - "display(sessions_by_genotype_sex.rename(columns={'session_number': 'session_count'}))" + "# Number of sessions per genotype/sex\n", + "sessions_by_genotype_sex = ecephys_sessions_table.pivot_table(\n", + " values=\"session_number\", index=\"genotype\", columns=\"sex\", aggfunc=len\n", + ")\n", + "display(sessions_by_genotype_sex.rename(columns={\"session_number\": \"session_count\"}))" ] }, { @@ -1267,9 +1266,10 @@ } ], "source": [ - "#Number of mice per genotype/sex\n", - "mice_by_genotype_sex = ecephys_sessions_table.pivot_table(values='mouse_id', index='genotype', \n", - " columns='sex', aggfunc=lambda x: len(np.unique(x)))\n", + "# Number of mice per genotype/sex\n", + "mice_by_genotype_sex = ecephys_sessions_table.pivot_table(\n", + " values=\"mouse_id\", index=\"genotype\", columns=\"sex\", aggfunc=lambda x: len(np.unique(x))\n", + ")\n", "display(mice_by_genotype_sex)" ] }, @@ -1328,8 +1328,8 @@ ], "source": [ "ecephys_sessions_no_filter = cache.get_ecephys_session_table(filter_abnormalities=False)\n", - "print(f'Number sessions returned by default: {len(ecephys_sessions_table)}')\n", - "print(f'Number of sessions returned without filtering abnormalities: {len(ecephys_sessions_no_filter)}')\n" + "print(f\"Number sessions returned by default: {len(ecephys_sessions_table)}\")\n", + "print(f\"Number of sessions returned without filtering abnormalities: {len(ecephys_sessions_no_filter)}\")" ] }, { @@ -1394,9 +1394,10 @@ ], "source": [ "# grab a session that was flagged for both tissue damage and epileptiform activity\n", - "ecephys_sessions_no_filter[['abnormal_histology', 'abnormal_activity']]\\\n", - " [~ecephys_sessions_no_filter['abnormal_histology'].isnull() & \n", - " ~ecephys_sessions_no_filter['abnormal_activity'].isnull()].iloc[0]" + "ecephys_sessions_no_filter[[\"abnormal_histology\", \"abnormal_activity\"]][\n", + " ~ecephys_sessions_no_filter[\"abnormal_histology\"].isnull()\n", + " & ~ecephys_sessions_no_filter[\"abnormal_activity\"].isnull()\n", + "].iloc[0]" ] }, { @@ -1864,7 +1865,7 @@ } ], "source": [ - "print('behavior data could be recorded on these experimental systems:\\n')\n", + "print(\"behavior data could be recorded on these experimental systems:\\n\")\n", "print(np.sort(behavior_sessions.equipment_name.unique()))" ] }, @@ -1941,7 +1942,7 @@ } ], "source": [ - "print('there are', len(behavior_sessions.mouse_id.unique()), 'mice in the dataset')" + "print(\"there are\", len(behavior_sessions.mouse_id.unique()), \"mice in the dataset\")" ] }, { @@ -2057,7 +2058,7 @@ } ], "source": [ - "print('the session_types available in this dataset are:\\n')\n", + "print(\"the session_types available in this dataset are:\\n\")\n", "print(np.sort(behavior_sessions.session_type.unique()))" ] }, @@ -2578,9 +2579,17 @@ } ], "source": [ - "training_history = behavior_sessions[behavior_sessions['mouse_id']==556016]\n", - "training_history = training_history.sort_values(by='date_of_acquisition')\n", - "training_history[['session_type', 'equipment_name', 'date_of_acquisition', 'prior_exposures_to_image_set', 'prior_exposures_to_omissions']]" + "training_history = behavior_sessions[behavior_sessions[\"mouse_id\"] == 556016]\n", + "training_history = training_history.sort_values(by=\"date_of_acquisition\")\n", + "training_history[\n", + " [\n", + " \"session_type\",\n", + " \"equipment_name\",\n", + " \"date_of_acquisition\",\n", + " \"prior_exposures_to_image_set\",\n", + " \"prior_exposures_to_omissions\",\n", + " ]\n", + "]" ] }, { @@ -2939,7 +2948,7 @@ ], "source": [ "units = cache.get_unit_table()\n", - "print(f'This dataset contains {len(units)} total units')\n", + "print(f\"This dataset contains {len(units)} total units\")\n", "\n", "units.head()" ] @@ -3297,9 +3306,9 @@ } ], "source": [ - "#grab the ecephys session id for one experiment; these session ids are the indices of the ecephys_sessions_table\n", + "# grab the ecephys session id for one experiment; these session ids are the indices of the ecephys_sessions_table\n", "session_id = ecephys_sessions_table.index.values[1]\n", - "session_units = units[units['ecephys_session_id']==session_id]\n", + "session_units = units[units[\"ecephys_session_id\"] == session_id]\n", "session_units.head()" ] }, @@ -3338,7 +3347,7 @@ "source": [ "# Looks like we inserted all 6 probes during this experiment\n", "session_probes_from_units_table = np.sort(session_units.ecephys_probe_id.unique())\n", - "print(f'We recorded from {len(session_probes_from_units_table)} probes this session')" + "print(f\"We recorded from {len(session_probes_from_units_table)} probes this session\")" ] }, { @@ -3398,8 +3407,8 @@ ], "source": [ "probes = cache.get_probe_table()\n", - "session_probes = probes[probes.ecephys_session_id==session_id].index.values\n", - "np.all(session_probes_from_units_table==session_probes)" + "session_probes = probes[probes.ecephys_session_id == session_id].index.values\n", + "np.all(session_probes_from_units_table == session_probes)" ] }, { @@ -3916,8 +3925,8 @@ }, "outputs": [], "source": [ - "#first let's merge the units and channels tables\n", - "session_units_channels = session_units.merge(channels, left_on='ecephys_channel_id', right_index=True)" + "# first let's merge the units and channels tables\n", + "session_units_channels = session_units.merge(channels, left_on=\"ecephys_channel_id\", right_index=True)" ] }, { @@ -3979,23 +3988,26 @@ "from matplotlib import pyplot as plt\n", "\n", "fig = plt.figure()\n", - "fig.set_size_inches([14,8])\n", - "ax = fig.add_subplot(111, projection='3d')\n", + "fig.set_size_inches([14, 8])\n", + "ax = fig.add_subplot(111, projection=\"3d\")\n", + "\n", + "\n", "def plot_probe_coords(probe_group):\n", - " ax.scatter(probe_group['left_right_ccf_coordinate_x'],\n", - " probe_group['anterior_posterior_ccf_coordinate_x'],\n", - " -probe_group['dorsal_ventral_ccf_coordinate_x'], #reverse the z coord so that down is into the brain\n", - " )\n", - " return probe_group['ecephys_probe_id_x'].values[0]\n", + " ax.scatter(\n", + " probe_group[\"left_right_ccf_coordinate_x\"],\n", + " probe_group[\"anterior_posterior_ccf_coordinate_x\"],\n", + " -probe_group[\"dorsal_ventral_ccf_coordinate_x\"], # reverse the z coord so that down is into the brain\n", + " )\n", + " return probe_group[\"ecephys_probe_id_x\"].values[0]\n", + "\n", "\n", - "probe_ids = session_units_channels.groupby('ecephys_probe_id_x').apply(plot_probe_coords)\n", + "probe_ids = session_units_channels.groupby(\"ecephys_probe_id_x\").apply(plot_probe_coords)\n", "\n", - "ax.set_zlabel('D/V')\n", - "ax.set_xlabel('Left/Right')\n", - "ax.set_ylabel('A/P')\n", + "ax.set_zlabel(\"D/V\")\n", + "ax.set_xlabel(\"Left/Right\")\n", + "ax.set_ylabel(\"A/P\")\n", "ax.legend(probe_ids)\n", - "ax.view_init(elev=55, azim=70)\n", - "\n" + "ax.view_init(elev=55, azim=70)" ] }, { diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quality_metrics.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quality_metrics.ipynb index ee89e205c8..ae31a77cf4 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quality_metrics.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quality_metrics.ipynb @@ -150,6 +150,7 @@ "import numpy as np\n", "from pathlib import Path\n", "import matplotlib.pyplot as plt\n", + "\n", "%matplotlib inline\n", "\n", "from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache" @@ -311,10 +312,11 @@ "outputs": [], "source": [ "from scipy.ndimage import gaussian_filter1d\n", - "plt.rcParams.update({'font.size': 14})\n", + "\n", + "plt.rcParams.update({\"font.size\": 14})\n", + "\n", "\n", "def plot_metric(data, bins, x_axis_label, color, max_value=-1):\n", - " \n", " h, b = np.histogram(data, bins=bins, density=True)\n", "\n", " x = b[:-1]\n", @@ -323,11 +325,11 @@ " plt.plot(x, y, color=color)\n", " plt.xlabel(x_axis_label)\n", " plt.gca().get_yaxis().set_visible(False)\n", - " [plt.gca().spines[loc].set_visible(False) for loc in ['right', 'top', 'left']]\n", + " [plt.gca().spines[loc].set_visible(False) for loc in [\"right\", \"top\", \"left\"]]\n", " if max_value < np.max(y) * 1.1:\n", " max_value = np.max(y) * 1.1\n", " plt.ylim([0, max_value])\n", - " \n", + "\n", " return max_value" ] }, @@ -398,10 +400,10 @@ } ], "source": [ - "data = units['firing_rate']\n", - "bins = np.linspace(0,50,100)\n", + "data = units[\"firing_rate\"]\n", + "bins = np.linspace(0, 50, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'Firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"Firing rate (Hz)\", \"red\")" ] }, { @@ -454,10 +456,10 @@ } ], "source": [ - "data = np.log10(units['firing_rate'])\n", - "bins = np.linspace(-3,2,100)\n", + "data = np.log10(units[\"firing_rate\"])\n", + "bins = np.linspace(-3, 2, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", \"red\")" ] }, { @@ -510,10 +512,10 @@ } ], "source": [ - "data = np.log10(units[units.nn_hit_rate > 0.9]['firing_rate'])\n", - "bins = np.linspace(-3,2,100)\n", + "data = np.log10(units[units.nn_hit_rate > 0.9][\"firing_rate\"])\n", + "bins = np.linspace(-3, 2, 100)\n", "\n", - "max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', 'red')" + "max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", \"red\")" ] }, { @@ -566,26 +568,47 @@ } ], "source": [ - "region_dict = {'cortex' : ['VISp', 'VISl', 'VISrl', 'VISam', 'VISpm', 'VIS', 'VISal','VISmma','VISmmp','VISli'],\n", - " 'thalamus' : ['LGd','LD', 'LP', 'VPM', 'TH', 'MGm','MGv','MGd','PO','LGv','VL',\n", - " 'VPL','POL','Eth','PoT','PP','PIL','IntG','IGL','SGN','VPL','PF','RT'],\n", - " 'hippocampus' : ['CA1', 'CA2','CA3', 'DG', 'SUB', 'POST','PRE','ProS','HPF'],\n", - " 'midbrain': ['MB','SCig','SCiw','SCsg','SCzo','PPT','APN','NOT','MRN','OP','LT','RPF','CP']}\n", - "\n", - "color_dict = {'cortex' : '#08858C',\n", - " 'thalamus' : '#FC6B6F',\n", - " 'hippocampus' : '#7ED04B',\n", - " 'midbrain' : '#FC9DFE'}\n", - "\n", - "bins = np.linspace(-3,2,100)\n", + "region_dict = {\n", + " \"cortex\": [\"VISp\", \"VISl\", \"VISrl\", \"VISam\", \"VISpm\", \"VIS\", \"VISal\", \"VISmma\", \"VISmmp\", \"VISli\"],\n", + " \"thalamus\": [\n", + " \"LGd\",\n", + " \"LD\",\n", + " \"LP\",\n", + " \"VPM\",\n", + " \"TH\",\n", + " \"MGm\",\n", + " \"MGv\",\n", + " \"MGd\",\n", + " \"PO\",\n", + " \"LGv\",\n", + " \"VL\",\n", + " \"VPL\",\n", + " \"POL\",\n", + " \"Eth\",\n", + " \"PoT\",\n", + " \"PP\",\n", + " \"PIL\",\n", + " \"IntG\",\n", + " \"IGL\",\n", + " \"SGN\",\n", + " \"VPL\",\n", + " \"PF\",\n", + " \"RT\",\n", + " ],\n", + " \"hippocampus\": [\"CA1\", \"CA2\", \"CA3\", \"DG\", \"SUB\", \"POST\", \"PRE\", \"ProS\", \"HPF\"],\n", + " \"midbrain\": [\"MB\", \"SCig\", \"SCiw\", \"SCsg\", \"SCzo\", \"PPT\", \"APN\", \"NOT\", \"MRN\", \"OP\", \"LT\", \"RPF\", \"CP\"],\n", + "}\n", + "\n", + "color_dict = {\"cortex\": \"#08858C\", \"thalamus\": \"#FC6B6F\", \"hippocampus\": \"#7ED04B\", \"midbrain\": \"#FC9DFE\"}\n", + "\n", + "bins = np.linspace(-3, 2, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = np.log10(units[units.structure_acronym.isin(region_dict[region])]['firing_rate'])\n", - " \n", - " max_value = plot_metric(data, bins, 'log$_{10}$ firing rate (Hz)', color_dict[region], max_value)\n", - " \n", + " data = np.log10(units[units.structure_acronym.isin(region_dict[region])][\"firing_rate\"])\n", + "\n", + " max_value = plot_metric(data, bins, \"log$_{10}$ firing rate (Hz)\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -697,18 +720,17 @@ } ], "source": [ - "bins = np.linspace(0,1,100)\n", + "bins = np.linspace(0, 1, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['presence_ratio']\n", - " \n", - " max_value = plot_metric(data, bins, 'Presence ratio', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"presence_ratio\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Presence ratio\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.9, 0.9],[0,max_value], ':')" + "plt.plot([0.9, 0.9], [0, max_value], \":\")" ] }, { @@ -871,18 +893,17 @@ } ], "source": [ - "bins = np.linspace(0,0.5,200)\n", + "bins = np.linspace(0, 0.5, 200)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['amplitude_cutoff']\n", - " \n", - " max_value = plot_metric(data, bins, 'Amplitude cutoff', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"amplitude_cutoff\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Amplitude cutoff\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.1, 0.1],[0,max_value], ':')" + "plt.plot([0.1, 0.1], [0, max_value], \":\")" ] }, { @@ -1047,18 +1068,17 @@ } ], "source": [ - "bins = np.linspace(0,10,200)\n", + "bins = np.linspace(0, 10, 200)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['isi_violations']\n", - " \n", - " max_value = plot_metric(data, bins, 'ISI violations', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"isi_violations\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"ISI violations\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([0.5, 0.5],[0,max_value], ':')" + "plt.plot([0.5, 0.5], [0, max_value], \":\")" ] }, { @@ -1121,18 +1141,17 @@ } ], "source": [ - "bins = np.linspace(-6,2.5,100)\n", + "bins = np.linspace(-6, 2.5, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = np.log10(units[units.structure_acronym.isin(region_dict[region])]['isi_violations'] + 1e-5) \n", - " \n", - " max_value = plot_metric(data, bins, '$log_{10}$ ISI violations', color_dict[region], max_value)\n", - " \n", + " data = np.log10(units[units.structure_acronym.isin(region_dict[region])][\"isi_violations\"] + 1e-5)\n", + "\n", + " max_value = plot_metric(data, bins, \"$log_{10}$ ISI violations\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())\n", "\n", - "plt.plot([np.log10(0.5), np.log10(0.5)],[0,max_value], ':')" + "plt.plot([np.log10(0.5), np.log10(0.5)], [0, max_value], \":\")" ] }, { @@ -1290,15 +1309,14 @@ } ], "source": [ - "bins = np.linspace(0,10,100)\n", + "bins = np.linspace(0, 10, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['snr']\n", - " \n", - " max_value = plot_metric(data, bins, 'SNR', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"snr\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"SNR\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1417,15 +1435,14 @@ } ], "source": [ - "bins = np.linspace(0,170,50)\n", + "bins = np.linspace(0, 170, 50)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['isolation_distance']\n", - " \n", - " max_value = plot_metric(data, bins, 'Isolation distance', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"isolation_distance\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Isolation distance\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1520,15 +1537,14 @@ } ], "source": [ - "bins = np.linspace(0,15,50)\n", + "bins = np.linspace(0, 15, 50)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['d_prime']\n", - " \n", - " max_value = plot_metric(data, bins, 'd-prime', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"d_prime\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"d-prime\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1622,15 +1638,14 @@ } ], "source": [ - "bins = np.linspace(0,1,100)\n", + "bins = np.linspace(0, 1, 100)\n", "max_value = -np.inf\n", "\n", "for idx, region in enumerate(region_dict.keys()):\n", - " \n", - " data = units[units.structure_acronym.isin(region_dict[region])]['nn_hit_rate']\n", - " \n", - " max_value = plot_metric(data, bins, 'Nearest-neighbors hit rate', color_dict[region], max_value)\n", - " \n", + " data = units[units.structure_acronym.isin(region_dict[region])][\"nn_hit_rate\"]\n", + "\n", + " max_value = plot_metric(data, bins, \"Nearest-neighbors hit rate\", color_dict[region], max_value)\n", + "\n", "_ = plt.legend(region_dict.keys())" ] }, @@ -1724,39 +1739,33 @@ } ], "source": [ - "metrics = ['firing_rate', \n", - " 'presence_ratio', \n", - " 'amplitude_cutoff', \n", - " 'isi_violations', \n", - " 'snr', \n", - " 'isolation_distance', \n", - " 'd_prime', \n", - " 'nn_hit_rate']\n", - "\n", - "ranges = [[0,20],\n", - " [0.9,0.995],\n", - " [0,0.5],\n", - " [0,2],\n", - " [0,8],\n", - " [0,125],\n", - " [0,10],\n", - " [0,1]]\n", - "\n", - "_ = plt.figure(figsize=(5,10))\n", + "metrics = [\n", + " \"firing_rate\",\n", + " \"presence_ratio\",\n", + " \"amplitude_cutoff\",\n", + " \"isi_violations\",\n", + " \"snr\",\n", + " \"isolation_distance\",\n", + " \"d_prime\",\n", + " \"nn_hit_rate\",\n", + "]\n", + "\n", + "ranges = [[0, 20], [0.9, 0.995], [0, 0.5], [0, 2], [0, 8], [0, 125], [0, 10], [0, 1]]\n", + "\n", + "_ = plt.figure(figsize=(5, 10))\n", "\n", "for idx, metric in enumerate(metrics):\n", - " \n", " data = units[metric].values\n", " data = data[np.invert(np.isnan(data))]\n", "\n", - " _ = plt.subplot(len(metrics),1,idx+1)\n", + " _ = plt.subplot(len(metrics), 1, idx + 1)\n", " _ = plt.boxplot(data, showfliers=False, showcaps=False, vert=False)\n", - " _ = plt.ylim([0.8,1.2])\n", + " _ = plt.ylim([0.8, 1.2])\n", " _ = plt.xlim(ranges[idx])\n", " _ = plt.yticks([])\n", - " \n", + "\n", " plt.title(metric)\n", - " \n", + "\n", "plt.tight_layout()" ] }, @@ -1834,9 +1843,7 @@ } ], "source": [ - "units_filt1 = units[(units.isi_violations<0.5) \n", - " & (units.amplitude_cutoff<0.1) \n", - " & (units.presence_ratio>0.9)]\n", + "units_filt1 = units[(units.isi_violations < 0.5) & (units.amplitude_cutoff < 0.1) & (units.presence_ratio > 0.9)]\n", "len(units_filt1)" ] }, @@ -1895,7 +1902,7 @@ } ], "source": [ - "units_filt2 = units[(units.snr>1) & (units.firing_rate>0.2)]\n", + "units_filt2 = units[(units.snr > 1) & (units.firing_rate > 0.2)]\n", "len(units_filt2)" ] }, @@ -1987,8 +1994,7 @@ } ], "source": [ - "session = cache.get_ecephys_session(\n", - " ecephys_session_id=1065437523)" + "session = cache.get_ecephys_session(ecephys_session_id=1065437523)" ] }, { @@ -2016,8 +2022,8 @@ "units = session.get_units()\n", "channels = session.get_channels()\n", "\n", - "#merge the units and channels tables to get full CCF/channel info for each unit\n", - "units = units.merge(channels, left_on='peak_channel_id', right_index=True)" + "# merge the units and channels tables to get full CCF/channel info for each unit\n", + "units = units.merge(channels, left_on=\"peak_channel_id\", right_index=True)" ] }, { @@ -2238,12 +2244,14 @@ } ], "source": [ - "area_waveform_stats = units.pivot_table(index='structure_acronym', \n", - " values=['velocity_above', 'velocity_below', 'waveform_duration'], \n", - " aggfunc=['mean', 'count'])\n", + "area_waveform_stats = units.pivot_table(\n", + " index=\"structure_acronym\",\n", + " values=[\"velocity_above\", \"velocity_below\", \"waveform_duration\"],\n", + " aggfunc=[\"mean\", \"count\"],\n", + ")\n", "\n", - "print('Mean waveform features across areas')\n", - "display(area_waveform_stats[area_waveform_stats['count']['waveform_duration']>50]['mean'])" + "print(\"Mean waveform features across areas\")\n", + "display(area_waveform_stats[area_waveform_stats[\"count\"][\"waveform_duration\"] > 50][\"mean\"])" ] }, { @@ -2287,15 +2295,19 @@ }, "outputs": [], "source": [ - "unit1 = units[(units['velocity_below']<0) & \n", - " (units['waveform_duration']>0.4) &\n", - " (units['structure_acronym']=='CA1')&\n", - " (units['quality']=='good')].iloc[1]\n", + "unit1 = units[\n", + " (units[\"velocity_below\"] < 0)\n", + " & (units[\"waveform_duration\"] > 0.4)\n", + " & (units[\"structure_acronym\"] == \"CA1\")\n", + " & (units[\"quality\"] == \"good\")\n", + "].iloc[1]\n", "\n", - "unit2 = units[(units['velocity_below']>0) & \n", - " (units['waveform_duration']<0.3)&\n", - " (units['structure_acronym']=='MRN')&\n", - " (units['quality']=='good')].iloc[0]" + "unit2 = units[\n", + " (units[\"velocity_below\"] > 0)\n", + " & (units[\"waveform_duration\"] < 0.3)\n", + " & (units[\"structure_acronym\"] == \"MRN\")\n", + " & (units[\"quality\"] == \"good\")\n", + "].iloc[0]" ] }, { @@ -2331,17 +2343,17 @@ } ], "source": [ - "fig, ax = plt.subplots(1,2)\n", - "ylabels = ['probe channel', '']\n", + "fig, ax = plt.subplots(1, 2)\n", + "ylabels = [\"probe channel\", \"\"]\n", "for iu, u in enumerate([unit1, unit2]):\n", " waveform = session.mean_waveforms[u.name]\n", - " peak_chan = u['probe_channel_number']\n", + " peak_chan = u[\"probe_channel_number\"]\n", " ax[iu].imshow(waveform)\n", - " ax[iu].set_ylim([peak_chan-30, peak_chan+30])\n", + " ax[iu].set_ylim([peak_chan - 30, peak_chan + 30])\n", " ax[iu].set_xticks([0, 30, 60])\n", " ax[iu].set_xticklabels([0, 1, 2])\n", " ax[iu].set_ylabel(ylabels[iu])\n", - " ax[iu].set_xlabel('time (ms)')\n", + " ax[iu].set_xlabel(\"time (ms)\")\n", " ax[iu].set_title(u.structure_acronym)" ] }, diff --git a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quickstart.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quickstart.ipynb index ade62105ea..cbf8404777 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quickstart.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_neuropixels_quickstart.ipynb @@ -62,9 +62,9 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", - "from allensdk.brain_observatory.behavior.behavior_project_cache.\\\n", - " behavior_neuropixels_project_cache \\\n", - " import VisualBehaviorNeuropixelsProjectCache" + "from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import (\n", + " VisualBehaviorNeuropixelsProjectCache,\n", + ")" ] }, { @@ -159,8 +159,7 @@ } ], "source": [ - "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(\n", - " cache_dir=Path(output_dir))\n", + "cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=Path(output_dir))\n", "\n", "# get the metadata tables\n", "units_table = cache.get_unit_table()\n", @@ -488,8 +487,9 @@ } ], "source": [ - "sst_novel_sessions = ecephys_sessions_table.loc[(ecephys_sessions_table['genotype'].str.contains('Sst')) & \n", - " (ecephys_sessions_table['experience_level']=='Novel')]\n", + "sst_novel_sessions = ecephys_sessions_table.loc[\n", + " (ecephys_sessions_table[\"genotype\"].str.contains(\"Sst\")) & (ecephys_sessions_table[\"experience_level\"] == \"Novel\")\n", + "]\n", "sst_novel_sessions.head()" ] }, @@ -547,8 +547,7 @@ ], "source": [ "session_id = 1064644573\n", - "session = cache.get_ecephys_session(\n", - " ecephys_session_id=session_id)" + "session = cache.get_ecephys_session(ecephys_session_id=session_id)" ] }, { @@ -662,7 +661,7 @@ "units = session.get_units()\n", "channels = session.get_channels()\n", "\n", - "unit_channels = units.merge(channels, left_on='peak_channel_id', right_index=True)" + "unit_channels = units.merge(channels, left_on=\"peak_channel_id\", right_index=True)" ] }, { @@ -743,7 +742,7 @@ } ], "source": [ - "unit_channels.value_counts('structure_acronym')" + "unit_channels.value_counts(\"structure_acronym\")" ] }, { @@ -802,13 +801,13 @@ }, "outputs": [], "source": [ - "#first let's sort our units by depth\n", - "unit_channels = unit_channels.sort_values('probe_vertical_position', ascending=False)\n", + "# first let's sort our units by depth\n", + "unit_channels = unit_channels.sort_values(\"probe_vertical_position\", ascending=False)\n", "\n", - "#now we'll filter them\n", - "good_unit_filter = ((unit_channels['snr']>1)&\n", - " (unit_channels['isi_violations']<1)&\n", - " (unit_channels['firing_rate']>0.1))\n", + "# now we'll filter them\n", + "good_unit_filter = (\n", + " (unit_channels[\"snr\"] > 1) & (unit_channels[\"isi_violations\"] < 1) & (unit_channels[\"firing_rate\"] > 0.1)\n", + ")\n", "\n", "good_units = unit_channels.loc[good_unit_filter]\n", "spike_times = session.spike_times" @@ -854,8 +853,9 @@ "outputs": [], "source": [ "stimulus_presentations = session.stimulus_presentations\n", - "change_times = stimulus_presentations[stimulus_presentations['active']&\n", - " stimulus_presentations['is_change']]['start_time'].values" + "change_times = stimulus_presentations[stimulus_presentations[\"active\"] & stimulus_presentations[\"is_change\"]][\n", + " \"start_time\"\n", + "].values" ] }, { @@ -880,17 +880,17 @@ }, "outputs": [], "source": [ - "#Convenience function to compute the PSTH\n", + "# Convenience function to compute the PSTH\n", "def makePSTH(spikes, startTimes, windowDur, binSize=0.001):\n", - " bins = np.arange(0,windowDur+binSize,binSize)\n", - " counts = np.zeros(bins.size-1)\n", - " for i,start in enumerate(startTimes):\n", + " bins = np.arange(0, windowDur + binSize, binSize)\n", + " counts = np.zeros(bins.size - 1)\n", + " for i, start in enumerate(startTimes):\n", " startInd = np.searchsorted(spikes, start)\n", - " endInd = np.searchsorted(spikes, start+windowDur)\n", - " counts = counts + np.histogram(spikes[startInd:endInd]-start, bins)[0]\n", - " \n", - " counts = counts/startTimes.size\n", - " return counts/binSize, bins" + " endInd = np.searchsorted(spikes, start + windowDur)\n", + " counts = counts + np.histogram(spikes[startInd:endInd] - start, bins)[0]\n", + "\n", + " counts = counts / startTimes.size\n", + " return counts / binSize, bins" ] }, { @@ -932,17 +932,15 @@ }, "outputs": [], "source": [ - "#Here's where we loop through the units in our area of interest and compute their PSTHs\n", - "area_of_interest = 'VISp'\n", + "# Here's where we loop through the units in our area of interest and compute their PSTHs\n", + "area_of_interest = \"VISp\"\n", "area_change_responses = []\n", - "area_units = good_units[good_units['structure_acronym']==area_of_interest]\n", + "area_units = good_units[good_units[\"structure_acronym\"] == area_of_interest]\n", "time_before_change = 1\n", "duration = 2.5\n", "for iu, unit in area_units.iterrows():\n", " unit_spike_times = spike_times[iu]\n", - " unit_change_response, bins = makePSTH(unit_spike_times, \n", - " change_times-time_before_change, \n", - " duration, binSize=0.01)\n", + " unit_change_response, bins = makePSTH(unit_spike_times, change_times - time_before_change, duration, binSize=0.01)\n", " area_change_responses.append(unit_change_response)\n", "area_change_responses = np.array(area_change_responses)" ] @@ -990,23 +988,22 @@ } ], "source": [ - "#Plot the results\n", - "fig, ax = plt.subplots(1,2)\n", - "fig.set_size_inches([12,4])\n", + "# Plot the results\n", + "fig, ax = plt.subplots(1, 2)\n", + "fig.set_size_inches([12, 4])\n", "\n", - "clims = [np.percentile(area_change_responses, p) for p in (0.1,99.9)]\n", + "clims = [np.percentile(area_change_responses, p) for p in (0.1, 99.9)]\n", "im = ax[0].imshow(area_change_responses, clim=clims)\n", - "ax[0].set_title('Active Change Responses for {}'.format(area_of_interest))\n", - "ax[0].set_ylabel('Unit number, sorted by depth')\n", - "ax[0].set_xlabel('Time from change (s)')\n", - "ax[0].set_xticks(np.arange(0, bins.size-1, 20))\n", - "_ = ax[0].set_xticklabels(np.round(bins[:-1:20]-time_before_change, 2))\n", + "ax[0].set_title(\"Active Change Responses for {}\".format(area_of_interest))\n", + "ax[0].set_ylabel(\"Unit number, sorted by depth\")\n", + "ax[0].set_xlabel(\"Time from change (s)\")\n", + "ax[0].set_xticks(np.arange(0, bins.size - 1, 20))\n", + "_ = ax[0].set_xticklabels(np.round(bins[:-1:20] - time_before_change, 2))\n", "\n", - "ax[1].plot(bins[:-1]-time_before_change, np.mean(area_change_responses, axis=0), 'k')\n", - "ax[1].set_title('{} population active change response (n={})'\\\n", - " .format(area_of_interest, area_change_responses.shape[0]))\n", - "ax[1].set_xlabel('Time from change (s)')\n", - "ax[1].set_ylabel('Firing Rate')\n" + "ax[1].plot(bins[:-1] - time_before_change, np.mean(area_change_responses, axis=0), \"k\")\n", + "ax[1].set_title(\"{} population active change response (n={})\".format(area_of_interest, area_change_responses.shape[0]))\n", + "ax[1].set_xlabel(\"Time from change (s)\")\n", + "ax[1].set_ylabel(\"Firing Rate\")" ] }, { @@ -1065,9 +1062,9 @@ }, "outputs": [], "source": [ - "rf_stim_table = stimulus_presentations[stimulus_presentations['stimulus_name'].str.contains('gabor')]\n", - "xs = np.sort(rf_stim_table.position_x.unique()) #positions of gabor along azimuth\n", - "ys = np.sort(rf_stim_table.position_y.unique()) #positions of gabor along elevation" + "rf_stim_table = stimulus_presentations[stimulus_presentations[\"stimulus_name\"].str.contains(\"gabor\")]\n", + "xs = np.sort(rf_stim_table.position_x.unique()) # positions of gabor along azimuth\n", + "ys = np.sort(rf_stim_table.position_y.unique()) # positions of gabor along elevation" ] }, { @@ -1096,14 +1093,14 @@ " unit_rf = np.zeros([ys.size, xs.size])\n", " for ix, x in enumerate(xs):\n", " for iy, y in enumerate(ys):\n", - " stim_times = rf_stim_table[(rf_stim_table.position_x==x)\n", - " &(rf_stim_table.position_y==y)]['start_time'].values\n", - " unit_response, bins = makePSTH(spikes, \n", - " stim_times+0.01, \n", - " 0.2, binSize=0.001)\n", + " stim_times = rf_stim_table[(rf_stim_table.position_x == x) & (rf_stim_table.position_y == y)][\n", + " \"start_time\"\n", + " ].values\n", + " unit_response, bins = makePSTH(spikes, stim_times + 0.01, 0.2, binSize=0.001)\n", " unit_rf[iy, ix] = unit_response.mean()\n", " return unit_rf\n", "\n", + "\n", "area_rfs = []\n", "for iu, unit in area_units.iterrows():\n", " unit_spike_times = spike_times[iu]\n", @@ -1144,14 +1141,14 @@ } ], "source": [ - "fig, axes = plt.subplots(int(len(area_rfs)/10)+1, 10)\n", + "fig, axes = plt.subplots(int(len(area_rfs) / 10) + 1, 10)\n", "fig.set_size_inches(12, 8)\n", "for irf, rf in enumerate(area_rfs):\n", - " ax_row = int(irf/10)\n", - " ax_col = irf%10\n", - " axes[ax_row][ax_col].imshow(rf, origin='lower')\n", + " ax_row = int(irf / 10)\n", + " ax_col = irf % 10\n", + " axes[ax_row][ax_col].imshow(rf, origin=\"lower\")\n", "for ax in axes.flat:\n", - " ax.axis('off')" + " ax.axis(\"off\")" ] }, { @@ -1400,26 +1397,23 @@ }, "outputs": [], "source": [ - "duration = opto_table.duration.min() #get the short pulses\n", - "level = opto_table.level.max() #and the high power trials\n", + "duration = opto_table.duration.min() # get the short pulses\n", + "level = opto_table.level.max() # and the high power trials\n", "\n", - "cortical_units = good_units[good_units['structure_acronym'].str.contains('VIS')]\n", + "cortical_units = good_units[good_units[\"structure_acronym\"].str.contains(\"VIS\")]\n", "\n", "\n", - "opto_times = opto_table.loc[(opto_table['duration']==duration)&\n", - " (opto_table['level']==level)]['start_time'].values\n", + "opto_times = opto_table.loc[(opto_table[\"duration\"] == duration) & (opto_table[\"level\"] == level)][\"start_time\"].values\n", "\n", - "time_before = 0.01 # seconds to take before the laser start for PSTH\n", - "duration = 0.03 # total duration of trial for PSTH in seconds\n", - "binSize = 0.001 # 1ms bin size for PSTH\n", + "time_before = 0.01 # seconds to take before the laser start for PSTH\n", + "duration = 0.03 # total duration of trial for PSTH in seconds\n", + "binSize = 0.001 # 1ms bin size for PSTH\n", "opto_response = []\n", "unit_id = []\n", "for iu, unit in cortical_units.iterrows():\n", " unit_spike_times = spike_times[iu]\n", - " unit_response, bins = makePSTH(unit_spike_times, \n", - " opto_times-time_before, duration, \n", - " binSize=binSize)\n", - " \n", + " unit_response, bins = makePSTH(unit_spike_times, opto_times - time_before, duration, binSize=binSize)\n", + "\n", " opto_response.append(unit_response)\n", " unit_id.append(iu)\n", "\n", @@ -1460,23 +1454,23 @@ ], "source": [ "fig, ax = plt.subplots()\n", - "fig.set_size_inches((5,10))\n", - "fig.suptitle('Optotagging: ' + str(session.metadata['ecephys_session_id'])\n", - " + ' ' + session.metadata['full_genotype'])\n", - "im = ax.imshow(opto_response, \n", - " origin='lower', aspect='auto',\n", - " )\n", + "fig.set_size_inches((5, 10))\n", + "fig.suptitle(\"Optotagging: \" + str(session.metadata[\"ecephys_session_id\"]) + \" \" + session.metadata[\"full_genotype\"])\n", + "im = ax.imshow(\n", + " opto_response,\n", + " origin=\"lower\",\n", + " aspect=\"auto\",\n", + ")\n", "min_clim_val = 0\n", "max_clim_val = 250\n", - "im.set_clim([min_clim_val, max_clim_val]) \n", - "[ax.axvline(bound, linestyle=':', color='white', linewidth=1.0)\\\n", - " for bound in [10, 19]]\n", - "ax.set_xlabel('Time from laser onset (ms)')\n", - "ax.set_ylabel('Unit number')\n", - "ax.set_xticks(1000*bins[:-1:5])\n", + "im.set_clim([min_clim_val, max_clim_val])\n", + "[ax.axvline(bound, linestyle=\":\", color=\"white\", linewidth=1.0) for bound in [10, 19]]\n", + "ax.set_xlabel(\"Time from laser onset (ms)\")\n", + "ax.set_ylabel(\"Unit number\")\n", + "ax.set_xticks(1000 * bins[:-1:5])\n", "\n", - "time_labels = np.round(1000*(bins[:-1:5]-time_before), 0)\n", - "_=ax.set_xticklabels(time_labels)" + "time_labels = np.round(1000 * (bins[:-1:5] - time_before), 0)\n", + "_ = ax.set_xticklabels(time_labels)" ] }, { @@ -1521,10 +1515,11 @@ "outputs": [], "source": [ "baseline_window = slice(0, 9) # baseline epoch\n", - "response_window = slice(11,18) # laser epoch\n", + "response_window = slice(11, 18) # laser epoch\n", "\n", - "response_magnitudes = np.mean(opto_response[:, response_window], axis=1) \\\n", - " - np.mean(opto_response[:, baseline_window], axis=1)" + "response_magnitudes = np.mean(opto_response[:, response_window], axis=1) - np.mean(\n", + " opto_response[:, baseline_window], axis=1\n", + ")" ] }, { @@ -1570,23 +1565,27 @@ } ], "source": [ - "fig, axes = plt.subplots(1,2)\n", + "fig, axes = plt.subplots(1, 2)\n", "fig.set_size_inches(10, 5)\n", "\n", "# Plot scatter of opto rate vs baseline rate\n", - "axes[0].plot(np.mean(opto_response[:, baseline_window], axis=1),\n", - " np.mean(opto_response[:, response_window], axis=1), 'k.', alpha=0.2)\n", + "axes[0].plot(\n", + " np.mean(opto_response[:, baseline_window], axis=1),\n", + " np.mean(opto_response[:, response_window], axis=1),\n", + " \"k.\",\n", + " alpha=0.2,\n", + ")\n", "axes[0].set_xlim([-10, 200])\n", "axes[0].set_ylim([-10, 400])\n", - "axes[0].set_aspect('equal')\n", - "axes[0].set_ylabel('response rate (Hz)')\n", - "axes[0].set_xlabel('baseline rate (Hz)')\n", + "axes[0].set_aspect(\"equal\")\n", + "axes[0].set_ylabel(\"response rate (Hz)\")\n", + "axes[0].set_xlabel(\"baseline rate (Hz)\")\n", "\n", "# Plot histogram of opto-evoked rate (note log yscale)\n", "_ = axes[1].hist(response_magnitudes, bins=20)\n", - "axes[1].set_yscale('log')\n", - "axes[1].set_xlabel('Opto-evoked rate (Hz)')\n", - "axes[1].set_ylabel('Unit Count')" + "axes[1].set_yscale(\"log\")\n", + "axes[1].set_xlabel(\"Opto-evoked rate (Hz)\")\n", + "axes[1].set_ylabel(\"Unit Count\")" ] } ], diff --git a/doc_template/examples_root/examples/nb/visual_behavior_ophys_data_access.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_ophys_data_access.ipynb index 12747fc992..a3a2cbc69a 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_ophys_data_access.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_ophys_data_access.ipynb @@ -536,7 +536,7 @@ "# Update this to a valid directory in your filesystem\n", "# Remember to choose a location that has plenty of free space available.\n", "output_dir = \"/local1/visual_behavior_ophys_cache_dir\"\n", - "DOWNLOAD_COMPLETE_DATASET = True " + "DOWNLOAD_COMPLETE_DATASET = True" ] }, { @@ -1038,11 +1038,10 @@ } ], "source": [ - "from allensdk.brain_observatory.behavior.behavior_project_cache.utils import \\\n", - " BehaviorCloudCacheVersionException\n", + "from allensdk.brain_observatory.behavior.behavior_project_cache.utils import BehaviorCloudCacheVersionException\n", "\n", "try:\n", - " cache.load_manifest('visual-behavior-ophys_project_manifest_v0.1.0.json')\n", + " cache.load_manifest(\"visual-behavior-ophys_project_manifest_v0.1.0.json\")\n", "except BehaviorCloudCacheVersionException as e:\n", " print(e)\n", " cache.load_manifest(cache.latest_manifest_file())" @@ -1148,8 +1147,9 @@ } ], "source": [ - "msg = cache.compare_manifests('visual-behavior-ophys_project_manifest_v0.1.0.json',\n", - " 'visual-behavior-ophys_project_manifest_v0.2.0.json')\n", + "msg = cache.compare_manifests(\n", + " \"visual-behavior-ophys_project_manifest_v0.1.0.json\", \"visual-behavior-ophys_project_manifest_v0.2.0.json\"\n", + ")\n", "print(msg)" ] }, @@ -2474,7 +2474,7 @@ ], "source": [ "# Visualizing a particular stimulus\n", - "plt.imshow(behavior_session.stimulus_templates['warped']['gratings_90.0'], cmap='gray')" + "plt.imshow(behavior_session.stimulus_templates[\"warped\"][\"gratings_90.0\"], cmap=\"gray\")" ] }, { @@ -2638,7 +2638,7 @@ } ], "source": [ - "plt.imshow(ophys_experiment.max_projection, cmap='gray')" + "plt.imshow(ophys_experiment.max_projection, cmap=\"gray\")" ] }, { @@ -2819,11 +2819,13 @@ "source": [ "from urllib.parse import urljoin\n", "\n", + "\n", "def get_manifest_url(manifest_version: str) -> str:\n", " hostname = \"https://visual-behavior-ophys-data.s3-us-west-2.amazonaws.com/\"\n", " object_key = f\"visual-behavior-ophys/manifests/visual-behavior-ophys_project_manifest_v{manifest_version}.json\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_manifest_url(\"0.1.0\"))" ] @@ -2866,6 +2868,7 @@ " object_key = f\"visual-behavior-ophys/project_metadata/{metadata_table_name}.csv\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_metadata_url(\"behavior_session_table\"))" ] @@ -2908,6 +2911,7 @@ " object_key = f\"visual-behavior-ophys/behavior_sessions/behavior_session_{behavior_session_id}.nwb\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_behavior_session_url(870987812))" ] @@ -2950,6 +2954,7 @@ " object_key = f\"visual-behavior-ophys/behavior_ophys_experiments/behavior_ophys_experiment_{ophys_experiment_id}.nwb\"\n", " return urljoin(hostname, object_key)\n", "\n", + "\n", "# Example:\n", "print(get_behavior_ophys_experiment_url(951980471))" ] @@ -9073,12 +9078,13 @@ "# The location will differ based on where you downloaded the manifest.json!\n", "my_manifest_location = output_dir / cache.latest_manifest_file()\n", "\n", + "\n", "def generate_all_download_urls_from_manifest(manifest_path: Path) -> List[str]:\n", - " with manifest_path.open('r') as fp:\n", + " with manifest_path.open(\"r\") as fp:\n", " manifest = json.load(fp)\n", - " \n", + "\n", " download_links = []\n", - " \n", + "\n", " # Get download links for specific version of metadata files\n", " for metadata_file_entry in manifest[\"metadata_files\"].values():\n", " base_download_url = metadata_file_entry[\"url\"]\n", @@ -9091,12 +9097,13 @@ " base_download_url = data_file_entry[\"url\"]\n", " version_query = f\"?versionId={data_file_entry['version_id']}\"\n", " full_download_url = urljoin(base_download_url, version_query)\n", - " download_links.append(full_download_url) \n", + " download_links.append(full_download_url)\n", "\n", " return download_links\n", "\n", + "\n", "# Example:\n", - "print('\\n'.join(generate_all_download_urls_from_manifest(my_manifest_location)))" + "print(\"\\n\".join(generate_all_download_urls_from_manifest(my_manifest_location)))" ] }, { diff --git a/doc_template/examples_root/examples/nb/visual_behavior_ophys_dataset_manifest.ipynb b/doc_template/examples_root/examples/nb/visual_behavior_ophys_dataset_manifest.ipynb index ba5e325bb5..f96c3fb636 100644 --- a/doc_template/examples_root/examples/nb/visual_behavior_ophys_dataset_manifest.ipynb +++ b/doc_template/examples_root/examples/nb/visual_behavior_ophys_dataset_manifest.ipynb @@ -281,7 +281,7 @@ "Collecting pip\r\n", " Downloading pip-23.3.1-py3-none-any.whl.metadata (3.5 kB)\r\n", "Downloading pip-23.3.1-py3-none-any.whl (2.1 MB)\r\n", - "\u001b[2K \u001b[90m\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m42.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m42.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", "\u001b[?25hInstalling collected packages: pip\r\n", " Attempting uninstall: pip\r\n", " Found existing installation: pip 23.0.1\r\n", @@ -533,10 +533,10 @@ "\n", "is not deleted between instantiations of this cache\n", " warnings.warn(msg, MissingLocalManifestWarning)\n", - "ophys_session_table.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 247k/247k [00:00<00:00, 1.63MMB/s] \n", - "behavior_session_table.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1.59M/1.59M [00:00<00:00, 7.10MMB/s]\n", - "ophys_experiment_table.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 657k/657k [00:00<00:00, 4.52MMB/s] \n", - "ophys_cells_table.csv: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 4.28M/4.28M [00:00<00:00, 11.8MMB/s]\n", + "ophys_session_table.csv: 100%|██████████| 247k/247k [00:00<00:00, 1.63MMB/s] \n", + "behavior_session_table.csv: 100%|██████████| 1.59M/1.59M [00:00<00:00, 7.10MMB/s]\n", + "ophys_experiment_table.csv: 100%|██████████| 657k/657k [00:00<00:00, 4.52MMB/s] \n", + "ophys_cells_table.csv: 100%|██████████| 4.28M/4.28M [00:00<00:00, 11.8MMB/s]\n", "/home/runner/work/AllenSDK/AllenSDK/allensdk/brain_observatory/behavior/behavior_project_cache/behavior_project_cache.py:135: UpdatedStimulusPresentationTableWarning: \n", "\tAs of AllenSDK version 2.16.0, the latest Visual Behavior Ophys data has been significantly updated from previous releases. Specifically the user will need to update all processing of the stimulus_presentations tables. These tables now include multiple stimulus types delineated by the columns `stimulus_block` and `stimulus_block_name`.\n", "\n", @@ -944,7 +944,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 34 columns

\n", + "

5 rows × 34 columns

\n", "" ], "text/plain": [ @@ -1158,7 +1158,7 @@ } ], "source": [ - "print('behavior data could be recorded on these experimental systems:\\n')\n", + "print(\"behavior data could be recorded on these experimental systems:\\n\")\n", "print(np.sort(behavior_sessions.equipment_name.unique()))" ] }, @@ -1255,7 +1255,7 @@ } ], "source": [ - "print('there are ', len(behavior_sessions.mouse_id.unique()), 'mice in the dataset')" + "print(\"there are \", len(behavior_sessions.mouse_id.unique()), \"mice in the dataset\")" ] }, { @@ -1316,7 +1316,7 @@ } ], "source": [ - "print('the different transgenic lines included in this dataset are:\\n')\n", + "print(\"the different transgenic lines included in this dataset are:\\n\")\n", "print(np.sort(behavior_sessions.full_genotype.unique()))" ] }, @@ -1375,7 +1375,7 @@ } ], "source": [ - "print('the different cre lines used in this dataset are:\\n')\n", + "print(\"the different cre lines used in this dataset are:\\n\")\n", "print(np.sort(behavior_sessions.cre_line.unique()))" ] }, @@ -1456,7 +1456,7 @@ } ], "source": [ - "print('the different reporter lines used in this dataset are:\\n')\n", + "print(\"the different reporter lines used in this dataset are:\\n\")\n", "print(np.sort(behavior_sessions.reporter_line.unique()))" ] }, @@ -1495,7 +1495,7 @@ } ], "source": [ - "print('the different indicators used in this dataset are:\\n')\n", + "print(\"the different indicators used in this dataset are:\\n\")\n", "print(np.sort(behavior_sessions.indicator.unique()))" ] }, @@ -1629,7 +1629,9 @@ } ], "source": [ - "behavior_sessions.groupby(['full_genotype', 'mouse_id']).count().reset_index().groupby('full_genotype').count()[['mouse_id']]" + "behavior_sessions.groupby([\"full_genotype\", \"mouse_id\"]).count().reset_index().groupby(\"full_genotype\").count()[\n", + " [\"mouse_id\"]\n", + "]" ] }, { @@ -1744,9 +1746,8 @@ } ], "source": [ - "print('the session_types available in this dataset are:\\n')\n", - "print(np.sort(behavior_sessions.session_type[\n", - " ~behavior_sessions.session_type.isna()].unique()))" + "print(\"the session_types available in this dataset are:\\n\")\n", + "print(np.sort(behavior_sessions.session_type[~behavior_sessions.session_type.isna()].unique()))" ] }, { @@ -1883,10 +1884,9 @@ } ], "source": [ - "# reminder about possible session types \n", - "print('the different session_types available in this dataset are:\\n')\n", - "print(np.sort(behavior_sessions.session_type[\n", - " ~behavior_sessions.session_type.isna()].unique()))" + "# reminder about possible session types\n", + "print(\"the different session_types available in this dataset are:\\n\")\n", + "print(np.sort(behavior_sessions.session_type[~behavior_sessions.session_type.isna()].unique()))" ] }, { @@ -2220,7 +2220,7 @@ " \n", " \n", "\n", - "

107 rows \u00d7 2 columns

\n", + "

107 rows × 2 columns

\n", "" ], "text/plain": [ @@ -2247,7 +2247,11 @@ ], "source": [ "# get a table of the project code for each mouse\n", - "project_code_lookup = behavior_sessions[~behavior_sessions.project_code.isnull()].reset_index().drop_duplicates('mouse_id')[['mouse_id','project_code']]\n", + "project_code_lookup = (\n", + " behavior_sessions[~behavior_sessions.project_code.isnull()]\n", + " .reset_index()\n", + " .drop_duplicates(\"mouse_id\")[[\"mouse_id\", \"project_code\"]]\n", + ")\n", "project_code_lookup" ] }, @@ -2276,10 +2280,11 @@ }, "outputs": [], "source": [ - "behavior_sessions = behavior_sessions.merge(project_code_lookup, on='mouse_id',\n", - " how='left', suffixes=('_session', '_mouse'))\n", - "behavior_sessions = behavior_sessions.drop(columns='project_code_session')\n", - "behavior_sessions = behavior_sessions.rename(columns={'project_code_mouse': 'project_code'})" + "behavior_sessions = behavior_sessions.merge(\n", + " project_code_lookup, on=\"mouse_id\", how=\"left\", suffixes=(\"_session\", \"_mouse\")\n", + ")\n", + "behavior_sessions = behavior_sessions.drop(columns=\"project_code_session\")\n", + "behavior_sessions = behavior_sessions.rename(columns={\"project_code_mouse\": \"project_code\"})" ] }, { @@ -2434,12 +2439,14 @@ } ], "source": [ - "for project_code in behavior_sessions.project_code.unique(): \n", - " project_sessions = behavior_sessions[behavior_sessions.project_code==project_code]\n", - " print('\\n project_code:', project_code)\n", - " print('\\n has these session types:\\n', np.sort(\n", - " project_sessions.session_type[~project_sessions.session_type.isna()].unique()))\n", - " print('\\n')" + "for project_code in behavior_sessions.project_code.unique():\n", + " project_sessions = behavior_sessions[behavior_sessions.project_code == project_code]\n", + " print(\"\\n project_code:\", project_code)\n", + " print(\n", + " \"\\n has these session types:\\n\",\n", + " np.sort(project_sessions.session_type[~project_sessions.session_type.isna()].unique()),\n", + " )\n", + " print(\"\\n\")" ] }, { @@ -2754,7 +2761,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 25 columns

\n", + "

5 rows × 25 columns

\n", "" ], "text/plain": [ @@ -3077,10 +3084,10 @@ } ], "source": [ - "# what do the ophys_experiment_id and ophys_container_id columns look like? \n", - "# are there always the same number of experiments and containers in different sessions? \n", - "# does the number of experiments and containers depend on the microscope used? \n", - "ophys_sessions[['ophys_experiment_id', 'ophys_container_id', 'equipment_name']][:15]" + "# what do the ophys_experiment_id and ophys_container_id columns look like?\n", + "# are there always the same number of experiments and containers in different sessions?\n", + "# does the number of experiments and containers depend on the microscope used?\n", + "ophys_sessions[[\"ophys_experiment_id\", \"ophys_container_id\", \"equipment_name\"]][:15]" ] }, { @@ -3252,12 +3259,13 @@ ], "source": [ "# pick a mouse\n", - "mouse_id = 445002 \n", + "mouse_id = 445002\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='CAM2P.3')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"CAM2P.3\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['session_type', 'date_of_acquisition', 'ophys_session_id']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[[\"session_type\", \"date_of_acquisition\", \"ophys_session_id\"]]" ] }, { @@ -3314,8 +3322,8 @@ } ], "source": [ - "print('there are', len(mouse_ophys_sessions), 'ophys sessions in the behavior_session_table for this mouse')\n", - "print('this includes ophys sessions that failed QC for ophys, but still have behavior data')" + "print(\"there are\", len(mouse_ophys_sessions), \"ophys sessions in the behavior_session_table for this mouse\")\n", + "print(\"this includes ophys sessions that failed QC for ophys, but still have behavior data\")" ] }, { @@ -3373,8 +3381,12 @@ } ], "source": [ - "print('there are', len(ophys_sessions[ophys_sessions.mouse_id==mouse_id]), 'sessions in the ophys_session_table for this mouse')\n", - "print('these are the sessions with valid ophys data')" + "print(\n", + " \"there are\",\n", + " len(ophys_sessions[ophys_sessions.mouse_id == mouse_id]),\n", + " \"sessions in the ophys_session_table for this mouse\",\n", + ")\n", + "print(\"these are the sessions with valid ophys data\")" ] }, { @@ -3448,7 +3460,7 @@ } ], "source": [ - "ophys_sessions[ophys_sessions.mouse_id==mouse_id][['date_of_acquisition', 'session_type']]" + "ophys_sessions[ophys_sessions.mouse_id == mouse_id][[\"date_of_acquisition\", \"session_type\"]]" ] }, { @@ -3581,10 +3593,11 @@ "# pick a mouse\n", "mouse_id = 453911\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='MESO.1')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"MESO.1\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['date_of_acquisition', 'session_type', 'ophys_session_id']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[[\"date_of_acquisition\", \"session_type\", \"ophys_session_id\"]]" ] }, { @@ -3698,7 +3711,7 @@ } ], "source": [ - "ophys_sessions[ophys_sessions.mouse_id==mouse_id][['date_of_acquisition', 'session_type']]" + "ophys_sessions[ophys_sessions.mouse_id == mouse_id][[\"date_of_acquisition\", \"session_type\"]]" ] }, { @@ -3811,10 +3824,11 @@ "# pick a mouse\n", "mouse_id = 438912\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='MESO.1')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"MESO.1\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['date_of_acquisition', 'session_type', 'ophys_session_id']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[[\"date_of_acquisition\", \"session_type\", \"ophys_session_id\"]]" ] }, { @@ -4005,12 +4019,15 @@ } ], "source": [ - "mouse_id = 445002 \n", + "mouse_id = 445002\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='CAM2P.3')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"CAM2P.3\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['session_type', 'date_of_acquisition', 'ophys_session_id', 'prior_exposures_to_image_set']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[\n", + " [\"session_type\", \"date_of_acquisition\", \"ophys_session_id\", \"prior_exposures_to_image_set\"]\n", + "]" ] }, { @@ -4121,11 +4138,13 @@ } ], "source": [ - "mouse_id = 445002 \n", + "mouse_id = 445002\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)]\n", + "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id == mouse_id)]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['session_type', 'date_of_acquisition', 'ophys_session_id', 'prior_exposures_to_image_set']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[\n", + " [\"session_type\", \"date_of_acquisition\", \"ophys_session_id\", \"prior_exposures_to_image_set\"]\n", + "]" ] }, { @@ -4221,7 +4240,9 @@ } ], "source": [ - "ophys_sessions[ophys_sessions.mouse_id==mouse_id][['date_of_acquisition', 'session_type', 'prior_exposures_to_image_set']]" + "ophys_sessions[ophys_sessions.mouse_id == mouse_id][\n", + " [\"date_of_acquisition\", \"session_type\", \"prior_exposures_to_image_set\"]\n", + "]" ] }, { @@ -4355,10 +4376,13 @@ "# pick a mouse\n", "mouse_id = 456915\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='MESO.1')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"MESO.1\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['date_of_acquisition', 'session_type', 'ophys_session_id', 'prior_exposures_to_session_type']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[\n", + " [\"date_of_acquisition\", \"session_type\", \"ophys_session_id\", \"prior_exposures_to_session_type\"]\n", + "]" ] }, { @@ -4434,7 +4458,9 @@ } ], "source": [ - "ophys_sessions[ophys_sessions.mouse_id==mouse_id][['date_of_acquisition', 'session_type', 'prior_exposures_to_session_type']]" + "ophys_sessions[ophys_sessions.mouse_id == mouse_id][\n", + " [\"date_of_acquisition\", \"session_type\", \"prior_exposures_to_session_type\"]\n", + "]" ] }, { @@ -4579,7 +4605,7 @@ } ], "source": [ - "np.sort(behavior_sessions[behavior_sessions.equipment_name=='CAM2P.4'].mouse_id.unique())" + "np.sort(behavior_sessions[behavior_sessions.equipment_name == \"CAM2P.4\"].mouse_id.unique())" ] }, { @@ -4654,10 +4680,13 @@ "# pick a mouse\n", "mouse_id = 436662\n", "# get behavior sessions that took place on the microscope\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)&\n", - " (behavior_sessions.equipment_name=='CAM2P.4')]\n", + "mouse_ophys_sessions = behavior_sessions[\n", + " (behavior_sessions.mouse_id == mouse_id) & (behavior_sessions.equipment_name == \"CAM2P.4\")\n", + "]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['date_of_acquisition', 'session_type', 'ophys_session_id', 'equipment_name', 'prior_exposures_to_omissions']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[\n", + " [\"date_of_acquisition\", \"session_type\", \"ophys_session_id\", \"equipment_name\", \"prior_exposures_to_omissions\"]\n", + "]" ] }, { @@ -4772,9 +4801,11 @@ "# pick a mouse\n", "mouse_id = 423606\n", "# get behavior sessions - include training as well\n", - "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id==mouse_id)]\n", + "mouse_ophys_sessions = behavior_sessions[(behavior_sessions.mouse_id == mouse_id)]\n", "# only look at the relevant columns\n", - "mouse_ophys_sessions.sort_values(by='date_of_acquisition')[['date_of_acquisition', 'session_type', 'ophys_session_id', 'equipment_name', 'prior_exposures_to_omissions']]" + "mouse_ophys_sessions.sort_values(by=\"date_of_acquisition\")[\n", + " [\"date_of_acquisition\", \"session_type\", \"ophys_session_id\", \"equipment_name\", \"prior_exposures_to_omissions\"]\n", + "]" ] }, { @@ -4830,15 +4861,19 @@ } ], "source": [ - "# get all behavior sessions that were habituation sessions (image set A or B) \n", + "# get all behavior sessions that were habituation sessions (image set A or B)\n", "# where the prior exposures to omissions was not zero\n", - "habituation_with_omission = behavior_sessions[((behavior_sessions.session_type=='OPHYS_0_images_A_habituation')|\n", - " (behavior_sessions.session_type=='OPHYS_0_images_B_habituation'))&\n", - " (behavior_sessions.prior_exposures_to_omissions>0)]\n", + "habituation_with_omission = behavior_sessions[\n", + " (\n", + " (behavior_sessions.session_type == \"OPHYS_0_images_A_habituation\")\n", + " | (behavior_sessions.session_type == \"OPHYS_0_images_B_habituation\")\n", + " )\n", + " & (behavior_sessions.prior_exposures_to_omissions > 0)\n", + "]\n", "\n", "mice_with_omission_during_habituation = habituation_with_omission.mouse_id.unique()\n", "\n", - "print(len(mice_with_omission_during_habituation), ' mice had omissions during habituation')" + "print(len(mice_with_omission_during_habituation), \" mice had omissions during habituation\")" ] }, { @@ -5124,7 +5159,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 30 columns

\n", + "

5 rows × 30 columns

\n", "" ], "text/plain": [ @@ -5430,11 +5465,15 @@ "source": [ "# loop through project codes and print the available imaging_depths and targeted_structures\n", "for project_code in ophys_experiments.project_code.unique():\n", - " \n", - " project_experiments = ophys_experiments[ophys_experiments.project_code==project_code]\n", - " print('\\nimaging_depths available for', project_code, 'include: ', project_experiments.imaging_depth.unique())\n", - " print('\\ntargeted_structures available for', project_code, 'include: ', project_experiments.targeted_structure.unique())\n", - " print('\\n')" + " project_experiments = ophys_experiments[ophys_experiments.project_code == project_code]\n", + " print(\"\\nimaging_depths available for\", project_code, \"include: \", project_experiments.imaging_depth.unique())\n", + " print(\n", + " \"\\ntargeted_structures available for\",\n", + " project_code,\n", + " \"include: \",\n", + " project_experiments.targeted_structure.unique(),\n", + " )\n", + " print(\"\\n\")" ] }, { @@ -5825,7 +5864,7 @@ " \n", " \n", "\n", - "

5 rows \u00d7 30 columns

\n", + "

5 rows × 30 columns

\n", "" ], "text/plain": [ @@ -5918,7 +5957,7 @@ } ], "source": [ - "container_experiments = ophys_experiments[ophys_experiments.ophys_container_id==ophys_container_id]\n", + "container_experiments = ophys_experiments[ophys_experiments.ophys_container_id == ophys_container_id]\n", "container_experiments" ] }, @@ -6286,8 +6325,9 @@ ], "source": [ "# get all Sst experiments in the relevant project code\n", - "sst_experiments = ophys_experiments[(ophys_experiments.cre_line=='Sst-IRES-Cre')&\n", - " (ophys_experiments.project_code=='VisualBehaviorTask1B')]\n", + "sst_experiments = ophys_experiments[\n", + " (ophys_experiments.cre_line == \"Sst-IRES-Cre\") & (ophys_experiments.project_code == \"VisualBehaviorTask1B\")\n", + "]\n", "\n", "# pick some container from this set\n", "ophys_container_id = sst_experiments.ophys_container_id.unique()[1]\n", @@ -6533,7 +6573,7 @@ " \n", " \n", "\n", - "

6 rows \u00d7 30 columns

\n", + "

6 rows × 30 columns

\n", "" ], "text/plain": [ @@ -6636,8 +6676,8 @@ } ], "source": [ - "# what experiments are there for this container? \n", - "sst_container_experiments = sst_experiments[sst_experiments.ophys_container_id==ophys_container_id]\n", + "# what experiments are there for this container?\n", + "sst_container_experiments = sst_experiments[sst_experiments.ophys_container_id == ophys_container_id]\n", "sst_container_experiments" ] }, @@ -6717,12 +6757,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "behavior_ophys_experiment_955276580.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 313M/313M [00:13<00:00, 22.8MMB/s]\n", - "behavior_ophys_experiment_956903375.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 322M/322M [00:13<00:00, 23.4MMB/s]\n", - "behavior_ophys_experiment_957652800.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 318M/318M [00:13<00:00, 22.8MMB/s]\n", - "behavior_ophys_experiment_959337347.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 342M/342M [00:14<00:00, 22.8MMB/s]\n", - "behavior_ophys_experiment_960351917.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 327M/327M [00:14<00:00, 22.6MMB/s]\n", - "behavior_ophys_experiment_960960480.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 332M/332M [00:18<00:00, 18.0MMB/s]\n" + "behavior_ophys_experiment_955276580.nwb: 100%|██████████| 313M/313M [00:13<00:00, 22.8MMB/s]\n", + "behavior_ophys_experiment_956903375.nwb: 100%|██████████| 322M/322M [00:13<00:00, 23.4MMB/s]\n", + "behavior_ophys_experiment_957652800.nwb: 100%|██████████| 318M/318M [00:13<00:00, 22.8MMB/s]\n", + "behavior_ophys_experiment_959337347.nwb: 100%|██████████| 342M/342M [00:14<00:00, 22.8MMB/s]\n", + "behavior_ophys_experiment_960351917.nwb: 100%|██████████| 327M/327M [00:14<00:00, 22.6MMB/s]\n", + "behavior_ophys_experiment_960960480.nwb: 100%|██████████| 332M/332M [00:18<00:00, 18.0MMB/s]\n" ] }, { @@ -6741,13 +6781,13 @@ "ophys_experiment_ids = sst_container_experiments.index.values\n", "\n", "# create figure axis\n", - "fig, ax = plt.subplots(1, len(ophys_experiment_ids), figsize=(20,5))\n", + "fig, ax = plt.subplots(1, len(ophys_experiment_ids), figsize=(20, 5))\n", "# enumerate over experiments in this container\n", - "for i, ophys_experiment_id in enumerate(ophys_experiment_ids): \n", + "for i, ophys_experiment_id in enumerate(ophys_experiment_ids):\n", " # get the dataset object\n", " dataset = cache.get_behavior_ophys_experiment(ophys_experiment_id=ophys_experiment_id)\n", " # get the max intensity projection and plot on the appropriate axis\n", - " ax[i].imshow(dataset.max_projection.data, cmap='gray')\n", + " ax[i].imshow(dataset.max_projection.data, cmap=\"gray\")\n", " ax[i].set_title(ophys_experiment_id)" ] }, @@ -6817,9 +6857,11 @@ "outputs": [], "source": [ "# get all Vip sessions in the Multiscope project code\n", - "vip_sessions = ophys_sessions[(ophys_sessions.cre_line=='Vip-IRES-Cre')&\n", - " (ophys_sessions.project_code=='VisualBehaviorMultiscope')&\n", - " (ophys_sessions.prior_exposures_to_image_set==0)]\n", + "vip_sessions = ophys_sessions[\n", + " (ophys_sessions.cre_line == \"Vip-IRES-Cre\")\n", + " & (ophys_sessions.project_code == \"VisualBehaviorMultiscope\")\n", + " & (ophys_sessions.prior_exposures_to_image_set == 0)\n", + "]\n", "\n", "# ophys_session_id is the index of the ophys_session_table\n", "ophys_session_id = vip_sessions.index.values[0]" @@ -6976,12 +7018,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "behavior_ophys_experiment_1050762966.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 248M/248M [00:10<00:00, 22.6MMB/s]\n", - "behavior_ophys_experiment_1050762969.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 258M/258M [00:11<00:00, 23.3MMB/s]\n", - "behavior_ophys_experiment_1050762972.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 251M/251M [00:10<00:00, 25.1MMB/s]\n", - "behavior_ophys_experiment_1050762974.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 251M/251M [00:10<00:00, 23.6MMB/s]\n", - "behavior_ophys_experiment_1050762975.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 259M/259M [00:10<00:00, 25.0MMB/s]\n", - "behavior_ophys_experiment_1050762977.nwb: 100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 250M/250M [00:10<00:00, 23.9MMB/s]\n" + "behavior_ophys_experiment_1050762966.nwb: 100%|██████████| 248M/248M [00:10<00:00, 22.6MMB/s]\n", + "behavior_ophys_experiment_1050762969.nwb: 100%|██████████| 258M/258M [00:11<00:00, 23.3MMB/s]\n", + "behavior_ophys_experiment_1050762972.nwb: 100%|██████████| 251M/251M [00:10<00:00, 25.1MMB/s]\n", + "behavior_ophys_experiment_1050762974.nwb: 100%|██████████| 251M/251M [00:10<00:00, 23.6MMB/s]\n", + "behavior_ophys_experiment_1050762975.nwb: 100%|██████████| 259M/259M [00:10<00:00, 25.0MMB/s]\n", + "behavior_ophys_experiment_1050762977.nwb: 100%|██████████| 250M/250M [00:10<00:00, 23.9MMB/s]\n" ] }, { @@ -7007,9 +7049,9 @@ ], "source": [ "# create figure axis\n", - "fig, ax = plt.subplots(1,1, figsize=(15,4))\n", + "fig, ax = plt.subplots(1, 1, figsize=(15, 4))\n", "# enumerate over experiments in this session\n", - "for i, ophys_experiment_id in enumerate(ophys_experiment_ids): \n", + "for i, ophys_experiment_id in enumerate(ophys_experiment_ids):\n", " # get the dataset object\n", " dataset = cache.get_behavior_ophys_experiment(ophys_experiment_id=ophys_experiment_id)\n", " # get ophys timestamps\n", @@ -7022,14 +7064,14 @@ " # take the mean over the cell axis\n", " average_dFF = np.mean(dff_traces, axis=0)\n", " # get the imaging_depth and targeted_structure for this experiment\n", - " imaging_depth = dataset.metadata['imaging_depth']\n", - " targeted_structure = dataset.metadata['targeted_structure']\n", + " imaging_depth = dataset.metadata[\"imaging_depth\"]\n", + " targeted_structure = dataset.metadata[\"targeted_structure\"]\n", " # plot it, including the imaging_depth and targeted_structure in the legend label\n", - " ax.plot(ophys_timestamps, average_dFF, label=targeted_structure+'_'+str(imaging_depth))\n", - " ax.set_title(dataset.metadata['cre_line']+', ophys_session_id: '+str(ophys_session_id))\n", - "ax.set_ylabel('dF/F')\n", - "ax.set_xlabel('time (seconds)')\n", - "ax.set_xlim(5*60, 10*60)\n", + " ax.plot(ophys_timestamps, average_dFF, label=targeted_structure + \"_\" + str(imaging_depth))\n", + " ax.set_title(dataset.metadata[\"cre_line\"] + \", ophys_session_id: \" + str(ophys_session_id))\n", + "ax.set_ylabel(\"dF/F\")\n", + "ax.set_xlabel(\"time (seconds)\")\n", + "ax.set_xlim(5 * 60, 10 * 60)\n", "ax.legend()" ] }, @@ -7226,10 +7268,10 @@ } ], "source": [ - "cell_per_exp = cells_table.groupby('ophys_experiment_id').size()\n", + "cell_per_exp = cells_table.groupby(\"ophys_experiment_id\").size()\n", "fig = plt.hist(cell_per_exp, bins=50)\n", - "plt.xlabel('Cell count')\n", - "plt.ylabel('Number of experiments')\n", + "plt.xlabel(\"Cell count\")\n", + "plt.ylabel(\"Number of experiments\")\n", "plt.show()\n", "cell_per_exp.describe()" ] @@ -7273,7 +7315,7 @@ }, "outputs": [], "source": [ - "ophys_experiments['n_cells'] = ophys_experiments.index.map(cell_per_exp)" + "ophys_experiments[\"n_cells\"] = ophys_experiments.index.map(cell_per_exp)" ] }, { @@ -7327,9 +7369,9 @@ ], "source": [ "fig, ax = plt.subplots(figsize=(30, 10))\n", - "ax.scatter(ophys_experiments['imaging_depth'], ophys_experiments['n_cells'], alpha=.3)\n", - "ax.set_xlabel('Imaging depth (microns)')\n", - "ax.set_ylabel('Cell count')\n", + "ax.scatter(ophys_experiments[\"imaging_depth\"], ophys_experiments[\"n_cells\"], alpha=0.3)\n", + "ax.set_xlabel(\"Imaging depth (microns)\")\n", + "ax.set_ylabel(\"Cell count\")\n", "plt.show()" ] }, @@ -7394,7 +7436,8 @@ ], "source": [ "import seaborn as sns\n", - "sns.boxplot(data=ophys_experiments, x='n_cells', y='cre_line')" + "\n", + "sns.boxplot(data=ophys_experiments, x=\"n_cells\", y=\"cre_line\")" ] } ], diff --git a/doc_template/examples_root/examples/simple/simple.py b/doc_template/examples_root/examples/simple/simple.py index 0db8fd5be7..18fc4add00 100644 --- a/doc_template/examples_root/examples/simple/simple.py +++ b/doc_template/examples_root/examples/simple/simple.py @@ -2,7 +2,7 @@ from utils import Utils import numpy -config = Config().load('config.json') +config = Config().load("config.json") # configure NEURON utils = Utils(config) @@ -30,11 +30,11 @@ # scaling mV = 1.0e-3 ms = 1.0e-3 -output_data = numpy.array(vec['v']) * mV -output_times = numpy.array(vec['t']) * ms +output_data = numpy.array(vec["v"]) * mV +output_times = numpy.array(vec["t"]) * ms output = numpy.column_stack((output_times, output_data)) # write to a dat File v_out_path = manifest.get_path("output_dat") -with open (v_out_path, "w") as f: - numpy.savetxt(f, output) \ No newline at end of file +with open(v_out_path, "w") as f: + numpy.savetxt(f, output) diff --git a/doc_template/examples_root/examples/simple/utils.py b/doc_template/examples_root/examples/simple/utils.py index 9ac58a0b1b..beb290f1d5 100644 --- a/doc_template/examples_root/examples/simple/utils.py +++ b/doc_template/examples_root/examples/simple/utils.py @@ -4,41 +4,41 @@ class Utils(HocUtils): _log = logging.getLogger(__name__) - + def __init__(self, description): super(Utils, self).__init__(description) self.stim = None self.stim_curr = None self.sampling_rate = None - + def generate_morphology(self): h = self.h self.soma = h.Section() self.soma.L = 10.0 self.soma.diam = 10.0 - + def load_cell_parameters(self): - passive = self.description.data['passive'][0] - channel_parameters = self.description.data['channel_parameters'][0] - conditions = self.description.data['conditions'][0] - + passive = self.description.data["passive"][0] + channel_parameters = self.description.data["channel_parameters"][0] + conditions = self.description.data["conditions"][0] + # Insert channels - self.soma.insert('pas') - self.soma.insert('NaTs') - self.soma.insert('K_P') + self.soma.insert("pas") + self.soma.insert("NaTs") + self.soma.insert("K_P") # Set fixed passive properties - self.soma.Ra = passive['ra'] + self.soma.Ra = passive["ra"] self.soma.cm = passive["cm"] for seg in self.soma: seg.pas.e = passive["e_pas"] seg.pas.g = channel_parameters["g_pas"] - + # Set active channel densities for seg in self.soma: seg.NaTs.gbar = channel_parameters["gbar_Na"] seg.K_P.gbar = channel_parameters["gbar_K"] - + # Set reversal potentials self.soma.ena = conditions["erev"][0]["ena"] self.soma.ek = conditions["erev"][0]["ek"] @@ -50,11 +50,9 @@ def setup_iclamp(self): self.stim.dur = 2.0 def record_values(self): - vec = { "v": self.h.Vector(), - "t": self.h.Vector() } - + vec = {"v": self.h.Vector(), "t": self.h.Vector()} + vec["v"].record(self.soma(0.5)._ref_v) vec["t"].record(self.h._ref_t) - + return vec - \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6343c48445..18a4861b0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,9 @@ notebooks = [ packages = ["allensdk"] exclude = ["allensdk/test", "allensdk/test_utilities"] +[tool.black] +line-length = 120 + [tool.ruff] line-length = 120 diff --git a/scripts/brain_observatory/create_input_json.py b/scripts/brain_observatory/create_input_json.py index 3ceb9400db..3cf67fa2c8 100644 --- a/scripts/brain_observatory/create_input_json.py +++ b/scripts/brain_observatory/create_input_json.py @@ -8,245 +8,234 @@ import datetime import re -def createInputJson(directory, resort_directory, module, output_file, last_unit_id): +def createInputJson(directory, resort_directory, module, output_file, last_unit_id): session_id = basename(directory) + sync_file = glob(join(directory, "*.sync"))[0] - sync_file = glob(join(directory, '*.sync'))[0] - - LIMS_session_id = os.path.basename(sync_file).split('.')[0][:9] + LIMS_session_id = os.path.basename(sync_file).split(".")[0][:9] - nwb_output_path = join('/mnt/nvme0/ecephys_nwb_files_20190827-2', session_id + '.spikes.nwb2') + nwb_output_path = join("/mnt/nvme0/ecephys_nwb_files_20190827-2", session_id + ".spikes.nwb2") - join(directory, 'stim_table.csv') + join(directory, "stim_table.csv") - #print(resort_directory) + # print(resort_directory) - probe_directories = glob(join(resort_directory, '*probe*','*probe*')) + probe_directories = glob(join(resort_directory, "*probe*", "*probe*")) - #print(probe_directories) + # print(probe_directories) probe_directories.sort() probes = [] for probe_idx, probe_directory in enumerate(probe_directories): - print(probe_directory) original_probe_directory = os.path.join(directory, os.path.basename(probe_directory)) name = probe_directory[-13:-7] - new_sorting_directory = glob(join(probe_directory, 'continuous', 'Neuropix-*-100.0'))[0] - original_sorting_directory = glob(join(original_probe_directory, 'continuous', 'Neuropix-*-100.0'))[0] + new_sorting_directory = glob(join(probe_directory, "continuous", "Neuropix-*-100.0"))[0] + original_sorting_directory = glob(join(original_probe_directory, "continuous", "Neuropix-*-100.0"))[0] - if original_sorting_directory.find('PXI') > -1: - probe_type = 'PXI' + if original_sorting_directory.find("PXI") > -1: + probe_type = "PXI" else: - probe_type = '3a' - + probe_type = "3a" - #lfp_dict = { + # lfp_dict = { # 'input_data_path' : 'none', # 'input_timestamps_path' : 'none', # 'input_channels_path' : 'none', # 'output_path' : 'none' - #} + # } timestamp_files = [] timestamp_files.append( - { 'name' : 'spike_timestamps', - 'input_path' : join(new_sorting_directory,'spike_times.npy'), - 'output_path' : join(new_sorting_directory, 'spike_times_master_clock.npy'), + { + "name": "spike_timestamps", + "input_path": join(new_sorting_directory, "spike_times.npy"), + "output_path": join(new_sorting_directory, "spike_times_master_clock.npy"), } ) - #timestamp_files.append( + # timestamp_files.append( # { 'name' : 'lfp_timestamps', # 'input_path' : join(new_sorting_directory,'spike_times.npy'), # 'output_path' : join(new_sorting_directory, 'spike_times_master_clock.npy'), # } - #) + # ) - - if module == 'allensdk.brain_observatory.ecephys.align_timestamps': + if module == "allensdk.brain_observatory.ecephys.align_timestamps": probe_dict = { - 'name' : name, - 'sampling_rate' : 30000., - 'lfp_sampling_rate' : 2500., - 'barcode_channel_states_path' : join(original_probe_directory, 'events', 'Neuropix-' + probe_type + '-100.0', 'TTL_1', 'channel_states.npy'), - 'barcode_timestamps_path' : join(original_probe_directory, 'events', 'Neuropix-' + probe_type + '-100.0', 'TTL_1', 'event_timestamps.npy'), - 'mappable_timestamp_files' : timestamp_files, + "name": name, + "sampling_rate": 30000.0, + "lfp_sampling_rate": 2500.0, + "barcode_channel_states_path": join( + original_probe_directory, + "events", + "Neuropix-" + probe_type + "-100.0", + "TTL_1", + "channel_states.npy", + ), + "barcode_timestamps_path": join( + original_probe_directory, + "events", + "Neuropix-" + probe_type + "-100.0", + "TTL_1", + "event_timestamps.npy", + ), + "mappable_timestamp_files": timestamp_files, } else: - channel_info = pd.read_csv(join(original_sorting_directory, 'ccf_regions_new.csv'), index_col=0) + channel_info = pd.read_csv(join(original_sorting_directory, "ccf_regions_new.csv"), index_col=0) channels = [] for idx, row in channel_info.iterrows(): + structure_acronym = row["structure_acronym"] + numbers = re.findall(r"\d+", structure_acronym) - structure_acronym = row['structure_acronym'] - numbers = re.findall(r'\d+', structure_acronym) - - if (len(numbers) > 0 and name[:2] != 'CA'): + if len(numbers) > 0 and name[:2] != "CA": structure_acronym = structure_acronym.split(numbers[0])[0] - cortical_layer = '/'.join(numbers) + cortical_layer = "/".join(numbers) else: - cortical_layer = 'none' + cortical_layer = "none" channel_dict = { - 'id' : idx + probe_idx * 1000, - 'valid_data' : row['is_valid'], - 'probe_id' : probe_idx, - 'local_index' : idx, - 'probe_vertical_position' : row['vertical_position'], - 'probe_horizontal_position' : row['horizontal_position'], - 'structure_id' : row['structure_id'], - 'cortical_layer' : cortical_layer, - 'structure_acronym' : structure_acronym, - 'AP_coordinate' : row['A/P'], - 'DV_coordinate' : row['D/V'], - 'ML_coordinate' : row['M/L'], - 'cortical_depth' : row['cortical_depth'] + "id": idx + probe_idx * 1000, + "valid_data": row["is_valid"], + "probe_id": probe_idx, + "local_index": idx, + "probe_vertical_position": row["vertical_position"], + "probe_horizontal_position": row["horizontal_position"], + "structure_id": row["structure_id"], + "cortical_layer": cortical_layer, + "structure_acronym": structure_acronym, + "AP_coordinate": row["A/P"], + "DV_coordinate": row["D/V"], + "ML_coordinate": row["M/L"], + "cortical_depth": row["cortical_depth"], } channels.append(channel_dict) - unit_info = pd.read_csv(join(new_sorting_directory, 'metrics.csv.v2'), index_col=0) - #unit_quality = pd.read_csv(join(new_sorting_directory, 'cluster_group.tsv'), index_col=0, sep='\t') - #unit_quality = unit_quality.replace(to_replace='unsorted',value='good') + unit_info = pd.read_csv(join(new_sorting_directory, "metrics.csv.v2"), index_col=0) + # unit_quality = pd.read_csv(join(new_sorting_directory, 'cluster_group.tsv'), index_col=0, sep='\t') + # unit_quality = unit_quality.replace(to_replace='unsorted',value='good') units = [] print(len(unit_info)) for idx, row in unit_info.iterrows(): - - if row['quality'] == 'good': - - unit_dict = { - 'id' : last_unit_id, - 'peak_channel_id' : row['peak_channel'] + probe_idx * 1000, - 'local_index' : idx, - 'cluster_id' : row['cluster_id'], - 'quality' : row['quality'], - 'firing_rate' : cleanUpNanAndInf(row['firing_rate']), - 'snr' : cleanUpNanAndInf(row['snr']), - 'isi_violations' : cleanUpNanAndInf(row['isi_viol']), - 'presence_ratio' : cleanUpNanAndInf(row['presence_ratio']), - 'amplitude_cutoff' : cleanUpNanAndInf(row['amplitude_cutoff']), - 'isolation_distance' : cleanUpNanAndInf(row['isolation_distance']), - 'l_ratio' : cleanUpNanAndInf(row['l_ratio']), - 'd_prime' : cleanUpNanAndInf(row['d_prime']), - 'nn_hit_rate' : cleanUpNanAndInf(row['nn_hit_rate']), - 'nn_miss_rate' : cleanUpNanAndInf(row['nn_miss_rate']), - 'max_drift' : cleanUpNanAndInf(row['max_drift']), - 'cumulative_drift' : cleanUpNanAndInf(row['cumulative_drift']), - 'silhouette_score' : cleanUpNanAndInf(row['silhouette_score']), - 'waveform_duration' : cleanUpNanAndInf(row['duration']), - 'waveform_halfwidth' : cleanUpNanAndInf(row['halfwidth']), - 'waveform_PT_ratio' : cleanUpNanAndInf(row['PT_ratio']), - 'waveform_repolarization_slope' : cleanUpNanAndInf(row['repolarization_slope']), - 'waveform_recovery_slope' : cleanUpNanAndInf(row['recovery_slope']), - 'waveform_amplitude' : cleanUpNanAndInf(row['amplitude']), - 'waveform_spread' : cleanUpNanAndInf(row['spread']), - 'waveform_velocity_above' : cleanUpNanAndInf(row['velocity_above']), - 'waveform_velocity_below' : cleanUpNanAndInf(row['velocity_below']) - } - - #if channel_info.loc[row['peak_channel']]['structure_acronym'] == 'VISp5': - units.append(unit_dict) - last_unit_id += 1 - - #print(len(unit_info)) + if row["quality"] == "good": + unit_dict = { + "id": last_unit_id, + "peak_channel_id": row["peak_channel"] + probe_idx * 1000, + "local_index": idx, + "cluster_id": row["cluster_id"], + "quality": row["quality"], + "firing_rate": cleanUpNanAndInf(row["firing_rate"]), + "snr": cleanUpNanAndInf(row["snr"]), + "isi_violations": cleanUpNanAndInf(row["isi_viol"]), + "presence_ratio": cleanUpNanAndInf(row["presence_ratio"]), + "amplitude_cutoff": cleanUpNanAndInf(row["amplitude_cutoff"]), + "isolation_distance": cleanUpNanAndInf(row["isolation_distance"]), + "l_ratio": cleanUpNanAndInf(row["l_ratio"]), + "d_prime": cleanUpNanAndInf(row["d_prime"]), + "nn_hit_rate": cleanUpNanAndInf(row["nn_hit_rate"]), + "nn_miss_rate": cleanUpNanAndInf(row["nn_miss_rate"]), + "max_drift": cleanUpNanAndInf(row["max_drift"]), + "cumulative_drift": cleanUpNanAndInf(row["cumulative_drift"]), + "silhouette_score": cleanUpNanAndInf(row["silhouette_score"]), + "waveform_duration": cleanUpNanAndInf(row["duration"]), + "waveform_halfwidth": cleanUpNanAndInf(row["halfwidth"]), + "waveform_PT_ratio": cleanUpNanAndInf(row["PT_ratio"]), + "waveform_repolarization_slope": cleanUpNanAndInf(row["repolarization_slope"]), + "waveform_recovery_slope": cleanUpNanAndInf(row["recovery_slope"]), + "waveform_amplitude": cleanUpNanAndInf(row["amplitude"]), + "waveform_spread": cleanUpNanAndInf(row["spread"]), + "waveform_velocity_above": cleanUpNanAndInf(row["velocity_above"]), + "waveform_velocity_below": cleanUpNanAndInf(row["velocity_below"]), + } + + # if channel_info.loc[row['peak_channel']]['structure_acronym'] == 'VISp5': + units.append(unit_dict) + last_unit_id += 1 + + # print(len(unit_info)) probe_dict = { - 'id' : probe_idx, - 'name' : name, - 'spike_times_path' : join(new_sorting_directory, 'spike_times_master_clock.npy'), - 'spike_clusters_file' : join(new_sorting_directory, 'spike_clusters.npy'), - 'mean_waveforms_path' : join(new_sorting_directory, 'mean_waveforms.npy'), - 'channels' : channels, - 'units' : units, + "id": probe_idx, + "name": name, + "spike_times_path": join(new_sorting_directory, "spike_times_master_clock.npy"), + "spike_clusters_file": join(new_sorting_directory, "spike_clusters.npy"), + "mean_waveforms_path": join(new_sorting_directory, "mean_waveforms.npy"), + "channels": channels, + "units": units, #'lfp' : lfp_dict } probes.append(probe_dict) - - if module == 'allensdk.brain_observatory.ecephys.align_timestamps': - - dictionary = \ - { - 'sync_h5_path' : glob(join(directory, '*.sync'))[0], - "probes" : probes, + if module == "allensdk.brain_observatory.ecephys.align_timestamps": + dictionary = { + "sync_h5_path": glob(join(directory, "*.sync"))[0], + "probes": probes, } - elif module == 'allensdk.brain_observatory.ecephys.stimulus_table': - - dictionary = \ - { - 'stimulus_pkl_path' : glob(join(directory, '*.stim.pkl'))[0], - 'sync_h5_path' : glob(join(directory, '*.sync'))[0], - 'output_stimulus_table_path' : os.path.join(directory, 'stim_table_allensdk.csv'), - 'output_frame_times_path' : os.path.join(directory, 'frame_times.npy'), - - "log_level" : 'INFO' + elif module == "allensdk.brain_observatory.ecephys.stimulus_table": + dictionary = { + "stimulus_pkl_path": glob(join(directory, "*.stim.pkl"))[0], + "sync_h5_path": glob(join(directory, "*.sync"))[0], + "output_stimulus_table_path": os.path.join(directory, "stim_table_allensdk.csv"), + "output_frame_times_path": os.path.join(directory, "frame_times.npy"), + "log_level": "INFO", } - elif module == 'allensdk.brain_observatory.extract_running_speed': - - dictionary = \ - { - 'stimulus_pkl_path' : glob(join(directory, '*.stim.pkl'))[0], - 'sync_h5_path' : glob(join(directory, '*.sync'))[0], - - 'output_path' : join(directory, 'running_speed.h5'), - - "log_level" : 'INFO' + elif module == "allensdk.brain_observatory.extract_running_speed": + dictionary = { + "stimulus_pkl_path": glob(join(directory, "*.stim.pkl"))[0], + "sync_h5_path": glob(join(directory, "*.sync"))[0], + "output_path": join(directory, "running_speed.h5"), + "log_level": "INFO", } - elif module == 'allensdk.brain_observatory.ecephys.optotagging_table': - - dictionary = \ - { - 'opto_pickle_path' : glob(join(directory, '*.opto.pkl.v2'))[0], - 'sync_h5_path' : glob(join(directory, '*.sync'))[0], - 'output_opto_table_path' : join(directory, 'optotagging_table.csv') + elif module == "allensdk.brain_observatory.ecephys.optotagging_table": + dictionary = { + "opto_pickle_path": glob(join(directory, "*.opto.pkl.v2"))[0], + "sync_h5_path": glob(join(directory, "*.sync"))[0], + "output_opto_table_path": join(directory, "optotagging_table.csv"), } - elif module == 'allensdk.brain_observatory.ecephys.write_nwb': - + elif module == "allensdk.brain_observatory.ecephys.write_nwb": session_string = os.path.basename(probe_directories[0]) YYYY = int(session_string[17:21]) MM = int(session_string[21:23]) DD = int(session_string[23:25]) - dictionary = \ - { - "log_level" : 'INFO', - "output_path" : nwb_output_path, - "session_id" : int(LIMS_session_id), - "session_start_time" : datetime.datetime(YYYY, MM, DD, 0, 0, 0).isoformat(), - "stimulus_table_path" : os.path.join(directory, 'stim_table_allensdk.csv'), - "probes" : probes, - "running_speed_path" : join(directory, 'running_speed.h5')#, - # "optotagging_table_path" : join(directory, 'optotagging_table.csv') + dictionary = { + "log_level": "INFO", + "output_path": nwb_output_path, + "session_id": int(LIMS_session_id), + "session_start_time": datetime.datetime(YYYY, MM, DD, 0, 0, 0).isoformat(), + "stimulus_table_path": os.path.join(directory, "stim_table_allensdk.csv"), + "probes": probes, + "running_speed_path": join(directory, "running_speed.h5"), # , + # "optotagging_table_path" : join(directory, 'optotagging_table.csv') } - - with io.open(output_file, 'w', encoding='utf-8') as f: - f.write(json.dumps(dictionary, ensure_ascii=False, sort_keys=True, indent=4)) + with io.open(output_file, "w", encoding="utf-8") as f: + f.write(json.dumps(dictionary, ensure_ascii=False, sort_keys=True, indent=4)) return dictionary, last_unit_id - def cleanUpNanAndInf(value): - if np.isnan(value) or np.isinf(value): return -1 else: - return value \ No newline at end of file + return value diff --git a/scripts/brain_observatory/deploy_visual_coding_ophys_eye_tracking.py b/scripts/brain_observatory/deploy_visual_coding_ophys_eye_tracking.py index eb672509f0..1f651ec496 100644 --- a/scripts/brain_observatory/deploy_visual_coding_ophys_eye_tracking.py +++ b/scripts/brain_observatory/deploy_visual_coding_ophys_eye_tracking.py @@ -1,4 +1,5 @@ """Script to deploy add-on visual coding data to S3""" + import os import tempfile from argparse import ArgumentParser @@ -9,30 +10,21 @@ import allensdk import pandas as pd -from allensdk.brain_observatory.data_release_utils.metadata_utils\ - .id_generator import \ - FileIDGenerator -from allensdk.brain_observatory.data_release_utils.metadata_utils.utils \ - import \ - add_file_paths_to_metadata_table +from allensdk.brain_observatory.data_release_utils.metadata_utils.id_generator import FileIDGenerator +from allensdk.brain_observatory.data_release_utils.metadata_utils.utils import add_file_paths_to_metadata_table -parser = ArgumentParser('deploy visual coding ophys eye tracking data') -parser.add_argument( - '--data_path', - help='dir containing eye tracking data', - required=True) +parser = ArgumentParser("deploy visual coding ophys eye tracking data") +parser.add_argument("--data_path", help="dir containing eye tracking data", required=True) args = parser.parse_args() def main(): - files = [x for x in os.listdir(args.data_path) if Path(x).suffix == '.npy'] + files = [x for x in os.listdir(args.data_path) if Path(x).suffix == ".npy"] if len(files) > 0: - exp_ids = [x.replace('.npy', '') for x in files] + exp_ids = [x.replace(".npy", "") for x in files] for exp_id in exp_ids: os.makedirs(Path(args.data_path) / exp_id) - os.rename( - Path(args.data_path) / f'{exp_id}.npy', - Path(args.data_path) / exp_id / 'eye_tracking.npy') + os.rename(Path(args.data_path) / f"{exp_id}.npy", Path(args.data_path) / exp_id / "eye_tracking.npy") metadata_path = _write_metadata() pipeline_metadata = [] @@ -45,23 +37,25 @@ def main(): pipeline_metadata.append(sdk_metadata) with tempfile.TemporaryDirectory() as tmp_dir: - release_tool = DataReleaseTool(input_data={ - 'metadata_files': [metadata_path], - 'project_name': 'visual-coding-ophys', - 'data_pipeline_metadata': pipeline_metadata, - 'bucket_name': 'visual-coding-ophys-data', - 'remote_client': 'AWS_S3', - 'release_semver_type': 'minor', - 'staging_directory': tmp_dir - }, args=[]) + release_tool = DataReleaseTool( + input_data={ + "metadata_files": [metadata_path], + "project_name": "visual-coding-ophys", + "data_pipeline_metadata": pipeline_metadata, + "bucket_name": "visual-coding-ophys-data", + "remote_client": "AWS_S3", + "release_semver_type": "minor", + "staging_directory": tmp_dir, + }, + args=[], + ) release_tool.run() def _write_metadata(): - exp_ids = [x for x in os.listdir(args.data_path) - if (Path(args.data_path) / x).is_dir()] + exp_ids = [x for x in os.listdir(args.data_path) if (Path(args.data_path) / x).is_dir()] - metadata_table = pd.DataFrame({'ophys_experiment_id': exp_ids}) + metadata_table = pd.DataFrame({"ophys_experiment_id": exp_ids}) file_id_generator = FileIDGenerator() metadata_table = add_file_paths_to_metadata_table( @@ -71,14 +65,14 @@ def _write_metadata(): file_prefix=None, index_col="ophys_experiment_id", data_dir_col="ophys_experiment_id", - on_missing_file='error', - file_suffix='npy', - file_stem='eye_tracking' + on_missing_file="error", + file_suffix="npy", + file_stem="eye_tracking", ) - output_path = str(Path(args.data_path) / 'metadata.csv') + output_path = str(Path(args.data_path) / "metadata.csv") metadata_table.to_csv(output_path, index=False) return output_path -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/brain_observatory/run_ecephys_nwb_packaging.py b/scripts/brain_observatory/run_ecephys_nwb_packaging.py index cbed0831af..ba6f1e6ba1 100644 --- a/scripts/brain_observatory/run_ecephys_nwb_packaging.py +++ b/scripts/brain_observatory/run_ecephys_nwb_packaging.py @@ -6,66 +6,74 @@ from create_input_json import createInputJson import warnings + warnings.filterwarnings("ignore", message="numpy.dtype size changed") -#available_directories = glob.glob('/mnt/hdd0/RE-SORT/mouse*') +# available_directories = glob.glob('/mnt/hdd0/RE-SORT/mouse*') -df = pd.read_csv('/home/joshs/Documents/mouse_table.csv') +df = pd.read_csv("/home/joshs/Documents/mouse_table.csv") -#mice = list(df['Mouse'].values) -#mice = mice[12:] +# mice = list(df['Mouse'].values) +# mice = mice[12:] mice = [404551, 404553, 404555, 404568] -#mice = [int(name[-6:]) for name in available_directories] -# = [int(name) for name in mice] - +# mice = [int(name[-6:]) for name in available_directories] +# = [int(name) for name in mice] + -json_directory = '/mnt/md0/data/json_files' +json_directory = "/mnt/md0/data/json_files" -modules = [#'allensdk.brain_observatory.ecephys.align_timestamps', - #'allensdk.brain_observatory.ecephys.stimulus_table', - #'allensdk.brain_observatory.ecephys.optotagging_table', - #'allensdk.brain_observatory.extract_running_speed', #, - 'allensdk.brain_observatory.ecephys.write_nwb'] +modules = [ #'allensdk.brain_observatory.ecephys.align_timestamps', + #'allensdk.brain_observatory.ecephys.stimulus_table', + #'allensdk.brain_observatory.ecephys.optotagging_table', + #'allensdk.brain_observatory.extract_running_speed', #, + "allensdk.brain_observatory.ecephys.write_nwb" +] -data_directory = '/mnt/md0/data' -resort_directory = '/mnt/hdd0/RE-SORT' +data_directory = "/mnt/md0/data" +resort_directory = "/mnt/hdd0/RE-SORT" df = pd.DataFrame() last_unit_id = 0 for mouse in mice: + # try: + mouse_directory = data_directory + "/mouse" + str(mouse) - #try: - mouse_directory = data_directory + '/mouse' + str(mouse) - - if os.path.exists(mouse_directory): - probe_data_directory = resort_directory + '/mouse' + str(mouse) - - pkl_file = glob.glob(mouse_directory + '/*.stim.pkl')[0] - session_id = os.path.basename(pkl_file).split('.')[0] - - print(session_id) + if os.path.exists(mouse_directory): + probe_data_directory = resort_directory + "/mouse" + str(mouse) - print(mouse_directory) + pkl_file = glob.glob(mouse_directory + "/*.stim.pkl")[0] + session_id = os.path.basename(pkl_file).split(".")[0] - for module in modules: + print(session_id) - input_json = os.path.join(json_directory, session_id + '-' + module + '-input.json') - output_json = os.path.join(json_directory, session_id + '-' + module + '-output.json') + print(mouse_directory) - info, last_unit_id = createInputJson(mouse_directory, probe_data_directory, module, input_json, last_unit_id) - - if not os.path.exists(info['output_path']): + for module in modules: + input_json = os.path.join(json_directory, session_id + "-" + module + "-input.json") + output_json = os.path.join(json_directory, session_id + "-" + module + "-output.json") - print('Running ' + module) + info, last_unit_id = createInputJson( + mouse_directory, probe_data_directory, module, input_json, last_unit_id + ) - command_string = ["python", "-W", "ignore", "-m", module, - "--input_json", input_json, - "--output_json", output_json] + if not os.path.exists(info["output_path"]): + print("Running " + module) - subprocess.check_call(command_string) - #except: - # print('Error processing') + command_string = [ + "python", + "-W", + "ignore", + "-m", + module, + "--input_json", + input_json, + "--output_json", + output_json, + ] + subprocess.check_call(command_string) +# except: +# print('Error processing') diff --git a/scripts/brain_observatory/run_ecephys_stimulus_analysis.py b/scripts/brain_observatory/run_ecephys_stimulus_analysis.py index 0b1c7c4878..d2d9288e1c 100644 --- a/scripts/brain_observatory/run_ecephys_stimulus_analysis.py +++ b/scripts/brain_observatory/run_ecephys_stimulus_analysis.py @@ -5,86 +5,60 @@ import glob import pandas as pd -def createInputJson(output_file): - - pd.read_csv('/mnt/md0/data/production_QC/experiment_table_2019-08-27.csv',index_col=0) - - #fc_mice = np.sort(df[df['stimulus_set'].str.match('Functional')].index.values) - - nwb_files = glob.glob('/mnt/nvme0/ecephys_nwb_files_20190827/*.nwb2') #['/mnt/nvme0/ecephys_nwb_files_20190727/mouse' + str(mouse) + '.spikes.nwb2' for mouse in df.index.values] - - print('Found ' + str(len(nwb_files)) + ' nwb files') - - dictionary = { \ - - "drifting_gratings" : - { - "stimulus_key" : "drifting_gratings" - }, - - "static_gratings" : - { - "stimulus_key" : "static_gratings" - }, - - "natural_scenes" : - { - "stimulus_key" : "natural_scenes" - }, - - "natural_movies" : - { - "stimulus_key" : "natural_movies" - }, - "dot_motion" : - { - "stimulus_key" : "dot_motion" - }, - - "contrast_tuning" : - { - "stimulus_key" : "contrast_tuning" - }, - - - "flashes" : - { - "stimulus_key" : "flashes" - }, - - - "receptive_field_mapping" : - { - "stimulus_key" : "gabors", - "mask_threshold" : 1.0, - "minimum_spike_count" : 10 - }, - - "output_file" : '/mnt/md0/data/production_QC/stimulus_analysis_TEST20190805.csv', - - "nwb_paths" : nwb_files[:1] #['/mnt/nvme0/ecephys_nwb_files_20190727/mouse412804_integration_test_fc.spikes.nwb2', - #'/mnt/nvme0/ecephys_nwb_files_20190727/integration_test.spikes.nwb2']#' - #nwb_files[:20] - } - - with io.open(output_file, 'w', encoding='utf-8') as f: - f.write(json.dumps(dictionary, ensure_ascii=False, sort_keys=True, indent=4)) +def createInputJson(output_file): + pd.read_csv("/mnt/md0/data/production_QC/experiment_table_2019-08-27.csv", index_col=0) + + # fc_mice = np.sort(df[df['stimulus_set'].str.match('Functional')].index.values) + + nwb_files = glob.glob( + "/mnt/nvme0/ecephys_nwb_files_20190827/*.nwb2" + ) # ['/mnt/nvme0/ecephys_nwb_files_20190727/mouse' + str(mouse) + '.spikes.nwb2' for mouse in df.index.values] + + print("Found " + str(len(nwb_files)) + " nwb files") + + dictionary = { + "drifting_gratings": {"stimulus_key": "drifting_gratings"}, + "static_gratings": {"stimulus_key": "static_gratings"}, + "natural_scenes": {"stimulus_key": "natural_scenes"}, + "natural_movies": {"stimulus_key": "natural_movies"}, + "dot_motion": {"stimulus_key": "dot_motion"}, + "contrast_tuning": {"stimulus_key": "contrast_tuning"}, + "flashes": {"stimulus_key": "flashes"}, + "receptive_field_mapping": {"stimulus_key": "gabors", "mask_threshold": 1.0, "minimum_spike_count": 10}, + "output_file": "/mnt/md0/data/production_QC/stimulus_analysis_TEST20190805.csv", + "nwb_paths": nwb_files[ + :1 + ], # ['/mnt/nvme0/ecephys_nwb_files_20190727/mouse412804_integration_test_fc.spikes.nwb2', + #'/mnt/nvme0/ecephys_nwb_files_20190727/integration_test.spikes.nwb2']#' + # nwb_files[:20] + } + + with io.open(output_file, "w", encoding="utf-8") as f: + f.write(json.dumps(dictionary, ensure_ascii=False, sort_keys=True, indent=4)) return dictionary -json_directory = '/mnt/md0/data/json_files' +json_directory = "/mnt/md0/data/json_files" -module = 'stimulus_analysis' +module = "stimulus_analysis" -input_json = os.path.join(json_directory, module + '-input.json') -output_json = os.path.join(json_directory, module + '-output.json') +input_json = os.path.join(json_directory, module + "-input.json") +output_json = os.path.join(json_directory, module + "-output.json") info = createInputJson(input_json) -command_string = ["python", "-W", "ignore", "-m", "allensdk.brain_observatory.ecephys." + module, - "--input_json", input_json, - "--output_json", output_json] +command_string = [ + "python", + "-W", + "ignore", + "-m", + "allensdk.brain_observatory.ecephys." + module, + "--input_json", + input_json, + "--output_json", + output_json, +] subprocess.check_call(command_string) diff --git a/scripts/brain_observatory/run_session_analysis.py b/scripts/brain_observatory/run_session_analysis.py index 91201b3653..ceb8a40b18 100644 --- a/scripts/brain_observatory/run_session_analysis.py +++ b/scripts/brain_observatory/run_session_analysis.py @@ -9,7 +9,7 @@ def main(): parser.add_argument("input_nwb") parser.add_argument("output_h5") - parser.add_argument("--plot", action='store_true') + parser.add_argument("--plot", action="store_true") args = parser.parse_args() logging.basicConfig() @@ -18,5 +18,5 @@ def main(): run_session_analysis(args.input_nwb, args.output_h5, args.plot) -if __name__ == '__main__': +if __name__ == "__main__": main() From c9371ae4f112958aadc6a297e0e59c30ce3f38a0 Mon Sep 17 00:00:00 2001 From: Galen Lynch Date: Fri, 20 Feb 2026 16:30:18 -0800 Subject: [PATCH 2/2] Suppress many flake8 warnings that are stylistic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add E742 and E743 to ignore list - ambiguous variable names are OK in math-y code F841 — Unused variables (8) These are all tuple unpacking with _-prefixed names like _i_prev, _t_prev, _r_prev and except Exception as _:. The convention communicates intent. F401 — Unused imports (14) These are re-exports in __init__.py files, moviepy.editor import * for side effects, and json_utilities re-export. Standard pattern. E721 — Type comparison (15) Most are type(self) != type(other) in __eq__ methods or exact type dispatch where isinstance would have different semantics (matching subclasses) --- .flake8 | 5 +- allensdk/api/queries/rma_template.py | 2 +- .../behavior_project_cache/__init__.py | 4 +- .../project_apis/abcs/__init__.py | 2 +- .../project_apis/data_io/__init__.py | 10 +-- .../behavior/data_objects/__init__.py | 4 +- .../ecephys/ecephys_session.py | 2 +- .../stimulus_analysis/drifting_gratings.py | 4 +- .../ecephys/stimulus_analysis/flashes.py | 2 +- .../stimulus_analysis/static_gratings.py | 2 +- .../eye_tracking/stage_1/DLC_Eye_Tracking.py | 2 +- .../stage_2/DLC_Ellipse_Fitting.py | 2 +- .../eye_tracking/stage_3/DLC_Labeled_Video.py | 2 +- allensdk/brain_observatory/nwb/metadata.py | 6 +- allensdk/brain_observatory/roi_masks.py | 2 +- .../brain_observatory/session_api_utils.py | 2 +- .../core/_data_object_base/data_object.py | 2 +- .../core/brain_observatory_nwb_data_set.py | 2 +- allensdk/core/cell_types_cache.py | 2 +- allensdk/ephys/ephys_extractor.py | 6 +- allensdk/ephys/feature_extractor.py | 2 +- .../internal/brain_observatory/itracker.py | 2 +- .../internal/ephys/core_feature_extract.py | 4 +- allensdk/internal/model/GLM.py | 2 +- .../internal/model/biophysical/fit_stage_1.py | 2 +- allensdk/internal/model/glif/ASGLM.py | 2 +- .../internal/model/glif/error_functions.py | 2 +- allensdk/internal/model/glif/find_sweeps.py | 4 +- .../internal/model/glif/glif_optimizer.py | 14 ++-- .../internal/model/glif/preprocess_neuron.py | 4 +- allensdk/internal/model/glif/spike_cutting.py | 2 +- .../model/glif/threshold_adaptation.py | 76 +++++++++---------- allensdk/internal/morphology/morphology.py | 4 +- .../morphology/surrogate_strategy.py | 2 +- .../morphology/upright_transform.py | 10 +-- .../run_observatory_thumbnails.py | 2 +- allensdk/model/glif/glif_neuron.py | 2 +- allensdk/test/api/test_svg_api.py | 14 ++-- allensdk/test/api/test_tree_search_api.py | 14 ++-- allensdk/test_utilities/custom_comparators.py | 6 +- pyproject.toml | 2 + 41 files changed, 119 insertions(+), 118 deletions(-) diff --git a/.flake8 b/.flake8 index a5929244c9..1ca6010485 100644 --- a/.flake8 +++ b/.flake8 @@ -2,11 +2,10 @@ max-line-length = 120 extend-ignore = E203, - E501, E402, E741, + E742, + E743, F403, F405, - W291, - W293, W503 diff --git a/allensdk/api/queries/rma_template.py b/allensdk/api/queries/rma_template.py index c876593c86..73d98ecc73 100644 --- a/allensdk/api/queries/rma_template.py +++ b/allensdk/api/queries/rma_template.py @@ -48,7 +48,7 @@ def __init__(self, base_uri=None, query_manifest=None): self.templates = query_manifest def to_filter_rhs(self, rhs): - if type(rhs) == list: + if type(rhs) == list: # noqa: E721 return ",".join(str(r) for r in rhs) return rhs diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py index 198c6d8a37..8df7cc22e8 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/__init__.py @@ -1,7 +1,7 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_project_cache import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_project_cache import ( # noqa: F401 VisualBehaviorOphysProjectCache, ) # noqa F401 -from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.behavior_neuropixels_project_cache import ( # noqa: F401 VisualBehaviorNeuropixelsProjectCache, ) # noqa F401 diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py index 8a46b05196..a3cdfc897d 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/abcs/__init__.py @@ -1,3 +1,3 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.abcs.behavior_project_base import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.abcs.behavior_project_base import ( # noqa: F401 BehaviorProjectBase, ) # noqa: F401, E501 diff --git a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py index e72177b134..0629bad157 100644 --- a/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py +++ b/allensdk/brain_observatory/behavior/behavior_project_cache/project_apis/data_io/__init__.py @@ -1,15 +1,15 @@ -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_lims_api import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_lims_api import ( # noqa: F401 BehaviorProjectLimsApi, ) # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_project_cloud_api import ( # noqa: F401 BehaviorProjectCloudApi, ) # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( # noqa: F401 VisualBehaviorNeuropixelsProjectCloudApi, ) # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.behavior_neuropixels_project_cloud_api import ( # noqa: F401 ProjectCloudApiBase, ) # noqa: F401, E501 -from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.natural_movie_one_cache import ( +from allensdk.brain_observatory.behavior.behavior_project_cache.project_apis.data_io.natural_movie_one_cache import ( # noqa: F401 NaturalMovieOneCache, ) # noqa: F401, E501 diff --git a/allensdk/brain_observatory/behavior/data_objects/__init__.py b/allensdk/brain_observatory/behavior/data_objects/__init__.py index 7bcc137971..124c671d90 100644 --- a/allensdk/brain_observatory/behavior/data_objects/__init__.py +++ b/allensdk/brain_observatory/behavior/data_objects/__init__.py @@ -1,7 +1,7 @@ -from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_session_id import ( +from allensdk.brain_observatory.behavior.data_objects.metadata.behavior_metadata.behavior_session_id import ( # noqa: F401 BehaviorSessionId, ) # noqa: E501, F401 -from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.stimulus_timestamps import ( +from allensdk.brain_observatory.behavior.data_objects.timestamps.stimulus_timestamps.stimulus_timestamps import ( # noqa: F401 StimulusTimestamps, ) # noqa: E501, F401 from allensdk.brain_observatory.behavior.data_objects.running_speed.running_speed import RunningSpeed # noqa: E501, F401 diff --git a/allensdk/brain_observatory/ecephys/ecephys_session.py b/allensdk/brain_observatory/ecephys/ecephys_session.py index 00b0ff5dfb..d3219fb3f6 100644 --- a/allensdk/brain_observatory/ecephys/ecephys_session.py +++ b/allensdk/brain_observatory/ecephys/ecephys_session.py @@ -1287,7 +1287,7 @@ def nan_intervals(array, nan_like=["null"]): def is_distinct_from(left, right): - if type(left) != type(right): + if type(left) != type(right): # noqa: E721 return True if pd.isna(left) and pd.isna(right): return False diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py index 22aefff25c..8a26a3458d 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/drifting_gratings.py @@ -550,7 +550,7 @@ def _fit_tf_tuning(self, unit_id, pref_ori, pref_tf): return fit_tf_ind, fit_tf, tf_low_cutoff, tf_high_cutoff ''' - ## VISUALIZATION ## + # VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): """Plot raster for one condition and one unit""" idx_tf = np.where(self.tfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_tf])[0] @@ -638,7 +638,7 @@ def make_star_plot(self, unit_id): plt.axis("off") -### General functions ### +# General functions ### def _gauss_function(x, a, x0, sigma): """ fit gaussian function at log scale diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py index a200f35a82..551fc8b429 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/flashes.py @@ -185,7 +185,7 @@ def _get_on_off_ratio(self, unit_id): else: return np.nan - ## VISUALIZATION ## + # VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): """Plot raster for one condition and one unit""" diff --git a/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py b/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py index 919a1f3f85..36516935ef 100644 --- a/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py +++ b/allensdk/brain_observatory/ecephys/stimulus_analysis/static_gratings.py @@ -333,7 +333,7 @@ def _get_osi(self, unit_id, pref_sf, pref_phase): tuning = np.array(df["spike_mean"].values) return osi(orivals_rad, tuning) - ## VISUALIZATION ## + # VISUALIZATION ## def plot_raster(self, stimulus_condition_id, unit_id): """Plot raster for one condition and one unit""" diff --git a/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py b/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py index 379c2fd068..7c371a956a 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py +++ b/allensdk/brain_observatory/eye_tracking/stage_1/DLC_Eye_Tracking.py @@ -6,7 +6,7 @@ os.environ["DLClight"] = "True" import deeplabcut -from moviepy.editor import * +from moviepy.editor import * # noqa: F401 import argparse import logging diff --git a/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py b/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py index 05afcee2a0..d428c649e7 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py +++ b/allensdk/brain_observatory/eye_tracking/stage_2/DLC_Ellipse_Fitting.py @@ -3,7 +3,7 @@ t0 = time.time() import os -from moviepy.editor import * +from moviepy.editor import * # noqa: F401 import numpy as np import pandas as pd import argparse diff --git a/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py b/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py index 37a0a025cb..472b494f00 100644 --- a/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py +++ b/allensdk/brain_observatory/eye_tracking/stage_3/DLC_Labeled_Video.py @@ -6,7 +6,7 @@ os.environ["DLClight"] = "True" import deeplabcut -from moviepy.editor import * +from moviepy.editor import * # noqa: F401 import argparse import logging diff --git a/allensdk/brain_observatory/nwb/metadata.py b/allensdk/brain_observatory/nwb/metadata.py index 11c8340921..39c4965c62 100644 --- a/allensdk/brain_observatory/nwb/metadata.py +++ b/allensdk/brain_observatory/nwb/metadata.py @@ -24,7 +24,7 @@ def extract_from_schema(schema): if name in fields_to_skip: continue - if type(val) == fields.Nested: + if type(val) == fields.Nested: # noqa: E721 dataset = _extract_dataset(val=val) datasets.append(dataset) continue @@ -108,7 +108,7 @@ def _extract_attributes(attributes, fields_to_skip=None): if fields_to_skip and name in fields_to_skip: continue - if type(val) == fields.List: + if type(val) == fields.List: # noqa: E721 res.append( NWBAttributeSpec( name=name, @@ -118,7 +118,7 @@ def _extract_attributes(attributes, fields_to_skip=None): required=val.required, ) ) - elif type(val) == fields.Nested: + elif type(val) == fields.Nested: # noqa: E721 continue else: res.append(NWBAttributeSpec(name=name, dtype=STYPE_DICT[type(val)], doc=val.metadata["doc"])) diff --git a/allensdk/brain_observatory/roi_masks.py b/allensdk/brain_observatory/roi_masks.py index f3cae5744e..8f8ebbb35a 100644 --- a/allensdk/brain_observatory/roi_masks.py +++ b/allensdk/brain_observatory/roi_masks.py @@ -472,7 +472,7 @@ def calculate_roi_and_neuropil_traces(movie_h5, roi_mask_list, motion_border): neuropil_masks.append(nmask) num_rois = len(roi_mask_list) - combined_list = roi_mask_list + neuropil_masks # read the large image stack only once + combined_list = roi_mask_list + neuropil_masks # read the large image stack only once with h5py.File(movie_h5, "r") as movie_f: stack_frames = movie_f["data"] diff --git a/allensdk/brain_observatory/session_api_utils.py b/allensdk/brain_observatory/session_api_utils.py index 6ec6f76f68..7dc0468d6d 100644 --- a/allensdk/brain_observatory/session_api_utils.py +++ b/allensdk/brain_observatory/session_api_utils.py @@ -20,7 +20,7 @@ def is_equal(a: Any, b: Any) -> bool: """Function to deal with checking if two variables of possibly mixed types have the same value.""" - if type(a) != type(b): + if type(a) != type(b): # noqa: E721 return False if isinstance(a, (pd.Series, pd.DataFrame)): diff --git a/allensdk/core/_data_object_base/data_object.py b/allensdk/core/_data_object_base/data_object.py index 76a7ada069..c7d9e51eea 100644 --- a/allensdk/core/_data_object_base/data_object.py +++ b/allensdk/core/_data_object_base/data_object.py @@ -157,7 +157,7 @@ def is_prop(attr): return {name: getattr(self, name) for name in props} def __eq__(self, other: "DataObject"): - if type(self) != type(other): + if type(self) != type(other): # noqa: E721 msg = f"Do not know how to compare with type {type(other)}" raise NotImplementedError(msg) diff --git a/allensdk/core/brain_observatory_nwb_data_set.py b/allensdk/core/brain_observatory_nwb_data_set.py index be992d0aaa..da0b1996f5 100755 --- a/allensdk/core/brain_observatory_nwb_data_set.py +++ b/allensdk/core/brain_observatory_nwb_data_set.py @@ -773,7 +773,7 @@ def get_metadata(self): # parse the device string (ugly, sorry) device_string = meta.pop("device_string", None) if device_string: - m = re.match("(.*?)\.\s(.*?)\sPlease*", device_string) + m = re.match(r"(.*?)\.\s(.*?)\sPlease*", device_string) if m: device, device_name = m.groups() meta["device"] = device diff --git a/allensdk/core/cell_types_cache.py b/allensdk/core/cell_types_cache.py index 4da5b62f66..63ef6c9fcc 100644 --- a/allensdk/core/cell_types_cache.py +++ b/allensdk/core/cell_types_cache.py @@ -39,7 +39,7 @@ from allensdk.api.warehouse_cache.cache import Cache, get_default_manifest_file from allensdk.api.queries.cell_types_api import CellTypesApi -from . import json_utilities as json_utilities +from . import json_utilities as json_utilities # noqa: F401 from .nwb_data_set import NwbDataSet from . import swc diff --git a/allensdk/ephys/ephys_extractor.py b/allensdk/ephys/ephys_extractor.py index 497db44a3d..649e8e844a 100644 --- a/allensdk/ephys/ephys_extractor.py +++ b/allensdk/ephys/ephys_extractor.py @@ -750,13 +750,13 @@ def _set_sweeps( baseline_detect_thresh, id_set, ): - if type(t_set) != list: + if type(t_set) != list: # noqa: E721 raise ValueError("t_set must be a list") - if type(v_set) != list: + if type(v_set) != list: # noqa: E721 raise ValueError("v_set must be a list") - if i_set is not None and type(i_set) != list: + if i_set is not None and type(i_set) != list: # noqa: E721 raise ValueError("i_set must be a list") if len(t_set) != len(v_set): diff --git a/allensdk/ephys/feature_extractor.py b/allensdk/ephys/feature_extractor.py index 974fb13e1f..1384bfe562 100644 --- a/allensdk/ephys/feature_extractor.py +++ b/allensdk/ephys/feature_extractor.py @@ -518,7 +518,7 @@ def adaptation_index(self, spikes, stim_end): adi /= cnt return adi - ##---------------------------------------------------------------------- + # ---------------------------------------------------------------------- # trough (AHP) is presently defined as the minimum voltage level # observed between successive spikes in a burst diff --git a/allensdk/internal/brain_observatory/itracker.py b/allensdk/internal/brain_observatory/itracker.py index fac605a4d4..13008b1075 100644 --- a/allensdk/internal/brain_observatory/itracker.py +++ b/allensdk/internal/brain_observatory/itracker.py @@ -395,7 +395,7 @@ def process_image(self, im, bbox_pupil=None, bbox_cr=None): pupil_params = result # fe.ransac_fit(pupil_candidate_points) else: logging.debug("No good fit found") - pupil_params = ((np.nan, np.nan), np.nan, (np.nan, np.nan)) # np.nan*np.ones(5) + pupil_params = ((np.nan, np.nan), np.nan, (np.nan, np.nan)) # np.nan*np.ones(5) # code for finding corneal reflection, start with finding rays from center of cr cr_rays, cr_ray_values = generate_rays(im, self.cr_loc) diff --git a/allensdk/internal/ephys/core_feature_extract.py b/allensdk/internal/ephys/core_feature_extract.py index 04a49a5b1f..47006bde0d 100644 --- a/allensdk/internal/ephys/core_feature_extract.py +++ b/allensdk/internal/ephys/core_feature_extract.py @@ -234,7 +234,7 @@ def generate_output_cell_features(cell_features, sweep_features, sweep_index): def extract_data(data, nwb_file): ########################################################## - #### alings with ephys_sweep_qc_tool extract_features #### + # alings with ephys_sweep_qc_tool extract_features #### cell_specimen = data["specimens"][0] sweep_list = cell_specimen["ephys_sweeps"] sweep_index = {s["sweep_number"]: s for s in sweep_list} @@ -323,7 +323,7 @@ def extract_data(data, nwb_file): except IndexError: cell_specimen["ephys_features"] = [ephys_features] - #### breaks with ephys_sweep_qc_tool extract_features #### + # breaks with ephys_sweep_qc_tool extract_features #### ########################################################## return sweep_list, sweep_features diff --git a/allensdk/internal/model/GLM.py b/allensdk/internal/model/GLM.py index 5c67f9875d..d87d3aef78 100644 --- a/allensdk/internal/model/GLM.py +++ b/allensdk/internal/model/GLM.py @@ -23,7 +23,7 @@ def create_basis_IPSP(neye, ncos, kpeaks, ks, DTsim, t0, I_stim, nkt, flag_exp, # print int(t0[kk]), spind-190000 spike_stim[spind] = 1.0 - ##Convolve temporal basis functions with spike-stim + # Convolve temporal basis functions with spike-stim c = np.zeros((len(spike_stim), ncos)) for jj in range(ncos): basisfilt = gg0["ktbas"][:, jj] diff --git a/allensdk/internal/model/biophysical/fit_stage_1.py b/allensdk/internal/model/biophysical/fit_stage_1.py index 0479dd0c33..4375215fe2 100644 --- a/allensdk/internal/model/biophysical/fit_stage_1.py +++ b/allensdk/internal/model/biophysical/fit_stage_1.py @@ -326,7 +326,7 @@ def prepare_stage_1(description, passive_fit_data): target_dict = {} target_dict["passive"] = [{"ra": ra, "cm": {"soma": cm1, "axon": cm1, "dend": cm2}, "e_pas": baseline_v}] - swc_data = pd.read_table(swc_path, sep="\s", comment="#", header=None) + swc_data = pd.read_table(swc_path, sep=r"\s", comment="#", header=None) has_apic = False if APICAL_DENDRITE_TYPE in pd.unique(swc_data[1]): has_apic = True diff --git a/allensdk/internal/model/glif/ASGLM.py b/allensdk/internal/model/glif/ASGLM.py index c337869216..06ce6dbb91 100644 --- a/allensdk/internal/model/glif/ASGLM.py +++ b/allensdk/internal/model/glif/ASGLM.py @@ -250,7 +250,7 @@ def ASGLM_pairwise( asc_amp_for_all_ks_pairs.append(asc_amp_for_each_sweep) llh_for_all_ks_pairs.append(llh_for_each_sweep) - #!!!!!!!!!!!!!we multiplied ks by dt for SI!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # !!!!!!!!!!!!!we multiplied ks by dt for SI!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ave_llh_for_each_pair = np.mean(llh_for_all_ks_pairs, axis=1) best_ks_pair_ind = np.where(np.max(ave_llh_for_each_pair) == ave_llh_for_each_pair)[0][0] diff --git a/allensdk/internal/model/glif/error_functions.py b/allensdk/internal/model/glif/error_functions.py index 117018e319..36b77db87d 100644 --- a/allensdk/internal/model/glif/error_functions.py +++ b/allensdk/internal/model/glif/error_functions.py @@ -177,7 +177,7 @@ def MLIN_list_error(param_guess, experiment, input_data): ] noSpike_prob = np.append( noSpike_negDiff, noSpike_posDiff - ) #!!NOTE: this may not line up correctly in outputs of MLIN HACK + ) # !!NOTE: this may not line up correctly in outputs of MLIN HACK spike_negDiff = (np.log(1.0 - 0.5 * np.exp(np.array(spike_bins["v_th_diff"]) / sv)))[ np.array(spike_bins["v_th_diff"]) <= 0.0 diff --git a/allensdk/internal/model/glif/find_sweeps.py b/allensdk/internal/model/glif/find_sweeps.py index cac860c634..5f7a2b7017 100644 --- a/allensdk/internal/model/glif/find_sweeps.py +++ b/allensdk/internal/model/glif/find_sweeps.py @@ -62,8 +62,8 @@ def organize_sweeps_by_name(sweeps, name): "suprathreshold": get_sweep_numbers(suprathreshold_list), "maximum_subthreshold": find_ranked_sweep(subthreshold_list, "stimulus_amplitude", reverse=True), "minimum_suprathreshold": find_ranked_sweep(suprathreshold_list, "stimulus_amplitude"), - #'maximum_subthreshold': find_ranked_sweep(subthreshold_list, 'stimulus_absolute_amplitude', reverse=True), - #'minimum_suprathreshold': find_ranked_sweep(suprathreshold_list, 'stimulus_absolute_amplitude') + # 'maximum_subthreshold': find_ranked_sweep(subthreshold_list, 'stimulus_absolute_amplitude', reverse=True), + # 'minimum_suprathreshold': find_ranked_sweep(suprathreshold_list, 'stimulus_absolute_amplitude') } diff --git a/allensdk/internal/model/glif/glif_optimizer.py b/allensdk/internal/model/glif/glif_optimizer.py index 3ad66139e6..21680cd3ae 100644 --- a/allensdk/internal/model/glif/glif_optimizer.py +++ b/allensdk/internal/model/glif/glif_optimizer.py @@ -305,19 +305,19 @@ def run_once(self, param0): # options={} # # options['avextox']=eps # options['maxiter']=500 -## options['full_output']=True -## options['disp']=True -## options['retall']=True +# options['full_output']=True +# options['disp']=True +# options['retall']=True # # print('Using Newton-CG method') # iteration_start_time = time.time() # xopt = minimize(self.error_function, param0, args=(self.experiment,), method='Newton-CG', jac=f_prime_constructor(self.error_function), callback=mycallback_ncg, options=options, tol=eps) # print('Newton-CG method took', (time.time()-iteration_start_time)/60., 'seconds') # -## print('Using Nelder-Mead method') -## iteration_start_time = time.time() -## xopt = minimize(self.error_function, param0, args=(self.experiment,), method='Nelder-Mead', callback=mycallback_nm, options=options, tol=eps) -## print('Nelder-Mead method took', (time.time()-iteration_start_time)/60., 'seconds') +# print('Using Nelder-Mead method') +# iteration_start_time = time.time() +# xopt = minimize(self.error_function, param0, args=(self.experiment,), method='Nelder-Mead', callback=mycallback_nm, options=options, tol=eps) +# print('Nelder-Mead method took', (time.time()-iteration_start_time)/60., 'seconds') # # print(xopt) # return xopt, fopt diff --git a/allensdk/internal/model/glif/preprocess_neuron.py b/allensdk/internal/model/glif/preprocess_neuron.py index ee4a4191b4..d3ae0b2573 100644 --- a/allensdk/internal/model/glif/preprocess_neuron.py +++ b/allensdk/internal/model/glif/preprocess_neuron.py @@ -525,7 +525,7 @@ def fill_in_lists(out_list, data_list): for_reference_dict[ "resistance" - ] = { #'R_lssq_Wrest':{'mean': R_lssq_wrest_mean, 'list': R_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, + ] = { # 'R_lssq_Wrest':{'mean': R_lssq_wrest_mean, 'list': R_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, "R_from_lims": {"value": cell_properties["ri"] * 1e6}, "R_test_list": {"mean": R_test_list_mean, "list": R_test_list}, "R_fit_ASC_and_R": {"mean": R_from_ASGLM, "list": best_R_fit_ascR}, @@ -533,7 +533,7 @@ def fill_in_lists(out_list, data_list): for_reference_dict[ "capacitance" - ] = { #'C_lssq_Wrest': {'mean':C_lssq_wrest_mean, 'list':C_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, + ] = { # 'C_lssq_Wrest': {'mean':C_lssq_wrest_mean, 'list':C_lssq_wrest_list, 'dependencies': 'from subthreshold (no spike cutting) noise'}, "C_from_lims": {"value": (cell_properties["tau"] * 1e-3) / (cell_properties["ri"] * 1e6)}, "C_test_list": {"mean": C_test_list_mean, "list": C_test_list}, } diff --git a/allensdk/internal/model/glif/spike_cutting.py b/allensdk/internal/model/glif/spike_cutting.py index 5bb0e2e380..8f58eeb73d 100644 --- a/allensdk/internal/model/glif/spike_cutting.py +++ b/allensdk/internal/model/glif/spike_cutting.py @@ -240,7 +240,7 @@ def plot_hack(slope, intercept, r, xlim): if isinstance(intercept_at_min_expVar_list, np.ndarray): intercept_at_min_expVar_list = float(intercept_at_min_expVar_list[0]) - if type(intercept_at_min_expVar_list) == list or type(intercept_at_min_expVar_list) == np.ndarray: + if type(intercept_at_min_expVar_list) == list or type(intercept_at_min_expVar_list) == np.ndarray: # noqa: E721 intercept_at_min_expVar_list = intercept_at_min_expVar_list[0] return spike_cut_length, slope_at_min_expVar_list, intercept_at_min_expVar_list diff --git a/allensdk/internal/model/glif/threshold_adaptation.py b/allensdk/internal/model/glif/threshold_adaptation.py index 07014fca6e..cf0713ce9a 100644 --- a/allensdk/internal/model/glif/threshold_adaptation.py +++ b/allensdk/internal/model/glif/threshold_adaptation.py @@ -283,7 +283,7 @@ def fit_avoltage_bvoltage( # Compute voltage component of threshold at biological spike (subtract th_inf and spike component of threshold # from biological voltage values at spike initiation) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # NOTE THAT THERE IS AN ISSUE HERE USING FAKE DATA. THE -1 IS HERE BECAUSE THE NEURON CROSSES THRESHOLD SOMETIME BETWEEN TWO INDICIES. # FOR THE FAKE DATA THE TIME OF THE SPIKE (THE POINT FOLLOWING WHEN THE VOLTAGE CROSSES THRESHOLD) IS SET TO NAN. # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE @@ -296,7 +296,7 @@ def fit_avoltage_bvoltage( v_comp_of_th_at_each_spike_via_data = ( v_trace[all_spikeInd] - internal_sp_comp_array[all_spikeInd] - th_inf ) # USE THIS FOR REAL DATA (although probably not necessary) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # For each ISI, calculate the difference between the voltage dependent component of the threshold # and the value that would be determined via a model that uses the actual voltage of neuron. @@ -401,7 +401,7 @@ def fit_avoltage_bvoltage_th( # Compute voltage component of threshold at biological spike (subtract th_inf and spike component of threshold # from biological voltage values at spike initiation) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # NOTE THAT THERE IS AN ISSUE HERE USING FAKE DATA. THE -1 IS HERE BECAUSE THE NEURON CROSSES THRESHOLD SOMETIME BETWEEN TWO INDICIES. # FOR THE FAKE DATA THE TIME OF THE SPIKE (THE POINT FOLLOWING WHEN THE VOLTAGE CROSSES THRESHOLD) IS SET TO NAN. # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE @@ -414,7 +414,7 @@ def fit_avoltage_bvoltage_th( v_comp_of_th_at_each_spike_via_data = ( v_trace[all_spikeInd] - internal_sp_comp_array[all_spikeInd] - th_inf ) # USE THIS FOR REAL DATA (although probably not necessary) - #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # For each ISI, calculate the difference between the voltage dependent component of the threshold # and the value that would be determined via a model that uses the actual voltage of neuron. @@ -496,7 +496,7 @@ def fit_avoltage_bvoltage_th( # t=(all_spikeInd[spike_number]-all_spikeInd[spike_number-1]-spike_cut_length)*dt #spike ISI # sp_comp_of_th_offset_local = spike_component_of_threshold_exact(a_spike, b_spike, t) #spike component of threshold at each ISI for each individual spike # #I THINK THE LINE BELOW MIGHT JUST BE WRONG BECAUSE THE OLD OFF SET WOULD DECAY AND I DONT THINK IT IS HERE:THIS IS WHAT IS BEING USED -## sp_comp_of_offset_sum_vector.append(sp_comp_of_offset_sum_vector[-1] + sp_comp_of_th_offset_local) #keeping track of residual spike component of threshold at each spike +# sp_comp_of_offset_sum_vector.append(sp_comp_of_offset_sum_vector[-1] + sp_comp_of_th_offset_local) #keeping track of residual spike component of threshold at each spike # left_over_decay=spike_component_of_threshold_exact(sp_comp_of_offset_sum_vector[-1], b_spike, t) # sp_comp_of_offset_sum_vector.append(left_over_decay + sp_comp_of_th_offset_local) #keeping track of spike component of threshold with residuals at each spike # @@ -509,7 +509,7 @@ def fit_avoltage_bvoltage_th( # # THE INTERPOLATED VOLTAGE CAN BE USED BUT THEN THE INTERPOLATED VOLTAGE MUST BE CALCULATED FOR THE TRUE VOLTAGE # # TRACE AND POSSIBLY IN THE INTEGRATION. # v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd-1]-np.array(sp_comp_of_offset_sum_vector)-th_inf #USE THIS FOR FAKE DATA -## v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd]-np.array(sp_comp_of_offset_sum_vector)-th_inf #THIS IS PROBABLY APPROPRIATE FOR REAL DATA +# v_comp_of_th_at_each_spike_via_data=v_trace[all_spikeInd]-np.array(sp_comp_of_offset_sum_vector)-th_inf #THIS IS PROBABLY APPROPRIATE FOR REAL DATA # #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # # # For each ISI, calculate the difference between the v_trace dependent component of the threshold @@ -567,16 +567,16 @@ def fit_avoltage_bvoltage_th( # artifact_removed_voltage[smooth_window]=blah(smooth_window) # # #windows boarders are just for plotting -## window_boarders_index.append(smooth_window[0]) -## window_boarders_index.append(smooth_window[-1]) -## plt.figure() -## plt.plot(voltage, 'b', lw=4) -## plt.plot(artifact_removed_voltage, 'r', lw=2) -## plt.plot(window_boarders_index, artifact_removed_voltage[window_boarders_index], '|g', ms=10) -## plt.xlim([40400, 41000]) -## plt.show() +# window_boarders_index.append(smooth_window[0]) +# window_boarders_index.append(smooth_window[-1]) +# plt.figure() +# plt.plot(voltage, 'b', lw=4) +# plt.plot(artifact_removed_voltage, 'r', lw=2) +# plt.plot(window_boarders_index, artifact_removed_voltage[window_boarders_index], '|g', ms=10) +# plt.xlim([40400, 41000]) +# plt.show() ## -## t = np.arange(0, len()) * dt +# t = np.arange(0, len()) * dt # # # keeping the smooth_v convention of the SDK find spike code. However in the SDK code # # this is used to name data potentially smoothed by a bessel filter @@ -652,29 +652,29 @@ def fit_avoltage_bvoltage_th( # # out_spk_idxs_list.append(np.array(out_spk_idxs)) # -## time_vector=np.arange(len(v))*dt -## plt.figure() -## # plt.subplot(3,1,1) -## # plt.plot(time_vector, ddv) -## # plt.plot(time_vector[out_spk_idxs], ddv[out_spk_idxs], '.r', ms=16) -## # plt.xlim([40300, 42000]) -## # plt.ylabel('ddv') -## plt.subplot(2,1,1) -## plt.plot(time_vector, dvdt) -## plt.plot(time_vector[out_spk_idxs], dvdt[out_spk_idxs], '.r', ms=16) -## plt.plot(time_vector[potential_artifact_indexes], dvdt[potential_artifact_indexes], 'b|', ms=24, lw=4) -## plt.xlim([40300*dt, 42000*dt]) -## plt.ylabel('dvdt') -## plt.subplot(2,1,2) -## plt.plot(time_vector, v) -## plt.plot(time_vector[out_spk_idxs], v[out_spk_idxs], 'r.', ms=16, label='threshold') -## plt.xlim([40300*dt, 42000*dt]) -## plt.ylabel('voltage (V)') -## plt.plot(time_vector[peaks], v[peaks], '.g', ms=16, label='peaks') -## plt.plot(time_vector[[spikes[ii]['upstroke_idx'] for ii in range(len(spikes))]], [spikes[ii]['upstroke_v'] for ii in range(len(spikes))], '.c', ms=16, label = 'max upstroke') -## plt.plot(time_vector[potential_artifact_indexes], v[potential_artifact_indexes], 'b|', ms=24, lw=4) -## plt.legend() -## plt.show() +# time_vector=np.arange(len(v))*dt +# plt.figure() +# # plt.subplot(3,1,1) +# # plt.plot(time_vector, ddv) +# # plt.plot(time_vector[out_spk_idxs], ddv[out_spk_idxs], '.r', ms=16) +# # plt.xlim([40300, 42000]) +# # plt.ylabel('ddv') +# plt.subplot(2,1,1) +# plt.plot(time_vector, dvdt) +# plt.plot(time_vector[out_spk_idxs], dvdt[out_spk_idxs], '.r', ms=16) +# plt.plot(time_vector[potential_artifact_indexes], dvdt[potential_artifact_indexes], 'b|', ms=24, lw=4) +# plt.xlim([40300*dt, 42000*dt]) +# plt.ylabel('dvdt') +# plt.subplot(2,1,2) +# plt.plot(time_vector, v) +# plt.plot(time_vector[out_spk_idxs], v[out_spk_idxs], 'r.', ms=16, label='threshold') +# plt.xlim([40300*dt, 42000*dt]) +# plt.ylabel('voltage (V)') +# plt.plot(time_vector[peaks], v[peaks], '.g', ms=16, label='peaks') +# plt.plot(time_vector[[spikes[ii]['upstroke_idx'] for ii in range(len(spikes))]], [spikes[ii]['upstroke_v'] for ii in range(len(spikes))], '.c', ms=16, label = 'max upstroke') +# plt.plot(time_vector[potential_artifact_indexes], v[potential_artifact_indexes], 'b|', ms=24, lw=4) +# plt.legend() +# plt.show() # # return out_spk_idxs_list diff --git a/allensdk/internal/morphology/morphology.py b/allensdk/internal/morphology/morphology.py index 5b5b104149..bfd3104505 100644 --- a/allensdk/internal/morphology/morphology.py +++ b/allensdk/internal/morphology/morphology.py @@ -797,8 +797,8 @@ def apply_affine(self, aff, scale=None): # assume equal scaling along all axes. take 3rd root to get # scale factor det_scale = np.power(abs(determinant), 1.0 / 3.0) - ## measure scale along each axis - ## keep this code here in case + # measure scale along each axis + # keep this code here in case # scale_x = abs(aff[0] + aff[3] + aff[6]) # scale_y = abs(aff[1] + aff[4] + aff[7]) # scale_z = abs(aff[2] + aff[5] + aff[8]) diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py b/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py index 2dc67a3c27..5a419d2f5a 100755 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/surrogate_strategy.py @@ -52,7 +52,7 @@ def prep_json(spec_id): block["path"] = path block["label"] = label poly.append(block) - ## break down string path into two numeric arrays + # break down string path into two numeric arrays # path_array = np.array(path.split(',')) # path_x = np.array(path_array[0::2], dtype=float) # path_y = np.array(path_array[1::2], dtype=float) diff --git a/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py b/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py index fceb73ed0b..d854545244 100644 --- a/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py +++ b/allensdk/internal/pipeline_modules/cell_types/morphology/upright_transform.py @@ -183,7 +183,7 @@ def vector_angle(v1, v2): def main(jin): # per IT-14567, blockface analysis is no longer required ######################################################################### - ## analyze blockface image + # analyze blockface image # try: # soma = jin["blockface"]["Soma"]["path"] # pia = jin["blockface"]["Pia"]["path"] @@ -193,7 +193,7 @@ def main(jin): # print("** Error -- missing requisite blockface field(s) in input json") # raise # - ## get soma position using weighted average of vertices + # get soma position using weighted average of vertices # try: # sx, sy = convert_coords_str(soma) # soma_x, soma_y = calculate_centroid(sx, sy) @@ -201,12 +201,12 @@ def main(jin): # print("** Error -- unable to calculate soma information (blockface)") # raise # - ## calculate shortest path + # calculate shortest path # try: # px, py, wx, wy = calculate_shortest(soma_x, soma_y, pia, wm) - ## calculate theta and affine + # calculate theta and affine # theta = vector_angle((0, 1), np.asarray([px,py]) - np.asarray([wx,wy])) - ## calculate soma depth and cortical thickness + # calculate soma depth and cortical thickness # depth = res * euclidean((soma_x, soma_y), (px, py)) # blk_thickness = res * euclidean((wx, wy), (px, py)) # except: diff --git a/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py b/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py index 36bd2f28a0..ab27b0f07e 100644 --- a/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py +++ b/allensdk/internal/pipeline_modules/run_observatory_thumbnails.py @@ -96,7 +96,7 @@ def get_input_data(experiment_id): input_data = { "nwb_file": nwb_file, - #'analysis_file': analysis_file, + # 'analysis_file': analysis_file, "analysis_file": analysis_file, "output_directory": output_directory, } diff --git a/allensdk/model/glif/glif_neuron.py b/allensdk/model/glif/glif_neuron.py index f124e6e5f3..0b1a981394 100755 --- a/allensdk/model/glif/glif_neuron.py +++ b/allensdk/model/glif/glif_neuron.py @@ -524,7 +524,7 @@ def interpolate_spike_value(dt, interpolated_spike_time_offset, v0, v1): def line_crossing_x(dx, a0, a1, b0, b1): """Find the x value of the intersection of two lines.""" - assert type(a0) != int and type(a1) != int and type(b0) != int and type(b1) != int, Exception( + assert type(a0) != int and type(a1) != int and type(b0) != int and type(b1) != int, Exception( # noqa: E721 "Do not pass integers into this function!" ) return dx * (b0 - a0) / ((a1 - a0) - (b1 - b0)) diff --git a/allensdk/test/api/test_svg_api.py b/allensdk/test/api/test_svg_api.py index e8bb1ffa92..1237861812 100644 --- a/allensdk/test/api/test_svg_api.py +++ b/allensdk/test/api/test_svg_api.py @@ -46,35 +46,35 @@ def svg(): def test_build_query(svg): - ####download true url + # download true url download = True groups = None section_image_id = 21889 returned_url = svg.build_query(section_image_id, groups, download) assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889" - ####download true with one group url + # download true with one group url download = True groups = [1] section_image_id = 21889 returned_url = svg.build_query(section_image_id, groups, download) assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1" - ####download true with groups url + # download true with groups url download = True groups = [1, 2] section_image_id = 21889 returned_url = svg.build_query(section_image_id, groups, download) assert returned_url == "http://api.brain-map.org/api/v2/svg_download/21889?groups=1,2" - ####download false url + # download false url download = False groups = None section_image_id = 21889 returned_url = svg.build_query(section_image_id, groups, download) assert returned_url == "http://api.brain-map.org/api/v2/svg/21889" - ####download false groups exist url + # download false groups exist url download = False groups = [28] section_image_id = 21889 @@ -95,14 +95,14 @@ def test_download_svg(svg): def test_get_svg(svg): svg.retrieve_xml_over_http = MagicMock(name="retrieve_xml_over_http") - ####groups None + # groups None section_image_id = 100960033 groups = None svg.get_svg(section_image_id, groups) svg.retrieve_xml_over_http.assert_called_with("http://api.brain-map.org/api/v2/svg/100960033") - ####groups in 28 + # groups in 28 section_image_id = 100960033 groups = [28] diff --git a/allensdk/test/api/test_tree_search_api.py b/allensdk/test/api/test_tree_search_api.py index 79a474f05f..c48b084fde 100644 --- a/allensdk/test/api/test_tree_search_api.py +++ b/allensdk/test/api/test_tree_search_api.py @@ -33,7 +33,7 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. # -####test AllenSDK tree search api for Specimen and Structure +# test AllenSDK tree search api for Specimen and Structure from allensdk.api.queries.tree_search_api import TreeSearchApi import pytest from unittest.mock import MagicMock @@ -48,7 +48,7 @@ def tree_search(): def test_get_specimen_tree(tree_search): - ####ancestor true for Specimen + # ancestor true for Specimen kind = "Specimen" db_id = 113817886 ancestors = True @@ -58,7 +58,7 @@ def test_get_specimen_tree(tree_search): "http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true" ) - ####ancestor true for Specimen + # ancestor true for Specimen kind = "Specimen" db_id = 113817886 ancestors = True @@ -68,7 +68,7 @@ def test_get_specimen_tree(tree_search): "http://api.brain-map.org/api/v2/tree_search/Specimen/113817886.json?ancestors=true&descendants=false" ) - ####ancestor false for Specimen + # ancestor false for Specimen kind = "Specimen" db_id = 113817886 ancestors = False @@ -80,7 +80,7 @@ def test_get_specimen_tree(tree_search): def test_get_structure_tree(tree_search): - ####ancestor True for Structure + # ancestor True for Structure kind = "Structure" db_id = 12547 ancestors = True @@ -90,7 +90,7 @@ def test_get_structure_tree(tree_search): "http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=true&descendants=true" ) - ####ancestor False for Structure + # ancestor False for Structure kind = "Structure" db_id = 12547 ancestors = False @@ -100,7 +100,7 @@ def test_get_structure_tree(tree_search): "http://api.brain-map.org/api/v2/tree_search/Structure/12547.json?ancestors=false&descendants=true" ) - ####ancestor None for Structure + # ancestor None for Structure kind = "Structure" db_id = 12547 ancestors = None diff --git a/allensdk/test_utilities/custom_comparators.py b/allensdk/test_utilities/custom_comparators.py index 503b2b5a8c..162cbc942d 100644 --- a/allensdk/test_utilities/custom_comparators.py +++ b/allensdk/test_utilities/custom_comparators.py @@ -137,7 +137,7 @@ def _nested_scalar_equivalence(val0: Any, val1: Any) -> bool: Return True if the scalars are identical. Return False otherwise. """ - if type(val0) != type(val1): + if type(val0) != type(val1): # noqa: E721 return False if isinstance(val0, numbers.Number): @@ -172,7 +172,7 @@ def _nested_iterable_equivalence(list0: Iterable, list1: Iterable) -> bool: for idx in range(len(list0)): v0 = list0[idx] v1 = list1[idx] - if type(v0) != type(v1): + if type(v0) != type(v1): # noqa: E721 return False if isinstance(v0, dict): @@ -206,7 +206,7 @@ def _nested_dict_equivalence(dict0: dict, dict1: dict) -> bool: for this_key in k0_list: val0 = dict0[this_key] val1 = dict1[this_key] - if type(val0) != type(val1): + if type(val0) != type(val1): # noqa: E721 return False if isinstance(val0, dict): diff --git a/pyproject.toml b/pyproject.toml index 18a4861b0e..bed81f3bb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,6 +118,8 @@ line-length = 120 ignore = [ "E402", # module-level import not at top of file "E741", # ambiguous variable name (l, O, I, etc.) + "E742", # ambiguous class name + "E743", # ambiguous function name "F403", # wildcard import used "F405", # name may be undefined from wildcard import ]