Skip to content

Commit 34f0073

Browse files
romanlutzCopilot
andauthored
MAINT: Modernize Optional/Union typing to PEP 604 syntax (microsoft#2109)
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 01f6dc2 commit 34f0073

10 files changed

Lines changed: 91 additions & 71 deletions

File tree

pyrit/executor/attack/component/conversation_manager.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
from __future__ import annotations
5+
46
import logging
57
import uuid
6-
from collections.abc import Sequence
78
from dataclasses import dataclass, field
8-
from typing import TYPE_CHECKING, Any, Optional
9+
from typing import TYPE_CHECKING, Any
910

1011
from pyrit.common.deprecation import print_deprecation_message
1112
from pyrit.common.utils import combine_dict
@@ -22,14 +23,16 @@
2223
MessagePiece,
2324
Score,
2425
)
25-
from pyrit.prompt_normalizer.prompt_converter_configuration import (
26-
PromptConverterConfiguration,
27-
)
2826
from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer
2927
from pyrit.prompt_target import CapabilityName, PromptTarget
3028

3129
if TYPE_CHECKING:
30+
from collections.abc import Sequence
31+
3232
from pyrit.executor.attack.core import AttackContext
33+
from pyrit.prompt_normalizer.prompt_converter_configuration import (
34+
PromptConverterConfiguration,
35+
)
3336

3437
logger = logging.getLogger(__name__)
3538

