Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 232 additions & 132 deletions docs/assets/api/schemas.json

Large diffs are not rendered by default.

74 changes: 30 additions & 44 deletions scripts/deliberate_lab/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,6 @@ class ProlificConfig(BaseModel):
bootedRedirectCode: str


class CohortDefinition(BaseModel):
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
id: Annotated[str, Field(min_length=1)]
alias: Annotated[str, Field(min_length=1)]
name: Annotated[str, Field(min_length=1)]
description: str | None = None
generatedCohortId: str | None = None
maxParticipantsPerCohort: Annotated[int | None, Field(ge=1)] = None


class CohortParticipantConfig(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -146,6 +133,19 @@ class BalanceAcross(StrEnum):
cohort = "cohort"


class CohortDefinition(BaseModel):
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
id: Annotated[str, Field(min_length=1)]
alias: Annotated[str, Field(min_length=1)]
name: Annotated[str, Field(min_length=1)]
description: str | None = None
generatedCohortId: str | None = None
maxParticipantsPerCohort: Annotated[int | None, Field(ge=1)] = None


class StageTextConfig(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
Expand Down Expand Up @@ -623,15 +623,6 @@ class ReasoningLevel(StrEnum):
high = "high"


class CustomRequestBodyField(BaseModel):
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
name: str
value: str


class GoogleThinkingLevel(StrEnum):
minimal = "minimal"
low = "low"
Expand Down Expand Up @@ -730,6 +721,15 @@ class OllamaProviderOptions(BaseModel):
numPredict: int | None = None


class CustomRequestBodyField(BaseModel):
model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)
name: str
value: str


class StructuredOutputType(StrEnum):
NONE = "NONE"
JSON_FORMAT = "JSON_FORMAT"
Expand Down Expand Up @@ -1014,9 +1014,7 @@ class ModelGenerationConfig(BaseModel):
topP: float | None = None
frequencyPenalty: float | None = None
presencePenalty: float | None = None
reasoningLevel: Annotated[ReasoningLevel | None, Field(title="ReasoningLevel")] = (
None
)
reasoningLevel: ReasoningLevel | None = None
reasoningBudget: int | None = None
includeReasoning: bool | None = None
disableSafetyFilters: bool | None = None
Expand Down Expand Up @@ -1107,7 +1105,7 @@ class StaticVariableConfig(BaseModel):
scope: Scope
definition: VariableDefinition
value: str
cohortValues: Annotated[dict[str, str] | None, Field(title="CohortValues")] = None
cohortValues: dict[str, str] | None = None


class Object(BaseModel):
Expand Down Expand Up @@ -1240,9 +1238,7 @@ class MultipleChoiceSurveyQuestion(BaseModel):
questionTitle: str
options: list[MultipleChoiceItem]
correctAnswerId: str | None = None
displayType: Annotated[
MultipleChoiceDisplayType | None, Field(title="MultipleChoiceDisplayType")
] = None
displayType: MultipleChoiceDisplayType | None = None
condition: ComparisonCondition | ConditionGroup | None = None


Expand Down Expand Up @@ -1375,9 +1371,7 @@ class TextPromptItem(BaseModel):
)
type: Literal["TEXT"] = "TEXT"
text: str
condition: Annotated[
ComparisonCondition | ConditionGroup | None, Field(title="Condition")
] = None
condition: ComparisonCondition | ConditionGroup | None = None


class ProfileInfoPromptItem(BaseModel):
Expand All @@ -1386,9 +1380,7 @@ class ProfileInfoPromptItem(BaseModel):
populate_by_name=True,
)
type: Literal["PROFILE_INFO"] = "PROFILE_INFO"
condition: Annotated[
ComparisonCondition | ConditionGroup | None, Field(title="Condition")
] = None
condition: ComparisonCondition | ConditionGroup | None = None


class ProfileContextPromptItem(BaseModel):
Expand All @@ -1397,9 +1389,7 @@ class ProfileContextPromptItem(BaseModel):
populate_by_name=True,
)
type: Literal["PROFILE_CONTEXT"] = "PROFILE_CONTEXT"
condition: Annotated[
ComparisonCondition | ConditionGroup | None, Field(title="Condition")
] = None
condition: ComparisonCondition | ConditionGroup | None = None


class StageContextPromptItem(BaseModel):
Expand All @@ -1414,9 +1404,7 @@ class StageContextPromptItem(BaseModel):
includeHelpText: bool
includeStageDisplay: bool
includeParticipantAnswers: bool
condition: Annotated[
ComparisonCondition | ConditionGroup | None, Field(title="Condition")
] = None
condition: ComparisonCondition | ConditionGroup | None = None


