Skip to content

Commit a54e8cd

Browse files
committed
Enable pylint in the github workflow
Signed-off-by: ted chang <htchang@us.ibm.com>
1 parent 0e60ecd commit a54e8cd

9 files changed

Lines changed: 52 additions & 49 deletions

File tree

.github/workflows/format.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,6 @@ jobs:
3535
python -m pip install -r setup_requirements.txt
3636
- name: Check Formatting
3737
run: tox -e fmt
38+
- name: Run pylint
39+
run: tox -e lint
3840

scripts/run_inference.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@ def _apply_config_changes(self, overrides: dict) -> dict:
7777
# If we have no overrides, this context manager is a noop; no need to do anything
7878
if not overrides:
7979
return {}
80-
with open(self.config_path, "r") as config_file:
80+
with open(self.config_path, "r", encoding="utf-8") as config_file:
8181
adapter_config = json.load(config_file)
8282
overridden_values = self._get_old_config_values(adapter_config, overrides)
8383
adapter_config = {**adapter_config, **overrides}
84-
with open(self.config_path, "w") as config_file:
84+
with open(self.config_path, "w", encoding="utf-8") as config_file:
8585
json.dump(adapter_config, config_file, indent=4)
8686
return overridden_values
8787

@@ -213,7 +213,8 @@ def main():
213213
)
214214
parser.add_argument(
215215
"--base_model_name_or_path",
216-
help="Override for base model to be used for non-merged models [default: value in model adapter_config.json]",
216+
help="Override for base model to be used for non-merged models \
217+
[default: value in model adapter_config.json]",
217218
default=None,
218219
)
219220
parser.add_argument(
@@ -243,7 +244,7 @@ def main():
243244
if args.text:
244245
texts = [args.text]
245246
else:
246-
with open(args.text_file, "r") as text_file:
247+
with open(args.text_file, "r", encoding="utf-8") as text_file:
247248
texts = [line.strip() for line in text_file.readlines()]
248249

249250
# TODO: we should add batch inference support
@@ -256,7 +257,7 @@ def main():
256257
]
257258

258259
# Export the results to a file
259-
with open(args.out_file, "w") as out_file:
260+
with open(args.out_file, "w", encoding="utf-8") as out_file:
260261
json.dump(results, out_file, sort_keys=True, indent=4)
261262

