Skip to content

Commit 7eede7c

Browse files
vid277cdxker
authored andcommitted
feature: add audio transcription route
1 parent 9342ab8 commit 7eede7c

2 files changed

Lines changed: 93 additions & 0 deletions

File tree

server/src/handlers/message_handler.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ use openai_dive::v1::{
3535
},
3636
image::{EditImageParametersBuilder, ImageData, ImageQuality, ImageSize},
3737
shared::{FileUpload, FileUploadBytes},
38+
audio::{AudioOutputFormat, AudioTranscriptionParametersBuilder},
3839
},
40+
models::WhisperModel,
3941
};
4042
use serde::{Deserialize, Serialize};
4143
use serde_json::json;
@@ -1412,3 +1414,90 @@ pub async fn edit_image(
14121414

14131415
Ok(HttpResponse::Ok().json(ImageEditResponse { image_urls }))
14141416
}
1417+
1418+
#[derive(Debug, Serialize, Deserialize, ToSchema)]
1419+
pub struct TranscribeAudioReqPayload {
1420+
/// The base64 encoded audio input of the user's input message.
1421+
pub audio_base64: String,
1422+
}
1423+
1424+
/// Transcribe Audio
1425+
///
1426+
/// Uses `whisper-1` to transcribe an audio file passed in as a base64 encoded string.
1427+
#[utoipa::path(
1428+
post,
1429+
path = "/message/transcribe_audio",
1430+
context_path = "/api",
1431+
tag = "Message",
1432+
request_body(content = TranscribeAudioReqPayload, description = "JSON request payload to transcribe an audio file", content_type = "application/json"),
1433+
responses(
1434+
(status = 200, description = "The transcribed text", body = String,
1435+
headers(
1436+
("TR-QueryID" = uuid::Uuid, description = "Query ID that is used for tracking analytics")
1437+
)
1438+
),
1439+
(status = 400, description = "Service error relating to transcribing the audio", body = ErrorResponseBody),
1440+
),
1441+
params(
1442+
("TR-Dataset" = uuid::Uuid, Header, description = "The dataset id or tracking_id to use for the request. We assume you intend to use an id if the value is a valid uuid."),
1443+
),
1444+
security(
1445+
("ApiKey" = ["admin"]),
1446+
)
1447+
)]
1448+
pub async fn transcribe_audio(
1449+
data: web::Json<TranscribeAudioReqPayload>,
1450+
dataset_org_plan_sub: DatasetAndOrgWithSubAndPlan,
1451+
_required_user: AdminOnly,
1452+
) -> Result<HttpResponse, ServiceError> {
1453+
let audio_bytes = base64::decode(data.audio_base64.clone()).map_err(|err| {
1454+
log::error!("Failed to decode base64 audio: {:?}", err);
1455+
ServiceError::BadRequest(format!("Error decoding audio base64: {:?}", err))
1456+
})?;
1457+
1458+
let dataset_config =
1459+
DatasetConfiguration::from_json(dataset_org_plan_sub.dataset.clone().server_configuration);
1460+
1461+
let llm_api_key = get_llm_api_key(&dataset_config);
1462+
let mut base_url = dataset_config.LLM_BASE_URL.clone();
1463+
1464+
if !base_url.contains("openai.com") {
1465+
base_url = "https://api.openai.com/v1".to_string();
1466+
}
1467+
1468+
let client = Client {
1469+
headers: None,
1470+
api_key: llm_api_key,
1471+
project: None,
1472+
http_client: reqwest::Client::new(),
1473+
base_url,
1474+
organization: None,
1475+
};
1476+
1477+
let file_upload = FileUpload::Bytes(FileUploadBytes {
1478+
filename: "audio.mp3".to_string(),
1479+
bytes: audio_bytes.into(),
1480+
});
1481+
1482+
let parameters = AudioTranscriptionParametersBuilder::default()
1483+
.file(file_upload)
1484+
.model(WhisperModel::Whisper1.to_string())
1485+
.response_format(AudioOutputFormat::Text)
1486+
.language("en".to_string())
1487+
.build()
1488+
.map_err(|err| {
1489+
log::error!("Failed to build transcription parameters: {:?}", err);
1490+
ServiceError::InternalServerError(format!("Failed to build transcription parameters: {:?}", err))
1491+
})?;
1492+
1493+
let transcribed_text = client
1494+
.audio()
1495+
.create_transcription(parameters)
1496+
.await
1497+
.map_err(|err| {
1498+
log::error!("Failed to transcribe audio: {:?}", err);
1499+
ServiceError::InternalServerError(format!("Failed to transcribe audio: {:?}", err))
1500+
})?;
1501+
1502+
Ok(HttpResponse::Ok().json(transcribed_text.replace("\n", "")))
1503+
}

server/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,10 @@ pub fn main() -> std::io::Result<()> {
11711171
web::resource("/message/edit_image")
11721172
.route(web::post().to(handlers::message_handler::edit_image))
11731173
)
1174+
.service(
1175+
web::resource("/message/transcribe_audio")
1176+
.route(web::post().to(handlers::message_handler::transcribe_audio))
1177+
)
11741178
.service(
11751179
web::resource("/message/{message_id}")
11761180
.route(web::get().to(handlers::message_handler::get_message_by_id))

0 commit comments

Comments
 (0)