Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/aws-proxy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
LOCALSTACK_AUTH_TOKEN: ${{ secrets.LOCALSTACK_AUTH_TOKEN }}
run: |
set -e
cd aws-proxy
docker pull localstack/localstack-pro &
docker pull public.ecr.aws/lambda/python:3.8 &

Expand All @@ -49,7 +50,6 @@ jobs:
# build and install extension
localstack extensions init
(
cd aws-proxy
make install
. .venv/bin/activate
pip install --upgrade --pre localstack localstack-ext
Expand Down
38 changes: 16 additions & 22 deletions aws-proxy/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,39 @@ VENV_RUN = . $(VENV_ACTIVATE)
TEST_PATH ?= tests
PIP_CMD ?= pip

usage: ## Show this help
usage: ## Show this help
@grep -Fh "##" $(MAKEFILE_LIST) | grep -Fv fgrep | sed -e 's/:.*##\s*/##/g' | awk -F'##' '{ printf "%-25s %s\n", $$1, $$2 }'

venv: $(VENV_ACTIVATE)

$(VENV_ACTIVATE): setup.py setup.cfg
install: ## Install dependencies
test -d .venv || $(VENV_BIN) .venv
$(VENV_RUN); pip install --upgrade pip setuptools plux wheel
$(VENV_RUN); pip install --upgrade black isort pyproject-flake8 flake8-black flake8-isort
$(VENV_RUN); pip install -e .
$(VENV_RUN); pip install -e .[test]
touch $(VENV_DIR)/bin/activate

clean:
clean: ## Clean up
rm -rf .venv/
rm -rf build/
rm -rf .eggs/
rm -rf *.egg-info/

lint:
$(VENV_RUN); python -m pflake8 --show-source

format:
$(VENV_RUN); python -m isort .; python -m black .
format: ## Run ruff to format the whole codebase
($(VENV_RUN); python -m ruff format .; python -m ruff check --output-format=full --fix .)

install: venv
$(VENV_RUN); $(PIP_CMD) install -e ".[test]"
lint: ## Run code linter to check code style
($(VENV_RUN); python -m ruff check --output-format=full . && python -m ruff format --check .)

test: venv
test: ## Run tests
$(VENV_RUN); python -m pytest $(PYTEST_ARGS) $(TEST_PATH)

dist: venv
$(VENV_RUN); python setup.py sdist bdist_wheel
entrypoints: ## Generate plugin entrypoints for Python package
$(VENV_RUN); python -m plux entrypoints

build: ## Build the extension
mkdir -p build
cp -r setup.py setup.cfg README.md aws_proxy build/
(cd build && python setup.py sdist)
build: entrypoints ## Build the extension
$(VENV_RUN); python -m build --no-isolation . --outdir build
@# make sure that the entrypoints are contained in the dist folder and are non-empty
@test -s localstack_extension_aws_proxy.egg-info/entry_points.txt || (echo "Entrypoints were not correctly created! Aborting!" && exit 1)

enable: $(wildcard ./build/dist/localstack_extension_aws_proxy-*.tar.gz) ## Enable the extension in LocalStack
enable: $(wildcard ./build/localstack_extension_aws_proxy-*.tar.gz) ## Enable the extension in LocalStack
$(VENV_RUN); \
pip uninstall --yes localstack-extension-aws-proxy; \
localstack extensions -v install file://$?
Expand Down
63 changes: 48 additions & 15 deletions aws-proxy/aws_proxy/client/auth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from localstack import config as localstack_config
from localstack.aws.spec import load_service
from localstack.config import external_service_url
from localstack.constants import AWS_REGION_US_EAST_1, DOCKER_IMAGE_NAME_PRO, LOCALHOST_HOSTNAME
from localstack.constants import (
AWS_REGION_US_EAST_1,
DOCKER_IMAGE_NAME_PRO,
LOCALHOST_HOSTNAME,
)
from localstack.http import Request
from localstack.pro.core.bootstrap.licensingv2 import (
ENV_LOCALSTACK_API_KEY,
Expand All @@ -25,7 +29,10 @@
from localstack.utils.bootstrap import setup_logging
from localstack.utils.collections import select_attributes
from localstack.utils.container_utils.container_client import PortMappings
from localstack.utils.docker_utils import DOCKER_CLIENT, reserve_available_container_port
from localstack.utils.docker_utils import (
DOCKER_CLIENT,
reserve_available_container_port,
)
from localstack.utils.files import new_tmp_file, save_file
from localstack.utils.functions import run_safe
from localstack.utils.net import get_docker_host_from_container, get_free_tcp_port
Expand All @@ -39,8 +46,6 @@
from aws_proxy.shared.constants import HEADER_HOST_ORIGINAL
from aws_proxy.shared.models import AddProxyRequest, ProxyConfig

from .http2_server import run_server

LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
if localstack_config.DEBUG:
Expand All @@ -66,9 +71,14 @@ def __init__(self, config: ProxyConfig, port: int = None):
super().__init__(port=port)

def do_run(self):
# note: keep import here, to avoid runtime errors
from .http2_server import run_server

self.register_in_instance()
bind_host = self.config.get("bind_host") or DEFAULT_BIND_HOST
proxy = run_server(port=self.port, bind_addresses=[bind_host], handler=self.proxy_request)
proxy = run_server(
port=self.port, bind_addresses=[bind_host], handler=self.proxy_request
)
proxy.join()

def proxy_request(self, request: Request, data: bytes) -> Response:
Expand Down Expand Up @@ -109,7 +119,9 @@ def proxy_request(self, request: Request, data: bytes) -> Response:
# adjust request dict and fix certain edge cases in the request
self._adjust_request_dict(service_name, request_dict)

headers_truncated = {k: truncate(to_str(v)) for k, v in dict(aws_request.headers).items()}
headers_truncated = {
k: truncate(to_str(v)) for k, v in dict(aws_request.headers).items()
}
LOG.debug(
"Sending request for service %s to AWS: %s %s - %s - %s",
service_name,
Expand Down Expand Up @@ -138,7 +150,9 @@ def proxy_request(self, request: Request, data: bytes) -> Response:
return response
except Exception as e:
if LOG.isEnabledFor(logging.DEBUG):
LOG.exception("Error when making request to AWS service %s: %s", service_name, e)
LOG.exception(
"Error when making request to AWS service %s: %s", service_name, e
)
return requests_response("", status_code=400)

def register_in_instance(self):
Expand Down Expand Up @@ -224,7 +238,10 @@ def _adjust_request_dict(self, service_name: str, request_dict: Dict):
body_str = run_safe(lambda: to_str(req_body)) or ""

# TODO: this custom fix should not be required - investigate and remove!
if "<CreateBucketConfiguration" in body_str and "LocationConstraint" not in body_str:
if (
"<CreateBucketConfiguration" in body_str
and "LocationConstraint" not in body_str
):
region = request_dict["context"]["client_region"]
if region == AWS_REGION_US_EAST_1:
request_dict["body"] = ""
Expand All @@ -238,15 +255,19 @@ def _adjust_request_dict(self, service_name: str, request_dict: Dict):
account_id = self._query_account_id_from_aws()
if "QueueUrl" in req_body:
queue_name = req_body["QueueUrl"].split("/")[-1]
req_body["QueueUrl"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
req_body["QueueUrl"] = (
f"https://queue.amazonaws.com/{account_id}/{queue_name}"
)
if "QueueOwnerAWSAccountId" in req_body:
req_body["QueueOwnerAWSAccountId"] = account_id
if service_name == "sqs" and request_dict.get("url"):
req_json = run_safe(lambda: json.loads(body_str)) or {}
account_id = self._query_account_id_from_aws()
queue_name = req_json.get("QueueName")
if account_id and queue_name:
request_dict["url"] = f"https://queue.amazonaws.com/{account_id}/{queue_name}"
request_dict["url"] = (
f"https://queue.amazonaws.com/{account_id}/{queue_name}"
)
req_json["QueueOwnerAWSAccountId"] = account_id
request_dict["body"] = to_bytes(json.dumps(req_json))

Expand All @@ -256,7 +277,9 @@ def _fix_headers(self, request: Request, service_name: str):
host = request.headers.get("Host") or ""
regex = r"^(https?://)?([0-9.]+|localhost)(:[0-9]+)?"
if re.match(regex, host):
request.headers["Host"] = re.sub(regex, rf"\1s3.{LOCALHOST_HOSTNAME}", host)
request.headers["Host"] = re.sub(
regex, rf"\1s3.{LOCALHOST_HOSTNAME}", host
)
request.headers.pop("Content-Length", None)
request.headers.pop("x-localstack-request-url", None)
request.headers.pop("X-Forwarded-For", None)
Expand Down Expand Up @@ -311,7 +334,9 @@ def start_aws_auth_proxy_in_container(
# should consider building pre-baked images for the extension in the future. Also,
# the new packaged CLI binary can help us gain more stability over time...

logging.getLogger("localstack.utils.container_utils.docker_cmd_client").setLevel(logging.INFO)
logging.getLogger("localstack.utils.container_utils.docker_cmd_client").setLevel(
logging.INFO
)
logging.getLogger("localstack.utils.docker_utils").setLevel(logging.INFO)
logging.getLogger("localstack.utils.run").setLevel(logging.INFO)

Expand All @@ -328,13 +353,18 @@ def start_aws_auth_proxy_in_container(
image_name = DOCKER_IMAGE_NAME_PRO
# add host mapping for localstack.cloud to localhost to prevent the health check from failing
additional_flags = (
repl_config.PROXY_DOCKER_FLAGS + " --add-host=localhost.localstack.cloud:host-gateway"
repl_config.PROXY_DOCKER_FLAGS
+ " --add-host=localhost.localstack.cloud:host-gateway"
)
DOCKER_CLIENT.create_container(
image_name,
name=container_name,
entrypoint="",
command=["bash", "-c", f"touch {CONTAINER_LOG_FILE}; tail -f {CONTAINER_LOG_FILE}"],
command=[
"bash",
"-c",
f"touch {CONTAINER_LOG_FILE}; tail -f {CONTAINER_LOG_FILE}",
],
ports=ports,
additional_flags=additional_flags,
)
Expand Down Expand Up @@ -388,7 +418,10 @@ def start_aws_auth_proxy_in_container(
command = f"{venv_activate}; localstack aws proxy -c {CONTAINER_CONFIG_FILE} -p {port} --host 0.0.0.0 > {CONTAINER_LOG_FILE} 2>&1"
if use_docker_sdk_command:
DOCKER_CLIENT.exec_in_container(
container_name, command=["bash", "-c", command], env_vars=env_vars, interactive=True
container_name,
command=["bash", "-c", command],
env_vars=env_vars,
interactive=True,
)
else:
env_vars_list = []
Expand Down
12 changes: 9 additions & 3 deletions aws-proxy/aws_proxy/client/http2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def _do_create():
def _encode_headers(headers):
if RETURN_CASE_SENSITIVE_HEADERS:
return [(key.encode(), value.encode()) for key, value in headers.items()]
return [(key.lower().encode(), value.encode()) for key, value in headers.items()]
return [
(key.lower().encode(), value.encode()) for key, value in headers.items()
]

quart_asgi._encode_headers = quart_asgi.encode_headers = _encode_headers
quart_app.encode_headers = quart_utils.encode_headers = _encode_headers
Expand All @@ -116,7 +118,9 @@ def build_and_validate_headers(headers):
for name, value in headers:
if name[0] == b":"[0]:
raise ValueError("Pseudo headers are not valid")
header_name = bytes(name) if RETURN_CASE_SENSITIVE_HEADERS else bytes(name).lower()
header_name = (
bytes(name) if RETURN_CASE_SENSITIVE_HEADERS else bytes(name).lower()
)
validated_headers.append((header_name.strip(), bytes(value).strip()))
return validated_headers

Expand Down Expand Up @@ -212,7 +216,9 @@ async def index(path=None):
response.headers.pop("Content-Length", None)
result.headers.pop("Server", None)
result.headers.pop("Date", None)
headers = {k: str(v).replace("\n", r"\n") for k, v in result.headers.items()}
headers = {
k: str(v).replace("\n", r"\n") for k, v in result.headers.items()
}
response.headers.update(headers)
# set multi-value headers
multi_value_headers = getattr(result, "multi_value_headers", {})
Expand Down
48 changes: 38 additions & 10 deletions aws-proxy/aws_proxy/server/aws_request_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class AwsProxyHandler(Handler):
# maps port numbers to proxy instances
PROXY_INSTANCES: Dict[int, ProxyInstance] = {}

def __call__(self, chain: HandlerChain, context: RequestContext, response: Response):
def __call__(
self, chain: HandlerChain, context: RequestContext, response: Response
):
proxy = self.select_proxy(context)
if not proxy:
return
Expand Down Expand Up @@ -63,7 +65,9 @@ def select_proxy(self, context: RequestContext) -> Optional[ProxyInstance]:
proxy = self.PROXY_INSTANCES[port]
proxy_config = proxy.get("config") or {}
services = proxy_config.get("services") or {}
service_name = self._get_canonical_service_name(context.service.service_name)
service_name = self._get_canonical_service_name(
context.service.service_name
)
service_config = services.get(service_name)
if not service_config:
continue
Expand Down Expand Up @@ -100,7 +104,9 @@ def _request_matches_resource(
self, context: RequestContext, resource_name_pattern: str
) -> bool:
try:
service_name = self._get_canonical_service_name(context.service.service_name)
service_name = self._get_canonical_service_name(
context.service.service_name
)
if service_name == "s3":
bucket_name = context.service_request.get("Bucket") or ""
s3_bucket_arn = arns.s3_bucket_arn(bucket_name)
Expand All @@ -113,7 +119,9 @@ def _request_matches_resource(
queue_name,
queue_url,
sqs_queue_arn(
queue_name, account_id=context.account_id, region_name=context.region
queue_name,
account_id=context.account_id,
region_name=context.region,
),
)
for candidate in candidates:
Expand All @@ -133,12 +141,16 @@ def _request_matches_resource(
) from e
return True

def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requests.Response:
def forward_request(
self, context: RequestContext, proxy: ProxyInstance
) -> requests.Response:
"""Forward the given request to the proxy instance, and return the response."""
port = proxy["port"]
request = context.request
target_host = get_addressable_container_host(default_local_hostname=LOCALHOST)
url = f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}"
url = (
f"http://{target_host}:{port}{request.path}?{to_str(request.query_string)}"
)

# inject Auth header, to ensure we're passing the right region to the proxy (e.g., for Cognito InitiateAuth)
self._extract_region_from_domain(context)
Expand All @@ -156,10 +168,20 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
data = request.form
elif request.data:
data = request.data
LOG.debug("Forward request: %s %s - %s - %s", request.method, url, dict(headers), data)
LOG.debug(
"Forward request: %s %s - %s - %s",
request.method,
url,
dict(headers),
data,
)
# construct response
result = requests.request(
method=request.method, url=url, data=data, headers=dict(headers), stream=True
method=request.method,
url=url,
data=data,
headers=dict(headers),
stream=True,
)
# TODO: ugly hack for now, simply attaching an additional attribute for raw response content
result.raw_content = result.raw.read()
Expand All @@ -173,7 +195,10 @@ def forward_request(self, context: RequestContext, proxy: ProxyInstance) -> requ
)
except requests.exceptions.ConnectionError:
# remove unreachable proxy
LOG.info("Removing unreachable AWS forward proxy due to connection issue: %s", url)
LOG.info(
"Removing unreachable AWS forward proxy due to connection issue: %s",
url,
)
self.PROXY_INSTANCES.pop(port, None)
return result

Expand All @@ -186,7 +211,10 @@ def _is_read_request(self, context: RequestContext) -> bool:
if operation_name.lower().startswith(("describe", "get", "list", "query")):
return True
# service-specific rules
if context.service.service_name == "cognito-idp" and operation_name == "InitiateAuth":
if (
context.service.service_name == "cognito-idp"
and operation_name == "InitiateAuth"
):
return True
if context.service.service_name == "dynamodb" and operation_name in {
"Scan",
Expand Down
Loading