Skip to content

Commit 50e19b6

Browse files
authored
Fixes custom schema processing (#287)
The `sanitizeSchema` method originally added changed `"` to `""` and added double quotes around the entire schema so that it would support schema names with single and double quotes in them such as `C$+0m'`. However, this approach is not compatible with Migration #10, where we need to use the schema name as a column value.
1 parent cf5e4a9 commit 50e19b6

10 files changed

Lines changed: 363 additions & 120 deletions

File tree

transact-cli/src/test/java/dev/dbos/transact/cli/MigrateCommandTest.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import static org.junit.jupiter.api.Assertions.assertTrue;
66

77
import dev.dbos.transact.Constants;
8+
import dev.dbos.transact.database.SystemDatabase;
89
import dev.dbos.transact.migrations.MigrationManager;
910

1011
import java.io.PrintWriter;
@@ -17,6 +18,8 @@
1718
import org.junit.jupiter.api.BeforeEach;
1819
import org.junit.jupiter.api.Test;
1920
import org.junit.jupiter.api.Timeout;
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.ValueSource;
2023
import picocli.CommandLine;
2124

2225
@Timeout(value = 2, unit = TimeUnit.MINUTES)
@@ -72,18 +75,32 @@ public void migrate_twice() throws Exception {
7275
assertTrue(checkTable(Constants.DB_SCHEMA, "workflow_status"));
7376
}
7477

75-
@Test
76-
public void migrate_custom_schema() throws Exception {
77-
78+
@ParameterizedTest
79+
@ValueSource(strings = {"invalid\"schema", "invalid'schema"})
80+
void testRunMigrations_fails_invalid_schema(String schema) throws Exception {
7881
assertFalse(checkConnection());
7982

80-
var schema = "C\"$+0m'";
83+
var cmd = new CommandLine(new DBOSCommand());
84+
var sw = new StringWriter();
85+
cmd.setOut(new PrintWriter(sw));
86+
87+
var exitCode =
88+
cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema", "%s".formatted(schema));
89+
assertEquals(1, exitCode);
90+
}
91+
92+
@ParameterizedTest
93+
@ValueSource(strings = {"F8nny_sCHem@-n@m3", "embedded\0null"})
94+
public void migrate_custom_schema(String schema) throws Exception {
95+
96+
assertFalse(checkConnection());
8197

8298
var cmd = new CommandLine(new DBOSCommand());
8399
var sw = new StringWriter();
84100
cmd.setOut(new PrintWriter(sw));
85101

86-
var exitCode = cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema=" + schema);
102+
var exitCode =
103+
cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema", "%s".formatted(schema));
87104
assertEquals(0, exitCode);
88105

89106
assertTrue(checkConnection());
@@ -105,7 +122,7 @@ static boolean checkTable(String schema, String table) throws SQLException {
105122
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)";
106123
try (var conn = DriverManager.getConnection(db_url, db_user, db_password);
107124
var stmt = conn.prepareStatement(sql)) {
108-
stmt.setString(1, schema);
125+
stmt.setString(1, SystemDatabase.sanitizeSchema(schema));
109126
stmt.setString(2, table);
110127
try (var rs = stmt.executeQuery()) {
111128
if (rs.next()) {

transact/src/main/java/dev/dbos/transact/database/NotificationsDAO.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void send(
7474
// Insert notification
7575
final String sql =
7676
"""
77-
INSERT INTO %s.notifications (destination_uuid, topic, message) VALUES (?, ?, ?)
77+
INSERT INTO "%s".notifications (destination_uuid, topic, message) VALUES (?, ?, ?)
7878
"""
7979
.formatted(this.schema);
8080

@@ -162,7 +162,7 @@ Object recv(
162162
try (Connection conn = dataSource.getConnection()) {
163163
final String sql =
164164
"""
165-
SELECT topic FROM %s.notifications WHERE destination_uuid = ? AND topic = ?
165+
SELECT topic FROM "%s".notifications WHERE destination_uuid = ? AND topic = ?
166166
"""
167167
.formatted(this.schema);
168168

@@ -214,12 +214,12 @@ Object recv(
214214
"""
215215
WITH oldest_entry AS (
216216
SELECT destination_uuid, topic, message, created_at_epoch_ms
217-
FROM %1$s.notifications
217+
FROM "%1$s".notifications
218218
WHERE destination_uuid = ? AND topic = ?
219219
ORDER BY created_at_epoch_ms ASC
220220
LIMIT 1
221221
)
222-
DELETE FROM %1$s.notifications
222+
DELETE FROM "%1$s".notifications
223223
WHERE destination_uuid = (SELECT destination_uuid FROM oldest_entry)
224224
AND topic = (SELECT topic FROM oldest_entry)
225225
AND created_at_epoch_ms = (SELECT created_at_epoch_ms FROM oldest_entry)
@@ -263,7 +263,7 @@ private void setEvent(
263263
throws SQLException {
264264
final String eventSql =
265265
"""
266-
INSERT INTO %s.workflow_events (workflow_uuid, key, value)
266+
INSERT INTO "%s".workflow_events (workflow_uuid, key, value)
267267
VALUES (?, ?, ?)
268268
ON CONFLICT (workflow_uuid, key)
269269
DO UPDATE SET value = EXCLUDED.value
@@ -279,7 +279,7 @@ ON CONFLICT (workflow_uuid, key)
279279

280280
final String eventHistorySql =
281281
"""
282-
INSERT INTO %s.workflow_events_history (workflow_uuid, function_id, key, value)
282+
INSERT INTO "%s".workflow_events_history (workflow_uuid, function_id, key, value)
283283
VALUES (?, ?, ?, ?)
284284
ON CONFLICT (workflow_uuid, key, function_id)
285285
DO UPDATE SET value = EXCLUDED.value
@@ -382,7 +382,7 @@ Object getEvent(
382382
Object value = null;
383383
final String sql =
384384
"""
385-
SELECT value FROM %s.workflow_events WHERE workflow_uuid = ? AND key = ?
385+
SELECT value FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ?
386386
"""
387387
.formatted(this.schema);
388388

transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ List<String> getAndStartQueuedWorkflows(
6262
var limiterQuery =
6363
"""
6464
SELECT COUNT(*)
65-
FROM %s.workflow_status
65+
FROM "%s".workflow_status
6666
WHERE queue_name = ?
6767
AND status != ?
6868
AND started_at_epoch_ms > ?
@@ -100,7 +100,7 @@ SELECT COUNT(*)
100100
String pendingQuery =
101101
"""
102102
SELECT executor_id, COUNT(*) as task_count
103-
FROM %s.workflow_status
103+
FROM "%s".workflow_status
104104
WHERE queue_name = ? AND status = ?
105105
"""
106106
.formatted(this.schema);
@@ -170,7 +170,7 @@ SELECT executor_id, COUNT(*) as task_count
170170
var query =
171171
"""
172172
SELECT workflow_uuid
173-
FROM %s.workflow_status
173+
FROM "%s".workflow_status
174174
WHERE queue_name = ?
175175
AND status = ?
176176
AND (application_version = ? OR application_version IS NULL)
@@ -226,7 +226,7 @@ SELECT executor_id, COUNT(*) as task_count
226226
List<String> updatedWorkflowIds = new ArrayList<>();
227227
String updateQuery =
228228
"""
229-
UPDATE %s.workflow_status
229+
UPDATE "%s".workflow_status
230230
SET status = ?,
231231
application_version = ?,
232232
executor_id = ?,
@@ -273,7 +273,7 @@ boolean clearQueueAssignment(String workflowId) throws SQLException {
273273

274274
final String sql =
275275
"""
276-
UPDATE %s.workflow_status
276+
UPDATE "%s".workflow_status
277277
SET started_at_epoch_ms = NULL, status = ?
278278
WHERE workflow_uuid = ? AND queue_name IS NOT NULL AND status = ?
279279
"""
@@ -294,7 +294,7 @@ List<String> getQueuePartitions(String queueName) throws SQLException {
294294
final String sql =
295295
"""
296296
SELECT DISTINCT queue_partition_key
297-
FROM %s.workflow_status
297+
FROM "%s".workflow_status
298298
WHERE queue_name = ?
299299
AND status = ?
300300
AND queue_partition_key IS NOT NULL

transact/src/main/java/dev/dbos/transact/database/StepsDAO.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ static void recordStepResultTxn(
5555
Objects.requireNonNull(schema);
5656
String sql =
5757
"""
58-
INSERT INTO %s.operation_outputs
58+
INSERT INTO "%s".operation_outputs
5959
(workflow_uuid, function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms)
6060
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
6161
ON CONFLICT DO NOTHING RETURNING completed_at_epoch_ms
@@ -132,7 +132,7 @@ static StepResult checkStepExecutionTxn(
132132
Objects.requireNonNull(schema);
133133
final String sql =
134134
"""
135-
SELECT status FROM %s.workflow_status WHERE workflow_uuid = ?
135+
SELECT status FROM "%s".workflow_status WHERE workflow_uuid = ?
136136
"""
137137
.formatted(schema);
138138

@@ -158,7 +158,7 @@ static StepResult checkStepExecutionTxn(
158158
String operationOutputSql =
159159
"""
160160
SELECT output, error, function_name
161-
FROM %s.operation_outputs
161+
FROM "%s".operation_outputs
162162
WHERE workflow_uuid = ? AND function_id = ?
163163
"""
164164
.formatted(schema);
@@ -203,7 +203,7 @@ List<StepInfo> listWorkflowSteps(Connection connection, String workflowId) throw
203203
final String sql =
204204
"""
205205
SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms
206-
FROM %s.operation_outputs
206+
FROM "%s".operation_outputs
207207
WHERE workflow_uuid = ?
208208
ORDER BY function_id;
209209
"""

transact/src/main/java/dev/dbos/transact/database/SystemDatabase.java

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ public class SystemDatabase implements AutoCloseable {
3737
private static final Logger logger = LoggerFactory.getLogger(SystemDatabase.class);
3838

3939
public static String sanitizeSchema(String schema) {
40-
schema =
41-
Objects.requireNonNullElse(schema, Constants.DB_SCHEMA)
42-
.replace("\0", "")
43-
.replace("\"", "\"\"");
44-
return "\"%s\"".formatted(schema);
40+
return Objects.requireNonNullElse(schema, Constants.DB_SCHEMA).replace("\0", "");
4541
}
4642

4743
private final DataSource dataSource;
@@ -55,7 +51,12 @@ public static String sanitizeSchema(String schema) {
5551
private final NotificationService notificationService;
5652

5753
private SystemDatabase(DataSource dataSource, String schema, boolean created) {
58-
this.schema = sanitizeSchema(schema);
54+
schema = sanitizeSchema(schema);
55+
if (schema.contains("'") || schema.contains("\"")) {
56+
throw new IllegalArgumentException("Schema name must not contain single or double quotes");
57+
}
58+
59+
this.schema = schema;
5960
this.dataSource = dataSource;
6061
this.created = created;
6162

@@ -425,7 +426,7 @@ public Optional<ExternalState> getExternalState(String service, String workflowN
425426
() -> {
426427
final String sql =
427428
"""
428-
SELECT value, update_seq, update_time FROM %s.event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ?
429+
SELECT value, update_seq, update_time FROM "%s".event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ?
429430
"""
430431
.formatted(this.schema);
431432

@@ -456,7 +457,7 @@ public ExternalState upsertExternalState(ExternalState state) {
456457
() -> {
457458
final var sql =
458459
"""
459-
INSERT INTO %s.event_dispatch_kv (
460+
INSERT INTO "%s".event_dispatch_kv (
460461
service_name, workflow_fn_name, key, value, update_time, update_seq)
461462
VALUES (?, ?, ?, ?, ?, ?)
462463
ON CONFLICT (service_name, workflow_fn_name, key)
@@ -509,15 +510,15 @@ public List<MetricData> getMetrics(Instant startTime, Instant endTime) {
509510
final var wfSQL =
510511
"""
511512
SELECT name, COUNT(workflow_uuid) as count
512-
FROM %s.workflow_status
513+
FROM "%s".workflow_status
513514
WHERE created_at >= ? AND created_at < ?
514515
GROUP BY name
515516
"""
516517
.formatted(this.schema);
517518
final var stepSQL =
518519
"""
519520
SELECT function_name, COUNT(*) as count
520-
FROM %s.operation_outputs
521+
FROM "%s".operation_outputs
521522
WHERE completed_at_epoch_ms >= ? AND completed_at_epoch_ms < ?
522523
GROUP BY function_name
523524
"""
@@ -559,7 +560,7 @@ private String getCheckpointName(Connection conn, String workflowId, int functio
559560
var sql =
560561
"""
561562
SELECT function_name
562-
FROM %s.operation_outputs
563+
FROM "%s".operation_outputs
563564
WHERE workflow_uuid = ? AND function_id = ?
564565
"""
565566
.formatted(this.schema);
@@ -613,7 +614,7 @@ public void deleteWorkflows(String... workflowIds) {
613614

614615
var sql =
615616
"""
616-
DELETE FROM %s.workflow_status
617+
DELETE FROM "%s".workflow_status
617618
WHERE workflow_uuid = ANY(?);
618619
"""
619620
.formatted(this.schema);
@@ -642,7 +643,7 @@ List<String> getWorkflowChildrenInternal(String workflowId) throws SQLException
642643
var sql =
643644
"""
644645
SELECT child_workflow_id
645-
FROM %s.operation_outputs
646+
FROM "%s".operation_outputs
646647
WHERE workflow_uuid = ? AND child_workflow_id IS NOT NULL
647648
"""
648649
.formatted(this.schema);
@@ -673,7 +674,7 @@ List<WorkflowEvent> listWorkflowEvents(Connection conn, String workflowId) throw
673674
var sql =
674675
"""
675676
SELECT key, value
676-
FROM %s.workflow_events
677+
FROM "%s".workflow_events
677678
WHERE workflow_uuid = ?
678679
"""
679680
.formatted(this.schema);
@@ -697,7 +698,7 @@ List<WorkflowEventHistory> listWorkflowEventHistory(Connection conn, String work
697698
var sql =
698699
"""
699700
SELECT key, value, function_id
700-
FROM %s.workflow_events_history
701+
FROM "%s".workflow_events_history
701702
WHERE workflow_uuid = ?
702703
"""
703704
.formatted(this.schema);
@@ -721,7 +722,7 @@ List<WorkflowStream> listWorkflowStreams(Connection conn, String workflowId) thr
721722
var sql =
722723
"""
723724
SELECT key, value, "offset", function_id
724-
FROM %s.streams
725+
FROM "%s".streams
725726
WHERE workflow_uuid = ?
726727
"""
727728
.formatted(this.schema);
@@ -771,7 +772,7 @@ public List<ExportedWorkflow> exportWorkflow(String workflowId, boolean exportCh
771772
public void importWorkflow(List<ExportedWorkflow> workflows) {
772773
var wfSQL =
773774
"""
774-
INSERT INTO %s.workflow_status (
775+
INSERT INTO "%s".workflow_status (
775776
workflow_uuid, status,
776777
name, class_name, config_name,
777778
authenticated_user, assumed_role, authenticated_roles,
@@ -789,7 +790,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {
789790

790791
var stepSQL =
791792
"""
792-
INSERT INTO %s.operation_outputs (
793+
INSERT INTO "%s".operation_outputs (
793794
workflow_uuid, function_id, function_name,
794795
output, error, child_workflow_id,
795796
started_at_epoch_ms, completed_at_epoch_ms
@@ -801,7 +802,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {
801802

802803
var eventSQL =
803804
"""
804-
INSERT INTO %s.workflow_events (
805+
INSERT INTO "%s".workflow_events (
805806
workflow_uuid, key, value
806807
) VALUES (
807808
?, ?, ?
@@ -811,7 +812,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {
811812

812813
var eventHistorySQL =
813814
"""
814-
INSERT INTO %s.workflow_events_history (
815+
INSERT INTO "%s".workflow_events_history (
815816
workflow_uuid, key, value, function_id
816817
) VALUES (
817818
?, ?, ?, ?
@@ -821,7 +822,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {
821822

822823
var streamsSQL =
823824
"""
824-
INSERT INTO %s.streams (
825+
INSERT INTO "%s".streams (
825826
workflow_uuid, key, value, function_id, offset
826827
) VALUES (
827828
?, ?, ?, ?, ?

0 commit comments

Comments
 (0)