@@ -280,11 +283,11 @@ def set_system_prompt(
280283
async def initialize_context_async(
281284
self,
282285
*,
283-
context: "AttackContext[Any]",
286+
context: AttackContext[Any],
284287
target: PromptTarget,
285288
conversation_id: str,
286289
request_converters: list[PromptConverterConfiguration] | None = None,
287-
prepended_conversation_config: Optional["PrependedConversationConfig"] = None,
290+
prepended_conversation_config: PrependedConversationConfig | None = None,
288291
max_turns: int | None = None,
289292
memory_labels: dict[str, str] | None = None,
290293
) -> ConversationState:
@@ -362,9 +365,9 @@ async def initialize_context_async(
362365
async def _handle_non_chat_target_async(
363366
self,
364367
*,
365-
context: "AttackContext[Any]",
368+
context: AttackContext[Any],
366369
prepended_conversation: list[Message],
367-
config: Optional["PrependedConversationConfig"],
370+
config: PrependedConversationConfig | None,
368371
) -> ConversationState:
369372
"""
370373
Handle prepended conversation for non-chat targets.
@@ -435,7 +438,7 @@ async def add_prepended_conversation_to_memory_async(
435438
prepended_conversation: list[Message],
436439
conversation_id: str,
437440
request_converters: list[PromptConverterConfiguration] | None = None,
438-
prepended_conversation_config: Optional["PrependedConversationConfig"] = None,
441+
prepended_conversation_config: PrependedConversationConfig | None = None,
439442
max_turns: int | None = None,
440443
target_identifier: ComponentIdentifier | None = None,
441444
) -> int:
@@ -518,11 +521,11 @@ async def add_prepended_conversation_to_memory_async(
518521
async def _process_prepended_for_chat_target_async(
519522
self,
520523
*,
521-
context: "AttackContext[Any]",
524+
context: AttackContext[Any],
522525
prepended_conversation: list[Message],
523526
conversation_id: str,
524527
request_converters: list[PromptConverterConfiguration] | None,
525-
prepended_conversation_config: Optional["PrependedConversationConfig"],
528+
prepended_conversation_config: PrependedConversationConfig | None,
526529
max_turns: int | None,
527530
target_identifier: ComponentIdentifier | None = None,
528531
) -> ConversationState:

pyrit/executor/attack/multi_turn/multi_prompt_sending.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
from __future__ import annotations
5+
46
import logging
57
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Any, Optional
8+
from typing import TYPE_CHECKING, Any
79

810
from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults
911
from pyrit.common.utils import get_kwarg_param
@@ -51,13 +53,13 @@ class MultiPromptSendingAttackParameters(AttackParameters):
5153

5254
@classmethod
5355
async def from_seed_group_async(
54-
cls: type["MultiPromptSendingAttackParameters"],
56+
cls: type[MultiPromptSendingAttackParameters],
5557
seed_group: SeedAttackGroup,
5658
*,
57-
adversarial_chat: Optional["PromptTarget"] = None,
58-
objective_scorer: Optional["TrueFalseScorer"] = None,
59+
adversarial_chat: PromptTarget | None = None,
60+
objective_scorer: TrueFalseScorer | None = None,
5961
**overrides: Any,
60-
) -> "MultiPromptSendingAttackParameters":
62+
) -> MultiPromptSendingAttackParameters:
6163
"""
6264
Create parameters from a SeedGroup, extracting user messages.
6365

pyrit/executor/attack/multi_turn/tree_of_attacks.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
from __future__ import annotations
5+
46
import asyncio
57
import enum
68
import json
79
import logging
810
import uuid
911
from dataclasses import dataclass, field
10-
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, Optional, cast, overload
12+
from typing import TYPE_CHECKING, Any, cast, overload
1213

1314
from treelib.tree import Tree
1415

@@ -70,6 +71,8 @@
7071
from pyrit.score.true_false.true_false_inverter_scorer import TrueFalseInverterScorer
7172

7273
if TYPE_CHECKING:
74+
from pathlib import Path
75+
7376
from pyrit.models.literals import PromptDataType
7477

7578
logger = logging.getLogger(__name__)
@@ -171,7 +174,7 @@ class TAPAttackContext(MultiTurnAttackContext[Any]):
171174

172175
# Nodes in the attack tree
173176
# Each node represents a branch in the attack tree with its own state
174-
nodes: list["_TreeOfAttacksNode"] = field(default_factory=list)
177+
nodes: list[_TreeOfAttacksNode] = field(default_factory=list)
175178

176179
# Best conversation ID and score found during the attack
177180
best_conversation_id: str | None = None
@@ -376,7 +379,7 @@ async def initialize_with_prepended_conversation_async(
376379
self,
377380
*,
378381
prepended_conversation: list[Message],
379-
prepended_conversation_config: Optional["PrependedConversationConfig"] = None,
382+
prepended_conversation_config: PrependedConversationConfig | None = None,
380383
) -> None:
381384
"""
382385
Initialize the node with a prepended conversation history.
@@ -769,7 +772,7 @@ def _handle_unexpected_error(self, error: Exception) -> None:
769772
logger.error(f"Node {self.node_id}: Unexpected error during execution: {error}")
770773
self.error_message = f"Execution error: {str(error)}"
771774

772-
def duplicate(self) -> "_TreeOfAttacksNode":
775+
def duplicate(self) -> _TreeOfAttacksNode:
773776
"""
774777
Create a duplicate of this node for branching.
775778

pyrit/memory/storage/serializers.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def data_serializer_factory(
6060
Args:
6161
data_type (str): The type of the data (e.g., 'text', 'image_path', 'audio_path').
6262
value (str): The data value to be serialized.
63-
extension (Optional[str]): The file extension, if applicable.
63+
extension (str | None): The file extension, if applicable.
6464
category (AllowedCategories): The category or context for the data (e.g., 'seed-prompt-entries').
6565
6666
Returns:
@@ -320,10 +320,10 @@ async def get_data_filename_async(self, file_name: str | None = None) -> Path |
320320
Generate or retrieve a unique filename for the data file.
321321
322322
Args:
323-
file_name (Optional[str]): Optional file name override.
323+
file_name (str | None): Optional file name override.
324324
325325
Returns:
326-
Union[Path, str]: Full storage path for the generated data file.
326+
Path | str: Full storage path for the generated data file.
327327
328328
Raises:
329329
TypeError: If the serializer is not configured for on-disk data.
@@ -471,7 +471,7 @@ async def get_data_filename( # pyrit-async-suffix-exempt
471471
file_name: Optional file name override.
472472
473473
Returns:
474-
Union[Path, str]: Full storage path for the generated data file.
474+
Path | str: Full storage path for the generated data file.
475475
"""
476476
print_deprecation_message(
477477
old_item="pyrit.memory.storage.serializers.DataTypeSerializer.get_data_filename",
@@ -586,7 +586,7 @@ def __init__(self, *, category: str, prompt_text: str, extension: str | None = N
586586
Args:
587587
category (str): Data category folder name.
588588
prompt_text (str): URL or path value.
589-
extension (Optional[str]): Optional extension for persisted content.
589+
extension (str | None): Optional extension for persisted content.
590590
591591
"""
592592
self.data_type = "url"
@@ -615,8 +615,8 @@ def __init__(self, *, category: str, prompt_text: str | None = None, extension:
615615
616616
Args:
617617
category (str): Data category folder name.
618-
prompt_text (Optional[str]): Optional existing image path.
619-
extension (Optional[str]): Optional image extension.
618+
prompt_text (str | None): Optional existing image path.
619+
extension (str | None): Optional image extension.
620620
621621
"""
622622
self.data_type = "image_path"
@@ -652,8 +652,8 @@ def __init__(
652652
653653
Args:
654654
category (str): Data category folder name.
655-
prompt_text (Optional[str]): Optional existing audio path.
656-
extension (Optional[str]): Optional audio extension.
655+
prompt_text (str | None): Optional existing audio path.
656+
extension (str | None): Optional audio extension.
657657
658658
"""
659659
self.data_type = "audio_path"
@@ -689,8 +689,8 @@ def __init__(
689689
690690
Args:
691691
category (str): The category or context for the data.
692-
prompt_text (Optional[str]): The video path or identifier.
693-
extension (Optional[str]): The file extension, defaults to 'mp4'.
692+
prompt_text (str | None): The video path or identifier.
693+
extension (str | None): The file extension, defaults to 'mp4'.
694694
695695
"""
696696
self.data_type = "video_path"
@@ -730,8 +730,8 @@ def __init__(
730730
731731
Args:
732732
category (str): The category or context for the data.
733-
prompt_text (Optional[str]): The binary file path or identifier.
734-
extension (Optional[str]): The file extension, defaults to 'bin'.
733+
prompt_text (str | None): The binary file path or identifier.
734+
extension (str | None): The file extension, defaults to 'bin'.
735735
736736
"""
737737
self.data_type = "binary_path"

pyrit/memory/storage/storage.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ async def read_file(self, path: Path | str) -> bytes: # pyrit-async-suffix-exem
7171
Read a file from storage (deprecated alias of ``read_file_async``).
7272
7373
Args:
74-
path (Union[Path, str]): The path to the file.
74+
path (Path | str): The path to the file.
7575
7676
Returns:
7777
bytes: The content of the file.
@@ -88,7 +88,7 @@ async def write_file(self, path: Path | str, data: bytes) -> None: # pyrit-asyn
8888
Write data to storage (deprecated alias of ``write_file_async``).
8989
9090
Args:
91-
path (Union[Path, str]): The path to the file.
91+
path (Path | str): The path to the file.
9292
data (bytes): The content to write to the file.
9393
"""
9494
print_deprecation_message(
@@ -103,7 +103,7 @@ async def path_exists(self, path: Path | str) -> bool: # pyrit-async-suffix-exe
103103
Check whether a path exists (deprecated alias of ``path_exists_async``).
104104
105105
Args:
106-
path (Union[Path, str]): The path to check.
106+
path (Path | str): The path to check.
107107
108108
Returns:
109109
bool: True if the path exists, False otherwise.
@@ -120,7 +120,7 @@ async def is_file(self, path: Path | str) -> bool: # pyrit-async-suffix-exempt
120120
Check whether the given path is a file (deprecated alias of ``is_file_async``).
121121
122122
Args:
123-
path (Union[Path, str]): The path to check.
123+
path (Path | str): The path to check.
124124
125125
Returns:
126126
bool: True if the path is a file, False otherwise.
@@ -137,7 +137,7 @@ async def create_directory_if_not_exists(self, path: Path | str) -> None: # pyr
137137
Create a directory if it does not exist (deprecated alias of ``create_directory_if_not_exists_async``).
138138
139139
Args:
140-
path (Union[Path, str]): The directory path to create.
140+
path (Path | str): The directory path to create.
141141
"""
142142
print_deprecation_message(
143143
old_item="pyrit.memory.storage.storage.StorageIO.create_directory_if_not_exists",
@@ -157,7 +157,7 @@ async def read_file_async(self, path: Path | str) -> bytes:
157157
Asynchronously reads a file from the local disk.
158158
159159
Args:
160-
path (Union[Path, str]): The path to the file.
160+
path (Path | str): The path to the file.
161161
162162
Returns:
163163
bytes: The content of the file.
@@ -225,7 +225,7 @@ def _convert_to_path(self, path: Path | str) -> Path:
225225
Convert an input path to a Path object.
226226
227227
Args:
228-
path (Union[Path, str]): Input path value.
228+
path (Path | str): Input path value.
229229
230230
Returns:
231231
Path: Normalized Path instance.
@@ -250,8 +250,8 @@ def __init__(
250250
Initialize an Azure Blob Storage I/O adapter.
251251
252252
Args:
253-
container_url (Optional[str]): Azure Blob container URL.
254-
sas_token (Optional[str]): Optional SAS token.
253+
container_url (str | None): Azure Blob container URL.
254+
sas_token (str | None): Optional SAS token.
255255
blob_content_type (SupportedContentType): Blob content type for uploads.
256256
257257
Raises:
@@ -394,7 +394,7 @@ def _resolve_blob_name(self, path: Path | str) -> str:
394394
created on Windows still produce valid blob names.
395395
396396
Args:
397-
path (Union[Path, str]): Blob URL or relative blob path.
397+
path (Path | str): Blob URL or relative blob path.
398398
399399
Returns:
400400
str: The resolved blob name.
@@ -458,7 +458,7 @@ async def write_file_async(self, path: Path | str, data: bytes) -> None:
458458
If a relative path is provided, it is used as the blob name directly.
459459
460460
Args:
461-
path (Union[Path, str]): Full blob URL or relative blob path.
461+
path (Path | str): Full blob URL or relative blob path.
462462
data (bytes): The data to write.
463463
"""
464464
if not self._client_async:
@@ -477,7 +477,7 @@ async def path_exists_async(self, path: Path | str) -> bool:
477477
Check whether a given path exists in the Azure Blob Storage container.
478478
479479
Args:
480-
path (Union[Path, str]): Blob URL or path to test.
480+
path (Path | str): Blob URL or path to test.
481481
482482
Returns:
483483
bool: True when the path exists.
@@ -501,7 +501,7 @@ async def is_file_async(self, path: Path | str) -> bool:
501501
Check whether the path refers to a file (blob) in Azure Blob Storage.
502502
503503
Args:
504-
path (Union[Path, str]): Blob URL or path to test.
504+
path (Path | str): Blob URL or path to test.
505505
506506
Returns:
507507
bool: True when the blob exists and has non-zero content size.
@@ -525,7 +525,7 @@ async def create_directory_if_not_exists_async(self, directory_path: Path | str)
525525
Log a no-op directory creation for Azure Blob Storage.
526526
527527
Args:
528-
directory_path (Union[Path, str]): Requested directory path.
528+
directory_path (Path | str): Requested directory path.
529529
530530
"""
531531
logger.info(

pyrit/prompt_converter/prompt_converter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT license.
33

4+
from __future__ import annotations
5+
46
import abc
57
import asyncio
68
import inspect
79
import re
810
from dataclasses import dataclass
9-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, get_args
11+
from typing import TYPE_CHECKING, Any, ClassVar, get_args
1012

1113
from pyrit import prompt_converter
1214
from pyrit.models import ComponentIdentifier, ConverterIdentifier, Identifiable, PromptDataType
@@ -89,7 +91,7 @@ def __init_subclass__(cls, **kwargs: object) -> None:
8991
f"Declare the output modalities this converter produces."
9092
)
9193

92-
def __init__(self, *, converter_target: Optional["PromptTarget"] = None) -> None:
94+
def __init__(self, *, converter_target: PromptTarget | None = None) -> None:
9395
"""
9496
Initialize the prompt converter.
9597

0 commit comments

Comments
 (0)