From a571cf0e7b76a57ec44bc99705c175f86a992cda Mon Sep 17 00:00:00 2001 From: ZhengHongming888 Date: Tue, 4 Jun 2024 15:28:07 -0700 Subject: [PATCH 1/2] revision for 7b mistral gaudi support --- sentence_transformers/SentenceTransformer.py | 11 +++++++++-- sentence_transformers/models/Transformer.py | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index dfcc297c8..74eefd6a4 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -363,6 +363,7 @@ def encode( convert_to_tensor: bool = False, device: str = None, normalize_embeddings: bool = False, + kwargs: Optional[Dict[str, Any]] = None, ) -> Union[List[Tensor], ndarray, Tensor]: """ Computes sentence embeddings. @@ -485,11 +486,17 @@ def encode( if self.device.type == "hpu": if "input_ids" in features: curr_tokenize_len = features["input_ids"].shape - additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + if curr_tokenize_len[1] > 4096: + additional_pad_len = math.ceil(curr_tokenize_len[1] / 128) * 128 - curr_tokenize_len[1] + + extra_features.update(kwargs["gaudi_kwargs"]) + else: + additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + features["input_ids"] = torch.cat( ( features["input_ids"], - torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), ), -1, ) diff --git a/sentence_transformers/models/Transformer.py b/sentence_transformers/models/Transformer.py index d5a670869..84080a512 100644 --- a/sentence_transformers/models/Transformer.py +++ b/sentence_transformers/models/Transformer.py @@ -4,6 +4,7 @@ from torch import nn from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config +from sentence_transformers.util import get_device_name class Transformer(nn.Module): @@ -114,6 +115,23 @@ def forward(self, features): if "token_type_ids" in features: trans_features["token_type_ids"] = features["token_type_ids"] + device = get_device_name() + curr_tokenize_len = features["input_ids"].shape + if ( + device == "hpu" + and curr_tokenize_len[1] > 4096 + and "attn_softmax_bf16" in features + and "reuse_cache" in features + and "use_flash_attention" in features + and "flash_attention_recompute" in features + and "flash_attention_causal_mask" in features + ): + trans_features["attn_softmax_bf16"] = features["attn_softmax_bf16"] + trans_features["reuse_cache"] = features["reuse_cache"] + trans_features["use_flash_attention"] = features["use_flash_attention"] + trans_features["flash_attention_recompute"] = features["flash_attention_recompute"] + trans_features["flash_attention_causal_mask"] = features["flash_attention_causal_mask"] + output_states = self.auto_model(**trans_features, return_dict=False) output_tokens = output_states[0] From 560845a0ff6df0e081498866f6aea81da4cbae5e Mon Sep 17 00:00:00 2001 From: ZhengHongming888 Date: Tue, 4 Jun 2024 16:02:05 -0700 Subject: [PATCH 2/2] rename gaudi_kwargs by hpu_kwargs --- sentence_transformers/SentenceTransformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 74eefd6a4..a6df287e1 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -489,7 +489,7 @@ def encode( if curr_tokenize_len[1] > 4096: additional_pad_len = math.ceil(curr_tokenize_len[1] / 128) * 128 - curr_tokenize_len[1] - extra_features.update(kwargs["gaudi_kwargs"]) + extra_features.update(kwargs["hpu_kwargs"]) else: additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]