Skip to content

Commit c9fcdfe

Browse files
committed
docs: add documentation
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 75d1cf1 commit c9fcdfe

1 file changed

Lines changed: 177 additions & 1 deletion

File tree

docs/training.md

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,180 @@ datasets:
254254
remove_columns: all
255255
fn_kwargs:
256256
conversation_column_name: "messages"
257-
```
257+
```
258+
259+
## Long Context Training
260+
261+
Long context training for instance to train on 128k sequence length can be performed using context parallel.
262+
263+
### Model Architectures Supported
264+
265+
1. Hybrid attention dense models. e.g. granite-4.0-h-1b
266+
1. Hybrid attention moe models. e.g. ibm-granite/granite-4.0-h-small
267+
1. SDPA attention dense models e.g. granite-4.0-1b
268+
1. SDPA attention moe models e.g. ibm-research/moe-7b-1b-active-shared-experts, mixtral etc
269+
270+
### Parallelisms Supported with Context Parallel
271+
272+
1. Context Parallel + FSDP sharding
273+
1. Context Parallel + FSDP sharding + Expert Parallel
274+
1. Context Parallel + FSDP sharding + DP
275+
1. Context Parallel + FSDP sharding + DP + Expert Parallel
276+
277+
### Usage
278+
279+
#### Enabling Context Parallel
280+
281+
FSDPv2 is compulsory to use context parallel. FSDPv2 can be activated using the following accelerate config
282+
283+
```
284+
compute_environment: LOCAL_MACHINE
285+
distributed_type: FSDP
286+
fsdp_config:
287+
fsdp_version: "2" # turn on v2 of FSDP
288+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
289+
fsdp_backward_prefetch: BACKWARD_PRE
290+
fsdp_backward_prefetch_policy: BACKWARD_PRE
291+
fsdp_forward_prefetch: false
292+
fsdp_offload_params: false
293+
fsdp_sharding_strategy: FULL_SHARD
294+
fsdp_state_dict_type: SHARDED_STATE_DICT
295+
fsdp_cpu_ram_efficient_loading: true
296+
fsdp_sync_module_states: true
297+
fsdp_use_orig_params: true
298+
```
299+
300+
Then, context parallel can be activated using the below accelerate config
301+
302+
```
303+
compute_environment: LOCAL_MACHINE
304+
distributed_type: FSDP
305+
fsdp_config:
306+
fsdp_version: "2" # turn on v2 of FSDP
307+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
308+
fsdp_backward_prefetch: BACKWARD_PRE
309+
fsdp_backward_prefetch_policy: BACKWARD_PRE
310+
fsdp_forward_prefetch: false
311+
fsdp_offload_params: false
312+
fsdp_sharding_strategy: FULL_SHARD
313+
fsdp_state_dict_type: SHARDED_STATE_DICT
314+
fsdp_cpu_ram_efficient_loading: true
315+
fsdp_sync_module_states: true
316+
fsdp_use_orig_params: true
317+
use_parallelism_config: "true" # required to turn on parallelism feature
318+
parallelism_config_cp_size: 2 # context parallel degree
319+
machine_rank: 0
320+
num_machines: 1
321+
num_processes: 8
322+
rdzv_backend: static
323+
same_network: true
324+
```
325+
326+
When using any model with mamba attention, its required to set the flag `--mcp` with context parallel degree. Further, for hybrid models that use combination of mamba and SDPA attention should use both `--mcp` and `parallelism_config_cp_size` options both having the same cp degree value.
327+
328+
#### Enabling Context Parallel with Data Parallel
329+
330+
Context parallel can be combined with data parallel using the `parallelism_config_dp_shard_size` parameter.
331+
332+
```
333+
compute_environment: LOCAL_MACHINE
334+
distributed_type: FSDP
335+
fsdp_config:
336+
fsdp_version: "2" # turn on v2 of FSDP
337+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
338+
fsdp_backward_prefetch: BACKWARD_PRE
339+
fsdp_backward_prefetch_policy: BACKWARD_PRE
340+
fsdp_forward_prefetch: false
341+
fsdp_offload_params: false
342+
fsdp_sharding_strategy: FULL_SHARD
343+
fsdp_state_dict_type: SHARDED_STATE_DICT
344+
fsdp_cpu_ram_efficient_loading: true
345+
fsdp_sync_module_states: true
346+
fsdp_use_orig_params: true
347+
use_parallelism_config: "true" # required to turn on parallelism feature
348+
parallelism_config_cp_size: 2 # context parallel degree
349+
parallelism_config_dp_shard_size: 8 # data parallel degree
350+
machine_rank: 0
351+
num_machines: 1
352+
num_processes: 8
353+
rdzv_backend: static
354+
same_network: true
355+
```
356+
357+
To be noted that, context parallel degree multiplied by data parallel degree should be equal to the total number of GPUs being used.
358+
359+
#### Enabling Mixed Precision
360+
361+
Mixed precision has to be provided using `fsdp_mixed_precision_policy` parameter only. Do not use direct flags like `--bf16` or `mixed_precision` accelerate config parameter.
362+
363+
```
364+
compute_environment: LOCAL_MACHINE
365+
distributed_type: FSDP
366+
fsdp_config:
367+
fsdp_version: "2" # turn on v2 of FSDP
368+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
369+
fsdp_backward_prefetch: BACKWARD_PRE
370+
fsdp_backward_prefetch_policy: BACKWARD_PRE
371+
fsdp_forward_prefetch: false
372+
fsdp_offload_params: false
373+
fsdp_sharding_strategy: FULL_SHARD
374+
fsdp_state_dict_type: SHARDED_STATE_DICT
375+
fsdp_cpu_ram_efficient_loading: true
376+
fsdp_sync_module_states: true
377+
fsdp_use_orig_params: true
378+
fsdp_mixed_precision_policy: "bf16" # mixed precision policy
379+
use_parallelism_config: "true" # required to turn on parallelism feature
380+
parallelism_config_cp_size: 2 # context parallel degree
381+
parallelism_config_dp_shard_size: 8 # data parallel degree
382+
machine_rank: 0
383+
num_machines: 1
384+
num_processes: 8
385+
rdzv_backend: static
386+
same_network: true
387+
```
388+
389+
#### Gradient Checkpointing
390+
391+
Optimal way to enable gradient checkpointing is using the accelerate config parameter `fsdp_activation_checkpointing` as shown below:
392+
393+
```
394+
compute_environment: LOCAL_MACHINE
395+
distributed_type: FSDP
396+
fsdp_config:
397+
fsdp_version: "2" # turn on v2 of FSDP
398+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
399+
fsdp_backward_prefetch: BACKWARD_PRE
400+
fsdp_backward_prefetch_policy: BACKWARD_PRE
401+
fsdp_forward_prefetch: false
402+
fsdp_offload_params: false
403+
fsdp_sharding_strategy: FULL_SHARD
404+
fsdp_state_dict_type: SHARDED_STATE_DICT
405+
fsdp_cpu_ram_efficient_loading: true
406+
fsdp_sync_module_states: true
407+
fsdp_use_orig_params: true
408+
fsdp_mixed_precision_policy: "bf16" # mixed precision policy
409+
fsdp_activation_checkpointing: true
410+
use_parallelism_config: "true" # required to turn on parallelism feature
411+
parallelism_config_cp_size: 2 # context parallel degree
412+
parallelism_config_dp_shard_size: 8 # data parallel degree
413+
machine_rank: 0
414+
num_machines: 1
415+
num_processes: 8
416+
rdzv_backend: static
417+
same_network: true
418+
```
419+
420+
#### Enabling Context Parallel with Data Parallel and Expert Parallel
421+
422+
For MoE models, expert parallel with MoE kernels can be enabled using the `--fast_moe` flag along with context and data parallelisms. The expert parallel degree is agnostic of context parallel degree. Therefore it can be used like described [here](./tuning-techniques.md#fms-acceleration).
423+
424+
### Recommendations
425+
426+
1. Keeping context parallelism within a node is usually optimal unless there is need for extremely long sequences like 256k. Given that, its optimal to choose the right cp degree in the multiple of 2 starting from 2 and upto 8.
427+
2. Data parallel degree multiplied by context parallel degree should be equal to total number of GPUs being used.
428+
3. Context parallel degree determinies number of chunks sequence has to be divided and distributed across GPUs, therefore it has to be choosen as minimium as needed to accommodate a sequence length.
429+
430+
### Known Limitations
431+
432+
1. load balancing is removed given limited support on mamba cp implementation. This could lead to potential throughput drops for trainings using causal mask.
433+
2. Padding free and flash attention are not supported.

0 commit comments

Comments
 (0)