Skip to content

Commit 4a3cb08

Browse files
authored
Merge pull request #2 from HexmosTech/rijul/cli-fixes
Added Session based caching, concurrency option, Updated metadata logic, CLI improvements
2 parents 3e52a85 + 75cd2dd commit 4a3cb08

11 files changed

Lines changed: 712 additions & 169 deletions

File tree

.dprompts.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,7 @@ port = "5432"
1010
api-endpoint = "http://localhost:11434/api/chat"
1111
model = "gemma2:2b"
1212
temperature = 0.7
13-
topP = 0.9
13+
topP = 0.9
14+
15+
[worker]
16+
concurrent_workers = 1

client.go

Lines changed: 179 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package main
22

33
import (
4+
"bufio"
45
"context"
56
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
610
"os"
11+
"strings"
712

813
"github.com/jackc/pgx/v5"
914
"github.com/jackc/pgx/v5/pgxpool"
@@ -13,8 +18,8 @@ import (
1318
)
1419

1520
type BulkJob struct {
16-
Args DPromptsJobArgs `json:"args"`
17-
Metadata map[string]interface{} `json:"metadata,omitempty"`
21+
SubTasks []DPromptsSubTask `json:"sub_tasks"`
22+
BasePrompt string `json:"base_prompt,omitempty"`
1823
}
1924

2025
// RunClient enqueues a job with args and metadata as JSON strings.
@@ -40,7 +45,6 @@ func RunClient(ctx context.Context, driver *riverpgxv5.Driver, argsJSON string,
4045
log.Fatal().Err(err).Msg("Failed to parse args JSON")
4146
}
4247

43-
4448
var insertOpts *river.InsertOpts
4549
if metadataJSON != "" {
4650
var metadata map[string]interface{}
@@ -73,45 +77,197 @@ func enqueueBulkJobsFromFile(ctx context.Context, riverClient *river.Client[pgx.
7377
}
7478
defer file.Close()
7579

76-
var jobs []BulkJob
77-
if err := json.NewDecoder(file).Decode(&jobs); err != nil {
78-
return err
79-
}
80+
decoder := json.NewDecoder(bufio.NewReader(file))
8081

81-
tx, err := dbPool.Begin(ctx)
82+
// Peek first non-whitespace token
83+
tok, err := nextNonSpaceToken(decoder)
8284
if err != nil {
83-
return err
85+
return fmt.Errorf("cannot read file: %w", err)
8486
}
85-
defer tx.Rollback(ctx)
8687

87-
var jobsToInsert []river.InsertManyParams
88+
// NDJSON format (each line = JSON object)
89+
if tok != json.Delim('[') {
90+
return processNDJSON(ctx, decoder, riverClient, dbPool)
91+
}
92+
93+
return processJSONArray(ctx, decoder, riverClient, dbPool)
94+
}
95+
96+
// ------ JSON ARRAY VERSION ------
97+
func processJSONArray(ctx context.Context, decoder *json.Decoder, riverClient *river.Client[pgx.Tx], dbPool *pgxpool.Pool) error {
98+
const batchSize = 500
8899

89-
for i := range jobs {
100+
batch := make([]river.InsertManyParams, 0, batchSize)
101+
total := 0
102+
count := 0
103+
104+
for decoder.More() {
105+
var job BulkJob
106+
if err := decoder.Decode(&job); err != nil {
107+
return fmt.Errorf("decode error at item %d: %w", total, err)
108+
}
109+
110+
params, err := toInsertParams(job)
111+
if err != nil {
112+
log.Error().
113+
Int("job_index", total).
114+
Err(err).
115+
Msg("Failed to convert job to InsertManyParams")
116+
return err
117+
}
118+
119+
batch = append(batch, params)
120+
total++
121+
count++
122+
123+
if total%50 == 0 {
124+
log.Info().Msgf("Loaded %d jobs into batch...", total)
125+
}
90126

91-
var insertOpts *river.InsertOpts
92-
if jobs[i].Metadata != nil {
93-
metadataBytes, err := json.Marshal(jobs[i].Metadata)
94-
if err != nil {
127+
if count == batchSize {
128+
log.Info().Msgf("Inserting batch of %d jobs (total so far: %d)", batchSize, total)
129+
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
130+
log.Error().Err(err).Msg("Failed to insert batch")
95131
return err
96132
}
97-
insertOpts = &river.InsertOpts{Metadata: metadataBytes}
133+
batch = batch[:0]
134+
count = 0
98135
}
99-
jobsToInsert = append(jobsToInsert, river.InsertManyParams{
100-
Args: jobs[i].Args,
101-
InsertOpts: insertOpts,
102-
})
103136
}
104137

105-
results, err := riverClient.InsertManyTx(ctx, tx, jobsToInsert)
138+
if len(batch) > 0 {
139+
log.Info().Msgf("Inserting final batch of %d jobs (total: %d)", len(batch), total)
140+
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
141+
log.Error().Err(err).Msg("Failed to insert final batch")
142+
return err
143+
}
144+
}
145+
146+
log.Info().Msgf("Bulk insert complete. Total jobs inserted: %d", total)
147+
return nil
148+
}
149+
150+
// ------ NDJSON VERSION ------
151+
func processNDJSON(ctx context.Context, decoder *json.Decoder, riverClient *river.Client[pgx.Tx], dbPool *pgxpool.Pool) error {
152+
const batchSize = 500
153+
154+
batch := make([]river.InsertManyParams, 0, batchSize)
155+
total := 0
156+
count := 0
157+
158+
for {
159+
var job BulkJob
160+
if err := decoder.Decode(&job); err != nil {
161+
if errors.Is(err, io.EOF) {
162+
break
163+
}
164+
return err
165+
}
166+
167+
params, err := toInsertParams(job)
168+
if err != nil {
169+
log.Error().
170+
Int("job_index", total).
171+
Err(err).
172+
Msg("Failed to convert job to InsertManyParams")
173+
return err
174+
}
175+
176+
batch = append(batch, params)
177+
total++
178+
count++
179+
180+
if total%50 == 0 {
181+
log.Info().Msgf("Loaded %d jobs into batch...", total)
182+
}
183+
184+
if count == batchSize {
185+
log.Info().Msgf("Inserting batch of %d jobs (total so far: %d)", batchSize, total)
186+
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
187+
log.Error().Err(err).Msg("Failed to insert batch")
188+
return err
189+
}
190+
batch = batch[:0]
191+
count = 0
192+
}
193+
}
194+
195+
if len(batch) > 0 {
196+
log.Info().Msgf("Inserting final batch of %d jobs (total: %d)", len(batch), total)
197+
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
198+
log.Error().Err(err).Msg("Failed to insert final batch")
199+
return err
200+
}
201+
}
202+
203+
log.Info().Msgf("Bulk insert complete. Total jobs inserted: %d", total)
204+
return nil
205+
}
206+
207+
// Helper: Read first non-space token
208+
func nextNonSpaceToken(dec *json.Decoder) (json.Token, error) {
209+
for {
210+
t, err := dec.Token()
211+
if err != nil {
212+
return nil, err
213+
}
214+
if _, ok := t.(json.Delim); ok || t != nil {
215+
return t, nil
216+
}
217+
}
218+
}
219+
220+
func toInsertParams(job BulkJob) (river.InsertManyParams, error) {
221+
if len(job.SubTasks) == 0 {
222+
return river.InsertManyParams{}, fmt.Errorf("job has no sub_tasks")
223+
}
224+
225+
for i, st := range job.SubTasks {
226+
if strings.TrimSpace(st.Prompt) == "" {
227+
return river.InsertManyParams{}, fmt.Errorf("sub_task[%d] has empty prompt", i)
228+
}
229+
}
230+
231+
var opts *river.InsertOpts
232+
if job.SubTasks[0].Metadata != nil {
233+
metadataBytes, _ := json.Marshal(job.SubTasks[0].Metadata)
234+
opts = &river.InsertOpts{
235+
Metadata: metadataBytes,
236+
}
237+
}
238+
239+
return river.InsertManyParams{
240+
Args: DPromptsJobArgs{
241+
BasePrompt: job.BasePrompt,
242+
SubTasks: job.SubTasks,
243+
},
244+
InsertOpts: opts,
245+
}, nil
246+
}
247+
248+
func insertBatch(
249+
ctx context.Context,
250+
riverClient *river.Client[pgx.Tx],
251+
dbPool *pgxpool.Pool,
252+
batch []river.InsertManyParams,
253+
) error {
254+
255+
tx, err := dbPool.Begin(ctx)
106256
if err != nil {
107257
return err
108258
}
259+
defer func() {
260+
_ = tx.Rollback(ctx) // safe no-op if already committed
261+
}()
262+
263+
if _, err := riverClient.InsertManyTx(ctx, tx, batch); err != nil {
264+
return err
265+
}
109266

110267
if err := tx.Commit(ctx); err != nil {
111268
return err
112269
}
113270

114-
log.Info().Msgf("Successfully enqueued %d jobs", len(results))
115271
return nil
116272
}
117273

config.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"github.com/rs/zerolog/log"
1212
)
1313

