diff --git a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt index 7767a69140..6289eba14c 100644 --- a/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt +++ b/model-client/src/commonMain/kotlin/org/modelix/model/oauth/OAuthConfig.kt @@ -44,7 +44,6 @@ data class OAuthConfig( clientId = clientId, clientSecret = clientSecret, authorizationUrl = authorizationUrl, - tokenUrl = tokenUrl, scopes = scopes, ) @@ -57,7 +56,6 @@ data class TokenCacheKey( val clientId: String?, val clientSecret: String?, val authorizationUrl: String?, - val tokenUrl: String?, val scopes: Set, ) diff --git a/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt b/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt index 2a80fe20f8..901183a02a 100644 --- a/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt +++ b/model-client/src/jvmMain/kotlin/org/modelix/model/oauth/ModelixAuthClient.kt @@ -28,6 +28,7 @@ import io.ktor.http.buildUrl import io.ktor.http.isSecure import io.ktor.http.takeFrom import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.awaitCancellation @@ -43,103 +44,176 @@ import org.modelix.kotlin.utils.urlEncode import java.net.SocketException import java.net.SocketTimeoutException import java.util.Collections +import java.util.concurrent.ConcurrentHashMap @Suppress("UndocumentedPublicClass") // already documented in the expected declaration actual class ModelixAuthClient { companion object { private val LOG = mu.KotlinLogging.logger { } + + // Shared across all ModelixAuthClient instances so that different ModelClientV2 + // instances pointed at different branch-specific tokenUrls can reuse one refresh token + // obtained via a single PKCE login (T031/T032). + private val cachedTokens: MutableMap = + Collections.synchronizedMap(HashMap()) } - private class CachedTokens { + // One entry per (issuer, clientId): a single refresh token shared across all branches, plus a + // per-branch (keyed by token URL) cache of the last minted access-token credential. The refresh + // grant is always issued to the caller's tokenUrl, so each branch gets its own scoped access + // token from one shared refresh token (T031); those access tokens are cached and reused until + // near expiry to avoid re-minting on every request (FR-033 skew margin). + private class CachedTokens(private val ioDispatcher: CoroutineDispatcher = Dispatchers.IO) { + @Volatile private var storedRefreshToken: String? = null + private val accessTokensByTokenUrl = ConcurrentHashMap() + private val authMutex = Mutex() private val httpTransport: HttpTransport = NetHttpTransport() private val jsonFactory: JsonFactory = GsonFactory() - private var lastCredentials: Credential? = null - private val authMutex = Mutex() - fun getTokens(config: OAuthConfig): Credential? { - return lastCredentials?.takeIf { !it.isExpired() } - } + companion object { + private const val MAX_PORT_ATTEMPTS = 100 + private const val LOCAL_RECEIVER_PORT_BASE = 26815 - fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? { - return lastCredentials?.refreshIfExpired() + // Refresh a little before the access token actually expires to avoid edge-of-expiry + // races (FR-033). + private const val EXPIRY_SKEW_SECONDS = 60L } - suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? { - return lastCredentials?.alwaysRefresh() ?: authorize(config) - } + /** The shared refresh token for this `(issuer, clientId)`, or null if not logged in yet. */ + fun getStoredRefreshToken(): String? = storedRefreshToken - suspend fun authorize(config: OAuthConfig): Credential? { - lastCredentials = null - return withContext(Dispatchers.IO) { - authMutex.withLock { - lastCredentials?.let { return@withLock it } - val flow = AuthorizationCodeFlow.Builder( - BearerToken.authorizationHeaderAccessMethod(), - httpTransport, - jsonFactory, - GenericUrl(config.tokenUrl), - ClientParametersAuthentication(config.clientId, config.clientSecret), - config.clientId, - config.authorizationUrl, - ) - .setScopes(config.scopes) - .enablePKCE() - .build() + /** A still-valid cached access token for [config]'s branch token URL, or null. No I/O. */ + fun getCachedAccessToken(config: OAuthConfig): String? = + config.tokenUrl?.let { accessTokensByTokenUrl[it] }?.takeIf { it.isFresh() }?.accessToken - repeat(100) { n -> - val port = 26815 + n - try { - val receiver: LocalServerReceiver = LocalServerReceiver.Builder().setHost("localhost").setPort(port).build() - val tokens = cancelable({ receiver.stop() }) { - val scope = this - val browser = config.authRequestHandler?.let { - AuthorizationCodeInstalledApp.Browser { url -> - it.browse(object : IAuthRequest { - override fun getUrl(): String = url - override fun cancel() = scope.cancel() - override fun isActive() = scope.isActive - }) - } - } ?: AuthorizationCodeInstalledApp.DefaultBrowser() - AuthorizationCodeInstalledApp(flow, receiver, browser).authorize(null) - } - lastCredentials = tokens - return@withContext tokens - } catch (ex: SocketException) { - LOG.info("Port $port already in use. Trying next one.") - LOG.debug("Login failed with socket exception, which is expected, if we can not open the callback port.", ex) - } - } - throw IllegalStateException("Couldn't find an available port for the redirect URL") - } + suspend fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? { + val tokenUrl = config.tokenUrl ?: return null + // Fast path: reuse a still-valid access token for this branch without locking or I/O. + accessTokensByTokenUrl[tokenUrl]?.takeIf { it.isFresh() }?.let { return it } + storedRefreshToken ?: return null + return authMutex.withLock { + // Re-check under the lock: another waiter may have just refreshed this branch. + accessTokensByTokenUrl[tokenUrl]?.takeIf { it.isFresh() }?.let { return@withLock it } + storedRefreshToken?.let { refreshWithStoredToken(config) } } } - private fun Credential.isExpired(): Boolean { - return (expiresInSeconds ?: return false) < 60 + suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? { + return authMutex.withLock { + refreshWithStoredToken(config) ?: performPKCE(config) + } } - private fun Credential.refreshIfExpired(): Credential? { - return if (isExpired()) { - alwaysRefresh() - } else { - this.takeIf { it.accessToken != null } + suspend fun authorize(config: OAuthConfig): Credential? { + return authMutex.withLock { + storedRefreshToken?.let { refreshWithStoredToken(config) } ?: performPKCE(config) } } - private fun Credential.alwaysRefresh(): Credential? { - for (attempt in 1..3) { + // Performs a refresh_token grant to config.tokenUrl using the stored RT. + // On invalid_grant the stored RT is discarded (T033). + // Must be called with authMutex held. + private suspend fun refreshWithStoredToken(config: OAuthConfig): Credential? { + val rt = storedRefreshToken ?: return null + val tokenUrl = config.tokenUrl ?: return null + return withContext(ioDispatcher) { try { - val success = refreshToken() - if (success) return this + val credential = Credential.Builder(BearerToken.authorizationHeaderAccessMethod()) + .setTransport(httpTransport) + .setJsonFactory(jsonFactory) + .setTokenServerUrl(GenericUrl(tokenUrl)) + .setClientAuthentication(ClientParametersAuthentication(config.clientId, config.clientSecret)) + .build() + credential.refreshToken = rt + val success = credential.refreshToken() + if (success) { + // Chain to the latest rotated refresh token (FR-032) and cache the + // branch-scoped access token for reuse until near expiry. + storedRefreshToken = credential.refreshToken ?: rt + accessTokensByTokenUrl[tokenUrl] = credential + credential + } else { + // Anomalous (the refresh token field was set): keep the refresh token but + // surface the failure so the caller can fall back to interactive login. + LOG.warn("Token refresh returned no credential") + null + } + } catch (e: TokenResponseException) { + // The refresh token is revoked/expired/reuse-detected: discard it and the + // access tokens minted from it, forcing a fresh interactive login (FR-033/036). + if (e.details?.error == "invalid_grant") { + storedRefreshToken = null + accessTokensByTokenUrl.clear() + } + LOG.warn("Token refresh failed: ${e.details?.error}") + null } catch (e: SocketTimeoutException) { LOG.warn(e) { "Token refresh timed out" } - } catch (e: TokenResponseException) { - LOG.warn("Could not refresh the access token: ${e.details}") - break + null } } - return null + } + + // Runs the PKCE authorization_code flow. Stores the resulting refresh token. + // Must be called with authMutex held — this is deliberate: it serializes concurrent branch + // opens onto a single interactive login ("login once"). The consequence is that while the + // user is completing the browser login, other branches block on the mutex until this holder + // finishes or is cancelled; do not narrow the lock scope or two logins can run at once. + private suspend fun performPKCE(config: OAuthConfig): Credential? { + storedRefreshToken = null + accessTokensByTokenUrl.clear() + return withContext(ioDispatcher) { + val flow = AuthorizationCodeFlow.Builder( + BearerToken.authorizationHeaderAccessMethod(), + httpTransport, + jsonFactory, + GenericUrl(config.tokenUrl), + ClientParametersAuthentication(config.clientId, config.clientSecret), + config.clientId, + config.authorizationUrl, + ) + .setScopes(config.scopes) + .enablePKCE() + .build() + + repeat(MAX_PORT_ATTEMPTS) { n -> + val port = LOCAL_RECEIVER_PORT_BASE + n + try { + val receiver = LocalServerReceiver.Builder().setHost("localhost").setPort(port).build() + val tokens = cancelable({ receiver.stop() }) { + val scope = this + val browser = config.authRequestHandler?.let { + AuthorizationCodeInstalledApp.Browser { url -> + it.browse(object : IAuthRequest { + override fun getUrl(): String = url + override fun cancel() = scope.cancel() + override fun isActive() = scope.isActive + }) + } + } ?: AuthorizationCodeInstalledApp.DefaultBrowser() + AuthorizationCodeInstalledApp(flow, receiver, browser).authorize(null) + } + storedRefreshToken = tokens.refreshToken + config.tokenUrl?.let { accessTokensByTokenUrl[it] = tokens } + return@withContext tokens + } catch (ex: SocketException) { + LOG.info("Port $port already in use. Trying next one.") + LOG.debug("Login failed with socket exception, which is expected, if we can not open the callback port.", ex) + } + } + error("Couldn't find an available port for the redirect URL") + } + } + + /** + * Whether this credential has a usable access token that is not within [EXPIRY_SKEW_SECONDS] + * of expiry. A credential with no expiry information is assumed usable (its staleness is + * still caught reactively by the 401 -> refresh path). + */ + private fun Credential.isFresh(): Boolean { + accessToken ?: return false + val expiresIn = expiresInSeconds ?: return true + return expiresIn > EXPIRY_SKEW_SECONDS } /** @@ -150,7 +224,7 @@ actual class ModelixAuthClient { return coroutineScope { var cancellationEx: CancellationException? = null var blockingCallReturned = false - val cancellationHandlerJob = launch(Dispatchers.IO) { + val cancellationHandlerJob = launch(ioDispatcher) { try { awaitCancellation() } catch (ex: CancellationException) { @@ -160,7 +234,7 @@ actual class ModelixAuthClient { } } } - withContext(Dispatchers.IO) { + withContext(ioDispatcher) { try { return@withContext blockingCall() } catch (ex: Throwable) { @@ -174,25 +248,42 @@ actual class ModelixAuthClient { } } - private val cachedTokens: MutableMap = Collections.synchronizedMap(HashMap()) - private fun getCachedTokens(config: OAuthConfig) = runSynchronized(cachedTokens) { cachedTokens.getOrPut(config.getCacheKey()) { CachedTokens() } } - fun getTokens(config: OAuthConfig): Credential? { - return getCachedTokens(config).getTokens(config) - } + /** + * Returns the OAuth refresh token cached in this JVM for the authorization server identified by + * [config] (its `(issuer/authorizationUrl, clientId, clientSecret, scopes)`), or `null` if no + * interactive login has happened yet in this process. + * + * Java-friendly accessor: it does not suspend and performs no network I/O — it only reads the + * in-memory cache that every [ModelixAuthClient] in the JVM shares, so the result is independent + * of which instance it is called on. Credentials are never persisted, so a different process + * always returns `null` until it logs in itself. + */ + fun getRefreshToken(config: OAuthConfig): String? = getCachedTokens(config).getStoredRefreshToken() + + /** + * Returns a still-valid access token cached in this JVM for [config]'s branch-specific + * [OAuthConfig.tokenUrl], or `null` when none is cached or it is within the expiry skew margin. + * + * Java-friendly accessor: it does not suspend and performs no network I/O. When no token is + * cached, drive a request through the authenticated client (or call [getRefreshToken]) — the + * normal request flow runs the refresh/login and populates this cache. + */ + fun getAccessToken(config: OAuthConfig): String? = getCachedTokens(config).getCachedAccessToken(config) - fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? { + private suspend fun getAndMaybeRefreshTokens(config: OAuthConfig): Credential? { return getCachedTokens(config).getAndMaybeRefreshTokens(config) } - suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? { + private suspend fun refreshTokensOrReauthorize(config: OAuthConfig): Credential? { return getCachedTokens(config).refreshTokensOrReauthorize(config) } - suspend fun authorize(config: OAuthConfig): Credential? { + // Visible for ModelixAuthClientTest (cancellation behavior); not part of the public API. + internal suspend fun authorize(config: OAuthConfig): Credential? { return getCachedTokens(config).authorize(config) } @@ -233,9 +324,9 @@ actual class ModelixAuthClient { install(Auth) { bearer { loadTokens { - // A potentially expired token is already refreshed here to avoid a 401 response. - // When a 401 response is received, we always (re-)execute the PKCE flow. - getAndMaybeRefreshTokens(currentAuthConfig)?.let { BearerTokens(it.accessToken, it.refreshToken) } + // Use fillParameters() so the branch-specific tokenUrl is used for the + // refresh grant even before any 401 is received (T031). + getAndMaybeRefreshTokens(currentAuthConfig.fillParameters())?.let { BearerTokens(it.accessToken, it.refreshToken) } } refreshTokens { try { @@ -249,8 +340,6 @@ actual class ModelixAuthClient { tokenUrl = initialAuthConfig.tokenUrl ?: useSameProtocol(wwwAuthenticate.parameter("token_uri") ?: return@let null).fillParameters(), ) - val realm = wwwAuthenticate.parameter("realm") - val description = wwwAuthenticate.parameter("error_description") } if (currentAuthConfig.tokenUrl == null) { LOG.warn { "No token URL configured" } @@ -267,8 +356,6 @@ actual class ModelixAuthClient { val tokens = refreshTokensOrReauthorize(currentAuthConfig.fillParameters()) checkNotNull(tokens) { "No tokens received" } - LOG.info("Access Token: " + tokens.accessToken) - BearerTokens(tokens.accessToken, tokens.refreshToken) } catch (ex: Throwable) { LOG.error(ex) { "Token refresh failed" } @@ -280,6 +367,11 @@ actual class ModelixAuthClient { } } + /** + * Parses the `WWW-Authenticate` challenge from this message, falling back to the non-standard + * `x-amzn-remapped-www-authenticate` header that the Amazon API Gateway substitutes to suppress + * browser login popups. Returns `null` when no parameterized challenge is present. + */ fun HttpMessage.parseWWWAuthenticate(): HttpAuthHeader.Parameterized? { // The Amazon API Gateway replaces the WWW-Authenticate header with x-amzn-remapped-www-authenticate // to prevent any login popup in the browser. REST clients are expected to read this non-standard header. diff --git a/model-client/src/jvmTest/kotlin/org/modelix/model/client2/CredentialCachingTest.kt b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/CredentialCachingTest.kt new file mode 100644 index 0000000000..064e172cee --- /dev/null +++ b/model-client/src/jvmTest/kotlin/org/modelix/model/client2/CredentialCachingTest.kt @@ -0,0 +1,651 @@ +package org.modelix.model.client2 + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.request.get +import io.ktor.http.HttpStatusCode +import io.ktor.http.Url +import io.ktor.http.auth.HttpAuthHeader +import io.ktor.http.buildUrl +import io.ktor.http.takeFrom +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.install +import io.ktor.server.auth.UnauthorizedResponse +import io.ktor.server.auth.parseAuthorizationHeader +import io.ktor.server.engine.embeddedServer +import io.ktor.server.netty.Netty +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.request.receiveParameters +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import org.modelix.kotlin.utils.filterNotNullValues +import org.modelix.model.lazy.RepositoryId +import org.modelix.model.oauth.IAuthConfig +import org.modelix.model.oauth.IAuthRequest +import org.modelix.model.oauth.IAuthRequestHandler +import org.modelix.model.oauth.OAuthConfig +import java.net.BindException +import java.util.Collections +import java.util.concurrent.atomic.AtomicInteger +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals + +private inline fun retryOnException(exceptionType: Class, body: () -> R): R { + var attempt = 0 + while (true) { + try { + return body() + } catch (ex: Throwable) { + if (++attempt >= 3 || !exceptionType.isAssignableFrom(ex::class.java)) throw ex + } + } +} + +@Serializable +private data class CachedTestTokenResponse( + @SerialName("access_token") val accessToken: String, + @SerialName("refresh_token") val refreshToken: String, +) + +class CredentialCachingTest { + + // The credential cache is process-wide (a ModelixAuthClient companion-object map) and these + // tests run in parallel with each other and with OAuthTest. Isolation is therefore achieved by + // giving every test a unique cache key — each embedded server binds a random port, so the + // discovered/observed authorizationUrl (and thus the cache key) differs per test. We must NOT + // clear the shared cache globally here: that would race with and wipe other tests' credentials. + + // T030: cache key must not include tokenUrl — two configs with the same issuer+clientId + // but different branch-specific token URLs must share a single credential cache entry. + @Test + fun `credential cache key excludes token URL so different branch URLs with same issuer and client ID share one credential`() { + val config1 = OAuthConfig( + clientId = "external-mps", + authorizationUrl = "https://idp.example.com/auth", + tokenUrl = "https://srs.example.com/v1/oauth2/modelix-token?repositoryId=repo&branchName=versions%2F1.0.0", + ) + val config2 = OAuthConfig( + clientId = "external-mps", + authorizationUrl = "https://idp.example.com/auth", + tokenUrl = "https://srs.example.com/v1/oauth2/modelix-token?repositoryId=repo&branchName=versions%2F2.0.0", + ) + assertEquals(config1.getCacheKey(), config2.getCacheKey()) + } + + // T032: two ModelClientV2 instances with the same authorizationUrl+clientId but different + // branch-specific tokenUrls must require exactly one PKCE login; the second client reuses + // the stored refresh token to obtain its own branch-scoped access token. + @Test + fun `two clients on different branch token URLs share one PKCE login via refresh token`() = runBlocking { + val expectedRepoId = "my-repo" + val pkceLoginCount = AtomicInteger(0) + val tokenSuffix = AtomicInteger(1) + val validAccessTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val validRefreshTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val expectedBranches = listOf( + RepositoryId(expectedRepoId).getBranchReference("a"), + RepositoryId(expectedRepoId).getBranchReference("b"), + ) + + fun issueTokens(): CachedTestTokenResponse { + val n = tokenSuffix.getAndIncrement() + val at = "at-$n" + val rt = "rt-$n" + validAccessTokens += at + validRefreshTokens += rt + return CachedTestTokenResponse(at, rt) + } + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + val server = retryOnException(BindException::class.java) { + embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + install(ContentNegotiation) { json() } + routing { + post("/token/{branch}") { + val params = call.receiveParameters() + when { + params["grant_type"] == "authorization_code" && params["code"] == "abc" -> { + pkceLoginCount.incrementAndGet() + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" && validRefreshTokens.contains(params["refresh_token"]) -> { + call.respond(issueTokens()) + } + else -> call.respond(HttpStatusCode.BadRequest) + } + } + get("/v2/repositories/{repository-id}/branches") { + val authHeader = call.request.parseAuthorizationHeader() + if (authHeader is HttpAuthHeader.Single && + authHeader.authScheme == "Bearer" && + validAccessTokens.contains(authHeader.blob) + ) { + call.respondText(expectedBranches.joinToString("\n") { it.branchName }) + } else { + val port = engine.resolvedConnectors().single().port + call.respond( + UnauthorizedResponse( + HttpAuthHeader.Parameterized( + "Bearer", + mapOf( + HttpAuthHeader.Parameters.Realm to "modelix", + "error" to "invalid_token", + "authorization_uri" to "http://localhost:$port/auth", + "token_uri" to "http://localhost:$port/token/default", + ).filterNotNullValues(), + ), + ), + ) + } + } + } + }.startSuspend() + } + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + runWithServer { port -> + val authHandler = object : IAuthRequestHandler { + override fun browse(request: IAuthRequest) { + val redirectUri = Url(request.getUrl()).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { HttpClient(CIO).get(callbackWithCode) } + } + } + + fun makeClient(branch: String) = ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + clientId("external-mps") + authorizationUrl("http://localhost:$port/auth") + tokenUrl("http://localhost:$port/token/$branch") + authRequestHandler(authHandler) + }, + ) + .build() + + makeClient("branch1").use { it.listBranches(RepositoryId(expectedRepoId)) } + makeClient("branch2").use { it.listBranches(RepositoryId(expectedRepoId)) } + + assertEquals(1, pkceLoginCount.get(), "Expected one PKCE login but got ${pkceLoginCount.get()}") + } + } + + // T033: when the token endpoint returns invalid_grant, the stored refresh token must be + // discarded and the client must fall back to a new PKCE login rather than looping. + @Test + fun `invalid_grant response discards stored refresh token and triggers new PKCE login`() = runBlocking { + val expectedRepoId = "my-repo" + val pkceLoginCount = AtomicInteger(0) + val tokenSuffix = AtomicInteger(1) + val validAccessTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val validRefreshTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + var rejectNextRefresh = false + val expectedBranches = listOf( + RepositoryId(expectedRepoId).getBranchReference("a"), + ) + + fun issueTokens(): CachedTestTokenResponse { + val n = tokenSuffix.getAndIncrement() + val at = "at-$n" + val rt = "rt-$n" + validAccessTokens += at + validRefreshTokens += rt + return CachedTestTokenResponse(at, rt) + } + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + val server = retryOnException(BindException::class.java) { + embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + install(ContentNegotiation) { json() } + routing { + post("/token") { + val params = call.receiveParameters() + when { + params["grant_type"] == "authorization_code" && params["code"] == "abc" -> { + pkceLoginCount.incrementAndGet() + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" && !rejectNextRefresh && + validRefreshTokens.contains(params["refresh_token"]) -> { + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" -> { + call.respond( + HttpStatusCode.BadRequest, + mapOf("error" to "invalid_grant", "error_description" to "Token expired"), + ) + } + else -> call.respond(HttpStatusCode.BadRequest) + } + } + get("/v2/repositories/{repository-id}/branches") { + val authHeader = call.request.parseAuthorizationHeader() + if (authHeader is HttpAuthHeader.Single && + authHeader.authScheme == "Bearer" && + validAccessTokens.contains(authHeader.blob) + ) { + call.respondText(expectedBranches.joinToString("\n") { it.branchName }) + } else { + val port = engine.resolvedConnectors().single().port + call.respond( + UnauthorizedResponse( + HttpAuthHeader.Parameterized( + "Bearer", + mapOf( + HttpAuthHeader.Parameters.Realm to "modelix", + "error" to "invalid_token", + "authorization_uri" to "http://localhost:$port/auth", + "token_uri" to "http://localhost:$port/token", + ).filterNotNullValues(), + ), + ), + ) + } + } + } + }.startSuspend() + } + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + runWithServer { port -> + val authHandler = object : IAuthRequestHandler { + override fun browse(request: IAuthRequest) { + val redirectUri = Url(request.getUrl()).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { HttpClient(CIO).get(callbackWithCode) } + } + } + + val client = ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + clientId("external-mps") + authorizationUrl("http://localhost:$port/auth") + tokenUrl("http://localhost:$port/token") + authRequestHandler(authHandler) + }, + ) + .build() + + client.use { + // First call: PKCE login, stores refresh token + it.listBranches(RepositoryId(expectedRepoId)) + assertEquals(1, pkceLoginCount.get()) + + // Invalidate the access token and mark the next refresh as invalid_grant + validAccessTokens.clear() + rejectNextRefresh = true + + // Second call: access token rejected → refresh grant returns invalid_grant + // → stored RT discarded → new PKCE login triggered + it.listBranches(RepositoryId(expectedRepoId)) + assertEquals(2, pkceLoginCount.get(), "Expected a second PKCE login after invalid_grant") + } + } + } + + // FR-032 single-flight: two clients for the same (issuer, clientId) opening *concurrently* must + // still trigger exactly one interactive PKCE login; the second waits on the shared mutex and + // reuses the refresh token. Without the mutex both would log in (and likely race on the local + // receiver port). + @Test + fun `concurrent branch opens trigger only one PKCE login`() = runBlocking { + val expectedRepoId = "my-repo" + val pkceLoginCount = AtomicInteger(0) + val tokenSuffix = AtomicInteger(1) + val validAccessTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val validRefreshTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val expectedBranches = listOf(RepositoryId(expectedRepoId).getBranchReference("a")) + + fun issueTokens(): CachedTestTokenResponse { + val n = tokenSuffix.getAndIncrement() + val at = "at-$n" + val rt = "rt-$n" + validAccessTokens += at + validRefreshTokens += rt + return CachedTestTokenResponse(at, rt) + } + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + val server = retryOnException(BindException::class.java) { + embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + install(ContentNegotiation) { json() } + routing { + post("/token/{branch}") { + val params = call.receiveParameters() + when { + params["grant_type"] == "authorization_code" && params["code"] == "abc" -> { + pkceLoginCount.incrementAndGet() + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" && validRefreshTokens.contains(params["refresh_token"]) -> { + call.respond(issueTokens()) + } + else -> call.respond(HttpStatusCode.BadRequest) + } + } + get("/v2/repositories/{repository-id}/branches") { + val authHeader = call.request.parseAuthorizationHeader() + if (authHeader is HttpAuthHeader.Single && + authHeader.authScheme == "Bearer" && + validAccessTokens.contains(authHeader.blob) + ) { + call.respondText(expectedBranches.joinToString("\n") { it.branchName }) + } else { + val port = engine.resolvedConnectors().single().port + call.respond( + UnauthorizedResponse( + HttpAuthHeader.Parameterized( + "Bearer", + mapOf( + HttpAuthHeader.Parameters.Realm to "modelix", + "error" to "invalid_token", + "authorization_uri" to "http://localhost:$port/auth", + "token_uri" to "http://localhost:$port/token/default", + ).filterNotNullValues(), + ), + ), + ) + } + } + } + }.startSuspend() + } + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + runWithServer { port -> + val authHandler = object : IAuthRequestHandler { + override fun browse(request: IAuthRequest) { + val redirectUri = Url(request.getUrl()).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { HttpClient(CIO).get(callbackWithCode) } + } + } + + fun makeClient(branch: String) = ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + clientId("external-mps") + authorizationUrl("http://localhost:$port/auth") + tokenUrl("http://localhost:$port/token/$branch") + authRequestHandler(authHandler) + }, + ) + .build() + + val clientA = makeClient("branch1") + val clientB = makeClient("branch2") + try { + coroutineScope { + // Launch both concurrently. listBranches suspends on network + the shared + // authMutex, so the two interleave; performPKCE already offloads its blocking + // work to an IO dispatcher, so an absent mutex would still bind two receiver + // ports and produce two logins. + awaitAll( + async { clientA.listBranches(RepositoryId(expectedRepoId)) }, + async { clientB.listBranches(RepositoryId(expectedRepoId)) }, + ) + } + assertEquals(1, pkceLoginCount.get(), "Concurrent opens must share one PKCE login, got ${pkceLoginCount.get()}") + } finally { + clientA.close() + clientB.close() + } + } + } + + // FR-032 rotation chaining: with single-use refresh tokens (each refresh invalidates the + // previous RT), repeated refreshes must keep succeeding — proving the client always chains to + // the latest rotated RT. If it ever re-sent a consumed RT, the server would reject it and force + // a second PKCE login. + @Test + fun `rotating single-use refresh tokens keep working without re-login`() = runBlocking { + val expectedRepoId = "my-repo" + val pkceLoginCount = AtomicInteger(0) + val tokenSuffix = AtomicInteger(1) + val validAccessTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val validRefreshTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val expectedBranches = listOf(RepositoryId(expectedRepoId).getBranchReference("a")) + + fun issueTokens(): CachedTestTokenResponse { + val n = tokenSuffix.getAndIncrement() + val at = "at-$n" + val rt = "rt-$n" + validAccessTokens += at + validRefreshTokens += rt + return CachedTestTokenResponse(at, rt) + } + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + val server = retryOnException(BindException::class.java) { + embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + install(ContentNegotiation) { json() } + routing { + post("/token") { + val params = call.receiveParameters() + when { + params["grant_type"] == "authorization_code" && params["code"] == "abc" -> { + pkceLoginCount.incrementAndGet() + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" && validRefreshTokens.contains(params["refresh_token"]) -> { + // single-use: consume the presented refresh token on rotation + validRefreshTokens.remove(params["refresh_token"]) + call.respond(issueTokens()) + } + else -> + call.respond( + HttpStatusCode.BadRequest, + mapOf("error" to "invalid_grant", "error_description" to "reused or unknown refresh token"), + ) + } + } + get("/v2/repositories/{repository-id}/branches") { + val authHeader = call.request.parseAuthorizationHeader() + if (authHeader is HttpAuthHeader.Single && + authHeader.authScheme == "Bearer" && + validAccessTokens.contains(authHeader.blob) + ) { + call.respondText(expectedBranches.joinToString("\n") { it.branchName }) + } else { + val port = engine.resolvedConnectors().single().port + call.respond( + UnauthorizedResponse( + HttpAuthHeader.Parameterized( + "Bearer", + mapOf( + HttpAuthHeader.Parameters.Realm to "modelix", + "error" to "invalid_token", + "authorization_uri" to "http://localhost:$port/auth", + "token_uri" to "http://localhost:$port/token", + ).filterNotNullValues(), + ), + ), + ) + } + } + } + }.startSuspend() + } + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + runWithServer { port -> + val authHandler = object : IAuthRequestHandler { + override fun browse(request: IAuthRequest) { + val redirectUri = Url(request.getUrl()).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { HttpClient(CIO).get(callbackWithCode) } + } + } + + ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + clientId("external-mps") + authorizationUrl("http://localhost:$port/auth") + tokenUrl("http://localhost:$port/token") + authRequestHandler(authHandler) + }, + ) + .build() + .use { client -> + client.listBranches(RepositoryId(expectedRepoId)) + // Force several refreshes; each rotates (and invalidates) the refresh token. + repeat(3) { + validAccessTokens.clear() + assertEquals(expectedBranches, client.listBranches(RepositoryId(expectedRepoId))) + } + assertEquals(1, pkceLoginCount.get(), "Rotating refresh tokens must not force a re-login") + } + } + } + + // FR-031/033 isolation: a different (issuer, clientId) must NOT reuse another's cached refresh + // token. Two clients with distinct client IDs against the same server must each log in once. + @Test + fun `different client IDs do not share the cached credential`() = runBlocking { + val expectedRepoId = "my-repo" + val pkceLoginCount = AtomicInteger(0) + val tokenSuffix = AtomicInteger(1) + val validAccessTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val validRefreshTokens: MutableSet = Collections.synchronizedSet(mutableSetOf()) + val expectedBranches = listOf(RepositoryId(expectedRepoId).getBranchReference("a")) + + fun issueTokens(): CachedTestTokenResponse { + val n = tokenSuffix.getAndIncrement() + val at = "at-$n" + val rt = "rt-$n" + validAccessTokens += at + validRefreshTokens += rt + return CachedTestTokenResponse(at, rt) + } + + suspend fun runWithServer(body: suspend (port: Int) -> Unit) { + val server = retryOnException(BindException::class.java) { + embeddedServer(Netty, port = Random.nextInt(20000, 60000)) { + install(ContentNegotiation) { json() } + routing { + post("/token") { + val params = call.receiveParameters() + when { + params["grant_type"] == "authorization_code" && params["code"] == "abc" -> { + pkceLoginCount.incrementAndGet() + call.respond(issueTokens()) + } + params["grant_type"] == "refresh_token" && validRefreshTokens.contains(params["refresh_token"]) -> { + call.respond(issueTokens()) + } + else -> call.respond(HttpStatusCode.BadRequest) + } + } + get("/v2/repositories/{repository-id}/branches") { + val authHeader = call.request.parseAuthorizationHeader() + if (authHeader is HttpAuthHeader.Single && + authHeader.authScheme == "Bearer" && + validAccessTokens.contains(authHeader.blob) + ) { + call.respondText(expectedBranches.joinToString("\n") { it.branchName }) + } else { + val port = engine.resolvedConnectors().single().port + call.respond( + UnauthorizedResponse( + HttpAuthHeader.Parameterized( + "Bearer", + mapOf( + HttpAuthHeader.Parameters.Realm to "modelix", + "error" to "invalid_token", + "authorization_uri" to "http://localhost:$port/auth", + "token_uri" to "http://localhost:$port/token", + ).filterNotNullValues(), + ), + ), + ) + } + } + } + }.startSuspend() + } + try { + body(server.engine.resolvedConnectors().single().port) + } finally { + server.stop() + } + } + + runWithServer { port -> + val authHandler = object : IAuthRequestHandler { + override fun browse(request: IAuthRequest) { + val redirectUri = Url(request.getUrl()).parameters["redirect_uri"]!! + val callbackWithCode = buildUrl { + takeFrom(redirectUri) + parameters.append("code", "abc") + } + runBlocking { HttpClient(CIO).get(callbackWithCode) } + } + } + + fun makeClient(clientId: String) = ModelClientV2.builder() + .url("http://localhost:$port") + .authConfig( + IAuthConfig.oauth { + clientId(clientId) + authorizationUrl("http://localhost:$port/auth") + tokenUrl("http://localhost:$port/token") + authRequestHandler(authHandler) + }, + ) + .build() + + makeClient("client-a").use { it.listBranches(RepositoryId(expectedRepoId)) } + makeClient("client-b").use { it.listBranches(RepositoryId(expectedRepoId)) } + + assertEquals(2, pkceLoginCount.get(), "Distinct client IDs must each log in; cache must not be shared") + } + } +}