Skip to content

Commit 45354f1

Browse files
vrdn-23alvarobartt
andauthored
Add support for DebertaV2 (#746)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent 882d027 commit 45354f1

7 files changed

Lines changed: 1572 additions & 4 deletions

File tree

backends/candle/src/layers/linear.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum HiddenAct {
1313
Relu,
1414
Silu,
1515
Swiglu,
16+
Tanh,
1617
}
1718

1819
impl HiddenAct {
@@ -22,6 +23,7 @@ impl HiddenAct {
2223
Self::Relu => x.relu(),
2324
Self::Silu => x.silu(),
2425
Self::Swiglu => candle_nn::ops::swiglu(x),
26+
Self::Tanh => x.tanh(),
2527
}
2628
}
2729
}
@@ -91,6 +93,7 @@ impl Linear {
9193
HiddenAct::Relu => x.relu(),
9294
HiddenAct::Silu => x.silu(),
9395
HiddenAct::Swiglu => candle_nn::ops::swiglu(&x),
96+
HiddenAct::Tanh => x.tanh(),
9497
}
9598
} else {
9699
Ok(x)

backends/candle/src/lib.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use crate::compute_cap::{
2222
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
2323
};
2424
use crate::models::{
25-
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
26-
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, LLamaConfig,
27-
MPNetConfig, MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel,
28-
NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
25+
BertConfig, BertModel, DebertaV2Config, DebertaV2Model, Dense, DenseConfig, DenseLayer,
26+
DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, Gemma3Config, Gemma3Model,
27+
JinaBertModel, JinaCodeBertModel, LLamaConfig, MPNetConfig, MPNetModel, MistralConfig, Model,
28+
ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config,
29+
Qwen3Model,
2930
};
3031
#[cfg(feature = "cuda")]
3132
use crate::models::{
@@ -92,6 +93,7 @@ impl<'de> Deserialize<'de> for BertConfigWrapper {
9293
#[serde(tag = "model_type", rename_all = "kebab-case")]
9394
enum Config {
9495
Bert(BertConfigWrapper),
96+
DebertaV2(DebertaV2Config),
9597
Camembert(BertConfig),
9698
#[serde(rename(deserialize = "distilbert"))]
9799
DistilBert(DistilBertConfig),
@@ -265,6 +267,10 @@ impl CandleBackend {
265267
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
266268
}
267269
},
270+
(Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => {
271+
tracing::info!("Starting DebertaV2 model on {:?}", device);
272+
Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?))
273+
}
268274
(
269275
Config::Camembert(config) | Config::Roberta(config) | Config::XlmRoberta(config),
270276
Device::Cpu | Device::Metal(_),
@@ -392,6 +398,11 @@ impl CandleBackend {
392398
}
393399
}
394400
#[cfg(feature = "cuda")]
401+
(Config::DebertaV2(config), Device::Cuda(_)) => {
402+
tracing::info!("Starting DebertaV2 model on {:?}", device);
403+
Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?))
404+
}
405+
#[cfg(feature = "cuda")]
395406
(Config::DistilBert(config), Device::Cuda(_)) => {
396407
if cfg!(feature = "flash-attn")
397408
&& dtype == DType::F16

0 commit comments

Comments
 (0)