1+ // 文生图服务下载,删除逻辑
2+ import logger from '../../logger' ;
3+ import DevClient , * as $Dev20230714 from '@alicloud/devs20230714' ;
4+ import * as $OpenApi from '@alicloud/openapi-client' ;
5+ import { MODEL_DOWNLOAD_TIMEOUT , NEW_MODEL_SERVICE_CLIENT_CONNECT_TIMEOUT , NEW_MODEL_SERVICE_CLIENT_READ_TIMEOUT } from '.' ;
6+ import { IInputs } from '../../interface' ;
7+ import { sleep } from '../../utils' ;
8+ import _ from 'lodash' ;
9+
10+ export class ArtModelService {
11+ logger = logger ;
12+ client : DevClient ;
13+ constructor ( private inputs : IInputs ) {
14+ }
15+
16+ async initClient ( ) {
17+ const {
18+ AccessKeyID : accessKeyId ,
19+ AccessKeySecret : accessKeySecret ,
20+ SecurityToken : securityToken ,
21+ } = await this . inputs . getCredential ( ) ;
22+
23+ let endpoint : string ;
24+
25+ endpoint = 'devs-pre.cn-hangzhou.aliyuncs.com' ;
26+ if ( process . env . ARTIFACT_ENDPOINT ) {
27+ endpoint = process . env . ARTIFACT_ENDPOINT ;
28+ }
29+ if ( process . env . artifact_endpoint ) {
30+ endpoint = process . env . artifact_endpoint ;
31+ }
32+
33+ const protocol = 'https' ;
34+
35+ const config = new $OpenApi . Config ( {
36+ accessKeyId,
37+ accessKeySecret,
38+ securityToken,
39+ protocol,
40+ endpoint,
41+ readTimeout : NEW_MODEL_SERVICE_CLIENT_READ_TIMEOUT ,
42+ connectTimeout : NEW_MODEL_SERVICE_CLIENT_CONNECT_TIMEOUT ,
43+ userAgent : `${ this . inputs . userAgent ||
44+ `Component:cap-model;Nodejs:${ process . version } ;OS:${ process . platform } -${ process . arch } `
45+ } `,
46+ } ) ;
47+
48+ logger . info ( `new models service init, ARTIFACT_ENDPOINT endpoint: ${ config . endpoint } ` ) ;
49+
50+ return new DevClient ( config ) ;
51+ }
52+
53+ async downloadModel ( name , params ) {
54+ const devClient = await this . initClient ( ) ;
55+ const {
56+ nasMountPoints,
57+ ossMountPoints,
58+ role,
59+ modelConfig,
60+ vpcConfig,
61+ storage,
62+ region
63+ } = params ;
64+ const files = modelConfig . files ;
65+ if ( ! _ . isEmpty ( files ) ) {
66+ // 先统一获取已有的任务列表
67+ const ListFileManagerTasksRequest = new $Dev20230714 . ListFileManagerTasksRequest ( { name } ) ;
68+ const res = await devClient . listFileManagerTasks ( ListFileManagerTasksRequest ) ;
69+ logger . debug ( 'listFileManagerTasks' , JSON . stringify ( res , null , 2 ) ) ;
70+ const existingTasks = res . body . data . tasks ;
71+
72+ // 第一步:筛选真正需要下载的文件
73+ const filesNeedPromises = files . map ( async file => {
74+ const { source, destination } = this . getSource ( storage , modelConfig . source . uri , file , nasMountPoints , ossMountPoints , region ) ;
75+
76+ const needDownload = ! existingTasks . some (
77+ task =>
78+ task . finished &&
79+ task . success &&
80+ task . progress . currentBytes === task . progress . totalBytes &&
81+ task . parameters . destination === destination &&
82+ task . parameters . source === source
83+ ) ;
84+
85+ if ( ! needDownload ) {
86+ logger . info ( `[Download-model] ${ file . source . path } Download model finished.` ) ;
87+ return null ;
88+ }
89+
90+ return {
91+ ...file ,
92+ source,
93+ fileName : file . target . path ,
94+ destination
95+ } ;
96+ } ) ;
97+
98+ const filesNeedResults = await Promise . all ( filesNeedPromises ) ;
99+ const filesNeed = filesNeedResults . filter ( Boolean ) ;
100+
101+ // 添加调试日志
102+ logger . info ( `[Download-model] Total files to check: ${ files . length } ` ) ;
103+ logger . info ( `[Download-model] Files need to download: ${ filesNeed . length } ` ) ;
104+
105+ const filesApi = filesNeed . map ( async ( file ) => {
106+ const { source, destination } = file ;
107+ const fileManagerRsyncRequest = new $Dev20230714 . FileManagerRsyncRequest ( {
108+ mountConfig : new $Dev20230714 . FileManagerMountConfig ( {
109+ name,
110+ nasMountPoints,
111+ ossMountPoints,
112+ role,
113+ region,
114+ vpcConfig
115+ } ) ,
116+ source,
117+ destination,
118+ conflictHandling : process . env . MODEL_CONFLIC_HANDLING || modelConfig . conflictResolution
119+ } )
120+ logger . debug ( 'FileManagerRsyncRequest' , JSON . stringify ( fileManagerRsyncRequest , null , 2 ) )
121+ const req = await devClient . fileManagerRsync ( fileManagerRsyncRequest )
122+ logger . debug ( `[Download-model] fileManagerRsync response for ${ file . fileName } : ${ JSON . stringify ( req . body , null , 2 ) } ` )
123+ if ( ! req . body . success ) {
124+ throw new Error ( `fileManagerRsync error: ${ JSON . stringify ( req . body , null , 2 ) } ` )
125+ }
126+
127+ const taskID = req . body . data . taskID ;
128+ logger . info ( `[Download-model] download model requestId for ${ file . fileName } : ${ req . body . requestId } , taskID: ${ taskID } ` ) ;
129+ const shouldContinue = true ;
130+ while ( shouldContinue ) {
131+ // eslint-disable-next-line no-await-in-loop
132+ const getFileManager = await devClient . getFileManagerTask ( taskID )
133+ logger . debug ( 'getFileManagerTask' , JSON . stringify ( getFileManager , null , 2 ) ) ;
134+ const modelStatus = getFileManager . body . data ;
135+ const totalBytes = modelStatus . progress . totalBytes as any - 0 ;
136+ const currentBytes = modelStatus . progress . currentBytes as any - 0 ;
137+
138+ if ( modelStatus . finished ) {
139+ // 如果存在错误信息,则抛出异常
140+ if ( modelStatus . errorMessage ) {
141+ logger . error ( `[Download-model] ${ file . fileName } : ${ modelStatus . errorMessage } ` ) ;
142+ throw new Error ( `[Download-model] ${ file . fileName } : ${ modelStatus . errorMessage } ` ) ;
143+ }
144+ // 下载成功完成
145+ this . _displayProgressComplete ( file . fileName , currentBytes , totalBytes ) ;
146+ if ( modelStatus . progress . total ) {
147+ const durationMs = modelStatus . finishedTime - modelStatus . startTime ;
148+ const durationSeconds = Math . floor ( durationMs / 1000 ) ;
149+ logger . info ( `Time taken for model download: ${ durationSeconds } s.` ) ;
150+ }
151+ logger . info ( `[Download-model] ${ file . fileName } Download model finished.` ) ;
152+ return true ;
153+ }
154+ // 显示下载进度
155+ this . _displayProgress ( file . fileName , currentBytes , totalBytes ) ;
156+
157+ if ( Date . now ( ) - modelStatus . startTime > MODEL_DOWNLOAD_TIMEOUT ) {
158+ // 清除进度条并换行
159+ process . stdout . write ( '\n' ) ;
160+ const errorMessage = `[Model-download] ${ file . fileName } Download timeout after ${ MODEL_DOWNLOAD_TIMEOUT / 1000 / 60
161+ } minutes`;
162+ throw new Error ( errorMessage ) ;
163+ }
164+
165+ // 根据文件大小调整轮询间隔
166+ let sleepTime = 2 ; // 默认2秒
167+ if ( totalBytes !== undefined && totalBytes > 1024 * 1024 * 1024 ) {
168+ // 文件大于1GB时,轮询间隔为10秒
169+ sleepTime = 10 ;
170+ }
171+
172+ await sleep ( sleepTime ) ;
173+ }
174+ } )
175+ await Promise . all ( filesApi )
176+ }
177+ }
178+
179+ async removeModel ( name , params ) {
180+ const {
181+ nasMountPoints,
182+ ossMountPoints,
183+ role,
184+ vpcConfig,
185+ modelConfig,
186+ } = params ;
187+ const devClient = await this . initClient ( ) ;
188+ const files = modelConfig . files ;
189+
190+ const filesRm = files . map ( async file => {
191+ let filepath = this . _getDestinationPath ( file . target . path , nasMountPoints , ossMountPoints )
192+
193+ const fileManagerRmRequest = new $Dev20230714 . FileManagerRmRequest ( {
194+ filepath,
195+ mountConfignew : new $Dev20230714 . FileManagerMountConfig ( {
196+ name,
197+ nasMountPoints,
198+ ossMountPoints,
199+ role,
200+ vpcConfig
201+ } ) ,
202+ skipParentDirectory : true
203+ } )
204+ await devClient . fileManagerRm ( fileManagerRmRequest ) ;
205+ } )
206+
207+ await Promise . all ( filesRm ) ;
208+ }
209+
210+ getSource ( storage , uri , file , nasMountPoints , ossMountPoints , region ) {
211+ let source , destination ;
212+ const validSourcePattern = / ^ ( m o d e l s c o p e | o s s | n a s ) : \/ \/ / ;
213+ // 处理源路径
214+ source = this . _getSourcePath ( file . source . path , uri , validSourcePattern ) ;
215+ // 处理目标路径
216+ destination = this . _getDestinationPath ( file . target . path , nasMountPoints , ossMountPoints ) ;
217+
218+ return {
219+ source,
220+ destination
221+ }
222+ }
223+
224+ private _displayProgressComplete ( filePath : string , currentBytes : number , totalBytes : number ) {
225+ if ( totalBytes && currentBytes !== undefined ) {
226+ const currentMB = ( currentBytes / 1024 / 1024 ) . toFixed ( 1 ) ;
227+ const totalMB = ( totalBytes / 1024 / 1024 ) . toFixed ( 1 ) ;
228+
229+ const totalBars = 50 ;
230+ const progressBar = '=' . repeat ( totalBars ) ;
231+
232+ process . stdout . write (
233+ `\r[Download-model] ${ filePath } [${ progressBar } ] 100.00% (${ currentMB } MB/${ totalMB } MB)\n` ,
234+ ) ;
235+ } else {
236+ process . stdout . write ( '\n' ) ;
237+ }
238+ // 清除进度条并换行
239+ process . stdout . write ( '\n' ) ;
240+ }
241+
242+ private _displayProgress ( filePath : string , currentBytes : number , totalBytes : number ) {
243+ if ( currentBytes && totalBytes ) {
244+ const percentage = ( currentBytes / totalBytes ) * 100 ;
245+ const currentMB = ( currentBytes / 1024 / 1024 ) . toFixed ( 1 ) ;
246+ const totalMB = ( totalBytes / 1024 / 1024 ) . toFixed ( 1 ) ;
247+
248+ // 每个等号代表2%,向下取整计算等号数量
249+ const totalBars = 50 ; // 总共50个字符位置
250+ const filledBars = Math . min ( totalBars , Math . floor ( percentage / 2 ) ) ; // 每个等号代表2%
251+ const emptyBars = totalBars - filledBars ;
252+
253+ const progressBar = '=' . repeat ( filledBars ) + '.' . repeat ( emptyBars ) ;
254+
255+ process . stdout . write (
256+ `\r[Download-model] ${ filePath } [${ progressBar } ] ${ percentage . toFixed (
257+ 2 ,
258+ ) } % (${ currentMB } MB/${ totalMB } MB)`,
259+ ) ;
260+ }
261+ }
262+
263+ private _getSourcePath ( filePath : string , uri : string , validSourcePattern : RegExp ) : string {
264+ if ( validSourcePattern . test ( filePath ) ) {
265+ return filePath ;
266+ } else if ( validSourcePattern . test ( uri ) ) {
267+ const downloadUri = uri . endsWith ( '/' ) && uri . length > 1 ? uri . slice ( 0 , - 1 ) : uri ;
268+ return `${ downloadUri } /${ filePath } ` ;
269+ } else {
270+ throw new Error ( `Invalid source path. Expected a valid URI starting with 'modelscope://', 'oss://', or 'nas://', but got: ${ filePath } ` ) ;
271+ }
272+ }
273+
274+ private _getDestinationPath (
275+ targetPath : string ,
276+ nasMountPoints : any [ ] ,
277+ ossMountPoints : any [ ] ,
278+ ) : string {
279+ const nasProtocolPattern = / ^ ( n a s ) : \/ \/ / ;
280+ const ossProtocolPattern = / ^ ( o s s ) : \/ \/ / ;
281+ // 通过 targetPath 是nas://或者oss://前缀来判断使用哪个路径
282+ if ( nasProtocolPattern . test ( targetPath ) ) {
283+ const pathWithoutProtocol = targetPath . replace ( nasProtocolPattern , '' ) ;
284+ return pathWithoutProtocol . startsWith ( '/' ) ? `file://${ nasMountPoints [ 0 ] . mountDir } ${ pathWithoutProtocol } ` : `file://${ pathWithoutProtocol } ` ;
285+ } else if ( ossProtocolPattern . test ( targetPath ) ) {
286+ const pathWithoutProtocol = targetPath . replace ( ossProtocolPattern , '' ) ;
287+ return pathWithoutProtocol . startsWith ( '/' ) ? `file://${ ossMountPoints [ 0 ] . mountDir } ${ pathWithoutProtocol } ` : `file://${ pathWithoutProtocol } ` ;
288+ }
289+ }
290+ }
0 commit comments