|
| 1 | +#include "module.h" |
| 2 | + |
| 3 | +#include <ctranslate2/models/wavlm.h> |
| 4 | + |
| 5 | +#include "replica_pool.h" |
| 6 | + |
| 7 | +#include <iostream> |
| 8 | + |
| 9 | +namespace ctranslate2 { |
| 10 | + namespace python { |
| 11 | + |
| 12 | + class WavLMWrapper : public ReplicaPoolHelper<models::WavLM> { |
| 13 | + public: |
| 14 | + using ReplicaPoolHelper::ReplicaPoolHelper; |
| 15 | + |
| 16 | + StorageView encode(const StorageView& features, const bool to_cpu) { |
| 17 | + std::shared_lock lock(_mutex); |
| 18 | + assert_model_is_ready(); |
| 19 | + return _pool->encode(features, to_cpu).get(); |
| 20 | + } |
| 21 | + }; |
| 22 | + |
| 23 | + void register_wavlm(py::module& m) { |
| 24 | + py::class_<WavLMWrapper>( |
| 25 | + m, "WavLM", |
| 26 | + R"pbdoc( |
| 27 | + Implements the WavLM speech recognition model published by Microsoft. |
| 28 | + )pbdoc") |
| 29 | + |
| 30 | + .def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), |
| 31 | + py::arg("model_path"), |
| 32 | + py::arg("device")="cpu", |
| 33 | + py::kw_only(), |
| 34 | + py::arg("device_index")=0, |
| 35 | + py::arg("compute_type")="default", |
| 36 | + py::arg("inter_threads")=1, |
| 37 | + py::arg("intra_threads")=0, |
| 38 | + py::arg("max_queued_batches")=0, |
| 39 | + py::arg("flash_attention")=false, |
| 40 | + py::arg("tensor_parallel")=false, |
| 41 | + py::arg("files")=py::none(), |
| 42 | + R"pbdoc( |
| 43 | + Initializes a WavLM model from a converted model. |
| 44 | +
|
| 45 | + Arguments: |
| 46 | + model_path: Path to the CTranslate2 model directory. |
| 47 | + device: Device to use (possible values are: cpu, cuda, auto). |
| 48 | + device_index: Device IDs where to place this model on. |
| 49 | + compute_type: Model computation type or a dictionary mapping a device name |
| 50 | + to the computation type (possible values are: default, auto, int8, int8_float32, |
| 51 | + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). |
| 52 | + inter_threads: Number of workers to allow executing multiple batches in parallel. |
| 53 | + intra_threads: Number of OpenMP threads per worker (0 to use a default value). |
| 54 | + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, |
| 55 | + 0 for an automatic value). When the queue is full, future requests will block |
| 56 | + until a free slot is available. |
| 57 | + flash_attention: run model with flash attention 2 for self-attention layer |
| 58 | + tensor_parallel: run model with tensor parallel mode |
| 59 | + files: Load model files from the memory. This argument is a dictionary mapping |
| 60 | + file names to file contents as file-like or bytes objects. If this is set, |
| 61 | + :obj:`model_path` acts as an identifier for this model. |
| 62 | + )pbdoc") |
| 63 | + |
| 64 | + .def_property_readonly("device", &WavLMWrapper::device, |
| 65 | + "Device this model is running on.") |
| 66 | + .def_property_readonly("device_index", &WavLMWrapper::device_index, |
| 67 | + "List of device IDs where this model is running on.") |
| 68 | + .def_property_readonly("compute_type", &WavLMWrapper::compute_type, |
| 69 | + "Computation type used by the model.") |
| 70 | + .def_property_readonly("num_workers", &WavLMWrapper::num_replicas, |
| 71 | + "Number of model workers backing this instance.") |
| 72 | + .def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches, |
| 73 | + "Number of batches waiting to be processed.") |
| 74 | + .def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel, |
| 75 | + "Run model with tensor parallel mode.") |
| 76 | + .def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches, |
| 77 | + "Number of batches waiting to be processed or currently processed.") |
| 78 | + |
| 79 | + .def("encode", &WavLMWrapper::encode, |
| 80 | + py::arg("features"), |
| 81 | + py::arg("to_cpu")=false, |
| 82 | + py::call_guard<py::gil_scoped_release>(), |
| 83 | + R"pbdoc( |
| 84 | + Encodes the input features. |
| 85 | +
|
| 86 | + Arguments: |
| 87 | + features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or |
| 88 | + raw audio, as a float array with shape (followed by VAD) |
| 89 | + ``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]`` |
| 90 | + to_cpu: Copy the encoder output to the CPU before returning the value. |
| 91 | +
|
| 92 | + Returns: |
| 93 | + The encoder output. |
| 94 | + )pbdoc") |
| 95 | + |
| 96 | + .def("unload_model", &WavLMWrapper::unload_model, |
| 97 | + py::arg("to_cpu")=false, |
| 98 | + py::call_guard<py::gil_scoped_release>(), |
| 99 | + R"pbdoc( |
| 100 | + Unloads the model attached to this wavlm but keep enough runtime context |
| 101 | + to quickly resume wavlm on the initial device. |
| 102 | +
|
| 103 | + Arguments: |
| 104 | + to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. |
| 105 | + )pbdoc") |
| 106 | + |
| 107 | + .def("load_model", &WavLMWrapper::load_model, |
| 108 | + py::arg("keep_cache")=false, |
| 109 | + py::call_guard<py::gil_scoped_release>(), |
| 110 | + R"pbdoc( |
| 111 | + Loads the model back to the initial device. |
| 112 | +
|
| 113 | + Arguments: |
| 114 | + keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. |
| 115 | + )pbdoc") |
| 116 | + |
| 117 | + .def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded, |
| 118 | + "Whether the model is loaded on the initial device and ready to be used.") |
| 119 | + ; |
| 120 | + } |
| 121 | + |
| 122 | + } |
| 123 | +} |
0 commit comments