Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/dsf/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,23 @@ PYBIND11_MODULE(dsf_cpp, m) {

// Bind Measurement to main module (can be used across different contexts)
pybind11::class_<dsf::Measurement<double>>(m, "Measurement")
.def(pybind11::init<double, double>(),
.def(pybind11::init<double, double, std::size_t>(),
pybind11::arg("mean"),
pybind11::arg("std"),
pybind11::arg("n"),
dsf::g_docstrings.at("dsf::Measurement::Measurement").c_str())
.def_readwrite("mean",
&dsf::Measurement<double>::mean,
dsf::g_docstrings.at("dsf::Measurement::mean").c_str())
.def_readwrite("std",
&dsf::Measurement<double>::std,
dsf::g_docstrings.at("dsf::Measurement::std").c_str());
dsf::g_docstrings.at("dsf::Measurement::std").c_str())
.def_readwrite("n",
&dsf::Measurement<double>::n,
dsf::g_docstrings.at("dsf::Measurement::n").c_str())
.def_readwrite("is_valid",
&dsf::Measurement<double>::is_valid,
dsf::g_docstrings.at("dsf::Measurement::is_valid").c_str());

// Bind mobility-related classes to mobility submodule
pybind11::class_<dsf::mobility::RoadNetwork>(mobility, "RoadNetwork")
Expand Down
28 changes: 16 additions & 12 deletions src/dsf/mobility/RoadDynamics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ namespace dsf::mobility {
///
/// - std_speed_kph: The standard deviation of the speed in kilometers per hour
///
/// - n_observations: The number of speed observations used to compute avg/std (0 if none)
///
/// - counts: The counts of the coil sensor (can be null)
///
/// - queue_length: The length of the queue on the street
Expand Down Expand Up @@ -1315,6 +1317,7 @@ namespace dsf::mobility {
"density_vpk REAL, "
"avg_speed_kph REAL, "
"std_speed_kph REAL, "
"n_observations INTEGER, "
"counts INTEGER, "
"queue_length INTEGER)");

Expand Down Expand Up @@ -1902,6 +1905,7 @@ namespace dsf::mobility {
double density;
std::optional<double> avgSpeed;
std::optional<double> stdSpeed;
std::optional<std::size_t> nObservations;
std::optional<std::size_t> counts;
std::size_t queueLength;
};
Expand All @@ -1914,9 +1918,6 @@ namespace dsf::mobility {
auto const numNodes{this->graph().nNodes()};
auto const numEdges{this->graph().nEdges()};

// Adaptive grain: if fewer active agents than threads, collapse to one block
// (effectively serial) to avoid TBB scheduling overhead on a nearly empty network.
const auto nCurrentAgents = static_cast<std::size_t>(this->nAgents());
const auto grainSize = std::max<std::size_t>(1, numNodes / n_threads);
this->m_taskArena.execute([&] {
tbb::parallel_for(
Expand Down Expand Up @@ -1974,6 +1975,7 @@ namespace dsf::mobility {
if (speedMeasure.is_valid) {
record.avgSpeed = speedMeasure.mean * 3.6; // to kph
record.stdSpeed = speedMeasure.std * 3.6;
record.nObservations = speedMeasure.n;
}
record.queueLength = pStreet->nExitingAgents();
streetDataRecords.push_back(record);
Expand Down Expand Up @@ -2007,8 +2009,9 @@ namespace dsf::mobility {
SQLite::Statement insertStmt(
*this->database(),
"INSERT INTO road_data (datetime, time_step, simulation_id, street_id, "
"coil, density_vpk, avg_speed_kph, std_speed_kph, counts, queue_length) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
"coil, density_vpk, avg_speed_kph, std_speed_kph, n_observations, counts, "
"queue_length) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");

for (auto const& record : streetDataRecords) {
insertStmt.bind(1, this->strDateTime());
Expand All @@ -2028,12 +2031,13 @@ namespace dsf::mobility {
insertStmt.bind(7);
insertStmt.bind(8);
}
insertStmt.bind(9, static_cast<std::int64_t>(record.nObservations.value_or(0)));
if (record.counts.has_value()) {
insertStmt.bind(9, static_cast<std::int64_t>(record.counts.value()));
insertStmt.bind(10, static_cast<std::int64_t>(record.counts.value()));
} else {
insertStmt.bind(9);
insertStmt.bind(10);
}
insertStmt.bind(10, static_cast<std::int64_t>(record.queueLength));
insertStmt.bind(11, static_cast<std::int64_t>(record.queueLength));
insertStmt.exec();
insertStmt.reset();
}
Expand Down Expand Up @@ -2596,7 +2600,7 @@ namespace dsf::mobility {
requires(is_numeric_v<delay_t>)
Measurement<double> RoadDynamics<delay_t>::meanTravelDistance(bool clearData) {
if (m_travelDTs.empty()) {
return Measurement(0., 0.);
return Measurement<double>();
}
std::vector<double> travelDistances;
travelDistances.reserve(m_travelDTs.size());
Expand Down Expand Up @@ -2666,7 +2670,7 @@ namespace dsf::mobility {
requires(is_numeric_v<delay_t>)
Measurement<double> RoadDynamics<delay_t>::streetMeanDensity(bool normalized) const {
if (this->graph().edges().empty()) {
return Measurement(0., 0.);
return Measurement<double>();
}
std::vector<double> densities;
densities.reserve(this->graph().nEdges());
Expand All @@ -2681,10 +2685,10 @@ namespace dsf::mobility {
sum += pStreet->length();
}
if (sum == 0) {
return Measurement(0., 0.);
return Measurement<double>();
}
auto meanDensity{std::accumulate(densities.begin(), densities.end(), 0.) / sum};
return Measurement(meanDensity, 0.);
return Measurement<double>(meanDensity, 0., densities.size());
}
return Measurement<double>(densities);
}
Expand Down
13 changes: 9 additions & 4 deletions src/dsf/utility/Measurement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,35 @@
namespace dsf {
/// @brief The Measurement struct represents the mean of a quantity and its standard deviation
/// @tparam T The type of the quantity
/// @param mean The mean
/// @param mean The mean of the sample
/// @param std The standard deviation of the sample
/// @param n The number of samples
/// @param is_valid True if the measurement is valid, false otherwise (i.e. checks if the sample is not empty)
template <typename T>
struct Measurement {
T mean = static_cast<T>(0);
T std = static_cast<T>(0);
std::size_t n = 0;
bool is_valid = false;

Measurement(T mean, T std) : mean{mean}, std{std}, is_valid{true} {}
Measurement() = default;

Check notice

Code scanning / Cppcheck (reported by Codacy)

MISRA 16.3 rule Note

MISRA 16.3 rule
Measurement(T mean, T std, std::size_t n)
: mean{mean}, std{std}, n{n}, is_valid{true} {}
template <typename TContainer>
Measurement(TContainer const& data) {
if (data.empty()) {
return;
}
n = data.size();
is_valid = true;
auto x_mean = static_cast<T>(0), x2_mean = static_cast<T>(0);

std::for_each(data.begin(), data.end(), [&x_mean, &x2_mean](auto value) -> void {
x_mean += value;
x2_mean += value * value;
});
mean = x_mean / data.size();
std = std::sqrt(x2_mean / data.size() - mean * mean);
mean = x_mean / n;
std = std::sqrt(x2_mean / n - mean * mean);

Check notice

Code scanning / Cppcheck (reported by Codacy)

MISRA 12.1 rule Note

MISRA 12.1 rule
}
};
} // namespace dsf
8 changes: 8 additions & 0 deletions test/mobility/Test_dynamics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Measurement<float> m(data);
CHECK_EQ(m.mean, 49.5f);
CHECK_EQ(m.std, doctest::Approx(28.8661f));
CHECK_EQ(m.n, 100);
CHECK(m.is_valid);
}
SUBCASE("STL array") {
std::array<float, 100> data;
Expand All @@ -42,6 +44,8 @@
Measurement<float> m(data);
CHECK_EQ(m.mean, 49.5f);
CHECK_EQ(m.std, doctest::Approx(28.8661f));
CHECK_EQ(m.n, 100);
CHECK(m.is_valid);
}
SUBCASE("STL span") {
auto p = std::make_unique_for_overwrite<float[]>(100);
Expand All @@ -51,6 +55,8 @@
Measurement<float> m(data);
CHECK_EQ(m.mean, 49.5f);
CHECK_EQ(m.std, doctest::Approx(28.8661f));
CHECK_EQ(m.n, 100);
CHECK(m.is_valid);
}
}

Expand Down Expand Up @@ -1285,7 +1291,9 @@
CHECK(roadColumns.count("density_vpk") == 1);
CHECK(roadColumns.count("avg_speed_kph") == 1);
CHECK(roadColumns.count("std_speed_kph") == 1);
CHECK(roadColumns.count("n_observations") == 1);

Check notice

Code scanning / Cppcheck (reported by Codacy)

MISRA 10.4 rule Note test

MISRA 10.4 rule
CHECK(roadColumns.count("counts") == 1);
CHECK(roadColumns.count("queue_length") == 1);

Check notice

Code scanning / Cppcheck (reported by Codacy)

MISRA 10.4 rule Note test

MISRA 10.4 rule

// Check avg_stats table
SQLite::Statement avgQuery(db, "SELECT COUNT(*) FROM avg_stats");
Expand Down
Loading