Skip to content

Commit fe23096

Browse files
committed
Refactors task execution logic
1 parent 7bd5a1a commit fe23096

10 files changed

Lines changed: 77 additions & 145 deletions

File tree

internal/workflow/choice_task.go

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,14 @@ func (c *ChoiceTask) GetAlternatives() []TaskId {
3232
return c.AlternativeNextTasks
3333
}
3434

35-
func (c *ChoiceTask) execute(progress *Progress, input *PartialData, r *Request) (*PartialData, *Progress, bool, error) {
36-
37-
outputData := NewPartialData(ReqId(r.Id), c.GetId(), nil) // partial initialization of outputData
38-
39-
// NOTE: we do not call task.CheckInput() as this task has no signature to match against
35+
func (c *ChoiceTask) Evaluate(input *PartialData, r *Request) (TaskId, error) {
4036

4137
// simply evaluate the Conditions and set the matching one
4238
matchedCondition := -1
4339
for i, condition := range c.Conditions {
4440
ok, err := condition.Evaluate(input.Data)
4541
if err != nil {
46-
return nil, progress, false, fmt.Errorf("error while testing condition: %v", err)
42+
return "", fmt.Errorf("error while testing condition: %v", err)
4743
}
4844
if ok {
4945
matchedCondition = i
@@ -52,64 +48,10 @@ func (c *ChoiceTask) execute(progress *Progress, input *PartialData, r *Request)
5248
}
5349

5450
if matchedCondition < 0 {
55-
return nil, progress, false, fmt.Errorf("no condition is met")
56-
}
57-
58-
nextTask := c.AlternativeNextTasks[matchedCondition]
59-
outputData.ForTask = nextTask
60-
outputData.Data = input.Data
61-
62-
// we skip all branch that will not be executed
63-
tasksToSkip := c.GetTasksToSkip(r.W, matchedCondition)
64-
for _, t := range tasksToSkip {
65-
progress.Skip(t.GetId())
66-
}
67-
68-
progress.Complete(c.GetId())
69-
err := progress.AddReadyTask(nextTask)
70-
if err != nil {
71-
return nil, progress, false, err
51+
return "", fmt.Errorf("no condition is met")
7252
}
73-
return outputData, progress, true, nil
74-
}
7553

76-
// VisitBranch returns all tasks of a branch under a choice task; branch number starts from 0
77-
func (c *ChoiceTask) VisitBranch(workflow *Workflow, branch int) []Task {
78-
branchTasks := make([]Task, 0)
79-
if len(c.AlternativeNextTasks) <= branch {
80-
fmt.Printf("fail to get branch %d\n", branch)
81-
return branchTasks
82-
}
83-
taskId := c.AlternativeNextTasks[branch]
84-
return Visit(workflow, taskId, true)
85-
}
86-
87-
// GetTasksToSkip skips all tasks that are in a branch that will not be executed.
88-
// If a skipped branch contains one or more tasks that in use by the current branch, the task
89-
// should NOT be skipped (Tested in TestParsingChoiceDagWithDataTestExpr)
90-
func (c *ChoiceTask) GetTasksToSkip(workflow *Workflow, matchedCondition int) []Task {
91-
toSkip := make([]Task, 0)
92-
93-
toNotSkip := c.VisitBranch(workflow, matchedCondition)
94-
for i := 0; i < len(c.AlternativeNextTasks); i++ {
95-
if i == matchedCondition {
96-
continue
97-
}
98-
branchTasks := c.VisitBranch(workflow, i)
99-
for _, task := range branchTasks {
100-
shouldBeSkipped := true
101-
for _, taskNotSkipped := range toNotSkip {
102-
if task.Equals(taskNotSkipped) {
103-
shouldBeSkipped = false
104-
break
105-
}
106-
}
107-
if shouldBeSkipped {
108-
toSkip = append(toSkip, task)
109-
}
110-
}
111-
}
112-
return toSkip
54+
return c.AlternativeNextTasks[matchedCondition], nil
11355
}
11456

11557
func (c *ChoiceTask) String() string {

internal/workflow/end_task.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ func NewEndTask() *EndTask {
1717
}
1818
}
1919

20-
func (e *EndTask) execute(progress *Progress, partialData *PartialData) (*PartialData, *Progress, bool, error) {
21-
progress.Complete(e.Id)
22-
return partialData, progress, false, nil // false because we want to stop when reaching the end
23-
}
24-
2520
func (e *EndTask) String() string {
2621
return fmt.Sprintf("[EndTask]")
2722
}

internal/workflow/fail_task.go

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,6 @@ func NewFailureTask(error, cause string) *FailureTask {
2121
return &fail
2222
}
2323

24-
func (f *FailureTask) execute(progress *Progress, r *Request) (*PartialData, *Progress, bool, error) {
25-
26-
output := make(map[string]interface{})
27-
output[f.Error] = f.Cause
28-
outputData := NewPartialData(ReqId(r.Id), f.GetNext(), output)
29-
30-
progress.Complete(f.GetId())
31-
32-
shouldContinueExecution := f.GetType() != Fail && f.GetType() != Succeed
33-
return outputData, progress, shouldContinueExecution, nil
34-
}
35-
3624
func (f *FailureTask) GetNext() TaskId {
3725
return f.NextTask
3826
}
@@ -48,3 +36,9 @@ func (f *FailureTask) SetNext(nextTask Task) error {
4836
func (f *FailureTask) String() string {
4937
return fmt.Sprintf("[Fail: %s]", f.Error)
5038
}
39+
40+
func (f *FailureTask) execute(input *PartialData, r *Request) (map[string]interface{}, error) {
41+
output := make(map[string]interface{})
42+
output[f.Error] = f.Cause
43+
return output, nil
44+
}

internal/workflow/function_task.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,18 @@ func (s *FunctionTask) SetNext(nextTask Task) error {
3535
return nil
3636
}
3737

38-
func (s *FunctionTask) execute(progress *Progress, input *PartialData, r *Request) (*PartialData, *Progress, bool, error) {
38+
func (s *FunctionTask) execute(input *PartialData, r *Request) (map[string]interface{}, error) {
3939

4040
err := s.CheckInput(input.Data)
4141
if err != nil {
42-
return nil, progress, false, err
42+
return nil, err
4343
}
4444
output, err := s.exec(r, input.Data)
4545
if err != nil {
46-
return nil, progress, false, err
46+
return nil, err
4747
}
4848

49-
nextTask := s.GetNext()
50-
outputData := NewPartialData(ReqId(r.Id), nextTask, output)
51-
52-
progress.Complete(s.Id)
53-
err = progress.AddReadyTask(nextTask)
54-
if err != nil {
55-
return nil, progress, false, err
56-
}
57-
return outputData, progress, true, nil
49+
return output, nil
5850
}
5951

6052
func (s *FunctionTask) exec(compRequest *Request, params ...map[string]interface{}) (map[string]interface{}, error) {

internal/workflow/pass_task.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,7 @@ func (p *PassTask) SetNext(nextTask Task) error {
3131
func (p *PassTask) String() string {
3232
return "[ Pass ]"
3333
}
34+
35+
func (p *PassTask) execute(input *PartialData, r *Request) (map[string]interface{}, error) {
36+
return input.Data, nil
37+
}

internal/workflow/progress.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,6 @@ func (p *Progress) Complete(id TaskId) {
110110
}
111111
}
112112

113-
func (p *Progress) AddReadyTask(id TaskId) error {
114-
if !p.IsReady(id) {
115-
return fmt.Errorf("the task is not ready")
116-
}
117-
p.ReadyToExecute = append(p.ReadyToExecute, id)
118-
return nil
119-
}
120-
121113
// TODO: skip on cascade next nodes
122114
func (p *Progress) Skip(id TaskId) {
123115
p.Status[id] = Skipped

internal/workflow/start_task.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,8 @@ func (s *StartTask) SetNext(nextTask Task) error {
2929
return nil
3030
}
3131

32-
func (s *StartTask) execute(progress *Progress, partialData *PartialData) (*PartialData, *Progress, bool, error) {
33-
34-
// TODO: move this logic into workflow "handleCompletion(output)"
35-
progress.Complete(s.GetId())
36-
37-
err := progress.AddReadyTask(s.GetNext())
38-
if err != nil {
39-
return nil, progress, false, err
40-
}
41-
return partialData, progress, true, nil
32+
func (s *StartTask) execute(input *PartialData, r *Request) (map[string]interface{}, error) {
33+
return input.Data, nil
4234
}
4335

4436
func (s *StartTask) String() string {

internal/workflow/succeed_task.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,7 @@ func (s *SuccessTask) GetNext() TaskId {
3131
func (s *SuccessTask) String() string {
3232
return "[Succeed]"
3333
}
34+
35+
func (s *SuccessTask) execute(input *PartialData, r *Request) (map[string]interface{}, error) {
36+
return input.Data, nil
37+
}

internal/workflow/task.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ type UnaryTask interface {
2525

2626
// SetNext connects the output of this task to another Task
2727
SetNext(nextTask Task) error
28+
29+
execute(data *PartialData, r *Request) (map[string]interface{}, error)
2830
}
2931

3032
type ConditionalTask interface {
3133
Task
3234
AddAlternative(nextTask Task) error
3335
GetAlternatives() []TaskId
36+
Evaluate(data *PartialData, r *Request) (TaskId, error)
3437
}
3538

3639
type baseTask struct {

internal/workflow/workflow.go

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -169,28 +169,9 @@ func Visit(workflow *Workflow, taskId TaskId, excludeEnd bool) []Task {
169169
return tasks
170170
}
171171

172-
func (workflow *Workflow) doNothingExec(progress *Progress, input *PartialData, task UnaryTask, r *Request) (*PartialData, *Progress, bool, error) {
173-
174-
output := input.Data
175-
outputData := NewPartialData(ReqId(r.Id), task.GetNext(), output)
176-
177-
progress.Complete(task.GetId())
178-
179-
shouldContinueExecution := task.GetType() != Fail && task.GetType() != Succeed
180-
if shouldContinueExecution {
181-
err := progress.AddReadyTask(task.GetNext())
182-
if err != nil {
183-
return nil, progress, false, nil
184-
}
185-
}
186-
187-
return outputData, progress, shouldContinueExecution, nil
188-
}
189-
190172
func (workflow *Workflow) Execute(r *Request, input *PartialData, progress *Progress) (*PartialData, *Progress, bool, error) {
191-
var output *PartialData
192173
var err error
193-
shouldContinue := true
174+
var outputData *PartialData
194175

195176
var nextTasks []TaskId
196177

@@ -213,24 +194,57 @@ func (workflow *Workflow) Execute(r *Request, input *PartialData, progress *Prog
213194
}
214195

215196
switch task := n.(type) {
216-
case *FunctionTask:
217-
output, progress, shouldContinue, err = task.execute(progress, input, r)
218-
case *ChoiceTask:
219-
output, progress, shouldContinue, err = task.execute(progress, input, r)
220-
case *StartTask:
221-
output, progress, shouldContinue, err = task.execute(progress, input)
222-
case *PassTask:
223-
output, progress, shouldContinue, err = workflow.doNothingExec(progress, input, task, r)
224-
case *FailureTask:
225-
output, progress, shouldContinue, err = task.execute(progress, r)
226-
case *SuccessTask:
227-
output, progress, shouldContinue, err = workflow.doNothingExec(progress, input, task, r)
197+
case UnaryTask:
198+
output, err := task.execute(input, r)
199+
if err != nil {
200+
progress.Fail(n.GetId())
201+
return nil, progress, false, err
202+
}
203+
progress.Complete(task.GetId())
204+
205+
nextTask := task.GetNext()
206+
outputData = NewPartialData(ReqId(r.Id), nextTask, output)
207+
if progress.IsReady(nextTask) {
208+
progress.ReadyToExecute = append(progress.ReadyToExecute, nextTask)
209+
}
210+
211+
case ConditionalTask:
212+
nextTaskId, err := task.Evaluate(input, r)
213+
if err != nil {
214+
progress.Fail(n.GetId())
215+
return nil, progress, false, err
216+
}
217+
218+
// we skip all tasks that will not be executed
219+
toSkip := make([]Task, 0)
220+
toNotSkip := Visit(workflow, nextTaskId, false)
221+
for _, a := range task.GetAlternatives() {
222+
if a == nextTaskId {
223+
continue
224+
}
225+
branchTasks := Visit(workflow, a, false)
226+
for _, otherTask := range branchTasks {
227+
if !slices.Contains(toNotSkip, otherTask) {
228+
toSkip = append(toSkip, otherTask)
229+
}
230+
}
231+
}
232+
for _, t := range toSkip {
233+
progress.Skip(t.GetId())
234+
}
235+
progress.Complete(task.GetId())
236+
237+
outputData = NewPartialData(ReqId(r.Id), nextTaskId, input.Data)
238+
if progress.IsReady(nextTaskId) {
239+
progress.ReadyToExecute = append(progress.ReadyToExecute, nextTaskId)
240+
}
228241
case *EndTask:
229-
output, progress, shouldContinue, err = task.execute(progress, input)
242+
progress.Complete(task.GetId())
243+
outputData = input
230244
}
231245
if err != nil {
232246
progress.Fail(n.GetId())
233-
return output, progress, false, err
247+
return nil, progress, false, err
234248
}
235249
} else {
236250
err = SaveProgress(progress)
@@ -244,7 +258,7 @@ func (workflow *Workflow) Execute(r *Request, input *PartialData, progress *Prog
244258
return nil, progress, false, nil
245259
}
246260

247-
return output, progress, shouldContinue, nil
261+
return outputData, progress, true, nil
248262
}
249263

250264
// GetUniqueFunctions returns a list with the function names used in the Workflow. The returned function names are unique and in alphabetical order
@@ -440,7 +454,7 @@ func (workflow *Workflow) Invoke(r *Request) error {
440454
return fmt.Errorf("failed workflow execution: %v", err)
441455
}
442456

443-
if !shouldContinue && pd != nil {
457+
if len(progress.ReadyToExecute) == 0 && pd != nil {
444458
r.ExecReport.Result = pd.Data
445459
}
446460
}

0 commit comments

Comments
 (0)