Skip to content

Commit bccfd43

Browse files
committed
Add a new model: WavLM
1 parent fe7a80e commit bccfd43

File tree

19 files changed

+944
-11
lines changed

19 files changed

+944
-11
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ set(SOURCES
129129
src/layers/common.cc
130130
src/layers/decoder.cc
131131
src/layers/transformer.cc
132+
src/layers/wavlm.cc
132133
src/layers/wav2vec2.cc
133134
src/layers/wav2vec2bert.cc
134135
src/layers/whisper.cc
@@ -139,6 +140,7 @@ set(SOURCES
139140
src/models/model_reader.cc
140141
src/models/sequence_to_sequence.cc
141142
src/models/transformer.cc
143+
src/models/wavlm.cc
142144
src/models/wav2vec2.cc
143145
src/models/wav2vec2bert.cc
144146
src/models/whisper.cc

include/ctranslate2/layers/attention.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,16 @@ namespace ctranslate2 {
6161
const StorageView* _relative_position_keys;
6262
const StorageView* _relative_asymmetric_position_keys;
6363
const StorageView* _relative_position_values;
64+
const StorageView* _gru_relative_position_const;
6465
dim_t _maximum_relative_position;
6566
dim_t _relative_left_max_position;
6667
dim_t _relative_right_max_position;
6768
const bool _merge_time_and_head_dims;
6869
const dim_t _cache_time_dim;
6970
std::unique_ptr<const LayerNorm> _q_norm; // Query normalization
7071
std::unique_ptr<const LayerNorm> _k_norm; // Key normalization
72+
protected:
73+
const std::unique_ptr<const Dense> _gru_relative_position_linear;
7174
};
7275
}
7376
}

include/ctranslate2/layers/wavlm.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#pragma once
2+
3+
#include <optional>
4+
#include "ctranslate2/layers/transformer.h"
5+
6+
namespace ctranslate2 {
7+
namespace layers {
8+
9+
class WavLMLayerNormConvLayer : public Layer {
10+
public:
11+
WavLMLayerNormConvLayer(const models::Model& model,
12+
const std::string& scope,
13+
dim_t stride,
14+
dim_t padding);
15+
16+
void operator()(const StorageView& input, StorageView& output) const;
17+
18+
DataType output_type() const override {
19+
return _conv.output_type();
20+
}
21+
22+
dim_t output_size() const override {
23+
return _conv.output_size();
24+
}
25+
26+
private:
27+
dim_t _stride;
28+
dim_t _padding;
29+
const Conv1D _conv;
30+
const LayerNorm _output_norm;
31+
const ops::Transpose _transpose;
32+
const ops::GELU _gelu;
33+
};
34+
35+
class WavLMPosConvLayer : public Layer {
36+
public:
37+
WavLMPosConvLayer(const models::Model& model, const std::string& scope);
38+
39+
void operator()(const StorageView& input, StorageView& output) const;
40+
41+
DataType output_type() const override {
42+
return _conv.output_type();
43+
}
44+
45+
dim_t output_size() const override {
46+
return _conv.output_size();
47+
}
48+
49+
private:
50+
const Conv1D _conv;
51+
const ops::Transpose _transpose;
52+
const ops::GELU _gelu;
53+
};
54+
55+
class WavLMEncoder : public Layer {
56+
public:
57+
WavLMEncoder(const models::Model& model, const std::string& scope);
58+
59+
void operator()(const StorageView& features, StorageView& output);
60+
61+
DataType output_type() const override {
62+
if (_lm_head) {
63+
return (*_lm_head).output_type();
64+
}
65+
else {
66+
return _output_norm.output_type();
67+
}
68+
}
69+
70+
dim_t output_size() const override {
71+
if (_lm_head) {
72+
return (*_lm_head).output_size();
73+
}
74+
else {
75+
return _output_norm.output_size();
76+
}
77+
}
78+
79+
dim_t input_size() const {
80+
return 1024;
81+
}
82+
83+
bool is_encoded(const StorageView& features) const {
84+
// Input features shape: [batch_size, input_size, input_time]
85+
// Encoder output shape: [batch_size, input_time // 2, output_size]
86+
//
87+
// input_time is variable so we check that dimension 1 is different than its original value.
88+
89+
return (features.rank() == 3
90+
&& features.dim(2) == output_size()
91+
&& features.dim(1) != input_size());
92+
}
93+
94+
const StorageView* _upgraded_model;
95+
96+
private:
97+
const StorageView* _return_logits;
98+
std::optional<WavLMLayerNormConvLayer> _feat_layer0;
99+
std::optional<std::vector<std::unique_ptr<const WavLMLayerNormConvLayer>>> _feat_layers;
100+
std::optional<LayerNorm> _fp_norm;
101+
std::optional<Dense> _fp_ff;
102+
std::optional<WavLMPosConvLayer> _pos_conv_embed;
103+
const ops::Transpose _transpose;
104+
const ops::GELU _gelu;
105+
const dim_t _num_heads;
106+
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
107+
const LayerNorm _output_norm;
108+
std::optional<Dense> _lm_head;
109+
};
110+
111+
}
112+
}

