diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 3d5026c497..0000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,190 +0,0 @@ -name: CI - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - workflow_dispatch: - -env: - JAVA: 11 - JAVA_DISTRIBUTION: zulu - -jobs: - ci-core: - runs-on: ubuntu-24.04 - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Java JDK - uses: actions/setup-java@v4 - with: - java-version: ${{ env.JAVA }} - distribution: ${{ env.JAVA_DISTRIBUTION }} - - - name: Setup Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Run core tests - run: | - ./gradlew :usvm-core:check :usvm-dataflow:check :usvm-util:check :usvm-sample-language:check - - - name: Upload Gradle reports - if: (!cancelled()) - uses: actions/upload-artifact@v4 - with: - name: gradle-reports-core - path: '**/build/reports/' - retention-days: 1 - - ci-jvm: - runs-on: ubuntu-24.04 - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Java JDK - uses: actions/setup-java@v4 - with: - java-version: ${{ env.JAVA }} - distribution: ${{ env.JAVA_DISTRIBUTION }} - - - name: Setup Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Run JVM tests - run: ./gradlew :usvm-jvm:check :usvm-jvm:usvm-jvm-api:check :usvm-jvm:usvm-jvm-test-api:check :usvm-jvm:usvm-jvm-util:check :usvm-jvm-dataflow:check :usvm-jvm-instrumentation:check - - - name: Upload Gradle reports - if: (!cancelled()) - uses: actions/upload-artifact@v4 - with: - name: gradle-reports-jvm - path: '**/build/reports/' - retention-days: 1 - - ci-python: - runs-on: ubuntu-24.04 - steps: - - name: Checkout repository - uses: actions/checkout@v4 - with: - # 'usvm-python/cpythonadapter/cpython' is a submodule - submodules: true - - - name: Setup Java JDK - uses: actions/setup-java@v4 - with: - java-version: ${{ env.JAVA }} - distribution: ${{ env.JAVA_DISTRIBUTION }} - - - name: Setup Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Install CPython optional dependencies - run: | - sudo apt-get update - sudo apt-get install -y -q \ - libssl-dev \ - libffi-dev - - - name: Run Python tests - run: ./gradlew -PcpythonActivated=true :usvm-python:check - - - name: Upload Gradle reports - if: (!cancelled()) - uses: actions/upload-artifact@v4 - with: - name: gradle-reports-python - path: '**/build/reports/' - retention-days: 1 - - ci-ts: - runs-on: ubuntu-24.04 - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Java JDK - uses: actions/setup-java@v4 - with: - java-version: ${{ env.JAVA }} - distribution: ${{ env.JAVA_DISTRIBUTION }} - - - name: Setup Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Set up Node - uses: actions/setup-node@v4 - with: - node-version: 22 - - - name: Configure /etc/hosts - run: cat .github/extra/hosts | sudo tee -a /etc/hosts - - - name: Set up ArkAnalyzer - run: | - REPO_URL="https://gitcode.com/Lipen/arkanalyzer.git" - DEST_DIR="arkanalyzer" - MAX_RETRIES=10 - RETRY_DELAY=3 # Delay between retries in seconds - BRANCH="neo/2025-06-24" - - for ((i=1; i<=MAX_RETRIES; i++)); do - git clone --depth=1 --branch $BRANCH $REPO_URL $DEST_DIR && break - echo "Clone failed, retrying in $RETRY_DELAY seconds..." - sleep "$RETRY_DELAY" - done - - if [[ $i -gt $MAX_RETRIES ]]; then - echo "Failed to clone the repository after $MAX_RETRIES attempts." - exit 1 - else - echo "Repository cloned successfully." - fi - - echo "ARKANALYZER_DIR=$(realpath $DEST_DIR)" >> $GITHUB_ENV - cd $DEST_DIR - - npm install - npm run build - - - name: Run TS tests - run: ./gradlew :usvm-ts:check :usvm-ts-dataflow:check - - - name: Upload Gradle reports - if: (!cancelled()) - uses: actions/upload-artifact@v4 - with: - name: gradle-reports-ts - path: '**/build/reports/' - retention-days: 1 - - lint: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Setup Java JDK - uses: actions/setup-java@v4 - with: - distribution: temurin - java-version: 21 - - - name: Setup Gradle - uses: gradle/actions/setup-gradle@v4 - - - name: Validate Project List - run: ./gradlew validateProjectList - - - name: Run Detekt - run: ./gradlew detektMain detektTest - - - name: Upload Detekt SARIF report - uses: github/codeql-action/upload-sarif@v3 - if: success() || failure() - with: - sarif_file: build/reports/detekt/detekt.sarif diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000000..ce7d9fa20d --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,26 @@ +name: Deploy Dokka Docs + +on: + push: + branches: [ mocks2 ] + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '21' + + - name: Build Dokka + run: ./gradlew dokkaHtml + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: usvm-jvm-mocks/build/dokka/html diff --git a/.github/workflows/gradle-publish.yml b/.github/workflows/gradle-publish.yml deleted file mode 100644 index f59c5a66be..0000000000 --- a/.github/workflows/gradle-publish.yml +++ /dev/null @@ -1,33 +0,0 @@ -name: Gradle Package - -on: - workflow_dispatch: - inputs: - version: - description: Release version - type: string - required: true - -jobs: - build: - runs-on: ubuntu-20.04 - permissions: - contents: read - packages: write - steps: - - uses: actions/checkout@v3 - - name: Set up JDK 17 - uses: actions/setup-java@v3 - with: - java-version: '17' - distribution: 'corretto' - server-id: github # Value of the distributionManagement/repository/id field of the pom.xml - settings-path: ${{ github.workspace }} # location for the settings.xml file - - - name: Publish to GitHub Packages - uses: gradle/gradle-build-action@v2 - with: - arguments: publishAllPublicationsToGitHubPackagesRepository -Pversion=${{ inputs.version }} - env: - GITHUB_ACTOR: ${{ github.actor }} - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/mocksCI.yml b/.github/workflows/mocksCI.yml new file mode 100644 index 0000000000..5742f51003 --- /dev/null +++ b/.github/workflows/mocksCI.yml @@ -0,0 +1,43 @@ +name: Mocks configurations CI + +on: + push: + branches: [ mocks2 ] + pull_request: + branches: [ main ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Cache Gradle packages + uses: actions/cache@v4 + with: + path: | + ~/.gradle/caches + ~/.gradle/wrapper + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle*', '**/gradle-wrapper.properties') }} + restore-keys: | + ${{ runner.os }}-gradle- + + - name: Grant execute permission to Gradle wrapper + run: chmod +x ./gradlew + + - name: Set up JDK 21 + uses: actions/setup-java@v3 + with: + distribution: temurin + java-version: 21 + + - name: Build project without tests + run: ./gradlew build -x test + + - name: Run JVM mocks tests + run: ./gradlew :usvm-jvm-mocks:test + + - name: Run ktlint check + run: ./gradlew ktlintCheck diff --git a/.github/workflows/python-runner-publish.yml b/.github/workflows/python-runner-publish.yml deleted file mode 100644 index 4f39d8d896..0000000000 --- a/.github/workflows/python-runner-publish.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Publish Package `usvm-python-runner` - -on: - workflow_dispatch: - inputs: - version: - description: Release version - type: string - required: true - -jobs: - build: - runs-on: ubuntu-20.04 - permissions: - contents: read - packages: write - steps: - - uses: actions/checkout@v3 - - name: Set up JDK 17 - uses: actions/setup-java@v3 - with: - java-version: '17' - distribution: 'corretto' - server-id: github # Value of the distributionManagement/repository/id field of the pom.xml - settings-path: ${{ github.workspace }} # location for the settings.xml file - - # The USERNAME and TOKEN need to correspond to the credentials environment variables used in - # the publishing section of your build.gradle - - name: Publish usvm-python-runner to GitHub Packages - uses: gradle/gradle-build-action@v2 - with: - arguments: :usvm-python:usvm-python-runner:publishAllPublicationsToGitHubPackagesRepository :usvm-python:usvm-python-common:publishAllPublicationsToGitHubPackagesRepository -Pversion=${{ inputs.version }} - env: - GITHUB_ACTOR: ${{ github.actor }} - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/build.gradle.kts b/build.gradle.kts index 8223024eeb..14a0d4a7e4 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -22,6 +22,7 @@ tasks.register("validateProjectList") { project(":usvm-python"), project(":usvm-ts"), project(":usvm-ts-dataflow"), + project(":usvm-jvm-rendering"), ) // Gather the actual subprojects from the current root project. diff --git a/buildSrc/src/main/kotlin/Dependencies.kt b/buildSrc/src/main/kotlin/Dependencies.kt index 0492a78a95..245d26aeae 100644 --- a/buildSrc/src/main/kotlin/Dependencies.kt +++ b/buildSrc/src/main/kotlin/Dependencies.kt @@ -17,6 +17,7 @@ object Versions { const val ksmt = "0.5.26" const val logback = "1.4.8" const val mockk = "1.13.4" + const val mockito = "5.4.0" const val rd = "2023.2.0" const val sarif4k = "0.5.0" const val shadow = "8.3.3" @@ -191,6 +192,12 @@ object Libs { version = Versions.mockk ) + val mockito = dep( + group = "org.mockito", + name = "mockito-core", + version = Versions.mockito + ) + // https://github.com/UnitTestBot/juliet-java-test-suite val juliet_support = dep( group = "com.github.UnitTestBot.juliet-java-test-suite", diff --git a/settings.gradle.kts b/settings.gradle.kts index b6b98f4490..eef3310317 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -51,6 +51,8 @@ include("usvm-python:usvm-python-runner") findProject(":usvm-python:usvm-python-runner")?.name = "usvm-python-runner" include("usvm-python:usvm-python-commons") findProject(":usvm-python:usvm-python-commons")?.name = "usvm-python-commons" +include("usvm-jvm-mocks") +include("usvm-jvm-rendering") // Actually, `includeBuild("../jacodb")` is enough, but there is a bug in IDEA when path is a symlink. // As a workaround, we convert it to a real absolute path. diff --git a/usvm-core/src/main/kotlin/org/usvm/Machine.kt b/usvm-core/src/main/kotlin/org/usvm/Machine.kt index da43d71e42..484a33db68 100644 --- a/usvm-core/src/main/kotlin/org/usvm/Machine.kt +++ b/usvm-core/src/main/kotlin/org/usvm/Machine.kt @@ -27,7 +27,7 @@ abstract class UMachine> : AutoCloseable { * @param stopStrategy is called on every step, before peeking a next state from the path selector. * Returning `true` aborts analysis. */ - protected fun run( + protected open fun run( interpreter: UInterpreter, pathSelector: UPathSelector, observer: UMachineObserver, diff --git a/usvm-core/src/main/kotlin/org/usvm/statistics/TimeStatistics.kt b/usvm-core/src/main/kotlin/org/usvm/statistics/TimeStatistics.kt index e767489bfb..f0894bc6c5 100644 --- a/usvm-core/src/main/kotlin/org/usvm/statistics/TimeStatistics.kt +++ b/usvm-core/src/main/kotlin/org/usvm/statistics/TimeStatistics.kt @@ -3,6 +3,7 @@ package org.usvm.statistics import org.usvm.UState import org.usvm.util.RealTimeStopwatch import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds /** * Maintains information about time spent on machine processes. @@ -16,7 +17,7 @@ class TimeStatistics> : UMachi /** * Total machine running time. */ - val runningTime get() = stopwatch.elapsed + val runningTime get() = Duration.INFINITE /** * Returns time spent by machine on [method]. diff --git a/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/serializer/UTestInstSerializer.kt b/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/serializer/UTestInstSerializer.kt index 264da7ba31..af05118184 100644 --- a/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/serializer/UTestInstSerializer.kt +++ b/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/serializer/UTestInstSerializer.kt @@ -20,41 +20,8 @@ import org.jacodb.api.jvm.ext.float import org.jacodb.api.jvm.ext.int import org.jacodb.api.jvm.ext.long import org.jacodb.api.jvm.ext.short -import org.usvm.test.api.ArithmeticOperationType -import org.usvm.test.api.ConditionType -import org.usvm.test.api.UTestAllocateMemoryCall -import org.usvm.test.api.UTestArithmeticExpression -import org.usvm.test.api.UTestArrayGetExpression -import org.usvm.test.api.UTestArrayLengthExpression -import org.usvm.test.api.UTestArraySetStatement -import org.usvm.test.api.UTestBinaryConditionExpression -import org.usvm.test.api.UTestBinaryConditionStatement -import org.usvm.test.api.UTestBooleanExpression -import org.usvm.test.api.UTestByteExpression -import org.usvm.test.api.UTestCastExpression -import org.usvm.test.api.UTestCharExpression -import org.usvm.test.api.UTestClassExpression -import org.usvm.test.api.UTestConstructorCall -import org.usvm.test.api.UTestCreateArrayExpression -import org.usvm.test.api.UTestDoubleExpression -import org.usvm.test.api.UTestExpression -import org.usvm.test.api.UTestFloatExpression -import org.usvm.test.api.UTestGetFieldExpression -import org.usvm.test.api.UTestGetStaticFieldExpression -import org.usvm.test.api.UTestGlobalMock -import org.usvm.test.api.UTestInst -import org.usvm.test.api.UTestIntExpression -import org.usvm.test.api.UTestLongExpression -import org.usvm.test.api.UTestMethodCall -import org.usvm.test.api.UTestMockObject -import org.usvm.test.api.UTestNullExpression -import org.usvm.test.api.UTestSetFieldStatement -import org.usvm.test.api.UTestSetStaticFieldStatement -import org.usvm.test.api.UTestShortExpression -import org.usvm.test.api.UTestStatement -import org.usvm.test.api.UTestStaticMethodCall -import org.usvm.test.api.UTestStringExpression import org.usvm.jvm.util.stringType +import org.usvm.test.api.* class UTestInstSerializer(private val ctx: SerializationContext) { @@ -100,6 +67,10 @@ class UTestInstSerializer(private val ctx: SerializationContext) { is UTestShortExpression -> serialize(uTestInst) is UTestArithmeticExpression -> serialize(uTestInst) is UTestClassExpression -> serialize(uTestInst) + is UTestAssertEqualsCall -> TODO() + is UTestAssertThrowsCall -> TODO() + is UTestInstList -> TODO() + is UTestMockInst -> TODO() } } diff --git a/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/testcase/executor/UTestExpressionExecutor.kt b/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/testcase/executor/UTestExpressionExecutor.kt index 1a1c9d061b..85c09308e1 100644 --- a/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/testcase/executor/UTestExpressionExecutor.kt +++ b/usvm-jvm-instrumentation/src/main/kotlin/org/usvm/instrumentation/testcase/executor/UTestExpressionExecutor.kt @@ -27,29 +27,7 @@ import org.usvm.jvm.util.toJavaClass import org.usvm.jvm.util.toJavaConstructor import org.usvm.jvm.util.toJavaField import org.usvm.jvm.util.toJavaMethod -import org.usvm.test.api.ArithmeticOperationType -import org.usvm.test.api.ConditionType -import org.usvm.test.api.UTestAllocateMemoryCall -import org.usvm.test.api.UTestArithmeticExpression -import org.usvm.test.api.UTestArrayGetExpression -import org.usvm.test.api.UTestArrayLengthExpression -import org.usvm.test.api.UTestArraySetStatement -import org.usvm.test.api.UTestBinaryConditionExpression -import org.usvm.test.api.UTestBinaryConditionStatement -import org.usvm.test.api.UTestCastExpression -import org.usvm.test.api.UTestClassExpression -import org.usvm.test.api.UTestConstExpression -import org.usvm.test.api.UTestConstructorCall -import org.usvm.test.api.UTestCreateArrayExpression -import org.usvm.test.api.UTestGetFieldExpression -import org.usvm.test.api.UTestGetStaticFieldExpression -import org.usvm.test.api.UTestGlobalMock -import org.usvm.test.api.UTestInst -import org.usvm.test.api.UTestMethodCall -import org.usvm.test.api.UTestMock -import org.usvm.test.api.UTestSetFieldStatement -import org.usvm.test.api.UTestSetStaticFieldStatement -import org.usvm.test.api.UTestStaticMethodCall +import org.usvm.test.api.* class UTestExpressionExecutor( private val workerClassLoader: WorkerClassLoader, @@ -112,6 +90,10 @@ class UTestExpressionExecutor( is UTestSetStaticFieldStatement -> executeUTestSetStaticFieldStatement(uTestExpression) is UTestArithmeticExpression -> executeUTestArithmeticExpression(uTestExpression) is UTestClassExpression -> executeUTestClassExpression(uTestExpression) + is UTestAssertEqualsCall -> TODO() + is UTestAssertThrowsCall -> TODO() + is UTestInstList -> TODO() + is UTestMockInst -> TODO() } }.also { it?.let { diff --git a/usvm-jvm-mocks/README_MOCKS.md b/usvm-jvm-mocks/README_MOCKS.md new file mode 100644 index 0000000000..2c560d09cc --- /dev/null +++ b/usvm-jvm-mocks/README_MOCKS.md @@ -0,0 +1,68 @@ +# Automatic mock configuration +`usvm-jvm-mocks` is a module for the [USVM](https://github.com/MchKosticyn/usvm) symbolic virtual machine. +It automates the configuration of mocked methods in tests, allowing mock values to be generated automatically instead of manually. + +## Features ✨ +- Detects and tracks mocked objects in tests +- Generates mock values for mocks' methods automatically using symbolic execution +- Produces ready-to-use mock configurations for test methods and even integrates fully into the original test (don't worry, it's a copy in case you don't like the result) +- Integrates with the existing USVM infrastructure (`usvm-jvm`) + +## Requirements + +- JDK 17+ +- Kotlin, Java +- Gradle + +## Getting started +Clone the repo: +```bash +git clone git@github.com:sofyak0zyreva/usvm.git +``` +Switch to dev's branch: +```bash +git switch mocks2 +``` +Cause *second time's a charm*✨ You're all set! + +## Usage +1. Place sample `.java` files --- your tests --- [here](./src/samples/java/org/usvm/samples/) +2. Create corresponding test file (yes, test file for your test file) [here](./src/test/kotlin/org/usvm/samples/): +### Example +```kotlin +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class ExampleTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun test() { + checkDiscoveredPropertiesWithExceptions( + ExampleClass::exampleMethodUnderTest, + ignoreNumberOfAnalysisResults, + { _, _, r -> r.getOrNull() == null } + ) + } +} +``` +`_` must be one greater than the number of arguments in the function under test (just trust me on this one). + +3. Run tests: +```bash +./gradlew :usvm-jvm-mocks:test +``` +or just a specific test if you so prefer. That's it! The result is waiting for you in the [testOutput](./src/samples/java/org/usvm/samples/testOutput) directory. Enjoy! There will be improvements so stay in touch✨ + +## Contacts +[sofyak0zyreva](https://github.com/sofyak0zyreva) (tg @soffque) in case you have questions + +## License +The product is distributed under MIT license. See [`LICENSE`](LICENSE) for details. + diff --git a/usvm-jvm-mocks/build.gradle.kts b/usvm-jvm-mocks/build.gradle.kts new file mode 100644 index 0000000000..8b8a48d0eb --- /dev/null +++ b/usvm-jvm-mocks/build.gradle.kts @@ -0,0 +1,127 @@ +plugins { + id("usvm.kotlin-conventions") + id("org.jlleitschuh.gradle.ktlint") version "11.6.0" + id("org.jetbrains.dokka") version "1.9.10" +} + +dependencies { + implementation(project(":usvm-jvm")) + implementation(project(":usvm-core")) + implementation(project(":usvm-jvm-rendering")) + + implementation("com.github.javaparser:javaparser-symbol-solver-core:3.26.3") + + implementation(project(":usvm-jvm:usvm-jvm-test-api")) + implementation(project(":usvm-jvm:usvm-jvm-util")) + implementation(project(":usvm-jvm:usvm-jvm-api")) + + implementation(Libs.jacodb_api_jvm) + implementation(Libs.jacodb_core) + implementation(Libs.jacodb_approximations) + + implementation("it.unimi.dsi:fastutil-core:8.5.13") + testImplementation(project(":usvm-jvm")) +} + +val samples by sourceSets.creating { + java { + srcDir("src/samples/java") + } +} + +repositories { + mavenLocal() + mavenCentral() +} + +val approximations by configurations.creating +val approximationsRepo = "com.github.UnitTestBot.java-stdlib-approximations" +val approximationsVersion = "607384f1a7" + +dependencies { + testImplementation(Libs.mockk) + testImplementation(Libs.junit_jupiter_params) + testImplementation(Libs.logback) + + testImplementation(samples.output) + testImplementation(project(":usvm-jvm").dependencyProject.project.sourceSets.getByName("test").output) + + // https://mvnrepository.com/artifact/org.burningwave/core + // Use it to export all modules to all + testImplementation("org.burningwave:core:12.62.7") + + approximations(approximationsRepo, "approximations", approximationsVersion) + testImplementation(approximationsRepo, "tests", approximationsVersion) +} + +val samplesImplementation: Configuration by configurations.getting + +dependencies { + samplesImplementation(Libs.mockito) + samplesImplementation(Libs.mockk) + samplesImplementation("org.projectlombok:lombok:${Versions.Samples.lombok}") + samplesImplementation("org.slf4j:slf4j-api:${Versions.Samples.slf4j}") + samplesImplementation("javax.validation:validation-api:${Versions.Samples.javaxValidation}") + samplesImplementation("com.github.stephenc.findbugs:findbugs-annotations:${Versions.Samples.findBugs}") + samplesImplementation("org.jetbrains:annotations:${Versions.Samples.jetbrainsAnnotations}") + + // Use usvm-api in samples for makeSymbolic, assume, etc. + samplesImplementation(project(":usvm-jvm:usvm-jvm-api")) + + testImplementation(project(":usvm-jvm-instrumentation")) +} + +val testSamples by configurations.creating +val testSamplesWithApproximations by configurations.creating + +dependencies { + testSamples(samples.output) + testSamples(project(":usvm-jvm:usvm-jvm-api")) + testSamples("org.mockito:mockito-core:5.4.0") + + testSamplesWithApproximations(samples.output) + testSamplesWithApproximations(project(":usvm-jvm:usvm-jvm-api")) + testSamplesWithApproximations(approximationsRepo, "tests", approximationsVersion) +} + +val usvmApiJarConfiguration by configurations.creating +dependencies { + implementation(project(":usvm-jvm")) + usvmApiJarConfiguration(project(":usvm-jvm:usvm-jvm-api")) +} + +tasks.withType { + val usvmApiJarPath = usvmApiJarConfiguration.resolvedConfiguration.files.single() + val usvmApproximationJarPath = approximations.resolvedConfiguration.files.single() + + environment("usvm.jvm.api.jar.path", usvmApiJarPath.absolutePath) + environment("usvm.jvm.approximations.jar.path", usvmApproximationJarPath.absolutePath) + + environment("usvm.jvm.test.samples", testSamples.asPath) + environment("usvm.jvm.test.samples.approximations", testSamplesWithApproximations.asPath) + + environment( + "usvm-jvm-instrumentation-jar", + project(":usvm-jvm-instrumentation") + .layout + .buildDirectory + .file("libs/usvm-jvm-instrumentation-runner.jar") + .get().asFile.absolutePath + ) + environment( + "usvm-jvm-collectors-jar", + project(":usvm-jvm-instrumentation") + .layout + .buildDirectory + .file("libs/usvm-jvm-instrumentation-collectors.jar") + .get().asFile.absolutePath + ) +} + +publishing { + publications { + create("maven") { + from(components["java"]) + } + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksComponents.kt b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksComponents.kt new file mode 100644 index 0000000000..c2627eaa13 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksComponents.kt @@ -0,0 +1,38 @@ +package machine + +import org.jacodb.api.jvm.JcType +import org.usvm.UComposer +import org.usvm.UContext +import org.usvm.UMachineOptions +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.machine.JcComponents +import org.usvm.machine.JcTypeSystem +import org.usvm.machine.USizeSort +import org.usvm.memory.UReadOnlyMemory +import org.usvm.model.ULazyModelDecoder +import org.usvm.solver.UExprTranslator +import org.usvm.solver.USoftConstraintsProvider + +/** + * JcMocksComponents simply applies changes made for usvm-jvm-mocks memory to the execution. + */ +class JcMocksComponents( + typeSystem: JcTypeSystem, + options: UMachineOptions +) : JcComponents(typeSystem, options) { + override fun > mkComposer( + ctx: Context + ): (UReadOnlyMemory, MutabilityOwnership) -> UComposer = + { memory: UReadOnlyMemory, ownership: MutabilityOwnership -> JcMocksComposer(ctx, memory, ownership) } + + override fun > buildTranslatorAndLazyDecoder( + ctx: Context + ): Pair, ULazyModelDecoder> { + val translator = JcMocksExprTranslator(ctx) + val decoder: ULazyModelDecoder = ULazyModelDecoder(translator) + return translator to decoder + } + override fun > mkSoftConstraintsProvider( + ctx: Context + ): USoftConstraintsProvider = JcMocksSoftConstraintsProvider(ctx) +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksInterpreter.kt b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksInterpreter.kt new file mode 100644 index 0000000000..888afbccc4 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksInterpreter.kt @@ -0,0 +1,134 @@ +package machine + +import io.ksmt.utils.asExpr +import machine.memory.JcMockedMethod +import machine.memory.JcMockedMethodsReading +import machine.memory.JcMockedMethodsRegion +import machine.memory.JcMockedMethodsRegionId +import machine.memory.JcMockedMethodsValue +import org.jacodb.api.jvm.cfg.JcAssignInst +import org.jacodb.api.jvm.cfg.JcStaticCallExpr +import org.jacodb.api.jvm.ext.toType +import org.usvm.UConcreteHeapRef +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.collection.field.UFieldLValue +import org.usvm.machine.JcApplicationGraph +import org.usvm.machine.JcConcreteMethodCallInst +import org.usvm.machine.JcContext +import org.usvm.machine.JcInterpreterObserver +import org.usvm.machine.JcMachineOptions +import org.usvm.machine.JcMethodCallBaseInst +import org.usvm.machine.interpreter.JcExprResolver +import org.usvm.machine.interpreter.JcInterpreter +import org.usvm.machine.interpreter.JcStepScope +import org.usvm.machine.state.skipMethodInvocationWithValue +import java.io.File + +/** + * mocksMap tracks addresses assigned to mock objects during initialization and its metadata. + */ +val mocksMap: MutableMap = HashMap() + +/** + * mockedMethods simply tracks which methods were called by mock objects. + */ +val mockedMethods: MutableSet> = mutableSetOf() + +/** + * mockedMethodsValues is a map that contains methods from memory and their current return values. + */ +val mockedMethodsValues: MutableMap, UExpr> = HashMap() + +/** + * varnamesMap saves variable names and lines where they were defined. + */ +val varnamesMap: MutableMap = HashMap() + +/** + * Given a line with metadata about mocked method, it returns the line where mock object that called it was initialized. + */ +fun getLineNumber(line: String): Int { + val num = line.substringAfter("(line:").substringBefore(")") + return num.toInt() +} + +/** + * JcMocksInterpreter is responsible for handling two special cases: + * 1. a mock was initialized; + * 2. its method was called. + * It handles return values, reads from memory and saving results to global maps. + */ +open class JcMocksInterpreter( + val file: String, + ctx: JcContext, + applicationGraph: JcApplicationGraph, + options: JcMachineOptions, + observer: JcInterpreterObserver? = null +) : JcInterpreter(ctx, applicationGraph, options, observer) { + fun updateVarName(lineNumber: Int, lines: List) { + val str = lines.getOrNull(lineNumber - 1) ?: throw IllegalArgumentException("no such line in file") + val res = str.trim().substringBefore("=").substringAfter(" ").substringBefore(" ") + varnamesMap[lineNumber] = res + } + + override fun callMethod( + scope: JcStepScope, + stmt: JcMethodCallBaseInst, + exprResolver: JcExprResolver + ) { + when (stmt) { + is JcConcreteMethodCallInst -> { + val method = stmt.method + val methodName = method.name + val retStmt = stmt.returnSite + if (retStmt !is JcAssignInst) { return } + val isAppCode = method.declaration.relativePath.startsWith("org.usvm.samples.") + if (stmt.arguments.isNotEmpty() && isAppCode) { + val refToMock = stmt.arguments[0] + val value = mocksMap[refToMock] + if (value != null || refToMock is JcMockedMethodsReading) { + val retType = retStmt.lhv.type + val newSymbolicRef: UExpr + val lineNumber = retStmt.lineNumber + val lines = File(file).readLines() + val enclosingClass = if (refToMock is JcMockedMethodsReading) { + "mock:" + method.enclosingClass.toType().typeName + "(" + refToMock.mockedMethod.method.substringAfter("(") + "::" + } else { value!! } + val num = getLineNumber(enclosingClass) + updateVarName(num, lines) + val mockedMethod = JcMockedMethod("$methodName(line:$lineNumber)", enclosingClass) + + scope.doWithState { + val retSort = ctx.typeToSort(retType) + val memoryRegion = memory.getRegion(JcMockedMethodsRegionId(retSort, retType)) as JcMockedMethodsRegion + val mockedMethodValue = JcMockedMethodsValue(mockedMethod, retSort, retType, method) + newSymbolicRef = memoryRegion.read(mockedMethodValue.key) + skipMethodInvocationWithValue(stmt, newSymbolicRef) + mockedMethods.add(mockedMethodValue) + } + return + } + } + + val mockCall = retStmt.rhv + if (methodName == "mock" && mockCall is JcStaticCallExpr && mockCall.args.size == 1) { + scope.doWithState { + val classRef = stmt.arguments[0].asExpr(ctx.addressSort) + val classRefTypeRepresentative = + memory.read(UFieldLValue(ctx.addressSort, classRef, ctx.classTypeSyntheticField)) + classRefTypeRepresentative as UConcreteHeapRef + val classType = memory.types.typeOf(classRefTypeRepresentative.address) + val ref = memory.allocConcrete(classType) + skipMethodInvocationWithValue(stmt, ref) + val lineNumber = stmt.returnSite.lineNumber + mocksMap[ref] = "mock:" + classType.typeName + "(line:" + lineNumber + ")::" + } + return + } + super.callMethod(scope, stmt, exprResolver) + } + else -> super.callMethod(scope, stmt, exprResolver) + } + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksMachine.kt b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksMachine.kt new file mode 100644 index 0000000000..e841827b20 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksMachine.kt @@ -0,0 +1,126 @@ +package machine + +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcType +import org.usvm.UInterpreter +import org.usvm.UMachineOptions +import org.usvm.UPathSelector +import org.usvm.logger +import org.usvm.machine.JcInterpreterObserver +import org.usvm.machine.JcMachine +import org.usvm.machine.JcMachineOptions +import org.usvm.machine.interpreter.JcInterpreter +import org.usvm.machine.state.JcState +import org.usvm.model.UModelBase +import org.usvm.solver.USatResult +import org.usvm.solver.USolverResult +import org.usvm.statistics.UMachineObserver +import org.usvm.stopstrategies.StopStrategy +import org.usvm.util.bracket +import org.usvm.util.debug + +private fun JcState.isSat(): Boolean { + if (models.isNotEmpty()) { + return true + } + + return verify() is USatResult +} + +private fun JcState.verify(): USolverResult> { + val solver = ctx.solver() + val solverResult = solver.check(pathConstraints) + + if (solverResult is USatResult) { + models = listOf(solverResult.model) + } + + return solverResult +} + +/** + * Symbolic machine responsible for handling mocked method execution. + * + * This machine extends the default USVM JVM execution model by + * intercepting mock initialization and method calls. + */ +open class JcMocksMachine( + private val file: String, + cp: JcClasspath, + options: UMachineOptions, + jcMachineOptions: JcMachineOptions = JcMachineOptions(), + interpreterObserver: JcInterpreterObserver? = null +) : JcMachine(cp, options, jcMachineOptions, interpreterObserver) { + override val components = JcMocksComponents(typeSystem, options) + override fun createInterpreter(): JcInterpreter { + return JcMocksInterpreter( + file, + ctx, + applicationGraph, + jcMachineOptions, + interpreterObserver + ) + } + override fun run( + interpreter: UInterpreter, + pathSelector: UPathSelector, + observer: UMachineObserver, + isStateTerminated: (JcState) -> Boolean, + stopStrategy: StopStrategy + ) { + logger.debug().bracket("$this.run($interpreter, ${pathSelector::class.simpleName})") { + observer.onMachineStarted() + try { + while (!pathSelector.isEmpty() && !stopStrategy.shouldStop()) { + val state = pathSelector.peek() + observer.onStatePeeked(state) + + val (forkedStates, stateAlive) = try { + interpreter.step(state) + } catch (e: Throwable) { + logger.error(e) { "Step failed" } + observer.onState(state, forks = emptySequence()) + pathSelector.remove(state) + observer.onStateTerminated(state, stateReachable = false) + continue + } + + observer.onState(state, forkedStates) + + val originalStateAlive = stateAlive && !isStateTerminated(state) + val aliveForkedStates = mutableListOf() + for (forkedState in forkedStates) { + if (!isStateTerminated(forkedState)) { + aliveForkedStates.add(forkedState) + } else { + // TODO: distinguish between states terminated by exception (runtime or user) and + // those which just exited + if (forkedState.isSat()) { + observer.onStateTerminated(forkedState, stateReachable = true) + } + } + } + if (originalStateAlive) { + pathSelector.update(state) + } else { + pathSelector.remove(state) + if (state.isSat()) { + observer.onStateTerminated(state, stateReachable = stateAlive) + } + } + + if (aliveForkedStates.isNotEmpty()) { + pathSelector.add(aliveForkedStates) + } + } + } finally { + observer.onMachineStopped() + } + + if (!pathSelector.isEmpty()) { + val stopReason = stopStrategy.stopReason() + logger.debug { stopReason } + } + } + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksTransformer.kt b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksTransformer.kt new file mode 100644 index 0000000000..8cec2084df --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/JcMocksTransformer.kt @@ -0,0 +1,113 @@ +package machine + +import io.ksmt.expr.KExpr +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.mkConst +import machine.memory.JcMockedMethod +import machine.memory.JcMockedMethodsReading +import machine.memory.JcMockedMethodsRegionId +import machine.memory.JcMockedMethodsValue +import org.jacodb.api.jvm.JcType +import org.usvm.UComposer +import org.usvm.UContext +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.machine.JcComposer +import org.usvm.machine.JcExprTranslator +import org.usvm.machine.JcSoftConstraintsProvider +import org.usvm.machine.JcTransformer +import org.usvm.machine.USizeSort +import org.usvm.machine.interpreter.statics.JcStaticFieldReading +import org.usvm.machine.jctx +import org.usvm.memory.UReadOnlyMemory +import org.usvm.memory.UReadOnlyMemoryRegion +import org.usvm.model.UModelEvaluator +import org.usvm.solver.UExprTranslator +import org.usvm.solver.URegionDecoder +import org.usvm.solver.USoftConstraintsProvider + +/** + * JcMocksTransformer and the following classes are helpers so that JcMockedMemory works like a proper memory region. + */ +interface JcMocksTransformer : JcTransformer { + fun transform(expr: JcMockedMethodsReading): UExpr +} + +class JcMocksComposer( + ctx: UContext, + memory: UReadOnlyMemory, + ownership: MutabilityOwnership +) : UComposer(ctx, memory, ownership), JcMocksTransformer { + override fun transform(expr: JcMockedMethodsReading): UExpr { + val ret = memory.read(JcMockedMethodsValue(expr.mockedMethod, expr.sort, expr.type, expr.method)) + for (key in mockedMethods) { + val newValue = memory.read(key.key) + mockedMethodsValues[key] = newValue + } + return ret + } + + private val jcComposer = JcComposer(ctx, memory, ownership) + override fun transform(expr: JcStaticFieldReading): UExpr { + return jcComposer.transform(expr) + } +} + +class JcMocksExprTranslator(ctx: UContext) : UExprTranslator(ctx), JcMocksTransformer { + override fun transform(expr: JcMockedMethodsReading): UExpr = + getOrPutRegionDecoder(expr.regionId) { + JcMockedMethodsDecoder(expr.regionId, this) + }.translate(expr) + + private val jcExprTranslator = JcExprTranslator(ctx) + override fun transform(expr: JcStaticFieldReading): UExpr { + return jcExprTranslator.transform(expr) + } +} + +class JcMockedMethodsDecoder( + private val regionId: JcMockedMethodsRegionId, + private val translator: UExprTranslator<*, *> +) : URegionDecoder, Sort> { + private val translated = mutableMapOf>() + + fun translate(expr: JcMockedMethodsReading): UExpr = + translated.getOrPut(expr.mockedMethod) { + expr.sort.mkConst("${expr.mockedMethod.enclosingClass}_${regionId.sort}_${expr.mockedMethod.method}") + } + + override fun decodeLazyRegion( + model: UModelEvaluator<*>, + assertions: List> + ): UReadOnlyMemoryRegion, Sort> = + JcMockedMethodsModel(model, translated, translator) +} + +class JcMockedMethodsModel( + private val model: UModelEvaluator<*>, + private val translatedMockedMethods: Map>, + private val translator: UExprTranslator<*, *> +) : UReadOnlyMemoryRegion, Sort> { + override fun read(key: JcMockedMethodsValue): UExpr { + val t = translatedMockedMethods[key.mockedMethod] + val translated = t + ?: translator.translate( + JcMockedMethodsReading(key.sort.jctx, key.memoryRegionId as JcMockedMethodsRegionId, key.mockedMethod, key.type, key.method, key.sort) + ) + return model.evalAndComplete(translated) + } +} + +class JcMocksSoftConstraintsProvider( + ctx: UContext +) : USoftConstraintsProvider(ctx), JcMocksTransformer { + override fun transform( + expr: JcMockedMethodsReading + ): UExpr = transformExpr(expr) + + private val jcSoftConstraintsProvider = JcSoftConstraintsProvider(ctx) + override fun transform(expr: JcStaticFieldReading): UExpr { + return jcSoftConstraintsProvider.transform(expr) + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/instructions/Builder.kt b/usvm-jvm-mocks/src/main/kotlin/machine/instructions/Builder.kt new file mode 100644 index 0000000000..50c0119082 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/instructions/Builder.kt @@ -0,0 +1,71 @@ +package machine.instructions + +import machine.memory.JcMockedMethodsValue +import machine.mockedMethodsValues +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcType +import org.jacodb.api.jvm.JcTypedMethod +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.api.util.JcTestStateResolver +import org.usvm.jvm.util.toTypedMethod +import org.usvm.machine.JcContext +import org.usvm.machine.state.JcState +import org.usvm.memory.UReadOnlyMemory +import org.usvm.model.UModelBase +import org.usvm.test.api.JcTestExecutorDecoderApi +import org.usvm.test.api.UTestAllocateMemoryCall +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestInst +import org.usvm.test.api.UTestMockInst + +/** + * createUTestMockConfigInfo() simply combines lists of test instructions into a single one by flattening it + * and creates UTestMockConfigInfo for render. + */ +fun createUTestMockConfigInfo(list: List>>): UTestMockConfigInfo { + val instructions = list.flatten() + return UTestMockConfigInfo(instructions) +} + +/** + * createUTestInstructions() takes key and state, passes information on to MemoryScope, + * which handles resolving values and returns UTestMockInst with additional meta info. + */ +fun createUTestInstructions( + key: JcMockedMethodsValue, + state: JcState +): List> { + val model = state.models.first() + val ctx = state.ctx + val memoryScope = MemoryScope(ctx, model, state.memory, key.method.toTypedMethod) + + return memoryScope.createUTestInstructions(key) +} + +private class MemoryScope( + ctx: JcContext, + model: UModelBase, + finalStateMemory: UReadOnlyMemory, + method: JcTypedMethod +) : JcTestStateResolver(ctx, model, finalStateMemory, method) { + override val decoderApi = JcTestExecutorDecoderApi(ctx.cp) + override fun allocateClassInstance(type: JcClassType): UTestExpression = + UTestAllocateMemoryCall(type.jcClass) + fun createUTestInstructions(key: JcMockedMethodsValue): List> { + val newMap: Map, UExpr> = mockedMethodsValues.toMap() + return withMode(ResolveMode.CURRENT) { + val list = mutableListOf>() + val parameters = resolveParameters() + val m = newMap[key] + val resolved = resolveExpr(m as UExpr, key.type) + val pairToAdd = (Pair(UTestMockInst(resolved, method.method, parameters), key.mockedMethod.str)) + val initStmts = this@MemoryScope.decoderApi.initializerInstructions() + for (initStmt in initStmts) { + list.add(Pair(initStmt, "")) + } + list.add(pairToAdd) + list + } + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/instructions/UTestMockConfigInfo.kt b/usvm-jvm-mocks/src/main/kotlin/machine/instructions/UTestMockConfigInfo.kt new file mode 100644 index 0000000000..4270960007 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/instructions/UTestMockConfigInfo.kt @@ -0,0 +1,8 @@ +package machine.instructions + +import org.usvm.test.api.UTestInst + +/** + * This class simply represents instructions that render should transform into Java-code. + */ +class UTestMockConfigInfo(val instructions: List>) diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsReading.kt b/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsReading.kt new file mode 100644 index 0000000000..94a5f7134c --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsReading.kt @@ -0,0 +1,47 @@ +package machine.memory + +import io.ksmt.cache.hash +import io.ksmt.cache.structurallyEqual +import io.ksmt.expr.printer.ExpressionPrinter +import io.ksmt.expr.transformer.KTransformerBase +import machine.JcMocksTransformer +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcType +import org.usvm.UContext +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.USymbol + +/** + * JcMockedMethodsReading represents an expression that is returned by the read() + * if there's no required JcMockedMethodsValue. + */ +class JcMockedMethodsReading internal constructor( + ctx: UContext<*>, + val regionId: JcMockedMethodsRegionId, + val mockedMethod: JcMockedMethod, + val type: JcType, + val method: JcMethod, + override val sort: Sort +) : USymbol(ctx) { + override fun accept(transformer: KTransformerBase): UExpr { + require(transformer is JcMocksTransformer) { "Expected a JcMocksTransformer, but got: $transformer" } + return transformer.transform(this) + } + + override fun internEquals(other: Any): Boolean = structurallyEqual( + other, + { regionId }, + { mockedMethod }, + { sort } + ) + + override fun internHashCode(): Int = hash(regionId, mockedMethod, sort) + + override fun print(printer: ExpressionPrinter) { + printer.append(regionId.toString()) + printer.append("[") + printer.append(mockedMethod.toString()) + printer.append("]") + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsRegion.kt b/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsRegion.kt new file mode 100644 index 0000000000..7d81d462ce --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/memory/JcMockedMethodsRegion.kt @@ -0,0 +1,91 @@ +package machine.memory + +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcType +import org.usvm.UBoolExpr +import org.usvm.UExpr +import org.usvm.USort +import org.usvm.collections.immutable.getOrDefault +import org.usvm.collections.immutable.implementations.immutableMap.UPersistentHashMap +import org.usvm.collections.immutable.internal.MutabilityOwnership +import org.usvm.collections.immutable.persistentHashMapOf +import org.usvm.machine.jctx +import org.usvm.memory.ULValue +import org.usvm.memory.UMemoryRegion +import org.usvm.memory.UMemoryRegionId +import org.usvm.memory.guardedWrite +import org.usvm.sampleUValue + +/** + * JcMockedMethod represents a mocked method's metadata that is kept inside memory region. + */ +class JcMockedMethod( + val method: String, + val enclosingClass: String +) { + val str = enclosingClass + method +} + +/** + * JcMockedMethodsValue contains all the information about the method that is stored in the memory region. + */ +data class JcMockedMethodsValue( + val mockedMethod: JcMockedMethod, + override val sort: Sort, + val type: JcType, + val method: JcMethod +) : ULValue, Sort> { + + override val memoryRegionId: UMemoryRegionId, Sort> = JcMockedMethodsRegionId(sort, type) + + override val key: JcMockedMethodsValue + get() = this +} + +/** + * JcMockedMethodsRegionId allows to quickly differentiate between memory regions. + */ +data class JcMockedMethodsRegionId( + override val sort: Sort, + val type: JcType +) : UMemoryRegionId, Sort> { + override fun emptyRegion(): UMemoryRegion, Sort> = JcMockedMethodsRegion(sort, type) +} + +/** + * JcMockedMethodsRegion represents a slice of memory that tracks mocked methods' return values. + */ +open class JcMockedMethodsRegion( + private val sort: Sort, + private val type: JcType, + private val mockedMethods: UPersistentHashMap>> = persistentHashMapOf() +) : UMemoryRegion, Sort> { + /** + * read() allows to read JcMockedMethodsValue from memory. + * If there's no such JcMockedMethodsValue, it returns a new symbolic expression for it. + */ + override fun read(key: JcMockedMethodsValue): UExpr { + val mockedMethod = key.mockedMethod + val field = mockedMethod.method + val ret = mockedMethods[mockedMethod.enclosingClass]?.get(field) + return ret ?: JcMockedMethodsReading(sort.jctx, key.memoryRegionId as JcMockedMethodsRegionId, mockedMethod, type, key.method, sort) + } + + /** + * write() allows to record a specific value. It is not used in the code but is required to be overridden. + */ + override fun write( + key: JcMockedMethodsValue, + value: UExpr, + guard: UBoolExpr, + ownership: MutabilityOwnership + ): UMemoryRegion, Sort> { + val mockedMethod = key.mockedMethod + val enclosingClass = mockedMethod.enclosingClass + val classFields = mockedMethods.getOrDefault(enclosingClass, persistentHashMapOf()) + + val newFieldValues = classFields.guardedWrite(key.mockedMethod.method, value, guard, ownership) { key.sort.sampleUValue() } + val newFieldsByClass = mockedMethods.put(enclosingClass, newFieldValues, ownership) + return JcMockedMethodsRegion(sort, type, newFieldsByClass) + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/render/ConfigInfoRenderer.kt b/usvm-jvm-mocks/src/main/kotlin/machine/render/ConfigInfoRenderer.kt new file mode 100644 index 0000000000..cffe0e8254 --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/render/ConfigInfoRenderer.kt @@ -0,0 +1,109 @@ +package machine.render + +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import machine.instructions.UTestMockConfigInfo +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.testRenderer.JcTestVisitor +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestClassRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeUtilsRenderer +import org.usvm.test.api.UTest +import org.usvm.test.api.UTestAllocateMemoryCall +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestInst +import org.usvm.test.api.UTestMockInst + +/** + * ConfigInfoRenderer inherits JcUnsafeTestRenderer and handles rendering Mockito.when(..).thenReturn(..) Java code. + * It uses JcExprUsageVisitor for the purpose of creating a variable if some value is repeatedly used in the configuration. + */ +class ConfigInfoRenderer( + private val mockConfigInfo: UTestMockConfigInfo, + test: UTest, + classRenderer: JcUnsafeTestClassRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + name: SimpleName, + annotations: List, + unsafeUtilsRenderer: JcUnsafeUtilsRenderer +) : JcUnsafeTestRenderer( + test, + classRenderer, + importManager, + identifiersManager, + cp, + name, + annotations, + unsafeUtilsRenderer +) { + inner class JcExprUsageVisitor : JcTestVisitor() { + private fun shouldDeclareVarCheck(expr: UTestExpression): Boolean { + return !preventVarDeclarationOf(expr) && isVisited(expr) || requireVarDeclarationOf(expr) + } + override fun visitExpr(expr: UTestExpression) { + if (shouldDeclareVarCheck(expr)) { + shouldDeclareVar.add(expr) + } + + super.visitExpr(expr) + } + fun visit(instructions: List) { + for (inst in instructions) { + visit(inst) + } + } + } + + init { + val instructions = mockConfigInfo.instructions.map { it.first } + JcExprUsageVisitor().visit(instructions) + } + + private fun getVarsNum(): Set { + return shouldDeclareVar + } + + /** + * Renders every instruction (every mock configuration) from UTestMockInst. + */ + override fun renderInternal(): MethodDeclaration { + val instructions = mockConfigInfo.instructions + for (inst in instructions) { + body.renderInst(inst.first) + } + return super.renderInternal() + } + + /** + * Detects if method under test throws Exception. + */ + fun ifThrowsInstantiationException(): Boolean { + for (inst in mockConfigInfo.instructions) { + if (inst.first is UTestMockInst && (inst.first as UTestMockInst).instance is UTestAllocateMemoryCall) { + return true + } + } + return false + } + + /** + * Renders comments for configuration with metadata. + */ + fun renderConfigInfo(): List { + val instructions = mockConfigInfo.instructions + val vars = getVarsNum() + val lines = mutableListOf() + for (inst in instructions) { + if (inst.first is UTestMockInst && (inst.first as UTestMockInst).instance in vars) { + lines.add("\n") + } + lines.add(inst.second + "\n") + } + return lines + } +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/render/RenderMockConfigInfo.kt b/usvm-jvm-mocks/src/main/kotlin/machine/render/RenderMockConfigInfo.kt new file mode 100644 index 0000000000..cafab0a37a --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/render/RenderMockConfigInfo.kt @@ -0,0 +1,98 @@ +package machine.render + +import com.github.javaparser.printer.DefaultPrettyPrinter +import machine.getLineNumber +import machine.instructions.UTestMockConfigInfo +import machine.varnamesMap +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestClassRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeUtilsRenderer +import org.usvm.test.api.UTest + +/** + * renderMockConfigInfo() takes UTestMockConfigInfo, creates an instance of ConfigInfoRenderer and handles rendering, + * returning Java-code of configurations and an indicator whether method under test throws Instantiation Exception. + */ +fun renderMockConfigInfo(cp: JcClasspath, test: UTest, mockConfigInfo: UTestMockConfigInfo): Pair { + val importManager = JcImportManager() + val identifiersManager = JcIdentifiersManager() + val strategy = ReflectionUtilsInlineStrategy.NoInline() + val utilsRenderer = JcUnsafeUtilsRenderer(importManager, strategy) + val testClassRenderer = JcUnsafeTestClassRenderer( + "name", + importManager, + identifiersManager, + cp, + utilsRenderer + ) + val testRenderer = ConfigInfoRenderer(mockConfigInfo, test, testClassRenderer, importManager, identifiersManager, cp, identifiersManager["test"], listOf(), utilsRenderer) + val res = testRenderer.render() + val printer = DefaultPrettyPrinter() + val text = printer.print(res) + val comments = testRenderer.renderConfigInfo() + val isExceptional = testRenderer.ifThrowsInstantiationException() + return Pair(modifyText(text, reorder(comments)), isExceptional) +} + +/** + * Here we combine mocks configuration code respectively with the comments contains metadata. + */ +fun modifyText(text: String, comments: List): String { + val lines = text.split("\n") as MutableList + lines.removeAt(lines.lastIndex) + lines.removeAt(lines.lastIndex) + lines.removeAt(0) + var finalText = "" + var varname = "" + for (i in 0 until lines.size) { + if (comments[i] != "\n") { + val lineNumber = getLineNumber(comments[i]) + varname = varnamesMap[lineNumber].toString() + val alteredLine = alter(lines[i], varname) + finalText = finalText + "//" + comments[i] + alteredLine + "\n" + } else { + val alteredLine = alter(lines[i], varname) + finalText = finalText + alteredLine + "\n" + } + } + return finalText +} + +/** + * This is a helper function for reordering comment lines to achieve a proper text order. + */ +fun reorder(lines: List): List { + val result = mutableListOf() + val newlines = mutableListOf() + + for (line in lines) { + if (line == "\n") { + newlines.add(line) + } else { + result.add(line) + result.addAll(newlines) + newlines.clear() + } + } + result.addAll(newlines) + return result +} + +/** + * In rendered code there's a slight problem: + * Mockito.when(classname.method(args) -- it should it fact be a variable name, not a classname. + * This function fixes the problem, given the line and the variable name. + */ +fun alter(line: String, varname: String): String { + if (line.contains("Mockito.when")) { + val retValue = line.substringAfter(".thenReturn") + val str = line.substringBefore(").thenReturn").substringAfter("Mockito.when(") + val method = str.substringAfter(".class") + val res = "Mockito.when($varname$method).thenReturn$retValue" + return res + } + return line.trim() +} diff --git a/usvm-jvm-mocks/src/main/kotlin/machine/render/TestOutputHandler.kt b/usvm-jvm-mocks/src/main/kotlin/machine/render/TestOutputHandler.kt new file mode 100644 index 0000000000..8006e04c5d --- /dev/null +++ b/usvm-jvm-mocks/src/main/kotlin/machine/render/TestOutputHandler.kt @@ -0,0 +1,68 @@ +package machine.render + +import java.io.File + +/** + * Saves Java code as a .java file to testOutput directory. + */ +fun renderTestFile(ifThrows: Boolean, name: String, type: String, pathToFile: String, mockConfigInfo: String) { + val fileLines = File(pathToFile).readLines().toMutableList() + fileLines[0] = "package org.usvm.samples.testOutput;" + var methodUnderTestLine = -1 + if (ifThrows) { + for (i in 0 until fileLines.size) { + if (fileLines[i].contains(type) && fileLines[i].contains(name)) { methodUnderTestLine = i } + } + } + val fileText = renderFinalFileText(methodUnderTestLine, fileLines, mockConfigInfo) + val fileName = pathToFile.substringAfterLast("/") + val path = pathToFile.substringBeforeLast("/") + val newPath = "$path/testOutput/$fileName" + File(newPath).writeText(fileText) +} + +/** + * Combines configurations and initial test file, returning ready-to-use Java code for the test. + */ +fun renderFinalFileText(methodLineNum: Int, fileLines: List, mockConfigInfo: String): String { + var fileText = "" + var firstLine = 0 + val configBlocks = mockConfigInfo.split("//") as MutableList + configBlocks.removeAt(0) + val data = mutableListOf>() + for (block in configBlocks) { + val lineNumber = getMethodLineNumber(block) + data.add(Pair(block, lineNumber)) + } + val sortedBlocks = (data.sortedBy { it.second }).map { it.first } + for (block in sortedBlocks) { + val lineNumber = getMethodLineNumber(block) + for (i in firstLine until lineNumber - 1) { + if (i == methodLineNum) { + val before = fileLines[i].substringBefore(")") + val after = fileLines[i].substringAfter(")") + fileText += "$before) throws InstantiationException$after\n" + continue + } + fileText += fileLines[i] + "\n" + } + firstLine = lineNumber - 1 + val prefix = fileLines[lineNumber - 1].takeWhile { it == ' ' || it == '\t' } + val blockLines = block.split("\n") + for (i in 1 until blockLines.size) { + fileText = fileText + prefix + blockLines[i] + "\n" + } + } + for (i in firstLine until fileLines.size) { + fileText += fileLines[i] + "\n" + } + return fileText +} + +/** + * Given a line with metadata, it parses it and finds the line, on which mock object called a certain method. + */ +fun getMethodLineNumber(str: String): Int { + val commentLine = str.substringBefore("\n") + return commentLine.substringAfterLast("(line:").substringBeforeLast(")").toInt() +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestABC.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestABC.java new file mode 100644 index 0000000000..3a52c6dac5 --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestABC.java @@ -0,0 +1,33 @@ +package org.usvm.samples; + +import org.mockito.Mockito; + +interface A { + int foo(int x); +} +interface B { + int bar(int x, int y); +} +interface C { + int baz(int x); +} + +public class TestABC { + void compute(int a, int b, int c) { + A mA = Mockito.mock(A.class); + B mB = Mockito.mock(B.class); + C mC = Mockito.mock(C.class); + + int r1 = mA.foo(a); + int r2 = mA.foo(r1 + b); + + int r3 = mB.bar(r1, r2); + int r4 = mC.baz(r3 + c); + + + assert(r1 == 3); + assert(r2 == r1 + b + 2); + assert(r3 == r1 * r2); + assert(r4 == r3 - 5); + } +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestAdder.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestAdder.java new file mode 100644 index 0000000000..5ecdb05192 --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestAdder.java @@ -0,0 +1,18 @@ +package org.usvm.samples; + +import org.mockito.Mockito; + +interface Adder { + int add(int a, int b); +} +public class TestAdder { + public void compute(int a, int b) { + Adder adder = Mockito.mock(Adder.class); + Adder adder2 = Mockito.mock(Adder.class); + + int res = adder.add(a, b); + int res2 = adder2.add(a, b); + assert (res == 3); + assert (res2 == 4); + } +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestCalc.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestCalc.java new file mode 100644 index 0000000000..d623c6535f --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestCalc.java @@ -0,0 +1,34 @@ +package org.usvm.samples; + +import org.mockito.Mockito; + +interface Count { + public int f(int x); + public int g(int x); +} + +class TestCalc { + void compute(int a, int b) { + Count s1 = Mockito.mock(Count.class); // g -> 16, f -> 5 + Count s2 = Mockito.mock(Count.class); // f -> 0, g -> 7 + + int r1 = s1.f(a); + int r2 = s2.g(b); + + int sum = r1 + r2; + + if (a > 0) { + int r3 = s1.g(sum); + assert(r3 == sum + 10); + } else { + int r4 = s2.g(sum); + assert(r4 == sum - 5); + } + + assert(sum == 12); + assert(r1 == 5); + assert(r2 == 7); + } +} + + diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestData.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestData.java new file mode 100644 index 0000000000..d5e0612e7f --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestData.java @@ -0,0 +1,43 @@ +package org.usvm.samples; + +import org.mockito.Mockito; + +interface Cache { + Data load(String key); + boolean exists(String key); +} + +class Data { + public final int value; + public final String meta; + public int getValue() { + return value; + } + public String getMeta() { + return meta; + } + public Data(int value, String meta) { + this.value = value; + this.meta = meta; + } +} + +public class TestData { + public void compute(String k1, String k2) { + Cache c1 = Mockito.mock(Cache.class); + Cache c2 = Mockito.mock(Cache.class); + + Data d1 = c1.load(k1); + Data d2 = c2.load(k2); + + boolean e1 = c1.exists(k1); + boolean e2 = c2.exists(k2); + + assert d1.getValue() + d2.getValue() == 10; + assert e1 != e2; + assert d1.getMeta().equals("x"); + assert d2.getMeta().equals("y"); +} + +} + diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestService.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestService.java new file mode 100644 index 0000000000..119df578a6 --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/TestService.java @@ -0,0 +1,50 @@ +package org.usvm.samples; + +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Objects; + + +interface Service { + String getName(int id); + Profile getProfile(int id); + String [] getTags(); + int getAge(int id); +} + +class Profile { + public final String name; + public final int ssn; + public int getSsn() { + return ssn; + } + public String getName() { + return name; + } + + public Profile(String name, int ssn) { + this.name = name; + this.ssn = ssn; + } +} + +public class TestService { + public void compute(int id) { + Service s = Mockito.mock(Service.class); + + String [] tags = s.getTags(); + Profile profile = s.getProfile(id); + String name = s.getName(id); + String name2 = s.getName(2); + int age = s.getAge(id); + int ssn = profile.getSsn(); + + assert name2.equals("Alice"); + String [] tags1 = {"a", "b", "c"}; + assert Arrays.equals(tags, tags1); + assert age > 18 && age < 45; + assert ssn > 9999 && ssn < 1000000; + assert Objects.equals(name, "Bob"); + } +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/ReflectionUtils.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/ReflectionUtils.java new file mode 100644 index 0000000000..c420e9aabe --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/ReflectionUtils.java @@ -0,0 +1,288 @@ +package org.usvm.samples.testOutput; + +import sun.misc.Unsafe; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; + +@SuppressWarnings({"removal", "deprecation"}) +public class ReflectionUtils { + private static final Unsafe UNSAFE; + + static { + try { + Field uns = Unsafe.class.getDeclaredField("theUnsafe"); + uns.setAccessible(true); + UNSAFE = (Unsafe) uns.get(null); + } catch (Throwable e) { + throw new RuntimeException(); + } + } + + //region Fields Interaction + + private static long getOffsetOf(Field field) { + return isStatic(field) ? UNSAFE.staticFieldOffset(field) : UNSAFE.objectFieldOffset(field); + } + + private static boolean isStatic(Field field) { + return (field.getModifiers() & Modifier.STATIC) > 0; + } + + private static List getInstanceFields(Class type) { + ArrayList fields = new ArrayList<>(); + for (Field field : type.getDeclaredFields()) { + if (!Modifier.isStatic(field.getModifiers())) + fields.add(field); + } + + return fields; + } + + private static List getStaticFields(Class type) { + ArrayList fields = new ArrayList<>(); + for (Field field : type.getDeclaredFields()) { + if (Modifier.isStatic(field.getModifiers())) + fields.add(field); + } + + return fields; + } + + private static Field getField(Object instance, String fieldName) { + Class type = instance.getClass(); + Class currentClass = type; + while (currentClass != Object.class && currentClass != null) { + for (Field field : getInstanceFields(currentClass)) { + if (field.getName().equals(fieldName)) + return field; + } + currentClass = currentClass.getSuperclass(); + } + + throw new IllegalArgumentException("Could not find field " + fieldName + " in " + type); + } + + private static Field getStaticField(Class type, String fieldName) { + for (Field field : getStaticFields(type)) { + if (field.getName().equals(fieldName)) + return field; + } + + throw new IllegalArgumentException("Could not find static field " + fieldName + " in " + type); + } + + private static Object getFieldValue(Object fixedInstance, Field field) { + long fieldOffset = getOffsetOf(field); + + if (!field.getType().isPrimitive()) { + return UNSAFE.getObject(fixedInstance, fieldOffset); + } + + if (field.getType() == boolean.class) { + return UNSAFE.getBoolean(fixedInstance, fieldOffset); + } else if (field.getType() == byte.class) { + return UNSAFE.getByte(fixedInstance, fieldOffset); + } else if (field.getType() == char.class) { + return UNSAFE.getChar(fixedInstance, fieldOffset); + } else if (field.getType() == short.class) { + return UNSAFE.getShort(fixedInstance, fieldOffset); + } else if (field.getType() == int.class) { + return UNSAFE.getInt(fixedInstance, fieldOffset); + } else if (field.getType() == long.class) { + return UNSAFE.getLong(fixedInstance, fieldOffset); + } else if (field.getType() == float.class) { + return UNSAFE.getFloat(fixedInstance, fieldOffset); + } else if (field.getType() == double.class) { + return UNSAFE.getDouble(fixedInstance, fieldOffset); + } + + throw new IllegalStateException("unexpected primitive type"); + } + + @SuppressWarnings("unchecked") + public static T getFieldValue(Object instance, String fieldName) { + Field field = getField(instance, fieldName); + return (T) getFieldValue(instance, field); + } + + @SuppressWarnings("unchecked") + public static T getStaticFieldValue(Class type, String fieldName) { + Field field = getStaticField(type, fieldName); + return (T) getFieldValue(UNSAFE.staticFieldBase(field), field); + } + + private static void setFieldValue(Object fixedInstance, Field field, Object value) { + long fieldOffset = getOffsetOf(field); + + if (!field.getType().isPrimitive()) { + UNSAFE.putObject(fixedInstance, fieldOffset, value); + return; + } + + if (field.getType() == boolean.class) { + UNSAFE.putBoolean(fixedInstance, fieldOffset, value != null && ((boolean) value)); + } else if (field.getType() == byte.class) { + UNSAFE.putByte(fixedInstance, fieldOffset, value != null ? (byte) value : 0); + } else if (field.getType() == char.class) { + UNSAFE.putChar(fixedInstance, fieldOffset, value != null ? (char) value : '\u0000'); + } else if (field.getType() == short.class) { + UNSAFE.putShort(fixedInstance, fieldOffset, value != null ? (short) value : 0); + } else if (field.getType() == int.class) { + UNSAFE.putInt(fixedInstance, fieldOffset, value != null ? (int) value : 0); + } else if (field.getType() == long.class) { + UNSAFE.putLong(fixedInstance, fieldOffset, value != null ? (long) value : 0); + } else if (field.getType() == float.class) { + UNSAFE.putFloat(fixedInstance, fieldOffset, value != null ? (float) value : 0.0f); + } else if (field.getType() == double.class) { + UNSAFE.putDouble(fixedInstance, fieldOffset, value != null ? (double) value : 0.0); + } + } + + public static void setFieldValue(Object instance, String fieldName, Object value) { + Field field = getField(instance, fieldName); + setFieldValue(instance, field, value); + } + + public static void setStaticFieldValue(Class type, String fieldName, Object value) { + Field field = getStaticField(type, fieldName); + setFieldValue(UNSAFE.staticFieldBase(field), field, value); + } + + public static Throwable getRootCause(Throwable exception) { + Throwable result = exception; + while (true) { + Throwable cause = result.getCause(); + if (cause == null || cause == result) return result; + result = cause; + } + } + + //endregion + + //region Allocation + + @SuppressWarnings("unchecked") + public static T allocateInstance(Class clazz) throws InstantiationException { + return (T) UNSAFE.allocateInstance(clazz); + } + + //endregion + + //region Methods Interaction + + private static String parameterTypesSignature(Class[] types) { + if (types.length == 0) + return ""; + + StringBuilder sb = new StringBuilder(); + for (Class type : types) { + sb.append(type.getTypeName()); + sb.append(";"); + } + + return sb.toString(); + } + + private static String methodSignature(Method method) { + String parametersSig = parameterTypesSignature(method.getParameterTypes()); + String returnTypeSig = method.getReturnType().getTypeName(); + return method.getName() + "(" + parametersSig + ")" + returnTypeSig + ";"; + } + + private static String methodSignature(Constructor method) { + String parametersSig = parameterTypesSignature(method.getParameterTypes()); + return "" + "(" + parametersSig + ")" + "void" + ";"; + } + + private static List getInstanceMethods(Class type) { + ArrayList methods = new ArrayList<>(); + for (Method method : type.getDeclaredMethods()) { + if (!Modifier.isStatic(method.getModifiers())) + methods.add(method); + } + + return methods; + } + + private static List getStaticMethods(Class type) { + ArrayList methods = new ArrayList<>(); + for (Method method : type.getDeclaredMethods()) { + if (Modifier.isStatic(method.getModifiers())) + methods.add(method); + } + + return methods; + } + + private static Method getMethod(Object instance, String methodSig) { + Class type = instance.getClass(); + Class currentClass = type; + while (currentClass != Object.class && currentClass != null) { + for (Method method : getInstanceMethods(currentClass)) { + if (methodSignature(method).equals(methodSig)) + return method; + } + currentClass = currentClass.getSuperclass(); + } + + throw new IllegalArgumentException("Could not find method " + methodSig + " in " + type); + } + + private static Method getStaticMethod(Class type, String methodSig) { + for (Method method : getStaticMethods(type)) { + if (methodSignature(method).equals(methodSig)) + return method; + } + + throw new IllegalArgumentException("Could not find static method " + methodSig + " in " + type); + } + + private static Constructor getConstructor(Class type, String ctorSig) { + for (Constructor ctor : type.getDeclaredConstructors()) { + if (methodSignature(ctor).equals(ctorSig)) + return ctor; + } + + throw new IllegalArgumentException("Could not find constructor " + ctorSig + " in " + type); + } + + private static Object callMethod(Object instance, Method method, Object... args) throws Throwable { + Object[] checkedArgs = (args == null) ? new Object[] { null } : args; + try { + method.setAccessible(true); + return method.invoke(instance, checkedArgs); + } catch (InvocationTargetException e) { + throw e.getTargetException(); + } + } + + @SuppressWarnings("unchecked") + public static T callMethod(Object instance, String methodSig, Object... args) throws Throwable { + return (T) callMethod(instance, getMethod(instance, methodSig), args); + } + + @SuppressWarnings("unchecked") + public static T callStaticMethod(Class type, String methodSig, Object... args) throws Throwable { + return (T) callMethod(null, getStaticMethod(type, methodSig), args); + } + + @SuppressWarnings("unchecked") + public static T callConstructor(Class type, String ctorSig, Object... args) throws Throwable { + Constructor ctor = getConstructor(type, ctorSig); + Object[] checkedArgs = (args == null) ? new Object[] { null } : args; + try { + ctor.setAccessible(true); + return (T) ctor.newInstance(checkedArgs); + } catch (InvocationTargetException e) { + throw e.getTargetException(); + } + } + + //endregion +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestABC.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestABC.java new file mode 100644 index 0000000000..a82b04e489 --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestABC.java @@ -0,0 +1,41 @@ +package org.usvm.samples.testOutput; + +import org.mockito.Mockito; + +interface A { + int foo(int x); +} +interface B { + int bar(int x, int y); +} +interface C { + int baz(int x); +} + +public class TestABC { + void compute(int a, int b, int c) { + A mA = Mockito.mock(A.class); + B mB = Mockito.mock(B.class); + C mC = Mockito.mock(C.class); + + Mockito.when(mA.foo(0)).thenReturn(3); + + int r1 = mA.foo(a); + Mockito.when(mA.foo(0)).thenReturn(5); + + int r2 = mA.foo(r1 + b); + + Mockito.when(mB.bar(0, 0)).thenReturn(15); + + int r3 = mB.bar(r1, r2); + Mockito.when(mC.baz(0)).thenReturn(10); + + int r4 = mC.baz(r3 + c); + + + assert(r1 == 3); + assert(r2 == r1 + b + 2); + assert(r3 == r1 * r2); + assert(r4 == r3 - 5); + } +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestAdder.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestAdder.java new file mode 100644 index 0000000000..9e58d5108b --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestAdder.java @@ -0,0 +1,22 @@ +package org.usvm.samples.testOutput; + +import org.mockito.Mockito; + +interface Adder { + int add(int a, int b); +} +public class TestAdder { + public void compute(int a, int b) { + Adder adder = Mockito.mock(Adder.class); + Adder adder2 = Mockito.mock(Adder.class); + + Mockito.when(adder.add(0, 0)).thenReturn(3); + + int res = adder.add(a, b); + Mockito.when(adder2.add(0, 0)).thenReturn(4); + + int res2 = adder2.add(a, b); + assert (res == 3); + assert (res2 == 4); + } +} diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestCalc.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestCalc.java new file mode 100644 index 0000000000..2c764aad2d --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestCalc.java @@ -0,0 +1,42 @@ +package org.usvm.samples.testOutput; + +import org.mockito.Mockito; + +interface Count { + public int f(int x); + public int g(int x); +} + +class TestCalc { + void compute(int a, int b) { + Count s1 = Mockito.mock(Count.class); // g -> 16, f -> 5 + Count s2 = Mockito.mock(Count.class); // f -> 0, g -> 7 + + Mockito.when(s1.f(1)).thenReturn(5); + + int r1 = s1.f(a); + Mockito.when(s2.g(1)).thenReturn(7); + + int r2 = s2.g(b); + + int sum = r1 + r2; + + if (a > 0) { + Mockito.when(s1.g(1)).thenReturn(22); + + int r3 = s1.g(sum); + assert(r3 == sum + 10); + } else { + Mockito.when(s2.g(1)).thenReturn(0); + + int r4 = s2.g(sum); + assert(r4 == sum - 5); + } + + assert(sum == 12); + assert(r1 == 5); + assert(r2 == 7); + } +} + + diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestData.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestData.java new file mode 100644 index 0000000000..223eda050c --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestData.java @@ -0,0 +1,65 @@ +package org.usvm.samples.testOutput; + +import org.mockito.Mockito; + +interface Cache { + Data load(String key); + boolean exists(String key); +} + +class Data { + public final int value; + public final String meta; + public int getValue() { + return value; + } + public String getMeta() { + return meta; + } + public Data(int value, String meta) { + this.value = value; + this.meta = meta; + } +} + +public class TestData { + public void compute(String k1, String k2) throws InstantiationException { + Cache c1 = Mockito.mock(Cache.class); + Cache c2 = Mockito.mock(Cache.class); + + Data v = ReflectionUtils.allocateInstance(Data.class); + ReflectionUtils.setFieldValue(v, "value", 0); + ReflectionUtils.setFieldValue(v, "meta", null); + Mockito.when(c1.load("")).thenReturn(v); + + Data d1 = c1.load(k1); + Data v1 = ReflectionUtils.allocateInstance(Data.class); + ReflectionUtils.setFieldValue(v1, "value", 0); + ReflectionUtils.setFieldValue(v1, "meta", null); + Mockito.when(c2.load("")).thenReturn(v1); + + Data d2 = c2.load(k2); + + Mockito.when(c1.exists("")).thenReturn(false); + + boolean e1 = c1.exists(k1); + Mockito.when(c2.exists("")).thenReturn(true); + + boolean e2 = c2.exists(k2); + + Mockito.when(d1.getValue()).thenReturn(15); + + Mockito.when(d2.getValue()).thenReturn(-5); + + assert d1.getValue() + d2.getValue() == 10; + assert e1 != e2; + Mockito.when(d1.getMeta()).thenReturn("x"); + + assert d1.getMeta().equals("x"); + Mockito.when(d2.getMeta()).thenReturn("y"); + + assert d2.getMeta().equals("y"); +} + +} + diff --git a/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestService.java b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestService.java new file mode 100644 index 0000000000..c0955fd7cc --- /dev/null +++ b/usvm-jvm-mocks/src/samples/java/org/usvm/samples/testOutput/TestService.java @@ -0,0 +1,69 @@ +package org.usvm.samples.testOutput; + +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Objects; + + +interface Service { + String getName(int id); + Profile getProfile(int id); + String [] getTags(); + int getAge(int id); +} + +class Profile { + public final String name; + public final int ssn; + public int getSsn() { + return ssn; + } + public String getName() { + return name; + } + + public Profile(String name, int ssn) { + this.name = name; + this.ssn = ssn; + } +} + +public class TestService { + public void compute(int id) throws InstantiationException { + Service s = Mockito.mock(Service.class); + + String[] v = new String[3]; + v[0] = "a"; + v[1] = "b"; + v[2] = "c"; + Mockito.when(s.getTags()).thenReturn(v); + + String [] tags = s.getTags(); + Profile v1 = ReflectionUtils.allocateInstance(Profile.class); + ReflectionUtils.setFieldValue(v1, "name", null); + ReflectionUtils.setFieldValue(v1, "ssn", 0); + Mockito.when(s.getProfile(0)).thenReturn(v1); + + Profile profile = s.getProfile(id); + Mockito.when(s.getName(0)).thenReturn("Bob"); + + String name = s.getName(id); + Mockito.when(s.getName(0)).thenReturn("Alice"); + + String name2 = s.getName(2); + Mockito.when(s.getAge(0)).thenReturn(29); + + int age = s.getAge(id); + Mockito.when(profile.getSsn()).thenReturn(475776); + + int ssn = profile.getSsn(); + + assert name2.equals("Alice"); + String [] tags1 = {"a", "b", "c"}; + assert Arrays.equals(tags, tags1); + assert age > 18 && age < 45; + assert ssn > 9999 && ssn < 1000000; + assert Objects.equals(name, "Bob"); + } +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ABCTest.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ABCTest.kt new file mode 100644 index 0000000000..5fa5236664 --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ABCTest.kt @@ -0,0 +1,21 @@ +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class ABCTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun testCompute() { + checkDiscoveredPropertiesWithExceptions( + TestABC::compute, + ignoreNumberOfAnalysisResults, + { _, _, _, _, r -> r.getOrNull() == null } + ) + } +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/AdderTest.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/AdderTest.kt new file mode 100644 index 0000000000..8c9346513e --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/AdderTest.kt @@ -0,0 +1,21 @@ +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class AdderTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun testCompute() { + checkDiscoveredPropertiesWithExceptions( + TestAdder::compute, + ignoreNumberOfAnalysisResults, + { _, _, _, r -> r.getOrNull() == null } + ) + } +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CalcTest.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CalcTest.kt new file mode 100644 index 0000000000..4fb844604f --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CalcTest.kt @@ -0,0 +1,21 @@ +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class CalcTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun testCalc() { + checkDiscoveredPropertiesWithExceptions( + TestCalc::compute, + ignoreNumberOfAnalysisResults, + { _, _, _, r -> r.getOrNull() == null } + ) + } +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CleanUp.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CleanUp.kt new file mode 100644 index 0000000000..bb6088e844 --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/CleanUp.kt @@ -0,0 +1,13 @@ +package org.usvm.samples + +import machine.mockedMethods +import machine.mockedMethodsValues +import machine.mocksMap +import machine.varnamesMap + +fun cleanUp() { + mocksMap.clear() + mockedMethods.clear() + mockedMethodsValues.clear() + varnamesMap.clear() +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/DataTest.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/DataTest.kt new file mode 100644 index 0000000000..15371f0124 --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/DataTest.kt @@ -0,0 +1,21 @@ +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class DataTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun testCalc() { + checkDiscoveredPropertiesWithExceptions( + TestData::compute, + ignoreNumberOfAnalysisResults, + { _, _, _, r -> r.getOrNull() == null } + ) + } +} diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/MocksTestRunner.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/MocksTestRunner.kt new file mode 100644 index 0000000000..68eee007f5 --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/MocksTestRunner.kt @@ -0,0 +1,898 @@ +package org.usvm.samples + +import machine.JcMocksMachine +import machine.instructions.createUTestInstructions +import machine.instructions.createUTestMockConfigInfo +import machine.mockedMethods +import machine.mockedMethodsValues +import machine.render.renderMockConfigInfo +import machine.render.renderTestFile +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.cfg.JcInst +import org.jacodb.api.jvm.cfg.JcReturnInst +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.extension.ExtendWith +import org.usvm.CoverageZone +import org.usvm.PathSelectionStrategy +import org.usvm.UMachineOptions +import org.usvm.api.JcClassCoverage +import org.usvm.api.JcParametersState +import org.usvm.api.JcTest +import org.usvm.api.StaticFieldValue +import org.usvm.api.createUTest +import org.usvm.api.targets.JcTarget +import org.usvm.api.util.JcTestInterpreter +import org.usvm.api.util.JcTestResolver +import org.usvm.machine.JcInterpreterObserver +import org.usvm.test.api.UTestInst +import org.usvm.test.util.TestRunner +import org.usvm.test.util.checkers.AnalysisResultsNumberMatcher +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults +import org.usvm.util.JcTestExecutor +import org.usvm.util.JcTestResolverType +import org.usvm.util.UTestRunnerController +import org.usvm.util.getJcMethodByName +import org.usvm.util.loadClasspathFromEnv +import java.io.File +import kotlin.reflect.KClass +import kotlin.reflect.KFunction +import kotlin.reflect.KFunction1 +import kotlin.reflect.KFunction2 +import kotlin.reflect.KFunction3 +import kotlin.reflect.KFunction4 +import kotlin.reflect.full.instanceParameter +import kotlin.reflect.jvm.javaConstructor +import kotlin.time.Duration + +class MocksTarget(override val location: JcInst) : JcTarget(location) + +/** + * MocksTestRunner is responsible for symbolic execution of a method under test. + * It also handles reading mocked methods' values from memory, creating UTestInstructions and UTestMockConfigInfo, + * which is later rendered and result is saved as a .java file in testOutput directory. + */ +@ExtendWith(UTestRunnerController::class) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +open class MocksTestRunner : TestRunner, KClass<*>?, JcClassCoverage>() { + + private var targets: List = emptyList() + private var interpreterObserver: JcInterpreterObserver? = null + + /** + * Sets JcTargets to run JcMachine with in the scope of [action]. + */ + protected fun withTargets( + targets: List, + interpreterObserver: JcInterpreterObserver, + action: () -> T + ): T { + val prevTargets = this.targets + val prevInterpreterObserver = this.interpreterObserver + try { + this.targets = targets + this.interpreterObserver = interpreterObserver + return action() + } finally { + this.targets = prevTargets + this.interpreterObserver = prevInterpreterObserver + } + } + + // region Default checkers + + protected inline fun checkExecutionBranches( + method: KFunction1, + vararg analysisResultsMatchers: (T, R?) -> Boolean, + invariants: Array<(T, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredProperties( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, R?) -> Boolean, + invariants: Array<(T, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkDiscoveredPropertiesWithStatics( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, R?, StaticsType, StaticsType) -> Boolean, + invariants: Array<(T, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeParametersBeforeAndAllStaticsWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, R::class, Map::class, Map::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkMatches( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, R?) -> Boolean, + invariants: Array<(T, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutations( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, T, R?) -> Boolean, + invariants: Array<(T, T, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, T::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranches( + method: KFunction2, + vararg analysisResultsMatchers: (T, A0, R?) -> Boolean, + invariants: Array<(T, A0, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredProperties( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, R?) -> Boolean, + invariants: Array<(T, A0, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatches( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, R?) -> Boolean, + invariants: Array<(T, A0, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutations( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, T, A0, R?) -> Boolean, + invariants: Array<(T, A0, T, A0, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, T::class, A0::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranches( + method: KFunction3, + vararg analysisResultsMatchers: (T, A0, A1, R?) -> Boolean, + invariants: Array<(T, A0, A1, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredProperties( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, R?) -> Boolean, + invariants: Array<(T, A0, A1, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatches( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, R?) -> Boolean, + invariants: Array<(T, A0, A1, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutations( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, A1, T, A0, A1, R?) -> Boolean, + invariants: Array<(T, A0, A1, T, A0, A1, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class, T::class, A0::class, A1::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranches( + method: KFunction4, + vararg analysisResultsMatchers: (T, A0, A1, A2, R?) -> Boolean, + invariants: Array<(T, A0, A1, A2, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredProperties( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, A2, R?) -> Boolean, + invariants: Array<(T, A0, A1, A2, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatches( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatches( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, A2, R?) -> Boolean, + invariants: Array> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeWithResult(method) }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class, A2::class, R::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutations( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, A1, A2, T, A0, A1, A2, R?) -> Boolean, + invariants: Array<(T, A0, A1, A2, T, A0, A1, A2, R?) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + expectedTypesForExtractedValues = arrayOf( + T::class, + A0::class, + A1::class, + A2::class, + T::class, + A0::class, + A1::class, + A2::class, + R::class + ), + checkMode = checkMode, + coverageChecker + ) + } + + // endregion + + // region Default checkers with exceptions + + protected inline fun checkExecutionBranchesWithExceptions( + method: KFunction1, + vararg analysisResultsMatchers: (T, Result) -> Boolean, + invariants: Array<(T, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredPropertiesWithExceptions( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, Result) -> Boolean, + invariants: Array<(T, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatchesWithExceptions( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, Result) -> Boolean, + invariants: Array<(T, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> + val values = test.takeAllParametersBefore(method) + test.result.let { values += it } + values + }, + expectedTypesForExtractedValues = arrayOf(T::class), // We don't check type for the result here + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutationsWithExceptions( + method: KFunction1, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, T, Result) -> Boolean, + invariants: Array<(T, T, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf(T::class, T::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranchesWithExceptions( + method: KFunction2, + vararg analysisResultsMatchers: (T, A0, Result) -> Boolean, + invariants: Array<(T, A0, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredPropertiesWithExceptions( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, Result) -> Boolean, + invariants: Array<(T, A0, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatchesWithExceptions( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, Result) -> Boolean, + invariants: Array<(T, A0, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> + val values = test.takeAllParametersBefore(method) + test.result.let { values += it } + values + }, + expectedTypesForExtractedValues = arrayOf(T::class, A0::class), // We don't check type for the result here + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutationsWithExceptions( + method: KFunction2, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, T, A0, Result) -> Boolean, + invariants: Array<(T, A0, T, A0, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, T::class, A0::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranchesWithExceptions( + method: KFunction3, + vararg analysisResultsMatchers: (T, A0, A1, Result) -> Boolean, + invariants: Array<(T, A0, A1, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredPropertiesWithExceptions( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, Result) -> Boolean, + invariants: Array<(T, A0, A1, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatchesWithExceptions( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, Result) -> Boolean, + invariants: Array<(T, A0, A1, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> + val values = test.takeAllParametersBefore(method) + test.result.let { values += it } + values + }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutationsWithExceptions( + method: KFunction3, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, A1, T, A0, A1, Result) -> Boolean, + invariants: Array<(T, A0, A1, T, A0, A1, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class, T::class, A0::class, A1::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkExecutionBranchesWithExceptions( + method: KFunction4, + vararg analysisResultsMatchers: (T, A0, A1, A2, Result) -> Boolean, + invariants: Array<(T, A0, A1, A2, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + ignoreNumberOfAnalysisResults, + *analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_EXECUTIONS + ) + } + + protected inline fun checkDiscoveredPropertiesWithExceptions( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, A2, Result) -> Boolean, + invariants: Array<(T, A0, A1, A2, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true } // TODO remove it + ) { + checkMatchesWithExceptions( + method, + analysisResultsNumberMatcher, + analysisResultsMatchers = analysisResultsMatchers, + invariants = invariants, + coverageChecker = coverageChecker, + checkMode = CheckMode.MATCH_PROPERTIES + ) + } + + protected inline fun checkMatchesWithExceptions( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg analysisResultsMatchers: (T, A0, A1, A2, Result) -> Boolean, + invariants: Array<(T, A0, A1, A2, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + analysisResultsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> + val values = test.takeAllParametersBefore(method) + test.result.let { values += it } + values + }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf(T::class, A0::class, A1::class, A2::class), + checkMode = checkMode, + coverageChecker + ) + } + + protected inline fun checkThisAndParamsMutationsWithExceptions( + method: KFunction4, + analysisResultsNumberMatcher: AnalysisResultsNumberMatcher, + vararg paramsMutationsMatchers: (T, A0, A1, A2, T, A0, A1, A2, Result) -> Boolean, + invariants: Array<(T, A0, A1, A2, T, A0, A1, A2, Result) -> Boolean> = emptyArray(), + noinline coverageChecker: (JcClassCoverage) -> Boolean = { _ -> true }, // TODO remove it + checkMode: CheckMode + ) { + internalCheck( + target = method, + analysisResultsNumberMatcher, + paramsMutationsMatchers, + invariants = invariants, + extractValuesToCheck = { test: JcTest -> test.takeAllParametersBeforeAndAfterWithResult(method) }, + // We don't check type for the result here + expectedTypesForExtractedValues = arrayOf( + T::class, + A0::class, + A1::class, + A2::class, + T::class, + A0::class, + A1::class, + A2::class + ), + checkMode = checkMode, + coverageChecker + ) + } + + // endregion + + protected fun JcTest.takeAllParametersBefore(method: KFunction<*>): MutableList = + before.takeAllParameters(method) + + protected fun JcTest.takeAllParametersBeforeWithResult(method: KFunction<*>): MutableList { + val values = before.takeAllParameters(method) + result.let { values += it.getOrNull() } + + return values + } + + protected fun JcTest.takeAllParametersAfter(method: KFunction<*>): MutableList = + after.takeAllParameters(method) + + protected fun JcTest.takeAllParametersAfterWithResult(method: KFunction<*>): MutableList { + val values = after.takeAllParameters(method) + result.let { values += it.getOrNull() } + + return values + } + + private fun JcTest.takeAllParametersBeforeAndAfter(method: KFunction<*>): MutableList { + val parameters = before.takeAllParameters(method) + parameters.addAll(after.takeAllParameters(method)) + + return parameters + } + + protected fun JcTest.takeAllParametersBeforeAndAfterWithResult(method: KFunction<*>): MutableList { + val values = takeAllParametersBeforeAndAfter(method) + values += result.getOrNull() + + return values + } + + private fun JcTest.takeStaticsBefore(): Map> = before.statics + private fun JcTest.takeStaticsAfter(): Map> = after.statics + + protected fun JcTest.takeParametersBeforeAndAllStaticsWithResult(method: KFunction<*>): MutableList { + val values = takeAllParametersBefore(method) + + values += result.getOrNull() + values += takeStaticsBefore() + values += takeStaticsAfter() + + return values + } + + private fun JcParametersState.takeAllParameters( + method: KFunction<*> + ): MutableList { + val values = mutableListOf() + if (method.instanceParameter != null && method.javaConstructor == null) { + values += requireNotNull(thisInstance) + } else { + // Note that for constructors we have thisInstance in such as case, in contrast to simple methods + require(thisInstance == null || method.javaConstructor != null) + } + + values.addAll(parameters) // add remaining arguments + return values + } + + protected open val jacodbCpKey: String + get() = samplesKey + + protected open val classpath: List + get() = samplesClasspath + + protected open val cp by lazy { + JacoDBContainer(jacodbCpKey, classpath).cp + } + + protected open val resolverType: JcTestResolverType = JcTestResolverType.INTERPRETER + + private val testResolver: JcTestResolver + get() = when (resolverType) { + JcTestResolverType.INTERPRETER -> JcTestInterpreter() + JcTestResolverType.CONCRETE_EXECUTOR -> JcTestExecutor(classpath = cp) + } + + override val typeTransformer: (Any?) -> KClass<*>? = { value -> value?.let { it::class } } + + override val checkType: (KClass<*>?, KClass<*>?) -> Boolean = + { expected, actual -> actual == null || expected != null && expected.java.isAssignableFrom(actual.java) } + + override var options: UMachineOptions = UMachineOptions( + pathSelectionStrategies = listOf(PathSelectionStrategy.TARGETED), + coverageZone = CoverageZone.TRANSITIVE, + exceptionsPropagation = false, +// timeout = 60_000.milliseconds, + stepsFromLastCovered = 3500L, + solverTimeout = Duration.INFINITE, // we do not need the timeout for a solver in tests + typeOperationsTimeout = Duration.INFINITE // we do not need the timeout for type operations in testsretu + ) + + open fun createMachine( + file: String, + cp: JcClasspath, + options: UMachineOptions, + interpreterObserver: JcInterpreterObserver? + ): JcMocksMachine { + return JcMocksMachine(file, cp, options, interpreterObserver = interpreterObserver) + } + + override val runner: (KFunction<*>, UMachineOptions) -> List = { method, options -> + val jcMethod = cp.getJcMethodByName(method) + val instList = jcMethod.method.instList + val localTargets = mutableListOf() + for (inst in instList) { + if (inst is JcReturnInst) { + val newTarget = MocksTarget(inst) + localTargets.add(newTarget) + } + } + targets = localTargets + + val fileReference = jcMethod.method.declaration.relativePath.substringBefore("#") + val pathToFile = fileReference.replaceFirst(".", "/") + .replaceFirst(".", "/") + .replaceFirst(".", "/") + val pathToDir = System.getProperty("user.dir") + val file = "$pathToDir/src/samples/java/$pathToFile.java" + + createMachine(file, cp, options, interpreterObserver).use { machine -> + val states = machine.analyze(jcMethod.method, targets) + val state = states.last() + val dummyTest = createUTest(jcMethod, state) + val testInstructions = mutableListOf>>() + val memory = state.memory + for (key in mockedMethods) { + val newValue = memory.read(key.key) + mockedMethodsValues[key] = newValue + testInstructions.add(createUTestInstructions(key, state)) + } + val testMockConfigInfo = createUTestMockConfigInfo(testInstructions) + val (mockConfigInfo, ifThrows) = renderMockConfigInfo(cp, dummyTest, testMockConfigInfo) + val methodUnderTestName = jcMethod.name + val methodUnderTestRetType = jcMethod.returnType.typeName + renderTestFile(ifThrows, methodUnderTestName, methodUnderTestRetType, file, mockConfigInfo) + states.map { + testResolver.resolve(jcMethod, it) + } + } + } + + override val coverageRunner: (List) -> JcClassCoverage = { _ -> + JcClassCoverage(visitedStmts = emptySet()) + } + + companion object { + private const val ENV_TEST_SAMPLES = "usvm.jvm.test.samples" + + val samplesClasspath by lazy { + loadClasspathFromEnv(ENV_TEST_SAMPLES) + } + + init { + // See https://dzone.com/articles/how-to-export-all-modules-to-all-modules-at-runtime-in-java?preview=true + org.burningwave.core.assembler.StaticComponentContainer.Modules.exportAllToAll() + } + } +} + +private typealias StaticsType = Map> diff --git a/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ServiceTest.kt b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ServiceTest.kt new file mode 100644 index 0000000000..ac0cafc072 --- /dev/null +++ b/usvm-jvm-mocks/src/test/kotlin/org/usvm/samples/ServiceTest.kt @@ -0,0 +1,21 @@ +package org.usvm.samples + +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.usvm.test.util.checkers.ignoreNumberOfAnalysisResults + +class ServiceTest : MocksTestRunner() { + @BeforeEach + fun reset() { + cleanUp() + } + + @Test + fun testCalc() { + checkDiscoveredPropertiesWithExceptions( + TestService::compute, + ignoreNumberOfAnalysisResults, + { _, _, r -> r.getOrNull() == null } + ) + } +} diff --git a/usvm-jvm-rendering/build.gradle.kts b/usvm-jvm-rendering/build.gradle.kts new file mode 100644 index 0000000000..6b51d9fac7 --- /dev/null +++ b/usvm-jvm-rendering/build.gradle.kts @@ -0,0 +1,27 @@ +plugins { + id("usvm.kotlin-conventions") +} + +dependencies { + implementation(Libs.jacodb_api_jvm) + implementation(Libs.jacodb_core) + implementation(project(":usvm-jvm:usvm-jvm-util")) + implementation(project(":usvm-jvm:usvm-jvm-test-api")) + implementation("com.github.javaparser:javaparser-symbol-solver-core:3.26.3") +} + +tasks.withType { + val reflectionUtils = project.sourceSets.main.get().java.find { file -> + file.name == "ReflectionUtils.java" + } + + from(reflectionUtils) +} + +publishing { + publications { + create("maven") { + from(components["java"]) + } + } +} diff --git a/usvm-jvm-rendering/src/main/java/org/usvm/jvm/rendering/ReflectionUtils.java b/usvm-jvm-rendering/src/main/java/org/usvm/jvm/rendering/ReflectionUtils.java new file mode 100644 index 0000000000..b002ee0ead --- /dev/null +++ b/usvm-jvm-rendering/src/main/java/org/usvm/jvm/rendering/ReflectionUtils.java @@ -0,0 +1,288 @@ +package org.usvm.jvm.rendering; + +import sun.misc.Unsafe; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; + +@SuppressWarnings({"removal", "deprecation"}) +public class ReflectionUtils { + private static final Unsafe UNSAFE; + + static { + try { + Field uns = Unsafe.class.getDeclaredField("theUnsafe"); + uns.setAccessible(true); + UNSAFE = (Unsafe) uns.get(null); + } catch (Throwable e) { + throw new RuntimeException(); + } + } + + //region Fields Interaction + + private static long getOffsetOf(Field field) { + return isStatic(field) ? UNSAFE.staticFieldOffset(field) : UNSAFE.objectFieldOffset(field); + } + + private static boolean isStatic(Field field) { + return (field.getModifiers() & Modifier.STATIC) > 0; + } + + private static List getInstanceFields(Class type) { + ArrayList fields = new ArrayList<>(); + for (Field field : type.getDeclaredFields()) { + if (!Modifier.isStatic(field.getModifiers())) + fields.add(field); + } + + return fields; + } + + private static List getStaticFields(Class type) { + ArrayList fields = new ArrayList<>(); + for (Field field : type.getDeclaredFields()) { + if (Modifier.isStatic(field.getModifiers())) + fields.add(field); + } + + return fields; + } + + private static Field getField(Object instance, String fieldName) { + Class type = instance.getClass(); + Class currentClass = type; + while (currentClass != Object.class && currentClass != null) { + for (Field field : getInstanceFields(currentClass)) { + if (field.getName().equals(fieldName)) + return field; + } + currentClass = currentClass.getSuperclass(); + } + + throw new IllegalArgumentException("Could not find field " + fieldName + " in " + type); + } + + private static Field getStaticField(Class type, String fieldName) { + for (Field field : getStaticFields(type)) { + if (field.getName().equals(fieldName)) + return field; + } + + throw new IllegalArgumentException("Could not find static field " + fieldName + " in " + type); + } + + private static Object getFieldValue(Object fixedInstance, Field field) { + long fieldOffset = getOffsetOf(field); + + if (!field.getType().isPrimitive()) { + return UNSAFE.getObject(fixedInstance, fieldOffset); + } + + if (field.getType() == boolean.class) { + return UNSAFE.getBoolean(fixedInstance, fieldOffset); + } else if (field.getType() == byte.class) { + return UNSAFE.getByte(fixedInstance, fieldOffset); + } else if (field.getType() == char.class) { + return UNSAFE.getChar(fixedInstance, fieldOffset); + } else if (field.getType() == short.class) { + return UNSAFE.getShort(fixedInstance, fieldOffset); + } else if (field.getType() == int.class) { + return UNSAFE.getInt(fixedInstance, fieldOffset); + } else if (field.getType() == long.class) { + return UNSAFE.getLong(fixedInstance, fieldOffset); + } else if (field.getType() == float.class) { + return UNSAFE.getFloat(fixedInstance, fieldOffset); + } else if (field.getType() == double.class) { + return UNSAFE.getDouble(fixedInstance, fieldOffset); + } + + throw new IllegalStateException("unexpected primitive type"); + } + + @SuppressWarnings("unchecked") + public static T getFieldValue(Object instance, String fieldName) { + Field field = getField(instance, fieldName); + return (T) getFieldValue(instance, field); + } + + @SuppressWarnings("unchecked") + public static T getStaticFieldValue(Class type, String fieldName) { + Field field = getStaticField(type, fieldName); + return (T) getFieldValue(UNSAFE.staticFieldBase(field), field); + } + + private static void setFieldValue(Object fixedInstance, Field field, Object value) { + long fieldOffset = getOffsetOf(field); + + if (!field.getType().isPrimitive()) { + UNSAFE.putObject(fixedInstance, fieldOffset, value); + return; + } + + if (field.getType() == boolean.class) { + UNSAFE.putBoolean(fixedInstance, fieldOffset, value != null && ((boolean) value)); + } else if (field.getType() == byte.class) { + UNSAFE.putByte(fixedInstance, fieldOffset, value != null ? (byte) value : 0); + } else if (field.getType() == char.class) { + UNSAFE.putChar(fixedInstance, fieldOffset, value != null ? (char) value : '\u0000'); + } else if (field.getType() == short.class) { + UNSAFE.putShort(fixedInstance, fieldOffset, value != null ? (short) value : 0); + } else if (field.getType() == int.class) { + UNSAFE.putInt(fixedInstance, fieldOffset, value != null ? (int) value : 0); + } else if (field.getType() == long.class) { + UNSAFE.putLong(fixedInstance, fieldOffset, value != null ? (long) value : 0); + } else if (field.getType() == float.class) { + UNSAFE.putFloat(fixedInstance, fieldOffset, value != null ? (float) value : 0.0f); + } else if (field.getType() == double.class) { + UNSAFE.putDouble(fixedInstance, fieldOffset, value != null ? (double) value : 0.0); + } + } + + public static void setFieldValue(Object instance, String fieldName, Object value) { + Field field = getField(instance, fieldName); + setFieldValue(instance, field, value); + } + + public static void setStaticFieldValue(Class type, String fieldName, Object value) { + Field field = getStaticField(type, fieldName); + setFieldValue(UNSAFE.staticFieldBase(field), field, value); + } + + public static Throwable getRootCause(Throwable exception) { + Throwable result = exception; + while (true) { + Throwable cause = result.getCause(); + if (cause == null || cause == result) return result; + result = cause; + } + } + + //endregion + + //region Allocation + + @SuppressWarnings("unchecked") + public static T allocateInstance(Class clazz) throws InstantiationException { + return (T) UNSAFE.allocateInstance(clazz); + } + + //endregion + + //region Methods Interaction + + private static String parameterTypesSignature(Class[] types) { + if (types.length == 0) + return ""; + + StringBuilder sb = new StringBuilder(); + for (Class type : types) { + sb.append(type.getTypeName()); + sb.append(";"); + } + + return sb.toString(); + } + + private static String methodSignature(Method method) { + String parametersSig = parameterTypesSignature(method.getParameterTypes()); + String returnTypeSig = method.getReturnType().getTypeName(); + return method.getName() + "(" + parametersSig + ")" + returnTypeSig + ";"; + } + + private static String methodSignature(Constructor method) { + String parametersSig = parameterTypesSignature(method.getParameterTypes()); + return "" + "(" + parametersSig + ")" + "void" + ";"; + } + + private static List getInstanceMethods(Class type) { + ArrayList methods = new ArrayList<>(); + for (Method method : type.getDeclaredMethods()) { + if (!Modifier.isStatic(method.getModifiers())) + methods.add(method); + } + + return methods; + } + + private static List getStaticMethods(Class type) { + ArrayList methods = new ArrayList<>(); + for (Method method : type.getDeclaredMethods()) { + if (Modifier.isStatic(method.getModifiers())) + methods.add(method); + } + + return methods; + } + + private static Method getMethod(Object instance, String methodSig) { + Class type = instance.getClass(); + Class currentClass = type; + while (currentClass != Object.class && currentClass != null) { + for (Method method : getInstanceMethods(currentClass)) { + if (methodSignature(method).equals(methodSig)) + return method; + } + currentClass = currentClass.getSuperclass(); + } + + throw new IllegalArgumentException("Could not find method " + methodSig + " in " + type); + } + + private static Method getStaticMethod(Class type, String methodSig) { + for (Method method : getStaticMethods(type)) { + if (methodSignature(method).equals(methodSig)) + return method; + } + + throw new IllegalArgumentException("Could not find static method " + methodSig + " in " + type); + } + + private static Constructor getConstructor(Class type, String ctorSig) { + for (Constructor ctor : type.getDeclaredConstructors()) { + if (methodSignature(ctor).equals(ctorSig)) + return ctor; + } + + throw new IllegalArgumentException("Could not find constructor " + ctorSig + " in " + type); + } + + private static Object callMethod(Object instance, Method method, Object... args) throws Throwable { + Object[] checkedArgs = (args == null) ? new Object[] { null } : args; + try { + method.setAccessible(true); + return method.invoke(instance, checkedArgs); + } catch (InvocationTargetException e) { + throw e.getTargetException(); + } + } + + @SuppressWarnings("unchecked") + public static T callMethod(Object instance, String methodSig, Object... args) throws Throwable { + return (T) callMethod(instance, getMethod(instance, methodSig), args); + } + + @SuppressWarnings("unchecked") + public static T callStaticMethod(Class type, String methodSig, Object... args) throws Throwable { + return (T) callMethod(null, getStaticMethod(type, methodSig), args); + } + + @SuppressWarnings("unchecked") + public static T callConstructor(Class type, String ctorSig, Object... args) throws Throwable { + Constructor ctor = getConstructor(type, ctorSig); + Object[] checkedArgs = (args == null) ? new Object[] { null } : args; + try { + ctor.setAccessible(true); + return (T) ctor.newInstance(checkedArgs); + } catch (InvocationTargetException e) { + throw e.getTargetException(); + } + } + + //endregion +} diff --git a/usvm-jvm-rendering/src/main/kotlin/Utils.kt b/usvm-jvm-rendering/src/main/kotlin/Utils.kt new file mode 100644 index 0000000000..89d10e8724 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/Utils.kt @@ -0,0 +1,13 @@ +fun Map.partitionByKey(predicate: (K) -> Boolean): Pair, Map> { + val isTrue = mutableMapOf() + val isFalse = mutableMapOf() + + for ((k, v) in entries) { + if (predicate(k)) { + isTrue.put(k, v) + } else { + isFalse.put(k, v) + } + } + return Pair(isTrue, isFalse) +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcFileRendererFactory.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcFileRendererFactory.kt new file mode 100644 index 0000000000..ab475abbdb --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcFileRendererFactory.kt @@ -0,0 +1,149 @@ +package org.usvm.jvm.rendering + +import com.github.javaparser.ast.CompilationUnit +import java.nio.file.Path +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.ext.toType +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestFileRenderer +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestInfo +import org.usvm.jvm.rendering.spring.webMvcTestRenderer.JcSpringMvcTestFileRenderer +import org.usvm.jvm.rendering.spring.webMvcTestRenderer.JcSpringMvcTestInfo +import org.usvm.jvm.rendering.testRenderer.JcTestFileRenderer +import org.usvm.jvm.rendering.testRenderer.JcTestInfo +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestFileRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestInfo + +sealed class JcTestClassInfo( + val clazz: JcClassOrInterface, + protected val filePath: Path?, + protected val packageName: String?, + protected val className: String? +) { + + private val defaultTestClassName: String by lazy { "${clazz.simpleName}Tests" } + + val testClassName: String get() = className ?: defaultTestClassName + + val testFilePath: Path? get() = filePath + + val testPackageName: String? get() = packageName + + class Base(clazz: JcClassOrInterface, testFilePath: Path?, packageName: String?, testClassName: String?) : + JcTestClassInfo(clazz, testFilePath, packageName, testClassName) + + class Unsafe(clazz: JcClassOrInterface, testFilePath: Path?, packageName: String?, testClassName: String?) : + JcTestClassInfo(clazz, testFilePath, packageName, testClassName) + + class SpringUnit(clazz: JcClassOrInterface, testFilePath: Path?, packageName: String?, testClassName: String?) : + JcTestClassInfo(clazz, testFilePath, packageName, testClassName) + + class SpringMvc(clazz: JcClassOrInterface, testFilePath: Path?, packageName: String?, testClassName: String?) : + JcTestClassInfo(clazz, testFilePath, packageName, testClassName) + + companion object { + fun from(testInfo: JcTestInfo) = when (testInfo) { + is JcSpringMvcTestInfo -> SpringMvc( + testInfo.method.enclosingClass, + testInfo.testFilePath, + testInfo.testPackageName, + testInfo.testClassName + ) + + is JcSpringUnitTestInfo -> SpringUnit( + testInfo.method.enclosingClass, + testInfo.testFilePath, + testInfo.testPackageName, + testInfo.testClassName + ) + + is JcUnsafeTestInfo -> Unsafe( + testInfo.method.enclosingClass, + testInfo.testFilePath, + testInfo.testPackageName, + testInfo.testClassName + ) + + else -> Base( + testInfo.method.enclosingClass, + testInfo.testFilePath, + testInfo.testPackageName, + testInfo.testClassName + ) + } + } + + override fun equals(other: Any?): Boolean { + return other is JcTestClassInfo && clazz == other.clazz + } + + override fun hashCode(): Int { + return clazz.hashCode() + } +} + +object JcTestFileRendererFactory { + fun testFileRendererFor( + cu: CompilationUnit, + cp: JcClasspath, + testClassInfo: JcTestClassInfo, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: ((JcClassOrInterface) -> Boolean)? = null + ): JcTestFileRenderer { + check(testClassInfo !is JcTestClassInfo.SpringUnit && testClassInfo !is JcTestClassInfo.SpringMvc || isAccessibleFromTestClass != null) { + "isAccessible predicate should not be null for spring renderer" + } + + return when (testClassInfo) { + is JcTestClassInfo.SpringMvc -> JcSpringMvcTestFileRenderer( + testClassInfo.clazz.toType(), + cu, + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass!! + ) + + is JcTestClassInfo.SpringUnit -> JcSpringUnitTestFileRenderer( + cu, + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass!! + ) + + is JcTestClassInfo.Unsafe -> JcUnsafeTestFileRenderer(cu, cp, reflectionUtilsInlineStrategy) + is JcTestClassInfo.Base -> JcTestFileRenderer(cu, cp) + } + } + + fun testFileRendererFor( + packageName: String?, + cp: JcClasspath, + testClassInfo: JcTestClassInfo, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: ((JcClassOrInterface) -> Boolean)? = null + ): JcTestFileRenderer { + check(testClassInfo !is JcTestClassInfo.SpringUnit && testClassInfo !is JcTestClassInfo.SpringMvc || isAccessibleFromTestClass != null) { + "isAccessible predicate should not be null for spring renderer" + } + + return when (testClassInfo) { + is JcTestClassInfo.SpringMvc -> JcSpringMvcTestFileRenderer( + testClassInfo.clazz.toType(), + packageName, + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass!! + ) + + is JcTestClassInfo.SpringUnit -> JcSpringUnitTestFileRenderer( + packageName, + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass!! + ) + + is JcTestClassInfo.Unsafe -> JcUnsafeTestFileRenderer(packageName, cp, reflectionUtilsInlineStrategy) + is JcTestClassInfo.Base -> JcTestFileRenderer(packageName, cp) + } + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcTestsRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcTestsRenderer.kt new file mode 100644 index 0000000000..d8474965a2 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/JcTestsRenderer.kt @@ -0,0 +1,72 @@ +package org.usvm.jvm.rendering + +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.printer.DefaultPrettyPrinter +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.testRenderer.JcTestInfo +import org.usvm.jvm.rendering.testTransformers.JcCallCtorTransformer +import org.usvm.jvm.rendering.testTransformers.JcPrimitiveWrapperTransformer +import org.usvm.jvm.rendering.testTransformers.JcTestTransformer +import org.usvm.jvm.rendering.testTransformers.JcDeadCodeTransformer +import org.usvm.jvm.rendering.testTransformers.JcOuterThisTransformer +import org.usvm.test.api.UTest + +class JcTestsRenderer { + private val transformers: List = listOf( + JcOuterThisTransformer(), + JcPrimitiveWrapperTransformer(), + JcCallCtorTransformer(), + JcDeadCodeTransformer() + ) + + fun renderTests( + cp: JcClasspath, + tests: List>, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy = ReflectionUtilsInlineStrategy.NoInline(), + isAccessibleFromTestClass: ((JcClassOrInterface) -> Boolean)? = null + ): Map { + val renderedFiles = mutableMapOf() + val testClasses = tests.groupBy { (_, info) -> JcTestClassInfo.from(info) } + val printer = DefaultPrettyPrinter() + + for ((testClassInfo, testsToRender) in testClasses) { + + val testFile = testClassInfo.testFilePath + val fileRenderer = when { + testFile != null -> { + JcTestFileRendererFactory.testFileRendererFor( + StaticJavaParser.parse(testFile), + cp, + testClassInfo, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + } + else -> { + JcTestFileRendererFactory.testFileRendererFor( + testClassInfo.testPackageName, + cp, + testClassInfo, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + } + } + + val testClassRenderer = fileRenderer.getOrAddClass(testClassInfo.testClassName) + + for ((test, testInfo) in testsToRender) { + val transformedTest = transformers.fold(test) { currentTest, transformer -> + transformer.transform(currentTest) + } + testClassRenderer.addTest(transformedTest, testInfo.testNamePrefix) + } + + val renderedCu = fileRenderer.render() + + renderedFiles[testClassInfo] = printer.print(renderedCu) + } + return renderedFiles + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/ReflectionUtilsInlineStrategy.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/ReflectionUtilsInlineStrategy.kt new file mode 100644 index 0000000000..800c126619 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/ReflectionUtilsInlineStrategy.kt @@ -0,0 +1,197 @@ +package org.usvm.jvm.rendering + +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.Modifier +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.SimpleName +import kotlin.jvm.optionals.getOrNull +import org.usvm.jvm.rendering.baseRenderer.JcImportManager + +sealed class ReflectionUtilsInlineStrategy(val inTestClassFile: Boolean) { + + abstract fun addReflectionUtils(importManager: JcImportManager, cu: CompilationUnit): CompilationUnit + + class NoInline : ReflectionUtilsInlineStrategy(inTestClassFile = false) { + override fun addReflectionUtils( + importManager: JcImportManager, + cu: CompilationUnit + ): CompilationUnit { + return cu + } + } + + class Inline : ReflectionUtilsInlineStrategy(inTestClassFile = true) { + + override fun addReflectionUtils( + importManager: JcImportManager, + cu: CompilationUnit + ): CompilationUnit { + val testClass = cu.types.singleOrNull() + + check(testClass != null) { + "exactly one test class expected" + } + + check(!testClass.getFieldByName("UNSAFE").isPresent) { + "field and init blocks merge not yet supported" + } + + val filteredUtilCu = filterReflectionUtilCu() ?: return cu + + val requiredUtilMembers = filteredUtilCu.getClassByName("ReflectionUtils").get().members + + for (member in requiredUtilMembers) { + if (member.isMethodDeclaration) { + member.asMethodDeclaration().setModifiers(Modifier.Keyword.PRIVATE, Modifier.Keyword.STATIC) + } + } + + testClass.members.addAll(requiredUtilMembers) + + cu.imports.addAll(filteredUtilCu.imports) + + return cu + } + } + + class NestedClass : ReflectionUtilsInlineStrategy(inTestClassFile = true) { + + override fun addReflectionUtils( + importManager: JcImportManager, + cu: CompilationUnit + ): CompilationUnit { + val testClass = cu.types.singleOrNull() + + check(testClass != null) { + "exactly one test class expected" + } + + val filteredUtilCu = filterReflectionUtilCu() ?: return cu + + var utilsClass = + testClass.members.firstOrNull { + it is ClassOrInterfaceDeclaration && it.isNestedType && it.name == SimpleName("ReflectionUtils") + } as? ClassOrInterfaceDeclaration + + val requiredUtils = filteredUtilCu.getClassByName("ReflectionUtils").get() + + if (utilsClass != null) { + mergeUtilClass(utilsClass, requiredUtils) + } else { + utilsClass = requiredUtils + testClass.addMember(utilsClass) + } + + cu.imports.addAll(filteredUtilCu.imports) + + utilsClass.setModifiers(Modifier.Keyword.PRIVATE, Modifier.Keyword.STATIC) + + return cu + } + } + + class OuterClass : ReflectionUtilsInlineStrategy(inTestClassFile = true) { + + override fun addReflectionUtils( + importManager: JcImportManager, + cu: CompilationUnit + ): CompilationUnit { + val filteredUtilCu = filterReflectionUtilCu() + if (filteredUtilCu == null) return cu + + var currentUtilsClass = cu.getClassByName("ReflectionUtils").getOrNull() + val requiredUtilsClass = filteredUtilCu.getClassByName("ReflectionUtils").get() + + if (currentUtilsClass != null) { + mergeUtilClass(currentUtilsClass, requiredUtilsClass) + } else { + currentUtilsClass = requiredUtilsClass + cu.addType(currentUtilsClass) + } + + cu.imports.addAll(filteredUtilCu.imports) + + currentUtilsClass.modifiers = NodeList() + + return cu + } + } + + private val utilsCu: CompilationUnit by lazy { + this::class.java.classLoader.getResourceAsStream("ReflectionUtils.java").use { stream -> + StaticJavaParser.parse(stream) + } + } + + fun useUsvmReflectionMethod(name: String) { + usvmUtilMethodCollector.add(name) + } + + protected fun filterReflectionUtilCu(): CompilationUnit? { + val usedMethods = extractUsedUsvmUtilMethods() + if (usedMethods.isEmpty()) return null + + val cu = utilsCu.clone() + val utilsClass = cu.getClassByName("ReflectionUtils").get() + + utilsClass.members.removeIf { it.isMethodDeclaration && (it.asMethodDeclaration().name.asString() !in usedMethods) } + cu.allContainedComments.forEach { it.remove() } + + return cu + } + + protected fun mergeUtilClass(prev: ClassOrInterfaceDeclaration, extra: ClassOrInterfaceDeclaration) { + val declaredMembersNames = + prev.members.mapNotNull { if (it.isMethodDeclaration) it.asMethodDeclaration().name else null } + + extra.members.filter { it.isMethodDeclaration }.forEach { declaration -> + if (declaration.asMethodDeclaration().name !in declaredMembersNames) { + prev.addMember(declaration) + } + } + } + + + private val usvmUtilMethodCollector: MutableSet = mutableSetOf() + + private val usvmUtilRequiredMethodsMapping = mapOf( + "callConstructor" to listOf("getConstructor", "methodSignature", "parameterTypesSignature"), + "callMethod" to listOf("getMethod", "getInstanceMethods", "methodSignature", "parameterTypesSignature"), + "callStaticMethod" to listOf( + "callMethod", + "getMethod", + "getStaticMethod", + "getStaticMethods", + "getInstanceMethods", + "methodSignature", + "parameterTypesSignature" + ), + "getStaticFieldValue" to listOf( + "getStaticField", + "getFieldValue", + "getOffsetOf", + "isStatic", + "getStaticFields" + ), + "getFieldValue" to listOf("getOffsetOf", "isStatic"), + "setStaticFieldValue" to listOf( + "getStaticField", + "getStaticFields", + "setFieldValue", + "getOffsetOf", + "isStatic" + ), + "setFieldValue" to listOf("getField", "getInstanceFields", "getOffsetOf", "isStatic"), + "allocateInstance" to listOf() + ) + + private fun extractUsedUsvmUtilMethods(): Set { + val usedMethodsTransitive = usvmUtilMethodCollector.flatMap { method -> + usvmUtilRequiredMethodsMapping[method]!! + method + } + + return usedMethodsTransitive.toSet() + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/Utils.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/Utils.kt new file mode 100644 index 0000000000..8f77f48cd1 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/Utils.kt @@ -0,0 +1,16 @@ +package org.usvm.jvm.rendering + +import org.jacodb.api.jvm.JcMethod +import org.objectweb.asm.Opcodes + +internal val String.normalized: String + get() = this.replace("<", "").replace(">", "").replace("$", "") + +internal val String.decapitalized: String + get() = replaceFirstChar { it.lowercase() } + +internal val String.capitalized: String + get() = replaceFirstChar { it.titlecase() } + +internal val JcMethod.isVararg: Boolean + get() = access.and(Opcodes.ACC_VARARGS) == Opcodes.ACC_VARARGS \ No newline at end of file diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcBlockRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcBlockRenderer.kt new file mode 100644 index 0000000000..e9cd50f4d6 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcBlockRenderer.kt @@ -0,0 +1,158 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.VariableDeclarator +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.expr.VariableDeclarationExpr +import com.github.javaparser.ast.stmt.BlockStmt +import com.github.javaparser.ast.stmt.ExpressionStmt +import com.github.javaparser.ast.stmt.IfStmt +import com.github.javaparser.ast.stmt.Statement +import com.github.javaparser.ast.type.ReferenceType +import com.github.javaparser.ast.type.Type +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcType + +open class JcBlockRenderer private constructor( + protected open val methodRenderer: JcMethodRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + protected val thrownExceptions: HashSet, + private val vars: HashSet +) : JcCodeRenderer(importManager, identifiersManager, cp) { + + protected constructor( + methodRenderer: JcMethodRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + thrownExceptions: HashSet + ) : this(methodRenderer, importManager, identifiersManager, cp, thrownExceptions, HashSet()) + + constructor( + methodRenderer: JcMethodRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath + ) : this(methodRenderer, importManager, identifiersManager, cp, HashSet()) + + private val statements = NodeList() + + protected open val classRenderer get() = methodRenderer.classRenderer + + override fun renderInternal(): BlockStmt { + return BlockStmt(statements) + } + + fun getThrownExceptions(): NodeList { + return NodeList(thrownExceptions) + } + + open fun newInnerBlock(): JcBlockRenderer { + val innerIdManager = JcIdentifiersManager(identifiersManager) + val innerVars = HashSet(vars) + return JcBlockRenderer( + methodRenderer, + importManager, + innerIdManager, + cp, + thrownExceptions, + innerVars + ) + } + + fun addExpression(expr: Expression) { + if (expr is NameExpr) { + check(expr in vars) { + "addExpression call on raw NameExpr $expr" + } + return + } + + addStatement(ExpressionStmt(expr)) + } + + fun addStatement(stmt: Statement) { + statements.add(stmt) + } + + fun renderVarDeclaration(type: JcType, expr: Expression? = null, namePrefix: String? = null): NameExpr { + val renderedType = renderType(type) + return renderVarDeclaration(renderedType, expr, namePrefix) + } + + protected fun renderVarDeclaration(type: Type, expr: Expression? = null, namePrefix: String? = null): NameExpr { + if (expr is NameExpr && expr in vars) + return expr + + val name = identifiersManager[namePrefix ?: "v"] + val declarator = VariableDeclarator(type, name, expr) + addExpression(VariableDeclarationExpr(declarator)) + + val freshVar = NameExpr(name) + vars.add(freshVar) + + return freshVar + } + + fun renderIfStatement( + condition: Expression, + initThenBody: (JcBlockRenderer) -> Unit, + initElseBody: (JcBlockRenderer) -> Unit + ) { + val thenBlockRenderer = newInnerBlock() + initThenBody(thenBlockRenderer) + val elseBlockRenderer = newInnerBlock() + initElseBody(elseBlockRenderer) + val thenBlock = thenBlockRenderer.render() + val thenStmt = thenBlock.statements.singleOrNull() ?: thenBlock + val elseBlock = elseBlockRenderer.render() + val elseStmt = elseBlock.statements.singleOrNull() ?: elseBlock + statements.add(IfStmt(condition, thenStmt, elseStmt)) + } + + fun renderArraySetStatement(array: Expression, index: Expression, value: Expression) { + addExpression(renderArraySet(array, index, value)) + } + + fun renderSetFieldStatement(instance: Expression, field: JcField, value: Expression) { + addExpression(renderSetField(instance, field, value)) + } + + fun renderSetStaticFieldStatement(field: JcField, value: Expression) { + addExpression(renderSetStaticField(field, value)) + } + + fun addThrownExceptions(method: JcMethod) { + addThrownExceptions(method.exceptions.map { it.typeName }) + } + + fun addThrownExceptions(exceptionsTypeNames: List) { + exceptionsTypeNames.forEach { addThrownException(it) } + } + + fun addThrownException(typeName: String) { + val thrown = renderClass(typeName) + thrownExceptions.add(thrown) + } + + override fun renderConstructorCall(ctor: JcMethod, type: JcClassType, args: List, inlinesVarargs: Boolean): Expression { + addThrownExceptions(ctor) + return super.renderConstructorCall(ctor, type, args, inlinesVarargs) + } + + override fun renderMethodCall(method: JcMethod, instance: Expression, args: List, inlinesVarargs: Boolean): Expression { + addThrownExceptions(method) + return super.renderMethodCall(method, instance, args, inlinesVarargs) + } + + override fun renderStaticMethodCall(method: JcMethod, args: List, inlinesVarargs: Boolean): Expression { + addThrownExceptions(method) + return super.renderStaticMethodCall(method, args, inlinesVarargs) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcClassRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcClassRenderer.kt new file mode 100644 index 0000000000..7fa608fd59 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcClassRenderer.kt @@ -0,0 +1,140 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.ast.Modifier +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.BodyDeclaration +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.body.FieldDeclaration +import com.github.javaparser.ast.body.VariableDeclarator +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.SimpleName +import com.github.javaparser.ast.type.Type +import org.jacodb.api.jvm.JcAnnotation +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcField + +open class JcClassRenderer : JcCodeRenderer { + + internal val name: SimpleName + private val modifiers: NodeList + private val annotations: NodeList + private val members: NodeList> + + private val renderingMethods: MutableList = mutableListOf() + + constructor( + decl: ClassOrInterfaceDeclaration, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath + ) : super(importManager, identifiersManager.extendedWith(decl), cp) + { + this.name = decl.name + this.modifiers = decl.modifiers + this.annotations = decl.annotations + this.members = decl.members + } + + constructor( + name: String, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath + ): super(importManager, identifiersManager, cp) { + this.name = identifiersManager.generateIdentifier(name) + this.modifiers = NodeList() + this.annotations = NodeList() + this.members = NodeList() + } + + protected fun addRenderingMethod(render: JcMethodRenderer) { + renderingMethods.add(render) + } + + fun getOrCreateField( + field: JcField, + initializer: Expression? = null + ): SimpleName { + val modifiers: List = getModifiers(field) + val annotations = field.annotations.map { renderAnnotation(it) } + + val fieldType = cp.findTypeOrNull(field.type.typeName) + ?: error("Field type ${field.type.typeName} not found in classpath") + return getOrCreateField( + renderType(fieldType), + field.name, + NodeList(modifiers), + NodeList(annotations), + initializer + ) + } + + fun getOrCreateField( + type: Type, + name: String, + modifiers: NodeList = NodeList(), + annotations: NodeList = NodeList(), + initializer: Expression? = null + ): SimpleName { + val fieldExists = members.any { + it is FieldDeclaration && it.variables.any { variable -> variable.name.asString() == name } + } + + if (fieldExists) + return SimpleName(name) + + return addField(type, name, modifiers, annotations, initializer) + } + + fun addField( + type: Type, + name: String, + modifiers: NodeList = NodeList(), + annotations: NodeList = NodeList(), + initializer: Expression? = null + ): SimpleName { + val fieldName = identifiersManager.generateIdentifier(name) + val declarator = VariableDeclarator(type, fieldName, initializer) + val decl = FieldDeclaration(modifiers, annotations, NodeList(declarator)) + members.add(decl) + + return fieldName + } + + fun addAnnotation(annotation: AnnotationExpr) { + if (!annotations.contains(annotation)) { + annotations.add(annotation) + } + } + + fun addAnnotation(annotation: JcAnnotation) { + val renderedAnnotation = renderAnnotation(annotation) + addAnnotation(renderedAnnotation) + } + + override fun renderInternal(): ClassOrInterfaceDeclaration { + val renderedMembers = mutableListOf>() + for (renderer in renderingMethods) { + try { + renderedMembers.add(renderer.render()) + } catch (e: Throwable) { + println("Renderer failed to render method: ${e.message}") + println("with ${e.stackTraceToString()}") + } + } + val allMembers = NodeList(renderedMembers) + allMembers.addAll(members) + return ClassOrInterfaceDeclaration( + modifiers, + annotations, + false, + name, + NodeList(), + NodeList(), + NodeList(), + NodeList(), + allMembers + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcCodeRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcCodeRenderer.kt new file mode 100644 index 0000000000..783fab2264 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcCodeRenderer.kt @@ -0,0 +1,760 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.Modifier +import com.github.javaparser.ast.Node +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.Parameter +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.ArrayAccessExpr +import com.github.javaparser.ast.expr.ArrayInitializerExpr +import com.github.javaparser.ast.expr.AssignExpr +import com.github.javaparser.ast.expr.BooleanLiteralExpr +import com.github.javaparser.ast.expr.CastExpr +import com.github.javaparser.ast.expr.CharLiteralExpr +import com.github.javaparser.ast.expr.ClassExpr +import com.github.javaparser.ast.expr.DoubleLiteralExpr +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.FieldAccessExpr +import com.github.javaparser.ast.expr.IntegerLiteralExpr +import com.github.javaparser.ast.expr.LambdaExpr +import com.github.javaparser.ast.expr.LongLiteralExpr +import com.github.javaparser.ast.expr.MarkerAnnotationExpr +import com.github.javaparser.ast.expr.MemberValuePair +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.Name +import com.github.javaparser.ast.expr.NormalAnnotationExpr +import com.github.javaparser.ast.expr.NullLiteralExpr +import com.github.javaparser.ast.expr.ObjectCreationExpr +import com.github.javaparser.ast.expr.StringLiteralExpr +import com.github.javaparser.ast.expr.TypeExpr +import com.github.javaparser.ast.stmt.BlockStmt +import com.github.javaparser.ast.type.ArrayType +import com.github.javaparser.ast.type.ClassOrInterfaceType +import com.github.javaparser.ast.type.PrimitiveType +import com.github.javaparser.ast.type.PrimitiveType.Primitive +import com.github.javaparser.ast.type.Type +import com.github.javaparser.ast.type.VoidType +import com.github.javaparser.ast.type.WildcardType +import kotlin.jvm.optionals.getOrNull +import kotlin.math.max +import org.jacodb.api.jvm.JcAccessible +import org.jacodb.api.jvm.JcAnnotation +import org.jacodb.api.jvm.JcArrayType +import org.jacodb.api.jvm.JcBoundedWildcard +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcPrimitiveType +import org.jacodb.api.jvm.JcType +import org.jacodb.api.jvm.JcTypeVariable +import org.jacodb.api.jvm.JcUnboundWildcard +import org.jacodb.api.jvm.PredefinedPrimitives +import org.jacodb.api.jvm.ext.objectType +import org.jacodb.api.jvm.ext.packageName +import org.jacodb.api.jvm.ext.toType +import org.jacodb.impl.features.classpaths.virtual.JcVirtualField +import org.jacodb.impl.types.JcTypeVariableImpl +import org.jacodb.impl.types.TypeNameImpl +import org.usvm.jvm.rendering.baseRenderer.JcTypeVariableExt.isRecursive +import org.usvm.jvm.rendering.isVararg +import org.usvm.jvm.util.toTypedMethod + +abstract class JcCodeRenderer( + open val importManager: JcImportManager, + internal val identifiersManager: JcIdentifiersManager, + protected val cp: JcClasspath, + private val packagePrivateAsPublic: Boolean = true +) { + + private var rendered: T? = null + + protected abstract fun renderInternal(): T + + fun render(): T { + if (rendered != null) + return rendered!! + + rendered = renderInternal() + return rendered!! + } + + companion object { + val voidType by lazy { VoidType() } + } + + val objectType by lazy { renderClass("java.lang.Object") } + + //region Types + + protected fun qualifiedName(typeName: String): String = typeName.replace("$", ".") + + fun renderType(type: JcType, includeGenericArgs: Boolean = true): Type = when (type) { + is JcPrimitiveType -> { + val fromTypeName = Primitive.byTypeName(type.typeName).getOrNull() + when { + fromTypeName != null -> PrimitiveType(fromTypeName) + type.typeName == PredefinedPrimitives.Void -> VoidType() + else -> error("cannot render primitive ${type.typeName}") + } + } + is JcArrayType -> ArrayType(renderType(type.elementType, includeGenericArgs)) + is JcClassType -> renderClass(type, includeGenericArgs) + is JcTypeVariable -> renderTypeVariable(type, includeGenericArgs) + is JcBoundedWildcard -> renderBoundedWildcardType(type, includeGenericArgs) + is JcUnboundWildcard -> WildcardType() + else -> error("unexpected type ${type.typeName}") + } + + private fun renderTypeVariable(variable: JcTypeVariable, includeGenericArgs: Boolean): Type { + val isRecursive = (variable as? JcTypeVariableImpl)?.isRecursive + return renderType(variable.bounds.firstOrNull() ?: cp.objectType, includeGenericArgs && isRecursive == false) + } + + private fun renderBoundedWildcardType(type: JcBoundedWildcard, includeGenericArgs: Boolean): Type { + val bound = (type.lowerBounds + type.upperBounds).first() + return renderType(bound, includeGenericArgs) + } + + fun renderClass(typeName: String, includeGenericArgs: Boolean = true): ClassOrInterfaceType { + check(!typeName.contains('<') && !typeName.contains('>')) { + "hardcoded generics not supported" + } + + val type = cp.findTypeOrNull(typeName) as? JcClassType + if (type != null) + return renderClass(type, includeGenericArgs) + var classOrInterface = StaticJavaParser.parseClassOrInterfaceType(typeName) + if (importManager.add(classOrInterface.nameWithScope)) + classOrInterface = classOrInterface.removeScope() + return classOrInterface + } + + fun shouldRenderClassAsPrivate(type: JcClassType): Boolean { + return !(type.isPublic || packagePrivateAsPublic && type.isPackagePrivate) + } + + fun renderClass(type: JcClassType, includeGenericArgs: Boolean = true): ClassOrInterfaceType { + check(!shouldRenderClassAsPrivate(type)) { + "Rendering private classes is not supported. Cannot render ${type.typeName}" + } + check(!type.jcClass.isAnonymous) { "Rendering anonymous classes is not supported" } + + return when { + type.outerType == null -> renderClassOuter(type, includeGenericArgs) + type.isStatic -> renderClassInnerStatic(type, includeGenericArgs) + else -> renderClassInner(type, includeGenericArgs) + } + } + + private fun renderClassInnerStatic(type: JcClassType, includeGenericArgs: Boolean): ClassOrInterfaceType { + val simpleNameParts = qualifiedName(type.jcClass.simpleName).split(".") + val renderName = qualifiedName(type.jcClass.name) + val importPackage = type.jcClass.packageName + "." + simpleNameParts.dropLast(1).joinToString(".") + val importName = simpleNameParts.last() + if (importManager.addStatic(importPackage, importName)) { + return StaticJavaParser.parseClassOrInterfaceType(simpleNameParts.last()) + .setTypeArgsIfNeeded(includeGenericArgs, type) + } + return ClassOrInterfaceType(renderClass(type.outerType!!, false), renderName) + .setTypeArgsIfNeeded(includeGenericArgs, type) + } + + private fun ClassOrInterfaceType.setTypeArgsIfNeeded( + includeGenericArgs: Boolean, + type: JcClassType + ): ClassOrInterfaceType { + if (includeGenericArgs && type.typeArguments.isNotEmpty()) + return setTypeArguments(NodeList(type.typeArguments.map { renderType(it) })) + + return this + } + + private fun renderClassInner(type: JcClassType, includeGenericArgs: Boolean): ClassOrInterfaceType { + return ClassOrInterfaceType( + renderClass(type.outerType!!, includeGenericArgs), + qualifiedName(type.jcClass.simpleName).split(".").last() + ).setTypeArgsIfNeeded(includeGenericArgs, type) + } + + private fun renderClassOuter(type: JcClassType, includeGenericArgs: Boolean): ClassOrInterfaceType { + val renderName = when { + importManager.add(type.jcClass.packageName, type.jcClass.simpleName) -> qualifiedName(type.jcClass.simpleName) + else -> qualifiedName(type.jcClass.name) + } + + val renderedType = StaticJavaParser.parseClassOrInterfaceType(qualifiedName(renderName)) + val argTypes = type.typeArguments.zip(type.typeParameters).map { (a, p) -> + when (a) { + is JcUnboundWildcard -> p.bounds.firstOrNull() ?: type.classpath.objectType + is JcBoundedWildcard -> (a.lowerBounds + a.upperBounds).first() + else -> a + } + } + + if (!includeGenericArgs || argTypes.any { (it is JcTypeVariableImpl && it.isRecursive) }) + return renderedType.removeTypeArguments() + + if (argTypes.isEmpty()) + return renderedType + + val renderedTypeArguments = argTypes.map { renderType(it, includeGenericArgs) } + return renderedType.setTypeArguments(NodeList(renderedTypeArguments)) + } + + fun renderClass(jcClass: JcClassOrInterface, includeGenericArgs: Boolean = true): ClassOrInterfaceType { + return renderClass(jcClass.toType(), includeGenericArgs) + } + + fun renderClassExpression(type: JcClassOrInterface): Expression = + ClassExpr(renderClass(type, false)) + + fun renderClassExpression(type: JcType): Expression = + ClassExpr(renderType(type, false)) + + //endregion + + //region Methods + + //region Test framework methods + + val assertionsClass: ClassOrInterfaceType get() = renderClass(JcTestFrameworkProvider.assertionsClassName) + + fun assertThrowsCall(exceptionClassExpr: Expression, observedLambda: Expression): MethodCallExpr { + check(JcTestFrameworkProvider.assertionsClassName != JcRendererTestFramework.JUNIT_4.assertionsClassName) { + "not yet supported" + } + + return MethodCallExpr( + TypeExpr(assertionsClass), + "assertThrows", + NodeList(exceptionClassExpr, observedLambda) + ) + } + + fun assertEqualsCall(expected: Expression, actual: Expression): MethodCallExpr { + val operands = mutableListOf(expected, actual) + if (JcTestFrameworkProvider.requireActualExpectedEqualsOrder) + operands.reverse() + + return MethodCallExpr( + TypeExpr(assertionsClass), + "assertEquals", + NodeList(operands) + ) + } + + //endregion + + //region Mockito methods + + val mockitoClass: ClassOrInterfaceType by lazy { renderClass("org.mockito.Mockito") } + + val mockitoCallsRealMethod by lazy { FieldAccessExpr(TypeExpr(mockitoClass), "CALLS_REAL_METHODS") } + + protected val JcField.isSpy: Boolean + get() = this is JcVirtualField && + name == "\$isSpyGenerated239" && + type == TypeNameImpl.fromTypeName("java.lang.Object") + + fun mockitoMockMethodCall(classToMock: JcClassType): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "mock", + NodeList(renderClassExpression(classToMock)) + ) + } + + fun mockitoSpyInstanceMethodCall(instanceToSpy: Expression): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "spy", + NodeList(instanceToSpy) + ) + } + + fun mockitoSpyClassMethodCall(classToSpy: JcClassType): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "spy", + NodeList(renderClassExpression(classToSpy)) + ) + } + + fun mockitoMockStaticMethodCall(mockedClass: JcClassOrInterface): MethodCallExpr { + val mockito = TypeExpr(mockitoClass) + return MethodCallExpr( + mockito, + "mockStatic", + NodeList(renderClassExpression(mockedClass), mockitoCallsRealMethod) + ) + } + + fun mockitoWhenMethodCall(receiver: Expression, methodCall: Expression): MethodCallExpr { + return MethodCallExpr( + receiver, + "when", + NodeList(methodCall) + ) + } + + fun mockitoThenThrowMethodCall(methodMock: Expression, exception: Expression): MethodCallExpr { + return MethodCallExpr( + methodMock, + "thenThrow", + NodeList(exception) + ) + } + + fun mockitoThenReturnMethodCall(methodMock: Expression, methodValue: Expression): MethodCallExpr { + return MethodCallExpr( + methodMock, + "thenReturn", + NodeList(methodValue) + ) + } + + fun mockitoDoAnswerMethodCall(receiver: Expression, doAnswerLambda: LambdaExpr): MethodCallExpr { + return MethodCallExpr( + receiver, + "doAnswer", + NodeList(doAnswerLambda) + ) + } + + fun mockitoThenAnswerMethodCall(receiver: Expression, thenAnswerLambda: LambdaExpr): MethodCallExpr { + return MethodCallExpr( + receiver, + "thenAnswer", + NodeList(thenAnswerLambda) + ) + } + + fun mockitoAnyBooleanMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyBoolean", + NodeList() + ) + } + + fun mockitoAnyByteMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyByte", + NodeList() + ) + } + + fun mockitoAnyCharMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyChar", + NodeList() + ) + } + + fun mockitoAnyIntMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyInt", + NodeList() + ) + } + + fun mockitoAnyLongMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyLong", + NodeList() + ) + } + + fun mockitoAnyFloatMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyFloat", + NodeList() + ) + } + + fun mockitoAnyDoubleMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyDouble", + NodeList() + ) + } + + fun mockitoAnyShortMethodCall(): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + "anyShort", + NodeList() + ) + } + + fun mockitoAnyMethodCall(type: JcType): MethodCallExpr { + return MethodCallExpr( + TypeExpr(mockitoClass), + NodeList(renderType(type)), + "any", + NodeList() + ) + } + + //endregion + + fun shouldRenderMethodCallAsPrivate(method: JcMethod): Boolean { + return !method.isPublic + } + + open fun renderPrivateCtorCall(ctor: JcMethod, type: JcClassType, args: List, inlinesVarargs: Boolean): Expression { + error("Rendering private methods is not supported") + } + + open fun renderConstructorCall(ctor: JcMethod, type: JcClassType, args: List, inlinesVarargs: Boolean): Expression { + check(ctor.isConstructor) { + "not a constructor in renderConstructorCall" + } + if (shouldRenderMethodCallAsPrivate(ctor)) + return renderPrivateCtorCall(ctor, type, args, inlinesVarargs) + + val castedArgs = callArgsWithGenericsCasted(ctor, args, inlinesVarargs) + + return when { + type.outerType == null || type.isStatic -> { + ObjectCreationExpr(null, renderClass(type), NodeList(castedArgs)) + } + + else -> { + val ctorTypeName = qualifiedName(type.jcClass.name).split(".").last() + val ctorType = StaticJavaParser.parseClassOrInterfaceType(ctorTypeName) + .setTypeArgsIfNeeded(true, type) + ObjectCreationExpr(castedArgs.first(), ctorType, NodeList(castedArgs.drop(1))) + } + } + } + + open fun renderPrivateMethodCall(method: JcMethod, instance: Expression, args: List, inlinesVarargs: Boolean): Expression { + error("Rendering private methods is not supported") + } + + open fun renderMethodCall(method: JcMethod, instance: Expression, args: List, inlinesVarargs: Boolean): Expression { + check(!method.isStatic) { + "cannot render static methods in renderMethodCall" + } + + if (shouldRenderMethodCallAsPrivate(method)) + return renderPrivateMethodCall(method, instance, args, inlinesVarargs) + + val castedArgs = callArgsWithGenericsCasted(method, args, inlinesVarargs) + + return MethodCallExpr( + instance, + method.name, + NodeList(castedArgs) + ) + } + + open fun renderPrivateStaticMethodCall(method: JcMethod, args: List, inlinesVarargs: Boolean): Expression { + error("Rendering private methods is not supported") + } + + open fun renderStaticMethodCall(method: JcMethod, args: List, inlinesVarargs: Boolean): Expression { + check(method.isStatic) { + "cannot render instance method in renderStaticMethodCall" + } + + if (shouldRenderMethodCallAsPrivate(method)) + return renderPrivateStaticMethodCall(method, args, inlinesVarargs) + + val castedArgs = callArgsWithGenericsCasted(method, args, inlinesVarargs) + + return MethodCallExpr( + renderStaticMethodCallScope(method, false), + method.name, + NodeList(castedArgs) + ) + } + + protected fun callArgsWithGenericsCasted(method: JcMethod, args: List, hasInlinedVarArgs: Boolean): List { + val typedParams = method.toTypedMethod.parameters.map { parameter -> parameter.type }.toMutableList() + + if (hasInlinedVarArgs) { + check(method.isVararg) { + "cannot inline non-vararg args" + } + + val varargParamType = typedParams.removeLast() + check(varargParamType is JcArrayType) { + "vararg param expected to be of array type" + } + + val extraArgType = varargParamType.elementType + val extraParamCount = max(args.size - typedParams.size, 0) + typedParams.addAll(List(extraParamCount) { extraArgType }) + } + + return args.zip(typedParams).map { (arg, paramType) -> + exprWithGenericsCasted(paramType, arg) + } + } + + protected fun exprWithGenericsCasted(type: JcType, expr: Expression): Expression { + if (type !is JcClassType || type.typeArguments.isEmpty()) return expr + val asObj = CastExpr(objectType, expr) + val asTargetType = CastExpr(renderType(type), asObj) + return asTargetType + } + + @Suppress("SameParameterValue") + private fun renderStaticMethodCallScope(method: JcMethod, allowStaticImport: Boolean): TypeExpr? { + val callType = method.enclosingClass.toType() + val useClassName = !allowStaticImport || !importManager.addStatic(callType.jcClass.name, method.name) + return if (useClassName) TypeExpr(renderClass(callType, includeGenericArgs = false)) else null + } + + protected open fun renderLambdaExpression(params: List, body: BlockStmt): Expression { + return LambdaExpr(NodeList(params), body) + } + + protected fun renderMethodParameter(type: JcType, name: String? = null): Parameter { + return renderMethodParameter(type.typeName, name) + } + + protected fun renderMethodParameter(clazz: JcClassOrInterface, name: String? = null): Parameter { + return renderMethodParameter(clazz.name, name) + } + + protected fun renderMethodParameter(typeName: String, name: String? = null): Parameter { + val paramName = identifiersManager.generateIdentifier(name ?: "param") + val renderedClass = renderClass(typeName) + return Parameter(renderedClass, paramName) + } + + //endregion + + //region Fields + + protected open fun shouldRenderGetFieldAsPrivate(field: JcField): Boolean { + return !field.isPublic + } + + protected open fun shouldRenderSetFieldAsPrivate(field: JcField): Boolean { + return !field.isPublic || field.isFinal + } + + open fun renderGetPrivateStaticField(field: JcField): Expression { + error("Rendering private fields is not supported") + } + + open fun renderGetStaticField(field: JcField): Expression { + check(field.isStatic) { + "cannot render instance field in renderGetStaticField" + } + + if (shouldRenderGetFieldAsPrivate(field)) + return renderGetPrivateStaticField(field) + + return FieldAccessExpr( + TypeExpr(renderClass(field.enclosingClass)), + field.name + ) + } + + open fun renderGetPrivateField(instance: Expression, field: JcField): Expression { + error("Rendering private fields is not supported") + } + + open fun renderGetField(instance: Expression, field: JcField): Expression { + check(!field.isStatic) { + "cannot render static field in renderGetField" + } + + if (shouldRenderGetFieldAsPrivate(field)) + return renderGetPrivateField(instance, field) + + return FieldAccessExpr( + instance, + field.name + ) + } + + open fun renderSetPrivateStaticField(field: JcField, value: Expression): Expression { + error("Rendering private fields is not supported") + } + + fun renderAssign(lhv: Expression, rhv: Expression): Expression { + return AssignExpr(lhv, rhv, AssignExpr.Operator.ASSIGN) + } + + fun renderSetStaticField(field: JcField, value: Expression): Expression { + check(field.isStatic) { + "cannot render instance field in renderSetStaticField" + } + + if (shouldRenderSetFieldAsPrivate(field)) + return renderSetPrivateStaticField(field, value) + + return renderAssign( + FieldAccessExpr(TypeExpr(renderClass(field.enclosingClass)), field.name), + value + ) + } + + open fun renderSetPrivateField(instance: Expression, field: JcField, value: Expression): Expression { + error("Rendering private fields is not supported") + } + + fun renderSetField(instance: Expression, field: JcField, value: Expression): Expression { + check(!field.isStatic) { + "cannot render static field in renderSetField" + } + + if (shouldRenderSetFieldAsPrivate(field)) + return renderSetPrivateField(instance, field, value) + + return renderAssign( + FieldAccessExpr(instance, field.name), + value + ) + } + + //endregion + + //region Arrays + + fun renderArraySet(array: Expression, index: Expression, value: Expression): Expression { + return renderAssign( + ArrayAccessExpr(array, index), + value + ) + } + + //endregion + + //region Modifiers + + fun getModifiers(accessible: JcAccessible): List { + val modifiers = mutableListOf() + if (accessible.isPublic) + modifiers.add(Modifier.publicModifier()) + if (accessible.isStatic) + modifiers.add(Modifier.staticModifier()) + if (accessible.isFinal) + modifiers.add(Modifier.finalModifier()) + if (accessible.isPrivate) + modifiers.add(Modifier.privateModifier()) + if (accessible.isAbstract) + modifiers.add(Modifier.abstractModifier()) + if (accessible.isProtected) + modifiers.add(Modifier.protectedModifier()) + + return modifiers + } + + //endregion + + //region Primitives + + fun renderBooleanPrimitive(value: Boolean): Expression = BooleanLiteralExpr(value) + + fun renderCharPrimitive(value: Char): Expression = CharLiteralExpr(value) + + fun renderStringPrimitive(value: String): Expression = StringLiteralExpr(value) + + fun renderBytePrimitive(value: Byte): Expression = + CastExpr(PrimitiveType.byteType(), IntegerLiteralExpr(value.toString())) + + fun renderShortPrimitive(value: Short): Expression = + CastExpr(PrimitiveType.shortType(), IntegerLiteralExpr(value.toString())) + + fun renderIntPrimitive(value: Int): Expression = IntegerLiteralExpr(value.toString()) + + fun renderLongPrimitive(value: Long): Expression = LongLiteralExpr(value.toString() + "L") + + fun renderFloatPrimitive(value: Float): Expression = DoubleLiteralExpr(value.toString() + "f") + + fun renderDoublePrimitive(value: Double): Expression = DoubleLiteralExpr(value.toString()) + + //endregion + + //region Annotations + + fun renderAnnotation(annotation: JcAnnotation): AnnotationExpr { + val annotationClass = annotation.jcClass + + check(annotationClass != null) { + "annotation class is null" + } + + val annotationName = Name(renderClass(annotationClass).nameWithScope) + + if (annotation.values.isEmpty()) { + return MarkerAnnotationExpr(annotationName) + } + + val annotationValues = renderAnnotationValues(annotation.values) + return NormalAnnotationExpr(annotationName, annotationValues) + } + + private fun renderAnnotationValues(rawValues: Map): NodeList { + val result = rawValues.map { (name, value) -> + val renderedValue = renderSingleAnnotationValue(value) + MemberValuePair(name, renderedValue) + } + return NodeList(result) + } + + private fun renderSingleAnnotationValue(value: Any?): Expression { + return when (value) { + null -> NullLiteralExpr() + + is String -> renderStringPrimitive(value) + + is Boolean -> renderBooleanPrimitive(value) + + is Char -> renderCharPrimitive(value) + + is Byte -> renderBytePrimitive(value) + + is Short-> renderShortPrimitive(value) + + is Int -> renderIntPrimitive(value) + + is Long -> renderLongPrimitive(value) + + is Float -> renderFloatPrimitive(value) + + is Double -> renderDoublePrimitive(value) + + is JcClassOrInterface -> { + renderClassExpression(value) + } + + is JcField -> { + check(value.isStatic) { + "enum value should be a static field" + } + + renderGetStaticField(value) + } + + is JcAnnotation -> { + renderAnnotation(value) + } + + is List<*> -> { + val renderedValues = NodeList(value.map { renderSingleAnnotationValue(it) }) + ArrayInitializerExpr(renderedValues) + } + + else -> error("unsupported annotation value kind $value") + } + } + + //endregion +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcFileRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcFileRenderer.kt new file mode 100644 index 0000000000..0dc76975b2 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcFileRenderer.kt @@ -0,0 +1,122 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.StaticJavaParser +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.PackageDeclaration +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.body.TypeDeclaration +import com.github.javaparser.ast.expr.SimpleName +import kotlin.jvm.optionals.getOrNull +import org.jacodb.api.jvm.JcClasspath + +open class JcFileRenderer : JcCodeRenderer { + + companion object { + private fun resolvePackageDeclarationFrom(packageName: String?, cu: CompilationUnit?): PackageDeclaration? { + val existingPackageDecl = cu?.packageDeclaration?.getOrNull() + + if (existingPackageDecl != null) + return existingPackageDecl + + return if (packageName.isNullOrBlank()) + null + else + PackageDeclaration(StaticJavaParser.parseName(packageName)) + } + } + + private constructor( + packageName: String?, + cu: CompilationUnit?, + importManager: JcImportManager, + cp: JcClasspath + ) : super( + importManager, + JcIdentifiersManager(cu), + cp + ) { + this.packageDeclaration = resolvePackageDeclarationFrom(packageName, cu) + this.existingMembers.addAll(cu?.types?.filterIsInstance().orEmpty()) + } + + protected constructor( + cu: CompilationUnit, + importManager: JcImportManager, + cp: JcClasspath + ) : this( + null, + cu, + importManager, + cp + ) + + protected constructor( + packageName: String?, + importManager: JcImportManager, + cp: JcClasspath + ) : this( + packageName, + null, + importManager, + cp + ) + + protected constructor(cu: CompilationUnit, cp: JcClasspath) : this(cu, JcImportManager(cu), cp) + + protected constructor(packageName: String?, cp: JcClasspath): this(packageName, JcImportManager(), cp) + + protected val packageDeclaration: PackageDeclaration? + + private val existingMembers: MutableList = mutableListOf() + private val renderingClasses: MutableList = mutableListOf() + + override fun renderInternal(): CompilationUnit { + val renderedClasses = mutableListOf>() + + for (renderer in renderingClasses) { + try { + val classRender = renderer.render() + renderedClasses.add(classRender) + } catch (e: Throwable) { + println("Renderer failed to render class: ${e.message}") + println("with ${e.stackTraceToString()}") + } + } + + val importDeclarations = importManager.render() + + val classEntries = NodeList(renderedClasses) + classEntries.addAll(existingMembers) + + return CompilationUnit( + packageDeclaration, + importDeclarations, + classEntries, + null + ) + } + + protected open fun classRendererFor(declaration: ClassOrInterfaceDeclaration): JcClassRenderer { + return JcClassRenderer(declaration, importManager, identifiersManager, cp) + } + + protected fun findRenderingClass(name: SimpleName): JcClassRenderer? { + val existingRenderer = renderingClasses.find { it.name == name } + if (existingRenderer != null) return existingRenderer + + val existingDecl = existingMembers.find { decl -> decl.name == name } + + if (existingDecl == null) return null + + existingMembers.removeIf { decl -> decl == existingDecl } + val renderer = classRendererFor(existingDecl) + addRenderingClass(renderer) + + return renderer + } + + fun addRenderingClass(render: JcClassRenderer) { + renderingClasses.add(render) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcIdentifiersManager.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcIdentifiersManager.kt new file mode 100644 index 0000000000..59f04b9df7 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcIdentifiersManager.kt @@ -0,0 +1,49 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.SimpleName +import org.usvm.jvm.rendering.normalized +import kotlin.math.max + +class JcIdentifiersManager private constructor( + private val prefixIndexer: MutableMap +) { + companion object { + private fun MutableMap.addAll(values: List) { + values.forEach { value -> + val index = value.takeLastWhile { it.isDigit() }.toIntOrNull() ?: 0 + val prefix = value.dropLastWhile { it.isDigit() } + compute(prefix) { pref, curMax -> max(curMax ?: 0, index) } + } + } + } + + constructor(manager: JcIdentifiersManager) : this(manager.prefixIndexer.toMutableMap()) + + constructor(cu: CompilationUnit? = null): this(mutableMapOf()) { + val classes = cu?.types?.map { it.name.asString() }.orEmpty() + prefixIndexer.addAll(classes) + } + + fun extendedWith(declaration: ClassOrInterfaceDeclaration): JcIdentifiersManager { + val newManager = JcIdentifiersManager(this) + + val fields = declaration.fields.flatMap { field -> field.variables.map { v -> v.name.asString() } } + newManager.prefixIndexer.addAll(fields) + + val methods = declaration.methods.map { it.name.asString() } + newManager.prefixIndexer.addAll(methods) + + return newManager + } + + fun generateIdentifier(prefix: String): SimpleName { + val normalizedPrefix = prefix.normalized + val id = prefixIndexer.merge(normalizedPrefix, 0) { a, _ -> a + 1 } + val suffix = if (id == 0) "" else id.toString() + return SimpleName("$normalizedPrefix$suffix") + } + + operator fun get(prefix: String): SimpleName = generateIdentifier(prefix) +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcImportManager.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcImportManager.kt new file mode 100644 index 0000000000..1580a2e3dc --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcImportManager.kt @@ -0,0 +1,77 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.ImportDeclaration +import com.github.javaparser.ast.NodeList + +open class JcImportManager(cu: CompilationUnit? = null) { + protected val names: MutableSet + protected val staticNames: MutableSet + protected val packages: MutableSet + protected val staticPackages: MutableSet + + private val simpleToPackage: MutableMap + + init { + if (cu != null) { + names = cu.imports.filter { !it.isAsterisk && !it.isStatic }.map { it.nameAsString }.toMutableSet() + staticNames = cu.imports.filter { !it.isAsterisk && it.isStatic }.map { it.nameAsString }.toMutableSet() + packages = cu.imports.filter { it.isAsterisk && !it.isStatic }.map { it.nameAsString }.toMutableSet() + staticPackages = cu.imports.filter { it.isAsterisk && it.isStatic }.map { it.nameAsString }.toMutableSet() + + simpleToPackage = (names + staticNames).associateBy { it.split(".").last() }.toMutableMap() + } else { + names = mutableSetOf() + simpleToPackage = mutableMapOf() + staticNames = mutableSetOf() + packages = mutableSetOf() + staticPackages = mutableSetOf() + } + + packages.add("java.lang") + } + + protected open fun add( + packageName: String, + simpleName: String, + packages: MutableSet, + names: MutableSet + ): Boolean { + if (packageName in packages) return true + val fullName = "$packageName.$simpleName".trimStart('.') + if (fullName in names) return true + if (simpleToPackage.putIfAbsent(simpleName, fullName) != null) + return false + + names.add(fullName) + + return true + } + + fun add(import: String): Boolean { + return add(import.substringBeforeLast('.', ""), import.substringAfterLast('.')) + } + + fun add(packageName: String, simpleName: String): Boolean { + return add(packageName, simpleName, packages, names) + } + + fun addStatic(import: String): Boolean { + val tokens = import.split(".") + return addStatic(tokens.dropLast(1).joinToString("."), tokens.last()) + } + + fun addStatic(packageName: String, simpleName: String): Boolean { + return add(packageName, simpleName, staticPackages, staticNames) + } + + fun render(): NodeList { + val declarations = buildList { + addAll(names.map { ImportDeclaration(it, false, false) }) + addAll(packages.map { ImportDeclaration(it, false, true) }) + addAll(staticNames.map { ImportDeclaration(it, true, false) }) + addAll(staticPackages.map { ImportDeclaration(it, true, true) }) + } + return NodeList(declarations) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcMethodRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcMethodRenderer.kt new file mode 100644 index 0000000000..840cbad855 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcMethodRenderer.kt @@ -0,0 +1,43 @@ +package org.usvm.jvm.rendering.baseRenderer + +import com.github.javaparser.ast.Modifier +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.body.Parameter +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import com.github.javaparser.ast.type.Type +import org.jacodb.api.jvm.JcClasspath + +open class JcMethodRenderer( + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + internal open val classRenderer: JcClassRenderer, + private val name: SimpleName, + private val modifiers: NodeList, + private val annotations: NodeList, + private val parameters: NodeList, + private val returnType: Type +): JcCodeRenderer(importManager, identifiersManager, cp) { + + protected open val body: JcBlockRenderer = JcBlockRenderer( + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp + ) + + override fun renderInternal(): MethodDeclaration { + return MethodDeclaration( + modifiers, + annotations, + NodeList(), + returnType, + name, + parameters, + body.getThrownExceptions(), + body.render() + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTestFrameworkProvider.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTestFrameworkProvider.kt new file mode 100644 index 0000000000..11881022b2 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTestFrameworkProvider.kt @@ -0,0 +1,21 @@ +package org.usvm.jvm.rendering.baseRenderer + +enum class JcRendererTestFramework(val testAnnotationClassName: String, val assertionsClassName: String) { + JUNIT_5("org.junit.jupiter.api.Test", "org.junit.jupiter.api.Assertions"), + JUNIT_4("org.junit.Test", ""), + TEST_NG("org.testng.annotations.Test", "org.testng.Assert") +} + +object JcTestFrameworkProvider { + private var framework: JcRendererTestFramework = JcRendererTestFramework.JUNIT_5 + + fun setTestFramework(newFramework: JcRendererTestFramework) { + framework = newFramework + } + + val assertionsClassName: String get() = framework.assertionsClassName + + val testAnnotationClassName: String get() = framework.testAnnotationClassName + + val requireActualExpectedEqualsOrder: Boolean get() = framework == JcRendererTestFramework.TEST_NG +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTypeVariableExt.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTypeVariableExt.kt new file mode 100644 index 0000000000..33cd221986 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/baseRenderer/JcTypeVariableExt.kt @@ -0,0 +1,96 @@ +package org.usvm.jvm.rendering.baseRenderer + +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcTypeVariable +import org.jacodb.api.jvm.JvmType +import org.jacodb.api.jvm.ext.findType +import org.jacodb.impl.types.JcTypeVariableDeclarationImpl +import org.jacodb.impl.types.JcTypeVariableImpl +import org.jacodb.impl.types.signature.JvmBoundWildcard +import org.jacodb.impl.types.signature.JvmClassRefType +import org.jacodb.impl.types.signature.JvmParameterizedType +import org.jacodb.impl.types.signature.JvmRefType +import org.jacodb.impl.types.signature.JvmTypeParameterDeclarationImpl +import org.jacodb.impl.types.signature.JvmTypeVariable +import org.jacodb.impl.types.signature.JvmUnboundWildcard + +/* + * TODO: JcTypeVariable require info about declaration + */ +object JcTypeVariableExt { + + val JcTypeVariableImpl.isRecursive: Boolean get() { + val declaration = (declaration as? JcTypeVariableDeclarationImpl) ?: return false + val jvmBounds = declaration.jvmBounds + + val existingSymbols = HashSet() + + existingSymbols.add(this.toJvmTypeOrNull()!!) + + return jvmBounds.any { bound -> bound.presentInOrRecursive(existingSymbols, this.classpath) } + } + + private fun JcTypeVariable.toJvmTypeOrNull(): JvmType? { + this as? JcTypeVariableImpl ?: return null + val declaration = (declaration as? JcTypeVariableDeclarationImpl) ?: return null + val thisTypeDecl = JvmTypeParameterDeclarationImpl(declaration.symbol, declaration.owner, declaration.jvmBounds) + val thisType = JvmTypeVariable(thisTypeDecl, this.nullable, this.annotations) + return thisType + } + + private fun JvmType.presentInOrRecursive(existingSymbols: HashSet, cp: JcClasspath): Boolean { + return when (this) { + + is JvmTypeVariable -> { + if (!existingSymbols.add(this)) return true + + val bounds = this.declaration?.bounds ?: return false + + bounds.any { bound -> + bound.presentInOrRecursive(HashSet(existingSymbols), cp) + } + } + + is JvmBoundWildcard -> { + if (!existingSymbols.add(bound)) return true + + bound.presentInOrRecursive(existingSymbols, cp) + } + + is JvmParameterizedType -> { + if (!existingSymbols.add(this)) return true + val decl = cp.findType(this.name) as JcClassType + + val params = decl.typeParameters.zip(this.parameterTypes).mapNotNull { (declParam, realParam) -> + if (realParam !is JvmUnboundWildcard) { + realParam + } + else { + check(declParam is JcTypeVariableDeclarationImpl) { + "JcClassType declaration expected not to have ?" + } + declParam.jvmBounds.firstOrNull() + } + } + + return params.any { bound -> + bound.presentInOrRecursive(HashSet(existingSymbols), cp) + } + } + + is JvmRefType -> { + if (!existingSymbols.add(this)) return true + this as? JvmClassRefType ?: return false + (cp.findType(this.name) as JcClassType).typeParameters.any { declParam -> + check(declParam is JcTypeVariableDeclarationImpl) { + "JcClassType declaration expected not to have ?" + } + declParam.jvmBounds.firstOrNull()?.presentInOrRecursive(existingSymbols, cp) == true + } + } + + else -> false + } + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/JcSpringReflectionUtilsRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/JcSpringReflectionUtilsRenderer.kt new file mode 100644 index 0000000000..721e9af29e --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/JcSpringReflectionUtilsRenderer.kt @@ -0,0 +1,249 @@ +package org.usvm.jvm.rendering.spring + +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.expr.CastExpr +import com.github.javaparser.ast.expr.ClassExpr +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.expr.StringLiteralExpr +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.ext.findType +import org.objectweb.asm.Opcodes +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestBlockRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeUtilsRenderer + +class JcSpringReflectionUtilsRenderer( + importManager: JcImportManager, + utilsInlineStrategy: ReflectionUtilsInlineStrategy, + private val isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean, +) : JcUnsafeUtilsRenderer(importManager, utilsInlineStrategy) { + + companion object { + private const val SPRING_TEST = "org.springframework.test.util.ReflectionTestUtils" + private const val SPRING_TEST_SIMPLE = "ReflectionTestUtils" + private const val SPRING_INTERNAL = "org.springframework.util.ReflectionUtils" + private const val SPRING_INTERNAL_SIMPLE = "ReflectionUtils" + } + + private val springTestUtilsName: Expression by lazy { + NameExpr(if (importManager.add(SPRING_TEST)) SPRING_TEST_SIMPLE else SPRING_TEST) + } + + private val springInternalUtilsName: Expression by lazy { + NameExpr(if (importManager.add(SPRING_INTERNAL)) SPRING_INTERNAL_SIMPLE else SPRING_INTERNAL) + } + + override fun renderCtorCall( + blockRenderer: JcUnsafeTestBlockRenderer, + ctor: JcMethod, + type: JcClassType, + args: List, + inlinesVarargs: Boolean + ): Expression { + return if (isAccessibleFromTestClass(type.jcClass)) + springCtorCall(blockRenderer, ctor, type, args, inlinesVarargs) + else + super.renderCtorCall(blockRenderer, ctor, type, args, inlinesVarargs) + } + + @Suppress("unused") + private fun springCtorCall( + blockRenderer: JcUnsafeTestBlockRenderer, + ctor: JcMethod, + type: JcClassType, + args: List, + inlinesVarargs: Boolean + ): Expression { + blockRenderer.addThrownExceptions( + listOf( + "java.lang.reflect.InvocationTargetException", + "java.lang.NoSuchMethodException", + "java.lang.InstantiationException", + "java.lang.IllegalAccessException" + ) + ) + val cp = ctor.enclosingClass.classpath + val ctorParametersTypes = ctor.parameters.map { cp.findType(it.type.typeName) } + val instanceType = blockRenderer.renderClass(type, includeGenericArgs = false) + val accessibleCtorArgs = listOf(ClassExpr(instanceType)) + ctorParametersTypes.map { + ClassExpr(blockRenderer.renderType(it, false)) + } + + val accessibleCtor = MethodCallExpr( + springInternalUtilsName, + NodeList(instanceType), + "accessibleConstructor", + NodeList(accessibleCtorArgs), + ) + val newInstCall = MethodCallExpr(accessibleCtor, "newInstance", NodeList(args)) + return newInstCall + } + + override fun renderInstanceMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + instance: Expression, + args: List, + inlinesVarargs: Boolean + ): Expression { + return if (isAccessibleFromTestClass(method.enclosingClass)) + springInstanceMethodCall(blockRenderer, method, instance, args, inlinesVarargs) + else + super.renderInstanceMethodCall(blockRenderer, method, instance, args, inlinesVarargs) + } + + @Suppress("unused") + private fun springInstanceMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + instance: Expression, + args: List, + inlinesVarargs: Boolean + ): Expression { + val allArgs = listOf(instance, StringLiteralExpr(method.name)) + args + return MethodCallExpr( + springTestUtilsName, + listTypeArgsFor(method, blockRenderer), + "invokeMethod", + NodeList(allArgs), + ) + } + + override fun renderStaticMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + args: List, + inlinesVarargs: Boolean + ): Expression { + return if (isAccessibleFromTestClass(method.enclosingClass)) + springStaticMethodCall(blockRenderer, method, args, inlinesVarargs) + else + super.renderStaticMethodCall(blockRenderer, method, args, inlinesVarargs) + } + + @Suppress("unused") + private fun springStaticMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + args: List, + inlinesVarargs: Boolean + ): Expression { + blockRenderer.addThrownException("java.lang.Throwable") + val enclosingClass = method.enclosingClass + val invokeMethodArgs = listOf( + blockRenderer.renderClassExpression(enclosingClass), + StringLiteralExpr(method.name) + ) + args + + return MethodCallExpr( + springTestUtilsName, + listTypeArgsFor(method, blockRenderer), + "invokeMethod", + NodeList(invokeMethodArgs), + ) + } + + override fun renderGetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField + ): Expression { + return if (isAccessibleFromTestClass(field.enclosingClass)) + springGetInstanceField(blockRenderer, instance, field) + else + super.renderGetInstanceField(blockRenderer, instance, field) + } + + private fun springGetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField + ): Expression { + val call = MethodCallExpr( + springTestUtilsName, + "getField", + NodeList(instance, StringLiteralExpr(field.name)) + ) + + return CastExpr(blockRenderer.renderType(fieldType(field)), call) + } + + override fun renderGetStaticField(blockRenderer: JcUnsafeTestBlockRenderer, field: JcField): Expression { + return if (isAccessibleFromTestClass(field.enclosingClass)) + springGetStaticField(blockRenderer, field) + else + super.renderGetStaticField(blockRenderer, field) + } + + private fun springGetStaticField(blockRenderer: JcUnsafeTestBlockRenderer, field: JcField): Expression { + val call = MethodCallExpr( + springTestUtilsName, + "getField", + NodeList( + blockRenderer.renderClassExpression(field.enclosingClass), + StringLiteralExpr(field.name) + ), + ) + return CastExpr(blockRenderer.renderType(fieldType(field)), call) + } + + override fun renderSetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField, + value: Expression + ): Expression { + return if (isAccessibleFromTestClass(field.enclosingClass) && !field.enclosingClass.isRecord) + springSetInstanceField(blockRenderer, instance, field, value) + else + super.renderSetInstanceField(blockRenderer, instance, field, value) + } + + private val JcClassOrInterface.isRecord: Boolean get() = (access and Opcodes.ACC_RECORD) != 0 + + private fun springSetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField, + value: Expression + ): Expression { + return MethodCallExpr( + springTestUtilsName, + "setField", + NodeList(instance, StringLiteralExpr(field.name), value), + ) + } + + override fun renderSetStaticField( + blockRenderer: JcUnsafeTestBlockRenderer, + field: JcField, + value: Expression + ): Expression { + return if (isAccessibleFromTestClass(field.enclosingClass) && !field.isFinal) + springSetStaticField(blockRenderer, field, value) + else + super.renderSetStaticField(blockRenderer, field, value) + } + + private fun springSetStaticField( + blockRenderer: JcUnsafeTestBlockRenderer, + field: JcField, + value: Expression + ): Expression { + return MethodCallExpr( + springTestUtilsName, + "setField", + NodeList( + blockRenderer.renderClassExpression(field.enclosingClass), + StringLiteralExpr(field.name), + value + ), + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestBlockRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestBlockRenderer.kt new file mode 100644 index 0000000000..220adbab23 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestBlockRenderer.kt @@ -0,0 +1,63 @@ +package org.usvm.jvm.rendering.spring.unitTestRenderer + +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.type.ReferenceType +import java.util.IdentityHashMap +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestBlockRenderer +import org.usvm.test.api.UTestExpression + +open class JcSpringUnitTestBlockRenderer protected constructor( + override val methodRenderer: JcSpringUnitTestRenderer, + override val importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + exprCache: IdentityHashMap, + thrownExceptions: HashSet, + unsafeUtilsRenderer: JcSpringReflectionUtilsRenderer +) : JcUnsafeTestBlockRenderer( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + exprCache, + thrownExceptions, + unsafeUtilsRenderer +) { + + constructor( + methodRenderer: JcSpringUnitTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + unsafeUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : this( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + IdentityHashMap(), + HashSet(), + unsafeUtilsRenderer + ) + + override fun newInnerBlock(): JcSpringUnitTestBlockRenderer { + return JcSpringUnitTestBlockRenderer( + methodRenderer, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + IdentityHashMap(exprCache), + thrownExceptions, + unsafeUtilsRenderer as JcSpringReflectionUtilsRenderer + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestClassRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestClassRenderer.kt new file mode 100644 index 0000000000..e54b7d23a0 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestClassRenderer.kt @@ -0,0 +1,50 @@ +package org.usvm.jvm.rendering.spring.unitTestRenderer + +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.testRenderer.JcTestRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestClassRenderer +import org.usvm.test.api.UTest + +open class JcSpringUnitTestClassRenderer : JcUnsafeTestClassRenderer { + + constructor( + name: String, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : super(name, importManager, identifiersManager, cp, reflectionUtilsRenderer) + + constructor( + decl: ClassOrInterfaceDeclaration, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : super(decl, importManager, identifiersManager, cp, reflectionUtilsRenderer) + + + override fun createTestRenderer( + test: UTest, + identifiersManager: JcIdentifiersManager, + name: SimpleName, + annotations: List, + ): JcTestRenderer { + return JcSpringUnitTestRenderer( + test, + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + name, + annotations, + unsafeUtilsRenderer as JcSpringReflectionUtilsRenderer + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestFileRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestFileRenderer.kt new file mode 100644 index 0000000000..2a6ada91d9 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestFileRenderer.kt @@ -0,0 +1,73 @@ +package org.usvm.jvm.rendering.spring.unitTestRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestFileRenderer + +open class JcSpringUnitTestFileRenderer: JcUnsafeTestFileRenderer { + + private val isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + + override val unsafeUtilsRenderer: JcSpringReflectionUtilsRenderer by lazy { + JcSpringReflectionUtilsRenderer(importManager, reflectionUtilsInlineStrategy, isAccessibleFromTestClass) + } + + protected constructor( + cu: CompilationUnit, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : super(cu, importManager, cp, reflectionUtilsInlineStrategy) { + this.isAccessibleFromTestClass = isAccessibleFromTestClass + } + + protected constructor( + packageName: String?, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : super(packageName, importManager, cp, reflectionUtilsInlineStrategy) { + this.isAccessibleFromTestClass = isAccessibleFromTestClass + } + + constructor( + cu: CompilationUnit, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : this( + cu, + JcImportManager(cu), + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + + constructor( + packageName: String?, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : this( + packageName, + JcImportManager(null), + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + + override fun classRendererFor(declaration: ClassOrInterfaceDeclaration): JcSpringUnitTestClassRenderer { + return JcSpringUnitTestClassRenderer(declaration, importManager, identifiersManager, cp, unsafeUtilsRenderer) + } + + override fun classRendererFor(name: String): JcSpringUnitTestClassRenderer { + return JcSpringUnitTestClassRenderer(name, importManager, identifiersManager, cp, unsafeUtilsRenderer) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestInfo.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestInfo.kt new file mode 100644 index 0000000000..5c9572a741 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestInfo.kt @@ -0,0 +1,14 @@ +package org.usvm.jvm.rendering.spring.unitTestRenderer + +import java.nio.file.Path +import org.jacodb.api.jvm.JcMethod +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestInfo + +open class JcSpringUnitTestInfo( + method: JcMethod, + isExceptional: Boolean, + testFilePath: Path? = null, + testPackageName: String? = null, + testClassName: String? = null, + testName: String? = null +) : JcUnsafeTestInfo(method, isExceptional, testFilePath, testPackageName, testClassName, testName) \ No newline at end of file diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestRenderer.kt new file mode 100644 index 0000000000..d404cb4b3a --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/unitTestRenderer/JcSpringUnitTestRenderer.kt @@ -0,0 +1,40 @@ +package org.usvm.jvm.rendering.spring.unitTestRenderer + +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.unsafeRenderer.JcUnsafeTestRenderer +import org.usvm.test.api.UTest + +open class JcSpringUnitTestRenderer( + test: UTest, + classRenderer: JcSpringUnitTestClassRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + name: SimpleName, + annotations: List, + unsafeUtilsRenderer: JcSpringReflectionUtilsRenderer +): JcUnsafeTestRenderer( + test, + classRenderer, + importManager, + identifiersManager, + cp, + name, + annotations, + unsafeUtilsRenderer +) { + + override val body: JcSpringUnitTestBlockRenderer = JcSpringUnitTestBlockRenderer( + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + unsafeUtilsRenderer + ) +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestBlockRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestBlockRenderer.kt new file mode 100644 index 0000000000..75ef2f7ee5 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestBlockRenderer.kt @@ -0,0 +1,82 @@ +package org.usvm.jvm.rendering.spring.webMvcTestRenderer + +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.FieldAccessExpr +import com.github.javaparser.ast.expr.ThisExpr +import com.github.javaparser.ast.type.ReferenceType +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestBlockRenderer +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestGetFieldExpression +import java.util.IdentityHashMap +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer + +open class JcSpringMvcTestBlockRenderer protected constructor( + override val methodRenderer: JcSpringMvcTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + exprCache: IdentityHashMap, + thrownExceptions: HashSet, + private val mvcTestClass: JcClassOrInterface, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer +) : JcSpringUnitTestBlockRenderer( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + exprCache, + thrownExceptions, + reflectionUtilsRenderer +) { + + constructor( + methodRenderer: JcSpringMvcTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + mvcTestClass: JcClassOrInterface, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : this( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + IdentityHashMap(), + HashSet(), + mvcTestClass, + reflectionUtilsRenderer + ) + + override fun newInnerBlock(): JcSpringMvcTestBlockRenderer { + return JcSpringMvcTestBlockRenderer( + methodRenderer, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + IdentityHashMap(exprCache), + thrownExceptions, + mvcTestClass, + unsafeUtilsRenderer as JcSpringReflectionUtilsRenderer + ) + } + + override fun renderGetFieldExpression(expr: UTestGetFieldExpression): Expression { + val field = expr.field + if (expr.field.enclosingClass == mvcTestClass) { + val testClassField = classRenderer.getOrCreateField(field) + return FieldAccessExpr(ThisExpr(), NodeList(), testClassField) + } + + return super.renderGetFieldExpression(expr) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestClassRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestClassRenderer.kt new file mode 100644 index 0000000000..7a33ecbd57 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestClassRenderer.kt @@ -0,0 +1,96 @@ +package org.usvm.jvm.rendering.spring.webMvcTestRenderer + +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcAnnotation +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.PredefinedPrimitives +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.testRenderer.JcTestRenderer +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestClassRenderer +import org.usvm.jvm.rendering.testTransformers.JcSpringMvcTestTransformer +import org.usvm.test.api.UTest + +class JcSpringMvcTestClassRenderer : JcSpringUnitTestClassRenderer { + + private val controller: JcClassType + + private lateinit var testClass: JcClassOrInterface + + constructor( + controller: JcClassType, + name: String, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : super(name, importManager, identifiersManager, cp, reflectionUtilsRenderer) { + this.controller = controller + } + + constructor( + controller: JcClassType, + decl: ClassOrInterfaceDeclaration, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer + ) : super(decl, importManager, identifiersManager, cp, reflectionUtilsRenderer) { + this.controller = controller + } + + override fun createTestRenderer( + test: UTest, + identifiersManager: JcIdentifiersManager, + name: SimpleName, + annotations: List, + ): JcTestRenderer { + val mvcTransformer = JcSpringMvcTestTransformer() + val transformedTest = mvcTransformer.transform(test) + + if (!this::testClass.isInitialized) { + testClass = mvcTransformer.testClass + } else { + check(testClass == mvcTransformer.testClass) { + "only one test class expected for class renderer" + } + } + + val testMethodAnnotations = testMethodAnnotationsFrom(testClass).map { renderAnnotation(it) } + + return JcSpringMvcTestRenderer( + transformedTest, + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + name, + annotations + testMethodAnnotations, + mvcTransformer.testClass, + unsafeUtilsRenderer as JcSpringReflectionUtilsRenderer + ) + } + + override fun renderInternal(): ClassOrInterfaceDeclaration { + check(this::testClass.isInitialized) { + "test class expected in class renderer" + } + + testClass.annotations.forEach { annotation -> addAnnotation(annotation) } + + return super.renderInternal() + } + + private fun testMethodAnnotationsFrom(stubClass: JcClassOrInterface): List { + return stubClass.declaredMethods.single { it.isFakeTest }.annotations + } + + private val JcMethod.isFakeTest: Boolean + get() = name == "fakeTest" && returnType.typeName == PredefinedPrimitives.Void && parameters.isEmpty() +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestFileRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestFileRenderer.kt new file mode 100644 index 0000000000..9ec8f04f93 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestFileRenderer.kt @@ -0,0 +1,74 @@ +package org.usvm.jvm.rendering.spring.webMvcTestRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestFileRenderer + +class JcSpringMvcTestFileRenderer : JcSpringUnitTestFileRenderer { + private constructor( + controller: JcClassType, + cu: CompilationUnit, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : super(cu, importManager, cp, reflectionUtilsInlineStrategy, isAccessibleFromTestClass) { + this.controller = controller + } + + private constructor( + controller: JcClassType, + packageName: String?, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : super(packageName, importManager, cp, reflectionUtilsInlineStrategy, isAccessibleFromTestClass) { + this.controller = controller + } + + constructor( + controller: JcClassType, + cu: CompilationUnit, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : this( + controller, + cu, + JcImportManager(cu), + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + + constructor( + controller: JcClassType, + packageName: String?, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy, + isAccessibleFromTestClass: (JcClassOrInterface) -> Boolean + ) : this( + controller, + packageName, + JcImportManager(null), + cp, + reflectionUtilsInlineStrategy, + isAccessibleFromTestClass + ) + + private val controller: JcClassType + + override fun classRendererFor(declaration: ClassOrInterfaceDeclaration): JcSpringMvcTestClassRenderer { + return JcSpringMvcTestClassRenderer(controller, declaration, importManager, identifiersManager, cp, unsafeUtilsRenderer) + } + + override fun classRendererFor(name: String): JcSpringMvcTestClassRenderer { + return JcSpringMvcTestClassRenderer(controller, name, importManager, identifiersManager, cp, unsafeUtilsRenderer) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestInfo.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestInfo.kt new file mode 100644 index 0000000000..796ce3b90c --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestInfo.kt @@ -0,0 +1,27 @@ +package org.usvm.jvm.rendering.spring.webMvcTestRenderer + +import java.nio.file.Path +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.ext.toType +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestInfo + +class JcSpringMvcTestInfo( + method: JcMethod, + isExceptional: Boolean, + testFilePath: Path? = null, + testPackageName: String? = null, + testClassName: String? = null, + testName: String? = null +) : JcSpringUnitTestInfo(method, isExceptional, testFilePath, testPackageName, testClassName, testName) { + + val controller by lazy { method.enclosingClass.toType() } + + override fun hashCode(): Int { + return method.enclosingClass.hashCode() + } + + override fun equals(other: Any?): Boolean { + if (other == null || other !is JcSpringMvcTestInfo) return false + return controller == other.controller && testFilePath == other.testFilePath && testClassName == other.testClassName + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestRenderer.kt new file mode 100644 index 0000000000..296ae81231 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/spring/webMvcTestRenderer/JcSpringMvcTestRenderer.kt @@ -0,0 +1,43 @@ +package org.usvm.jvm.rendering.spring.webMvcTestRenderer + +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.spring.JcSpringReflectionUtilsRenderer +import org.usvm.jvm.rendering.spring.unitTestRenderer.JcSpringUnitTestRenderer +import org.usvm.test.api.UTest + +open class JcSpringMvcTestRenderer( + test: UTest, + override val classRenderer: JcSpringMvcTestClassRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + name: SimpleName, + annotations: List, + mvcTestClass: JcClassOrInterface, + reflectionUtilsRenderer: JcSpringReflectionUtilsRenderer +): JcSpringUnitTestRenderer( + test, + classRenderer, + importManager, + identifiersManager, + cp, + name, + annotations, + reflectionUtilsRenderer +) { + + override val body: JcSpringMvcTestBlockRenderer = JcSpringMvcTestBlockRenderer( + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + mvcTestClass, + reflectionUtilsRenderer + ) +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestBlockRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestBlockRenderer.kt new file mode 100644 index 0000000000..85ced33713 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestBlockRenderer.kt @@ -0,0 +1,776 @@ +package org.usvm.jvm.rendering.testRenderer + +import com.github.javaparser.ast.ArrayCreationLevel +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.Parameter +import com.github.javaparser.ast.expr.ArrayAccessExpr +import com.github.javaparser.ast.expr.ArrayCreationExpr +import com.github.javaparser.ast.expr.BinaryExpr +import com.github.javaparser.ast.expr.CastExpr +import com.github.javaparser.ast.expr.ClassExpr +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.FieldAccessExpr +import com.github.javaparser.ast.expr.IntegerLiteralExpr +import com.github.javaparser.ast.expr.LambdaExpr +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.expr.NullLiteralExpr +import com.github.javaparser.ast.expr.StringLiteralExpr +import com.github.javaparser.ast.expr.TypeExpr +import com.github.javaparser.ast.stmt.ReturnStmt +import com.github.javaparser.ast.type.ReferenceType +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.PredefinedPrimitives +import org.usvm.jvm.rendering.baseRenderer.JcBlockRenderer +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import java.util.IdentityHashMap +import kotlin.collections.filter +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcTypedMethod +import org.jacodb.api.jvm.ext.findType +import org.jacodb.api.jvm.ext.isAssignable +import org.jacodb.api.jvm.ext.jcdbSignature +import org.jacodb.impl.features.classpaths.virtual.JcVirtualMethod +import org.usvm.jvm.rendering.isVararg +import org.usvm.jvm.util.toTypedMethod +import org.usvm.test.api.* +import partitionByKey + +open class JcTestBlockRenderer protected constructor( + override val methodRenderer: JcTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + protected val shouldDeclareVar: Set, + protected val exprCache: IdentityHashMap, + thrownExceptions: HashSet +) : JcBlockRenderer(methodRenderer, importManager, identifiersManager, cp, thrownExceptions) { + + constructor( + methodRenderer: JcTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set + ) : this(methodRenderer, importManager, identifiersManager, cp, shouldDeclareVar, IdentityHashMap(), HashSet()) + + override fun newInnerBlock(): JcTestBlockRenderer { + return JcTestBlockRenderer( + methodRenderer, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + IdentityHashMap(exprCache), + thrownExceptions + ) + } + + fun renderInst(inst: UTestInst) = when (inst) { + is UTestStatement -> renderStatement(inst) + is UTestExpression -> addExpression(renderExpression(inst)) + } + + protected fun renderStatement(stmt: UTestStatement) { + return when (stmt) { + is UTestArraySetStatement -> renderArraySetStatement(stmt) + is UTestBinaryConditionStatement -> renderBinaryConditionStatement(stmt) + is UTestSetFieldStatement -> renderSetFieldStatement(stmt) + is UTestSetStaticFieldStatement -> renderSetStaticFieldStatement(stmt) + } + } + + protected fun renderExpression(expr: UTestExpression): Expression = + exprCache.getOrPut(expr) { + val rendered = doRenderExpression(expr) + if (shouldDeclareVar.contains(expr)) { + check(expr.type?.typeName != PredefinedPrimitives.Void) { + "void cannot be rendered as var" + } + renderVarDeclaration(expr.type!!, rendered) + } + else { + rendered + } + } + + private fun doRenderExpression(expr: UTestExpression): Expression { + return when (expr) { + is UTestArithmeticExpression -> renderArithmeticExpression(expr) + is UTestArrayGetExpression -> renderArrayGetExpression(expr) + is UTestArrayLengthExpression -> renderArrayLengthExpression(expr) + is UTestBinaryConditionExpression -> renderBinaryConditionExpression(expr) + is UTestAllocateMemoryCall -> renderAllocateMemoryCall(expr) + is UTestConstructorCall -> renderConstructorCall(expr) + is UTestMethodCall -> renderMethodCall(expr) + is UTestStaticMethodCall -> renderStaticMethodCall(expr) + is UTestAssertThrowsCall -> renderAssertThrowCall(expr) + is UTestAssertEqualsCall -> renderAssertEqualsCall(expr) + is UTestCastExpression -> renderCastExpression(expr) + is UTestClassExpression -> renderClassExpression(expr) + is UTestCreateArrayExpression -> renderCreateArrayExpression(expr) + is UTestGetFieldExpression -> renderGetFieldExpression(expr) + is UTestGetStaticFieldExpression -> renderGetStaticFieldExpression(expr) + is UTestGlobalMock -> renderGlobalMock(expr) + is UTestMockObject -> renderMockObject(expr) + is UTestConstExpression<*> -> renderConstExpression(expr) + is UTestMockInst -> renderMockInst(expr) + is UTestInstList -> error("UTestInstList should not be rendered") + } + } + + fun renderConstExpression(expr: UTestConstExpression<*>): Expression = when (expr) { + is UTestBooleanExpression -> renderBooleanExpression(expr) + is UTestByteExpression -> renderByteExpression(expr) + is UTestCharExpression -> renderCharExpression(expr) + is UTestDoubleExpression -> renderDoubleExpression(expr) + is UTestFloatExpression -> renderFloatExpression(expr) + is UTestIntExpression -> renderIntExpression(expr) + is UTestLongExpression -> renderLongExpression(expr) + is UTestNullExpression -> renderNullExpression(expr) + is UTestShortExpression -> renderShortExpression(expr) + is UTestStringExpression -> renderStringExpression(expr) + } + + open fun renderArraySetStatement(stmt: UTestArraySetStatement) { + val array = renderExpression(stmt.arrayInstance) + val index = renderExpression(stmt.index) + val value = renderExpression(stmt.setValueExpression) +// val e = renderArraySet(array, index, value) + renderArraySetStatement(array, index, value) +// return e + } + + open fun renderBinaryConditionStatement(stmt: UTestBinaryConditionStatement) { + val condition = renderBinaryCondition(stmt.conditionType, stmt.lhv, stmt.rhv) + renderIfStatement( + condition = condition, + initThenBody = { + it as JcTestBlockRenderer + for (thenStmt in stmt.trueBranch) { + it.renderInst(thenStmt) + } + }, + initElseBody = { + it as JcTestBlockRenderer + for (thenStmt in stmt.trueBranch) { + it.renderInst(thenStmt) + } + } + ) + } + + open fun renderSetFieldStatement(stmt: UTestSetFieldStatement) { + val instance = renderExpression(stmt.instance) + val field = stmt.field + val value = renderExpression(stmt.value) + renderSetFieldStatement(instance, field, value) +// return renderSetField(instance, field, value) + } + + open fun renderSetStaticFieldStatement(stmt: UTestSetStaticFieldStatement) { + val field = stmt.field + val value = renderExpression(stmt.value) + renderSetStaticFieldStatement(field, value) +// return renderSetStaticField(field, value) + } + + open fun renderArithmeticExpression(expr: UTestArithmeticExpression): Expression { + val operation = when (expr.operationType) { + ArithmeticOperationType.AND -> BinaryExpr.Operator.AND + ArithmeticOperationType.PLUS -> BinaryExpr.Operator.PLUS + ArithmeticOperationType.SUB -> BinaryExpr.Operator.MINUS + ArithmeticOperationType.MUL -> BinaryExpr.Operator.MULTIPLY + ArithmeticOperationType.DIV -> BinaryExpr.Operator.DIVIDE + ArithmeticOperationType.REM -> BinaryExpr.Operator.REMAINDER + ArithmeticOperationType.EQ -> BinaryExpr.Operator.EQUALS + ArithmeticOperationType.NEQ -> BinaryExpr.Operator.NOT_EQUALS + ArithmeticOperationType.GT -> BinaryExpr.Operator.GREATER + ArithmeticOperationType.GEQ -> BinaryExpr.Operator.GREATER_EQUALS + ArithmeticOperationType.LT -> BinaryExpr.Operator.LESS + ArithmeticOperationType.LEQ -> BinaryExpr.Operator.LESS_EQUALS + ArithmeticOperationType.OR -> BinaryExpr.Operator.OR + ArithmeticOperationType.XOR -> BinaryExpr.Operator.XOR + } + + return BinaryExpr( + renderExpression(expr.lhv), + renderExpression(expr.rhv), + operation + ) + } + + open fun renderArrayGetExpression(expr: UTestArrayGetExpression): Expression = + ArrayAccessExpr(renderExpression(expr.arrayInstance), renderExpression(expr.index)) + + open fun renderArrayLengthExpression(expr: UTestArrayLengthExpression): Expression = + FieldAccessExpr(renderExpression(expr.arrayInstance), "length") + + protected fun renderBinaryCondition( + conditionType: ConditionType, + lhv: UTestExpression, + rhv: UTestExpression + ): Expression { + + val operation = when (conditionType) { + ConditionType.EQ -> BinaryExpr.Operator.EQUALS + ConditionType.NEQ -> BinaryExpr.Operator.NOT_EQUALS + ConditionType.GEQ -> BinaryExpr.Operator.GREATER_EQUALS + ConditionType.GT -> BinaryExpr.Operator.GREATER + } + + return BinaryExpr( + renderExpression(lhv), + renderExpression(rhv), + operation + ) + } + + open fun renderBinaryConditionExpression(expr: UTestBinaryConditionExpression): Expression { + val varExpr = renderVarDeclaration(expr.type!!) + val condition = renderBinaryCondition(expr.conditionType, expr.lhv, expr.rhv) + renderIfStatement( + condition = condition, + initThenBody = { + it.addExpression(renderAssign(varExpr, renderExpression(expr.trueBranch))) + }, + initElseBody = { + it.addExpression(renderAssign(varExpr, renderExpression(expr.elseBranch))) + } + ) + + return varExpr + } + + open fun renderAllocateMemoryCall(expr: UTestAllocateMemoryCall): Expression = + error("Unsafe is not supported") + + open fun renderConstructorCall(expr: UTestConstructorCall): Expression { + val type = expr.type as JcClassType + val (args, inlinesVararg) = renderCallArgs(expr.method, expr.args) + return renderConstructorCall(expr.method, type, args, inlinesVararg) + } + + open fun renderMockInst(expr: UTestMockInst): Expression { + val (args, inlinesVararg) = renderCallArgs(expr.method, expr.args) + val classExpr = ClassExpr(renderClass(expr.method.enclosingClass, false)) + val call = renderMethodCall( + expr.method, + classExpr, + args, + inlinesVararg + ) + val instance = renderExpression(expr.instance) + val whenPart = MethodCallExpr(TypeExpr(mockitoClass), "when", NodeList(call)) + val fullExpr = MethodCallExpr(whenPart, "thenReturn", NodeList(instance)) + return fullExpr +// return renderAssign(call, instance) + } + + open fun renderMethodCall(expr: UTestMethodCall): Expression { + val (args, inlinesVararg) = renderCallArgs(expr.method, expr.args) + return renderMethodCall( + expr.method, + renderExpression(expr.instance), + args, + inlinesVararg + ) + } + + open fun renderStaticMethodCall(expr: UTestStaticMethodCall): Expression { + val (args, inlinesVararg) = renderCallArgs(expr.method, expr.args) + return renderStaticMethodCall(expr.method, args, inlinesVararg) + } + + open fun renderLambdaExpression(params: List, body: List): Expression { + val lambdaBodyRenderer = newInnerBlock() + + val renderedParams = params.map { + lambdaBodyRenderer.renderMethodParameter(it.type ?: error("untyped lambda parameter")) + } + body.forEach { inst -> lambdaBodyRenderer.renderInst(inst) } + + val lambdaBody = lambdaBodyRenderer.render() + return LambdaExpr(NodeList(renderedParams), lambdaBody) + } + + open fun renderAssertThrowCall(expr: UTestAssertThrowsCall): Expression { + val exceptionType = renderClassExpression(expr.exceptionClass) + val observedLambda = renderLambdaExpression(listOf(), expr.instList) + return assertThrowsCall(exceptionType, observedLambda) + } + + open fun renderAssertEqualsCall(expr: UTestAssertEqualsCall): Expression { + val lhs = renderExpression(expr.expected) + val rhs = renderExpression(expr.actual) + + return assertEqualsCall(lhs, rhs) + } + + private fun renderCallArgs(method: JcMethod, args: List): Pair, Boolean> { + val typedParams = method.toTypedMethod.parameters + + check(args.size == typedParams.size) { + "args size != params size in call ${method.name} with ${args.joinToString(" ")}" + } + + val filteredArgs = args.toMutableList() + var inlinesVarargs = false + + if (method.isVararg) { + val vararg = args.last() + check(vararg is UTestCreateArrayExpression) { + "vararg arg expected to be array" + } + + val varargSize = vararg.size + check(varargSize is UTestIntExpression) { + "vararg size expected to be int" + } + + if (varargSize.value == 0) { + filteredArgs.removeLast() + inlinesVarargs = true + } + } + + return filteredArgs.map { arg -> renderExpression(arg) } to inlinesVarargs + } + + open fun renderCastExpression(expr: UTestCastExpression): Expression = CastExpr( + renderType(expr.type), + renderExpression(expr.expr) + ) + + open fun renderClassExpression(expr: UTestClassExpression): Expression = + renderClassExpression(expr.type) + + open fun renderBooleanExpression(expr: UTestBooleanExpression): Expression = renderBooleanPrimitive(expr.value) + + open fun renderByteExpression(expr: UTestByteExpression): Expression = renderBytePrimitive(expr.value) + + open fun renderCharExpression(expr: UTestCharExpression): Expression = renderCharPrimitive(expr.value) + + open fun renderDoubleExpression(expr: UTestDoubleExpression): Expression = renderDoublePrimitive(expr.value) + + open fun renderFloatExpression(expr: UTestFloatExpression): Expression = renderFloatPrimitive(expr.value) + + open fun renderIntExpression(expr: UTestIntExpression): Expression = renderIntPrimitive(expr.value) + + open fun renderLongExpression(expr: UTestLongExpression): Expression = renderLongPrimitive(expr.value) + + open fun renderNullExpression(expr: UTestNullExpression): Expression = NullLiteralExpr() + + open fun renderShortExpression(expr: UTestShortExpression): Expression = renderShortPrimitive(expr.value) + + open fun renderStringExpression(expr: UTestStringExpression): Expression { + val literal = StringLiteralExpr() + return literal.setString(expr.value) + } + + open fun renderCreateArrayExpression(expr: UTestCreateArrayExpression): Expression = + ArrayCreationExpr( + renderType(expr.elementType, false), + NodeList(ArrayCreationLevel(renderExpression(expr.size))), + null + ) + + open fun renderGetFieldExpression(expr: UTestGetFieldExpression): Expression { + return renderGetField(renderExpression(expr.instance), expr.field) + } + + open fun renderGetStaticFieldExpression(expr: UTestGetStaticFieldExpression): Expression { + return renderGetStaticField(expr.field) + } + + open fun renderGlobalMock(expr: UTestGlobalMock): Expression = TODO("global mocks not yet supported") + + private fun fetchDoAnswerFields( + fields: Map + ): Pair>, Map> { + val doAnswerFields = mutableMapOf>() + val commonFields = mutableMapOf() + + for ((field, value) in fields) { + val doAnswerArgDescriptor = DoAnswerArgDescriptor.fromMockFieldOrNull(field, value) + if (doAnswerArgDescriptor == null) { + commonFields.put(field, value) + } else { + doAnswerFields.getOrPut(doAnswerArgDescriptor.signature) { mutableListOf() }.add(doAnswerArgDescriptor) + } + } + + return doAnswerFields to commonFields + } + + private fun fetchDoAnswerMethods( + doAnswerFields: Map>, + methods: Map> + ): Pair>, Map>> { + val doAnswerMethods = mutableMapOf>() + val commonMethods = mutableMapOf>() + + for ((method, insts) in methods) { + val descriptor = DoAnswerInvocationDescriptor.fromMockMethodOrNull(methods, doAnswerFields, method, insts) + + if (descriptor != null) { + doAnswerMethods.getOrPut(descriptor.method) { mutableListOf() }.add(descriptor) + } else if (method.jcdbSignature !in doAnswerFields) { + commonMethods.put(method, insts) + } + } + + return doAnswerMethods to commonMethods + } + + open fun renderMockObject(expr: UTestMockObject): Expression { + val type = expr.type as JcClassType + val (spyFields, otherFields) = expr.fields.partitionByKey { it.isSpy } + + check(spyFields.size <= 1) { + "multiple spy fields found" + } + + val instanceUnderSpy = spyFields.entries.singleOrNull()?.value + + if (otherFields.isEmpty() && expr.methods.isEmpty()) { + return renderInstanceMockCreationExpressions(type, instanceUnderSpy) + } + + val (doAnswerFields, instanceFields) = fetchDoAnswerFields(otherFields) + + val (doAnswerMethods, otherMethods) = fetchDoAnswerMethods(doAnswerFields, expr.methods) + val (staticMethods, instanceMethods) = otherMethods.partitionByKey { it.isStatic } + + val mockExpr = renderInstanceMockCreationExpressions(type, instanceUnderSpy) + val mockVarNamePrefix = if (instanceUnderSpy != null) "spy" else "mocked" + + val shouldCreateInstanceMock = instanceFields.isNotEmpty() || + instanceMethods.isNotEmpty() || + doAnswerMethods.keys.any { method -> method !is JcVirtualMethod && !method.isStatic } + val mockVar: NameExpr? = + if (shouldCreateInstanceMock) + renderVarDeclaration(type, mockExpr, mockVarNamePrefix) + else + null + + if (mockVar != null) + exprCache[expr] = mockVar + + val staticMock = renderMockStaticInitializer( + type, + staticMethods, + doAnswerMethods.keys.any { it !is JcVirtualMethod && it.isStatic } + ) + + if (mockVar != null) { + renderMockInstanceFields(mockVar, instanceFields) + renderMockObjectMethods(mockVar, instanceMethods) + } + + renderMockObjectDoAnswerMethods(mockVar, staticMock, doAnswerMethods) + + return (mockVar ?: staticMock)!! + } + + private fun renderMockStaticInitializer( + type: JcClassType, + staticMethods: Map>, + hasDoAnswerMethods: Boolean + ): NameExpr? { + if (staticMethods.isEmpty() && !hasDoAnswerMethods) return null + + val staticMock = renderMockedStaticVarDeclaration(type.jcClass) + renderMockObjectMethods(staticMock, staticMethods) + return staticMock + } + + private fun renderMockInstanceFields(mockVar: NameExpr, instanceFields: Map) { + for ((field, fieldValue) in instanceFields) { + val renderedFieldValue = renderExpression(fieldValue) + renderSetFieldStatement(mockVar, field, renderedFieldValue) + } + } + + private fun renderMockObjectMethods(mockVar: NameExpr, methods: Map>) { + for ((method, mockValues) in methods) { + if (mockValues.isEmpty()) + continue + + val mockInitialization = renderSingleMockObjectMethod(mockVar, method, mockValues) + + addExpression(mockInitialization) + } + } + + private data class DoAnswerInvocationDescriptor( + val method: JcMethod, + val index: Int, + val args: Map, + val instList: UTestInstList, + val returnValue: UTestExpression? + ) { + companion object { + fun fromMockMethodOrNull( + allMockMethods: Map>, + sigToArgs: Map>, + method: JcMethod, + insts: List + ): DoAnswerInvocationDescriptor? { + val sigAndInvokeIdx = method.name.split("$\$_invocation_") + if (sigAndInvokeIdx.size != 2) return null + + val signature = sigAndInvokeIdx.first() + + val method = allMockMethods.entries.single { (method, _) -> method.jcdbSignature == signature } + val invocationIndex = sigAndInvokeIdx.last().toInt() + + check(insts.first() is UTestInstList) { + "bad doAnswer effect descriptor" + } + + val instList = insts.first() as UTestInstList + val retVal = method.value.getOrNull(invocationIndex) + + val args = sigToArgs[signature]?.filter { descriptor -> + descriptor.invocation == invocationIndex + }?.associate { descriptor -> + descriptor.expr to descriptor.position + } ?: emptyMap() + + return DoAnswerInvocationDescriptor(method.key, invocationIndex, args, instList, retVal) + } + } + } + + private data class DoAnswerArgDescriptor( + val signature: String, + val position: Int, + val invocation: Int, + val expr: UTestAllocateMemoryCall + ) { + companion object { + fun fromMockFieldOrNull(field: JcField, value: UTestExpression): DoAnswerArgDescriptor? { + val rawTokens = field.name.split("_method_$$") + if (rawTokens.size != 2 || !rawTokens.first().startsWith("arg_")) return null + + val idx = rawTokens.first().drop(4).takeWhile { it.isDigit() }.toInt() + + val sigAndInvocation = rawTokens.last().split("$\$_invocation_") + if (sigAndInvocation.size != 2) return null + + val signature = sigAndInvocation.first() + val invocationIdx = sigAndInvocation.last().toInt() + return DoAnswerArgDescriptor(signature, idx, invocationIdx, value as UTestAllocateMemoryCall) + } + } + } + + private fun renderMockObjectDoAnswerMethods( + mockVar: NameExpr?, + mockStaticUtil: NameExpr?, + invocations: Map> + ) { + check(mockVar != null || mockStaticUtil != null) { + "either instance or static mock should not be null" + } + + val invocationOnMockType = renderClass("org.mockito.invocation.InvocationOnMock") + + for ((method, invokesList) in invocations) { + + val args = mockMethodMatchersList(method.toTypedMethod) + val initReceiver = renderInitialDoAnswerReceiver(method, mockStaticUtil, args) + + val doAnswerChain = invokesList.sortedBy { descriptor -> + descriptor.index + }.fold(initReceiver) { currentReceiver, invokeDescriptor -> + val lambda = renderDoAnswerLambda(invokeDescriptor, invocationOnMockType) + chainDoAnswerCalls(method, currentReceiver, lambda) + } + + val res = renderFinalDoAnswerExpression(method, doAnswerChain, mockVar, args) + addExpression(res) + } + } + + private fun renderInitialDoAnswerReceiver( + method: JcMethod, + mockStaticUtil: NameExpr?, + args: List + ): Expression = when { + method.isStatic -> renderMockObjectStaticMethodWhenCall(mockStaticUtil!!, method, args) + else -> TypeExpr(mockitoClass) + } + + /* + * TODO: replace lambda rendering with more general approach + * example: invocation -> { return invocation.getArgument(0) + invocation.getArgument(1); } + * relies on "return value is always generated outside of lambda" + */ + private fun renderDoAnswerLambda( + invokeDescriptor: DoAnswerInvocationDescriptor, + invocationOnMockType: ReferenceType + ): LambdaExpr { + val retExpr = invokeDescriptor.returnValue?.let { renderExpression(it) } ?: NullLiteralExpr() + + val lambdaBlock = newInnerBlock() + val lambdaVarName = lambdaBlock.identifiersManager["invocationOnMock"] + + for ((expr, argPos) in invokeDescriptor.args) { + + val argInitializer = MethodCallExpr( + NameExpr(lambdaVarName), + "getArgument", + NodeList(IntegerLiteralExpr(argPos.toString())) + ) + + lambdaBlock.exprCache[expr] = lambdaBlock.renderVarDeclaration(expr.type, argInitializer) + } + + for (inst in invokeDescriptor.instList.instList) { + lambdaBlock.renderInst(inst) + } + + lambdaBlock.addStatement(ReturnStmt(retExpr)) + + return LambdaExpr(NodeList(Parameter(invocationOnMockType, lambdaVarName)), lambdaBlock.render(), true) + } + + private fun chainDoAnswerCalls( + method: JcMethod, + currentReceiver: Expression, + lambda: LambdaExpr + ): Expression = when { + method.isStatic -> mockitoThenAnswerMethodCall(currentReceiver, lambda) + else -> mockitoDoAnswerMethodCall(currentReceiver, lambda) + } + + private fun renderFinalDoAnswerExpression( + method: JcMethod, + receiver: Expression, + mockVar: NameExpr?, + args: List + ): Expression = when { + method.isStatic -> receiver + + else -> { + check(mockVar != null) { + "instance mock cannot be null for instance doAnswer" + } + + val whenCall = mockitoWhenMethodCall(receiver, mockVar) + renderMethodCall(method, whenCall, args, method.isVararg) + } + } + + private fun renderInstanceMockCreationExpressions(type: JcClassType, instanceUnderSpy: UTestExpression?): Expression { + return when (instanceUnderSpy) { + null -> { + mockitoMockMethodCall(type) + } + + is UTestNullExpression -> { + mockitoSpyClassMethodCall(type) + } + + else -> { + val instanceToSpy = renderExpression(instanceUnderSpy) + mockitoSpyInstanceMethodCall(instanceToSpy) + } + } + } + + private fun mockMethodMatchersList(typedMethod: JcTypedMethod): List { + val args = typedMethod.parameters.map { param -> + when (param.type.typeName) { + PredefinedPrimitives.Boolean -> mockitoAnyBooleanMethodCall() + PredefinedPrimitives.Byte -> mockitoAnyByteMethodCall() + PredefinedPrimitives.Char -> mockitoAnyCharMethodCall() + PredefinedPrimitives.Short -> mockitoAnyShortMethodCall() + PredefinedPrimitives.Int -> mockitoAnyIntMethodCall() + PredefinedPrimitives.Long -> mockitoAnyLongMethodCall() + PredefinedPrimitives.Float -> mockitoAnyFloatMethodCall() + PredefinedPrimitives.Double -> mockitoAnyDoubleMethodCall() + else -> mockitoAnyMethodCall(param.type) + } + } + return args + } + + private fun isExceptionalMockMethodArgs(mockValues: List): Boolean = mockValues.all { value -> + value is UTestClassExpression && value.type.isAssignable(cp.findType("java.lang.Throwable")) + } + + private fun renderSingleMockObjectMethod( + mockVar: NameExpr, + method: JcMethod, + mockValues: List + ): Expression { + check(method.returnType.typeName != PredefinedPrimitives.Void || isExceptionalMockMethodArgs(mockValues)) { + "non-exceptional void mock" + } + + val typedMethod = method.toTypedMethod + val args = mockMethodMatchersList(typedMethod) + + val mockWhenCall = + if (method.isStatic) + renderMockObjectStaticMethodWhenCall(mockVar, method, args) + else + renderMockObjectInstanceMethodWhenCall(mockVar, method, args) + + val methodReturnType = typedMethod.returnType + val renderedReturnType = renderType(methodReturnType) + val mockedMethod = mockValues.fold(mockWhenCall) { mock, nextReturnValue -> + var renderedMockValue = exprWithGenericsCasted(methodReturnType, renderExpression(nextReturnValue)) + + /* + * TODO: fresh var required when mocked method M of class T use another method from T + * require optimisations + */ + if (method.isStatic) { + renderedMockValue = renderVarDeclaration(renderedReturnType, renderedMockValue) + } + + if (nextReturnValue is UTestClassExpression && nextReturnValue.type.isAssignable(cp.findType("java.lang.Throwable"))) + mockitoThenThrowMethodCall(mock, renderedMockValue) + else + mockitoThenReturnMethodCall(mock, renderedMockValue) + } + + return mockedMethod + } + + private fun renderMockObjectInstanceMethodWhenCall( + mockVar: NameExpr, + method: JcMethod, + args: List + ): Expression { + val methodCall = renderMethodCall(method, mockVar, args, false) + return mockitoWhenMethodCall(TypeExpr(mockitoClass), methodCall) + } + + private fun renderMockObjectStaticMethodWhenCall(mockVar: NameExpr, method: JcMethod, args: List): Expression { + val mockedMethodRef = LambdaExpr(NodeList(), renderStaticMethodCall(method, args, false)) + return mockitoWhenMethodCall(mockVar, mockedMethodRef) + } + + private fun renderMockedStaticVarDeclaration(mockedClass: JcClassOrInterface): NameExpr { + val mockMethodDeclType = renderClass(mockedClass) + + val mockedStaticType = renderClass("org.mockito.MockedStatic").setTypeArguments(mockMethodDeclType) + + val mockStaticCall = mockitoMockStaticMethodCall(mockedClass) + + val mockStaticUtil = renderVarDeclaration(mockedStaticType, mockStaticCall, "staticMockUtil") + + val mockStaticDefferedClose = MethodCallExpr(mockStaticUtil, "close") + methodRenderer.trailingExpressions.add(mockStaticDefferedClose) + return mockStaticUtil + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestClassRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestClassRenderer.kt new file mode 100644 index 0000000000..0b92bc445d --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestClassRenderer.kt @@ -0,0 +1,63 @@ +package org.usvm.jvm.rendering.testRenderer + +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.MarkerAnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcClassRenderer +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.baseRenderer.JcTestFrameworkProvider +import org.usvm.test.api.UTest + +open class JcTestClassRenderer : JcClassRenderer { + + constructor( + name: String, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath + ) : super(name, importManager, identifiersManager, cp) + + constructor( + decl: ClassOrInterfaceDeclaration, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + ): super(decl, importManager, identifiersManager, cp) + + protected val testAnnotation: AnnotationExpr by lazy { + val annotationName = renderClass(JcTestFrameworkProvider.testAnnotationClassName) + MarkerAnnotationExpr(annotationName.nameWithScope) + } + + protected open fun createTestRenderer( + test: UTest, + identifiersManager: JcIdentifiersManager, + name: SimpleName, + annotations: List, + ): JcTestRenderer { + return JcTestRenderer( + test, + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + name, + annotations + ) + } + + fun addTest(test: UTest, namePrefix: String? = null): JcTestRenderer { + val renderer = createTestRenderer( + test, + JcIdentifiersManager(identifiersManager), + identifiersManager[namePrefix ?: "test"], + listOf(testAnnotation) + ) + + addRenderingMethod(renderer) + return renderer + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestFileRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestFileRenderer.kt new file mode 100644 index 0000000000..9eb0b144a7 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestFileRenderer.kt @@ -0,0 +1,40 @@ +package org.usvm.jvm.rendering.testRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcFileRenderer +import org.usvm.jvm.rendering.baseRenderer.JcImportManager + +open class JcTestFileRenderer : JcFileRenderer { + protected constructor(cu: CompilationUnit, importManager: JcImportManager, cp: JcClasspath) : super(cu, importManager, cp) + + protected constructor(packageName: String?, importManager: JcImportManager, cp: JcClasspath) : super(packageName, importManager, cp) + + constructor(cu: CompilationUnit, cp: JcClasspath) : this(cu, JcImportManager(cu), cp) + + constructor(packageName: String?, cp: JcClasspath) : this(packageName, JcImportManager(), cp) + + override fun classRendererFor(declaration: ClassOrInterfaceDeclaration): JcTestClassRenderer { + return JcTestClassRenderer(declaration, importManager, identifiersManager, cp) + } + + protected open fun classRendererFor(name: String): JcTestClassRenderer = + JcTestClassRenderer( + name, + importManager, + identifiersManager, + cp + ) + + fun getOrAddClass(testClassName: String): JcTestClassRenderer { + val existing = findRenderingClass(SimpleName(testClassName)) + if (existing != null) return existing as JcTestClassRenderer + + val renderer = classRendererFor(testClassName) + + addRenderingClass(renderer) + return renderer + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestInfo.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestInfo.kt new file mode 100644 index 0000000000..a66fd7d1a0 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestInfo.kt @@ -0,0 +1,33 @@ +package org.usvm.jvm.rendering.testRenderer + +import java.nio.file.Path +import org.jacodb.api.jvm.JcMethod +import org.usvm.jvm.rendering.normalized + +abstract class JcTestInfo( + val method: JcMethod, + val isExceptional: Boolean? = null, + val testFilePath: Path?, + val testPackageName: String?, + val testClassName: String?, + val testName: String? +) { + private val defaultNamePrefix: String get() = "${method.name}$isExceptionalSuffix".normalized + + val testNamePrefix: String get() = testName ?: defaultNamePrefix + + private val isExceptionalSuffix: String + get() = when (isExceptional) { + true -> "Exceptional" + else -> "" + } + + override fun hashCode(): Int { + return method.hashCode() + } + + override fun equals(other: Any?): Boolean { + if (other == null || other !is JcTestInfo) return false + return method == other.method && testFilePath == other.testFilePath && testClassName == other.testClassName + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestRenderer.kt new file mode 100644 index 0000000000..abfde36f95 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestRenderer.kt @@ -0,0 +1,82 @@ +package org.usvm.jvm.rendering.testRenderer + +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.body.MethodDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.SimpleName +import java.util.Collections +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.baseRenderer.JcMethodRenderer +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.test.api.UTest +import org.usvm.test.api.UTestConstExpression +import org.usvm.test.api.UTestExpression +import java.util.IdentityHashMap +import org.jacodb.api.jvm.JcClasspath + +open class JcTestRenderer( + private val test: UTest, + override val classRenderer: JcTestClassRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + name: SimpleName, + annotations: List, +): JcMethodRenderer( + importManager, + identifiersManager, + cp, + classRenderer, + name, + NodeList(), + NodeList(annotations), + NodeList(), + voidType +) { + + protected val shouldDeclareVar: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + internal val trailingExpressions: MutableList = mutableListOf() + + override val body: JcTestBlockRenderer = JcTestBlockRenderer( + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar + ) + + open fun requireVarDeclarationOf(expr: UTestExpression): Boolean = false + + open fun preventVarDeclarationOf(expr: UTestExpression): Boolean = expr is UTestConstExpression<*> + + inner class JcExprUsageVisitor: JcTestVisitor() { + + private fun shouldDeclareVarCheck(expr: UTestExpression): Boolean { + return !preventVarDeclarationOf(expr) && isVisited(expr) || requireVarDeclarationOf(expr) + } + + override fun visitExpr(expr: UTestExpression) { + if (shouldDeclareVarCheck(expr)) + shouldDeclareVar.add(expr) + + super.visitExpr(expr) + } + } + + init { + JcExprUsageVisitor().visit(test) + } + + override fun renderInternal(): MethodDeclaration { + for (inst in test.initStatements) + body.renderInst(inst) + + body.renderInst(test.callMethodExpression) + + trailingExpressions.forEach { expr -> body.addExpression(expr) } + + return super.renderInternal() + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestVisitor.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestVisitor.kt new file mode 100644 index 0000000000..1ab160746d --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testRenderer/JcTestVisitor.kt @@ -0,0 +1,241 @@ +package org.usvm.jvm.rendering.testRenderer + +import org.usvm.test.api.* +import java.util.Collections +import java.util.IdentityHashMap + +open class JcTestVisitor { + + private val cache: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + protected fun isVisited(inst: UTestInst) = cache.contains(inst) + + protected fun clearVisited() = cache.clear() + + fun visit(test: UTest) { + for (inst in test.initStatements) + visit(inst) + + visit(test.callMethodExpression as UTestInst) + } + + open fun visit(inst: UTestInst) { + when (inst) { + is UTestExpression -> visitExpr(inst) + is UTestStatement -> visitStmt(inst) + } + } + + protected open fun visitExpr(expr: UTestExpression) { + if (!cache.add(expr)) + return + + visit(expr) + } + + open fun visit(expr: UTestExpression) { + when (expr) { + is UTestArithmeticExpression -> visit(expr) + is UTestArrayGetExpression -> visit(expr) + is UTestArrayLengthExpression -> visit(expr) + is UTestBinaryConditionExpression -> visit(expr) + is UTestCastExpression -> visit(expr) + is UTestCreateArrayExpression -> visit(expr) + is UTestGetFieldExpression -> visit(expr) + is UTestCall -> visit(expr) + is UTestClassExpression -> visit(expr) + is UTestBooleanExpression -> visit(expr) + is UTestByteExpression -> visit(expr) + is UTestCharExpression -> visit(expr) + is UTestDoubleExpression -> visit(expr) + is UTestFloatExpression -> visit(expr) + is UTestIntExpression -> visit(expr) + is UTestLongExpression -> visit(expr) + is UTestNullExpression -> visit(expr) + is UTestShortExpression -> visit(expr) + is UTestStringExpression -> visit(expr) + is UTestGetStaticFieldExpression -> visit(expr) + is UTestGlobalMock -> visit(expr) + is UTestMockObject -> visit(expr) + is UTestInstList -> visit(expr) + is UTestMockInst -> visit(expr) + } + } + + protected open fun visitStmt(stmt: UTestStatement) { + if (!cache.add(stmt)) + return + + visit(stmt) + } + + open fun visit(stmt: UTestStatement) { + when (stmt) { + is UTestArraySetStatement -> visit(stmt) + is UTestBinaryConditionStatement -> visit(stmt) + is UTestSetFieldStatement -> visit(stmt) + is UTestSetStaticFieldStatement -> visit(stmt) + } + } + + //region Expressions + + open fun visit(expr: UTestArithmeticExpression) { + visitExpr(expr.lhv) + visitExpr(expr.rhv) + } + + open fun visit(expr: UTestArrayGetExpression) { + visitExpr(expr.arrayInstance) + visitExpr(expr.index) + } + + open fun visit(expr: UTestArrayLengthExpression) { + visitExpr(expr.arrayInstance) + } + + open fun visit(expr: UTestBinaryConditionExpression) { + visitExpr(expr.lhv) + visitExpr(expr.rhv) + visitExpr(expr.trueBranch) + visitExpr(expr.elseBranch) + } + + open fun visit(expr: UTestCastExpression) { + visitExpr(expr.expr) + } + + open fun visit(expr: UTestCreateArrayExpression) { + visitExpr(expr.size) + } + + open fun visit(expr: UTestGetFieldExpression) { + visitExpr(expr.instance) + } + + open fun visit(expr: UTestGetStaticFieldExpression) { } + + open fun visit(expr: UTestClassExpression) { } + + open fun visit(expr: UTestBooleanExpression) { } + open fun visit(expr: UTestByteExpression) { } + open fun visit(expr: UTestCharExpression) { } + open fun visit(expr: UTestDoubleExpression) { } + open fun visit(expr: UTestFloatExpression) { } + open fun visit(expr: UTestIntExpression) { } + open fun visit(expr: UTestLongExpression) { } + open fun visit(expr: UTestNullExpression) { } + open fun visit(expr: UTestShortExpression) { } + open fun visit(expr: UTestStringExpression) { } + open fun visit(expr: UTestMockInst) { } + + //endregion + + //region Mocks + + open fun visit(expr: UTestGlobalMock) { + for (fieldValue in expr.fields.values) { + visitExpr(fieldValue) + } + + for (methodValues in expr.methods.values) { + for (value in methodValues) { + visitExpr(value) + } + } + } + + open fun visit(expr: UTestMockObject) { + for (fieldValue in expr.fields.values) { + visitExpr(fieldValue) + } + + for (methodValues in expr.methods.values) { + for (value in methodValues) { + visitExpr(value) + } + } + } + + open fun visit(expr: UTestInstList) { + for (inst in expr.instList) { + visit(inst) + } + } + + //endregion + + //region Calls + + open fun visit(call: UTestCall) { + when (call) { + is UTestConstructorCall -> visit(call) + is UTestStaticMethodCall -> visit(call) + is UTestMethodCall -> visit(call) + is UTestAllocateMemoryCall -> visit(call) + is UTestAssertThrowsCall -> visit(call) + is UTestAssertEqualsCall -> visit(call) + } + } + + open fun visit(call: UTestConstructorCall) { + for (arg in call.args) + visitExpr(arg) + } + + open fun visit(call: UTestStaticMethodCall) { + for (arg in call.args) + visitExpr(arg) + } + + open fun visit(call: UTestMethodCall) { + visitExpr(call.instance) + for (arg in call.args) + visitExpr(arg) + } + + open fun visit(call: UTestAllocateMemoryCall) { } + + open fun visit(call: UTestAssertThrowsCall) { + for (inst in call.instList) { + visit(inst) + } + } + + open fun visit(call: UTestAssertEqualsCall) { + visit(call.expected) + visit(call.actual) + } + + //endregion + + //region Statements + + open fun visit(stmt: UTestArraySetStatement) { + visitExpr(stmt.arrayInstance) + visitExpr(stmt.index) + visitExpr(stmt.setValueExpression) + } + + open fun visit(stmt: UTestBinaryConditionStatement) { + visitExpr(stmt.lhv) + visitExpr(stmt.rhv) + + for (thenStatement in stmt.trueBranch) + visitStmt(thenStatement) + + for (elseStatement in stmt.elseBranch) + visitStmt(elseStatement) + } + + open fun visit(stmt: UTestSetFieldStatement) { + visitExpr(stmt.instance) + visitExpr(stmt.value) + } + + open fun visit(stmt: UTestSetStaticFieldStatement) { + visitExpr(stmt.value) + } + + //endregion +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcCallCtorTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcCallCtorTransformer.kt new file mode 100644 index 0000000000..25bffeebae --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcCallCtorTransformer.kt @@ -0,0 +1,16 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.usvm.test.api.UTestCall +import org.usvm.test.api.UTestConstructorCall +import org.usvm.test.api.UTestMethodCall + +class JcCallCtorTransformer: JcTestTransformer() { + + override fun transform(call: UTestMethodCall): UTestCall? { + if (!call.method.isConstructor) + return super.transform(call) + + val args = call.args.map { transformExpr(it) ?: return null } + return UTestConstructorCall(call.method, args) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcDeadCodeTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcDeadCodeTransformer.kt new file mode 100644 index 0000000000..b059f84c23 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcDeadCodeTransformer.kt @@ -0,0 +1,189 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.usvm.jvm.rendering.testRenderer.JcTestVisitor +import java.util.* +import org.jacodb.api.jvm.ext.objectType +import org.usvm.test.api.UTest +import org.usvm.test.api.UTestAllocateMemoryCall +import org.usvm.test.api.UTestArraySetStatement +import org.usvm.test.api.UTestCall +import org.usvm.test.api.UTestConstructorCall +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestGlobalMock +import org.usvm.test.api.UTestInst +import org.usvm.test.api.UTestMockObject +import org.usvm.test.api.UTestSetFieldStatement +import org.usvm.test.api.UTestSetStaticFieldStatement +import org.usvm.test.api.UTestStatement + +class JcDeadCodeTransformer: JcTestTransformer() { + private val roots: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + private val reachable: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + private val rootFetcher get() = ReachabilityRootFetcher(roots) + + private class ReachabilityRootFetcher(val roots: MutableSet): JcTestVisitor() { + val fetched: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + fun fetchFrom(expr: UTestExpression): MutableSet = fetchFrom(listOf(expr)) + fun fetchFrom(exprs: List): MutableSet { + + for (expr in exprs) { + visitExpr(expr) + } + + return fetched + } + + override fun visit(expr: UTestExpression) { + if (expr in roots) { + fetched.add(expr) + return + } + super.visit(expr) + } + } + + private class ReachabilityCollector( + val roots: MutableSet, + val reachable: MutableSet + ) : JcTestVisitor() { + private var marker = false + + override fun visit(inst: UTestInst) { + marker = false + super.clearVisited() + super.visit(inst) + } + + override fun visit(expr: UTestExpression) { + if (expr in reachable) return + + if (marker) { + reachable.add(expr) + withReachable { super.visit(expr) } + return + } + + super.visit(expr) + } + + override fun visit(stmt: UTestSetStaticFieldStatement) { + if (stmt.value in roots) return + + roots.add(stmt.value) + reachable.add(stmt.value) + withReachable { super.visit(stmt) } + } + + override fun visit(stmt: UTestSetFieldStatement) { + if (stmt.instance in reachable || stmt.value in reachable) { + withReachable { super.visit(stmt) } + return + } + super.visit(stmt) + } + + override fun visit(stmt: UTestArraySetStatement) { + if (stmt.arrayInstance in reachable || stmt.setValueExpression in reachable) { + withReachable { super.visit(stmt) } + return + } + super.visit(stmt) + } + + override fun visit(call: UTestCall) { + if (call is UTestAllocateMemoryCall || call in roots) return + + if (call !is UTestConstructorCall || call.type != call.type.classpath.objectType) { + roots.add(call) + } + + reachable.add(call) + withReachable { super.visit(call) } + } + + override fun visit(expr: UTestMockObject) { + if (expr in roots) return + + roots.add(expr) + reachable.add(expr) + withReachable { super.visit(expr) } + } + + override fun visit(expr: UTestGlobalMock) { + if (expr in roots) return + + roots.add(expr) + reachable.add(expr) + withReachable { super.visit(expr) } + } + + private fun withReachable(block: () -> Unit) { + val prevReachability = marker + marker = true + try { + block() + } finally { + marker = prevReachability + } + } + } + + override fun transform(test: UTest): UTest { + roots.clear() + reachable.clear() + + propagateReachabilityIn(test) + + val filteredInitInstList = test.initStatements.flatMap { + transformInstProxy(it) + }.filterNotNull() + + val callMethodExpression = transformCall(test.callMethodExpression) ?: error("call must be present in UTest") + + return UTest(filteredInitInstList, callMethodExpression) + } + + private fun propagateReachabilityIn(test: UTest) { + val clone: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + do { + clone.clear() + clone.addAll(reachable) + check(clone.size == reachable.size) { + "clone is not the size of reachable" + } + ReachabilityCollector(roots, reachable).visit(test) + } while (clone.size != reachable.size) + } + + private fun transformExprs(expr: UTestExpression): List{ + if (expr in reachable) + return listOf(super.transformExpr(expr)) + + return rootFetcher.fetchFrom(expr).map { super.transformExpr(it) }.toList() + } + + private fun transformInstProxy(inst: UTestInst): List { + return when (inst) { + is UTestArraySetStatement -> transformArraySet(inst) + is UTestSetFieldStatement -> transformFieldSet(inst) + is UTestExpression -> transformExprs(inst) + else -> listOf(super.transformInst(inst)) + } + } + + private fun transformArraySet(stmt: UTestArraySetStatement): List { + return keepStatementOrFetchRoots(stmt, stmt.arrayInstance, listOf(stmt.arrayInstance, stmt.index, stmt.setValueExpression)) + } + + private fun transformFieldSet(stmt: UTestSetFieldStatement): List { + return keepStatementOrFetchRoots(stmt, stmt.instance, listOf(stmt.instance, stmt.value)) + } + + private fun keepStatementOrFetchRoots(stmt: UTestStatement, instance: UTestExpression, targets: List): List { + if (instance in reachable) return listOf(super.transformStmt(stmt)) + return rootFetcher.fetchFrom(targets).map { super.transformExpr(it) }.toList() + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcInstDuplicationTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcInstDuplicationTransformer.kt new file mode 100644 index 0000000000..47c01c4141 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcInstDuplicationTransformer.kt @@ -0,0 +1,33 @@ +package org.usvm.jvm.rendering.testTransformers + +import java.util.Collections +import java.util.IdentityHashMap +import org.usvm.test.api.UTest +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestInst +import org.usvm.test.api.UTestStatement + +class JcInstDuplicationTransformer: JcTestTransformer() { + private val visited: MutableSet = Collections.newSetFromMap(IdentityHashMap()) + + override fun transform(test: UTest): UTest { + val initInstList = test.initStatements.mapNotNull { + transformInstProxy(it) + } + return UTest(initInstList, test.callMethodExpression) + } + + private fun transformInstProxy(inst: UTestInst): UTestInst? = when(inst) { + in visited -> null + else -> super.transformInst(inst) + } + override fun transform(stmt: UTestStatement): UTestStatement? { + visited.add(stmt) + return super.transform(stmt) + } + + override fun transform(expr: UTestExpression): UTestExpression? { + visited.add(expr) + return super.transform(expr) + } +} \ No newline at end of file diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcOuterThisTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcOuterThisTransformer.kt new file mode 100644 index 0000000000..9e068a10aa --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcOuterThisTransformer.kt @@ -0,0 +1,19 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.jacodb.api.jvm.JcField +import org.usvm.test.api.UTestNullExpression +import org.usvm.test.api.UTestSetFieldStatement +import org.usvm.test.api.UTestStatement + +class JcOuterThisTransformer: JcTestTransformer() { + + private val JcField.isOuterThisField: Boolean + get() = name == "this$0" + + override fun transform(stmt: UTestSetFieldStatement): UTestStatement? { + if (stmt.field.isOuterThisField && stmt.value is UTestNullExpression) + return null + + return super.transform(stmt) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcPrimitiveWrapperTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcPrimitiveWrapperTransformer.kt new file mode 100644 index 0000000000..8f01c5529a --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcPrimitiveWrapperTransformer.kt @@ -0,0 +1,41 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.jacodb.api.jvm.JcField +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestSetFieldStatement +import org.usvm.test.api.UTestStatement +import java.util.IdentityHashMap + +class JcPrimitiveWrapperTransformer: JcTestTransformer() { + + private val toReplace = IdentityHashMap() + + private val primitiveWrappers = setOf( + "java.lang.Boolean", + "java.lang.Short", + "java.lang.Integer", + "java.lang.Long", + "java.lang.Float", + "java.lang.Double", + "java.lang.Byte", + "java.lang.Character" + ) + + private val JcField.isWrapperValueField: Boolean + get() = name == "value" && primitiveWrappers.contains(enclosingClass.name) + + override fun transform(stmt: UTestSetFieldStatement): UTestStatement? { + if (!stmt.field.isWrapperValueField) + return super.transform(stmt) + + val oldValue = toReplace.put(stmt.instance, transformExpr(stmt.value)) + check(oldValue == null) { + "old primitive value is not null" + } + return null + } + + override fun transform(expr: UTestExpression): UTestExpression? { + return toReplace[expr] ?: return super.transform(expr) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcSpringMvcTestTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcSpringMvcTestTransformer.kt new file mode 100644 index 0000000000..ea0e49e162 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcSpringMvcTestTransformer.kt @@ -0,0 +1,82 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcMethod +import org.usvm.test.api.UTestCall +import org.usvm.test.api.UTestClassExpression +import org.usvm.test.api.UTestConstructorCall +import org.usvm.test.api.UTestMethodCall +import org.usvm.test.api.UTestStaticMethodCall + +class JcSpringMvcTestTransformer: JcTestTransformer() { + + private var mvcTestClass: JcClassOrInterface? = null + + private val testContextManagerName = "org.springframework.test.context.TestContextManager" + + private val JcClassOrInterface.isTestContextManager: Boolean get() { + return name == testContextManagerName + } + + private val JcMethod.isIgnoreResultMethod: Boolean get() { + return name == "ignoreResult" && enclosingClass.name == mvcTestClass?.name + } + + private val JcMethod.isPrepareInstanceMethod: Boolean get() { + return name == "prepareTestInstance" && enclosingClass.isTestContextManager + } + + private val JcMethod.isBeforeTestClass: Boolean get() { + return name == "beforeTestClass" && enclosingClass.isTestContextManager + } + + private val JcMethod.isAfterTestMethod: Boolean get() { + return name == "afterTestMethod" && enclosingClass.isTestContextManager + } + + private val JcMethod.isBeforeTestMethod: Boolean get() { + return name == "beforeTestMethod" && enclosingClass.isTestContextManager + } + + private val JcMethod.isInitTestCtxMethod: Boolean get() { + return isPrepareInstanceMethod || isBeforeTestClass || isAfterTestMethod || isBeforeTestMethod + } + + override fun transform(call: UTestMethodCall): UTestCall? { + val method = call.method + + if (method.isPrepareInstanceMethod) { + val instance = call.instance + check(instance is UTestConstructorCall && instance.method.enclosingClass.name == testContextManagerName) { + "isPrepareInstanceMethod instance fail" + } + + val arg = instance.args.singleOrNull() as? UTestClassExpression + check(arg != null) { + "isPrepareInstanceMethod arg fail" + } + + mvcTestClass = (arg.type as JcClassType).jcClass + + return null + } + + if (method.isInitTestCtxMethod) { + return null + } + + return super.transform(call) + } + + override fun transform(call: UTestStaticMethodCall): UTestCall? { + val method = call.method + + if (method.isIgnoreResultMethod) + return super.transform(call.args.single() as UTestMethodCall) + + return super.transform(call) + } + + val testClass get() = mvcTestClass ?: error("testClass not found") +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcTestTransformer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcTestTransformer.kt new file mode 100644 index 0000000000..35e68336b0 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/testTransformers/JcTestTransformer.kt @@ -0,0 +1,268 @@ +package org.usvm.jvm.rendering.testTransformers + +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.usvm.test.api.* +import java.util.IdentityHashMap + +// TODO: add CompositeTransformer +abstract class JcTestTransformer { + + private val cache = IdentityHashMap() + + protected fun transformed(expr: UTestExpression): UTestExpression? = cache[expr] + + open fun transform(test: UTest): UTest { + val initStatements = test.initStatements.mapNotNull { transformInst(it) } + val callMethodExpression = transformCall(test.callMethodExpression) + ?: error("Transformers should not delete result method call") + + return UTest(initStatements, callMethodExpression) + } + + protected fun transformInst(inst: UTestInst): UTestInst? { + return when (inst) { + is UTestExpression -> transformExpr(inst) + is UTestStatement -> transformStmt(inst) + } + } + + protected fun transformExpr(expr: UTestExpression): UTestExpression? { + return cache.getOrPut(expr) { transform(expr) } + } + + protected fun transformCall(call: UTestCall): UTestCall? { + return cache.getOrPut(call) { transform(call) } as? UTestCall + } + + protected fun transformStmt(stmt: UTestStatement): UTestStatement? { + return transform(stmt) + } + + //region Expression transformers + + open fun transform(expr: UTestExpression): UTestExpression? { + return when (expr) { + is UTestArithmeticExpression -> transform(expr) + is UTestArrayGetExpression -> transform(expr) + is UTestArrayLengthExpression -> transform(expr) + is UTestBinaryConditionExpression -> transform(expr) + is UTestCastExpression -> transform(expr) + is UTestCreateArrayExpression -> transform(expr) + is UTestGetFieldExpression -> transform(expr) + is UTestCall -> transformCall(expr) + is UTestClassExpression -> transform(expr) + is UTestBooleanExpression -> transform(expr) + is UTestByteExpression -> transform(expr) + is UTestCharExpression -> transform(expr) + is UTestDoubleExpression -> transform(expr) + is UTestFloatExpression -> transform(expr) + is UTestIntExpression -> transform(expr) + is UTestLongExpression -> transform(expr) + is UTestNullExpression -> transform(expr) + is UTestShortExpression -> transform(expr) + is UTestStringExpression -> transform(expr) + is UTestGetStaticFieldExpression -> transform(expr) + is UTestGlobalMock -> transform(expr) + is UTestMockObject -> transform(expr) + is UTestInstList -> transform(expr) + is UTestMockInst -> transform(expr) + } + } + + open fun transform(expr: UTestArithmeticExpression): UTestExpression? { + val lhv = transformExpr(expr.lhv) ?: return null + val rhv = transformExpr(expr.rhv) ?: return null + return UTestArithmeticExpression(expr.operationType, lhv, rhv, expr.type) + } + + open fun transform(expr: UTestArrayGetExpression): UTestExpression? { + val array = transformExpr(expr.arrayInstance) ?: return null + val index = transformExpr(expr.index) ?: return null + return UTestArrayGetExpression(array, index) + } + + open fun transform(expr: UTestArrayLengthExpression): UTestExpression? { + val array = transformExpr(expr.arrayInstance) ?: return null + return UTestArrayLengthExpression(array) + } + + open fun transform(expr: UTestBinaryConditionExpression): UTestExpression? { + val lhv = transformExpr(expr.lhv) ?: return null + val rhv = transformExpr(expr.rhv) ?: return null + val trueBranch = transformExpr(expr.trueBranch) ?: return null + val elseBranch = transformExpr(expr.elseBranch) ?: return null + return UTestBinaryConditionExpression(expr.conditionType, lhv, rhv, trueBranch, elseBranch) + } + + open fun transform(expr: UTestCastExpression): UTestExpression? { + val toCastExpr = transformExpr(expr.expr) ?: return null + return UTestCastExpression(toCastExpr, expr.type) + } + + open fun transform(expr: UTestCreateArrayExpression): UTestExpression? { + val size = transformExpr(expr.size) ?: return null + return UTestCreateArrayExpression(expr.elementType, size) + } + + open fun transform(expr: UTestGetFieldExpression): UTestExpression? { + val instance = transformExpr(expr.instance) ?: return null + return UTestGetFieldExpression(instance, expr.field) + } + + open fun transform(expr: UTestGetStaticFieldExpression): UTestExpression? = expr + + open fun transform(expr: UTestClassExpression): UTestExpression? = expr + + open fun transform(expr: UTestBooleanExpression): UTestExpression? = expr + open fun transform(expr: UTestByteExpression): UTestExpression? = expr + open fun transform(expr: UTestCharExpression): UTestExpression? = expr + open fun transform(expr: UTestDoubleExpression): UTestExpression? = expr + open fun transform(expr: UTestFloatExpression): UTestExpression? = expr + open fun transform(expr: UTestIntExpression): UTestExpression? = expr + open fun transform(expr: UTestLongExpression): UTestExpression? = expr + open fun transform(expr: UTestNullExpression): UTestExpression? = expr + open fun transform(expr: UTestShortExpression): UTestExpression? = expr + open fun transform(expr: UTestStringExpression): UTestExpression? = expr + + //region Mock transformers + + private fun transformMockContents( + expr: UTestMock + ): Pair, Map>> { + val transformedFields = expr.fields.mapNotNull { (field, fieldValue) -> + val value = transformExpr(fieldValue) ?: return@mapNotNull null + field to value + }.toMap() + + val transformedMethods = expr.methods.mapNotNull { (method, values) -> + val transformedValues = values.mapNotNull { transformExpr(it) } + if (transformedValues.isEmpty()) + return@mapNotNull null + method to transformedValues + }.toMap() + + return transformedFields to transformedMethods + } + + open fun transform(expr: UTestGlobalMock): UTestExpression? { + val fields = mutableMapOf() + val methods = mutableMapOf>() + val mock = UTestGlobalMock(expr.type, fields, methods) + cache[expr] = mock + + val (transformedFields, transformedMethods) = transformMockContents(expr) + + fields.putAll(transformedFields) + methods.putAll(transformedMethods) + return mock + } + + open fun transform(expr: UTestMockObject): UTestExpression? { + val fields = mutableMapOf() + val methods = mutableMapOf>() + val mock = UTestMockObject(expr.type, fields, methods) + cache[expr] = mock + + val (transformedFields, transformedMethods) = transformMockContents(expr) + + fields.putAll(transformedFields) + methods.putAll(transformedMethods) + return mock + } + + open fun transform(expr: UTestInstList): UTestExpression? { + val instList = expr.instList.mapNotNull { inst -> transformInst(inst) } + return UTestInstList(instList) + } + + //endregion + + //region Call transformers + + open fun transform(call: UTestCall): UTestExpression? { + return when (call) { + is UTestConstructorCall -> transform(call) + is UTestStaticMethodCall -> transform(call) + is UTestMethodCall -> transform(call) + is UTestAllocateMemoryCall -> transform(call) + is UTestAssertThrowsCall -> transform(call) + is UTestAssertEqualsCall -> transform(call) + } + } + + open fun transform(call: UTestConstructorCall): UTestCall? { + val args = call.args.map { transformExpr(it) ?: return null } + return UTestConstructorCall(call.method, args) + } + + open fun transform(call: UTestStaticMethodCall): UTestCall? { + val args = call.args.map { transformExpr(it) ?: return null } + return UTestStaticMethodCall(call.method, args) + } + + open fun transform(call: UTestMethodCall): UTestCall? { + val instance = transformExpr(call.instance) ?: return null + val args = call.args.map { transformExpr(it) ?: return null } + return UTestMethodCall(instance, call.method, args) + } + + open fun transform(call: UTestAllocateMemoryCall): UTestCall? = call + + open fun transform(call: UTestAssertThrowsCall): UTestCall? { + val instListTransformed = call.instList.mapNotNull { transformInst(it) } + + if (instListTransformed.isEmpty()) + return null + + return UTestAssertThrowsCall(call.exceptionClass, instListTransformed) + } + + open fun transform(call: UTestAssertEqualsCall): UTestCall? { + val expected = transform(call.expected) ?: return null + val actual = transform(call.actual) ?: return null + + return UTestAssertEqualsCall(expected, actual) + } + + //endregion + + //region Statement transformers + + open fun transform(stmt: UTestStatement): UTestStatement? { + return when (stmt) { + is UTestArraySetStatement -> transform(stmt) + is UTestBinaryConditionStatement -> transform(stmt) + is UTestSetFieldStatement -> transform(stmt) + is UTestSetStaticFieldStatement -> transform(stmt) + } + } + + open fun transform(stmt: UTestArraySetStatement): UTestStatement? { + val array = transformExpr(stmt.arrayInstance) ?: return null + val index = transformExpr(stmt.index) ?: return null + val value = transformExpr(stmt.setValueExpression) ?: return null + return UTestArraySetStatement(array, index, value) + } + + open fun transform(stmt: UTestBinaryConditionStatement): UTestStatement? { + val lhv = transformExpr(stmt.lhv) ?: return null + val rhv = transformExpr(stmt.rhv) ?: return null + val trueBranch = stmt.trueBranch.mapNotNull { transformStmt(it) } + val elseBranch = stmt.elseBranch.mapNotNull { transformStmt(it) } + return UTestBinaryConditionStatement(stmt.conditionType, lhv, rhv, trueBranch, elseBranch) + } + + open fun transform(stmt: UTestSetFieldStatement): UTestStatement? { + val instance = transformExpr(stmt.instance) ?: return null + val value = transformExpr(stmt.value) ?: return null + return UTestSetFieldStatement(instance, stmt.field, value) + } + + open fun transform(stmt: UTestSetStaticFieldStatement): UTestStatement? { + val value = transformExpr(stmt.value) ?: return null + return UTestSetStaticFieldStatement(stmt.field, value) + } + + //endregion +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestBlockRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestBlockRenderer.kt new file mode 100644 index 0000000000..6f87e13c1a --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestBlockRenderer.kt @@ -0,0 +1,139 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.type.ReferenceType +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.testRenderer.JcTestBlockRenderer +import org.usvm.test.api.UTestAllocateMemoryCall +import org.usvm.test.api.UTestExpression +import org.usvm.test.api.UTestStaticMethodCall +import java.util.IdentityHashMap + +open class JcUnsafeTestBlockRenderer protected constructor( + override val methodRenderer: JcUnsafeTestRenderer, + override val importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + exprCache: IdentityHashMap, + thrownExceptions: HashSet, + protected open val unsafeUtilsRenderer: JcUnsafeUtilsRenderer +) : JcTestBlockRenderer( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + exprCache, + thrownExceptions +) { + + constructor( + methodRenderer: JcUnsafeTestRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + shouldDeclareVar: Set, + unsafeUtilsRenderer: JcUnsafeUtilsRenderer + ) : this( + methodRenderer, + importManager, + identifiersManager, + cp, + shouldDeclareVar, + IdentityHashMap(), + HashSet(), + unsafeUtilsRenderer + ) + + override fun newInnerBlock(): JcUnsafeTestBlockRenderer { + return JcUnsafeTestBlockRenderer( + methodRenderer, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + IdentityHashMap(exprCache), + thrownExceptions, + unsafeUtilsRenderer + ) + } + + //region Private Methods + + // TODO: remove special case for getRootCause method + override fun renderStaticMethodCall(expr: UTestStaticMethodCall): Expression { + if (expr.method.name == "getRootCause" && expr.method.enclosingClass.name == "ReflectionUtils") { + return MethodCallExpr( + NameExpr("ReflectionUtils"), + "getRootCause", + NodeList(renderExpression(expr.args.single())) + ) + } + return super.renderStaticMethodCall(expr) + } + + override fun renderPrivateCtorCall( + ctor: JcMethod, + type: JcClassType, + args: List, + inlinesVarargs: Boolean + ): Expression { + return unsafeUtilsRenderer.renderCtorCall(this, ctor, type, args, inlinesVarargs) + } + + override fun renderPrivateMethodCall( + method: JcMethod, + instance: Expression, + args: List, + inlinesVarargs: Boolean + ): Expression { + return unsafeUtilsRenderer.renderInstanceMethodCall(this, method, instance, args, inlinesVarargs) + } + + override fun renderPrivateStaticMethodCall( + method: JcMethod, + args: List, + inlinesVarargs: Boolean + ): Expression { + return unsafeUtilsRenderer.renderStaticMethodCall(this, method, args, inlinesVarargs) + } + + //endregion + + //region Private Fields + + override fun renderGetPrivateStaticField(field: JcField): Expression { + return unsafeUtilsRenderer.renderGetStaticField(this, field) + } + + override fun renderGetPrivateField(instance: Expression, field: JcField): Expression { + return unsafeUtilsRenderer.renderGetInstanceField(this, instance, field) + } + + override fun renderSetPrivateStaticField(field: JcField, value: Expression): Expression { + return unsafeUtilsRenderer.renderSetStaticField(this, field, value) + } + + override fun renderSetPrivateField(instance: Expression, field: JcField, value: Expression): Expression { + return unsafeUtilsRenderer.renderSetInstanceField(this, instance, field, value) + } + + //endregion + + //region Allocation + + override fun renderAllocateMemoryCall(expr: UTestAllocateMemoryCall): Expression { + return unsafeUtilsRenderer.renderAllocateInstance(this, expr.clazz) + } + + //endregion +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestClassRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestClassRenderer.kt new file mode 100644 index 0000000000..97fc388032 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestClassRenderer.kt @@ -0,0 +1,55 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.testRenderer.JcTestClassRenderer +import org.usvm.jvm.rendering.testRenderer.JcTestRenderer +import org.usvm.test.api.UTest + +open class JcUnsafeTestClassRenderer : JcTestClassRenderer { + + val unsafeUtilsRenderer: JcUnsafeUtilsRenderer + + constructor( + name: String, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + unsafeUtilsRenderer: JcUnsafeUtilsRenderer + ) : super(name, importManager, identifiersManager, cp) { + this.unsafeUtilsRenderer = unsafeUtilsRenderer + } + + constructor( + decl: ClassOrInterfaceDeclaration, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + unsafeUtilsRenderer: JcUnsafeUtilsRenderer + ) : super(decl, importManager, identifiersManager, cp) { + this.unsafeUtilsRenderer = unsafeUtilsRenderer + } + + override fun createTestRenderer( + test: UTest, + identifiersManager: JcIdentifiersManager, + name: SimpleName, + annotations: List, + ): JcTestRenderer { + + return JcUnsafeTestRenderer( + test, + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + name, + annotations, + unsafeUtilsRenderer + ) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestFileRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestFileRenderer.kt new file mode 100644 index 0000000000..57fe97ed9d --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestFileRenderer.kt @@ -0,0 +1,69 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.testRenderer.JcTestFileRenderer + +open class JcUnsafeTestFileRenderer : JcTestFileRenderer { + + protected val reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy + + protected open val unsafeUtilsRenderer: JcUnsafeUtilsRenderer by lazy { + JcUnsafeUtilsRenderer(importManager, reflectionUtilsInlineStrategy) + } + + protected constructor( + cu: CompilationUnit, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy + ) : super(cu, importManager, cp) { + this.reflectionUtilsInlineStrategy = reflectionUtilsInlineStrategy + } + + protected constructor( + packageName: String?, + importManager: JcImportManager, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy + ) : super(packageName, importManager, cp) { + this.reflectionUtilsInlineStrategy = reflectionUtilsInlineStrategy + } + + constructor( + cu: CompilationUnit, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy + ) : this( + cu, + JcImportManager(cu), + cp, + reflectionUtilsInlineStrategy + ) + + constructor( + packageName: String?, + cp: JcClasspath, + reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy + ) : this( + packageName, + JcImportManager(null), + cp, + reflectionUtilsInlineStrategy + ) + + override fun classRendererFor(declaration: ClassOrInterfaceDeclaration): JcUnsafeTestClassRenderer { + return JcUnsafeTestClassRenderer(declaration, importManager, identifiersManager, cp, unsafeUtilsRenderer) + } + + override fun classRendererFor(name: String): JcUnsafeTestClassRenderer = + JcUnsafeTestClassRenderer(name, importManager, identifiersManager, cp, unsafeUtilsRenderer) + + + override fun renderInternal(): CompilationUnit { + return unsafeUtilsRenderer.addReflectionUtils(importManager, super.renderInternal()) + } +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestInfo.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestInfo.kt new file mode 100644 index 0000000000..82de2e8233 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestInfo.kt @@ -0,0 +1,14 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import java.nio.file.Path +import org.jacodb.api.jvm.JcMethod +import org.usvm.jvm.rendering.testRenderer.JcTestInfo + +open class JcUnsafeTestInfo( + method: JcMethod, + isExceptional: Boolean, + testFilePath: Path? = null, + testPackageName: String? = null, + testClassName: String? = null, + testName: String? = null +) : JcTestInfo(method, isExceptional, testFilePath, testPackageName, testClassName, testName) \ No newline at end of file diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestRenderer.kt new file mode 100644 index 0000000000..70d796209e --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeTestRenderer.kt @@ -0,0 +1,38 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import com.github.javaparser.ast.expr.AnnotationExpr +import com.github.javaparser.ast.expr.SimpleName +import org.jacodb.api.jvm.JcClasspath +import org.usvm.jvm.rendering.baseRenderer.JcIdentifiersManager +import org.usvm.jvm.rendering.baseRenderer.JcImportManager +import org.usvm.jvm.rendering.testRenderer.JcTestRenderer +import org.usvm.test.api.UTest + +open class JcUnsafeTestRenderer( + test: UTest, + classRenderer: JcUnsafeTestClassRenderer, + importManager: JcImportManager, + identifiersManager: JcIdentifiersManager, + cp: JcClasspath, + name: SimpleName, + annotations: List, + unsafeUtilsRenderer: JcUnsafeUtilsRenderer +): JcTestRenderer( + test, + classRenderer, + importManager, + identifiersManager, + cp, + name, + annotations, +) { + + override val body: JcUnsafeTestBlockRenderer = JcUnsafeTestBlockRenderer( + this, + importManager, + JcIdentifiersManager(identifiersManager), + cp, + shouldDeclareVar, + unsafeUtilsRenderer + ) +} diff --git a/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeUtilsRenderer.kt b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeUtilsRenderer.kt new file mode 100644 index 0000000000..b7e9ece1c6 --- /dev/null +++ b/usvm-jvm-rendering/src/main/kotlin/org/usvm/jvm/rendering/unsafeRenderer/JcUnsafeUtilsRenderer.kt @@ -0,0 +1,190 @@ +package org.usvm.jvm.rendering.unsafeRenderer + +import com.github.javaparser.ast.CompilationUnit +import com.github.javaparser.ast.NodeList +import com.github.javaparser.ast.expr.Expression +import com.github.javaparser.ast.expr.MethodCallExpr +import com.github.javaparser.ast.expr.NameExpr +import com.github.javaparser.ast.expr.StringLiteralExpr +import com.github.javaparser.ast.type.Type +import org.jacodb.api.jvm.JcClassOrInterface +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcField +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcRefType +import org.jacodb.api.jvm.JcType +import org.jacodb.api.jvm.ext.autoboxIfNeeded +import org.jacodb.api.jvm.ext.findType +import org.jacodb.api.jvm.ext.jcdbSignature +import org.jacodb.api.jvm.ext.nullType +import org.jacodb.api.jvm.ext.void +import org.usvm.jvm.rendering.ReflectionUtilsInlineStrategy +import org.usvm.jvm.rendering.baseRenderer.JcImportManager + +open class JcUnsafeUtilsRenderer( + protected val importManager: JcImportManager, + protected val reflectionUtilsInlineStrategy: ReflectionUtilsInlineStrategy +) { + + companion object { + private const val USVM = "org.usvm.jvm.rendering.ReflectionUtils" + private const val USVM_SIMPLE = "ReflectionUtils" + } + + private val utilsName: Expression by lazy { + NameExpr( + if (reflectionUtilsInlineStrategy.inTestClassFile || importManager.add(USVM)) + USVM_SIMPLE + else + USVM + ) + } + + // TODO: rewrite properly after inline strategy redesign + fun addReflectionUtils(importManager: JcImportManager, cu: CompilationUnit): CompilationUnit { + return reflectionUtilsInlineStrategy.addReflectionUtils(importManager, cu) + } + + open fun renderCtorCall( + blockRenderer: JcUnsafeTestBlockRenderer, + ctor: JcMethod, + type: JcClassType, + args: List, + inlinesVarargs: Boolean + ): Expression { + blockRenderer.addThrownException("java.lang.Throwable") + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("callConstructor") + val allArgs = listOf(blockRenderer.renderClassExpression(type), StringLiteralExpr(ctor.jcdbSignature)) + args + return MethodCallExpr( + utilsName, + NodeList(blockRenderer.renderClass(type)), + "callConstructor", + NodeList(allArgs), + ) + } + + open fun renderInstanceMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + instance: Expression, + args: List, + inlinesVarargs: Boolean + ): Expression { + blockRenderer.addThrownException("java.lang.Throwable") + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("callMethod") + val allArgs = listOf(instance, StringLiteralExpr(method.jcdbSignature)) + args + return MethodCallExpr( + utilsName, + listTypeArgsFor(method, blockRenderer), + "callMethod", + NodeList(allArgs), + ) + } + + open fun renderStaticMethodCall( + blockRenderer: JcUnsafeTestBlockRenderer, + method: JcMethod, + args: List, + inlinesVarargs: Boolean + ): Expression { + blockRenderer.addThrownException("java.lang.Throwable") + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("callStaticMethod") + val enclosingClass = method.enclosingClass + val allArgs = + listOf(blockRenderer.renderClassExpression(enclosingClass), StringLiteralExpr(method.jcdbSignature)) + args + return MethodCallExpr( + utilsName, + listTypeArgsFor(method, blockRenderer), + "callStaticMethod", + NodeList(allArgs), + ) + } + + open fun renderGetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField + ): Expression { + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("getStaticFieldValue") + return MethodCallExpr( + utilsName, + listTypeArgsFor(field, blockRenderer), + "getStaticFieldValue", + NodeList(blockRenderer.renderClassExpression(field.enclosingClass), StringLiteralExpr(field.name)), + ) + } + + open fun renderGetStaticField(blockRenderer: JcUnsafeTestBlockRenderer, field: JcField): Expression { + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("getStaticFieldValue") + return MethodCallExpr( + utilsName, + listTypeArgsFor(field, blockRenderer), + "getStaticFieldValue", + NodeList(blockRenderer.renderClassExpression(field.enclosingClass), StringLiteralExpr(field.name)), + ) + } + + open fun renderSetInstanceField( + blockRenderer: JcUnsafeTestBlockRenderer, + instance: Expression, + field: JcField, + value: Expression + ): Expression { + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("setFieldValue") + return MethodCallExpr( + utilsName, + "setFieldValue", + NodeList(instance, StringLiteralExpr(field.name), value), + ) + } + + open fun renderSetStaticField( + blockRenderer: JcUnsafeTestBlockRenderer, + field: JcField, + value: Expression + ): Expression { + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("setStaticFieldValue") + return MethodCallExpr( + utilsName, + "setStaticFieldValue", + NodeList(blockRenderer.renderClassExpression(field.enclosingClass), StringLiteralExpr(field.name), value), + ) + } + + open fun renderAllocateInstance(blockRenderer: JcUnsafeTestBlockRenderer, clazz: JcClassOrInterface): Expression { + blockRenderer.addThrownException("java.lang.InstantiationException") + reflectionUtilsInlineStrategy.useUsvmReflectionMethod("allocateInstance") + return MethodCallExpr( + utilsName, + NodeList(blockRenderer.renderClass(clazz)), + "allocateInstance", + NodeList(blockRenderer.renderClassExpression(clazz)), + ) + } + + private fun listTypeArgsFor(type: JcType, blockRenderer: JcUnsafeTestBlockRenderer): NodeList? { + val cp = type.classpath + return when (type) { + is JcRefType -> NodeList(blockRenderer.renderType(type)) + cp.void, cp.nullType -> null + else -> NodeList(blockRenderer.renderType(type.autoboxIfNeeded())) + } + } + + protected fun listTypeArgsFor(method: JcMethod, blockRenderer: JcUnsafeTestBlockRenderer,): NodeList? { + val cp = method.enclosingClass.classpath + val resultTypeName = method.returnType.typeName + val resultType = cp.findType(resultTypeName) + return listTypeArgsFor(resultType, blockRenderer) + } + + protected fun listTypeArgsFor(field: JcField, blockRenderer: JcUnsafeTestBlockRenderer): NodeList? { + return listTypeArgsFor(fieldType(field), blockRenderer) + } + + protected fun fieldType(field: JcField): JcType { + val cp = field.enclosingClass.classpath + val fieldTypeName = field.type.typeName + return cp.findType(fieldTypeName) + } +} diff --git a/usvm-jvm/build.gradle.kts b/usvm-jvm/build.gradle.kts index 59cbffcdf3..7078b853e9 100644 --- a/usvm-jvm/build.gradle.kts +++ b/usvm-jvm/build.gradle.kts @@ -24,7 +24,7 @@ val `sample-approximations` by sourceSets.creating { val approximations by configurations.creating val approximationsRepo = "com.github.UnitTestBot.java-stdlib-approximations" -val approximationsVersion = "88c6be3469" +val approximationsVersion = "607384f1a7" dependencies { implementation(project(":usvm-core")) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestStateResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestStateResolver.kt index 6d2c5299c7..2be7dd1749 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestStateResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/api/util/JcTestStateResolver.kt @@ -229,7 +229,7 @@ abstract class JcTestStateResolver( val evaluatedType = typeSelector.firstOrNull(typeStream, type.jcClass) ?: return decoderApi.createNullConst(type) - // We check for the type stream emptiness firsly and only then for the resolved cache, + // We check for the type stream emptiness first and only then for the resolved cache, // because even if the object is already resolved, it could be incompatible with the [type], if it // is an element of an array of the wrong type. diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt index 42344b579e..61e424d133 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcComponents.kt @@ -15,7 +15,7 @@ import org.usvm.solver.USoftConstraintsProvider import org.usvm.solver.USolverBase import org.usvm.solver.UTypeSolver -class JcComponents( +open class JcComponents( private val typeSystem: JcTypeSystem, // TODO specific JcMachineOptions should be here private val options: UMachineOptions, diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcContext.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcContext.kt index c41566c540..25c78dea4c 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcContext.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcContext.kt @@ -30,7 +30,7 @@ import org.usvm.machine.interpreter.statics.JcStaticFieldReading import org.usvm.machine.interpreter.statics.JcStaticFieldRegionId import org.usvm.util.extractJcRefType -internal typealias USizeSort = UBv32Sort +typealias USizeSort = UBv32Sort class JcContext( val cp: JcClasspath, diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt index 6138db8534..7ac12ae01f 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcMachine.kt @@ -40,24 +40,32 @@ import org.usvm.util.originalInst val logger = object : KLogging() {}.logger -class JcMachine( +open class JcMachine( cp: JcClasspath, private val options: UMachineOptions, - private val jcMachineOptions: JcMachineOptions = JcMachineOptions(), - private val interpreterObserver: JcInterpreterObserver? = null, + protected val jcMachineOptions: JcMachineOptions = JcMachineOptions(), + protected val interpreterObserver: JcInterpreterObserver? = null, ) : UMachine() { - private val applicationGraph = JcApplicationGraph(cp) + protected val applicationGraph = JcApplicationGraph(cp) - private val typeSystem = JcTypeSystem(cp, options.typeOperationsTimeout) - private val components = JcComponents(typeSystem, options) - private val ctx = JcContext(cp, components) - - private val interpreter = JcInterpreter(ctx, applicationGraph, jcMachineOptions, interpreterObserver) + protected val typeSystem = JcTypeSystem(cp, options.typeOperationsTimeout) + protected open val components = JcComponents(typeSystem, options) + protected val ctx by lazy { createContext(cp, components) } + protected open fun createInterpreter(): JcInterpreter { + return JcInterpreter(ctx, applicationGraph, jcMachineOptions, interpreterObserver) + } private val cfgStatistics = CfgStatisticsImpl(applicationGraph) + protected open fun createContext( + cp: JcClasspath, + components: JcComponents, + ): JcContext { + return JcContext(cp, components) + } fun analyze(methods: List, targets: List = emptyList()): List { logger.debug("{}.analyze({})", this, methods) + val interpreter = createInterpreter() val initialStates = mutableMapOf() methods.forEach { initialStates[it] = interpreter.getInitialState(it, targets) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt index 0265d80424..893810c8f8 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcInterpreter.kt @@ -82,9 +82,9 @@ typealias JcStepScope = StepScope /** * A JacoDB interpreter. */ -class JcInterpreter( - private val ctx: JcContext, - private val applicationGraph: JcApplicationGraph, +open class JcInterpreter( + protected val ctx: JcContext, + protected val applicationGraph: JcApplicationGraph, private val options: JcMachineOptions, private val observer: JcInterpreterObserver? = null, var forkBlackList: UForkBlackList = UForkBlackList.createDefault(), @@ -227,7 +227,7 @@ class JcInterpreter( private val typeSelector = JcFixedInheritorsNumberTypeSelector() - private fun callMethod( + protected open fun callMethod( scope: JcStepScope, stmt: JcMethodCallBaseInst, exprResolver: JcExprResolver @@ -235,12 +235,17 @@ class JcInterpreter( val simpleValueResolver = exprResolver.simpleValueResolver val method = stmt.method when (stmt) { - is JcMethodEntrypointInst -> { - observer?.onEntryPoint(simpleValueResolver, stmt, scope) + is JcConcreteMethodCallInst -> { + observer?.onMethodCallWithResolvedArguments(simpleValueResolver, stmt, scope) + if (approximateMethod(scope, stmt)) { + return + } - // Run static initializer for all enum arguments of the entrypoint - for ((type, ref) in stmt.entrypointArguments) { - exprResolver.ensureExprCorrectness(ref, type) ?: return + val entryPoint = applicationGraph.entryPoints(method).singleOrNull() + + if (method.isNative || entryPoint == null) { + mockMethod(scope, stmt, applicationGraph) + return } handleInnerClassMethodCall( @@ -249,33 +254,24 @@ class JcInterpreter( method, outerClassInstanceConstructorArgument = { // Implicit first argument is `this`, an instance of the outer class would be second - stmt.entrypointArguments[1].second + stmt.arguments[1].asExpr(ctx.addressSort) }, thisInstanceMethodArgument = { // For methods, we need to extract `this` - stmt.entrypointArguments.first().second + stmt.arguments.first().asExpr(ctx.addressSort) }, ) - val entryPoint = applicationGraph.entryPoints(method).singleOrNull() - ?: error("Entrypoint method $method has no entry points") - scope.doWithState { - newStmt(entryPoint) + addNewMethodCall(stmt, entryPoint) } } + is JcMethodEntrypointInst -> { + observer?.onEntryPoint(simpleValueResolver, stmt, scope) - is JcConcreteMethodCallInst -> { - observer?.onMethodCallWithResolvedArguments(simpleValueResolver, stmt, scope) - if (approximateMethod(scope, stmt)) { - return - } - - val entryPoint = applicationGraph.entryPoints(method).singleOrNull() - - if (method.isNative || entryPoint == null) { - mockMethod(scope, stmt, applicationGraph) - return + // Run static initializer for all enum arguments of the entrypoint + for ((type, ref) in stmt.entrypointArguments) { + exprResolver.ensureExprCorrectness(ref, type) ?: return } handleInnerClassMethodCall( @@ -284,16 +280,19 @@ class JcInterpreter( method, outerClassInstanceConstructorArgument = { // Implicit first argument is `this`, an instance of the outer class would be second - stmt.arguments[1].asExpr(ctx.addressSort) + stmt.entrypointArguments[1].second }, thisInstanceMethodArgument = { // For methods, we need to extract `this` - stmt.arguments.first().asExpr(ctx.addressSort) + stmt.entrypointArguments.first().second }, ) + val entryPoint = applicationGraph.entryPoints(method).singleOrNull() + ?: error("Entrypoint method $method has no entry points") + scope.doWithState { - addNewMethodCall(stmt, entryPoint) + newStmt(entryPoint) } } diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt index 22fc918c39..1a652ff3e1 100644 --- a/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/JavaMethodTestRunner.kt @@ -806,7 +806,7 @@ open class JavaMethodTestRunner : TestRunner, KClass<*>?, J pathSelectionStrategies = listOf(PathSelectionStrategy.FORK_DEPTH), coverageZone = CoverageZone.TRANSITIVE, exceptionsPropagation = true, - timeout = 60_000.milliseconds, +// timeout = 60_000.milliseconds, stepsFromLastCovered = 3500L, solverTimeout = Duration.INFINITE, // we do not need the timeout for a solver in tests typeOperationsTimeout = Duration.INFINITE, // we do not need the timeout for type operations in tests diff --git a/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/Api.kt b/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/Api.kt index b4580b29fa..70f44cea30 100644 --- a/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/Api.kt +++ b/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/Api.kt @@ -14,6 +14,14 @@ sealed interface UTestExpression: UTestInst { val type: JcType? } +class UTestMockInst( + val instance: UTestExpression, + val method: JcMethod, + val args: List, +): UTestExpression { + override val type: JcType? = method.enclosingClass.classpath.findTypeOrNull(method.returnType) +} + sealed class UTestMock( override val type: JcType, open val fields: Map, @@ -28,6 +36,14 @@ class UTestMockObject( override val methods: Map> ) : UTestMock(type, fields, methods) +/* + * TODO: remove when there will be proper doAnswer-like support + * in UTestMockObject + */ +class UTestInstList(val instList: List): UTestExpression { + override val type: JcType? = null +} + /** * Mock for all objects of type */ @@ -83,6 +99,32 @@ class UTestAllocateMemoryCall( override val type: JcType = clazz.toType() } +class UTestAssertThrowsCall( + val exceptionClass: JcClassOrInterface, + val instList: List +) : UTestCall { + override val instance: UTestExpression? = null + override val method: JcMethod? = null + override val args: List = emptyList() + override val type: JcType = exceptionClass.toType() +} + +class UTestAssertEqualsCall( + val expected: UTestExpression, + val actual: UTestExpression +) : UTestCall { + init { + check(expected.type != null && actual.type != null) { + "operand types expected" + } + } + + override val instance: UTestExpression? = null + override val method: JcMethod? = null + override val args: List = emptyList() + override val type: JcType = expected.type!!.classpath.boolean +} + sealed interface UTestStatement : UTestInst class UTestSetFieldStatement( @@ -247,4 +289,4 @@ enum class ArithmeticOperationType { //Bitwise OR, AND, XOR -} +} \ No newline at end of file diff --git a/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/JcTestExecutorDecoderApi.kt b/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/JcTestExecutorDecoderApi.kt index ae4c85cbbc..d46a0281c3 100644 --- a/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/JcTestExecutorDecoderApi.kt +++ b/usvm-jvm/usvm-jvm-test-api/src/main/kotlin/org/usvm/test/api/JcTestExecutorDecoderApi.kt @@ -98,4 +98,4 @@ open class JcTestExecutorDecoderApi( } internal val JcClasspath.stringType: JcType - get() = findClassOrNull("java.lang.String")!!.toType() + get() = findClassOrNull("java.lang.String")!!.toType() \ No newline at end of file diff --git a/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt b/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt index b2989d2685..7e13d38a8f 100644 --- a/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt +++ b/usvm-util/src/main/kotlin/org/usvm/UMachineOptions.kt @@ -200,7 +200,7 @@ data class UMachineOptions( /** * Timeout to stop execution on. Use [Duration.INFINITE] for no timeout. */ - val timeout: Duration = 20_000.milliseconds, + val timeout: Duration = Duration.INFINITE, /** * A number of steps from the last terminated state. */