Skip to content

Commit 4eec2f7

Browse files
Minor change in the test file
1 parent c531a18 commit 4eec2f7

18 files changed

Lines changed: 69 additions & 735 deletions

examples/benchmarks/ort_inference_performance.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=
5858

5959
logger.info(f'Running ORT inference benchmark with HuggingFace model: {model_identifier}')
6060

61-
context = BenchmarkRegistry.create_benchmark_context(
62-
'ort-inference', platform=Platform.CUDA, parameters=parameters
63-
)
61+
context = BenchmarkRegistry.create_benchmark_context('ort-inference', platform=Platform.CUDA, parameters=parameters)
6462
benchmark = BenchmarkRegistry.launch_benchmark(context)
6563
if benchmark:
6664
logger.info(
@@ -74,12 +72,14 @@ def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=
7472
if __name__ == '__main__':
7573
parser = argparse.ArgumentParser(description='ORT inference benchmark')
7674
parser.add_argument(
77-
'--model_source', type=str, default='in-house', choices=['in-house', 'huggingface'],
75+
'--model_source',
76+
type=str,
77+
default='in-house',
78+
choices=['in-house', 'huggingface'],
7879
help='Source of the model: in-house (default) or huggingface'
7980
)
8081
parser.add_argument(
81-
'--model_identifier', type=str, default='bert-base-uncased',
82-
help='HuggingFace model identifier'
82+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
8383
)
8484
parser.add_argument('--precision', type=str, default='float16', choices=['float32', 'float16', 'int8'])
8585
parser.add_argument('--batch_size', type=int, default=32)

examples/benchmarks/pytorch_huggingface_models.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

examples/benchmarks/tensorrt_inference_performance.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@ def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32,
7474
if __name__ == '__main__':
7575
parser = argparse.ArgumentParser(description='TensorRT inference benchmark')
7676
parser.add_argument(
77-
'--model_source', type=str, default='in-house', choices=['in-house', 'huggingface'],
77+
'--model_source',
78+
type=str,
79+
default='in-house',
80+
choices=['in-house', 'huggingface'],
7881
help='Source of the model: in-house (default) or huggingface'
7982
)
8083
parser.add_argument(
81-
'--model_identifier', type=str, default='bert-base-uncased',
82-
help='HuggingFace model identifier'
84+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
8385
)
8486
parser.add_argument('--precision', type=str, default='fp16', choices=['fp32', 'fp16', 'int8'])
8587
parser.add_argument('--batch_size', type=int, default=32)
@@ -89,8 +91,7 @@ def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32,
8991

9092
if args.model_source == 'huggingface':
9193
run_huggingface_benchmark(
92-
args.model_identifier, args.precision, args.batch_size,
93-
args.seq_length, args.iterations
94+
args.model_identifier, args.precision, args.batch_size, args.seq_length, args.iterations
9495
)
9596
else:
9697
run_inhouse_benchmark()

superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
class torch2onnxExporter():
1919
"""PyTorch model to ONNX exporter."""
20+
2021
def __init__(self):
2122
"""Constructor."""
2223
from transformers import BertConfig, GPT2Config, LlamaConfig
@@ -322,6 +323,7 @@ def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=5
322323

323324
# Wrapper for vision models
324325
class VisionModelWrapper(torch.nn.Module):
326+
325327
def __init__(self, model):
326328
super().__init__()
327329
self.model = model
@@ -336,20 +338,29 @@ def forward(self, pixel_values):
336338
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
337339

338340
wrapped_model = VisionModelWrapper(model)
339-
export_args = (dummy_input,)
341+
export_args = (dummy_input, )
340342
else:
341343
# NLP models: use input_ids and attention_mask
342344
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
343345
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
344346
input_names = ['input_ids', 'attention_mask']
345347
dynamic_axes = {
346-
'input_ids': {0: 'batch_size', 1: 'seq_length'},
347-
'attention_mask': {0: 'batch_size', 1: 'seq_length'},
348-
'output': {0: 'batch_size'},
348+
'input_ids': {
349+
0: 'batch_size',
350+
1: 'seq_length'
351+
},
352+
'attention_mask': {
353+
0: 'batch_size',
354+
1: 'seq_length'
355+
},
356+
'output': {
357+
0: 'batch_size'
358+
},
349359
}
350360

351361
# Wrapper for NLP models
352362
class NLPModelWrapper(torch.nn.Module):
363+
353364
def __init__(self, model):
354365
super().__init__()
355366
self.model = model

superbench/benchmarks/micro_benchmarks/huggingface_model_loader.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,7 @@ def load_model(
100100
dtype = self._get_torch_dtype(torch_dtype) if torch_dtype else None
101101

102102
# Prepare loading kwargs
103-
load_kwargs = {
104-
'cache_dir': self.cache_dir,
105-
'revision': revision,
106-
**kwargs
107-
}
103+
load_kwargs = {'cache_dir': self.cache_dir, 'revision': revision, **kwargs}
108104

109105
# Add token if available
110106
if self.token:
@@ -117,19 +113,15 @@ def load_model(
117113
# Load config (use pre-downloaded config if provided)
118114
if config is None:
119115
logger.info('Loading model configuration...')
120-
config = AutoConfig.from_pretrained(
121-
model_identifier, trust_remote_code=True, **load_kwargs
122-
)
116+
config = AutoConfig.from_pretrained(model_identifier, trust_remote_code=True, **load_kwargs)
123117
else:
124118
logger.info('Using pre-downloaded model configuration.')
125119

126120
# Load tokenizer (may fail for some models, that's ok)
127121
tokenizer = None
128122
try:
129123
logger.info('Loading tokenizer...')
130-
tokenizer = AutoTokenizer.from_pretrained(
131-
model_identifier, trust_remote_code=True, **load_kwargs
132-
)
124+
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True, **load_kwargs)
133125
except Exception as e:
134126
logger.warning(f'Could not load tokenizer: {e}. Continuing without tokenizer.')
135127

@@ -179,7 +171,9 @@ def load_model(
179171
raise ModelLoadError(f"Unexpected error loading model '{model_identifier}': {e}") from e
180172

181173
def load_model_from_config(
182-
self, config: ModelSourceConfig, device: Optional[str] = None,
174+
self,
175+
config: ModelSourceConfig,
176+
device: Optional[str] = None,
183177
config_pretrained: Optional[PretrainedConfig] = None,
184178
) -> Tuple[PreTrainedModel, PretrainedConfig, AutoTokenizer]:
185179
"""Load a model using ModelSourceConfig.
@@ -197,10 +191,8 @@ def load_model_from_config(
197191
ModelLoadError: If model loading fails.
198192
"""
199193
if not config.is_huggingface():
200-
raise ValueError(
201-
f"Cannot load model with source '{config.source}'. "
202-
"Use 'huggingface' source."
203-
)
194+
raise ValueError(f"Cannot load model with source '{config.source}'. "
195+
"Use 'huggingface' source.")
204196

205197
# Validate config
206198
is_valid, error = config.validate()
@@ -244,10 +236,8 @@ def _get_torch_dtype(self, dtype_str: str) -> torch.dtype:
244236
}
245237

246238
if dtype_str.lower() not in dtype_map:
247-
raise ValueError(
248-
f"Invalid dtype '{dtype_str}'. "
249-
f'Must be one of {list(dtype_map.keys())}'
250-
)
239+
raise ValueError(f"Invalid dtype '{dtype_str}'. "
240+
f'Must be one of {list(dtype_map.keys())}')
251241

252242
return dtype_map[dtype_str.lower()]
253243

@@ -289,9 +279,7 @@ def estimate_param_count_from_config(hf_config) -> Optional[int]:
289279

290280
# Embeddings: token + (optional) position
291281
max_pos = getattr(hf_config, 'max_position_embeddings', 0)
292-
has_pos_embed = getattr(hf_config, 'position_embedding_type', None) not in (
293-
'rotary', None
294-
)
282+
has_pos_embed = getattr(hf_config, 'position_embedding_type', None) not in ('rotary', None)
295283
embed_params = vocab * hidden
296284
if has_pos_embed and max_pos > 0:
297285
embed_params += max_pos * hidden
@@ -346,7 +334,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
346334
precision_lower = precision_str.lower()
347335
if precision_lower in ('float16', 'fp16', 'bfloat16', 'bf16'):
348336
bytes_per_param = 2
349-
elif precision_lower in ('int8',):
337+
elif precision_lower in ('int8', ):
350338
bytes_per_param = 1
351339
else:
352340
bytes_per_param = 4
@@ -368,7 +356,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
368356
except ImportError:
369357
logger.warning('psutil not installed — cannot check system memory. Skipping memory check.')
370358
return 0, 0, True
371-
max_gpu_mem = 80 * (1024 ** 3) # 80GB — largest common single-GPU memory
359+
max_gpu_mem = 80 * (1024**3) # 80GB — largest common single-GPU memory
372360
effective_mem = min(sys_mem, max_gpu_mem)
373361
fits = (estimated_bytes / effective_mem) < 0.85
374362
return estimated_bytes, effective_mem, fits

superbench/benchmarks/micro_benchmarks/model_source_config.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelSourceConfig:
3333
revision: Optional[str] = None
3434
cache_dir: Optional[str] = None
3535
device_map: Optional[str] = None
36-
use_auth_token: Optional[str] = None # Deprecated
36+
use_auth_token: Optional[str] = None # Deprecated
3737
additional_kwargs: Dict[str, Any] = field(default_factory=dict)
3838

3939
def __post_init__(self):
@@ -45,18 +45,14 @@ def __post_init__(self):
4545
# Normalize and validate source
4646
self.source = self.source.lower()
4747
if self.source not in ['in-house', 'huggingface']:
48-
raise ValueError(
49-
f"Invalid model source '{self.source}'. "
50-
f"Must be 'in-house' or 'huggingface'."
51-
)
48+
raise ValueError(f"Invalid model source '{self.source}'. "
49+
f"Must be 'in-house' or 'huggingface'.")
5250

5351
# Validate torch_dtype
5452
valid_dtypes = ['float32', 'float16', 'bfloat16', 'int8']
5553
if self.torch_dtype not in valid_dtypes:
56-
raise ValueError(
57-
f"Invalid torch_dtype '{self.torch_dtype}'. "
58-
f'Must be one of {valid_dtypes}.'
59-
)
54+
raise ValueError(f"Invalid torch_dtype '{self.torch_dtype}'. "
55+
f'Must be one of {valid_dtypes}.')
6056

6157
# Validate identifier is provided
6258
if not self.identifier:
@@ -72,10 +68,7 @@ def validate(self) -> Tuple[bool, str]:
7268
# Check identifier is not empty for HuggingFace models
7369
if self.source == 'huggingface':
7470
if not self.identifier or not self.identifier.strip():
75-
return (
76-
False,
77-
'HuggingFace model identifier cannot be empty'
78-
)
71+
return (False, 'HuggingFace model identifier cannot be empty')
7972

8073
return (True, '')
8174

0 commit comments

Comments
 (0)