Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/sph/run_sph_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,16 @@ def H_profile(r):
setup.apply_setup(gen_disc)

model.do_vtk_dump("init_disc.vtk", True)
model.dump("tmp.sham")


ctx = shamrock.Context()
model = shamrock.get_Model_SPH(context=ctx, vector_type="f64_3", sph_kernel="M4")
model.load_from_dump("tmp.sham")

model.dump("tmp2.sham")

exit()
Comment thread
tdavidcl marked this conversation as resolved.
model.change_htolerances(coarse=1.3, fine=1.1)
model.timestep()
model.change_htolerances(coarse=1.1, fine=1.1)
Expand Down
5 changes: 5 additions & 0 deletions src/shammodels/sph/src/Model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "shamrock/patch/PatchDataLayer.hpp"
#include "shamrock/scheduler/DataInserterUtility.hpp"
#include "shamrock/scheduler/PatchScheduler.hpp"
#include "shamrock/solvergraph/ScalarEdgeSerializable.hpp"
#include "shamsys/NodeInstance.hpp"
#include "shamsys/legacy/log.hpp"
#include <functional>
Expand Down Expand Up @@ -69,6 +70,10 @@ void shammodels::sph::Model<Tvec, SPHKernel>::init() {

PatchScheduler &sched = shambase::get_check_ref(ctx.sched);

auto time_edge = sched.synchronized_data.container.register_edge(
"time", shamrock::solvergraph::ScalarEdgeSerializable<Tscal>("time", "t"));
time_edge->value = 0;

sched.add_root_patch();

shamlog_debug_ln("Sys", "build local scheduler tables");
Expand Down
28 changes: 25 additions & 3 deletions src/shamrock/include/shamrock/scheduler/PatchScheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,19 @@
#include "shambase/DistributedData.hpp"
#include "shambase/stacktrace.hpp"
#include "shambase/time.hpp"
#include "nlohmann/json_fwd.hpp"
#include "shamalgs/collective/distributedDataComm.hpp"
#include "shamrock/legacy/patch/utility/patch_field.hpp"
#include "shamrock/solvergraph/NodeSetEdge.hpp"
#include "shamrock/solvergraph/PatchDataLayerRefs.hpp"
#include "shamrock/solvergraph/SolverGraph.hpp"
#include <nlohmann/json.hpp>
#include <unordered_map>
#include <unordered_set>
#include <fstream>
#include <functional>
#include <memory>
#include <optional>
#include <stdexcept>
#include <tuple>
#include <vector>
Expand All @@ -47,8 +51,25 @@
#include "shamrock/scheduler/HilbertLoadBalance.hpp"
#include "shamrock/scheduler/PatchTree.hpp"
#include "shamrock/scheduler/SchedulerPatchData.hpp"
#include "shamrock/solvergraph/IEdgeNamed.hpp"
#include "shamrock/solvergraph/JsonSerializable.hpp"
#include "shamsys/legacy/sycl_handler.hpp"

inline std::unordered_map<
std::string,
std::function<std::shared_ptr<shamrock::solvergraph::IEdge>(const nlohmann::json &j)>>
deser_map = {};

/// Data stored within the scheduler that are garanteed to be in sink across all ranks
struct SynchronizedData {
shamrock::solvergraph::SolverGraph container
= shamrock::solvergraph::SolverGraph::with_constraint(
std::nullopt, shamrock::solvergraph::json_serializable_edge_constraint);

nlohmann::json to_json();

void from_json(const nlohmann::json &j);
};
struct PatchSchedulerConfig {
u64 split_load_value = 0_u64;
u64 merge_load_value = 0_u64;
Expand Down Expand Up @@ -98,9 +119,10 @@ class PatchScheduler {
u64 crit_patch_split; ///< splitting limit (if load value > crit_patch_split => patch split)
u64 crit_patch_merge; ///< merging limit (if load value < crit_patch_merge => patch merge)

SchedulerPatchList patch_list; ///< handle the list of the patches of the scheduler
SchedulerPatchData patch_data; ///< handle the data of the patches of the scheduler
PatchTree patch_tree; ///< handle the tree structure of the patches
SchedulerPatchList patch_list; ///< handle the list of the patches of the scheduler
SchedulerPatchData patch_data; ///< handle the data of the patches of the scheduler
PatchTree patch_tree; ///< handle the tree structure of the patches
SynchronizedData synchronized_data; ///< data that is synchroneous across all ranks

// using unordered set is not an issue since we use the find command after
std::unordered_set<u64> owned_patch_id; ///< list of owned patch ids updated with
Expand Down
1 change: 1 addition & 0 deletions src/shamrock/include/shamrock/solvergraph/IEdgeNamed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ namespace shamrock::solvergraph {

virtual std::string _impl_get_dot_label() const { return name; }
virtual std::string _impl_get_tex_symbol() const { return "{" + texsymbol + "}"; }
virtual std::string get_raw_tex_symbol() const { return texsymbol; }
};

} // namespace shamrock::solvergraph
38 changes: 38 additions & 0 deletions src/shamrock/include/shamrock/solvergraph/JsonSerializable.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#pragma once

/**
* @file JsonSerializable.hpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief
*/

#include "shamrock/solvergraph/IEdge.hpp"
#include <nlohmann/json.hpp>
#include <string>

namespace shamrock::solvergraph {

struct JsonSerializable {
virtual ~JsonSerializable() {};

virtual void to_json(nlohmann::json &j) = 0;
virtual void from_json(const nlohmann::json &j) = 0;

virtual std::string type_name() = 0;
};

inline bool json_serializable_edge_constraint(
const std::shared_ptr<shamrock::solvergraph::IEdge> &edge) {
// check that the edge can be cross-casted to JsonSerializable
return bool(std::dynamic_pointer_cast<JsonSerializable>(edge));
};
} // namespace shamrock::solvergraph
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// -------------------------------------------------------//
//
// SHAMROCK code for hydrodynamics
// Copyright (c) 2021-2026 Timothée David--Cléris <tim.shamrock@proton.me>
// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1
// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information
//
// -------------------------------------------------------//

#pragma once

/**
* @file ScalarEdgeSerializable.hpp
* @author Timothée David--Cléris (tim.shamrock@proton.me)
* @brief
*
*/

#include "shambase/exception.hpp"
#include "shambase/pre_main_call.hpp"
#include "shambase/string.hpp"
#include "shambase/type_name_info.hpp"
#include "nlohmann/json_fwd.hpp"
#include "shamrock/scheduler/PatchScheduler.hpp"
#include "shamrock/solvergraph/IEdge.hpp"
#include "shamrock/solvergraph/ScalarEdge.hpp"
#include <memory>
#include <stdexcept>

namespace shamrock::solvergraph {

template<class T>
class ScalarEdgeSerializable : public ScalarEdge<T>, public JsonSerializable {
public:
using ScalarEdge<T>::ScalarEdge;
using ScalarEdge<T>::value;

virtual void to_json(nlohmann::json &j) {
j = nlohmann::json{
{"type", type_name()},
{"value", value},
{"label", this->get_label()},
{"tex_symbol", this->get_raw_tex_symbol()}};
};

virtual void from_json(const nlohmann::json &j) {
std::string type = j.at("type");

if (type != type_name()) {
throw shambase::make_except_with_loc<std::runtime_error>(shambase::format(
"error when deserializing ScalarEdgeSerializable, expected type info "
"\"{}\" but got \"{}\"",
type_name(),
type));
}

value = j.at("value").get<T>();
};

inline static std::string type_name_static() {
return "ScalarEdgeSerializable<" + shambase::get_type_name<T>() + ">";
}

virtual std::string type_name() { return type_name_static(); };
};

} // namespace shamrock::solvergraph

template<class T>
void register_ctor_deser() {

auto ctor = [](const nlohmann::json &j) -> std::shared_ptr<shamrock::solvergraph::IEdge> {
std::string label = j.at("label").get<std::string>();
std::string tex_symbol = j.at("tex_symbol").get<std::string>();

return std::make_shared<shamrock::solvergraph::ScalarEdgeSerializable<T>>(
label, tex_symbol);
};

deser_map.insert({shamrock::solvergraph::ScalarEdgeSerializable<T>::type_name_static(), ctor});
}

PRE_MAIN_FUNCTION_CALL([&]() {
register_ctor_deser<f64>();
})
75 changes: 73 additions & 2 deletions src/shamrock/include/shamrock/solvergraph/SolverGraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,40 @@
*
*/

#include "shambase/exception.hpp"
#include "shambase/memory.hpp"
#include "shamrock/solvergraph/IEdge.hpp"
#include "shamrock/solvergraph/INode.hpp"
#include <unordered_map>
#include <functional>
#include <memory>
#include <optional>
#include <stdexcept>
#include <vector>

namespace shamrock::solvergraph {

struct SolverGraphContraint {
std::optional<std::function<bool(const std::shared_ptr<INode> &)>> _validate_node;
std::optional<std::function<bool(const std::shared_ptr<IEdge> &)>> _validate_edge;

inline static SolverGraphContraint no_constraint() { return {std::nullopt, std::nullopt}; }

inline bool validate_node(const std::shared_ptr<INode> &node) {
if (_validate_node) {
return (*_validate_node)(node);
}
return true;
}

inline bool validate_edge(const std::shared_ptr<IEdge> &edge) {
if (_validate_edge) {
return (*_validate_edge)(edge);
}
return true;
}
};

/**
* @brief A graph container for managing solver nodes and edges with type-safe access.
*
Expand Down Expand Up @@ -55,16 +81,31 @@ namespace shamrock::solvergraph {
*/
class SolverGraph {
/// Registry of nodes by name
std::unordered_map<std::string, std::shared_ptr<INode>> nodes;
std::unordered_map<std::string, std::shared_ptr<INode>> nodes = {};

/// Registry of edges by name
std::unordered_map<std::string, std::shared_ptr<IEdge>> edges;
std::unordered_map<std::string, std::shared_ptr<IEdge>> edges = {};

SolverGraphContraint constraint = SolverGraphContraint::no_constraint();

public:
///////////////////////////////////////
// base getters and setters
///////////////////////////////////////

SolverGraph() = default;

SolverGraph(
std::optional<std::function<bool(const std::shared_ptr<INode> &)>> _validate_node,
std::optional<std::function<bool(const std::shared_ptr<IEdge> &)>> _validate_edge)
: constraint(SolverGraphContraint{_validate_node, _validate_edge}) {}

inline static SolverGraph with_constraint(
std::optional<std::function<bool(const std::shared_ptr<INode> &)>> _validate_node,
std::optional<std::function<bool(const std::shared_ptr<IEdge> &)>> _validate_edge) {
return SolverGraph{_validate_node, _validate_edge};
}

/**
* @brief Register a node with the graph using a shared pointer.
*
Expand All @@ -74,6 +115,12 @@ namespace shamrock::solvergraph {
*/
inline std::shared_ptr<INode> register_node_ptr_base(
const std::string &name, std::shared_ptr<INode> node) {

if (!constraint.validate_node(node)) {
throw shambase::make_except_with_loc<std::invalid_argument>(
"node validation failed under solvergraph constraint");
}

const auto [it, inserted] = nodes.try_emplace(name, std::move(node));
if (!inserted) {
shambase::throw_with_loc<std::invalid_argument>(
Expand All @@ -91,6 +138,12 @@ namespace shamrock::solvergraph {
*/
inline std::shared_ptr<IEdge> register_edge_ptr_base(
const std::string &name, std::shared_ptr<IEdge> edge) {

if (!constraint.validate_edge(edge)) {
throw shambase::make_except_with_loc<std::invalid_argument>(
"edge validation failed under solvergraph constraint");
}

const auto [it, inserted] = edges.try_emplace(name, std::move(edge));
if (!inserted) {
shambase::throw_with_loc<std::invalid_argument>(
Expand Down Expand Up @@ -336,6 +389,24 @@ namespace shamrock::solvergraph {
inline const T &get_edge_ref(const std::string &name) const {
return shambase::get_check_ref(get_edge_ptr<T>(name));
}

std::vector<std::string> get_edge_names() {
std::vector<std::string> ret{};

for (auto &[k, e] : edges) {
ret.push_back(k);
}
Comment thread
tdavidcl marked this conversation as resolved.
return ret;
}

std::vector<std::string> get_node_names() {
std::vector<std::string> ret{};

for (auto &[k, n] : nodes) {
ret.push_back(k);
}
Comment thread
tdavidcl marked this conversation as resolved.
return ret;
}
};

} // namespace shamrock::solvergraph
3 changes: 3 additions & 0 deletions src/shamrock/src/io/ShamrockDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ namespace shamrock {
sched.patch_list = jmeta_patch.at("patchlist").get<SchedulerPatchList>();
sched.patch_tree = jmeta_patch.at("patchtree").get<scheduler::PatchTree>();
sched.patch_data.sim_box.from_json(jmeta_patch.at("sim_box"));
if (jmeta_patch.contains("synchronized_data")) {
sched.synchronized_data.from_json(jmeta_patch.at("synchronized_data"));
}

// edit patch owner to fit in new world size, or spread if more processes now
// a bit dirty but gets the job done for now
Expand Down
Loading
Loading