Skip to content

Commit a7ba9db

Browse files
committed
Resolve comments
1 parent d3a8653 commit a7ba9db

3 files changed

Lines changed: 22 additions & 25 deletions

File tree

sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def process(self, kv: Tuple[str, str]):
8484
with FileSystems.open(uri) as f:
8585
image_bytes = f.read()
8686
yield uri, {"image_bytes": image_bytes}
87-
except Exception as e:
87+
except OSError as e:
8888
logging.warning("Failed to read image %s: %s", uri, e)
8989
return
9090

@@ -164,6 +164,8 @@ def load_model(self):
164164
from transformers import BlipForConditionalGeneration, BlipProcessor
165165
processor = BlipProcessor.from_pretrained(self.model_name)
166166
model = BlipForConditionalGeneration.from_pretrained(self.model_name)
167+
model.to(self.device)
168+
model.eval()
167169
return (model, processor)
168170

169171
def batch_elements_kwargs(self):
@@ -173,8 +175,6 @@ def run_inference(
173175
self, batch: List[Dict[str, Any]], model_bundle, inference_args=None):
174176

175177
model, processor = model_bundle
176-
model.to(self.device)
177-
model.eval()
178178
start = now_millis()
179179

180180
images = [x["image"] for x in batch]
@@ -236,6 +236,8 @@ def load_model(self):
236236
from transformers import CLIPModel, CLIPProcessor
237237
processor = CLIPProcessor.from_pretrained(self.model_name)
238238
model = CLIPModel.from_pretrained(self.model_name)
239+
model.to(self.device)
240+
model.eval()
239241
return (model, processor)
240242

241243
def batch_elements_kwargs(self):
@@ -245,8 +247,6 @@ def run_inference(
245247
self, batch: List[Dict[str, Any]], model_bundle, inference_args=None):
246248

247249
model, processor = model_bundle
248-
model.to(self.device)
249-
model.eval()
250250
start_batch = now_millis()
251251

252252
# Flat lists for a single batched CLIP forward pass
@@ -464,15 +464,15 @@ def cleanup_pubsub_resources(
464464
try:
465465
subscriber.delete_subscription(
466466
request={"subscription": full_subscription_path})
467-
print(f"Deleted subscription: {subscription_name}")
467+
logging.info(f"Deleted subscription: {subscription_name}")
468468
except NotFound:
469-
print(f"Subscription already deleted: {subscription_name}")
469+
logging.info(f"Subscription already deleted: {subscription_name}")
470470

471471
try:
472472
publisher.delete_topic(request={"topic": full_topic_path})
473-
print(f"Deleted topic: {topic_name}")
473+
logging.info(f"Deleted topic: {topic_name}")
474474
except NotFound:
475-
print(f"Topic already deleted: {topic_name}")
475+
logging.info(f"Topic already deleted: {topic_name}")
476476

477477

478478
def override_or_add(args, flag, value):

sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,15 @@ def cleanup_pubsub_resources(
342342
try:
343343
subscriber.delete_subscription(
344344
request={"subscription": full_subscription_path})
345-
print(f"Deleted subscription: {subscription_name}")
345+
logging.info(f"Deleted subscription: {subscription_name}")
346346
except NotFound:
347-
print(f"Subscription already deleted: {subscription_name}")
347+
logging.info(f"Subscription already deleted: {subscription_name}")
348348

349349
try:
350350
publisher.delete_topic(request={"topic": full_topic_path})
351-
print(f"Deleted topic: {topic_name}")
351+
logging.info(f"Deleted topic: {topic_name}")
352352
except NotFound:
353-
print(f"Topic already deleted: {topic_name}")
353+
logging.info(f"Topic already deleted: {topic_name}")
354354

355355

356356
def override_or_add(args, flag, value):

sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,10 @@ def process(self, kv: Tuple[str, PredictionResult]):
157157
if isinstance(inference_obj, dict):
158158
logits = inference_obj.get("logits", None)
159159
if logits is None:
160-
# fallback: try first value if dict shape differs
161-
try:
162-
logits = next(iter(inference_obj.values()))
163-
logging.warning(
164-
'Could not find <logits> key in model output.'
165-
'Falling back to first value in dict.')
166-
except Exception:
167-
logging.warning('Could not find <logits> key in dict.')
160+
raise ValueError(
161+
f"Unable to find 'logits' in model output. "
162+
f"Available keys: {list(inference_obj.keys())}"
163+
)
168164
else:
169165
logits = inference_obj
170166

@@ -297,15 +293,15 @@ def cleanup_pubsub_resources(
297293
try:
298294
subscriber.delete_subscription(
299295
request={"subscription": full_subscription_path})
300-
print(f"Deleted subscription: {subscription_name}")
296+
logging.info(f"Deleted subscription: {subscription_name}")
301297
except NotFound:
302-
print(f"Subscription already deleted: {subscription_name}")
298+
logging.info(f"Subscription already deleted: {subscription_name}")
303299

304300
try:
305301
publisher.delete_topic(request={"topic": full_topic_path})
306-
print(f"Deleted topic: {topic_name}")
302+
logging.info(f"Deleted topic: {topic_name}")
307303
except NotFound:
308-
print(f"Topic already deleted: {topic_name}")
304+
logging.info(f"Topic already deleted: {topic_name}")
309305

310306

311307
def override_or_add(args, flag, value):
@@ -475,6 +471,7 @@ def run(
475471

476472
predictions = (
477473
to_infer
474+
| 'Reshuffle' >> beam.Reshuffle()
478475
| 'RunInference' >> RunInference(
479476
KeyedModelHandler(model_handler)).with_resource_hints(
480477
accelerator="type:nvidia-tesla-t4;count:1;install-nvidia-driver"))

0 commit comments

Comments
 (0)