diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml index a977164456ba..5b141b50e183 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_general.yaml @@ -1,5 +1,5 @@ # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. +# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. # The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set. # An example line in an input manifest file (`.json` format): @@ -20,21 +20,21 @@ diarizer: ignore_overlap: True # Consider or ignore overlap segments while scoring vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name + model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets - window_length_in_sec: 0.63 # Window length in sec for VAD context input + window_length_in_sec: 0.63 # Window length in sec for VAD context input shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction smoothing: False # False or type of smoothing method (eg: median) overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.5 # Onset threshold for detecting the beginning and end of a speech + onset: 0.5 # Onset threshold for detecting the beginning and end of a speech offset: 0.3 # Offset threshold for detecting the end of a speech - pad_onset: 0.2 # Adding durations before each speech segment - pad_offset: 0.2 # Adding durations after each speech segment + pad_onset: 0.2 # Adding durations before each speech segment + pad_offset: 0.2 # Adding durations after each speech segment min_duration_on: 0.5 # Threshold for short speech segment deletion min_duration_off: 0.5 # Threshold for small non_speech deletion - filter_speech_first: True + filter_speech_first: True speaker_embeddings: model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) @@ -42,15 +42,15 @@ diarizer: window_length_in_sec: [1.9,1.2,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] shift_length_in_sec: [0.95,0.6,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] multiscale_weights: [1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. - + save_embeddings: True # If True, save speaker embedding tensor. + clustering: parameters: oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 10 # The higher the number, the more values will be examined with more time. + max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. + sparse_search_volume: 10 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) @@ -62,13 +62,13 @@ diarizer: asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. + word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. colored_text: False # If True, use colored text to distinguish speakers in the output transcript. print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - + ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. beam_width: 32 diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml index 90c478b08606..303cbb8e9041 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml @@ -1,5 +1,5 @@ # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. +# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. # The configurations in this YAML file is suitable for 3~5 speakers participating in a meeting and may not show the best performance on other types of dialogues. # An example line in an input manifest file (`.json` format): @@ -20,21 +20,21 @@ diarizer: ignore_overlap: True # Consider or ignore overlap segments while scoring vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name + model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set - parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) - window_length_in_sec: 0.63 # Window length in sec for VAD context input + parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) + window_length_in_sec: 0.63 # Window length in sec for VAD context input shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction smoothing: False # False or type of smoothing method (eg: median) overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.9 # Onset threshold for detecting the beginning and end of a speech + onset: 0.9 # Onset threshold for detecting the beginning and end of a speech offset: 0.5 # Offset threshold for detecting the end of a speech - pad_onset: 0 # Adding durations before each speech segment - pad_offset: 0 # Adding durations after each speech segment + pad_onset: 0 # Adding durations before each speech segment + pad_offset: 0 # Adding durations after each speech segment min_duration_on: 0 # Threshold for short speech segment deletion min_duration_off: 0.6 # Threshold for small non_speech deletion - filter_speech_first: True + filter_speech_first: True speaker_embeddings: model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) @@ -42,19 +42,19 @@ diarizer: window_length_in_sec: [3.0,2.5,2.0,1.5,1.0,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] shift_length_in_sec: [1.5,1.25,1.0,0.75,0.5,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] multiscale_weights: [1,1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. - + save_embeddings: True # If True, save speaker embedding tensor. + clustering: parameters: oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. + max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. + sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. - embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) - + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) + asr: model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. parameters: @@ -62,13 +62,13 @@ diarizer: asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. + word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. colored_text: False # If True, use colored text to distinguish speakers in the output transcript. print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - + ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. beam_width: 32 diff --git a/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml b/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml index a51883315e08..14f6b4e97183 100644 --- a/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml +++ b/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml @@ -1,5 +1,5 @@ # This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. +# The inference parameters for VAD, speaker embedding extractor, clustering module, ASR decoder are all included in this YAML file. # All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. # The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues. # An example line in an input manifest file (`.json` format): @@ -20,21 +20,21 @@ diarizer: ignore_overlap: True # Consider or ignore overlap segments while scoring vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name + model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set - parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) - window_length_in_sec: 0.15 # Window length in sec for VAD context input + parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) + window_length_in_sec: 0.15 # Window length in sec for VAD context input shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction smoothing: "median" # False or type of smoothing method (eg: median) overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.1 # Onset threshold for detecting the beginning and end of a speech + onset: 0.1 # Onset threshold for detecting the beginning and end of a speech offset: 0.1 # Offset threshold for detecting the end of a speech - pad_onset: 0.1 # Adding durations before each speech segment - pad_offset: 0 # Adding durations after each speech segment + pad_onset: 0.1 # Adding durations before each speech segment + pad_offset: 0 # Adding durations after each speech segment min_duration_on: 0 # Threshold for short speech segment deletion min_duration_off: 0.2 # Threshold for small non_speech deletion - filter_speech_first: True + filter_speech_first: True speaker_embeddings: model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) @@ -42,19 +42,19 @@ diarizer: window_length_in_sec: [1.5,1.25,1.0,0.75,0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] shift_length_in_sec: [0.75,0.625,0.5,0.375,0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] multiscale_weights: [1,1,1,1,1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. - - clustering: + save_embeddings: True # If True, save speaker embedding tensor. + + clustering: parameters: oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. + max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. + sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. chunk_cluster_count: 50 # Number of forced clusters (overclustering) per unit chunk in long-form audio clustering. - embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) - + embeddings_per_chunk: 10000 # Number of embeddings in each chunk for long-form audio clustering. Adjust based on GPU memory capacity. (default: 10000, approximately 40 mins of audio) + asr: model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. parameters: @@ -62,13 +62,13 @@ diarizer: asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. + word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. colored_text: False # If True, use colored text to distinguish speakers in the output transcript. print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - + ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. beam_width: 32 diff --git a/examples/speaker_tasks/recognition/README.md b/examples/speaker_tasks/recognition/README.md index 0e0f5ae3b4fc..5cc60a2f62dc 100644 --- a/examples/speaker_tasks/recognition/README.md +++ b/examples/speaker_tasks/recognition/README.md @@ -13,45 +13,45 @@ Documentation section for speaker related tasks can be found at: ## Training Speaker Recognition models can be trained in a similar way as other models in NeMo using train and dev manifest files. Steps on how to create manifest files for voxceleb are provided below. -We provide three model configurations based on TitaNet, SpeakerNet and modified ECAPA_TDNN, with pretrained models provided for each of them. +We provide three model configurations based on TitaNet, SpeakerNet and modified ECAPA_TDNN, with pretrained models provided for each of them. For training titanet_large (channel-attention) model: ```bash -python speaker_reco.py --config_path='conf' --config_name='titanet_large.yaml' +python speaker_reco.py --config_path='conf' --config_name='titanet_large.yaml' ``` For training speakernet (x-vector) model: ```bash -python speaker_reco.py --config_path='conf' --config_name='SpeakerNet_verification_3x2x256.yaml' +python speaker_reco.py --config_path='conf' --config_name='SpeakerNet_verification_3x2x256.yaml' ``` For training ecapa_tdnn (channel-attention) model: ```bash -python speaker_reco.py --config_path='conf' --config_name='ecapa_tdnn.yaml' +python speaker_reco.py --config_path='conf' --config_name='ecapa_tdnn.yaml' ``` For step by step tutorial see [notebook](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb). ### Fine Tuning For fine tuning on a pretrained .nemo speaker recognition model, ```bash -python speaker_reco_finetune.py --config_path='conf' --config_name='titanet-finetune.yaml' +python speaker_reco_finetune.py --config_path='conf' --config_name='titanet-finetune.yaml' ``` for fine tuning tips see this [tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb) ## Inference -We provide generic scripts for manifest file creation, embedding extraction, Voxceleb evaluation and speaker ID inference. Hence most of the steps would be common and differ slightly based on your end application. +We provide generic scripts for manifest file creation, embedding extraction, Voxceleb evaluation and speaker ID inference. Hence most of the steps would be common and differ slightly based on your end application. We explain here the process for voxceleb EER calculation on voxceleb-O cleaned [trail file](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt) ### Manifest Creation -We first generate manifest file to get embeddings. The embeddings are then used by `voxceleb_eval.py` script to get EER +We first generate manifest file to get embeddings. The embeddings are then used by `voxceleb_eval.py` script to get EER ```bash # create list of files from voxceleb1 test folder (40 speaker test set) find -iname '*.wav' > voxceleb1_test_files.txt -python /scripts/speaker_tasks/filelist_to_manifest.py --filelist voxceleb1_test_files.txt --id -3 --out voxceleb1_test_manifest.json +python /scripts/speaker_tasks/filelist_to_manifest.py --filelist voxceleb1_test_files.txt --id -3 --out voxceleb1_test_manifest.json ``` -### Embedding Extraction +### Embedding Extraction Now using the manifest file created, we can extract embeddings to `data` folder using: ```bash python extract_speaker_embeddings.py --manifest=voxceleb1_test_manifest.json --model_path='titanet_large' --embedding_dir='./' @@ -65,39 +65,39 @@ embs = speaker_model.get_embedding('audio_path') ### Voxceleb Evaluation ``` bash -python voxceleb_eval.py --trial_file='/path/to/trail/file' --emb='./embeddings/voxceleb1_test_manifest_embeddings.pkl' -``` -The above command gives the performance of models on voxceleb-o cleaned trial file. +python voxceleb_eval.py --trial_file='/path/to/trail/file' --emb='./embeddings/voxceleb1_test_manifest_embeddings.pt' +``` +The above command gives the performance of models on voxceleb-o cleaned trial file. ### SpeakerID inference Using data from an enrollment set, one can infer labels on a test set using various backends such as cosine-similarity or a neural classifier. To infer speaker labels using cosine_similarity backend -```bash +```bash python speaker_identification_infer.py data.enrollment_manifest= data.test_manifest= backend.backend_model=cosine_similarity -``` +``` refer to conf/speaker_identification_infer.yaml for more options. ## Voxceleb Data Preparation -Scripts we provide for data preparation are very generic and can be applied to any dataset with a few path changes. -For VoxCeleb datasets, we first download the datasets individually and make a list of audio files. Then we use the script to generate manifest files for training and validation. -Download [voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and [voxceleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html) data. +Scripts we provide for data preparation are very generic and can be applied to any dataset with a few path changes. +For VoxCeleb datasets, we first download the datasets individually and make a list of audio files. Then we use the script to generate manifest files for training and validation. +Download [voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and [voxceleb2](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox2.html) data. -Once downloaded and uncompressed, use programs such as ffmpeg to convert audio files from m4a format to wav format. +Once downloaded and uncompressed, use programs such as ffmpeg to convert audio files from m4a format to wav format. Refer to the following sample command ```bash -ffmpeg -v 8 -i -f wav -acodec pcm_s16le +ffmpeg -v 8 -i -f wav -acodec pcm_s16le ``` Generate a list file that contains paths to all the dev audio files from voxceleb1 and voxceleb2 using find command as shown below: -```bash +```bash find -iname '*.wav' > voxceleb1_dev.txt find -iname '*.wav' > voxceleb2_dev.txt cat voxceleb1_dev.txt voxceleb2_dev.txt > voxceleb12.txt -``` +``` -This list file is now used to generate training and validation manifest files using a script provided in `/scripts/speaker_tasks/`. This script has optional arguments to split the whole manifest file in to train and dev and also chunk audio files to smaller segments for robust training (for testing, we don't need this). +This list file is now used to generate training and validation manifest files using a script provided in `/scripts/speaker_tasks/`. This script has optional arguments to split the whole manifest file in to train and dev and also chunk audio files to smaller segments for robust training (for testing, we don't need this). ```bash python /scripts/speaker_tasks/filelist_to_manifest.py --filelist voxceleb12.txt --id -3 --out voxceleb12_manifest.json --split --create_segments diff --git a/examples/speaker_tasks/recognition/extract_speaker_embeddings.py b/examples/speaker_tasks/recognition/extract_speaker_embeddings.py index e67dc7d5aec3..511c4517b42d 100644 --- a/examples/speaker_tasks/recognition/extract_speaker_embeddings.py +++ b/examples/speaker_tasks/recognition/extract_speaker_embeddings.py @@ -15,13 +15,13 @@ """ This is a helper script to extract speaker embeddings based on manifest file Usage: -python extract_speaker_embeddings.py --manifest=/path/to/manifest/file' +python extract_speaker_embeddings.py --manifest=/path/to/manifest/file' --model_path='/path/to/.nemo/file'(optional) --embedding_dir='/path/to/embedding/directory' Args: --manifest: path to manifest file containing audio_file paths for which embeddings need to be extracted ---model_path(optional): path to .nemo speaker verification model file to extract embeddings, if not passed SpeakerNet-M model would +--model_path(optional): path to .nemo speaker verification model file to extract embeddings, if not passed SpeakerNet-M model would be downloaded from NGC and used to extract embeddings --embeddings_dir(optional): path to directory where embeddings need to stored default:'./' @@ -30,7 +30,6 @@ import json import os -import pickle as pkl from argparse import ArgumentParser import numpy as np @@ -43,10 +42,10 @@ def get_embeddings(speaker_model, manifest_file, batch_size=1, embedding_dir='./', device='cuda'): """ - save embeddings to pickle file + save embeddings to cached file Args: - speaker_model: NeMo model - manifest_file: path to the manifest file containing the audio file path from which the + speaker_model: NeMo model + manifest_file: path to the manifest file containing the audio file path from which the embeddings should be extracted batch_size: batch_size for inference embedding_dir: path to directory to store embeddings file @@ -72,15 +71,18 @@ def get_embeddings(speaker_model, manifest_file, batch_size=1, embedding_dir='./ prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2] name = os.path.join(embedding_dir, prefix) - embeddings_file = name + '_embeddings.pkl' - pkl.dump(out_embeddings, open(embeddings_file, 'wb')) + embeddings_file = name + '_embeddings.pt' + torch.save(out_embeddings, embeddings_file) logging.info("Saved embedding files to {}".format(embedding_dir)) def main(): parser = ArgumentParser() parser.add_argument( - "--manifest", type=str, required=True, help="Path to manifest file", + "--manifest", + type=str, + required=True, + help="Path to manifest file", ) parser.add_argument( "--model_path", @@ -90,7 +92,11 @@ def main(): help="path to .nemo speaker verification model file to extract embeddings, if not passed SpeakerNet-M model would be downloaded from NGC and used to extract embeddings", ) parser.add_argument( - "--batch_size", type=int, default=1, required=False, help="batch size", + "--batch_size", + type=int, + default=1, + required=False, + help="batch size", ) parser.add_argument( "--embedding_dir", diff --git a/examples/speaker_tasks/recognition/voxceleb_eval.py b/examples/speaker_tasks/recognition/voxceleb_eval.py index bf21c62c5709..88749031efdb 100644 --- a/examples/speaker_tasks/recognition/voxceleb_eval.py +++ b/examples/speaker_tasks/recognition/voxceleb_eval.py @@ -14,7 +14,6 @@ import argparse import os -import pickle as pkl import sys import numpy as np @@ -23,14 +22,13 @@ from sklearn.metrics import roc_curve from tqdm import tqdm - """ -This script faciliates to get EER % based on cosine-smilarity +This script faciliates to get EER % based on cosine-smilarity for Voxceleb dataset. Args: trial_file str: path to voxceleb trial file - emb : path to pickle file of embeddings dictionary (generated from spkr_get_emb.py) + emb : path to cached file of embeddings dictionary (generated from spkr_get_emb.py) save_kaldi_emb: if required pass this argument to save kaldi embeddings for KALDI PLDA training later Note: order of audio files in manifest file should match the embeddings """ @@ -40,8 +38,7 @@ def get_acc(trial_file='', emb='', save_kaldi_emb=False): trial_score = open('trial_score.txt', 'w') dirname = os.path.dirname(trial_file) - with open(emb, 'rb') as f: - emb = pkl.load(f) + emb = torch.load(emb) trial_embs = [] keys = [] all_scores = [] diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index 59f043f74555..4b76d1755f56 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -14,7 +14,6 @@ import json import os -import pickle as pkl import shutil import tempfile from copy import deepcopy @@ -378,8 +377,8 @@ def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: in prefix = get_uniqname_from_filepath(manifest_file) name = os.path.join(embedding_dir, prefix) - self._embeddings_file = name + '_embeddings.pkl' - pkl.dump(self.embeddings, open(self._embeddings_file, 'wb')) + self._embeddings_file = name + '_embeddings.pt' + torch.save(self.embeddings, self._embeddings_file) logging.info("Saved embedding files to {}".format(embedding_dir)) def diarize(self, paths2audio_files: List[str] = None, batch_size: int = 0): diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index 7f0b262ae420..4e86befcfcc0 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -12,20 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os.path - -import pickle -import warnings from dataclasses import dataclass - -try: - from joblib.numpy_pickle_utils import _read_fileobject as _validate_joblib_file -except ImportError: - from joblib.numpy_pickle_utils import _validate_fileobject_and_memmap as _validate_joblib_file import torch -from sklearn.linear_model import LogisticRegression -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import StandardScaler from nemo.collections.asr.parts.utils.asr_confidence_utils import ( ConfidenceConfig, @@ -150,84 +138,3 @@ def compute_confidence(hypothesis: Hypothesis, confidence_cfg: ConfidenceConfig) conf_value = aggr_func(conf_func(filtered_logprobs, v=vocab_size, t=alpha)).cpu().item() return conf_value - - -def safe_joblib_load(file_path: str) -> Pipeline: - """ - Safely load a joblib file containing a scikit-learn pipeline. - - Args: - file_path: Path to the joblib file - - Returns: - Pipeline: A scikit-learn pipeline object - - Raises: - ValueError: If the file doesn't exist or contains unauthorized content - SecurityError: If the file contains potentially malicious content - """ - if not os.path.exists(file_path): - raise ValueError(f"Model file not found: {file_path}") - - # Define whitelist of allowed classes for deserialization - ALLOWED_CLASSES = { - 'sklearn.pipeline.Pipeline', - 'sklearn.preprocessing._data.StandardScaler', - 'sklearn.linear_model._logistic.LogisticRegression', - 'numpy.ndarray', - 'numpy.dtype', - 'numpy._pickle', - } - - class RestrictedUnpickler(pickle.Unpickler): - def find_class(self, module, name): - # Only allow specific classes to be loaded - class_path = f"{module}.{name}" - if class_path in ALLOWED_CLASSES: - if module == "numpy._pickle": - import numpy as np - - return getattr(np, name) - return super().find_class(module, name) - # Log and raise exception for unauthorized classes - raise SecurityError(f"Unauthorized class {class_path} in joblib file") - - try: - # Use joblib's load function with our custom unpickler - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # First try to load with our custom unpickler - try: - with open(file_path, 'rb') as rawf: - with _validate_joblib_file(rawf, file_path, mmap_mode=None) as stream: - if isinstance(stream, tuple): - stream = stream[0] - - if isinstance(stream, str): - with open(stream, "rb") as f: - model = RestrictedUnpickler(f).load() - else: - model = RestrictedUnpickler(stream).load() - - # Validate the loaded object is a sklearn Pipeline - if not isinstance(model, Pipeline): - raise ValueError("Loaded model must be a scikit-learn Pipeline") - - # Validate pipeline steps - for step_name, step_obj in model.named_steps.items(): - if not (isinstance(step_obj, (StandardScaler, LogisticRegression))): - raise ValueError(f"Unauthorized pipeline step: {type(step_obj)}") - - except (pickle.UnpicklingError, AttributeError) as e: - raise SecurityError(f"Failed to safely load model: {e}") - - return model - - except Exception as e: - raise SecurityError(f"Failed to safely load model: {str(e)}") - - -class SecurityError(Exception): - """Custom exception for security-related errors.""" - - pass diff --git a/nemo/collections/asr/models/configs/diarizer_config.py b/nemo/collections/asr/models/configs/diarizer_config.py index 63f220b5f494..f27222b0b6c0 100644 --- a/nemo/collections/asr/models/configs/diarizer_config.py +++ b/nemo/collections/asr/models/configs/diarizer_config.py @@ -113,7 +113,7 @@ class SpeakerEmbeddingsParams(DiarizerComponentConfig): shift_length_in_sec: Tuple[float] = (0.75, 0.625, 0.5, 0.375, 0.25) # Weight for each scale. None (for single scale) or list with window/shift scale count. ex) [0.33,0.33,0.33] multiscale_weights: Tuple[float] = (1, 1, 1, 1, 1) - # save speaker embeddings in pickle format. + # save speaker embeddings in torch format. save_embeddings: bool = True diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index 4ba946cf9f76..6a71920bf6d4 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -19,7 +19,6 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer -from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer diff --git a/nemo/collections/common/tokenizers/column_coder.py b/nemo/collections/common/tokenizers/column_coder.py deleted file mode 100644 index bf1ab2f7accb..000000000000 --- a/nemo/collections/common/tokenizers/column_coder.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import Dict, List, Tuple - -import numpy as np -from numpy import ndarray -from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler - -from nemo.utils import logging - -__all__ = ["IntCode", "FloatCode", "CategoryCode", "ColumnCodes"] - - -class Code(object): - def compute_code(self, data_series: ndarray): - """ - @params: - data_series: an array of input data used to calculate mapping - """ - raise NotImplementedError() - - def __init__(self, col_name: str, code_len: int, start_id: int, fillall: bool = True, hasnan: bool = True): - """ - @params: - col_name: name of the column - code_len: number of tokens used to code the column. - start_id: offset for token_id. - fillall: if True, reserve space for digit number even the digit number is - not present in the data_series. Otherwise, only reserve space for the numbers - in the data_series. - hasnan: if True, reserve space for nan - """ - self.name = col_name - self.code_len = code_len - self.start_id = start_id - self.end_id = start_id - self.fillall = fillall - self.hasnan = hasnan - - def encode(self, item: str) -> List[int]: - raise NotImplementedError() - - def decode(self, ids: List[int]) -> str: - raise NotImplementedError() - - @property - def code_range(self) -> List[Tuple[int, int]]: - """ - get the vocab id range for each of the encoded tokens - @returns [(min, max), (min, max), ...] - """ - return [(self.start_id, self.end_id)] - - -class IntCode(Code): - def __init__( - self, col_name: str, code_len: int, start_id: int, fillall: bool = True, base: int = 100, hasnan: bool = True - ): - super().__init__(col_name, code_len, start_id, fillall, hasnan) - self.base = base - self.int_min: int = None - - def compute_code(self, data_series: ndarray): - significant_val = self.array_convert_to_int(data_series) - - digits_id_to_item = [{} for _ in range(self.code_len)] - digits_item_to_id = [{} for _ in range(self.code_len)] - for i in range(self.code_len): - id_to_item = digits_id_to_item[i] - item_to_id = digits_item_to_id[i] - v = (significant_val // self.base ** i) % self.base - if self.fillall: - uniq_items = range(0, self.base) - else: - uniq_items = sorted(np.unique(v).tolist()) - for k in range(len(uniq_items)): - item = str(uniq_items[k]) - item_to_id[item] = self.end_id - id_to_item[self.end_id] = item - self.end_id += 1 - self.digits_id_to_item = digits_id_to_item - self.digits_item_to_id = digits_item_to_id - self.NA_token = 'nan' - if self.hasnan: - self.end_id += 1 # add the N/A token - codes = [] - ranges = self.code_range - for i in ranges: - codes.append(i[1] - 1) - self.NA_token_id = codes - - def array_convert_to_int(self, val: ndarray): - val = val.astype(int) - self.int_min = val.min() - return val - self.int_min - - def convert_to_int(self, val: float) -> int: - return int(val) - self.int_min - - def reverse_convert_to_int(self, val: int) -> int: - return val + self.int_min - - @property - def code_range(self) -> List[Tuple[int, int]]: - """ - get the vocab id range for each of the encoded tokens - @returns [(min, max), (min, max), ...] - """ - # first largest digits - outputs = [] - c = 0 - for i in reversed(range(self.code_len)): - ids = self.digits_id_to_item[i].keys() - if c == 0: - if self.hasnan: - outputs.append((min(ids), max(ids) + 2)) # the first token contains the N/A - else: - outputs.append((min(ids), max(ids) + 1)) # non N/A - else: - outputs.append((min(ids), max(ids) + 1)) - c += 1 - return outputs - - def encode(self, item: str) -> List[int]: - if self.hasnan and item == self.NA_token: - return self.NA_token_id - elif not self.hasnan and item == self.NA_token: - raise ValueError(f"colum {self.name} cannot handle nan, please set hasnan=True") - val = float(item) - val_int = self.convert_to_int(val) - digits = [] - for i in range(self.code_len): - digit = (val_int // self.base ** i) % self.base - digits.append(str(digit)) - if (val_int // self.base ** self.code_len) != 0: - raise ValueError("not right length") - codes = [] - for i in reversed(range(self.code_len)): - digit_str = digits[i] - if digit_str in self.digits_item_to_id[i]: - codes.append(self.digits_item_to_id[i][digit_str]) - else: - # find the nearest encode id - allowed_digits = np.array([int(d) for d in self.digits_item_to_id[i].keys()]) - near_id = np.argmin(np.abs(allowed_digits - int(digit_str))) - digit_str = str(allowed_digits[near_id]) - codes.append(self.digits_item_to_id[i][digit_str]) - logging.warning('out of domain num is encounterd, use nearest code') - return codes - - def decode(self, ids: List[int]) -> str: - if self.hasnan and ids[0] == self.NA_token_id[0]: - return self.NA_token - v = 0 - for i in reversed(range(self.code_len)): - digit = int(self.digits_id_to_item[i][ids[self.code_len - i - 1]]) - v += digit * self.base ** i - v = self.reverse_convert_to_int(v) - return str(v) - - -class FloatCode(IntCode): - def __init__( - self, - col_name: str, - code_len: int, - start_id: int, - fillall: bool = True, - base: int = 100, - hasnan: bool = True, - transform: str = 'quantile', - ): - super().__init__(col_name, code_len, start_id, fillall, base, hasnan) - if transform == 'yeo-johnson': - self.scaler = PowerTransformer(standardize=True) - elif transform == 'quantile': - self.scaler = QuantileTransformer(output_distribution='uniform', n_quantiles=100) - elif transform == 'robust': - self.scaler = RobustScaler() - else: - raise ValueError('Supported data transformations are "yeo-johnson", "quantile", and "robust"') - - def convert_to_int(self, val: float) -> int: - val = np.expand_dims(np.array(val), axis=0) - values = self.scaler.transform(val[:, None])[:, 0] - self.mval - values = (values * self.base ** self.extra_digits).astype(int) - output = values[0] - return output - - def array_convert_to_int(self, val: ndarray): - values = self.scaler.fit_transform(val[:, None])[:, 0] - self.mval = values.min() - values = values - self.mval - digits = int(math.log(values.max(), self.base)) + 1 - # extra digits used for 'float' part of the number - extra_digits = self.code_len - digits - if extra_digits < 0: - raise ValueError("need large length to code the nummber") - self.extra_digits = extra_digits - values = (values * self.base ** self.extra_digits).astype(int) - return values - - def reverse_convert_to_int(self, val: int) -> float: - val = val / self.base ** self.extra_digits - val = np.expand_dims(np.array(val), axis=0) - v = self.scaler.inverse_transform(val[:, None] + self.mval)[0, 0] - return v - - def decode(self, ids: List[int]) -> str: - if self.hasnan and ids[0] == self.NA_token_id[0]: - return self.NA_token - v = 0 - for i in reversed(range(self.code_len)): - digit = int(self.digits_id_to_item[i][ids[self.code_len - i - 1]]) - v += digit * self.base ** i - v = self.reverse_convert_to_int(v) - accuracy = max(int(abs(np.log10(0.1 / self.base ** self.extra_digits))), 1) - return f"{v:.{accuracy}f}" - - -class CategoryCode(Code): - def __init__(self, col_name: str, start_id: int): - super().__init__(col_name, 1, start_id, True, False) - - def compute_code(self, data_series: ndarray): - uniq_items = np.unique(data_series).tolist() - id_to_item = {} - item_to_id = {} - for i in range(len(uniq_items)): - item = str(uniq_items[i]) - item_to_id[item] = self.end_id - id_to_item[self.end_id] = item - self.end_id += 1 - self.id_to_item = id_to_item - self.item_to_id = item_to_id - - def encode(self, item) -> List[int]: - return [self.item_to_id[item]] - - def decode(self, ids: List[int]) -> str: - return self.id_to_item[ids[0]] - - -column_map = {"int": IntCode, "float": FloatCode, "category": CategoryCode} - - -class ColumnCodes(object): - def __init__(self): - self.column_codes: Dict[str, Code] = {} - self.columns = [] - self.sizes = [] - - @property - def vocab_size(self): - return self.column_codes[self.columns[-1]].end_id - - def register(self, name: str, ccode: Code): - self.columns.append(name) - self.column_codes[name] = ccode - self.sizes.append(ccode.code_len) - - def encode(self, col: str, item: str) -> List[int]: - if col in self.column_codes: - return self.column_codes[col].encode(item) - else: - raise ValueError(f"cannot encode {col} {item}") - - def decode(self, col: str, ids: List[int]) -> str: - if col in self.column_codes: - return self.column_codes[col].decode(ids) - else: - raise ValueError("cannot decode") - - def get_range(self, column_id: int) -> List[Tuple[int, int]]: - return self.column_codes[self.columns[column_id]].code_range - - @classmethod - def get_column_codes(cls, column_configs, example_arrays): - column_codes = cls() - beg = 0 - cc = None - for config in column_configs: - col_name = config['name'] - coder = column_map[config['code_type']] - args = config.get('args', {}) - start_id = beg if cc is None else cc.end_id - args['start_id'] = start_id - args['col_name'] = col_name - cc = coder(**args) - cc.compute_code(example_arrays[col_name]) - column_codes.register(col_name, cc) - return column_codes diff --git a/nemo/collections/common/tokenizers/tabular_tokenizer.py b/nemo/collections/common/tokenizers/tabular_tokenizer.py deleted file mode 100644 index 5fa36832959c..000000000000 --- a/nemo/collections/common/tokenizers/tabular_tokenizer.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle -from typing import List - -import numpy - -from nemo.collections.common.tokenizers.column_coder import ColumnCodes -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - -__all__ = ['TabularTokenizer'] - -END_OF_TEXT = '<|endoftext|>' -NEW_LINE = '\n' - - -def find_index_of(list_input, item): - output = -1 - try: - output = list_input.index(item) - except ValueError: - pass - return output - - -class TabularTokenizer(TokenizerSpec): - def __init__(self, coder, special_tokens=[END_OF_TEXT, NEW_LINE], delimiter=','): - if isinstance(coder, ColumnCodes): - self.code_column: ColumnCodes = coder - else: - with open(coder, 'rb') as handle: - self.code_column: ColumnCodes = pickle.load(handle) - self.num_columns = len(self.code_column.columns) - self.special_tokens = {} - self.special_tokens_decoder = {} - self.add_special_tokens(special_tokens) - self.delimiter = delimiter - self.eod_id = self.special_tokens[END_OF_TEXT] - self.eos_id = self.eod_id - self.bos_id = self.eos_id - - def __len__(self): - return self.vocab_size - - @property - def vocab_size(self): - return max(self.special_tokens_decoder.keys()) + 1 - - def text_to_ids(self, text): - return self.encode(text) - - def ids_to_text(self, token_ids): - return self.decode(token_ids) - - @property - def eod(self): - return self.eod_id - - @property - def eor(self): - return self.special_tokens[NEW_LINE] - - def add_special_tokens(self, special_tokens): - """ Add a list of additional tokens to the encoder. - The additional tokens are indexed starting from the last - index of the - current vocabulary in the order of the `special_tokens` list. - """ - if not special_tokens: - self.special_tokens = {} - self.special_tokens_decoder = {} - return - new = dict( - (tok, self.code_column.vocab_size + i) - for i, tok in enumerate(special_tokens) - if tok not in self.special_tokens - ) - self.special_tokens.update(new) - self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} - - def text_to_tokens(self, text): - """ Tokenize a string. """ - tokens = [] - rows = text.split(NEW_LINE) - num_rows = len(rows) - for row_id in range(num_rows): - row = rows[row_id] - if row == '': - continue - fields = row.split(self.delimiter) - for f in fields: - splits = f.split(END_OF_TEXT) - if len(splits) == 1: - tokens.append(f.strip()) - elif len(splits) == 2: - if splits[0] != '': - tokens.append(splits[0].strip()) - tokens.append(END_OF_TEXT) - if splits[1] != '': - tokens.append(splits[1].strip()) - else: - raise ValueError("delimiter error") - if row_id != num_rows - 1: - tokens.append(NEW_LINE) - return tokens - - def tokens_to_ids(self, tokens: List[str]): - """ Converts a sequence of tokens into ids using the vocab. """ - ids = [] - cindex = 0 - if NEW_LINE in tokens: - idd = tokens.index(NEW_LINE) - cindex = (self.num_columns - idd) % self.num_columns - for token in tokens: - - if token in self.special_tokens: - ids.append(self.special_tokens[token]) - else: - index = cindex % self.num_columns - column = self.code_column.columns[index] - ids.extend(self.code_column.encode(column, token)) - cindex += 1 - return ids - - def ids_to_tokens(self, ids, skip_special_tokens=False): - """Converts a sequence of ids in Tabular tokens using the vocab.""" - tokens = [] - sizes = self.code_column.sizes - ids_size = sum(sizes) - cindex = 0 - eor_pos = find_index_of(ids, self.eor) - eod_pos = find_index_of(ids, self.eod) - if eor_pos >= 0 and eod_pos >= 0: - idd = min(eor_pos, eod_pos) - cindex = (ids_size - idd) % ids_size - elif eor_pos >= 0 and eod_pos < 0: - idd = eor_pos - cindex = (ids_size - idd) % ids_size - elif eod_pos >= 0 and eor_pos < 0: - idd = eod_pos - cindex = (ids_size - idd) % ids_size - cum_sizes = numpy.cumsum(sizes) - old_column_index = -1 - token_ids = [] - for i in ids: - if i in self.special_tokens_decoder: - if not skip_special_tokens: - tokens.append(self.special_tokens_decoder[i]) - else: - index = cindex % ids_size - column_index = numpy.where(index < cum_sizes)[0][0] - column = self.code_column.columns[column_index] - if old_column_index != column_index: - token_ids = [i] - old_column_index = column_index - else: - token_ids.append(i) - if len(token_ids) == sizes[column_index]: - tokens.append(self.code_column.decode(column, token_ids)) - cindex += 1 - return tokens - - def encode(self, text): - return self.tokens_to_ids(self.text_to_tokens(text)) - - def decode(self, token_ids): - tokens = self.ids_to_tokens(token_ids, skip_special_tokens=False) - return self.tokens_to_text(tokens) - - def tokens_to_text(self, tokens): - all_lines = [] - line = [] - for token in tokens: - if token == END_OF_TEXT or token == NEW_LINE: - if len(line) != 0: - line_text = self.delimiter.join(line) - all_lines.append(line_text) - all_lines.append(token) - line = [] - else: - line.append(token) - if len(line) != 0: - # remaining items - line_text = self.delimiter.join(line) - all_lines.append(line_text) - text = "".join(all_lines) - return text diff --git a/nemo/collections/common/tokenizers/tokenizer_utils.py b/nemo/collections/common/tokenizers/tokenizer_utils.py index 7c9ea39dff9f..fbdfcfa847c6 100644 --- a/nemo/collections/common/tokenizers/tokenizer_utils.py +++ b/nemo/collections/common/tokenizers/tokenizer_utils.py @@ -149,6 +149,8 @@ def get_tokenizer( return tokenizer +# TODO: this is unused code, should remove all unused tokenizers +# Should also remove it from docs/source/core/core.rst def get_nmt_tokenizer( library: str = "sentencepiece", model_name: Optional[str] = None, @@ -253,10 +255,6 @@ def get_nmt_tokenizer( special_tokens=special_tokens_dict, chat_template=chat_template, ) - elif library == "tabular": - from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer - - return TabularTokenizer(vocab_file, delimiter=delimiter) elif library == "tiktoken": from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer diff --git a/nemo/collections/tts/data/dataset.py b/nemo/collections/tts/data/dataset.py index 3b5ace9af9e9..857207cc6274 100644 --- a/nemo/collections/tts/data/dataset.py +++ b/nemo/collections/tts/data/dataset.py @@ -15,7 +15,6 @@ import json import math import os -import pickle import random from collections import defaultdict from pathlib import Path @@ -133,7 +132,7 @@ def __init__( min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. - ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths + ignore_file (Optional[Union[str, Path]]): The location of a json-saved list of audio paths that will be pruned prior to training. Defaults to None which does not prune. trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio signal. Defaults to False. @@ -343,8 +342,8 @@ def __init__( def filter_files(data, ignore_file, min_duration, max_duration, total_duration): if ignore_file: logging.info(f"Using {ignore_file} to prune dataset.") - with open(Path(ignore_file).expanduser(), "rb") as f: - wavs_to_ignore = set(pickle.load(f)) + with open(Path(ignore_file).expanduser(), "r") as f: + wavs_to_ignore = set(json.load(f)) filtered_data: List[Dict] = [] pruned_duration = 0 if total_duration is not None else None @@ -941,7 +940,7 @@ def __init__( min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. - ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths + ignore_file (Optional[Union[str, Path]]): The location of a json-saved list of audio paths that will be pruned prior to training. Defaults to None which does not prune. trim (bool): Whether to apply librosa.effects.trim to the audio file. Defaults to False. load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning). @@ -1092,7 +1091,7 @@ def __init__( min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load audio to compute duration. Defaults to None which does not prune. - ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths + ignore_file (Optional[Union[str, Path]]): The location of a json-saved list of audio paths that will be pruned prior to training. Defaults to None which does not prune. trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio signal. Defaults to False. diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py index 7afe8d922ecf..db98c9f831c4 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py @@ -15,8 +15,8 @@ """ # This script would evaluate an N-gram language model trained with KenLM library (https://github.com/kpu/kenlm) in -# fusion with beam search decoders on top of a trained ASR model with CTC decoder. To evaluate a model with -# Transducer (RNN-T) decoder use another script 'scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py'. +# fusion with beam search decoders on top of a trained ASR model with CTC decoder. To evaluate a model with +# Transducer (RNN-T) decoder use another script 'scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py'. # NeMo's beam search decoders are capable of using the KenLM's N-gram models # to find the best candidates. This script supports both character level and BPE level # encodings and models which is detected automatically from the type of the model. @@ -59,12 +59,12 @@ import contextlib import json import os -import pickle from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import List, Optional import editdistance +import msgpack import numpy as np import torch from omegaconf import MISSING, OmegaConf @@ -78,7 +78,6 @@ from nemo.core.config import hydra_runner from nemo.utils import logging - # fmt: off @@ -113,7 +112,7 @@ class EvalBeamSearchNGramConfig: decoding_strategy: str = "beam" decoding: ctc_beam_decoding.BeamCTCInferConfig = field(default_factory=lambda: ctc_beam_decoding.BeamCTCInferConfig(beam_size=128)) - + text_processing: Optional[TextProcessingConfig] = field(default_factory=lambda: TextProcessingConfig( punctuation_marks = ".,?", separate_punctuation = False, @@ -278,10 +277,10 @@ def main(cfg: EvalBeamSearchNGramConfig): target_transcripts = apply_text_processing(punctuation_capitalization, cfg, target_transcripts) if cfg.hyps_cache_file and os.path.exists(cfg.hyps_cache_file): - logging.info(f"Found a pickle file of hypotheses at '{cfg.hyps_cache_file}'.") - logging.info(f"Loading the cached pickle file of hypotheses from '{cfg.hyps_cache_file}' ...") + logging.info(f"Found a cached file of hypotheses at '{cfg.hyps_cache_file}'.") + logging.info(f"Loading the cached file of hypotheses from '{cfg.hyps_cache_file}' ...") with open(cfg.hyps_cache_file, 'rb') as probs_file: - all_hyps = pickle.load(probs_file) + all_hyps = msgpack.load(probs_file) if len(all_hyps) != len(audio_file_paths): raise ValueError( @@ -298,9 +297,9 @@ def main(cfg: EvalBeamSearchNGramConfig): if cfg.hyps_cache_file: os.makedirs(os.path.split(cfg.hyps_cache_file)[0], exist_ok=True) - logging.info(f"Writing pickle files of hypotheses at '{cfg.hyps_cache_file}'...") + logging.info(f"Writing cached files of hypotheses at '{cfg.hyps_cache_file}'...") with open(cfg.hyps_cache_file, 'wb') as f_dump: - pickle.dump(all_hyps, f_dump) + msgpack.dump(all_hyps, f_dump) wer_dist_greedy = 0 cer_dist_greedy = 0 diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py index 57bf9db6f3bd..febf085e0263 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py @@ -59,13 +59,13 @@ import contextlib import json import os -import pickle import tempfile from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import List, Optional import editdistance +import msgpack import numpy as np import torch from omegaconf import MISSING, OmegaConf @@ -287,10 +287,10 @@ def main(cfg: EvalBeamSearchNGramConfig): audio_file_paths.append(str(audio_file.absolute())) if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file): - logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.") - logging.info(f"Loading the cached pickle file of probabilities from '{cfg.probs_cache_file}' ...") + logging.info(f"Found a cached file of probabilities at '{cfg.probs_cache_file}'.") + logging.info(f"Loading the cached file of probabilities from '{cfg.probs_cache_file}' ...") with open(cfg.probs_cache_file, 'rb') as probs_file: - all_probs = pickle.load(probs_file) + all_probs = msgpack.load(probs_file) if len(all_probs) != len(audio_file_paths): raise ValueError( @@ -330,9 +330,9 @@ def main(cfg: EvalBeamSearchNGramConfig): all_probs.append(encoded_no_pad) if cfg.probs_cache_file: - logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...") + logging.info(f"Writing cached files of probabilities at '{cfg.probs_cache_file}'...") with open(cfg.probs_cache_file, 'wb') as f_dump: - pickle.dump(all_probs, f_dump) + msgpack.dump(all_probs, f_dump) if cfg.decoding_strategy == "greedy_batch": asr_model = asr_model.to('cpu') diff --git a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py index 71ff48563683..b677a769e06b 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py @@ -56,12 +56,12 @@ import contextlib import json import os -import pickle from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import List, Optional import editdistance +import msgpack import numpy as np import torch from omegaconf import MISSING, OmegaConf @@ -110,7 +110,7 @@ class EvalWFSTNGramConfig: decoding: ctc_beam_decoding.WfstCTCInferConfig = field( default_factory=lambda: ctc_beam_decoding.WfstCTCInferConfig(beam_size=1) ) - + text_processing: Optional[TextProcessingConfig] = field(default_factory=lambda: TextProcessingConfig( punctuation_marks = ".,?", separate_punctuation = False, @@ -288,10 +288,10 @@ def main(cfg: EvalWFSTNGramConfig): target_transcripts = punctuation_capitalization.separate_punctuation(target_transcripts) if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file): - logging.info(f"Found a pickle file of probabilities at '{cfg.probs_cache_file}'.") - logging.info(f"Loading the cached pickle file of probabilities from '{cfg.probs_cache_file}' ...") + logging.info(f"Found a cached file of probabilities at '{cfg.probs_cache_file}'.") + logging.info(f"Loading the cached file of probabilities from '{cfg.probs_cache_file}' ...") with open(cfg.probs_cache_file, 'rb') as probs_file: - all_probs = pickle.load(probs_file) + all_probs = msgpack.load(probs_file) if len(all_probs) != len(audio_file_paths): raise ValueError( @@ -312,9 +312,9 @@ def main(cfg: EvalWFSTNGramConfig): all_probs = all_logits if cfg.probs_cache_file: os.makedirs(os.path.split(cfg.probs_cache_file)[0], exist_ok=True) - logging.info(f"Writing pickle files of probabilities at '{cfg.probs_cache_file}'...") + logging.info(f"Writing cached files of probabilities at '{cfg.probs_cache_file}'...") with open(cfg.probs_cache_file, 'wb') as f_dump: - pickle.dump(all_probs, f_dump) + msgpack.dump(all_probs, f_dump) wer_dist_greedy = 0 cer_dist_greedy = 0 diff --git a/scripts/freesound_download_resample/download_resample_freesound.sh b/scripts/freesound_download_resample/download_resample_freesound.sh deleted file mode 100644 index af0a07f44d52..000000000000 --- a/scripts/freesound_download_resample/download_resample_freesound.sh +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#!/bin/bash - -# This is bash script actually run the downloading and resampling script. -# See instructions in freesound_download.py - -# Change this arguments if you want -page_size=100 # Number of sounds per page -max_samples=200 # Maximum number of sound samples -min_filesize=0 # Minimum filesize allowed (in MB) -max_filesize=100 # Maximum filesize allowed (in MB) - -if [[ $# -ne 3 ]]; then - echo "Require number of all files | data directory | resample data directory as arguments to the script" - exit 2 -fi - - -NUM_ALL_FILES=$1 -DATADIR=$2 -RESAMPLE_DATADIR=$3 - - -if [ ! -d "$DATADIR" ]; then - echo "Creating dir $DATADIR" - mkdir -p "$DATADIR" -fi - -if [ ! -d "$RESAMPLE_DATADIR" ]; then - echo "Creating dir $RESAMPLE_DATADIR" - mkdir -p "$RESAMPLE_DATADIR" -fi - -# we just need background categories for constructing dataset, feel free to include other (speech) categories for testing and training your VAD model -# background -categories=( - "Air brake" - "Static" - "Acoustic environment" - "Distortion" - "Tape hiss" - "Hubbub" - "Vibration" - "Cacophony" - "Throbbing" - "Reverberation" - "Inside, public space" - "Inside, small room" - "Echo" - "Outside, rural" - "Outside, natural" - "Outside, urban" - "Outside, manmade" - "Car" - "Bus" - "Traffic noise" - "Roadway noise" - "Truck" - "Emergency vehicle" - "Motorcycle" - "Aircraft engine" - "Aircraft" - "Helicopter" - "Bicycle" - "Skateboard" - "Subway, metro, underground" - "Railroad car" - "Train wagon" - "Train" - "Sailboat" - "Rowboat" - "Ship" -) - - -WAV_FILECOUNT="$(find $DATADIR -name '*.wav' -type f | wc -l)" -FLAC_FILECOUNT="$(find $DATADIR -name '*.flac' -type f | wc -l)" -FILECOUNT="$((WAV_FILECOUNT + FLAC_FILECOUNT))" -echo "File count: " $FILECOUNT - - -while((FILECOUNT <= NUM_ALL_FILES)) -do - for category in "${categories[@]}" - do - python freesound_download.py --data_dir "${DATADIR}" --category "${category}" --page_size "${page_size}" --max_samples "${max_samples}" --min_filesize "${min_filesize}" --max_filesize "${max_filesize}" - ret=$? - if [ $ret -ne 0 ]; then - exit 1 - fi - done - - WAV_FILECOUNT="$(find $DATADIR -name '*.wav' -type f | wc -l)" - FLAC_FILECOUNT="$(find $DATADIR -name '*.flac' -type f | wc -l)" - FILECOUNT="$((WAV_FILECOUNT + FLAC_FILECOUNT))" - echo "Current file count is: " $FILECOUNT -done - -# RESAMPLE -echo "Got enough data. Start resample!" -python freesound_resample.py --data_dir="${DATADIR}" --resampled_dir="${RESAMPLE_DATADIR}" - -echo "Done resample data!" diff --git a/scripts/freesound_download_resample/freesound_download.py b/scripts/freesound_download_resample/freesound_download.py deleted file mode 100644 index 438d0d20cc26..000000000000 --- a/scripts/freesound_download_resample/freesound_download.py +++ /dev/null @@ -1,590 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import os -import pickle -import time - -try: - import librosa - import requests - import requests_oauthlib - from joblib import Parallel, delayed - from oauthlib.oauth2 import TokenExpiredError -except (ModuleNotFoundError, ImportError) as e: - raise e - -try: - import freesound -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "freesound is not installed. Execute `pip install --no-cache-dir git+https://github.com/MTG/freesound-python.git` in terminal" - ) - - -""" -Instructions -1. We will need some requirements including freesound, requests, requests_oauthlib, joblib, librosa and sox. If they are not installed, please run `pip install -r freesound_requirements.txt` -2. Create an API key for freesound.org at https://freesound.org/help/developers/ -3. Create a python file called `freesound_private_apikey.py` and add lined `api_key = ` and `client_id = ` -4. Authorize by run `python freesound_download.py --authorize` and visit website, and paste response code -5. Feel free to change any arguments in download_resample_freesound.sh such as max_samples and max_filesize -6. Run `bash download_resample_freesound.sh ` -""" - -# Import the API Key -try: - from freesound_private_apikey import api_key, client_id - - print("API Key found !") -except ImportError: - raise ImportError( - "Create a python file called `freesound_private_apikey.py` and add lined `api_key = ` and `client_id = `" - ) - -auth_url = 'https://freesound.org/apiv2/oauth2/authorize/' -redirect_url = 'https://freesound.org/home/app_permissions/permission_granted/' -token_url = 'https://freesound.org/apiv2/oauth2/access_token/' -scope = ["read", "write"] - -BACKGROUND_CLASSES = [ - "Air brake", - "Static", - "Acoustic environment", - "Distortion", - "Tape hiss", - "Hubbub", - "Vibration", - "Cacophony", - "Throbbing", - "Reverberation", - "Inside, public space", - "Inside, small room", - "Echo", - "Outside, rural", - "Outside, natural", - "Outside, urban", - "Outside, manmade", - "Car", - "Bus", - "Traffic noise", - "Roadway noise", - "Truck", - "Emergency vehicle", - "Motorcycle", - "Aircraft engine", - "Aircraft", - "Helicopter", - "Bicycle", - "Skateboard", - "Subway, metro, underground", - "Railroad car", - "Train wagon", - "Train", - "Sailboat", - "Rowboat", - "Ship", -] - -SPEECH_CLASSES = [ - "Male speech", - "Female speech", - "Speech synthesizer", - "Babbling", - "Conversation", - "Child speech", - "Narration", - "Laughter", - "Yawn", - "Whispering", - "Whimper", - "Baby cry", - "Sigh", - "Groan", - "Humming", - "Male singing", - "Female singing", - "Child singing", - "Children shouting", -] - - -def initialize_oauth(): - # If token already exists, then just load it - if os.path.exists('_token.pkl'): - token = unpickle_object('_token') - oauth = requests_oauthlib.OAuth2Session(client_id, redirect_uri=redirect_url, scope=scope, token=token) - - else: - # Construct a new token after OAuth2 flow - # Initialize a OAuth2 session - oauth = requests_oauthlib.OAuth2Session(client_id, redirect_uri=redirect_url, scope=scope) - - authorization_url, state = oauth.authorization_url(auth_url) - print(f"Visit below website and paste access token below : \n\n{authorization_url}\n") - - authorization_response = input("Paste authorization response code here :\n") - - token = oauth.fetch_token( - token_url, - authorization_response=authorization_response, - code=authorization_response, - client_secret=api_key, - ) - - # Save the token generated - pickle_object(token, '_token') - - return oauth, token - - -def instantiate_session(): - # Reconstruct session in process, and force singular execution thread to reduce session - # connections to server - token = unpickle_object('_token') - session = requests_oauthlib.OAuth2Session(client_id, redirect_uri=redirect_url, scope=scope, token=token) - adapter = requests.adapters.HTTPAdapter(pool_connections=1, pool_maxsize=1) - session.mount('http://', adapter) - return session - - -def refresh_token(session): - print("Refreshing tokens...") - # Token expired, perform token refresh - extras = {'client_id': client_id, 'client_secret': api_key} - token = session.refresh_token(token_url, **extras) - print("Token refresh performed...") - # Save the refreshed token - pickle_object(token, '_token') - return session - - -def pickle_object(token, name): - with open(name + '.pkl', 'wb') as f: - pickle.dump(token, f) - - -def unpickle_object(name): - fp = name + '.pkl' - if os.path.exists(fp): - with open(fp, 'rb') as f: - token = pickle.load(f) - - return token - else: - raise FileNotFoundError('Token not found!') - - -def is_resource_limited(e: freesound.FreesoundException): - """ - Test if the reason for a freesound exception was either rate limit - or daily limit. - - If it was for either reason, sleep for an appropriate delay and return - to try again. - - Args: - e: Freesound Exception object - - Returns: - A boolean which describes whether the error was due to some - api limit issue, or if it was some other reason. - - If false is returned, then the user should carefully check the cause - and log it. - """ - detail = e.detail['detail'] - - if '2000' in detail: - # This is the request limit, hold off for 1 hour and try again - print(f"Hit daily limit, sleeping for 20 minutes.") - time.sleep(60 * 20) - return True - - elif '60' in detail: - # This is the request limit per minute, hold off for 1 minute and try again - print(f"Hit rate limit, sleeping for 1 minute.") - time.sleep(60) - return True - - else: - return False - - -def prepare_client(client: freesound.FreesoundClient, token) -> freesound.FreesoundClient: - # Initialize the client with token auth - client.set_token(token['access_token'], auth_type='oauth') - print("Client ready !") - return client - - -def get_text_query_with_resource_limit_checks(client, query: str, filters: list, fields: str, page_size: int): - """ - Performs a text query, checks for rate / api limits, and retries. - - Args: - client: FreesoundAPI client - query: query string (either exact or inexact) - filters: list of string filters - fields: String of values to recover - page_size: samples per page returned - - Returns: - - """ - pages = None - attempts = 20 - - while pages is None: - try: - pages = client.text_search(query=query, filter=" ".join(filters), fields=fields, page_size=str(page_size),) - - except freesound.FreesoundException as e: - # Most probably a rate limit or a request limit - # Check if that was the case, and wait appropriate ammount of time - # for retry - was_resource_limited = is_resource_limited(e) - - # If result of test False, it means that failure was due to some other reason. - # Log it, then break loop - if not was_resource_limited: - print(e.with_traceback(None)) - break - - attempts -= 1 - - # Attempt to refresh tokens if it fails multiple times - if attempts % 5 == 0 and attempts > 0: - session = instantiate_session() - refresh_token(session) - session.close() - token = unpickle_object('_token') - client = prepare_client(client, token) - - if attempts <= 0: - print(f"Failed to query pages for '{query}' after 10 attempts, skipping query") - break - - if pages is None: - print(f"Query attempts remaining = {attempts}") - - return client, pages - - -def get_resource_with_auto_refresh(session, download_url): - """ - Attempts download of audio with a token refresh if necessary. - """ - try: - result = session.get(download_url) - - except TokenExpiredError as e: - session = refresh_token(session) - result = session.get(download_url) - - except Exception as e: - result = None - - print(f"Skipping file {download_url} due to exception below\n\n") - print(e) - - return result.content - - -def download_song(basepath, id, name, download_url): - # Cleanup name - name = name.encode('ascii', 'replace').decode() - name = name.replace("?", "-") - name = name.replace(":", "-") - name = name.replace("(", "-") - name = name.replace(")", "-") - name = name.replace("'", "") - name = name.replace(",", "-") - name = name.replace("/", "-") - name = name.replace("\\", "-") - name = name.replace(".", "-") - name = name.replace(" ", "") - - # Correct last `.` for filetype - name = name[:-4] + '.wav' - - # Add file id to filename - name = f"id_{id}" + "_" + name - - fp = os.path.join(basepath, name) - - # Check if file, if exists already, can be loaded by librosa - # If it cannot be loaded, possibly corrupted file. - # Delete and then re-download - if os.path.exists(fp): - try: - _ = librosa.load(path=fp) - except Exception: - # File is currupted, delete and re-download. - os.remove(fp) - - print(f"Pre-existing file {fp} was corrupt and was deleted, will be re-downloaded.") - - if not os.path.exists(fp): - print("Downloading file :", name) - - session = instantiate_session() - - data = None - attempts = 10 - - try: - while data is None: - - try: - # Get the sound data - data = get_resource_with_auto_refresh(session, download_url) - - except freesound.FreesoundException as e: - # Most probably a rate limit or a request limit - # Check if that was the case, and wait appropriate amount of time - # for retry - was_resource_limited = is_resource_limited(e) - - # If result of test False, it means that failure was due to some other reason. - # Log it, then break loop - if not was_resource_limited: - print(e) - break - - attempts -= 1 - - if attempts <= 0: - print(f"Failed to download file {fp} after 10 attempts, skipping file") - break - - if data is None: - print(f"Download attempts remaining = {attempts}") - - finally: - session.close() - - # Write the data to file - if data is not None: - print("Downloaded file :", name) - - with open(fp, 'wb') as f: - f.write(data) - - # If file size is less than 89, then this probably is a text format and not an actual audio file. - if os.path.getsize(fp) > 89: - print(f"File written : {fp}") - - else: - os.remove(fp) - print(f"File corrupted and has been deleted: {fp}") - - else: - print(f"File [{fp}] corrupted or faced some issue when downloading, skipped.") - - # Sleep to avoid hitting rate limits - time.sleep(5) - - else: - print(f"File [{fp}] already exists in dataset, skipping re-download.") - - -def get_songs_by_category( - client: freesound.FreesoundClient, - category: str, - data_dir: str, - max_num_samples=100, - page_size=100, - min_filesize_in_mb=0, - max_filesize_in_mb=10, - n_jobs=None, -): - """ - Download songs of a category with restrictions - - Args: - client: FreesoundAPI client - category: category to be downloaded - data_dir: directory of downloaded songs - max_num_samples: maximum number of samples of this category - page_size: samples per page returned - min_filesize_in_mb: minimum filesize of the song in MB - max_filesize_in_mb: maximum filesize of the song in MB - n_jobs: number of jobs for parallel processing - - Returns: - - """ - # quote string to force exact match - query = f'"{category}"' - print(f"Query : {query}") - - page_size = min(page_size, 150) - max_filesize = int(max_filesize_in_mb * (2 ** 20)) - - if min_filesize_in_mb == 0: - min_filesize_in_mb = 1 - else: - min_filesize_in_mb = int(min_filesize_in_mb * (2 ** 20)) - - if max_num_samples < 0: - max_num_samples = int(1e6) - - filters = [ - 'type:(wav OR flac)', - 'license:("Attribution" OR "Creative Commons 0")', - f'filesize:[{min_filesize_in_mb} TO {max_filesize}]', - ] - - fields = "id,name,download,license" - - client, pages = get_text_query_with_resource_limit_checks( - client, query=query, filters=filters, fields=fields, page_size=page_size - ) - - if pages is None: - print(f"Number of attempts exceeded limit, skipping query {query}") - return - - num_pages = pages.count - - # Check if returned empty result; if so, fallback to inexact category search - if num_pages == 0: - print(f"Found 0 samples of results for query '{query}'") - print(f"Trying less restricted query : {category}") - - client, pages = get_text_query_with_resource_limit_checks( - client, query=category, filters=filters, fields=fields, page_size=page_size - ) - - if pages is None: - print(f"Number of attempts exceeded limit, skipping query {query}") - return - - num_pages = pages.count - - print(f"Found {num_pages} samples of results for query '{query}'") - - category = category.replace(' ', '_') - basepath = os.path.join(data_dir, category) - - if not os.path.exists(basepath): - os.makedirs(basepath) - - sounds = [] - sample_count = 0 - - # Retrieve sound license information - with open(os.path.join(basepath, 'licenses.txt'), 'w') as f: - f.write("ID,LICENSE\n") - f.flush() - - while True: - for sound in pages: - if sample_count >= max_num_samples: - print( - f"Collected {sample_count} samples, which is >= max number of samples requested " - f"{max_num_samples}. Stopping for this category : {category}" - ) - break - - sounds.append(sound) - sample_count += 1 - - f.write(f"{sound.id},{sound.license}\n") - f.flush() - - if sample_count >= max_num_samples: - break - - try: - pages = pages.next_page() - except ValueError: - break - - if n_jobs is None: - n_jobs = max(1, len(sounds)) - - # Parallel download all songs - with Parallel(n_jobs=n_jobs, verbose=10) as parallel: - _ = parallel(delayed(download_song)(basepath, sound.id, sound.name, sound.download) for sound in sounds) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser(description="Freesound download script") - - parser.add_argument( - '--authorize', action='store_true', dest='auth', help='Flag to only perform OAuth2 authorization step' - ) - - parser.add_argument('-c', '--category', default='', type=str, help='Category required to download') - - parser.add_argument('-d', '--data_dir', default='', type=str, help='Destination folder to store data') - - parser.add_argument('--page_size', default=100, type=int, help='Number of sounds per page') - - parser.add_argument('--max_samples', default=100, type=int, help='Maximum number of sound samples') - - parser.add_argument('--min_filesize', default=0, type=int, help='Maximum filesize allowed (in MB)') - - parser.add_argument('--max_filesize', default=20, type=int, help='Maximum filesize allowed (in MB)') - - parser.set_defaults(auth=False) - - args = parser.parse_args() - - if args.auth: - """ Initialize oauth token to be used by all """ - oauth, token = initialize_oauth() - oauth.close() - - print("Authentication suceeded ! Token stored in `_token.pkl`") - exit(0) - - if not os.path.exists('_token.pkl'): - raise FileNotFoundError( - "Please authorize the application first using " "`python freesound_download.py --authorize`" - ) - if args.data_dir == '': - raise ValueError("Data dir must be passed as an argument using `--data_dir`") - - data_dir = args.data_dir - - page_size = args.page_size - max_num_samples = args.max_samples - min_filesize_in_mb = args.min_filesize - max_filesize_in_mb = args.max_filesize - - # Initialize and authenticate client - token = unpickle_object('_token') - freesound_client = freesound.FreesoundClient() - client = prepare_client(freesound_client, token) - - category = args.category - - if category == '': - raise ValueError("Cannot pass empty string as it will select all of FreeSound data !") - - print(f"Downloading category : {category}") - get_songs_by_category( - client, - category, - data_dir=data_dir, - max_num_samples=max_num_samples, - page_size=page_size, - min_filesize_in_mb=min_filesize_in_mb, - max_filesize_in_mb=max_filesize_in_mb, - n_jobs=30, - ) diff --git a/scripts/freesound_download_resample/freesound_requirements.txt b/scripts/freesound_download_resample/freesound_requirements.txt deleted file mode 100644 index 1b8d7ec3dd90..000000000000 --- a/scripts/freesound_download_resample/freesound_requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -git+https://github.com/MTG/freesound-python.git -joblib -librosa -requests -requests_oauthlib -sox diff --git a/scripts/freesound_download_resample/freesound_resample.py b/scripts/freesound_download_resample/freesound_resample.py deleted file mode 100644 index 9e48620ff7ac..000000000000 --- a/scripts/freesound_download_resample/freesound_resample.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import glob -import os -import time - -import librosa -import sox -from joblib import Parallel, delayed - - -def resample_file(resampled_dir, filepath, ext, sample_rate): - """ - Resample an audio file to 16kHZ and transform to monochannel - Remove incompatible files. - - Args: - resampled_dir: Directory of transformed files. - filepath: Filepath of Audio - ext: File type e.g. "wav", "flac" - - Returns: - - """ - head, filename = os.path.split(filepath) - _, clsname = os.path.split(head) - - filename, _ = os.path.splitext(filename) - - new_dir = os.path.join(resampled_dir, clsname) - if not os.path.exists(new_dir): - os.makedirs(new_dir) - - new_path = os.path.join(new_dir, filename + f'.{ext}') - - # check if the resampled data exists. - if os.path.exists(new_path): - print(f"Resampled file {filepath} exists. Skip it.") - return None - - transform = sox.Transformer() - transform.set_output_format(file_type='wav') - transform.convert(samplerate=sample_rate, n_channels=1) - - try: - transform.build(filepath, new_path) - print(f"Finished converting file {filepath}.") - - return None - - except sox.core.SoxError as e: - - try: - # Check if the file is readable - librosa.load(path=filepath) - - # if it is, force input format and try again - transform.set_input_format(file_type=ext) - transform.build(filepath, new_path) - return None - - except Exception: - return filepath - - -def main(): - start = time.time() - parser = argparse.ArgumentParser(description='Freesound data resample') - parser.add_argument("--data_dir", required=True, default=None, type=str) - parser.add_argument('--resampled_dir', required=True, default=None, type=str) - parser.add_argument('--sample_rate', default=16000, type=int) - args = parser.parse_args() - - data_dir = args.data_dir - resampled_dir = args.resampled_dir - sample_rate = args.sample_rate - - wav_files = sorted(glob.glob(os.path.join(data_dir, '*/*.wav'))) - flac_files = sorted(glob.glob(os.path.join(data_dir, '*/*.flac'))) - - with Parallel(n_jobs=-1, verbose=10) as parallel: - wav_files_failed = parallel( - delayed(resample_file)(resampled_dir, filepath, ext='wav', sample_rate=sample_rate) - for filepath in wav_files - ) - - flac_files_failed = parallel( - delayed(resample_file)(resampled_dir, filepath, ext='flac', sample_rate=sample_rate) - for filepath in flac_files - ) - - with open('dataset_conversion_logs.txt', 'w') as f: - for file in wav_files_failed: - if file is not None: - f.write(f"{file}\n") - - for file in flac_files_failed: - if file is not None: - f.write(f"{file}\n") - - end = time.time() - print(f'Resample data in {data_dir} and save to {resampled_dir} takes {end-start} seconds.') - - -if __name__ == '__main__': - - main() diff --git a/scripts/tokenizers/train_tabular_data_tokenizer.py b/scripts/tokenizers/train_tabular_data_tokenizer.py deleted file mode 100644 index d160b2468f3a..000000000000 --- a/scripts/tokenizers/train_tabular_data_tokenizer.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pickle - -import pandas as pd -from omegaconf import OmegaConf - -from nemo.collections.common.tokenizers.column_coder import ColumnCodes -from nemo.core.config import hydra_runner -from nemo.utils import logging - - -@hydra_runner(config_path="conf", config_name="tabular_data_tokenizer") -def main(cfg) -> None: - logging.info("\n\n************** Experiment configuration ***********") - logging.info(OmegaConf.to_yaml(cfg)) - table = pd.read_csv(cfg.table_csv_file) - example_arrays = {} - for col in cfg.table_structure: - col_name = col['name'] - example_arrays[col_name] = table[col_name].dropna().unique() - cc = ColumnCodes.get_column_codes(cfg.table_structure, example_arrays) - with open(cfg.tokenizer_file, 'wb') as handle: - pickle.dump(cc, handle) - - -if __name__ == '__main__': - main() diff --git a/tests/collections/common/pl_utils.py b/tests/collections/common/pl_utils.py index a2e9609c8492..ddf57122685c 100644 --- a/tests/collections/common/pl_utils.py +++ b/tests/collections/common/pl_utils.py @@ -27,8 +27,8 @@ # limitations under the License. import os -import pickle import sys +import tempfile from functools import partial from typing import Callable, Optional @@ -51,7 +51,7 @@ def setup_ddp(rank, world_size): - """ Setup ddp enviroment """ + """Setup ddp enviroment""" os.environ["MASTER_ADDR"] = 'localhost' os.environ['MASTER_PORT'] = '8088' @@ -72,29 +72,34 @@ def _class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - Args: - rank: rank of current process - worldsize: number of processes - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + Args: + rank: rank of current process + worldsize: number of processes + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric metric = metric_class(dist_sync_on_step=dist_sync_on_step, **metric_args) - # verify metrics work after being loaded from pickled state - pickled_metric = pickle.dumps(metric) - metric = pickle.loads(pickled_metric) + # verify metrics work after being loaded from saved state + # As per https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#saving-and-loading-metrics, best to + # save and load state_dicts + if len(metric.state_dict()) > 0: + metric.persistent(True) + with tempfile.TemporaryFile() as fp: + torch.save(metric.state_dict(), fp) + metric = metric.load_state_dict(torch.load(fp, map_location="cpu")) for i in range(rank, NUM_BATCHES, worldsize): batch_result = metric(preds[i], target[i]) @@ -133,14 +138,14 @@ def _functional_test( metric_args: dict = {}, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning functional metric - and reference metric. - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric functional that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Utility function doing the actual comparison between lightning functional metric + and reference metric. + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric functional that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ metric = partial(metric_functional, **metric_args) @@ -153,19 +158,19 @@ def _functional_test( class MetricTester: - """ Class used for efficiently run alot of parametrized tests in ddp mode. - Makes sure that ddp is only setup once and that pool of processes are - used for all tests. - All tests should subclass from this and implement a new method called - `test_metric_name` - where the method `self.run_metric_test` is called inside. + """Class used for efficiently run alot of parametrized tests in ddp mode. + Makes sure that ddp is only setup once and that pool of processes are + used for all tests. + All tests should subclass from this and implement a new method called + `test_metric_name` + where the method `self.run_metric_test` is called inside. """ atol = 1e-8 def setup_class(self): - """ Setup the metric class. This will spawn the pool of workers that are - used for metric testing and setup_ddp + """Setup the metric class. This will spawn the pool of workers that are + used for metric testing and setup_ddp """ try: set_start_method('spawn') @@ -176,7 +181,7 @@ def setup_class(self): self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)]) def teardown_class(self): - """ Close pool of workers """ + """Close pool of workers""" self.pool.close() self.pool.join() @@ -188,14 +193,14 @@ def run_functional_metric_test( sk_metric: Callable, metric_args: dict = {}, ): - """ Main method that should be used for testing functions. Call this inside - testing method - Args: - preds: torch tensor with predictions - target: torch tensor with targets - metric_functional: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - metric_args: dict with additional arguments used for class initialization + """Main method that should be used for testing functions. Call this inside + testing method + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_functional: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + metric_args: dict with additional arguments used for class initialization """ _functional_test( preds=preds, @@ -218,21 +223,21 @@ def run_class_metric_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - Args: - ddp: bool, if running in ddp mode or not - preds: torch tensor with predictions - target: torch tensor with targets - metric_class: lightning metric class that should be tested - sk_metric: callable function that is used for comparison - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + Args: + ddp: bool, if running in ddp mode or not + preds: torch tensor with predictions + target: torch tensor with targets + metric_class: lightning metric class that should be tested + sk_metric: callable function that is used for comparison + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": @@ -286,21 +291,21 @@ def _perplexity_class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - Args: - rank: rank of current process - worldsize: number of processes - probs: torch tensor with probabilities - logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for - ``Perplexity`` metric. - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + Args: + rank: rank of current process + worldsize: number of processes + probs: torch tensor with probabilities + logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for + ``Perplexity`` metric. + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric perplexity = Perplexity(dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -309,9 +314,12 @@ def _perplexity_class_test( perplexity(probs, logits) return - # verify perplexity works after being loaded from pickled state - pickled_metric = pickle.dumps(perplexity) - perplexity = pickle.loads(pickled_metric) + # verify perplexity works after being loaded from saved state + if len(perplexity.state_dict()) > 0: + perplexity.persistent(True) + with tempfile.TemporaryFile() as fp: + torch.save(perplexity.state_dict(), fp) + perplexity = perplexity.load_state_dict(torch.load(fp, map_location="cpu")) for i in range(rank, NUM_BATCHES, worldsize): batch_result = perplexity(None if probs is None else probs[i], None if logits is None else logits[i]) @@ -361,20 +369,20 @@ def run_class_perplexity_test( check_dist_sync_on_step: bool = True, check_batch: bool = True, ): - """ Main method that should be used for testing class. Call this inside testing - methods. - Args: - ddp: bool, if running in ddp mode or not - probs: torch tensor with probabilities. - logits: torch tensor with logits. This test checks that probs and logits are mutually exclusive for - ``Perplexity`` metric. - dist_sync_on_step: bool, if true will synchronize metric state across - processes at each ``forward()`` - metric_args: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Main method that should be used for testing class. Call this inside testing + methods. + Args: + ddp: bool, if running in ddp mode or not + probs: torch tensor with probabilities. + logits: torch tensor with logits. This test checks that probs and logits are mutually exclusive for + ``Perplexity`` metric. + dist_sync_on_step: bool, if true will synchronize metric state across + processes at each ``forward()`` + metric_args: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ if ddp: if sys.platform == "win32": @@ -447,28 +455,32 @@ def _loss_class_test( check_batch: bool = True, atol: float = 1e-8, ): - """ Utility function doing the actual comparison between lightning class metric - and reference metric. - Args: - rank: rank of current process - worldsize: number of processes - loss_sum_or_avg: a one dimensional float torch tensor with loss sums or means. - num_measurements: a one dimensional integer torch tensor with number of values on which sums or means from - ``loss_sum_or_avg`` were computed. - dist_sync_on_step: bool, if true will synchronize metric state across processes at each call of the - method :meth:`forward()` - take_avg_loss: dict with additional arguments used for class initialization - check_dist_sync_on_step: bool, if true will check if the metric is also correctly - calculated per batch per device (and not just at the end) - check_batch: bool, if true will check if the metric is also correctly - calculated across devices for each batch (and not just at the end) + """Utility function doing the actual comparison between lightning class metric + and reference metric. + Args: + rank: rank of current process + worldsize: number of processes + loss_sum_or_avg: a one dimensional float torch tensor with loss sums or means. + num_measurements: a one dimensional integer torch tensor with number of values on which sums or means from + ``loss_sum_or_avg`` were computed. + dist_sync_on_step: bool, if true will synchronize metric state across processes at each call of the + method :meth:`forward()` + take_avg_loss: dict with additional arguments used for class initialization + check_dist_sync_on_step: bool, if true will check if the metric is also correctly + calculated per batch per device (and not just at the end) + check_batch: bool, if true will check if the metric is also correctly + calculated across devices for each batch (and not just at the end) """ # Instantiate lightning metric loss_metric = GlobalAverageLossMetric(dist_sync_on_step=dist_sync_on_step, take_avg_loss=take_avg_loss) - # verify loss works after being loaded from pickled state - pickled_metric = pickle.dumps(loss_metric) - loss_metric = pickle.loads(pickled_metric) + # verify loss works after being loaded from saved state + if len(loss_metric.state_dict()) > 0: + loss_metric.persistent(True) + with tempfile.TemporaryFile() as fp: + torch.save(loss_metric.state_dict(), fp) + loss_metric = loss_metric.load_state_dict(torch.load(fp, map_location="cpu")) + for i in range(rank, NUM_BATCHES, worldsize): batch_result = loss_metric(loss_sum_or_avg[i], num_measurements[i]) if loss_metric.dist_sync_on_step: diff --git a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb index 3db99889d92e..49eaafa49331 100644 --- a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb +++ b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb @@ -1085,8 +1085,7 @@ " def autocast(enabled=None):\n", " yield\n", "import numpy as np\n", - "import json\n", - "import pickle as pkl" + "import json" ] }, { @@ -1144,8 +1143,8 @@ " prefix = manifest_file.split('/')[-1].rsplit('.', 1)[-2]\n", "\n", " name = os.path.join(embedding_dir, prefix)\n", - " embeddings_file = name + '_embeddings.pkl'\n", - " pkl.dump(out_embeddings, open(embeddings_file, 'wb'))\n", + " embeddings_file = name + '_embeddings.pt'\n", + " torch.save(out_embeddings, embeddings_file)\n", " print(\"Saved embedding files to {}\".format(embedding_dir))" ] },