Skip to content

Commit 0cc6351

Browse files
committed
remove flash-attn2
1 parent 80aa547 commit 0cc6351

5 files changed

Lines changed: 3 additions & 47 deletions

File tree

examples/cosmos/README.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ cd examples/cosmos
1414
pip install -r requirements.txt
1515
```
1616

17-
> [!NOTE]
18-
> `flash-attn` is required for the default `flash_attention_2` text encoder attention implementation and must be installed separately after PyTorch:
19-
> ```bash
20-
> pip install flash-attn --no-build-isolation
21-
> ```
22-
> If your hardware does not support it, pass `--text_encoder_attn_implementation sdpa` to the training and eval scripts instead.
23-
2417
## Data preparation
2518

2619
The training script expects a dataset directory with the following layout:

examples/cosmos/eval_cosmos_predict25_lora.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,6 @@ def parse_args():
102102
default=None,
103103
help="Negative prompt. Defaults to the pipeline's built-in negative prompt.",
104104
)
105-
parser.add_argument(
106-
"--text_encoder_attn_implementation",
107-
type=str,
108-
default="flash_attention_2",
109-
choices=["eager", "sdpa", "flash_attention_2"],
110-
help="The attention implementation to use for the text encoder (Qwen2.5 VL).",
111-
)
112-
113105
return parser.parse_args()
114106

115107

@@ -144,7 +136,6 @@ def check_video_safety(self, video):
144136
device_map=args.device,
145137
torch_dtype=torch.bfloat16,
146138
safety_checker=MockSafetyChecker(),
147-
text_encoder_attn_implementation=args.text_encoder_attn_implementation,
148139
)
149140

150141
if args.lora_dir is not None:

examples/cosmos/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
--extra-index-url https://download.pytorch.org/whl/cu130
2+
torch
3+
torchvision
14
accelerate>=0.31.0
25
huggingface_hub
36
imageio

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,6 @@ def parse_args():
7474
default=None,
7575
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
7676
)
77-
parser.add_argument(
78-
"--text_encoder_attn_implementation",
79-
type=str,
80-
default="flash_attention_2",
81-
choices=["eager", "sdpa", "flash_attention_2"],
82-
help="The attention implementation to use for the text encoder (Qwen2.5 VL).",
83-
)
8477
parser.add_argument(
8578
"--train_data_dir",
8679
type=str,
@@ -516,7 +509,6 @@ def main():
516509
args.pretrained_model_name_or_path,
517510
revision=args.revision,
518511
torch_dtype=torch.bfloat16,
519-
text_encoder_attn_implementation=args.text_encoder_attn_implementation,
520512
safety_checker=MockSafetyChecker(),
521513
)
522514

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
from typing import Callable
1716

1817
import numpy as np
@@ -245,28 +244,6 @@ def __init__(
245244
self.latents_mean = latents_mean
246245
self.latents_std = 1.0 / latents_std
247246

248-
@classmethod
249-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
250-
text_encoder_attn_implementation = kwargs.pop("text_encoder_attn_implementation", "flash_attention_2")
251-
if "text_encoder" not in kwargs:
252-
load_kwargs = {
253-
"revision": kwargs.get("revision", None),
254-
"device_map": kwargs.get("device_map", None),
255-
"torch_dtype": kwargs.get("torch_dtype", None),
256-
"attn_implementation": text_encoder_attn_implementation,
257-
}
258-
259-
if os.path.isdir(pretrained_model_name_or_path):
260-
text_encoder_path = os.path.join(pretrained_model_name_or_path, "text_encoder")
261-
else:
262-
text_encoder_path = pretrained_model_name_or_path
263-
load_kwargs["subfolder"] = "text_encoder"
264-
kwargs["text_encoder"] = Qwen2_5_VLForConditionalGeneration.from_pretrained(
265-
text_encoder_path, **load_kwargs
266-
)
267-
268-
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
269-
270247
def get_latent_shape_cthw(self, height: int, width: int, num_frames: int):
271248
C = self.vae.config.z_dim
272249
T = (num_frames - 1) // self.vae_scale_factor_temporal + 1

0 commit comments

Comments
 (0)