Skip to content

Commit 6e66c59

Browse files
committed
refactor(dbexec): 🔧 improve error handling in QueryRow and Scan
* change `rowWithAfter` to use a pointer receiver for `Scan` * ensure callback is only called once using `sync.Once` * update tests to reflect changes in `Scan` behavior * enhance query builder to quote identifiers for safety
1 parent 83fcfff commit 6e66c59

18 files changed

Lines changed: 10490 additions & 5544 deletions

‎README.md‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Production-ready, lightweight ORM and query builder for PostgreSQL on top of PGX v5. Ships with connection pooling, automatic migrations from struct tags, a fluent query builder, generic repository, soft delete, optimistic locking, transactions, read/write splitting, retry/backoff, a circuit breaker, and comprehensive e2e tests.
44

5+
`QueryBuilder` instances are mutable and not goroutine-safe. Create a fresh builder per query chain.
6+
57
### Features
68

79
- Fast, reliable connections via PGX v5 (`pgxpool`)
@@ -18,6 +20,8 @@ Production-ready, lightweight ORM and query builder for PostgreSQL on top of PGX
1820

1921
Note: OpenTelemetry/Prometheus integrations are not included yet.
2022

23+
Safety note: `Table`, `Join`, `OrderBy`, `Set`, and `Raw` accept SQL fragments and should only receive trusted application SQL. Identifier-oriented helpers like `TableQ`, `SelectQ`, and repository reflection paths quote identifiers automatically, but values should still be passed through placeholders.
24+
2125
### Install
2226

2327
```bash

‎bench_test.go‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func BenchmarkQueryBuilderKeysetPredicate(b *testing.B) {
7676
OrderBy("id ASC").
7777
After("id", 100).
7878
Before("id", 1000)
79-
_ = qb.buildKeysetPredicate()
79+
_, _ = qb.buildKeysetPredicate(0)
8080
}
8181
}
8282

‎coverage.out‎

Lines changed: 10144 additions & 5485 deletions
Large diffs are not rendered by default.

‎dbexec.go‎

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package norm
22

33
import (
44
"context"
5+
"sync"
56

67
"github.com/jackc/pgx/v5"
78
"github.com/jackc/pgx/v5/pgconn"
@@ -51,7 +52,7 @@ func (b breakerExecuter) QueryRow(ctx context.Context, sql string, args ...any)
5152
return errorRow{err: err}
5253
}
5354
row := b.exec.QueryRow(ctx, sql, args...)
54-
return rowWithAfter{Row: row, after: func(err error) { br.after(err) }}
55+
return &rowWithAfter{Row: row, after: func(err error) { br.after(err) }}
5556
}
5657
return b.exec.QueryRow(ctx, sql, args...)
5758
}
@@ -65,13 +66,16 @@ func (e errorRow) Scan(dest ...any) error { return e.err }
6566
type rowWithAfter struct {
6667
pgx.Row
6768
after func(error)
69+
once sync.Once
6870
}
6971

70-
func (r rowWithAfter) Scan(dest ...any) error {
72+
func (r *rowWithAfter) Scan(dest ...any) error {
7173
err := r.Row.Scan(dest...)
72-
if r.after != nil {
73-
r.after(err)
74-
}
74+
r.once.Do(func() {
75+
if r.after != nil {
76+
r.after(err)
77+
}
78+
})
7579
return err
7680
}
7781

@@ -111,7 +115,7 @@ func (r routingExecuter) QueryRow(ctx context.Context, sql string, args ...any)
111115
return errorRow{err: err}
112116
}
113117
row := exec.QueryRow(ctx, sql, args...)
114-
return rowWithAfter{Row: row, after: func(err error) { br.after(err) }}
118+
return &rowWithAfter{Row: row, after: func(err error) { br.after(err) }}
115119
}
116120
return exec.QueryRow(ctx, sql, args...)
117121
}

‎dbexec_row_after_test.go‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ func (f *fakeBaseRow) Scan(_ ...any) error { f.scanned++; return f.retErr }
1212
func TestRowWithAfter_CallsCallback(t *testing.T) {
1313
base := &fakeBaseRow{}
1414
called := 0
15-
r := rowWithAfter{Row: base, after: func(err error) { called++ }}
15+
r := &rowWithAfter{Row: base, after: func(err error) { called++ }}
1616
_ = r.Scan()
17-
if called != 1 || base.scanned != 1 {
17+
_ = r.Scan()
18+
if called != 1 || base.scanned != 2 {
1819
t.Fatalf("after not called or scan not forwarded")
1920
}
2021
}

‎docs/guides/query-builder.md‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
## Query Builder
22

3+
`QueryBuilder` is mutable and not goroutine-safe. Build a fresh chain per query instead of reusing the same builder across concurrent goroutines.
4+
5+
`Table`, `Join`, `OrderBy`, `Set`, and `Raw` accept SQL fragments. Do not pass untrusted user input to these methods. Use placeholders for values, and prefer `TableQ` / `SelectQ` / `SelectQI` when you need identifier quoting.
6+
37
Select:
48

59
```go
@@ -37,6 +41,8 @@ Keyset pagination helpers:
3741
_ = db.Query().Table("users").OrderBy("id ASC").After("id", 123).Limit(20).Find(ctx, &rows)
3842
```
3943

44+
`After` / `Before`, `Insert`, `Returning`, and `OnConflict` are intended for column identifiers and are quoted automatically.
45+
4046
Read routing:
4147

4248
```go

