Skip to content

Commit 9727c61

Browse files
committed
Add weighting to TTS Services
1 parent d24fd69 commit 9727c61

8 files changed

Lines changed: 88 additions & 50 deletions

File tree

src/main.rs

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66

77
use anyhow::Ok;
88
use parking_lot::Mutex;
9+
use small_fixed_array::FixedArray;
910

1011
use poise::serenity_prelude as serenity;
1112
use serenity::small_fixed_array::FixedString;
@@ -56,7 +57,7 @@ async fn main_(start_time: std::time::SystemTime) -> Result<()> {
5657
let http = Arc::new(http_builder.build());
5758

5859
println!("Performing big startup join");
59-
let tts_service = || config.main.tts_services[0].clone();
60+
let tts_service = || config.tts_services[0].url.clone();
6061
let (
6162
ws_connections,
6263
webhooks,
@@ -73,7 +74,7 @@ async fn main_(start_time: std::time::SystemTime) -> Result<()> {
7374
shard_count,
7475
premium_user,
7576
) = tokio::try_join!(
76-
setup_ws_stream(&config.main),
77+
setup_ws_stream(&config.tts_services),
7778
get_webhooks(&http, config.webhooks),
7879
create_db_handler!(pool.clone(), "guilds", "guild_id"),
7980
create_db_handler!(pool.clone(), "userinfo", "user_id"),
@@ -111,43 +112,65 @@ async fn main_(start_time: std::time::SystemTime) -> Result<()> {
111112
tokio::spawn(analytics.clone().start());
112113

113114
let data = Arc::new(Data {
114-
pool,
115+
analytics,
116+
guilds_db,
117+
userinfo_db,
118+
nickname_db,
119+
user_voice_db,
120+
guild_voice_db,
121+
122+
entitlement_cache: mini_moka::sync::Cache::builder()
123+
.time_to_live(Duration::from_hours(1))
124+
.build(),
125+
startup_message,
126+
premium_avatar_url: FixedString::from_string_trunc(premium_user.face()),
115127
system_info: Mutex::new(sysinfo::System::new()),
116-
bot_list_tokens: Mutex::new(config.bot_list_tokens),
128+
start_time,
129+
reqwest,
130+
regex_cache: RegexCache::new()?,
131+
webhooks,
132+
pool,
117133

134+
service_weight_lookups: FixedArray::try_from(
135+
config
136+
.tts_services
137+
.iter()
138+
.enumerate()
139+
.flat_map(|(index, service)| {
140+
(1..=service.weight.get()).map(move |_| index.try_into().unwrap())
141+
})
142+
.collect::<Box<[_]>>(),
143+
)
144+
.unwrap(),
145+
tts_services: FixedArray::try_from(
146+
config
147+
.tts_services
148+
.into_iter()
149+
.map(|service| service.url)
150+
.collect::<Box<[_]>>(),
151+
)
152+
.unwrap(),
118153
ws_connections,
154+
voice_connections: Mutex::default(),
155+
156+
config: config.main,
157+
premium_config: config.premium,
119158
runners: OnceLock::new(), // Filled in later
159+
160+
website_info: Mutex::new(config.website_info),
161+
bot_list_tokens: Mutex::new(config.bot_list_tokens),
120162
fully_started: AtomicBool::new(false),
121-
voice_connections: parking_lot::Mutex::default(),
122163
update_startup_lock: tokio::sync::Mutex::new(()),
123-
entitlement_cache: mini_moka::sync::Cache::builder()
124-
.time_to_live(Duration::from_hours(1))
125-
.build(),
126164

127-
gtts_voices,
128165
espeak_voices,
129-
translation_languages,
130-
gcloud_voices: prepare_gcloud_voices(gcloud_voices),
166+
gtts_voices,
131167
polly_voices: polly_voices
132168
.into_iter()
133169
.map(|v| (v.id.clone(), v))
134170
.collect::<BTreeMap<_, _>>(),
171+
gcloud_voices: prepare_gcloud_voices(gcloud_voices),
135172

136-
config: config.main,
137-
premium_config: config.premium,
138-
website_info: Mutex::new(config.website_info),
139-
reqwest,
140-
premium_avatar_url: FixedString::from_string_trunc(premium_user.face()),
141-
analytics,
142-
webhooks,
143-
start_time,
144-
startup_message,
145-
regex_cache: RegexCache::new()?,
146-
guilds_db,
147-
userinfo_db,
148-
nickname_db,
149-
user_voice_db,
150-
guild_voice_db,
173+
translation_languages,
151174
});
152175

153176
let framework_options = poise::FrameworkOptions {

src/startup.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use tokio_tungstenite::tungstenite::Message;
1212
use tts_core::{
1313
opt_ext::OptionTryUnwrap as _,
1414
structs::{
15-
Data, GoogleGender, GoogleVoice, MainConfig, Result, TTSMode, WebhookConfig,
15+
Data, GoogleGender, GoogleVoice, Result, TTSMode, TTSServiceConfig, WebhookConfig,
1616
WebhookConfigRaw,
1717
},
1818
voice,
@@ -121,14 +121,18 @@ async fn connect_ws_stream(mut url: reqwest::Url) -> Result<voice::RawWSStream>
121121
Ok(tokio_tungstenite::connect_async(url).await?.0)
122122
}
123123

124-
pub async fn setup_ws_stream(config: &MainConfig) -> Result<FixedArray<voice::LockedWSStream, u8>> {
125-
let tasks = config.tts_services.iter().map(async |url| {
126-
let stream = connect_ws_stream(url.clone()).await?;
127-
anyhow::Ok(voice::LockedWSStream::new(stream))
128-
});
124+
pub async fn setup_ws_stream(
125+
tts_services: &[TTSServiceConfig],
126+
) -> Result<FixedArray<voice::LockedWSStream, u8>> {
127+
let tasks = tts_services
128+
.iter()
129+
.map(async |TTSServiceConfig { url, .. }| {
130+
let stream = connect_ws_stream(url.clone()).await?;
131+
anyhow::Ok(voice::LockedWSStream::new(stream))
132+
});
129133

130134
let streams = futures::future::try_join_all(tasks).await?;
131-
println!("Connected to {} tts-services", config.tts_services.len());
135+
println!("Connected to {} tts-services", tts_services.len());
132136
Ok(streams.trunc_into())
133137
}
134138

@@ -176,13 +180,13 @@ pub fn start_ws_health_checks(data: &Arc<Data>) {
176180
if check_ws_healthy(&mut rng, &mut ws_tx).await {
177181
tracing::debug!("Health check passed for tts-service-{index}");
178182
} else {
179-
let url = data.config.tts_services[index].clone();
183+
let url = data.tts_services[index].clone();
180184
*ws_tx = reconnect_ws_stream(&url, index).await;
181185
}
182186
}
183187
};
184188

185-
for index in 0..data.ws_connections.len() {
189+
for index in 0..data.tts_services.len() {
186190
tokio::spawn(health_check(data.clone(), index));
187191
}
188192
}

tts_commands/src/other.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use poise::{
99

1010
use aformat::ToArrayString;
1111
use tts_core::{
12-
common::{build_invite_components, fetch_audio, prepare_url, select_tts_index},
12+
common::{build_invite_components, fetch_audio, prepare_url},
1313
constants::OPTION_SEPERATORS,
1414
opt_ext::OptionTryUnwrap,
1515
require_guild,
@@ -123,12 +123,12 @@ async fn tts_(ctx: Context<'_>, author: &serenity::User, message: &str) -> Comma
123123
let speaking_rate = data.speaking_rate(author.id, mode).await?;
124124

125125
let tts_service_index = match ctx.guild_id() {
126-
Some(guild_id) => select_tts_index(guild_id, data.config.tts_services.len()),
126+
Some(guild_id) => data.select_tts_index(guild_id),
127127
None => 0,
128128
};
129129

130130
let url = prepare_url(
131-
data.config.tts_services[tts_service_index].clone(),
131+
data.tts_services[tts_service_index].clone(),
132132
message,
133133
&voice,
134134
mode,

tts_commands/src/owner.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,15 @@ pub async fn info_(ctx: Context<'_>) -> CommandResult {
156156
.await?;
157157

158158
let voice_debug = voice::debug_info(&data, guild_id);
159+
let tts_service_url = &data.tts_services[data.select_tts_index(guild_id)];
159160

160161
let embed = CreateEmbed::default()
161162
.title("TTS Bot Debug Info")
162163
.description(format!(
163164
"
164165
Shard ID: `{shard_id}`
165166
Voice Connection: `{voice_debug:?}`
167+
TTS Service URL: `{tts_service_url}`
166168
167169
Server Data: `{guild_row:?}`
168170
User Data: `{user_row:?}`

tts_core/src/common.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ pub(crate) fn timestamp_in_future(ts: serenity::Timestamp) -> bool {
1111
*ts > chrono::Utc::now()
1212
}
1313

14-
#[must_use]
15-
pub fn select_tts_index(guild_id: serenity::GuildId, tts_services: u8) -> u8 {
16-
(guild_id.get() % u64::from(tts_services))
17-
.try_into()
18-
.unwrap()
19-
}
20-
2114
/// Builds components for invite command and invite link in DMs
2215
///
2316
/// This has to allocate internally due to lifetime issues with trying CPS.

tts_core/src/structs.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ macro_rules! into_static_display {
4848
pub struct Config {
4949
#[serde(rename = "Main")]
5050
pub main: MainConfig,
51+
#[serde(rename = "TTS-Services")]
52+
pub tts_services: FixedArray<TTSServiceConfig, u8>,
5153
#[serde(rename = "Webhook-Info")]
5254
pub webhooks: WebhookConfigRaw,
5355
#[serde(rename = "Website-Info")]
@@ -60,7 +62,6 @@ pub struct Config {
6062

6163
#[derive(serde::Deserialize)]
6264
pub struct MainConfig {
63-
pub tts_services: FixedArray<reqwest::Url, u8>,
6465
pub website_url: Option<reqwest::Url>,
6566
pub announcements_channel: ChannelId,
6667
pub main_server_invite: FixedString,
@@ -74,6 +75,12 @@ pub struct MainConfig {
7475
pub gtts_disabled: AtomicBool,
7576
}
7677

78+
#[derive(serde::Deserialize)]
79+
pub struct TTSServiceConfig {
80+
pub url: reqwest::Url,
81+
pub weight: NonZeroU8,
82+
}
83+
7784
#[derive(serde::Deserialize)]
7885
pub struct PostgresConfig {
7986
pub host: String,
@@ -204,11 +211,14 @@ pub struct Data {
204211
pub webhooks: WebhookConfig,
205212
pub pool: sqlx::PgPool,
206213

214+
pub service_weight_lookups: FixedArray<u8, u8>, // Maps weighted index to non-weighted.
215+
pub tts_services: FixedArray<reqwest::Url, u8>,
207216
pub ws_connections: FixedArray<voice::LockedWSStream, u8>,
208217
pub voice_connections: Mutex<HashMap<serenity::GuildId, voice::ConnectionEntry>>,
209-
pub runners: OnceLock<Arc<DashMap<serenity::ShardId, serenity::ShardRunnerMetadata>>>,
218+
210219
pub config: MainConfig,
211220
pub premium_config: Option<PremiumConfig>,
221+
pub runners: OnceLock<Arc<DashMap<serenity::ShardId, serenity::ShardRunnerMetadata>>>,
212222

213223
// Startup information
214224
pub website_info: Mutex<Option<WebsiteInfo>>,
@@ -231,6 +241,13 @@ impl std::fmt::Debug for Data {
231241
}
232242

233243
impl Data {
244+
#[must_use]
245+
pub fn select_tts_index(&self, guild_id: serenity::GuildId) -> u8 {
246+
self.service_weight_lookups[(guild_id.get() % u64::from(self.service_weight_lookups.len()))
247+
.try_into()
248+
.unwrap()]
249+
}
250+
234251
pub async fn speaking_rate(&self, user_id: UserId, mode: TTSMode) -> Result<f32> {
235252
let row = self.user_voice_db.get((user_id.into(), mode)).await?;
236253

tts_core/src/voice/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use tokio::{net::TcpStream, sync::Mutex as TMutex};
2121
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message as RawWSMessage};
2222

2323
use crate::{
24-
common::select_tts_index,
2524
structs::Data,
2625
voice::models::{WSConnectionInfo, WSMessageFrame},
2726
};
@@ -112,7 +111,7 @@ pub async fn start_connection(data: &Data, ctx: VCContext) -> StartConnectionRes
112111
//
113112
// It is important that `rx` is dropped AFTER the leave notifier is triggered, the `rx` drop
114113
// will any pending leave notifiers and therefore trigger them.
115-
let ws_tx = &data.ws_connections[select_tts_index(guild_id, data.ws_connections.len())];
114+
let ws_tx = &data.ws_connections[data.select_tts_index(guild_id)];
116115
let leave_notifier = ws_task(ctx, ws_tx, &mut rx, connect_tx).await;
117116

118117
data.voice_connections.lock().remove(&guild_id);

tts_migrations/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ pub async fn load_db_and_conf() -> Result<(sqlx::PgPool, Config)> {
240240
run(&mut config_toml, &pool).await?;
241241

242242
let config: Config = config_toml.try_into()?;
243-
if config.main.tts_services.is_empty() {
243+
if config.tts_services.is_empty() {
244244
return Err(anyhow::anyhow!("No TTS services are configured"));
245245
}
246246

0 commit comments

Comments
 (0)