diff --git a/CHANGELOG.md b/CHANGELOG.md index af023d5..a8a9d28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,14 @@ * Added WebGPU readiness guidance covering browser capability checks, cross-origin isolation, bridge asset/version diagnostics, fallback behavior, model/configuration pressure, and the Flutter Web real-model smoke path. +* **Model download UX**: + * Added `ModelDownloadController`, a dependency-free helper that turns + `ModelDownloadManager` cache/download work into app-facing lifecycle states + for resolving, cache checks, downloads, verification, ready, failed, + cancelled, and retry flows. + * Wired the runnable chat app example through a `ModelDownloadManager` adapter + so its model-management UI demonstrates the controller while preserving the + example's multi-asset and web-cache service behavior. ## 0.6.13 diff --git a/example/chat_app/lib/screens/manage_models_screen.dart b/example/chat_app/lib/screens/manage_models_screen.dart index 796f087..f1a5602 100644 --- a/example/chat_app/lib/screens/manage_models_screen.dart +++ b/example/chat_app/lib/screens/manage_models_screen.dart @@ -12,6 +12,7 @@ import 'package:shared_preferences/shared_preferences.dart'; import '../models/downloadable_model.dart'; import '../providers/chat_provider.dart'; import '../services/hugging_face_model_discovery_service.dart'; +import '../services/model_download_controller_adapter.dart'; import '../services/model_service_base.dart'; import '../utils/backend_utils.dart'; import '../widgets/model_card.dart'; @@ -21,10 +22,19 @@ class ManageModelsScreen extends StatefulWidget { final VoidCallback? onModelActivated; final bool embeddedPanel; + // Test hooks for exercising download-controller wiring without relying on + // platform storage or the full built-in model catalog. + final ModelService? modelService; + final List? initialModels; + final bool? showModelLibraryInitially; + const ManageModelsScreen({ super.key, this.onModelActivated, this.embeddedPanel = false, + this.modelService, + this.initialModels, + this.showModelLibraryInitially, }); @override @@ -36,12 +46,10 @@ class _ManageModelsScreenState extends State static const String _customModelsPrefsKey = 'custom_hf_models_v1'; static const int _webLargeModelWarningBytes = 1900 * 1024 * 1024; - final ModelService _modelService = ModelService(); + late final ModelService _modelService; final HuggingFaceModelDiscoveryService _hfDiscoveryService = HuggingFaceModelDiscoveryService(); - final List _models = List.from( - DownloadableModel.defaultModels, - ); + final List _models = []; final List _customModels = []; final Map> @@ -49,7 +57,9 @@ class _ManageModelsScreenState extends State final Map _lastDownloadedBytes = {}; final Map _lastDownloadSampleAt = {}; final Map _smoothedDownloadRateBytesPerSec = {}; - final Map _cancelTokens = {}; + final Map _downloadControllers = {}; + final Map> + _downloadSubscriptions = {}; Set _downloadedFiles = {}; String? _modelsDir; @@ -62,10 +72,18 @@ class _ManageModelsScreenState extends State void initState() { super.initState(); WidgetsBinding.instance.addObserver(this); - _showModelLibrary = false; + _modelService = widget.modelService ?? ModelService(); + _models.addAll(_initialModelCatalog()); + _showModelLibrary = widget.showModelLibraryInitially ?? false; _initModelService(); } + List _initialModelCatalog() { + return List.from( + widget.initialModels ?? DownloadableModel.defaultModels, + ); + } + @override void didChangeAppLifecycleState(AppLifecycleState state) { if (kIsWeb) { @@ -74,7 +92,7 @@ class _ManageModelsScreenState extends State if (state == AppLifecycleState.paused || state == AppLifecycleState.hidden) { - _pauseActiveDownloads('App moved to background'); + _pauseActiveDownloads(); } } @@ -462,6 +480,23 @@ class _ManageModelsScreenState extends State return stageText; } + String? _downloadTaskLabel(ModelDownloadTaskSnapshot? task) { + if (task == null) { + return null; + } + return switch (task.stage) { + ModelDownloadTaskStage.idle => null, + ModelDownloadTaskStage.resolving => 'Resolving model', + ModelDownloadTaskStage.checkingCache => 'Checking cache', + ModelDownloadTaskStage.downloading => + kIsWeb ? 'Caching model' : 'Downloading model', + ModelDownloadTaskStage.verifying => 'Verifying model', + ModelDownloadTaskStage.ready => 'Ready', + ModelDownloadTaskStage.failed => task.errorMessage ?? 'Download failed', + ModelDownloadTaskStage.cancelled => 'Paused', + }; + } + String _downloadFailureMessage(dynamic error) { if (error is DioException) { final normalized = '${error.message ?? ''} ${error.error ?? ''}' @@ -582,7 +617,9 @@ class _ManageModelsScreenState extends State bool? isDownloading, double? progress, ModelDownloadProgress? detail, + ModelDownloadTaskSnapshot? task, bool clearDetail = false, + bool clearTask = false, bool clearProgress = false, }) { final notifier = _downloadUiStateFor(filename); @@ -591,7 +628,9 @@ class _ManageModelsScreenState extends State isDownloading: isDownloading, progress: clearProgress ? 0.0 : progress, detail: detail, + task: task, clearDetail: clearDetail, + clearTask: clearTask, ); } @@ -601,98 +640,154 @@ class _ManageModelsScreenState extends State _smoothedDownloadRateBytesPerSec.remove(filename); } - void _pauseActiveDownloads(String reason) { - final entries = _cancelTokens.entries.toList(growable: false); - for (final entry in entries) { - final token = entry.value; - if (!token.isCancelled) { - token.cancel(reason); + void _pauseActiveDownloads() { + for (final controller in _downloadControllers.values) { + if (controller.snapshot.isRunning) { + controller.cancel(); } } } + Future _disposeDownloadController( + String filename, { + ModelDownloadController? controller, + StreamSubscription? subscription, + }) async { + final currentSubscription = _downloadSubscriptions[filename]; + if (subscription == null || identical(currentSubscription, subscription)) { + await _downloadSubscriptions.remove(filename)?.cancel(); + } else { + await subscription.cancel(); + } + + final currentController = _downloadControllers[filename]; + if (controller == null || identical(currentController, controller)) { + await _downloadControllers.remove(filename)?.dispose(); + } else { + await controller.dispose(); + } + } + + void _handleDownloadSnapshot( + DownloadableModel model, + ModelDownloadTaskSnapshot snapshot, + ) { + if (!mounted) { + return; + } + _updateDownloadUiState( + model.filename, + isDownloading: snapshot.isRunning, + progress: snapshot.fraction, + task: snapshot, + ); + } + Future _downloadModel(DownloadableModel model) async { if (!kIsWeb && _modelsDir == null) { return; } + if (_downloadControllers[model.filename]?.snapshot.isRunning ?? false) { + return; + } + + await _disposeDownloadController(model.filename); + + ModelDownloadController? controller; + StreamSubscription? subscription; - final cancelToken = CancelToken(); _updateDownloadUiState( model.filename, isDownloading: true, clearDetail: true, + clearTask: true, + clearProgress: true, ); - _lastDownloadedBytes.remove(model.filename); - _lastDownloadSampleAt.remove(model.filename); - _smoothedDownloadRateBytesPerSec.remove(model.filename); - _cancelTokens[model.filename] = cancelToken; + _clearDownloadTracking(model.filename); - await _modelService.downloadModel( - model: model, - modelsDir: _modelsDir ?? '', - cancelToken: cancelToken, - onProgress: (_) {}, - onProgressDetail: (detail) { - if (!mounted) { - return; - } - _updateDownloadRate(model.filename, detail); - _updateDownloadUiState( - model.filename, - progress: detail.overallProgress, - detail: detail, - ); - }, - onSuccess: (filename) { - if (!mounted) return; - _updateDownloadUiState( - model.filename, - isDownloading: false, - clearProgress: true, - clearDetail: true, - ); + try { + final manager = ChatAppModelDownloadManager( + modelService: _modelService, + model: model, + modelsDir: _modelsDir ?? '', + onProgressDetail: (detail) { + if (!mounted) { + return; + } + _updateDownloadRate(model.filename, detail); + _updateDownloadUiState( + model.filename, + progress: detail.overallProgress, + detail: detail, + ); + }, + ); + controller = ModelDownloadController(manager: manager); + _downloadControllers[model.filename] = controller; + subscription = controller.snapshots.listen( + (snapshot) => _handleDownloadSnapshot(model, snapshot), + ); + _downloadSubscriptions[model.filename] = subscription; + + await controller.start(manager.source); + if (!mounted) { + return; + } + _updateDownloadUiState( + model.filename, + isDownloading: false, + clearProgress: true, + clearDetail: true, + clearTask: true, + ); + _clearDownloadTracking(model.filename); + setState(() { + _downloadedFiles.add(model.filename); + }); + ScaffoldMessenger.of(context).showSnackBar( + SnackBar(content: Text('${model.name} downloaded successfully.')), + ); + } catch (error) { + if (!mounted) { + return; + } + final snapshot = controller?.snapshot; + final isCancel = + snapshot?.stage == ModelDownloadTaskStage.cancelled || + (error is DioException && error.type == DioExceptionType.cancel); + _updateDownloadUiState( + model.filename, + isDownloading: false, + clearDetail: !isCancel, + clearTask: !isCancel, + clearProgress: !isCancel, + ); + if (!isCancel) { _clearDownloadTracking(model.filename); - _cancelTokens.remove(model.filename); - setState(() { - _downloadedFiles.add(filename); - }); - ScaffoldMessenger.of(context).showSnackBar( - SnackBar(content: Text('${model.name} downloaded successfully.')), - ); - }, - onError: (error) { - if (!mounted) return; - final isCancel = - error is DioException && error.type == DioExceptionType.cancel; - _updateDownloadUiState( - model.filename, - isDownloading: false, - clearDetail: !isCancel, - clearProgress: !isCancel, - ); - if (!isCancel) { - _clearDownloadTracking(model.filename); - } - _cancelTokens.remove(model.filename); - - ScaffoldMessenger.of(context).showSnackBar( - SnackBar( - content: Text( - isCancel - ? 'Download paused: ${model.name}' - : _downloadFailureMessage(error), - ), + } + + ScaffoldMessenger.of(context).showSnackBar( + SnackBar( + content: Text( + isCancel + ? 'Download paused: ${model.name}' + : error is DioException + ? _downloadFailureMessage(error) + : snapshot?.errorMessage ?? _downloadFailureMessage(error), ), - ); - }, - ); + ), + ); + } finally { + await _disposeDownloadController( + model.filename, + controller: controller, + subscription: subscription, + ); + } } void _cancelDownload(DownloadableModel model) { - final token = _cancelTokens[model.filename]; - if (token != null && !token.isCancelled) { - token.cancel(); - } + _downloadControllers[model.filename]?.cancel(); } Future _selectModel(DownloadableModel model) async { @@ -770,7 +865,7 @@ class _ManageModelsScreenState extends State clearDetail: true, ); _clearDownloadTracking(model.filename); - _cancelTokens.remove(model.filename); + await _disposeDownloadController(model.filename); setState(() { _downloadedFiles.remove(model.filename); @@ -814,11 +909,7 @@ class _ManageModelsScreenState extends State final provider = context.read(); - for (final token in _cancelTokens.values) { - if (!token.isCancelled) { - token.cancel('Bulk remove models'); - } - } + _pauseActiveDownloads(); final snapshot = List.from(_models); for (final model in snapshot) { @@ -827,7 +918,7 @@ class _ManageModelsScreenState extends State _models ..clear() - ..addAll(DownloadableModel.defaultModels); + ..addAll(_initialModelCatalog()); _customModels.clear(); for (final notifier in _downloadUiStateByFile.values) { notifier.dispose(); @@ -836,7 +927,14 @@ class _ManageModelsScreenState extends State _lastDownloadedBytes.clear(); _lastDownloadSampleAt.clear(); _smoothedDownloadRateBytesPerSec.clear(); - _cancelTokens.clear(); + for (final subscription in _downloadSubscriptions.values) { + await subscription.cancel(); + } + _downloadSubscriptions.clear(); + for (final controller in _downloadControllers.values) { + await controller.dispose(); + } + _downloadControllers.clear(); _downloadedFiles = await _modelService.getDownloadedModels(_models); await _saveCustomModels(); @@ -1231,6 +1329,9 @@ class _ManageModelsScreenState extends State valueListenable: downloadStateListenable, builder: (context, downloadState, _) { final detail = downloadState.detail; + final taskLabel = _downloadTaskLabel( + downloadState.task, + ); final card = ModelCard( model: model, @@ -1240,7 +1341,7 @@ class _ManageModelsScreenState extends State isDownloading: downloadState.isDownloading, progress: downloadState.progress, downloadStatusLabel: detail == null - ? null + ? taskLabel : _downloadStageLabel(detail), downloadTransferLabel: detail == null ? null @@ -1833,7 +1934,15 @@ class _ManageModelsScreenState extends State @override void dispose() { WidgetsBinding.instance.removeObserver(this); - _pauseActiveDownloads('Model manager disposed'); + _pauseActiveDownloads(); + for (final subscription in _downloadSubscriptions.values) { + unawaited(subscription.cancel()); + } + _downloadSubscriptions.clear(); + for (final controller in _downloadControllers.values) { + unawaited(controller.dispose()); + } + _downloadControllers.clear(); for (final notifier in _downloadUiStateByFile.values) { notifier.dispose(); } @@ -1846,18 +1955,22 @@ class _ModelDownloadUiState { final bool isDownloading; final double progress; final ModelDownloadProgress? detail; + final ModelDownloadTaskSnapshot? task; const _ModelDownloadUiState({ this.isDownloading = false, this.progress = 0.0, this.detail, + this.task, }); _ModelDownloadUiState copyWith({ bool? isDownloading, double? progress, ModelDownloadProgress? detail, + ModelDownloadTaskSnapshot? task, bool clearDetail = false, + bool clearTask = false, }) { final normalizedProgress = ((progress ?? this.progress).clamp(0.0, 1.0) as num).toDouble(); @@ -1866,6 +1979,7 @@ class _ModelDownloadUiState { isDownloading: isDownloading ?? this.isDownloading, progress: normalizedProgress, detail: clearDetail ? null : (detail ?? this.detail), + task: clearTask ? null : (task ?? this.task), ); } } diff --git a/example/chat_app/lib/services/model_download_controller_adapter.dart b/example/chat_app/lib/services/model_download_controller_adapter.dart new file mode 100644 index 0000000..cee0666 --- /dev/null +++ b/example/chat_app/lib/services/model_download_controller_adapter.dart @@ -0,0 +1,219 @@ +import 'dart:async'; + +import 'package:dio/dio.dart'; +import 'package:llamadart/llamadart.dart' as llama; +import 'package:path/path.dart' as p; + +import '../models/downloadable_model.dart'; +import 'model_service_base.dart'; + +/// Adapts the chat app's platform-specific model service to the package-level +/// [llama.ModelDownloadController] contract. +/// +/// The example app still owns its multi-asset native/web storage details, while +/// the controller owns the app-facing resolving/cache/download/verify/ready and +/// cancel/retry state machine. +class ChatAppModelDownloadManager implements llama.ModelDownloadManager { + ChatAppModelDownloadManager({ + required this.modelService, + required this.model, + required this.modelsDir, + this.onProgressDetail, + }) : source = sourceFor(model); + + final ModelService modelService; + final DownloadableModel model; + final String modelsDir; + final void Function(ModelDownloadProgress progress)? onProgressDetail; + final llama.ModelSource source; + + static llama.ModelSource sourceFor(DownloadableModel model) { + return _sourceForAsset(model.modelSource); + } + + @override + Future ensureModel( + llama.ModelSource source, { + llama.ModelLoadOptions options = llama.ModelLoadOptions.defaults, + llama.ModelDownloadProgressCallback? onProgress, + }) async { + _checkSource(source); + _rejectUnsupportedOptions(options); + + final cached = await get(source.cacheKey); + switch (options.cachePolicy) { + case llama.ModelCachePolicy.preferCached: + if (cached != null) { + return cached; + } + break; + case llama.ModelCachePolicy.cacheOnly: + if (cached != null) { + return cached; + } + throw StateError('Model is not cached: ${source.displayName}.'); + case llama.ModelCachePolicy.refresh: + case llama.ModelCachePolicy.noCache: + break; + } + + Object? failure; + String? completedFilename; + final cancelToken = CancelToken(); + final cancellationPoller = _bridgeCancellation( + options.cancelToken, + cancelToken, + ); + + try { + await modelService.downloadModel( + model: model, + modelsDir: modelsDir, + cancelToken: cancelToken, + onProgress: (progress) { + onProgress?.call(llama.ModelDownloadProgress.fraction(progress)); + }, + onProgressDetail: (detail) { + onProgressDetail?.call(detail); + onProgress?.call( + llama.ModelDownloadProgress.fraction(detail.overallProgress), + ); + }, + onSuccess: (filename) { + completedFilename = filename; + }, + onError: (error) { + failure = error; + }, + ); + } finally { + cancellationPoller?.cancel(); + } + + final error = failure; + if (error != null) { + Error.throwWithStackTrace(error, StackTrace.current); + } + if (completedFilename == null) { + throw StateError('Model download finished without success or failure.'); + } + + return _cacheEntry(source); + } + + @override + Future> list({String? cacheDirectory}) async { + final cached = await get(source.cacheKey, cacheDirectory: cacheDirectory); + return cached == null + ? const [] + : [cached]; + } + + @override + Future get( + String cacheKey, { + String? cacheDirectory, + }) async { + if (cacheKey != source.cacheKey) { + return null; + } + final downloaded = await modelService.getDownloadedModels( + [model], + ); + if (!downloaded.contains(model.filename)) { + return null; + } + return _cacheEntry(source); + } + + @override + Future remove(String cacheKey, {String? cacheDirectory}) async { + throw UnsupportedError( + 'ChatAppModelDownloadManager delegates deletion to ModelService.deleteModel.', + ); + } + + @override + Future clear({String? cacheDirectory}) async { + throw UnsupportedError( + 'ChatAppModelDownloadManager delegates deletion to the chat app UI.', + ); + } + + @override + Future> prune({ + Duration? maxAge, + int? maxBytes, + String? cacheDirectory, + }) async { + throw UnsupportedError( + 'ChatAppModelDownloadManager does not manage package cache pruning.', + ); + } + + void _checkSource(llama.ModelSource requestedSource) { + if (requestedSource.cacheKey != source.cacheKey) { + throw ArgumentError.value( + requestedSource, + 'source', + 'ChatAppModelDownloadManager is bound to ${source.displayName}.', + ); + } + } + + void _rejectUnsupportedOptions(llama.ModelLoadOptions options) { + if (options.sha256 != null) { + throw UnsupportedError( + 'ChatAppModelDownloadManager cannot verify SHA-256 checksums.', + ); + } + } + + Timer? _bridgeCancellation( + llama.ModelDownloadCancelToken? controllerToken, + CancelToken cancelToken, + ) { + if (controllerToken == null) { + return null; + } + void cancelIfNeeded() { + if (controllerToken.isCancelled && !cancelToken.isCancelled) { + cancelToken.cancel('Download cancelled.'); + } + } + + cancelIfNeeded(); + if (cancelToken.isCancelled) { + return null; + } + return Timer.periodic(const Duration(milliseconds: 100), (_) { + cancelIfNeeded(); + }); + } + + llama.ModelCacheEntry _cacheEntry(llama.ModelSource source) { + final now = DateTime.now().toUtc(); + return llama.ModelCacheEntry( + sourceCanonicalKey: source.canonicalKey, + cacheKey: source.cacheKey, + fileName: source.fileName, + filePath: source.isLocal + ? source.path! + : p.join(modelsDir, source.fileName), + bytes: model.sizeBytes > 0 ? model.sizeBytes : null, + createdAt: now, + updatedAt: now, + ); + } +} + +llama.ModelSource _sourceForAsset(ModelAssetSource source) { + if (source is LocalModelAssetSource) { + return llama.ModelSource.path(source.path); + } + final remote = source as RemoteModelAssetSource; + return llama.ModelSource.url( + Uri.parse(remote.url), + fileName: remote.filename, + ); +} diff --git a/example/chat_app/test/manage_models_screen_download_test.dart b/example/chat_app/test/manage_models_screen_download_test.dart new file mode 100644 index 0000000..543ca74 --- /dev/null +++ b/example/chat_app/test/manage_models_screen_download_test.dart @@ -0,0 +1,202 @@ +import 'dart:async'; + +import 'package:dio/dio.dart'; +import 'package:flutter/material.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:llamadart_chat_example/models/downloadable_model.dart'; +import 'package:llamadart_chat_example/providers/chat_provider.dart'; +import 'package:llamadart_chat_example/screens/manage_models_screen.dart'; +import 'package:llamadart_chat_example/services/model_service_base.dart'; +import 'package:provider/provider.dart'; +import 'package:shared_preferences/shared_preferences.dart'; + +import 'mocks.dart'; + +void main() { + TestWidgetsFlutterBinding.ensureInitialized(); + + group('ManageModelsScreen model download controller wiring', () { + testWidgets('pause button cancels the active controller download', ( + tester, + ) async { + SharedPreferences.setMockInitialValues({}); + final model = _remoteModel(); + final modelService = _HoldingModelService(); + + await _pumpScreen(tester, modelService: modelService, models: [model]); + + expect(find.text(model.name), findsOneWidget); + expect(find.text('Download'), findsOneWidget); + + await tester.tap(find.text('Download')); + await modelService.downloadStarted.future.timeout(_testTimeout); + await tester.pump(); + + expect(modelService.downloadCalls, 1); + expect(find.text('Downloading model'), findsOneWidget); + expect(find.text('25%'), findsOneWidget); + + await tester.tap(find.byTooltip('Pause Download')); + await tester.pump(const Duration(milliseconds: 150)); + await modelService.downloadCancelled.future.timeout(_testTimeout); + await tester.pump(); + + expect(modelService.lastCancelToken?.isCancelled, isTrue); + expect(find.text('Paused'), findsOneWidget); + expect(find.text('25%'), findsOneWidget); + expect(find.text('Resume Download'), findsOneWidget); + }); + + testWidgets('cancel and discard reports a paused cancellation', ( + tester, + ) async { + SharedPreferences.setMockInitialValues({}); + final model = _remoteModel(); + final modelService = _HoldingModelService(); + + await _pumpScreen(tester, modelService: modelService, models: [model]); + + await tester.tap(find.text('Download')); + await modelService.downloadStarted.future.timeout(_testTimeout); + await tester.pump(); + + await tester.tap(find.byTooltip('Cancel & Discard')); + await tester.pump(const Duration(milliseconds: 150)); + await modelService.downloadCancelled.future.timeout(_testTimeout); + await tester.pump(); + + expect(modelService.lastCancelToken?.isCancelled, isTrue); + expect(find.text('Download paused: ${model.name}'), findsOneWidget); + expect(find.text('Download failed. Please retry.'), findsNothing); + }); + + testWidgets('disposing the screen cancels active controller downloads', ( + tester, + ) async { + SharedPreferences.setMockInitialValues({}); + final modelService = _HoldingModelService(); + + await _pumpScreen( + tester, + modelService: modelService, + models: [_remoteModel()], + ); + + await tester.tap(find.text('Download')); + await modelService.downloadStarted.future.timeout(_testTimeout); + await tester.pump(); + + await tester.pumpWidget(const SizedBox.shrink()); + await tester.pump(const Duration(milliseconds: 150)); + await modelService.downloadCancelled.future.timeout(_testTimeout); + + expect(modelService.lastCancelToken?.isCancelled, isTrue); + }); + }); +} + +const _testTimeout = Duration(seconds: 2); + +Future _pumpScreen( + WidgetTester tester, { + required _HoldingModelService modelService, + required List models, +}) async { + final provider = ChatProvider( + chatService: MockChatService(), + settingsService: MockSettingsService(), + ); + addTearDown(provider.dispose); + + await tester.pumpWidget( + ChangeNotifierProvider.value( + value: provider, + child: MaterialApp( + home: Scaffold( + body: ManageModelsScreen( + embeddedPanel: true, + modelService: modelService, + initialModels: models, + showModelLibraryInitially: true, + ), + ), + ), + ), + ); + await tester.pumpAndSettle(); +} + +DownloadableModel _remoteModel() { + return const DownloadableModel( + name: 'Tiny Test Model', + description: 'Small fake model for screen tests.', + url: 'https://example.com/tiny.gguf', + filename: 'tiny.gguf', + sizeBytes: 10, + ); +} + +class _HoldingModelService implements ModelService { + final Completer downloadStarted = Completer(); + final Completer downloadCancelled = Completer(); + + int downloadCalls = 0; + CancelToken? lastCancelToken; + + @override + Future getModelsDirectory() async => '/models'; + + @override + Future> getDownloadedModels( + List models, + ) async { + return {}; + } + + @override + Future downloadModel({ + required DownloadableModel model, + required String modelsDir, + required CancelToken cancelToken, + required Function(double progress) onProgress, + Function(ModelDownloadProgress progress)? onProgressDetail, + required Function(String filename) onSuccess, + required Function(dynamic error) onError, + }) async { + downloadCalls += 1; + lastCancelToken = cancelToken; + onProgress(0.25); + onProgressDetail?.call( + ModelDownloadProgress( + overallProgress: 0.25, + downloadedBytes: 25, + totalBytes: 100, + stage: ModelDownloadStage.model, + stageIndex: 1, + stageCount: 1, + stageDownloadedBytes: 25, + stageTotalBytes: 100, + ), + ); + if (!downloadStarted.isCompleted) { + downloadStarted.complete(); + } + + while (!cancelToken.isCancelled) { + await Future.delayed(const Duration(milliseconds: 10)); + } + if (!downloadCancelled.isCompleted) { + downloadCancelled.complete(); + } + onError( + DioException( + requestOptions: RequestOptions(path: model.url), + type: DioExceptionType.cancel, + message: 'Download cancelled.', + ), + ); + } + + @override + Future deleteModel(String modelsDir, DownloadableModel model) async {} +} diff --git a/example/chat_app/test/model_download_controller_adapter_test.dart b/example/chat_app/test/model_download_controller_adapter_test.dart new file mode 100644 index 0000000..c7cb1d6 --- /dev/null +++ b/example/chat_app/test/model_download_controller_adapter_test.dart @@ -0,0 +1,215 @@ +import 'dart:async'; + +import 'package:dio/dio.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:llamadart/llamadart.dart' as llama; +import 'package:llamadart_chat_example/models/downloadable_model.dart'; +import 'package:llamadart_chat_example/services/model_download_controller_adapter.dart'; +import 'package:llamadart_chat_example/services/model_service_base.dart'; + +void main() { + group('ChatAppModelDownloadManager', () { + test( + 'returns cached entries without calling the chat app downloader', + () async { + final model = _remoteModel(); + final service = _FakeModelService(downloadedFiles: {model.filename}); + final manager = ChatAppModelDownloadManager( + modelService: service, + model: model, + modelsDir: '/models', + ); + + final entry = await manager.ensureModel(manager.source); + + expect(service.downloadCalls, 0); + expect(service.getDownloadedCalls, 1); + expect(entry.cacheKey, manager.source.cacheKey); + expect(entry.fileName, model.filename); + expect(entry.filePath, '/models/${model.filename}'); + }, + ); + + test( + 'rejects checksum options instead of reporting unverified ready', + () async { + final model = _remoteModel(); + final service = _FakeModelService(downloadedFiles: {model.filename}); + final manager = ChatAppModelDownloadManager( + modelService: service, + model: model, + modelsDir: '/models', + ); + + await expectLater( + manager.ensureModel( + manager.source, + options: llama.ModelLoadOptions(sha256: 'a' * 64), + ), + throwsA(isA()), + ); + + expect(service.downloadCalls, 0); + }, + ); + + test( + 'refresh downloads through the chat app service and forwards progress', + () async { + final model = _remoteModel(); + final detail = ModelDownloadProgress( + overallProgress: 0.5, + downloadedBytes: 5, + totalBytes: 10, + stage: ModelDownloadStage.model, + stageIndex: 1, + stageCount: 1, + stageDownloadedBytes: 5, + stageTotalBytes: 10, + ); + final service = _FakeModelService( + downloadedFiles: {model.filename}, + progressDetails: [detail], + ); + final appDetails = []; + final controllerProgress = []; + final manager = ChatAppModelDownloadManager( + modelService: service, + model: model, + modelsDir: '/models', + onProgressDetail: appDetails.add, + ); + + final entry = await manager.ensureModel( + manager.source, + options: llama.ModelLoadOptions( + cachePolicy: llama.ModelCachePolicy.refresh, + ), + onProgress: controllerProgress.add, + ); + + expect(service.downloadCalls, 1); + expect(service.lastModel, same(model)); + expect(service.lastModelsDir, '/models'); + expect(service.lastCancelToken, isNotNull); + expect(appDetails, [same(detail)]); + expect(controllerProgress.single.fraction, 0.5); + expect(entry.filePath, '/models/${model.filename}'); + }, + ); + + test( + 'bridges controller cancellation into the chat app Dio cancel token', + () async { + final model = _remoteModel(); + final service = _FakeModelService(waitForCancellation: true); + final manager = ChatAppModelDownloadManager( + modelService: service, + model: model, + modelsDir: '/models', + ); + final controller = llama.ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + final task = controller.start(manager.source); + await service.downloadStarted.future; + + controller.cancel(); + + await expectLater(task, throwsA(isA())); + expect(service.lastCancelToken?.isCancelled, isTrue); + expect( + controller.snapshot.stage, + llama.ModelDownloadTaskStage.cancelled, + ); + }, + ); + }); +} + +DownloadableModel _remoteModel() { + return const DownloadableModel( + name: 'Tiny Test Model', + description: 'Small fake model for adapter tests.', + url: 'https://example.com/tiny.gguf?token=secret', + filename: 'tiny.gguf', + sizeBytes: 10, + ); +} + +class _FakeModelService implements ModelService { + _FakeModelService({ + Set? downloadedFiles, + this.progressDetails = const [], + this.waitForCancellation = false, + }) : downloadedFiles = downloadedFiles ?? {}; + + final Set downloadedFiles; + final List progressDetails; + final bool waitForCancellation; + final Completer downloadStarted = Completer(); + + int getDownloadedCalls = 0; + int downloadCalls = 0; + DownloadableModel? lastModel; + String? lastModelsDir; + CancelToken? lastCancelToken; + + @override + Future getModelsDirectory() async => '/models'; + + @override + Future> getDownloadedModels( + List models, + ) async { + getDownloadedCalls += 1; + return models + .where((model) => downloadedFiles.contains(model.filename)) + .map((model) => model.filename) + .toSet(); + } + + @override + Future downloadModel({ + required DownloadableModel model, + required String modelsDir, + required CancelToken cancelToken, + required Function(double progress) onProgress, + Function(ModelDownloadProgress progress)? onProgressDetail, + required Function(String filename) onSuccess, + required Function(dynamic error) onError, + }) async { + downloadCalls += 1; + lastModel = model; + lastModelsDir = modelsDir; + lastCancelToken = cancelToken; + if (!downloadStarted.isCompleted) { + downloadStarted.complete(); + } + + if (waitForCancellation) { + while (!cancelToken.isCancelled) { + await Future.delayed(const Duration(milliseconds: 10)); + } + onError( + DioException( + requestOptions: RequestOptions(path: model.url), + type: DioExceptionType.cancel, + message: 'Download cancelled.', + ), + ); + return; + } + + for (final detail in progressDetails) { + onProgressDetail?.call(detail); + } + downloadedFiles.add(model.filename); + onSuccess(model.filename); + } + + @override + Future deleteModel(String modelsDir, DownloadableModel model) async { + downloadedFiles.remove(model.filename); + } +} diff --git a/lib/src/core/models/download/model_download_controller.dart b/lib/src/core/models/download/model_download_controller.dart new file mode 100644 index 0000000..e857b88 --- /dev/null +++ b/lib/src/core/models/download/model_download_controller.dart @@ -0,0 +1,380 @@ +import 'dart:async'; + +import '../../exceptions.dart'; +import '../model_load_options.dart'; +import '../model_source.dart'; +import 'model_download_manager_base.dart'; +import 'model_download_manager_stub.dart' + if (dart.library.io) '../../../platform/io/model_download_manager_io.dart'; + +/// High-level lifecycle stage for an app-facing model download task. +enum ModelDownloadTaskStage { + /// No task has started yet. + idle, + + /// The source and task options are being prepared. + resolving, + + /// The package-managed cache is being checked before network work starts. + checkingCache, + + /// Remote bytes are being downloaded or cached by the manager. + downloading, + + /// The manager is finalizing, verifying, or promoting the resolved file. + verifying, + + /// The model is available as a [ModelCacheEntry]. + ready, + + /// The task failed with an actionable, redacted [ModelDownloadTaskSnapshot.errorMessage]. + failed, + + /// The task was cancelled cooperatively. + cancelled, +} + +/// Immutable app-facing state for a [ModelDownloadController]. +class ModelDownloadTaskSnapshot { + /// Creates a model download task snapshot. + const ModelDownloadTaskSnapshot({ + required this.stage, + this.source, + this.entry, + this.progress, + this.errorMessage, + }); + + /// Initial idle snapshot. + const ModelDownloadTaskSnapshot.idle() + : stage = ModelDownloadTaskStage.idle, + source = null, + entry = null, + progress = null, + errorMessage = null; + + /// Current lifecycle stage. + final ModelDownloadTaskStage stage; + + /// Source being resolved or downloaded, when a task has started. + final ModelSource? source; + + /// Resolved cache entry after [stage] becomes [ModelDownloadTaskStage.ready]. + final ModelCacheEntry? entry; + + /// Latest byte-level progress reported by the underlying manager. + final ModelDownloadProgress? progress; + + /// Redacted user-facing failure/cancellation message, when available. + final String? errorMessage; + + /// Whether the task is actively doing asynchronous work. + bool get isRunning { + return switch (stage) { + ModelDownloadTaskStage.resolving || + ModelDownloadTaskStage.checkingCache || + ModelDownloadTaskStage.downloading || + ModelDownloadTaskStage.verifying => true, + _ => false, + }; + } + + /// Whether [ModelDownloadController.retry] can retry this snapshot's source. + bool get canRetry { + return source != null && + (stage == ModelDownloadTaskStage.failed || + stage == ModelDownloadTaskStage.cancelled); + } + + /// Best-known completion fraction, or null when unknown. + double? get fraction { + if (stage == ModelDownloadTaskStage.ready) { + return 1.0; + } + return progress?.fraction; + } +} + +/// Small, dependency-free controller for app model download/cache UX. +/// +/// The controller wraps a [ModelDownloadManager] and converts low-level cache +/// and byte progress callbacks into stable app states: resolving, cache check, +/// downloading, verifying, ready, failed, and cancelled. It intentionally uses +/// `dart:async` streams rather than Flutter types so it can be adapted to +/// `ValueNotifier`, `ChangeNotifier`, BLoC, Riverpod, or any other UI layer. +class ModelDownloadController { + /// Creates a model download controller. + /// + /// When [manager] is omitted the platform default manager is used. On + /// platforms without package-managed download support, starting a task emits a + /// failed snapshot with the manager's unsupported-operation message. + ModelDownloadController({ModelDownloadManager? manager}) + : manager = manager ?? DefaultModelDownloadManager(); + + /// Low-level manager used to inspect caches and resolve/download models. + final ModelDownloadManager manager; + + final StreamController _snapshots = + StreamController.broadcast(sync: true); + + ModelDownloadTaskSnapshot _snapshot = const ModelDownloadTaskSnapshot.idle(); + ModelDownloadCancelToken? _cancelToken; + ModelSource? _lastSource; + ModelLoadOptions _lastOptions = ModelLoadOptions.defaults; + bool _isDisposed = false; + int _generation = 0; + + /// Latest snapshot, synchronously updated before each stream event is emitted. + ModelDownloadTaskSnapshot get snapshot => _snapshot; + + /// Broadcast stream of task snapshots. + Stream get snapshots => _snapshots.stream; + + /// Starts resolving [source] with [options]. + /// + /// Throws [StateError] when another task is already running. The returned + /// future completes with the ready [ModelCacheEntry] or rethrows the manager's + /// failure after emitting a failed/cancelled snapshot. + Future start( + ModelSource source, { + ModelLoadOptions options = ModelLoadOptions.defaults, + }) { + _throwIfDisposed(); + if (options.cancelToken != null) { + throw ArgumentError.value( + options.cancelToken, + 'options.cancelToken', + 'ModelDownloadController owns cancellation; call cancel() on the controller instead.', + ); + } + if (_snapshot.isRunning) { + throw StateError('A model download task is already running.'); + } + _lastSource = source; + _lastOptions = options; + final generation = _generation + 1; + _generation = generation; + final cancelToken = ModelDownloadCancelToken(); + _cancelToken = cancelToken; + final effectiveOptions = _withCancelToken(options, cancelToken); + return _run(generation, source, options, effectiveOptions, cancelToken); + } + + /// Retries the last source with the last options passed to [start]. + Future retry() { + final source = _lastSource; + if (source == null) { + throw StateError('No model download task is available to retry.'); + } + return start(source, options: _lastOptions); + } + + /// Requests cooperative cancellation for the active task. + void cancel() { + _cancelToken?.cancel(); + } + + /// Cancels any active task and closes the snapshot stream. + Future dispose() async { + if (_isDisposed) { + return; + } + _isDisposed = true; + cancel(); + await _snapshots.close(); + } + + Future _run( + int generation, + ModelSource source, + ModelLoadOptions originalOptions, + ModelLoadOptions effectiveOptions, + ModelDownloadCancelToken cancelToken, + ) async { + ModelDownloadProgress? latestProgress; + try { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.resolving, + source: source, + ), + ); + _throwIfCancelled(cancelToken); + + final shouldCheckCache = + source.isRemote && + originalOptions.cachePolicy != ModelCachePolicy.refresh && + originalOptions.cachePolicy != ModelCachePolicy.noCache; + var cacheHit = false; + if (shouldCheckCache) { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.checkingCache, + source: source, + ), + ); + final cached = await manager.get( + source.cacheKey, + cacheDirectory: originalOptions.cacheDirectory, + ); + _throwIfCancelled(cancelToken); + if (cached != null) { + cacheHit = true; + } + } + + final shouldReportDownload = + source.isRemote && + !cacheHit && + originalOptions.cachePolicy != ModelCachePolicy.cacheOnly; + if (shouldReportDownload) { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.downloading, + source: source, + ), + ); + } else { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.verifying, + source: source, + ), + ); + } + + final entry = await manager.ensureModel( + source, + options: effectiveOptions, + onProgress: (progress) { + latestProgress = progress; + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.downloading, + source: source, + progress: progress, + ), + ); + }, + ); + _throwIfCancelled(cancelToken); + + if (_snapshot.stage != ModelDownloadTaskStage.verifying) { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.verifying, + source: source, + progress: latestProgress, + ), + ); + } + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.ready, + source: source, + entry: entry, + progress: latestProgress, + ), + ); + return entry; + } catch (error) { + if (_isCancelledError(cancelToken)) { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.cancelled, + source: source, + progress: latestProgress, + errorMessage: 'Download cancelled for ${source.displayName}.', + ), + ); + } else { + _emit( + generation, + ModelDownloadTaskSnapshot( + stage: ModelDownloadTaskStage.failed, + source: source, + progress: latestProgress, + errorMessage: _redactedErrorMessage(error), + ), + ); + } + rethrow; + } finally { + if (generation == _generation) { + _cancelToken = null; + } + } + } + + void _emit(int generation, ModelDownloadTaskSnapshot snapshot) { + if (_isDisposed || generation != _generation) { + return; + } + _snapshot = snapshot; + _snapshots.add(snapshot); + } + + void _throwIfDisposed() { + if (_isDisposed) { + throw StateError('ModelDownloadController has been disposed.'); + } + } +} + +ModelLoadOptions _withCancelToken( + ModelLoadOptions options, + ModelDownloadCancelToken cancelToken, +) { + return ModelLoadOptions( + cachePolicy: options.cachePolicy, + cacheDirectory: options.cacheDirectory, + sha256: options.sha256, + bearerToken: options.bearerToken, + headers: options.headers, + cancelToken: cancelToken, + resume: options.resume, + maxRetries: options.maxRetries, + ); +} + +void _throwIfCancelled(ModelDownloadCancelToken cancelToken) { + if (cancelToken.isCancelled) { + throw LlamaStateException('Model download was cancelled.'); + } +} + +bool _isCancelledError(ModelDownloadCancelToken cancelToken) { + return cancelToken.isCancelled; +} + +String _redactedErrorMessage(Object error) { + return error.toString().replaceAllMapped(_urlPattern, (match) { + final value = match.group(0)!; + final trailing = _trailingPunctuation.firstMatch(value)?.group(0) ?? ''; + final candidate = trailing.isEmpty + ? value + : value.substring(0, value.length - trailing.length); + final uri = Uri.tryParse(candidate); + if (uri == null || (uri.scheme != 'http' && uri.scheme != 'https')) { + return '$trailing'; + } + final redacted = Uri( + scheme: uri.scheme, + host: uri.host, + port: uri.hasPort ? uri.port : null, + path: uri.path, + ); + return '${redacted.toString()}$trailing'; + }); +} + +final RegExp _urlPattern = RegExp(r'https?:\/\/\S+'); +final RegExp _trailingPunctuation = RegExp(r'[.?!:\)\]\}>]+$'); diff --git a/lib/src/core/models/download/model_download_manager.dart b/lib/src/core/models/download/model_download_manager.dart index 6d7981f..acbbed1 100644 --- a/lib/src/core/models/download/model_download_manager.dart +++ b/lib/src/core/models/download/model_download_manager.dart @@ -1,3 +1,4 @@ +export 'model_download_controller.dart'; export 'model_download_manager_base.dart'; export 'model_download_manager_stub.dart' if (dart.library.io) '../../../platform/io/model_download_manager_io.dart'; diff --git a/test/unit/core/models/download/model_download_controller_test.dart b/test/unit/core/models/download/model_download_controller_test.dart new file mode 100644 index 0000000..c27ec9c --- /dev/null +++ b/test/unit/core/models/download/model_download_controller_test.dart @@ -0,0 +1,356 @@ +import 'dart:async'; + +import 'package:llamadart/llamadart.dart'; +import 'package:test/test.dart'; + +void main() { + group('ModelDownloadController', () { + test( + 'returns cached remote entries without reporting a download', + () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf?token=secret'), + ); + final cached = _entryFor(source, '/cache/model.gguf'); + final manager = _FakeDownloadManager( + entry: cached, + cachedEntry: cached, + ); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + final stages = []; + final sub = controller.snapshots.listen( + (snapshot) => stages.add(snapshot.stage), + ); + addTearDown(sub.cancel); + + final entry = await controller.start(source); + + expect(entry, same(cached)); + expect(manager.ensureCalls, 1); + expect(stages, [ + ModelDownloadTaskStage.resolving, + ModelDownloadTaskStage.checkingCache, + ModelDownloadTaskStage.verifying, + ModelDownloadTaskStage.ready, + ]); + expect(stages, isNot(contains(ModelDownloadTaskStage.downloading))); + expect(controller.snapshot.entry, same(cached)); + expect(controller.snapshot.isRunning, isFalse); + }, + ); + + test('validates cached entries through ensureModel before ready', () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final cached = _entryFor(source, '/cache/model.gguf'); + final manager = _FakeDownloadManager(cachedEntry: cached) + ..error = LlamaModelException('Checksum mismatch for cached model.'); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + await expectLater( + controller.start( + source, + options: ModelLoadOptions( + sha256: + 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', + ), + ), + throwsA(isA()), + ); + + expect(manager.ensureCalls, 1); + expect(controller.snapshot.stage, ModelDownloadTaskStage.failed); + expect(controller.snapshot.entry, isNull); + }); + + test( + 'emits download, verification, and ready states for cache misses', + () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final entry = _entryFor(source, '/cache/model.gguf'); + final manager = _FakeDownloadManager(entry: entry) + ..progressEvents = const [ + ModelDownloadProgress(receivedBytes: 2, totalBytes: 10), + ModelDownloadProgress(receivedBytes: 10, totalBytes: 10), + ]; + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + final snapshots = []; + final sub = controller.snapshots.listen(snapshots.add); + addTearDown(sub.cancel); + + final result = await controller.start(source); + + expect(result, same(entry)); + expect(manager.ensureCalls, 1); + expect(manager.lastOptions?.cancelToken, isNotNull); + expect( + snapshots.map((snapshot) => snapshot.stage), + containsAllInOrder([ + ModelDownloadTaskStage.resolving, + ModelDownloadTaskStage.checkingCache, + ModelDownloadTaskStage.downloading, + ModelDownloadTaskStage.verifying, + ModelDownloadTaskStage.ready, + ]), + ); + expect( + snapshots + .where( + (snapshot) => + snapshot.stage == ModelDownloadTaskStage.downloading, + ) + .last + .progress + ?.fraction, + 1.0, + ); + expect(controller.snapshot.entry, same(entry)); + }, + ); + + test( + 'cancel requests cooperative cancellation and emits cancelled', + () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final manager = _FakeDownloadManager( + entry: _entryFor(source, '/cache/model.gguf'), + ); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + final gate = Completer(); + manager.ensureGate = gate; + + final stages = []; + final sub = controller.snapshots.listen( + (snapshot) => stages.add(snapshot.stage), + ); + addTearDown(sub.cancel); + + final task = controller.start(source); + await Future.delayed(Duration.zero); + + controller.cancel(); + gate.complete(); + + await expectLater(task, throwsA(isA())); + expect(manager.lastOptions?.cancelToken?.isCancelled, isTrue); + expect(controller.snapshot.stage, ModelDownloadTaskStage.cancelled); + expect(controller.snapshot.canRetry, isTrue); + expect(stages, contains(ModelDownloadTaskStage.cancelled)); + }, + ); + + test( + 'manager cancellation-like errors fail unless controller cancelled', + () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final manager = _FakeDownloadManager() + ..error = LlamaStateException('Download was cancelled by server.'); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + await expectLater( + controller.start(source), + throwsA(isA()), + ); + + expect(controller.snapshot.stage, ModelDownloadTaskStage.failed); + expect(controller.snapshot.canRetry, isTrue); + expect( + controller.snapshot.errorMessage, + contains('Download was cancelled by server.'), + ); + }, + ); + + test('redacts secrets from failure messages', () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf?token=secret#fragment'), + ); + final manager = _FakeDownloadManager() + ..error = LlamaModelException( + 'Failed to fetch https://example.com/model.gguf?token=secret#fragment', + ); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + await expectLater( + controller.start(source), + throwsA(isA()), + ); + + expect(controller.snapshot.stage, ModelDownloadTaskStage.failed); + expect( + controller.snapshot.errorMessage, + contains('https://example.com/model.gguf'), + ); + expect(controller.snapshot.errorMessage, isNot(contains('secret'))); + expect(controller.snapshot.errorMessage, isNot(contains('token='))); + }); + + test( + 'redacts semicolon and comma query tails from failure messages', + () async { + final source = ModelSource.url( + Uri.parse( + 'https://example.com/model.gguf?token=secret;sig=abc,scope=all', + ), + ); + final manager = _FakeDownloadManager() + ..error = LlamaModelException( + 'Failed https://example.com/model.gguf?token=secret;sig=abc,scope=all.', + ); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + await expectLater( + controller.start(source), + throwsA(isA()), + ); + + expect( + controller.snapshot.errorMessage, + contains('https://example.com/model.gguf'), + ); + expect(controller.snapshot.errorMessage, isNot(contains('secret'))); + expect(controller.snapshot.errorMessage, isNot(contains('sig='))); + expect(controller.snapshot.errorMessage, isNot(contains('scope='))); + }, + ); + + test('rejects caller-supplied cancellation tokens', () { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final controller = ModelDownloadController( + manager: _FakeDownloadManager(), + ); + addTearDown(controller.dispose); + + expect( + () => controller.start( + source, + options: ModelLoadOptions(cancelToken: ModelDownloadCancelToken()), + ), + throwsA(isA()), + ); + }); + + test('retry reuses the last source and options after failure', () async { + final source = ModelSource.url( + Uri.parse('https://example.com/model.gguf'), + ); + final entry = _entryFor(source, '/cache/model.gguf'); + final manager = _FakeDownloadManager() + ..error = LlamaModelException('temporary failure'); + final controller = ModelDownloadController(manager: manager); + addTearDown(controller.dispose); + + await expectLater( + controller.start( + source, + options: ModelLoadOptions(cachePolicy: ModelCachePolicy.refresh), + ), + throwsA(isA()), + ); + expect(controller.snapshot.stage, ModelDownloadTaskStage.failed); + + manager + ..error = null + ..entry = entry; + final retried = await controller.retry(); + + expect(retried, same(entry)); + expect(manager.ensureCalls, 2); + expect(manager.lastOptions?.cachePolicy, ModelCachePolicy.refresh); + expect(controller.snapshot.stage, ModelDownloadTaskStage.ready); + }); + }); +} + +ModelCacheEntry _entryFor(ModelSource source, String filePath) { + final now = DateTime.utc(2026); + return ModelCacheEntry( + sourceCanonicalKey: source.metadataSourceKey, + cacheKey: source.cacheKey, + fileName: source.fileName, + filePath: filePath, + bytes: 10, + createdAt: now, + updatedAt: now, + ); +} + +class _FakeDownloadManager implements ModelDownloadManager { + _FakeDownloadManager({this.entry, this.cachedEntry}); + + ModelCacheEntry? entry; + ModelCacheEntry? cachedEntry; + Object? error; + Completer? ensureGate; + List progressEvents = const []; + int ensureCalls = 0; + ModelLoadOptions? lastOptions; + + @override + Future ensureModel( + ModelSource source, { + ModelLoadOptions options = ModelLoadOptions.defaults, + ModelDownloadProgressCallback? onProgress, + }) async { + ensureCalls += 1; + lastOptions = options; + await ensureGate?.future; + if (options.cancelToken?.isCancelled ?? false) { + throw LlamaStateException('Model download was cancelled.'); + } + final failure = error; + if (failure != null) { + throw failure; + } + for (final progress in progressEvents) { + onProgress?.call(progress); + } + return entry ?? _entryFor(source, '/cache/${source.fileName}'); + } + + @override + Future> list({String? cacheDirectory}) async => + cachedEntry == null + ? const [] + : [cachedEntry!]; + + @override + Future get( + String cacheKey, { + String? cacheDirectory, + }) async { + final entry = cachedEntry; + return entry != null && entry.cacheKey == cacheKey ? entry : null; + } + + @override + Future remove(String cacheKey, {String? cacheDirectory}) async {} + + @override + Future clear({String? cacheDirectory}) async {} + + @override + Future> prune({ + Duration? maxAge, + int? maxBytes, + String? cacheDirectory, + }) async => const []; +} diff --git a/website/docs/changelog/recent-releases.md b/website/docs/changelog/recent-releases.md index f06e7ce..9d30ab5 100644 --- a/website/docs/changelog/recent-releases.md +++ b/website/docs/changelog/recent-releases.md @@ -12,6 +12,13 @@ For canonical full release notes, use: - Added WebGPU readiness guidance covering browser capability checks, cross-origin isolation, bridge asset/version diagnostics, fallback behavior, model/configuration pressure, and the Flutter Web real-model smoke path. +- Added `ModelDownloadController`, a dependency-free helper that turns + `ModelDownloadManager` cache/download work into app-facing lifecycle states + for resolving, cache checks, downloads, verification, ready, failed, + cancelled, and retry flows. +- Wired the runnable chat app example through a `ModelDownloadManager` adapter + so its model-management UI demonstrates the controller while preserving the + example's multi-asset and web-cache service behavior. ## 0.6.13 diff --git a/website/docs/examples/chat-app.md b/website/docs/examples/chat-app.md index fc24191..497c279 100644 --- a/website/docs/examples/chat-app.md +++ b/website/docs/examples/chat-app.md @@ -31,6 +31,11 @@ flutter test - Real-time streaming chat UI. - Model selection and download flow. +- The runnable chat app wires `ModelDownloadController` into its model-management + flow through a small adapter, so cache checks, progress, cancel, retry, and + clear ready/failure states come from the same package helper app code can + reuse. The adapter keeps the example's platform-specific service layer for + multi-asset model + `mmproj` downloads and browser cache behavior. - Runtime backend preference and GPU layer controls. - Persistent settings and split Dart/native logging controls. - Tool-calling toggles and model capability badges. diff --git a/website/docs/guides/model-lifecycle.md b/website/docs/guides/model-lifecycle.md index 17aa9c1..cf23c37 100644 --- a/website/docs/guides/model-lifecycle.md +++ b/website/docs/guides/model-lifecycle.md @@ -186,6 +186,75 @@ await manager.prune(maxAge: const Duration(days: 30), maxBytes: 20 * 1024 * 1024 await manager.clear(); ``` +### App-friendly download controller + +Flutter apps often need more than byte callbacks: they need stable UI states, +retry/cancel controls, cache-hit handling, and a safe error string for snackbars +or banners. `ModelDownloadController` wraps any `ModelDownloadManager` and emits +that lifecycle without depending on Flutter: + +```dart +final controller = ModelDownloadController( + manager: DefaultModelDownloadManager( + defaultCacheDirectory: '/app/cache/llamadart-models', + ), +); + +final subscription = controller.snapshots.listen((snapshot) { + switch (snapshot.stage) { + case ModelDownloadTaskStage.checkingCache: + print('Checking cache for ${snapshot.source?.displayName}'); + break; + case ModelDownloadTaskStage.downloading: + final percent = snapshot.fraction == null + ? 'unknown' + : '${(snapshot.fraction! * 100).toStringAsFixed(1)}%'; + print('Downloading $percent'); + break; + case ModelDownloadTaskStage.ready: + print('Ready at ${snapshot.entry?.filePath}'); + break; + case ModelDownloadTaskStage.failed: + print(snapshot.errorMessage); + break; + case ModelDownloadTaskStage.cancelled: + print('Cancelled; retry is available: ${snapshot.canRetry}'); + break; + default: + break; + } +}); + +try { + final entry = await controller.start( + ModelSource.parse('hf://owner/repo/model-Q4_K_M.gguf'), + options: ModelLoadOptions(maxRetries: 3), + ); + await engine.loadModel(entry.filePath); +} catch (_) { + if (controller.snapshot.canRetry) { + // Wire this to a Retry button. + await controller.retry(); + } +} finally { + await subscription.cancel(); + await controller.dispose(); +} +``` + +Controller stages are `idle`, `resolving`, `checkingCache`, `downloading`, +`verifying`, `ready`, `failed`, and `cancelled`. The cache check is advisory for +UI state only; `ready` is emitted only after the manager's authoritative +`ensureModel(...)` path validates the cache entry and any caller-provided +checksum. Call `cancel()` from your UI to request cooperative cancellation; call +`retry()` after `failed` or `cancelled` to reuse the last source/options. Because +the controller owns cancellation, pass cache/auth/retry options through +`ModelLoadOptions` but call `controller.cancel()` instead of supplying +`ModelLoadOptions.cancelToken`. Error messages redact URL query strings and +fragments so signed URLs or tokens are not shown in UI logs. On web, inject a +custom manager for browser-specific storage; the default package manager remains +native/file-backed. + Downloaded files are written to `.part` files and promoted to the completed model path only after the HTTP stream and optional SHA-256 verification succeed. Stable-cache remote downloads are serialized per cache entry in-process,