diff --git a/R/create_default_parameters.R b/R/create_default_parameters.R index b72cd2f67..943806ed6 100644 --- a/R/create_default_parameters.R +++ b/R/create_default_parameters.R @@ -304,15 +304,11 @@ create_default_Population <- function( init_naa[n_ages] <- init_naa[n_ages] / M_value # sum of infinite series # Create a list of default parameters - default <- create_default_parameters_template( - n_parameters = n_years * n_ages - ) |> + default <- create_default_parameters_template() |> # Add the module type, label, value, and estimation type dplyr::mutate( label = "log_M", value = log(M_value), - age = rep(ages, n_years), - time = rep(years, each = n_ages), estimation_type = "constant" ) |> dplyr::add_row( diff --git a/inst/include/interface/rcpp/rcpp_objects/rcpp_models.hpp b/inst/include/interface/rcpp/rcpp_objects/rcpp_models.hpp index c588e52a0..e8b848700 100644 --- a/inst/include/interface/rcpp/rcpp_objects/rcpp_models.hpp +++ b/inst/include/interface/rcpp/rcpp_objects/rcpp_models.hpp @@ -364,9 +364,18 @@ class CatchAtAgeInterface : public FisheryModelInterfaceBase { ss << " \"id\":" << population_interface->log_M.id_m << ",\n"; ss << " \"type\": \"vector\",\n"; ss << " \"dimensionality\": {\n"; - ss << " \"header\": [" << "\"n_years\", \"n_ages\"" << "],\n"; - ss << " \"dimensions\": [" << population_interface->n_years.get() << ", " - << population_interface->n_ages.get() << "]\n},\n"; + if (population_interface->log_M.size() == + static_cast(population_interface->n_years.get() * + population_interface->n_ages.get())) { + ss << " \"header\": [\"n_years\", \"n_ages\"],\n"; + ss << " \"dimensions\": [" << population_interface->n_years.get() + << ", " << population_interface->n_ages.get() << "]\n"; + } else { + ss << " \"header\": [\"scalar\"],\n"; + ss << " \"dimensions\": [" << population_interface->log_M.size() + << "]\n"; + } + ss << "},\n"; ss << " \"values\": " << population_interface->log_M << "\n\n"; ss << "},\n"; @@ -1000,13 +1009,27 @@ class CatchAtAgeInterface : public FisheryModelInterfaceBase { population->n_ages.get()}, fims::Vector{"n_years", "n_ages"}); - derived_quantities["mortality_M"] = fims::Vector( - population->n_years.get() * population->n_ages.get()); - derived_quantities_dim_info["mortality_M"] = fims_popdy::DimensionInfo( - "mortality_M", - fims::Vector{population->n_years.get(), - population->n_ages.get()}, - fims::Vector{"n_years", "n_ages"}); + derived_quantities["mortality_M"] = + fims::Vector(population->log_M.size()); + fims::Vector dim_sizes; + fims::Vector dim_names; + if (population->log_M.size() == + static_cast(population->n_years.get() * + population->n_ages.get())) { + dim_sizes.resize(2); + dim_sizes[0] = static_cast(population->n_years.get()); + dim_sizes[1] = static_cast(population->n_ages.get()); + dim_names.resize(2); + dim_names[0] = "n_years"; + dim_names[1] = "n_ages"; + } else { + dim_sizes.resize(1); + dim_sizes[0] = static_cast(population->log_M.size()); + dim_names.resize(1); + dim_names[0] = "scalar"; + } + derived_quantities_dim_info["mortality_M"] = + fims_popdy::DimensionInfo("mortality_M", dim_sizes, dim_names); derived_quantities["mortality_Z"] = fims::Vector( population->n_years.get() * population->n_ages.get()); diff --git a/inst/include/interface/rcpp/rcpp_objects/rcpp_population.hpp b/inst/include/interface/rcpp/rcpp_objects/rcpp_population.hpp index edb73f1af..2648d9c75 100644 --- a/inst/include/interface/rcpp/rcpp_objects/rcpp_population.hpp +++ b/inst/include/interface/rcpp/rcpp_objects/rcpp_population.hpp @@ -328,6 +328,17 @@ class PopulationInterface : public PopulationInterfaceBase { population->growth_id = this->growth_id.get(); population->recruitment_id = this->recruitment_id.get(); population->maturity_id = this->maturity_id.get(); + if (this->log_M.size() != + static_cast(this->n_years.get() * this->n_ages.get()) && + this->log_M.size() != 1) { + throw std::invalid_argument( + "Population log_M size mismatch. " + "Population log_M is of size " + + fims::to_string(this->log_M.size()) + + ". Population log_M can only be either of size 1 or n_years * " + + "n_ages, and the number of n_years * n_ages is " + + fims::to_string(this->n_years.get() * this->n_ages.get())); + } population->log_M.resize(this->log_M.size()); if (this->log_f_multiplier.size() == diff --git a/inst/include/models/functors/catch_at_age.hpp b/inst/include/models/functors/catch_at_age.hpp index ade4889f0..81826fbf7 100644 --- a/inst/include/models/functors/catch_at_age.hpp +++ b/inst/include/models/functors/catch_at_age.hpp @@ -203,8 +203,8 @@ class CatchAtAge : public FisheryModelBase { for (size_t age = 0; age < population->n_ages; age++) { for (size_t year = 0; year < population->n_years; year++) { size_t i_age_year = age * population->n_years + year; - population->M[i_age_year] = - fims_math::exp(population->log_M[i_age_year]); + population->M.get_force_scalar(i_age_year) = + fims_math::exp(population->log_M.get_force_scalar(i_age_year)); } } @@ -349,14 +349,15 @@ class CatchAtAge : public FisheryModelBase { // using M from previous age/year dq_["unfished_numbers_at_age"][i_age_year] = dq_["unfished_numbers_at_age"][i_agem1_yearm1] * - (fims_math::exp(-population->M[i_agem1_yearm1])); + (fims_math::exp(-population->M.get_force_scalar(i_agem1_yearm1))); // Plus group calculation if (age == (population->n_ages - 1)) { dq_["unfished_numbers_at_age"][i_age_year] = dq_["unfished_numbers_at_age"][i_age_year] + dq_["unfished_numbers_at_age"][i_agem1_yearm1 + 1] * - (fims_math::exp(-population->M[i_agem1_yearm1 + 1])); + (fims_math::exp( + -population->M.get_force_scalar(i_agem1_yearm1 + 1))); } } @@ -408,10 +409,12 @@ class CatchAtAge : public FisheryModelBase { dq_["sum_selectivity"][i_age_year] += s; } - dq_["mortality_M"][i_age_year] = population->M[i_age_year]; + dq_["mortality_M"].get_force_scalar(i_age_year) = + population->M.get_force_scalar(i_age_year); dq_["mortality_Z"][i_age_year] = - population->M[i_age_year] + dq_["mortality_F"][i_age_year]; + population->M.get_force_scalar(i_age_year) + + dq_["mortality_F"][i_age_year]; } /** @@ -577,16 +580,23 @@ class CatchAtAge : public FisheryModelBase { dq_["proportion_mature_at_age"][0] * population->growth->evaluate(0, population->ages[0]); for (size_t a = 1; a < (population->n_ages - 1); a++) { - numbers_spr[a] = numbers_spr[a - 1] * fims_math::exp(-population->M[a]); + // pull out M from the first year + size_t i_age_year = 0 * population->n_ages + a; + numbers_spr[a] = + numbers_spr[a - 1] * + fims_math::exp(-population->M.get_force_scalar(i_age_year)); phi_0 += numbers_spr[a] * population->proportion_female[a] * dq_["proportion_mature_at_age"][a] * population->growth->evaluate(0, population->ages[a]); } numbers_spr[population->n_ages - 1] = + // M will be from first year (numbers_spr[population->n_ages - 2] * - fims_math::exp(-population->M[population->n_ages - 2])) / - (1 - fims_math::exp(-population->M[population->n_ages - 1])); + fims_math::exp( + -population->M.get_force_scalar(population->n_ages - 2))) / + (1 - fims_math::exp( + -population->M.get_force_scalar(population->n_ages - 1))); phi_0 += numbers_spr[population->n_ages - 1] * population->proportion_female[population->n_ages - 1] * dq_["proportion_mature_at_age"][population->n_ages - 1] * diff --git a/tests/testthat/test-integration-fleet-log-obs-error-input.R b/tests/testthat/test-integration-fleet-log-obs-error-input.R index 2374592a8..213ce053a 100644 --- a/tests/testthat/test-integration-fleet-log-obs-error-input.R +++ b/tests/testthat/test-integration-fleet-log-obs-error-input.R @@ -183,3 +183,119 @@ test_that("`log_Fmort` returns correct error messages when wrong dimensions", { clear() }) + + +## IO correctness ---- +test_that("`log_M` output dimensions follow size rules", { + # Check scalar log_M input + parameters_4_model <- default_parameters |> + tidyr::unnest(cols = data) + + test_fit <- parameters_4_model |> + initialize_fims(data = data_4_model) |> + fit_fims(optimize = FALSE) + + output <- get_estimates(test_fit) |> + dplyr::mutate( + uncertainty_label = "se", + year = year_i, + estimate = estimated + ) + + log_m_input <- parameters_4_model |> + dplyr::filter(label == "log_M") |> + dplyr::pull(value) + + mortality_m_output <- output |> + dplyr::filter(label == "mortality_M") |> + dplyr::pull(estimated) + + log_m_output <- output |> + dplyr::filter(label == "log_M") |> + dplyr::pull(estimated) + + #' @description Test that log_M input is scalar. + expect_true(length(log_m_input) == 1) + #' @description Test that log_M output is scalar. + expect_true(length(log_m_output) == 1) + #' @description Test that mortality_M output is scalar. + expect_true(length(mortality_m_output) == 1) + + # Check log_M input with n_years * n_ages dimensions + n_years <- get_n_years(data_4_model) + n_ages <- get_n_ages(data_4_model) + + parameters_4_model <- default_parameters |> + tidyr::unnest(cols = data) + + log_m_template <- parameters_4_model |> + dplyr::filter(label == "log_M") |> + dplyr::select(-time, -age, -value) + + log_m_grid <- tidyr::expand_grid( + time = 1:n_years, + age = 1:n_ages + ) |> + dplyr::mutate(value = log_m_input[1]) |> + dplyr::bind_cols(log_m_template[rep(1, n_years * n_ages), ]) + + parameters_4_model <- parameters_4_model |> + dplyr::filter(!(module_name == "Population" & label == "log_M")) |> + dplyr::bind_rows(log_m_grid) + + test_fit <- parameters_4_model |> + initialize_fims(data = data_4_model) |> + fit_fims(optimize = FALSE) + + output <- get_estimates(test_fit) |> + dplyr::mutate( + uncertainty_label = "se", + year = year_i, + estimate = estimated + ) + + log_m_input <- parameters_4_model |> + dplyr::filter(label == "log_M") |> + dplyr::pull(value) + + mortality_m_output <- output |> + dplyr::filter(label == "mortality_M") |> + dplyr::pull(estimated) + + log_m_output <- output |> + dplyr::filter(label == "log_M") |> + dplyr::pull(estimated) + + #' @description Test that log_M input has n_years * n_ages dimensions. + expect_equal(length(log_m_input), n_years * n_ages) + #' @description Test that log_M output matches n_years * n_ages. + expect_equal(length(log_m_output), n_years * n_ages) + #' @description Test that mortality_M output matches n_years * n_ages. + expect_equal(length(mortality_m_output), n_years * n_ages) + + clear() +}) + + +## Error handling ---- +test_that("`log_M` returns correct error messages when wrong dimensions", { + #' @description Test that returns correct error message when log_M is too long. + parameters_4_model <- default_parameters |> + tidyr::unnest(cols = data) |> + dplyr::add_row( + model_family = "catch_at_age", + module_name = "Population", + estimation_type = "constant", + label = "log_M", + value = -3 + ) + + expect_error( + { + test_fit <- parameters_4_model |> + initialize_fims(data = data_4_model) |> + fit_fims(optimize = FALSE) + }, + regexp = "Population log_M size mismatch" + ) +})