14-
1514
func LoadDBConfig(path string) (*DBConfig, error) {
1615
var conf struct {
1716
Database DBConfig
@@ -23,7 +22,6 @@ func LoadDBConfig(path string) (*DBConfig, error) {
2322
return &conf.Database, nil
2423
}
2524

26-
2725
func NewDBPool(ctx context.Context, configPath string) (*pgxpool.Pool, error) {
2826

2927
if configPath == "" {
@@ -65,4 +63,22 @@ func NewDBPool(ctx context.Context, configPath string) (*pgxpool.Pool, error) {
6563
}
6664

6765
return dbPool, nil
68-
}
66+
}
67+
68+
func LoadWorkerConfig(path string) (*WorkerConfig, error) {
69+
var conf struct {
70+
Worker WorkerConfig
71+
}
72+
73+
_, err := toml.DecodeFile(path, &conf)
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
// Default fallback
79+
if conf.Worker.ConcurrentWorkers <= 0 {
80+
conf.Worker.ConcurrentWorkers = 1
81+
}
82+
83+
return &conf.Worker, nil
84+
}

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ require (
1212

1313
require (
1414
github.com/davecgh/go-spew v1.1.1 // indirect
15+
github.com/inconshreveable/mousetrap v1.1.0 // indirect
1516
github.com/jackc/pgpassfile v1.0.0 // indirect
1617
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
1718
github.com/jackc/puddle/v2 v2.2.2 // indirect
@@ -21,6 +22,8 @@ require (
2122
github.com/riverqueue/river/riverdriver v0.26.0 // indirect
2223
github.com/riverqueue/river/rivershared v0.26.0 // indirect
2324
github.com/riverqueue/river/rivertype v0.26.0 // indirect
25+
github.com/spf13/cobra v1.10.2 // indirect
26+
github.com/spf13/pflag v1.0.9 // indirect
2427
github.com/stretchr/testify v1.11.1 // indirect
2528
github.com/tidwall/gjson v1.18.0 // indirect
2629
github.com/tidwall/match v1.1.1 // indirect

go.sum

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
22
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
33
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
4+
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
45
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
56
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
67
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
78
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
9+
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
10+
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
811
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0=
912
github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds=
1013
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
@@ -44,6 +47,11 @@ github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99
4447
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
4548
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
4649
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
50+
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
51+
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
52+
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
53+
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
54+
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
4755
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
4856
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
4957
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
@@ -61,6 +69,7 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
6169
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
6270
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
6371
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
72+
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
6473
golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
6574
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
6675
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=

0 commit comments

Comments
 (0)