Skip to content

Commit 94ce43f

Browse files
authored
Merge branch 'main' into batch-resumable-upload
2 parents bbd6cbc + 1c580e5 commit 94ce43f

15 files changed

Lines changed: 990 additions & 44 deletions

File tree

.github/dependabot.yml

Lines changed: 0 additions & 28 deletions
This file was deleted.

amber/src/main/scala/org/apache/texera/web/service/WorkflowService.scala

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ import org.apache.texera.amber.error.ErrorUtils.{
5050
}
5151
import org.apache.texera.dao.jooq.generated.tables.pojos.User
5252
import org.apache.texera.service.util.LargeBinaryManager
53-
import org.apache.texera.web.model.websocket.event.TexeraWebSocketEvent
53+
import org.apache.texera.web.model.websocket.event.{TexeraWebSocketEvent, WorkflowErrorEvent}
5454
import org.apache.texera.web.model.websocket.request.WorkflowExecuteRequest
5555
import org.apache.texera.web.resource.dashboard.user.workflow.WorkflowExecutionsResource
5656
import org.apache.texera.web.service.WorkflowService.mkWorkflowStateId
@@ -277,6 +277,14 @@ class WorkflowService(
277277
}
278278
}
279279
}
280+
// Once the execution is published via `executionService.onNext`, the normal
281+
// state-store path surfaces fatal errors to the UI: `errorHandler` writes
282+
// them into `executionStateStore.metadataStore`, whose diff handler (set up
283+
// in the WorkflowExecutionService constructor) emits a WorkflowErrorEvent
284+
// that `connectToExecution` forwards. Before that point, neither the emitter
285+
// nor a subscriber exists yet, so a failure in the constructor itself would
286+
// be recorded but never reach the frontend -- see the fallback in `catch`.
287+
var executionPublished = false
280288
try {
281289
val execution = new WorkflowExecutionService(
282290
controllerConf,
@@ -290,13 +298,36 @@ class WorkflowService(
290298
)
291299
lifeCycleManager.registerCleanUpOnStateChange(executionStateStore)
292300
executionService.onNext(execution)
301+
executionPublished = true
293302
execution.executeWorkflow()
294303
} catch {
295-
case e: Throwable => errorHandler(e)
304+
case e: Throwable =>
305+
errorHandler(e)
306+
// If the execution was never published, no `connectToExecution`
307+
// subscriber is bound to `executionStateStore`, so the state-store path
308+
// above cannot deliver the error. Push it directly in that pre-publish
309+
// window only; once published, the state-store path already surfaces it
310+
// (pushing here too would double-emit).
311+
if (!executionPublished) {
312+
reportFatalErrorsToSubscribers(executionStateStore)
313+
}
296314
}
297315

298316
}
299317