include/ctranslate2/models/wavlm.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#pragma once
2+
3+
//#include "ctranslate2/generation.h"
4+
#include "ctranslate2/layers/wavlm.h"
5+
#include "ctranslate2/models/model.h"
6+
#include "ctranslate2/replica_pool.h"
7+
8+
namespace ctranslate2 {
9+
namespace models {
10+
11+
struct WavLMOptions {
12+
// Maximum generation length.
13+
size_t max_length = 448;
14+
15+
// Randomly sample from the top K candidates (set 0 to sample from the full distribution).
16+
size_t sampling_topk = 1;
17+
18+
// Maximum index of the first predicted timestamp.
19+
size_t max_initial_timestamp_index = 50;
20+
21+
// Suppress blank outputs at the beginning of the sampling.
22+
bool suppress_blank = true;
23+
24+
// List of token IDs to suppress.
25+
// -1 will suppress a default set of symbols as defined in the model config.json file.
26+
std::vector<int> suppress_tokens = {-1};
27+
};
28+
29+
30+
class WavLMModel : public Model {
31+
public:
32+
const Vocabulary& get_vocabulary() const;
33+
size_t current_spec_revision() const override;
34+
bool is_quantizable(const std::string& variable_name) const override;
35+
bool is_linear_weight(const std::string& variable_name) const override;
36+
std::unique_ptr<Model> clone() const override;
37+
38+
bool use_global_int16_scale() const override {
39+
return false;
40+
}
41+
42+
protected:
43+
void initialize(ModelReader& model_reader) override;
44+
private:
45+
std::shared_ptr<const Vocabulary> _vocabulary;
46+
};
47+
48+
class WavLMReplica : public ModelReplica {
49+
public:
50+
static std::unique_ptr<WavLMReplica> create_from_model(const Model& model);
51+
52+
WavLMReplica(const std::shared_ptr<const WavLMModel>& model);
53+
StorageView encode(StorageView features, const bool to_cpu);
54+
private:
55+
const std::shared_ptr<const WavLMModel> _model;
56+
const std::unique_ptr<layers::WavLMEncoder> _encoder;
57+
58+
StorageView maybe_encode(StorageView features);
59+
};
60+
61+
class WavLM : public ReplicaPool<WavLMReplica> {
62+
public:
63+
using ReplicaPool::ReplicaPool;
64+
std::future<StorageView> encode(const StorageView& features, const bool to_cpu);
65+
};
66+
67+
}
68+
}

