-
Notifications
You must be signed in to change notification settings - Fork 10
adversarial augmentation #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
3c5ed18
adversarial augmentation
Tetragrammaton123 f9e2dcd
async mode, lint, typing, etc
Tetragrammaton123 24af07b
merge dev
Tetragrammaton123 7a8889b
Merge remote-tracking branch 'origin/dev' into my-fix-branch
Tetragrammaton123 04a3ef6
async update (aiometer), в гугл колабе ускорилось в раза 2-3
Tetragrammaton123 c18698e
async
Tetragrammaton123 9dfa81f
mypy fix
Tetragrammaton123 7fb32dd
mypy again and init
Tetragrammaton123 1be0fbf
fix
Tetragrammaton123 300f165
pull dev
voorhs 0735471
move to proper directory
voorhs c23860f
run formatter
voorhs 66fecba
Merge branch 'dev' of https://github.com/deeppavlov/AutoIntent into m…
Tetragrammaton123 45c8fc8
тесты для адверсариал аугментации
Tetragrammaton123 cc647ee
исправление ошибок
Tetragrammaton123 40647a1
опять ошибка
Tetragrammaton123 9d731b3
я не сдамся
Tetragrammaton123 979a81d
add disclaimer
voorhs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .critic_human_like import CriticHumanLike | ||
| from .human_utterance_generator import HumanUtteranceGenerator | ||
|
|
||
| __all__ = ["HumanUtteranceGenerator"] |
83 changes: 83 additions & 0 deletions
83
autointent/generation/utterances/_adversarial/critic_human_like.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| """CriticHumanLike class for distinguishing human vs generated utterances.""" | ||
|
|
||
| from typing import Literal | ||
|
|
||
| from pydantic import BaseModel | ||
|
|
||
| from autointent.generation import Generator | ||
| from autointent.generation.chat_templates import Message, Role | ||
|
|
||
|
|
||
| class CriticResponse(BaseModel): | ||
| """Structured answer.""" | ||
| reasoning: str | ||
| label: Literal["human", "generated"] | ||
|
|
||
| class CriticHumanLike: | ||
| """A simple critic class that classifies user utterances as either 'human' or 'generated'. | ||
|
|
||
| using an LLM-based binary classifier prompt. | ||
| """ | ||
|
|
||
| def __init__(self, generator: Generator, max_retries: int = 3)-> None: | ||
| """Initialize the CriticFirst. | ||
|
|
||
| Args: | ||
| generator: Wrapper for the LLM API to generate classification responses. | ||
| max_retries: Maximum number of attempts to retry classification if the response is invalid. | ||
| """ | ||
| self.generator = generator | ||
| self.max_retries = max_retries | ||
| def build_classification_prompt(self, example: str, intent_name: str) -> Message: | ||
| """Args. | ||
|
|
||
| example: The user utterance to classify. | ||
| intent_name: The name of the intent associated with the utterance. | ||
|
|
||
| Returns: | ||
| Message: A formatted message prompt for classification. | ||
| """ | ||
| content = ( | ||
| "You are a critic that determines whether a user utterance was written by a human or " | ||
| "generated by a language model.\n\n" | ||
| f"Intent: {intent_name}\n" | ||
| f'Utterance: "{example}"\n\n' | ||
| "Here is an example of a human-written utterance for this intent:\n" | ||
| '"Could you please help me find the nearest coffee shop?"\n\n' | ||
| "Respond in **JSON format** with three keys:\n" | ||
| "- `reasoning`: a short chain-of-thought where you explain your logic\n" | ||
| "- `label`: must be either `human` or `generated`\n" | ||
| "Example:\n" | ||
| "{\n" | ||
| ' "reasoning": "The phrasing includes casual contractions and natural hesitation. The utterance ' | ||
| 'flows similarly to how a human would speak spontaneously.",\n' | ||
| ' "label": "human",\n' | ||
| "}" | ||
| ) | ||
| return Message(role=Role.USER, content=content) | ||
|
|
||
| def is_human(self, utterance: str, intent_name: str) -> bool: | ||
| """Args. | ||
|
|
||
| utterance: The utterance to evaluate. | ||
| intent_name: The associated intent. | ||
|
|
||
| Returns: | ||
| bool: True if classified as human, False otherwise. | ||
| """ | ||
| message = self.build_classification_prompt(utterance, intent_name) | ||
| response = self.generator.get_structured_output_sync( | ||
| messages=[message], | ||
| output_model=CriticResponse, | ||
| max_retries=self.max_retries | ||
| ) | ||
| return response.label == "human" | ||
| async def is_human_async(self, utterance: str, intent_name: str) -> bool: | ||
| message = self.build_classification_prompt(utterance, intent_name) | ||
|
|
||
| response = await self.generator.get_structured_output_async( | ||
| messages=[message], | ||
| output_model=CriticResponse, | ||
| max_retries=self.max_retries | ||
| ) | ||
| return response.label == "human" |
178 changes: 178 additions & 0 deletions
178
autointent/generation/utterances/_adversarial/human_utterance_generator.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| import asyncio | ||
| import random | ||
| from collections import defaultdict | ||
|
|
||
| from datasets import Dataset as HFDataset | ||
| from datasets import concatenate_datasets | ||
|
|
||
| from autointent import Dataset | ||
| from autointent.custom_types import Split | ||
| from autointent.generation import Generator | ||
| from autointent.generation.chat_templates._evolution_templates_schemas import Message, Role | ||
| from autointent.schemas import Sample | ||
|
|
||
| from .critic_human_like import CriticHumanLike | ||
|
|
||
|
|
||
| class HumanUtteranceGenerator: | ||
| """Generator of human-like utterances. | ||
|
|
||
| This class rewrites given user utterances to make them sound more natural and human-like, | ||
| while preserving their original intent. The generation process is iterative and attempts | ||
| to bypass a critic that identifies machine-generated text. | ||
| """ | ||
|
|
||
| def __init__(self, generator: Generator, critic: CriticHumanLike, async_mode: bool = False)-> None: | ||
| """Initialize the HumanUtteranceGeneratoror. | ||
|
|
||
| Args: | ||
| generator: Wrapper for the LLM API used to generate utterances. | ||
| critic: Critic to determine whether the generated utterance sounds human-like. | ||
| async_mode: Whether to use asynchronous mode for generation. | ||
| """ | ||
| self.generator = generator | ||
| self.critic = critic | ||
| self.async_mode = async_mode | ||
|
|
||
| def augment( | ||
| self, | ||
| dataset: Dataset, | ||
| split_name: str = Split.TRAIN, | ||
| update_split: bool = True, | ||
| n_final_per_class: int = 5 | ||
| ) -> list[Sample]: | ||
| """Generate human-like utterances for each intent by iteratively refining machine-generated candidates. | ||
|
|
||
| Args: | ||
| dataset: The dataset to augment. | ||
| split_name: The name of the split to augment (e.g., 'train'). | ||
| update_split: Whether to update the dataset split with the new utterances. | ||
| n_final_per_class: Number of successful utterances to generate per intent. | ||
|
|
||
| Returns: | ||
| list[Sample]: List of newly generated samples. | ||
| """ | ||
| if self.async_mode: | ||
| return asyncio.run( | ||
| self.augment_async( | ||
| dataset=dataset, | ||
| split_name=split_name, | ||
| update_split=update_split, | ||
| n_final_per_class=n_final_per_class | ||
| ) | ||
| ) | ||
| original_split = dataset[split_name] | ||
| id_to_name = {intent.id: intent.name for intent in dataset.intents} | ||
| new_samples = [] | ||
|
|
||
| class_to_samples = defaultdict(list) | ||
| for sample in original_split: | ||
| class_to_samples[sample["label"]].append(sample["utterance"]) | ||
|
|
||
| for intent_id, intent_name in id_to_name.items(): | ||
| if intent_name is None: | ||
| continue | ||
| generated_count = 0 | ||
| attempt = 0 | ||
|
|
||
| seed_utterances = class_to_samples.get(intent_id, []) | ||
| if not seed_utterances: | ||
| continue | ||
|
|
||
| while generated_count < n_final_per_class and attempt < n_final_per_class * 3: | ||
| attempt += 1 | ||
| n_seeds = min(3, len(seed_utterances)) | ||
| seed_examples = random.sample(seed_utterances, k=n_seeds) | ||
| rejected: list[str] = [] | ||
|
|
||
| for _ in range(3): | ||
| prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected) | ||
| generated = self.generator.get_chat_completion([prompt]).strip() | ||
| if self.critic.is_human(generated, intent_name): | ||
| new_samples.append({ | ||
| Dataset.label_feature: intent_id, | ||
| Dataset.utterance_feature: generated | ||
| }) | ||
| generated_count += 1 | ||
| break | ||
| rejected.append(generated) | ||
| if update_split: | ||
| generated_split = HFDataset.from_list(new_samples) | ||
| dataset[split_name] = concatenate_datasets([original_split, generated_split]) | ||
|
|
||
| return [Sample(**sample) for sample in new_samples] | ||
|
|
||
| async def augment_async( | ||
| self, | ||
| dataset: Dataset, | ||
| split_name: str = Split.TRAIN, | ||
| update_split: bool = True, | ||
| n_final_per_class: int = 5 | ||
| ) -> list[Sample]: | ||
| original_split = dataset[split_name] | ||
| id_to_name = {intent.id: intent.name for intent in dataset.intents} | ||
| new_samples = [] | ||
|
|
||
| class_to_samples = defaultdict(list) | ||
| for sample in original_split: | ||
| class_to_samples[sample["label"]].append(sample["utterance"]) | ||
|
|
||
| for intent_id, intent_name in id_to_name.items(): | ||
| if intent_name is None: | ||
| continue | ||
| generated_count = 0 | ||
| attempt = 0 | ||
| seed_utterances = class_to_samples.get(intent_id, []) | ||
| if not seed_utterances: | ||
| continue | ||
|
|
||
| while generated_count < n_final_per_class and attempt < n_final_per_class * 3: | ||
| attempt += 1 | ||
| seed_examples = random.sample(seed_utterances, k=min(3, len(seed_utterances))) | ||
| rejected: list[str] = [] | ||
|
|
||
| for _ in range(3): | ||
| prompt = self._build_adversarial_prompt(intent_name, seed_examples, rejected) | ||
| generated = (await self.generator.get_chat_completion_async([prompt])).strip() | ||
| if await self.critic.is_human_async(generated, intent_name): | ||
| new_samples.append({ | ||
| Dataset.label_feature: intent_id, | ||
| Dataset.utterance_feature: generated | ||
| }) | ||
| generated_count += 1 | ||
| break | ||
| rejected.append(generated) | ||
|
voorhs marked this conversation as resolved.
Outdated
|
||
|
|
||
| if update_split: | ||
| generated_split = HFDataset.from_list(new_samples) | ||
| dataset[split_name] = concatenate_datasets([original_split, generated_split]) | ||
|
|
||
| return [Sample(**sample) for sample in new_samples] | ||
| def _build_adversarial_prompt(self, intent_name: str, seed_examples: list[str], rejected: list[str]) -> Message: | ||
| """Build a few-shot prompt. | ||
|
|
||
| Build a few-shot prompt to guide the generator to create a new human-like utterance | ||
| from scratch based on the intent name and example utterances. | ||
|
|
||
| Args: | ||
| intent_name: The intent of the utterance. | ||
| seed_examples: List of 1-3 example utterances for the intent. | ||
| rejected: List of previously rejected generations. | ||
|
|
||
| Returns: | ||
| Message: A formatted prompt instructing the generator to produce a new natural-sounding utterance.. | ||
| """ | ||
| rejected_block = "\n".join(f"- {r}" for r in rejected) if rejected else "None" | ||
| examples_block = "\n".join(f'- "{ex}"' for ex in seed_examples) | ||
| content = ( | ||
| f"Your task is to generate a new user utterance that fits the intent '{intent_name}'.\n\n" | ||
| "Here are some examples of utterances for this intent:\n" | ||
| f"{examples_block}\n\n" | ||
| "Preserving its original intent: " | ||
| f"'{intent_name}'.\n\n" | ||
| f"The following previous attempts were classified as machine-generated and rejected:\n{rejected_block}\n\n" | ||
| "Try to write something that would pass as written by a real human. Output a single version only.\n" | ||
| "IMPORTANT: You must modify the original utterance." | ||
| ) | ||
| return Message(role=Role.USER, content=content) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.