318+
/**
319+
* Push the fatal errors currently recorded in `stateStore` to connected
320+
* websocket subscribers (via `errorSubject`).
321+
*
322+
* Fallback used only when execution initialization fails before the execution
323+
* is published (e.g. the WorkflowExecutionService constructor throws): in that
324+
* window the per-execution state store has no diff-handler emitter and no
325+
* websocket subscriber, so the error -- already recorded by `errorHandler` --
326+
* would otherwise be logged but never reach the frontend.
327+
*/
328+
private[service] def reportFatalErrorsToSubscribers(stateStore: ExecutionStateStore): Unit =
329+
errorSubject.onNext(WorkflowErrorEvent(stateStore.metadataStore.getState.fatalErrors))
330+
300331
def convertToJson(frontendVersion: String): String = {
301332
val environmentVersionMap = Map(
302333
"engine_version" -> Json.toJson(frontendVersion)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.amber.engine.architecture.messaginglayer
21+
22+
import org.apache.pekko.actor.{ActorSystem, DeadLetter, Props}
23+
import org.apache.pekko.testkit.{ImplicitSender, TestKit, TestProbe}
24+
import org.apache.texera.amber.core.tuple.{AttributeType, Schema, TupleLike}
25+
import org.apache.texera.amber.core.virtualidentity.{ActorVirtualIdentity, ChannelIdentity}
26+
import org.apache.texera.amber.engine.architecture.common.WorkflowActor.{
27+
MessageBecomesDeadLetter,
28+
NetworkMessage
29+
}
30+
import org.apache.texera.amber.engine.common.ambermessage.{DataFrame, WorkflowFIFOMessage}
31+
import org.scalatest.BeforeAndAfterAll
32+
import org.scalatest.flatspec.AnyFlatSpecLike
33+
34+
import scala.concurrent.duration.DurationInt
35+
36+
class DeadLetterMonitorActorSpec
37+
extends TestKit(ActorSystem("DeadLetterMonitorActorSpec"))
38+
with ImplicitSender
39+
with AnyFlatSpecLike
40+
with BeforeAndAfterAll {
41+
42+
override def afterAll(): Unit = {
43+
TestKit.shutdownActorSystem(system)
44+
}
45+
46+
private val channelId =
47+
ChannelIdentity(
48+
ActorVirtualIdentity("sender"),
49+
ActorVirtualIdentity("receiver"),
50+
isControl = false
51+
)
52+
53+
private def aNetworkMessage(): NetworkMessage = {
54+
val payload = DataFrame(
55+
Array(TupleLike(1) enforceSchema Schema().add("field1", AttributeType.INTEGER))
56+
)
57+
NetworkMessage(0, WorkflowFIFOMessage(channelId, 0, payload))
58+
}
59+
60+
"DeadLetterMonitorActor" should "forward MessageBecomesDeadLetter to the original sender for a NetworkMessage dead letter" in {
61+
val monitor = system.actorOf(Props(new DeadLetterMonitorActor()))
62+
val originalSender = TestProbe()
63+
val recipient = TestProbe()
64+
val message = aNetworkMessage()
65+
66+
monitor ! DeadLetter(message, originalSender.ref, recipient.ref)
67+
68+
originalSender.expectMsg(MessageBecomesDeadLetter(message))
69+
}
70+
71+
it should "ignore a dead letter whose payload is not a NetworkMessage" in {
72+
val monitor = system.actorOf(Props(new DeadLetterMonitorActor()))
73+
val originalSender = TestProbe()
74+
val recipient = TestProbe()
75+
76+
monitor ! DeadLetter("not a network message", originalSender.ref, recipient.ref)
77+
78+
originalSender.expectNoMessage(200.millis)
79+
}
80+
81+
it should "ignore messages that are not dead letters" in {
82+
val monitor = system.actorOf(Props(new DeadLetterMonitorActor()))
83+
84+
monitor ! "some unrelated message"
85+
86+
expectNoMessage(200.millis)
87+
}
88+
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.texera.web.service
21+
22+
import com.google.protobuf.timestamp.Timestamp
23+
import org.apache.texera.amber.core.virtualidentity.WorkflowIdentity
24+
import org.apache.texera.amber.core.workflowruntimestate.FatalErrorType.EXECUTION_FAILURE
25+
import org.apache.texera.amber.core.workflowruntimestate.WorkflowFatalError
26+
import org.apache.texera.web.model.websocket.event.{TexeraWebSocketEvent, WorkflowErrorEvent}
27+
import org.apache.texera.web.storage.ExecutionStateStore
28+
import org.scalatest.flatspec.AnyFlatSpec
29+
import org.scalatest.matchers.should.Matchers
30+
31+
import java.time.Instant
32+
import scala.collection.mutable.ArrayBuffer
33+
34+
/**
35+
* Unit tests for `WorkflowService.reportFatalErrorsToSubscribers`, the seam
36+
* that surfaces init-time fatal errors to the websocket. When execution
37+
* initialization fails, the error is recorded in the metadata store; this push
38+
* is what makes it visible to connected clients instead of only logged.
39+
*/
40+
class WorkflowServiceSpec extends AnyFlatSpec with Matchers {
41+
42+
private def fatalError(message: String): WorkflowFatalError =
43+
WorkflowFatalError(EXECUTION_FAILURE, Timestamp(Instant.now), message, "", "", "")
44+
45+
/** A WorkflowService with a subscriber collecting every event it pushes. */
46+
private def serviceWithCollector(): (WorkflowService, ArrayBuffer[TexeraWebSocketEvent]) = {
47+
val service = new WorkflowService(WorkflowIdentity(1), computingUnitId = 1, cleanUpTimeout = 30)
48+
val events = ArrayBuffer.empty[TexeraWebSocketEvent]
49+
service.connect(evt => events += evt)
50+
(service, events)
51+
}
52+
53+
private def errorEventsIn(events: ArrayBuffer[TexeraWebSocketEvent]): Seq[WorkflowErrorEvent] =
54+
events.collect { case e: WorkflowErrorEvent => e }.toSeq
55+
56+
"WorkflowService" should
57+
"push a WorkflowErrorEvent carrying the store's fatal error to connected subscribers" in {
58+
val (service, events) = serviceWithCollector()
59+
val store = new ExecutionStateStore()
60+
val err = fatalError("boom during init")
61+
store.metadataStore.updateState(_.addFatalErrors(err))
62+
63+
service.reportFatalErrorsToSubscribers(store)
64+
65+
val errorEvents = errorEventsIn(events)
66+
errorEvents should have size 1
67+
// Forwards exactly the store's fatal errors -- no more, no less.
68+
errorEvents.head.fatalErrors should contain theSameElementsAs Seq(err)
69+
}
70+
71+
it should "carry every fatal error currently recorded in the store" in {
72+
val (service, events) = serviceWithCollector()
73+
val store = new ExecutionStateStore()
74+
val first = fatalError("first")
75+
val second = fatalError("second")
76+
store.metadataStore.updateState(_.addFatalErrors(first).addFatalErrors(second))
77+
78+
service.reportFatalErrorsToSubscribers(store)
79+
80+
val errorEvents = errorEventsIn(events)
81+
errorEvents should have size 1
82+
// Exactly the two recorded errors -- no extras.
83+
errorEvents.head.fatalErrors should contain theSameElementsAs Seq(first, second)
84+
}
85+
}

common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.texera.amber.operator.huggingFace.codegen.{
3030
ImageTaskCodegen,
3131
MediaGenCodegen,
3232
PythonCodegenBase,
33+
QaRankingCodegen,
3334
TaskCodegen,
3435
TextGenCodegen
3536
}
@@ -108,6 +109,25 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor {
108109
@AutofillAttributeName
109110
var inputAudioColumn: EncodableString = ""
110111

112+
@JsonProperty(value = "contextColumn", required = false)
113+
@JsonSchemaTitle("Context Column")
114+
@JsonPropertyDescription("Column containing the context passage for question answering")
115+
@AutofillAttributeName
116+
var contextColumn: EncodableString = ""
117+
118+
@JsonProperty(value = "candidateLabels", required = false)
119+
@JsonSchemaTitle("Candidate Labels")
120+
@JsonPropertyDescription("Comma-separated candidate labels for zero-shot classification")
121+
var candidateLabels: EncodableString = ""
122+
123+
@JsonProperty(value = "sentencesColumn", required = false)
124+
@JsonSchemaTitle("Sentences Column")
125+
@JsonPropertyDescription(
126+
"Column with comma-separated sentences for sentence similarity and text ranking"
127+
)
128+
@AutofillAttributeName
129+
var sentencesColumn: EncodableString = ""
130+
111131
@JsonProperty(
112132
value = "systemPrompt",
113133
required = false,
@@ -153,6 +173,7 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor {
153173
ImageTaskCodegen.tasks.foreach(t => byTask += (t -> ImageTaskCodegen))
154174
AudioTaskCodegen.tasks.foreach(t => byTask += (t -> AudioTaskCodegen))
155175
MediaGenCodegen.tasks.foreach(t => byTask += (t -> MediaGenCodegen))
176+
QaRankingCodegen.tasks.foreach(t => byTask += (t -> QaRankingCodegen))
156177
byTask.toMap
157178
}
158179

@@ -200,6 +221,12 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor {
200221
if (audioInput == null) "" else audioInput
201222
val safeInputAudioColumn: EncodableString =
202223
if (inputAudioColumn == null) "" else inputAudioColumn
224+
val safeContextColumn: EncodableString =
225+
if (contextColumn == null) "" else contextColumn
226+
val safeCandidateLabels: EncodableString =
227+
if (candidateLabels == null) "" else candidateLabels
228+
val safeSentencesColumn: EncodableString =
229+
if (sentencesColumn == null) "" else sentencesColumn
203230

204231
val ctx = CodegenContext(
205232
hfApiToken = safeToken,
@@ -213,7 +240,10 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor {
213240
imageInput = safeImageInput,
214241
inputImageColumn = safeInputImageColumn,
215242
audioInput = safeAudioInput,
216-
inputAudioColumn = safeInputAudioColumn
243+
inputAudioColumn = safeInputAudioColumn,
244+
contextColumn = safeContextColumn,
245+
candidateLabels = safeCandidateLabels,
246+
sentencesColumn = safeSentencesColumn
217247
)
218248

219249
PythonCodegenBase.render(ctx, codegenForTask(safeTask))

common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ object PythonCodegenBase {
5959
val inputImageColumn = ctx.inputImageColumn
6060
val audioInput = ctx.audioInput
6161
val inputAudioColumn = ctx.inputAudioColumn
62+
val contextColumn = ctx.contextColumn
63+
val candidateLabels = ctx.candidateLabels
64+
val sentencesColumn = ctx.sentencesColumn
6265
pyb"""import os
6366
|import re
6467
|import json
@@ -141,6 +144,9 @@ object PythonCodegenBase {
141144
| self.INPUT_IMAGE_COLUMN = $inputImageColumn
142145
| self.AUDIO_INPUT = $audioInput
143146
| self.INPUT_AUDIO_COLUMN = $inputAudioColumn
147+
| self.CONTEXT_COLUMN = $contextColumn
148+
| self.CANDIDATE_LABELS = $candidateLabels
149+
| self.SENTENCES_COLUMN = $sentencesColumn
144150
|
145151
| def _resolve_providers(self, token):
146152
| '''Query the HF Hub API for inference providers serving this model.
@@ -491,6 +497,24 @@ object PythonCodegenBase {
491497
| f"Prompt column '{prompt_col}' not found in input table. "
492498
| f"Available columns: {list(table.columns)}"
493499
| )
500+
| if task == "zero-shot-classification":
501+
| labels = [l.strip() for l in str(self.CANDIDATE_LABELS).split(",") if l.strip()]
502+
| assert labels, (
503+
| "Candidate Labels are required for zero-shot-classification. "
504+
| "Provide a comma-separated list of labels."
505+
| )
506+
| if task == "question-answering":
507+
| ctx_col = self.CONTEXT_COLUMN
508+
| assert ctx_col and ctx_col in table.columns, (
509+
| f"Context column '{ctx_col}' not found in input table. "
510+
| f"Available columns: {list(table.columns)}"
511+
| )
512+
| if task in ("sentence-similarity", "text-ranking"):
513+
| sent_col = self.SENTENCES_COLUMN
514+
| assert sent_col and sent_col in table.columns, (
515+
| f"Sentences column '{sent_col}' not found in input table. "
516+
| f"Available columns: {list(table.columns)}"
517+
| )
494518
|
495519
| # --- handle empty table ---
496520
| if table.empty:
@@ -506,6 +530,16 @@ object PythonCodegenBase {
506530
| "Authorization": f"Bearer {token}",
507531
| "Content-Type": "application/octet-stream",
508532
| }
533+
| # --- pre-compute table dict for table-question-answering ---
534+
| table_dict = None
535+
| if task == "table-question-answering":
536+
| table_dict = {}
537+
| for col in table.columns:
538+
| if col != prompt_col and col != result_col:
539+
| table_dict[col] = [
540+
| str(v) if not pd.isna(v) else "" for v in table[col].tolist()
541+
| ]
542+
|
509543
| # --- resolve image source (upload or column) for image tasks ---
510544
| has_image_upload = bool(self.IMAGE_INPUT) and bool(str(self.IMAGE_INPUT).strip())
511545
| use_image_column = not has_image_upload and bool(self.INPUT_IMAGE_COLUMN) and self.INPUT_IMAGE_COLUMN in table.columns

0 commit comments

Comments
 (0)