@@ -30,6 +30,7 @@ import (
3030 "github.com/docker/go-connections/nat"
3131 "github.com/docker/go-units"
3232 "github.com/dstackai/dstack/runner/consts"
33+ "github.com/dstackai/dstack/runner/internal/common"
3334 "github.com/dstackai/dstack/runner/internal/log"
3435 "github.com/dstackai/dstack/runner/internal/shim/backends"
3536 "github.com/dstackai/dstack/runner/internal/shim/host"
@@ -54,7 +55,7 @@ type DockerRunner struct {
5455 dockerParams DockerParameters
5556 dockerInfo dockersystem.Info
5657 gpus []host.GpuInfo
57- gpuVendor host .GpuVendor
58+ gpuVendor common .GpuVendor
5859 gpuLock * GpuLock
5960 tasks TaskStorage
6061}
@@ -69,12 +70,12 @@ func NewDockerRunner(ctx context.Context, dockerParams DockerParameters) (*Docke
6970 return nil , tracerr .Wrap (err )
7071 }
7172
72- var gpuVendor host .GpuVendor
73+ var gpuVendor common .GpuVendor
7374 gpus := host .GetGpuInfo (ctx )
7475 if len (gpus ) > 0 {
7576 gpuVendor = gpus [0 ].Vendor
7677 } else {
77- gpuVendor = host .GpuVendorNone
78+ gpuVendor = common .GpuVendorNone
7879 }
7980 gpuLock , err := NewGpuLock (gpus )
8081 if err != nil {
@@ -134,7 +135,7 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error {
134135 log .Error (ctx , "failed to inspect container" , "id" , containerID , "task" , taskID )
135136 } else {
136137 switch d .gpuVendor {
137- case host .GpuVendorNvidia :
138+ case common .GpuVendorNvidia :
138139 deviceRequests := containerFull .HostConfig .Resources .DeviceRequests
139140 if len (deviceRequests ) == 1 {
140141 gpuIDs = deviceRequests [0 ].DeviceIDs
@@ -145,28 +146,28 @@ func (d *DockerRunner) restoreStateFromContainers(ctx context.Context) error {
145146 "id" , containerID , "task" , taskID ,
146147 )
147148 }
148- case host .GpuVendorAmd :
149+ case common .GpuVendorAmd :
149150 for _ , device := range containerFull .HostConfig .Resources .Devices {
150151 if host .IsRenderNodePath (device .PathOnHost ) {
151152 gpuIDs = append (gpuIDs , device .PathOnHost )
152153 }
153154 }
154- case host .GpuVendorTenstorrent :
155+ case common .GpuVendorTenstorrent :
155156 for _ , device := range containerFull .HostConfig .Resources .Devices {
156157 if strings .HasPrefix (device .PathOnHost , "/dev/tenstorrent/" ) {
157158 // Extract the device ID from the path
158159 deviceID := strings .TrimPrefix (device .PathOnHost , "/dev/tenstorrent/" )
159160 gpuIDs = append (gpuIDs , deviceID )
160161 }
161162 }
162- case host .GpuVendorIntel :
163+ case common .GpuVendorIntel :
163164 for _ , envVar := range containerFull .Config .Env {
164165 if indices , found := strings .CutPrefix (envVar , "HABANA_VISIBLE_DEVICES=" ); found {
165166 gpuIDs = strings .Split (indices , "," )
166167 break
167168 }
168169 }
169- case host .GpuVendorNone :
170+ case common .GpuVendorNone :
170171 gpuIDs = []string {}
171172 }
172173 ports = extractPorts (ctx , containerFull .NetworkSettings .Ports )
@@ -1014,12 +1015,12 @@ func configureGpuDevices(hostConfig *container.HostConfig, gpuDevices []GPUDevic
10141015 }
10151016}
10161017
1017- func configureGpus (config * container.Config , hostConfig * container.HostConfig , vendor host .GpuVendor , ids []string ) {
1018+ func configureGpus (config * container.Config , hostConfig * container.HostConfig , vendor common .GpuVendor , ids []string ) {
10181019 // NVIDIA: ids are identifiers reported by nvidia-smi, GPU-<UUID> strings
10191020 // AMD: ids are DRI render node paths, e.g., /dev/dri/renderD128
10201021 // Tenstorrent: ids are device indices to be used with /dev/tenstorrent/<id>
10211022 switch vendor {
1022- case host .GpuVendorNvidia :
1023+ case common .GpuVendorNvidia :
10231024 hostConfig .Resources .DeviceRequests = append (
10241025 hostConfig .Resources .DeviceRequests ,
10251026 container.DeviceRequest {
@@ -1030,7 +1031,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v
10301031 DeviceIDs : ids ,
10311032 },
10321033 )
1033- case host .GpuVendorAmd :
1034+ case common .GpuVendorAmd :
10341035 // All options are listed here: https://hub.docker.com/r/rocm/pytorch
10351036 // Only --device are mandatory, other seem to be performance-related.
10361037 // --device=/dev/kfd
@@ -1060,7 +1061,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v
10601061 // --security-opt=seccomp=unconfined
10611062 hostConfig .SecurityOpt = append (hostConfig .SecurityOpt , "seccomp=unconfined" )
10621063 // TODO: in addition, for non-root user, --group-add=video, and possibly --group-add=render, are required.
1063- case host .GpuVendorTenstorrent :
1064+ case common .GpuVendorTenstorrent :
10641065 // For Tenstorrent, simply add each device
10651066 for _ , id := range ids {
10661067 devicePath := fmt .Sprintf ("/dev/tenstorrent/%s" , id )
@@ -1081,7 +1082,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v
10811082 Target : "/dev/hugepages-1G" ,
10821083 })
10831084 }
1084- case host .GpuVendorIntel :
1085+ case common .GpuVendorIntel :
10851086 // All options are listed here:
10861087 // https://docs.habana.ai/en/latest/Installation_Guide/Additional_Installation/Docker_Installation.html
10871088 // --runtime=habana
@@ -1092,7 +1093,7 @@ func configureGpus(config *container.Config, hostConfig *container.HostConfig, v
10921093 hostConfig .CapAdd = append (hostConfig .CapAdd , "SYS_NICE" )
10931094 // -e HABANA_VISIBLE_DEVICES=0,1,...
10941095 config .Env = append (config .Env , fmt .Sprintf ("HABANA_VISIBLE_DEVICES=%s" , strings .Join (ids , "," )))
1095- case host .GpuVendorNone :
1096+ case common .GpuVendorNone :
10961097 // nothing to do
10971098 }
10981099}
0 commit comments