Skip to content

Commit fb41a64

Browse files
SBALAVIGNESH123Bala Vignesh S
andauthored
feat(orm): Add JOIN support to Mapper/BaseBuilder and FK auto-detection in code generator (#2491)
Co-authored-by: Bala Vignesh S <sbalavignesh123@gmail.com>
1 parent 8c6e588 commit fb41a64

5 files changed

Lines changed: 434 additions & 3 deletions

File tree

drogon_ctl/create_model.cc

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,59 @@ bool drogon_ctl::ConvertMethod::shouldConvert(const std::string &tableName,
164164
} // endif
165165
}
166166

167+
/**
168+
* @brief Try to add an auto-detected FK relationship to the list.
169+
*
170+
* Checks for duplicates against existing relationships, creates a
171+
* Relationship object from the FK info, and appends it to the list.
172+
* User-configured relationships always take priority.
173+
*
174+
* @param allRelationships The mutable vector of relationships.
175+
* @param originalTable The table containing the FK column.
176+
* @param fkColumn The FK column name.
177+
* @param referencedTable The table referenced by the FK.
178+
* @param referencedColumn The column referenced by the FK.
179+
* @param normalizeNames If true, apply toLower() to table names.
180+
*/
181+
static void tryAddAutoRelationship(std::vector<Relationship> &allRelationships,
182+
const std::string &originalTable,
183+
const std::string &fkColumn,
184+
const std::string &referencedTable,
185+
const std::string &referencedColumn,
186+
bool normalizeNames)
187+
{
188+
for (const auto &r : allRelationships)
189+
{
190+
if (r.originalKey() == fkColumn &&
191+
r.targetTableName() == referencedTable)
192+
{
193+
return; // Already exists in user config
194+
}
195+
}
196+
Json::Value relJson;
197+
relJson["type"] = "has one";
198+
relJson["original_table_name"] =
199+
normalizeNames ? toLower(originalTable) : originalTable;
200+
relJson["original_key"] = fkColumn;
201+
relJson["target_table_name"] =
202+
normalizeNames ? toLower(referencedTable) : referencedTable;
203+
relJson["target_key"] = referencedColumn;
204+
relJson["enable_reverse"] = true;
205+
try
206+
{
207+
Relationship autoRel(relJson);
208+
allRelationships.push_back(autoRel);
209+
std::cout << " Auto-detected FK: " << originalTable << "."
210+
<< fkColumn << " -> " << referencedTable << "."
211+
<< referencedColumn << std::endl;
212+
}
213+
catch (const std::runtime_error &e)
214+
{
215+
std::cerr << "Warning: Could not create auto-relationship: " << e.what()
216+
<< std::endl;
217+
}
218+
}
219+
167220
#if USE_POSTGRESQL
168221
void create_model::createModelClassFromPG(
169222
const std::string &path,
@@ -182,8 +235,9 @@ void create_model::createModelClassFromPG(
182235
data["primaryKeyName"] = "";
183236
data["dbName"] = dbname_;
184237
data["rdbms"] = std::string("postgresql");
185-
data["relationships"] = relationships;
186238
data["convertMethods"] = convertMethods;
239+
// Start with user-configured relationships (mutable copy)
240+
std::vector<Relationship> allRelationships(relationships);
187241
if (schema != "public")
188242
{
189243
data["schema"] = schema;
@@ -397,6 +451,42 @@ void create_model::createModelClassFromPG(
397451
data["primaryKeyValNames"] = pkValNames;
398452
}
399453

454+
// Auto-detect foreign key relationships from database schema
455+
*client << "SELECT "
456+
"kcu.column_name AS fk_column, "
457+
"ccu.table_name AS referenced_table, "
458+
"ccu.column_name AS referenced_column "
459+
"FROM information_schema.key_column_usage kcu "
460+
"JOIN information_schema.referential_constraints rc "
461+
"ON kcu.constraint_name = rc.constraint_name "
462+
"AND kcu.constraint_schema = rc.constraint_schema "
463+
"JOIN information_schema.constraint_column_usage ccu "
464+
"ON rc.unique_constraint_name = ccu.constraint_name "
465+
"AND rc.unique_constraint_schema = ccu.constraint_schema "
466+
"WHERE kcu.table_name = $1 "
467+
"AND kcu.table_schema = $2"
468+
<< tableName << schema << Mode::Blocking >>
469+
[&](bool isNull,
470+
const std::string &fkColumn,
471+
const std::string &referencedTable,
472+
const std::string &referencedColumn) {
473+
if (!isNull)
474+
{
475+
tryAddAutoRelationship(allRelationships,
476+
tableName,
477+
fkColumn,
478+
referencedTable,
479+
referencedColumn,
480+
true);
481+
}
482+
} >>
483+
[](const DrogonDbException &e) {
484+
// FK detection is best-effort; don't fail if unsupported
485+
std::cerr << "Note: FK auto-detection not available: "
486+
<< e.base().what() << std::endl;
487+
};
488+
489+
data["relationships"] = allRelationships;
400490
data["columns"] = cols;
401491
std::ofstream headerFile(path + "/" + className + ".h", std::ofstream::out);
402492
std::ofstream sourceFile(path + "/" + className + ".cc",
@@ -467,8 +557,9 @@ void create_model::createModelClassFromMysql(
467557
data["primaryKeyName"] = "";
468558
data["dbName"] = dbname_;
469559
data["rdbms"] = std::string("mysql");
470-
data["relationships"] = relationships;
471560
data["convertMethods"] = convertMethods;
561+
// Start with user-configured relationships (mutable copy)
562+
std::vector<Relationship> allRelationships(relationships);
472563
std::vector<ColumnInfo> cols;
473564
int i = 0;
474565
*client << "desc `" + tableName + "`" << Mode::Blocking >>
@@ -593,6 +684,35 @@ void create_model::createModelClassFromMysql(
593684
data["primaryKeyType"] = pkTypes;
594685
data["primaryKeyValNames"] = pkValNames;
595686
}
687+
688+
// Auto-detect foreign key relationships from MySQL schema
689+
*client << "SELECT COLUMN_NAME, REFERENCED_TABLE_NAME, "
690+
"REFERENCED_COLUMN_NAME "
691+
"FROM information_schema.KEY_COLUMN_USAGE "
692+
"WHERE TABLE_SCHEMA = DATABASE() "
693+
"AND TABLE_NAME = ? "
694+
"AND REFERENCED_TABLE_NAME IS NOT NULL"
695+
<< tableName << Mode::Blocking >>
696+
[&](bool isNull,
697+
const std::string &fkColumn,
698+
const std::string &referencedTable,
699+
const std::string &referencedColumn) {
700+
if (!isNull)
701+
{
702+
tryAddAutoRelationship(allRelationships,
703+
tableName,
704+
fkColumn,
705+
referencedTable,
706+
referencedColumn,
707+
true);
708+
}
709+
} >>
710+
[](const DrogonDbException &e) {
711+
std::cerr << "Note: FK auto-detection not available: "
712+
<< e.base().what() << std::endl;
713+
};
714+
715+
data["relationships"] = allRelationships;
596716
data["columns"] = cols;
597717
std::ofstream headerFile(path + "/" + className + ".h", std::ofstream::out);
598718
std::ofstream sourceFile(path + "/" + className + ".cc",
@@ -646,8 +766,9 @@ void create_model::createModelClassFromSqlite3(
646766
data["primaryKeyName"] = "";
647767
data["dbName"] = std::string("sqlite3");
648768
data["rdbms"] = std::string("sqlite3");
649-
data["relationships"] = relationships;
650769
data["convertMethods"] = convertMethods;
770+
// Start with user-configured relationships (mutable copy)
771+
std::vector<Relationship> allRelationships(relationships);
651772
std::vector<ColumnInfo> cols;
652773
std::string sql = "PRAGMA table_info(" + tableName + ");";
653774
*client << sql << Mode::Blocking >> [&](const Result &result) {
@@ -774,6 +895,28 @@ void create_model::createModelClassFromSqlite3(
774895
data["primaryKeyType"] = pkTypes;
775896
data["primaryKeyValNames"] = pkValNames;
776897
}
898+
899+
// Auto-detect foreign key relationships from SQLite3 schema
900+
std::string fkSql = "PRAGMA foreign_key_list(\"" + tableName + "\");";
901+
*client << fkSql << Mode::Blocking >> [&](const Result &fkResult) {
902+
for (auto &fkRow : fkResult)
903+
{
904+
auto referencedTable = fkRow["table"].as<std::string>();
905+
auto fkColumn = fkRow["from"].as<std::string>();
906+
auto referencedColumn = fkRow["to"].as<std::string>();
907+
tryAddAutoRelationship(allRelationships,
908+
tableName,
909+
fkColumn,
910+
referencedTable,
911+
referencedColumn,
912+
true);
913+
}
914+
} >> [](const DrogonDbException &e) {
915+
std::cerr << "Note: FK auto-detection not available: "
916+
<< e.base().what() << std::endl;
917+
};
918+
919+
data["relationships"] = allRelationships;
777920
data["columns"] = cols;
778921
std::ofstream headerFile(path + "/" + className + ".h", std::ofstream::out);
779922
std::ofstream sourceFile(path + "/" + className + ".cc",

orm_lib/inc/drogon/orm/BaseBuilder.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,69 @@ struct Filter
6767
std::string value;
6868
};
6969

70+
/**
71+
* @brief Represents a SQL JOIN clause.
72+
*/
73+
enum class JoinType
74+
{
75+
InnerJoin,
76+
LeftJoin,
77+
RightJoin,
78+
FullJoin
79+
};
80+
81+
inline std::string to_join_string(JoinType type)
82+
{
83+
switch (type)
84+
{
85+
case JoinType::InnerJoin:
86+
return "INNER JOIN";
87+
case JoinType::LeftJoin:
88+
return "LEFT JOIN";
89+
case JoinType::RightJoin:
90+
return "RIGHT JOIN";
91+
case JoinType::FullJoin:
92+
return "FULL JOIN";
93+
}
94+
// Should never reach here
95+
return "INNER JOIN";
96+
}
97+
98+
struct JoinClause
99+
{
100+
JoinType type;
101+
std::string table;
102+
std::string onLeft; // e.g. "users.id"
103+
std::string onRight; // e.g. "posts.user_id"
104+
};
105+
106+
/**
107+
* @brief Validate that a string is a safe SQL identifier.
108+
*
109+
* Only allows alphanumeric characters, underscores, and dots
110+
* (for table.column notation). This prevents SQL injection when
111+
* building JOIN clauses from user-provided identifiers.
112+
*
113+
* @param identifier The identifier to validate.
114+
* @return true if the identifier is safe to use in SQL.
115+
*/
116+
inline bool isValidSqlIdentifier(const std::string &identifier)
117+
{
118+
if (identifier.empty())
119+
{
120+
return false;
121+
}
122+
for (auto c : identifier)
123+
{
124+
if (!std::isalnum(static_cast<unsigned char>(c)) && c != '_' &&
125+
c != '.')
126+
{
127+
return false;
128+
}
129+
}
130+
return true;
131+
}
132+
70133
// Forward declaration to be a friend
71134
template <typename T, bool SelectAll, bool Single = false>
72135
class TransformBuilder;
@@ -87,6 +150,7 @@ class BaseBuilder
87150
std::string from_;
88151
std::string columns_;
89152
std::vector<Filter> filters_;
153+
std::vector<JoinClause> joins_;
90154
std::optional<std::uint64_t> limit_;
91155
std::optional<std::uint64_t> offset_;
92156
// The order is important; use vector<pair> instead of unordered_map and
@@ -122,6 +186,11 @@ class BaseBuilder
122186
};
123187

124188
std::string sql = "select " + columns_ + " from " + from_;
189+
for (const auto &join : joins_)
190+
{
191+
sql += " " + to_join_string(join.type) + " " + join.table + " ON " +
192+
join.onLeft + " = " + join.onRight;
193+
}
125194
if (!filters_.empty())
126195
{
127196
sql += " where " + filters_[0].column + " " +

orm_lib/inc/drogon/orm/CoroMapper.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,54 @@ class CoroMapper : public Mapper<T>
211211
return *this;
212212
}
213213

214+
/**
215+
* @brief Add an INNER JOIN clause to the query.
216+
*
217+
* @param table The table to join.
218+
* @param onLeft The left side of ON (e.g. "users.id").
219+
* @param onRight The right side of ON (e.g. "posts.user_id").
220+
* @return CoroMapper<T>& The CoroMapper itself.
221+
*/
222+
CoroMapper<T> &innerJoin(const std::string &table,
223+
const std::string &onLeft,
224+
const std::string &onRight)
225+
{
226+
Mapper<T>::innerJoin(table, onLeft, onRight);
227+
return *this;
228+
}
229+
230+
/**
231+
* @brief Add a LEFT JOIN clause to the query.
232+
*
233+
* @param table The table to join.
234+
* @param onLeft The left side of ON (e.g. "users.id").
235+
* @param onRight The right side of ON (e.g. "posts.user_id").
236+
* @return CoroMapper<T>& The CoroMapper itself.
237+
*/
238+
CoroMapper<T> &leftJoin(const std::string &table,
239+
const std::string &onLeft,
240+
const std::string &onRight)
241+
{
242+
Mapper<T>::leftJoin(table, onLeft, onRight);
243+
return *this;
244+
}
245+
246+
/**
247+
* @brief Add a RIGHT JOIN clause to the query.
248+
*
249+
* @param table The table to join.
250+
* @param onLeft The left side of ON (e.g. "users.id").
251+
* @param onRight The right side of ON (e.g. "posts.user_id").
252+
* @return CoroMapper<T>& The CoroMapper itself.
253+
*/
254+
CoroMapper<T> &rightJoin(const std::string &table,
255+
const std::string &onLeft,
256+
const std::string &onRight)
257+
{
258+
Mapper<T>::rightJoin(table, onLeft, onRight);
259+
return *this;
260+
}
261+
214262
// Read api for coroutines
215263

216264
inline internal::MapperAwaiter<std::vector<T>> findAll()
@@ -225,6 +273,7 @@ class CoroMapper : public Mapper<T>
225273
ExceptPtrCallback &&errCallback) {
226274
std::string sql = "select count(*) from ";
227275
sql += T::tableName;
276+
sql += this->joinString_;
228277
if (criteria)
229278
{
230279
sql += " where ";
@@ -250,6 +299,7 @@ class CoroMapper : public Mapper<T>
250299
ExceptPtrCallback &&errCallback) {
251300
std::string sql = "select * from ";
252301
sql += T::tableName;
302+
sql += this->joinString_;
253303
bool hasParameters = false;
254304
if (criteria)
255305
{
@@ -311,6 +361,7 @@ class CoroMapper : public Mapper<T>
311361
ExceptPtrCallback &&errCallback) {
312362
std::string sql = "select * from ";
313363
sql += T::tableName;
364+
sql += this->joinString_;
314365
bool hasParameters = false;
315366
if (criteria)
316367
{

0 commit comments

Comments
 (0)