Skip to content

Commit 18df07b

Browse files
Minor change in the test file
1 parent c531a18 commit 18df07b

18 files changed

Lines changed: 60 additions & 739 deletions

examples/benchmarks/ort_inference_performance.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=
5858

5959
logger.info(f'Running ORT inference benchmark with HuggingFace model: {model_identifier}')
6060

61-
context = BenchmarkRegistry.create_benchmark_context(
62-
'ort-inference', platform=Platform.CUDA, parameters=parameters
63-
)
61+
context = BenchmarkRegistry.create_benchmark_context('ort-inference', platform=Platform.CUDA, parameters=parameters)
6462
benchmark = BenchmarkRegistry.launch_benchmark(context)
6563
if benchmark:
6664
logger.info(
@@ -74,12 +72,14 @@ def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=
7472
if __name__ == '__main__':
7573
parser = argparse.ArgumentParser(description='ORT inference benchmark')
7674
parser.add_argument(
77-
'--model_source', type=str, default='in-house', choices=['in-house', 'huggingface'],
75+
'--model_source',
76+
type=str,
77+
default='in-house',
78+
choices=['in-house', 'huggingface'],
7879
help='Source of the model: in-house (default) or huggingface'
7980
)
8081
parser.add_argument(
81-
'--model_identifier', type=str, default='bert-base-uncased',
82-
help='HuggingFace model identifier'
82+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
8383
)
8484
parser.add_argument('--precision', type=str, default='float16', choices=['float32', 'float16', 'int8'])
8585
parser.add_argument('--batch_size', type=int, default=32)

examples/benchmarks/pytorch_huggingface_models.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

examples/benchmarks/tensorrt_inference_performance.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,14 @@ def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32,
7474
if __name__ == '__main__':
7575
parser = argparse.ArgumentParser(description='TensorRT inference benchmark')
7676
parser.add_argument(
77-
'--model_source', type=str, default='in-house', choices=['in-house', 'huggingface'],
77+
'--model_source',
78+
type=str,
79+
default='in-house',
80+
choices=['in-house', 'huggingface'],
7881
help='Source of the model: in-house (default) or huggingface'
7982
)
8083
parser.add_argument(
81-
'--model_identifier', type=str, default='bert-base-uncased',
82-
help='HuggingFace model identifier'
84+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
8385
)
8486
parser.add_argument('--precision', type=str, default='fp16', choices=['fp32', 'fp16', 'int8'])
8587
parser.add_argument('--batch_size', type=int, default=32)
@@ -89,8 +91,7 @@ def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32,
8991

9092
if args.model_source == 'huggingface':
9193
run_huggingface_benchmark(
92-
args.model_identifier, args.precision, args.batch_size,
93-
args.seq_length, args.iterations
94+
args.model_identifier, args.precision, args.batch_size, args.seq_length, args.iterations
9495
)
9596
else:
9697
run_inhouse_benchmark()

superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,16 +336,24 @@ def forward(self, pixel_values):
336336
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
337337

338338
wrapped_model = VisionModelWrapper(model)
339-
export_args = (dummy_input,)
339+
export_args = (dummy_input, )
340340
else:
341341
# NLP models: use input_ids and attention_mask
342342
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
343343
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
344344
input_names = ['input_ids', 'attention_mask']
345345
dynamic_axes = {
346-
'input_ids': {0: 'batch_size', 1: 'seq_length'},
347-
'attention_mask': {0: 'batch_size', 1: 'seq_length'},
348-
'output': {0: 'batch_size'},
346+
'input_ids': {
347+
0: 'batch_size',
348+
1: 'seq_length'
349+
},
350+
'attention_mask': {
351+
0: 'batch_size',
352+
1: 'seq_length'
353+
},
354+
'output': {
355+
0: 'batch_size'
356+
},
349357
}
350358

351359
# Wrapper for NLP models

superbench/benchmarks/micro_benchmarks/huggingface_model_loader.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class HuggingFaceModelLoader:
4747
cache_dir: Directory to cache downloaded models.
4848
token: HuggingFace authentication token for private/gated models.
4949
"""
50-
5150
def __init__(self, cache_dir: Optional[str] = None, token: Optional[str] = None):
5251
"""Initialize the HuggingFace model loader.
5352
@@ -100,11 +99,7 @@ def load_model(
10099
dtype = self._get_torch_dtype(torch_dtype) if torch_dtype else None
101100

102101
# Prepare loading kwargs
103-
load_kwargs = {
104-
'cache_dir': self.cache_dir,
105-
'revision': revision,
106-
**kwargs
107-
}
102+
load_kwargs = {'cache_dir': self.cache_dir, 'revision': revision, **kwargs}
108103

109104
# Add token if available
110105
if self.token:
@@ -117,19 +112,15 @@ def load_model(
117112
# Load config (use pre-downloaded config if provided)
118113
if config is None:
119114
logger.info('Loading model configuration...')
120-
config = AutoConfig.from_pretrained(
121-
model_identifier, trust_remote_code=True, **load_kwargs
122-
)
115+
config = AutoConfig.from_pretrained(model_identifier, trust_remote_code=True, **load_kwargs)
123116
else:
124117
logger.info('Using pre-downloaded model configuration.')
125118

126119
# Load tokenizer (may fail for some models, that's ok)
127120
tokenizer = None
128121
try:
129122
logger.info('Loading tokenizer...')
130-
tokenizer = AutoTokenizer.from_pretrained(
131-
model_identifier, trust_remote_code=True, **load_kwargs
132-
)
123+
tokenizer = AutoTokenizer.from_pretrained(model_identifier, trust_remote_code=True, **load_kwargs)
133124
except Exception as e:
134125
logger.warning(f'Could not load tokenizer: {e}. Continuing without tokenizer.')
135126

@@ -179,7 +170,9 @@ def load_model(
179170
raise ModelLoadError(f"Unexpected error loading model '{model_identifier}': {e}") from e
180171

181172
def load_model_from_config(
182-
self, config: ModelSourceConfig, device: Optional[str] = None,
173+
self,
174+
config: ModelSourceConfig,
175+
device: Optional[str] = None,
183176
config_pretrained: Optional[PretrainedConfig] = None,
184177
) -> Tuple[PreTrainedModel, PretrainedConfig, AutoTokenizer]:
185178
"""Load a model using ModelSourceConfig.
@@ -197,10 +190,7 @@ def load_model_from_config(
197190
ModelLoadError: If model loading fails.
198191
"""
199192
if not config.is_huggingface():
200-
raise ValueError(
201-
f"Cannot load model with source '{config.source}'. "
202-
"Use 'huggingface' source."
203-
)
193+
raise ValueError(f"Cannot load model with source '{config.source}'. Use 'huggingface' source.")
204194

205195
# Validate config
206196
is_valid, error = config.validate()
@@ -244,10 +234,7 @@ def _get_torch_dtype(self, dtype_str: str) -> torch.dtype:
244234
}
245235

246236
if dtype_str.lower() not in dtype_map:
247-
raise ValueError(
248-
f"Invalid dtype '{dtype_str}'. "
249-
f'Must be one of {list(dtype_map.keys())}'
250-
)
237+
raise ValueError(f"Invalid dtype '{dtype_str}'.Must be one of {list(dtype_map.keys())}")
251238

252239
return dtype_map[dtype_str.lower()]
253240

@@ -289,9 +276,7 @@ def estimate_param_count_from_config(hf_config) -> Optional[int]:
289276

290277
# Embeddings: token + (optional) position
291278
max_pos = getattr(hf_config, 'max_position_embeddings', 0)
292-
has_pos_embed = getattr(hf_config, 'position_embedding_type', None) not in (
293-
'rotary', None
294-
)
279+
has_pos_embed = getattr(hf_config, 'position_embedding_type', None) not in ('rotary', None)
295280
embed_params = vocab * hidden
296281
if has_pos_embed and max_pos > 0:
297282
embed_params += max_pos * hidden
@@ -346,7 +331,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
346331
precision_lower = precision_str.lower()
347332
if precision_lower in ('float16', 'fp16', 'bfloat16', 'bf16'):
348333
bytes_per_param = 2
349-
elif precision_lower in ('int8',):
334+
elif precision_lower in ('int8', ):
350335
bytes_per_param = 1
351336
else:
352337
bytes_per_param = 4
@@ -368,7 +353,7 @@ def estimate_memory(param_count, precision_str, mode='training'):
368353
except ImportError:
369354
logger.warning('psutil not installed — cannot check system memory. Skipping memory check.')
370355
return 0, 0, True
371-
max_gpu_mem = 80 * (1024 ** 3) # 80GB — largest common single-GPU memory
356+
max_gpu_mem = 80 * (1024**3) # 80GB — largest common single-GPU memory
372357
effective_mem = min(sys_mem, max_gpu_mem)
373358
fits = (estimated_bytes / effective_mem) < 0.85
374359
return estimated_bytes, effective_mem, fits

superbench/benchmarks/micro_benchmarks/model_source_config.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelSourceConfig:
3333
revision: Optional[str] = None
3434
cache_dir: Optional[str] = None
3535
device_map: Optional[str] = None
36-
use_auth_token: Optional[str] = None # Deprecated
36+
use_auth_token: Optional[str] = None # Deprecated
3737
additional_kwargs: Dict[str, Any] = field(default_factory=dict)
3838

3939
def __post_init__(self):
@@ -45,18 +45,12 @@ def __post_init__(self):
4545
# Normalize and validate source
4646
self.source = self.source.lower()
4747
if self.source not in ['in-house', 'huggingface']:
48-
raise ValueError(
49-
f"Invalid model source '{self.source}'. "
50-
f"Must be 'in-house' or 'huggingface'."
51-
)
48+
raise ValueError(f"Invalid model source '{self.source}'.Must be 'in-house' or 'huggingface'.")
5249

5350
# Validate torch_dtype
5451
valid_dtypes = ['float32', 'float16', 'bfloat16', 'int8']
5552
if self.torch_dtype not in valid_dtypes:
56-
raise ValueError(
57-
f"Invalid torch_dtype '{self.torch_dtype}'. "
58-
f'Must be one of {valid_dtypes}.'
59-
)
53+
raise ValueError(f"Invalid torch_dtype '{self.torch_dtype}'.Must be one of {valid_dtypes}.")
6054

6155
# Validate identifier is provided
6256
if not self.identifier:
@@ -72,10 +66,7 @@ def validate(self) -> Tuple[bool, str]:
7266
# Check identifier is not empty for HuggingFace models
7367
if self.source == 'huggingface':
7468
if not self.identifier or not self.identifier.strip():
75-
return (
76-
False,
77-
'HuggingFace model identifier cannot be empty'
78-
)
69+
return (False, 'HuggingFace model identifier cannot be empty')
7970

8071
return (True, '')
8172

0 commit comments

Comments
 (0)