@@ -481,30 +481,15 @@ def _get_fine_tuning_options_and_model_arn(
481481 if not recipes_with_template :
482482 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } " )
483483
484- # Select recipe based on training type
485- # Collect override_params from ALL matching recipes (standard + subscription)
486- recipe = None
487- if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
488- recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
489- elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
490- recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
491-
492- # Override recipe selection when sequence_length is explicitly requested
484+ # Filter by SequenceLength before recipe selection if sequence_length is requested
493485 if sequence_length :
494- if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
495- candidates = [r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )]
496- elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
497- candidates = [r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )]
498- else :
499- candidates = []
500-
501486 requested = _parse_context_length (sequence_length )
502- candidates_with_context = [r for r in candidates if r .get ("SequenceLength" )]
487+ candidates_with_context = [r for r in recipes_with_template if r .get ("SequenceLength" )]
503488 if candidates_with_context :
504489 filtered = [r for r in candidates_with_context if _parse_context_length (r .get ("SequenceLength" )) >= requested ]
505490 if filtered :
506491 filtered .sort (key = lambda r : _parse_context_length (r .get ("SequenceLength" )))
507- recipe = filtered [ 0 ]
492+ recipes_with_template = filtered
508493 else :
509494 available = sorted (set (r .get ("SequenceLength" ) for r in candidates_with_context ))
510495 raise ValueError (
@@ -517,6 +502,14 @@ def _get_fine_tuning_options_and_model_arn(
517502 f"and sequence length:{ sequence_length } "
518503 )
519504
505+ # Select recipe based on training type
506+ # Collect override_params from ALL matching recipes (standard + subscription)
507+ recipe = None
508+ if (isinstance (training_type , TrainingType ) and training_type == TrainingType .LORA ) or training_type == "LORA" :
509+ recipe = next ((r for r in recipes_with_template if r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
510+ elif (isinstance (training_type , TrainingType ) and training_type == TrainingType .FULL ) or training_type == "FULL" :
511+ recipe = next ((r for r in recipes_with_template if not r .get ("Peft" ) and not r .get ("IsSubscriptionModel" )), None )
512+
520513 if not recipe :
521514 raise ValueError (f"No recipes found with Smtj for technique: { customization_technique } ,training_type:{ training_type } " )
522515
@@ -690,18 +683,15 @@ def _create_serverless_config(model_arn, customization_technique,
690683 else (training_type .value if isinstance (training_type , TrainingType ) else training_type )
691684
692685 # Create ServerlessJobConfig using shapes
693- config_kwargs = dict (
686+ serverless_config = ServerlessJobConfig (
694687 job_type = job_type ,
695688 base_model_arn = model_arn ,
696689 customization_technique = customization_technique ,
697690 peft = peft ,
698691 evaluator_arn = evaluator_arn ,
699692 accept_eula = accept_eula ,
693+ sequence_length = sequence_length ,
700694 )
701- if sequence_length is not None :
702- config_kwargs ["sequence_length" ] = sequence_length
703-
704- serverless_config = ServerlessJobConfig (** config_kwargs )
705695
706696 return serverless_config
707697
0 commit comments