@@ -1250,3 +1250,131 @@ func TestDeployer_Wakeup(t *testing.T) {
12501250 time .Sleep (100 * time .Millisecond )
12511251 })
12521252}
1253+
1254+ // TestDeployer_serverlessDeploy_PD verifies that PD config is correctly
1255+ // propagated from DeployRequest to the database Deploy record in
1256+ // serverlessDeploy. Before the fix, field shadowing caused dr.PD to always
1257+ // be nil even when DeployExtend.PD was set.
1258+ func TestDeployer_serverlessDeploy_PD (t * testing.T ) {
1259+ t .Run ("deploy model with PD" , func (t * testing.T ) {
1260+ var oldDeploy database.Deploy
1261+ oldDeploy .ID = 1
1262+
1263+ pdConfig := & types.PDConfig {
1264+ Enabled : true ,
1265+ PrefillReplicas : 2 ,
1266+ DecodeReplicas : 2 ,
1267+ Prefill : & types.PDRoleRuntimeConfig {
1268+ TP : 2 , EP : 1 , DP : 1 , TotalGPUs : 2 ,
1269+ },
1270+ Decode : & types.PDRoleRuntimeConfig {
1271+ TP : 2 , EP : 1 , DP : 1 , TotalGPUs : 2 ,
1272+ },
1273+ }
1274+
1275+ dr := types.DeployRequest {
1276+ RepoID : 1 ,
1277+ Type : types .InferenceType ,
1278+ UserUUID : "1" ,
1279+ SKU : "1" ,
1280+ DeployExtend : types.DeployExtend {
1281+ PD : pdConfig ,
1282+ },
1283+ }
1284+
1285+ newDeploy := oldDeploy
1286+ newDeploy .UserUUID = dr .UserUUID
1287+ newDeploy .SKU = dr .SKU
1288+ newDeploy .PD = dr .PD
1289+
1290+ mockTaskStore := mockdb .NewMockDeployTaskStore (t )
1291+ mockTaskStore .EXPECT ().GetServerlessDeployByRepID (mock .Anything , dr .RepoID ).Return (& oldDeploy , nil )
1292+ mockTaskStore .EXPECT ().UpdateDeploy (mock .Anything , & newDeploy ).Return (nil )
1293+
1294+ d := & deployer {
1295+ deployTaskStore : mockTaskStore ,
1296+ }
1297+ dbdeploy , err := d .serverlessDeploy (context .TODO (), dr )
1298+ require .Nil (t , err )
1299+ require .NotNil (t , dbdeploy .PD )
1300+ require .True (t , dbdeploy .PD .Enabled )
1301+ require .Equal (t , 2 , dbdeploy .PD .PrefillReplicas )
1302+ require .Equal (t , 2 , dbdeploy .PD .DecodeReplicas )
1303+ require .Same (t , pdConfig , dbdeploy .PD )
1304+ })
1305+
1306+ t .Run ("deploy model without PD (nil)" , func (t * testing.T ) {
1307+ var oldDeploy database.Deploy
1308+ oldDeploy .ID = 1
1309+
1310+ dr := types.DeployRequest {
1311+ RepoID : 1 ,
1312+ Type : types .InferenceType ,
1313+ UserUUID : "1" ,
1314+ SKU : "1" ,
1315+ }
1316+
1317+ newDeploy := oldDeploy
1318+ newDeploy .UserUUID = dr .UserUUID
1319+ newDeploy .SKU = dr .SKU
1320+ newDeploy .PD = nil
1321+
1322+ mockTaskStore := mockdb .NewMockDeployTaskStore (t )
1323+ mockTaskStore .EXPECT ().GetServerlessDeployByRepID (mock .Anything , dr .RepoID ).Return (& oldDeploy , nil )
1324+ mockTaskStore .EXPECT ().UpdateDeploy (mock .Anything , & newDeploy ).Return (nil )
1325+
1326+ d := & deployer {
1327+ deployTaskStore : mockTaskStore ,
1328+ }
1329+ dbdeploy , err := d .serverlessDeploy (context .TODO (), dr )
1330+ require .Nil (t , err )
1331+ require .Nil (t , dbdeploy .PD )
1332+ })
1333+ }
1334+
1335+ // TestDeployer_dedicatedDeploy_PD verifies that PD config is correctly
1336+ // propagated from DeployRequest to the database Deploy record in
1337+ // dedicatedDeploy.
1338+ func TestDeployer_dedicatedDeploy_PD (t * testing.T ) {
1339+ pdConfig := & types.PDConfig {
1340+ Enabled : true ,
1341+ PrefillReplicas : 1 ,
1342+ DecodeReplicas : 1 ,
1343+ Prefill : & types.PDRoleRuntimeConfig {
1344+ TP : 2 , EP : 2 , DP : 1 , TotalGPUs : 2 ,
1345+ },
1346+ Decode : & types.PDRoleRuntimeConfig {
1347+ TP : 2 , EP : 2 , DP : 1 , TotalGPUs : 2 ,
1348+ },
1349+ }
1350+
1351+ dr := types.DeployRequest {
1352+ Path : "namespace/name" ,
1353+ Type : types .InferenceType ,
1354+ DeployExtend : types.DeployExtend {
1355+ PD : pdConfig ,
1356+ },
1357+ }
1358+
1359+ var capturedDeploy * database.Deploy
1360+ mockTaskStore := mockdb .NewMockDeployTaskStore (t )
1361+ mockTaskStore .EXPECT ().CreateDeploy (mock .Anything , mock .MatchedBy (func (d * database.Deploy ) bool {
1362+ capturedDeploy = d
1363+ return true
1364+ })).Return (nil )
1365+
1366+ node , _ := snowflake .NewNode (1 )
1367+ d := & deployer {
1368+ snowflakeNode : node ,
1369+ deployTaskStore : mockTaskStore ,
1370+ }
1371+
1372+ _ , err := d .dedicatedDeploy (context .TODO (), dr )
1373+ require .Nil (t , err )
1374+ require .NotNil (t , capturedDeploy )
1375+ require .NotNil (t , capturedDeploy .PD )
1376+ require .True (t , capturedDeploy .PD .Enabled )
1377+ require .Equal (t , 1 , capturedDeploy .PD .PrefillReplicas )
1378+ require .Equal (t , 1 , capturedDeploy .PD .DecodeReplicas )
1379+ require .Same (t , pdConfig , capturedDeploy .PD )
1380+ }
0 commit comments