Commit 7bd5df9
committed
feat(train): Add SequenceLength support for SFT, DPO, RLVR, RLAIF trainers
Add optional sequence_length parameter to all four trainers that enables
customers to specify their desired context length for serverless training
jobs. The parameter is passed in ServerlessJobConfig for recipe filtering.
During trainer initialization, _get_fine_tuning_options_and_model_arn
filters recipes by SequenceLength field, picking the smallest recipe
with context length >= the requested value. Raises ValueError if no
sufficient recipe exists or if recipes lack SequenceLength metadata.
Changes:
- ServerlessJobConfig: add sequence_length field
- _parse_context_length: parse values like '8K' to integers
- _get_fine_tuning_options_and_model_arn: filter by SequenceLength
- _create_serverless_config: conditionally include sequence_length
- SFTTrainer, DPOTrainer, RLVRTrainer, RLAIFTrainer: accept and
thread sequence_length through init and train methods
- Unit tests for all new functionality1 parent 4374751 commit 7bd5df9
11 files changed
Lines changed: 525 additions & 68 deletions
File tree
- sagemaker-core/src/sagemaker/core/shapes
- sagemaker-train
- src/sagemaker/train
- common_utils
- tests/unit/train
- common_utils
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9588 | 9588 | | |
9589 | 9589 | | |
9590 | 9590 | | |
| 9591 | + | |
9591 | 9592 | | |
9592 | 9593 | | |
9593 | 9594 | | |
| |||
9597 | 9598 | | |
9598 | 9599 | | |
9599 | 9600 | | |
9600 | | - | |
| 9601 | + | |
9601 | 9602 | | |
9602 | 9603 | | |
9603 | 9604 | | |
| |||
Lines changed: 72 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
318 | 318 | | |
319 | 319 | | |
320 | 320 | | |
321 | | - | |
322 | | - | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
323 | 348 | | |
324 | 349 | | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
| 356 | + | |
| 357 | + | |
| 358 | + | |
325 | 359 | | |
326 | 360 | | |
327 | 361 | | |
| |||
362 | 396 | | |
363 | 397 | | |
364 | 398 | | |
365 | | - | |
| 399 | + | |
366 | 400 | | |
367 | | - | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
368 | 427 | | |
369 | 428 | | |
370 | 429 | | |
| |||
519 | 578 | | |
520 | 579 | | |
521 | 580 | | |
522 | | - | |
| 581 | + | |
| 582 | + | |
523 | 583 | | |
524 | 584 | | |
525 | 585 | | |
| |||
528 | 588 | | |
529 | 589 | | |
530 | 590 | | |
| 591 | + | |
531 | 592 | | |
532 | 593 | | |
533 | 594 | | |
| |||
537 | 598 | | |
538 | 599 | | |
539 | 600 | | |
540 | | - | |
| 601 | + | |
541 | 602 | | |
542 | 603 | | |
543 | 604 | | |
544 | 605 | | |
545 | 606 | | |
546 | | - | |
| 607 | + | |
547 | 608 | | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
548 | 613 | | |
549 | 614 | | |
550 | 615 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
100 | 100 | | |
101 | 101 | | |
102 | 102 | | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
103 | 107 | | |
104 | 108 | | |
105 | 109 | | |
| |||
116 | 120 | | |
117 | 121 | | |
118 | 122 | | |
| 123 | + | |
119 | 124 | | |
120 | 125 | | |
121 | 126 | | |
| |||
134 | 139 | | |
135 | 140 | | |
136 | 141 | | |
| 142 | + | |
137 | 143 | | |
138 | 144 | | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
147 | 153 | | |
148 | 154 | | |
149 | 155 | | |
| |||
227 | 233 | | |
228 | 234 | | |
229 | 235 | | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
235 | | - | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
236 | 244 | | |
237 | 245 | | |
238 | 246 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
114 | 114 | | |
115 | 115 | | |
116 | 116 | | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
117 | 121 | | |
118 | 122 | | |
119 | 123 | | |
| |||
135 | 139 | | |
136 | 140 | | |
137 | 141 | | |
| 142 | + | |
138 | 143 | | |
139 | 144 | | |
140 | 145 | | |
| |||
156 | 161 | | |
157 | 162 | | |
158 | 163 | | |
| 164 | + | |
159 | 165 | | |
160 | 166 | | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
166 | | - | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
167 | 174 | | |
168 | 175 | | |
169 | 176 | | |
| |||
242 | 249 | | |
243 | 250 | | |
244 | 251 | | |
245 | | - | |
246 | | - | |
247 | | - | |
248 | | - | |
249 | | - | |
250 | | - | |
251 | | - | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
252 | 261 | | |
253 | 262 | | |
254 | 263 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
109 | 113 | | |
110 | 114 | | |
111 | 115 | | |
| |||
126 | 130 | | |
127 | 131 | | |
128 | 132 | | |
| 133 | + | |
129 | 134 | | |
130 | 135 | | |
131 | 136 | | |
| |||
146 | 151 | | |
147 | 152 | | |
148 | 153 | | |
| 154 | + | |
149 | 155 | | |
150 | 156 | | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
158 | 165 | | |
159 | 166 | | |
160 | 167 | | |
| |||
233 | 240 | | |
234 | 241 | | |
235 | 242 | | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
242 | | - | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
243 | 252 | | |
244 | 253 | | |
245 | 254 | | |
| |||
0 commit comments