From d019d4342b787386d57a7b151959fcbc64a24ef5 Mon Sep 17 00:00:00 2001 From: Johannes Wolf Date: Mon, 20 Apr 2026 22:01:32 +0200 Subject: [PATCH 1/3] Basic Schema Implementation --- include/simfil/environment.h | 10 + include/simfil/expression-visitor.h | 2 + include/simfil/expression.h | 49 ++- include/simfil/model/model.h | 13 +- include/simfil/model/nodes.h | 16 +- include/simfil/model/nodes.impl.h | 24 ++ include/simfil/model/schema.h | 273 ++++++++++++++++ include/simfil/parser.h | 6 - include/simfil/value.h | 2 +- src/completion.cpp | 18 +- src/completion.h | 8 +- src/diagnostics.cpp | 7 + src/environment.cpp | 7 + src/expression-visitor.cpp | 57 +--- src/expressions.cpp | 355 +++++++++++++++++---- src/expressions.h | 160 +++++++--- src/model/model.cpp | 136 +++++++- src/model/nodes.cpp | 12 + src/parser.cpp | 9 +- src/rewrite-rules.h | 106 +++++++ src/simfil.cpp | 176 ++++++----- test/CMakeLists.txt | 1 + test/common.hpp | 8 + test/performance.cpp | 1 - test/schema.cpp | 463 ++++++++++++++++++++++++++++ test/simfil.cpp | 31 +- test/value.cpp | 3 + 27 files changed, 1688 insertions(+), 265 deletions(-) create mode 100644 include/simfil/model/schema.h create mode 100644 src/rewrite-rules.h create mode 100644 test/schema.cpp diff --git a/include/simfil/environment.h b/include/simfil/environment.h index 03838728..7251ae9a 100644 --- a/include/simfil/environment.h +++ b/include/simfil/environment.h @@ -22,6 +22,7 @@ namespace simfil class Expr; class Function; class Diagnostics; +class Schema; struct ResultFn; struct Debug; @@ -61,6 +62,8 @@ struct Trace struct Environment { public: + using QuerySchemaCallback = std::function; + /** * Construct a SIMFIL execution environment with a string cache, * which is used to map field names to short integer IDs. @@ -116,6 +119,12 @@ struct Environment [[nodiscard]] auto strings() const -> std::shared_ptr; + /** + * Query an object schema by its schema id. + * Returns nullptr if no callback is configured or the schema is unknown. + */ + auto querySchema(SchemaId schemaId) const -> const Schema*; + public: std::unique_ptr warnMtx; std::vector> warnings; @@ -129,6 +138,7 @@ struct Environment /* constant ident -> value */ std::map constants; + QuerySchemaCallback querySchemaCallback; Debug* debug = nullptr; std::shared_ptr stringPool; }; diff --git a/include/simfil/expression-visitor.h b/include/simfil/expression-visitor.h index d01ccdfd..29308ed8 100644 --- a/include/simfil/expression-visitor.h +++ b/include/simfil/expression-visitor.h @@ -21,6 +21,7 @@ class UnpackExpr; class UnaryWordOpExpr; class BinaryWordOpExpr; class FieldExpr; +class WildcardFieldExpr; class PathExpr; class AndExpr; class OrExpr; @@ -54,6 +55,7 @@ class ExprVisitor virtual void visit(const CallExpression& expr); virtual void visit(const PathExpr& expr); virtual void visit(const FieldExpr& expr); + virtual void visit(const WildcardFieldExpr& expr); virtual void visit(const UnpackExpr& expr); virtual void visit(const UnaryWordOpExpr& expr); virtual void visit(const BinaryWordOpExpr& expr); diff --git a/include/simfil/expression.h b/include/simfil/expression.h index 39692f9d..e75e293f 100644 --- a/include/simfil/expression.h +++ b/include/simfil/expression.h @@ -8,12 +8,16 @@ #include "simfil/result.h" #include +#include namespace simfil { +class Expr; class ExprVisitor; +using ExprPtr = std::unique_ptr; + class Expr { friend class AST; @@ -31,17 +35,16 @@ class Expr VALUE, }; - Expr() = delete; - explicit Expr(ExprId id) - : id_(id) - {} - explicit Expr(ExprId id, const Token& token) - : id_(id) + Expr() = default; + explicit Expr(const Token& token) { assert(token.end >= token.begin); sourceLocation_.offset = token.begin; sourceLocation_.size = token.end - token.begin; } + explicit Expr(SourceLocation location) + : sourceLocation_(location) + {} virtual ~Expr() = default; @@ -56,6 +59,28 @@ class Expr return false; } + /* Accept expression visitor */ + virtual auto accept(ExprVisitor& v) const -> void = 0; + + /* Get the number of child expressions */ + virtual auto numChildren() const -> std::size_t + { + return 0; + } + + /* Get the n-th child expression */ + virtual auto childAt(std::size_t index) -> ExprPtr& + { + if (numChildren() == 0) + throw std::out_of_range("AST Child index out of range"); + throw std::runtime_error("Missing childAt function implementation"); + } + + virtual auto childAt(std::size_t index) const -> const ExprPtr& + { + return const_cast(*this).childAt(index); + } + /* Debug */ virtual auto toString() const -> std::string = 0; @@ -90,11 +115,7 @@ class Expr return ieval(ctx, std::move(val), res); } - /* Accept expression visitor */ - virtual auto accept(ExprVisitor& v) const -> void = 0; - /* Source location the expression got parsed from */ - [[nodiscard]] auto sourceLocation() const -> SourceLocation { return sourceLocation_; @@ -110,12 +131,10 @@ class Expr return ieval(ctx, value, result); } - ExprId id_; + ExprId id_ = 0; SourceLocation sourceLocation_; }; -using ExprPtr = std::unique_ptr; - class AST { public: @@ -126,6 +145,8 @@ class AST ~AST(); + auto reenumerate() -> void; + auto expr() const -> const Expr& { return *expr_; @@ -137,6 +158,8 @@ class AST } private: + static auto reenumerate(Expr& expr, Expr::ExprId& nextId) -> void; + /* The original query string of the AST */ std::string queryString_; diff --git a/include/simfil/model/model.h b/include/simfil/model/model.h index 31951867..0ca6bcec 100644 --- a/include/simfil/model/model.h +++ b/include/simfil/model/model.h @@ -2,6 +2,7 @@ #pragma once #include "simfil/model/string-pool.h" +#include "simfil/model/schema.h" #include "simfil/byte-array.h" #include "tl/expected.hpp" #if defined(SIMFIL_WITH_MODEL_JSON) @@ -274,7 +275,9 @@ class ModelPool : public Model size_t stringDataBytes = 0; size_t stringRangeBytes = 0; size_t objectMemberBytes = 0; + size_t objectSchemaBytes = 0; size_t arrayMemberBytes = 0; + size_t arraySchemaBytes = 0; [[nodiscard]] size_t totalBytes() const { @@ -284,7 +287,9 @@ class ModelPool : public Model + stringDataBytes + stringRangeBytes + objectMemberBytes - + arrayMemberBytes; + + objectSchemaBytes + + arrayMemberBytes + + arraySchemaBytes; } }; @@ -299,12 +304,18 @@ class ModelPool : public Model struct Impl; std::unique_ptr impl_; + [[nodiscard]] SchemaId objectSchemaId(ArrayIndex members) const; + auto setObjectSchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected; + [[nodiscard]] SchemaId arraySchemaId(ArrayIndex members) const; + auto setArraySchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected; + /** * Protected object/array member storage access, * so derived ModelPools can create Object/Array-derived nodes. */ Object::Storage& objectMemberStorage(); [[nodiscard]] Object::Storage const& objectMemberStorage() const; + Array::Storage& arrayMemberStorage(); [[nodiscard]] Array::Storage const& arrayMemberStorage() const; }; diff --git a/include/simfil/model/nodes.h b/include/simfil/model/nodes.h index d25da4bb..afed34f4 100644 --- a/include/simfil/model/nodes.h +++ b/include/simfil/model/nodes.h @@ -8,6 +8,7 @@ #include #include "arena.h" +#include "schema.h" #include "string-pool.h" #include "simfil/byte-array.h" #include "simfil/error.h" @@ -56,8 +57,9 @@ enum class ValueType Bytes, TransientObject, Object, - Array - // If you add types, update TypeFlags::flags bit size! + Array, + // End + LAST_ }; using ScalarValueType = std::variant< @@ -276,6 +278,9 @@ struct ModelNode /// Get an Object model's field names [[nodiscard]] virtual StringId keyAt(int64_t i) const; + /// Get the schema id for schema-aware container nodes, or NoSchemaId otherwise. + [[nodiscard]] virtual SchemaId schema() const; + /// Get the number of children [[nodiscard]] virtual uint32_t size() const; @@ -428,6 +433,7 @@ struct ModelNodeBase : public ModelNode [[nodiscard]] ModelNode::Ptr get(const StringId&) const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] StringId keyAt(int64_t) const override; + [[nodiscard]] SchemaId schema() const override; [[nodiscard]] uint32_t size() const override; bool iterate(IterCallback const&) const override {return true;} // NOLINT (allow discard) @@ -544,6 +550,9 @@ struct BaseArray : public MandatoryDerivedModelNodeBase bool forEach(std::function const& callback) const; + [[nodiscard]] SchemaId schema() const override; + auto setSchema(SchemaId schemaId) -> tl::expected; + [[nodiscard]] ValueType type() const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] uint32_t size() const override; @@ -607,6 +616,9 @@ struct BaseObject : public MandatoryDerivedModelNodeBase return addFieldInternal(name, static_cast(value)); } + [[nodiscard]] SchemaId schema() const override; + auto setSchema(SchemaId schemaId) -> tl::expected; + [[nodiscard]] ValueType type() const override; [[nodiscard]] ModelNode::Ptr at(int64_t) const override; [[nodiscard]] uint32_t size() const override; diff --git a/include/simfil/model/nodes.impl.h b/include/simfil/model/nodes.impl.h index d8a08e7e..09adbb8d 100644 --- a/include/simfil/model/nodes.impl.h +++ b/include/simfil/model/nodes.impl.h @@ -21,6 +21,18 @@ ValueType BaseArray::type() const return ValueType::Array; } +template +SchemaId BaseArray::schema() const +{ + return model().arraySchemaId(members_); +} + +template +auto BaseArray::setSchema(SchemaId schemaId) -> tl::expected +{ + return model().setArraySchemaId(members_, schemaId); +} + template ModelNode::Ptr BaseArray::at(int64_t i) const { @@ -96,6 +108,18 @@ ValueType BaseObject::type() const return ValueType::Object; } +template +SchemaId BaseObject::schema() const +{ + return model().objectSchemaId(members_); +} + +template +auto BaseObject::setSchema(SchemaId schemaId) -> tl::expected +{ + return model().setObjectSchemaId(members_, schemaId); +} + template ModelNode::Ptr BaseObject::at(int64_t i) const { diff --git a/include/simfil/model/schema.h b/include/simfil/model/schema.h new file mode 100644 index 00000000..27b047b2 --- /dev/null +++ b/include/simfil/model/schema.h @@ -0,0 +1,273 @@ +#pragma once + +#include "simfil/model/string-pool.h" +#include +#include +#include +#include +#include +#include +#include + +namespace simfil +{ + +class Schema; + +using SchemaId = std::uint16_t; +constexpr SchemaId NoSchemaId = SchemaId{0}; +constexpr SchemaId MaxSchemaId = SchemaId{std::numeric_limits::max()}; + +/** + * Concept defining a callback to query a Schema* by SchemaId. + */ +template +concept QuerySchemaFn = requires(const Fn& fn) { + { fn(SchemaId{}) } -> std::convertible_to; +}; +template +concept QueryMutableSchemaFn = requires(const Fn& fn) { + { fn(SchemaId{}) } -> std::convertible_to; +}; + +/** + * + */ +class Schema +{ +public: + /** Schema kind */ + enum class Kind { + Object, + Array, + }; + + /** Finalization state */ + enum class State { + Dirty, + Finalizing, + Clean, + }; + + virtual ~Schema() = default; + + /** + * Return this schemas kind. + */ + virtual auto kind() const -> Kind = 0; + + /** + * Returns true if this schema or any of the schemas it refers to + * can possibly contain the given field. + * + * @param fieldId The field id to query the schema for + */ + virtual auto canHaveField(StringId fieldId) const -> bool = 0; + + /** + * Finalize this schema and all schemas it refers to. + * + * @param queryFn Schema Query callback. + */ + virtual auto finalize(const std::function& queryFn) -> State + { + return State::Clean; + } + + /** + * @return All nested field names. + */ + virtual auto nestedFields() const & -> std::span = 0; +}; + +/** + * Schema for object nodes. + * + * Stores direct fields and optional child schema ids per field. After + * `finalize()` it also caches all reachable child fields. + */ +class ObjectSchema : public Schema +{ +public: + struct FieldSummary { + StringId field = 0; + sfl::small_vector schemas; + + auto operator<=>(const FieldSummary& other) const + { + return field <=> other.field; + } + }; + + auto kind() const -> Kind override + { + return Kind::Object; + } + + auto canHaveField(StringId field) const -> bool override + { + // Be conservative if the schema has not been finalized. + if (state_ != State::Clean) + return true; + + auto iter = std::lower_bound(flatFields_.begin(), flatFields_.end(), field); + return iter != flatFields_.end() && *iter == field; + } + + /** + * Add a direct field and optional child schemas reachable through it. + */ + auto addField(StringId field, std::initializer_list schemas = {}) -> void + { + FieldSummary summary; + summary.field = field; + summary.schemas.insert(summary.schemas.end(), schemas.begin(), schemas.end()); + fields_.push_back(std::move(summary)); + state_ = State::Dirty; + } + + /** + * Recompute the cached descendant field set from this schema and all + * reachable child schemas. + */ + auto finalize(const std::function& lookup) -> State override + { + if (state_ == State::Clean || state_ == State::Finalizing) + return state_; + + state_ = State::Finalizing; + flatFields_.clear(); + auto canFinalize = true; + + for (const auto& field : fields_) { + flatFields_.push_back(field.field); + for (const auto& fieldSchemaId : field.schemas) { + if (auto* childSchema = lookup(fieldSchemaId)) { + auto childState = childSchema->finalize(lookup); + if (childState != State::Clean) { + canFinalize = false; + continue; + } + + auto childFields = childSchema->nestedFields(); + flatFields_.insert(flatFields_.end(), childFields.begin(), childFields.end()); + } + } + } + + if (!canFinalize) { + flatFields_.clear(); + state_ = State::Dirty; + return State::Dirty; + } + + std::sort(flatFields_.begin(), flatFields_.end()); + flatFields_.erase(std::unique(flatFields_.begin(), flatFields_.end()), flatFields_.end()); + state_ = State::Clean; + return State::Clean; + } + + auto fields() const & -> std::span + { + return {fields_.begin(), fields_.end()}; + } + + auto nestedFields() const & -> std::span override + { + return {flatFields_.cbegin(), flatFields_.cend()}; + } + +private: + sfl::small_vector fields_; + + std::vector flatFields_; // Ordered! + State state_ = State::Dirty; +}; + +/** + * Schema for array nodes. + * + * Stores the set of possible element schemas. After `finalize()` it caches + * all fields reachable through any element schema. + */ +class ArraySchema : public Schema +{ +public: + auto kind() const -> Kind override + { + return Kind::Array; + } + + auto canHaveField(StringId field) const -> bool override + { + if (state_ != State::Clean) + return true; + + auto iter = std::lower_bound(flatFields_.begin(), flatFields_.end(), field); + return iter != flatFields_.end() && *iter == field; + } + + /** + * Add possible schemas for elements contained in the array. + */ + auto addElementSchemas(std::initializer_list schemas) -> void + { + schemas_.insert(schemas_.end(), schemas.begin(), schemas.end()); + state_ = State::Dirty; + } + + /** + * Recompute the cached descendant field set from all possible element + * schemas. + */ + auto finalize(const std::function& lookup) -> State override + { + if (state_ == State::Clean || state_ == State::Finalizing) + return state_; + + state_ = State::Finalizing; + flatFields_.clear(); + auto canFinalize = true; + + for (const auto& schemaId : schemas_) { + if (auto* childSchema = lookup(schemaId)) { + auto childState = childSchema->finalize(lookup); + if (childState != State::Clean) { + canFinalize = false; + continue; + } + + auto childFields = childSchema->nestedFields(); + flatFields_.insert(flatFields_.end(), childFields.begin(), childFields.end()); + } + } + + if (!canFinalize) { + flatFields_.clear(); + state_ = State::Dirty; + return State::Dirty; + } + + std::sort(flatFields_.begin(), flatFields_.end()); + flatFields_.erase(std::unique(flatFields_.begin(), flatFields_.end()), flatFields_.end()); + state_ = State::Clean; + return State::Clean; + } + + auto nestedFields() const & -> std::span override + { + return {flatFields_.cbegin(), flatFields_.cend()}; + } + + auto elementSchemas() const & -> std::span + { + return {schemas_.begin(), schemas_.end()}; + } + +private: + sfl::small_vector schemas_; + std::vector flatFields_; // Ordered! + State state_ = State::Dirty; +}; + +} diff --git a/include/simfil/parser.h b/include/simfil/parser.h index 0999e500..43fe4c2b 100644 --- a/include/simfil/parser.h +++ b/include/simfil/parser.h @@ -59,7 +59,6 @@ class Parser }; struct Context { - Expr::ExprId id = 0; bool inPath = false; }; @@ -104,11 +103,6 @@ class Parser auto mode() const -> Mode; auto relaxed() const -> bool; - /** - * Get the next expression id. - */ - auto nextId() -> Expr::ExprId; - Context ctx; Environment* const env; std::unordered_map prefixParsers; diff --git a/include/simfil/value.h b/include/simfil/value.h index 3805a75d..fe35d892 100644 --- a/include/simfil/value.h +++ b/include/simfil/value.h @@ -104,7 +104,7 @@ inline auto valueType2String(ValueType t) -> const char* */ struct TypeFlags { - std::bitset<10> flags; + std::bitset(ValueType::LAST_)> flags; auto test(ValueType type) const { diff --git a/src/completion.cpp b/src/completion.cpp index 83f3be14..eefd32ce 100644 --- a/src/completion.cpp +++ b/src/completion.cpp @@ -125,8 +125,8 @@ auto completeWords(const simfil::Context& ctx, std::string_view prefix, simfil:: namespace simfil { -CompletionFieldOrWordExpr::CompletionFieldOrWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token, bool inPath) - : Expr(id, token) +CompletionFieldOrWordExpr::CompletionFieldOrWordExpr(std::string prefix, Completion* comp, const Token& token, bool inPath) + : Expr(token) , prefix_(std::move(prefix)) , comp_(comp) , inPath_(inPath) @@ -225,9 +225,8 @@ struct FindExpressionRange : ExprVisitor } -CompletionAndExpr::CompletionAndExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp) - : Expr(id) - , left_(std::move(left)) +CompletionAndExpr::CompletionAndExpr(ExprPtr left, ExprPtr right, const Completion* comp) + : left_(std::move(left)) , right_(std::move(right)) { FindExpressionRange leftRange; @@ -281,9 +280,8 @@ auto CompletionAndExpr::toString() const -> std::string return "(and ? ?)"; } -CompletionOrExpr::CompletionOrExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp) - : Expr(id) - , left_(std::move(left)) +CompletionOrExpr::CompletionOrExpr(ExprPtr left, ExprPtr right, const Completion* comp) + : left_(std::move(left)) , right_(std::move(right)) { FindExpressionRange leftRange; @@ -337,8 +335,8 @@ auto CompletionOrExpr::toString() const -> std::string return "(or ? ?)"; } -CompletionWordExpr::CompletionWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token) - : Expr(id, token) +CompletionWordExpr::CompletionWordExpr(std::string prefix, Completion* comp, const Token& token) + : Expr(token) , prefix_(std::move(prefix)) , comp_(comp) {} diff --git a/src/completion.h b/src/completion.h index 8f9c51ba..13b8ad0e 100644 --- a/src/completion.h +++ b/src/completion.h @@ -45,7 +45,7 @@ struct Completion class CompletionFieldOrWordExpr : public Expr { public: - CompletionFieldOrWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token, bool inPath); + CompletionFieldOrWordExpr(std::string prefix, Completion* comp, const Token& token, bool inPath); auto type() const -> Type override; auto ieval(Context ctx, const Value& value, const ResultFn& result) const -> tl::expected override; @@ -60,7 +60,7 @@ class CompletionFieldOrWordExpr : public Expr class CompletionAndExpr : public Expr { public: - CompletionAndExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp); + CompletionAndExpr(ExprPtr left, ExprPtr right, const Completion* comp); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -73,7 +73,7 @@ class CompletionAndExpr : public Expr class CompletionOrExpr : public Expr { public: - CompletionOrExpr(ExprId id, ExprPtr left, ExprPtr right, const Completion* comp); + CompletionOrExpr(ExprPtr left, ExprPtr right, const Completion* comp); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -86,7 +86,7 @@ class CompletionOrExpr : public Expr class CompletionWordExpr : public Expr { public: - CompletionWordExpr(ExprId id, std::string prefix, Completion* comp, const Token& token); + CompletionWordExpr(std::string prefix, Completion* comp, const Token& token); auto type() const -> Type override; auto constant() const -> bool override; diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index 062f4fdf..bce1fb68 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -253,6 +253,13 @@ auto Diagnostics::prepareIndices(const Expr& ast) -> void indices_[e.id()] = fieldIndex_++; } + auto visit(const WildcardFieldExpr& e) -> void override { + ExprVisitor::visit(e); + if (e.id() >= indices_.size()) + indices_.resize(e.id() + 1, Diagnostics::InvalidIndex); + indices_[e.id()] = fieldIndex_++; + } + auto visitComparisonOperator(const ComparisonExprBase& e) -> void { if (e.id() >= indices_.size()) diff --git a/src/environment.cpp b/src/environment.cpp index 88bbd07b..279a9656 100644 --- a/src/environment.cpp +++ b/src/environment.cpp @@ -59,6 +59,13 @@ auto Environment::strings() const -> std::shared_ptr { return stringPool; } +auto Environment::querySchema(SchemaId schemaId) const -> const Schema* +{ + if (!querySchemaCallback || schemaId == NoSchemaId) + return nullptr; + return querySchemaCallback(schemaId); +} + Context::Context(Environment* env, Diagnostics* diag, Context::Phase phase) : env(env) , diag(diag) diff --git a/src/expression-visitor.cpp b/src/expression-visitor.cpp index ec5a2acf..54f7589e 100644 --- a/src/expression-visitor.cpp +++ b/src/expression-visitor.cpp @@ -13,6 +13,10 @@ ExprVisitor::~ExprVisitor() = default; void ExprVisitor::visit(const Expr& e) { index_++; + + const auto count = e.numChildren(); + for (auto i = 0; i < count; ++i) + e.childAt(i)->accept(*this); } void ExprVisitor::visit(const WildcardExpr& expr) @@ -38,58 +42,31 @@ void ExprVisitor::visit(const ConstExpr& expr) void ExprVisitor::visit(const SubscriptExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.index_) - expr.index_->accept(*this); } void ExprVisitor::visit(const SubExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.sub_) - expr.sub_->accept(*this); } void ExprVisitor::visit(const AnyExpr& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const EachExpr& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const CallExpression& expr) { visit(static_cast(expr)); - - for (const auto& arg : expr.args_) - if (arg) - arg->accept(*this); } void ExprVisitor::visit(const PathExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const FieldExpr& expr) @@ -97,50 +74,34 @@ void ExprVisitor::visit(const FieldExpr& expr) visit(static_cast(expr)); } -void ExprVisitor::visit(const UnpackExpr& expr) +void ExprVisitor::visit(const WildcardFieldExpr& expr) { visit(static_cast(expr)); +} - if (expr.sub_) - expr.sub_->accept(*this); +void ExprVisitor::visit(const UnpackExpr& expr) +{ + visit(static_cast(expr)); } void ExprVisitor::visit(const UnaryWordOpExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); } void ExprVisitor::visit(const BinaryWordOpExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const AndExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const OrExpr& expr) { visit(static_cast(expr)); - - if (expr.left_) - expr.left_->accept(*this); - if (expr.right_) - expr.right_->accept(*this); } void ExprVisitor::visit(const BinaryExpr& e) diff --git a/src/expressions.cpp b/src/expressions.cpp index 82d77d2b..a91d6b2a 100644 --- a/src/expressions.cpp +++ b/src/expressions.cpp @@ -2,6 +2,8 @@ #include "fmt/format.h" #include "simfil/environment.h" +#include "simfil/expression.h" +#include "simfil/model/string-pool.h" #include "simfil/result.h" #include "simfil/value.h" #include "simfil/function.h" @@ -77,9 +79,7 @@ auto boolify(const Value& v) -> bool } -WildcardExpr::WildcardExpr(ExprId id) - : Expr(id) -{} +WildcardExpr::WildcardExpr() = default; auto WildcardExpr::type() const -> Type { @@ -143,9 +143,7 @@ auto WildcardExpr::toString() const -> std::string return "**"s; } -AnyChildExpr::AnyChildExpr(ExprId id) - : Expr(id) -{} +AnyChildExpr::AnyChildExpr() = default; auto AnyChildExpr::type() const -> Type { @@ -186,14 +184,12 @@ auto AnyChildExpr::toString() const -> std::string return "*"s; } -FieldExpr::FieldExpr(ExprId id, std::string name) - : Expr(id) - , name_(std::move(name)) +FieldExpr::FieldExpr(std::string name) + : name_(std::move(name)) {} -FieldExpr::FieldExpr(ExprId id, std::string name, const Token& token) - : Expr(id, token) - , name_(std::move(name)) +FieldExpr::FieldExpr(std::string name, const Token& token) + : name_(std::move(name)) {} auto FieldExpr::type() const -> Type @@ -262,14 +258,12 @@ auto FieldExpr::toString() const -> std::string return name_; } -MultiConstExpr::MultiConstExpr(ExprId id, const std::vector& vec) - : Expr(id) - , values_(vec) +MultiConstExpr::MultiConstExpr(const std::vector& vec) + : values_(vec) {} -MultiConstExpr::MultiConstExpr(ExprId id, std::vector&& vec) - : Expr(id) - , values_(std::move(vec)) +MultiConstExpr::MultiConstExpr(std::vector&& vec) + : values_(std::move(vec)) {} auto MultiConstExpr::type() const -> Type @@ -308,9 +302,8 @@ auto MultiConstExpr::toString() const -> std::string return fmt::format("{{{}}}", fmt::join(items, " ")); } -ConstExpr::ConstExpr(ExprId id, Value value) - : Expr(id) - , value_(std::move(value)) +ConstExpr::ConstExpr(Value value) + : value_(std::move(value)) {} auto ConstExpr::type() const -> Type @@ -345,9 +338,8 @@ auto ConstExpr::value() const -> const Value& return value_; } -SubscriptExpr::SubscriptExpr(ExprId id, ExprPtr left, ExprPtr index) - : Expr(id) - , left_(std::move(left)) +SubscriptExpr::SubscriptExpr(ExprPtr left, ExprPtr index) + : left_(std::move(left)) , index_(std::move(index)) {} @@ -400,14 +392,30 @@ void SubscriptExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto SubscriptExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto SubscriptExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return index_; + default: + return Expr::childAt(index); + } +} + auto SubscriptExpr::toString() const -> std::string { return fmt::format("(index {} {})", left_->toString(), index_->toString()); } -SubExpr::SubExpr(ExprId id, ExprPtr left, ExprPtr sub) - : Expr(id) - , left_(std::move(left)) +SubExpr::SubExpr(ExprPtr left, ExprPtr sub) + : left_(std::move(left)) , sub_(std::move(sub)) {} @@ -453,9 +461,25 @@ void SubExpr::accept(ExprVisitor& v) const v.visit(*this); } -AnyExpr::AnyExpr(ExprId id, std::vector args) - : Expr(id) - , args_(std::move(args)) +auto SubExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto SubExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return sub_; + default: + return Expr::childAt(index); + } +} + +AnyExpr::AnyExpr(std::vector args) + : args_(std::move(args)) {} auto AnyExpr::type() const -> Type @@ -497,6 +521,16 @@ auto AnyExpr::accept(ExprVisitor& v) const -> void v.visit(*this); } +auto AnyExpr::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto AnyExpr::childAt(std::size_t index) -> ExprPtr& +{ + return args_[index]; +} + auto AnyExpr::toString() const -> std::string { if (args_.empty()) @@ -509,9 +543,8 @@ auto AnyExpr::toString() const -> std::string return fmt::format("(any {})", fmt::join(items, " ")); } -EachExpr::EachExpr(ExprId id, std::vector args) - : Expr(id) - , args_(std::move(args)) +EachExpr::EachExpr(std::vector args) + : args_(std::move(args)) {} auto EachExpr::type() const -> Type @@ -552,6 +585,16 @@ auto EachExpr::accept(ExprVisitor& v) const -> void v.visit(*this); } +auto EachExpr::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto EachExpr::childAt(std::size_t index) -> ExprPtr& +{ + return args_[index]; +} + auto EachExpr::toString() const -> std::string { if (args_.empty()) @@ -564,9 +607,8 @@ auto EachExpr::toString() const -> std::string return fmt::format("(each {})", fmt::join(items, " ")); } -CallExpression::CallExpression(ExprId id, std::string name, std::vector args) - : Expr(id) - , name_(std::move(name)) +CallExpression::CallExpression(std::string name, std::vector args) + : name_(std::move(name)) , args_(std::move(args)) {} @@ -606,6 +648,16 @@ void CallExpression::accept(ExprVisitor& v) const v.visit(*this); } +auto CallExpression::numChildren() const -> std::size_t +{ + return args_.size(); +} + +auto CallExpression::childAt(std::size_t index) -> ExprPtr& +{ + return args_[index]; +} + auto CallExpression::toString() const -> std::string { if (args_.empty()) @@ -618,9 +670,8 @@ auto CallExpression::toString() const -> std::string return fmt::format("({} {})", name_, fmt::join(items, " ")); } -PathExpr::PathExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) +PathExpr::PathExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) { assert(left_.get()); @@ -667,14 +718,30 @@ void PathExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto PathExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto PathExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } +} + auto PathExpr::toString() const -> std::string { return fmt::format("(. {} {})", left_->toString(), right_->toString()); } -UnpackExpr::UnpackExpr(ExprId id, ExprPtr sub) - : Expr(id) - , sub_(std::move(sub)) +UnpackExpr::UnpackExpr(ExprPtr sub) + : sub_(std::move(sub)) {} auto UnpackExpr::type() const -> Type @@ -717,14 +784,25 @@ void UnpackExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto UnpackExpr::numChildren() const -> std::size_t +{ + return 1; +} + +auto UnpackExpr::childAt(std::size_t index) -> ExprPtr& +{ + if (index == 0) + return sub_; + return Expr::childAt(index); +} + auto UnpackExpr::toString() const -> std::string { return fmt::format("(... {})", sub_->toString()); } -UnaryWordOpExpr::UnaryWordOpExpr(ExprId id, std::string ident, ExprPtr left) - : Expr(id) - , ident_(std::move(ident)) +UnaryWordOpExpr::UnaryWordOpExpr(std::string ident, ExprPtr left) + : ident_(std::move(ident)) , left_(std::move(left)) {} @@ -756,14 +834,25 @@ void UnaryWordOpExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto UnaryWordOpExpr::numChildren() const -> std::size_t +{ + return 1; +} + +auto UnaryWordOpExpr::childAt(std::size_t index) -> ExprPtr& +{ + if (index == 0) + return left_; + return Expr::childAt(index); +} + auto UnaryWordOpExpr::toString() const -> std::string { return fmt::format("({} {})", ident_, left_->toString()); } -BinaryWordOpExpr::BinaryWordOpExpr(ExprId id, std::string ident, ExprPtr left, ExprPtr right) - : Expr(id) - , ident_(std::move(ident)) +BinaryWordOpExpr::BinaryWordOpExpr(std::string ident, ExprPtr left, ExprPtr right) + : ident_(std::move(ident)) , left_(std::move(left)) , right_(std::move(right)) {} @@ -806,14 +895,30 @@ void BinaryWordOpExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto BinaryWordOpExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto BinaryWordOpExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } +} + auto BinaryWordOpExpr::toString() const -> std::string { return fmt::format("({} {} {})", ident_, left_->toString(), right_->toString()); } -AndExpr::AndExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) +AndExpr::AndExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) { assert(left_.get()); @@ -850,14 +955,30 @@ void AndExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto AndExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto AndExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } +} + auto AndExpr::toString() const -> std::string { return fmt::format("(and {} {})", left_->toString(), right_->toString()); } -OrExpr::OrExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) +OrExpr::OrExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) { assert(left_.get()); @@ -895,9 +1016,135 @@ void OrExpr::accept(ExprVisitor& v) const v.visit(*this); } +auto OrExpr::numChildren() const -> std::size_t +{ + return 2; +} + +auto OrExpr::childAt(std::size_t index) -> ExprPtr& +{ + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } +} + auto OrExpr::toString() const -> std::string { return fmt::format("(or {} {})", left_->toString(), right_->toString()); } +WildcardFieldExpr::WildcardFieldExpr(std::string name, SourceLocation location) + : Expr(location) + , name_(std::move(name)) +{} + +auto WildcardFieldExpr::type() const -> Type +{ + return Type::PATH; +} + +auto WildcardFieldExpr::ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected +{ + if (ctx.phase == Context::Phase::Compilation) + return ores(ctx, Value::undef()); + + CountedResultFn res(ores, ctx); + + Diagnostics::FieldExprData* diag = nullptr; + if (ctx.diag) + diag = &ctx.diag->get(*this); + + if (diag) { + diag->location = sourceLocation(); + if (diag->name.empty()) + diag->name = name_; + } + + // Querying a field not in the string-pool + // is a no-op here (not true for FieldExpr). + if (!nameId_) { + nameId_ = ctx.env->strings()->get(name_); + if (!nameId_) { + if (diag) + diag->evaluations++; + res.ensureCall(); + return {Result::Continue}; + } + } + + struct Iterate + { + Context& ctx; + ResultFn& res; + StringId field; + Diagnostics::FieldExprData* diag; + + [[nodiscard]] auto iterate(ModelNode const& val) noexcept -> tl::expected + { + if (field == StringPool::StaticStringIds::Empty) + return Result::Continue; + + if (val.type() == ValueType::Null) [[unlikely]] + return Result::Continue; + + if (auto* schema = ctx.env->querySchema(val.schema())) { + if (!schema->canHaveField(field)) + return Result::Continue; + } + + if (diag) + diag->evaluations++; + + if (auto sub = val.get(field)) { + if (diag) + diag->hits++; + + auto result = res(ctx, Value::field(*sub)); + TRY_EXPECTED(result); + if (*result == Result::Stop) [[unlikely]] + return *result; + } + + tl::expected finalResult = Result::Continue; + val.iterate(ModelNode::IterLambda([&, this](const auto& subNode) { + auto subResult = iterate(subNode); + if (!subResult) { + finalResult = std::move(subResult); + return false; + } + + if (*subResult == Result::Stop) { + finalResult = Result::Stop; + return false; + } + + return true; + })); + + return finalResult; + } + }; + + auto r = val.nodePtr() + ? Iterate{ctx, res, nameId_, diag}.iterate(**val.nodePtr()) + : tl::expected(Result::Continue); + res.ensureCall(); + return r; +} + +void WildcardFieldExpr::accept(ExprVisitor& v) const +{ + v.visit(*this); +} + +auto WildcardFieldExpr::toString() const -> std::string +{ + return fmt::format("**.{}", name_); +} + } diff --git a/src/expressions.h b/src/expressions.h index c98257ec..fa249236 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -15,7 +15,7 @@ namespace simfil class WildcardExpr : public Expr { public: - explicit WildcardExpr(ExprId); + WildcardExpr(); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; @@ -29,7 +29,7 @@ class WildcardExpr : public Expr class AnyChildExpr : public Expr { public: - explicit AnyChildExpr(ExprId); + AnyChildExpr(); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -40,8 +40,8 @@ class AnyChildExpr : public Expr class FieldExpr : public Expr { public: - FieldExpr(ExprId id, std::string name); - FieldExpr(ExprId id, std::string name, const Token& token); + explicit FieldExpr(std::string name); + FieldExpr(std::string name, const Token& token); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; @@ -59,8 +59,8 @@ class MultiConstExpr : public Expr static constexpr size_t Limit = 10000; MultiConstExpr() = delete; - MultiConstExpr(ExprId id, const std::vector& vec); - MultiConstExpr(ExprId id, std::vector&& vec); + explicit MultiConstExpr(const std::vector& vec); + explicit MultiConstExpr(std::vector&& vec); auto type() const -> Type override; auto constant() const -> bool override; @@ -76,11 +76,10 @@ class ConstExpr : public Expr public: ConstExpr() = delete; template - ConstExpr(ExprId id, CType_&& value) - : Expr(id) - , value_(Value::make(std::forward(value))) + explicit ConstExpr(CType_&& value) + : value_(Value::make(std::forward(value))) {} - ConstExpr(ExprId id, Value value); + explicit ConstExpr(Value value); auto type() const -> Type override; auto constant() const -> bool override; @@ -97,11 +96,13 @@ class ConstExpr : public Expr class SubscriptExpr : public Expr { public: - SubscriptExpr(ExprId id, ExprPtr left, ExprPtr index); + SubscriptExpr(ExprPtr left, ExprPtr index); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_; @@ -111,12 +112,14 @@ class SubscriptExpr : public Expr class SubExpr : public Expr { public: - SubExpr(ExprId id, ExprPtr left, ExprPtr sub); + SubExpr(ExprPtr left, ExprPtr sub); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, sub_; @@ -125,11 +128,13 @@ class SubExpr : public Expr class AnyExpr : public Expr { public: - AnyExpr(ExprId id, std::vector args); + explicit AnyExpr(std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; std::vector args_; @@ -138,11 +143,13 @@ class AnyExpr : public Expr class EachExpr : public Expr { public: - EachExpr(ExprId id, std::vector args); + explicit EachExpr(std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; std::vector args_; @@ -151,12 +158,14 @@ class EachExpr : public Expr class CallExpression : public Expr { public: - CallExpression(ExprId id, std::string name, std::vector args); + CallExpression(std::string name, std::vector args); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; std::string name_; @@ -167,12 +176,14 @@ class CallExpression : public Expr class PathExpr : public Expr { public: - PathExpr(ExprId id, ExprPtr left, ExprPtr right); + PathExpr(ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& ores) const -> tl::expected override; auto ieval(Context ctx, Value&& val, const ResultFn& ores) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, right_; @@ -186,11 +197,13 @@ class PathExpr : public Expr class UnpackExpr : public Expr { public: - UnpackExpr(ExprId id, ExprPtr sub); + explicit UnpackExpr(ExprPtr sub); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr sub_; @@ -203,9 +216,8 @@ template class UnaryExpr : public Expr { public: - UnaryExpr(ExprId id, ExprPtr sub) - : Expr(id) - , sub_(std::move(sub)) + explicit UnaryExpr(ExprPtr sub) + : sub_(std::move(sub)) {} auto type() const -> Type override @@ -223,12 +235,24 @@ class UnaryExpr : public Expr })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(*this); sub_->accept(v); } + auto numChildren() const -> std::size_t override + { + return 1; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + if (index == 0) + return sub_; + return Expr::childAt(index); + } + auto toString() const -> std::string override { return "("s + Operator::name() + " "s + sub_->toString() + ")"s; @@ -244,15 +268,13 @@ template class BinaryExpr : public Expr { public: - BinaryExpr(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) + BinaryExpr(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) {} - BinaryExpr(ExprId id, const Token& token, ExprPtr left, ExprPtr right) - : Expr(id, token) - , left_(std::move(left)) + BinaryExpr(const Token& token, ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) {} @@ -273,11 +295,26 @@ class BinaryExpr : public Expr })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(*this); - left_->accept(v); - right_->accept(v); + } + + auto numChildren() const -> std::size_t override + { + return 2; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } } auto toString() const -> std::string override @@ -292,14 +329,13 @@ class ComparisonExprBase : public Expr { public: - ComparisonExprBase(ExprId id, ExprPtr left, ExprPtr right) - : Expr(id) - , left_(std::move(left)) + ComparisonExprBase(ExprPtr left, ExprPtr right) + : left_(std::move(left)) , right_(std::move(right)) {} - ComparisonExprBase(ExprId id, const Token& token, ExprPtr left, ExprPtr right) - : Expr(id, token) + ComparisonExprBase(const Token& token, ExprPtr left, ExprPtr right) + : Expr(token) , left_(std::move(left)) , right_(std::move(right)) {} @@ -309,6 +345,23 @@ class ComparisonExprBase : public Expr return Type::VALUE; } + auto numChildren() const -> std::size_t override + { + return 2; + } + + auto childAt(std::size_t index) -> ExprPtr& override + { + switch (index) { + case 0: + return left_; + case 1: + return right_; + default: + return Expr::childAt(index); + } + } + ExprPtr left_, right_; }; @@ -352,11 +405,9 @@ class ComparisonExpr : public ComparisonExprBase })); } - void accept(ExprVisitor& v) const override + auto accept(ExprVisitor& v) const -> void override { v.visit(static_cast(*this)); - left_->accept(v); - right_->accept(v); } auto toString() const -> std::string override @@ -404,11 +455,13 @@ class BinaryExpr : public ComparisonExpr Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; std::string ident_; @@ -418,11 +471,13 @@ class UnaryWordOpExpr : public Expr class BinaryWordOpExpr : public Expr { public: - BinaryWordOpExpr(ExprId id, std::string ident, ExprPtr left, ExprPtr right); + BinaryWordOpExpr(std::string ident, ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; std::string ident_; @@ -432,11 +487,13 @@ class BinaryWordOpExpr : public Expr class AndExpr : public Expr { public: - AndExpr(ExprId id, ExprPtr left, ExprPtr right); + AndExpr(ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, right_; @@ -445,14 +502,33 @@ class AndExpr : public Expr class OrExpr : public Expr { public: - OrExpr(ExprId id, ExprPtr left, ExprPtr right); + OrExpr(ExprPtr left, ExprPtr right); auto type() const -> Type override; auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; void accept(ExprVisitor& v) const override; + auto numChildren() const -> std::size_t override; + auto childAt(std::size_t index) -> ExprPtr& override; auto toString() const -> std::string override; ExprPtr left_, right_; }; +/** A specialized expression for queries of the form `**.field`, that + * takes object schema information into account. + */ +class WildcardFieldExpr : public Expr +{ +public: + explicit WildcardFieldExpr(std::string name, SourceLocation location = {}); + + auto type() const -> Type override; + auto ieval(Context ctx, const Value& val, const ResultFn& res) const -> tl::expected override; + void accept(ExprVisitor& v) const override; + auto toString() const -> std::string override; + + std::string name_; + mutable StringId nameId_ = {}; +}; + } diff --git a/src/model/model.cpp b/src/model/model.cpp index 6b1ad9a1..1ebae09e 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -15,11 +16,13 @@ #include #include #include +#include #include #include #include #include "../expected.h" +#include "simfil/model/schema.h" namespace simfil { @@ -86,6 +89,10 @@ struct ModelPool::Impl ModelColumn strings_; ModelColumn byteArrays_; + ModelColumn objectSchemas_; + ModelColumn objectSingletonSchemas_; + ModelColumn arraySchemas_; + ModelColumn arraySingletonSchemas_; Object::Storage objectMemberArrays_; Array::Storage arrayMemberArrays_; } columns_; @@ -99,6 +106,10 @@ struct ModelPool::Impl s.text1b(columns_.stringData_, maxColumnSize); s.object(columns_.strings_); s.object(columns_.byteArrays_); + s.object(columns_.objectSchemas_); + s.object(columns_.objectSingletonSchemas_); + s.object(columns_.arraySchemas_); + s.object(columns_.arraySingletonSchemas_); s.ext(columns_.objectMemberArrays_, bitsery::ext::ArrayArenaExt{}); s.ext(columns_.arrayMemberArrays_, bitsery::ext::ArrayArenaExt{}); @@ -119,6 +130,20 @@ ModelPool::~ModelPool() // NOLINT std::vector ModelPool::checkForErrors() const { std::vector errors; + auto objectHasSchema = [&](ArrayIndex members) { + if (Object::Storage::is_singleton_handle(members)) { + return Object::Storage::singleton_payload(members) + < impl_->columns_.objectSingletonSchemas_.size(); + } + return members < impl_->columns_.objectSchemas_.size(); + }; + auto arrayHasSchema = [&](ArrayIndex members) { + if (Array::Storage::is_singleton_handle(members)) { + return Array::Storage::singleton_payload(members) + < impl_->columns_.arraySingletonSchemas_.size(); + } + return members < impl_->columns_.arraySchemas_.size(); + }; auto validateArrayIndex = [&](auto i, auto arrType, auto const& arena) { if (!arena.valid(static_cast(i))) { @@ -142,6 +167,10 @@ std::vector ModelPool::checkForErrors() const if (node->addr().column() == Objects) if (!validateArrayIndex(node->addr().index(), "object", impl_->columns_.objectMemberArrays_)) return; + if (!objectHasSchema(node->addr().index())) { + errors.emplace_back(fmt::format("Missing object schema index {}.", node->addr().index())); + return; + } for (auto const& [fieldName, fieldValue] : node->fields()) { validatePooledString(fieldName); validateModelNode(fieldValue); @@ -151,6 +180,10 @@ std::vector ModelPool::checkForErrors() const if (node->addr().column() == Arrays) if (!validateArrayIndex(node->addr().index(), "arrays", impl_->columns_.arrayMemberArrays_)) return; + if (!arrayHasSchema(node->addr().index())) { + errors.emplace_back(fmt::format("Missing array schema index {}.", node->addr().index())); + return; + } for (auto const& member : *node) validateModelNode(member); } @@ -163,16 +196,28 @@ std::vector ModelPool::checkForErrors() const }; // Validate objects - for (auto i = 0; i < impl_->columns_.objectMemberArrays_.size(); ++i) + for (auto i = FirstRegularArrayIndex; i < impl_->columns_.objectMemberArrays_.size(); ++i) validateModelNode(ModelNode::Ptr::make( shared_from_this(), ModelNodeAddress{Objects, (uint32_t)i})); + for (auto i = 0u; i < impl_->columns_.objectMemberArrays_.singleton_handle_count(); ++i) + validateModelNode(ModelNode::Ptr::make( + shared_from_this(), + ModelNodeAddress{ + Objects, + SingletonArrayHandleMask | static_cast(i)})); // Validate arrays - for (auto i = 0; i < impl_->columns_.arrayMemberArrays_.size(); ++i) + for (auto i = FirstRegularArrayIndex; i < impl_->columns_.arrayMemberArrays_.size(); ++i) validateModelNode(ModelNode::Ptr::make( shared_from_this(), ModelNodeAddress{Arrays, (uint32_t)i})); + for (auto i = 0u; i < impl_->columns_.arrayMemberArrays_.singleton_handle_count(); ++i) + validateModelNode(ModelNode::Ptr::make( + shared_from_this(), + ModelNodeAddress{ + Arrays, + SingletonArrayHandleMask | static_cast(i)})); // Validate roots for (auto i = 0; i < numRoots(); ++i) @@ -205,6 +250,10 @@ void ModelPool::clear() clear_and_shrink(columns.strings_); clear_and_shrink(columns.stringData_); clear_and_shrink(columns.byteArrays_); + clear_and_shrink(columns.objectSchemas_); + clear_and_shrink(columns.objectSingletonSchemas_); + clear_and_shrink(columns.arraySchemas_); + clear_and_shrink(columns.arraySingletonSchemas_); clear_and_shrink(columns.objectMemberArrays_); clear_and_shrink(columns.arrayMemberArrays_); } @@ -293,12 +342,32 @@ void ModelPool::addRoot(ModelNode::Ptr const& rootNode) { model_ptr ModelPool::newObject(size_t initialFieldCapacity, bool fixedSize) { auto memberArrId = impl_->columns_.objectMemberArrays_.new_array(initialFieldCapacity, fixedSize); + if (Object::Storage::is_singleton_handle(memberArrId)) { + auto singletonIndex = Object::Storage::singleton_payload(memberArrId); + if (impl_->columns_.objectSingletonSchemas_.size() <= singletonIndex) + impl_->columns_.objectSingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.objectSingletonSchemas_[singletonIndex] = SchemaId{}; + } else { + if (impl_->columns_.objectSchemas_.size() <= memberArrId) + impl_->columns_.objectSchemas_.resize(memberArrId + 1); + impl_->columns_.objectSchemas_[memberArrId] = SchemaId{}; + } return model_ptr::make(shared_from_this(), ModelNodeAddress{Objects, (uint32_t)memberArrId}); } model_ptr ModelPool::newArray(size_t initialFieldCapacity, bool fixedSize) { auto memberArrId = impl_->columns_.arrayMemberArrays_.new_array(initialFieldCapacity, fixedSize); + if (Array::Storage::is_singleton_handle(memberArrId)) { + auto singletonIndex = Array::Storage::singleton_payload(memberArrId); + if (impl_->columns_.arraySingletonSchemas_.size() <= singletonIndex) + impl_->columns_.arraySingletonSchemas_.resize(singletonIndex + 1); + impl_->columns_.arraySingletonSchemas_[singletonIndex] = SchemaId{}; + } else { + if (impl_->columns_.arraySchemas_.size() <= memberArrId) + impl_->columns_.arraySchemas_.resize(memberArrId + 1); + impl_->columns_.arraySchemas_[memberArrId] = SchemaId{}; + } return model_ptr::make(shared_from_this(), ModelNodeAddress{Arrays, (uint32_t)memberArrId}); } @@ -434,7 +503,11 @@ ModelPool::SerializationSizeStats ModelPool::serializationSizeStats() const stats.stringRangeBytes = impl_->columns_.strings_.byte_size(); stats.stringRangeBytes += impl_->columns_.byteArrays_.byte_size(); stats.objectMemberBytes = impl_->columns_.objectMemberArrays_.byte_size(); + stats.objectSchemaBytes = impl_->columns_.objectSchemas_.byte_size() + + impl_->columns_.objectSingletonSchemas_.byte_size(); stats.arrayMemberBytes = impl_->columns_.arrayMemberArrays_.byte_size(); + stats.arraySchemaBytes = impl_->columns_.arraySchemas_.byte_size() + + impl_->columns_.arraySingletonSchemas_.byte_size(); return stats; } @@ -447,6 +520,64 @@ Object::Storage& ModelPool::objectMemberStorage() { return impl_->columns_.objectMemberArrays_; } +SchemaId ModelPool::objectSchemaId(ArrayIndex members) const +{ + if (Object::Storage::is_singleton_handle(members)) { + auto singletonIndex = Object::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.objectSingletonSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.objectSingletonSchemas_[singletonIndex]}; + } + if (members >= impl_->columns_.objectSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.objectSchemas_[members]}; +} + +auto ModelPool::setObjectSchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected +{ + if (Object::Storage::is_singleton_handle(members)) { + auto singletonIndex = Object::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.objectSingletonSchemas_.size()) + return tl::unexpected(Error::RuntimeError, "Object singleton schema index out of range."); + impl_->columns_.objectSingletonSchemas_[singletonIndex] = schemaId; + return {}; + } + if (members >= impl_->columns_.objectSchemas_.size()) + return tl::unexpected(Error::RuntimeError, "Object schema index out of range."); + + impl_->columns_.objectSchemas_[members] = schemaId; + return {}; +} + +SchemaId ModelPool::arraySchemaId(ArrayIndex members) const +{ + if (Array::Storage::is_singleton_handle(members)) { + auto singletonIndex = Array::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.arraySingletonSchemas_.size()) + return {}; + return SchemaId{impl_->columns_.arraySingletonSchemas_[singletonIndex]}; + } + if (members >= impl_->columns_.arraySchemas_.size()) + return {}; + return SchemaId{impl_->columns_.arraySchemas_[members]}; +} + +auto ModelPool::setArraySchemaId(ArrayIndex members, SchemaId schemaId) -> tl::expected +{ + if (Array::Storage::is_singleton_handle(members)) { + auto singletonIndex = Array::Storage::singleton_payload(members); + if (singletonIndex >= impl_->columns_.arraySingletonSchemas_.size()) + return tl::unexpected(Error::RuntimeError, "Array singleton schema index out of range."); + impl_->columns_.arraySingletonSchemas_[singletonIndex] = schemaId; + return {}; + } + if (members >= impl_->columns_.arraySchemas_.size()) + return tl::unexpected(Error::RuntimeError, "Array schema index out of range."); + + impl_->columns_.arraySchemas_[members] = schemaId; + return {}; +} + Object::Storage const& ModelPool::objectMemberStorage() const { return impl_->columns_.objectMemberArrays_; @@ -480,6 +611,7 @@ tl::expected ModelPool::read(const std::vector& input, con "Failed to read ModelPool: Error {}", static_cast>(s.adapter().error()))); } + return {}; } diff --git a/src/model/nodes.cpp b/src/model/nodes.cpp index 9274ea62..d15b342c 100644 --- a/src/model/nodes.cpp +++ b/src/model/nodes.cpp @@ -63,6 +63,13 @@ StringId ModelNode::keyAt(int64_t i) const { return result; } +SchemaId ModelNode::schema() const { + SchemaId result = NoSchemaId; + if (model_) + model_->resolve(*this, Model::Lambda([&](auto&& resolved) { result = resolved.schema(); })); + return result; +} + /// Get the number of children uint32_t ModelNode::size() const { uint32_t result = 0; @@ -186,6 +193,11 @@ StringId ModelNodeBase::keyAt(int64_t) const return 0; } +SchemaId ModelNodeBase::schema() const +{ + return NoSchemaId; +} + uint32_t ModelNodeBase::size() const { return 0; diff --git a/src/parser.cpp b/src/parser.cpp index ca7be97d..1211cd93 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -113,11 +113,6 @@ auto Parser::relaxed() const -> bool return mode_ == Mode::Relaxed; } -auto Parser::nextId() -> Expr::ExprId -{ - return ctx.id++; -} - auto Parser::parseInfix(expected left, int prec) -> expected { TRY_EXPECTED(left); @@ -127,7 +122,7 @@ auto Parser::parseInfix(expected left, int prec) -> expected(nextId()); + return std::make_unique(); return unexpected( Error::ParserError, @@ -176,7 +171,7 @@ auto Parser::parseTo(Token::Type type) -> expected if (!*expr) { if (relaxed()) - return std::make_unique(nextId()); + return std::make_unique(); return unexpected( Error::ParserError, diff --git a/src/rewrite-rules.h b/src/rewrite-rules.h new file mode 100644 index 00000000..a2f52298 --- /dev/null +++ b/src/rewrite-rules.h @@ -0,0 +1,106 @@ +#pragma once + +#include "simfil/value.h" +#include "simfil/expression.h" + +#include "expressions.h" + +#include + +namespace simfil +{ + +using RewriteRule = std::function; + +/** + * Apply a list of rewrite-rules top-down to an expression (sub-)tree. + */ +inline auto rewriteTopDown(ExprPtr expr, std::span rules, const RewriteRule* sourceRule = nullptr) -> ExprPtr +{ + for (const auto& rule : rules) { + // Prevent rule self-recursion. + if (&rule == sourceRule) + continue; + + auto rewrite = rule(expr); + if (rewrite && rewrite.get() != expr.get()) { + return rewriteTopDown(std::move(rewrite), rules, &rule); // NOLINT + } + } + + const auto count = expr->numChildren(); + for (auto i = 0; i < count; ++i) { + auto& child = expr->childAt(i); + child = rewriteTopDown(std::move(child), rules, nullptr); + } + + return std::move(expr); +} + +/** Rewrite `PathExpr(WildcardExpr, WildcardExpr)` -> `WildcardExpr` */ +inline auto rewriteWildcardWildcard(ExprPtr& expr) -> ExprPtr +{ + if (const auto* path = dynamic_cast(expr.get())) { + const auto* lhs = path->left_.get(); + const auto* rhs = path->right_.get(); + if (dynamic_cast(lhs) && dynamic_cast(rhs)) + return std::make_unique(); + } + + return nullptr; +} + +/** Rewrite `PathExpr(WildcardExpr, _) | PathExpr(_, WildcardExpr)` -> `WildcardExpr` */ +inline auto rewriteWildcardThis(ExprPtr& expr) -> ExprPtr +{ + auto rewrite = [](const PathExpr* path, const Expr* left, const Expr* right) -> std::unique_ptr { + const auto* lhs = dynamic_cast(left); + const auto* rhs = dynamic_cast(right); + if (lhs && rhs && rhs->name_ == "_") { + return std::make_unique(); + } + return nullptr; + }; + + if (const auto* path = dynamic_cast(expr.get())) { + if (auto replacement = rewrite(path, path->left_.get(), path->right_.get())) + return std::move(replacement); + if (auto replacement = rewrite(path, path->right_.get(), path->left_.get())) + return std::move(replacement); + } + + return nullptr; +} + +/** Rewrite `PathExpr(PathExpr(?, WildcardExpr), FieldExpr)` -> `PathExpr(?, WildcardFieldExpr(field))` */ +inline auto rewriteAnyWildcardField(ExprPtr& expr) -> ExprPtr +{ + if (auto* path = dynamic_cast(expr.get())) { + auto* lhs = dynamic_cast(path->left_.get()); + auto* rhs = dynamic_cast(path->right_.get()); + if (lhs && rhs) { + auto* lhsRhs = dynamic_cast(lhs->right_.get()); + if (lhsRhs) { + return std::make_unique(std::move(lhs->left_), + std::make_unique(rhs->name_, rhs->sourceLocation())); + } + } + } + return nullptr; +} + +/** Rewrite `PathExpr(WildcardExpr, FieldExpr)` -> `WildcardFieldExpr(field)` */ +inline auto rewriteWildcardField(ExprPtr& expr) -> ExprPtr +{ + if (auto* path = dynamic_cast(expr.get())) { + auto* lhs = dynamic_cast(path->left_.get()); + auto* rhs = dynamic_cast(path->right_.get()); + if (lhs && rhs) { + return std::make_unique(rhs->name_, rhs->sourceLocation()); + } + } + + return nullptr; +} + +} diff --git a/src/simfil.cpp b/src/simfil.cpp index ae6f6738..6154dc1b 100644 --- a/src/simfil.cpp +++ b/src/simfil.cpp @@ -15,9 +15,10 @@ #include "fmt/core.h" #include "expressions.h" -#include "expression-patterns.h" #include "completion.h" #include "expected.h" +#include "expression-patterns.h" +#include "rewrite-rules.h" #include #include @@ -47,6 +48,15 @@ static constexpr std::string_view TypenameString("string"); static constexpr std::string_view TypenameBytes("bytes"); } +static RewriteRule bottomUpRewriteRules[] = { + rewriteWildcardThis, + rewriteWildcardField, +}; + +static RewriteRule topDownRewriteRules[] = { + rewriteAnyWildcardField, +}; + /** * Parser precedence groups. */ @@ -116,7 +126,7 @@ static auto scopedNotInPath(Parser& p) { * Tries to evaluate the input expression on a stub context. * Returns the evaluated result on success, otherwise the original expression is returned. */ -static auto simplifyOrForward(Environment* env, expected expr) -> expected +static auto simplifyOrForward(const RewriteRule* currentRule, Environment* env, expected expr) -> expected { if (!expr) return expr; @@ -149,16 +159,52 @@ static auto simplifyOrForward(Environment* env, expected expr) - env->warn("Expression is always "s + values[0].toString(), (*expr)->toString()); if (values.size() == 1) - return std::make_unique((*expr)->id(), std::move(values[0])); + return std::make_unique(std::move(values[0])); if (values.size() > 1) - return std::make_unique((*expr)->id(), std::vector(std::make_move_iterator(values.begin()), - std::make_move_iterator(values.end()))); + return std::make_unique(std::vector(std::make_move_iterator(values.begin()), + std::make_move_iterator(values.end()))); + + /* Apply bottom-up rewrite rules */ + for (const auto& rule : bottomUpRewriteRules) { + /* Prevent rule self-recursion */ + if (&rule == currentRule) + continue; + + if (auto rewrite = rule(*expr)) { + /* If a rewrite rule matched we try to simplify and re-write its output again */ + return simplifyOrForward(&rule, env, std::move(rewrite)); + } + } return expr; } +static auto simplifyOrForward(Environment* env, expected expr) -> expected +{ + return simplifyOrForward(nullptr, env, std::move(expr)); +} + + AST::~AST() = default; +auto AST::reenumerate() -> void +{ + if (!expr_) + return; + + auto nextId = Expr::ExprId{0}; + reenumerate(*expr_, nextId); +} + +auto AST::reenumerate(Expr& expr, Expr::ExprId& nextId) -> void +{ + expr.id_ = nextId++; + + const auto count = expr.numChildren(); + for (auto i = 0u; i < count; ++i) + reenumerate(*expr.childAt(i), nextId); +} + /** * Parser wrapper for parsing and & or operators. * @@ -174,12 +220,10 @@ class AndOrParser : public InfixParselet return right; if (t.type == Token::OP_AND) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); else if (t.type == Token::OP_OR) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); assert(0); return nullptr; @@ -205,12 +249,10 @@ class CompletionAndOrParser : public InfixParselet return right; if (t.type == Token::OP_AND) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), comp_)); else if (t.type == Token::OP_OR) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right), comp_)); assert(0); return nullptr; @@ -231,7 +273,7 @@ class CastParser : public InfixParselet { auto type = p.consume(); if (type.type == Token::C_NULL) - return std::make_unique(p.nextId(), Value::null()); + return std::make_unique(Value::null()); if (type.type != Token::Type::WORD) return unexpected(Error::InvalidType, fmt::format("'as' expected typename got {}", type.toString())); @@ -239,17 +281,17 @@ class CastParser : public InfixParselet auto name = std::get(type.value); return simplifyOrForward(p.env, [&]() -> expected { if (name == strings::TypenameNull) - return std::make_unique(p.nextId(), Value::null()); + return std::make_unique(Value::null()); if (name == strings::TypenameBool) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameInt) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameFloat) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameString) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); if (name == strings::TypenameBytes) - return std::make_unique>(p.nextId(), std::move(left)); + return std::make_unique>(std::move(left)); return unexpected(Error::InvalidType, fmt::format("Invalid type name for cast '{}'", name)); }()); @@ -277,8 +319,7 @@ class BinaryOpParser : public InfixParselet if (!right) return right; - return simplifyOrForward(p.env, std::make_unique>(p.nextId(), - t, + return simplifyOrForward(p.env, std::make_unique>(t, std::move(left), std::move(*right))); } @@ -303,7 +344,7 @@ class UnaryOpParser : public PrefixParselet if (!sub) return sub; - return simplifyOrForward(p.env, std::make_unique>(p.nextId(), std::move(*sub))); + return simplifyOrForward(p.env, std::make_unique>(std::move(*sub))); } }; @@ -315,7 +356,7 @@ class UnaryPostOpParser : public InfixParselet { auto parse(Parser& p, ExprPtr left, Token t) const -> expected override { - return p.parseInfix(simplifyOrForward(p.env, std::make_unique>(p.nextId(), std::move(left))), 0); + return p.parseInfix(simplifyOrForward(p.env, std::make_unique>(std::move(left))), 0); } auto precedence() const -> int override @@ -331,7 +372,7 @@ class UnpackOpParser : public InfixParselet { auto parse(Parser& p, ExprPtr left, Token t) const -> expected override { - return p.parseInfix(simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(left))), 0); + return p.parseInfix(simplifyOrForward(p.env, std::make_unique(std::move(left))), 0); } auto precedence() const -> int override @@ -353,14 +394,12 @@ class WordOpParser : public InfixParselet return right; if (*right) - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::get(t.value), + return simplifyOrForward(p.env, std::make_unique(std::get(t.value), std::move(left), std::move(*right))); /* Parse as unary operator */ - return p.parseInfix(simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::get(t.value), + return p.parseInfix(simplifyOrForward(p.env, std::make_unique(std::get(t.value), std::move(left))), 0); } @@ -380,7 +419,7 @@ class ScalarParser : public PrefixParselet { auto parse(Parser& p, Token t) const -> expected override { - return std::make_unique(p.nextId(), std::get(t.value)); + return std::make_unique(std::get(t.value)); } }; @@ -394,7 +433,7 @@ class RegExpParser : public PrefixParselet auto parse(Parser& p, Token t) const -> expected override { auto value = ReType::Type.make(std::get(t.value)); - return std::make_unique(p.nextId(), std::move(value)); + return std::make_unique(std::move(value)); } }; @@ -415,7 +454,7 @@ class ConstParser : public PrefixParselet auto parse(Parser& p, Token t) const -> expected override { - return std::make_unique(p.nextId(), value_); + return std::make_unique(value_); } Value value_; @@ -450,10 +489,7 @@ class SubscriptParser : public PrefixParselet, public InfixParselet if (!body) return body; - auto outerId = p.nextId(); - auto innerId = p.nextId(); - return simplifyOrForward(p.env, std::make_unique(outerId, - std::make_unique(innerId, "_"), + return simplifyOrForward(p.env, std::make_unique(std::make_unique("_"), std::move(*body))); } @@ -464,8 +500,7 @@ class SubscriptParser : public PrefixParselet, public InfixParselet if (!body) return body; - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*body))); } @@ -491,10 +526,7 @@ class SubSelectParser : public PrefixParselet, public InfixParselet auto body = p.parseTo(Token::RBRACE); TRY_EXPECTED(body); - auto outerId = p.nextId(); - auto innerId = p.nextId(); - return simplifyOrForward(p.env, std::make_unique(outerId, - std::make_unique(innerId, "_"), + return simplifyOrForward(p.env, std::make_unique(std::make_unique("_"), std::move(*body))); } @@ -503,8 +535,7 @@ class SubSelectParser : public PrefixParselet, public InfixParselet auto _ = scopedNotInPath(p); auto body = p.parseTo(Token::RBRACE); TRY_EXPECTED(body); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), - std::move(left), + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*body))); } @@ -526,15 +557,15 @@ class WordParser : public PrefixParselet { /* Self */ if (t.type == Token::SELF) - return std::make_unique(p.nextId(), "_", t); + return std::make_unique("_", t); /* Any Child */ if (t.type == Token::OP_TIMES) - return std::make_unique(p.nextId()); + return std::make_unique(); /* Wildcard */ if (t.type == Token::WILDCARD) - return std::make_unique(p.nextId()); + return std::make_unique(); auto word = std::get(t.value); @@ -547,25 +578,25 @@ class WordParser : public PrefixParselet TRY_EXPECTED(arguments); if (word == "any") { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(std::move(*arguments))); } else if (word == "each" || word == "all") { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(std::move(*arguments))); } else { - return simplifyOrForward(p.env, std::make_unique(p.nextId(), word, std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(word, std::move(*arguments))); } } else if (!p.ctx.inPath) { /* Parse Symbols (words in upper-case) */ if (isSymbolWord(word)) { - return std::make_unique(p.nextId(), Value::make(std::move(word))); + return std::make_unique(Value::make(std::move(word))); } /* Constant */ else if (auto constant = p.env->findConstant(word)) { - return std::make_unique(p.nextId(), *constant); + return std::make_unique(*constant); } } /* Single field name */ - return std::make_unique(p.nextId(), std::move(word), t); + return simplifyOrForward(p.env, std::make_unique(std::move(word), t)); } }; @@ -583,15 +614,15 @@ class CompletionWordParser : public WordParser { /* Self */ if (t.type == Token::SELF) - return std::make_unique(p.nextId(), "_"); + return std::make_unique("_"); /* Any Child */ if (t.type == Token::OP_TIMES) - return std::make_unique(p.nextId()); + return std::make_unique(); /* Wildcard */ if (t.type == Token::WILDCARD) - return std::make_unique(p.nextId()); + return std::make_unique(); auto word = std::get(t.value); @@ -607,26 +638,26 @@ class CompletionWordParser : public WordParser auto arguments = p.parseList(Token::RPAREN); TRY_EXPECTED(arguments); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), word, std::move(*arguments))); + return simplifyOrForward(p.env, std::make_unique(word, std::move(*arguments))); } else if (!p.ctx.inPath) { /* Parse Symbols (words in upper-case) */ if (isSymbolWord(word)) { if (t.containsPoint(comp_->point)) { - return std::make_unique(p.nextId(), word.substr(0, comp_->point - t.begin), comp_, t); + return std::make_unique(word.substr(0, comp_->point - t.begin), comp_, t); } - return std::make_unique(p.nextId(), Value::make(std::move(word))); + return std::make_unique(Value::make(std::move(word))); } /* Constant */ else if (auto constant = p.env->findConstant(word)) { - return std::make_unique(p.nextId(), *constant); + return std::make_unique(*constant); } } /* Single field name */ if (t.containsPoint(comp_->point)) { - return std::make_unique(p.nextId(), word.substr(0, comp_->point - t.begin), comp_, t, p.ctx.inPath); + return std::make_unique(word.substr(0, comp_->point - t.begin), comp_, t, p.ctx.inPath); } - return std::make_unique(p.nextId(), std::move(word)); + return simplifyOrForward(p.env, std::make_unique(std::move(word))); } Completion* comp_; @@ -653,7 +684,7 @@ class PathParser : public InfixParselet auto right = p.parsePrecedence(precedence()); TRY_EXPECTED(right); - return std::make_unique(p.nextId(), std::move(left), std::move(*right)); + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); } auto precedence() const -> int override @@ -684,10 +715,10 @@ class CompletionPathParser : public PathParser if (!*right) { Token expectedWord(Token::WORD, "", t.end, t.end); - right = std::make_unique(p.nextId(), "", comp_, expectedWord, p.ctx.inPath); + right = std::make_unique("", comp_, expectedWord, p.ctx.inPath); } - return std::make_unique(p.nextId(), std::move(left), std::move(*right)); + return simplifyOrForward(p.env, std::make_unique(std::move(left), std::move(*right))); } Completion* comp_; @@ -819,10 +850,8 @@ auto compile(Environment& env, std::string_view query, bool any, bool autoWildca /* Expand a single value to `** == ` */ if (autoWildcard && *root && (*root)->constant()) { - auto outerId = p.nextId(); - auto innerId = p.nextId(); - root = std::make_unique>( - outerId, std::make_unique(innerId), std::move(*root)); + root = simplifyOrForward(p.env, std::make_unique>( + std::make_unique(), std::move(*root))); } if (!*root) @@ -831,17 +860,22 @@ auto compile(Environment& env, std::string_view query, bool any, bool autoWildca if (any) { std::vector args; args.emplace_back(std::move(*root)); - return simplifyOrForward(p.env, std::make_unique(p.nextId(), std::move(args))); + return simplifyOrForward(p.env, std::make_unique(std::move(args))); } else { return root; } }(); TRY_EXPECTED(expr); + /* Apply AST rewrite rules */ + expr = rewriteTopDown(std::move(*expr), topDownRewriteRules); + if (!p.match(Token::Type::NIL)) return unexpected(Error::ExpectedEOF, "Expected end-of-input; got "s + p.current().toString()); - return std::make_unique(std::string(query), std::move(*expr)); + auto ast = std::make_unique(std::string(query), std::move(*expr)); + ast->reenumerate(); + return ast; } auto complete(Environment& env, std::string_view query, size_t point, const ModelNode& node, const CompletionOptions& options) -> expected, Error> diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 53b8792b..96038825 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -8,6 +8,7 @@ add_executable(test.simfil common.hpp common.cpp token.cpp + schema.cpp simfil.cpp diagnostics.cpp completion.cpp diff --git a/test/common.hpp b/test/common.hpp index 3291a10f..5ad26bc4 100644 --- a/test/common.hpp +++ b/test/common.hpp @@ -15,6 +15,14 @@ #include #include +#if __has_include() +# include +#else +# define RUNNING_ON_VALGRIND false +# define CALLGRIND_START_INSTRUMENTATION (void)0 +# define CALLGRIND_STOP_INSTRUMENTATION (void)0 +#endif + using namespace simfil; static const char* const TestModel = R"json( diff --git a/test/performance.cpp b/test/performance.cpp index c1e8de01..65c4d1b5 100644 --- a/test/performance.cpp +++ b/test/performance.cpp @@ -1,6 +1,5 @@ #include "simfil/simfil.h" #include "simfil/model/model.h" - #include #include #include diff --git a/test/schema.cpp b/test/schema.cpp new file mode 100644 index 00000000..6842056b --- /dev/null +++ b/test/schema.cpp @@ -0,0 +1,463 @@ +#include "simfil/diagnostics.h" +#include "simfil/model/nodes.h" +#include "simfil/simfil.h" +#include "simfil/environment.h" +#include "simfil/model/schema.h" +#include "simfil/model/model.h" +#include "simfil/model/json.h" +#include "common.hpp" + +#include +#include +#include +#include + +using namespace simfil; + +namespace +{ + +class SchemaRegistry +{ +public: + std::map> schemas; + + // Enable schema lookup. + // + // By having this flag we do not cheat the price of + // the function call for the no-schema benchmarks instead + // of setting the environments query pointer to null. + bool enabled = true; + + auto get(SchemaId id) const -> const Schema* + { + if (!enabled) + return nullptr; + + auto i = schemas.find(id); + if (i != schemas.end()) + return i->second.get(); + return nullptr; + } + + auto finalize() -> void + { + auto& self = *this; + for (const auto& [key, value] : schemas) { + value->finalize([&self](auto id) { return self(id); }); + } + } + + auto operator()(SchemaId id) -> Schema* + { + return const_cast(const_cast(this)->get(id)); + } + + auto operator()(SchemaId id) const -> const Schema* + { + return get(id); + } + + auto asFunction() const & -> std::function + { + return [this](SchemaId id) { + return (*this)(id); + }; + } +}; + +} + +TEST_CASE("Object schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto obj = model->newObject(0); + REQUIRE(obj->schema() == NoSchemaId); + + obj->setSchema(SchemaId{1}); + REQUIRE(obj->schema() == SchemaId{1}); +} + +TEST_CASE("Singleton object schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto obj = model->newObject(1, true); + REQUIRE(obj->schema() == NoSchemaId); + + REQUIRE(obj->addField("field", int64_t{1})); + obj->setSchema(SchemaId{1}); + REQUIRE(obj->schema() == SchemaId{1}); + REQUIRE(model->validate()); +} + +TEST_CASE("Array schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto arr = model->newArray(0); + REQUIRE(arr->schema() == NoSchemaId); + + arr->setSchema(SchemaId{1}); + REQUIRE(arr->schema() == SchemaId{1}); +} + +TEST_CASE("Singleton array schema id assignment", "[model.schema]") { + auto model = std::make_shared(); + + auto arr = model->newArray(1, true); + REQUIRE(arr->schema() == NoSchemaId); + + arr->append(int64_t(1)); + arr->setSchema(SchemaId{1}); + REQUIRE(arr->schema() == SchemaId{1}); + REQUIRE(model->validate()); +} + +TEST_CASE("Object schema finalization", "[model.schema]") { + auto strings = std::make_shared(); + const auto a = strings->emplace("a").value(); + const auto b = strings->emplace("b").value(); + const auto c = strings->emplace("c").value(); + const auto link = strings->emplace("link").value(); + const auto back = strings->emplace("back").value(); + const auto missing = strings->emplace("missing").value(); + + SECTION("dirty schemas are conservative") { + ObjectSchema schema; + schema.addField(a); + + // No finalize() called, so canHaveField must return `true`. + REQUIRE(schema.canHaveField(a)); + REQUIRE(schema.canHaveField(missing)); + + schema.finalize([](SchemaId) { return nullptr; }); + REQUIRE(schema.canHaveField(a)); + REQUIRE(!schema.canHaveField(missing)); + } + + SECTION("acyclic schemas finalize fields") { + std::vector schemas(3); + schemas[1].addField(a, {SchemaId{2}}); + schemas[2].addField(b); + + auto lookup = [&](SchemaId schemaId) -> ObjectSchema* { + const auto index = static_cast(schemaId); + return index < schemas.size() ? &schemas[index] : nullptr; + }; + + schemas[1].finalize(lookup); + + REQUIRE(schemas[1].canHaveField(a)); + REQUIRE(schemas[1].canHaveField(b)); + REQUIRE_FALSE(schemas[1].canHaveField(c)); + } + + SECTION("cyclic schemas stay conservative") { + std::vector schemas(3); + schemas[1].addField(link, {SchemaId{2}}); + schemas[1].addField(c); + schemas[2].addField(back, {SchemaId{1}}); + + auto lookup = [&](SchemaId schemaId) -> ObjectSchema* { + const auto index = static_cast(schemaId); + return index < schemas.size() ? &schemas[index] : nullptr; + }; + + schemas[1].finalize(lookup); + + REQUIRE(schemas[1].canHaveField(missing)); + REQUIRE(schemas[2].canHaveField(missing)); + } + + SECTION("array schemas finalize element fields") { + ObjectSchema objectA; + objectA.addField(a); + + ObjectSchema objectB; + objectB.addField(b); + + ArraySchema arraySchema; + arraySchema.addElementSchemas({SchemaId{1}, SchemaId{2}}); + + auto lookup = [&](SchemaId schemaId) -> Schema* { + switch (schemaId) { + case SchemaId{1}: + return &objectA; + case SchemaId{2}: + return &objectB; + default: + return nullptr; + } + }; + + arraySchema.finalize(lookup); + + REQUIRE(arraySchema.canHaveField(a)); + REQUIRE(arraySchema.canHaveField(b)); + REQUIRE_FALSE(arraySchema.canHaveField(c)); + } +} + +TEST_CASE("Array schema serialization", "[model.schema]") { + auto model = std::make_shared(); + auto arr = model->newArray(1); + arr->append(int64_t(42)); + REQUIRE(arr->setSchema(SchemaId{7})); + model->addRoot(arr); + + std::stringstream stream; + REQUIRE(model->write(stream)); + + const auto input = std::vector(std::istreambuf_iterator(stream), {}); + auto recoveredModel = std::make_shared(); + REQUIRE(recoveredModel->read(input)); + + auto recoveredRoot = recoveredModel->root(0); + REQUIRE(recoveredRoot); + REQUIRE((*recoveredRoot)->type() == ValueType::Array); + REQUIRE((*recoveredRoot)->schema() == SchemaId{7}); +} + +// A minimal test that makes sure a field not in the schema +// is pruned if we query for it via **.field. +TEST_CASE("WildcardFieldExpr Field Pruning", "[model.schema]") +{ + auto jsonModel = R"json( + { + "field": 123 + } + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto fieldId = strings->get("field"); + + // We need to add "noField" to the StringPool to prevent + // evaluation skipping the expression. + (void)strings->emplace("noField"); + + // Build a simple schema + auto schemaName = strings->emplace("schema1").value(); + auto schema1 = std::make_unique(); + schema1->addField(fieldId, { NoSchemaId }); + + registry.schemas[(SchemaId)schemaName] = std::move(schema1); + registry.finalize(); + + // Assign schemas to the model + auto root = model->root(0); + REQUIRE(root); + + auto rootObj = model->resolve(*root.value()); + REQUIRE(rootObj); + REQUIRE(rootObj->setSchema((SchemaId)schemaName)); + REQUIRE(rootObj->schema() == (SchemaId)schemaName); + + // Run a query and check if pruning of unknown fields works + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "**.noField", false, false); + REQUIRE(ast); + + Diagnostics diagWithPruning; + registry.enabled = true; + auto resultWithPruning = eval(env, *ast.value(), *model->root(0).value(), &diagWithPruning); + REQUIRE(resultWithPruning); + + Diagnostics diagNoPruning; + registry.enabled = false; + auto resultNoPruning = eval(env, *ast.value(), *model->root(0).value(), &diagNoPruning); + REQUIRE(resultNoPruning); + + // We compare field evaluations for both runs + auto withPruningData = diagWithPruning.fieldData_[0]; + auto noPruningData = diagNoPruning.fieldData_[0]; + REQUIRE(withPruningData.evaluations < noPruningData.evaluations); +} + +TEST_CASE("WildcardFieldExpr Array Field Pruning", "[model.schema]") +{ + auto jsonModel = R"json( + [ + { + "field": 123 + } + ] + )json"; + auto model = json::parse(jsonModel).value(); + auto registry = SchemaRegistry{}; + auto strings = model->strings(); + auto fieldId = strings->get("field"); + + (void)strings->emplace("noField"); + + constexpr auto objectSchemaId = SchemaId{1}; + constexpr auto arraySchemaId = SchemaId{2}; + + auto objectSchema = std::make_unique(); + objectSchema->addField(fieldId, { NoSchemaId }); + registry.schemas[objectSchemaId] = std::move(objectSchema); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({objectSchemaId}); + registry.schemas[arraySchemaId] = std::move(arraySchema); + registry.finalize(); + + auto root = model->root(0); + REQUIRE(root); + auto rootArray = model->resolve(*root.value()); + + REQUIRE(rootArray); + REQUIRE(rootArray->setSchema(arraySchemaId)); + REQUIRE(rootArray->schema() == arraySchemaId); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "**.noField", false, false); + REQUIRE(ast); + + auto modelRoot = model->root(0); + REQUIRE(modelRoot); + + Diagnostics diagWithPruning; + registry.enabled = true; + auto resultWithPruning = eval(env, **ast, **modelRoot, &diagWithPruning); + REQUIRE(resultWithPruning); + + Diagnostics diagNoPruning; + registry.enabled = false; + auto resultNoPruning = eval(env, **ast, **modelRoot, &diagNoPruning); + REQUIRE(resultNoPruning); + + auto withPruningData = diagWithPruning.fieldData_[0]; + auto noPruningData = diagNoPruning.fieldData_[0]; + REQUIRE(withPruningData.evaluations < noPruningData.evaluations); +} + +TEST_CASE("Schema query performance", "[perf.schema]") { + if (RUNNING_ON_VALGRIND) { // NOLINT + SKIP("Skipping benchmarks when running under valgrind"); + } + + constexpr auto n = std::size_t{10'000}; + static_assert(n % 2 == 0, "n must be even"); + + const auto payloadASchemaId = SchemaId{1}; + const auto payloadBSchemaId = SchemaId{2}; + const auto xASchemaId = SchemaId{3}; + const auto xBSchemaId = SchemaId{4}; + const auto yASchemaId = SchemaId{5}; + const auto yBSchemaId = SchemaId{6}; + const auto rootObjASchemaId = SchemaId{7}; + const auto rootObjBSchemaId = SchemaId{8}; + const auto arraySchemaId = SchemaId{9}; + + auto strings = std::make_shared(); + auto model = std::make_shared(strings); + auto registry = SchemaRegistry{}; + + const auto aId = strings->emplace("a").value(); + const auto bId = strings->emplace("b").value(); + const auto yId = strings->emplace("y").value(); + const auto xId = strings->emplace("x").value(); + const auto payloadId = strings->emplace("payload").value(); + + auto payloadASchema = std::make_unique(); + payloadASchema->addField(xId, { xASchemaId }); + registry.schemas[payloadASchemaId] = std::move(payloadASchema); + + auto payloadBSchema = std::make_unique(); + payloadBSchema->addField(xId, { xBSchemaId }); + registry.schemas[payloadBSchemaId] = std::move(payloadBSchema); + + auto xASchema = std::make_unique(); + xASchema->addField(yId, { yASchemaId }); + registry.schemas[xASchemaId] = std::move(xASchema); + + auto xBSchema = std::make_unique(); + xBSchema->addField(yId, { yBSchemaId }); + registry.schemas[xBSchemaId] = std::move(xBSchema); + + auto yASchema = std::make_unique(); + yASchema->addField(aId); + registry.schemas[yASchemaId] = std::move(yASchema); + + auto yBSchema = std::make_unique(); + yBSchema->addField(bId); + registry.schemas[yBSchemaId] = std::move(yBSchema); + + auto rootObjASchema = std::make_unique(); + rootObjASchema->addField(payloadId, { payloadASchemaId }); + registry.schemas[rootObjASchemaId] = std::move(rootObjASchema); + + auto rootObjBSchema = std::make_unique(); + rootObjBSchema->addField(payloadId, { payloadBSchemaId }); + registry.schemas[rootObjBSchemaId] = std::move(rootObjBSchema); + + auto arraySchema = std::make_unique(); + arraySchema->addElementSchemas({ rootObjASchemaId, rootObjBSchemaId }); + registry.schemas[arraySchemaId] = std::move(arraySchema); + registry.finalize(); + + auto root = model->newArray(n); + for (auto i = 0u; i < n; ++i) { + auto obj = model->newObject(1, true); + auto payload = model->newObject(1, true); + auto x = model->newObject(1, true); + auto y = model->newObject(1, true); + + if (i % 2 == 0) { + y->addField("a", int64_t(1)); + y->setSchema(yASchemaId); + x->setSchema(xASchemaId); + payload->setSchema(payloadASchemaId); + obj->setSchema(rootObjASchemaId); + } else { + y->addField("b", int64_t(1)); + y->setSchema(yBSchemaId); + x->setSchema(xBSchemaId); + payload->setSchema(payloadBSchemaId); + obj->setSchema(rootObjBSchemaId); + } + + x->addField("y", y); + payload->addField("x", x); + obj->addField("payload", payload); + root->append(obj); + } + + REQUIRE(root->setSchema(arraySchemaId)); + model->addRoot(root); + + Environment env(strings); + env.querySchemaCallback = registry.asFunction(); + + auto ast = compile(env, "count(**.a == 1)", false, false); + REQUIRE(ast); + + auto modelRoot = model->root(0); + REQUIRE(modelRoot); + + registry.enabled = false; + BENCHMARK("Query nested field 'a' recursive without schema") { + auto res = eval(env, **ast, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(n / 2)); + return count; + }; + + registry.enabled = true; + BENCHMARK("Query nested field 'a' recursive with schema") { + auto res = eval(env, **ast, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == int64_t(n / 2)); + return count; + }; +} diff --git a/test/simfil.cpp b/test/simfil.cpp index 7ec6511c..5aedf6fa 100644 --- a/test/simfil.cpp +++ b/test/simfil.cpp @@ -57,9 +57,10 @@ TEST_CASE("Path", "[ast.path]") { TEST_CASE("Wildcard", "[ast.wildcard]") { REQUIRE_AST("*", "*"); REQUIRE_AST("**", "**"); - REQUIRE_AST("**.a", "(. ** a)"); - REQUIRE_AST("a.**.b", "(. (. a **) b)"); - REQUIRE_AST("a.**.b.**.c", "(. (. (. (. a **) b) **) c)"); + REQUIRE_AST("**.a", "**.a"); /* Optimization rewrites this from (. ** a) to **.a */ + REQUIRE_AST("**.a.b.c", "(. (. **.a b) c)"); + REQUIRE_AST("a.**.b", "(. a **.b)"); + REQUIRE_AST("a.**.b.**.c", "(. (. a **.b) **.c)"); REQUIRE_AST("* == *", "(== * *)"); /* Do not optimize away */ REQUIRE_AST("** == **", "(== ** **)"); /* Do not optimize away */ @@ -425,6 +426,7 @@ TEST_CASE("Path Wildcard", "[yaml.path-wildcard]") { REQUIRE_RESULT("sub.*", R"(sub a|sub b|{"a":"sub sub a","b":"sub sub b"})"); REQUIRE_RESULT("sub.**", R"({"a":"sub a","b":"sub b","sub":{"a":"sub sub a","b":"sub sub b"}}|sub a|sub b|)" R"({"a":"sub sub a","b":"sub sub b"}|sub sub a|sub sub b)"); + REQUIRE_RESULT("**.a", "1|sub a|sub sub a"); REQUIRE_RESULT("(sub.*.{typeof _ != 'model'} + sub.*.{typeof _ != 'model'})._", "sub asub a|sub asub b|sub bsub a|sub bsub b"); /* . filters null */ REQUIRE_RESULT("sub.*.{typeof _ != 'model'} + sub.*.{typeof _ != 'model'}", "sub asub a|sub asub b|sub bsub a|sub bsub b"); /* {_} filters null */ REQUIRE_RESULT("count(*)", "12"); @@ -754,6 +756,13 @@ TEST_CASE("Visit AST", "[visit.ast]") visitedFieldName = expr.name_; } + + auto visit(const WildcardFieldExpr& expr) -> void override + { + ExprVisitor::visit(expr); + + visitedFieldName = expr.name_; + } }; Visitor visitor; @@ -761,3 +770,19 @@ TEST_CASE("Visit AST", "[visit.ast]") REQUIRE(visitor.visitedFieldName == "field"); } + +TEST_CASE("AST expr ids are reenumerated after rewrites", "[ast.expr-id]") +{ + auto ast = Compile("**.field = 123", false); + + std::vector ids; + const auto collectIds = [&](const auto& self, const Expr& expr) -> void { + ids.emplace_back(expr.id()); + for (auto i = 0u; i < expr.numChildren(); ++i) + self(self, *expr.childAt(i)); + }; + + collectIds(collectIds, ast->expr()); + + REQUIRE(ids == std::vector{0, 1, 2}); +} diff --git a/test/value.cpp b/test/value.cpp index c01640a3..35cc59f5 100644 --- a/test/value.cpp +++ b/test/value.cpp @@ -1,9 +1,12 @@ #include #include #include +#include +#include #include "simfil/value.h" #include "simfil/model/model.h" +#include "simfil/model/schema.h" #include "simfil/token.h" #include "simfil/transient.h" From a16e7a1ea9431791eb502b3b408645ce3ccce5ee Mon Sep 17 00:00:00 2001 From: Johannes Wolf Date: Sat, 9 May 2026 19:53:10 +0200 Subject: [PATCH 2/3] Harden ConstExpr Constructor --- src/expressions.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/expressions.h b/src/expressions.h index fa249236..ad7b136d 100644 --- a/src/expressions.h +++ b/src/expressions.h @@ -76,6 +76,7 @@ class ConstExpr : public Expr public: ConstExpr() = delete; template + requires (!std::derived_from, ConstExpr>) explicit ConstExpr(CType_&& value) : value_(Value::make(std::forward(value))) {} From 9bbff9282089f8c815c300ab5a9965483621b005 Mon Sep 17 00:00:00 2001 From: Johannes Wolf Date: Sat, 9 May 2026 20:01:23 +0200 Subject: [PATCH 3/3] Add Root Pruning Test --- test/schema.cpp | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/test/schema.cpp b/test/schema.cpp index 6842056b..36aee532 100644 --- a/test/schema.cpp +++ b/test/schema.cpp @@ -361,6 +361,7 @@ TEST_CASE("Schema query performance", "[perf.schema]") { const auto bId = strings->emplace("b").value(); const auto yId = strings->emplace("y").value(); const auto xId = strings->emplace("x").value(); + const auto missingId = strings->emplace("missing").value(); const auto payloadId = strings->emplace("payload").value(); auto payloadASchema = std::make_unique(); @@ -433,15 +434,18 @@ TEST_CASE("Schema query performance", "[perf.schema]") { Environment env(strings); env.querySchemaCallback = registry.asFunction(); - auto ast = compile(env, "count(**.a == 1)", false, false); - REQUIRE(ast); - auto modelRoot = model->root(0); REQUIRE(modelRoot); + auto aAst = compile(env, "count(**.a == 1)", false, false); + REQUIRE(aAst); + + auto missingAst = compile(env, "count(**.missing == 1)", false, false); + REQUIRE(missingAst); + registry.enabled = false; BENCHMARK("Query nested field 'a' recursive without schema") { - auto res = eval(env, **ast, **modelRoot, nullptr); + auto res = eval(env, **aAst, **modelRoot, nullptr); REQUIRE(res); REQUIRE(res->size() == 1); @@ -450,9 +454,19 @@ TEST_CASE("Schema query performance", "[perf.schema]") { return count; }; + BENCHMARK("Query missing field 'missing' without schema") { + auto res = eval(env, **missingAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == 0); + return count; + }; + registry.enabled = true; BENCHMARK("Query nested field 'a' recursive with schema") { - auto res = eval(env, **ast, **modelRoot, nullptr); + auto res = eval(env, **aAst, **modelRoot, nullptr); REQUIRE(res); REQUIRE(res->size() == 1); @@ -460,4 +474,14 @@ TEST_CASE("Schema query performance", "[perf.schema]") { REQUIRE(count == int64_t(n / 2)); return count; }; + + BENCHMARK("Query missing field 'missing' with schema") { + auto res = eval(env, **missingAst, **modelRoot, nullptr); + REQUIRE(res); + REQUIRE(res->size() == 1); + + auto count = res->front().template as(); + REQUIRE(count == 0); + return count; + }; }