262263
print(f"Exported results to: {args.out_file}")

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ allowlist_externals = ./scripts/fmt.sh
1717
[testenv:lint]
1818
description = lint with pylint
1919
deps = pylint>=2.16.2,<=3.1.0
20+
-r requirements.txt
2021
commands = pylint tuning scripts/*.py
2122
allowlist_externals = pylint

tuning/config/configs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Standard
22
from dataclasses import dataclass, field
3-
from typing import Dict, Optional, Union
3+
from typing import Optional, Union
44

55
# Third Party
66
import torch
@@ -50,7 +50,8 @@ class TrainingArguments(transformers.TrainingArguments):
5050
model_max_length: int = field(
5151
default=DEFAULT_CONTEXT_LENGTH,
5252
metadata={
53-
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
53+
"help": "Maximum sequence length. Sequences will be right padded \
54+
(and possibly truncated)."
5455
},
5556
)
5657
packing: bool = field(

tuning/config/peft_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ class LoraConfig:
1010
target_modules: List[str] = field(
1111
default_factory=lambda: ["q_proj", "v_proj"],
1212
metadata={
13-
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
14-
'end with one of the strings. If the value is ["all-linear"], then LORA selects all linear and Conv1D '
13+
"help": "The names of the modules to apply LORA to. LORA selects modules which either \
14+
completely match or "
15+
'end with one of the strings. If the value is ["all-linear"], \
16+
then LORA selects all linear and Conv1D '
1517
"modules except for the output layer."
1618
},
1719
)

tuning/data/tokenizer_data_utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
# Standard
2-
from typing import Dict, Sequence
3-
import copy
4-
import json
5-
import logging
2+
from typing import Dict
63

74
# Third Party
8-
from torch.utils.data import Dataset
9-
import torch
105
import transformers
116

12-
# Local
13-
from tuning.config import configs
14-
157

168
def tokenizer_and_embedding_resize(
179
special_tokens_dict: Dict,

tuning/sft_trainer.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional, Union
44
import json
55
import os
6+
import sys
67

78
# Third Party
89
from peft.utils.other import fsdp_auto_wrap_policy
@@ -80,15 +81,15 @@ def _track_loss(self, loss_key, log_file, logs, state):
8081
return
8182

8283
# append the current log to the jsonl file
83-
with open(log_file, "a") as f:
84+
with open(log_file, "a", encoding="utf-8") as f:
8485
f.write(f"{json.dumps(log_obj, sort_keys=True)}\n")
8586

8687

8788
def train(
8889
model_args: configs.ModelArguments,
8990
data_args: configs.DataArguments,
9091
train_args: configs.TrainingArguments,
91-
peft_config: Optional[
92+
peft_configs: Optional[
9293
Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]
9394
] = None,
9495
):
@@ -98,7 +99,7 @@ def train(
9899
model_args: tuning.config.configs.ModelArguments
99100
data_args: tuning.config.configs.DataArguments
100101
train_args: tuning.config.configs.TrainingArguments
101-
peft_config: peft_config.LoraConfig for Lora tuning | \
102+
peft_configs: peft_config.LoraConfig for Lora tuning | \
102103
peft_config.PromptTuningConfig for prompt tuning | \
103104
None for fine tuning
104105
The peft configuration to pass to trainer
@@ -130,7 +131,7 @@ def train(
130131
use_flash_attention_2=model_args.use_flash_attn,
131132
)
132133

133-
peft_config = get_hf_peft_config(task_type, peft_config)
134+
peft_configs = get_hf_peft_config(task_type, peft_configs)
134135

135136
model.gradient_checkpointing_enable()
136137

@@ -140,9 +141,7 @@ def train(
140141
)
141142

142143
# TODO: understand if we need to hardcode these here or just use defaults in model
143-
if isinstance(tokenizer, LlamaTokenizer) or isinstance(
144-
tokenizer, LlamaTokenizerFast
145-
):
144+
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
146145
tokenizer.add_special_tokens(
147146
{
148147
"bos_token": "<s>",
@@ -151,33 +150,36 @@ def train(
151150
"pad_token": "<pad>",
152151
}
153152
)
154-
elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(
155-
tokenizer, GPT2Tokenizer
156-
):
153+
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
157154
tokenizer.add_special_tokens(
158155
{
159156
"pad_token": "<pad>",
160157
}
161158
)
162159

163-
"""TODO: near term - how response template ids are parsed out needs to be cleaned.
164-
The [2:] here applies if response template has \n prefix, it is needed to strip \n, otherwise template is not found.
165-
We will create issue to clean this out after we discuss data formats and collators we will support
166-
"""
160+
# TODO: near term - how response template ids are parsed out needs to be cleaned.
161+
# The [2:] here applies if response template has \n prefix, it is needed to strip \n,
162+
# otherwise template is not found. We will create issue to clean this out after we discuss
163+
# data formats and collators we will support.
167164
response_template_ids = tokenizer.encode(
168165
data_args.response_template, add_special_tokens=False
169166
)[2:]
170-
# TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
171-
# as in current main. We need to change name of this parameter we expose to users.
167+
# TODO: This is actually max_seq_length and not model_max_length. we should not override
168+
# model_max_length as in current main. We need to change name of this parameter we expose
169+
# to users.
172170
model_max_length = min(train_args.model_max_length, tokenizer.model_max_length)
173-
logger.info(f"Model max length {model_max_length}")
171+
logger.info("Model max length %s, model_max_length")
174172
if train_args.model_max_length > tokenizer.model_max_length:
175173
logger.warning(
176-
f"model_max_length {train_args.model_max_length} exceeds tokenizer.model_max_length {tokenizer.model_max_length}, using tokenizer.model_max_length {tokenizer.model_max_length}"
174+
"model_max_length %s exceeds tokenizer.model_max_length \
175+
%s, using tokenizer.model_max_length %s",
176+
train_args.model_max_length,
177+
tokenizer.model_max_length,
178+
tokenizer.model_max_length,
177179
)
178180

179181
# TODO: we need to change this, perhaps follow what open instruct does?
180-
special_tokens_dict = dict()
182+
special_tokens_dict = {}
181183
if tokenizer.pad_token is None:
182184
logger.warning("PAD token set to default, missing in tokenizer")
183185
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
@@ -205,19 +207,21 @@ def train(
205207
if data_args.validation_data_path:
206208
data_files["validation"] = data_args.validation_data_path
207209

208-
format_dataset = lambda example: {
210+
format_dataset = lambda example: { # pylint: disable=unnecessary-lambda-assignment
209211
f"{data_args.dataset_text_field}": example[f"{data_args.dataset_text_field}"]
210212
+ tokenizer.eos_token
211213
}
212214

213215
json_dataset = datasets.load_dataset("json", data_files=data_files)
214216
formatted_train_dataset = json_dataset["train"].map(format_dataset)
215-
logger.info(f"Training dataset length is {len(formatted_train_dataset)}")
217+
logger.info("Training dataset length is %s", len(formatted_train_dataset))
216218

217219
formatted_validation_dataset = None
218220
if data_args.validation_data_path:
219221
formatted_validation_dataset = json_dataset["validation"].map(format_dataset)
220-
logger.info(f"Validation dataset length is {len(formatted_validation_dataset)}")
222+
logger.info(
223+
"Validation dataset length is %s", len(formatted_validation_dataset)
224+
)
221225

222226
aim_callback = get_aimstack_callback()
223227
file_logger_callback = FileLoggingCallback(logger)
@@ -234,13 +238,13 @@ def train(
234238
logger.error(
235239
"Error, response template is None, needs to be set for training"
236240
)
237-
exit(-1)
241+
sys.exit(-1)
238242

239243
if data_args.dataset_text_field is None:
240244
logger.error(
241245
"Error, dataset_text_field is None, needs to be set for training"
242246
)
243-
exit(-1)
247+
sys.exit(-1)
244248

245249
data_collator = DataCollatorForCompletionOnlyLM(
246250
response_template_ids,
@@ -260,17 +264,17 @@ def train(
260264
args=train_args,
261265
max_seq_length=model_max_length,
262266
callbacks=callbacks,
263-
peft_config=peft_config,
267+
peft_config=peft_configs,
264268
)
265269

266-
if run_distributed and peft_config is not None:
270+
if run_distributed and peft_configs is not None:
267271
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(
268272
model
269273
)
270274
trainer.train()
271275

272276

273-
def main(**kwargs):
277+
def main(**kwargs): # pylint: disable=unused-argument
274278
parser = transformers.HfArgumentParser(
275279
dataclass_types=(
276280
configs.ModelArguments,
@@ -286,7 +290,7 @@ def main(**kwargs):
286290
choices=["pt", "lora", None, "none"],
287291
default="pt",
288292
)
289-
(
293+
( # pylint: disable=unbalanced-tuple-unpacking
290294
model_args,
291295
data_args,
292296
training_args,

tuning/utils/data_type_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Standard
22
from typing import Union
3+
import sys
34

45
# Third Party
56
from transformers.utils import logging

tuning/utils/merge_model_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Standard
22
from typing import Union
3-
import argparse
43
import json
54
import os
65

@@ -27,7 +26,7 @@ def create_merged_model(
2726
References:
2827
- https://github.com/huggingface/peft/issues/1040
2928
- https://github.com/huggingface/peft/issues/280#issuecomment-1500805831
30-
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter
29+
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter # pylint: disable=line-too-long
3130
3231
Args:
3332
checkpoint_model: Union[str, list[str]]
@@ -82,7 +81,7 @@ def fetch_base_model_from_checkpoint(checkpoint_model: str) -> str:
8281
if not os.path.isfile(adapter_config):
8382
raise FileNotFoundError("Unable to locate adapter config to infer base model!")
8483

85-
with open(adapter_config, "r") as cfg:
84+
with open(adapter_config, "r", encoding="utf-8") as cfg:
8685
adapter_dict = json.load(cfg)
8786
if "base_model_name_or_path" not in adapter_dict:
8887
raise KeyError(

0 commit comments

Comments
 (0)