Skip to content

Commit 9e83afc

Browse files
committed
Added subtasks
1 parent 0400355 commit 9e83afc

6 files changed

Lines changed: 246 additions & 141 deletions

File tree

client.go

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"os"
11+
"strings"
1112

1213
"github.com/jackc/pgx/v5"
1314
"github.com/jackc/pgx/v5/pgxpool"
@@ -16,9 +17,11 @@ import (
1617
"github.com/rs/zerolog/log"
1718
)
1819

20+
21+
1922
type BulkJob struct {
20-
Args DPromptsJobArgs `json:"args"`
21-
Metadata map[string]interface{} `json:"metadata,omitempty"`
23+
SubTasks []DPromptsSubTask `json:"sub_tasks"`
24+
BasePrompt string `json:"base_prompt,omitempty"`
2225
}
2326

2427
// RunClient enqueues a job with args and metadata as JSON strings.
@@ -109,32 +112,38 @@ func processJSONArray(ctx context.Context, decoder *json.Decoder, riverClient *r
109112

110113
params, err := toInsertParams(job)
111114
if err != nil {
115+
log.Error().
116+
Int("job_index", total).
117+
Err(err).
118+
Msg("Failed to convert job to InsertManyParams")
112119
return err
113120
}
114121

115122
batch = append(batch, params)
116123
total++
117124
count++
118125

119-
if total%100 == 0 {
120-
log.Info().Msgf("Loaded %d jobs...", total)
126+
if total%50 == 0 {
127+
log.Info().Msgf("Loaded %d jobs into batch...", total)
121128
}
122129

123130
if count == batchSize {
131+
log.Info().Msgf("Inserting batch of %d jobs (total so far: %d)", batchSize, total)
124132
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
133+
log.Error().Err(err).Msg("Failed to insert batch")
125134
return err
126135
}
127-
log.Info().Msgf("Inserted batch of %d jobs (total: %d)", batchSize, total)
128136
batch = batch[:0]
129137
count = 0
130138
}
131139
}
132140

133141
if len(batch) > 0 {
142+
log.Info().Msgf("Inserting final batch of %d jobs (total: %d)", len(batch), total)
134143
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
144+
log.Error().Err(err).Msg("Failed to insert final batch")
135145
return err
136146
}
137-
log.Info().Msgf("Inserted final batch of %d jobs (total: %d)", len(batch), total)
138147
}
139148

140149
log.Info().Msgf("Bulk insert complete. Total jobs inserted: %d", total)
@@ -160,38 +169,43 @@ func processNDJSON(ctx context.Context, decoder *json.Decoder, riverClient *rive
160169

161170
params, err := toInsertParams(job)
162171
if err != nil {
172+
log.Error().
173+
Int("job_index", total).
174+
Err(err).
175+
Msg("Failed to convert job to InsertManyParams")
163176
return err
164177
}
165178

166179
batch = append(batch, params)
167180
total++
168181
count++
169182

170-
if total%100 == 0 {
171-
log.Info().Msgf("Loaded %d jobs...", total)
183+
if total%50 == 0 {
184+
log.Info().Msgf("Loaded %d jobs into batch...", total)
172185
}
173186

174187
if count == batchSize {
188+
log.Info().Msgf("Inserting batch of %d jobs (total so far: %d)", batchSize, total)
175189
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
190+
log.Error().Err(err).Msg("Failed to insert batch")
176191
return err
177192
}
178-
log.Info().Msgf("Inserted batch of %d jobs (total: %d)", batchSize, total)
179193
batch = batch[:0]
180194
count = 0
181195
}
182196
}
183197

184198
if len(batch) > 0 {
199+
log.Info().Msgf("Inserting final batch of %d jobs (total: %d)", len(batch), total)
185200
if err := insertBatch(ctx, riverClient, dbPool, batch); err != nil {
201+
log.Error().Err(err).Msg("Failed to insert final batch")
186202
return err
187203
}
188-
log.Info().Msgf("Inserted final batch of %d jobs (total: %d)", len(batch), total)
189204
}
190205

