forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgit_utils.py
More file actions
415 lines (345 loc) · 17.8 KB
/
git_utils.py
File metadata and controls
415 lines (345 loc) · 17.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import os
from pathlib import Path
import subprocess
import tempfile
import warnings
import six
from six.moves import urllib
import re
from pathlib import Path
from urllib.parse import urlparse
def _sanitize_git_url(repo_url):
"""Sanitize Git repository URL to prevent URL injection attacks.
Args:
repo_url (str): The Git repository URL to sanitize
Returns:
str: The sanitized URL
Raises:
ValueError: If the URL contains suspicious patterns that could indicate injection
"""
at_count = repo_url.count("@")
if repo_url.startswith("git@"):
# git@ format requires exactly one @
if at_count != 1:
raise ValueError("Invalid SSH URL format: git@ URLs must have exactly one @ symbol")
elif repo_url.startswith("ssh://"):
# ssh:// format can have 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid SSH URL format: multiple @ symbols detected")
elif repo_url.startswith("https://") or repo_url.startswith("http://"):
# HTTPS format allows 0 or 1 @ symbols
if at_count > 1:
raise ValueError("Invalid HTTPS URL format: multiple @ symbols detected")
# Check for invalid characters in the URL before parsing
# These characters should not appear in legitimate URLs
invalid_chars = ["<", ">", "[", "]", "{", "}", "\\", "^", "`", "|"]
for char in invalid_chars:
if char in repo_url:
raise ValueError("Invalid characters in hostname")
try:
parsed = urlparse(repo_url)
# Check for suspicious characters in hostname that could indicate injection
if parsed.hostname:
# Check for URL-encoded characters that might be used for obfuscation
suspicious_patterns = ["%25", "%40", "%2F", "%3A"] # encoded %, @, /, :
for pattern in suspicious_patterns:
if pattern in parsed.hostname.lower():
raise ValueError(f"Suspicious URL encoding detected in hostname: {pattern}")
# Validate that the hostname looks legitimate
if not re.match(r"^[a-zA-Z0-9.-]+$", parsed.hostname):
raise ValueError("Invalid characters in hostname")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Failed to parse URL: {str(e)}")
else:
raise ValueError(
"Unsupported URL scheme: only https://, http://, git@, and ssh:// are allowed"
)
return repo_url
def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None):
"""Git clone repo containing the training code and serving code.
This method also validate ``git_config``, and set ``entry_point``,
``source_dir`` and ``dependencies`` to the right file or directory in the
repo cloned.
Args:
git_config (dict[str, str]): Git configurations used for cloning files,
including ``repo``, ``branch``, ``commit``, ``2FA_enabled``,
``username``, ``password`` and ``token``. The ``repo`` field is
required. All other fields are optional. ``repo`` specifies the Git
repository where your training script is stored. If you don't
provide ``branch``, the default value 'master' is used. If you don't
provide ``commit``, the latest commit in the specified branch is
used. ``2FA_enabled``, ``username``, ``password`` and ``token`` are
for authentication purpose. If ``2FA_enabled`` is not provided, we
consider 2FA as disabled.
For GitHub and GitHub-like repos, when SSH URLs are provided, it
doesn't matter whether 2FA is enabled or disabled; you should either
have no passphrase for the SSH key pairs, or have the ssh-agent
configured so that you will not be prompted for SSH passphrase when
you do 'git clone' command with SSH URLs. When https URLs are
provided: if 2FA is disabled, then either token or username+password
will be used for authentication if provided (token prioritized); if
2FA is enabled, only token will be used for authentication if
provided. If required authentication info is not provided, python
SDK will try to use local credentials storage to authenticate. If
that fails either, an error message will be thrown.
For CodeCommit repos, 2FA is not supported, so '2FA_enabled' should
not be provided. There is no token in CodeCommit, so 'token' should
not be provided too. When 'repo' is an SSH URL, the requirements are
the same as GitHub-like repos. When 'repo' is an https URL,
username+password will be used for authentication if they are
provided; otherwise, python SDK will try to use either CodeCommit
credential helper or local credential storage for authentication.
entry_point (str): A relative location to the Python source file which
should be executed as the entry point to training or model hosting
in the Git repo.
source_dir (str): A relative location to a directory with other training
or model hosting source code dependencies aside from the entry point
file in the Git repo (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
dependencies (list[str]): A list of relative locations to directories
with any additional libraries that will be exported to the container
in the Git repo (default: []).
Returns:
dict: A dict that contains the updated values of entry_point, source_dir
and dependencies.
Raises:
CalledProcessError: If 1. failed to clone git repo
2. failed to checkout the required branch
3. failed to checkout the required commit
ValueError: If 1. entry point specified does not exist in the repo
2. source dir specified does not exist in the repo
3. dependencies specified do not exist in the repo
4. wrong format is provided for git_config
"""
if entry_point is None:
raise ValueError("Please provide an entry point.")
_validate_git_config(git_config)
# SECURITY: Sanitize the repository URL to prevent injection attacks
git_config["repo"] = _sanitize_git_url(git_config["repo"])
dest_dir = tempfile.mkdtemp()
_generate_and_run_clone_command(git_config, dest_dir)
_checkout_branch_and_commit(git_config, dest_dir)
updated_paths = {
"entry_point": entry_point,
"source_dir": source_dir,
"dependencies": dependencies,
}
# check if the cloned repo contains entry point, source directory and dependencies
if source_dir:
if not os.path.isdir(os.path.join(dest_dir, source_dir)):
raise ValueError("Source directory does not exist in the repo.")
if not os.path.isfile(os.path.join(dest_dir, source_dir, entry_point)):
raise ValueError("Entry point does not exist in the repo.")
updated_paths["source_dir"] = os.path.join(dest_dir, source_dir)
else:
if os.path.isfile(os.path.join(dest_dir, entry_point)):
updated_paths["entry_point"] = os.path.join(dest_dir, entry_point)
else:
raise ValueError("Entry point does not exist in the repo.")
if dependencies is not None:
updated_paths["dependencies"] = []
for path in dependencies:
if os.path.exists(os.path.join(dest_dir, path)):
updated_paths["dependencies"].append(os.path.join(dest_dir, path))
else:
raise ValueError("Dependency {} does not exist in the repo.".format(path))
return updated_paths
def _validate_git_config(git_config):
"""Validates the git configuration.
Checks all configuration values except 2FA_enabled are string types. The
2FA_enabled configuration should be a boolean.
Args:
git_config: The configuration to validate.
"""
if "repo" not in git_config:
raise ValueError("Please provide a repo for git_config.")
for key in git_config:
if key == "2FA_enabled":
if not isinstance(git_config["2FA_enabled"], bool):
raise ValueError("Please enter a bool type for 2FA_enabled'.")
elif not isinstance(git_config[key], six.string_types):
raise ValueError("'{}' must be a string.".format(key))
def _generate_and_run_clone_command(git_config, dest_dir):
"""Check if a git_config param is valid.
If it is valid, create the command to git, clone the repo, and run it.
Args:
git_config ((dict[str, str]): Git configurations used for cloning files,
including ``repo``, ``branch`` and ``commit``.
dest_dir (str): The local directory to clone the Git repo into.
Raises:
CalledProcessError: If failed to clone git repo.
"""
if git_config["repo"].startswith("https://git-codecommit") or git_config["repo"].startswith(
"ssh://git-codecommit"
):
_clone_command_for_codecommit(git_config, dest_dir)
else:
_clone_command_for_github_like(git_config, dest_dir)
def _clone_command_for_github_like(git_config, dest_dir):
"""Check if a git_config param representing a GitHub (or like) repo is valid.
If it is valid, create the command to git clone the repo, and run it.
Args:
git_config ((dict[str, str]): Git configurations used for cloning files,
including ``repo``, ``branch`` and ``commit``.
dest_dir (str): The local directory to clone the Git repo into.
Raises:
ValueError: If git_config['repo'] is in the wrong format.
CalledProcessError: If failed to clone git repo.
"""
is_https = git_config["repo"].startswith("https://")
is_ssh = git_config["repo"].startswith("git@") or git_config["repo"].startswith("ssh://")
if not is_https and not is_ssh:
raise ValueError("Invalid Git url provided.")
if is_ssh:
_clone_command_for_ssh(git_config, dest_dir)
elif "2FA_enabled" in git_config and git_config["2FA_enabled"] is True:
_clone_command_for_github_like_https_2fa_enabled(git_config, dest_dir)
else:
_clone_command_for_github_like_https_2fa_disabled(git_config, dest_dir)
def _clone_command_for_ssh(git_config, dest_dir):
"""Placeholder docstring"""
if "username" in git_config or "password" in git_config or "token" in git_config:
warnings.warn("SSH cloning, authentication information in git config will be ignored.")
_run_clone_command(git_config["repo"], dest_dir)
def _clone_command_for_github_like_https_2fa_disabled(git_config, dest_dir):
"""Placeholder docstring"""
updated_url = git_config["repo"]
if "token" in git_config:
if "username" in git_config or "password" in git_config:
warnings.warn("Using token for authentication, " "other credentials will be ignored.")
updated_url = _insert_token_to_repo_url(url=git_config["repo"], token=git_config["token"])
elif "username" in git_config and "password" in git_config:
updated_url = _insert_username_and_password_to_repo_url(
url=git_config["repo"], username=git_config["username"], password=git_config["password"]
)
elif "username" in git_config or "password" in git_config:
warnings.warn("Credentials provided in git config will be ignored.")
_run_clone_command(updated_url, dest_dir)
def _clone_command_for_github_like_https_2fa_enabled(git_config, dest_dir):
"""Placeholder docstring"""
updated_url = git_config["repo"]
if "token" in git_config:
if "username" in git_config or "password" in git_config:
warnings.warn("Using token for authentication, " "other credentials will be ignored.")
updated_url = _insert_token_to_repo_url(url=git_config["repo"], token=git_config["token"])
_run_clone_command(updated_url, dest_dir)
def _clone_command_for_codecommit(git_config, dest_dir):
"""Check if a git_config param representing a CodeCommit repo is valid.
If it is, create the command to git clone the repo, and run it.
Args:
git_config ((dict[str, str]): Git configurations used for cloning files,
including ``repo``, ``branch`` and ``commit``.
dest_dir (str): The local directory to clone the Git repo into.
Raises:
ValueError: If git_config['repo'] is in the wrong format.
CalledProcessError: If failed to clone git repo.
"""
is_https = git_config["repo"].startswith("https://git-codecommit")
is_ssh = git_config["repo"].startswith("ssh://git-codecommit")
if not is_https and not is_ssh:
raise ValueError("Invalid Git url provided.")
if "2FA_enabled" in git_config:
warnings.warn("CodeCommit does not support 2FA, '2FA_enabled' will be ignored.")
if "token" in git_config:
warnings.warn("There are no tokens in CodeCommit, the token provided will be ignored.")
if is_ssh:
_clone_command_for_ssh(git_config, dest_dir)
else:
_clone_command_for_codecommit_https(git_config, dest_dir)
def _clone_command_for_codecommit_https(git_config, dest_dir):
"""Invoke the clone command for codecommit.
Args:
git_config: The git configuration.
dest_dir: The destination directory for the clone.
"""
updated_url = git_config["repo"]
if "username" in git_config and "password" in git_config:
updated_url = _insert_username_and_password_to_repo_url(
url=git_config["repo"], username=git_config["username"], password=git_config["password"]
)
elif "username" in git_config or "password" in git_config:
warnings.warn("Credentials provided in git config will be ignored.")
_run_clone_command(updated_url, dest_dir)
def _run_clone_command(repo_url, dest_dir):
"""Run the 'git clone' command with the repo url and the directory to clone the repo into.
Args:
repo_url (str): Git repo url to be cloned.
dest_dir: (str): Local path where the repo should be cloned into.
Raises:
CalledProcessError: If failed to clone git repo.
"""
my_env = os.environ.copy()
if repo_url.startswith("https://"):
my_env["GIT_TERMINAL_PROMPT"] = "0"
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
try:
with tempfile.TemporaryDirectory() as tmp_dir:
custom_ssh_executable = Path(tmp_dir) / "ssh_batch"
with open(custom_ssh_executable, "w") as pipe:
print("#!/bin/sh", file=pipe)
print("ssh -oBatchMode=yes $@", file=pipe)
os.chmod(custom_ssh_executable, 0o511)
my_env["GIT_SSH"] = str(custom_ssh_executable)
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
except subprocess.CalledProcessError:
del my_env["GIT_SSH"]
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
def _insert_token_to_repo_url(url, token):
"""Insert the token to the Git repo url, to make a component of the git clone command.
This method can only be called when repo_url is an https url.
Args:
url (str): Git repo url where the token should be inserted into.
token (str): Token to be inserted.
Returns:
str: the component needed fot the git clone command.
"""
index = len("https://")
if url.find(token) == index:
return url
return url.replace("https://", "https://" + token + "@")
def _insert_username_and_password_to_repo_url(url, username, password):
"""Insert username and password to the Git repo url to make a component of git clone command.
This method can only be called when repo_url is an https url.
Args:
url (str): Git repo url where the token should be inserted into.
username (str): Username to be inserted.
password (str): Password to be inserted.
Returns:
str: the component needed for the git clone command.
"""
password = urllib.parse.quote_plus(password)
# urllib parses ' ' as '+', but what we need is '%20' here
password = password.replace("+", "%20")
index = len("https://")
return url[:index] + username + ":" + password + "@" + url[index:]
def _checkout_branch_and_commit(git_config, dest_dir):
"""Checkout the required branch and commit.
Args:
git_config (dict[str, str]): Git configurations used for cloning files,
including ``repo``, ``branch`` and ``commit``.
dest_dir (str): the directory where the repo is cloned
Raises:
CalledProcessError: If 1. failed to checkout the required branch 2.
failed to checkout the required commit
"""
if "branch" in git_config:
subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(dest_dir))
if "commit" in git_config:
subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(dest_dir))