Skip to content

Commit 64b4d66

Browse files
Add DNT and custom translation support in NMT client (#108)
* changes made for custom translation, dnt phrases * formatted * minor fixes done * changes made * minor fix done * changes done * changes made * changes made for dnt and custom translation * .gitmodules changed --------- Co-authored-by: Manisha Johnson <manishaj@nvidia.com>
1 parent 6f7bd6b commit 64b4d66

4 files changed

Lines changed: 52 additions & 4 deletions

File tree

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[submodule "common"]
22
path = common
33
url = https://github.com/nvidia-riva/common.git
4-
branch = main
4+
branch = release/2.18.0

common

riva/client/nmt.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ def streaming_s2t_request_generator(
2222
for chunk in audio_chunks:
2323
yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk)
2424

25+
def add_dnt_phrases_dict(req, dnt_phrases_dict):
26+
if dnt_phrases_dict is not None:
27+
dnt_phrases = [f"{key}##{value}" for key, value in dnt_phrases_dict.items()]
28+
if dnt_phrases:
29+
result_dnt_phrases = ",".join(dnt_phrases)
30+
req.dnt_phrases.append(result_dnt_phrases)
31+
2532
class NeuralMachineTranslationClient:
2633
"""
2734
A class for translating text to text. Provides :meth:`translate` which returns translated text
@@ -137,6 +144,7 @@ def translate(
137144
source_language: str,
138145
target_language: str,
139146
future: bool = False,
147+
dnt_phrases_dict: Optional[dict] = None,
140148
) -> Union[riva_nmt.TranslateTextResponse, _MultiThreadedRendezvous]:
141149
"""
142150
Translate input list of input text :param:`text` using model :param:`model` from :param:`source_language` into :param:`target_language`
@@ -158,7 +166,7 @@ def translate(
158166
source_language=source_language,
159167
target_language=target_language
160168
)
161-
169+
add_dnt_phrases_dict(req, dnt_phrases_dict)
162170
func = self.stub.TranslateText.future if future else self.stub.TranslateText
163171
return func(req, metadata=self.auth.get_auth_metadata())
164172

scripts/nmt/nmt.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,35 @@
3838
from riva.client.argparse_utils import add_connection_argparse_parameters
3939

4040

41+
def read_dnt_phrases_file(file_path):
42+
dnt_phrases_dict = {}
43+
if file_path:
44+
try:
45+
with open(file_path, "r") as infile:
46+
for line in infile:
47+
# Trim leading and trailing whitespaces
48+
line = line.strip()
49+
50+
if line:
51+
pos = line.find("##")
52+
if pos != -1:
53+
# Line contains "##"
54+
key = line[:pos].strip()
55+
value = line[pos + 2 :].strip()
56+
else:
57+
# Line doesn't contain "##"
58+
key = line.strip()
59+
value = ""
60+
61+
# Add the key-value pair to the dictionary
62+
if key:
63+
dnt_phrases_dict[key] = value
64+
65+
except IOError:
66+
raise RuntimeError(f"Could not open file {file_path}")
67+
68+
return dnt_phrases_dict
69+
4170
def parse_args() -> argparse.Namespace:
4271
parser = argparse.ArgumentParser(
4372
description="Neural machine translation by Riva AI Services",
@@ -48,6 +77,7 @@ def parse_args() -> argparse.Namespace:
4877
"--text", default="mir Das ist mir Wurs, bien ich ein berliner", type=str, help="Text to translate"
4978
)
5079
inputs.add_argument("--text-file", type=str, help="Path to file for translation")
80+
parser.add_argument("--dnt-phrases-file", type=str, help="Path to file which contains dnt phrases and custom translations")
5181
parser.add_argument("--model-name", default="", type=str, help="model to use to translate")
5282
parser.add_argument(
5383
"--source-language-code", type=str, default="en-US", help="Source language code (according to BCP-47 standard)"
@@ -65,7 +95,17 @@ def parse_args() -> argparse.Namespace:
6595
def main() -> None:
6696
def request(inputs,args):
6797
try:
68-
response = nmt_client.translate(inputs, args.model_name, args.source_language_code, args.target_language_code)
98+
dnt_phrases_input = {}
99+
if args.dnt_phrases_file != None:
100+
dnt_phrases_input = read_dnt_phrases_file(args.dnt_phrases_file)
101+
response = nmt_client.translate(
102+
texts=inputs,
103+
model=args.model_name,
104+
source_language=args.source_language_code,
105+
target_language=args.target_language_code,
106+
future=False,
107+
dnt_phrases_dict=dnt_phrases_input,
108+
)
69109
for translation in response.translations:
70110
print(translation.text)
71111
except grpc.RpcError as e:

0 commit comments

Comments
 (0)