‎docs/guides/repository.md‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Generic CRUD with soft delete, pagination, upsert, and bulk insert via CopyFrom.
44

5+
`UpdatePartial`, `Upsert`, `CreateCopyFrom`, and struct-tag-driven CRUD APIs treat column names as identifiers. Keep those names static in application code; values should still come from placeholders or map values, not string-built SQL.
6+
57
```go
68
type User struct { /* fields with db/norm tags */ }
79

‎logger_more_test.go‎

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package norm
2+
3+
import (
4+
"bytes"
5+
"log"
6+
"strings"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestStdLoggerAndFormatHelpers(t *testing.T) {
12+
oldWriter := log.Writer()
13+
oldFlags := log.Flags()
14+
defer log.SetOutput(oldWriter)
15+
defer log.SetFlags(oldFlags)
16+
17+
var buf bytes.Buffer
18+
log.SetOutput(&buf)
19+
log.SetFlags(0)
20+
21+
StdLogger{}.Debug("query", Field{Key: "id", Value: 7}, Field{Key: "name", Value: "alice"})
22+
out := strings.TrimSpace(buf.String())
23+
if !strings.Contains(out, "[DEBUG] query id=7 name=alice") {
24+
t.Fatalf("unexpected log output: %q", out)
25+
}
26+
27+
buf.Reset()
28+
StdLogger{}.Info("ignored", Field{Key: "stmt", Value: "SELECT 1;"})
29+
if got := strings.TrimSpace(buf.String()); got != "SELECT 1;" {
30+
t.Fatalf("stmt shortcut mismatch: %q", got)
31+
}
32+
33+
if got := formatFields([]Field{{Key: "n", Value: 3}, {Key: "s", Value: "x"}}); got != "n=3 s=x" {
34+
t.Fatalf("formatFields=%q", got)
35+
}
36+
if got := formatFields(nil); got != "" {
37+
t.Fatalf("expected empty fields, got %q", got)
38+
}
39+
}
40+
41+
func TestInlineSQLAndSQLLiteral(t *testing.T) {
42+
ts := time.Unix(1700000000, 0).UTC()
43+
got := inlineSQL("INSERT INTO t VALUES ($1, $2, $3, $4, $5, $6)", []any{"O'Reilly", []byte{0xAB, 0xCD}, true, nil, ts, 7})
44+
checks := []string{
45+
"'O''Reilly'",
46+
"decode('ABCD','hex')",
47+
"TRUE",
48+
"NULL",
49+
ts.Format(time.RFC3339Nano),
50+
"7",
51+
}
52+
for _, check := range checks {
53+
if !strings.Contains(got, check) {
54+
t.Fatalf("inlineSQL missing %q in %q", check, got)
55+
}
56+
}
57+
if !strings.HasSuffix(got, ";") {
58+
t.Fatalf("inlineSQL should end with semicolon: %q", got)
59+
}
60+
61+
if got := inlineSQL("SELECT 1", nil); got != "SELECT 1;" {
62+
t.Fatalf("inlineSQL no-args mismatch: %q", got)
63+
}
64+
if got := sqlLiteral(false); got != "FALSE" {
65+
t.Fatalf("sqlLiteral false mismatch: %q", got)
66+
}
67+
if got := sqlLiteral(struct{ Name string }{Name: "bob"}); !strings.Contains(got, "bob") {
68+
t.Fatalf("sqlLiteral struct mismatch: %q", got)
69+
}
70+
if got := escapeSQLString("a'b"); got != "a''b" {
71+
t.Fatalf("escapeSQLString mismatch: %q", got)
72+
}
73+
}

‎metrics_expvar_test.go‎

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package norm
2+
3+
import (
4+
"expvar"
5+
"strings"
6+
"testing"
7+
"time"
8+
)
9+
10+
func TestExpvarMetricsUpdatesVars(t *testing.T) {
11+
m := ExpvarMetrics{}
12+
m.QueryDuration(12*time.Millisecond, "select 1")
13+
m.ConnectionCount(3, 4)
14+
m.ErrorCount("timeout")
15+
m.CircuitStateChanged("open")
16+
17+
if got := expvar.Get("norm_query_count").String(); got == "0" {
18+
t.Fatalf("query count not updated")
19+
}
20+
if got := expvar.Get("norm_last_query_ms").String(); got != "12" {
21+
t.Fatalf("last query ms mismatch: %s", got)
22+
}
23+
if got := expvar.Get("norm_connections_active").String(); got != "3" {
24+
t.Fatalf("active connections mismatch: %s", got)
25+
}
26+
if got := expvar.Get("norm_connections_idle").String(); got != "4" {
27+
t.Fatalf("idle connections mismatch: %s", got)
28+
}
29+
if got := expvar.Get("norm_circuit_state").String(); got != "\"open\"" {
30+
t.Fatalf("circuit state mismatch: %s", got)
31+
}
32+
if got := expvar.Get("norm_error_count").String(); !strings.Contains(got, "timeout") {
33+
t.Fatalf("error count mismatch: %s", got)
34+
}
35+
}

‎norm_logging_test.go‎

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package norm
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/jackc/pgx/v5/pgxpool"
9+
)
10+
11+
func TestMakeLogFieldsAndPoolHelpers(t *testing.T) {
12+
pool := &pgxpool.Pool{}
13+
readPool := &pgxpool.Pool{}
14+
kn := &KintsNorm{
15+
pool: pool,
16+
readPool: readPool,
17+
logContextFields: func(context.Context) []Field {
18+
return []Field{{Key: "req_id", Value: "r1"}}
19+
},
20+
}
21+
22+
fields := kn.makeLogFields(context.Background(), "SELECT $1", []any{"x"})
23+
if len(fields) != 4 {
24+
t.Fatalf("unexpected field count: %d", len(fields))
25+
}
26+
if fields[0].Key != "req_id" || fields[1].Key != "sql" || fields[2].Key != "args" || fields[3].Key != "stmt" {
27+
t.Fatalf("unexpected fields: %#v", fields)
28+
}
29+
30+
kn.maskParams = true
31+
fields = kn.makeLogFields(context.Background(), "SELECT $1", []any{"x"})
32+
if len(fields) != 3 || fields[2].Value != "[masked]" {
33+
t.Fatalf("masked fields mismatch: %#v", fields)
34+
}
35+
36+
if kn.Pool() != pool {
37+
t.Fatalf("pool accessor mismatch")
38+
}
39+
if kn.ReadPool() != readPool {
40+
t.Fatalf("read pool accessor mismatch")
41+
}
42+
if qb := kn.QueryRead(); qb == nil || qb.exec == nil {
43+
t.Fatalf("query read builder not initialized")
44+
}
45+
46+
kn2 := &KintsNorm{}
47+
if err := kn2.Close(); err != nil {
48+
t.Fatalf("close nil pools: %v", err)
49+
}
50+
if defaultIfZeroDuration(0, time.Second) != time.Second {
51+
t.Fatalf("helper regression")
52+
}
53+
}

0 commit comments

Comments
 (0)