Skip to content

Commit 464aa5c

Browse files
committed
fix: Support for docker compose > v2 (5739)
1 parent 91ca011 commit 464aa5c

File tree

5 files changed

+219
-17
lines changed

5 files changed

+219
-17
lines changed

sagemaker-core/src/sagemaker/core/local/image.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(
138138
def _get_compose_cmd_prefix():
139139
"""Gets the Docker Compose command.
140140
141-
The method initially looks for 'docker compose' v2
141+
The method initially looks for 'docker compose' v2+
142142
executable, if not found looks for 'docker-compose' executable.
143143
144144
Returns:
@@ -162,10 +162,12 @@ def _get_compose_cmd_prefix():
162162
"Proceeding to check for 'docker-compose' CLI."
163163
)
164164

165-
if output and "v2" in output.strip():
166-
logger.info("'Docker Compose' found using Docker CLI.")
167-
compose_cmd_prefix.extend(["docker", "compose"])
168-
return compose_cmd_prefix
165+
if output:
166+
match = re.search(r"v(\d+)", output.strip())
167+
if match and int(match.group(1)) >= 2:
168+
logger.info("'Docker Compose' found using Docker CLI.")
169+
compose_cmd_prefix.extend(["docker", "compose"])
170+
return compose_cmd_prefix
169171

170172
if shutil.which("docker-compose") is not None:
171173
logger.info("'Docker Compose' found using Docker Compose CLI.")

sagemaker-core/src/sagemaker/core/modules/local_core/local_container.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
593593
def _get_compose_cmd_prefix(self) -> List[str]:
594594
"""Gets the Docker Compose command.
595595
596-
The method initially looks for 'docker compose' v2
596+
The method initially looks for 'docker compose' v2+
597597
executable, if not found looks for 'docker-compose' executable.
598598
599599
Returns:
@@ -617,10 +617,12 @@ def _get_compose_cmd_prefix(self) -> List[str]:
617617
"Proceeding to check for 'docker-compose' CLI."
618618
)
619619

620-
if output and "v2" in output.strip():
621-
logger.info("'Docker Compose' found using Docker CLI.")
622-
compose_cmd_prefix.extend(["docker", "compose"])
623-
return compose_cmd_prefix
620+
if output:
621+
match = re.search(r"v(\d+)", output.strip())
622+
if match and int(match.group(1)) >= 2:
623+
logger.info("'Docker Compose' found using Docker CLI.")
624+
compose_cmd_prefix.extend(["docker", "compose"])
625+
return compose_cmd_prefix
624626

625627
if shutil.which("docker-compose") is not None:
626628
logger.info("'Docker Compose' found using Docker Compose CLI.")