191206
log.Info().Msgf("Bulk insert complete. Total jobs inserted: %d", total)
192207
return nil
193208
}
194-
195209
// Helper: Read first non-space token
196210
func nextNonSpaceToken(dec *json.Decoder) (json.Token, error) {
197211
for {
@@ -206,35 +220,67 @@ func nextNonSpaceToken(dec *json.Decoder) (json.Token, error) {
206220
}
207221

208222
func toInsertParams(job BulkJob) (river.InsertManyParams, error) {
209-
var insertOpts *river.InsertOpts
210-
if job.Metadata != nil {
211-
metaBytes, err := json.Marshal(job.Metadata)
212-
if err != nil {
213-
return river.InsertManyParams{}, err
223+
if len(job.SubTasks) == 0 {
224+
return river.InsertManyParams{}, fmt.Errorf("job has no sub_tasks")
225+
}
226+
227+
for i, st := range job.SubTasks {
228+
if strings.TrimSpace(st.Prompt) == "" {
229+
return river.InsertManyParams{}, fmt.Errorf("sub_task[%d] has empty prompt", i)
230+
}
231+
}
232+
233+
var opts *river.InsertOpts
234+
if job.SubTasks[0].Metadata != nil {
235+
metadataBytes, _ := json.Marshal(job.SubTasks[0].Metadata)
236+
opts = &river.InsertOpts{
237+
Metadata: metadataBytes,
214238
}
215-
insertOpts = &river.InsertOpts{Metadata: metaBytes}
216239
}
240+
217241
return river.InsertManyParams{
218-
Args: job.Args,
219-
InsertOpts: insertOpts,
242+
Args: DPromptsJobArgs{
243+
BasePrompt: job.BasePrompt,
244+
SubTasks: job.SubTasks,
245+
},
246+
InsertOpts: opts,
220247
}, nil
221248
}
222249

223-
func insertBatch(ctx context.Context, riverClient *river.Client[pgx.Tx], dbPool *pgxpool.Pool, batch []river.InsertManyParams) error {
250+
251+
252+
253+
254+
255+
func insertBatch(
256+
ctx context.Context,
257+
riverClient *river.Client[pgx.Tx],
258+
dbPool *pgxpool.Pool,
259+
batch []river.InsertManyParams,
260+
) error {
261+
224262
tx, err := dbPool.Begin(ctx)
225263
if err != nil {
226264
return err
227265
}
228-
defer tx.Rollback(ctx)
266+
defer func() {
267+
_ = tx.Rollback(ctx) // safe no-op if already committed
268+
}()
229269

230270
if _, err := riverClient.InsertManyTx(ctx, tx, batch); err != nil {
231271
return err
232272
}
233273

234-
return tx.Commit(ctx)
274+
if err := tx.Commit(ctx); err != nil {
275+
return err
276+
}
277+
278+
return nil
235279
}
236280

237281

282+
283+
238284
func newRiverClient(driver *riverpgxv5.Driver) (*river.Client[pgx.Tx], error) {
239285
return river.NewClient[pgx.Tx](driver, &river.Config{})
240286
}

main.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,19 @@ func main() {
1717

1818
rootCmd := &cobra.Command{
1919
Use: "dpr",
20-
Short: "dpr CLI tool for job management",
2120
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
2221
if configPath == "" {
2322
home, err := os.UserHomeDir()
2423
if err != nil {
2524
return err
2625
}
27-
configPath = home + "/.dprompt.toml"
26+
configPath = home + "/.dprompts.toml"
2827
}
2928
return nil
3029
},
3130
}
3231

33-
rootCmd.PersistentFlags().StringVar(&configPath, "config", "", "Path to config file (default: $HOME/.dprompt.toml)")
32+
rootCmd.PersistentFlags().StringVar(&configPath, "config", "", "Path to config file (default: $HOME/.dprompts.toml)")
3433

3534
// ---- Client subcommand ----
3635
var argsJSON, metadataJSON, bulkFile string

ollama.go

Lines changed: 20 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ package main
33
import (
44
"bytes"
55
"encoding/json"
6-
"errors"
6+
"fmt"
77
"io"
88
"net/http"
99
"time"
1010

1111
"github.com/BurntSushi/toml"
12-
"github.com/rs/zerolog/log"
1312
)
1413

1514

@@ -25,27 +24,25 @@ func LoadLLMConfig(configPath string) (*LLMConfig, error) {
2524
return &conf.LLM, nil
2625
}
2726

28-
func CallOllama(prompt string, schema interface{}, configPath string, groupName string, base_prompt string) (string, error) {
29-
startTotal := time.Now()
27+
func CallOllama(
28+
prompt string,
29+
schema interface{},
30+
configPath string,
31+
basePrompt string,
32+
) (string, error) {
3033

31-
// 1️⃣ Load config
32-
start := time.Now()
34+
// Load config
3335
llmConfig, err := LoadLLMConfig(configPath)
3436
if err != nil {
3537
return "", err
3638
}
37-
log.Info().
38-
Dur("duration", time.Since(start)).
39-
Msg("LoadLLMConfig duration")
4039

41-
// 2️⃣ Build request body
42-
start = time.Now()
43-
req := map[string]interface{}{
40+
// Build request
41+
req := map[string]any{
4442
"model": llmConfig.Model,
4543
"stream": false,
46-
"session_id": groupName,
4744
"messages": []map[string]string{
48-
{"role": "system", "content": base_prompt },
45+
{"role": "system", "content": basePrompt},
4946
{"role": "user", "content": prompt},
5047
},
5148
"options": map[string]float64{
@@ -64,52 +61,36 @@ func CallOllama(prompt string, schema interface{}, configPath string, groupName
6461
if err != nil {
6562
return "", err
6663
}
67-
log.Info().
68-
Dur("duration", time.Since(start)).
69-
Msg("Marshal request duration")
7064

71-
// 3️⃣ HTTP POST
72-
start = time.Now()
65+
// HTTP call
7366
client := &http.Client{Timeout: 360 * time.Second}
74-
resp, err := client.Post(llmConfig.APIEndpoint, "application/json", bytes.NewReader(reqBody))
67+
resp, err := client.Post(
68+
llmConfig.APIEndpoint,
69+
"application/json",
70+
bytes.NewReader(reqBody),
71+
)
7572
if err != nil {
7673
return "", err
7774
}
7875
defer resp.Body.Close()
7976

80-
log.Info().
81-
Dur("duration", time.Since(start)).
82-
Msg("HTTP POST duration")
83-
8477
if resp.StatusCode != http.StatusOK {
85-
return "", errors.New("ollama API returned non-200 status: " + resp.Status)
78+
return "", fmt.Errorf("ollama API returned %s", resp.Status)
8679
}
8780

88-
// 4️⃣ Read response
89-
start = time.Now()
81+
// Read & decode response
9082
body, err := io.ReadAll(resp.Body)
9183
if err != nil {
9284
return "", err
9385
}
94-
log.Info().
95-
Dur("duration", time.Since(start)).
96-
Msg("Read response duration")
9786

98-
// 5️⃣ Unmarshal response
99-
start = time.Now()
10087
var ollamaResp OllamaResponse
10188
if err := json.Unmarshal(body, &ollamaResp); err != nil {
10289
return "", err
10390
}
104-
log.Info().
105-
Dur("duration", time.Since(start)).
106-
Msg("Unmarshal response duration")
107-
108-
log.Info().
109-
Dur("total_duration", time.Since(startTotal)).
110-
Msg("Ollama call completed")
11191

11292
return ollamaResp.Message.Content, nil
11393
}
11494

11595

96+

types.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@ type DBConfig struct {
99
Port string
1010
}
1111

12-
13-
type DPromptsJobArgs struct {
12+
type DPromptsSubTask struct {
1413
Prompt string `json:"prompt"`
1514
Schema interface{} `json:"schema,omitempty"`
16-
BasePrompt string `json:"base_prompt,omitempty"`
15+
Metadata map[string]interface{} `json:"metadata,omitempty"` // <-- new
16+
1717
}
1818

19+
type DPromptsJobArgs struct {
20+
SubTasks []DPromptsSubTask `json:"sub_tasks"`
21+
BasePrompt string `json:"base_prompt,omitempty"`
22+
23+
}
1924

2025
type DPromptsJobResult struct {
2126
Response string `json:"response"`

0 commit comments

Comments
 (0)