class PromptItemGroup(BaseModel):
Expand All @@ -1434,9 +1422,7 @@ class PromptItemGroup(BaseModel):
| PromptItemGroup
]
shuffleConfig: ShuffleConfig | None = None
condition: Annotated[
ComparisonCondition | ConditionGroup | None, Field(title="Condition")
] = None
condition: ComparisonCondition | ConditionGroup | None = None


class StructuredOutputConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion utils/src/experiment.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ export const CohortDefinitionSchema = Type.Object(
generatedCohortId: Type.Optional(Type.String()),
maxParticipantsPerCohort: Type.Optional(Type.Integer({minimum: 1})),
},
strict,
{$id: 'CohortDefinition', ...strict},
);

export const ExperimentTemplateSchema = Type.Object(
Expand Down
6 changes: 3 additions & 3 deletions utils/src/participant.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ const strict = {additionalProperties: false} as const;
/** ParticipantProfileBase input validation. */
export const ParticipantProfileBaseData = Type.Object(
{
pronouns: Type.Union([Type.Null(), Type.String()]),
avatar: Type.Union([Type.Null(), Type.String()]),
name: Type.Union([Type.Null(), Type.String()]),
pronouns: Type.Optional(Type.Union([Type.Null(), Type.String()])),
avatar: Type.Optional(Type.Union([Type.Null(), Type.String()])),
name: Type.Optional(Type.Union([Type.Null(), Type.String()])),
},
{$id: 'ParticipantProfileBase', ...strict},
);
Expand Down
6 changes: 3 additions & 3 deletions utils/src/prompt.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const PromptItemTypeData = Type.Union(

/** Base prompt item fields */
const BasePromptItemFields = {
condition: Type.Optional(ConditionSchema),
condition: Type.Optional(Type.Union([Type.Null(), ConditionSchema])),
};

/** Text prompt item */
Expand Down Expand Up @@ -128,10 +128,10 @@ export const PromptItemData = Type.Union([
/** Agent chat settings - defined here to avoid circular imports with agent.validation.ts */
export const AgentChatSettingsData = Type.Object(
{
wordsPerMinute: Type.Union([Type.Number(), Type.Null()]),
wordsPerMinute: Type.Optional(Type.Union([Type.Number(), Type.Null()])),
minMessagesBeforeResponding: Type.Integer(),
canSelfTriggerCalls: Type.Boolean(),
maxResponses: Type.Union([Type.Integer(), Type.Null()]),
maxResponses: Type.Optional(Type.Union([Type.Integer(), Type.Null()])),
initialMessage: Type.String(),
},
{$id: 'AgentChatSettings', ...strict},
Expand Down
26 changes: 16 additions & 10 deletions utils/src/providers.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,29 +182,35 @@ export const CustomRequestBodyFieldData = Type.Object(
name: Type.String(),
value: Type.String(),
},
strict,
{$id: 'CustomRequestBodyField', ...strict},
);

/** Model generation config */
export const ModelGenerationConfigData = Type.Object(
{
// Universal settings
maxTokens: Type.Optional(Type.Integer()),
stopSequences: Type.Optional(Type.Array(Type.String())),
temperature: Type.Optional(Type.Number()),
topP: Type.Optional(Type.Number()),
frequencyPenalty: Type.Optional(Type.Number()),
presencePenalty: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Union([Type.Null(), Type.Integer()])),
stopSequences: Type.Optional(
Type.Union([Type.Null(), Type.Array(Type.String())]),
),
temperature: Type.Optional(Type.Union([Type.Null(), Type.Number()])),
topP: Type.Optional(Type.Union([Type.Null(), Type.Number()])),
frequencyPenalty: Type.Optional(Type.Union([Type.Null(), Type.Number()])),
presencePenalty: Type.Optional(Type.Union([Type.Null(), Type.Number()])),
// Reasoning settings
reasoningLevel: Type.Optional(ReasoningLevelData),
reasoningLevel: Type.Optional(
Type.Union([Type.Null(), ReasoningLevelData]),
),
reasoningBudget: Type.Optional(Type.Union([Type.Integer(), Type.Null()])),
includeReasoning: Type.Optional(Type.Boolean()),
disableSafetyFilters: Type.Optional(Type.Boolean()),
// Provider-specific options
providerOptions: Type.Optional(ProviderOptionsMapData),
providerOptions: Type.Optional(
Type.Union([Type.Null(), ProviderOptionsMapData]),
),
// Legacy
customRequestBodyFields: Type.Optional(
Type.Array(CustomRequestBodyFieldData),
Type.Union([Type.Null(), Type.Array(CustomRequestBodyFieldData)]),
),
},
{$id: 'ModelGenerationConfig', ...strict},
Expand Down
14 changes: 6 additions & 8 deletions utils/src/shared.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@ export const PermissionsConfigSchema = Type.Object({
/** CohortParticipantConfig input validation. */
export const CohortParticipantConfigSchema = Type.Object(
{
minParticipantsPerCohort: Type.Union([
Type.Null(),
Type.Integer({minimum: 0}),
]),
maxParticipantsPerCohort: Type.Union([
Type.Null(),
Type.Integer({minimum: 1}),
]),
minParticipantsPerCohort: Type.Optional(
Type.Union([Type.Null(), Type.Integer({minimum: 0})]),
),
maxParticipantsPerCohort: Type.Optional(
Type.Union([Type.Null(), Type.Integer({minimum: 1})]),
),
includeAllParticipantsInCohortCount: Type.Boolean(),
botProtection: Type.Boolean(),
},
Expand Down
14 changes: 6 additions & 8 deletions utils/src/stages/chat_stage.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,12 @@ export const ChatStageConfigData = Type.Composite(
Type.Object(
{
kind: Type.Literal(StageKind.CHAT),
timeLimitInMinutes: Type.Union([
Type.Integer({minimum: 1}),
Type.Null(),
]),
timeMinimumInMinutes: Type.Union([
Type.Integer({minimum: 1}),
Type.Null(),
]),
timeLimitInMinutes: Type.Optional(
Type.Union([Type.Integer({minimum: 1}), Type.Null()]),
),
timeMinimumInMinutes: Type.Optional(
Type.Union([Type.Integer({minimum: 1}), Type.Null()]),
),
discussions: Type.Array(ChatDiscussionData),
},
strict,
Expand Down
2 changes: 1 addition & 1 deletion utils/src/stages/info_stage.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export const InfoStageConfigData = Type.Composite(
kind: Type.Literal(StageKind.INFO),
infoLines: Type.Array(Type.String()),
// Optional YouTube video ID to display
youtubeVideoId: Type.Optional(Type.String()),
youtubeVideoId: Type.Optional(Type.Union([Type.Null(), Type.String()])),
},
strict,
),
Expand Down
14 changes: 6 additions & 8 deletions utils/src/stages/private_chat_stage.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ export const PrivateChatStageConfigData = Type.Composite(
kind: Type.Literal(StageKind.PRIVATE_CHAT),
// If defined, ends chat after specified time limit
// (starting from when the first message is sent)
timeLimitInMinutes: Type.Union([
Type.Integer({minimum: 1}),
Type.Null(),
]),
timeMinimumInMinutes: Type.Union([
Type.Integer({minimum: 1}),
Type.Null(),
]),
timeLimitInMinutes: Type.Optional(
Type.Union([Type.Integer({minimum: 1}), Type.Null()]),
),
timeMinimumInMinutes: Type.Optional(
Type.Union([Type.Integer({minimum: 1}), Type.Null()]),
),
// If true, requires participant to go back and forth with mediator(s)
// (rather than being able to send multiple messages at once)
isTurnBasedChat: Type.Optional(Type.Boolean()),
Expand Down
17 changes: 12 additions & 5 deletions utils/src/stages/survey_stage.validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ export const MultipleChoiceSurveyQuestionData = Type.Object(
kind: Type.Literal(SurveyQuestionKind.MULTIPLE_CHOICE),
questionTitle: Type.String(),
options: Type.Array(MultipleChoiceItemData),
correctAnswerId: Type.Union([Type.Null(), Type.String()]),
correctAnswerId: Type.Optional(Type.Union([Type.Null(), Type.String()])),
displayType: Type.Optional(
Type.Enum(MultipleChoiceDisplayType, {$id: 'MultipleChoiceDisplayType'}),
Type.Union([
Type.Null(),
Type.Enum(MultipleChoiceDisplayType, {
$id: 'MultipleChoiceDisplayType',
}),
]),
),
condition: Type.Optional(Type.Union([Type.Null(), ConditionSchema])),
},
Expand All @@ -72,9 +77,11 @@ export const ScaleSurveyQuestionData = Type.Object(
upperText: Type.String(),
lowerValue: Type.Number(),
lowerText: Type.String(),
middleText: Type.Optional(Type.String()),
useSlider: Type.Optional(Type.Boolean()),
stepSize: Type.Optional(Type.Number({minimum: 1})),
middleText: Type.Optional(Type.Union([Type.Null(), Type.String()])),
useSlider: Type.Optional(Type.Union([Type.Null(), Type.Boolean()])),
stepSize: Type.Optional(
Type.Union([Type.Null(), Type.Number({minimum: 1})]),
),
condition: Type.Optional(Type.Union([Type.Null(), ConditionSchema])),
},
{$id: 'ScaleSurveyQuestion', ...strict},
Expand Down
Loading
Loading