|
1 | | -use std::fmt; |
| 1 | +use std::{fmt, i32}; |
2 | 2 |
|
3 | 3 | use serde::{Deserialize, Serialize}; |
4 | 4 | use sqlx::{prelude::Type, query, query_as, query_scalar, Error as SqlxError, SqliteExecutor}; |
@@ -93,19 +93,16 @@ impl Location<Id> { |
93 | 93 | where |
94 | 94 | E: SqliteExecutor<'e>, |
95 | 95 | { |
96 | | - let max_mode = if include_service_locations { |
97 | | - ServiceLocationMode::AlwaysOn as i32 |
98 | | - } else { |
99 | | - ServiceLocationMode::Disabled as i32 |
100 | | - }; |
| 96 | + let max_service_location_mode = |
| 97 | + Self::get_service_location_mode_filter(include_service_locations); |
101 | 98 | query_as!( |
102 | 99 | Self, |
103 | 100 | "SELECT id, instance_id, name, address, pubkey, endpoint, allowed_ips, dns, network_id,\ |
104 | 101 | route_all_traffic, keepalive_interval, \ |
105 | 102 | location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ |
106 | 103 | FROM location WHERE service_location_mode <= $1 \ |
107 | 104 | ORDER BY name ASC;", |
108 | | - max_mode |
| 105 | + max_service_location_mode |
109 | 106 | ) |
110 | 107 | .fetch_all(executor) |
111 | 108 | .await |
@@ -167,19 +164,16 @@ impl Location<Id> { |
167 | 164 | where |
168 | 165 | E: SqliteExecutor<'e>, |
169 | 166 | { |
170 | | - let max_mode = if include_service_locations { |
171 | | - ServiceLocationMode::AlwaysOn as i32 |
172 | | - } else { |
173 | | - ServiceLocationMode::Disabled as i32 |
174 | | - }; |
| 167 | + let max_service_location_mode = |
| 168 | + Self::get_service_location_mode_filter(include_service_locations); |
175 | 169 | query_as!( |
176 | 170 | Self, |
177 | 171 | "SELECT id \"id: _\", instance_id, name, address, pubkey, endpoint, allowed_ips, dns, \ |
178 | 172 | network_id, route_all_traffic, keepalive_interval, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" \ |
179 | 173 | FROM location WHERE instance_id = $1 AND service_location_mode <= $2 \ |
180 | 174 | ORDER BY name ASC", |
181 | 175 | instance_id, |
182 | | - max_mode |
| 176 | + max_service_location_mode |
183 | 177 | ) |
184 | 178 | .fetch_all(executor) |
185 | 179 | .await |
@@ -236,6 +230,16 @@ impl Location<Id> { |
236 | 230 | LocationMfaMode::Internal | LocationMfaMode::External => true, |
237 | 231 | } |
238 | 232 | } |
| 233 | + |
| 234 | + /// Returns a filter value that can be used in SQL queries like `service_location_mode <= ?` when querying locations |
| 235 | + /// to exclude (<= 1) or include service locations (all service locations modes). |
| 236 | + fn get_service_location_mode_filter(include_service_locations: bool) -> i32 { |
| 237 | + if include_service_locations { |
| 238 | + i32::MAX |
| 239 | + } else { |
| 240 | + ServiceLocationMode::Disabled as i32 |
| 241 | + } |
| 242 | + } |
239 | 243 | } |
240 | 244 |
|
241 | 245 | impl Location<NoId> { |
|
0 commit comments