Skip to content

Commit 942d788

Browse files
support ddp (#26)
1 parent 13d31cf commit 942d788

29 files changed

Lines changed: 68 additions & 370 deletions

llmc/__main__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import yaml
1010
from easydict import EasyDict
1111
from loguru import logger
12+
from torch.distributed import destroy_process_group, init_process_group
1213

1314
from llmc.compression.quantization import *
1415
from llmc.compression.sparsification import *
@@ -111,20 +112,29 @@ def main(config):
111112
llmc_start_time = time.time()
112113
parser = argparse.ArgumentParser()
113114
parser.add_argument('--config', type=str, required=True)
115+
parser.add_argument('--task_id', type=str, required=True)
114116
args = parser.parse_args()
115117

116118
with open(args.config, 'r') as file:
117119
config = yaml.safe_load(file)
118120
config = EasyDict(config)
119121

122+
init_process_group(backend='nccl')
123+
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
124+
125+
if int(os.environ['RANK']) != 0:
126+
logger.remove()
127+
120128
check_config(config)
121129

122130
logger.info(f'args: {args}')
123131
logger.info(f'config:\n{json.dumps(config, ensure_ascii=False, indent=4)}')
124132

125133
print_important_package_version()
126134

127-
seed_all(config.base.seed)
135+
logger.info(f'WORLD_SIZE : {int(os.environ["WORLD_SIZE"])}')
136+
137+
seed_all(config.base.seed + int(os.environ['RANK']))
128138

129139
# mkdirs
130140
if 'save' in config:
@@ -149,6 +159,8 @@ def main(config):
149159

150160
main(config)
151161

162+
destroy_process_group()
163+
152164
llmc_end_time = time.time()
153165
llmc_duration_time = llmc_end_time - llmc_start_time
154166
logger.info(f'llmc_duration_time: {llmc_duration_time} s')

llmc/compression/quantization/awq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import gc
2+
import os
23

34
import torch
5+
import torch.distributed as dist
46
import torch.nn as nn
57
from loguru import logger
68

@@ -136,6 +138,8 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
136138
best_error = loss_mean
137139
best_scales = scales_mean
138140
best_scales = best_scales.view(-1)
141+
dist.all_reduce(best_scales, op=dist.ReduceOp.SUM)
142+
best_scales /= int(os.environ['WORLD_SIZE'])
139143
del org_out_dict
140144
gc.collect()
141145
torch.cuda.empty_cache()

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import functools
22
import gc
33
import json
4+
import os
45
from collections import defaultdict
56
from functools import partial
67

78
import torch
9+
import torch.distributed as dist
810
import torch.nn as nn
911
from loguru import logger
1012

@@ -487,6 +489,12 @@ def auto_clip(self, block, input_feat, n_sample_token):
487489
n_sample_token=n_sample_token,
488490
)
489491

492+
dist.all_reduce(max_val, op=dist.ReduceOp.SUM)
493+
max_val /= int(os.environ['WORLD_SIZE'])
494+
495+
dist.all_reduce(min_val, op=dist.ReduceOp.SUM)
496+
min_val /= int(os.environ['WORLD_SIZE'])
497+
490498
self.apply_clip(m, min_val, max_val, n)
491499

492500
@torch.no_grad()
@@ -802,6 +810,8 @@ def contiguous_params(self):
802810

803811
@torch.no_grad()
804812
def save_model(self, path):
813+
if int(os.environ['RANK']) != 0:
814+
return
805815
if self.online_rotate:
806816
self.contiguous_params()
807817
if self.config.model.type == 'Llava':

llmc/data/dataset/base_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from abc import ABCMeta
23

34
import torch
@@ -84,6 +85,10 @@ def get_calib_samples(self):
8485

8586
def get_calib_dataset(self):
8687
samples = self.get_calib_samples()
88+
logger.info(f'len(samples) all : {len(samples)}')
89+
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
90+
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
91+
logger.info(f'len(samples) rank : {len(samples)}')
8792
calib_samples = []
8893
if self.calib_bs < 0:
8994
batch = torch.cat(samples, dim=0)

scripts/export_rtn_llama.sh

Lines changed: 0 additions & 9 deletions
This file was deleted.

scripts/run_adadim_llama.sh

Lines changed: 0 additions & 15 deletions
This file was deleted.

scripts/run_awq_llama.sh

Lines changed: 0 additions & 16 deletions
This file was deleted.

scripts/run_dgq_llama.sh

Lines changed: 0 additions & 16 deletions
This file was deleted.

scripts/run_gptq_llama.sh

Lines changed: 0 additions & 15 deletions
This file was deleted.

scripts/run_gptq_owq_llama.sh

Lines changed: 0 additions & 15 deletions
This file was deleted.

0 commit comments

Comments
 (0)