@@ -392,30 +392,15 @@ def _get_fine_tuning_options_and_model_arn(
392392 if not recipes_with_template :
393393 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } " )
394394
395- # Select recipe based on training type
396- # Collect override_params from ALL matching recipes (standard + subscription)
397- recipe = None
398- if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
399- recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
400- elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
401- recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
402-
403- # Override recipe selection when sequence_length is explicitly requested
395+ # Filter by SequenceLength before recipe selection if sequence_length is requested
404396 if sequence_length :
405- if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
406- candidates = [r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )]
407- elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
408- candidates = [r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )]
409- else :
410- candidates = []
411-
412397 requested = _parse_context_length (sequence_length )
413- candidates_with_context = [r for r in candidates if r .get ("SequenceLength" )]
398+ candidates_with_context = [r for r in recipes_with_template if r .get ("SequenceLength" )]
414399 if candidates_with_context :
415400 filtered = [r for r in candidates_with_context if _parse_context_length (r .get ("SequenceLength" )) >= requested ]
416401 if filtered :
417402 filtered .sort (key = lambda r : _parse_context_length (r .get ("SequenceLength" )))
418- recipe = filtered [ 0 ]
403+ recipes_with_template = filtered
419404 else :
420405 available = sorted (set (r .get ("SequenceLength" ) for r in candidates_with_context ))
421406 raise ValueError (
@@ -428,6 +413,14 @@ def _get_fine_tuning_options_and_model_arn(
428413 f"and sequence length:{ sequence_length } "
429414 )
430415
416+ # Select recipe based on training type
417+ # Collect override_params from ALL matching recipes (standard + subscription)
418+ recipe = None
419+ if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
420+ recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
421+ elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
422+ recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
423+
431424 if not recipe :
432425 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } ,training_type:{ training_type } " )
433426
@@ -601,18 +594,15 @@ def _create_serverless_config(model_arn, customization_technique,
601594 else (training_type .value if isinstance (training_type , TrainingType ) else training_type )
602595
603596 # Create ServerlessJobConfig using shapes
604- config_kwargs = dict (
597+ serverless_config = ServerlessJobConfig (
605598 job_type = job_type ,
606599 base_model_arn = model_arn ,
607600 customization_technique = customization_technique ,
608601 peft = peft ,
609602 evaluator_arn = evaluator_arn ,
610603 accept_eula = accept_eula ,
604+ sequence_length = sequence_length ,
611605 )
612- if sequence_length is not None :
613- config_kwargs ["sequence_length" ] = sequence_length
614-
615- serverless_config = ServerlessJobConfig (** config_kwargs )
616606
617607 return serverless_config
618608
0 commit comments