Skip to content

Commit ec46a40

Browse files
7heMech23rd
authored andcommitted
Added Cloudflare STT integration for voice message transcription. (#442)
* Added Cloudflare STT integration for voice message transcription. * fix Cloudflare STT resource leak, silent fallback, default toggle state, and add credential gating
1 parent 2e958f2 commit ec46a40

7 files changed

Lines changed: 438 additions & 14 deletions

File tree

TMessagesProj/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ android {
9999
coreLibraryDesugaringEnabled true
100100
}
101101

102+
kotlinOptions {
103+
jvmTarget = JavaVersion.VERSION_17
104+
}
105+
102106
defaultConfig.versionCode = Integer.parseInt(Utils['getVersionCode']())
103107

104108
defaultConfig {
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
package org.telegram.messenger;
2+
3+
import android.media.MediaCodec;
4+
import android.media.MediaExtractor;
5+
import android.media.MediaFormat;
6+
import android.media.MediaMuxer;
7+
import android.text.TextUtils;
8+
import android.util.Base64;
9+
10+
import com.google.gson.Gson;
11+
import com.google.gson.annotations.Expose;
12+
import com.google.gson.annotations.SerializedName;
13+
14+
import org.telegram.ui.Components.BulletinFactory;
15+
import org.telegram.ui.Components.AlertsCreator;
16+
import org.telegram.ui.LaunchActivity;
17+
import org.telegram.ui.ActionBar.BaseFragment;
18+
19+
import java.io.BufferedReader;
20+
import java.io.File;
21+
import java.io.IOException;
22+
import java.io.InputStreamReader;
23+
import java.io.OutputStream;
24+
import java.net.HttpURLConnection;
25+
import java.net.URL;
26+
import java.nio.ByteBuffer;
27+
import java.nio.file.Files;
28+
import java.util.List;
29+
import java.util.concurrent.ExecutorService;
30+
import java.util.concurrent.Executors;
31+
import java.util.function.BiConsumer;
32+
33+
public class CloudflareSTT {
34+
private static final Gson gson = new Gson();
35+
private static final ExecutorService executorService = Executors.newCachedThreadPool();
36+
37+
public static boolean isConfigured() {
38+
return SharedConfig.cfEnableStt && !TextUtils.isEmpty(SharedConfig.cfAccountID) && !TextUtils.isEmpty(SharedConfig.cfApiToken);
39+
}
40+
41+
public static void showErrorDialog(Exception e) {
42+
var fragment = LaunchActivity.getSafeLastFragment();
43+
var message = e.getLocalizedMessage();
44+
if (fragment == null || !BulletinFactory.canShowBulletin(fragment) || message == null) {
45+
return;
46+
}
47+
if (message.length() > 45) {
48+
AlertsCreator.showSimpleAlert(fragment, LocaleController.getString("ErrorOccurred", R.string.ErrorOccurred), e.getMessage());
49+
} else {
50+
BulletinFactory.of(fragment).createErrorBulletin(message).show();
51+
}
52+
}
53+
54+
private static void extractAudio(String inputFilePath, String outputFilePath) throws IOException {
55+
var extractor = new MediaExtractor();
56+
MediaMuxer muxer = null;
57+
try {
58+
extractor.setDataSource(inputFilePath);
59+
60+
MediaFormat audioFormat = null;
61+
int audioTrackIndex = -1;
62+
for (int i = 0; i < extractor.getTrackCount(); i++) {
63+
var format = extractor.getTrackFormat(i);
64+
var mime = format.getString(MediaFormat.KEY_MIME);
65+
if (mime != null && mime.startsWith("audio/")) {
66+
audioFormat = format;
67+
audioTrackIndex = i;
68+
break;
69+
}
70+
}
71+
72+
if (audioFormat == null) {
73+
throw new IOException("No audio track found in " + inputFilePath);
74+
}
75+
76+
muxer = new MediaMuxer(outputFilePath, MediaMuxer.OutputFormat.MUXER_OUTPUT_MPEG_4);
77+
var trackIndex = muxer.addTrack(audioFormat);
78+
muxer.start();
79+
80+
extractor.selectTrack(audioTrackIndex);
81+
82+
var bufferInfo = new MediaCodec.BufferInfo();
83+
var buffer = ByteBuffer.allocate(65536);
84+
85+
while (true) {
86+
var sampleSize = extractor.readSampleData(buffer, 0);
87+
if (sampleSize < 0) {
88+
break;
89+
}
90+
91+
bufferInfo.offset = 0;
92+
bufferInfo.size = sampleSize;
93+
bufferInfo.presentationTimeUs = extractor.getSampleTime();
94+
bufferInfo.flags = 0;
95+
96+
muxer.writeSampleData(trackIndex, buffer, bufferInfo);
97+
extractor.advance();
98+
}
99+
100+
muxer.stop();
101+
} finally {
102+
if (muxer != null) {
103+
muxer.release();
104+
}
105+
extractor.release();
106+
}
107+
}
108+
109+
public static void requestWorkersAi(String path, boolean video, BiConsumer<String, Exception> callback) {
110+
if (!isConfigured()) {
111+
callback.accept(null, new Exception(LocaleController.getString("CloudflareCredentialsNotSet", R.string.CloudflareCredentialsNotSet)));
112+
return;
113+
}
114+
executorService.submit(() -> {
115+
File audioPath;
116+
if (video) {
117+
var audioFile = new File(path + ".m4a");
118+
try {
119+
extractAudio(path, audioFile.getAbsolutePath());
120+
} catch (IOException e) {
121+
FileLog.e(e);
122+
callback.accept(null, e);
123+
return;
124+
}
125+
audioPath = audioFile;
126+
} else {
127+
audioPath = new File(path);
128+
}
129+
byte[] audio;
130+
try {
131+
audio = Files.readAllBytes(audioPath.toPath());
132+
} catch (IOException e) {
133+
callback.accept(null, e);
134+
return;
135+
}
136+
137+
var payload = new WhisperRequest();
138+
payload.audio = Base64.encodeToString(audio, Base64.NO_WRAP);
139+
payload.vadFilter = false;
140+
141+
try {
142+
URL url = new URL("https://api.cloudflare.com/client/v4/accounts/" + SharedConfig.cfAccountID + "/ai/run/@cf/openai/whisper-large-v3-turbo");
143+
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
144+
conn.setRequestMethod("POST");
145+
conn.setRequestProperty("Authorization", "Bearer " + SharedConfig.cfApiToken);
146+
conn.setRequestProperty("Content-Type", "application/json");
147+
conn.setConnectTimeout(120000);
148+
conn.setReadTimeout(120000);
149+
conn.setDoOutput(true);
150+
151+
String jsonInputString = gson.toJson(payload);
152+
try (OutputStream os = conn.getOutputStream()) {
153+
byte[] input = jsonInputString.getBytes("utf-8");
154+
os.write(input, 0, input.length);
155+
}
156+
157+
int code = conn.getResponseCode();
158+
BufferedReader br = new BufferedReader(new InputStreamReader(
159+
code >= 200 && code < 300 ? conn.getInputStream() : conn.getErrorStream(), "utf-8"));
160+
StringBuilder response = new StringBuilder();
161+
String responseLine;
162+
while ((responseLine = br.readLine()) != null) {
163+
response.append(responseLine.trim());
164+
}
165+
166+
var whisperResponse = gson.fromJson(response.toString(), WhisperResponse.class);
167+
if (whisperResponse.success != null && whisperResponse.success && whisperResponse.result != null) {
168+
callback.accept(whisperResponse.result.text, null);
169+
} else {
170+
var errors = whisperResponse.errors;
171+
if (errors != null && !errors.isEmpty()) {
172+
callback.accept(null, new Exception(errors.size() == 1 ? errors.get(0).message : errors.toString()));
173+
} else {
174+
callback.accept(null, new Exception("Unknown error from Cloudflare: " + code));
175+
}
176+
}
177+
} catch (Exception e) {
178+
callback.accept(null, e);
179+
}
180+
});
181+
}
182+
183+
public static class WhisperRequest {
184+
@SerializedName("audio")
185+
@Expose
186+
public String audio;
187+
@SerializedName("vad_filter")
188+
@Expose
189+
public Boolean vadFilter;
190+
}
191+
192+
public static class Result {
193+
@SerializedName("text")
194+
@Expose
195+
public String text;
196+
}
197+
198+
public static class WhisperResponse {
199+
@SerializedName("result")
200+
@Expose
201+
public Result result;
202+
@SerializedName("success")
203+
@Expose
204+
public Boolean success;
205+
@SerializedName("errors")
206+
@Expose
207+
public List<Error> errors;
208+
}
209+
210+
public static class Error {
211+
@SerializedName("message")
212+
@Expose
213+
public String message;
214+
215+
@Override
216+
public String toString() {
217+
return message;
218+
}
219+
}
220+
}

TMessagesProj/src/main/java/org/telegram/messenger/SharedConfig.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ public class SharedConfig {
6161
PASSCODE_TYPE_PASSWORD = 1;
6262
private static int legacyDevicePerformanceClass = -1;
6363

64+
public static String cfAccountID = "";
65+
public static String cfApiToken = "";
66+
public static boolean cfEnableStt = false;
67+
6468
public static boolean loopStickers() {
6569
return LiteMode.isEnabled(LiteMode.FLAG_ANIMATED_STICKERS_CHAT);
6670
}
@@ -478,6 +482,9 @@ public static void saveConfig() {
478482
editor.putString("storageCacheDir", !TextUtils.isEmpty(storageCacheDir) ? storageCacheDir : "");
479483
editor.putBoolean("proxyRotationEnabled", proxyRotationEnabled);
480484
editor.putInt("proxyRotationTimeout", proxyRotationTimeout);
485+
editor.putString("cfAccountID", cfAccountID);
486+
editor.putString("cfApiToken", cfApiToken);
487+
editor.putBoolean("cfEnableStt", cfEnableStt);
481488

482489
if (pendingAppUpdate != null) {
483490
try {
@@ -545,6 +552,9 @@ public static void loadConfig() {
545552
storageCacheDir = preferences.getString("storageCacheDir", null);
546553
proxyRotationEnabled = preferences.getBoolean("proxyRotationEnabled", false);
547554
proxyRotationTimeout = preferences.getInt("proxyRotationTimeout", ProxyRotationController.DEFAULT_TIMEOUT_INDEX);
555+
cfAccountID = preferences.getString("cfAccountID", "");
556+
cfApiToken = preferences.getString("cfApiToken", "");
557+
cfEnableStt = preferences.getBoolean("cfEnableStt", false);
548558
String authKeyString = preferences.getString("pushAuthKey", null);
549559
if (!TextUtils.isEmpty(authKeyString)) {
550560
pushAuthKey = Base64.decode(authKeyString, Base64.DEFAULT);

TMessagesProj/src/main/java/org/telegram/ui/Cells/ChatMessageCell.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12490,6 +12490,8 @@ private void updateWaveform() {
1249012490
(
1249112491
UserConfig.getInstance(currentAccount).isPremium()
1249212492
||
12493+
org.telegram.messenger.CloudflareSTT.isConfigured()
12494+
||
1249312495
TranscribeButton.isFreeTranscribeInChat(currentMessageObject)
1249412496
||
1249512497
MessagesController.getInstance(currentAccount).transcribeAudioTrialWeeklyNumber > 0 &&

TMessagesProj/src/main/java/org/telegram/ui/Components/TranscribeButton.java

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.telegram.messenger.ChatObject;
3535
import org.telegram.messenger.DialogObject;
3636
import org.telegram.messenger.FileLog;
37+
import org.telegram.messenger.LocaleController;
3738
import org.telegram.messenger.MessageObject;
3839
import org.telegram.messenger.MessagesController;
3940
import org.telegram.messenger.MessagesStorage;
@@ -48,6 +49,7 @@
4849
import org.telegram.ui.Cells.ChatMessageCell;
4950
import org.telegram.ui.PremiumPreviewFragment;
5051

52+
import java.io.File;
5153
import java.util.ArrayList;
5254
import java.util.HashMap;
5355
import java.util.Objects;
@@ -116,7 +118,7 @@ public TranscribeButton(ChatMessageCell parent, SeekBarWaveform seekBar) {
116118

117119
this.isOpen = false;
118120
this.shouldBeOpen = false;
119-
premium = parent.getMessageObject() != null && UserConfig.getInstance(parent.getMessageObject().currentAccount).isPremium();
121+
premium = parent.getMessageObject() != null && (UserConfig.getInstance(parent.getMessageObject().currentAccount).isPremium() || org.telegram.messenger.CloudflareSTT.isConfigured());
120122

121123
loadingFloat = new AnimatedFloat(parent, 250, CubicBezierInterpolator.EASE_OUT_QUINT);
122124
animatedDrawLock = new AnimatedFloat(parent, 250, CubicBezierInterpolator.EASE_OUT_QUINT);
@@ -309,7 +311,7 @@ public void setBounds(int x, int y, int w, int h, int r) {
309311
this.radius = Math.min(Math.min(w, h) / 2, r);
310312
this.diameter = this.radius * 2;
311313
}
312-
314+
313315
public int width() {
314316
return this.bounds.width();
315317
}
@@ -678,6 +680,63 @@ private static void transcribePressed(MessageObject messageObject, boolean open,
678680
if (BuildVars.LOGS_ENABLED) {
679681
FileLog.d("sending Transcription request, msg_id=" + messageId + " dialog_id=" + dialogId);
680682
}
683+
if (org.telegram.messenger.CloudflareSTT.isConfigured()) {
684+
File path = null;
685+
String attachPath = messageObject.messageOwner.attachPath;
686+
if (!TextUtils.isEmpty(attachPath)) {
687+
File temp = new File(attachPath);
688+
if (temp.exists()) {
689+
path = temp;
690+
}
691+
}
692+
if (path == null) {
693+
path = org.telegram.messenger.FileLoader.getInstance(account).getPathToMessage(messageObject.messageOwner);
694+
if (path != null && !path.exists()) {
695+
path = null;
696+
}
697+
}
698+
if (path == null) {
699+
path = org.telegram.messenger.FileLoader.getInstance(account).getPathToAttach(messageObject.getDocument(), true);
700+
}
701+
if (path == null || !path.exists()) {
702+
NotificationCenter.getInstance(account).postNotificationName(NotificationCenter.voiceTranscriptionUpdate, messageObject);
703+
NotificationCenter.getInstance(account).postNotificationName(NotificationCenter.updateTranscriptionLock);
704+
NotificationCenter.getGlobalInstance().postNotificationName(NotificationCenter.showBulletin, Bulletin.TYPE_ERROR, LocaleController.getString(R.string.PleaseDownload));
705+
return;
706+
}
707+
long id = org.telegram.messenger.Utilities.random.nextLong();
708+
if (transcribeOperationsByDialogPosition == null) {
709+
transcribeOperationsByDialogPosition = new HashMap<>();
710+
}
711+
transcribeOperationsByDialogPosition.put(reqInfoHash(messageObject), messageObject);
712+
org.telegram.messenger.CloudflareSTT.requestWorkersAi(path.getAbsolutePath(), messageObject.isRoundVideo(), (text, exception) -> {
713+
if (text != null) {
714+
if (transcribeOperationsById == null) {
715+
transcribeOperationsById = new HashMap<>();
716+
}
717+
transcribeOperationsById.put(id, messageObject);
718+
messageObject.messageOwner.voiceTranscriptionId = id;
719+
720+
final long duration = SystemClock.elapsedRealtime() - start;
721+
TranscribeButton.openVideoTranscription(messageObject);
722+
messageObject.messageOwner.voiceTranscriptionOpen = true;
723+
messageObject.messageOwner.voiceTranscriptionFinal = true;
724+
725+
MessagesStorage.getInstance(account).updateMessageVoiceTranscription(dialogId, messageId, text, messageObject.messageOwner);
726+
AndroidUtilities.runOnUIThread(() -> finishTranscription(messageObject, id, text), Math.max(0, minDuration - duration));
727+
} else {
728+
AndroidUtilities.runOnUIThread(() -> {
729+
if (transcribeOperationsByDialogPosition != null) {
730+
transcribeOperationsByDialogPosition.remove(reqInfoHash(messageObject));
731+
}
732+
NotificationCenter.getInstance(account).postNotificationName(NotificationCenter.voiceTranscriptionUpdate, messageObject);
733+
NotificationCenter.getInstance(account).postNotificationName(NotificationCenter.updateTranscriptionLock);
734+
org.telegram.messenger.CloudflareSTT.showErrorDialog(exception);
735+
});
736+
}
737+
});
738+
return;
739+
}
681740
TLRPC.TL_messages_transcribeAudio req = new TLRPC.TL_messages_transcribeAudio();
682741
req.peer = peer;
683742
req.msg_id = messageId;
@@ -848,6 +907,9 @@ public static boolean showTranscribeLock(MessageObject messageObject) {
848907
if (messageObject == null || messageObject.messageOwner == null) {
849908
return false;
850909
}
910+
if (org.telegram.messenger.CloudflareSTT.isConfigured()) {
911+
return false;
912+
}
851913
if (isFreeTranscribeInChat(messageObject)) {
852914
return false;
853915
}

0 commit comments

Comments
 (0)