Skip to content

Commit 57a6e2a

Browse files
Fix force all traffic (#761)
* fix client policy condition * Fix macos * Fix apple.rs --------- Co-authored-by: Adam Ciarciński <adam@defguard.net>
1 parent 00fd54b commit 57a6e2a

File tree

3 files changed

+49
-25
lines changed

3 files changed

+49
-25
lines changed

src-tauri/src/apple.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@ use objc2_network_extension::{
3030
NETunnelProviderManager, NETunnelProviderProtocol, NETunnelProviderSession, NEVPNStatus,
3131
};
3232
use serde::Deserialize;
33-
use sqlx::SqliteExecutor;
3433
use tauri::{AppHandle, Emitter, Manager};
3534
use tracing::Level;
3635

3736
use crate::{
3837
active_connections::find_connection,
3938
appstate::AppState,
4039
database::{
41-
models::{location::Location, tunnel::Tunnel, wireguard_keys::WireguardKeys, Id},
40+
models::{
41+
instance::{ClientTrafficPolicy, Instance},
42+
location::Location,
43+
tunnel::Tunnel,
44+
wireguard_keys::WireguardKeys,
45+
Id,
46+
},
4247
DB_POOL,
4348
},
4449
error::Error,
@@ -931,7 +936,7 @@ pub async fn sync_locations_and_tunnels(mtu: Option<u32>) -> Result<(), sqlx::Er
931936
let all_locations = Location::all(&*DB_POOL, false).await?;
932937
for location in &all_locations {
933938
// For syncing, set `preshred_key` to `None`.
934-
let Ok(tunnel_config) = location.tunnel_configurarion(&*DB_POOL, None, mtu).await else {
939+
let Ok(tunnel_config) = location.tunnel_configurarion(None, mtu).await else {
935940
error!(
936941
"Failed to convert location {} to tunnel configuration.",
937942
location.name
@@ -1019,17 +1024,13 @@ pub async fn sync_locations_and_tunnels(mtu: Option<u32>) -> Result<(), sqlx::Er
10191024

10201025
impl Location<Id> {
10211026
/// Build [`TunnelConfiguration`] from [`Location`].
1022-
pub(crate) async fn tunnel_configurarion<'e, E>(
1027+
pub(crate) async fn tunnel_configurarion(
10231028
&self,
1024-
executor: E,
10251029
preshared_key: Option<String>,
10261030
mtu: Option<u32>,
1027-
) -> Result<TunnelConfiguration, Error>
1028-
where
1029-
E: SqliteExecutor<'e>,
1030-
{
1031+
) -> Result<TunnelConfiguration, Error> {
10311032
debug!("Looking for WireGuard keys for location {self} instance");
1032-
let Some(keys) = WireguardKeys::find_by_instance_id(executor, self.instance_id).await?
1033+
let Some(keys) = WireguardKeys::find_by_instance_id(&*DB_POOL, self.instance_id).await?
10331034
else {
10341035
error!("No keys found for instance: {}", self.instance_id);
10351036
return Err(Error::InternalError(
@@ -1057,7 +1058,19 @@ impl Location<Id> {
10571058
}
10581059

10591060
debug!("Parsing location {self} allowed IPs: {}", self.allowed_ips);
1060-
let allowed_ips = if self.route_all_traffic {
1061+
let Some(instance) = Instance::find_by_id(&*DB_POOL, self.instance_id).await? else {
1062+
error!("Instance {} not found", self.instance_id);
1063+
return Err(Error::InternalError(format!(
1064+
"Instance {} not found",
1065+
self.instance_id
1066+
)));
1067+
};
1068+
let route_all_traffic = match instance.client_traffic_policy {
1069+
ClientTrafficPolicy::ForceAllTraffic => true,
1070+
ClientTrafficPolicy::DisableAllTraffic => false,
1071+
ClientTrafficPolicy::None => self.route_all_traffic,
1072+
};
1073+
let allowed_ips = if route_all_traffic {
10611074
debug!("Using all traffic routing for location {self}");
10621075
vec![DEFAULT_ROUTE_IPV4.into(), DEFAULT_ROUTE_IPV6.into()]
10631076
} else {

src-tauri/src/database/models/location.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ use sqlx::{prelude::Type, query, query_as, query_scalar, Error as SqlxError, Sql
1111
use super::wireguard_keys::WireguardKeys;
1212
use super::{Id, NoId};
1313
#[cfg(not(target_os = "macos"))]
14-
use crate::utils::{DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6};
14+
use crate::{
15+
database::DbPool,
16+
utils::{DEFAULT_ROUTE_IPV4, DEFAULT_ROUTE_IPV6},
17+
};
1518
use crate::{
1619
error::Error,
1720
proto::{
@@ -240,19 +243,17 @@ impl Location<Id> {
240243
}
241244

242245
#[cfg(not(target_os = "macos"))]
243-
pub(crate) async fn interface_configuration<'e, E>(
246+
pub(crate) async fn interface_configuration(
244247
&self,
245-
executor: E,
248+
pool: &DbPool,
246249
interface_name: String,
247250
preshared_key: Option<String>,
248251
mtu: Option<u32>,
249-
) -> Result<InterfaceConfiguration, Error>
250-
where
251-
E: SqliteExecutor<'e>,
252-
{
252+
) -> Result<InterfaceConfiguration, Error> {
253+
use crate::database::models::instance::{ClientTrafficPolicy, Instance};
254+
253255
debug!("Looking for WireGuard keys for location {self} instance");
254-
let Some(keys) = WireguardKeys::find_by_instance_id(executor, self.instance_id).await?
255-
else {
256+
let Some(keys) = WireguardKeys::find_by_instance_id(pool, self.instance_id).await? else {
256257
error!("No keys found for instance: {}", self.instance_id);
257258
return Err(Error::InternalError(
258259
"No keys found for instance".to_string(),
@@ -279,7 +280,19 @@ impl Location<Id> {
279280
}
280281

281282
debug!("Parsing location {self} allowed IPs: {}", self.allowed_ips);
282-
let allowed_ips = if self.route_all_traffic {
283+
let Some(instance) = Instance::find_by_id(pool, self.instance_id).await? else {
284+
error!("Instance {} not found", self.instance_id);
285+
return Err(Error::InternalError(format!(
286+
"Instance {} not found",
287+
self.instance_id
288+
)));
289+
};
290+
let route_all_traffic = match instance.client_traffic_policy {
291+
ClientTrafficPolicy::ForceAllTraffic => true,
292+
ClientTrafficPolicy::DisableAllTraffic => false,
293+
ClientTrafficPolicy::None => self.route_all_traffic,
294+
};
295+
let allowed_ips = if route_all_traffic {
283296
debug!("Using all traffic routing for location {self}");
284297
vec![DEFAULT_ROUTE_IPV4.into(), DEFAULT_ROUTE_IPV6.into()]
285298
} else {

src-tauri/src/utils.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,9 @@ pub(crate) async fn setup_interface(
131131
_name: &str,
132132
preshared_key: Option<String>,
133133
mtu: Option<u32>,
134-
pool: &DbPool,
134+
_pool: &DbPool,
135135
) -> Result<String, Error> {
136-
let tunnel_config = location
137-
.tunnel_configurarion(pool, preshared_key, mtu)
138-
.await?;
136+
let tunnel_config = location.tunnel_configurarion(preshared_key, mtu).await?;
139137

140138
tunnel_config.save();
141139
tokio::time::sleep(TUNNEL_START_DELAY).await;

0 commit comments

Comments
 (0)