Skip to content

Commit 03c69da

Browse files
authored
Merge pull request #230 from stvnkiss/master
Add From support for UpdateBuilder
2 parents 567ea02 + 2301772 commit 03c69da

3 files changed

Lines changed: 136 additions & 1 deletion

File tree

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ fmt.Println(ub)
203203
// UPDATE users SET level = level + ? WHERE id = ?
204204
```
205205

206+
### Build `UPDATE ... FROM`
207+
208+
`UpdateBuilder.From` emits a `FROM` clause for PostgreSQL, SQLite, and SQLServer flavors (it is ignored by other flavors). When a CTE includes tables created with `CTETable`, those table names are emitted before any explicit `From(...)` tables.
209+
210+
```go
211+
ub := PostgreSQL.NewUpdateBuilder()
212+
ub.Update("users")
213+
ub.Set(ub.Assign("name", "Huan Du"))
214+
ub.From("people")
215+
ub.Where("users.person_id = people.id")
216+
217+
sql, args := ub.Build()
218+
fmt.Println(sql)
219+
fmt.Println(args)
220+
221+
// Output:
222+
// UPDATE users SET name = $1 FROM people WHERE users.person_id = people.id
223+
// [Huan Du]
224+
```
225+
206226
Refer to the [WhereClause](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#WhereClause) examples to learn its usage.
207227

208228
### Build `ORDER BY` clause

update.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
updateMarkerAfterWith
1616
updateMarkerAfterUpdate
1717
updateMarkerAfterSet
18+
updateMarkerAfterFrom
1819
updateMarkerAfterWhere
1920
updateMarkerAfterOrderBy
2021
updateMarkerAfterLimit
@@ -71,6 +72,7 @@ type UpdateBuilder struct {
7172
cteBuilder *CTEBuilder
7273

7374
tables []string
75+
fromTables []string
7476
assignments []string
7577
orderByCols []string
7678
order string
@@ -140,6 +142,13 @@ func (ub *UpdateBuilder) SetMore(assignment ...string) *UpdateBuilder {
140142
return ub
141143
}
142144

145+
// From sets table names of FROM in UPDATE.
146+
func (ub *UpdateBuilder) From(table ...string) *UpdateBuilder {
147+
ub.fromTables = table
148+
ub.marker = updateMarkerAfterFrom
149+
return ub
150+
}
151+
143152
// Where adds expressions to the WHERE clause in UPDATE.
144153
//
145154
// Multiple calls to Where will join expressions with AND.
@@ -339,7 +348,7 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
339348
buf.WriteStringsPrefixed("INSERTED.", ub.returning, ", ")
340349
}
341350

342-
ub.injection.WriteTo(buf, insertMarkerAfterReturning)
351+
ub.injection.WriteTo(buf, updateMarkerAfterReturning)
343352
}
344353

345354
if flavor != MySQL {
@@ -354,6 +363,20 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
354363
}
355364
}
356365

366+
if flavor == PostgreSQL || flavor == SQLite || flavor == SQLServer {
367+
if len(ub.fromTables) > 0 {
368+
369+
if ub.cteBuilder == nil || len(ub.cteBuilder.tableNamesForFrom()) == 0 {
370+
buf.WriteLeadingString("FROM ")
371+
} else {
372+
buf.WriteString(", ")
373+
}
374+
375+
buf.WriteStrings(ub.fromTables, ", ")
376+
ub.injection.WriteTo(buf, updateMarkerAfterFrom)
377+
}
378+
}
379+
357380
if ub.WhereClause != nil {
358381
ub.whereClauseProxy.WhereClause = ub.WhereClause
359382
defer func() {

update_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,95 @@ func TestUpdateBuilderClone(t *testing.T) {
283283
clone.Asc().Limit(5)
284284
a.NotEqual(ub.String(), clone.String())
285285
}
286+
287+
func TestUpdateBuilderFrom(t *testing.T) {
288+
a := assert.New(t)
289+
ub := NewUpdateBuilder()
290+
ub.Update("user")
291+
ub.Set(ub.Assign("name", "Huan Du"))
292+
ub.From("person")
293+
ub.Where(ub.Equal("id", 123))
294+
295+
sql, _ := ub.BuildWithFlavor(MySQL)
296+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
297+
298+
sql, _ = ub.BuildWithFlavor(PostgreSQL)
299+
a.Equal("UPDATE user SET name = $1 FROM person WHERE id = $2", sql)
300+
301+
sql, _ = ub.BuildWithFlavor(SQLite)
302+
a.Equal("UPDATE user SET name = ? FROM person WHERE id = ?", sql)
303+
304+
sql, _ = ub.BuildWithFlavor(SQLServer)
305+
a.Equal("UPDATE user SET name = @p1 FROM person WHERE id = @p2", sql)
306+
307+
sql, _ = ub.BuildWithFlavor(CQL)
308+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
309+
310+
sql, _ = ub.BuildWithFlavor(ClickHouse)
311+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
312+
313+
sql, _ = ub.BuildWithFlavor(Presto)
314+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
315+
316+
// Test with no from
317+
ub2 := NewUpdateBuilder()
318+
ub2.Update("user")
319+
ub2.Set(ub2.Assign("name", "Test"))
320+
ub2.From()
321+
ub2.Where(ub2.Equal("id", 1))
322+
323+
sql, _ = ub2.BuildWithFlavor(PostgreSQL)
324+
a.Equal("UPDATE user SET name = $1 WHERE id = $2", sql)
325+
326+
// Test with multiple from tables
327+
ub3 := NewUpdateBuilder()
328+
ub3.Update("user")
329+
ub3.Set(ub3.Assign("name", "Test"))
330+
ub3.From("person", "company")
331+
ub3.Where(ub3.Equal("id", 1))
332+
333+
sql, _ = ub3.BuildWithFlavor(PostgreSQL)
334+
a.Equal("UPDATE user SET name = $1 FROM person, company WHERE id = $2", sql)
335+
336+
// Test chaining
337+
ub5 := NewUpdateBuilder().Update("user").Set("status = 1").From("person").From("company")
338+
sql, _ = ub5.BuildWithFlavor(PostgreSQL)
339+
a.Equal("UPDATE user SET status = 1 FROM company", sql) // Last From call overwrites
340+
341+
// Test SQL injection after FROM
342+
ub6 := NewUpdateBuilder()
343+
ub6.Update("user")
344+
ub6.Set(ub6.Assign("name", "Test"))
345+
ub6.From("person")
346+
ub6.SQL("/* comment after from */")
347+
ub6.Where(ub6.Equal("id", 1))
348+
349+
sql, _ = ub6.BuildWithFlavor(PostgreSQL)
350+
a.Equal("UPDATE user SET name = $1 FROM person /* comment after from */ WHERE id = $2", sql)
351+
352+
// Test with CTE (WITH clause)
353+
cte := With(CTETable("temp_user").As(Select("id").From("active_users")))
354+
ub7 := cte.Update("user")
355+
ub7.Set(ub7.Assign("status", "active"))
356+
ub7.From("person")
357+
ub7.Where("user.id IN (SELECT id FROM temp_user)")
358+
359+
sql, _ = ub7.BuildWithFlavor(PostgreSQL)
360+
a.Equal("WITH temp_user AS (SELECT id FROM active_users) UPDATE user SET status = $1 FROM temp_user, person WHERE user.id IN (SELECT id FROM temp_user)", sql)
361+
362+
// Test with SQLServer Returning
363+
ub8 := ub.Clone().Returning("id", "name")
364+
sql, _ = ub8.BuildWithFlavor(SQLServer)
365+
a.Equal("UPDATE user SET name = @p1 OUTPUT INSERTED.id, INSERTED.name FROM person WHERE id = @p2", sql)
366+
367+
// Test with SQL injection after WHERE
368+
ub9 := NewUpdateBuilder()
369+
ub9.Update("user")
370+
ub9.Set(ub9.Assign("name", "Test"))
371+
ub9.From("person")
372+
ub9.Where("user.id = person.id")
373+
ub9.SQL("/* comment after where */")
374+
375+
sql, _ = ub9.BuildWithFlavor(PostgreSQL)
376+
a.Equal("UPDATE user SET name = $1 FROM person WHERE user.id = person.id /* comment after where */", sql)
377+
}

0 commit comments

Comments
 (0)