Skip to content

Commit 8e0a8f8

Browse files
authored
Merge pull request #63 from tedhtchang/Enable-pylint
Enable pylint in the github workflow
2 parents a93e3bc + a6cfa6a commit 8e0a8f8

10 files changed

Lines changed: 68 additions & 54 deletions

File tree

.github/workflows/format.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ jobs:
3535
python -m pip install -r setup_requirements.txt
3636
- name: Check Formatting
3737
run: tox -e fmt
38+
- name: Run pylint
39+
run: tox -e lint
3840

.pylintrc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,9 @@ disable=raw-checker-failed,
443443
attribute-defined-outside-init,
444444
abstract-method,
445445
pointless-statement,
446-
wrong-import-order
446+
wrong-import-order,
447+
duplicate-code,
448+
unbalanced-tuple-unpacking
447449

448450
# Enable the message, report, category or checker with the given id(s). You can
449451
# either give multiple identifier separated by comma (,) or put this option

build/launch_training.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_highest_checkpoint(dir_path):
4848
for curr_dir in os.listdir(dir_path):
4949
if curr_dir.startswith("checkpoint"):
5050
if checkpoint_dir:
51-
curr_dir_num = int(checkpoint_dir.split("-")[-1])
51+
curr_dir_num = int(checkpoint_dir.rsplit("-", maxsplit=1)[-1])
5252
new_dir_num = int(curr_dir.split("-")[-1])
5353
if new_dir_num > curr_dir_num:
5454
checkpoint_dir = curr_dir
@@ -87,13 +87,13 @@ def main():
8787
) = parser.parse_json_file(json_path, allow_extra_keys=True)
8888

8989
contents = ""
90-
with open(json_path, "r") as f:
90+
with open(json_path, "r", encoding="utf-8") as f:
9191
contents = json.load(f)
9292
peft_method_parsed = contents.get("peft_method")
93-
logging.debug(f"Input params parsed: {contents}")
93+
logging.debug("Input params parsed: %s", contents)
9494
elif json_env_var:
9595
job_config_dict = txt_to_obj(json_env_var)
96-
logging.debug(f"Input params parsed: {job_config_dict}")
96+
logging.debug("Input params parsed: %s", job_config_dict)
9797

9898
(
9999
model_args,
@@ -106,7 +106,8 @@ def main():
106106
peft_method_parsed = job_config_dict.get("peft_method")
107107
else:
108108
raise ValueError(
109-
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
109+
"Must set environment variable 'SFT_TRAINER_CONFIG_JSON_PATH' \
110+
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
110111
)
111112

112113
tune_config = None
@@ -118,7 +119,12 @@ def main():
118119
tune_config = prompt_tuning_config
119120

120121
logging.debug(
121-
f"Parameters used to launch training: model_args {model_args}, data_args {data_args}, training_args {training_args}, tune_config {tune_config}"
122+
"Parameters used to launch training: \
123+
model_args %s, data_args %s, training_args %s, tune_config %s",
124+
model_args,
125+
data_args,
126+
training_args,
127+
tune_config,
122128
)
123129

124130
original_output_dir = training_args.output_dir
@@ -138,7 +144,9 @@ def main():
138144
)
139145

140146
logging.info(
141-
f"Merging lora tuned checkpoint {lora_checkpoint_dir} with base model into output path: {export_path}"
147+
"Merging lora tuned checkpoint %s with base model into output path: %s",
148+
lora_checkpoint_dir,
149+
export_path,
142150
)
143151

