Skip to content

Commit 9e3a5ed

Browse files
committed
fix: preserve PostgreSQL :: casts in Named queries
Fixes #956 by treating :: after identifiers and after named parameters as PostgreSQL type casts instead of consuming a colon. Also covers :param::type patterns (e.g. #983).
1 parent 40876a6 commit 9e3a5ed

3 files changed

Lines changed: 83 additions & 27 deletions

File tree

named.go

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,52 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
337337
currentVar := 1
338338
name := make([]byte, 0, 10)
339339

340-
for i, b := range qs {
340+
isIdentByteBeforeCast := func(b byte) bool {
341+
return unicode.IsLetter(rune(b)) || unicode.IsDigit(rune(b)) || b == '.'
342+
}
343+
344+
appendBindvar := func(param []byte) {
345+
switch bindType {
346+
case NAMED:
347+
rebound = append(rebound, ':')
348+
rebound = append(rebound, param...)
349+
case QUESTION, UNKNOWN:
350+
rebound = append(rebound, '?')
351+
case DOLLAR:
352+
rebound = append(rebound, '$')
353+
for _, b := range strconv.Itoa(currentVar) {
354+
rebound = append(rebound, byte(b))
355+
}
356+
currentVar++
357+
case AT:
358+
rebound = append(rebound, '@', 'p')
359+
for _, b := range strconv.Itoa(currentVar) {
360+
rebound = append(rebound, byte(b))
361+
}
362+
currentVar++
363+
}
364+
}
365+
366+
for i := 0; i < len(qs); i++ {
367+
b := qs[i]
341368
// a ':' while we're in a name is an error
342369
if b == ':' {
370+
// PostgreSQL type cast after a named parameter, e.g. :boundary::jsonb
371+
if inName && len(name) > 0 && i < last && qs[i+1] == ':' {
372+
names = append(names, string(name))
373+
appendBindvar(name)
374+
name = name[:0]
375+
inName = false
376+
rebound = append(rebound, ':', ':')
377+
i++
378+
continue
379+
}
380+
// PostgreSQL type cast in an identifier, e.g. path::text
381+
if !inName && i > 0 && isIdentByteBeforeCast(qs[i-1]) && i < last && qs[i+1] == ':' {
382+
rebound = append(rebound, ':', ':')
383+
i++
384+
continue
385+
}
343386
// if this is the second ':' in a '::' escape sequence, append a ':'
344387
if inName && i > 0 && qs[i-1] == ':' {
345388
rebound = append(rebound, ':')
@@ -350,7 +393,7 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
350393
return query, names, err
351394
}
352395
inName = true
353-
name = []byte{}
396+
name = name[:0]
354397
} else if inName && i > 0 && b == '=' && len(name) == 0 {
355398
rebound = append(rebound, ':', '=')
356399
inName = false
@@ -369,27 +412,7 @@ func compileNamedQuery(qs []byte, bindType int) (query string, names []string, e
369412
}
370413
// add the string representation to the names list
371414
names = append(names, string(name))
372-
// add a proper bindvar for the bindType
373-
switch bindType {
374-
// oracle only supports named type bind vars even for positional
375-
case NAMED:
376-
rebound = append(rebound, ':')
377-
rebound = append(rebound, name...)
378-
case QUESTION, UNKNOWN:
379-
rebound = append(rebound, '?')
380-
case DOLLAR:
381-
rebound = append(rebound, '$')
382-
for _, b := range strconv.Itoa(currentVar) {
383-
rebound = append(rebound, byte(b))
384-
}
385-
currentVar++
386-
case AT:
387-
rebound = append(rebound, '@', 'p')
388-
for _, b := range strconv.Itoa(currentVar) {
389-
rebound = append(rebound, byte(b))
390-
}
391-
currentVar++
392-
}
415+
appendBindvar(name)
393416
// add this byte to string unless it was not part of the name
394417
if i != last {
395418
rebound = append(rebound, b)

named_issue956_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package sqlx
2+
3+
import "testing"
4+
5+
func TestNamedPostgresCastInIdentifier(t *testing.T) {
6+
query := `SELECT DISTINCT t.path::text AS catalog_path WHERE t.company_id = :company_id FROM table AS t`
7+
q, args, err := Named(query, map[string]interface{}{"company_id": 555})
8+
if err != nil {
9+
t.Fatal(err)
10+
}
11+
want := `SELECT DISTINCT t.path::text AS catalog_path WHERE t.company_id = ? FROM table AS t`
12+
if q != want {
13+
t.Fatalf("got %q want %q", q, want)
14+
}
15+
if len(args) != 1 || args[0].(int) != 555 {
16+
t.Fatalf("args = %#v", args)
17+
}
18+
}
19+
20+
func TestNamedPostgresCastAfterNamedParam(t *testing.T) {
21+
query := `SELECT :boundary::jsonb AS boundary`
22+
q, args, err := Named(query, map[string]interface{}{"boundary": `{"type":"Polygon"}`})
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
want := `SELECT ?::jsonb AS boundary`
27+
if q != want {
28+
t.Fatalf("got %q want %q", q, want)
29+
}
30+
if len(args) != 1 {
31+
t.Fatalf("args = %#v", args)
32+
}
33+
}

named_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ func TestCompileQuery(t *testing.T) {
3939
},
4040
{
4141
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
42-
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
43-
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
44-
T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
45-
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
42+
R: `SELECT 'a::b::c' || first_name, '::ABC::_:' FROM person WHERE first_name=? AND last_name=?`,
43+
D: `SELECT 'a::b::c' || first_name, '::ABC::_:' FROM person WHERE first_name=$1 AND last_name=$2`,
44+
T: `SELECT 'a::b::c' || first_name, '::ABC::_:' FROM person WHERE first_name=@p1 AND last_name=@p2`,
45+
N: `SELECT 'a::b::c' || first_name, '::ABC::_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
4646
V: []string{"first_name", "last_name"},
4747
},
4848
{

0 commit comments

Comments
 (0)