Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,28 @@ import Foundation
import Tokenizers
import Hub

/// Wrapper to store ModelContext in NSCache (requires NSObject subclass).
private final class CachedContext: NSObject, @unchecked Sendable {
let context: ModelContext
init(_ context: ModelContext) { self.context = context }
/// Wrapper to store model availability state in NSCache.
private final class CachedModelState: NSObject, @unchecked Sendable {
enum Value {
case loaded(ModelContext)
case failed(String)
}

let value: Value

init(_ value: Value) {
self.value = value
}
}

/// Coordinates a bounded in-memory cache with structured, coalesced loading.
private final class ModelContextCache {
private let cache: NSCache<NSString, CachedContext>
private let inFlight = Locked<[String: Task<CachedContext, Error>]>([:])
private let cache: NSCache<NSString, CachedModelState>
private let inFlight = Locked<[String: Task<CachedModelState, Error>]>([:])

/// Creates a cache with a count-based eviction limit.
init(countLimit: Int) {
let cache = NSCache<NSString, CachedContext>()
let cache = NSCache<NSString, CachedModelState>()
cache.countLimit = countLimit
self.cache = cache
}
Expand All @@ -42,23 +50,45 @@ import Foundation
loader: @escaping @Sendable () async throws -> ModelContext
) async throws -> ModelContext {
let cacheKey = key as NSString
if let cached = cache.object(forKey: cacheKey) {
return cached.context
if let cached = cache.object(forKey: cacheKey),
case .loaded(let context) = cached.value
{
return context
}

if let task = inFlightTask(for: key) {
return try await task.value.context
let cached = try await task.value
if case .loaded(let context) = cached.value {
return context
}
throw CancellationError()
}

let task = Task { try await CachedContext(loader()) }
let task = Task {
let context = try await loader()
return CachedModelState(.loaded(context))
}
setInFlight(task, for: key)

do {
let cached = try await task.value
cache.setObject(cached, forKey: cacheKey)
clearInFlight(for: key)
return cached.context
if case .loaded(let context) = cached.value {
return context
}
throw CancellationError()
} catch {
// Don't treat cancellations as load failures.
if error is CancellationError || Task.isCancelled {
cache.removeObject(forKey: cacheKey)
clearInFlight(for: key)
throw error
}
cache.setObject(
CachedModelState(.failed(String(reflecting: error))),
forKey: cacheKey
)
clearInFlight(for: key)
throw error
}
Expand All @@ -74,6 +104,28 @@ import Foundation
cache.removeAllObjects()
}

/// Returns whether a cached context exists for the key.
func contains(_ key: String) -> Bool {
guard let cached = cache.object(forKey: key as NSString) else {
return false
}
if case .loaded = cached.value {
return true
}
return false
}

/// Returns a description of the most recent load failure for the key.
func failureDescription(for key: String) -> String? {
guard let cached = cache.object(forKey: key as NSString) else {
return nil
}
if case .failed(let description) = cached.value {
return description
}
return nil
}

/// Cancels in-flight work and removes cached data for the key.
func removeAndCancel(for key: String) async {
let task = removeInFlight(for: key)
Expand All @@ -88,27 +140,27 @@ import Foundation
cache.removeAllObjects()
}

private func inFlightTask(for key: String) -> Task<CachedContext, Error>? {
private func inFlightTask(for key: String) -> Task<CachedModelState, Error>? {
inFlight.withLock { $0[key] }
}

private func setInFlight(_ task: Task<CachedContext, Error>, for key: String) {
private func setInFlight(_ task: Task<CachedModelState, Error>, for key: String) {
inFlight.withLock { $0[key] = task }
}

private func clearInFlight(for key: String) {
inFlight.withLock { $0[key] = nil }
}

private func removeInFlight(for key: String) -> Task<CachedContext, Error>? {
private func removeInFlight(for key: String) -> Task<CachedModelState, Error>? {
inFlight.withLock {
let task = $0[key]
$0[key] = nil
return task
}
}

private func removeAllInFlight() -> [Task<CachedContext, Error>] {
private func removeAllInFlight() -> [Task<CachedModelState, Error>] {
inFlight.withLock {
let tasks = Array($0.values)
$0.removeAll()
Expand All @@ -132,8 +184,12 @@ import Foundation
/// ```
public struct MLXLanguageModel: LanguageModel {
/// The reason the model is unavailable.
/// This model is always available.
public typealias UnavailableReason = Never
public enum UnavailableReason: Sendable, Equatable, Hashable {
/// The model has not been loaded into memory yet.
case notLoaded
/// The model failed to load and includes the underlying error details.
case failedToLoad(String)
}

/// The model identifier.
public let modelId: String
Expand All @@ -156,6 +212,20 @@ import Foundation
self.directory = directory
}

/// The current availability of this model in memory.
public var availability: Availability<UnavailableReason> {
let key = directory?.absoluteString ?? modelId
if modelCache.contains(key) {
return .available
}

if let failureDescription = modelCache.failureDescription(for: key) {
return .unavailable(.failedToLoad(failureDescription))
}

return .unavailable(.notLoaded)
}

/// Removes this model from the shared cache and cancels any in-flight load.
///
/// Call this to free memory when the model is no longer needed.
Expand Down
34 changes: 34 additions & 0 deletions Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ import Testing
let model = MLXLanguageModel(modelId: "mlx-community/Qwen3-0.6B-4bit")
let visionModel = MLXLanguageModel(modelId: "mlx-community/Qwen2-VL-2B-Instruct-4bit")

@Test func availabilityBecomesAvailableAfterSuccessfulLoad() async throws {
await model.removeFromCache()

#expect(model.availability == .unavailable(.notLoaded))
#expect(model.isAvailable == false)

let session = LanguageModelSession(model: model)
let response = try await session.respond(to: "Say hello")
#expect(!response.content.isEmpty)

#expect(model.availability == .available)
#expect(model.isAvailable == true)
}

@Test func basicResponse() async throws {
let session = LanguageModelSession(model: model)

Expand Down Expand Up @@ -205,5 +219,25 @@ import Testing
)
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
}

@Test func unavailableForNonexistentModel() async {
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
await model.removeFromCache()
#expect(model.availability == .unavailable(.notLoaded))
#expect(model.isAvailable == false)

let session = LanguageModelSession(model: model)
await #expect(throws: Error.self) {
_ = try await session.respond(to: "Hello")
}

switch model.availability {
case .unavailable(.failedToLoad(let description)):
#expect(!description.isEmpty)
default:
Issue.record("Expected model availability to report failedToLoad after failed request")
}
#expect(model.isAvailable == false)
}
}
#endif // MLX