144152
create_merged_model(
@@ -151,7 +159,9 @@ def main():
151159
# copy last checkpoint into mounted output dir
152160
pt_checkpoint_dir = get_highest_checkpoint(training_args.output_dir)
153161
logging.info(
154-
f"Copying last checkpoint {pt_checkpoint_dir} into output dir {original_output_dir}"
162+
"Copying last checkpoint %s into output dir %s",
163+
pt_checkpoint_dir,
164+
original_output_dir,
155165
)
156166
shutil.copytree(
157167
os.path.join(training_args.output_dir, pt_checkpoint_dir),

scripts/run_inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ def _apply_config_changes(self, overrides: dict) -> dict:
9191
# If we have no overrides, this context manager is a noop; no need to do anything
9292
if not overrides:
9393
return {}
94-
with open(self.config_path, "r") as config_file:
94+
with open(self.config_path, "r", encoding="utf-8") as config_file:
9595
adapter_config = json.load(config_file)
9696
overridden_values = self._get_old_config_values(adapter_config, overrides)
9797
adapter_config = {**adapter_config, **overrides}
98-
with open(self.config_path, "w") as config_file:
98+
with open(self.config_path, "w", encoding="utf-8") as config_file:
9999
json.dump(adapter_config, config_file, indent=4)
100100
return overridden_values
101101

@@ -227,7 +227,8 @@ def main():
227227
)
228228
parser.add_argument(
229229
"--base_model_name_or_path",
230-
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
230+
help="Override for base model to be used for non-merged models \
231+
[default: value in model adapter_config.json]",
231232
default=None,
232233
)
233234
parser.add_argument(
@@ -257,7 +258,7 @@ def main():
257258
if args.text:
258259
texts = [args.text]
259260
else:
260-
with open(args.text_file, "r") as text_file:
261+
with open(args.text_file, "r", encoding="utf-8") as text_file:
261262
texts = [line.strip() for line in text_file.readlines()]
262263

263264
# TODO: we should add batch inference support
@@ -270,7 +271,7 @@ def main():
270271
]
271272

272273
# Export the results to a file
273-
with open(args.out_file, "w") as out_file:
274+
with open(args.out_file, "w", encoding="utf-8") as out_file:
274275
json.dump(results, out_file, sort_keys=True, indent=4)
275276

276277
print(f"Exported results to: {args.out_file}")

tox.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ allowlist_externals = ./scripts/fmt.sh
1717
[testenv:lint]
1818
description = lint with pylint
1919
deps = pylint>=2.16.2,<=3.1.0
20-
commands = pylint tuning scripts/*.py
20+
-r requirements.txt
21+
commands = pylint tuning scripts/*.py build/*.py
2122
allowlist_externals = pylint

tuning/config/configs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
# Standard
1616
from dataclasses import dataclass, field
17-
from typing import Dict, Optional, Union
17+
from typing import Optional, Union
1818

1919
# Third Party
2020
import torch
@@ -64,7 +64,8 @@ class TrainingArguments(transformers.TrainingArguments):
6464
model_max_length: int = field(
6565
default=DEFAULT_CONTEXT_LENGTH,
6666
metadata={
67-
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
67+
"help": "Maximum sequence length. Sequences will be right padded \
68+
(and possibly truncated)."
6869
},
6970
)
7071
packing: bool = field(

tuning/config/peft_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ class LoraConfig:
2424
target_modules: List[str] = field(
2525
default_factory=lambda: ["q_proj", "v_proj"],
2626
metadata={
27-
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
28-
'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D '
27+
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
28+
completely match or "
29+
'end with one of the strings. If the value is ["all-linear"], \
30+
then LORA selects all linear and Conv1D '
2931
"modules except for the output layer."
3032
},
3133
)

tuning/data/tokenizer_data_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,11 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, Sequence
17-
import copy
18-
import json
19-
import logging
16+
from typing import Dict
2017

2118
# Third Party
22-
from torch.utils.data import Dataset
23-
import torch
2419
import transformers
2520

26-
# Local
27-
from tuning.config import configs
28-
2921

3022
def tokenizer_and_embedding_resize(
3123
special_tokens_dict: Dict,

tuning/sft_trainer.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Optional, Union
1818
import json
1919
import os
20+
import sys
2021

2122
# Third Party
2223
from peft.utils.other import fsdp_auto_wrap_policy
@@ -94,15 +95,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
9495
return
9596

9697
# append the current log to the jsonl file
97-
with open(log_file, "a") as f:
98+
with open(log_file, "a", encoding="utf-8") as f:
9899
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
99100

100101

101102
def train(
102103
model_args: configs.ModelArguments,
103104
data_args: configs.DataArguments,
104105
train_args: configs.TrainingArguments,
105-
peft_config: Optional[
106+
peft_config: Optional[ # pylint: disable=redefined-outer-name
106107
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
107108
] = None,
108109
):
@@ -154,9 +155,7 @@ def train(
154155
)
155156

156157
# TODO: understand if we need to hardcode these here or just use defaults in model
157-
if isinstance(tokenizer, LlamaTokenizer) or isinstance(
158-
tokenizer, LlamaTokenizerFast
159-
):
158+
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
160159
tokenizer.add_special_tokens(
161160
{
162161
"bos_token": "<s>",
@@ -165,33 +164,36 @@ def train(
165164
"pad_token": "<pad>",
166165
}
167166
)
168-
elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(
169-
tokenizer, GPT2Tokenizer
170-
):
167+
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
171168
tokenizer.add_special_tokens(
172169
{
173170
"pad_token": "<pad>",
174171
}
175172
)
176173

177-
"""TODO: near term - how response template ids are parsed out needs to be cleaned.
178-
The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found.
179-
We will create issue to clean this out after we discuss data formats and collators we will support
180-
"""
174+
# TODO: near term - how response template ids are parsed out needs to be cleaned.
175+
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
176+
# otherwise template is not found. We will create issue to clean this out after we discuss
177+
# data formats and collators we will support.
181178
response_template_ids = tokenizer.encode(
182179
data_args.response_template, add_special_tokens=False
183180
)[2:]
184-
# TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
185-
# as in current main. We need to change name of this parameter we expose to users.
181+
# TODO: This is actually max_seq_length and not model_max_length. we should not override
182+
# model_max_length as in current main. We need to change name of this parameter we expose
183+
# to users.
186184
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
187-
logger.info(f"Model max length {model_max_length}")
185+
logger.info("Model max length %s, model_max_length")
188186
if train_args.model_max_length > tokenizer.model_max_length:
189187
logger.warning(
190-
f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}"
188+
"model_max_length %s exceeds tokenizer.model_max_length \
189+
%s, using tokenizer.model_max_length %s",
190+
train_args.model_max_length,
191+
tokenizer.model_max_length,
192+
tokenizer.model_max_length,
191193
)
192194

193195
# TODO: we need to change this, perhaps follow what open instruct does?
194-
special_tokens_dict = dict()
196+
special_tokens_dict = {}
195197
if tokenizer.pad_token is None:
196198
logger.warning("PAD token set to default, missing in tokenizer")
197199
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
@@ -219,19 +221,21 @@ def train(
219221
if data_args.validation_data_path:
220222
data_files["validation"] = data_args.validation_data_path
221223

222-
format_dataset = lambda example: {
224+
format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment
223225
f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"]
224226
+ tokenizer.eos_token
225227
}
226228

227229
json_dataset = datasets.load_dataset("json", data_files=data_files)
228230
formatted_train_dataset = json_dataset["train"].map(format_dataset)
229-
logger.info(f"Training dataset length is {len(formatted_train_dataset)}")
231+
logger.info("Training dataset length is %s", len(formatted_train_dataset))
230232

231233
formatted_validation_dataset = None
232234
if data_args.validation_data_path:
233235
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
234-
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")
236+
logger.info(
237+
"Validation dataset length is %s", len(formatted_validation_dataset)
238+
)
235239

236240
aim_callback = get_aimstack_callback()
237241
file_logger_callback = FileLoggingCallback(logger)
@@ -248,13 +252,13 @@ def train(
248252
logger.error(
249253
"Error, response template is None, needs to be set for training"
250254
)
251-
exit(-1)
255+
sys.exit(-1)
252256

253257
if data_args.dataset_text_field is None:
254258
logger.error(
255259
"Error, dataset_text_field is None, needs to be set for training"
256260
)
257-
exit(-1)
261+
sys.exit(-1)
258262

259263
data_collator = DataCollatorForCompletionOnlyLM(
260264
response_template_ids,
@@ -284,7 +288,7 @@ def train(
284288
trainer.train()
285289

286290

287-
def main(**kwargs):
291+
def main(**kwargs): # pylint: disable=unused-argument
288292
parser = transformers.HfArgumentParser(
289293
dataclass_types=(
290294
configs.ModelArguments,

tuning/utils/merge_model_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
# Standard
1616
from typing import Union
17-
import argparse
1817
import json
1918
import os
2019

@@ -41,7 +40,7 @@ def create_merged_model(
4140
References:
4241
- https://github.com/huggingface/peft/issues/1040
4342
- https://github.com/huggingface/peft/issues/280#issuecomment-1500805831
44-
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter
43+
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long
4544
4645
Args:
4746
checkpoint_model: Union[str, list[str]]
@@ -96,7 +95,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
9695
if not os.path.isfile(adapter_config):
9796
raise FileNotFoundError("Unable to locate adapter config to infer base model!")
9897

99-
with open(adapter_config, "r") as cfg:
98+
with open(adapter_config, "r", encoding="utf-8") as cfg:
10099
adapter_dict = json.load(cfg)
101100
if "base_model_name_or_path" not in adapter_dict:
102101
raise KeyError(

0 commit comments

Comments
 (0)