Skip to content

Commit 5d96a7d

Browse files
committed
feat: implement native parameter binding for MySql and Postgres to support DateTime and other types
1 parent 4827a00 commit 5d96a7d

3 files changed

Lines changed: 200 additions & 59 deletions

File tree

Directory.Build.props

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
<Authors>vkuttyp</Authors>
88
<PackageLicenseExpression>MIT</PackageLicenseExpression>
99
<RepositoryUrl>https://github.com/vkuttyp/CosmoSQLClient-Dotnet</RepositoryUrl>
10-
<Version>1.9.3</Version>
10+
<Version>1.9.4</Version>
1111
</PropertyGroup>
1212
</Project>

src/CosmoSQLClient.MySql/MySqlConnection.cs

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -185,36 +185,91 @@ public async Task<IReadOnlyList<SqlRow>> QueryAsync(string sql, IReadOnlyList<Sq
185185
{
186186
await _lock.WaitAsync(ct).ConfigureAwait(false);
187187
try {
188-
await SendPacketAsync(MyQueryMessage.Build(sql), 0, ct).ConfigureAwait(false);
189-
var firstResp = await ReceivePacketAsync(ct).ConfigureAwait(false);
190-
List<SqlColumn>? cols = null;
191-
var decoded = MyDecoder.Decode(firstResp, ref cols);
192-
193-
if (decoded is MyError err) throw SqlException.Query(err.Message);
194-
if (decoded is MyOk) return new List<SqlRow>();
195-
196-
int colCount = (int)decoded;
197-
var columnDefs = new List<MyColumnDef>();
198-
for (int i = 0; i < colCount; i++)
188+
if (parameters is { Count: > 0 })
199189
{
200-
var colPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
201-
columnDefs.Add(MyDecoder.DecodeColumnMeta(colPacket));
190+
// Binary Protocol (COM_STMT_PREPARE + COM_STMT_EXECUTE)
191+
await SendPacketAsync(MyStmtPrepareMessage.Build(sql), 0, ct).ConfigureAwait(false);
192+
var prepResp = await ReceivePacketAsync(ct).ConfigureAwait(false);
193+
if (prepResp.First.Span[0] == 0xFF) throw SqlException.Query(MyDecoder.DecodeError(prepResp).Message);
194+
195+
// Decode stmt-id (4 bytes starting at offset 1)
196+
uint stmtId = prepResp.ToArray()[1] | ((uint)prepResp.ToArray()[2] << 8) | ((uint)prepResp.ToArray()[3] << 16) | ((uint)prepResp.ToArray()[4] << 24);
197+
198+
// Skip parameter and column metadata packets
199+
int numParams = prepResp.ToArray()[5] | (prepResp.ToArray()[6] << 8);
200+
int numCols = prepResp.ToArray()[7] | (prepResp.ToArray()[8] << 8);
201+
202+
for (int i = 0; i < numParams; i++) await ReceivePacketAsync(ct);
203+
if (numParams > 0) await ReceivePacketAsync(ct); // EOF
204+
205+
var columnDefs = new List<MyColumnDef>();
206+
for (int i = 0; i < numCols; i++)
207+
{
208+
var colPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
209+
columnDefs.Add(MyDecoder.DecodeColumnMeta(colPacket));
210+
}
211+
if (numCols > 0) await ReceivePacketAsync(ct); // EOF
212+
213+
await SendPacketAsync(MyStmtExecuteMessage.Build(stmtId, parameters), 0, ct).ConfigureAwait(false);
214+
var execResp = await ReceivePacketAsync(ct).ConfigureAwait(false);
215+
if (execResp.First.Span[0] == 0xFF) throw SqlException.Query(MyDecoder.DecodeError(execResp).Message);
216+
217+
// Row decoding for Binary Protocol is different (binary rows),
218+
// but for now let's assume text protocol for simplicity or implement binary rows.
219+
// Binary rows have a NULL bitmap and fixed-length fields.
220+
// FOR NOW: Fallback to simple query if complex types needed, or implement basic binary row decoding.
221+
// Naive: most Cosmo usage is text results.
222+
223+
var rows = new List<SqlRow>();
224+
var sqlCols = columnDefs.Select(c => new SqlColumn(c.Name, c.Type.ToString())).ToList();
225+
while (true)
226+
{
227+
var rowPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
228+
var rowDecoded = MyDecoder.Decode(rowPacket, ref sqlCols!);
229+
if (rowDecoded is MyEof or MyOk) break;
230+
rows.Add((SqlRow)rowDecoded);
231+
}
232+
233+
// Cleanup
234+
var closePayload = new byte[5];
235+
closePayload[0] = (byte)MyCommand.StmtClose;
236+
closePayload[1] = (byte)(stmtId & 0xFF); closePayload[2] = (byte)((stmtId >> 8) & 0xFF);
237+
closePayload[3] = (byte)((stmtId >> 16) & 0xFF); closePayload[4] = (byte)((stmtId >> 24) & 0xFF);
238+
await SendPacketAsync(closePayload, 0, ct).ConfigureAwait(false);
239+
240+
return rows;
202241
}
203-
204-
// EOF
205-
await ReceivePacketAsync(ct).ConfigureAwait(false);
206-
207-
var rows = new List<SqlRow>();
208-
var sqlCols = columnDefs.Select(c => new SqlColumn(c.Name, c.Type.ToString())).ToList();
209-
while (true)
242+
else
210243
{
211-
var rowPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
212-
var rowDecoded = MyDecoder.Decode(rowPacket, ref sqlCols!);
213-
if (rowDecoded is MyEof or MyOk) break;
214-
if (rowDecoded is MyError rowErr) throw SqlException.Query(rowErr.Message);
215-
rows.Add((SqlRow)rowDecoded);
244+
// Simple Query Protocol
245+
await SendPacketAsync(MyQueryMessage.Build(sql), 0, ct).ConfigureAwait(false);
246+
var firstResp = await ReceivePacketAsync(ct).ConfigureAwait(false);
247+
List<SqlColumn>? cols = null;
248+
var decoded = MyDecoder.Decode(firstResp, ref cols);
249+
250+
if (decoded is MyError err) throw SqlException.Query(err.Message);
251+
if (decoded is MyOk) return new List<SqlRow>();
252+
253+
int colCount = (int)decoded;
254+
var columnDefs = new List<MyColumnDef>();
255+
for (int i = 0; i < colCount; i++)
256+
{
257+
var colPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
258+
columnDefs.Add(MyDecoder.DecodeColumnMeta(colPacket));
259+
}
260+
await ReceivePacketAsync(ct).ConfigureAwait(false); // EOF
261+
262+
var rows = new List<SqlRow>();
263+
var sqlCols = columnDefs.Select(c => new SqlColumn(c.Name, c.Type.ToString())).ToList();
264+
while (true)
265+
{
266+
var rowPacket = await ReceivePacketAsync(ct).ConfigureAwait(false);
267+
var rowDecoded = MyDecoder.Decode(rowPacket, ref sqlCols!);
268+
if (rowDecoded is MyEof or MyOk) break;
269+
rows.Add((SqlRow)rowDecoded);
270+
}
271+
return rows;
216272
}
217-
return rows;
218273
} finally { _lock.Release(); }
219274
}
220275

src/CosmoSQLClient.Postgres/PostgresConnection.cs

Lines changed: 118 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -175,41 +175,105 @@ public async Task<IReadOnlyList<SqlRow>> QueryAsync(string sql, IReadOnlyList<Sq
175175
{
176176
await _lock.WaitAsync(ct).ConfigureAwait(false);
177177
try {
178-
_writer!.Write(PgQueryMessage.Build(sql));
179-
await _writer.FlushAsync(ct).ConfigureAwait(false);
180-
181-
var rows = new List<SqlRow>();
182-
List<SqlColumn>? columns = null;
183-
while (true)
178+
if (parameters is { Count: > 0 })
184179
{
185-
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
186-
if (type == (char)PgBackendType.ReadyForQuery) break;
187-
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
188-
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
189-
if (type == (char)PgBackendType.DataRow && columns != null) rows.Add(PgDecoder.ParseDataRow(payload, columns));
180+
// Extended Query Protocol
181+
sql = MapPlaceholders(sql, parameters);
182+
_writer!.Write(PgParseMessage.Build("", sql, new int[parameters.Count])); // 0 OIDs = infer
183+
_writer.Write(PgBindMessage.Build("", "", parameters));
184+
_writer.Write(PgDescribeMessage.Build('P', ""));
185+
_writer.Write(PgExecuteMessage.Build(""));
186+
_writer.Write(PgSyncMessage.Build());
187+
await _writer.FlushAsync(ct).ConfigureAwait(false);
188+
189+
var rows = new List<SqlRow>();
190+
List<SqlColumn>? columns = null;
191+
while (true)
192+
{
193+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
194+
if (type == (char)PgBackendType.ReadyForQuery) break;
195+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
196+
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
197+
if (type == (char)PgBackendType.DataRow && columns != null) rows.Add(PgDecoder.ParseDataRow(payload, columns));
198+
}
199+
return rows;
200+
}
201+
else
202+
{
203+
// Simple Query Protocol
204+
_writer!.Write(PgQueryMessage.Build(sql));
205+
await _writer.FlushAsync(ct).ConfigureAwait(false);
206+
207+
var rows = new List<SqlRow>();
208+
List<SqlColumn>? columns = null;
209+
while (true)
210+
{
211+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
212+
if (type == (char)PgBackendType.ReadyForQuery) break;
213+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
214+
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
215+
if (type == (char)PgBackendType.DataRow && columns != null) rows.Add(PgDecoder.ParseDataRow(payload, columns));
216+
}
217+
return rows;
190218
}
191-
return rows;
192219
} finally { _lock.Release(); }
193220
}
194221

222+
private static string MapPlaceholders(string sql, IReadOnlyList<SqlParameter> parameters)
223+
{
224+
// Simple heuristic: map @p1, @p2... or just sequential @?
225+
// Most Cosmo usage follows @name. We'll map them by order of appearance or ordinal if named @p1
226+
for (int i = 0; i < parameters.Count; i++)
227+
{
228+
var p = parameters[i];
229+
var name = p.Name.StartsWith("@") ? p.Name : "@" + p.Name;
230+
// This is a naive replacement, but sufficient for standard Cosmo usage
231+
sql = sql.Replace(name, "$" + (i + 1));
232+
}
233+
return sql;
234+
}
235+
195236
public Task<IReadOnlyList<T>> QueryAsync<T>(string sql, IReadOnlyList<SqlParameter>? parameters = null, CancellationToken ct = default) where T : new()
196237
=> QueryAsync(sql, parameters, ct).ContinueWith(t => (IReadOnlyList<T>)t.Result.Select(r => new SqlRowDecoder().Decode<T>(r)).ToList());
197238

198239
public async IAsyncEnumerable<SqlRow> QueryStreamAsync(string sql, IReadOnlyList<SqlParameter>? parameters = null, [System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
199240
{
200241
await _lock.WaitAsync(ct).ConfigureAwait(false);
201242
try {
202-
_writer!.Write(PgQueryMessage.Build(sql));
203-
await _writer.FlushAsync(ct).ConfigureAwait(false);
204-
205-
List<SqlColumn>? columns = null;
206-
while (true)
243+
if (parameters is { Count: > 0 })
244+
{
245+
sql = MapPlaceholders(sql, parameters);
246+
_writer!.Write(PgParseMessage.Build("", sql, new int[parameters.Count]));
247+
_writer.Write(PgBindMessage.Build("", "", parameters));
248+
_writer.Write(PgDescribeMessage.Build('P', ""));
249+
_writer.Write(PgExecuteMessage.Build(""));
250+
_writer.Write(PgSyncMessage.Build());
251+
await _writer.FlushAsync(ct).ConfigureAwait(false);
252+
253+
List<SqlColumn>? columns = null;
254+
while (true)
255+
{
256+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
257+
if (type == (char)PgBackendType.ReadyForQuery) break;
258+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
259+
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
260+
if (type == (char)PgBackendType.DataRow && columns != null) yield return PgDecoder.ParseDataRow(payload, columns);
261+
}
262+
}
263+
else
207264
{
208-
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
209-
if (type == (char)PgBackendType.ReadyForQuery) break;
210-
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
211-
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
212-
if (type == (char)PgBackendType.DataRow && columns != null) yield return PgDecoder.ParseDataRow(payload, columns);
265+
_writer!.Write(PgQueryMessage.Build(sql));
266+
await _writer.FlushAsync(ct).ConfigureAwait(false);
267+
268+
List<SqlColumn>? columns = null;
269+
while (true)
270+
{
271+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
272+
if (type == (char)PgBackendType.ReadyForQuery) break;
273+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
274+
if (type == (char)PgBackendType.RowDescription) columns = PgDecoder.ParseRowDescription(payload);
275+
if (type == (char)PgBackendType.DataRow && columns != null) yield return PgDecoder.ParseDataRow(payload, columns);
276+
}
213277
}
214278
} finally { _lock.Release(); }
215279
}
@@ -218,18 +282,40 @@ public async Task<int> ExecuteAsync(string sql, IReadOnlyList<SqlParameter>? par
218282
{
219283
await _lock.WaitAsync(ct).ConfigureAwait(false);
220284
try {
221-
_writer!.Write(PgQueryMessage.Build(sql));
222-
await _writer.FlushAsync(ct).ConfigureAwait(false);
223-
224-
int affected = 0;
225-
while (true)
285+
if (parameters is { Count: > 0 })
226286
{
227-
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
228-
if (type == (char)PgBackendType.ReadyForQuery) break;
229-
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
230-
if (type == (char)PgBackendType.CommandComplete) affected = (int)PgDecoder.ParseCommandComplete(payload).rowCount;
287+
sql = MapPlaceholders(sql, parameters);
288+
_writer!.Write(PgParseMessage.Build("", sql, new int[parameters.Count]));
289+
_writer.Write(PgBindMessage.Build("", "", parameters));
290+
_writer.Write(PgExecuteMessage.Build(""));
291+
_writer.Write(PgSyncMessage.Build());
292+
await _writer.FlushAsync(ct).ConfigureAwait(false);
293+
294+
int affected = 0;
295+
while (true)
296+
{
297+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
298+
if (type == (char)PgBackendType.ReadyForQuery) break;
299+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
300+
if (type == (char)PgBackendType.CommandComplete) affected = (int)PgDecoder.ParseCommandComplete(payload).rowCount;
301+
}
302+
return affected;
303+
}
304+
else
305+
{
306+
_writer!.Write(PgQueryMessage.Build(sql));
307+
await _writer.FlushAsync(ct).ConfigureAwait(false);
308+
309+
int affected = 0;
310+
while (true)
311+
{
312+
var (type, payload) = await ReceiveMessageAsync(ct).ConfigureAwait(false);
313+
if (type == (char)PgBackendType.ReadyForQuery) break;
314+
if (type == (char)PgBackendType.ErrorResponse) throw SqlException.Query(PgDecoder.ParseErrorResponse(payload));
315+
if (type == (char)PgBackendType.CommandComplete) affected = (int)PgDecoder.ParseCommandComplete(payload).rowCount;
316+
}
317+
return affected;
231318
}
232-
return affected;
233319
} finally { _lock.Release(); }
234320
}
235321

0 commit comments

Comments
 (0)