Skip to content

Commit dd31c4a

Browse files
committed
Support nested struct db aliases in Struct (ref #233)
1 parent 86715e9 commit dd31c4a

4 files changed

Lines changed: 285 additions & 13 deletions

File tree

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,34 @@ type ATable struct {
283283
}
284284
```
285285

286+
If a field is itself a struct and has a `db` tag, `Struct` treats that tag as a database alias and expands the nested fields. This is useful for `JOIN` projections built from reusable structs.
287+
288+
```go
289+
type Post struct {
290+
ID string `db:"id"`
291+
Text string `db:"text"`
292+
}
293+
294+
type Comment struct {
295+
Body string `db:"body"`
296+
}
297+
298+
type PostCommentJoined struct {
299+
Post Post `db:"post"`
300+
Comment Comment `db:"comment"`
301+
}
302+
303+
joined := sqlbuilder.NewStruct(new(PostCommentJoined))
304+
sql, _ := joined.SelectFrom("posts post").
305+
Join("comments comment", "post.id = comment.post_id").
306+
Build()
307+
308+
fmt.Println(sql)
309+
310+
// Output:
311+
// SELECT post.id, post.text, comment.body FROM posts post JOIN comments comment ON post.id = comment.post_id
312+
```
313+
286314
For detailed instructions on utilizing `Struct`, refer to the [examples](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#Struct).
287315

288316
Furthermore, `Struct` can be employed as a zero-configuration ORM. Unlike most ORM implementations that necessitate preliminary configurations for database connectivity, `Struct` operates without any configuration, functioning seamlessly with any SQL driver compatible with `database/sql`. `Struct` does not invoke any `database/sql` APIs; it solely generates the appropriate SQL statements with arguments for `DB#Query`/`DB#Exec` or an array of struct field addresses for `Rows#Scan`/`Row#Scan`.

struct.go

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,16 @@ func (s *Struct) updateWithTags(table string, with, without []string, value inte
376376
assignments := make([]string, 0, len(tagged.ForWrite))
377377

378378
for _, sf := range tagged.ForWrite {
379-
name := sf.Name
380-
val := v.FieldByName(name)
379+
val, ok := fieldByIndex(v, sf.Index, false)
380+
381+
if !ok || !val.IsValid() {
382+
if sf.ShouldOmitEmpty(with...) {
383+
continue
384+
}
385+
386+
assignments = append(assignments, ub.Assign(sf.Quote(s.Flavor), nil))
387+
continue
388+
}
381389

382390
if isEmptyValue(val) {
383391
if sf.ShouldOmitEmpty(with...) {
@@ -387,7 +395,10 @@ func (s *Struct) updateWithTags(table string, with, without []string, value inte
387395
val = dereferencedFieldValue(val)
388396
}
389397

390-
data := val.Interface()
398+
var data interface{}
399+
if val.IsValid() {
400+
data = val.Interface()
401+
}
391402
assignments = append(assignments, ub.Assign(sf.Quote(s.Flavor), data))
392403
}
393404

@@ -471,12 +482,16 @@ func (s *Struct) buildColsAndValuesForTag(ib *InsertBuilder, with, without []str
471482

472483
for _, sf := range tagged.ForWrite {
473484
cols = append(cols, sf.Quote(s.Flavor))
474-
name := sf.Name
475485
shouldOmitEmpty := sf.ShouldOmitEmpty(with...)
476486
nilCnt := 0
477487

478488
for i, v := range vs {
479-
val := v.FieldByName(name)
489+
val, ok := fieldByIndex(v, sf.Index, false)
490+
if !ok || !val.IsValid() {
491+
nilCnt++
492+
values[i] = append(values[i], nil)
493+
continue
494+
}
480495

481496
if isEmptyValue(val) && shouldOmitEmpty {
482497
nilCnt++
@@ -639,8 +654,12 @@ func (s *Struct) addrWithFields(fields []*structField, st interface{}) []interfa
639654
addrs := make([]interface{}, 0, len(fields))
640655

641656
for _, sf := range fields {
642-
name := sf.Name
643-
data := v.FieldByName(name).Addr().Interface()
657+
field, ok := fieldByIndex(v, sf.Index, true)
658+
if !ok || !field.IsValid() {
659+
return nil
660+
}
661+
662+
data := field.Addr().Interface()
644663
addrs = append(addrs, data)
645664
}
646665

@@ -708,8 +727,13 @@ func (s *Struct) valuesWithTags(with, without []string, value interface{}) (valu
708727
values = make([]interface{}, 0, len(tagged.ForWrite))
709728

710729
for _, sf := range tagged.ForWrite {
711-
name := sf.Name
712-
data := v.FieldByName(name).Interface()
730+
field, ok := fieldByIndex(v, sf.Index, false)
731+
if !ok || !field.IsValid() {
732+
values = append(values, nil)
733+
continue
734+
}
735+
736+
data := field.Interface()
713737
values = append(values, data)
714738
}
715739

@@ -770,15 +794,53 @@ func dereferencedFieldValue(v reflect.Value) reflect.Value {
770794
break
771795
}
772796

797+
if v.IsNil() {
798+
return reflect.Value{}
799+
}
800+
773801
v = v.Elem()
774802
}
775803

776804
return v
777805
}
778806

807+
func fieldByIndex(v reflect.Value, index []int, allocate bool) (reflect.Value, bool) {
808+
field := v
809+
810+
for i, idx := range index {
811+
for field.Kind() == reflect.Ptr || field.Kind() == reflect.Interface {
812+
if field.IsNil() {
813+
if !allocate || field.Kind() != reflect.Ptr {
814+
return reflect.Value{}, false
815+
}
816+
817+
field.Set(reflect.New(field.Type().Elem()))
818+
}
819+
820+
field = field.Elem()
821+
}
822+
823+
if field.Kind() != reflect.Struct || idx < 0 || idx >= field.NumField() {
824+
return reflect.Value{}, false
825+
}
826+
827+
field = field.Field(idx)
828+
829+
if i == len(index)-1 {
830+
return field, true
831+
}
832+
}
833+
834+
return field, true
835+
}
836+
779837
// isEmptyValue checks if v is zero.
780838
// Following code is borrowed from `IsZero` method in `reflect.Value` since Go 1.13.
781839
func isEmptyValue(v reflect.Value) bool {
840+
if !v.IsValid() {
841+
return true
842+
}
843+
782844
switch v.Kind() {
783845
case reflect.Bool:
784846
return !v.Bool()

struct_test.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,28 @@ type structUserForTest struct {
2121
unexported struct{}
2222
}
2323

24+
type structNestedPostForTest struct {
25+
ID string `db:"id" fieldtag:"post"`
26+
Text string `db:"text"`
27+
}
28+
29+
type structNestedCommentForTest struct {
30+
ID string `db:"id" fieldtag:"comment"`
31+
Body string `db:"body"`
32+
}
33+
34+
type structNestedJoinRowForTest struct {
35+
Post structNestedPostForTest `db:"post"`
36+
Comment *structNestedCommentForTest `db:"comment"`
37+
Rank int `db:"rank" fieldtag:"meta"`
38+
}
39+
40+
type structTimeValueForTest struct {
41+
CreatedAt time.Time `db:"created_at"`
42+
}
43+
2444
var userForTest = NewStruct(new(structUserForTest))
45+
var nestedJoinRowForTest = NewStruct(new(structNestedJoinRowForTest))
2546
var _ = new(structUserForTest).unexported // disable lint warning
2647

2748
func TestStructSelectFrom(t *testing.T) {
@@ -260,6 +281,75 @@ func TestStructColumns(t *testing.T) {
260281
a.Equal(userForTest.ColumnsForTag("invalid"), nil)
261282
}
262283

284+
func TestStructNestedAliasSelectAndColumns(t *testing.T) {
285+
a := assert.New(t)
286+
sql, args := nestedJoinRowForTest.SelectFrom("posts post").
287+
Join("comments comment", "post.id = comment.post_id").
288+
Build()
289+
290+
a.Equal(sql, "SELECT post.id, post.text, comment.id, comment.body, post.rank FROM posts post JOIN comments comment ON post.id = comment.post_id")
291+
a.Equal(args, nil)
292+
a.Equal(nestedJoinRowForTest.Columns(), []string{"post.id", "post.text", "comment.id", "comment.body", "rank"})
293+
a.Equal(nestedJoinRowForTest.WithTag("post", "meta").Columns(), []string{"rank", "post.id"})
294+
}
295+
296+
func TestStructNestedAliasAddrValuesAndWrites(t *testing.T) {
297+
a := assert.New(t)
298+
row := &structNestedJoinRowForTest{
299+
Post: structNestedPostForTest{
300+
ID: "post-1",
301+
Text: "hello",
302+
},
303+
Comment: &structNestedCommentForTest{
304+
ID: "comment-1",
305+
Body: "world",
306+
},
307+
Rank: 7,
308+
}
309+
310+
a.Equal(nestedJoinRowForTest.Values(row), []interface{}{"post-1", "hello", "comment-1", "world", 7})
311+
312+
updateSQL, updateArgs := nestedJoinRowForTest.Update("joined", row).Build()
313+
a.Equal(updateSQL, "UPDATE joined SET post.id = ?, post.text = ?, comment.id = ?, comment.body = ?, rank = ?")
314+
a.Equal(updateArgs, []interface{}{"post-1", "hello", "comment-1", "world", 7})
315+
316+
insertSQL, insertArgs := nestedJoinRowForTest.InsertInto("joined", row).Build()
317+
a.Equal(insertSQL, "INSERT INTO joined (post.id, post.text, comment.id, comment.body, rank) VALUES (?, ?, ?, ?, ?)")
318+
a.Equal(insertArgs, []interface{}{"post-1", "hello", "comment-1", "world", 7})
319+
320+
var scanned structNestedJoinRowForTest
321+
_, _ = fmt.Sscan("post-2 newer comment-2 scanned 9", nestedJoinRowForTest.Addr(&scanned)...)
322+
a.Equal(scanned.Post.ID, "post-2")
323+
a.Equal(scanned.Post.Text, "newer")
324+
a.Assert(scanned.Comment != nil)
325+
a.Equal(scanned.Comment.ID, "comment-2")
326+
a.Equal(scanned.Comment.Body, "scanned")
327+
a.Equal(scanned.Rank, 9)
328+
329+
var reordered structNestedJoinRowForTest
330+
_, _ = fmt.Sscan("comment-3 11 post-3 reverse", nestedJoinRowForTest.AddrWithCols([]string{"comment.id", "rank", "post.id", "comment.body"}, &reordered)...)
331+
a.Assert(reordered.Comment != nil)
332+
a.Equal(reordered.Comment.ID, "comment-3")
333+
a.Equal(reordered.Rank, 11)
334+
a.Equal(reordered.Post.ID, "post-3")
335+
a.Equal(reordered.Comment.Body, "reverse")
336+
337+
withNilComment := &structNestedJoinRowForTest{
338+
Post: structNestedPostForTest{ID: "post-4", Text: "nil"},
339+
Rank: 13,
340+
}
341+
a.Equal(nestedJoinRowForTest.Values(withNilComment), []interface{}{"post-4", "nil", nil, nil, 13})
342+
}
343+
344+
func TestStructTaggedTimeFieldRemainsScalar(t *testing.T) {
345+
a := assert.New(t)
346+
st := NewStruct(new(structTimeValueForTest))
347+
a.Equal(st.Columns(), []string{"created_at"})
348+
349+
sql, _ := st.SelectFrom("events e").Build()
350+
a.Equal(sql, "SELECT e.created_at FROM events e")
351+
}
352+
263353
func TestWithAndWithoutTags(t *testing.T) {
264354
type Tags struct {
265355
A int `db:"a" fieldtag:"tag1"`
@@ -418,6 +508,33 @@ func ExampleStruct_buildJOIN() {
418508
// [Huan%]
419509
}
420510

511+
func ExampleStruct_buildJOINWithNestedStructAlias() {
512+
type Post struct {
513+
ID string `db:"id"`
514+
Text string `db:"text"`
515+
}
516+
517+
type Comment struct {
518+
Body string `db:"body"`
519+
}
520+
521+
type PostCommentJoined struct {
522+
Post Post `db:"post"`
523+
Comment Comment `db:"comment"`
524+
}
525+
526+
joined := NewStruct(new(PostCommentJoined))
527+
sb := joined.SelectFrom("posts post").Join("comments comment", "post.id = comment.post_id")
528+
sql, args := sb.Build()
529+
530+
fmt.Println(sql)
531+
fmt.Println(args)
532+
533+
// Output:
534+
// SELECT post.id, post.text, comment.body FROM posts post JOIN comments comment ON post.id = comment.post_id
535+
// []
536+
}
537+
421538
var orderDB testDB = 1
422539

423540
func ExampleStruct_WithTag() {

0 commit comments

Comments
 (0)