diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/Accessors.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/Accessors.kt index 9a8822781..39dadfff4 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/Accessors.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/Accessors.kt @@ -61,6 +61,8 @@ sealed interface AccessPathBase { } } +sealed interface AbstractionAlwaysUnrollNextAccessor + sealed class Accessor : Comparable { abstract fun toSuffix(): String protected abstract val accessorClassId: Int @@ -71,15 +73,16 @@ sealed class Accessor : Comparable { } return when (this) { - ElementAccessor, FinalAccessor, AnyAccessor, ValueAccessor -> 0 // Definitely equal + ElementAccessor, FinalAccessor, AnyAccessor, ValueAccessor, TypeInfoGroupAccessor -> 0 // Definitely equal is FieldAccessor -> this.compareToFieldAccessor(other as FieldAccessor) is TaintMarkAccessor -> this.compareToTaintMarkAccessor(other as TaintMarkAccessor) is ClassStaticAccessor -> this.compareToClassStaticAccessor(other as ClassStaticAccessor) + is TypeInfoAccessor -> this.compareToTypeInfoAccessor(other as TypeInfoAccessor) } } } -data class TaintMarkAccessor(val mark: String): Accessor() { +data class TaintMarkAccessor(val mark: String): Accessor(), AbstractionAlwaysUnrollNextAccessor { override fun toSuffix(): String = "![$mark]" override fun toString(): String = "![$mark]" @@ -126,7 +129,7 @@ data object ElementAccessor : Accessor() { override val accessorClassId: Int = 0 } -data object FinalAccessor : Accessor() { +data object FinalAccessor : Accessor(), AbstractionAlwaysUnrollNextAccessor { override fun toSuffix(): String = ".\$" override fun toString(): String = "\$" @@ -153,7 +156,7 @@ data class ClassStaticAccessor(val typeName: String) : Accessor() { } } -object ValueAccessor : Accessor() { +object ValueAccessor : Accessor(), AbstractionAlwaysUnrollNextAccessor { override fun toString(): String = "[value]" override fun toSuffix(): String = ".[value]" @@ -165,3 +168,21 @@ inline fun ApManager.tryAnyAccessorOrNull(accessor: Accessor, body: () if (!anyAccessorUnrollStrategy.unrollAccessor(accessor)) return null return body() } + +object TypeInfoGroupAccessor : Accessor(), AbstractionAlwaysUnrollNextAccessor { + override fun toString(): String = "[type]" + override fun toSuffix(): String = ".[type]" + + override val accessorClassId: Int = 7 +} + +data class TypeInfoAccessor(val typeName: String) : Accessor(), AbstractionAlwaysUnrollNextAccessor { + override fun toString(): String = "{$typeName}" + override fun toSuffix(): String = ".{$typeName}" + + override val accessorClassId: Int = 8 + + fun compareToTypeInfoAccessor(other: TypeInfoAccessor): Int { + return typeName.compareTo(other.typeName) + } +} diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/TaintAnalysisUnitRunnerManager.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/TaintAnalysisUnitRunnerManager.kt index bc8366066..9a05af598 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/TaintAnalysisUnitRunnerManager.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/TaintAnalysisUnitRunnerManager.kt @@ -25,7 +25,6 @@ import org.opentaint.dataflow.ap.ifds.taint.CommonTaintAnalysisContext import org.opentaint.dataflow.ap.ifds.taint.TaintAnalysisUnitStorage import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerability -import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityWithEndFactRequirement import org.opentaint.dataflow.ap.ifds.trace.ParallelProcessingContext import org.opentaint.dataflow.ap.ifds.trace.TraceResolver import org.opentaint.dataflow.ap.ifds.trace.VulnerabilityChecker @@ -219,11 +218,30 @@ class TaintAnalysisUnitRunnerManager( cancellationTimeout: Duration, ): List { val traceResolver = TraceResolver(entryPoints, this, resolverParams, cancellation) - val traceResolutionContext = object : ParallelProcessingContext( - analyzerDispatcher, name = "Trace resolution", vulnerabilities + + val states = vulnerabilities.map { TraceResolver.State.Initial(it) } + val traceResolutionContext = object : ParallelProcessingContext( + analyzerDispatcher, name = "Trace resolution", states ) { - override fun createUnprocessed(item: TaintVulnerability) = - VulnerabilityWithTrace(item, trace = null) + override fun processItem(item: TraceResolver.State): ProcessingResult { + val res = traceResolver.resolveTrace(item) + return when (res) { + is TraceResolver.TraceResolutionResult.InProgress -> { + ProcessingResult.Running(res.state) + } + + is TraceResolver.TraceResolutionResult.NoTrace -> { + ProcessingResult.Done(VulnerabilityWithTrace(res.vulnerability, trace = null)) + } + + is TraceResolver.TraceResolutionResult.Resolved -> { + ProcessingResult.Done(VulnerabilityWithTrace(res.vulnerability, res.trace)) + } + } + } + + override fun createUnprocessed(item: TraceResolver.State): VulnerabilityWithTrace = + VulnerabilityWithTrace(item.vulnerability, trace = null) private var prevStats: MethodStats? = null @@ -254,10 +272,7 @@ class TaintAnalysisUnitRunnerManager( return traceResolutionContext.processAll( progressScope, timeout, cancellationTimeout, cancellation - ) { vulnerability -> - val trace = traceResolver.resolveTrace(vulnerability) - VulnerabilityWithTrace(vulnerability, trace) - } + ) } fun confirmVulnerabilities( @@ -269,13 +284,13 @@ class TaintAnalysisUnitRunnerManager( cancellation.activate() val confirmed = mutableListOf() - val unconfirmedVulnerabilities = mutableListOf() + val unconfirmedVulnerabilities = mutableListOf() for (vulnerability in vulnerabilities) { - when (vulnerability) { - is TaintSinkTracker.TaintVulnerabilityUnconditional -> confirmed.add(vulnerability) - is TaintSinkTracker.TaintVulnerabilityWithFact -> confirmed.add(vulnerability) - is TaintVulnerabilityWithEndFactRequirement -> unconfirmedVulnerabilities.add(vulnerability) + if (VulnerabilityChecker.needVerification(vulnerability)) { + unconfirmedVulnerabilities.add(vulnerability) + } else { + confirmed.add(vulnerability) } } @@ -311,25 +326,27 @@ class TaintAnalysisUnitRunnerManager( private fun confirmVulnerabilitiesWithCancellation( entryPoints: Set, - vulnerabilities: List, + vulnerabilities: List, timeout: Duration, cancellationTimeout: Duration, ): List { val checker = VulnerabilityChecker(entryPoints, this, cancellation) val vulnConfirmationContext = - object : ParallelProcessingContext( + object : ParallelProcessingContext( analyzerDispatcher, name = "Vulnerability confirmation", vulnerabilities ) { - override fun createUnprocessed(item: TaintVulnerabilityWithEndFactRequirement): VerifiedVulnerability = - VerifiedVulnerability( - item.vulnerability, - status = VulnerabilityVerificationStatus.UNKNOWN - ) + override fun processItem(item: TaintVulnerability): ProcessingResult { + val result = checker.verifyVulnerability(item) + return ProcessingResult.Done(result) + } + + override fun createUnprocessed(item: TaintVulnerability): VerifiedVulnerability = + VerifiedVulnerability(item, status = VulnerabilityVerificationStatus.UNKNOWN) } return vulnConfirmationContext.processAll( progressScope, timeout, cancellationTimeout, cancellation - ) { checker.verifyVulnerability(it) } + ) } fun methodCallers(method: CommonMethod): Set = diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataFactFilter.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataFactFilter.kt index 8924d2176..9b38d0587 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataFactFilter.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataFactFilter.kt @@ -9,6 +9,8 @@ import org.opentaint.dataflow.ap.ifds.FactTypeChecker.CompatibilityFilterResult import org.opentaint.dataflow.ap.ifds.FieldAccessor import org.opentaint.dataflow.ap.ifds.FinalAccessor import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.ValueAccessor import org.opentaint.dataflow.util.forEach import kotlin.collections.plusAssign @@ -37,6 +39,8 @@ fun AutomataApManager.createFilter( is FinalAccessor -> FactTypeChecker.AlwaysRejectFilter is TaintMarkAccessor -> OnlyFinalAccessorAllowedFilter is ValueAccessor -> OnlyTaintMarkAccessorAllowedFilter + is TypeInfoGroupAccessor -> FactTypeChecker.AlwaysAcceptFilter + is TypeInfoAccessor -> FactTypeChecker.AlwaysAcceptFilter else -> error("Unexpected single accessor") } }, @@ -63,6 +67,8 @@ private inline fun AutomataApManager.createFilter( is TaintMarkAccessor -> filters += singleAccessorFilter(accessor) is ValueAccessor -> filters += singleAccessorFilter(accessor) + is TypeInfoGroupAccessor -> filters += singleAccessorFilter(accessor) + is TypeInfoAccessor -> filters += singleAccessorFilter(accessor) is FieldAccessor, is ClassStaticAccessor -> filters += accessorListFilter(listOf(accessor)) diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstraction.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstraction.kt index b96fae125..86ab75cfd 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstraction.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstraction.kt @@ -1,21 +1,28 @@ package org.opentaint.dataflow.ap.ifds.access.automata -import org.opentaint.ir.api.common.cfg.CommonInst +import it.unimi.dsi.fastutil.ints.IntArrayList import org.opentaint.dataflow.ap.ifds.ExclusionSet import org.opentaint.dataflow.ap.ifds.FactTypeChecker import org.opentaint.dataflow.ap.ifds.MethodAnalyzerEdges import org.opentaint.dataflow.ap.ifds.access.FinalFactAp import org.opentaint.dataflow.ap.ifds.access.InitialFactAbstraction import org.opentaint.dataflow.ap.ifds.access.InitialFactAp +import org.opentaint.dataflow.ap.ifds.access.util.AccessorIdx +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.ANY_ACCESSOR_IDX +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isAlwaysUnrollNext import org.opentaint.dataflow.ap.ifds.tryAnyAccessorOrNull +import org.opentaint.dataflow.util.ConcurrentReadSafeObject2IntMap import org.opentaint.dataflow.util.contains import org.opentaint.dataflow.util.filter import org.opentaint.dataflow.util.forEach +import org.opentaint.dataflow.util.forEachIntEntry import org.opentaint.dataflow.util.getOrCreateIndex import org.opentaint.dataflow.util.getValue import org.opentaint.dataflow.util.int2ObjectMap import org.opentaint.dataflow.util.object2IntMap +import org.opentaint.dataflow.util.reversedForEachInt import org.opentaint.dataflow.util.toBitSet +import org.opentaint.ir.api.common.cfg.CommonInst import java.util.BitSet class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFactAbstraction { @@ -50,7 +57,7 @@ class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFact fact: AccessGraphInitialFactAp, typeChecker: FactTypeChecker ): List> { - val addedBasedFacts = addedFacts.find(fact.base) ?: return emptyList() + val addedBasedFacts = addedFacts.getOrCreate(fact.base) return addedBasedFacts.registerNew(fact.access, fact.exclusions, typeChecker).map { Pair( AccessGraphInitialFactAp(fact.base, it, ExclusionSet.Empty), @@ -77,7 +84,7 @@ class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFact private val analyzedExclusionIndex = int2ObjectMap() fun addAndAbstract(graph: AccessGraph, typeChecker: FactTypeChecker): List = with(graph.manager) { - if (added.isEmpty()) { + if (added.isEmpty() && analyzed.isEmpty()) { added.put(graph, 0) addedGraphs.add(graph) addedIndex.add(graph, 0) @@ -107,7 +114,12 @@ class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFact ): List = with(graph.manager) { if (exclusion !is ExclusionSet.Concrete) return emptyList() - val analyzedGraphIdx = analyzed.getValue(graph) + var analyzedGraphIdx = analyzed.getInt(graph) + if (analyzedGraphIdx == ConcurrentReadSafeObject2IntMap.NO_VALUE) { + graph.registerNewAnalyzed() + analyzedGraphIdx = analyzed.getValue(graph) + } + val analyzedGraphExclusion = analyzedExclusion[analyzedGraphIdx] val newAccessors = exclusion.set.toBitSet { it.idx }.filter { it !in analyzedGraphExclusion } @@ -127,6 +139,14 @@ class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFact ): List { val relevantAnalyzedGraphIndices = BitSet() addedGraph.accessors().forEach { accessor -> + if (accessor == ANY_ACCESSOR_IDX) { + analyzedExclusionIndex.forEachIntEntry { excludedAccessor, graphs -> + if (tryAnyAccessorOrNull(excludedAccessor.accessor) { true } != true) return@forEachIntEntry + relevantAnalyzedGraphIndices.or(graphs) + } + return@forEach + } + val graphsWithAccessorExcluded = analyzedExclusionIndex.get(accessor) ?: return@forEach relevantAnalyzedGraphIndices.or(graphsWithAccessorExcluded) } @@ -196,12 +216,61 @@ class AutomataInitialFactAbstraction(initialStatement: CommonInst) : InitialFact } } - val singleAccessorGraph = emptyGraph().prepend(accessor) - val newGraph = analyzedGraph.concat(singleAccessorGraph) + if (accessor.isAlwaysUnrollNext()) { + val graphs = mutableListOf() + unrollNext(graphs, IntArrayList(), delta, accessor) + + graphs.forEach { g -> + val newGraph = analyzedGraph.concat(g) + newAnalyzedGraphs.add(newGraph) + } + } else { + val singleAccessorGraph = emptyGraph().prepend(accessor) + val newGraph = analyzedGraph.concat(singleAccessorGraph) + + newAnalyzedGraphs.add(newGraph) + } + } + } + } + + private fun AutomataApManager.unrollNext( + dst: MutableList, + path: IntArrayList, + graph: AccessGraph, + accessor: AccessorIdx + ) { + if (path.contains(accessor)) return // note: we don't expect long accessor chains here + path.add(accessor) + + try { + val nextGraph = graph.read(accessor) + ?: return - newAnalyzedGraphs.add(newGraph) + if (nextGraph.initialNodeIsFinal()) { + dst += rebuildGraph(path) } + + nextGraph.stateSuccessors(nextGraph.initial).forEach { nextAccessor -> + if (nextAccessor.isAlwaysUnrollNext()) { + unrollNext(dst, path, nextGraph, nextAccessor) + } else { + path.add(nextAccessor) + dst += rebuildGraph(path) + path.removeInt(path.lastIndex) + } + } + } finally { + path.removeInt(path.lastIndex) + } + } + + private fun AutomataApManager.rebuildGraph(path: IntArrayList): AccessGraph { + var res = emptyGraph() + path.reversedForEachInt { accessor -> + res = res.prepend(accessor) } + return res } private fun AccessGraph.registerNewAnalyzed(): AccessGraph? { diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/cactus/AccessCactus.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/cactus/AccessCactus.kt index 718d0c9cb..11daf3c86 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/cactus/AccessCactus.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/cactus/AccessCactus.kt @@ -11,6 +11,8 @@ import org.opentaint.dataflow.ap.ifds.FactTypeChecker import org.opentaint.dataflow.ap.ifds.FieldAccessor import org.opentaint.dataflow.ap.ifds.FinalAccessor import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.ValueAccessor import org.opentaint.dataflow.ap.ifds.access.FinalFactAp import org.opentaint.dataflow.ap.ifds.access.InitialFactAp @@ -1202,6 +1204,8 @@ class AccessCactus( FinalAccessor -> error("Unexpected FinalAccessor") AnyAccessor -> low === AnyAccessor ValueAccessor -> TODO() + is TypeInfoAccessor -> TODO() + TypeInfoGroupAccessor -> TODO() } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessPath.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessPath.kt index d43e3b37c..16ea1e8b8 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessPath.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessPath.kt @@ -13,10 +13,12 @@ import org.opentaint.dataflow.ap.ifds.access.util.AccessorIdx import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.ANY_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.ELEMENT_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.FINAL_ACCESSOR_IDX +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.TYPE_INFO_GROUP_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.VALUE_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isFieldAccessor import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isStaticAccessor import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTypeInfoAccessor import org.opentaint.dataflow.util.reversedForEachInt class AccessPath( @@ -291,6 +293,10 @@ class AccessPath( } accessor == ANY_ACCESSOR_IDX -> this // todo: All accessors are not supported in tree base ap + + accessor == TYPE_INFO_GROUP_ACCESSOR_IDX -> AccessNode(manager, accessor, this) + accessor.isTypeInfoAccessor() -> AccessNode(manager, accessor, this) + else -> error("Unsupported accessor $accessor") } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessTree.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessTree.kt index 9e76adcb9..394e4d320 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessTree.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/AccessTree.kt @@ -17,10 +17,12 @@ import org.opentaint.dataflow.ap.ifds.access.util.AccessorIdx import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.ANY_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.ELEMENT_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.FINAL_ACCESSOR_IDX +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.TYPE_INFO_GROUP_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.VALUE_ACCESSOR_IDX import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isFieldAccessor import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isStaticAccessor import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTypeInfoAccessor import org.opentaint.dataflow.ap.ifds.serialization.SummarySerializationContext import org.opentaint.dataflow.util.Cancellation import org.opentaint.dataflow.util.forEachInt @@ -33,7 +35,7 @@ import java.util.Optional import kotlin.jvm.optionals.getOrNull class AccessTree( - private val apManager: TreeApManager, + val apManager: TreeApManager, override val base: AccessPathBase, val access: AccessNode, override val exclusions: ExclusionSet @@ -437,6 +439,10 @@ class AccessTree( } accessor == ANY_ACCESSOR_IDX -> prependAnyAccessor() + + accessor == TYPE_INFO_GROUP_ACCESSOR_IDX -> create(accessor, this) + accessor.isTypeInfoAccessor() -> create(accessor, this) + else -> error("Unsupported accessor: $accessor") } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstraction.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstraction.kt index a34d9a570..044b0473e 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstraction.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstraction.kt @@ -17,6 +17,7 @@ import org.opentaint.dataflow.ap.ifds.access.tree.AccessTree.AccessNode.Companio import org.opentaint.dataflow.ap.ifds.access.tree.AccessTree.AccessNode.Companion.createAbstractNodeFromReversedAp import org.opentaint.dataflow.ap.ifds.access.util.AccessorIdx import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.FINAL_ACCESSOR_IDX +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isAlwaysUnrollNext import org.opentaint.dataflow.util.forEachInt import org.opentaint.dataflow.ap.ifds.access.tree.AccessTree.AccessNode as AccessTreeNode @@ -76,7 +77,6 @@ class TreeInitialFactAbstraction( ) { var concreteFactAccess = initialConcreteFact while (true) { - val unrollRequests = mutableListOf() abstractAccessPath(facts.analyzed, concreteFactAccess, unrollRequests) { abstractAccess -> apManager.cancellation.checkpoint() @@ -170,7 +170,7 @@ class TreeInitialFactAbstraction( initialAnalyzedTrieRoot: AccessPathTrieNode, initialAdded: AccessTreeNode, unrollRequests: MutableList, - createAbstractAp: (ReversedApNode?) -> Unit + crossinline createAbstractAp: (ReversedApNode?) -> Unit ) { val unprocessed = mutableListOf() unprocessed.add(AbstractionState(initialAnalyzedTrieRoot, initialAdded, currentAp = null)) @@ -208,35 +208,70 @@ class TreeInitialFactAbstraction( addedNode: AccessTreeNode, currentAp: ReversedApNode?, unprocessed: MutableList, - createAbstractAp: (ReversedApNode?) -> Unit + crossinline createAbstractAp: (ReversedApNode?) -> Unit ) { val node = analyzedTrieRoot.child(accessor) - if (node == null) { - val exclusions = analyzedTrieRoot.exclusions() - - // We have no excludes -> continue with the most abstract fact - if (exclusions == null) { - createAbstractAp(currentAp) - return + if (node != null) { + val apWithAccessor = ReversedApNode(accessor, currentAp) + if (accessor.isAlwaysUnrollNext()) { + abstractNextAccessPath(addedNode, apWithAccessor) { + createAbstractAp(it) + } + } else { + unprocessed += AbstractionState(node, addedNode, apWithAccessor) } + return + } - // Concrete: a.b.* E - // Added: a.* S - if (exclusions.contains(accessor)) { - // We have initial fact that exclude {b} and we have no a.b fact yet - // Return a.b.* {} - - createAbstractAp(ReversedApNode(accessor, currentAp)) + val exclusions = analyzedTrieRoot.exclusions() - return - } + // We have no excludes -> continue with the most abstract fact + if (exclusions == null) { + createAbstractAp(currentAp) + return + } + // Concrete: a.b.* E + // Added: a.* S + if (!exclusions.contains(accessor)) { // We have no conflict with added facts return } + // We have initial fact that exclude {b} and we have no a.b fact yet + if (!accessor.isAlwaysUnrollNext()) { + // Return a.b.* {} + createAbstractAp(ReversedApNode(accessor, currentAp)) + return + } + val apWithAccessor = ReversedApNode(accessor, currentAp) - unprocessed += AbstractionState(node, addedNode, apWithAccessor) + abstractNextAccessPath(addedNode, apWithAccessor) { + createAbstractAp(it) + } + } + + private fun abstractNextAccessPath( + addedNode: AccessTreeNode, + currentAp: ReversedApNode, + createAbstractAp: (ReversedApNode) -> Unit + ) { + if (addedNode.containsAnyAccessor()) { + TODO("Any after unroll-next is not supported yet") + } + + if (addedNode.isFinal) { + createAbstractAp(ReversedApNode(FINAL_ACCESSOR_IDX, currentAp)) + } + + addedNode.forEachAccessor { accessor, node -> + val nextAp = ReversedApNode(accessor, currentAp) + if (!accessor.isAlwaysUnrollNext()) { + createAbstractAp(nextAp) + } else { + abstractNextAccessPath(node, nextAp, createAbstractAp) + } + } } private class MethodSameMarkInitialFact( diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInterner.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInterner.kt index b0452601c..ca084d9f5 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInterner.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInterner.kt @@ -7,6 +7,8 @@ import org.opentaint.dataflow.ap.ifds.ElementAccessor import org.opentaint.dataflow.ap.ifds.FieldAccessor import org.opentaint.dataflow.ap.ifds.FinalAccessor import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.ValueAccessor import org.opentaint.dataflow.util.ConcurrentReadSafeObject2IntMap import org.opentaint.dataflow.util.getOrCreateIndex @@ -38,41 +40,45 @@ class AccessorInterner { private val fields = AccessorStorage() private val statics = AccessorStorage() private val taints = AccessorStorage() - private val others = AccessorStorage() - private val storageByKind = arrayOf(fields, statics, taints, others) + private val types = AccessorStorage() + private val storageByBasicKind = arrayOf(fields, statics, taints) fun index(accessor: Accessor): AccessorIdx { val kind = when (accessor) { is FieldAccessor -> FIELD_KIND is ClassStaticAccessor -> STATIC_KIND is TaintMarkAccessor -> TAINT_KIND - else -> OTHER_KIND + is TypeInfoAccessor -> TYPES_KIND + is AnyAccessor -> return ANY_ACCESSOR_IDX + is ElementAccessor -> return ELEMENT_ACCESSOR_IDX + is FinalAccessor -> return FINAL_ACCESSOR_IDX + is TypeInfoGroupAccessor -> return TYPE_INFO_GROUP_ACCESSOR_IDX + is ValueAccessor -> return VALUE_ACCESSOR_IDX } - if (kind != OTHER_KIND) { - val storage = storageByKind[kind] + if (kind.getAccessorBasicKind() != TYPES_OR_MARKER_KIND) { + val storage = storageByBasicKind[kind] val idx = storage.index(accessor) - return setAccessorKind(idx, kind) - } - - return when (accessor) { - AnyAccessor -> ANY_ACCESSOR_IDX - ElementAccessor -> ELEMENT_ACCESSOR_IDX - FinalAccessor -> FINAL_ACCESSOR_IDX - ValueAccessor -> VALUE_ACCESSOR_IDX - else -> { - val storage = storageByKind[OTHER_KIND] - val idx = storage.index(accessor) + OTHER_ACCESSOR_START - setAccessorKind(idx, kind) - } + return setAccessorKind(idx, kind, BASIC_KIND_BITS) + } else { + check(kind == TYPES_KIND) + val idx = types.index(accessor) + return setAccessorKind(idx, TYPES_KIND, TYPES_OR_MARKER_KIND_BITS) } } fun accessor(idx: AccessorIdx): Accessor? { - val kind = idx.getAccessorKind() - if (kind != OTHER_KIND) { - val storage = storageByKind[kind] - return storage.getOrNull(idx.getAccessorIdx()) + val kind = idx.getAccessorBasicKind() + if (kind != TYPES_OR_MARKER_KIND) { + val storage = storageByBasicKind[kind] + val accessorIdx = idx.getAccessorIdx(BASIC_KIND_BITS) + return storage.getOrNull(accessorIdx) + } + + val typesOrMarkerKind = idx.getAccessorKind(TYPES_OR_MARKER_KIND_MASK) + if (typesOrMarkerKind == TYPES_KIND) { + val accessorIdx = idx.getAccessorIdx(TYPES_OR_MARKER_KIND_BITS) + return types.getOrNull(accessorIdx) } return when (idx) { @@ -80,37 +86,54 @@ class AccessorInterner { ELEMENT_ACCESSOR_IDX -> ElementAccessor VALUE_ACCESSOR_IDX -> ValueAccessor ANY_ACCESSOR_IDX -> AnyAccessor - else -> { - val storage = storageByKind[OTHER_KIND] - storage.getOrNull(idx.getAccessorIdx() - OTHER_ACCESSOR_START) - } + TYPE_INFO_GROUP_ACCESSOR_IDX -> TypeInfoGroupAccessor + else -> error("Unexpected accessor $idx") } } companion object { - const val FIELD_KIND = 0 - const val STATIC_KIND = 1 - const val TAINT_KIND = 2 - const val OTHER_KIND = 3 + const val BASIC_KIND_BITS = 2 + const val BASIC_KIND_MASK = 0b11 + const val FIELD_KIND = 0b00 + const val STATIC_KIND = 0b01 + const val TAINT_KIND = 0b10 + const val TYPES_OR_MARKER_KIND = 0b11 + + const val TYPES_OR_MARKER_KIND_BITS = 1 + BASIC_KIND_BITS + const val TYPES_OR_MARKER_KIND_MASK = (1 shl BASIC_KIND_BITS) or BASIC_KIND_MASK - const val OTHER_ACCESSOR_START = 4 - const val FINAL_ACCESSOR_IDX = (0 shl 2) or OTHER_KIND - const val ELEMENT_ACCESSOR_IDX = (1 shl 2) or OTHER_KIND - const val VALUE_ACCESSOR_IDX = (2 shl 2) or OTHER_KIND - const val ANY_ACCESSOR_IDX = (3 shl 2) or OTHER_KIND + const val MARKERS_KIND = (0 shl BASIC_KIND_BITS) or TYPES_OR_MARKER_KIND + const val TYPES_KIND = (1 shl BASIC_KIND_BITS) or TYPES_OR_MARKER_KIND + + const val FINAL_ACCESSOR_IDX = (0 shl TYPES_OR_MARKER_KIND_BITS) or MARKERS_KIND + const val ELEMENT_ACCESSOR_IDX = (1 shl TYPES_OR_MARKER_KIND_BITS) or MARKERS_KIND + const val VALUE_ACCESSOR_IDX = (2 shl TYPES_OR_MARKER_KIND_BITS) or MARKERS_KIND + const val ANY_ACCESSOR_IDX = (3 shl TYPES_OR_MARKER_KIND_BITS) or MARKERS_KIND + const val TYPE_INFO_GROUP_ACCESSOR_IDX = (4 shl TYPES_OR_MARKER_KIND_BITS) or MARKERS_KIND @Suppress("NOTHING_TO_INLINE") - inline fun AccessorIdx.getAccessorKind(): Int = this and 0x3 + inline fun AccessorIdx.getAccessorBasicKind(): Int = getAccessorKind(BASIC_KIND_MASK) @Suppress("NOTHING_TO_INLINE") - inline fun AccessorIdx.getAccessorIdx(): Int = this shr 2 + inline fun AccessorIdx.getAccessorKind(kindMask: Int): Int = this and kindMask @Suppress("NOTHING_TO_INLINE") - inline fun setAccessorKind(accessorIdx: Int, kind: Int): AccessorIdx = - (accessorIdx shl 2) or kind + inline fun AccessorIdx.getAccessorIdx(kindBits: Int): Int = this shr kindBits - fun AccessorIdx.isFieldAccessor(): Boolean = getAccessorKind() == FIELD_KIND - fun AccessorIdx.isStaticAccessor(): Boolean = getAccessorKind() == STATIC_KIND - fun AccessorIdx.isTaintMarkAccessor(): Boolean = getAccessorKind() == TAINT_KIND + @Suppress("NOTHING_TO_INLINE") + inline fun setAccessorKind(accessorIdx: Int, kind: Int, kindBits: Int): AccessorIdx = + (accessorIdx shl kindBits) or kind + + fun AccessorIdx.isFieldAccessor(): Boolean = getAccessorBasicKind() == FIELD_KIND + fun AccessorIdx.isStaticAccessor(): Boolean = getAccessorBasicKind() == STATIC_KIND + fun AccessorIdx.isTaintMarkAccessor(): Boolean = getAccessorBasicKind() == TAINT_KIND + fun AccessorIdx.isTypeInfoAccessor(): Boolean = getAccessorKind(TYPES_OR_MARKER_KIND_MASK) == TYPES_KIND + + fun AccessorIdx.isAlwaysUnrollNext(): Boolean = + isTaintMarkAccessor() + || (this == FINAL_ACCESSOR_IDX) + || (this == VALUE_ACCESSOR_IDX) + || (this == TYPE_INFO_GROUP_ACCESSOR_IDX) + || isTypeInfoAccessor() } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintAnalysisUnitStorage.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintAnalysisUnitStorage.kt index 56b4f0b3b..aa4c0c4a5 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintAnalysisUnitStorage.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintAnalysisUnitStorage.kt @@ -3,18 +3,31 @@ package org.opentaint.dataflow.ap.ifds.taint import org.opentaint.dataflow.ap.ifds.LanguageManager import org.opentaint.dataflow.ap.ifds.MethodSummariesUnitStorage import org.opentaint.dataflow.ap.ifds.access.ApManager -import java.util.concurrent.ConcurrentLinkedQueue +import org.opentaint.ir.api.common.cfg.CommonInst +import java.util.concurrent.ConcurrentHashMap class TaintAnalysisUnitStorage(apManager: ApManager, languageManager: LanguageManager) : MethodSummariesUnitStorage(apManager, languageManager) { - private val vulnerabilities = ConcurrentLinkedQueue() + private data class VulnerabilityIdentity( + val ruleId: String, + val statement: CommonInst, + ) + + private val vulnerabilityBuckets = ConcurrentHashMap() fun addVulnerability(vulnerability: TaintSinkTracker.TaintVulnerability) { - vulnerabilities.add(vulnerability) + val identity = VulnerabilityIdentity(vulnerability.ruleId, vulnerability.statement) + val bucket = vulnerabilityBuckets.computeIfAbsent(identity) { vulnerability } + + synchronized(bucket) { + bucket.mergeAdd(vulnerability) + } } fun collectVulnerabilities(collector: MutableList) { - collector.addAll(vulnerabilities) + vulnerabilityBuckets.values.forEach { + collector.add(it) + } } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintSinkTracker.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintSinkTracker.kt index 40b4fc5b6..57b1c6a9b 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintSinkTracker.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/taint/TaintSinkTracker.kt @@ -1,8 +1,11 @@ package org.opentaint.dataflow.ap.ifds.taint +import mu.KLogging import org.opentaint.dataflow.ap.ifds.MethodEntryPoint import org.opentaint.dataflow.ap.ifds.access.FinalFactAp import org.opentaint.dataflow.ap.ifds.access.InitialFactAp +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityRuleNode.Unconditional +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityRuleNode.WithRequirement import org.opentaint.dataflow.configuration.CommonTaintConfigurationItem import org.opentaint.dataflow.configuration.CommonTaintConfigurationSink import org.opentaint.dataflow.configuration.CommonTaintConfigurationSource @@ -12,41 +15,104 @@ import java.util.concurrent.ConcurrentHashMap class TaintSinkTracker( private val storage: TaintAnalysisUnitStorage, ) { - sealed interface TaintVulnerability { + data class TaintVulnerability( + val statement: CommonInst, + val ruleId: String, + val vulnerabilityRules: MutableMap, + ) { val rule: CommonTaintConfigurationSink - val methodEntryPoint: MethodEntryPoint - val statement: CommonInst + get() = vulnerabilityRules.keys.firstOrNull() + ?: error("Incorrect vulnerability without rules") + + fun mergeAdd(other: TaintVulnerability) { + check(other.ruleId == ruleId && other.statement == statement) { + "Unable to merge different vulnerabilities" + } + + vulnerabilityRules.addAll(other.vulnerabilityRules, TaintVulnerabilityRuleNode::merge) + } } - data class TaintVulnerabilityUnconditional( - override val rule: CommonTaintConfigurationSink, - override val methodEntryPoint: MethodEntryPoint, - override val statement: CommonInst - ) : TaintVulnerability + sealed interface TaintVulnerabilityRuleNode { + fun merge(other: TaintVulnerabilityRuleNode): TaintVulnerabilityRuleNode - enum class VulnerabilityTriggerPosition { - BEFORE_INST, AFTER_INST + data class Unconditional(val methodEntryPoint: MethodEntryPoint) : TaintVulnerabilityRuleNode { + override fun merge(other: TaintVulnerabilityRuleNode) = this + } + + data class Fact( + val vulnerabilityTriggerPosition: VulnerabilityTriggerPosition, + val facts: MutableMap, + ) : TaintVulnerabilityRuleNode { + override fun merge(other: TaintVulnerabilityRuleNode): TaintVulnerabilityRuleNode { + val otherFact: Fact = when (other) { + is Unconditional -> { + return other + } + + is WithRequirement -> { + logger.debug { "Vulnerability with requirement ignored: replaced with fact" } + return this + } + + is Fact -> other + } + + if (otherFact.vulnerabilityTriggerPosition != vulnerabilityTriggerPosition) { + logger.debug { "Vulnerability with fact ignored: position mismatch" } + return this + } + + facts.addAll(otherFact.facts, VulnerabilityFactGroups::merge) + return this + } + } + + data class WithRequirement( + val requirement: MutableMap + ) : TaintVulnerabilityRuleNode { + override fun merge(other: TaintVulnerabilityRuleNode): TaintVulnerabilityRuleNode { + val otherWithReq: WithRequirement = when (other) { + is Unconditional -> return other + is Fact -> { + logger.debug { "Vulnerability with requirement ignored: replaced with fact" } + return other + } + + is WithRequirement -> other + } + + requirement.addAll(otherWithReq.requirement, TaintVulnerabilityRuleNode::merge) + return this + } + } } - data class TaintVulnerabilityWithFact( - override val rule: CommonTaintConfigurationSink, - override val methodEntryPoint: MethodEntryPoint, - override val statement: CommonInst, - val factAp: Set, - val vulnerabilityTriggerPosition: VulnerabilityTriggerPosition, - ): TaintVulnerability + data class EndFactRequirement( + val methodEntryPoint: MethodEntryPoint, + val endFactRequirement: Set + ) + + data class VulnerabilityFacts(val facts: Set) - data class TaintVulnerabilityWithEndFactRequirement( - val vulnerability: TaintVulnerability, - val endFactRequirement: Set, - ) : TaintVulnerability by vulnerability + data class VulnerabilityFactGroups(val facts: MutableSet) { + fun merge(other: VulnerabilityFactGroups): VulnerabilityFactGroups { + facts.addAll(other.facts) + return this + } + } + + enum class VulnerabilityTriggerPosition { + BEFORE_INST, AFTER_INST + } fun addUnconditionalVulnerability( methodEntryPoint: MethodEntryPoint, statement: CommonInst, rule: CommonTaintConfigurationSink, ) { - addVulnerability(TaintVulnerabilityUnconditional(rule, methodEntryPoint, statement)) + val vuln = TaintVulnerability(statement, rule.id, hashMapOf(rule to Unconditional(methodEntryPoint))) + addVulnerability(vuln) } fun addUnconditionalVulnerabilityWithEndFactRequirement( @@ -55,12 +121,9 @@ class TaintSinkTracker( rule: CommonTaintConfigurationSink, requiredEndFacts: Set, ) { - addVulnerability( - TaintVulnerabilityWithEndFactRequirement( - TaintVulnerabilityUnconditional(rule, methodEntryPoint, statement), - requiredEndFacts, - ) - ) + val reqNode = WithRequirement(hashMapOf(EndFactRequirement(methodEntryPoint, requiredEndFacts) to Unconditional(methodEntryPoint))) + val vuln = TaintVulnerability(statement, rule.id, hashMapOf(rule to reqNode)) + addVulnerability(vuln) } fun addVulnerability( @@ -70,11 +133,23 @@ class TaintSinkTracker( rule: CommonTaintConfigurationSink, vulnerabilityTriggerPosition: VulnerabilityTriggerPosition = VulnerabilityTriggerPosition.BEFORE_INST, ) { - addVulnerability( - TaintVulnerabilityWithFact(rule, methodEntryPoint, statement, facts, vulnerabilityTriggerPosition) - ) + val factNode = createFactNode(vulnerabilityTriggerPosition, methodEntryPoint, facts) + val vuln = TaintVulnerability(statement, rule.id, hashMapOf(rule to factNode)) + addVulnerability(vuln) } + private fun createFactNode( + vulnerabilityTriggerPosition: VulnerabilityTriggerPosition, + methodEntryPoint: MethodEntryPoint, + facts: Set + ): TaintVulnerabilityRuleNode.Fact = TaintVulnerabilityRuleNode.Fact( + vulnerabilityTriggerPosition, hashMapOf( + methodEntryPoint to VulnerabilityFactGroups( + hashSetOf(VulnerabilityFacts(facts)) + ) + ) + ) + fun addVulnerabilityWithEndFactRequirement( methodEntryPoint: MethodEntryPoint, facts: Set, @@ -83,24 +158,13 @@ class TaintSinkTracker( requiredEndFacts: Set, vulnerabilityTriggerPosition: VulnerabilityTriggerPosition = VulnerabilityTriggerPosition.BEFORE_INST, ) { - addVulnerability( - TaintVulnerabilityWithEndFactRequirement( - TaintVulnerabilityWithFact(rule, methodEntryPoint, statement, facts, vulnerabilityTriggerPosition), - requiredEndFacts, - ) - ) + val factNode = createFactNode(vulnerabilityTriggerPosition, methodEntryPoint, facts) + val reqNode = WithRequirement(hashMapOf(EndFactRequirement(methodEntryPoint, requiredEndFacts) to factNode)) + val vuln = TaintVulnerability(statement, rule.id, hashMapOf(rule to reqNode)) + addVulnerability(vuln) } - private val reportedVulnerabilities = ConcurrentHashMap>() - private fun addVulnerability(vulnerability: TaintVulnerability) { - val reportedVulnerabilitiesFoRule = reportedVulnerabilities.computeIfAbsent(vulnerability.rule.id) { - ConcurrentHashMap.newKeySet() - } - - // todo: current deduplication is incompatible with traces - if (!reportedVulnerabilitiesFoRule.add(vulnerability.statement)) return - storage.addVulnerability(vulnerability) } @@ -169,4 +233,22 @@ class TaintSinkTracker( } } } + + companion object { + private val logger = object : KLogging() {}.logger + + inline fun MutableMap.addAll(other: Map, addValue: V.(V) -> V) { + for ((key, value) in other) { + val curValue = this[key] + if (curValue != null) { + val modified = curValue.addValue(value) + if (modified !== curValue) { + put(key, modified) + } + } else { + put(key, value) + } + } + } + } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/ParallelProcessingContext.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/ParallelProcessingContext.kt index 122f6ebe8..bf57c14a6 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/ParallelProcessingContext.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/ParallelProcessingContext.kt @@ -6,6 +6,8 @@ import kotlinx.coroutines.CoroutineExceptionHandler import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.joinAll @@ -14,62 +16,90 @@ import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeoutOrNull import mu.KotlinLogging import org.opentaint.dataflow.util.Cancellation -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicIntegerArray +import java.util.concurrent.atomic.AtomicReferenceArray import kotlin.time.Duration import kotlin.time.Duration.Companion.seconds abstract class ParallelProcessingContext( dispatcher: CoroutineDispatcher, private val name: String, - private val data: List, + private val tasks: List, ) { - abstract fun createUnprocessed(item: T): R - - open fun reportStats() { + sealed interface ProcessingResult { + data class Done(val result: R) : ProcessingResult + data class Running(val task: T) : ProcessingResult } - fun processingResults(): List { - val processingRes = mutableListOf() - processingRes.addAll(result) + abstract fun processItem(item: T): ProcessingResult - data.mapNotNullTo(processingRes) { - if (it in processedItems) return@mapNotNullTo null - - createUnprocessed(it) - } + abstract fun createUnprocessed(item: T): R - return processingRes + open fun reportStats() { } val processed: Int get() = processedCounter.get() + private val tasksQueue = Channel(Channel.UNLIMITED) + private val workers = mutableListOf() + private val completed = CompletableDeferred() private val processedCounter = AtomicInteger() - private val result = ConcurrentLinkedQueue() - private val processedItems = ConcurrentHashMap.newKeySet() - private val jobs = mutableListOf() + + private val latestState: AtomicReferenceArray = AtomicReferenceArray(tasks.size).also { + for (i in tasks.indices) it.set(i, tasks[i]) + } + private val results: AtomicReferenceArray = AtomicReferenceArray(tasks.size) + private val terminated: AtomicIntegerArray = AtomicIntegerArray(tasks.size) + private val scope = CoroutineScope(dispatcher) private val exceptionHandler = CoroutineExceptionHandler { _, exception -> logger.error(exception) { "$name failed" } - updatedProcessed() } - fun processAllWithCompletion( - body: (T) -> R - ): CompletableDeferred { - data.mapTo(jobs) { vulnerability -> - scope.launch(exceptionHandler) { - try { - result.add(body(vulnerability)) - processedItems.add(vulnerability) - } catch (ex: Throwable) { - logger.error(ex) { "$name failed" } - } finally { - updatedProcessed() + private fun processingResults(): List { + return List(tasks.size) { i -> + results.get(i) ?: createUnprocessed(latestState.get(i)) + } + } + + private fun markTerminal(index: Int) { + if (terminated.compareAndSet(index, 0, 1)) { + if (processedCounter.incrementAndGet() == tasks.size) { + tasksQueue.close() + completed.complete(Unit) + } + } + } + + private fun processAllWithCompletion(cancellation: Cancellation): CompletableDeferred { + val workerCount = minOf(WORKER_COUNT, tasks.size) + repeat(workerCount) { + workers += scope.launch(exceptionHandler) { + for (index in tasksQueue) { + if (!cancellation.isActive()) break + + val task = latestState.get(index) + try { + when (val r = processItem(task)) { + is ProcessingResult.Done -> { + latestState.set(index, null) + results.set(index, r.result) + markTerminal(index) + } + + is ProcessingResult.Running -> { + latestState.set(index, r.task) + tasksQueue.send(index) + } + } + } catch (ex: Throwable) { + logger.error(ex) { "$name failed" } + markTerminal(index) + } } } } @@ -81,14 +111,17 @@ abstract class ParallelProcessingContext( timeout: Duration, cancellationTimeout: Duration, cancellation: Cancellation, - body: (T) -> R ): List { - val completion = processAllWithCompletion(body) + for (i in tasks.indices) { + tasksQueue.trySendBlocking(i) + } + + val completion = processAllWithCompletion(cancellation) val progress = progressScope.launch { while (isActive) { delay(10.seconds) - logger.info { "${name}: processed ${processed}/${data.size} items" } + logger.info { "${name}: processed ${processed}/${tasks.size} items" } reportStats() } } @@ -101,6 +134,7 @@ abstract class ParallelProcessingContext( withTimeoutOrNull(cancellationTimeout) { cancellation.cancel() + tasksQueue.cancel() progress.cancelAndJoin() joinCtx() @@ -108,21 +142,17 @@ abstract class ParallelProcessingContext( } return processingResults().also { result -> - logger.info { "${name}: processed ${result.size}/${data.size} items" } + logger.info { "${name}: processed ${result.size}/${tasks.size} items" } } } suspend fun joinCtx() { - jobs.joinAll() - } - - private fun updatedProcessed() { - if (processedCounter.incrementAndGet() == data.size) { - completed.complete(Unit) - } + workers.joinAll() } companion object { + private const val WORKER_COUNT = 10 + private val logger = KotlinLogging.logger {} } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/TraceResolver.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/TraceResolver.kt index 169b53a0d..a0119a542 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/TraceResolver.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/TraceResolver.kt @@ -2,11 +2,15 @@ package org.opentaint.dataflow.ap.ifds.trace import org.opentaint.dataflow.ap.ifds.MethodEntryPoint import org.opentaint.dataflow.ap.ifds.TaintAnalysisUnitRunnerManager +import org.opentaint.dataflow.ap.ifds.access.InitialFactAp import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerability +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityRuleNode import org.opentaint.dataflow.ap.ifds.trace.MethodTraceResolver.TraceEntry.MethodEntry import org.opentaint.dataflow.ap.ifds.trace.MethodTraceResolver.TraceEntry.SourceStartEntry import org.opentaint.dataflow.ap.ifds.trace.MethodTraceResolver.TraceEntryAction +import org.opentaint.dataflow.ap.ifds.trace.TraceResolver.TraceResolutionResult.NoTrace +import org.opentaint.dataflow.ap.ifds.trace.TraceResolver.TraceResolutionResult.Resolved import org.opentaint.dataflow.util.Cancellation import org.opentaint.ir.api.common.CommonMethod import org.opentaint.ir.api.common.cfg.CommonInst @@ -19,8 +23,6 @@ class TraceResolver( ) { data class Params( val resolveEntryPointToStartTrace: Boolean = true, - val startToSourceTraceResolutionLimit: Int? = null, - val startToSinkTraceResolutionLimit: Int? = null, val sourceToSinkInnerTraceResolutionLimit: Int? = null, val innerCallTraceResolveStrategy: InnerCallTraceResolveStrategy = InnerCallTraceResolveStrategy.Default, ) @@ -126,43 +128,175 @@ class TraceResolver( } } - fun resolveTrace(vulnerability: TaintVulnerability): Trace { - when (vulnerability) { - is TaintSinkTracker.TaintVulnerabilityWithEndFactRequirement -> { - return resolveTrace(vulnerability.vulnerability) - } + sealed interface TraceResolutionResult { + data class Resolved(val vulnerability: TaintVulnerability, val trace: Trace) : TraceResolutionResult + data class NoTrace(val vulnerability: TaintVulnerability) : TraceResolutionResult + data class InProgress(val state: State) : TraceResolutionResult + } + + fun resolveTrace(state: State): TraceResolutionResult { + when (state) { + is State.Initial -> { + val requests = mutableListOf() + + val vulnerability = state.vulnerability + val unconditionalTrace = vulnerability.vulnerabilityRules.values.firstNotNullOfOrNull { + collectTraceResolutionRequests(requests, vulnerability.statement, it) + } + + if (unconditionalTrace != null) { + return Resolved(vulnerability, unconditionalTrace) + } - is TaintSinkTracker.TaintVulnerabilityUnconditional -> { - val node = SimpleTraceNode(vulnerability.statement, vulnerability.methodEntryPoint) - val entryPointToStart = resolveEntryPointToStartTrace(setOf(node)) - val sourceToSinkTrace = SourceToSinkTrace(setOf(node), setOf(node), emptyMap()) - return Trace(entryPointToStart, sourceToSinkTrace) + val nextState = Source2SinkTraceResolutionState( + vulnerability, + InterProceduralTraceGraphBuilder(), + requests.distinct().sorted(), + nextRequestIdx = 0, + kind = ProcessingKind.ADD_NEXT_REQUEST, + ) + + return TraceResolutionResult.InProgress(nextState) } - is TaintSinkTracker.TaintVulnerabilityWithFact -> { - val builder = InterProceduralTraceGraphBuilder() - - manager.withMethodRunner(vulnerability.methodEntryPoint) { - val traces = resolveIntraProceduralTraceSummary( - vulnerability.methodEntryPoint, - vulnerability.statement, - vulnerability.factAp, - includeStatement = when (vulnerability.vulnerabilityTriggerPosition) { - TaintSinkTracker.VulnerabilityTriggerPosition.BEFORE_INST -> false - TaintSinkTracker.VulnerabilityTriggerPosition.AFTER_INST -> true - } - ) + is Source2SinkTraceResolutionState -> when (state.kind) { + ProcessingKind.ADD_NEXT_REQUEST -> { + if (state.nextRequestIdx >= state.requests.size) { + return NoTrace(state.vulnerability) + } + + val nextState = addNextRequest(state) + return TraceResolutionResult.InProgress(nextState) + } + + ProcessingKind.PROCESS -> { + state.builder.process(limit = 100) - for (trace in traces) { - builder.createSinkNode(trace) + if (!state.builder.isEmpty()) { + return TraceResolutionResult.InProgress(state) } + + state.builder.removeUnresolvedInnerCalls() + val trace = state.builder.createSource2SinkTrace() + + if (trace.startNodes.isEmpty()) { + // trace is invalid, proceed with next request + val nextState = state.copy(kind = ProcessingKind.ADD_NEXT_REQUEST) + return TraceResolutionResult.InProgress(nextState) + } + + val nextState = Ep2StartTraceResolutionState(state.vulnerability, trace) + return TraceResolutionResult.InProgress(nextState) } + } + + is Ep2StartTraceResolutionState -> { + val entryPointToStart = resolveEntryPointToStartTrace(state.trace.startNodes) + val resultTrace = Trace(entryPointToStart, state.trace) + return Resolved(state.vulnerability, resultTrace) + } + } + } + + private fun addNextRequest(state: Source2SinkTraceResolutionState): Source2SinkTraceResolutionState { + val request = state.requests[state.nextRequestIdx] + manager.withMethodRunner(request.methodEntryPoint) { + val traces = resolveIntraProceduralTraceSummary( + request.methodEntryPoint, + state.vulnerability.statement, + request.facts, + request.includeStatement + ) + + for (trace in traces) { + state.builder.createSinkNode(trace) + } + } + + val nextState = state.copy( + nextRequestIdx = state.nextRequestIdx + 1, + kind = ProcessingKind.PROCESS + ) + return nextState + } - val sourceToSinkTrace = builder.build() + private data class TraceResolutionRequest( + val methodEntryPoint: MethodEntryPoint, + val includeStatement: Boolean, + val facts: Set + ) : Comparable { + override fun compareTo(other: TraceResolutionRequest): Int { + val factsCmp = compareFacts(other.facts) + if (factsCmp != 0) return factsCmp - val entryPointToStart = resolveEntryPointToStartTrace(sourceToSinkTrace.startNodes) - return Trace(entryPointToStart, sourceToSinkTrace) + val stmtCmp = includeStatement.compareTo(other.includeStatement) + if (stmtCmp != 0) return stmtCmp + + return methodEntryPoint.toString().compareTo(other.methodEntryPoint.toString()) + } + + private fun compareFacts(other: Set): Int { + val sizeCmp = facts.sumOf { it.size }.compareTo(other.sumOf { it.size }) + if (sizeCmp != 0) return sizeCmp + + val thisFactsStr = facts.map { it.toString() }.sorted().joinToString() + val otherFactsStr = other.map { it.toString() }.sorted().joinToString() + return thisFactsStr.compareTo(otherFactsStr) + } + } + + sealed interface State { + val vulnerability: TaintVulnerability + + data class Initial(override val vulnerability: TaintVulnerability) : State + } + + private enum class ProcessingKind { + ADD_NEXT_REQUEST, PROCESS + } + + private data class Source2SinkTraceResolutionState( + override val vulnerability: TaintVulnerability, + val builder: InterProceduralTraceGraphBuilder, + val requests: List, + val nextRequestIdx: Int, + val kind: ProcessingKind, + ): State + + private class Ep2StartTraceResolutionState( + override val vulnerability: TaintVulnerability, + val trace: SourceToSinkTrace, + ): State + + private fun collectTraceResolutionRequests( + requests: MutableList, + statement: CommonInst, + node: TaintVulnerabilityRuleNode + ) : Trace? = when (node) { + is TaintVulnerabilityRuleNode.Unconditional -> { + val node = SimpleTraceNode(statement, node.methodEntryPoint) + val entryPointToStart = resolveEntryPointToStartTrace(setOf(node)) + val sourceToSinkTrace = SourceToSinkTrace(setOf(node), setOf(node), emptyMap()) + Trace(entryPointToStart, sourceToSinkTrace) + } + + is TaintVulnerabilityRuleNode.WithRequirement -> { + node.requirement.values.firstNotNullOfOrNull { collectTraceResolutionRequests(requests, statement, it) } + } + + is TaintVulnerabilityRuleNode.Fact -> { + val includeStatement = when (node.vulnerabilityTriggerPosition) { + TaintSinkTracker.VulnerabilityTriggerPosition.BEFORE_INST -> false + TaintSinkTracker.VulnerabilityTriggerPosition.AFTER_INST -> true + } + + for ((methodEntryPoint, factGroups) in node.facts) { + for (facts in factGroups.facts) { + requests += TraceResolutionRequest(methodEntryPoint, includeStatement, facts.facts) + } } + + null } } @@ -205,20 +339,11 @@ class TraceResolver( val unprocessedCall2Sink = mutableListOf() val unprocessedInner = mutableListOf() - private var startToSourceTraceResolutionStat = 0 - private var startToSinkTraceResolutionStat = 0 - fun createSinkNode(trace: MethodTraceResolver.SummaryTrace) { val nodes = resolveNode(trace, CallKind.CallToSink, depth = 0) sinkNodes.addAll(nodes) } - fun build(): SourceToSinkTrace { - process() - removeUnresolvedInnerCalls() - return createSource2SinkTrace() - } - private fun pollUnprocessedEvent(): BuilderUnprocessedTrace? { unprocessedCall2Sink.removeLastOrNull()?.let { return it } unprocessedCall2Source.removeLastOrNull()?.let { return it } @@ -236,8 +361,12 @@ class TraceResolver( } } - private fun process() { - while (cancellation.isActive()) { + fun isEmpty(): Boolean = + unprocessedCall2Sink.isEmpty() && unprocessedCall2Source.isEmpty() && unprocessedInner.isEmpty() + + fun process(limit: Int) { + var steps = 0 + while (cancellation.isActive() && ++steps < limit) { val event = pollUnprocessedEvent() ?: break val resolvedNodes = resolveNode(event.trace, event.kind, event.depth) @@ -255,7 +384,7 @@ class TraceResolver( } } - private fun createSource2SinkTrace(): SourceToSinkTrace { + fun createSource2SinkTrace(): SourceToSinkTrace { val rootsWithReachableSources = rootNodes.filter { node -> entriesReachableFrom(successors, node, sourceNodes) { edge -> edge.takeIf { it.kind == CallKind.CallToSource }?.node @@ -273,7 +402,7 @@ class TraceResolver( return SourceToSinkTrace(rootsWithReachableSinks, sinkNodes, successors) } - private fun removeUnresolvedInnerCalls() { + fun removeUnresolvedInnerCalls() { while (cancellation.isActive()) { val unresolvedNodes = hashMapOf>() @@ -399,10 +528,6 @@ class TraceResolver( val callerTraces = resolveMethodEntry(start) for ((callerStatement, callerTrace) in callerTraces) { - if (params.startToSinkTraceResolutionLimit != null) { - if (startToSinkTraceResolutionStat++ > params.startToSinkTraceResolutionLimit) continue - } - addUnprocessedEvent( BuilderUnprocessedTrace( trace = callerTrace, @@ -445,13 +570,6 @@ class TraceResolver( return node } - if (params.startToSourceTraceResolutionLimit != null) { - if (startToSourceTraceResolutionStat++ > params.startToSourceTraceResolutionLimit) { - sourceNodes.add(node) - return node - } - } - addUnprocessedEvent( BuilderUnprocessedTrace( trace = callSummary.summaryTrace, diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/VulnerabilityChecker.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/VulnerabilityChecker.kt index ce4656295..5ad049bad 100644 --- a/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/VulnerabilityChecker.kt +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/main/kotlin/org/opentaint/dataflow/ap/ifds/trace/VulnerabilityChecker.kt @@ -5,11 +5,14 @@ import org.opentaint.dataflow.ap.ifds.TaintAnalysisUnitRunnerManager import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor import org.opentaint.dataflow.ap.ifds.access.FinalFactAp import org.opentaint.dataflow.ap.ifds.analysis.MethodCallFlowFunction +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.EndFactRequirement import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerability -import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityWithEndFactRequirement +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityRuleNode +import org.opentaint.dataflow.ap.ifds.taint.TaintSinkTracker.TaintVulnerabilityRuleNode.WithRequirement import org.opentaint.dataflow.ap.ifds.trace.MethodForwardTraceResolver.EdgeReason import org.opentaint.dataflow.ap.ifds.trace.MethodForwardTraceResolver.RelevantFactFilter import org.opentaint.dataflow.ap.ifds.trace.MethodForwardTraceResolver.TraceEdge +import org.opentaint.dataflow.configuration.CommonTaintConfigurationSink import org.opentaint.dataflow.util.Cancellation import org.opentaint.ir.api.common.CommonMethod import org.opentaint.ir.api.common.cfg.CommonInst @@ -28,16 +31,58 @@ class VulnerabilityChecker( val status: VulnerabilityVerificationStatus, ) - fun verifyVulnerability(vulnerability: TaintVulnerabilityWithEndFactRequirement): VerifiedVulnerability { - val requiredFacts = vulnerability.endFactRequirement - val vuln = vulnerability.vulnerability + fun verifyVulnerability(vulnerability: TaintVulnerability): VerifiedVulnerability { + val verifiedNodes = hashMapOf() + + var isUnknown = false + for ((rule, node) in vulnerability.vulnerabilityRules) { + if (node !is WithRequirement){ + verifiedNodes[rule] = node + continue + } + + val nodeRequirement = hashMapOf() + for ((requirement, nested) in node.requirement) { + val reqStatus = verifyVulnerabilityReq(vulnerability.statement, requirement) + when (reqStatus) { + VulnerabilityVerificationStatus.UNCONFIRMED -> continue + VulnerabilityVerificationStatus.UNKNOWN -> { + isUnknown = true + nodeRequirement[requirement] = nested + } + + VulnerabilityVerificationStatus.CONFIRMED -> { + nodeRequirement[requirement] = nested + } + } + } + + if (nodeRequirement.isEmpty()) continue + + verifiedNodes[rule] = WithRequirement(nodeRequirement) + } + + if (verifiedNodes.isEmpty()) { + return VerifiedVulnerability(vulnerability, VulnerabilityVerificationStatus.UNCONFIRMED) + } + + val verifiedVuln = TaintVulnerability(vulnerability.statement, vulnerability.ruleId, verifiedNodes) + val resultStatus = if (isUnknown) VulnerabilityVerificationStatus.UNKNOWN else VulnerabilityVerificationStatus.CONFIRMED + return VerifiedVulnerability(verifiedVuln, resultStatus) + } + + private fun verifyVulnerabilityReq( + statement: CommonInst, + requirement: EndFactRequirement + ): VulnerabilityVerificationStatus { + val requiredFacts = requirement.endFactRequirement if (requiredFacts.size != 1) { - return VerifiedVulnerability(vuln, status = VulnerabilityVerificationStatus.UNKNOWN) + return VulnerabilityVerificationStatus.UNKNOWN } val startFact = requiredFacts.single() val startRequest = VulnConfirmationRequest( - vulnerability.methodEntryPoint, vulnerability.statement, + requirement.methodEntryPoint, statement, fact = startFact, resolveFromMethodStartDown = false ) @@ -45,18 +90,18 @@ class VulnerabilityChecker( val visited = hashSetOf() while (unprocessed.isNotEmpty()) { if (!cancellation.isActive()) { - return VerifiedVulnerability(vuln, status = VulnerabilityVerificationStatus.UNKNOWN) + return VulnerabilityVerificationStatus.UNKNOWN } val request = unprocessed.removeLast() if (!visited.add(request)) continue if (factReachAnalysisEnd(request, unprocessed)) { - return VerifiedVulnerability(vuln, status = VulnerabilityVerificationStatus.CONFIRMED) + return VulnerabilityVerificationStatus.CONFIRMED } } - return VerifiedVulnerability(vuln, status = VulnerabilityVerificationStatus.UNCONFIRMED) + return VulnerabilityVerificationStatus.UNCONFIRMED } private data class VulnConfirmationRequest( @@ -202,4 +247,9 @@ class VulnerabilityChecker( return IntraProcCheckResult.FactCleaned } + + companion object { + fun needVerification(vuln: TaintVulnerability): Boolean = + vuln.vulnerabilityRules.values.any { it is WithRequirement } + } } diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/InitialFactAbstractionTest.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/InitialFactAbstractionTest.kt new file mode 100644 index 000000000..d0b5c04ad --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/InitialFactAbstractionTest.kt @@ -0,0 +1,551 @@ +package org.opentaint.dataflow.ap.ifds.access + +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.ap.ifds.Accessor +import org.opentaint.dataflow.ap.ifds.AnyAccessor +import org.opentaint.dataflow.ap.ifds.ClassStaticAccessor +import org.opentaint.dataflow.ap.ifds.ElementAccessor +import org.opentaint.dataflow.ap.ifds.ExclusionSet +import org.opentaint.dataflow.ap.ifds.FactTypeChecker +import org.opentaint.dataflow.ap.ifds.FieldAccessor +import org.opentaint.dataflow.ap.ifds.FinalAccessor +import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor +import org.opentaint.dataflow.ap.ifds.ValueAccessor +import org.opentaint.ir.api.common.CommonMethod +import org.opentaint.ir.api.common.CommonMethodParameter +import org.opentaint.ir.api.common.CommonTypeName +import org.opentaint.ir.api.common.cfg.CommonInst +import org.opentaint.ir.api.common.cfg.CommonInstLocation +import org.opentaint.ir.api.common.cfg.ControlFlowGraph +import kotlin.test.Test +import kotlin.test.assertTrue + +abstract class InitialFactAbstractionTest { + private companion object { + const val TYPE_A = "A" + const val TYPE_B = "B" + const val TYPE_C = "C" + const val TYPE_D = "D" + const val NO_ANY_FIELD_NAME = "" + + val FIELD_A_B = FieldAccessor(TYPE_A, "b", TYPE_B) + val FIELD_B_C = FieldAccessor(TYPE_B, "c", TYPE_C) + val FIELD_B_E = FieldAccessor(TYPE_B, "e", TYPE_D) + val FIELD_C_D = FieldAccessor(TYPE_C, "d", TYPE_D) + val FIELD_NO_ANY = FieldAccessor(TYPE_B, NO_ANY_FIELD_NAME, TYPE_D) + + val MARK = TaintMarkAccessor("test-mark") + val MARK_2 = TaintMarkAccessor("test-mark-2") + val TYPE_INFO_A = TypeInfoAccessor("A") + val TYPE_INFO_B = TypeInfoAccessor("B") + } + + abstract fun mkApManager(strategy: AnyAccessorUnrollStrategy): ApManager + + private val apManager: ApManager = mkApManager(UnrollStrategy) + + private object UnrollStrategy : AnyAccessorUnrollStrategy { + override fun unrollAccessor(accessor: Accessor): Boolean = when (accessor) { + is ElementAccessor -> true + is FieldAccessor -> accessor.fieldName != NO_ANY_FIELD_NAME + is ClassStaticAccessor, + is AnyAccessor, + is FinalAccessor, + is TaintMarkAccessor, + is TypeInfoAccessor, + is TypeInfoGroupAccessor -> false + + is ValueAccessor -> error("Unexpected accessor to unroll: $accessor") + } + } + + abstract fun merge(fact: FinalFactAp, vararg facts: FinalFactAp): FinalFactAp + + private fun runScenario( + name: String, + analyzed: List, + added: FinalFactAp, + expectedFacts: List = emptyList(), + expectedEmpty: Boolean = false, + ) { + val abstraction = newAbstraction() + analyzed.forEach { analyzedFact -> + abstraction.registerNewInitialFact(analyzedFact, FactTypeChecker.Dummy) + } + + val produced = abstraction.addAbstractedInitialFact(added, FactTypeChecker.Dummy) + + if (expectedEmpty) { + val message = buildString { + appendLine("[$name] expected no produced facts") + appendLine("analyzed=$analyzed") + appendLine("added=$added") + appendLine("produced=${producedFactsToString(produced)}") + } + assertTrue(abstractionIsEmpty(produced), message) + } + + expectedFacts.forEach { expected -> + val message = buildString { + appendLine("[$name] expected fact is missing") + appendLine("analyzed=$analyzed") + appendLine("added=$added") + appendLine("expected=$expected") + appendLine("produced=${producedFactsToString(produced)}") + } + assertTrue( + produced.any { (initial, final) -> initial == expected && final.equalTo(expected) }, + message, + ) + } + } + + + + @Test + fun `scenario 1 exclusion hit on c returns a b c`() = runScenario( + "1 exclusion hit on c returns a.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 2 exclusion miss on e returns empty`() = runScenario( + "2 exclusion miss on e returns empty", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_E)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedEmpty = true + ) + + @Test + fun `scenario 3 analyzed mark no exclusions returns empty`() = runScenario( + "3 analyzed mark no exclusions returns empty", + listOf(initialFact(AccessPathBase.This, MARK)), + finalFact(AccessPathBase.This, FIELD_A_B, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 4 no analyzed facts for this base returns most abstract`() = runScenario( + "4 no analyzed facts for this base returns most abstract", + listOf(initialFact(AccessPathBase.ClassStatic, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedFacts = listOf(initialFact(AccessPathBase.This)) + ) + + @Test + fun `scenario 5 root exclusion on b returns a b`() = runScenario( + "5 root exclusion on b returns a.b", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `scenario 6 root non matching exclusion returns empty`() = runScenario( + "6 root non matching exclusion returns empty", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_B_E)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C), + expectedEmpty = true + ) + + @Test + fun `scenario 9 multiple analyzed paths currently produce no abstraction`() = runScenario( + "9 multiple analyzed paths currently produce no abstraction", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B), initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_E)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedEmpty = true + ) + + @Test + fun `scenario 10 most abstract analyzed with empty exclusions returns empty`() = runScenario( + "10 most abstract analyzed with empty exclusions returns empty", + listOf(initialFact(AccessPathBase.This)), + finalFact(AccessPathBase.This, FIELD_A_B), + expectedEmpty = true + ) + + @Test + fun `scenario 11 mark exclusion at root currently returns empty`() = runScenario( + "11 mark exclusion at root currently returns empty", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + finalFact(AccessPathBase.This, FIELD_A_B, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 12 value exclusion after mark currently returns empty`() = runScenario( + "12 value exclusion after mark currently returns empty", + listOf(initialFact(AccessPathBase.This, MARK).exclude(ValueAccessor)), + finalFact(AccessPathBase.This, FIELD_A_B, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 13 type group exclusion at root currently returns empty`() = runScenario( + "13 type group exclusion at root currently returns empty", + listOf(initialFact(AccessPathBase.This).exclude(TypeInfoGroupAccessor)), + finalFact(AccessPathBase.This, FIELD_A_B, TYPE_INFO_A, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 14 type accessor exclusion after group currently returns empty`() = runScenario( + "14 type accessor exclusion after group currently returns empty", + listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor).exclude(TYPE_INFO_A)), + finalFact(AccessPathBase.This, FIELD_A_B, TYPE_INFO_A, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 15 non matching type accessor exclusion returns empty`() = runScenario( + "15 non matching type accessor exclusion returns empty", + listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor).exclude(TYPE_INFO_B)), + finalFact(AccessPathBase.This, FIELD_A_B, TYPE_INFO_A, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 16 mark2 exclusion does not match mark1 returns empty`() = runScenario( + "16 mark2 exclusion does not match mark1 returns empty", + listOf(initialFact(AccessPathBase.This).exclude(MARK_2)), + finalFact(AccessPathBase.This, FIELD_A_B, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 17 exclusion on c with short added path returns a b c`() = runScenario( + "17 exclusion on c with short added path returns a.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 18 exclusion on b with short added path returns a b`() = runScenario( + "18 exclusion on b with short added path returns a.b", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + finalFact(AccessPathBase.This, FIELD_A_B), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `scenario 19 unrelated base plus matching this base exclusion uses this base result`() = runScenario( + "19 unrelated base plus matching this-base exclusion uses this-base result", + listOf( + initialFact(AccessPathBase.ClassStatic, FIELD_A_B).exclude(FIELD_B_C), + initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C) + ), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 20 conflicting exclusions on two levels return a b c`() = runScenario( + "20 conflicting exclusions on two levels return a.b.c", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B), initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 21 mark exclusion with only mark path currently returns empty`() = runScenario( + "21 mark exclusion with only mark path currently returns empty", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + finalFact(AccessPathBase.This, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 22 type group exclusion with only type path currently returns empty`() = runScenario( + "22 type group exclusion with only type path currently returns empty", + listOf(initialFact(AccessPathBase.This).exclude(TypeInfoGroupAccessor)), + finalFact(AccessPathBase.This, TYPE_INFO_A, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 23 root exclusion on mark with bare mark path returns mark final`() = runScenario( + "23 root exclusion on mark with bare mark path returns mark.final", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + finalFact(AccessPathBase.This, MARK), + expectedFacts = listOf(initialFact(AccessPathBase.This, MARK, FinalAccessor)) + ) + + @Test + fun `scenario 24 exclusion on value under mark with bare chain currently returns empty`() = runScenario( + "24 exclusion on value under mark with bare chain currently returns empty", + listOf(initialFact(AccessPathBase.This, MARK).exclude(ValueAccessor)), + finalFact(AccessPathBase.This, ValueAccessor, MARK), + expectedEmpty = true + ) + + @Test + fun `scenario 25 root exclusion on type group with bare type chain returns type group final`() = runScenario( + "25 root exclusion on type group with bare type chain returns type group.final", + listOf(initialFact(AccessPathBase.This).exclude(TypeInfoGroupAccessor)), + finalFact(AccessPathBase.This, TypeInfoGroupAccessor), + expectedFacts = listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor, FinalAccessor)) + ) + + @Test + fun `scenario 26 exclusion on concrete type under group with bare chain currently returns empty`() = runScenario( + "26 exclusion on concrete type under group with bare chain currently returns empty", + listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor).exclude(TYPE_INFO_A)), + finalFact(AccessPathBase.This, TYPE_INFO_A, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 27 root exclusion on mark with mark2 path returns empty`() = runScenario( + "27 root exclusion on mark with mark2 path returns empty", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + finalFact(AccessPathBase.This, MARK_2), + expectedEmpty = true + ) + + @Test + fun `scenario 28 exclusion on type a under group with type b path returns empty`() = runScenario( + "28 exclusion on typeA under group with typeB path returns empty", + listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor).exclude(TYPE_INFO_A)), + finalFact(AccessPathBase.This, TYPE_INFO_B, TypeInfoGroupAccessor), + expectedEmpty = true + ) + + @Test + fun `scenario 29 merged final root exclusion on b returns a b`() = runScenario( + "29 merged final root exclusion on b returns a.b", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_E)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `scenario 30 merged final exclusion on c under b returns a b c`() = runScenario( + "30 merged final exclusion on c under b returns a.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_E)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 31 merged final non matching exclusion on e returns empty`() = runScenario( + "31 merged final non matching exclusion on e returns empty", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_E)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)), + expectedEmpty = true + ) + + @Test + fun `scenario 32 merged final mark plus field with mark exclusion returns mark final`() = runScenario( + "32 merged final mark plus field with mark exclusion returns mark.final", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + merge(finalFact(AccessPathBase.This, MARK), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)), + expectedFacts = listOf(initialFact(AccessPathBase.This, MARK, FinalAccessor)) + ) + + @Test + fun `scenario 33 merged final type group plus field exclusion returns type group final`() = runScenario( + "33 merged final type group plus field exclusion returns type group.final", + listOf(initialFact(AccessPathBase.This).exclude(TypeInfoGroupAccessor)), + merge(finalFact(AccessPathBase.This, TypeInfoGroupAccessor), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)), + expectedFacts = listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor, FinalAccessor)) + ) + + @Test + fun `scenario 34 merged final with any branch and root exclusion on b returns a b`() = runScenario( + "34 merged final with any branch and root exclusion on b returns a.b", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + merge(finalFact(AccessPathBase.This, AnyAccessor, FIELD_B_C), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `scenario 35 merged final any under b plus concrete c exclusion returns a b c`() = runScenario( + "35 merged final any under b plus concrete c exclusion returns a.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, MARK), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 36 merged final any under b with non matching exclusion on e returns a b e`() = runScenario( + "36 merged final any under b with non matching exclusion on e returns a.b.e", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_E)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, MARK), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_E)) + ) + + @Test + fun `scenario 37 merged final with root any and concrete branch exclusion on c returns empty`() = runScenario( + "37 merged final with root any and concrete branch exclusion on c returns empty", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + merge(finalFact(AccessPathBase.This, AnyAccessor, FIELD_B_C, FIELD_C_D), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_E)), + expectedEmpty = true + ) + + @Test + fun `scenario 38 merged final unroll next with any no any leaf and c exclusion returns a b c`() = runScenario( + "38 merged final unroll next with any no-any leaf and c exclusion returns a.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + merge(finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, FIELD_NO_ANY), finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D)), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `scenario 39 merged final mark and value chains with root mark exclusion returns mark final`() = runScenario( + "39 merged final mark and value chains with root mark exclusion returns mark.final", + listOf(initialFact(AccessPathBase.This).exclude(MARK)), + merge(finalFact(AccessPathBase.This, MARK), finalFact(AccessPathBase.This, ValueAccessor, MARK)), + expectedFacts = listOf(initialFact(AccessPathBase.This, MARK, FinalAccessor)) + ) + + @Test + fun `scenario 40 merged final type group and any typed chain with root exclusion returns type group final`() = + runScenario( + "40 merged final type group and any typed chain with root exclusion returns type group.final", + listOf(initialFact(AccessPathBase.This).exclude(TypeInfoGroupAccessor)), + merge( + finalFact(AccessPathBase.This, TypeInfoGroupAccessor), + finalFact(AccessPathBase.This, AnyAccessor, TYPE_INFO_A, TypeInfoGroupAccessor) + ), + expectedFacts = listOf(initialFact(AccessPathBase.This, TypeInfoGroupAccessor, FinalAccessor)) + ) + + @Test + fun `any accessor scenario 1 analyzed excludes b added any c under root returns this b`() = runScenario( + "any-1 analyzed excludes b, added any.c under root returns this.b", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + finalFact(AccessPathBase.This, AnyAccessor, FIELD_B_C), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `any accessor scenario 2 analyzed excludes c under b added b any mark returns this b c`() = runScenario( + "any-2 analyzed excludes c under b, added b.any.mark returns this.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, MARK), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `any accessor scenario 3 analyzed excludes c under b added b any d returns this b c`() = runScenario( + "any-3 analyzed excludes c under b, added b.any.d returns this.b.c", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, FIELD_C_D), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `any accessor scenario 4 analyzed excludes e under b added b any mark returns this b e`() = runScenario( + "any-4 analyzed excludes e under b, added b.any.mark returns this.b.e", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_E)), + finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, MARK), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_E)) + ) + + @Test + fun `any accessor scenario 5 analyzed excludes root b added any c`() = runScenario( + "any-5 analyzed excludes root b, added any.c", + listOf(initialFact(AccessPathBase.This).exclude(FIELD_A_B)), + finalFact(AccessPathBase.This, AnyAccessor, FIELD_B_C), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B)) + ) + + @Test + fun `any accessor scenario 6 analyzed excludes c under b added b any with rule storage`() = runScenario( + "any-6 analyzed excludes c under b, added b.any with rule-storage", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, FIELD_NO_ANY), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `any accessor scenario 7 analyzed excludes c under b added b any value`() = runScenario( + "any-7 analyzed excludes c under b, added b.any.value", + listOf(initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C)), + finalFact(AccessPathBase.This, FIELD_A_B, AnyAccessor, ValueAccessor), + expectedFacts = listOf(initialFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C)) + ) + + @Test + fun `same conflicting fact added twice yields abstraction only once`() { + val abstraction = newAbstraction() + val analyzed = initialFact(AccessPathBase.This, FIELD_A_B).exclude(FIELD_B_C) + abstraction.registerNewInitialFact(analyzed, FactTypeChecker.Dummy) + + val added = finalFact(AccessPathBase.This, FIELD_A_B, FIELD_B_C, FIELD_C_D) + val firstProduced = abstraction.addAbstractedInitialFact(added, FactTypeChecker.Dummy) + val secondProduced = abstraction.addAbstractedInitialFact(added, FactTypeChecker.Dummy) + + assertTrue( + firstProduced.isNotEmpty(), + "Expected first add to produce abstraction; analyzed=$analyzed; added=$added; produced=${ + producedFactsToString( + firstProduced + ) + }", + ) + assertTrue( + abstractionIsEmpty(secondProduced), + "Expected second add of same fact to produce nothing; analyzed=$analyzed; added=$added; produced=${ + producedFactsToString( + secondProduced + ) + }", + ) + } + + private fun initialFact(base: AccessPathBase, vararg accessors: Accessor): InitialFactAp { + var fact = apManager.mostAbstractInitialAp(base) + accessors.reversed().forEach { accessor -> + fact = fact.prependAccessor(accessor) + } + return fact + } + + private fun finalFact(base: AccessPathBase, vararg accessors: Accessor): FinalFactAp { + var fact = apManager.createFinalAp(base, ExclusionSet.Empty) + accessors.reversed().forEach { accessor -> + fact = fact.prependAccessor(accessor) + } + return fact + } + + private fun producedFactsToString(produced: List>): String = + if (produced.isEmpty()) { + "[]" + } else { + produced.joinToString(prefix = "[", postfix = "]") { (initial, _) -> "$initial" } + } + + private fun abstractionIsEmpty(produced: List>): Boolean = + produced.isEmpty() || (produced.size == 1 && produced.single().first.size == 0) + + + private fun newAbstraction() = apManager.initialFactAbstraction(dummyInst) + + private val dummyInst = object : CommonInst { + override fun toString(): String = "dummy-inst" + override val location: CommonInstLocation = object : CommonInstLocation { + override val method: CommonMethod = object : CommonMethod { + override val name: String = "dummy" + override val parameters: List = emptyList() + override val returnType: CommonTypeName = object : CommonTypeName { + override val typeName: String = "void" + } + + override fun flowGraph(): ControlFlowGraph = object : ControlFlowGraph { + override val instructions: List = emptyList() + override val entries: List = emptyList() + override val exits: List = emptyList() + override fun successors(node: CommonInst): Set = emptySet() + override fun predecessors(node: CommonInst): Set = emptySet() + } + } + } + } +} diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstractionTest.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstractionTest.kt new file mode 100644 index 000000000..c32e17622 --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/automata/AutomataInitialFactAbstractionTest.kt @@ -0,0 +1,19 @@ +package org.opentaint.dataflow.ap.ifds.access.automata + +import org.opentaint.dataflow.ap.ifds.access.AnyAccessorUnrollStrategy +import org.opentaint.dataflow.ap.ifds.access.ApManager +import org.opentaint.dataflow.ap.ifds.access.FinalFactAp +import org.opentaint.dataflow.ap.ifds.access.InitialFactAbstractionTest + +class AutomataInitialFactAbstractionTest : InitialFactAbstractionTest() { + override fun mkApManager(strategy: AnyAccessorUnrollStrategy): ApManager = AutomataApManager(strategy) + + override fun merge(fact: FinalFactAp, vararg facts: FinalFactAp): FinalFactAp { + check(fact is AccessGraphFinalFactAp) + return facts.fold(fact) { acc, f -> + val graph = f as AccessGraphFinalFactAp + val access = acc.access.merge(graph.access) + AccessGraphFinalFactAp(fact.base, access, fact.exclusions) + } + } +} diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstractionTest.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstractionTest.kt new file mode 100644 index 000000000..c7362531a --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/tree/TreeInitialFactAbstractionTest.kt @@ -0,0 +1,19 @@ +package org.opentaint.dataflow.ap.ifds.access.tree + +import org.opentaint.dataflow.ap.ifds.access.AnyAccessorUnrollStrategy +import org.opentaint.dataflow.ap.ifds.access.ApManager +import org.opentaint.dataflow.ap.ifds.access.FinalFactAp +import org.opentaint.dataflow.ap.ifds.access.InitialFactAbstractionTest + +class TreeInitialFactAbstractionTest : InitialFactAbstractionTest() { + override fun mkApManager(strategy: AnyAccessorUnrollStrategy): ApManager = TreeApManager(strategy) + + override fun merge(fact: FinalFactAp, vararg facts: FinalFactAp): FinalFactAp { + check(fact is AccessTree) + return facts.fold(fact) { acc, f -> + val tree = f as AccessTree + val access = acc.access.mergeAdd(tree.access) + AccessTree(fact.apManager, fact.base, access, fact.exclusions) + } + } +} diff --git a/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInternerTest.kt b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInternerTest.kt new file mode 100644 index 000000000..3397d12cc --- /dev/null +++ b/core/opentaint-dataflow-core/opentaint-dataflow/src/test/kotlin/org/opentaint/dataflow/ap/ifds/access/util/AccessorInternerTest.kt @@ -0,0 +1,104 @@ +package org.opentaint.dataflow.ap.ifds.access.util + +import org.opentaint.dataflow.ap.ifds.AbstractionAlwaysUnrollNextAccessor +import org.opentaint.dataflow.ap.ifds.Accessor +import org.opentaint.dataflow.ap.ifds.AnyAccessor +import org.opentaint.dataflow.ap.ifds.ClassStaticAccessor +import org.opentaint.dataflow.ap.ifds.ElementAccessor +import org.opentaint.dataflow.ap.ifds.FieldAccessor +import org.opentaint.dataflow.ap.ifds.FinalAccessor +import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor +import org.opentaint.dataflow.ap.ifds.ValueAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isAlwaysUnrollNext +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isFieldAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isStaticAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.access.util.AccessorInterner.Companion.isTypeInfoAccessor +import kotlin.random.Random +import kotlin.test.Test +import kotlin.test.assertEquals + +class AccessorInternerTest { + private companion object { + const val RANDOM_SEED = 42L + const val RANDOM_ACCESSORS_COUNT = 500 + const val MAX_STRING_LEN = 6 + } + + private val singletonAccessors: List = listOf( + AnyAccessor, + ElementAccessor, + FinalAccessor, + ValueAccessor, + TypeInfoGroupAccessor, + ) + + private fun randomString(random: Random): String { + val length = random.nextInt(0, MAX_STRING_LEN + 1) + return buildString { + repeat(length) { append('a' + random.nextInt(26)) } + } + } + + private fun randomDataAccessor(random: Random): Accessor = when (random.nextInt(4)) { + 0 -> FieldAccessor(randomString(random), randomString(random), randomString(random)) + 1 -> ClassStaticAccessor(randomString(random)) + 2 -> TaintMarkAccessor(randomString(random)) + else -> TypeInfoAccessor(randomString(random)) + } + + private fun sampleAccessors(): List { + val random = Random(RANDOM_SEED) + return singletonAccessors + List(RANDOM_ACCESSORS_COUNT) { randomDataAccessor(random) } + } + + @Test + fun `accessor(index(a)) returns the same accessor`() { + val interner = AccessorInterner() + for (accessor in sampleAccessors()) { + val idx = interner.index(accessor) + assertEquals(accessor, interner.accessor(idx), "Round-trip failed for $accessor") + } + } + + @Test + fun `interner returns the same index for equal accessors`() { + val interner = AccessorInterner() + for (accessor in sampleAccessors()) { + val first = interner.index(accessor) + val second = interner.index(accessor) + assertEquals(first, second, "Indices differ between two index() calls for $accessor") + } + } + + @Test + fun `predicates on indices match predicates on accessors`() { + val interner = AccessorInterner() + for (accessor in sampleAccessors()) { + val idx = interner.index(accessor) + + assertEquals( + accessor is FieldAccessor, idx.isFieldAccessor(), + "isFieldAccessor mismatch for $accessor", + ) + assertEquals( + accessor is ClassStaticAccessor, idx.isStaticAccessor(), + "isStaticAccessor mismatch for $accessor", + ) + assertEquals( + accessor is TaintMarkAccessor, idx.isTaintMarkAccessor(), + "isTaintMarkAccessor mismatch for $accessor", + ) + assertEquals( + accessor is TypeInfoAccessor, idx.isTypeInfoAccessor(), + "isTypeInfoAccessor mismatch for $accessor", + ) + assertEquals( + accessor is AbstractionAlwaysUnrollNextAccessor, idx.isAlwaysUnrollNext(), + "isAlwaysUnrollNext mismatch for $accessor", + ) + } + } +} diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRFactTypeChecker.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRFactTypeChecker.kt index d18b93d83..2cb852841 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRFactTypeChecker.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRFactTypeChecker.kt @@ -15,6 +15,8 @@ import org.opentaint.dataflow.ap.ifds.FactTypeChecker.FilterResult import org.opentaint.dataflow.ap.ifds.FieldAccessor import org.opentaint.dataflow.ap.ifds.FinalAccessor import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.ValueAccessor import org.opentaint.dataflow.ap.ifds.access.FinalFactAp import org.opentaint.dataflow.jvm.util.JIRHierarchyInfo @@ -106,6 +108,9 @@ class JIRFactTypeChecker(private val cp: JIRClasspath) : FactTypeChecker { FilterResult.Reject } } + + is TypeInfoAccessor -> return FilterResult.Accept + TypeInfoGroupAccessor -> return FilterResult.Accept } } @@ -191,6 +196,7 @@ class JIRFactTypeChecker(private val cp: JIRClasspath) : FactTypeChecker { } is TaintMarkAccessor, FinalAccessor, AnyAccessor, is ClassStaticAccessor -> null + is TypeInfoAccessor, TypeInfoGroupAccessor -> null } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLambdaTracker.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLambdaTracker.kt index e422b05e6..a022d7aa2 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLambdaTracker.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRLambdaTracker.kt @@ -1,34 +1,10 @@ package org.opentaint.dataflow.jvm.ap.ifds -import org.opentaint.ir.api.jvm.JIRMethod import org.opentaint.dataflow.jvm.ap.ifds.LambdaAnonymousClassFeature.JIRLambdaClass -import java.util.concurrent.ConcurrentHashMap - -class JIRLambdaTracker { - private val lambdaTrackers = ConcurrentHashMap() - - fun registerLambda(lambda: JIRLambdaClass) { - val methodLambdas = lambdaTrackers.computeIfAbsent(lambda.lambdaMethod) { - LambdaTracker(lambda.lambdaMethod) - } - - methodLambdas.addLambda(lambda) - } - - fun subscribeOnLambda(method: JIRMethod, subscriber: LambdaSubscriber) { - val methodLambdas = lambdaTrackers.computeIfAbsent(method) { - LambdaTracker(method) - } - - methodLambdas.addSubscriber(subscriber) - } - - fun forEachRegisteredLambda(method: JIRMethod, subscriber: LambdaSubscriber) { - val methodLambdas = lambdaTrackers[method] ?: return - methodLambdas.forEachRegisteredLambda(subscriber) - } +import org.opentaint.ir.api.jvm.JIRMethod - private class LambdaTracker(private val method: JIRMethod) { +object JIRLambdaTracker { + class LambdaTracker(val method: JIRMethod) { private val subscribers = hashSetOf() private val registeredLambdas = hashSetOf() diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRSummariesFeature.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRSummariesFeature.kt index 5fbc629e6..e89922df0 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRSummariesFeature.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/JIRSummariesFeature.kt @@ -8,6 +8,8 @@ import org.opentaint.dataflow.ap.ifds.ElementAccessor import org.opentaint.dataflow.ap.ifds.FieldAccessor import org.opentaint.dataflow.ap.ifds.FinalAccessor import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.access.ApMode import org.opentaint.ir.api.jvm.ByteCodeIndexer import org.opentaint.ir.api.jvm.JIRClasspath @@ -159,35 +161,42 @@ class JIRSummariesFeature( FINAL_ACCESSOR_ID -> FinalAccessor ELEMENT_ACCESSOR_ID -> ElementAccessor VALUE_ACCESSOR_ID -> ValueAccessor + TYPE_INFO_GROUP_ACCESSOR_ID -> TypeInfoGroupAccessor else -> { idToAccessorCache.computeIfAbsent(id) { - val (classNameId, fieldNameId, fieldTypeId, taintMarkId, staticTypeNameId) = jIRdb.persistence.read { context -> + val ids = jIRdb.persistence.read { context -> val accessorEntity = context.txn.find(ACCESSOR_IDS_TYPE, "id", id) .singleOrNull() ?: error("Deserialization error. Unknown accessor with id: $id") - val classNameId = accessorEntity.get("classNameId") - val fieldNameId = accessorEntity.get("fieldNameId") - val fieldTypeId = accessorEntity.get("fieldTypeId") - val taintMarkId = accessorEntity.get("taintMarkId") - val staticTypeNameId = accessorEntity.get("staticTypeNameId") - arrayOf(classNameId, fieldNameId, fieldTypeId, taintMarkId, staticTypeNameId) + AccessorIds( + classNameId = accessorEntity.get("classNameId"), + fieldNameId = accessorEntity.get("fieldNameId"), + fieldTypeId = accessorEntity.get("fieldTypeId"), + taintMarkId = accessorEntity.get("taintMarkId"), + staticTypeNameId = accessorEntity.get("staticTypeNameId"), + typeInfoTypeNameId = accessorEntity.get("typeInfoTypeNameId"), + ) } - if (classNameId != null) { - checkNotNull(fieldNameId) { "Expected non-null fieldNameId" } - checkNotNull(fieldTypeId) { "Expected non-null fieldTypeId" } + if (ids.classNameId != null) { + checkNotNull(ids.fieldNameId) { "Expected non-null fieldNameId" } + checkNotNull(ids.fieldTypeId) { "Expected non-null fieldTypeId" } - val className = findSymbolName(classNameId, symbolType = "className") - val fieldName = findSymbolName(fieldNameId, symbolType = "fieldName") - val fieldType = findSymbolName(fieldTypeId, symbolType = "fieldType") + val className = findSymbolName(ids.classNameId, symbolType = "className") + val fieldName = findSymbolName(ids.fieldNameId, symbolType = "fieldName") + val fieldType = findSymbolName(ids.fieldTypeId, symbolType = "fieldType") FieldAccessor(className, fieldName, fieldType) - } else if (staticTypeNameId != null) { - val typeName = interner.findSymbolName(staticTypeNameId) + } else if (ids.staticTypeNameId != null) { + val typeName = interner.findSymbolName(ids.staticTypeNameId) ?: error("Deserialization error. Unknown typeName id: $id") ClassStaticAccessor(typeName) + } else if (ids.typeInfoTypeNameId != null) { + val typeName = interner.findSymbolName(ids.typeInfoTypeNameId) + ?: error("Deserialization error. Unknown typeName id: $id") + TypeInfoAccessor(typeName) } else { - checkNotNull(taintMarkId) { "Expected non-null taintMarkId" } + checkNotNull(ids.taintMarkId) { "Expected non-null taintMarkId" } - val taintMarkName = interner.findSymbolName(taintMarkId) + val taintMarkName = interner.findSymbolName(ids.taintMarkId) ?: error("Deserialization error. Unknown taintMark id: $id") TaintMarkAccessor(taintMarkName) } @@ -202,6 +211,7 @@ class JIRSummariesFeature( ElementAccessor -> ELEMENT_ACCESSOR_ID FinalAccessor -> FINAL_ACCESSOR_ID ValueAccessor -> VALUE_ACCESSOR_ID + TypeInfoGroupAccessor -> TYPE_INFO_GROUP_ACCESSOR_ID is FieldAccessor -> accessorToIdCache.computeIfAbsent(accessor) { val classNameId = accessor.className.asSymbolId(interner) @@ -244,6 +254,18 @@ class JIRSummariesFeature( newAccessors.add(accessor) } } + + is TypeInfoAccessor -> accessorToIdCache.computeIfAbsent(accessor) { + val typeInfoTypeNameId = accessor.typeName.asSymbolId(interner) + val accessorId = jIRdb.persistence.read { context -> + context.txn.find(ACCESSOR_IDS_TYPE, "typeInfoTypeNameId", typeInfoTypeNameId) + .singleOrNull() + ?.get("id") + } + accessorId ?: accessorIdGen.incrementAndGet().also { + newAccessors.add(accessor) + } + } } } @@ -328,6 +350,14 @@ class JIRSummariesFeature( staticAccessorId["staticTypeNameId"] = staticTypeNameId } } + } else if (accessor is TypeInfoAccessor) { + val typeInfoTypeNameId = accessor.typeName.asSymbolId(interner) + jIRdb.persistence.write { context -> + context.txn.newEntity(ACCESSOR_IDS_TYPE).also { typeInfoAccessorId -> + typeInfoAccessorId["id"] = accessorToIdCache[accessor]!! + typeInfoAccessorId["typeInfoTypeNameId"] = typeInfoTypeNameId + } + } } else { accessor as TaintMarkAccessor @@ -354,6 +384,15 @@ class JIRSummariesFeature( ?: error("Deserialization error. Unknown $symbolType id: $id") } + private data class AccessorIds( + val classNameId: Long?, + val fieldNameId: Long?, + val fieldTypeId: Long?, + val taintMarkId: Long?, + val staticTypeNameId: Long?, + val typeInfoTypeNameId: Long?, + ) + companion object { private const val METHOD_IDS_TYPE = "MethodIds" private const val ACCESSOR_IDS_TYPE = "AccessorIds" @@ -363,6 +402,7 @@ class JIRSummariesFeature( private const val FINAL_ACCESSOR_ID = 1L private const val ELEMENT_ACCESSOR_ID = 2L private const val VALUE_ACCESSOR_ID = 3L - private const val MAX_RESERVED_ACCESSOR_ID = 3L + private const val TYPE_INFO_GROUP_ACCESSOR_ID = 4L + private const val MAX_RESERVED_ACCESSOR_ID = 4L } } \ No newline at end of file diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/LambdaAnonymousClassFeature.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/LambdaAnonymousClassFeature.kt index ef8733b6b..f2d1383bc 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/LambdaAnonymousClassFeature.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/LambdaAnonymousClassFeature.kt @@ -1,5 +1,7 @@ package org.opentaint.dataflow.jvm.ap.ifds +import org.opentaint.dataflow.jvm.util.JIRInstListBuilder +import org.opentaint.dataflow.jvm.util.typeName import org.opentaint.ir.api.jvm.JIRClassOrInterface import org.opentaint.ir.api.jvm.JIRClassType import org.opentaint.ir.api.jvm.JIRClasspath @@ -51,15 +53,11 @@ import org.opentaint.ir.impl.features.classpaths.virtual.JIRVirtualParameter import org.opentaint.ir.impl.types.JIRClassTypeImpl import org.opentaint.ir.impl.types.JIRTypedFieldImpl import org.opentaint.ir.impl.types.substition.JIRSubstitutorImpl -import org.opentaint.dataflow.jvm.util.JIRInstListBuilder -import org.opentaint.dataflow.jvm.util.typeName import java.util.Objects import java.util.concurrent.ConcurrentHashMap class LambdaAnonymousClassFeature : JIRClasspathExtFeature { private val lambdaClasses = ConcurrentHashMap() - private val lambdaProxyClasses = ConcurrentHashMap() - private val lambdaProxyClassesByName = ConcurrentHashMap() override fun tryFindClass(classpath: JIRClasspath, name: String): JIRClasspathExtFeature.JIRResolvedClassResult? { val clazz = lambdaClasses[name] @@ -67,80 +65,9 @@ class LambdaAnonymousClassFeature : JIRClasspathExtFeature { return JIRResolvedClassResultImpl(name, clazz) } - val proxyClass = lambdaProxyClassesByName[name] - if (proxyClass != null) { - return JIRResolvedClassResultImpl(name, proxyClass) - } - return null } - fun getOrCreateLambdaProxy( - method: JIRMethod, - classpath: JIRClasspath, - location: RegisteredLocation - ): JIRLambdaClass = lambdaProxyClasses.computeIfAbsent(method) { - val proxyClassName = lambdaProxyClassName(method) - val proxyMethod = generateProxyMethod(method, classpath) - val lambdaClass = JIRLambdaClass( - name = proxyClassName, - fields = emptyList(), - methods = listOf(proxyMethod), - lambdaMethod = method, - lambdaInterfaceType = method.enclosingClass, - lambdaLocation = null - ) - lambdaClass.bindWithLocation(classpath, location) - proxyMethod.bind(lambdaClass) - lambdaProxyClassesByName[proxyClassName] = lambdaClass - lambdaClass - } - - private fun generateProxyMethod(originalMethod: JIRMethod, classpath: JIRClasspath): OpentaintLambdaProxyMethod { - val instructions = JIRInstListBuilder() - - val method = OpentaintLambdaProxyMethod( - originalMethod = originalMethod, - name = originalMethod.name, - returnType = originalMethod.returnType, - description = originalMethod.description, - parameters = originalMethod.parameters.map { JIRVirtualParameter(it.index, it.type) }, - instructions = instructions - ) - - val interfaceType = classpath.findType(originalMethod.enclosingClass.name) as? JIRClassType - ?: error("Cannot resolve interface type: ${originalMethod.enclosingClass.name}") - - val receiverArg = JIRThis(interfaceType) - val forwardedArgs = originalMethod.parameters.map { param -> - JIRArgument( - param.index, - param.name ?: "arg|${param.index}", - classpath.findType(param.type.typeName) - ) - } - - val typedMethod = interfaceType.lookup.method(originalMethod.name, originalMethod.description) - ?: error("Cannot resolve typed method: ${originalMethod.name}${originalMethod.description} in ${interfaceType.typeName}") - - val methodRef = VirtualMethodRefImpl.of(interfaceType, typedMethod) - val callExpr = JIRVirtualCallExpr(methodRef, receiverArg, forwardedArgs) - - val isVoidReturn = originalMethod.returnType == PredefinedPrimitives.Void.typeName() - val retVal: JIRValue? = if (isVoidReturn) { - instructions.addInstWithLocation(method) { loc -> JIRCallInst(loc, callExpr) } - null - } else { - val resultLocal = JIRLocalVar(0, "result", classpath.findType(originalMethod.returnType.typeName)) - instructions.addInstWithLocation(method) { loc -> JIRAssignInst(loc, resultLocal, callExpr) } - resultLocal - } - - instructions.addInstWithLocation(method) { loc -> JIRReturnInst(loc, retVal) } - - return method - } - fun generateLambda(location: JIRInstLocation, lambda: JIRLambdaExpr): JIRLambdaClass { val lambdaClassName = with(location) { "${method.enclosingClass.name}$${method.name}_${method.descriptionHash()}\$jIR_lambda$${index}" @@ -458,14 +385,7 @@ class LambdaAnonymousClassFeature : JIRClasspathExtFeature { } companion object { - private const val LAMBDA_PROXY_CLASS_PREFIX = "opentaint.lambda.OpentaintLambdaProxy$" - @OptIn(ExperimentalStdlibApi::class) private fun JIRMethod.descriptionHash() = description.hashCode().toHexString() - - private fun lambdaProxyClassName(method: JIRMethod): String { - val clsName = method.enclosingClass.name.replace('.', '_') - return "$LAMBDA_PROXY_CLASS_PREFIX${clsName}_${method.name}_${method.descriptionHash()}" - } } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAnalysisManager.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAnalysisManager.kt index e0843d379..fc80048ba 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAnalysisManager.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRAnalysisManager.kt @@ -17,7 +17,6 @@ import org.opentaint.dataflow.ap.ifds.analysis.MethodEntrypointResolver import org.opentaint.dataflow.ap.ifds.analysis.MethodSequentFlowFunction import org.opentaint.dataflow.ap.ifds.analysis.MethodSideEffectSummaryHandler import org.opentaint.dataflow.ap.ifds.analysis.MethodStartFlowFunction -import org.opentaint.dataflow.jvm.ap.ifds.taint.ExternalMethodTracker import org.opentaint.dataflow.ap.ifds.taint.TaintAnalysisContext import org.opentaint.dataflow.ap.ifds.trace.MethodCallPrecondition import org.opentaint.dataflow.ap.ifds.trace.MethodSequentPrecondition @@ -26,14 +25,13 @@ import org.opentaint.dataflow.graph.MethodInstGraph import org.opentaint.dataflow.ifds.UnitResolver import org.opentaint.dataflow.jvm.ap.ifds.JIRCallResolver import org.opentaint.dataflow.jvm.ap.ifds.JIRFactTypeChecker -import org.opentaint.dataflow.jvm.ap.ifds.JIRLambdaTracker import org.opentaint.dataflow.jvm.ap.ifds.JIRLanguageManager import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalVariableReachability import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodCallFactMapper import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodContextSerializer -import org.opentaint.dataflow.jvm.ap.ifds.LambdaExpressionToAnonymousClassTransformerFeature import org.opentaint.dataflow.jvm.ap.ifds.jIRDowncast +import org.opentaint.dataflow.jvm.ap.ifds.taint.ExternalMethodTracker import org.opentaint.dataflow.jvm.ap.ifds.taint.JIRTaintAnalysisContext import org.opentaint.dataflow.jvm.ap.ifds.taint.TaintRulesProvider import org.opentaint.dataflow.jvm.ap.ifds.trace.JIRMethodCallPrecondition @@ -57,12 +55,10 @@ class JIRAnalysisManager( val externalMethodTracker: ExternalMethodTracker? = null, private val params: Params = Params(), ) : JIRLanguageManager(cp), TaintAnalysisManager { - private val lambdaTracker = JIRLambdaTracker() override val factTypeChecker = JIRFactTypeChecker(cp) data class Params( val aliasAnalysisParams: JIRLocalAliasAnalysis.Params = JIRLocalAliasAnalysis.Params(), - val callResolverParams: JIRMethodCallResolver.Params = JIRMethodCallResolver.Params(), ) override fun getMethodCallResolver( @@ -74,7 +70,7 @@ class JIRAnalysisManager( jIRDowncast(unitResolver) val jIRCallResolver = JIRCallResolver(cp, unitResolver) - return JIRMethodCallResolver(lambdaTracker, jIRCallResolver, runner, params.callResolverParams) + return JIRMethodCallResolver(jIRCallResolver, runner) } override fun getMethodAnalysisContext( @@ -274,11 +270,7 @@ class JIRAnalysisManager( override val methodContextSerializer = JIRMethodContextSerializer(cp) override fun onInstructionReached(inst: CommonInst) { - jIRDowncast(inst) - val allocatedLambda = LambdaExpressionToAnonymousClassTransformerFeature.findLambdaAllocation(inst) - if (allocatedLambda != null) { - lambdaTracker.registerLambda(allocatedLambda) - } + } override fun reportLanguageSpecificRunnerProgress(logger: KLogger) { diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodAnalysisContext.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodAnalysisContext.kt index 3a1cd3c7f..a94639d19 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodAnalysisContext.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodAnalysisContext.kt @@ -1,10 +1,12 @@ package org.opentaint.dataflow.jvm.ap.ifds.analysis +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap import org.opentaint.dataflow.ap.ifds.MethodEntryPoint import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor import org.opentaint.dataflow.ap.ifds.analysis.MethodAnalysisContext import org.opentaint.dataflow.ap.ifds.analysis.MethodCallFactMapper import org.opentaint.dataflow.jvm.ap.ifds.JIRFactTypeChecker +import org.opentaint.dataflow.jvm.ap.ifds.JIRLambdaTracker import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalAliasAnalysis import org.opentaint.dataflow.jvm.ap.ifds.JIRLocalVariableReachability import org.opentaint.dataflow.jvm.ap.ifds.JIRMethodCallFactMapper @@ -21,4 +23,6 @@ class JIRMethodAnalysisContext( get() = JIRMethodCallFactMapper val taintMarksAssignedOnMethodEnter = hashSetOf() + + val lambdaCallResolution = Int2ObjectOpenHashMap() } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallResolver.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallResolver.kt index 06b61f36e..0f58e5857 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallResolver.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodCallResolver.kt @@ -1,11 +1,17 @@ package org.opentaint.dataflow.jvm.ap.ifds.analysis +import mu.KLogging +import org.opentaint.dataflow.ap.ifds.AccessPathBase +import org.opentaint.dataflow.ap.ifds.Accessor import org.opentaint.dataflow.ap.ifds.EmptyMethodContext import org.opentaint.dataflow.ap.ifds.MethodAnalyzer +import org.opentaint.dataflow.ap.ifds.MethodAnalyzer.MethodCallHandler import org.opentaint.dataflow.ap.ifds.MethodEntryPoint import org.opentaint.dataflow.ap.ifds.MethodWithContext import org.opentaint.dataflow.ap.ifds.TaintAnalysisUnitRunner import org.opentaint.dataflow.ap.ifds.TaintAnalysisUnitRunner.LambdaResolvedEvent +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.analysis.MethodAnalysisContext import org.opentaint.dataflow.ap.ifds.analysis.MethodCallResolver import org.opentaint.dataflow.ap.ifds.analysis.MethodCallResolver.MethodCallResolutionResult @@ -13,28 +19,26 @@ import org.opentaint.dataflow.jvm.ap.ifds.JIRCallResolver import org.opentaint.dataflow.jvm.ap.ifds.JIRLambdaTracker import org.opentaint.dataflow.jvm.ap.ifds.LambdaAnonymousClassFeature import org.opentaint.dataflow.jvm.ap.ifds.jIRDowncast +import org.opentaint.dataflow.util.getOrCreate import org.opentaint.ir.api.common.cfg.CommonCallExpr import org.opentaint.ir.api.common.cfg.CommonInst +import org.opentaint.ir.api.jvm.JIRClassType import org.opentaint.ir.api.jvm.JIRMethod +import org.opentaint.ir.api.jvm.cfg.JIRAssignInst import org.opentaint.ir.api.jvm.cfg.JIRCallExpr import org.opentaint.ir.api.jvm.cfg.JIRInst +import org.opentaint.ir.api.jvm.cfg.JIRNewExpr import org.opentaint.ir.api.jvm.ext.findMethodOrNull class JIRMethodCallResolver( - private val lambdaTracker: JIRLambdaTracker, val callResolver: JIRCallResolver, val runner: TaintAnalysisUnitRunner, - private val params: Params, ) : MethodCallResolver { - data class Params( - val skipUnresolvedLambda: Boolean = true, - ) - override fun resolveMethodCall( callerContext: MethodAnalysisContext, callExpr: CommonCallExpr, location: CommonInst, - handler: MethodAnalyzer.MethodCallHandler, + handler: MethodCallHandler, failureHandler: MethodAnalyzer.MethodCallResolutionFailureHandler ) { jIRDowncast(callExpr) @@ -54,23 +58,18 @@ class JIRMethodCallResolver( return resolvedJirMethodCalls(callerContext, callExpr, location) } - private val lambdaFeature by lazy { - callResolver.cp.features?.filterIsInstance()?.firstOrNull() - ?: error("No lambda feature found") - } - private fun resolveJirMethodCall( callerContext: JIRMethodAnalysisContext, callExpr: JIRCallExpr, location: JIRInst, - handler: MethodAnalyzer.MethodCallHandler, + handler: MethodCallHandler, failureHandler: MethodAnalyzer.MethodCallResolutionFailureHandler ) { val callees = callResolver.resolve(callExpr, location, callerContext) val analyzer = runner.getMethodAnalyzer(callerContext.methodEntryPoint) for (resolvedCallee in callees) { - resolveJirMethodCall(callerContext, resolvedCallee, analyzer, callExpr, failureHandler, handler) + resolveJirMethodCall(callerContext, resolvedCallee, analyzer, callExpr, location, failureHandler, handler) } } @@ -79,8 +78,9 @@ class JIRMethodCallResolver( resolvedCallee: JIRCallResolver.MethodResolutionResult, analyzer: MethodAnalyzer, callExpr: JIRCallExpr, + location: JIRInst, failureHandler: MethodAnalyzer.MethodCallResolutionFailureHandler, - handler: MethodAnalyzer.MethodCallHandler + handler: MethodCallHandler ) { when (resolvedCallee) { JIRCallResolver.MethodResolutionResult.MethodResolutionFailed -> { @@ -92,18 +92,81 @@ class JIRMethodCallResolver( } is JIRCallResolver.MethodResolutionResult.Lambda -> { - resolvedCallee.withLambdaProxy( - callerContext, - delegate = { resolveJirMethodCall(callerContext, it, analyzer, callExpr, failureHandler, handler) }, - handleLambda = { - val subscription = LambdaSubscription(runner, callerContext.methodEntryPoint, handler) - lambdaTracker.subscribeOnLambda(resolvedCallee.method, subscription) - } - ) + analyzer.handleMethodCallResolutionFailure(callExpr, failureHandler) + + val locationIdx = location.location.index + val lambdaResolver = callerContext.lambdaCallResolution.getOrCreate(locationIdx) { + JIRLambdaTracker.LambdaTracker(resolvedCallee.method) + } + + val subscription = LambdaSubscription(runner, callerContext.methodEntryPoint, handler) + lambdaResolver.addSubscriber(subscription) + + tryExtractLambdaType(lambdaResolver, handler, analyzer) } } } + private fun tryExtractLambdaType( + lambdaResolver: JIRLambdaTracker.LambdaTracker, + handler: MethodCallHandler, + analyzer: MethodAnalyzer, + ) { + val (start, fact) = when (handler) { + is MethodCallHandler.ZeroToZeroHandler, + is MethodCallHandler.NDFactToFactHandler -> return + + is MethodCallHandler.FactToFactHandler -> { + handler.startFactBase to handler.currentEdge.factAp + } + + is MethodCallHandler.ZeroToFactHandler -> { + handler.startFactBase to handler.currentEdge.factAp + } + } + + if (start != AccessPathBase.This) return + + val typeInfoGroup = fact.readAccessor(TypeInfoGroupAccessor) + if (typeInfoGroup == null) { + if (handler is MethodCallHandler.FactToFactHandler) { + val edge = handler.currentEdge + val refinedInitial = edge.initialFactAp.exclude(TypeInfoGroupAccessor) + analyzer.triggerSideEffectRequirement(refinedInitial) + } + return + } + + val typeInfos = typeInfoGroup.getStartAccessors().filterIsInstance() + typeInfos.forEach { typeInfo -> + val cls = callResolver.cp.findClassOrNull(typeInfo.typeName) + check(cls is LambdaAnonymousClassFeature.JIRLambdaClass) { + "Unexpected type info: $cls" + } + + val lambdaMethod = lambdaResolver.method + val lambdaImpl = cls.findMethodOrNull(lambdaMethod.name, lambdaMethod.description) + if (lambdaImpl == null) { + logger.debug { "Lambda class $cls has no lambda method $lambdaMethod" } + return@forEach + } + + lambdaResolver.addLambda(cls) + } + } + + data object TypeInfoSequentFlowFunction { + fun handle(inst: JIRInst, body: (List) -> Unit) { + if (inst !is JIRAssignInst) return + val allocation = inst.rhv as? JIRNewExpr ?: return + val allocatedType = allocation.type as? JIRClassType ?: return + val allocatedClass = allocatedType.jIRClass + if (allocatedClass !is LambdaAnonymousClassFeature.JIRLambdaClass) return + + body(listOf(TypeInfoGroupAccessor, TypeInfoAccessor(allocatedClass.name))) + } + } + private fun resolvedJirMethodCalls( callerContext: JIRMethodAnalysisContext, callExpr: JIRCallExpr, @@ -111,12 +174,13 @@ class JIRMethodCallResolver( ): List { val callees = callResolver.resolve(callExpr, location, callerContext) return callees.flatMap { resolvedCallee -> - resolvedJirMethodCalls(callerContext, resolvedCallee) + resolvedJirMethodCalls(callerContext, location, resolvedCallee) } } private fun resolvedJirMethodCalls( callerContext: JIRMethodAnalysisContext, + location: JIRInst, resolvedCallee: JIRCallResolver.MethodResolutionResult ): List = when (resolvedCallee) { @@ -129,32 +193,38 @@ class JIRMethodCallResolver( } is JIRCallResolver.MethodResolutionResult.Lambda -> { - resolvedCallee.withLambdaProxy(callerContext, { resolvedJirMethodCalls(callerContext, it) }) { - val resolvedLambdas = mutableListOf() - lambdaTracker.forEachRegisteredLambda( - resolvedCallee.method, - object : JIRLambdaTracker.LambdaSubscriber { - override fun newLambda( - method: JIRMethod, - lambdaClass: LambdaAnonymousClassFeature.JIRLambdaClass - ) { - val methodImpl = lambdaClass.findMethodOrNull(method.name, method.description) - ?: error("Lambda class $lambdaClass has no lambda method $method") - - resolvedLambdas += MethodCallResolutionResult.ResolvedMethod(MethodWithContext(methodImpl, EmptyMethodContext)) - } + val locationIdx = location.location.index + val lambdaResolver = callerContext.lambdaCallResolution.getOrCreate(locationIdx) { + JIRLambdaTracker.LambdaTracker(resolvedCallee.method) + } + + val resolvedLambdas = mutableListOf() + resolvedLambdas += MethodCallResolutionResult.ResolutionFailure + + lambdaResolver.forEachRegisteredLambda( + object : JIRLambdaTracker.LambdaSubscriber { + override fun newLambda( + method: JIRMethod, + lambdaClass: LambdaAnonymousClassFeature.JIRLambdaClass + ) { + val methodImpl = lambdaClass.findMethodOrNull(method.name, method.description) + ?: error("Lambda class $lambdaClass has no lambda method $method") + + resolvedLambdas += MethodCallResolutionResult.ResolvedMethod( + MethodWithContext(methodImpl, EmptyMethodContext) + ) } - ) + } + ) - resolvedLambdas.ifEmpty { listOf(MethodCallResolutionResult.ResolutionFailure) } - } + resolvedLambdas } } private data class LambdaSubscription( private val runner: TaintAnalysisUnitRunner, private val callerEntryPoint: MethodEntryPoint, - private val handler: MethodAnalyzer.MethodCallHandler + private val handler: MethodCallHandler ) : JIRLambdaTracker.LambdaSubscriber { override fun newLambda(method: JIRMethod, lambdaClass: LambdaAnonymousClassFeature.JIRLambdaClass) { val methodImpl = lambdaClass.findMethodOrNull(method.name, method.description) @@ -165,27 +235,7 @@ class JIRMethodCallResolver( } } - private inline fun JIRCallResolver.MethodResolutionResult.Lambda.withLambdaProxy( - callerContext: JIRMethodAnalysisContext, - delegate: (JIRCallResolver.MethodResolutionResult) -> T, - handleLambda: () -> T - ): T { - if (params.skipUnresolvedLambda) { - return delegate(JIRCallResolver.MethodResolutionResult.MethodResolutionFailed) - } - - val caller = callerContext.methodEntryPoint.method as JIRMethod - - if (caller is LambdaAnonymousClassFeature.OpentaintLambdaProxyMethod) { - return handleLambda() - } - - val callerLocation = caller.enclosingClass.declaration.location - - val proxy = lambdaFeature.getOrCreateLambdaProxy(method, callResolver.cp, callerLocation) - val proxyMethod = proxy.declaredMethods.first() - val proxyWithCtx = MethodWithContext(proxyMethod, EmptyMethodContext) - val concreteCall = JIRCallResolver.MethodResolutionResult.ConcreteMethod(proxyWithCtx) - return delegate(concreteCall) + companion object { + private val logger = object : KLogging() {}.logger } } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt index 0e0fb7315..f13269dd8 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/analysis/JIRMethodSequentFlowFunction.kt @@ -55,6 +55,15 @@ class JIRMethodSequentFlowFunction( override fun propagateZeroToZero(): Set = buildSet { add(Sequent.ZeroToZero) + if (currentInst is JIRAssignInst) { + JIRMethodCallResolver.TypeInfoSequentFlowFunction.handle(currentInst) { accessors -> + val lhv = accessPathBase(currentInst.lhv) ?: return@handle + val startFact = apManager.createFinalAp(lhv, ExclusionSet.Universe) + val fact = accessors.foldRight(startFact) { a, f -> f.prependAccessor(a) } + add(Sequent.ZeroToFact(fact, TraceInfo.Flow)) + } + } + applyUnconditionalSources() } diff --git a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/FactReader.kt b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/FactReader.kt index 76a8696a3..14ba1f557 100644 --- a/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/FactReader.kt +++ b/core/opentaint-dataflow-core/opentaint-jvm-dataflow/src/main/kotlin/org/opentaint/dataflow/jvm/ap/ifds/taint/FactReader.kt @@ -20,7 +20,7 @@ interface FactReader { fun containsPositionWithTaintMark(position: PositionAccess, mark: TaintMark): Boolean = containsPosition(position.withSuffix(taintedPosSuffix(mark))) - private fun taintedPosSuffix(mark: TaintMark) = listOf(TaintMarkAccessor(mark.name), FinalAccessor) + private fun taintedPosSuffix(mark: TaintMark): List = listOf(TaintMarkAccessor(mark.name), FinalAccessor) } class FinalFactReader( diff --git a/core/opentaint-jvm-sast-dataflow/src/main/kotlin/org/opentaint/jvm/sast/dataflow/JIRTaintAnalyzer.kt b/core/opentaint-jvm-sast-dataflow/src/main/kotlin/org/opentaint/jvm/sast/dataflow/JIRTaintAnalyzer.kt index 6ca753a1d..2c4467607 100644 --- a/core/opentaint-jvm-sast-dataflow/src/main/kotlin/org/opentaint/jvm/sast/dataflow/JIRTaintAnalyzer.kt +++ b/core/opentaint-jvm-sast-dataflow/src/main/kotlin/org/opentaint/jvm/sast/dataflow/JIRTaintAnalyzer.kt @@ -14,6 +14,8 @@ import org.opentaint.dataflow.ap.ifds.MethodEntryPoint import org.opentaint.dataflow.ap.ifds.MethodWithContext import org.opentaint.dataflow.ap.ifds.TaintAnalysisUnitRunnerManager import org.opentaint.dataflow.ap.ifds.TaintMarkAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoAccessor +import org.opentaint.dataflow.ap.ifds.TypeInfoGroupAccessor import org.opentaint.dataflow.ap.ifds.ValueAccessor import org.opentaint.dataflow.ap.ifds.access.AnyAccessorUnrollStrategy import org.opentaint.dataflow.ap.ifds.access.ApMode @@ -86,7 +88,9 @@ class JIRTaintAnalyzer( is ClassStaticAccessor, is AnyAccessor, is FinalAccessor, - is TaintMarkAccessor -> false + is TaintMarkAccessor, + is TypeInfoAccessor, + is TypeInfoGroupAccessor -> false is ValueAccessor -> error("Unexpected accessor to unroll: $accessor") } } @@ -188,7 +192,7 @@ class JIRTaintAnalyzer( val vulnerabilitiesWithTraces = ifdsEngine.generateTraces(entryPoints, vulnerabilities, traceResolutionTimeout) .also { logger.info { "Finish trace generation" } } - val filteredVulnerabilities = vulnerabilitiesWithTraces.filterVulnWithoutTrace() + val filteredVulnerabilities = vulnerabilitiesWithTraces.filter { it.trace != null } if (filteredVulnerabilities.size != vulnerabilitiesWithTraces.size) { val delta = vulnerabilitiesWithTraces.size - filteredVulnerabilities.size logger.info { "Filter out $delta vulnerabilities without traces" } @@ -217,8 +221,6 @@ class JIRTaintAnalyzer( entryPointsSet, vulnerabilities, resolverParams = TraceResolver.Params( resolveEntryPointToStartTrace = options.symbolicExecutionEnabled, - startToSourceTraceResolutionLimit = 100, - startToSinkTraceResolutionLimit = 100, sourceToSinkInnerTraceResolutionLimit = 5, innerCallTraceResolveStrategy = InnerCallTraceResolveStrategy ), @@ -227,14 +229,6 @@ class JIRTaintAnalyzer( ) } - private fun List.filterVulnWithoutTrace(): List = - filter { it.hasValidTrace() } - - private fun VulnerabilityWithTrace.hasValidTrace(): Boolean { - val trace = trace ?: return false - return trace.sourceToSinkTrace.startNodes.isNotEmpty() - } - private val taintConfig: TaintRulesProvider by lazy { StringConcatRuleProvider(taintConfiguration) } @@ -290,21 +284,20 @@ class JIRTaintAnalyzer( ): String = buildString { data class VulnInfo(val location: String, val ruleId: String, val kind: String) - fun TaintSinkTracker.TaintVulnerability.vulnSummary(): VulnInfo = when (this) { - is TaintSinkTracker.TaintVulnerabilityWithEndFactRequirement -> { - vulnerability.vulnSummary().let { it.copy(kind = "end#${it.kind}") } - } - - is TaintSinkTracker.TaintVulnerabilityUnconditional -> { - VulnInfo("${statement.location}|${statement}", rule.id, "unconditional") + fun TaintSinkTracker.TaintVulnerabilityRuleNode.kind(): List = when (this) { + is TaintSinkTracker.TaintVulnerabilityRuleNode.Unconditional -> listOf("unconditional") + is TaintSinkTracker.TaintVulnerabilityRuleNode.Fact -> listOf("fact") + is TaintSinkTracker.TaintVulnerabilityRuleNode.WithRequirement -> requirement.values.flatMap { v -> + v.kind().map { "end#${it}" } } + } - is TaintSinkTracker.TaintVulnerabilityWithFact -> { - VulnInfo("${statement.location}|${statement}", rule.id, "fact") - } + fun TaintSinkTracker.TaintVulnerability.vulnSummary(): List { + val kinds = vulnerabilityRules.values.flatMap { it.kind() }.distinct() + return kinds.map { VulnInfo("${statement.location}|${statement}", ruleId, it) } } - val info = vulnerabilities.mapTo(mutableListOf()) { it.vulnSummary() } + val info = vulnerabilities.flatMapTo(mutableListOf()) { it.vulnSummary() } info.sortWith(compareBy { it.kind }.thenBy { it.ruleId }.thenBy { it.location }) appendLine("VULNERABILITIES:") diff --git a/core/samples/src/main/java/test/samples/LambdaDataFlowSample.java b/core/samples/src/main/java/test/samples/LambdaDataFlowSample.java index 1231def2f..41af860f0 100644 --- a/core/samples/src/main/java/test/samples/LambdaDataFlowSample.java +++ b/core/samples/src/main/java/test/samples/LambdaDataFlowSample.java @@ -30,5 +30,15 @@ private String applyTransform(String input, Function fn) { } public String source() { return "tainted"; } - public void sink(String data) { } + public String sink(String data) { return data; } + + public void lambdaCaptureFlow() { + String data = source(); + lambdaCapture(data, s -> sink(s)); + } + + private void lambdaCapture(String input, Function fn) { + Function g = (x -> fn.apply(x)); + g.apply(input); + } } diff --git a/core/src/main/kotlin/org/opentaint/jvm/sast/project/SarifWebInfoAnnotator.kt b/core/src/main/kotlin/org/opentaint/jvm/sast/project/SarifWebInfoAnnotator.kt index be513e0e2..194b3a06c 100644 --- a/core/src/main/kotlin/org/opentaint/jvm/sast/project/SarifWebInfoAnnotator.kt +++ b/core/src/main/kotlin/org/opentaint/jvm/sast/project/SarifWebInfoAnnotator.kt @@ -93,7 +93,6 @@ abstract class SarifWebInfoAnnotator( ): Set { val methods = hashSetOf() methods.add(vulnerability.statement.location.method) - methods.add(vulnerability.methodEntryPoint.method) trace?.sourceToSinkTrace?.let { collectRelevantMethods(it, methods) } trace?.entryPointToStart?.let { collectRelevantMethods(it, methods) } diff --git a/core/src/test/kotlin/org/opentaint/jvm/sast/dataflow/JavaDataFlowReachabilityTest.kt b/core/src/test/kotlin/org/opentaint/jvm/sast/dataflow/JavaDataFlowReachabilityTest.kt index 43da02c7f..611b8c516 100644 --- a/core/src/test/kotlin/org/opentaint/jvm/sast/dataflow/JavaDataFlowReachabilityTest.kt +++ b/core/src/test/kotlin/org/opentaint/jvm/sast/dataflow/JavaDataFlowReachabilityTest.kt @@ -214,6 +214,20 @@ class JavaDataFlowReachabilityTest : AnalysisTest() { ) } + @Test + fun `lambda flow - taint tracked capture lambda`() { + val testCls = "$SAMPLE_PACKAGE.LambdaDataFlowSample" + val config = lambdaConfig(testCls) + + assertReachable( + config = config, + testCls = testCls, + entryPointName = "lambdaCaptureFlow", + ruleId = LAMBDA_RULE_ID, + testName = "lambda capture flow" + ) + } + @Test fun `lambda flow - taint tracked through lambda passed to method`() { val testCls = "$SAMPLE_PACKAGE.LambdaDataFlowSample"