diff --git a/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java b/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java index 0996a3123..1a2092894 100644 --- a/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java +++ b/conjure-java-core/src/main/java/com/palantir/conjure/java/types/SafetyEvaluator.java @@ -16,8 +16,6 @@ package com.palantir.conjure.java.types; -import com.google.common.collect.ImmutableTable; -import com.google.common.collect.Table; import com.palantir.conjure.java.util.TypeFunctions; import com.palantir.conjure.spec.AliasDefinition; import com.palantir.conjure.spec.ArgumentDefinition; @@ -40,12 +38,10 @@ import com.palantir.logsafe.Preconditions; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Supplier; -import java.util.stream.Stream; public final class SafetyEvaluator { @@ -68,7 +64,7 @@ public final class SafetyEvaluator { // Memoization cache shared across all evaluate() calls on this instance. Avoids redundant recursive // traversals of the type graph, which otherwise dominate generation time for large definitions. - private final Map> cache = new HashMap<>(); + private final Map cache = new HashMap<>(); public SafetyEvaluator(ConjureDefinition definition) { this(TypeFunctions.toTypesMap(definition)); @@ -79,13 +75,13 @@ public SafetyEvaluator(Map definitionMap) { } public Optional evaluate(TypeDefinition def) { - return Preconditions.checkNotNull(def, "TypeDefinition is required") - .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>())); + return fromValue(Preconditions.checkNotNull(def, "TypeDefinition is required") + .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>()))); } public Optional evaluate(Type type) { - return Preconditions.checkNotNull(type, "TypeDefinition is required") - .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>()).fieldVisitor); + return fromValue(Preconditions.checkNotNull(type, "TypeDefinition is required") + .accept(new TypeDefinitionSafetyVisitor(definitionMap, cache, new HashSet<>()).fieldVisitor)); } public Optional evaluate(Type type, Optional declaredSafety) { @@ -128,10 +124,23 @@ public Optional getUsageTimeSafety(FieldDefinition field) { return field.getSafety(); } - private static final class TypeDefinitionSafetyVisitor implements TypeDefinition.Visitor> { - private final Map> cache; + private static Optional fromValue(LogSafety.Value value) { + return switch (value) { + case SAFE -> Optional.of(LogSafety.SAFE); + case UNSAFE -> Optional.of(LogSafety.UNSAFE); + case DO_NOT_LOG -> Optional.of(LogSafety.DO_NOT_LOG); + case UNKNOWN -> Optional.empty(); + }; + } + + private static LogSafety.Value value(Optional safety) { + return safety.map(LogSafety::get).orElse(LogSafety.Value.UNKNOWN); + } + + private static final class TypeDefinitionSafetyVisitor implements TypeDefinition.Visitor { + private final Map cache; private final Set inProgress; - private final Type.Visitor> fieldVisitor; + private final Type.Visitor fieldVisitor; // Tracks whether cycle-breaking (the SAFE fallback for back-edges) was used anywhere // in the current evaluation subtree. Used to decide whether a result is safe to cache. @@ -139,7 +148,7 @@ private static final class TypeDefinitionSafetyVisitor implements TypeDefinition private TypeDefinitionSafetyVisitor( Map definitionMap, - Map> cache, + Map cache, Set inProgress) { this.cache = cache; this.inProgress = inProgress; @@ -147,19 +156,19 @@ private TypeDefinitionSafetyVisitor( } @Override - public Optional visitAlias(AliasDefinition value) { + public LogSafety.Value visitAlias(AliasDefinition value) { return with(value.getTypeName(), () -> getSafety(value.getAlias(), value.getSafety())); } @Override - public Optional visitEnum(EnumDefinition _value) { - return ENUM_SAFETY; + public LogSafety.Value visitEnum(EnumDefinition _value) { + return ENUM_SAFETY.get().get(); } @Override - public Optional visitObject(ObjectDefinition value) { + public LogSafety.Value visitObject(ObjectDefinition value) { return with(value.getTypeName(), () -> { - Optional safety = OPTIONAL_OF_SAFE; + LogSafety.Value safety = LogSafety.SAFE.get(); for (FieldDefinition field : value.getFields()) { safety = combine(safety, getSafety(field.getType(), field.getSafety())); } @@ -168,9 +177,9 @@ public Optional visitObject(ObjectDefinition value) { } @Override - public Optional visitUnion(UnionDefinition value) { + public LogSafety.Value visitUnion(UnionDefinition value) { return with(value.getTypeName(), () -> { - Optional safety = UNKNOWN_UNION_VARINT_SAFETY; + LogSafety.Value safety = LogSafety.Value.UNKNOWN; for (FieldDefinition variant : value.getUnion()) { safety = combine(safety, getSafety(variant.getType(), variant.getSafety())); } @@ -179,15 +188,15 @@ public Optional visitUnion(UnionDefinition value) { } @Override - public Optional visitUnknown(String unknownType) { + public LogSafety.Value visitUnknown(String unknownType) { throw new IllegalStateException("Unknown type: " + unknownType); } - private Optional with(TypeName typeName, Supplier> task) { + private LogSafety.Value with(TypeName typeName, Supplier task) { // Return memoized result if this type has already been fully evaluated. // Note: cache values are Optional which may be Optional.empty(), // so we check for null (absent key) rather than emptiness. - Optional cached = cache.get(typeName); + LogSafety.Value cached = cache.get(typeName); if (cached != null) { return cached; } @@ -195,14 +204,14 @@ private Optional with(TypeName typeName, Supplier // Given recursive evaluation, we return the least restrictive type: SAFE. // Mark that this subtree's result depends on cycle-breaking. encounteredCycle = true; - return OPTIONAL_OF_SAFE; + return LogSafety.Value.SAFE; } // Save and reset cycle state so we can detect cycles within this type's subtree only. boolean previousCycleState = encounteredCycle; encounteredCycle = false; - Optional result = task.get(); + LogSafety.Value result = task.get(); boolean subtreeHadCycle = encounteredCycle; // Propagate cycle detection upward: if this subtree had a cycle, callers should know. @@ -223,174 +232,171 @@ private Optional with(TypeName typeName, Supplier return result; } - private Optional getSafety(Type type, Optional safety) { - return safety.or(() -> type.accept(fieldVisitor)); + private LogSafety.Value getSafety(Type type, Optional safety) { + return safety.map(LogSafety::get).orElseGet(() -> type.accept(fieldVisitor)); } } - private static final class FieldSafetyVisitor implements Type.Visitor> { + private static final class FieldSafetyVisitor implements Type.Visitor { private final Map definitionMap; - private final TypeDefinition.Visitor> typeDefVisitor; + private final TypeDefinition.Visitor typeDefVisitor; FieldSafetyVisitor( - Map definitionMap, - TypeDefinition.Visitor> typeDefVisitor) { + Map definitionMap, TypeDefinition.Visitor typeDefVisitor) { this.definitionMap = definitionMap; this.typeDefVisitor = typeDefVisitor; } @Override - public Optional visitPrimitive(PrimitiveType value) { + public LogSafety.Value visitPrimitive(PrimitiveType value) { return value.accept(PrimitiveTypeSafetyVisitor.INSTANCE); } @Override - public Optional visitOptional(OptionalType value) { + public LogSafety.Value visitOptional(OptionalType value) { return value.getItemType().accept(this); } @Override - public Optional visitList(ListType value) { + public LogSafety.Value visitList(ListType value) { return value.getItemType().accept(this); } @Override - public Optional visitSet(SetType value) { + public LogSafety.Value visitSet(SetType value) { return value.getItemType().accept(this); } @Override - public Optional visitMap(MapType value) { - Optional keySafety = value.getKeyType().accept(this); - Optional valueSafety = value.getValueType().accept(this); + public LogSafety.Value visitMap(MapType value) { + LogSafety.Value keySafety = value.getKeyType().accept(this); + LogSafety.Value valueSafety = value.getValueType().accept(this); return combine(keySafety, valueSafety); } @Override - public Optional visitReference(TypeName value) { + public LogSafety.Value visitReference(TypeName value) { // inProgress is handled by TypeDefinitionSafetyVisitor - return Optional.ofNullable(definitionMap.get(value)).flatMap(item -> item.accept(typeDefVisitor)); + return Optional.ofNullable(definitionMap.get(value)) + .map(item -> item.accept(typeDefVisitor)) + .orElse(LogSafety.Value.UNKNOWN); } @Override - public Optional visitExternal(ExternalReference value) { - return value.getSafety(); + public LogSafety.Value visitExternal(ExternalReference value) { + return value(value.getSafety()); } @Override - public Optional visitUnknown(String unknownType) { + public LogSafety.Value visitUnknown(String unknownType) { throw new IllegalStateException("Unknown type: " + unknownType); } } - private enum PrimitiveTypeSafetyVisitor implements PrimitiveType.Visitor> { + private enum PrimitiveTypeSafetyVisitor implements PrimitiveType.Visitor { INSTANCE; @Override - public Optional visitString() { - return Optional.empty(); + public LogSafety.Value visitString() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitDatetime() { - return Optional.empty(); + public LogSafety.Value visitDatetime() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitInteger() { - return Optional.empty(); + public LogSafety.Value visitInteger() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitDouble() { - return Optional.empty(); + public LogSafety.Value visitDouble() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitSafelong() { - return Optional.empty(); + public LogSafety.Value visitSafelong() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitBinary() { - return Optional.empty(); + public LogSafety.Value visitBinary() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitAny() { - return Optional.empty(); + public LogSafety.Value visitAny() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitBoolean() { - return Optional.empty(); + public LogSafety.Value visitBoolean() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitUuid() { - return Optional.empty(); + public LogSafety.Value visitUuid() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitRid() { - return Optional.empty(); + public LogSafety.Value visitRid() { + return LogSafety.Value.UNKNOWN; } @Override - public Optional visitBearertoken() { - return Optional.of(LogSafety.DO_NOT_LOG); + public LogSafety.Value visitBearertoken() { + return LogSafety.Value.DO_NOT_LOG; } @Override - public Optional visitUnknown(String unknownValue) { + public LogSafety.Value visitUnknown(String unknownValue) { throw new IllegalStateException("Unknown primitive type: " + unknownValue); } } - private static final Table, Optional, Optional> COMBINE_TABLE = - computeCombineTable(); - - private static Table, Optional, Optional> computeCombineTable() { - List> allValues = Stream.concat( - Stream.of(Optional.empty()), - LogSafety.values().stream().map(Optional::of)) - .toList(); - ImmutableTable.Builder, Optional, Optional> builder = - ImmutableTable.builder(); - for (Optional left : allValues) { - for (Optional right : allValues) { - builder.put(left, right, computeCombine(left, right)); - } - } - return builder.buildOrThrow(); - } + private static final int SAFETY_VALUE_COUNT = LogSafety.Value.values().length; + private static final LogSafety.Value[] LOG_SAFETY_TABLE = computeLogSafetyTable(); - private static Optional computeCombine(Optional one, Optional two) { - if (one.isPresent() && two.isPresent()) { - return Optional.of(combine(one.get(), two.get())); + private static LogSafety.Value[] computeLogSafetyTable() { + LogSafety.Value[] table = new LogSafety.Value[SAFETY_VALUE_COUNT * SAFETY_VALUE_COUNT]; + for (LogSafety.Value left : LogSafety.Value.values()) { + for (LogSafety.Value right : LogSafety.Value.values()) { + table[index(left, right)] = combineInternal(left, right); + } } - return one.or(() -> two) - // When one value is unknown, we cannot assume the other is safe - .filter(value -> !LogSafety.SAFE.equals(value)); + return table; } - private static LogSafety combine(LogSafety one, LogSafety two) { - LogSafety.Value first = one.get(); - LogSafety.Value second = two.get(); - if (first == LogSafety.Value.UNKNOWN || second == LogSafety.Value.UNKNOWN) { - throw new IllegalStateException("Unable to compare LogSafety values: " + one + " and " + two); + private static LogSafety.Value combineInternal(LogSafety.Value left, LogSafety.Value right) { + if (left == LogSafety.Value.DO_NOT_LOG || right == LogSafety.Value.DO_NOT_LOG) { + return LogSafety.Value.DO_NOT_LOG; } - if (first == LogSafety.Value.DO_NOT_LOG || second == LogSafety.Value.DO_NOT_LOG) { - return LogSafety.DO_NOT_LOG; + if (left == LogSafety.Value.UNSAFE || right == LogSafety.Value.UNSAFE) { + return LogSafety.Value.UNSAFE; } - if (first == LogSafety.Value.UNSAFE || second == LogSafety.Value.UNSAFE) { - return LogSafety.UNSAFE; + if (left == LogSafety.Value.UNKNOWN || right == LogSafety.Value.UNKNOWN) { + return LogSafety.Value.UNKNOWN; } - return one; + return LogSafety.Value.SAFE; + } + + @SuppressWarnings("EnumOrdinal") + private static int index(LogSafety.Value left, LogSafety.Value right) { + return left.ordinal() * SAFETY_VALUE_COUNT + right.ordinal(); } + // TODO(aldexis): remove this public static Optional combine(Optional one, Optional two) { - return Preconditions.checkNotNull(COMBINE_TABLE.get(one, two), "Missing an entry in the combine table"); + LogSafety.Value combined = combine(value(one), value(two)); + + return fromValue(combined); + } + + private static LogSafety.Value combine(LogSafety.Value one, LogSafety.Value two) { + return Preconditions.checkNotNull(LOG_SAFETY_TABLE[index(one, two)], "Missing an entry in the combine table"); } public static boolean allows(Optional required, Optional given) {