diff --git a/.cirrus.yml b/.cirrus.yml index d51417db..459bbf26 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -3,13 +3,20 @@ jammy_task: only_if: "$CIRRUS_PR != '' || $CIRRUS_BRANCH == 'main'" # This condition ensures the task runs only for PRs or on the main branch. timeout_in: 120m container: - image: ghcr.io/dairlab/docker-dair/jammy-dair-base:v1.42 + image: ghcr.io/dairlab/docker-dair/jammy-dair-base:v1.42.0-2 cpu: 8 memory: 24 format_script: - apt update && apt install -y clang-format - ./tools/scripts/check_format.sh test_script: + - | # For PRs, merge the base branch to ensure the latest changes are included in the CI environment. This helps catch merge conflicts early and ensures tests run against the most up-to-date code. + if [ -n "$CIRRUS_PR" ]; then + git config user.email "ci@ci.com" + git config user.name "CI" + git fetch origin $CIRRUS_BASE_BRANCH + git merge origin/$CIRRUS_BASE_BRANCH --no-edit || (echo "Merge conflict with $CIRRUS_BASE_BRANCH" && exit 1) + fi - export CC=clang-15 - export CXX=clang++-15 - apt update && apt install -y python3-venv @@ -40,13 +47,20 @@ noble_task: only_if: "$CIRRUS_PR != '' || $CIRRUS_BRANCH == 'main'" timeout_in: 120m container: - image: ghcr.io/dairlab/docker-dair/noble-dair-base:v1.42 + image: ghcr.io/dairlab/docker-dair/noble-dair-base:v1.42.0-2 cpu: 8 memory: 24 format_script: - apt update && apt install -y clang-format - ./tools/scripts/check_format.sh test_script: + - | # For PRs, merge the base branch to ensure the latest changes are included in the CI environment. This helps catch merge conflicts early and ensures tests run against the most up-to-date code. + if [ -n "$CIRRUS_PR" ]; then + git config user.email "ci@ci.com" + git config user.name "CI" + git fetch origin $CIRRUS_BASE_BRANCH + git merge origin/$CIRRUS_BASE_BRANCH --no-edit || (echo "Merge conflict with $CIRRUS_BASE_BRANCH" && exit 1) + fi - export CC=clang-15 - export CXX=clang++-15 - apt update && apt install -y python3-venv diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 6882e9d0..f2610448 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -12,7 +12,7 @@ jobs: group: ci-${{ github.ref }} cancel-in-progress: true container: - image: ghcr.io/dairlab/docker-dair/noble-dair-base:v1.42 + image: ghcr.io/dairlab/docker-dair/noble-dair-base:v1.51.1 credentials: username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} diff --git a/bindings/pyc3/BUILD.bazel b/bindings/pyc3/BUILD.bazel index 6aeb3bd1..95e9419c 100644 --- a/bindings/pyc3/BUILD.bazel +++ b/bindings/pyc3/BUILD.bazel @@ -30,9 +30,7 @@ pybind_py_library( pybind_py_library( name = "traj_eval_py", - cc_deps = ["//core:lcs", - "//core:traj_eval", - "//core:c3"], + cc_deps = ["//core:traj_eval"], cc_so_name = "traj_eval", cc_srcs = ["traj_eval_py.cc"], py_deps = [ diff --git a/bindings/pyc3/__init__.py b/bindings/pyc3/__init__.py index df08b6d4..cefe5f11 100644 --- a/bindings/pyc3/__init__.py +++ b/bindings/pyc3/__init__.py @@ -1,6 +1,5 @@ # Importing everything in this directory to this package import pydrake -from . import * from .c3 import * from .traj_eval import * from .systems import * diff --git a/bindings/pyc3/c3_multibody_py.cc b/bindings/pyc3/c3_multibody_py.cc index 1a69a9e1..30c995eb 100644 --- a/bindings/pyc3/c3_multibody_py.cc +++ b/bindings/pyc3/c3_multibody_py.cc @@ -17,14 +17,32 @@ PYBIND11_MODULE(multibody, m) { m.doc() = "C3 Multibody Utilities"; // LCSFactory Class and ContactModel enum - py::enum_(m, "ContactModel") - .value("Unknown", c3::multibody::ContactModel::kUnknown) - .value("StewartAndTrinkle", - c3::multibody::ContactModel::kStewartAndTrinkle) - .value("Anitescu", c3::multibody::ContactModel::kAnitescu) - .value("FrictionlessSpring", - c3::multibody::ContactModel::kFrictionlessSpring) - .export_values(); + auto contact_model_enum = + py::enum_(m, "ContactModel") + .value("Unknown", c3::multibody::ContactModel::kUnknown) + .value("StewartAndTrinkle", + c3::multibody::ContactModel::kStewartAndTrinkle) + .value("Anitescu", c3::multibody::ContactModel::kAnitescu) + .value("FrictionlessSpring", + c3::multibody::ContactModel::kFrictionlessSpring) + .export_values(); + + contact_model_enum.attr("__str__") = py::cpp_function( + [](c3::multibody::ContactModel + model) { // Iterate through the map to find the string for the + // given enum value + for (const auto& pair : GetContactModelMap()) { + if (pair.second == model) { + return pair.first; + } + } + return std::string("unknown"); + }, + py::name("__str__"), py::is_method(contact_model_enum)); + + // Add a binding for the map itself + m.def("GetContactModelMap", &GetContactModelMap, + "Returns a map from contact model names to enum values."); py::class_(m, "LCSContactDescription") .def(py::init<>()) @@ -40,6 +58,91 @@ PYBIND11_MODULE(multibody, m) { &c3::multibody::LCSContactDescription:: CreateSlackVariableDescription); + py::class_(m, "ContactPairConfig") + .def(py::init<>()) + .def_readwrite("body_A", &ContactPairConfig::body_A) + .def_readwrite("body_B", &ContactPairConfig::body_B) + .def_readwrite("body_A_collision_geom_indices", + &ContactPairConfig::body_A_collision_geom_indices) + .def_readwrite("body_B_collision_geom_indices", + &ContactPairConfig::body_B_collision_geom_indices) + .def_readwrite("mu", &ContactPairConfig::mu) + .def_readwrite("num_friction_directions", + &ContactPairConfig::num_friction_directions) + .def_readwrite("planar_normal_direction", + &ContactPairConfig::planar_normal_direction) + .def_readwrite("num_active_contact_pairs", + &ContactPairConfig::num_active_contact_pairs); + + py::class_(m, "LCSFactoryOptions") + .def(py::init<>()) + .def_readwrite("dt", &LCSFactoryOptions::dt) + .def_readwrite("N", &LCSFactoryOptions::N) + .def_readwrite("contact_model", &LCSFactoryOptions::contact_model) + .def_property( + "num_contacts", + [](const LCSFactoryOptions& self) { + return self.num_contacts.has_value() + ? py::cast(self.num_contacts.value()) + : py::none(); + }, + [](LCSFactoryOptions& self, py::object val) { + if (val.is_none()) { + self.num_contacts.reset(); + } else { + self.num_contacts = py::cast(val); + } + }) + .def_property( + "spring_stiffness", + [](const LCSFactoryOptions& self) { + return self.spring_stiffness.has_value() + ? py::cast(self.spring_stiffness.value()) + : py::none(); + }, + [](LCSFactoryOptions& self, py::object val) { + if (val.is_none()) { + self.spring_stiffness.reset(); + } else { + self.spring_stiffness = py::cast(val); + } + }) + .def_property( + "num_friction_directions", + [](const LCSFactoryOptions& self) { + return self.num_friction_directions.has_value() + ? py::cast(self.num_friction_directions.value()) + : py::none(); + }, + [](LCSFactoryOptions& self, py::object val) { + if (val.is_none()) { + self.num_friction_directions.reset(); + } else { + self.num_friction_directions = py::cast(val); + } + }) + .def_property( + "mu", + [](const LCSFactoryOptions& self) { + return self.mu.has_value() ? py::cast(self.mu.value()) : py::none(); + }, + [](LCSFactoryOptions& self, py::object val) { + if (val.is_none()) { + self.mu.reset(); + } else { + self.mu = py::cast(val); + } + }) + .def_readwrite("num_friction_directions_per_contact", + &LCSFactoryOptions::num_friction_directions_per_contact) + .def_readwrite("mu_per_contact", &LCSFactoryOptions::mu_per_contact) + .def_readwrite("planar_normal_direction_per_contact", + &LCSFactoryOptions::planar_normal_direction_per_contact) + .def_readwrite("planar_normal_direction", + &LCSFactoryOptions::planar_normal_direction) + .def_readwrite("contact_pair_configs", + &LCSFactoryOptions::contact_pair_configs); + py::class_(m, "LCSFactory") .def(py::init&, drake::systems::Context&, @@ -76,6 +179,24 @@ PYBIND11_MODULE(multibody, m) { py::arg("other"), py::arg("active_lambda_inds"), py::arg("inactive_lambda_inds")) // Overload the function GetNumContactVariables + .def("GetNumContactVariablesInstance", + py::overload_cast<>( + &c3::multibody::LCSFactory::GetNumContactVariables, py::const_)) + .def_static( + "GetNumContactVariables", + [](const c3::LCSFactoryOptions& options) { + return c3::multibody::LCSFactory::GetNumContactVariables(options, + nullptr); + }, + py::arg("options")) + .def_static( + "GetNumContactVariables", + [](const c3::LCSFactoryOptions& options, + const drake::multibody::MultibodyPlant* plant) { + return c3::multibody::LCSFactory::GetNumContactVariables(options, + plant); + }, + py::arg("options"), py::arg("plant")) .def_static("GetNumContactVariables", py::overload_cast( &c3::multibody::LCSFactory::GetNumContactVariables), @@ -83,73 +204,10 @@ PYBIND11_MODULE(multibody, m) { py::arg("num_friction_directions")) .def_static( "GetNumContactVariables", - py::overload_cast&, - const c3::LCSFactoryOptions&>( + py::overload_cast>( &c3::multibody::LCSFactory::GetNumContactVariables), - py::arg("plant"), py::arg("options")); - - py::class_(m, "ContactPairConfig") - .def(py::init<>()) - .def_readwrite("body_A", &ContactPairConfig::body_A) - .def_readwrite("body_B", &ContactPairConfig::body_B) - .def_readwrite("body_A_collision_geom_indices", - &ContactPairConfig::body_A_collision_geom_indices) - .def_readwrite("body_B_collision_geom_indices", - &ContactPairConfig::body_B_collision_geom_indices) - .def_readwrite("mu", &ContactPairConfig::mu) - .def_readwrite("num_friction_directions", - &ContactPairConfig::num_friction_directions) - .def_readwrite("planar_normal_direction", - &ContactPairConfig::planar_normal_direction) - .def_readwrite("num_active_contact_pairs", - &ContactPairConfig::num_active_contact_pairs); - - py::class_(m, "LCSFactoryOptions") - .def(py::init<>()) - .def_readwrite("dt", &LCSFactoryOptions::dt) - .def_readwrite("N", &LCSFactoryOptions::N) - .def_property( - "contact_model", - [](const LCSFactoryOptions& self) { - // Convert string back to enum for Python - if (self.contact_model == "stewart_and_trinkle") - return c3::multibody::ContactModel::kStewartAndTrinkle; - if (self.contact_model == "anitescu") - return c3::multibody::ContactModel::kAnitescu; - if (self.contact_model == "frictionless_spring") - return c3::multibody::ContactModel::kFrictionlessSpring; - return c3::multibody::ContactModel::kUnknown; - }, - [](LCSFactoryOptions& self, c3::multibody::ContactModel val) { - // Convert enum to the string the C++ struct expects - switch (val) { - case c3::multibody::ContactModel::kStewartAndTrinkle: - self.contact_model = "stewart_and_trinkle"; - break; - case c3::multibody::ContactModel::kAnitescu: - self.contact_model = "anitescu"; - break; - case c3::multibody::ContactModel::kFrictionlessSpring: - self.contact_model = "frictionless_spring"; - break; - default: - self.contact_model = "unknown"; - break; - } - }) - .def_readwrite("num_friction_directions", - &LCSFactoryOptions::num_friction_directions) - .def_readwrite("num_friction_directions_per_contact", - &LCSFactoryOptions::num_friction_directions_per_contact) - .def_readwrite("num_contacts", &LCSFactoryOptions::num_contacts) - .def_readwrite("spring_stiffness", &LCSFactoryOptions::spring_stiffness) - .def_readwrite("mu", &LCSFactoryOptions::mu) - .def_readwrite("planar_normal_direction", - &LCSFactoryOptions::planar_normal_direction) - .def_readwrite("planar_normal_direction_per_contact", - &LCSFactoryOptions::planar_normal_direction_per_contact) - .def_readwrite("contact_pair_configs", - &LCSFactoryOptions::contact_pair_configs); + py::arg("contact_model"), py::arg("num_contacts"), + py::arg("num_friction_directions_per_contact")); m.def("LoadLCSFactoryOptions", &LoadLCSFactoryOptions); } diff --git a/bindings/pyc3/c3_systems_py.cc b/bindings/pyc3/c3_systems_py.cc index 7a9d470b..2f660a22 100644 --- a/bindings/pyc3/c3_systems_py.cc +++ b/bindings/pyc3/c3_systems_py.cc @@ -38,7 +38,6 @@ namespace systems { namespace pyc3 { PYBIND11_MODULE(systems, m) { py::module::import("pydrake.systems.framework"); - py::module::import("multibody"); // ensure LCSFactoryOptions is registered py::class_>(m, "C3Controller") .def(py::init&, const C3::CostMatrices, C3ControllerOptions>(), diff --git a/bindings/pyc3/test/test_c3.py b/bindings/pyc3/test/test_c3.py index 8ddcba7d..91eee1e3 100644 --- a/bindings/pyc3/test/test_c3.py +++ b/bindings/pyc3/test/test_c3.py @@ -99,8 +99,6 @@ def make_cartpole_options_and_costs(lcs, N=5, c3plus=False): opts.R = R_mat opts.G = G_mat opts.U = U_mat - opts.g_vector = [0.1] * n_lambda + [0.0] * n_u - opts.u_vector = [1.0] * n_lambda + [0.0] * n_u opts.warm_start = False opts.scale_lcs = False opts.end_on_qp_step = True @@ -137,8 +135,6 @@ def make_options(n_x=4, n_u=2, n_lambda=2, is_c3plus=False): n_z = n_x + n_u + n_lambda + (n_lambda if is_c3plus else 0) opts.G = np.ones((n_z, n_z)) opts.U = np.ones((n_z, n_z)) - opts.g_vector = [1.0] * n_lambda - opts.u_vector = [1.0] * n_u opts.warm_start = False opts.scale_lcs = False opts.end_on_qp_step = False diff --git a/bindings/pyc3/test/test_multibody.py b/bindings/pyc3/test/test_multibody.py index 06046db1..0bcad14b 100644 --- a/bindings/pyc3/test/test_multibody.py +++ b/bindings/pyc3/test/test_multibody.py @@ -29,18 +29,18 @@ def test_fields(self): opts.N = 3 opts.num_contacts = 2 # mu is list[float] per binding - opts.mu = [0.5] + opts.mu = 0.5 opts.spring_stiffness = 100.0 opts.num_friction_directions = 4 self.assertAlmostEqual(opts.dt, 0.01) self.assertEqual(opts.N, 3) self.assertEqual(opts.num_contacts, 2) - self.assertAlmostEqual(opts.mu[0], 0.5) + self.assertAlmostEqual(opts.mu, 0.5) def test_contact_model(self): opts = multibody.LCSFactoryOptions() - opts.contact_model = multibody.ContactModel.StewartAndTrinkle - self.assertEqual(opts.contact_model, multibody.ContactModel.StewartAndTrinkle) + opts.contact_model = str(multibody.ContactModel.StewartAndTrinkle) + self.assertEqual(opts.contact_model, "stewart_and_trinkle") def test_contact_pair_configs(self): opts = multibody.LCSFactoryOptions() @@ -83,7 +83,7 @@ def test_with_options(self): opts = multibody.LCSFactoryOptions() opts.num_contacts = 2 opts.num_friction_directions = 4 - opts.contact_model = multibody.ContactModel.StewartAndTrinkle + opts.contact_model = "stewart_and_trinkle" n = multibody.LCSFactory.GetNumContactVariables(opts) self.assertGreater(n, 0) @@ -96,9 +96,9 @@ def test_load(self): self.assertEqual(opts.N, 10) self.assertAlmostEqual(opts.dt, 0.01) self.assertEqual(opts.num_contacts, 3) - self.assertEqual(opts.contact_model, multibody.ContactModel.StewartAndTrinkle) + self.assertEqual(opts.contact_model, "stewart_and_trinkle") self.assertEqual(opts.num_friction_directions, 1) - self.assertAlmostEqual(opts.mu[0], 0.1) + self.assertAlmostEqual(opts.mu, 0.1) self.assertEqual(len(opts.contact_pair_configs), 3) self.assertEqual(opts.contact_pair_configs[0].body_A, "cube") self.assertEqual(opts.contact_pair_configs[0].body_B, "left_finger") @@ -107,6 +107,9 @@ def test_get_num_contact_variables_from_loaded_options(self): opts = multibody.LoadLCSFactoryOptions( "multibody/test/resources/lcs_factory_pivoting_options.yaml" ) + opts.contact_pair_configs = ( + None # test that GetNumContactVariables doesn't require this field + ) n = multibody.LCSFactory.GetNumContactVariables(opts) self.assertGreater(n, 0) diff --git a/bindings/pyc3/test/test_traj_eval.py b/bindings/pyc3/test/test_traj_eval.py index 5c92352f..57886918 100644 --- a/bindings/pyc3/test/test_traj_eval.py +++ b/bindings/pyc3/test/test_traj_eval.py @@ -95,8 +95,6 @@ def test_compute_quadratic_trajectory_cost_with_c3(self): opts.R = self.R_matrix opts.G = np.eye(self.n_x + self.n_u + self.n_lambda) opts.U = np.eye(self.n_x + self.n_u + self.n_lambda) - opts.g_vector = [1.0] * self.n_lambda + [0.0] * self.n_u - opts.u_vector = [1.0] * self.n_lambda + [0.0] * self.n_u opts.warm_start = False opts.scale_lcs = False opts.end_on_qp_step = True diff --git a/bindings/pyc3/traj_eval_py.cc b/bindings/pyc3/traj_eval_py.cc index b7df69e6..9c95070b 100644 --- a/bindings/pyc3/traj_eval_py.cc +++ b/bindings/pyc3/traj_eval_py.cc @@ -3,7 +3,6 @@ #include #include -#include "core/lcs.h" #include "core/traj_eval.h" namespace py = pybind11; diff --git a/examples/python/BUILD.bazel b/examples/python/BUILD.bazel index 608fea40..3ba7d4d5 100644 --- a/examples/python/BUILD.bazel +++ b/examples/python/BUILD.bazel @@ -1,5 +1,13 @@ package(default_visibility = ["//visibility:public"]) +# Library for common systems +py_library( + name = "common_systems", + srcs = ["common_systems.py"], + deps = [ + "//bindings/pyc3:pyc3", + ], +) py_binary( name = "c3_example", @@ -18,6 +26,7 @@ py_binary( deps = [ "//bindings/pyc3:pyc3", ":c3_example", + ":common_systems", # Add this dependency ], data = [ "//examples:example_data" @@ -43,8 +52,9 @@ py_binary( deps = [ "//bindings/pyc3:pyc3", ":c3_example", + ":common_systems", # Add this dependency ], data = [ "//examples:example_data" ], -) \ No newline at end of file +) diff --git a/examples/python/c3_example.py b/examples/python/c3_example.py index 0cb9acf2..5f9da7fd 100644 --- a/examples/python/c3_example.py +++ b/examples/python/c3_example.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt from tqdm import tqdm -from pyc3 import LCS, C3MIQP, C3ControllerOptions, CostMatrices, LoadC3Options +from pyc3 import LCS, C3MIQP, C3Options, CostMatrices, LoadC3Options def make_cartpole_with_soft_walls_dynamics(N: int) -> LCS: @@ -61,7 +61,7 @@ def make_cartpole_with_soft_walls_dynamics(N: int) -> LCS: ) -def make_cartpole_costs(lcs: LCS, options: C3ControllerOptions, N: int) -> CostMatrices: +def make_cartpole_costs(lcs: LCS, options: C3Options, N: int) -> CostMatrices: R = [options.R for _ in range(N)] Q = [options.Q for _ in range(N)] diff --git a/examples/python/lcs_factory_example.py b/examples/python/lcs_factory_example.py index 168b4178..667f3785 100644 --- a/examples/python/lcs_factory_example.py +++ b/examples/python/lcs_factory_example.py @@ -67,7 +67,7 @@ def main(): pivoting = LCS.CreatePlaceholderLCS( plant.num_positions() + plant.num_velocities(), plant.num_actuators(), - LCSFactory.GetNumContactVariables(lcs_factory_options), + LCSFactory.GetNumContactVariables(lcs_factory_options, plant), lcs_factory_options.N, lcs_factory_options.dt, ) diff --git a/multibody/lcs_factory.cc b/multibody/lcs_factory.cc index 1582ee4a..a3022277 100644 --- a/multibody/lcs_factory.cc +++ b/multibody/lcs_factory.cc @@ -816,29 +816,37 @@ int LCSFactory::GetNumContactVariables(ContactModel contact_model, } int LCSFactory::GetNumContactVariables( - const drake::multibody::MultibodyPlant& plant, - const LCSFactoryOptions& options) { + const LCSFactoryOptions& options, + const drake::multibody::MultibodyPlant* plant) { multibody::ContactModel contact_model = GetContactModelMap().at(options.contact_model); - int n_contacts = options.ResolveNumContacts(); - std::vector n_friction_directions_per_contact = - options.ResolveNumFrictionDirections(); + int n_contacts; + std::vector n_friction_directions_per_contact; + if (options.contact_pair_configs.has_value()) { // If contact pair configs are provided, they take precedence over the // options for number of contacts and friction directions. We can expand the // contact pair configs to get the actual number of contacts and friction // directions per contact. + if (plant == nullptr) { + throw std::invalid_argument( + "plant must be provided when contact_pair_configs is set."); + } + // Use default context since we only need the geometry query results to // expand the contact pair configs, and the geometry query results do not // depend on the state of the plant. - auto context = plant.CreateDefaultContext(); + auto context = plant->CreateDefaultContext(); auto expanded = ExpandContactPairConfigs( - plant, *context, options.contact_pair_configs.value()); + *plant, *context, options.contact_pair_configs.value()); n_contacts = expanded.num_contacts(); n_friction_directions_per_contact = expanded.num_friction_directions_per_contact; + } else { + n_contacts = options.ResolveNumContacts(); + n_friction_directions_per_contact = options.ResolveNumFrictionDirections(); } return GetNumContactVariables(contact_model, n_contacts, diff --git a/multibody/lcs_factory.h b/multibody/lcs_factory.h index 00b40643..f3474ec9 100644 --- a/multibody/lcs_factory.h +++ b/multibody/lcs_factory.h @@ -224,14 +224,30 @@ class LCSFactory { * This is the preferred overload as it encapsulates all contact model and * friction configuration in a single options object. * - * @param plant The MultibodyPlant to analyze for contact information. * @param options The LCS options specifying contact model and friction * properties. + * @param plant Optional pointer to the MultibodyPlant, required if contact + * pair configurations are specified in the options. Used to expand contact + * pair configurations to determine the actual number of contacts and friction + * directions. * @return int The number of contact variables. */ static int GetNumContactVariables( - const drake::multibody::MultibodyPlant& plant, - const LCSFactoryOptions& options); + const LCSFactoryOptions& options, + const drake::multibody::MultibodyPlant* plant = nullptr); + + /** + * @brief Get the Num Contact Variables object based on the internal state of + * the factory. + * + * This method returns the number of contact variables (n_lambda_) that was + * computed during the construction of the LCSFactory. This value is + * determined by the contact model and the number of contacts, and is used to + * define the size of the contact force variable in the generated LCS. + * + * @return int + */ + [[nodiscard]] int GetNumContactVariables() const { return n_lambda_; } private: /** diff --git a/multibody/test/multibody_test.cc b/multibody/test/multibody_test.cc index 08697d12..973332f2 100644 --- a/multibody/test/multibody_test.cc +++ b/multibody/test/multibody_test.cc @@ -145,22 +145,21 @@ TEST_F(LCSFactoryPivotingTest, GetNumContactVariables) { options.contact_model = "stewart_and_trinkle"; options.num_friction_directions = 4; options.num_contacts = 2; - EXPECT_EQ(LCSFactory::GetNumContactVariables(*fixture.plant, options), 20); + EXPECT_EQ(LCSFactory::GetNumContactVariables(options), 20); options.contact_model = "anitescu"; options.num_friction_directions = 2; options.num_contacts = 3; - EXPECT_EQ(LCSFactory::GetNumContactVariables(*fixture.plant, options), 12); + EXPECT_EQ(LCSFactory::GetNumContactVariables(options), 12); options.contact_model = "frictionless_spring"; options.num_friction_directions = 0; options.num_contacts = 3; - EXPECT_EQ(LCSFactory::GetNumContactVariables(*fixture.plant, options), 3); + EXPECT_EQ(LCSFactory::GetNumContactVariables(options), 3); // Test error handling for invalid contact model options.contact_model = "some_random_contact_model"; - EXPECT_THROW(LCSFactory::GetNumContactVariables(*fixture.plant, options), - std::out_of_range); + EXPECT_THROW(LCSFactory::GetNumContactVariables(options), std::out_of_range); } // Test that contact pairs can be parsed from options instead of explicit list @@ -190,8 +189,8 @@ TEST_F(LCSFactoryPivotingTest, ContactPairParsing) { EXPECT_EQ(lcs.num_states(), fixture.plant->num_positions() + fixture.plant->num_velocities()); EXPECT_EQ(lcs.num_inputs(), fixture.plant->num_actuators()); - EXPECT_EQ(lcs.num_lambdas(), LCSFactory::GetNumContactVariables( - *fixture.plant, fixture.options)); + EXPECT_EQ(lcs.num_lambdas(), + LCSFactory::GetNumContactVariables(fixture.options, fixture.plant)); } // Parameterized test fixture for testing different contact models and friction @@ -238,8 +237,8 @@ TEST_P(LCSFactoryParameterizedPivotingTest, GenerateLCS) { EXPECT_EQ(lcs.num_states(), fixture.plant->num_positions() + fixture.plant->num_velocities()); EXPECT_EQ(lcs.num_inputs(), fixture.plant->num_actuators()); - EXPECT_EQ(lcs.num_lambdas(), LCSFactory::GetNumContactVariables( - *fixture.plant, fixture.options)); + EXPECT_EQ(lcs.num_lambdas(), + LCSFactory::GetNumContactVariables(fixture.options, fixture.plant)); } // Test static linearization method for different contact models @@ -261,8 +260,8 @@ TEST_P(LCSFactoryParameterizedPivotingTest, LinearizePlantToLCS) { EXPECT_EQ(lcs.num_states(), fixture.plant->num_positions() + fixture.plant->num_velocities()); EXPECT_EQ(lcs.num_inputs(), fixture.plant->num_actuators()); - EXPECT_EQ(lcs.num_lambdas(), LCSFactory::GetNumContactVariables( - *fixture.plant, fixture.options)); + EXPECT_EQ(lcs.num_lambdas(), + LCSFactory::GetNumContactVariables(fixture.options)); } // Test that updating state and input changes contact-dependent LCS matrices diff --git a/systems/c3_controller.cc b/systems/c3_controller.cc index 1b99a584..05f7d322 100644 --- a/systems/c3_controller.cc +++ b/systems/c3_controller.cc @@ -71,7 +71,7 @@ C3Controller::C3Controller( // Determine the size of lambda based on the contact model n_lambda_ = multibody::LCSFactory::GetNumContactVariables( - plant_, controller_options_.lcs_factory_options); + controller_options_.lcs_factory_options, &plant_); // Placeholder vector for initialization VectorXd zeros = VectorXd::Zero(n_x_ + n_lambda_ + n_u_); diff --git a/systems/lcs_factory_system.cc b/systems/lcs_factory_system.cc index 85c28ec7..8ac7ec5c 100644 --- a/systems/lcs_factory_system.cc +++ b/systems/lcs_factory_system.cc @@ -53,7 +53,7 @@ void LCSFactorySystem::InitializeSystem( this->set_name("lcs_factory_system"); n_x_ = plant.num_positions() + plant.num_velocities(); - n_lambda_ = multibody::LCSFactory::GetNumContactVariables(plant, options); + n_lambda_ = multibody::LCSFactory::GetNumContactVariables(options, &plant); n_u_ = plant.num_actuators(); lcs_state_input_port_ = diff --git a/systems/test/systems_test.cc b/systems/test/systems_test.cc index 6fd56f06..8c98d9eb 100644 --- a/systems/test/systems_test.cc +++ b/systems/test/systems_test.cc @@ -358,7 +358,7 @@ TEST_F(LCSFactorySystemTest, OutputLCSIsValid) { EXPECT_EQ(lcs.num_inputs(), plant->num_actuators()); EXPECT_EQ(lcs.num_lambdas(), LCSFactory::GetNumContactVariables( - *plant, controller_options.lcs_factory_options)); + controller_options.lcs_factory_options, plant)); EXPECT_EQ(lcs.dt(), controller_options.lcs_factory_options.dt); EXPECT_EQ(lcs.N(), controller_options.lcs_factory_options.N); }