sagemaker-core/tests/unit/modules/local_core/test_local_container.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,82 @@ def test_get_compose_cmd_prefix_not_found(
638638
with pytest.raises(ImportError, match="Docker Compose is not installed"):
639639
container._get_compose_cmd_prefix()
640640

641+
@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
642+
def test_get_compose_cmd_prefix_docker_compose_v5(
643+
self, mock_check_output, mock_session, basic_channel
644+
):
645+
"""Test _get_compose_cmd_prefix accepts Docker Compose v5"""
646+
container = _LocalContainer(
647+
training_job_name="test-job",
648+
instance_type="local",
649+
instance_count=1,
650+
image="test-image:latest",
651+
container_root="/tmp/test",
652+
input_data_config=[basic_channel],
653+
environment={},
654+
hyper_parameters={},
655+
container_entrypoint=[],
656+
container_arguments=[],
657+
sagemaker_session=mock_session,
658+
)
659+
660+
mock_check_output.return_value = "Docker Compose version v5.1.1"
661+
662+
result = container._get_compose_cmd_prefix()
663+
664+
assert result == ["docker", "compose"]
665+
666+
@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
667+
def test_get_compose_cmd_prefix_docker_compose_v3(
668+
self, mock_check_output, mock_session, basic_channel
669+
):
670+
"""Test _get_compose_cmd_prefix accepts Docker Compose v3"""
671+
container = _LocalContainer(
672+
training_job_name="test-job",
673+
instance_type="local",
674+
instance_count=1,
675+
image="test-image:latest",
676+
container_root="/tmp/test",
677+
input_data_config=[basic_channel],
678+
environment={},
679+
hyper_parameters={},
680+
container_entrypoint=[],
681+
container_arguments=[],
682+
sagemaker_session=mock_session,
683+
)
684+
685+
mock_check_output.return_value = "Docker Compose version v3.0.0"
686+
687+
result = container._get_compose_cmd_prefix()
688+
689+
assert result == ["docker", "compose"]
690+
691+
@patch("sagemaker.core.modules.local_core.local_container.subprocess.check_output")
692+
@patch("sagemaker.core.modules.local_core.local_container.shutil.which")
693+
def test_get_compose_cmd_prefix_docker_compose_v1_rejected(
694+
self, mock_which, mock_check_output, mock_session, basic_channel
695+
):
696+
"""Test _get_compose_cmd_prefix rejects Docker Compose v1"""
697+
container = _LocalContainer(
698+
training_job_name="test-job",
699+
instance_type="local",
700+
instance_count=1,
701+
image="test-image:latest",
702+
container_root="/tmp/test",
703+
input_data_config=[basic_channel],
704+
environment={},
705+
hyper_parameters={},
706+
container_entrypoint=[],
707+
container_arguments=[],
708+
sagemaker_session=mock_session,
709+
)
710+
711+
mock_check_output.return_value = "docker-compose version v1.29.2"
712+
mock_which.return_value = None
713+
714+
with pytest.raises(ImportError, match="Docker Compose is not installed"):
715+
container._get_compose_cmd_prefix()
716+
641717
def test_init_with_container_entrypoint(self, mock_session, basic_channel):
642718
"""Test initialization with container entrypoint"""
643719
container = _LocalContainer(

sagemaker-train/src/sagemaker/train/local/local_container.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def _get_data_source_local_path(self, data_source: DataSource):
601601
def _get_compose_cmd_prefix(self) -> List[str]:
602602
"""Gets the Docker Compose command.
603603
604-
The method initially looks for 'docker compose' v2
604+
The method initially looks for 'docker compose' v2+
605605
executable, if not found looks for 'docker-compose' executable.
606606
607607
Returns:
@@ -625,10 +625,12 @@ def _get_compose_cmd_prefix(self) -> List[str]:
625625
"Proceeding to check for 'docker-compose' CLI."
626626
)
627627

628-
if output and "v2" in output.strip():
629-
logger.info("'Docker Compose' found using Docker CLI.")
630-
compose_cmd_prefix.extend(["docker", "compose"])
631-
return compose_cmd_prefix
628+
if output:
629+
match = re.search(r"v(\d+)", output.strip())
630+
if match and int(match.group(1)) >= 2:
631+
logger.info("'Docker Compose' found using Docker CLI.")
632+
compose_cmd_prefix.extend(["docker", "compose"])
633+
return compose_cmd_prefix
632634

633635
if shutil.which("docker-compose") is not None:
634636
logger.info("'Docker Compose' found using Docker Compose CLI.")

sagemaker-train/tests/unit/train/local/test_local_container.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from unittest.mock import patch, call
13+
from unittest.mock import patch, call, Mock
1414
import pytest
15+
import subprocess
1516

16-
from sagemaker.train.local.local_container import _rmtree
17+
from sagemaker.train.local.local_container import _rmtree, _LocalContainer
18+
from sagemaker.core.shapes import DataSource, S3DataSource
19+
from sagemaker.core.shapes import Channel
1720

1821
IMAGE = "763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.1-cpu-py310"
1922

@@ -78,3 +81,120 @@ def test_rmtree_no_image_raises(self, mock_rmtree):
7881

7982
with pytest.raises(PermissionError):
8083
_rmtree("/tmp/test")
84+
85+
86+
@pytest.fixture
87+
def _basic_channel():
88+
"""Create a basic channel for testing."""
89+
data_source = DataSource(
90+
s3_data_source=S3DataSource(
91+
s3_uri="s3://bucket/data",
92+
s3_data_type="S3Prefix",
93+
s3_data_distribution_type="FullyReplicated",
94+
)
95+
)
96+
return Channel(channel_name="training", data_source=data_source)
97+
98+
99+
@pytest.fixture
100+
def _mock_session():
101+
"""Create a mock SageMaker session."""
102+
session = Mock()
103+
session.boto_region_name = "us-west-2"
104+
session.boto_session = Mock()
105+
session.s3_resource = Mock()
106+
session.s3_resource.meta.client._endpoint.host = "https://s3.us-west-2.amazonaws.com"
107+
return session
108+
109+
110+
def _make_container(_mock_session, _basic_channel):
111+
"""Helper to create a _LocalContainer for compose prefix tests."""
112+
return _LocalContainer(
113+
training_job_name="test-job",
114+
instance_type="local",
115+
instance_count=1,
116+
image="test-image:latest",
117+
container_root="/tmp/test",
118+
input_data_config=[_basic_channel],
119+
environment={},
120+
hyper_parameters={},
121+
container_entrypoint=[],
122+
container_arguments=[],
123+
sagemaker_session=_mock_session,
124+
)
125+
126+
127+
class TestGetComposeCmdPrefix:
128+
"""Test cases for _get_compose_cmd_prefix version detection."""
129+
130+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
131+
def test_get_compose_cmd_prefix_with_docker_compose_v2(self, mock_check_output, _mock_session, _basic_channel):
132+
"""Docker Compose v2 should be accepted."""
133+
container = _make_container(_mock_session, _basic_channel)
134+
mock_check_output.return_value = "Docker Compose version v2.20.0"
135+
result = container._get_compose_cmd_prefix()
136+
assert result == ["docker", "compose"]
137+
138+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
139+
def test_get_compose_cmd_prefix_with_docker_compose_v5(self, mock_check_output, _mock_session, _basic_channel):
140+
"""Docker Compose v5 should be accepted."""
141+
container = _make_container(_mock_session, _basic_channel)
142+
mock_check_output.return_value = "Docker Compose version v5.1.1"
143+
result = container._get_compose_cmd_prefix()
144+
assert result == ["docker", "compose"]
145+
146+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
147+
def test_get_compose_cmd_prefix_with_docker_compose_v3(self, mock_check_output, _mock_session, _basic_channel):
148+
"""Docker Compose v3 should be accepted."""
149+
container = _make_container(_mock_session, _basic_channel)
150+
mock_check_output.return_value = "Docker Compose version v3.0.0"
151+
result = container._get_compose_cmd_prefix()
152+
assert result == ["docker", "compose"]
153+
154+
@patch("sagemaker.train.local.local_container.shutil.which")
155+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
156+
def test_get_compose_cmd_prefix_with_docker_compose_v1_falls_through(
157+
self, mock_check_output, mock_which, _mock_session, _basic_channel
158+
):
159+
"""Docker Compose v1 should not be accepted; falls through to docker-compose standalone."""
160+
container = _make_container(_mock_session, _basic_channel)
161+
mock_check_output.return_value = "docker-compose version 1.29.2"
162+
mock_which.return_value = "/usr/bin/docker-compose"
163+
result = container._get_compose_cmd_prefix()
164+
assert result == ["docker-compose"]
165+
166+
@patch("sagemaker.train.local.local_container.shutil.which")
167+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
168+
def test_get_compose_cmd_prefix_with_docker_compose_v1_no_standalone_raises(
169+
self, mock_check_output, mock_which, _mock_session, _basic_channel
170+
):
171+
"""Docker Compose v1 with no standalone fallback should raise ImportError."""
172+
container = _make_container(_mock_session, _basic_channel)
173+
mock_check_output.return_value = "docker-compose version v1.29.2"
174+
mock_which.return_value = None
175+
with pytest.raises(ImportError, match="Docker Compose is not installed"):
176+
container._get_compose_cmd_prefix()
177+
178+
@patch("sagemaker.train.local.local_container.shutil.which")
179+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
180+
def test_get_compose_cmd_prefix_not_installed_raises(
181+
self, mock_check_output, mock_which, _mock_session, _basic_channel
182+
):
183+
"""When docker compose is not installed at all, should raise ImportError."""
184+
container = _make_container(_mock_session, _basic_channel)
185+
mock_check_output.side_effect = subprocess.CalledProcessError(1, "cmd")
186+
mock_which.return_value = None
187+
with pytest.raises(ImportError, match="Docker Compose is not installed"):
188+
container._get_compose_cmd_prefix()
189+
190+
@patch("sagemaker.train.local.local_container.shutil.which")
191+
@patch("sagemaker.train.local.local_container.subprocess.check_output")
192+
def test_get_compose_cmd_prefix_standalone_fallback(
193+
self, mock_check_output, mock_which, _mock_session, _basic_channel
194+
):
195+
"""When docker compose plugin fails, falls back to docker-compose standalone."""
196+
container = _make_container(_mock_session, _basic_channel)
197+
mock_check_output.side_effect = subprocess.CalledProcessError(1, "cmd")
198+
mock_which.return_value = "/usr/local/bin/docker-compose"
199+
result = container._get_compose_cmd_prefix()
200+
assert result == ["docker-compose"]

0 commit comments

Comments
 (0)