Skip to content

Commit 266e6f1

Browse files
simbasimba
authored andcommitted
feat(ux): add robust caching bandwidth speedometer
1 parent 3108901 commit 266e6f1

4 files changed

Lines changed: 92 additions & 7 deletions

File tree

Sources/mlx-server/Server.swift

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,63 @@ import MLXVLM
2323

2424
final class ProgressTracker {
2525
var lastUpdate: TimeInterval = 0
26+
var lastBytes: Int64 = 0
27+
var speedStr = "0.0 MB/s"
2628
var isDone = false
2729
var spinnerFrames = ["", "", "", "", "", "", "", "", "", ""]
2830
var frameIndex = 0
31+
let modelId: String
32+
33+
init(modelId: String) {
34+
self.modelId = modelId
35+
}
36+
37+
func getDownloadedBytes() -> Int64 {
38+
let home = FileManager.default.homeDirectoryForCurrentUser
39+
let folderName = "models--" + modelId.replacingOccurrences(of: "/", with: "--")
40+
let modelHubDir = home.appendingPathComponent(".cache/huggingface/hub/\(folderName)")
41+
let downloadDir = home.appendingPathComponent(".cache/huggingface/download")
42+
43+
func sumDir(_ dir: URL) -> Int64 {
44+
var total: Int64 = 0
45+
if let enumerator = FileManager.default.enumerator(at: dir, includingPropertiesForKeys: [.fileSizeKey]) {
46+
for case let file as URL in enumerator {
47+
// Quick check to skip symlinks from inflating size
48+
if let attr = try? file.resourceValues(forKeys: [.fileSizeKey, .isSymbolicLinkKey]),
49+
let size = attr.fileSize,
50+
attr.isSymbolicLink != true {
51+
total += Int64(size)
52+
}
53+
}
54+
}
55+
return total
56+
}
57+
58+
return sumDir(modelHubDir) + sumDir(downloadDir)
59+
}
2960

3061
func printProgress(_ progress: Progress) {
3162
if isDone { return }
3263
let now = Date().timeIntervalSince1970
3364
let fraction = progress.fractionCompleted
3465

35-
if lastUpdate == 0 { lastUpdate = now }
66+
if lastUpdate == 0 {
67+
lastUpdate = now
68+
lastBytes = getDownloadedBytes()
69+
}
3670
let interval = now - lastUpdate
3771

38-
if interval > 0.1 {
72+
if interval > 0.5 {
3973
frameIndex = (frameIndex + 1) % spinnerFrames.count
74+
75+
let currentBytes = getDownloadedBytes()
76+
let diff = Double(currentBytes - lastBytes)
77+
if diff >= 0 {
78+
let speedMBps = (diff / interval) / 1_048_576.0
79+
speedStr = String(format: "%.1f MB/s", speedMBps)
80+
}
81+
82+
lastBytes = currentBytes
4083
lastUpdate = now
4184
}
4285

@@ -56,12 +99,11 @@ final class ProgressTracker {
5699

57100
let pctStr = String(format: "%3d%%", pct)
58101
let spinner = spinnerFrames[frameIndex]
102+
let speedText = "| Speed: \(speedStr)"
59103

60-
// If the library properly bubbled up throughput or total bytes, we'd show it,
61-
// but swift-transformers aggregated Progress uses abstract units (e.g. 100 for total).
62-
let msg = String(format: "\r[mlx-server] Download: [%@] %@ %@", bars, pctStr, spinner)
104+
let msg = String(format: "\r[mlx-server] Download: [%@] %@ %@ %@", bars, pctStr, spinner, speedText)
63105

64-
print(msg.padding(toLength: 80, withPad: " ", startingAt: 0), terminator: "")
106+
print(msg.padding(toLength: 90, withPad: " ", startingAt: 0), terminator: "")
65107
fflush(stdout)
66108

67109
if fraction >= 1.0 {
@@ -220,7 +262,14 @@ struct MLXServer: AsyncParsableCommand {
220262

221263
let isVision = self.vision
222264
let container: ModelContainer
223-
let tracker = ProgressTracker()
265+
266+
// Handle getting the simple model ID string for the tracker
267+
let resolvedModelId: String = {
268+
if case .id(let idStr, _) = modelConfig.id { return idStr }
269+
return self.model
270+
}()
271+
let tracker = ProgressTracker(modelId: resolvedModelId)
272+
224273
if isVision {
225274
print("[mlx-server] Loading VLM (vision-language model)...")
226275
container = try await VLMModelFactory.shared.loadContainer(

test_children.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import Foundation
2+
3+
func sumBytes(_ p: Progress) -> Int64 {
4+
if p.children.isEmpty {
5+
return p.completedUnitCount
6+
}
7+
return p.children.reduce(0) { $0 + sumBytes($1) }
8+
}

test_progress.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import Foundation
2+
3+
let p = Progress()
4+
print("throughput:", p.throughput ?? "nil")
5+
print("userInfo:", p.userInfo)

test_speed.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import Foundation
2+
3+
func getSystemDownloadBytes() -> UInt64 {
4+
var ifaddr: UnsafeMutablePointer<ifaddrs>?
5+
guard getifaddrs(&ifaddr) == 0 else { return 0 }
6+
defer { freeifaddrs(ifaddr) }
7+
8+
var total: UInt64 = 0
9+
var ptr = ifaddr
10+
while let p = ptr {
11+
if p.pointee.ifa_addr != nil, p.pointee.ifa_addr.pointee.sa_family == UInt8(AF_LINK) {
12+
let data = p.pointee.ifa_data?.bindMemory(to: if_data.self, capacity: 1)
13+
total += UInt64(data?.pointee.ifi_ibytes ?? 0)
14+
}
15+
ptr = p.pointee.ifa_next
16+
}
17+
return total
18+
}
19+
20+
let start = getSystemDownloadBytes()
21+
Thread.sleep(forTimeInterval: 1.0)
22+
let end = getSystemDownloadBytes()
23+
print("Speed: \((Double(end - start) / 1048576.0)) MB/s")

0 commit comments

Comments
 (0)