88 "encoding/json"
99 "errors"
1010 "fmt"
11+ "github.com/tursodatabase/libsql-client-go/sqliteparserutils"
1112 "io"
1213 "net/http"
1314 net_url "net/url"
@@ -286,10 +287,131 @@ func sendPipelineRequest(ctx context.Context, msg *hrana.PipelineRequest, url st
286287 return result , false , nil
287288}
288289
290+ func (h * hranaV2Conn ) executeMsg (ctx context.Context , msg * hrana.PipelineRequest ) (* hrana.PipelineResponse , error ) {
291+ result , err := h .sendPipelineRequest (ctx , msg , false )
292+ if err != nil {
293+ return nil , err
294+ }
295+
296+ for _ , r := range result .Results {
297+ if r .Error != nil {
298+ return nil , errors .New (r .Error .Message )
299+ }
300+ if r .Response == nil {
301+ return nil , errors .New ("no response received" )
302+ }
303+ }
304+ return result , nil
305+ }
306+
307+ type chunker struct {
308+ chunk []string
309+ iterator * sqliteparserutils.StatementIterator
310+ limit int
311+ }
312+
313+ func newChunker (iterator * sqliteparserutils.StatementIterator , limit int ) * chunker {
314+ return & chunker {iterator : iterator , chunk : make ([]string , 0 , limit ), limit : limit }
315+ }
316+
317+ func isTransactionStatement (stmt string ) bool {
318+ patterns := [][]byte {[]byte ("begin" ), []byte ("commit" ), []byte ("end" ), []byte ("rollback" )}
319+ for _ , p := range patterns {
320+ if len (stmt ) >= len (p ) && bytes .Equal (bytes .ToLower ([]byte (stmt [0 :len (p )])), p ) {
321+ return true
322+ }
323+ }
324+ return false
325+ }
326+
327+ func (c * chunker ) Next () (chunk []string , isEOF bool ) {
328+ c .chunk = c .chunk [:0 ]
329+ var stmt string
330+ for ! isEOF && len (c .chunk ) < c .limit {
331+ stmt , _ , isEOF = c .iterator .Next ()
332+ // We need to skip transaction statements. Chunks run in a transaction by default.
333+ if stmt != "" && ! isTransactionStatement (stmt ) {
334+ c .chunk = append (c .chunk , stmt )
335+ }
336+ }
337+ return c .chunk , isEOF
338+ }
339+
340+ func (h * hranaV2Conn ) executeSingleStmt (ctx context.Context , stmt string , wantRows bool ) (* hrana.PipelineResponse , error ) {
341+ msg := & hrana.PipelineRequest {}
342+ executeStream , err := hrana .ExecuteStream (stmt , nil , wantRows )
343+ if err != nil {
344+ return nil , fmt .Errorf ("failed to execute SQL: %s\n %w" , stmt , err )
345+ }
346+ msg .Add (* executeStream )
347+ res , err := h .executeMsg (ctx , msg )
348+ if err != nil {
349+ return nil , fmt .Errorf ("failed to execute SQL: %s\n %w" , stmt , err )
350+ }
351+ return res , nil
352+ }
353+
354+ func (h * hranaV2Conn ) executeInChunks (ctx context.Context , query string , wantRows bool ) (* hrana.PipelineResponse , error ) {
355+ const chunkSize = 4096
356+ iterator := sqliteparserutils .CreateStatementIterator (query )
357+ chunker := newChunker (iterator , chunkSize )
358+
359+ chunk , isEOF := chunker .Next ()
360+ if isEOF && len (chunk ) == 1 {
361+ return h .executeSingleStmt (ctx , chunk [0 ], wantRows )
362+ }
363+
364+ _ , err := h .executeSingleStmt (ctx , "BEGIN" , false )
365+ if err != nil {
366+ return nil , err
367+ }
368+
369+ batch := & hrana.Batch {Steps : make ([]hrana.BatchStep , chunkSize )}
370+ msg := & hrana.PipelineRequest {}
371+ msg .Add (hrana.StreamRequest {Type : "batch" , Batch : batch })
372+ for idx := range batch .Steps {
373+ batch .Steps [idx ].Stmt .WantRows = wantRows
374+ }
375+
376+ result := & hrana.PipelineResponse {}
377+ for {
378+ for idx := range chunk {
379+ batch .Steps [idx ].Stmt .Sql = & chunk [idx ]
380+ }
381+ if len (chunk ) < chunkSize {
382+ // We can trim batch.Steps because this is the last chunk anyway.
383+ // isEOF has to be true at this point.
384+ batch .Steps = batch .Steps [:len (chunk )]
385+ }
386+ res , err := h .executeMsg (ctx , msg )
387+ if err != nil {
388+ h .closeStream ()
389+ return nil , fmt .Errorf ("failed to execute SQL:\n %w" , err )
390+ }
391+ result .Baton = res .Baton
392+ result .BaseUrl = res .BaseUrl
393+ result .Results = append (result .Results , res .Results ... )
394+ if isEOF {
395+ break
396+ }
397+ chunk , isEOF = chunker .Next ()
398+ }
399+ _ , err = h .executeSingleStmt (ctx , "COMMIT" , false )
400+ if err != nil {
401+ h .closeStream ()
402+ return nil , err
403+ }
404+ return result , nil
405+ }
406+
289407func (h * hranaV2Conn ) executeStmt (ctx context.Context , query string , args []driver.NamedValue , wantRows bool ) (* hrana.PipelineResponse , error ) {
408+ const querySizeLimitForChunking = 20 * 1024 * 1024
409+ if len (args ) == 0 && len (query ) > querySizeLimitForChunking && ! h .schemaDb {
410+ return h .executeInChunks (ctx , query , wantRows )
411+ }
290412 stmts , params , err := shared .ParseStatementAndArgs (query , args )
291413 if err != nil {
292- return nil , fmt .Errorf ("failed to execute SQL: %s \n %w" , query , err )
414+ return nil , fmt .Errorf ("failed to execute SQL:\n %w" , err )
293415 }
294416 msg := & hrana.PipelineRequest {}
295417 if len (stmts ) == 1 {
@@ -299,29 +421,22 @@ func (h *hranaV2Conn) executeStmt(ctx context.Context, query string, args []driv
299421 }
300422 executeStream , err := hrana .ExecuteStream (stmts [0 ], p , wantRows )
301423 if err != nil {
302- return nil , fmt .Errorf ("failed to execute SQL: %s \n %w" , query , err )
424+ return nil , fmt .Errorf ("failed to execute SQL:\n %w" , err )
303425 }
304426 msg .Add (* executeStream )
305427 } else {
306428 batchStream , err := hrana .BatchStream (stmts , params , wantRows , ! h .schemaDb )
307429 if err != nil {
308- return nil , fmt .Errorf ("failed to execute SQL: %s \n %w" , query , err )
430+ return nil , fmt .Errorf ("failed to execute SQL:\n %w" , err )
309431 }
310432 msg .Add (* batchStream )
311433 }
312434
313- result , err := h .sendPipelineRequest (ctx , msg , false )
435+ resp , err := h .executeMsg (ctx , msg )
314436 if err != nil {
315- return nil , fmt .Errorf ("failed to execute SQL: %s \n %w" , query , err )
437+ return nil , fmt .Errorf ("failed to execute SQL:\n %w" , err )
316438 }
317-
318- if result .Results [0 ].Error != nil {
319- return nil , fmt .Errorf ("failed to execute SQL: %s\n %s" , query , result .Results [0 ].Error .Message )
320- }
321- if result .Results [0 ].Response == nil {
322- return nil , fmt .Errorf ("failed to execute SQL: %s\n %s" , query , "no response received" )
323- }
324- return result , nil
439+ return resp , nil
325440}
326441
327442func (h * hranaV2Conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
@@ -477,7 +592,7 @@ func (h *hranaV2Conn) QueryContext(ctx context.Context, query string, args []dri
477592 }
478593}
479594
480- func (h * hranaV2Conn ) ResetSession ( ctx context. Context ) error {
595+ func (h * hranaV2Conn ) closeStream () {
481596 if h .baton != "" {
482597 go func (baton , url , jwt , host string ) {
483598 msg := hrana.PipelineRequest {Baton : baton }
@@ -486,5 +601,9 @@ func (h *hranaV2Conn) ResetSession(ctx context.Context) error {
486601 }(h .baton , h .url , h .jwt , h .host )
487602 h .baton = ""
488603 }
604+ }
605+
606+ func (h * hranaV2Conn ) ResetSession (ctx context.Context ) error {
607+ h .closeStream ()
489608 return nil
490609}
0 commit comments