Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions angelslim/compressor/cache/teacache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,9 @@

import numpy as np
import torch
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
scale_lora_layers,
unscale_lora_layers,
)

from ...utils import print_info
from ...utils.lazy_imports import Transformer2DModelOutput, diffusers


class TeaCache:
Expand Down Expand Up @@ -130,9 +124,9 @@ def flux_teacache_forward(
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
if diffusers.utils.USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
diffusers.utils.scale_lora_layers(self, lora_scale)
else:
if (
joint_attention_kwargs is not None
Expand Down Expand Up @@ -236,7 +230,7 @@ def custom_forward(*inputs):

ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False}
if is_torch_version(">=", "1.11.0")
if diffusers.utils.is_torch_version(">=", "1.11.0")
else {}
)
encoder_hidden_states, hidden_states = (
Expand Down Expand Up @@ -294,7 +288,7 @@ def custom_forward(*inputs):

ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False}
if is_torch_version(">=", "1.11.0")
if diffusers.utils.is_torch_version(">=", "1.11.0")
else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
Expand Down Expand Up @@ -342,7 +336,9 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
{"use_reentrant": False}
if diffusers.utils.is_torch_version(">=", "1.11.0")
else {}
)
encoder_hidden_states, hidden_states = (
torch.utils.checkpoint.checkpoint(
Expand Down Expand Up @@ -398,7 +394,9 @@ def custom_forward(*inputs):
return custom_forward

ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
{"use_reentrant": False}
if diffusers.utils.is_torch_version(">=", "1.11.0")
else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
Expand Down Expand Up @@ -432,9 +430,9 @@ def custom_forward(*inputs):
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

if USE_PEFT_BACKEND:
if diffusers.utils.USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
diffusers.utils.unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from typing import Any, Dict, Optional

import numpy as np
import ray
from fastchat.llm_judge.common import load_questions
from transformers import AutoTokenizer

from angelslim.utils.lazy_imports import fastchat, ray

from .generate_baseline_answer import get_model_answers as get_baseline_answers
from .generate_eagle_answer import get_model_answers as get_eagle_answers

Expand Down Expand Up @@ -146,7 +146,7 @@ def _run_eagle_benchmark(self):
"""Run Eagle speculative decoding benchmark"""
args = self._create_args_namespace("eagle")

questions = load_questions(
questions = fastchat.llm_judge.common.load_questions(
self._get_question_file_path(),
self.config.question_begin,
self.config.question_end,
Expand Down Expand Up @@ -186,7 +186,7 @@ def _run_baseline_benchmark(self):
"""Run baseline benchmark"""
args = self._create_args_namespace("baseline")

questions = load_questions(
questions = fastchat.llm_judge.common.load_questions(
self._get_question_file_path(),
self.config.question_begin,
self.config.question_end,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from typing import Any, Dict, List

import numpy as np
import ray
import shortuuid
import torch
from fastchat.llm_judge.common import load_questions
from tqdm import tqdm

from angelslim.compressor.speculative.inference.models import Eagle3Model
from angelslim.utils.lazy_imports import fastchat, ray

SYSTEM_PROMPT = {
"role": "system",
Expand Down Expand Up @@ -231,7 +230,7 @@ def get_model_answers(

def run_evaluation(config: EvaluationConfig, args: argparse.Namespace) -> None:
"""Run the evaluation with optional distributed processing"""
questions = load_questions(
questions = fastchat.llm_judge.common.load_questions(
config.question_file, args.question_begin, args.question_end
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from typing import Any, Dict, List

import numpy as np
import ray
import shortuuid
import torch
from fastchat.llm_judge.common import load_questions
from tqdm import tqdm

from angelslim.compressor.speculative.inference.models import Eagle3Model
from angelslim.utils.lazy_imports import fastchat, ray

SYSTEM_PROMPT = {
"role": "system",
Expand Down Expand Up @@ -237,7 +236,7 @@ def get_model_answers(

def run_evaluation(config: EvaluationConfig, args: argparse.Namespace) -> None:
"""Run the evaluation with optional distributed processing"""
questions = load_questions(
questions = fastchat.llm_judge.common.load_questions(
config.question_file, args.question_begin, args.question_end
)

Expand Down
4 changes: 2 additions & 2 deletions angelslim/data/multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@

from datasets import load_dataset
from PIL import Image
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
from transformers import ProcessorMixin

from ..utils.lazy_imports import qwen_vl_utils
from .base_dataset import BaseDataset


Expand Down Expand Up @@ -108,7 +108,7 @@ def _process_and_append(self, messages: List[Dict]):
)

# Extract vision info
image_inputs, video_inputs = process_vision_info(messages)
image_inputs, video_inputs = qwen_vl_utils.process_vision_info(messages)

# Process inputs
inputs = self.processor(
Expand Down
13 changes: 8 additions & 5 deletions angelslim/models/diffusion/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
import numpy as np
import torch
import torch.nn as nn
from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from safetensors.torch import load_file
from tqdm import tqdm

from ...compressor import CompressorFactory
from ...compressor.quant.core import PTQDiffusionSave, PTQOnlyScaleSave, QuantConfig
from ...compressor.quant.modules import QLinear
from ...utils.lazy_imports import (
FluxPipelineOutput,
calculate_shift,
diffusers,
retrieve_timesteps,
)
from ...utils.utils import find_layers, find_parent_layer_and_sub_name
from ..base_model import BaseDiffusionModel
from ..model_factory import SlimModelFactory
Expand Down Expand Up @@ -82,7 +85,7 @@ def from_pretrained(
[comp_name], self, slim_config=slim_config
)
else:
self.model = FluxPipeline.from_pretrained(
self.model = diffusers.FluxPipeline.from_pretrained(
model_path,
torch_dtype=torch_dtype,
cache_dir=cache_dir,
Expand Down Expand Up @@ -199,7 +202,7 @@ def model_forward(self, dataloader, **kwargs):
).images[0]


class FluxSlimPipeline(FluxPipeline):
class FluxSlimPipeline(diffusers.FluxPipeline):
def __init__(
self,
scheduler,
Expand Down
1 change: 1 addition & 0 deletions angelslim/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .config_parser import SlimConfigParser, parse_json_full_config # noqa: F401
from .default_compress_config import * # noqa: F401 F403
from .lazy_imports import * # noqa: F401 F403
from .utils import common_prefix # noqa: F401
from .utils import find_layers # noqa: F401
from .utils import find_parent_layer_and_sub_name # noqa: F401
Expand Down
Loading