Skip to content

Commit 0f43e56

Browse files
高魏洪qwencoder
andcommitted
refactor(model):重构模型下载逻辑,支持多种解决方案并优化资源自动部署
Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
1 parent d2f1a3b commit 0f43e56

3 files changed

Lines changed: 683 additions & 277 deletions

File tree

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
// 文生图服务下载,删除逻辑
2+
import logger from '../../logger';
3+
import DevClient, * as $Dev20230714 from '@alicloud/devs20230714';
4+
import * as $OpenApi from '@alicloud/openapi-client';
5+
import { MODEL_DOWNLOAD_TIMEOUT, NEW_MODEL_SERVICE_CLIENT_CONNECT_TIMEOUT, NEW_MODEL_SERVICE_CLIENT_READ_TIMEOUT } from '.';
6+
import { IInputs } from '../../interface';
7+
import { sleep } from '../../utils';
8+
import _ from 'lodash';
9+
10+
export class ArtModelService {
11+
logger = logger;
12+
client: DevClient;
13+
constructor(private inputs: IInputs) {
14+
}
15+
16+
async initClient() {
17+
const {
18+
AccessKeyID: accessKeyId,
19+
AccessKeySecret: accessKeySecret,
20+
SecurityToken: securityToken,
21+
} = await this.inputs.getCredential();
22+
23+
let endpoint: string;
24+
25+
endpoint = 'devs-pre.cn-hangzhou.aliyuncs.com';
26+
if (process.env.ARTIFACT_ENDPOINT) {
27+
endpoint = process.env.ARTIFACT_ENDPOINT;
28+
}
29+
if (process.env.artifact_endpoint) {
30+
endpoint = process.env.artifact_endpoint;
31+
}
32+
33+
const protocol = 'https';
34+
35+
const config = new $OpenApi.Config({
36+
accessKeyId,
37+
accessKeySecret,
38+
securityToken,
39+
protocol,
40+
endpoint,
41+
readTimeout: NEW_MODEL_SERVICE_CLIENT_READ_TIMEOUT,
42+
connectTimeout: NEW_MODEL_SERVICE_CLIENT_CONNECT_TIMEOUT,
43+
userAgent: `${this.inputs.userAgent ||
44+
`Component:cap-model;Nodejs:${process.version};OS:${process.platform}-${process.arch}`
45+
}`,
46+
});
47+
48+
logger.info(`new models service init, ARTIFACT_ENDPOINT endpoint: ${config.endpoint}`);
49+
50+
return new DevClient(config);
51+
}
52+
53+
async downloadModel(name, params) {
54+
const devClient = await this.initClient();
55+
const {
56+
nasMountPoints,
57+
ossMountPoints,
58+
role,
59+
modelConfig,
60+
vpcConfig,
61+
storage,
62+
region
63+
} = params;
64+
const files = modelConfig.files;
65+
if (!_.isEmpty(files)) {
66+
// 先统一获取已有的任务列表
67+
const ListFileManagerTasksRequest = new $Dev20230714.ListFileManagerTasksRequest({ name });
68+
const res = await devClient.listFileManagerTasks(ListFileManagerTasksRequest);
69+
logger.debug('listFileManagerTasks', JSON.stringify(res, null, 2));
70+
const existingTasks = res.body.data.tasks;
71+
72+
// 第一步:筛选真正需要下载的文件
73+
const filesNeedPromises = files.map(async file => {
74+
const { source, destination } = this.getSource(storage, modelConfig.source.uri, file, nasMountPoints, ossMountPoints, region);
75+
76+
const needDownload = !existingTasks.some(
77+
task =>
78+
task.finished &&
79+
task.success &&
80+
task.progress.currentBytes === task.progress.totalBytes &&
81+
task.parameters.destination === destination &&
82+
task.parameters.source === source
83+
);
84+
85+
if (!needDownload) {
86+
logger.info(`[Download-model] ${file.source.path} Download model finished.`);
87+
return null;
88+
}
89+
90+
return {
91+
...file,
92+
source,
93+
fileName: file.target.path,
94+
destination
95+
};
96+
});
97+
98+
const filesNeedResults = await Promise.all(filesNeedPromises);
99+
const filesNeed = filesNeedResults.filter(Boolean);
100+
101+
// 添加调试日志
102+
logger.info(`[Download-model] Total files to check: ${files.length}`);
103+
logger.info(`[Download-model] Files need to download: ${filesNeed.length}`);
104+
105+
const filesApi = filesNeed.map(async (file) => {
106+
const { source, destination } = file;
107+
const fileManagerRsyncRequest = new $Dev20230714.FileManagerRsyncRequest({
108+
mountConfig: new $Dev20230714.FileManagerMountConfig({
109+
name,
110+
nasMountPoints,
111+
ossMountPoints,
112+
role,
113+
region,
114+
vpcConfig
115+
}),
116+
source,
117+
destination,
118+
conflictHandling: process.env.MODEL_CONFLIC_HANDLING || modelConfig.conflictResolution
119+
})
120+
logger.debug('FileManagerRsyncRequest', JSON.stringify(fileManagerRsyncRequest, null, 2))
121+
const req = await devClient.fileManagerRsync(fileManagerRsyncRequest)
122+
logger.debug(`[Download-model] fileManagerRsync response for ${file.fileName}: ${JSON.stringify(req.body, null, 2)}`)
123+
if (!req.body.success) {
124+
throw new Error(`fileManagerRsync error: ${JSON.stringify(req.body, null, 2)}`)
125+
}
126+
127+
const taskID = req.body.data.taskID;
128+
logger.info(`[Download-model] download model requestId for ${file.fileName}: ${req.body.requestId}, taskID: ${taskID}`);
129+
const shouldContinue = true;
130+
while (shouldContinue) {
131+
// eslint-disable-next-line no-await-in-loop
132+
const getFileManager = await devClient.getFileManagerTask(taskID)
133+
logger.debug('getFileManagerTask', JSON.stringify(getFileManager, null, 2));
134+
const modelStatus = getFileManager.body.data;
135+
const totalBytes = modelStatus.progress.totalBytes as any - 0;
136+
const currentBytes = modelStatus.progress.currentBytes as any - 0;
137+
138+
if (modelStatus.finished) {
139+
// 如果存在错误信息,则抛出异常
140+
if (modelStatus.errorMessage) {
141+
logger.error(`[Download-model] ${file.fileName}: ${modelStatus.errorMessage}`);
142+
throw new Error(`[Download-model] ${file.fileName}: ${modelStatus.errorMessage}`);
143+
}
144+
// 下载成功完成
145+
this._displayProgressComplete(file.fileName, currentBytes, totalBytes);
146+
if (modelStatus.progress.total) {
147+
const durationMs = modelStatus.finishedTime - modelStatus.startTime;
148+
const durationSeconds = Math.floor(durationMs / 1000);
149+
logger.info(`Time taken for model download: ${durationSeconds}s.`);
150+
}
151+
logger.info(`[Download-model] ${file.fileName} Download model finished.`);
152+
return true;
153+
}
154+
// 显示下载进度
155+
this._displayProgress(file.fileName, currentBytes, totalBytes);
156+
157+
if (Date.now() - modelStatus.startTime > MODEL_DOWNLOAD_TIMEOUT) {
158+
// 清除进度条并换行
159+
process.stdout.write('\n');
160+
const errorMessage = `[Model-download] ${file.fileName} Download timeout after ${MODEL_DOWNLOAD_TIMEOUT / 1000 / 60
161+
} minutes`;
162+
throw new Error(errorMessage);
163+
}
164+
165+
// 根据文件大小调整轮询间隔
166+
let sleepTime = 2; // 默认2秒
167+
if (totalBytes !== undefined && totalBytes > 1024 * 1024 * 1024) {
168+
// 文件大于1GB时,轮询间隔为10秒
169+
sleepTime = 10;
170+
}
171+
172+
await sleep(sleepTime);
173+
}
174+
})
175+
await Promise.all(filesApi)
176+
}
177+
}
178+
179+
async removeModel(name, params) {
180+
const {
181+
nasMountPoints,
182+
ossMountPoints,
183+
role,
184+
vpcConfig,
185+
modelConfig,
186+
} = params;
187+
const devClient = await this.initClient();
188+
const files = modelConfig.files;
189+
190+
const filesRm = files.map(async file => {
191+
let filepath = this._getDestinationPath(file.target.path, nasMountPoints, ossMountPoints)
192+
193+
const fileManagerRmRequest = new $Dev20230714.FileManagerRmRequest({
194+
filepath,
195+
mountConfignew: new $Dev20230714.FileManagerMountConfig({
196+
name,
197+
nasMountPoints,
198+
ossMountPoints,
199+
role,
200+
vpcConfig
201+
}),
202+
skipParentDirectory: true
203+
})
204+
await devClient.fileManagerRm(fileManagerRmRequest);
205+
})
206+
207+
await Promise.all(filesRm);
208+
}
209+
210+
getSource(storage, uri, file, nasMountPoints, ossMountPoints, region) {
211+
let source, destination;
212+
const validSourcePattern = /^(modelscope|oss|nas):\/\//;
213+
// 处理源路径
214+
source = this._getSourcePath(file.source.path, uri, validSourcePattern);
215+
// 处理目标路径
216+
destination = this._getDestinationPath(file.target.path, nasMountPoints, ossMountPoints);
217+
218+
return {
219+
source,
220+
destination
221+
}
222+
}
223+
224+
private _displayProgressComplete(filePath: string, currentBytes: number, totalBytes: number) {
225+
if (totalBytes && currentBytes !== undefined) {
226+
const currentMB = (currentBytes / 1024 / 1024).toFixed(1);
227+
const totalMB = (totalBytes / 1024 / 1024).toFixed(1);
228+
229+
const totalBars = 50;
230+
const progressBar = '='.repeat(totalBars);
231+
232+
process.stdout.write(
233+
`\r[Download-model] ${filePath} [${progressBar}] 100.00% (${currentMB}MB/${totalMB}MB)\n`,
234+
);
235+
} else {
236+
process.stdout.write('\n');
237+
}
238+
// 清除进度条并换行
239+
process.stdout.write('\n');
240+
}
241+
242+
private _displayProgress(filePath: string, currentBytes: number, totalBytes: number) {
243+
if (currentBytes && totalBytes) {
244+
const percentage = (currentBytes / totalBytes) * 100;
245+
const currentMB = (currentBytes / 1024 / 1024).toFixed(1);
246+
const totalMB = (totalBytes / 1024 / 1024).toFixed(1);
247+
248+
// 每个等号代表2%,向下取整计算等号数量
249+
const totalBars = 50; // 总共50个字符位置
250+
const filledBars = Math.min(totalBars, Math.floor(percentage / 2)); // 每个等号代表2%
251+
const emptyBars = totalBars - filledBars;
252+
253+
const progressBar = '='.repeat(filledBars) + '.'.repeat(emptyBars);
254+
255+
process.stdout.write(
256+
`\r[Download-model] ${filePath} [${progressBar}] ${percentage.toFixed(
257+
2,
258+
)}% (${currentMB}MB/${totalMB}MB)`,
259+
);
260+
}
261+
}
262+
263+
private _getSourcePath(filePath: string, uri: string, validSourcePattern: RegExp): string {
264+
if (validSourcePattern.test(filePath)) {
265+
return filePath;
266+
} else if (validSourcePattern.test(uri)) {
267+
const downloadUri = uri.endsWith('/') && uri.length > 1 ? uri.slice(0, -1) : uri;
268+
return `${downloadUri}/${filePath}`;
269+
} else {
270+
throw new Error(`Invalid source path. Expected a valid URI starting with 'modelscope://', 'oss://', or 'nas://', but got: ${filePath}`);
271+
}
272+
}
273+
274+
private _getDestinationPath(
275+
targetPath: string,
276+
nasMountPoints: any[],
277+
ossMountPoints: any[],
278+
): string {
279+
const nasProtocolPattern = /^(nas):\/\//;
280+
const ossProtocolPattern = /^(oss):\/\//;
281+
// 通过 targetPath 是nas://或者oss://前缀来判断使用哪个路径
282+
if (nasProtocolPattern.test(targetPath)) {
283+
const pathWithoutProtocol = targetPath.replace(nasProtocolPattern, '');
284+
return pathWithoutProtocol.startsWith('/') ? `file://${nasMountPoints[0].mountDir}${pathWithoutProtocol}` : `file://${pathWithoutProtocol}`;
285+
} else if (ossProtocolPattern.test(targetPath)) {
286+
const pathWithoutProtocol = targetPath.replace(ossProtocolPattern, '');
287+
return pathWithoutProtocol.startsWith('/') ? `file://${ossMountPoints[0].mountDir}${pathWithoutProtocol}` : `file://${pathWithoutProtocol}`;
288+
}
289+
}
290+
}

0 commit comments

Comments
 (0)