diff --git a/Example/KnitExample/ContentView.swift b/Example/KnitExample/ContentView.swift index d24f724d..8ab0ad91 100644 --- a/Example/KnitExample/ContentView.swift +++ b/Example/KnitExample/ContentView.swift @@ -7,7 +7,7 @@ import SwiftUI struct ContentView: View { - let resolver: Resolver + let resolver: BaseResolver var body: some View { VStack { @@ -21,7 +21,7 @@ struct ContentView: View { } struct ContentView_Previews: PreviewProvider { - static let assembler = ScopedModuleAssembler([KnitExampleAssembly()]) + static let assembler = ScopedModuleAssembler([KnitExampleAssembly()]) static var previews: some View { return ContentView(resolver: assembler.resolver) } diff --git a/Example/KnitExample/KnitExampleApp.swift b/Example/KnitExample/KnitExampleApp.swift index 6a4b99bb..eb585bc0 100644 --- a/Example/KnitExample/KnitExampleApp.swift +++ b/Example/KnitExample/KnitExampleApp.swift @@ -8,12 +8,12 @@ import Knit @main struct KnitExampleApp: App { - let assembler: ScopedModuleAssembler - var resolver: Resolver { assembler.resolver } + let assembler: ScopedModuleAssembler + var resolver: BaseResolver { assembler.resolver } @MainActor init() { - assembler = ScopedModuleAssembler( + assembler = ScopedModuleAssembler( [KnitExampleAssembly()] ) } diff --git a/Example/KnitExample/KnitExampleAssembly.swift b/Example/KnitExample/KnitExampleAssembly.swift index 7f20012e..cbfe1fcc 100644 --- a/Example/KnitExample/KnitExampleAssembly.swift +++ b/Example/KnitExample/KnitExampleAssembly.swift @@ -9,7 +9,7 @@ import KnitMacros // @knit internal final class KnitExampleAssembly: ModuleAssembly { - typealias TargetResolver = Resolver + typealias TargetResolver = BaseResolver static var dependencies: [any ModuleAssembly.Type] { [] } diff --git a/Example/KnitExample/KnitExampleUserAssembly.swift b/Example/KnitExample/KnitExampleUserAssembly.swift index 41cbe817..d7e2f74c 100644 --- a/Example/KnitExample/KnitExampleUserAssembly.swift +++ b/Example/KnitExample/KnitExampleUserAssembly.swift @@ -9,7 +9,7 @@ import Knit /// An assembly expected to be registered at the user level rather than at the app level final class KnitExampleUserAssembly: ModuleAssembly { - typealias TargetResolver = Resolver + typealias TargetResolver = BaseResolver static var dependencies: [any ModuleAssembly.Type] { [] } diff --git a/Sources/Knit/Container.swift b/Sources/Knit/Container.swift index 4b3affda..f19adc18 100644 --- a/Sources/Knit/Container.swift +++ b/Sources/Knit/Container.swift @@ -11,24 +11,11 @@ import Swinject The Knit.Container also performs the function of a weak wrapper of the `SwinjectContainer`. */ -public class Container: Knit.Resolver { +public class Container { // MARK: - Knit.Resolver - public var resolver: TargetResolver { - self as! TargetResolver - } - - /// Returns `true` if the backing container is still available in memory, otherwise `false`. - public var isAvailable: Bool { - _swinjectContainer != nil - } - - // MARK: - SwinjectResolver - - public func unsafeResolver(file: StaticString, function: StaticString, line: UInt) -> SwinjectResolver { - _unwrappedSwinjectContainer(file: file, function: function, line: line) - } + public let resolver: TargetResolver // MARK: - Private Properties @@ -39,6 +26,7 @@ public class Container: Knit.Resolver { // This should not be promoted from `fileprivate` access level. fileprivate init(_swinjectContainer: SwinjectContainer) { self._swinjectContainer = _swinjectContainer + self.resolver = TargetResolver(_swinjectContainer: _swinjectContainer) } } diff --git a/Sources/Knit/Module/ModuleAssembly.swift b/Sources/Knit/Module/ModuleAssembly.swift index d2dba943..eb64c056 100644 --- a/Sources/Knit/Module/ModuleAssembly.swift +++ b/Sources/Knit/Module/ModuleAssembly.swift @@ -7,7 +7,7 @@ import Swinject public protocol ModuleAssembly { - associatedtype TargetResolver + associatedtype TargetResolver: Knit.Resolver static var resolverType: Self.TargetResolver.Type { get } @@ -20,10 +20,6 @@ public protocol ModuleAssembly { /// A common case is a fake assembly that registers fake services matching those from the original module. static var replaces: [any ModuleAssembly.Type] { get } - /// Filter the list of dependencies down to those which match the scope of this assembly - /// This can be overridden in apps with custom Resolver hierarchies - static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type] - /// Hints about this assembly using by DependencyBuilder. Designed for internal use static var _assemblyFlags: [ModuleAssemblyFlags] { get } @@ -39,13 +35,6 @@ public extension ModuleAssembly { static var replaces: [any ModuleAssembly.Type] { [] } - static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type] { - return dependencies.filter { - // Default the scoped implementation to match types directly - return self.resolverType == $0.resolverType - } - } - static var _assemblyFlags: [ModuleAssemblyFlags] { var result: [ModuleAssemblyFlags] = [] if self is any AutoInitModuleAssembly.Type { @@ -84,6 +73,13 @@ public protocol GeneratedModuleAssembly: ModuleAssembly { extension ModuleAssembly where Self: GeneratedModuleAssembly { // Default the dependencies to using generatedDependencies scoped to those with compatible resolvers public static var dependencies: [any ModuleAssembly.Type] { scoped(generatedDependencies) } + + /// Filter the list of dependencies down to those which match the scope of this assembly + public static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type] { + return dependencies.filter { module in + return resolverType.inherits(from: module.resolverType) + } + } } /// Control the behavior of Assembly Overrides. diff --git a/Sources/Knit/Module/ScopedModuleAssembler.swift b/Sources/Knit/Module/ScopedModuleAssembler.swift index 58ddee6c..00bbe144 100644 --- a/Sources/Knit/Module/ScopedModuleAssembler.swift +++ b/Sources/Knit/Module/ScopedModuleAssembler.swift @@ -6,7 +6,7 @@ import Foundation import Swinject /// Module assembly which only allows registering assemblies which target a particular resolver type. -public final class ScopedModuleAssembler { +public final class ScopedModuleAssembler { public let internalAssembler: ModuleAssembler @@ -58,18 +58,6 @@ public final class ScopedModuleAssembler { behaviors: [Behavior] = [], postAssemble: ((Container) -> Void)? = nil ) throws { - // For provided modules, fail early if they are scoped incorrectly - for assembly in modules { - let moduleAssemblyType = type(of: assembly) - if moduleAssemblyType.resolverType != TargetResolver.self { - let scopingError = ScopedModuleAssemblerError.incorrectTargetResolver( - expected: String(describing: TargetResolver.self), - actual: String(describing: moduleAssemblyType.resolverType) - ) - - throw DependencyBuilderError.assemblyValidationFailure(moduleAssemblyType, reason: scopingError) - } - } self.internalAssembler = try ModuleAssembler( parent: parent, _modules: modules, diff --git a/Sources/Knit/Resolver.swift b/Sources/Knit/Resolver.swift index 70106c41..f048b2ed 100644 --- a/Sources/Knit/Resolver.swift +++ b/Sources/Knit/Resolver.swift @@ -13,4 +13,58 @@ public protocol Resolver: AnyObject { func unsafeResolver(file: StaticString, function: StaticString, line: UInt) -> SwinjectResolver + init(_swinjectContainer: SwinjectContainer) + + /// Resolvers require a manual implementation that matches the inheritance structure of the Resolver + /// If ResolverB inherits from ResolverA then the ResolverB inherits function should match this + /// Example: + /// public func ResolverB: ResolverA { + /// static func inherits(from resolverType: Resolver.Type) -> Bool { + /// return self == resolverType || resolverType == ResolverA.self + /// } + /// } + static func inherits(from resolverType: Resolver.Type) -> Bool + +} + +/// Default Resolver implementation. Designed to be inherited from +open class BaseResolver: Resolver { + + private weak var _swinjectContainer: SwinjectContainer? + + /// Returns `true` if the backing container is still available in memory, otherwise `false`. + public var isAvailable: Bool { + _swinjectContainer != nil + } + + // MARK: - SwinjectResolver + + public func unsafeResolver(file: StaticString, function: StaticString, line: UInt) -> SwinjectResolver { + _unwrappedSwinjectContainer(file: file, function: function, line: line) + } + + public required init(_swinjectContainer: SwinjectContainer) { + self._swinjectContainer = _swinjectContainer + } + + /// Default implementation uses pure equality + open class func inherits(from resolverType: Resolver.Type) -> Bool { + return self == resolverType + } + + // Force unwraps the weak Container + func _unwrappedSwinjectContainer( + file: StaticString = #fileID, + function: StaticString = #function, + line: UInt = #line + ) -> SwinjectContainer { + guard let _swinjectContainer else { + fatalError( + "\(function) incorrectly accessed the container for \(self) which has already been released", + file: file, + line: line + ) + } + return _swinjectContainer + } } diff --git a/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift b/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift index a44d1c35..0d2444a6 100644 --- a/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift +++ b/Sources/Knit/ServiceCollection/Container+ServiceCollection.swift @@ -30,7 +30,7 @@ extension Container { name: makeUniqueCollectionRegistrationName(), factory: { r in MainActor.assumeIsolated { - let resolver = r.resolve(Container.self)! as! TargetResolver + let resolver = r.resolve(Container.self)!.resolver return factory(resolver) } } diff --git a/Tests/KnitMacrosTests/ResolvableTests.swift b/Tests/KnitMacrosTests/ResolvableTests.swift index 72ff5c13..570e83ea 100644 --- a/Tests/KnitMacrosTests/ResolvableTests.swift +++ b/Tests/KnitMacrosTests/ResolvableTests.swift @@ -2,367 +2,3 @@ // Copyright © Block, Inc. All rights reserved. // -import KnitMacrosImplementations -import SwiftSyntaxMacros -import SwiftSyntaxMacrosTestSupport -import XCTest - -let testMacros: [String: Macro.Type] = [ - "Resolvable": ResolvableMacro.self -] - -final class ResolvableTests: XCTestCase { - func test_macro_expansion() throws { - assertMacroExpansion( - """ - @Resolvable - init(arg1: String, arg2: Int) {} - """, - expandedSource: """ - - init(arg1: String, arg2: Int) {} - - static func make(resolver: Resolver) -> Self { - return .init( - arg1: resolver.string(), - arg2: resolver.int() - ) - } - """, - macros: testMacros - ) - } - - func test_optional_parameter() throws { - assertMacroExpansion( - """ - @Resolvable - init(arg1: String?) {} - """, - expandedSource: """ - - init(arg1: String?) {} - - static func make(resolver: Resolver) -> Self { - return .init( - arg1: resolver.string() - ) - } - """, - macros: testMacros - ) - } - - func test_closure_param() throws { - assertMacroExpansion( - """ - @Resolvable - init(closure: @escaping () -> Void) {} - """, - expandedSource: """ - - init(closure: @escaping () -> Void) {} - - static func make(resolver: CustomResolver) -> Self { - return .init( - closure: resolver.closure() - ) - } - """, - macros: testMacros - ) - } - - func test_any_protocol_param() throws { - assertMacroExpansion( - """ - @Resolvable - init(arg1: any Publisher) {} - """, - expandedSource: """ - - init(arg1: any Publisher) {} - - static func make(resolver: Resolver) -> Self { - return .init( - arg1: resolver.publisher() - ) - } - """, - macros: testMacros - ) - } - - func test_default_param() throws { - assertMacroExpansion( - """ - @Resolvable - init(@UseDefault value: Int = 5) {} - """, - expandedSource: """ - - init(@UseDefault value: Int = 5) {} - - static func make(resolver: Resolver) -> Self { - return .init( - value: 5 - ) - } - """, - macros: testMacros - ) - } - - func test_default_param_unused() throws { - assertMacroExpansion( - """ - @Resolvable - init(value: Int = 5) {} - """, - expandedSource: """ - - init(value: Int = 5) {} - - static func make(resolver: Resolver) -> Self { - return .init( - value: resolver.int() - ) - } - """, - macros: testMacros - ) - } - - func test_argument() throws { - assertMacroExpansion( - """ - @Resolvable - init(@Argument value: Int) {} - """, - expandedSource: """ - - init(@Argument value: Int) {} - - static func make(resolver: Resolver, value: Int) -> Self { - return .init( - value: value - ) - } - """, - macros: testMacros - ) - } - - func test_argument_withDefaultValue() throws { - assertMacroExpansion( - """ - @Resolvable - init(@Argument value: Int = 5) {} - """, - expandedSource: """ - - init(@Argument value: Int = 5) {} - - static func make(resolver: Resolver, value: Int = 5) -> Self { - return .init( - value: value - ) - } - """, - macros: testMacros - ) - } - - func test_named() throws { - assertMacroExpansion( - """ - @Resolvable - init(@Named("customName") value: Int) {} - """, - expandedSource: """ - - init(@Named("customName") value: Int) {} - - static func make(resolver: Resolver) -> Self { - return .init( - value: resolver.int(name: .customName) - ) - } - """, - macros: testMacros - ) - } - - func test_apply_static() throws { - assertMacroExpansion( - """ - @Resolvable - static func makeThing(value: Int) -> Thing { - Thing(value: value) - } - """, - expandedSource: """ - - static func makeThing(value: Int) -> Thing { - Thing(value: value) - } - - static func makeThing(resolver: Resolver) -> Thing { - return makeThing( - value: resolver.int() - ) - } - """, - macros: testMacros - ) - } - - func test_non_static_function() throws { - assertMacroExpansion( - """ - @Resolvable - func makeThing(value: Int) -> Thing { .init() } - """, - expandedSource: """ - - func makeThing(value: Int) -> Thing { .init() } - """, - diagnostics: [ - .init( - message: "@Resolvable can only be used on init declarations or static functions", - line: 1, - column: 1 - ), - ], - macros: testMacros - ) - } - - func test_main_actor_init() { - assertMacroExpansion( - """ - @Resolvable @MainActor - init(arg1: String, arg2: Int) {} - """, - expandedSource: """ - - @MainActor - init(arg1: String, arg2: Int) {} - - @MainActor static func make(resolver: Resolver) -> Self { - return .init( - arg1: resolver.string(), - arg2: resolver.int() - ) - } - """, - macros: testMacros - ) - } - - func test_main_actor_static_function() { - assertMacroExpansion( - """ - @Resolvable @MainActor - static func makeThing(value: Int) -> Thing { - Thing(value: value) - } - """, - expandedSource: """ - - @MainActor - static func makeThing(value: Int) -> Thing { - Thing(value: value) - } - - @MainActor static func makeThing(resolver: Resolver) -> Thing { - return makeThing( - value: resolver.int() - ) - } - """, - macros: testMacros - ) - } - - func test_publisher_type() throws { - assertMacroExpansion( - """ - @Resolvable - init(profileValueProvider: AnyCurrentValuePublisher) {} - """, - expandedSource: """ - - init(profileValueProvider: AnyCurrentValuePublisher) {} - - static func make(resolver: Resolver) -> Self { - return .init( - profileValueProvider: resolver.globalAddressPublisher() - ) - } - """, - macros: testMacros - ) - } - - func test_escaping_argument() throws { - assertMacroExpansion( - """ - @Resolvable - init(@Argument arg: @escaping () -> Void) {} - """, - expandedSource: """ - - init(@Argument arg: @escaping () -> Void) {} - - static func make(resolver: Resolver, arg: @escaping (() -> Void)) -> Self { - return .init( - arg: arg - ) - } - """, - macros: testMacros - ) - } - - func test_macro_expansion_nested_type() throws { - assertMacroExpansion( - """ - @Resolvable - init(arg1: MyType.Nested) {} - """, - expandedSource: """ - - init(arg1: MyType.Nested) {} - - static func make(resolver: Resolver) -> Self { - return .init( - arg1: resolver.nested() - ) - } - """, - macros: testMacros - ) - } - - func test_macro_expansion_nested_type_argument() throws { - assertMacroExpansion( - """ - @Resolvable - init(@Argument arg1: MyType.Nested) {} - """, - expandedSource: """ - - init(@Argument arg1: MyType.Nested) {} - - static func make(resolver: Resolver, arg1: MyType.Nested) -> Self { - return .init( - arg1: arg1 - ) - } - """, - macros: testMacros - ) - } - -} diff --git a/Tests/KnitTests/ModuleAssemblyScopingTests.swift b/Tests/KnitTests/ModuleAssemblyScopingTests.swift index 73db5969..54ff4aec 100644 --- a/Tests/KnitTests/ModuleAssemblyScopingTests.swift +++ b/Tests/KnitTests/ModuleAssemblyScopingTests.swift @@ -23,15 +23,13 @@ final class ModuleAssemblyScopingTests: XCTestCase { } -private protocol ParentResolver: Resolver {} -private protocol ChildResolver: Resolver {} -private protocol OtherResolver: Resolver {} - -extension ChildResolver { - static func contains(resolver: Resolver.Type) -> Bool { - return resolver == self || resolver == ParentResolver.self +private class ParentResolver: BaseResolver {} +private class ChildResolver: ParentResolver { + public override class func inherits(from resolverType: Resolver.Type) -> Bool { + return resolverType == self || resolverType == ParentResolver.self } } +private class OtherResolver: BaseResolver {} private struct Assembly1: GeneratedModuleAssembly { typealias TargetResolver = ParentResolver @@ -56,15 +54,3 @@ private struct Assembly4: GeneratedModuleAssembly { static var generatedDependencies: [any ModuleAssembly.Type] { [Assembly1.self] } func assemble(container: Container) {} } - -private extension ModuleAssembly { - // Override the default scoping function to allow assemblies using ParentResolver to be included in ChildResolver - static func scoped(_ dependencies: [any ModuleAssembly.Type]) -> [any ModuleAssembly.Type] { - return dependencies.filter { - if self.resolverType == ChildResolver.self && $0.resolverType == ParentResolver.self { - return true - } - return self.resolverType == $0.resolverType - } - } -} diff --git a/Tests/KnitTests/ScopedModuleAssemblerTests.swift b/Tests/KnitTests/ScopedModuleAssemblerTests.swift index 0d290cf1..621518db 100644 --- a/Tests/KnitTests/ScopedModuleAssemblerTests.swift +++ b/Tests/KnitTests/ScopedModuleAssemblerTests.swift @@ -66,7 +66,7 @@ private struct Assembly1: AutoInitModuleAssembly { func assemble(container: Knit.Container) { } } -protocol OutsideResolver: SwinjectResolver { } +class OutsideResolver: BaseResolver { } private struct Assembly2: AutoInitModuleAssembly { typealias TargetResolver = OutsideResolver diff --git a/Tests/KnitTests/ServiceCollectorTests.swift b/Tests/KnitTests/ServiceCollectorTests.swift index 34f039b4..2164fde7 100644 --- a/Tests/KnitTests/ServiceCollectorTests.swift +++ b/Tests/KnitTests/ServiceCollectorTests.swift @@ -6,11 +6,9 @@ import Swinject import XCTest -private protocol ParentResolver: Knit.Resolver {} -private protocol ChildResolver: ParentResolver {} -private protocol GrandChildResolver: ChildResolver {} - -extension Knit.Container: ParentResolver, ChildResolver, GrandChildResolver {} +private class ParentResolver: BaseResolver {} +private class ChildResolver: ParentResolver {} +private class GrandChildResolver: ChildResolver {} private protocol ServiceProtocol {} @@ -87,7 +85,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection() { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register some services into a collection @@ -99,12 +97,12 @@ final class ServiceCollectorTests: XCTestCase { container.registerIntoCollection(CustomService.self) { _ in CustomService(name: "Custom 2") } // Resolving each collection should produce the expected services - let serviceProtocolCollection = container.resolveCollection(ServiceProtocol.self) + let serviceProtocolCollection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual(serviceProtocolCollection.entries.count, 2) XCTAssert(serviceProtocolCollection.entries.first is ServiceA) XCTAssert(serviceProtocolCollection.entries.last is ServiceB) - let customServiceCollection = container.resolveCollection(CustomService.self) + let customServiceCollection = container.resolver.resolveCollection(CustomService.self) XCTAssertEqual( customServiceCollection.entries.map(\.name), ["Custom 1", "Custom 2"] @@ -114,19 +112,19 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_emptyWithBehavior() { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 0) } @MainActor func test_registerIntoCollection_emptyWithoutBehavior() { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 0) } @@ -135,7 +133,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_doesntConflictWithArray() throws { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register A into a collection @@ -145,7 +143,7 @@ final class ServiceCollectorTests: XCTestCase { container.register([ServiceProtocol].self) { _ in [ServiceB()] } // Resolving the collection should produce A - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 1) XCTAssert(collection.entries.first is ServiceA) @@ -158,7 +156,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_doesntImplicitlyAggregateInstances() throws { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register A and B into a collection @@ -169,7 +167,7 @@ final class ServiceCollectorTests: XCTestCase { _ = container.register(ServiceProtocol.self) { _ in ServiceB() } // Resolving the collection should produce A and B - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual(collection.entries.count, 2) XCTAssert(collection.entries.first is ServiceA) XCTAssert(collection.entries.last is ServiceB) @@ -181,7 +179,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_allowsDuplicates() { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register some duplicate services @@ -190,7 +188,7 @@ final class ServiceCollectorTests: XCTestCase { _ = container.registerIntoCollection(ServiceProtocol.self) { _ in CustomService(name: "Car Repair") } // Resolving the collection should produce all services - let collection = container.resolveCollection(ServiceProtocol.self) + let collection = container.resolver.resolveCollection(ServiceProtocol.self) XCTAssertEqual( collection.entries.compactMap { ($0 as? CustomService)?.name }, ["Dry Cleaning", "Car Repair", "Car Repair"] @@ -202,7 +200,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsTransientScopedObjects() throws { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register a service with the `transient` scope. @@ -211,8 +209,8 @@ final class ServiceCollectorTests: XCTestCase { .registerIntoCollection(CustomService.self) { _ in CustomService(name: "service") } .inObjectScope(.transient) - let collection1 = container.resolveCollection(CustomService.self) - let collection2 = container.resolveCollection(CustomService.self) + let collection1 = container.resolver.resolveCollection(CustomService.self) + let collection2 = container.resolver.resolveCollection(CustomService.self) let instance1 = try XCTUnwrap(collection1.entries.first) let instance2 = try XCTUnwrap(collection2.entries.first) @@ -223,7 +221,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsContainerScopedObjects() throws { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register a service with the `container` scope. @@ -232,8 +230,8 @@ final class ServiceCollectorTests: XCTestCase { .registerIntoCollection(CustomService.self) { _ in CustomService(name: "service") } .inObjectScope(.container) - let collection1 = container.resolveCollection(CustomService.self) - let collection2 = container.resolveCollection(CustomService.self) + let collection1 = container.resolver.resolveCollection(CustomService.self) + let collection2 = container.resolver.resolveCollection(CustomService.self) let instance1 = try XCTUnwrap(collection1.entries.first) let instance2 = try XCTUnwrap(collection2.entries.first) @@ -244,7 +242,7 @@ final class ServiceCollectorTests: XCTestCase { @MainActor func test_registerIntoCollection_supportsWeakScopedObjects() throws { let swinjectContainer = SwinjectContainer() - let container = ContainerManager(swinjectContainer: swinjectContainer).register(Any.self) + let container = ContainerManager(swinjectContainer: swinjectContainer).register(ParentResolver.self) container._unwrappedSwinjectContainer().addBehavior(ServiceCollector()) // Register a service with the `weak` scope. @@ -259,18 +257,18 @@ final class ServiceCollectorTests: XCTestCase { .inObjectScope(.weak) // Resolve the initial instance - var instance1: CustomService? = try XCTUnwrap(container.resolveCollection(CustomService.self).entries.first) + var instance1: CustomService? = try XCTUnwrap(container.resolver.resolveCollection(CustomService.self).entries.first) XCTAssertEqual(factoryCallCount, 1) // Resolving again shouldn't increase `factoryCallCount` since `instance1` is still retained. - var instance2: CustomService? = try XCTUnwrap(container.resolveCollection(CustomService.self).entries.first) + var instance2: CustomService? = try XCTUnwrap(container.resolver.resolveCollection(CustomService.self).entries.first) XCTAssertEqual(factoryCallCount, 1) XCTAssert(instance2 === instance1) // Release our instances and resolve again. This time a new instance should be created. instance1 = nil instance2 = nil - _ = container.resolveCollection(CustomService.self) + _ = container.resolver.resolveCollection(CustomService.self) XCTAssertEqual(factoryCallCount, 2) } diff --git a/Tests/KnitTests/SynchronizationTests.swift b/Tests/KnitTests/SynchronizationTests.swift index b56feace..1b92b12b 100644 --- a/Tests/KnitTests/SynchronizationTests.swift +++ b/Tests/KnitTests/SynchronizationTests.swift @@ -84,11 +84,7 @@ private final class Service2 { } } -private protocol TestScopedResolver: Knit.Resolver { - func service1() -> Service1 - func service2() -> Service2 -} -extension TestScopedResolver { +class TestScopedResolver: BaseResolver { fileprivate func service1() -> Service1 { self.unsafeResolver(file: #filePath, function: #function, line: #line).resolve(Service1.self)! } @@ -97,4 +93,3 @@ extension TestScopedResolver { self.unsafeResolver(file: #filePath, function: #function, line: #line).resolve(Service2.self)! } } -extension Container: TestScopedResolver {} diff --git a/Tests/KnitTests/TestResolver.swift b/Tests/KnitTests/TestResolver.swift index e755f196..e4f2def2 100644 --- a/Tests/KnitTests/TestResolver.swift +++ b/Tests/KnitTests/TestResolver.swift @@ -5,9 +5,7 @@ @testable import Knit import Swinject -protocol TestResolver: Knit.Resolver { } - -extension Knit.Container: TestResolver {} +class TestResolver: BaseResolver {} extension ModuleAssembly {