include/ctranslate2/storage_view.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ namespace ctranslate2 {
230230
template <typename T>
231231
StorageView& fill(T value);
232232
StorageView& zero();
233+
StorageView& one();
233234

234235
StorageView& copy_from(const StorageView& other, bool synchronous = false);
235236

python/cpp/module.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ PYBIND11_MODULE(_ext, m)
8686
ctranslate2::python::register_generator(m);
8787
ctranslate2::python::register_encoder(m);
8888
ctranslate2::python::register_whisper(m);
89+
ctranslate2::python::register_wavlm(m);
8990
ctranslate2::python::register_wav2vec2(m);
9091
ctranslate2::python::register_wav2vec2bert(m);
9192
ctranslate2::python::register_mpi(m);

python/cpp/module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ namespace ctranslate2 {
1717
void register_translation_stats(py::module& m);
1818
void register_translator(py::module& m);
1919
void register_whisper(py::module& m);
20+
void register_wavlm(py::module& m);
2021
void register_wav2vec2(py::module& m);
2122
void register_wav2vec2bert(py::module& m);
2223
void register_mpi(py::module& m);

python/cpp/wavlm.cc

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include "module.h"
2+
3+
#include <ctranslate2/models/wavlm.h>
4+
5+
#include "replica_pool.h"
6+
7+
#include <iostream>
8+
9+
namespace ctranslate2 {
10+
namespace python {
11+
12+
class WavLMWrapper : public ReplicaPoolHelper<models::WavLM> {
13+
public:
14+
using ReplicaPoolHelper::ReplicaPoolHelper;
15+
16+
StorageView encode(const StorageView& features, const bool to_cpu) {
17+
std::shared_lock lock(_mutex);
18+
assert_model_is_ready();
19+
return _pool->encode(features, to_cpu).get();
20+
}
21+
};
22+
23+
void register_wavlm(py::module& m) {
24+
py::class_<WavLMWrapper>(
25+
m, "WavLM",
26+
R"pbdoc(
27+
Implements the WavLM speech recognition model published by Microsoft.
28+
)pbdoc")
29+
30+
.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(),
31+
py::arg("model_path"),
32+
py::arg("device")="cpu",
33+
py::kw_only(),
34+
py::arg("device_index")=0,
35+
py::arg("compute_type")="default",
36+
py::arg("inter_threads")=1,
37+
py::arg("intra_threads")=0,
38+
py::arg("max_queued_batches")=0,
39+
py::arg("flash_attention")=false,
40+
py::arg("tensor_parallel")=false,
41+
py::arg("files")=py::none(),
42+
R"pbdoc(
43+
Initializes a WavLM model from a converted model.
44+
45+
Arguments:
46+
model_path: Path to the CTranslate2 model directory.
47+
device: Device to use (possible values are: cpu, cuda, auto).
48+
device_index: Device IDs where to place this model on.
49+
compute_type: Model computation type or a dictionary mapping a device name
50+
to the computation type (possible values are: default, auto, int8, int8_float32,
51+
int8_float16, int8_bfloat16, int16, float16, bfloat16, float32).
52+
inter_threads: Number of workers to allow executing multiple batches in parallel.
53+
intra_threads: Number of OpenMP threads per worker (0 to use a default value).
54+
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
55+
0 for an automatic value). When the queue is full, future requests will block
56+
until a free slot is available.
57+
flash_attention: run model with flash attention 2 for self-attention layer
58+
tensor_parallel: run model with tensor parallel mode
59+
files: Load model files from the memory. This argument is a dictionary mapping
60+
file names to file contents as file-like or bytes objects. If this is set,
61+
:obj:`model_path` acts as an identifier for this model.
62+
)pbdoc")
63+
64+
.def_property_readonly("device", &WavLMWrapper::device,
65+
"Device this model is running on.")
66+
.def_property_readonly("device_index", &WavLMWrapper::device_index,
67+
"List of device IDs where this model is running on.")
68+
.def_property_readonly("compute_type", &WavLMWrapper::compute_type,
69+
"Computation type used by the model.")
70+
.def_property_readonly("num_workers", &WavLMWrapper::num_replicas,
71+
"Number of model workers backing this instance.")
72+
.def_property_readonly("num_queued_batches", &WavLMWrapper::num_queued_batches,
73+
"Number of batches waiting to be processed.")
74+
.def_property_readonly("tensor_parallel", &WavLMWrapper::tensor_parallel,
75+
"Run model with tensor parallel mode.")
76+
.def_property_readonly("num_active_batches", &WavLMWrapper::num_active_batches,
77+
"Number of batches waiting to be processed or currently processed.")
78+
79+
.def("encode", &WavLMWrapper::encode,
80+
py::arg("features"),
81+
py::arg("to_cpu")=false,
82+
py::call_guard<py::gil_scoped_release>(),
83+
R"pbdoc(
84+
Encodes the input features.
85+
86+
Arguments:
87+
features: hidden_states (up to v.4.3.1, https://github.com/OpenNMT/CTranslate2/blob/59c7dda738892df7a064aa360d0e45a4c3840b07/python/tests/test_transformers.py#L1028) or
88+
raw audio, as a float array with shape (followed by VAD)
89+
``[batch_size, 409, 1024]`` or ``[batch_size, 1, 131200]``
90+
to_cpu: Copy the encoder output to the CPU before returning the value.
91+
92+
Returns:
93+
The encoder output.
94+
)pbdoc")
95+
96+
.def("unload_model", &WavLMWrapper::unload_model,
97+
py::arg("to_cpu")=false,
98+
py::call_guard<py::gil_scoped_release>(),
99+
R"pbdoc(
100+
Unloads the model attached to this wavlm but keep enough runtime context
101+
to quickly resume wavlm on the initial device.
102+
103+
Arguments:
104+
to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded.
105+
)pbdoc")
106+
107+
.def("load_model", &WavLMWrapper::load_model,
108+
py::arg("keep_cache")=false,
109+
py::call_guard<py::gil_scoped_release>(),
110+
R"pbdoc(
111+
Loads the model back to the initial device.
112+
113+
Arguments:
114+
keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists.
115+
)pbdoc")
116+
117+
.def_property_readonly("model_is_loaded", &WavLMWrapper::model_is_loaded,
118+
"Whether the model is loaded on the initial device and ready to be used.")
119+
;
120+
}
121+
122+
}
123+
}

0 commit comments

Comments
 (0)