diff --git a/src/cypher/cypher.c b/src/cypher/cypher.c index 6aedeb9..b4b049b 100644 --- a/src/cypher/cypher.c +++ b/src/cypher/cypher.c @@ -92,25 +92,30 @@ static void lex_string_literal(const char *input, int len, int *pos, char quote, int start = *pos; char buf[CBM_SZ_4K]; int blen = 0; + const int max_blen = CBM_SZ_4K - 1; while (*pos < len && input[*pos] != quote) { if (input[*pos] == '\\' && *pos + SKIP_ONE < len) { (*pos)++; - switch (input[*pos]) { - case 'n': - buf[blen++] = '\n'; - break; - case 't': - buf[blen++] = '\t'; - break; - case '\\': - buf[blen++] = '\\'; - break; - default: - buf[blen++] = input[*pos]; - break; + if (blen < max_blen) { + switch (input[*pos]) { + case 'n': + buf[blen++] = '\n'; + break; + case 't': + buf[blen++] = '\t'; + break; + case '\\': + buf[blen++] = '\\'; + break; + default: + buf[blen++] = input[*pos]; + break; + } } } else { - buf[blen++] = input[*pos]; + if (blen < max_blen) { + buf[blen++] = input[*pos]; + } } (*pos)++; } @@ -469,6 +474,9 @@ static int parse_props(parser_t *p, cbm_prop_filter_t **out, int *count) { int cap = CYP_INIT_CAP4; int n = 0; cbm_prop_filter_t *arr = malloc(cap * sizeof(cbm_prop_filter_t)); + if (!arr) { + return CBM_NOT_FOUND; + } while (!check(p, TOK_RBRACE) && !check(p, TOK_EOF)) { const cbm_token_t *key = expect(p, TOK_IDENT); @@ -487,8 +495,18 @@ static int parse_props(parser_t *p, cbm_prop_filter_t **out, int *count) { } if (n >= cap) { - cap *= PAIR_LEN; - arr = safe_realloc(arr, cap * sizeof(cbm_prop_filter_t)); + int new_cap = cap * PAIR_LEN; + void *tmp = realloc(arr, new_cap * sizeof(cbm_prop_filter_t)); + if (!tmp) { + for (int i = 0; i < n; i++) { + free((void *)arr[i].key); + free((void *)arr[i].value); + } + free(arr); + return CBM_NOT_FOUND; + } + arr = tmp; + cap = new_cap; } arr[n].key = heap_strdup(key->text); arr[n].value = heap_strdup(val->text); @@ -569,6 +587,9 @@ static int parse_rel_types(parser_t *p, cbm_rel_pattern_t *out) { int cap = CYP_INIT_CAP4; int n = 0; const char **types = malloc(cap * sizeof(const char *)); + if (!types) { + return CBM_NOT_FOUND; + } const cbm_token_t *t = expect(p, TOK_IDENT); if (!t) { @@ -587,8 +608,17 @@ static int parse_rel_types(parser_t *p, cbm_rel_pattern_t *out) { return CBM_NOT_FOUND; } if (n >= cap) { - cap *= PAIR_LEN; - types = safe_realloc(types, cap * sizeof(const char *)); + int new_cap = cap * PAIR_LEN; + void *tmp = realloc(types, new_cap * sizeof(const char *)); + if (!tmp) { + for (int i = 0; i < n; i++) { + free((void *)types[i]); + } + free(types); + return CBM_NOT_FOUND; + } + types = (const char **)tmp; + cap = new_cap; } types[n++] = heap_strdup(t->text); } @@ -762,14 +792,32 @@ static cbm_expr_t *parse_in_list(parser_t *p, cbm_condition_t *c) { int vcap = CYP_INIT_CAP8; int vn = 0; const char **vals = malloc(vcap * sizeof(const char *)); + if (!vals) { + free((void *)c->variable); + free((void *)c->property); + free((void *)c->op); + return NULL; + } while (!check(p, TOK_RBRACKET) && !check(p, TOK_EOF)) { if (vn > 0) { match(p, TOK_COMMA); } if (check(p, TOK_STRING) || check(p, TOK_NUMBER)) { if (vn >= vcap) { - vcap *= PAIR_LEN; - vals = safe_realloc(vals, vcap * sizeof(const char *)); + int new_vcap = vcap * PAIR_LEN; + void *tmp = realloc((void *)vals, new_vcap * sizeof(const char *)); + if (!tmp) { + for (int i = 0; i < vn; i++) { + free((void *)vals[i]); + } + free((void *)vals); + free((void *)c->variable); + free((void *)c->property); + free((void *)c->op); + return NULL; + } + vals = (const char **)tmp; + vcap = new_vcap; } vals[vn++] = heap_strdup(advance(p)->text); } else { @@ -1061,8 +1109,15 @@ static const char *parse_value_literal(parser_t *p) { static cbm_case_expr_t *parse_case_expr(parser_t *p) { /* CASE already consumed */ cbm_case_expr_t *kase = calloc(CBM_ALLOC_ONE, sizeof(cbm_case_expr_t)); + if (!kase) { + return NULL; + } int bcap = CYP_INIT_CAP4; kase->branches = malloc(bcap * sizeof(cbm_case_branch_t)); + if (!kase->branches) { + free(kase); + return NULL; + } while (check(p, TOK_WHEN)) { advance(p); @@ -1073,8 +1128,19 @@ static cbm_case_expr_t *parse_case_expr(parser_t *p) { } const char *then_val = parse_value_literal(p); if (kase->branch_count >= bcap) { - bcap *= PAIR_LEN; - kase->branches = safe_realloc(kase->branches, bcap * sizeof(cbm_case_branch_t)); + int new_bcap = bcap * PAIR_LEN; + void *tmp = realloc(kase->branches, new_bcap * sizeof(cbm_case_branch_t)); + if (!tmp) { + expr_free(when); + for (int i = 0; i < kase->branch_count; i++) { + expr_free(kase->branches[i].when_expr); + } + free(kase->branches); + free(kase); + return NULL; + } + kase->branches = tmp; + bcap = new_bcap; } kase->branches[kase->branch_count++] = (cbm_case_branch_t){.when_expr = when, .then_val = then_val}; diff --git a/src/store/store.c b/src/store/store.c index 4920732..876a828 100644 --- a/src/store/store.c +++ b/src/store/store.c @@ -2552,7 +2552,12 @@ int cbm_store_get_schema(cbm_store_t *s, const char *project, cbm_schema_info_t const char *sql = "SELECT label, COUNT(*) FROM nodes WHERE project = ?1 GROUP BY label " "ORDER BY COUNT(*) DESC;"; sqlite3_stmt *stmt = NULL; - sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL); + if (sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL) != SQLITE_OK || !stmt) { + if (stmt) { + sqlite3_finalize(stmt); + } + return CBM_NOT_FOUND; + } bind_text(stmt, SKIP_ONE, project); int cap = ST_INIT_CAP_8; @@ -2577,7 +2582,13 @@ int cbm_store_get_schema(cbm_store_t *s, const char *project, cbm_schema_info_t const char *sql = "SELECT type, COUNT(*) FROM edges WHERE project = ?1 GROUP BY type ORDER " "BY COUNT(*) DESC;"; sqlite3_stmt *stmt = NULL; - sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL); + if (sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL) != SQLITE_OK || !stmt) { + if (stmt) { + sqlite3_finalize(stmt); + } + cbm_store_schema_free(out); + return CBM_NOT_FOUND; + } bind_text(stmt, SKIP_ONE, project); int cap = ST_INIT_CAP_8; @@ -3283,7 +3294,12 @@ static bool pkg_in_list(const char *pkg, char **list, int count) { static int collect_pkg_names(cbm_store_t *s, const char *sql, const char *project, char **pkgs, int max_pkgs) { sqlite3_stmt *stmt = NULL; - sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL); + if (sqlite3_prepare_v2(s->db, sql, CBM_NOT_FOUND, &stmt, NULL) != SQLITE_OK || !stmt) { + if (stmt) { + sqlite3_finalize(stmt); + } + return CBM_NOT_FOUND; + } bind_text(stmt, SKIP_ONE, project); int count = 0; while (sqlite3_step(stmt) == SQLITE_ROW && count < max_pkgs) { diff --git a/tests/test_cypher.c b/tests/test_cypher.c index 13527d5..a169432 100644 --- a/tests/test_cypher.c +++ b/tests/test_cypher.c @@ -78,6 +78,32 @@ TEST(cypher_lex_single_quote_string) { PASS(); } +TEST(cypher_lex_string_overflow) { + /* Build a string literal longer than 4096 bytes to verify we don't + * overflow the stack buffer in lex_string_literal. */ + const int big = 5000; + /* query: "AAAA...A" (quotes included) */ + char *query = malloc(big + 3); /* quote + big chars + quote + NUL */ + ASSERT_NOT_NULL(query); + query[0] = '"'; + memset(query + 1, 'A', big); + query[big + 1] = '"'; + query[big + 2] = '\0'; + + cbm_lex_result_t r = {0}; + int rc = cbm_lex(query, &r); + ASSERT_EQ(rc, 0); + ASSERT_NULL(r.error); + ASSERT_GTE(r.count, 1); + ASSERT_EQ(r.tokens[0].type, TOK_STRING); + /* The string should be truncated to CBM_SZ_4K - 1 (4095) characters. */ + ASSERT_EQ((int)strlen(r.tokens[0].text), 4095); + + cbm_lex_free(&r); + free(query); + PASS(); +} + TEST(cypher_lex_number) { cbm_lex_result_t r = {0}; int rc = cbm_lex("42 3.14", &r); @@ -2064,6 +2090,7 @@ SUITE(cypher) { RUN_TEST(cypher_lex_relationship); RUN_TEST(cypher_lex_string_literal); RUN_TEST(cypher_lex_single_quote_string); + RUN_TEST(cypher_lex_string_overflow); RUN_TEST(cypher_lex_number); RUN_TEST(cypher_lex_operators); RUN_TEST(cypher_lex_keywords_case_insensitive);