Skip to content

Commit e2a2e98

Browse files
committed
Merge remote-tracking branch 'upstream/main'
Signed-off-by: Dushyant Behl <dushyantbehl@users.noreply.github.com>
2 parents 1836a67 + 6eba340 commit e2a2e98

21 files changed

Lines changed: 575 additions & 133 deletions

.github/dependabot.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
version: 2
2+
updates:
3+
- package-ecosystem: "pip"
4+
directory: "/"
5+
schedule:
6+
interval: "daily"

.github/workflows/format.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright The Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
name: Format
16+
17+
on:
18+
push:
19+
branches: [ "main" ]
20+
pull_request:
21+
branches: [ "main" ]
22+
23+
jobs:
24+
build:
25+
runs-on: ubuntu-latest
26+
steps:
27+
- uses: actions/checkout@v3
28+
- name: Set up Python 3.9
29+
uses: actions/setup-python@v4
30+
with:
31+
python-version: 3.9
32+
- name: Install dependencies
33+
run: |
34+
python -m pip install --upgrade pip
35+
python -m pip install -r setup_requirements.txt
36+
- name: Check Formatting
37+
run: tox -e fmt
38+

.gitignore

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
*.egg-info
2+
*.pyc
3+
__pycache__
4+
.coverage
5+
.coverage.*
6+
durations/*
7+
coverage*.xml
8+
dist
9+
htmlcov
10+
build
11+
test
12+
13+
# IDEs
14+
.vscode/
15+
.idea/
16+
17+
# Env files
18+
.env
19+
20+
# Virtual Env
21+
venv/
22+
.venv/
23+
24+
# Mac personalization files
25+
*.DS_Store
26+
27+
# Tox envs
28+
.tox

.isort.cfg

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[settings]
2+
profile=black
3+
from_first=true
4+
import_heading_future=Future
5+
import_heading_stdlib=Standard
6+
import_heading_thirdparty=Third Party
7+
import_heading_firstparty=First Party
8+
import_heading_localfolder=Local
9+
known_firstparty=
10+
known_localfolder=tuning

.pre-commit-config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
repos:
2+
- repo: https://github.com/psf/black
3+
rev: 22.3.0
4+
hooks:
5+
- id: black
6+
exclude: imports
7+
- repo: https://github.com/PyCQA/isort
8+
rev: 5.11.5
9+
hooks:
10+
- id: isort
11+
exclude: imports

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ tokenizers>=0.13.3
88
tqdm
99
trl
1010
ninja
11-
peft
11+
peft>=0.8.0
1212
datasets>=2.15.0
1313
flash-attn
1414
fire

scripts/fmt.sh

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env bash
2+
3+
pre-commit run --all-files
4+
RETURN_CODE=$?
5+
6+
function echoWarning() {
7+
LIGHT_YELLOW='\033[1;33m'
8+
NC='\033[0m' # No Color
9+
echo -e "${LIGHT_YELLOW}${1}${NC}"
10+
}
11+
12+
if [ "$RETURN_CODE" -ne 0 ]; then
13+
if [ "${CI}" != "true" ]; then
14+
echoWarning "☝️ This appears to have failed, but actually your files have been formatted."
15+
echoWarning "Make a new commit with these changes before making a pull request."
16+
else
17+
echoWarning "This test failed because your code isn't formatted correctly."
18+
echoWarning 'Locally, run `make run fmt`, it will appear to fail, but change files.'
19+
echoWarning "Add the changed files to your commit and this stage will pass."
20+
fi
21+
22+
exit $RETURN_CODE
23+
fi

scripts/run_inference.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
99
If these things change in the future, we should consider breaking it up.
1010
"""
11+
# Standard
1112
import argparse
1213
import json
1314
import os
15+
16+
# Third Party
1417
from peft import AutoPeftModelForCausalLM
15-
import torch
1618
from tqdm import tqdm
1719
from transformers import AutoTokenizer
20+
import torch
1821

1922

2023
### Utilities
@@ -30,10 +33,13 @@ class AdapterConfigPatcher:
3033
# When loaded in this block, the config's base_model_name_or_path is "foo"
3134
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
3235
"""
36+
3337
def __init__(self, checkpoint_path: str, overrides: dict):
3438
self.checkpoint_path = checkpoint_path
3539
self.overrides = overrides
36-
self.config_path = AdapterConfigPatcher._locate_adapter_config(self.checkpoint_path)
40+
self.config_path = AdapterConfigPatcher._locate_adapter_config(
41+
self.checkpoint_path
42+
)
3743
# Values that we will patch later on
3844
self.patched_values = {}
3945

@@ -58,7 +64,7 @@ def _locate_adapter_config(checkpoint_path: str) -> str:
5864
def _apply_config_changes(self, overrides: dict) -> dict:
5965
"""Applies a patch to a config with some override dict, returning the values
6066
that we patched over so that they may be restored later.
61-
67+
6268
Args:
6369
overrides: dict
6470
Overrides to write into the adapter_config.json. Currently, we
@@ -99,7 +105,9 @@ def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict:
99105
# For now, we only expect to patch the base model; this may change in the future,
100106
# but ensure that anything we are patching is defined in the original config
101107
if not set(overrides.keys()).issubset(set(adapter_config.keys())):
102-
raise KeyError("Adapter config overrides must be set in the config being patched")
108+
raise KeyError(
109+
"Adapter config overrides must be set in the config being patched"
110+
)
103111
return {key: adapter_config[key] for key in overrides}
104112

105113
def __enter__(self):
@@ -119,7 +127,9 @@ def __init__(self, model, tokenizer, device):
119127
self.device = device
120128

121129
@classmethod
122-
def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "TunedCausalLM":
130+
def load(
131+
cls, checkpoint_path: str, base_model_name_or_path: str = None
132+
) -> "TunedCausalLM":
123133
"""Loads an instance of this model.
124134
125135
Args:
@@ -138,7 +148,11 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
138148
TunedCausalLM
139149
An instance of this class on which we can run inference.
140150
"""
141-
overrides = {"base_model_name_or_path": base_model_name_or_path} if base_model_name_or_path is not None else {}
151+
overrides = (
152+
{"base_model_name_or_path": base_model_name_or_path}
153+
if base_model_name_or_path is not None
154+
else {}
155+
)
142156
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
143157
# Apply the configs to the adapter config of this model; if no overrides
144158
# are provided, then the context manager doesn't have any effect.
@@ -153,7 +167,6 @@ def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "Tuned
153167
peft_model.to(device)
154168
return cls(peft_model, tokenizer, device)
155169

156-
157170
def run(self, text: str, *, max_new_tokens: int) -> str:
158171
"""Runs inference on an instance of this model.
159172
@@ -165,13 +178,17 @@ def run(self, text: str, *, max_new_tokens: int) -> str:
165178
166179
Returns:
167180
str
168-
Text generation result.
181+
Text generation result.
169182
"""
170183
tok_res = self.tokenizer(text, return_tensors="pt")
171184
input_ids = tok_res.input_ids.to(self.device)
172185

173-
peft_outputs = self.peft_model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
174-
decoded_result = self.tokenizer.batch_decode(peft_outputs, skip_special_tokens=False)[0]
186+
peft_outputs = self.peft_model.generate(
187+
input_ids=input_ids, max_new_tokens=max_new_tokens
188+
)
189+
decoded_result = self.tokenizer.batch_decode(
190+
peft_outputs, skip_special_tokens=False
191+
)[0]
175192
return decoded_result
176193

177194

@@ -180,7 +197,9 @@ def main():
180197
parser = argparse.ArgumentParser(
181198
description="Loads a tuned model and runs an inference call(s) through it"
182199
)
183-
parser.add_argument("--model", help="Path to tuned model to be loaded", required=True)
200+
parser.add_argument(
201+
"--model", help="Path to tuned model to be loaded", required=True
202+
)
184203
parser.add_argument(
185204
"--out_file",
186205
help="JSON file to write results to",
@@ -189,7 +208,7 @@ def main():
189208
parser.add_argument(
190209
"--base_model_name_or_path",
191210
help="Override for base model to be used [default: value in model adapter_config.json]",
192-
default=None
211+
default=None,
193212
)
194213
parser.add_argument(
195214
"--max_new_tokens",
@@ -199,7 +218,10 @@ def main():
199218
)
200219
group = parser.add_mutually_exclusive_group(required=True)
201220
group.add_argument("--text", help="Text to run inference on")
202-
group.add_argument("--text_file", help="File to be processed where each line is a text to run inference on")
221+
group.add_argument(
222+
"--text_file",
223+
help="File to be processed where each line is a text to run inference on",
224+
)
203225
args = parser.parse_args()
204226
# If we passed a file, check if it exists before doing anything else
205227
if args.text_file and not os.path.isfile(args.text_file):
@@ -220,7 +242,10 @@ def main():
220242

221243
# TODO: we should add batch inference support
222244
results = [
223-
{"input": text, "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens)}
245+
{
246+
"input": text,
247+
"output": loaded_model.run(text, max_new_tokens=args.max_new_tokens),
248+
}
224249
for text in tqdm(texts)
225250
]
226251

@@ -230,5 +255,6 @@ def main():
230255

231256
print(f"Exported results to: {args.out_file}")
232257

258+
233259
if __name__ == "__main__":
234260
main()

setup.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1+
# Third Party
12
from setuptools import find_packages, setup
23

3-
setup(
4-
name="tuning",
5-
version="0.0.1",
6-
packages=find_packages()
7-
)
4+
setup(name="tuning", version="0.0.1", packages=find_packages())

setup_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
pre-commit>=3.0.4,<4.0
2+
pydeps>=1.12.12,<2
3+
tox>=4.4.2,<5

0 commit comments

Comments
 (0)