diff --git a/src/main/kotlin/VerifierCli.kt b/src/main/kotlin/VerifierCli.kt index 7f9a991..6d151ba 100644 --- a/src/main/kotlin/VerifierCli.kt +++ b/src/main/kotlin/VerifierCli.kt @@ -19,15 +19,16 @@ package com.android.keyattestation.verifier import java.io.File import java.io.PrintStream import java.security.cert.X509Certificate +import java.time.Instant // Any chain shorter than this is not possibly valid Key Attestation chain. private const val MIN_CERTS_IN_VALID_CHAIN = 3 -class VerifierCli(private val output: PrintStream) { +class VerifierCli(private val output: PrintStream, private val instantSource: InstantSource) { companion object { @JvmStatic fun main(args: Array) { - VerifierCli(System.out).run(args) + VerifierCli(System.out, { Instant.now() }).run(args) } } @@ -62,7 +63,7 @@ class VerifierCli(private val output: PrintStream) { Verifier( trustAnchorsSource = GoogleTrustAnchors, revokedSerialsSource = { emptySet() }, - instantSource = { java.time.Instant.now() }, + instantSource, ) val result = verifier.verify(certs) diff --git a/src/test/kotlin/VerifierCliTest.kt b/src/test/kotlin/VerifierCliTest.kt index e161166..06002e1 100644 --- a/src/test/kotlin/VerifierCliTest.kt +++ b/src/test/kotlin/VerifierCliTest.kt @@ -22,6 +22,8 @@ import java.io.ByteArrayOutputStream import java.io.PrintStream import java.nio.charset.StandardCharsets import java.nio.file.Files +import java.nio.file.Path +import java.time.Instant import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 @@ -29,17 +31,39 @@ import org.junit.runners.JUnit4 @RunWith(JUnit4::class) class VerifierCliTest { companion object { + // For tests have no time dependency + val INCONSEQUENTIAL_TIME = Instant.EPOCH + private fun resolveTestData(path: String) = kotlin.io.path.Path("testdata/$path") + + private fun getValidTimeFromCert(path: Path): Instant { + val pem = Files.readString(path) + val certs = + """-----BEGIN CERTIFICATE-----([\s\S]*?)-----END CERTIFICATE-----""" + .toRegex() + .findAll(pem) + .map { it.value.asX509Certificate() } + .toList() + + check(certs.size >= 2) { "Certificate chain must have at least 2 certificates" } + val intermediate = certs[1] + val notBefore = intermediate.notBefore.toInstant() + val notAfter = intermediate.notAfter.toInstant() + + // Return a time in the middle of the validity period. + return notBefore.plusMillis((notAfter.toEpochMilli() - notBefore.toEpochMilli()) / 2) + } } @Test fun run_validChain_outputsSuccess() { val path = resolveTestData("tegu/sdk36/TEE_EC_2026_ROOT.pem") + val validTime = getValidTimeFromCert(path) val outputStream = ByteArrayOutputStream() val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8.name()) - VerifierCli(printStream).run(arrayOf(path.toString())) + VerifierCli(printStream, { validTime }).run(arrayOf(path.toString())) val output = outputStream.toString(StandardCharsets.UTF_8.name()) assertThat(output).contains("Verification Successful!") @@ -50,10 +74,11 @@ class VerifierCliTest { @Test fun run_invalidChain_outputsFailure() { val path = resolveTestData("invalid/tags_not_in_ascending_order.pem") + val validTime = getValidTimeFromCert(path) val outputStream = ByteArrayOutputStream() val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8.name()) - VerifierCli(printStream).run(arrayOf(path.toString())) + VerifierCli(printStream, { validTime }).run(arrayOf(path.toString())) val output = outputStream.toString(StandardCharsets.UTF_8.name()) assertThat(output).contains("Verification Failed") @@ -66,7 +91,7 @@ class VerifierCliTest { val outputStream = ByteArrayOutputStream() val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8.name()) - VerifierCli(printStream).run(arrayOf(path.toString())) + VerifierCli(printStream, { INCONSEQUENTIAL_TIME }).run(arrayOf(path.toString())) val output = outputStream.toString(StandardCharsets.UTF_8.name()) assertThat(output).contains("No certificates found in the file.") @@ -101,7 +126,7 @@ F3amnfzZkTIFYCL1rPPb6Vp9pI1xhRE5Uk21Eso= val outputStream = ByteArrayOutputStream() val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8.name()) - VerifierCli(printStream).run(arrayOf(path.toString())) + VerifierCli(printStream, { INCONSEQUENTIAL_TIME }).run(arrayOf(path.toString())) val output = outputStream.toString(StandardCharsets.UTF_8.name()) assertThat(output) @@ -154,7 +179,7 @@ ykrM9WmORlkUk9NsoQ== val outputStream = ByteArrayOutputStream() val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8.name()) - VerifierCli(printStream).run(arrayOf(path.toString())) + VerifierCli(printStream, { INCONSEQUENTIAL_TIME }).run(arrayOf(path.toString())) val output = outputStream.toString(StandardCharsets.UTF_8.name()) assertThat(output).contains("Less than 3 certificates found in the file.")