diff --git a/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt b/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt index c73ea2ea01..7ba39b77aa 100644 --- a/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt +++ b/usvm-core/src/main/kotlin/org/usvm/memory/HeapRefSplitting.kt @@ -163,6 +163,25 @@ inline fun foldHeapRefWithStaticAsSymbolic( blockOnSymbolic = blockOnSymbolic ) +inline fun foldHeapRefWithStaticAsConcrete( + ref: UHeapRef, + initial: R, + initialGuard: UBoolExpr, + ignoreNullRefs: Boolean = true, + collapseHeapRefs: Boolean = true, + blockOnConcrete: (R, GuardedExpr) -> R, + blockOnSymbolic: (R, GuardedExpr) -> R, +): R = foldHeapRef( + ref, + initial, + initialGuard, + ignoreNullRefs, + collapseHeapRefs, + staticIsConcrete = true, + blockOnConcrete = blockOnConcrete, + blockOnSymbolic = blockOnSymbolic +) + inline fun foldHeapRef2( ref0: UHeapRef, ref1: UHeapRef, diff --git a/usvm-jvm/build.gradle.kts b/usvm-jvm/build.gradle.kts index 59cbffcdf3..9beb55db37 100644 --- a/usvm-jvm/build.gradle.kts +++ b/usvm-jvm/build.gradle.kts @@ -24,7 +24,7 @@ val `sample-approximations` by sourceSets.creating { val approximations by configurations.creating val approximationsRepo = "com.github.UnitTestBot.java-stdlib-approximations" -val approximationsVersion = "88c6be3469" +val approximationsVersion = "aa9e358446" dependencies { implementation(project(":usvm-core")) diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt index 0d01d3c196..1db286af01 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/JcApproximations.kt @@ -76,11 +76,14 @@ import org.usvm.api.mapTypeStreamNotNull import org.usvm.api.memcpy import org.usvm.api.objectTypeEquals import org.usvm.api.objectTypeSubtype +import org.usvm.api.readArrayIndex +import org.usvm.api.readArrayLength import org.usvm.api.readField import org.usvm.api.writeField import org.usvm.collection.array.UArrayIndexLValue import org.usvm.collection.array.length.UArrayLengthLValue import org.usvm.collection.field.UFieldLValue +import org.usvm.getIntValue import org.usvm.jvm.util.allInstanceFields import org.usvm.jvm.util.javaName import org.usvm.machine.interpreter.JcExprResolver @@ -88,7 +91,10 @@ import org.usvm.machine.interpreter.JcStepScope import org.usvm.machine.mocks.mockMethod import org.usvm.machine.state.JcState import org.usvm.machine.state.newStmt +import org.usvm.machine.state.skipMethodInvocationAndBoxIfNeeded import org.usvm.machine.state.skipMethodInvocationWithValue +import org.usvm.memory.foldHeapRefWithStaticAsConcrete +import org.usvm.mkSizeExpr import org.usvm.sizeSort import org.usvm.types.first import org.usvm.types.singleOrNull @@ -571,6 +577,84 @@ class JcMethodApproximationResolver( return false } + private fun JcState.arrayContentEquals( + firstArray: UHeapRef, + secondArray: UHeapRef, + firstLength: UExpr, + secondLength: UExpr, + arrayType: JcArrayType, + ): UBoolExpr? = with(ctx) { + val arrayDesciptor = arrayDescriptorOf(arrayType) + val elementType = arrayType.elementType + val elementSort = typeToSort(elementType) + + val concreteLength = + getIntValue(firstLength) + ?: getIntValue(secondLength) + ?: return@with null + + val arrayEquals = List(concreteLength) { + val idx = mkSizeExpr(it) + val first = memory.readArrayIndex(firstArray, idx, arrayDesciptor, elementSort) + val second = + memory.readArrayIndex(secondArray, idx, arrayDesciptor, elementSort) + mkEq(first, second) + } + + return@with mkAnd(arrayEquals) + } + + private fun JcState.arrayEquals(methodCall: JcMethodCall, firstArray: UHeapRef, secondArray: UHeapRef) = with(ctx) { + val possibleElementTypes = primitiveTypes + cp.objectType + val possibleArrayTypes = possibleElementTypes.map { cp.arrayTypeOf(it) } + + val branches = mutableListOf Unit>>() + var typeDiffersConstraint: UBoolExpr = trueExpr + + val arrayRefsEqual = mkEq(firstArray, secondArray) + val oneArrayIsNull = mkOr(mkEq(firstArray, nullRef), mkEq(secondArray, nullRef)) + branches += arrayRefsEqual to { state -> + state.skipMethodInvocationAndBoxIfNeeded(methodCall, cp.boolean, trueExpr) + } + branches += mkAnd(mkNot(arrayRefsEqual), oneArrayIsNull) to { state -> + state.skipMethodInvocationAndBoxIfNeeded(methodCall, cp.boolean, falseExpr) + } + val needToCheckContent = mkAnd(mkNot(arrayRefsEqual), mkNot(oneArrayIsNull)) + for (arrayType in possibleArrayTypes) { + val typeConstraint = scope.calcOnState { + mkAnd( + memory.types.evalIsSubtype(firstArray, arrayType), + memory.types.evalIsSubtype(secondArray, arrayType) + ) + } + typeDiffersConstraint = mkAnd(typeDiffersConstraint, mkNot(typeConstraint)) + val arrayDesciptor = arrayDescriptorOf(arrayType) + val firstLength = memory.readArrayLength(firstArray, arrayDesciptor, sizeSort) + val secondLength = memory.readArrayLength(secondArray, arrayDesciptor, sizeSort) + val lengthsEqual = mkEq(firstLength, secondLength) + + branches += mkAnd(needToCheckContent, typeConstraint, mkNot(lengthsEqual)) to { state -> + state.skipMethodInvocationAndBoxIfNeeded(methodCall, cp.boolean, falseExpr) + } + + branches += mkAnd(needToCheckContent, typeConstraint, lengthsEqual) to { state -> + val checkResult = state.arrayContentEquals(firstArray, secondArray, firstLength, secondLength, arrayType) + if (checkResult == null) { + // Unable to check + state.skipMethodInvocationWithValue(methodCall, nullRef) + } else { + state.skipMethodInvocationAndBoxIfNeeded(methodCall, cp.boolean, checkResult) + } + } + } + + branches += typeDiffersConstraint to { state -> + state.skipMethodInvocationAndBoxIfNeeded(methodCall, cp.boolean, falseExpr) + } + + scope.forkMulti(branches) + } + private sealed interface StringConcatElement private data class StringConcatStrElement(val str: String) : StringConcatElement private data class StringConcatRefElement(val ref: UHeapRef) : StringConcatElement @@ -956,6 +1040,19 @@ class JcMethodApproximationResolver( val arg = it.arguments.single().asExpr(ctx.booleanSort) scope.assert(arg)?.let { ctx.voidValue } } + dispatchUsvmApiMethod(Engine::assumeSymbolic) { + val instance = it.arguments[0].asExpr(ctx.addressSort) + val condition = it.arguments[1].asExpr(ctx.booleanSort) + foldHeapRefWithStaticAsConcrete( + ref = instance, + initial = Unit, + initialGuard = ctx.trueExpr, + ignoreNullRefs = true, + collapseHeapRefs = true, + blockOnConcrete = { _, _ -> Unit }, + blockOnSymbolic = { acc, ref -> scope.assert(ctx.mkImplies(ref.guard, condition)) ?: acc } + )?.let { ctx.voidValue } + } dispatchUsvmApiMethod(Engine::makeSymbolicBoolean) { scope.calcOnState { makeSymbolicPrimitive(ctx.booleanSort) } } @@ -1098,6 +1195,12 @@ class JcMethodApproximationResolver( makeSymbolicArray(ctx.cp.objectType, sizeExpr) } } + dispatchUsvmApiMethod(Engine::arrayEquals) { + val first = it.arguments[0].asExpr(ctx.addressSort) + val second = it.arguments[1].asExpr(ctx.addressSort) + scope.doWithState { arrayEquals(it, first, second) } + null + } dispatchMkList(Engine::makeSymbolicList) { scope.calcOnState { mkSymbolicList(symbolicListType) } } diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt index ad0f6e77ce..cc6dad55fd 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/state/JcStateUtils.kt @@ -1,11 +1,15 @@ package org.usvm.machine.state +import org.jacodb.api.jvm.JcClassType import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcPrimitiveType import org.jacodb.api.jvm.JcType import org.jacodb.api.jvm.cfg.JcArgument import org.jacodb.api.jvm.cfg.JcDynamicCallExpr import org.jacodb.api.jvm.cfg.JcInst +import org.jacodb.api.jvm.ext.autoboxIfNeeded import org.jacodb.api.jvm.ext.cfg.locals +import org.jacodb.api.jvm.ext.findType import org.usvm.UExpr import org.usvm.UHeapRef import org.usvm.USort @@ -88,3 +92,29 @@ fun JcState.skipMethodInvocationWithValue(methodCall: JcMethodCall, value: UExpr methodResult = JcMethodResult.Success(methodCall.method, value) newStmt(methodCall.returnSite) } + +fun JcState.skipMethodInvocationAndBoxIfNeeded(methodCall: JcMethodCall, valueType: JcType, value: UExpr) { + val typeSystem = ctx.typeSystem() + val methodReturnType = ctx.cp.findType(methodCall.method.returnType.typeName) + when { + valueType is JcPrimitiveType && methodReturnType is JcClassType -> { + val boxedType = valueType.autoboxIfNeeded() as JcClassType + check(typeSystem.isSupertype(methodReturnType, boxedType)) { + "skipMethodInvocationAndBoxIfNeeded: Incorrect method return type" + } + val boxMethod = boxedType.declaredMethods.first { + it.name == "valueOf" && it.isStatic && it.parameters.singleOrNull()?.type == valueType + } + methodResult = JcMethodResult.NoCall + newStmt(JcConcreteMethodCallInst(methodCall.location, boxMethod.method, listOf(value), methodCall.returnSite)) + } + + else -> { + // TODO: implement unboxing if needed + check(typeSystem.isSupertype(methodReturnType, valueType)) { + "skipMethodInvocationAndBoxIfNeeded: Incorrect method return type" + } + skipMethodInvocationWithValue(methodCall, value) + } + } +} diff --git a/usvm-jvm/usvm-jvm-api/src/main/java/org/usvm/api/Engine.java b/usvm-jvm/usvm-jvm-api/src/main/java/org/usvm/api/Engine.java index 706ab361da..d03fe027c4 100644 --- a/usvm-jvm/usvm-jvm-api/src/main/java/org/usvm/api/Engine.java +++ b/usvm-jvm/usvm-jvm-api/src/main/java/org/usvm/api/Engine.java @@ -5,6 +5,7 @@ import org.usvm.api.internal.SymbolicMapImpl; import java.lang.reflect.Array; +import java.util.Arrays; public class Engine { @@ -12,6 +13,10 @@ public static void assume(boolean expr) { assert expr; } + public static void assumeSymbolic(Object instance, boolean expr) { + assert expr; + } + @SuppressWarnings("unused") public static T makeSymbolic(Class clazz) { return null; @@ -103,6 +108,28 @@ public static double[] makeSymbolicDoubleArray(int size) { return new double[size]; } + public static Boolean arrayEquals(Object first, Object second) { + if (first instanceof byte[] && second instanceof byte[]) + return Arrays.equals((byte[])first, (byte[])second); + if (first instanceof char[] && second instanceof char[]) + return Arrays.equals((char[])first, (char[])second); + if (first instanceof int[] && second instanceof int[]) + return Arrays.equals((int[])first, (int[])second); + if (first instanceof long[] && second instanceof long[]) + return Arrays.equals((long[])first, (long[])second); + if (first instanceof boolean[] && second instanceof boolean[]) + return Arrays.equals((boolean[])first, (boolean[])second); + if (first instanceof float[] && second instanceof float[]) + return Arrays.equals((float[])first, (float[])second); + if (first instanceof double[] && second instanceof double[]) + return Arrays.equals((double[])first, (double[])second); + if (first instanceof short[] && second instanceof short[]) + return Arrays.equals((short[])first, (short[])second); + if (first instanceof Object[] && second instanceof Object[]) + return Arrays.equals((Object[])first, (Object[])second); + return false; + } + public static SymbolicList makeSymbolicList() { return new SymbolicListImpl<>(); }