diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md new file mode 100644 index 00000000..1ed54bc4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -0,0 +1,46 @@ +--- +name: Bug Report +about: Submit a bug report +title: "[Bug Report] Bug title" + +--- + +If you are submitting a bug report, please fill in the following details and use the tag [bug]. + +### Describe the bug + +A clear and concise description of what the bug is. + +### Steps to reproduce + +Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. + + + +### System Info + +Describe the characteristic of your environment: + + +- Commit: [e.g. 8f3b9ca or main branch] +- OS: [e.g. Ubuntu 22.04] +- GPU: [e.g. RTX 3060] +- CUDA: [e.g. 12.8] +- GPU Driver: [e.g. 570.195.03, this can be seen by using `nvidia-smi` command.] + +### Additional context + +Add any other context about the problem here. + +### Checklist + +- [ ] I have checked that there is no similar issue in the repo (**required**) + diff --git a/.github/ISSUE_TEMPLATE/proposal.md b/.github/ISSUE_TEMPLATE/proposal.md new file mode 100644 index 00000000..8b7792a2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/proposal.md @@ -0,0 +1,26 @@ +--- +name: Proposal +about: Propose changes that are not bug fixes +title: "[Proposal] Proposal title" +--- + + +### Proposal + +A clear and concise description of the proposal. In a few sentences, describe the feature and its core capabilities. + +### Motivation + +Please outline the motivation for the proposal. Summarize the core use cases and user problems and needs you are trying to solve. + +Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". + +If this is related to another GitHub issue, please link here too. + +### Additional context + +Add any other context or screenshots about the feature request here. + +### Checklist + +- [ ] I have checked that there is no similar issue in the repo (**required**) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..bcdc2439 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,54 @@ +# Description + + + +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. +List any dependencies that are required for this change. + +Fixes # (issue) + + + +## Type of change + + + +- Bug fix (non-breaking change which fixes an issue) +- New feature (non-breaking change which adds functionality) +- Breaking change (existing functionality will not work without user modification) +- Documentation update + +## Screenshots + +Please attach before and after screenshots of the change if applicable. + + + +## Checklist + +- [ ] I have run the `black .` command to format the code base. +- [ ] I have made corresponding changes to the documentation +- [ ] I have added tests that prove my fix is effective or that my feature works +- [ ] Dependencies have been updated, if applicable. + + diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..4857c898 --- /dev/null +++ b/.gitignore @@ -0,0 +1,222 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +#python api built by docs +# docs/source/api + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# ignore file for vlac demo +cache + +*.filament +*.filamat +*.png +*.tiff +*.npy +*.pkl +*.hdf5 +*.pt +*.csv +materials + +!embodichain/agents/dexforce_vla/data/empty_lang_embed.pt +!resources/Arch.png +embodichain/toolkits/outputs/* +embodichain/toolkits/outputs/* + +embodichain/database/* + +3rdparty/ +Log/ +embodichain/deploy/h1/sim/unitree_h1 +embodichain/deploy/inspire/inspire_hand +embodichain/devices/camera/king_fisher/kingfisher + +# tensorboard logs +embodichain/agents/policy/runs/* +embodichain/agents/dexrdt/wandb +*.pth +outputs +test_configs/* + +wandb/ +embodichain/deploy/mobile_aloha/collect_data/data +*.mp4 + +# vscode settings +.vscode/ + +# web server backend +*/backend/datas/ +*/backend/logs/ +*/backend/__pycache__/ + +# web server frontend +*/backend/export_datas/ +*/backend/**/__pycache__/ + +# web server frontend +!web_server/frontend/**/* +!web_server/frontend/**/event* +*/frontend/node_modules +*/frontend/.DS_Store +*/frontend/*.local +*/frontend/.eslintcache +*/frontend/.stylelintcache diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..17245082 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,23 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.10" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + +# Optionally, but recommended, +# declare the Python requirements required to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000..61c352a4 --- /dev/null +++ b/README.md @@ -0,0 +1,45 @@ +# EmbodiChain + +![teaser](assets/imgs/teaser.jpg) +**📘 [Documentation](http://192.168.3.120/MixedAI/docs_dev/embodichain/index.html)** +--- + +EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI. It streamlines research and development by unifying high-performance simulation, real-to-sim data pipelines, modular model architectures, and efficient training workflows. This integration enables rapid experimentation, seamless deployment of intelligent agents, and effective Sim2Real transfer for real-world robotic systems. + +> [!NOTE] +> EmbodiChain is in Alpha and under active development: +> * More features will be continually added in the coming months. +> * Since this is an early release, we welcome feedback (bug reports, feature requests, etc.) via GitHub Issues. + + +## Key Features + +- **High-Fidelity, GPU-Accelerated Simulation**: Combines realistic physics for both rigid and deformable objects with advanced ray-traced sensor modeling, all accelerated on the GPU for high-throughput batched simulations. +- **Unified Robot Learning Environment**: Offers standardized interfaces for a wide range of robot learning tasks, including Imitation Learning and Reinforcement Learning. +- **Scalable Data Pipeline**: Features a comprehensive toolkit for automated data collection, efficient processing, and large-scale data generation to fuel your models. +- **Efficient Training & Evaluation**: Supports modern training paradigms like online data streaming for Imitation Learning and massively parallel environment rollouts for Reinforcement Learning. +- **Modular and Extensible**: Designed with modularity in mind to easily integrate new robot platforms, environments, and learning algorithms. + + +## Getting Started + +To get started with EmbodiChain, follow these steps: + +- [Installation Guide](http://192.168.3.120/MixedAI/docs_dev/embodichain/quick_start/install.html) +- [Quick Start Tutorial](http://192.168.3.120/MixedAI/docs_dev/embodichain/tutorial/index.html) +- [API Reference](http://192.168.3.120/MixedAI/docs_dev/embodichain/api_reference/index.html) + + +## Citation + +If you use EmbodiChain in your research, please cite our work: + +```bibtex +@misc{EmbodiChain, + author = {EmbodiChain Developers}, + title = {EmbodiChain: An end-to-end, GPU-accelerated, and modular platform for building generalized Embodied Intelligence.}, + month = {November}, + year = {2025}, + url = {https://github.com/DexForce/EmbodiChain} +} +``` \ No newline at end of file diff --git a/VERSION b/VERSION new file mode 100644 index 00000000..8acdd82b --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.0.1 diff --git a/assets/imgs/teaser.jpg b/assets/imgs/teaser.jpg new file mode 100644 index 00000000..59f12c3d Binary files /dev/null and b/assets/imgs/teaser.jpg differ diff --git a/configs/agents/rl/push_cube/gym_config.json b/configs/agents/rl/push_cube/gym_config.json new file mode 100644 index 00000000..b78e7f37 --- /dev/null +++ b/configs/agents/rl/push_cube/gym_config.json @@ -0,0 +1,129 @@ +{ + "id": "PushCubeRL", + "max_episodes": 5, + "env": { + "num_envs": 128, + "sim_steps_per_control": 4, + "events": { + "randomize_cube": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cube"}, + "position_range": [[-0.2, -0.2, 0.0], [0.2, 0.2, 0.0]], + "relative_position": true + } + } + }, + "observations": {}, + "extensions": { + "obs_mode": "state", + "episode_length": 100, + "joint_limits": 0.5, + "action_scale": 0.1, + "success_threshold": 0.1, + "reaching_reward_weight": 0.1, + "place_reward_weight": 2.0, + "place_penalty_weight": 0.5, + "action_penalty_weight": 0.01, + "success_bonus_weight": 10.0 + } + }, + "robot": { + "uid": "Manipulator", + "urdf_cfg": { + "components": [ + { + "component_type": "arm", + "urdf_path": "UniversalRobots/UR10/UR10.urdf", + "transform": [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ] + }, + { + "component_type": "hand", + "urdf_path": "DH_PGI_140_80/DH_PGI_140_80.urdf", + "transform": [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0] + ] + } + ] + }, + "init_pos": [0.0, 0.0, 0.0], + "init_rot": [0.0, 0.0, 0.0], + "init_qpos": [0.0, -1.57, 1.57, -1.57, -1.57, 0.0, 0.04, 0.04], + "drive_pros": { + "drive_type": "force", + "stiffness": 100000.0, + "damping": 1000.0, + "max_velocity": 2.0, + "max_effort": 500.0 + }, + "solver_cfg": { + "arm": { + "class_type": "PytorchSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + "tcp": [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.16], + [0.0, 0.0, 0.0, 1.0] + ] + } + }, + "control_parts": { + "arm": ["JOINT[1-6]"] + } + }, + "sensor": [], + "light": { + }, + "background": [ + { + "uid": "goal_sphere", + "shape": { + "shape_type": "Sphere", + "radius": 0.02 + }, + "body_type": "kinematic", + "init_pos": [-0.9, -0.6, 0.05], + "attrs": { + "enable_collision": false, + "mass": 0.0 + } + } + ], + "rigid_object": [ + { + "uid": "cube", + "shape": { + "shape_type": "Cube", + "size": [0.1, 0.1, 0.1] + }, + "body_type": "dynamic", + "init_pos": [-0.6, -0.4, 0.05], + "attrs": { + "mass": 10.0, + "static_friction": 3.0, + "dynamic_friction": 2.0, + "linear_damping": 2.0, + "angular_damping": 2.0, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.1, + "max_depenetration_velocity": 10.0, + "max_linear_velocity": 1.0, + "max_angular_velocity": 1.0 + } + } + ], + "rigid_object_group": [], + "articulation": [] +} diff --git a/configs/agents/rl/push_cube/train_config.json b/configs/agents/rl/push_cube/train_config.json new file mode 100644 index 00000000..f1558fda --- /dev/null +++ b/configs/agents/rl/push_cube/train_config.json @@ -0,0 +1,64 @@ +{ + "trainer": { + "exp_name": "push_cube_ppo", + "gym_config": "configs/agents/rl/push_cube/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "iterations": 1000, + "rollout_steps": 1024, + "eval_freq": 2, + "save_freq": 200, + "use_wandb": false, + "wandb_project_name": "embodychain-push_cube", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.0], + "target": [0, 0, 0], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/eval" + } + } + } + } + }, + "policy": { + "name": "actor_critic", + "actor": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + }, + "critic": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.01, + "vf_coef": 0.5, + "max_grad_norm": 0.5 + } + } +} \ No newline at end of file diff --git a/configs/gym/action_bank/conf.json b/configs/gym/action_bank/conf.json new file mode 100644 index 00000000..3c5e9850 --- /dev/null +++ b/configs/gym/action_bank/conf.json @@ -0,0 +1,237 @@ +{ + "scope": { + "right_arm": { + "type": "DiGraph", + "dim": [ + 6 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 0, + 0, + 0, + 0, + 0, + 0 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + }, + "left_arm": { + "type": "DiGraph", + "dim": [ + 6 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 0, + 0, + 0, + 0, + 0, + 0 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + }, + "left_eef": { + "type": "DiGraph", + "dim": [ + 2 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 0, + 0 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + }, + "right_eef": { + "type": "DiGraph", + "dim": [ + 2 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 0, + 0 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + } + }, + "node": { + "right_arm": [ + { + "home_qpos": { + "name": "A", + "kwargs": {} + } + }, + { + "bottle_grasp": { + "name": "B", + "kwargs": {} + } + }, + { + "bottle_pre1_pose": { + "name": "C", + "kwargs": {} + } + }, + { + "bottle_pre2_pose": { + "name": "C", + "kwargs": {} + } + }, + { + "bottle_place_pose": { + "name": "D", + "kwargs": {} + } + } + ], + "left_arm": [ + { + "lhome_qpos": { + "name": "a", + "kwargs": {} + } + }, + { + "cup_monitor_pose": { + "name": "b", + "kwargs": {} + } + } + ], + "left_eef": [ + { + "open": { + "name": "aa", + "kwargs": {} + } + }, + { + "close": { + "name": "bb", + "kwargs": {} + } + } + ], + "right_eef": [ + { + "ropen": { + "name": "cc", + "kwargs": {} + } + }, + { + "rclose": { + "name": "dd", + "kwargs": {} + } + } + ] + }, + "edge": { + "right_arm": [ + { + "init_to_pre1": { + "src": "home_qpos", + "sink": "bottle_pre1_pose", + "duration": 1, + "kwargs": {} + } + }, + { + "grasp_to_move": { + "src": "bottle_pre1_pose", + "sink": "bottle_grasp", + "duration": 2, + "kwargs": {} + } + }, + { + "move_to_rotation": { + "src": "bottle_grasp", + "sink": "bottle_pre2_pose", + "duration": 3, + "kwargs": {} + } + }, + { + "rotation_back_to_move": { + "src": "bottle_pre2_pose", + "sink": "bottle_place_pose", + "duration": 4, + "kwargs": {} + } + } + ], + "left_arm": [ + { + "init_to_monitor": { + "src": "lhome_qpos", + "sink": "cup_monitor_pose", + "duration": 1, + "kwargs": {} + } + }, + { + "left_arm_go_back": { + "src": "cup_monitor_pose", + "sink": "lhome_qpos", + "duration": 2, + "kwargs": {} + } + } + ], + "left_eef": [ + { + "lopen": { + "src": "close", + "sink": "open", + "duration": 10, + "kwargs": {} + } + } + ], + "right_eef": [ + { + "ropen": { + "src": "rclose", + "sink": "ropen", + "duration": 10, + "kwargs": {} + } + } + ] + }, + "sync": { + "grasp_to_move": { + "depend_tasks": [ + "init_to_monitor" + ] + } + } +} \ No newline at end of file diff --git a/configs/gym/cobotmagic.json b/configs/gym/cobotmagic.json new file mode 100644 index 00000000..6eef0718 --- /dev/null +++ b/configs/gym/cobotmagic.json @@ -0,0 +1,120 @@ +{ + "id": "EmbodiedEnv-v1", + "max_episodes": 10, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.5, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_robot": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 5, + "params": { + "entity_cfg": {"uid": "CobotMagic", "link_names": [".*"]}, + "random_texture_prob": 0.5, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + }, + "replace_fork": { + "func": "replace_assets_from_group", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "fork"}, + "folder_path": "TableWare/tableware/fork/" + } + } + } + }, + "sensor": [ + { + "sensor_type": "Camera", + "width": 640, + "height": 480, + "enable_mask": true, + "enable_depth": true, + "extrinsics": { + "eye": [0.0, 0.0, 1.0], + "target": [0.0, 0.0, 0.0] + } + } + ], + "robot": { + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.3, 1.2] + }, + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [0, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "ShopTableSimple/shop_table_simple.ply" + }, + "attrs" : { + "mass": 10.0 + }, + "body_scale": [2, 1.6, 1], + "body_type": "kinematic" + } + ], + "rigid_object": [ + { + "uid": "fork", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/fork/standard_fork_scale.ply" + }, + "body_scale": [0.75, 0.75, 1.0], + "init_pos": [0.0, 0.0, 1.0] + } + ], + "articulation": [ + { + "fpath": "SlidingBoxDrawer/SlidingBoxDrawer.urdf", + "init_pos": [0.5, 0.0, 0.85] + } + ] +} diff --git a/configs/gym/dexforce_w1.json b/configs/gym/dexforce_w1.json new file mode 100644 index 00000000..42f0f073 --- /dev/null +++ b/configs/gym/dexforce_w1.json @@ -0,0 +1,109 @@ +{ + "id": "EmbodiedEnv-v1", + "max_episodes": 10, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.5, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + }, + "replace_fork": { + "func": "replace_assets_from_group", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "fork"}, + "folder_path": "TableWare/tableware/fork/" + } + } + } + }, + "sensor": [ + { + "sensor_type": "Camera", + "width": 640, + "height": 480, + "enable_mask": true, + "enable_depth": true, + "extrinsics": { + "eye": [0.0, 0.0, 1.0], + "target": [0.0, 0.0, 0.0] + } + } + ], + "robot": { + "robot_type": "DexforceW1", + "init_pos": [0.0, 1.0, 0] + }, + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [0, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "ShopTableSimple/shop_table_simple.ply" + }, + "attrs" : { + "mass": 10.0 + }, + "body_scale": [2, 1.6, 1], + "body_type": "kinematic" + } + ], + "rigid_object": [ + { + "uid": "fork", + "shape": { + "shape_type": "Mesh", + "fpath": "TableWare/tableware/fork/standard_fork_scale.ply" + }, + "body_scale": [0.75, 0.75, 1.0], + "init_pos": [0.0, 0.0, 1.0] + } + ], + "articulation": [ + { + "fpath": "SlidingBoxDrawer/SlidingBoxDrawer.urdf", + "init_pos": [0.5, 0.0, 0.85] + } + ] +} diff --git a/configs/gym/pour_water/action_config.json b/configs/gym/pour_water/action_config.json new file mode 100644 index 00000000..42f036a8 --- /dev/null +++ b/configs/gym/pour_water/action_config.json @@ -0,0 +1,938 @@ +{ + "scope": { + "right_arm": { + "type": "DiGraph", + "dim": [ + 6 + ], + "init": { + "method": "current_qpos", + "init_node_name": "right_arm_init_qpos" + }, + "dtype": "float32" + }, + "left_arm": { + "type": "DiGraph", + "dim": [ + 6 + ], + "init": { + "method": "current_qpos", + "init_node_name": "left_arm_init_qpos" + }, + "dtype": "float32" + }, + "left_eef": { + "type": "DiGraph", + "dim": [ + 1 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 1 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + }, + "right_eef": { + "type": "DiGraph", + "dim": [ + 1 + ], + "init": { + "method": "given_qpos", + "kwargs": { + "given_qpos": [ + 1 + ] + }, + "init_node_name": "" + }, + "dtype": "float32" + } + }, + "node": { + "right_arm": [ + { + "right_arm_init_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "right_arm_init_qpos", + "dst_key": "right_arm_init_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_fk_xpos", + "kwargs": { + "control_part": "right_arm", + "fk_func": "env.robot.compute_fk" + } + } + ] + } + ] + } + ] + + } + } + }, + { + "right_arm_aim_qpos": { + "name": "generate_right_arm_aim_qpos", + "kwargs": {} + } + }, + { + "bottle_grasp": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_pose", + "dst_key": "bottle_grasp_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_rotation_replaced_pose", + "kwargs": { + "rotation_value": "env.affordance_datas['right_arm_aim_qpos'][0]", + "rot_axis": "z", + "mode": "intrinsic" + } + }, + { + "name": "get_frame_changed_pose", + "kwargs": { + "frame_change_matrix": "env.affordance_datas['bottle_pose']", + "mode": "intrinsic", + "inverse": true + } + }, + { + "name": "get_frame_changed_pose", + "kwargs": { + "frame_change_matrix": "env.affordance_datas['bottle_grasp_pose']", + "mode": "intrinsic" + } + } + ] + } + ] + }, + { + "src_key": "bottle_grasp_pose", + "dst_key": "bottle_pre1_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_offset_pose", + "kwargs": { + "offset_value": -0.05, + "direction": "z", + "mode": "intrinsic" + } + } + ] + } + ] + }, + { + "src_key": "bottle_pre1_pose", + "dst_key": "bottle_pre2_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_offset_pose", + "kwargs": { + "offset_value": -0.05, + "direction": "z", + "mode": "intrinsic" + } + } + ] + } + ] + } + ] + } + } + }, + { + "bottle_pre1_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_pre1_pose", + "dst_key": "bottle_pre1_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "get_ik_ret", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['right_arm_init_qpos']", + "control_part": "right_arm" + }, + "pass_processes": [ + { + "name": "get_ik_qpos", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['right_arm_init_qpos']", + "control_part": "right_arm" + } + } + ] + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + } + } + }, + { + "bottle_grasp_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_grasp_pose", + "dst_key": "bottle_grasp_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "get_ik_ret", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['bottle_pre1_qpos']", + "control_part": "right_arm" + }, + "pass_processes": [ + { + "name": "get_ik_qpos", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['bottle_pre1_qpos']", + "control_part": "right_arm" + } + } + ] + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + + } + } + }, + { + "bottle_up_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_grasp_qpos", + "dst_key": "bottle_up_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_offset_qpos", + "kwargs": { + "offset_value": -0.05, + "joint_list_offset": [1] + } + } + ] + }, + { + "name": "is_qpos_exceed_new", + "kwargs": { + "robot": "env.robot", + "control_part": "right_arm" + } + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + } + } + }, + { + "pour_water_start_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_grasp_pose", + "dst_key": "pour_water_start_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name":"no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_replaced_pose", + "kwargs": { + "pose_replace_value": "env.affordance_datas['cup_pose'][:2,3]", + "axis_str_replace": "xy" + } + }, + { + "name": "get_offset_pose", + "kwargs": { + "offset_value": 0.05, + "direction": "x", + "mode": "extrinsic" + } + }, + { + "name": "get_offset_pose", + "kwargs": { + "offset_value": -0.10, + "direction": "y", + "mode": "extrinsic" + } + }, + { + "name": "get_offset_pose", + "kwargs": { + "offset_value": 0.125, + "direction": "z", + "mode": "extrinsic" + } + } + ] + } + ] + }, + { + "src_key": "pour_water_start_pose", + "dst_key": "pour_water_start_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "get_ik_ret", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['bottle_up_qpos']", + "control_part": "right_arm" + }, + "pass_processes": [ + { + "name": "get_ik_qpos", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['bottle_up_qpos']", + "control_part": "right_arm" + } + } + ] + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + } + } + }, + { + "pour_water_stop_pose": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "pour_water_start_qpos", + "dst_key": "bottle_rotation_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_offset_qpos", + "kwargs": { + "offset_value": -75, + "joint_list_offset": [5], + "degrees": true + } + } + ] + }, + { + "name": "is_qpos_exceed_new", + "kwargs": { + "robot": "env.robot", + "control_part": "right_arm" + } + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + }, + { + "src_key": "bottle_rotation_qpos", + "dst_key": "pour_water_stop_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_fk_xpos", + "kwargs": { + "control_part": "right_arm", + "fk_func": "env.robot.compute_fk" + } + } + ] + } + ] + } + ] + } + } + }, + { + "bottle_place_pose": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_grasp_pose", + "dst_key": "bottle_place_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name":"no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_replaced_pose", + "kwargs": { + "pose_replace_value": [0.7, -0.1], + "axis_str_replace": "xy" + } + } + ] + } + ] + }, + { + "src_key": "bottle_place_pose", + "dst_key": "bottle_place_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "get_ik_ret", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['pour_water_start_qpos']", + "control_part": "right_arm" + }, + "pass_processes": [ + { + "name": "get_ik_qpos", + "kwargs": { + "ik_func": "env.robot.compute_ik", + "qpos_seed": "env.affordance_datas['pour_water_start_qpos']", + "control_part": "right_arm" + } + } + ] + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + } + } + }, + { + "bottle_pre_place_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "bottle_place_qpos", + "dst_key": "bottle_pre_place_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_offset_qpos", + "kwargs":{ + "offset_value": -0.05, + "joint_list_offset": [1] + } + } + ] + }, + { + "name": "is_qpos_exceed_new", + "kwargs": { + "robot": "env.robot", + "control_part": "right_arm" + } + }, + { + "name": "is_qpos_flip", + "kwargs": { + "qpos_ref": "env.affordance_datas['right_arm_init_qpos']", + "qpos_ids": [3, 4], + "threshold": 3.455751918948773, + "mode": "delta", + "return_inverse": true + } + } + ] + } + ] + } + } + }, + { + "compute_unoffset_for_exp":{ + "name": "compute_unoffset_for_exp", + "kwargs": { + "pose_input_output_names_changes": { + "bottle_grasp_pose": { + "output_pose_name":"bottle_grasp_pose_object_unoffset", + "pose_changes": [ + ["framechange_extrinsic_inverse", "env.affordance_datas['bottle_pose']"], + ["rotation_z_intrinsic_degrees", 90], + ["offset_-z_intrinsic", 0.025], + ["rotation_z_intrinsic_degrees", -90] + ] + }, + "bottle_place_pose": { + "output_pose_name":"bottle_place_pose_unoffset", + "pose_changes": [ + ["rotation_z_intrinsic_degrees", 90], + ["offset_-z_intrinsic", 0.025], + ["rotation_z_intrinsic_degrees", -90] + ] + }, + "pour_water_start_pose": { + "output_pose_name":"pour_water_start_pose_unoffset", + "pose_changes": [ + ["rotation_z_intrinsic_degrees", 90], + ["offset_-z_intrinsic", 0.025], + ["rotation_z_intrinsic_degrees", -90] + ] + }, + "pour_water_stop_pose": { + "output_pose_name":"pour_water_stop_pose_unoffset", + "pose_changes": [ + ["rotation_z_intrinsic_degrees", 90], + ["offset_-z_intrinsic", 0.025], + ["rotation_z_intrinsic_degrees", -90] + ] + } + } + } + } + } + ], + "left_arm": [ + { + "left_arm_init_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "left_arm_init_qpos", + "dst_key": "left_arm_init_pose", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_fk_xpos", + "kwargs": { + "control_part": "left_arm", + "fk_func": "env.robot.compute_fk" + } + } + ] + } + ] + } + ] + + } + } + }, + { + "left_monitor_qpos": { + "name": "generate_affordances_from_src", + "kwargs": { + "affordance_infos": [ + { + "src_key": "left_arm_init_qpos", + "dst_key": "left_arm_monitor_qpos", + "valid_funcs_name_kwargs_proc": [ + { + "name": "no_validation", + "kwargs": {}, + "pass_processes": [ + { + "name": "get_replaced_qpos", + "kwargs": { + "replace_value": [-0.6, 1.0, -1.2, 0.0, 0.58, 0.0], + "joint_list_replace": [0, 1, 2, 3, 4, 5] + } + } + ] + } + ] + } + ] + } + } + } + ], + "left_eef": [ + { + "open": { + "name": "execute_open", + "kwargs": {} + } + }, + { + "close": { + "name": "execute_close", + "kwargs": {} + } + } + ], + "right_eef": [ + { + "ropen": { + "name": "execute_open", + "kwargs": {} + } + }, + { + "rclose": { + "name": "execute_close", + "kwargs": {} + } + } + ] + }, + "edge": { + "right_arm": [ + { + "init_to_aim": { + "src": "right_arm_init_qpos", + "sink": "right_arm_aim_qpos", + "duration": 10, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "right_arm_init_qpos", + "right_arm_aim_qpos" + ] + } + } + }, + { + "aim_to_pre1": { + "src": "right_arm_aim_qpos", + "sink": "bottle_pre1_qpos", + "duration": 24, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "right_arm_aim_qpos", + "bottle_pre1_qpos" + ] + } + } + }, + { + "pre1_to_grasp": { + "src": "bottle_pre1_qpos", + "sink": "bottle_grasp_qpos", + "duration": 24, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_pre1_qpos", + "bottle_grasp_qpos" + ] + } + } + }, + { + "grasp_to_up": { + "src": "bottle_grasp_qpos", + "sink": "bottle_up_qpos", + "duration": 4, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_grasp_qpos", + "bottle_up_qpos" + ] + } + } + }, + { + "up_to_move": { + "src": "bottle_up_qpos", + "sink": "pour_water_start_qpos", + "duration": 20, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_up_qpos", + "pour_water_start_qpos" + ] + } + } + }, + { + "move_to_rotation": { + "src": "pour_water_start_qpos", + "sink": "bottle_rotation_qpos", + "duration": 24, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "pour_water_start_qpos", + "bottle_rotation_qpos" + ] + } + } + }, + { + "rotation_back_to_move": { + "src": "bottle_rotation_qpos", + "sink": "pour_water_start_qpos", + "duration": 24, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_rotation_qpos", + "pour_water_start_qpos" + ] + } + } + }, + { + "move_back_to_pre_place": { + "src": "pour_water_start_qpos", + "sink": "bottle_pre_place_qpos", + "duration": 20, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "pour_water_start_qpos", + "bottle_pre_place_qpos" + ] + } + } + }, + { + "pre_place_back_to_place": { + "src": "bottle_pre_place_qpos", + "sink": "bottle_place_qpos", + "duration": 4, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_pre_place_qpos", + "bottle_place_qpos" + ] + } + } + }, + { + "place_back_to_init": { + "src": "bottle_place_qpos", + "sink": "right_arm_init_qpos", + "duration": 24, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "right_arm", + "keypose_names": [ + "bottle_place_qpos", + "right_arm_init_qpos" + ] + } + } + } + ], + "left_arm": [ + { + "left_init_to_monitor": { + "src": "left_arm_init_qpos", + "sink": "left_monitor_qpos", + "duration": 15, + "name": "plan_trajectory", + "kwargs": { + "agent_uid": "left_arm", + "keypose_names": [ + "left_arm_init_qpos", + "left_arm_monitor_qpos" + ] + } + } + }, + { + "left_arm_go_back": { + "src": "left_monitor_qpos", + "sink": "left_arm_init_qpos", + "duration": 15, + "kwargs": {} + } + } + ], + "left_eef": [], + "right_eef": [ + { + "rclose0": { + "src": "ropen", + "sink": "rclose", + "duration": 11, + "name": "execute_close", + "kwargs": { + "return_action": true, + "expand": true + } + } + }, + { + "ropen0": { + "src": "rclose", + "sink": "ropen", + "duration": 11, + "name": "execute_open", + "kwargs": { + "return_action": true, + "expand": true + } + } + } + ] + }, + "sync": { + "rclose0": { + "depend_tasks": [ + "pre1_to_grasp" + ] + }, + "grasp_to_up": { + "depend_tasks": [ + "rclose0" + ] + }, + "ropen0": { + "depend_tasks": [ + "pre_place_back_to_place" + ] + }, + "place_back_to_init": { + "depend_tasks": [ + "ropen0" + ] + }, + "left_arm_go_back": { + "depend_tasks": [ + "ropen0" + ] + } + }, + "misc": { + "vis_graph": false, + "vis_gantt": false, + "warpping": true + } +} diff --git a/configs/gym/pour_water/gym_config.json b/configs/gym/pour_water/gym_config.json new file mode 100644 index 00000000..4207f438 --- /dev/null +++ b/configs/gym/pour_water/gym_config.json @@ -0,0 +1,459 @@ +{ + "id": "PourWater-v3", + "max_episodes": 5, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0] + } + }, + "init_bottle_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "bottle"}, + "position_range": [[-0.08, -0.12, 0.0], [0.08, 0.04, 0.0]], + "relative_position": true + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cup"}, + "position_range": [[-0.08, -0.04, 0.0], [0.08, 0.12, 0.0]], + "relative_position": true + } + }, + "prepare_extra_attr": { + "func": "prepare_extra_attr", + "mode": "reset", + "params": { + "attrs": [ + { + "name": "object_lengths", + "mode": "callable", + "entity_uids": "all_objects", + "func_name": "compute_object_length", + "func_kwargs": { + "is_svd_frame": true, + "sample_points": 5000 + } + }, + { + "name": "grasp_pose_object", + "mode": "static", + "entity_cfg": { + "uid": "bottle" + }, + "value": [[ + [0.32243, 0.03245, 0.94604, 0.025], + [0.00706, -0.99947, 0.03188, -0.0 ], + [0.94657, -0.0036 , -0.32249, 0.0 ], + [0.0 , 0.0 , 0.0 , 1.0 ] + ]] + }, + { + "name": "left_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "left_arm_base", + "to_matrix": true + } + }, + { + "name": "right_arm_base_pose", + "mode": "callable", + "entity_cfg": { + "uid": "CobotMagic" + }, + "func_name": "get_link_pose", + "func_kwargs": { + "link_name": "right_arm_base", + "to_matrix": true + } + } + ] + } + }, + "register_info_to_env": { + "func": "register_info_to_env", + "mode": "reset", + "params": { + "registry": [ + { + "entity_cfg": { + "uid": "bottle" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "cup" + }, + "pose_register_params": { + "compute_relative": false, + "compute_pose_object_to_arena": true, + "to_matrix": true + } + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["left_arm"] + }, + "attrs": ["left_arm_base_pose"], + "pose_register_params": { + "compute_relative": "cup", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + }, + { + "entity_cfg": { + "uid": "CobotMagic", + "control_parts": ["right_arm"] + }, + "attrs": ["right_arm_base_pose"], + "pose_register_params": { + "compute_relative": "bottle", + "compute_pose_object_to_arena": false, + "to_matrix": true + }, + "prefix": false + } + ], + "registration": "affordance_datas", + "sim_update": true + } + }, + "random_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 2, + "params": { + "entity_cfg": {"uid": "table"}, + "random_texture_prob": 0.5, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_robot_material": { + "func": "randomize_visual_material", + "mode": "interval", + "interval_step": 5, + "params": { + "entity_cfg": {"uid": "CobotMagic", "link_names": [".*"]}, + "random_texture_prob": 0.5, + "texture_path": "CocoBackground/coco", + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]] + } + }, + "random_camera_intrinsics": { + "func": "randomize_camera_intrinsics", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "cam_high"}, + "focal_x_range": [-50, 50], + "focal_y_range": [-50, 50] + } + }, + "random_robot_init_eef_pose": { + "func": "randomize_robot_eef_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "CobotMagic", "control_parts": ["left_arm", "right_arm"]}, + "position_range": [[-0.01, -0.01, -0.01], [0.01, 0.01, 0]] + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [2, 0, 2], + "target": [0.5, 0, 1] + } + } + }, + "observations": { + "norm_robot_eef_joint": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [12, 13, 14, 15] + } + }, + "bottle_pose": { + "func": "get_rigid_object_pose", + "mode": "add", + "name": "bottle_pose", + "params": { + "entity_cfg": {"uid": "bottle"} + } + }, + "cup_pose": { + "func": "get_rigid_object_pose", + "mode": "add", + "name": "cup_pose", + "params": { + "entity_cfg": {"uid": "cup"} + } + }, + "cam_high_semantic_mask_l": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_high/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_high"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "cam_high_semantic_mask_r": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_high/semantic_mask_r", + "params": { + "entity_cfg": {"uid": "cam_high"}, + "foreground_uids": ["bottle", "cup"], + "is_right": true + } + }, + "cam_left_semantic_mask": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_left_wrist/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_left_wrist"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "cam_right_semantic_mask": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_right_wrist/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_right_wrist"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "exteroception": { + "func": "compute_exteroception", + "mode": "add", + "name": "exteroception", + "params": { + "descriptor": { + "all_sensors": [ + { + "type": "robot", + "control_part": "left_arm", + "offset": [0.0, 0.0, 0.025], + "follow_eef": true + }, + { + "type": "robot", + "control_part": "right_arm", + "offset": [0.0, 0.0, 0.025], + "follow_eef": true + }, + { + "type": "affordance", + "obj_uid": "bottle", + "key": "bottle_grasp_pose_object_unoffset", + "is_arena_coord": false, + "follow_eef": "right_eef" + }, + { + "type": "affordance", + "obj_uid": "bottle", + "key": "pour_water_start_pose_unoffset", + "is_arena_coord": true, + "follow_eef": "right_eef" + }, + { + "type": "affordance", + "obj_uid": "bottle", + "key": "pour_water_stop_pose_unoffset", + "is_arena_coord": true, + "follow_eef": "right_eef" + }, + { + "type": "affordance", + "obj_uid": "bottle", + "key": "bottle_place_pose_unoffset", + "is_arena_coord": true + } + ] + }, + "x_interval": 0.02, + "y_interval": 0.05, + "kpnts_number": 2, + "groups": 6 + } + } + }, + "dataset": { + "instruction": { + "lang": "Pour water from the bottle into the mug." + }, + "robot_meta": { + "arm_dofs": 12, + "control_freq": 25, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], + "observation": { + "vision": { + "cam_high": ["mask", "exteroception"], + "cam_right_wrist": ["mask", "exteroception"], + "cam_left_wrist": ["mask", "exteroception"] + }, + "states": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "min_len_steps": 5 + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "CobotMagic", + "robot_type": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [-0.3,0.3,1.0,1.0,-1.2,-1.2,0.0,0.0,0.6,0.6,0.0,0.0,0.05,0.05,0.05,0.05] + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": true, + "enable_depth": true, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "eye": [0.35368482807598, 0.014695524383058989, 1.4517046071614774], + "target": [0.7186357573287919, -0.054534732904795505, 0.5232553674540066], + "up": [0.9306678549330372, -0.0005600064212467153, 0.3658647703553347] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 480, + "enable_mask": true, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "right_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 480, + "enable_mask": true, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "extrinsics": { + "parent": "left_link6", + "pos": [-0.08, 0.0, 0.04], + "quat": [0.15304635, 0.69034543, -0.69034543, -0.15304635] + } + } + ], + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 50.0, + "init_pos": [2, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply" + }, + "attrs" : { + "mass": 10.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.01 + }, + "body_scale": [1, 1, 1], + "body_type": "kinematic", + "init_pos": [0.725, 0.0, 0.825], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid":"cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply" + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, 0.1, 0.9], + "body_scale":[0.75, 0.75, 1.0], + "max_convex_hull_num": 8 + }, + { + "uid":"bottle", + "shape": { + "shape_type": "Mesh", + "fpath": "ScannedBottle/kashijia_processed.ply" + }, + "attrs" : { + "mass": 0.01, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.01, + "max_depenetration_velocity": 1e1, + "min_position_iters": 32, + "min_velocity_iters":8 + }, + "init_pos": [0.75, -0.1, 0.932], + "body_scale":[1, 1, 1], + "max_convex_hull_num": 8 + } + ] +} \ No newline at end of file diff --git a/configs/gym/real2sim/reality_config_PourWaterW1Single_v3.json b/configs/gym/real2sim/reality_config_PourWaterW1Single_v3.json new file mode 100644 index 00000000..e7801584 --- /dev/null +++ b/configs/gym/real2sim/reality_config_PourWaterW1Single_v3.json @@ -0,0 +1,582 @@ +{ + "id": "pour_water_single_real2sim", + "max_episodes": 10, + "env": { + "observations": { + "norm_robot_eef_joint": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [7, 15] + } + }, + "bottle_pose": { + "func": "get_rigid_object_pose", + "mode": "add", + "name": "bottle_pose", + "params": { + "entity_cfg": {"uid": "bottle"} + } + }, + "cup_pose": { + "func": "get_rigid_object_pose", + "mode": "add", + "name": "cup_pose", + "params": { + "entity_cfg": {"uid": "cup"} + } + }, + "cam_high_semantic_mask_l": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_high/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_high"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "cam_high_semantic_mask_r": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_high/semantic_mask_r", + "params": { + "entity_cfg": {"uid": "cam_high"}, + "foreground_uids": ["bottle", "cup"], + "is_right": true + } + }, + "cam_left_semantic_mask": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_left_wrist/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_left_wrist"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "cam_right_semantic_mask": { + "func": "compute_semantic_mask", + "mode": "add", + "name": "sensor/cam_right_wrist/semantic_mask_l", + "params": { + "entity_cfg": {"uid": "cam_right_wrist"}, + "foreground_uids": ["bottle", "cup"] + } + }, + "exteroception": { + "func": "compute_exteroception", + "mode": "add", + "name": "exteroception", + "params": { + "descriptor": { + "cam_high": [ + { + "type": "robot", + "control_part": "left_arm" + }, + { + "type": "robot", + "control_part": "right_arm" + } + ] + }, + "x_interval": 0.02, + "y_interval": 0.05, + "kpnts_number": 2, + "groups": 2 + } + } + }, + "dataset": { + "dir_path": "", + "instruction": { + "lang": "Pour water from the bottle into the cup." + }, + "robot_meta": { + "arm_dofs": 14, + "control_freq": 25, + "qpos_to_control": [0, 1, 2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14], + "observation": { + "vision": { + "cam_high": ["mask"], + "cam_right_wrist": ["mask"], + "cam_left_wrist": ["mask"] + }, + "proprioception": ["qpos"], + "exteroception": ["cam_high", "cam_right_wrist", "cam_left_wrist"] + }, + "action": "qpos_with_eef_pose", + "min_len_steps": 30 + } + }, + "success_params": { + "strict": false + } + }, + "robot": { + "uid": "dexforce_w1", + "init_pos": [0.14, 0.0, 0.04], + "init_qpos": [0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.57923,0.0,-1.57003,-1.2,0.4,0.0,0.0,0.57923,0.0,1.57003,1.2,-0.4,0.0,2.5,0.0,0.0,0.0,0.0,0.0,2.5,0.0,0.0,0.0,0.0,0.0], + "solver_cfg": { + "left_arm": { + "ik_nearest_weight": [5, 5, 1, 5, 1, 1, 1] + }, + "right_arm": { + "ik_nearest_weight": [5, 5, 1, 5, 1, 1, 1], + "tcp": [[ 1.0 , 0.0 , 0.0 , 0.0 ], + [ 0.0 , 0.0 , -1.0 , -0.04], + [ 0.0 , 1.0 , 0.0 , 0.14], + [ 0.0 , 0.0 , 0.0 , 1.0 ]] + } + } + }, + "sensor": [ + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "sensor_name": "cam_left_wrist", + "width": 640, + "height": 480, + "intrinsics": [487.39422607421875, 487.39422607421875, 320.3005676269531, 210.7530517578125], + "enable_mask": true, + "attach_link": "left_ee", + "attach_xpos": [ + [ 0.70711, -0.40558, 0.57923, 0.09 ], + [-0.0 , -0.81915, -0.57358, -0.05 ], + [ 0.70711, 0.40558, -0.57923, 0.04 ], + [ 0.0 , 0.0 , 0.0 , 1.0 ] + ], + "set_near": 0.005, + "set_far": 50 + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "sensor_name": "cam_right_wrist", + "width": 640, + "height": 480, + "intrinsics": [488.1665344238281, 488.1665344238281, 322.7323303222656, 213.17434692382812], + "enable_mask": true, + "attach_link": "right_ee", + "attach_xpos": [ + [-0.70711, -0.40558, 0.57923, 0.09 ], + [-0.0 , 0.81915, 0.57358, 0.05 ], + [-0.70711, 0.40558, -0.57923, 0.04 ], + [ 0.0 , 0.0 , 0.0 , 1.0 ] + ], + "set_near": 0.005, + "set_far": 50 + }, + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "sensor_name": "cam_high", + "width": 960, + "height": 540, + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "enable_mask": true, + "enable_depth": true, + "relativate_T": [ + [ + 0.99996440327, + 0.000856048544, + 0.008394008265, + 0.059684025824163614 + ], + [ + -0.00085599875, + 0.999999633588, + -9.524764e-06, + 1.064844737251626e-05 + ], + [ + -0.008394013343, + 2.339165e-06, + 0.999964769647, + 0.0002219304982263564 + ], + [ + 0, + 0, + 0, + 1 + ] + ], + "set_near": 0.005, + "set_far": 50, + "camera_position": [ + 0.35368482807598, + 0.014695524383058989, + 1.4517046071614774 + ], + "camera_look_at": [ + 0.7186357573287919, + -0.054534732904795505, + 0.5232553674540066 + ], + "up_vector": [ + 0.9306678549330372, + -0.0005600064212467153, + 0.3658647703553347 + ] + } + ], + "light": { + "direct": [ + { + "uid": "light1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 100, + "init_pos": [0, 0, 2], + "radius": 10.0 + } + ] + }, + "background": [ + { + "uid": "table", + "shape": { "shape_type": "Mesh", "fpath": "static_scenes/static_scene_3/circle_table_simple.ply", "compute_uv": false }, + "body_scale": [1.0, 0.8, 0.8], + "init_local_pose": [ + [2.220446049250313e-16, 0.0, 1.0, 0.7], + [0.0, 1.0, 0.0, 0.0], + [-1.0, 0.0, 2.220446049250313e-16, 0.92], + [0.0, 0.0, 0.0, 1.0] + ], + "body_type": "kinematic", + "max_convex_hull_num": 1 + } + ], + "rigid_object": [ + { + "uid": "bottle", + "shape": { "shape_type": "Mesh", "fpath": "objects/object_6/xingbake_processed.ply" }, + "attrs": { + "mass": 0.005, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.1, + "max_depenetration_velocity": 1e1 + }, + "body_scale": [1, 1, 1], + "init_local_pose": [ + [1.0, 0.0, 0.0, 0.59013], + [0.0, 1.0, 0.0, -0.02475], + [0.0, 0.0, 1.0, 1.06664], + [0.0, 0.0, 0.0, 1.0] + ], + "max_convex_hull_num": 8 + }, + { + "uid": "cup", + "shape": { "shape_type": "Mesh", "fpath": "objects/object_7/paper_cup.ply" }, + "attrs": { + "mass": 0.005, + "contact_offset": 0.003, + "rest_offset": 0.001, + "restitution": 0.1, + "max_depenetration_velocity": 1e1 + }, + "body_scale": [0.75, 0.75, 1.0], + "init_local_pose": [ + [0.9986, -0.00476, 0.05268, 0.66128], + [0.01116, 0.99248, -0.12189, 0.1], + [-0.0517, 0.12231, 0.99114, 1.0], + [0.0, 0.0, 0.0, 1.0] + ], + "max_convex_hull_num": 8 + } + ], + "task": { + "name": "pour_water_single", + "data": { + "0": { + "trajectory": { + "path": "demo1", + "sample_ratio": 0.1, + "scope": { + "right_arm": [ + 8, + 9, + 10, + 11, + 12, + 13, + 14 + ], + "left_arm": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ], + "left_eef": [ + 7 + ], + "right_eef": [ + 15 + ] + } + }, + "node": { + "right_arm": [ + { + "0": { + "affordance_name": "grasp_qpos", + "master": "", + "slaver": "bottle", + "timestep": 122, + "mimicable": true, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "1": { + "affordance_name": "after_pour_qpos", + "master": "bottle", + "slaver": "", + "timestep": 436, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "3": { + "affordance_name": "place_qpos", + "master": "bottle", + "slaver": "", + "timestep": 502, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "4": { + "affordance_name": "pour_qpos", + "master": "bottle", + "slaver": "cup", + "timestep": 327, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "5": { + "affordance_name": "before_pour_qpos", + "master": "bottle", + "slaver": "", + "timestep": 207, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + } + } + ], + "left_arm": [], + "right_eef": [ + { + "0": { + "affordance_name": "right_close", + "master": "", + "slaver": "bottle", + "timestep": 123, + "mimicable": false, + "duration": 20, + "trajectory": { + "name": "execute_close", + "kwargs": {} + } + }, + "1": { + "affordance_name": "right_open", + "master": "", + "slaver": "bottle", + "timestep": 502, + "mimicable": false, + "duration": 20, + "trajectory": { + "name": "execute_open", + "kwargs": {} + } + } + + } + ], + "left_eef": [] + }, + "sync": { + "right_eefhand_init_qpos": { + "depend_tasks": [ + "grasp_qpos" + ] + }, + "grasp_qpos": { + "depend_tasks": [ + "right_close" + ] + }, + "right_close": { + "depend_tasks": [ + "place_qpos" + ] + }, + "place_qpos": { + "depend_tasks": [ + "right_open" + ] + } + } + }, + "1": { + "trajectory": { + "path": "demo2", + "sample_ratio": 0.1, + "scope": { + "right_arm": [ + 8, + 9, + 10, + 11, + 12, + 13, + 14 + ], + "left_arm": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6 + ], + "left_eef": [ + 7 + ], + "right_eef": [ + 15 + ] + } + }, + "node": { + "right_arm": [ + { + "0": { + "affordance_name": "grasp_qpos", + "master": "", + "slaver": "bottle", + "timestep": 163, + "mimicable": true, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "1": { + "affordance_name": "pour_qpos", + "master": "bottle", + "slaver": "cup", + "timestep": 400, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "3": { + "affordance_name": "place_qpos", + "master": "bottle", + "slaver": "", + "timestep": 692, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "4": { + "affordance_name": "after_pour_qpos", + "master": "bottle", + "slaver": "", + "timestep": 532, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + }, + "5": { + "affordance_name": "before_pour_qpos", + "master": "bottle", + "slaver": "", + "timestep": 232, + "trajectory": { + "name": "load_trajectory", + "kwargs": {} + } + } + } + ], + "left_arm": [], + "right_eef": [ + { + "0": { + "affordance_name": "right_close", + "master": "", + "slaver": "bottle", + "timestep": 210, + "mimicable": false, + "duration": 20, + "trajectory": { + "name": "execute_close", + "kwargs": {} + } + }, + "1": { + "affordance_name": "right_open", + "master": "", + "slaver": "bottle", + "timestep": 692, + "mimicable": false, + "duration": 20, + "trajectory": { + "name": "execute_open", + "kwargs": {} + } + } + + } + ], + "left_eef": [] + }, + "sync": { + "right_eefhand_init_qpos": { + "depend_tasks": [ + "grasp_qpos" + ] + }, + "grasp_qpos": { + "depend_tasks": [ + "right_close" + ] + }, + "right_close": { + "depend_tasks": [ + "place_qpos" + ] + }, + "place_qpos": { + "depend_tasks": [ + "right_open" + ] + } + } + } + } + } +} diff --git a/configs/gym/scoop_ice/gym_config.json b/configs/gym/scoop_ice/gym_config.json new file mode 100644 index 00000000..eed397ac --- /dev/null +++ b/configs/gym/scoop_ice/gym_config.json @@ -0,0 +1,254 @@ +{ + "id": "ScoopIce-v1", + "max_episodes": 5, + "env": { + "events": { + "drop_ice":{ + "func": "drop_rigid_object_group_sequentially", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "ice_cubes"}, + "drop_position": [0.5, -0.05, 1.0], + "position_range": [[-0.12, -0.12, 0], [0.12, 0.12, 0]], + "physics_step": 10 + } + }, + "init_scoop_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "scoop"}, + "position_range": [[0.45, -0.27, 1.05], [0.45, -0.27, 1.05]], + "rotation_range": [[30, -25, 180], [30, -25, 180]], + "relative_position": false, + "relative_rotation": false + } + }, + "init_cup_pose": { + "func": "randomize_rigid_object_pose", + "mode": "reset", + "params": { + "entity_cfg": {"uid": "paper_cup"}, + "position_range": [[0.455, 0.112, 1.05], [0.455, 0.112, 1.05]], + "rotation_range": [[0, 0, 0], [0, 0, 0]], + "relative_position": false, + "relative_rotation": false + } + }, + "record_camera": { + "func": "record_camera_data", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "cam1", + "resolution": [320, 240], + "eye": [0, -1, 2], + "target": [0.5, 0, 1] + } + } + }, + "dataset": { + "instruction": { + "lang": "Scoop ice." + }, + "robot_meta": { + "arm_dofs": 14, + "control_freq": 25, + "qpos_to_control": [6, 8, 10, 12, 14, 16, 18, 24, 7, 9, 11, 13, 15, 17, 19, 29], + "observation": { + "vision": { + "cam_high": [], + "cam_right_wrist": [], + "cam_left_wrist": [] + }, + "states": ["qpos"] + }, + "min_len_steps": 10 + } + } + }, + "robot": { + "robot_type": "DexforceW1", + "init_pos": [0.0, 0.0, 0], + "init_qpos":[ + 0.42241, -1.11061, 0.55116, 0.01815, 0.00002, -0.43273, + -0.30339, 0.2412 , -1.20074, 0.72621, 0.40264, -0.41044, + -1.34341, 1.27664, -0.42869, -0.30873, -0.23608, 0.54272, + -0.12095, -0.16959, 0.00011, 0.00002, 0.00009, 0.00003, + 1.50006, 0.30045, 0.30028, 0.30051, 0.30037, 0.90003, + 0.23262, 0.24298, 0.22003, 0.19901, -0 , 0.59305, + 0.6033 , 0.58056, 0.55942, 0.29999 + ], + "solver_cfg": { + "left_arm": { + "class_type": "PytorchSolver", + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + "tcp": + [ + [-1.0, 0.0, 0.0, 0.012], + [0.0, 0.0, 1.0, 0.0675], + [0.0, 1.0, 0.0, 0.127], + [0.0, 0.0, 0.0, 1.0] + ] + }, + "right_arm": { + "class_type": "PytorchSolver", + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + "tcp": + [ + [1.0, 0.0, 0.0, 0.012], + [0.0, 0.0, -1.0, -0.0675], + [0.0, 1.0, 0.0, 0.127], + [0.0, 0.0, 0.0, 1.0] + ] + } + } + }, + "sensor": [ + { + "sensor_type": "StereoCamera", + "uid": "cam_high", + "width": 960, + "height": 540, + "enable_mask": true, + "enable_depth": true, + "left_to_right_pos": [0.059684025824163614, 0, 0], + "intrinsics": [453.851402686215, 453.8347628855552, 469.827725021235, 258.6656181845155], + "intrinsics_right": [453.4536601653505, 453.3306024582175, 499.13697412367776, 297.7176248477935], + "extrinsics": { + "parent": "eyes" + } + }, + { + "sensor_type": "Camera", + "uid": "cam_right_wrist", + "width": 640, + "height": 360, + "enable_mask": true, + "intrinsics": [337.0, 325.0, 320.0, 180.0], + "extrinsics": { + "parent": "right_ee", + "pos": [0.09, 0.05, 0.04], + "quat": [0.36497168, -0.11507513, 0.88111957, 0.27781593] + } + }, + { + "sensor_type": "Camera", + "uid": "cam_left_wrist", + "width": 640, + "height": 360, + "enable_mask": true, + "intrinsics": [337.0, 325.0, 320.0, 180.0], + "extrinsics": { + "parent": "left_ee", + "pos": [0.09, -0.05, 0.04], + "quat": [0.27781593, 0.88111957, -0.11507513, 0.36497168] + } + } + ], + "light": { + }, + "background": [ + { + "uid": "table", + "shape": { + "shape_type": "Mesh", + "fpath": "CircleTableSimple/circle_table_simple.ply" + }, + "attrs" : { + "mass": 1.0, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.05 + }, + "body_type": "kinematic", + "init_pos": [0.80, 0, 0.54], + "init_rot": [0, 90, 0] + } + ], + "rigid_object": [ + { + "uid": "scoop", + "shape": { + "shape_type": "Mesh", + "fpath": "ScoopIceNewEnv/scoop.ply" + }, + "attrs" : { + "mass": 0.5, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.0, + "min_position_iters": 32, + "min_velocity_iters": 8 + }, + "max_convex_hull_num": 8, + "init_pos": [0, 10, 10] + }, + { + "uid": "paper_cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply" + }, + "attrs" : { + "mass": 0.5, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "restitution": 0.0, + "min_position_iters": 32, + "min_velocity_iters": 8 + }, + "max_convex_hull_num": 16, + "init_pos": [0, 10, 10] + } + ], + "rigid_object_group": [ + { + "uid": "ice_cubes", + "max_num": 300, + "folder_path": "ScoopIceNewEnv/ice_mesh_small", + "ext": ".obj", + "rigid_objects": { + "obj": { + "attrs" : { + "mass": 0.004, + "contact_offset": 0.001, + "rest_offset": 0, + "dynamic_friction": 0.05, + "static_friction": 0.1, + "restitution": 0.00, + "min_position_iters": 32, + "min_velocity_iters": 8, + "max_depenetration_velocity": 1.0 + }, + "shape": { + "shape_type": "Mesh" + }, + "init_pos": [0, 0, 2], + "body_scale": [1.0, 1.0, 1.0] + } + } + } + ], + "articulation": [ + { + "uid": "container", + "fpath": "ScoopIceNewEnv/IceContainer/ice_container.urdf", + "init_pos": [0.635, -0.04, 0.94], + "init_rot": [0, 0, -80], + "attrs": { + "mass": 1.0, + "dynamic_friction": 0.05, + "static_friction": 0.1, + "max_depenetration_velocity": 1.0 + }, + "drive_pros": { + "stiffness": 1.0, + "damping": 0.1, + "max_effort": 100.0 + } + } + ] +} \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..864eb2a7 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,21 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @rm -rf "$(BUILDDIR)" + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 00000000..747ffb7b --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..cacbd083 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,9 @@ +sphinx>=4.0 +sphinx-book-theme>=0.3.0 +sphinx-tabs +sphinx-copybutton +myst-parser +sphinx-autosummary-accessors +sphinxcontrib-bibtex +sphinx-design +sphinx_autodoc_typehints diff --git a/docs/source/_templates/module.rst b/docs/source/_templates/module.rst new file mode 100644 index 00000000..b3160842 --- /dev/null +++ b/docs/source/_templates/module.rst @@ -0,0 +1,58 @@ +{{ fullname | escape | underline}} + +.. automodule:: {{ fullname }} + :members: + :undoc-members: + :show-inheritance: + + {% block modules %} + {% if modules %} + .. rubric:: Modules + + .. autosummary:: + :toctree: + :template: module.rst + {% for item in modules %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: Classes + + .. autosummary:: + :toctree: + :template: class.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: Functions + + .. autosummary:: + :toctree: + :template: function.rst + {% for item in functions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: Exceptions + + .. autosummary:: + :toctree: + :template: class.rst + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.reshape_tiled_image.rst b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.reshape_tiled_image.rst new file mode 100644 index 00000000..510de4bd --- /dev/null +++ b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.reshape_tiled_image.rst @@ -0,0 +1,6 @@ +embodichain.utils.warp.kernels.reshape\_tiled\_image +==================================================== + +.. currentmodule:: embodichain.utils.warp.kernels + +.. autodata:: reshape_tiled_image \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.rst b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.rst new file mode 100644 index 00000000..fb1ba6ab --- /dev/null +++ b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kernels.rst @@ -0,0 +1,6 @@ +embodichain.utils.warp.kernels +============================== + +.. automodule:: embodichain.utils.warp.kernels + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.opw_solver.rst b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.opw_solver.rst new file mode 100644 index 00000000..7fec8013 --- /dev/null +++ b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.opw_solver.rst @@ -0,0 +1,6 @@ +embodichain.utils.warp.kinematics.opw\_solver +============================================= + +.. automodule:: embodichain.utils.warp.kinematics.opw_solver + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.rst b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.rst new file mode 100644 index 00000000..0e0f9b9c --- /dev/null +++ b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.rst @@ -0,0 +1,6 @@ +embodichain.utils.warp.kinematics +================================= + +.. automodule:: embodichain.utils.warp.kinematics + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.warp_trajectory.rst b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.warp_trajectory.rst new file mode 100644 index 00000000..aa9760cd --- /dev/null +++ b/docs/source/api_reference/embodichain/_autosummary/embodichain.utils.warp.kinematics.warp_trajectory.rst @@ -0,0 +1,6 @@ +embodichain.utils.warp.kinematics.warp\_trajectory +================================================== + +.. automodule:: embodichain.utils.warp.kinematics.warp_trajectory + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rst b/docs/source/api_reference/embodichain/embodichain.agents.rst new file mode 100644 index 00000000..1db72b60 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.agents.rst @@ -0,0 +1,12 @@ +embodichain.agents +================== + +.. automodule:: embodichain.agents + + .. rubric:: Submodules + + .. autosummary:: + + dexforce_vla + rl + diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst new file mode 100644 index 00000000..afc9df9a --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst @@ -0,0 +1,130 @@ +embodichain.lab.gym.envs.managers +========================================== + +.. automodule:: embodichain.lab.gym.envs.managers + + .. rubric:: Submodules + + .. autosummary:: + + randomization + + .. rubric:: Classes + + .. autosummary:: + + FunctorCfg + SceneEntityCfg + EventCfg + ObservationCfg + Functor + ManagerBase + EventManager + ObservationManager + + .. rubric:: Functions + + .. autosummary:: + + observations.get_rigid_object_pose + observations.normalize_robot_joint_data + observations.compute_semantic_mask + observations.compute_exteroception + events.replace_assets_from_group + record.record_camera_data + randomization.rendering.randomize_light + randomization.rendering.randomize_camera_intrinsics + randomization.rendering.randomize_visual_material + randomization.spatial.get_random_pose + randomization.spatial.randomize_rigid_object_pose + randomization.spatial.randomize_robot_eef_pose + randomization.spatial.randomize_robot_qpos + +.. currentmodule:: embodichain.lab.gym.envs.managers + +Configuration Classes +--------------------- + +.. autoclass:: FunctorCfg + :members: + :exclude-members: __init__, class_type + +.. autoclass:: SceneEntityCfg + :members: + :exclude-members: __init__, class_type + +.. autoclass:: EventCfg + :members: + :exclude-members: __init__, class_type + +.. autoclass:: ObservationCfg + :members: + :exclude-members: __init__, class_type + +Base Classes +------------ + +.. autoclass:: Functor + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: ManagerBase + :members: + :inherited-members: + :show-inheritance: + +Managers +-------- + +.. autoclass:: EventManager + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: ObservationManager + :members: + :inherited-members: + :show-inheritance: + +Observation Functions +-------------------- + +.. automodule:: embodichain.lab.gym.envs.managers.observations + :members: + +Event Functions +-------------- + +.. automodule:: embodichain.lab.gym.envs.managers.events + :members: + +Recording Functions +------------------ + +.. automodule:: embodichain.lab.gym.envs.managers.record + :members: + +Randomization +------------- + +.. automodule:: embodichain.lab.gym.envs.managers.randomization + + .. rubric:: Submodules + + .. autosummary:: + + rendering + spatial + + Rendering + ~~~~~~~~~~~~~~~~~~~~~~~ + + .. automodule:: embodichain.lab.gym.envs.managers.randomization.rendering + :members: + + Spatial + ~~~~~~~~~~~~~~~~~~~~~ + + .. automodule:: embodichain.lab.gym.envs.managers.randomization.spatial + :members: diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.rst new file mode 100644 index 00000000..24c568f9 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.rst @@ -0,0 +1,34 @@ +embodichain.lab.gym.envs +==================================== + +.. automodule:: embodichain.lab.gym.envs + + .. rubric:: Submodules + + .. autosummary:: + managers + +.. currentmodule:: embodichain.lab.gym.envs + +Environment Classes +------------------- + +.. currentmodule:: embodichain.lab.gym.envs + +.. autoclass:: BaseEnv + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: EnvCfg + :members: + :exclude-members: __init__, class_type + +.. autoclass:: EmbodiedEnv + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: EmbodiedEnvCfg + :members: + :exclude-members: __init__, class_type diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst new file mode 100644 index 00000000..e788b8f7 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst @@ -0,0 +1,153 @@ +embodichain.lab.gym +=============================== + +.. automodule:: embodichain.lab.gym + + .. rubric:: Submodules + + .. autosummary:: + :toctree: . + + envs + utils + +.. currentmodule:: embodichain.lab.gym + +Overview +-------- + +The ``gym`` module provides a comprehensive framework for creating robot learning environments. It extends the Gymnasium interface to support multi-environment parallel execution, +custom observations, and robotic-specific functionality. + +Key Features: + +* **Multi-Environment Support**: Run multiple environment instances in parallel for efficient training +* **Gymnasium Integration**: Full compatibility with the Gymnasium API and ecosystem +* **Robotic Focus**: Built-in support for robot control, sensors, and manipulation tasks +* **Extensible Architecture**: Easy to create custom environments and tasks +* **GPU Acceleration**: Leverage GPU computing for high-performance simulation + +Environments Module (envs) +--------------------------- + +.. currentmodule:: embodichain.lab.gym.envs + +Base Environment Classes +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: BaseEnv + :members: + :inherited-members: + :show-inheritance: + + The foundational environment class that provides the core functionality for all EmbodiChain RL environments. + This class extends the Gymnasium ``Env`` interface with multi-environment support and robotic-specific features. + +.. autoclass:: EnvCfg + :members: + :exclude-members: __init__, class_type + + Configuration class for basic environment settings including simulation parameters and environment count. + +Embodied Environment Classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: EmbodiedEnv + :members: + :inherited-members: + :show-inheritance: + + An advanced environment class that provides additional features for embodied AI research, including + sophisticated observation management, event handling, and multi-modal sensor integration. + +.. autoclass:: EmbodiedEnvCfg + :members: + :exclude-members: __init__, class_type + + Configuration class for embodied environments with extended settings for lighting, observation management, + and advanced simulation features. + +Utilities Module (utils) +------------------------- + +.. currentmodule:: embodichain.lab.gym.utils + +Registration System +~~~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.registration + +.. autoclass:: EnvSpec + :members: + :show-inheritance: + + Specification class for environment registration, containing environment metadata and creation parameters. + +.. autofunction:: register + + Register a new environment class with the EmbodiChain environment registry. + + :param name: Unique identifier for the environment + :param cls: Environment class (must inherit from BaseEnv or BaseEnv) + :param max_episode_steps: Maximum steps per episode (optional) + :param default_kwargs: Default keyword arguments for environment creation + +.. autofunction:: register_env + + Decorator function for registering environment classes. This is the recommended way to register environments. + + :param uid: Unique identifier for the environment + :param max_episode_steps: Maximum steps per episode (optional) + :param override: Whether to override existing environment with same ID + :param kwargs: Additional registration parameters + + Example: + .. code-block:: python + + @register_env("MyEnv-v1", max_episode_steps=1000) + class MyCustomEnv(BaseEnv): + def __init__(self, **kwargs): + super().__init__(**kwargs) + +.. autofunction:: make + + Create an environment instance from a registered environment ID. + + :param env_id: Registered environment identifier + :param kwargs: Additional keyword arguments for environment creation + :returns: Environment instance + +.. autoclass:: TimeLimitWrapper + :members: + :show-inheritance: + + Gymnasium wrapper that adds episode time limits to environments. + +Action Conversion Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.action_conversion + + Utilities for converting between different action representations in robotic environments. + +Gymnasium Utilities +~~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.gym_utils + + Helper functions and utilities for Gymnasium environment integration. + +Image Utilities +~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.img_utils + + Image processing utilities for visual observations and rendering. + +Miscellaneous Utilities +~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.misc + + Miscellaneous utility functions for environment development and debugging. + diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.utils.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.utils.rst new file mode 100644 index 00000000..a0c9f265 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.utils.rst @@ -0,0 +1,48 @@ +embodichain.lab.gym.utils +===================================== + +.. automodule:: embodichain.lab.gym.utils + +Registration System +------------------- + +.. currentmodule:: embodichain.lab.gym.utils.registration + +.. autoclass:: EnvSpec + :members: + :show-inheritance: + +.. autofunction:: register + +.. autofunction:: register_env + +.. autofunction:: make + +.. autoclass:: TimeLimitWrapper + :members: + :show-inheritance: + +Utility Modules +--------------- + +Action Conversion +~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.action_conversion + +Gymnasium Utilities +~~~~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.gym_utils + +Image Utilities +~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.img_utils + +Miscellaneous +~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.gym.utils.misc + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.rst b/docs/source/api_reference/embodichain/embodichain.lab.rst new file mode 100644 index 00000000..e8a7d565 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.rst @@ -0,0 +1,29 @@ +embodichain.lab +===================== + +.. automodule:: embodichain.lab + + .. rubric:: Submodules + + .. autosummary:: + + devices + gym + sim + utility + +Device Management +----------------- + +.. automodule:: embodichain.lab.devices + :members: + :undoc-members: + :show-inheritance: + +Utilities +--------- + +.. automodule:: embodichain.lab.utility + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.cfg.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.cfg.rst new file mode 100644 index 00000000..dacdae33 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.cfg.rst @@ -0,0 +1,21 @@ +embodichain.lab.sim.cfg +=================================== + +.. automodule:: embodichain.lab.sim.cfg + + + .. rubric:: Classes + + .. autosummary:: + + ArticulationCfg + GPUMemoryCfg + JointDrivePropertiesCfg + LightCfg + ObjectBaseCfg + PhysicsCfg + RigidBodyAttributesCfg + RigidObjectCfg + RobotCfg + URDFCfg + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.common.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.common.rst new file mode 100644 index 00000000..2d8ffc14 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.common.rst @@ -0,0 +1,12 @@ +embodichain.lab.sim.common +====================================== + +.. automodule:: embodichain.lab.sim.common + + + .. rubric:: Classes + + .. autosummary:: + + BatchEntity + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.material.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.material.rst new file mode 100644 index 00000000..61e7bbdf --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.material.rst @@ -0,0 +1,14 @@ +embodichain.lab.sim.material +======================================== + +.. automodule:: embodichain.lab.sim.material + + + .. rubric:: Classes + + .. autosummary:: + + VisualMaterial + VisualMaterialCfg + VisualMaterialInst + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.objects.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.objects.rst new file mode 100644 index 00000000..34928803 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.objects.rst @@ -0,0 +1,111 @@ +embodichain.lab.sim.objects +========================================== + + +.. automodule:: embodichain.lab.sim.objects + + .. rubric:: Classes + + .. autosummary:: + + Light + LightCfg + RigidObject + RigidBodyData + RigidObjectCfg + RigidObjectGroup + RigidBodyGroupData + RigidObjectGroupCfg + Articulation + ArticulationData + ArticulationCfg + Robot + RobotCfg + +.. currentmodule:: embodichain.lab.sim.objects + +Light +----- + +.. autoclass:: Light + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: LightCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Rigid Object +------------ + +.. autoclass:: RigidObject + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: RigidBodyData + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: RigidObjectCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Rigid Object Group +------------------- + +.. autoclass:: RigidObjectGroup + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: RigidBodyGroupData + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: RigidObjectGroupCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Articulation +------------ + +.. autoclass:: Articulation + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: ArticulationData + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: ArticulationCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Robot +----- + +.. autoclass:: Robot + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: RobotCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst new file mode 100644 index 00000000..d6428af3 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.robots.rst @@ -0,0 +1,6 @@ +embodichain.lab.sim.robots +====================================== + +.. automodule:: embodichain.lab.sim.robots + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst new file mode 100644 index 00000000..87b82a41 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.rst @@ -0,0 +1,101 @@ +embodichain.lab.sim +===================== + +.. automodule:: embodichain.lab.sim + + .. rubric:: Submodules + + .. autosummary:: + :toctree: . + + sim_manager + cfg + common + material + shapes + objects + sensors + solvers + utility + +.. currentmodule:: embodichain.lab.sim + +Simulation Manager +------------------ + +.. autoclass:: SimulationManager + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: SimulationManagerCfg + :members: + :undoc-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Configurations +------------------ + +.. automodule:: embodichain.lab.sim.cfg + :members: + :undoc-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Common Conponents +------------------ + +.. automodule:: embodichain.lab.sim.common + :members: + :undoc-members: + :show-inheritance: + +Materials +------------------ + +.. automodule:: embodichain.lab.sim.material + :members: + :undoc-members: + :show-inheritance: + +Shapes +------------------ + +.. automodule:: embodichain.lab.sim.shapes + :members: + :undoc-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Objects +------- + +.. toctree:: + :maxdepth: 1 + + embodichain.lab.sim.objects + +Sensors +------- + +.. toctree:: + :maxdepth: 1 + + embodichain.lab.sim.sensors + +Solvers +------- + +.. toctree:: + :maxdepth: 1 + + embodichain.lab.sim.solvers + +Utility +------- + +.. toctree:: + :maxdepth: 1 + + embodichain.lab.sim.utility \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.sensors.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.sensors.rst new file mode 100644 index 00000000..b52a9185 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.sensors.rst @@ -0,0 +1,54 @@ +embodichain.lab.sim.sensors +========================================== + + +.. automodule:: embodichain.lab.sim.sensors + + .. rubric:: Classes + + .. autosummary:: + SensorCfg + BaseSensor + CameraCfg + Camera + StereoCameraCfg + StereoCamera + +.. currentmodule:: embodichain.lab.sim.sensors + +Sensor +------ +.. autoclass:: BaseSensor + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: SensorCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +Camera +------ +.. autoclass:: Camera + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: CameraCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate + +Stereo Camera +------------- +.. autoclass:: StereoCamera + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: StereoCameraCfg + :members: + :inherited-members: + :show-inheritance: + :exclude-members: __init__, copy, replace, to_dict, validate diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.shapes.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.shapes.rst new file mode 100644 index 00000000..724f5229 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.shapes.rst @@ -0,0 +1,16 @@ +embodichain.lab.sim.shapes +====================================== + +.. automodule:: embodichain.lab.sim.shapes + + + .. rubric:: Classes + + .. autosummary:: + + CubeCfg + LoadOption + MeshCfg + ShapeCfg + SphereCfg + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.sim_manager.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.sim_manager.rst new file mode 100644 index 00000000..e6f74aee --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.sim_manager.rst @@ -0,0 +1,13 @@ +embodichain.lab.sim.sim_manager +========================================= + +.. automodule:: embodichain.lab.sim.sim_manager + + + .. rubric:: Classes + + .. autosummary:: + + SimulationManager + SimulationManagerCfg + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.solvers.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.solvers.rst new file mode 100644 index 00000000..36fcdbe5 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.solvers.rst @@ -0,0 +1,95 @@ +embodichain.lab.sim.solvers +========================================== + + +.. automodule:: embodichain.lab.sim.solvers + + .. rubric:: Classes + + .. autosummary:: + SolverCfg + BaseSolver + PytorchSolverCfg + PytorchSolver + PinocchioSolverCfg + PinocchioSolver + PinkSolverCfg + PinkSolver + DifferentialSolverCfg + DifferentialSolver + OPWSolverCfg + OPWSolver + +.. currentmodule:: embodichain.lab.sim.solvers + +Base Solver +----------- + +.. autoclass:: SolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: BaseSolver + :members: + :inherited-members: + :show-inheritance: + +PyTorch Solver +-------------- + +.. autoclass:: PytorchSolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: PytorchSolver + :members: + :inherited-members: + :show-inheritance: + +Pinocchio Solver +---------------- + +.. autoclass:: PinocchioSolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: PinocchioSolver + :members: + :inherited-members: + :show-inheritance: + +Pink Solver +----------- + +.. autoclass:: PinkSolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: PinkSolver + :members: + :inherited-members: + :show-inheritance: + +Differential Solver +------------------- + +.. autoclass:: DifferentialSolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: DifferentialSolver + :members: + :inherited-members: + :show-inheritance: + +OPW Solver +---------- + +.. autoclass:: OPWSolverCfg + :members: + :exclude-members: __init__, copy, replace, to_dict, validate + +.. autoclass:: OPWSolver + :members: + :inherited-members: + :show-inheritance: diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst new file mode 100644 index 00000000..5b1c4bd8 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.types.rst @@ -0,0 +1,6 @@ +embodichain.lab.sim.types +===================================== + +.. automodule:: embodichain.lab.sim.types + + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst b/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst new file mode 100644 index 00000000..f64d3ce3 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.lab.sim.utility.rst @@ -0,0 +1,31 @@ +embodichain.lab.sim.utility +========================================== + +.. automodule:: embodichain.lab.sim.utility + +Utility Functions +----------------- + +This module contains utility functions for simulation, mesh processing, and URDF handling. + +.. rubric:: Submodules + +.. autosummary:: + + sim_utils + mesh_utils + urdf_utils + +.. currentmodule:: embodichain.lab.sim.utility + +Simulation Utils +~~~~~~~~~~~~~~~~ + +.. automodule:: embodichain.lab.sim.utility.sim_utils + :members: + +Mesh Utils +~~~~~~~~~~ + +.. automodule:: embodichain.lab.sim.utility.mesh_utils + :members: diff --git a/docs/source/api_reference/embodichain/embodichain.toolkits.rst b/docs/source/api_reference/embodichain/embodichain.toolkits.rst new file mode 100644 index 00000000..cc2639d5 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.toolkits.rst @@ -0,0 +1,29 @@ +embodichain.toolkits +==================== + +.. automodule:: embodichain.toolkits + + .. rubric:: Submodules + + .. autosummary:: + + graspkit + urdf_assembly + + +GraspKit +-------- + +.. automodule:: embodichain.toolkits.graspkit + :members: + :undoc-members: + :show-inheritance: + + +URDF Assembly Tool +------------------- + +.. automodule:: embodichain.toolkits.urdf_assembly + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api_reference/embodichain/embodichain.utils.rst b/docs/source/api_reference/embodichain/embodichain.utils.rst new file mode 100644 index 00000000..52ad4a4f --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.utils.rst @@ -0,0 +1,100 @@ +embodichain.utils +================= + +.. automodule:: embodichain.utils + + .. Rubric:: Submodules + + .. autosummary:: + + warp + configclass + file + heat_map + logger + math + module_utils + string + utility + visualizer + + +High Performance Computing with Warp +--------------- + +.. toctree:: + :maxdepth: 1 + + embodichain.utils.warp + +Configuration Classes +--------------------- + +.. automodule:: embodichain.utils.configclass + :members: + :undoc-members: + :show-inheritance: + +File Operations +--------------- + +.. automodule:: embodichain.utils.file + :members: + :undoc-members: + :show-inheritance: + +Heat Map Utilities +------------------ + +.. automodule:: embodichain.utils.heat_map + :members: + :undoc-members: + :show-inheritance: + +Logging +------- + +.. automodule:: embodichain.utils.logger + :members: + :undoc-members: + :show-inheritance: + +Mathematical Operations +----------------------- + +.. automodule:: embodichain.utils.math + :members: + :undoc-members: + :show-inheritance: + +Module Utilities +----------------------- + +.. automodule:: embodichain.utils.module_utils + :members: + :undoc-members: + :show-inheritance: + +String Operations +----------------- + +.. automodule:: embodichain.utils.string + :members: + :undoc-members: + :show-inheritance: + +General Utilities +----------------- + +.. automodule:: embodichain.utils.utility + :members: + :undoc-members: + :show-inheritance: + +Visualization +------------- + +.. automodule:: embodichain.utils.visualizer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api_reference/embodichain/embodichain.utils.warp.kinematics.rst b/docs/source/api_reference/embodichain/embodichain.utils.warp.kinematics.rst new file mode 100644 index 00000000..d6b6c286 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.utils.warp.kinematics.rst @@ -0,0 +1,37 @@ +embodichain.utils.warp.kinematics +================================= + +Utilities for kinematics implemented with Warp (high-performance kernels). + +This subpackage provides Warp kernels and helper functions for inverse/forward +kinematics and batched trajectory warping used across EmbodiChain. The modules +documented below are the main entry points: + +- ``opw_solver``: efficient OPW-based forward/inverse kinematics kernels. +- ``warp_trajectory``: kernels to compute, interpolate, and apply trajectory offsets. + +.. automodule:: embodichain.utils.warp.kinematics + + .. Rubric:: Submodules + + .. autosummary:: + + opw_solver + warp_trajectory + +OPW Kinematics Solver +----------------------- + +.. automodule:: embodichain.utils.warp.kinematics.opw_solver + :members: + :undoc-members: + :show-inheritance: + + +Trajectory Warping Utilities +---------------------------- +.. automodule:: embodichain.utils.warp.kinematics.warp_trajectory + :members: + :undoc-members: + :show-inheritance: + \ No newline at end of file diff --git a/docs/source/api_reference/embodichain/embodichain.utils.warp.rst b/docs/source/api_reference/embodichain/embodichain.utils.warp.rst new file mode 100644 index 00000000..cdc62d20 --- /dev/null +++ b/docs/source/api_reference/embodichain/embodichain.utils.warp.rst @@ -0,0 +1,27 @@ +embodichain.utils.warp +======================= + +High-performance Warp utilities used by EmbodiChain. + +This package exposes Warp kernels and helpers for various high-performance computing tasks: + +- Image processing. +- 3D spatial computating, +- Robotics kinematics and trajectory computating + +.. automodule:: embodichain.utils.warp + + .. Rubric:: Submodules + + .. autosummary:: + + kinematics + kernels + +kernel Operators +----------------------- + +.. automodule:: embodichain.utils.warp.kernels + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/api_reference/index.rst b/docs/source/api_reference/index.rst new file mode 100644 index 00000000..fa3112ae --- /dev/null +++ b/docs/source/api_reference/index.rst @@ -0,0 +1,19 @@ +API Reference +============= + +This page provides detailed documentation for all EmbodiChain modules and classes. + +Core Framework +-------------- + +The following modules are available in the core ``embodichain`` framework: + +.. currentmodule:: embodichain + +.. autosummary:: + :toctree: embodichain + + agents + lab + toolkits + utils diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 00000000..311ddbf4 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,66 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +import os +import sys + + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +package_root = os.path.join(project_root, "embodichain") +sys.path.insert(0, package_root) + + +project = "EmbodiChain" +copyright = "2025, The EmbodiChain Project Developers" +author = "The EmbodiChain Project Developers" + +# Read version from VERSION file if it exists +with open(os.path.join(os.path.dirname(__file__), "..", "..", "VERSION")) as f: + full_version = f.read().strip() + version = ".".join(full_version.split(".")[:3]) + + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", # optional, shows type hints + "sphinx_design", + "myst_parser", # if you prefer Markdown pages +] +# Napoleon settings if using Google/NumPy docstring style: +napoleon_google_docstring = True +napoleon_numpy_docstring = True + +# generate autosummary even if no references +autosummary_generate = True +autosummary_generate_overwrite = False +# default autodoc settings +autodoc_default_options = { + "autosummary": True, +} + +# If using MyST and writing .md API stubs: +myst_enable_extensions = ["colon_fence", "deflist", "html_admonition"] + + +templates_path = ["_templates"] +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_book_theme" +html_static_path = ["_static"] +html_logo = "_static/logo_e.png" diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 00000000..2058a8ba --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,50 @@ +EmbodiChain Documentation +========================= + +Welcome to the EmbodiChain! + +Table of Contents +================= + +.. toctree:: + :maxdepth: 1 + :caption: Introduction + :glob: + + introduction + +.. toctree:: + :maxdepth: 1 + :caption: Getting Started + :glob: + + quick_start/install.md + tutorial/index + quick_start/docs.md + +.. toctree:: + :maxdepth: 1 + :caption: Overview + :glob: + + overview/sim/index + overview/gym/index + overview/vla/index + overview/rl/index + +.. toctree:: + :maxdepth: 1 + :caption: Resources + :glob: + + resources/robot/index* + resources/task/index* + resources/roadmap.md + +.. toctree:: + :maxdepth: 2 + :caption: API Reference + :titlesonly: + + api_reference/index + diff --git a/docs/source/introduction.rst b/docs/source/introduction.rst new file mode 100644 index 00000000..404541da --- /dev/null +++ b/docs/source/introduction.rst @@ -0,0 +1,58 @@ +.. EmbodiChain documentation master file, created by + sphinx-quickstart on Tue Nov 19 11:00:25 2024. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +EmbodiChain +====================================== + +.. image:: ../../assets/imgs/teaser.jpg + :alt: teaser + +📘 `Documentation `_ + +--- + +EmbodiChain is an end-to-end, GPU-accelerated framework for Embodied AI. It streamlines research and development by unifying high-performance simulation, real-to-sim data pipelines, modular model architectures, and efficient training workflows. This integration enables rapid experimentation, seamless deployment of intelligent agents, and effective Sim2Real transfer for real-world robotic systems. + +.. NOTE:: + EmbodiChain is in Alpha and under active development: + + * More features will be continually added in the coming months. + * Since this is an early release, we welcome feedback (bug reports, feature requests, etc.) via GitHub Issues. + + +Key Features +------------ + +* **High-Fidelity, GPU-Accelerated Simulation**: Combines realistic physics for both rigid and deformable objects with advanced ray-traced sensor modeling, all accelerated on the GPU for high-throughput batched simulations. +* **Unified Robot Learning Environment**: Offers standardized interfaces for a wide range of robot learning tasks, including Imitation Learning and Reinforcement Learning. +* **Scalable Data Pipeline**: Features a comprehensive toolkit for automated data collection, efficient processing, and large-scale data generation to fuel your models. +* **Efficient Training & Evaluation**: Supports modern training paradigms like online data streaming for Imitation Learning and massively parallel environment rollouts for Reinforcement Learning. +* **Modular and Extensible**: Designed with modularity in mind to easily integrate new robot platforms, environments, and learning algorithms. + + +Getting Started +--------------- + +To get started with EmbodiChain, follow these steps: + +* `Installation Guide `_ +* `Quick Start Tutorial `_ +* `API Reference `_ + + +Citation +-------- + +If you use EmbodiChain in your research, please cite our work: + +.. code-block:: bibtex + + @misc{EmbodiChain, + author = {EmbodiChain Developers}, + title = {EmbodiChain: An end-to-end, GPU-accelerated, and modular platform for building generalized Embodied Intelligence.}, + month = {November}, + year = {2025}, + url = {https://github.com/DexForce/EmbodiChain} + } diff --git a/docs/source/overview/gym/index.rst b/docs/source/overview/gym/index.rst new file mode 100644 index 00000000..7781adfb --- /dev/null +++ b/docs/source/overview/gym/index.rst @@ -0,0 +1,5 @@ +Embodied Environments +================== + +Overview of the Embodied Environments: + diff --git a/docs/source/overview/rl/algorithm.md b/docs/source/overview/rl/algorithm.md new file mode 100644 index 00000000..cfc92421 --- /dev/null +++ b/docs/source/overview/rl/algorithm.md @@ -0,0 +1,67 @@ +# RL Algorithms + +This module contains the core implementations of reinforcement learning algorithms, mainly including PPO (Proximal Policy Optimization). + +## Main Classes and Functions + +### BaseAlgorithm +- Abstract base class for RL algorithms, defining common interfaces such as buffer initialization, data collection, and update. +- Key methods: + - `initialize_buffer(num_steps, num_envs, obs_dim, action_dim)`: Initialize the trajectory buffer. + - `collect_rollout(env, policy, obs, num_steps, on_step_callback)`: Collect interaction data. + - `update()`: Update the policy based on collected data. +- Designed to be algorithm-agnostic; Trainer only depends on this interface to support various RL algorithms. +- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. + +### PPO +- Mainstream on-policy algorithm, supports Generalized Advantage Estimation (GAE), policy update, and hyperparameter configuration. +- Key methods: + - `_compute_gae(rewards, values, dones)`: Generalized Advantage Estimation. + - `collect_rollout`: Collect trajectories and compute advantages/returns. + - `update`: Multi-epoch minibatch optimization, including entropy, value, and policy loss, with gradient clipping. +- Supports custom callbacks, detailed logging, and GPU acceleration. +- Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization. +- Supports advantage normalization, entropy regularization, value loss weighting, etc. + +### Config Classes +- `AlgorithmCfg`, `PPOCfg`: Centralized management of learning rate, batch size, clip_coef, ent_coef, vf_coef, and other parameters. +- Supports automatic loading from JSON config files for batch experiments and parameter tuning. +- Can be extended via inheritance for multiple algorithms and tasks. + +## Code Example +```python +class BaseAlgorithm: + def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): + ... + def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): + ... + def update(self): + ... + +class PPO(BaseAlgorithm): + def _compute_gae(self, rewards, values, dones): + ... + def collect_rollout(self, ...): + ... + def update(self): + ... +``` + +## Usage Recommendations +- It is recommended to manage all algorithm parameters via config classes and JSON config files for reproducibility and tuning. +- Supports multi-environment parallel collection to improve sampling efficiency. +- Custom algorithm classes can be implemented to extend new RL methods. + +## Extension Notes +- Users can inherit from `BaseAlgorithm` to implement custom algorithms and flexibly integrate them into the RL framework. +- Supports multi-environment parallelism and event-driven extension. +- Typical usage: +```python +algo = PPO(cfg, policy) +buffer = algo.initialize_buffer(...) +for _ in range(num_iterations): + algo.collect_rollout(...) + algo.update() +``` + +--- diff --git a/docs/source/overview/rl/buffer.md b/docs/source/overview/rl/buffer.md new file mode 100644 index 00000000..91852074 --- /dev/null +++ b/docs/source/overview/rl/buffer.md @@ -0,0 +1,65 @@ +# Rollout Buffer + +This module implements the data buffer for RL training, responsible for storing trajectory data from agent-environment interactions. + +## Main Classes and Structure + +### RolloutBuffer +- Used for on-policy algorithms (such as PPO), efficiently stores observations, actions, rewards, dones, values, and logprobs for each step. +- Supports multi-environment parallelism (shape: [T, N, ...]), all data allocated on GPU. +- Structure fields: + - `obs`: Observation tensor, float32, shape [T, N, obs_dim] + - `actions`: Action tensor, float32, shape [T, N, action_dim] + - `rewards`: Reward tensor, float32, shape [T, N] + - `dones`: Done flags, bool, shape [T, N] + - `values`: Value estimates, float32, shape [T, N] + - `logprobs`: Action log probabilities, float32, shape [T, N] + - `_extras`: Algorithm-specific fields (e.g., advantages, returns), dict[str, Tensor] + +## Main Methods +- `add(obs, action, reward, done, value, logprob)`: Add one step of data. +- `set_extras(extras)`: Attach algorithm-related tensors (e.g., advantages, returns). +- `iterate_minibatches(batch_size)`: Randomly sample minibatches, returns dict (including all fields and extras). +- Supports efficient GPU shuffle and indexing for large-scale training. + +## Usage Example +```python +buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, device) +for t in range(num_steps): + buffer.add(obs, action, reward, done, value, logprob) +buffer.set_extras({"advantages": adv, "returns": ret}) +for batch in buffer.iterate_minibatches(batch_size): + # batch["obs"], batch["actions"], batch["advantages"] ... + pass +``` + +## Design and Extension +- Supports multi-environment parallel collection, compatible with Gymnasium/IsaacGym environments. +- All data is allocated on GPU to avoid frequent CPU-GPU copying. +- The extras field can be flexibly extended to meet different algorithm needs (e.g., GAE, TD-lambda, distributional advantages). +- The iterator automatically shuffles to improve training stability. +- Compatible with various RL algorithms (PPO, A2C, SAC, etc.), custom fields and sampling logic supported. + +## Code Example +```python +class RolloutBuffer: + def __init__(self, num_steps, num_envs, obs_dim, action_dim, device): + # Initialize tensors + ... + def add(self, obs, action, reward, done, value, logprob): + # Add data + ... + def set_extras(self, extras): + # Attach algorithm-related tensors + ... + def iterate_minibatches(self, batch_size): + # Random minibatch sampling + ... +``` + +## Practical Tips +- It is recommended to call set_extras after each rollout to ensure advantage/return tensors align with main data. +- When using iterate_minibatches, set batch_size appropriately for training stability. +- Extend the extras field as needed for custom sampling and statistics. + +--- diff --git a/docs/source/overview/rl/config.md b/docs/source/overview/rl/config.md new file mode 100644 index 00000000..bf5c04df --- /dev/null +++ b/docs/source/overview/rl/config.md @@ -0,0 +1,55 @@ +# Config + +This module defines configuration classes for RL algorithms, centralizing the management of training hyperparameters and supporting automatic loading and experiment reproducibility. + +## Main Classes and Structure + +### AlgorithmCfg +- Base parameter config class for RL algorithms, supports dataclass-based automation. +- Typical fields: + - `device`: Training device (e.g., "cuda", "cpu"). + - `learning_rate`: Learning rate. + - `batch_size`: Batch size per training epoch. + - `gamma`: Discount factor. + - `gae_lambda`: GAE advantage estimation parameter. + - `max_grad_norm`: Gradient clipping threshold. +- Supports inheritance and extension (e.g., PPOCfg adds clip_coef, ent_coef, vf_coef). + +### Automatic Loading +- Supports automatic parsing of JSON config files; the main training script injects parameters automatically. +- Decouples config from code, making batch experiments and parameter tuning easier. + +## Usage Example +```python +from embodichain.agents.rl.utils import AlgorithmCfg +cfg = AlgorithmCfg(learning_rate=1e-4, batch_size=8192, gamma=0.99) +``` +Or via config file: +```json +{ + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 0.0001, + "batch_size": 8192, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.01, + "vf_coef": 0.5, + "max_grad_norm": 0.5 + } + } +} +``` + +## Extension and Customization +- Custom algorithm parameter classes are supported for multi-algorithm and multi-task experiments. +- Config classes are seamlessly integrated with the main training script for automated experiments and reproducibility. +- Supports parameter validation, default values, and type hints. + +## Practical Tips +- It is recommended to manage all experiment parameters via JSON config files for reproducibility and tuning. +- Supports multi-algorithm config for easy comparison and automation. + +--- diff --git a/docs/source/overview/rl/index.rst b/docs/source/overview/rl/index.rst new file mode 100644 index 00000000..2d78b85b --- /dev/null +++ b/docs/source/overview/rl/index.rst @@ -0,0 +1,80 @@ +Reinforcement Learning +====================== + +This section introduces the overall architecture and submodules of the embodychain RL (Reinforcement Learning) module. The RL framework supports mainstream algorithms (such as PPO) and provides flexible components for policy, buffer, trainer, etc., making it easy to extend and customize. + +.. contents:: Table of contents + :local: + :depth: 2 + +Overview +-------- + +The embodychain RL module is used to train agents to accomplish tasks in simulation environments. It mainly includes algorithm implementations, policy networks, data buffers, training processes, and utility tools. + +Architecture Diagram Example +--------------------------- + +.. code-block:: text + + +-------------------+ + | train.py | + +-------------------+ + | + v + +-------------------+ + | Trainer | + +-------------------+ + | | | | + v v v v + Algo Policy Buffer Env + +- train.py is responsible for entry, config parsing, and module initialization. +- Trainer coordinates algorithm, policy, buffer, and environment. +- Algo/Policy/Buffer/Env are independent, making extension easy. + +Module Categories +----------------- + +- Algorithm (`algo/`): RL algorithm implementations, including `BaseAlgorithm`, `PPO`, etc. +- Buffer (`buffer/`): Trajectory data buffer, such as `RolloutBuffer`. +- Models (`models/`): Policy network modules, including `Policy`, `ActorCritic`, `MLP`. +- Trainer (`utils/trainer.py`): Main training loop and logging management. +- Config (`utils/config.py`): Algorithm config class definitions. +- Train Script (`train.py`): RL training entry script. + +Extension and Customization +--------------------------- + +- Users can customize algorithms (by inheriting `BaseAlgorithm`), policies (by inheriting `Policy`), buffers, etc. +- Supports multi-environment parallelism, event-driven extension, and flexible config management. +- It is recommended to manage all parameters via config files for reproducibility and batch experiments. + +Common Issues and Best Practices +------------------------------- +- Config files are recommended to use JSON for easy management and reproducibility. +- Parallel environment sampling can significantly improve training efficiency. +- The event-driven mechanism allows flexible insertion of custom logic (such as evaluation, saving, callbacks). +- It is recommended to use WandB/TensorBoard for training process visualization. + +Example +------- + +.. code-block:: bash + + python train.py --config configs/agents/rl/push_cube/train_config.json + +For more details, please refer to the source code and API documentation of each submodule. + +See also +-------- + +.. toctree:: + :maxdepth: 1 + + algorithm.md + buffer.md + models.md + trainer.md + config.md + train_script.md diff --git a/docs/source/overview/rl/models.md b/docs/source/overview/rl/models.md new file mode 100644 index 00000000..8bf7986e --- /dev/null +++ b/docs/source/overview/rl/models.md @@ -0,0 +1,50 @@ +# Policy Models + +This module contains RL policy networks and related model implementations, supporting various architectures and distributional extensions. + +## Main Classes and Structure + +### Policy +- Abstract base class for RL policies; all policies must inherit from it. +- Unified interface: + - `get_action(obs, deterministic=False)`: Sample or output actions. + - `get_value(obs)`: Estimate state value. + - `evaluate_actions(obs, actions)`: Evaluate action probabilities, entropy, and value. +- Supports GPU deployment and distributed training. + +### ActorCritic +- Typical actor-critic policy, includes actor (action distribution) and critic (value function). +- Supports Gaussian action distributions, learnable log_std, suitable for continuous action spaces. +- Key methods: + - `get_action`: Actor network outputs mean, samples action, returns log_prob and critic value. + - `evaluate_actions`: Used for loss calculation in PPO/SAC algorithms. +- Custom actor/critic network architectures supported (e.g., MLP/CNN/Transformer). + +### MLP +- Multi-layer perceptron, supports custom number of layers, activation functions, LayerNorm, Dropout. +- Used to build actor/critic networks. +- Supports orthogonal initialization and output reshaping. + +### Factory Functions +- `build_policy(policy_block, obs_space, action_space, device, ...)`: Automatically build policy from config. +- `build_mlp_from_cfg(module_cfg, in_dim, out_dim)`: Automatically build MLP from config. + +## Usage Example +```python +actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) +critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) +policy = build_policy(policy_block, obs_space, action_space, device, actor=actor, critic=critic) +action, log_prob, value = policy.get_action(obs) +``` + +## Extension and Customization +- Supports custom network architectures (e.g., CNN, Transformer) by implementing the Policy interface. +- Can extend to multi-head policies, distributional actors, hybrid action spaces, etc. +- Factory functions facilitate config management and automated experiments. + +## Practical Tips +- It is recommended to configure all network architectures and hyperparameters for reproducibility. +- Supports multi-environment parallelism and distributed training to improve sampling efficiency. +- Extend the Policy interface as needed for multi-modal input, hierarchical policies, etc. + +--- diff --git a/docs/source/overview/rl/train_script.md b/docs/source/overview/rl/train_script.md new file mode 100644 index 00000000..cc7a1568 --- /dev/null +++ b/docs/source/overview/rl/train_script.md @@ -0,0 +1,48 @@ +# Train Script + +This module provides the RL training entry script, responsible for parsing configuration, initializing modules, and starting training. It supports multi-task and automated experiments. + +## Main Structure and Flow + +### train.py +- Main training script, supports command-line arguments (such as --config), automatically loads JSON config. +- Initializes device, random seed, output directory, and logging (TensorBoard/WandB). +- Loads environment config, supports multi-environment parallelism and evaluation environments. +- Builds policy (e.g., actor-critic), algorithm (e.g., PPO), and Trainer. +- Supports event management (e.g., environment randomization, data logging, evaluation events). +- Automatically saves model checkpoints and performs periodic evaluation. + +## Argument Parsing +- Supports command-line arguments: + - `--config`: Specify the path to the config file (JSON only). +- The config file includes parameters for trainer, policy, algorithm, events, and other modules. + +## Module Initialization +- Device selection (CPU/GPU), automatic detection and setup. +- Random seed setting to ensure experiment reproducibility. +- Output directory is automatically generated, log files are managed automatically. +- Supports TensorBoard/WandB logging, automatically records the training process. + +## Training Flow +1. Load the JSON config file and parse parameters for each module. +2. Initialize environment, policy, algorithm, and Trainer. +3. Enter the main training loop: collect data, update policy, record logs. +4. Periodically evaluate and save the model. +5. Supports graceful interruption and auto-saving with KeyboardInterrupt. + +## Usage Example +```bash +python train.py --config configs/agents/rl/push_cube/train_config.json +``` + +## Extension and Customization +- Supports custom event modules for flexible training flow extension. +- Can integrate multi-task and multi-environment training. +- Config-driven management for batch experiments and parameter tuning. + +## Practical Tips +- It is recommended to manage all experiment parameters via JSON config files for reproducibility and tuning. +- Supports multi-environment and event extension to improve training flexibility. +- Logging and checkpoint management help with experiment tracking and recovery. + +--- diff --git a/docs/source/overview/rl/trainer.md b/docs/source/overview/rl/trainer.md new file mode 100644 index 00000000..1a2b0fe3 --- /dev/null +++ b/docs/source/overview/rl/trainer.md @@ -0,0 +1,53 @@ +# Trainer + +This module implements the main RL training loop, logging management, and event-driven extension. + +## Main Classes and Structure + +### Trainer +- RL training coordinator, responsible for the interaction between algorithm, environment, and policy. +- Main responsibilities: + - Manage training loop, evaluation, and model saving. + - Event-driven extension (e.g., environment randomization, data logging, evaluation events). + - Logging output (TensorBoard/WandB/console), tracking rewards, episode length, loss, etc. +- Key fields: + - `policy`: RL policy object. + - `algorithm`: RL algorithm object. + - `env`/`eval_env`: Training and evaluation environments. + - `writer`: TensorBoard logger. + - `event_manager`/`eval_event_manager`: Event managers. + - `global_step`, `ret_window`, `len_window`: Training statistics. + +## Main Methods +- `train(total_timesteps)`: Main training loop, automatically collects data, updates policy, and logs. +- `_collect_rollout()`: Collect one rollout, supports custom callback statistics. +- `_log_train(losses)`: Log training loss, reward, sampling speed, etc. +- `_eval_once()`: Periodic evaluation, records evaluation metrics. +- `save_checkpoint()`: Save model parameters and training state. + +## Event Management +- Supports custom events (e.g., environment randomization, data logging) injected via EventManager. +- Events can be executed by interval/step/trigger, enabling flexible extension. + +## Logging and Monitoring +- Supports TensorBoard and WandB logging, automatically records reward, episode length, loss, sampling speed, etc. +- Console output for training progress and statistics. + +## Usage Example +```python +trainer = Trainer(policy, env, algorithm, num_steps, batch_size, writer, ...) +trainer.train(total_steps) +trainer.save_checkpoint() +``` + +## Extension and Customization +- Custom event modules can be implemented for environment reset, data collection, evaluation, etc. +- Supports multi-environment parallelism and distributed training. +- Training process can be flexibly adjusted via config files. + +## Practical Tips +- It is recommended to perform periodic evaluation and model saving to prevent loss of progress during training. +- The event mechanism can be used for automated experiments, data collection, and environment reset. +- Logging and monitoring help analyze training progress and tune hyperparameters. + +--- diff --git a/docs/source/overview/sim/index.rst b/docs/source/overview/sim/index.rst new file mode 100644 index 00000000..fd6e56fd --- /dev/null +++ b/docs/source/overview/sim/index.rst @@ -0,0 +1,23 @@ +Simulation Framework +================== + +Overview of the Simulation Framework: + +- Architecture + +- Components + + - Simulation Manager + - Simulation Object + - Material + - Virtual Sensor + - Kinematics Solver + - Motion Generation + + +.. toctree:: + :maxdepth: 1 + :glob: + + solvers/index + planners/index diff --git a/docs/source/overview/sim/planners/index.rst b/docs/source/overview/sim/planners/index.rst new file mode 100644 index 00000000..0347197a --- /dev/null +++ b/docs/source/overview/sim/planners/index.rst @@ -0,0 +1,36 @@ +Planners +================================= + +This section documents the planners provided by the project with a focus on +planners for robotic motion: path planning, trajectory generation, +collision avoidance, and practical considerations such as smoothness and +dynamic feasibility. + +The repository contains several planner implementations — each has a dedicated +page with implementation details and examples. Use the links at the bottom of +this page to jump to a specific planner. + +.. contents:: Table of contents + :local: + :depth: 2 + +Overview +-------- + +The `embodichain` project provides a unified interface for robot trajectory planning, supporting both joint space and Cartesian space interpolation. The main planners include: + +- **MotionGenerator**: A unified trajectory planning interface that supports joint/Cartesian interpolation, automatic constraint handling, flexible planner selection, and is easily extensible for collision checking and additional planners. +- **ToppraPlanner**: A time-optimal trajectory planner based on the TOPPRA library, supporting joint trajectory generation under velocity and acceleration constraints. +- **TrajectorySampleMethod**: An enumeration for trajectory sampling strategies, supporting sampling by time, quantity, or distance. + +These tools can be used to generate smooth and dynamically feasible robot trajectories, and are extensible for future collision checking and various sampling requirements. + +See also +-------- + +.. toctree:: + :maxdepth: 1 + + motion_generator.md + toppra_planner.md + trajectory_sample_method.md diff --git a/docs/source/overview/sim/planners/motion_generator.md b/docs/source/overview/sim/planners/motion_generator.md new file mode 100644 index 00000000..5aa75a0b --- /dev/null +++ b/docs/source/overview/sim/planners/motion_generator.md @@ -0,0 +1,142 @@ +# MotionGenerator + +`MotionGenerator` provides a unified interface for robot trajectory planning, supporting both joint space and Cartesian space interpolation. It is designed to work with different planners (such as ToppraPlanner) and can be extended to support collision checking in the future. + +## Features + +* **Unified planning interface**: Supports trajectory planning with or without collision checking (collision checking is reserved for future implementation). +* **Flexible planner selection**: Allows selection of different planners (currently supports TOPPRA for time-optimal planning). +* **Automatic constraint handling**: Retrieves velocity and acceleration limits from the robot or uses user-specified/default values. +* **Supports both joint and Cartesian interpolation**: Generates discrete trajectories using either joint space or Cartesian space interpolation. +* **Convenient sampling**: Supports various sampling strategies via `TrajectorySampleMethod`. + +## Usage + +### Initialization + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, +) + +from embodichain.lab.sim.planners.motion_generator import MotionGenerator +from embodichain.lab.sim.objects.robot import Robot +from embodichain.lab.sim.solvers.pink_solver import PinkSolverCfg +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod + +# Configure the simulation +sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device="cpu", +) + +sim = SimulationManager(sim_cfg) +sim.set_manual_update(True) + +# Get UR10 URDF path +urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + +# Create UR10 robot +robot_cfg = RobotCfg( + uid="UR10_test", + urdf_cfg=URDFCfg( + components=[{"component_type": "arm", "urdf_path": urdf_path}] + ), + control_parts={"arm": ["Joint[1-6]"]}, + solver_cfg={ + "arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="ee_link", + root_link_name="base_link", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ) + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"Joint[1-6]": 1e4}, + damping={"Joint[1-6]": 1e3}, + ), +) +robot = sim.add_robot(cfg=robot_cfg) + +motion_gen = MotionGenerator( + robot=robot, + uid="arm", + planner_type="toppra", + default_velocity=0.2, + default_acceleration=0.5 +) + +``` + +### Trajectory Planning + +#### Joint Space Planning + +```python +current_state = { + "position": [0, 0, 0, 0, 0, 0], + "velocity": [0, 0, 0, 0, 0, 0], + "acceleration": [0, 0, 0, 0, 0, 0] +} +target_states = [ + {"position": [1, 1, 1, 1, 1, 1]} +] +success, positions, velocities, accelerations, times, duration = motion_gen.plan( + current_state=current_state, + target_states=target_states, + sample_method=TrajectorySampleMethod.TIME, + sample_interval=0.01 +) +``` + +#### Cartesian or Joint Interpolation + +```python +# Using joint configurations (qpos_list) +qpos_list = [ + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1] +] +out_qpos_list, out_xpos_list = motion_gen.create_discrete_trajectory( + qpos_list=qpos_list, + is_linear=False, + sample_method=TrajectorySampleMethod.QUANTITY, + sample_num=20 +) +``` + +### Estimating Trajectory Sample Count + +You can estimate the number of sampling points required for a trajectory before generating it: + +```python +# Estimate based on joint configurations (qpos_list) +qpos_list = [ + [0, 0, 0, 0, 0, 0], + [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [1, 1, 1, 1, 1, 1] +] +sample_count = motion_gen.estimate_trajectory_sample_count( + qpos_list=qpos_list, # List of joint positions + step_size=0.01, # unit: m + angle_step=0.05, # unit: rad +) +print(f"Estimated sample count: {sample_count}") +``` + +## Notes + +* The planner type can be specified as a string or `PlannerType` enum. +* If the robot provides its own joint limits, those will be used; otherwise, default or user-specified limits are applied. +* For Cartesian interpolation, inverse kinematics (IK) is used to compute joint configurations for each interpolated pose. +* The class is designed to be extensible for additional planners and collision checking in the future. +* The sample count estimation is useful for predicting computational load and memory requirements. diff --git a/docs/source/overview/sim/planners/toppra_planner.md b/docs/source/overview/sim/planners/toppra_planner.md new file mode 100644 index 00000000..3ff756fd --- /dev/null +++ b/docs/source/overview/sim/planners/toppra_planner.md @@ -0,0 +1,58 @@ +# ToppraPlanner + +`ToppraPlanner` is a trajectory planner based on the [TOPPRA](https://toppra.readthedocs.io/) (Time-Optimal Path Parameterization via Reachability Analysis) library. It generates time-optimal joint trajectories under velocity and acceleration constraints. + +## Features + +- **Time-optimal trajectory generation**: Computes the fastest possible trajectory between waypoints, given joint velocity and acceleration limits. +- **Flexible sampling**: Supports sampling by time interval or by number of points. +- **Constraint handling**: Automatically formats velocity and acceleration constraints for the TOPPRA solver. +- **Dense and sparse waypoints**: Supports both dense and sparse waypoint interpolation. + +## Usage + +### Initialization + +```python +from embodichain.lab.sim.planners.toppra_planner import ToppraPlanner +planner = ToppraPlanner( + dofs=6, + max_constraints={ + "velocity": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "acceleration": [2.0, 2.0, 2.0, 2.0, 2.0, 2.0] + } +) +``` + +### Planning + +```python +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.lab.sim.planners.toppra_planner import ToppraPlanner +success, positions, velocities, accelerations, times, duration = planner.plan( + current_state={ + "position": [0, 0, 0, 0, 0, 0], + "velocity": [0, 0, 0, 0, 0, 0], + "acceleration": [0, 0, 0, 0, 0, 0] + }, + target_states=[ + {"position": [1, 1, 1, 1, 1, 1]} + ], + sample_method=TrajectorySampleMethod.TIME, + sample_interval=0.01 +) +``` + +- `positions`, `velocities`, `accelerations` are arrays of sampled trajectory points. +- `times` is the array of time stamps. +- `duration` is the total trajectory time. + +## Notes + +- The planner requires the `toppra` library (`pip install toppra==0.6.3`). +- For dense waypoints, the default spline interpolation is used. For sparse waypoints, you may need to adjust the interpolation method. +- The number of grid points (`gridpt_min_nb_points`) is important for accurate acceleration constraint handling. + +## References + +- [TOPPRA Documentation](https://hungpham2511.github.io/toppra/index.html) diff --git a/docs/source/overview/sim/planners/trajectory_sample_method.md b/docs/source/overview/sim/planners/trajectory_sample_method.md new file mode 100644 index 00000000..0b8c16ff --- /dev/null +++ b/docs/source/overview/sim/planners/trajectory_sample_method.md @@ -0,0 +1,17 @@ +# TrajectorySampleMethod + +`TrajectorySampleMethod` is an enumeration that defines different strategies for sampling points along a trajectory. It provides meaningful names for various sampling methods, making trajectory planning code more readable and maintainable. + +## Enum Members + +- **TIME**: + Sample trajectory points based on fixed time intervals. + Example: Generate a point every 0.01 seconds. + +- **QUANTITY**: + Sample a specified number of points along the trajectory, regardless of the time interval. + Example: Generate exactly 100 points between start and end. + +- **DISTANCE**: + Sample points based on fixed distance intervals along the path. + Example: Generate a point every 1 cm along the trajectory. diff --git a/docs/source/overview/sim/solvers/differential_solver.md b/docs/source/overview/sim/solvers/differential_solver.md new file mode 100644 index 00000000..5f68cad6 --- /dev/null +++ b/docs/source/overview/sim/solvers/differential_solver.md @@ -0,0 +1,99 @@ +# DifferentialSolver + +The `DifferentialSolver` is a differential inverse kinematics (IK) controller designed for robot manipulators. It computes joint-space commands to achieve desired end-effector positions or poses using various Jacobian-based methods. + +## Key Features + +* Supports multiple IK methods: pseudo-inverse (`pinv`), singular value decomposition (`svd`), transpose (`trans`), and damped least squares (`dls`) +* Configurable for position or pose control, with absolute or relative modes +* Efficient batch computation for multiple environments +* Flexible configuration via `DifferentialSolverCfg` + +## Configuration Example + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers.differential_solver import DifferentialSolver +from embodichain.lab.sim.solvers.differential_solver import DifferentialSolverCfg + +cfg = DifferentialSolverCfg( + urdf_path=get_data_path("UniversalRobots/UR5/UR5.urdf"), + joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + end_link_name="ee_link", + root_link_name="base_link", + command_type="pose", + use_relative_mode=False, + ik_method="pinv", + ik_params={"k_val": 1.0} +) + +solver = DifferentialSolver(cfg) +``` + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, -1.0, 0.0, -0.722600], + # [ 0.0, 0.0, -1.0, -0.191450], + # [ 1.0, 0.0, 0.0, 0.079159], + # [ 0.0, 0.0, 0.0, 1.0 ]]]) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, -1.0, 0.0, -0.722600], + [ 0.0, 0.0, -1.0, -0.191450], + [ 1.0, 0.0, 0.0, 0.079159], + [ 0.0, 0.0, 0.0, 1.0 ]]]) + qpos_seed = torch.zeros((1, 6)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([True]) + # Convergence info: tensor([[0.0, -0.231429, 0.353367, 0.893100, 0.0, 0.555758]]) +``` + +> **Tip:** +> - `get_fk` is for forward kinematics (joint to end-effector), `get_ik` is for inverse kinematics (end-effector to joint). +> - For batch computation, the first dimension of `qpos` and `target_xpos` is the batch size. + +## IK Methods Supported + +* **pinv**: Jacobian pseudo-inverse +* **svd**: Singular value decomposition +* **trans**: Jacobian transpose +* **dls**: Damped least squares + +## References + +* [Isaac Sim Library](https://github.com/isaac-sim/IsaacLab) diff --git a/docs/source/overview/sim/solvers/index.rst b/docs/source/overview/sim/solvers/index.rst new file mode 100644 index 00000000..8ffa5570 --- /dev/null +++ b/docs/source/overview/sim/solvers/index.rst @@ -0,0 +1,96 @@ +Solvers +================================= + +This section documents the solvers provided by the project with a focus on +robotic kinematics: forward kinematics (FK), inverse kinematics (IK), +differential (velocity) kinematics, constraint handling and practical +considerations such as singularities and performance tuning. + +The repository contains several solver implementations — each has a dedicated +page with implementation details and examples. Use the links at the bottom of +this page to jump to a specific solver. + +.. contents:: Table of contents + :local: + :depth: 2 + +Overview +-------- + +Robotic kinematics solvers translate between joint-space and task-space. + +- Forward kinematics (FK) maps joint values q to an end-effector pose. +- Inverse kinematics (IK) finds joint values q that achieve a desired end-effector + pose. + + +Forward kinematics +------------------- + +Forward kinematics composes joint transforms according to the robot's +kinematic tree to produce the end-effector transform. Practical builders compute these transforms efficiently using the robot's +URDF or internal kinematic model. FK solvers in `embodichain` are +optimized for batch evaluation and for returning both pose and link frames. + +Inverse kinematics +------------------- + +Inverse kinematics is the core topic for robotics. There are two common +approaches implemented in the repository: + +- Analytical IK (closed-form): when the robot geometry admits a closed-form + solution (e.g., many 6-DOF industrial arms), these solvers return exact + solutions quickly and deterministically. +- Numerical IK: general-purpose methods based on the Jacobian or optimization + that work for arbitrary kinematic chains but may be slower and require + a good initial guess. + +Analytical IK +~~~~~~~~~~~~~ + +Analytical solvers (see the OPW) exploit kinematic +structure to derive algebraic inverse mappings. Benefits include: + +- very fast runtime +- exact solutions when they exist + +Limitations: + +- only available for specific robot families and joint arrangements + +Numerical IK (Jacobian-based and optimization) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Numerical IK methods iteratively update joint values q to reduce pose error. +Jacobian-based updates use the task Jacobian J(q) to relate changes in joint +space to end-effector motion. + + +Multi-chain and closed-loop kinematics +------------------------------------- + +Solvers can handle serial chains, branched kinematic trees and some closed-loop +mechanisms. Closed-loop systems commonly require constraint solvers and may +embed loop-closure constraints in the solver as equality constraints. + + +Choosing a solver +----------------- + +- Use analytic solvers (OPW for 6-DOF arms or SRS for 7-DOF arms) when available for speed and + determinism. +- Use numerical solvers (PyTorch/optimization, Differential) when you need + flexibility.. + +See also +-------- + +.. toctree:: + :maxdepth: 1 + + pytorch_solver.md + differential_solver.md + pink_solver.md + pinocchio_solver.md + opw_solver.md + srs_solver.md diff --git a/docs/source/overview/sim/solvers/opw_solver.md b/docs/source/overview/sim/solvers/opw_solver.md new file mode 100644 index 00000000..f4dd1bf1 --- /dev/null +++ b/docs/source/overview/sim/solvers/opw_solver.md @@ -0,0 +1,111 @@ +# OPWSolver + +`OPWSolver` is a specialized inverse kinematics (IK) solver for 6-DOF industrial robots using the OPW kinematic parameterization. It provides fast, analytical solutions for robots with parallel and offset axes, supporting both CPU and GPU acceleration. The solver is suitable for large-scale batch IK tasks and real-time control. + +## Key Features + +* Analytical IK for OPW-parameterized 6-DOF manipulators +* Supports both parallel and offset axes, with custom axis flipping +* Fast batch computation for multiple target poses +* Configurable for CPU (py_opw_kinematics) and GPU (warp) backends +* Flexible configuration via `OPWSolverCfg` +* Strict enforcement of joint limits +* Forward kinematics (FK) and multiple IK solution branches + +## Configuration + +The solver is configured using the `OPWSolverCfg` class, which defines OPW parameters and solver options. + +```python +import torch +import numpy as np +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers.opw_solver import OPWSolver, OPWSolverCfg + +cfg = OPWSolverCfg( + urdf_path=get_data_path("CobotMagicArm/CobotMagicNoGripper.urdf"), + joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + end_link_name="link6", + root_link_name="arm_base", + a1 = 0.0, + a2 = -21.984, + b = 0.0, + c1 = 123.0, + c2 = 285.03, + c3 = 250.75, + c4 = 91.0, + offsets = ( + 0.0, + 82.21350356417211 * np.pi / 180.0, + -167.21710113148163 * np.pi / 180.0, + 0.0, + 0.0, + 0.0, + ), + flip_axes = (False, False, False, False, False, False), + has_parallelogram = False, +) + +solver = OPWSolver(cfg, device="cuda") +``` + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, 0.087093, 0.996200, 0.056135], + # [-1.0, 0.0 , -0.0 , -0.0 ], + # [ 0.0, -0.996200, 0.087093, 0.213281], + # [ 0.0, 0.0 , 0.0 , 1.0 ]]], device=solver.device) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, 0.087093, 0.996200, 0.056135], + [-1.0, 0.0 , -0.0 , -0.0 ], + [ 0.0, -0.996200, 0.087093, 0.213281], + [ 0.0, 0.0 , 0.0 , 1.0 ]]], device=solver.device) + + qpos_seed = torch.zeros((1, 6)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([1], device='cuda:0', dtype=torch.int32) + # Convergence info: tensor([[-3.141593, 0.793811, 0.0, 0.0, 2.522188, 1.570792]], device='cuda:0') + +``` + +## References + +* [OPW Kinematics Paper](https://doi.org/10.1109/TRO.2017.2776312) +* [py_opw_kinematics Documentation](https://github.com/UM-ARM-Lab/py_opw_kinematics) +* [warp Documentation](https://github.com/NVIDIA/warp) diff --git a/docs/source/overview/sim/solvers/pink_solver.md b/docs/source/overview/sim/solvers/pink_solver.md new file mode 100644 index 00000000..c7de6d61 --- /dev/null +++ b/docs/source/overview/sim/solvers/pink_solver.md @@ -0,0 +1,103 @@ +# PinkSolver + +`PinkSolver` is an advanced inverse kinematics (IK) solver for robot manipulators, built on [Pinocchio](https://github.com/stack-of-tasks/pinocchio) and [Pink](https://github.com/stephane-caron/pink). It supports flexible task definitions, robust optimization, and null space posture control. + +## Key Features + +- Supports both position-only and full pose (position + orientation) constraints +- Configurable convergence tolerance (`pos_eps`, `rot_eps`), damping, and iteration limits +- Handles joint limits and safety checks during optimization +- Allows variable and fixed task definitions for flexible control (see `FrameTask`, `NullSpacePostureTask`) +- Integrates with Pinocchio robot models and Pink task framework +- Supports multiple solver backends: `osqp`, `clarabel`, `ecos`, `proxqp`, `scs`, `daqp` +- Provides joint mapping between simulation and solver for flexible robot integration +- Null space posture task for redundancy resolution and secondary objectives +- Torch and numpy compatible for seamless integration in simulation pipelines + +## Configuration Example + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers.pink_solver import PinkSolver +from embodichain.lab.sim.solvers.pink_solver import PinkSolverCfg + +cfg = PinkSolverCfg( + urdf_path=get_data_path("UniversalRobots/UR5/UR5.urdf"), + joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + end_link_name="ee_link", + root_link_name="base_link", + max_iterations=500, + pos_eps=1e-4, + rot_eps=1e-4, + dt=0.05, + damp=1e-10, + is_only_position_constraint=False, + fail_on_joint_limit_violation=True, + solver_type="osqp", + variable_input_tasks=None, + fixed_input_tasks=None, +) + +solver = PinkSolver(cfg) +``` + + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, -1.0, 0.0, -0.722600], + # [ 0.0, 0.0, -1.0, -0.191450], + # [ 1.0, 0.0, 0.0, 0.079159], + # [ 0.0, 0.0, 0.0, 1.0 ]]]) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, -1.0, 0.0, -0.722600], + [ 0.0, 0.0, -1.0, -0.191450], + [ 1.0, 0.0, 0.0, 0.079159], + [ 0.0, 0.0, 0.0, 1.0 ]]]) + qpos_seed = torch.zeros((1, 6)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([True]) + # Convergence info: tensor([[0.0, -0.231429, 0.353367, 0.893100, 0.0, 0.555758]]) +``` + + +## References + +- [Pinocchio Library](https://github.com/stack-of-tasks/pinocchio) +- [Pink Library](https://github.com/stephane-caron/pink) +- [Null Space Posture Task](https://github.com/stephane-caron/pink#null-space-posture-task) diff --git a/docs/source/overview/sim/solvers/pinocchio_solver.md b/docs/source/overview/sim/solvers/pinocchio_solver.md new file mode 100644 index 00000000..fa2f0dec --- /dev/null +++ b/docs/source/overview/sim/solvers/pinocchio_solver.md @@ -0,0 +1,94 @@ +# PinocchioSolver + +The `PinocchioSolver` is a high-precision inverse kinematics (IK) solver for robot manipulators, leveraging [Pinocchio](https://github.com/stack-of-tasks/pinocchio) and [CasADi](https://web.casadi.org/) for symbolic and numerical optimization. It supports both position and orientation constraints, joint limits, and smoothness regularization for robust and realistic IK solutions. + +## Key Features + +* Supports both position-only and full pose constraints +* Configurable convergence tolerance, damping, and iteration limits +* Enforces joint limits during optimization +* Uses CasADi for symbolic cost and constraint definition +* Integrates with Pinocchio robot models for accurate kinematics +* Batch sampling for robust IK seed initialization + +## Configuration Example + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers.pinocchio_solver import PinocchioSolver +from embodichain.lab.sim.solvers.pinocchio_solver import PinocchioSolverCfg + +cfg = PinocchioSolverCfg( + urdf_path=get_data_path("UniversalRobots/UR5/UR5.urdf"), + joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + end_link_name="ee_link", + root_link_name="base_link", + max_iterations=1000, + pos_eps=1e-4, + rot_eps=1e-4, + dt=0.05, + damp=1e-6, + num_samples=30, + is_only_position_constraint=False, +) + +solver = PinocchioSolver(cfg) +``` + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, -1.0, 0.0, -0.722600], + # [ 0.0, 0.0, -1.0, -0.191450], + # [ 1.0, 0.0, 0.0, 0.079159], + # [ 0.0, 0.0, 0.0, 1.0 ]]]) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, -1.0, 0.0, -0.722600], + [ 0.0, 0.0, -1.0, -0.191450], + [ 1.0, 0.0, 0.0, 0.079159], + [ 0.0, 0.0, 0.0, 1.0 ]]]) + qpos_seed = torch.zeros((1, 6)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([True]) + # Convergence info: tensor([[0.0, -0.231429, 0.353367, 0.893100, 0.0, 0.555758]]) +``` + +## References + +* [Pinocchio Documentation](https://stack-of-tasks.github.io/pinocchio/) +* [CasADi Documentation](https://web.casadi.org/) diff --git a/docs/source/overview/sim/solvers/pytorch_solver.md b/docs/source/overview/sim/solvers/pytorch_solver.md new file mode 100644 index 00000000..7677a3cb --- /dev/null +++ b/docs/source/overview/sim/solvers/pytorch_solver.md @@ -0,0 +1,112 @@ +# PytorchSolver + +`PytorchSolver` is a high-performance inverse kinematics (IK) solver for robot manipulators, leveraging [pytorch_kinematics](https://github.com/UM-ARM-Lab/pytorch_kinematics) for efficient computation and seamless integration with PyTorch workflows. It supports both position and orientation constraints, joint limits, batch sampling, and GPU acceleration, making it suitable for real-time and large-scale applications. + +## Key Features + +* Full support for position-only or full pose (position + orientation) constraints +* Configurable convergence tolerance, damping, and iteration limits +* Enforces joint limits during optimization +* Batch sampling for robust IK seed initialization and solution diversity +* Efficient batched computation for multiple target poses +* PyTorch integration for GPU acceleration and tensor-based workflows +* Flexible configuration via `PytorchSolverCfg` class + +## Configuration + +The solver is configured using the `PytorchSolverCfg` class, which allows detailed control over solver parameters and robot model setup. + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers.pytorch_solver import PytorchSolver +from embodichain.lab.sim.solvers.pytorch_solver import PytorchSolverCfg +import torch + +cfg = PytorchSolverCfg( + urdf_path=get_data_path("UniversalRobots/UR5/UR5.urdf"), + joint_names=["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"], + end_link_name="ee_link", + root_link_name="base_link", + max_iterations=1000, + pos_eps=1e-4, + rot_eps=1e-4, + dt=0.05, + damp=1e-6, + num_samples=30, + is_only_position_constraint=False, +) + +solver = PytorchSolver(cfg) +``` + +### Dynamic Parameter Adjustment + +Solver parameters can be updated at runtime using `set_iteration_params` : + +```python +solver.set_iteration_params( + pos_eps=1e-5, + rot_eps=1e-5, + max_iterations=500, + num_samples=50, + damp=1e-7, +) +``` + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, -1.0, 0.0, -0.722600], + # [ 0.0, 0.0, -1.0, -0.191450], + # [ 1.0, 0.0, 0.0, 0.079159], + # [ 0.0, 0.0, 0.0, 1.0 ]]]) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, -1.0, 0.0, -0.722600], + [ 0.0, 0.0, -1.0, -0.191450], + [ 1.0, 0.0, 0.0, 0.079159], + [ 0.0, 0.0, 0.0, 1.0 ]]]) + qpos_seed = torch.zeros((1, 6)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([True], device='cuda:0') + # Convergence info: tensor([[0.0, -0.244575, 0.373442, 0.853886, 0.0, 0.588007]], device='cuda:0') + +``` + +## References + +* [pytorch_kinematics Documentation](https://github.com/UM-ARM-Lab/pytorch_kinematics) diff --git a/docs/source/overview/sim/solvers/srs_solver.md b/docs/source/overview/sim/solvers/srs_solver.md new file mode 100644 index 00000000..3cabb57e --- /dev/null +++ b/docs/source/overview/sim/solvers/srs_solver.md @@ -0,0 +1,133 @@ +# SRSSolver + +`SRSSolver` is a high-performance inverse kinematics (IK) solver specifically designed for 7-DOF manipulators with a Spherical-Rotational-Spherical (S-R-S) joint structure. This architecture is common in anthropomorphic and redundant industrial arms, providing high dexterity and redundancy for advanced manipulation tasks. SRSSolver supports batch computation, joint limits, GPU acceleration, and seamless integration with PyTorch workflows. + +## What is S-R-S Kinematics? + +The S-R-S (Spherical-Rotational-Spherical) structure refers to a 7-joint manipulator arrangement: + +* **Spherical (S)**: A 3-DOF spherical joint (often realized by three intersecting revolute joints), enabling arbitrary orientation in 3D space. +* **Rotational (R)**: A single revolute joint, typically located at the "elbow, " providing an extra degree of freedom for redundancy. +* **Spherical (S)**: Another 3-DOF spherical joint at the wrist, allowing full orientation control of the end-effector. + +This structure enables: +* **Redundancy**: The arm can reach the same pose with multiple joint configurations, useful for obstacle avoidance and singularity avoidance. +* **High Dexterity**: Suitable for tasks requiring complex manipulation and orientation. +* **Wide Application**: Common in humanoid robots and collaborative arms. + +## Key Features + +* Optimized for S-R-S 7-DOF kinematic chains +* Supports position and orientation constraints, joint limits +* Batch sampling and multi-solution output for robust IK +* GPU acceleration for large-scale or real-time applications +* Flexible configuration for DH parameters, joint limits, link lengths, and more + +## Configuration + +SRSSolver is configured via the `SRSSolverCfg` class, allowing detailed control over kinematic parameters and solver behavior. + +```python +from embodichain.data import get_data_path +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1ArmSide, + DexforceW1ArmKind, + DexforceW1Version, +) +from embodichain.lab.sim.robots.dexforce_w1.params import ( + W1ArmKineParams, +) +from embodichain.lab.sim.solvers.srs_solver import SRSSolver, SRSSolverCfg + +arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.RIGHT, + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + version=DexforceW1Version.V021, +) + +cfg = SRSSolverCfg( + urdf_path=get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf"), + joint_names=[f"{'RIGHT'}_J{i+1}" for i in range(7)], + end_link_name="left_ee", + root_link_name="left_arm_base", + dh_params=arm_params.dh_params, + qpos_limits=arm_params.qpos_limits, + T_e_oe=arm_params.T_e_oe, + T_b_ob=arm_params.T_b_ob, + link_lengths=arm_params.link_lengths, + rotation_directions=arm_params.rotation_directions, +) + +solver = SRSSolver(cfg, num_envs=1, device="cuda") +``` + +## Main Methods + +* `get_fk(self, qpos: torch.Tensor) -> torch.Tensor` + Computes the end-effector pose (homogeneous transformation matrix) for the given joint positions. + + **Parameters:** + + `qpos` (`torch.Tensor` or `list[float]`): Joint positions, shape `(num_envs, num_joints)` or `(num_joints,)`. + + **Returns:** + + `torch.Tensor`: End-effector pose(s), shape `(num_envs, 4, 4)`. + + **Example:** + +```python + fk = solver.get_fk(qpos=[0.0, 0.0, 0.0, 1.5708, 0.0, 0.0, 0.0]) + print(fk) + # Output: + # tensor([[[ 0.0, -1.0, 0.0, 0.0], + # [ 0.0, 0.0, -1.0, -0.33], + # [ 1.0, 0.0, 0.0, 0.3625], + # [ 0.0, 0.0, 0.0, 1.0]]]) +``` + +* `get_ik(self, target_xpos: torch.Tensor, qpos_seed: torch.Tensor = None, return_all_solutions: bool = False, jacobian: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]` + Computes joint positions (inverse kinematics) for the given target end-effector pose. + + **Parameters:** + + `target_xpos` (`torch.Tensor`): Target end-effector pose(s), shape `(num_envs, 4, 4)`. + + `qpos_seed` (`torch.Tensor`, optional): Initial guess for joint positions, shape `(num_envs, num_joints)`. If `None`, a default is used. + + `return_all_solutions` (`bool`, optional): If `True`, returns all possible solutions. Default is `False`. + + `jacobian` (`torch.Tensor`, optional): Custom Jacobian. Usually not required. + + **Returns:** + + `Tuple[torch.Tensor, torch.Tensor]`: + - First element: Joint positions, shape `(num_envs, num_joints)`. + - Second element: Convergence info or error for each environment. + + **Example:** + +```python + import torch + xpos = torch.tensor([[[ 0.0, -1.0, 0.0, -0.0], + [ 0.0, 0.0, -1.0, -0.33], + [ 1.0, 0.0, 0.0, 0.3625], + [ 0.0, 0.0, 0.0, 1.0]]]) + qpos_seed = torch.zeros((1, 7)) + qpos_sol, info = solver.get_ik(target_xpos=xpos) + print("IK solution:", qpos_sol) + print("Convergence info:", info) + # IK solution: tensor([True], device='cuda:0') + # Convergence info: tensor([[[-0.022269, 0.045214, -0.022273, -1.570796, 0.045204, -0.001007, 0.044519]]], device='cuda:0') +``` + +## References + +The following key references provide the theoretical foundation and algorithmic background for this implementation: + +* **Analytical Inverse Kinematic Computation for 7-DOF Redundant Manipulators With Joint Limits and Its Application to Redundancy Resolution** + Masayuki Shimizu, Hiromu Kakuya, Woo-Keun Yoon, Kosei Kitagaki, Kazuhiro Kosuge + *IEEE Transactions on Robotics*, 2008 + [DOI: 10.1109/TRO.2008.2003266](https://doi.org/10.1109/TRO.2008.2003266) + This paper presents an analytical approach for solving the inverse kinematics of 7-DOF redundant manipulators, including joint limit handling and redundancy resolution strategies. + +* **Position-based kinematics for 7-DoF serial manipulators with global configuration control, joint limit, and singularity avoidance** + Carlos Faria, Flora Ferreira, Wolfram Erlhagen, Sérgio Monteiro, Estela Bicho + *Mechanism and Machine Theory*, 2018 + [DOI: 10.1016/j.mechmachtheory.2017.10.025](https://doi.org/10.1016/j.mechmachtheory.2017.10.025) + This work introduces position-based kinematic algorithms for 7-DOF manipulators, focusing on global configuration control, joint limit enforcement, and singularity avoidance. + +These publications provide the mathematical models and solution strategies that underpin the SRSSolver's design and functionality. diff --git a/docs/source/overview/vla/index.rst b/docs/source/overview/vla/index.rst new file mode 100644 index 00000000..dc775db1 --- /dev/null +++ b/docs/source/overview/vla/index.rst @@ -0,0 +1,2 @@ +Vision-Language-Action Models +================== diff --git a/docs/source/quick_start/docs.md b/docs/source/quick_start/docs.md new file mode 100644 index 00000000..c62a3d71 --- /dev/null +++ b/docs/source/quick_start/docs.md @@ -0,0 +1,18 @@ +# Build Documentation + +## 1. Install the documentation dependencies + +```bash +pip install -r docs/requirements.txt +``` + +> If you have issue like `locale.Error: unsupported locale setting`, please enter `export LC_ALL=C.UTF-8; export LANG=C.UTF-8` before build the API. + +## 2. Build the HTML site + +```bash +cd docs +make html +``` + +Then you can preview the documentation in your browser at `docs/build/html/index.html`. diff --git a/docs/source/quick_start/install.md b/docs/source/quick_start/install.md new file mode 100644 index 00000000..8a365939 --- /dev/null +++ b/docs/source/quick_start/install.md @@ -0,0 +1,71 @@ +# Installation + +## System Requirements + +The following minimum system requirements are recommended to run EmbodiChain reliably. These are the tested configurations during development — other Linux distributions and versions may work but are not officially supported. + +- Operating System: Linux (x86_64) + - Recommended distributions: Ubuntu 20.04 LTS or Ubuntu 22.04 LTS + +- NVIDIA GPU and drivers: + - Hardware: NVIDIA GPU with compute capability 7.0 or higher (e.g., RTX 20 series, RTX 30 series, A100, etc.) + - NVIDIA driver: 535 or higher (recommended 570) + - CUDA Toolkit: any of 11.8 — 12.8 (we test primarily with 11.8 and 12.x) + +- Python: + - Supported Python versions: + - Python 3.9 + - Python 3.10 + - Use a virtual environment (venv, virtualenv, or conda) to isolate dependencies + +Notes: + +- Ensure your NVIDIA driver and CUDA toolkit versions are compatible with your chosen PyTorch wheel. +- We recommend installing PyTorch from the official PyTorch instructions for your CUDA version: https://pytorch.org/get-started/locally/ + +--- + +### Recommended: Install with Docker + +We strongly recommend using our pre-configured Docker environment, which contains all necessary dependencies. + +```bash +docker pull dexforce/embodichain:ubuntu22.04-cuda12.8 +``` + +--- + + +### Install EmbodiChain + +> **We strongly recommend using a virtual environment to avoid dependency conflicts.** + +Install `DexSim` manually: +```bash +# If you are using Python 3.10 +pip install http://pyp.open3dv.site:2345/packages/dexsim_engine-0.3.6-cp310-cp310-manylinux_2_31_x86_64.whl +``` + +> We are working on uploading DexSim to PyPI for easier installation. Please stay tuned! + + +Clone the EmbodiChain repository: +```bash +git clone https://github.com/DexForce/EmbodiChain.git +``` + +Install the project in development mode: + +```bash +pip install -e . +``` + + +### Verify Installation +To verify that EmbodiChain is installed correctly, run a simple demo script to create a simulation scene: + +```bash + python scripts/tutorials/sim/create_scene.py +``` +--- + diff --git a/docs/source/resources/roadmap.md b/docs/source/resources/roadmap.md new file mode 100644 index 00000000..89bf504a --- /dev/null +++ b/docs/source/resources/roadmap.md @@ -0,0 +1,25 @@ +# Roadmap + +Currently, EmbodiChain is under active development. Our plan for the feature roadmap is as follows: + +- Simulation: + - Rendering: + - Improve ray-tracing backend performance and fix some konwn issues. + - Add a high performance Hybrid rendering backend for better visual quality and speed trade-off. + - Support a more efficient real-time denoiser. + - Add a new rasterization backend for basic rendering tasks. + - Support 3DGS rendering mode (If we have enough bandwidth). + - Physics: + - Improve soft body simulation stability and add more examples and tasks. + - We are also exploring how to integrate [newton physics](https://github.com/newton-physics/newton) into EmbodiChain as an alternative physics backend. + - Sensors: + - Add contact and force sensors with examples. + - Kinematics Solvers: + - Improve the existing IK solver performance and stability (especially SRSSolver and OPWSolver). + - Motion Generation: + - Add more advanced motion generation methods and examples. + - Useful Tools: + - Add a robot workspace analysis tool for better visualization and sampling of robot accessible workspace. + - We are working on USD support for EmbodiChain to enable better scene creation and asset management. + +- Models and Training Workflows: diff --git a/docs/source/resources/robot/cobotmagic.md b/docs/source/resources/robot/cobotmagic.md new file mode 100644 index 00000000..ac7a959f --- /dev/null +++ b/docs/source/resources/robot/cobotmagic.md @@ -0,0 +1,112 @@ +# CobotMagic + +CobotMagic is a versatile dual-arm collaborative robot developed by AgileX Robotics. It is widely used in simulation, education, industry, and service scenarios. All examples in this document are based on the latest PourWater task environment. + +## Key Features + +- **Dual-arm parallel structure** supporting multiple layouts (standard, face-to-face, custom) +- **Configurable gripper models** (V70/V100) and material types (NORMAL/NEW_UV/NO_MATERIAL) +- **Flexible URDF assembly** and simulation parameter configuration +- **Compatible with SimulationManager**, supporting multi-arena parallel simulation +- **High degree of freedom**: 16 axes (dual arms + grippers, each gripper includes 1 mimic joint) +- **Customizable control groups** for flexible task decomposition and extension + +--- + +## Robot Parameters + +| Parameter | Description | +|-----------------------|------------------------------------------------------------------| +| Number of joints | 16 (dual arms + grippers, each gripper includes a mimic joint) | +| Gripper models | V70 / V100 | +| Layout types | NORMAL (standard) / FACE_TO_FACE / CUSTOM | +| Initial base height | 0.7775m (adjustable) | +| Mobile base support | **Not supported** in the current version (fixed base only) | + +> **Note:** The current version of CobotMagic does **not** support a mobile base. All examples and environments assume a fixed base configuration. + + +--- + +## Quick Initialization Example + +```python +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.robots import CobotMagicCfg + +config = SimulationManagerCfg(headless=False, sim_device="cpu") +sim = SimulationManager(config) +sim.build_multiple_arenas(2) # Supports parallel simulation in multiple arenas +sim.set_manual_update(False) + +robot = sim.add_robot(cfg=CobotMagicCfg().from_dict({})) +``` + +--- + +## Configuration Parameters + +### 1. Main Configuration Items + +- **uid**: Unique identifier for the robot, default is "CobotMagic" +- **urdf_cfg**: URDF configuration, supports multi-component assembly (e.g., dual arms) +- **control_parts**: Control groups for independent control of each arm and gripper +- **solver_cfg**: Inverse kinematics solver configuration, customizable end-effector and base +- **drive_pros**: Joint drive properties (stiffness, damping, max effort, etc.) +- **attrs**: Rigid body physical attributes (mass, friction, damping, etc.) + +### 2. Custom Usage Example + +```python +from embodichain.lab.sim.robots import CobotMagicCfg + +custom_cfg = { + "init_pos": [0.0, 0.0, 1.0], # Initial position + # Add more custom parameters as needed +} +cfg = CobotMagicCfg.from_dict(custom_cfg) +robot = sim.add_robot(cfg=cfg) +``` + +### 3. Control Group Example + +```python +control_parts = { + "left_arm": ["LEFT_JOINT1", ..., "LEFT_JOINT6"], + "left_eef": ["LEFT_JOINT7", "LEFT_JOINT8"], + "right_arm": ["RIGHT_JOINT1", ..., "RIGHT_JOINT6"], + "right_eef": ["RIGHT_JOINT7", "RIGHT_JOINT8"], +} +``` + +--- + +## Common Issues & Notes + +- **URDF Path**: Ensure the corresponding URDF files exist in the data path (e.g., `CobotMagicArm/CobotMagicWithGripperV100.urdf`). +- **Simulation Device**: Supports CPU/GPU simulation. Set `sim_device` according to your hardware. +- **Multi-arena Simulation**: Use `build_multiple_arenas(n)` to quickly create n parallel simulation environments. +- **Gripper Model Switching**: To switch gripper models, modify the URDF path in `urdf_cfg`. +- **Mobile Base**: Not supported in the current version; related parameters will be ignored. + +--- + +## References + +- [AgileX CobotMagic Product Page](https://global.agilex.ai/products/cobot-magic) +- Related URDF file path: `CobotMagicArm/` + - CobotMagicWithGripperV70.urdf + - CobotMagicWithGripperV100.urdf + - CobotMagicNoGripper.urdf +- [embodichain Simulation Platform Documentation](https://github.com/dexforce/embodichain) + +--- + +## References + +- [AgileX CobotMagic Product Page](https://global.agilex.ai/products/cobot-magic) +- Related URDF file paths (located in `CobotMagicArm/`): + - `CobotMagicWithGripperV70.urdf` + - `CobotMagicWithGripperV100.urdf` + - `CobotMagicNoGripper.urdf` +- [embodichain Simulation Platform Documentation](https://github.com/dexforce/embodichain) diff --git a/docs/source/resources/robot/dexforce_w1.md b/docs/source/resources/robot/dexforce_w1.md new file mode 100644 index 00000000..69c4860e --- /dev/null +++ b/docs/source/resources/robot/dexforce_w1.md @@ -0,0 +1,54 @@ +# Dexforce W1 + +Dexforce W1 is a versatile robot developed by DexForce Technology Co., Ltd., supporting both industrial and anthropomorphic arm types. It is suitable for various simulation and real-world application scenarios. + +## Key Features + +- Supports multiple arm types (industrial, anthropomorphic) +- Configurable left/right hand brand and version +- Flexible URDF assembly and simulation configuration +- Compatible with SimulationManager simulation environment + + +## Usage in Simulation Environment + +""" +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1HandBrand, DexforceW1ArmSide, DexforceW1ArmKind, DexforceW1Version +) +from embodichain.lab.sim.robots.dexforce_w1.utils import build_dexforce_w1_cfg + +config = SimulationManagerCfg(headless=False, sim_device="cpu") +sim = SimulationManager(config) +sim.build_multiple_arenas(4) +sim.set_manual_update(True) + +hand_types = { + DexforceW1ArmSide.LEFT: DexforceW1HandBrand.BRAINCO_HAND, + DexforceW1ArmSide.RIGHT: DexforceW1HandBrand.BRAINCO_HAND, +} +hand_versions = { + DexforceW1ArmSide.LEFT: DexforceW1Version.V021, + DexforceW1ArmSide.RIGHT: DexforceW1Version.V021, +} + +cfg = build_dexforce_w1_cfg( + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + hand_types=hand_types, + hand_versions=hand_versions, +) + +robot = sim.add_robot(cfg=cfg) +print("DexforceW1 robot added to the simulation.") +``` + +## Type Descriptions + + +| Type | Options / Values | Description | +|-------------------------|-------------------------------------------------------|------------------------------------| +| `DexforceW1ArmKind` | `ANTHROPOMORPHIC`, `INDUSTRIAL` | Arm type | +| `DexforceW1HandBrand` | `BRAINCO_HAND`, `DH_PGC_GRIPPER`, `DH_PGC_GRIPPER_M` | Hand brand | +| `DexforceW1Version` | `V021` | Release version | +| `DexforceW1ArmSide` | `LEFT`, `RIGHT` | Left/right hand identifier | diff --git a/docs/source/resources/robot/index.rst b/docs/source/resources/robot/index.rst new file mode 100644 index 00000000..621c5294 --- /dev/null +++ b/docs/source/resources/robot/index.rst @@ -0,0 +1,9 @@ +Supported Robots +====================== + +.. toctree:: + :maxdepth: 1 + + Dexforce W1 + CobotMagic + \ No newline at end of file diff --git a/docs/source/resources/task/index.rst b/docs/source/resources/task/index.rst new file mode 100644 index 00000000..7a83855d --- /dev/null +++ b/docs/source/resources/task/index.rst @@ -0,0 +1,7 @@ +Supported Tasks +====================== + +.. toctree:: + :maxdepth: 1 + + Pour Water \ No newline at end of file diff --git a/docs/source/resources/task/pour_water.md b/docs/source/resources/task/pour_water.md new file mode 100644 index 00000000..a5be3b73 --- /dev/null +++ b/docs/source/resources/task/pour_water.md @@ -0,0 +1,3 @@ +# Pour Water + +Zhao Runyi is pouring water now, please do not disturb him... \ No newline at end of file diff --git a/docs/source/tutorial/basic_env.rst b/docs/source/tutorial/basic_env.rst new file mode 100644 index 00000000..a0b8fabf --- /dev/null +++ b/docs/source/tutorial/basic_env.rst @@ -0,0 +1,185 @@ + +.. _tutorial_create_basic_env: + +Creating a Basic Environment +============================ + +.. currentmodule:: embodichain.lab.gym + +This tutorial shows you how to create a simple robot learning environment using EmbodiChain's Gym interface. You'll learn how to inherit from the base environment class, set up robots and objects, define actions and observations, and run training scenarios. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``random_reach.py`` script in the ``scripts/tutorials/gym`` directory. + +.. dropdown:: Code for random_reach.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +This tutorial demonstrates how to create a custom RL environment by inheriting from :class:`envs.BaseEnv`. The environment implements a simple reach task where a robot arm tries to reach randomly positioned targets. + +Environment Registration +------------------------- + +First, we register the environment with the Gymnasium registry using the :func:`utils.registration.register_env` decorator: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :start-at: @register_env("RandomReach-v1", max_episode_steps=100, override=True) + :end-at: class RandomReachEnv(BaseEnv): + +The decorator parameters define: + +- **Environment ID**: ``"RandomReach-v1"`` - unique identifier for the environment +- **max_episode_steps**: Maximum steps per episode (100 in this case) +- **override**: Whether to override existing environment with same ID + +Environment Initialization +--------------------------- + +The ``__init__`` method configures the simulation environment and calls the parent constructor: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :lines: 25-46 + +Key configuration options include: + +- **num_envs**: Number of parallel environments to run +- **headless**: Whether to run without GUI (useful for training) +- **device**: Computation device ("cpu" or "cuda") + +Robot Setup +------------ + +The `_setup_robot` method loads and configures the robot for the environment: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :start-at: def _setup_robot(self, **kwargs) -> Robot: + :end-at: return robot + +This method demonstrates: + +1. **URDF Loading**: Using data module to access robot URDF files +2. **Robot Configuration**: Setting initial position and joint configuration +3. **Action Space Definition**: Creating action space based on joint limits + +The action space is automatically derived from the robot's joint limits, ensuring actions stay within valid ranges. + +Scene Preparation +----------------- + +The :meth:`_prepare_scene` method adds additional objects to the simulation environment: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :lines: 72-84 + +In this example, we add a kinematic cube that serves as a visual target. The cube is configured with: + +- **No collision**: ``enable_collision=False`` for visualization only +- **Kinematic body**: Can be moved programmatically without physics +- **Custom size**: Small 3cm cube for target visualization +- **initial position**: Initially placed at a fixed location + +State Updates +------------- + +The `_update_sim_state` method is called at each simulation step to update object states: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :start-at: def _update_sim_state(self, **kwargs) -> None: + :end-at: self.cube.set_local_pose(pose=pose) + +This method randomizes the cube's position. The pose is updated for all parallel environments simultaneously. + +Note that this method is called after perform action execution and simulation update but before observation collection. For more details, see :meth:`envs.BaseEnv.step`. + +Action Execution +---------------- + +The `_step_action` method applies actions to the robot: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :start-at: def _step_action(self, action: EnvAction) -> EnvAction: + :end-at: return action + +In this simple environment, actions directly set joint positions. More complex environments might: + +- Convert actions to joint torques or velocities +- Apply action filtering or scaling +- Implement inverse kinematics for end-effector control + +Observation Extension +--------------------- + +The default observations include the following keys: + +- `robot`: Robot proprioception data (joint positions, velocities, efforts) +- `sensor` (optional): Data from any sensors (e.g., cameras) + +The `_extend_obs` method allows you to add custom observations: + +.. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py + :language: python + :start-at: def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: + :end-at: return obs + +While commented out in this example, you can add custom data like: + +- Object positions and orientations +- Distance calculations +- Custom sensor readings +- Task-specific state information + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the environment: + +.. code-block:: bash + + cd /path/to/embodichain + python scripts/tutorials/gym/random_reach.py + +You can customize the execution with command-line options: + +.. code-block:: bash + + # Run multiple parallel environments + python scripts/tutorials/gym/random_reach.py --num_envs 4 + + # Run with GPU acceleration + python scripts/tutorials/gym/random_reach.py --device cuda + + # Run in headless mode (no GUI) + python scripts/tutorials/gym/random_reach.py --headless + +The script demonstrates: + +1. **Environment Creation**: Using ``gym.make()`` with custom parameters +2. **Episode Loop**: Running multiple episodes with random actions +3. **Performance Monitoring**: Calculating frames per second (FPS) + +Key Features Demonstrated +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This tutorial showcases several important features of EmbodiChain environments: + +1. **Gymnasium Integration**: Full compatibility with the Gymnasium API +2. **Parallel Environments**: Running multiple environments simultaneously for efficient training +3. **Robot Integration**: Easy loading and control of robotic systems +4. **Custom Objects**: Adding and manipulating scene objects +5. **Flexible Actions**: Customizable action spaces and execution methods +6. **Extensible Observations**: Adding task-specific observation data diff --git a/docs/source/tutorial/create_scene.rst b/docs/source/tutorial/create_scene.rst new file mode 100644 index 00000000..4f8dd314 --- /dev/null +++ b/docs/source/tutorial/create_scene.rst @@ -0,0 +1,93 @@ +Creating a simulation scene +========================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial shows how to create a basic simulation scene using SimulationManager. It covers the setup of the simulation context, adding rigid objects, and running the simulation loop. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``create_scene.py`` script in the ``scripts/tutorials/sim`` directory. + +.. dropdown:: Code for create_scene.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/create_scene.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Configuring the simulation +-------------------------- + +The first step is to configure the simulation environment. This is done using the :class:`SimulationManagerCfg` data class, which allows you to specify various parameters like window dimensions, headless mode, physics timestep, simulation device (CPU/GPU), and rendering options like ray tracing. + +Command-line arguments are parsed using ``argparse`` to allow for easy customization of the simulation from the terminal. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_scene.py + :language: python + :start-at: # Parse command line arguments + :end-at: sim.build_multiple_arenas(args.num_envs, space=3.0) + +There are two kinds of physics mode in :class:`SimulationManager`: + +- `manual`: The physics updates only when the user calls the :meth:`SimulationManager.update` function. This mode is used for robot learning tasks where precise control over simulation steps is required. Enabled by setting :meth:`SimulationManager.set_manual_update` to True. +- `auto`: The physics updates in a standalone thread, which enable asynchronous rendering and physics stepping. This mode is suitable for visualizations and demos for digital twins applications. This is the default mode. + +If `num_envs` is greater than 1, :meth:`SimulationManager.build_multiple_arenas` should be used to create multiple simulation arenas. + +Adding objects to the scene +--------------------------- + +With the simulation context created, we can add objects. This tutorial demonstrates adding a dynamic rigid cube to the scene using the :meth:`SimulationManager.add_rigid_object` method. The object's properties, such as its shape, initial position, and physics attributes (mass, friction, restitution), are defined through a configuration object, :class:`cfg.RigidObjectCfg`. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_scene.py + :language: python + :start-at: # Add objects to the scene + :end-at: init_pos=[0.0, 0.0, 1.0], + +Running the simulation +---------------------- + +The simulation is advanced through a loop in the ``run_simulation`` function. Before starting the loop, GPU physics is initialized if a CUDA device is used. + +Inside the loop, :meth:`SimulationManager.update` is called to step the physics simulation forward. The script also includes logic to calculate and print the Frames Per Second (FPS) to monitor performance. The simulation runs until it's manually stopped with ``Ctrl+C``. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_scene.py + :language: python + :start-at: def run_simulation(sim: SimulationManager): + :end-at: last_step = step_count + +Exiting the simulation +---------------------- + +Upon exiting the simulation loop (e.g., by a ``KeyboardInterrupt``), it's important to clean up resources. The :meth:`SimulationManager.destroy` method is called in a ``finally`` block to ensure that the simulation is properly terminated and all allocated resources are released. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_scene.py + :language: python + :start-at: except KeyboardInterrupt: + :end-at: sim.destroy() + + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the script and see the result, execute the following command: + +.. code-block:: bash + + python scripts/tutorials/sim/create_scene.py + +A window should appear showing a cube dropping onto a flat plane. To stop the simulation, you can either close the window or press ``Ctrl+C`` in the terminal. + +You can also pass arguments to customize the simulation. For example, to run in headless mode with `n` parallel environments using specified device: + +.. code-block:: bash + + python scripts/tutorials/sim/create_scene.py --headless --num_envs --device + +Now that we have a basic understanding of how to create a scene, let's move on to more advanced topics. diff --git a/docs/source/tutorial/create_softbody.rst b/docs/source/tutorial/create_softbody.rst new file mode 100644 index 00000000..cfaf85c8 --- /dev/null +++ b/docs/source/tutorial/create_softbody.rst @@ -0,0 +1,68 @@ +Creating a soft-body simulation +=============================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial shows how to create a soft-body simulation using ``SimulationManager``. It covers the setup of the simulation context, adding a deformable mesh (soft object), and running the simulation loop. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``create_softbody.py`` script in the ``scripts/tutorials/sim`` directory. + +.. dropdown:: Code for create_softbody.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/create_softbody.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Configuring the simulation +-------------------------- + +The first step is to configure the simulation environment. This is done using the :class:`SimulationManagerCfg` data class, which allows you to specify parameters like window dimensions, headless mode, physics timestep, simulation device (CPU/GPU), and rendering options like ray tracing. Reminded that soft body simulation can only run on cuda deive. + + +.. literalinclude:: ../../../scripts/tutorials/sim/create_softbody.py + :language: python + :start-at: # Configure the simulation + :end-at: print("[INFO]: Scene setup complete!") + +If ``num_envs`` is greater than 1, :meth:`SimulationManager.build_multiple_arenas` should be used to create multiple simulation arenas. + +Adding a soft body to the scene +------------------------------- + +With the simulation context created, we can add a soft (deformable) object. This tutorial demonstrates adding a soft-body cow mesh to the scene using the :meth:`SimulationManager.add_soft_object` method. The object's geometry and physical parameters are defined through configuration objects: + +- :class:`cfg.MeshCfg` for the mesh shape (``cow.obj``) +- :class:`cfg.SoftbodyVoxelAttributesCfg` for voxelization and simulation mesh resolution +- :class:`cfg.SoftbodyPhysicalAttributesCfg` for material properties (Young's modulus, Poisson's ratio, density, frictions, solver iterations) + +.. literalinclude:: ../../../scripts/tutorials/sim/create_softbody.py + :language: python + :start-at: # add softbody to the scene + :end-at: print("[INFO]: Add soft object complete!") + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the script and see the result, execute the following command: + +.. code-block:: bash + + python scripts/tutorials/sim/create_softbody.py + +A window should appear showing a soft-body cow mesh falling onto a ground plane. To stop the simulation, you can either close the window or press ``Ctrl+C`` in the terminal. + +You can also pass arguments to customize the simulation. For example, to run in headless mode with ``n`` parallel environments using the specified device: + +.. code-block:: bash + + python scripts/tutorials/sim/create_softbody.py --headless --num_envs --device + +Now that we have a basic understanding of how to create a soft-body scene, let's move on to more advanced topics. diff --git a/docs/source/tutorial/gizmo.rst b/docs/source/tutorial/gizmo.rst new file mode 100644 index 00000000..b0d39b2c --- /dev/null +++ b/docs/source/tutorial/gizmo.rst @@ -0,0 +1,270 @@ +.. _tutorial_gizmo_robot: + +Interactive Robot Control with Gizmo +===================================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial demonstrates how to use the Gizmo class for interactive robot manipulation in SimulationManager. You'll learn how to create a gizmo attached to a robot's end-effector and use it for real-time inverse kinematics (IK) control, allowing intuitive manipulation of robot poses through visual interaction. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``gizmo_robot.py`` script in the ``scripts/tutorials/sim`` directory. + +.. dropdown:: Code for gizmo_robot.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/gizmo_robot.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + + +Similar to the previous tutorial on robot simulation, we use the :class:`SimulationManager` class to set up the simulation environment. If you haven't read that tutorial yet, please refer to :doc:`robot` first. + + + +**Important:** Gizmo only supports single environment mode (`num_envs=1`). Using multiple environments will raise an exception. + +All gizmo creation, visibility, and destruction operations must be managed via the SimulationManager API: + +.. code-block:: python + + # Toggle visibility for a gizmo + sim.toggle_gizmo_visibility("ur10_gizmo_test", control_part="arm") + + # Set visibility explicitly + sim.set_gizmo_visibility("ur10_gizmo_test", visible=False, control_part="arm") + +Always use the SimulationManager API to control gizmo visibility and lifecycle. Do not operate on the Gizmo instance directly. + +What is a Gizmo? +----------------- + +A Gizmo is an interactive visual tool that allows users to manipulate simulation objects in real-time through mouse interactions. In robotics applications, gizmos are particularly useful for: + +- **Interactive Robot Control**: Drag the robot's end-effector to desired positions +- **Inverse Kinematics**: Automatically solve joint angles to reach target poses +- **Real-time Manipulation**: Provide immediate visual feedback during robot motion planning +- **Debugging and Visualization**: Test robot reachability and workspace limits + +The :class:`objects.Gizmo` class provides a unified interface for interactive control of different simulation elements including robots, rigid objects, and cameras. + +Setting up Robot Configuration +------------------------------ + +First, we configure a UR10 robot with an IK solver for end-effector control: + +.. literalinclude:: ../../../scripts/tutorials/sim/gizmo_robot.py + :language: python + :start-at: # Create UR10 robot configuration + :end-at: robot = sim.add_robot(cfg=robot_cfg) + +Key components of the robot configuration: + +- **URDF Configuration**: Loads the robot's kinematic and visual model +- **Control Parts**: Defines which joints can be controlled (``"Joint[1-6]"`` for UR10) +- **IK Solver**: :class:`solvers.PinkSolverCfg` provides inverse kinematics capabilities +- **Drive Properties**: Sets stiffness and damping for joint control + +The IK solver is crucial for gizmo functionality, as it enables the robot to automatically calculate joint angles needed to reach gizmo target positions. + +Creating and Attaching a Gizmo +------------------------------- + + + +After configuring the robot, enable the gizmo for interactive control using the SimulationManager API (supports robot, rigid object, camera; key is `uid:control_part`): + +.. code-block:: python + + # Enable gizmo for the robot's arm + sim.enable_gizmo(uid="ur10_gizmo_test", control_part="arm") + if not sim.has_gizmo("ur10_gizmo_test", control_part="arm"): + logger.log_error("Failed to enable gizmo!") + return + + + +The Gizmo instance is managed internally by SimulationManager. If you need to access it: + +.. code-block:: python + + gizmo = sim.get_gizmo("ur10_gizmo_test", control_part="arm") + + + +The Gizmo system will automatically: + +1. **Detect Target Type**: Identify that the target is a robot (vs. rigid object or camera) +2. **Find End-Effector**: Locate the robot's end-effector link (``ee_link`` for UR10) +3. **Create Proxy Object**: Generate a small invisible cube at the end-effector position +4. **Set Up IK Callback**: Configure the gizmo to trigger IK solving when moved + +How Gizmo-Robot Interaction Works +---------------------------------- + + + +The gizmo-robot interaction follows this efficient workflow: + +1. **Gizmo Callback**: When the user drags the gizmo, a callback function updates the proxy object's transform +2. **Deferred IK Solving**: Instead of solving IK immediately in the callback (which causes UI lag), the target transform is stored +3. **Update Loop**: During each simulation step, ``gizmo.update()`` solves IK and applies joint commands +4. **Robot Motion**: The robot smoothly moves to follow the gizmo position + +This design separates UI responsiveness from computational IK solving, ensuring smooth interaction even with complex robots. + +The Simulation Loop +------------------- + + + +In the main loop, simply call `sim.update_gizmos()`. There is no need to manually update any Gizmo instance. + + + +.. code-block:: python + + def run_simulation(sim: SimulationManager): + step_count = 0 + try: + last_time = time.time() + last_step = 0 + while True: + time.sleep(0.033) # 30Hz + sim.update_gizmos() # Update all gizmos + step_count += 1 + # ...performance statistics, etc... + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + sim.destroy() # Release all resources + logger.log_info("Simulation terminated successfully") + + + +Main loop highlights: + +- **Gizmo update**: Only `sim.update_gizmos()` is needed, no `gizmo.update()` +- **Performance monitoring**: Optional FPS statistics +- **Resource cleanup**: Only `sim.destroy()` is needed, no manual Gizmo destruction +- **Graceful shutdown**: Supports Ctrl+C interruption + +Gizmo Lifecycle Management +-------------------------- + + + + +Gizmo lifecycle is managed by SimulationManager: + +- Enable: `sim.enable_gizmo(...)` +- Update: Main loop automatically calls `sim.update_gizmos()` +- Destroy/disable: `sim.disable_gizmo(...)` or `sim.destroy()` (recommended) + +There is no need to manually create or destroy Gizmo instances. All resources are managed by SimulationManager. + +Available Gizmo Methods +----------------------- + + + + +If you need to access the underlying Gizmo instance (via `sim.get_gizmo`), you can use the following methods: + +**Transform Control:** + +- ``set_world_pose(pose)``: Set gizmo world position and orientation +- ``get_world_pose()``: Get current gizmo world transform +- ``set_local_pose(pose)``: Set gizmo local transform relative to parent +- ``get_local_pose()``: Get gizmo local transform + + + +**Visual properties (strongly recommend using SimulationManager API):** + +- ``sim.toggle_gizmo_visibility(uid, control_part=None)``: Toggle gizmo visibility +- ``sim.set_gizmo_visibility(uid, visible, control_part=None)``: Set gizmo visibility + +**Hierarchy Management:** + +- ``get_parent()``: Get gizmo's parent node in scene hierarchy +- ``get_name()``: Get gizmo node name for debugging +- ``detach()``: Disconnect gizmo from current target +- ``attach(target)``: Attach gizmo to a new simulation object + +Running the Tutorial +-------------------- + +To run the gizmo robot tutorial: + +.. code-block:: bash + + cd scripts/tutorials/sim + python gizmo_robot.py --device cpu + +Command-line options: + +- ``--device cpu|cuda``: Choose simulation device +- ``--num_envs N``: Number of parallel environments +- ``--headless``: Run without GUI for automated testing +- ``--enable_rt``: Enable ray tracing for better visuals + +Once running: + +1. **Mouse Interaction**: Click and drag the gizmo (colorful axes) to move the robot +2. **Real-time IK**: Watch the robot joints automatically adjust to follow the gizmo +3. **Workspace Limits**: Observe how the robot behaves at workspace boundaries +4. **Performance**: Monitor FPS in the console output + +Tips and Best Practices +------------------------ + + + +**Performance optimization:** + +- Only call ``sim.update_gizmos()`` in the main loop, no need for ``gizmo.update()`` +- Reduce IK solver iterations for better real-time performance if needed +- Use ``set_manual_update(False)`` for smoother interaction + + + +**Debugging tips:** + +- Check console output for IK solver success/failure messages +- Use ``get_world_pose()`` to check gizmo position (if needed) +- Monitor FPS to identify performance bottlenecks + + + +**Robot compatibility:** + +- Ensure your robot is configured with a correct IK solver +- Check the end-effector (EE) link name +- Test joint limits and workspace boundaries + + + +**Visualization customization:** + +- Adjust gizmo appearance via Gizmo config (e.g., ``set_line_width()``; requires access to the instance via `sim.get_gizmo`) +- Adjust gizmo scale according to robot size +- Enable collision for debugging if needed + +Next Steps +---------- + +After mastering basic gizmo usage, you can explore: + +- **Multi-robot Gizmos**: Attach gizmos to multiple robots simultaneously +- **Custom Gizmo Callbacks**: Implement application-specific interaction logic +- **Gizmo with Rigid Objects**: Use gizmos for interactive object manipulation +- **Advanced IK Configuration**: Fine-tune solver parameters for specific robots + +For more advanced robot control and simulation features, refer to the complete :doc:`robot` tutorial and the API documentation for :class:`objects.Gizmo` and :class:`solvers.PinkSolverCfg`. diff --git a/docs/source/tutorial/index.rst b/docs/source/tutorial/index.rst new file mode 100644 index 00000000..11eaf5d7 --- /dev/null +++ b/docs/source/tutorial/index.rst @@ -0,0 +1,19 @@ +Tutorials +========= + +.. toctree:: + :maxdepth: 1 + :hidden: + + create_scene + create_softbody + rigid_object_group + robot + solver + sensor + motion_gen + gizmo + basic_env + modular_env + rl + diff --git a/docs/source/tutorial/modular_env.rst b/docs/source/tutorial/modular_env.rst new file mode 100644 index 00000000..53175e97 --- /dev/null +++ b/docs/source/tutorial/modular_env.rst @@ -0,0 +1,237 @@ +.. _tutorial_modular_env: + +Creating a Modular Environment +============================== + +.. currentmodule:: embodichain.lab.gym + +This tutorial demonstrates how to create sophisticated robotic environments using EmbodiChain's modular architecture. You'll learn how to use the advanced :class:`envs.EmbodiedEnv` class with configuration-driven setup, event managers, observation managers, and randomization systems. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``modular_env.py`` script in the ``scripts/tutorials/gym`` directory. + +.. dropdown:: Code for modular_env.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +This tutorial showcases EmbodiChain's most powerful environment creation approach using the :class:`envs.EmbodiedEnv` class. Unlike the basic environment tutorial, this approach uses declarative configuration classes and manager systems for maximum flexibility and reusability. + +Event Configuration +------------------- + +Events define automated behaviors that occur during simulation. There are three types of supported modes: + +- `startup`: triggers once when the environment is initialized +- `reset`: triggers every time the environment is reset +- `interval`: triggers at fixed step intervals during simulation + +The :class:`ExampleEventCfg` demonstrates three types of events: + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 36-76 + +**Asset Replacement Event** + +The ``replace_obj`` event demonstrates dynamic asset swapping: + +- **Function**: :func:`envs.managers.events.replace_assets_from_group` +- **Mode**: ``"reset"`` - triggers at environment reset +- **Purpose**: Randomly selects different fork models from a folder + +**Light Randomization Event** + +The ``randomize_light`` event creates dynamic lighting conditions: + +- **Function**: :func:`envs.managers.randomization.rendering.randomize_light` +- **Mode**: ``"interval"`` - triggers every 5 steps +- **Parameters**: Randomizes position, color, and intensity within specified ranges + +**Material Randomization Event** + +The ``randomize_table_mat`` event varies visual appearance: + +- **Function**: :func:`envs.managers.randomization.rendering.randomize_visual_material` +- **Mode**: ``"interval"`` - triggers every 10 steps +- **Features**: Random textures from COCO dataset and base color variations + +for more randomization events, please refer + +Observation Configuration +------------------------- + +The default observation from :class:`envs.EmbodiedEnv` includes: +- `robot`: robot proprioceptive data (joint positions, velocities, efforts) +- `sensor`: all available sensor data (images, depth, segmentation, etc.) + +However, users always need to define some custom observation for specified learning tasks. To handle this, the observation manager system allows users to declaratively specify additional observations. + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 79-87 + +This configuration: + +- **Function**: :func:`envs.managers.observations.get_rigid_object_pose` +- **Mode**: ``"add"`` - appends data to observation dictionary +- **Name**: Custom key for the observation data +- **Target**: Tracks the fork object's pose in the scene + +For details documentation, see :class:`envs.managers.cfg.ObservationCfg`. + +Environment Configuration +------------------------- + +The main environment configuration inherits from :class:`envs.EmbodiedEnvCfg` and defines all scene components: + +**Robot Configuration** + +.. currentmodule:: embodichain.lab.sim.robots + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :start-at: robot: RobotCfg = DexforceW1Cfg.from_dict( + :end-at: ) + +Uses the pre-configured :class:`DexforceW1Cfg` with customizations: + +- **Version**: Specific robot variant (v021) +- **Arm Type**: Anthropomorphic configuration +- **Position**: Initial placement in the scene + +**Sensor Configuration** + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 104-118 + +.. currentmodule:: embodichain.lab.sim.sensors + +Configures a stereo camera system using :class:`StereoCameraCfg`: + +- **Resolution**: 960x540 pixels for realistic visual input +- **Features**: Depth sensing and segmentation masks enabled +- **Stereo Setup**: 6cm baseline between left and right cameras +- **Mounting**: Attached to robot's "eyes" frame + +**Lighting Configuration** + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 120-130 + +Defines scene illumination with controllable point lights: + +- **Type**: Point light for realistic shadows +- **Properties**: Configurable color, intensity, and position +- **UID**: Named reference for event system manipulation + +**Rigid Objects** + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 132-157 + +Multiple objects demonstrate different physics properties: + +*Table Configuration:* + +- **Shape**: Custom PLY mesh with UV mapping +- **Physics**: Kinematic body (movable but not affected by forces) +- **Material**: Friction and restitution properties for realistic contact + +*Fork Configuration:* + +- **Shape**: Detailed mesh from asset library +- **Scale**: Proportionally scaled for scene consistency +- **Physics**: Dynamic body affected by gravity and collisions + +**Articulated Objects** + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :lines: 159-169 + +Demonstrates complex mechanisms with moving parts: + +- **URDF**: Sliding drawer with joints and constraints +- **Positioning**: Placed on table surface for interaction + +Environment Implementation +-------------------------- + +The actual environment class is remarkably simple due to the configuration-driven approach: + +.. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py + :language: python + :start-at: @register_env("ModularEnv-v1", max_episode_steps=100, override=True) + :end-at: super().__init__(cfg, **kwargs) + +The :class:`envs.EmbodiedEnv` base class automatically: + +- Loads all configured scene components +- Sets up observation and action spaces +- Initializes event and observation managers +- Handles environment lifecycle (reset, step, etc.) + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the modular environment: + +.. code-block:: bash + + cd /path/to/embodichain + python scripts/tutorials/gym/modular_env.py + +The script demonstrates the complete workflow: + +1. **Configuration**: Creates an instance of ``ExampleCfg`` +2. **Registration**: Uses the registered environment ID +3. **Execution**: Runs episodes with zero actions to observe automatic behaviors + + +Manager System Benefits +~~~~~~~~~~~~~~~~~~~~~~~ + +The manager-based architecture provides several key advantages: + +**Event Managers** + +- **Modularity**: Reusable event functions across environments +- **Timing Control**: Flexible scheduling (reset, interval, condition-based) +- **Parameter Binding**: Type-safe configuration with validation +- **Extensibility**: Easy to add custom event behaviors + +**Observation Managers** + +- **Flexible Data**: Any simulation data can become an observation +- **Processing Pipeline**: Built-in normalization and transformation +- **Dynamic Composition**: Runtime observation space modification +- **Performance**: Efficient data collection and GPU acceleration + + +Key Features Demonstrated +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This tutorial showcases the most advanced features of EmbodiChain environments: + +1. **Configuration-Driven Design**: Declarative environment specification +2. **Manager Systems**: Modular event and observation handling +3. **Asset Management**: Dynamic loading and randomization +4. **Sensor Integration**: Realistic camera systems with stereo vision +5. **Physics Simulation**: Complex articulated and rigid body dynamics +6. **Visual Randomization**: Automated domain randomization +7. **Extensible Architecture**: Easy customization and extension points + + +This tutorial demonstrates the full power of EmbodiChain's modular environment system, providing the foundation for creating sophisticated robotic learning scenarios. diff --git a/docs/source/tutorial/motion_gen.rst b/docs/source/tutorial/motion_gen.rst new file mode 100644 index 00000000..712a314a --- /dev/null +++ b/docs/source/tutorial/motion_gen.rst @@ -0,0 +1,134 @@ + +.. _tutorial_motion_generator: + +Motion Generator +================ + +.. currentmodule:: embodichain.lab.sim.planners.motion_generator + +Overview +~~~~~~~~ + +The ``MotionGenerator`` class in EmbodiChain provides a unified and extensible interface for robot trajectory planning. It supports time-optimal trajectory generation (currently via TOPPRA), joint/Cartesian interpolation, and is designed for easy integration with RL, imitation learning, and classical control scenarios. + +Key Features +------------ + +- **Unified API**: One interface for multiple planning strategies (time-optimal, interpolation, etc.) +- **Constraint Support**: Velocity/acceleration constraints configurable per joint +- **Flexible Input**: Supports both joint space and Cartesian space waypoints +- **Extensible**: Easy to add new planners (RRT, PRM, etc.) +- **Integration Ready**: Can be used in RL, imitation learning, or classical pipelines + +Typical Usage +~~~~~~~~~~~~~ + +.. code-block:: python + + from embodichain.lab.sim.planners.motion_generator import MotionGenerator + + # Assume you have a robot instance and uid + motion_gen = MotionGenerator( + robot=robot, + uid="arm", + default_velocity=0.2, + default_acceleration=0.5 + ) + + # Plan a joint-space trajectory + current_state = {"position": [0, 0, 0, 0, 0, 0]} + target_states = [{"position": [0.5, 0.2, 0, 0, 0, 0]}] + success, positions, velocities, accelerations, times, duration = motion_gen.plan( + current_state=current_state, + target_states=target_states + ) + + # Generate a discrete trajectory (joint or Cartesian) + qpos_list, xpos_list = motion_gen.create_discrete_trajectory( + qpos_list=[[0,0,0,0,0,0],[0.5,0.2,0,0,0,0]], + sample_num=20 + ) + +API Reference +~~~~~~~~~~~~~ + +**Initialization** + +.. code-block:: python + + MotionGenerator( + robot: Robot, + uid: str, + sim=None, + planner_type="toppra", + default_velocity=0.2, + default_acceleration=0.5, + collision_margin=0.01, + **kwargs + ) + +- ``robot``: Robot instance, must support get_joint_ids, compute_fk, compute_ik +- ``uid``: Unique robot identifier (e.g., "arm") +- ``planner_type``: Planner type (default: "toppra") +- ``default_velocity``, ``default_acceleration``: Default joint constraints + +**plan** + +.. code-block:: python + + plan( + current_state: Dict, + target_states: List[Dict], + sample_method=TrajectorySampleMethod.TIME, + sample_interval=0.01, + **kwargs + ) -> Tuple[bool, positions, velocities, accelerations, times, duration] + +- Plans a time-optimal trajectory (joint space), returns trajectory arrays and duration. + +**create_discrete_trajectory** + +.. code-block:: python + + create_discrete_trajectory( + xpos_list=None, + qpos_list=None, + is_use_current_qpos=True, + is_linear=False, + sample_method=TrajectorySampleMethod.QUANTITY, + sample_num=20, + qpos_seed=None, + **kwargs + ) -> Tuple[List[np.ndarray], List[np.ndarray]] + +- Generates a discrete trajectory between waypoints (joint or Cartesian), auto-handles FK/IK. + +**estimate_trajectory_sample_count** + +.. code-block:: python + + estimate_trajectory_sample_count( + xpos_list=None, + qpos_list=None, + step_size=0.01, + angle_step=np.pi/90, + **kwargs + ) -> int + +- Estimates the number of samples needed for a trajectory. + +**plan_with_collision** + +.. code-block:: python + + plan_with_collision(...) + +- (Reserved) Plan trajectory with collision checking (not yet implemented). + +Notes & Best Practices +~~~~~~~~~~~~~~~~~~~~~ + +- Only collision-free planning is currently supported; collision checking is a placeholder. +- Input/outputs are numpy arrays or torch tensors; ensure type consistency. +- Robot instance must implement get_joint_ids, compute_fk, compute_ik, get_proprioception, etc. +- For custom planners, extend the PlannerType Enum and _create_planner methods. diff --git a/docs/source/tutorial/rigid_object_group.rst b/docs/source/tutorial/rigid_object_group.rst new file mode 100644 index 00000000..4ed5a3a1 --- /dev/null +++ b/docs/source/tutorial/rigid_object_group.rst @@ -0,0 +1,52 @@ +Rigid object group tutorial +========================== + +.. currentmodule:: embodichain.lab.sim + +This tutorial shows how to create and use a `RigidObjectGroup` in SimulationManager. +It follows the style used in the `create_scene` tutorial and references the +example script located in ``scripts/tutorials/sim/create_rigid_object_group.py``. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``create_rigid_object_group.py`` script in the +``scripts/tutorials/sim`` directory. + +.. dropdown:: Code for create_rigid_object_group.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/create_rigid_object_group.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + + +Adding a RigidObjectGroup +------------------------- + +The key part of the tutorial demonstrates creating a ``RigidObjectGroup`` via +``sim.add_rigid_object_group``. The group is configured with a mapping of +object UIDs to ``RigidObjectCfg`` entries. Each entry defines a shape +(here ``CubeCfg``), physics attributes, and initial pose. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_rigid_object_group.py + :language: python + :start-at: obj_group: RigidObjectGroup = sim.add_rigid_object_group( + :end-at: print("[INFO]: Scene setup complete!") + + +Running the tutorial +~~~~~~~~~~~~~~~~~~~~ + +To run the script from the repository root: + +.. code-block:: bash + + python scripts/tutorials/sim/create_rigid_object_group.py + +You can pass flags such as ``--headless``, ``--num_envs ``, and +``--device `` to customize the run. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst new file mode 100644 index 00000000..d3c18883 --- /dev/null +++ b/docs/source/tutorial/rl.rst @@ -0,0 +1,365 @@ +.. _tutorial_rl: + +Reinforcement Learning Training +================================ + +.. currentmodule:: embodichain.agents.rl + +This tutorial shows you how to train reinforcement learning agents using EmbodiChain's RL framework. You'll learn how to configure training via JSON, set up environments, policies, and algorithms, and launch training sessions. + +Overview +~~~~~~~~ + +The RL framework provides a modular, extensible stack for robotics tasks: + +- **Trainer**: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save) +- **Algorithm**: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO) +- **Policy**: Neural network models implementing a unified interface (get_action/get_value/evaluate_actions) +- **Buffer**: On-policy rollout storage and minibatch iterator (managed by algorithm) +- **Env Factory**: Build environments from a JSON config via registry + +Architecture +~~~~~~~~~~~~ + +The framework follows a clean separation of concerns: + +- **Trainer**: Orchestrates the training loop (calls algorithm for data collection and updates, handles logging/eval/save) +- **Algorithm**: Controls data collection process (interacts with environment, fills buffer, computes advantages/returns) and updates the policy (e.g., PPO) +- **Policy**: Neural network models implementing a unified interface +- **Buffer**: On-policy rollout storage and minibatch iterator (managed by algorithm) +- **Env Factory**: Build environments from a JSON config via registry + +The core components and their relationships: + +- Trainer → Policy, Env, Algorithm (via callbacks for statistics) +- Algorithm → Policy, RolloutBuffer (algorithm manages its own buffer) + +Configuration via JSON +~~~~~~~~~~~~~~~~~~~~~~ + +Training is configured via a JSON file that defines runtime settings, environment, policy, and algorithm parameters. + +Example Configuration +--------------------- + +The configuration file (e.g., ``train_config.json``) is located in ``configs/agents/rl/push_cube``: + +.. dropdown:: Example: train_config.json + :icon: code + + .. literalinclude:: ../../../configs/agents/rl/push_cube/train_config.json + :language: json + :linenos: + +Configuration Sections +--------------------- + +Runtime Settings +^^^^^^^^^^^^^^^^ + +The ``runtime`` section controls experiment setup: + +- **exp_name**: Experiment name (used for output directories) +- **seed**: Random seed for reproducibility +- **cuda**: Whether to use GPU (default: true) +- **headless**: Whether to run simulation in headless mode +- **iterations**: Number of training iterations +- **rollout_steps**: Steps per rollout (e.g., 1024) +- **eval_freq**: Frequency of evaluation (in steps) +- **save_freq**: Frequency of checkpoint saving (in steps) +- **use_wandb**: Whether to enable Weights & Biases logging (set in JSON config) +- **wandb_project_name**: Weights & Biases project name + +Environment Configuration +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``env`` section defines the task environment: + +- **id**: Environment registry ID (e.g., "PushCubeRL") +- **cfg**: Environment-specific configuration parameters + +Example: + +.. code-block:: json + + "env": { + "id": "PushCubeRL", + "cfg": { + "num_envs": 4, + "obs_mode": "state", + "episode_length": 100, + "action_scale": 0.1, + "success_threshold": 0.1 + } + } + +Policy Configuration +^^^^^^^^^^^^^^^^^^^ + +The ``policy`` section defines the neural network policy: + +- **name**: Policy name (e.g., "actor_critic", "vla") +- **cfg**: Policy-specific hyperparameters (empty for actor_critic) +- **actor**: Actor network configuration (required for actor_critic) +- **critic**: Critic network configuration (required for actor_critic) + +Example: + +.. code-block:: json + + "policy": { + "name": "actor_critic", + "cfg": {}, + "actor": { + "type": "mlp", + "hidden_sizes": [256, 256], + "activation": "relu" + }, + "critic": { + "type": "mlp", + "hidden_sizes": [256, 256], + "activation": "relu" + } + } + +Algorithm Configuration +^^^^^^^^^^^^^^^^^^^^^^^ + +The ``algorithm`` section defines the RL algorithm: + +- **name**: Algorithm name (e.g., "ppo") +- **cfg**: Algorithm-specific hyperparameters + +Example: + +.. code-block:: json + + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 64, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.01, + "vf_coef": 0.5, + "max_grad_norm": 0.5 + } + } + +Training Script +~~~~~~~~~~~~~~~ + +The training script (``train.py``) is located in ``embodichain/agents/rl/``: + +.. dropdown:: Code for train.py + :icon: code + + .. literalinclude:: ../../../embodichain/agents/rl/train.py + :language: python + :linenos: + +The Script Explained +-------------------- + +The training script performs the following steps: + +1. **Parse Configuration**: Loads JSON config and extracts runtime/env/policy/algorithm blocks +2. **Setup**: Initializes device, seeds, output directories, TensorBoard, and Weights & Biases +3. **Build Components**: + - Environment via ``build_env()`` factory + - Policy via ``build_policy()`` registry + - Algorithm via ``build_algo()`` factory +4. **Create Trainer**: Instantiates the ``Trainer`` with all components +5. **Train**: Runs the training loop until completion + +Launching Training +------------------ + +To start training, run: + +.. code-block:: bash + + python embodichain/agents/rl/train.py --config configs/agents/rl/push_cube/train_config.json + +Outputs +------- + +All outputs are written to ``./outputs/_/``: + +- **logs/**: TensorBoard logs +- **checkpoints/**: Model checkpoints + +Training Process +~~~~~~~~~~~~~~~ + +The training process follows this sequence: + +1. **Rollout Phase**: Algorithm collects trajectories by interacting with the environment (via ``collect_rollout``). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info. +2. **GAE Computation**: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras) +3. **Update Phase**: Algorithm updates the policy using collected data (e.g., PPO) +4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases +5. **Evaluation** (periodic): Trainer evaluates the current policy +6. **Checkpointing** (periodic): Trainer saves model checkpoints + +Policy Interface +~~~~~~~~~~~~~~~~ + +All policies must inherit from the ``Policy`` abstract base class: + +.. code-block:: python + + from abc import ABC, abstractmethod + import torch.nn as nn + + class Policy(nn.Module, ABC): + device: torch.device + + @abstractmethod + def get_action( + self, obs: torch.Tensor, deterministic: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (action, log_prob, value)""" + raise NotImplementedError + + @abstractmethod + def get_value(self, obs: torch.Tensor) -> torch.Tensor: + """Returns value estimate""" + raise NotImplementedError + + @abstractmethod + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns (log_prob, entropy, value)""" + raise NotImplementedError + +Available Policies +------------------ + +- **ActorCritic**: MLP-based Gaussian policy with learnable log_std. Requires external ``actor`` and ``critic`` modules to be provided (defined in JSON config). +- **VLAPlaceholderPolicy**: Placeholder for Vision-Language-Action policies + +Algorithms +~~~~~~~~~~ + +Available Algorithms +-------------------- + +- **PPO**: Proximal Policy Optimization with GAE + +Adding a New Algorithm +--------------------- + +To add a new algorithm: + +1. Create a new algorithm class in ``embodichain/agents/rl/algo/`` +2. Implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods +3. Register in ``algo/__init__.py``: + +.. code-block:: python + + from embodichain.agents.rl.algo import BaseAlgorithm, register_algo + from embodichain.agents.rl.buffer import RolloutBuffer + + @register_algo("my_algo") + class MyAlgorithm(BaseAlgorithm): + def __init__(self, cfg, policy): + self.cfg = cfg + self.policy = policy + self.device = torch.device(cfg.device) + self.buffer = None + + def initialize_buffer(self, num_steps, num_envs, obs_dim, action_dim): + """Initialize the algorithm's buffer.""" + self.buffer = RolloutBuffer(num_steps, num_envs, obs_dim, action_dim, self.device) + + def collect_rollout(self, env, policy, obs, num_steps, on_step_callback=None): + """Control data collection process (interact with env, fill buffer, compute advantages/returns).""" + # Collect trajectories + # Compute advantages/returns (e.g., GAE for on-policy algorithms) + # Attach extras to buffer: self.buffer.set_extras({"advantages": adv, "returns": ret}) + # Return empty dict (dense logging handled in trainer) + return {} + + def update(self): + """Update the policy using collected data.""" + # Access extras from buffer: self.buffer._extras.get("advantages") + # Use self.buffer to update policy + return {"loss": 0.0} + +Adding a New Policy +-------------------- + +To add a new policy: + +1. Create a new policy class inheriting from the ``Policy`` abstract base class +2. Register in ``models/__init__.py``: + +.. code-block:: python + + from embodichain.agents.rl.models import register_policy, Policy + + @register_policy("my_policy") + class MyPolicy(Policy): + def __init__(self, obs_space, action_space, device, config): + super().__init__() + self.device = device + # Initialize your networks here + + def get_action(self, obs, deterministic=False): + ... + def get_value(self, obs): + ... + def evaluate_actions(self, obs, actions): + ... + +Adding a New Environment +------------------------ + +To add a new RL environment: + +1. Create an environment class inheriting from ``EmbodiedEnv`` +2. Register it with the Gymnasium registry: + +.. code-block:: python + + from embodichain.lab.gym.utils.registration import register_env + + @register_env("MyTaskRL", max_episode_steps=100, override=True) + class MyTaskEnv(EmbodiedEnv): + cfg: MyTaskEnvCfg + ... + +3. Use the environment ID in your JSON config: + +.. code-block:: json + + "env": { + "id": "MyTaskRL", + "cfg": { + ... + } + } + +Best Practices +~~~~~~~~~~~~~~ + +- **Device Management**: Device is single-sourced from ``runtime.cuda``. All components (trainer/algorithm/policy/env) share the same device. + +- **Action Scaling**: Keep action scaling in the environment, not in the policy. + +- **Observation Format**: Environments should provide consistent observation shape/types (torch.float32) and a single ``done = terminated | truncated``. + +- **Algorithm Interface**: Algorithms must implement ``initialize_buffer()``, ``collect_rollout()``, and ``update()`` methods. The algorithm completely controls data collection and buffer management. + +- **Reward Components**: Organize reward components in ``info["rewards"]`` dictionary and metrics in ``info["metrics"]`` dictionary. The trainer performs dense per-step logging directly from environment info. + +- **Configuration**: Use JSON for all hyperparameters. This makes experiments reproducible and easy to track. + +- **Logging**: Metrics are automatically logged to TensorBoard and Weights & Biases. Check ``outputs//logs/`` for TensorBoard logs. + +- **Checkpoints**: Regular checkpoints are saved to ``outputs//checkpoints/``. Use these to resume training or evaluate policies. + diff --git a/docs/source/tutorial/robot.rst b/docs/source/tutorial/robot.rst new file mode 100644 index 00000000..8312ad27 --- /dev/null +++ b/docs/source/tutorial/robot.rst @@ -0,0 +1,143 @@ +.. _tutorial_simulate_robot: + +Simulating a Robot +================ + +.. currentmodule:: embodichain.lab.sim + +This tutorial shows you how to create and simulate a robot using SimulationManager. You'll learn how to load a robot from URDF files, configure control systems, and run basic robot simulation with joint control. + +The Code +~~~~~~~~ + +The tutorial corresponds to the ``create_robot.py`` script in the ``scripts/tutorials/sim`` directory. + +.. dropdown:: Code for create_robot.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/create_robot.py + :language: python + :linenos: + + +The Code Explained +~~~~~~~~~~~~~~~~~~ + +Similar to the previous tutorial on creating a simulation scene, we use the :class:`SimulationManager` class to set up the simulation environment. If you haven't read that tutorial yet, please refer to :doc:`create_scene` first. + +Loading Robot URDF +------------------- + +SimulationManager supports loading robots from URDF (Unified Robot Description Format) files. You can load either a single URDF file or compose multiple URDF components into a complete robot system. + +For a simple two-component robot (arm + hand): + +.. literalinclude:: ../../../scripts/tutorials/sim/create_robot.py + :language: python + :start-at: sr5_urdf_path = get_data_path("Rokae/SR5/SR5.urdf") + :end-at: robot: Robot = sim.add_robot(cfg=cfg) + + +The :class:`cfg.URDFCfg` allows you to compose multiple URDF files with specific transformations, enabling complex robot assemblies. + + +Configuring Control Parts +-------------------------- + +Control parts define how the robot's joints are grouped for control purposes. This is useful for organizing complex robots with multiple subsystems. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_robot.py + :language: python + :start-at: # Define control parts for the robot + :end-at: } + +Joint names in control parts can use regex patterns for flexible matching. For example: + +- ``"JOINT[1-6]"`` matches JOINT1, JOINT2, ..., JOINT6 +- ``"L_.*"`` matches all joints starting with `"L_"` + +Setting Drive Properties +------------------------ + +Drive properties control how the robot's joints behave during simulation, including stiffness, damping, and force limits. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_robot.py + :language: python + :start-at: drive_pros=JointDrivePropertiesCfg( + :end-at: ) + +You can set different stiffness values for different joint groups using regex patterns. More details on drive properties can be found in :class:`cfg.JointDrivePropertiesCfg`. + +For more robot configuration options, refer to :class:`cfg.RobotCfg`. + +Robot Control +------------- + +For the basic control of robot joints, you can set position targets using :meth:`objects.Robot.set_qpos`. The control action should be created as a torch.Tensor with shape (num_envs, num_joints), where `num_joints` is the total number of joints in the robot or the number of joints in a specific control part. + +- If you can control all joints, use: + + .. code-block:: python + + robot.set_qpos(qpos=target_positions) + +- If you want to control a subset of joints, specify the joint IDs: + + .. code-block:: python + + robot.set_qpos(qpos=target_positions, joint_ids=subset_joint_ids) + +Getting Robot State +-------------------- + +You can query the robot's current joint positions and velocities via :meth:`objects.Robot.get_qpos` and :meth:`objects.Robot.get_qvel`. For more robot API details, see :class:`objects.Robot`. + +The Code Execution +~~~~~~~~~~~~~~~~~~ + +To run the robot simulation script: + +.. code-block:: bash + + cd /root/sources/embodichain + python scripts/tutorials/sim/create_robot.py + +You can customize the simulation with various command-line options: + +.. code-block:: bash + + # Run with GPU physics + python scripts/tutorials/sim/create_robot.py --device cuda + + # Run multiple environments + python scripts/tutorials/sim/create_robot.py --num_envs 4 + + # Run in headless mode + python scripts/tutorials/sim/create_robot.py --headless + + # Enable ray tracing rendering + python scripts/tutorials/sim/create_robot.py --enable_rt + +The simulation will show the robot moving through different poses, demonstrating basic joint control capabilities. + +Key Features Demonstrated +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This tutorial demonstrates several key features of robot simulation in SimulationManager: + +1. **URDF Loading**: Both single-file and multi-component robot loading +2. **Control Parts**: Organizing joints into logical control groups +3. **Drive Properties**: Configuring joint stiffness and control behavior +4. **Joint Control**: Setting position targets and reading joint states +5. **Multi-Environment**: Running multiple robot instances in parallel + +Next Steps +~~~~~~~~~~ + +After mastering basic robot simulation, you can explore: + +- End-effector control and inverse kinematics +- Sensor integration (cameras, force sensors) +- Robot-object interaction scenarios + +This tutorial provides the foundation for creating sophisticated robotic simulation scenarios with SimulationManager. \ No newline at end of file diff --git a/docs/source/tutorial/sensor.rst b/docs/source/tutorial/sensor.rst new file mode 100644 index 00000000..1d5c4dc9 --- /dev/null +++ b/docs/source/tutorial/sensor.rst @@ -0,0 +1,118 @@ +.. _tutorial_simulate_sensor: + +Simulating a Camera Sensor +========================= + +.. currentmodule:: embodichain.lab.sim + +This tutorial demonstrates how to create and simulate a camera sensor attached to a robot using SimulationManager. You will learn how to configure a camera, attach it to the robot's end-effector, and visualize the sensor's output during simulation. + +Source Code +~~~~~~~~~~~ + +The code for this tutorial is in ``scripts/tutorials/sim/create_sensor.py``. + +.. dropdown:: Show code for create_sensor.py + :icon: code + + .. literalinclude:: ../../../scripts/tutorials/sim/create_sensor.py + :language: python + :linenos: + +Overview +~~~~~~~~ + +This tutorial builds on the basic robot simulation example. If you are not familiar with robot simulation in SimulationManager, please read the :doc:`robot` tutorial first. + +1. **Sensor Creation and Attachment** +------------------------------------- + +The camera sensor is created using :class:`CameraCfg` and can be attached to the robot's end-effector or placed freely in the scene. The attachment is controlled by the ``--attach_sensor`` argument. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_sensor.py + :language: python + :start-at: def create_sensor + :end-at: return camera + +- The camera's intrinsics (focal lengths and principal point) and resolution are set. +- The ``extrinsics`` specify the camera's pose relative to its parent (e.g., the robot's ``ee_link`` or the world). +- The camera is added to the simulation with :meth:`sim.add_sensor`. + +2. **Visualizing Sensor Output** +-------------------------------- + +The function ``get_sensor_image`` retrieves and visualizes the camera's color, depth, mask, and normal images. In GUI mode, images are shown in a 2x2 grid using OpenCV. In headless mode, images are saved to disk. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_sensor.py + :language: python + :start-at: def get_sensor_image + :end-at: plt.close(fig) + +- The camera is updated to capture the latest data. +- Four types of images are visualized: color, depth, mask, and normals. +- Images are displayed in a window or saved as PNG files depending on the mode. + +3. **Simulation Loop** +---------------------- + +The simulation loop moves the robot through different arm poses and periodically updates and visualizes the sensor output. + +.. literalinclude:: ../../../scripts/tutorials/sim/create_sensor.py + :language: python + :start-at: def run_simulation + :end-at: sim.destroy() + +- The robot alternates between two arm positions. +- After each movement, the sensor image is refreshed and visualized. + +Running the Example +~~~~~~~~~~~~~~~~~~~ + +To run the sensor simulation script: + +.. code-block:: bash + + cd /home/dex/projects/yuanhaonan/embodichain + python scripts/tutorials/sim/create_sensor.py + +You can customize the simulation with the following command-line options: + +.. code-block:: bash + + # Use GPU physics + python scripts/tutorials/sim/create_sensor.py --device cuda + + # Simulate multiple environments + python scripts/tutorials/sim/create_sensor.py --num_envs 4 + + # Run in headless mode (no GUI, images saved to disk) + python scripts/tutorials/sim/create_sensor.py --headless + + # Enable ray tracing rendering + python scripts/tutorials/sim/create_sensor.py --enable_rt + + # Attach the camera to the robot end-effector + python scripts/tutorials/sim/create_sensor.py --attach_sensor + +Key Features Demonstrated +~~~~~~~~~~~~~~~~~~~~~~~~ + +This tutorial demonstrates: + +1. **Camera sensor creation** using :class:`CameraCfg` +2. **Sensor attachment** to a robot link or placement in the scene +3. **Camera configuration** (intrinsics, extrinsics, clipping planes) +4. **Real-time visualization** of color, depth, mask, and normal images +5. **Robot-sensor integration** in a simulation loop + +Next Steps +~~~~~~~~~~ + +After completing this tutorial, you can explore: + +- Using other sensor types (e.g., stereo cameras, force sensors) +- Recording sensor data for offline analysis +- Integrating sensor feedback into robot control or learning algorithms + +This tutorial provides a foundation for integrating perception into robotic simulation scenarios with SimulationManager. +This tutorial provides the foundation for integrating perception into robotic simulation scenarios with SimulationManager. \ No newline at end of file diff --git a/docs/source/tutorial/solver.rst b/docs/source/tutorial/solver.rst new file mode 100644 index 00000000..dd190f01 --- /dev/null +++ b/docs/source/tutorial/solver.rst @@ -0,0 +1,114 @@ + +.. _tutorial_solver: + +Create a solver +=============== + +.. currentmodule:: embodichain.lab.sim.solvers + +Overview +~~~~~~~~ + +The ``solver`` module in EmbodiChain provides a unified and extensible interface for robot kinematics computation, including forward kinematics (FK), inverse kinematics (IK), and Jacobian calculation. It supports multiple solver backends (e.g., Pinocchio, OPW, SRS, PINK, PyTorch) and is designed for both simulation and real-robot applications. + +Key Features +------------ +- **Unified API**: Abstract base class (`BaseSolver`) defines a common interface for all solvers. +- **Multiple Backends**: Supports Pinocchio, OPW, SRS, PINK, PyTorch, and differential solvers. +- **Flexible Configuration**: Easily switch solver type and parameters via configuration. +- **Batch and Single Query**: Supports both batch and single FK/IK/Jacobian queries. +- **Extensible**: New solvers can be added by subclassing `BaseSolver` and implementing required methods. + +Example: Using PinkSolver +~~~~~~~~~~~~~~~~~~~~~~~~~ + + +.. code-block:: python + + from embodichain.lab.sim.solvers import PinkSolverCfg + from embodichain.lab.sim.objects.robot import Robot + + # 1. Configure PinkSolver + pink_cfg = PinkSolverCfg( + urdf_path="/path/to/robot.urdf", + joint_names=[ + "shoulder_pan_joint", "shoulder_lift_joint", "elbow_joint", + "wrist_1_joint", "wrist_2_joint", "wrist_3_joint" + ], + end_link_name="ee_link", + root_link_name="base_link" + ) + # 2. Assign solver config to robot config + robot_cfg.solver_cfg = pink_cfg + # 3. Instantiate robot (solver will be initialized automatically) + robot = Robot(cfg=robot_cfg, entities=[], device="cpu") + + # 4. Use FK/IK/Jacobian + qpos = [0.0, -1.57, 1.57, 0.0, 1.57, 0.0] # 6-DOF joint angles (radians) + ee_pose = robot.compute_fk(qpos) # Forward kinematics, returns 4x4 matrix + print("End-effector pose (FK):\n", ee_pose) + + import numpy as np + target_pose = np.array([ + [0, -1, 0, 0.5], + [1, 0, 0, 0.2], + [0, 0, 1, 0.3], + [0, 0, 0, 1.0] + ]) + success, qpos_sol = robot.compute_ik(target_pose, joint_seed=qpos) + print("IK success:", success) + print("IK solution:", qpos_sol) + + J = robot.get_solver().get_jacobian(qpos) + print("Jacobian:\n", J) + +**Note** + +- robot.compute_fk(qpos) internally calls the bound solver's get_fk method. +- robot.compute_ik(target_pose, joint_seed) internally calls the solver's get_ik method. + +API Reference +~~~~~~~~~~~~~ + +**BaseSolver** + +.. code-block:: python + + class BaseSolver: + def get_fk(self, qpos, **kwargs) -> torch.Tensor: + """Compute forward kinematics for the end-effector.""" + + def get_ik(self, target_pose, joint_seed=None, num_samples=None, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute inverse kinematics for a given pose.""" + + def get_jacobian(self, qpos, locations=None, jac_type="full") -> torch.Tensor: + """Compute the Jacobian matrix for the given joint positions.""" + +- **set_ik_nearst_weight**: Set weights for IK nearest neighbor search. +- **set_position_limits / get_position_limits**: Set or get joint position limits. +- **set_tcp / get_tcp**: Set or get the tool center point (TCP) transformation. + +**PinkSolver** + +- Implements all BaseSolver methods using the Pink library. +- Supports custom task lists, solver type selection, and joint limit handling. +- See PinkSolverCfg for all configuration options. + +Configuration +~~~~~~~~~~~~~ + +- All solvers are configured via a `SolverCfg` or its subclass (e.g., `PinkSolverCfg`). +- Key config fields: `urdf_path`, `joint_names`, `end_link_name`, `root_link_name`, `tcp`, and solver-specific parameters. +- Use `cfg.init_solver()` to instantiate the solver, or assign to `robot_cfg.solver_cfg` for automatic integration. + +Notes & Best Practices +~~~~~~~~~~~~~~~~~~~~~ +- Always ensure URDF and joint/link names match your robot model. +- For IK, providing a good `qpos_seed` improves convergence and solution quality. +- Use `set_iteration_params` (if available) to tune solver performance for your application. +- For custom robots or new algorithms, subclass `BaseSolver` and register your solver. + +See Also +~~~~~~~~ +- :ref:`tutorial_motion_generator` — Motion Generator +- :ref:`tutorial_basic_env` — Basic Environment Setup diff --git a/docs/sync_readme.py b/docs/sync_readme.py new file mode 100644 index 00000000..a3198b6e --- /dev/null +++ b/docs/sync_readme.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +"""Sync project README.md into docs/source/introduction.md. + +Idempotent copy. Exit code 0 on success. +""" +import shutil +from pathlib import Path +import sys + + +def main() -> int: + repo_root = Path(__file__).resolve().parents[1] + readme = repo_root / "README.md" + dest = repo_root / "docs" / "source" / "introduction.md" + + if not readme.exists(): + print(f"ERROR: README not found at {readme}") + return 2 + + # Ensure destination directory exists + dest.parent.mkdir(parents=True, exist_ok=True) + + # Copy file + shutil.copyfile(readme, dest) + print(f"Copied {readme} -> {dest}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/embodichain/__init__.py b/embodichain/__init__.py new file mode 100644 index 00000000..1b3998b6 --- /dev/null +++ b/embodichain/__init__.py @@ -0,0 +1,32 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os + +embodichain_dir = os.path.dirname(__file__) + +# Read version from VERSION file +def _get_version(): + version_file = os.path.join(os.path.dirname(embodichain_dir), "VERSION") + try: + with open(version_file, "r") as f: + return f.read().strip() + except FileNotFoundError: + print("VERSION file not found.") + return "unknown" + + +__version__ = _get_version() diff --git a/embodichain/agents/rl/algo/__init__.py b/embodichain/agents/rl/algo/__init__.py new file mode 100644 index 00000000..6aca3d5d --- /dev/null +++ b/embodichain/agents/rl/algo/__init__.py @@ -0,0 +1,52 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Tuple, Type, Any +import torch + +from .base import BaseAlgorithm +from .ppo import PPOCfg, PPO + +# name -> (CfgClass, AlgoClass) +_ALGO_REGISTRY: Dict[str, Tuple[Type[Any], Type[Any]]] = { + "ppo": (PPOCfg, PPO), +} + + +def get_registered_algo_names() -> list[str]: + return list(_ALGO_REGISTRY.keys()) + + +def build_algo(name: str, cfg_kwargs: Dict[str, float], policy, device: torch.device): + key = name.lower() + if key not in _ALGO_REGISTRY: + raise ValueError( + f"Algorithm '{name}' not found. Available: {get_registered_algo_names()}" + ) + CfgCls, AlgoCls = _ALGO_REGISTRY[key] + cfg = CfgCls(device=str(device), **cfg_kwargs) + return AlgoCls(cfg, policy) + + +__all__ = [ + "BaseAlgorithm", + "PPOCfg", + "PPO", + "get_registered_algo_names", + "build_algo", +] diff --git a/embodichain/agents/rl/algo/base.py b/embodichain/agents/rl/algo/base.py new file mode 100644 index 00000000..1cb23309 --- /dev/null +++ b/embodichain/agents/rl/algo/base.py @@ -0,0 +1,52 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Any, Optional, Callable +import torch + + +class BaseAlgorithm: + """Base class for RL algorithms. + + Algorithms must implement buffer initialization, rollout collection, and + policy update. Trainer depends only on this interface to remain + algorithm-agnostic. + """ + + device: torch.device + + def initialize_buffer( + self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int + ) -> None: + """Initialize internal buffer(s) required by the algorithm.""" + raise NotImplementedError + + def collect_rollout( + self, + env, + policy, + obs: torch.Tensor, + num_steps: int, + on_step_callback: Optional[Callable] = None, + ) -> Dict[str, Any]: + """Collect trajectories and return logging info (e.g., reward components).""" + raise NotImplementedError + + def update(self) -> Dict[str, float]: + """Update policy using collected data and return training losses.""" + raise NotImplementedError diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py new file mode 100644 index 00000000..2ecc195e --- /dev/null +++ b/embodichain/agents/rl/algo/ppo.py @@ -0,0 +1,184 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +from typing import Dict, Any, Tuple, Callable, Optional + +from embodichain.agents.rl.utils import AlgorithmCfg +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.utils import configclass +from .base import BaseAlgorithm + + +@configclass +class PPOCfg(AlgorithmCfg): + """Configuration for the PPO algorithm.""" + + n_epochs: int = 10 + clip_coef: float = 0.2 + ent_coef: float = 0.01 + vf_coef: float = 0.5 + + +class PPO(BaseAlgorithm): + """PPO algorithm operating via Policy and RolloutBuffer (algo-agnostic design).""" + + def __init__(self, cfg: PPOCfg, policy): + self.cfg = cfg + self.policy = policy + self.device = torch.device(cfg.device) + self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) + self.buffer: Optional[RolloutBuffer] = None + # no per-rollout aggregation for dense logging + + def _compute_gae( + self, rewards: torch.Tensor, values: torch.Tensor, dones: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Internal method to compute GAE. Only called by collect_rollout.""" + T, N = rewards.shape + advantages = torch.zeros_like(rewards, device=self.device) + last_adv = torch.zeros(N, device=self.device) + for t in reversed(range(T)): + next_value = values[t + 1] if t < T - 1 else torch.zeros_like(values[0]) + not_done = (~dones[t]).float() + delta = rewards[t] + self.cfg.gamma * next_value * not_done - values[t] + last_adv = ( + delta + self.cfg.gamma * self.cfg.gae_lambda * not_done * last_adv + ) + advantages[t] = last_adv + returns = advantages + values + return advantages, returns + + def initialize_buffer( + self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int + ): + """Initialize the rollout buffer. Called by trainer before first rollout.""" + self.buffer = RolloutBuffer( + num_steps, num_envs, obs_dim, action_dim, self.device + ) + + def collect_rollout( + self, + env, + policy, + obs: torch.Tensor, + num_steps: int, + on_step_callback: Optional[Callable] = None, + ) -> Dict[str, Any]: + """Collect a rollout. Algorithm controls the data collection process.""" + if self.buffer is None: + raise RuntimeError( + "Buffer not initialized. Call initialize_buffer() first." + ) + + policy.train() + self.buffer.step = 0 + current_obs = obs + + for t in range(num_steps): + # Get action from policy + actions, log_prob, value = policy.get_action( + current_obs, deterministic=False + ) + + # Step environment + result = env.step(actions) + next_obs, reward, terminated, truncated, env_info = result + done = terminated | truncated + # Light dtype normalization + reward = reward.float() + done = done.bool() + + # Add to buffer + self.buffer.add(current_obs, actions, reward, done, value, log_prob) + + # Dense logging is handled in Trainer.on_step via info; no aggregation here + # Call callback for statistics and logging + if on_step_callback is not None: + on_step_callback(current_obs, actions, reward, done, env_info, next_obs) + + current_obs = next_obs + + # Compute advantages/returns and attach to buffer extras + adv, ret = self._compute_gae( + self.buffer.rewards, self.buffer.values, self.buffer.dones + ) + self.buffer.set_extras({"advantages": adv, "returns": ret}) + + # No aggregated logging results; Trainer performs dense per-step logging + return {} + + def update(self) -> dict: + """Update the policy using the collected rollout buffer.""" + if self.buffer is None: + raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + + # Normalize advantages (optional, common default) + adv = self.buffer._extras.get("advantages") + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + total_actor_loss = 0.0 + total_value_loss = 0.0 + total_entropy = 0.0 + total_steps = 0 + + for _ in range(self.cfg.n_epochs): + for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): + obs = batch["obs"] + actions = batch["actions"] + old_logprobs = batch["logprobs"] + returns = batch["returns"] + advantages = ( + (batch["advantages"] - adv.mean()) / (adv.std() + 1e-8) + ).detach() + + logprobs, entropy, values = self.policy.evaluate_actions(obs, actions) + ratio = (logprobs - old_logprobs).exp() + surr1 = ratio * advantages + surr2 = ( + torch.clamp( + ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef + ) + * advantages + ) + actor_loss = -torch.min(surr1, surr2).mean() + value_loss = torch.nn.functional.mse_loss(values, returns) + entropy_loss = -entropy.mean() + + loss = ( + actor_loss + + self.cfg.vf_coef * value_loss + + self.cfg.ent_coef * entropy_loss + ) + + self.optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.cfg.max_grad_norm + ) + self.optimizer.step() + + bs = obs.shape[0] + total_actor_loss += actor_loss.item() * bs + total_value_loss += value_loss.item() * bs + total_entropy += (-entropy_loss.item()) * bs + total_steps += bs + + return { + "actor_loss": total_actor_loss / max(1, total_steps), + "value_loss": total_value_loss / max(1, total_steps), + "entropy": total_entropy / max(1, total_steps), + } diff --git a/embodichain/agents/rl/buffer/__init__.py b/embodichain/agents/rl/buffer/__init__.py new file mode 100644 index 00000000..8e6f6392 --- /dev/null +++ b/embodichain/agents/rl/buffer/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .rollout_buffer import RolloutBuffer + +__all__ = ["RolloutBuffer"] diff --git a/embodichain/agents/rl/buffer/rollout_buffer.py b/embodichain/agents/rl/buffer/rollout_buffer.py new file mode 100644 index 00000000..d99a8966 --- /dev/null +++ b/embodichain/agents/rl/buffer/rollout_buffer.py @@ -0,0 +1,106 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Iterator + +import torch + + +class RolloutBuffer: + """On-device rollout buffer for on-policy algorithms. + + Stores (obs, actions, rewards, dones, values, logprobs) over time. + After finalize(), exposes advantages/returns and minibatch iteration. + """ + + def __init__( + self, + num_steps: int, + num_envs: int, + obs_dim: int, + action_dim: int, + device: torch.device, + ): + self.num_steps = num_steps + self.num_envs = num_envs + self.obs_dim = obs_dim + self.action_dim = action_dim + self.device = device + + T, N = num_steps, num_envs + self.obs = torch.zeros(T, N, obs_dim, dtype=torch.float32, device=device) + self.actions = torch.zeros(T, N, action_dim, dtype=torch.float32, device=device) + self.rewards = torch.zeros(T, N, dtype=torch.float32, device=device) + self.dones = torch.zeros(T, N, dtype=torch.bool, device=device) + self.values = torch.zeros(T, N, dtype=torch.float32, device=device) + self.logprobs = torch.zeros(T, N, dtype=torch.float32, device=device) + + self.step = 0 + # Container for algorithm-specific extra fields (e.g., advantages, returns) + self._extras: dict[str, torch.Tensor] = {} + + def add( + self, + obs: torch.Tensor, + action: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + value: torch.Tensor, + logprob: torch.Tensor, + ) -> None: + t = self.step + self.obs[t].copy_(obs) + self.actions[t].copy_(action) + self.rewards[t].copy_(reward) + self.dones[t].copy_(done) + self.values[t].copy_(value) + self.logprobs[t].copy_(logprob) + self.step += 1 + + def set_extras(self, extras: dict[str, torch.Tensor]) -> None: + """Attach algorithm-specific tensors (shape [T, N, ...]) for batching. + + Examples: + {"advantages": adv, "returns": ret} + """ + self._extras = extras or {} + + def iterate_minibatches(self, batch_size: int) -> Iterator[Dict[str, torch.Tensor]]: + T, N = self.num_steps, self.num_envs + total = T * N + indices = torch.randperm(total, device=self.device) + for start in range(0, total, batch_size): + idx = indices[start : start + batch_size] + t_idx = idx // N + n_idx = idx % N + batch = { + "obs": self.obs[t_idx, n_idx], + "actions": self.actions[t_idx, n_idx], + "rewards": self.rewards[t_idx, n_idx], + "dones": self.dones[t_idx, n_idx], + "values": self.values[t_idx, n_idx], + "logprobs": self.logprobs[t_idx, n_idx], + } + # Slice extras if present and shape aligned to [T, N, ...] + for name, tensor in self._extras.items(): + try: + batch[name] = tensor[t_idx, n_idx] + except Exception: + # Skip misaligned extras silently + continue + yield batch diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py new file mode 100644 index 00000000..669e2b33 --- /dev/null +++ b/embodichain/agents/rl/models/__init__.py @@ -0,0 +1,101 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Type +import torch +from gymnasium import spaces + +from .actor_critic import ActorCritic +from .policy import Policy +from .mlp import MLP + +# In-module policy registry +_POLICY_REGISTRY: Dict[str, Type[Policy]] = {} + + +def register_policy(name: str, policy_cls: Type[Policy]) -> None: + if name in _POLICY_REGISTRY: + raise ValueError(f"Policy '{name}' is already registered") + _POLICY_REGISTRY[name] = policy_cls + + +def get_registered_policy_names() -> list[str]: + return list(_POLICY_REGISTRY.keys()) + + +def get_policy_class(name: str) -> Type[Policy] | None: + return _POLICY_REGISTRY.get(name) + + +def build_policy( + policy_block: dict, + obs_space: spaces.Space, + action_space: spaces.Space, + device: torch.device, + actor: torch.nn.Module | None = None, + critic: torch.nn.Module | None = None, +) -> Policy: + """Build policy strictly from json-like block: { name: ..., cfg: {...} }""" + name = policy_block["name"].lower() + if name not in _POLICY_REGISTRY: + available = ", ".join(get_registered_policy_names()) + raise ValueError( + f"Policy '{name}' is not registered. Available policies: {available}" + ) + policy_cls = _POLICY_REGISTRY[name] + if name == "actor_critic": + if actor is None or critic is None: + raise ValueError( + "ActorCritic policy requires external 'actor' and 'critic' modules." + ) + return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + else: + return policy_cls(obs_space, action_space, device) + + +def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: + """Construct an MLP module from a minimal json-like config. + + Expected schema: + module_cfg = { + "type": "mlp", + "hidden_sizes": [256, 256], + "activation": "relu", + } + """ + if module_cfg.get("type", "").lower() != "mlp": + raise ValueError("Only 'mlp' type is supported for actor/critic in this setup.") + + hidden_sizes = module_cfg["network_cfg"]["hidden_sizes"] + activation = module_cfg["network_cfg"]["activation"] + return MLP(in_dim, out_dim, hidden_sizes, activation) + + +# default registrations +register_policy("actor_critic", ActorCritic) + +__all__ = [ + "ActorCritic", + "register_policy", + "get_registered_policy_names", + "build_policy", + "build_mlp_from_cfg", + "get_policy_class", + "Policy", + "MLP", +] diff --git a/embodichain/agents/rl/models/actor_critic.py b/embodichain/agents/rl/models/actor_critic.py new file mode 100644 index 00000000..1c40043a --- /dev/null +++ b/embodichain/agents/rl/models/actor_critic.py @@ -0,0 +1,96 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Any, Tuple + +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +from .mlp import MLP +from .policy import Policy + + +class ActorCritic(Policy): + """Actor-Critic with learnable log_std for Gaussian policy. + + This is a placeholder implementation of the Policy interface that: + - Encapsulates MLP networks (actor + critic) that need to be trained by RL algorithms + - Handles internal computation: MLP output → mean + learnable log_std → Normal distribution + - Provides a uniform interface for RL algorithms (PPO, SAC, etc.) + + This allows seamless swapping with other policy implementations (e.g., VLAPolicy) + without modifying RL algorithm code. + + Implements: + - get_action(obs, deterministic=False) -> (action, log_prob, value) + - get_value(obs) + - evaluate_actions(obs, actions) -> (log_prob, entropy, value) + """ + + def __init__( + self, + obs_space, + action_space, + device: torch.device, + actor: nn.Module, + critic: nn.Module, + ): + super().__init__() + self.obs_dim = obs_space.shape[-1] + self.action_dim = action_space.shape[-1] + self.device = device + + # Require external injection of actor and critic + self.actor = actor + self.critic = critic + self.actor.to(self.device) + self.critic.to(self.device) + + # learnable log_std per action dim + self.log_std = nn.Parameter(torch.zeros(self.action_dim, device=self.device)) + self.log_std_min = -5.0 + self.log_std_max = 2.0 + + @torch.no_grad() + def get_action( + self, obs: torch.Tensor, deterministic: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean = self.actor(obs) + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) + action = mean if deterministic else dist.sample() + log_prob = dist.log_prob(action).sum(dim=-1) + value = self.critic(obs).squeeze(-1) + return action, log_prob, value + + @torch.no_grad() + def get_value(self, obs: torch.Tensor) -> torch.Tensor: + return self.critic(obs).squeeze(-1) + + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean = self.actor(obs) + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) + log_prob = dist.log_prob(actions).sum(dim=-1) + entropy = dist.entropy().sum(dim=-1) + value = self.critic(obs).squeeze(-1) + return log_prob, entropy, value diff --git a/embodichain/agents/rl/models/mlp.py b/embodichain/agents/rl/models/mlp.py new file mode 100644 index 00000000..d839f63d --- /dev/null +++ b/embodichain/agents/rl/models/mlp.py @@ -0,0 +1,121 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from functools import reduce +from typing import Iterable, List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn + + +ActivationName = Union[str, None] + + +def _resolve_activation(name: ActivationName) -> nn.Module: + if name is None: + return nn.Identity() + name_l = str(name).lower() + if name_l in ("relu",): + return nn.ReLU() + if name_l in ("elu",): + return nn.ELU() + if name_l in ("tanh",): + return nn.Tanh() + if name_l in ("gelu",): + return nn.GELU() + if name_l in ("silu", "swish"): + return nn.SiLU() + # fallback + return nn.ReLU() + + +class MLP(nn.Sequential): + """General MLP supporting custom last activation, orthogonal init, and output reshape. + + Args: + - input_dim: input dimension + - output_dim: output dimension (int or shape tuple/list) + - hidden_dims: hidden layer sizes, e.g. [256, 256] + - activation: hidden layer activation name (relu/elu/tanh/gelu/silu) + - last_activation: last-layer activation name or None for linear + - use_layernorm: whether to add LayerNorm after each hidden linear layer + - dropout_p: dropout probability for hidden layers (0 disables) + """ + + def __init__( + self, + input_dim: int, + output_dim: Union[int, Sequence[int]], + hidden_dims: Sequence[int], + activation: ActivationName = "elu", + last_activation: ActivationName = None, + use_layernorm: bool = False, + dropout_p: float = 0.0, + ) -> None: + super().__init__() + + act = lambda: _resolve_activation(activation) + last_act = ( + _resolve_activation(last_activation) + if last_activation is not None + else None + ) + + layers: List[nn.Module] = [] + dims = [input_dim] + list(hidden_dims) + + for in_d, out_d in zip(dims[:-1], dims[1:]): + layers.append(nn.Linear(in_d, out_d)) + if use_layernorm: + layers.append(nn.LayerNorm(out_d)) + layers.append(act()) + if dropout_p and dropout_p > 0.0: + layers.append(nn.Dropout(p=dropout_p)) + + # Output layer + if isinstance(output_dim, int): + layers.append(nn.Linear(dims[-1], output_dim)) + else: + total_out = int(reduce(lambda a, b: a * b, output_dim)) + layers.append(nn.Linear(dims[-1], total_out)) + layers.append(nn.Unflatten(dim=-1, unflattened_size=tuple(output_dim))) + + if last_act is not None: + layers.append(last_act) + + for idx, layer in enumerate(layers): + self.add_module(str(idx), layer) + + def init_orthogonal(self, scales: Union[float, Sequence[float]] = 1.0) -> None: + """Orthogonal-initialize linear layers and zero the bias. + + scales: single gain value or a sequence with length equal to the + number of linear layers. + """ + + def get_scale(i: int) -> float: + if isinstance(scales, (list, tuple)): + return float(scales[i]) + return float(scales) + + lin_idx = 0 + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight, gain=get_scale(lin_idx)) + nn.init.zeros_(m.bias) + lin_idx += 1 diff --git a/embodichain/agents/rl/models/policy.py b/embodichain/agents/rl/models/policy.py new file mode 100644 index 00000000..cd21d0f7 --- /dev/null +++ b/embodichain/agents/rl/models/policy.py @@ -0,0 +1,94 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Policy base class for RL algorithms. + +This module defines an abstract Policy base class that all RL policies must +inherit from. A Policy encapsulates the neural networks and exposes a uniform +interface for RL algorithms (e.g., PPO, SAC) to interact with. +""" + +from __future__ import annotations + +from typing import Tuple +from abc import ABC, abstractmethod +import torch.nn as nn + +import torch + + +class Policy(nn.Module, ABC): + """Abstract base class that all RL policies must implement. + + A Policy: + - Encapsulates neural networks that are trained by RL algorithms + - Handles internal computations (e.g., network output → distribution) + - Provides a uniform interface for algorithms (PPO, SAC, etc.) + """ + + device: torch.device + """Device where the policy parameters are located.""" + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def get_action( + self, obs: torch.Tensor, deterministic: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sample an action from the policy. + + Args: + obs: Observation tensor of shape (batch_size, obs_dim) + deterministic: If True, return the mean action; otherwise sample + + Returns: + Tuple of (action, log_prob, value): + - action: Sampled action tensor of shape (batch_size, action_dim) + - log_prob: Log probability of the action, shape (batch_size,) + - value: Value estimate, shape (batch_size,) + """ + raise NotImplementedError + + @abstractmethod + def get_value(self, obs: torch.Tensor) -> torch.Tensor: + """Get value estimate for given observations. + + Args: + obs: Observation tensor of shape (batch_size, obs_dim) + + Returns: + Value estimate tensor of shape (batch_size,) + """ + raise NotImplementedError + + @abstractmethod + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Evaluate actions and compute log probabilities, entropy, and values. + + Args: + obs: Observation tensor of shape (batch_size, obs_dim) + actions: Action tensor of shape (batch_size, action_dim) + + Returns: + Tuple of (log_prob, entropy, value): + - log_prob: Log probability of actions, shape (batch_size,) + - entropy: Entropy of the action distribution, shape (batch_size,) + - value: Value estimate, shape (batch_size,) + """ + raise NotImplementedError diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py new file mode 100644 index 00000000..e87d4629 --- /dev/null +++ b/embodichain/agents/rl/train.py @@ -0,0 +1,270 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import argparse +import os +import time +from pathlib import Path + +import numpy as np +import torch +import wandb +import json +from torch.utils.tensorboard import SummaryWriter +from copy import deepcopy + +from embodichain.agents.rl.models import build_policy, get_registered_policy_names +from embodichain.agents.rl.models import build_mlp_from_cfg +from embodichain.agents.rl.algo import build_algo, get_registered_algo_names +from embodichain.agents.rl.utils.trainer import Trainer +from embodichain.utils import logger +from embodichain.lab.gym.envs.tasks.rl import build_env +from embodichain.lab.gym.utils.gym_utils import config_to_rl_cfg +from embodichain.utils.utility import load_json +from embodichain.utils.module_utils import find_function_from_modules +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs.managers.cfg import EventCfg + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True, help="Path to JSON config") + args = parser.parse_args() + + with open(args.config, "r") as f: + cfg_json = json.load(f) + + trainer_cfg = cfg_json["trainer"] + policy_block = cfg_json["policy"] + algo_block = cfg_json["algorithm"] + + # Runtime + exp_name = trainer_cfg.get("exp_name", "generic_exp") + seed = int(trainer_cfg.get("seed", 1)) + device_str = trainer_cfg.get("device", "cpu") + iterations = int(trainer_cfg.get("iterations", 250)) + rollout_steps = int(trainer_cfg.get("rollout_steps", 2048)) + eval_freq = int(trainer_cfg.get("eval_freq", 10000)) + save_freq = int(trainer_cfg.get("save_freq", 50000)) + headless = bool(trainer_cfg.get("headless", True)) + wandb_project_name = trainer_cfg.get("wandb_project_name", "embodychain-generic") + + # Device + if not isinstance(device_str, str): + raise ValueError( + f"runtime.device must be a string such as 'cpu' or 'cuda:0'. Got: {device_str!r}" + ) + try: + device = torch.device(device_str) + except RuntimeError as exc: + raise ValueError( + f"Failed to parse runtime.device='{device_str}': {exc}" + ) from exc + + if device.type == "cuda": + if not torch.cuda.is_available(): + raise ValueError( + "CUDA device requested but torch.cuda.is_available() is False." + ) + index = ( + device.index if device.index is not None else torch.cuda.current_device() + ) + device_count = torch.cuda.device_count() + if index < 0 or index >= device_count: + raise ValueError( + f"CUDA device index {index} is out of range (available devices: {device_count})." + ) + torch.cuda.set_device(index) + device = torch.device(f"cuda:{index}") + elif device.type != "cpu": + raise ValueError(f"Unsupported device type: {device}") + logger.log_info(f"Device: {device}") + + # Seeds + np.random.seed(seed) + torch.manual_seed(seed) + torch.backends.cudnn.deterministic = True + if device.type == "cuda": + torch.cuda.manual_seed_all(seed) + + # Outputs + run_stamp = time.strftime("%Y%m%d_%H%M%S") + run_base = os.path.join("outputs", f"{exp_name}_{run_stamp}") + log_dir = os.path.join(run_base, "logs") + checkpoint_dir = os.path.join(run_base, "checkpoints") + os.makedirs(log_dir, exist_ok=True) + os.makedirs(checkpoint_dir, exist_ok=True) + writer = SummaryWriter(f"{log_dir}/{exp_name}") + + # Initialize Weights & Biases (optional) + use_wandb = trainer_cfg.get("use_wandb", False) + + # Initialize Weights & Biases (optional) + if use_wandb: + wandb.init(project=wandb_project_name, name=exp_name, config=cfg_json) + + gym_config_path = Path(trainer_cfg["gym_config"]) + logger.log_info(f"Current working directory: {Path.cwd()}") + + gym_config_data = load_json(str(gym_config_path)) + gym_env_cfg = config_to_rl_cfg(gym_config_data) + + # Ensure sim configuration mirrors runtime overrides + if gym_env_cfg.sim_cfg is None: + gym_env_cfg.sim_cfg = SimulationManagerCfg() + if device.type == "cuda": + gpu_index = device.index + if gpu_index is None: + gpu_index = torch.cuda.current_device() + gym_env_cfg.sim_cfg.sim_device = torch.device(f"cuda:{gpu_index}") + if hasattr(gym_env_cfg.sim_cfg, "gpu_id"): + gym_env_cfg.sim_cfg.gpu_id = gpu_index + else: + gym_env_cfg.sim_cfg.sim_device = torch.device("cpu") + gym_env_cfg.sim_cfg.headless = headless + + logger.log_info( + f"Loaded gym_config from {gym_config_path} (env_id={gym_env_cfg.env_id}, headless={gym_env_cfg.sim_cfg.headless}, sim_device={gym_env_cfg.sim_cfg.sim_device})" + ) + + env = build_env(gym_env_cfg.env_id, base_env_cfg=gym_env_cfg) + + eval_gym_env_cfg = deepcopy(gym_env_cfg) + eval_gym_env_cfg.num_envs = 4 + eval_gym_env_cfg.sim_cfg.headless = True + + eval_env = build_env(eval_gym_env_cfg.env_id, base_env_cfg=eval_gym_env_cfg) + + # Build Policy via registry + policy_name = policy_block["name"] + # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic) + if policy_name.lower() == "actor_critic": + obs_dim = env.observation_space.shape[-1] + action_dim = env.action_space.shape[-1] + + actor_cfg = policy_block.get("actor") + critic_cfg = policy_block.get("critic") + if actor_cfg is None or critic_cfg is None: + raise ValueError( + "ActorCritic requires 'actor' and 'critic' definitions in JSON (policy.actor / policy.critic)." + ) + + actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) + critic = build_mlp_from_cfg(critic_cfg, obs_dim, 1) + + policy = build_policy( + policy_block, + env.observation_space, + env.action_space, + device, + actor=actor, + critic=critic, + ) + else: + policy = build_policy( + policy_block, env.observation_space, env.action_space, device + ) + + # Build Algorithm via factory + algo_name = algo_block["name"].lower() + algo_cfg = algo_block["cfg"] + algo = build_algo(algo_name, algo_cfg, policy, device) + + # Build Trainer + event_modules = [ + "embodichain.lab.gym.envs.managers.randomization", + "embodichain.lab.gym.envs.managers.record", + "embodichain.lab.gym.envs.managers.events", + ] + events_dict = trainer_cfg.get("events", {}) + train_event_cfg = {} + eval_event_cfg = {} + # Parse train events + for event_name, event_info in events_dict.get("train", {}).items(): + event_func_str = event_info.get("func") + mode = event_info.get("mode", "interval") + params = event_info.get("params", {}) + interval_step = event_info.get("interval_step", 1) + event_func = find_function_from_modules( + event_func_str, event_modules, raise_if_not_found=True + ) + train_event_cfg[event_name] = EventCfg( + func=event_func, + mode=mode, + params=params, + interval_step=interval_step, + ) + # Parse eval events + for event_name, event_info in events_dict.get("eval", {}).items(): + event_func_str = event_info.get("func") + mode = event_info.get("mode", "interval") + params = event_info.get("params", {}) + interval_step = event_info.get("interval_step", 1) + event_func = find_function_from_modules( + event_func_str, event_modules, raise_if_not_found=True + ) + eval_event_cfg[event_name] = EventCfg( + func=event_func, + mode=mode, + params=params, + interval_step=interval_step, + ) + trainer = Trainer( + policy=policy, + env=env, + algorithm=algo, + num_steps=rollout_steps, + batch_size=algo_cfg["batch_size"], + writer=writer, + eval_freq=eval_freq, + save_freq=save_freq, + checkpoint_dir=checkpoint_dir, + exp_name=exp_name, + use_wandb=use_wandb, + eval_env=eval_env, + event_cfg=train_event_cfg, + eval_event_cfg=eval_event_cfg, + ) + + logger.log_info("Generic training initialized") + logger.log_info(f"Task: {type(env).__name__}") + logger.log_info( + f"Policy: {policy_name} (available: {get_registered_policy_names()})" + ) + logger.log_info( + f"Algorithm: {algo_name} (available: {get_registered_algo_names()})" + ) + + total_steps = int(iterations * rollout_steps * env.num_envs) + logger.log_info(f"Total steps: {total_steps} (iterations≈{iterations})") + + try: + trainer.train(total_steps) + except KeyboardInterrupt: + logger.log_info("Training interrupted by user") + finally: + trainer.save_checkpoint() + writer.close() + if use_wandb: + try: + wandb.finish() + except Exception: + pass + logger.log_info("Training finished") + + +if __name__ == "__main__": + main() diff --git a/embodichain/agents/rl/utils/__init__.py b/embodichain/agents/rl/utils/__init__.py new file mode 100644 index 00000000..f6f9f4f9 --- /dev/null +++ b/embodichain/agents/rl/utils/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .config import AlgorithmCfg + +__all__ = [ + "AlgorithmCfg", +] diff --git a/embodichain/agents/rl/utils/config.py b/embodichain/agents/rl/utils/config.py new file mode 100644 index 00000000..2a89e243 --- /dev/null +++ b/embodichain/agents/rl/utils/config.py @@ -0,0 +1,29 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.utils import configclass + + +@configclass +class AlgorithmCfg: + """Minimal algorithm configuration shared across RL algorithms.""" + + device: str = "cuda" + learning_rate: float = 3e-4 + batch_size: int = 64 + gamma: float = 0.99 + gae_lambda: float = 0.95 + max_grad_norm: float = 0.5 diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py new file mode 100644 index 00000000..0ae1fb1e --- /dev/null +++ b/embodichain/agents/rl/utils/trainer.py @@ -0,0 +1,265 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Dict, Any, Tuple, Callable, Optional +import time +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from collections import deque +import wandb + +from embodichain.lab.gym.envs.managers.event_manager import EventManager + + +class Trainer: + """Algorithm-agnostic trainer that coordinates training loop, logging, and evaluation.""" + + def __init__( + self, + policy, + env, + algorithm, + num_steps: int, + batch_size: int, + writer: SummaryWriter | None, + eval_freq: int, + save_freq: int, + checkpoint_dir: str, + exp_name: str, + use_wandb: bool = True, + eval_env=None, + event_cfg=None, + eval_event_cfg=None, + ): + self.policy = policy + self.env = env + self.eval_env = eval_env + self.algorithm = algorithm + self.num_steps = num_steps + self.batch_size = batch_size + self.writer = writer + self.eval_freq = eval_freq + self.save_freq = save_freq + self.checkpoint_dir = checkpoint_dir + self.exp_name = exp_name + self.use_wandb = use_wandb + + if event_cfg is not None: + self.event_manager = EventManager(event_cfg, env=self.env) + if eval_event_cfg is not None: + self.eval_event_manager = EventManager(eval_event_cfg, env=self.eval_env) + + # Get device from algorithm + self.device = self.algorithm.device + self.global_step = 0 + self.start_time = time.time() + self.ret_window = deque(maxlen=100) + self.len_window = deque(maxlen=100) + + # initial obs (assume env returns torch tensors already on target device) + obs, _ = self.env.reset() + self.obs = obs + + # Initialize algorithm's buffer + self.observation_space = getattr(self.env, "observation_space", None) + self.action_space = getattr(self.env, "action_space", None) + obs_dim = ( + self.observation_space.shape[-1] + if self.observation_space + else self.obs.shape[-1] + ) + action_dim = self.action_space.shape[-1] if self.action_space else None + if action_dim is None: + raise RuntimeError( + "Env must expose action_space with shape for buffer initialization." + ) + num_envs = self.obs.shape[0] if self.obs.ndim == 2 else 1 + + # Algorithm manages its own buffer + self.algorithm.initialize_buffer(num_steps, num_envs, obs_dim, action_dim) + + # episode stats tracked on device to avoid repeated CPU round-trips + self.curr_ret = torch.zeros(num_envs, dtype=torch.float32, device=self.device) + self.curr_len = torch.zeros(num_envs, dtype=torch.int32, device=self.device) + + # ---- lightweight helpers for dense logging ---- + @staticmethod + def _mean_scalar(x) -> float: + if hasattr(x, "detach"): + x = x.detach().cpu().numpy() + else: + x = np.asarray(x) + return float(np.mean(x)) + + def _log_scalar_dict(self, prefix: str, data: dict): + if not self.writer or not isinstance(data, dict): + return + for k, v in data.items(): + try: + self.writer.add_scalar( + f"{prefix}/{k}", self._mean_scalar(v), self.global_step + ) + except Exception: + continue + + def _pack_log_dict(self, prefix: str, data: dict) -> dict: + if not isinstance(data, dict): + return {} + out = {} + for k, v in data.items(): + try: + out[f"{prefix}/{k}"] = self._mean_scalar(v) + except Exception: + continue + return out + + def train(self, total_timesteps: int): + print(f"Start training, total steps: {total_timesteps}") + while self.global_step < total_timesteps: + self._collect_rollout() + losses = self.algorithm.update() + self._log_train(losses) + if self.global_step % self.eval_freq == 0: + self._eval_once() + if self.global_step % self.save_freq == 0: + self.save_checkpoint() + + @torch.no_grad() + def _collect_rollout(self): + """Collect a rollout. Algorithm controls the data collection process.""" + + # Callback function for statistics and logging + def on_step(obs, actions, reward, done, info, next_obs): + """Callback called at each step during rollout collection.""" + # Episode stats (stay on device; convert only when episode ends) + self.curr_ret += reward + self.curr_len += 1 + done_idx = torch.nonzero(done, as_tuple=False).squeeze(-1) + if done_idx.numel() > 0: + finished_ret = self.curr_ret[done_idx].detach().cpu().tolist() + finished_len = self.curr_len[done_idx].detach().cpu().tolist() + self.ret_window.extend(finished_ret) + self.len_window.extend(finished_len) + self.curr_ret[done_idx] = 0 + self.curr_len[done_idx] = 0 + + # Update global step and observation + self.obs = next_obs + self.global_step += next_obs.shape[0] if next_obs.ndim == 2 else 1 + + if isinstance(info, dict): + rewards_dict = info.get("rewards") + metrics_dict = info.get("metrics") + self._log_scalar_dict("rewards", rewards_dict) + self._log_scalar_dict("metrics", metrics_dict) + log_dict = {} + log_dict.update(self._pack_log_dict("rewards", rewards_dict)) + log_dict.update(self._pack_log_dict("metrics", metrics_dict)) + if log_dict and self.use_wandb: + wandb.log(log_dict, step=self.global_step) + + # Algorithm controls data collection + result = self.algorithm.collect_rollout( + env=self.env, + policy=self.policy, + obs=self.obs, + num_steps=self.num_steps, + on_step_callback=on_step, + ) + + def _log_train(self, losses: Dict[str, float]): + if self.writer: + for k, v in losses.items(): + self.writer.add_scalar(f"train/{k}", v, self.global_step) + elapsed = max(1e-6, time.time() - self.start_time) + sps = self.global_step / elapsed + self.writer.add_scalar("charts/SPS", sps, self.global_step) + if len(self.ret_window) > 0: + self.writer.add_scalar( + "charts/episode_reward_avg_100", + float(np.mean(self.ret_window)), + self.global_step, + ) + if len(self.len_window) > 0: + self.writer.add_scalar( + "charts/episode_length_avg_100", + float(np.mean(self.len_window)), + self.global_step, + ) + # console + sps = self.global_step / max(1e-6, time.time() - self.start_time) + avgR = np.mean(self.ret_window) if len(self.ret_window) > 0 else float("nan") + avgL = np.mean(self.len_window) if len(self.len_window) > 0 else float("nan") + print( + f"[train] step={self.global_step} sps={sps:.0f} avgReward(100)={avgR:.3f} avgLength(100)={avgL:.1f}" + ) + + # wandb (mirror TB logs) + if self.use_wandb: + log_dict = {f"train/{k}": v for k, v in losses.items()} + log_dict["charts/SPS"] = sps + if not np.isnan(avgR): + log_dict["charts/episode_reward_avg_100"] = float(avgR) + if not np.isnan(avgL): + log_dict["charts/episode_length_avg_100"] = float(avgL) + wandb.log(log_dict, step=self.global_step) + + @torch.no_grad() + def _eval_once(self, num_episodes: int = 5): + self.policy.eval() + returns = [] + for _ in range(num_episodes): + obs, _ = self.eval_env.reset() + done_any = torch.zeros( + obs.shape[0] if obs.ndim == 2 else 1, + dtype=torch.bool, + device=self.device, + ) + num_envs_eval = obs.shape[0] if obs.ndim == 2 else 1 + ep_ret = torch.zeros(num_envs_eval, dtype=torch.float32, device=self.device) + while not done_any.any(): + actions, _, _ = self.policy.get_action(obs, deterministic=True) + result = self.eval_env.step(actions) + obs, reward, terminated, truncated, info = result + done = terminated | truncated + reward = reward.float() + done_any = done + ep_ret += reward + + if hasattr(self, "eval_event_manager"): + if "interval" in self.eval_event_manager.available_modes: + self.eval_event_manager.apply(mode="interval") + + returns.extend(ep_ret.detach().cpu().tolist()) + if self.writer and len(returns) > 0: + self.writer.add_scalar( + "eval/avg_reward", float(np.mean(returns)), self.global_step + ) + + def save_checkpoint(self): + # minimal model-only checkpoint; trainer/algorithm states can be added + path = f"{self.checkpoint_dir}/{self.exp_name}_step_{self.global_step}.pt" + torch.save( + { + "global_step": self.global_step, + "policy": self.policy.state_dict(), + }, + path, + ) + print(f"Checkpoint saved: {path}") diff --git a/embodichain/data/__init__.py b/embodichain/data/__init__.py new file mode 100644 index 00000000..f97508ac --- /dev/null +++ b/embodichain/data/__init__.py @@ -0,0 +1,35 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os + +database_dir = os.path.dirname(os.path.abspath(__file__)).replace("data", "database") +video_dir = os.path.join(database_dir, "video") +weights_dir = os.path.join(database_dir, "weights") +database_2d_dir = os.path.join(database_dir, "2dasset") +database_lang_dir = os.path.join(database_dir, "lang") +database_demo_dir = os.path.join(database_dir, "demostration") +database_tmp_dir = os.path.join(database_dir, "tmp") +database_train_dir = os.path.join(database_dir, "train") + + +if not os.path.exists(database_tmp_dir): + os.makedirs(database_tmp_dir, exist_ok=True) +if not os.path.exists(database_train_dir): + os.makedirs(database_train_dir, exist_ok=True) + +from . import assets +from .dataset import * diff --git a/embodichain/data/assets/__init__.py b/embodichain/data/assets/__init__.py new file mode 100644 index 00000000..875628de --- /dev/null +++ b/embodichain/data/assets/__init__.py @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .materials import * +from .demo_assets import * +from .obj_assets import * +from .w1_assets import * +from .eef_assets import * +from .robot_assets import * +from .scene_assets import * diff --git a/embodichain/data/assets/demo_assets.py b/embodichain/data/assets/demo_assets.py new file mode 100644 index 00000000..8480d468 --- /dev/null +++ b/embodichain/data/assets/demo_assets.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +from pathlib import Path +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class ScoopIceNewEnv(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/demo/ScoopIceNewEnv.zip", + "e92734a9de0f64be33a11fbda0fbd3b6", + ) + prefix = "ScoopIceNewEnv" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class MultiW1Data(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/demo/multi_w1_demo.zip", + "984e8fa3aa05cb36a1fd973a475183ed", + ) + prefix = "MultiW1Data" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/eef_assets.py b/embodichain/data/assets/eef_assets.py new file mode 100644 index 00000000..fa1e9aae --- /dev/null +++ b/embodichain/data/assets/eef_assets.py @@ -0,0 +1,264 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +import open3d as o3d +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class DH_PGC_140_50(EmbodiChainDataset): + """Dataset class for the DH Robotics PGC-140-50 end-effector gripper. + + Reference: + https://www.dh-robotics.com/product/pgc + + Directory structure: + DH_PGC_140_50/ + DH_PGC_140_50.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import DH_PGC_140_50 + >>> dataset = DH_PGC_140_50() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DH_PGC_140_50/DH_PGC_140_50.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/DH_PGC_140_50.zip", + "c2a642308a76e99b1b8b7cb3a11c5df3", + ) + prefix = "DH_PGC_140_50" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DH_PGI_140_80(EmbodiChainDataset): + """Dataset class for the DH Robotics PGI-140-80 end-effector gripper. + + Reference: + https://www.dh-robotics.com/product/pgia### + + Directory structure: + DH_PGI_140_80/ + DH_PGI_140_80.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import DH_PGI_140_80 + >>> dataset = DH_PGI_140_80() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DH_PGI_140_80/DH_PGI_140_80.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/DH_PGI_140_80.zip", + "05a1a08b13c6250cc12affeeda3a08ba", + ) + prefix = "DH_PGI_140_80" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DH_PGC_140_50_M(EmbodiChainDataset): + """Dataset class for the DH Robotics PGC-140-50 end-effector gripper. + DexForce modified connector and finger. + + Reference: + https://www.dh-robotics.com/product/pgc + + Directory structure: + DH_PGC_140_50_M/ + DH_PGC_140_50_M.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import DH_PGC_140_50_M + >>> dataset = DH_PGC_140_50_M() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/DH_PGC_140_50_M.zip", + "3a9ab5f32639e03afb38dc033b44bb62", + ) + prefix = "DH_PGC_140_50_M" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ZH_CTM2F110(EmbodiChainDataset): + """Dataset class for the Zhixing Robot Technology CTM2F110 end-effector gripper. + + Reference: + https://www.changingtek.com/service + + Directory structure: + ZH_CTM2F110/ + ZH_CTM2F110.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import ZH_CTM2F110 + >>> dataset = ZH_CTM2F110() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("ZH_CTM2F110/ZH_CTM2F110.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/ZH_CTM2F110.zip", + "0e7c3310425609797fe010b2a76fe465", + ) + prefix = "ZH_CTM2F110" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class BrainCoHandRevo1(EmbodiChainDataset): + """Dataset class for the BrainCo Hand Revo 1 robotic hand. + + Reference: + https://www.brainco-hz.com/docs/revolimb-hand/revo1/parameters.html + + Directory structure: + BrainCoHandRevo1/ + BrainCoRightHand/BrainCoRightHand.urdf + BrainCoLeftHand/BrainCoLeftHand.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import BrainCoHandRevo1 + >>> dataset = BrainCoHandRevo1() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("BrainCoHandRevo1/BrainCoRightHand/BrainCoRightHand.urdf")) + >>> print(get_data_path("BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/BrainCoHandRevo01.zip", + "ff9ac77e7e1493fd32d40c87fecbee6c", + ) + prefix = "BrainCoHandRevo1" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class InspireHand(EmbodiChainDataset): + """Dataset class for the Inspire Hand robotic hand. + + Reference: + https://en.inspire-robots.com/product/rh56bfx + + Directory structure: + InspireHand/ + InspireLeftHand/InspireLeftHand.urdf + InspireRightHand/InspireRightHand.urdf + inspire_joint_data.csv + inspire_joint_data.npy + + Example usage: + >>> from embodichain.data.eef_dataset import InspireHand + >>> dataset = InspireHand() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("InspireHand/InspireLeftHand/InspireLeftHand.urdf")) + >>> print(get_data_path("InspireHand/InspireRightHand/InspireRightHand.urdf")) + >>> print(get_data_path("InspireHand/inspire_joint_data.csv")) + >>> print(get_data_path("InspireHand/inspire_joint_data.npy")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/InspireHand.zip", + "c60132a6f03866fb021cca5b6d72845e", + ) + prefix = "InspireHand" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Robotiq2F85(EmbodiChainDataset): + """Dataset class for the Robotiq 2F85 robotic gripper. + + Reference: + https://robotiq.com/products/adaptive-grippers#Two-Finger-Gripper + + Directory structure: + Robotiq2F85/ + Robotiq2F85.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import Robotiq2F85 + >>> dataset = Robotiq2F85() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Robotiq2F85/Robotiq2F85.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/Robotiq2F85.zip", + "53ecbf2c953f43f1134aa7223e592292", + ) + prefix = "Robotiq2F85" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class WheelTecFA2F(EmbodiChainDataset): + """Dataset class for the WheelTec FA 2 fingers robotic gripper. + + Reference: + https://www.wheeltec.net/ + + Directory structure: + WheelTecFA2F/ + WheelTecFA2F.urdf + + Example usage: + >>> from embodichain.data.eef_dataset import WheelTecFA2F + >>> dataset = WheelTecFA2F() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("WheelTecFA2F/WheelTecFA2F.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/eef_assets/WheelTecFA2F.zip", + "feaf13f25b1c6ce58d011b1f2fa72f58", + ) + prefix = "WheelTecFA2F" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/materials.py b/embodichain/data/assets/materials.py new file mode 100644 index 00000000..784d6c6d --- /dev/null +++ b/embodichain/data/assets/materials.py @@ -0,0 +1,107 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d + +from pathlib import Path +from typing import List + +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.utils import logger +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class SimResources(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/materials/embodisim_resources.zip", + "53c054b3ae0857416dc52632eb562c12", + ) + prefix = "SimResources" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + def get_ibl_path(self, name: str) -> str: + """Get the path of the IBL resource. + + Args: + name (str): The name of the IBL resource. + + Returns: + str: The path to the IBL resource. + """ + ibl_names = self.get_ibl_list() + if name not in ibl_names: + logger.log_error( + f"Invalid IBL name: {name}. Available names are: {ibl_names}" + ) + return str(Path(self.extract_dir) / "embodysim_resources" / "IBL" / name) + + def get_ibl_list(self) -> List[str]: + """Get the names of all IBL resources. + + Returns: + List[str]: The names of all IBL resources. + """ + return [ + f.name + for f in Path(self.extract_dir).glob("embodysim_resources/IBL/*") + if f.is_dir() + ] + + def get_material_path(self, name: str) -> str: + """Get the path of the material resource. + + Args: + name (str): The name of the material resource. + + Returns: + str: The path to the material resource. + """ + material_names = self.get_material_list() + if name not in material_names: + logger.log_error( + f"Invalid material name: {name}. Available names are: {material_names}" + ) + return str(Path(self.extract_dir) / "embodysim_resources" / "materials" / name) + + def get_material_list(self) -> List[str]: + """Get the names of all material resources. + + Returns: + List[str]: The names of all material resources. + """ + return [ + f.name + for f in Path(self.extract_dir).glob("embodysim_resources/materials/*") + if f.is_dir() + ] + + +class CocoBackground(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/materials/CocoBackground.zip", + "fda82404a317281263bd5849e9eb31a1", + ) + prefix = "CocoBackground" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/obj_assets.py b/embodichain/data/assets/obj_assets.py new file mode 100644 index 00000000..6c2310f8 --- /dev/null +++ b/embodichain/data/assets/obj_assets.py @@ -0,0 +1,215 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +import open3d as o3d +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class ShopTableSimple(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/shop_table_simple.zip", + "e3061ee024de7840f773b70140dcd43f", + ) + prefix = "ShopTableSimple" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class CircleTableSimple(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/circle_table_simple.zip", + "42ad2be8cd0caddcf9bfbf106b7783f3", + ) + prefix = "CircleTableSimple" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class PlasticBin(o3d.data.DownloadDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/plastic_bin.zip", + "21e00083689a4a3c4e4ae3fd89c61e55", + ) + prefix = "PlasticBin" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Chair(o3d.data.DownloadDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/chair.zip", + "df3d7d1a05731d45fb2c678a40a39cd4", + ) + prefix = "Chair" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ContainerMetal(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/container_metal.zip", + "ceafb87f8177609f87aaa6779fcbb9a3", + ) + prefix = "ContainerMetal" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class SimpleBoxDrawer(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/simple_box_drawer.zip", + "966b648bca16823ee91525847c183973", + ) + prefix = "SimpleBoxDrawer" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class AdrianoTable(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/adriano_table.zip", + "8453583a9a1a9d04d50268f8a3da554f", + ) + prefix = "AdrianoTable" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class CoffeeCup(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/CoffeeCup.zip", + "f05fce385826414c15e19df3b75dc886", + ) + prefix = "CoffeeCup" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class SlidingBoxDrawer(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/SlidingBoxDrawer.zip", + "b03d9006503d27b75ddeb06d31b2c7a5", + ) + prefix = "SlidingBoxDrawer" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class AiLiMu_BoxDrawer(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + EMBODICHAIN_DOWNLOAD_PREFIX + "AiLiMu_BoxDrawer_v3.zip", + "9a2889151a23d482f95f602cce9900c6", + ) + prefix = "AiLiMu_BoxDrawer" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class AluminumTable(o3d.data.DownloadDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + EMBODICHAIN_DOWNLOAD_PREFIX + "AluminumTable.glb", + "02991d36ca9b70f019ed330a61143aa9", + ) + prefix = "AluminumTable" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ToyDuck(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + EMBODICHAIN_DOWNLOAD_PREFIX + "ToyDuck.zip", + "2f5c00ba487edf34ad668f7257c0264e", + ) + prefix = "ToyDuck" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class PaperCup(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + EMBODICHAIN_DOWNLOAD_PREFIX + "PaperCup.zip", + "359d13af8c5f31ad3226d8994a1a7198", + ) + prefix = "PaperCup" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ChainRainSec(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/lianguijie.zip", + "2387589040a4d3f2676b622362452242", + ) + prefix = "ChainRainSec" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class TableWare(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/tableware.zip", + "403e340fc0e4996c002ee774f89cd236", + ) + prefix = "TableWare" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ScannedBottle(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/obj_assets/ScannedBottle.zip", + "d2b2d4deb7b463a734af099f7624b4af", + ) + prefix = "ScannedBottle" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/robot_assets.py b/embodichain/data/assets/robot_assets.py new file mode 100644 index 00000000..60a33678 --- /dev/null +++ b/embodichain/data/assets/robot_assets.py @@ -0,0 +1,463 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +import open3d as o3d +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class CobotMagicArm(EmbodiChainDataset): + """Dataset class for the Cobot Magic Arm robot. + + Reference: + https://global.agilex.ai/products/cobot-magic + + Directory structure: + CobotMagicArm/ + CobotMagicNoGripper.urdf + CobotMagicWithGripperV70.urdf + CobotMagicWithGripperV70NewUV.urdf + CobotMagicWithGripperV70NoMaterial.urdf + CobotMagicWithGripperV100.urdf + CobotMagicWithGripperV100NewUV.urdf + CobotMagicWithGripperV100NoMaterial.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import CobotMagicArm + >>> dataset = CobotMagicArm() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("CobotMagicArm/CobotMagicWithGripperV100.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/CobotMagicArmV2.zip", + "14af3e84b74193680899a59fc74e8337", + ) + prefix = "CobotMagicArm" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class RidgeBack(EmbodiChainDataset): + """Dataset class for the RidgeBack wheeled robot. + + Reference: + https://clearpathrobotics.com/ridgeback-indoor-robot-platform/ + + Directory structure: + RidgeBack/ + RidgeBack.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import RidgeBack + >>> dataset = RidgeBack() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("RidgeBack/RidgeBack.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/RidgeBack.zip", + "f03e1a6f4c781ad8957a88bdb010e9b6", + ) + prefix = "RidgeBack" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class UnitreeH1(EmbodiChainDataset): + """Dataset class for the Unitree H1 robot. + + Reference: + https://www.unitree.com/h1/ + + Directory structure: + UnitreeH1/ + UnitreeH1.urdf + UnitreeH1WithWrist.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import UnitreeH1 + >>> dataset = UnitreeH1() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("UnitreeH1/UnitreeH1.urdf")) + >>> print(get_data_path("UnitreeH1/UnitreeH1WithWrist.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/UnitreeH1.zip", + "339417cef5051a912693f3c64d29dddc", + ) + prefix = "UnitreeH1" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class ABB(EmbodiChainDataset): + """Dataset class for the ABB robot. + + Reference: + https://global.abb/ + + Directory structure: + ABB/ + IRB1200_5_90/IRB1200_5_90.urdf + IRB2600_12_165/IRB2600_12_165.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import ABB + >>> dataset = ABB() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("ABB/IRB1200_5_90/IRB1200_5_90.urdf")) + >>> print(get_data_path("ABB/IRB2600_12_165/IRB2600_12_165.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/ABB.zip", + "ea6df4983982606c43387783e5fb8c05", + ) + prefix = "ABB" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Motoman(EmbodiChainDataset): + """Dataset class for the Motoman robot. + + Reference: + https://www.motoman.com/en-us + + Directory structure: + Motoman/ + GP7/GP7.urdf + GP12/GP12.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Motoman + >>> dataset = Motoman() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Motoman/GP7/GP7.urdf")) + >>> print(get_data_path("Motoman/GP12/GP12.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Motoman.zip", + "ee5f16cfce34d8e2cb996fcff8a25986", + ) + prefix = "Motoman" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class KUKA(EmbodiChainDataset): + """Dataset class for the KUKA robot. + + Reference: + https://www.kuka.com/ + + Directory structure: + KUKA/ + KUKA/KR6_R700_sixx/KR6_R700_sixx.urdf + KUKA/KR6_R900_sixx/KR6_R900_sixx.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import ABB + >>> dataset = ABB() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("KUKA/KR6_R700_sixx/KR6_R700_sixx.urdf")) + >>> print(get_data_path("KUKA/KR6_R900_sixx/KR6_R900_sixx.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/KUKA.zip", + "da7a2dfd0db3f486e407f038d25c7537", + ) + prefix = "KUKA" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Fanuc(EmbodiChainDataset): + """Dataset class for the Fanuc robot. + + Reference: + https://www.fanuc.com/ + + Directory structure: + Fanuc/ + M_20iA/M_20iA.urdf + R_2000iC_165F/R_2000iC_165F.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Fanuc + >>> dataset = Fanuc() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Fanuc/KR6_R700_sixx/KR6_R700_sixx.urdf")) + >>> print(get_data_path("Fanuc/KR6_R900_sixx/KR6_R900_sixx.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Fanuc.zip", + "0a1c562f4719f7cdc1b24545fec4a301", + ) + prefix = "Fanuc" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class UniversalRobots(EmbodiChainDataset): + """Dataset class for the Universal Robots. + + Reference: + https://www.universal-robots.com/products/ur-series/ + + Directory structure: + UniversalRobots/ + UR3/UR3.urdf + UR3e/UR3e.urdf + UR5/UR5.urdf + UR5e/UR5e.urdf + UR10/UR10.urdf + UR10e/UR10e.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import UniversalRobots + >>> dataset = UniversalRobots() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("UniversalRobots/UR3/UR3.urdf")) + >>> print(get_data_path("UniversalRobots/UR3e/UR3e.urdf")) + >>> print(get_data_path("UniversalRobots/UR5/UR5.urdf")) + >>> print(get_data_path("UniversalRobots/UR5e/UR5e.urdf")) + >>> print(get_data_path("UniversalRobots/UR10/UR10.urdf")) + >>> print(get_data_path("UniversalRobots/UR10e/UR10e.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/UniversalRobots.zip", + "dbd12f7e36cef4e5025b82f748233b80", + ) + prefix = "UniversalRobots" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Rokae(EmbodiChainDataset): + """Dataset class for the Rokae robots. + + Reference: + https://www.rokae.com/en/product/show/349/SR-Cobots.html + + Directory structure: + Rokae/ + SR3/SR3.urdf + SR5/SR5.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Rokae + >>> dataset = Rokae() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Rokae/SR3/SR3.urdf")) + >>> print(get_data_path("Rokae/SR5/SR5.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Rokae.zip", + "fbfb852d6139e94b7c422771542f988f", + ) + prefix = "Rokae" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Franka(EmbodiChainDataset): + """Dataset class for the Franka robots. + + Reference: + https://franka.de/franka-research-3 + + Directory structure: + Franka/ + Panda/Panda.urdf + PandaHand/PandaHand.urdf + PandaWithHand/PandaWithHand.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Franka + >>> dataset = Franka() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Franka/Panda/Panda.urdf")) + >>> print(get_data_path("Franka/PandaHand/PandaHand.urdf")) + >>> print(get_data_path("Franka/PandaWithHand/PandaWithHand.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Franka.zip", + "c2de367fe1da02eeb45a8129f903d0b6", + ) + prefix = "Franka" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Agile(EmbodiChainDataset): + """Dataset class for the Agile robots. + + Reference: + https://www.agile-robots.com/en/solutions/diana-7/ + + Directory structure: + Agile/ + Diana7/Diana7.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Agile + >>> dataset = Agile() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Agile/Diana7/Diana7.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Agile.zip", + "fd47d7ab8a4d13960fd76e59544ba836", + ) + prefix = "Agile" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Hans(EmbodiChainDataset): + """Dataset class for the Hans robots. + + Reference: + https://www.huayan-robotics.com/elfin + + Directory structure: + Hans/ + E05/E05.urdf + E10/E10.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Hans + >>> dataset = Hans() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Hans/E05/E05.urdf")) + >>> print(get_data_path("Hans/E10/E10.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Hans.zip", + "c867c406e3dffd6982fd0a15e7dc7e29", + ) + prefix = "Hans" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class Aubo(EmbodiChainDataset): + """Dataset class for the Aubo robots. + + Reference: + https://www.aubo-robotics.cn/ + + Directory structure: + Aubo/ + i5/i5.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import Aubo + >>> dataset = Aubo() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Aubo/i5/i5.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/Aubo.zip", + "2574649cd199c11267cc0f4aeac65557", + ) + prefix = "Aubo" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class RainbowY1(EmbodiChainDataset): + """Dataset class for the Aubo robots. + + Reference: + https://www.rainbow-robotics.com/en_rby1 + + Directory structure: + RainbowY1/ + RainbowY1.urdf + + Example usage: + >>> from embodichain.data.robot_dataset import RainbowY1 + >>> dataset = RainbowY1() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("RainbowY1/RainbowY1.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/robot_assets/RainbowY1.zip", + "5979a3aaadb5de6488b13765d523564f", + ) + prefix = "RainbowY1" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/scene_assets.py b/embodichain/data/assets/scene_assets.py new file mode 100644 index 00000000..bd6e30a3 --- /dev/null +++ b/embodichain/data/assets/scene_assets.py @@ -0,0 +1,63 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +from pathlib import Path +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + + +class SceneData(EmbodiChainDataset): + """Dataset class for the Scene. + + Directory structure: + SceneData/ + factory.glb + kitchen.gltf + office.glb + + Example usage: + >>> from embodichain.data.assets.scene_assets import SceneData + >>> data = SceneData() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("Scenedata/factory.glb")) + """ + + def __init__(self, data_root: str = None): + + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/scene_assets/SceneData.zip", + "fb46e4694cc88886fc785704e891a68a", + ) + prefix = "SceneData" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + super().__init__(prefix, data_descriptor, path) + + +class EmptyRoom(o3d.data.DownloadDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/scene_assets/empty_room.zip", + "612ffead4fac95114bec2e3812469f96", + ) + prefix = "EmptyRoom" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/assets/w1_assets.py b/embodichain/data/assets/w1_assets.py new file mode 100644 index 00000000..b46e807f --- /dev/null +++ b/embodichain/data/assets/w1_assets.py @@ -0,0 +1,210 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +from embodichain.data.dataset import EmbodiChainDataset +from embodichain.data.constants import ( + EMBODICHAIN_DOWNLOAD_PREFIX, + EMBODICHAIN_DEFAULT_DATA_ROOT, +) + +# ================= Dexforce W1 Asset Dataset Overview ================= +# This file provides dataset classes for the Dexforce W1 humanoid robot +# and its individual components. +# +# Main Asset: +# - DexforceW1V021: +# Represents the complete humanoid robot asset, +# including both industrial arms and anthropomorphic arms. +# +# Component Assets: +# - DexforceW1ChassisV021: Chassis component +# - DexforceW1TorsoV021: Torso component +# - DexforceW1EyesV021: Eyes component +# - DexforceW1HeadV021: Head component +# +# Arm Assets: +# - DexforceW1LeftArm1V021 / DexforceW1RightArm1V021: +# Anthropomorphic (human-like) arms, left and right. +# - DexforceW1LeftArm2V021 / DexforceW1RightArm2V021: +# Industrial arms, left and right. +# +# All classes inherit from EmbodiChainDataset and are responsible for +# downloading and managing the data resources for their respective components. +# ====================================================================== + + +class DexforceW1V021(EmbodiChainDataset): + """Dataset class for the Dexforce W1 V021. + + Directory structure: + DexforceW1V021/DexforceW1V021.urdf + + Example usage: + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf")) + >>> print(get_data_path("DexforceW1V021/DexforceW1_v02_2.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/DexforceW1V021.zip", + "3cc3a0bfd1c50ebed5bee9dadeee6756", + ) + prefix = "DexforceW1V021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M(EmbodiChainDataset): + """Dataset class for the industrial Dexforce W1 V021 with DH_PGC_gripper. + + Directory structure: + DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M/DexforceW1V021.urdf + + Example usage: + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M/DexforceW1V021.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M.zip", + "06ec5dfa76dc69160d7ff9bc537a6a7b", + ) + prefix = "DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1V021_ANTHROPOMORPHIC_BRAINCO_HAND_REVO1(EmbodiChainDataset): + """Dataset class for the anthropomorphic Dexforce W1 V021 with BrainCo_hand_revo_1. + + Directory structure: + DexforceW1V021_ANTHROPOMORPHIC_BRAINCO_HAND_REVO1/DexforceW1V021.urdf + + Example usage: + >>> from embodichain.data import get_data_path + >>> print(get_data_path("DexforceW1V021_ANTHROPOMORPHIC_BRAINCO_HAND_REVO1/DexforceW1V021.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/DexforceW1V021_ANTHROPOMORPHIC_BRAINCO_HAND_REVO1.zip", + "ef19d247799e79233863b558c47b32cd", + ) + prefix = "DexforceW1V021_ANTHROPOMORPHIC_BRAINCO_HAND_REVO1" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1ChassisV021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_Chassis_v021.zip", + "6b0517a4d92a572988641d46269d063f", + ) + prefix = "DexforceW1ChassisV021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1TorsoV021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_Torso_v021.zip", + "4f762a3ae6ef2acbe484c915cf80da7b", + ) + prefix = "DexforceW1TorsoV021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1EyesV021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_Eyes_v021.zip", + "80e0b86ef2e934f439c99b79074f6f3c", + ) + prefix = "DexforceW1EyesV021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1HeadV021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_Head_v021.zip", + "ba72805828c5fd62ad55d6a1458893d0", + ) + prefix = "DexforceW1HeadV021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1LeftArm1V021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_LeftArm_1_v021.zip", + "c3cacda7bd36389ed98620047bff6216", + ) + prefix = "DexforceW1LeftArm1V021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1RightArm1V021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_RightArm_1_v021.zip", + "456c9495748171003246a3f6626bb0db", + ) + prefix = "DexforceW1RightArm2V021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1LeftArm2V021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_LeftArm_2_v021.zip", + "b99bd0587cc9a36fed3cdaa4f9fd62e7", + ) + prefix = "DexforceW1LeftArm2V021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) + + +class DexforceW1RightArm2V021(EmbodiChainDataset): + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + "https://huggingface.co/datasets/dexforce/embodichain_data/resolve/main/dexforce_w1/W1_RightArm_2_v021.zip", + "d9f25b2d5244ca5a859040327273a99e", + ) + prefix = "DexforceW1RightArm1V021" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/data/constants.py b/embodichain/data/constants.py new file mode 100644 index 00000000..3a4aa144 --- /dev/null +++ b/embodichain/data/constants.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +from pathlib import Path + +EMBODICHAIN_DOWNLOAD_PREFIX = "http://192.168.3.120/CoreEngine/Data/embodychain_data/" +EMBODICHAIN_DEFAULT_DATA_ROOT = str(Path.home() / ".cache" / "embodichain_data") diff --git a/embodichain/data/data_engine/__init__.py b/embodichain/data/data_engine/__init__.py new file mode 100644 index 00000000..e4655620 --- /dev/null +++ b/embodichain/data/data_engine/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- diff --git a/embodichain/data/data_engine/data_dict_extractor.py b/embodichain/data/data_engine/data_dict_extractor.py new file mode 100644 index 00000000..d56a5a94 --- /dev/null +++ b/embodichain/data/data_engine/data_dict_extractor.py @@ -0,0 +1,1135 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_warning, log_error + +try: + import h5ffmpeg as hf + + has_h5ffmpeg = True +except Exception as e: + has_h5ffmpeg = False + log_warning("Fail to import h5ffmpeg.") + +import h5py +import os +import random +import torch +import numpy as np + +from functools import cached_property +from typing import Dict, Any, List, Union, Optional +from embodichain.data.enum import ( + Modality, + PrivilegeType, + JointType, +) +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.sensors import StereoCamera +from embodichain.lab.gym.envs import BaseEnv, EmbodiedEnv +from embodichain.lab.gym.utils.gym_utils import map_qpos_to_eef_pose +from embodichain.utils.utility import get_right_name +from embodichain.lab.gym.utils.misc import _data_key_to_control_part +from embodichain.utils import logger +from embodichain.data.data_engine.unified_state import ( + StateUnifier, +) +from embodichain.data.enum import ( + SUPPORTED_PROPRIO_TYPES, + SUPPORTED_ACTION_TYPES, + SUPPORTED_EXTRA_VISION_TYPES, +) +from tqdm import tqdm +from copy import deepcopy + +SCALE_FACTOR = 4e3 # Scale factor for depth data +FAR_CLIP = 4.0 # m + +DATA_FORMATS = { + "observations": { + Modality.IMAGES.value: {}, + Modality.GEOMAP.value: {}, + PrivilegeType.MASK.value: {}, + PrivilegeType.EXTEROCEPTION.value: {}, + Modality.STATES.value: {}, + }, + Modality.ACTIONS.value: {}, +} + + +class CompressedVideoHDF5: + def __init__(self, save_path: str, chunks: int = 20) -> None: + """ + Initializes the data dictionary extractor with the specified save path and number of chunks. + Attempts to configure video encoding settings based on the detected GPU model using the h5ffmpeg library. + Supported GPUs include NVIDIA A800 and NVIDIA GeForce RTX 3060, with specific encoding configurations for each. + If the GPU is unsupported or an error occurs during initialization, a warning is logged and default configuration is used. + + Args: + save_path (str): Path where extracted data will be saved. + chunks (int, optional): Number of chunks to split the data into. Defaults to 20. + + Raises: + ValueError: If the detected GPU is not supported. + """ + self.save_path = save_path + self.chunks = chunks + + try: + import h5ffmpeg as hf + import torch + + name = torch.cuda.get_device_name() + + if "A800" in name or name == "NVIDIA A800-SXM4-80GB": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "3060" in name or name == "NVIDIA GeForce RTX 3060": + self.conf = { + Modality.GEOMAP.value: hf.h264_nvenc(), + Modality.IMAGES.value: hf.h264_nvenc(), + PrivilegeType.MASK.value: hf.h264_nvenc(), + } + elif "3090" in name or name == "NVIDIA GeForce RTX 3090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + elif "4090" in name or name == "NVIDIA GeForce RTX 4090": + self.conf = { + Modality.GEOMAP.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + Modality.IMAGES.value: hf.x264( + preset="veryfast", tune="fastdecode" + ), + PrivilegeType.MASK.value: hf.x264( + preset="veryslow", tune="ssim", crf=0 + ), + } + else: + raise ValueError("Unsupported GPU: {}".format(name)) + + except Exception as e: + log_warning( + "{}. Please make sure h5ffmpeg is successfully installed.".format(e) + ) + self.conf = {} + + @staticmethod + def is_compressed_hdf5(data: Dict) -> bool: + images_group = data.get("observations", {}).get(Modality.IMAGES.value, {}) + has_compressed_keys = any( + (isinstance(k, str) and ("index" in k or "start" in k)) + for k in images_group.keys() + ) + return has_compressed_keys + + @staticmethod + def get_chunk_name(name: str, id: Union[int, str]) -> str: + """ + Generates a chunk name by concatenating the given name with the provided id, separated by an underscore. + Args: + name (str): The base name for the chunk. + id (Union[int, str]): The identifier to append to the name. + Returns: + str: The resulting chunk name in the format 'name_id'. + """ + + return name + "_{}".format(id) + + @staticmethod + def video_save( + f, + chunks: int, + data: Dict[str, np.ndarray], + key: str, + dtype=np.uint8, + conf: Dict = None, + ): + """ + Saves video data from multiple cameras into an HDF5 file, splitting the data into chunks for efficient storage. + Args: + f: An open HDF5 file handle where the video data will be saved. + data (Dict[str, np.ndarray]): Dictionary mapping camera names to their corresponding video data arrays. + key (str): Key under "observations" group in the HDF5 file to store the video data. + dtype (type, optional): Data type to convert the video frames to before saving (default: np.uint8). + conf (Dict, optional): Additional configuration parameters for HDF5 dataset creation. + Notes: + - Video data for each camera is processed and split into the specified number of chunks. + - Index and start datasets are created for each camera to map frame indices to chunk IDs and chunk start indices. + - Uses CompressedVideoHDF5 utility functions for data formatting and conversion. + - Progress is displayed using tqdm for each chunk being saved. + """ + import h5ffmpeg as hf + + f_images = f["observations"].create_group(key) + + for cam_name in data.keys(): + data_ = data[cam_name] + if len(data_) != 0: + data_ = CompressedVideoHDF5.to_bhw(data_) + + if dtype == np.uint16: + data_ = CompressedVideoHDF5.uint16_depth(data_) + else: + data_ = data_.astype(dtype) + + data_chunks = np.array_split(data_, chunks, axis=0) + data_chunk_ids = np.arange(data_.shape[0]) + data_chunk_ids_ = np.array_split(data_chunk_ids, chunks) + idtochunkid = np.zeros((data_.shape[0])) + chunkid2startid = np.zeros((chunks,)) + for chunkid, temp in enumerate(data_chunk_ids_): + chunkid2startid[chunkid] = min(temp) + for tempi in temp: + idtochunkid[tempi] = chunkid + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "index"), + data=idtochunkid, + ) + _ = f_images.create_dataset( + CompressedVideoHDF5.get_chunk_name(cam_name, "start"), + data=chunkid2startid, + ) + + for t, data_chunk in enumerate(tqdm(data_chunks)): + _ = f_images.create_dataset( + "{}/{}".format(cam_name, t), + data=data_chunk, + chunks=data_chunk.shape, + **conf, + ) + + @staticmethod + def uint16_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts a depth data array to a uint16 format after applying scaling and clipping. + Args: + data (np.ndarray): The input depth data as a NumPy array. + scale_factor (float, optional): The factor by which to scale the depth data. + Defaults to SCALE_FACTOR. + far_clip (float, optional): The maximum depth value (far clipping plane) + before scaling. Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled and clipped depth data as a NumPy array of type uint16. + """ + return (np.clip(data * scale_factor, 0, far_clip * scale_factor)).astype( + np.uint16 + ) + + @staticmethod + def float32_depth( + data: np.ndarray, scale_factor: float = SCALE_FACTOR, far_clip: float = FAR_CLIP + ) -> np.ndarray: + """ + Converts depth data to float32 and scales it by the given scale factor. + Args: + data (np.ndarray): The input depth data array. + scale_factor (float, optional): The factor by which to scale the depth values. Defaults to SCALE_FACTOR. + far_clip (float, optional): The far clipping distance (unused in this function). Defaults to FAR_CLIP. + Returns: + np.ndarray: The scaled depth data as a float32 numpy array. + """ + + return data.astype(np.float32) / scale_factor + + @staticmethod + def to_bhw(data: np.ndarray) -> np.ndarray: + """ + Reshapes a 4D numpy array from (vdepth, height, width, channels) to (vdepth, height, width * channels). + If the input is already a 3D array, returns it unchanged. + Args: + data (np.ndarray): Input array of shape (vdepth, height, width, channels) or (vdepth, height, width). + Returns: + np.ndarray: Reshaped array of shape (vdepth, height, width * channels) or the original array if 3D. + Raises: + Logs an error if the input array does not have 3 or 4 dimensions. + """ + + if len(data.shape) == 4: + vdepth, h, w, channels = ( + data.shape[0], + data.shape[1], + data.shape[2], + data.shape[3], + ) + return data.reshape(vdepth, h, w * channels) + elif len(data.shape) == 3: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + @staticmethod + def to_bhwc(data: np.ndarray): + """ + Converts a numpy array to BHWC (Batch, Height, Width, Channels) format. + If the input array has 3 dimensions, it reshapes the array to have a channel dimension of size 3. + If the input array already has 4 dimensions, it returns the array unchanged. + Otherwise, logs an error for unsupported shapes. + Args: + data (np.ndarray): Input numpy array to be converted. + Returns: + np.ndarray: Array in BHWC format. + Raises: + Logs an error if the input array shape is not supported. + """ + + if len(data.shape) == 3: + vdepth, h, w = data.shape + return data.reshape(vdepth, h, -1, 3) + elif len(data.shape) == 4: + return data + else: + log_error("Unsupported data shape: {}".format(data.shape)) + + def dump( + self, + ret: Dict, + video_names: List[str] = [ + Modality.IMAGES.value, + PrivilegeType.MASK.value, + Modality.GEOMAP.value, + ], + dtypes: List = [np.uint8, np.uint8, np.uint16], + ): + """ + Dumps the provided data into an HDF5 file, saving specific video data with + compression and specified data types. + Args: + ret (Dict): The data dictionary containing observations and other metadata. + video_names (List[str], optional): A list of video names to extract from + the observations. Defaults to [Modality.IMAGES.value, PrivilegeType.MASK.value, Modality.GEOMAP.value]. + dtypes (List, optional): A list of data types corresponding to each video + name. Defaults to [np.uint8, np.uint8, np.uint16]. + Raises: + AssertionError: If the lengths of `video_names` and `dtypes` are not equal. + RuntimeError: If the configuration (`self.conf`) is empty, indicating that + `h5ffmpeg` is not installed or configured properly. + Notes: + - The method modifies the `ret` dictionary by temporarily removing the + specified video data during the HDF5 file creation process and then + restoring it afterward. + - The `hdfdict.dump` function is used to save the remaining data in the + dictionary, while the `CompressedVideoHDF5.video_save` function handles + the saving of video data with compression. + """ + + assert len(video_names) == len( + dtypes + ), "Inequal length of video names {} and dtypes {}.".format(video_names, dtypes) + import hdfdict + + if self.conf == {}: + raise RuntimeError( + "Please make sure h5ffmpeg is successfully installed before using `dump`." + ) + + pop_ret = {} + for video_name, dtype in zip(video_names, dtypes): + video_data = ret["observations"].pop(video_name) + pop_ret[video_name] = video_data + + # Open the file once and pass the open file object to hdfdict.dump so + # h5py doesn't try to truncate the same path while it is already open. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + for video_name, dtype in zip(video_names, dtypes): + CompressedVideoHDF5.video_save( + f, + self.chunks, + pop_ret[video_name], + video_name, + dtype=dtype, + conf=self.conf[video_name], + ) + + ret["observations"].update(pop_ret) + + @staticmethod + def decode_resources( + f: Dict, + ret: Dict, + name: str, + slice_id: int, + condition: callable, + function: callable, + padding: bool = True, + chunk_id: int = None, + ): + """ + Decodes and processes resources from a hierarchical data structure, applying + a condition and transformation function to the data, and optionally adding + zero-padding. + Args: + f (Dict): The input data dictionary containing observations and metadata. + ret (Dict): The output data dictionary where processed data will be stored. + name (str): The key name under "observations" to access specific data. + slice_id (int): The slice index used to retrieve the corresponding chunk ID. + condition (callable): A function that takes the data as input and returns + a boolean indicating whether the transformation function should be applied. + function (callable): A function to transform the data if the condition is met. + padding (bool, optional): Whether to add zero-padding to the data. Defaults to True. + chunk_id (int, optional): The chunk ID to use instead of deriving it from the slice ID. + Defaults to None. + Returns: + None: The function modifies the `ret` dictionary in place. + """ + + import time + + images = f["observations"][name] + + for cam_name in images.keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + + start_time = time.time() + sliceid2chunkid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "index") + ][:] + chunkid = int(sliceid2chunkid[slice_id]) if chunk_id is None else chunk_id + data_ = images[cam_name][str(chunkid)][:] + # log_warning("".format(time.time() - start_time) + if condition(data_): + data_ = function(data_) + + if padding: + chunkid2startid = images[ + CompressedVideoHDF5.get_chunk_name(cam_name, "start") + ][:] + start_idx = chunkid2startid[chunkid] + zero_padding = np.zeros_like(data_)[0:1] + zero_padding = np.repeat(zero_padding, repeats=start_idx, axis=0) + ret["observations"][name][cam_name] = np.concatenate( + [zero_padding, data_], 0 + ) + else: + if ret["observations"][name][cam_name] is None: + ret["observations"][name][cam_name] = data_ + else: + ret["observations"][name][cam_name] = np.concatenate( + [ret["observations"][name][cam_name], data_], 0 + ) + + def safe_filter(self, f: Dict, slice_id: int = None) -> Dict: + """ + Filters and processes the input data dictionary based on the configuration + and specified slice ID. + Args: + f (Dict): The input data dictionary containing observations, including + images, masks, and geomap. + slice_id (int, optional): The specific slice ID to process. If None, + processes all chunks. Defaults to None. + Returns: + Dict: The filtered and processed data dictionary with updated + observations for images, masks, and geomap. + Notes: + - The method filters out camera names containing "index" or "start". + - It initializes the return dictionary with None values for images, + masks, and geomap for the filtered camera names. + - Depending on the `slice_id`, it either processes all chunks or a + specific slice using the `CompressedVideoHDF5.decode_resources` + method. + - The processed observations are updated in the input dictionary `f`. + """ + + if self.conf is {}: + return f + + cam_names = [] + for cam_name in f["observations"][Modality.IMAGES.value].keys(): + if "index" in cam_name: + continue + if "start" in cam_name: + continue + cam_names.append(cam_name) + + # Only build return structure for actually present modalities, avoid errors when real data lacks mask/geomap + present_modalities = [] + if Modality.IMAGES.value in f["observations"]: + present_modalities.append(Modality.IMAGES.value) + if PrivilegeType.MASK.value in f["observations"]: + present_modalities.append(PrivilegeType.MASK.value) + if Modality.GEOMAP.value in f["observations"]: + present_modalities.append(Modality.GEOMAP.value) + + ret = {"observations": {}} + for modality_key in present_modalities: + ret["observations"][modality_key] = { + cam_name: None for cam_name in cam_names + } + + if slice_id == None: + # For all chunks + for chunk_id_ in range(self.chunks): + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + None, + lambda x: len(x.shape) == 3, + self.to_bhwc, + chunk_id=chunk_id_, + padding=False, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + None, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + chunk_id=chunk_id_, + padding=False, + ) + + else: + if Modality.IMAGES.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.IMAGES.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if PrivilegeType.MASK.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + PrivilegeType.MASK.value, + slice_id, + lambda x: len(x.shape) == 3, + self.to_bhwc, + ) + if Modality.GEOMAP.value in present_modalities: + CompressedVideoHDF5.decode_resources( + f, + ret, + Modality.GEOMAP.value, + slice_id, + lambda x: x.dtype == np.uint16 and len(x) != 0, + self.float32_depth, + ) + if Modality.IMAGES.value in present_modalities: + f["observations"][Modality.IMAGES.value] = ret["observations"][ + Modality.IMAGES.value + ] + if PrivilegeType.MASK.value in present_modalities: + f["observations"][PrivilegeType.MASK.value] = ret["observations"][ + PrivilegeType.MASK.value + ] + if Modality.GEOMAP.value in present_modalities: + f["observations"][Modality.GEOMAP.value] = ret["observations"][ + Modality.GEOMAP.value + ] + + return f + + +class ActStateStatistic: + def __init__(self, data_dict: Dict, min_len_steps: int) -> None: + self.data_dict = data_dict + self.min_len_steps = min_len_steps + + def prepare_state_and_action( + self, + ): + proprio = self.data_dict["observations"][Modality.STATES.value][:] + num_steps = proprio.shape[0] + # [Optional] We drop too-short episode + if num_steps < self.min_len_steps: + return False, None + # [Optional] We skip the first few still steps + EPS = 1e-2 + # Get the idx of the first qpos whose delta exceeds the threshold + proprio_delta = np.abs(proprio - proprio[0:1]) + indices = np.where(np.any(proprio_delta > EPS, axis=1))[0] + if len(indices) > 0: + first_idx = indices[0] + else: + raise ValueError("Found no qpos that exceeds the threshold.") + target_actions = self.data_dict[Modality.ACTIONS.value][:] + # Parse the state and action + state = proprio[first_idx - 1 :] + action = target_actions[first_idx - 1 :] + # Return the resulting sample + + return True, {Modality.STATES.value: state, Modality.ACTIONS.value: action} + + def statistic( + self, + ) -> Dict: + EPS = 1e-8 + episode_cnt = 0 + state_sum = 0 + state_sum_sq = 0 + z_state_sum = 0 + z_state_sum_sq = 0 + state_cnt = 0 + nz_state_cnt = None + state_max = None + state_min = None + _, episode = self.prepare_state_and_action() + episode_cnt += 1 + + states = episode[Modality.STATES.value] + + # Zero the values that are close to zero + z_states = states.copy() + z_states[np.abs(states) <= EPS] = 0 + # Compute the non-zero count + if nz_state_cnt is None: + nz_state_cnt = np.zeros(states.shape[1]) + nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) + + # Update statistics + state_sum += np.sum(states, axis=0) + state_sum_sq += np.sum(states**2, axis=0) + z_state_sum += np.sum(z_states, axis=0) + z_state_sum_sq += np.sum(z_states**2, axis=0) + state_cnt += states.shape[0] + if state_max is None: + state_max = np.max(states, axis=0) + state_min = np.min(states, axis=0) + else: + state_max = np.maximum(state_max, np.max(states, axis=0)) + state_min = np.minimum(state_min, np.min(states, axis=0)) + + # Add one to avoid division by zero + nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) + + result = { + "state_mean": (state_sum / state_cnt).tolist(), + "state_std": np.sqrt( + np.maximum( + (z_state_sum_sq / nz_state_cnt) + - (z_state_sum / state_cnt) ** 2 * (state_cnt / nz_state_cnt), + np.zeros_like(state_sum_sq), + ) + ).tolist(), + "state_min": state_min.tolist(), + "state_max": state_max.tolist(), + } + + return result + + +class DataDictExtractor: + def __init__( + self, + env: Union[BaseEnv, EmbodiedEnv], + save_path: str = None, + compression_opts: int = 9, + ): + self.env = env + self.save_path = save_path + self.data = {} + + # save all supported proprio and action types. + robot_meta_config = deepcopy(self.env.metadata["dataset"]["robot_meta"]) + robot_meta_config["observation"][ + Modality.STATES.value + ] = SUPPORTED_PROPRIO_TYPES + robot_meta_config[Modality.ACTIONS.value] = SUPPORTED_ACTION_TYPES + + self.state_unifier = StateUnifier(robot_meta=robot_meta_config) + self.compression_opts = compression_opts + + @cached_property + def robot_control_parts(self) -> List[str]: + """Get the robot's control parts. + + Note: + If control_parts is specified in the robot metadata, return those parts. + Otherwise, return all control parts. + + Returns: + List[str]: The robot's control parts. + """ + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + control_parts = robot_meta_config.get("control_parts", None) + if control_parts is None: + return [] + else: + return control_parts + + def _get_arm_control_parts(self) -> List[str]: + control_parts = self.robot_control_parts + arm_control_parts = [] + for part in control_parts: + if "arm" in part: + arm_control_parts.append(part) + return arm_control_parts + + def _has_exteroception(self) -> bool: + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + return PrivilegeType.EXTEROCEPTION.value in robot_meta_config["observation"] + + def extract( + self, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + data_dict: Dict = DATA_FORMATS, + save: bool = True, + ): + if save: + assert ( + self.save_path is not None + ), "Please provide a save path for the dataset." + data_dict = deepcopy(data_dict) + + self._init_data(data_dict) + + ret = {} + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + + for i, (obs, action) in enumerate(zip(obs_list, action_list)): + self._extract_vision_obs(obs, data_dict) + self._extract_proprioception(obs, data_dict) + self._extract_action(action, data_dict) + + action = self._collate_action(data_dict) + proprio = self._collate_proprio(data_dict) + robot_meta = self._collate_metainfo() + + extra_vision_config = robot_meta_config["observation"]["vision"] + obs = {"observations": {}} + images = self.collate_sub_anns( + data_dict, extra_vision_config, Modality.IMAGES.value + ) + obs["observations"].update(proprio) + obs["observations"].update(images) + + extra_vision_names = list( + set([name for list in extra_vision_config.values() for name in list]) + ) + for extra_vision_name in extra_vision_names: + extra_vision_obs = self.collate_sub_anns( + data_dict, extra_vision_config, extra_vision_name + ) + obs["observations"].update(extra_vision_obs) + + ret.update(robot_meta) + ret.update(obs) + ret.update(action) + + statistics = ActStateStatistic( + ret, self.env.metadata["dataset"]["robot_meta"]["min_len_steps"] + ).statistic() + ret.update(statistics) + + if save: + if has_h5ffmpeg: + cvhdf5 = CompressedVideoHDF5(self.save_path) + all_video_names = [Modality.IMAGES.value] + [ + name + for name in extra_vision_names + if name != PrivilegeType.EXTEROCEPTION.value + ] + all_dtypes = [ + np.uint16 if name == Modality.GEOMAP.value else np.uint8 + for name in all_video_names + ] + cvhdf5.dump(ret, video_names=all_video_names, dtypes=all_dtypes) + else: + logger.log_info( + "h5ffmpeg is not installed, saving dataset without compression." + ) + import hdfdict + + # Open the file once and pass the file object to hdfdict.dump to + # avoid opening/truncating the same file path twice which causes + # "unable to truncate a file which is already open" errors on + # some platforms and HDF5 builds. + with h5py.File(self.save_path, "w") as f: + hdfdict.dump(ret, f) + + return ret + + def _init_data(self, data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for proprio_name in SUPPORTED_PROPRIO_TYPES: + data_dict["observations"][Modality.STATES.value][proprio_name] = [] + for action_name in SUPPORTED_ACTION_TYPES: + data_dict[Modality.ACTIONS.value][action_name] = [] + + for camera_name, extra_vision_list in extra_vision_config.items(): + is_stereo = isinstance(self.env.get_sensor(camera_name), StereoCamera) + + data_dict["observations"][Modality.IMAGES.value][camera_name] = [] + if is_stereo: + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ] = [] + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + data_dict["observations"][extra_vision_name][camera_name] = [] + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ] = [] + + def _extract_vision_obs(self, obs: Dict[str, Any], data_dict: Dict): + robot_meta_config = self.env.metadata["dataset"]["robot_meta"] + extra_vision_config = robot_meta_config["observation"]["vision"] + + for camera_name, extra_vision_list in extra_vision_config.items(): + if camera_name in obs["sensor"]: + is_stereo = isinstance(self.env.get_sensor(camera_name), StereoCamera) + + data_dict["observations"][Modality.IMAGES.value][camera_name].append( + obs["sensor"][camera_name]["color"] + .squeeze(0)[:, :, :3] + .cpu() + .numpy() + ) + if is_stereo: + # save rgb right + data_dict["observations"][Modality.IMAGES.value][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["color_right"] + .squeeze_(0)[:, :, :3] + .cpu() + .numpy() + ) + + for extra_vision_name in extra_vision_list: + if extra_vision_name in SUPPORTED_EXTRA_VISION_TYPES: + if extra_vision_name == PrivilegeType.EXTEROCEPTION.value: + if is_stereo: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name]["l"] + .cpu() + .numpy() + ) + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs[extra_vision_name][camera_name]["r"] + .cpu() + .numpy() + ) + elif camera_name in obs.get(extra_vision_name, {}): + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs[extra_vision_name][camera_name].cpu().numpy() + ) + elif extra_vision_name == PrivilegeType.MASK.value: + # save semantic mask for monocular cameras + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["semantic_mask_l"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + if is_stereo: + data_dict["observations"][extra_vision_name][ + get_right_name(camera_name) + ].append( + obs["sensor"][camera_name]["semantic_mask_r"] + .squeeze_(0) + .numpy() + .astype(np.uint8) + ) + elif extra_vision_name == Modality.GEOMAP.value: + if not is_stereo: + log_error( + f"Camera {camera_name} is not stereo, while '{extra_vision_name}' is in gym_config.dataset.robot_meta.vision, please check again." + ) + if "depth" in obs["sensor"][camera_name]: + data_dict["observations"][extra_vision_name][ + camera_name + ].append( + obs["sensor"][camera_name]["depth"] + .squeeze_() + .numpy() + ) + else: + log_error( + f"obs['sensor'][{camera_name}] has no key named 'depth' while it's required in gym_config.dataset.robot_meta.vision, please check again." + ) + else: + log_error( + f"Extra vision observation name {extra_vision_name} is not in SUPPORTED_EXTRA_VISION_TYPES {SUPPORTED_EXTRA_VISION_TYPES}, please check again." + ) + else: + logger.log_error( + f"Camera {camera_name} not found in observations, please check your sensor configuration in gym_config.json" + ) + + def _extract_action( + self, + action: torch.Tensor, + data_dict: Dict, + ): + robot: Robot = self.env.robot + + for key in data_dict[Modality.ACTIONS.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = robot.get_joint_ids(part, remove_mimic=True) + data_dict[Modality.ACTIONS.value][key].append( + action[0, indices].cpu().numpy() + if isinstance(action, torch.Tensor) + else action[0, indices] + ) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, action, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict[Modality.ACTIONS.value][key].append( + val.squeeze_(0).cpu().numpy() + if isinstance(val, torch.Tensor) + else val.squeeze_(0) + ) + + def _extract_proprioception( + self, + obs: Dict[str, Any], + data_dict: Dict, + ): + robot: Robot = self.env.robot + + qpos = obs["robot"][JointType.QPOS.value] + for key in data_dict["observations"][Modality.STATES.value].keys(): + part = _data_key_to_control_part( + robot=robot, + control_parts=self.env.metadata["dataset"]["robot_meta"].get( + "control_parts", [] + ), + data_key=key, + ) + if part is None: + continue + indices = robot.get_joint_ids(part, remove_mimic=True) + data_dict["observations"][Modality.STATES.value][key].append( + qpos[0][indices].cpu().numpy() + ) + + eef_pose_dict = map_qpos_to_eef_pose( + robot, qpos, control_parts=self._get_arm_control_parts() + ) + for key, val in eef_pose_dict.items(): + data_dict["observations"][Modality.STATES.value][key].append( + val.squeeze_(0).cpu().numpy() + ) + + def _collate_proprio(self, data_dict: Dict) -> Dict: + proprio_dict = {} + for proprio_name in self.state_unifier.proprio_meta: + proprio = np.array( + data_dict["observations"][Modality.STATES.value][proprio_name] + ) + proprio_dict[proprio_name] = proprio + proprios = self.state_unifier.fill_in_state(proprio_dict) + return {Modality.STATES.value: proprios} + + def _collate_metainfo( + self, + ) -> Dict: + meta_info = { + "arm_dofs": self.env.metadata["dataset"]["robot_meta"].get("arm_dofs", 12), + "observation": self.env.metadata["dataset"]["robot_meta"].get( + "observation", {} + ), + "min_len_steps": self.env.metadata["dataset"]["robot_meta"].get( + "min_len_steps", 125 + ), + } + return { + "robot_meta": meta_info, + "instruction": { + "lang": self.env.metadata["dataset"]["instruction"].get("lang", "") + }, + } + + def _collate_action(self, data_dict: Dict) -> Dict: + action_data_dict = data_dict[Modality.ACTIONS.value] + for k, v in action_data_dict.items(): + action_data_dict[k] = np.array(v) + + action_dict = {} + action_dict.update(action_data_dict) + action = self.state_unifier.fill_in_action(action_dict) + return {Modality.ACTIONS.value: action} + + @staticmethod + def collate_sub_anns( + data_dict: Dict, + extra_vision_config: Dict, + key: str = Modality.IMAGES.value, + ) -> Dict: + ret = {key: {}} + for camera_name in extra_vision_config: + images_list = data_dict["observations"][key].pop(camera_name, None) + if images_list is None: + continue + if len(images_list) > 0: + ret[key][camera_name] = np.empty( + (len(images_list),) + images_list[0].shape, + dtype=images_list[0].dtype, + ) + for idx, image in enumerate(images_list): + ret[key][camera_name][idx] = image + else: + ret[key][camera_name] = np.array([]) + + del images_list + if get_right_name(camera_name) in data_dict["observations"][key]: + images_right_list = data_dict["observations"][key].pop( + get_right_name(camera_name), None + ) + if images_right_list is None: + continue + if len(images_right_list) > 0: + ret[key][get_right_name(camera_name)] = np.empty( + (len(images_right_list),) + images_right_list[0].shape, + dtype=images_right_list[0].dtype, + ) + for idx, image in enumerate(images_right_list): + ret[key][get_right_name(camera_name)][idx] = image + else: + ret[key][get_right_name(camera_name)] = np.array([]) + del images_right_list + + return ret + + +def fetch_imitation_dataset( + env: BaseEnv, + obs_list: List[Dict[str, Any]], + action_list: List[Dict[str, Any]], + id: str, + folder_name: str, +) -> Dict: + """ + Save imitation dataset for a single episode. + + Args: + env (BaseEnv): Environment instance. + obs_list (List[Dict]): List of observation dicts. + action_list (List[Dict]): List of action dicts. + id (str): Unique identifier for the episode. + folder_name (str): Folder name for saving the dataset. + + Returns: + dict: Contains data_path, id, current_episode, and extracted data. + """ + # Get dataset save path + dataset_path = env.metadata["dataset"].get("save_path", None) + if dataset_path is None: + from embodichain.data import database_demo_dir + + dataset_path = database_demo_dir + + # Create folder if first episode + dataset_save_path = os.path.join(dataset_path, folder_name) + if env.curr_episode == 0 and id: + os.makedirs(dataset_save_path, exist_ok=True) + + # Check robot dof validity + try: + robot: Robot = env.robot + assert ( + env.metadata["dataset"]["robot_meta"]["arm_dofs"] <= robot.dof + ), f"Control dof {env.metadata['dataset']['robot_meta']['arm_dofs']} must be less than {robot.dof}." + except Exception as e: + logger.log_error(f"Robot DOF check failed: {e}") + return None + + # Select data format + data_format = DATA_FORMATS + + # Extract and save data + if id is None: + ret = DataDictExtractor(env).extract( + obs_list, action_list, save=False, data_dict=data_format + ) + save_path = None + else: + save_path = os.path.join(dataset_save_path, id + ".hdf5") + logger.log_info(f"Save episode {env.curr_episode} to '{save_path}'") + ret = DataDictExtractor(env, save_path).extract( + obs_list, action_list, save=True, data_dict=data_format + ) + + # Update episode count + env.curr_episode += 1 + + # Return result dict + return { + "data_path": dataset_save_path, + "id": id, + "current_episode": env.curr_episode, + "data": ret, + "save_path": save_path, + } diff --git a/embodichain/data/data_engine/datasets/vla_datasets.py b/embodichain/data/data_engine/datasets/vla_datasets.py new file mode 100644 index 00000000..a3edb854 --- /dev/null +++ b/embodichain/data/data_engine/datasets/vla_datasets.py @@ -0,0 +1,521 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import fnmatch +from embodichain.utils.logger import log_warning, log_info + +try: + import h5ffmpeg as hf +except Exception as e: + log_warning("Fail to import h5ffmpeg.") +import h5py +import numpy as np +from typing import Dict, Callable, List, Tuple +from embodichain.utils.utility import get_right_name, pad_to_chunk, convert_bytes +from embodichain.utils.logger import log_warning, log_info +from embodichain.data.enum import Proprioception, Image, Exteroception, ModalInput +from copy import deepcopy +from typing import Dict +from embodichain.data.enum import ( + Modality, + PrivilegeType, + ActionMode, + JointType, + EefType, + CameraName, + TeleoperationData, +) +from embodichain.data.global_mapping import GlobalMapping +from scipy.ndimage import gaussian_filter1d +from embodichain.data.data_engine.unified_state import ActionIndicesGenerator + + +class VLADataset: + """ + This class is used to sample episodes from the embododiment dataset + stored in HDF5. + """ + + def __init__( + self, + data_path: str, + batch_size: int, + chunk_size: int, + state: List, + output: List, + img_history_size: int, + state_history_len: int, + precomp_lang_embed: bool = True, + online_config: Dict = None, + camera_used: List[str] = None, + indices_generator=None, + ) -> None: + # [Modify] The path to the HDF5 dataset directory + # Each HDF5 file contains one episode + + self.precomp_lang_embed = precomp_lang_embed + self.batch_size = batch_size + self.chunk_size = chunk_size + self.state = state + self.output = output + self.img_history_size = img_history_size + self.state_history_len = state_history_len + self.online_config = online_config + self.camera_used = camera_used + self.indices_generator = indices_generator + if self.camera_used is not None: + for cam in CameraName: + if cam.value not in camera_used: + log_warning( + "{} does not exist in {}".format(cam.value, camera_used) + ) + + if self.online_config is not None: + from embodichain.data.data_engine.online.engine import OnlineEngine + + log_info("Init online vla dataset.", color="purple") + self.engine = OnlineEngine(**self.online_config) + self.DATASET_NAME = "online_whatever" + else: + log_info("Init offline vla dataset.", color="purple") + self.engine = None + self.data_path = data_path + assert os.path.exists(self.data_path), "{} does not exist.".format( + self.data_path + ) + if os.path.isabs(self.data_path) is False: + self.data_path = os.path.join(os.getcwd(), self.data_path) + self.DATASET_NAME = os.path.basename(self.data_path) + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info( + f"Init dataset with size of: {len(self.file_paths)}", color="purple" + ) + + def update_data_size(self): + """Interface for update validation dataset size generated on the fly.""" + self.file_paths = [] + for root, _, files in os.walk(self.data_path): + for filename in fnmatch.filter(files, "*.hdf5"): + file_path = os.path.join(root, filename) + self.file_paths.append(file_path) + log_info(f"Update dataset with size of: {len(self.file_paths)}", color="purple") + + def __len__(self): + return ( + len(self.file_paths) + if self.online_config is None + else np.maximum(self.engine.episode_limit, self.batch_size) + ) + + def get_item(self, index: int = None, chunk_size: int = None): + """Get a training sample at a random timestep. + + Args: + index (int, optional): the index of the episode. + If not provided, a random episode will be selected. + state_only (bool, optional): Whether to return only the state. + In this way, the sample will contain a complete trajectory rather + than a single timestep. Defaults to False. + + Returns: + sample (dict): a dictionary containing the training sample. + """ + chunk_size = self.chunk_size if chunk_size is None else chunk_size + while True: + if self.online_config is None: + # offline + if index is None: + file_path = np.random.choice(self.file_paths) + else: + file_path = self.file_paths[index] + valid, sample = self.parse_hdf5_file(file_path, chunk_size) + else: + data_dict = self.engine.sample_data() + valid, sample = self.parse_dict(data_dict, chunk_size) + + if valid: + return sample + else: + if self.online_config is None: + index = np.random.randint(0, len(self.file_paths)) + + @staticmethod + def parse_exteroception( + file: Dict, + step_id: int, + chunk_size: int, + camera_used: List[str] = [], + ) -> Exteroception: + exteroception = [] + for cam in camera_used: + exteroception_full = file["observations"][ + PrivilegeType.EXTEROCEPTION.value + ][cam] + exteroception.append(exteroception_full[step_id : step_id + chunk_size]) + + exteroception = np.concatenate(exteroception, 1) + _, cs, kn, _ = exteroception.shape + exteroception = pad_to_chunk(exteroception, chunk_size) + return Exteroception( + data=exteroception.reshape(chunk_size, cs, kn, 2).transpose( + 1, 0, 2, 3 + ) # cs, chunk_size, kn, 2 + ) + + @staticmethod + # Parse the images + def parse_img( + file: Dict, + step_id: int, + first_idx: int, + cam: str, + chunk_size: int, + key: str = Modality.IMAGES.value, + camera_used: List[str] = [], + np_ops: Callable = lambda x: x, + ) -> Image: + valid_len = min(step_id - (first_idx - 1) + 1, chunk_size) + cam_mask = np.array([False] * (chunk_size - valid_len) + [True] * valid_len) + if cam in camera_used: + temp = file["observations"][key][cam][0] + imgs = np.zeros((valid_len,) + temp.shape, dtype=temp.dtype) + for t, i in enumerate(range(max(step_id - chunk_size + 1, 0), step_id + 1)): + img = file["observations"][key][cam][i] + imgs[t] = img + imgs = np_ops(imgs) + imgs = pad_to_chunk(imgs, chunk_size=chunk_size) + mask = cam_mask.copy() + else: + imgs = np.zeros((chunk_size, 0, 0, 0)) + mask = np.zeros((chunk_size,), dtype=bool) + return Image(data=imgs, mask=mask, name=cam) + + def parse_hdf5_file(self, file_path, chunk_size: int) -> Dict[str, ModalInput]: + import hdfdict + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + + with h5py.File(file_path, "r") as f: + data = hdfdict.load(f) + keyname = ( + JointType.QPOS.value + if VLADataset.is_real_datasets(data) + else Modality.STATES.value + ) + step_id = VLADataset.random_step_id(data, chunk_size, keyname) + if not VLADataset.is_real_datasets(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + else: + # Real data: if compressed structure is detected (containing *_index/*_start), also perform decoding filtering + try: + if CompressedVideoHDF5.is_compressed_hdf5(data): + data = CompressedVideoHDF5(file_path, chunks=None).safe_filter( + data, step_id + ) + except Exception: + pass + + ret = self.parse_dict(data, chunk_size, step_id) + + return ret + + @staticmethod + def random_step_id( + f: Dict, chunk_size: int, key: str = Modality.STATES.value + ) -> int: + obs = f["observations"] + proprio = obs[key][:] + num_steps = proprio.shape[0] + # We randomly sample a timestep + first_idx = 1 + step_id = np.random.randint( + first_idx, np.maximum(first_idx + 1, num_steps - 1 - chunk_size) + ) + return step_id + + @staticmethod + def is_real_datasets(f: Dict): + return "robot_meta" not in f.keys() + + def parse_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + if not VLADataset.is_real_datasets(f): + log_warning("Using simulation hdf5 datasets.") + return self.parse_sim_dict(f, chunk_size, step_id) + else: + log_warning("Using real world offline hdf5 datasets.") + return self.parse_real_dict(f, chunk_size, step_id) + + def parse_real_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + + from embodichain.data.data_engine.unified_state import ( + StateUnifier, + ) + from embodichain.data.enum import ( + ControlParts, + EndEffector, + JointType, + ) + + if step_id is None: + step_id = VLADataset.random_step_id(f, chunk_size, "qpos") + obs = f["observations"] + first_idx = 1 + proprio = obs["qpos"][:] + num_steps = proprio.shape[0] + camera_used_in_real = list(obs[Modality.IMAGES.value].keys()) + camera_used_from_real_to_dualsys = { + "cam_hand_left": CameraName.LEFT_WRIST.value, + "cam_hand_right": CameraName.RIGHT_WRIST.value, + "cam_high_left": CameraName.HEAD.value, + } + camera_used_from_dualsys_to_real = { + val: key for key, val in camera_used_from_real_to_dualsys.items() + } + # Now assume it is from W1. + camera_used = [ + camera_used_from_real_to_dualsys[cam] + for cam in camera_used_in_real + if cam in camera_used_from_real_to_dualsys + ] + + # Assemble the meta + meta = { + "dataset_name": self.DATASET_NAME, + "#steps": num_steps, + "step_id": step_id, + "camera_used": camera_used, + "instruction": "", + } + # save all supported proprio and action types. + robot_meta_config = {"arm_dofs": 14, "observation": {}} + + REAL_SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.HEAD.value + JointType.QPOS.value, + ControlParts.WAIST.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ] + REAL_SUPPORTED_ACTION_TYPES = REAL_SUPPORTED_PROPRIO_TYPES + + robot_meta_config["observation"][ + Modality.STATES.value + ] = REAL_SUPPORTED_PROPRIO_TYPES + robot_meta_config[Modality.ACTIONS.value] = REAL_SUPPORTED_ACTION_TYPES + state_unifier = StateUnifier(robot_meta=robot_meta_config) + + qpos_index_dict = { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: TeleoperationData.LEFT_ARM_QPOS_INDICES.value, + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: TeleoperationData.RIGHT_ARM_QPOS_INDICES.value, + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.LEFT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: TeleoperationData.RIGHT_EEF_DEXTROUSHAND_INDICES.value, + ControlParts.HEAD.value + + JointType.QPOS.value: TeleoperationData.HEAD_QPOS_INDICES.value, + ControlParts.WAIST.value + + JointType.QPOS.value: TeleoperationData.WAIST_QPOS_INDICES.value, + } + qpos_dict = {} + for key, indices in qpos_index_dict.items(): + qpos_dict[key] = proprio[:, indices] + + actions = state_unifier.fill_in_action(qpos_dict) + proprio = state_unifier.fill_in_state(qpos_dict) + parse_dict = self.parse_core(proprio, actions, step_id, chunk_size) + parse_dict.update({"meta": meta}) + for cam in camera_used: + parse_dict[cam] = VLADataset.parse_img( + f, + step_id, + first_idx, + camera_used_from_dualsys_to_real[cam], + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used_from_dualsys_to_real[cam], + ) + return True, parse_dict + + def parse_sim_dict( + self, f: Dict, chunk_size: int, step_id: int = None + ) -> Dict[str, ModalInput]: + + if step_id is None: + step_id = VLADataset.random_step_id(f, chunk_size) + + obs = f["observations"] + metadata = dict(f["robot_meta"]) + first_idx = 1 + + proprio = obs[Modality.STATES.value][:] + num_steps = proprio.shape[0] + min_len_step = metadata["min_len_steps"] + # [Optional] We drop too-short episode + if num_steps < min_len_step: + return False, None + + # We randomly sample a timestep + + camera_used = ( + convert_bytes(list(metadata["observation"]["vision"].keys())) + if self.camera_used is None + else self.camera_used + ) + + # Assemble the meta + meta = { + "dataset_name": self.DATASET_NAME, + "#steps": num_steps, + "step_id": step_id, + "instruction": "", + "camera_used": camera_used, + } + + assert ( + self.indices_generator.dof == metadata["arm_dofs"] + ), "Train dof {} but dataset dof {}.".format( + self.indices_generator.dof, metadata["arm_dofs"] + ) + parse_dict = self.parse_core( + proprio, f[Modality.ACTIONS.value], step_id, chunk_size + ) + parse_dict.update({"meta": meta}) + + for cam in camera_used: + cam_r = get_right_name(cam) + if cam_r in obs[Modality.IMAGES.value] and cam_r not in camera_used: + # insert camera name after cam + camera_used.insert(camera_used.index(cam) + 1, cam_r) + + for cam in camera_used: + parse_dict[cam] = VLADataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + Modality.IMAGES.value, + camera_used=camera_used, + ) + + if PrivilegeType.MASK.value in obs: + parse_dict[ + cam + "_{}".format(PrivilegeType.MASK.value) + ] = VLADataset.parse_img( + f, + step_id, + first_idx, + cam, + self.img_history_size, + PrivilegeType.MASK.value, + camera_used=camera_used, + ) + if PrivilegeType.EXTEROCEPTION.value in obs: + if obs[PrivilegeType.EXTEROCEPTION.value][camera_used[0]].shape[0] != 0: + parse_dict[ + PrivilegeType.EXTEROCEPTION.value + ] = VLADataset.parse_exteroception( + f, + step_id, + chunk_size, + camera_used=camera_used, + ) + + if Modality.GEOMAP.value in obs: + if ( + hasattr(obs[Modality.GEOMAP.value][camera_used[0]], "shape") + and obs[Modality.GEOMAP.value][camera_used[0]].shape[0] != 0 + ): + parse_dict[Modality.GEOMAP.value] = VLADataset.parse_img( + f, + step_id, + first_idx, + CameraName.HEAD.value, + self.img_history_size, + Modality.GEOMAP.value, + camera_used=camera_used, + np_ops=lambda x: np.tile(np.expand_dims(x, -1), [1, 1, 1, 3]), + ) + + # Return the resulting sample + # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0) + # E.g., return np.zeros((self.img_history_size, 0, 0, 0)) for the key "cam_left_wrist", + # if the left-wrist camera is unavailable on your robot + return True, parse_dict + + def parse_core( + self, proprio: np.ndarray, actions: np.ndarray, step_id: int, chunk_size: int + ): + # Parse the state and action + state = proprio[np.maximum(step_id - self.state_history_len, 0) : step_id] + state = np.concatenate( + [np.tile(state[0:1], [self.state_history_len - state.shape[0], 1]), state], + 0, + ) + self.indices_generator: ActionIndicesGenerator + global_mapping = self.indices_generator.global_mapping + state_indices = global_mapping.get_indices( + convert_bytes(self.state), + ) + state_indicator = np.zeros_like(state, dtype=np.int8) + state_indicator[:, state_indices] = 1 + state *= state_indicator + proprio *= state_indicator[0:1] + state_std = np.std(proprio, axis=0) + state_mean = np.mean(proprio, axis=0) + state_norm = np.sqrt(np.mean(proprio**2, axis=0)) + action_indices = self.indices_generator.get( + self.output, + ) + actions = deepcopy(actions[step_id : step_id + chunk_size]) + # FIXME: handness injection + delta_qpos_indices = self.indices_generator.get_all_delta_qpos() + qpos_indices = self.indices_generator.get_all_qpos() + # NOTE: Ops `cumsum` equal to action[:horizon]-action[0:1]. + # TODO: action = action_chunk - current_obs. + actions[:, delta_qpos_indices] = ( + actions[:, qpos_indices] - state[-1:, qpos_indices] + ) + + actions = pad_to_chunk(actions, chunk_size=chunk_size) + action_indicator = np.zeros_like(actions, dtype=np.int8) + action_indicator[:, action_indices] = 1 + actions *= action_indicator[0:1] + + parse_dict = { + "state_std": state_std, + "state_mean": state_mean, + "state_norm": state_norm, + Modality.STATES.value: Proprioception(data=state, mask=state_indicator), + Modality.ACTIONS.value: Proprioception(data=actions, mask=action_indicator), + PrivilegeType.PROGRESS.value: step_id / proprio.shape[0], + } + return parse_dict diff --git a/embodichain/data/data_engine/online/engine.py b/embodichain/data/data_engine/online/engine.py new file mode 100644 index 00000000..8b49e391 --- /dev/null +++ b/embodichain/data/data_engine/online/engine.py @@ -0,0 +1,525 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import time +import sys +import numpy as np +from threading import Thread +from typing import Dict, Tuple, Any, List, Callable, Optional +from copy import deepcopy +import threading +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +import torch +import torch.multiprocessing as mp +import copy +from embodichain.utils.logger import ( + log_info, + log_warning, + decorate_str_color, + log_debug, +) + +# Must call cuda init to prevent cuda error in subprocess. +torch._C._cuda_init() + +from dexsim.utility import NumpyRNG + +import threading +from multiprocessing import shared_memory +import pickle +from datetime import datetime +import zmq + +__all__ = ["MaiDataEngine"] + +rng = NumpyRNG.get_rng() + +log_info_produce = lambda x: log_info(decorate_str_color(x, "cyan")) +log_info_consume = lambda x: log_info(decorate_str_color(x, "orange")) + +MAX_LOOP_TIMES = 40000 + + +def init_context(port): + context = zmq.Context() + socket = context.socket(zmq.REQ) + socket.connect("tcp://localhost:{}".format(port)) + return socket + + +class DataPoolCont: + data: Any + count: int = 0 + tag: str + + @staticmethod + def from_list(data_pool: List[Dict]) -> List["DataPoolCont"]: + ret = [] + for data in data_pool: + dcnt = DataPoolCont() + dcnt.data = data + dcnt.count = 0 + dcnt.tag = str(datetime.now()).split(".")[0] + ret.append(dcnt) + return ret + + @staticmethod + def clean_data_pool_in_place( + data_pool: List["DataPoolCont"], clean_indices: List[int] + ): + if clean_indices is None: + data_pool = [] + else: + if len(clean_indices) > 0: + log_debug( + "Clean data pool with data indices {}, counts {}.".format( + clean_indices, + [data_pool[index].count for index in clean_indices], + ), + color="purple", + ) + for i in list(np.sort(clean_indices)[::-1]): + data_pool.pop(i) + + +def fetch_data( + queue_data: mp.Queue, data_pool: List[DataPoolCont], worker_info, debug: bool = True +) -> bool: + start_time = time.time() + try: + existing_shm = queue_data.get(timeout=5) + except Exception as error: + log_debug("Timeout! {}.".format(str(error)), color="red") + return False + log_debug( + "[Thread {}][Worker {}][Get] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + start_time = time.time() + scene_data = pickle.loads(existing_shm.buf[:]) + log_debug( + "[Thread {}][Worker {}][Pickle] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + if np.random.random() > 0.5 or queue_data.qsize() == 0: + start_time = time.time() + queue_data.put(existing_shm) # put back + log_debug( + "[Thread {}][Worker {}][Put] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + + assert isinstance(scene_data, list), "Invalid data format {}.".format( + type(scene_data) + ) + start_time = time.time() + data = DataPoolCont.from_list(scene_data) + data_pool.extend(data) + + log_debug( + "[Thread {}][Worker {}][Other] Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + time.time() - start_time, + ) + ) + return True + + +class RestockCriterion: + def __init__(self, data_pool_limit: int, buffer_size: int, max_sample_num: int): + self.data_pool_limit = data_pool_limit + self.buffer_size = buffer_size + self.max_sample_num = max_sample_num + + def restock_condition(self, data_pool: List, queue: mp.Queue) -> bool: + return len(data_pool) < self.data_pool_limit + + def expired_condition( + self, data_pool: List[DataPoolCont], inverse: bool = False + ) -> List[bool]: + + if len(data_pool) == 0: + return [] + + if inverse: + return [data.count <= self.max_sample_num for data in data_pool] + else: + return [data.count > self.max_sample_num for data in data_pool] + + +class OnlineEngine: + """Data manager for online data production and training. + + The objectives of this class are: + - Manage the fetch data in a separate thread. + - Perform data synchronization between the data production process and + the training process (main process). + - Provide data sampling interface for the training process, which is designed + to return a batch of synthetic data with the different scene id. + - Data lifecycle management. + + To achieve the above objectives, the following functions should be implemented: + - from_shm_thread (static method) + + Args: + insight_config (List[CfgNode]): The config of insight pipeline. + episode_limit (int, optional): The maximum number of frames in the data pool. Defaults to 24. + max_sample_num (int, optional): The maximum number of times that a data can be sampled. + Defaults to 2. + target_device (torch.device, optional): The target device of the data. Defaults to torch.device('cpu'). + annos_param (Dict[str, Any], optional): The parameters of the annotations. Defaults to None. + data_gen_func (Callable, optional): The data generation function. Defaults to None. + unique_scene_frame (int, optional): The number of unique scene frame to be sampled. Defaults to None. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + + def __init__( + self, + episode_limit: int = 24, + max_sample_num: int = 2, + port: int = 5555, + buffer_size: int = 10, + multiprocess: bool = False, + **kwargs, + ) -> None: + + self.episode_limit = episode_limit + self._max_sample_num = max_sample_num + self.port = port + + self._data_pool = [] + + self._duration = 0.01 + + self._context = mp.get_context("forkserver") + + self._queue_data = self._context.Queue() + self._queue_data.cancel_join_thread() + + self.buffer_size = buffer_size + + self._data_gen_proc = None + self._fetch_data_thread = None + self._restock_data_pool = None + + self._is_started = False + self._is_restocked = False + self._socket = init_context(port + 1 if multiprocess else port) + + self._restock_criterion = RestockCriterion( + data_pool_limit=episode_limit, + buffer_size=buffer_size, + max_sample_num=max_sample_num, + ) + self._lock = threading.RLock() + + def start( + self, + ) -> None: + """Start the data production process and the data synchronization thread. + + Args: + wait_for_limit (bool, optional): Whether to wait for the data pool to reach + the frame limit. Defaults to False. + """ + + self._signal_gen = self._context.Value("b", True) + self._signal_fetch = self._context.Value("b", True) + + self._fetch_data_thread = Thread( + target=self.from_shm_thread, + args=( + self._socket, + self._queue_data, + self._duration, + self.buffer_size, + ), + daemon=True, + ) + self._fetch_data_thread.start() + self._is_started = True + log_info( + "Now start the thread to fetch data from share memory.", color="purple" + ) + + def start_restock(self, static: bool = False): + if static: + self._restock_data_pool = Thread( + target=self.restock_data_pool_static, + args=( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ), + daemon=True, + ) + else: + self._restock_data_pool = Thread( + target=self.restock_data_pool, + daemon=True, + ) + + self._restock_data_pool.start() + self._is_restocked = True + + def stop(self) -> None: + if self.is_started: + self._is_started = False + self._signal_fetch.value = 2 + self._fetch_data_thread.join() + self.empty_queue(self._queue_data, self._context) + self.clean_data_pool_in_place() + self._signal_gen.value = 2 + else: + log_info( + "The data generation process has not been started.", color="purple" + ) + + @property + def is_started(self) -> bool: + return self._is_started + + @property + def data_size(self) -> int: + with self._lock: + return len(self._data_pool) + + @property + def queue_size(self) -> int: + return self._queue.qsize() + + @property + def unique_scene_frame(self) -> int: + return self._unique_scene_frame + + @staticmethod + def empty_queue(queue: mp.Queue, context: mp) -> None: + while queue.qsize() > 0: + try: + queue.get() + except Exception as e: + log_info("queue put invaild data format") + queue.close() + queue.join_thread() + queue = context.Queue() + break + return queue + + @staticmethod + def empty_share_memory(queue: mp.Queue) -> None: + while queue.qsize() > 0: + shm_name = queue.get() + shm = shared_memory.SharedMemory(shm_name) + shm.close() + shm.unlink() + + def restock_data_pool(self): + return OnlineEngine.restock_data_pool_static( + self._data_pool, + self._queue_data, + self._duration, + self._restock_criterion, + self._context, + self._lock, + ) + + @staticmethod + def restock_data_pool_static( + data_pool: List[DataPoolCont], + queue_data: mp.Queue, + duration: float, + restock_criterion: RestockCriterion, + context, + thread_lock, + ): + counts = 0 + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + while True: + time.sleep(duration) + # always clean the data pool first. + + start_time = time.time() + with thread_lock: + # delete + clean_indices = list( + np.argwhere(restock_criterion.expired_condition(data_pool)).reshape( + -1 + ) + ) + DataPoolCont.clean_data_pool_in_place( + data_pool, + clean_indices, + ) + if len(clean_indices) > 0: + log_debug( + "[Thread {}][Delete][Cost {}s]".format( + threading.current_thread().ident, time.time() - start_time + ) + ) + + # after clean, we check whether to restock data. + while restock_criterion.restock_condition(data_pool, queue_data): + + prev_data_size = len(data_pool) + should_fetch = False + for i in range(worker_info.num_workers): + if queue_data.qsize() > 0 and worker_info.id == i: + should_fetch = True + if should_fetch: + start_time = time.time() + with thread_lock: + # add + fetch_data( + data_pool=data_pool, + queue_data=queue_data, + worker_info=worker_info, + ) + log_debug( + "[Thread {}][Worker {}][ToDataPool] Produce data: {}->{}. Cost {}s.".format( + threading.current_thread().ident, + worker_info.id, + prev_data_size, + len(data_pool), + time.time() - start_time, + ) + ) + counts = 0 + else: + counts += 1 + + if counts % MAX_LOOP_TIMES == 0 and counts != 0: + log_info("Can not find the shm after {} times.".format(counts)) + # queue_data = OnlineEngine.empty_queue(queue_data, context) + + @staticmethod + def from_shm_thread( + socket, + queue_data: mp.Queue, + duration: float = 0.001, + buffer_size: int = 10, + ) -> None: + """The data fetching thread for data synchronization. + + The queue_data_size is used to control the data fetching thread. + If queue_data_size < buffer_size, the data fetching thread will fetch data from the queue. + If queue_data_size >= buffer_size, the data fetching thread will stop fetch data. + + Args: + socket (zmq.Context): The socket send signal for connect fetch and generator. + queue_data (mp.Queue): This queue contains information about shared memory. + duration (float, optional): _description_. Defaults to 0.001. + port (int, optional): The ZeroMQ socket port. Defaults to 5555. + buffer_size(int, optional): The number of max data queue size. Defaults to 10. + """ + counts = 0 + while True: + time.sleep(duration) + counts += 1 + if queue_data.qsize() < buffer_size: + socket.send_string(ConsumerTeleEnum.SHAKEHAND.value) + message = socket.recv() + try: + message_str = message.decode() + except Exception as e: + log_debug(str(e), color="red") + message_str = "" + if message_str != ProducerTeleEnum.NOREADY.value: + log_debug("Receive data.", color="purple") + shm_name = pickle.loads(message).popleft() + existing_shm = shared_memory.SharedMemory(name=shm_name) + queue_data.put(existing_shm) + log_debug( + "[FromShmThread] Produce queue: {}->{};".format( + queue_data.qsize() - 1, queue_data.qsize() + ) + ) + else: + if counts % MAX_LOOP_TIMES == 0: + log_debug("Queue is full. Skip this stage.", "purple") + + def sample_data( + self, + ): + + if self._is_restocked: + pass + else: + log_debug("Now start the thread to restock data.", color="purple") + self.start_restock(static=False) + + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + + class FakeWorkerInfo: + num_workers = 1 + id = 0 + + worker_info = FakeWorkerInfo() + + counts = 0 + while True: + time.sleep(self._duration) + if len(self._data_pool) > 0: + start_time = time.time() + with self._lock: + index = rng.integers(0, len(self._data_pool)) + data = self._data_pool[index] + self._data_pool[index].count += 1 + log_debug( + "[SampleData, worker {}] Consume data {}: index {}; times: {}->{}; Show queue size: {}; Cost time: {}s.".format( + worker_info.id, + data.tag, + index, + data.count, + data.count + 1, + self._queue_data.qsize(), + np.round(time.time() - start_time, 4), + ) + ) + counts = 0 + return data.data + else: + counts += 1 + if counts % MAX_LOOP_TIMES == 0: + log_info("Data pool is always empty after {} times.".format(counts)) diff --git a/embodichain/data/data_engine/online/enum.py b/embodichain/data/data_engine/online/enum.py new file mode 100644 index 00000000..a93b00d6 --- /dev/null +++ b/embodichain/data/data_engine/online/enum.py @@ -0,0 +1,34 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from enum import Enum + + +class ConsumerTeleEnum(Enum): + SHAKEHAND = "Data is ready?" + CONSUME = "Fetch data!" + NOCONSUME = "Data_pool is full." + GOTDATA = "Feched data!" + NOGOTDATA = "Not fetching data." + + +class ProducerTeleEnum(Enum): + READY = "Yes" + NOREADY = "No ready" + FULL = "Data_pool is full" + FAIL = "Failed" + SEND = "Send!" + EMPTYSTR = "Empty String." diff --git a/embodichain/data/data_engine/online/online_generator.py b/embodichain/data/data_engine/online/online_generator.py new file mode 100644 index 00000000..979f3543 --- /dev/null +++ b/embodichain/data/data_engine/online/online_generator.py @@ -0,0 +1,191 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import time +import zmq +import random +from multiprocessing import shared_memory +import pickle +from collections import deque +from typing import List +from threading import Thread +import multiprocessing as mp +import traceback +from embodichain.utils.logger import log_info, log_warning, log_error, log_debug +from embodichain.data.data_engine.online.enum import ( + ConsumerTeleEnum, + ProducerTeleEnum, +) + +torch._C._cuda_init() + + +class OnlineGenerator: + """Callback collection for online training mode.""" + + def __init__( + self, port: int, max_limit_gb: int = 50, multiprocess: bool = False, **kwargs + ) -> None: + self.shm_val = None + max_limit = max_limit_gb * 1024**3 + self._context = mp.get_context("forkserver") + self.port = port + self.socket = self.init_context(self.port, multiprocess) + self._duration = 0.01 + self.queue = deque() + self.queue_memroy = deque() + self.max_limit = max_limit + + self.validation_config = kwargs.get("validation", {}) + + def get_validation_config(self): + return self.validation_config + + def init_context(self, port, multiprocess: bool = False): + context = zmq.Context() + socket = context.socket(zmq.REP) + if multiprocess: + socket.connect(f"tcp://127.0.0.1:{port}") + else: + socket.bind(f"tcp://*:{port}") + + return socket + + def generator(self, generate_func, loop_times: int = -1, **kwargs): + self.signal = self._context.Value("b", True) + + self._zmq_send = Thread( + target=self.zmq_send, args=(self.queue, self.signal), daemon=True + ) + self._zmq_send.start() + log_debug("Start zmq sending.") + scene_id = 0 + + # -1 means infinite loop + while scene_id < loop_times or loop_times == -1: + if self.signal.value == 1: + first_time = True + try: + t0 = time.time() + return_list = generate_func( + time_id=scene_id, **self.validation_config + ) + + # TODO: support multiple trajectories for each scene. + if len(return_list) > 1: + log_error( + "Only support one trajectory for each scene in online generation mode." + ) + + data_dict_list = [return_list[0]["data"]] + + if ( + scene_id == 0 + and self.validation_config.get("num_samples", 0) > 0 + and "data_path" in return_list[0] + ): + # create shared memory to store the validation dataset path, which will be accessed by training process. + import sys + + data_path = return_list[0]["data_path"] + + shared_name = self.validation_config.get( + "dataset_name", "val_data_path" + ) + log_info( + f"Create shared memory for validation data path: {shared_name}", + color="green", + ) + self.shm_val = shared_memory.SharedMemory( + name=shared_name, + create=True, + size=len(data_path.encode()) + sys.getsizeof(""), + ) + self.shm_val.buf[: len(data_path.encode())] = data_path.encode() + log_info( + f"Craete shared memory for validation data path: {data_path}" + ) + + log_info( + f"Generate scene {scene_id + 1} time cost: {time.time() - t0}" + ) + serialized_data = pickle.dumps(data_dict_list) + shm = shared_memory.SharedMemory( + create=True, size=len(serialized_data) + ) + self.queue.append(shm.name) + self.queue_memroy.append( + {"name": shm.name, "size": len(serialized_data)} + ) + shm.buf[: len(serialized_data)] = serialized_data + except Exception as e: + log_error(f"Error in data generation process: {e}.") + traceback.print_exc() + self._zmq_send.join() + break + scene_id += 1 + self.empty_memory() + elif self.signal.value == 0: + if first_time: + log_warning("zmq recive full signal, wait generator signal") + first_time = False + log_warning("Signal value is 0.") + time.sleep(self._duration) + continue + else: + log_error("Unknown signal, data generator stop") + break + + def zmq_send(self, queue, signal): + while True: + try: + message = self.socket.recv_string() + if message == ConsumerTeleEnum.SHAKEHAND.value: + if len(queue) > 0: + log_warning( + "Recieve {} and send [data] to consumer.".format(message) + ) + self.socket.send(pickle.dumps(queue)) + queue.clear() + else: + self.socket.send(ProducerTeleEnum.NOREADY.value.encode()) + signal.value = 1 + except Exception as e: + print(e) + traceback.print_exc() + break + + def empty_memory(self): + total_size = sum([x["size"] for x in self.queue_memroy]) + log_info(f"share memory size is {total_size/(1024**3)} GB") + while total_size >= self.max_limit: + shm_name = self.queue_memroy.popleft() + if shm_name["name"] in self.queue: + log_info(f"remove {shm_name['name']} from queue") + self.queue.remove(shm_name["name"]) + try: + shm = shared_memory.SharedMemory(shm_name["name"]) + except: + continue + shm.close() + shm.unlink() + total_size = sum([x["size"] for x in self.queue_memroy]) + + def __del__(self): + if self.shm_val: + self.shm_val.close() + self.shm_val.unlink() diff --git a/embodichain/data/data_engine/unified_state.py b/embodichain/data/data_engine/unified_state.py new file mode 100644 index 00000000..fc430def --- /dev/null +++ b/embodichain/data/data_engine/unified_state.py @@ -0,0 +1,381 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.data.global_indices import ( + GLOBAL_INDICES, + STATE_VEC_LEN, +) +from embodichain.data.global_mapping import GlobalMapping +import numpy as np +from typing import List, Dict, Tuple, Union +from embodichain.data.enum import ( + ArmEnum, + Modality, + JointType, + ActionMode, + EefType, + ControlParts, + EndEffector, + Modality, +) +from embodichain.utils.logger import log_info, log_warning + +DEFAULT_EMPTY_STATE = -1 + + +"""Unified state utilities for EmbodiChain. + +This module provides helpers to construct and query a unified state/action +vector representation used across EmbodiChain environments and agents. + +Classes: + StateUnifier: Fill sparse per-modality state/action dictionaries into a + fixed-length unified state vector where unspecified entries are set + to a sentinel value (DEFAULT_EMPTY_STATE). + + ActionIndicesGenerator: Query index ranges in the unified vector for + common action/state groups (e.g. qpos, delta qpos, end-effector pose). + +Constants: + DEFAULT_EMPTY_STATE (int): Sentinel value used to mark unspecified + entries in the unified vector. +""" + + +class StateUnifier: + """Convert per-modality state/action arrays into a unified vector. + + The StateUnifier is constructed with ``robot_meta`` (the robot's + metadata) which should contain an ``observation`` mapping with keys for + modalities (e.g. ``Modality.STATES``) and an ``actions`` specification. + + Attributes: + metadata (dict): Robot metadata passed at construction. + arm_dofs (int): Degrees of freedom for the arm (default: 12). + indices_generator (ActionIndicesGenerator): Helper for action indices. + proprio_meta: Metadata list for proprioceptive modalities. + global_mapping (GlobalMapping): Mapping from names to unified indices. + output: Action output specification from metadata. + state_dim (int): Fixed length of the unified state vector. + """ + + def __init__(self, robot_meta: Dict) -> None: + assert "arm_dofs" in robot_meta + assert "observation" in robot_meta + assert Modality.ACTIONS.value in robot_meta + + self.arm_dofs = robot_meta["arm_dofs"] + self.indices_generator = ActionIndicesGenerator(self.arm_dofs) + self.proprio_meta = robot_meta["observation"][Modality.STATES.value] + self.global_mapping = GlobalMapping(self.arm_dofs) + self.output = robot_meta[Modality.ACTIONS.value] + + self.state_dim = STATE_VEC_LEN + + def fill_in_state( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified state vector from given values. + + Args: + values (np.ndarray or dict): If ``values`` is a numpy array it is + assumed to already be aligned to the unified layout and will + be placed into the output container. If it is a ``dict``, + keys should match entries from the robot metadata + ``observation[Modality.STATES]`` and values are numpy arrays + with a trailing dimension matching each state's width. + + Returns: + np.ndarray: An array with shape ``(..., STATE_VEC_LEN)`` containing + the unified state with unspecified entries set to + ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.global_mapping.get_indices(self.proprio_meta) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for val in values.values(): + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for state_name in self.proprio_meta: + state_indices = self.global_mapping.get_indices([state_name]) + if values[state_name].size != 0: + uni_vec[..., state_indices] = values[state_name] + + return uni_vec + + def fill_in_action( + self, values: Union[np.ndarray, Dict[str, np.ndarray]] + ) -> np.ndarray: + """Fill a unified action vector from given action values. + + This mirrors :meth:`fill_in_state` but uses the metadata's action + output specification to determine which named outputs map into the + unified vector. + + Args: + values (np.ndarray or dict): Action values aligned to the unified + layout or a mapping from output names to numpy arrays. + + Returns: + np.ndarray: Unified vector shaped ``(..., STATE_VEC_LEN)`` with + unspecified entries filled with ``DEFAULT_EMPTY_STATE``. + """ + if isinstance(values, np.ndarray): + UNI_STATE_INDICES = self.indices_generator.get(self.output) + uni_vec = ( + np.ones(values.shape[:-1] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + ) + uni_vec[..., UNI_STATE_INDICES] = values + return uni_vec + else: + shape_tuple_list = [] + for key, val in values.items(): + + shape_tuple = val.shape[:-1] + if val.size != 0: + shape_tuple_list.append(shape_tuple) + + shape_tuple = list(set(shape_tuple_list)) + assert len(shape_tuple) == 1, "shape tuple {} is not unique.".format( + shape_tuple + ) + + uni_vec = np.ones(shape_tuple[0] + (self.state_dim,)) * DEFAULT_EMPTY_STATE + for out_name in self.output: + state_indices = self.global_mapping.get_indices([out_name]) + if out_name in values and values[out_name].size != 0: + uni_vec[..., state_indices] = values[out_name] + return uni_vec + + +class ActionIndicesGenerator: + """Utility for generating index lists for action/state groups. + + The ActionIndicesGenerator wraps :class:`GlobalMapping` to provide + common queries like retrieving indices for all joint positions (qpos), + delta qpos (relative mode), end-effector transforms/poses, and + hand-specific selections (left/right/both). + + Args: + dof (int, optional): If provided, a :class:`GlobalMapping` is + constructed and reused for queries. + """ + + def __init__(self, dof: int = None): + self.global_mapping = None + self.dof = dof + if dof is not None: + self.global_mapping = GlobalMapping(dof) + + def get_all_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices covering all joint position entries. + + Args: + dof (int, optional): Degrees of freedom to construct a temporary + :class:`GlobalMapping` if the generator was not initialized + with a ``dof``. + handness (str): One of values from :class:`ArmEnum` specifying + which arm(s) to include. + + Returns: + List[int]: Ordered list of indices in the unified vector + corresponding to qpos entries for the requested arm + selection. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [qpos_name], [delta_qpos_name]) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + return self.get( + all_names, dof, [handness + qpos_name], [handness + delta_qpos_name] + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + return self.get( + all_names, dof, [handness + qpos_name], [handness + delta_qpos_name] + ) + + def get_all_delta_qpos( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices for delta (relative) joint position entries. + + Args and return are the same as :meth:`get_all_qpos` but select the + ``ActionMode.RELATIVE`` named entries. + """ + qpos_name = JointType.QPOS.value + delta_qpos_name = ActionMode.RELATIVE.value + qpos_name + global_mapping = self.get_mapping(dof) + + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [delta_qpos_name], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + return self.get(all_names, dof, [handness + delta_qpos_name], []) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + return self.get(all_names, dof, [handness + delta_qpos_name], []) + + def get_all_eef( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices covering end-effector (EEF) related entries. + + Args: + dof (int, optional): Degrees of freedom for mapping lookup. + handness (str): Which arm(s) to include (left/right/both). + + Returns: + List[int]: Indices corresponding to EEF-related entries. + """ + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + if handness == ArmEnum.DUAL_ARM.value: + return self.get( + all_names, + dof, + [ControlParts.LEFT_EEF.value, ControlParts.RIGHT_EEF.value], + [], + ) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_EEF.value + return self.get( + all_names, + dof, + [handness], + [], + ) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_EEF.value + return self.get( + all_names, + dof, + [handness], + [], + ) + + def get_all_eef_pose( + self, dof: int = None, handness: str = ArmEnum.DUAL_ARM.value + ) -> List[int]: + """Return indices specifically for EEF pose entries. + + Args: + dof (int, optional): Degrees of freedom for mapping lookup. + handness (str): Which arm(s) to include (left/right/both). + + Returns: + List[int]: Indices corresponding to EEF poses. + """ + global_mapping = self.get_mapping(dof) + all_names = list(global_mapping.mapping_from_name_to_indices.keys()) + + if handness == ArmEnum.DUAL_ARM.value: + return self.get(all_names, dof, [EefType.POSE.value], []) + elif handness == ArmEnum.LEFT_ARM_ONLY.value: + handness = ControlParts.LEFT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + elif handness == ArmEnum.RIGHT_ARM_ONLY.value: + handness = ControlParts.RIGHT_ARM.value + return self.get(all_names, dof, [handness + EefType.POSE.value], []) + + def get_mapping(self, dof: int = None): + """Return the :class:`GlobalMapping` used by this generator. + + If a mapping was created during initialization (because ``dof`` was + provided), ensure any provided ``dof`` argument matches it. Otherwise + construct and return a temporary :class:`GlobalMapping` for the + requested ``dof``. + + Args: + dof (int, optional): Degrees of freedom to construct a mapping + if one was not provided at initialization. + + Returns: + GlobalMapping: Mapping instance for name->index lookups. + """ + if self.global_mapping is not None: + assert dof is None or dof == self.dof + global_mapping = self.global_mapping + else: + assert ( + dof is not None + ), "Dof must be set when dof is not provided in initialization." + global_mapping = GlobalMapping(dof) + return global_mapping + + def get( + self, + output: List[str], + dof: int = None, + white_list: List[str] = None, + black_list: List[str] = None, + ) -> List[int]: + """Select and return indices from ``output`` names applying optional + white/black list filters. + + Args: + output (List[str]): Names (keys) in a :class:`GlobalMapping` + whose indices should be collected. + dof (int, optional): Degrees of freedom used to construct a + temporary :class:`GlobalMapping` if needed. + white_list (List[str], optional): If provided, only include names + that contain any of these substrings. + black_list (List[str], optional): If provided, exclude names + that contain any of these substrings. + + Returns: + List[int]: Ordered list of unified-vector indices for the + selected names. + """ + + action_indices = [] + global_mapping = self.get_mapping(dof) + + for action_type in output: + if isinstance(white_list, list) and isinstance(black_list, list): + if any([temp in action_type for temp in white_list]) and all( + [temp not in action_type for temp in black_list] + ): + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + else: + action_indices += global_mapping.mapping_from_name_to_indices[ + action_type + ] + + return action_indices # keep order. diff --git a/embodichain/data/dataset.py b/embodichain/data/dataset.py new file mode 100644 index 00000000..f5c775fd --- /dev/null +++ b/embodichain/data/dataset.py @@ -0,0 +1,170 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import sys +import shutil +import hashlib +import open3d as o3d + + +from embodichain.utils import logger + + +class EmbodiChainDataset(o3d.data.DownloadDataset): + def __init__(self, prefix, data_descriptor, path): + # Perform the zip file and extracted contents check + # If the zip was not valid, the zip file would have been removed + # and the parent class would download and extract it again + self.check_zip(prefix, data_descriptor, path) + # Call the parent class constructor + super().__init__(prefix, data_descriptor, path) + + def check_zip(self, prefix, data_descriptor, path): + """Check the integrity of the zip file and its extracted contents.""" + # Path to the downloaded zip file + zip_file_name = os.path.split(data_descriptor.urls[0])[1] + zip_dir_path = os.path.join(path, "download", f"{prefix}") + zip_path = os.path.join(path, "download", f"{prefix}", f"{zip_file_name}") + # Path to the extracted directory + extracted_path = os.path.join(path, "extract", prefix) + + def is_safe_path(path_to_check): + """Verify if the path is within safe directory boundaries""" + return ( + "embodichain_data/download" in path_to_check + or "embodichain_data/extract" in path_to_check + ) + + def safe_remove_directory(dir_path): + """Safely remove a directory after path validation""" + if not is_safe_path(dir_path): + logger.log_warning( + f"Safety check failed, refusing to delete directory: {dir_path}" + ) + return False + + if os.path.exists(dir_path): + try: + shutil.rmtree(dir_path) + logger.log_info(f"Successfully removed directory: {dir_path}") + return True + except OSError as e: + logger.log_warning(f"Error while removing directory: {e}") + return False + return True + + # Check if the file already exists + if os.path.exists(zip_path): + # Calculate MD5 checksum of the existing file + md5_existing = self.calculate_md5(zip_path) + # Compare with the expected MD5 checksum + if md5_existing != data_descriptor.md5: + # If checksums do not match, delete the existing file + os.remove(zip_path) + # Ensure the extracted directory is removed if it exists + safe_remove_directory(extracted_path) + logger.log_warning( + f"Invalid MD5 checksum detected:\n" + f" - File: {zip_path}\n" + f" - Expected MD5: {data_descriptor.md5}\n" + f" - Actual MD5: {md5_existing}\n" + f"Cleaned up invalid files and directories for fresh download." + ) + return + else: + safe_remove_directory(zip_dir_path) + safe_remove_directory(extracted_path) + logger.log_info( + f"ZIP file not found at {zip_path}." + f"Cleaning up related directories for fresh download." + ) + return + + # Check if the extracted directory exists and is not empty + if not os.path.exists(extracted_path) or not os.listdir(extracted_path): + # Remove the zip file to trigger Open3D's automatic download mechanism + # Open3D will re-download and extract when the zip file is missing + if os.path.exists(zip_path): + os.remove(zip_path) + + # Clean up any existing empty extraction directory + # This ensures a clean state for the upcoming extraction process + safe_remove_directory(extracted_path) + logger.log_info( + f"Removed zip file {zip_path} and extracted path {extracted_path} to trigger Open3D download and extract. " + f"Reason: {'Missing extraction directory.' if not os.path.exists(extracted_path) else 'Empty extraction directory.'}" + ) + return + + def calculate_md5(self, file_path, chunk_size=8192): + """Calculate the MD5 checksum of a file.""" + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +def get_data_class(dataset_name: str): + """Retrieve the dataset class from the available modules. + + Args: + dataset_name (str): The name of the dataset class. + + Returns: + type: The dataset class. + + Raises: + AttributeError: If the dataset class is not found in any module. + """ + module_names = [ + "embodichain.data", + "embodichain.data.assets", + __name__, + ] + + for module_name in module_names: + try: + return getattr(sys.modules[module_name], dataset_name) + except AttributeError: + continue + + raise AttributeError(f"Dataset class '{dataset_name}' not found in any module.") + + +def get_data_path(data_path_in_config: str) -> str: + """Get the absolute path of the data file. + + Args: + data_path_in_config (str): The dataset path in the format "${dataset_name}/subpath". + + Returns: + str: The absolute path of the data file. + """ + if os.path.isabs(data_path_in_config): + return data_path_in_config + + split_str = data_path_in_config.split("/") + dataset_name = split_str[0] + sub_path = os.path.join(*split_str[1:]) + + # Use the optimized get_data_class function + data_class = get_data_class(dataset_name) + data_obj = data_class() + data_dir = data_obj.extract_dir + data_path = os.path.join(data_dir, sub_path) + return data_path diff --git a/embodichain/data/enum.py b/embodichain/data/enum.py new file mode 100644 index 00000000..5c31ed9e --- /dev/null +++ b/embodichain/data/enum.py @@ -0,0 +1,349 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np + +from typing import List, Tuple, Union, Dict +from enum import Enum, IntEnum +from itertools import product +from aenum import Enum as AEnum +from aenum import NoAlias + +from embodichain.utils.utility import get_right_name + + +class ModalInput: + def __init__( + self, + data: Union[torch.Tensor, np.ndarray] = None, + mask: Union[torch.Tensor, np.ndarray] = None, + name: str = "", + ): + self.data = data + self.mask = mask # indicator mask for the data, e.g., which part is valid. + self.name = name + + +class Privilege: + def __init__( + self, + data: Union[torch.Tensor, np.ndarray] = None, + mask: Union[torch.Tensor, np.ndarray] = None, + name: str = "", + ): + self.data = data + self.mask = mask # indicator mask for the data, e.g., which part is valid. + self.name = name + + +class Mask(Privilege): + pass + + +class Exteroception(Privilege): + pass + + +class State(Privilege): + pass + + +class Proprioception(ModalInput): + pass + + +class Image(ModalInput): + pass + + +class GeoMap(ModalInput): + pass + + +class Lang(ModalInput): + pass + + +class Modality(Enum): + STATES = "states" + STATE_INDICATOR = "state_indicator" + ACTIONS = "actions" + ACTION_INDICATOR = "action_indicator" + IMAGES = "images" + LANG = "lang" + LANG_INDICATOR = "lang_indicator" + GEOMAP = "geomap" # e.g., depth, point cloud, etc. + VISION_LANGUAGE = "vision_language" # e.g., image + lang + + +class JointType(Enum): + QPOS = "qpos" + + +class EefType(Enum): + POSE = "eef_pose" + + +class ActionMode(Enum): + ABSOLUTE = "" + RELATIVE = "delta_" # This indicates the action is relative change with respect to last state. + + +class EndEffector(Enum): + GRIPPER = "gripper" + DEXTROUSHAND = "hand" + + +class EefExecute(Enum): + OPEN = "execute_open" + CLOSE = "execute_close" + + +class CameraName(Enum): + HEAD = "cam_high" + HEAD_RIGHT = get_right_name("cam_high") + RIGHT_WRIST = "cam_right_wrist" + LEFT_WRIST = "cam_left_wrist" + + +class ControlParts(Enum): + LEFT_ARM = "left_arm" + RIGHT_ARM = "right_arm" + LEFT_EEF = "left_eef" + RIGHT_EEF = "right_eef" + HEAD = "head" + WAIST = "waist" + + +class TeleoperationData(Enum): + """Enum for teleoperation data conversion script specific string constants""" + + # Camera types + HEAD_CAMERA = "head" + HAND_CAMERA = "hand" + + # Camera positions + LEFT_PLACE = "left" + RIGHT_PLACE = "right" + + # Camera name prefixes + CAM_HIGH_PREFIX = "cam_high" + CAM_HAND_PREFIX = "cam_hand" + + # File names and patterns + METADATA_FILE = "metadata.jsonl" + QPOS_PATTERN = "pose_record_*.json" + IMAGE_PATH_KEY = "image_path" + TIMESTAMP_KEY = "timestamp" + CAMERA_TYPE_KEY = "camera_type" + + # Data structure keys + OBSERVATIONS = "observations" + IMAGES = "images" + QPOS = "qpos" + ACTION = "action" + FRAMES = "frames" + DATA = "data" + + # Joint keys (common ones) + LEFT_GRIPPER = "LEFT_GRIPPER" + RIGHT_GRIPPER = "RIGHT_GRIPPER" + LEFT_HAND_PREFIX = "LEFT_HAND" + RIGHT_HAND_PREFIX = "RIGHT_HAND" + # Joint index mapping for real robot data + LEFT_ARM_QPOS_INDICES = [6, 7, 8, 9, 10, 11, 12] + RIGHT_ARM_QPOS_INDICES = [14, 15, 16, 17, 18, 19, 20] + LEFT_EEF_DEXTROUSHAND_INDICES = [22, 23, 24, 25, 26, 27] + RIGHT_EEF_DEXTROUSHAND_INDICES = [28, 29, 30, 31, 32, 33] + WAIST_QPOS_INDICES = [ + 3, + ] + HEAD_QPOS_INDICES = [4, 5] + + +class Hints(Enum): + EEF = ( + ControlParts.LEFT_EEF.value, + ControlParts.RIGHT_EEF.value, + EndEffector.GRIPPER.value, + EndEffector.DEXTROUSHAND.value, + ) + ARM = (ControlParts.LEFT_ARM.value, ControlParts.RIGHT_ARM.value) + + +class CameraLoc(AEnum): + # The difference between CameraLoc and CameraName is that CameraLoc allows duplicate values. + # And the value is used to indicate the sub-network ids, e.g. LEFT_WRIST and RIGHT_WRIST share the same sub-network feature extraction. + _settings_ = NoAlias + HEAD = 0 + RIGHT_WRIST = 1 + LEFT_WRIST = 1 + + +class CameraOrder(IntEnum): + # This is used to indicate the order of camera inputs, for both simulation and real deployment, training and inference. + # For dual system, the order is HEAD, RIGHT_WRIST, LEFT_WRIST. + HEAD = 0 + RIGHT_WRIST = 1 + LEFT_WRIST = 2 + + +DEFAULT_CAMERA_ORDER = {tmp.value: CameraName[tmp.name].value for tmp in CameraOrder} +DEFAULT_CAMERA_LOC = {CameraName[tmp.name].value: tmp.value for tmp in CameraLoc} + + +def link_type(*args) -> str: + l = len(args) + if l == 0: + return "" + elif l == 1: + return args[0] + elif l >= 2: + ret_str = "[{}]".format(args[0]) + for i in range(1, l): + ret_str += "_[{}]".format(args[i]) + return ret_str + + +combined_members = { + link_type(a.name + b.name, c.name + d.name, e.name): link_type( + a.value + b.value, c.value + d.value, e.value + ) + for a, b, c, d, e in product( + ActionMode, JointType, ActionMode, EefType, EndEffector + ) +} +ActionType = Enum("ActionType", combined_members) +combined_proprio_members = { + link_type(a.name, b.name, c.name): link_type(a.value, b.value, c.value) + for a, b, c in product(JointType, EefType, EndEffector) +} +ProprioType = Enum("ProprioType", combined_proprio_members) + + +def parse_action_type(action_type: str) -> Tuple[str, str, str]: + splits = action_type.split("[") + assert len(splits) == 3, "{} must contain 3-[].".format(action_type) + proprio_type = splits[0].split("]")[0] + eef_type = splits[1].split("]")[0] + end_effector = splits[2].split("]")[0] + return proprio_type, eef_type, end_effector + + +def parse_proprio_type(proprio_type: str) -> Tuple[str, str, str]: + return parse_action_type(proprio_type) + + +class PrivilegeType(Enum): + EXTEROCEPTION = "exteroception" + MASK = "mask" + STATE = "state" + PROGRESS = "progress" + + +SUPPORTED_PROPRIO_TYPES = [ + ControlParts.LEFT_ARM.value + EefType.POSE.value, + ControlParts.RIGHT_ARM.value + EefType.POSE.value, + ControlParts.LEFT_ARM.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + JointType.QPOS.value, + ControlParts.LEFT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.RIGHT_EEF.value + EndEffector.DEXTROUSHAND.value, + ControlParts.LEFT_EEF.value + EndEffector.GRIPPER.value, + ControlParts.RIGHT_EEF.value + EndEffector.GRIPPER.value, +] +SUPPORTED_ACTION_TYPES = SUPPORTED_PROPRIO_TYPES + [ + ControlParts.LEFT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, + ControlParts.RIGHT_ARM.value + ActionMode.RELATIVE.value + JointType.QPOS.value, +] +SUPPORTED_EXTRA_VISION_TYPES = [ + Modality.GEOMAP.value, + PrivilegeType.EXTEROCEPTION.value, + PrivilegeType.MASK.value, +] + + +def search_sub_str(sub_str: str, refs: List[str]): + ret = [False for _ in refs] + for i, ref in enumerate(refs): + ret[i] = sub_str in ref + return ret + + +class ArmEnum(IntEnum): + LEFT_ARM_ONLY = 1 + RIGHT_ARM_ONLY = 2 + DUAL_ARM = 3 + + +class ArmName(Enum): + LEFT_ARM_ONLY = "left_arm" + RIGHT_ARM_ONLY = "right_arm" + + +class SemanticMask(IntEnum): + BACKGROUND = 0 + FOREGROUND = 1 + ROBOT = 2 + + +def get_all_cond(suffix: str = "_cond") -> Dict[str, str]: + cond_dict = {} + for modality in Modality: + cond_dict[modality.value] = modality.value + suffix + for privilege in PrivilegeType: + cond_dict[privilege.value] = privilege.value + suffix + return cond_dict + + +def is_dual_arms(dofs: int) -> bool: + return dofs > 10 + + +from collections import deque + + +class HistoryChunks: + def __init__(self, history_len: int = 2) -> None: + self.deque = deque(maxlen=history_len) + self.history_len = history_len + + def inqueue(self, data: ModalInput) -> None: + self.deque.append(data) + + def __getitem__( + self, + index: int, + ) -> ModalInput: + return self.deque[index] + + def __len__( + self, + ) -> int: + return len(self.deque) + + def isfull( + self, + ) -> bool: + return len(self) == self.history_len + + def get_list( + self, + ) -> List[ModalInput]: + return list(self.deque) + + def clean(self) -> None: + self.deque.clear() diff --git a/embodichain/data/global_indices.py b/embodichain/data/global_indices.py new file mode 100644 index 00000000..518a2e4c --- /dev/null +++ b/embodichain/data/global_indices.py @@ -0,0 +1,132 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np + +GLOBAL_INDICES = { + # [0, 10): right arm joint positions + **{"arm_joint_{}_pos".format(i): i for i in range(10)}, + **{"right_arm_joint_{}_pos".format(i): i for i in range(10)}, + # [10, 15): right gripper joint positions + **{"gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + **{"right_gripper_joint_{}_pos".format(i): i + 10 for i in range(5)}, + "gripper_open": 10, # alias of right_gripper_joint_0_pos + "right_gripper_open": 10, + # [15, 25): right arm joint velocities + **{"arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + **{"right_arm_joint_{}_vel".format(i): i + 15 for i in range(10)}, + # [25, 30): right gripper joint velocities + **{"gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + **{"right_gripper_joint_{}_vel".format(i): i + 25 for i in range(5)}, + "gripper_open_vel": 25, # alias of right_gripper_joint_0_vel + "right_gripper_open_vel": 25, + # [30, 33): right end effector positions + "eef_pos_x": 30, + "right_eef_pos_x": 30, + "eef_pos_y": 31, + "right_eef_pos_y": 31, + "eef_pos_z": 32, + "right_eef_pos_z": 32, + # [33, 39): right end effector 6D pose + "eef_angle_0": 33, + "right_eef_angle_0": 33, + "eef_angle_1": 34, + "right_eef_angle_1": 34, + "eef_angle_2": 35, + "right_eef_angle_2": 35, + "eef_angle_3": 36, + "right_eef_angle_3": 36, + "eef_angle_4": 37, + "right_eef_angle_4": 37, + "eef_angle_5": 38, + "right_eef_angle_5": 38, + # [39, 42): right end effector velocities + "eef_vel_x": 39, + "right_eef_vel_x": 39, + "eef_vel_y": 40, + "right_eef_vel_y": 40, + "eef_vel_z": 41, + "right_eef_vel_z": 41, + # [42, 45): right end effector angular velocities + "eef_angular_vel_roll": 42, + "right_eef_angular_vel_roll": 42, + "eef_angular_vel_pitch": 43, + "right_eef_angular_vel_pitch": 43, + "eef_angular_vel_yaw": 44, + "right_eef_angular_vel_yaw": 44, + # [45, 50): reserved + # [50, 60): left arm joint positions + **{"left_arm_joint_{}_pos".format(i): i + 50 for i in range(10)}, + # [60, 65): left gripper joint positions + **{"left_gripper_joint_{}_pos".format(i): i + 60 for i in range(5)}, + "left_gripper_open": 60, # alias of left_gripper_joint_0_pos + # [65, 75): left arm joint velocities + **{"left_arm_joint_{}_vel".format(i): i + 65 for i in range(10)}, + # [75, 80): left gripper joint velocities + **{"left_gripper_joint_{}_vel".format(i): i + 75 for i in range(5)}, + "left_gripper_open_vel": 75, # alias of left_gripper_joint_0_vel + # [80, 83): left end effector positions + "left_eef_pos_x": 80, + "left_eef_pos_y": 81, + "left_eef_pos_z": 82, + # [83, 89): left end effector 6D pose + "left_eef_angle_0": 83, + "left_eef_angle_1": 84, + "left_eef_angle_2": 85, + "left_eef_angle_3": 86, + "left_eef_angle_4": 87, + "left_eef_angle_5": 88, + # [89, 92): left end effector velocities + "left_eef_vel_x": 89, + "left_eef_vel_y": 90, + "left_eef_vel_z": 91, + # [92, 95): left end effector angular velocities + "left_eef_angular_vel_roll": 92, + "left_eef_angular_vel_pitch": 93, + "left_eef_angular_vel_yaw": 94, + # [95, 100): reserved + # [100, 102): base linear velocities + "base_vel_x": 100, + "base_vel_y": 101, + # [102, 103): base angular velocities + "base_angular_vel": 102, + # [103, 115): dextrous hand joint positions + **{"left_hand_joint_{}_pos".format(i): i + 103 for i in range(6)}, + **{"right_hand_joint_{}_pos".format(i): i + 109 for i in range(6)}, + # [115, 119): torso joint positions + **{"torso_joint_{}_pos".format(i): i + 115 for i in range(4)}, + # [119, 121): head joint positions + **{"head_joint_{}_pos".format(i): i + 119 for i in range(2)}, + "waist": 115, + # [121, 123): head joint velocities + **{"head_joint_{}_vel".format(i): i + 121 for i in range(2)}, + "waist_vel": 113, + # [124, 128): reserved +} + + +STATE_VEC_LEN = 128 + + +def get_all_left_related_indices(including_end: bool = True): + if including_end: + return np.arange(50, 128, step=1) + else: + return np.arange(50, 100) + + +def get_all_right_related_indices(): + return np.arange(0, 50) diff --git a/embodichain/data/global_mapping.py b/embodichain/data/global_mapping.py new file mode 100644 index 00000000..24ebcb34 --- /dev/null +++ b/embodichain/data/global_mapping.py @@ -0,0 +1,161 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.data.enum import ( + ControlParts, + ActionMode, + EndEffector, + JointType, + EefType, + is_dual_arms, +) +from embodichain.data.global_indices import GLOBAL_INDICES +import numpy as np +from typing import List + + +class GlobalMapping: + def __init__(self, dof: int): + self_attrs = GlobalMapping.__dict__ + num_arm = 2 if is_dual_arms(dofs=dof) else 1 + single_dof = dof // num_arm + function_dict = {} + for k, v in self_attrs.items(): + if isinstance(v, staticmethod) and "__" not in k: + function_dict.update(v.__func__(dof=single_dof, num_arm=num_arm)) + self.mapping_from_name_to_indices = function_dict + + @staticmethod + def get_qpos_indices(dof: int, num_arm, **kwrags): + + return { + ControlParts.LEFT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_pos"] for i in range(dof) + ], + ControlParts.HEAD.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_pos".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + JointType.QPOS.value: [GLOBAL_INDICES["waist"]], + } + + @staticmethod + def get_gripper_open_state_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open"]], + ControlParts.RIGHT_EEF.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open"]], + } + + @staticmethod + def get_hand_qpos_indices(num_arm: int, hand_dof: int = 6, **kwrags): + return { + ControlParts.LEFT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"left_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + ControlParts.RIGHT_EEF.value + + EndEffector.DEXTROUSHAND.value: [ + GLOBAL_INDICES[f"right_hand_joint_{i}_pos"] for i in range(hand_dof) + ], + } + + @staticmethod + def get_gripper_open_vel_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["left_gripper_open_vel"]], + ControlParts.RIGHT_EEF.value + + ActionMode.RELATIVE.value + + EndEffector.GRIPPER.value: [GLOBAL_INDICES["right_gripper_open_vel"]], + } + + @staticmethod + def get_delta_qpos_indices(dof: int, num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"left_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.RIGHT_ARM.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES[f"right_arm_joint_{i}_vel"] for i in range(dof) + ], + ControlParts.HEAD.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [ + GLOBAL_INDICES["head_joint_{}_vel".format(i)] for i in range(2) + ], + ControlParts.WAIST.value + + ActionMode.RELATIVE.value + + JointType.QPOS.value: [GLOBAL_INDICES["waist_vel"]], + } + + @staticmethod + def get_eef_pose_indices(num_arm, **kwrags): + return { + ControlParts.LEFT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["left_eef_pos_x"], + GLOBAL_INDICES["left_eef_pos_y"], + GLOBAL_INDICES["left_eef_pos_z"], + GLOBAL_INDICES["left_eef_angle_0"], + GLOBAL_INDICES["left_eef_angle_1"], + GLOBAL_INDICES["left_eef_angle_2"], + GLOBAL_INDICES["left_eef_angle_3"], + GLOBAL_INDICES["left_eef_angle_4"], + GLOBAL_INDICES["left_eef_angle_5"], + ], + ControlParts.RIGHT_ARM.value + + EefType.POSE.value: [ + GLOBAL_INDICES["right_eef_pos_x"], + GLOBAL_INDICES["right_eef_pos_y"], + GLOBAL_INDICES["right_eef_pos_z"], + GLOBAL_INDICES["right_eef_angle_0"], + GLOBAL_INDICES["right_eef_angle_1"], + GLOBAL_INDICES["right_eef_angle_2"], + GLOBAL_INDICES["right_eef_angle_3"], + GLOBAL_INDICES["right_eef_angle_4"], + GLOBAL_INDICES["right_eef_angle_5"], + ], + } + + def get_indices(self, state_meta: List[str]): + state_indices = [] + + for proprio_name in state_meta: + state_indices += self.mapping_from_name_to_indices[proprio_name] + + return state_indices + + def ret_all_state( + self, + ): + state_indices = [] + + for val in self.mapping_from_name_to_indices.values(): + state_indices += val + + return state_indices diff --git a/embodichain/lab/__init__.py b/embodichain/lab/__init__.py new file mode 100644 index 00000000..e66036ce --- /dev/null +++ b/embodichain/lab/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import sim +from . import gym +from . import devices diff --git a/embodichain/lab/devices/__init__.py b/embodichain/lab/devices/__init__.py new file mode 100644 index 00000000..29525240 --- /dev/null +++ b/embodichain/lab/devices/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .device import Device diff --git a/embodichain/lab/devices/device.py b/embodichain/lab/devices/device.py new file mode 100644 index 00000000..4e75cd9c --- /dev/null +++ b/embodichain/lab/devices/device.py @@ -0,0 +1,44 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import abc # for abstract base class definitions + + +class Device(metaclass=abc.ABCMeta): + """ + Base class for all robot controllers. + Defines basic interface for all controllers to adhere to. + """ + + @abc.abstractmethod + def start_control(self): + """ + Method that should be called externally before controller can + start receiving commands. + """ + raise NotImplementedError + + @abc.abstractmethod + def stop_control(self): + """ + Method that should be called externally to stop the controller. + """ + raise NotImplementedError + + @abc.abstractmethod + def get_controller_state(self): + """Returns the current state of the device, a dictionary of pos, orn, grasp, and reset.""" + raise NotImplementedError diff --git a/embodichain/lab/gym/__init__.py b/embodichain/lab/gym/__init__.py new file mode 100644 index 00000000..eae2638c --- /dev/null +++ b/embodichain/lab/gym/__init__.py @@ -0,0 +1,18 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import envs +from . import utils diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py new file mode 100644 index 00000000..286d7824 --- /dev/null +++ b/embodichain/lab/gym/envs/__init__.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base_env import * +from .embodied_env import * +from .tasks import * +from .wrapper import * + +from embodichain.lab.gym.envs.embodied_env import EmbodiedEnv + +# Specific task environments +from embodichain.lab.gym.envs.tasks.tableware.pour_water.pour_water import ( + PourWaterEnv, +) +from embodichain.lab.gym.envs.tasks.tableware.scoop_ice import ScoopIce + +# Reinforcement learning environments +from embodichain.lab.gym.envs.tasks.rl.push_cube import PushCubeEnv diff --git a/embodichain/lab/gym/envs/action_bank/configurable_action.py b/embodichain/lab/gym/envs/action_bank/configurable_action.py new file mode 100644 index 00000000..9c285aa5 --- /dev/null +++ b/embodichain/lab/gym/envs/action_bank/configurable_action.py @@ -0,0 +1,1472 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +import networkx as nx +from copy import deepcopy +from typing import Dict, Tuple, Union, List, Callable, Any, Optional +from tqdm import tqdm + +import torch +import matplotlib.pyplot as plt +from functools import partial +from embodichain.utils.math import pose_inv +from embodichain.utils.logger import log_info, log_warning, log_error +from embodichain.lab.sim.cfg import MarkerCfg +from embodichain.lab.gym.utils.misc import resolve_env_params, _data_key_to_control_part +from embodichain.data.enum import Hints, EefExecute +from .utils import generate_affordance_from_src, get_init_affordance +import functools + + +# https://stackoverflow.com/questions/41834530/how-to-make-python-decorators-work-like-a-tag-to-make-function-calls-by-tag +class TagDecorator(object): + def __init__(self, tagName): + self.functions = {} + self.tagName = tagName + + def __str__(self): + return "".format(tagName=self.tagName) + + def __call__(self, f): + # class_key = f"{f.__module__}.{f.__qualname__.rsplit('.', 1)[0]}" + class_name = f.__qualname__.split(".")[0] + if class_name in self.functions.keys(): + self.functions[class_name].update({f.__name__: f}) + else: + self.functions.update({class_name: {f.__name__: f}}) + return f + + +@functools.lru_cache(maxsize=None) # memoization +def get_func_tag(tagName): + return TagDecorator(tagName) + + +tag_node = get_func_tag("node") +tag_edge = get_func_tag("edge") + + +class ActionBank: + _function_type: Dict[str, Callable] + + def __init__(self, conf: Dict): + # 轨迹可以是qpos或者xpos,不同方法检查一下shape,每个姿态qpos是1维(16维),xpos是2维(4x4的矩阵) + self.conf = conf + + @property + def vis_gantt(self): + return self.conf.get("misc", {}).get("vis_gantt", False) + + @property + def vis_graph(self): + return self.conf.get("misc", {}).get("vis_graph", False) + + @property + def warpping(self): + return self.conf.get("misc", {}).get("warpping", True) + + @staticmethod + def get_function_name(input: Dict) -> str: + """ + Retrieve the function name from the input dictionary. + + This method assumes that the input dictionary contains exactly one key, + which represents the function name. If the dictionary contains more than + one key, a ValueError is raised. + + Args: + input (Dict): A dictionary with a single key representing the function name. + + Returns: + str: The function name extracted from the dictionary. + + Raises: + ValueError: If the input dictionary contains zero or more than one key. + """ + if len(list(input.keys())) != 1: + raise ValueError( + "The input dict {} has invalid keys {}.".format( + input, list(input.keys()) + ) + ) + + return list(input.keys())[0] + + def get_scope_names( + self, + ) -> List[str]: + return list(self.conf["scope"].keys()) + + def get_node_names(self, bool_attr_name: str = None) -> Dict[str, List[str]]: + scopes = self.get_scope_names() + nodes = self.conf["node"] + node_names = {} + for scope in scopes: + node_names[scope] = [] + for node in nodes[scope]: + if bool_attr_name is not None: + if node[self.get_function_name(node)].get(bool_attr_name, False): + function_name = ActionBank.get_function_name(node) + node_names[scope].append(function_name) + return node_names + + def graph2id(self, type: str = "node") -> Dict[str, Dict[str, str]]: + scopes = self.get_scope_names() + nodes = self.conf[type] + graph_2_id = {} + for scope in scopes: + graph_2_id[scope] = {} + for i, node in enumerate(nodes[scope]): + function_name = ActionBank.get_function_name(node) + graph_2_id[scope].update({function_name: i}) + return graph_2_id + + def get_edge_names(self, node_name: str = None) -> Dict[str, List[Dict[str, str]]]: + scopes = self.get_scope_names() + edges = self.conf["edge"] + edge_names = {} + for i, key in enumerate(scopes): + edge_names[key] = [] + for edge in edges[key]: + function_name = ActionBank.get_function_name(edge) + src = edge[function_name]["src"] + sink = edge[function_name]["sink"] + temp = {"name": function_name, "src": src, "sink": sink} + edge_names[key].append(temp) + if node_name is None: + return edge_names + else: + filtered_edge_names = {} + for scope, edge_list in edge_names.items(): + filtered_edge_names[scope] = [ + edge + for edge in edge_list + if edge["src"] == node_name or edge["sink"] == node_name + ] + return filtered_edge_names + + def _infer_fill_type(self, scope: str, label: str, edge_cfg: Dict) -> str: + # 1) explicit in config + ft = edge_cfg.get("fill_type", None) + # 2) built-in eef rules + fn = edge_cfg.get("name", None) + if fn in {EefExecute.OPEN.value, EefExecute.CLOSE.value}: + return "still" + # 3) explicit wins; otherwise default + return ft if ft in ("still", "scalable") else "still" + + def _get_unit_pairs(self, legends: List[str]) -> Dict[str, str]: + """ + Return a symmetric map executor -> partner within the same unit (arm+eef). + Priority: + 1) explicit config: self.conf["misc"]["unit_pairs"] = [["right_arm","right_eefhand"], ["left_arm","left_eefhand"], ...] + 2) heuristic by side-prefix and name hints ('arm' vs 'eef'/'hand'/'gripper') + """ + pairs: Dict[str, str] = {} + + # 1) explicit mapping if provided + explicit = self.conf.get("misc", {}).get("unit_pairs", None) + if explicit: + for a, b in explicit: + if a in legends and b in legends: + pairs[a] = b + pairs[b] = a + + # 2) heuristic fallback + def side_key(name: str) -> str: + # prefer token before '_' or '-', else prefix match + if "_" in name: + return name.split("_", 1)[0] + if "-" in name: + return name.split("-", 1)[0] + for pref in ("left", "right", "L", "R"): + if name.lower().startswith(pref.lower()): + return pref + return "" + + eef_hints = Hints.EEF.value + arm_hints = Hints.ARM.value + + from collections import defaultdict + + by_side = defaultdict(list) + for n in legends: + by_side[side_key(n)].append(n) + + for _, names in by_side.items(): + arms = [n for n in names if any(h in n.lower() for h in arm_hints)] + eefs = [n for n in names if any(h in n.lower() for h in eef_hints)] + # pair the first unmatched arm with the first unmatched eef + for a in arms: + if a in pairs: + continue + partner = next((e for e in eefs if e not in pairs), None) + if partner: + pairs[a] = partner + pairs[partner] = a + + return pairs + + def _apply_bubble_filling(self, packages, taskkey2index): + from collections import defaultdict + + if not packages: + return packages + + # group by executor + per_legend = defaultdict(list) + for p in packages: + per_legend[p["legend"]].append(p) + for lg in per_legend: + per_legend[lg].sort(key=lambda x: (x["start"], x["end"])) + + legends = list(per_legend.keys()) + unit_pairs = self._get_unit_pairs(legends) + + # fill_type lookup + fill_type = {} + for scope, scope_edges in self.conf.get("edge", {}).items(): + for edge in scope_edges: + lbl = list(edge.keys())[0] + fill_type[lbl] = edge[lbl].get("fill_type", "still") + + label2pkg = {p["label"]: p for p in packages} + global_end = max(p["end"] for p in packages) + + def first_start_at_or_after(seq, t): + for pkg in seq: + if pkg["start"] >= t: + return pkg["start"] + return global_end + + # optional sync boundary (unchanged) + dep_of = defaultdict(list) + for e_label, s in self.conf.get("sync", {}).items(): + for d in s.get("depend_tasks", []): + dep_of[d].append(e_label) + + def sync_boundary_for(lbl): + deps = dep_of.get(lbl, []) + if not deps: + return None + starts = [label2pkg[d]["start"] for d in deps if d in label2pkg] + return max(starts) if starts else None + + # unit-aware filling + for lg, seq in per_legend.items(): + partner = unit_pairs.get(lg, None) + partner_seq = per_legend.get(partner, []) if partner else [] + + # middle gaps + for i in range(len(seq) - 1): + curr, nxt = seq[i], seq[i + 1] + if curr["end"] < nxt["start"]: + cap_local = nxt["start"] + cap_partner = ( + first_start_at_or_after(partner_seq, curr["end"]) + if partner_seq + else global_end + ) + cap_sync = sync_boundary_for(curr["label"]) + cap = ( + min(cap_local, cap_partner, cap_sync) + if cap_sync is not None + else min(cap_local, cap_partner) + ) + if cap > curr["end"]: + curr["end"] = cap + curr.setdefault( + "fill_type", fill_type.get(curr["label"], "still") + ) + + # tail gap: cap by partner’s next (≥ end), not global + if seq: + last = seq[-1] + cap_partner = ( + first_start_at_or_after(partner_seq, last["end"]) + if partner_seq + else last["end"] + ) + if cap_partner > last["end"]: + last["end"] = cap_partner + last.setdefault("fill_type", fill_type.get(last["label"], "still")) + + return packages + + def parse_network( + self, + node_functions: Dict[str, Callable], + edge_functions: Dict[str, Callable], + vis_graph: bool = False, + ) -> Tuple[nx.DiGraph, Dict[str, List], Dict[str, Tuple[int, int]]]: + """Construct a graph with self.conf["node"]&["edge"], and node_functions, edge_functions be its node generator and edge linker. + + Return the constructed nx.DiGraph graph_compose, + + and tasks_data = {"scope name" : [(scope_id=task_id, skill_duration_{i})]}, + + and taskkey2index = {"edge_name": (scope_id=task_id, edge_id=skill_id)} + + Args: + node_functions (Dict[str, Callable]): A Dict consists of key-value pair that key be all nodes (affordance) name and value be its generating functions + edge_functions (Dict[str, Callable]): A Dict consists of key-value pair that key be all edges (skill, a part of a trajectory) name and value be its linker functions + vis_graph (bool, optional): Whether to show the graph or not. Defaults to True. + + Returns: + graph_compose: A composed nx.DiGraph representing the graph defined in self.conf, while the node generators and edge linkers prepared, + tasks_data: A Dict consists of key-value pair that key be all scopes' names, and value be List(Tuple=(scope_id=task_id, skill_duration)) + taskkey2index: A Dict consists of key-value pair that key be all edges' names, and value be Tuple=(scope_id=task_id, edge_id=skill_id) + """ + nodes = self.conf.get("node", {}) + edges = self.conf.get("edge", {}) + graph_type = self.conf.get("scope", {}) + + graphs = {key: nx.DiGraph() for key in graph_type.keys()} + disjoint_names = {} + tasks_data = {} + taskkey2index = {} + + # key2index = {} + edges_flatten = {} + for i, key in enumerate(graphs.keys()): + # key2index[key] = i + for j, edge in enumerate(edges[key]): + edge = deepcopy(edge) + taskkey2index[ActionBank.get_function_name(edge)] = (i, j) + edge["type"] = key + edges_flatten.update(edge) + for i, key in enumerate(graphs.keys()): + tasks_data[key] = [] + for edge in edges[key]: + label = ActionBank.get_function_name(edge) # edge label in config + cfg = edge[label] + src = cfg["src"] + sink = cfg["sink"] + kwargs = cfg.get("kwargs", {}) + duration = cfg.get("duration", 0) + if not isinstance(duration, int): + raise TypeError("Duration must be an integer.") + + # function to call + fn_name = cfg.get("name", label) + # normalize and persist fill_type (default + built-in rules) + fill_type = self._infer_fill_type(key, label, cfg) + + graphs[key].add_edge( + src, + sink, + linker=partial( + edge_functions[fn_name], **kwargs, duration=duration + ), + duration=duration, + fill_type=fill_type, + edge_label=label, + scope=key, + ) + tasks_data[key].append((i, duration)) + + for node in nodes[key]: + function_name = ActionBank.get_function_name(node) + if function_name in disjoint_names.keys(): + error_msg = f"Function {function_name} is already defined in {disjoint_names[function_name]} but re-defined in {key} again." + log_error(error_msg) + disjoint_names.update({function_name: key}) + graphs[key].add_node( + function_name, + generator=partial( + node_functions[node[function_name]["name"]], + **node[function_name]["kwargs"], + ), + ) + + graph_compose = nx.DiGraph() + for key, graph in graphs.items(): + if self.vis_graph or vis_graph: + nx.draw(graph, with_labels=True) + plt.show() + if graph_type[key]["type"] == "tree": + assert nx.is_tree( + graph.to_undirected() + ), "{} graph is not tree.".format(key) + + graph_compose = nx.compose(graph_compose, graph) + + if self.vis_graph or vis_graph: + nx.draw(graph_compose, with_labels=True) + plt.show() + + return graph_compose, tasks_data, taskkey2index + + def gantt( + self, + tasks_data: Dict[str, List], + taskkey2index: Dict[str, int], + vis: bool = False, + ) -> Dict[str, Any]: + """Given tasks on different machines and skills within tasks that takes a specific duration, try to minimize the max length among task, while respecting: + Constraint 1: For skills of a same task, which occupied a same machine, do not overlap with each other. + Constraint 2: For skills of a same task, the start time of skill should not surpass the end time of the before skill + Constraint 3: For sync edges define in self.conf["sync"], which defined a skill, its start time should not surpass the end time of the depend skill. + with a set of start and end time of all skills. Then draw the gantt with the solution, return the solution packages with start and end time of each edge. + + Args: + tasks_data (Dict[str, List]): A Dict consists of key-value pair that key be all scopes' names, and value be List(Tuple=(scope_id=task_id, skill_duration)) + taskkey2index (Dict[str, int]): A Dict consists of key-value pair that key be all edges' names, and value be Tuple=(scope_id=task_id, edge_id=skill_id) + vis (bool, optional): Whether to visualize the gantt or not. Defaults to False. + + Returns: + packages Dict[str, Any]: + { + "total_num": int(solver.objective_value)=max length of a task (among tasks = scopes = machines = executors), + "packages": packages=List(Dict=assigned_task={ + "labels": edge_name, + "start" : edge.start + "end": edge.start + edge.duration + "legend": task_name + "color": color representing the task_id, the former the bluer + } + ) + } + """ + import collections + + # https://developers.google.com/optimization/scheduling/task_shop?hl=zh-cn + from ortools.sat.python import cp_model + + machines_count = len(list(tasks_data.keys())) + all_machines = range(machines_count) + # Computes horizon dynamically as the sum of all durations. + id2key = {i: key for i, key in enumerate(tasks_data.keys())} + tasks_data = list(tasks_data.values()) + + horizon = sum(skill[1] for task in tasks_data for skill in task) + model = cp_model.CpModel() + # Named tuple to store information about created variables. + skill_type = collections.namedtuple("skill_type", "start end interval") + + # Creates task intervals and add to the corresponding machine lists. + all_skills = {} + machine_to_intervals = collections.defaultdict(list) + + for task_id, task in enumerate(tasks_data): + for skill_id, skill in enumerate(task): + machine, duration = skill + suffix = f"_{task_id}_{skill_id}" + start_var = model.new_int_var(0, horizon, "start" + suffix) + end_var = model.new_int_var(0, horizon, "end" + suffix) + interval_var = model.new_interval_var( + start_var, duration, end_var, "interval" + suffix + ) + all_skills[task_id, skill_id] = skill_type( + start=start_var, end=end_var, interval=interval_var + ) + machine_to_intervals[machine].append(interval_var) + + # Create and add disjunctive constraints for each machine. + for machine in all_machines: + model.add_no_overlap(machine_to_intervals[machine]) + + # Precedences inside a task. + for task_id, task in enumerate(tasks_data): + for skill_id in range(len(task) - 1): + model.add( + all_skills[task_id, skill_id + 1].start + >= all_skills[task_id, skill_id].end + ) + + sync_edges = self.conf["sync"] + for edge_name in sync_edges.keys(): + task_id, skill_id = taskkey2index[edge_name] + for depend_task in sync_edges[edge_name]["depend_tasks"]: + before_task_id, before_skill_id = taskkey2index[depend_task] + model.add( + all_skills[task_id, skill_id].start + >= all_skills[before_task_id, before_skill_id].end + ) + + # Makespan objective. + obj_var = model.new_int_var(0, horizon, "makespan") + + max_equality = [] + for task_id, task in enumerate(tasks_data): + if len(task) != 0: + max_equality.append(all_skills[task_id, len(task) - 1].end) + + model.add_max_equality(obj_var, max_equality) + model.minimize(obj_var) + solver = cp_model.CpSolver() + status = solver.solve(model) + + packages = [] + if status == cp_model.OPTIMAL or status == cp_model.FEASIBLE: + # Create one list of assigned skills per machine. + n = len(list(id2key.keys())) + color = [(1 - i / n, 0, i / n) for i in range(n)] + keys_ = list(taskkey2index.keys()) + values_ = list(taskkey2index.values()) + for task_id, task in enumerate(tasks_data): + for skill_id, skill in enumerate(task): + machine = skill[0] + duration = skill[1] + start = solver.value(all_skills[task_id, skill_id].start) + + assigned_task = {} + assigned_task["label"] = keys_[values_.index((task_id, skill_id))] + assigned_task["start"] = start + assigned_task["end"] = start + duration + assigned_task["legend"] = id2key[task_id] + assigned_task["color"] = color[task_id] + packages.append(assigned_task) + + # Finally print the solution found. + log_warning(f"Optimal Schedule Length: {solver.objective_value}") + else: + log_error("No solution found.") + + packages = self._apply_bubble_filling(packages, taskkey2index) + + new_total = max((p["end"] for p in packages), default=0) + + if self.vis_gantt or vis: + from embodichain.utils.visualizer import Gantt + + draw_gantt_data = { + "title": " Sample GANTT", + "xlabel": "Trajectory (steps)", + "packages": packages, + } + g = Gantt(draw_gantt_data) + g.render() + g.show() + + return {"total_num": int(new_total), "packages": packages} + + def initialize_action_list( + self, env, action_list, executor: str, executor_init_info: Dict + ) -> np.ndarray: + """ + Initialize the action list for a specific executor. + + This method initializes the action trajectory for the given executor based on the provided initialization information. + The initialization can be done using predefined qpos values or the current qpos of the executor. + + Args: + self (ActionBank): The ActionBank instance. + env (object): The environment instance containing executor and affordance data. + action_list (np.ndarray): A numpy array of shape (T, qpos_dim), representing the uninitialized action trajectory for the executor. + executor (str): The name of the executor (e.g., "left_arm", "right_arm"). + executor_init_info (Dict): A dictionary containing initialization information for the executor, such as method and parameters. + + Returns: + np.ndarray: The initialized action list for the executor. + """ + + def initialize_with_given_qpos(action_list, executor, executor_init_info, env): + given_qpos = executor_init_info.get("kwargs", {}).get("given_qpos", None) + if given_qpos is None: + log_warning( + "No given_qpos is provided for initialize_with_given_qpos. Using {}.".format( + get_init_affordance(executor) + ) + ) + given_qpos = env.affordance_datas[get_init_affordance(executor)] + + executor_qpos_dim = action_list[executor].shape[0] + given_qpos = np.asarray(given_qpos) + if len(given_qpos.shape) != 1: + log_warning( + f"Shape of given init qpos should be (1,), but got {given_qpos.shape} with length {len(given_qpos.shape)}. Using 0-th element with {given_qpos.shape[-1]}." + ) + last_ids = (0,) * (given_qpos.ndim - 1) + (Ellipsis,) + given_qpos = given_qpos[last_ids] + + if given_qpos.shape[0] != executor_qpos_dim: + log_error( + f"Shape of given init qpos should be {(executor_qpos_dim,)}, but got {given_qpos.shape[0]}." + ) + + init_node_name = executor_init_info.get( + "init_node_name", f"{executor}_init_qpos" + ) + if ( + len(init_node_name) > 0 + ): # so if you don't need to inject it, just assign the "init_node_name" to be "" in action_config + env.affordance_datas[init_node_name] = given_qpos + action_list[executor][:, 0] = given_qpos + + return action_list + + def initialize_with_current_qpos( + action_list, executor, executor_init_info, env + ): + # TODO: Hard to get current qpos for multi-agent env + current_qpos = env.robot.get_qpos() + joint_ids = env.robot.get_joint_ids(name=get_control_part(env, executor)) + if current_qpos.ndim == 2 and current_qpos.shape[0] == 1: + current_qpos = current_qpos[0] + current_qpos = current_qpos[joint_ids].cpu() + + executor_qpos_dim = action_list[executor].shape[0] + + # NOTE: hard code! + current_qpos = current_qpos[:executor_qpos_dim] + + if current_qpos.shape[0] != executor_qpos_dim: + log_error( + f"Shape of given init qpos should be {(executor_qpos_dim,)}, but got {current_qpos.shape[0]}." + ) + + init_node_name = executor_init_info.get( + "init_node_name", f"{executor}_init_qpos" + ) + if ( + len(init_node_name) > 0 + ): # so if you don't need to inject it, just assign the "init_node_name" to be "" in action_config + env.affordance_datas[init_node_name] = current_qpos + action_list[executor][:, 0] = current_qpos + + return action_list + + INIT_METHOD_MAPPING = { + "given_qpos": initialize_with_given_qpos, + "current_qpos": initialize_with_current_qpos, + } + + init_method = executor_init_info.get("method", "current_qpos") + + if init_method is None: + log_warning( + f"No init method provided in action config for executor {executor}, please check. Skipping.." + ) + return action_list + if init_method not in INIT_METHOD_MAPPING: + log_warning( + f"Provided init method action config for executor {executor}: {init_method} is not accomplished yet, please check. Skipping.." + ) + return action_list + else: + init_func = INIT_METHOD_MAPPING[init_method] + action_list = init_func(action_list, executor, executor_init_info, env) + + return action_list + + @staticmethod + @tag_node + @resolve_env_params + def generate_affordances_from_src(env, affordance_infos: List[Dict]) -> bool: + for affordance_info in affordance_infos: + src_key = affordance_info["src_key"] + dst_key = affordance_info["dst_key"] + valid_funcs_name_kwargs_proc = affordance_info[ + "valid_funcs_name_kwargs_proc" + ] + to_array = env.action_bank.warpping + + ret = generate_affordance_from_src( + env, src_key, dst_key, valid_funcs_name_kwargs_proc, to_array + ) + if not ret: + return False + return True + + def _prepare_warpping(self, env): + if hasattr(env, "affordance_datas"): + for affordance_name, affordance_value in env.affordance_datas.items(): + # NOTE: take only first arena's affordance data + if affordance_value.ndim == 3: + affordance_value = affordance_value[0] + if isinstance(affordance_value, torch.Tensor): + affordance_value = np.asarray(affordance_value.cpu()) + env.affordance_datas[affordance_name] = affordance_value + else: + log_warning("No env.affordance_datas, skip _prepare_warpping..") + + def create_action_list( + self, env, graph_compose: nx.DiGraph, packages: List[Dict], **kwargs + ) -> Dict: + """Create an action list based on the given environment, graph, and packages. + + Args: + env (embodichain.lab.gym.envs.BaseEnv): The environment instance. + graph_compose (nx.DiGraph): The composed graph containing nodes and edges. + packages (List[Dict]): The task packages with scheduling information. + + Returns: + Dict: The generated action list for all executors. + """ + + def initialize_action_list( + scope: Dict, total_num: int + ) -> Tuple[Dict, Dict, Dict]: + """Initialize action list and related variables.""" + action_list = {} + end_time = {} + in_working = {} + + for executor in scope.keys(): + end_time[executor] = 0 + in_working[executor] = False + + action_list[executor] = np.zeros( + tuple(scope[executor]["dim"]) + (total_num,), + dtype=getattr(np, scope[executor]["dtype"]), + ) + + init_info = scope[executor].get("init", {}) + action_list = self.initialize_action_list( + env, action_list, executor, init_info + ) + + return action_list, end_time, in_working + + def generate_nodes(graph_compose: nx.DiGraph, nodes: Dict) -> bool: + """Generate nodes using the graph's node generators.""" + node_generators = nx.get_node_attributes(graph_compose, "generator") + + failed_nodes = [] + log_info("Action bank start node generation for action graph...") + for node_dict_list in nodes.values(): + for node in node_dict_list: + node_name = list(node.keys())[0] + try: + log_info(f"\tGenerating node '{node_name}' .") + ret = node_generators[node_name](env, **kwargs) + if not ret: + log_warning(f"Node '{node_name}' generation fails.") + failed_nodes.append(node_name) + except KeyError as e: + log_warning( + f"[KeyError] '{node_name}': {e}. Node generator might be missing or invalid." + ) + failed_nodes.append(node_name) + except AttributeError as e: + log_warning( + f"[AttributeError] '{node_name}': {e}. Missing required attributes in environment." + ) + failed_nodes.append(node_name) + except TypeError as e: + log_warning( + f"[TypeError] '{node_name}': {e}. Check input data types." + ) + failed_nodes.append(node_name) + except ValueError as e: + log_warning( + f"[ValueError] '{node_name}': {e}. Check input values." + ) + failed_nodes.append(node_name) + except Exception as e: + log_warning( + f"[UnexpectedError] '{node_name}': {e}. Debug dependencies or implementation." + ) + failed_nodes.append(node_name) + if failed_nodes: + log_warning(f"Failed to generate the following nodes: {failed_nodes}") + return False + + log_info( + f"Node generation is finished. Total nodes generated: {sum(len(v) for v in nodes.values())}." + ) + return True + + def generate_edges( + total_num: int, + all_executors: List[str], + edges_flatten: Dict, + node_linkers: Dict, + ) -> None: + """ + Generate edges and populate the action list for all executors. + + Args: + total_num (int): The total number of time steps for the action list. + all_executors (List[str]): A list of executor names (e.g., "left_arm", "right_arm"). + edges_flatten (Dict[str, Dict]): A flattened dictionary of edges, where keys are edge labels + and values are dictionaries containing edge details (e.g., "src", "sink"). + node_linkers (Dict[Tuple[str, str], Callable]): A dictionary mapping edge (source, sink) pairs + to their corresponding linker functions. + + Returns: + None: This function modifies the `action_list` in place. + """ + + def get_task_in_time(tasks, time): + """Get the task that is active at the given time.""" + return next( + (task for task in tasks if task["start"] <= time < task["end"]), + None, + ) + + for i in tqdm(range(total_num), desc="Generating edges"): + for executor in all_executors: + if end_time[executor] == i: + in_working[executor] = False + + if not in_working[executor]: + pkg = get_task_in_time( + [ + pkg + for pkg in packages["packages"] + if pkg["legend"] == executor + ], + i, + ) + if pkg is None: + if i >= 1: + action_list[executor][..., i] = action_list[executor][ + ..., i - 1 + ] + else: + end_time[executor] = pkg["end"] + skill_idx = ( + edges_flatten[pkg["label"]]["src"], + edges_flatten[pkg["label"]]["sink"], + ) + ret = node_linkers[skill_idx](env) + if not isinstance(ret, np.ndarray): + + if isinstance(ret, torch.Tensor): + ret = ret.cpu().numpy() + else: + raise TypeError( + "The return value of the linker {} must be a numpy array, but a {}.".format( + skill_idx, type(ret) + ) + ) + + start_idx = pkg["start"] + end_idx = pkg["end"] + + T_need = end_idx - start_idx + T_orig = ret.shape[1] + + # fill_type of this edge + ft = edges_flatten[pkg["label"]].get( + "fill_type", pkg.get("fill_type", "still") + ) + + def _resample_time(x, new_T): + if new_T == x.shape[1]: + return x + if x.shape[1] <= 1: + return np.repeat(x, new_T, axis=1)[:, :new_T] + t_old = np.linspace(0.0, 1.0, x.shape[1]) + t_new = np.linspace(0.0, 1.0, new_T) + out = np.empty((x.shape[0], new_T), dtype=x.dtype) + for d in range(x.shape[0]): + out[d] = np.interp(t_new, t_old, x[d]) + return out + + def _pad_or_trim_last(x, new_T): + if new_T <= x.shape[1]: + return x[:, :new_T] + pad = np.repeat(x[:, -1:], new_T - x.shape[1], axis=1) + return np.concatenate([x, pad], axis=1) + + if T_need != T_orig: + if ft == "scalable": + ret = _resample_time(ret, T_need) + else: # "still" + ret = _pad_or_trim_last(ret, T_need) + + action_list[executor][..., start_idx:end_idx] = ret + in_working[executor] = True + + # Main logic + scope = self.conf["scope"] + total_num = packages["total_num"] + all_executors = list(scope.keys()) + edges_flatten = { + k: v + for edges in self.conf["edge"].values() + for edge in edges + for k, v in edge.items() + } + node_linkers = nx.get_edge_attributes(graph_compose, "linker") + + action_list, end_time, in_working = initialize_action_list(scope, total_num) + + if self.warpping: + self._prepare_warpping(env) + + if not generate_nodes(graph_compose, self.conf["node"]): + return None + + # After node initialization, check if env.affordance_datas contains updated initial value for each executor. + for executor in scope.keys(): + init_node_name = get_init_affordance(executor) + if ( + not hasattr(env, "affordance_datas") + or init_node_name not in env.affordance_datas + ): + log_warning( + f"Executor '{executor}': init_node_name '{init_node_name}' not found in env.affordance_datas. Skipping initial value update." + ) + continue + affordance_init = env.affordance_datas[init_node_name] + affordance_init = np.asarray(affordance_init) + action_init_slice = action_list[executor][:, 0] + if affordance_init.shape != action_init_slice.shape: + log_warning( + f"Executor '{executor}': affordance_init shape {affordance_init.shape} does not match action_list[executor][:, 0] shape {action_init_slice.shape}. Skipping initial value update." + ) + continue + if not np.allclose(action_init_slice, affordance_init): + log_info( + f"Updated initial value for executor '{executor}' in action_list from affordance_datas['{init_node_name}']." + ) + action_list[executor][:, 0] = affordance_init + + generate_edges(total_num, all_executors, edges_flatten, node_linkers) + + return action_list + + +def attach_node_and_edge( + cls: ActionBank, functions_dict: Dict[str, Dict[str, Callable]] +) -> ActionBank: + for tag, funcs in functions_dict.items(): + tag_function = get_func_tag(tag) + for func_name, func in funcs.items(): + setattr(cls, func_name, staticmethod(func)) + + class_name = cls.__name__ + if class_name in tag_function.functions.keys(): + tag_function.functions[class_name].update({func_name: func}) + else: + tag_function.functions.update({class_name: {func_name: func}}) + return cls + + +def attach_action_bank(cls, action_bank: ActionBank, **kwargs): + def set_attr_for_cls(cls, attr_name: str, attr_value: Any): + if hasattr(cls, attr_name): + getattr(cls, attr_name).append(attr_value) + else: + setattr(cls, attr_name, [attr_value]) + + action_config = kwargs.get("action_config", None) + if action_config is None: + log_error( + f"The action config is None, but it's needed for Env: {type(cls).__name__}, Task Type: {cls.metadata['task_type']}." + ) + set_attr_for_cls(cls, "action_banks", action_bank(action_config)) + + vis_graph = kwargs.get("vis", False) + graph_compose, jobs_data, jobkey2index = cls.action_banks[-1].parse_network( + get_func_tag("node").functions[cls.action_banks[-1].__class__.__name__], + get_func_tag("edge").functions[cls.action_banks[-1].__class__.__name__], + vis_graph=vis_graph, + ) + + vis_gantt = kwargs.get("vis", False) + package = cls.action_banks[-1].gantt(jobs_data, jobkey2index, vis=vis_gantt) + + set_attr_for_cls(cls, "packages", package) + set_attr_for_cls(cls, "graph_composes", graph_compose) + + return cls + + +def get_xpos_name(affordance_name: str) -> str: + if affordance_name.find("qpos") == -1: + affordance_xpos_name = affordance_name + "_xpos" + else: + affordance_xpos_name = affordance_name.replace("qpos", "xpos") + return affordance_xpos_name + + +def get_control_part(env, agent_uid): + control_parts = env.metadata["dataset"]["robot_meta"].get("control_parts", []) + + if agent_uid in control_parts: + return agent_uid + else: + return _data_key_to_control_part( + robot=env.robot, + control_parts=control_parts, + data_key=agent_uid, + ) + + +def generate_trajectory_qpos( + env, + agent_uid: str, + trajectory: Dict[str, np.ndarray], + trajectory_id: str, + gather_index: List[int], + trajectory_index: int, + affordance_name: str, + slaver: str = "", + canonical_trajectory: List[float] = None, + canonical_trajectory_index: int = None, + canonical_pose: List[float] = [], + vis: bool = False, +) -> bool: + affordance_xpos_name = get_xpos_name(affordance_name) + + current_qpos = torch.as_tensor(trajectory[trajectory_id])[trajectory_index][ + None, gather_index + ] # TODO: only for 1 env + affordance_xpos = env.robot.compute_fk( + torch.as_tensor(current_qpos), + get_control_part(env, agent_uid), + to_matrix=True, + ) + if slaver != "": + assert canonical_trajectory is not None + assert canonical_trajectory_index is not None + assert ( + len(canonical_pose) == 4 + ), f"canonical_pose should be a 4x4 matrix, but got {len(canonical_pose)} elements." + canonical_pose = torch.as_tensor( + canonical_pose, + device=affordance_xpos.device, + dtype=affordance_xpos.dtype, + ).reshape(1, 4, 4) + can_affordance_xpos = env.robot.compute_fk( + torch.as_tensor(canonical_trajectory)[canonical_trajectory_index][ + gather_index + ], + get_control_part(env, agent_uid), + to_matrix=True, + ) + can_obj_xpos = canonical_pose + obj_xpos = env.sim.get_asset(slaver).get_local_pose(to_matrix=True) + affordance_xpos = torch.bmm( + obj_xpos, torch.bmm(pose_inv(can_obj_xpos), can_affordance_xpos) + ) + control_part = get_control_part(env, agent_uid) + qpos_seed = env.robot.get_qpos()[:, env.robot.get_joint_ids(name=control_part)] + ret, current_qpos = env.robot.compute_ik( + affordance_xpos, qpos_seed, control_part + ) + ret = ret.all().item() + if not ret: + log_warning( + f"IK failed for slaver {slaver} with xpos {affordance_xpos}. Using the previous qpos instead." + ) + return False + + if vis: + env.sim.draw_marker( + cfg=MarkerCfg( + marker_type="axis", + axis_xpos=affordance_xpos, + axis_size=0.002, + axis_len=0.005, + ) + ) + # TODO: only support 1 env numpy now + current_qpos = current_qpos.squeeze(0).cpu().numpy() + affordance_xpos = affordance_xpos.squeeze(0).cpu().numpy() + + env.affordance_datas[affordance_name] = current_qpos + env.affordance_datas[affordance_xpos_name] = affordance_xpos + return True + + +def modify_action_config_edges( + action_config: Dict, + duration_updates: Dict[str, int] = None, + trajectory_updates: Dict[str, List] = None, + analytic_planner: bool = False, +) -> Dict: + """ + Modify the action configuration by updating the duration and trajectory of edges. + + This function iterates through all edges in the action configuration and applies updates to their + duration and trajectory based on the provided mappings. If `analytic_planner` is enabled, the edge + name is set to "plan_trajectory". + + Args: + action_config (Dict): The original action configuration. + duration_updates (Dict[str, int], optional): A mapping of edge names to their new durations. + trajectory_updates (Dict[str, List], optional): A mapping of edge names to their new trajectories. + analytic_planner (bool, optional): If True, sets the edge name to "plan_trajectory". Defaults to False. + + Returns: + Dict: The modified action configuration. + """ + modified_config = deepcopy(action_config) + + # Iterate through all scopes in the action configuration + for scope_name, scope_edges in modified_config["edge"].items(): + for edge_config in scope_edges: + edge_name = list(edge_config.keys())[0] + edge_data = edge_config[edge_name] + # If analytic_planner is enabled, set the edge name to "plan_trajectory" + if analytic_planner: + edge_data["name"] = "plan_trajectory" + + # Update the duration if a mapping is provided + if duration_updates and edge_name in duration_updates: + edge_data["duration"] = duration_updates[edge_name] + + # Update the trajectory if a mapping is provided + if trajectory_updates and edge_name in trajectory_updates: + edge_data.setdefault("kwargs", {}) # Ensure "kwargs" exists + edge_data["kwargs"]["trajectory"] = trajectory_updates[edge_name] + + return modified_config + + +def to_affordance_name(name: str) -> str: + return name.replace("generate_", "") + + +def to_affordance_node_func(name: str) -> str: + return "generate_" + name + + +class GeneralActionBank(ActionBank): + @staticmethod + @tag_edge + def load_trajectory( + env, + trajectory_id: str, + gather_index: List[int], + keypose_timesteps: Tuple[int, int], + raw_duration: int, + duration: int, + **kwargs, + ): + from scipy import interpolate + + f = {} + start_t, end_t = keypose_timesteps[0], keypose_timesteps[1] + trajectory = np.asarray(env.trajectory[trajectory_id])[:, gather_index] + sub_trajectory = trajectory[start_t:end_t, :] + ds_sub_trajectory = np.zeros((duration, sub_trajectory.shape[1])) + for i in range(sub_trajectory.shape[1]): + x = np.arange(sub_trajectory.shape[0]) + f[i] = interpolate.interp1d(x, sub_trajectory[:, i], axis=-1) + ds_sub_trajectory[:, i] = f[i](np.linspace(0, raw_duration - 1, duration)) + + return ds_sub_trajectory.T # (D, T) + + @staticmethod + @tag_edge + def mimic_trajectory( + env, + agent_uid: str, + raw_edge: Dict, + raw_affordance: Dict, + target_edge: Dict, + vis: bool = False, + **kwargs, + ): + + GeneralActionBank.generate_trajectory_qpos = generate_trajectory_qpos + if isinstance(raw_affordance, dict): + aff_kwargs = deepcopy(raw_affordance.get("kwargs", {})) + aff_kwargs.pop("trajectory", {}) + aff_kwargs.pop("canonical_trajectory", {}) + getattr(GeneralActionBank, raw_affordance["name"])( + env, + **aff_kwargs, + trajectory=env.trajectory, + canonical_trajectory=env.canonical_trajectory, + ) + xpos = env.affordance_datas[ + get_xpos_name(to_affordance_name(raw_affordance["name"])) + ] + else: + log_warning( + f"raw_affordance is not a dict, but {type(raw_affordance)} and {raw_affordance}. Using it as a string name directly." + ) + xpos = env.affordance_datas[get_xpos_name(raw_affordance)] + + # raw_trajectory = getattr(GeneralActionBank, raw_edge["name"])( + # env, **raw_edge.get("kwargs", {}) + # ) + # base_pose = env.agent.get_base_xpos(agent_uid) + + # import time + # for t, temp in enumerate([env.agent.get_fk(raw_trajectory[:, i], uid=agent_uid) + # for i in range(raw_trajectory.shape[1]) + # ]): + # if t % 10 ==0: + + # print(temp) + # env.scene.draw_marker(cfg=MarkerCfg( + # marker_type="axis", + # axis_xpos=env.agent.get_base_xpos(agent_uid) @ temp, + # axis_size=0.002, + # axis_len=0.005 + # )) + # time.sleep(0.01) + + # trans = np.linalg.inv(base_pose) @ new_xpos @ np.linalg.inv(xpos) @ base_pose + # ref_poses = [ + # trans @ env.agent.get_fk(raw_trajectory[:, i], uid=agent_uid) + # for i in range(raw_trajectory.shape[1]) + # ] + # if vis: + # env.scene.draw_marker(cfg=MarkerCfg( + # marker_type="axis", + # axis_xpos=xpos, + # axis_size=0.002, + # axis_len=0.005 + # )) + + # for t, temp in enumerate(ref_poses): + # print(temp) + # if t % 10 ==0: + # env.scene.draw_marker(cfg=MarkerCfg( + # marker_type="axis", + # axis_xpos=env.agent.get_base_xpos(agent_uid) @ temp, + # axis_size=0.002, + # axis_len=0.005 + # )) + # time.sleep(0.01) + + ref_poses = [] + target_edge["name"] = "plan_trajectory" + return getattr(GeneralActionBank, target_edge["name"])( + env, + ref_poses=ref_poses, + duration=target_edge["duration"], + vis=vis, + **target_edge.get("kwargs", {}), + ) + + @staticmethod + @tag_edge + def plan_trajectory( + env, + agent_uid: str, + keypose_names: List[str], + duration: int, + ref_poses: List[np.ndarray] = [], + vis: bool = False, + **kwargs, + ) -> np.ndarray: + from embodichain.lab.gym.motion_generation.action.arm_action import ( + ArmAction, + ) + + # Retrieve the start and end positions + start_qpos = env.affordance_datas[keypose_names[0]] + + control_part = get_control_part(env, agent_uid) + start_qpos = torch.as_tensor(env.affordance_datas[keypose_names[0]])[None] + start_xpos = torch.bmm( + env.robot.get_control_part_base_pose(control_part, to_matrix=True), + env.robot.compute_fk(start_qpos, control_part, to_matrix=True), + ) + + end_qpos = torch.as_tensor(env.affordance_datas[keypose_names[-1]]) + end_xpos = torch.bmm( + env.robot.get_control_part_base_pose(control_part, to_matrix=True), + env.robot.compute_fk(end_qpos, control_part, to_matrix=True), + ) + + # TODO: only 1 env + start_qpos = start_qpos.squeeze(0).cpu().numpy() + start_xpos = start_xpos.squeeze(0).cpu().numpy() + end_qpos = end_qpos.squeeze(0).cpu().numpy() + end_xpos = end_xpos.squeeze(0).cpu().numpy() + + if vis: + env.sim.draw_marker( + cfg=MarkerCfg( + marker_type="axis", + axis_xpos=start_xpos, + axis_size=0.002, + axis_len=0.005, + ) + ) + + env.sim.draw_marker( + cfg=MarkerCfg( + marker_type="axis", + axis_xpos=end_xpos, + axis_size=0.002, + axis_len=0.005, + ) + ) + + filtered_keyposes = [start_qpos, end_qpos] + if "eef" in agent_uid: + filtered_keyposes = [start_qpos] + + if len(filtered_keyposes) == 1 and len(ref_poses) == 0: + # 只有一个点,返回静止轨迹 + ret = np.array([filtered_keyposes[0]] * duration) + else: + # 生成轨迹 + if len(ref_poses) == 0: + ret, _ = ArmAction.create_discrete_trajectory( + agent=env.robot, + uid=get_control_part(env, agent_uid), + qpos_list=filtered_keyposes, + sample_num=duration, + qpos_seed=filtered_keyposes[0], + is_use_current_qpos=False, + **getattr(env, "planning_config", {}), + ) + else: + ret, _ = ArmAction.create_discrete_trajectory( + agent=env.robot, + uid=get_control_part(env, agent_uid), + xpos_list=[start_xpos] + ref_poses + [end_xpos], + sample_num=duration, + is_use_current_qpos=False, + **getattr(env, "planning_config", {}), + ) + if isinstance(ret, list): + print(ret) + return ret.T + + @staticmethod + @tag_edge + def execute_open(env, **kwargs): + from embodichain.lab.gym.utils.misc import ( + mul_linear_expand, + ) + + duration = kwargs.get("duration", 1) + expand = kwargs.get("expand", True) + if expand: + action = mul_linear_expand(np.array([[1.0], [0.0]]), [duration - 1]) + action = np.concatenate([action, np.array([[0.0]])]).transpose() + else: + action = np.zeros((1, duration)) + return action + + @staticmethod + @tag_edge + def execute_close(env, **kwargs): + from embodichain.lab.gym.utils.misc import ( + mul_linear_expand, + ) + + duration = kwargs.get("duration", 1) + expand = kwargs.get("expand", True) + if expand: + action = mul_linear_expand(np.array([[0.0], [1.0]]), [duration - 1]) + action = np.concatenate([action, np.array([[1.0]])]).transpose() + else: + action = np.ones((1, duration)) + return action + + +class ActionBankMimic: + def __init__(self, action_banks: List[ActionBank], prob: float = 0.5) -> None: + self.action_banks = action_banks + self.keyword = "mimicable" + self.prob = prob + + def mimic(self, id=None) -> ActionBank: + if len(self.action_banks) == 1: + return self.action_banks[0] + + if id is None: + id = np.random.randint(len(self.action_banks)) + assert id < len( + self.action_banks + ), f"Invalid id {id}, should be less than {len(self.action_banks)}" + + acb = self.action_banks[id] + ret_acb = deepcopy(acb) + node_names = acb.get_node_names(bool_attr_name=self.keyword) + + ret_node_grap2id = ret_acb.graph2id() + edge_need_modify = {} + for scope in node_names.keys(): + edge_need_modify[scope] = [] + # if np.random.random() < self.prob: + # continue + for node in node_names[scope]: + mimic_id = np.random.randint(len(self.action_banks)) + action_bank_mimic_for_this_node = self.action_banks[mimic_id] + mimic_node_names = action_bank_mimic_for_this_node.get_node_names( + bool_attr_name=self.keyword + ) + temp_graph2id = action_bank_mimic_for_this_node.graph2id() + + if node in mimic_node_names[scope]: + ret_acb.conf["node"][scope][ret_node_grap2id[scope][node]][ + node + ] = deepcopy( + action_bank_mimic_for_this_node.conf["node"][scope][ + temp_graph2id[scope][node] + ][node] + ) + + edges = ret_acb.get_edge_names(node_name=node) + for edge in edges[scope]: + edge.update({"mimic_id": mimic_id}) + edge_need_modify[scope].extend(edges[scope]) + else: + log_warning( + f"Node {node} in scope {scope} not found in action bank {mimic_id} [{mimic_node_names[scope]}] to mimic." + ) + edge_need_modify[scope] = { + v["name"]: v for v in edge_need_modify[scope] + }.values() + + raw_edge_grap2id = acb.graph2id("edge") + raw_node_grap2id = acb.graph2id("node") + ret_edge_grap2id = ret_acb.graph2id("edge") + for scope in node_names.keys(): + for edge in edge_need_modify[scope]: + edge_name = edge["name"] + mimic_id = edge.pop("mimic_id") + mimic_grap2id = self.action_banks[mimic_id].graph2id("edge") + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["name"] = "mimic_trajectory" + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["duration"] = deepcopy( + self.action_banks[mimic_id].conf["edge"][scope][ + mimic_grap2id[scope][edge_name] + ][edge_name]["duration"] + ) + + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["kwargs"] = {} + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["kwargs"]["agent_uid"] = scope + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["kwargs"]["raw_edge"] = deepcopy( + acb.conf["edge"][scope][raw_edge_grap2id[scope][edge_name]][ + edge_name + ] + ) + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["kwargs"]["raw_affordance"] = ( + deepcopy( + acb.conf["node"][scope][raw_node_grap2id[scope][node]][node] + ) + if node in raw_node_grap2id[scope] + else node + ) + ret_acb.conf["edge"][scope][ret_edge_grap2id[scope][edge_name]][ + edge_name + ]["kwargs"]["target_edge"] = deepcopy( + self.action_banks[mimic_id].conf["edge"][scope][ + mimic_grap2id[scope][edge_name] + ][edge_name] + ) + + return ret_acb diff --git a/embodichain/lab/gym/envs/action_bank/utils.py b/embodichain/lab/gym/envs/action_bank/utils.py new file mode 100644 index 00000000..52fb7c08 --- /dev/null +++ b/embodichain/lab/gym/envs/action_bank/utils.py @@ -0,0 +1,69 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import numpy as np + +from copy import deepcopy +from typing import List, Union, Optional + +from embodichain.utils import logger +from embodichain.lab.gym.utils.misc import validation_with_process_from_name + + +"""Node Generation Utils""" + + +def get_init_affordance(scope: str, tag: str = "init") -> str: + return "{}_{}_qpos".format(scope, tag) + + +def generate_affordance_from_src( + env, + src_key: str, + dst_key: str, + valid_funcs_name_kwargs_proc: Optional[List] = None, + to_array: bool = True, +) -> bool: + """Generate a new affordance entry in env.affordance_datas by applying a validation and processing + pipeline to an existing source affordance. + + Args: + env: The environment object containing affordance data. + src_key (str): The key of the source affordance in env.affordance_datas. + dst_key (str): The key to store the generated affordance in env.affordance_datas. + valid_funcs_name_kwargs_proc (Optional[List]): A list of validation or processing functions (with kwargs) + to apply to the source affordance. Defaults to an empty list. + to_array (bool): Whether to convert the result to a numpy array before storing. Defaults to True. + + Returns: + bool: True if the affordance was successfully generated and stored, False otherwise. + """ + if valid_funcs_name_kwargs_proc is None: + valid_funcs_name_kwargs_proc = [] + try: + result = validation_with_process_from_name( + env, + deepcopy(env.affordance_datas[src_key]), + valid_funcs_name_kwargs_proc, + ) + if result is None: + logger.log_warning(f"Failed to generate {dst_key} from {src_key}") + return False + + env.affordance_datas[dst_key] = np.asarray(result) if to_array else result + return True + except Exception as e: + logger.log_error(f"Affordance generation error: {e}") + return False diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py new file mode 100644 index 00000000..123c2c04 --- /dev/null +++ b/embodichain/lab/gym/envs/base_env.py @@ -0,0 +1,506 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import gymnasium as gym + +from typing import Dict, List, Union, Tuple, Any, Optional, Sequence +from functools import cached_property + +from embodichain.lab.sim.types import EnvObs, EnvAction +from embodichain.lab.sim import SimulationManagerCfg, SimulationManager +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.sensors import BaseSensor +from embodichain.lab.gym.utils import gym_utils +from embodichain.utils import configclass +from embodichain.utils import logger, set_seed + +__all__ = ["BaseEnv", "EnvCfg"] + + +@configclass +class EnvCfg: + """Configuration for an Robot Learning Environment.""" + + num_envs: int = 1 + """The number of sub environments (arena in dexsim context) to be simulated in parallel.""" + + sim_cfg: SimulationManagerCfg = SimulationManagerCfg() + """Simulation configuration for the environment.""" + + seed: Optional[int] = None + """The seed for the random number generator. Defaults to -1, in which case the seed is not set. + + Note: + The seed is set at the beginning of the environment initialization. This ensures that the environment + creation is deterministic and behaves similarly across different runs. + """ + + sim_steps_per_control: int = 4 + """Number of simulation steps per control (env) step. + + For instance, if the simulation dt is 0.01s and the control dt is 0.1s, then the `sim_steps_per_control` is 10. + This means that the control action is updated every 10 simulation steps. + """ + + ignore_terminations: bool = False + """Whether to ignore terminations when deciding when to auto reset. Terminations can be caused by + the task reaching a success or fail state as defined in a task's evaluation function. + + If set to False, meaning there is early stop in episode rollouts. + If set to True, this would generally for situations where you may want to model a task as infinite horizon where a task + stops only due to the timelimit. + """ + + +class BaseEnv(gym.Env): + """Base environment for robot learning. + + Args: + cfg (EnvCfg): The environment configuration. + **kwargs: Additional keyword arguments. + """ + + # placeholder contains any meta information about the environment. + metadata: Dict = {} + + # The simulator manager instance. + sim: SimulationManager = None + + # TODO: May be support multiple robots in the future. + # The robot agent instance. + robot: Robot = None + + # The sensors used in the environment. + sensors: Dict[str, BaseSensor] = {} + + # The action space is determined by the robot agent and the task the environment is used for. + action_space: gym.spaces.Space = None + # The observation space is determined by the sensors used in the environment and the task the environment is used for. + observation_space: gym.spaces.Space = None + + single_action_space: gym.spaces.Space = None + single_observation_space: gym.spaces.Space = None + + def __init__( + self, + cfg: EnvCfg, + **kwargs, + ): + self.cfg = cfg + + # the number of envs to be simulated in parallel. + self.num_envs = self.cfg.num_envs + + if self.cfg.sim_cfg is None: + self.sim_cfg = SimulationManagerCfg(headless=True) + else: + self.sim_cfg = self.cfg.sim_cfg + + if self.cfg.seed is not None: + self.cfg.seed = set_seed(self.cfg.seed) + else: + logger.log_info(f"No seed is set for the environment.") + + self.sim_freq = int(1 / self.sim_cfg.physics_dt) + self.control_freq = self.sim_freq // self.cfg.sim_steps_per_control + + self._setup_scene(**kwargs) + + # TODO: To be removed. + if self.device.type == "cuda": + self.sim.init_gpu_physics() + + if not self.sim_cfg.headless: + self.sim.open_window() + + self._elapsed_steps = torch.zeros( + self.num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device + ) + + self._init_sim_state(**kwargs) + + self._init_raw_obs: Dict = self.get_obs(**kwargs) + + logger.log_info("[INFO]: Initialized environment:") + logger.log_info(f"\tEnvironment device : {self.sim.device}") + logger.log_info(f"\tNumber of environments: {self.num_envs}") + logger.log_info(f"\tEnvironment seed : {self.cfg.seed}") + logger.log_info(f"\tPhysics dt : {self.sim_cfg.physics_dt}") + logger.log_info( + f"\tEnvironment dt : {self.sim_cfg.physics_dt * self.cfg.sim_steps_per_control}" + ) + + @property + def device(self) -> torch.Tensor: + """Return the device used by the environment.""" + return self.sim.device + + @cached_property + def single_observation_space(self) -> gym.spaces.Space: + if self.num_envs == 1: + return gym_utils.convert_observation_to_space(self._init_raw_obs) + else: + return gym_utils.convert_observation_to_space( + self._init_raw_obs, unbatched=True + ) + + @cached_property + def observation_space(self) -> gym.spaces.Space: + if self.num_envs == 1: + return self.single_observation_space + else: + return gym.vector.utils.batch_space( + self.single_observation_space, n=self.num_envs + ) + + @cached_property + def action_space(self) -> gym.spaces.Space: + if self.num_envs == 1: + return self.single_action_space + else: + return gym.vector.utils.batch_space( + self.single_action_space, n=self.num_envs + ) + + @property + def elapsed_steps(self) -> Union[int, torch.Tensor]: + return self._elapsed_steps + + def get_sensor(self, name: str, **kwargs) -> BaseSensor: + """Get the sensor instance by name. + + Args: + name: The name of the sensor. + kwargs: Additional keyword arguments. + + Returns: + The sensor instance. + """ + if name not in self.sensors: + logger.log_error( + f"Sensor '{name}' not found in the environment. Available sensors: {list(self.sensors.keys())}" + ) + + return self.sensors[name] + + def _setup_scene(self, **kwargs): + # Init sim manager. + # we want to open gui window when the scene is setup, so init sim manager in headless mode first. + headless = self.sim_cfg.headless + self.sim_cfg.headless = True + self.sim = SimulationManager(self.sim_cfg) + self.sim_cfg.headless = headless + self.sim.set_manual_update(True) + + logger.log_info( + f"Initializing {self.num_envs} environments on {self.sim_cfg.sim_device}." + ) + if self.num_envs > 1: + self.sim.build_multiple_arenas(self.num_envs) + + self.robot = self._setup_robot(**kwargs) + if self.robot is None: + logger.log_error( + f"The robot instance must be initialized in :meth:`_setup_robot` function." + ) + if self.single_action_space is None: + logger.log_error( + f":attr:`single_action_space` must be defined in the :meth:`_setup_robot` function." + ) + + self._prepare_scene(**kwargs) + + self.sensors = self._setup_sensors(**kwargs) + + def _setup_robot(self, **kwargs) -> Robot: + """Load the robot agent, setup the controller and action space. + + Note: + 1. The fuction must return the robot instance. + 2. The self.single_action_space should be defined. + """ + + # TODO: single_action_space may be configured in config? + pass + + def _prepare_scene(self, **kwargs) -> None: + """Prepare the scene assets into the environment. + + This function can be customized to performed different scene creation ways, such as loading from file. + """ + pass + + def _setup_sensors(self, **kwargs) -> Dict[str, BaseSensor]: + """Setup the sensors used in the environment. + + The sensors to be setup could be binding to the robot or the environment. + + Note: + If the function is overridden, it must return a dictionary of sensors with the sensor name as the key + and the sensor instance as the value. + """ + return {} + + def _init_sim_state(self, **kwargs): + """Initialize the simulation state at the beginning of scene creation.""" + pass + + def _update_sim_state(self, **kwargs): + """Update the simulation state at each step. + + The function is called internally by the environment in :meth:`step` after update the physics simulation. + + Note: + Currently, the interface is designed to perform randomization of lighting, textures at each simulation step. + + Args: + **kwargs: Additional keyword arguments to be passed to the :meth:`_update_sim_state` function. + """ + # TODO: Add randomization event here. + pass + + def _initialize_episode(self, env_ids: Optional[Sequence[int]] = None, **kwargs): + """Initialize the simulation assets before each episode. Randomization can be performed at this stage. + + Args: + env_ids: The environment IDs to be initialized. If None, all environments are initialized. + This is useful for vectorized environments to reset only the specified environments. + **kwargs: Additional keyword arguments to be passed to the :meth:`_initialize_episode` function. + """ + pass + + def _get_sensor_obs(self, **kwargs) -> Dict[str, any]: + """Get the sensor observation from the environment. + + Args: + **kwargs: Additional keyword arguments to be passed to the :meth:`_get_sensor_obs` function. + + Returns: + The sensor observation dictionary. + """ + obs = {} + + fetch_only = False + if self.sim.is_rt_enabled: + fetch_only = True + self.sim.render_camera_group() + + for sensor_name, sensor in self.sensors.items(): + sensor.update(fetch_only=fetch_only) + obs[sensor_name] = sensor.get_data() + return obs + + def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: + """Extend the observation dictionary. + + Overwrite this function to extend or modify extra observation to the existing keys (robot, sensor, extra). + + Args: + obs: The observation dictionary. + **kwargs: Additional keyword arguments to be passed to the :meth:`_extend_obs` function. + + Returns: + The extended observation dictionary. + """ + return obs + + def get_obs(self, **kwargs) -> EnvObs: + """Get the observation from the robot agent and the environment. + + The default observation are: + - robot: the robot proprioception. + - sensor (optional): the sensor readings. + - extra (optional): any extra information. + + Note: + If self.num_envs == 1, return the observation in single_observation_space format. + If self.num_envs > 1, return the observation in observation_space format. + + Args: + **kwargs: Additional keyword arguments to be passed to the :meth:`_get_sensor_obs` functions. + + Returns: + The observation dictionary. + """ + obs = None + + obs = dict(robot=self.robot.get_proprioception()) + + sensor_obs = self._get_sensor_obs(**kwargs) + if sensor_obs: + obs["sensor"] = sensor_obs + + obs = self._extend_obs(obs=obs, **kwargs) + + return obs + + def evaluate(self, **kwargs) -> Dict[str, Any]: + """ + Evaluate whether the environment is currently in a success state by returning a dictionary with a "success" key or + a failure state via a "fail" key + + This function may also return additional data that has been computed (e.g. is the robot grasping some object) that may be + reused when generating observations and rewards. + + By default if not overridden, this function returns an empty dictionary + + Args: + **kwargs: Additional keyword arguments to be passed to the :meth:`evaluate` function. + + Returns: + The evaluation dictionary. + """ + return dict() + + def get_info(self, **kwargs) -> Dict[str, Any]: + """Get info about the current environment state, include elapsed steps, success, fail, etc. + + The returned info dictionary must contain at the success and fail status of the current step. + + Args: + **kwargs: Additional keyword arguments to be passed to the :meth:`get_info` function. + + Returns: + The info dictionary. + """ + info = dict(elapsed_steps=self._elapsed_steps) + + info.update(self.evaluate(**kwargs)) + return info + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> bool: + """Check if the episode is truncated. + + Args: + obs: The observation from the environment. + info: The info dictionary. + + Returns: + True if the episode is truncated, False otherwise. + """ + return torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + + def get_reward( + self, + obs: EnvObs, + action: EnvAction, + info: Dict[str, Any], + ) -> float: + """Get the reward for the current step. + + Each SimulationManager env must implement its own get_reward function to define the reward function for the task, If the + env is considered for RL/IL training. + + Args: + obs: The observation from the environment. + action: The action applied to the robot agent. + info: The info dictionary. + + Returns: + The reward for the current step. + """ + + return torch.zeros(self.num_envs, dtype=torch.float32, device=self.device) + + def _step_action(self, action: EnvAction) -> EnvAction: + """Set action control command into simulation. + + Args: + action: The action applied to the robot agent. + + Returns: + The action return. + """ + pass + + def reset( + self, seed: Optional[int] = None, options: Optional[Dict] = None + ) -> Tuple[EnvObs, Dict]: + """Reset the SimulationManager environment and return the observation and info. + + Args: + seed: The seed for the random number generator. Defaults to None, in which case the seed is not set. + options: Additional options for resetting the environment. This can include: + + Returns: + A tuple containing the observations and infos. + """ + if seed is not None: + torch.manual_seed(seed) + + if options is None: + options = dict() + + reset_ids = options.get( + "reset_ids", + torch.arange(self.num_envs, dtype=torch.int32, device=self.device), + ) + self.sim.reset_objects_state(env_ids=reset_ids) + self._elapsed_steps[reset_ids] = 0 + + # Reset hook for user to perform any custom reset logic. + self._initialize_episode(reset_ids, **options) + + return self.get_obs(**options), self.get_info(**options) + + def step( + self, action: EnvAction, **kwargs + ) -> Tuple[EnvObs, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + """Step the environment with the given action. + + Args: + action: The action applied to the robot agent. + + Returns: + A tuple contraining the observation, reward, terminated, truncated, and info dictionary. + """ + self._elapsed_steps += 1 + + # TODO: may be add hook for action preprocessing. + action = self._step_action(action=action) + self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) + self._update_sim_state(**kwargs) + + obs = self.get_obs(**kwargs) + info = self.get_info(**kwargs) + rewards = self.get_reward(obs=obs, action=action, info=info) + + terminateds = torch.logical_or( + info.get( + "success", + torch.zeros(self.num_envs, dtype=torch.bool, device=self.device), + ), + info.get( + "fail", torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + ), + ) + truncateds = self.check_truncated(obs=obs, info=info) + if self.cfg.ignore_terminations: + terminateds[:] = False + + dones = torch.logical_or(terminateds, truncateds) + reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + obs, _ = self.reset(options={"reset_ids": reset_env_ids}) + + # TODO: may be add hook for observation postprocessing. + + return obs, rewards, terminateds, truncateds, info + + def close(self) -> None: + """Close the environment and release resources.""" + self.sim.destroy() diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py new file mode 100644 index 00000000..4dcd2412 --- /dev/null +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -0,0 +1,506 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +import gymnasium as gym + +from dataclasses import MISSING +from typing import Dict, Union, Optional, Sequence, Tuple, Any, List + +from embodichain.lab.sim.cfg import ( + RobotCfg, + RigidObjectCfg, + RigidObjectGroupCfg, + ArticulationCfg, + LightCfg, +) +from embodichain.lab.gym.envs.action_bank.configurable_action import ( + get_func_tag, +) +from embodichain.lab.gym.envs.action_bank.configurable_action import ( + ActionBank, +) +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.sensors import BaseSensor, SensorCfg +from embodichain.lab.sim.types import EnvObs, EnvAction +from embodichain.lab.gym.envs import BaseEnv, EnvCfg +from embodichain.lab.gym.envs.managers import ( + EventManager, + ObservationManager, +) +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import configclass, logger + + +__all__ = ["EmbodiedEnvCfg", "EmbodiedEnv"] + + +@configclass +class EmbodiedEnvCfg(EnvCfg): + """Configuration class for the Embodied Environment. Inherits from EnvCfg and can be extended + with additional parameters if needed. + """ + + @configclass + class EnvLightCfg: + direct: List[LightCfg] = [] + + # TODO: support more types of indirect light in the future. + # indirect: Dict[str, Any] | None = None + + robot: RobotCfg = MISSING + + sensor: List[SensorCfg] = [] + + light: EnvLightCfg = EnvLightCfg() + + background: List[RigidObjectCfg] = [] + + rigid_object: List[RigidObjectCfg] = [] + + rigid_object_group: List[RigidObjectGroupCfg] = [] + + articulation: List[ArticulationCfg] = [] + + events: Union[object, None] = None + """Event settings. Defaults to None, in which case no events are applied through the event manager. + + Please refer to the :class:`embodichain.lab.gym.managers.EventManager` class for more details. + """ + + observations: Union[object, None] = None + """Observation settings. Defaults to None, in which case no additional observations are applied through + the observation manager. + + Please refer to the :class:`embodichain.lab.gym.managers.ObservationManager` class for more details. + """ + + # TODO: This would be changed to a more generic data pipeline configuration. + dataset: Union[Dict[str, Any], None] = None + """Data pipeline configuration. Defaults to None. + """ + + # Some helper attributes + filter_visual_rand: bool = False + """Whether to filter out visual randomization + + This is useful when we want to disable visual randomization for debug motion and physics issues. + """ + + +@register_env("EmbodiedEnv-v1") +class EmbodiedEnv(BaseEnv): + """Embodied AI environment that is used to simulate the Embodied AI tasks. + + Core simulation components for Embodied AI environments. + - sensor: The sensors used to perceive the environment, which could be attached to the agent or the environment. + - robot: The robot which will be used to interact with the environment. + - light: The lights in the environment, which could be used to illuminate the environment. + - indirect: the indirect light sources, such as ambient light, IBL, etc. + The indirect light sources are used for global illumination which affects the entire scene. + - direct: The direct light sources, such as point light, spot light, etc. + The direct light sources are used for local illumination which mainly affects the arena in the scene. + - background: Kinematic or Static rigid objects, such as obstacles or landmarks. + - rigid_object: Dynamic objects that can be interacted with. + - rigid_object_group: Groups of rigid objects that can be interacted with. + - deformable_object(TODO: supported in the future): Deformable volumes or surfaces (cloth) that can be interacted with. + - articulation: Articulated objects that can be manipulated, such as doors, drawers, etc. + - event manager: The event manager is used to manage the events in the environment, such as randomization, + perturbation, etc. + - observation manager: The observation manager is used to manage the observations in the environment, + such as depth, segmentation, etc. + - action bank: The action bank is used to manage the actions in the environment, such as action composition, action graph, etc. + - affordance_datas: The affordance data that can be used to store the intermediate results or information + """ + + def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): + self.affordance_datas = {} + self.action_bank = None + + extensions = getattr(cfg, "extensions", {}) or {} + + defaults = { + "obs_mode": "state", + "episode_length": 50, + "joint_limits": 0.5, + "action_scale": 0.1, + } + + for name, default in defaults.items(): + value = extensions.get(name, getattr(cfg, name, default)) + setattr(cfg, name, value) + setattr(self, name, getattr(cfg, name)) + + super().__init__(cfg, **kwargs) + + def _init_sim_state(self, **kwargs): + """Initialize the simulation state at the beginning of scene creation.""" + + self._apply_functor_filter() + + # create event manager + self.cfg: EmbodiedEnvCfg + if self.cfg.events: + self.event_manager = EventManager(self.cfg.events, self) + + # perform events at the start of the simulation + if "startup" in self.event_manager.available_modes: + self.event_manager.apply(mode="startup") + + if self.cfg.observations: + self.observation_manager = ObservationManager(self.cfg.observations, self) + + # TODO: A workaround for handling dataset saving, which need history data of obs-action pairs. + # We may improve this by implementing a data manager to handle data saving and online streaming. + if self.cfg.dataset is not None: + self.metadata["dataset"] = self.cfg.dataset + self.episode_obs_list = [] + self.episode_action_list = [] + + self.curr_episode = 0 + + def _apply_functor_filter(self) -> None: + """Apply functor filters to the environment components based on configuration. + + This method is used to filter out certain components of the environment, such as visual randomization, + based on the configuration settings. For example, if `filter_visual_rand` is set to True in the configuration, + all visual randomization functors will be removed from the event manager. + """ + from embodichain.utils.module_utils import get_all_exported_items_from_module + from embodichain.lab.gym.envs.managers.cfg import EventCfg + + functors_to_remove = get_all_exported_items_from_module( + "embodichain.lab.gym.envs.managers.randomization.rendering" + ) + if self.cfg.filter_visual_rand and self.cfg.events: + # Iterate through all attributes of the events object + for attr_name in dir(self.cfg.events): + attr = getattr(self.cfg.events, attr_name) + if isinstance(attr, EventCfg): + if attr.func.__name__ in functors_to_remove: + logger.log_info( + f"Filtering out visual randomization functor: {attr.func.__name__}" + ) + setattr(self.cfg.events, attr_name, None) + + def _init_action_bank( + self, action_bank_cls: ActionBank, action_config: Dict[str, Any] + ): + """ + Initialize action bank and parse action graph structure. + + Args: + action_bank_cls: The ActionBank class for this environment. + action_config: The configuration dict for the action bank. + """ + self.action_bank = action_bank_cls(action_config) + misc_cfg = action_config.get("misc", {}) + try: + this_class_name = self.action_bank.__class__.__name__ + node_func = {} + edge_func = {} + for class_name in [this_class_name, ActionBank.__name__]: + node_func.update(get_func_tag("node").functions.get(class_name, {})) + edge_func.update(get_func_tag("edge").functions.get(class_name, {})) + except KeyError as e: + raise KeyError( + f"Function tag for {e} not found in action bank function registry." + ) + + self.graph_compose, jobs_data, jobkey2index = self.action_bank.parse_network( + node_functions=node_func, edge_functions=edge_func, vis_graph=False + ) + self.packages = self.action_bank.gantt( + tasks_data=jobs_data, taskkey2index=jobkey2index, vis=False + ) + + def set_affordance(self, key: str, value: Any): + """ + Set an affordance value by key. + + Args: + key (str): The affordance key. + value (Any): The affordance value. + """ + self.affordance_datas[key] = value + + def get_affordance(self, key: str, default: Any = None): + """ + Get an affordance value by key. + + Args: + key (str): The affordance key. + default (Any, optional): Default value if key not found. + + Returns: + Any: The affordance value or default. + """ + return self.affordance_datas.get(key, default) + + def reset( + self, seed: Optional[int] = None, options: Optional[Dict] = None + ) -> Tuple[EnvObs, Dict]: + obs, info = super().reset(seed=seed, options=options) + + if hasattr(self, "episode_obs_list"): + self.episode_obs_list = [obs] + self.episode_action_list = [] + + return obs, info + + def step( + self, action: EnvAction, **kwargs + ) -> Tuple[EnvObs, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + # TODO: Maybe add action preprocessing manager and its functors. + obs, reward, done, truncated, info = super().step(action, **kwargs) + + if hasattr(self, "episode_action_list"): + + self.episode_obs_list.append(obs) + self.episode_action_list.append(action) + + return obs, reward, done, truncated, info + + def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: + if self.observation_manager: + obs = self.observation_manager.compute(obs) + return obs + + def _prepare_scene(self, **kwargs) -> None: + self._setup_lights() + self._setup_background() + self._setup_interactive_objects() + + def _update_sim_state(self, **kwargs) -> None: + """Perform the simulation step and apply events if configured. + + The events manager applies its functors after physics simulation and rendering, + and before the observation and reward computation (if applicable). + """ + if self.cfg.events: + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval") + + def _initialize_episode( + self, env_ids: Optional[Sequence[int]] = None, **kwargs + ) -> None: + # apply events such as randomization for environments that need a reset + if self.cfg.events: + if "reset" in self.event_manager.available_modes: + self.event_manager.apply(mode="reset", env_ids=env_ids) + + def _step_action(self, action: EnvAction) -> EnvAction: + """Set action control command into simulation. + + Args: + action: The action applied to the robot agent. + + Returns: + The action return. + """ + # TODO: Support data structure action input such as struct. + qpos = action + + self.robot.set_qpos(qpos=qpos) + + return action + + def _setup_robot(self, **kwargs) -> Robot: + """Setup the robot in the environment. + + Currently, only joint position control is supported. Would be extended to support joint velocity and torque + control in the future. + + Returns: + Robot: The robot instance added to the scene. + """ + if self.cfg.robot is None: + logger.error("Robot configuration is not provided.") + + # Initialize the robot based on the configuration. + robot: Robot = self.sim.add_robot(self.cfg.robot) + + robot.build_pk_serial_chain() + + # TODO: we may need control parts to group actual controlled joints ids. + # In this way, the action pass to env should be a dict or struct to store the + # joint ids as well. + qpos_limits = robot.body_data.qpos_limits[0].cpu().numpy() + self.single_action_space = gym.spaces.Box( + low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 + ) + return robot + + def _setup_sensors(self, **kwargs) -> Dict[str, BaseSensor]: + """Setup the sensors in the environment. + + Returns: + Dict[str, BaseSensor]: A dictionary mapping sensor UIDs to sensor instances. + """ + + # TODO: support sensor attachment to the robot. + + sensors = {} + for cfg in self.cfg.sensor: + sensor = self.sim.add_sensor(cfg) + sensors[cfg.uid] = sensor + return sensors + + def _setup_lights(self) -> None: + """Setup the lights in the environment.""" + for cfg in self.cfg.light.direct: + self.sim.add_light(cfg=cfg) + + def _setup_background(self) -> None: + """Setup the static rigid objects in the environment.""" + for cfg in self.cfg.background: + if cfg.body_type == "dynamic": + logger.log_error( + f"Background object must be kinematic or static rigid object." + ) + self.sim.add_rigid_object(cfg=cfg) + + def _setup_interactive_objects(self) -> None: + """Setup the interactive objects in the environment.""" + + for cfg in self.cfg.articulation: + self.sim.add_articulation(cfg=cfg) + + for cfg in self.cfg.rigid_object: + if cfg.body_type != "dynamic": + logger.log_error( + f"Interactive rigid object must be dynamic rigid object." + ) + self.sim.add_rigid_object(cfg=cfg) + + for cfg in self.cfg.rigid_object_group: + self.sim.add_rigid_object_group(cfg=cfg) + + def preview_sensor_data( + self, name: str, data_type: str = "color", env_ids: int = 0, method: str = "plt" + ) -> None: + """Preview the sensor data by matplotlib + + Note: + Currently only support RGB image preview. + + Args: + name (str): name of the sensor to preview. + data_type (str): type of the sensor data to preview. + env_ids (int): index of the arena to preview. Defaults to 0. + method (str): method to preview the sensor data. Currently support "plt" and "cv2". Defaults to "plt". + """ + # TODO: this function need to be improved to support more sensor types and data types. + + sensor = self.get_sensor(name=name) + + if data_type not in sensor.SUPPORTED_DATA_TYPES: + logger.error( + f"Data type '{data_type}' not supported by sensor '{name}'. Supported types: {sensor.SUPPORTED_DATA_TYPES}" + ) + + sensor.update() + + data = sensor.get_data() + + # TODO: maybe put the preview (visualization) method to the sensor class. + if sensor.cfg.sensor_type == "StereoCamera": + view = data[data_type][env_ids].cpu().numpy() + view_right = data[f"{data_type}_right"][env_ids].cpu().numpy() + view = np.concatenate((view, view_right), axis=1) + else: + view = data[data_type][env_ids].cpu().numpy() + + if method == "cv2": + import cv2 + + cv2.imshow( + f"sensor_data_{data_type}", cv2.cvtColor(view, cv2.COLOR_RGB2BGR) + ) + cv2.waitKey(0) + elif method == "plt": + from matplotlib import pyplot as plt + + plt.imshow(view) + plt.savefig(f"sensor_data_{data_type}.png") + + def create_demo_action_list(self, *args, **kwargs) -> Optional[Sequence[EnvAction]]: + """Create a demonstration action list for the environment. + + This function should be implemented in subclasses to generate a sequence of actions + that demonstrate a specific task or behavior within the environment. + + Returns: + Optional[Sequence[EnvAction]]: A list of actions if a demonstration is available, otherwise None. + """ + raise NotImplementedError( + "The method 'create_demo_action_list' must be implemented in subclasses." + ) + + def to_dataset( + self, id: str, save_path: str = None, folder_name: str = None + ) -> Optional[str]: + """ + Convert the recorded episode data to a dataset format and save to disk. + + Args: + id (str): Unique identifier for the dataset. + save_path (str, optional): Path to save the dataset. If None, use config or default. + folder_name (str, optional): Folder name for saving. If None, auto-generate. + + Returns: + Optional[str]: The path to the saved dataset, or None if failed. + """ + # TODO: To be refactor data pipeline into more modularized and extendable way. + from embodichain.data.data_engine.data_dict_extractor import ( + fetch_imitation_dataset, + ) + from embodichain.lab.gym.utils.misc import camel_to_snake + + save_path = self.cfg.dataset.get("save_path", None) + if save_path is None: + from embodichain.data import database_demo_dir + + save_path = database_demo_dir + + if self.curr_episode == 0: + self.folder_name = f"{camel_to_snake(self.__class__.__name__)}_{camel_to_snake(self.robot.cfg.uid)}" + if os.path.exists(os.path.join(save_path, self.folder_name)): + self.folder_name = f"{self.folder_name}_{np.random.randint(0, 1000)}" + + dataset_path = fetch_imitation_dataset( + self, + self.episode_obs_list[:-1], + self.episode_action_list, + id, + self.folder_name, + ) + return dataset_path + + def is_task_success(self, **kwargs) -> torch.Tensor: + """Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. + + Args: + **kwargs: Additional arguments for task-specific success criteria. + + Returns: + torch.Tensor: A boolean tensor indicating success for each environment in the batch. + """ + + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py new file mode 100644 index 00000000..946165a8 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -0,0 +1,20 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .cfg import FunctorCfg, SceneEntityCfg, EventCfg, ObservationCfg +from .manager_base import Functor, ManagerBase +from .event_manager import EventManager +from .observation_manager import ObservationManager diff --git a/embodichain/lab/gym/envs/managers/cfg.py b/embodichain/lab/gym/envs/managers/cfg.py new file mode 100644 index 00000000..3f5c8da6 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/cfg.py @@ -0,0 +1,311 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from collections.abc import Callable +from dataclasses import MISSING +from typing import TYPE_CHECKING, Any, Literal + +from embodichain.lab.sim.objects import Articulation, RigidObject +from embodichain.utils import configclass + +if TYPE_CHECKING: + from embodichain.lab.sim import SimulationManager + from .manager_base import Functor + + # from .recorder_manager import RecorderTerm + + +@configclass +class FunctorCfg: + """Configuration for a functor.""" + + func: Callable | Functor = MISSING + """The function or class to be called for the functor. + + The function must take the environment object as the first argument. + The remaining arguments are specified in the :attr:`params` attribute. + + It also supports `callable classes`_, i.e. classes that implement the :meth:`__call__` + method. In this case, the class should inherit from the :class:`Functor` class + and implement the required methods. + + .. _`callable classes`: https://docs.python.org/3/reference/datamodel.html#object.__call__ + """ + + params: dict[str, Any | SceneEntityCfg] = dict() + """The parameters to be passed to the function as keyword arguments. Defaults to an empty dict. + + .. note:: + If the value is a :class:`SceneEntityCfg` object, the manager will query the scene entity + from the :class:`SimulationManager` and process the entity's joints and bodies as specified + in the :class:`SceneEntityCfg` object. + """ + + +@configclass +class EventCfg(FunctorCfg): + """Configuration for a event functor. + + The event functor is used to trigger events in the environment at specific times or under specific conditions. + The `mode` attribute determines when the functor is applied. + - `startup`: The functor is applied when the environment is started. + - `interval`: The functor is applied at each env step. + - `reset`: The functor is applied when the environment is reset. + """ + + mode: Literal["startup", "interval", "reset"] = "reset" + """The mode in which the event functor is applied. + + Note: + The mode name ``"interval"`` is a special mode that is handled by the + manager Hence, its name is reserved and cannot be used for other modes. + """ + + # TODO: maybe support simulation time-based events (time = step * (physics_dt * sim_steps_per_control)) + interval_step: int = 10 + """The number of environment step after which the functor is applied. Defaults to 4.""" + + is_global: bool = False + """Whether the event should be tracked on a per-environment basis. Defaults to False. + + If True, the same interval step is used for all the environment instances. + If False, the interval step is sampled independently for each environment instance + and the functor is applied when the current step hits the interval step for that instance. + + Note: + This is only used if the mode is ``"interval"``. + """ + + +@configclass +class ObservationCfg(FunctorCfg): + """Configuration for an observation functor. + + The observation functor is used to compute observations for the environment. The `mode` attribute + determines whether the observation is already present in the observation space or not. + """ + + mode: Literal["modify", "add"] = "modify" + """The mode for the observation computation. + + - `modify`: The observation is already present in the observation space, updated the value in-place. + - `add`: The observation is not present in the observation space, add a new entry to the observation space. + """ + + name: str = MISSING + """The name of the observation. + + The name can be a new key to observation space, eg: + - `object_position`: shape of (num_envs, 3) + - `robot/eef_pose`: shape of (num_envs, 7) or (num_envs, 4, 4) + - `sensor/cam_high/mask`: shape of (num_envs, H, W) + or a existing key to modify, eg: + - `robot/qpos`: shape of (num_envs, num_dofs) + `/` is used to separate different levels of hierarchy in the observation dictionary. + """ + + +@configclass +class SceneEntityCfg: + """Configuration for a scene entity that is used by the manager's functor. + + This class is used to specify the name of the scene entity that is queried from the + :class:`SimulationManager` and passed to the manager's functor. + """ + + uid: str = MISSING + """The name of the scene entity. + + This is the name defined in the scene configuration file. See the :class:`SimulationManagerCfg` + class for more details. + """ + + joint_names: str | list[str] | None = None + """The names of the joints from the scene entity. Defaults to None. + + The names can be either joint names or a regular expression matching the joint names. + + These are converted to joint indices on initialization of the manager and passed to the functor + as a list of joint indices under :attr:`joint_ids`. + """ + + joint_ids: list[int] | slice = slice(None) + """The indices of the joints from the asset required by the functor. Defaults to slice(None), which means + all the joints in the asset (if present). + + If :attr:`joint_names` is specified, this is filled in automatically on initialization of the + manager. + """ + + link_names: str | list[str] | None = None + """The names of the links from the asset required by the functor. Defaults to None. + + The names can be either link names or a regular expression matching the link names. + """ + + control_parts: str | list[str] | None = None + """The names of the control parts from the asset(only support for robot) required by the functor. Defaults to None. + """ + + # TODO: Maybe support tendon names and ids in the future. + + body_names: str | list[str] | None = None + """The names of the bodies from the asset required by the functor. Defaults to None. + + The names can be either body names or a regular expression matching the body names. + + These are converted to body indices on initialization of the manager and passed to the functor + function as a list of body indices under :attr:`body_ids`. + """ + + body_ids: list[int] | slice = slice(None) + """The indices of the bodies from the asset required by the functor. Defaults to slice(None), which means + all the bodies in the asset. + + If :attr:`body_names` is specified, this is filled in automatically on initialization of the + manager. + """ + + # TODO: Maybe support object collection (same as IsaacLab definitions). + + preserve_order: bool = False + """Whether to preserve indices ordering to match with that in the specified joint, body, or object collection names. + Defaults to False. + + If False, the ordering of the indices are sorted in ascending order (i.e. the ordering in the entity's joints, + bodies, or object in the object collection). Otherwise, the indices are preserved in the order of the specified + joint, body, or object collection names. + + For more details, see the :meth:`isaaclab.utils.string.resolve_matching_names` function. + + .. note:: + This attribute is only used when :attr:`joint_names`, :attr:`body_names` are specified. + + """ + + def resolve(self, scene: SimulationManager): + """Resolves the scene entity and converts the joint and body names to indices. + + This function examines the scene entity from the :class:`SimulationManager` and resolves the indices + and names of the joints and bodies. It is an expensive operation as it resolves regular expressions + and should be called only once. + + Args: + scene: The interactive scene instance. + + Raises: + ValueError: If the scene entity is not found. + ValueError: If both ``joint_names`` and ``joint_ids`` are specified and are not consistent. + ValueError: If both ``body_names`` and ``body_ids`` are specified and are not consistent. + """ + # check if the entity is valid + asset_uids = scene.asset_uids + if self.uid not in asset_uids: + raise ValueError( + f"The scene entity '{self.uid}' does not exist. Available entities: {asset_uids}." + ) + + # convert joint names to indices based on regex + self._resolve_joint_names(scene) + + # convert body names to indices based on regex + self._resolve_body_names(scene) + + def _resolve_joint_names(self, scene: SimulationManager): + # convert joint names to indices based on regex + if self.joint_names is not None or self.joint_ids != slice(None): + entity: Articulation = scene[self.uid] + # -- if both are not their default values, check if they are valid + if self.joint_names is not None and self.joint_ids != slice(None): + if isinstance(self.joint_names, str): + self.joint_names = [self.joint_names] + if isinstance(self.joint_ids, int): + self.joint_ids = [self.joint_ids] + joint_ids, _ = entity.find_joints( + self.joint_names, preserve_order=self.preserve_order + ) + joint_names = [entity.joint_names[i] for i in self.joint_ids] + if joint_ids != self.joint_ids or joint_names != self.joint_names: + raise ValueError( + "Both 'joint_names' and 'joint_ids' are specified, and are not consistent." + f"\n\tfrom joint names: {self.joint_names} [{joint_ids}]" + f"\n\tfrom joint ids: {joint_names} [{self.joint_ids}]" + "\nHint: Use either 'joint_names' or 'joint_ids' to avoid confusion." + ) + # -- from joint names to joint indices + elif self.joint_names is not None: + if isinstance(self.joint_names, str): + self.joint_names = [self.joint_names] + self.joint_ids, _ = entity.find_joints( + self.joint_names, preserve_order=self.preserve_order + ) + # performance optimization (slice offers faster indexing than list of indices) + # only all joint in the entity order are selected + if ( + len(self.joint_ids) == entity.num_joints + and self.joint_names == entity.joint_names + ): + self.joint_ids = slice(None) + # -- from joint indices to joint names + elif self.joint_ids != slice(None): + if isinstance(self.joint_ids, int): + self.joint_ids = [self.joint_ids] + self.joint_names = [entity.joint_names[i] for i in self.joint_ids] + + def _resolve_body_names(self, scene: SimulationManager): + # convert body names to indices based on regex + if self.body_names is not None or self.body_ids != slice(None): + entity: RigidObject = scene[self.uid] + # -- if both are not their default values, check if they are valid + if self.body_names is not None and self.body_ids != slice(None): + if isinstance(self.body_names, str): + self.body_names = [self.body_names] + if isinstance(self.body_ids, int): + self.body_ids = [self.body_ids] + body_ids, _ = entity.find_bodies( + self.body_names, preserve_order=self.preserve_order + ) + body_names = [entity.body_names[i] for i in self.body_ids] + if body_ids != self.body_ids or body_names != self.body_names: + raise ValueError( + "Both 'body_names' and 'body_ids' are specified, and are not consistent." + f"\n\tfrom body names: {self.body_names} [{body_ids}]" + f"\n\tfrom body ids: {body_names} [{self.body_ids}]" + "\nHint: Use either 'body_names' or 'body_ids' to avoid confusion." + ) + # -- from body names to body indices + elif self.body_names is not None: + if isinstance(self.body_names, str): + self.body_names = [self.body_names] + self.body_ids, _ = entity.find_bodies( + self.body_names, preserve_order=self.preserve_order + ) + # performance optimization (slice offers faster indexing than list of indices) + # only all bodies in the entity order are selected + if ( + len(self.body_ids) == entity.num_bodies + and self.body_names == entity.body_names + ): + self.body_ids = slice(None) + # -- from body indices to body names + elif self.body_ids != slice(None): + if isinstance(self.body_ids, int): + self.body_ids = [self.body_ids] + self.body_names = [entity.body_names[i] for i in self.body_ids] diff --git a/embodichain/lab/gym/envs/managers/event_manager.py b/embodichain/lab/gym/envs/managers/event_manager.py new file mode 100644 index 00000000..fa06b9d8 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/event_manager.py @@ -0,0 +1,353 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. + +# All rights reserved. +# +# This file incorporates code from the Isaac Lab Project +# Copyright (c) 2022-2025, The Isaac Lab Project Developers +# (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# ---------------------------------------------------------------------------- + +"""Event manager for orchestrating operations based on different simulation events.""" + +from __future__ import annotations + +import inspect +import torch +from collections.abc import Sequence +from prettytable import PrettyTable +from typing import TYPE_CHECKING, Union + +from embodichain.utils import logger +from .manager_base import ManagerBase +from .cfg import EventCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class EventManager(ManagerBase): + """Manager for orchestrating operations based on different simulation events. + + The event manager applies operations to the environment based on different simulation events. For example, + changing the masses of objects or their friction coefficients during initialization/ reset, or applying random + pushes to the robot at a fixed interval of steps. The user can specify several modes of events to fine-tune the + behavior based on when to apply the event. + + The event functors are parsed from a config class containing the manager's settings and each functor's + parameters. Each event functor should instantiate the :class:`EventCfg` class. + + Event functors can be grouped by their mode. The mode is a user-defined string that specifies when + the event functor should be applied. This provides the user complete control over when event + functors should be applied. + + For a typical training process, you may want to apply events in the following modes: + + - "prestartup": Event is applied once at the beginning of the training before the simulation starts. + This is used to randomize USD-level properties of the simulation stage. + - "startup": Event is applied once at the beginning of the training once simulation is started. + - "reset": Event is applied at every reset. + - "interval": Event is applied at pre-specified intervals of time. + + However, you can also define your own modes and use them in the training process as you see fit. + For this you will need to add the triggering of that mode in the environment implementation as well. + + .. note:: + + The triggering of operations corresponding to the mode ``"interval"`` are the only mode that are + directly handled by the manager itself. The other modes are handled by the environment implementation. + + """ + + _env: EmbodiedEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the event manager. + + Args: + cfg: A configuration object or dictionary (``dict[str, EventCfg]``). + env: An environment object. + """ + # create buffers to parse and store functors + self._mode_functor_names: dict[str, list[str]] = dict() + self._mode_functor_cfgs: dict[str, list[EventCfg]] = dict() + self._mode_class_functor_cfgs: dict[str, list[EventCfg]] = dict() + + # call the base class (this will parse the functors config) + super().__init__(cfg, env) + + def __str__(self) -> str: + """Returns: A string representation for event manager.""" + functor_num = sum(len(v) for v in self._mode_functor_names.values()) + msg = f" contains {functor_num} active functors.\n" + + # add info on each mode + for mode in self._mode_functor_names: + # create table for functor information + table = PrettyTable() + table.title = f"Active Event Functors in Mode: '{mode}'" + # add table headers based on mode + if mode == "interval": + table.field_names = ["Index", "Name", "Interval step"] + table.align["Name"] = "l" + for index, (name, cfg) in enumerate( + zip(self._mode_functor_names[mode], self._mode_functor_cfgs[mode]) + ): + table.add_row([index, name, cfg.interval_step]) + else: + table.field_names = ["Index", "Name"] + table.align["Name"] = "l" + for index, name in enumerate(self._mode_functor_names[mode]): + table.add_row([index, name]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_functors(self) -> dict[str, list[str]]: + """Name of active event functors. + + The keys are the modes of event and the values are the names of the event functors. + """ + return self._mode_functor_names + + @property + def available_modes(self) -> list[str]: + """Modes of events.""" + return list(self._mode_functor_names.keys()) + + """ + Operations. + """ + + def reset(self, env_ids: Union[Sequence[int], None] = None) -> dict[str, float]: + # call all functors that are classes + for mode_cfg in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfg: + functor_cfg.func.reset(env_ids=env_ids) + + # resolve number of environments + if env_ids is None: + num_envs = self._env.num_envs + else: + num_envs = len(env_ids) + + # May be add more useful reset logic later. + + # nothing to log here + return {} + + def apply( + self, + mode: str, + env_ids: Union[Sequence[int], None] = None, + ): + """Calls each event functor in the specified mode. + + This function iterates over all the event functors in the specified mode and calls the function + corresponding to the functor. The function is called with the environment instance and the environment + indices to apply the event to. + + For the "interval" mode, the function is called when the time interval has passed. This requires + specifying the time step of the environment. + + For the "reset" mode, the function is called when the mode is "reset" and the total number of environment + steps that have happened since the last trigger of the function is equal to its configured parameter for + the number of environment steps between resets. + + Args: + mode: The mode of event. + env_ids: The indices of the environments to apply the event to. + Defaults to None, in which case the event is applied to all environments when applicable. + + Raises: + ValueError: If the mode is ``"interval"`` and the environment indices are provided. This is an undefined + behavior as the environment indices are computed based on the time left for each environment. + ValueError: If the mode is ``"reset"`` and the total number of environment steps that have happened + is not provided. + """ + # check if mode is valid + if mode not in self._mode_functor_names: + logger.log_warning(f"Event mode '{mode}' is not defined. Skipping event.") + return + + if mode == "interval" and env_ids is not None: + logger.log_error( + f"Event mode '{mode}' does not require environment indices. This is an undefined behavior" + " as the environment indices are computed based on the time left for each environment." + ) + + # iterate over all the event functors + for index, functor_cfg in enumerate(self._mode_functor_cfgs[mode]): + functor_cfg: EventCfg + if mode == "interval": + self._interval_functor_step_count[index] += 1 + + # check if the interval has passed and sample a new interval + # note: we compare with a small value to handle floating point errors + if ( + functor_cfg.is_global + and self._interval_functor_step_count[index] + % functor_cfg.interval_step + == 0 + ): + + # call the event functor (with None for env_ids) + functor_cfg.func(self._env, None, **functor_cfg.params) + else: + valid_env_ids = ( + ( + self._interval_functor_step_count[index] + % functor_cfg.interval_step + == 0 + ) + .nonzero() + .flatten() + ) + if len(valid_env_ids) > 0: + # call the event functor + functor_cfg.func(self._env, valid_env_ids, **functor_cfg.params) + elif mode == "reset": + # resolve the environment indices + if env_ids is None: + env_ids = slice(None) + + functor_cfg.func(self._env, env_ids, **functor_cfg.params) + else: + # call the event functor + functor_cfg.func(self._env, env_ids, **functor_cfg.params) + + """ + Operations - Functor settings. + """ + + def set_functor_cfg(self, functor_name: str, cfg: EventCfg): + """Sets the configuration of the specified functor into the manager. + + The method finds the functor by name by searching through all the modes. + It then updates the configuration of the functor with the first matching name. + + Args: + functor_name: The name of the event functor. + cfg: The configuration for the event functor. + + Raises: + ValueError: If the functor name is not found. + """ + functor_found = False + for mode, functors in self._mode_functor_names.items(): + if functor_name in functors: + self._mode_functor_cfgs[mode][functors.index(functor_name)] = cfg + functor_found = True + break + if not functor_found: + logger.log_error(f"Event functor '{functor_name}' not found.") + + def get_functor_cfg(self, functor_name: str) -> EventCfg: + """Gets the configuration for the specified functor. + + The method finds the functor by name by searching through all the modes. + It then returns the configuration of the functor with the first matching name. + + Args: + functor_name: The name of the event functor. + + Returns: + The configuration of the event functor. + + Raises: + ValueError: If the functor name is not found. + """ + for mode, functors in self._mode_functor_names.items(): + if functor_name in functors: + return self._mode_functor_cfgs[mode][functors.index(functor_name)] + logger.log_error(f"Event functor '{functor_name}' not found.") + + """ + Operations - Visit functor. + """ + + def get_functor(self, functor_name: str): + """ + Retrieve a functor from the configuration by its name. + + Args: + functor_name (str): The name of the functor to retrieve. + + Returns: + The functor if it exists in the configuration, otherwise None. + """ + if hasattr(self.cfg, functor_name): + functor = getattr(self.cfg, functor_name).func + return functor + else: + logger.log_warning( + f"Got no functor {functor_name} in event_manager, please check again." + ) + return None + + """ + Helper functions. + """ + + def _prepare_functors(self): + # buffer to store the time left for "interval" mode + # if interval is global, then it is a single value, otherwise it is per environment + self._interval_functor_step_count: list[torch.Tensor] = list() + + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the functors + for functor_name, functor_cfg in cfg_items: + # check for non config + if functor_cfg is None: + continue + # check for valid config type + if not isinstance(functor_cfg, EventCfg): + raise TypeError( + f"Configuration for the functor '{functor_name}' is not of type EventCfg." + f" Received: '{type(functor_cfg)}'." + ) + + # resolve common parameters + self._resolve_common_functor_cfg(functor_name, functor_cfg, min_argc=2) + + # check if mode is a new mode + if functor_cfg.mode not in self._mode_functor_names: + # add new mode + self._mode_functor_names[functor_cfg.mode] = list() + self._mode_functor_cfgs[functor_cfg.mode] = list() + self._mode_class_functor_cfgs[functor_cfg.mode] = list() + # add functor name and parameters + self._mode_functor_names[functor_cfg.mode].append(functor_name) + self._mode_functor_cfgs[functor_cfg.mode].append(functor_cfg) + + # check if the functor is a class + if inspect.isclass(functor_cfg.func): + self._mode_class_functor_cfgs[functor_cfg.mode].append(functor_cfg) + + # resolve the mode of the events + # -- interval mode + if functor_cfg.mode == "interval": + # sample the time left for global + if functor_cfg.is_global: + count = torch.zeros(1, dtype=torch.int32, device=self.device) + self._interval_functor_step_count.append(count) + else: + count = torch.zeros( + self.num_envs, dtype=torch.int32, device=self.device + ) + self._interval_functor_step_count.append(count) diff --git a/embodichain/lab/gym/envs/managers/events.py b/embodichain/lab/gym/envs/managers/events.py new file mode 100644 index 00000000..1e089434 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/events.py @@ -0,0 +1,610 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +import os +import random + +from copy import deepcopy +from typing import TYPE_CHECKING, List, Union, Tuple, Dict + +from embodichain.lab.sim.objects import ( + Light, + RigidObject, + RigidObjectGroup, + Articulation, + Robot, +) +from embodichain.lab.sim.cfg import RigidObjectCfg, ArticulationCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg +from embodichain.lab.gym.envs.managers import Functor, FunctorCfg +from embodichain.utils.module_utils import find_function_from_modules +from embodichain.utils.string import remove_regex_chars, resolve_matching_names +from embodichain.utils.file import get_all_files_in_directory +from embodichain.utils.math import ( + sample_uniform, + pose_inv, + xyz_quat_to_4x4_matrix, + trans_matrix_to_xyz_quat, +) +from embodichain.utils import logger +from embodichain.data import get_data_path + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class replace_assets_from_group(Functor): + """Replace assets in the environment from a specified group of assets. + + The group of assets can be defined in the following ways: + - A directory containing multiple asset files. + - A json file listing multiple assets with their properties. (not supported yet) + - ... (other methods can be added in the future) + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the term. + + Args: + cfg: The configuration of the functor. + env: The environment instance. + + Raises: + ValueError: If the asset is not a RigidObject or an Articulation. + """ + super().__init__(cfg, env) + + # extract the used quantities (to enable type-hinting) + entity_cfg: SceneEntityCfg = cfg.params["entity_cfg"] + asset = env.sim.get_asset(entity_cfg.uid) + if asset is None: + logger.log_error( + f"Asset with UID '{entity_cfg.uid}' not found in the simulation." + ) + + if ( + isinstance(asset, RigidObject) + and isinstance(asset.cfg.shape, MeshCfg) is False + ): + logger.log_error( + "Only mesh-based RigidObject assets are supported for replacement." + ) + + self.asset_cfg = asset.cfg + self.asset_type = type(asset) + + if isinstance(asset, Articulation): + logger.log_error("Replacing articulation assets is not supported yet.") + + self._asset_group_path: list[str] = [] + + # The following block of code only handle rigid object assets. + # If we want to support articulation assets, the group path format + # should be changed into list of folder (each folder contains a urdf file + # and its associated resources) + folder_path = cfg.params.get("folder_path", None) + + if folder_path is None: + logger.log_error( + "folder_path must be specified in the functor configuration." + ) + + if folder_path.endswith("/") is False: + folder_path, patterns = os.path.split(folder_path) + + # remove regular expression from patterns + patterns = remove_regex_chars(patterns) + full_path = get_data_path(f"{folder_path}/") + self._asset_group_path = get_all_files_in_directory( + full_path, patterns=patterns + ) + else: + full_path = get_data_path(folder_path) + self._asset_group_path = get_all_files_in_directory(full_path) + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + folder_path: str, + ) -> None: + + env.sim.remove_asset(entity_cfg.uid) + asset_path = random.choice(self._asset_group_path) + self.asset_cfg.shape.fpath = asset_path + if self.asset_type == RigidObject: + new_asset = env.sim.add_rigid_object(cfg=self.asset_cfg) + else: + logger.log_error("Only RigidObject assets are supported for replacement.") + + +class prepare_extra_attr(Functor): + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """ + Initializes the event manager with the given configuration and environment. + + Args: + cfg (FunctorCfg): The configuration object for the functor. + env (EmbodiedEnv): The embodied environment instance. + + Attributes: + extra_attrs (dict): A dictionary to hold additional attributes. + """ + super().__init__(cfg, env) + + self.extra_attrs = {} + + def __call__( + self, env: EmbodiedEnv, env_ids: Union[torch.Tensor, None], attrs: List[Dict] + ) -> None: + """ + Processes extra attributes for the given environment. + + This method iterates over a list of attributes, validates them, and updates + the `extra_attrs` dictionary based on the specified modes and values. It handles + both static and callable attributes, logging warnings for any issues encountered. + + Args: + env (EmbodiedEnv): The environment instance to which the attributes are applied. + env_ids (Union[torch.Tensor, None]): Optional tensor of environment IDs (not used in this method). + attrs (List[Dict]): A list of dictionaries containing attribute configurations. + Each dictionary must contain a 'name', and may contain 'entity_cfg', 'entities', + 'mode', 'value', 'func_name', and 'func_kwargs'. + + Returns: + None: This method does not return a value. + """ + for attr_idx, attr in enumerate(attrs): + attr_name = attr.get("name", None) + if attr_name is None: + logger.log_warning( + f"{attr_idx}-th extra attribute got no name, skipping.." + ) + continue + if attr.get("entity_cfg", None) is not None: + entity_cfgs = [SceneEntityCfg(**attr["entity_cfg"])] + elif attr.get("entity_uids", None) is not None: + entity_uids = attr["entity_uids"] + if isinstance(entity_uids, (str, list)): + entity_uids = resolve_uids(env, entity_uids) + if entity_uids is None: + logger.log_warning( + f"Entities string {entity_uids} is not supported, skipping.." + ) + continue + else: + logger.log_warning( + f"Entities type {type(entity_uids)} is not supported, skipping.." + ) + continue + entity_cfgs = [SceneEntityCfg(uid=uid) for uid in entity_uids] + else: + logger.log_warning( + f"'entity_cfg' or 'entity_uids' must be provieded, skipping.." + ) + continue + + attr_mode = attr.get("mode", None) + if attr_mode is None: + logger.log_info( + f"Extra attribute {attr_name} got no mode, setting mode to default 'static'.", + color="green", + ) + attr_mode = "static" + + if attr_mode == "static": + attr_value = attr.get("value", None) + if attr_value is None: + logger.log_warning( + f"Extra attribute {attr_name} got mode 'static' but no value, skipping.." + ) + continue + for cfg in entity_cfgs: + if cfg.uid not in self.extra_attrs: + self.extra_attrs[cfg.uid] = {} + self.extra_attrs[cfg.uid].update({attr_name: attr_value}) + + elif attr_mode == "callable": + attr_func_name = attr.get("func_name", None) + if attr_func_name is None: + logger.log_info( + f"Extra attribute {attr_name} got mode 'callable' but no 'func_name', skipping..", + color="green", + ) + continue + + attr_func_kwargs = attr.get("func_kwargs", None) + if attr_func_name is None: + logger.log_info( + f"Extra attribute {attr_name} got no func_kwargs, setting func_kwargs to default empty dict..", + color="green", + ) + attr_func_kwargs = {} + + is_global_func = True + ASSET_MODULES = [ + "embodichain.lab.gym.envs.object", + "embodichain.lab.gym.utils.misc", + ] + global_func = find_function_from_modules( + attr_func_name, modules=ASSET_MODULES, raise_if_not_found=False + ) + if global_func is None: + is_global_func = False + for cfg in entity_cfgs: + if cfg.uid not in self.extra_attrs: + self.extra_attrs[cfg.uid] = {} + if not is_global_func: + asset = env.sim.get_asset(cfg.uid) + if callable((attr_func := getattr(asset, attr_func_name))): + attr_func_ret = attr_func(**attr_func_kwargs) + else: + logger.log_warning( + f"Extra attribute {attr_name} got no attr_func_name '{attr_func_name}', skipping.." + ) + continue + else: + attr_func_kwargs.update( + {"env": env, "env_ids": env_ids, "entity_cfg": cfg} + ) + attr_func_ret = global_func(**attr_func_kwargs) + self.extra_attrs[cfg.uid].update({attr_name: attr_func_ret}) + + +def register_entity_attrs( + env: EmbodiedEnv, + env_ids: torch.Tensor, + entity_cfg: SceneEntityCfg, + registration: str = "affordance_datas", + attrs: List[str] = [], + prefix: bool = True, +): + """Register the atrributes of an entity to the `env.registration` dict. + + TODO: Currently this method only support 1 env or multi-envs that reset() together, + + as it's behavior is to update a overall dict every time it's called. + + In the future, asynchronously reset mode shall be supported. + + Args: + env (EmbodiedEnv): The environment the entity is in. + env_ids (Union[torch.Tensor, None]): The ids of the envs that the entity should be registered. + entity_cfg (SceneEntityCfg): The config of the entity. + attrs (List[str]): The list of entity attributes that asked to be registered. + registration (str, optional): The env's registration string where the attributes should be injected to. + """ + entity = env.sim.get_asset(entity_cfg.uid) + + if not hasattr(env, registration): + logger.log_warning( + f"Environment has no atrtribute {registration} for registration, please check again." + ) + return + else: + registration_dict = getattr(env, registration, None) + if not isinstance(registration_dict, Dict): + logger.log_warning( + f"Got registration env.{registration} with type {type(registration_dict)}, please check again." + ) + return + + for attr in attrs: + attr_key = f"{entity_cfg.uid}_{attr}" if prefix else attr + if (attr_val := getattr(entity, attr_key, None)) is not None: + registration_dict.update({attr_key: attr_val}) + elif ( + attr_val := getattr( + env.event_manager.get_functor("prepare_extra_attr"), "extra_attrs", {} + ) + .get(entity_cfg.uid, {}) + .get(attr) + ) is not None: + registration_dict.update({attr_key: attr_val}) + else: + logger.log_warning( + f"Attr {attr} for entity {entity_cfg.uid} has neither been found in entity attrbutes nor prepare_extra_attrs functor, skipping.." + ) + + +def register_entity_pose( + env: EmbodiedEnv, + env_ids: torch.Tensor, + entity_cfg: SceneEntityCfg, + registration: str = "affordance_datas", + compute_relative: Union[bool, List, str] = "all_robots", + compute_pose_object_to_arena: bool = True, + to_matrix: bool = True, +): + update_registration_dict = {} + if not hasattr(env, registration): + logger.log_warning( + f"Environment has no atrtribute {registration} for registration, please check again." + ) + return + else: + registration_dict = getattr(env, registration, None) + if not isinstance(registration_dict, Dict): + logger.log_warning( + f"Got registration env.{registration} with type {type(registration_dict)}, please check again." + ) + return + + entity_pose_name, entity_pose = get_pose( + env, env_ids, entity_cfg, return_name=True, to_matrix=True + ) + update_registration_dict.update({entity_pose_name: entity_pose}) + + if compute_relative: + # transform other entity's pose to entity frame + relative_poses = {} + if compute_relative == True: + entity_uids = ( + env.sim.get_articulation_uid_list() + + env.sim.get_rigid_object_uid_list() + + env.sim.get_robot_uid_list() + ) + elif isinstance(compute_relative, (str, list)): + entity_uids = resolve_uids(env, compute_relative) + else: + logger.log_warning( + f"Compute relative pose option with type {type(compute_relative)} is not supported, using empty list for skipping.." + ) + entity_uids = [] + + for other_entity_uid in entity_uids: + if other_entity_uid != entity_cfg.uid: + # TODO: this is only for asset + other_entity_pose = env.sim.get_asset(other_entity_uid).get_local_pose( + to_matrix=True + )[env_ids, :] + relative_pose = torch.bmm(pose_inv(entity_pose), other_entity_pose) + relative_poses.update( + { + f"{other_entity_uid}_pose_{entity_pose_name.replace('_pose', '')}": relative_pose + } + ) + + update_registration_dict.update(relative_poses) + + entity = env.sim.get_asset(entity_cfg.uid) + if isinstance(entity, RigidObject): + extra_attr_functor = env.event_manager.get_functor("prepare_extra_attr") + entity_extra_attrs = getattr(extra_attr_functor, "extra_attrs", {}).get( + entity_cfg.uid, {} + ) + for ( + entity_extra_attr_key, + entity_extra_attr_val, + ) in entity_extra_attrs.items(): + if entity_extra_attr_key.endswith("_pose_object"): + entity_extra_attr_val = torch.as_tensor( + entity_extra_attr_val, device=env.device + ) + if entity_extra_attr_val.ndim < 3: + logger.log_info( + f"Got xyz_quat pose {entity_extra_attr_key}: {entity_extra_attr_val}, transforming it to matrix.", + color="green", + ) + entity_extra_attr_val = xyz_quat_to_4x4_matrix( + entity_extra_attr_val + ) + update_registration_dict.update( + { + entity_cfg.uid + + "_" + + (entity_extra_attr_key): entity_extra_attr_val + } + ) + if compute_pose_object_to_arena: + pose_arena = torch.bmm(entity_pose, entity_extra_attr_val) + update_registration_dict.update( + { + entity_cfg.uid + + "_" + + ( + entity_extra_attr_key.replace("_pose_object", "_pose") + ): pose_arena + } + ) + else: + logger.log_warning( + f"Now compute_pose_object_to_arena only support RigidObject type entity, skipping.." + ) + + if not to_matrix: + for key, val in update_registration_dict.items(): + update_registration_dict[key] = trans_matrix_to_xyz_quat(val) + + registration_dict = getattr(env, registration, None) + if not isinstance(registration_dict, Dict): + logger.log_warning( + f"Got registration env.{registration} with type {type(registration_dict)}, please check again." + ) + return + registration_dict.update(update_registration_dict) + + +def register_info_to_env( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + registry: List[Dict], + registration: str = "affordance_datas", + sim_update: bool = True, +): + if env_ids is None: + env_ids = torch.arange(env.num_envs, device=env.device) + if sim_update: + logger.log_info( + "Calling env.sim.update(100) for after-physics-applied object attributes..", + color="green", + ) + env.sim.update(100) + for entity_registry in registry: + entity_cfg = SceneEntityCfg(**entity_registry["entity_cfg"]) + logger.log_info(f"Registering {entity_cfg.uid}..", color="green") + if (entity_attrs := entity_registry.get("attrs")) is not None: + prefix = entity_registry.get("prefix", True) + register_entity_attrs( + env, env_ids, entity_cfg, registration, entity_attrs, prefix + ) + if ( + pose_register_params := entity_registry.get("pose_register_params") + ) is not None: + register_entity_pose( + env, env_ids, entity_cfg, registration, **pose_register_params + ) + + +"""Helper Function""" + + +def resolve_uids(env: EmbodiedEnv, entity_uids: Union[List[str], str]) -> List[str]: + if isinstance(entity_uids, str): + if entity_uids == "all_objects": + entity_uids = ( + env.sim.get_rigid_object_uid_list() + + env.sim.get_articulation_uid_list() + ) + elif entity_uids == "all_robots": + entity_uids = env.sim.get_robot_uid_list() + elif entity_uids == "all_sensors": + entity_uids = env.sim.get_sensor_uid_list() + else: + # logger.log_warning(f"Entity uids {entity_uids} not supported in ['all_objects', 'all_robots', 'all_sensors'], wrapping it as a list..") + entity_uids = [entity_uids] + elif isinstance(entity_uids, (list, set, tuple)): + entity_uids = list(entity_uids) + else: + logger.log_error( + f"Entity uids {entity_uids} with type {type(entity_uids)} not supported in [List[str], str], please check again." + ) + return entity_uids + + +def resolve_dict(env: EmbodiedEnv, entity_dict: Dict): + for entity_key in list(entity_dict.keys()): + entity_val = entity_dict.pop(entity_key) + entity_uids = resolve_uids(env, entity_key) + for entity_uid in entity_uids: + entity_dict.update({entity_uid: deepcopy(entity_val)}) + return entity_dict + + +EntityWithPose = Union[RigidObject, Robot] + + +def get_pose( + env: EmbodiedEnv, + env_ids: torch.Tensor, + entity_cfg: SceneEntityCfg, + return_name: bool = True, + to_matrix: bool = True, +): + entity = env.sim.get_asset(entity_cfg.uid) + + if isinstance(entity, RigidObject): + entity_pose = entity.get_local_pose(to_matrix=to_matrix)[env_ids, :] + entity_pose_register_name = entity_cfg.uid + "_pose" + elif isinstance(entity, Robot): + _, control_parts = resolve_matching_names( + entity_cfg.control_parts, list(entity.control_parts.keys()) + ) + if len(control_parts) != 1: + logger.log_warning( + "Only 1 control part can be assigned for computing the robot pose, please check again. Skipping" + ) + return None + entity_cfg.control_parts = control_parts + control_part = control_parts[0] + control_part_qpos = entity.get_qpos()[ + env_ids, entity.get_joint_ids(control_part) + ] + entity_pose = entity.compute_fk( + control_part_qpos, name=control_part, to_matrix=to_matrix + ) # NOTE: now compute_fk returns arena pose + entity_pose_register_name = control_part + "_pose" + else: + logger.log_warning( + f"Entity with tyope {type(entity)} is not supported, please check again." + ) + return None + + if return_name: + return entity_pose_register_name, entity_pose + else: + return entity_pose + + +def drop_rigid_object_group_sequentially( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + drop_position: List[float] = [0.0, 0.0, 1.0], + position_range: Tuple[List[float], List[float]] = ( + [-0.1, -0.1, 0.0], + [0.1, 0.1, 0.0], + ), + physics_step: int = 2, +) -> None: + """Drop rigid object group from a specified height sequentially in the environment. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + drop_position (List[float]): The base position from which to drop the objects. Default is [0.0, 0.0, 1.0]. + position_range (Tuple[List[float], List[float]]): The range for randomizing the drop position around the base position. + physics_step (int): The number of physics steps to simulate after dropping the objects. Default is 2. + """ + + obj_group: RigidObjectGroup = env.sim.get_rigid_object_group(entity_cfg.uid) + + if obj_group is None: + logger.log_error( + f"RigidObjectGroup with UID '{entity_cfg.uid}' not found in the simulation." + ) + + num_instance = len(env_ids) + num_objects = obj_group.num_objects + + range_low = torch.tensor(position_range[0], device=env.device) + range_high = torch.tensor(position_range[1], device=env.device) + drop_pos = ( + torch.tensor(drop_position, device=env.device) + .unsqueeze_(0) + .repeat(num_instance, 1) + ) + drop_pose = torch.zeros((num_instance, 7), device=env.device) + drop_pose[:, 3] = 1.0 # w component of quaternion + drop_pose[:, :3] = drop_pos + for i in range(num_objects): + random_offset = sample_uniform( + lower=range_low, + upper=range_high, + size=(num_instance, 3), + ) + drop_pose_i = drop_pose.unsqueeze(1) + drop_pose_i[:, 0, :3] = drop_pos + random_offset + + obj_group.set_local_pose(pose=drop_pose_i, env_ids=env_ids, obj_ids=[i]) + + env.sim.update(step=physics_step) diff --git a/embodichain/lab/gym/envs/managers/manager_base.py b/embodichain/lab/gym/envs/managers/manager_base.py new file mode 100644 index 00000000..d89503bc --- /dev/null +++ b/embodichain/lab/gym/envs/managers/manager_base.py @@ -0,0 +1,408 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# All rights reserved. +# +# This file incorporates code from the Isaac Lab Project +# Copyright (c) 2022-2025, The Isaac Lab Project Developers +# (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import copy +import inspect +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Union + +from embodichain.utils.string import string_to_callable, resolve_matching_names +from embodichain.utils.utility import class_to_dict +from embodichain.utils import logger + +from .cfg import FunctorCfg, SceneEntityCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class Functor(ABC): + """Base class for Functor. + + Functor implementations can be functions or classes. If the functor is a class, it should + inherit from this base class and implement the required methods. + + Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager + class should also have a corresponding configuration class that defines the configuration functors for the + manager. Each functor should the :class:`FunctorCfg` class or its subclass. + + Example pseudo-code for creating a manager: + + .. code-block:: python + + from embodichain.utils import configclass + from embodichain.lab.gym.managers import ManagerBase + from embodichain.lab.gym.managers FunctorCfg + + @configclass + class MyManagerCfg: + + functor1: FunctorCfg = FunctorCfg(...) + functor2: FunctorCfg = FunctorCfg(...) + functor3: FunctorCfg = FunctorCfg(...) + + # define manager instance + my_manager = ManagerBase(cfg=ManagerCfg(), env=env) + + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the functor. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + # store the inputs + self.cfg = cfg + self._env = env + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + def __name__(self) -> str: + """Return the name of the class or subclass.""" + return self.__class__.__name__ + + """ + Operations. + """ + + def reset(self, env_ids: Union[Sequence[int], None] = None) -> None: + """Resets the functor. + + Args: + env_ids: The environment ids. Defaults to None, in which case + all environments are considered. + """ + pass + + def serialize(self) -> dict: + """General serialization call. Includes the configuration dict.""" + return {"cfg": class_to_dict(self.cfg)} + + def __call__(self, *args) -> Any: + """Returns the value of the functor required by the manager. + + In case of a class implementation, this function is called by the manager + to get the value of the functor. The arguments passed to this function are + the ones specified in the functor configuration (see :attr:`FunctorCfg.params`). + + .. attention:: + To be consistent with memory-less implementation of functors with functions, it is + recommended to ensure that the returned mutable quantities are cloned before + returning them. For instance, if the functor returns a tensor, it is recommended + to ensure that the returned tensor is a clone of the original tensor. This prevents + the manager from storing references to the tensors and altering the original tensors. + + Args: + *args: Variable length argument list. + + Returns: + The value of the functor. + """ + raise NotImplementedError( + "The method '__call__' should be implemented by the subclass." + ) + + +class ManagerBase(ABC): + """Base class for all managers.""" + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the manager. + + This function is responsible for parsing the configuration object and creating the functors. + + If the simulation is not playing, the scene entities are not resolved immediately. + Instead, the resolution is deferred until the simulation starts. This is done to ensure + that the scene entities are resolved even if the manager is created after the simulation + has already started. + + Args: + cfg: The configuration object. If None, the manager is initialized without any functors. + env: The environment instance. + """ + # store the inputs + self.cfg = copy.deepcopy(cfg) + self._env = env + + # parse config to create functors information + if self.cfg: + self._prepare_functors() + + def __repr__(self) -> str: + return self.__str__() + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + @abstractmethod + def active_functors(self) -> Union[list[str], dict[str, list[str]]]: + """Name of active functors.""" + raise NotImplementedError + + """ + Operations. + """ + + def reset(self, env_ids: Union[Sequence[int], None] = None) -> dict[str, float]: + """Resets the manager and returns logging information for the current time-step. + + Args: + env_ids: The environment ids for which to log data. + Defaults None, which logs data for all environments. + + Returns: + Dictionary containing the logging information. + """ + return {} + + def find_functors(self, name_keys: Union[str, Sequence[str]]) -> list[str]: + """Find functors in the manager based on the names. + + This function searches the manager for functors based on the names. The names can be + specified as regular expressions or a list of regular expressions. The search is + performed on the active functors in the manager. + + Please check the :meth:`~embodichain.utils.string.resolve_matching_names` function for more + information on the name matching. + + Args: + name_keys: A regular expression or a list of regular expressions to match the functor names. + + Returns: + A list of functor names that match the input keys. + """ + # resolve search keys + if isinstance(self.active_functors, dict): + list_of_strings = [] + for names in self.active_functors.values(): + list_of_strings.extend(names) + else: + list_of_strings = self.active_functors + + # return the matching names + return resolve_matching_names(name_keys, list_of_strings)[1] + + def get_active_iterable_functors( + self, env_idx: int + ) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active functors as iterable sequence of tuples. + + The first element of the tuple is the name of the functor and the second element is the raw value(s) of the functor. + + Returns: + The active functors. + """ + raise NotImplementedError + + """ + Implementation specific. + """ + + @abstractmethod + def _prepare_functors(self): + """Prepare functors information from the configuration object.""" + raise NotImplementedError + + """ + Internal callbacks. + """ + + def _resolve_functors_callback(self, event): + """Resolve configurations of functors once the simulation starts. + + Please check the :meth:`_process_functor_cfg_at_play` method for more information. + """ + # check if scene entities have been resolved + if self._is_scene_entities_resolved: + return + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # iterate over all the functors + for functor_name, functor_cfg in cfg_items: + # check for non config + if functor_cfg is None: + continue + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + self._process_functor_cfg_at_play(functor_name, functor_cfg) + + # set the flag + self._is_scene_entities_resolved = True + + """ + Internal functions. + """ + + def _resolve_common_functor_cfg( + self, functor_name: str, functor_cfg: FunctorCfg, min_argc: int = 1 + ): + """Resolve common attributes of the functor configuration. + + Usually, called by the :meth:`_prepare_functors` method to resolve common attributes of the functor + configuration. These include: + + * Resolving the functor function and checking if it is callable. + * Checking if the functor function's arguments are matched by the parameters. + * Resolving special attributes of the functor configuration like ``entity_cfg``, ``sensor_cfg``, etc. + * Initializing the functor if it is a class. + + By default, all functor functions are expected to have at least one argument, which is the + environment object. Some other managers may expect functions to take more arguments, for + instance, the environment indices as the second argument. In such cases, the + ``min_argc`` argument can be used to specify the minimum number of arguments + required by the functor function to be called correctly by the manager. + + Args: + functor_name: The name of the functor. + functor_cfg: The functor configuration. + min_argc: The minimum number of arguments required by the functor function to be called correctly + by the manager. + + Raises: + TypeError: If the functor configuration is not of type :class:`FunctorCfg`. + ValueError: If the scene entity defined in the functor configuration does not exist. + AttributeError: If the functor function is not callable. + ValueError: If the functor function's arguments are not matched by the parameters. + """ + # check if the functor is a valid functor config + if not isinstance(functor_cfg, FunctorCfg): + raise TypeError( + f"Configuration for the functor '{functor_name}' is not of type FunctorCfg." + f" Received: '{type(functor_cfg)}'." + ) + + # get the corresponding function or functional class + if isinstance(functor_cfg.func, str): + functor_cfg.func = string_to_callable(functor_cfg.func) + # check if function is callable + if not callable(functor_cfg.func): + raise AttributeError( + f"The functor '{functor_name}' is not callable. Received: {functor_cfg.func}" + ) + + # check if the functor is a class of valid type + if inspect.isclass(functor_cfg.func): + if not issubclass(functor_cfg.func, Functor): + raise TypeError( + f"Configuration for the functor '{functor_name}' is not of type ManagerTermBase." + f" Received: '{type(functor_cfg.func)}'." + ) + func_static = functor_cfg.func.__call__ + min_argc += 1 # forward by 1 to account for 'self' argument + else: + func_static = functor_cfg.func + # check if function is callable + if not callable(func_static): + raise AttributeError( + f"The functor '{functor_name}' is not callable. Received: {functor_cfg.func}" + ) + + # check statically if the functor's arguments are matched by params + functor_params = list(functor_cfg.params.keys()) + args = inspect.signature(func_static).parameters + args_with_defaults = [ + arg for arg in args if args[arg].default is not inspect.Parameter.empty + ] + args_without_defaults = [ + arg for arg in args if args[arg].default is inspect.Parameter.empty + ] + args = args_without_defaults + args_with_defaults + # ignore first two arguments for env and env_ids + # Think: Check for cases when kwargs are set inside the function? + if len(args) > min_argc: + if set(args[min_argc:]) != set(functor_params + args_with_defaults): + raise ValueError( + f"The functor '{functor_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}" + f" and optional parameters: {args_with_defaults}, but received: {functor_params}." + ) + + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + self._process_functor_cfg_at_play(functor_name, functor_cfg) + + def _process_functor_cfg_at_play(self, functor_name: str, functor_cfg: FunctorCfg): + """Process the functor configuration at runtime. + + This function is called when the simulation starts playing. It is used to process the functor + configuration at runtime. This includes: + + * Resolving the scene entity configuration for the functor. + * Initializing the functor if it is a class. + + Since the above steps rely on PhysX to parse over the simulation scene, they are deferred + until the simulation starts playing. + + Args: + functor_name: The name of the functor. + functor_cfg: The functor configuration. + """ + for key, value in functor_cfg.params.items(): + if isinstance(value, SceneEntityCfg): + # load the entity + try: + value.resolve(self._env.sim) + except ValueError as e: + raise ValueError(f"Error while parsing '{functor_name}:{key}'. {e}") + # log the entity for checking later + msg = f"[{functor_cfg.__class__.__name__}:{functor_name}] Found entity '{value.uid}'." + if value.joint_ids is not None: + msg += f"\n\tJoint names: {value.joint_names} [{value.joint_ids}]" + if value.body_ids is not None: + msg += f"\n\tBody names: {value.body_names} [{value.body_ids}]" + # print the information + print(f"[INFO]: {msg}") + # store the entity + functor_cfg.params[key] = value + + # initialize the functor if it is a class + if inspect.isclass(functor_cfg.func): + try: + logger.log_info( + f"Initializing functor '{functor_name}' with class '{functor_cfg.func.__name__}'." + ) + functor_cfg.func = functor_cfg.func(cfg=functor_cfg, env=self._env) + except Exception as e: + logger.log_error(f"Failed to initialize functor '{functor_name}': {e}") diff --git a/embodichain/lab/gym/envs/managers/observation_manager.py b/embodichain/lab/gym/envs/managers/observation_manager.py new file mode 100644 index 00000000..cd1dc70a --- /dev/null +++ b/embodichain/lab/gym/envs/managers/observation_manager.py @@ -0,0 +1,213 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Observation manager for orchestrating operations based on different simulation observations.""" + +from __future__ import annotations + +import inspect +import torch +from collections.abc import Sequence +from prettytable import PrettyTable +from typing import TYPE_CHECKING, Union + +from embodichain.utils import logger +from embodichain.lab.sim.types import EnvObs +from embodichain.lab.gym.utils.gym_utils import ( + fetch_data_from_dict, + assign_data_to_dict, +) +from .manager_base import ManagerBase +from .cfg import ObservationCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class ObservationManager(ManagerBase): + """Manager for orchestrating operations based on different simulation observations. + + The default observation space will contain two observation groups: + - `robot`: Contains the default observations related to the robot. + - `qpos`: The joint positions of the robot. + - `qvel`: The joint velocities of the robot. + - `qf`: The joint forces of the robot. + - `sensor`: Contains the observations related to the sensors which are enabled in the environment. + + The observation manager offers two modes of operation: + - `modify`: This mode perform data fetching and modification on existing observation data. + - `add`: This mode perform new observation computation and add new observation data to the observation space. + """ + + _env: EmbodiedEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the observation manager. + + Args: + cfg: A configuration object or dictionary (``dict[str, ObservationCfg]``). + env: An environment object. + """ + + self._mode_functor_names: dict[str, list[str]] = dict() + self._mode_functor_cfgs: dict[str, list[ObservationCfg]] = dict() + self._mode_class_functor_cfgs: dict[str, list[ObservationCfg]] = dict() + + # call the base class (this will parse the functors config) + super().__init__(cfg, env) + + def __str__(self) -> str: + """Returns: A string representation for observation manager.""" + functor_num = sum(len(v) for v in self._mode_functor_names.values()) + msg = f" contains {functor_num} active functors.\n" + + # add info on each mode + for mode in self._mode_functor_names: + # create table for functor information + table = PrettyTable() + table.title = f"Active Observation Functors in Mode: '{mode}'" + + table.field_names = ["Index", "Name"] + table.align["Name"] = "l" + for index, name in enumerate(self._mode_functor_names[mode]): + table.add_row([index, name]) + + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_functors(self) -> dict[str, list[str]]: + """Name of active observation functors. + + The keys are the modes of observation and the values are the names of the observation functors. + """ + return self._mode_functor_names + + """ + Operations. + """ + + def reset(self, env_ids: Union[Sequence[int], None] = None) -> dict[str, float]: + # call all functors that are classes + for mode_cfg in self._mode_class_functor_cfgs.values(): + for functor_cfg in mode_cfg: + functor_cfg.func.reset(env_ids=env_ids) + + # nothing to log here + return {} + + def compute( + self, + obs: EnvObs, + ) -> EnvObs: + """Calls each observation functor in the specified mode. + + This function iterates over all the observation functors in the specified mode and calls the function + corresponding to the functor. The function is called with the environment instance and the environment + indices to apply the observation to. + + Args: + obs: The observation data to apply the observation to. + + Returns: + The modified observation data. + + Raises: + ValueError: If the mode is not supported. + """ + + # iterate over all the observation functors + for mode, functor_cfgs in self._mode_functor_cfgs.items(): + for functor_cfg in functor_cfgs: + functor_cfg: ObservationCfg + + if mode == "modify": + data = fetch_data_from_dict(obs, functor_cfg.name) + data = functor_cfg.func(self._env, data, **functor_cfg.params) + elif mode == "add": + data = functor_cfg.func(self._env, obs, **functor_cfg.params) + assign_data_to_dict(obs, functor_cfg.name, data) + else: + logger.log_error(f"Unsupported observation mode '{mode}'.") + + return obs + + def get_functor_cfg(self, functor_name: str) -> ObservationCfg: + """Gets the configuration for the specified functor. + + The method finds the functor by name by searching through all the modes. + It then returns the configuration of the functor with the first matching name. + + Args: + functor_name: The name of the observation functor. + + Returns: + The configuration of the observation functor. + + Raises: + ValueError: If the functor name is not found. + """ + for mode, functors in self._mode_functor_names.items(): + if functor_name in functors: + return self._mode_functor_cfgs[mode][functors.index(functor_name)] + logger.log_error(f"observation functor '{functor_name}' not found.") + + """ + Helper functions. + """ + + def _prepare_functors(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the functors + for functor_name, functor_cfg in cfg_items: + # check for non config + if functor_cfg is None: + continue + # check for valid config type + if not isinstance(functor_cfg, ObservationCfg): + raise TypeError( + f"Configuration for the functor '{functor_name}' is not of type ObservationCfg." + f" Received: '{type(functor_cfg)}'." + ) + + # resolve common parameters + self._resolve_common_functor_cfg(functor_name, functor_cfg, min_argc=2) + + # check if mode is a new mode + if functor_cfg.mode not in self._mode_functor_names: + # add new mode + self._mode_functor_names[functor_cfg.mode] = list() + self._mode_functor_cfgs[functor_cfg.mode] = list() + self._mode_class_functor_cfgs[functor_cfg.mode] = list() + # add functor name and parameters + self._mode_functor_names[functor_cfg.mode].append(functor_name) + self._mode_functor_cfgs[functor_cfg.mode].append(functor_cfg) + + # check if the functor is a class + if inspect.isclass(functor_cfg.func): + self._mode_class_functor_cfgs[functor_cfg.mode].append(functor_cfg) diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py new file mode 100644 index 00000000..c628384c --- /dev/null +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -0,0 +1,615 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +import os +import random +from typing import TYPE_CHECKING, Literal, Union, Optional, List, Dict, Sequence + +from embodichain.lab.sim.objects import RigidObject, Articulation, Robot +from embodichain.lab.sim.sensors import Camera, StereoCamera +from embodichain.lab.sim.types import EnvObs +from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg +from embodichain.lab.gym.envs.managers.events import resolve_dict +from embodichain.lab.gym.envs.managers import Functor, FunctorCfg +from embodichain.utils import logger + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +def get_rigid_object_pose( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, +) -> torch.Tensor: + """Get the world poses of the rigid objects in the environment. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A tensor of shape (num_envs, 4, 4) representing the world poses of the rigid objects. + """ + + obj = env.sim.get_rigid_object(entity_cfg.uid) + + return obj.get_local_pose(to_matrix=True) + + +def normalize_robot_joint_data( + env: EmbodiedEnv, + data: torch.Tensor, + joint_ids: Sequence[int], + limit: Literal["qpos_limits", "qvel_limits"] = "qpos_limits", +) -> torch.Tensor: + """Normalize the robot joint positions to the range of [0, 1] based on the joint limits. + + Args: + env: The environment instance. + obs: The observation dictionary. + joint_ids: The indices of the joints to be normalized. + limit: The type of joint limits to be used for normalization. Options are: + - `qpos_limits`: Use the joint position limits for normalization. + - `qvel_limits`: Use the joint velocity limits for normalization. + """ + + robot = env.robot + + # shape of target_limits: (num_envs, len(joint_ids), 2) + target_limits = getattr(robot.body_data, limit)[:, joint_ids, :] + + # normalize the joint data to the range of [0, 1] + data[:, joint_ids] = (data[:, joint_ids] - target_limits[:, :, 0]) / ( + target_limits[:, :, 1] - target_limits[:, :, 0] + ) + + return data + + +def compute_semantic_mask( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, + foreground_uids: Sequence[str], + is_right: bool = False, +) -> torch.Tensor: + """Compute the semantic mask for the specified scene entity. + + Note: + The semantic mask is defined as (B, H, W, 3) where the three channels represents: + - robot channel: the instance id of the robot is set to 1 (0 if not robot) + - background channel: the instance id of the background is set to 1 (0 if not background) + - foreground channel: the instance id of the foreground objects is set to 1 (0 if not foreground) + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + foreground_uids: The list of uids for the foreground objects. + is_right: Whether to use the right camera for stereo cameras. Default is False. + Only applicable if the sensor is a StereoCamera. + + Returns: + A tensor of shape (num_envs, height, width) representing the semantic mask. + """ + + sensor: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid) + if sensor.cfg.enable_mask is False: + logger.log_error( + f"Sensor '{entity_cfg.uid}' does not have mask enabled. Please enable the mask in the sensor configuration." + ) + + if isinstance(sensor, StereoCamera) and is_right: + mask = obs["sensor"][entity_cfg.uid]["mask_right"] + else: + mask = obs["sensor"][entity_cfg.uid]["mask"] + + robot_uids = env.robot.get_user_ids() + + mask_exp = mask.unsqueeze(-1) + + robot_uids_exp = robot_uids.unsqueeze_(1).unsqueeze_(1) + + robot_mask = (mask_exp == robot_uids_exp).any(-1).squeeze_(-1) + + foreground_assets = [env.sim.get_asset(uid) for uid in foreground_uids] + + # cat assets uid (num_envs, n) into dim 1 + foreground_uids = torch.cat( + [ + asset.get_user_ids().unsqueeze(1) + if asset.get_user_ids().dim() == 1 + else asset.get_user_ids() + for asset in foreground_assets + ], + dim=1, + ) + + foreground_uids_exp = foreground_uids.unsqueeze_(1).unsqueeze_(1) + + foreground_mask = (mask_exp == foreground_uids_exp).any(-1).squeeze_(-1) + + background_mask = ~(robot_mask | foreground_mask).squeeze_(-1) + + return torch.stack([robot_mask, background_mask, foreground_mask], dim=-1) + + +class compute_exteroception(Functor): + """Compute the exteroception for the observation space. + + The exteroception is currently defined as a set of keypoints around a reference pose, which are prjected from 3D + space to 2D image plane. + The reference pose can derive from the following sources: + - Pose from robot control part (e.g., end-effector, usually tcp pose) + - Object affordance pose (e.g., handle pose of a mug or a pick pose of a cube) + + Therefore, the exteroception are defined in the camera-like sensor, for example. + descriptor = { + "cam_high": [ + { + "type": "affordance", + "obj_uid": "obj1", + "key": "grasp_pose", + "is_arena_coord": True + }, + { + "type": "affordance", + "obj_uid": "obj1", + "key": "place_pose", + }, + { + "type": "robot", + "control_part": "left_arm", + }, + { + "type": "robot", + "control_part": "right_arm", + } + ], + ... + } + + Explanation of the parameters: + - The key of the dictionary is the sensor uid. + - The value is another dictionary, where the key is the source type, and the value is a dictionary of parameters. + - For `affordance` source type, the parameters are: + - `obj_uid`: The uid of the object to get the affordance pose from. + - `key`: The key of the affordance pose in the affordance data. + - `is_arena_coord`: Whether the affordance pose is in the arena coordinate system. Default is False. + - For `robot` source type, the parameters are: + - `control_part`: The control part of the robot to get the pose from. + """ + + def __init__( + self, + cfg: FunctorCfg, + env: EmbodiedEnv, + ): + super().__init__(cfg, env) + + if self._env.num_envs != 1: + logger.log_error( + f"Exteroception functor only supported env with 'num_envs=1' but got 'num_envs={self._env.num_envs}'. Please check again." + ) + + self._valid_source = ["robot", "affordance"] + + @staticmethod + def shift_pose(pose: torch.Tensor, axis: int, shift: float) -> torch.Tensor: + """Shift the pose along the specified axis by the given amount. + + Args: + pose: The original pose tensor of shape (B, 4, 4). + axis: The axis along which to shift (0 for x, 1 for y, 2 for z). + shift: The amount to shift along the specified axis. + """ + shift_pose = torch.linalg.inv(pose) + shift_pose[:, axis, -1] += shift + shift_pose = torch.linalg.inv(shift_pose) + return shift_pose + + @staticmethod + def expand_pose( + pose: torch.Tensor, + x_interval: float, + y_interval: float, + kpnts_number: int, + ref_pose: torch.Tensor = None, + ) -> torch.Tensor: + """Expand pose with keypoints along x and y axes. + + Args: + pose: The original pose tensor of shape (B, 4, 4). + x_interval: The interval for expanding along x-axis. + y_interval: The interval for expanding along y-axis. + kpnts_number: Number of keypoints to generate for each axis. + ref_pose: Reference pose tensor of shape (B, 4, 4). If None, uses identity matrix. + + Returns: + Expanded poses tensor of shape (B, 1 + 2*kpnts_number, 4, 4). + """ + batch_size = pose.shape[0] + device = pose.device + + # Create default reference pose if not provided + if ref_pose is None: + ref_pose = ( + torch.eye(4, device=device).unsqueeze_(0).repeat(batch_size, 1, 1) + ) + + # Start with the original pose transformed by ref_pose + ret = [ref_pose @ pose] + + # Generate x-axis offsets and expand poses + # TODO: only support 1 env + xoffset = torch.linspace(-x_interval, x_interval, kpnts_number, device=device) + for x_shift in xoffset: + shifted_pose = compute_exteroception.shift_pose(pose, 0, x_shift.item()) + x_expanded = ref_pose @ shifted_pose + ret.append(x_expanded) + + # Generate y-axis offsets and expand poses + # TODO: only support 1 env + yoffset = torch.linspace(-y_interval, y_interval, kpnts_number, device=device) + for y_shift in yoffset: + shifted_pose = compute_exteroception.shift_pose(pose, 1, y_shift.item()) + y_expanded = ref_pose @ shifted_pose + ret.append(y_expanded) + + # Stack all poses along a new dimension + return torch.stack(ret, dim=1) + + @staticmethod + def _project_3d_to_2d( + cam_pose: torch.Tensor, + intrinsics: torch.Tensor, + height: int, + width: int, + target_poses: torch.Tensor, + normalize: bool = True, + ) -> torch.Tensor: + """Project 3D poses to 2D image plane. + + Args: + cam_pose: Camera pose of in arena frame of shape (B, 4, 4). + intrinsics: Camera intrinsic matrix of shape (B, 3, 3). + height: Image height. + width: Image width. + target_poses: 3D poses of shape (B, N, 4, 4). + normalize: Whether to normalize the projected points to [0, 1] range. + + Returns: + Projected 2D points of shape (B, N, 2). + """ + batch_size, num_poses = target_poses.shape[:2] + + # Convert to opencv coordinate system + cam_pose[:, :3, 1] = -cam_pose[:, :3, 1] + cam_pose[:, :3, 2] = -cam_pose[:, :3, 2] + + # Expand cam_pose_inv and intrinsics to match target_poses batch dimension + cam_pose_inv = torch.linalg.inv(cam_pose) # (B, 4, 4) + cam_pose_inv_expanded = cam_pose_inv.unsqueeze(1).expand( + -1, num_poses, -1, -1 + ) # (B, N, 4, 4) + cam_pose_inv_reshaped = cam_pose_inv_expanded.reshape(-1, 4, 4) # (B*N, 4, 4) + + intrinsics_expanded = intrinsics.unsqueeze(1).expand( + -1, num_poses, -1, -1 + ) # (B, N, 3, 3) + intrinsics_reshaped = intrinsics_expanded.reshape(-1, 3, 3) # (B*N, 3, 3) + + # Reshape target_poses to (B*N, 4, 4) + target_poses_reshaped = target_poses.reshape(-1, 4, 4) # (B*N, 4, 4) + + # Transform 3D points to camera coordinates in parallel + # Extract translation part (position) from target poses: (B*N, 4, 1) + target_positions = target_poses_reshaped[:, :, 3:4] # (B*N, 4, 1) + + # Transform to camera coordinates: (B*N, 4, 1) + cam_positions = cam_pose_inv_reshaped.bmm(target_positions) # (B*N, 4, 1) + cam_positions_3d = cam_positions[:, :3, 0] # (B*N, 3) + + # Project to 2D using intrinsics in parallel + # Add small epsilon to avoid division by zero + eps = 1e-8 + z_safe = torch.clamp(cam_positions_3d[:, 2], min=eps) # (B*N,) + + # Normalize by depth + normalized_points = cam_positions_3d[:, :2] / z_safe.unsqueeze(-1) # (B*N, 2) + + # Convert to homogeneous coordinates and apply intrinsics + normalized_homogeneous = torch.cat( + [normalized_points, torch.ones_like(normalized_points[:, :1])], dim=-1 + ) # (B*N, 3) + pixel_coords = intrinsics_reshaped.bmm( + normalized_homogeneous.unsqueeze(-1) + ).squeeze( + -1 + ) # (B*N, 3) + + # Extract 2D coordinates + points_2d_flat = pixel_coords[:, :2] # (B*N, 2) + + # Reshape back to (B, N, 2) + points_2d = points_2d_flat.reshape(batch_size, num_poses, 2) + + # clip to range [0, width] and [0, height] + points_2d[..., 0] = torch.clamp(points_2d[..., 0], 0, width - 1) + points_2d[..., 1] = torch.clamp(points_2d[..., 1], 0, height - 1) + + if normalize: + # Normalize to [0, 1] range + points_2d[..., 0] /= width + points_2d[..., 1] /= height + + return points_2d + + def _get_gripper_ratio( + self, control_part: str, gripper_qpos: Optional[torch.Tensor] = None + ): + robot: Robot = self._env.robot + gripper_max_limit = robot.body_data.qpos_limits[ + :, robot.get_joint_ids(control_part) + ][:, 0, 1] + + if gripper_qpos is None: + gripper_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)][:, 0] + + return gripper_qpos / gripper_max_limit + + def _get_robot_exteroception( + self, + control_part: Optional[str] = None, + x_interval: float = 0.02, + y_interval: float = 0.02, + kpnts_number: int = 12, + offset: Optional[Union[List, torch.Tensor]] = None, + follow_eef: bool = False, + ) -> torch.Tensor: + """Get the robot exteroception poses. + + Args: + control_part: The part of the robot to use as reference. If None, uses the base. + x_interval: The interval for expanding along x-axis. + y_interval: The interval for expanding along y-axis. + kpnts_number: Number of keypoints to generate for each axis. + offset: Intrinsic offset that need to be substracted. + follow_eef: Whether to follow the gripper or not. + + Returns: + A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses. + """ + robot: Robot = self._env.robot + if control_part is not None: + current_qpos = robot.get_qpos()[:, robot.get_joint_ids(control_part)] + robot_pose = robot.compute_fk( + current_qpos, name=control_part, to_matrix=True + ) + if follow_eef: + gripper_ratio = self._get_gripper_ratio( + control_part.replace("_arm", "_eef") + ) # TODO: "_eef" hardcode + # TODO: only support 1 env + y_interval = (y_interval * gripper_ratio)[0].item() + else: + logger.log_error("Not supported Robot without control part yet.") + + if offset is not None: + offset = torch.as_tensor( + offset, dtype=torch.float32, device=self._env.device + ) + + if (offset.ndim > 2) or (offset.shape[-1] != 3): + logger.log_error( + f"Only (N, 3) shaped xyz-intrinsic offset supported, got shape {offset.shape}" + ) + elif offset.ndim == 1: + offset = offset[None] + # TODO: This operation may be slow when large scale Parallelization, but when small (num_envs=1) this operation is faster + robot_pose[:, :3, 3] = robot_pose[:, :3, 3] - torch.einsum( + "bij,bj->bi", robot_pose[:, :3, :3], offset + ) + + return compute_exteroception.expand_pose( + robot_pose, + x_interval, + y_interval, + kpnts_number, + ) + + def _get_object_exteroception( + self, + uid: str, + affordance_key: str, + x_interval: float = 0.02, + y_interval: float = 0.02, + kpnts_number: int = 12, + is_arena_coord: bool = False, + follow_eef: Optional[str] = None, + ) -> torch.Tensor: + """Get the rigid object exteroception poses. + + Args: + uid: The UID of the object. + affordance_key: The key of the affordance to use for the object pose. + x_interval: The interval for expanding along x-axis. + y_interval: The interval for expanding along y-axis. + kpnts_number: Number of keypoints to generate for each axis. + is_arena_coord: Whether to use the arena coordinate system. Default is False. + + Returns: + A tensor of shape (num_envs, 1 + 2*kpnts_number, 4, 4) representing the exteroception poses. + """ + + obj: RigidObject = self._env.sim.get_rigid_object(uid) + if obj is None: + logger.log_error( + f"Rigid object with UID '{uid}' not found in the simulation." + ) + + if hasattr(self._env, "affordance_datas") is False: + logger.log_error( + "Affordance data is not available in the environment. We cannot compute object exteroception." + ) + + if affordance_key not in self._env.affordance_datas: + # TODO: should this default behavior be warned? + # logger.log_warning( + # f"Affordance key '{affordance_key}' not found in the affordance data, using identity pose.." + # ) + pass + + affordance_pose = torch.as_tensor( + self._env.affordance_datas.get( + affordance_key, torch.eye(4).repeat(self._env.num_envs, 1, 1) + ), + dtype=torch.float32, + ) + if affordance_pose.ndim < 3: + affordance_pose = affordance_pose.repeat(self._env.num_envs, 1, 1) + + ref_pose = None if is_arena_coord else obj.get_local_pose(to_matrix=True) + + if follow_eef is not None: + gripper_ratio = self._get_gripper_ratio(control_part=follow_eef) + # TODO: only support 1 env + y_interval = (y_interval * gripper_ratio)[0].item() + + return compute_exteroception.expand_pose( + affordance_pose, + x_interval, + y_interval, + kpnts_number, + ref_pose=ref_pose, + ) + + def _check_source_valid(self, source: str) -> bool: + if source not in self._valid_source: + logger.log_error( + f"Invalid exteroception source '{source}'. Supported sources are {self._valid_source}." + ) + return True + + def __call__( + self, + env: EmbodiedEnv, + obs: EnvObs, + descriptor: Dict[str, Dict[str, str]], + x_interval: float = 0.02, + y_interval: float = 0.02, + kpnts_number: int = 12, + groups: int = 6, + ) -> Dict[str, Dict[str, torch.Tensor]]: + """Compute the exteroception poses based on the asset type. + + Args: + descriptor: The observation dictionary. + + Returns: + A dictionary containing the exteroception poses with key 'exteroception'. + """ + + exteroception = {} + descriptor = resolve_dict(self._env, descriptor) + for sensor_uid, sources in descriptor.items(): + sensor: Union[Camera, StereoCamera] = self._env.sim.get_sensor(sensor_uid) + if sensor is None: + logger.log_error( + f"Sensor with UID '{sensor_uid}' not found in the simulation." + ) + + if not isinstance(sensor, (Camera, StereoCamera)): + logger.log_error( + f"Sensor with UID '{sensor_uid}' is not a Camera or StereoCamera." + ) + + height, width = sensor.cfg.height, sensor.cfg.width + + exteroception[sensor_uid] = {} + taget_pose_list = [] + for source in sources: + source_type = source["type"] + self._check_source_valid(source_type) + + if source_type == "robot": + target_pose = self._get_robot_exteroception( + control_part=source["control_part"], + x_interval=x_interval, + y_interval=y_interval, + kpnts_number=kpnts_number, + offset=source.get("offset", None), + follow_eef=source.get("follow_eef", False), + ) + elif source_type == "affordance": + target_pose = self._get_object_exteroception( + uid=source["obj_uid"], + affordance_key=source["key"], + x_interval=x_interval, + y_interval=y_interval, + kpnts_number=kpnts_number, + is_arena_coord=source["is_arena_coord"], + follow_eef=source.get("follow_eef", None), + ) + else: + logger.log_error( + f"Unsupported exteroception source '{source_type}'. Supported sources are 'robot' and 'affordance." + ) + taget_pose_list.append(target_pose) + + target_poses = torch.cat(taget_pose_list, dim=1) + if target_poses.shape[1] / (2 * kpnts_number + 1) != groups: + logger.log_error( + f"Exteroception groups number mismatch. Expected {groups}, but got {int(target_poses.shape[1] / (2 * kpnts_number + 1))}." + ) + + if isinstance(sensor, StereoCamera): + intrinsics, right_intrinsics = sensor.get_intrinsics() + left_arena_pose, right_arena_pose = sensor.get_left_right_arena_pose() + projected_kpnts = compute_exteroception._project_3d_to_2d( + left_arena_pose, + intrinsics, + height, + width, + target_poses, + ) + exteroception[sensor_uid]["l"] = projected_kpnts + + projected_kpnts = compute_exteroception._project_3d_to_2d( + right_arena_pose, + right_intrinsics, + height, + width, + target_poses, + ) + exteroception[sensor_uid]["r"] = projected_kpnts + else: + intrinsics = sensor.get_intrinsics() + projected_kpnts = compute_exteroception._project_3d_to_2d( + sensor.get_arena_pose(to_matrix=True), + intrinsics, + height, + width, + target_poses, + ) + exteroception[sensor_uid] = projected_kpnts + + return exteroception diff --git a/embodichain/lab/gym/envs/managers/randomization/__init__.py b/embodichain/lab/gym/envs/managers/randomization/__init__.py new file mode 100644 index 00000000..6483181f --- /dev/null +++ b/embodichain/lab/gym/envs/managers/randomization/__init__.py @@ -0,0 +1,22 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .rendering import * +from .spatial import * + +""" +Randomization are all implemented as Event functors. +""" diff --git a/embodichain/lab/gym/envs/managers/randomization/rendering.py b/embodichain/lab/gym/envs/managers/randomization/rendering.py new file mode 100644 index 00000000..7c0aed82 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/randomization/rendering.py @@ -0,0 +1,469 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +import os +import random +from typing import TYPE_CHECKING, Literal, Union, Optional, Dict + +from embodichain.lab.sim.objects import Light, RigidObject, Articulation +from embodichain.lab.sim.sensors import Camera, StereoCamera +from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg +from embodichain.lab.gym.envs.managers import Functor, FunctorCfg +from embodichain.lab.sim import ( + VisualMaterial, + VisualMaterialInst, + VisualMaterialCfg, +) +from embodichain.utils.string import resolve_matching_names +from embodichain.utils.math import sample_uniform +from embodichain.utils import logger +from embodichain.data import get_data_path + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +__all__ = [ + "randomize_light", + "randomize_camera_intrinsics", + "randomize_visual_material", +] + + +def randomize_light( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + position_range: Optional[tuple[list[float], list[float]]] = None, + color_range: Optional[tuple[list[float], list[float]]] = None, + intensity_range: Optional[tuple[float, float]] = None, +) -> None: + """Randomize light properties by adding, scaling, or setting random values. + + This function allows randomizing light properties in the scene. The function samples random values from the + given distribution parameters and adds, scales, or sets the values into the physics simulation based on the + operation. + + The distribution parameters are lists of two elements each, representing the lower and upper bounds of the + distribution for the x, y, and z components of the light properties. The function samples random values for each + component independently. + + .. attention:: + This function applied the same light properties for all the environments. + + position_range is the x, y, z value added into light's cfg.init_pos. + color_range is the absolute r, g, b value set to the light object. + intensity_range is the value added into light's cfg.intensity. + + .. tip:: + This function uses CPU tensors to assign light properties. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + position_range (Optional[tuple[list[float], list[float]]]): The range for the position randomization. + color_range (Optional[tuple[list[float], list[float]]]): The range for the color randomization. + intensity_range (Optional[tuple[float, float]]): The range for the intensity randomization. + """ + + light: Light = env.sim.get_light(entity_cfg.uid) + num_instance = len(env_ids) + + if position_range: + init_pos = light.cfg.init_pos + new_pos = ( + torch.tensor(init_pos, dtype=torch.float32) + .unsqueeze_(0) + .repeat(num_instance, 1) + ) + random_value = sample_uniform( + lower=torch.tensor(position_range[0]), + upper=torch.tensor(position_range[1]), + size=new_pos.shape, + ) + new_pos += random_value + light.set_local_pose(new_pos, env_ids=env_ids) + + if color_range: + color = torch.zeros((num_instance, 3), dtype=torch.float32) + random_value = sample_uniform( + lower=torch.tensor(color_range[0]), + upper=torch.tensor(color_range[1]), + size=color.shape, + ) + color += random_value + light.set_color(color, env_ids=env_ids) + + if intensity_range: + init_intensity = light.cfg.intensity + new_intensity = ( + torch.tensor(init_intensity, dtype=torch.float32) + .unsqueeze_(0) + .repeat(num_instance, 1) + ) + random_value = sample_uniform( + lower=torch.tensor(intensity_range[0]), + upper=torch.tensor(intensity_range[1]), + size=new_intensity.shape, + ) + new_intensity += random_value + new_intensity.squeeze_(1) + light.set_intensity(new_intensity, env_ids=env_ids) + + +def randomize_camera_intrinsics( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + focal_x_range: Optional[tuple[float, float]] = None, + focal_y_range: Optional[tuple[float, float]] = None, + cx_range: Optional[tuple[float, float]] = None, + cy_range: Optional[tuple[float, float]] = None, +) -> None: + """Randomize camera intrinsic properties by adding, scaling, or setting random values. + + This function allows randomizing camera intrinsic parameters in the scene. The function samples random values + from the given distribution parameters and adds, scales, or sets the values into the physics simulation based + on the operation. + + The distribution parameters are tuples of two elements each, representing the lower and upper bounds of the + distribution for the focal length (fx, fy) and principal point (cx, cy) components of the camera intrinsics. + The function samples random values for each component independently. + + .. attention:: + This function applies the same intrinsic properties for all the environments. + + focal_x_range and focal_y_range are values added to the camera's current fx and fy values. + focal_xy_range is a combined range for both fx and fy, where the range is specified as + [[fx_min, fy_min], [fx_max, fy_max]]. + cx_range and cy_range are values added to the camera's current cx and cy values. + + .. tip:: + This function uses CPU tensors to assign camera intrinsic properties. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + focal_x_range (Optional[tuple[float, float]]): The range for the focal length x randomization. + focal_y_range (Optional[tuple[float, float]]): The range for the focal length y randomization. + cx_range (Optional[tuple[float, float]]): The range for the principal point x randomization. + cy_range (Optional[tuple[float, float]]): The range for the principal point y randomization. + """ + + camera: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid) + num_instance = len(env_ids) + + # Get current intrinsics as baseline + current_intrinsics = camera.cfg.intrinsics # (fx, fy, cx, cy) + + # Create new intrinsics tensor for all instances + new_intrinsics = ( + torch.tensor(current_intrinsics, dtype=torch.float32) + .unsqueeze(0) + .repeat(num_instance, 1) + ) + + # Randomize focal length x (fx) + if focal_x_range: + random_value = sample_uniform( + lower=torch.tensor(focal_x_range[0]), + upper=torch.tensor(focal_x_range[1]), + size=(num_instance,), + ) + new_intrinsics[:, 0] += random_value + + # Randomize focal length y (fy) + if focal_y_range: + random_value = sample_uniform( + lower=torch.tensor(focal_y_range[0]), + upper=torch.tensor(focal_y_range[1]), + size=(num_instance,), + ) + new_intrinsics[:, 1] += random_value + + # Randomize principal point x (cx) + if cx_range: + random_value = sample_uniform( + lower=torch.tensor(cx_range[0]), + upper=torch.tensor(cx_range[1]), + size=(num_instance,), + ) + new_intrinsics[:, 2] += random_value + + # Randomize principal point y (cy) + if cy_range: + random_value = sample_uniform( + lower=torch.tensor(cy_range[0]), + upper=torch.tensor(cy_range[1]), + size=(num_instance,), + ) + new_intrinsics[:, 3] += random_value + + camera.set_intrinsics(new_intrinsics, env_ids=env_ids) + + +class randomize_visual_material(Functor): + """Randomize the the visual material properties of a RigidObject or an Articulation. + + Note: + 1. Currently supported randomized properties include: + - base_color: RGB color of the material. Value should be in [0, 1], shape of (3,) + - base_color_texture: Texture image for the base color of the material. + The textures will be preloaded from the given texture_path during initialization. + - metallic: Metallic property of the material. Value should be in [0, 1]. + - roughness: Roughness property of the material. Value should be in [0, 1]. + - ior: Index of Refraction of the material (only supported in ray tracing mode). + 2. The default ground plane can also be randomized by setting entity_cfg.uid to "default_plane". + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the term. + + Args: + cfg: The configuration of the functor. + env: The environment instance. + + Raises: + ValueError: If the asset is not a RigidObject or an Articulation. + """ + super().__init__(cfg, env) + + self.entity_cfg: SceneEntityCfg = cfg.params["entity_cfg"] + + # special case: default ground plane. + if self.entity_cfg.uid == "default_plane": + pass + else: + self.entity: Union[RigidObject, Articulation] = env.sim.get_asset( + self.entity_cfg.uid + ) + + if not isinstance(self.entity, (RigidObject, Articulation)): + raise ValueError( + f"Randomization functor 'randomize_visual_material' not supported for asset: '{self.entity_cfg.uid}'" + f" with type: '{type(self.entity)}'." + ) + + # TODO: Maybe need to consider two cases: + # 1. the texture folder is very large, and we don't want to load all the textures into memory. + # 2. the texture is generated on the fly. + + # Preload textures (currently only base color textures are supported) + self.textures = [] + texture_path = get_data_path(cfg.params.get("texture_path", None)) + if texture_path is not None: + from embodichain.utils.utility import read_all_folder_images + + texture_key = os.path.basename(texture_path) + # check if the texture group is already loaded in the global texture cache + if texture_key in env.sim.get_texture_cache(): + logger.log_info( + f"Texture group '{texture_key}' is already loaded in the global texture cache." + ) + self.textures = env.sim.get_texture_cache(texture_key) + else: + self.textures = read_all_folder_images(texture_path) + + # padding the texture with alpha channel if not exist + for i in range(len(self.textures)): + if self.textures[i].shape[2] == 3: + data = torch.as_tensor(self.textures[i]) + alpha_channel = ( + torch.ones( + (data.shape[0], data.shape[1], 1), dtype=data.dtype + ) + * 255 + ) + data = torch.cat((data, alpha_channel), dim=2) + self.textures[i] = data + + env.sim.set_texture_cache(texture_key, self.textures) + + if self.entity_cfg.uid == "default_plane": + pass + + else: + # TODO: we may need to get the default material instance from the asset itself. + mat: VisualMaterial = env.sim.create_visual_material( + cfg=VisualMaterialCfg( + base_color=[1.0, 1.0, 1.0, 1.0], + uid=f"{self.entity_cfg.uid}_random_mat", + ) + ) + if isinstance(self.entity, RigidObject): + self.entity.set_visual_material(mat) + elif isinstance(self.entity, Articulation): + _, link_names = resolve_matching_names( + self.entity_cfg.link_names, self.entity.link_names + ) + self.entity_cfg.link_names = link_names + self.entity.set_visual_material(mat, link_names=link_names) + + @staticmethod + def gen_random_base_color_texture(width: int, height: int) -> torch.Tensor: + """Generate a random base color texture. + + Args: + width: The width of the texture. + height: The height of the texture. + + Returns: + A torch tensor representing the random base color texture with shape (height, width, 4). + """ + # Generate random RGB values + rgb = torch.ones((height, width, 3), dtype=torch.float32) + rgb *= torch.rand((1, 1, 3), dtype=torch.float32) + rgba = torch.cat((rgb, torch.ones((height, width, 1))), dim=2) + rgba = (rgba * 255).to(torch.uint8) + return rgba + + def _randomize_texture(self, mat_inst: VisualMaterialInst) -> None: + if len(self.textures) > 0: + # Randomly select a texture from the preloaded textures + texture_idx = torch.randint(0, len(self.textures), (1,)).item() + mat_inst.set_base_color_texture(texture_data=self.textures[texture_idx]) + + def _randomize_mat_inst( + self, + mat_inst: VisualMaterialInst, + plan: Dict[str, torch.Tensor], + random_texture_prob: float, + idx: int = 0, + ) -> None: + # randomize the material instance pbr properties based on the plan. + for key, value in plan.items(): + if key == "base_color": + mat_inst.set_base_color(value[idx].tolist()) + else: + getattr(mat_inst, f"set_{key}")(value[idx].item()) + + # randomize texture or base color based on the probability. + if random_texture_prob <= 0.0 or len(self.textures) == 0: + return + if random.random() < random_texture_prob: + self._randomize_texture(mat_inst) + else: + # set a random base color instead. + random_color = torch.rand(3).tolist() + random_color.append(1.0) # alpha + mat_inst.set_base_color(random_color) + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + random_texture_prob: float = 0.5, + texture_path: Optional[str] = None, + base_color_range: Optional[tuple[list[float], list[float]]] = None, + metallic_range: Optional[tuple[float, float]] = None, + roughness_range: Optional[tuple[float, float]] = None, + ior_range: Optional[tuple[float, float]] = None, + ): + from embodichain.lab.sim.utility import is_rt_enabled + + # resolve environment ids + if env_ids is None: + env_ids = torch.arange(env.num_envs, device="cpu") + else: + env_ids = env_ids.cpu() + + if self.entity_cfg.uid == "default_plane": + env_ids = [0] + + randomize_plan = {} + if base_color_range: + base_color = sample_uniform( + lower=torch.tensor(base_color_range[0], dtype=torch.float32), + upper=torch.tensor(base_color_range[1], dtype=torch.float32), + size=(len(env_ids), 3), # RGB + ) + # append alpha channel + alpha_channel = torch.ones((len(env_ids), 1), dtype=torch.float32) + base_color = torch.cat((base_color, alpha_channel), dim=1) + randomize_plan["base_color"] = base_color + + if metallic_range: + metallic = sample_uniform( + lower=torch.tensor(metallic_range[0], dtype=torch.float32), + upper=torch.tensor(metallic_range[1], dtype=torch.float32), + size=(len(env_ids), 1), + ) + randomize_plan["metallic"] = metallic + + if roughness_range: + roughness = sample_uniform( + lower=torch.tensor(roughness_range[0], dtype=torch.float32), + upper=torch.tensor(roughness_range[1], dtype=torch.float32), + size=(len(env_ids), 1), + ) + randomize_plan["roughness"] = roughness + + if ior_range and is_rt_enabled(): + ior = sample_uniform( + lower=torch.tensor(ior_range[0], dtype=torch.float32), + upper=torch.tensor(ior_range[1], dtype=torch.float32), + size=(len(env_ids), 1), + ) + randomize_plan["ior"] = ior + + # ground plane only has one instance. + mat_insts = None + if self.entity_cfg.uid == "default_plane": + mat_inst = env.sim.get_visual_material("plane_mat").get_default_instance() + self._randomize_mat_inst( + mat_inst=mat_inst, + plan=randomize_plan, + random_texture_prob=random_texture_prob, + idx=0, + ) + return + elif isinstance(self.entity, RigidObject): + mat_insts = self.entity.get_visual_material_inst(env_ids=env_ids) + elif isinstance(self.entity, Articulation): + mat_insts = self.entity.get_visual_material_inst( + env_ids=env_ids, + link_names=self.entity_cfg.link_names, + ) + + for i, data in enumerate(mat_insts): + if isinstance(self.entity, RigidObject): + # For RigidObject, data is the material instance directly + mat: VisualMaterialInst = data + elif isinstance(self.entity, Articulation): + # For Articulation, data is the key-value pair of link name and material instance + mat: Dict[str, VisualMaterialInst] = data + + if isinstance(self.entity, RigidObject): + self._randomize_mat_inst( + mat_inst=mat, + plan=randomize_plan, + random_texture_prob=random_texture_prob, + idx=i, + ) + else: + for name, mat_inst in mat.items(): + self._randomize_mat_inst( + mat_inst=mat_inst, + plan=randomize_plan, + random_texture_prob=random_texture_prob, + idx=i, + ) diff --git a/embodichain/lab/gym/envs/managers/randomization/spatial.py b/embodichain/lab/gym/envs/managers/randomization/spatial.py new file mode 100644 index 00000000..70bcabe8 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/randomization/spatial.py @@ -0,0 +1,261 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +from typing import TYPE_CHECKING, Literal, Union, Optional, List + +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg +from embodichain.utils.math import sample_uniform, matrix_from_euler +from embodichain.utils import logger + + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +def get_random_pose( + init_pos: torch.Tensor, + init_rot: torch.Tensor, + position_range: Optional[tuple[list[float], list[float]]] = None, + rotation_range: Optional[tuple[list[float], list[float]]] = None, + relative_position: bool = True, + relative_rotation: bool = False, +) -> torch.Tensor: + """Generate a random pose based on the initial position and rotation. + + Args: + init_pos (torch.Tensor): The initial position tensor of shape (num_instance, 3). + init_rot (torch.Tensor): The initial rotation tensor of shape (num_instance, 3, 3). + position_range (Optional[tuple[list[float], list[float]]]): The range for the position randomization. + rotation_range (Optional[tuple[list[float], list[float]]]): The range for the rotation randomization. + The rotation is represented as Euler angles (roll, pitch, yaw) in degree. + relative_position (bool): Whether to randomize the position relative to the initial position. Default is True. + relative_rotation (bool): Whether to randomize the rotation relative to the initial rotation. Default is False. + + Returns: + torch.Tensor: The generated random pose tensor of shape (num_instance, 4, 4). + """ + + num_instance = init_pos.shape[0] + pose = ( + torch.eye(4, dtype=torch.float32, device=init_pos.device) + .unsqueeze_(0) + .repeat(num_instance, 1, 1) + ) + pose[:, :3, :3] = init_rot + pose[:, :3, 3] = init_pos + + if position_range: + + pos_low = torch.tensor(position_range[0], device=init_pos.device) + pos_high = torch.tensor(position_range[1], device=init_pos.device) + + random_value = sample_uniform( + lower=pos_low, + upper=pos_high, + size=(num_instance, 3), + ) + if relative_position: + random_value += init_pos + + pose[:, :3, 3] = random_value + + if rotation_range: + + rot_low = torch.tensor(rotation_range[0], device=init_pos.device) + rot_high = torch.tensor(rotation_range[1], device=init_pos.device) + + random_value = ( + sample_uniform( + lower=rot_low, + upper=rot_high, + size=(num_instance, 3), + ) + * torch.pi + / 180.0 + ) + rot = matrix_from_euler(random_value) + + if relative_rotation: + rot = torch.bmm(init_rot, rot) + pose[:, :3, :3] = rot + + return pose + + +def randomize_rigid_object_pose( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + position_range: Optional[tuple[list[float], list[float]]] = None, + rotation_range: Optional[tuple[list[float], list[float]]] = None, + relative_position: bool = True, + relative_rotation: bool = False, +) -> None: + """Randomize the pose of a rigid object in the environment. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + position_range (Optional[tuple[list[float], list[float]]]): The range for the position randomization. + rotation_range (Optional[tuple[list[float], list[float]]]): The range for the rotation randomization. + The rotation is represented as Euler angles (roll, pitch, yaw) in degree. + relative_position (bool): Whether to randomize the position relative to the object's initial position. Default is True. + relative_rotation (bool): Whether to randomize the rotation relative to the object's initial rotation. Default is False. + """ + + rigid_object: RigidObject = env.sim.get_rigid_object(entity_cfg.uid) + num_instance = len(env_ids) + + init_pos = ( + torch.tensor(rigid_object.cfg.init_pos, dtype=torch.float32, device=env.device) + .unsqueeze_(0) + .repeat(num_instance, 1) + ) + init_rot = ( + torch.tensor(rigid_object.cfg.init_rot, dtype=torch.float32, device=env.device) + * torch.pi + / 180.0 + ) + init_rot = init_rot.unsqueeze_(0).repeat(num_instance, 1) + init_rot = matrix_from_euler(init_rot) + + pose = get_random_pose( + init_pos=init_pos, + init_rot=init_rot, + position_range=position_range, + rotation_range=rotation_range, + relative_position=relative_position, + relative_rotation=relative_rotation, + ) + + rigid_object.set_local_pose(pose, env_ids=env_ids) + rigid_object.clear_dynamics() + + +def randomize_robot_eef_pose( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + position_range: Optional[tuple[list[float], list[float]]] = None, + rotation_range: Optional[tuple[list[float], list[float]]] = None, +) -> None: + """Randomize the initial end-effector pose of a robot in the environment. + + Note: + - The position and rotation are performed randomization in a relative manner. + - The current state of eef pose is computed based on the current joint positions of the robot. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + robot_name (str): The name of the robot. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + position_range (Optional[tuple[list[float], list[float]]]): The range for the position randomization. + rotation_range (Optional[tuple[list[float], list[float]]]): The range for the rotation randomization. + The rotation is represented as Euler angles (roll, pitch, yaw) in degree. + """ + + def set_random_eef_pose(joint_ids: List[int], robot: Robot) -> None: + current_qpos = robot.get_qpos()[env_ids][:, joint_ids] + if current_qpos.dim() == 1: + current_qpos = current_qpos.unsqueeze_(0) + + current_eef_pose = robot.compute_fk( + name=part, qpos=current_qpos, to_matrix=True + ) + + new_eef_pose = get_random_pose( + init_pos=current_eef_pose[:, :3, 3], + init_rot=current_eef_pose[:, :3, :3], + position_range=position_range, + rotation_range=rotation_range, + relative_position=True, + relative_rotation=True, + ) + + ret, new_qpos = robot.compute_ik( + pose=new_eef_pose, name=part, joint_seed=current_qpos + ) + + new_qpos[ret == False] = current_qpos[ret == False] + robot.set_qpos(new_qpos, env_ids=env_ids, joint_ids=joint_ids) + + robot = env.sim.get_robot(entity_cfg.uid) + + control_parts = entity_cfg.control_parts + if control_parts is None: + joint_ids = robot.get_joint_ids() + set_random_eef_pose(joint_ids, robot) + else: + for part in control_parts: + joint_ids = robot.get_joint_ids(part) + set_random_eef_pose(joint_ids, robot) + + # simulate 10 steps to let the robot reach the target pose. + env.sim.update(step=10) + + +def randomize_robot_qpos( + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + qpos_range: Optional[tuple[list[float], list[float]]] = None, + relative_qpos: bool = True, + joint_ids: Optional[List[int]] = None, +) -> None: + """Randomize the initial joint positions of a robot in the environment. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + qpos_range (Optional[tuple[list[float], list[float]]]): The range for the joint position randomization. + relative_qpos (bool): Whether to randomize the joint positions relative to the current joint positions. Default is True. + joint_ids (Optional[List[int]]): The list of joint IDs to randomize. If None, all joints will be randomized. + """ + if qpos_range is None: + return + + num_instance = len(env_ids) + + robot = env.sim.get_robot(entity_cfg.uid) + + if joint_ids is None: + if len(qpos_range[0]) != robot.dof: + logger.log_error( + f"The length of qpos_range {len(qpos_range[0])} does not match the robot dof {robot.dof}." + ) + joint_ids = robot.get_joint_ids() + + qpos = sample_uniform( + lower=torch.tensor(qpos_range[0], device=env.device), + upper=torch.tensor(qpos_range[1], device=env.device), + size=(num_instance, len(joint_ids)), + ) + + if relative_qpos: + current_qpos = robot.get_qpos()[env_ids][:, joint_ids] + current_qpos += qpos + else: + current_qpos = qpos + + robot.set_qpos(qpos=current_qpos, env_ids=env_ids, joint_ids=joint_ids) + env.sim.update(step=100) diff --git a/embodichain/lab/gym/envs/managers/record.py b/embodichain/lab/gym/envs/managers/record.py new file mode 100644 index 00000000..157064a7 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/record.py @@ -0,0 +1,237 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +import os +import random +import numpy as np +from typing import TYPE_CHECKING, Literal, Union, List + +from dexsim.utility import images_to_video +from embodichain.lab.gym.envs.managers import Functor, FunctorCfg +from embodichain.lab.sim.sensors.camera import CameraCfg + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class record_camera_data(Functor): + """Record camera data in the environment. The camera is usually setup with third-person view, and + is used to record the scene during the episode. It is helpful for debugging and visualization. + + Note: + Currently, the functor is implemented in `interval' mode such that, it can only save the + recorded frames when in :meth:`env.step()` function call. For example: + ```python + env.step() + # perform multiple steps in the same episode + env.reset() + env.step() # the video of the first episode will be saved here. + ``` + The final episode frames will not be saved in the current implementation. + We may improve it in the future. + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the functor. + + Args: + cfg: The configuration of the functor. + env: The environment instance. + + Raises: + ValueError: If the asset is not a RigidObject or an Articulation. + """ + super().__init__(cfg, env) + + # extract the used quantities (to enable type-hinting) + self._name = cfg.params.get("name", "default") + resolution = cfg.params.get("resolution", (640, 480)) + eye = cfg.params.get("eye", (0, 0, 2)) + target = cfg.params.get("target", (0, 0, 0)) + up = cfg.params.get("up", (0, 0, 1)) + intrinsics = cfg.params.get( + "intrinsics", (600, 600, int(resolution[0] / 2), int(resolution[1] / 2)) + ) + + self.camera = env.sim.add_sensor( + sensor_cfg=CameraCfg( + uid=self._name, + width=resolution[0], + height=resolution[1], + extrinsics=CameraCfg.ExtrinsicsCfg(eye=eye, target=target, up=up), + intrinsics=intrinsics, + ) + ) + + self._current_episode = 0 + self._frames: List[np.ndarray] = [] + + def _draw_frames_into_one_image(self, frames: torch.Tensor) -> torch.Tensor: + """ + Concatenate multiple frames into a single image with nearly square arrangement. + + Args: + frames: Tensor with shape (B, H, W, 4) where B is batch size + + Returns: + Single concatenated image tensor with shape (grid_h * H, grid_w * W, 4) + """ + if frames.numel() == 0: + return frames + + B, H, W, C = frames.shape + + # Calculate grid dimensions for nearly square arrangement + grid_w = int(torch.ceil(torch.sqrt(torch.tensor(B, dtype=torch.float32)))) + grid_h = int(torch.ceil(torch.tensor(B, dtype=torch.float32) / grid_w)) + + # Create empty grid to hold all frames + result = torch.zeros( + (grid_h * H, grid_w * W, C), dtype=frames.dtype, device=frames.device + ) + + # Fill the grid with frames + for i in range(B): + row = i // grid_w + col = i % grid_w + + start_h = row * H + end_h = start_h + H + start_w = col * W + end_w = start_w + W + + result[start_h:end_h, start_w:end_w] = frames[i] + + return result + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + name: str, + resolution: tuple[int, int] = (640, 480), + eye: tuple[float, float, float] = (0, 0, 2), + target: tuple[float, float, float] = (0, 0, 0), + up: tuple[float, float, float] = (0, 0, 1), + intrinsics: tuple[float, float, float, float] = ( + 600, + 600, + 320, + 240, + ), + max_env_num: int = 16, + save_path: str = "./outputs/videos", + ): + # TODO: the current implementation will lost the final episode frames recording. + # Check if the frames should be saved for the current episode + if env.elapsed_steps.sum().item() == len(env_ids) and len(self._frames) > 0: + video_name = f"episode_{self._current_episode}_{self._name}" + images_to_video(self._frames, save_path, video_name, fps=20) + + self._current_episode += 1 + self._frames = [] + + self.camera.update(fetch_only=self.camera.is_rt_enabled) + data = self.camera.get_data() + rgb = data["color"] + + num_frames = max(rgb.shape[0], max_env_num) + rgb = rgb[:num_frames] + rgb = self._draw_frames_into_one_image(rgb)[..., :3].cpu().numpy() + self._frames.append(rgb) + + +class record_camera_data_async(record_camera_data): + """Record camera data for multiple environments, merge and save as a single video at episode end.""" + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._num_envs = min(4, getattr(env, "num_envs", 1)) + self._frames_list = [[] for _ in range(self._num_envs)] + self._ep_idx = [0 for _ in range(self._num_envs)] + + def __call__( + self, + env: EmbodiedEnv, + env_ids: Union[torch.Tensor, None], + name: str, + resolution: tuple[int, int] = (640, 480), + eye: tuple[float, float, float] = (0, 0, 2), + target: tuple[float, float, float] = (0, 0, 0), + up: tuple[float, float, float] = (0, 0, 1), + intrinsics: tuple[float, float, float, float] = ( + 600, + 600, + 320, + 240, + ), + max_env_num: int = 16, + save_path: str = "./outputs/videos", + ): + self.camera.update(fetch_only=self.camera.is_rt_enabled) + data = self.camera.get_data() + rgb = data["color"] # shape: (num_envs, H, W, 4) + if isinstance(rgb, torch.Tensor): + rgb_np = rgb.cpu().numpy() + else: + rgb_np = rgb + # Only collect frames for the first 4 environments + for i in range(self._num_envs): + self._frames_list[i].append(rgb_np[i][..., :]) + + # Check if elapsed_steps==1 (just reset) + elapsed = env.elapsed_steps + if isinstance(elapsed, torch.Tensor): + elapsed_np = elapsed.cpu().numpy() + else: + elapsed_np = elapsed + # Only check reset for the first 4 environments + ready_envs = [ + i + for i in range(self._num_envs) + if elapsed_np[i] == 1 and len(self._frames_list[i]) > 1 + ] + # Used to temporarily store episode frames for each env + if not hasattr(self, "_pending_env_episodes"): + self._pending_env_episodes = {} + for i in ready_envs: + if i not in self._pending_env_episodes: + self._pending_env_episodes[i] = self._frames_list[i][:-1] + self._frames_list[i] = [ + self._frames_list[i][-1] + ] # Only keep the first frame after reset + self._ep_idx[i] += 1 + # If all specified envs have collected frames, concatenate and save + if len(self._pending_env_episodes) == self._num_envs: + min_len = min(len(frames) for frames in self._pending_env_episodes.values()) + big_frames = [] + for j in range(min_len): + frames = [ + self._pending_env_episodes[i][j] for i in range(self._num_envs) + ] + frames_tensor = torch.from_numpy(np.stack(frames)).to(torch.uint8) + big_frame = ( + self._draw_frames_into_one_image(frames_tensor)[..., :3] + .cpu() + .numpy() + ) + big_frames.append(big_frame) + video_name = f"ep{self._ep_idx[0]-1}_{self._name}_allenvs" + images_to_video(big_frames, save_path, video_name, fps=20) + self._pending_env_episodes.clear() diff --git a/embodichain/lab/gym/envs/object/__init__.py b/embodichain/lab/gym/envs/object/__init__.py new file mode 100644 index 00000000..851d86e2 --- /dev/null +++ b/embodichain/lab/gym/envs/object/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .geometry import * diff --git a/embodichain/lab/gym/envs/object/geometry.py b/embodichain/lab/gym/envs/object/geometry.py new file mode 100644 index 00000000..b95faf3c --- /dev/null +++ b/embodichain/lab/gym/envs/object/geometry.py @@ -0,0 +1,138 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import open3d as o3d + +from typing import Union + +from dexsim.models import MeshObject +from embodichain.utils import logger +from embodichain.lab.sim.objects import RigidObject +from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg +from embodichain.utils.utility import inv_transform + + +def get_pc_svd_frame(pc: Union[np.ndarray, torch.Tensor]) -> np.ndarray: + """ + Computes the pose of a point cloud using Singular Value Decomposition (SVD). + + This function centers the point cloud, performs SVD to obtain the rotation, + and constructs a 4x4 transformation matrix representing the pose of the point cloud. + + Args: + pc (np.ndarray): A 2D numpy array of shape (N, 3) representing the point cloud, + where N is the number of points. + + Returns: + np.ndarray: A 4x4 transformation matrix that includes the rotation and translation + of the point cloud. + """ + if pc.ndim != 2: + logger.log_error( + f"get_pc_svd_frame only support the pc of 1 object, which means that pc.ndim==2, but got {pc.ndim}" + ) + pc_center = pc.mean(axis=0) + pc_centered = pc - pc_center + u, s, vt = torch.linalg.svd(pc_centered) + rotation = vt.T + pc_pose = torch.eye(4, dtype=torch.float32) + pc_pose[:3, :3] = rotation + pc_pose[:3, 3] = pc_center + return pc_pose + + +def apply_svd_transfer_pc( + geometry: Union[ + np.ndarray, + torch.Tensor, + o3d.cuda.pybind.geometry.TriangleMesh, + MeshObject, + RigidObject, + ], + sample_points: int = 1000, +) -> np.ndarray: + """ + Applies Singular Value Decomposition (SVD) transfer to a point cloud represented by geometry. + + Parameters: + geometry (Union[np.ndarray, MeshObject]): The input geometry, which can be a NumPy array of vertices + or a MeshObject containing vertex data. + sample_points (int): The number of sample points to consider (default is 1000). + + Returns: + np.ndarray: The transformed vertices in standard position after applying SVD. + """ + if isinstance(geometry, (RigidObject, MeshObject)): + verts = torch.as_tensor(geometry.get_vertices()) + elif isinstance(geometry, (np.ndarray, torch.Tensor)): + verts = torch.as_tensor(geometry) + elif isinstance(geometry, o3d.cuda.pybind.geometry.TriangleMesh): + verts = torch.as_tensor(geometry.vertices) + else: + logger.log_error( + f"Unsupported geometry type: {type(geometry)}. Expected np.ndarray, torch.Tensor, MeshObject, or RigidObject." + ) + + if verts.ndim < 3: + verts = verts[None] + + sample_ids = ( + np.random.choice(verts.shape[1], sample_points) + if isinstance(verts, np.ndarray) + else torch.randint(0, verts.shape[1], (sample_points,)) + ) + verts = verts[:, sample_ids, :] + + standard_verts = [] + for object_verts in verts: + pc_svd_frame = get_pc_svd_frame(object_verts) + inv_svd_frame = inv_transform(pc_svd_frame) + standard_object_verts = ( + object_verts @ inv_svd_frame[:3, :3].T + inv_svd_frame[:3, 3] + ) + standard_verts.append(standard_object_verts) + + return torch.stack(standard_verts) + + +def compute_object_length( + env, + env_ids: Union[torch.Tensor, None], + entity_cfg: SceneEntityCfg, + sample_points: int, + is_svd_frame: bool = True, +): + rigid_object: RigidObject = env.sim.get_rigid_object(entity_cfg.uid) + object_lengths = {} + for axis in ["x", "y", "z"]: + object_lengths.update( + {axis: torch.zeros((env.num_envs,), dtype=torch.float32, device=env.device)} + ) + pcs = rigid_object.get_vertices(env_ids) + body_scale = rigid_object.get_body_scale(env_ids) + scaled_pcs = pcs * body_scale + + if is_svd_frame: + scaled_pcs = apply_svd_transfer_pc(scaled_pcs, sample_points) + + for axis, idx in zip(["x", "y", "z"], [0, 1, 2]): + scaled_pos = scaled_pcs[..., idx] # (num_envs, sample_points) + length = scaled_pos.max(dim=1)[0] - scaled_pos.min(dim=1)[0] + object_lengths.update({axis: length}) + + return object_lengths diff --git a/embodichain/lab/gym/envs/rl_env_cfg.py b/embodichain/lab/gym/envs/rl_env_cfg.py new file mode 100644 index 00000000..34448747 --- /dev/null +++ b/embodichain/lab/gym/envs/rl_env_cfg.py @@ -0,0 +1,33 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from typing import Any, Dict + +from embodichain.lab.gym.envs.embodied_env import EmbodiedEnvCfg +from embodichain.utils import configclass + + +@configclass +class RLEnvCfg(EmbodiedEnvCfg): + """Extended configuration for RL environments built from gym-style specs.""" + + env_id: str = "" + extensions: Dict[str, Any] = {} + + @classmethod + def from_dict(cls, d): + """Create an instance from a dictionary.""" + return cls(**d) diff --git a/embodichain/lab/gym/envs/tasks/rl/__init__.py b/embodichain/lab/gym/envs/tasks/rl/__init__.py new file mode 100644 index 00000000..f8cf3031 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/rl/__init__.py @@ -0,0 +1,32 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from copy import deepcopy +from embodichain.lab.gym.utils import registration as env_registry +from embodichain.lab.gym.envs.rl_env_cfg import RLEnvCfg + + +def build_env(env_id: str, base_env_cfg: RLEnvCfg): + """Create env from registry id, auto-inferring cfg class (EnvName -> EnvNameCfg).""" + env = env_registry.make(env_id, cfg=deepcopy(base_env_cfg)) + return env + + +__all__ = [ + "build_env", +] diff --git a/embodichain/lab/gym/envs/tasks/rl/push_cube.py b/embodichain/lab/gym/envs/tasks/rl/push_cube.py new file mode 100644 index 00000000..412877e7 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/rl/push_cube.py @@ -0,0 +1,227 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from typing import Dict, Any, Optional, Sequence +from gymnasium import spaces + +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.sim.cfg import MarkerCfg +from embodichain.lab.sim.types import EnvObs, EnvAction +from embodichain.utils import logger + + +@register_env("PushCubeRL", max_episode_steps=50, override=True) +class PushCubeEnv(EmbodiedEnv): + """Push cube task for reinforcement learning. + + The task involves pushing a cube to a target goal position using a robotic arm. + The reward consists of reaching reward, placing reward, action penalty, and success bonus. + """ + + def __init__(self, cfg=None, **kwargs): + if cfg is None: + cfg = EmbodiedEnvCfg() + + extensions = getattr(cfg, "extensions", {}) or {} + + # cfg.sim_cfg.enable_rt = True + + defaults = { + "success_threshold": 0.1, + "reaching_reward_weight": 0.1, + "place_reward_weight": 2.0, + "place_penalty_weight": 0.5, + "action_penalty_weight": 0.01, + "success_bonus_weight": 10.0, + } + for name, default in defaults.items(): + value = extensions.get(name, getattr(cfg, name, default)) + setattr(cfg, name, value) + setattr(self, name, getattr(cfg, name)) + + self.last_cube_goal_dist = None + + super().__init__(cfg, **kwargs) + + def _draw_goal_marker(self): + """Draw axis marker at goal position for visualization.""" + goal_sphere = self.sim.get_rigid_object("goal_sphere") + if goal_sphere is None: + return + + num_envs = self.cfg.num_envs + + # Get actual goal positions from each arena + goal_poses = goal_sphere.get_local_pose(to_matrix=True) # (num_envs, 4, 4) + + # Draw marker for each arena separately + for arena_idx in range(num_envs): + marker_name = f"goal_marker_{arena_idx}" + + self.sim.remove_marker(marker_name) + + goal_pose = goal_poses[arena_idx].detach().cpu().numpy() + marker_cfg = MarkerCfg( + name=marker_name, + marker_type="axis", + axis_xpos=[goal_pose], + axis_size=0.003, + axis_len=0.02, + arena_index=arena_idx, + ) + self.sim.draw_marker(cfg=marker_cfg) + + def _init_sim_state(self, **kwargs): + super()._init_sim_state(**kwargs) + self.single_action_space = spaces.Box( + low=-self.joint_limits, + high=self.joint_limits, + shape=(6,), + dtype=np.float32, + ) + if self.obs_mode == "state": + self.single_observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(15,), dtype=np.float32 + ) + + def _initialize_episode( + self, env_ids: Optional[Sequence[int]] = None, **kwargs + ) -> None: + super()._initialize_episode(env_ids=env_ids, **kwargs) + cube = self.sim.get_rigid_object("cube") + + # Calculate previous distance (for incremental reward) based on current (possibly randomized) pose + cube_pos = cube.body_data.pose[:, :3] + goal_sphere = self.sim.get_rigid_object("goal_sphere") + goal_pos = goal_sphere.body_data.pose[ + :, :3 + ] # Get actual goal positions for each environment + self.last_cube_goal_dist = torch.norm(cube_pos[:, :2] - goal_pos[:, :2], dim=1) + + # Draw marker at goal position + # self._draw_goal_marker() + + def _step_action(self, action: EnvAction) -> EnvAction: + scaled_action = action * self.action_scale + scaled_action = torch.clamp( + scaled_action, -self.joint_limits, self.joint_limits + ) + current_qpos = self.robot.body_data.qpos + target_qpos = current_qpos.clone() + target_qpos[:, :6] += scaled_action[:, :6] + self.robot.set_qpos(qpos=target_qpos) + return scaled_action + + def get_obs(self, **kwargs) -> EnvObs: + qpos_all = self.robot.body_data.qpos[:, :6] + ee_pose_matrix = self.robot.compute_fk( + name="arm", qpos=qpos_all, to_matrix=True + ) + ee_pos_all = ee_pose_matrix[:, :3, 3] + cube = self.sim.get_rigid_object("cube") + cube_pos_all = cube.body_data.pose[:, :3] + # Get actual goal positions for each environment + goal_sphere = self.sim.get_rigid_object("goal_sphere") + goal_pos_all = goal_sphere.body_data.pose[:, :3] + if self.obs_mode == "state": + return torch.cat([qpos_all, ee_pos_all, cube_pos_all, goal_pos_all], dim=1) + return { + "robot": {"qpos": qpos_all, "ee_pos": ee_pos_all}, + "object": {"cube_pos": cube_pos_all, "goal_pos": goal_pos_all}, + } + + def get_reward( + self, obs: EnvObs, action: EnvAction, info: Dict[str, Any] + ) -> torch.Tensor: + if self.obs_mode == "state": + ee_pos = obs[:, 6:9] + cube_pos = obs[:, 9:12] + goal_pos = obs[:, 12:15] + else: + ee_pos = obs["robot"]["ee_pos"] + cube_pos = obs["object"]["cube_pos"] + goal_pos = obs["object"]["goal_pos"] + push_direction = goal_pos - cube_pos + push_dir_norm = torch.norm(push_direction, dim=1, keepdim=True) + 1e-6 + push_dir_normalized = push_direction / push_dir_norm + push_pose = ( + cube_pos + - 0.015 * push_dir_normalized + + torch.tensor([0, 0, 0.015], device=self.device, dtype=torch.float32) + ) + ee_to_push_dist = torch.norm(ee_pos - push_pose, dim=1) + reaching_reward_raw = 1.0 - torch.tanh(5.0 * ee_to_push_dist) + reaching_reward = self.reaching_reward_weight * reaching_reward_raw + cube_to_goal_dist = torch.norm(cube_pos[:, :2] - goal_pos[:, :2], dim=1) + distance_delta = 10.0 * (self.last_cube_goal_dist - cube_to_goal_dist) + distance_delta_normalized = torch.tanh(distance_delta) + place_reward = torch.where( + distance_delta_normalized >= 0, + self.place_reward_weight * distance_delta_normalized, + self.place_penalty_weight * distance_delta_normalized, + ) + self.last_cube_goal_dist = cube_to_goal_dist + action_magnitude = torch.norm(action, dim=1) + action_penalty = -self.action_penalty_weight * action_magnitude + success_bonus_raw = info["success"].float() + success_bonus = self.success_bonus_weight * success_bonus_raw + reward = reaching_reward + place_reward + action_penalty + success_bonus + # Organize reward components in a dedicated "rewards" dict + # This allows trainer to easily identify and log reward components + if "rewards" not in info: + info["rewards"] = {} + info["rewards"]["reaching_reward"] = reaching_reward + info["rewards"]["place_reward"] = place_reward + info["rewards"]["action_penalty"] = action_penalty + info["rewards"]["success_bonus"] = success_bonus + return reward + + def get_info(self, **kwargs) -> Dict[str, Any]: + cube = self.sim.get_rigid_object("cube") + cube_pos = cube.body_data.pose[:, :3] + # Get actual goal positions for each environment + goal_sphere = self.sim.get_rigid_object("goal_sphere") + goal_pos = goal_sphere.body_data.pose[:, :3] + xy_distance = torch.norm(cube_pos[:, :2] - goal_pos[:, :2], dim=1) + is_success = xy_distance < self.success_threshold + info = { + "success": is_success, + "fail": torch.zeros( + self.cfg.num_envs, device=self.device, dtype=torch.bool + ), + "elapsed_steps": self._elapsed_steps, + } + info["metrics"] = { + "distance_to_goal": xy_distance, + } + return info + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + is_timeout = self._elapsed_steps >= self.episode_length + cube = self.sim.get_rigid_object("cube") + cube_pos = cube.body_data.pose[:, :3] + is_fallen = cube_pos[:, 2] < -0.1 + return is_timeout | is_fallen + + def evaluate(self, **kwargs) -> Dict[str, Any]: + info = self.get_info(**kwargs) + return { + "success": info["success"][0].item(), + "distance_to_goal": info["distance_to_goal"], + } diff --git a/embodichain/lab/gym/envs/tasks/tableware/__init__.py b/embodichain/lab/gym/envs/tasks/tableware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/__init__.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py new file mode 100644 index 00000000..8938aeb6 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/action_bank.py @@ -0,0 +1,233 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from copy import deepcopy +from typing import Dict, Tuple, Union, List, Any, Optional, Callable +from embodichain.lab.gym.envs.action_bank.configurable_action import ( + ActionBank, + tag_node, + tag_edge, +) + +from embodichain.lab.gym.utils.misc import ( + resolve_env_params, + mul_linear_expand, + get_offset_pose_list, + get_changed_pose, +) + +from embodichain.lab.sim.planners.motion_generator import MotionGenerator +from embodichain.utils import logger + + +__all__ = ["PourWaterActionBank"] + + +class PourWaterActionBank(ActionBank): + @staticmethod + @tag_node + @resolve_env_params + def generate_left_arm_aim_qpos( + env, + valid_funcs_name_kwargs_proc: Optional[List] = None, + ): + # FIXME FIXME FIXME FIXME + logger.log_warning( + f"CAUTION=============================THIS FUNC generate_left_arm_aim_qpos IS WRONG!!!! PLEASE FIX IT!!!!" + ) + left_aim_horizontal_angle = np.arctan2( + *( + ( + env.affordance_datas["cup_pose"][:2, 3] + - env.affordance_datas["left_arm_base_pose"][:2, 3] + )[1::-1] + ) + ) + left_arm_aim_qpos = deepcopy(env.affordance_datas["left_arm_init_qpos"]) + left_arm_aim_qpos[0] = left_aim_horizontal_angle + env.affordance_datas["left_arm_aim_qpos"] = left_arm_aim_qpos + return True + + @staticmethod + @tag_node + @resolve_env_params + # DONE: valid & process qpos & fk + def generate_right_arm_aim_qpos( + env, + valid_funcs_name_kwargs_proc: Optional[List] = None, + ): + # FIXME FIXME FIXME FIXME + logger.log_warning( + f"CAUTION=============================THIS FUNC generate_right_arm_aim_qpos IS WRONG!!!! PLEASE FIX IT!!!!" + ) + right_aim_horizontal_angle = np.arctan2( + *( + ( + env.affordance_datas["bottle_pose"][:2, 3] + - env.affordance_datas["right_arm_base_pose"][:2, 3] + )[1::-1] + ) + ) + right_arm_aim_qpos = deepcopy(env.affordance_datas["right_arm_init_qpos"]) + right_arm_aim_qpos[0] = right_aim_horizontal_angle + env.affordance_datas["right_arm_aim_qpos"] = right_arm_aim_qpos + return True + + @staticmethod + @tag_node + @resolve_env_params + def compute_unoffset_for_exp(env, pose_input_output_names_changes: Dict = {}): + env.affordance_datas["bottle_grasp_unoffset_matrix_object"] = np.eye( + 4 + ) # For the overall transform matrix calculation + for input_pose_name, change_params in pose_input_output_names_changes.items(): + output_pose_name = change_params["output_pose_name"] + pose_changes = change_params["pose_changes"] + env.affordance_datas[output_pose_name] = get_changed_pose( + env.affordance_datas[input_pose_name], pose_changes + ) + + return True + + @staticmethod + @tag_edge + @tag_node + # TODO: Got the dimension from the scope + def execute_open(env, return_action: bool = False, **kwargs): + if return_action: + duration = kwargs.get("duration", 1) + expand = kwargs.get("expand", False) + if expand: + action = mul_linear_expand(np.array([[0.0], [1.0]]), [duration - 1]) + action = np.concatenate([action, np.array([[1.0]])]).transpose() + else: + action = np.ones((1, duration)) + return action + else: + return True + + @staticmethod + @tag_edge + @tag_node + def execute_close(env, return_action: bool = False, **kwargs): + + if return_action: + duration = kwargs.get("duration", 1) + expand = kwargs.get("expand", False) + if expand: + action = mul_linear_expand(np.array([[1.0], [0.0]]), [duration - 1]) + action = np.concatenate([action, np.array([[0.0]])]).transpose() + else: + action = np.zeros((1, duration)) + return action + else: + return True + + @staticmethod + @tag_edge + def plan_trajectory( + env, + agent_uid: str, + keypose_names: List[str], + duration: int, + edge_name: str = "", + ): + keyposes = [ + env.affordance_datas[keypose_name] for keypose_name in keypose_names + ] + + keyposes = [ + kp.cpu().numpy() if hasattr(kp, "cpu") and hasattr(kp, "numpy") else kp + for kp in keyposes + ] + + if all( + np.linalg.norm(former - latter).sum() <= 1e-3 + for former, latter in zip(keyposes, keyposes[1:]) + ): + logger.log_warning( + f"Applying plan_trajectory to two very close qpos! Using stand_still." + ) + keyposes = [keyposes[0]] * 2 + ret_transposed = PourWaterActionBank.stand_still( + env, + agent_uid, + keypose_names, + duration, + ) + + return ret_transposed + + else: + mo_gen = MotionGenerator(robot=env.robot, uid=agent_uid) + ret, _ = mo_gen.create_discrete_trajectory( + qpos_list=keyposes, + sample_num=duration, + qpos_seed=keyposes[0], + is_use_current_qpos=False, + ) + + return ret.T + + @staticmethod + @tag_edge + def stand_still( + env, + agent_uid: str, + keypose_names: List[str], + duration: int, + ): + keyposes = [ + env.affordance_datas[keypose_name] for keypose_name in keypose_names + ] + + stand_still_qpos = keyposes[0] + + if ( + stand_still_qpos.shape + != np.asarray(env.robot.get_joint_ids("left_arm")).shape + ): + logger.log_error( + f"The shape of stand_still qpos is different from {agent_uid}'s setting." + ) + + if any( + np.linalg.norm(former - latter).sum() > 1e-6 + for former, latter in zip(keyposes, keyposes[1:]) + ): + logger.log_warning( + f"Applying stand still to two different qpos! Using the first qpos {stand_still_qpos}" + ) + keyposes = [stand_still_qpos] * 2 + + ret = np.asarray([stand_still_qpos] * duration) + + return ret.T + + @staticmethod + @tag_edge + def left_arm_go_back(env, duration: int): + left_arm_monitor_qpos, left_arm_init_qpos = ( + env.affordance_datas["left_arm_monitor_qpos"], + env.affordance_datas["left_arm_init_qpos"], + ) + left_home_sample_num = duration + qpos_expand_left = np.array([left_arm_monitor_qpos, left_arm_init_qpos]) + qpos_expand_left = mul_linear_expand(qpos_expand_left, [left_home_sample_num]) + ret = np.array(qpos_expand_left).T + return ret diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py new file mode 100644 index 00000000..433b76de --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py @@ -0,0 +1,147 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.utils import logger + +from embodichain.lab.gym.envs.tasks.tableware.pour_water.action_bank import ( + PourWaterActionBank, +) + +__all__ = ["PourWaterEnv"] + + +@register_env("PourWater-v3", max_episode_steps=600) +class PourWaterEnv(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + action_config = kwargs.get("action_config", None) + if action_config is not None: + self.action_config = action_config + + def create_demo_action_list(self, *args, **kwargs): + """ + Create a demonstration action list for the current task. + + Returns: + list: A list of demo actions generated by the task. + """ + logger.log_info("Create demo action list for PourWaterTask.") + + if getattr(self, "action_config") is not None: + self._init_action_bank(PourWaterActionBank, self.action_config) + action_list = self.create_expert_demo_action_list(*args, **kwargs) + else: + logger.log_error("No action_config found in env, please check again.") + + if action_list is None: + return action_list + + logger.log_info( + f"Demo action list created with {len(action_list)} steps.", color="green" + ) + return action_list + + def create_expert_demo_action_list(self, **kwargs): + """ + Create an expert demonstration action list using the action bank. + + This function generates a trajectory based on expert knowledge, mapping joint and end-effector + states to the required action format for the environment and robot type. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + list: A list of actions, each containing joint positions ("qpos"). + """ + + if hasattr(self, "action_bank") is False or self.action_bank is None: + logger.log_error( + "Action bank is not initialized. Cannot create expert demo action list." + ) + + ret = self.action_bank.create_action_list( + self, self.graph_compose, self.packages + ) + + if ret is None: + logger.log_warning("Failed to generate expert demo action list.") + return None + + # TODO: to be removed, need a unified interface in robot class + left_arm_joints = self.robot.get_joint_ids(name="left_arm") + right_arm_joints = self.robot.get_joint_ids(name="right_arm") + left_eef_joints = self.robot.get_joint_ids(name="left_eef") + right_eef_joints = self.robot.get_joint_ids(name="right_eef") + + total_traj_num = ret[list(ret.keys())[0]].shape[-1] + actions = torch.zeros( + (total_traj_num, self.num_envs, self.robot.dof), dtype=torch.float32 + ) + + for key, joints in [ + ("left_arm", left_arm_joints), + ("right_arm", right_arm_joints), + ("left_eef", left_eef_joints), + ("right_eef", right_eef_joints), + ]: + if key in ret: + # TODO: only 1 env supported now + actions[:, 0, joints] = torch.as_tensor(ret[key].T, dtype=torch.float32) + + return actions + + def is_task_success(self, **kwargs) -> torch.Tensor: + """Determine if the task is successfully completed. This is mainly used in the data generation process + of the imitation learning. + + Args: + **kwargs: Additional arguments for task-specific success criteria. + + Returns: + torch.Tensor: A boolean tensor indicating success for each environment in the batch. + """ + + bottle = self.sim.get_rigid_object("bottle") + cup = self.sim.get_rigid_object("cup") + + bottle_final_xpos = bottle.get_local_pose(to_matrix=True) + cup_final_xpos = cup.get_local_pose(to_matrix=True) + + bottle_ret = self._is_fall(bottle_final_xpos) + cup_ret = self._is_fall(cup_final_xpos) + + return ~(bottle_ret | cup_ret) + + def _is_fall(self, pose: torch.Tensor) -> torch.Tensor: + # Extract z-axis from rotation matrix (last column, first 3 elements) + pose_rz = pose[:, :3, 2] + world_z_axis = torch.tensor([0, 0, 1], dtype=pose.dtype, device=pose.device) + + # Compute dot product for each batch element + dot_product = torch.sum(pose_rz * world_z_axis, dim=-1) # Shape: (batch_size,) + + # Clamp to avoid numerical issues with arccos + dot_product = torch.clamp(dot_product, -1.0, 1.0) + + # Compute angle and check if fallen + angle = torch.arccos(dot_product) + return angle >= torch.pi / 4 diff --git a/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py b/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py new file mode 100644 index 00000000..46f6afad --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/tableware/scoop_ice.py @@ -0,0 +1,214 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +import pickle + +from copy import deepcopy +from typing import Optional, Sequence +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.data import get_data_path +from embodichain.utils import logger +from tqdm import tqdm + + +@register_env("ScoopIce-v1", max_episode_steps=600) +class ScoopIce(EmbodiedEnv): + def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): + super().__init__(cfg, **kwargs) + + self.affordance_datas = {} + + # TODO: hardcode code, should be implemented as functor way. + self.trajectory = pickle.load( + open( + get_data_path("ScoopIceNewEnv/pose_record_20250919_184544.pkl"), + "rb", + ) + ) + self.trajectory_sample_rate = 2 + + def set_scoop_pose(self, xyzrxryrz): + scoop = self.sim.get_rigid_object("scoop") + pose = np.eye(4) + pose[:3, 3] = xyzrxryrz[:3] + pose[:3, :3] = R.from_euler("XYZ", xyzrxryrz[3:], degrees=True).as_matrix() + n_env = self.sim.num_envs + pose_t = torch.tensor( + pose[None, :, :].repeat(n_env, axis=0), + dtype=torch.float32, + device=self.device, + ) + scoop.set_local_pose(pose_t) + + def set_cup_pose(self, xyzrxryrz): + cup = self.sim.get_rigid_object("paper_cup") + pose = np.eye(4) + pose[:3, 3] = xyzrxryrz[:3] + pose[:3, :3] = R.from_euler("XYZ", xyzrxryrz[3:], degrees=True).as_matrix() + n_env = self.sim.num_envs + pose_t = torch.tensor( + pose[None, :, :].repeat(n_env, axis=0), + dtype=torch.float32, + device=self.device, + ) + cup.set_local_pose(pose_t) + + def add_xpos_offset(self, arm_qpos: np.ndarray, offset: np.ndarray, is_left: bool): + """Add offset to arm qposes along end-effector x axis. + + Args: + arm_qposes (np.ndarray): [waypoint_num, dof] + """ + waypoint_num = arm_qpos.shape[0] + dof = arm_qpos.shape[1] + offset_t = torch.tensor(offset, dtype=torch.float32, device=self.device) + control_part = "left_arm" if is_left else "right_arm" + + arm_qpos_batch = torch.tensor( + arm_qpos[None, :, :], dtype=torch.float32, device=self.device + ) + + arm_xpos_batch = self.robot.compute_batch_fk( + qpos=arm_qpos_batch, name=control_part, to_matrix=True + ) + arm_xpos_batch[:, :, :3, 3] += offset_t + ret, arm_qpos_offset_batch = self.robot.compute_batch_ik( + pose=arm_xpos_batch, + joint_seed=arm_qpos_batch, + name=control_part, + ) + return arm_qpos_offset_batch[0].to("cpu").numpy() + + def pack_qpos(self): + self.num_envs = self.sim.num_envs + left_arm_qpos = self.trajectory["left_arm"] # [waypoint_num, dof] + logger.log_info("Adding x and z offset to left arm trajectory...") + left_arm_qpos = self.add_xpos_offset( + arm_qpos=left_arm_qpos, offset=np.array([-0.018, 0.0, -0.01]), is_left=True + ) + right_arm_qpos = self.trajectory["right_arm"] # [waypoint_num, dof] + # TODO: add z offset to right arm + logger.log_info("Adding z offset to right arm trajectory...") + right_arm_qpos = self.add_xpos_offset( + arm_qpos=right_arm_qpos, offset=np.array([0.00, 0.0, 0.02]), is_left=False + ) + left_eef_qpos = self.trajectory["left_eef"] # [waypoint_num, hand_dof] + right_eef_qpos = self.trajectory["right_eef"] + torso_qpos = self.trajectory["torso"] + # TODO: need head qpos. + + left_arm_qpos_expand = left_arm_qpos[None, :, :].repeat(self.num_envs, axis=0) + right_arm_qpos_expand = right_arm_qpos[None, :, :].repeat(self.num_envs, axis=0) + left_eef_qpos_expand = left_eef_qpos[None, :, :].repeat(self.num_envs, axis=0) + right_eef_qpos_expand = right_eef_qpos[None, :, :].repeat(self.num_envs, axis=0) + torso_qpos_expand = torso_qpos[None, :, :].repeat(self.num_envs, axis=0) + all_qpos = np.concatenate( + [ + left_arm_qpos_expand, + right_arm_qpos_expand, + left_eef_qpos_expand, + right_eef_qpos_expand, + torso_qpos_expand, + ], + axis=2, + ) + return all_qpos + + def _initialize_episode( + self, env_ids: Optional[Sequence[int]] = None, **kwargs + ) -> None: + + left_arm_ids = self.robot.get_joint_ids(name="left_arm") + right_arm_ids = self.robot.get_joint_ids(name="right_arm") + left_eef_ids = self.robot.get_joint_ids(name="left_eef") + right_eef_ids = self.robot.get_joint_ids(name="right_eef") + torso_ids = self.robot.get_joint_ids(name="torso") + all_ids = np.hstack( + [left_arm_ids, right_arm_ids, left_eef_ids, right_eef_ids, torso_ids] + ) + + # TODO: read xy random range from config + xy_random_range = np.array([[-0.01, -0.01], [0.01, 0.01]]) + xy_random_offset = np.zeros(shape=(self.num_envs, 2)) + for arena_id in range(self.num_envs): + xy_random_offset[arena_id] = np.random.uniform( + low=xy_random_range[0], high=xy_random_range[1], size=(2,) + ) + # TODO: apply warping to container pose + + all_qpos = self.pack_qpos() + all_qpos_t = torch.tensor(all_qpos, dtype=torch.float32, device=self.device) + + # to initial qpos + left_open_qpos = np.array([0.06, 1.5, 0.2, 0.2, 0.2, 0.2]) + left_close_qpos = np.array([0.13, 1.5, 0.5, 0.5, 0.5, 0.5]) + right_open_qpos = np.array([0.3, 1.5, 0.3, 0.3, 0.3, 0.3]) + right_close_qpos = np.array([0.6, 1.5, 0.7, 0.5, 0.7, 0.6]) + + all_qpos_t[:, :, 14:20] = torch.tensor( + left_close_qpos, dtype=torch.float32, device=self.device + ) + all_qpos_t[:, :, 20:26] = torch.tensor( + right_close_qpos, dtype=torch.float32, device=self.device + ) + + first_close_qpos = all_qpos_t[:, 0, :].to("cpu").numpy() + first_open_qpos = deepcopy(first_close_qpos) + + # to first open pose + first_open_qpos[:, 14:20] = left_open_qpos + first_open_qpos[:, 20:26] = right_open_qpos + self.robot.set_qpos( + torch.tensor(first_open_qpos, dtype=torch.float32, device=self.device), + joint_ids=all_ids, + ) + self.sim.update(step=200) + # save warp trajectory as demo action list + waypoint_num = self.trajectory["left_arm"].shape[0] + current_qpos = self.robot.get_qpos() + self.demo_action_list = [] + for waypoint_idx in range(waypoint_num): + action = current_qpos.clone() + action[:, all_ids] = all_qpos_t[:, waypoint_idx, :] + # TODO: sample in trajectory + self.demo_action_list.append(action) + + # TODO: tricky implementation. Hold the first joint state for a while. + if waypoint_idx == 0: + for _ in range(20): + self.demo_action_list.append(action) + + self.sim.update(step=100) + + # apply events such as randomization for environments that need a reset + if self.cfg.events: + if "reset" in self.event_manager.available_modes: + self.event_manager.apply(mode="reset", env_ids=env_ids) + + def create_demo_action_list(self, *args, **kwargs): + logger.log_info( + f"The original demo action list length: {len(self.demo_action_list)}" + ) + logger.log_info( + f"Downsample the demo action list by self.trajectory_sample_rate5 times." + ) + return self.demo_action_list[:: self.trajectory_sample_rate] diff --git a/embodichain/lab/gym/envs/wrapper/__init__.py b/embodichain/lab/gym/envs/wrapper/__init__.py new file mode 100644 index 00000000..3e39714b --- /dev/null +++ b/embodichain/lab/gym/envs/wrapper/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .no_fail import NoFailWrapper diff --git a/embodichain/lab/gym/envs/wrapper/no_fail.py b/embodichain/lab/gym/envs/wrapper/no_fail.py new file mode 100644 index 00000000..e5806d11 --- /dev/null +++ b/embodichain/lab/gym/envs/wrapper/no_fail.py @@ -0,0 +1,31 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import gymnasium as gym + + +class NoFailWrapper(gym.Wrapper): + """A wrapper that alter the env's is_task_success method to make sure all the is_task_success determination return True. + + Args: + env (gym.Env): the environment to wrap. + """ + + def __init__(self, env: gym.Env): + super().__init__(env) + + def is_task_success(self, *args, **kwargs): + return True diff --git a/embodichain/lab/gym/utils/__init__.py b/embodichain/lab/gym/utils/__init__.py new file mode 100644 index 00000000..e4655620 --- /dev/null +++ b/embodichain/lab/gym/utils/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py new file mode 100644 index 00000000..42692190 --- /dev/null +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -0,0 +1,617 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +import torch +import dexsim + +from typing import Dict, Any, List, Tuple, Union, Sequence, Optional +from gymnasium import spaces +from copy import deepcopy + +from embodichain.lab.sim.types import Device, Array +from embodichain.lab.sim.objects import Robot +from embodichain.utils.module_utils import find_function_from_modules +from embodichain.utils.utility import get_class_instance +from dexsim.utility import log_debug, log_error + + +def get_dtype_bounds(dtype: np.dtype): + """Gets the min and max values of a given numpy type""" + if np.issubdtype(dtype, np.floating): + info = np.finfo(dtype) + return info.min, info.max + elif np.issubdtype(dtype, np.integer): + info = np.iinfo(dtype) + return info.min, info.max + elif np.issubdtype(dtype, np.bool_): + return 0, 1 + else: + raise TypeError(dtype) + + +def convert_observation_to_space( + observation: Any, prefix: str = "", unbatched: bool = False +) -> spaces.Space: + """Convert observation to OpenAI gym observation space (recursively). + Modified from `gym.envs.mujoco_env` + """ + if isinstance(observation, (dict)): + # CATUION: Explicitly create a list of key-value tuples + # Otherwise, spaces.Dict will sort keys if a dict is provided + space = spaces.Dict( + [ + ( + k, + convert_observation_to_space( + v, prefix + "/" + k, unbatched=unbatched + ), + ) + for k, v in observation.items() + ] + ) + elif isinstance(observation, (list, tuple)): + array = np.array(observation) + dtype = array.dtype + space = spaces.Box(-np.inf, np.inf, shape=array.shape, dtype=dtype) + elif isinstance(observation, np.ndarray): + if unbatched: + shape = observation.shape[1:] + else: + shape = observation.shape + dtype = observation.dtype + low, high = get_dtype_bounds(dtype) + if np.issubdtype(dtype, np.floating): + low, high = -np.inf, np.inf + space = spaces.Box(low, high, shape=shape, dtype=dtype) + elif isinstance(observation, (float, np.float32, np.float64)): + log_debug(f"The observation ({prefix}) is a (float) scalar") + space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=np.float32) + elif isinstance(observation, (int, np.int32, np.int64)): + log_debug(f"The observation ({prefix}) is a (integer) scalar") + space = spaces.Box(-np.inf, np.inf, shape=[1], dtype=int) + elif isinstance(observation, (bool, np.bool_)): + log_debug(f"The observation ({prefix}) is a (bool) scalar") + space = spaces.Box(0, 1, shape=[1], dtype=np.bool_) + else: + raise NotImplementedError(type(observation), observation) + + return space + + +def _batch(array: Union[np.ndarray, Sequence]): + if isinstance(array, (dict)): + return {k: _batch(v) for k, v in array.items()} + if isinstance(array, str): + return array + if isinstance(array, np.ndarray): + if array.shape == (): + return array.reshape(1, 1) + return array[None, :] + if isinstance(array, list): + if len(array) == 1: + return [array] + if ( + isinstance(array, float) + or isinstance(array, int) + or isinstance(array, bool) + or isinstance(array, np.bool_) + ): + return np.array([[array]]) + return array + + +def batch(*args: Tuple[Union[np.ndarray, Dict]]): + """Adds one dimension in front of everything. If given a dictionary, every leaf in the dictionary + has a new dimension. If given a tuple, returns the same tuple with each element batched + """ + x = [_batch(x) for x in args] + if len(args) == 1: + return x[0] + return tuple(x) + + +def to_tensor(array: Array, device: Optional[Device] = None): + """ + Maps any given sequence to a torch tensor on the CPU/GPU. If physx gpu is not enabled then we use CPU, otherwise GPU, unless specified + by the device argument + + Args: + array: The data to map to a tensor + device: The device to put the tensor on. By default this is None and to_tensor will put the device on the GPU if physx is enabled + and CPU otherwise + + """ + if isinstance(array, (dict)): + return {k: to_tensor(v) for k, v in array.items()} + if torch.cuda.is_available(): + if isinstance(array, np.ndarray): + if array.dtype == np.uint16: + array = array.astype(np.int32) + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + elif isinstance(array, torch.Tensor): + ret = array + else: + ret = torch.tensor(array) + if device is None: + if ret.device.type == "cpu": + return ret.cuda() + # keep same device if already on GPU + return ret + else: + return ret.to(device) + else: + if isinstance(array, np.ndarray): + if array.dtype == np.uint16: + array = array.astype(np.int32) + if array.dtype == np.uint32: + array = array.astype(np.int64) + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + elif isinstance(array, list) and isinstance(array[0], np.ndarray): + ret = torch.from_numpy(np.array(array)) + if ret.dtype == torch.float64: + ret = ret.float() + elif np.iterable(array): + ret = torch.Tensor(array) + else: + ret = torch.Tensor(array) + if device is None: + return ret + else: + return ret.to(device) + + +def to_cpu_tensor(array: Array): + """ + Maps any given sequence to a torch tensor on the CPU. + """ + if isinstance(array, (dict)): + return {k: to_tensor(v) for k, v in array.items()} + if isinstance(array, np.ndarray): + ret = torch.from_numpy(array) + if ret.dtype == torch.float64: + ret = ret.float() + return ret + elif isinstance(array, torch.Tensor): + return array.cpu() + else: + return torch.tensor(array).cpu() + + +def flatten_state_dict( + state_dict: dict, use_torch=False, device: Device = None +) -> Array: + """Flatten a dictionary containing states recursively. Expects all data to be either torch or numpy + + Args: + state_dict: a dictionary containing scalars or 1-dim vectors. + use_torch (bool): Whether to convert the data to torch tensors. + + Raises: + AssertionError: If a value of @state_dict is an ndarray with ndim > 2. + + Returns: + np.ndarray | torch.Tensor: flattened states. + + Notes: + The input is recommended to be ordered (e.g. dict). + However, since python 3.7, dictionary order is guaranteed to be insertion order. + """ + states = [] + + for key, value in state_dict.items(): + if isinstance(value, dict): + state = flatten_state_dict(value, use_torch=use_torch) + if state.size == 0: + state = None + if use_torch: + state = to_tensor(state) + elif isinstance(value, (tuple, list)): + state = None if len(value) == 0 else value + if use_torch: + state = to_tensor(state) + elif isinstance(value, (bool, np.bool_, int, np.int32, np.int64)): + # x = np.array(1) > 0 is np.bool_ instead of ndarray + state = int(value) + if use_torch: + state = to_tensor(state) + elif isinstance(value, (float, np.float32, np.float64)): + state = np.float32(value) + if use_torch: + state = to_tensor(state) + elif isinstance(value, np.ndarray): + if value.ndim > 2: + raise AssertionError( + "The dimension of {} should not be more than 2.".format(key) + ) + state = value if value.size > 0 else None + if use_torch: + state = to_tensor(state) + + elif isinstance(value, torch.Tensor): + state = value + if len(state.shape) == 1: + state = state[:, None] + else: + raise TypeError("Unsupported type: {}".format(type(value))) + if state is not None: + states.append(state) + + if use_torch: + if len(states) == 0: + return torch.empty(0, device=device) + else: + return torch.hstack(states) + else: + if len(states) == 0: + return np.empty(0) + else: + return np.hstack(states) + + +def clip_and_scale_action( + action: Union[np.ndarray, torch.Tensor], low: float, high: float +): + """Clip action to [-1, 1] and scale according to a range [low, high].""" + if isinstance(action, np.ndarray): + action = np.clip(action, -1, 1) + elif isinstance(action, torch.Tensor): + action = torch.clip(action, -1, 1) + else: + log_error("Unsupported type: {}".format(type(action))) + return 0.5 * (high + low) + 0.5 * (high - low) * action + + +def dict_array_to_torch_inplace( + data: Dict[str, Any], device: Union[str, torch.device] = "cpu" +) -> None: + """ + Convert arrays in a dictionary to torch tensors in-place. + + Args: + data (Dict[str, Any]): Dictionary to modify in-place + device (Union[str, torch.device]): Device to place the tensors on + """ + for key, value in data.items(): + if isinstance(value, np.ndarray): + item: torch.Tensor = torch.from_numpy(value).to(device) + if len(item.shape) == 1: + item.unsqueeze_(0) + data[key] = item + elif isinstance(value, dict): + dict_array_to_torch_inplace(value, device) + + +def cat_tensor_with_ids( + tensors: List[torch.Tensor], ids: List[List[int]], dim: int +) -> torch.Tensor: + """ + Concatenate tensors along a new dimension specified by `dim`, using the provided `ids` to index into the tensors. + + Args: + tensors (List[torch.Tensor]): List of tensors to concatenate. + ids (List[List[int]]): List of lists, where each inner list contains the indices to select from the corresponding tensor. + dim (int): The dimension along which to concatenate the tensors. + + Returns: + torch.Tensor: The concatenated tensor. + """ + out = torch.zeros( + (tensors[0].shape[0], dim), dtype=tensors[0].dtype, device=tensors[0].device + ) + + for i, tensor in enumerate(tensors): + out[:, ids[i]] = tensor + + return out + + +def config_to_rl_cfg(config: dict) -> "RLEnvCfg": + """Parse gym-style configuration dict into an RL-ready config object.""" + + from embodichain.lab.gym.envs.rl_env_cfg import RLEnvCfg + + # Use config_to_cfg to parse shared fields + env_cfg = config_to_cfg(config) + # Convert to RLEnvCfg if needed + if not isinstance(env_cfg, RLEnvCfg): + env_cfg = RLEnvCfg.from_dict(env_cfg.__dict__) + # RL-specific fields + env_cfg.env_id = config.get("id") + env_cfg.num_envs = config["env"].get("num_envs", env_cfg.num_envs) + env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) + # Add any RL-specific parsing here + return env_cfg + + +def config_to_cfg(config: dict) -> "EmbodiedEnvCfg": + """Parser configuration file into cfgs for env initialization. + + Args: + config (dict): The configuration dictionary containing robot, sensor, light, background, and interactive objects. + + Returns: + EmbodiedEnvCfg: A configuration object for initializing the environment. + """ + + from embodichain.lab.sim.cfg import ( + RobotCfg, + RigidObjectCfg, + RigidObjectGroupCfg, + ArticulationCfg, + LightCfg, + ) + from embodichain.lab.gym.envs import EmbodiedEnvCfg + from embodichain.lab.sim.sensors import SensorCfg + from embodichain.lab.gym.envs.managers import ( + SceneEntityCfg, + EventCfg, + ObservationCfg, + ) + from embodichain.utils import configclass + from embodichain.data import get_data_path + + @configclass + class ComponentCfg: + """Configuration for env events. + + This class is used to define various events that can occur in the environment, + """ + + pass + + env_cfg = EmbodiedEnvCfg() + + # check all necessary keys + required_keys = ["id", "max_episodes", "env", "robot"] + for key in required_keys: + if key not in config: + log_error(f"Missing required config key: {key}") + + # parser robot config + # TODO: support multiple robots cfg initialization from config, eg, cobotmagic, dexforce_w1, etc. + if "robot_type" in config["robot"]: + robot_cfg = get_class_instance( + "embodichain.lab.sim.robots", + config["robot"]["robot_type"] + "Cfg", + ) + config["robot"].pop("robot_type") + robot_cfg = robot_cfg.from_dict(config["robot"]) + else: + robot_cfg = RobotCfg.from_dict(config["robot"]) + + env_cfg.robot = robot_cfg + + # parser sensor config + env_cfg.sensor = [SensorCfg.from_dict(s) for s in config.get("sensor", [])] + + # parser light config + if "light" in config: + env_cfg.light = EmbodiedEnvCfg.EnvLightCfg() + env_cfg.light.direct = [ + LightCfg.from_dict(l) for l in config["light"].get("direct", []) + ] + + # parser background objects config + if "background" in config: + for obj_dict in config["background"]: + shape_type = obj_dict["shape"]["shape_type"] + if shape_type == "Mesh": + obj_dict["shape"]["fpath"] = get_data_path(obj_dict["shape"]["fpath"]) + # Set to static object if not specified. + obj_dict["body_type"] = ( + "static" if "body_type" not in obj_dict else obj_dict["body_type"] + ) + cfg = RigidObjectCfg.from_dict(obj_dict) + env_cfg.background.append(cfg) + + # parser scene objects config + if "rigid_object" in config: + for obj_dict in config["rigid_object"]: + shape_type = obj_dict["shape"]["shape_type"] + if shape_type == "Mesh": + obj_dict["shape"]["fpath"] = get_data_path(obj_dict["shape"]["fpath"]) + cfg = RigidObjectCfg.from_dict(obj_dict) + env_cfg.rigid_object.append(cfg) + + if "rigid_object_group" in config: + for obj_dict in config["rigid_object_group"]: + if "folder_path" in obj_dict: + obj_dict["folder_path"] = get_data_path(obj_dict["folder_path"]) + for rigid_obj in obj_dict["rigid_objects"].values(): + shape_type = rigid_obj["shape"]["shape_type"] + if shape_type == "Mesh" and "fpath" in rigid_obj["shape"]: + rigid_obj["shape"]["fpath"] = get_data_path( + rigid_obj["shape"]["fpath"] + ) + cfg = RigidObjectGroupCfg.from_dict(obj_dict) + env_cfg.rigid_object_group.append(cfg) + + if "articulation" in config: + for obj_dict in config["articulation"]: + obj_dict["fpath"] = get_data_path(obj_dict["fpath"]) + cfg = ArticulationCfg.from_dict(obj_dict) + env_cfg.articulation.append(cfg) + + env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) + + # load dataset config + env_cfg.dataset = config["env"].get("dataset", None) + + # TODO: support more env events, eg, grasp pose generation, mesh preprocessing, etc. + + env_cfg.events = ComponentCfg() + if "events" in config["env"]: + # Define modules to search for event functions + event_modules = [ + "embodichain.lab.gym.envs.managers.randomization", + "embodichain.lab.gym.envs.managers.record", + "embodichain.lab.gym.envs.managers.events", + ] + + # parser env events config + for event_name, event_params in config["env"]["events"].items(): + event_params_modified = deepcopy(event_params) + if "entity_cfg" in event_params["params"]: + entity_cfg = SceneEntityCfg( + **event_params_modified["params"]["entity_cfg"] + ) + event_params_modified["params"]["entity_cfg"] = entity_cfg + + # Find the function from multiple modules using the utility function + event_func = find_function_from_modules( + event_params["func"], event_modules, raise_if_not_found=True + ) + interval_step = event_params_modified.get("interval_step", 10) + + event = EventCfg( + func=event_func, + mode=event_params_modified["mode"], + params=event_params_modified["params"], + interval_step=interval_step, + ) + setattr(env_cfg.events, event_name, event) + + env_cfg.observations = ComponentCfg() + if "observations" in config["env"]: + # Define modules to search for observation functions + observation_modules = [ + "embodichain.lab.gym.envs.managers.observations", + ] + + for obs_name, obs_params in config["env"]["observations"].items(): + obs_params_modified = deepcopy(obs_params) + + if "entity_cfg" in obs_params["params"]: + entity_cfg = SceneEntityCfg( + **obs_params_modified["params"]["entity_cfg"] + ) + obs_params_modified["params"]["entity_cfg"] = entity_cfg + + # Find the function from multiple modules using the utility function + obs_func = find_function_from_modules( + obs_params["func"], + observation_modules, + raise_if_not_found=True, + ) + + observation = ObservationCfg( + func=obs_func, + mode=obs_params_modified["mode"], + name=obs_params_modified["name"], + params=obs_params_modified["params"], + ) + + setattr(env_cfg.observations, obs_name, observation) + + return env_cfg + + +def map_qpos_to_eef_pose( + robot: Robot, qpos: torch.Tensor, control_parts: List[str] +) -> Dict[str, torch.Tensor]: + """Map qpos to end-effector pose. + + Note: + The computed eef pose will be in the base frame of the control part. + + Args: + robot (Robot): The robot instance. + qpos (torch.Tensor): The qpos tensor of shape (N, num_joints). + control_parts (List[str]): List of control part names. + to_dict (bool): Whether to return the result as a dictionary. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the end-effector poses for each control part. + """ + from embodichain.data.enum import EefType + + eef_pose_dict = {} + for i, name in enumerate(control_parts): + eef_pose = torch.zeros( + (qpos.shape[0], 9), dtype=torch.float32, device=qpos.device + ) + + # TODO: need to be configurable. + control_ids = robot.get_joint_ids(name) + current_qpos = qpos[:, control_ids] + part_eef_pose = ( + robot.pk_serial_chain[name] + .forward_kinematics(current_qpos, end_only=True) + .get_matrix() + ) + + eef_pose[:, :3] = part_eef_pose[:, :3, 3] + eef_pose[:, 3:6] = part_eef_pose[:, :3, 0] + eef_pose[:, 6:9] = part_eef_pose[:, :3, 1] + + eef_pose_dict[name + EefType.POSE.value] = eef_pose + + return eef_pose_dict + + +def fetch_data_from_dict( + data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str +) -> Any: + """Fetch data from a nested dictionary using a '/' separated key. + + Args: + data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to fetch data from. + name (str): The '/' separated key string. + + Returns: + Any: The fetched data. + + Raises: + KeyError: If the specified key does not exist in the dictionary. + """ + keys = name.split("/") + current_data = data_dict + + for key in keys: + if key in current_data: + current_data = current_data[key] + else: + raise KeyError(f"Key '{key}' not found in the dictionary.") + + return current_data + + +def assign_data_to_dict( + data_dict: Dict[str, Union[Any, Dict[str, Any]]], name: str, value: Any +) -> None: + """Assign data to a nested dictionary using a '/' separated key. + Missing intermediate dictionaries will be created automatically. + + Args: + data_dict (Dict[str, Union[Any, Dict[str, Any]]]): The nested dictionary to assign data to. + name (str): The '/' separated key string. + value (Any): The value to assign. + """ + keys = name.split("/") + current_data = data_dict + + for key in keys[:-1]: + if key not in current_data or not isinstance(current_data[key], dict): + current_data[key] = {} # create intermediate dict if missing + current_data = current_data[key] + + last_key = keys[-1] + current_data[last_key] = value diff --git a/embodichain/lab/gym/utils/misc.py b/embodichain/lab/gym/utils/misc.py new file mode 100644 index 00000000..92098663 --- /dev/null +++ b/embodichain/lab/gym/utils/misc.py @@ -0,0 +1,1564 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import re +import os +import ast +import cv2 +import h5py +import torch +import inspect +import open3d as o3d + +from copy import deepcopy +from functools import partial, wraps, lru_cache +from collections import OrderedDict +from importlib import import_module +from scipy.spatial.transform import Rotation as R +from typing import Any, Dict, List, Tuple, Union, Sequence, Callable, Optional, Mapping + +import numpy as np + +from embodichain.lab.sim.objects import Robot +from embodichain.utils.utility import inv_transform +from embodichain.utils.logger import log_info, log_warning, log_error + + +def no_validation(*args, **kwargs): + return True + + +def add_xy_random_offset( + pose: np.ndarray, max_offset: Union[float, List[float]] +) -> np.ndarray: + """ + Add a random offset to the x and y translation of a pose. + + Args: + pose (np.ndarray): 4x4 pose matrix. + max_offset (float or List[float]): If float, uniform in [-max_offset, max_offset] for both axes. + If list, [x_min, x_max, y_min, y_max]. + + Returns: + np.ndarray: Pose with random xy offset. + """ + shift_pose = deepcopy(pose) + if isinstance(max_offset, float): + xy_shift = np.random.uniform(-max_offset, max_offset, size=2) + else: + x_shift = np.random.uniform(max_offset[0], max_offset[1]) + y_shift = np.random.uniform(max_offset[2], max_offset[3]) + xy_shift = np.array([x_shift, y_shift]) + shift_pose[:2, 3] += xy_shift + return shift_pose + + +def mul_linear_expand( + arr: np.ndarray, expand_times: Union[int, List[int]], is_interp: bool = True +) -> np.ndarray: + """ + Linearly interpolate or repeat between points in an array. + + Args: + arr (np.ndarray): Input array of shape (N, D). + expand_times (int or List[int]): Number of samples between each pair. + is_interp (bool): If True, interpolate; else, repeat. + + Returns: + np.ndarray: Expanded/interpolated array. + """ + arr = np.asarray(arr) + arr_len, dim = arr.shape + if isinstance(expand_times, int): + interp_path = np.zeros(shape=(arr_len * expand_times, dim), dtype=float) + else: + assert len(expand_times) == arr_len - 1, "Invalid expand_times size." + interp_path = np.zeros(shape=(sum(expand_times), dim), dtype=float) + + idx = 0 + for i in range(arr_len - 1): + sample_times = ( + expand_times if isinstance(expand_times, int) else expand_times[i] + ) + for k in range(sample_times): + if is_interp: + alpha = k / sample_times + v = (1 - alpha) * arr[i] + alpha * arr[i + 1] + else: + v = arr[i] + interp_path[idx] = v + idx += 1 + interp_path = interp_path[:idx] + return interp_path + + +def axis_idx(k: str) -> int: + return {"x": 0, "y": 1, "z": 2}.get(k, None) + + +def axis_str_to_list(s: str): + if any(c not in "xyz" for c in s): + return None + return ["xyz".index(c) for c in s] + + +def is_pose_axis_align( + pose: List, + vector: List, + axis_str: str, + mode: str, + cos_threshold: float = None, + degree_threshold: float = None, +): + """Check if the given `axis` of a `pose` is aligned with given vector, under given mode and cos_threshold. + i.e. the cosine value of the angle between the pose's `axis` and `vector` is respecting (leq or geq) the `cos_threhold` or not. + + Args: + pose (List): The pose to be checked. + vector (List): The vector to be aligning to. + axis_str (str): The string of the axis. + mode (str): leq or geq. + cos_threshold (float): The threshold of the cosine value between the pose's axis and vector. + degree_threshold (float): The threshold of the degree value between the pose's axis and vector, only functions when cos_threshold is not given. + """ + pose = np.asarray(pose) + vector = np.asarray(vector) + + if cos_threshold is None: + if degree_threshold is None: + log_error( + 'cos_threshold & angle_threshold are both None, illegal for "is_pose_axis_align".' + ) + else: + cos_threshold = np.cos(np.deg2rad(degree_threshold)) + + axis_id = axis_idx(axis_str) + axis = pose[:3, axis_id] + cos_value = np.dot(axis, vector) / np.linalg.norm(vector) + + if "abs" in mode: + cos_value = abs(cos_value) + + if "leq" in mode: + return cos_value <= cos_threshold + elif "geq" in mode: + return cos_value >= cos_threshold + + +def is_pose_flip( + pose: list, ref_pose: list, axis_str: str = "y", return_inverse: bool = False +): + pose = np.asarray(pose) + ref_pose = np.asarray(ref_pose) + axis_idx = axis_idx(axis_str) + if axis_idx is None: + log_error(f'Axis {axis_str} is not among ["x", "y", "z"]') + relative_angle = np.abs(np.arccos(pose[:3, axis_idx].dot(ref_pose[:3, axis_idx]))) + valid_ret = relative_angle > np.pi / 2 + + if return_inverse: + valid_ret = not valid_ret + + return valid_ret + + +def is_qpos_exceed(qpos: np.ndarray, agent, uid: str): + return not ( + any(qpos < agent.get_joint_limits(uid)[:, 0]) + or any(qpos > agent.get_joint_limits(uid)[:, 1]) + ) + + +def is_qpos_exceed_new( + qpos: Union[np.ndarray, torch.Tensor], robot: Robot, control_part: str +) -> bool: + """ + Check if the given qpos exceeds the joint limits of the specified control part. + Supports both numpy and torch tensor inputs. + + Args: + qpos (Union[np.ndarray, torch.Tensor]): The joint positions to check. + robot (Robot): The robot object containing joint limits. + control_part (str): The name of the control part to check. + + Returns: + bool: True if qpos exceeds joint limits, False otherwise. + """ + joint_limits = robot.body_data.qpos_limits[0][ + robot.get_joint_ids(name=control_part) + ] + # Convert joint_limits to tensor if qpos is tensor, else to numpy + if isinstance(qpos, torch.Tensor): + joint_limits = torch.as_tensor( + joint_limits, dtype=qpos.dtype, device=qpos.device + ) + exceed = torch.any(qpos < joint_limits[:, 0]) or torch.any( + qpos > joint_limits[:, 1] + ) + return not exceed + else: + qpos = np.asarray(qpos) + # 保证 joint_limits 是 numpy 类型 + if isinstance(joint_limits, torch.Tensor): + joint_limits = joint_limits.cpu().numpy() + exceed = np.any(qpos < joint_limits[:, 0]) or np.any(qpos > joint_limits[:, 1]) + return not exceed + + +def is_qpos_flip( + qpos: Union[np.ndarray, torch.Tensor], + qpos_ref: Union[np.ndarray, torch.Tensor], + qpos_ids: Union[List, np.ndarray], + threshold: float = 1.1 * np.pi, + mode: str = "delta", + return_inverse: bool = False, +): + """ + Check whether the joint positions (qpos) are flipped compared to a reference (qpos_ref). + Supports both numpy and torch tensor inputs. + + Args: + qpos (Union[np.ndarray, torch.Tensor]): The joint positions to check. + qpos_ref (Union[np.ndarray, torch.Tensor]): The reference joint positions. + qpos_ids (Union[List, np.ndarray]): Indices of joints to compare. + threshold (float, optional): Threshold for delta mode. Defaults to 1.1 * np.pi. + mode (str, optional): "delta" for norm difference, "sign" for sign difference. Defaults to "delta". + return_inverse (bool, optional): If True, returns the inverse result. Defaults to False. + + Returns: + bool: True if flipped, False otherwise. + """ + # Ensure qpos_ids is numpy array for indexing + if isinstance(qpos_ids, torch.Tensor): + qpos_ids = qpos_ids.cpu().numpy() + # If either input is torch.Tensor, convert both to tensor for comparison + if isinstance(qpos, torch.Tensor) or isinstance(qpos_ref, torch.Tensor): + if not isinstance(qpos, torch.Tensor): + qpos = torch.from_numpy(qpos) + if not isinstance(qpos_ref, torch.Tensor): + qpos_ref = torch.from_numpy(qpos_ref) + qpos_ids_tensor = torch.as_tensor(qpos_ids, dtype=torch.long) + if mode == "delta": + # Compute norm difference for selected joints + qpos_diff = torch.norm(qpos[qpos_ids_tensor] - qpos_ref[qpos_ids_tensor]) + valid_ret = qpos_diff > threshold + elif mode == "sign": + # Check sign difference for selected joints + valid_ret = (qpos[qpos_ids_tensor] * qpos_ref[qpos_ids_tensor]) < 0 + else: + log_error(f"The qpos flip mode {mode} has not been implemented yet.") + # Convert torch scalar to Python bool + if isinstance(valid_ret, torch.Tensor): + valid_ret = valid_ret.item() if valid_ret.numel() == 1 else bool(valid_ret) + else: + qpos_ids = np.asarray(qpos_ids) + if mode == "delta": + qpos_diff = np.linalg.norm(qpos[qpos_ids] - qpos_ref[qpos_ids]) + valid_ret = qpos_diff > threshold + elif mode == "sign": + valid_ret = (qpos[qpos_ids] * qpos_ref[qpos_ids]) < 0 + else: + log_error(f"The qpos flip mode {mode} has not been implemented yet.") + + if return_inverse: + valid_ret = not valid_ret + + return valid_ret + + +def get_replaced_pose( + pose_to_change: np.ndarray, + pose_replace_value: Union[float, List], + axis_str_replace: str, +) -> np.ndarray: + """ + Replace specific axes of a pose with new values. + + Args: + pose_to_change (np.ndarray): The pose to be modified (4x4 matrix). + pose_replace_value (Union[float, List]): The values to replace the specified axes. + axis_str_replace (str): A string specifying the axes to replace (e.g., "xy"). + + Returns: + np.ndarray: The modified pose. + + Raises: + ValueError: If the lengths of `pose_replace_value` and `axis_str_replace` do not match. + """ + axis_list_replace = axis_str_to_list(axis_str_replace) + if axis_list_replace is None: + raise ValueError(f"Invalid axis string: {axis_str_replace}") + + if isinstance(pose_replace_value, (Sequence, np.ndarray)): + pose_replace_value_length = len(pose_replace_value) + else: + pose_replace_value_length = 1 + pose_replace_value = [pose_replace_value] + + if pose_replace_value_length != len(axis_list_replace): + log_error( + f'The axis asked to be raplaced is "{axis_str_replace}", but got {pose_replace_value_length} changes quantity.' + ) + for axis, replace_quantity in zip(axis_list_replace, pose_replace_value): + pose_to_change[axis, 3] = replace_quantity + return pose_to_change + + +def get_offset_pose( + pose_to_change: np.ndarray, + offset_value: Union[float, List[float]], + direction: Union[str, List] = "z", + mode: str = "extrinsic", +) -> np.ndarray: + """Offset the `pose_to_change` given the `offset_value`, `direction` & `mode`. Returns the offset pose. + + Args: + pose_to_change (np.ndarray): The pose to be offset. + offset_value (Union[float, List[float]]): The offset. + direction (Union[str, List], optional): String as "x", "y" or, "z" and 3-dim np.ndarray indicating the offset directions. Defaults to "z". + mode (str, optional): String "extrinsic" or "intrinsic", indicating which system frame should each offset shall be done. Defaults to "extrinsic". + + Returns: + np.ndarray: The resulting 4x4 offset pose. + + Raises: + ValueError: If inputs are invalid or incompatible. + """ + if isinstance(direction, str): + minus = "-" in direction + direction = direction.removeprefix("-") + direction = np.isin(np.arange(3), axis_str_to_list(direction)).astype(int) * ( + -1 if minus else 1 + ) + + direction = np.asarray(direction) + direction = direction / np.linalg.norm(direction) + offset_matrix = np.eye(4) + offset_matrix[:3, 3] = offset_value * direction + if mode == "extrinsic": + offset_pose = offset_matrix @ pose_to_change + elif mode == "intrinsic": + offset_pose = pose_to_change @ offset_matrix + else: + log_error(f"Mode {mode} illegal.") + return offset_pose + + +# TODO: This one is not for work +def get_offset_pose_list( + pose_to_change: np.ndarray, + offsets: Union[float, List[float]], + directions: Union[str, np.ndarray, List[str], List[np.ndarray]] = [], + modes: Union[str, List[str]] = [], +): + """Offset the `pose_to_change` given the `offsets`, `directions` & `modes`. Returns the offset poses. + + Args: + pose_to_change (np.ndarray): The pose to be offset. + offsets (Union[float, List[float]]): The offset or the offset list. + directions (Union[str, np.ndarray, List[str], List[np.ndarray]], optional): String as "x", "y" or, "z" and 3-dim np.ndarray indicating the offset directions, together with their list that have same size of `offsets` . Defaults to []. + modes (Union[str, List[str]], optional): String "extrinsic" or "intrinsic", and its list, indicating which system frame should each offset shall be done. Defaults to []. + return_single_pose: Whether return the single pose or not + """ + num_offset_pose = len(offsets) if isinstance(offsets, list) else 1 + num_offset_direction = len(directions) if isinstance(directions, list) else 1 + num_offset_mode = len(modes) if isinstance(modes, list) else 1 + if num_offset_direction == 0: + directions = ["z"] * num_offset_pose + num_offset_direction = num_offset_pose + if num_offset_mode == 0: + modes = ["extrinsic"] * num_offset_direction + num_offset_mode = num_offset_direction + if num_offset_direction != num_offset_pose: + log_error( + f"The offsets {offsets} have a different length {num_offset_pose} other than directions {directions}'s {num_offset_direction}." + ) + if num_offset_mode != num_offset_direction: + log_warning( + f"The directions {directions} have a different length {num_offset_direction} other than modes {modes}'s {num_offset_mode}." + ) + if num_offset_direction == 1 and not isinstance(directions, list): + directions = [directions] + if num_offset_mode == 1 and not isinstance(modes, list): + modes = [modes] + + offset_poses = [] + for idx, (offset, (direction, mode)) in enumerate( + zip(offsets, zip(directions, modes)) + ): + offset_pose = get_offset_pose(pose_to_change, offset, direction, mode) + offset_poses.append(offset_pose) + + return offset_poses + + +def get_rotated_pose( + pose_to_change: np.ndarray, + rot_angle: float, + rot_axis: Union[str, List] = "z", + mode: str = "extrinsic", + degrees: Union[bool, str] = None, +): + """Rotate the `pose_to_change` given the `rot_angel`, `rot_axis` & `mode`. Returns the rotate pose. + + Args: + pose_to_change (np.ndarray): The pose to be rotated. + rot_angle (float): The rotation angle. + rot_axis (Union[str, List], optional): String as "x", "y" or, "z" and 3-dim np.ndarray indicating the rotation axis. Defaults to "z". + mode (str, optional): String "extrinsic" or "intrinsic", and its list, indicating which system frame should each rotation shall be done. Defaults to "extrinsic". + degrees (str): If it's "degrees" then the input rotation angle is degree, then it's not degrees but radians. + """ + if isinstance(rot_axis, str): + rot_axis = np.isin(np.arange(3), axis_str_to_list(rot_axis)).astype(int) + rot_axis = np.asarray(rot_axis) + rot_axis = rot_axis / np.linalg.norm(rot_axis) + + if degrees == "degrees" or degrees == True: + rot_angle = np.deg2rad(rot_angle) + + rotation_matrix = np.eye(4) + rotation_matrix[:3, :3] = R.from_rotvec(rot_axis * rot_angle).as_matrix() + if mode == "extrinsic": + rotated_pose = rotation_matrix @ pose_to_change + elif mode == "intrinsic": + rotated_pose = pose_to_change @ rotation_matrix + else: + log_error(f"Mode {mode} illegal.") + return rotated_pose + + +def get_rotation_replaced_pose( + pose_to_change: np.ndarray, + rotation_value: Union[float, List], + rot_axis: Union[str, List] = "z", + mode: str = "extrinsic", + degrees: Union[bool, str] = None, +): + if isinstance(rotation_value, (float, int, np.number)): + replaced_rotation_matrix = get_rotated_pose( + np.eye(4), rotation_value, rot_axis, mode, degrees + )[:3, :3] + elif isinstance(rotation_value, list): + rotation_value = np.asarray(rotation_value) + if rotation_value.shape == (3, 3): + replaced_rotation_matrix == rotation_value + elif rotation_value.shape == (3,): + log_warning( + f'Getting shape (3,) rotation_value {rotation_value} for "rotreplace", make sure it\'s rpy.' + ) + replaced_rotation_matrix = R.from_euler("xyz", rotation_value).as_matrix() + elif rotation_value.shape == (4,): + log_warning( + f'Getting shape (4,) rotation_value {rotation_value} for "rotreplace", make sure it\'s xyzw quaternion.' + ) + replaced_rotation_matrix = R.from_quat(rotation_value).as_matrix() + else: + log_error( + f'rotation_value has shape {rotation_value.shape}, not suppoorted by "rotreplace".' + ) + else: + log_error( + f'rotation_value has type {type(rotation_value)}, not suppoorted by "rotreplace".' + ) + rotation_replaced_pose = deepcopy(pose_to_change) + rotation_replaced_pose[:3, :3] = replaced_rotation_matrix + return rotation_replaced_pose + + +def get_frame_changed_pose( + pose_to_change: np.ndarray, + frame_change_matrix: Union[List, np.ndarray], + mode: bool = "extrinsic", + inverse: Union[bool, str] = False, +): + if isinstance(frame_change_matrix, list): + frame_change_matrix = np.asarray(frame_change_matrix) + if not isinstance(frame_change_matrix, np.ndarray): + log_error( + f'frame_change_matrix has type{type(frame_change_matrix)} other than np.ndarray, not suppoorted by "get_frame_changed_pose".' + ) + else: + if frame_change_matrix.shape != (4, 4): + log_error( + f'frame_change_matrix has shape {frame_change_matrix.shape} other than (4,4), not suppoorted by "get_frame_changed_pose".' + ) + + if inverse == "inverse" or inverse == True: + frame_change_matrix = inv_transform(frame_change_matrix) + + if mode == "extrinsic": + pose_to_change = frame_change_matrix @ pose_to_change + elif mode == "intrinsic": + pose_to_change = pose_to_change @ frame_change_matrix + else: + log_error(f"Mode {mode} illegal.") + + return pose_to_change + + +def get_aligned_pose( + pose_to_change: np.ndarray, + align_vector: List, + pose_axis: str = "z", +): + align_vector = np.asarray(align_vector) + pose_axis = axis_idx(pose_axis) + rotation_axis = np.cross(pose_to_change[:3, pose_axis], align_vector) + rotation_axis_norm = np.linalg.norm(rotation_axis) + if rotation_axis_norm >= 1e-5: + rotation_axis = rotation_axis / rotation_axis_norm + rotation_angle = np.arccos(pose_to_change[:3, 2].dot(align_vector)) + pose_to_change[:3, :3] = ( + R.from_rotvec(rotation_axis * rotation_angle).as_matrix() + @ pose_to_change[:3, :3] + ) + return pose_to_change + + +# TODO: automatically routing,given kwargs automatically find the mode. +def get_changed_pose( + pose_to_change: np.ndarray, pose_changes: List[Tuple[str, Any]] = [] +): + """Change the single pose given the `pose_changes` that indicates how to change the pose. + + Args: + pose_to_change (np.ndarray): The pose to be changed. + pose_changes (List[Tuple[str, Any]], optional): The list contains tuples that [0] refer to pose change name that indicates the change mode and parameters, split by "_", e.g. "offset_${np.array([0.05, -0.10, 0.125])}". And [1] be the change value, e.g. "${env.affordance_datas[\"cup_move_pose\"][:2,3]}". Defaults to []. + + Returns: + pose_to_change (np.ndarray): The changed pose. + """ + for pose_change_name, pose_change_value in pose_changes: + change_partition = pose_change_name.split("_") + change_mode = change_partition[0] + if change_mode == "replace": + pose_to_change = get_replaced_pose( + pose_to_change, + pose_replace_value=pose_change_value, + axis_str_replace=change_partition[1], + ) + elif change_mode == "offset": + pose_to_change = get_offset_pose( + pose_to_change, [pose_change_value], *change_partition[1:] + ) + elif change_mode == "rotation": + pose_to_change = get_rotated_pose( + pose_to_change, + pose_change_value, + *change_partition[1:], + ) + elif change_mode == "rotreplace": + pose_to_change = get_rotation_replaced_pose( + pose_to_change, pose_change_value, *change_partition[1:] + ) + elif change_mode == "framechange": + pose_to_change = get_frame_changed_pose( + pose_to_change, pose_change_value, *change_partition[1:] + ) + elif change_mode == "align": + get_aligned_pose(pose_to_change, pose_change_value, change_partition[1]) + else: + # TODO + log_error(f"The {change_mode} change mode haven't realized yet!") + return pose_to_change + + +def get_replaced_qpos( + qpos_to_change: Union[np.ndarray, torch.Tensor], + replace_value: Union[float, List[float]], + joint_list_replace: List, +): + if not isinstance(replace_value, Sequence): + replace_value = [replace_value] + for joint, replace_quantity in zip(joint_list_replace, replace_value): + qpos_to_change[joint] = float(replace_quantity) + return qpos_to_change + + +def get_offset_qpos( + qpos_to_change: np.ndarray, + offset_value: Union[float, List[float]], + joint_list_offset: List, + degrees: Union[str, bool, List[int]] = None, +) -> np.ndarray: + if not isinstance(offset_value, Sequence): + offset_value = [offset_value] + + degrees_joint_list = [] + if degrees is not None: + if isinstance(degrees, str): + degrees_all = degrees == "degrees" + + if degrees_all: + degrees_joint_list = joint_list_offset + else: + degrees_joint_str = degrees.split("degrees")[1] + for degerees_joint_idx_str in degrees_joint_str: + degrees_joint_list.append(int(degerees_joint_idx_str)) + elif isinstance(degrees, bool): + degrees_all = degrees == True + if degrees_all: + degrees_joint_list = joint_list_offset + elif isinstance(degrees, list): + degrees_joint_list = degrees + + if not set(degrees_joint_list).issubset(set(joint_list_offset)): + log_error( + f"degrees_joint_list {degrees_joint_list}, not subset to joint_list_offset {joint_list_offset}." + ) + + for joint, offset_quantity in zip(joint_list_offset, offset_value): + if joint in degrees_joint_list: + offset_quantity = np.deg2rad(offset_quantity) + qpos_to_change[joint] += offset_quantity + + return qpos_to_change + + +def get_changed_qpos( + qpos_to_change: np.ndarray, qpos_changes: List[Tuple[str, Any]] = [], frame=None +): + """Change the single qpos given the `qpos_changes` that indicates how to change the qpos. + + Args: + qpos_to_change (np.ndarray): The qpos to be changed. + qpos_changes (Dict[str, Any], optional): The list contains tuples that [0] refer to pose change name that indicates the change mode and parameters, split by "_", e.g. "offset_123". And [1] be the change value, e.g. "[0.5, 0.6, 0.7]". Defaults to [] + + Returns: + qpos_to_change (np.ndarray): The changed qpos. + """ + if isinstance(qpos_to_change, torch.tensor): + qpos_to_change = np.asarray(qpos_to_change) + + for qpos_change_name, qpos_change_value in qpos_changes: + change_partition = qpos_change_name.split("_") + change_mode = change_partition[0] + if not isinstance(qpos_change_value, Sequence): + qpos_change_value = [qpos_change_value] + + joint_str_change = change_partition[1] + joint_list_change = [] + for joint_idx_str in joint_str_change: + joint_list_change.append(int(joint_idx_str)) + if len(qpos_change_value) != len(joint_list_change): + log_error( + f'The joints asked to be raplaced is "{joint_str_change}", but got {len(qpos_change_value)} changes quantity.' + ) + + if change_mode == "replace": + qpos_to_change = get_offset_qpos( + qpos_to_change, + replace_value=qpos_change_value, + replace_joint_list=joint_list_change, + ) + elif change_mode == "offset": + qpos_to_change = get_offset_qpos( + qpos_to_change, + offset_value=qpos_change_value, + offset_joint_list=joint_list_change, + degrees=change_partition[2], + ) + else: + log_error(f"The {change_mode} change mode haven't realized yet!") + return qpos_to_change + + +def pose_shift(pose_in_cam: np.ndarray, axis: int, shift: float) -> np.ndarray: + shift_pose = np.copy(pose_in_cam) + shift_pose = np.linalg.inv(shift_pose) + shift_pose[axis, -1] += shift + shift_pose = np.linalg.inv(shift_pose) + return shift_pose + + +def expand_pose( + pose: np.ndarray, + x_interval: float, + y_interval: float, + kpnts_number: int, + grab_ratios: List = None, + ref_pose: np.ndarray = np.eye(4), +): + ret = [ref_pose @ pose] + + xoffset = np.linspace(-x_interval, x_interval, kpnts_number) + + if grab_ratios is None: + yoffset = np.linspace(-y_interval, y_interval, kpnts_number) + else: + yoffset = np.linspace( + -y_interval * grab_ratios[0], + y_interval * grab_ratios[1], + kpnts_number, + ) + + x_expand = [ref_pose @ pose_shift(pose, 0, x) for x in list(xoffset)] + y_expand = [ref_pose @ pose_shift(pose, 1, y) for y in list(yoffset)] + return ret + x_expand + y_expand + + +def parse_mask_by_uuids( + mask: np.ndarray, uuids: Dict[str, int] +) -> Dict[str, np.ndarray]: + from embodichain.data.enum import SemanticMask + + if len(uuids.keys()) == 0: + return {} + + robot_mask = np.zeros_like(mask, dtype=np.bool_) + robot_uuids = uuids["robot"] + for uuid in robot_uuids: + robot_mask[mask == uuid] = True + + object_uuids = uuids.get("object", []) + foreground_mask = np.zeros_like(mask, dtype=np.bool_) + for uuid in object_uuids: + foreground_mask[mask == uuid] = True + + background_mask = np.logical_not(np.logical_or(robot_mask, foreground_mask)) + + ret_masks = np.tile( + np.expand_dims(np.zeros_like(robot_mask), -1), [1, 1, len(SemanticMask)] + ) + ret_masks[:, :, SemanticMask.ROBOT.value] = robot_mask + ret_masks[:, :, SemanticMask.BACKGROUND.value] = background_mask + ret_masks[:, :, SemanticMask.FOREGROUND.value] = foreground_mask + + return ret_masks + + +def project_3d_to_2d( + cam, poses: List[np.ndarray], normalize: bool = True +) -> np.ndarray: + import dexsim.utility as dexutils + + cam_pose = cam._camera.get_world_pose() + cam_pose = dexutils.change_coordinate_stystem(cam_pose, ["X", "-Y", "-Z"]) + + intrinsic = cam._camera.get_intrinsic() + height = cam._camera.get_height() + width = cam._camera.get_width() + keypoints = [] + for pose in poses: + pose_ = np.matmul(np.linalg.inv(cam_pose), pose) + rvec, tvec = cv2.Rodrigues(np.eye(3))[0], np.zeros((3, 1)) + image_points, _ = cv2.projectPoints( + pose_[:3, -1], + rvec, + tvec, + np.array(intrinsic).reshape(3, 3), + np.zeros( + 5, + ), + ) + keypoints.append(image_points.reshape(1, 2)) + + keypoints = np.concatenate(keypoints, 0) + keypoints[:, 0] = np.clip(keypoints[:, 0], 0, width - 1) + keypoints[:, 1] = np.clip(keypoints[:, 1], 0, height - 1) + + if normalize: + keypoints[:, 0] = keypoints[:, 0] / width + keypoints[:, 1] = keypoints[:, 1] / height + return keypoints + + +def camel_to_snake(name): + # Insert underscores before each uppercase letter and convert to lowercase + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + snake_case = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + return snake_case + + +def print_keys_recursively(h5file: h5py.Group, path="/"): + """ + Recursively prints the keys in the HDF5 file. + + :param h5file: An open h5py File object. + :param path: The current path in the HDF5 file. + """ + for key in h5file[path].keys(): + print(f"{path}{key}") + if isinstance(h5file[path + key], h5py.Group): + print_keys_recursively(h5file, path + key + "/") + + +def hdf5_to_dict(h5file: h5py.Group): + def recursive_dict(group): + result = {} + for key, item in group.items(): + if isinstance(item, h5py.Dataset): + result[key] = item[()] + elif isinstance(item, h5py.Group): + result[key] = recursive_dict(item) + return result + + return recursive_dict(h5file) + + +def extract_keys_hierarchically(d, assign_data_type: str = "list"): + result = {} + + for key, value in d.items(): + # If the value is a dictionary, recursively extract its keys + if isinstance(value, dict): + result[key] = extract_keys_hierarchically(value) + else: + if assign_data_type == "null": + result[key] = None + elif assign_data_type == "list": + result[key] = [] + else: + raise ValueError(f"Invalid assign_data_type: {assign_data_type}") + + return result + + +def get_file_list(path: str, ext: str): + file_list = [] + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(ext): + file_list.append(os.path.join(root, file)) + + return file_list + + +# Pattern to recognize an attribute or indexer sequence, e.g., foo, ["bar"], [0] +_TOKEN_RE = re.compile( + r"""( + (?P[A-Za-z_]\w*) # attribute name + | \[\s*(?P"[^"]*"|'[^']*'|\d+)\s*\] # bracket indexer with quoted key or integer +)""", + re.VERBOSE, +) + + +def resolve_env_attr(obj: Any, env: Any) -> Any: + """ + Recursively replace any string of the form 'env:...' by evaluating it as a Python expression on the given `env` object. + Other containers (mappings, sequences) will be traversed and resolved element-wise. + + Supports: + - Arbitrary attribute access (e.g. env.x.y.z) + - Arbitrary indexing and slicing (e.g. env.x["key"][1:4]) + - Any valid Python expression after the 'env:' prefix. + + Args: + obj: The object to resolve. If it's: + - A dict-like Mapping: each value is passed back into resolve_env_attr. + - A Sequence (list/tuple/etc.) but not str: each element is resolved. + - A str starting with 'env:': the suffix is treated as a Python + expression relative to `env` and eval'ed. + - Anything else: returned unchanged. + env: An object whose attributes, methods, indices, etc. may be + referenced in the 'env:' expressions. + + Returns: + The resolved object, with 'env:' strings replaced by their eval results. + """ + # 1) If it's a mapping, recurse into its values + if isinstance(obj, Mapping): + return {k: resolve_env_attr(v, env) for k, v in obj.items()} + + # 2) If it's a non-str sequence, recurse into its elements + if isinstance(obj, Sequence) and not isinstance(obj, str): + return type(obj)(resolve_env_attr(item, env) for item in obj) + + # 3) If it's a string starting with "env.", eval it directly + if isinstance(obj, str) and obj.startswith("env."): + return eval(obj, {}, {"env": env}) + + # 4) Everything else passes through unchanged + return obj + + +_EXPR = re.compile(r"\$\{([^}]+)\}") # For searching ${...} marker + + +def resolve_formatted_string(obj, local_vars=None, global_vars=None): + """Given a dict carrys "${...}"-like strings , `eval` the "${...}$" values while keep the dict structure. + + Args: + obj (Union[Dict, Sequence]): The original "Grand" dict or the iterables in it. + local_vars (_type_, optional): _description_. Defaults to None. + global_vars (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + # Gut the caller's locals & globals + if local_vars is None or global_vars is None: + frame = inspect.currentframe().f_back # caller frame + local_vars = local_vars or frame.f_locals + global_vars = global_vars or frame.f_globals + + # 1) dict + if isinstance(obj, Mapping): + return { + k: resolve_formatted_string(v, local_vars, global_vars) + for k, v in obj.items() + } + + # 2) list/tuple + if isinstance(obj, Sequence) and not isinstance(obj, str): + return type(obj)( + resolve_formatted_string(v, local_vars, global_vars) for v in obj + ) + + # 3) str + if isinstance(obj, str): + full = _EXPR.fullmatch(obj.strip()) + if full: + # the whole string is ${expr} -> return eval(expr) + return eval( + full.group(1), + {"__builtins__": None}, # eval with given locals & globals + {**global_vars, **local_vars}, + ) + # par tof the string is ${expr}:replace ...${expr}.. -> str(...eval(expr)..) + def _sub(m): + return str( + eval(m.group(1), {"__builtins__": None}, {**global_vars, **local_vars}) + ) + + return _EXPR.sub(_sub, obj) + + # 4) other type just return + return obj + + +def resolve_params(resolve_func): + """ + Decorator factory that applies `resolve_func` to each argument of the + decorated function, with optional per-decorator `exclude` names. + + If `resolve_func`'s signature is: resolve_func(obj) + then we call: resolve_func(val) + + If its signature is: resolve_func(obj, x, y, ...) + then for each argument `val` of the decorated function we call: resolve_func(val, x=bound['x'], y=bound['y'], ...) pulling `x`, `y`, etc. by name from the decorated function's bound args. + + Usage patterns: + + # 1) create a decorator + resolve_formatted_params = resolve_params(resolve_formatted_string) and use without exclude: + @resolve_formatted_params + def generate_func(a, b, c): ... + + # 2) use the same decorator with an exclude list: + @resolve_formatted_params(exclude=['c']) + def generate_func(a, b, c): ... + + # 3) or inline: + @resolve_params(resolve_env_attr, exclude=['env']) + def generate_func(env, path, mode): ... + + Args: + resolve_func: function whose first parameter is the value to transform. Any additional parameters will be looked up by name in the decorated function's arguments. + + Returns: + A decorator which can be used either as: + @decorator + or: + @decorator(exclude=[...]) + """ + resolve_sig = inspect.signature(resolve_func) + resolve_param_names = list(resolve_sig.parameters.keys()) + + def decorator_factory(*, exclude=()): + exclude = set(exclude) + + def decorator(func): + func_sig = inspect.signature(func) + + @wraps(func) + def wrapper(*args, **kwargs): + bound = func_sig.bind_partial(*args, **kwargs) + bound.apply_defaults() + + # Resolve each argument except those in exclude + resolved = {} + for name, val in bound.arguments.items(): + if name in exclude: + resolved[name] = val + continue + + try: + if len(resolve_param_names) == 1: + # single-arg resolver + resolved_val = resolve_func(val) + else: + # multi-arg resolver: gather extra args by name + extra_kwargs = { + pname: bound.arguments[pname] + for pname in resolve_param_names[1:] + } + resolved_val = resolve_func(val, **extra_kwargs) + resolved[name] = resolved_val + except Exception as e: + log_error(f"{e}") + resolved[name] = val + + # Rebuild positional and keyword args in original order + args_to_pass = [] + kwargs_to_pass = {} + for param in func_sig.parameters.values(): + if param.kind in ( + param.POSITIONAL_ONLY, + param.POSITIONAL_OR_KEYWORD, + ): + if param.name in resolved: + args_to_pass.append(resolved.pop(param.name)) + elif param.kind is param.VAR_POSITIONAL: + args_to_pass.extend(resolved.pop(param.name, ())) + elif param.kind is param.KEYWORD_ONLY: + if param.name in resolved: + kwargs_to_pass[param.name] = resolved.pop(param.name) + elif param.kind is param.VAR_KEYWORD: + kwargs_to_pass.update(resolved.pop(param.name, {})) + + return func(*args_to_pass, **kwargs_to_pass) + + return wrapper + + return decorator + + def decorator_or_factory(func=None, *, exclude=()): + # @decorator + if func is not None and callable(func): + return decorator_factory(exclude=())(func) + # @decorator(exclude=[...]) + return decorator_factory(exclude=exclude) + + return decorator_or_factory + + +resolve_formatted_params = resolve_params(resolve_formatted_string) +resolve_env_params = resolve_params(resolve_env_attr) + + +def transfer_str_to_lambda( + lambda_string: str, locals_dict: Dict = {}, globals_dict: Dict = {} +): + """Transfer the string represented lambda function into a real lambda function. + + Args: + lambda_string (str): The lambda string to be transfer + locals_dict (dict): Read-only dict that carrys local variables for lambda function to use. Defaults to be {}. + globals_dict (dict): Read-only dict that carrys global variables for lambda function to use. Defaults to be {}. + + Returns: + lambda_function: The lambda function + """ + # AST analyze + node = ast.parse(lambda_string, mode="eval") + # Assure the top to be a lambda + if not isinstance(node.body, ast.Lambda): + log_error(f'The lambda string "{lambda_string}" is not illegal.') + # Compile to be a function object + code = compile(node, filename="", mode="eval") + # eval to be a real lambda function + return eval(code, locals_dict, globals_dict) + + +def find_function( + func_name: Union[str, Callable[..., Any]], + instances: List = [], + module_names: List[str] = [], +): + """ + Finds and returns a function by its name from a list of instances or module names. + + Args: + func_name (Union[str, Callable[..., Any]]): The name of the function to find, + or the function itself. + instances (List, optional): A list of instances to search for the function. + Defaults to an empty list. + module_names (List[str], optional): A list of module names to search for the function. + Defaults to an empty list. + + Returns: + Callable[..., Any] or bool: The found function if it exists, otherwise False. + """ + if isinstance(func_name, str): + if any(hasattr(instance := inst, func_name) for inst in instances): + func = getattr(instance, func_name) + elif any( + hasattr((module := import_module(module_name)), func_name) + for module_name in module_names + ): + func = getattr(module, func_name) + else: + return False + else: + func = func_name + return func + + +def find_funcs_with_kwargs( + funcs_name_kwargs_proc: List[Dict[str, Any]], + instances: List, + module_names: List, +): + for func_name_kwargs_proc in funcs_name_kwargs_proc: + func_name = func_name_kwargs_proc["name"] + func = find_function( + func_name, + instances=instances, + module_names=module_names, + ) + if func != False: + func_name_kwargs_proc.update({"func": func}) + else: + log_warning(f"Function {func_name} not found, skipping...") + + return funcs_name_kwargs_proc + + +def validate_with_process( + env, + input: Any, + valid_funcs_kwargs_proc: List[Dict[str, Any]], +): + for valid_func_kwargs_proc in valid_funcs_kwargs_proc: + validation_func = valid_func_kwargs_proc["func"] + kwargs = valid_func_kwargs_proc["kwargs"] + rejected_processes = valid_func_kwargs_proc.get("rejected_processes", None) + pass_processes = valid_func_kwargs_proc.get("pass_processes", None) + + ret = validation_func(input, **kwargs) + if not ret: + log_warning( + f"Validation function {validation_func.__name__} returns False." + ) + if rejected_processes is not None: + log_warning("Processing with rejected_processes..") + for rejected_process in rejected_processes: + rejected_process_func_name = rejected_process["name"] + rejected_process_kwargs = rejected_process.get("kwargs", {}) + + rejected_process_func = find_function( + rejected_process_func_name, + instances=[env], + module_names=[ + __name__, + ], + ) + if rejected_process_func != False: + input = rejected_process_func(input, **rejected_process_kwargs) + else: + log_error( + f"rejected_process_func {rejected_process_func_name} after validation_func {validation_func.__name__} not found." + ) + else: + log_warning("Skipping..") + return None + + if pass_processes is not None: + for pass_process in pass_processes: + pass_process_func_name = pass_process["name"] + pass_process_kwargs = pass_process.get("kwargs", {}) + + pass_process_func = find_function( + pass_process_func_name, + instances=[env], + module_names=[ + __name__, + ], + ) + if pass_process_func != False: + input = pass_process_func(input, **pass_process_kwargs) + else: + log_error( + f"pass_process_func {pass_process_func_name} after validation_func {validation_func.__name__} not found." + ) + + return input + + +def validation_with_process_from_name( + env, + input: List[np.ndarray], + valid_funcs_name_kwargs_proc: List[Dict[str, Any]], + module_names: Optional[List[str]] = None, +): + """Apply a sequence of validation and processing functions (by name) to the input data. + + Args: + env: The environment object, used for method lookup. + input_data: The data to be validated and processed. + valid_funcs_name_kwargs_proc: List of dicts, each specifying a function name and kwargs. + module_names: List of module names to search for functions. Defaults to [__name__]. + + Returns: + The processed data if all validations pass, otherwise None. + """ + if valid_funcs_name_kwargs_proc is None: + valid_funcs_name_kwargs_proc = [] + if module_names is None: + module_names = [__name__] + + valid_funcs_kwargs_proc = find_funcs_with_kwargs( + valid_funcs_name_kwargs_proc, + instances=[env], + module_names=[__name__], + ) + valid_output = validate_with_process(env, input, valid_funcs_kwargs_proc) + return valid_output + + +def _get_valid_grasp( + env, + grasp_list: List[np.ndarray], + valid_funcs_name_kwargs_proc: List[Union[str, Dict[str, Any]]], +) -> np.ndarray: + """TODO 懒狗了,总而言之言而总之就是一个函数,可以集成一堆validation function,检验grasp_pose是否valid,也可以再特定的alidatrion function后面跟一堆process + + Args: + env: TODO + grasp_list (List[np.ndarray]): TODO + validation_func_names_kwargs (Dict[str, dict]): TODO + validation_func_names_process (Optional[Dict[str, Dict[str, dict]]], optional): TODO. Defaults to None. + + Returns: + np.ndarray: TODO + """ + valid_func_kwargs_proc = find_funcs_with_kwargs( + valid_funcs_name_kwargs_proc, instances=[env], module_names=[__name__] + ) + + for grasp in grasp_list: + grasp_pose = grasp.pose # TODO: be a func? + grasp_pose = validate_with_process(env, grasp_pose, valid_func_kwargs_proc) + # The loop is broken as ONE validation results is False + if grasp_pose is None: + continue + # All validation results are True in the loop + else: + return grasp_pose + return None + + +def lru_cache_n(maxsize: int = 10, max_count: int = 2) -> Callable: + """ + Decorator to provide an LRU cache with a maximum call count per key. + After a key is accessed `max_count` times, the result will be recomputed. + + Args: + maxsize: Maximum number of cache entries. + max_count: Number of times a cached result can be returned before recomputation. + + Returns: + Decorator for caching function results. + """ + + def decorator(func): + cache = OrderedDict() + + def _make_hashable(x): + try: + hash(x) + return x + except TypeError: + if isinstance(x, np.ndarray): + return (x.shape, str(x.dtype), x.tobytes()) + if isinstance(x, dict): + return tuple(sorted((k, _make_hashable(v)) for k, v in x.items())) + if isinstance(x, set): + return tuple(sorted(_make_hashable(i) for i in x)) + if isinstance(x, (list, tuple)): + return tuple(_make_hashable(i) for i in x) + raise TypeError(f"Unhashable type in cache key: {type(x)}") + + @wraps(func) + def wrapper(*args, **kwargs): + key = ( + tuple(_make_hashable(a) for a in args), + tuple(sorted((k, _make_hashable(v)) for k, v in kwargs.items())), + ) + res, cnt = cache.pop(key, (None, max_count)) + if cnt == max_count: + res, cnt = func(*args, **kwargs), 0 + cache[key] = (res, cnt + 1) + if len(cache) > maxsize: + cache.popitem(last=False) + return res + + return wrapper + + return decorator + + +def multi_output_factory_function( + func_name: Union[str, Callable], + instances: Optional[List] = None, + module_names: Optional[List[str]] = None, + output_num: int = 1, +) -> Callable: + """ + Factory to create a cached version of a function that may have multiple outputs. + + Args: + func_name: Function name (str) or function object. + instances: The instances that may carrys the method that match the func_name. + Usually be environment that carrys the methods wherein the function may be found. + module_names: The list of modules that the function may be found. Defaults to []. + output_num: Number of outputs expected from the function. + + Returns: + Cached function callable. + """ + if instances is None: + instances = [] + if module_names is None: + module_names = [] + func = find_function(func_name, instances, module_names) + if not callable(func): + raise ValueError(f"Function {func_name} not found or not callable.") + + max_count = max(1, output_num - 1) + + @lru_cache_n(max_count=max_count) + def cached_func(*args, **kwargs): + return func(*args, **kwargs) + + return cached_func + + +def cached_ik( + target_xpos: np.ndarray, + ik_func: Union[str, Callable], + control_part: str, + is_left: bool, + qpos_seed: np.ndarray, + instances: list = [], + module_names: list = [], +) -> tuple: + """ + Call the inverse kinematics (IK) function with caching for efficiency. + + Args: + target_xpos: The target end-effector position (usually a numpy array). + ik_func: The IK function or function name to be called. + control_part: String of the cotrol part for IK computing. + is_left: Whether the control part is on the left side. + qpos_seed: The initial guess for the joint positions. + instances: The instances that may carrys the method that match the func_name. + Usually be environment that carrys the methods wherein the function may be found. + module_names: The list of modules that the function may be found. Defaults to []. + + Returns: + Tuple: (ik_result, qpos_result), where ik_result is the IK status and qpos_result is the joint solution. + """ + # cached_ik_func = multi_output_factory_function("_get_arm_ik", instances=[env], module_names=[__name__], output_num=2) + cached_ik_func = multi_output_factory_function( + ik_func, instances=instances, module_names=module_names, output_num=2 + ) + if control_part == "none": + return cached_ik_func(target_xpos, is_left, qpos_seed) + + ret, qpos = cached_ik_func(torch.as_tensor(target_xpos), qpos_seed, control_part) + if isinstance(ret, torch.Tensor): + ret = ret.all().item() + return ret, qpos.squeeze(0).cpu().numpy() + + +def get_ik_ret( + target_xpos: np.ndarray, + ik_func: Union[str, Callable], + qpos_seed: np.ndarray, + control_part: str = "none", + is_left: bool = True, + instances: list = [], + module_names: list = [], +) -> bool: + """ + Get the first return value from the cached IK function, typically the IK status or result flag. + + Args: + target_xpos: The target end-effector position. + ik_func: The IK function or function name to be called. + control_part: String of the cotrol part for IK computing. + qpos_seed: The initial guess for the joint positions. + instances: The instances that may carrys the method that match the func_name. + Usually be environment that carrys the methods wherein the function may be found. + module_names: The list of modules that the function may be found. Defaults to []. + + Returns: + The first output of the IK function (e.g., success flag or status). + """ + ret = cached_ik( + target_xpos, + ik_func, + control_part, + is_left, + qpos_seed, + instances=instances, + module_names=module_names, + )[0] + return ret + + +def get_ik_qpos( + target_xpos: np.ndarray, + ik_func: Union[str, Callable], + qpos_seed: np.ndarray, + control_part: str = "none", + is_left: bool = True, + instances: list = [], + module_names: list = [], +) -> np.ndarray: + """ + Get the second return value from the cached IK function, typically the joint positions. + + Args: + target_xpos: The target end-effector position. + ik_func: The IK function or function name to be called. + control_part: String of the control part for IK computing. + is_left: Whether the control part is on the left side. Defaults to True. + qpos_seed: The initial guess for the joint positions. + instances: The instances that may carrys the method that match the func_name. + Usually be environment that carrys the methods wherein the function may be found. + module_names: The list of modules that the function may be found. Defaults to []. + + Returns: + The second output of the IK function (e.g., the joint position solution). + """ + qpos = cached_ik( + target_xpos, + ik_func, + control_part, + is_left, + qpos_seed, + instances=instances, + module_names=module_names, + )[1] + return qpos + + +def get_fk_xpos( + target_qpos: np.ndarray, + control_part: str, + fk_func: Union[str, Callable], +) -> np.ndarray: + xpos = fk_func(name=control_part, qpos=torch.as_tensor(target_qpos), to_matrix=True) + + # the xpos computed from robot is in the local arena frame, which is equivalent to world frame of the + # old version. + return xpos.squeeze(0).cpu().numpy() + + +# FIXME: remove +def _data_key_to_control_part(robot, control_parts, data_key: str) -> Optional[str]: + # TODO: Temporary workaround, should be removed after refactoring data dict extractor. + # @lru_cache(max_size=None) # NOTE: no way to pass a hashable parameter + def is_eef_hand(robot, control_parts) -> bool: + # TODO: This is a temporary workaround, should be used a more general method to check + # whether the end-effector is a hand. + for part in control_parts: + if "eef" in part: + joint_ids = robot.get_joint_ids(part, remove_mimic=True) + return len(joint_ids) >= 2 + return False + + if "left_arm" in data_key: + if "qpos" in data_key: + return "left_arm" + if "hand" in data_key and is_eef_hand(robot, control_parts): + return "left_eef" + if "gripper" in data_key and is_eef_hand(robot, control_parts) is False: + return "left_eef" + return None + + if "right_arm" in data_key: + if "qpos" in data_key: + return "right_arm" + if "hand" in data_key and is_eef_hand(robot, control_parts): + return "right_eef" + if "gripper" in data_key and is_eef_hand(robot, control_parts) is False: + return "right_eef" + return None + + +# FIXME: only for v3 W1 +def map_ee_state_to_env_actions( + robot, ee_state: np.ndarray, env_actions: np.ndarray +) -> np.ndarray: + """ + Map end-effector (gripper) state to environment joint actions. + + Args: + ee_state (np.ndarray): Normalized gripper state, shape (batch, 2). + env_actions (np.ndarray): Environment joint actions to be updated. + + Returns: + np.ndarray: Updated environment joint actions with gripper positions. + """ + from embodichain.data.enum import ControlParts, JointType, EndEffector + + left_eef_limits = ( + robot.body_data.qpos_limits.squeeze(0) + .cpu() + .numpy()[robot.get_joint_ids(name=ControlParts.LEFT_EEF.value)] + ) + right_eef_limits = ( + robot.body_data.qpos_limits.squeeze(0) + .cpu() + .numpy()[robot.get_joint_ids(name=ControlParts.RIGHT_EEF.value)] + ) + + def w1_gripper_mapping(normalized_state, eef_limits): + # Define normalized open/close positions for the gripper (range 0-1) + open_state_normalized = np.array([90.0, 0.0, 0.0, 0.0, 0.0, 0.0]) / 100.0 + close_state_normalized = np.array([90.0, 55.0, 30.0, 30.0, 30.0, 30.0]) / 100.0 + + # Convert normalized values to actual joint angles + open_state_actual = eef_limits[:, 0] + open_state_normalized * ( + eef_limits[:, 1] - eef_limits[:, 0] + ) + close_state_actual = eef_limits[:, 0] + close_state_normalized * ( + eef_limits[:, 1] - eef_limits[:, 0] + ) + + # Interpolate between open and close joint angles + if isinstance(normalized_state, np.ndarray) and normalized_state.ndim > 0: + return ( + open_state_actual * (1 - normalized_state[:, None]) + + close_state_actual * normalized_state[:, None] + ) + else: + return ( + open_state_actual * (1 - normalized_state) + + close_state_actual * normalized_state + ) + + if ee_state.ndim == 1: + ee_state = ee_state.reshape(1, -1) + if env_actions.ndim == 1: + env_actions = env_actions.reshape(1, -1) + + # Map normalized gripper state to actual joint positions + left_hand_qpos = w1_gripper_mapping(ee_state[:, 0], left_eef_limits) + right_hand_qpos = w1_gripper_mapping(ee_state[:, 1], right_eef_limits) + + # Get indices for left and right end-effector joints + left_eef_ids = robot.get_joint_ids(name=ControlParts.LEFT_EEF.value) + right_eef_ids = robot.get_joint_ids(name=ControlParts.RIGHT_EEF.value) + + env_actions[:, left_eef_ids] = left_hand_qpos + env_actions[:, right_eef_ids] = right_hand_qpos + return env_actions diff --git a/embodichain/lab/gym/utils/registration.py b/embodichain/lab/gym/utils/registration.py new file mode 100644 index 00000000..a3e9c61f --- /dev/null +++ b/embodichain/lab/gym/utils/registration.py @@ -0,0 +1,204 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import sys +from copy import deepcopy +from functools import partial +from typing import TYPE_CHECKING, Dict, Type + +import gymnasium as gym +from gymnasium.envs.registration import EnvSpec as GymEnvSpec +from gymnasium.envs.registration import WrapperSpec + +from dexsim.utility import log_warning + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import BaseEnv + + +class EnvSpec: + def __init__( + self, + uid: str, + cls: Type[BaseEnv], + max_episode_steps=None, + default_kwargs: dict = None, + ): + """A specification for a Embodied environment.""" + self.uid = uid + self.cls = cls + self.max_episode_steps = max_episode_steps + self.default_kwargs = {} if default_kwargs is None else default_kwargs + + def make(self, **kwargs): + _kwargs = self.default_kwargs.copy() + _kwargs.update(kwargs) + return self.cls(**_kwargs) + + @property + def gym_spec(self): + """Return a gym EnvSpec for this env""" + entry_point = self.cls.__module__ + ":" + self.cls.__name__ + return GymEnvSpec( + self.uid, + entry_point, + max_episode_steps=self.max_episode_steps, + kwargs=self.default_kwargs, + ) + + +REGISTERED_ENVS: Dict[str, EnvSpec] = {} + + +def register( + name: str, cls: Type[BaseEnv], max_episode_steps=None, default_kwargs: dict = None +): + """Register a Embodied environment.""" + + # hacky way to avoid circular import errors when users inherit a task in DexSim and try to register it themselves + from embodichain.lab.gym.envs import BaseEnv, BaseEnv + + if name in REGISTERED_ENVS: + log_warning(f"Env {name} already registered") + if not (issubclass(cls, BaseEnv) or issubclass(cls, BaseEnv)): + raise TypeError(f"Env {name} must inherit from BaseEnv or BaseEnv") + REGISTERED_ENVS[name] = EnvSpec( + name, cls, max_episode_steps=max_episode_steps, default_kwargs=default_kwargs + ) + + +class TimeLimitWrapper(gym.Wrapper): + """like the standard gymnasium timelimit wrapper but fixes truncated variable to be a batched array""" + + def __init__(self, env: gym.Env, max_episode_steps: int): + super().__init__(env) + prev_frame_locals = sys._getframe(1).f_locals + frame = sys._getframe(1) + # check for user supplied max_episode_steps during gym.make calls + if frame.f_code.co_name == "make" and "max_episode_steps" in prev_frame_locals: + if prev_frame_locals["max_episode_steps"] is not None: + max_episode_steps = prev_frame_locals["max_episode_steps"] + # do some wrapper surgery to remove the previous timelimit wrapper + # with gymnasium 0.29.1, this will remove the timelimit wrapper and nothing else. + curr_env = env + while curr_env is not None: + if isinstance(curr_env, gym.wrappers.TimeLimit): + self.env = curr_env.env + break + self._max_episode_steps = max_episode_steps + + @property + def base_env(self) -> BaseEnv: + return self.env.unwrapped + + def step(self, action): + observation, reward, terminated, truncated, info = self.env.step(action) + truncated = truncated | (self.base_env.elapsed_steps >= self._max_episode_steps) + return observation, reward, terminated, truncated, info + + +def make(env_id, **kwargs): + """Instantiate a Embodied environment. + + Args: + env_id (str): Environment ID. + as_gym (bool, optional): Add TimeLimit wrapper as gym. + **kwargs: Keyword arguments to pass to the environment. + """ + if env_id not in REGISTERED_ENVS: + raise KeyError("Env {} not found in registry".format(env_id)) + env_spec = REGISTERED_ENVS[env_id] + + env = env_spec.make(**kwargs) + return env + + +def make_vec(env_id, **kwargs): + env = gym.make(env_id, **kwargs) + return env + + +def register_env(uid: str, max_episode_steps=None, override=False, **kwargs): + """A decorator to register Embodied environments. + + Args: + uid (str): unique id of the environment. + max_episode_steps (int): maximum number of steps in an episode. + override (bool): whether to override the environment if it is already registered. + + Notes: + - `max_episode_steps` is processed differently from other keyword arguments in gym. + `gym.make` wraps the env with `gym.wrappers.TimeLimit` to limit the maximum number of steps. + - `gym.EnvSpec` uses kwargs instead of **kwargs! + """ + try: + json.dumps(kwargs) + except TypeError: + raise RuntimeError( + f"You cannot register_env with non json dumpable kwargs, e.g. classes or types. If you really need to do this, it is recommended to create a mapping of string to the unjsonable data and to pass the string in the kwarg and during env creation find the data you need" + ) + + def _register_env(cls): + cls = register_env_function(cls, uid, override, max_episode_steps, **kwargs) + return cls + + return _register_env + + +def register_env_function(cls, uid, override=False, max_episode_steps=None, **kwargs): + if uid in REGISTERED_ENVS: + if override: + from gymnasium.envs.registration import registry + + log_warning(f"Override registered env {uid}") + REGISTERED_ENVS.pop(uid) + registry.pop(uid) + else: + log_warning(f"Env {uid} is already registered. Skip registration.") + return cls + + # Register for ManiSkil2 + register( + uid, + cls, + max_episode_steps=max_episode_steps, + default_kwargs=deepcopy(kwargs), + ) + + # Register for gym + gym.register( + uid, + entry_point=partial(make, env_id=uid), + vector_entry_point=partial(make_vec, env_id=uid), + max_episode_steps=max_episode_steps, + disable_env_checker=True, # Temporary solution as we allow empty observation spaces + kwargs=deepcopy(kwargs), + additional_wrappers=[ + WrapperSpec( + "MSTimeLimit", + entry_point="embodichain.lab.gym.utils.registration:TimeLimitWrapper", + kwargs=dict(max_episode_steps=max_episode_steps) + if max_episode_steps is not None + else {}, + ) + ] + if max_episode_steps is not None + else [], + ) + return cls diff --git a/embodichain/lab/scripts/generate_video.py b/embodichain/lab/scripts/generate_video.py new file mode 100644 index 00000000..4b51b3dd --- /dev/null +++ b/embodichain/lab/scripts/generate_video.py @@ -0,0 +1,162 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.utils.logger import log_info, log_warning + +import h5py +import argparse +import numpy as np +import os + +from tqdm import tqdm +from dexsim.utility import images_to_video +from typing import Dict, Callable, Tuple +from embodichain.utils.visualizer import draw_keypoints, draw_action_distribution +from embodichain.data.enum import EefType, JointType, Modality, PrivilegeType +from embodichain.data.data_engine.unified_state import ActionIndicesGenerator + + +class VideoCreator: + def __init__(self) -> None: + pass + + @staticmethod + def _sub_function( + images, + output_path, + video_key, + exteroceptions: Dict = None, + multiplier: int = 1, + drawer: Callable = lambda x: x, + ): + for key in images.keys(): + imgs = images[key] + if imgs is None: + log_warning(f"No images found for key: {key}. Skipping.") + continue + img_list = [] + for i in tqdm(range(imgs.shape[0])): + image_i = drawer(imgs[i] * multiplier) + if exteroceptions is not None and len(exteroceptions[key]) != 0: + image_i = draw_keypoints( + image_i, exteroceptions[key][i].reshape(-1, 2) + ) + img_list.append(image_i) + + images_to_video(img_list, output_path, f"{key}_{video_key}") + + @staticmethod + def monocular_save( + observations: Dict, + video_key: str, + output_path: str, + multiplier: int = 1, + drawer: Callable = lambda x: x, + draw_exteroception: bool = True, + ): + images = observations[video_key] + if ( + PrivilegeType.EXTEROCEPTION.value in observations.keys() + and draw_exteroception + ): + exteroceptions = observations[PrivilegeType.EXTEROCEPTION.value] + else: + exteroceptions = None + VideoCreator._sub_function( + images, + output_path, + video_key, + exteroceptions, + multiplier, + drawer, + ) + + +def visualize_data_dict(f: Dict, output_path: str): + observations = f["observations"] + + if PrivilegeType.MASK.value in observations.keys(): + VideoCreator.monocular_save( + observations, + PrivilegeType.MASK.value, + output_path, + 255, + draw_exteroception=False, + ) + + if Modality.GEOMAP.value in observations.keys(): + from embodichain.utils.img_utils import gen_disp_colormap + + VideoCreator.monocular_save( + observations, + Modality.GEOMAP.value, + output_path, + 1, + lambda x: (gen_disp_colormap(x).transpose(1, 2, 0) * 255).astype(np.uint8), + draw_exteroception=False, + ) + + VideoCreator.monocular_save(observations, Modality.IMAGES.value, output_path) + + +def main(args): + + data_path = args.data_path + output_path = args.output_path + assert data_path.endswith(".hdf5"), "Data path must have format of .hdf5" + with h5py.File(data_path, "r") as f: + from embodichain.data.data_engine.data_dict_extractor import ( + CompressedVideoHDF5, + ) + import hdfdict + + data = hdfdict.load(data_path) + data = CompressedVideoHDF5(output_path).safe_filter(data) + + visualize_data_dict(data, output_path) + robot_meta = data["robot_meta"] + arm_dofs = robot_meta["arm_dofs"][()] + indices_generator = ActionIndicesGenerator(arm_dofs) + + actions = f[Modality.ACTIONS.value][()] + key_names = indices_generator.global_mapping.mapping_from_name_to_indices.keys() + log_info(f"Arm dofs: {arm_dofs}", color="green") + indices_dict = {} + for key_name in key_names: + indices_dict[key_name] = indices_generator.get([key_name]) + draw_action_distribution(actions, indices_dict, output_path, smooth=args.smooth) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data_path", type=str, help="Path to the data file.") + parser.add_argument( + "--output_path", + type=str, + help="Path to the output video file.", + default="./outputs", + ) + parser.add_argument( + "--smooth", + action="store_true", + default=False, + help="whether smooth joints.", + ) + args = parser.parse_args() + + main(args) diff --git a/embodichain/lab/scripts/preview_env.py b/embodichain/lab/scripts/preview_env.py new file mode 100644 index 00000000..9357584d --- /dev/null +++ b/embodichain/lab/scripts/preview_env.py @@ -0,0 +1,149 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import gymnasium +import argparse +import numpy as np + +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.utils.utility import load_json +from embodichain.utils import logger + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--enable_sensors_in_step", + help="Whether to enable sensors in each step of the simulation.", + default=False, + action="store_true", + ) + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument( + "--action_config", + type=str, + help="Path to the action configuration file.", + default=None, + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + args = parser.parse_args() + + """ + TODO: Currently, this file is only used to preview the template.json config file. + We may add more features to support more general case parsing from config files. + """ + + ############################################################################################## + # load gym config + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + action_config = {} + if args.action_config is not None: + action_config = load_json(args.action_config) + action_config["action_config"] = action_config + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make(id=gym_config["id"], cfg=cfg, **action_config) + + if args.enable_sensors_in_step is False: + pass + + obs, info = env.reset() + + """ + Run the following code to create a demonstration and perform env steps. + + ``` + # Demo version of environment rollout + for i in range(10): + qpos = env.robot.get_qpos() + + obs, reward, done, truncated, info = env.step(qpos) + + # reset the environment + env.reset() + ``` + + Run the following code to preview the sensor observations. + + ``` + env.preview_sensor_data("camera") + ``` + """ + + end = False + while end is False: + print("Press `p` to into embed mode to interact with the environment.") + print("Press `q` to quit the simulation.") + txt = input() + if txt == "p": + from IPython import embed + + embed() + elif txt == "q": + end = True diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py new file mode 100644 index 00000000..5fcb07fa --- /dev/null +++ b/embodichain/lab/scripts/run_env.py @@ -0,0 +1,301 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import gymnasium +import numpy as np +import argparse +import os +import torch + +from threading import Thread + +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, +) +from embodichain.lab.scripts.generate_video import visualize_data_dict +from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, +) +from embodichain.utils.logger import log_warning, log_info, log_error +from embodichain.lab.sim.cfg import MarkerCfg + + +def generate_and_execute_action_list(env, idx, debug_mode): + + action_list = env.create_demo_action_list(action_sentence=idx) + + # TODO: To be modified. + # if debug_mode: + # env.visual_action(action_list) + + if action_list is None or len(action_list) == 0: + log_warning("Action is invalid. Skip to next generation.") + return False + + for action in action_list: + # Step the environment with the current action + obs, reward, terminated, truncated, info = env.step(action) + + # TODO: To be modified. + # if debug_mode: + # xpos_dict = env.agent.get_debug_xpos_dict() + + # for key, val in xpos_dict.items(): + # env.scene.draw_marker(cfg=MarkerCfg( + # marker_type="axis", + # axis_xpos=val, + # axis_size=0.002, + # axis_len=0.005 + # )) + + # for key, val in xpos_dict.items(): + # env.scene.remove_fixed_actor(key) + + # TODO: we may assume in export demonstration rollout, there is no truncation from the env. + # but truncation is useful to improve the generation efficiency. + + return True + + +def generate_function( + env, + obj_num, + time_id: int = 0, + online_training: bool = False, + save_path: str = "", + save_video: bool = False, + debug_mode: bool = False, + **kwargs, +): + """ + Generate and execute a sequence of actions in the environment. + + This function resets the environment, generates and executes action trajectories, + collects data, and optionally saves videos of the episodes. It supports both online + and offline data generation modes. + + Args: + env: The environment instance. + obj_num (int): Number of trajectories to generate per episode. + time_id (int, optional): Identifier for the current time step or episode. + online_training (bool, optional): Whether to use online data generation. + save_path (str, optional): Path to save generated videos. + save_video (bool, optional): Whether to save episode videos. + debug_mode (bool, optional): Enable debug mode for visualization and logging. + **kwargs: Additional keyword arguments for data generation. + + Returns: + list or bool: Returns a list of data dicts if online_training is True, + otherwise returns True if generation is successful. + """ + + def wait_for_threads(threads): + for t in threads: + t.join() + + vis_threads = [] + valid = True + while True: + _, _ = env.reset() + + ret = [] + for trajectory_idx in range(obj_num): + valid = generate_and_execute_action_list(env, trajectory_idx, debug_mode) + + if not valid: + log_warning("Invalid action, skipping trajectory.") + break + + if not debug_mode and env.is_task_success().item(): + # Create a unique identifier for the dataset entry + dataset_id = f"time_{time_id}_trajectory_{trajectory_idx}" + if online_training: + dataset_id += "_online_generated" + num_samples = kwargs.get("num_samples", 0) + is_save_dataset = time_id < num_samples + + data_dict = env.to_dataset( + id=dataset_id if is_save_dataset else None, + ) + + ret.append(data_dict) + else: + data_dict = env.to_dataset( + id=dataset_id, + ) + + episode = getattr(env, "get_current_episode", lambda: time_id)() + + if save_video: + video_path = os.path.join(save_path, f"episode_{episode}") + if online_training: + vis_thread = Thread( + target=visualize_data_dict, + args=(data_dict["data"], video_path), + daemon=True, + ) + vis_thread.start() + vis_threads.append(vis_thread) + else: + visualize_data_dict(data_dict["data"], video_path) + + else: + log_warning(f"Task fail, Skip to next generation.") + valid = False + break + + if valid: + break + else: + log_warning("Reset valid flag to True.") + valid = True + + wait_for_threads(vis_threads) + return ret if online_training else True + + +def main(args, env, gym_config): + is_online_training = os.path.exists(args.online_config) + if is_online_training: + + log_info("Start online data generation.", color="green") + assert os.path.exists(args.online_config), "{} does not exist.".format( + args.online_config + ) + + online_config = load_json(args.online_config) + online_callback = OnlineGenerator(**online_config) + + obj_num = 1 + generator_func = lambda time_id, **kwargs: generate_function( + env, + obj_num, + time_id, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + headless=args.headless, + **kwargs, + ) + online_callback.generator(generator_func, **online_config) + else: + log_info("Start offline data generation.", color="green") + obj_num = 1 + for i in range(gym_config["max_episodes"]): + generate_function( + env, + obj_num, + i, + online_training=is_online_training, + save_path=args.save_path, + save_video=args.save_video, + debug_mode=args.debug_mode, + ) + + if args.headless: + env.reset(options={"final": True}) + + +if __name__ == "__main__": + np.set_printoptions(5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + parser = argparse.ArgumentParser() + # parser.add_argument("--task_type", help="Type of task to perform.") + # parser.add_argument("--robot_name", help="Name of the robot.") + parser.add_argument( + "--num_envs", + help="The number of environments to run in parallel.", + default=1, + type=int, + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument( + "--headless", + help="Whether to perform the simulation in headless mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--enable_rt", + help="Whether to use RTX rendering backend for the simulation.", + default=False, + action="store_true", + ) + parser.add_argument( + "--gpu_id", + help="The GPU ID to use for the simulation.", + default=0, + type=int, + ) + parser.add_argument( + "--save_video", + help="Whether to save data as video.", + default=False, + action="store_true", + ) + parser.add_argument( + "--save_path", help="path", default="./outputs/thirdviewvideo", type=str + ) + parser.add_argument( + "--debug_mode", + help="Enable debug mode.", + default=False, + action="store_true", + ) + parser.add_argument( + "--filter_visual_rand", + help="Whether to filter out visual randomization.", + default=False, + action="store_true", + ) + + parser.add_argument("--online_config", type=str, help="online_config", default="") + parser.add_argument("--gym_config", type=str, help="gym_config", default="") + parser.add_argument("--action_config", type=str, help="action_config", default=None) + + args = parser.parse_args() + + if args.num_envs != 1: + log_error(f"Currently only support num_envs=1, but got {args.num_envs}.") + + gym_config = load_json(args.gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + action_config = {} + if args.action_config is not None: + action_config = load_json(args.action_config) + action_config["action_config"] = action_config + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make(id=gym_config["id"], cfg=cfg, **action_config) + main(args, env, gym_config) diff --git a/embodichain/lab/sim/__init__.py b/embodichain/lab/sim/__init__.py new file mode 100644 index 00000000..9ec8105a --- /dev/null +++ b/embodichain/lab/sim/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .material import VisualMaterialCfg, VisualMaterial, VisualMaterialInst +from .common import BatchEntity +from .sim_manager import * diff --git a/embodichain/lab/sim/cfg.py b/embodichain/lab/sim/cfg.py new file mode 100644 index 00000000..396c71f3 --- /dev/null +++ b/embodichain/lab/sim/cfg.py @@ -0,0 +1,1051 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations +import os +import numpy as np +import torch + +from typing import Sequence, Union, Dict, Literal, List, Optional, Any +from dataclasses import field, MISSING + +from dexsim.types import ( + PhysicalAttr, + ActorType, + AxisArrowType, + AxisCornerType, + VoxelConfig, + SoftBodyAttr, + SoftBodyMaterialModel, +) +from embodichain.utils import configclass, is_configclass +from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATA_ROOT +from embodichain.data import get_data_path +from embodichain.utils import logger +from embodichain.utils.utility import key_in_nested_dict + +from .shapes import ShapeCfg, MeshCfg + + +@configclass +class PhysicsCfg: + gravity: np.ndarray = field(default_factory=lambda: np.array([0, 0, -9.81])) + bounce_threshold: float = 2.0 + enable_pcm: bool = True + enable_tgs: bool = True + enable_ccd: bool = False + enable_enhanced_determinism: bool = False + enable_friction_every_iteration: bool = True + + length_tolerance: float = 0.05 + """The length tolerance for the simulation. + + Note: the larger the tolerance, the faster the simulation will be. + """ + speed_tolerance: float = 0.25 + """The speed tolerance for the simulation. + + Note: the larger the tolerance, the faster the simulation will be. + """ + + def to_dexsim_args(self) -> Dict[str, Any]: + """Convert to dexsim physics args dictionary.""" + args = { + "gravity": self.gravity.tolist(), + "bounce_threshold": self.bounce_threshold, + "enable_pcm": self.enable_pcm, + "enable_tgs": self.enable_tgs, + "enable_ccd": self.enable_ccd, + "enable_enhanced_determinism": self.enable_enhanced_determinism, + "enable_friction_every_iteration": self.enable_friction_every_iteration, + } + return args + + +@configclass +class MarkerCfg: + """Configuration for visual markers in the simulation. + + This class defines properties for creating visual markers such as coordinate frames, + lines, and points that can be used for debugging, visualization, or reference purposes + in the simulation environment. + """ + + name: str = "empty-mesh" + """Name of the marker for identification purposes.""" + marker_type: Literal["axis", "line", "point"] = "axis" + """Type of marker to display. Can be 'axis' (3D coordinate frame), 'line', or 'point'. (only axis supported now)""" + axis_xpos: List[np.ndarray] = None + """List of 4x4 transformation matrices defining the position and orientation of each axis marker.""" + axis_size: float = 0.002 + """Thickness/size of the axis lines in meters.""" + axis_len: float = 0.005 + """Length of each axis arm in meters.""" + line_color: List[float] = [1, 1, 0, 1.0] + """RGBA color values for the marker lines. Values should be between 0.0 and 1.0.""" + arrow_type: AxisArrowType = AxisArrowType.CONE + """Type of arrow head for axis markers (e.g., CONE, ARROW, etc.).""" + corner_type: AxisCornerType = AxisCornerType.SPHERE + """Type of corner/joint visualization for axis markers (e.g., SPHERE, CUBE, etc.).""" + arena_index: int = -1 + """Index of the arena where the marker should be placed. -1 means all arenas.""" + + +@configclass +class GPUMemoryCfg: + """A gpu memory configuration dataclass that neatly holds all parameters that configure physics GPU memory for simulation""" + + temp_buffer_capacity: int = 2**24 + """Increase this if you get 'PxgPinnedHostLinearMemoryAllocator: overflowing initial allocation size, increase capacity to at least %.' """ + max_rigid_contact_count: int = 2**19 + """Increase this if you get 'Contact buffer overflow detected'""" + max_rigid_patch_count: int = ( + 2**18 + ) # 81920 is DexSim default but most tasks work with 2**18 + """Increase this if you get 'Patch buffer overflow detected'""" + heap_capacity: int = 2**26 + found_lost_pairs_capacity: int = ( + 2**25 + ) # 262144 is DexSim default but most tasks work with 2**25 + found_lost_aggregate_pairs_capacity: int = 2**10 + total_aggregate_pairs_capacity: int = 2**10 + + +@configclass +class RigidBodyAttributesCfg: + """Physical attributes for rigid bodies. + + There are three parts of attributes that can be set: + 1. The dynamic properties, such as mass, damping, etc. + 2. The collision properties. + 3. The physics material properties. + """ + + mass: float = 1.0 + # set mass to 0 will use density to calculate mass. + density: float = 1000.0 + + angular_damping: float = 0.7 + linear_damping: float = 0.7 + max_depenetration_velocity: float = 10.0 + sleep_threshold: float = 0.001 + min_position_iters: int = 4 + min_velocity_iters: int = 1 + + max_linear_velocity: float = 1e2 + max_angular_velocity: float = 1e2 + + # collision properties. + enable_ccd: bool = False + contact_offset: float = 0.002 + rest_offset: float = 0.001 + enable_collision: bool = True + + # physics material properties. + restitution: float = 0.0 + dynamic_friction: float = 0.5 + static_friction: float = 0.5 + + def attr(self) -> PhysicalAttr: + """Convert to dexsim PhysicalAttr""" + attr = PhysicalAttr() + attr.mass = self.mass + attr.contact_offset = self.contact_offset + attr.rest_offset = self.rest_offset + attr.enable_collision = self.enable_collision + attr.dynamic_friction = self.dynamic_friction + attr.static_friction = self.static_friction + attr.angular_damping = self.angular_damping + attr.linear_damping = self.linear_damping + attr.sleep_threshold = self.sleep_threshold + attr.restitution = self.restitution + attr.enable_ccd = self.enable_ccd + attr.max_depenetration_velocity = self.max_depenetration_velocity + attr.min_position_iters = self.min_position_iters + attr.min_velocity_iters = self.min_velocity_iters + return attr + + @classmethod + def from_dict( + cls, init_dict: Dict[str, Union[str, float, int]] + ) -> RigidBodyAttributesCfg: + """Initialize the configuration from a dictionary.""" + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +@configclass +class SoftbodyVoxelAttributesCfg: + # voxel config + triangle_remesh_resolution: int = 8 + """Resolution to remesh the softbody mesh before building physx collision mesh.""" + + triangle_simplify_target: int = 0 + """Simplify mesh faces to target value. Do nothing if this value is zero.""" + + # TODO: this value will be automatically computed with simulation_mesh_resolution and mesh scale. + maximal_edge_length: float = 0 + # """To shorten edges that are too long, additional points get inserted at their center leading to a subdivision of the input mesh. Do nothing if this value is zero.""" + + simulation_mesh_resolution: int = 8 + """Resolution to build simulation voxelize textra mesh. This value must be greater than 0.""" + + simulation_mesh_output_obj: bool = False + """Whether to output the simulation mesh as an obj file for debugging.""" + + def attr(self) -> VoxelConfig: + """Convert to dexsim VoxelConfig""" + attr = VoxelConfig() + attr.triangle_remesh_resolution = self.triangle_remesh_resolution + attr.maximal_edge_length = self.maximal_edge_length + attr.simulation_mesh_resolution = self.simulation_mesh_resolution + attr.triangle_simplify_target = self.triangle_simplify_target + return attr + + +@configclass +class SoftbodyPhysicalAttributesCfg: + # material properties + youngs: float = 1e6 + """Young's modulus (higher = stiffer).""" + + poissons: float = 0.45 + """Poisson's ratio (higher = closer to incompressible).""" + + dynamic_friction: float = 0.0 + """Dynamic friction coefficient.""" + + elasticity_damping: float = 0.0 + """Elasticity damping factor.""" + + # soft body properties + material_model: SoftBodyMaterialModel = SoftBodyMaterialModel.CO_ROTATIONAL + """Material constitutive model.""" + + # --- Mode / collision switches --- + enable_kinematic: bool = False + """If True, (partially) kinematic behavior is enabled.""" + + enable_ccd: bool = False + """Enable continuous collision detection (CCD).""" + + enable_self_collision: bool = False + """Enable self-collision handling.""" + + has_gravity: bool = True + """Whether the soft body is affected by gravity.""" + + # --- Self-collision & simplification parameters --- + self_collision_stress_tolerance: float = 0.9 + """Stress tolerance threshold for self-collision constraints.""" + + collision_mesh_simplification: bool = True + """Whether to simplify the collision mesh for self-collision.""" + + self_collision_filter_distance: float = 0.1 + """Distance threshold below which vertex pairs may be filtered from self-collision checks.""" + + # --- Damping, sleep & settling --- + vertex_velocity_damping: float = 0.005 + """Per-vertex velocity damping.""" + + linear_damping: float = 0.0 + """Global linear damping applied to the soft body.""" + + sleep_threshold: float = 0.05 + """Velocity/energy threshold below which the soft body can go to sleep.""" + + settling_threshold: float = 0.1 + """Threshold used to decide convergence/settling state.""" + + settling_damping: float = 10.0 + """Additional damping applied during settling phase.""" + + # --- Mass / density & velocity limits --- + mass: float = -1.0 + """Total mass of the soft body. If set to a negative value, density will be used to compute mass.""" + + density: float = 1000.0 + """Material density in kg/m^3.""" + + max_depenetration_velocity: float = 1e6 + """Maximum velocity used to resolve penetrations. Must be larger than zero.""" + + max_velocity: float = 100 + """Clamp for linear (or vertex) velocity. If set to zero, the limit is ignored.""" + + # --- Solver iteration counts --- + min_position_iters: int = 4 + """Minimum solver iterations for position correction.""" + + min_velocity_iters: int = 1 + """Minimum solver iterations for velocity updates.""" + + def attr(self) -> SoftBodyAttr: + attr = SoftBodyAttr() + attr.youngs = self.youngs + attr.poissons = self.poissons + attr.dynamic_friction = self.dynamic_friction + attr.elasticity_damping = self.elasticity_damping + attr.material_model = self.material_model + attr.enable_kinematic = self.enable_kinematic + attr.enable_ccd = self.enable_ccd + attr.enable_self_collision = self.enable_self_collision + attr.has_gravity = self.has_gravity + attr.self_collision_stress_tolerance = self.self_collision_stress_tolerance + attr.collision_mesh_simplification = self.collision_mesh_simplification + attr.vertex_velocity_damping = self.vertex_velocity_damping + attr.mass = self.mass + attr.density = self.density + attr.max_depenetration_velocity = self.max_depenetration_velocity + attr.max_velocity = self.max_velocity + attr.self_collision_filter_distance = self.self_collision_filter_distance + attr.linear_damping = self.linear_damping + attr.sleep_threshold = self.sleep_threshold + attr.settling_threshold = self.settling_threshold + attr.settling_damping = self.settling_damping + attr.min_position_iters = self.min_position_iters + attr.min_velocity_iters = self.min_velocity_iters + return attr + + +@configclass +class JointDrivePropertiesCfg: + """Properties to define the drive mechanism of a joint.""" + + drive_type: Literal["force", "acceleration"] = "force" + """Joint drive type to apply. + + If the drive type is "force", then the joint is driven by a force and the acceleration is computed based on the force applied. + If the drive type is "acceleration", then the joint is driven by an acceleration and the force is computed based on the acceleration applied. + """ + + stiffness: Union[Dict[str, float], float] = 1e3 + """Stiffness of the joint drive. + + The unit depends on the joint model: + + * For linear joints, the unit is kg-m/s^2 (N/m). + * For angular joints, the unit is kg-m^2/s^2/rad (N-m/rad). + """ + + damping: Union[Dict[str, float], float] = 1e2 + """Damping of the joint drive. + + The unit depends on the joint model: + + * For linear joints, the unit is kg-m/s (N-s/m). + * For angular joints, the unit is kg-m^2/s/rad (N-m-s/rad). + """ + + max_effort: Union[Dict[str, float], float] = 1e10 + """Maximum effort that can be applied to the joint (in kg-m^2/s^2).""" + + max_velocity: Union[Dict[str, float], float] = 1e10 + """Maximum velocity that the joint can reach (in rad/s or m/s). + + For linear joints, this is the maximum linear velocity with unit m/s. + For angular joints, this is the maximum angular velocity with unit rad/s. + """ + + friction: Union[Dict[str, float], float] = 0.0 + """Friction coefficient of the joint""" + + @classmethod + def from_dict( + cls, init_dict: Dict[str, Union[str, float, int]] + ) -> JointDrivePropertiesCfg: + """Initialize the configuration from a dictionary.""" + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +@configclass +class ObjectBaseCfg: + """Base configuration for an asset in the simulation. + + This class defines the basic properties of an asset, such as its type, initial state, and collision group. + It is used as a base class for specific asset configurations. + """ + + uid: Union[str, None] = None + + init_pos: tuple[float, float, float] = (0.0, 0.0, 0.0) + """Position of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0).""" + + init_rot: tuple[float, float, float] = (0.0, 0.0, 0.0) + """Euler angles (in degree) of the root in simulation world frame. Defaults to (0.0, 0.0, 0.0).""" + + init_local_pose: Optional[np.ndarray] = None + """4x4 transformation matrix of the root in local frame. If specified, it will override init_pos and init_rot.""" + + @classmethod + def from_dict(cls, init_dict: Dict[str, Union[str, float, tuple]]) -> ObjectBaseCfg: + """Initialize the configuration from a dictionary.""" + cfg = cls() # Create a new instance of the class (cls) + for key, value in init_dict.items(): + if hasattr(cfg, key): + attr = getattr(cfg, key) + if is_configclass(attr): + setattr( + cfg, key, attr.from_dict(value) + ) # Call from_dict on the attribute + else: + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + + # Automatically infer init_local_pose if not provided + if cfg.init_local_pose is None: + # If only init_pos or init_rot are provided, generate the 4x4 pose matrix + from scipy.spatial.transform import Rotation as R + + T = np.eye(4) + T[:3, 3] = np.array(cfg.init_pos) + T[:3, :3] = R.from_euler("xyz", np.deg2rad(cfg.init_rot)).as_matrix() + cfg.init_local_pose = T + else: + # If only init_local_pose is provided, extract init_pos and init_rot + from scipy.spatial.transform import Rotation as R + + T = np.array(cfg.init_local_pose) + cfg.init_pos = tuple(T[:3, 3]) + cfg.init_rot = tuple(R.from_matrix(T[:3, :3]).as_euler("xyz", degrees=True)) + + return cfg + + +@configclass +class LightCfg(ObjectBaseCfg): + """Configuration for a light asset in the simulation. + + This class extends the base asset configuration to include specific properties for lights, + """ + + # TODO: to be added more light type, such as spot, sun, etc. + light_type: Literal["point"] = "point" + + color: tuple[float, float, float] = (1.0, 1.0, 1.0) + + intensity: float = 50.0 + """Intensity of the light source with unit of watts/m^2.""" + + radius: float = 1e2 + """Falloff of the light, only used for point light.""" + + +@configclass +class RigidObjectCfg(ObjectBaseCfg): + """Configuration for a rigid body asset in the simulation. + + This class extends the base asset configuration to include specific properties for rigid bodies, + such as physical attributes and collision group. + """ + + shape: ShapeCfg = ShapeCfg() + """Shape configuration for the rigid body. """ + + # TODO: supoort basic primitive shapes, such as box, sphere, etc cfg and spawn method. + + attrs: RigidBodyAttributesCfg = RigidBodyAttributesCfg() + + body_type: Literal["dynamic", "kinematic", "static"] = "dynamic" + + max_convex_hull_num: int = 1 + """The maximum number of convex hulls that will be created for the rigid body. + + If `max_convex_hull_num` is set to larger than 1, the rigid body will be decomposed into multiple convex hulls using coacd alogorithm. + Reference: https://github.com/SarahWeiii/CoACD + """ + + body_scale: Union[tuple, list] = (1.0, 1.0, 1.0) + """Scale of the rigid body in the simulation world frame.""" + + def to_dexsim_body_type(self) -> ActorType: + """Convert the body type to dexsim ActorType.""" + if self.body_type == "dynamic": + return ActorType.DYNAMIC + elif self.body_type == "kinematic": + return ActorType.KINEMATIC + elif self.body_type == "static": + return ActorType.STATIC + else: + logger.log_error( + f"Invalid body type '{self.body_type}' specified. Must be one of 'dynamic', 'kinematic', or 'static'." + ) + + +@configclass +class SoftObjectCfg(ObjectBaseCfg): + """Configuration for a soft body asset in the simulation. + + This class extends the base asset configuration to include specific properties for soft bodies, + such as physical attributes and collision group. + """ + + voxel_attr: SoftbodyVoxelAttributesCfg = SoftbodyVoxelAttributesCfg() + """Tetra mesh voxelization attributes for the soft body.""" + + physical_attr: SoftbodyPhysicalAttributesCfg = SoftbodyPhysicalAttributesCfg() + """Physical attributes for the soft body.""" + + shape: MeshCfg = MeshCfg() + """Mesh configuration for the soft body.""" + + +@configclass +class RigidObjectGroupCfg: + """Configuration for a rigid object group asset in the simulation. + + Rigid object groups can be initialized from multiple rigid object configurations specified in a folder. + If `folder_path` is specified, user should provide a RigidObjectCfg in `rigid_objects` as a template configuration for + all objects in the group. + + For example: + ```python + rigid_object_group: RigidObjectGroupCfg( + folder_path="path/to/folder", + max_num=5, + rigid_objects={ + "template_obj": RigidObjectCfg( + shape=MeshCfg( + fpath="", # fpath will be ignored when folder_path is specified + ), + body_type="dynamic", + ) + } + ) + """ + + uid: Union[str, None] = None + + rigid_objects: Dict[str, RigidObjectCfg] = MISSING + """Configuration for the rigid objects in the group.""" + + body_type: Literal["dynamic", "kinematic"] = "dynamic" + """Body type for all rigid objects in the group. """ + + folder_path: Optional[str] = None + """Path to the folder containing the rigid object assets. + + This is used to initialize multiple rigid object configurations from a folder. + """ + + max_num: int = 1 + """Maximum number of rigid objects to initialize from the folder. + + This is only used when `folder_path` is specified. + """ + + ext: str = ".obj" + """File extension for the rigid object assets. + + This is only used when `folder_path` is specified. + """ + + @classmethod + def from_dict(cls, init_dict: Dict[str, Any]) -> RigidObjectGroupCfg: + """Initialize the configuration from a dictionary.""" + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + attr = getattr(cfg, key) + if is_configclass(attr): + setattr( + cfg, key, attr.from_dict(value) + ) # Call from_dict on the attribute + elif key == "rigid_objects" and "folder_path" not in init_dict: + rigid_objects_cfg = {} + for obj_name, obj_cfg in value.items(): + rigid_objects_cfg[obj_name] = RigidObjectCfg.from_dict(obj_cfg) + setattr(cfg, key, rigid_objects_cfg) + elif key == "rigid_objects" and "folder_path" in init_dict: + folder_path = init_dict["folder_path"] + max_num = init_dict.get("max_num", 1) + rigid_objects_cfg = {} + if os.path.exists(folder_path) and os.path.isdir(folder_path): + files = os.listdir(folder_path) + files = [f for f in files if f.endswith(cfg.ext)] + # select files up to max_num + n_file = len(files) + select_files = [] + for i in range(max_num): + select_files.append(files[i % n_file]) + + for i, file_name in enumerate(select_files): + file_path = os.path.join(folder_path, file_name) + rigid_obj_cfg: RigidObjectCfg = RigidObjectCfg.from_dict( + list(init_dict["rigid_objects"].values())[0] + ) + rigid_obj_cfg.uid = f"{cfg.uid}_obj_{i}" + rigid_obj_cfg.shape.fpath = file_path + rigid_objects_cfg[rigid_obj_cfg.uid] = rigid_obj_cfg + setattr(cfg, "rigid_objects", rigid_objects_cfg) + else: + logger.log_error( + f"Folder '{folder_path}' does not exist or is not a directory." + ) + else: + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +@configclass +class URDFCfg: + """Standalone configuration class for URDF assembly.""" + + components: Dict[str, Dict[str, Union[str, Dict, np.ndarray]]] = field( + default_factory=dict + ) + """Dictionary of robot components to be assembled.""" + + sensors: Dict[str, Dict[str, Union[str, np.ndarray]]] = field(default_factory=dict) + """Dictionary of sensors to be attached to the robot.""" + + use_signature_check: bool = True + """Whether to use signature check when merging URDFs.""" + + base_link_name: str = "base_link" + """Name of the base link in the assembled robot.""" + + fpath: Optional[str] = None + """Full output file path for the assembled URDF. If specified, overrides fname and fpath_prefix.""" + + fname: Optional[str] = None + """Name used for output file and directory. If not specified, auto-generated from component names.""" + + fpath_prefix: str = EMBODICHAIN_DEFAULT_DATA_ROOT + "/assembled" + """Output directory prefix for the assembled URDF file.""" + + def __init__( + self, + components: Optional[List[Dict[str, Union[str, np.ndarray]]]] = None, + sensors: Optional[Dict[str, Dict[str, Union[str, np.ndarray]]]] = None, + fpath: Optional[str] = None, + fname: Optional[str] = None, + fpath_prefix: str = EMBODICHAIN_DEFAULT_DATA_ROOT + "/assembled", + use_signature_check: bool = True, + base_link_name: str = "base_link", + ): + """ + Initialize URDFCfg with optional list of components and output path settings. + + Args: + components (Optional[List[Dict]]): List of component configurations. Each dict should contain: + - 'component_type' (str): The type/name of the component (e.g., 'chassis', 'arm', 'hand'). + - 'urdf_path' (str): Path to the component's URDF file. + - 'transform' (Optional[np.ndarray]): 4x4 transformation matrix (optional). + - Additional params can be included as extra keys. + sensors (Optional[Dict]): Sensor configurations for the robot. + fpath (Optional[str]): Full output file path for the assembled URDF. If specified, overrides fname and fpath_prefix. + fname (Optional[str]): Name used for output file and directory. If not specified, auto-generated from component names. + fpath_prefix (str): Output directory prefix for the assembled URDF file. + use_signature_check (bool): Whether to use signature check when merging URDFs. + base_link_name (str): Name of the base link in the assembled robot. + """ + self.components = {} + self.sensors = sensors or {} + self.fpath = fpath + self.use_signature_check = use_signature_check + self.base_link_name = base_link_name + self.fname = fname + self.fpath_prefix = fpath_prefix + + # Auto-add components if provided + if components: + for comp_config in components: + if not isinstance(comp_config, dict): + logger.log_error( + f"Component configuration must be a dict, got {type(comp_config)}" + ) + continue + + # Extract required fields + component_type = comp_config.get("component_type") + urdf_path = comp_config.get("urdf_path") + + if not component_type or not urdf_path: + logger.log_error( + f"Component configuration must contain 'component_type' and 'urdf_path', got {comp_config}" + ) + continue + + # Extract optional fields + transform = comp_config.get("transform", np.eye(4)) + + # Extract additional params (exclude known keys) + params = { + k: v + for k, v in comp_config.items() + if k not in ["component_type", "urdf_path", "transform"] + } + + # Add the component + self.add_component(component_type, urdf_path, transform, **params) + + if sensors is not None: + if not isinstance(sensors, list): + logger.log_error( + f"sensors must be a list of dicts, got {type(sensors)}" + ) + self.sensors = [] + else: + # Optionally check each sensor dict + valid_sensors = [] + for sensor_config in sensors: + if not isinstance(sensor_config, dict): + logger.log_error( + f"Sensor configuration must be a dict, got {type(sensor_config)}" + ) + continue + sensor_name = sensor_config.get("sensor_name") + if not sensor_name: + logger.log_error( + f"Sensor configuration must contain 'sensor_name', got {sensor_config}" + ) + continue + valid_sensors.append(sensor_config) + self.sensors = valid_sensors + + def set_urdf(self, urdf_path: str) -> "URDFCfg": + """Directly specify a single URDF file for the robot, compatible with the single-URDF robot case. + + Args: + urdf_path (str): Path to the robot's URDF file. + + Returns: + URDFCfg: Returns self to allow method chaining. + """ + self.components.clear() + urdf_file = os.path.splitext(os.path.basename(urdf_path))[0] + self.components[urdf_file] = { + "urdf_path": urdf_path, + "transform": None, + "params": {}, + } + self.fpath = urdf_path + return self + + def add_component( + self, + component_type: str, + urdf_path: str, + transform: Optional[np.ndarray] = None, + **params, + ) -> URDFCfg: + """Add a robot component to the assembly configuration. + + Args: + component_type (str): The type/name of the component. Should be one of SUPPORTED_COMPONENTS + (e.g., 'chassis', 'torso', 'head', 'left_arm', 'right_hand', 'arm', 'hand', etc.). + urdf_path (str): Path to the component's URDF file. + transform (Optional[np.ndarray]): 4x4 transformation matrix for the component in the robot frame (default: None). + **params: Additional keyword parameters for the component (e.g., color, material, etc.). + + Returns: + URDFCfg: Returns self to allow method chaining. + """ + if urdf_path: + if not os.path.exists(urdf_path): + urdf_path_candidate = get_data_path(urdf_path) + if os.path.exists(urdf_path_candidate): + urdf_path = urdf_path_candidate + else: + logger.log_error(f"URDF path '{urdf_path}' does not exist.") + raise FileNotFoundError(f"URDF path '{urdf_path}' does not exist.") + + self.components[component_type] = { + "urdf_path": urdf_path, + "transform": np.array(transform), + "params": params, + } + + if self.fname: + self.fpath = f"{self.fpath_prefix}/{self.fname}/{self.fname}.urdf" + else: + # Update output_path to use all component urdf file names joined by underscores as directory + if len(self.components) == 1: + # Only one component, use its urdf file name + urdf_file = os.path.splitext(os.path.basename(urdf_path))[0] + name = urdf_file + else: + # Multiple components, join all urdf file names + urdf_files = [ + os.path.splitext(os.path.basename(v["urdf_path"]))[0] + for v in self.components.values() + ] + name = "_".join(urdf_files) + self.fpath = f"{self.fpath_prefix}/{name}/{name}.urdf" + + return self + + def add_sensor(self, sensor_name: str, **sensor_config) -> URDFCfg: + """Add a sensor to the robot configuration. + + Args: + sensor_name (str): The name of the sensor. + **sensor_config: Additional configuration parameters for the sensor. + + Returns: + URDFCfg: Returns self to allow method chaining. + """ + self.sensors.append({"sensor_name": sensor_name, **sensor_config}) + return self + + def assemble_urdf(self) -> str: + """Assemble URDF files for the robot based on the configuration. + + Returns: + str: The path to the resulting (possibly merged) URDF file. + """ + components = list(self.components.items()) + # If there is only one component, return its URDF path directly. + if len(components) == 1: + _, comp_config = components[0] + return comp_config["urdf_path"] + + from embodichain.toolkits.urdf_assembly import URDFAssemblyManager + + # If there are multiple components, merge them into a single URDF file. + manager = URDFAssemblyManager() + manager.base_link_name = self.base_link_name + for comp_type, comp_config in components: + params = comp_config.get("params", {}) + success = manager.add_component( + comp_type, + comp_config["urdf_path"], + comp_config.get("transform"), + **params, + ) + if not success: + logger.log_error( + f"Failed to add component '{comp_type}' with config: {comp_config}" + ) + + for sensor in self.sensors: + manager.attach_sensor( + sensor_name=sensor.get("sensor_name"), + sensor_source=sensor.get("sensor_source"), + parent_component=sensor.get("parent_component"), + parent_link=sensor.get("parent_link"), + sensor_type=sensor.get("sensor_type"), + **{ + k: v + for k, v in sensor.items() + if k + not in [ + "sensor_name", + "sensor_source", + "parent_component", + "parent_link", + "sensor_type", + ] + }, + ) + + try: + # Merge all added components into a single URDF file at the specified output path. + merged_urdf_xml = manager.merge_urdfs(self.fpath, self.use_signature_check) + except Exception as e: + logger.log_error(f"URDF merge failed: {e}") + + return self.fpath + + @classmethod + def from_dict(cls, init_dict: Dict) -> "URDFCfg": + if isinstance(init_dict, cls): + return init_dict + components = init_dict.get("components", None) + if isinstance(components, dict): + components = [{"component_type": k, **v} for k, v in components.items()] + sensors = init_dict.get("sensors", None) + fpath = init_dict.get("fpath", None) + use_signature_check = init_dict.get("use_signature_check", True) + base_link_name = init_dict.get("base_link_name", "base_link") + return cls( + components=components, + sensors=sensors, + fpath=fpath, + use_signature_check=use_signature_check, + base_link_name=base_link_name, + ) + + +@configclass +class ArticulationCfg(ObjectBaseCfg): + """Configuration for an articulation asset in the simulation. + + This class extends the base asset configuration to include specific properties for articulations, + such as joint drive properties, physical attributes. + """ + + fpath: str = None + """Path to the articulation asset file.""" + + drive_pros: JointDrivePropertiesCfg = JointDrivePropertiesCfg() + """Properties to define the drive mechanism of a joint.""" + + attrs: RigidBodyAttributesCfg = RigidBodyAttributesCfg() + """Physical attributes for all links . """ + + fix_base: bool = True + """Whether to fix the base of the articulation. + + Set to True for articulations that should not move, such as a fixed base robot arm or a door. + Set to False for articulations that should move freely, such as a mobile robot or a humanoid robot. + """ + + disable_self_collision: bool = True + """Whether to enable or disable self-collisions.""" + + init_qpos: Union[torch.Tensor, np.ndarray, Sequence[float]] = None + """Initial joint positions of the articulation. + + If None, the joint positions will be set to zero. + If provided, it should be a array of shape (num_joints,). + """ + + sleep_threshold: float = 0.005 + """Energy below which the articulation may go to sleep. Range: [0, max_float32]""" + + min_position_iters: int = 4 + """Number of position iterations the solver should perform for this articulation. Range: [1,255].""" + + min_velocity_iters: int = 1 + """Number of velocity iterations the solver should perform for this articulation. Range: [0,255].""" + + +@configclass +class RobotCfg(ArticulationCfg): + from embodichain.lab.sim.solvers import SolverCfg + + """Configuration for a robot asset in the simulation. + + # TODO: solver and motion planner may not be configurable inside the robot. + # But currently we put them here and could be moved if necessary. + """ + + control_parts: Union[Dict[str, List[str]], None] = None + """Control parts is the mapping from part name to joint names. + + For example, {'left_arm': ['joint1', 'joint2'], 'right_arm': ['joint3', 'joint4']} + If no control part is specified, the robot will use all joints as a single control part. + + Note: + - if `control_parts` is specified, `solver_cfg` must be a dict with part names as + keys corresponding to the control parts name. + - The joint names in the control parts support regular expressions, e.g., 'joint[1-6]'. + After initialization of robot, the names will be expanded to a list of full joint names. + """ + + urdf_cfg: Optional[URDFCfg] = None + """URDF assembly configuration which allows for assembling a robot from multiple URDF components. + """ + + # TODO: how to support one solver for multiple parts? + solver_cfg: Union[SolverCfg, Dict[str, SolverCfg], None] = None + """Solver is used to compute forward and inverse kinematics for the robot. + """ + + @classmethod + def from_dict(cls, init_dict: Dict[str, Union[str, float, tuple]]) -> RobotCfg: + """Initialize the configuration from a dictionary.""" + if isinstance(init_dict, cls): + return init_dict + + import importlib + + solver_module = importlib.import_module("embodichain.lab.sim.solvers") + + cfg = cls() # Create a new instance of the class (cls) + for key, value in init_dict.items(): + if hasattr(cfg, key): + attr = getattr(cfg, key) + if key == "urdf_cfg": + from embodichain.lab.sim.cfg import URDFCfg + + setattr(cfg, key, URDFCfg.from_dict(value)) + elif is_configclass(attr): + setattr( + cfg, key, attr.from_dict(value) + ) # Call from_dict on the attribute + elif "class_type" in value: + setattr( + cfg, + key, + getattr(solver_module, f"{value['class_type']}Cfg").from_dict( + value + ), + ) + elif isinstance(value, dict) and key_in_nested_dict( + value, "class_type" + ): + setattr( + cfg, + key, + { + k: getattr( + solver_module, f"{v['class_type']}Cfg" + ).from_dict(v) + for k, v in value.items() + }, + ) + + else: + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + def build_pk_serial_chain( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> Dict[str, "pk.SerialChain"]: + """Build the serial chain from the URDF file. + + Note: + This method is usually used in imitation dataset saving (compute eef pose from qpos using FK) + and model training (provide a differentiable FK layer or loss computation). + + Args: + device (torch.device): The device to which the chain will be moved. Defaults to CPU. + **kwargs: Additional arguments for building the serial chain. + + Returns: + Dict[str, pk.SerialChain]: The serial chain of the robot for specified control part. + """ + return {} diff --git a/embodichain/lab/sim/common.py b/embodichain/lab/sim/common.py new file mode 100644 index 00000000..d129b5a6 --- /dev/null +++ b/embodichain/lab/sim/common.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import List, TypeVar, Sequence, Optional +from functools import cached_property + +from embodichain.lab.sim.cfg import ObjectBaseCfg +from embodichain.utils import logger + +T = TypeVar("T") + + +@dataclass +class BatchEntity(ABC): + """Abstract base class for batch entity in the simulation engine. + + This class defines the interfaces for managing and manipulating a batch of entity. + A single entity could be one of the following assets: + - actor (eg. rigid object) + - articulation (eg. robot) + - camera + - light + - sensor (eg. force sensor) + + """ + + uid: Optional[str] = None + cfg: ObjectBaseCfg = None + _entities: List[T] = None + device: torch.device = None + + def __init__( + self, + cfg: ObjectBaseCfg, + entities: List[T] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + + if entities is None or len(entities) == 0: + logger.log_error("Invalid entities list: must not be empty.") + + self.cfg = cfg.copy() + self.uid = self.cfg.uid + if self.uid is None: + logger.log_error("UID must be set in the configuration.") + self._entities = entities + self.device = device + + self.reset() + + def __str__(self) -> str: + return f"{self.__class__}: managing {self.num_instances} {self._entities[0].__class__} objects | uid: {self.uid} | device: {self.device}" + + def __repr__(self) -> str: + return self.__str__() + + @property + def num_instances(self) -> int: + return len(self._entities) + + @abstractmethod + def set_local_pose( + self, pose: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + pass + + @abstractmethod + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + pass + + @property + def pose(self) -> torch.Tensor: + return self.get_local_pose(to_matrix=False) + + @abstractmethod + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + """Reset the entity to its initial state. + + Args: + env_ids (Optional[Sequence[int]]): The environment IDs to reset. If None, reset all environments. + """ + pass + + def destroy(self) -> None: + """Destroy all entities managed by this batch entity.""" + pass diff --git a/embodichain/lab/sim/material.py b/embodichain/lab/sim/material.py new file mode 100644 index 00000000..fe6ad0db --- /dev/null +++ b/embodichain/lab/sim/material.py @@ -0,0 +1,386 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import copy +import torch +import dexsim +import numpy as np + +from typing import Optional, Dict, Union +from functools import cached_property + +from dexsim.engine import MaterialInst, Material +from embodichain.lab.sim.utility import is_rt_enabled +from embodichain.utils import configclass, logger + + +@configclass +class VisualMaterialCfg: + """Configuration for visual material with PBR properties for rasterization and ray tracing.""" + + uid: str = "default_mat" + + # Basic PBR properties + base_color: list = [0.5, 0.5, 0.5, 1.0] + """Base color/diffuse color (RGBA)""" + + metallic: float = 0.0 + """Metallic factor (0.0 = dielectric, 1.0 = metallic)""" + + roughness: float = 0.5 + """Surface roughness (0.0 = smooth, 1.0 = rough)""" + + # Additional PBR properties + emissive: list = [0.0, 0.0, 0.0] # Emissive color (RGB) + emissive_intensity: float = 1.0 # Emissive intensity multiplier + + # Texture maps + base_color_texture: str = None + """Base color texture map""" + + metallic_texture: str = None + """Metallic map""" + + roughness_texture: str = None + """Roughness map""" + + normal_texture: str = None + """Normal map""" + + ao_texture: str = None + """Ambient occlusion map""" + + # Ray tracing specific properties + ior: float = 1.5 + """Index of refraction for ray tracing materials""" + + rt_material_type: str = "BRDF_GGX_SMITH" + """Ray tracing material type. Options: 'BRDF_GGX_SMITH', 'BTDF_GGX_SMITH', 'BSDF_GGX_SMITH'""" + + # Currently disabled properties + # subsurface: float = 0.0 # Subsurface scattering factor + # subsurface_color: list = [1.0, 1.0, 1.0] # Subsurface scattering color + + @classmethod + def from_dict(cls, cfg_dict: dict) -> VisualMaterialCfg: + base = cls() + for k, v in cfg_dict.items(): + if hasattr(base, k): + setattr(base, k, v) + else: + logger.log_warning(f"Unknown field '{k}' in VisualMaterialCfg.") + return base + + +class VisualMaterial: + """Visual material definition in the simulation environment. + + A visual material is actually a material template from which material instances can be created. + It holds multiple material instances, which is used to assign to different objects in the environment. + """ + + RT_MATERIAL_TYPES = [ + "BRDF_GGX_SMITH", + "BTDF_GGX_SMITH", + "BSDF_GGX_SMITH", + ] + + def __init__(self, cfg: VisualMaterialCfg, mat: Material): + self.uid = cfg.uid + self.cfg = copy.deepcopy(cfg) + self._mat = mat + + self._default_mat_inst = self.create_instance(self.uid) + + @cached_property + def is_rt_enabled(self) -> bool: + return is_rt_enabled() + + @property + def mat(self) -> Material: + return self._mat + + def set_default_properties( + self, mat_inst: VisualMaterialInst, cfg: VisualMaterialCfg + ) -> None: + mat_inst.set_base_color(cfg.base_color) + mat_inst.set_metallic(cfg.metallic) + mat_inst.set_roughness(cfg.roughness) + mat_inst.set_emissive(cfg.emissive) + # mat_inst.set_emissive_intensity(self.cfg.emissive_intensity) # Unimplemented + + mat_inst.set_base_color_texture(cfg.base_color_texture) + mat_inst.set_metallic_texture(cfg.metallic_texture) + mat_inst.set_roughness_texture(cfg.roughness_texture) + mat_inst.set_normal_texture(cfg.normal_texture) + mat_inst.set_ao_texture(cfg.ao_texture) + + if self.is_rt_enabled: + mat_inst.set_ior(cfg.ior) + mat_inst.mat.update_pbr_material_type(cfg.rt_material_type) + + def create_instance(self, uid: str) -> VisualMaterialInst: + """Create a new material instance from this material template. + + Note: + - If the uid already exists, the existing instance will be returned. + + Args: + uid (str): Unique identifier for the material instance. + + Returns: + VisualMaterialInst: The created material instance. + """ + inst = VisualMaterialInst(uid, self._mat) + # TODO: Support change default properties for material. + # This will improve the instance creation efficiency. + self.set_default_properties(inst, self.cfg) + return inst + + def get_default_instance(self) -> VisualMaterialInst: + """Get the default material instance created with the same uid as the material template. + + Returns: + VisualMaterialInst: The default material instance. + """ + return self._default_mat_inst + + def get_instance(self, uid: str) -> VisualMaterialInst: + """Get an existing material instance by its uid. + + Args: + uid (str): Unique identifier for the material instance. + + Returns: + VisualMaterialInst: The material instance. + """ + return VisualMaterialInst(uid, self._mat) + + +class VisualMaterialInst: + """Instance of a visual material in the simulation environment.""" + + def __init__(self, uid: str, mat: Material): + self.uid = uid + self._mat = mat + + # Init properties with default values + self.base_color = [0.5, 0.5, 0.5, 1.0] + self.metallic = 0.0 + self.roughness = 0.5 + self.emissive = [0.0, 0.0, 0.0] + self.emissive_intensity = 1.0 + self.base_color_texture = None + self.metallic_texture = None + self.roughness_texture = None + self.normal_texture = None + self.ao_texture = None + self.ior = 1.5 + # self.subsurface = 0.0 + + @property + def mat(self) -> MaterialInst: + return self._mat.get_inst(self.uid) + + def set_base_color(self, color: list) -> None: + """Set base color/diffuse color.""" + self.base_color = color + self.mat.set_base_color(color) + + def set_metallic(self, metallic: float) -> None: + """Set metallic factor.""" + self.metallic = metallic + inst = self._mat.get_inst(self.uid) + inst.set_metallic(metallic) + + def set_roughness(self, roughness: float) -> None: + """Set surface roughness.""" + self.roughness = roughness + inst = self._mat.get_inst(self.uid) + inst.set_roughness(roughness) + + def set_emissive(self, emissive: list) -> None: + """Set emissive color.""" + self.emissive = emissive + value = np.zeros(4) + value[0:3] = emissive + inst = self._mat.get_inst(self.uid) + inst.set_emissive(value) + + def set_emissive_intensity(self, intensity: float) -> None: + """Set emissive intensity multiplier.""" + logger.log_error("Unimplemented: set_emissive_intensity") + + def set_base_color_texture( + self, texture_path: str = None, texture_data: Optional[torch.Tensor] = None + ) -> None: + """Set base color texture from file path or texture data. + + Args: + texture_path: Path to texture file + texture_data: Texture data as a torch.Tensor + """ + if texture_path is not None and texture_data is not None: + logger.log_warning( + "Both texture_path and texture_data are provided. Using texture_path." + ) + + if texture_path is not None: + self.base_color_texture = texture_path + inst = self._mat.get_inst(self.uid) + inst.set_base_color_map(texture_path) + elif texture_data is not None: + self.base_color_texture = texture_data + inst = self._mat.get_inst(self.uid) + + # TODO: Optimize texture creation method. + world = dexsim.default_world() + env = world.get_env() + color_texture = env.create_color_texture( + texture_data.cpu().numpy(), has_alpha=True + ) + inst.set_base_color_map(color_texture) + + def set_metallic_texture( + self, texture_path: str = None, texture_data: Optional[torch.Tensor] = None + ) -> None: + """Set metallic texture from file path or texture data. + + Args: + texture_path: Path to texture file + texture_data: Texture data as a torch.Tensor + """ + if texture_path is not None and texture_data is not None: + logger.log_warning( + "Both texture_path and texture_data are provided. Using texture_path." + ) + + if texture_path is not None: + self.metallic_texture = texture_path + inst = self._mat.get_inst(self.uid) + inst.set_metallic_map(texture_path) + elif texture_data is not None: + self.metallic_texture = texture_data + inst = self._mat.get_inst(self.uid) + + # TODO: Optimize texture creation method. + world = dexsim.default_world() + env = world.get_env() + metallic_texture = env.create_color_texture( + texture_data.cpu().numpy(), has_alpha=False + ) + inst.set_metallic_map(metallic_texture) + + def set_roughness_texture( + self, texture_path: str = None, texture_data: Optional[torch.Tensor] = None + ) -> None: + """Set roughness texture from file path or texture data. + + Args: + texture_path: Path to texture file + texture_data: Texture data as a torch.Tensor + """ + if texture_path is not None and texture_data is not None: + logger.log_warning( + "Both texture_path and texture_data are provided. Using texture_path." + ) + + if texture_path is not None: + self.roughness_texture = texture_path + inst = self._mat.get_inst(self.uid) + inst.set_roughness_map(texture_path) + elif texture_data is not None: + self.roughness_texture = texture_data + inst = self._mat.get_inst(self.uid) + + # TODO: Optimize texture creation method. + world = dexsim.default_world() + env = world.get_env() + roughness_texture = env.create_color_texture( + texture_data.cpu().numpy(), has_alpha=False + ) + inst.set_roughness_map(roughness_texture) + + def set_normal_texture( + self, texture_path: str = None, texture_data: Optional[torch.Tensor] = None + ) -> None: + """Set normal texture from file path or texture data. + + Args: + texture_path: Path to texture file + texture_data: Texture data as a torch.Tensor + """ + if texture_path is not None and texture_data is not None: + logger.log_warning( + "Both texture_path and texture_data are provided. Using texture_path." + ) + + if texture_path is not None: + self.normal_texture = texture_path + inst = self._mat.get_inst(self.uid) + inst.set_normal_map(texture_path) + elif texture_data is not None: + self.normal_texture = texture_data + inst = self._mat.get_inst(self.uid) + + # TODO: Optimize texture creation method. + world = dexsim.default_world() + env = world.get_env() + normal_texture = env.create_color_texture( + texture_data.cpu().numpy(), has_alpha=False + ) + inst.set_normal_map(normal_texture) + + def set_ao_texture( + self, texture_path: str = None, texture_data: Optional[torch.Tensor] = None + ) -> None: + """Set ambient occlusion texture from file path or texture data. + + Args: + texture_path: Path to texture file + texture_data: Texture data as a torch.Tensor + """ + if texture_path is not None and texture_data is not None: + logger.log_warning( + "Both texture_path and texture_data are provided. Using texture_path." + ) + + if texture_path is not None: + self.ao_texture = texture_path + inst = self._mat.get_inst(self.uid) + inst.set_ao_map(texture_path) + elif texture_data is not None: + self.ao_texture = texture_data + inst = self._mat.get_inst(self.uid) + + # TODO: Optimize texture creation method. + world = dexsim.default_world() + env = world.get_env() + ao_texture = env.create_color_texture( + texture_data.cpu().numpy(), has_alpha=False + ) + inst.set_ao_map(ao_texture) + + def set_ior(self, ior: float) -> None: + """Set index of refraction.""" + if is_rt_enabled() is False: + logger.log_debug("Ray Tracing rendering not enabled, ignoring IOR setting.") + return + self.ior = ior + inst = self._mat.get_inst(self.uid) + inst.set_rt_param("ior", ior) diff --git a/embodichain/lab/sim/objects/__init__.py b/embodichain/lab/sim/objects/__init__.py new file mode 100644 index 00000000..9c4ba945 --- /dev/null +++ b/embodichain/lab/sim/objects/__init__.py @@ -0,0 +1,28 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from ..common import BatchEntity +from .rigid_object import RigidObject, RigidBodyData, RigidObjectCfg +from .rigid_object_group import ( + RigidObjectGroup, + RigidBodyGroupData, + RigidObjectGroupCfg, +) +from .soft_object import SoftObject, SoftBodyData, SoftObjectCfg +from .articulation import Articulation, ArticulationData, ArticulationCfg +from .robot import Robot, RobotCfg +from .light import Light, LightCfg +from .gizmo import Gizmo diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py new file mode 100644 index 00000000..c9f56090 --- /dev/null +++ b/embodichain/lab/sim/objects/articulation.py @@ -0,0 +1,1487 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import dexsim +import numpy as np + +from dataclasses import dataclass +from functools import cached_property +from typing import List, Sequence, Optional, Dict, Union + +from dexsim.engine import Articulation as _Articulation +from dexsim.types import ( + ArticulationGPUAPIWriteType, + ArticulationGPUAPIReadType, +) +from dexsim.engine import CudaArray, PhysicsScene + +from embodichain.lab.sim import VisualMaterialInst, VisualMaterial +from embodichain.lab.sim.cfg import ArticulationCfg, JointDrivePropertiesCfg +from embodichain.lab.sim.common import BatchEntity +from embodichain.utils.math import ( + matrix_from_quat, + quat_from_matrix, + convert_quat, + matrix_from_euler, +) +from embodichain.lab.sim.utility.sim_utils import ( + get_dexsim_drive_type, + set_dexsim_articulation_cfg, + is_rt_enabled, +) +from embodichain.lab.sim.utility.solver_utils import ( + create_pk_chain, + create_pk_serial_chain, +) +from embodichain.utils import logger + + +@dataclass +class ArticulationData: + """GPU data manager for articulation.""" + + def __init__( + self, entities: List[_Articulation], ps: PhysicsScene, device: torch.device + ) -> None: + """Initialize the ArticulationData. + + Args: + entities (List[_Articulation]): List of DexSim Articulation objects. + ps (PhysicsScene): The physics scene. + device (torch.device): The device to use for the articulation data. + """ + self.entities = entities + self.ps = ps + self.num_instances = len(entities) + self.device = device + + # get gpu indices for the entities. + # only meaningful when using GPU physics. + self.gpu_indices = torch.as_tensor( + [np.int32(entity.get_gpu_index()) for entity in self.entities], + dtype=torch.int32, + device=self.device, + ) + + self.dof = self.entities[0].get_dof() + self.num_links = self.entities[0].get_links_num() + self.link_names = self.entities[0].get_link_names() + + self._root_pose = torch.zeros( + (self.num_instances, 7), dtype=torch.float32, device=self.device + ) + self._root_lin_vel = torch.zeros( + (self.num_instances, 3), dtype=torch.float32, device=self.device + ) + self._root_ang_vel = torch.zeros( + (self.num_instances, 3), dtype=torch.float32, device=self.device + ) + + max_num_links = ( + self.ps.gpu_get_articulation_max_link_count() + if self.device.type == "cuda" + else self.num_links + ) + self._body_link_pose = torch.zeros( + (self.num_instances, max_num_links, 7), + dtype=torch.float32, + device=self.device, + ) + self._body_link_vel = torch.zeros( + (self.num_instances, max_num_links, 6), + dtype=torch.float32, + device=self.device, + ) + + self._body_link_lin_vel = torch.zeros( + (self.num_instances, max_num_links, 3), + dtype=torch.float32, + device=self.device, + ) + self._body_link_ang_vel = torch.zeros( + (self.num_instances, max_num_links, 3), + dtype=torch.float32, + device=self.device, + ) + + max_dof = ( + self.ps.gpu_get_articulation_max_dof() + if self.device.type == "cuda" + else self.dof + ) + self._qpos = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._qvel = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._qacc = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._qf = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + + @property + def root_pose(self) -> torch.Tensor: + """Get the root pose of the articulation. + + Returns: + torch.Tensor: The root pose of the articulation with shape of (num_instances, 7). + """ + if self.device.type == "cpu": + # Fetch pose from CPU entities + root_pose = torch.as_tensor( + np.array([entity.get_local_pose() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + xyzs = root_pose[:, :3, 3] + quats = quat_from_matrix(root_pose[:, :3, :3]) + return torch.cat((xyzs, quats), dim=-1) + else: + self.ps.gpu_fetch_root_data( + data=self._root_pose, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.ROOT_GLOBAL_POSE, + ) + self._root_pose[:, :4] = convert_quat(self._root_pose[:, :4], to="wxyz") + return self._root_pose[:, [4, 5, 6, 0, 1, 2, 3]] + + @property + def root_lin_vel(self) -> torch.Tensor: + """Get the linear velocity of the root link of the articulation. + + Returns: + torch.Tensor: The linear velocity of the root link with shape of (num_instances, 3). + """ + if self.device.type == "cpu": + # Fetch linear velocity from CPU entities + return torch.as_tensor( + np.array( + [entity.get_root_link_velocity()[:3] for entity in self.entities] + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_root_data( + data=self._root_lin_vel, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.ROOT_LINEAR_VELOCITY, + ) + return self._root_lin_vel.clone() + + @property + def root_ang_vel(self) -> torch.Tensor: + """Get the angular velocity of the root link of the articulation. + + Returns: + torch.Tensor: The angular velocity of the root link with shape of (num_instances, 3). + """ + if self.device.type == "cpu": + # Fetch angular velocity from CPU entities + return torch.as_tensor( + np.array( + [entity.get_root_link_velocity()[3:] for entity in self.entities] + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_root_data( + data=self._root_ang_vel, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.ROOT_ANGULAR_VELOCITY, + ) + return self._root_ang_vel.clone() + + @property + def root_vel(self) -> torch.Tensor: + """Get the velocity of the root link of the articulation. + + Returns: + torch.Tensor: The velocity of the root link, concatenating linear and angular velocities. + """ + return torch.cat((self.root_lin_vel, self.root_ang_vel), dim=-1) + + @property + def qpos(self) -> torch.Tensor: + """Get the current positions (qpos) of the articulation. + + Returns: + torch.Tensor: The current positions of the articulation with shape of (num_instances, dof). + """ + if self.device.type == "cpu": + # Fetch qpos from CPU entities + return torch.as_tensor( + np.array( + [entity.get_current_qpos() for entity in self.entities], + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_joint_data( + data=self._qpos, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.JOINT_POSITION, + ) + return self._qpos[:, : self.dof].clone() + + @property + def qvel(self) -> torch.Tensor: + """Get the current velocities (qvel) of the articulation. + + Returns: + torch.Tensor: The current velocities of the articulation with shape of (num_instances, dof). + """ + if self.device.type == "cpu": + # Fetch qvel from CPU entities + return torch.as_tensor( + np.array([entity.get_current_qvel() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_joint_data( + data=self._qvel, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.JOINT_VELOCITY, + ) + return self._qvel[:, : self.dof].clone() + + @property + def qacc(self) -> torch.Tensor: + """Get the current accelerations (qacc) of the articulation. + + Returns: + torch.Tensor: The current accelerations of the articulation with shape of (num_instances, dof). + """ + if self.device.type == "cpu": + # Fetch qacc from CPU entities + return torch.as_tensor( + np.array([entity.get_current_qacc() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_joint_data( + data=self._qacc, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.JOINT_ACCELERATION, + ) + return self._qacc[:, : self.dof].clone() + + @property + def qf(self) -> torch.Tensor: + """Get the current forces (qf) of the articulation. + + Returns: + torch.Tensor: The current forces of the articulation with shape of (num_instances, dof). + """ + if self.device.type == "cpu": + # Fetch qf from CPU entities + return torch.as_tensor( + np.array([entity.get_current_qf() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_joint_data( + data=self._qf, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.JOINT_FORCE, + ) + return self._qf[:, : self.dof].clone() + + @property + def body_link_pose(self) -> torch.Tensor: + """Get the pose of all links in the articulation. + + Returns: + torch.Tensor: The poses of the links in the articulation with shape (N, num_links, 7). + """ + if self.device.type == "cpu": + from embodichain.lab.sim.utility import get_dexsim_arenas + + arenas = get_dexsim_arenas() + for j, entity in enumerate(self.entities): + + link_pose = np.zeros((self.num_links, 4, 4), dtype=np.float32) + for i, link_name in enumerate(self.link_names): + pose = entity.get_link_pose(link_name) + arena_pose = arenas[j].get_root_node().get_local_pose() + pose[:2, 3] -= arena_pose[:2, 3] + link_pose[i] = pose + + link_pose = torch.from_numpy(link_pose) + xyz = link_pose[:, :3, 3] + quat = quat_from_matrix(link_pose[:, :3, :3]) + self._body_link_pose[j][: self.num_links, :] = torch.cat( + (xyz, quat), dim=-1 + ) + return self._body_link_pose[:, : self.num_links, :] + else: + self.ps.gpu_fetch_link_data( + data=self._body_link_pose, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.LINK_GLOBAL_POSE, + ) + quat = convert_quat(self._body_link_pose[..., :4], to="wxyz") + return torch.cat((self._body_link_pose[..., 4:], quat), dim=-1) + + @property + def body_link_vel(self) -> torch.Tensor: + """Get the velocities of all links in the articulation. + + Returns: + torch.Tensor: The poses of the links in the articulation with shape (N, num_links, 6). + """ + if self.device.type == "cpu": + for i, entity in enumerate(self.entities): + self._body_link_vel[i][: self.num_links] = torch.from_numpy( + entity.get_link_general_velocities() + ) + return self._body_link_vel[:, : self.num_links, :] + else: + self.ps.gpu_fetch_link_data( + data=self._body_link_lin_vel, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.LINK_LINEAR_VELOCITY, + ) + self.ps.gpu_fetch_link_data( + data=self._body_link_ang_vel, + gpu_indices=self.gpu_indices, + data_type=ArticulationGPUAPIReadType.LINK_ANGULAR_VELOCITY, + ) + self._body_link_vel[..., :3] = self._body_link_lin_vel + self._body_link_vel[..., 3:] = self._body_link_ang_vel + return self._body_link_vel[:, : self.num_links, :] + + @property + def joint_stiffness(self) -> torch.Tensor: + """Get the joint stiffness of the articulation. + + Returns: + torch.Tensor: The joint stiffness of the articulation with shape (N, dof). + """ + return torch.as_tensor( + np.array([entity.get_drive()[0] for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + + @property + def joint_damping(self) -> torch.Tensor: + """Get the joint damping of the articulation. + + Returns: + torch.Tensor: The joint damping of the articulation with shape (N, dof). + """ + return torch.as_tensor( + np.array([entity.get_drive()[1] for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + + @property + def joint_friction(self) -> torch.Tensor: + """Get the joint friction of the articulation. + + Returns: + torch.Tensor: The joint friction of the articulation with shape (N, dof). + """ + return torch.as_tensor( + np.array([entity.get_drive()[4] for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + + @cached_property + def qpos_limits(self) -> torch.Tensor: + """Get the joint position limits of the articulation. + + Returns: + torch.Tensor: The joint position limits of the articulation with shape (N, dof, 2). + """ + return torch.as_tensor( + np.array([entity.get_joint_limits() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + + @cached_property + def qvel_limits(self) -> torch.Tensor: + """Get the joint velocity limits of the articulation. + + Returns: + torch.Tensor: The joint velocity limits of the articulation with shape (N, dof). + """ + # TODO: get joint velocity limits always returns zero? + return torch.as_tensor( + np.array( + [entity.get_drive()[3] for entity in self.entities], + ), + dtype=torch.float32, + device=self.device, + ) + + @cached_property + def qf_limits(self) -> torch.Tensor: + """Get the joint effort limits of the articulation. + + Returns: + torch.Tensor: The joint effort limits of the articulation with shape (N, dof). + """ + return torch.as_tensor( + np.array( + [entity.get_drive()[2] for entity in self.entities], + ), + dtype=torch.float32, + device=self.device, + ) + + +class Articulation(BatchEntity): + """Articulation represents a batch of articulations in the simulation. + + An articulation is a collection of rigid bodies connected by joints. The joints can be either + fixed or actuated. The joints can be of different types, such as revolute or prismatic. + + For fixed-base articulation, it can be a robot arm, door, etc. + For floating-base articulation, it can be a humanoid, drawer, etc. + + Args: + cfg (ArticulationCfg): Configuration for the articulation. + entities (List[_Articulation], optional): List of articulation entities. + device (torch.device, optional): Device to use (CPU or CUDA). + """ + + def __init__( + self, + cfg: ArticulationCfg, + entities: List[_Articulation] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + # Initialize world and physics scene + self._world = dexsim.default_world() + self._ps = self._world.get_physics_scene() + + self.cfg = cfg + self._entities = entities + self.device = device + + # Store all indices for batch operations + self._all_indices = torch.arange( + len(entities), dtype=torch.int32, device=device + ) + + if device.type == "cuda": + self._world.update(0.001) + + self._data = ArticulationData(entities=entities, ps=self._ps, device=device) + + self.cfg: ArticulationCfg + if self.cfg.init_qpos is None: + self.cfg.init_qpos = torch.zeros(self.dof, dtype=torch.float32) + + # Set articulation configuration in DexSim + set_dexsim_articulation_cfg(entities, self.cfg) + + # Init joint drive parameters. + num_entities = len(entities) + dof = self._data.dof + default_cfg = JointDrivePropertiesCfg() + self.default_joint_damping = torch.full( + (num_entities, dof), default_cfg.damping, dtype=torch.float32, device=device + ) + self.default_joint_stiffness = torch.full( + (num_entities, dof), + default_cfg.stiffness, + dtype=torch.float32, + device=device, + ) + self.default_joint_max_effort = torch.full( + (num_entities, dof), + default_cfg.max_effort, + dtype=torch.float32, + device=device, + ) + self.default_joint_max_velocity = torch.full( + (num_entities, dof), + default_cfg.max_velocity, + dtype=torch.float32, + device=device, + ) + self.default_joint_friction = torch.full( + (num_entities, dof), + default_cfg.friction, + dtype=torch.float32, + device=device, + ) + self._set_default_joint_drive() + + self.pk_chain = create_pk_chain(urdf_path=self.cfg.fpath, device=self.device) + + # For rendering purposes, each articulation can have multiple material instances associated with its links. + self._visual_material: List[Dict[str, VisualMaterialInst]] = [ + {} for _ in range(len(entities)) + ] + + # Stores mimic information for joints. + self._mimic_info = entities[0].get_mimic_info() + + # TODO: very weird that we must call update here to make sure the GPU indices are valid. + if device.type == "cuda": + self._world.update(0.001) + + super().__init__(cfg, entities, device) + + # set default collision filter + self._set_default_collision_filter() + + def __str__(self) -> str: + parent_str = super().__str__() + return parent_str + f" | dof: {self.dof} | num_links: {self.num_links}" + + @property + def dof(self) -> int: + """Get the degree of freedom of the articulation. + + Returns: + int: The degree of freedom of the articulation. + """ + return self._data.dof + + @property + def num_links(self) -> int: + """Get the number of links in the articulation. + + Returns: + int: The number of links in the articulation. + """ + return self._data.num_links + + @property + def link_names(self) -> List[str]: + """Get the names of the links in the articulation. + + Returns: + List[str]: The names of the links in the articulation. + """ + return self._data.link_names + + @property + def root_link_name(self) -> str: + """Get the name of the root link of the articulation. + + Returns: + str: The name of the root link. + """ + return self.entities[0].get_root_link_name() + + @property + def joint_names(self) -> List[str]: + """Get the names of the actived joints in the articulation. + + Returns: + List[str]: The names of the actived joints in the articulation. + """ + return self._entities[0].get_actived_joint_names() + + @property + def all_joint_names(self) -> List[str]: + """Get the names of the joints in the articulation. + + Returns: + List[str]: The names of the joints in the articulation. + """ + return self._entities[0].get_joint_names() + + @property + def body_data(self) -> ArticulationData: + """Get the rigid body data manager for this rigid object. + + Returns: + RigidBodyData: The rigid body data manager. + """ + return self._data + + @property + def root_state(self) -> torch.Tensor: + """Get the root state of the articulation. + + Returns: + torch.Tensor: The root state of the articulation with shape (N, 13). + """ + root_pose = self.body_data.root_pose + root_lin_vel = self.body_data.root_lin_vel + root_ang_vel = self.body_data.root_ang_vel + return torch.cat((root_pose, root_lin_vel, root_ang_vel), dim=-1) + + @property + def body_state(self) -> torch.Tensor: + """Get the body state of the articulation. + + Returns: + torch.Tensor: The body state of the articulation with shape (N, num_links, 13). + """ + body_pose = self.body_data.body_link_pose + body_vel = self.body_data.body_link_vel + return torch.cat((body_pose, body_vel), dim=-1) + + @property + def mimic_ids(self) -> List[Optional[int]]: + """Get the mimic joint ids for the articulation. + + Returns: + List[Optional[int]]: The mimic joint ids. + """ + return self._mimic_info.mimic_id.tolist() + + @property + def mimic_parents(self) -> List[Optional[int]]: + """Get the mimic joint parent ids for the articulation. + + Returns: + List[Optional[int]]: The mimic joint parent ids. + """ + return self._mimic_info.mimic_parent.tolist() + + @property + def mimic_multipliers(self) -> List[float]: + """Get the mimic joint multipliers for the articulation. + + Returns: + List[float]: The mimic joint multipliers. + """ + return self._mimic_info.mimic_multiplier.tolist() + + @property + def mimic_offsets(self) -> List[float]: + """Get the mimic joint offsets for the articulation. + + Returns: + List[float]: The mimic joint offsets. + """ + return self._mimic_info.mimic_offset.tolist() + + def _set_default_collision_filter(self) -> None: + collision_filter_data = torch.zeros( + size=(self.num_instances, 4), dtype=torch.int32 + ) + for i in range(self.num_instances): + collision_filter_data[i, 0] = i + collision_filter_data[i, 1] = 1 + self.set_collision_filter(collision_filter_data) + + def set_collision_filter( + self, filter_data: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """set collision filter data for the rigid object. + + Args: + filter_data (torch.Tensor): [N, 4] of int. + First element of each object is arena id. + If 2nd element is 0, the object will collision with all other objects in world. + 3rd and 4th elements are not used currently. + + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(filter_data): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(filter_data)}." + ) + + filter_data_np = filter_data.cpu().numpy().astype(np.uint32) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_collision_filter_data(filter_data_np[i]) + + def set_local_pose( + self, pose: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set local pose of the articulation. + + Args: + pose (torch.Tensor): The local pose of the articulation with shape (N, 7) or (N, 4, 4). + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(pose): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(pose)}." + ) + + if self.device.type == "cpu": + pose = pose.cpu() + if pose.dim() == 2 and pose.shape[1] == 7: + pose_matrix = torch.eye(4).unsqueeze(0).repeat(pose.shape[0], 1, 1) + pose_matrix[:, :3, 3] = pose[:, :3] + pose_matrix[:, :3, :3] = matrix_from_quat(pose[:, 3:7]) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose_matrix[i]) + elif pose.dim() == 3 and pose.shape[1:] == (4, 4): + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose[i]) + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + # TODO: in manual physics mode, the update should be explicitly called after + # setting the pose to synchronize the state to renderer. + self._world.update(0.001) + + else: + if pose.dim() == 2 and pose.shape[1] == 7: + xyz = pose[:, :3] + quat = convert_quat(pose[:, 3:7], to="xyzw") + elif pose.dim() == 3 and pose.shape[1:] == (4, 4): + xyz = pose[:, :3, 3] + quat = quat_from_matrix(pose[:, :3, :3]) + quat = convert_quat(quat, to="xyzw") + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + # we should keep `pose_` life cycle to the end of the function. + pose_ = torch.cat((quat, xyz), dim=-1) + indices = self.body_data.gpu_indices[local_env_ids] + self._ps.gpu_apply_root_data( + data=pose_, + gpu_indices=indices, + data_type=ArticulationGPUAPIWriteType.ROOT_GLOBAL_POSE, + ) + self._ps.gpu_compute_articulation_kinematic(gpu_indices=indices) + + # TODO: To be removed when gpu articulation data sync is supported. + if is_rt_enabled() is False: + self.body_data.body_link_pose + link_pose = self.body_data._body_link_pose[local_env_ids] + self._world.sync_poses_gpu_to_cpu( + link_pose=CudaArray(link_pose), + articulation_gpu_indices=CudaArray(indices), + ) + + def get_local_pose(self, to_matrix=False) -> torch.Tensor: + """Get local pose (root link pose) of the articulation. + + Args: + to_matrix (bool, optional): If True, return the pose as a 4x4 matrix. If False, return as (x, y, z, qw, qx, qy, qz). Defaults to False. + + Returns: + torch.Tensor: The local pose of the articulation with shape (N, 7) or (N, 4, 4) depending on `to_matrix`. + """ + pose = self.body_data.root_pose + if to_matrix: + xyz = pose[:, :3] + mat = matrix_from_quat(pose[:, 3:7]) + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(pose.shape[0], 1, 1) + ) + pose[:, :3, 3] = xyz + pose[:, :3, :3] = mat + return pose + + def get_link_pose( + self, link_name: str, env_ids: Optional[Sequence[int]] = None, to_matrix=False + ) -> torch.Tensor: + """Get the pose of a specific link in the articulation. + + Args: + link_name (str): The name of the link. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + to_matrix (bool, optional): If True, return the pose as a 4x4 matrix. If False, return as (x, y, z, qw, qx, qy, qz). Defaults to False. + + Returns: + torch.Tensor: The pose of the specified link with shape (N, 7) or (N, 4, 4) depending on `to_matrix`. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if link_name not in self.link_names: + logger.log_error( + f"Link name {link_name} not found in {self.__class__.__name__}. Available links: {self.link_names}" + ) + + link_idx = self.link_names.index(link_name) + link_pose = self.body_data.body_link_pose[local_env_ids, link_idx, :] + + if to_matrix: + xyz = link_pose[:, :3] + mat = matrix_from_quat(link_pose[:, 3:7]) + link_pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(link_pose.shape[0], 1, 1) + ) + link_pose[:, :3, 3] = xyz + link_pose[:, :3, :3] = mat + return link_pose + + def get_qpos(self) -> torch.Tensor: + """Get the current positions (qpos) of the articulation.""" + return self.body_data.qpos + + def set_qpos( + self, + qpos: torch.Tensor, + joint_ids: Optional[Sequence[int]] = None, + env_ids: Optional[Sequence[int]] = None, + target: bool = True, + ) -> None: + """Set the joint positions (qpos) or target positions for the articulation. + + Args: + qpos (torch.Tensor): Joint positions with shape (N, dof), where N is the number of environments. + joint_ids (Optional[Sequence[int]], optional): Joint indices to apply the positions. If None, applies to all joints. + env_ids (Optional[Sequence[int]]): Environment indices to apply the positions. Defaults to all environments. + target (bool): If True, sets target positions for simulation. If False, updates current positions directly. + + Raises: + ValueError: If the length of `env_ids` does not match the length of `qpos`. + """ + if not isinstance(qpos, torch.Tensor): + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) + + if joint_ids is None: + local_joint_ids = torch.arange( + self.dof, device=self.device, dtype=torch.int32 + ) + elif not isinstance(joint_ids, torch.Tensor): + local_joint_ids = torch.as_tensor( + joint_ids, dtype=torch.int32, device=self.device + ) + else: + local_joint_ids = joint_ids + + local_env_ids = self._all_indices if env_ids is None else env_ids + + # TODO: Refactor this part to use a more generic and extensible approach, + # such as a class decorator that can automatically convert ndarray to torch.Tensor + # and handle dimension padding for specified member functions. + # This will make the codebase cleaner and reduce repetitive type checks/conversions. + # (e.g., support specifying which methods should be decorated for auto-conversion.) + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) + + if self.device.type == "cuda": + limits = self.body_data.qpos_limits[0].T + # clamp qpos to limits + lower_limits = limits[0][local_joint_ids] + upper_limits = limits[1][local_joint_ids] + qpos = qpos.clamp(lower_limits, upper_limits) + + # Make sure qpos is 2D tensor + if qpos.dim() == 1: + qpos = qpos.unsqueeze(0) + # If only one qpos is provided, repeat it for all envs + if len(qpos) == 1 and len(local_env_ids) > 1: + qpos = qpos.repeat(len(local_env_ids), 1) + + if len(local_env_ids) != len(qpos): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match qpos length {len(qpos)}. " + f"env_ids: {local_env_ids}, qpos.shape: {qpos.shape}" + ) + + data_type = ( + ArticulationGPUAPIWriteType.JOINT_TARGET_POSITION + if target + else ArticulationGPUAPIWriteType.JOINT_POSITION + ) + + if self.device.type == "cpu": + for i, env_idx in enumerate(local_env_ids): + setter = ( + self._entities[env_idx].set_current_qpos + if target + else self._entities[env_idx].set_qpos + ) + setter(qpos[i].numpy(), local_joint_ids.numpy()) + else: + # TODO: trigger qpos getter to sync data, otherwise crash + if joint_ids is not None: + self.body_data.qpos + + indices = self.body_data.gpu_indices[local_env_ids] + qpos_set = self.body_data._qpos[local_env_ids] + qpos_set[:, local_joint_ids] = qpos + self._ps.gpu_apply_joint_data( + data=qpos_set, + gpu_indices=indices, + data_type=data_type, + ) + + def get_qvel(self) -> torch.Tensor: + """Get the current velocities (qvel) of the articulation. + + Returns: + torch.Tensor: The current velocities of the articulation. + """ + return self.body_data.qvel + + def set_qvel( + self, + qvel: torch.Tensor, + joint_ids: Optional[Sequence[int]] = None, + env_ids: Optional[Sequence[int]] = None, + target: bool = True, + ) -> None: + """Set the velocities (qvel) or target velocities of the articulation. + + Args: + qvel (torch.Tensor): The velocities with shape (N, dof). + joint_ids (Optional[Sequence[int]], optional): Joint indices to apply the velocities. If None, applies to all joints. + env_ids (Optional[Sequence[int]], optional): Environment indices. Defaults to all indices. + If True, sets target positions for simulation. If False, updates current positions directly. + + Raises: + ValueError: If the length of `env_ids` does not match the length of `qvel`. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(qvel): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match qvel length {len(qvel)}." + ) + + data_type = ( + ArticulationGPUAPIWriteType.JOINT_TARGET_VELOCITY + if target + else ArticulationGPUAPIWriteType.JOINT_VELOCITY + ) + + if self.device.type == "cpu": + local_joint_ids = np.arange(self.dof) if joint_ids is None else joint_ids + for i, env_idx in enumerate(local_env_ids): + setter = ( + self._entities[env_idx].set_current_qvel + if target + else self._entities[env_idx].set_qvel + ) + setter(qvel[i].numpy(), local_joint_ids) + else: + indices = self.body_data.gpu_indices[local_env_ids] + if joint_ids is None: + qvel_set = self.body_data._qvel[local_env_ids] + qvel_set[:, : self.dof] = qvel + else: + self.body_data.qvel + qvel_set = self.body_data._qvel[local_env_ids] + qvel_set[:, joint_ids] = qvel + self._ps.gpu_apply_joint_data( + data=qvel_set, + gpu_indices=indices, + data_type=data_type, + ) + + def set_qf( + self, + qf: torch.Tensor, + joint_ids: Optional[Sequence[int]] = None, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """Set the generalized efforts (qf) of the articulation. + + Args: + qf (torch.Tensor): The generalized efforts with shape (N, dof). + joint_ids (Optional[Sequence[int]], optional): Joint indices to apply the efforts. If None, applies to all joints. + env_ids (Optional[Sequence[int]], optional): Environment indices. Defaults to all indices. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(qf): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match qf length {len(qf)}." + ) + + if self.device.type == "cpu": + local_joint_ids = np.arange(self.dof) if joint_ids is None else joint_ids + for i, env_idx in enumerate(local_env_ids): + setter = self._entities[env_idx].set_current_qf + setter(qf[i].numpy(), local_joint_ids) + else: + indices = self.body_data.gpu_indices[local_env_ids] + if joint_ids is None: + qf_set = self.body_data._qf[local_env_ids] + qf_set[:, : self.dof] = qf + else: + self.body_data.qf + qf_set = self.body_data._qf[local_env_ids] + qf_set[:, joint_ids] = qf + self._ps.gpu_apply_joint_data( + data=qf_set, + gpu_indices=indices, + data_type=ArticulationGPUAPIWriteType.JOINT_FORCE, + ) + + def set_drive( + self, + stiffness: Optional[torch.Tensor] = None, + damping: Optional[torch.Tensor] = None, + max_effort: Optional[torch.Tensor] = None, + max_velocity: Optional[torch.Tensor] = None, + friction: Optional[torch.Tensor] = None, + drive_type: str = "force", + joint_ids: Optional[Sequence[int]] = None, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """Set the drive properties for the articulation. + + Args: + stiffness (torch.Tensor): The stiffness of the joint drive with shape (len(env_ids), len(joint_ids)). + damping (torch.Tensor): The damping of the joint drive with shape (len(env_ids), len(joint_ids)). + max_effort (torch.Tensor): The maximum effort of the joint drive with shape (len(env_ids), len(joint_ids)). + max_velocity (torch.Tensor): The maximum velocity of the joint drive with shape (len(env_ids), len(joint_ids)). + friction (torch.Tensor): The joint friction coefficient with shape (len(env_ids), len(joint_ids)). + drive_type (str, optional): The type of drive to apply. Defaults to "force". + joint_ids (Optional[Sequence[int]], optional): The joint indices to apply the drive to. If None, applies to all joints. Defaults to None. + env_ids (Optional[Sequence[int]], optional): The environment indices to apply the drive to. If None, applies to all environments. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + local_joint_ids = np.arange(self.dof) if joint_ids is None else joint_ids + + for i, env_idx in enumerate(local_env_ids): + drive_args = { + "drive_type": get_dexsim_drive_type(drive_type), + "joint_ids": local_joint_ids, + } + if stiffness is not None: + drive_args["stiffness"] = stiffness[i].cpu().numpy() + if damping is not None: + drive_args["damping"] = damping[i].cpu().numpy() + if max_effort is not None: + drive_args["max_force"] = max_effort[i].cpu().numpy() + if max_velocity is not None: + drive_args["max_velocity"] = max_velocity[i].cpu().numpy() + if friction is not None: + drive_args["joint_friction"] = friction[i].cpu().numpy() + self._entities[env_idx].set_drive(**drive_args) + + def get_user_ids(self) -> torch.Tensor: + """Get the user ids of the articulation. + + Returns: + torch.Tensor: The user ids of the articulation with shape (N, num_link). + """ + return torch.as_tensor( + np.array( + [entity.get_user_ids() for entity in self._entities], + ), + dtype=torch.int32, + device=self.device, + ) + + def clear_dynamics(self, env_ids: Optional[Sequence[int]] = None) -> None: + """Clear the dynamics of the articulation. + + Args: + env_ids (Optional[Sequence[int]]): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + if self.device.type == "cpu": + zero_joint_data = np.zeros((len(local_env_ids), self.dof), dtype=np.float32) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_qvel(zero_joint_data[i]) + self._entities[env_idx].set_current_qf(zero_joint_data[i]) + else: + zeros = torch.zeros( + (len(local_env_ids), self.dof), dtype=torch.float32, device=self.device + ) + indices = self.body_data.gpu_indices[local_env_ids] + self._ps.gpu_apply_joint_data( + data=zeros, + gpu_indices=indices, + data_type=ArticulationGPUAPIWriteType.JOINT_VELOCITY, + ) + self._ps.gpu_apply_joint_data( + data=zeros, + gpu_indices=indices, + data_type=ArticulationGPUAPIWriteType.JOINT_FORCE, + ) + + def reallocate_body_data(self) -> None: + """Reallocate body data tensors to match the current articulation state in the GPU physics scene.""" + if self.device.type == "cpu": + logger.log_warning(f"Reallocating body data on CPU is not supported.") + return + + max_dof = self._ps.gpu_get_articulation_max_dof() + max_num_links = self._ps.gpu_get_articulation_max_link_count() + self._data._qpos = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._data._qvel = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._data._qacc = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._data._qf = torch.zeros( + (self.num_instances, max_dof), dtype=torch.float32, device=self.device + ) + self._data._body_link_pose = torch.zeros( + (self.num_instances, max_num_links, 7), + dtype=torch.float32, + device=self.device, + ) + self._data._body_link_vel = torch.zeros( + (self.num_instances, max_num_links, 6), + dtype=torch.float32, + device=self.device, + ) + + self._data._body_link_lin_vel = torch.zeros( + (self.num_instances, max_num_links, 3), + dtype=torch.float32, + device=self.device, + ) + self._data._body_link_ang_vel = torch.zeros( + (self.num_instances, max_num_links, 3), + dtype=torch.float32, + device=self.device, + ) + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + local_env_ids = self._all_indices if env_ids is None else env_ids + num_instances = len(local_env_ids) + self.cfg: ArticulationCfg + pos = torch.as_tensor( + self.cfg.init_pos, dtype=torch.float32, device=self.device + ) + rot = ( + torch.as_tensor(self.cfg.init_rot, dtype=torch.float32, device=self.device) + * torch.pi + / 180.0 + ) + pos = pos.unsqueeze(0).repeat(num_instances, 1) + rot = rot.unsqueeze(0).repeat(num_instances, 1) + mat = matrix_from_euler(rot, "XYZ") + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(num_instances, 1, 1) + ) + pose[:, :3, 3] = pos + pose[:, :3, :3] = mat + self.set_local_pose(pose, env_ids=local_env_ids) + + qpos = torch.as_tensor( + self.cfg.init_qpos, dtype=torch.float32, device=self.device + ) + qpos = qpos.unsqueeze(0).repeat(num_instances, 1) + self.set_qpos(qpos, target=False, env_ids=local_env_ids) + # Set drive target to hold position. + self.set_qpos(qpos, target=True, env_ids=local_env_ids) + + self.clear_dynamics(env_ids=local_env_ids) + + if self.device.type == "cuda": + self._ps.gpu_compute_articulation_kinematic( + gpu_indices=self.body_data.gpu_indices[local_env_ids] + ) + + # TODO: To be removed when gpu articulation data sync is supported. + if is_rt_enabled() is False: + self.body_data.body_link_pose + link_pose = self.body_data._body_link_pose[local_env_ids] + indices = self.body_data.gpu_indices[local_env_ids] + self._world.sync_poses_gpu_to_cpu( + link_pose=CudaArray(link_pose), + articulation_gpu_indices=CudaArray(indices), + ) + else: + self._world.update(0.001) + + def _set_default_joint_drive(self) -> None: + """Set default joint drive parameters based on the configuration.""" + import numbers + from embodichain.utils.string import resolve_matching_names_values + + drive_props = [ + ("damping", self.default_joint_damping), + ("stiffness", self.default_joint_stiffness), + ("max_effort", self.default_joint_max_effort), + ("max_velocity", self.default_joint_max_velocity), + ("friction", self.default_joint_friction), + ] + + for prop_name, default_array in drive_props: + value = getattr(self.cfg.drive_pros, prop_name, None) + if value is None: + continue + if isinstance(value, numbers.Number): + default_array[:] = value + else: + try: + indices, _, values = resolve_matching_names_values( + value, self.joint_names + ) + default_array[:, indices] = torch.as_tensor( + values, dtype=torch.float32, device=self.device + ) + except Exception as e: + logger.log_error(f"Failed to set {prop_name}: {e}") + + drive_pros = self.cfg.drive_pros + if isinstance(drive_pros, dict): + drive_type = drive_pros.get("drive_type", None) + else: + drive_type = getattr(drive_pros, "drive_type", None) + + # Apply drive parameters to all articulations in the batch + self.set_drive( + stiffness=self.default_joint_stiffness, + damping=self.default_joint_damping, + max_effort=self.default_joint_max_effort, + max_velocity=self.default_joint_max_velocity, + friction=self.default_joint_friction, + drive_type=drive_type, + ) + + def compute_fk( + self, + qpos: Optional[Union[torch.tensor, np.ndarray]], + link_names: Optional[Union[str, list[str], tuple[str]]] = None, + end_link_name: Optional[str] = None, + root_link_name: Optional[str] = None, + to_dict: bool = False, + **kwargs, + ) -> Union[torch.tensor, dict[str, "pk.Transform3d"]]: + """Compute the forward kinematics (FK) for the given joint positions. + + Args: + qpos (torch.Tensor): Joint positions. Shape can be (dof,) for a single configuration or + (batch_size, dof) for batched configurations. + link_names (Union[str, list[str], tuple[str]], optional): Names of the links for which FK is computed. + If None, all links are considered. + end_link_name (str, optional): Name of the end link for which FK is computed. If None, all links are considered. + root_link_name (str, optional): Name of the root link for which FK is computed. Defaults to None. + to_dict (bool, optional): If True, returns the FK result as a dictionary of Transform3d objects. Defaults to False. + **kwargs: Additional keyword arguments for customization. + + Returns: + torch.Tensor: The homogeneous transformation matrix/matrices for the specified links. + Shape is (batch_size, 4, 4) for batched input or (4, 4) for single input. + If `to_dict` is True, returns a dictionary of Transform3d objects instead. + """ + frame_indices = None + + # Adapt link_names to work with get_frame_indices + if link_names is not None: + if isinstance(link_names, str): + # Single link name + frame_indices = self.pk_chain.get_frame_indices(link_names) + elif isinstance(link_names, (list, tuple)): + # Multiple link names + frame_indices = self.pk_chain.get_frame_indices(*link_names) + else: + raise TypeError( + f"Invalid type for link_names: {type(link_names)}. Expected str, list, or tuple." + ) + + if end_link_name is None and root_link_name is None: + result = self.pk_chain.forward_kinematics( + th=qpos, frame_indices=frame_indices + ) + else: + pk_serial_chain = create_pk_serial_chain( + chain=self.pk_chain, + root_link_name=root_link_name, + end_link_name=end_link_name, + ) + result = pk_serial_chain.forward_kinematics(th=qpos, end_only=True) + + if to_dict: + return result + + # Extract transformation matrices + if isinstance(result, dict): + if link_names: + matrices = torch.stack( + [result[name].get_matrix() for name in link_names], dim=0 + ) + else: + link_name = end_link_name if end_link_name else list(result.keys())[-1] + matrices = result[link_name].get_matrix() + elif isinstance(result, list): + matrices = torch.stack( + [xpos.get_matrix().squeeze() for xpos in result], dim=0 + ) + else: + matrices = result.get_matrix() + + # Ensure batch format + if matrices.dim() == 2: + matrices = matrices.unsqueeze(0) + + # Create result tensor with proper homogeneous coordinates + if matrices.dim() == 4: # Multiple links + num_links, batch_size, _, _ = matrices.shape + result = ( + torch.eye(4, device=self.device) + .expand(num_links, batch_size, 4, 4) + .clone() + ) + result[:, :, :3, :] = matrices[:, :, :3, :] + result = result.permute(1, 0, 2, 3) # (batch_size, num_links, 4, 4) + elif matrices.dim() == 3: # Single link + batch_size, _, _ = matrices.shape + result = torch.eye(4, device=self.device).expand(batch_size, 4, 4).clone() + result[:, :3, :] = matrices[:, :3, :] + else: + raise ValueError(f"Unexpected matrices shape: {matrices.shape}") + + return result + + def compute_jacobian( + self, + qpos: Optional[Union[torch.Tensor, np.ndarray]], + end_link_name: str = None, + root_link_name: str = None, + locations: Optional[Union[torch.Tensor, np.ndarray]] = None, + jac_type: str = "full", + ) -> torch.Tensor: + """Compute the Jacobian matrix for the given joint positions using the pk_serial_chain. + + Args: + qpos (torch.Tensor): The joint positions. Shape can be (dof,) for a single configuration + or (batch_size, dof) for batched configurations. + end_link_name (str, optional): The name of the end link for which the Jacobian is computed. + Defaults to the last link in the chain. + root_link_name (str, optional): The name of the root link for which the Jacobian is computed. + Defaults to the first link in the chain. + locations (Union[torch.Tensor, np.ndarray], optional): Offset points relative to the end-effector + frame for which the Jacobian is computed. + Shape can be (batch_size, 3) or (3,) for a single offset. + Defaults to None (origin of the end-effector frame). + jac_type (str, optional): Specifies the part of the Jacobian to return: + - 'full': Returns the full Jacobian (6, dof) or (batch_size, 6, dof). + - 'trans': Returns only the translational part (3, dof) or (batch_size, 3, dof). + - 'rot': Returns only the rotational part (3, dof) or (batch_size, 3, dof). + Defaults to 'full'. + + Returns: + torch.Tensor: The Jacobian matrix. Shape depends on the input: + - For a single link: (6, dof) or (batch_size, 6, dof). + - For multiple links: (num_links, 6, dof) or (num_links, batch_size, 6, dof). + The shape also depends on the `jac_type` parameter. + """ + if qpos is None: + qpos = torch.zeros(self.dof, device=self.device) + + # Ensure qpos is a tensor on the correct device + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) + + # Default root and end link names if not provided + frame_names = self.pk_chain.get_frame_names() + if root_link_name is None: + root_link_name = frame_names[0] # Default to the first frame + if end_link_name is None: + end_link_name = frame_names[-1] # Default to the last frame + + # Create pk_serial_chain + pk_serial_chain = create_pk_serial_chain( + chain=self.pk_chain, + root_link_name=root_link_name, + end_link_name=end_link_name, + ) + + # Compute the Jacobian using the kinematics chain + J = pk_serial_chain.jacobian(th=qpos, locations=locations) + + # Handle jac_type to return the desired part of the Jacobian + if jac_type == "trans": + return J[:, :3, :] if J.dim() == 3 else J[:3, :] + elif jac_type == "rot": + return J[:, 3:, :] if J.dim() == 3 else J[3:, :] + elif jac_type == "full": + return J + else: + raise ValueError( + f"Invalid jac_type '{jac_type}'. Must be 'full', 'trans', or 'rot'." + ) + + def set_visual_material( + self, + mat: VisualMaterial, + env_ids: Optional[Sequence[int]] = None, + link_names: Optional[List[str]] = None, + ) -> None: + """Set visual material for the rigid object. + + Args: + mat (VisualMaterial): The material to set. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + link_names (Optional[List[str]], optional): List of link names to apply the material to. If None, applies to all links. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + link_names = self.link_names if link_names is None else link_names + + for i, env_idx in enumerate(local_env_ids): + for link_name in link_names: + mat_inst = mat.create_instance( + f"{mat.uid}_{self.uid}_{link_name}_{env_idx}" + ) + self._entities[env_idx].set_material(link_name, mat_inst.mat) + self._visual_material[env_idx][link_name] = mat_inst + + def get_visual_material_inst( + self, + env_ids: Optional[Sequence[int]] = None, + link_names: Optional[List[str]] = None, + ) -> List[Dict[str, VisualMaterialInst]]: + """Get visual material instances for the rigid object. + + Args: + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + link_names (Optional[List[str]], optional): List of link names to filter materials. If None, returns materials for all links. + Returns: + List[Dict[str, VisualMaterialInst]]: A list where each element corresponds to an environment and contains a dictionary mapping link names to their VisualMaterialInst. + """ + if env_ids is None and link_names is None: + return self._visual_material + + local_env_ids = self._all_indices if env_ids is None else env_ids + link_names = self.link_names if link_names is None else link_names + + result = [] + for i, env_idx in enumerate(local_env_ids): + if link_names is None: + result.append(self._visual_material[env_idx]) + else: + mat_dict = { + link_name: self._visual_material[env_idx][link_name] + for link_name in link_names + if link_name in self._visual_material[env_idx] + } + result.append(mat_dict) + return result + + def destroy(self) -> None: + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + for i, entity in enumerate(self._entities): + arenas[i].remove_articulation(entity) diff --git a/embodichain/lab/sim/objects/gizmo.py b/embodichain/lab/sim/objects/gizmo.py new file mode 100644 index 00000000..137fe20d --- /dev/null +++ b/embodichain/lab/sim/objects/gizmo.py @@ -0,0 +1,545 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +Gizmo: A reusable controller for interactive manipulation of simulation elements (object, robot, camera, etc.) +""" + + +import numpy as np +import torch +import dexsim +from typing import Callable, Optional +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.sim.common import BatchEntity +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.sensors import Camera +from embodichain.utils import configclass, logger + +from dexsim.types import ( + AxisOption, + RotationRingsOption, + AxisArrowType, + AxisCornerType, + AxisTagType, + TransformMask, + ActorType, + RigidBodyShape, + PhysicalAttr, +) +from dexsim.render import GizmoController + +from embodichain.lab.sim.utility.gizmo_utils import create_gizmo_callback + + +@configclass +class GizmoCfg: + """Configuration class for Gizmo parameters. + + This class defines the visual and interaction parameters for gizmo controllers, + including axis appearance and rotation rings settings. + """ + + # Axis configuration + axis_length_x: float = 0.2 + """Length of X-axis arrow.""" + axis_length_y: float = 0.2 + """Length of Y-axis arrow.""" + axis_length_z: float = 0.2 + """Length of Z-axis arrow.""" + axis_size: float = 0.01 + """Thickness of axis lines.""" + arrow_type: AxisArrowType = AxisArrowType.CONE + """Type of arrow head.""" + corner_type: AxisCornerType = AxisCornerType.SPHERE + """Type of axis corner.""" + tag_type: AxisTagType = AxisTagType.PLANE + """Type of axis label.""" + + # Rotation rings configuration + rings_radius: float = 0.15 + """Radius of rotation rings.""" + rings_size: float = 0.01 + """Thickness of rotation rings.""" + + def to_options_dict(self) -> dict: + """Convert configuration to options dictionary format expected by gizmo creation. + + Returns: + Dictionary containing AxisOption and RotationRingsOption objects. + """ + return { + "axis": AxisOption( + lx=self.axis_length_x, + ly=self.axis_length_y, + lz=self.axis_length_z, + size=self.axis_size, + arrow_type=self.arrow_type, + corner_type=self.corner_type, + tag_type=self.tag_type, + ), + "rings": RotationRingsOption( + radius=self.rings_radius, size=self.rings_size + ), + } + + +class Gizmo: + """ + Generic Gizmo controller for simulation elements. + Supports RigidObject, Robot, and Camera with type-specific handling. + + Note: + Gizmo can only be used in single environment mode (num_envs=1). + Will raise RuntimeError if used with multiple environments. + """ + + def __init__( + self, + target: BatchEntity, + cfg: Optional[GizmoCfg] = None, + control_part: Optional[str] = "arm", + ): + """ + Args: + target: The simulation element to control (RigidObject, Robot, or Camera) + cfg: Gizmo configuration parameters (optional, uses default if None) + control_part: For robots, specifies which control part to use (optional, default: "arm") + """ + self.target = target + self._target_type = self._detect_target_type(target) + self._control_part = control_part + self._env = dexsim.default_world().get_env() + self._windows = dexsim.default_world().get_windows() + + # Check if running in single environment (num_env must be 1) + num_envs = dexsim.get_world_num() + if num_envs > 1: + raise RuntimeError( + f"Gizmo can only be used in single environment mode (num_env=1), " + f"but current num_envs={num_envs}. Please create simulation with num_envs=1." + ) + + # Use provided config or get default + if cfg is None: + cfg = self._get_default_cfg() + self.cfg = cfg + self._gizmo = self._create_gizmo(self.cfg) + self._callback = None + self._state = "active" + self._setup_gizmo_follow() + + def _detect_target_type(self, target: BatchEntity) -> str: + """Detect target type: 'rigidobject', 'robot', or 'camera' using isinstance only.""" + if Robot is not None and isinstance(target, Robot): + return "robot" + if Camera is not None and isinstance(target, Camera): + return "camera" + if RigidObject is not None and isinstance(target, RigidObject): + return "rigidobject" + + raise ValueError( + f"Unsupported target type: {type(target)}. Only RigidObject, Robot, and Camera are supported." + ) + + def _get_default_cfg(self) -> GizmoCfg: + """Get default gizmo configuration (same for all target types)""" + return GizmoCfg() + + def _create_gizmo(self, cfg: GizmoCfg): + """Create gizmo using configuration object""" + options = cfg.to_options_dict() + axis = options["axis"] + rings = options["rings"] + return self._env.create_gizmo(axis, rings) + + def _compute_ee_pose_fk(self): + """Compute end-effector pose using forward kinematics""" + # Get current joint positions for this arm + proprioception = self.target.get_proprioception() + current_qpos_full = proprioception["qpos"] + current_joint_ids = self.target.get_joint_ids(self._robot_arm_name) + + joint_positions = current_qpos_full[:, current_joint_ids] + if joint_positions.dim() > 1: + joint_positions = joint_positions[0] + + # Compute forward kinematics + ee_pose = self.target.compute_fk( + joint_positions, name=self._control_part, to_matrix=True + ) + + return ee_pose + + def _create_proxy_cube( + self, position: np.ndarray, rotation_matrix: np.ndarray, name: str + ): + """Create a proxy cube for gizmo tracking""" + # Convert rotation matrix to euler angles + euler = R.from_matrix(rotation_matrix).as_euler("xyz", degrees=False) + + # Create small proxy cube at specified position + proxy_cube = self._env.create_cube(0.02, 0.02, 0.02) # 2cm cube + proxy_cube.set_location(position[0], position[1], position[2]) + proxy_cube.set_rotation_euler(euler[0], euler[1], euler[2]) + + # Add kinematic physics to proxy cube + attr = PhysicalAttr() + attr.mass = 0.05 + proxy_cube.add_rigidbody(ActorType.KINEMATIC, RigidBodyShape.CONVEX, attr) + + # Connect gizmo to proxy cube + self._gizmo.node.update_gizmo_follow(proxy_cube.node) + + logger.log_info(f"{name} gizmo proxy created at position: {position}") + return proxy_cube + + def _setup_camera_gizmo(self): + """Setup gizmo for Camera by creating a proxy RigidObject at camera position""" + # Get current camera pose + camera_pose = self.target.get_local_pose(to_matrix=True)[0] # Get first camera + camera_pos = camera_pose[:3, 3].cpu().numpy() + camera_rot_matrix = camera_pose[:3, :3].cpu().numpy() + + # Create proxy cube and set callback + self._proxy_cube = self._create_proxy_cube( + camera_pos, camera_rot_matrix, "Camera" + ) + self._gizmo.node.set_flush_transform_callback(self._proxy_gizmo_callback) + + def _proxy_gizmo_callback(self, node, translation, rotation, flag): + """Generic callback for proxy-based gizmo: only updates proxy cube transform, defers actual updates""" + if node is None: + return + + # Check if proxy cube still exists (not destroyed) + if not hasattr(self, "_proxy_cube") or self._proxy_cube is None: + return + + # Update proxy cube transform + if flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_T): + node.set_translation(translation) + elif flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_R): + node.set_rotation_rpy(rotation) + + # Mark that target needs to be updated, save target transform + proxy_pos = self._proxy_cube.get_location() + proxy_rot = self._proxy_cube.get_rotation_euler() + target_transform = torch.eye(4, dtype=torch.float32) + target_transform[:3, 3] = torch.tensor( + [proxy_pos[0], proxy_pos[1], proxy_pos[2]], dtype=torch.float32 + ) + target_transform[:3, :3] = torch.tensor( + R.from_euler("xyz", proxy_rot).as_matrix(), dtype=torch.float32 + ) + self._pending_target_transform = target_transform + + def _update_camera_pose(self, target_transform: torch.Tensor): + """Update camera pose to match target transform""" + try: + # Set camera pose using set_local_pose method + # target_transform shape: (4, 4), but set_local_pose expects (N, 4, 4) + target_transform_batch = target_transform.unsqueeze( + 0 + ) # Add batch dimension + self.target.set_local_pose(target_transform_batch) + return True + except Exception as e: + logger.log_error(f"Error updating camera pose: {e}") + return False + + def _setup_robot_gizmo(self): + """Setup gizmo for Robot by creating a proxy RigidObject at end-effector""" + # Get end-effector pose using specified control part + if self.target.cfg.solver_cfg is None: + raise ValueError( + "Robot has no solver configured for IK/FK computations for gizmo" + ) + + arm_names = list(self.target.control_parts.keys()) + if not arm_names: + raise ValueError("Robot has no control parts defined") + + # Use specified control part or fall back to first available + if self._control_part and self._control_part in arm_names: + self._robot_arm_name = self._control_part + else: + logger.log_error(f"Control part '{self._control_part}' not found.") + + logger.log_info(f"Using control part: {self._robot_arm_name}") + + # Get end-effector pose using forward kinematics + ee_pose = self._compute_ee_pose_fk()[0] # remove batch dimension + + ee_pos = ee_pose[:3, 3].cpu().numpy() + ee_rot_matrix = ee_pose[:3, :3].cpu().numpy() + + # Create proxy cube and set callback + self._proxy_cube = self._create_proxy_cube(ee_pos, ee_rot_matrix, "Robot") + self._gizmo.node.set_flush_transform_callback(self._proxy_gizmo_callback) + + def _update_robot_ik(self, target_transform: torch.Tensor): + """Update robot joints using IK to reach target transform""" + try: + # Get robot solver for the arm + solver = self.target.get_solver(self._robot_arm_name) + if solver is None: + logger.log_warning(f"No solver found for arm: {self._robot_arm_name}") + return False + + # Get current joint positions as seed using proprioception + current_qpos_full = self.target.get_qpos() + + # Get joint IDs for this arm + current_joint_ids = self.target.get_joint_ids(self._robot_arm_name) + + # Extract joint positions for this specific arm + if len(current_joint_ids) > 0: + joint_seed = current_qpos_full[ + :, current_joint_ids + ] # Select arm joints + if joint_seed.dim() > 1: + joint_seed = joint_seed[0] # Take first batch element + else: + logger.log_warning( + f"No joint IDs found for arm: {self._robot_arm_name}" + ) + return False + + # Solve IK + ik_success, new_qpos = solver.get_ik( + target_xpos=target_transform, joint_seed=joint_seed + ) + + if ik_success: + # Ensure correct dimensions for setting qpos + if new_qpos.dim() == 1: + new_qpos = new_qpos.unsqueeze(0) + elif new_qpos.dim() == 3: + new_qpos = new_qpos[:, 0, :] + + # Update robot joint positions + self.target.set_qpos(qpos=new_qpos, joint_ids=current_joint_ids) + return True + else: + logger.log_warning("IK solution not found") + return False + + except Exception as e: + logger.log_error(f"Error in robot IK: {e}") + return False + + def _setup_gizmo_follow(self): + """Setup gizmo based on target type""" + if self._target_type == "rigidobject": + # RigidObject: direct node access through MeshObject + self._gizmo.node.update_gizmo_follow(self.target._entities[0].node) + self._gizmo.node.set_flush_transform_callback(create_gizmo_callback()) + elif self._target_type == "robot": + # Robot: create proxy object at end-effector position + self._setup_robot_gizmo() + elif self._target_type == "camera": + # Camera: create proxy object at camera position + self._setup_camera_gizmo() + + def attach(self, target: BatchEntity): + """Attach gizmo to a new simulation element.""" + self.target = target + self._target_type = self._detect_target_type(target) + self._setup_gizmo_follow() + + def detach(self): + """Detach gizmo from current element.""" + self.target = None + # Use detach_parent to properly disconnect gizmo + try: + self._gizmo.node.detach_parent() + except Exception as e: + logger.log_warning(f"Failed to detach gizmo parent: {e}") + + def set_transform_callback(self, callback: Callable): + """Set callback for gizmo transform events (translation/rotation).""" + self._callback = callback + self._gizmo.node.set_flush_transform_callback(callback) + + def set_world_pose(self, pose): + """Set gizmo's world pose.""" + self._gizmo.node.set_world_pose(pose) + + def set_local_pose(self, pose): + """Set gizmo's local pose.""" + self._gizmo.node.set_local_pose(pose) + + def set_line_width(self, width: float): + """Set gizmo line width.""" + self._gizmo.node.set_line_width(width) + + def enable_collision(self, enabled: bool): + """Enable or disable gizmo collision.""" + self._gizmo.node.enable_collision(enabled) + + def get_world_pose(self): + """Get gizmo's world pose.""" + return self._gizmo.node.get_world_pose() + + def get_local_pose(self): + """Get gizmo's local pose.""" + return self._gizmo.node.get_local_pose() + + def get_name(self): + """Get gizmo node name.""" + return self._gizmo.node.get_name() + + def get_parent(self): + """Get gizmo's parent node.""" + return self._gizmo.node.get_parent() + + def toggle_visibility(self) -> bool: + """ + Toggle the visibility of the gizmo. + + Returns: + bool: The new visibility state (True = visible, False = hidden) + """ + if not hasattr(self, "_is_visible"): + self._is_visible = True # Default to visible + + # Toggle the state + self._is_visible = not self._is_visible + + # Apply the visibility setting to the gizmo node + if self._gizmo and hasattr(self._gizmo, "node"): + self._gizmo.node.set_physical_visible(self._is_visible, self._is_visible) + + return self._is_visible + + def set_visibility(self, visible: bool): + """ + Set the visibility of the gizmo. + + Args: + visible (bool): True to show, False to hide the gizmo + """ + self._is_visible = visible + + # Apply the visibility setting to the gizmo node + if self._gizmo and hasattr(self._gizmo, "node"): + self._gizmo.node.set_physical_visible(self._is_visible, self._is_visible) + + def is_visible(self) -> bool: + """ + Check if the gizmo is currently visible. + + Returns: + bool: True if visible, False if hidden + """ + return getattr(self, "_is_visible", True) + + def update(self): + """Synchronize gizmo with target's current transform, and handle IK solving here.""" + if self._target_type == "rigidobject": + self._gizmo.node.update_gizmo_follow(self.target._entities[0].node) + elif self._target_type == "robot": + # If there is a pending target, solve IK and clear it + if ( + hasattr(self, "_pending_target_transform") + and self._pending_target_transform is not None + ): + self._update_robot_ik(self._pending_target_transform) + self._pending_target_transform = None + elif self._target_type == "camera": + # Update proxy cube position to match current camera pose + if hasattr(self, "_proxy_cube") and self._proxy_cube: + camera_pose = self.target.get_local_pose(to_matrix=True)[0] + camera_pos = camera_pose[:3, 3].cpu().numpy() + self._proxy_cube.set_location( + camera_pos[0], camera_pos[1], camera_pos[2] + ) + + # If there is a pending camera target, update camera pose and clear it + if ( + hasattr(self, "_pending_target_transform") + and self._pending_target_transform is not None + ): + self._update_camera_pose(self._pending_target_transform) + self._pending_target_transform = None + + def apply_transform(self, translation, rotation): + """Apply transform based on target type""" + if self._target_type == "rigidobject": + self.target.set_location(*translation) + self.target.set_rotation_euler(*rotation) + elif self._target_type == "robot": + # Robot transforms are handled by IK in the gizmo callback + if hasattr(self, "_proxy_cube") and self._proxy_cube: + self._proxy_cube.set_location(*translation) + self._proxy_cube.set_rotation_euler(*rotation) + elif self._target_type == "camera": + # Camera transforms are handled by pose update in the gizmo callback + if hasattr(self, "_proxy_cube") and self._proxy_cube: + self._proxy_cube.set_location(*translation) + self._proxy_cube.set_rotation_euler(*rotation) + else: + # Other target types + pass + + def destroy(self): + """Clean up gizmo resources and release references.""" + # Clear transform callback first to avoid bad_function_call + if hasattr(self, "_gizmo") and self._gizmo and hasattr(self._gizmo, "node"): + try: + # Clear transform callback before any other cleanup + self._gizmo.node.set_flush_transform_callback(None) + logger.log_info("Cleared gizmo transform callback") + except Exception as e: + logger.log_warning(f"Failed to clear gizmo callback: {e}") + + # Remove proxy cube if exists (before detaching gizmo) + if hasattr(self, "_proxy_cube") and self._proxy_cube: + try: + # Detach gizmo from proxy cube first + if ( + hasattr(self, "_gizmo") + and self._gizmo + and hasattr(self._gizmo, "node") + ): + self._gizmo.node.detach_parent() + # Then remove the proxy cube + self._env.remove_actor(self._proxy_cube) + logger.log_info("Successfully removed proxy cube from environment") + except Exception as e: + logger.log_warning(f"Failed to remove proxy cube: {e}") + self._proxy_cube = None + + # Final gizmo cleanup + if hasattr(self, "_gizmo") and self._gizmo and hasattr(self._gizmo, "node"): + try: + # Ensure detach_parent is called if not done above + if self._target_type in ["robot", "camera"]: + pass # Already detached above + else: + self._gizmo.node.detach_parent() + logger.log_info("Successfully cleaned up gizmo node") + except Exception as e: + logger.log_warning(f"Failed to cleanup gizmo node: {e}") + + # Clear pending transform + if hasattr(self, "_pending_target_transform"): + self._pending_target_transform = None + + # Directly release references + self._gizmo = None + self.target = None diff --git a/embodichain/lab/sim/objects/light.py b/embodichain/lab/sim/objects/light.py new file mode 100644 index 00000000..fb085e88 --- /dev/null +++ b/embodichain/lab/sim/objects/light.py @@ -0,0 +1,281 @@ +import torch +import numpy as np +from typing import List, Optional, Sequence +from dexsim.render import Light as _Light +from embodichain.lab.sim.cfg import LightCfg +from embodichain.lab.sim.common import BatchEntity +from embodichain.utils import logger + + +class Light(BatchEntity): + """Light represents a batch of lights in the simulation. + + Each light supports the following properties: + - Color (3 floats) + - Intensity (1 float) + - Falloff (1 float) + - Location (3 floats) + """ + + def __init__( + self, + cfg: LightCfg, + entities: List[_Light] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + + super().__init__(cfg, entities, device) + + def set_color( + self, colors: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set color for one or more lights. + + Args: + colors (torch.Tensor): Tensor of shape (M, 3) or (3,), representing RGB values. + - If shape is (3,), the same color is applied to all targeted instances. + - If shape is (M, 3), M must match the number of targeted instances. + env_ids (Optional[Sequence[int]]): Indices of instances to set. If None: + - For colors.shape == (3,), applies to all instances. + - For colors.shape == (M, 3), M must equal num_instances, applies per-instance. + """ + self._apply_vector3(colors, env_ids, "set_color") + + def set_intensity( + self, intensities: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set intensity for one or more lights. + + Args: + intensities (torch.Tensor): Tensor of shape (M,), (1,), or scalar (0-dim). + - If scalar or shape (1,), the same intensity is applied to all targeted instances. + - If shape (M,), M must match the number of targeted instances. + env_ids (Optional[Sequence[int]]): Indices of instances to set. If None: + - For scalar/shape (1,), applies to all instances. + - For shape (M,), M must equal num_instances, applies per-instance. + """ + self._apply_scalar(intensities, env_ids, "set_intensity") + + def set_falloff( + self, falloffs: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set falloff (radius) for one or more lights. + + Args: + falloffs (torch.Tensor): Tensor of shape (M,), (1,), or scalar (0-dim). + - If scalar or shape (1,), the same falloff is applied to all targeted instances. + - If shape (M,), M must match the number of targeted instances. + env_ids (Optional[Sequence[int]]): Indices of instances to set. If None: + - For scalar/shape (1,), applies to all instances. + - For shape (M,), M must equal num_instances, applies per-instance. + """ + self._apply_scalar(falloffs, env_ids, "set_falloff") + + def set_local_pose( + self, + pose: torch.Tensor, + env_ids: Optional[Sequence[int]] = None, + to_matrix: bool = False, + ) -> None: + """Set local pose (translation) for one or more lights. + + Args: + pose (torch.Tensor): + - If to_matrix=False: shape (3,) or (M, 3), representing (x, y, z). + - If to_matrix=True: shape (4, 4) or (M, 4, 4); translation extracted automatically. + env_ids (Optional[Sequence[int]]): Indices to set. If None: + - For vector input (3,) broadcast to all, or (M,3) with M == num_instances. + - For matrix input (4,4) broadcast to all, or (M,4,4) with M == num_instances. + to_matrix (bool): Interpret `pose` as full 4x4 matrix if True, else as vector(s). + """ + if not torch.is_tensor(pose): + logger.log_error( + f"set_local_pose requires a torch.Tensor, got {type(pose)}" + ) + return + + cpu = pose.detach().cpu() + if to_matrix: + if cpu.ndim == 2 and cpu.shape == (4, 4): + trans = cpu[:3, 3] + elif cpu.ndim == 3 and cpu.shape[1:] == (4, 4): + trans = cpu[..., 0:3, 3] + else: + logger.log_error( + f"set_local_pose matrix: expected (4,4) or (N,4,4), got {tuple(cpu.shape)}" + ) + return + else: + trans = cpu # expect (3,) or (M,3) + + try: + self._apply_vector3(trans, env_ids, setter_name="set_location") + except Exception as e: + logger.log_error(f"set_local_pose: error while applying translation: {e}") + + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get local pose of each light, either as full matrix or translation vector. + + Args: + to_matrix (bool, optional): If True, return poses as 4×4 matrices. + If False, return translations only as (x, y, z). Defaults to False. + Returns: + torch.Tensor: + - If to_matrix=True: Tensor of shape (N, 4, 4), where N == num_instances. + - If to_matrix=False: Tensor of shape (N, 3), containing translations. + On error or empty instances, returns an empty tensor with shape (0, 4, 4) or (0, 3) respectively, and logs via logger.log_error. + """ + mats = [] + for i in range(self.num_instances): + try: + mat = self._entities[i].get_local_pose() # expect numpy (4,4) + arr = np.array(mat, dtype=np.float32) + if arr.shape != (4, 4): + logger.log_error( + f"get_local_pose: unexpected shape {arr.shape} for instance {i}" + ) + return torch.empty( + (0, 4, 4) if to_matrix else (0, 3), dtype=torch.float32 + ) + mats.append(arr) + except Exception as e: + logger.log_error(f"get_local_pose: error for instance {i}: {e}") + return torch.empty( + (0, 4, 4) if to_matrix else (0, 3), dtype=torch.float32 + ) + + if not mats: + return torch.empty((0, 4, 4) if to_matrix else (0, 3), dtype=torch.float32) + + stacked = np.stack(mats, axis=0) # (N,4,4) + tensor4 = torch.from_numpy(stacked) + if to_matrix: + return tensor4 + # else return translations + return tensor4[:, 0:3, 3].clone() + + def _apply_vector3( + self, + tensor: torch.Tensor, + env_ids: Optional[Sequence[int]], + setter_name: str, + ) -> None: + """ + Generic helper for 3-element vectors (color, location). + Expects tensor shape: (3,), or (M,3) with M == num_instances or M == len(env_ids). + """ + # Validate tensor type + if not torch.is_tensor(tensor): + logger.log_error( + f"{setter_name} requires a torch.Tensor, got {type(tensor)}" + ) + return + + cpu = tensor.detach().cpu() + # Determine target indices + if env_ids is None: + all_ids = list(range(self.num_instances)) + else: + all_ids = list(env_ids) + + # Cases: + # 1) cpu.ndim == 1 and size == 3: broadcast to all_ids + if cpu.ndim == 1 and cpu.shape[0] == 3: + arr = cpu.numpy() + for i in all_ids: + getattr(self._entities[i], setter_name)( + float(arr[0]), float(arr[1]), float(arr[2]) + ) + return + + # 2) cpu.ndim == 2 and cpu.shape == (num_instances, 3), env_ids None or full + if cpu.ndim == 2 and cpu.shape == (self.num_instances, 3) and env_ids is None: + arr_all = cpu.numpy() + for i in range(self.num_instances): + getattr(self._entities[i], setter_name)( + float(arr_all[i, 0]), float(arr_all[i, 1]), float(arr_all[i, 2]) + ) + return + + # 3) cpu.ndim == 2 and env_ids provided, cpu.shape == (len(env_ids), 3) + if ( + cpu.ndim == 2 + and env_ids is not None + and cpu.shape[0] == len(all_ids) + and cpu.shape[1] == 3 + ): + arr_sel = cpu.numpy() + for idx, i in enumerate(all_ids): + getattr(self._entities[i], setter_name)( + float(arr_sel[idx, 0]), + float(arr_sel[idx, 1]), + float(arr_sel[idx, 2]), + ) + return + + logger.log_error( + f"{setter_name}: tensor shape {tuple(cpu.shape)} is invalid for broadcasting " + f"(expected (3,) or ({self.num_instances},3) or ({len(all_ids)},3))." + ) + + def _apply_scalar( + self, + tensor: torch.Tensor, + env_ids: Optional[Sequence[int]], + setter_name: str, + ) -> None: + """ + Generic helper for scalar floats (intensity, falloff). + Accepts tensor shape: () (0-dim), (1,), or (M,) with M == num_instances or M == len(env_ids). + """ + if not torch.is_tensor(tensor): + logger.log_error( + f"{setter_name} requires a torch.Tensor, got {type(tensor)}" + ) + return + + cpu = tensor.detach().cpu() + if env_ids is None: + all_ids = list(range(self.num_instances)) + else: + all_ids = list(env_ids) + + # 1) scalar tensor: broadcast + if cpu.ndim == 0: + val = float(cpu.item()) + for i in all_ids: + getattr(self._entities[i], setter_name)(val) + return + + # 2) 1D tensor: + if cpu.ndim == 1: + length = cpu.shape[0] + arr = cpu.numpy() + # a) length == num_instances and env_ids None: map one-to-one + if length == self.num_instances and env_ids is None: + for i in range(self.num_instances): + getattr(self._entities[i], setter_name)(float(arr[i])) + return + # b) length == len(env_ids) when env_ids provided: map one-to-one + if env_ids is not None and length == len(all_ids): + for idx, i in enumerate(all_ids): + getattr(self._entities[i], setter_name)(float(arr[idx])) + return + # c) length == 1: broadcast + if length == 1: + val = float(arr[0]) + for i in all_ids: + getattr(self._entities[i], setter_name)(val) + return + + logger.log_error( + f"{setter_name}: tensor shape {tuple(cpu.shape)} is invalid for broadcasting " + f"(expected scalar, (1,), ({self.num_instances},) or ({len(all_ids)},))." + ) + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + self.cfg: LightCfg + self.set_color(torch.as_tensor(self.cfg.color), env_ids=env_ids) + self.set_intensity(torch.as_tensor(self.cfg.intensity), env_ids=env_ids) + self.set_falloff(torch.as_tensor(self.cfg.radius), env_ids=env_ids) + self.set_local_pose(torch.as_tensor(self.cfg.init_pos), env_ids=env_ids) diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py new file mode 100644 index 00000000..d5b4bb6e --- /dev/null +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -0,0 +1,667 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import dexsim +import numpy as np + +from dataclasses import dataclass +from typing import List, Sequence, Optional, Union + +from dexsim.models import MeshObject +from dexsim.types import RigidBodyGPUAPIReadType, RigidBodyGPUAPIWriteType +from dexsim.engine import CudaArray, PhysicsScene +from embodichain.lab.sim.cfg import RigidObjectCfg, RigidBodyAttributesCfg +from embodichain.lab.sim import ( + VisualMaterial, + VisualMaterialInst, + BatchEntity, +) +from embodichain.lab.sim.utility import is_rt_enabled +from embodichain.utils.math import convert_quat +from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler +from embodichain.utils import logger + + +@dataclass +class RigidBodyData: + """Data manager for rigid body with body type of dynamic or kinematic. + + Note: + 1. The pose data managed by dexsim is in the format of (qx, qy, qz, qw, x, y, z), but in SimulationManager, we use (x, y, z, qw, qx, qy, qz) format. + """ + + def __init__( + self, entities: List[MeshObject], ps: PhysicsScene, device: torch.device + ) -> None: + """Initialize the RigidBodyData. + + Args: + entities (List[MeshObject]): List of MeshObjects representing the rigid bodies. + ps (PhysicsScene): The physics scene. + device (torch.device): The device to use for the rigid body data. + """ + self.entities = entities + self.ps = ps + self.num_instances = len(entities) + self.device = device + + # get gpu indices for the entities. + self.gpu_indices = torch.as_tensor( + [entity.get_gpu_index() for entity in self.entities], + dtype=torch.int32, + device=self.device, + ) + + # Initialize rigid body data. + self._pose = torch.zeros( + (self.num_instances, 7), dtype=torch.float32, device=self.device + ) + self._lin_vel = torch.zeros( + (self.num_instances, 3), dtype=torch.float32, device=self.device + ) + self._ang_vel = torch.zeros( + (self.num_instances, 3), dtype=torch.float32, device=self.device + ) + + @property + def pose(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch pose from CPU entities + xyzs = torch.as_tensor( + np.array([entity.get_location() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + quats = torch.as_tensor( + np.array( + [entity.get_rotation_quat() for entity in self.entities], + ), + dtype=torch.float32, + device=self.device, + ) + quats = convert_quat(quats, to="wxyz") + self._pose = torch.cat((xyzs, quats), dim=-1) + else: + self.ps.gpu_fetch_rigid_body_data( + data=self._pose, + gpu_indices=self.gpu_indices, + data_type=RigidBodyGPUAPIReadType.POSE, + ) + self._pose[:, :4] = convert_quat(self._pose[:, :4], to="wxyz") + self._pose = self._pose[:, [4, 5, 6, 0, 1, 2, 3]] + return self._pose + + @property + def lin_vel(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch linear velocity from CPU entities + self._lin_vel = torch.as_tensor( + np.array([entity.get_linear_velocity() for entity in self.entities]), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_rigid_body_data( + data=self._lin_vel, + gpu_indices=self.gpu_indices, + data_type=RigidBodyGPUAPIReadType.LINEAR_VELOCITY, + ) + return self._lin_vel + + @property + def ang_vel(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch angular velocity from CPU entities + self._ang_vel = torch.as_tensor( + np.array( + [entity.get_angular_velocity() for entity in self.entities], + ), + dtype=torch.float32, + device=self.device, + ) + else: + self.ps.gpu_fetch_rigid_body_data( + data=self._ang_vel, + gpu_indices=self.gpu_indices, + data_type=RigidBodyGPUAPIReadType.ANGULAR_VELOCITY, + ) + return self._ang_vel + + @property + def vel(self) -> torch.Tensor: + """Get the linear and angular velocities of the rigid bodies. + + Returns: + torch.Tensor: The linear and angular velocities concatenated, with shape (N, 6). + """ + return torch.cat((self.lin_vel, self.ang_vel), dim=-1) + + +class RigidObject(BatchEntity): + """RigidObject represents a batch of rigid body in the simulation. + + There are three types of rigid body: + - Static: Actors that do not move and are used as the environment. + - Dynamic: Actors that can move and are affected by physics. + - Kinematic: Actors that can move but are not affected by physics. + + """ + + def __init__( + self, + cfg: RigidObjectCfg, + entities: List[MeshObject] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + self.body_type = cfg.body_type + + self._world = dexsim.default_world() + self._ps = self._world.get_physics_scene() + + self._all_indices = torch.arange( + len(entities), dtype=torch.int32, device=device + ) + + # data for managing body data (only for dynamic and kinematic bodies) on GPU. + self._data: Optional[RigidBodyData] = None + if self.is_static is False: + self._data = RigidBodyData(entities=entities, ps=self._ps, device=device) + + # For rendering purposes, each instance can have its own material. + self._visual_material: List[VisualMaterialInst] = [None] * len(entities) + + for entity in entities: + entity.set_body_scale(*cfg.body_scale) + entity.set_physical_attr(cfg.attrs.attr()) + + if device.type == "cuda": + self._world.update(0.001) + + super().__init__(cfg, entities, device) + + # set default collision filter + self._set_default_collision_filter() + + def __str__(self) -> str: + parent_str = super().__str__() + return ( + parent_str + + f" | body type: {self.body_type} | max_convex_hull_num: {self.cfg.max_convex_hull_num}" + ) + + @property + def body_data(self) -> Optional[RigidBodyData]: + """Get the rigid body data manager for this rigid object. + + Returns: + RigidBodyData: The rigid body data manager. + """ + if self.is_static: + logger.log_warning("Static rigid object has no body data.") + return None + + return self._data + + @property + def body_state(self) -> torch.Tensor: + """Get the body state of the rigid object. + + The body state of a rigid object is represented as a tensor with the following format: + [x, y, z, qw, qx, qy, qz, lin_x, lin_y, lin_z, ang_x, ang_y, ang_z] + + If the rigid object is static, linear and angular velocities will be zero. + + Returns: + torch.Tensor: The body state of the rigid object with shape (N, 13), where N is the number of instances. + """ + if self.is_static: + # For static bodies, we return the state with zero velocities. + zero_velocity = torch.zeros((self.num_instances, 6), device=self.device) + return torch.cat((self.pose, zero_velocity), dim=-1) + + return torch.cat( + (self.body_data.pose, self.body_data.lin_vel, self.body_data.ang_vel), + dim=-1, + ) + + @property + def is_static(self) -> bool: + """Check if the rigid object is static. + + Returns: + bool: True if the rigid object is static, False otherwise. + """ + return self.body_type == "static" + + @property + def is_non_dynamic(self) -> bool: + """Check if the rigid object is non-dynamic (static or kinematic). + + Returns: + bool: True if the rigid object is non-dynamic, False otherwise. + """ + return self.body_type in ("static", "kinematic") + + def _set_default_collision_filter(self) -> None: + collision_filter_data = torch.zeros( + size=(self.num_instances, 4), dtype=torch.int32 + ) + for i in range(self.num_instances): + collision_filter_data[i, 0] = i + collision_filter_data[i, 1] = 1 + self.set_collision_filter(collision_filter_data) + + def set_collision_filter( + self, filter_data: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """set collision filter data for the rigid object. + + Args: + filter_data (torch.Tensor): [N, 4] of int. + First element of each object is arena id. + If 2nd element is 0, the object will collision with all other objects in world. + 3rd and 4th elements are not used currently. + + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(filter_data): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(filter_data)}." + ) + + filter_data_np = filter_data.cpu().numpy().astype(np.uint32) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].get_physical_body().set_collision_filter_data( + filter_data_np[i] + ) + + def set_local_pose( + self, pose: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set local pose of the rigid object. + + Args: + pose (torch.Tensor): The local pose of the rigid object with shape (N, 7) or (N, 4, 4). + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(pose): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(pose)}." + ) + + if self.device.type == "cpu" or self.is_static: + pose = pose.cpu() + if pose.dim() == 2 and pose.shape[1] == 7: + pose_matrix = torch.eye(4).unsqueeze(0).repeat(pose.shape[0], 1, 1) + pose_matrix[:, :3, 3] = pose[:, :3] + pose_matrix[:, :3, :3] = matrix_from_quat(pose[:, 3:7]) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose_matrix[i]) + elif pose.dim() == 3 and pose.shape[1:] == (4, 4): + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose[i]) + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + else: + if pose.dim() == 2 and pose.shape[1] == 7: + xyz = pose[:, :3] + quat = convert_quat(pose[:, 3:7], to="xyzw") + elif pose.dim() == 3 and pose.shape[1:] == (4, 4): + xyz = pose[:, :3, 3] + quat = quat_from_matrix(pose[:, :3, :3]) + quat = convert_quat(quat, to="xyzw") + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + # we should keep `pose_` life cycle to the end of the function. + pose = torch.cat((quat, xyz), dim=-1) + indices = self.body_data.gpu_indices[local_env_ids] + self._ps.gpu_apply_rigid_body_data( + data=pose.clone(), + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.POSE, + ) + if is_rt_enabled() is False: + self._world.sync_poses_gpu_to_cpu( + rigid_pose=CudaArray(pose), rigid_gpu_indices=CudaArray(indices) + ) + + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get local pose of the rigid object. + + Args: + to_matrix (bool, optional): If True, return the pose as a 4x4 matrix. If False, return as (x, y, z, qw, qx, qy, qz). Defaults to False. + + Returns: + torch.Tensor: The local pose of the rigid object with shape (N, 7) or (N, 4, 4) depending on `to_matrix`. + """ + + def get_local_pose_cpu( + entities: List[MeshObject], to_matrix: bool + ) -> torch.Tensor: + """Helper function to get local pose on CPU.""" + if to_matrix: + pose = torch.as_tensor( + [entity.get_local_pose() for entity in entities], + ) + else: + xyzs = torch.as_tensor([entity.get_location() for entity in entities]) + quats = torch.as_tensor( + [entity.get_rotation_quat() for entity in entities] + ) + quats = convert_quat(quats, to="wxyz") + pose = torch.cat((xyzs, quats), dim=-1) + + return pose + + if self.is_static: + return get_local_pose_cpu(self._entities, to_matrix).to(self.device) + + pose = self.body_data.pose + if to_matrix: + xyz = pose[:, :3] + mat = matrix_from_quat(pose[:, 3:7]) + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(pose.shape[0], 1, 1) + ) + pose[:, :3, 3] = xyz + pose[:, :3, :3] = mat + return pose + + def add_force_torque( + self, + force: Optional[torch.Tensor] = None, + torque: Optional[torch.Tensor] = None, + pos: Optional[torch.Tensor] = None, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """Add force and/or torque to the rigid object. + + TODO: Currently, apply force at position `pos` is not supported. + + Note: there are a few different ways to apply force and torque: + - If `pos` is specified, the force is applied at that position. + - if not `pos` is specified, the force and torque are applied at the center of mass of the rigid body. + + Args: + force (Optional[torch.Tensor] = None): The force to add with shape (N, 3). Defaults to None. + torque (Optional[torch.Tensor], optional): The torque to add with shape (N, 3). Defaults to None. + pos (Optional[torch.Tensor], optional): The position to apply the force at with shape (N, 3). Defaults to None. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + if force is None and torque is None: + logger.log_warning( + "Both force and torque are None. No force or torque will be applied." + ) + return + + if self.is_non_dynamic: + logger.log_warning( + "Cannot apply force or torque to non-dynamic rigid body." + ) + return + + local_env_ids = self._all_indices if env_ids is None else env_ids + + if force is not None and len(local_env_ids) != len(force): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match force length {len(force)}." + ) + + if torque is not None and len(local_env_ids) != len(torque): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match torque length {len(torque)}." + ) + + if self.device.type == "cpu": + for i, env_idx in enumerate(local_env_ids): + if force is not None: + self._entities[env_idx].add_force(force[i].cpu().numpy()) + if torque is not None: + self._entities[env_idx].add_torque(torque[i].cpu().numpy()) + + else: + indices = self.body_data.gpu_indices[local_env_ids] + if force is not None: + self._ps.gpu_apply_rigid_body_data( + data=force, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.FORCE, + ) + if torque is not None: + self._ps.gpu_apply_rigid_body_data( + data=torque, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.TORQUE, + ) + + def set_attrs( + self, + attrs: Union[RigidBodyAttributesCfg, List[RigidBodyAttributesCfg]], + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """Set physical attributes for the rigid object. + + Args: + attrs (Union[RigidBodyAttributesCfg, List[RigidBodyAttributesCfg]]): The physical attributes to set. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if isinstance(attrs, List) and len(local_env_ids) != len(attrs): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match attrs length {len(attrs)}." + ) + + # TODO: maybe need to improve the physical attributes setter efficiency. + if isinstance(attrs, RigidBodyAttributesCfg): + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_physical_attr(attrs.attr()) + else: + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_physical_attr(attrs[i].attr()) + + def set_visual_material( + self, mat: VisualMaterial, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set visual material for the rigid object. + + Args: + mat (VisualMaterial): The material to set. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + for i, env_idx in enumerate(local_env_ids): + mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}") + self._entities[env_idx].set_material(mat_inst.mat) + self._visual_material[env_idx] = mat_inst + + def get_visual_material_inst( + self, env_ids: Optional[Sequence[int]] = None + ) -> List[VisualMaterialInst]: + """Get material instances for the rigid object. + + Args: + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + + Returns: + List[MaterialInst]: List of material instances. + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + return [self._visual_material[i] for i in ids] + + def get_body_scale(self, env_ids: Optional[Sequence[int]] = None) -> torch.Tensor: + """ + Retrieve the body scale for specified environment instances. + + Args: + env_ids (Optional[Sequence[int]]): A sequence of environment instance IDs. + If None, retrieves the body scale for all instances. + + Returns: + torch.Tensor: A tensor containing the body scales of the specified instances, + with shape (N, 3) dtype int32 and located on the specified device. + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + return torch.as_tensor( + [self._entities[id].get_body_scale() for id in ids], + dtype=torch.float32, + device=self.device, + ) + + def set_body_scale( + self, scale: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set the scale of the rigid body. + + Args: + scale (torch.Tensor): The scale to set with shape (N, 3). + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(scale): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match scale length {len(scale)}." + ) + + if self.device.type == "cpu": + for i, env_idx in enumerate(local_env_ids): + scale = scale[i].cpu().numpy() + self._entities[env_idx].set_body_scale(*scale) + else: + logger.log_error(f"Setting body scale on GPU is not supported yet.") + + def get_vertices(self, env_ids: Optional[Sequence[int]] = None) -> torch.Tensor: + """ + Retrieve the vertices of the rigid objects. + + Args: + env_ids (Optional[Sequence[int]]): A sequence of environment IDs for which to retrieve vertices. + If None, retrieves vertices for all instances. + + Returns: + torch.Tensor: A tensor containing the user IDs of the specified rigid objects with shape (N, num_verts, 3). + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + return torch.as_tensor( + np.array( + [self._entities[id].get_vertices() for id in ids], + ), + dtype=torch.float32, + device=self.device, + ) + + def get_user_ids(self) -> torch.Tensor: + """Get the user ids of the rigid bodies. + + Returns: + torch.Tensor: A tensor of shape (num_envs,) representing the user ids of the rigid bodies. + """ + return torch.as_tensor( + [entity.get_user_id() for entity in self._entities], + dtype=torch.int32, + device=self.device, + ) + + def clear_dynamics(self, env_ids: Optional[Sequence[int]] = None) -> None: + """Clear the dynamics of the rigid bodies by resetting velocities and applying zero forces and torques. + + Args: + env_ids (Optional[Sequence[int]]): Environment indices. If None, then all indices are used. + """ + if self.is_non_dynamic: + return + + local_env_ids = self._all_indices if env_ids is None else env_ids + + if self.device.type == "cpu": + for env_idx in local_env_ids: + self._entities[env_idx].clear_dynamics() + else: + # Apply zero force and torque to the rigid bodies. + zeros = torch.zeros( + (len(local_env_ids), 3), dtype=torch.float32, device=self.device + ) + indices = self.body_data.gpu_indices[local_env_ids] + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.LINEAR_VELOCITY, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.ANGULAR_VELOCITY, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.FORCE, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.TORQUE, + ) + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + local_env_ids = self._all_indices if env_ids is None else env_ids + num_instances = len(local_env_ids) + self.set_attrs(self.cfg.attrs, env_ids=local_env_ids) + + pos = torch.as_tensor( + self.cfg.init_pos, dtype=torch.float32, device=self.device + ) + rot = ( + torch.as_tensor(self.cfg.init_rot, dtype=torch.float32, device=self.device) + * torch.pi + / 180.0 + ) + pos = pos.unsqueeze(0).repeat(num_instances, 1) + rot = rot.unsqueeze(0).repeat(num_instances, 1) + mat = matrix_from_euler(rot, "XYZ") + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(num_instances, 1, 1) + ) + pose[:, :3, 3] = pos + pose[:, :3, :3] = mat + self.set_local_pose(pose, env_ids=local_env_ids) + + self.clear_dynamics(env_ids=local_env_ids) + + def destroy(self) -> None: + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + for i, entity in enumerate(self._entities): + arenas[i].remove_actor(entity) diff --git a/embodichain/lab/sim/objects/rigid_object_group.py b/embodichain/lab/sim/objects/rigid_object_group.py new file mode 100644 index 00000000..2db3053f --- /dev/null +++ b/embodichain/lab/sim/objects/rigid_object_group.py @@ -0,0 +1,513 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import dexsim +import numpy as np + +from dataclasses import dataclass +from typing import List, Sequence, Optional, Union + +from dexsim.models import MeshObject +from dexsim.types import RigidBodyGPUAPIReadType, RigidBodyGPUAPIWriteType +from dexsim.engine import CudaArray, PhysicsScene +from embodichain.lab.sim.cfg import ( + RigidObjectGroupCfg, + RigidBodyAttributesCfg, +) +from embodichain.lab.sim import ( + BatchEntity, +) +from embodichain.lab.sim.material import VisualMaterial, VisualMaterialInst +from embodichain.utils.math import convert_quat +from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler +from embodichain.utils import logger + + +@dataclass +class RigidBodyGroupData: + """Data manager for rigid body group with body type of dynamic or kinematic.""" + + def __init__( + self, entities: List[List[MeshObject]], ps: PhysicsScene, device: torch.device + ) -> None: + """Initialize the RigidBodyGroupData. + + Args: + entities (List[List[MeshObject]]): List of List MeshObjects representing the rigid body group. + ps (PhysicsScene): The physics scene. + device (torch.device): The device to use for the rigid body group data. + """ + self.entities = entities + self.ps = ps + self.num_instances = len(entities) + self.num_objects = len(entities[0]) + self.device = device + + # get gpu indices for the rigid bodies with shape of (num_instances, num_objects) + self.gpu_indices = torch.as_tensor( + [[entity.get_gpu_index() for entity in instance] for instance in entities], + dtype=torch.int32, + device=self.device, + ) + + # Initialize rigid body group data tensors. Shape of (num_instances, num_objects, data_dim) + self._pose = torch.zeros( + (self.num_instances, self.num_objects, 7), + dtype=torch.float32, + device=self.device, + ) + self._lin_vel = torch.zeros( + (self.num_instances, self.num_objects, 3), + dtype=torch.float32, + device=self.device, + ) + self._ang_vel = torch.zeros( + (self.num_instances, self.num_objects, 3), + dtype=torch.float32, + device=self.device, + ) + + @property + def pose(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch pose from CPU entities + xyzs = torch.as_tensor( + [ + [entity.get_location() for entity in instance] + for instance in self.entities + ], + device=self.device, + ) + quats = torch.as_tensor( + [ + [entity.get_rotation_quat() for entity in instance] + for instance in self.entities + ], + device=self.device, + ) + quats = convert_quat(quats.reshape(-1, 4), to="wxyz").reshape( + -1, self.num_objects, 4 + ) + return torch.cat((xyzs, quats), dim=-1) + else: + pose = self._pose.reshape(-1, 7) + self.ps.gpu_fetch_rigid_body_data( + data=pose, + gpu_indices=self.gpu_indices.flatten(), + data_type=RigidBodyGPUAPIReadType.POSE, + ) + pose = convert_quat(pose[:, :4], to="wxyz") + pose = pose[:, [4, 5, 6, 0, 1, 2, 3]] + return self._pose + + @property + def lin_vel(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch linear velocity from CPU entities + self._lin_vel = torch.as_tensor( + [ + [entity.get_linear_velocity() for entity in instance] + for instance in self.entities + ], + dtype=torch.float32, + device=self.device, + ) + else: + lin_vel = self._lin_vel.reshape(-1, 3) + self.ps.gpu_fetch_rigid_body_data( + data=lin_vel, + gpu_indices=self.gpu_indices.flatten(), + data_type=RigidBodyGPUAPIReadType.LINEAR_VELOCITY, + ) + return self._lin_vel + + @property + def ang_vel(self) -> torch.Tensor: + if self.device.type == "cpu": + # Fetch angular velocity from CPU entities + self._ang_vel = torch.as_tensor( + [ + [entity.get_linear_velocity() for entity in instance] + for instance in self.entities + ], + dtype=torch.float32, + device=self.device, + ) + else: + ang_vel = self._ang_vel.reshape(-1, 3) + self.ps.gpu_fetch_rigid_body_data( + data=ang_vel, + gpu_indices=self.gpu_indices.flatten(), + data_type=RigidBodyGPUAPIReadType.ANGULAR_VELOCITY, + ) + return self._ang_vel + + @property + def vel(self) -> torch.Tensor: + """Get the linear and angular velocities of the rigid bodies. + + Returns: + torch.Tensor: The linear and angular velocities concatenated, with shape (num_instances, num_objects, 6). + """ + return torch.cat((self.lin_vel, self.ang_vel), dim=-1) + + +class RigidObjectGroup(BatchEntity): + """RigidObjectGroup represents a batch of rigid bodies in the simulation.""" + + def __init__( + self, + cfg: RigidObjectGroupCfg, + entities: List[List[MeshObject]] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + self.body_type = cfg.body_type + + self._world = dexsim.default_world() + self._ps = self._world.get_physics_scene() + + self._all_indices = torch.arange( + len(entities), dtype=torch.int32, device=device + ) + self._all_obj_indices = torch.arange( + len(entities[0]), dtype=torch.int32, device=device + ) + + # data for managing body data (only for dynamic and kinematic bodies) on GPU. + self._data = RigidBodyGroupData(entities=entities, ps=self._ps, device=device) + + body_cfgs = list(cfg.rigid_objects.values()) + for instance in entities: + for i, body in enumerate(instance): + body.set_body_scale(*body_cfgs[i].body_scale) + body.set_physical_attr(body_cfgs[i].attrs.attr()) + + if device.type == "cuda": + self._world.update(0.001) + + super().__init__(cfg, entities, device) + + # set default collision filter + self._set_default_collision_filter() + + def __str__(self) -> str: + parent_str = super().__str__() + return ( + parent_str + + f" | body type: {self.body_type} | num_objects: {self.num_objects}" + ) + + @property + def num_objects(self) -> int: + """Get the number of objects in each rigid body instance. + + Returns: + int: The number of objects in each rigid body instance. + """ + return self._data.num_objects + + @property + def body_data(self) -> RigidBodyGroupData: + """Get the rigid body data manager for this rigid object. + + Returns: + RigidBodyGroupData: The rigid body data manager. + """ + return self._data + + @property + def body_state(self) -> torch.Tensor: + """Get the body state of the rigid object. + + The body state of a rigid object is represented as a tensor with the following format: + [x, y, z, qw, qx, qy, qz, lin_x, lin_y, lin_z, ang_x, ang_y, ang_z] + + If the rigid object is static, linear and angular velocities will be zero. + + Returns: + torch.Tensor: The body state of the rigid object with shape (num_instances, num_objects, 13), + where N is the number of instances. + """ + return torch.cat( + (self.body_data.pose, self.body_data.lin_vel, self.body_data.ang_vel), + dim=-1, + ) + + @property + def is_non_dynamic(self) -> bool: + """Check if the rigid object is non-dynamic (static or kinematic). + + Returns: + bool: True if the rigid object is non-dynamic, False otherwise. + """ + return self.body_type in ("static", "kinematic") + + def _set_default_collision_filter(self) -> None: + collision_filter_data = torch.zeros( + size=(self.num_instances, 4), dtype=torch.int32 + ) + for i in range(self.num_instances): + collision_filter_data[i, 0] = i + collision_filter_data[i, 1] = 1 + self.set_collision_filter(collision_filter_data) + + def set_collision_filter( + self, filter_data: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """set collision filter data for the rigid object group. + + Args: + filter_data (torch.Tensor): [N, 4] of int. + First element of each object is arena id. + If 2nd element is 0, the object will collision with all other objects in world. + 3rd and 4th elements are not used currently. + + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(filter_data): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(filter_data)}." + ) + + filter_data_np = filter_data.cpu().numpy().astype(np.uint32) + for i, env_idx in enumerate(local_env_ids): + for entity in self._entities[env_idx]: + entity.get_physical_body().set_collision_filter_data(filter_data_np[i]) + + def set_local_pose( + self, + pose: torch.Tensor, + env_ids: Optional[Sequence[int]] = None, + obj_ids: Optional[Sequence[int]] = None, + ) -> None: + """Set local pose of the rigid object group. + + Args: + pose (torch.Tensor): The local pose of the rigid object group with shape (num_instances, num_objects, 7) or + (num_instances, num_objects, 4, 4). + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + obj_ids (Optional[Sequence[int]], optional): Object indices within the group. If None, all objects are set. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + local_obj_ids = self._all_obj_indices if obj_ids is None else obj_ids + + if len(local_env_ids) != len(pose): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(pose)}." + ) + + if self.device.type == "cpu": + pose = pose.cpu() + if pose.dim() == 3 and pose.shape[2] == 7: + reshape_pose = pose.reshape(-1, 7) + pose_matrix = ( + torch.eye(4).unsqueeze(0).repeat(reshape_pose.shape[0], 1, 1) + ) + pose_matrix[:, :3, 3] = reshape_pose[:, :3] + pose_matrix[:, :3, :3] = matrix_from_quat(reshape_pose[:, 3:7]) + pose = pose_matrix.reshape(-1, len(local_obj_ids), 4, 4) + elif pose.dim() == 4 and pose.shape[2:] == (4, 4): + pass + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (num_instances, num_objects, 7) or (num_instances, num_objects, 4, 4)." + ) + + for i, env_idx in enumerate(local_env_ids): + for j, obj_idx in enumerate(local_obj_ids): + self._entities[env_idx][obj_idx].set_local_pose(pose[i, j]) + + else: + if pose.dim() == 3 and pose.shape[2] == 7: + xyz = pose[..., :3].reshape(-1, 3) + quat = pose[..., 3:7].reshape(-1, 4) + quat = convert_quat(quat, to="xyzw") + elif pose.dim() == 4 and pose.shape[2:] == (4, 4): + xyz = pose[..., :3, 3].reshape(-1, 3) + mat = pose[..., :3, :3].reshape(-1, 3, 3) + quat = quat_from_matrix(mat) + quat = convert_quat(quat, to="xyzw") + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + # we should keep `pose_` life cycle to the end of the function. + pose = torch.cat((quat, xyz), dim=-1) + indices = self.body_data.gpu_indices[local_env_ids][ + :, local_obj_ids + ].flatten() + self._ps.gpu_apply_rigid_body_data( + data=pose.clone(), + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.POSE, + ) + self._world.sync_poses_gpu_to_cpu( + rigid_pose=CudaArray(pose), rigid_gpu_indices=CudaArray(indices) + ) + + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get local pose of the rigid object group. + + Args: + to_matrix (bool, optional): If True, return the pose as a 4x4 matrix. If False, return as (x, y, z, qw, qx, qy, qz). Defaults to False. + + Returns: + torch.Tensor: The local pose of the rigid object with shape (num_instances, num_objects, 7) or (num_instances, num_objects, 4, 4) depending on `to_matrix`. + """ + pose = self.body_data.pose + if to_matrix: + pose = pose.reshape(-1, 7) + xyz = pose[:, :3] + mat = matrix_from_quat(pose[:, 3:7]) + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(self.num_instances * self.num_objects, 1, 1) + ) + pose[:, :3, 3] = xyz + pose[:, :3, :3] = mat + pose = pose.reshape(self.num_instances, self.num_objects, 4, 4) + return pose + + def get_user_ids(self) -> torch.Tensor: + """Get the user ids of the rigid body group. + + Returns: + torch.Tensor: A tensor of shape (num_envs, num_objects) representing the user ids of the rigid body group. + """ + return torch.as_tensor( + [ + [entity.get_user_id() for entity in instance] + for instance in self._entities + ], + dtype=torch.int32, + device=self.device, + ) + + def clear_dynamics(self, env_ids: Optional[Sequence[int]] = None) -> None: + """Clear the dynamics of the rigid bodies by resetting velocities and applying zero forces and torques. + + Args: + env_ids (Optional[Sequence[int]]): Environment indices. If None, then all indices are used. + """ + if self.is_non_dynamic: + return + + local_env_ids = self._all_indices if env_ids is None else env_ids + + if self.device.type == "cpu": + for env_idx in local_env_ids: + for entity in self._entities[env_idx]: + entity.clear_dynamics() + else: + # Apply zero force and torque to the rigid bodies. + zeros = torch.zeros( + (len(local_env_ids) * self.num_objects, 3), + dtype=torch.float32, + device=self.device, + ) + indices = self.body_data.gpu_indices[local_env_ids].flatten() + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.LINEAR_VELOCITY, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.ANGULAR_VELOCITY, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.FORCE, + ) + self._ps.gpu_apply_rigid_body_data( + data=zeros, + gpu_indices=indices, + data_type=RigidBodyGPUAPIWriteType.TORQUE, + ) + + def set_visual_material( + self, mat: VisualMaterial, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set visual material for the rigid object group. + + Args: + mat (VisualMaterial): The material to set. + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + for i, env_idx in enumerate(local_env_ids): + for j, entity in enumerate(self._entities[env_idx]): + mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}_{j}") + entity.set_material(mat_inst.mat) + + # Note: The rigid object group is not supported to change the visual material once created. + # If needed, we should create a visual material dict to store the material instances, and + # implement a get_visual_material method to retrieve the material instances. + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + local_env_ids = self._all_indices if env_ids is None else env_ids + num_instances = len(local_env_ids) + + self.cfg: RigidObjectGroupCfg + body_cfgs = list(self.cfg.rigid_objects.values()) + + init_pos = [] + init_rot = [] + for cfg in body_cfgs: + init_pos.append(cfg.init_pos) + init_rot.append(cfg.init_rot) + + # (num_objects, 3) + pos = torch.as_tensor(init_pos, dtype=torch.float32, device=self.device) + rot = ( + torch.as_tensor(init_rot, dtype=torch.float32, device=self.device) + * torch.pi + / 180.0 + ) + # Convert pos and rot to shape (num_instances, num_objects, dim) + pos = pos.unsqueeze_(0).repeat(num_instances, 1, 1) + rot = rot.unsqueeze_(0).repeat(num_instances, 1, 1) + + mat = matrix_from_euler(rot.reshape(-1, 3), "XYZ") + # Init pose with shape (num_instances, num_objects, 4, 4) + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze_(0) + .repeat(num_instances * self.num_objects, 1, 1) + ) + pose[:, :3, 3] = pos.reshape(-1, 3) + pose[:, :3, :3] = mat + pose = pose.reshape(num_instances, self.num_objects, 4, 4) + self.set_local_pose(pose, env_ids=local_env_ids) + + self.clear_dynamics(env_ids=local_env_ids) + + def destroy(self) -> None: + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + for i, instance in enumerate(self._entities): + for entity in instance: + arenas[i].remove_actor(entity) diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py new file mode 100644 index 00000000..d15f50fb --- /dev/null +++ b/embodichain/lab/sim/objects/robot.py @@ -0,0 +1,653 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np + +from typing import List, Dict, Optional, Tuple, Union, Sequence +from dataclasses import dataclass, field + +from dexsim.engine import Articulation as _Articulation +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver +from embodichain.lab.sim.objects import Articulation +from embodichain.lab.sim.utility.tensor import to_tensor +from embodichain.utils.math import quat_from_matrix +from embodichain.utils.string import ( + is_regular_expression, + resolve_matching_names_values, +) +from embodichain.utils import logger + + +@dataclass +class ControlGroup: + r"""Represents a group of controllable joints in a robot. + + Attributes: + joint_names (List[str]): Names of the joints in this control group. + joint_ids (List[int]): IDs corresponding to the joints in this control group. + link_names (List[str]): Names of child links associated with the joints. + """ + + joint_names: List[str] = field(default_factory=list) + joint_ids: List[int] = field(default_factory=list) + link_names: List[str] = field(default_factory=list) + + def __post_init__(self): + pass + + +class Robot(Articulation): + """A class representing a batch of robots in the simulation environment. + + Robot is a specific type of articulation that can have additional properties or methods. + - `control_parts`: Specify the parts that can be controlled in a different manner. Different part may have + different joint ids, drive properties, pyhsical attributes, kinematic solvers or motion planners. + - `solvers`: Specify the kinematic solvers for the robot. + - `planners`: Specify the motion planner for the robot. + """ + + def __init__( + self, + cfg: RobotCfg, + entities: List[_Articulation], + device: torch.device = torch.device("cpu"), + ) -> None: + super().__init__(cfg, entities, device) + + self._solvers = {} + + # Initialize joint ids for control parts. + self._joint_ids: Dict[str, List[int]] = {} + + self._control_groups: Dict[str, ControlGroup] = {} + + if self.cfg.control_parts: + self._init_control_parts(self.cfg.control_parts) + + if self.cfg.solver_cfg: + self.init_solver(self.cfg.solver_cfg) + + def __str__(self) -> str: + parent_str = super().__str__() + return ( + parent_str + + f" | control_parts: {self.control_parts}, solvers: {self._solvers}" + ) + + @property + def control_parts(self) -> Union[Dict[str, List[str]], None]: + """Get the control parts of the robot.""" + return self.cfg.control_parts + + def get_joint_ids( + self, name: Optional[str] = None, remove_mimic: bool = False + ) -> List[int]: + """Get the joint ids of the robot for a specific control part. + + Args: + name (str, optional): The name of the control part to get the joint ids for. If None, the default part is used. + remove_mimic (bool, optional): If True, mimic joints will be excluded from the returned joint ids. Defaults to False. + + Returns: + List[int]: The joint ids of the robot for the specified control part. + """ + if not self.control_parts or name is None: + return ( + torch.arange(self.dof, dtype=torch.int32).tolist() + if not remove_mimic + else [i for i in range(self.dof) if i not in self.mimic_ids] + ) + + if name not in self.control_parts: + logger.log_error( + f"The control part '{name}' does not exist in the robot's control parts." + ) + return ( + self._joint_ids[name] + if not remove_mimic + else [i for i in self._joint_ids[name] if i not in self.mimic_ids] + ) + + def get_proprioception(self) -> Dict[str, torch.Tensor]: + """Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. + + The default proprioception information includes: + - qpos: Joint positions. + - qvel: Joint velocities. + - qf: Joint efforts. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the robot's proprioception information + """ + + return dict( + qpos=self.body_data.qpos, qvel=self.body_data.qvel, qf=self.body_data.qf + ) + + def compute_fk( + self, + qpos: Optional[Union[torch.tensor, np.ndarray]], + name: Optional[str] = None, + link_names: Optional[List[str]] = None, + end_link_name: Optional[str] = None, + root_link_name: Optional[str] = None, + env_ids: Optional[Sequence[int]] = None, + to_matrix: bool = False, + ) -> torch.Tensor: + """Compute the forward kinematics of the robot given joint positions and optionally a specific part name. + The output pose will be in the local arena frame. + + Args: + qpos (Optional[Union[torch.tensor, np.ndarray]]): Joint positions of the robot, (n_envs, num_joints). + name (str, optional): The name of the control part to compute the FK for. If None, the default part is used. + link_names (List[str], optional): The names of the links to compute the FK for. If None, all links are used. + end_link_name (str, optional): The name of the end link to compute the FK for. If None, the default end link is used. + root_link_name (str, optional): The name of the root link to compute the FK for. If None, the default root link is used. + env_ids (Sequence[int], optional): The environment ids to compute the FK for. If None, all environments are used. + to_matrix (bool, optional): If True, returns the transformation in the form of a 4x4 matrix. + + Returns: + torch.Tensor: The forward kinematics result with shape (n_envs, 7) or (n_envs, 4, 4) if `to_matrix` is True. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if name is None and hasattr(super(), "compute_fk"): + return super().compute_fk( + qpos=qpos, + link_names=link_names, + end_link_name=end_link_name, + root_link_name=root_link_name, + ) + + if not self._solvers: + logger.log_error( + "No solvers are defined for the robot. Please ensure that the robot has solvers configured." + ) + + solver = self._solvers.get(name if name is not None else "default", None) + if solver is None: + logger.log_error( + f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." + ) + return None + + if qpos.dim() == 1: + qpos = qpos.unsqueeze(0) + + if qpos.shape[0] != len(local_env_ids): + logger.log_error( + f"Joint positions batch size mismatch. Expected {len(local_env_ids)} but got {qpos.shape[0]}." + ) + + if qpos.shape[1] != solver.dof: + logger.log_error( + f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." + ) + + result_matrix = solver.get_fk(qpos=qpos) + + base_pose = self.get_link_pose( + link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True + ) + result_matrix = torch.bmm(base_pose, result_matrix) + + if to_matrix: + return result_matrix + else: + pos = result_matrix[:, :3, 3] + quat = quat_from_matrix(result_matrix[:, :3, :3]) + return torch.cat((pos, quat), dim=-1) + + def compute_ik( + self, + pose: Union[torch.Tensor, np.ndarray], + joint_seed: Optional[Union[torch.Tensor, np.ndarray]] = None, + name: Optional[str] = None, + env_ids: Optional[Sequence[int]] = None, + return_all_solutions: bool = False, + ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Compute the inverse kinematics of the robot given joint positions and optionally a specific part name. + The input pose should be in the local arena frame. + + Args: + pose (torch.Tensor): The end effector pose of the robot, (n_envs, 7) or (n_envs, 4, 4). + joint_seed (torch.Tensor, optional): The joint positions to use as a seed for the IK computation, (n_envs, dof). + If None, the zero joint positions will be used as the seed. + name (str, optional): The name of the control part to compute the IK for. If None, the default part is used. + env_ids (Optional[Sequence[int]]): Environment indices to apply the positions. Defaults to all environments. + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The success Tensor with shape (n_envs, ) and qpos Tensor with shape (n_envs, max_results, dof). + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + solver = self._solvers.get(name if name is not None else "default", None) + if solver is None: + logger.log_error( + f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." + ) + return None + + pose = to_tensor(pose, device=self.device) + if (pose.dim() == 1 and pose.shape[1] == 7) or ( + pose.dim() == 2 and pose.shape[1] == 4 + ): + pose = pose.unsqueeze(0) + + if pose.shape[0] != len(local_env_ids): + logger.log_error( + f"Pose batch size mismatch. Expected {len(local_env_ids)} but got {pose.shape[0]}." + ) + + if joint_seed is not None: + joint_seed = to_tensor(joint_seed, device=self.device) + if joint_seed.dim() == 1: + joint_seed = joint_seed.unsqueeze(0) + + if joint_seed.shape[0] != len(local_env_ids): + logger.log_error( + f"Joint seed batch size mismatch. Expected {len(local_env_ids)} but got {joint_seed.shape[0]}." + ) + + if pose.shape[-1] == 7 and pose.dim() == 2: + # Convert pose from (batch, 7) to (batch, 4, 4) + pose = torch.cat( + ( + pose[:, :3].unsqueeze(-1), # Position + quat_from_matrix(pose[:, 3:]).unsqueeze(-1), # Quaternion + ), + dim=-1, + ) + pose = torch.cat( + ( + pose, + torch.tensor([[0, 0, 0, 1]], device=pose.device).expand( + pose.shape[0], -1, -1 + ), + ), + dim=1, + ) + + base_pose = self.get_link_pose( + link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True + ) + pose = torch.bmm(torch.inverse(base_pose), pose) + + ret, qpos = solver.get_ik( + target_xpos=pose, + qpos_seed=joint_seed, + return_all_solutions=return_all_solutions, + ) + dof = qpos.shape[-1] + if not return_all_solutions: + qpos = qpos.reshape(-1, dof) + + return ret.to(self.device), qpos.to(self.device) + + def compute_batch_fk( + self, + qpos: torch.tensor, + name: str, + env_ids: Optional[Sequence[int]] = None, + to_matrix: bool = False, + ): + """Compute the forward kinematics of the robot given joint positions and optionally a specific part name. + The output pose will be in the local arena frame. + + Args: + qpos (Optional[Union[torch.tensor, np.ndarray]]): Joint positions of the robot, (n_envs, n_batch, num_joints). + name (str, optional): The name of the control part to compute the FK for. If None, the default part is used. + env_ids (Sequence[int], optional): The environment ids to compute the FK for. If None, all environments are used. + to_matrix (bool, optional): If True, returns the transformation in the form of a 4x4 matrix. + + Returns: + torch.Tensor: The forward kinematics result with shape (n_envs, batch, 7) or (n_envs, batch, 4, 4) if `to_matrix` is True. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + if not self._solvers: + logger.log_error( + "No solvers are defined for the robot. Please ensure that the robot has solvers configured." + ) + + solver = self._solvers.get(name if name is not None else "default", None) + if solver is None: + logger.log_error( + f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." + ) + return None + + if qpos.shape[0] != len(local_env_ids): + logger.log_error( + f"Joint positions batch size mismatch. Expected {len(local_env_ids)} but got {qpos.shape[0]}." + ) + + if qpos.shape[2] != solver.dof: + logger.log_error( + f"Joint positions shape mismatch. Expected {solver.dof} joints, got {qpos.shape[1]}." + ) + + n_batch = qpos.shape[1] + qpos_batch = qpos.reshape(-1, solver.dof) + xpos_batch = solver.get_fk(qpos=qpos_batch) + + # get xpos from link root + base_xpos_n_envs = self.get_link_pose( + link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True + ) + base_xpos_batch = ( + base_xpos_n_envs[:, None, :, :].repeat(1, n_batch, 1, 1).reshape(-1, 4, 4) + ) + result_matrix = torch.bmm(base_xpos_batch, xpos_batch) + + if to_matrix: + result_matrix = result_matrix.reshape(len(local_env_ids), n_batch, 4, 4) + return result_matrix + else: + pos = result_matrix[:, :3, 3] + quat = quat_from_matrix(result_matrix[:, :3, :3]) + result = torch.cat((pos, quat), dim=-1) + result = result.reshape(len(local_env_ids), n_batch, 7) + return result + + def compute_batch_ik( + self, + pose: Union[torch.Tensor, np.ndarray], + joint_seed: Optional[Union[torch.Tensor, np.ndarray]], + name: str, + env_ids: Optional[Sequence[int]] = None, + ): + """Compute the inverse kinematics of the robot given joint positions and optionally a specific part name. + The input pose should be in the local arena frame. + + Args: + pose (torch.Tensor): The end effector pose of the robot, (n_envs, n_batch, 7) or (n_envs, n_batch, 4, 4). + joint_seed (torch.Tensor, optional): The joint positions to use as a seed for the IK computation, (n_envs, n_batch, dof). If None, the zero joint positions will be used as the seed. + name (str): The name of the control part to compute the IK for. If None, the default part is used. + env_ids (Optional[Sequence[int]]): Environment indices to apply the positions. Defaults to all environments. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + Success Tensor with shape (n_envs, n_batch) + Qpos Tensor with shape (n_envs, n_batch, dof). + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + solver = self._solvers.get(name if name is not None else "default", None) + if solver is None: + logger.log_error( + f"The control part '{name}' does not have an associated solver. Please ensure that a valid control part with an available solver is provided." + ) + return None + pose = to_tensor(pose, device=self.device) + + if pose.shape[0] != len(local_env_ids): + logger.log_error( + f"Pose batch size mismatch. Expected {len(local_env_ids)} but got {pose.shape[0]}." + ) + + n_batch = pose.shape[1] + n_dof = solver.dof + if joint_seed is None: + joint_seed = torch.zeros( + (len(local_env_ids), n_batch, n_dof), + dtype=torch.float32, + device=self.device, + ) + + if joint_seed.shape[0] != len(local_env_ids): + logger.log_error( + f"Joint seed env size mismatch. Expected {len(local_env_ids)} but got {joint_seed.shape[0]}." + ) + + if joint_seed.shape[1] != n_batch: + logger.log_error( + f"Joint seed batch size mismatch. Expected {n_batch} but got {joint_seed.shape[1]}." + ) + + if joint_seed.shape[-1] != n_dof: + logger.log_error( + f"Joint seed dof size mismatch. Expected {n_batch} but got {joint_seed.shape[-1]}." + ) + + if pose.shape[-1] == 7 and pose.dim() == 3: + # Convert pose from (n_envs, n_batch, 7) to (n_envs * n_batch, 4, 4) + pose_batch = torch.reshape(-1, 7) + pose_batch = torch.cat( + ( + pose_batch[:, :3].unsqueeze(-1), # Position + quat_from_matrix(pose_batch[:, 3:]).unsqueeze(-1), # Quaternion + ), + dim=-1, + ) + pose_batch = torch.cat( + ( + pose_batch, + torch.tensor([[0, 0, 0, 1]], device=pose_batch.device).expand( + pose_batch.shape[0], -1, -1 + ), + ), + dim=1, + ) + else: + # Convert pose from (n_envs, n_batch, 4, 4) to (n_envs * n_batch, 4, 4) + pose_batch = pose.reshape(-1, 4, 4) + + # get xpos from link root + base_xpos_n_envs = self.get_link_pose( + link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True + ) + base_inv_xpos_n_envs = torch.inverse(base_xpos_n_envs) + base_inv_xpos_batch = ( + base_inv_xpos_n_envs[:, None, :, :] + .repeat(1, n_batch, 1, 1) + .reshape(-1, 4, 4) + ) + pose_batch = torch.bmm(base_inv_xpos_batch, pose_batch) + + joint_seed_batch = joint_seed.reshape(-1, n_dof) + ret, qpos_batch = solver.get_ik( + target_xpos=pose_batch, + qpos_seed=joint_seed_batch, + return_all_solutions=False, + ) + ret = ret.reshape(len(local_env_ids), n_batch) + qpos = qpos_batch.reshape(len(local_env_ids), n_batch, n_dof) + return ret, qpos + + def _init_control_parts(self, control_parts: Dict[str, List[str]]) -> None: + """Initialize the control parts of the robot. + + Args: + control_parts (Dict[str, List[str]]): A dictionary where keys are control part names and values are lists of + joint names or regular expressions that match joint names. + """ + joint_name_to_ids = {name: i for i, name in enumerate(self.joint_names)} + for name, joint_names in control_parts.items(): + # convert joint_names which is a regular expression to a list of joint names + joint_names_expanded = [] + for jn in joint_names: + if is_regular_expression(jn): + _, names, _ = resolve_matching_names_values( + {jn: None}, self.joint_names + ) + joint_names_expanded.extend(names) + else: + joint_names_expanded.append(jn) + + self._joint_ids[name] = [ + joint_name_to_ids[joint_name] + for joint_name in joint_names_expanded + if joint_name in joint_name_to_ids + ] + if len(self._joint_ids[name]) != len(joint_names_expanded): + logger.log_error( + f"joint names in control part '{name}' do not match the robot's joint names. The full joint names are: {self.joint_names}." + ) + self.cfg.control_parts[name] = joint_names_expanded + + # Initialize control groups + self._control_groups = self._extract_control_groups() + + def init_solver(self, cfg: Union[SolverCfg, Dict[str, SolverCfg]]) -> None: + """Initialize the kinematic solver for the robot. + + Args: + cfg (Union[SolverCfg, Dict[str, SolverCfg]]): The configuration for the kinematic solver. + """ + self.cfg: RobotCfg + + if isinstance(cfg, SolverCfg): + if self.control_parts: + logger.log_error( + "Control parts are defined in the robot configuration, solver_cfg must be a dictionary." + ) + + if cfg.urdf_path is None: + cfg.urdf_path = self.cfg.fpath + self._solvers["default"] = cfg.init_solver(device=self.device) + elif isinstance(cfg, Dict): + if isinstance(self.cfg.control_parts, Dict) is False: + logger.log_error( + "When `solver_cfg` is a dictionary, `control_parts` must also be a dictionary." + ) + + # If solver_cfg is a dictionary, iterate through it to create solvers + for name, solver_cfg in cfg.items(): + if solver_cfg.urdf_path is None: + solver_cfg.urdf_path = self.cfg.fpath + _, part_names, value = resolve_matching_names_values( + {name: solver_cfg}, self.cfg.control_parts.keys() + ) + for part_name in part_names: + if ( + not hasattr(solver_cfg, "joint_names") + or solver_cfg.joint_names is None + ): + solver_cfg.joint_names = self.cfg.control_parts[part_name] + self._solvers[name] = solver_cfg.init_solver(device=self.device) + + def get_solver(self, name: Optional[str] = None) -> Optional[BaseSolver]: + """Get the kinematic solver for a specific control part. + + Args: + name (str, optional): The name of the control part to get the solver for. If None, the default part is used. + + Returns: + Optional[BaseSolver]: The kinematic solver for the specified control part, or None if not found. + """ + + if not self._solvers: + logger.log_error( + "No solvers are defined for the robot. Please ensure that the robot has solvers configured." + ) + return None + + return self._solvers.get(name if name is not None else "default", None) + + def get_control_part_base_pose( + self, + name: Optional[str] = None, + env_ids: Optional[Sequence[int]] = None, + to_matrix: bool = False, + ) -> torch.Tensor: + """Retrieves the base pose of the control part for a specified robot. + + Args: + name (Optional[str]): The name of the control part the solver adhere to. If None, the default solver is used. + env_ids (Optional[Sequence[int]]): A sequence of environment IDs to specify the environments. + If None, all indices are used. + to_matrix (bool): If True, returns the pose in the form of a 4x4 matrix. + + Returns: + The pose of the specified link in the form of a matrix. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + root_link_name = None + if name in self._control_groups: + root_link_name = self._control_groups[name].link_names[0] + + return self.get_link_pose( + link_name=root_link_name, env_ids=local_env_ids, to_matrix=to_matrix + ) + + def _extract_control_groups(self) -> Dict[str, ControlGroup]: + r"""Extract control groups from the active joint names. + + This method creates a dictionary of control groups where each control + group is associated with its corresponding joint names. It utilizes + the `_extract_control_group` method to populate the control groups. + + Returns: + Dict[str, ControlGroup]: A dictionary mapping control group names + to their corresponding ControlGroup instances. + """ + if not self.control_parts: + return {} + + control_groups = { + control_group_name: self._extract_control_group(joint_names) + for control_group_name, joint_names in self.control_parts.items() + } + + return control_groups + + def _extract_control_group(self, joint_names: List[str]) -> ControlGroup: + r"""Extract a control group from the given list of joint names. + + Args: + joint_names (List[str]): A list of joint names + to be included in the control group. + + Returns: + ControlGroup: An instance of ControlGroup containing the specified joints + and their associated links. + """ + control_group = ControlGroup() + joint_id_list = [] + + for joint_name in joint_names: + if joint_name in self.joint_names: + joint_index = self.joint_names.index(joint_name) + joint_id_list.append(joint_index) + control_group.joint_names.append(joint_name) + + # Set root link for first joint + if len(control_group.link_names) == 0: + parent_names = self._entities[0].get_ancestral_link_names( + joint_index + ) + control_group.link_names.extend(parent_names) + + child_name = self._entities[0].get_child_link_name(joint_index) + control_group.link_names.append(child_name) + + control_group.joint_ids = joint_id_list + return control_group + + def build_pk_serial_chain(self) -> None: + """Build the kinematic serial chain for the robot. + + This method is mainly used for robot learning scenarios, for example: + - Imitation learning dataset generation. + """ + self.pk_serial_chain = self.cfg.build_pk_serial_chain(device=self.device) + + def destroy(self) -> None: + return super().destroy() diff --git a/embodichain/lab/sim/objects/soft_object.py b/embodichain/lab/sim/objects/soft_object.py new file mode 100644 index 00000000..fcbf1c95 --- /dev/null +++ b/embodichain/lab/sim/objects/soft_object.py @@ -0,0 +1,376 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import dexsim +import numpy as np +from functools import cached_property + +from dataclasses import dataclass +from typing import List, Sequence, Optional, Union + +from dexsim.models import MeshObject +from dexsim.engine import PhysicsScene +from dexsim.types import SoftBodyGPUAPIReadWriteType +from embodichain.lab.sim.common import ( + BatchEntity, +) +from embodichain.utils.math import ( + matrix_from_euler, +) +from embodichain.utils import logger +from embodichain.lab.sim.cfg import ( + SoftObjectCfg, +) +from embodichain.utils.math import xyz_quat_to_4x4_matrix + + +@dataclass +class SoftBodyData: + """Data manager for soft body + + Note: + 1. The pose data managed by dexsim is in the format of (qx, qy, qz, qw, x, y, z), but in EmbodySim, we use (x, y, z, qw, qx, qy, qz) format. + """ + + def __init__( + self, entities: List[MeshObject], ps: PhysicsScene, device: torch.device + ) -> None: + """Initialize the SoftBodyData. + + Args: + entities (List[MeshObject]): List of MeshObjects representing the soft bodies. + ps (PhysicsScene): The physics scene. + device (torch.device): The device to use for the soft body data. + """ + self.entities = entities + # TODO: soft body data can only be stored in cuda device for now. + self.device = device + # TODO: inorder to retrieve arena position, we need to access the node of each entity. + self._arena_positions = self._get_arena_position() + self.ps = ps + self.num_instances = len(entities) + + softbodies = [ + self.entities[i].get_physical_body() for i in range(self.num_instances) + ] + self.n_collision_vertices = softbodies[0].get_num_vertices() + self.n_sim_vertices = softbodies[0].get_num_sim_vertices() + + self._rest_position_buffer = torch.empty( + (self.num_instances, self.n_collision_vertices, 4), + device=self.device, + dtype=torch.float32, + ) + for i, softbody in enumerate(softbodies): + self._rest_position_buffer[i] = softbody.get_position_inv_mass_buffer() + + self._rest_sim_position_buffer = torch.empty( + (self.num_instances, self.n_sim_vertices, 4), + device=self.device, + dtype=torch.float32, + ) + + for i, softbody in enumerate(softbodies): + self._rest_sim_position_buffer[ + i + ] = softbody.get_sim_position_inv_mass_buffer() + + self._collision_position_buffer = torch.zeros( + (self.num_instances, self.n_collision_vertices, 4), + device=self.device, + dtype=torch.float32, + ) + self._sim_vertex_velocity_buffer = torch.zeros( + (self.num_instances, self.n_sim_vertices, 4), + device=self.device, + dtype=torch.float32, + ) + self._sim_vertex_position_buffer = torch.zeros( + (self.num_instances, self.n_sim_vertices, 4), + device=self.device, + dtype=torch.float32, + ) + + def _get_arena_position(self): + n_env = len(self.entities) + arena_positions = torch.empty( + (n_env, 3), device=self.device, dtype=torch.float32 + ) + for i, entity in enumerate(self.entities): + arena = entity.node.get_parent() + arena_position = arena.get_world_pose()[:3, 3] + arena_positions[i] = torch.as_tensor( + arena_position, device=self.device, dtype=torch.float32 + ) + return arena_positions + + @property + def rest_collision_vertices(self): + """Get the rest position buffer of the soft bodies.""" + return self._rest_position_buffer[:, :, :3].clone() + + @property + def rest_sim_vertices(self): + """Get the rest sim position buffer of the soft bodies.""" + return self._rest_sim_position_buffer[:, :, :3].clone() + + @property + def collision_position_buffer(self): + """Get the current vertex position buffer of the soft bodies.""" + for i, softbody in enumerate(self.soft_bodies): + self._collision_position_buffer[i] = softbody.get_position_inv_mass_buffer() + return self._collision_position_buffer.clone() + + @property + def sim_vertex_position_buffer(self): + """Get the current sim vertex position buffer of the soft bodies.""" + for i, softbody in enumerate(self.soft_bodies): + self._sim_vertex_position_buffer[ + i + ] = softbody.get_sim_position_inv_mass_buffer() + return self._sim_vertex_position_buffer.clone() + + @property + def sim_vertex_velocity_buffer(self): + """Get the current vertex velocity buffer of the soft bodies.""" + for i, softbody in enumerate(self.soft_bodies): + self._sim_vertex_velocity_buffer[ + i + ] = softbody.get_sim_position_inv_mass_buffer() + return self._sim_vertex_velocity_buffer.clone() + + +class SoftObject(BatchEntity): + """SoftObject represents a batch of rigid body in the simulation.""" + + def __init__( + self, + cfg: SoftObjectCfg, + entities: List[MeshObject] = None, + device: torch.device = torch.device("cpu"), + ) -> None: + self._world = dexsim.default_world() + self._ps = self._world.get_physics_scene() + self._all_indices = torch.arange( + len(entities), dtype=torch.int32, device=device + ) + + self._data = SoftBodyData(entities=entities, ps=self._ps, device=device) + + # TODO: soft body physical attribute is already set in soft body creation(embodichain/lab/sim/utility/sim_utils.py load_soft_object_from_cfg) + self._world.update(0.001) + + super().__init__(cfg=cfg, entities=entities, device=device) + + # set default collision filter + self._set_default_collision_filter() + + def _set_default_collision_filter(self) -> None: + collision_filter_data = torch.zeros( + size=(self.num_instances, 4), dtype=torch.int32 + ) + for i in range(self.num_instances): + collision_filter_data[i, 0] = i + collision_filter_data[i, 1] = 1 + self.set_collision_filter(collision_filter_data) + + def set_collision_filter( + self, filter_data: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """set collision filter data for the rigid object. + + Args: + filter_data (torch.Tensor): [N, 4] of int. + First element of each object is arena id. + If 2nd element is 0, the object will collision with all other objects in world. + 3rd and 4th elements are not used currently. + + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. Defaults to None. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(filter_data): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(filter_data)}." + ) + + filter_data_np = filter_data.cpu().numpy().astype(np.uint32) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].get_physical_body().set_collision_filter_data( + filter_data_np[i] + ) + + @property + def body_data(self) -> Optional[SoftBodyData]: + """Get the soft body data manager for this rigid object. + + Returns: + SoftBodyData: The rigid body data manager. + """ + return self._data + + def set_local_pose( + self, pose: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set local pose of the rigid object. + + Args: + pose (torch.Tensor): The local pose of the rigid object with shape (N, 7) or (N, 4, 4). + env_ids (Optional[Sequence[int]], optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(pose): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match pose length {len(pose)}." + ) + + if pose.dim() == 2 and pose.shape[1] == 7: + pose4x4 = xyz_quat_to_4x4_matrix(pose) + elif pose.dim() == 3 and pose.shape[1:3] == (4, 4): + pose4x4 = pose + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + for i, env_idx in enumerate(local_env_ids): + # TODO: soft body cannot directly set by `set_local_pose` currently. + rest_collision_vertices = self.body_data.rest_collision_vertices[i] + rest_sim_vertices = self.body_data.rest_sim_vertices[i] + rotation = pose4x4[i][:3, :3] + translation = pose4x4[i][:3, 3] + + # apply transformation to local rest vertices and back + rest_collision_vertices_local = ( + rest_collision_vertices - self._data._arena_positions[i] + ) + transformed_collision_vertices = ( + rest_collision_vertices_local @ rotation.T + translation + ) + transformed_collision_vertices = ( + transformed_collision_vertices + self._data._arena_positions[i] + ) + + rest_sim_vertices_local = rest_sim_vertices - self._data._arena_positions[i] + transformed_sim_vertices = ( + rest_sim_vertices_local @ rotation.T + translation + ) + transformed_sim_vertices = ( + transformed_sim_vertices + self._data._arena_positions[i] + ) + + # apply vertices to soft body + soft_body = self._entities[env_idx].get_physical_body() + collision_position_buffer = soft_body.get_position_inv_mass_buffer() + sim_position_buffer = soft_body.get_sim_position_inv_mass_buffer() + sim_velocity_buffer = soft_body.get_sim_velocity_buffer() + + collision_position_buffer[:, :3] = transformed_collision_vertices + sim_position_buffer[:, :3] = transformed_sim_vertices + sim_velocity_buffer[:, :3] = 0.0 + + soft_body.mark_dirty(SoftBodyGPUAPIReadWriteType.ALL) + # TODO: currently soft body has no wake up interface, use set_wake_counter and pass in a positive value to wake it up + soft_body.set_wake_counter(0.4) + + def get_rest_collision_vertices(self) -> torch.Tensor: + """Get the rest collision vertices of the soft object. + + Returns: + torch.Tensor: The rest collision vertices with shape (N, num_collision_vertices, 3). + """ + return self.body_data.rest_collision_vertices + + def get_rest_sim_vertices(self) -> torch.Tensor: + """Get the rest sim vertices of the soft object. + + Returns: + torch.Tensor: The rest sim vertices with shape (N, num_sim_vertices, 3). + """ + return self.body_data.rest_sim_vertices + + def get_current_collision_vertices(self) -> torch.Tensor: + """Get the current collision vertices of the soft object. + + Returns: + torch.Tensor: The current collision vertices with shape (N, num_collision_vertices, 3). + """ + return self.body_data.collision_position_buffer + + def get_current_sim_vertices(self) -> torch.Tensor: + """Get the current sim vertices of the soft object. + + Returns: + torch.Tensor: The current sim vertices with shape (N, num_sim_vertices, 3). + """ + return self.body_data.sim_vertex_position_buffer + + def get_current_sim_vertex_velocities(self) -> torch.Tensor: + """Get the current sim vertex velocities of the soft object. + + Returns: + torch.Tensor: The current sim vertex velocities with shape (N, num_sim_vertices, 3). + """ + return self.body_data.sim_vertex_velocity_buffer + + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get local pose of the rigid object. + + Args: + to_matrix (bool, optional): If True, return the pose as a 4x4 matrix. If False, return as (x, y, z, qw, qx, qy, qz). Defaults to False. + + Returns: + torch.Tensor: The local pose of the rigid object with shape (N, 7) or (N, 4, 4) depending on `to_matrix`. + """ + raise NotImplementedError("Getting local pose for SoftObject is not supported.") + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + local_env_ids = self._all_indices if env_ids is None else env_ids + num_instances = len(local_env_ids) + + # TODO: set attr for soft body after loading in physx scene + + # rest soft body to init_pos + pos = torch.as_tensor( + self.cfg.init_pos, dtype=torch.float32, device=self.device + ) + rot = ( + torch.as_tensor(self.cfg.init_rot, dtype=torch.float32, device=self.device) + * torch.pi + / 180.0 + ) + pos = pos.unsqueeze(0).repeat(num_instances, 1) + rot = rot.unsqueeze(0).repeat(num_instances, 1) + mat = matrix_from_euler(rot, "XYZ") + pose = ( + torch.eye(4, dtype=torch.float32, device=self.device) + .unsqueeze(0) + .repeat(num_instances, 1, 1) + ) + pose[:, :3, 3] = pos + pose[:, :3, :3] = mat + self.set_local_pose(pose, env_ids=local_env_ids) + + def destroy(self) -> None: + # TODO: not tested yet + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + for i, entity in enumerate(self._entities): + arenas[i].remove_actor(entity) diff --git a/embodichain/lab/sim/planners/base_planner.py b/embodichain/lab/sim/planners/base_planner.py new file mode 100644 index 00000000..d85e026d --- /dev/null +++ b/embodichain/lab/sim/planners/base_planner.py @@ -0,0 +1,201 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple, Union, Optional +import matplotlib.pyplot as plt + +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.utils import logger + + +class BasePlanner(ABC): + r"""Base class for trajectory planners. + + This class provides common functionality that can be shared across different + planner implementations, such as constraint checking and trajectory visualization. + + Args: + dofs: Number of degrees of freedom + max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints + """ + + def __init__(self, dofs: int, max_constraints: Dict[str, List[float]]): + self.dofs = dofs + self.max_constraints = max_constraints + + @abstractmethod + def plan( + self, + current_state: Dict, + target_states: List[Dict], + **kwargs, + ) -> Tuple[ + bool, + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + float, + ]: + r"""Execute trajectory planning. + + This method must be implemented by subclasses to provide the specific + planning algorithm. + + Args: + current_state: Dictionary containing 'position', 'velocity', 'acceleration' for current state + target_states: List of dictionaries containing target states + + Returns: + Tuple of (success, positions, velocities, accelerations, times, duration): + - success: bool, whether planning succeeded + - positions: np.ndarray (N, DOF), joint positions along trajectory + - velocities: np.ndarray (N, DOF), joint velocities along trajectory + - accelerations: np.ndarray (N, DOF), joint accelerations along trajectory + - times: np.ndarray (N,), time stamps for each point + - duration: float, total trajectory duration + """ + logger.log_error("Subclasses must implement plan() method", NotImplementedError) + + def is_satisfied_constraint( + self, velocities: np.ndarray, accelerations: np.ndarray + ) -> bool: + r"""Check if the trajectory satisfies velocity and acceleration constraints. + + This method checks whether the given velocities and accelerations satisfy + the constraints defined in max_constraints. It allows for some tolerance + to account for numerical errors in dense waypoint scenarios. + + Args: + velocities: Velocity array (N, DOF) where N is the number of trajectory points + accelerations: Acceleration array (N, DOF) where N is the number of trajectory points + + Returns: + bool: True if all constraints are satisfied, False otherwise + + Note: + - Allows 10% tolerance for velocity constraints + - Allows 25% tolerance for acceleration constraints + - Prints exceed information if constraints are violated + - Assumes symmetric constraints (velocities and accelerations can be positive or negative) + """ + # Convert max_constraints to symmetric format for constraint checking + # This assumes symmetric constraints (common for most planners) + vlims = np.array([[-v, v] for v in self.max_constraints["velocity"]]) + alims = np.array([[-a, a] for a in self.max_constraints["acceleration"]]) + + vel_check = np.all((velocities >= vlims[:, 0]) & (velocities <= vlims[:, 1])) + acc_check = np.all( + (accelerations >= alims[:, 0]) & (accelerations <= alims[:, 1]) + ) + + # 超限情况 + if not vel_check: + vel_exceed_info = [] + min_vel = np.min(velocities, axis=0) + max_vel = np.max(velocities, axis=0) + for i in range(self.dofs): + exceed_percentage = 0 + max_vel_limit = self.max_constraints["velocity"][i] + if min_vel[i] < -max_vel_limit: + exceed_percentage = (min_vel[i] + max_vel_limit) / max_vel_limit + if max_vel[i] > max_vel_limit: + temp = (max_vel[i] - max_vel_limit) / max_vel_limit + if temp > exceed_percentage: + exceed_percentage = temp + vel_exceed_info.append(exceed_percentage * 100) + logger.log_info(f"Velocity exceed info: {vel_exceed_info} percentage") + + if not acc_check: + acc_exceed_info = [] + min_acc = np.min(accelerations, axis=0) + max_acc = np.max(accelerations, axis=0) + for i in range(self.dofs): + exceed_percentage = 0 + max_acc_limit = self.max_constraints["acceleration"][i] + if min_acc[i] < -max_acc_limit: + exceed_percentage = (min_acc[i] + max_acc_limit) / max_acc_limit + if max_acc[i] > max_acc_limit: + temp = (max_acc[i] - max_acc_limit) / max_acc_limit + if temp > exceed_percentage: + exceed_percentage = temp + acc_exceed_info.append(exceed_percentage * 100) + logger.log_info(f"Acceleration exceed info: {acc_exceed_info} percentage") + + return vel_check and acc_check + + def plot_trajectory( + self, positions: np.ndarray, velocities: np.ndarray, accelerations: np.ndarray + ) -> None: + r"""Plot trajectory data. + + This method visualizes the trajectory by plotting position, velocity, and + acceleration curves for each joint over time. It also displays the constraint + limits for reference. + + Args: + positions: Position array (N, DOF) where N is the number of trajectory points + velocities: Velocity array (N, DOF) where N is the number of trajectory points + accelerations: Acceleration array (N, DOF) where N is the number of trajectory points + + Note: + - Creates a 3-subplot figure (position, velocity, acceleration) + - Shows constraint limits as dashed lines + - Requires matplotlib to be installed + """ + time_step = 0.01 + time_steps = np.arange(positions.shape[0]) * time_step + fig, axs = plt.subplots(3, 1, figsize=(10, 8)) + + for i in range(self.dofs): + axs[0].plot(time_steps, positions[:, i], label=f"Joint {i+1}") + axs[1].plot(time_steps, velocities[:, i], label=f"Joint {i+1}") + axs[2].plot(time_steps, accelerations[:, i], label=f"Joint {i+1}") + + # Plot velocity constraints (only for first joint to avoid clutter) + # Convert max_constraints to symmetric format for visualization + if self.dofs > 0: + max_vel = self.max_constraints["velocity"][0] + max_acc = self.max_constraints["acceleration"][0] + axs[1].plot( + time_steps, + [-max_vel] * len(time_steps), + "k--", + label="Max Velocity", + ) + axs[1].plot(time_steps, [max_vel] * len(time_steps), "k--") + # Plot acceleration constraints (only for first joint to avoid clutter) + axs[2].plot( + time_steps, + [-max_acc] * len(time_steps), + "k--", + label="Max Accleration", + ) + axs[2].plot(time_steps, [max_acc] * len(time_steps), "k--") + + axs[0].set_title("Position") + axs[1].set_title("Velocity") + axs[2].set_title("Acceleration") + + for ax in axs: + ax.set_xlabel("Time [s]") + ax.legend() + ax.grid() + + plt.tight_layout() + plt.show() diff --git a/embodichain/lab/sim/planners/motion_generator.py b/embodichain/lab/sim/planners/motion_generator.py new file mode 100644 index 00000000..165844f7 --- /dev/null +++ b/embodichain/lab/sim/planners/motion_generator.py @@ -0,0 +1,604 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from typing import Dict, List, Tuple, Union, Optional, Any +from enum import Enum +from scipy.spatial.transform import Rotation, Slerp + +from embodichain.lab.sim.planners.toppra_planner import ToppraPlanner +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.lab.sim.objects.robot import Robot +from embodichain.utils import logger + + +class PlannerType(Enum): + r"""Enumeration for different planner types.""" + TOPPRA = "toppra" + """TOPPRA planner for time-optimal trajectory planning.""" + + +class MotionGenerator: + r"""Unified motion generator for robot trajectory planning. + + This class provides a unified interface for trajectory planning with and without + collision checking. It supports V3 environment interfaces and can use different + types of planners (ToppraPlanner, RRT, PRM, etc.) for trajectory generation. + + Args: + robot: Robot agent object (must support compute_fk, compute_ik, dof, get_joint_ids) + uid: Unique identifier for the robot (optional) + sim: Simulation environment object (optional, reserved for future collision checking) + planner_type: Type of planner to use (default: "toppra") + default_velocity: Default velocity limits for each joint (rad/s) + default_acceleration: Default acceleration limits for each joint (rad/s²) + collision_margin: Safety margin for collision checking (meters, reserved for future use) + **kwargs: Additional arguments passed to planner initialization + """ + + def __init__( + self, + robot: Robot, + uid: str, + sim=None, + planner_type: Union[str, PlannerType] = "toppra", + default_velocity: float = 0.2, + default_acceleration: float = 0.5, + collision_margin: float = 0.01, + **kwargs, + ): + self.robot = robot + self.sim = sim + self.collision_margin = collision_margin + self.uid = uid + + # Get robot DOF using get_joint_ids for specified control part (None for whole body) + self.dof = len(robot.get_joint_ids(uid)) + + # Create planner based on planner_type + self.planner_type = self._parse_planner_type(planner_type) + self.planner = self._create_planner( + self.planner_type, default_velocity, default_acceleration, **kwargs + ) + + def _parse_planner_type(self, planner_type: Union[str, PlannerType]) -> str: + r"""Parse planner type from string or enum. + + Args: + planner_type: Planner type as string or PlannerType enum + + Returns: + Planner type as string + """ + if isinstance(planner_type, PlannerType): + return planner_type.value + elif isinstance(planner_type, str): + planner_type_lower = planner_type.lower() + # Validate planner type + valid_types = [e.value for e in PlannerType] + if planner_type_lower not in valid_types: + logger.log_warning( + f"Unknown planner type '{planner_type}', using 'toppra'. " + f"Valid types: {valid_types}" + ) + return "toppra" + return planner_type_lower + else: + logger.log_error( + f"planner_type must be str or PlannerType, got {type(planner_type)}", + TypeError, + ) + + def _create_planner( + self, + planner_type: str, + default_velocity: float, + default_acceleration: float, + **kwargs, + ) -> Any: + r"""Create planner instance based on planner type. + + Args: + planner_type: Type of planner to create + default_velocity: Default velocity limit + default_acceleration: Default acceleration limit + **kwargs: Additional arguments for planner initialization + + Returns: + Planner instance + """ + # Get constraints from robot or use defaults + max_constraints = self._get_constraints( + default_velocity, default_acceleration, **kwargs + ) + + if planner_type == "toppra": + return ToppraPlanner(self.dof, max_constraints) + else: + logger.log_error( + f"Unknown planner type '{planner_type}'. " + f"Supported types: {[e.value for e in PlannerType]}", + ValueError, + ) + + def _get_constraints( + self, default_velocity: float, default_acceleration: float, **kwargs + ) -> Dict[str, List[float]]: + r"""Get velocity and acceleration constraints for the robot. + + Priority: + 1. kwargs['max_constraints'] if provided + 2. Robot's built-in constraints (if available) + 3. Default values + + Args: + default_velocity: Default velocity limit + default_acceleration: Default acceleration limit + **kwargs: Additional arguments + + Returns: + Dictionary with 'velocity' and 'acceleration' constraints + """ + # Check if constraints are provided in kwargs + if "max_constraints" in kwargs and kwargs["max_constraints"] is not None: + constraints = kwargs["max_constraints"] + if isinstance(constraints, dict) and "velocity" in constraints: + return constraints + + # Try to get constraints from robot (if available) + # TODO: Add robot.get_joint_limits() or similar if available in future + + # Use default constraints + return { + "velocity": [default_velocity] * self.dof, + "acceleration": [default_acceleration] * self.dof, + } + + def _create_state_dict( + self, position: np.ndarray, velocity: Optional[np.ndarray] = None + ) -> Dict: + r"""Create a state dictionary for trajectory planning. + + Args: + position: Joint positions + velocity: Joint velocities (optional, defaults to zeros) + acceleration: Joint accelerations (optional, defaults to zeros) + + Returns: + State dictionary with 'position', 'velocity', 'acceleration' + """ + if velocity is None: + velocity = np.zeros(self.dof) + + if isinstance(position, torch.Tensor) | isinstance(position, np.ndarray): + position = position.squeeze() + + return { + "position": position.tolist() + if isinstance(position, np.ndarray) + else position, + "velocity": velocity.tolist() + if isinstance(velocity, np.ndarray) + else velocity, + "acceleration": [0.0] * self.dof, + } + + def plan( + self, + current_state: Dict, + target_states: List[Dict], + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.TIME, + sample_interval: Union[float, int] = 0.01, + **kwargs, + ) -> Tuple[ + bool, + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + Optional[np.ndarray], + float, + ]: + r"""Plan trajectory without collision checking. + + This method generates a smooth trajectory using the selected planner that satisfies + velocity and acceleration constraints, but does not check for collisions. + + Args: + current_state: Dictionary containing current state: + - "position": Current joint positions (required) + - "velocity": Current joint velocities (optional, defaults to zeros) + - "acceleration": Current joint accelerations (optional, defaults to zeros) + target_states: List of target state dictionaries, each with same format as current_state + sample_method: Sampling method (TIME or QUANTITY) + sample_interval: Sampling interval (time in seconds for TIME method, or number of points for QUANTITY) + **kwargs: Additional arguments + + Returns: + Tuple of (success, positions, velocities, accelerations, times, duration): + - success: bool, whether planning succeeded + - positions: np.ndarray (N, DOF), joint positions along trajectory + - velocities: np.ndarray (N, DOF), joint velocities along trajectory + - accelerations: np.ndarray (N, DOF), joint accelerations along trajectory + - times: np.ndarray (N,), time stamps for each point + - duration: float, total trajectory duration + """ + # Validate inputs + if len(current_state["position"]) != self.dof: + logger.log_warning( + f"Current state position dimension {len(current_state['position'])} " + f"does not match robot DOF {self.dof}" + ) + return False, None, None, None, None, 0.0 + + for i, target in enumerate(target_states): + if len(target["position"]) != self.dof: + logger.log_warning( + f"Target state {i} position dimension {len(target['position'])} " + f"does not match robot DOF {self.dof}" + ) + return False, None, None, None, None, 0.0 + + # Plan trajectory using selected planner + ( + success, + positions, + velocities, + accelerations, + times, + duration, + ) = self.planner.plan( + current_state=current_state, + target_states=target_states, + sample_method=sample_method, + sample_interval=sample_interval, + ) + + return success, positions, velocities, accelerations, times, duration + + def plan_with_collision( + self, + current_state: Dict, + target_states: List[Dict], + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.TIME, + sample_interval: Union[float, int] = 0.01, + collision_check_interval: float = 0.01, + **kwargs, + ) -> None: + r"""Plan trajectory with collision checking. + + TODO: This method is not yet implemented. It should: + 1. Generate a trajectory using the selected planner + 2. Check for collisions along the trajectory + 3. Return failure if collisions are detected + """ + pass + + def create_discrete_trajectory( + self, + xpos_list: Optional[List[np.ndarray]] = None, + qpos_list: Optional[List[np.ndarray]] = None, + is_use_current_qpos: bool = True, + is_linear: bool = False, + sample_method: TrajectorySampleMethod = TrajectorySampleMethod.QUANTITY, + sample_num: Union[float, int] = 20, + qpos_seed: Optional[np.ndarray] = None, + **kwargs, + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + r"""Generate a discrete trajectory between waypoints using cartesian or joint space interpolation. + + This method supports two trajectory planning approaches: + 1. Linear interpolation: Fast, uniform spacing, no dynamics constraints + 2. Planner-based: Smooth, considers velocity/acceleration limits, realistic motion + + Args: + xpos_list: List of waypoints as 4x4 transformation matrices (optional) + qpos_list: List of joint configurations (optional) + is_use_current_qpos: Whether to use current joint angles as starting point + is_linear: If True, use cartesian linear interpolation, else joint space + sample_method: Sampling method (QUANTITY or TIME) + sample_num: Number of interpolated points for final trajectory + qpos_seed: Initial joint configuration for IK solving + **kwargs: Additional arguments + + Returns: + A tuple containing: + - List[np.ndarray]: Joint space trajectory as a list of joint configurations + - List[np.ndarray]: Cartesian space trajectory as a list of 4x4 matrices + """ + + def interpolate_xpos( + current_xpos: np.ndarray, target_xpos: np.ndarray, num_samples: int + ) -> List[np.ndarray]: + """Interpolate between two poses using Slerp for rotation and linear for translation.""" + if num_samples < 2: + num_samples = 2 + + slerp = Slerp( + [0, 1], + Rotation.from_matrix([current_xpos[:3, :3], target_xpos[:3, :3]]), + ) + interpolated_poses = [] + for s in np.linspace(0, 1, num_samples): + interp_rot = slerp(s).as_matrix() + interp_trans = (1 - s) * current_xpos[:3, 3] + s * target_xpos[:3, 3] + interp_pose = np.eye(4) + interp_pose[:3, :3] = interp_rot + interp_pose[:3, 3] = interp_trans + interpolated_poses.append(interp_pose) + return interpolated_poses + + def calculate_point_allocations( + xpos_list: List[np.ndarray], + step_size: float = 0.002, + angle_step: float = np.pi / 90, + ) -> List[int]: + """Calculate number of interpolation points between each pair of waypoints.""" + point_allocations = [] + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + if isinstance(start_pose, torch.Tensor): + start_pose = start_pose.squeeze().cpu().numpy() + if isinstance(end_pose, torch.Tensor): + end_pose = end_pose.squeeze().cpu().numpy() + + pos_dist = np.linalg.norm(end_pose[:3, 3] - start_pose[:3, 3]) + pos_points = max(1, int(pos_dist / step_size)) + + angle_diff = Rotation.from_matrix( + start_pose[:3, :3].T @ end_pose[:3, :3] + ) + angle = abs(angle_diff.as_rotvec()).max() + rot_points = max(1, int(angle / angle_step)) + + num_points = max(pos_points, rot_points) + point_allocations.append(num_points) + + return point_allocations + + # Handle input arguments + if qpos_list is not None: + qpos_list = np.asarray(qpos_list) + qpos_tensor = ( + torch.tensor(qpos_list) + if not isinstance(qpos_list, torch.Tensor) + else qpos_list + ) + xpos_list = [ + self.robot.compute_fk(qpos=q, name=self.uid, to_matrix=True) + .squeeze(0) + .cpu() + .numpy() + for q in qpos_tensor + ] + + if xpos_list is None: + logger.log_warning("Either xpos_list or qpos_list must be provided") + return [], [] + + # Get current position if needed + if is_use_current_qpos: + joint_ids = self.robot.get_joint_ids(self.uid) + qpos_tensor = self.robot.get_qpos() + # qpos_tensor shape: (batch, dof), usually batch=1 + current_qpos = qpos_tensor[0, joint_ids] + + current_xpos = ( + self.robot.compute_fk(qpos=current_qpos, name=self.uid, to_matrix=True) + .squeeze(0) + .cpu() + .numpy() + ) + + # Check if current position is significantly different from first waypoint + pos_diff = np.linalg.norm(current_xpos[:3, 3] - xpos_list[0][:3, 3]) + rot_diff = np.linalg.norm(current_xpos[:3, :3] - xpos_list[0][:3, :3]) + + if pos_diff > 0.001 or rot_diff > 0.01: + xpos_list = np.concatenate( + [current_xpos[None, :, :], xpos_list], axis=0 + ) + if qpos_list is not None: + qpos_list = np.concatenate( + [current_qpos[None, :], qpos_list], axis=0 + ) + + if qpos_seed is None and qpos_list is not None: + qpos_seed = qpos_list[0] + + # Input validation + if len(xpos_list) < 2: + logger.log_warning("xpos_list must contain at least 2 points") + return [], [] + + # Calculate point allocations for interpolation + interpolated_point_allocations = calculate_point_allocations( + xpos_list, step_size=0.002, angle_step=np.pi / 90 + ) + + # Generate trajectory + interpolate_qpos_list = [] + if is_linear or qpos_list is None: + # Linear cartesian interpolation + for i in range(len(xpos_list) - 1): + interpolated_poses = interpolate_xpos( + xpos_list[i], xpos_list[i + 1], interpolated_point_allocations[i] + ) + for xpos in interpolated_poses: + success, qpos = self.robot.compute_ik( + pose=xpos, joint_seed=qpos_seed, name=self.uid + ) + + if isinstance(success, torch.Tensor): + is_success = bool(success.all()) + elif isinstance(success, np.ndarray): + is_success = bool(np.all(success)) + elif isinstance(success, (list, tuple)): + is_success = all(success) + else: + is_success = bool(success) + + if isinstance(qpos, torch.Tensor): + has_nan = torch.isnan(qpos).any().item() + else: + has_nan = np.isnan(qpos).any() + + if not is_success or qpos is None or has_nan: + logger.log_debug( + f"IK failed or returned nan at pose, skipping this point." + ) + continue + + interpolate_qpos_list.append( + qpos[0] if isinstance(qpos, (np.ndarray, list)) else qpos + ) + qpos_seed = ( + qpos[0] if isinstance(qpos, (np.ndarray, list)) else qpos + ) + else: + # Joint space interpolation + interpolate_qpos_list = ( + qpos_list.tolist() if isinstance(qpos_list, np.ndarray) else qpos_list + ) + + if len(interpolate_qpos_list) < 2: + logger.log_error("Need at least 2 waypoints for trajectory planning") + + # Create trajectory dictionary + current_state = self._create_state_dict(interpolate_qpos_list[0]) + target_states = [ + self._create_state_dict(pos) for pos in interpolate_qpos_list[1:] + ] + + # Plan trajectory using internal plan method + success, positions, velocities, accelerations, times, duration = self.plan( + current_state=current_state, + target_states=target_states, + sample_method=sample_method, + sample_interval=sample_num, + **kwargs, + ) + + if not success or positions is None: + logger.log_error("Failed to plan trajectory") + + # Convert positions to list + out_qpos_list = ( + positions.tolist() if isinstance(positions, np.ndarray) else positions + ) + + out_qpos_list = ( + torch.tensor(out_qpos_list) + if not isinstance(out_qpos_list, torch.Tensor) + else out_qpos_list + ) + out_xpos_list = [ + self.robot.compute_fk(qpos=q.unsqueeze(0), name=self.uid, to_matrix=True) + .squeeze(0) + .cpu() + .numpy() + for q in out_qpos_list + ] + + return out_qpos_list, out_xpos_list + + def estimate_trajectory_sample_count( + self, + xpos_list: List[np.ndarray] = None, + qpos_list: List[np.ndarray] = None, + step_size: float = 0.01, + angle_step: float = np.pi / 90, + **kwargs, + ) -> int: + """Estimate the number of trajectory sampling points required. + + This function estimates the total number of sampling points needed to generate + a trajectory based on the given waypoints and sampling parameters. It can be + used to predict computational load and memory requirements before actual + trajectory generation. + + Args: + xpos_list: List of 4x4 transformation matrices representing waypoints + qpos_list: List of joint positions (optional) + is_linear: Whether to use linear interpolation + step_size: Maximum allowed distance between consecutive points (in meters) + angle_step: Maximum allowed angular difference between consecutive points (in radians) + **kwargs: Additional parameters for further customization + + Returns: + int: Estimated number of trajectory sampling points + """ + + def rotation_matrix_to_angle(self, rot_matrix: np.ndarray) -> float: + cos_angle = (np.trace(rot_matrix) - 1) / 2 + cos_angle = np.clip(cos_angle, -1.0, 1.0) + return np.arccos(cos_angle) + + # Input validation + if xpos_list is None and qpos_list is None: + return 0 + + # If joint position list is provided but end effector position list is not, + # convert through forward kinematics + if qpos_list is not None and xpos_list is None: + if len(qpos_list) < 2: + return 1 if len(qpos_list) == 1 else 1 + xpos_list = [ + self.robot.compute_fk( + qpos=torch.tensor(q, dtype=torch.float32), + name=self.uid, + to_matrix=True, + ) + .squeeze(0) + .cpu() + .numpy() + for q in qpos_list + ] + + if xpos_list is None or len(xpos_list) == 0: + return 1 + + if len(xpos_list) == 1: + return 1 + + total_samples = 1 # Starting point + + total_pos_dist = 0.0 + total_angle = 0.0 + + for i in range(len(xpos_list) - 1): + start_pose = xpos_list[i] + end_pose = xpos_list[i + 1] + + pos_diff = end_pose[:3, 3] - start_pose[:3, 3] + total_pos_dist += np.linalg.norm(pos_diff) + + try: + rot_matrix = start_pose[:3, :3].T @ end_pose[:3, :3] + angle = rotation_matrix_to_angle(rot_matrix) + total_angle += angle + except Exception: + pass + + pos_samples = max(1, int(total_pos_dist / step_size)) + rot_samples = max(1, int(total_angle / angle_step)) + + total_samples = max(pos_samples, rot_samples) + + return max(2, total_samples) diff --git a/embodichain/lab/sim/planners/toppra_planner.py b/embodichain/lab/sim/planners/toppra_planner.py new file mode 100644 index 00000000..0a37b783 --- /dev/null +++ b/embodichain/lab/sim/planners/toppra_planner.py @@ -0,0 +1,172 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +from embodichain.utils import logger +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.lab.sim.planners.base_planner import BasePlanner + +from typing import TYPE_CHECKING, Union, Optional, Tuple + +try: + import toppra as ta + import toppra.constraint as constraint +except ImportError: + logger.log_error( + "toppra not installed. Install with `pip install toppra==0.6.3`", ImportError + ) + +ta.setup_logging(level="WARN") + + +class ToppraPlanner(BasePlanner): + def __init__(self, dofs, max_constraints): + r"""Initialize the TOPPRA trajectory planner. + + Args: + dofs: Number of degrees of freedom + max_constraints: Dictionary containing 'velocity' and 'acceleration' constraints + """ + super().__init__(dofs, max_constraints) + + # Create TOPPRA-specific constraint arrays (symmetric format) + # This format is required by TOPPRA library + self.vlims = np.array([[-v, v] for v in max_constraints["velocity"]]) + self.alims = np.array([[-a, a] for a in max_constraints["acceleration"]]) + + def plan( + self, + current_state: dict, + target_states: list[dict], + **kwargs, + ): + r"""Execute trajectory planning. + + Args: + current_state: Dictionary containing 'position', 'velocity', 'acceleration' for current state + target_states: List of dictionaries containing target states + + Returns: + Tuple of (success, positions, velocities, accelerations, times, duration) + """ + sample_method = kwargs.get("sample_method", TrajectorySampleMethod.TIME) + sample_interval = kwargs.get("sample_interval", 0.01) + if not isinstance(sample_interval, (float, int)): + logger.log_error( + f"sample_interval must be float/int, got {type(sample_interval)}", + TypeError, + ) + if sample_method == TrajectorySampleMethod.TIME and sample_interval <= 0: + logger.log_error("Time interval must be positive", ValueError) + elif sample_method == TrajectorySampleMethod.QUANTITY and sample_interval < 2: + logger.log_error("At least 2 sample points required", ValueError) + + # Check waypoints + if len(current_state["position"]) != self.dofs: + logger.log_info("Current wayponit does not align") + return False, None, None, None, None, None + for target in target_states: + if len(target["position"]) != self.dofs: + logger.log_info("Target Wayponits does not align") + return False, None, None, None, None, None + + if ( + len(target_states) == 1 + and np.sum( + np.abs( + np.array(target_states[0]["position"]) + - np.array(current_state["position"]) + ) + ) + < 1e-3 + ): + logger.log_info("Only two same waypoints, do not plan") + return ( + True, + np.array([current_state["position"], target_states[0]["position"]]), + np.array([[0.0] * self.dofs, [0.0] * self.dofs]), + np.array([[0.0] * self.dofs, [0.0] * self.dofs]), + 0, + 0, + ) + + # Build waypoints + waypoints = [np.array(current_state["position"])] + for target in target_states: + waypoints.append(np.array(target["position"])) + waypoints = np.array(waypoints) + + # Create spline interpolation + # NOTE: Suitable for dense waypoints + ss = np.linspace(0, 1, len(waypoints)) + + # NOTE: Suitable for sparse waypoints; for dense waypoints, CubicSpline may fail strict monotonicity requirement + # len_total = 0 + # len_from_start = [0] + # for i in range(len(waypoints)-1): + # len_total += np.sum(np.abs(waypoints[i+1] - waypoints[i])) + # len_from_start.append(len_total) + # ss = np.array([cur/len_total for cur in len_from_start]) + + path = ta.SplineInterpolator(ss, waypoints) + + # Set constraints + pc_vel = constraint.JointVelocityConstraint(self.vlims) + pc_acc = constraint.JointAccelerationConstraint(self.alims) + + # Create TOPPRA instance + instance = ta.algorithm.TOPPRA( + [pc_vel, pc_acc], + path, + parametrizer="ParametrizeConstAccel", + gridpt_min_nb_points=max(100, 10 * len(waypoints)), + ) + # NOTES:合理设置gridpt_min_nb_points对加速度约束很重要 + + # Compute parameterized trajectory + jnt_traj = instance.compute_trajectory() + if jnt_traj is None: + # raise RuntimeError("Unable to find feasible trajectory") + logger.log_info("Unable to find feasible trajectory") + return False, None, None, None, None, None + + duration = jnt_traj.duration + # Sample trajectory points + if duration <= 0: + logger.log_error(f"Duration must be positive, got {duration}", ValueError) + if sample_method == TrajectorySampleMethod.TIME: + n_points = max(2, int(np.ceil(duration / sample_interval)) + 1) + ts = np.linspace(0, duration, n_points) + else: + ts = np.linspace(0, duration, num=int(sample_interval)) + + positions = [] + velocities = [] + accelerations = [] + + for t in ts: + positions.append(jnt_traj.eval(t)) + velocities.append(jnt_traj.evald(t)) + accelerations.append(jnt_traj.evaldd(t)) + + return ( + True, + np.array(positions), + np.array(velocities), + np.array(accelerations), + ts, + duration, + ) diff --git a/embodichain/lab/sim/planners/utils.py b/embodichain/lab/sim/planners/utils.py new file mode 100644 index 00000000..154e8529 --- /dev/null +++ b/embodichain/lab/sim/planners/utils.py @@ -0,0 +1,54 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from enum import Enum +from typing import Union +from embodichain.utils import logger + + +class TrajectorySampleMethod(Enum): + r"""Enumeration for different trajectory sampling methods. + + This enum defines various methods for sampling trajectories, + providing meaningful names for different sampling strategies. + """ + TIME = "time" + """Sample based on time intervals.""" + + QUANTITY = "quantity" + """Sample based on a specified number of points.""" + + DISTANCE = "distance" + """Sample based on distance intervals.""" + + @classmethod + def from_str( + cls, value: Union[str, "TrajectorySampleMethod"] + ) -> "TrajectorySampleMethod": + if isinstance(value, cls): + return value + try: + return cls[value.upper()] + except KeyError: + valid_values = [e.name for e in cls] + logger.log_error( + f"Invalid version '{value}'. Valid values are: {valid_values}", + ValueError, + ) + + def __str__(self): + """Override string representation for better readability.""" + return self.value.capitalize() diff --git a/embodichain/lab/sim/robots/__init__.py b/embodichain/lab/sim/robots/__init__.py new file mode 100644 index 00000000..de4c08aa --- /dev/null +++ b/embodichain/lab/sim/robots/__init__.py @@ -0,0 +1,18 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .dexforce_w1 import * +from .cobotmagic import CobotMagicCfg diff --git a/embodichain/lab/sim/robots/cobotmagic.py b/embodichain/lab/sim/robots/cobotmagic.py new file mode 100644 index 00000000..93308f46 --- /dev/null +++ b/embodichain/lab/sim/robots/cobotmagic.py @@ -0,0 +1,227 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch +import numpy as np + +from typing import Dict, List, Optional, Any, Union + +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, + RigidBodyAttributesCfg, +) +from embodichain.lab.sim.solvers import SolverCfg, OPWSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import configclass +from embodichain.utils import logger + + +@configclass +class CobotMagicCfg(RobotCfg): + urdf_cfg: URDFCfg = None + control_parts: Optional[Dict[str, List[str]]] = None + solver_cfg: Optional[Dict[str, "SolverCfg"]] = None + + @classmethod + def from_dict(cls, init_dict: Dict[str, Union[str, float, int]]) -> CobotMagicCfg: + from embodichain.lab.sim.solvers import merge_solver_cfg + + cfg = cls() + default_cfgs = cls()._build_default_cfgs() + for key, value in default_cfgs.items(): + setattr(cfg, key, value) + + robot_cfg = RobotCfg.from_dict(init_dict) + + # set attrs into cfg from the robot_cfg + for key, value in init_dict.items(): + if key == "solver_cfg": + # merge provided solver_cfg values into default solver config + provided_solver_cfg = init_dict.get("solver_cfg") + if provided_solver_cfg: + for part, item in provided_solver_cfg.items(): + if "class_type" in provided_solver_cfg[part]: + cfg.solver_cfg[part] = robot_cfg.solver_cfg[part] + else: + try: + merged = merge_solver_cfg( + cfg.solver_cfg, provided_solver_cfg + ) + cfg.solver_cfg = merged + except Exception: + logger.log_error( + f"Failed to merge solver_cfg, using provided config outright." + ) + else: + setattr(cfg, key, getattr(robot_cfg, key)) + + return cfg + + @staticmethod + def _build_default_cfgs() -> Dict[str, Any]: + arm_urdf = get_data_path("CobotMagicArm/CobotMagicWithGripperV100.urdf") + left_arm_xpos = np.array( + [ + [1.0, 0.0, 0.0, 0.233], + [0.0, 1.0, 0.0, 0.300], + [0.0, 0.0, 1.0, 0.000], + [0.0, 0.0, 0.0, 1.000], + ] + ) + right_arm_xpos = np.array( + [ + [1.0, 0.0, 0.0, 0.233], + [0.0, 1.0, 0.0, -0.300], + [0.0, 0.0, 1.0, 0.000], + [0.0, 0.0, 0.0, 1.000], + ] + ) + urdf_cfg = URDFCfg( + components=[ + { + "component_type": "left_arm", + "urdf_path": arm_urdf, + "transform": left_arm_xpos, + }, + { + "component_type": "right_arm", + "urdf_path": arm_urdf, + "transform": right_arm_xpos, + }, + ] + ) + return { + "uid": "CobotMagic", + "urdf_cfg": urdf_cfg, + "control_parts": { + "left_arm": [ + "LEFT_JOINT1", + "LEFT_JOINT2", + "LEFT_JOINT3", + "LEFT_JOINT4", + "LEFT_JOINT5", + "LEFT_JOINT6", + ], + "left_eef": ["LEFT_JOINT7", "LEFT_JOINT8"], + "right_arm": [ + "RIGHT_JOINT1", + "RIGHT_JOINT2", + "RIGHT_JOINT3", + "RIGHT_JOINT4", + "RIGHT_JOINT5", + "RIGHT_JOINT6", + ], + "right_eef": ["RIGHT_JOINT7", "RIGHT_JOINT8"], + }, + "solver_cfg": { + "left_arm": OPWSolverCfg( + end_link_name="left_link6", + root_link_name="left_arm_base", + tcp=np.array( + [[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]] + ), + ), + "right_arm": OPWSolverCfg( + end_link_name="right_link6", + root_link_name="right_arm_base", + tcp=np.array( + [[-1, 0, 0, 0], [0, -1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]] + ), + ), + }, + "min_position_iters": 8, + "min_velocity_iters": 2, + "drive_pros": JointDrivePropertiesCfg( + stiffness={ + "LEFT_JOINT[1-6]": 7e4, + "RIGHT_JOINT[1-6]": 7e4, + "LEFT_JOINT[7-8]": 3e2, + "RIGHT_JOINT[7-8]": 3e2, + }, + damping={ + "LEFT_JOINT[1-6]": 1e3, + "RIGHT_JOINT[1-6]": 1e3, + "LEFT_JOINT[7-8]": 3e1, + "RIGHT_JOINT[7-8]": 3e1, + }, + max_effort={ + "LEFT_JOINT[1-6]": 3e6, + "RIGHT_JOINT[1-6]": 3e6, + "LEFT_JOINT[7-8]": 3e3, + "RIGHT_JOINT[7-8]": 3e3, + }, + ), + "attrs": RigidBodyAttributesCfg( + mass=0.1, + static_friction=0.95, + dynamic_friction=0.9, + linear_damping=0.7, + angular_damping=0.7, + contact_offset=0.005, + rest_offset=0.001, + restitution=0.01, + max_depenetration_velocity=1e1, + ), + } + + def build_pk_serial_chain( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> Dict[str, "pk.SerialChain"]: + from embodichain.lab.sim.utility.solver_utils import ( + create_pk_chain, + create_pk_serial_chain, + ) + + urdf_path = get_data_path("CobotMagicArm/CobotMagicNoGripper.urdf") + chain = create_pk_chain(urdf_path, device) + + left_arm_chain = create_pk_serial_chain( + chain=chain, end_link_name="link6", root_link_name="base_link" + ).to(device=device) + right_arm_chain = create_pk_serial_chain( + chain=chain, end_link_name="link6", root_link_name="base_link" + ).to(device=device) + return {"left_arm": left_arm_chain, "right_arm": right_arm_chain} + + +if __name__ == "__main__": + from embodichain.lab.sim import SimulationManager, SimulationManagerCfg + from embodichain.lab.sim.robots import CobotMagicCfg + + torch.set_printoptions(precision=5, sci_mode=False) + + config = SimulationManagerCfg(headless=False, sim_device="cuda") + sim = SimulationManager(config) + sim.build_multiple_arenas(2) + sim.set_manual_update(True) + + config = { + "init_pos": [0.0, 0.0, 1.0], + } + + cfg = CobotMagicCfg.from_dict(config) + robot = sim.add_robot(cfg=cfg) + + sim.init_gpu_physics() + print("CobotMagic added to the simulation.") + + from IPython import embed + + embed() diff --git a/embodichain/lab/sim/robots/dexforce_w1/__init__.py b/embodichain/lab/sim/robots/dexforce_w1/__init__.py new file mode 100644 index 00000000..5aee21db --- /dev/null +++ b/embodichain/lab/sim/robots/dexforce_w1/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .cfg import DexforceW1Cfg diff --git a/embodichain/lab/sim/robots/dexforce_w1/cfg.py b/embodichain/lab/sim/robots/dexforce_w1/cfg.py new file mode 100644 index 00000000..c049a071 --- /dev/null +++ b/embodichain/lab/sim/robots/dexforce_w1/cfg.py @@ -0,0 +1,341 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import enum +import json +import numpy as np +import typing +import torch + +from typing import Dict + +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1HandBrand, + DexforceW1ArmSide, + DexforceW1ArmKind, + DexforceW1Version, +) +from embodichain.lab.sim.robots.dexforce_w1.utils import ( + build_dexforce_w1_cfg, +) +from embodichain.lab.sim.solvers import SolverCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + JointDrivePropertiesCfg, + RigidBodyAttributesCfg, +) +from embodichain.data import get_data_path +from embodichain.utils import configclass, logger + + +@configclass +class DexforceW1Cfg(RobotCfg): + """DexforceW1 specific configuration, inherits from RobotCfg and allows custom parameters.""" + + version: DexforceW1Version = DexforceW1Version.V021 + arm_kind: DexforceW1ArmKind = DexforceW1ArmKind.INDUSTRIAL + + @classmethod + def from_dict( + cls, init_dict: typing.Dict[str, typing.Union[str, float, tuple]] + ) -> DexforceW1Cfg: + """Initialize DexforceW1Cfg from a dictionary. + + Args: + init_dict (Dict[str, Union[str, float, tuple]]): Dictionary of configuration parameters. + + Returns: + DexforceW1Cfg: An instance of DexforceW1Cfg with parameters set. + """ + from embodichain.lab.sim.solvers import merge_solver_cfg + + init_dict_m = init_dict.copy() + version = init_dict_m.get("version", "v021") + arm_kind = init_dict_m.get("arm_kind", "anthropomorphic") + init_dict_m.pop("version", None) + init_dict_m.pop("arm_kind", None) + cfg: DexforceW1Cfg = cls()._build_default_cfg( + version=version, arm_kind=arm_kind + ) + + default_physics_cfgs = cls()._build_default_physics_cfgs() + for key, value in default_physics_cfgs.items(): + setattr(cfg, key, value) + + default_solver_cfg = cls()._build_default_solver_cfg( + is_industrial=(arm_kind == "industrial") + ) + cfg.solver_cfg = default_solver_cfg + + # override default values with those provided in init_dict. + robot_cfg = RobotCfg.from_dict(init_dict_m) + + # set attrs into cfg from the robot_cfg, but merge solver_cfg specially + for key, value in init_dict_m.items(): + if key == "solver_cfg": + # merge provided solver_cfg values into default solver config + provided_solver_cfg = init_dict_m.get("solver_cfg") + if provided_solver_cfg: + for part, item in provided_solver_cfg.items(): + if "class_type" in provided_solver_cfg[part]: + cfg.solver_cfg[part] = robot_cfg.solver_cfg[part] + else: + try: + merged = merge_solver_cfg( + cfg.solver_cfg, provided_solver_cfg + ) + cfg.solver_cfg = merged + except Exception: + logger.log_error( + f"Failed to merge solver_cfg, using provided config outright." + ) + else: + setattr(cfg, key, getattr(robot_cfg, key)) + + return cfg + + @staticmethod + def _build_default_solver_cfg(is_industrial: bool) -> SolverCfg: + # TODO: maybe change default solver for DexforceW1 + from embodichain.lab.sim.solvers import PytorchSolverCfg + from embodichain.lab.sim.solvers import SRSSolverCfg + from embodichain.lab.sim.robots.dexforce_w1.params import ( + W1ArmKineParams, + ) + + if is_industrial: + w1_left_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.LEFT, + arm_kind=DexforceW1ArmKind.INDUSTRIAL, + version=DexforceW1Version.V021, + ) + w1_right_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.RIGHT, + arm_kind=DexforceW1ArmKind.INDUSTRIAL, + version=DexforceW1Version.V021, + ) + else: + w1_left_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.LEFT, + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + version=DexforceW1Version.V021, + ) + w1_right_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.RIGHT, + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + version=DexforceW1Version.V021, + ) + + return { + "right_arm": SRSSolverCfg( + end_link_name="right_ee", + root_link_name="right_arm_base", + dh_params=w1_right_arm_params.dh_params, + qpos_limits=w1_right_arm_params.qpos_limits, + T_e_oe=w1_right_arm_params.T_e_oe, + T_b_ob=w1_right_arm_params.T_b_ob, + link_lengths=w1_right_arm_params.link_lengths, + rotation_directions=w1_right_arm_params.rotation_directions, + tcp=np.array( + [ + [1.0, 0.0, 0.0, 0.012], + [0.0, 0.0, -1.0, -0.0675], + [0.0, 1.0, 0.0, 0.127], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ), + "left_arm": SRSSolverCfg( + end_link_name="left_ee", + root_link_name="left_arm_base", + dh_params=w1_left_arm_params.dh_params, + qpos_limits=w1_left_arm_params.qpos_limits, + T_e_oe=w1_left_arm_params.T_e_oe, + T_b_ob=w1_left_arm_params.T_b_ob, + link_lengths=w1_left_arm_params.link_lengths, + rotation_directions=w1_left_arm_params.rotation_directions, + tcp=np.array( + [ + [-1.0, 0.0, 0.0, 0.012], + [0.0, 0.0, 1.0, 0.0675], + [0.0, 1.0, 0.0, 0.127], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ), + } + + @staticmethod + def _build_default_physics_cfgs() -> typing.Dict[str, any]: + return { + "min_position_iters": 32, + "min_velocity_iters": 8, + "drive_pros": JointDrivePropertiesCfg( + stiffness={ + "(RIGHT|LEFT)_J[0-9]": 1e4, + "(RIGHT|LEFT)_[A-Z|_]+": 1e2, + "(ANKLE|KNEE|BUTTOCK|WAIST)": 1e7, + }, + damping={ + "(RIGHT|LEFT)_J[0-2]": 1e3, + "(RIGHT|LEFT)_[A-Z|_]+": 1e1, + "(ANKLE|KNEE|BUTTOCK|WAIST)": 1e4, + }, + max_effort={ + "(RIGHT|LEFT)_J[0-9]": 1e5, + "(RIGHT|LEFT)_[A-Z|_]+": 1e3, + "(ANKLE|KNEE|BUTTOCK|WAIST)": 1e10, + }, + ), + # TODO: we may use the some properties from URDF as default values + # eg. mass, friction, damping, etc. + "attrs": RigidBodyAttributesCfg( + mass=1.0, + static_friction=0.95, + dynamic_friction=0.9, + linear_damping=0.7, + angular_damping=0.7, + contact_offset=0.005, + rest_offset=0.001, + restitution=0.05, + max_depenetration_velocity=10.0, + ), + } + + @staticmethod + def _build_default_cfg( + version: str = "v021", arm_kind: str = "anthropomorphic" + ) -> DexforceW1Cfg: + hand_types = { + DexforceW1ArmSide.LEFT: DexforceW1HandBrand.BRAINCO_HAND, + DexforceW1ArmSide.RIGHT: DexforceW1HandBrand.BRAINCO_HAND, + } + hand_versions = { + DexforceW1ArmSide.LEFT: DexforceW1Version(version), + DexforceW1ArmSide.RIGHT: DexforceW1Version(version), + } + + cfg = build_dexforce_w1_cfg( + arm_kind=DexforceW1ArmKind(arm_kind), + hand_types=hand_types, + hand_versions=hand_versions, + ) + cfg.version = DexforceW1Version(version) + cfg.arm_kind = DexforceW1ArmKind(arm_kind) + + return cfg + + def to_dict(self): + """Convert config to a Python dict, handling enums and numpy arrays.""" + + def serialize(obj, _visited=None): + if _visited is None: + _visited = set() + # Only skip recursion for mutable objects (dict, custom class) + if isinstance(obj, (dict, object)) and not isinstance( + obj, (str, int, float, bool, type(None)) + ): + obj_id = id(obj) + if obj_id in _visited: + return None # Prevent infinite recursion + _visited.add(obj_id) + + if isinstance(obj, enum.Enum): + return obj.value + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, dict): + # Only serialize values, keep keys as str/int/float/bool/None + return {str(k): serialize(v, _visited) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [serialize(v, _visited) for v in obj] + if hasattr(obj, "to_dict") and obj is not self: + return serialize(obj.to_dict(), _visited) + if hasattr(obj, "__dict__"): + return {k: serialize(v, _visited) for k, v in obj.__dict__.items()} + return obj + + return serialize(self) + + def to_string(self): + """Return config as a JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + def save_to_file(self, filepath): + """Save config to a local file as JSON.""" + with open(filepath, "w") as f: + f.write(self.to_string()) + + def build_pk_serial_chain( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> Dict[str, "pk.SerialChain"]: + from embodichain.lab.sim.utility.solver_utils import ( + create_pk_chain, + create_pk_serial_chain, + ) + + if DexforceW1ArmKind.INDUSTRIAL == self.arm_kind: + urdf_path = get_data_path("DexforceW1V021/DexforceW1_v02_2.urdf") + elif DexforceW1ArmKind.ANTHROPOMORPHIC == self.arm_kind: + urdf_path = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + + chain = create_pk_chain(urdf_path, device) + + left_arm_chain = create_pk_serial_chain( + chain=chain, end_link_name="left_ee", root_link_name="left_arm_base" + ).to(device=device) + right_arm_chain = create_pk_serial_chain( + chain=chain, end_link_name="right_ee", root_link_name="right_arm_base" + ).to(device=device) + + return { + "left_arm": left_arm_chain, + "right_arm": right_arm_chain, + } + + +if __name__ == "__main__": + # Example usage + import numpy as np + + np.set_printoptions(precision=5, suppress=True) + from embodichain.lab.sim import SimulationManager, SimulationManagerCfg + from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1ArmKind, + ) + + config = SimulationManagerCfg(headless=True, sim_device="cpu") + sim = SimulationManager(config) + sim.build_multiple_arenas(1) + sim.set_manual_update(True) + + cfg = DexforceW1Cfg.from_dict( + { + "uid": "dexforce_w1", + "version": "v021", + "arm_kind": "anthropomorphic", + } + ) + + robot = sim.add_robot(cfg=cfg) + sim.update(step=1) + print("DexforceW1 robot added to the simulation.") + + from IPython import embed + + embed() diff --git a/embodichain/lab/sim/robots/dexforce_w1/params.py b/embodichain/lab/sim/robots/dexforce_w1/params.py new file mode 100644 index 00000000..6dfdf349 --- /dev/null +++ b/embodichain/lab/sim/robots/dexforce_w1/params.py @@ -0,0 +1,269 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from typing import Optional +from dataclasses import dataclass, field +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1HandBrand, + DexforceW1ArmSide, + DexforceW1ArmKind, + DexforceW1Version, +) + + +@dataclass +class W1ArmKineParams: + """Kinematics parameters for W1 arm variants. + + - arm_kind and W1Version enum types expected to be defined elsewhere. + - dh_params stored as numpy array of shape (7,4). + - qpos_limits stored as numpy array of shape (7,2) in radians. + """ + + arm_side: "DexforceW1ArmSide" + arm_kind: "DexforceW1ArmKind" + version: "DexforceW1Version" = field(default_factory=lambda: DexforceW1Version.V021) + + # (initialized in __post_init__) + # physical constants + d_list: list[float] = field(init=False, default_factory=list) + link_lengths: list[float] = field(init=False, default_factory=list) + rotation_directions: list[float] = field(init=False, default_factory=list) + + # transforms + T_b_ob: np.ndarray = field(init=False) + T_e_oe: np.ndarray = field(init=False) + + # kinematic parameters + dh_params: np.ndarray = field(init=False) + qpos_limits: np.ndarray = field(init=False) + + def __post_init__(self): + if self.version == DexforceW1Version.V021: + self.d_list = np.array([0.0, 0.0, 0.260, 0.0, 0.166, 0.098, 0.0]) + self.link_lengths = np.array( + [ + self.d_list[0] + self.d_list[1], + self.d_list[2] + self.d_list[3], + self.d_list[4] + self.d_list[5], + self.d_list[6], + ] + ) + else: + raise ValueError(f"W1Version {self.version} are not supported.") + + # helpers: create DH rows and clamp limits + def dh_row(d, alpha, a, theta): + return [d, alpha, a, theta] + + def deg2rad_list(list_of_pairs): + return np.deg2rad(np.array(list_of_pairs, dtype=float)) + + T_b_ob = np.array( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.1025], + [0.0, 0.0, 0.0, 1.0], + ] + ) + # Build parameters per arm_kind and side, minimizing duplication + if self.arm_kind == DexforceW1ArmKind.INDUSTRIAL: + # default tcp for industrial + T_e_oe = np.array( + [ + [-1.0, 0.0, 0.0, 0.0], + [0.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.066], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + # fmt: off + dh = [ + dh_row(self.link_lengths[0], -np.pi / 2, 0.0, 0.0), + dh_row(0.0, np.pi / 2, 0.0, 0.0), + dh_row(self.link_lengths[1], np.pi / 2, 0.0, np.pi / 2), + dh_row(0.0, -np.pi / 2, 0.0, 0.0), + dh_row(self.link_lengths[2], -np.pi / 2, 0.0, 0.0), + dh_row(0.0, np.pi / 2, 0.0, 0.0), + dh_row(self.link_lengths[3], 0.0, 0.0, 0.0), + ] + + # fmt: on + if self.arm_side == DexforceW1ArmSide.LEFT: + limits = [ + [-170.0, 170.0], + [-120.0, 90.0], + [-170.0, 170.0], + [-135.0, 90.0], + [-170.0, 170.0], + [-90.0, 90.0], + [-170.0, 170.0], + ] + rotation_directions = np.array([1, 1, 1, 1, 1, -1, 1]) + else: + limits = [ + [-170.0, 170.0], + [-90.0, 120.0], + [-170.0, 170.0], + [-90.0, 135.0], + [-170.0, 170.0], + [-90.0, 90.0], + [-170.0, 170.0], + ] + rotation_directions = np.array([1, 1, 1, -1, 1, 1, 1]) + + self.T_e_oe = T_e_oe + + elif self.arm_kind == DexforceW1ArmKind.ANTHROPOMORPHIC: + T_e_oe = np.array( + [ + [0.0, 0.0, -1.0, -0.066], + [0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + # fmt: off + dh = [ + dh_row(self.link_lengths[0], -np.pi / 2, 0.0, 0.0), + dh_row(0.0, np.pi / 2, 0.0, 0.0), + dh_row(self.link_lengths[1], np.pi / 2, 0.0, np.pi / 2), + dh_row(0.0, -np.pi / 2, 0.0, 0.0), + dh_row(self.link_lengths[2], -np.pi / 2, 0.0, 0.0), + dh_row(0.0, np.pi / 2, 0.0, np.pi / 2), + dh_row(self.link_lengths[3], 0.0, 0.0, 0.0), + ] + # fmt: on + + if self.arm_side == DexforceW1ArmSide.LEFT: + limits = [ + [-170.0, 170.0], + [-120.0, 90.0], + [-170.0, 170.0], + [-135.0, 90.0], + [-170.0, 170.0], + [-45.0, 45.0], + [-90.0, 60.0], + ] + rotation_directions = np.array([1, 1, 1, 1, 1, -1, 1]) + else: + limits = [ + [-170.0, 170.0], + [-90.0, 120.0], + [-170.0, 170.0], + [-90.0, 135.0], + [-170.0, 170.0], + [-45.0, 45.0], + [-60.0, 90.0], + ] + rotation_directions = np.array([1, 1, 1, -1, 1, 1, 1]) + else: + raise ValueError(f"Unsupported arm_kind: {self.arm_kind}") + + self.T_b_ob = T_b_ob + self.T_e_oe = T_e_oe + + # finalize arrays + self.dh_params = np.array(dh, dtype=float) + self.qpos_limits = deg2rad_list(limits) + self.rotation_directions = rotation_directions + + # sanity checks + assert self.dh_params.shape == (7, 4), "dh_params must be shape (7,4)" + assert self.qpos_limits.shape == (7, 2), "qpos_limits must be shape (7,2)" + + def as_dict(self) -> dict: + return { + "arm_side": self.arm_side.name, + "arm_kind": self.arm_kind.name, + "version": self.version.name, + "link_lengths": self.link_lengths.tolist(), + "T_b_ob": self.T_b_ob.tolist(), + "T_e_oe": self.T_e_oe.tolist(), + "dh_params": self.dh_params.tolist(), + "qpos_limits": self.qpos_limits.tolist(), + "rotation_directions": self.rotation_directions.tolist(), + } + + @classmethod + def from_dict(cls, data: dict) -> "W1ArmKineParams": + arm_side = ( + DexforceW1ArmSide[data["arm_side"]] + if isinstance(data.get("arm_side"), str) + else data.get("arm_side") + ) + + arm_kind = ( + DexforceW1ArmKind[data["arm_kind"]] + if isinstance(data.get("arm_kind"), str) + else data.get("arm_kind") + ) + version = ( + DexforceW1Version[data["version"]] + if isinstance(data.get("version"), str) + else data.get("version", DexforceW1Version.V021) + ) + inst = cls(arm_side=arm_side, arm_kind=arm_kind, version=version) + + # allow overriding computed arrays if provided + if "dh_params" in data: + object.__setattr__( + inst, "dh_params", np.array(data["dh_params"], dtype=float) + ) + if "qpos_limits" in data: + object.__setattr__( + inst, + "qpos_limits", + np.deg2rad(np.array(data["qpos_limits"], dtype=float)), + ) if np.max( + np.abs(np.array(data["qpos_limits"])) + ) > 2 * np.pi else object.__setattr__( + inst, "qpos_limits", np.array(data["qpos_limits"], dtype=float) + ) + if "link_lengths" in data: + object.__setattr__( + inst, "link_lengths", np.array(data["link_lengths"], dtype=float) + ) + if "T_b_ob" in data: + object.__setattr__(inst, "T_b_ob", np.array(data["T_b_ob"], dtype=float)) + if "T_e_oe" in data: + object.__setattr__(inst, "T_e_oe", np.array(data["T_e_oe"], dtype=float)) + if "rotation_directions" in data: + object.__setattr__( + inst, + "rotation_directions", + np.array(data["rotation_directions"], dtype=float), + ) + inst.validate() + return inst + + def to_torch( + self, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32 + ) -> dict: + dev = torch.device("cpu") if device is None else device + return { + "dh_params": torch.tensor(self.dh_params, dtype=dtype, device=dev), + "qpos_limits": torch.tensor(self.qpos_limits, dtype=dtype, device=dev), + "T_b_ob": torch.tensor(self.T_b_ob, dtype=dtype, device=dev), + "T_e_oe": torch.tensor(self.T_e_oe, dtype=dtype, device=dev), + "rotation_directions": torch.tensor( + self.rotation_directions, dtype=dtype, device=dev + ), + } diff --git a/embodichain/lab/sim/robots/dexforce_w1/types.py b/embodichain/lab/sim/robots/dexforce_w1/types.py new file mode 100644 index 00000000..939dde31 --- /dev/null +++ b/embodichain/lab/sim/robots/dexforce_w1/types.py @@ -0,0 +1,66 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import enum + +all = [ + "DexforceW1Version", + "DexforceW1ArmKind", + "DexforceW1ArmSide", + "DexforceW1Type", +] + + +class DexforceW1Version(enum.Enum): + """Versioning for DexforceW1 components.""" + + V021 = "v021" + + +class DexforceW1ArmKind(enum.Enum): + """Arm type for DexforceW1: anthropomorphic or industrial.""" + + ANTHROPOMORPHIC = "anthropomorphic" + INDUSTRIAL = "industrial" + + +class DexforceW1ArmSide(enum.Enum): + """Arm side for DexforceW1: left or right.""" + + LEFT = "left" + RIGHT = "right" + + +class DexforceW1Type(enum.Enum): + """Component type for DexforceW1.""" + + CHASSIS = "chassis" + TORSO = "torso" + EYES = "eyes" + HEAD = "head" + LEFT_ARM1 = "left_arm" # Anthropomorphic left arm + RIGHT_ARM1 = "right_arm" # Anthropomorphic right arm + LEFT_ARM2 = "left_arm2" # Industrial left arm + RIGHT_ARM2 = "right_arm2" # Industrial right arm + LEFT_HAND = "left_hand" + RIGHT_HAND = "right_hand" + FULL_BODY = "full_body" # Full robot + + +class DexforceW1HandBrand(enum.Enum): + BRAINCO_HAND = "BRAINCO_HAND" + DH_PGC_GRIPPER = "DH_PGC_GRIPPER" + DH_PGC_GRIPPER_M = "DH_PGC_GRIPPER_M" diff --git a/embodichain/lab/sim/robots/dexforce_w1/utils.py b/embodichain/lab/sim/robots/dexforce_w1/utils.py new file mode 100644 index 00000000..1bfbeb6d --- /dev/null +++ b/embodichain/lab/sim/robots/dexforce_w1/utils.py @@ -0,0 +1,746 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import numpy as np +from scipy.spatial.transform import Rotation as R +from typing import List, Dict, Optional + +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1ArmKind, + DexforceW1Type, + DexforceW1ArmSide, + DexforceW1Version, + DexforceW1HandBrand, +) +from embodichain.data import get_data_path +from embodichain.lab.sim.solvers import SolverCfg +from embodichain.lab.sim.cfg import RobotCfg, URDFCfg + + +all = [ + "ChassisManager", + "TorsoManager", + "HeadManager", + "ArmManager", + "HandManager", + "EyesManager", + "build_dexforce_w1_assembly_urdf_cfg", + "build_dexforce_w1_cfg", +] + + +class ChassisManager: + def __init__(self): + self.urdf_paths = { + DexforceW1Version.V021: get_data_path("DexforceW1ChassisV021/chassis.urdf"), + } + + def get_urdf(self, version=DexforceW1Version.V021): + return self.urdf_paths[version] + + def get_config(self, version=DexforceW1Version.V021): + return { + "urdf_path": self.get_urdf(version), + "joint_names": [], + "end_link_name": "base_link", + "root_link_name": "base_link", + } + + +class TorsoManager: + def __init__(self): + self.urdf_paths = { + DexforceW1Version.V021: get_data_path("DexforceW1TorsoV021/torso.urdf"), + } + self.joint_names = ["ANKLE", "KNEE", "BUTTOCK", "WAIST"] + + def get_urdf(self, version=DexforceW1Version.V021): + return self.urdf_paths[version] + + def get_config(self, version=DexforceW1Version.V021): + return { + "urdf_path": self.get_urdf(version), + "joint_names": self.joint_names, + "end_link_name": "waist", + "root_link_name": "base_link", + } + + +class HeadManager: + def __init__(self): + self.urdf_paths = { + DexforceW1Version.V021: get_data_path("DexforceW1HeadV021/head.urdf"), + } + self.joint_names = ["NECK1", "NECK2"] + + def get_urdf(self, version=DexforceW1Version.V021): + return self.urdf_paths[version] + + def get_config(self, version=DexforceW1Version.V021): + return { + "urdf_path": self.get_urdf(version), + "joint_names": self.joint_names, + "end_link_name": "neck2", + "root_link_name": "neck1", + } + + +class EyesManager: + def __init__(self): + self.urdf_paths = { + DexforceW1Version.V021: get_data_path("DexforceW1EyesV021/eyes.urdf"), + } + + def get_urdf(self, version=DexforceW1Version.V021): + return self.urdf_paths[version] + + def get_config(self, version=DexforceW1Version.V021): + return { + "urdf_path": self.get_urdf(version), + "joint_names": [], + "end_link_name": "eyes", + "root_link_name": "base_link", + } + + +class ArmManager: + def __init__(self): + self.urdf_paths = { + ( + DexforceW1ArmKind.ANTHROPOMORPHIC, + DexforceW1ArmSide.LEFT, + DexforceW1Version.V021, + ): get_data_path("DexforceW1LeftArm1V021/left_arm.urdf"), + ( + DexforceW1ArmKind.ANTHROPOMORPHIC, + DexforceW1ArmSide.RIGHT, + DexforceW1Version.V021, + ): get_data_path("DexforceW1RightArm1V021/right_arm.urdf"), + ( + DexforceW1ArmKind.INDUSTRIAL, + DexforceW1ArmSide.LEFT, + DexforceW1Version.V021, + ): get_data_path("DexforceW1LeftArm2V021/left_arm.urdf"), + ( + DexforceW1ArmKind.INDUSTRIAL, + DexforceW1ArmSide.RIGHT, + DexforceW1Version.V021, + ): get_data_path("DexforceW1RightArm2V021/right_arm.urdf"), + } + + def get_urdf(self, kind, side, version=DexforceW1Version.V021): + return self.urdf_paths[(kind, side, version)] + + def get_config(self, kind, side, version=DexforceW1Version.V021): + prefix = "LEFT" if side == DexforceW1ArmSide.LEFT else "RIGHT" + return { + "urdf_path": self.get_urdf(kind, side, version), + "joint_names": [f"{prefix}_J{i}" for i in range(1, 8)], + "end_link_name": f"{prefix.lower()}_ee", + "root_link_name": f"{prefix.lower()}_arm_base", + } + + +class HandManager: + def __init__(self): + self.urdf_paths = { + ( + DexforceW1HandBrand.BRAINCO_HAND, + DexforceW1ArmSide.LEFT, + DexforceW1Version.V021, + ): get_data_path("BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf"), + ( + DexforceW1HandBrand.BRAINCO_HAND, + DexforceW1ArmSide.RIGHT, + DexforceW1Version.V021, + ): get_data_path("BrainCoHandRevo1/BrainCoRightHand/BrainCoRightHand.urdf"), + ( + DexforceW1HandBrand.DH_PGC_GRIPPER, + DexforceW1ArmSide.LEFT, + DexforceW1Version.V021, + ): get_data_path("DH_PGC_140_50/DH_PGC_140_50.urdf"), + ( + DexforceW1HandBrand.DH_PGC_GRIPPER, + DexforceW1ArmSide.RIGHT, + DexforceW1Version.V021, + ): get_data_path("DH_PGC_140_50/DH_PGC_140_50.urdf"), + ( + DexforceW1HandBrand.DH_PGC_GRIPPER_M, + DexforceW1ArmSide.LEFT, + DexforceW1Version.V021, + ): get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf"), + ( + DexforceW1HandBrand.DH_PGC_GRIPPER_M, + DexforceW1ArmSide.RIGHT, + DexforceW1Version.V021, + ): get_data_path("DH_PGC_140_50_M/DH_PGC_140_50_M.urdf"), + } + + def get_config( + self, + brand: DexforceW1HandBrand, + side: DexforceW1ArmSide, + version: DexforceW1Version = DexforceW1Version.V021, + ): + prefix = "LEFT" if side == DexforceW1ArmSide.LEFT else "RIGHT" + if brand == DexforceW1HandBrand.BRAINCO_HAND: + if side == DexforceW1ArmSide.LEFT: + base_link_name = f"{prefix.lower()}_hand_base" + root_link_name = f"{prefix.lower()}_thumb_dist" + joint_names = [ + f"{prefix}_HAND_THUMB1", # Left thumb flexion + f"{prefix}_HAND_THUMB2", # Left thumb abduction/adduction + f"{prefix}_HAND_INDEX", # Left index finger flexion + f"{prefix}_HAND_MIDDLE", # Left middle finger flexion + f"{prefix}_HAND_RING", # Left ring finger flexion + f"{prefix}_HAND_PINKY", # Left pinky finger flexion + ] + else: + base_link_name = f"{prefix.lower()}_hand_base" + root_link_name = f"{prefix.lower()}_thumb_dist" + joint_names = [ + f"{prefix}_HAND_THUMB1", # Right thumb flexion + f"{prefix}_HAND_THUMB2", # Right thumb abduction/adduction + f"{prefix}_HAND_INDEX", # Right index finger flexion + f"{prefix}_HAND_MIDDLE", # Right middle finger flexion + f"{prefix}_HAND_RING", # Right ring finger flexion + f"{prefix}_HAND_PINKY", # Right pinky finger flexion + ] + elif brand == DexforceW1HandBrand.DH_PGC_GRIPPER: + base_link_name = f"{prefix.lower()}_base_link_1" + root_link_name = (f"{prefix.lower()}_finger2_link",) + joint_names = [f"{prefix}_FINGER1_JOINT", f"{prefix}_FINGER2_JOINT"] + elif brand == DexforceW1HandBrand.DH_PGC_GRIPPER_M: + base_link_name = f"{prefix.lower()}_base_link_1" + root_link_name = (f"{prefix.lower()}_finger2",) + joint_names = [f"{prefix}_FINGER1", f"{prefix}_FINGER2"] + else: + raise ValueError(f"Unknown hand brand: {brand}") + + return { + "urdf_path": self.get_urdf(brand, side, version), + "joint_names": joint_names, + "end_link_name": base_link_name, + "root_link_name": root_link_name, + } + + def get_urdf( + self, + brand: DexforceW1HandBrand, + side: DexforceW1ArmSide, + version: DexforceW1Version = DexforceW1Version.V021, + ): + return self.urdf_paths[(brand, side, version)] + + def get_attach_xpos( + self, + brand: DexforceW1HandBrand, + arm_kind: DexforceW1ArmKind = DexforceW1ArmKind.INDUSTRIAL, + is_left: bool = True, + ): + if brand == DexforceW1HandBrand.BRAINCO_HAND: + rot_params = { + (DexforceW1ArmKind.INDUSTRIAL, True): [90, 0, 0], + (DexforceW1ArmKind.INDUSTRIAL, False): [90, 0, 180], + (DexforceW1ArmKind.ANTHROPOMORPHIC, True): [90, 0, 180], + (DexforceW1ArmKind.ANTHROPOMORPHIC, False): [90, 0, 0], + } + attach_xpos = np.eye(4) + rot = R.from_euler("xyz", rot_params[(arm_kind, is_left)], degrees=True) + attach_xpos[:3, :3] = rot.as_matrix() + attach_xpos[2, 3] = 0.0 + return attach_xpos + elif brand == DexforceW1HandBrand.DH_PGC_GRIPPER: + attach_xpos = np.array( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.015], [0, 0, 0, 1]] + ) + attach_xpos[:3, :3] = ( + attach_xpos[:3, :3] + @ R.from_rotvec([0, 0, 90], degrees=True).as_matrix() + ) + return attach_xpos + elif brand == DexforceW1HandBrand.DH_PGC_GRIPPER_M: + attach_xpos = np.eye(4) + attach_xpos[:3, :3] = ( + attach_xpos[:3, :3] + @ R.from_rotvec([0, 0, 90], degrees=True).as_matrix() + ) + return attach_xpos + else: + raise ValueError(f"Unknown brand: {brand}") + + +eyes_manager = EyesManager() +chassis_manager = ChassisManager() +torso_manager = TorsoManager() +head_manager = HeadManager() +arm_manager = ArmManager() +hand_manager = HandManager() + + +def build_dexforce_w1_assembly_urdf_cfg( + arm_kind: DexforceW1ArmKind, + arm_sides: List[DexforceW1ArmSide] = [ + DexforceW1ArmSide.LEFT, + DexforceW1ArmSide.RIGHT, + ], + fname: Optional[str] = "DexforceW1V021", + hand_types: Optional[Dict[DexforceW1ArmSide, DexforceW1HandBrand]] = None, + hand_versions: Optional[Dict[DexforceW1ArmSide, DexforceW1Version]] = None, + hand_attach_xposes: Optional[Dict[DexforceW1ArmSide, np.ndarray]] = None, + include_chassis: bool = True, + include_torso: bool = True, + include_head: bool = True, + include_hand: bool = True, + include_eyes: bool = True, + include_wrist_cameras: bool = True, + component_versions: Optional[Dict[DexforceW1Type, DexforceW1Version]] = None, +) -> URDFCfg: + """ + Assemble DexforceW1 robot urdf configuration. + + Args: + arm_kind: Arm type (anthropomorphic or industrial). + arm_sides: List of arm sides to include (left/right). Default both sides. + fname: Output configuration name. Default "DexforceW1V021". + hand_types: Dict specifying hand brand (DexforceW1HandBrand) for each arm side. Default None, which uses the default brand. + hand_versions: Dict specifying hand version for each arm side. Default None, which uses the default version. + hand_attach_xposes: Dict specifying hand attachment pose for each arm side. Default None, which uses the default attachment pose. + include_chassis: Whether to include chassis. Default True. + include_torso: Whether to include torso. Default True. + include_head: Whether to include head. Default True. + include_hand: Whether to include hand. Default True. + include_wrist_cameras: Whether to include wrist cameras. Default True. + component_versions: Dict specifying version for each robot component. Default all V021. + + Returns: + URDFCfg: Assembled URDF configuration. + """ + + def get_version(t, default=DexforceW1Version.V021): + return (component_versions or {}).get(t, default) + + components = [] + if include_chassis: + components.append( + { + "component_type": "chassis", + "urdf_path": chassis_manager.get_urdf( + get_version(DexforceW1Type.CHASSIS) + ), + } + ) + if include_torso: + components.append( + { + "component_type": "torso", + "urdf_path": torso_manager.get_urdf(get_version(DexforceW1Type.TORSO)), + } + ) + if include_head: + components.append( + { + "component_type": "head", + "urdf_path": head_manager.get_urdf(get_version(DexforceW1Type.HEAD)), + } + ) + + sensors = [] + + if include_eyes: + # TODO: Support user-defined eye transforms + import xml.etree.ElementTree as ET + + attach_xpos = np.array( + [ + [-0.0, 0.25959, -0.96572, 0.091], + [0.0, -0.96572, -0.25959, -0.051], + [-1.0, -0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + joint_xml = """ + + + + + + """ + + link_xml = """ + + + + + + + + """ + + joint_elem = ET.fromstring(joint_xml) + link_elem = ET.fromstring(link_xml) + + sensors.append( + { + "sensor_name": "eyes", + "sensor_source": ([link_elem], [joint_elem]), # eyes_manager.get_urdf() + "parent_component": "head", + "parent_link": "neck2", + "transform": attach_xpos, + "sensor_type": "camera", + } + ) + if include_wrist_cameras: + for arm_side in arm_sides: + # TODO: Support user-defined eye transforms + import xml.etree.ElementTree as ET + + if arm_side == DexforceW1ArmSide.LEFT: + rpy = [2.79252648, 0.0, 1.57079633] + xyz = [0.08, 0.0, 0.06] + tf_xpos = np.eye(4) + tf_xpos[:3, :3] = R.from_rotvec([0, 0, -90], degrees=True).as_matrix() + else: + rpy = [2.79252648, 0.0, 1.57079633] + xyz = [0.08, 0.0, 0.06] + tf_xpos = np.eye(4) + tf_xpos[:3, :3] = R.from_rotvec([0, 0, 90], degrees=True).as_matrix() + + attach_xpos = np.eye(4) + attach_xpos[:3, :3] = R.from_euler("xyz", rpy).as_matrix() + attach_xpos[:3, 3] = xyz + attach_xpos = tf_xpos @ attach_xpos + + joint_xml = f""" + + + + + + """ + + link_xml = f""" + + + + + + + + """ + + joint_elem = ET.fromstring(joint_xml) + link_elem = ET.fromstring(link_xml) + sensors.append( + { + "sensor_name": f"{arm_side.value.lower()}_wrist_camera", + "sensor_source": ([link_elem], [joint_elem]), + "parent_component": f"{arm_side.value}_arm", + "parent_link": f"{arm_side.value}_ee", + "transform": attach_xpos, + "sensor_type": "camera", + } + ) + + for arm_side in arm_sides: + if arm_kind == DexforceW1ArmKind.ANTHROPOMORPHIC: + arm_type = ( + DexforceW1Type.LEFT_ARM1 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM1 + ) + else: + arm_type = ( + DexforceW1Type.LEFT_ARM2 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM2 + ) + arm_version = get_version(arm_type) + arm_cfg = arm_manager.get_config(arm_kind, arm_side, arm_version) + components.append( + { + "component_type": f"{arm_side.value}_arm", + "urdf_path": arm_cfg["urdf_path"], + } + ) + + if include_hand: + for arm_side in arm_sides: + # hand_brand: DexforceW1HandBrand + hand_brand = (hand_types or {}).get( + arm_side, DexforceW1HandBrand.BRAINCO_HAND + ) + hand_version = (hand_versions or {}).get( + arm_side, + get_version( + DexforceW1Type.LEFT_HAND + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_HAND + ), + ) + urdf_path = hand_manager.get_urdf(hand_brand, arm_side, hand_version) + + attach_xpos = (hand_attach_xposes or {}).get( + arm_side, + hand_manager.get_attach_xpos( + hand_brand, arm_kind, arm_side == DexforceW1ArmSide.LEFT + ), + ) + components.append( + { + "component_type": f"{arm_side.value}_hand", + "urdf_path": urdf_path, + "transform": attach_xpos, + } + ) + return URDFCfg(components=components, sensors=sensors, fname=fname) + + +def build_dexforce_w1_solver_cfg( + arm_kind: DexforceW1ArmKind, + arm_sides: List[DexforceW1ArmSide] = [ + DexforceW1ArmSide.LEFT, + DexforceW1ArmSide.RIGHT, + ], + component_versions: Optional[Dict[DexforceW1Type, DexforceW1Version]] = None, + urdf_cfg: Optional[URDFCfg] = None, +) -> Dict[DexforceW1Type, SolverCfg]: + """ + Build DexforceW1 solver configuration dict. + + Args: + arm_kind: Arm type. + arm_sides: Included arm sides. Optional, default both sides. + component_versions: Component version dict. Optional, default all V021. + urdf_cfg: Optional, URDFCfg object from build_dexforce_w1_assembly_urdf_cfg. + + Returns: + Dict[DexforceW1Type, SolverCfg] + """ + + def get_version(t, default=DexforceW1Version.V021): + return (component_versions or {}).get(t, default) + + solver_cfg = {} + + for arm_side in arm_sides: + if arm_kind == DexforceW1ArmKind.ANTHROPOMORPHIC: + arm_type = ( + DexforceW1Type.LEFT_ARM1 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM1 + ) + else: + arm_type = ( + DexforceW1Type.LEFT_ARM2 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM2 + ) + arm_version = get_version(arm_type) + arm_cfg = arm_manager.get_config(arm_kind, arm_side, arm_version) + solver_cfg[arm_type] = SolverCfg.from_dict( + { + "class_type": "PytorchSolver", + "urdf_path": arm_cfg["urdf_path"], + "joint_names": arm_cfg["joint_names"], + "end_link_name": arm_cfg["end_link_name"], + "root_link_name": arm_cfg["root_link_name"], + } + ) + + # Use urdf_cfg.fname if provided, otherwise fallback to default path + full_body_urdf_path = ( + urdf_cfg.fname + if urdf_cfg is not None + else get_data_path("DexforceW1FullBodyV021/full_body.urdf") + ) + + solver_cfg[DexforceW1Type.FULL_BODY] = SolverCfg.from_dict( + { + "class_type": "PytorchSolver", + "urdf_path": full_body_urdf_path, + "joint_names": [ + "ANKLE", + "KNEE", + "BUTTOCK", + "WAIST", + "NECK1", + "NECK2", + "LEFT_J1", + "LEFT_J2", + "LEFT_J3", + "LEFT_J4", + "LEFT_J5", + "LEFT_J6", + "LEFT_J7", + "RIGHT_J1", + "RIGHT_J2", + "RIGHT_J3", + "RIGHT_J4", + "RIGHT_J5", + "RIGHT_J6", + "RIGHT_J7", + ], + "end_link_name": "right_ee", + "root_link_name": "base_link", + } + ) + return solver_cfg + + +def build_dexforce_w1_cfg( + arm_kind: DexforceW1ArmKind, + arm_sides: List[DexforceW1ArmSide] = [ + DexforceW1ArmSide.LEFT, + DexforceW1ArmSide.RIGHT, + ], + hand_types: Optional[Dict[DexforceW1ArmSide, DexforceW1HandBrand]] = None, + hand_versions: Optional[Dict[DexforceW1ArmSide, DexforceW1Version]] = None, + hand_attach_xposes: Optional[Dict[DexforceW1ArmSide, np.ndarray]] = None, + include_chassis: bool = True, + include_torso: bool = True, + include_head: bool = True, + include_hand: bool = True, + component_versions: Optional[Dict[DexforceW1Type, DexforceW1Version]] = None, + solver_cfg: Optional[Dict[DexforceW1Type, SolverCfg]] = None, +) -> "DexforceW1Cfg": + """ + Build DexforceW1 robot configuration object. + + Args: + arm_kind: Arm type (anthropomorphic or industrial). + arm_sides: List of arm sides to include (left/right). Default both sides. + hand_types: Dict specifying hand brand (DexforceW1HandBrand) for each arm side. Default None, which uses the default brand. + hand_versions: Dict specifying hand version for each arm side. Default None, which uses the default version. + hand_attach_xposes: Dict specifying hand attachment pose for each arm side. Default None, which uses the default attachment pose. + include_chassis: Whether to include chassis. Optional, default True. + include_torso: Whether to include torso. Optional, default True. + include_head: Whether to include head. Optional, default True. + include_hand: Whether to include hand. Optional, default True. + include_wrist_cameras: Whether to include wrist cameras. Optional, default True. + component_versions: Dict specifying version for each robot component. + solver_cfg: Optional, pre-defined solver configuration dict. + + Returns: + DexforceW1Cfg: Robot configuration object. + """ + urdf_cfg = build_dexforce_w1_assembly_urdf_cfg( + arm_kind=arm_kind, + arm_sides=arm_sides, + hand_types=hand_types, + hand_versions=hand_versions, + hand_attach_xposes=hand_attach_xposes, + include_chassis=include_chassis, + include_torso=include_torso, + include_head=include_head, + component_versions=component_versions, + ) + + left_arm_joints = [] + right_arm_joints = [] + for arm_side in arm_sides: + if arm_kind == DexforceW1ArmKind.ANTHROPOMORPHIC: + arm_type = ( + DexforceW1Type.LEFT_ARM1 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM1 + ) + else: + arm_type = ( + DexforceW1Type.LEFT_ARM2 + if arm_side == DexforceW1ArmSide.LEFT + else DexforceW1Type.RIGHT_ARM2 + ) + arm_version = (component_versions or {}).get(arm_type, DexforceW1Version.V021) + arm_cfg = arm_manager.get_config(arm_kind, arm_side, arm_version) + if arm_side == DexforceW1ArmSide.LEFT: + left_arm_joints = arm_cfg["joint_names"] + elif arm_side == DexforceW1ArmSide.RIGHT: + right_arm_joints = arm_cfg["joint_names"] + + torso_joints = [] + head_joints = [] + left_hand_joints = [] + right_hand_joints = [] + + if include_torso: + torso_joints = torso_manager.get_config()["joint_names"] + if include_head: + head_joints = head_manager.get_config()["joint_names"] + if include_hand: + if DexforceW1ArmSide.LEFT in arm_sides: + left_hand_brand = (hand_types or {}).get( + DexforceW1ArmSide.LEFT, DexforceW1HandBrand.BRAINCO_HAND + ) + left_hand_version = (hand_versions or {}).get( + DexforceW1ArmSide.LEFT, DexforceW1Version.V021 + ) + left_hand_cfg = hand_manager.get_config( + left_hand_brand, DexforceW1ArmSide.LEFT, left_hand_version + ) + left_hand_joints = left_hand_cfg["joint_names"] + if DexforceW1ArmSide.RIGHT in arm_sides: + right_hand_brand = (hand_types or {}).get( + DexforceW1ArmSide.RIGHT, DexforceW1HandBrand.BRAINCO_HAND + ) + right_hand_version = (hand_versions or {}).get( + DexforceW1ArmSide.RIGHT, DexforceW1Version.V021 + ) + right_hand_cfg = hand_manager.get_config( + right_hand_brand, DexforceW1ArmSide.RIGHT, right_hand_version + ) + right_hand_joints = right_hand_cfg["joint_names"] + + control_parts = {} + + if torso_joints: + control_parts["torso"] = torso_joints + if head_joints: + control_parts["head"] = head_joints + if left_arm_joints: + control_parts["left_arm"] = left_arm_joints + if right_arm_joints: + control_parts["right_arm"] = right_arm_joints + if left_arm_joints and right_arm_joints: + control_parts["dual_arm"] = left_arm_joints + right_arm_joints + if left_hand_joints: + control_parts["left_eef"] = left_hand_joints + if right_hand_joints: + control_parts["right_eef"] = right_hand_joints + + if torso_joints and head_joints and left_arm_joints and right_arm_joints: + control_parts["full_body"] = ( + torso_joints + head_joints + left_arm_joints + right_arm_joints + ) + + from embodichain.lab.sim.robots.dexforce_w1.cfg import DexforceW1Cfg + + cfg = DexforceW1Cfg() + cfg.arm_kind = arm_kind + cfg.urdf_cfg = urdf_cfg + cfg.control_parts = control_parts + + if solver_cfg is not None: + cfg.solver_cfg = solver_cfg + else: + cfg.solver_cfg = build_dexforce_w1_solver_cfg( + arm_kind=arm_kind, + arm_sides=arm_sides, + component_versions=component_versions, + urdf_cfg=urdf_cfg, + ) + + return cfg diff --git a/embodichain/lab/sim/sensors/__init__.py b/embodichain/lab/sim/sensors/__init__.py new file mode 100644 index 00000000..36c4d0e6 --- /dev/null +++ b/embodichain/lab/sim/sensors/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base_sensor import BaseSensor, SensorCfg +from .camera import Camera, CameraCfg +from .stereo import StereoCamera, StereoCameraCfg diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py new file mode 100644 index 00000000..70f78c6e --- /dev/null +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -0,0 +1,175 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import torch + +from abc import abstractmethod +from typing import Dict, List, Any, Optional, Sequence, Tuple, Union +from embodichain.lab.sim.cfg import ObjectBaseCfg +from embodichain.lab.sim.common import BatchEntity +from embodichain.utils.math import matrix_from_quat +from embodichain.lab.sim.utility import get_dexsim_arena_num +from embodichain.utils import configclass, is_configclass, logger + + +@configclass +class SensorCfg(ObjectBaseCfg): + """Configuration class for sensors. + + This class can be extended to include specific sensor configurations. + """ + + @configclass + class OffsetCfg: + """Configuration of the sensor offset relative to the parent frame.""" + + pos: Tuple[float, float, float] = (0.0, 0.0, 0.0) + """Position of the sensor in the parent frame. Defaults to (0.0, 0.0, 0.0).""" + quat: Tuple[float, float, float, float] = (1.0, 0.0, 0.0, 0.0) + """Orientation of the sensor in the parent frame as a quaternion (w, x, y, z). Defaults to (1.0, 0.0, 0.0, 0.0).""" + + parent: Optional[str] = None + """Name of the parent frame. If not specified, the sensor will be placed in the arena frame. + + This is usually the case when the sensor is not attached to any specific object, eg, link of a robot arm. + """ + + @property + def transformation(self) -> torch.Tensor: + pos = torch.tensor(self.pos, dtype=torch.float32) + quat = torch.tensor(self.quat, dtype=torch.float32) + rot = matrix_from_quat(quat.unsqueeze(0)).squeeze(0) + T = torch.eye(4, dtype=torch.float32) + T[:3, :3] = rot + T[:3, 3] = pos + return T + + @classmethod + def from_dict(cls, init_dict: dict) -> SensorCfg.OffsetCfg: + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning(f"Key '{key}' not found in {cls.__name__}.") + return cfg + + @abstractmethod + def get_data_types(self) -> List[str]: + """Get the data types supported by this sensor configuration. + + Returns: + A list of data types that this sensor configuration supports. + """ + return [] + + sensor_type: str = "BaseSensor" + + @classmethod + def from_dict(cls, init_dict: Dict[str, Any]) -> "SensorCfg": + """Initialize the configuration from a dictionary.""" + from embodichain.utils.utility import get_class_instance + + cfg = get_class_instance( + "embodichain.lab.sim.sensors", init_dict["sensor_type"] + "Cfg" + )() + for key, value in init_dict.items(): + if hasattr(cfg, key): + attr = getattr(cfg, key) + if is_configclass(attr): + setattr( + cfg, key, attr.from_dict(value) + ) # Call from_dict on the attribute + else: + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +class BaseSensor(BatchEntity): + """Base class for sensor abstraction in the simulation engine. + + Sensors should inherit from this class and implement the `update` and `get_data` methods. + """ + + SUPPORTED_DATA_TYPES = [] + + def __init__( + self, config: SensorCfg, device: torch.device = torch.device("cpu") + ) -> None: + + self._data_buffer: Dict[str, torch.Tensor] = {} + + self._entities = [None for _ in range(get_dexsim_arena_num())] + self._build_sensor_from_config(config, device=device) + + super().__init__(config, self._entities, device) + + @abstractmethod + def _build_sensor_from_config( + self, config: SensorCfg, device: torch.device + ) -> None: + """Build the sensor from the provided configuration. + + Args: + config: The configuration for the sensor. + device: The device of the sensor + """ + pass + + @abstractmethod + def update(self, **kwargs) -> None: + """Update the sensor state based on the current simulation state. + + This method is called periodically to ensure the sensor data is up-to-date. + + Args: + **kwargs: Additional keyword arguments for sensor update. + """ + pass + + @abstractmethod + def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get the pose of the sensor in the arena frame. + + Args: + to_matrix: If True, return the pose as a 4x4 transformation matrix. + + Returns: + A tensor representing the pose of the sensor in the arena frame. + """ + logger.log_error("Not implemented yet.") + + def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: + """Retrieve data from the sensor. + + Args: + copy: If True, return a copy of the data buffer. Defaults to True. + + Returns: + The data collected by the sensor. + """ + if copy: + return {key: value.clone() for key, value in self._data_buffer.items()} + return self._data_buffer + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + return super().reset(env_ids) diff --git a/embodichain/lab/sim/sensors/camera.py b/embodichain/lab/sim/sensors/camera.py new file mode 100644 index 00000000..11f05bbb --- /dev/null +++ b/embodichain/lab/sim/sensors/camera.py @@ -0,0 +1,537 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import dexsim +import math +import torch +import dexsim.render as dr +import warp as wp + +from functools import cached_property +from typing import Union, Tuple, Optional, Sequence, List + +from embodichain.lab.sim.sensors import BaseSensor, SensorCfg +from embodichain.utils.math import matrix_from_quat, quat_from_matrix, look_at_to_pose +from embodichain.utils.warp.kernels import reshape_tiled_image +from embodichain.utils import logger, configclass +from embodichain.lab.sim.utility.sim_utils import is_rt_enabled + + +@configclass +class CameraCfg(SensorCfg): + """Configuration class for Camera.""" + + @configclass + class ExtrinsicsCfg(SensorCfg.OffsetCfg): + """Configuration class for camera extrinsics. + + The extrinsics define the position and orientation of the camera in the 3D world. + If eye, target, and up are provided, they will be used to compute the extrinsics. + Otherwise, the position and orientation will be set to the defaults. + """ + + eye: Union[Tuple[float, float, float], None] = None + target: Union[Tuple[float, float, float], None] = None + up: Union[Tuple[float, float, float], None] = None + """Alternative way to specify the camera extrinsics using eye, target, and up vectors.""" + + @property + def transformation(self) -> torch.Tensor: + if self.eye: + self.up = (0.0, 0.0, 1.0) if self.up is None else self.up + return look_at_to_pose(self.eye, self.target, self.up).squeeze(0) + else: + return super().transformation + + sensor_type: str = "Camera" + + # Camera parameters + width: int = 640 + height: int = 480 + near: float = 0.005 + far: float = 100.0 + + # The camera intrinsics are defined as (fx, fy, cx, cy) + intrinsics: Tuple[float, float, float, float] = (600, 600, 320.0, 240.0) + extrinsics: ExtrinsicsCfg = ExtrinsicsCfg() + + enable_color: bool = True + enable_depth: bool = False + enable_mask: bool = False + enable_normal: bool = False + enable_position: bool = False + + fx: float = intrinsics[0] + fy: float = intrinsics[1] + cx: float = intrinsics[2] + cy: float = intrinsics[3] + + def get_view_attrib(self) -> dr.ViewFlags: + """Get the view attributes for the camera. + + The camera view whcich is used to render the scene + Default view attributes for the camera are: [COLOR, DEPTH, MASK] + The supported view attributes are: + - COLOR: RGBA images + - DEPTH: Depth images + - MASK: Instance segmentation masks + - NORMAL: Normal images + - POSITION: Position images with 3D coordinates. + + Returns: + The view attributes for the camera. + """ + view_attrib: dr.ViewFlags = dr.ViewFlags.COLOR + # TODO: change for fast-rt renderer backend. + if self.enable_color: + view_attrib |= dr.ViewFlags.COLOR + if self.enable_depth: + if is_rt_enabled() is False: + view_attrib |= dr.ViewFlags.NORMAL + view_attrib |= dr.ViewFlags.DEPTH + if self.enable_mask: + view_attrib |= dr.ViewFlags.MASK + if is_rt_enabled() is False: + view_attrib |= dr.ViewFlags.DEPTH + if self.enable_normal: + view_attrib |= dr.ViewFlags.NORMAL + if self.enable_position: + view_attrib |= dr.ViewFlags.POSITION + return view_attrib + + def get_data_types(self) -> List[str]: + data_types = [] + if self.enable_color: + data_types.append("color") + if self.enable_depth: + data_types.append("depth") + if self.enable_mask: + data_types.append("mask") + if self.enable_normal: + data_types.append("normal") + if self.enable_position: + data_types.append("position") + return data_types + + +class Camera(BaseSensor): + """Base class for sensor abstraction in the simulation engine. + + Sensors should inherit from this class and implement the `update` and `get_data` methods. + """ + + SUPPORTED_DATA_TYPES = ["color", "depth", "mask", "normal", "position"] + + def __init__( + self, config: CameraCfg, device: torch.device = torch.device("cpu") + ) -> None: + super().__init__(config, device) + + def _build_sensor_from_config( + self, config: CameraCfg, device: torch.device + ) -> None: + self._world = dexsim.default_world() + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + num_instances = len(arenas) + + if self.is_rt_enabled: + self._frame_buffer = self._world.create_camera_group( + [config.width, config.height], num_instances, True + ) + + view_attrib = config.get_view_attrib() + for i, arena in enumerate(arenas): + view_name = f"{self.uid}_view{i + 1}" + view = arena.create_camera( + view_name, + config.width, + config.height, + True, + view_attrib, + self._frame_buffer, + ) + view.set_intrinsic(config.intrinsics) + view.set_near(config.near) + view.set_far(config.far) + self._entities[i] = view + + else: + self._grid_size = math.ceil(math.sqrt(num_instances)) + frame_width = self._grid_size * config.width + frame_height = self._grid_size * config.height + view_attrib = config.get_view_attrib() + # Create the data frame + self._frame_buffer = self._world.create_frame_buffer( + [frame_width, frame_height], view_attrib, True + ) + self._frame_buffer.set_read_able(view_attrib) + + # Create camera views + for i, arena in enumerate(arenas): + col = i // self._grid_size + row = i % self._grid_size + x = row * config.width + y = col * config.height + view_name = f"{self.uid}_view{i + 1}" + + view = arena.create_camera_view( + view_name, (x, y), (config.width, config.height), self._frame_buffer + ) + view.set_intrinsic(config.intrinsics) + view.set_near(config.near) + view.set_far(config.far) + view.enable_postprocessing(True) + + self._entities[i] = view + + # Define a mapping of data types to their respective shapes and dtypes + buffer_specs = { + "color": ( + (self.num_instances, config.height, config.width, 4), + torch.uint8, + ), + "depth": ( + (self.num_instances, config.height, config.width), + torch.float32, + ), + "mask": ( + (self.num_instances, config.height, config.width), + torch.int32, + ), + "normal": ( + (self.num_instances, config.height, config.width, 3), + torch.float32, + ), + "position": ( + (self.num_instances, config.height, config.width, 3), + torch.float32, + ), + } + data_types = config.get_data_types() + + # Iterate through enabled data types and initialize buffers + for data_type in data_types: + if getattr(config, f"enable_{data_type}", False): + shape, dtype = buffer_specs[data_type] + self._data_buffer[data_type] = torch.empty( + shape, dtype=dtype, device=device + ) + + self.cfg: CameraCfg = config + if self.cfg.extrinsics.parent is not None: + self._attach_to_entity() + + @cached_property + def is_rt_enabled(self) -> bool: + """Check if Ray Tracing rendering backend is enabled in the default dexsim world. + + Returns: + bool: True if Ray Tracing rendering is enabled, False otherwise. + """ + return is_rt_enabled() + + def update(self, **kwargs) -> None: + """Update the sensor data. + + The supported data types are: + - color: RGB images with shape (B, H, W, 4) and dtype torch.uint8 + - depth: Depth images with shape (B, H, W) and dtype torch.float32 + - mask: Instance segmentation masks with shape (B, H, W) and dtype torch.int32 + - normal: Normal images with shape (B, H, W, 3) and dtype torch.float32 + - position: Position images with shape (B, H, W, 3) and dtype torch.float32 + + Args: + **kwargs: Additional keyword arguments for sensor update. + - fetch_only (bool): If True, only fetch the data from dexsim internal frame buffer without performing rendering. + """ + fetch_only = kwargs.get("fetch_only", False) + if not fetch_only: + if self.is_rt_enabled: + self._frame_buffer.apply() + else: + self._frame_buffer.apply_frame() + + self.cfg: CameraCfg + # TODO: support fetch data from gpu buffer directly. + if self.cfg.enable_color: + if self.is_rt_enabled: + self._data_buffer["color"] = self._frame_buffer.get_rgb_gpu_buffer().to( + self.device + ) + else: + data = self._frame_buffer.get_color_gpu_buffer().to(self.device) + self._update_buffer_impl(data, self._data_buffer["color"]) + + if self.cfg.enable_depth: + data = self._frame_buffer.get_depth_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["depth"] = data + else: + self._update_buffer_impl( + data, self._data_buffer["depth"].unsqueeze_(-1) + ) + self._data_buffer["depth"].squeeze_(-1) + + if self.cfg.enable_mask: + if self.is_rt_enabled: + data = self._frame_buffer.get_visible_mask_gpu_buffer().to( + self.device, torch.int32 + ) + self._data_buffer["mask"] = data + else: + data = self._frame_buffer.get_visible_gpu_buffer().to( + self.device, torch.int32 + ) + self._update_buffer_impl(data, self._data_buffer["mask"].unsqueeze_(-1)) + self._data_buffer["mask"].squeeze_(-1) + + if self.cfg.enable_normal: + data = self._frame_buffer.get_normal_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["normal"] = data + else: + self._update_buffer_impl(data, self._data_buffer["normal"]) + + if self.cfg.enable_position: + data = self._frame_buffer.get_position_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["position"] = data + else: + self._update_buffer_impl(data, self._data_buffer["position"]) + + def _update_buffer_impl( + self, data_buffer: torch.Tensor, data_buffer_out: torch.Tensor + ) -> None: + device = str(self.device) + channel = data_buffer.shape[-1] if data_buffer.dim() >= 3 else 1 + wp.launch( + kernel=reshape_tiled_image, + dim=(self.num_instances, self.cfg.height, self.cfg.width), + inputs=[ + wp.from_torch(data_buffer).flatten(), + wp.from_torch(data_buffer_out), + self.cfg.height, + self.cfg.width, + channel, + self._grid_size, + ], + device="cuda:0" if device == "cuda" else device, + ) + + def _attach_to_entity(self) -> None: + """Attach the sensor to the parent entity in each environment.""" + env = self._world.get_env() + for i, entity in enumerate(self._entities): + + parent = None + if i == 0: + parent = env.find_node(f"{self.cfg.extrinsics.parent}") + else: + parent = env.find_node(f"{self.cfg.extrinsics.parent}.{i-1}") + if parent is None: + logger.log_error( + f"Failed to find parent entity {self.cfg.extrinsics.parent} for sensor {self.cfg.uid}." + ) + + entity.attach_node(parent) + + def set_local_pose( + self, pose: torch.Tensor, env_ids: Optional[Sequence[int]] = None + ) -> None: + """Set the local pose of the camera. + + Note: The pose should be in the OpenGL coordinate system, which means the Y is up and Z is forward. + + Args: + pose (torch.Tensor): The local pose to set, should be a 4x4 transformation matrix. + env_ids (Optional[Sequence[int]]): The environment IDs to set the pose for. If None, set for all environments. + """ + if env_ids is None: + local_env_ids = range(len(self._entities)) + else: + local_env_ids = env_ids + + pose = pose.cpu() + if pose.dim() == 2 and pose.shape[1] == 7: + pose_matrix = torch.eye(4).unsqueeze(0).repeat(pose.shape[0], 1, 1) + pose_matrix[:, :3, 3] = pose[:, :3] + pose_matrix[:, :3, :3] = matrix_from_quat(pose[:, 3:7]) + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose_matrix[i].numpy()) + elif pose.dim() == 3 and pose.shape[1:] == (4, 4): + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].set_local_pose(pose[i].numpy()) + else: + logger.log_error( + f"Invalid pose shape {pose.shape}. Expected (N, 7) or (N, 4, 4)." + ) + + def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get the local pose of the camera. + + Args: + to_matrix (bool): If True, return the pose as a 4x4 matrix. If False, return as a quaternion. + + Returns: + torch.Tensor: The local pose of the camera. + """ + poses = [] + for entity in self._entities: + pose = entity.get_local_pose() + poses.append(torch.as_tensor(pose, dtype=torch.float32)) + + poses = torch.stack(poses, dim=0).to(self.device) + if to_matrix is False: + xyz = poses[:, :3, 3] + quat = quat_from_matrix(poses[:, :3, :3]) + return torch.cat((xyz, quat), dim=-1) + return poses + + def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: + """Get the pose of the sensor in the arena frame. + + Args: + to_matrix (bool): If True, return the pose as a 4x4 transformation matrix. + + Returns: + A tensor representing the pose of the sensor in the arena frame. + """ + from embodichain.lab.sim.utility import get_dexsim_arenas + + arenas = get_dexsim_arenas() + + poses = [] + for i, entity in enumerate(self._entities): + pose = entity.get_world_pose() + pose[:2, 3] -= arenas[i].get_root_node().get_local_pose()[:2, 3] + poses.append(torch.as_tensor(pose, dtype=torch.float32)) + + poses = torch.stack(poses, dim=0).to(self.device) + if to_matrix is False: + xyz = poses[:, :3, 3] + quat = quat_from_matrix(poses[:, :3, :3]) + return torch.cat((xyz, quat), dim=-1) + return poses + + def look_at( + self, + eye: torch.Tensor, + target: torch.Tensor, + up: Optional[torch.Tensor] = None, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """Set the camera to look at a target point. + + Args: + eye (torch.Tensor): The position of the camera (eye) with shape (N, 3). + target (torch.Tensor): The point the camera should look at (target) with shape (N, 3). + up (Optional[torch.Tensor]): The up direction vector. If None, defaults to [0, 0, 1]. + env_ids (Optional[Sequence[int]]): The environment IDs to set the look at for. If None, set for all environments. + """ + if up is None: + up = torch.tensor([[0.0, 0.0, 1.0]]).repeat(eye.shape[0], 1) + + pose = look_at_to_pose(eye, target, up) + # To opengl coordinate system. + pose[:, :3, 1] = -pose[:, :3, 1] + pose[:, :3, 2] = -pose[:, :3, 2] + self.set_local_pose(pose, env_ids=env_ids) + + def set_intrinsics( + self, + intrinsics: torch.Tensor, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """ + Set the camera intrinsics for both left and right cameras. + + Args: + intrinsics (torch.Tensor): The intrinsics for the left camera with shape (4,) / (3, 3) or (N, 4) / (N, 3, 3). + env_ids (Optional[Sequence[int]], optional): The environment ids to set the intrinsics. + If None, set for all environments. Defaults to None. + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + + if intrinsics.dim() == 2 and intrinsics.shape[1] == 3: + intrinsics = intrinsics.unsqueeze(0).repeat(len(ids), 1, 1) + + if intrinsics.dim() == 1: + intrinsics = intrinsics.unsqueeze(0).repeat(len(ids), 1) + + if len(ids) != intrinsics.shape[0]: + logger.log_error( + f"Invalid intrinsics shape {intrinsics.shape} for {len(ids)} environments." + ) + + for i, env_id in enumerate(ids): + entity = self._entities[env_id] + if intrinsics.shape[1] == 3: + entity.set_intrinsic(intrinsics[i].cpu().numpy()) + else: + entity.set_intrinsic(intrinsics[i].cpu().tolist()) + + def get_intrinsics(self) -> torch.Tensor: + """ + Get the camera intrinsics for both left and right cameras. + + Returns: + torch.Tensor: The intrinsics for the left camera with shape (N, 3, 3). + """ + intrinsics = [] + for entity in self._entities: + intrinsics.append( + torch.as_tensor(entity.get_intrinsic(), dtype=torch.float32) + ) + + return torch.stack(intrinsics, dim=0).to(self.device) + + def reset(self, env_ids: Optional[Sequence[int]] = None) -> None: + self.cfg: CameraCfg + + if self.cfg.extrinsics.eye is not None: + eye = ( + torch.tensor(self.cfg.extrinsics.eye, dtype=torch.float32) + .squeeze_(0) + .repeat(self.num_instances, 1) + ) + target = ( + torch.tensor(self.cfg.extrinsics.target, dtype=torch.float32) + .squeeze_(0) + .repeat(self.num_instances, 1) + ) + up = ( + torch.tensor(self.cfg.extrinsics.up, dtype=torch.float32) + .squeeze_(0) + .repeat(self.num_instances, 1) + if self.cfg.extrinsics.up is not None + else None + ) + self.look_at(eye, target, up, env_ids=env_ids) + else: + pose = self.cfg.extrinsics.transformation + pose = pose.unsqueeze_(0).repeat(self.num_instances, 1, 1) + + if self.cfg.extrinsics.parent is None: + # To opengl coordinate system. + pose[:, :3, 1] = -pose[:, :3, 1] + pose[:, :3, 2] = -pose[:, :3, 2] + + self.set_local_pose(pose, env_ids=env_ids) diff --git a/embodichain/lab/sim/sensors/stereo.py b/embodichain/lab/sim/sensors/stereo.py new file mode 100644 index 00000000..d8b8e51d --- /dev/null +++ b/embodichain/lab/sim/sensors/stereo.py @@ -0,0 +1,538 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import dexsim +import math +import torch +import numpy as np +import warp as wp +import dexsim.render as dr + +from typing import Dict, Tuple, List, Optional, Sequence + +from dexsim.utility import inv_transform +from embodichain.lab.sim.sensors import Camera, CameraCfg +from embodichain.utils.warp.kernels import reshape_tiled_image +from embodichain.utils.math import matrix_from_euler +from embodichain.utils import logger, configclass +from embodichain.lab.sim.utility.sim_utils import is_rt_enabled + + +@configclass +class StereoCameraCfg(CameraCfg): + """Configuration class for StereoCamera.""" + + sensor_type: str = "StereoCamera" + + # The camera intrinsics of the right camera. + # The default camera is the left camera. + intrinsics_right: Tuple[float, float, float, float] = (600, 600, 320.0, 240.0) + + left_to_right_pos: Tuple[float, float, float] = (0.05, 0.0, 0.0) + # The rotation from left camera to right camera in degrees. + left_to_right_rot: Tuple[float, float, float] = (0.0, 0.0, 0.0) + + enable_disparity: bool = False + + fx_r: float = intrinsics_right[0] + fy_r: float = intrinsics_right[1] + cx_r: float = intrinsics_right[2] + cy_r: float = intrinsics_right[3] + + @property + def left_to_right(self) -> torch.Tensor: + """Get the transformation matrix from left camera to right camera.""" + left_to_right = torch.eye(4, dtype=torch.float32) + left_to_right[:3, 3] = torch.tensor(self.left_to_right_pos, dtype=torch.float32) + rot = torch.tensor(self.left_to_right_rot, dtype=torch.float32) + left_to_right[:3, :3] = matrix_from_euler(rot.unsqueeze(0)).squeeze(0) + return left_to_right + + @property + def right_to_left(self) -> torch.Tensor: + """Get the transformation matrix from right camera to left camera.""" + return torch.inverse(self.left_to_right) + + def get_data_types(self) -> List[str]: + data_types = [] + if self.enable_color: + data_types.append("color") + data_types.append("color_right") + if self.enable_depth: + data_types.append("depth") + data_types.append("depth_right") + if self.enable_mask: + data_types.append("mask") + data_types.append("mask_right") + if self.enable_normal: + data_types.append("normal") + data_types.append("normal_right") + if self.enable_position: + data_types.append("position") + data_types.append("position_right") + if self.enable_disparity: + data_types.append("disparity") + return data_types + + +class PairCameraView: + def __init__( + self, + left_view: dr.CameraView, + right_view: dr.CameraView, + left_to_right: np.ndarray, + ) -> PairCameraView: + self._left_view = left_view + self._right_view = right_view + self._left_to_right = left_to_right + + self._left_to_center = np.eye(4, dtype=np.float32) + self._left_to_center[:3, 3] = left_to_right[:3, 3] * -0.5 + + self._right_to_center = np.eye(4, dtype=np.float32) + self._right_to_center[:3, 3] = left_to_right[:3, 3] * 0.5 + + def set_local_pose(self, pose: np.ndarray) -> None: + left_pose = pose @ self._left_to_center + right_pose = pose @ self._right_to_center + self._left_view.set_local_pose(left_pose) + self._right_view.set_local_pose(right_pose) + + def get_local_pose(self) -> np.ndarray: + left_pose = self._left_view.get_local_pose() + return left_pose @ inv_transform(self._left_to_center) + + def get_node(self) -> dexsim.engine.Node: + return self._left_view.get_node() + + def attach_node(self, parent: dexsim.engine.Node) -> None: + self._left_view.attach_node(parent) + self._right_view.attach_node(parent) + + +class StereoCamera(Camera): + """Base class for sensor abstraction in the simulation engine. + + Sensors should inherit from this class and implement the `update` and `get_data` methods. + """ + + SUPPORTED_DATA_TYPES = [ + "color", + "depth", + "mask", + "normal", + "position", + "color_right", + "depth_right", + "mask_right", + "normal_right", + "position_right", + "disparity", + ] + + def __init__( + self, + config: StereoCameraCfg, + device: torch.device = torch.device("cpu"), + ) -> None: + super().__init__(config, device) + + # check valid config + if self.cfg.enable_disparity and not self.cfg.enable_depth: + logger.log_error("Disparity can only be enabled when depth is enabled.") + + def _build_sensor_from_config( + self, config: StereoCameraCfg, device: torch.device + ) -> None: + self._world = dexsim.default_world() + env = self._world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + arenas = [env] + num_instances = len(arenas) + + if self.is_rt_enabled: + self._frame_buffer = self._world.create_camera_group( + [config.width, config.height], num_instances * 2, True + ) + view_attrib = config.get_view_attrib() + left_list = [] + right_list = [] + for i, arena in enumerate(arenas): + left_view_name = f"{self.uid}_left_view{i + 1}" + left_view = arena.create_camera( + left_view_name, + config.width, + config.height, + True, + view_attrib, + self._frame_buffer, + ) + left_view.set_intrinsic(config.intrinsics) + left_view.set_near(config.near) + left_view.set_far(config.far) + left_list.append(left_view) + + for i, arena in enumerate(arenas): + right_view_name = f"{self.uid}_right_view{i + 1}" + right_view = arena.create_camera( + right_view_name, + config.width, + config.height, + True, + view_attrib, + self._frame_buffer, + ) + right_view.set_intrinsic(config.intrinsics_right) + right_view.set_near(config.near) + right_view.set_far(config.far) + right_list.append(right_view) + + for i in range(num_instances): + self._entities[i] = PairCameraView( + left_list[i], right_list[i], config.left_to_right.cpu().numpy() + ) + + else: + self._grid_size = math.ceil(math.sqrt(num_instances)) + + # stereo camera has two views, we append the right camera to the left camera's view list + frame_width = self._grid_size * config.width * 2 + frame_height = self._grid_size * config.height + view_attrib = config.get_view_attrib() + + # Create the data frame + self._frame_buffer = self._world.create_frame_buffer( + [frame_width, frame_height], view_attrib, True + ) + self._frame_buffer.set_read_able(view_attrib) + + # Create camera views + for i, arena in enumerate(arenas): + col = i // self._grid_size + row = i % self._grid_size + x = row * config.width * 2 + y = col * config.height + left_view_name = f"{self.uid}_left_view{i + 1}" + + left_view = arena.create_camera_view( + left_view_name, + (x, y), + (config.width, config.height), + self._frame_buffer, + ) + + left_view.set_intrinsic(config.intrinsics) + left_view.set_near(config.near) + left_view.set_far(config.far) + left_view.enable_postprocessing(True) + + right_view_name = f"{self.uid}_right_view{i + 1}" + right_view = arena.create_camera_view( + right_view_name, + (x + config.width, y), + (config.width, config.height), + self._frame_buffer, + ) + right_view.set_intrinsic(config.intrinsics_right) + right_view.set_near(config.near) + right_view.set_far(config.far) + right_view.enable_postprocessing(True) + + self._entities[i] = PairCameraView( + left_view, right_view, config.left_to_right.cpu().numpy() + ) + + # Define a mapping of data types to their respective shapes and dtypes + buffer_specs = { + "color": ( + (self.num_instances, config.height, config.width, 4), + torch.uint8, + ), + "depth": ( + (self.num_instances, config.height, config.width, 1), + torch.float32, + ), + "mask": ( + (self.num_instances, config.height, config.width, 1), + torch.int32, + ), + "normal": ( + (self.num_instances, config.height, config.width, 3), + torch.float32, + ), + "position": ( + (self.num_instances, config.height, config.width, 3), + torch.float32, + ), + "disparity": ( + (self.num_instances, config.height, config.width, 1), + torch.float32, + ), + } + buffer_specs.update( + { + f"{data_type}_right": buffer_specs[data_type] + for data_type in ["color", "depth", "mask", "normal", "position"] + } + ) + data_types = config.get_data_types() + + # stereo buffer to store data for left and right cameras + # the data in `_data_buffer` is shared with the data in `_data_buffer_stereo`. + self._data_buffer_stereo: Dict[str, torch.Tensor] = {} + + # Iterate through enabled data types and initialize buffers + for data_type in data_types: + if "right" in data_type: + continue + if getattr(config, f"enable_{data_type}", False): + shape, dtype = buffer_specs[data_type] + if data_type == "disparity": + self._data_buffer[data_type] = torch.empty( + shape, dtype=dtype, device=device + ) + + # create new shape with width * 2 for stereo camera + shape_ = (shape[0], shape[1], shape[2] * 2, shape[3]) + + self._data_buffer_stereo[data_type] = torch.empty( + shape_, dtype=dtype, device=device + ) + self._data_buffer[data_type] = self._data_buffer_stereo[data_type][ + :, :, : config.width, : + ] + self._data_buffer[f"{data_type}_right"] = self._data_buffer_stereo[ + data_type + ][:, :, config.width :, :] + + self.cfg: CameraCfg = config + if self.cfg.extrinsics.parent is not None: + self._attach_to_entity() + + def update(self, **kwargs) -> None: + """Update the sensor data. + + The supported data types are: + - color: RGB images with shape (B, H, W, 4) and dtype torch.uint8 + - depth: Depth images with shape (B, H, W, 1) and dtype torch.float32 + - mask: Instance segmentation masks with shape (B, H, W, 1) and dtype torch.int32 + - normal: Normal images with shape (B, H, W, 3) and dtype torch.float32 + - position: Position images with shape (B, H, W, 3) and dtype torch.float32 + - disparity: Disparity images with shape (B, H, W, 1) and dtype torch.float32 + Args: + **kwargs: Additional keyword arguments for sensor update. + - fetch_only (bool): If True, only fetch the data from dexsim internal frame buffer without performing rendering. + """ + + fetch_only = kwargs.get("fetch_only", False) + if not fetch_only: + if self.is_rt_enabled: + self._frame_buffer.apply() + else: + self._frame_buffer.apply_frame() + + self.cfg: StereoCameraCfg + if self.cfg.enable_color: + if self.is_rt_enabled: + data = self._frame_buffer.get_rgb_gpu_buffer().to(self.device) + self._data_buffer["color"] = data[: self.num_instances, ...] + self._data_buffer[f"color_right"] = data[self.num_instances :, ...] + else: + data = self._frame_buffer.get_color_gpu_buffer().to(self.device) + self._update_buffer_impl(data, self._data_buffer_stereo["color"]) + if self.cfg.enable_depth: + data = self._frame_buffer.get_depth_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["depth"] = data[: self.num_instances, ...].unsqueeze_( + -1 + ) + self._data_buffer[f"depth_right"] = data[ + self.num_instances :, ... + ].unsqueeze_(-1) + else: + self._update_buffer_impl(data, self._data_buffer_stereo["depth"]) + if self.cfg.enable_mask: + if self.is_rt_enabled: + data = self._frame_buffer.get_visible_mask_gpu_buffer().to( + self.device, torch.int32 + ) + self._data_buffer["mask"] = data[: self.num_instances, ...].unsqueeze_( + -1 + ) + self._data_buffer[f"mask_right"] = data[ + self.num_instances :, ... + ].unsqueeze_(-1) + else: + data = self._frame_buffer.get_visible_gpu_buffer().to( + self.device, torch.int32 + ) + self._update_buffer_impl(data, self._data_buffer_stereo["mask"]) + if self.cfg.enable_normal: + data = self._frame_buffer.get_normal_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["normal"] = data[: self.num_instances, ...] + self._data_buffer[f"normal_right"] = data[self.num_instances :, ...] + else: + self._update_buffer_impl(data, self._data_buffer_stereo["normal"]) + if self.cfg.enable_position: + data = self._frame_buffer.get_position_gpu_buffer().to(self.device) + if self.is_rt_enabled: + self._data_buffer["position"] = data[: self.num_instances, ...] + self._data_buffer[f"position_right"] = data[self.num_instances :, ...] + else: + self._update_buffer_impl(data, self._data_buffer_stereo["position"]) + if self.cfg.enable_disparity: + disparity = self._data_buffer["disparity"] + disparity.fill_(0.0) + distance = torch.sqrt( + torch.sum(torch.square(self.cfg.left_to_right[:3, 3])) + ) + # Compute disparity only for non-zero depth values + depth = self._data_buffer["depth"] + valid_depth_mask = depth > 0 + disparity[valid_depth_mask] = ( + self.cfg.fx * distance / depth[valid_depth_mask] + ) + + def _update_buffer_impl( + self, data_buffer: torch.Tensor, data_buffer_out: torch.Tensor + ) -> None: + device = str(self.device) + channel = data_buffer.shape[-1] if data_buffer.dim() >= 3 else 1 + wp.launch( + kernel=reshape_tiled_image, + dim=(self.num_instances, self.cfg.height, self.cfg.width * 2), + inputs=[ + wp.from_torch(data_buffer).flatten(), + wp.from_torch(data_buffer_out), + self.cfg.height, + self.cfg.width * 2, + channel, + self._grid_size, + ], + device="cuda:0" if device == "cuda" else device, + ) + + def get_left_right_arena_pose(self) -> torch.Tensor: + """Get the local pose of the left and right cameras. + + Returns: + torch.Tensor: The local pose of the left camera with shape (num_envs, 4, 4). + """ + from embodichain.lab.sim.utility import get_dexsim_arenas + + arenas = get_dexsim_arenas() + + left_poses = [] + right_poses = [] + for i, entity in enumerate(self._entities): + arena_pose = arenas[i].get_root_node().get_local_pose() + left_pose = entity._left_view.get_world_pose() + left_pose[:2, 3] -= arena_pose[:2, 3] + left_poses.append( + torch.as_tensor( + left_pose, + dtype=torch.float32, + ) + ) + right_pose = entity._right_view.get_world_pose() + right_pose[:2, 3] -= arena_pose[:2, 3] + right_poses.append( + torch.as_tensor( + right_pose, + dtype=torch.float32, + ) + ) + return torch.stack(left_poses, dim=0).to(self.device), torch.stack( + right_poses, dim=0 + ).to(self.device) + + def set_intrinsics( + self, + intrinsics: torch.Tensor, + right_intrinsics: Optional[torch.Tensor] = None, + env_ids: Optional[Sequence[int]] = None, + ) -> None: + """ + Set the camera intrinsics for both left and right cameras. + + Args: + intrinsics (torch.Tensor): The intrinsics for the left camera with shape (4,) / (3, 3) or (B, 4) / (B, 3, 3). + right_intrinsics (Optional[torch.Tensor], optional): The intrinsics for the right camera with shape 4,) / (3, 3) or (B, 4) / (B, 3, 3). + If None, use the same intrinsics as the left camera. Defaults to None. + env_ids (Optional[Sequence[int]], optional): The environment ids to set the intrinsics. If None, set for all environments. + Defaults to None. + """ + ids = env_ids if env_ids is not None else range(self.num_instances) + + if intrinsics.dim() == 2 and intrinsics.shape[1] == 3: + intrinsics = intrinsics.unsqueeze(0).repeat(len(ids), 1, 1) + + if intrinsics.dim() == 1: + intrinsics = intrinsics.unsqueeze(0).repeat(len(ids), 1) + + if len(ids) != intrinsics.shape[0]: + logger.log_error( + f"Intrinsics shape {intrinsics.shape} does not match env_ids length {len(ids)}" + ) + + if right_intrinsics is None: + right_intrinsics = intrinsics + else: + if right_intrinsics.dim() == 2 and right_intrinsics.shape[1] == 3: + right_intrinsics = right_intrinsics.unsqueeze(0).repeat(len(ids), 1, 1) + + if right_intrinsics.dim() == 1: + right_intrinsics = right_intrinsics.unsqueeze(0).repeat(len(ids), 1) + + if len(ids) != right_intrinsics.shape[0]: + logger.log_error( + f"Right intrinsics shape {right_intrinsics.shape} does not match env_ids length {len(ids)}" + ) + + for i, env_id in enumerate(ids): + entity = self._entities[env_id] + if intrinsics.shape[1] == 3: + entity._left_view.set_intrinsic(intrinsics[i].cpu().numpy()) + entity._right_view.set_intrinsic(right_intrinsics[i].cpu().numpy()) + else: + entity._left_view.set_intrinsic(intrinsics[i].cpu().tolist()) + entity._right_view.set_intrinsic(right_intrinsics[i].cpu().tolist()) + + def get_intrinsics(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the camera intrinsics for both left and right cameras. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The intrinsics for the left and right cameras with shape (B, 3, 3). + """ + intrinsics_left = [] + intrinsics_right = [] + for entity in self._entities: + intrinsics_left.append( + torch.as_tensor(entity._left_view.get_intrinsic(), dtype=torch.float32) + ) + intrinsics_right.append( + torch.as_tensor(entity._right_view.get_intrinsic(), dtype=torch.float32) + ) + + return ( + torch.stack(intrinsics_left, dim=0).to(self.device), + torch.stack(intrinsics_right, dim=0).to(self.device), + ) diff --git a/embodichain/lab/sim/shapes.py b/embodichain/lab/sim/shapes.py new file mode 100644 index 00000000..8135943b --- /dev/null +++ b/embodichain/lab/sim/shapes.py @@ -0,0 +1,144 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Optional, List, Dict, Union, TYPE_CHECKING, Any +from dataclasses import MISSING +from embodichain.utils import configclass, is_configclass, logger + +if TYPE_CHECKING: + from embodichain.lab.sim.material import VisualMaterialCfg + + +@configclass +class LoadOption: + + rebuild_normals: bool = False + """Whether to rebuild normals for the shape. Defaults to False.""" + + rebuild_tangent: bool = False + """Whether to rebuild tangents for the shape. Defaults to False.""" + + rebuild_3rdnormal: bool = False + """Whether to rebuild the normal for the shape using 3rd party library. Defaults to False.""" + + rebuild_3rdtangent: bool = False + """Whether to rebuild the tangent for the shape using 3rd party library. Defaults to False.""" + + smooth: float = -1.0 + """Angle threshold (in degrees) for smoothing normals. Defaults to -1.0 (no smoothing).""" + + @classmethod + def from_dict(cls, init_dict: Dict[str, Any]) -> LoadOption: + """Initialize the configuration from a dictionary.""" + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +@configclass +class ShapeCfg: + + shape_type: str = MISSING + """Type of the shape. Must be specified in subclasses.""" + + visual_material: Optional[VisualMaterialCfg] = None + """Configuration parameters for the visual material of the shape. Defaults to None.""" + + @classmethod + def from_dict(cls, init_dict: Dict[str, Any]) -> ShapeCfg: + """Initialize the configuration from a dictionary.""" + from embodichain.utils.utility import get_class_instance + + if "shape_type" not in init_dict: + logger.log_error("shape type must be specified in the configuration.") + + cfg = get_class_instance( + "embodichain.lab.sim.shapes", init_dict["shape_type"] + "Cfg" + )() + for key, value in init_dict.items(): + if hasattr(cfg, key): + attr = getattr(cfg, key) + if key == "visual_material" and isinstance(value, dict): + setattr( + cfg, + key, + VisualMaterialCfg.from_dict(value), + ) + elif is_configclass(attr): + setattr(cfg, key, attr.from_dict(value)) + else: + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +@configclass +class MeshCfg(ShapeCfg): + """Configuration parameters for a triangle mesh shape.""" + + shape_type: str = "Mesh" + + fpath: str = MISSING + """File path to the shape mesh file.""" + + load_option: LoadOption = LoadOption() + """Options for loading and processing the shape. + + Please refer to dexsim.types.LoadOption for more details: http://192.168.3.120/MixedAI/docs_dev/dexsim/tutorial/basics/physics/actor.html + """ + + compute_uv: bool = False + """Whether to compute UV coordinates for the shape. Defaults to False. + + If the shape already has UV coordinates, setting this to True will recompute and overwrite them. + """ + + project_direction: List[float] = [1.0, 1.0, 1.0] + """Direction to project the UV coordinates. Defaults to [1.0, 1.0, 1.0].""" + + +@configclass +class CubeCfg(ShapeCfg): + """Configuration parameters for a cube shape.""" + + shape_type: str = "Cube" + + size: List[float] = [1.0, 1.0, 1.0] + """Size of the cube (in m) as [length, width, height].""" + + +@configclass +class SphereCfg(ShapeCfg): + """Configuration parameters for a sphere shape.""" + + shape_type: str = "Sphere" + + radius: float = 1.0 + """Radius of the sphere (in m).""" + + resolution: int = 20 + """Resolution of the sphere mesh. Defaults to 20.""" diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py new file mode 100644 index 00000000..7d2e216e --- /dev/null +++ b/embodichain/lab/sim/sim_manager.py @@ -0,0 +1,1486 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import sys +import dexsim +import torch +import numpy as np +import warp as wp + +from tqdm import tqdm +from pathlib import Path +from copy import deepcopy +from functools import cached_property +from typing import List, Union, Optional, Dict, Tuple, Union, Sequence +from dataclasses import dataclass, asdict, field, MISSING + +# Global cache directories +SIM_CACHE_DIR = Path.home() / ".cache" / "embodichain_cache" +MATERIAL_CACHE_DIR = SIM_CACHE_DIR / "mat_cache" +CONVEX_DECOMP_DIR = SIM_CACHE_DIR / "convex_decomposition" +REACHABLE_XPOS_DIR = SIM_CACHE_DIR / "robot_reachable_xpos" + +from dexsim.types import ( + Backend, + ThreadMode, + PhysicalAttr, + ActorType, + RigidBodyShape, + RigidBodyGPUAPIReadType, + ArticulationGPUAPIReadType, +) +from dexsim.engine import CudaArray, Material +from dexsim.models import MeshObject +from dexsim.render import Light as _Light, LightType +from dexsim.render import GizmoController + +from embodichain.lab.sim.objects import ( + RigidObject, + RigidObjectGroup, + SoftObject, + Articulation, + Robot, + Light, +) +from embodichain.lab.sim.objects.gizmo import Gizmo +from embodichain.lab.sim.sensors import ( + SensorCfg, + BaseSensor, + Camera, + StereoCamera, +) +from embodichain.lab.sim.cfg import ( + PhysicsCfg, + MarkerCfg, + GPUMemoryCfg, + LightCfg, + RigidObjectCfg, + SoftObjectCfg, + RigidObjectGroupCfg, + ArticulationCfg, + RobotCfg, +) +from embodichain.lab.sim import VisualMaterial, VisualMaterialCfg +from embodichain.data.assets import SimResources +from embodichain.utils import configclass, logger + +__all__ = [ + "SimulationManager", + "SimulationManagerCfg", + "SIM_CACHE_DIR", + "MATERIAL_CACHE_DIR", + "CONVEX_DECOMP_DIR", + "REACHABLE_XPOS_DIR", +] + + +@configclass +class SimulationManagerCfg: + """Global robot simulation configuration.""" + + width: int = 1920 + """The width of the simulation window.""" + + height: int = 1080 + """The height of the simulation window.""" + + headless: bool = False + """Whether to run the simulation in headless mode (no Window).""" + + enable_rt: bool = False + """Whether to enable ray tracing rendering.""" + + enable_denoiser: bool = True + """Whether to enable denoising for ray tracing rendering.""" + + spp: int = 64 + """Samples per pixel for ray tracing rendering. This parameter is only valid when ray tracing is enabled and enable_denoiser is False.""" + + gpu_id: int = 0 + """The gpu index that the simulation engine will be used. + + Note: it will affect the gpu physics device if using gpu physics. + """ + + thread_mode: ThreadMode = ThreadMode.RENDER_SHARE_ENGINE + """The threading mode for the simulation engine. + + - RENDER_SHARE_ENGINE: The rendering thread shares the same thread with the simulation engine. + - RENDER_SCENE_SHARE_ENGINE: The rendering thread and scene update thread share the same thread with the simulation engine. + """ + + arena_space: float = 5.0 + """The distance between each arena when building multiple arenas.""" + + physics_dt: float = 1.0 / 100.0 + """The time step for the physics simulation.""" + + sim_device: Union[str, torch.device] = "cpu" + """The device for the simulation engine. Can be 'cpu', 'cuda', or a torch.device object.""" + + physics_config: PhysicsCfg = field(default_factory=PhysicsCfg) + """The physics configuration parameters.""" + gpu_memory_config: GPUMemoryCfg = field(default_factory=GPUMemoryCfg) + """The GPU memory configuration parameters.""" + + +class SimulationManager: + r"""Global Embodied AI simulation manager. + + This class is used to manage the global simulation environment and simulated assets. + - assets loading, creation, modification and deletion. + - assets include robots, fixed actors, dynamic actors and background. + - manager the scenes and the simulation environment. + - parallel scenes simulation on both CPU and GPU. + - sensors arrangement + - lighting and indirect lighting + - physics simulation parameters control + - ... + + Note: + 1. The arena is used as a standalone space for robots to simulate in. When :meth:`build_multiple_arenas` is called, + it will create multiple arenas in a grid pattern. Meanwhile, each simulation assets adding interface will + take an additional parameter `arena_index` to specify which arena to place the asset. The name of the asset to + be added will be appended with the arena index to avoid name conflict. + 2. In GUI mode, the physics will be set to a fps (or a wait time for manual mode) for better visualization. + + + Args: + sim_config (SimulationManagerCfg, optional): simulation configuration. Defaults to SimulationManagerCfg(). + """ + + SUPPORTED_SENSOR_TYPES = {"Camera": Camera, "StereoCamera": StereoCamera} + + def __init__( + self, sim_config: SimulationManagerCfg = SimulationManagerCfg() + ) -> None: + # Cache paths + self._sim_cache_dir = SIM_CACHE_DIR + self._material_cache_dir = MATERIAL_CACHE_DIR + self._convex_decomp_dir = CONVEX_DECOMP_DIR + self._reachable_xpos_dir = REACHABLE_XPOS_DIR + + # Setup cache file path. + for path in [ + self._sim_cache_dir, + self._material_cache_dir, + self._convex_decomp_dir, + self._reachable_xpos_dir, + ]: + os.makedirs(path, exist_ok=True) + + self.sim_config = sim_config + self.device = torch.device("cpu") + + world_config = self._convert_sim_config(sim_config) + + # Initialize warp runtime context before creating the world. + wp.init() + self._world = dexsim.World(world_config) + + fps = int(1.0 / sim_config.physics_dt) + self._world.set_physics_fps(fps) + + self._world.set_time_scale(1.0) + self._world.set_delta_time(sim_config.physics_dt) + self._world.show_coordinate_axis(False) + + if sys.platform == "linux": + dexsim.set_physics_config(**sim_config.physics_config.to_dexsim_args()) + dexsim.set_physics_gpu_memory_config( + **sim_config.gpu_memory_config.to_dict() + ) + + self._is_initialized_gpu_physics = False + self._ps = self._world.get_physics_scene() + + # activate physics + self.enable_physics(True) + + self._env = self._world.get_env() + + self._default_resources = SimResources() + + # set unique material path to accelerate material creation. + if self.sim_config.enable_rt is False: + self._env.set_unique_mat_path( + os.path.join(self._material_cache_dir, "dexsim_mat") + ) + + # arena is used as a standalone space for robots to simulate in. + self._arenas: List[dexsim.environment.Arena] = [] + + # gizmo management + self._gizmos: Dict[str, object] = dict() # Store active gizmos + + # marker management + self._markers: Dict[str, MeshObject] = dict() + + self._rigid_objects: Dict[str, RigidObject] = dict() + self._rigid_object_groups: Dict[str, RigidObjectGroup] = dict() + self._soft_objects: Dict[str, SoftObject] = dict() + self._articulations: Dict[str, Articulation] = dict() + self._robots: Dict[str, Robot] = dict() + + self._sensors: Dict[str, BaseSensor] = dict() + self._lights: Dict[str, _Light] = dict() + + # material placeholder. + self._visual_materials: Dict[str, VisualMaterial] = dict() + + # Global texture cache for material creation or randomization. + # The structure is keys to the loaded texture data. The keys represent the texture group. + self._texture_cache: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] = dict() + + # TODO: maybe need to add some interface to interact with background and layouts. + # background and layouts are 3d assets that can has only render body for visualization. + + self._create_default_plane() + self.set_default_background() + + def _convert_sim_config( + self, sim_config: SimulationManagerCfg + ) -> dexsim.WorldConfig: + world_config = dexsim.WorldConfig() + win_config = dexsim.WindowsConfig() + win_config.width = sim_config.width + win_config.height = sim_config.height + world_config.win_config = win_config + world_config.open_windows = not sim_config.headless + self.is_window_opened = not sim_config.headless + world_config.backend = Backend.VULKAN + world_config.thread_mode = sim_config.thread_mode + world_config.cache_path = str(self._material_cache_dir) + world_config.length_tolerance = sim_config.physics_config.length_tolerance + world_config.speed_tolerance = sim_config.physics_config.speed_tolerance + + if sim_config.enable_rt: + world_config.renderer = dexsim.types.Renderer.FASTRT + if sim_config.enable_denoiser is False: + world_config.raytrace_config.spp = sim_config.spp + world_config.raytrace_config.open_denoise = False + + if type(sim_config.sim_device) is str: + self.device = torch.device(sim_config.sim_device) + else: + self.device = sim_config.sim_device + + if self.device.type == "cuda": + world_config.enable_gpu_sim = True + world_config.direct_gpu_api = True + + if self.device.index is not None and sim_config.gpu_id != self.device.index: + logger.log_warning( + f"Conflict gpu_id {sim_config.gpu_id} and device index {self.device.index}. Using device index." + ) + sim_config.gpu_id = self.device.index + + self.device = torch.device(f"cuda:{sim_config.gpu_id}") + + world_config.gpu_id = sim_config.gpu_id + + return world_config + + def get_default_resources(self) -> SimResources: + """Get the default resources instance. + + Returns: + SimResources: The default resources path. + """ + return self._default_resources + + @property + def num_envs(self) -> int: + """Get the number of arenas in the simulation. + + Returns: + int: number of arenas. + """ + return len(self._arenas) if len(self._arenas) > 0 else 1 + + @cached_property + def is_use_gpu_physics(self) -> bool: + """Check if the physics simulation is using GPU.""" + world_config = dexsim.get_world_config() + return self.device.type == "cuda" and world_config.enable_gpu_sim + + @property + def is_rt_enabled(self) -> bool: + """Check if Ray Tracing rendering backend is enabled.""" + return self.sim_config.enable_rt + + @property + def is_physics_manually_update(self) -> bool: + return self._world.is_physics_manually_update() + + @property + def asset_uids(self) -> List[str]: + """Get all assets uid in the simulation. + + The assets include lights, sensors, robots, rigid objects and articulations. + + Returns: + List[str]: list of all assets uid. + """ + uid_list = ["default_plane"] + uid_list.extend(list(self._lights.keys())) + uid_list.extend(list(self._sensors.keys())) + uid_list.extend(list(self._robots.keys())) + uid_list.extend(list(self._rigid_objects.keys())) + uid_list.extend(list(self._rigid_object_groups.keys())) + uid_list.extend(list(self._articulations.keys())) + return uid_list + + def enable_physics(self, enable: bool) -> None: + """Enable or disable physics simulation. + + Args: + enable (bool): whether to enable physics simulation. + """ + self._world.enable_physics(enable) + + def set_manual_update(self, enable: bool) -> None: + """Set manual update for physics simulation. + + If enable is True, the physics simulation will be updated manually by calling :meth:`update`. + If enable is False, the physics simulation will be updated automatically by the engine thread loop. + + Args: + enable (bool): whether to enable manual update. + """ + self._world.set_manual_update(enable) + + def init_gpu_physics(self) -> None: + """Initialize the GPU physics simulation.""" + if self.device.type != "cuda": + logger.log_warning( + "The simulation device is not cuda, cannot initialize GPU physics." + ) + return + + if self._is_initialized_gpu_physics: + return + + # init rigid body. + rigid_body_num = ( + 0 + if self._get_non_static_rigid_obj_num() == 0 + else len(self._ps.gpu_rigid_indices) + ) + self._rigid_body_pose = torch.zeros( + (rigid_body_num, 7), dtype=torch.float32, device=self.device + ) + + # init articulation. + articulation_num = ( + 0 + if len(self._articulations) == 0 and len(self._robots) == 0 + else len(self._ps.gpu_articulation_indices) + ) + max_link_count = self._ps.gpu_get_articulation_max_link_count() + self._link_pose = torch.zeros( + (articulation_num, max_link_count, 7), + dtype=torch.float32, + device=self.device, + ) + for art in self._articulations.values(): + art.reallocate_body_data() + for robot in self._robots.values(): + robot.reallocate_body_data() + + # We do not perform reallocate body data for robot. + + self._is_initialized_gpu_physics = True + + def render_camera_group(self) -> None: + """Render all camera group in the simulation. + + Note: This interface is only valid when Ray Tracing rendering backend is enabled. + """ + + if self.is_rt_enabled: + self._world.render_camera_group() + else: + logger.log_warning( + "This interface is only valid when Ray Tracing rendering backend is enabled." + ) + + def update(self, physics_dt: Optional[float] = None, step: int = 10) -> None: + """Update the physics. + + Args: + physics_dt (Optional[float], optional): the time step for physics simulation. Defaults to None. + step (int, optional): the number of steps to update physics. Defaults to 10. + """ + if self.is_use_gpu_physics and not self._is_initialized_gpu_physics: + logger.log_warning( + f"Using GPU physics, but not initialized yet. Forcing initialization." + ) + self.init_gpu_physics() + + if self.is_physics_manually_update: + if physics_dt is None: + physics_dt = self.sim_config.physics_dt + for i in range(step): + self._world.update(physics_dt) + + if self.sim_config.enable_rt is False: + self._sync_gpu_data() + + else: + logger.log_warning("Physics simulation is not manually updated.") + + def _sync_gpu_data(self) -> None: + if not self.is_use_gpu_physics: + return + + if not self._is_initialized_gpu_physics: + logger.log_warning( + "GPU physics is not initialized. Skipping GPU data synchronization." + ) + return + + if self.is_window_opened or self._sensors: + if len(self._rigid_body_pose) > 0: + self._ps.gpu_fetch_rigid_body_data( + data=CudaArray(self._rigid_body_pose), + gpu_indices=self._ps.gpu_rigid_indices, + data_type=RigidBodyGPUAPIReadType.POSE, + ) + + if len(self._link_pose) > 0: + self._ps.gpu_fetch_link_data( + data=CudaArray(self._link_pose), + gpu_indices=self._ps.gpu_articulation_indices, + data_type=ArticulationGPUAPIReadType.LINK_GLOBAL_POSE, + ) + + # TODO: might be optimized. + self._world.sync_poses_gpu_to_cpu( + rigid_pose=CudaArray(self._rigid_body_pose), + link_pose=CudaArray(self._link_pose), + ) + + def get_env(self, arena_index: int = -1) -> dexsim.environment.Arena: + """Get the arena or env by index. + + If arena_index is -1, return the global env. + If arena_index is valid, return the corresponding arena. + + Args: + arena_index (int, optional): the index of arena to get, -1 for global env. Defaults to -1. + + Returns: + dexsim.environment.Arena: The arena or global env. + """ + if arena_index >= 0: + if arena_index > len(self._arenas) - 1: + logger.log_error( + f"Invalid arena index: {arena_index}. Current number of arenas: {len(self._arenas)}" + ) + return self._arenas[arena_index] + else: + return self._env + + def get_world(self) -> dexsim.World: + return self._world + + def open_window(self) -> None: + """Open the simulation window.""" + self._world.open_window() + self.is_window_opened = True + + def close_window(self) -> None: + """Close the simulation window.""" + self._world.close_window() + self.is_window_opened = False + + def build_multiple_arenas(self, num: int, space: Optional[float] = None) -> None: + """Build multiple arenas in a grid pattern. + + This interface is used for vectorized simulation. + + Args: + num (int): number of arenas to build. + space (float, optional): The distance between each arena. Defaults to the arena_space in sim_config. + """ + + if space is None: + space = self.sim_config.arena_space + + if num <= 0: + logger.log_warning("Number of arenas must be greater than 0.") + return + + scene_grid_length = int(np.ceil(np.sqrt(num))) + + for i in range(num): + arena = self._env.add_arena(f"arena_{i}") + + id_x, id_y = i % scene_grid_length, i // scene_grid_length + arena.set_root_node_position([id_x * space, id_y * space, 0]) + self._arenas.append(arena) + + def set_indirect_lighting(self, name: str) -> None: + """Set indirect lighting. + + Args: + name (str): name of path of the indirect lighting. + """ + if name.startswith("/") is False: + ibl_path = self._default_resources.get_ibl_path(name) + logger.log_info(f"Set IBL {name} from sim default resources.") + else: + ibl_path = name + logger.log_info(f"Set IBL {name} from custom path.") + + self._env.set_IBL(ibl_path) + + def set_emission_light( + self, color: Optional[Sequence[float]] = None, intensity: Optional[float] = None + ) -> None: + """Set environment emission light. + + Args: + color (Sequence[float]): color of the light. + intensity (float): intensity of the light. + """ + if color is None: + self._env.set_env_light_emission(color) + if intensity is None: + self._env.set_env_light_intensity(intensity) + + def _create_default_plane(self): + default_length = 1000 + repeat_uv_size = int(default_length / 2) + self._default_plane = self._env.create_plane( + 0, default_length, repeat_uv_size, repeat_uv_size + ) + self._default_plane.set_name("default_plane") + plane_collision = self._env.create_cube( + default_length, default_length, default_length / 10 + ) + plane_collision_pose = np.eye(4, dtype=float) + plane_collision_pose[2, 3] = -default_length / 20 - 0.001 + plane_collision.set_local_pose(plane_collision_pose) + plane_collision.add_rigidbody(ActorType.KINEMATIC, RigidBodyShape.CONVEX) + + # TODO: add default physics attributes for the plane. + + def set_default_background(self) -> None: + """Set default background.""" + + mat_name = "plane_mat" + mat = None + mat_path = self._default_resources.get_material_path("PlaneDark") + color_texture = os.path.join(mat_path, "PlaneDark_2K_Color.jpg") + roughness_texture = os.path.join(mat_path, "PlaneDark_2K_Roughness.jpg") + mat = self.create_visual_material( + cfg=VisualMaterialCfg( + uid=mat_name, + base_color_texture=color_texture, + roughness_texture=roughness_texture, + ) + ) + + if self.sim_config.enable_rt: + self.set_emission_light([0.1, 0.1, 0.1], 10.0) + else: + self.set_indirect_lighting("lab_day") + + self._default_plane.set_material(mat.get_instance("plane_mat").mat) + self._visual_materials[mat_name] = mat + + def set_texture_cache( + self, key: str, texture: Union[torch.Tensor, List[torch.Tensor]] + ) -> None: + """Set the texture to the global texture cache. + + Args: + key (str): The key of the texture. + texture (Union[torch.Tensor, List[torch.Tensor]]): The texture data. + """ + self._texture_cache[key] = texture + + def get_texture_cache( + self, key: Optional[str] = None + ) -> Optional[Union[torch.Tensor, List[torch.Tensor]]]: + """Get the texture from the global texture cache. + + Args: + key (str, optional): The key of the texture. If None, return None. Defaults to None. + + Returns: + Optional[Union[torch.Tensor, List[torch.Tensor]]]: The texture if found, otherwise None. + """ + if key is None: + return self._texture_cache + + if key not in self._texture_cache: + logger.log_warning(f"Texture {key} not found in global texture cache.") + return None + return self._texture_cache[key] + + def get_asset( + self, uid: str + ) -> Optional[Union[Light, BaseSensor, Robot, RigidObject, Articulation]]: + """Get an asset by its UID. + + The asset can be a light, sensor, robot, rigid object or articulation. + + Args: + uid (str): The UID of the asset. + + Returns: + Light | BaseSensor | Robot | RigidObject | Articulation | None: The asset instance if found, otherwise None. + """ + if uid in self._lights: + return self._lights[uid] + if uid in self._sensors: + return self._sensors[uid] + if uid in self._robots: + return self._robots[uid] + if uid in self._rigid_objects: + return self._rigid_objects[uid] + if uid in self._rigid_object_groups: + return self._rigid_object_groups[uid] + if uid in self._articulations: + return self._articulations[uid] + + logger.log_warning(f"Asset {uid} not found.") + return None + + def add_light(self, cfg: LightCfg) -> Light: + """Create a light in the scene. + + Args: + cfg (LightCfg): Configuration for the light, including type, color, intensity, and radius. + + Returns: + Light: The created light instance. + """ + if cfg.uid is None: + uid = "light" + cfg.uid = uid + else: + uid = cfg.uid + + if uid in self._lights: + logger.log_error(f"Light {uid} already exists.") + + light_type = cfg.light_type + if light_type == "point": + light_type = LightType.POINT + else: + logger.log_error( + f"Unsupported light type: {light_type}. Supported types: point." + ) + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + light_list = [] + for i, env in enumerate(env_list): + light_name = f"{uid}_{i}" + light = env.create_light(light_name, light_type) + light_list.append(light) + + batch_lights = Light(cfg=cfg, entities=light_list) + + self._lights[uid] = batch_lights + + return batch_lights + + def get_light(self, uid: str) -> Optional[Light]: + """Get a light by its UID. + + Args: + uid (str): The UID of the light. + + Returns: + Light | None: The light instance if found, otherwise None. + """ + if uid not in self._lights: + logger.log_warning(f"Light {uid} not found.") + return None + return self._lights[uid] + + def add_rigid_object( + self, + cfg: RigidObjectCfg, + ) -> RigidObject: + """Add a rigid object to the scene. + + Args: + cfg (RigidObjectCfg): Configuration for the rigid object. + + Returns: + RigidObject: The added rigid object instance handle. + """ + from embodichain.lab.sim.utility.sim_utils import ( + load_mesh_objects_from_cfg, + ) + + uid = cfg.uid + if uid is None: + logger.log_error("Rigid object uid must be specified.") + if uid in self._rigid_objects: + logger.log_error(f"Rigid object {uid} already exists.") + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + obj_list = load_mesh_objects_from_cfg( + cfg=cfg, + env_list=env_list, + cache_dir=self._convex_decomp_dir, + ) + + rigid_obj = RigidObject(cfg=cfg, entities=obj_list, device=self.device) + + if cfg.shape.visual_material: + mat = self.create_visual_material(cfg.shape.visual_material) + rigid_obj.set_visual_material(mat) + + self._rigid_objects[uid] = rigid_obj + + return rigid_obj + + def add_soft_object(self, cfg: SoftObjectCfg) -> SoftObject: + """Add a soft object to the scene. + + Args: + cfg (SoftObjectCfg): Configuration for the soft object. + + Returns: + SoftObject: The added soft object instance handle. + """ + if not self.is_use_gpu_physics: + logger.log_error("Soft object requires GPU physics to be enabled.") + + from embodichain.lab.sim.utility import ( + load_soft_object_from_cfg, + ) + + uid = cfg.uid + if uid is None: + logger.log_error("Soft object uid must be specified.") + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + obj_list = load_soft_object_from_cfg( + cfg=cfg, + env_list=env_list, + ) + + soft_obj = SoftObject(cfg=cfg, entities=obj_list, device=self.device) + self._soft_objects[uid] = soft_obj + return soft_obj + + def get_rigid_object(self, uid: str) -> Optional[RigidObject]: + """Get a rigid object by its unique ID. + + Args: + uid (str): The unique ID of the rigid object. + + Returns: + Optional[RigidObject]: The rigid object instance if found, otherwise None. + """ + if uid not in self._rigid_objects: + logger.log_warning(f"Rigid object {uid} not found.") + return None + return self._rigid_objects[uid] + + def get_rigid_object_uid_list(self) -> List[str]: + """Get current rigid body uid list + + Returns: + List[str]: list of rigid body uid. + """ + return list(self._rigid_objects.keys()) + + def add_rigid_object_group(self, cfg: RigidObjectGroupCfg) -> RigidObjectGroup: + """Add a rigid object group to the scene. + + Args: + cfg (RigidObjectGroupCfg): Configuration for the rigid object group. + """ + from embodichain.lab.sim.utility.sim_utils import ( + load_mesh_objects_from_cfg, + ) + + uid = cfg.uid + if uid is None: + logger.log_error("Rigid object group uid must be specified.") + if uid in self._rigid_object_groups: + logger.log_error(f"Rigid object group {uid} already exists.") + + if cfg.body_type == "static": + logger.log_error("Rigid object group cannot be static.") + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + + obj_group_list = [] + for key, rigid_cfg in tqdm( + cfg.rigid_objects.items(), desc="Loading rigid objects" + ): + obj_list = load_mesh_objects_from_cfg( + cfg=rigid_cfg, + env_list=env_list, + cache_dir=self._convex_decomp_dir, + ) + obj_group_list.append(obj_list) + + # Convert [a1, a2, ...], [b1, b2, ...] to [(a1, b1, ...), (a2, b2, ...), ...] + obj_group_list = list(zip(*obj_group_list)) + rigid_obj_group = RigidObjectGroup( + cfg=cfg, entities=obj_group_list, device=self.device + ) + + self._rigid_object_groups[uid] = rigid_obj_group + + return rigid_obj_group + + def get_rigid_object_group(self, uid: str) -> Optional[RigidObjectGroup]: + """Get a rigid object group by its unique ID. + + Args: + uid (str): The unique ID of the rigid object group. + + Returns: + Optional[RigidObjectGroup]: The rigid object group instance if found, otherwise None. + """ + if uid not in self._rigid_object_groups: + logger.log_warning(f"Rigid object group {uid} not found.") + return None + return self._rigid_object_groups[uid] + + def _get_non_static_rigid_obj_num(self) -> int: + """Get the number of non-static rigid objects in the scene. + + Returns: + int: The number of non-static rigid objects. + """ + count = 0 + for obj in self._rigid_objects.values(): + if obj.cfg.body_type != "static": + count += 1 + return count + + def add_articulation( + self, + cfg: ArticulationCfg, + ) -> Articulation: + """Add an articulation to the scene. + + Args: + cfg (ArticulationCfg): Configuration for the articulation. + + Returns: + Articulation: The added articulation instance handle. + """ + + uid = cfg.uid + if uid is None: + uid = os.path.splitext(os.path.basename(cfg.fpath))[0] + cfg.uid = uid + if uid in self._articulations: + logger.log_error(f"Articulation {uid} already exists.") + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + obj_list = [] + + for env in env_list: + art = env.load_urdf(cfg.fpath) + obj_list.append(art) + + articulation = Articulation(cfg=cfg, entities=obj_list, device=self.device) + + self._articulations[uid] = articulation + + return articulation + + def get_articulation(self, uid: str) -> Optional[Articulation]: + """Get an articulation by its unique ID. + + Args: + uid (str): The unique ID of the articulation. + + Returns: + Optional[Articulation]: The articulation instance if found, otherwise None. + """ + if uid not in self._articulations: + logger.log_warning(f"Articulation {uid} not found.") + return None + return self._articulations[uid] + + def get_articulation_uid_list(self) -> List[str]: + """Get current articulation uid list + + Returns: + List[str]: list of articulation uid. + """ + return list(self._articulations.keys()) + + def add_robot(self, cfg: RobotCfg) -> Optional[Robot]: + """Add a Robot to the scene. + + Args: + cfg (RobotCfg): Configuration for the robot. + + Returns: + Optional[Robot]: The added robot instance handle, or None if failed. + """ + + uid = cfg.uid + if cfg.fpath is None: + if cfg.urdf_cfg is None: + logger.log_error( + "Robot configuration must have a valid fpath or urdf_cfg." + ) + return None + + cfg.fpath = cfg.urdf_cfg.assemble_urdf() + + if uid is None: + uid = os.path.splitext(os.path.basename(cfg.fpath))[0] + cfg.uid = uid + if uid in self._robots: + logger.log_error(f"Robot {uid} already exists.") + return self._robots[uid] + + env_list = [self._env] if len(self._arenas) == 0 else self._arenas + obj_list = [] + + for env in env_list: + art = env.load_urdf(cfg.fpath) + obj_list.append(art) + + robot = Robot(cfg=cfg, entities=obj_list, device=self.device) + + self._robots[uid] = robot + + return robot + + def get_robot(self, uid: str) -> Optional[Robot]: + """Get a Robot by its unique ID. + + Args: + uid (str): The unique ID of the robot. + + Returns: + Optional[Robot]: The robot instance if found, otherwise None. + """ + if uid not in self._robots: + logger.log_warning(f"Robot {uid} not found.") + return None + return self._robots[uid] + + def get_robot_uid_list(self) -> List[str]: + """ + Retrieves a list of unique identifiers (UIDs) for all robots in the V2 system. + + Returns: + list: A list containing the UIDs of the robots. + """ + return list(self._robots.keys()) + + def enable_gizmo( + self, uid: str, control_part: Optional[str] = None, gizmo_cfg: object = None + ) -> None: + """Enable gizmo control for any simulation object (Robot, RigidObject, Camera, etc.). + + Args: + uid (str): UID of the object to attach gizmo to (searches in robots, rigid_objects, sensors, etc.) + control_part (Optional[str], optional): Control part name for robots. Defaults to "arm". + gizmo_cfg (object, optional): Gizmo configuration object. Defaults to None. + """ + # Create gizmo key combining uid and control_part + gizmo_key = f"{uid}:{control_part}" if control_part else uid + + # Check if gizmo already exists + if gizmo_key in self._gizmos: + logger.log_warning( + f"Gizmo for '{uid}' with control_part '{control_part}' already exists." + ) + return + + # Search for target object in different collections + target = None + object_type = None + + if uid in self._robots: + target = self._robots[uid] + object_type = "robot" + elif uid in self._rigid_objects: + target = self._rigid_objects[uid] + object_type = "rigid_object" + elif uid in self._sensors: + target = self._sensors[uid] + object_type = "sensor" + + else: + logger.log_error( + f"Object with uid '{uid}' not found in any collection (robots, rigid_objects, sensors, articulations)." + ) + return + + try: + gizmo = Gizmo(target, gizmo_cfg, control_part) + self._gizmos[gizmo_key] = gizmo + logger.log_info( + f"Gizmo enabled for {object_type} '{uid}' with control_part '{control_part}'" + ) + + # Initialize GizmoController if not already done. + if not hasattr(self, "_gizmo_controller") or self._gizmo_controller is None: + windows = ( + self._world.get_windows() + if hasattr(self._world, "get_windows") + else None + ) + self._gizmo_controller = GizmoController(windows) + print("GizmoController attributes and methods:") + print(dir(self._gizmo_controller)) + + except Exception as e: + logger.log_error( + f"Failed to create gizmo for {object_type} '{uid}' with control_part '{control_part}': {e}" + ) + + def disable_gizmo(self, uid: str, control_part: Optional[str] = None) -> None: + """Disable and remove gizmo for a robot. + + Args: + uid (str): Object UID to disable gizmo for + control_part (Optional[str], optional): Control part name for robots. Defaults to None. + """ + # Create gizmo key combining uid and control_part + gizmo_key = f"{uid}:{control_part}" if control_part else uid + + if gizmo_key not in self._gizmos: + from embodichain.utils import logger + + logger.log_warning( + f"No gizmo found for '{uid}' with control_part '{control_part}'." + ) + return + + try: + gizmo = self._gizmos[gizmo_key] + if gizmo is not None: + gizmo.destroy() + del self._gizmos[gizmo_key] + + from embodichain.utils import logger + + logger.log_info( + f"Gizmo disabled for '{uid}' with control_part '{control_part}'" + ) + + except Exception as e: + from embodichain.utils import logger + + logger.log_error( + f"Failed to disable gizmo for '{uid}' with control_part '{control_part}': {e}" + ) + + def get_gizmo(self, uid: str, control_part: Optional[str] = None) -> object: + """Get gizmo instance for a robot. + + Args: + uid (str): Object UID + control_part (Optional[str], optional): Control part name for robots. Defaults to None. + + Returns: + object: Gizmo instance if found, None otherwise. + """ + # Create gizmo key combining uid and control_part + gizmo_key = f"{uid}:{control_part}" if control_part else uid + return self._gizmos.get(gizmo_key, None) + + def has_gizmo(self, uid: str, control_part: Optional[str] = None) -> bool: + """Check if a gizmo exists for the given UID and control part. + + Args: + uid (str): Object UID to check + control_part (Optional[str], optional): Control part name for robots. Defaults to None. + + Returns: + bool: True if gizmo exists, False otherwise. + """ + # Create gizmo key combining uid and control_part + gizmo_key = f"{uid}:{control_part}" if control_part else uid + return gizmo_key in self._gizmos + + def list_gizmos(self) -> Dict[str, bool]: + """List all active gizmos and their status. + + Returns: + Dict[str, bool]: Dictionary mapping gizmo keys (uid:control_part) to gizmo active status. + """ + return { + gizmo_key: (gizmo is not None) for gizmo_key, gizmo in self._gizmos.items() + } + + def update_gizmos(self): + """Update all active gizmos.""" + for gizmo_key, gizmo in list( + self._gizmos.items() + ): # Use list() to avoid modification during iteration + if gizmo is not None: + try: + gizmo.update() + except Exception as e: + from embodichain.utils import logger + + logger.log_error(f"Error updating gizmo '{gizmo_key}': {e}") + + def toggle_gizmo_visibility( + self, uid: str, control_part: Optional[str] = None + ) -> bool: + """ + Toggle the visibility of a gizmo by uid and optional control_part. + Returns the new visibility state (True=visible, False=hidden), or None if not found. + """ + gizmo = self.get_gizmo(uid, control_part) + if gizmo is not None: + return gizmo.toggle_visibility() + return None + + def set_gizmo_visibility( + self, uid: str, visible: bool, control_part: Optional[str] = None + ) -> None: + """ + Set the visibility of a gizmo by uid and optional control_part. + """ + gizmo = self.get_gizmo(uid, control_part) + if gizmo is not None: + gizmo.set_visibility(visible) + + def add_sensor(self, sensor_cfg: SensorCfg) -> BaseSensor: + """General interface to add a sensor to the scene and returns a handle. + + Args: + sensor_cfg (SensorCfg): configuration for the sensor. + + Returns: + BaseSensor: The added sensor instance handle. + """ + sensor_type = sensor_cfg.sensor_type + if sensor_type not in self.SUPPORTED_SENSOR_TYPES: + logger.log_warning(f"Unsupported sensor type: {sensor_type}") + return None + + sensor_uid = sensor_cfg.uid + if sensor_uid is None: + sensor_uid = f"{sensor_type.lower()}_{len(self._sensors)}" + sensor_cfg.uid = sensor_uid + + if sensor_uid in self._sensors: + logger.log_warning(f"Sensor {sensor_uid} already exists.") + return None + + sensor = self.SUPPORTED_SENSOR_TYPES[sensor_type](sensor_cfg, self.device) + + self._sensors[sensor_uid] = sensor + + # Check if the sensor needs to change the parent frame. + + return sensor + + def get_sensor(self, uid: str) -> Optional[BaseSensor]: + """Get a sensor by its UID. + + Args: + uid (str): The UID of the sensor. + + Returns: + BaseSensor | None: The sensor instance if found, otherwise None. + """ + if uid not in self._sensors: + logger.log_warning(f"Sensor {uid} not found.") + return None + return self._sensors[uid] + + def get_sensor_uid_list(self) -> List[str]: + """Get current sensor uid list + + Returns: + List[str]: list of sensor uid. + """ + return list(self._sensors.keys()) + + def remove_asset(self, uid: str) -> bool: + """Remove an asset by its UID. + + The asset can be a light, sensor, robot, rigid object or articulation. + + Note: + Currently, lights and sensors are not supported to be removed. + + Args: + uid (str): The UID of the asset. + Returns: + bool: True if the asset is removed successfully, otherwise False. + """ + if uid in self._rigid_objects: + obj = self._rigid_objects.pop(uid) + obj.destroy() + return True + + if uid in self._soft_objects: + obj = self._soft_objects.pop(uid) + obj.destroy() + return True + + if uid in self._rigid_object_groups: + group = self._rigid_object_groups.pop(uid) + group.destroy() + return True + + if uid in self._articulations: + art = self._articulations.pop(uid) + art.destroy() + return True + + if uid in self._robots: + robot = self._robots.pop(uid) + robot.destroy() + return True + + return False + + def get_asset( + self, uid: str + ) -> Optional[Union[RigidObject, Articulation, Robot, Light, BaseSensor]]: + """Get an asset by its UID. + + The asset can be a rigid object, articulation or robot. + + Args: + uid (str): The UID of the asset. + """ + if uid in self._rigid_objects: + return self._rigid_objects[uid] + + if uid in self._articulations: + return self._articulations[uid] + + if uid in self._robots: + return self._robots[uid] + + if uid in self._lights: + return self._lights[uid] + + if uid in self._sensors: + return self._sensors[uid] + + logger.log_warning(f"Asset {uid} not found.") + return None + + def draw_marker( + self, + cfg: MarkerCfg, + ) -> MeshObject: + """Draw visual markers in the simulation scene for debugging and visualization. + + Args: + cfg (MarkerCfg): Marker configuration with the following key parameters: + - name (str): Unique identifier for the marker group + - marker_type (str): Type of marker ("axis" currently supported) + - axis_xpos (np.ndarray | List[np.ndarray]): 4x4 transformation matrices + for marker positions and orientations + - axis_size (float): Thickness of axis arrows + - axis_len (float): Length of axis arrows + - arena_index (int): Arena index for placement (-1 for global) + + Returns: + List[MeshObject]: List of created marker handles, False if invalid input, + None if no poses provided. + + Example: + ```python + cfg = MarkerCfg(name="test_axis", marker_type="axis", axis_xpos=np.eye(4)) + markers = sim.draw_marker(cfg) + ``` + """ + # Validate marker type + if cfg.marker_type != "axis": + logger.log_error( + f"Unsupported marker type '{cfg.marker_type}'. Currently only 'axis' is supported." + ) + return False + + draw_xpos = deepcopy(cfg.axis_xpos) + draw_xpos = np.array(draw_xpos) + if draw_xpos.ndim == 2: + if draw_xpos.shape == (4, 4): + draw_xpos = np.expand_dims(draw_xpos, axis=0) + else: + logger.log_error( + f"axis_xpos must be of shape (N, 4, 4), got {draw_xpos.shape}." + ) + return False + elif draw_xpos.ndim != 3 or draw_xpos.shape[1:] != (4, 4): + logger.log_error( + f"axis_xpos must be of shape (N, 4, 4), got {draw_xpos.shape}." + ) + return False + + original_name = cfg.name + name = original_name + count = 0 + + while name in self._markers: + count += 1 + name = f"{original_name}_{count}" + if count > 0: + logger.log_warning( + f"Marker name '{original_name}' already exists. Using '{name}'." + ) + + marker_num = len(draw_xpos) + if marker_num == 0: + logger.log_warning(f"No marker poses provided.") + return None + + if cfg.arena_index >= 0: + name = f"{name}_{cfg.arena_index}" + + env = self.get_env(cfg.arena_index) + + # Create markers based on marker type + marker_handles = [] + + if cfg.marker_type == "axis": + # Create coordinate axes + axis_option = dexsim.types.AxisOption( + lx=cfg.axis_len, + ly=cfg.axis_len, + lz=cfg.axis_len, + size=cfg.axis_size, + arrow_type=cfg.arrow_type, + corner_type=cfg.corner_type, + tag_type=dexsim.types.AxisTagType.NONE, + ) + + for i, pose in enumerate(draw_xpos): + axis_handle = env.create_axis(axis_option) + axis_handle.set_local_pose(pose) + marker_handles.append(axis_handle) + + # TODO: Add support for other marker types in the future + # elif cfg.marker_type == "line": + # # Create line markers + # pass + # elif cfg.marker_type == "point": + # # Create point markers + # pass + + self._markers[name] = (marker_handles, cfg.arena_index) + + if self.is_physics_manually_update: + self.update(step=1) + + return marker_handles + + def remove_marker(self, name: str) -> bool: + """Remove markers (including axis) with the given name. + + Args: + name (str): The name of the marker to remove. + Returns: + bool: True if the marker was removed successfully, False otherwise. + """ + if name not in self._markers: + logger.log_warning(f"Marker {name} not found.") + return False + try: + env = self.get_env(self._markers[name][1]) + marker_handles, arena_index = self._markers[name] + for marker_handle in marker_handles: + if marker_handle is not None: + env.remove_actor(marker_handle.get_name()) + self._markers.pop(name) + return True + except Exception as e: + logger.log_warning(f"Failed to remove marker {name}: {str(e)}") + return False + + def create_visual_material(self, cfg: VisualMaterialCfg) -> VisualMaterial: + """Create a visual material with given configuration. + + Args: + cfg (VisualMaterialCfg): configuration for the visual material. + + Returns: + VisualMaterial: the created visual material instance handle. + """ + + if cfg.uid in self._visual_materials: + logger.log_warning( + f"Visual material {cfg.uid} already exists. Returning the existing one." + ) + return self._visual_materials[cfg.uid] + + mat: Material = self._env.create_pbr_material(cfg.uid, True) + visual_mat = VisualMaterial(cfg, mat) + + self._visual_materials[cfg.uid] = visual_mat + return visual_mat + + def get_visual_material(self, uid: str) -> VisualMaterial: + """Get visual material by UID. + + Args: + uid (str): uid of visual material. + """ + if uid not in self._visual_materials: + logger.log_warning(f"Visual material {uid} not found.") + return None + + return self._visual_materials[uid] + + def clean_materials(self): + self._visual_materials = {} + self._env.clean_materials() + + def reset_objects_state(self, env_ids: Optional[Sequence[int]] = None) -> None: + """Reset the state of all objects in the scene. + + Args: + env_ids: The environment IDs to reset. If None, reset all environments. + """ + for robot in self._robots.values(): + robot.reset(env_ids) + for articulation in self._articulations.values(): + articulation.reset(env_ids) + for rigid_obj in self._rigid_objects.values(): + rigid_obj.reset(env_ids) + for rigid_obj_group in self._rigid_object_groups.values(): + rigid_obj_group.reset(env_ids) + for light in self._lights.values(): + light.reset(env_ids) + for sensor in self._sensors.values(): + sensor.reset(env_ids) + + def destroy(self) -> None: + """Destroy all simulated assets and release resources.""" + # Clean up all gizmos before destroying the simulation + for uid in list(self._gizmos.keys()): + self.disable_gizmo(uid) + + self.clean_materials() + + self._env.clean() + self._world.quit() diff --git a/embodichain/lab/sim/solvers/__init__.py b/embodichain/lab/sim/solvers/__init__.py new file mode 100644 index 00000000..6f6123fa --- /dev/null +++ b/embodichain/lab/sim/solvers/__init__.py @@ -0,0 +1,23 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .base_solver import SolverCfg, BaseSolver, merge_solver_cfg +from .pytorch_solver import PytorchSolverCfg, PytorchSolver +from .pinocchio_solver import PinocchioSolverCfg, PinocchioSolver +from .differential_solver import DifferentialSolverCfg, DifferentialSolver +from .pink_solver import PinkSolverCfg, PinkSolver +from .opw_solver import OPWSolverCfg, OPWSolver +from .srs_solver import SRSSolverCfg, SRSSolver diff --git a/embodichain/lab/sim/solvers/base_solver.py b/embodichain/lab/sim/solvers/base_solver.py new file mode 100644 index 00000000..af7433cc --- /dev/null +++ b/embodichain/lab/sim/solvers/base_solver.py @@ -0,0 +1,457 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from typing import List, Optional, Dict, Any, Union, TYPE_CHECKING, Tuple +from abc import abstractmethod, ABCMeta + +from embodichain.utils import configclass, logger + +if TYPE_CHECKING: + from typing import Self + +from embodichain.lab.sim.utility.solver_utils import create_pk_serial_chain + + +@configclass +class SolverCfg: + """Configuration for the kinematic solver used in the robot simulation.""" + + class_type: str = "BaseSolver" + """The class type of the solver to be used.""" + + urdf_path: Optional[str] = None + """The file path to the URDF model of the robot.""" + + joint_names: Optional[list[str]] = None + """List of joint names for the solver. + + If None, all joints in the URDF will be used. + If specified, only these named joints will be included in the kinematic chain. + """ + + end_link_name: str = None + """The name of the end-effector link for the solver. + + This defines the target link for forward/inverse kinematics calculations. + Must match a link name in the URDF file. + """ + + root_link_name: str = None + """The name of the root/base link for the solver. + + This defines the starting point of the kinematic chain. + Must match a link name in the URDF file. + """ + + # TODO: may be support pos and rot separately for easier manipulation. + tcp: Union[torch.Tensor, np.ndarray] = np.eye(4) + """The tool center point (TCP) position as a 4x4 homogeneous matrix. + + This represents the position and orientation of the tool in the robot's end-effector frame. + """ + + ik_nearest_weight: Optional[List[float]] = None + """Weights for the inverse kinematics nearest calculation. + + The weights influence how the solver prioritizes closeness to the seed position + when multiple solutions are available. + """ + + @abstractmethod + def init_solver(self, device: torch.device, **kwargs) -> "BaseSolver": + pass + + @classmethod + def from_dict(cls, init_dict: Dict[str, Any]) -> "SolverCfg": + """Initialize the configuration from a dictionary.""" + from embodichain.utils.utility import get_class_instance + + if "class_type" not in init_dict: + logger.log_error("class type must be specified in the configuration.") + + cfg = get_class_instance( + "embodichain.lab.sim.solvers", init_dict["class_type"] + "Cfg" + )() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning( + f"Key '{key}' not found in {cfg.__class__.__name__}." + ) + return cfg + + +class BaseSolver(metaclass=ABCMeta): + def __init__(self, cfg: SolverCfg = None, device: str = None, **kwargs): + r"""Initializes the kinematics solver with a robot model. + + Args: + cfg (SolverCfg): The configuration for the solver. + device (str or torch.device, optional): The device to run the solver on. Defaults to "cuda" if available, otherwise "cpu". + **kwargs: Additional keyword arguments for customization. + """ + self.cfg = cfg + + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + elif isinstance(device, str): + self.device = torch.device(device) + else: + self.device = device + + self.urdf_path = cfg.urdf_path + + self.joint_names = cfg.joint_names + + self.end_link_name = cfg.end_link_name + + self.root_link_name = cfg.root_link_name + + # TODO: Check whether the joint name is revolute or prismatic + # Degrees of freedom of robot joints + self.dof = len(self.joint_names) if self.joint_names else 0 + + # Weight for nearest neighbor search in IK (Inverse Kinematics) algorithms + if cfg.ik_nearest_weight is not None: + if len(cfg.ik_nearest_weight) != self.dof: + logger.log_error( + f"Length of ik_nearest_weight ({len(cfg.ik_nearest_weight)}) does not match the number of DOF ({self.dof})." + ) + self.ik_nearest_weight = torch.tensor( + cfg.ik_nearest_weight, dtype=torch.float32, device=self.device + ) + else: + self.ik_nearest_weight = torch.ones( + self.dof, dtype=torch.float32, device=self.device + ) + + self.tcp_xpos = np.eye(4) + + self.pk_serial_chain = kwargs.get("pk_serial_chain", None) + if self.pk_serial_chain is None: + self.pk_serial_chain = create_pk_serial_chain( + urdf_path=self.urdf_path, + end_link_name=self.end_link_name, + root_link_name=self.root_link_name, + device=self.device, + ) + + def set_ik_nearest_weight( + self, ik_weight: np.ndarray, joint_ids: np.ndarray = None + ) -> bool: + r"""Sets the inverse kinematics nearest weight. + + Args: + ik_weight (np.ndarray): A numpy array representing the nearest weights for inverse kinematics. + joint_ids (np.ndarray, optional): A numpy array representing the indices of the joints to which the weights apply. + If None, defaults to all joint indices. + + Returns: + bool: True if the weights are set successfully, False otherwise. + """ + ik_weight = np.array(ik_weight) + + # Set joint_ids to all joint indices if it is None + if joint_ids is None: + joint_ids = np.arange(self.dof) + + joint_ids = np.array(joint_ids) + + # Check if joint_ids has valid indices + if np.any(joint_ids >= self.dof) or np.any(joint_ids < 0): + logger.log_warning( + "joint_ids must contain valid indices between 0 and {}.".format( + self.dof - 1 + ) + ) + return False + + # Check if ik_weight and joint_ids have the same length + if ik_weight.shape[0] != joint_ids.shape[0]: + logger.log_warning("ik_weight and joint_ids must have the same length.") + return False + + # Initialize the weights + if self.ik_nearest_weight is None: + # If ik_nearest_weight is None, set all weights to 1 + self.ik_nearest_weight = np.ones(self.dof) + + # Set specific weights for joint_ids to the provided ik_weight + for i, joint_id in enumerate(joint_ids): + self.ik_nearest_weight[joint_id] = ik_weight[i] + else: + # If ik_nearest_weight is not None, only fill joint_ids + for i, joint_id in enumerate(joint_ids): + self.ik_nearest_weight[joint_id] = ik_weight[i] + + return True + + def get_ik_nearest_weight(self): + r"""Gets the inverse kinematics nearest weight. + + Returns: + np.ndarray: A numpy array representing the nearest weights for inverse kinematics. + """ + return self.ik_nearest_weight + + def set_position_limits( + self, + lower_position_limits: List[float], + upper_position_limits: List[float], + ) -> bool: + r"""Sets the upper and lower joint position limits. + + Parameters: + lower_position_limits (List[float]): A list of lower limits for each joint. + upper_position_limits (List[float]): A list of upper limits for each joint. + + Returns: + bool: True if limits are successfully set, False if the input is invalid. + """ + if ( + len(lower_position_limits) != self.model.nq + or len(upper_position_limits) != self.model.nq + ): + logger.log_warning("Length of limits must match the number of joints.") + return False + + if any( + lower > upper + for lower, upper in zip(lower_position_limits, upper_position_limits) + ): + logger.log_warning( + "Each lower limit must be less than or equal to the corresponding upper limit." + ) + return False + + self.lower_position_limits = np.array(lower_position_limits) + self.upper_position_limits = np.array(upper_position_limits) + return True + + def get_position_limits(self) -> dict: + r"""Returns the current joint position limits. + + Returns: + dict: A dictionary containing: + - lower_position_limits (List[float]): The current lower limits for each joint. + - upper_position_limits (List[float]): The current upper limits for each joint. + """ + return { + "lower_position_limits": self.lower_position_limits.tolist(), + "upper_position_limits": self.upper_position_limits.tolist(), + } + + def set_tcp(self, xpos: np.ndarray): + r"""Sets the TCP position with the given 4x4 homogeneous matrix. + + Args: + xpos (np.ndarray): The 4x4 homogeneous matrix to be set as the TCP position. + + Raises: + ValueError: If the input is not a 4x4 numpy array. + """ + xpos = np.array(xpos) + if xpos.shape != (4, 4): + raise ValueError("Input must be a 4x4 homogeneous matrix") + self.tcp_xpos = xpos + + def get_tcp(self) -> np.ndarray: + r"""Returns the current TCP position. + + Returns: + np.ndarray: The current TCP position. + + Raises: + ValueError: If the TCP position has not been set. + """ + return self.tcp_xpos + + @abstractmethod + def get_ik( + self, + target_pose: torch.Tensor, + joint_seed: Optional[torch.Tensor] = None, + num_samples: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Computes the inverse kinematics for a given target pose. + + This method generates random joint configurations within the specified limits, + including the provided joint_seed, and attempts to find valid inverse kinematics solutions. + It then identifies the joint position that is closest to the joint_seed. + + Args: + target_pose (torch.Tensor): The target pose represented as a 4x4 transformation matrix. + joint_seed (Optional[torch.Tensor]): The initial joint positions used as a seed. + num_samples (Optional[int]): The number of random joint seeds to generate. + **kwargs: Additional keyword arguments for customization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - success (torch.Tensor): Boolean tensor indicating IK solution validity for each environment, shape (num_envs,). + - target_joints (torch.Tensor): Computed target joint positions, shape (num_envs, num_joints). + """ + pass + + def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.Tensor: + r""" + Computes the forward kinematics for the end-effector link. + + Args: + qpos (torch.Tensor): Joint positions. Can be a single configuration (dof,) or a batch (batch_size, dof). + **kwargs: Additional keyword arguments for customization. + + Returns: + torch.Tensor: The homogeneous transformation matrix of the end link with TCP applied. + Shape is (4, 4) for single input, or (batch_size, 4, 4) for batch input. + """ + tcp_xpos = torch.as_tensor( + self.tcp_xpos, device=self.device, dtype=torch.float32 + ) + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) + + # Compute forward kinematics + result = self.pk_serial_chain.forward_kinematics( + qpos, end_only=(self.end_link_name is None) + ) + + # Extract transformation matrices + if isinstance(result, dict): + matrices = result[self.end_link_name].get_matrix() + elif isinstance(result, list): + matrices = torch.stack([xpos.get_matrix().squeeze() for xpos in result]) + else: + matrices = result.get_matrix() + + # Ensure batch format + if matrices.dim() == 2: + matrices = matrices.unsqueeze(0) + + # Create result tensor with proper homogeneous coordinates + result = ( + torch.eye(4, device=self.device).expand(matrices.shape[0], 4, 4).clone() + ) + result[:, :3, :] = matrices[:, :3, :] + + # Ensure batch format for TCP + batch_size = result.shape[0] + tcp_xpos_batch = tcp_xpos.unsqueeze(0).expand(batch_size, -1, -1) + + # Apply TCP transformation + return torch.bmm(result, tcp_xpos_batch) + + def get_jacobian( + self, + qpos: torch.Tensor, + locations: Optional[Union[torch.Tensor, np.ndarray]] = None, + jac_type: str = "full", + ) -> torch.Tensor: + r"""Compute the Jacobian matrix for the given joint positions. + + Args: + qpos (torch.Tensor): The joint positions. Shape: (dof,) or (batch_size, dof). + locations (Optional[torch.Tensor]): The offset points (relative to the end-effector coordinate system). + Shape: (batch_size, 3) or (3,) for a single offset. + jac_type (str, optional): 'full', 'trans', or 'rot' for full, translational, or rotational Jacobian. + Defaults to 'full'. + + Returns: + torch.Tensor: The Jacobian matrix. Shape: + - (batch_size, 6, dof) for 'full' + - (batch_size, 3, dof) for 'trans' or 'rot' + """ + if qpos is None: + qpos = torch.zeros(self.dof, device=self.device) + + # Ensure qpos is a tensor + qpos = torch.as_tensor(qpos, dtype=torch.float32, device=self.device) + + # Ensure locations is a tensor if provided + if locations is not None: + locations = torch.as_tensor( + locations, dtype=torch.float32, device=self.device + ) + + # Compute the Jacobian using the kinematics chain + J = self.pk_serial_chain.jacobian(th=qpos, locations=locations) + + # Handle jac_type to return the desired part of the Jacobian + if jac_type == "trans": + return J[:, :3, :] if J.dim() == 3 else J[:3, :] + elif jac_type == "rot": + return J[:, 3:, :] if J.dim() == 3 else J[3:, :] + elif jac_type == "full": + return J + else: + raise ValueError( + f"Invalid jac_type '{jac_type}'. Must be 'full', 'trans', or 'rot'." + ) + + +def merge_solver_cfg( + default: Dict[str, SolverCfg], provided: Dict[str, Any] +) -> Dict[str, SolverCfg]: + """Merge provided solver configuration into the default solver config. + + Rules: + - For each arm key in provided, if the key exists in default, update fields provided. + - If a provided value is a dict, update attributes on the SolverCfg-like object (or dict) by setting keys. + - Primitive values or arrays/lists replace the target value. + - Unknown keys in provided create new entries in the result. + """ + + result = {} + # copy defaults shallowly + for k, v in default.items(): + result[k] = v + + for k, v in provided.items(): + if k in result: + target = result[k] + # if target has __dict__ or is a dataclass-like, set attrs + if hasattr(target, "__dict__") or isinstance(target, dict): + # if provided is a dict, set/override attributes + if isinstance(v, dict): + for sub_k, sub_v in v.items(): + # try to set attribute if possible, otherwise assign into dict + if hasattr(target, sub_k): + try: + setattr(target, sub_k, sub_v) + except Exception: + # fallback to dict assignment if object doesn't accept + try: + target[sub_k] = sub_v + except Exception: + pass + else: + try: + target[sub_k] = sub_v + except Exception: + setattr(target, sub_k, sub_v) + else: + # non-dict provided value replaces the target entirely + result[k] = v + else: + # target is a primitive, replace + result[k] = v + else: + # new solver entry provided; include as-is + result[k] = v + + return result diff --git a/embodichain/lab/sim/solvers/differential_solver.py b/embodichain/lab/sim/solvers/differential_solver.py new file mode 100644 index 00000000..5bbb8028 --- /dev/null +++ b/embodichain/lab/sim/solvers/differential_solver.py @@ -0,0 +1,424 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +from copy import deepcopy +from typing import Optional, Union, Tuple, Any, Literal, TYPE_CHECKING +from scipy.spatial.transform import Rotation + +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver +from embodichain.utils.math import ( + apply_delta_pose, + compute_pose_error, +) + + +if TYPE_CHECKING: + from typing import Self + + +@configclass +class DifferentialSolverCfg(SolverCfg): + """Configuration for differential inverse kinematics controller.""" + + class_type: str = "DifferentialSolver" + + pos_eps: float = 5e-4 # Tolerance for convergence for position + rot_eps: float = 5e-4 # Tolerance for convergence for rotation + max_iterations: int = 1000 # Maximum number of iterations for the solver + + # Constraint configuration + is_only_position_constraint: bool = ( + False # Whether to only consider position constraints + ) + + # Type of task-space command to control the articulation's body. + command_type: Literal["position", "pose"] = "pose" + + # Whether to use relative mode for the controller. + use_relative_mode: bool = False + + # Method for computing inverse of Jacobian.""" + ik_method: Literal["pinv", "svd", "trans", "dls"] = "pinv" + + # Parameters for the inverse-kinematics method. + ik_params: Optional[dict] = None + + def __post_init__(self): + # Default parameters for different inverse kinematics approaches + default_ik_params = { + "pinv": {"k_val": 1.0}, + "svd": {"k_val": 1.0, "min_singular_value": 1e-5}, + "trans": {"k_val": 1.0}, + "dls": {"lambda_val": 0.01}, + } + + # Update parameters for IK-method if not provided + params = self.ik_params or {} + self.ik_params = {**default_ik_params[self.ik_method], **params} + + def init_solver( + self, num_envs: int = 1, device: torch.device = torch.device("cpu"), **kwargs + ) -> "DifferentialSolver": + """Initialize the solver with the configuration. + + Args: + device (torch.device): The device to use for the solver. Defaults to CPU. + num_envs (int): The number of environments for which the solver is initialized. + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + DifferentialSolver: An initialized solver instance. + """ + + solver = DifferentialSolver( + cfg=self, num_envs=num_envs, device=device, **kwargs + ) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + solver.set_tcp(tcp) + + return solver + + +class DifferentialSolver(BaseSolver): + r"""Differential inverse kinematics (IK) controller. + + This controller implements differential inverse kinematics using various methods for + computing the inverse of the Jacobian matrix. + """ + + def __init__( + self, + cfg: DifferentialSolverCfg, + num_envs: int = 1, + device: str = "cpu", + **kwargs, + ): + r"""Initializes the differential kinematics solver. + + This constructor sets up the kinematics solver using differential methods, + allowing for efficient computation of robot kinematics based on + the specified URDF model. + + Args: + cfg: The configuration for the solver. + num_envs (int): The number of environments for the solver. Defaults to 1. + device (str, optional): The device to use for the solver (e.g., "cpu" or "cuda"). Defaults to "cpu". + **kwargs: Additional keyword arguments passed to the base solver. + + """ + super().__init__(cfg=cfg, num_envs=num_envs, device=device, **kwargs) + + # Initialize buffers + self.ee_pos_des = torch.zeros(num_envs, 3, device=device) + self.ee_quat_des = torch.zeros(num_envs, 4, device=device) + self._command = torch.zeros(num_envs, self.action_dim, device=device) + + @property + def action_dim(self) -> int: + """Returns the dimension of the controller's input command. + + Returns: + int: The dimension of the input command. + """ + if self.cfg.command_type == "position": + return 3 # (x, y, z) + elif self.cfg.command_type == "pose" and self.cfg.use_relative_mode: + return 6 # (dx, dy, dz, droll, dpitch, dyaw) + else: + return 7 # (x, y, z, qw, qx, qy, qz) + + def reset(self, env_ids: Optional[torch.Tensor] = None): + """Reset the internal buffers for the specified environments. + + Args: + env_ids (Optional[torch.Tensor]): The environment indices to reset. If None, reset all. + """ + if env_ids is None: + env_ids = torch.arange(self.num_envs, device=self.device) + + self.ee_pos_des[env_ids] = 0 + self.ee_quat_des[env_ids] = torch.tensor([1.0, 0, 0, 0], device=self.device) + self._command[env_ids] = 0 + + def set_command( + self, + command: torch.Tensor, + ee_pos: Optional[torch.Tensor] = None, + ee_quat: Optional[torch.Tensor] = None, + ) -> bool: + """Set the target end-effector pose command. + + Args: + command (torch.Tensor): The command tensor. + ee_pos (Optional[torch.Tensor]): Current end-effector position (for relative mode). + ee_quat (Optional[torch.Tensor]): Current end-effector quaternion (for relative mode). + + Returns: + bool: True if the command was set successfully, False otherwise. + """ + # TODO: Init solver with correct batch size + batch_size = command.shape[0] + if self._command.shape[0] != batch_size: + device = command.device + self._command = torch.zeros(batch_size, self.action_dim, device=device) + self.ee_pos_des = torch.zeros(batch_size, 3, device=device) + self.ee_quat_des = torch.zeros(batch_size, 4, device=device) + self._command[:] = command + + if self.cfg.command_type == "position": + if ee_quat is None: + logger.log_warning( + "End-effector orientation cannot be None for position control" + ) + return False + + if self.cfg.use_relative_mode: + if ee_pos is None: + logger.log_warning("Current position required for relative mode") + return False + self.ee_pos_des[:] = ee_pos + self._command + self.ee_quat_des[:] = ee_quat + else: + self.ee_pos_des[:] = self._command + self.ee_quat_des[:] = ee_quat + else: + if self.cfg.use_relative_mode: + if ee_pos is None or ee_quat is None: + logger.log_warning("Current pose required for relative mode") + return False + self.ee_pos_des, self.ee_quat_des = apply_delta_pose( + ee_pos, ee_quat, self._command + ) + else: + self.ee_pos_des = self._command[:, 0:3] + self.ee_quat_des = self._command[:, 3:7] + + return True + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor = None, + return_all_solutions: bool = False, + jacobian: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute target joint positions using differential inverse kinematics. + + Args: + target_xpos (torch.Tensor): Current end-effector position, shape (num_envs, 3). + qpos_seed (torch.Tensor): Current joint positions, shape (num_envs, num_joints). Defaults to zeros. + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + jacobian (torch.Tensor): Jacobian matrix, shape (num_envs, 6, num_joints). + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - success (torch.Tensor): Boolean tensor indicating IK solution validity for each environment, shape (num_envs,). + - target_joints (torch.Tensor): Computed target joint positions, shape (num_envs, num_joints). + """ + if qpos_seed is None: + qpos_seed = torch.zeros(self.dof, device=self.device) + + if jacobian is None: + jacobian = self.get_jacobian(qpos_seed) + current_xpos = self.get_fk(qpos_seed, to_matrix=True) + + # Transform target_xpos by TCP + tcp_xpos = torch.as_tensor( + deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32 + ) + current_xpos = current_xpos @ torch.inverse(tcp_xpos) + compute_xpos = target_xpos @ torch.inverse(tcp_xpos) + + # Ensure compute_xpos is a batch of matrices + if current_xpos.dim() == 2 and current_xpos.shape == (4, 4): + current_xpos = current_xpos.unsqueeze(0) + + # Ensure compute_xpos is a batch of matrices + if compute_xpos.dim() == 2 and compute_xpos.shape == (4, 4): + compute_xpos = compute_xpos.unsqueeze(0) + + compute_pose = self._matrix_to_pos_quat(compute_xpos) + self.set_command(command=compute_pose) + + qpos = qpos_seed + num_iter = 1 if self.cfg.max_iterations == 1 else self.cfg.max_iterations + for i in range(num_iter): + current_pose = self._matrix_to_pos_quat(current_xpos) + ee_pos = current_pose[:, :3] + ee_quat = current_pose[:, 3:] + + if self.cfg.command_type == "position": + position_error = self.ee_pos_des - ee_pos + jacobian_pos = jacobian[:, :3] + delta_joint_pos = self._compute_delta_joint_pos( + delta_pose=position_error, jacobian=jacobian_pos + ) + else: + pos_error, rot_error = compute_pose_error( + ee_pos, ee_quat, self.ee_pos_des, self.ee_quat_des + ) + pose_error = torch.cat((pos_error, rot_error), dim=1) + delta_joint_pos = self._compute_delta_joint_pos( + delta_pose=pose_error, jacobian=jacobian + ) + + qpos = qpos + delta_joint_pos + current_xpos = self.get_fk(qpos) + + # Ensure current_xpos and target_xpos are batches of matrices + if current_xpos.dim() == 2 and current_xpos.shape == (4, 4): + current_xpos = current_xpos.unsqueeze(0) + + if target_xpos.dim() == 2 and target_xpos.shape == (4, 4): + target_xpos = target_xpos.unsqueeze(0) + + pos_converged = ( + torch.norm(current_xpos[:, :3, 3] - target_xpos[:, :3, 3], dim=1) + < self.cfg.pos_eps + ) + rot_converged = ( + torch.norm(current_xpos[:, :3, :3] - target_xpos[:, :3, :3], dim=(1, 2)) + < self.cfg.rot_eps + ) + + if self.cfg.is_only_position_constraint: + if pos_converged.all(): + break + else: + if (pos_converged & rot_converged).all(): + break + + if return_all_solutions: + logger.log_warning( + "return_all_solutions=True is not supported in DifferentialSolver. Returning the best solution only." + ) + + if self.cfg.is_only_position_constraint: + success = pos_converged + else: + success = pos_converged & rot_converged + + return success, qpos + + # Helper functions + def _compute_delta_joint_pos( + self, delta_pose: torch.Tensor, jacobian: torch.Tensor + ) -> torch.Tensor: + """Compute joint-space delta using the specified IK method. + + Args: + delta_pose (torch.Tensor): The pose error tensor. + jacobian (torch.Tensor): The Jacobian matrix. + + Returns: + torch.Tensor: The joint-space delta tensor. + """ + method = self.cfg.ik_method + params = self.cfg.ik_params + + # compute the delta in joint-space + if method == "pinv": # Jacobian pseudo-inverse + # params + k_val = params["k_val"] + # compute + jacobian_pinv = torch.linalg.pinv(jacobian) + delta_joint_pos = k_val * ( + jacobian_pinv @ delta_pose.unsqueeze(-1) + ).squeeze(-1) + elif method == "svd": + # params + k_val = params["k_val"] + min_singular_value = params["min_singular_value"] + # compute + # U: 6xd, S: dxd, V: d x num-joint + U, S, Vh = torch.linalg.svd(jacobian, full_matrices=False) + S_inv = 1.0 / S + S_inv = torch.where(S > min_singular_value, S_inv, torch.zeros_like(S_inv)) + jacobian_pinv = ( + torch.transpose(Vh, 1, 2)[:, :, :6] + @ torch.diag_embed(S_inv) + @ torch.transpose(U, 1, 2) + ) + delta_joint_pos = k_val * ( + jacobian_pinv @ delta_pose.unsqueeze(-1) + ).squeeze(-1) + elif method == "trans": + # params + k_val = params["k_val"] + # compute + jacobian_T = torch.transpose(jacobian, 1, 2) + delta_joint_pos = params["k_val"] * ( + jacobian_T @ delta_pose.unsqueeze(-1) + ).squeeze(-1) + elif method == "dls": + # params + lambda_val = self.cfg.ik_params["lambda_val"] + # compute + jacobian_T = torch.transpose(jacobian, 1, 2) + lambda_matrix = (lambda_val**2) * torch.eye( + jacobian.shape[1], device=self.device + ) + delta_joint_pos = ( + jacobian_T + @ torch.linalg.solve( + jacobian @ jacobian_T + lambda_matrix, delta_pose.unsqueeze(-1) + ) + ).squeeze(-1) + else: + raise ValueError(f"Unsupported IK method: {method}") + + return delta_joint_pos + + @staticmethod + def _matrix_to_pos_quat(mat): + """Convert a transformation matrix to position and quaternion. + + Args: + mat (torch.Tensor): Transformation matrix tensor of shape (N, 4, 4). + + Returns: + torch.Tensor: Concatenated position and quaternion tensor of shape (N, 7). + """ + # Ensure mat is a batch of matrices + if mat.dim() == 2 and mat.shape == (4, 4): + mat = mat.unsqueeze(0) # Convert (4, 4) to (1, 4, 4) + elif mat.dim() != 3 or mat.shape[1:] != (4, 4): + raise ValueError( + f"Expected mat to have shape (N, 4, 4), but got {mat.shape}" + ) + + # Extract position + pos = mat[:, :3, 3] + + # Extract rotation matrix and convert to quaternion + rot_matrices = mat[:, :3, :3].cpu().numpy() # Convert to NumPy for scipy + quats = Rotation.from_matrix(rot_matrices).as_quat() # (N, 4), [x, y, z, w] + + # Convert quaternion back to torch.Tensor and reorder to [w, x, y, z] + quats = torch.tensor(quats, device=mat.device, dtype=mat.dtype) # (N, 4) + quats = quats[:, [3, 0, 1, 2]] # Reorder to [w, x, y, z] + + # Concatenate position and quaternion + return torch.cat([pos, quats], dim=1) diff --git a/embodichain/lab/sim/solvers/null_space_posture_task.py b/embodichain/lab/sim/solvers/null_space_posture_task.py new file mode 100644 index 00000000..13d4e590 --- /dev/null +++ b/embodichain/lab/sim/solvers/null_space_posture_task.py @@ -0,0 +1,270 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +import numpy as np +from typing import List, Optional, Union, TYPE_CHECKING +from embodichain.utils import logger + +try: + import pinocchio as pin +except ImportError: + logger.log_warning("pinocchio not installed. Install with `pip install pin==2.7.0`") + +try: + from pink.configuration import Configuration + from pink.tasks import Task +except ImportError: + logger.log_warning( + "pin-pink not installed. Install with `pip install pin-pink==3.4.0`" + ) + + +class NullSpacePostureTask(Task): + r"""Pink-based task that adds a posture objective that is in the null space projection of other tasks. + + This task implements posture control in the null space of higher priority tasks + (typically end-effector pose tasks) within the Pink inverse kinematics framework. + + **Mathematical Formulation:** + + For details on Pink Inverse Kinematics optimization formulation visit: https://github.com/stephane-caron/pink + + **Null Space Posture Task Implementation:** + + This task consists of two components: + + 1. **Error Function**: The posture error is computed as: + + .. math:: + + \mathbf{e}(\mathbf{q}) = \mathbf{M} \cdot (\mathbf{q}^* - \mathbf{q}) + + where: + - :math:`\mathbf{q}^*` is the target joint configuration + - :math:`\mathbf{q}` is the current joint configuration + - :math:`\mathbf{M}` is a joint selection mask matrix + + 2. **Jacobian Matrix**: The task Jacobian is the null space projector: + + .. math:: + + \mathbf{J}_{\text{posture}}(\mathbf{q}) = \mathbf{N}(\mathbf{q}) = \mathbf{I} - \mathbf{J}_{\text{primary}}^+ \mathbf{J}_{\text{primary}} + + where: + - :math:`\mathbf{J}_{\text{primary}}` is the combined Jacobian of all higher priority tasks + - :math:`\mathbf{J}_{\text{primary}}^+` is the pseudoinverse of the primary task Jacobian + - :math:`\mathbf{N}(\mathbf{q})` is the null space projector matrix + + For example, if there are two frame tasks (e.g., controlling the pose of two end-effectors), the combined Jacobian + :math:`\mathbf{J}_{\text{primary}}` is constructed by stacking the individual Jacobians for each frame vertically: + + .. math:: + + \mathbf{J}_{\text{primary}} = + \begin{bmatrix} + \mathbf{J}_1(\mathbf{q}) \\ + \mathbf{J}_2(\mathbf{q}) + \end{bmatrix} + + where :math:`\mathbf{J}_1(\mathbf{q})` and :math:`\mathbf{J}_2(\mathbf{q})` are the Jacobians for the first and second frame tasks, respectively. + + The null space projector ensures that joint velocities in the null space produce zero velocity + for the primary tasks: :math:`\mathbf{J}_{\text{primary}} \cdot \dot{\mathbf{q}}_{\text{null}} = \mathbf{0}`. + + **Task Integration:** + + When integrated into the Pink framework, this task contributes to the optimization as: + + .. math:: + + \left\| \mathbf{N}(\mathbf{q}) \mathbf{v} + \mathbf{M} \cdot (\mathbf{q}^* - \mathbf{q}) \right\|_{W_{\text{posture}}}^2 + + This formulation allows the robot to maintain a desired posture while respecting the constraints + imposed by higher priority tasks (e.g., end-effector positioning). + + """ + + def __init__( + self, + cost: float, + lm_damping: float = 0.0, + gain: float = 1.0, + controlled_frames: Optional[List[str]] = None, + controlled_joints: Optional[List[str]] = None, + ) -> None: + r"""Initialize the null space posture task. + + This task maintains a desired joint posture in the null space of higher-priority + frame tasks. Joint selection allows excluding specific joints (e.g., wrist joints + in humanoid manipulation) to prevent large rotational ranges from overwhelming + errors in critical joints like shoulders and waist. + + Args: + cost: Task weighting factor in the optimization objective. + Units: :math:`[\text{cost}] / [\text{rad}]`. + lm_damping: Levenberg-Marquardt regularization scale (unitless). Defaults to 0.0. + gain: Task gain :math:`\alpha \in [0, 1]` for low-pass filtering. + Defaults to 1.0 (no filtering). + controlled_frames: Frame names whose Jacobians define the primary tasks for + null space projection. If None or empty, no projection is applied. + controlled_joints: Joint names to control in the posture task. If None or + empty, all actuated joints are controlled. + """ + super().__init__(cost=cost, gain=gain, lm_damping=lm_damping) + self.target_q: np.ndarray | None = None + self.controlled_frames: list[str] = controlled_frames or [] + self.controlled_joints: list[str] = controlled_joints or [] + self._joint_mask: np.ndarray | None = None + self._frame_names: list[str] | None = None + + def __repr__(self) -> str: + """Human-readable representation of the task.""" + return ( + f"NullSpacePostureTask(cost={self.cost}, gain={self.gain}, lm_damping={self.lm_damping}," + f" controlled_frames={self.controlled_frames}, controlled_joints={self.controlled_joints})" + ) + + def _build_joint_mapping(self, configuration: Configuration) -> None: + """Build joint mask and cache frequently used values. + + Creates a binary mask that selects which joints should be controlled + in the posture task. + + Args: + configuration: Robot configuration containing the model and joint information. + """ + # Create joint mask for full configuration size + self._joint_mask = np.zeros(configuration.model.nq) + + # Create dictionary for joint names to indices (exclude root joint) + joint_names = configuration.model.names.tolist()[1:] + + # Build joint mask efficiently + for i, joint_name in enumerate(joint_names): + if joint_name in self.controlled_joints: + self._joint_mask[i] = 1.0 + + # Cache frame names for performance + self._frame_names = list(self.controlled_frames) + + def set_target(self, target_q: np.ndarray) -> None: + """Set target posture configuration. + + Args: + target_q: Target vector in the configuration space. If the model + has a floating base, then this vector should include + floating-base coordinates (although they have no effect on the + posture task since only actuated joints are controlled). + """ + self.target_q = target_q.copy() + + def set_target_from_configuration(self, configuration: Configuration) -> None: + """Set target posture from a robot configuration. + + Args: + configuration: Robot configuration whose joint angles will be used + as the target posture. + """ + self.set_target(configuration.q) + + def compute_error(self, configuration: Configuration) -> np.ndarray: + r"""Compute posture task error. + + The error computation follows: + + .. math:: + + \mathbf{e}(\mathbf{q}) = \mathbf{M} \cdot (\mathbf{q}^* - \mathbf{q}) + + where :math:`\mathbf{M}` is the joint selection mask and :math:`\mathbf{q}^* - \mathbf{q}` + is computed using Pinocchio's difference function to handle joint angle wrapping. + + Args: + configuration: Robot configuration :math:`\mathbf{q}`. + + Returns: + Posture task error :math:`\mathbf{e}(\mathbf{q})` with the same dimension + as the configuration vector, but with zeros for non-controlled joints. + + Raises: + ValueError: If no posture target has been set. + """ + if self.target_q is None: + raise ValueError("No posture target has been set. Call set_target() first.") + + # Initialize joint mapping if needed + if self._joint_mask is None: + self._build_joint_mapping(configuration) + + # Compute configuration difference using Pinocchio's difference function + # This handles joint angle wrapping correctly + err = pin.difference( + configuration.model, + self.target_q, + configuration.q, + ) + + # Apply pre-computed joint mask to select only controlled joints + return self._joint_mask * err + + def compute_jacobian(self, configuration: Configuration) -> np.ndarray: + r"""Compute the null space projector Jacobian. + + The null space projector is defined as: + + .. math:: + + \mathbf{N}(\mathbf{q}) = \mathbf{I} - \mathbf{J}_{\text{primary}}^+ \mathbf{J}_{\text{primary}} + + where: + - :math:`\mathbf{J}_{\text{primary}}` is the combined Jacobian of all controlled frames + - :math:`\mathbf{J}_{\text{primary}}^+` is the pseudoinverse of the primary task Jacobian + - :math:`\mathbf{I}` is the identity matrix + + The null space projector ensures that joint velocities in the null space produce + zero velocity for the primary tasks: :math:`\mathbf{J}_{\text{primary}} \cdot \dot{\mathbf{q}}_{\text{null}} = \mathbf{0}`. + + If no controlled frames are specified, returns the identity matrix. + + Args: + configuration: Robot configuration :math:`\mathbf{q}`. + + Returns: + Null space projector matrix :math:`\mathbf{N}(\mathbf{q})` with dimensions + :math:`n_q \times n_q` where :math:`n_q` is the number of configuration variables. + """ + # Initialize joint mapping if needed + if self._frame_names is None: + self._build_joint_mapping(configuration) + + # If no frame tasks are defined, return identity matrix (no null space projection) + if not self._frame_names: + return np.eye(configuration.model.nq) + + # Get Jacobians for all frame tasks and combine them + J_frame_tasks = [ + configuration.get_frame_jacobian(frame_name) + for frame_name in self._frame_names + ] + J_combined = np.concatenate(J_frame_tasks, axis=0) + + # Compute null space projector: N = I - J^+ * J + N_combined = ( + np.eye(J_combined.shape[1]) - np.linalg.pinv(J_combined) @ J_combined + ) + + return N_combined diff --git a/embodichain/lab/sim/solvers/opw_solver.py b/embodichain/lab/sim/solvers/opw_solver.py new file mode 100644 index 00000000..1dfd751e --- /dev/null +++ b/embodichain/lab/sim/solvers/opw_solver.py @@ -0,0 +1,734 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +from itertools import product +from typing import Optional, Union, Tuple, Any, Literal, TYPE_CHECKING +from scipy.spatial.transform import Rotation + +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver +from embodichain.utils.warp.kinematics.opw_solver import ( + OPWparam, + opw_fk_kernel, + opw_ik_kernel, + opw_best_ik_kernel, + wp_vec6f, +) +from embodichain.utils.device_utils import standardize_device_string +import warp as wp +import polars as pl + +try: + from py_opw_kinematics import KinematicModel, Robot, EulerConvention +except ImportError: + raise ImportError( + "py_opw_kinematics not installed. Install with `pip install py_opw_kinematics==0.1.6`" + ) + + +if TYPE_CHECKING: + from typing import Self + + +def normalize_to_pi(angle): + angle = (angle + np.pi) % (2.0 * np.pi) - np.pi + return angle + + +@configclass +class OPWSolverCfg(SolverCfg): + """Configuration for OPW inverse kinematics controller.""" + + class_type: str = "OPWSolver" + + # OPW-specific parameters + a1 = 0.0 + a2 = -21.984 + b = 0.0 + c1 = 123.0 + c2 = 285.03 + c3 = 250.75 + c4 = 91.0 + offsets = ( + 0.0, + 82.21350356417211 * np.pi / 180.0, + -167.21710113148163 * np.pi / 180.0, + 0.0, + 0.0, + 0.0, + ) + flip_axes = (False, False, False, False, False, False) + has_parallelogram = False + + # Parameters for the inverse-kinematics method. + ik_params: Optional[dict] = None + + def init_solver( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> "OPWSolver": + """Initialize the solver with the configuration. + + Args: + device (torch.device): The device to use for the solver. Defaults to CPU. + n_sample (int): The number of environments for which the solver is initialized. + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + OPWSolver: An initialized solver instance. + """ + + solver = OPWSolver(cfg=self, device=device, **kwargs) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + solver.set_tcp(tcp) + + return solver + + +class OPWSolver(BaseSolver): + r"""OPW inverse kinematics (IK) controller. + + This controller implements OPW inverse kinematics using various methods for + computing the inverse of the Jacobian matrix. + """ + + def __init__(self, cfg: OPWSolverCfg, device: str = "cpu", **kwargs): + r"""Initializes the OPW kinematics solver. + + This constructor sets up the kinematics solver using OPW methods, + allowing for efficient computation of robot kinematics based on + the specified URDF model. + + Args: + cfg: The configuration for the solver. + device (str, optional): The device to use for the solver (e.g., "cpu" or "cuda"). Defaults to "cpu". + **kwargs: Additional keyword arguments passed to the base solver. + + """ + super().__init__(cfg=cfg, device=device, **kwargs) + if self.device.type == "cpu": + self._init_py_opw_kinematics_solver(cfg, **kwargs) + else: + self._init_warp_solver(cfg, **kwargs) + self.set_tcp(np.eye(4)) + + def _init_py_opw_kinematics_solver(self, cfg: OPWSolverCfg, **kwargs) -> None: + self.kinematic_model = KinematicModel( + a1=cfg.a1, + a2=cfg.a2, + b=cfg.b, + c1=cfg.c1, + c2=cfg.c2, + c3=cfg.c3, + c4=cfg.c4, + offsets=cfg.offsets, + flip_axes=cfg.flip_axes, + has_parallelogram=cfg.has_parallelogram, + ) + self.euler_convention = EulerConvention("ZYX", extrinsic=False, degrees=False) + self.opw_robot = Robot( + self.kinematic_model, self.euler_convention, ee_rotation=(0, 0, 0) + ) + if self.pk_serial_chain != "": + fk_dict = self.pk_serial_chain.forward_kinematics( + th=np.zeros(6), end_only=False + ) + root_tf = fk_dict[list(fk_dict.keys())[0]] + + self.root_base_xpos = root_tf.get_matrix().cpu().numpy() + + def set_tcp(self, xpos: np.ndarray): + super().set_tcp(xpos) + if self.device.type != "cpu": + self._tcp_warp = wp.mat44f(self.tcp_xpos) + tcp_inv = np.eye(4, dtype=float) + tcp_inv[:3, :3] = self.tcp_xpos[:3, :3].T + tcp_inv[:3, 3] = -tcp_inv[:3, :3].T @ self.tcp_xpos[:3, 3] + self._tcp_inv_warp = wp.mat44f(tcp_inv) + + def _init_warp_solver(self, cfg: OPWSolverCfg, **kwargs): + self.params = OPWparam() + # convert unit from mm to m, increate precision + self.params.a1 = cfg.a1 / 1000.0 + self.params.a2 = cfg.a2 / 1000.0 + self.params.b = cfg.b / 1000.0 + self.params.c1 = cfg.c1 / 1000.0 + self.params.c2 = cfg.c2 / 1000.0 + self.params.c3 = cfg.c3 / 1000.0 + self.params.c4 = cfg.c4 / 1000.0 + self.offsets = wp.array( + cfg.offsets, dtype=float, device=standardize_device_string(self.device) + ) + self.sign_corrections = wp.array( + [-1.0 if flip else 1.0 for flip in cfg.flip_axes], + dtype=float, + device=standardize_device_string(self.device), + ) + + def get_fk(self, qpos: torch.tensor, **kwargs) -> torch.tensor: + r""" + Computes the forward kinematics for the end-effector link. + + Args: + qpos (torch.Tensor): Joint positions. Can be a single configuration (dof,) or a batch (batch_size, dof). + **kwargs: Additional keyword arguments for customization. + + Returns: + torch.Tensor: The homogeneous transformation matrix of the end link with TCP applied. + Shape is (4, 4) for single input, or (batch_size, 4, 4) for batch input. + """ + if standardize_device_string(self.device) == "cpu": + return super().get_fk(qpos, **kwargs) + else: + return self.get_fk_warp(qpos, **kwargs) + + def get_fk_warp(self, qpos: torch.tensor, **kwargs) -> torch.tensor: + r""" + Computes the forward kinematics for the end-effector link. + + Args: + qpos (torch.Tensor): Joint positions. Can be a single configuration (dof,) or a batch (batch_size, dof). + **kwargs: Additional keyword arguments for customization. + + Returns: + torch.Tensor: The homogeneous transformation matrix of the end link with TCP applied. + Shape is (4, 4) for single input, or (batch_size, 4, 4) for batch input. + """ + if qpos.shape == (6,): + qpos_batch = qpos[None, :] + else: + qpos_batch = qpos + n_sample = qpos_batch.shape[0] + qpos_wp = wp.from_torch(qpos_batch.reshape(-1)) # dtype=float, device="cuda") + # qpos_wp = wp.array(qpos_batch.detach().cpu().numpy().reshape(-1), dtype=float, device=self.device) + xpos_wp = wp.zeros( + n_sample * 16, dtype=float, device=standardize_device_string(self.device) + ) + wp.launch( + kernel=opw_fk_kernel, + dim=(n_sample), + inputs=[ + qpos_wp, + self._tcp_warp, + self.params, + self.offsets, + self.sign_corrections, + ], + outputs=[xpos_wp], + device=standardize_device_string(self.device), + ) + xpos = wp.to_torch(xpos_wp).reshape(n_sample, 4, 4) + return xpos + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor = None, + return_all_solutions: bool = False, + **kwargs, + ): + """Compute target joint positions using OPW inverse kinematics. + + Args: + target_xpos (torch.Tensor): Current end-effector pose, shape (n_sample, 4, 4). + qpos_seed (torch.Tensor): Current joint positions, shape (n_sample, num_joints). Defaults to None. + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - target_joints (torch.Tensor): Computed target joint positions, shape (n_sample, num_joints). + - success (torch.Tensor): Boolean tensor indicating IK solution validity for each environment, shape (n_sample,). + """ + if self.device.type == "cpu": + return self.get_ik_py_opw( + target_xpos, qpos_seed, return_all_solutions, **kwargs + ) + else: + return self.get_ik_warp( + target_xpos, qpos_seed, return_all_solutions, **kwargs + ) + + def get_ik_warp( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor, + return_all_solutions: bool = False, + **kwargs, + ): + """Compute target joint positions using OPW inverse kinematics. + + Args: + target_xpos (torch.Tensor): Current end-effector pose, shape (n_sample, 4, 4). + qpos_seed (torch.Tensor): Current joint positions, shape (n_sample, num_joints). + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - target_joints (torch.Tensor): Computed target joint positions, shape (n_sample, n_solution, num_joints). + - success (torch.Tensor): Boolean tensor indicating IK solution validity for each environment, shape (n_sample,). + """ + N_SOL = 8 + DOF = 6 + n_sample = target_xpos.shape[0] + + if target_xpos.shape == (4, 4): + target_xpos_batch = target_xpos[None, :, :] + else: + target_xpos_batch = target_xpos + target_xpos_wp = wp.from_torch(target_xpos_batch.reshape(-1)) + + all_qpos_wp = wp.zeros( + n_sample * N_SOL * DOF, + dtype=float, + device=standardize_device_string(self.device), + ) + all_ik_valid_wp = wp.zeros( + n_sample * N_SOL, dtype=int, device=standardize_device_string(self.device) + ) + + # TODO: whether require gradient + wp.launch( + kernel=opw_ik_kernel, + dim=(n_sample), + inputs=( + target_xpos_wp, + self._tcp_inv_warp, + self.params, + self.offsets, + self.sign_corrections, + ), + outputs=[all_qpos_wp, all_ik_valid_wp], + device=standardize_device_string(self.device), + ) + + if return_all_solutions: + all_qpos = wp.to_torch(all_qpos_wp).reshape(n_sample, N_SOL, DOF) + all_ik_valid = wp.to_torch(all_ik_valid_wp).reshape(n_sample, N_SOL) + return all_ik_valid, all_qpos + + if qpos_seed is not None: + qpos_seed_wp = wp.from_torch(qpos_seed.reshape(-1)) + else: + qpos_seed_wp = wp.zeros( + n_sample * DOF, + dtype=float, + device=standardize_device_string(self.device), + ) + joint_weight = kwargs.get("joint_weight", torch.zeros(size=(DOF,), dtype=float)) + joint_weight_wp = wp_vec6f( + joint_weight[0], + joint_weight[1], + joint_weight[2], + joint_weight[3], + joint_weight[4], + joint_weight[5], + ) + best_ik_result_wp = wp.zeros( + n_sample * 6, dtype=float, device=standardize_device_string(self.device) + ) + best_ik_valid_wp = wp.zeros( + n_sample, dtype=int, device=standardize_device_string(self.device) + ) + wp.launch( + kernel=opw_best_ik_kernel, + dim=(n_sample), + inputs=[ + all_qpos_wp, + all_ik_valid_wp, + qpos_seed_wp, + joint_weight_wp, + ], + outputs=[best_ik_result_wp, best_ik_valid_wp], + device=standardize_device_string(self.device), + ) + best_ik_result = wp.to_torch(best_ik_result_wp).reshape(n_sample, 1, 6) + best_ik_valid = wp.to_torch(best_ik_valid_wp) + return best_ik_valid, best_ik_result + + def get_ik_py_opw( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute target joint positions using OPW inverse kinematics. + + Args: + target_xpos (torch.Tensor): Current end-effector position, shape (n_sample, 3). + qpos_seed (torch.Tensor): Current joint positions, shape (n_sample, num_joints). + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - target_joints (torch.Tensor): Computed target joint positions, shape (n_sample, num_joints). + - success (torch.Tensor): Boolean tensor indicating IK solution validity for each environment, shape (n_sample,). + """ + # TODO: opw solver can only get one solution at a time + DOF = 6 + if qpos_seed is not None: + if isinstance(qpos_seed, torch.Tensor): + qpos_seed_np = qpos_seed.detach().cpu().numpy() + else: + qpos_seed_np = np.array(qpos_seed) + else: + qpos_seed_np = np.zeros(DOF) + + if isinstance(target_xpos, torch.Tensor): + target_xpos = target_xpos.detach().cpu().numpy() + + if target_xpos.shape == (4, 4): + target_xpos_batch = target_xpos[None, :, :] + else: + target_xpos_batch = target_xpos + + # TODO: support root base transform + # target_xpos = self.root_base_xpos @ target_xpos + # compute_xpos = target_xpos @ np.linalg.inv(self.tcp_xpos) + + # TODO: single version + # if target_xpos.ndim == 3: + # target_xpos = target_xpos[0] + # position = np.array(compute_xpos[:3, 3]) * 1000 + # rotation = Rotation.from_matrix(compute_xpos[:3, :3]) + # rotation = rotation.as_euler("ZYX") + # solutions = self.opw_robot.inverse((position, rotation)) + # if len(solutions) == 0: + # logger.log_warning("OPWSolver failed: No solutions found.") + # if return_all_solutions: + # return torch.tensor([False]), torch.zeros((1, 1, 6)) + # else: + # return torch.tensor([False]), torch.zeros((1, 6)) + + # ret, qpos = self._select_optimal_solution( + # qpos_seed_np, solutions, weights=None, return_all_valid=return_all_solutions + # ) + # if not ret or len(qpos) == 0: + # logger.log_warning("No valid solutions found within joint limits.") + # if return_all_solutions: + # return torch.tensor([False]), torch.zeros((1, 1, 6)) + # else: + # return torch.tensor([False]), torch.zeros((1, 6)) + + # if return_all_solutions: + # # qpos: (N, 6) -> (1, N, 6) + # qpos_tensor = torch.from_numpy(qpos).float().unsqueeze(0) + # else: + # # qpos: (6,) -> (1, 6) + # qpos_tensor = torch.from_numpy(qpos).float().reshape(1, 6) + + x_list = [] + y_list = [] + z_list = [] + a_list = [] + b_list = [] + c_list = [] + for xpos in target_xpos_batch: + compute_xpos = xpos @ np.linalg.inv(self.tcp_xpos) + position = np.array(compute_xpos[:3, 3]) * 1000 + rotation = Rotation.from_matrix(compute_xpos[:3, :3]) + rotation = rotation.as_euler("ZYX") + x_list.append(position[0]) + y_list.append(position[1]) + z_list.append(position[2]) + a_list.append(rotation[0]) + b_list.append(rotation[1]) + c_list.append(rotation[2]) + poses = pl.DataFrame( + { + "X": x_list, + "Y": y_list, + "Z": z_list, + "A": a_list, + "B": b_list, + "C": c_list, + } + ) + qpos_seed_np = qpos_seed_np.reshape(-1)[:DOF] + res = self.opw_robot.batch_inverse(current_joints=qpos_seed_np, poses=poses) + solutions = res.to_numpy().copy() + is_success = np.any(np.logical_not(np.isnan(solutions)), axis=1) + for i in range(solutions.shape[0]): + for j in range(solutions.shape[1]): + solutions[i, j] = normalize_to_pi(solutions[i, j]) + + if return_all_solutions: + logger.log_warning( + "return_all_solutions=True is not supported in OPWSolverCPUMode. Returning the best solution only." + ) + qpos_tensor = torch.tensor(solutions, dtype=torch.float32, device=self.device) + qpos_tensor = qpos_tensor.reshape(-1, 1, DOF) + return torch.tensor(is_success), qpos_tensor + + def _calculate_dynamic_weights( + self, current_joints, joint_limits, base_weights=None + ) -> np.ndarray: + r"""Calculate dynamic joint weights based on proximity to joint limits. + + This function increases the weight of joints that are close to their limits, making the IK solver + penalize solutions that move joints near their boundaries. The closer a joint is to its limit, + the higher its weight will be, encouraging safer joint configurations. + + Args: + current_joints (np.ndarray): Current joint positions, shape (6,). + joint_limits (list or np.ndarray): List of (min, max) tuples for each joint, shape (6, 2). + base_weights (np.ndarray, optional): Base weights for each joint, shape (6,). Defaults to ones. + + Returns: + np.ndarray: Dynamic weights for each joint, shape (6,). + """ + if base_weights is None: + base_weights = np.ones(6) + + dynamic_weights = np.copy(base_weights) + for i in range(6): + cj = current_joints[i] + if isinstance(cj, np.ndarray): + if cj.size == 1: + cj = float(cj) + else: + cj = float(cj.flat[0]) + jl_min = joint_limits[i][0] + jl_max = joint_limits[i][1] + range_size = jl_max - jl_min + distance_to_min = cj - jl_min + distance_to_max = jl_max - cj + + min_ratio = distance_to_min / range_size + max_ratio = distance_to_max / range_size + danger_ratio = min(float(min_ratio), float(max_ratio)) + if danger_ratio < 0.2: + dynamic_weights[i] *= 5.0 + elif danger_ratio < 0.4: + dynamic_weights[i] *= 2.0 + + return dynamic_weights + + def _select_optimal_solution( + self, + current_joints, + all_solutions, + joint_limits=None, + weights=None, + prev_joints=None, + return_all_valid=False, + ) -> Tuple[bool, np.ndarray]: + r"""Select the optimal IK solution based on joint limits and weighted differences. + + Args: + current_joints (np.ndarray): Current joint positions in radians, shape=(6,) + all_solutions (List[np.ndarray]): List of all possible IK solutions, each solution has shape=(6,) + joint_limits (List[Tuple], optional): Joint limits list [(min1,max1),...,(min6,max6)]. Defaults to None. + weights (np.ndarray, optional): Weight coefficients for each joint, shape=(6,). Defaults to None. + prev_joints (np.ndarray, optional): Previous joint positions in radians, shape=(6,). Defaults to None. + return_all_valid (bool, optional): If True, return all valid solutions instead of just the optimal one. Defaults to False. + + Returns: + Tuple[bool, np.ndarray]: A tuple containing: + - Success flag (True if solution found) + - Joint angles of the optimal solution (single solution) or all valid solutions (if return_all_valid=True) + """ + # Input validation + if current_joints is None or all_solutions is None: + return False, np.array([]) + + # Convert inputs to numpy arrays + current_joints = np.asarray(current_joints).reshape(-1) + all_solutions = [(np.asarray(sol)) for sol in all_solutions] + + # Set default joint limits if none provided + if joint_limits is None: + joint_limits = [ + (-np.pi, np.pi), # joint 1 + (-np.pi, np.pi), # joint 2 + (-np.pi, np.pi), # joint 3 + (-np.pi, np.pi), # joint 4 + (-np.pi, np.pi), # joint 5 + (-np.pi, np.pi), # joint 6 + ] + + # TODO: support funciton to setting safty margin + # SAFETY_MARGIN = np.radians(5.0) + # joint_limits = [ + # (-2.618 + SAFETY_MARGIN, 2.618 - SAFETY_MARGIN), # 约(-145.5°+5°, 124.2°-5°) + # (0.0 + SAFETY_MARGIN, 3.14 - SAFETY_MARGIN), # 约(0°+5°, 180°-5°) + # (-2.967 + SAFETY_MARGIN, 0.0 - SAFETY_MARGIN), # 约(-170°+5°, 0°-5°) + # (-1.745 + SAFETY_MARGIN, 1.745 - SAFETY_MARGIN), # 约(-100°+5°, 100°-5°) + # (-1.22 + SAFETY_MARGIN, 1.22 - SAFETY_MARGIN), # 约(-70°+5°, 70°-5°) + # (-2.0944 + SAFETY_MARGIN, 2.0944 - SAFETY_MARGIN), # 约(-120°+5°, 120°-5°) + # ] + + # Handle empty solution case + if len(all_solutions) == 0: + logger.log_warning("No available solutions found.") + return None, np.array([]) + + # Set default weights if none provided + if weights is None: + weights = np.ones(6) + else: + weights = np.asarray(weights) + + # Initialize previous joints if not provided + if prev_joints is None: + prev_joints = current_joints + else: + prev_joints = np.asarray(prev_joints) + + # Ensure we only work with first 6 joints + current_joints = current_joints[:6] + prev_joints = prev_joints[:6] + + # Calculate dynamic weights considering joint limits + dynamic_weights = self._calculate_dynamic_weights( + current_joints, joint_limits, weights + ) + + # Initialize variables for tracking best solution and all valid solutions with scores + best_score = float("inf") + best_qpos = None + all_valid_solutions = [] # List of (solution, score) tuples for sorting + + # Evaluate each IK solution + for q in all_solutions: + possible_arrays = [] + valid_solution = True + + # Generate possible joint values considering 2π periodicity + for i in range(6): + current_possible_values = [] + # Determine previous movement direction + prev_move = current_joints[i] - prev_joints[i] + + # Prefer offsets in the same direction as previous movement + preferred_offsets = range(0, 3) if prev_move >= 0 else range(-2, 1) + for offset in preferred_offsets: + adjusted_value = q[i] + offset * (2 * np.pi) + if joint_limits[i][0] <= adjusted_value <= joint_limits[i][1]: + current_possible_values.append(adjusted_value) + + # If no values found in preferred direction, try all directions + if not current_possible_values: + for offset in range(-2, 3): + adjusted_value = q[i] + offset * (2 * np.pi) + if joint_limits[i][0] <= adjusted_value <= joint_limits[i][1]: + current_possible_values.append(adjusted_value) + + # If still no valid values, mark solution as invalid + if not current_possible_values: + valid_solution = False + break + + possible_arrays.append(current_possible_values) + + # Skip invalid solutions + if not valid_solution: + continue + + # Helper function to safely normalize weights + def safe_normalize(weights): + max_weight = np.max(weights) + if max_weight > 0: + return weights / max_weight + return np.zeros_like(weights) + + # Evaluate all combinations of possible joint values + for combination in product(*possible_arrays): + solution = np.array(combination) + if solution.size != 6: + continue + + solution = solution.reshape(current_joints.shape) + + # Calculate optimization score for this solution + # 1. Position difference penalty (weighted squared difference) + pos_diff = np.sum((solution - current_joints) ** 2 * dynamic_weights) + + # 2. Joint limit proximity penalty + limit_penalty = 0 + for i in range(6): + margin = 0.05 # 5% safety margin + # Calculate safe operating range + lower = joint_limits[i][0] + margin * ( + joint_limits[i][1] - joint_limits[i][0] + ) + upper = joint_limits[i][1] - margin * ( + joint_limits[i][1] - joint_limits[i][0] + ) + + # Apply penalty if near joint limits + if solution[i] < lower or solution[i] > upper: + normalized_weights = safe_normalize(dynamic_weights) + limit_penalty += 5.0 * (1 - normalized_weights[i]) + + # 3. Direction change penalty (for avoiding sign flips) + direction_penalty = 0 + for i in range(6): + prev_move = current_joints[i] - prev_joints[i] + current_move = solution[i] - current_joints[i] + # Penalize direction reversals + if prev_move * current_move < 0: # Opposite signs + direction_penalty += ( + abs(current_move) * dynamic_weights[i] * 2.0 + ) + + # 4. Velocity continuity penalty (for smooth motion) + velocity_penalty = 0 + for i in range(6): + prev_vel = current_joints[i] - prev_joints[i] + current_vel = solution[i] - current_joints[i] + accel = abs(current_vel - prev_vel) + # Penalize excessive acceleration (>30°/s²) + if accel > np.radians(30): + velocity_penalty += accel * dynamic_weights[i] * 5.0 + + # Combine all penalty terms + total_score = ( + pos_diff + limit_penalty + direction_penalty + velocity_penalty + ) + + # Add to valid solutions list with score (for sorting when return_all_valid=True) + all_valid_solutions.append((solution.copy(), total_score)) + + # Update best solution if current one is better (for single solution return) + if total_score < best_score: + best_score = total_score + best_qpos = solution.copy() + + # Return results based on what was requested + if return_all_valid: + if len(all_valid_solutions) == 0: + return False, np.array([]) + + # Sort solutions by score (ascending order - lower score is better) + all_valid_solutions.sort(key=lambda x: x[1]) + + # Extract only the solutions (remove scores) + sorted_solutions = np.array([sol[0] for sol in all_valid_solutions]) + return True, sorted_solutions + else: + if best_qpos is None: + return False, np.array([]) + return True, best_qpos diff --git a/embodichain/lab/sim/solvers/pink_solver.py b/embodichain/lab/sim/solvers/pink_solver.py new file mode 100644 index 00000000..b9d74a5f --- /dev/null +++ b/embodichain/lab/sim/solvers/pink_solver.py @@ -0,0 +1,417 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from typing import List, Optional, Tuple, Union, TYPE_CHECKING +from embodichain.utils import logger + +from embodichain.lab.sim.utility.import_utils import ( + lazy_import_pinocchio, + lazy_import_pink, +) + +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver + +from embodichain.utils.string import ( + is_regular_expression, + resolve_matching_names_values, +) + +if TYPE_CHECKING: + from typing import Self + + +@configclass +class PinkSolverCfg(SolverCfg): + """Configuration for Pink IK Solver.""" + + class_type: str = "PinkSolver" + + # Solver iteration parameters + pos_eps: float = 5e-4 # Tolerance for convergence for position + rot_eps: float = 5e-4 # Tolerance for convergence for rotation + max_iterations: int = 1000 # Maximum number of iterations for the solver + dt: float = 0.1 # Time step for numerical integration + damp: float = 1e-6 # Damping factor to prevent numerical instability + + # Constraint configuration + is_only_position_constraint: bool = ( + False # Whether to only consider position constraints + ) + + # Path to the mesh files associated with the robot. These files are also loaded by Pinocchio's `robot_wrapper.BuildFromURDF`. + mesh_path: Optional[str] = None + + # A list of tasks for the Pink IK controller. These tasks are controllable by the env action. + # These tasks can be used to control the pose of a frame or the angles of joints. + # For more details, visit: https://github.com/stephane-caron/pink + variable_input_tasks: List["pink.tasks.FrameTask"] = None + + # A list of tasks for the Pink IK controller. These tasks are fixed and not controllable by the env action. + # These tasks can be used to fix the pose of a frame or the angles of joints to a desired configuration. + # For more details, visit: https://github.com/stephane-caron/pink + fixed_input_tasks: List["pink.tasks.FrameTask"] = None + + # Show warning if IK solver fails to find a solution. + show_ik_warnings: bool = True + + # If True, the Pink IK solver will fail and raise an error if any joint limit is violated during optimization. + # PinkSolver will handle the error by setting the last joint positions. + # If False, the solver will ignore joint limit violations and return the closest solution found. + fail_on_joint_limit_violation: bool = True + + # Solver options: + # "clarabel": High-performance SOCP solver written in Rust. + # - Suitable for large-scale problems. + # - Fast and supports sparse matrices. + # + # "ecos": Efficient SOCP solver for small to medium-scale problems. + # - Fast and memory-efficient. + # + # "osqp": Quadratic programming solver based on ADMM. + # - Ideal for sparse and large-scale QP problems. + # - Numerically stable and widely used in robotics/control. + # + # "proxqp": C++ solver for dense and sparse QP problems. + # - Optimized for real-time applications. + # + # "scs": Solver for linear cone programming and SOCP. + # - Suitable for large-scale problems with low precision requirements. + # + # "daqp": Specialized QP solver for real-time and embedded systems. + # - Designed for fast and reliable quadratic programming. + solver_type = "osqp" + + def init_solver(self, **kwargs) -> "PinkSolver": + """Initialize the solver with the configuration. + + Args: + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + PinkSolver: An initialized solver instance. + """ + + solver = PinkSolver(cfg=self, **kwargs) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + + solver.set_tcp(tcp) + + return solver + + +class PinkSolver(BaseSolver): + """Standalone implementation of Pink IK Solver.""" + + def __init__(self, cfg: PinkSolverCfg, **kwargs): + """Initialize the solver with the configuration. + + Args: + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + PinkSolver: An initialized solver instance. + """ + super().__init__(cfg=cfg, **kwargs) + + self.pin = lazy_import_pinocchio() + self.pink = lazy_import_pink() + + from embodichain.lab.sim.solvers.null_space_posture_task import ( + NullSpacePostureTask, + ) + + self.tcp = cfg.tcp + + if cfg.mesh_path is None: + urdf_dir = os.path.dirname(cfg.urdf_path) + cfg.mesh_path = urdf_dir + + # Initialize robot model + self.entire_robot = self.pin.RobotWrapper.BuildFromURDF( + self.cfg.urdf_path, self.cfg.mesh_path, root_joint=None + ) + + self.pink_joint_names = self.entire_robot.model.names.tolist()[ + 1: + ] # Exclude 'universe' joint + + self.pink_dof = ( + self.entire_robot.model.njoints - 1 + ) # Degrees of freedom of robot joints + + # Get reduced robot model + self.robot = self._get_reduce_robot() + + # Initialize Pink configuration + self.pink_cfg = self.pink.configuration.Configuration( + self.robot.model, self.robot.data, self.robot.q0 + ) + + if self.cfg.variable_input_tasks is None: + self.cfg.variable_input_tasks = [ + self.pink.tasks.FrameTask( + frame=self.cfg.end_link_name, # Frame name (use actual frame name from URDF) + position_cost=1.0, # Position cost weight + orientation_cost=1.0, # Orientation cost weight + ) + ] + + if self.cfg.fixed_input_tasks is None: + self.cfg.fixed_input_tasks = [] + + # Set default targets for tasks + for task in self.cfg.variable_input_tasks: + if isinstance(task, NullSpacePostureTask): + task.set_target(self.init_qpos) + continue + task.set_target_from_configuration(self.pink_cfg) + for task in self.cfg.fixed_input_tasks: + task.set_target_from_configuration(self.pink_cfg) + + # Create joint name mappings if provided + if self.cfg.joint_names: + pink_joint_names = self.robot.model.names.tolist()[ + 1: + ] # Exclude 'universe' joint + self.dexsim_to_pink_ordering = [ + self.cfg.joint_names.index(pink_joint) + for pink_joint in pink_joint_names + ] + self.pink_to_dexsim_ordering = [ + pink_joint_names.index(isaac_joint) + for isaac_joint in self.cfg.joint_names + ] + else: + self.dexsim_to_pink_ordering = None + self.pink_to_dexsim_ordering = None + + def _get_reduce_robot(self) -> "pin.RobotWrapper": + """Build a reduced robot model by locking all joints except those in self.joint_names. + + Returns: + pin.RobotWrapper: The reduced robot model with specified joints unlocked. + """ + pink_joint_names = self.entire_robot.model.names.tolist() + + # Lock all joints except those in self.joint_names and 'universe' + fixed_joint_names = [ + name + for name in pink_joint_names + if name not in self.joint_names and name != "universe" + ] + + reduced_robot = self.entire_robot.buildReducedRobot( + list_of_joints_to_lock=fixed_joint_names + ) + return reduced_robot + + def reorder_array( + self, input_array: List[float], reordering_array: List[int] + ) -> List[float]: + """Reorder array elements based on provided indices. + + Args: + input_array: Array to reorder + reordering_array: Indices for reordering + + Returns: + Reordered array + """ + return [input_array[i] for i in reordering_array] + + def update_null_space_joint_targets(self, current_qpos: np.ndarray): + """Update the null space joint targets. + + This method updates the target joint positions for null space posture tasks based on the current + joint configuration. This is useful for maintaining desired joint configurations when the primary + task allows redundancy. + + Args: + current_qpos: The current joint positions of shape (num_joints,). + """ + from embodichain.lab.sim.solvers.null_space_posture_task import ( + NullSpacePostureTask, + ) + + for task in self.cfg.variable_input_tasks: + if isinstance(task, NullSpacePostureTask): + task.set_target(current_qpos) + + def get_ik( + self, + target_xpos: Optional[Union[torch.Tensor, np.ndarray]], + qpos_seed: Optional[Union[torch.Tensor, np.ndarray]] = None, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute target joint positions using inverse kinematics. + + Args: + target_pose (Optional[Union[torch.Tensor, np.ndarray]]): Target end-effector pose + qpos_seed (Optional[Union[torch.Tensor, np.ndarray]]): Seed joint positions + return_all_solutions (bool, optional): Whether to return all IK solutions or just the best one. Defaults to False. + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Target joint positions. (n_sample, 1, dof) of float. + """ + if qpos_seed is None: + qpos_seed = np.zeros(self.dof) + + if isinstance(qpos_seed, torch.Tensor): + qpos_seed = qpos_seed.detach().cpu().numpy() + if qpos_seed.ndim > 1: + qpos_seed = qpos_seed.flatten() + + if target_xpos.ndim == 2: + target_xpos = target_xpos.unsqueeze(0) + if isinstance(target_xpos, torch.Tensor): + target_xpos = target_xpos.detach().cpu().numpy() + + if target_xpos.shape == (1, 4, 4): + target_xpos = target_xpos[0] + + if target_xpos.shape == (4, 4): + xpos = self.pin.SE3(target_xpos) + else: + raise ValueError( + f"target_xpos shape {target_xpos.shape} not supported for SE3 construction." + ) + + self.cfg.variable_input_tasks[0].set_target(xpos) + + # Handle joint ordering if mapping provided + if self.dexsim_to_pink_ordering: + qpos_pink = np.array( + self.reorder_array(qpos_seed, self.dexsim_to_pink_ordering) + ) + else: + qpos_pink = np.array(qpos_seed) + + # Update configuration with current joint positions + self.pink_cfg.update(qpos_pink) + + tasks = self.cfg.variable_input_tasks + self.cfg.fixed_input_tasks + + try: + num_iter = 1 if self.cfg.max_iterations == 1 else self.cfg.max_iterations + + for i in range(num_iter): + # Solve IK to get joint velocities + velocity = self.pink.solve_ik( + configuration=self.pink_cfg, + tasks=tasks, + damping=self.cfg.damp, + dt=self.cfg.dt, + solver=self.cfg.solver_type, + safety_break=self.cfg.fail_on_joint_limit_violation, + ) + self.pink_cfg.integrate_inplace(velocity, self.cfg.dt) + err = self.cfg.variable_input_tasks[0].compute_error(self.pink_cfg) + + # Compute joint position changes + # Update joint positions + # delta_q = velocity * self.cfg.dt + # self.pink_cfg.update(delta_q) + # logger.log_warning(f"Iteration {i}, error: {err}, delta_q: {delta_q}") + pos_achieved = np.linalg.norm(err[:3]) <= self.cfg.pos_eps + + if self.cfg.is_only_position_constraint: + if pos_achieved: + break + else: + ori_achieved = np.linalg.norm(err[3:]) <= self.cfg.rot_eps + if pos_achieved and ori_achieved: + break + + # except NoSolutionFound as e: + except (AssertionError, Exception) as e: + # Print warning and return the current joint positions as the target + # Not using omni.log since its not available in CI during docs build + if self.cfg.show_ik_warnings: + logger.log_warning( + "Warning: IK quadratic solver could not find a solution! Did not update the target joint" + f" positions.\nError: {e}" + ) + return torch.tensor(False, dtype=torch.bool), torch.tensor( + qpos_seed, device=self.device, dtype=torch.float32 + ) + + qpos = torch.tensor( + self.pink_cfg.q[self.pink_to_dexsim_ordering], + device=self.device, + dtype=torch.float32, + ) + + if return_all_solutions: + logger.log_warning( + "return_all_solutions=True is not supported in DifferentialSolver. Returning the best solution only." + ) + + # Add the velocity changes to the current joint positions to get the target joint positions + # target_qpos = torch.add( + # qvel_dexsim, + # torch.tensor(joint_seed, device=self.device, dtype=torch.float32), + # ) + dof = qpos.shape[-1] + qpos = qpos.reshape(-1, 1, dof) + return torch.tensor(True, dtype=torch.bool), qpos + + def _get_fk( + self, + qpos: Optional[Union[torch.Tensor, np.ndarray]], + **kwargs, + ) -> torch.tensor: + """Compute the forward kinematics for the robot given joint positions. + + Args: + qpos (torch.Tensor or np.ndarray): Joint positions, shape should be (nq,). + **kwargs: Additional keyword arguments (not used). + + Returns: + torch.Tensor: The homogeneous transformation matrix (4x4) of the end-effector (after applying TCP). + """ + if isinstance(qpos, torch.Tensor): + qpos_np = qpos.detach().cpu().numpy() + else: + qpos_np = np.array(qpos) + + qpos_np = np.squeeze(qpos_np) + if qpos_np.ndim != 1: + raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}") + + self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np) + + # Retrieve the pose of the specified link + frame_index = self.robot.model.getFrameId(self.end_link_name) + joint_index = self.robot.model.frames[frame_index].parent + xpos_se3 = self.robot.data.oMi.tolist()[joint_index] + + xpos = np.eye(4) + xpos[:3, :3] = xpos_se3.rotation + xpos[:3, 3] = xpos_se3.translation.T + + result = np.dot(xpos, self.tcp_xpos) + return torch.from_numpy(result) diff --git a/embodichain/lab/sim/solvers/pinocchio_solver.py b/embodichain/lab/sim/solvers/pinocchio_solver.py new file mode 100644 index 00000000..719fabbe --- /dev/null +++ b/embodichain/lab/sim/solvers/pinocchio_solver.py @@ -0,0 +1,644 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import numpy as np +from typing import Optional, Union, Tuple, Any, List, TYPE_CHECKING +from itertools import product +from copy import deepcopy + +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver + +from embodichain.lab.sim.utility.import_utils import ( + lazy_import_pinocchio, + lazy_import_casadi, + # lazy_import_pinocchio_casadi, +) + + +if TYPE_CHECKING: + from typing import Self + + +@configclass +class PinocchioSolverCfg(SolverCfg): + + class_type: str = "PinocchioSolver" + + mesh_path: str = None + + # Solver iteration parameters + pos_eps: float = 5e-4 # Tolerance for convergence for position + rot_eps: float = 5e-4 # Tolerance for convergence for rotation + max_iterations: int = 1000 # Maximum number of iterations for the solver + dt: float = 0.1 # Time step for numerical integration + damp: float = 1e-6 # Damping factor to prevent numerical instability + + # Constraint configuration + is_only_position_constraint: bool = ( + False # Whether to only consider position constraints + ) + + # Sampling configuration + num_samples: int = ( + 30 # Number of samples to generate different joint seeds for IK iterations + ) + + def init_solver(self, **kwargs) -> "PinocchioSolver": + """Initialize the solver with the configuration. + + Args: + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + PinocchioSolver: An initialized solver instance. + """ + + solver = PinocchioSolver(cfg=self, **kwargs) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + + solver.set_tcp(tcp) + + return solver + + +class PinocchioSolver(BaseSolver): + def __init__(self, cfg: PinocchioSolverCfg, **kwargs): + super().__init__(cfg=cfg, **kwargs) + + self.pin = lazy_import_pinocchio() + self.casadi = lazy_import_casadi() + # self.cpin = lazy_import_pinocchio_casadi() + + # Set Tool Center Point (TCP) + self.tcp = cfg.tcp + + # Set IK solver parameters + self.pos_eps = cfg.pos_eps + self.rot_eps = cfg.rot_eps + self.max_iterations = cfg.max_iterations + self.dt = cfg.dt + self.damp = cfg.damp + self.is_only_position_constraint = cfg.is_only_position_constraint + self.num_samples = cfg.num_samples + + # Set mesh path if not provided + if cfg.mesh_path is None: + urdf_dir = os.path.dirname(cfg.urdf_path) + cfg.mesh_path = urdf_dir + + # Load full robot model from URDF + self.entire_robot = self.pin.RobotWrapper.BuildFromURDF( + cfg.urdf_path, cfg.mesh_path, root_joint=None + ) + + # Get all joint names and degrees of freedom (excluding 'universe') + self.all_joint_names = self.entire_robot.model.names.tolist()[ + 1: + ] # Exclude 'universe' joint + self.all_dof = ( + self.entire_robot.model.njoints - 1 + ) # Degrees of freedom of robot joints + + # Build reduced robot model (only relevant joints unlocked) + self.robot = self._get_reduce_robot() + self.joint_names = self.robot.model.names.tolist()[ + 1: + ] # Exclude 'universe' joint + self.dof = ( + self.robot.model.njoints - 1 + ) # Degrees of freedom of reduced robot joints + + self.upper_position_limits = self.robot.model.upperPositionLimit + self.lower_position_limits = self.robot.model.lowerPositionLimit + + self.ik_nearest_weight = np.ones(self.dof) + + # TODO: The Casadi-based solver is currently disabled due to stability issues. + # Note: Casadi-based optimization is currently prone to divergence and requires further debugging and optimization. + if __debug__ and False: + # Creating Casadi models and data for symbolic computing + self.cmodel = self.cpin.Model(self.robot.model) + self.cdata = self.cmodel.createData() + self.cq = self.casadi.SX.sym("q", self.robot.model.nq, 1) + self.cTf = self.casadi.SX.sym("Tf", 4, 4) + self.cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + self.ee_id = self.robot.model.getFrameId(self.end_link_name) + + # Define error functions for position and orientation + self.translational_error = self.casadi.Function( + "translational_error", + [self.cq, self.cTf], + [self.cdata.oMf[self.ee_id].translation - self.cTf[:3, 3]], + ) + self.rotational_error = self.casadi.Function( + "rotational_error", + [self.cq, self.cTf], + [ + self.cpin.log3( + self.cdata.oMf[self.ee_id].rotation @ self.cTf[:3, :3].T + ) + ], + ) + + # Set up CasADi optimization problem + self.opti = self.casadi.Opti() + self.var_q = self.opti.variable(self.robot.model.nq) + self.var_q_last = self.opti.parameter(self.robot.model.nq) + self.param_tf = self.opti.parameter(4, 4) + self.translational_cost = self.casadi.sumsqr( + self.translational_error(self.var_q, self.param_tf) + ) + self.rotation_cost = self.casadi.sumsqr( + self.rotational_error(self.var_q, self.param_tf) + ) + self.regularization_cost = self.casadi.sumsqr(self.var_q) + self.smooth_cost = self.casadi.sumsqr(self.var_q - self.var_q_last) + + # Add joint position constraints to ensure the solution stays within physical joint limits. + self.opti.subject_to( + self.opti.bounded( + self.robot.model.lowerPositionLimit, + self.var_q, + self.robot.model.upperPositionLimit, + ) + ) + + # Define the objective function for IK optimization: + # - Prioritize end-effector position accuracy (high weight) + # - Include orientation accuracy + # - Add regularization to avoid extreme joint values + # - Encourage smoothness between consecutive solutions + self.opti.minimize( + 100 * self.translational_cost + + 50 * self.rotation_cost + + 0.02 * self.regularization_cost + + 0.1 * self.smooth_cost + ) + + # Set solver options for IPOPT + opts = { + "ipopt": { + "print_level": 0, + "max_iter": self.max_iterations, + "tol": self.pos_eps, + }, + "print_time": False, + "calc_lam_p": True, + } + self.opti.solver("ipopt", opts) + + # Initialize joint positions to zero + self.init_qpos = np.zeros(self.robot.model.nq) + + # Perform forward kinematics with zero configuration + self.pin.forwardKinematics(self.robot.model, self.robot.data, self.init_qpos) + + # Retrieve the pose of the specified root link + frame_index = self.robot.model.getFrameId(self.root_link_name) + root_base_pose = self.robot.model.frames[frame_index].placement + self.root_base_xpos = np.eye(4) + self.root_base_xpos[:3, :3] = root_base_pose.rotation + self.root_base_xpos[:3, 3] = root_base_pose.translation.T + + def _get_reduce_robot(self) -> "pin.RobotWrapper": + """Build a reduced robot model by locking all joints except those in self.joint_names. + + Returns: + pin.RobotWrapper: The reduced robot model with specified joints unlocked. + """ + all_joint_names = self.entire_robot.model.names.tolist() + + # Lock all joints except those in self.joint_names and 'universe' + fixed_joint_names = [ + name + for name in all_joint_names + if name not in self.joint_names and name != "universe" + ] + + reduced_robot = self.entire_robot.buildReducedRobot( + list_of_joints_to_lock=fixed_joint_names + ) + return reduced_robot + + def set_tcp(self, tcp: np.ndarray): + self.tcp = tcp + + def get_iteration_params(self) -> dict: + r"""Returns the current iteration parameters. + + Returns: + dict: A dictionary containing the current values of: + - pos_eps (float): Pos convergence threshold + - rot_eps (float): Rot convergence threshold + - max_iterations (int): Maximum number of iterations. + - dt (float): Time step size. + - damp (float): Damping factor. + - num_samples (int): Number of samples. + - is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. + """ + return { + "pos_eps": self._pos_eps, + "rot_eps": self._rot_eps, + "max_iterations": self._max_iterations, + "dt": self._dt, + "damp": self._damp, + "num_samples": self._num_samples, + } + + def set_iteration_params( + self, + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1000, + dt: float = 0.1, + damp: float = 1e-6, + num_samples: int = 30, + is_only_position_constraint: bool = False, + ) -> bool: + r"""Sets the iteration parameters for the kinematics solver. + + Args: + pos_eps (float): Pos convergence threshold, must be positive. + rot_eps (float): Rot convergence threshold, must be positive. + max_iterations (int): Maximum number of iterations, must be positive. + dt (float): Time step size, must be positive. + damp (float): Damping factor, must be non-negative. + num_samples (int): Number of samples, must be positive. + is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. + + Returns: + bool: True if all parameters are valid and set, False otherwise. + """ + # TODO: Check which parameters are no longer needed. + # Validate parameters + if pos_eps <= 0: + logger.log_warning("Pos epsilon must be positive.") + return False + if rot_eps <= 0: + logger.log_warning("Rot epsilon must be positive.") + return False + if max_iterations <= 0: + logger.log_warning("Max iterations must be positive.") + return False + if dt <= 0: + logger.log_warning("Time step must be positive.") + return False + if damp < 0: + logger.log_warning("Damping factor must be non-negative.") + return False + if num_samples <= 0: + logger.log_warning("Number of samples must be positive.") + return False + + # Set parameters if all are valid + self.pos_eps = pos_eps + self.rot_eps = rot_eps + self.max_iterations = max_iterations + self.dt = dt + self.damp = damp + self.num_samples = num_samples + self.is_only_position_constraint = is_only_position_constraint + + if False: + opts = { + "ipopt": { + "print_level": 0, + "max_iter": self.max_iterations, + "tol": self.pos_eps, + }, + "print_time": False, + "calc_lam_p": False, + } + self.opti.solver("ipopt", opts) + + return True + + def qpos_to_limits( + self, + q: np.ndarray, + joint_seed: np.ndarray, + ): + """Adjusts the joint positions (q) to be within specified limits and as close as possible to the joint seed, + while minimizing the total weighted difference. + + Args: + q (np.ndarray): The original joint positions. + joint_seed (np.ndarray): The desired (seed) joint positions. + + Returns: + np.ndarray: The adjusted joint positions within the specified limits. + """ + best_qpos_limit = np.copy(q) + best_total_q_diff = float("inf") + + # Initialize a list for possible values for each joint + possible_arrays = [] + + if self.ik_nearest_weight is None: + self.ik_nearest_weight = np.ones_like(best_qpos_limit) + + # Generate possible values for each joint + dof_num = len(q) + for i in range(dof_num): + current_possible_values = [] + + # Calculate how many 2π fits into the adjustment to the limits + lower_adjustment = (q[i] - self.lower_position_limits[i]) // (2 * np.pi) + upper_adjustment = (self.upper_position_limits[i] - q[i]) // (2 * np.pi) + + # Consider the current value and its periodic adjustments + for offset in range( + int(lower_adjustment) - 1, int(upper_adjustment) + 2 + ): # Adjust by calculated limits + adjusted_value = q[i] + offset * (2 * np.pi) + + # Check if the adjusted value is within limits + if ( + self.lower_position_limits[i] + <= adjusted_value + <= self.upper_position_limits[i] + ): + current_possible_values.append(adjusted_value) + + # Also check the original value + if self.lower_position_limits[i] <= q[i] <= self.upper_position_limits[i]: + current_possible_values.append(q[i]) + + if not current_possible_values: + return [] # If no possible values for an active joint + possible_arrays.append(current_possible_values) + + # Generate all possible combinations + all_possible_combinations = product(*possible_arrays) + + # Check each combination and calculate the absolute difference sum + for combination in all_possible_combinations: + total_q_diff = np.sum( + np.abs(np.array(combination) - joint_seed) * self.ik_nearest_weight + ) + + # If a smaller difference sum is found, update the best solution + if total_q_diff < best_total_q_diff: + best_total_q_diff = total_q_diff + best_qpos_limit = np.array(combination) + + return best_qpos_limit + + def get_ik( + self, + target_xpos: Optional[Union[torch.Tensor, np.ndarray]], + qpos_seed: np.ndarray = None, + qvel_seed: np.ndarray = None, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[bool, np.ndarray]: + """Solve inverse kinematics (IK) for the robot to achieve the specified end-effector pose. + + Args: + target_xpos (torch.Tensor or np.ndarray): Desired end-effector pose as a (4, 4) homogeneous transformation matrix. + qpos_seed (np.ndarray, optional): Initial joint positions used as the seed for optimization. If None, uses zero configuration. + qvel_seed (np.ndarray, optional): Initial joint velocities (not used in current implementation). + return_all_solutions (bool, optional): If True, return all valid IK solutions found; otherwise, return only the best solution. Default is False. + **kwargs: Additional keyword arguments for future extensions. + + Returns: + Tuple[bool, np.ndarray]: + - success (bool or torch.BoolTensor): True if a valid solution is found, False otherwise. + - qpos (np.ndarray or torch.Tensor): Joint positions that achieve the target pose. If no solution, returns the seed joint positions. + """ + if qpos_seed is not None: + if isinstance(qpos_seed, torch.Tensor): + self.init_qpos = qpos_seed.detach().cpu().numpy() + else: + self.init_qpos = np.array(qpos_seed) + + if isinstance(target_xpos, torch.Tensor): + target_xpos = target_xpos.detach().cpu().numpy() + + if target_xpos.ndim == 3: + target_xpos = target_xpos[0] + + target_xpos = self.root_base_xpos @ target_xpos + compute_xpos = target_xpos @ np.linalg.inv(self.tcp_xpos) + + frame_index = self.robot.model.getFrameId(self.end_link_name) + joint_index = self.robot.model.frames[frame_index].parent + + l2w = self.pin.SE3() + l2w.translation[:] = compute_xpos[:3, 3] + l2w.rotation[:] = compute_xpos[:3, :3] + l2j = self.robot.model.frames[frame_index].placement + oMdes = l2w * l2j.inverse() + + # Deep copy joint seed to avoid modifying the original seed + q = deepcopy(self.init_qpos).astype(np.float64).flatten() + + for i in range(self.max_iterations): + # Perform forward kinematics to compute the current pose + self.pin.forwardKinematics(self.robot.model, self.robot.data, q) + current_pose_se3 = self.robot.data.oMi[joint_index] + + if self.is_only_position_constraint: + # Fix the rotation part of the pose + fixed_pose = np.eye(4) + fixed_pose[:3, :3] = compute_xpos[:3, :3] # Use target rotation + fixed_pose[ + :3, 3 + ] = current_pose_se3.translation.T # Use current position + fixed_pose_SE3 = self.pin.SE3(fixed_pose) + current_pose_se3 = self.pin.SE3(fixed_pose_SE3) + + iMd = current_pose_se3.actInv(oMdes) # Calculate the pose error + err = self.pin.log6(iMd).vector # Get the error vector + + # Check position convergence + pos_converged = np.linalg.norm(err[:3]) < self.pos_eps + + if self.is_only_position_constraint: + if pos_converged: + # Convergence achieved, apply joint limits + q = self.qpos_to_limits(q, self.init_qpos) + if 0 == len(q): + continue + return torch.tensor([True], dtype=torch.bool), torch.from_numpy( + q + ).to(dtype=torch.float32) + else: + # Check rotation convergence + rot_converged = np.linalg.norm(err[3:]) < self.rot_eps + + # Check for overall convergence + if pos_converged and rot_converged: + # Convergence achieved, apply joint limits + q = self.qpos_to_limits(q, self.init_qpos) + if 0 == len(q): + continue + return torch.tensor([True], dtype=torch.bool), torch.from_numpy( + q + ).to(dtype=torch.float32) + + # Compute the Jacobian + J = self.pin.computeJointJacobian( + self.robot.model, self.robot.data, q, joint_index + ) + Jlog = self.pin.Jlog6(iMd.inverse()) + J = -Jlog @ J + + # Damped least squares + JJt = J @ J.T + JJt[np.diag_indices_from(JJt)] += self.damp + # Compute the velocity update + v = -(J.T @ np.linalg.solve(JJt, err)) + + # Update joint positions + new_q = self.pin.integrate(self.robot.model, q, v * self.dt) + q = new_q + + # Return failure and the last computed joint positions + return torch.tensor([False], dtype=torch.bool), torch.from_numpy( + np.array(q) + ).to(dtype=torch.float32) + + # TODO: The Casadi-based solver is currently disabled due to stability issues. + # Note: Casadi-based optimization is currently prone to divergence and requires further debugging and optimization. + if __debug__ and False: + self.opti.set_initial(self.var_q, self.init_qpos) + + self.opti.set_value(self.param_tf, compute_xpos) + + try: + num_iter = 1 if self.max_iterations == 1 else self.max_iterations + + for i in range(num_iter): + self.opti.set_value(self.var_q_last, self.init_qpos) + sol = self.opti.solve() + sol_q = self.opti.value(self.var_q) + # self.smooth_filter.add_data(sol_q) + # sol_q = self.smooth_filter.filtered_data + self.init_qpos = sol_q + + # if qvel_seed is not None: + # v = qvel_seed * 0.0 + # else: + # v = (sol_q - self.init_qpos) * 0.0 + # sol_tauff = pin.rnea( + # self.robot.model, + # self.robot.data, + # sol_q, + # v, + # np.zeros(self.robot.model.nv), + # ) + + temp_xpos = self._get_fk(sol_q) + err = temp_xpos - target_xpos + pos_converged = np.linalg.norm(err[:3]) < self.pos_eps + print(f"Iter {i}: pos_err={np.linalg.norm(err[:3])}") + + if self.is_only_position_constraint: + if pos_converged: + break + else: + rot_converged = np.linalg.norm(err[:3, :3]) < self.rot_eps + if pos_converged and rot_converged: + break + + if return_all_solutions: + logger.log_warning( + "return_all_solutions=True is not supported in DifferentialSolver. Returning the best solution only." + ) + + return torch.tensor(True, dtype=torch.bool), torch.from_numpy(sol_q).to( + dtype=torch.float32 + ) + + except Exception as e: + logger.log_warning(f"IK solver failed to converge. Debug info: {e}") + + sol_q = self.opti.debug.value(self.var_q) + # self.smooth_filter.add_data(sol_q) + # sol_q = self.smooth_filter.filtered_data + self.init_qpos = sol_q + + # if qvel_seed is not None: + # v = qvel_seed * 0.0 + # else: + # v = (sol_q - self.init_qpos) * 0.0 + + # sol_tauff = pin.rnea( + # self.robot.model, + # self.robot.data, + # sol_q, + # v, + # np.zeros(self.robot.model.nv), + # ) + + logger.log_debug( + f"sol_q:{sol_q} \nmotorstate: \n{qpos_seed} \nwrist_pose: \n{target_xpos}" + ) + + if return_all_solutions: + logger.log_warning( + "return_all_solutions=True is not supported in DifferentialSolver. Returning the best solution only." + ) + + return torch.tensor(False, dtype=torch.bool), torch.from_numpy( + np.array(qpos_seed) + ).to(dtype=torch.float32) + + def _get_fk( + self, + qpos: Optional[Union[torch.Tensor, np.ndarray]], + **kwargs, + ) -> np.ndarray: + """Compute the forward kinematics for the robot given joint positions. + + Args: + qpos (torch.Tensor or np.ndarray): Joint positions, shape should be (nq,). + **kwargs: Additional keyword arguments (not used). + + Returns: + np.ndarray: The resulting end-effector pose as a (4, 4) homogeneous transformation matrix. + """ + if isinstance(qpos, torch.Tensor): + qpos_np = qpos.detach().cpu().numpy() + else: + qpos_np = np.array(qpos) + + qpos_np = np.squeeze(qpos_np) + if qpos_np.ndim != 1: + raise ValueError(f"qpos shape must be (nq,), but got {qpos_np.shape}") + + self.pin.forwardKinematics(self.robot.model, self.robot.data, qpos_np) + + # Retrieve the pose of the specified link + frame_index = self.robot.model.getFrameId(self.end_link_name) + joint_index = self.robot.model.frames[frame_index].parent + xpos_se3 = self.robot.data.oMi.tolist()[joint_index] + + xpos = np.eye(4) + xpos[:3, :3] = xpos_se3.rotation + xpos[:3, 3] = xpos_se3.translation.T + + result = np.dot(xpos, self.tcp_xpos) + return result diff --git a/embodichain/lab/sim/solvers/pytorch_solver.py b/embodichain/lab/sim/solvers/pytorch_solver.py new file mode 100644 index 00000000..238e8912 --- /dev/null +++ b/embodichain/lab/sim/solvers/pytorch_solver.py @@ -0,0 +1,603 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + +from typing import Optional, Union, Tuple, List, TYPE_CHECKING +from dataclasses import MISSING +from copy import deepcopy + +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver +from embodichain.lab.sim.solvers.qpos_seed_sampler import QposSeedSampler + +if TYPE_CHECKING: + from typing import Self + +from embodichain.lab.sim.utility.import_utils import ( + lazy_import_pytorch_kinematics, +) + + +@configclass +class PytorchSolverCfg(SolverCfg): + """Configuration for the pytorch kinematics solver used in the robot simulation. + + This configuration includes properties related to the solver setup, such as the URDF path, + the end link name, and the root link name, along with the Tool Center Point (TCP). + """ + + class_type: str = "PytorchSolver" + + # Solver iteration parameters + pos_eps: float = 5e-4 + """Tolerance for convergence for position""" + + rot_eps: float = 5e-4 + """Tolerance for convergence for rotation""" + + max_iterations: int = 500 + """Maximum number of iterations for the solver""" + + dt: float = 0.1 + """Time step for numerical integration""" + + damp: float = 1e-6 + """Damping factor to prevent numerical instability""" + + is_only_position_constraint: bool = False + """Flag to indicate whether the solver should only consider position constraints.""" + + num_samples: int = 5 + """Number of samples to generate different joint seeds for IK iterations. + + A higher number of samples increases the chances of finding a valid solution + """ + + ik_nearest_weight: Optional[List[float]] = None + """Weights for the inverse kinematics nearest calculation. + + The weights influence how the solver prioritizes closeness to the seed position + when multiple solutions are available. + """ + + def init_solver( + self, device: torch.device = torch.device("cpu"), **kwargs + ) -> "PytorchSolver": + """Initialize the solver with the configuration. + + Args: + device (torch.device): The device to use for the solver. Defaults to CPU. + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + PytorchSolver: An initialized solver instance. + """ + + solver = PytorchSolver(cfg=self, device=device, **kwargs) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + solver.set_tcp(tcp) + + return solver + + +def ensure_pose_shape(func): + """ + Decorator to ensure the input target_pose is of shape (n, 4, 4). + If input is (4, 4), it will be converted to (1, 4, 4). + Raises ValueError if shape is invalid. + """ + + def wrapper(self, target_xpos, *args, **kwargs): + target_xpos = torch.as_tensor( + target_xpos, device=self.device, dtype=torch.float32 + ) + if target_xpos.dim() == 2: + if target_xpos.shape != (4, 4): + raise ValueError("target_xpos must be of shape (4, 4) or (n, 4, 4).") + target_xpos = target_xpos.unsqueeze(0) + elif target_xpos.dim() == 3: + if target_xpos.shape[1:] != (4, 4): + raise ValueError("target_xpos must be of shape (4, 4) or (n, 4, 4).") + else: + raise ValueError( + "target_xpos must be a tensor of shape (4, 4) or (n, 4, 4)." + ) + return func(self, target_xpos, *args, **kwargs) + + return wrapper + + +class PytorchSolver(BaseSolver): + def __init__( + self, + cfg: PytorchSolverCfg, + device: str = None, + **kwargs, + ): + r"""Initializes the PyTorch kinematics solver. + + This constructor sets up the kinematics solver using PyTorch, + allowing for efficient computation of robot kinematics based on + the specified URDF model. + + Args: + cfg: The configuration for the solver. + device (str, optional): The device to use for the solver (e.g., "cpu" or "cuda"). + **kwargs: Additional keyword arguments passed to the base solver. + + """ + super().__init__(cfg=cfg, device=device, **kwargs) + + self.pk = lazy_import_pytorch_kinematics() + + # Initialize solver parameters from configuration + self._pos_eps = cfg.pos_eps + self._rot_eps = cfg.rot_eps + self._max_iterations = cfg.max_iterations + self._dt = cfg.dt + self._damp = cfg.damp + self._is_only_position_constraint = cfg.is_only_position_constraint + self._num_samples = cfg.num_samples + + # Get agent joint limits. + self.lim = torch.tensor( + self.pk_serial_chain.get_joint_limits(), device=self.device + ) + + # Inverse kinematics is available via damped least squares (iterative steps with Jacobian pseudo-inverse damped to avoid oscillation near singularlities). + self.pik = self.pk.PseudoInverseIK( + self.pk_serial_chain, + pos_tolerance=self._pos_eps, + rot_tolerance=self._rot_eps, + joint_limits=self.lim.T, + early_stopping_any_converged=True, + max_iterations=self._max_iterations, + lr=self._dt, + num_retries=1, + ) + + self.dof = self.pk_serial_chain.n_joints + + self.upper_position_limits = self.pk_serial_chain.high + self.lower_position_limits = self.pk_serial_chain.low + + def get_iteration_params(self) -> dict: + r"""Returns the current iteration parameters. + + Returns: + dict: A dictionary containing the current values of: + - pos_eps (float): Pos convergence threshold + - rot_eps (float): Rot convergence threshold + - max_iterations (int): Maximum number of iterations. + - dt (float): Time step size. + - damp (float): Damping factor. + - num_samples (int): Number of samples. + - is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. + """ + return { + "pos_eps": self._pos_eps, + "rot_eps": self._rot_eps, + "max_iterations": self._max_iterations, + "dt": self._dt, + "damp": self._damp, + "num_samples": self._num_samples, + } + + def set_iteration_params( + self, + pos_eps: float = 5e-4, + rot_eps: float = 5e-4, + max_iterations: int = 1000, + dt: float = 0.1, + damp: float = 1e-6, + num_samples: int = 30, + is_only_position_constraint: bool = False, + ) -> bool: + r"""Sets the iteration parameters for the kinematics solver. + + Args: + pos_eps (float): Pos convergence threshold, must be positive. + rot_eps (float): Rot convergence threshold, must be positive. + max_iterations (int): Maximum number of iterations, must be positive. + dt (float): Time step size, must be positive. + damp (float): Damping factor, must be non-negative. + num_samples (int): Number of samples, must be positive. + is_only_position_constraint (bool): Flag to indicate whether the solver should only consider position constraints. + + Returns: + bool: True if all parameters are valid and set, False otherwise. + """ + # Validate parameters + if pos_eps <= 0: + logger.log_warning("Pos epsilon must be positive.") + return False + if rot_eps <= 0: + logger.log_warning("Rot epsilon must be positive.") + return False + if max_iterations <= 0: + logger.log_warning("Max iterations must be positive.") + return False + if dt <= 0: + logger.log_warning("Time step must be positive.") + return False + if damp < 0: + logger.log_warning("Damping factor must be non-negative.") + return False + if num_samples <= 0: + logger.log_warning("Number of samples must be positive.") + return False + + # Set parameters if all are valid + self._pos_eps = pos_eps + self._rot_eps = rot_eps + self._max_iterations = max_iterations + self._dt = dt + self._damp = damp + self._num_samples = num_samples + self._is_only_position_constraint = is_only_position_constraint + + self.pik = self.pk.PseudoInverseIK( + self.pk_serial_chain, + pos_tolerance=self._pos_eps, + rot_tolerance=self._rot_eps, + joint_limits=self.lim.T, + early_stopping_any_converged=True, + max_iterations=self._max_iterations, + lr=self._dt, + num_retries=1, + ) + + return True + + def _compute_inverse_kinematics( + self, target_pose: torch.Tensor, joint_seed: torch.Tensor + ) -> Tuple[Union[bool, torch.Tensor], torch.Tensor]: + r"""Computes the inverse kinematics solutions for the given target poses and joint seeds. + + Args: + target_pose (torch.Tensor): The target poses represented as a (batch_size, 4, 4) tensor. + joint_seed (torch.Tensor): The initial joint positions used as a seed. It can be either a 1D tensor of shape (dof,) or a 2D tensor of shape (batch_size, dof). + + Returns: + Tuple[Union[bool, torch.Tensor], torch.Tensor]: + - First element: + - If solutions exist: torch.BoolTensor of shape (batch_size,) indicating convergence per pose + - If no solutions: Python False + - Second element: + - If solutions exist: torch.Tensor of shape (batch_size, dof) containing joint solutions + - If no solutions: Empty torch.Tensor + """ + target_pose = target_pose.to(self.device).float() + joint_seed = joint_seed.to(self.device).float() + + # Extract translation and rotation parts + pos = target_pose[:, :3, 3] + rot = target_pose[:, :3, :3] + + tf = self.pk.Transform3d( + pos=pos, + rot=rot, + device=self.device, + ) + self.pik.initial_config = joint_seed + + result = self.pik.solve(tf) + + if result.converged_any.any().item(): + return result.converged_any, result.solutions[:, 0, :].squeeze(0) + + return False, torch.empty(0) + + @staticmethod + def _qpos_to_limits_single( + q: torch.Tensor, + joint_seed: torch.Tensor, + lower_position_limits: torch.Tensor, + upper_position_limits: torch.Tensor, + ik_nearest_weight: torch.Tensor, + periodic_mask: torch.Tensor = None, # Optional mask for periodic joints + ) -> torch.Tensor: + """ + Adjusts the given joint positions (q) to fit within the specified limits while minimizing the difference to the seed position. + + Args: + q (torch.Tensor): The initial joint positions. + joint_seed (torch.Tensor): The seed joint positions for comparison. + lower_position_limits (torch.Tensor): The lower bounds for the joint positions. + upper_position_limits (torch.Tensor): The upper bounds for the joint positions. + ik_nearest_weight (torch.Tensor): The weights for the inverse kinematics nearest calculation. + periodic_mask (torch.Tensor, optional): Boolean mask indicating which joints are periodic. + + Returns: + torch.Tensor: The adjusted joint positions that fit within the limits. + """ + device = q.device + joint_seed = joint_seed.to(device) + lower = lower_position_limits.to(device) + upper = upper_position_limits.to(device) + weight = ik_nearest_weight.to(device) + + # If periodic_mask is not provided, assume all joints are periodic + if periodic_mask is None: + periodic_mask = torch.ones_like(q, dtype=torch.bool, device=device) + + # Only enumerate [-2π, 0, 2π] for periodic joints, single value for non-periodic + offsets = torch.tensor([-2 * torch.pi, 0, 2 * torch.pi], device=device) + candidate_list = [] + for i in range(q.size(0)): + if periodic_mask[i]: + candidate_list.append(q[i] + offsets) + else: + candidate_list.append(q[i].unsqueeze(0)) + # Generate all possible combinations + mesh = torch.meshgrid(*candidate_list, indexing="ij") + candidates = torch.stack([m.reshape(-1) for m in mesh], dim=1) + # Filter candidates that are out of limits + mask = (candidates >= lower) & (candidates <= upper) + valid_mask = mask.all(dim=1) + valid_candidates = candidates[valid_mask] + if valid_candidates.shape[0] == 0: + return torch.tensor([]).to(device) + # Compute weighted distance to seed and select the closest + diffs = torch.abs(valid_candidates - joint_seed) * weight + distances = torch.sum(diffs, dim=1) + min_idx = torch.argmin(distances) + return valid_candidates[min_idx] + + def _qpos_to_limits( + self, qpos_list_split: torch.Tensor, joint_seed: torch.Tensor + ) -> torch.Tensor: + r"""Adjusts a batch of joint positions to fit within joint limits and minimize the weighted distance to the seed position. + + Args: + qpos_list_split (torch.Tensor): Batch of candidate joint positions, shape (N, dof). + joint_seed (torch.Tensor): The reference joint positions for comparison, shape (dof,). + + Returns: + torch.Tensor: Batch of adjusted joint positions that fit within the limits, shape (M, dof), + where M <= N (invalid candidates are filtered out). + """ + + periodic_mask = torch.ones_like( + qpos_list_split[0], dtype=torch.bool, device=self.device + ) + + adjusted_qpos_list = [ + self._qpos_to_limits_single( + q, + joint_seed, + self.lower_position_limits, + self.upper_position_limits, + self.ik_nearest_weight, + periodic_mask, + ) + for q in qpos_list_split + ] + + # Filter out empty results + adjusted_qpos_list = [q for q in adjusted_qpos_list if q.numel() > 0] + + return ( + torch.stack(adjusted_qpos_list).to(qpos_list_split.device) + if adjusted_qpos_list + else torch.tensor([], device=self.device) + ) + + @ensure_pose_shape + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor = None, + num_samples: int = None, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Computes the inverse kinematics for given target poses. + + This function generates random joint configurations within the specified limits, + including the provided joint_seed, and attempts to find valid inverse kinematics solutions. + It then identifies the joint positions that are closest to the joint_seed. + + Args: + target_xpos (torch.Tensor): A tensor representing the target positions. It can be of shape + (batch_size, 3) for multiple positions or (3,) for a single position. + qpos_seed (torch.Tensor, optional): Initial joint positions used as seed for IK solving. + Can be: + - 1D tensor of shape (dof,): Single seed for all target positions + - 2D tensor of shape (batch_size, dof): Individual seed per position + If None, defaults to zero configuration. Defaults to None. + num_samples (int, optional): The number of random samples to generate. Must be positive. + Defaults to None. + return_all_solutions (bool, optional): If True, returns all valid solutions found. + **kwargs: Additional arguments for future extensions. + + Returns: + Tuple[List[bool], torch.Tensor]: A tuple containing: + - A tensor of booleans indicating whether valid solutions were found for each target pose. (Shape: (batch_size,)) + - A tensor of shape (batch_size, 1, dof) containing joint positions for + each target pose, or an empty tensor if no valid solutions were found. + """ + # Convert target_pose to tensor and ensure correct device and dtype + target_xpos = torch.as_tensor( + target_xpos, device=self.device, dtype=torch.float32 + ) + if num_samples is not None: + self._num_samples = num_samples + + # Prepare qpos_seed + if qpos_seed is None: + qpos_seed = torch.zeros(self.dof, device=self.device) + else: + qpos_seed = torch.as_tensor(qpos_seed, device=self.device) + + # Check qpos_seed dimensions + if qpos_seed.dim() == 1: + qpos_seed = qpos_seed.unsqueeze(0) + qpos_seed_ndim = 1 + elif qpos_seed.dim() == 2: + qpos_seed_ndim = 2 + if qpos_seed.shape[0] != target_xpos.shape[0]: + raise ValueError( + "Batch size of qpos_seed must match batch size of target_xpos when qpos_seed is a 2D tensor." + ) + else: + raise ValueError("`qpos_seed` must be a tensor of shape (n,) or (n, n).") + + # Transform target_xpos by TCP + tcp_xpos = torch.as_tensor( + deepcopy(self.tcp_xpos), device=self.device, dtype=torch.float32 + ) + target_xpos = target_xpos @ torch.inverse(tcp_xpos) + + # Get joint limits and ensure shape matches dof + upper_limits = self.upper_position_limits.float() + lower_limits = self.lower_position_limits.float() + + batch_size = target_xpos.shape[0] + + sampler = QposSeedSampler( + num_samples=self._num_samples, dof=self.dof, device=self.device + ) + random_qpos_seeds = sampler.sample( + qpos_seed, lower_limits, upper_limits, batch_size + ) + target_xpos_repeated = sampler.repeat_target_xpos( + target_xpos, self._num_samples + ) + + # Compute IK solutions for all samples + res_list, qpos_list = self._compute_inverse_kinematics( + target_xpos_repeated, random_qpos_seeds + ) + + if not isinstance(res_list, torch.Tensor) or not res_list.any(): + logger.log_warning( + "Pk: No valid solutions found for the given target poses and joint seeds." + ) + return torch.zeros( + batch_size, dtype=torch.bool, device=self.device + ), torch.zeros((batch_size, self.dof), device=self.device) + + # Split res_list and qpos_list according to self._num_samples + res_list_split = torch.split(res_list, self._num_samples) + qpos_list_split = torch.split(qpos_list, self._num_samples) + + # Initialize the final results and the closest joint positions + final_results = [] + final_qpos = [] + + # For each batch, select the closest valid solution to qpos_seed + for i in range(batch_size): + target_qpos_seed = qpos_seed[i] if qpos_seed_ndim == 2 else qpos_seed + + if not res_list_split[i].any(): + final_results.append(False) + final_qpos.append(torch.zeros((1, self.dof), device=self.device)) + continue + + result_qpos_limit = self._qpos_to_limits( + qpos_list_split[i], target_qpos_seed + ) + + if result_qpos_limit.shape[0] == 0: + final_results.append(False) + final_qpos.append(torch.zeros((self.dof), device=self.device)) + continue + + distances = torch.norm(result_qpos_limit - target_qpos_seed, dim=1) + sorted_indices = torch.argsort(distances) + # shape: (N, dof) + sorted_qpos_array = result_qpos_limit[sorted_indices] + final_qpos.append(sorted_qpos_array) + final_results.append(True) + + # Pad all batches to the same number of solutions for stacking + max_solutions = max([q.shape[0] for q in final_qpos]) if final_qpos else 1 + final_qpos_tensor = torch.zeros( + (batch_size, max_solutions, self.dof), device=self.device + ) + for i, q in enumerate(final_qpos): + n = q.shape[0] + final_qpos_tensor[i, :n, :] = q + + final_results = torch.tensor( + final_results, dtype=torch.bool, device=self.device + ) + + if return_all_solutions: + # Return all sorted solutions for each batch (shape: batch_size, max_solutions, dof) + return final_results, final_qpos_tensor + + # Only return the closest solution for each batch (shape: batch_size, 1, dof) + # If multiple solutions, take the first (closest) + final_qpos_tensor = final_qpos_tensor[:, :1, :] + return final_results, final_qpos_tensor + + def get_all_fk(self, qpos: torch.tensor) -> torch.tensor: + r"""Get the forward kinematics for all links from root to end link. + + Args: + qpos (torch.Tensor): The joint positions. + + Returns: + list: A list of 4x4 homogeneous transformation matrices representing the poses of all links from root to end link. + """ + qpos = torch.as_tensor(qpos) + qpos = qpos.to(self.device) + + ret = self.pk_serial_chain.forward_kinematics(qpos, end_only=False) + link_names = list(ret.keys()) + + if self.root_link_name is not None: + try: + start_index = link_names.index(self.root_link_name) + except ValueError: + raise KeyError( + f"Root link name '{self.root_link_name}' not found in the kinematic chain" + ) + else: + start_index = 0 + + if self.end_link_name is not None: + try: + end_index = link_names.index(self.end_link_name) + 1 + except ValueError: + raise KeyError( + f"End link name '{self.end_link_name}' not found in the kinematic chain" + ) + else: + end_index = len(link_names) + + poses = [] + for link_name in link_names[start_index:end_index]: + xpos = ret[link_name] + if not hasattr(xpos, "get_matrix"): + raise AttributeError( + f"The result for link '{link_name}' must have 'get_matrix' attributes." + ) + xpos_t = torch.eye(4, device=xpos.get_matrix().device) + m = xpos.get_matrix() + xpos_t[:3, 3] = m[:, :3, 3] + xpos_t[:3, :3] = m[:, :3, :3] + poses.append(xpos_t) + + return poses diff --git a/embodichain/lab/sim/solvers/qpos_seed_sampler.py b/embodichain/lab/sim/solvers/qpos_seed_sampler.py new file mode 100644 index 00000000..91f7f3c6 --- /dev/null +++ b/embodichain/lab/sim/solvers/qpos_seed_sampler.py @@ -0,0 +1,88 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + + +class QposSeedSampler: + """ + Standard joint seed sampler for IK solving. + + Generates joint seed samples for each target pose in a batch, including the provided seed and random samples within joint limits. + + Args: + num_samples (int): Number of samples per batch (including the seed). + dof (int): Degrees of freedom. + device (torch.device): Target device. + """ + + def __init__(self, num_samples: int, dof: int, device: torch.device): + self.num_samples = num_samples + self.dof = dof + self.device = device + + def sample( + self, + qpos_seed: torch.Tensor, + lower_limits: torch.Tensor, + upper_limits: torch.Tensor, + batch_size: int, + ) -> torch.Tensor: + """Generate joint seed samples for IK solving. + + Args: + qpos_seed (torch.Tensor): (batch_size, dof) or (1, dof) initial seed. + lower_limits (torch.Tensor): (dof,) lower joint limits. + upper_limits (torch.Tensor): (dof,) upper joint limits. + batch_size (int): Batch size. + + Returns: + torch.Tensor: (batch_size * num_samples, dof) joint seeds. + """ + joint_seeds_list = [] + for i in range(batch_size): + current_seed = ( + qpos_seed[i].unsqueeze(0) + if qpos_seed.shape[0] == batch_size + else qpos_seed + ) + if self.num_samples > 1: + rand_part = lower_limits + (upper_limits - lower_limits) * torch.rand( + (self.num_samples - 1, self.dof), device=self.device + ) + else: + rand_part = torch.empty((0, self.dof), device=self.device) + seeds = torch.cat([current_seed, rand_part], dim=0) + joint_seeds_list.append(seeds) + return torch.cat(joint_seeds_list, dim=0) + + def repeat_target_xpos( + self, target_xpos: torch.Tensor, num_samples: int + ) -> torch.Tensor: + """Repeat each target pose num_samples times for batch processing. + + Args: + target_xpos (torch.Tensor): (batch_size, 4, 4) or (batch_size, 3, 3) target poses. + num_samples (int): Number of repeats per batch. + + Returns: + torch.Tensor: (batch_size * num_samples, 4, 4) or (batch_size * num_samples, 3, 3) + """ + repeated_list = [ + target_xpos[i].unsqueeze(0).repeat(num_samples, 1, 1) + for i in range(target_xpos.shape[0]) + ] + return torch.cat(repeated_list, dim=0) diff --git a/embodichain/lab/sim/solvers/srs_solver.py b/embodichain/lab/sim/solvers/srs_solver.py new file mode 100644 index 00000000..44fa894f --- /dev/null +++ b/embodichain/lab/sim/solvers/srs_solver.py @@ -0,0 +1,1222 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import warp as wp +from itertools import product +from typing import Optional, Union, Tuple, Any, Literal, TYPE_CHECKING +from embodichain.utils import configclass, logger +from embodichain.lab.sim.solvers import SolverCfg, BaseSolver + +from embodichain.utils.warp.kinematics.srs_solver import ( + transform_pose_kernel, + compute_ik_kernel, + sort_ik_kernel, + nearest_ik_kernel, + check_success_kernel, +) +from embodichain.utils.device_utils import standardize_device_string + +if TYPE_CHECKING: + from typing import Self + from embodichain.lab.sim.robots.dexforce_w1.params import W1ArmKineParams + + +all = ["SRSSolver", "SRSSolverCfg"] + + +@configclass +class SRSSolverCfg(SolverCfg): + """Configuration for SRS inverse kinematics controller.""" + + class_type: str = "SRSSolver" + """Type of the solver class.""" + + # kine_params: "W1ArmKineParams" + # SRS-specific parameters + dh_params = [] + """Denavit-Hartenberg parameters for the robot's kinematic chain.""" + + qpos_limits = [] + """Joint position limits for the robot.""" + + T_b_ob = np.eye(4) + """Base to observed base transform.""" + + T_e_oe = np.eye(4) + """End-effector to observed end-effector transform.""" + + link_lengths = [] + """Link lengths of the robot arm.""" + + rotation_directions = [] + """Rotation directions for each joint.""" + + num_samples: int = 100 + """Number of samples for elbow angle during IK computation.""" + + sort_ik: bool = True + """Whether to sort IK solutions based on proximity to seed joint positions.""" + + # TODO: Each target pose may have multiple IK solutions; weights can help select the best one. + ik_nearest_weight: np.array = np.ones(7) + """Weights for each joint when finding the nearest IK solution.""" + + def init_solver( + self, num_envs: int = 1, device: torch.device = torch.device("cpu"), **kwargs + ) -> "SRSSolver": + """Initialize the solver with the configuration. + + Args: + device (torch.device): The device to use for the solver. Defaults to CPU. + num_envs (int): The number of environments for which the solver is initialized. + **kwargs: Additional keyword arguments that may be used for solver initialization. + + Returns: + SRSSolver: An initialized solver instance. + """ + + solver = SRSSolver(cfg=self, num_envs=num_envs, device=device, **kwargs) + + # Set the Tool Center Point (TCP) for the solver + if isinstance(self.tcp, torch.Tensor): + tcp = self.tcp.cpu().numpy() + else: + tcp = self.tcp + solver.set_tcp(tcp) + + return solver + + +class _BaseSRSSolverImpl: + """Base implementation for the SRS inverse kinematics solver.""" + + def __init__(self, cfg: SRSSolverCfg, device: torch.device): + # Initialize configuration and device + self.cfg = cfg + self.device = device + self.dofs = 7 + self.dh_params = cfg.dh_params + self.qpos_limits = cfg.qpos_limits + self.tcp_xpos = np.eye(4) + + # Initialize transformation matrices + self._parse_params() + + def _parse_params(self): + # Compute the inverse transformation matrices for TCP, end-effector, and base. + self.tcp_xpos = self.cfg.tcp + self.tcp_inv_np = np.linalg.inv(self.tcp_xpos) + self.T_e_oe_inv_np = np.linalg.inv(self.cfg.T_e_oe) + self.T_b_ob_inv_np = np.linalg.inv(self.cfg.T_b_ob) + + # Convert configuration parameters to numpy arrays for efficient computation. + self.dh_params_np = np.asarray(self.cfg.dh_params) + self.qpos_limits_np = np.asarray(self.cfg.qpos_limits) + self.link_lengths_np = np.asarray(self.cfg.link_lengths) + self.rotation_directions_np = np.asarray(self.cfg.rotation_directions) + + +class _CPUSRSSolverImpl(_BaseSRSSolverImpl): + """CPU implementation of the SRS inverse kinematics solver.""" + + def __init__(self, cfg: SRSSolverCfg, device: torch.device): + super().__init__(cfg, device) + + def _parse_params(self): + super()._parse_params() + + # Generate all possible configuration combinations for shoulder, elbow, and wrist. + # Each configuration is represented by a vector of three elements, each being +1 or -1. + # This covers all 8 possible sign combinations for the three joints. + self.configs = [ + np.array([x, y, z]) for x, y, z in product([1.0, -1.0], repeat=3) + ] + + # Generate a set of elbow angles sampled uniformly from -π to π. + # The number of samples is determined by self.cfg.num_samples. + # These angles are used for searching possible IK solutions. + self.elbow_angles = torch.linspace( + -torch.pi, torch.pi, self.cfg.num_samples, device=self.device + ) + + # Convert ik_nearest_weight to a tensor for efficient computation. + self.ik_nearest_weight_tensor = torch.tensor( + self.cfg.ik_nearest_weight, dtype=torch.float32, device=self.device + ) + + def _get_fk(self, target_joint: np.ndarray) -> np.ndarray: + """ + Compute the forward kinematics (FK) for a given joint state. + + Args: + target_joint (np.ndarray): Joint angles (shape: [7,]). + + Returns: + np.ndarray: 4x4 transformation matrix representing the end-effector pose. + """ + # Initialize pose as identity matrix + pose = np.eye(4) + + # Iterate through the DH parameters and compute the transformation + for i in range(self.dh_params.shape[0]): + d = self.dh_params[i, 0] + alpha = self.dh_params[i, 1] + a = self.dh_params[i, 2] + theta = self.dh_params[i, 3] + + # Add joint angle contribution if within bounds + if i < target_joint.size: + theta += target_joint[i] * self.cfg.rotation_directions[i] + + # Compute the transformation matrix for this joint + T = self._dh_transform(d, alpha, a, theta) + pose = pose @ T + + # Apply additional transformations: user frame, base, and tool center point (TCP) + pose = ( + self.cfg.T_b_ob + @ pose + @ self.cfg.T_e_oe # End-effector-to-observed-end-effector transform + @ self.tcp_xpos # Tool center point transform + ) + + return pose + + def _calculate_arm_joint_angles( + self, + P26: np.ndarray, + elbow_config: int, + joints: np.ndarray, + link_lengths: np.ndarray, + ) -> bool: + """ + Calculate joint angles based on the position vector P26. + + Args: + P26 (np.ndarray): Vector from shoulder to wrist. + elbow_config (int): Elbow configuration (+1 or -1). + joints (np.ndarray): Joint angles to be updated. + link_lengths (np.ndarray): Link lengths of the robot. + + Returns: + bool: True if successful, False otherwise. + """ + d_bs, d_se, d_ew = link_lengths[:3] + + norm_P26 = np.linalg.norm(P26) + if norm_P26 < np.abs(d_bs + d_ew): + logger.log_warning("Specified pose outside reachable workspace.") + return False + + elbow_cos_angle = (norm_P26**2 - d_se**2 - d_ew**2) / (2 * d_se * d_ew) + if abs(elbow_cos_angle) > 1.0: + logger.log_debug("Elbow singularity. End effector at limit.") + return False + + joints[3] = elbow_config * np.arccos(elbow_cos_angle) + + if abs(P26[2]) > 1e-6: + joints[0] = np.arctan2(P26[1], P26[0]) + else: + joints[0] = 0 + + euclidean_norm = np.hypot(P26[0], P26[1]) + angle_phi = np.arccos( + (d_se**2 + norm_P26**2 - d_ew**2) / (2 * d_se * norm_P26) + ) + joints[1] = np.arctan2(euclidean_norm, P26[2]) + elbow_config * angle_phi + + return True + + def _dh_transform( + self, d: float, alpha: float, a: float, theta: float + ) -> np.ndarray: + """ + Compute the Denavit-Hartenberg transformation matrix. + + Args: + d (float): Link offset. + alpha (float): Link twist. + a (float): Link length. + theta (float): Joint angle. + + Returns: + np.ndarray: 4x4 transformation matrix. + """ + cos_theta, sin_theta = np.cos(theta), np.sin(theta) + cos_alpha, sin_alpha = np.cos(alpha), np.sin(alpha) + + # fmt: off + return np.array( + [ + [cos_theta, -sin_theta * cos_alpha, sin_theta * sin_alpha, a * cos_theta], + [sin_theta, cos_theta * cos_alpha, -cos_theta * sin_alpha, a * sin_theta], + [0, sin_alpha, cos_alpha, d], + [0, 0, 0, 1], + ] + ) + # fmt: on + + def _skew(self, vector: np.ndarray) -> np.ndarray: + """ + Compute the skew-symmetric matrix of a vector. + + Args: + vector (np.ndarray): Input vector (3,). + + Returns: + np.ndarray: Skew-symmetric matrix (3x3). + """ + return np.array( + [ + [0, -vector[2], vector[1]], + [vector[2], 0, -vector[0]], + [-vector[1], vector[0], 0], + ] + ) + + def _compute_reference_plane( + self, target_pose: np.ndarray, elbow_config: int + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + """ + Calculate the reference plane vector, rotation matrix, and joint values. + + Args: + target_pose (np.ndarray): Transformed target pose (4x4). + elbow_config (int): Elbow configuration (+1 or -1). + + Returns: + tuple: (plane_normal, base_to_elbow_rotation, joint_angles) or (None, None, None) if failed. + """ + dh_params = self.dh_params + link_lengths = self.cfg.link_lengths + + P_target = target_pose[:3, 3] + P02 = np.array([0, 0, link_lengths[0]]) + P67 = np.array([0, 0, dh_params[6, 0]]) + P06 = P_target - target_pose[:3, :3] @ P67 + P26 = P06 - P02 + + joint_angles = np.zeros(7) + if not self._calculate_arm_joint_angles( + P26, elbow_config, joint_angles, link_lengths + ): + return None, None, None + + T34_v = self._dh_transform( + dh_params[3, 0], dh_params[3, 1], dh_params[3, 2], joint_angles[3] + ) + P34_v = T34_v[:3, 3] + + norm_P34_P02 = np.linalg.norm(P34_v - P02) + if norm_P34_P02 > 1e-6: + v1 = (P34_v - P02) / norm_P34_P02 + else: + v1 = np.zeros_like(P34_v - P02) + v2 = (P06 - P02) / np.linalg.norm(P06 - P02) + plane_normal = np.cross(v1, v2) + + base_to_elbow_rotation = np.eye(3) + for i in range(3): + T = self._dh_transform( + dh_params[i, 0], dh_params[i, 1], dh_params[i, 2], joint_angles[i] + ) + base_to_elbow_rotation = base_to_elbow_rotation @ T[:3, :3] + + return plane_normal, base_to_elbow_rotation, joint_angles + + def _process_all_solutions( + self, + ik_qpos_tensor: torch.Tensor, + qpos_seed: torch.Tensor, + valid_mask: torch.Tensor, + success_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns all valid IK solutions (optionally sorted). + + Args: + ik_qpos_tensor (torch.Tensor): The IK joint position tensor. + qpos_seed (torch.Tensor): The seed joint position tensor. + valid_mask (torch.Tensor): The mask indicating valid solutions. + success_tensor (torch.Tensor): The tensor indicating success of IK solutions. + + Returns: + torch.Tensor: The success tensor. + torch.Tensor: The IK solutions tensor (sorted if specified). + """ + if self.cfg.sort_ik: + weighted_diff = ( + ik_qpos_tensor - qpos_seed.unsqueeze(1) + ) * self.ik_nearest_weight_tensor + distances = torch.norm(weighted_diff, dim=2) + distances[~valid_mask] = float("inf") + sorted_indices = torch.argsort(distances, dim=1) + sorted_ik_qpos_tensor = torch.gather( + ik_qpos_tensor, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, 7) + ) + return success_tensor, sorted_ik_qpos_tensor + else: + return success_tensor, ik_qpos_tensor + + def _process_single_solution( + self, + ik_qpos_tensor: torch.Tensor, + qpos_seed: torch.Tensor, + valid_mask: torch.Tensor, + success_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns the nearest valid IK solution (optionally sorted). + + Args: + ik_qpos_tensor (torch.Tensor): The IK joint position tensor. + qpos_seed (torch.Tensor): The seed joint position tensor. + valid_mask (torch.Tensor): The mask indicating valid solutions. + success_tensor (torch.Tensor): The tensor indicating success of IK solutions. + + Returns: + torch.Tensor: The success tensor. + torch.Tensor: The nearest valid IK solution tensor. + """ + num_targets = ik_qpos_tensor.shape[0] + if self.cfg.sort_ik: + weighted_diff = ( + ik_qpos_tensor - qpos_seed.unsqueeze(1) + ) * self.ik_nearest_weight_tensor + distances = torch.norm(weighted_diff, dim=2) + mask = success_tensor.unsqueeze(1) & valid_mask + distances[~mask] = float("inf") + nearest_indices = torch.argmin(distances, dim=1) + nearest_solutions = torch.zeros( + (num_targets, 7), dtype=qpos_seed.dtype, device=self.device + ) + has_solution = distances.min(dim=1).values != float("inf") + if has_solution.any(): + nearest_solutions[has_solution] = ik_qpos_tensor[ + torch.arange(num_targets)[has_solution], + nearest_indices[has_solution], + ] + return success_tensor, nearest_solutions.unsqueeze(1) + else: + # Return first solution only + return success_tensor, ik_qpos_tensor[:, :1, :] + + def _get_each_ik( + self, target_pose: np.ndarray, nsparam: float, config: np.ndarray + ) -> Tuple[bool, Optional[np.ndarray]]: + """ + Computes the inverse kinematics for a given target pose, normalization parameter, and configuration. + + Args: + target_pose (np.ndarray): 4x4 target pose matrix. + nsparam (float): Normalization parameter (angle). + config (np.ndarray): Configuration index. + + Returns: + bool: Success flag. + np.ndarray: List of joint solutions (7) or None if no solution is found. + """ + # Validate the target pose matrix + target_pose = np.array(target_pose) + if target_pose.ndim == 3 and target_pose.shape[0] == 1: + target_pose = target_pose[0] # Extract the first matrix + if target_pose.shape != (4, 4): + logger.log_error( + f"Invalid xpos shape: {target_pose.shape}, expected (4,4)." + ) + return False, None + + shoulder_config, elbow_config, wrist_config = config[0], config[1], config[2] + + dof = self.dofs + joints_output = np.zeros(dof) + + # Extract parameters + dh_params = self.dh_params + link_lengths = self.cfg.link_lengths + rotation_directions = self.cfg.rotation_directions + + # Transform target pose + target_xpos = ( + self.T_b_ob_inv_np @ target_pose @ self.tcp_inv_np @ self.T_e_oe_inv_np + ) + P_target = target_xpos[:3, 3] + R_target = target_xpos[:3, :3] + P02 = np.array([0, 0, link_lengths[0]]) # Base to shoulder + P67 = np.array([0, 0, dh_params[6, 0]]) # Hand to end-effector + P06 = P_target - R_target @ P67 + P26 = P06 - P02 + + # Calculate joint angles + joints = np.zeros(dof) + if not self._calculate_arm_joint_angles( + P26, elbow_config, joints, link_lengths + ): + return False, None + + # Calculate transformations + T34 = self._dh_transform( + dh_params[3, 0], dh_params[3, 1], dh_params[3, 2], joints[3] + ) + R34 = T34[:3, :3] + + # Calculate reference plane + V_v_to_sew, R03_o, joint_v = self._compute_reference_plane( + target_xpos, config[1] + ) + if V_v_to_sew is None: + return False, None + + # Calculate shoulder joint rotation matrices + usw = P26 / np.linalg.norm(P26) + skew_usw = self._skew(usw) + angle_psi = nsparam + s_psi = wp.sin(angle_psi) + c_psi = wp.cos(angle_psi) + + # Calculate rotation matrix R03 + A_s = skew_usw @ R03_o + B_s = -skew_usw @ skew_usw @ R03_o + C_s = (usw[:, None] @ usw[None, :]) @ R03_o + R03 = A_s * s_psi + B_s * c_psi + C_s + + # Calculate shoulder joint angles + angle1 = np.arctan2(R03[1, 1] * shoulder_config, R03[0, 1] * shoulder_config) + angle2 = np.arccos(R03[2, 1]) * shoulder_config + angle3 = np.arctan2(-R03[2, 2] * shoulder_config, -R03[2, 0] * shoulder_config) + + # Calculate wrist joint angles + A_w = R34.T @ A_s.T @ R_target + B_w = R34.T @ B_s.T @ R_target + C_w = R34.T @ C_s.T @ R_target + R47 = A_w * s_psi + B_w * c_psi + C_w + + angle5 = np.arctan2(R47[1, 2] * wrist_config, R47[0, 2] * wrist_config) + angle6 = np.arccos(R47[2, 2]) * wrist_config + angle7 = np.arctan2(R47[2, 1] * wrist_config, -R47[2, 0] * wrist_config) + + joints_output[0] = (angle1 - dh_params[0, 3]) * rotation_directions[0] + joints_output[1] = (angle2 - dh_params[1, 3]) * rotation_directions[1] + joints_output[2] = (angle3 - dh_params[2, 3]) * rotation_directions[2] + joints_output[3] = (joints[3] - dh_params[3, 3]) * rotation_directions[3] + joints_output[4] = (angle5 - dh_params[4, 3]) * rotation_directions[4] + joints_output[5] = (angle6 - dh_params[5, 3]) * rotation_directions[5] + joints_output[6] = (angle7 - dh_params[6, 3]) * rotation_directions[6] + + return True, joints_output + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute inverse kinematics (IK) for the given target pose using CPU. + + Args: + target_xpos: Target end-effector pose (4x4). + qpos_seed: Initial joint positions (rad). + return_all_solutions: Whether to return all solutions. Default is False. + kwargs: Additional keyword arguments. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Success flag and joint positions. + """ + num_targets = target_xpos.shape[0] + # Validate and normalize qpos_seed + if qpos_seed is None: + qpos_seed = torch.zeros( + (target_xpos.shape[0], 7), dtype=torch.float32, device=self.device + ) + + # Prepare to collect results + max_possible_solutions = len(self.elbow_angles) * len(self.configs) + all_solutions = np.zeros( + (num_targets, max_possible_solutions, 7), dtype=np.float32 + ) + solution_counts = np.zeros(num_targets, dtype=np.int32) + + # Iterate over target poses + for target_idx, xpos in enumerate(target_xpos): + sol_idx = 0 + for psi in self.elbow_angles: + for config in self.configs: + success, qpos = self._get_each_ik(xpos, psi.item(), config) + if success: + fk_xpos = self._get_fk(qpos) + if np.allclose(fk_xpos, xpos, atol=1e-4): + all_solutions[target_idx, sol_idx, :] = qpos + sol_idx += 1 + solution_counts[target_idx] = sol_idx + + has_solution = solution_counts > 0 + if not any(has_solution): + logger.log_warning( + f"Failed to calculate IK solutions.\n" + f"Target pose: {target_xpos}\nSeed: {qpos_seed}" + ) + return ( + torch.zeros(num_targets, dtype=torch.bool, device=self.device), + torch.zeros( + (num_targets, num_targets, 7), + dtype=qpos_seed.dtype, + device=self.device, + ), + ) + max_solutions = solution_counts.max() + + # Convert results to tensors + ik_qpos_tensor = torch.zeros( + (num_targets, max_solutions, 7), + dtype=qpos_seed.dtype, + device=self.device, + ) + for target_idx in range(num_targets): + count = solution_counts[target_idx] + if count > 0: + ik_qpos_tensor[target_idx, :count] = torch.from_numpy( + all_solutions[target_idx, :count] + ).to(self.device, dtype=qpos_seed.dtype) + + valid_mask = ik_qpos_tensor.abs().sum(dim=2) > 0 # (num_targets, max_solutions) + success_tensor = torch.from_numpy(has_solution).to(self.device) + if return_all_solutions: + return self._process_all_solutions( + ik_qpos_tensor, qpos_seed, valid_mask, success_tensor + ) + else: + return self._process_single_solution( + ik_qpos_tensor, qpos_seed, valid_mask, success_tensor + ) + + +class _CUDASRSSolverImpl(_BaseSRSSolverImpl): + """CUDA implementation of the SRS inverse kinematics solver.""" + + def __init__(self, cfg: SRSSolverCfg, device: torch.device): + super().__init__(cfg, device) + + def _parse_params(self): + super()._parse_params() + + # Convert numpy transformation matrices to Warp mat44 format for CUDA computation. + self.tcp_inv_wp = wp.mat44(*self.tcp_inv_np.flatten()) + self.T_b_ob_inv_wp = wp.mat44(*self.T_b_ob_inv_np.flatten()) + self.T_e_oe_inv_wp = wp.mat44(*self.T_e_oe_inv_np.flatten()) + + # Convert DH parameters, joint limits, link lengths, and rotation directions to Warp arrays. + self.dh_params_wp = wp.array( + self.dh_params_np.flatten(), + dtype=float, + device=standardize_device_string(self.device), + ) + self.qpos_limits_wp = wp.array( + self.qpos_limits_np, + dtype=wp.vec2, + device=standardize_device_string(self.device), + ) + self.link_lengths_wp = wp.array( + self.link_lengths_np.flatten(), + dtype=float, + device=standardize_device_string(self.device), + ) + self.rotation_directions_wp = wp.array( + self.rotation_directions_np.flatten(), + dtype=float, + device=standardize_device_string(self.device), + ) + + # Generate all possible configuration combinations for shoulder, elbow, and wrist. + # Each configuration is represented by a vector of three elements, each being +1 or -1. + # This covers all 8 possible sign combinations for the three joints. + self.configs = [wp.vec3(x, y, z) for x, y, z in product([1.0, -1.0], repeat=3)] + self.configs_wp = wp.array( + self.configs, dtype=wp.vec3, device=standardize_device_string(self.device) + ) + + # Generate a set of elbow angles sampled uniformly from -π to π. + # The number of samples is determined by self.cfg.num_samples. + # These angles are used for searching possible IK solutions. + joint_reference_limits = [-wp.pi, wp.pi] + self.elbow_angles = np.linspace( + joint_reference_limits[0], joint_reference_limits[1], self.cfg.num_samples + ).tolist() + + # Convert elbow angles to Warp array for CUDA computation. + self.elbow_angles_wp = wp.array( + self.elbow_angles, + dtype=float, + device=standardize_device_string(self.device), + ) + + def _sort_ik_solutions( + self, qpos_out_wp, success_wp, qpos_seed, num_targets, num_configs, num_angles + ): + """ + Sort IK solutions based on weighted distance. + + Args: + qpos_out_wp: Warp array of IK solutions (shape: [num_targets * num_configs * num_angles, 7]). + success_wp: Warp array of validity flags (shape: [num_targets * num_configs * num_angles]). + qpos_seed: Warp array of seed positions (shape: [num_targets, 7]). + num_targets: Number of targets. + num_configs: Number of configurations. + num_angles: Number of angles. + + Returns: + Tuple[wp.array, wp.array]: Sorted IK solutions and their validity flags. + """ + N = num_targets + N_SOL = num_configs * num_angles + DOF = 7 + + sorted_ik_solutions = wp.zeros( + N * N_SOL * DOF, dtype=float, device=standardize_device_string(self.device) + ) + sorted_ik_valid_flags = wp.zeros( + N * N_SOL, dtype=int, device=standardize_device_string(self.device) + ) + distances = wp.zeros( + N * N_SOL, dtype=float, device=standardize_device_string(self.device) + ) + indices = wp.zeros( + N * N_SOL, dtype=int, device=standardize_device_string(self.device) + ) + + wp.launch( + kernel=sort_ik_kernel, + dim=num_targets, + inputs=[ + qpos_out_wp, + success_wp, + qpos_seed, + wp.array( + self.cfg.ik_nearest_weight, + dtype=float, + device=standardize_device_string(self.device), + ), + distances, + indices, + N_SOL, + ], + outputs=[ + sorted_ik_solutions, + sorted_ik_valid_flags, + ], + device=standardize_device_string(self.device), + ) + return sorted_ik_solutions, sorted_ik_valid_flags + + def _nearest_ik_solution( + self, qpos_out_wp, success_wp, qpos_seed, num_targets, num_configs, num_angles + ): + """ + Find the nearest valid IK solution for each target pose. + + Selects the IK solution closest to the seed configuration among all valid solutions. + + Args: + qpos_out_wp: IK solutions array of shape [num_targets * num_configs * num_angles, 7] + success_wp: Validity flags array of shape [num_targets * num_configs * num_angles] + qpos_seed: Seed configurations array of shape [num_targets, 7] + num_targets: Number of target poses + num_configs: Number of IK configurations + num_angles: Number of sampling angles + + Returns: + Tuple[wp.array, wp.array]: + - Nearest IK solutions array of shape [num_targets, 7] + - Validity flags array of shape [num_targets] indicating solution feasibility + """ + N = num_targets + N_SOL = num_configs * num_angles + DOF = 7 + + nearest_ik_solutions = wp.zeros( + N * DOF, dtype=float, device=standardize_device_string(self.device) + ) + nearest_ik_valid_flags = wp.zeros( + N, dtype=int, device=standardize_device_string(self.device) + ) + + wp.launch( + kernel=nearest_ik_kernel, + dim=num_targets, + inputs=[ + qpos_out_wp, + success_wp, + qpos_seed.flatten(), + wp.array( + self.cfg.ik_nearest_weight, + dtype=float, + device=standardize_device_string(self.device), + ), + N_SOL, + ], + outputs=[ + nearest_ik_solutions, + nearest_ik_valid_flags, + ], + device=standardize_device_string(self.device), + ) + return nearest_ik_solutions, nearest_ik_valid_flags + + def _process_all_solutions( + self, + qpos_out_wp: wp.array, + success_wp: wp.array, + qpos_seed: wp.array, + num_targets: int, + num_configs: int, + num_angles: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Process and return all valid IK solutions. + + Args: + qpos_out_wp: Warp array of IK solutions. + success_wp: Warp array of success flags. + qpos_seed: Seed joint positions. + num_targets: Number of target poses. + num_configs: Number of configurations. + num_angles: Number of angles. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Success flags and all valid joint positions. + """ + num_per_target = num_configs * num_angles + + if self.cfg.sort_ik: + sorted_ik_solutions, sorted_ik_valid_flags = self._sort_ik_solutions( + qpos_out_wp, + success_wp, + qpos_seed.flatten(), + num_targets, + num_configs, + num_angles, + ) + + ik_solutions_tensor = wp.to_torch(sorted_ik_solutions).view( + num_targets, num_per_target, 7 + ) + ik_valid_flags_tensor = ( + wp.to_torch(sorted_ik_valid_flags) + .view(num_targets, num_per_target) + .bool() + ) + else: + ik_solutions_tensor = wp.to_torch(qpos_out_wp).view( + num_targets, num_per_target, 7 + ) + ik_valid_flags_tensor = ( + wp.to_torch(success_wp).view(num_targets, num_per_target).bool() + ) + + success_flags = ik_valid_flags_tensor.any(dim=1) + + valid_qpos_list = [ + ik_solutions_tensor[i][ik_valid_flags_tensor[i]] for i in range(num_targets) + ] + max_solutions = max(q.shape[0] for q in valid_qpos_list) + valid_qpos_tensor = torch.zeros( + (num_targets, max_solutions, 7), + dtype=torch.float32, + device=self.device, + ) + for i, q in enumerate(valid_qpos_list): + valid_qpos_tensor[i, : q.shape[0]] = q.to(self.device) + + return success_flags.to(self.device), valid_qpos_tensor + + def _process_single_solution( + self, + qpos_out_wp: wp.array, + success_wp: wp.array, + qpos_seed: wp.array, + num_targets: int, + num_configs: int, + num_angles: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Process and return the nearest valid IK solution for each target. + + Args: + qpos_out_wp: Warp array of IK solutions. + success_wp: Warp array of success flags. + qpos_seed: Seed joint positions. + num_targets: Number of target poses. + num_configs: Number of configurations. + num_angles: Number of angles. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Success flags and nearest valid joint positions. + """ + num_per_target = num_configs * num_angles + + if self.cfg.sort_ik: + nearest_ik_solutions, nearest_ik_valid_flags = self._nearest_ik_solution( + qpos_out_wp, + success_wp, + qpos_seed, + num_targets, + num_configs, + num_angles, + ) + + nearest_ik_solutions_tensor = wp.to_torch(nearest_ik_solutions).view( + num_targets, 7 + ) + nearest_ik_valid_flags_tensor = ( + wp.to_torch(nearest_ik_valid_flags).view(num_targets).bool() + ) + + first_valid_qpos = torch.zeros( + (num_targets, 1, 7), dtype=torch.float32, device=self.device + ) + for i in range(num_targets): + if nearest_ik_valid_flags_tensor[i]: + first_valid_qpos[i, 0] = nearest_ik_solutions_tensor[i].to( + self.device + ) + + return nearest_ik_valid_flags_tensor.to(self.device), first_valid_qpos + else: + ik_solutions_tensor = wp.to_torch(qpos_out_wp).view( + num_targets, num_per_target, 7 + ) + ik_valid_flags_tensor = ( + wp.to_torch(success_wp).view(num_targets, num_per_target).bool() + ) + + first_valid_qpos = torch.zeros( + (num_targets, 1, 7), dtype=torch.float32, device=self.device + ) + valid_flags = torch.zeros(num_targets, dtype=torch.bool, device=self.device) + for i in range(num_targets): + valid_indices = torch.where(ik_valid_flags_tensor[i])[0] + if len(valid_indices) > 0: + first_valid_qpos[i, 0] = ik_solutions_tensor[ + i, valid_indices[0] + ].to(self.device) + valid_flags[i] = True + + return valid_flags, first_valid_qpos + + def _check_success_flags( + self, + success_wp: wp.array, + num_targets: int, + num_configs: int, + num_angles: int, + ) -> torch.Tensor: + """ + Check success flags for IK solutions. + + Args: + success_wp: Warp array of success flags. + num_targets: Number of target poses. + num_configs: Number of configurations. + num_angles: Number of angles. + + Returns: + torch.Tensor: Success flags as a boolean tensor. + """ + num_solutions = num_configs * num_angles + success_flags_wp = wp.empty( + num_targets, dtype=int, device=standardize_device_string(self.device) + ) + wp.launch( + kernel=check_success_kernel, + dim=num_targets, + inputs=[ + success_wp, + num_solutions, + ], + outputs=[ + success_flags_wp, + ], + device=standardize_device_string(self.device), + ) + return wp.to_torch(success_flags_wp).bool().to(self.device) + + def _compute_ik_solutions( + self, + combinations_wp: wp.array, + xpos_wp: wp.array, + qpos_out_wp: wp.array, + success_wp: wp.array, + num_combinations: int, + ) -> None: + """ + Compute IK solutions using the provided combinations. + + Args: + combinations_wp: Warp array of combinations for parallel processing. + xpos_wp: Transformed target poses. + qpos_out_wp: Output array for joint positions. + success_wp: Output array for success flags. + num_combinations: Total number of combinations to process. + """ + # Temporary arrays + res_arm_angles = wp.zeros( + num_combinations, dtype=int, device=standardize_device_string(self.device) + ) + joints_arm = wp.zeros( + num_combinations, + dtype=wp.vec4, + device=standardize_device_string(self.device), + ) + res_plane_normal = wp.zeros( + num_combinations, dtype=int, device=standardize_device_string(self.device) + ) + plane_normal = wp.zeros( + num_combinations, + dtype=wp.vec3, + device=standardize_device_string(self.device), + ) + base_to_elbow_rotation = wp.zeros( + num_combinations, + dtype=wp.mat33, + device=standardize_device_string(self.device), + ) + joints_plane = wp.zeros( + num_combinations, + dtype=wp.vec4, + device=standardize_device_string(self.device), + ) + + # Launch kernel to compute IK solutions + wp.launch( + kernel=compute_ik_kernel, + dim=num_combinations, + inputs=( + combinations_wp, + xpos_wp, + self.elbow_angles_wp, + self.qpos_limits_wp, + self.configs_wp, + self.dh_params_wp, + self.link_lengths_wp, + self.rotation_directions_wp, + res_arm_angles, + joints_arm, + res_plane_normal, + plane_normal, + base_to_elbow_rotation, + joints_plane, + ), + outputs=[success_wp, qpos_out_wp], + device=standardize_device_string(self.device), + ) + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute inverse kinematics (IK) for the given target pose. + + Args: + target_xpos: Target end-effector pose (4x4). + qpos_seed: Initial joint positions (rad). + return_all_solutions: Whether to return all solutions. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Success flag and joint positions. + """ + # Prepare inputs + target_xpos = target_xpos.to(self.device) + target_xpos = target_xpos.view(-1, 4, 4) + target_xpos_wp = wp.from_torch(target_xpos, dtype=wp.mat44) + + # transform pose + xpos_wp = wp.zeros( + target_xpos_wp.shape[0], + dtype=wp.mat44, + device=standardize_device_string(self.device), + ) + wp.launch( + kernel=transform_pose_kernel, + dim=target_xpos_wp.shape[0], + inputs=[ + target_xpos_wp, + self.T_b_ob_inv_wp, + self.T_e_oe_inv_wp, + self.tcp_inv_wp, + ], + outputs=[xpos_wp], + device=standardize_device_string(self.device), + ) + + # Define configurations and angles + if qpos_seed is None: + qpos_seed = wp.zeros( + (target_xpos.shape[0], 7), + dtype=float, + device=standardize_device_string(self.device), + ) + # TODO: Currently, full-space sampling is used to temporarily address situations + # where joint space discontinuities or solution failures occur in different user scenarios. + # Future plans include reducing the sampling space and adjusting the configuration. + # + # self.configs = [wp.vec3(*np.sign(qpos_seed[[1, 3, 5]].cpu().numpy()))] + + # Prepare output arrays + num_targets = target_xpos_wp.shape[0] + num_configs = len(self.configs) + num_angles = len(self.elbow_angles) + # num_solutions = num_configs * num_angles + num_combinations = num_targets * num_configs * num_angles + + # Generate combinations for parallel processing + combinations_np = np.stack( + np.meshgrid( + np.arange(num_targets), + np.arange(num_configs), + np.arange(num_angles), + indexing="ij", + ), + axis=-1, + ).reshape(-1, 3) + combinations_wp = wp.array( + combinations_np, + dtype=wp.vec3, + device=standardize_device_string(self.device), + ) + + # Output arrays + qpos_out_wp = wp.zeros( + num_combinations * 7, + dtype=float, + device=standardize_device_string(self.device), + ) + success_wp = wp.zeros( + num_combinations, dtype=int, device=standardize_device_string(self.device) + ) + + # Compute IK solutions + self._compute_ik_solutions( + combinations_wp, xpos_wp, qpos_out_wp, success_wp, num_combinations + ) + + # Check for successful solutions + success_flags_tensor = self._check_success_flags( + success_wp, num_targets, num_configs, num_angles + ) + + if success_flags_tensor.any(): + if return_all_solutions: + return self._process_all_solutions( + qpos_out_wp, + success_wp, + qpos_seed, + num_targets, + num_configs, + num_angles, + ) + else: + return self._process_single_solution( + qpos_out_wp, + success_wp, + qpos_seed, + num_targets, + num_configs, + num_angles, + ) + else: + return ( + torch.zeros(num_targets, dtype=torch.bool, device=self.device), + torch.zeros( + (num_targets, num_targets, 7), + dtype=torch.float32, + device=self.device, + ), + ) + + +class SRSSolver(BaseSolver): + r"""SRS inverse kinematics (IK) controller. + + This controller implements SRS inverse kinematics using various methods for + computing the inverse of the Jacobian matrix. + """ + + def __init__(self, cfg: SRSSolverCfg, num_envs: int, device: str, **kwargs): + r"""Initializes the SRS kinematics solver. + + This constructor sets up the kinematics solver using SRS methods, + allowing for efficient computation of robot kinematics based on + the specified URDF model. + + Args: + cfg: The configuration for the solver. + num_envs (int): The number of environments for the solver. + device (str, optional): The device to use for the solver (e.g., "cpu" or "cuda"). + **kwargs: Additional keyword arguments passed to the base solver. + + """ + super().__init__(cfg=cfg, num_envs=num_envs, device=device, **kwargs) + + # Degrees of freedom + self.dofs = 7 + + # Tool Center Point (TCP) position + self.tcp_xpos = np.eye(4) + + # Compute root base transform + fk_dict = self.pk_serial_chain.forward_kinematics( + th=np.zeros(7), end_only=False + ) + root_tf = fk_dict[list(fk_dict.keys())[0]] + self.root_base_xpos = root_tf.get_matrix().cpu().numpy() + + # Initialize implementation based on device + if self.device.type == "cuda": + self.impl = _CUDASRSSolverImpl(cfg, self.device) + else: + self.impl = _CPUSRSSolverImpl(cfg, self.device) + + def get_ik( + self, + target_xpos: torch.Tensor, + qpos_seed: torch.Tensor = None, + return_all_solutions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute inverse kinematics (IK) for the given target pose. + + Args: + target_xpos: Target end-effector pose (4x4). + qpos_seed: Initial joint positions (rad). Default is None. + return_all_solutions: Whether to return all solutions. Default is False. + kwargs: Additional keyword arguments. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Success flag and joint positions. + """ + return self.impl.get_ik( + target_xpos=target_xpos, + qpos_seed=qpos_seed, + return_all_solutions=return_all_solutions, + **kwargs, + ) diff --git a/embodichain/lab/sim/types.py b/embodichain/lab/sim/types.py new file mode 100644 index 00000000..3e080f4e --- /dev/null +++ b/embodichain/lab/sim/types.py @@ -0,0 +1,28 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +import torch + +from typing import Sequence, Union, Dict, Literal + + +Array = Union[torch.Tensor, np.ndarray, Sequence] +Device = Union[str, torch.device] + +EnvObs = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] + +EnvAction = Union[torch.Tensor, Dict[str, torch.Tensor]] diff --git a/embodichain/lab/sim/utility/__init__.py b/embodichain/lab/sim/utility/__init__.py new file mode 100644 index 00000000..94cbbc59 --- /dev/null +++ b/embodichain/lab/sim/utility/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .sim_utils import * +from .mesh_utils import * +from .gizmo_utils import * diff --git a/embodichain/lab/sim/utility/action_utils.py b/embodichain/lab/sim/utility/action_utils.py new file mode 100644 index 00000000..2fc70cee --- /dev/null +++ b/embodichain/lab/sim/utility/action_utils.py @@ -0,0 +1,337 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import numpy as np +import torch +import warp as wp + +from typing import Tuple + +from embodichain.utils.utility import inv_transform +from embodichain.utils.warp import ( + trajectory_get_diff_kernel, + trajectory_interpolate_kernel, + trajectory_add_origin_kernel, + get_offset_qpos_kernel, + pairwise_distances, + cumsum_distances, + repeat_first_point, + interpolate_along_distance, +) +from embodichain.lab.sim.solvers.base_solver import BaseSolver +from embodichain.utils.device_utils import standardize_device_string + + +def compute_pose_offset_related_to_first(full_pose: torch.Tensor) -> torch.Tensor: + """Compute pose offset relative to the first pose. + + Args: + full_pose (torch.Tensor): The full pose tensor of shape (N, 4, 4). + + Returns: + torch.Tensor: The pose offset tensor of shape (N, 4, 4). + """ + inv_pose0_np = inv_transform(full_pose[0].to("cpu").numpy()) + inv_pose0 = torch.tensor(inv_pose0_np, device=full_pose.device) + inv_pose0_repeat = inv_pose0[None, :, :].repeat(full_pose.shape[0], 1, 1) + return torch.bmm(inv_pose0_repeat, full_pose) + + +def sort_and_padding_key_frame( + trajectory: np.ndarray, key_indices: np.ndarray, key_frames_batch: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """sort and padding key frames for warping trajectory + + Args: + trajectory (torch.Tensor): raw trajectory. [n_waypoint, dof] of float. + key_indices (torch.Tensor): key frame waypoint indices. [n_keyframe,] of int. + key_frames_batch (torch.Tensor): batch key frames. [n_batch, n_keyframe, dof] of float. + + Returns: + key_indices_ascending (np.ndarray): padded and sorted key frame indices. [n_keyframe_new,] of int. + key_frames_ascending (np.ndarray): padded and sorted batch key frames. [n_batch, n_keyframe_new, dof] of float. + """ + sort_ids = np.argsort(key_indices) + key_indices_ascending = key_indices[sort_ids] + key_frames_ascending = key_frames_batch[:, sort_ids, :] + n_batch = key_frames_batch.shape[0] + if key_indices_ascending[0] != 0: + key_indices_ascending = np.hstack([0, key_indices_ascending]) + padding_frame = trajectory[0][None, None, :].repeat(n_batch, axis=0) + key_frames_ascending = np.concatenate( + [padding_frame, key_frames_ascending], axis=1 + ) + if key_indices_ascending[-1] != trajectory.shape[0] - 1: + key_indices_ascending = np.hstack( + [key_indices_ascending, trajectory.shape[0] - 1] + ) + padding_frame = trajectory[trajectory.shape[0] - 1][None, None, :].repeat( + n_batch, axis=0 + ) + key_frames_ascending = np.concatenate( + [key_frames_ascending, padding_frame], axis=1 + ) + return key_indices_ascending, key_frames_ascending + + +def warp_trajectory_qpos( + trajectory: torch.Tensor, + key_indices: torch.Tensor, + key_frames_batch: torch.Tensor, + device: str = "cuda", +) -> torch.Tensor: + """warp trajectory + + Args: + trajectory (torch.Tensor): raw trajectory. [n_waypoint, dof] of float. + key_indices (torch.Tensor): key frame waypoint indices. [n_keyframe,] of int. + key_frames_batch (torch.Tensor): batch key frames. [n_batch, n_keyframe, dof] of float. + device (str, optional): torch tensor device. Defaults to "cuda". + + Returns: + torch.Tensor: warped trajectory. [n_batch, n_waypoint, dof] of float. + """ + # sort and pad key frames + trajectory_np = trajectory.to("cpu").numpy().astype(np.float32) + key_indices_np = key_indices.to("cpu").numpy().astype(np.int32) + key_frames_batch_np = key_frames_batch.to("cpu").numpy().astype(np.float32) + + key_indices_padded, key_frames_padded = sort_and_padding_key_frame( + trajectory_np, key_indices_np, key_frames_batch_np + ) + + # allocate cuda memory + n_batch = key_frames_padded.shape[0] + n_keyframe = key_indices_padded.shape[0] + n_waypoint, dof = trajectory_np.shape + wp_in_trajectory = wp.array( + trajectory_np.flatten(), dtype=float, device=standardize_device_string(device) + ) + out_trajectory = np.zeros((n_batch, n_waypoint, dof), dtype=np.float32) + wp_out_trajectory = wp.array( + out_trajectory.flatten(), dtype=float, device=standardize_device_string(device) + ) + wp_key_indices = wp.array( + key_indices_padded, dtype=int, device=standardize_device_string(device) + ) + wp_key_frames = wp.array( + key_frames_padded.flatten(), + dtype=float, + device=standardize_device_string(device), + ) + + # calcuate + wp.launch( + kernel=trajectory_get_diff_kernel, + dim=(n_batch, dof), + inputs=[ + wp_in_trajectory, + wp_key_indices, + wp_key_frames, + n_waypoint, + dof, + n_keyframe, + ], + outputs=[ + wp_out_trajectory, + ], + device=standardize_device_string(device), + ) + wp.launch( + kernel=trajectory_interpolate_kernel, + dim=(n_batch, n_waypoint, dof), + inputs=[wp_key_indices, n_waypoint, dof, n_keyframe], + outputs=[ + wp_out_trajectory, + ], + device=standardize_device_string(device), + ) + wp.launch( + kernel=trajectory_add_origin_kernel, + dim=(n_batch, n_waypoint, dof), + inputs=[wp_in_trajectory, n_waypoint, dof], + outputs=[ + wp_out_trajectory, + ], + device=standardize_device_string(device), + ) + warp_traj = ( + wp.to_torch(wp_out_trajectory) + .reshape(n_batch, n_waypoint, dof) + .to(torch.device(device)) + ) + return warp_traj + + +def get_trajectory_object_offset_qpos( + trajectory: torch.Tensor, + key_indices: torch.Tensor, + key_obj_indices: torch.Tensor, + obj_offset: torch.Tensor, + solver: BaseSolver, + base_xpos: torch.Tensor, + device=torch.device("cuda"), +): + """warp trajectory according to object pose offset + + Args: + trajectory (torch.Tensor): raw trajectory. [n_waypoint, dof] of float, joint positions. + key_indices (torch.Tensor): key frame waypoint indices. [n_keyframe,] of int. + key_obj_indices (torch.Tensor): key frame belong to which object index. [n_keyframe,] of int. + obj_offset (torch.Tensor): each object pose offset. [obj_num, n_batch, 4, 4] of float. + solver (BaseSolver): robot kinematic solver. + base_xpos (torch.Tensor): solver root link pose in world coordinate. [4, 4] of float. + device (str, optional): torch tensor device. Defaults to "cuda". + + Returns: + torch.Tensor: warped trajectory. [n_batch, n_waypoint, dof] of float. + """ + assert key_indices.shape[0] == key_obj_indices.shape[0] + dof = trajectory.shape[1] + key_qpos = trajectory[key_indices] # [n_keyframe, DOF] + n_batch = obj_offset.shape[1] # batch num, aws arena num + n_keyframe = key_qpos.shape[0] + key_xpos = solver.get_fk(key_qpos) # [n_keyframe, 4, 4] + + base_xpos_repeat = base_xpos[None, :, :].repeat(n_keyframe, 1, 1) + key_xpos = torch.bmm(base_xpos_repeat, key_xpos) + + base_xpos_inv_np = inv_transform(base_xpos.to("cpu").numpy()) + base_xpos_inv_wp = wp.mat44f(base_xpos_inv_np) + key_obj_indices_wp = wp.from_torch(key_obj_indices.reshape(-1)) + obj_offset_wp = wp.from_torch(obj_offset.reshape(-1)) + key_xpos_wp = wp.from_torch(key_xpos.reshape(-1)) + key_obj_offset_wp = wp.zeros( + n_batch * n_keyframe * 16, dtype=float, device=standardize_device_string(device) + ) + + wp.launch( + kernel=get_offset_qpos_kernel, + dim=(n_batch, n_keyframe), + inputs=[ + key_obj_indices_wp, + obj_offset_wp, + key_xpos_wp, + base_xpos_inv_wp, + n_batch, + n_keyframe, + ], + outputs=[ + key_obj_offset_wp, + ], + device=standardize_device_string(device), + ) + key_xpos_offset = wp.to_torch(key_obj_offset_wp).reshape(n_batch * n_keyframe, 4, 4) + key_qpos_batch = key_qpos[None, :, :].repeat(n_batch, 1, 1).reshape(-1, dof) + # for pytorch solver, ik use qpos seed but not joint seed + is_success, key_qpos_offset = solver.get_ik( + target_xpos=key_xpos_offset, + qpos_seed=key_qpos_batch, + ) + key_qpos_offset = key_qpos_offset.reshape(n_batch, n_keyframe, -1) + return is_success, key_qpos_offset + + +def interpolate_with_distance_warp( + trajectory: torch.Tensor, # expected shape [B, N, M], float or convertible to float + interp_num: int, # T + device=torch.device("cuda"), +) -> torch.Tensor: + """ + Resample a batch of trajectories of shape [B, N, M] into [B, T, M] by + piecewise-linear interpolation over cumulative Euclidean distance + along the N dimension, handling each batch independently. + + Args: + trajectory: Torch.Tensor of shape [B, N, M]. + interp_num: Target number of samples T. + device: Warp device string ('cpu', 'cuda', 'cuda:0', ...). + dtype: Working dtype (wp.float32 or wp.float64). Defaults to wp.float32. + + Returns: + Torch.Tensor of shape [B, T, M] with interpolated trajectories. + """ + # Flatten input trajectory for warp kernels (avoid multi-dimensional wp.array bugs) + trajectory_flat = trajectory.contiguous().to(device).view(-1) + points = wp.from_torch(trajectory_flat) + + B, N, M = trajectory.shape # original shape components + T = int(interp_num) + + if T < 0: + raise ValueError("`interp_num` must be non-negative.") + + # Handle degenerate T + out = ( + wp.empty( + (B * T * M,), dtype=wp.float32, device=standardize_device_string(device) + ) + if T > 0 + else wp.empty((0,), dtype=wp.float32, device=standardize_device_string(device)) + ) + + # Handle N < 2 + if N < 2: + if N == 1 and T > 0: + # Repeat the single point across T (kernel expects flattened arrays) + wp.launch( + kernel=repeat_first_point, + dim=B * T, + inputs=[points, out, B, T, M, N], + device=standardize_device_string(device), + ) + # N == 0 -> return empty (out already allocated) + interp_trajectory = ( + wp.to_torch(out).view(B, T, M) if T > 0 else wp.to_torch(out).view(B, 0, M) + ) + return interp_trajectory + + if T == 0: + return out # nothing to do + + # 1) pairwise distances along N + dists = wp.empty( + (B * (N - 1),), dtype=wp.float32, device=standardize_device_string(device) + ) + wp.launch( + kernel=pairwise_distances, + dim=B * (N - 1), + inputs=[points, dists, B, N, M], + device=standardize_device_string(device), + ) + + # 2) cumulative distances per batch + cumulative = wp.empty( + (B * N,), dtype=wp.float32, device=standardize_device_string(device) + ) + wp.launch( + kernel=cumsum_distances, + dim=B, + inputs=[dists, cumulative, B, N], + device=standardize_device_string(device), + ) + + # 3) interpolation per (b, t) + wp.launch( + kernel=interpolate_along_distance, + dim=B * T, + inputs=[points, cumulative, out, B, N, M, T], + device=standardize_device_string(device), + ) + + # wp.synchronize_device(device) + interp_trajectory = wp.to_torch(out).view(B, T, M) + return interp_trajectory diff --git a/embodichain/lab/sim/utility/gizmo_utils.py b/embodichain/lab/sim/utility/gizmo_utils.py new file mode 100644 index 00000000..5595a825 --- /dev/null +++ b/embodichain/lab/sim/utility/gizmo_utils.py @@ -0,0 +1,46 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +Gizmo utility functions for EmbodiSim. + +This module provides utility functions for creating gizmo transform callbacks. +""" + +from typing import Callable +from dexsim.types import TransformMask + + +def create_gizmo_callback() -> Callable: + """Create a standard gizmo transform callback function. + + This callback handles basic translation and rotation operations for gizmo controls. + It applies transformations directly to the node when gizmo controls are manipulated. + + Returns: + Callable: A callback function that can be used with gizmo.node.set_flush_transform_callback() + """ + + def gizmo_transform_callback(node, translation, rotation, flag): + if node is not None: + if flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_T): + # Handle translation changes + node.set_translation(translation) + elif flag == (TransformMask.TRANSFORM_LOCAL | TransformMask.TRANSFORM_R): + # Handle rotation changes + node.set_rotation_rpy(rotation) + + return gizmo_transform_callback diff --git a/embodichain/lab/sim/utility/import_utils.py b/embodichain/lab/sim/utility/import_utils.py new file mode 100644 index 00000000..0630a48e --- /dev/null +++ b/embodichain/lab/sim/utility/import_utils.py @@ -0,0 +1,125 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.utils import logger + + +def lazy_import_pytorch_kinematics(): + """ + Lazily import pytorch_kinematics and return the module. + + Returns: + module: The pytorch_kinematics module if available. + + Raises: + ImportError: If the module is not installed. + """ + try: + import pytorch_kinematics as pk + + return pk + except ImportError as e: + logger.log_warning( + "pytorch_kinematics not installed. Install with `pip install pytorch_kinematics==0.7.5`" + ) + raise e + + +def lazy_import_pinocchio(): + """ + Lazily import pinocchio and return the module. + + Returns: + module: The pinocchio module if available. + + Raises: + ImportError: If the module is not installed. + """ + try: + import pinocchio as pin + + return pin + except ImportError as e: + logger.log_warning( + "pinocchio not installed. Install with `conda install pinocchio==3.1.0 -c conda-forge`" + ) + raise e + + +def lazy_import_casadi(): + """ + Lazily import casadi and return the module. + + Returns: + module: The casadi module if available. + + Raises: + ImportError: If the module is not installed. + """ + try: + import casadi + + return casadi + except ImportError as e: + logger.log_warning( + "casadi not installed. Install with `pip install casadi==3.6.7`" + ) + raise e + + +def lazy_import_pinocchio_casadi(): + """ + Lazily import pinocchio.casadi and return the module. + + Returns: + module: The pinocchio.casadi module if available. + + Raises: + ImportError: If the module is not installed. + """ + try: + from pinocchio import casadi as cpin + + return cpin + except ImportError as e: + logger.log_warning( + f"Failed to import pinocchio.casadi: {e}. Install with `conda install pinocchio-casadi -c conda-forge` first." + ) + raise e + + +def lazy_import_pink(): + """ + Lazily import pin-pink and return its components. + + Returns: + tuple: The solve_ik, Configuration, and FrameTask components. + + Raises: + ImportError: If the module is not installed. + """ + try: + from pink import solve_ik + from pink.configuration import Configuration + from pink.tasks import FrameTask + import pink + + return pink + except ImportError as e: + logger.log_warning( + "Failed to import 'pin-pink'. Please install it using `pip install pin-pink==3.4.0`." + ) + raise ImportError("pin-pink is required but not installed.") from e diff --git a/embodichain/lab/sim/utility/io_utils.py b/embodichain/lab/sim/utility/io_utils.py new file mode 100644 index 00000000..b4d7fbd6 --- /dev/null +++ b/embodichain/lab/sim/utility/io_utils.py @@ -0,0 +1,26 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from os import devnull +from contextlib import contextmanager, redirect_stderr, redirect_stdout + + +@contextmanager +def suppress_stdout_stderr(): + """A context manager that redirects stdout and stderr to devnull""" + with open(devnull, "w") as fnull: + with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: + yield (err, out) diff --git a/embodichain/lab/sim/utility/mesh_utils.py b/embodichain/lab/sim/utility/mesh_utils.py new file mode 100644 index 00000000..4a30bd7c --- /dev/null +++ b/embodichain/lab/sim/utility/mesh_utils.py @@ -0,0 +1,208 @@ +import os +import dexsim.engine +import numpy as np +import open3d as o3d +import trimesh +from typing import Tuple, List, Dict, Any, Optional, Union + +import dexsim +from embodichain.utils import logger +from embodichain.data import get_data_path + + +def process_meshes( + mesh_config: Union[List[Dict], Dict], processor_config: Dict = None +) -> List[str]: + r"""Process a list of mesh files using the specified processor configuration. + + Args: + mesh_config (list): A list of dictionaries containing mesh file paths. + processor_config (dict): A dictionary containing the processor configuration. + + Returns: + list: A list of processed mesh file paths. + """ + from embodichain.toolkits.processor.function.mesh_processor import ( + build_mesh_processors, + ) + from embodichain.toolkits.processor.component import TriangleComponent + from embodichain.toolkits.processor.entity import MeshEntity + + processors, replace = None, False + if processor_config is not None: + if "replace" in processor_config: + replace = processor_config.pop("replace") + processors = build_mesh_processors(processor_config) + + if isinstance(mesh_config, dict): + mesh_config_list = list(mesh_config.values()) + else: + mesh_config_list = mesh_config + batch_meshes, batch_index = [], [] + for idx, config in enumerate(mesh_config_list): + if "mesh_file" not in config and "mesh_path" not in config: + logger.log_error("Config must contain 'mesh_file' and 'mesh_path' keys.") + key = "mesh_file" if "mesh_file" in config else "mesh_path" + mesh_fpath = config[key] + mesh_fpath = get_data_path(mesh_fpath) + if not os.path.exists(mesh_fpath): + logger.log_error(f"Mesh file not found at path: {mesh_fpath}") + config[key] = mesh_fpath + save_fpath = ( + os.path.dirname(config[key]) + + "/mesh_processed_" + + os.path.basename(config[key]) + ) + + if processors is None and "mesh_processor" not in config: + # No processors specified, so just return + continue + elif os.path.exists(save_fpath) and not replace: + config[key] = save_fpath + continue + elif "mesh_processor" in config: + # Process the mesh file with the specified processor + mesh_processor = build_mesh_processors(config["mesh_processor"]) + tri_component = TriangleComponent.from_fpath(mesh_fpath) + mesh_entity = MeshEntity("mesh", tri_component) + mesh = mesh_processor.apply([mesh_entity])[0] + mesh.save_mesh(save_fpath) + # Update the mesh file path in the config + config[key] = save_fpath + else: + tri_component = TriangleComponent.from_fpath(mesh_fpath) + mesh_entity = MeshEntity("mesh", tri_component) + batch_meshes.append(mesh_entity) + batch_index.append(idx) + + # Process the batch of meshes with the default processors + if batch_meshes and processors is not None: + meshes = processors.apply(batch_meshes) + for idx, config in enumerate(mesh_config_list): + if idx in batch_index: + save_fpath = ( + os.path.dirname(config[key]) + + "/mesh_processed_" + + os.path.basename(config[key]) + ) + meshes[batch_index.index(idx)].save_mesh(save_fpath) + config[key] = save_fpath + if isinstance(mesh_config, dict): + mesh_config = {k: v for k, v in zip(mesh_config.keys(), mesh_config_list)} + return mesh_config + + +def export_articulation_mesh( + articulation: Union[dexsim.engine.Articulation, list], + output_path: str = "./articulation.obj", + link_names: Optional[Union[List[str], Dict[Any, List[str]]]] = None, + base_xpos: Optional[np.ndarray] = None, + base_link_name: Optional[str] = None, + **kwargs: Any, +) -> o3d.geometry.TriangleMesh: + r"""Export a combined mesh from all links of one or more articulations to a mesh file format. + + This function retrieves the link geometries and poses from the given articulation(s), + transforms each link mesh to its world pose, merges them into a single mesh, and + exports the result to the specified file path. The export format is inferred from + the file extension (e.g., .obj, .ply, .stl, .glb, .gltf). + + Args: + articulation (dexsim.engine.Articulation or list): The articulation object or list of articulations. + output_path (str): The output file path including the file name and extension. + Supported extensions: .obj, .ply, .stl, .glb, .gltf. + link_names (list[str] or dict[Any, list[str]], optional): + Specify which links to export. If None, export all links. + base_xpos (np.ndarray, optional): 4x4 homogeneous transformation matrix. + All meshes will be transformed into this base pose coordinate system. + base_link_name (str, optional): If specified, use the pose of this link as the base pose. + The link will be searched from all link_names of all articulations. + + Returns: + o3d.geometry.TriangleMesh: The combined Open3D mesh object of all articulations. + """ + output_path = os.path.abspath(output_path) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + combined_mesh = o3d.geometry.TriangleMesh() + articulations = ( + articulation if isinstance(articulation, (list, tuple)) else [articulation] + ) + + # Determine base transform: priority base_xpos > base_link_name > identity + base_inv = None + if base_xpos is not None: + base_inv = np.linalg.inv(base_xpos) + elif base_link_name is not None: + # Search base_link_name from all link_names of all articulations + found = False + for art in articulations: + # Get all possible link names for this articulation + if link_names is None: + cur_link_names = art.get_link_names() + elif isinstance(link_names, dict): + cur_link_names = link_names.get(art, art.get_link_names()) + else: + cur_link_names = link_names + if base_link_name in cur_link_names: + base_pose = art.get_link_pose(base_link_name) + base_inv = np.linalg.inv(base_pose) + found = True + break + if not found: + logger.log_warning( + f"base_link_name '{base_link_name}' not found in any articulation, using identity." + ) + base_inv = np.eye(4) + else: + base_inv = np.eye(4) + + for art in articulations: + if link_names is None: + cur_link_names = art.get_link_names() + elif isinstance(link_names, dict): + cur_link_names = link_names.get(art, art.get_link_names()) + else: + cur_link_names = link_names + + link_poses = [art.get_link_pose(name) for name in cur_link_names] + + for i, link_name in enumerate(cur_link_names): + verts, faces = art.get_link_vert_face(link_name) + logger.log_debug( + f"Link '{link_name}' has {verts.shape[0]} vertices, {verts.shape[1]} faces." + ) + if verts.shape[0] == 0: + continue + + mesh = o3d.geometry.TriangleMesh( + o3d.utility.Vector3dVector(verts), o3d.utility.Vector3iVector(faces) + ) + mesh.compute_vertex_normals() + mesh.transform(link_poses[i]) + mesh.transform(base_inv) + combined_mesh += mesh + + combined_mesh.compute_vertex_normals() + + ext = os.path.splitext(output_path)[1].lower() + + if ext in [".obj", ".ply", ".stl"]: + o3d.io.write_triangle_mesh(output_path, combined_mesh) + logger.log_info(f"Mesh exported using Open3D to: {output_path}") + + elif ext in [".glb", ".gltf"]: + mesh_trimesh = trimesh.Trimesh( + vertices=np.asarray(combined_mesh.vertices), + faces=np.asarray(combined_mesh.triangles), + vertex_normals=np.asarray(combined_mesh.vertex_normals), + ) + mesh_trimesh.export(output_path) + logger.log_info(f"Mesh exported using trimesh to: {output_path}") + + else: + raise ValueError( + f"Unsupported file format: '{ext}'. Supported: obj, ply, stl, glb, gltf" + ) + + return combined_mesh diff --git a/embodichain/lab/sim/utility/sim_utils.py b/embodichain/lab/sim/utility/sim_utils.py new file mode 100644 index 00000000..adcc53e4 --- /dev/null +++ b/embodichain/lab/sim/utility/sim_utils.py @@ -0,0 +1,286 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import dexsim +import open3d as o3d + +from typing import List, Union, Optional + +from dexsim.types import DriveType, ArticulationFlag, LoadOption, RigidBodyShape +from dexsim.engine import Articulation +from dexsim.environment import Env, Arena +from dexsim.models import MeshObject + +from embodichain.lab.sim.cfg import ArticulationCfg, RigidObjectCfg, SoftObjectCfg +from embodichain.lab.sim.shapes import MeshCfg, CubeCfg, SphereCfg +from embodichain.utils import logger +from dexsim.kit.meshproc import get_mesh_auto_uv +import numpy as np + + +def get_dexsim_arenas() -> List[dexsim.environment.Arena]: + """Get all arenas in the default dexsim world. + + Returns: + List[dexsim.environment.Arena]: A list of arenas in the default world, or an empty list if no world is found. + """ + world = dexsim.default_world() + if world is None: + logger.log_warning(f"No default world found. Returning empty arena list.") + return [] + + env = world.get_env() + arenas = env.get_all_arenas() + if len(arenas) == 0: + return [env] + return arenas + + +def get_dexsim_arena_num() -> int: + """Get the number of arenas in the default dexsim world. + + Returns: + int: The number of arenas in the default world, or 0 if no world is found. + """ + arenas = get_dexsim_arenas() + return len(arenas) + + +def get_dexsim_drive_type(drive_type: str) -> DriveType: + """Get the dexsim drive type from a string. + + Args: + drive_type (str): The drive type as a string. + + Returns: + DriveType: The corresponding DriveType enum. + """ + if drive_type == "force": + return DriveType.FORCE + elif drive_type == "acceleration": + return DriveType.ACCELERATION + else: + logger.error(f"Invalid dexsim drive type: {drive_type}") + + +def set_dexsim_articulation_cfg(arts: List[Articulation], cfg: ArticulationCfg) -> None: + """Set articulation configuration for a list of dexsim articulations. + + Args: + arts (List[Articulation]): List of dexsim articulations to configure. + cfg (ArticulationCfg): Configuration object containing articulation settings. + """ + + def get_drive_type(drive_pros): + if isinstance(drive_pros, dict): + return drive_pros.get("drive_type", None) + return getattr(drive_pros, "drive_type", None) + + drive_pros = getattr(cfg, "drive_pros", None) + drive_type = get_drive_type(drive_pros) if drive_pros is not None else None + + if drive_type == "force": + drive_type = DriveType.FORCE + elif drive_type == "target": + drive_type == DriveType.FORCE + else: + logger.log_error(f"Unknow drive type {drive_type}") + + for art in arts: + art.set_physical_attr(cfg.attrs.attr()) + art.set_articulation_flag(ArticulationFlag.FIX_BASE, cfg.fix_base) + art.set_articulation_flag( + ArticulationFlag.DISABLE_SELF_COLLISION, cfg.disable_self_collision + ) + art.set_solver_iteration_counts( + min_position_iters=cfg.min_position_iters, + min_velocity_iters=cfg.min_velocity_iters, + ) + link_names = art.get_link_names() + for name in link_names: + physical_body = art.get_physical_body(name) + inertia = physical_body.get_mass_space_inertia_tensor() + inertia = np.maximum(inertia, 1e-4) + physical_body.set_mass_space_inertia_tensor(inertia) + + +def is_rt_enabled() -> bool: + """Check if Ray Tracing rendering backend is enabled in the default dexsim world. + + Returns: + bool: True if Ray Tracing rendering is enabled, False otherwise. + """ + config = dexsim.get_world_config() + + return config.renderer == dexsim.types.Renderer.FASTRT + + +def create_cube( + envs: List[Union[Env, Arena]], size: List[float], uid: str = "cube" +) -> List[MeshObject]: + """Create cube objects in the specified environments or arenas. + + Args: + envs (List[Union[Env, Arena]]): List of environments or arenas to create cubes in. + size (List[float]): Size of the cube as [length, width, height] in meters. + uid (str, optional): Unique identifier for the cube objects. Defaults to "cube". + + Returns: + List[MeshObject]: List of created cube mesh objects. + """ + cubes = [] + for i, env in enumerate(envs): + cube = env.create_cube(size[0], size[1], size[2]) + cube.set_name(f"{uid}_{i}") + cubes.append(cube) + return cubes + + +def create_sphere( + envs: List[Union[Env, Arena]], + radius: float, + resolution: int = 20, + uid: str = "sphere", +) -> List[MeshObject]: + """Create sphere objects in the specified environments or arenas. + + Args: + envs (List[Union[Env, Arena]]): List of environments or arenas to create spheres in. + radius (float): Radius of the sphere in meters. + resolution (int, optional): Resolution of the sphere mesh. Defaults to 20. + uid (str, optional): Unique identifier for the sphere objects. Defaults to "sphere". + + Returns: + List[MeshObject]: List of created sphere mesh objects. + """ + spheres = [] + for i, env in enumerate(envs): + sphere = env.create_sphere(radius, resolution) + sphere.set_name(f"{uid}_{i}") + spheres.append(sphere) + return spheres + + +def load_mesh_objects_from_cfg( + cfg: RigidObjectCfg, env_list: List[Arena], cache_dir: Optional[str] = None +) -> List[MeshObject]: + """Load mesh objects from configuration. + + Args: + cfg (RigidObjectCfg): Configuration for the rigid object. + env_list (List[Arena]): List of arenas to load the objects into. + + cache_dir (Optional[str], optional): Directory for caching convex decomposition files. Defaults to None + Returns: + List[MeshObject]: List of loaded mesh objects. + """ + obj_list = [] + body_type = cfg.to_dexsim_body_type() + if isinstance(cfg.shape, MeshCfg): + + option = LoadOption() + option.rebuild_normals = cfg.shape.load_option.rebuild_normals + option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent + option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal + option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent + option.smooth = cfg.shape.load_option.smooth + + cfg: RigidObjectCfg + max_convex_hull_num = cfg.max_convex_hull_num + fpath = cfg.shape.fpath + + compute_uv = cfg.shape.compute_uv + + for i, env in enumerate(env_list): + if max_convex_hull_num > 1: + obj = env.load_actor_with_coacd( + fpath, + duplicate=True, + attach_scene=True, + option=option, + cache_path=cache_dir, + actor_type=body_type, + max_convex_hull_num=max_convex_hull_num, + ) + else: + obj = env.load_actor( + fpath, duplicate=True, attach_scene=True, option=option + ) + obj.add_rigidbody(body_type, RigidBodyShape.CONVEX) + obj.set_name(f"{cfg.uid}_{i}") + obj_list.append(obj) + + if compute_uv: + vertices = obj.get_vertices() + triangles = obj.get_triangles() + + o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) + _, uvs = get_mesh_auto_uv( + o3d_mesh, np.array(cfg.shape.project_direction) + ) + obj.set_uv_mapping(uvs) + + elif isinstance(cfg.shape, CubeCfg): + from embodichain.lab.sim.utility.sim_utils import create_cube + + obj_list = create_cube(env_list, cfg.shape.size, uid=cfg.uid) + for obj in obj_list: + obj.add_rigidbody(body_type, RigidBodyShape.BOX) + + elif isinstance(cfg.shape, SphereCfg): + from embodichain.lab.sim.utility.sim_utils import create_sphere + + obj_list = create_sphere( + env_list, cfg.shape.radius, cfg.shape.resolution, uid=cfg.uid + ) + for obj in obj_list: + obj.add_rigidbody(body_type, RigidBodyShape.SPHERE) + else: + logger.log_error( + f"Unsupported rigid object shape type: {type(cfg.shape)}. Supported types: MeshCfg, CubeCfg, SphereCfg." + ) + return obj_list + + +def load_soft_object_from_cfg( + cfg: SoftObjectCfg, env_list: List[Arena] +) -> List[MeshObject]: + obj_list = [] + + option = LoadOption() + option.rebuild_normals = cfg.shape.load_option.rebuild_normals + option.rebuild_tangent = cfg.shape.load_option.rebuild_tangent + option.rebuild_3rdnormal = cfg.shape.load_option.rebuild_3rdnormal + option.rebuild_3rdtangent = cfg.shape.load_option.rebuild_3rdtangent + option.smooth = cfg.shape.load_option.smooth + option.share_mesh = False + + for i, env in enumerate(env_list): + obj = env.load_actor( + fpath=cfg.shape.fpath, duplicate=True, attach_scene=True, option=option + ) + obj.add_softbody(cfg.voxel_attr.attr(), cfg.physical_attr.attr()) + if cfg.shape.compute_uv: + vertices = obj.get_vertices() + triangles = obj.get_triangles() + + o3d_mesh = o3d.t.geometry.TriangleMesh(vertices, triangles) + _, uvs = get_mesh_auto_uv(o3d_mesh, cfg.shape.project_direction) + obj.set_uv_mapping(uvs) + obj.set_name(f"{cfg.uid}_{i}") + obj_list.append(obj) + return obj_list diff --git a/embodichain/lab/sim/utility/solver_utils.py b/embodichain/lab/sim/utility/solver_utils.py new file mode 100644 index 00000000..976462f2 --- /dev/null +++ b/embodichain/lab/sim/utility/solver_utils.py @@ -0,0 +1,111 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +from embodichain.lab.sim.utility.io_utils import suppress_stdout_stderr + +from typing import Optional, Union, Tuple, Any, TYPE_CHECKING +from copy import deepcopy + +from embodichain.utils import configclass, logger + +if TYPE_CHECKING: + from typing import Self + +from embodichain.lab.sim.utility.import_utils import ( + lazy_import_pytorch_kinematics, +) + + +def create_pk_chain( + urdf_path: str, + device: torch.device, + **kwargs, +) -> "pk.SerialChain": + """ + Factory method to create a pk.SerialChain object from a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + end_link_name (str): Name of the end-effector link. + root_link_name (Optional[str]): Name of the root link. If None, the chain starts from the base. + device (torch.device): The device to which the chain will be moved. + is_serial (bool): Whether the chain is serial or not. + + Returns: + pk.SerialChain: The created serial chain object. + """ + pk = lazy_import_pytorch_kinematics() + with open(urdf_path, "rb") as f: + urdf_str = f.read() + + with suppress_stdout_stderr(): + return pk.build_chain_from_urdf(urdf_str).to(device=device) + + +def create_pk_serial_chain( + urdf_path: str = None, + device: torch.device = None, + end_link_name: str = None, + root_link_name: Optional[Union[str, None]] = None, + chain: Optional["pk.SerialChain"] = None, + **kwargs, +) -> "pk.SerialChain": + """ + Factory method to create a pk.SerialChain object from a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + end_link_name (str): Name of the end-effector link. + root_link_name (Optional[str]): Name of the root link. If None, the chain starts from the base. + device (torch.device): The device to which the chain will be moved. + is_serial (bool): Whether the chain is serial or not. + + Returns: + pk.SerialChain: The created serial chain object. + """ + if urdf_path is None and chain is None: + raise ValueError("Either `urdf_path` or `chain` must be provided.") + if urdf_path and chain: + raise ValueError("`urdf_path` and `chain` cannot be provided at the same time.") + + pk = lazy_import_pytorch_kinematics() + + if chain is None: + try: + with open(urdf_path, "rb") as f: + urdf_str = f.read() + except FileNotFoundError: + raise ValueError(f"URDF file not found at path: {urdf_path}") + except IOError as e: + raise ValueError(f"Failed to read URDF file: {e}") + + with suppress_stdout_stderr(): + if root_link_name is None: + return pk.build_serial_chain_from_urdf( + urdf_str, + end_link_name=end_link_name, + ).to(device=device) + else: + return pk.build_serial_chain_from_urdf( + urdf_str, + end_link_name=end_link_name, + root_link_name=root_link_name, + ).to(device=device) + else: + return pk.SerialChain( + chain=chain, end_frame_name=end_link_name, root_frame_name=root_link_name + ) diff --git a/embodichain/lab/sim/utility/tensor.py b/embodichain/lab/sim/utility/tensor.py new file mode 100644 index 00000000..3b8553cf --- /dev/null +++ b/embodichain/lab/sim/utility/tensor.py @@ -0,0 +1,55 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np + +from typing import Union, Optional + + +def to_tensor( + arr: Union[torch.Tensor, np.ndarray, list], + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Convert input to torch.Tensor with specified dtype and device. + + Supports torch.Tensor, np.ndarray, and list. + + Args: + arr (Union[torch.Tensor, np.ndarray, list]): Input array. + dtype (torch.dtype, optional): Desired tensor dtype. Defaults to torch.float32. + device (torch.device, optional): Desired device. If None, uses current device. + + Returns: + torch.Tensor: Converted tensor. + """ + if isinstance(arr, torch.Tensor): + return arr.to(dtype=dtype, device=device) if device else arr.to(dtype=dtype) + elif isinstance(arr, np.ndarray): + return ( + torch.from_numpy(arr).to(dtype=dtype, device=device) + if device + else torch.from_numpy(arr).to(dtype=dtype) + ) + elif isinstance(arr, list): + return ( + torch.tensor(arr, dtype=dtype, device=device) + if device + else torch.tensor(arr, dtype=dtype) + ) + else: + raise TypeError("Input must be a torch.Tensor, np.ndarray, or list.") diff --git a/embodichain/toolkits/__init__.py b/embodichain/toolkits/__init__.py new file mode 100644 index 00000000..e4655620 --- /dev/null +++ b/embodichain/toolkits/__init__.py @@ -0,0 +1,15 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- diff --git a/embodichain/toolkits/graspkit/pg_grasp/__init__.py b/embodichain/toolkits/graspkit/pg_grasp/__init__.py new file mode 100644 index 00000000..4409079f --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/__init__.py @@ -0,0 +1,21 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .antipodal import AntipodalGenerator, GraspSelectMethod + +__all__ = ["AntipodalGenerator", "GraspSelectMethod"] + +del antipodal diff --git a/embodichain/toolkits/graspkit/pg_grasp/antipodal.py b/embodichain/toolkits/graspkit/pg_grasp/antipodal.py new file mode 100644 index 00000000..56239749 --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/antipodal.py @@ -0,0 +1,671 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import time +import pathlib +import pickle +import os + +from enum import Enum +from copy import deepcopy +from typing import List + +from .cone_sampler import ConeSampler +from embodichain.utils.utility import get_mesh_md5 +from embodichain.utils import logger + + +class GraspSelectMethod(Enum): + NORMAL_SCORE = 0 + NEAR_APPROACH = 1 + CENTER = 2 + + +class AntipodalGrasp: + def __init__(self, pose: np.ndarray, open_len: float, score: float) -> None: + self.pose = pose # [4, 4] of float grasp pose + self.open_len = open_len # gripper open length + self.score = score # grasp pose score + + def grasp_pose_visual_mesh(self, gripper_open_length: float = None): + if gripper_open_length is None: + gripper_open_length = self.open_len + open_ratio = 1.0 + else: + open_ratio = self.open_len / gripper_open_length + open_ratio = max(1e-4, open_ratio) + gripper_finger = o3d.geometry.TriangleMesh.create_box( + gripper_open_length * 0.04, + gripper_open_length * 0.04, + gripper_open_length * 0.5, + ) + gripper_finger.translate( + np.array( + [ + -gripper_open_length * 0.02, + -gripper_open_length * 0.02, + -gripper_open_length * 0.25, + ] + ) + ) + gripper_left = deepcopy(gripper_finger) + gripper_left = gripper_left.translate( + np.array( + [ + -gripper_open_length * open_ratio * 0.5, + 0, + -gripper_open_length * 0.25, + ] + ) + ) + + gripper_right = deepcopy(gripper_finger) + gripper_right = gripper_right.translate( + np.array( + [gripper_open_length * open_ratio * 0.5, 0, -gripper_open_length * 0.25] + ) + ) + + gripper_root1 = o3d.geometry.TriangleMesh.create_box( + gripper_open_length * open_ratio, + gripper_open_length * 0.04, + gripper_open_length * 0.04, + ) + gripper_root1.translate( + np.array( + [ + gripper_open_length * open_ratio * -0.5, + gripper_open_length * -0.02, + gripper_open_length * -0.02, + ] + ) + ) + gripper_root1.translate( + np.array( + [ + 0, + 0, + gripper_open_length * -0.5, + ] + ) + ) + + gripper_root2 = o3d.geometry.TriangleMesh.create_box( + gripper_open_length * 0.04, + gripper_open_length * 0.04, + gripper_open_length * 0.5, + ) + gripper_root2.translate( + np.array( + [ + gripper_open_length * -0.02, + gripper_open_length * -0.02, + gripper_open_length * -0.25, + ] + ) + ) + gripper_root2.translate( + np.array( + [ + 0, + 0, + gripper_open_length * -0.75, + ] + ) + ) + + gripper_visual = gripper_left + gripper_right + gripper_root1 + gripper_root2 + gripper_visual.compute_vertex_normals() + return gripper_visual + + +class Antipodal: + def __init__( + self, + point_a: np.ndarray, + point_b: np.ndarray, + normal_a: np.ndarray, + normal_b: np.ndarray, + ) -> None: + """antipodal contact pair + + Args: + point_a (np.ndarray): position in point a + point_b (np.ndarray): position in point b + normal_a (np.ndarray): normal in point a + normal_b (np.ndarray): normal in point b + """ + self.point_a = point_a + self.point_b = point_b + self.normal_a = normal_a + self.normal_b = normal_b + self.dis = np.linalg.norm(point_a - point_b) + self.angle_cos = self.normal_a.dot(-self.normal_b) + self.score = self._get_score() + self._canonical_pose = self._get_canonical_pose() + + def _get_score(self): + # TODO: only normal angle is taken into account + return self.angle_cos + + def get_dis(self, another) -> float: + """get distance acoording to another antipodal + + Args: + other (Antipodal): another antipodal + + Returns: + float: distance + """ + aa_dis = np.linalg.norm(self.point_a - another.point_a) + bb_dis = np.linalg.norm(self.point_b - another.point_b) + ab_dis = np.linalg.norm(self.point_a - another.point_b) + ba_dis = np.linalg.norm(self.point_b - another.point_a) + return min(aa_dis, bb_dis, ab_dis, ba_dis) + + def get_dis_arr(self, others) -> np.ndarray: + """get distance acoording to others antipodals + + Args: + others (List[Antipodal]): other antipodals + + Returns: + np.ndarray: distance array + """ + other_num = len(others) + other_a = np.empty(shape=(other_num, 3), dtype=float) + other_b = np.empty(shape=(other_num, 3), dtype=float) + for i in range(other_num): + other_a[i] = others[i].point_a + other_b[i] = others[i].point_b + aa_dis = np.linalg.norm(other_a - self.point_a, axis=1) + ab_dis = np.linalg.norm(other_a - self.point_b, axis=1) + ba_dis = np.linalg.norm(other_b - self.point_a, axis=1) + bb_dis = np.linalg.norm(other_b - self.point_b, axis=1) + dis_arr = np.vstack([aa_dis, ab_dis, ba_dis, bb_dis]).min(axis=0) + return dis_arr + + def _get_canonical_pose(self) -> np.ndarray: + """get canonical pose of antipodal contact pair + + Returns: + np.ndarray: canonical pose + """ + # assume gripper closing along x_axis + x_d = self.point_a - self.point_b + x_d = x_d / np.linalg.norm(x_d) + y_d = np.cross(np.array([0, 0, 1.0], dtype=float), x_d) + if np.linalg.norm(y_d) < 1e-4: + y_d = np.cross(np.array([1, 0, 0.0], dtype=float), x_d) + y_d = y_d / np.linalg.norm(y_d) + z_d = np.cross(x_d, y_d) + pose = np.eye(4, dtype=float) + pose[:3, 0] = x_d # rotation x + pose[:3, 1] = y_d # rotation y + pose[:3, 2] = z_d # rotation z + pose[:3, 3] = (self.point_a + self.point_b) / 2 # position + return pose + + def sample_pose(self, sample_num: int = 36) -> np.ndarray: + """sample parallel gripper grasp poses given antipodal contact pairs + + Args: + sample_num (int, optional): sample number. Defaults to 36. + + Returns: + np.ndarray: [sample_num, 4, 4] of float. Sample poses. + """ + # assume gripper closing along x_axis + x_d = self._canonical_pose[:3, 0] + y_d = self._canonical_pose[:3, 1] + z_d = self._canonical_pose[:3, 2] + position = self._canonical_pose[:3, 3] + beta_list = np.linspace(2 * np.pi / sample_num, 2 * np.pi, sample_num) + grasp_poses = np.empty(shape=(sample_num, 4, 4), dtype=float) + for i in range(sample_num): + sample_z = np.sin(beta_list[i]) * y_d + np.cos(beta_list[i]) * z_d + sample_z = sample_z / np.linalg.norm(sample_z) + sample_y = np.cross(sample_z, x_d) + pose = np.eye(4, dtype=float) + pose[:3, 0] = x_d # rotation x + pose[:3, 1] = sample_y # rotation y + pose[:3, 2] = sample_z # rotation z + pose[:3, 3] = position # position + grasp_poses[i] = pose + return grasp_poses + + +class AntipodalGenerator: + def __init__( + self, + mesh_o3dt: o3d.t.geometry.TriangleMesh, + open_length: float, + min_open_length: float = 0.002, + max_angle: float = np.pi / 10, + surface_sample_num: int = 5000, + layer_num: int = 4, + sample_each_layer: int = 6, + nms_ratio: float = 0.02, + antipodal_sample_num: int = 16, + unique_id: str = None, + cache_dir: str = None, + ): + """antipodal grasp pose generator + + Args: + mesh_o3dt (o3d.t.geometry.TriangleMesh): input mesh + open_length (float): gripper maximum open length + max_angle (float, optional): maximum grasp direction with surface normal. Defaults to np.pi/10. + surface_sample_num (int, optional): contact sample number in mesh surface. Defaults to 5000. + layer_num (int, optional): cone sample layer number . Defaults to 4. + sample_each_layer (int, optional): cone sample number in each layer. Defaults to 6. + nms_ratio (float, optional): nms distance ratio. Defaults to 0.02. + antipodal_sample_num (int, optional): grasp poses sample on each antipodal contact pair. Defaults to 16. + cache_dir (str, optional): file cache directory. Defaults to None. + """ + self._antipodal_max_angle = max_angle + self._open_length = open_length + self._min_open_length = min_open_length + self._mesh_o3dt = mesh_o3dt + verts = mesh_o3dt.vertex.positions.numpy() + self._center_of_mass = verts.mean(axis=0) + if unique_id is None: + unique_file_name = self._get_unique_id( + mesh_o3dt, open_length, max_angle, surface_sample_num + ) + else: + unique_file_name = f"{unique_id}_{str(open_length)}_{str(max_angle)}_{str(surface_sample_num)}" + if cache_dir is None: + cache_dir = os.path.join(pathlib.Path.home(), "grasp_pose") + logger.log_debug(f"Set cache directory to {cache_dir}.") + if not os.path.isdir(cache_dir): + os.mkdir(cache_dir) + cache_file = os.path.join(cache_dir, f"{unique_file_name}.pickle") + if not os.path.isfile(cache_file): + # generate cache + grasp_list = self._generate_cache( + cache_file, + mesh_o3dt=mesh_o3dt, + max_angle=max_angle, + surface_sample_num=surface_sample_num, + layer_num=layer_num, + sample_each_layer=sample_each_layer, + nms_ratio=nms_ratio, + antipodal_sample_num=antipodal_sample_num, + ) + else: + # load cache + grasp_list = self._load_cache(cache_file) + self._grasp_list = grasp_list + + def _get_unique_id( + self, + mesh_o3dt: o3d.t.geometry.TriangleMesh, + open_length: float, + max_angle: float, + surface_sample_num: int, + ) -> str: + mesh_md5 = get_mesh_md5(mesh_o3dt) + return ( + f"{mesh_md5}_{str(open_length)}_{str(max_angle)}_{str(surface_sample_num)}" + ) + + def _generate_cache( + self, + cache_file: str, + mesh_o3dt: o3d.t.geometry.TriangleMesh, + max_angle: float = np.pi / 10, + surface_sample_num: int = 5000, + layer_num: int = 4, + sample_each_layer: int = 6, + nms_ratio: float = 0.02, + antipodal_sample_num: int = 16, + ): + self._mesh_o3dt = mesh_o3dt + self._cone_sampler = ConeSampler( + max_angle=max_angle, + layer_num=layer_num, + sample_each_layer=sample_each_layer, + ) + mesh_o3dt = mesh_o3dt.compute_vertex_normals() + assert 1e-4 < max_angle < np.pi / 2 + self._mesh_len = self._get_pc_size(mesh_o3dt.vertex.positions.numpy()).max() + start_time = time.time() + antipodal_list = self._antipodal_generate(mesh_o3dt, surface_sample_num) + logger.log_debug( + f"Antipodal sampling cost {(time.time() - start_time) * 1000} ms." + ) + logger.log_debug(f"Find {len(antipodal_list)} initial antipodal pairs.") + + valid_antipodal_list = self._antipodal_clean(antipodal_list) + + start_time = time.time() + nms_antipodal_list = self._antipodal_nms( + valid_antipodal_list, nms_ratio=nms_ratio + ) + logger.log_debug(f"NMS cost {(time.time() - start_time) * 1000} ms.") + logger.log_debug( + f"There are {len(nms_antipodal_list)} antipodal pair after nms." + ) + # all poses + start_time = time.time() + grasp_poses, score, open_length = self._sample_grasp_pose( + nms_antipodal_list, antipodal_sample_num + ) + logger.log_debug(f"Pose sampling cost {(time.time() - start_time) * 1000} ms.") + logger.log_debug( + f"There are {grasp_poses.shape[0]} poses after grasp pose sampling." + ) + # write data + data_dict = { + "grasp_poses": grasp_poses, + "score": score, + "open_length": open_length, + } + pickle.dump(data_dict, open(cache_file, "wb")) + # TODO: contact pair visualization + # self.antipodal_visual(nms_antipodal_list) + grasp_num = grasp_poses.shape[0] + logger.log_debug(f"Write {grasp_num} poses to pickle file {cache_file}.") + grasp_list = [None for i in range(grasp_num)] + for i in range(grasp_num): + grasp_list[i] = AntipodalGrasp(grasp_poses[i], open_length[i], score[i]) + return grasp_list + + def _load_cache(self, cache_file: str): + data_dict = pickle.load(open(cache_file, "rb")) + grasp_num = data_dict["grasp_poses"].shape[0] + logger.log_debug(f"Load {grasp_num} poses from pickle file {cache_file}.") + grasp_list = [None for i in range(grasp_num)] + for i in range(grasp_num): + grasp_list[i] = AntipodalGrasp( + data_dict["grasp_poses"][i], + data_dict["open_length"][i], + data_dict["score"][i], + ) + return grasp_list + + def _get_pc_size(self, vertices): + return np.array( + [ + vertices[:, 0].max() - vertices[:, 0].min(), + vertices[:, 1].max() - vertices[:, 1].min(), + vertices[:, 2].max() - vertices[:, 2].min(), + ] + ) + + def _antipodal_generate( + self, mesh_o3dt: o3d.t.geometry.TriangleMesh, surface_sample_num: int = 5000 + ): + surface_pcd = mesh_o3dt.to_legacy().sample_points_uniformly(surface_sample_num) + points = np.array(surface_pcd.points) + normals = np.array(surface_pcd.normals) + point_num = points.shape[0] + scene = o3d.t.geometry.RaycastingScene() + scene.add_triangles(mesh_o3dt) + # raycast + ray_n = self._cone_sampler._ray_num + ray_num = point_num * ray_n + ray_begins = np.empty(shape=(ray_num, 3), dtype=float) + ray_direcs = np.empty(shape=(ray_num, 3), dtype=float) + origin_normals = np.empty(shape=(ray_num, 3), dtype=float) + origin_points = np.empty(shape=(ray_num, 3), dtype=float) + start_time = time.time() + for i in range(point_num): + ray_direc = self._cone_sampler.cone_sample_direc( + normals[i], is_visual=False + ) + # raycast from outside of object + ray_begin = points[i] - 2 * self._mesh_len * ray_direc + ray_direcs[i * ray_n : (i + 1) * ray_n, :] = ray_direc + ray_begins[i * ray_n : (i + 1) * ray_n, :] = ray_begin + origin_normals[i * ray_n : (i + 1) * ray_n, :] = normals[i] + origin_points[i * ray_n : (i + 1) * ray_n, :] = points[i] + logger.log_debug(f"Cone sampling cost {(time.time() - start_time) * 1000} ms.") + + start_time = time.time() + rays = o3d.core.Tensor( + np.hstack([ray_begins, ray_direcs]), dtype=o3d.core.Dtype.Float32 + ) + logger.log_debug(f"Open3d raycast {(time.time() - start_time) * 1000} ms.") + + raycast_rtn = scene.cast_rays(rays) + hit_dis_all = raycast_rtn["t_hit"].numpy() + hit_normal_all = raycast_rtn["primitive_normals"].numpy() + + # max_angle_cos = np.cos(self._antipodal_max_angle) + antipodal_list = [] + # get antipodal points + start_time = time.time() + for i in range(ray_num): + hit_dis = hit_dis_all[i] + hit_normal = hit_normal_all[i] + hit_point = ray_begins[i] + ray_direcs[i] * hit_dis + antipodal_dis = np.linalg.norm(origin_points[i] - hit_point) + if ( + antipodal_dis > self._min_open_length + and antipodal_dis < self._open_length + ): + # reject thin close object + antipodal = Antipodal( + origin_points[i], hit_point, origin_normals[i], hit_normal + ) + antipodal_list.append(antipodal) + logger.log_debug( + f"Antipodal initialize cost {(time.time() - start_time) * 1000} ms." + ) + return antipodal_list + + def _sample_grasp_pose( + self, antipodal_list: List[Antipodal], antipodal_sample_num: int = 36 + ): + antipodal_num = len(antipodal_list) + grasp_poses = np.empty( + shape=(antipodal_sample_num * antipodal_num, 4, 4), dtype=float + ) + scores = np.empty(shape=(antipodal_sample_num * antipodal_num,), dtype=float) + open_length = np.empty( + shape=(antipodal_sample_num * antipodal_num,), dtype=float + ) + for i in range(antipodal_num): + grasp_poses[ + i * antipodal_sample_num : (i + 1) * antipodal_sample_num + ] = antipodal_list[i].sample_pose(antipodal_sample_num) + scores[ + i * antipodal_sample_num : (i + 1) * antipodal_sample_num + ] = antipodal_list[i].score + open_length[ + i * antipodal_sample_num : (i + 1) * antipodal_sample_num + ] = antipodal_list[i].dis + return grasp_poses, scores, open_length + + def get_all_grasp(self) -> List[AntipodalGrasp]: + """get all grasp + + Returns: + List[AntipodalGrasp]: list of grasp + """ + return self._grasp_list + + def select_grasp( + self, + approach_direction: np.ndarray, + select_num: int = 20, + max_angle: float = np.pi / 10, + select_method: GraspSelectMethod = GraspSelectMethod.NORMAL_SCORE, + ) -> List[AntipodalGrasp]: + """Select grasps. Masked by max_angle and sort by grasp score + + Args: + approach_direction (np.ndarray): gripper approach direction + select_num (int, optional): select grasp number. Defaults to 10. + max_angle (float, optional): max angle threshold (angle with surface normal). Defaults to np.pi/10. + select_method (select_method, optional) + Return: + List[AntipodalGrasp]: list of grasp + """ + grasp_num = len(self._grasp_list) + all_idx = np.arange(grasp_num) + grasp_poses = np.empty(shape=(grasp_num, 4, 4), dtype=float) + scores = np.empty(shape=(grasp_num,), dtype=float) + position = grasp_poses[:, :3, 3] + + for i in range(grasp_num): + grasp_poses[i] = self._grasp_list[i].pose + scores[i] = self._grasp_list[i].score + + # mask acoording to table up direction + grasp_z = grasp_poses[:, :3, 2] + direc_dot = (grasp_z * approach_direction).sum(axis=1) + valid_mask = direc_dot > np.cos(max_angle) + valid_id = all_idx[valid_mask] + + # sort acoording to different grasp score + if select_method == GraspSelectMethod.NORMAL_SCORE: + valid_score = scores[valid_id] + sort_valid_idx = np.argsort(valid_score)[::-1] # large is better + elif select_method == GraspSelectMethod.NEAR_APPROACH: + position_dot = (position * approach_direction).sum(axis=1) + valid_height = position_dot[valid_id] + sort_valid_idx = np.argsort(valid_height) + elif select_method == GraspSelectMethod.CENTER: + center_dis = np.linalg.norm(position - self._center_of_mass, axis=1) + valid_center_dis = center_dis[valid_id] + sort_valid_idx = np.argsort(valid_center_dis) + else: + logger.log_warning(f"select_method {select_method.name} not implemented.") + # return all grasp + return self._grasp_list + + # get best score sample index + result_num = min(len(sort_valid_idx), select_num) + best_valid_idx = sort_valid_idx[:result_num] + best_idx = valid_id[best_valid_idx] + + result_grasp_list = [] + for idx in best_idx: + result_grasp_list.append(self._grasp_list[idx]) + return result_grasp_list + + def _antipodal_nms( + self, antipodal_list: List[Antipodal], nms_ratio: float = 0.02 + ) -> List[Antipodal]: + antipodal_num = len(antipodal_list) + nms_mask = np.empty(shape=(antipodal_num,), dtype=bool) + nms_mask.fill(True) + score_list = np.empty(shape=(antipodal_num,), dtype=float) + + for i in range(antipodal_num): + score_list[i] = antipodal_list[i].score + + sort_idx = np.argsort(score_list)[::-1] + + dis_th = self._mesh_len * nms_ratio + for i in range(antipodal_num): + if not nms_mask[sort_idx[i]]: + continue + antipodal_max = antipodal_list[sort_idx[i]] + other_antipodal = [] + other_idx = [] + for j in range(i + 1, antipodal_num): + if nms_mask[sort_idx[j]]: + other_antipodal.append(antipodal_list[sort_idx[j]]) + other_idx.append(sort_idx[j]) + dis_arr = antipodal_max.get_dis_arr(other_antipodal) + invalid_mask = dis_arr < dis_th + for j, flag in enumerate(invalid_mask): + if flag: + nms_mask[other_idx[j]] = False + nms_antipodal_list = [] + for i in range(antipodal_num): + if nms_mask[sort_idx[i]]: + nms_antipodal_list.append(antipodal_list[sort_idx[i]]) + + # TODO: nms validation check. remove in future + # antipodal_num = len(nms_antipodal_list) + # for i in range(antipodal_num): + # for j in range(i + 1, antipodal_num): + # antipodal_dis = nms_antipodal_list[i].get_dis(nms_antipodal_list[j]) + # if antipodal_dis < dis_th: + # logger.log_warning(f"find near antipodal {i} and {j} with dis {antipodal_dis}") + return nms_antipodal_list + + def _antipodal_clean(self, antipodal_list: List[Antipodal]): + # TODO: need collision checker + + valid_antipodal = [] + max_angle_cos = np.cos(self._antipodal_max_angle) + for antipodal in antipodal_list: + if ( + 1e-5 < antipodal.dis < self._open_length + and antipodal.angle_cos > max_angle_cos + ): + valid_antipodal.append(antipodal) + return valid_antipodal + + def antipodal_visual(self, antipodal_list): + mesh_visual = self._mesh_o3dt.to_legacy() + antipodal_num = len(antipodal_list) + draw_points = np.empty(shape=(antipodal_num * 2, 3), dtype=float) + draw_lines = np.empty(shape=(antipodal_num, 2), dtype=int) + for i in range(antipodal_num): + direc = antipodal_list[i].point_b - antipodal_list[i].point_a + direc = direc / np.linalg.norm(direc) + anti_begin = antipodal_list[i].point_a - direc * 0.005 + anti_end = antipodal_list[i].point_b + direc * 0.005 + draw_points[i * 2] = anti_begin + draw_points[i * 2 + 1] = anti_end + draw_lines[i] = np.array([i * 2, i * 2 + 1], dtype=int) + + line_set = o3d.geometry.LineSet( + points=o3d.utility.Vector3dVector(draw_points), + lines=o3d.utility.Vector2iVector(draw_lines), + ) + draw_colors = np.empty(shape=(antipodal_num, 3), dtype=float) + draw_colors[:] = np.array([0.0, 1.0, 1.0]) + line_set.colors = o3d.utility.Vector3dVector(draw_colors) + o3d.visualization.draw_geometries([line_set, mesh_visual]) + + def grasp_pose_visual( + self, grasp_list: List[AntipodalGrasp] + ) -> List[o3d.t.geometry.TriangleMesh]: + """visualize grasp pose + + Args: + grasp_list (List[AntipodalGrasp]): list of grasp + + Returns: + List[o3d.t.geometry.TriangleMesh]: list of visualization mesh + """ + pose_num = len(grasp_list) + visual_mesh_list = [self._mesh_o3dt.compute_vertex_normals()] + + max_angle_cos = np.cos(self._antipodal_max_angle) + + for i in range(pose_num): + grasp_mesh = grasp_list[i].grasp_pose_visual_mesh( + gripper_open_length=self._open_length + ) + grasp_mesh.transform(grasp_list[i].pose) + # low score: red | hight score: blue + score_ratio = (grasp_list[i].score - max_angle_cos) / (1 - max_angle_cos) + score_ratio = min(1.0, score_ratio) + score_ratio = max(0.0, score_ratio) + grasp_mesh.paint_uniform_color(np.array([1 - score_ratio, 0, score_ratio])) + visual_mesh_list.append(o3d.t.geometry.TriangleMesh.from_legacy(grasp_mesh)) + return visual_mesh_list diff --git a/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py b/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py new file mode 100644 index 00000000..5de52f4c --- /dev/null +++ b/embodichain/toolkits/graspkit/pg_grasp/cone_sampler.py @@ -0,0 +1,121 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +from scipy.spatial.transform import Rotation as R + + +def rotate_to_ref(direc: np.ndarray, rotate_ref: np.ndarray): + assert direc.shape == (3,) + direc_len = np.linalg.norm(direc) + assert direc_len > 1e-5 + direc_unit = direc / direc_len + + assert rotate_ref.shape == (3,) + rotate_ref_len = np.linalg.norm(rotate_ref) + assert rotate_ref_len > 1e-5 + rotate_ref_unit = rotate_ref / rotate_ref_len + + rotate_axis = np.cross(rotate_ref_unit, direc_unit) + rotate_axis_len = np.linalg.norm(rotate_axis) + if rotate_axis_len < 1e-5: + # co axis, no need to do rotation + dot_res = direc_unit.dot(rotate_ref_unit) + if dot_res > 0: + # identity rotation + return np.eye(3, dtype=float) + else: + # negative, rotate 180 degree + # rotate with a perpendicular axis + random_axis = np.random.random(size=(3,)) + perpendicular_axis = np.cross(random_axis, rotate_ref_unit) + perpendicular_axis = perpendicular_axis / np.linalg.norm(perpendicular_axis) + ref_rotation = R.from_rotvec(perpendicular_axis * np.pi).as_matrix() + return ref_rotation + else: + rotate_axis = rotate_axis / rotate_axis_len + angle = np.arccos(direc_unit.dot(rotate_ref_unit)) + ref_rotation = R.from_rotvec(angle * rotate_axis, degrees=False).as_matrix() + return ref_rotation + + +class ConeSampler: + def __init__( + self, max_angle: float, layer_num: int = 2, sample_each_layer: int = 4 + ) -> None: + """cone ray sampler + + Args: + max_angle (float): maximum ray angle to surface normal + layer_num (int, optional): circle layer. Defaults to 2. + sample_each_layer (int, optional): ray samples in each circle layer. Defaults to 4. + """ + self._max_angle = max_angle + self._layer_num = layer_num + self._ray_num = layer_num * sample_each_layer + 1 + alpha_list = np.linspace(max_angle / layer_num, max_angle, layer_num) + beta_list = np.linspace( + 2 * np.pi / sample_each_layer, 2 * np.pi, sample_each_layer + ) + self._direc_ref = np.array([0, 0, 1]) + + rotation_list = np.empty(shape=(self._ray_num, 3, 3), dtype=float) + + for i, alpha in enumerate(alpha_list): + for j, beta in enumerate(beta_list): + x_rotation = R.from_euler( + seq="XYZ", angles=np.array([alpha, 0, 0]), degrees=False + ).as_matrix() + z_rotation = R.from_euler( + seq="XYZ", angles=np.array([0, 0, beta]), degrees=False + ).as_matrix() + rotation_list[i * sample_each_layer + j + 1] = z_rotation @ x_rotation + # original direction + rotation_list[0] = np.eye(3) + self._sample_direc = rotation_list[:, :3, 2] # z-axis + + def cone_sample_direc(self, direc: np.ndarray, is_visual: bool = False): + """sample cone directly + + Args: + direc (np.ndarray): direction to sample a cone + is_visual (bool, optional): use visualization or not. Defaults to False. + + Returns: + np.ndarray: [_ray_num, 3] of float, cone direction list + """ + ref_rotation = rotate_to_ref(direc, self._direc_ref) + cone_direc_list = self._sample_direc @ ref_rotation.T + if is_visual: + self._visual(cone_direc_list) + return cone_direc_list + + def _visual(self, cone_direc_list: np.ndarray): + drawer = o3d.geometry.TriangleMesh.create_coordinate_frame(0.5) + for cone_direc in cone_direc_list: + arrow = o3d.geometry.TriangleMesh.create_arrow( + cylinder_radius=0.02, + cone_radius=0.03, + cylinder_height=0.9, + cone_height=0.1, + ) + arrow.compute_vertex_normals() + arrow_rotation = rotate_to_ref(cone_direc, self._direc_ref) + arrow.rotate(arrow_rotation, center=(0, 0, 0)) + arrow.paint_uniform_color(np.array([0.5, 0.5, 0.5])) + drawer += arrow + o3d.visualization.draw_geometries([drawer]) diff --git a/embodichain/toolkits/urdf_assembly/__init__.py b/embodichain/toolkits/urdf_assembly/__init__.py new file mode 100644 index 00000000..f0866a6f --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/__init__.py @@ -0,0 +1,18 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +from .urdf_assembly_manager import URDFAssemblyManager + +__all__ = ["URDFAssemblyManager"] diff --git a/embodichain/toolkits/urdf_assembly/component.py b/embodichain/toolkits/urdf_assembly/component.py new file mode 100644 index 00000000..603225a3 --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/component.py @@ -0,0 +1,342 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import traceback +import numpy as np +from pathlib import Path +from dataclasses import dataclass +import xml.etree.ElementTree as ET +from typing import Dict, Optional + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) +from embodichain.toolkits.urdf_assembly.mesh import URDFMeshManager + + +__all__ = ["ComponentRegistry", "URDFComponent", "URDFComponentManager"] + + +class ComponentRegistry: + r"""Registry for storing and retrieving URDFComponent objects.""" + + def __init__(self): + self._components = {} + + def add(self, component_type: str, component_obj): + self._components[component_type] = component_obj + + def get(self, component_type: str): + return self._components.get(component_type) + + def all(self): + return self._components + + def remove(self, component_type: str): + if component_type in self._components: + del self._components[component_type] + + +@dataclass +class URDFComponent: + r"""Represents a URDF component with its configuration and transformation. + + This dataclass encapsulates all the information needed to process and integrate + a URDF component into the robot assembly, including file path, attachment + configuration, parameters, and optional spatial transformation. + """ + + urdf_path: str # Path to the URDF file for this component + default_attach_link: str = ( + "base_link" # Default link name for attachment (usually the first link) + ) + params: Dict = None # Component-specific parameters (e.g., wheel_type for chassis) + transform: Optional[ + np.ndarray + ] = None # Optional 4x4 transformation matrix for positioning + + def __post_init__(self): + # Convert path to Path object for better path handling + if isinstance(self.urdf_path, str): + self.urdf_path = Path(self.urdf_path) + + # Validate transformation matrix dimensions and type + if self.transform is not None: + if not isinstance(self.transform, np.ndarray) or self.transform.shape != ( + 4, + 4, + ): + raise ValueError("Transform must be a 4x4 numpy array") + + +class URDFComponentManager: + """Responsible for loading, renaming, and processing meshes for a single component.""" + + def __init__(self, mesh_manager: URDFMeshManager): + self.mesh_manager = mesh_manager + self.logger = URDFAssemblyLogger.get_logger("component_manager") + + def process_component( + self, + comp: str, + prefix: str, + comp_obj, + name_mapping: dict, + base_points: dict, + links: list, + joints: list, + ): + r"""Process a single URDF component by renaming elements and handling meshes. + + Args: + comp (str): Component name (e.g., 'chassis', 'left_arm', 'hand'). + prefix (str): Prefix to add to component elements for uniqueness (e.g., 'left_'). + None means no prefix will be applied. + comp_obj: URDFComponent object containing the component's URDF path and parameters. + name_mapping (dict): Dictionary mapping (component, original_name) tuples to new names. + Used for resolving cross-references between components. + base_points (dict): Dictionary mapping component names to their base connection link names. + Used for establishing parent-child relationships. + links (list): Global list to collect all processed link elements from all components. + joints (list): Global list to collect all processed joint elements from all components. + """ + + try: + urdf_root = ET.parse(comp_obj.urdf_path).getroot() + + # Safe way to get link and joint names, handling None values + global_link_names = { + link.get("name").lower() + for link in links + if link.get("name") is not None + } + global_joint_names = { + joint.get("name").upper() + for joint in joints + if joint.get("name") is not None + } + + first_link_flag = True + joint_name_mapping = {} + + # Process links first + for link in urdf_root.findall("link"): + orig_name = link.get("name") + if orig_name is None: + self.logger.warning( + f"Found link with no name in component {comp}, skipping" + ) + continue + + # Generate unique name + if prefix: + new_name = self._generate_unique_name( + orig_name, prefix, global_link_names + ).lower() + else: + # For components without prefix, ensure names are unique + if orig_name.lower() in global_link_names: + new_name = f"{comp}_{orig_name}".lower() + else: + new_name = orig_name.lower() + + global_link_names.add(new_name) + + # Set first link as base point + if first_link_flag: + base_points[comp] = new_name + first_link_flag = False + + # Update link name mapping and set link name to lowercase + name_mapping[(comp, orig_name)] = new_name + link.set("name", new_name) + links.append(link) + + self._process_meshes(link, comp_obj.urdf_path, comp) + + # Process joints: Build mapping table AND process all at once + joints_to_process = [] + + # First collect all joints and build complete mapping + for joint in urdf_root.findall("joint"): + orig_joint_name = joint.get("name") + if orig_joint_name is None: + continue + + new_joint_name = self._generate_unique_name( + orig_joint_name, prefix, global_joint_names + ).upper() + global_joint_names.add(new_joint_name) + + # Build the complete mapping table + joint_name_mapping[orig_joint_name] = new_joint_name + joints_to_process.append((joint, orig_joint_name, new_joint_name)) + + self.logger.debug(f"Joint name mapping for [{comp}]: {joint_name_mapping}") + + # Now process all joints with complete mapping available + for joint, orig_joint_name, new_joint_name in joints_to_process: + # Set the new joint name + joint.set("name", new_joint_name) + + # Update parent and child links to lowercase - with None checks + parent_elem = joint.find("parent") + child_elem = joint.find("child") + + if parent_elem is not None: + parent = parent_elem.get("link") + if parent is not None: + new_parent_name = name_mapping.get( + (comp, parent), parent + ).lower() + parent_elem.set("link", new_parent_name) + else: + self.logger.warning( + f"Found parent element with no link attribute in joint {orig_joint_name}" + ) + + if child_elem is not None: + child = child_elem.get("link") + if child is not None: + new_child_name = name_mapping.get((comp, child), child).lower() + child_elem.set("link", new_child_name) + else: + self.logger.warning( + f"Found child element with no link attribute in joint {orig_joint_name}" + ) + + # Process mimic joint references using the complete mapping table + mimic_elem = joint.find("mimic") + if mimic_elem is not None: + mimic_joint = mimic_elem.get("joint") + if mimic_joint is not None: + self.logger.info( + f"Processing mimic joint reference: ({mimic_joint}) in joint ({orig_joint_name})" + ) + # Look up the corresponding new joint name in the mapping table + new_mimic_joint = joint_name_mapping.get(mimic_joint) + if new_mimic_joint: + # Update the mimic element to reference the renamed joint + mimic_elem.set("joint", new_mimic_joint) + self.logger.info( + f"✓ Updated mimic joint reference: ({mimic_joint}) -> ({new_mimic_joint})" + ) + else: + self.logger.warning( + f"✗ Could not find mapping for mimic joint: ({mimic_joint})" + ) + self.logger.warning( + f"Available mappings: {list(joint_name_mapping.keys())}" + ) + + joints.append(joint) + + self.logger.debug( + f"Processed component: [{comp}], links: {len(urdf_root.findall('link'))}, joints: {len(urdf_root.findall('joint'))}" + ) + + except Exception as e: + self.logger.error( + f"Failed to process component [{comp}]: {e}", exc_info=True + ) + self.logger.error(f"Traceback: {traceback.format_exc()}") + + def _generate_unique_name( + self, orig_name: str, prefix: str, existing_names: set + ) -> str: + r"""Generate a unique name by adding a prefix and ensuring no conflicts. + + Args: + orig_name (str): The original name to modify. + prefix (str): The prefix to add to the name. + existing_names (set): A set of existing names to check for conflicts. + + Returns: + str: A unique name derived from the original name. + """ + if orig_name is None: + orig_name = "unnamed" + + if prefix and not orig_name.lower().startswith(prefix.lower()): + new_name = f"{prefix}{orig_name}".lower() + else: + new_name = orig_name.lower() + + # Ensure the new name is unique + if new_name in existing_names: + counter = 1 + unique_name = f"{new_name}_{counter}" + while unique_name in existing_names: + counter += 1 + unique_name = f"{new_name}_{counter}" + new_name = unique_name + + return new_name + + def _process_meshes(self, link: ET.Element, base_urdf_path: str, comp_name: str): + r"""Process visual and collision meshes for a link. + + Args: + link (ET.Element): The URDF link element to process. + base_urdf_path (str): The base path for the URDF files. + comp_name (str): The name of the component being processed. + """ + try: + for visual in link.findall("visual"): + geometry = visual.find("geometry") + if geometry is not None: + mesh = geometry.find("mesh") + if mesh is not None: + filename = mesh.get("filename") + if filename is not None: + self.logger.debug(f"Processing visual mesh: {filename}") + new_mesh_filename = ( + self.mesh_manager.copy_and_modify_mesh_file( + base_urdf_path, + filename, + "Visual", + comp_name, + ) + ) + self.logger.debug( + f"Updated visual mesh filename: {new_mesh_filename}" + ) + mesh.set("filename", new_mesh_filename) + + for collision in link.findall("collision"): + geometry = collision.find("geometry") + if geometry is not None: + mesh = geometry.find("mesh") + if mesh is not None: + filename = mesh.get("filename") + if filename is not None: + self.logger.debug(f"Processing collision mesh: {filename}") + new_mesh_filename = ( + self.mesh_manager.copy_and_modify_mesh_file( + base_urdf_path, + filename, + "Collision", + comp_name, + ) + ) + self.logger.debug( + f"Updated collision mesh filename: {new_mesh_filename}" + ) + mesh.set("filename", new_mesh_filename) + except Exception as e: + self.logger.error( + f"Failed to process meshes for component {comp_name}: {e}" + ) diff --git a/embodichain/toolkits/urdf_assembly/connection.py b/embodichain/toolkits/urdf_assembly/connection.py new file mode 100644 index 00000000..0e871b83 --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/connection.py @@ -0,0 +1,212 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import xml.etree.ElementTree as ET + +from scipy.spatial.transform import Rotation as R + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) + +__all__ = ["URDFConnectionManager"] + + +class URDFConnectionManager: + r""" + Responsible for managing connection rules between components and sensor attachments. + """ + + def __init__(self, base_link_name: str): + r"""Initialize the URDFConnectionManager. + + Args: + base_link_name (str): The name of the base link to which the chassis or other components may be attached. + """ + self.base_link_name = base_link_name + self.logger = URDFAssemblyLogger.get_logger("connection_manager") + + def add_connections( + self, + joints: list, + base_points: dict, + parent_attach_points: dict, + connection_rules: list, + component_transforms: dict = None, + ): + r"""Add connection joints between robot components according to the specified rules. + + Args: + joints (list): A list to collect joint elements. + base_points (dict): A mapping from component names to their child connection link names. + parent_attach_points (dict): A mapping from component names to their parent connection link names. + connection_rules (list): A list of (parent, child) tuples specifying connection relationships. + component_transforms (dict): Optional mapping from component names to their transform matrices. + """ + chassis_component = "chassis" + component_transforms = component_transforms or {} + + existing_joint_names = { + joint.get("name") for joint in joints if hasattr(joint, "get") + } + + # chassis is always attached to base_link (no transform applied to this connection) + if chassis_component in base_points: + chassis_first_link = base_points[chassis_component] + joint_name = f"BASE_LINK_TO_{chassis_component.upper()}_CONNECTOR" + if joint_name not in existing_joint_names: + joint = ET.Element("joint", name=joint_name, type="fixed") + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "parent", link=self.base_link_name) + ET.SubElement(joint, "child", link=chassis_first_link) + joints.append(joint) + existing_joint_names.add(joint_name) + self.logger.info( + f"[{chassis_component.capitalize()}] connected to [base_link] via ({chassis_first_link})" + ) + else: + # If no chassis, connect components directly to base_link with their transforms + self.logger.info( + "No chassis found, connecting components directly to base_link" + ) + + # Find components that don't have parents in connection_rules + components_with_parents = {child for parent, child in connection_rules} + orphan_components = [ + comp + for comp in base_points.keys() + if comp not in components_with_parents + ] + + for comp in orphan_components: + comp_first_link = base_points[comp] + joint_name = f"BASE_TO_{comp.upper()}_CONNECTOR" + + if joint_name not in existing_joint_names: + joint = ET.Element("joint", name=joint_name, type="fixed") + + # Apply transform to this specific connection if the component has one + if comp in component_transforms: + transform = component_transforms[comp] + xyz = transform[:3, 3] # Extract translation + rotation = R.from_matrix(transform[:3, :3]) + rpy = rotation.as_euler("xyz") + + ET.SubElement( + joint, + "origin", + xyz=f"{xyz[0]} {xyz[1]} {xyz[2]}", + rpy=f"{rpy[0]} {rpy[1]} {rpy[2]}", + ) + self.logger.info( + f"Applied transform to base connection {comp}: xyz={xyz}, rpy={rpy}" + ) + else: + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + + ET.SubElement(joint, "parent", link=self.base_link_name) + ET.SubElement(joint, "child", link=comp_first_link) + joints.append(joint) + existing_joint_names.add(joint_name) + + self.logger.info( + f"[{comp.capitalize()}] connected to [base_link] via ({comp_first_link})" + ) + + # Process other connection relationships + for parent, child in connection_rules: + if parent in parent_attach_points and child in base_points: + parent_connect_link = parent_attach_points[parent].lower() + child_connect_link = base_points[child].lower() + + self.logger.info( + f"Connecting [{parent}]-({parent_connect_link}) to [{child}]-({child_connect_link})" + ) + + # Create a unique joint name + base_joint_name = f"{parent.upper()}_TO_{child.upper()}_CONNECTOR" + if base_joint_name not in existing_joint_names: + joint = ET.Element("joint", name=base_joint_name, type="fixed") + + # Apply transform to this specific connection if the child component has one + if child in component_transforms: + transform = component_transforms[child] + xyz = transform[:3, 3] # Extract translation + rotation = R.from_matrix(transform[:3, :3]) + rpy = rotation.as_euler("xyz") + + ET.SubElement( + joint, + "origin", + xyz=f"{xyz[0]} {xyz[1]} {xyz[2]}", + rpy=f"{rpy[0]} {rpy[1]} {rpy[2]}", + ) + self.logger.info( + f"Applied transform to connection {parent} -> {child}: xyz={xyz}, rpy={rpy}" + ) + else: + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + + ET.SubElement(joint, "parent", link=parent_connect_link) + ET.SubElement(joint, "child", link=child_connect_link) + joints.append(joint) + existing_joint_names.add(base_joint_name) + else: + self.logger.warning( + f"Duplicate connection rule: {parent} -> {child}" + ) + else: + self.logger.error(f"Invalid connection rule: {parent} -> {child}") + + def add_sensor_attachments( + self, joints: list, attach_dict: dict, base_points: dict + ): + r"""Attach sensors to the robot by creating fixed joints.""" + for sensor_name, attach in attach_dict.items(): + sensor_urdf = ET.parse(attach.sensor_urdf).getroot() + + # Add sensor links and joints to the main lists + for link in sensor_urdf.findall("link"): + # Ensure sensor link names are lowercase + link.set("name", link.get("name").lower()) + joints.append(link) # This should be added to links list instead + + for joint in sensor_urdf.findall("joint"): + # Ensure sensor joint names are uppercase and link references are lowercase + joint.set("name", joint.get("name").upper()) + parent_elem = joint.find("parent") + child_elem = joint.find("child") + if parent_elem is not None: + parent_elem.set("link", parent_elem.get("link").lower()) + if child_elem is not None: + child_elem.set("link", child_elem.get("link").lower()) + joints.append(joint) + + parent_link = base_points.get( + attach.parent_component, attach.parent_component + ).lower() # Ensure lowercase + + # Create connection joint with uppercase name + joint_name = ( + f"{attach.parent_component.upper()}_TO_{sensor_name.upper()}_CONNECTOR" + ) + joint = ET.Element("joint", name=joint_name, type="fixed") + ET.SubElement(joint, "origin", xyz="0 0 0", rpy="0 0 0") + ET.SubElement(joint, "parent", link=parent_link) + ET.SubElement( + joint, "child", link=sensor_urdf.find("link").get("name").lower() + ) + joints.append(joint) diff --git a/embodichain/toolkits/urdf_assembly/file_writer.py b/embodichain/toolkits/urdf_assembly/file_writer.py new file mode 100644 index 00000000..18d6a24e --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/file_writer.py @@ -0,0 +1,211 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from datetime import datetime + +import xml.etree.ElementTree as ET + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) + +__all__ = ["URDFFileWriter"] + + +class URDFFileWriter: + r"""Responsible for formatting XML and writing URDF files with proper headers.""" + + def __init__(self, module_names: list = None): + r"""Initialize the URDFFileWriter. + + Args: + module_names (list): List of module names to include in the header. + """ + self.module_names = module_names or [] + self.logger = URDFAssemblyLogger.get_logger("file_writer") + + def create_section_comment( + self, content: str, comment_type: str = "section" + ) -> ET.Comment: + r"""Create standardized section comments for URDF organization. + + Args: + content (str): The content of the comment. + comment_type (str): Type of comment - "section", "start", "end", "empty". + + Returns: + ET.Comment: XML comment element. + """ + if comment_type == "empty": + return ET.Comment("") + elif comment_type == "start": + return ET.Comment(f" Start of ({content.lower()}) ") + elif comment_type == "end": + return ET.Comment(f" End of ({content.lower()}) ") + else: + return ET.Comment(f" {content} ") + + def add_section_comments( + self, elements_list: list, part_name: str, section_type: str + ): + r"""Add standardized section comments to elements list. + + Args: + elements_list (list): List to add comments to (links or joints). + part_name (str): Name of the component part. + section_type (str): Type of section ("Links" or "Joints"). + """ + elements_list.append(self.create_section_comment("", "empty")) + elements_list.append( + self.create_section_comment( + f"{section_type} for part: {part_name}", "start" + ) + ) + + def add_section_end_comments( + self, elements_list: list, part_name: str, section_type: str + ): + r"""Add standardized section end comments to elements list. + + Args: + elements_list (list): List to add comments to (links or joints). + part_name (str): Name of the component part. + section_type (str): Type of section ("Links" or "Joints"). + """ + elements_list.append( + self.create_section_comment(f"{section_type} for part: {part_name}", "end") + ) + elements_list.append(self.create_section_comment("", "empty")) + + def make_comment_line(self, content: str, width: int = 80) -> str: + r"""Create a properly formatted comment line with centered content. + + Args: + content (str): The content to be centered in the comment. + width (int): Total width of the comment line (default is 80). + + Returns: + str: A formatted XML comment line. + """ + content = content.strip() + pad_total = width - 7 - len(content) + pad_left = pad_total // 2 + pad_right = pad_total - pad_left + if pad_total < 0: + pad_left = 0 + pad_right = 0 + return f"" + + def generate_header( + self, module_names: list = None, assembly_signature: str = None + ) -> str: + r"""Generate a standard header for URDF files with assembly signature. + + Args: + module_names (list): List of module names to include in the header. + assembly_signature (str): MD5 signature of the assembly configuration. + + Returns: + str: Formatted header string. + """ + if module_names is None: + module_names = self.module_names + + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Calculate proper spacing for centered content + header_width = 80 + separator_line = "" + + def center_comment(text: str) -> str: + """Center text within comment brackets with proper padding.""" + content_width = header_width - 8 # Account for + text_len = len(text) + if text_len >= content_width: + return f"" + + padding = content_width - text_len + left_pad = padding // 2 + right_pad = padding - left_pad + return f"" + + header_lines = [ + '', + separator_line, + center_comment("Robot URDF Model Generation Report"), + center_comment(f"Generation Time: {now}"), + center_comment("Tool Version: DexForce URDF Composer V1.0"), + center_comment(f"Included Modules: {' + '.join(module_names)}"), + ] + + # Add assembly signature if provided + if assembly_signature: + header_lines.append( + center_comment(f"配置签名 ASSEMBLY_SIGNATURE: {assembly_signature}") + ) + + header_lines.append(separator_line) + + return "\n".join(header_lines) + "\n" + + def prettify(self, elem: ET.Element, level: int = 0) -> ET.Element: + r"""Format an XML element by adding newlines and indentation. + + Args: + elem (ET.Element): The XML element to format. + level (int): The current indentation level (default is 0). + + Returns: + ET.Element: The formatted XML element. + """ + indent = "\n" + " " * level # Create indentation string based on level + if len(elem): # If the element has children + if not elem.text or not elem.text.strip(): + elem.text = indent + " " # Add indentation if no text + if not elem.tail or not elem.tail.strip(): + elem.tail = indent # Add indentation after the element + for child in elem: + self.prettify(child, level + 1) # Recursive call for children + if not child.tail or not child.tail.strip(): + child.tail = indent # Ensure the last child has proper tail indentation + else: # If the element has no children + if level and (not elem.tail or not elem.tail.strip()): + elem.tail = indent # Add indentation for elements at a non-zero level + return elem + + def write_urdf( + self, + merged_urdf: ET.Element, + output_path: str, + module_names: list = None, + assembly_signature: str = None, + ): + r"""Write the merged URDF to file with proper formatting and header including signature. + + Args: + merged_urdf (ET.Element): The merged URDF XML element. + output_path (str): Path where the URDF file will be written. + module_names (list): Optional list of module names for the header. + assembly_signature (str): Optional assembly signature to include in header. + """ + header = self.generate_header(module_names, assembly_signature) + xml_str = ET.tostring(self.prettify(merged_urdf), encoding="unicode") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(header) + f.write(xml_str) + + self.logger.info(f"URDF file written to: {output_path}") diff --git a/embodichain/toolkits/urdf_assembly/logging_utils.py b/embodichain/toolkits/urdf_assembly/logging_utils.py new file mode 100644 index 00000000..63c17d0c --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/logging_utils.py @@ -0,0 +1,131 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import logging +from typing import Optional + +__all__ = ["URDFAssemblyLogger"] + + +class URDFColorFormatter(logging.Formatter): + r"""Color formatter for URDF assembly logging""" + + COLORS = { + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[41m", # Red background + } + # Symbol colors + BRACKET_COLOR = "\033[94m" # Bright blue for [] + PAREN_COLOR = "\033[95m" # Magenta for () + RESET = "\033[0m" + + def format(self, record): + color = self.COLORS.get(record.levelname, self.RESET) + message = super().format(record) + + # Apply symbol coloring first + message = self._colorize_symbols(message, color) + + return f"{color}{message}{self.RESET}" + + def _colorize_symbols(self, message, base_color): + r"""Add colors to brackets and parentheses while preserving base color""" + import re + + # Color square brackets and their content, then restore base color + message = re.sub( + r"\[([^\]]+)\]", + f"{self.BRACKET_COLOR}[\\1]{self.RESET}{base_color}", + message, + ) + + # Color parentheses and their content, then restore base color + message = re.sub( + r"\(([^)]+)\)", f"{self.PAREN_COLOR}(\\1){self.RESET}{base_color}", message + ) + + return message + + +class URDFAssemblyLogger: + r"""URDF Assembly module-specific logger manager""" + + _loggers = {} # Cache for created loggers + _initialized = False + + @classmethod + def get_logger(cls, name: Optional[str] = None) -> logging.Logger: + r"""Get or create a URDF assembly-specific logger + + Args: + name: Logger name, defaults to calling module name + + Returns: + Configured logger instance + """ + if name is None: + # Get caller's module name + import inspect + + frame = inspect.currentframe().f_back + module_name = frame.f_globals.get("__name__", "unknown") + if module_name == "__main__": + name = "urdf_assembly.main" + else: + name = f'urdf_assembly.{module_name.split(".")[-1]}' + else: + # Ensure using urdf_assembly prefix + if not name.startswith("urdf_assembly."): + name = f"urdf_assembly.{name}" + + # Return cached logger or create new one + if name not in cls._loggers: + logger = logging.getLogger(name) + + # Avoid duplicate handlers + if not logger.handlers: + handler = logging.StreamHandler() + formatter = URDFColorFormatter("[%(levelname)s] %(name)s: %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False # Don't propagate to root logger + + cls._loggers[name] = logger + + return cls._loggers[name] + + @classmethod + def set_level(cls, level): + r"""Set log level for all URDF assembly loggers""" + for logger in cls._loggers.values(): + logger.setLevel(level) + + @classmethod + def disable_other_loggers(cls): + r"""Disable output from other non-URDF loggers""" + logging.getLogger().setLevel(logging.CRITICAL) + + +# Remove original setup_logger function, replace with URDF-specific initialization +def setup_urdf_logging(): + """Initialize URDF assembly logging system""" + # Optional: disable other logger outputs + URDFAssemblyLogger.disable_other_loggers() + return URDFAssemblyLogger.get_logger("urdf_assembly.main") diff --git a/embodichain/toolkits/urdf_assembly/mesh.py b/embodichain/toolkits/urdf_assembly/mesh.py new file mode 100644 index 00000000..c03a28c8 --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/mesh.py @@ -0,0 +1,190 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import shutil +import xml.etree.ElementTree as ET + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) + +__all__ = ["URDFMeshManager"] + + +class URDFMeshManager: + r"""Responsible for copying, renaming, and handling dependencies of mesh files.""" + + def __init__(self, output_dir: str): + r"""Initialize the URDFMeshManager with output directory configuration. + + Args: + output_dir (str): Base directory where mesh files will be organized. + Creates subdirectories 'Visual' and 'Collision' for + different mesh types. + """ + self.output_dir = output_dir + self.logger = URDFAssemblyLogger.get_logger("mesh_manager") + + def ensure_dirs(self): + r"""Ensure that the output directory contains 'Collision' and 'Visual' subdirectories. + Creates them if they do not exist. + + Returns: + tuple: Paths to the 'Collision' and 'Visual' directories. + """ + collision_dir = os.path.join(self.output_dir, "Collision") + visual_dir = os.path.join(self.output_dir, "Visual") + os.makedirs(collision_dir, exist_ok=True) + os.makedirs(visual_dir, exist_ok=True) + return collision_dir, visual_dir + + def copy_and_modify_mesh_file( + self, base_urdf_path: str, mesh_file_name: str, sub_folder: str, comp_name: str + ): + r"""Copy a mesh file to the output directory and handle dependencies. + + Args: + base_urdf_path (str): Path to the base URDF file. + mesh_file_name (str): Name of the mesh file to copy. + sub_folder (str): 'Visual' or 'Collision'. + comp_name (str): Component name, e.g. 'chassis', 'left_arm'. + + Returns: + str: Relative path to the new mesh file for URDF reference. + """ + # New mesh path format: output_dir/{sub_folder}/{comp_name}/{original_filename} + target_dir = os.path.join(self.output_dir, sub_folder, comp_name) + os.makedirs(target_dir, exist_ok=True) + + # Get URDF directory + urdf_dir = os.path.dirname(base_urdf_path) + + # Handle different path types + if os.path.isabs(mesh_file_name): + # Absolute path + original_mesh_path = mesh_file_name + else: + # Relative path - join with URDF directory and normalize + original_mesh_path = os.path.join(urdf_dir, mesh_file_name) + original_mesh_path = os.path.normpath(original_mesh_path) + + # Debug information + self.logger.debug(f"Processing mesh file:") + self.logger.debug(f" URDF path: {base_urdf_path}") + self.logger.debug(f" URDF dir: {urdf_dir}") + self.logger.debug(f" Mesh reference: {mesh_file_name}") + self.logger.debug(f" Resolved path: {original_mesh_path}") + + # Check if file exists + if not os.path.exists(original_mesh_path): + # Try some common alternative patterns + alternatives = [] + + # Try removing '../' and looking in same directory as URDF + if mesh_file_name.startswith("../"): + alt_path = os.path.join(urdf_dir, mesh_file_name[3:]) + alternatives.append(alt_path) + + # Try looking in parent directory structure + parent_dir = os.path.dirname(urdf_dir) + if mesh_file_name.startswith("../"): + alt_path = os.path.join(parent_dir, mesh_file_name[3:]) + alternatives.append(alt_path) + else: + alt_path = os.path.join(parent_dir, mesh_file_name) + alternatives.append(alt_path) + + # Try looking directly in the mesh filename as basename + basename = os.path.basename(mesh_file_name) + alt_path = os.path.join(urdf_dir, basename) + alternatives.append(alt_path) + + # Check alternatives + found_alternative = None + for alt in alternatives: + alt_normalized = os.path.normpath(alt) + if os.path.exists(alt_normalized): + found_alternative = alt_normalized + self.logger.debug( + f"Found mesh file at alternative location: {alt_normalized}" + ) + break + + if found_alternative: + original_mesh_path = found_alternative + else: + self.logger.error(f"Mesh file not found: {original_mesh_path}") + self.logger.debug(f" Tried alternatives: {alternatives}") + # Return original path to keep existing URDF reference + return mesh_file_name + + new_mesh_path = os.path.join(target_dir, os.path.basename(mesh_file_name)) + + try: + shutil.copyfile(original_mesh_path, new_mesh_path) + self.logger.debug(f"Copied mesh: {original_mesh_path} -> {new_mesh_path}") + except Exception as e: + self.logger.error(f"Failed to copy mesh file: {e}", exc_info=True) + return mesh_file_name + + # Handle OBJ's mtl dependency + if mesh_file_name.lower().endswith(".obj"): + mtl_filename = os.path.splitext(mesh_file_name)[0] + ".mtl" + original_mtl_path = os.path.join( + os.path.dirname(original_mesh_path), mtl_filename + ) + if os.path.exists(original_mtl_path): + new_mtl_path = os.path.join(target_dir, os.path.basename(mtl_filename)) + shutil.copyfile(original_mtl_path, new_mtl_path) + # Fix mtllib path in obj file to reference local filename + with open(new_mesh_path, "r") as f: + obj_content = f.read() + obj_content = obj_content.replace( + f"mtllib {mtl_filename}", f"mtllib {os.path.basename(mtl_filename)}" + ) + with open(new_mesh_path, "w") as f: + f.write(obj_content) + + # Handle DAE's texture dependency + if mesh_file_name.lower().endswith(".dae"): + try: + dae_tree = ET.parse(original_mesh_path) + dae_root = dae_tree.getroot() + ns = {} + if "}" in dae_root.tag: + ns["c"] = dae_root.tag.split("}")[0].strip("{") + image_tags = dae_root.findall(".//c:image", ns) + else: + image_tags = dae_root.findall(".//image") + for image in image_tags: + init_from = ( + image.find("c:init_from", ns) if ns else image.find("init_from") + ) + if init_from is not None and init_from.text: + tex_filename = os.path.basename(init_from.text) + original_tex_path = os.path.join( + os.path.dirname(original_mesh_path), tex_filename + ) + if os.path.exists(original_tex_path): + new_tex_path = os.path.join(target_dir, tex_filename) + shutil.copyfile(original_tex_path, new_tex_path) + except Exception as e: + self.logger.warning( + f"Failed to parse DAE texture dependency: {e}", exc_info=True + ) + + return os.path.join(sub_folder, comp_name, os.path.basename(mesh_file_name)) diff --git a/embodichain/toolkits/urdf_assembly/sensor.py b/embodichain/toolkits/urdf_assembly/sensor.py new file mode 100644 index 00000000..6cf87829 --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/sensor.py @@ -0,0 +1,957 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import copy +import traceback +import numpy as np +from dataclasses import dataclass +import xml.etree.ElementTree as ET + +from scipy.spatial.transform import Rotation as R +from typing import Dict, List, Optional, Union, Tuple + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) +from embodichain.toolkits.urdf_assembly.mesh import URDFMeshManager + +__all__ = ["SensorRegistry", "SensorAttachment", "URDFSensorManager"] + + +class SensorRegistry: + """Registry for storing and retrieving SensorAttachment objects.""" + + def __init__(self): + self._sensors = {} + + def add(self, sensor_name: str, sensor_obj): + self._sensors[sensor_name] = sensor_obj + + def get(self, sensor_name: str): + return self._sensors.get(sensor_name) + + def all(self): + return self._sensors + + def remove(self, sensor_name: str): + if sensor_name in self._sensors: + del self._sensors[sensor_name] + + +@dataclass +class SensorAttachment: + r"""Represents a sensor attachment configuration to a robot component. + + This dataclass defines how a sensor should be attached to a specific component + and link within the robot assembly, including optional spatial transformation + to position the sensor correctly relative to the attachment point. + """ + + sensor_urdf: str # Path to the sensor's URDF file + parent_component: str # Name of the component to which the sensor is attached + parent_link: str # Specific link name within the parent component for attachment + transform: Optional[ + np.ndarray + ] = None # Optional 4x4 transformation matrix for sensor positioning + sensor_type: Optional[str] = None # Optional sensor type field + + +class URDFSensorManager: + r"""Responsible for loading, processing, and managing sensor attachments.""" + + def __init__(self, mesh_manager: URDFMeshManager): + r"""Initialize the URDFSensorManager. + + Args: + mesh_manager (URDFMeshManager): Manager for handling mesh files. + """ + self.mesh_manager = mesh_manager + self.logger = URDFAssemblyLogger.get_logger("sensor_manager") + self.attached_sensors = {} # Maps sensor_name to processed sensor data + + def attach_sensor( + self, + sensor_name: str, + sensor_source: Union[ + str, ET.Element, Dict, Tuple[List[ET.Element], List[ET.Element]] + ], + parent_component: str, + parent_link: str, + transform: Optional[np.ndarray] = None, + sensor_type: Optional[str] = None, + extract_links: Optional[List[str]] = None, + extract_joints: Optional[List[str]] = None, + ) -> bool: + r"""Attach a sensor to a specific component and link with multiple input format support. + + Args: + sensor_name (str): Unique identifier for the sensor attachment. + sensor_source: Sensor definition source (multiple formats supported). + parent_component (str): Target component name for sensor attachment. + parent_link (str): Specific link within parent component for attachment. + transform (Optional[np.ndarray]): Optional 4x4 homogeneous transformation matrix. + sensor_type (Optional[str]): Sensor type classification. + extract_links (Optional[List[str]]): Specific link names to extract from URDF. + extract_joints (Optional[List[str]]): Specific joint names to extract from URDF. + + Returns: + bool: True if sensor attachment successful, False on failure. + """ + try: + # Phase 1: Input validation + if not self._validate_sensor_params( + sensor_name, sensor_source, parent_component, parent_link, transform + ): + return False + + # Phase 2: Process sensor source based on input type + sensor_elements = self._process_sensor_source( + sensor_source, extract_links, extract_joints, sensor_name + ) + + if not sensor_elements: + self.logger.error("Failed to process sensor source") + return False + + sensor_links, sensor_joints, sensor_urdf_path = sensor_elements + + # Phase 3: Validate sensor elements + if not self._validate_sensor_elements(sensor_links, sensor_joints): + return False + + # Phase 4: Process and rename sensor elements to avoid conflicts + processed_elements = self._process_and_rename_sensor_elements( + sensor_links, sensor_joints, sensor_name + ) + + if not processed_elements: + self.logger.error("Failed to process sensor elements") + return False + + processed_links, processed_joints = processed_elements + + # Phase 5: Create sensor attachment data (compatible with existing SensorAttachment) + sensor_attachment = SensorAttachment( + sensor_urdf=sensor_urdf_path, + parent_component=parent_component, + parent_link=parent_link, + transform=transform, + ) + + # Store processed sensor data + self.attached_sensors[sensor_name] = { + "attachment": sensor_attachment, + "links": processed_links, + "joints": processed_joints, + "sensor_type": sensor_type, + } + + self.logger.debug( + f"Successfully attached sensor [{sensor_name}] " + f"({sensor_type or 'unspecified'}) with {len(processed_links)} links " + f"and {len(processed_joints)} joints to component ({parent_component}) " + f"at link ({parent_link})" + ) + return True + + except Exception as e: + self.logger.error(f"Sensor attachment failed for [{sensor_name}]: {str(e)}") + self.logger.debug(f"Traceback: {traceback.format_exc()}") + return False + + def _validate_sensor_params( + self, + sensor_name: str, + sensor_source, + parent_component: str, + parent_link: str, + transform: Optional[np.ndarray], + ) -> bool: + r"""Validate input parameters for sensor attachment. + + Args: + sensor_name: Sensor identifier to validate + sensor_source: Sensor source to validate + parent_component: Parent component name to validate + parent_link: Parent link name to validate + transform: Transformation matrix to validate + + Returns: + bool: True if all parameters are valid, False otherwise + """ + # Validate sensor name + if not sensor_name or not isinstance(sensor_name, str): + self.logger.error("Sensor name must be a non-empty string") + return False + + if not sensor_name.replace("_", "").replace("-", "").isalnum(): + self.logger.error( + "Sensor name must contain only alphanumeric characters, underscores, and hyphens" + ) + return False + + # Validate sensor source + if sensor_source is None: + self.logger.error("Sensor source cannot be None") + return False + + # Validate parent component and link + if not parent_component or not isinstance(parent_component, str): + self.logger.error("Parent component must be a non-empty string") + return False + + if not parent_link or not isinstance(parent_link, str): + self.logger.error("Parent link must be a non-empty string") + return False + + # Validate transformation matrix if provided + if transform is not None: + if not isinstance(transform, np.ndarray): + self.logger.error("Transform must be a numpy array") + return False + + if transform.shape != (4, 4): + self.logger.error( + f"Transform must be 4x4 matrix, got shape {transform.shape}" + ) + return False + + if not self._is_valid_homogeneous_transform(transform): + self.logger.error( + "Transform is not a valid homogeneous transformation matrix" + ) + return False + + return True + + def _process_sensor_source( + self, + sensor_source, + extract_links: Optional[List[str]], + extract_joints: Optional[List[str]], + sensor_name: str, + ) -> Optional[Tuple[List[ET.Element], List[ET.Element], str]]: + r"""Process sensor source based on input type and extract relevant elements. + + Args: + sensor_source: Input sensor source in various formats + extract_links: Optional list of specific link names to extract + extract_joints: Optional list of specific joint names to extract + sensor_name: Sensor name for path generation + + Returns: + Optional tuple of (links, joints, urdf_path) or None on failure + """ + try: + if isinstance(sensor_source, str): + # Handle URDF file path + return self._process_urdf_file_source( + sensor_source, extract_links, extract_joints + ) + + elif isinstance(sensor_source, ET.Element): + # Handle pre-loaded URDF element + return self._process_urdf_element_source( + sensor_source, extract_links, extract_joints, sensor_name + ) + + elif isinstance(sensor_source, dict): + # Handle configuration dictionary + return self._process_config_dict_source(sensor_source, sensor_name) + + elif isinstance(sensor_source, tuple) and len(sensor_source) == 2: + # Handle direct (links, joints) tuple + return self._process_element_tuple_source(sensor_source, sensor_name) + + else: + self.logger.error( + f"Unsupported sensor source type: {type(sensor_source)}" + ) + return None + + except Exception as e: + self.logger.error(f"Error processing sensor source: {str(e)}") + return None + + def _process_urdf_file_source( + self, + file_path: str, + extract_links: Optional[List[str]], + extract_joints: Optional[List[str]], + ) -> Optional[Tuple[List[ET.Element], List[ET.Element], str]]: + r"""Process URDF file source and extract specified elements. + + Args: + file_path: Path to URDF file + extract_links: Optional list of link names to extract + extract_joints: Optional list of joint names to extract + + Returns: + Tuple of (links, joints, file_path) or None on failure + """ + if not os.path.exists(file_path): + self.logger.error(f"Sensor URDF file not found: {file_path}") + return None + + try: + urdf_element = ET.parse(file_path).getroot() + links, joints = self._extract_elements_from_urdf( + urdf_element, extract_links, extract_joints + ) + return links, joints, file_path + + except ET.ParseError as e: + self.logger.error(f"Failed to parse URDF file {file_path}: {str(e)}") + return None + + def _process_urdf_element_source( + self, + urdf_element: ET.Element, + extract_links: Optional[List[str]], + extract_joints: Optional[List[str]], + sensor_name: str, + ) -> Tuple[List[ET.Element], List[ET.Element], str]: + r"""Process pre-loaded URDF element source. + + Args: + urdf_element: Pre-loaded URDF root element + extract_links: Optional list of link names to extract + extract_joints: Optional list of joint names to extract + sensor_name: Sensor name for path generation + + Returns: + Tuple of (links, joints, generated_path) + """ + links, joints = self._extract_elements_from_urdf( + urdf_element, extract_links, extract_joints + ) + generated_path = f"" + return links, joints, generated_path + + def _process_config_dict_source( + self, config: Dict, sensor_name: str + ) -> Tuple[List[ET.Element], List[ET.Element], str]: + r"""Process configuration dictionary source and create URDF elements. + + Args: + config: Configuration dictionary for sensor creation + sensor_name: Sensor name for element generation + + Returns: + Tuple of (links, joints, generated_path) + """ + urdf_element = self._create_sensor_from_config(config, sensor_name) + links = urdf_element.findall("link") + joints = urdf_element.findall("joint") + generated_path = f"" + return links, joints, generated_path + + def _process_element_tuple_source( + self, element_tuple: Tuple, sensor_name: str + ) -> Optional[Tuple[List[ET.Element], List[ET.Element], str]]: + r"""Process direct element tuple source. + + Args: + element_tuple: Tuple containing (links_list, joints_list) + sensor_name: Sensor name for path generation + + Returns: + Tuple of (links, joints, generated_path) or None on failure + """ + links_list, joints_list = element_tuple + + if not isinstance(links_list, list) or not isinstance(joints_list, list): + self.logger.error( + "Element tuple must contain (List[ET.Element], List[ET.Element])" + ) + return None + + # Validate that all elements are actually ET.Element instances + for i, link in enumerate(links_list): + if not isinstance(link, ET.Element): + self.logger.error(f"Links list item {i} is not an ET.Element") + return None + + for i, joint in enumerate(joints_list): + if not isinstance(joint, ET.Element): + self.logger.error(f"Joints list item {i} is not an ET.Element") + return None + + generated_path = f"" + return links_list, joints_list, generated_path + + def _extract_elements_from_urdf( + self, + urdf_element: ET.Element, + extract_links: Optional[List[str]] = None, + extract_joints: Optional[List[str]] = None, + ) -> Tuple[List[ET.Element], List[ET.Element]]: + r"""Extract specified links and joints from URDF element. + + Args: + urdf_element: URDF root element to extract from + extract_links: Optional list of specific link names to extract + extract_joints: Optional list of specific joint names to extract + + Returns: + Tuple of (extracted_links, extracted_joints) + """ + # Extract links + all_links = urdf_element.findall("link") + if extract_links: + links = [] + for link_name in extract_links: + link = urdf_element.find(f".//link[@name='{link_name}']") + if link is not None: + links.append(link) + self.logger.debug(f"Extracted link: {link_name}") + else: + self.logger.warning(f"Link '{link_name}' not found in URDF") + else: + links = all_links + self.logger.debug(f"Extracted all {len(links)} links from URDF") + + # Extract joints + all_joints = urdf_element.findall("joint") + if extract_joints: + joints = [] + for joint_name in extract_joints: + joint = urdf_element.find(f".//joint[@name='{joint_name}']") + if joint is not None: + joints.append(joint) + self.logger.debug(f"Extracted joint: {joint_name}") + else: + self.logger.warning(f"Joint '{joint_name}' not found in URDF") + else: + joints = all_joints + self.logger.debug(f"Extracted all {len(joints)} joints from URDF") + + return links, joints + + def _validate_sensor_elements( + self, sensor_links: List[ET.Element], sensor_joints: List[ET.Element] + ) -> bool: + r"""Validate extracted sensor elements for completeness and consistency. + + Args: + sensor_links: List of sensor link elements + sensor_joints: List of sensor joint elements + + Returns: + bool: True if elements are valid, False otherwise + """ + if not sensor_links: + self.logger.error("No links found in sensor definition") + return False + + # Validate link elements + for i, link in enumerate(sensor_links): + if not isinstance(link, ET.Element): + self.logger.error(f"Invalid link element at index {i}") + return False + + link_name = link.get("name") + if not link_name: + self.logger.error(f"Link at index {i} has no name attribute") + return False + + # Validate joint elements + for i, joint in enumerate(sensor_joints): + if not isinstance(joint, ET.Element): + self.logger.error(f"Invalid joint element at index {i}") + return False + + joint_name = joint.get("name") + if not joint_name: + self.logger.error(f"Joint at index {i} has no name attribute") + return False + + self.logger.debug( + f"Validated {len(sensor_links)} links and {len(sensor_joints)} joints" + ) + return True + + def _is_valid_homogeneous_transform(self, transform: np.ndarray) -> bool: + """ + Validate that a 4x4 matrix is a plausible homogeneous transformation matrix. + Only warn if not strictly valid, but still return True. + + Args: + transform: 4x4 transformation matrix to validate + + Returns: + bool: Always True, but warns if not strictly valid + """ + try: + # Check shape + if transform.shape != (4, 4): + self.logger.warning("Transform matrix is not 4x4.") + return False + + # Check bottom row + expected_bottom_row = np.array([0, 0, 0, 1]) + if not np.allclose(transform[3, :], expected_bottom_row, atol=1e-6): + self.logger.warning("Transform bottom row is not [0, 0, 0, 1].") + + # Check rotation matrix orthogonality + rotation_matrix = transform[:3, :3] + should_be_identity = np.dot(rotation_matrix, rotation_matrix.T) + if not np.allclose(should_be_identity, np.eye(3), atol=1e-6): + self.logger.warning("Rotation part of transform is not orthogonal.") + + # Check determinant + if not np.isclose(np.linalg.det(rotation_matrix), 1.0, atol=1e-6): + self.logger.warning("Rotation matrix determinant is not close to 1.") + + # Always return True, just warn + return True + + except Exception as e: + self.logger.warning(f"Transform validation exception: {e}") + return True + + def _process_and_rename_sensor_elements( + self, + sensor_links: List[ET.Element], + sensor_joints: List[ET.Element], + sensor_name: str, + ) -> Optional[Tuple[List[ET.Element], List[ET.Element]]]: + r"""Process and rename sensor link and joint elements to avoid name conflicts. + + Args: + sensor_links (List[ET.Element]): List of sensor link XML elements. + sensor_joints (List[ET.Element]): List of sensor joint XML elements. + sensor_name (str): The sensor's name, used as a prefix. + + Returns: + Optional[Tuple[List[ET.Element], List[ET.Element]]]: Tuple of processed (links, joints), + or None if processing fails. + """ + try: + processed_links = [] + processed_joints = [] + sensor_prefix = f"{sensor_name}_" + sensor_name_lower = sensor_name.lower() + link_name_mapping = {} + + # Process links: add prefix if needed and build mapping + for link in sensor_links: + original_name = link.get("name") + # If the name already contains the sensor name (case-insensitive), do not add prefix + if sensor_name_lower in original_name.lower(): + new_name = original_name + else: + new_name = f"{sensor_prefix}{original_name}" + link_name_mapping[original_name] = new_name + new_link = copy.deepcopy(link) + new_link.set("name", new_name) + processed_links.append(new_link) + + # Process joints: add prefix if needed and update parent/child references + for joint in sensor_joints: + original_name = joint.get("name") + if sensor_name_lower in original_name.lower(): + new_name = original_name + else: + new_name = f"{sensor_prefix}{original_name}" + new_joint = copy.deepcopy(joint) + new_joint.set("name", new_name) + + # Update parent link reference + parent_elem = new_joint.find("parent") + if parent_elem is not None: + parent_link_name = parent_elem.get("link") + parent_elem.set( + "link", + link_name_mapping.get(parent_link_name, parent_link_name), + ) + # Update child link reference + child_elem = new_joint.find("child") + if child_elem is not None: + child_link_name = child_elem.get("link") + child_elem.set( + "link", link_name_mapping.get(child_link_name, child_link_name) + ) + + processed_joints.append(new_joint) + + return processed_links, processed_joints + except Exception as e: + self.logger.error(f"Failed to process sensor elements: {str(e)}") + return None + + def _create_sensor_from_config(self, config: Dict, sensor_name: str) -> ET.Element: + r"""Create sensor URDF element from configuration dictionary. + + Args: + config: Configuration dictionary containing sensor specifications + sensor_name: Name for the generated sensor + + Returns: + ET.Element: Root element of generated sensor URDF + """ + # Create root robot element + robot = ET.Element("robot", name=f"sensor_{sensor_name}") + + # Create main sensor link + link_name = config.get("link_name", f"{sensor_name}_link") + link = ET.SubElement(robot, "link", name=link_name) + + # Add visual element if specified + if "visual" in config: + visual_config = config["visual"] + visual = ET.SubElement(link, "visual") + + # Add origin if specified + if "origin" in visual_config: + origin_data = visual_config["origin"] + ET.SubElement( + visual, + "origin", + xyz=origin_data.get("xyz", "0 0 0"), + rpy=origin_data.get("rpy", "0 0 0"), + ) + + # Add geometry + geometry = ET.SubElement(visual, "geometry") + geom_type = visual_config.get("type", "box") + + if geom_type == "box": + size = visual_config.get("size", "0.1 0.1 0.1") + ET.SubElement(geometry, "box", size=size) + + elif geom_type == "cylinder": + radius = str(visual_config.get("radius", 0.05)) + length = str(visual_config.get("length", 0.1)) + ET.SubElement(geometry, "cylinder", radius=radius, length=length) + + elif geom_type == "sphere": + radius = str(visual_config.get("radius", 0.05)) + ET.SubElement(geometry, "sphere", radius=radius) + + elif geom_type == "mesh": + filename = visual_config.get("filename", "") + if filename: + mesh_elem = ET.SubElement(geometry, "mesh", filename=filename) + if "scale" in visual_config: + mesh_elem.set("scale", visual_config["scale"]) + + # Add material/color if specified + if "color" in visual_config: + material = ET.SubElement( + visual, "material", name=f"{sensor_name}_material" + ) + ET.SubElement(material, "color", rgba=visual_config["color"]) + + # Add collision element if specified + if "collision" in config: + collision_config = config["collision"] + collision = ET.SubElement(link, "collision") + + # Add origin if specified + if "origin" in collision_config: + origin_data = collision_config["origin"] + ET.SubElement( + collision, + "origin", + xyz=origin_data.get("xyz", "0 0 0"), + rpy=origin_data.get("rpy", "0 0 0"), + ) + + # Add geometry (similar to visual) + geometry = ET.SubElement(collision, "geometry") + geom_type = collision_config.get("type", "box") + + if geom_type == "box": + size = collision_config.get("size", "0.1 0.1 0.1") + ET.SubElement(geometry, "box", size=size) + + elif geom_type == "cylinder": + radius = str(collision_config.get("radius", 0.05)) + length = str(collision_config.get("length", 0.1)) + ET.SubElement(geometry, "cylinder", radius=radius, length=length) + + elif geom_type == "sphere": + radius = str(collision_config.get("radius", 0.05)) + ET.SubElement(geometry, "sphere", radius=radius) + + # Add inertial properties if specified + if "inertial" in config: + inertial_config = config["inertial"] + inertial = ET.SubElement(link, "inertial") + + # Add origin if specified + if "origin" in inertial_config: + origin_data = inertial_config["origin"] + ET.SubElement( + inertial, + "origin", + xyz=origin_data.get("xyz", "0 0 0"), + rpy=origin_data.get("rpy", "0 0 0"), + ) + + # Add mass + mass_value = str(inertial_config.get("mass", 0.1)) + ET.SubElement(inertial, "mass", value=mass_value) + + # Add inertia tensor + inertia_elem = ET.SubElement(inertial, "inertia") + inertia_properties = { + "ixx": "ixx", + "iyy": "iyy", + "izz": "izz", + "ixy": "ixy", + "ixz": "ixz", + "iyz": "iyz", + } + + for attr, config_key in inertia_properties.items(): + value = str(inertial_config.get(config_key, 0.0)) + inertia_elem.set(attr, value) + + # Add any additional joints if specified in config + if "joints" in config: + for joint_config in config["joints"]: + joint = ET.SubElement( + robot, + "joint", + name=joint_config.get("name", f"{sensor_name}_joint"), + type=joint_config.get("type", "fixed"), + ) + + # Add origin + if "origin" in joint_config: + origin_data = joint_config["origin"] + ET.SubElement( + joint, + "origin", + xyz=origin_data.get("xyz", "0 0 0"), + rpy=origin_data.get("rpy", "0 0 0"), + ) + + # Add parent and child links + if "parent" in joint_config: + ET.SubElement(joint, "parent", link=joint_config["parent"]) + if "child" in joint_config: + ET.SubElement(joint, "child", link=joint_config["child"]) + + # Add axis for revolute/prismatic joints + if ( + joint_config.get("type") in ["revolute", "prismatic"] + and "axis" in joint_config + ): + ET.SubElement(joint, "axis", xyz=joint_config["axis"]) + + # Add limits for revolute/prismatic joints + if "limits" in joint_config: + limits_data = joint_config["limits"] + limit_elem = ET.SubElement(joint, "limit") + for attr in ["lower", "upper", "effort", "velocity"]: + if attr in limits_data: + limit_elem.set(attr, str(limits_data[attr])) + + self.logger.debug(f"Generated sensor URDF from config for '{sensor_name}'") + return robot + + def process_sensor_attachments( + self, + links: list, + joints: list, + base_points: dict, + existing_link_names: set, + existing_joint_names: set, + ): + r"""Process all attached sensors by adding their link and joint elements to the robot. + + Args: + links (list): Global list to collect sensor link elements. + joints (list): Global list to collect sensor joint elements. + base_points (dict): Mapping from component names to their base link names. + existing_link_names (set): Set of existing link names to avoid conflicts. + existing_joint_names (set): Set of existing joint names to avoid conflicts. + """ + for sensor_name, sensor_data in self.attached_sensors.items(): + try: + attachment = sensor_data["attachment"] + sensor_links = sensor_data["links"] + sensor_joints = sensor_data["joints"] + sensor_type = sensor_data.get("sensor_type", "unknown") + + self.logger.debug( + f"Processing sensor attachment: {sensor_name} ({sensor_type})" + ) + + # Process sensor links: ensure names are lowercase and prefixed + for link in sensor_links: + link_name = link.get("name") + if link_name: + # Get original and sensor type strings + original_name = link_name.lower() + sensor_type_str = ( + str(sensor_type).lower() if sensor_type else "" + ) + # Add prefix only if not already present + if sensor_type_str and sensor_type_str not in original_name: + formatted_name = f"{original_name}_{sensor_type_str}" + else: + formatted_name = original_name + + # Ensure unique link names + unique_name = formatted_name + count = 1 + while unique_name in existing_link_names: + unique_name = f"{formatted_name}_{count}" + self.logger.warning( + f"Link name '{unique_name}' already exists. Trying a new name '{unique_name}' with suffix: '{count}'" + ) + formatted_name = unique_name + count += 1 + + link.set("name", formatted_name) + + # Track link names and add to global list + existing_link_names.add(formatted_name) + links.append(link) + + # Process meshes for this sensor link + self._process_sensor_meshes( + link, attachment.sensor_urdf, sensor_name + ) + + self.logger.debug(f"Added sensor link: {formatted_name}") + + # Process sensor joints: ensure names are UPPERCASE and follow PARENT_TO_CHILD format + for joint in sensor_joints: + joint_name = joint.get("name") + if joint_name: + parent_elem = joint.find("parent") + child_elem = joint.find("child") + parent_link = ( + parent_elem.get("link").lower() + if parent_elem is not None + else "" + ) + child_link = ( + child_elem.get("link").lower() + if child_elem is not None + else "" + ) + + # Format joint name as PARENT_TO_CHILD in uppercase + formatted_name = f"{parent_link}_to_{child_link}".upper() + joint.set("name", formatted_name) + + # Ensure parent/child link references are lowercase + if parent_elem is not None: + parent_elem.set("link", parent_link) + if child_elem is not None: + child_elem.set("link", child_link) + + if attachment.transform is not None: + transform = attachment.transform + xyz = transform[:3, 3] + rotation = R.from_matrix(transform[:3, :3]) + rpy = rotation.as_euler("xyz") + + origin_elem = joint.find("origin") + if origin_elem is None: + origin_elem = ET.SubElement(joint, "origin") + origin_elem.set("xyz", f"{xyz[0]} {xyz[1]} {xyz[2]}") + origin_elem.set("rpy", f"{rpy[0]} {rpy[1]} {rpy[2]}") + + self.logger.info( + f"Applied transform to sensor joint {joint.get('name')}: xyz={xyz}, rpy={rpy}" + ) + + existing_joint_names.add(formatted_name) + joints.append(joint) + self.logger.debug(f"Added sensor joint: {formatted_name}") + + except Exception as e: + self.logger.error( + f"Failed to process sensor attachment {sensor_name}: {str(e)}" + ) + self.logger.debug(f"Traceback: {traceback.format_exc()}") + + def _process_sensor_meshes( + self, link: ET.Element, base_urdf_path: str, sensor_name: str + ): + r"""Process visual and collision meshes for a sensor link. + + Args: + link (ET.Element): The URDF link element to process. + base_urdf_path (str): The base path for the URDF files. + sensor_name (str): The name of the sensor being processed. + """ + try: + # Process visual meshes + for visual in link.findall("visual"): + geometry = visual.find("geometry") + if geometry is not None: + mesh = geometry.find("mesh") + if mesh is not None: + filename = mesh.get("filename") + if filename is not None: + self.logger.debug( + f"Processing sensor visual mesh: {filename}" + ) + new_mesh_filename = self.mesh_manager.copy_and_modify_mesh_file( + base_urdf_path, + filename, + "Visual", + f"sensor_{sensor_name}", # Use sensor prefix for organization + ) + self.logger.debug( + f"Updated sensor visual mesh filename: {new_mesh_filename}" + ) + mesh.set("filename", new_mesh_filename) + + # Process collision meshes + for collision in link.findall("collision"): + geometry = collision.find("geometry") + if geometry is not None: + mesh = geometry.find("mesh") + if mesh is not None: + filename = mesh.get("filename") + if filename is not None: + self.logger.debug( + f"Processing sensor collision mesh: {filename}" + ) + new_mesh_filename = self.mesh_manager.copy_and_modify_mesh_file( + base_urdf_path, + filename, + "Collision", + f"sensor_{sensor_name}", # Use sensor prefix for organization + ) + self.logger.debug( + f"Updated sensor collision mesh filename: {new_mesh_filename}" + ) + mesh.set("filename", new_mesh_filename) + + except Exception as e: + self.logger.error(f"Failed to process meshes for sensor {sensor_name}: {e}") + + def get_attached_sensors(self) -> Dict: + r"""Get all attached sensors with processed data.""" + return self.attached_sensors + + def convert_to_legacy_format(self) -> Dict: + r"""Convert processed sensors to legacy attach_dict format for compatibility.""" + legacy_dict = {} + for sensor_name, sensor_data in self.attached_sensors.items(): + legacy_dict[sensor_name] = sensor_data["attachment"] + return legacy_dict diff --git a/embodichain/toolkits/urdf_assembly/signature.py b/embodichain/toolkits/urdf_assembly/signature.py new file mode 100644 index 00000000..d0b5de9e --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/signature.py @@ -0,0 +1,204 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import json +import hashlib +from pathlib import Path +import xml.etree.ElementTree as ET + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + URDFAssemblyLogger, +) + +__all__ = ["URDFAssemblySignatureManager"] + + +class URDFAssemblySignatureManager: + r"""Simple MD5-based signature manager for URDF assemblies without persistent cache.""" + + def __init__(self): + self.logger = URDFAssemblyLogger.get_logger("signature_manager") + + def _calculate_file_md5(self, file_path: str) -> str: + r"""Calculate MD5 hash of a file.""" + hash_md5 = hashlib.md5() + try: + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + except Exception as e: + self.logger.error(f"Error calculating MD5 for {file_path}: {e}") + return "" + + def _calculate_string_md5(self, content: str) -> str: + r"""Calculate MD5 hash of a string.""" + return hashlib.md5(content.encode("utf-8")).hexdigest() + + def calculate_assembly_signature(self, urdf_dict: dict, output_path: str) -> str: + r"""Calculate a unique signature for the assembly configuration. + + Args: + urdf_dict (dict): Dictionary of components and their configurations + output_path (str): Target output path for the assembly + + Returns: + str: MD5 hash representing the assembly configuration + """ + signature_data = { + "output_filename": os.path.basename(output_path), + "components": {}, + } + + def to_serializable(obj): + r"""Recursively convert objects to types that are JSON serializable. + + Args: + obj: The object to convert (could be Path, dict, list, or other types). + + Returns: + The converted object, ready for JSON serialization. + - Path objects are converted to strings. + - dict and list are recursively processed. + - Other types are returned as-is. + """ + if isinstance(obj, Path): + return str(obj) + elif isinstance(obj, dict): + return {k: to_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_serializable(i) for i in obj] + else: + return obj + + # Process each component + for comp_type, comp_obj in urdf_dict.items(): + if comp_obj is None: + continue + + # Calculate file MD5 + file_md5 = self._calculate_file_md5(comp_obj.urdf_path) + if not file_md5: + self.logger.warning(f"Could not calculate MD5 for {comp_obj.urdf_path}") + continue + + # Include component configuration + comp_data = { + "urdf_path": str(comp_obj.urdf_path), + "file_md5": file_md5, + "params": to_serializable(comp_obj.params or {}), + "transform": comp_obj.transform.tolist() + if comp_obj.transform is not None + else None, + } + + signature_data["components"][comp_type] = comp_data + + # Convert to JSON string for consistent hashing + signature_json = json.dumps(signature_data, sort_keys=True, ensure_ascii=False) + assembly_md5 = self._calculate_string_md5(signature_json) + + self.logger.info(f"Assembly signature calculated: [{assembly_md5}]") + self.logger.debug(f"Signature data: {signature_json}") + + return assembly_md5 + + def extract_signature_from_urdf(self, urdf_file_path: str) -> str: + r"""Extract signature from existing URDF file's header comment. + + Args: + urdf_file_path (str): Path to existing URDF file + + Returns: + str: Extracted signature or empty string if not found + """ + try: + with open(urdf_file_path, "r", encoding="utf-8") as f: + content = f.read() + + # Look for signature in comment + import re + + # 1. + # 2. + patterns = [ + r"", + r"", + ] + + for pattern in patterns: + match = re.search(pattern, content) + if match: + signature = match.group(1) + self.logger.info( + f"Found existing signature in ({urdf_file_path}): [{signature}]" + ) + return signature + + self.logger.debug(f"No signature found in {urdf_file_path}") + return "" + + except Exception as e: + self.logger.warning( + f"Failed to extract signature from {urdf_file_path}: {e}", exc_info=True + ) + return "" + + def is_assembly_up_to_date(self, current_signature: str, output_path: str) -> bool: + r"""Check if the assembly at output_path has the same signature as current configuration. + + Args: + current_signature (str): MD5 signature of current assembly configuration + output_path (str): Path to existing URDF file + + Returns: + bool: True if signatures match and file exists + """ + if not os.path.exists(output_path): + self.logger.info(f"Output file does not exist: {output_path}") + return False + + # Verify file is not empty and is valid URDF + try: + if os.path.getsize(output_path) == 0: + self.logger.warning(f"Output file is empty: {output_path}") + return False + + # Try to parse as XML to ensure it's valid + ET.parse(output_path) + except Exception as e: + self.logger.warning( + f"Output file is invalid: {output_path}, error: {e}", exc_info=True + ) + return False + + # Extract signature from existing file + existing_signature = self.extract_signature_from_urdf(output_path) + + if existing_signature == current_signature: + self.logger.info( + f"✅ Assembly is up-to-date. Signature: {current_signature}" + ) + return True + else: + if existing_signature: + self.logger.info(f"Assembly signatures differ:") + self.logger.info(f" Current: {current_signature}") + self.logger.info(f" Existing: {existing_signature}") + else: + self.logger.info(f"No signature found in existing file: {output_path}") + return False diff --git a/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py b/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py new file mode 100644 index 00000000..80bcd90f --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/urdf_assembly_manager.py @@ -0,0 +1,783 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import time +import logging +import numpy as np +from pathlib import Path +from functools import wraps +import xml.etree.ElementTree as ET + +from scipy.spatial.transform import Rotation as R +from typing import Dict, List, Optional, Union, Tuple + +from embodichain.toolkits.urdf_assembly.logging_utils import ( + setup_urdf_logging, +) +from embodichain.toolkits.urdf_assembly.signature import ( + URDFAssemblySignatureManager, +) +from embodichain.toolkits.urdf_assembly.component import ( + URDFComponent, + ComponentRegistry, + URDFComponentManager, +) +from embodichain.toolkits.urdf_assembly.sensor import ( + SensorAttachment, + SensorRegistry, + URDFSensorManager, +) +from embodichain.toolkits.urdf_assembly.connection import ( + URDFConnectionManager, +) +from embodichain.toolkits.urdf_assembly.mesh import URDFMeshManager +from embodichain.toolkits.urdf_assembly.file_writer import ( + URDFFileWriter, +) +from embodichain.toolkits.urdf_assembly.utils import ( + ensure_directory_exists, +) + +__all__ = ["URDFAssemblyManager"] + + +def performance_monitor(func): + r"""Performance monitoring decorator for tracking function execution time""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + start_time = time.time() + try: + result = func(self, *args, **kwargs) + duration = time.time() - start_time + self.logger.debug(f"{func.__name__} completed in {duration:.3f}s") + return result + except Exception as e: + duration = time.time() - start_time + self.logger.error(f"{func.__name__} failed after {duration:.3f}s: {e}") + raise + + return wrapper + + +class URDFAssemblyManager: + r""" + A class to manage the assembly of URDF files and their components. + """ + # Supported wheel types for chassis components + SUPPORTED_WHEEL_TYPES = [ + "omni", + "differential", + "tracked", + ] + + # Supported robot component types + SUPPORTED_COMPONENTS = [ + "chassis", + "legs", + "torso", + "head", + "left_arm", + "right_arm", + "left_hand", + "right_hand", + "arm", + "hand", + ] + + # Supported sensor types for attachment + SUPPORTED_SENSORS = [ + "camera", + "lidar", + "imu", + "gps", + "force", + ] + + # Supported mesh file formats + SUPPORTED_MESH_TYPES = [ + "stl", + "obj", + "ply", + "dae", + "glb", + ] + + def __init__( + self, + component_registry: ComponentRegistry = None, + sensor_registry: SensorRegistry = None, + mesh_manager: URDFMeshManager = None, + component_manager: "URDFComponentManager" = None, + sensor_manager: "URDFSensorManager" = None, + ): + self.logger = setup_urdf_logging() + + # Use registries for components and sensors + self.component_registry = component_registry or ComponentRegistry() + self.sensor_registry = sensor_registry or SensorRegistry() + + # Initialize mesh manager + self.mesh_manager = mesh_manager or URDFMeshManager(output_dir=".") + + # Initialize managers for components and sensors + self.component_manager = component_manager or URDFComponentManager( + self.mesh_manager + ) + self.sensor_manager = sensor_manager or URDFSensorManager(self.mesh_manager) + + # Processing order for components with their name prefixes + # Tuple format: (component_name, prefix) + self.component_order = [ + ("chassis", None), + ("legs", None), + ("torso", None), + ("head", None), + ("left_arm", "left_"), + ("right_arm", "right_"), + ("left_hand", "left_"), + ("right_hand", "right_"), + ("arm", None), + ("hand", None), + ] + + # Attachment position indices for component connections. + # This dictionary defines which link of each component should be used as the connection point + # when attaching to other components: + # 0 : use the first link in the component's URDF (typically for child connections) + # -1 : use the last link in the component's URDF (typically for parent connections) + # For example, 'chassis': 0 means the first link of the chassis is used for child attachments; + # 'torso': -1 means the last link of the torso is used for child attachments, etc. + self.attach_positions = { + "chassis": 0, # Use first link of chassis for child connections + "legs": -1, # Use last link of legs for child connections + "torso": -1, # Use last link of torso for child connections + "head": 0, # Use first link of head for connections + "left_arm": -1, # Use last link of left_arm for hand attachment + "right_arm": -1, # Use last link of right_arm for hand attachment + "left_hand": 0, # Use first link of left_hand for connections + "right_hand": 0, # Use first link of right_hand for connections + "arm": -1, # Use last link of arm for hand attachment + "hand": 0, # Use first link of hand for connections + } + + # Connection rules defining parent-child relationships between components + self.connection_rules = [ + ("chassis", "legs"), + ("legs", "torso"), + ("chassis", "torso"), + ("chassis", "left_arm"), + ("chassis", "right_arm"), + ("chassis", "arm"), + ("torso", "head"), + ("torso", "left_arm"), + ("torso", "right_arm"), + ("torso", "arm"), + ("left_arm", "left_hand"), + ("right_arm", "right_hand"), + ("arm", "hand"), + ] + + # Configure logging + logging.basicConfig(level=logging.INFO) + + # Name of the base link for the robot + self.base_link_name = "base_link" + + # Initialize the URDF file writer for output formatting + self.file_writer = URDFFileWriter() + + # Initialize signature manager instead of cache manager + self.signature_manager = URDFAssemblySignatureManager() + + def add_component( + self, + component_type: str, + urdf_path: Union[str, Path], + transform: Optional[np.ndarray] = None, + **params, + ) -> bool: + r"""Add a URDF component to the component registry. + + This method creates a URDFComponent object and registers it in the component registry. + + Args: + component_type (str): The type/name of the component (e.g., 'chassis', 'head'). + urdf_path (str or Path): Path to the URDF file for this component. + transform (np.ndarray, optional): 4x4 transformation matrix for positioning the component. + **params: Additional component-specific parameters (e.g., wheel_type for chassis). + + Returns: + bool: True if component added successfully, False otherwise. + """ + try: + if not isinstance(component_type, str): + raise ValueError("component_type must be a string") + if not isinstance(urdf_path, (str, Path)): + raise ValueError("urdf_path must be a string or Path") + + component = URDFComponent( + urdf_path=urdf_path, params=params, transform=transform + ) + self.component_registry.add(component_type, component) + self.logger.info( + f"Added component: [{component_type}], URDF: ({urdf_path})" + ) + return True + except Exception as e: + self.logger.error(f"Failed to add component [{component_type}]: {e}") + return False + + def attach_sensor( + self, + sensor_name: str, + sensor_source, + parent_component: str, + parent_link: str, + transform: Optional[np.ndarray] = None, + **kwargs, + ) -> bool: + r"""Attach a sensor to a specific component and link, and register it in the sensor registry. + + This method creates a SensorAttachment object and registers it in the sensor registry. + + Args: + sensor_name (str): Unique name for the sensor (e.g., 'camera'). + sensor_source (str or ET.Element): Path to the sensor's URDF file or an XML element. + parent_component (str): Name of the component to which the sensor is attached. + parent_link (str): Name of the link within the parent component for attachment. + **kwargs: Additional keyword arguments (e.g., transform, sensor_type). + + Returns: + bool: True if sensor attached successfully, False otherwise. + """ + try: + sensor = SensorAttachment( + sensor_urdf=sensor_source, + parent_component=parent_component, + parent_link=parent_link, + transform=transform, + **kwargs, + ) + self.sensor_registry.add(sensor_name, sensor) + urdf_info = ( + f"\n\tURDF: ({sensor.sensor_urdf})" + if sensor.sensor_urdf + else ", URDF: [N/A]" + ) + self.logger.info( + f"Attached sensor: [{sensor_name}] " + f"to [{parent_component}] at link [{parent_link}]{urdf_info}" + ) + return True + except Exception as e: + self.logger.error(f"Failed to attach sensor [{sensor_name}]: {e}") + return False + + def get_component(self, component_type: str): + r"""Retrieve a component from the registry by its type/name. + + Args: + component_type (str): The type/name of the component to retrieve. + + Returns: + URDFComponent or None: The registered component object, or None if not found. + """ + return self.component_registry.get(component_type) + + def get_attached_sensors(self): + r"""Get all attached sensors from the sensor registry. + + Returns: + dict: A dictionary mapping sensor names to SensorAttachment objects. + """ + return self.sensor_registry.all() + + def _load_urdf(self, urdf_path: str) -> Optional[ET.Element]: + r"""Load a URDF file and return its root element. + + Args: + urdf_path (str): Path to the URDF file. + + Returns: + ET.Element: The root element of the parsed URDF XML. + """ + try: + tree = ET.parse(urdf_path) + return tree.getroot() + except Exception as e: + self.logger.error(f"Failed to load URDF {urdf_path}: {e}") + return None + + def _apply_transformation( + self, urdf: ET.Element, transform: np.ndarray, link_name: str + ): + r"""Applies a transformation matrix to the 'xyz' attributes of the origins of the specified link and its first joint in the URDF. + + Args: + urdf (ET.Element): The root element of the URDF to transform. + transform (np.ndarray): A 4x4 transformation matrix to apply. + link_name (str): The name of the link to apply the transformation to. + """ + # Now handle the first joint connected to this link + for joint in urdf.findall("joint"): + origin = joint.find("origin") + if origin is not None: + # Check if the joint connects to the specified link + child_link = joint.find("child").get("link") + if child_link == link_name: + xyz = np.array([float(val) for val in origin.get("xyz").split()]) + transformed_xyz = np.dot(transform[:3, :3], xyz) + transform[:3, 3] + origin.set("xyz", " ".join(map(str, transformed_xyz))) + + # Apply transformation to rpy + if "rpy" in origin.attrib: + rpy = np.array( + [float(val) for val in origin.get("rpy").split()] + ) + rotation = R.from_euler("xyz", rpy) + transformed_rotation = ( + R.from_matrix(transform[:3, :3]) * rotation + ) + transformed_rpy = transformed_rotation.as_euler("xyz") + origin.set("rpy", " ".join(map(str, transformed_rpy))) + elif "quat" in origin.attrib: + quat = np.array( + [float(val) for val in origin.get("quat").split()] + ) + rotation = R.from_euler("xyz", quat) + transformed_rotation = ( + R.from_matrix(transform[:3, :3]) * rotation + ) + transformed_rpy = transformed_rotation.as_euler("xyz") + origin.set("rpy", " ".join(map(str, transformed_rpy))) + + break # Stop after processing the first joint + + def _create_base_link(self) -> ET.Element: + r"""Creates a base link and returns it. + + Returns: + ET.Element: The base link element. + """ + base_link = ET.Element("link", name=self.base_link_name) + + return base_link + + def _validate_urdf_file(self, urdf_path: str) -> bool: + r"""Validate URDF file integrity and format compliance + + Args: + urdf_path (str): Path to the URDF file to validate + + Returns: + bool: True if file is valid, False otherwise + """ + try: + # Check if file exists + if not os.path.exists(urdf_path): + self.logger.error(f"URDF file not found: {urdf_path}") + return False + + # Check file size to ensure it's not empty + if os.path.getsize(urdf_path) == 0: + self.logger.error(f"URDF file is empty: {urdf_path}") + return False + + # Attempt to parse XML structure + root = ET.parse(urdf_path).getroot() + if root.tag != "robot": + self.logger.error(f"Invalid URDF root element: {root.tag}") + return False + + # Check for presence of basic link elements + if not root.findall("link"): + self.logger.error(f"No links found in URDF: {urdf_path}") + return False + + # Check robot name attribute + robot_name = root.get("name") + if not robot_name: + self.logger.warning(f"URDF robot has no name attribute: {urdf_path}") + + self.logger.debug(f"URDF file validation passed: {urdf_path}") + return True + + except ET.ParseError as e: + self.logger.error(f"XML parse error in {urdf_path}: {e}") + return False + except Exception as e: + self.logger.error(f"Validation error for {urdf_path}: {e}") + return False + + def _generate_connection_rules(self) -> list: + r"""Dynamically generate connection rules based on available components. + + Returns: + list: A list of (parent, child) tuples specifying connection relationships. + """ + connection_rules = [] + + # Filter components that exist in urdf_dict + existing_components = [ + comp + for comp in self.SUPPORTED_COMPONENTS + if self.component_registry.get(comp) + ] + + self.logger.debug(f"Existing components: {existing_components}") + + # Define explicit connection rules - only meaningful relationships + # Rule 1: chassis connects to torso (if both exist) + if "chassis" in existing_components and "legs" in existing_components: + connection_rules.append(("chassis", "legs")) + if "torso" in existing_components: + connection_rules.append(("legs", "torso")) + elif "chassis" in existing_components and "torso" in existing_components: + # If there are no legs, chassis connects directly to torso + connection_rules.append(("chassis", "torso")) + + # Rule 2: torso connects to head (if both exist) + if "torso" in existing_components and "head" in existing_components: + connection_rules.append(("torso", "head")) + + # Rule 3: torso connects to arms (if they exist) + if "torso" in existing_components: + if "left_arm" in existing_components: + connection_rules.append(("torso", "left_arm")) + if "right_arm" in existing_components: + connection_rules.append(("torso", "right_arm")) + if "arm" in existing_components: + connection_rules.append(("torso", "arm")) + + # Rule 4: arms connect to hands (if both exist) + if "left_arm" in existing_components and "left_hand" in existing_components: + connection_rules.append(("left_arm", "left_hand")) + if "right_arm" in existing_components and "right_hand" in existing_components: + connection_rules.append(("right_arm", "right_hand")) + + # Rule 5: single arm connects to hand + if "arm" in existing_components and "hand" in existing_components: + connection_rules.append(("arm", "hand")) + + # Rule 6: If no torso, chassis can directly connect to head and arms + if "chassis" in existing_components and "torso" not in existing_components: + if "head" in existing_components: + connection_rules.append(("chassis", "head")) + if "left_arm" in existing_components: + connection_rules.append(("chassis", "left_arm")) + if "right_arm" in existing_components: + connection_rules.append(("chassis", "right_arm")) + # Connect single arm directly to chassis (no torso scenario) + if "arm" in existing_components: + connection_rules.append(("chassis", "arm")) + + connection_rules = list(set(connection_rules)) + + self.logger.info( + f"Generated connection rules: {connection_rules}, total {len(connection_rules)} rules" + ) + + return connection_rules + + def _find_end_link( + self, component: str, base_points: dict, joints: list + ) -> Union[str, None]: + """Find the end link of a component by traversing the joint chain downward. + + Args: + component (str): Component name to find the end link for. + base_points (dict): Mapping from component names to their base link names. + joints (list): List of joint elements to traverse. + + Returns: + Union[str, None]: Name of the end link, or None if component not found. + """ + current_link = base_points.get(component) + if not current_link: + return None + + visited_links = set() # Prevent infinite loops in joint chains + while True: + visited_links.add(current_link) + found = False + for joint in joints: + if hasattr(joint, "find"): # Ensure it's an XML element, not a comment + parent = joint.find("parent") + child = joint.find("child") + if parent is not None and parent.get("link") == current_link: + if child is not None: + next_link = child.get("link") + if next_link not in visited_links: # Avoid revisiting links + current_link = next_link + found = True + break + if not found: + break # No further links found in the chain + return current_link + + @performance_monitor + def merge_urdfs( + self, + output_path: str = "./assembly_robot.urdf", + use_signature_check: bool = True, + ) -> ET.Element: + """Merge URDF files according to single base link, connection point naming, + and type compatibility matrix rules. + + Args: + output_path (str): Path where the merged URDF file will be saved. + use_signature_check (bool): Whether to check signatures to avoid redundant processing. + + Returns: + ET.Element: The root element of the merged URDF. + """ + output_path = os.path.abspath(output_path) + assembly_signature = None + + # Log components to be merged + available_components = [ + comp + for comp, obj in self.component_registry.all().items() + if obj is not None + ] + self.logger.info(f"🔧 Preparing to merge components: {available_components}") + + for comp in available_components: + comp_obj = self.component_registry.get(comp) + self.logger.info(f" [{comp}]: {comp_obj.urdf_path}") + if comp_obj.params: + self.logger.debug(f" Parameters: {comp_obj.params}") + if comp_obj.transform is not None: + self.logger.debug(f" Transform: applied") + + if use_signature_check: + # Calculate current assembly signature + assembly_signature = self.signature_manager.calculate_assembly_signature( + self.component_registry.all(), output_path + ) + + self.logger.info(f"Current assembly signature: [{assembly_signature}]") + self.logger.debug(f"Target output path: ({output_path})") + + # Check if assembly is up-to-date + if self.signature_manager.is_assembly_up_to_date( + assembly_signature, output_path + ): + self.logger.info( + f"✅ URDF assembly is up-to-date: ({output_path}), skipping rebuild." + ) + return ET.parse(output_path).getroot() + else: + self.logger.info( + "Assembly configuration has changed or file doesn't exist, rebuilding..." + ) + + # Perform normal assembly process + self.logger.info("🔄 Building new URDF assembly...") + + # 1. Generate standard header with module information + module_names = [ + os.path.splitext(os.path.basename(obj.urdf_path))[0] + for comp, obj in self.component_registry.all().items() + if obj + ] + + robot_name = os.path.splitext(os.path.basename(output_path))[0] + merged_urdf = ET.Element("robot", name=robot_name) + + # 2. Create single base link for the entire robot + base_link = ET.Element("link", name=self.base_link_name) + # Store links and joints separately for proper ordering + links = [base_link] + joints = [] + + # Mapping tables for component processing + name_mapping = {} # Maps (component, original_name) to new_name + base_points = {} # Maps component to its base connection link + parent_attach_points = {} # Maps component to its parent connection link + + # Initialize managers for mesh handling and component processing + output_dir = os.path.dirname(output_path) or "." + ensure_directory_exists(output_dir, self.logger) + mesh_manager = URDFMeshManager(output_dir) + mesh_manager.ensure_dirs() + component_manager = URDFComponentManager(mesh_manager) + connection_manager = URDFConnectionManager(self.base_link_name) + + # Initialize sensor manager with mesh_manager + sensor_manager = URDFSensorManager(mesh_manager) + + # Process any pending enhanced sensors + if hasattr(self, "_pending_sensors"): + for sensor_name, params in self._pending_sensors.items(): + success = sensor_manager.attach_sensor( + sensor_name=sensor_name, **params + ) + if success: + # Sync to legacy attach_dict for backward compatibility + self.attach_dict.update(sensor_manager.convert_to_legacy_format()) + + # 3. Process all components in defined order + connection_rules = self._generate_connection_rules() + + # Collect component transforms for connection joints + component_transforms = {} + for comp, comp_obj in self.component_registry.all().items(): + if comp_obj and comp_obj.transform is not None: + component_transforms[comp] = comp_obj.transform + + for comp, prefix in self.component_order: + comp_obj = self.component_registry.get(comp) + if not comp_obj: + continue + + # Add section comments using file writer + self.file_writer.add_section_comments(links, comp, "Links") + self.file_writer.add_section_comments(joints, comp, "Joints") + + # Parse component URDF to analyze its structure + urdf_root = ET.parse(comp_obj.urdf_path).getroot() + + # Determine parent component and attachment point for current component + parent_component = None + parent_attach_link = None + + # Find parent component based on connection rules + for parent, child in connection_rules: + if child == comp and parent in base_points: + parent_component = parent + # Use base connection point for chassis + if parent == "chassis": + parent_attach_link = base_points[parent] + else: + # For other components, find their end link + parent_attach_link = self._find_end_link( + parent, base_points, joints + ) + break + + if parent_component and parent_attach_link: + self.logger.debug( + f"Component [{comp}] will connect to parent [{parent_component}] at link: ({parent_attach_link})" + ) + else: + self.logger.debug( + f"Component [{comp}] has no parent component (likely chassis or standalone)" + ) + + # Process the component using the component manager + component_manager.process_component( + comp, prefix, comp_obj, name_mapping, base_points, links, joints + ) + + # Determine attachment positions for current component + original_links = urdf_root.findall("link") + + if original_links: + # Set base connection point (always first link for child connections) + first_original_name = original_links[0].get("name") + first_mapped_name = name_mapping.get( + (comp, first_original_name), first_original_name + ) + base_points[comp] = first_mapped_name + + self.logger.debug( + f"Set base_points[{comp}] = ({first_mapped_name}) .first link for child connection, original: ({first_original_name})" + ) + + # Set parent connection point based on attach_positions configuration + index = self.attach_positions.get(comp, 0) + try: + if 0 <= index < len(original_links): + original_attach_name = original_links[index].get("name") + elif index == -1 and original_links: + original_attach_name = original_links[-1].get("name") + else: + original_attach_name = ( + original_links[0].get("name") if original_links else None + ) + + # Find mapped name for the attachment point + mapped_attach_name = name_mapping.get( + (comp, original_attach_name), original_attach_name + ) + parent_attach_points[comp] = mapped_attach_name + + self.logger.debug( + f"Set parent_attach_points[{comp}] = ({mapped_attach_name}). Index {index} for parent connection, original: ({original_attach_name})" + ) + except IndexError: + # Fall back to first link if index is out of range + parent_attach_points[comp] = first_mapped_name + self.logger.warning( + f"Index {index} out of range for component {comp}, using first link: {first_mapped_name}" + ) + + # Add section end comments using file writer + self.file_writer.add_section_end_comments(links, comp, "Links") + self.file_writer.add_section_end_comments(joints, comp, "Joints") + + # 4. Create connection joints between components using transforms + connection_manager.add_connections( + joints, + base_points, + parent_attach_points, + connection_rules, + component_transforms, + ) + + # Track existing names for sensor processing + existing_link_names = { + link.get("name").lower() for link in links if link.get("name") + } + existing_joint_names = { + joint.get("name").upper() for joint in joints if joint.get("name") + } + + # 5. Process sensor attachments using the new sensor manager + for sensor_name, sensor_attach in self.sensor_registry.all().items(): + sensor_manager.attach_sensor( + sensor_name=sensor_name, + sensor_source=sensor_attach.sensor_urdf, + parent_component=sensor_attach.parent_component, + parent_link=sensor_attach.parent_link, + transform=sensor_attach.transform, + ) + + sensor_manager.process_sensor_attachments( + links, joints, base_points, existing_link_names, existing_joint_names + ) + + # 6. Add all links and joints to merged URDF in proper order + for link in links: + merged_urdf.append(link) + for joint in joints: + merged_urdf.append(joint) + + # 7. Write the final URDF file with proper formatting, header and signature + if use_signature_check and assembly_signature: + self.file_writer.write_urdf( + merged_urdf, output_path, module_names, assembly_signature + ) + self.logger.info( + f"✅ URDF assembly written with signature: {assembly_signature}" + ) + else: + self.file_writer.write_urdf(merged_urdf, output_path, module_names) + self.logger.info("✅ URDF assembly written without signature.") + return merged_urdf diff --git a/embodichain/toolkits/urdf_assembly/utils.py b/embodichain/toolkits/urdf_assembly/utils.py new file mode 100644 index 00000000..68ed896c --- /dev/null +++ b/embodichain/toolkits/urdf_assembly/utils.py @@ -0,0 +1,30 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from pathlib import Path +import logging + + +def ensure_directory_exists(path: str, logger: logging.Logger = None): + """Ensure the directory exists, create if not.""" + try: + path_obj = Path(path) + path_obj.mkdir(parents=True, exist_ok=True) + except Exception as e: + if logger: + logger.error(f"Failed to create directory {path}: {e}") + else: + print(f"Failed to create directory {path}: {e}") diff --git a/embodichain/utils/__init__.py b/embodichain/utils/__init__.py new file mode 100644 index 00000000..f30c7bcc --- /dev/null +++ b/embodichain/utils/__init__.py @@ -0,0 +1,59 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .configclass import configclass, is_configclass + + +GLOBAL_SEED = 1024 + + +def set_seed(seed: int, deterministic: bool = False) -> int: + """Set the random seed for reproducibility. + + Args: + seed (int): The seed value to set. If -1, a random seed will be generated. + deterministic (bool): If True, sets the environment to deterministic mode for reproducibility. + """ + import random + import numpy as np + import torch + import os + import warp as wp + + if seed == -1 and deterministic: + seed = GLOBAL_SEED + elif seed == -1: + seed = np.random.randint(0, 10000) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + wp.rand_init(seed) + + if deterministic: + # refer to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + else: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + + return seed diff --git a/embodichain/utils/cfg.py b/embodichain/utils/cfg.py new file mode 100644 index 00000000..47faa48b --- /dev/null +++ b/embodichain/utils/cfg.py @@ -0,0 +1,598 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import functools +import inspect +import logging + +from copy import deepcopy +from pathlib import Path +from typing import Dict, List, Optional, Union + +from fvcore.common.config import BASE_KEY +from fvcore.common.config import CfgNode as _CfgNode +from iopath.common.file_io import PathManager as PathManagerBase +from yacs.config import _VALID_TYPES, _assert_with_logging + +sep = "." +prefix = "" +PathManager = PathManagerBase() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _flatten_dict( + src: Dict, + prefix: Optional[str] = prefix, + sep: Optional[str] = sep, + dct: Optional[Dict] = {}, +) -> Dict: + """Traverse a dictionary and return all keys including nested ones. + + Args: + src (Dict): an instance of :class:`Dict`. + prefix (Optional[str], optional): [description]. Defaults to prefix. + sep (Optional[str], optional): [description]. Defaults to sep. + dct (Optional[Dict], optional): [description]. Defaults to {}. + + Returns: + Dict: flatten dictionary with all keys. + """ + items = [] + for k, v in src.items(): + new_key = prefix + sep + k if prefix else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def _dict_depth(d: Dict | CfgNode) -> int: + """Calculate the maximal depth of dictionary + + Args: + d (Dict): an instance of :class:`Dict`. + + Returns: + int: maximal depth. + """ + if isinstance(d, dict): + # 如果d是空dict就直接给0 + return 1 + (max(map(_dict_depth, d.values())) if d else 0) + # return 1 + (max(map(_dict_depth, d.values())) if d else 0) + else: + return 0 + # 无限递归最后肯定不是dict,也就是说肯定会raise error,这是不合理的 + # TypeError("Expected type is dict but {} is received".format( + # type(d).__name__)) + + +# NOTE: given the new config system +# (https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html), +# they will stop adding new functionalities to default CfgNode. + + +# NOTE: maybe someday one require save config orderly, I have tried and find it not easy. +# there is a method making yaml.load() output ordered dict: https://tendcode.com/article/yaml_order/ , +# but yacs.config.CfgNode is a subclass of :class:`Dict`, so it may hard to make a dict +# subclass has ordered key when initialize. +class CfgNode(_CfgNode): + # counter records user visits of every attributes and is used in self.unvisited_keys() + COUNTER = "__COUNTER__" + CACHED_NAMES = "__CACHED_NAMES__" + + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + """ + Args: + init_dict (dict): the possibly-nested dictionary to initailize the CfgNode. + key_list (list[str]): a list of names which index this CfgNode from the root. + Currently only used for logging purposes. + new_allowed (bool): whether adding new key is allowed when merging with + other configs. + """ + super(CfgNode, self).__init__(init_dict) + # when self.load_cfg_from_file(), it consequently goes to cls(cfg_as_dict), where `init_dict` is not None + # counter dict only contain flattened leaf node of a CfgNode rather than direct child node, + # for example, counter dict of node + # TOPKEYA: + # KEYA: "value1" + # KEYB: + # SUBKEYA: 1000 + # SUBKEYB: 2000 + # has key ['TOPKEYA.KEYA', 'TOPKEYA.KEYB.SUBKEYA', 'TOPKEYA.KEYB.SUBKEYB'], but has no 'TOPKEYA' or 'TOPKEA.KEYB', + # and the counter dict of node TOPKEYA has key ['KEYA', 'KEYB.SUBKEYA', 'KEYB.SUBKEYB'], but has no 'KEYB'. + if init_dict is not None: + self.__dict__[CfgNode.COUNTER] = _flatten_dict(init_dict) + for key in self.__dict__[CfgNode.COUNTER].keys(): + self.__dict__[CfgNode.COUNTER][key] = 0 + else: + self.__dict__[CfgNode.COUNTER] = {} + self.__dict__[CfgNode.CACHED_NAMES] = [] + + self.set_new_allowed(new_allowed) + + def __getattr__(self, name): + if name in self: + self.__dict__[CfgNode.CACHED_NAMES].append(name) + concated_name = sep.join(self.__dict__[CfgNode.CACHED_NAMES]) + if concated_name in self.__dict__[CfgNode.COUNTER]: + # only parent node of leaf CfgNode can reach here, and top level node can't + self.__dict__[CfgNode.COUNTER][concated_name] += 1 + self.__dict__[CfgNode.CACHED_NAMES] = [] + return self[name] + else: + raise AttributeError(name) + + # TODO: overload __setattr__ to use `new_allowed` to avoid user manually add key by `cfg["key"]=value`. + # Or is it necessary to do that? Because neither yacs and detectron2 make this feature. + + # TODO: When adding a new key, COUNTER does not contain an entry for the newly added key + + @classmethod + def _open_cfg(cls, filename): + return PathManager.open(filename, "r", encoding="utf-8") + + @classmethod + def load_cfg_from_file( + cls, + filename_or_str_content: Union[str, Path], + new_allowed: bool = True, + root_path: str = None, + ) -> CfgNode: + """load configration from a yaml file. + Modified from function load_yaml_with_base() of fvcore.common.config.CfgNode. + The original one do not support `NEW_ALLOWED` key, but I think sometime it will + be needed, so we had better add it. + + Args: + filename_or_str_content (Union[str, Path]): a yaml filename or yaml content string + new_allowed (bool): whether adding new key is allowed when merging with + other configs. + root_path (str): Parent directory of `_BASE_` config. Usually _BASE_ is written + as a relative path, the result will change if the path executing command change, + and we directly use `root_path` as the actual parent directory of `_BASE_` config file + to avoid this confusion. + + Returns: + cfg: a :class:`CfgNode` instance. + """ + is_file = PathManager.isfile(filename_or_str_content) + if len(str(filename_or_str_content)) < 256 and str( + filename_or_str_content + ).endswith(".yaml"): + # We assume if input is a yaml file path, it will not longer than 256 + # and it should ends with '.yaml' + if is_file: + with cls._open_cfg(filename_or_str_content) as file: + # load_cfg use yaml.safe_load() to prevent malicious code (see https://zhuanlan.zhihu.com/p/54332357); + # fvcore supports yaml.unsafe_load(), but I don't see any code use it both in detectron2 and fvcore, + # so I think use original load_cfg() in yacs is enough. + cfg = cls.load_cfg(file) + else: + msg = ( + f"CfgNode: Input string: '{filename_or_str_content}' looks like" + " a yaml file path, but the file is not found on disk!" + ) + logger.error(msg) + raise FileNotFoundError(msg) + else: + # Otherwise the input is a yaml-format string + cfg = cls.load_cfg(filename_or_str_content) + + if root_path is not None and hasattr(cfg, "_BASE_"): + path = Path(root_path) / cfg._BASE_ + if not path.exists(): + raise ValueError("Path {} does not exist.".format(path)) + cfg._BASE_ = str(path) + + def _load_with_base(base_cfg_file: str) -> CfgNode: + if base_cfg_file.startswith("~"): + base_cfg_file = Path(base_cfg_file).expanduser() + if not any(map(base_cfg_file.startswith, ["/", "https://", "http://"])): + if is_file: + # the path to base cfg is relative to the config file itself. + base_cfg_file = Path(filename_or_str_content).parent / base_cfg_file + return cls.load_cfg_from_file(base_cfg_file, new_allowed=new_allowed) + + if BASE_KEY in cfg: + if isinstance(cfg[BASE_KEY], list): + base_cfg = cls(new_allowed=new_allowed) + base_cfg_files = cfg[BASE_KEY] + # NOTE: `new_allowed` of the new added key is default False, so after a "new_allowed" merge new keys from other config, + # the new key is not `new_allowed`, which is unreasonable, so we manually update `new_allowed` of merged new keys + for base_cfg_file in base_cfg_files: + base_cfg.merge_from_other_cfg(_load_with_base(base_cfg_file)) + base_cfg.set_new_allowed(new_allowed) + else: + base_cfg_file = cfg[BASE_KEY] + base_cfg = _load_with_base(base_cfg_file) + del cfg[BASE_KEY] + + base_cfg.merge_from_other_cfg(cfg) + return base_cfg + + cfg.set_new_allowed(new_allowed) + return cfg + + def merge_from_other_cfg(self, cfg_other): + """Merge `cfg_other` into this CfgNode.""" + _merge_a_into_b(cfg_other, self, self, []) + other_counter = cfg_other.__dict__[CfgNode.COUNTER] + self.__dict__[CfgNode.COUNTER] = { + **self.__dict__[CfgNode.COUNTER], + **other_counter, + } + + def dict(self): + # NOTE: Without deepcopy, if value is a list, cfg.dict() will use a shallow copy of this list, + # then change this list of cfg.dict() will lead to unexpected changeing of original cfg + result = {} + for key, value in deepcopy(self).items(): + if isinstance(value, CfgNode): + result[key] = value.dict() + else: + result[key] = value + + return result + + def diff(self, other: CfgNode): + """Show the difference between self and other `CfgNode`, helping user + find Help users quickly identify the difference between them. + + Args: + other (CfgNode): Another `CfgNode`. + + Returns: + DeepDiff: A class containing difference, include adding, deleting and modifing. + """ + from deepdiff import DeepDiff + + return DeepDiff(self, other) + + def dump(self, *args, **kwargs): + """ + At present dump() can only ensure original CfgNode == the one after dump and reload, + but can not ensure the order of their keys is consistent. + + Returns: + str: a yaml string representation of the config + """ + # to make it show up in docs + return super().dump(*args, **kwargs) + + def save(self, filepath): + with open(filepath, "w", encoding="utf-8") as fp: + # set sort_key=False to keep writing order the same as original + # input file rather than ordered by alphabetically; + # set default_flow_style=None to keep list element written in one line + # allow_unicode=True to support Chinese input + self.dump( + stream=fp, sort_keys=False, default_flow_style=None, allow_unicode=True + ) + + def depth(self): + return _dict_depth(self) + + def unvisited_keys(self, inverse: Optional[bool] = False) -> List[str]: + """Return all unvisited keys. + + Args: + inverse (Optional[bool], optional): return all visited keys if `inverse` is True. Defaults to False. + + Returns: + List[str]: list of all unvisited/visited keys. + """ + self.__update_counter(self) + condition = lambda x: x == 0 if not inverse else x > 0 + return [ + key + for key, value in self.__dict__[CfgNode.COUNTER].items() + if condition(value) + ] + + def __update_counter(self, root: CfgNode, prefix=""): + """Internal methods to recursively update counter for each keys. + + Args: + root (CfgNode): Parent node of current CfgNode. + prefix (str, optional): Concatenation of parent, grandparent and so on. + For root CfgNode `prefix` is "", for a SUBKEY `prefix` may be "TOPKEYA.KEYB". + """ + for key, kid_node in self.items(): + new_key = prefix + sep + key if prefix else key + if isinstance(kid_node, dict) and _dict_depth(kid_node) > 0: + kid_node.__update_counter(root, new_key) + else: + # a new_key of value "TOPKEYA.KEYB.SUBKEYA" lead to a1 slice_key + # of value "['KEYB.SUBKEYA', 'TOPKEYA.KEYB.SUBKEYA']", which contain all parent keys + sliced_keys = [ + ".".join(new_key.split(".")[-k:]) + for k in range(2, 1 + len(new_key.split("."))) + ] + # `self` is the father of `key`, and `root` is the father of `self` + for root_key in root.__dict__[CfgNode.COUNTER].keys(): + matched = any( + [sliced_key in root_key for sliced_key in sliced_keys] + ) + if matched: + root.__dict__[CfgNode.COUNTER][root_key] = self.__dict__[ + CfgNode.COUNTER + ][key] + + +def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): + """Checks that `replacement`, which is intended to replace `original` is of + the right type. The type is correct if it matches exactly or is one of a few + cases in which the type can be easily coerced. + """ + original_type = type(original) + replacement_type = type(replacement) + + # The types must match (with some exceptions) + if replacement_type == original_type or issubclass(original_type, replacement_type): + return replacement + + # If either of them is None, allow type conversion to one of the valid types + if (replacement_type == type(None) and original_type in _VALID_TYPES) or ( + original_type == type(None) and replacement_type in _VALID_TYPES + ): + return replacement + + # Cast replacement from from_type to to_type if the replacement and original + # types match from_type and to_type + def conditional_cast(from_type, to_type): + if replacement_type == from_type and original_type == to_type: + return True, to_type(replacement) + else: + return False, None + + # Conditionally casts + # list <-> tuple + casts = [(tuple, list), (list, tuple)] + # For py2: allow converting from str (bytes) to a unicode string + try: + casts.append((str, unicode)) # noqa: F821 + except Exception: + pass + + for from_type, to_type in casts: + converted, converted_value = conditional_cast(from_type, to_type) + if converted: + return converted_value + + raise ValueError( + f"Key type mismatchs during merging config! Key: {full_key}, original: {original} of type {original_type}, new: {replacement} of type {replacement_type}." + ) + + +def _merge_a_into_b(a, b, root, key_list): + """Merge config dictionary a into config dictionary b, clobbering the + options in b whenever they are also specified in a. + """ + _assert_with_logging( + isinstance(a, CfgNode), + "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), + ) + _assert_with_logging( + isinstance(b, CfgNode), + "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), + ) + + for k, v_ in a.items(): + full_key = ".".join(key_list + [k]) + + v = deepcopy(v_) + v = b._decode_cfg_value(v) + + if k in b: + v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) + # Recursively merge dicts + if isinstance(v, CfgNode): + try: + _merge_a_into_b(v, b[k], root, key_list + [k]) + except BaseException: + raise + else: + b[k] = v + elif b.is_new_allowed() or isinstance(b, MutableCfgNode): + b[k] = v + else: + if root.key_is_deprecated(full_key): + continue + elif root.key_is_renamed(full_key): + root.raise_key_rename_error(full_key) + else: + raise KeyError("Non-existent config key: {}".format(full_key)) + + +class MutableCfgNode(CfgNode): + def __init__(self, init_dict=None, key_list=None, new_allowed=False): + super().__init__(init_dict, key_list, new_allowed) + self.set_new_allowed(new_allowed) + + +def _get_args_from_config(from_config_func, *args, **kwargs): + """ + Use `from_config` to obtain explicit arguments. + Returns: + dict: arguments to be used for cls.__init__ + """ + # inspect.signature() obtains parameter list of function, such as (a, b=0, *c, d, e=1, **f) + signature = inspect.signature(from_config_func) + # cfg should be passed as the first parameter, whether it is a positional or keyword argument + if list(signature.parameters.keys())[0] != "cfg": + if inspect.isfunction(from_config_func): + name = from_config_func.__name__ + else: + name = f"{from_config_func.__self__}.from_config" + raise TypeError(f"{name} must take 'cfg' as the first argument!") + support_var_arg = any( + param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] + for param in signature.parameters.values() + ) + if ( + support_var_arg + ): # forward all arguments to from_config, if from_config accepts them + ret = from_config_func(*args, **kwargs) + else: + # forward supported arguments to from_config + supported_arg_names = set(signature.parameters.keys()) + extra_kwargs = {} + for name in list(kwargs.keys()): + if name not in supported_arg_names: + extra_kwargs[name] = kwargs.pop(name) + ret = from_config_func(*args, **kwargs) + # forward the other arguments to __init__ + ret.update(extra_kwargs) + return ret + + +def _called_with_cfg(*args, **kwargs): + """ + Returns: + bool: whether the arguments contain CfgNode and should be considered + forwarded to from_config. + """ + from omegaconf import DictConfig + + if len(args) and isinstance(args[0], (_CfgNode, DictConfig)): + return True + if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)): + return True + # `from_config`'s first argument is forced to be "cfg". + # So the above check covers all cases. + return False + + +def configurable(init_func=None, *, from_config=None): + """ + Decorate a function or a class's method so that it can be called + with a :class:`CfgNode` object using a :func:`from_config` function that translates + :class:`CfgNode` to arguments. + Examples: + :: + # Usage 1: Decorator on __init__: + class A: + @configurable + def __init__(self, a, b=2, c=3): + pass + @classmethod + def from_config(cls, cfg): # 'cfg' must be the first argument + # Returns kwargs to be passed to __init__ + return {"a": cfg.A, "b": cfg.B} + a1 = A(a=1, b=2) # regular construction + a2 = A(cfg) # construct with a cfg + a3 = A(cfg, b=3, c=4) # construct with extra overwrite + + # Usage 2: Decorator on any function. Needs an extra from_config argument: + @configurable(from_config=lambda cfg: {"a": cfg.A, "b": cfg.B}) + def a_func(a, b=2, c=3): + pass + a1 = a_func(a=1, b=2) # regular call + a2 = a_func(cfg) # call with a cfg + a3 = a_func(cfg, b=3, c=4) # call with extra overwrite + + # Usage 3: Decorator on any method of class. Needs an extra from_config argument: + class A: + @configurable(from_config=lambda cfg: {"a": cfg.A, "b": cfg.B}) + def a_func(self, a, b=2, c=3): + pass + insA = A() + cfg = CfgNode.load_cfg('{"A": "2", "B": "3"}') + a1 = insA.a_func(a=1, b=2) # regular call + a2 = insA.a_func(cfg) # call with a cfg + a3 = insA.a_func(cfg, b=3, c=4) # call with extra overwrite + + Args: + init_func (callable): a class's ``__init__`` method in usage 1. The + class must have a ``from_config`` classmethod which takes `cfg` as + the first argument. + from_config (callable): the from_config function in usage 2 and 3. It must take `cfg` + as its first argument. + """ + if init_func is not None: + assert ( + inspect.isfunction(init_func) + and from_config is None + and init_func.__name__ == "__init__" + ), "Incorrect use of @configurable. Check API documentation for examples." + + @functools.wraps(init_func) + def wrapped(self, *args, **kwargs): + try: + from_config_func = type(self).from_config + except AttributeError as e: + raise AttributeError( + "Class with @configurable must have a 'from_config' classmethod." + ) from e + if not inspect.ismethod(from_config_func): + raise TypeError( + "Class with @configurable must have a 'from_config' classmethod." + ) + + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) + init_func(self, **explicit_args) + else: + init_func(self, *args, **kwargs) + + return wrapped + + else: + if from_config is None: + return configurable # @configurable() is made equivalent to @configurable + assert inspect.isfunction( + from_config + ), "from_config argument of configurable must be a function!" + + def wrapper(orig_func): + params = inspect.signature(orig_func).parameters + if "self" in params or "cls" in params: # classmethod or instancemethod + + @functools.wraps(orig_func) + def wrapped( + self, *args, **kwargs + ): # here `self` means actual `self` or `cls` + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config( + from_config, *args, **kwargs + ) + return orig_func(self, **explicit_args) + else: + return orig_func(self, *args, **kwargs) + + wrapped.from_config = from_config + return wrapped + + else: # function or staticmethod + + @functools.wraps(orig_func) + def wrapped(*args, **kwargs): + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config( + from_config, *args, **kwargs + ) + return orig_func(**explicit_args) + else: + return orig_func(*args, **kwargs) + + wrapped.from_config = from_config + return wrapped + + return wrapper diff --git a/embodichain/utils/configclass.py b/embodichain/utils/configclass.py new file mode 100644 index 00000000..61727d60 --- /dev/null +++ b/embodichain/utils/configclass.py @@ -0,0 +1,616 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# All rights reserved. +# +# This file incorporates code from the Isaac Lab Project +# Copyright (c) 2022-2025, The Isaac Lab Project Developers +# (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# ---------------------------------------------------------------------------- + + +import torch +import inspect +import types +from collections.abc import Callable, Mapping, Iterable, Sized +from copy import deepcopy +from dataclasses import MISSING, Field, dataclass, field, replace +from typing import Any, ClassVar, Optional +from .string import callable_to_string, string_to_callable + + +_CONFIGCLASS_METHODS = ["to_dict", "replace", "copy", "validate"] +"""List of class methods added at runtime to dataclass.""" + +""" +Wrapper around dataclass. +""" + + +def __dataclass_transform__(): + """Add annotations decorator for PyLance.""" + return lambda a: a + + +def is_configclass(cls: Any) -> bool: + """Check if a class is a configclass. + + Args: + cls: The class to check. + + Returns: + True if the class is a configclass, False otherwise. + """ + return hasattr(cls, "validate") + + +@__dataclass_transform__() +def configclass(cls, **kwargs): + """Wrapper around `dataclass` functionality to add extra checks and utilities. + + As of Python 3.7, the standard dataclasses have two main issues which makes them non-generic for + configuration use-cases. These include: + + 1. Requiring a type annotation for all its members. + 2. Requiring explicit usage of :meth:`field(default_factory=...)` to reinitialize mutable variables. + + This function provides a decorator that wraps around Python's `dataclass`_ utility to deal with + the above two issues. It also provides additional helper functions for dictionary <-> class + conversion and easily copying class instances. + + Usage: + + .. code-block:: python + + from dataclasses import MISSING + + from isaaclab.utils.configclass import configclass + + + @configclass + class ViewerCfg: + eye: list = [7.5, 7.5, 7.5] # field missing on purpose + lookat: list = field(default_factory=[0.0, 0.0, 0.0]) + + + @configclass + class EnvCfg: + num_envs: int = MISSING + episode_length: int = 2000 + viewer: ViewerCfg = ViewerCfg() + + # create configuration instance + env_cfg = EnvCfg(num_envs=24) + + # print information as a dictionary + print(env_cfg.to_dict()) + + # create a copy of the configuration + env_cfg_copy = env_cfg.copy() + + # replace arbitrary fields using keyword arguments + env_cfg_copy = env_cfg_copy.replace(num_envs=32) + + Args: + cls: The class to wrap around. + **kwargs: Additional arguments to pass to :func:`dataclass`. + + Returns: + The wrapped class. + + .. _dataclass: https://docs.python.org/3/library/dataclasses.html + + Reference: + https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab/isaaclab/utils/configclass.py + """ + # add type annotations + _add_annotation_types(cls) + # add field factory + _process_mutable_types(cls) + # copy mutable members + # note: we check if user defined __post_init__ function exists and augment it with our own + if hasattr(cls, "__post_init__"): + setattr( + cls, "__post_init__", combined_function(cls.__post_init__, custom_post_init) + ) + else: + setattr(cls, "__post_init__", custom_post_init) + # add helper functions for dictionary conversion + setattr(cls, "to_dict", class_to_dict) + # setattr(cls, "from_dict", update_class_from_dict) + setattr(cls, "replace", _replace_class_with_kwargs) + setattr(cls, "copy", _replace_class_with_kwargs) + setattr(cls, "validate", _validate) + # wrap around dataclass + cls = dataclass(cls, **kwargs) + # return wrapped class + return cls + + +def combined_function(f1: Callable, f2: Callable) -> Callable: + """Combine two functions into one. + + Args: + f1: The first function. + f2: The second function. + + Returns: + The combined function. + """ + + def _combined(*args, **kwargs): + # call both functions + f1(*args, **kwargs) + f2(*args, **kwargs) + + return _combined + + +def custom_post_init(obj): + """Deepcopy all elements to avoid shared memory issues for mutable objects in dataclasses initialization. + + This function is called explicitly instead of as a part of :func:`_process_mutable_types()` to prevent mapping + proxy type i.e. a read only proxy for mapping objects. The error is thrown when using hierarchical data-classes + for configuration. + """ + for key in dir(obj): + # skip dunder members + if key.startswith("__"): + continue + # get data member + value = getattr(obj, key) + # check annotation + ann = obj.__class__.__dict__.get(key) + # duplicate data members that are mutable + if not callable(value) and not isinstance(ann, property): + try: + setattr(obj, key, deepcopy(value)) + except AttributeError as e: + from IPython import embed + + embed() + + +def class_to_dict(obj: object) -> dict[str, Any]: + """Convert an object into dictionary recursively. + + Note: + Ignores all names starting with "__" (i.e. built-in methods). + + Args: + obj: An instance of a class to convert. + + Raises: + ValueError: When input argument is not an object. + + Returns: + Converted dictionary mapping. + """ + # check that input data is class instance + if not hasattr(obj, "__class__"): + raise ValueError(f"Expected a class instance. Received: {type(obj)}.") + # convert object to dictionary + if isinstance(obj, dict): + obj_dict = obj + elif isinstance(obj, torch.Tensor): + # We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty + # dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we + # want to store it directly as the tensor. + return obj + elif hasattr(obj, "__dict__"): + obj_dict = obj.__dict__ + else: + return obj + + # convert to dictionary + data = dict() + for key, value in obj_dict.items(): + # disregard builtin attributes + if key.startswith("__"): + continue + # check if attribute is callable -- function + if callable(value): + data[key] = callable_to_string(value) + # check if attribute is a dictionary + elif hasattr(value, "__dict__") or isinstance(value, dict): + data[key] = class_to_dict(value) + # check if attribute is a list or tuple + elif isinstance(value, (list, tuple)): + data[key] = type(value)([class_to_dict(v) for v in value]) + else: + data[key] = value + return data + + +def update_class_from_dict(obj, data: dict[str, Any], _ns: str = "") -> None: + """Reads a dictionary and sets object variables recursively. + + This function performs in-place update of the class member attributes. + + Args: + obj: An instance of a class to update. + data: Input dictionary to update from. + _ns: Namespace of the current object. This is useful for nested configuration + classes or dictionaries. Defaults to "". + + Raises: + TypeError: When input is not a dictionary. + ValueError: When dictionary has a value that does not match default config type. + KeyError: When dictionary has a key that does not exist in the default config type. + """ + for key, value in data.items(): + # key_ns is the full namespace of the key + key_ns = _ns + "/" + key + + # -- A) if key is present in the object ------------------------------------ + if hasattr(obj, key) or (isinstance(obj, dict) and key in obj): + obj_mem = obj[key] if isinstance(obj, dict) else getattr(obj, key) + + # -- 1) nested mapping → recurse --------------------------- + if isinstance(value, Mapping): + # recursively call if it is a dictionary + update_class_from_dict(obj_mem, value, _ns=key_ns) + continue + + # -- 2) iterable (list / tuple / etc.) --------------------- + if isinstance(value, Iterable) and not isinstance(value, str): + + # ---- 2a) flat iterable → replace wholesale ---------- + if all(not isinstance(el, Mapping) for el in value): + out_val = tuple(value) if isinstance(obj_mem, tuple) else value + if isinstance(obj, dict): + obj[key] = out_val + else: + setattr(obj, key, out_val) + continue + + # ---- 2b) existing value is None → abort ------------- + if obj_mem is None: + raise ValueError( + f"[Config]: Cannot merge list under namespace: {key_ns} because the existing value is None." + ) + + # ---- 2c) length mismatch → abort ------------------- + if ( + isinstance(obj_mem, Sized) + and isinstance(value, Sized) + and len(obj_mem) != len(value) + ): + raise ValueError( + f"[Config]: Incorrect length under namespace: {key_ns}." + f" Expected: {len(obj_mem)}, Received: {len(value)}." + ) + + # ---- 2d) keep tuple/list parity & recurse ---------- + if isinstance(obj_mem, tuple): + value = tuple(value) + else: + set_obj = True + # recursively call if iterable contains Mappings + for i in range(len(obj_mem)): + if isinstance(value[i], Mapping): + update_class_from_dict(obj_mem[i], value[i], _ns=key_ns) + set_obj = False + # do not set value to obj, otherwise it overwrites the cfg class with the dict + if not set_obj: + continue + + # -- 3) callable attribute → resolve string -------------- + elif callable(obj_mem): + # update function name + value = string_to_callable(value) + + # -- 4) simple scalar / explicit None --------------------- + elif value is None or isinstance(value, type(obj_mem)): + pass + + # -- 5) type mismatch → abort ----------------------------- + else: + raise ValueError( + f"[Config]: Incorrect type under namespace: {key_ns}." + f" Expected: {type(obj_mem)}, Received: {type(value)}." + ) + + # -- 6) final assignment --------------------------------- + if isinstance(obj, dict): + obj[key] = value + else: + setattr(obj, key, value) + + # -- B) if key is not present ------------------------------------ + else: + raise KeyError(f"[Config]: Key not found under namespace: {key_ns}.") + + +def _replace_class_with_kwargs(obj: object, **kwargs) -> object: + """Return a new object replacing specified fields with new values. + + This is especially useful for frozen classes. Example usage: + + .. code-block:: python + + @configclass(frozen=True) + class C: + x: int + y: int + + c = C(1, 2) + c1 = c.replace(x=3) + assert c1.x == 3 and c1.y == 2 + + Args: + obj: The object to replace. + **kwargs: The fields to replace and their new values. + + Returns: + The new object. + """ + return replace(obj, **kwargs) + + +def _validate(obj: object, prefix: str = "") -> list[str]: + """Check the validity of configclass object. + + This function checks if the object is a valid configclass object. A valid configclass object contains no MISSING + entries. + + Args: + obj: The object to check. + prefix: The prefix to add to the missing fields. Defaults to ''. + + Returns: + A list of missing fields. + + Raises: + TypeError: When the object is not a valid configuration object. + """ + missing_fields = [] + + if type(obj) is type(MISSING): + missing_fields.append(prefix) + return missing_fields + elif isinstance(obj, (list, tuple)): + for index, item in enumerate(obj): + current_path = f"{prefix}[{index}]" + missing_fields.extend(_validate(item, prefix=current_path)) + return missing_fields + elif isinstance(obj, dict): + obj_dict = obj + elif hasattr(obj, "__dict__"): + obj_dict = obj.__dict__ + else: + return missing_fields + + for key, value in obj_dict.items(): + # disregard builtin attributes + if key.startswith("__"): + continue + current_path = f"{prefix}.{key}" if prefix else key + missing_fields.extend(_validate(value, prefix=current_path)) + + # raise an error only once at the top-level call + if prefix == "" and missing_fields: + formatted_message = "\n".join(f" - {field}" for field in missing_fields) + raise TypeError( + f"Missing values detected in object {obj.__class__.__name__} for the following" + f" fields:\n{formatted_message}\n" + ) + return missing_fields + + +def _add_annotation_types(cls): + """Add annotations to all elements in the dataclass. + + By definition in Python, a field is defined as a class variable that has a type annotation. + + In case type annotations are not provided, dataclass ignores those members when :func:`__dict__()` is called. + This function adds these annotations to the class variable to prevent any issues in case the user forgets to + specify the type annotation. + + This makes the following a feasible operation: + + @dataclass + class State: + pos = (0.0, 0.0, 0.0) + ^^ + If the function is NOT used, the following type-error is returned: + TypeError: 'pos' is a field but has no type annotation + """ + # get type hints + hints = {} + # iterate over class inheritance + # we add annotations from base classes first + for base in reversed(cls.__mro__): + # check if base is object + if base is object: + continue + # get base class annotations + ann = base.__dict__.get("__annotations__", {}) + # directly add all annotations from base class + hints.update(ann) + # iterate over base class members + # Note: Do not change this to dir(base) since it orders the members alphabetically. + # This is not desirable since the order of the members is important in some cases. + for key in base.__dict__: + # get class member + value = getattr(base, key) + # skip members + if _skippable_class_member(key, value, hints): + continue + # add type annotations for members that don't have explicit type annotations + # for these, we deduce the type from the default value + if not isinstance(value, type): + if key not in hints: + # check if var type is not MISSING + # we cannot deduce type from MISSING! + if value is MISSING: + raise TypeError( + f"Missing type annotation for '{key}' in class '{cls.__name__}'." + " Please add a type annotation or set a default value." + ) + # add type annotation + hints[key] = type(value) + elif key != value.__name__: + # note: we don't want to add type annotations for nested configclass. Thus, we check if + # the name of the type matches the name of the variable. + # since Python 3.10, type hints are stored as strings + hints[key] = f"type[{value.__name__}]" + + # Note: Do not change this line. `cls.__dict__.get("__annotations__", {})` is different from + # `cls.__annotations__` because of inheritance. + cls.__annotations__ = cls.__dict__.get("__annotations__", {}) + cls.__annotations__ = hints + + +def _process_mutable_types(cls): + """Initialize all mutable elements through :obj:`dataclasses.Field` to avoid unnecessary complaints. + + By default, dataclass requires usage of :obj:`field(default_factory=...)` to reinitialize mutable objects every time a new + class instance is created. If a member has a mutable type and it is created without specifying the `field(default_factory=...)`, + then Python throws an error requiring the usage of `default_factory`. + + Additionally, Python only explicitly checks for field specification when the type is a list, set or dict. This misses the + use-case where the type is class itself. Thus, the code silently carries a bug with it which can lead to undesirable effects. + + This function deals with this issue + + This makes the following a feasible operation: + + @dataclass + class State: + pos: list = [0.0, 0.0, 0.0] + ^^ + If the function is NOT used, the following value-error is returned: + ValueError: mutable default for field pos is not allowed: use default_factory + """ + # note: Need to set this up in the same order as annotations. Otherwise, it + # complains about missing positional arguments. + ann = cls.__dict__.get("__annotations__", {}) + + # iterate over all class members and store them in a dictionary + class_members = {} + for base in reversed(cls.__mro__): + # check if base is object + if base is object: + continue + # iterate over base class members + for key in base.__dict__: + # get class member + f = getattr(base, key) + # skip members + if _skippable_class_member(key, f): + continue + # store class member if it is not a type or if it is already present in annotations + if not isinstance(f, type) or key in ann: + class_members[key] = f + # iterate over base class data fields + # in previous call, things that became a dataclass field were removed from class members + # so we need to add them back here as a dataclass field directly + for key, f in base.__dict__.get("__dataclass_fields__", {}).items(): + # store class member + if not isinstance(f, type): + class_members[key] = f + + # check that all annotations are present in class members + # note: mainly for debugging purposes + if len(class_members) != len(ann): + raise ValueError( + f"In class '{cls.__name__}', number of annotations ({len(ann)}) does not match number of class members" + f" ({len(class_members)}). Please check that all class members have type annotations and/or a default" + " value. If you don't want to specify a default value, please use the literal `dataclasses.MISSING`." + ) + # iterate over annotations and add field factory for mutable types + for key in ann: + # find matching field in class + value = class_members.get(key, MISSING) + # check if key belongs to ClassVar + # in that case, we cannot use default_factory! + origin = getattr(ann[key], "__origin__", None) + if origin is ClassVar: + continue + # check if f is MISSING + # note: commented out for now since it causes issue with inheritance + # of dataclasses when parent have some positional and some keyword arguments. + # Ref: https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses + # TODO: check if this is fixed in Python 3.10 + # if f is MISSING: + # continue + if isinstance(value, Field): + setattr(cls, key, value) + elif not isinstance(value, type): + # create field factory for mutable types + value = field(default_factory=_return_f(value)) + setattr(cls, key, value) + + +def _skippable_class_member(key: str, value: Any, hints: Optional[dict] = None) -> bool: + """Check if the class member should be skipped in configclass processing. + + The following members are skipped: + + * Dunder members: ``__name__``, ``__module__``, ``__qualname__``, ``__annotations__``, ``__dict__``. + * Manually-added special class functions: From :obj:`_CONFIGCLASS_METHODS`. + * Members that are already present in the type annotations. + * Functions bounded to class object or class. + * Properties bounded to class object. + + Args: + key: The class member name. + value: The class member value. + hints: The type hints for the class. Defaults to None, in which case, the + members existence in type hints are not checked. + + Returns: + True if the class member should be skipped, False otherwise. + """ + # skip dunder members + if key.startswith("__"): + return True + # skip manually-added special class functions + if key in _CONFIGCLASS_METHODS: + return True + # check if key is already present + if hints is not None and key in hints: + return True + # skip functions bounded to class + if callable(value): + # FIXME: This doesn't yet work for static methods because they are essentially seen as function types. + # check for class methods + if isinstance(value, types.MethodType): + return True + # check for instance methods + signature = inspect.signature(value) + if "self" in signature.parameters or "cls" in signature.parameters: + return True + # skip property methods + if isinstance(value, property): + return True + # Otherwise, don't skip + return False + + +def _return_f(f: Any) -> Callable[[], Any]: + """Returns default factory function for creating mutable/immutable variables. + + This function should be used to create default factory functions for variables. + + Example: + + .. code-block:: python + + value = field(default_factory=_return_f(value)) + setattr(cls, key, value) + """ + + def _wrap(): + if isinstance(f, Field): + if f.default_factory is MISSING: + return deepcopy(f.default) + else: + return f.default_factory + else: + return deepcopy(f) + + return _wrap diff --git a/embodichain/utils/device_utils.py b/embodichain/utils/device_utils.py new file mode 100644 index 00000000..198ca84e --- /dev/null +++ b/embodichain/utils/device_utils.py @@ -0,0 +1,39 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + +from typing import Union + + +def standardize_device_string(device: Union[str, torch.device]) -> str: + """Standardize the device string for Warp compatibility. + + Args: + device (Union[str, torch.device]): The device specification. + + Returns: + str: The standardized device string. + """ + if isinstance(device, str): + device_str = device + else: + device_str = str(device) + + if device_str.startswith("cuda"): + device_str = "cuda:0" + + return device_str diff --git a/embodichain/utils/file.py b/embodichain/utils/file.py new file mode 100644 index 00000000..cd089491 --- /dev/null +++ b/embodichain/utils/file.py @@ -0,0 +1,54 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import re + +from typing import Optional, List + + +def get_all_files_in_directory( + directory: str, + exts: Optional[List[str]] = None, + patterns: Optional[List[str]] = None, +) -> List[str]: + """Get all files in a directory with optional filtering by extensions or regex patterns. + + Args: + directory (str): The directory to search for files. + exts (Optional[List[str]]): List of file extensions to filter by. If None, all files are returned. + patterns (Optional[List[str]]): List of regex patterns to match file names. If None, no pattern matching is applied. + + Returns: + List[str]: List of file paths in the directory matching the specified extensions or patterns. + """ + all_files = [] + compiled_patterns = ( + [re.compile(pattern) for pattern in patterns] if patterns else [] + ) + + for root, _, files in os.walk(directory): + for file in files: + match_ext = exts is None or any( + file.lower().endswith(ext.lower()) for ext in exts + ) + match_pattern = not compiled_patterns or any( + pattern.search(file) for pattern in compiled_patterns + ) + + if match_ext and match_pattern: + all_files.append(os.path.join(root, file)) + return all_files diff --git a/embodichain/utils/img_utils.py b/embodichain/utils/img_utils.py new file mode 100644 index 00000000..14593c49 --- /dev/null +++ b/embodichain/utils/img_utils.py @@ -0,0 +1,154 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """Convert binary masks to bounding boxes. + + Args: + masks (torch.Tensor): A tensor of shape (..., H, W) containing binary masks + where non-zero values indicate the presence of the object. + + Returns: + torch.Tensor: A tensor of shape (..., 4) containing the bounding boxes + in XYXY format. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out + + +def gen_disp_colormap(inputs, normalize=True, torch_transpose=True): + """ + Generate an RGB visualization using the "plasma" colormap for 2D/3D/4D scalar image inputs. + + This utility maps scalar image(s) to an RGB colormap suitable for display or further processing. + It accepts either a NumPy array or a torch.Tensor (torch tensors are detached, moved to CPU and + converted to NumPy). The matplotlib "plasma" colormap with 256 entries is used. + + Parameters + - inputs (numpy.ndarray or torch.Tensor): + Scalar image data with one of the following dimensionalities: + * 2D: (H, W) -> a single image + * 3D: (N, H, W) -> a batch of N single-channel images + * 4D: (N, C, H, W) -> a batch with channel dimension; expected C==1 (first channel used) + The function will convert torch.Tensor input to numpy internally. + - normalize (bool, default True): + If True, input values are linearly scaled to [0, 1] using (x - min) / (max - min). + If the input is constant (min == max), a small divisor (1e5) is used to avoid division + by zero, which effectively maps values near 0. If False, values are assumed to already be + in the [0, 1] range (no scaling is performed). + - torch_transpose (bool, default True): + Controls the output channel ordering to match common PyTorch conventions: + * If True: outputs are transposed to channel-first form: + - 2D input -> (3, H, W) + - 3D input -> (N, 3, H, W) + - 4D input -> (N, 3, H, W) (uses the first channel) + * If False: outputs keep channel-last ordering: + - 2D input -> (H, W, 3) + - 3D input -> (N, H, W, 3) + - 4D input -> (N, H, W, 3) + + Returns + - numpy.ndarray: + RGB image(s) with float values in [0, 1]. The exact output shape depends on the input + dimensionality and the value of torch_transpose (see above). The alpha channel produced by + the colormap is discarded; only the RGB channels are returned. + + Notes and behavior + - The function uses matplotlib.pyplot.get_cmap("plasma", 256). + - For 4D inputs the code selects the first channel (index 0) before applying the colormap. + - Inputs with dimensionality other than 2, 3, or 4 are not supported and will likely raise + an error or produce unintended results. + - This function is non-destructive: it returns a new NumPy array and does not modify the input. + - Typical use cases: visualizing depth maps, single-channel activation maps, or other scalar + images as colored RGB images for inspection or logging. + + Examples + - 2D array (H, W) -> returns (3, H, W) if torch_transpose=True + - 3D array (N, H, W) -> returns (N, 3, H, W) if torch_transpose=True + - 4D array (N, 1, H, W) -> returns (N, 3, H, W) if torch_transpose=True + """ + import matplotlib.pyplot as plt + import torch + + _DEPTH_COLORMAP = plt.get_cmap("plasma", 256) # for plotting + if isinstance(inputs, torch.Tensor): + inputs = inputs.detach().cpu().numpy() + + vis = inputs + if normalize: + ma = float(vis.max()) + mi = float(vis.min()) + d = ma - mi if ma != mi else 1e5 + vis = (vis - mi) / d + + if vis.ndim == 4: + vis = vis.transpose([0, 2, 3, 1]) + vis = _DEPTH_COLORMAP(vis) + vis = vis[:, :, :, 0, :3] + if torch_transpose: + vis = vis.transpose(0, 3, 1, 2) + elif vis.ndim == 3: + vis = _DEPTH_COLORMAP(vis) + vis = vis[:, :, :3] + if torch_transpose: + vis = vis.transpose(0, 3, 1, 2) + elif vis.ndim == 2: + vis = _DEPTH_COLORMAP(vis) + vis = vis[..., :3] + if torch_transpose: + vis = vis.transpose(2, 0, 1) + + return vis diff --git a/embodichain/utils/logger.py b/embodichain/utils/logger.py new file mode 100644 index 00000000..b2dd22eb --- /dev/null +++ b/embodichain/utils/logger.py @@ -0,0 +1,74 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S" +) + +# Create a custom logger +logger = logging.getLogger(__name__) + +# Set the default log level +logger.setLevel(logging.INFO) + + +def decorate_str_color(msg: str, color: str): + """Decorate a string with a specific color.""" + color_map = { + "red": "\033[91m", + "green": "\033[92m", + "yellow": "\033[93m", + "blue": "\033[94m", + "purple": "\033[95m", + "cyan": "\033[96m", + "orange": "\033[33m", + "white": "\033[97m", + } + return f"{color_map.get(color, '')}{msg}\033[0m" if color else msg + + +def set_log_level(level: str): + """Set the logging level.""" + level = level.upper() + assert level in ["DEBUG", "INFO", "WARNING", "ERROR"], "Invalid log level" + logger.setLevel(getattr(logging, level)) + + +def format_message(level: str, message: str): + """Format the log message with a consistent prefix.""" + return f"[EmbodiChain {level}]: {message}" + + +def log_info(message, color=None): + """Log an info message.""" + logger.info(decorate_str_color(format_message("INFO", message), color)) + + +def log_debug(message, color="blue"): + """Log a debug message.""" + logger.debug(decorate_str_color(format_message("DEBUG", message), color)) + + +def log_warning(message): + """Log a warning message.""" + logger.warning(decorate_str_color(format_message("WARNING", message), "purple")) + + +def log_error(message, error_type=RuntimeError): + """Log an error message.""" + raise error_type(decorate_str_color(format_message("ERROR", message), "red")) diff --git a/embodichain/utils/math.py b/embodichain/utils/math.py new file mode 100644 index 00000000..c42f4374 --- /dev/null +++ b/embodichain/utils/math.py @@ -0,0 +1,2269 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +# needed to import for allowing type-hinting: Union[torch.Tensor, np.ndarray] +from __future__ import annotations + +import math +import torch +import numpy as np +import torch.nn.functional +from typing import Literal, Optional, Union + + +def look_at_to_pose( + eye: Union[torch.Tensor, list], + target: Union[torch.Tensor, list], + up: Union[torch.Tensor, list] = [0, 0, 1], +) -> torch.Tensor: + """Get the camera pose from eye to target with up direction, supporting batch processing. + + Args: + eye (Union[torch.Tensor, list]): Camera positions with shape (N, 3). + target (Union[torch.Tensor, list]): Target positions with shape (N, 3). + up (Union[torch.Tensor, list], optional): Up directions with shape (N, 3) or (3,). Defaults to [0, 0, 1]. + + Returns: + torch.Tensor: Camera pose matrices with shape (N, 4, 4). + """ + eye = ( + torch.tensor(eye, dtype=torch.float32) + if not isinstance(eye, torch.Tensor) + else eye + ) + target = ( + torch.tensor(target, dtype=torch.float32) + if not isinstance(target, torch.Tensor) + else target + ) + up = ( + torch.tensor(up, dtype=torch.float32) + if not isinstance(up, torch.Tensor) + else up + ) + + if eye.ndim == 1: + eye = eye.unsqueeze(0) + + if target.ndim == 1: + target = target.unsqueeze(0) + + if up.ndim == 1: + up = up.unsqueeze(0).repeat( + eye.shape[0], 1 + ) # Broadcast up vector to batch size + + assert ( + eye.shape[-1] == 3 and target.shape[-1] == 3 and up.shape[-1] == 3 + ), "Inputs must have shape (N, 3)." + + # Compute camera axes + camera_z = target - eye + camera_z = camera_z / torch.norm( + camera_z, dim=1, keepdim=True + ) # Normalize camera_z + camera_x = torch.cross(camera_z, up, dim=1) + camera_x_norm = torch.norm(camera_x, dim=1, keepdim=True) + if torch.any(camera_x_norm < 1e-6): # Handle degenerate cases + up = ( + torch.tensor([0, 1, 0], dtype=torch.float32) + .unsqueeze(0) + .repeat(eye.shape[0], 1) + ) + camera_x = torch.cross(up, camera_z, dim=1) + camera_x = camera_x / torch.norm( + camera_x, dim=1, keepdim=True + ) # Normalize camera_x + camera_y = torch.cross(camera_z, camera_x, dim=1) # Compute camera_y + + # Construct camera pose matrices + camera_pose = ( + torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat(eye.shape[0], 1, 1) + ) # (N, 4, 4) + camera_pose[:, :3, 0] = camera_x + camera_pose[:, :3, 1] = camera_y + camera_pose[:, :3, 2] = camera_z + camera_pose[:, :3, 3] = eye + + return camera_pose + + +@torch.jit.script +def scale_transform( + x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor +) -> torch.Tensor: + """Normalizes a given input tensor to a range of [-1, 1]. + + .. note:: + It uses pytorch broadcasting functionality to deal with batched input. + + Args: + x: Input tensor of shape (N, dims). + lower: The minimum value of the tensor. Shape is (N, dims) or (dims,). + upper: The maximum value of the tensor. Shape is (N, dims) or (dims,). + + Returns: + Normalized transform of the tensor. Shape is (N, dims). + """ + # default value of center + offset = (lower + upper) * 0.5 + # return normalized tensor + return 2 * (x - offset) / (upper - lower) + + +@torch.jit.script +def unscale_transform( + x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor +) -> torch.Tensor: + """De-normalizes a given input tensor from range of [-1, 1] to (lower, upper). + + .. note:: + It uses pytorch broadcasting functionality to deal with batched input. + + Args: + x: Input tensor of shape (N, dims). + lower: The minimum value of the tensor. Shape is (N, dims) or (dims,). + upper: The maximum value of the tensor. Shape is (N, dims) or (dims,). + + Returns: + De-normalized transform of the tensor. Shape is (N, dims). + """ + # default value of center + offset = (lower + upper) * 0.5 + # return normalized tensor + return x * (upper - lower) * 0.5 + offset + + +@torch.jit.script +def saturate(x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor) -> torch.Tensor: + """Clamps a given input tensor to (lower, upper). + + It uses pytorch broadcasting functionality to deal with batched input. + + Args: + x: Input tensor of shape (N, dims). + lower: The minimum value of the tensor. Shape is (N, dims) or (dims,). + upper: The maximum value of the tensor. Shape is (N, dims) or (dims,). + + Returns: + Clamped transform of the tensor. Shape is (N, dims). + """ + return torch.max(torch.min(x, upper), lower) + + +@torch.jit.script +def normalize(x: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: + """Normalizes a given input tensor to unit length. + + Args: + x: Input tensor of shape (N, dims). + eps: A small value to avoid division by zero. Defaults to 1e-9. + + Returns: + Normalized tensor of shape (N, dims). + """ + return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1) + + +@torch.jit.script +def wrap_to_pi(angles: torch.Tensor) -> torch.Tensor: + r"""Wraps input angles (in radians) to the range :math:`[-\pi, \pi]`. + + This function wraps angles in radians to the range :math:`[-\pi, \pi]`, such that + :math:`\pi` maps to :math:`\pi`, and :math:`-\pi` maps to :math:`-\pi`. In general, + odd positive multiples of :math:`\pi` are mapped to :math:`\pi`, and odd negative + multiples of :math:`\pi` are mapped to :math:`-\pi`. + + The function behaves similar to MATLAB's `wrapToPi `_ + function. + + Args: + angles: Input angles of any shape. + + Returns: + Angles in the range :math:`[-\pi, \pi]`. + """ + # wrap to [0, 2*pi) + wrapped_angle = (angles + torch.pi) % (2 * torch.pi) + # map to [-pi, pi] + # we check for zero in wrapped angle to make it go to pi when input angle is odd multiple of pi + return torch.where( + (wrapped_angle == 0) & (angles > 0), torch.pi, wrapped_angle - torch.pi + ) + + +@torch.jit.script +def copysign(mag: float, other: torch.Tensor) -> torch.Tensor: + """Create a new floating-point tensor with the magnitude of input and the sign of other, element-wise. + + Note: + The implementation follows from `torch.copysign`. The function allows a scalar magnitude. + + Args: + mag: The magnitude scalar. + other: The tensor containing values whose signbits are applied to magnitude. + + Returns: + The output tensor. + """ + mag_torch = abs(mag) * torch.ones_like(other) + return torch.copysign(mag_torch, other) + + +""" +Rotation +""" + + +@torch.jit.script +def quat_unique(q: torch.Tensor) -> torch.Tensor: + """Convert a unit quaternion to a standard form where the real part is non-negative. + + Quaternion representations have a singularity since ``q`` and ``-q`` represent the same + rotation. This function ensures the real part of the quaternion is non-negative. + + Args: + q: The quaternion orientation in (w, x, y, z). Shape is (..., 4). + + Returns: + Standardized quaternions. Shape is (..., 4). + """ + return torch.where(q[..., 0:1] < 0, -q, q) + + +@torch.jit.script +def matrix_from_quat(quaternions: torch.Tensor) -> torch.Tensor: + """Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: The quaternion orientation in (w, x, y, z). Shape is (..., 4). + + Returns: + Rotation matrices. The shape is (..., 3, 3). + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L41-L70 + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def convert_quat( + quat: Union[torch.Tensor, np.ndarray], to: Literal["xyzw", "wxyz"] = "xyzw" +) -> Union[torch.Tensor, np.ndarray]: + """Converts quaternion from one convention to another. + + The convention to convert TO is specified as an optional argument. If to == 'xyzw', + then the input is in 'wxyz' format, and vice-versa. + + Args: + quat: The quaternion of shape (..., 4). + to: Convention to convert quaternion to.. Defaults to "xyzw". + + Returns: + The converted quaternion in specified convention. + + Raises: + ValueError: Invalid input argument `to`, i.e. not "xyzw" or "wxyz". + ValueError: Invalid shape of input `quat`, i.e. not (..., 4,). + """ + # check input is correct + if quat.shape[-1] != 4: + msg = f"Expected input quaternion shape mismatch: {quat.shape} != (..., 4)." + raise ValueError(msg) + if to not in ["xyzw", "wxyz"]: + msg = f"Expected input argument `to` to be 'xyzw' or 'wxyz'. Received: {to}." + raise ValueError(msg) + # check if input is numpy array (we support this backend since some classes use numpy) + if isinstance(quat, np.ndarray): + # use numpy functions + if to == "xyzw": + # wxyz -> xyzw + return np.roll(quat, -1, axis=-1) + else: + # xyzw -> wxyz + return np.roll(quat, 1, axis=-1) + else: + # convert to torch (sanity check) + if not isinstance(quat, torch.Tensor): + quat = torch.tensor(quat, dtype=float) + # convert to specified quaternion type + if to == "xyzw": + # wxyz -> xyzw + return quat.roll(-1, dims=-1) + else: + # xyzw -> wxyz + return quat.roll(1, dims=-1) + + +@torch.jit.script +def quat_conjugate(q: torch.Tensor) -> torch.Tensor: + """Computes the conjugate of a quaternion. + + Args: + q: The quaternion orientation in (w, x, y, z). Shape is (..., 4). + + Returns: + The conjugate quaternion in (w, x, y, z). Shape is (..., 4). + """ + shape = q.shape + q = q.reshape(-1, 4) + return torch.cat((q[..., 0:1], -q[..., 1:]), dim=-1).view(shape) + + +@torch.jit.script +def quat_inv(q: torch.Tensor, eps: float = 1e-9) -> torch.Tensor: + """Computes the inverse of a quaternion. + + Args: + q: The quaternion orientation in (w, x, y, z). Shape is (N, 4). + eps: A small value to avoid division by zero. Defaults to 1e-9. + + Returns: + The inverse quaternion in (w, x, y, z). Shape is (N, 4). + """ + return quat_conjugate(q) / q.pow(2).sum(dim=-1, keepdim=True).clamp(min=eps) + + +@torch.jit.script +def quat_from_euler_xyz( + roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor +) -> torch.Tensor: + """Convert rotations given as Euler angles in radians to Quaternions. + + Note: + The euler angles are assumed in XYZ convention. + + Args: + roll: Rotation around x-axis (in radians). Shape is (N,). + pitch: Rotation around y-axis (in radians). Shape is (N,). + yaw: Rotation around z-axis (in radians). Shape is (N,). + + Returns: + The quaternion in (w, x, y, z). Shape is (N, 4). + """ + cy = torch.cos(yaw * 0.5) + sy = torch.sin(yaw * 0.5) + cr = torch.cos(roll * 0.5) + sr = torch.sin(roll * 0.5) + cp = torch.cos(pitch * 0.5) + sp = torch.sin(pitch * 0.5) + # compute quaternion + qw = cy * cr * cp + sy * sr * sp + qx = cy * sr * cp - sy * cr * sp + qy = cy * cr * sp + sy * sr * cp + qz = sy * cr * cp - cy * sr * sp + + return torch.stack([qw, qx, qy, qz], dim=-1) + + +@torch.jit.script +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """Returns torch.sqrt(torch.max(0, x)) but with a zero sub-gradient where x is 0. + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L91-L99 + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +@torch.jit.script +def quat_from_matrix(matrix: torch.Tensor) -> torch.Tensor: + """Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: The rotation matrices. Shape is (..., 3, 3). + + Returns: + The quaternion in (w, x, y, z). Shape is (..., 4). + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L102-L161 + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + return quat_candidates[ + torch.nn.functional.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + + +def xyz_quat_to_4x4_matrix(xyz_quat: torch.Tensor) -> torch.Tensor: + """Convert a 7D pose vector (x, y, z, qw, qx, qy, qz) to a 4x4 transformation matrix. + + Args: + xyz_quat: The pose vector in (x, y, z, qw, qx, qy, qz). Shape is (..., 7). + + Returns: + The transformation matrix. Shape is (..., 4, 4). + """ + if xyz_quat.shape[-1] != 7: + raise ValueError(f"Invalid input shape {xyz_quat.shape}, expected (..., 7).") + + # get rotation matrix from quaternion + rot_mat = matrix_from_quat(xyz_quat[..., 3:7]) # (..., 3, 3) + + # create transformation + trans = ( + torch.eye(4, dtype=xyz_quat.dtype, device=xyz_quat.device) + .unsqueeze_(0) + .repeat(xyz_quat.shape[0], 1, 1) + ) + trans[..., :3, 3] = xyz_quat[..., :3] + trans[..., :3, :3] = rot_mat + return trans + + +def trans_matrix_to_xyz_quat(matrix: torch.Tensor) -> torch.Tensor: + """Convert a (4, 4) pose transformation matrix ((R, t), (0, 1)) to a 7D pose vector. + + Args: + matrix: The pose transformation matrix in ((R, t), (0, 1)). Shape is (..., 4, 4). + + Returns: + The pose vector in (x, y, z, qw, qx, qy, qz). Shape is (..., 7). + """ + if matrix.shape[-2:] != (4, 4): + raise ValueError(f"Invalid input shape {matrix.shape}, expected (..., 4, 4).") + + # get rotation matrix from quaternion + quat = quat_from_matrix(matrix[..., :3, :3]) # (..., 4) + + # create vector + vec = torch.concatenate([matrix[..., :3, 3], quat], dim=-1).to( + dtype=matrix.dtype, device=matrix.device + ) + return vec + + +@torch.jit.script +def quat_from_euler_xyz( + roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor +) -> torch.Tensor: + """Convert rotations given as Euler angles in radians to Quaternions. + + Note: + The euler angles are assumed in XYZ convention. + + Args: + roll: Rotation around x-axis (in radians). Shape is (N,). + pitch: Rotation around y-axis (in radians). Shape is (N,). + yaw: Rotation around z-axis (in radians). Shape is (N,). + + Returns: + The quaternion in (w, x, y, z). Shape is (N, 4). + """ + cy = torch.cos(yaw * 0.5) + sy = torch.sin(yaw * 0.5) + cr = torch.cos(roll * 0.5) + sr = torch.sin(roll * 0.5) + cp = torch.cos(pitch * 0.5) + sp = torch.sin(pitch * 0.5) + # compute quaternion + qw = cy * cr * cp + sy * sr * sp + qx = cy * sr * cp - sy * cr * sp + qy = cy * cr * sp + sy * sr * cp + qz = sy * cr * cp - cy * sr * sp + + return torch.stack([qw, qx, qy, qz], dim=-1) + + +def _axis_angle_rotation( + axis: Literal["X", "Y", "Z"], angle: torch.Tensor +) -> torch.Tensor: + """Return the rotation matrices for one of the rotations about an axis of which Euler angles describe, + for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: Euler angles in radians of any shape. + + Returns: + Rotation matrices. Shape is (..., 3, 3). + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L164-L191 + """ + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def matrix_from_euler( + euler_angles: torch.Tensor, convention: str = "XYZ" +) -> torch.Tensor: + """ + Convert rotations given as Euler angles (intrinsic) in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians. Shape is (..., 3). + convention: Convention string of three uppercase letters from {"X", "Y", and "Z"}. + For example, "XYZ" means that the rotations should be applied first about x, + then y, then z. Defaults to "XYZ". + + Returns: + Rotation matrices. Shape is (..., 3, 3). + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L194-L220 + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + _axis_angle_rotation(c, e) + for c, e in zip(convention, torch.unbind(euler_angles, -1)) + ] + # return functools.reduce(torch.matmul, matrices) + return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) + + +@torch.jit.script +def euler_xyz_from_quat( + quat: torch.Tensor, wrap_to_2pi: bool = False +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert rotations given as quaternions to Euler angles in radians. + + Note: + The euler angles are assumed in XYZ extrinsic convention. + + Args: + quat: The quaternion orientation in (w, x, y, z). Shape is (N, 4). + wrap_to_2pi (bool): Whether to wrap output Euler angles into [0, 2π). If + False, angles are returned in the default range (−π, π]. Defaults to + False. + + Returns: + A tuple containing roll-pitch-yaw. Each element is a tensor of shape (N,). + + Reference: + https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + """ + q_w, q_x, q_y, q_z = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] + # roll (x-axis rotation) + sin_roll = 2.0 * (q_w * q_x + q_y * q_z) + cos_roll = 1 - 2 * (q_x * q_x + q_y * q_y) + roll = torch.atan2(sin_roll, cos_roll) + + # pitch (y-axis rotation) + sin_pitch = 2.0 * (q_w * q_y - q_z * q_x) + pitch = torch.where( + torch.abs(sin_pitch) >= 1, + copysign(torch.pi / 2.0, sin_pitch), + torch.asin(sin_pitch), + ) + + # yaw (z-axis rotation) + sin_yaw = 2.0 * (q_w * q_z + q_x * q_y) + cos_yaw = 1 - 2 * (q_y * q_y + q_z * q_z) + yaw = torch.atan2(sin_yaw, cos_yaw) + + if wrap_to_2pi: + return roll % (2 * torch.pi), pitch % (2 * torch.pi), yaw % (2 * torch.pi) + return roll, pitch, yaw + + +@torch.jit.script +def axis_angle_from_quat(quat: torch.Tensor, eps: float = 1.0e-6) -> torch.Tensor: + """Convert rotations given as quaternions to axis/angle. + + Args: + quat: The quaternion orientation in (w, x, y, z). Shape is (..., 4). + eps: The tolerance for Taylor approximation. Defaults to 1.0e-6. + + Returns: + Rotations given as a vector in axis angle form. Shape is (..., 3). + The vector's magnitude is the angle turned anti-clockwise in radians around the vector's direction. + + Reference: + https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py#L526-L554 + """ + # Modified to take in quat as [q_w, q_x, q_y, q_z] + # Quaternion is [q_w, q_x, q_y, q_z] = [cos(theta/2), n_x * sin(theta/2), n_y * sin(theta/2), n_z * sin(theta/2)] + # Axis-angle is [a_x, a_y, a_z] = [theta * n_x, theta * n_y, theta * n_z] + # Thus, axis-angle is [q_x, q_y, q_z] / (sin(theta/2) / theta) + # When theta = 0, (sin(theta/2) / theta) is undefined + # However, as theta --> 0, we can use the Taylor approximation 1/2 - theta^2 / 48 + quat = quat * (1.0 - 2.0 * (quat[..., 0:1] < 0.0)) + mag = torch.linalg.norm(quat[..., 1:], dim=-1) + half_angle = torch.atan2(mag, quat[..., 0]) + angle = 2.0 * half_angle + # check whether to apply Taylor approximation + sin_half_angles_over_angles = torch.where( + angle.abs() > eps, torch.sin(half_angle) / angle, 0.5 - angle * angle / 48 + ) + return quat[..., 1:4] / sin_half_angles_over_angles.unsqueeze(-1) + + +@torch.jit.script +def quat_from_angle_axis(angle: torch.Tensor, axis: torch.Tensor) -> torch.Tensor: + """Convert rotations given as angle-axis to quaternions. + + Args: + angle: The angle turned anti-clockwise in radians around the vector's direction. Shape is (N,). + axis: The axis of rotation. Shape is (N, 3). + + Returns: + The quaternion in (w, x, y, z). Shape is (N, 4). + """ + theta = (angle / 2).unsqueeze(-1) + xyz = normalize(axis) * theta.sin() + w = theta.cos() + return normalize(torch.cat([w, xyz], dim=-1)) + + +@torch.jit.script +def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """Multiply two quaternions together. + + Args: + q1: The first quaternion in (w, x, y, z). Shape is (..., 4). + q2: The second quaternion in (w, x, y, z). Shape is (..., 4). + + Returns: + The product of the two quaternions in (w, x, y, z). Shape is (..., 4). + + Raises: + ValueError: Input shapes of ``q1`` and ``q2`` are not matching. + """ + # check input is correct + if q1.shape != q2.shape: + msg = f"Expected input quaternion shape mismatch: {q1.shape} != {q2.shape}." + raise ValueError(msg) + # reshape to (N, 4) for multiplication + shape = q1.shape + q1 = q1.reshape(-1, 4) + q2 = q2.reshape(-1, 4) + # extract components from quaternions + w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] + w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] + # perform multiplication + ww = (z1 + x1) * (x2 + y2) + yy = (w1 - y1) * (w2 + z2) + zz = (w1 + y1) * (w2 - z2) + xx = ww + yy + zz + qq = 0.5 * (xx + (z1 - x1) * (x2 - y2)) + w = qq - ww + (z1 - y1) * (y2 - z2) + x = qq - xx + (x1 + w1) * (x2 + w2) + y = qq - yy + (w1 - x1) * (y2 + z2) + z = qq - zz + (z1 + y1) * (w2 - x2) + + return torch.stack([w, x, y, z], dim=-1).view(shape) + + +@torch.jit.script +def yaw_quat(quat: torch.Tensor) -> torch.Tensor: + """Extract the yaw component of a quaternion. + + Args: + quat: The orientation in (w, x, y, z). Shape is (..., 4) + + Returns: + A quaternion with only yaw component. + """ + shape = quat.shape + quat_yaw = quat.view(-1, 4) + qw = quat_yaw[:, 0] + qx = quat_yaw[:, 1] + qy = quat_yaw[:, 2] + qz = quat_yaw[:, 3] + yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz)) + quat_yaw = torch.zeros_like(quat_yaw) + quat_yaw[:, 3] = torch.sin(yaw / 2) + quat_yaw[:, 0] = torch.cos(yaw / 2) + quat_yaw = normalize(quat_yaw) + return quat_yaw.view(shape) + + +@torch.jit.script +def quat_box_minus(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """The box-minus operator (quaternion difference) between two quaternions. + + Args: + q1: The first quaternion in (w, x, y, z). Shape is (N, 4). + q2: The second quaternion in (w, x, y, z). Shape is (N, 4). + + Returns: + The difference between the two quaternions. Shape is (N, 3). + + Reference: + https://github.com/ANYbotics/kindr/blob/master/doc/cheatsheet/cheatsheet_latest.pdf + """ + quat_diff = quat_mul(q1, quat_conjugate(q2)) # q1 * q2^-1 + return axis_angle_from_quat(quat_diff) # log(qd) + + +@torch.jit.script +def quat_box_plus( + q: torch.Tensor, delta: torch.Tensor, eps: float = 1.0e-6 +) -> torch.Tensor: + """The box-plus operator (quaternion update) to apply an increment to a quaternion. + + Args: + q: The initial quaternion in (w, x, y, z). Shape is (N, 4). + delta: The axis-angle perturbation. Shape is (N, 3). + eps: A small value to avoid division by zero. Defaults to 1e-6. + + Returns: + The updated quaternion after applying the perturbation. Shape is (N, 4). + + Reference: + https://github.com/ANYbotics/kindr/blob/master/doc/cheatsheet/cheatsheet_latest.pdf + """ + delta_norm = torch.clamp_min( + torch.linalg.norm(delta, dim=-1, keepdim=True), min=eps + ) + delta_quat = quat_from_angle_axis( + delta_norm.squeeze(-1), delta / delta_norm + ) # exp(dq) + new_quat = quat_mul(delta_quat, q) # Apply perturbation + return quat_unique(new_quat) + + +@torch.jit.script +def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + """Apply a quaternion rotation to a vector. + + Args: + quat: The quaternion in (w, x, y, z). Shape is (..., 4). + vec: The vector in (x, y, z). Shape is (..., 3). + + Returns: + The rotated vector in (x, y, z). Shape is (..., 3). + """ + # store shape + shape = vec.shape + # reshape to (N, 3) for multiplication + quat = quat.reshape(-1, 4) + vec = vec.reshape(-1, 3) + # extract components from quaternions + xyz = quat[:, 1:] + t = xyz.cross(vec, dim=-1) * 2 + return (vec + quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) + + +@torch.jit.script +def quat_apply_inverse(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + """Apply an inverse quaternion rotation to a vector. + + Args: + quat: The quaternion in (w, x, y, z). Shape is (..., 4). + vec: The vector in (x, y, z). Shape is (..., 3). + + Returns: + The rotated vector in (x, y, z). Shape is (..., 3). + """ + # store shape + shape = vec.shape + # reshape to (N, 3) for multiplication + quat = quat.reshape(-1, 4) + vec = vec.reshape(-1, 3) + # extract components from quaternions + xyz = quat[:, 1:] + t = xyz.cross(vec, dim=-1) * 2 + return (vec - quat[:, 0:1] * t + xyz.cross(t, dim=-1)).view(shape) + + +@torch.jit.script +def quat_apply_yaw(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: + """Rotate a vector only around the yaw-direction. + + Args: + quat: The orientation in (w, x, y, z). Shape is (N, 4). + vec: The vector in (x, y, z). Shape is (N, 3). + + Returns: + The rotated vector in (x, y, z). Shape is (N, 3). + """ + quat_yaw = yaw_quat(quat) + return quat_apply(quat_yaw, vec) + + +def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Rotate a vector by a quaternion along the last dimension of q and v. + .. deprecated v2.1.0: + This function will be removed in a future release in favor of the faster implementation :meth:`quat_apply`. + + Args: + q: The quaternion in (w, x, y, z). Shape is (..., 4). + v: The vector in (x, y, z). Shape is (..., 3). + + Returns: + The rotated vector in (x, y, z). Shape is (..., 3). + """ + # deprecation + omni.log.warn( + "The function 'quat_rotate' will be deprecated in favor of the faster method 'quat_apply'." + " Please use 'quat_apply' instead...." + ) + return quat_apply(q, v) + + +def quat_rotate_inverse(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Rotate a vector by the inverse of a quaternion along the last dimension of q and v. + + .. deprecated v2.1.0: + This function will be removed in a future release in favor of the faster implementation :meth:`quat_apply_inverse`. + Args: + q: The quaternion in (w, x, y, z). Shape is (..., 4). + v: The vector in (x, y, z). Shape is (..., 3). + + Returns: + The rotated vector in (x, y, z). Shape is (..., 3). + """ + omni.log.warn( + "The function 'quat_rotate_inverse' will be deprecated in favor of the faster method 'quat_apply_inverse'." + " Please use 'quat_apply_inverse' instead...." + ) + return quat_apply_inverse(q, v) + + +@torch.jit.script +def quat_error_magnitude(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """Computes the rotation difference between two quaternions. + + Args: + q1: The first quaternion in (w, x, y, z). Shape is (..., 4). + q2: The second quaternion in (w, x, y, z). Shape is (..., 4). + + Returns: + Angular error between input quaternions in radians. + """ + axis_angle_error = quat_box_minus(q1, q2) + return torch.norm(axis_angle_error, dim=-1) + + +@torch.jit.script +def skew_symmetric_matrix(vec: torch.Tensor) -> torch.Tensor: + """Computes the skew-symmetric matrix of a vector. + + Args: + vec: The input vector. Shape is (3,) or (N, 3). + + Returns: + The skew-symmetric matrix. Shape is (1, 3, 3) or (N, 3, 3). + + Raises: + ValueError: If input tensor is not of shape (..., 3). + """ + # check input is correct + if vec.shape[-1] != 3: + raise ValueError( + f"Expected input vector shape mismatch: {vec.shape} != (..., 3)." + ) + # unsqueeze the last dimension + if vec.ndim == 1: + vec = vec.unsqueeze(0) + # create a skew-symmetric matrix + skew_sym_mat = torch.zeros(vec.shape[0], 3, 3, device=vec.device, dtype=vec.dtype) + skew_sym_mat[:, 0, 1] = -vec[:, 2] + skew_sym_mat[:, 0, 2] = vec[:, 1] + skew_sym_mat[:, 1, 2] = -vec[:, 0] + skew_sym_mat[:, 1, 0] = vec[:, 2] + skew_sym_mat[:, 2, 0] = -vec[:, 1] + skew_sym_mat[:, 2, 1] = vec[:, 0] + + return skew_sym_mat + + +""" +Transformations +""" + + +def is_identity_pose(pos: torch.tensor, rot: torch.tensor) -> bool: + """Checks if input poses are identity transforms. + + The function checks if the input position and orientation are close to zero and + identity respectively using L2-norm. It does NOT check the error in the orientation. + + Args: + pos: The cartesian position. Shape is (N, 3). + rot: The quaternion in (w, x, y, z). Shape is (N, 4). + + Returns: + True if all the input poses result in identity transform. Otherwise, False. + """ + # create identity transformations + pos_identity = torch.zeros_like(pos) + rot_identity = torch.zeros_like(rot) + rot_identity[..., 0] = 1 + # compare input to identity + return torch.allclose(pos, pos_identity) and torch.allclose(rot, rot_identity) + + +@torch.jit.script +def combine_frame_transforms( + t01: torch.Tensor, + q01: torch.Tensor, + t12: Optional[torch.Tensor] = None, + q12: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Combine transformations between two reference frames into a stationary frame. + + It performs the following transformation operation: :math:`T_{02} = T_{01} \times T_{12}`, + where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. + + Args: + t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). + q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). + t12: Position of frame 2 w.r.t. frame 1. Shape is (N, 3). + Defaults to None, in which case the position is assumed to be zero. + q12: Quaternion orientation of frame 2 w.r.t. frame 1 in (w, x, y, z). Shape is (N, 4). + Defaults to None, in which case the orientation is assumed to be identity. + + Returns: + A tuple containing the position and orientation of frame 2 w.r.t. frame 0. + Shape of the tensors are (N, 3) and (N, 4) respectively. + """ + # compute orientation + if q12 is not None: + q02 = quat_mul(q01, q12) + else: + q02 = q01 + # compute translation + if t12 is not None: + t02 = t01 + quat_apply(q01, t12) + else: + t02 = t01 + + return t02, q02 + + +def rigid_body_twist_transform( + v0: torch.Tensor, w0: torch.Tensor, t01: torch.Tensor, q01: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Transform the linear and angular velocity of a rigid body between reference frames. + + Given the twist of 0 relative to frame 0, this function computes the twist of 1 relative to frame 1 + from the position and orientation of frame 1 relative to frame 0. The transformation follows the + equations: + + .. math:: + + w_11 = R_{10} w_00 = R_{01}^{-1} w_00 + v_11 = R_{10} v_00 + R_{10} (w_00 \times t_01) = R_{01}^{-1} (v_00 + (w_00 \times t_01)) + + where + + - :math:`R_{01}` is the rotation matrix from frame 0 to frame 1 derived from quaternion :math:`q_{01}`. + - :math:`t_{01}` is the position of frame 1 relative to frame 0 expressed in frame 0 + - :math:`w_0` is the angular velocity of 0 in frame 0 + - :math:`v_0` is the linear velocity of 0 in frame 0 + + Args: + v0: Linear velocity of 0 in frame 0. Shape is (N, 3). + w0: Angular velocity of 0 in frame 0. Shape is (N, 3). + t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). + q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). + + Returns: + A tuple containing: + - The transformed linear velocity in frame 1. Shape is (N, 3). + - The transformed angular velocity in frame 1. Shape is (N, 3). + """ + w1 = quat_rotate_inverse(q01, w0) + v1 = quat_rotate_inverse(q01, v0 + torch.cross(w0, t01, dim=-1)) + return v1, w1 + + +# @torch.jit.script +def subtract_frame_transforms( + t01: torch.Tensor, + q01: torch.Tensor, + t02: Union[torch.Tensor, None] = None, + q02: Union[torch.Tensor, None] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r"""Subtract transformations between two reference frames into a stationary frame. + + It performs the following transformation operation: :math:`T_{12} = T_{01}^{-1} \times T_{02}`, + where :math:`T_{AB}` is the homogeneous transformation matrix from frame A to B. + + Args: + t01: Position of frame 1 w.r.t. frame 0. Shape is (N, 3). + q01: Quaternion orientation of frame 1 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). + t02: Position of frame 2 w.r.t. frame 0. Shape is (N, 3). + Defaults to None, in which case the position is assumed to be zero. + q02: Quaternion orientation of frame 2 w.r.t. frame 0 in (w, x, y, z). Shape is (N, 4). + Defaults to None, in which case the orientation is assumed to be identity. + + Returns: + A tuple containing the position and orientation of frame 2 w.r.t. frame 1. + Shape of the tensors are (N, 3) and (N, 4) respectively. + """ + # compute orientation + q10 = quat_inv(q01) + if q02 is not None: + q12 = quat_mul(q10, q02) + else: + q12 = q10 + # compute translation + if t02 is not None: + t12 = quat_apply(q10, t02 - t01) + else: + t12 = quat_apply(q10, -t01) + return t12, q12 + + +# @torch.jit.script +def compute_pose_error( + t01: torch.Tensor, + q01: torch.Tensor, + t02: torch.Tensor, + q02: torch.Tensor, + rot_error_type: Literal["quat", "axis_angle"] = "axis_angle", +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the position and orientation error between source and target frames. + + Args: + t01: Position of source frame. Shape is (N, 3). + q01: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4). + t02: Position of target frame. Shape is (N, 3). + q02: Quaternion orientation of target frame in (w, x, y, z). Shape is (N, 4). + rot_error_type: The rotation error type to return: "quat", "axis_angle". + Defaults to "axis_angle". + + Returns: + A tuple containing position and orientation error. Shape of position error is (N, 3). + Shape of orientation error depends on the value of :attr:`rot_error_type`: + + - If :attr:`rot_error_type` is "quat", the orientation error is returned + as a quaternion. Shape is (N, 4). + - If :attr:`rot_error_type` is "axis_angle", the orientation error is + returned as an axis-angle vector. Shape is (N, 3). + + Raises: + ValueError: Invalid rotation error type. + """ + # Compute quaternion error (i.e., difference quaternion) + # Reference: https://personal.utdallas.edu/~sxb027100/dock/quaternion.html + # q_current_norm = q_current * q_current_conj + source_quat_norm = quat_mul(q01, quat_conjugate(q01))[:, 0] + # q_current_inv = q_current_conj / q_current_norm + source_quat_inv = quat_conjugate(q01) / source_quat_norm.unsqueeze(-1) + # q_error = q_target * q_current_inv + quat_error = quat_mul(q02, source_quat_inv) + + # Compute position error + pos_error = t02 - t01 + + # return error based on specified type + if rot_error_type == "quat": + return pos_error, quat_error + elif rot_error_type == "axis_angle": + # Convert to axis-angle error + axis_angle_error = axis_angle_from_quat(quat_error) + return pos_error, axis_angle_error + else: + raise ValueError( + f"Unsupported orientation error type: {rot_error_type}. Valid: 'quat', 'axis_angle'." + ) + + +@torch.jit.script +def apply_delta_pose( + source_pos: torch.Tensor, + source_rot: torch.Tensor, + delta_pose: torch.Tensor, + eps: float = 1.0e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Applies delta pose transformation on source pose. + + The first three elements of `delta_pose` are interpreted as cartesian position displacement. + The remaining three elements of `delta_pose` are interpreted as orientation displacement + in the angle-axis format. + + Args: + source_pos: Position of source frame. Shape is (N, 3). + source_rot: Quaternion orientation of source frame in (w, x, y, z). Shape is (N, 4).. + delta_pose: Position and orientation displacements. Shape is (N, 6). + eps: The tolerance to consider orientation displacement as zero. Defaults to 1.0e-6. + + Returns: + A tuple containing the displaced position and orientation frames. + Shape of the tensors are (N, 3) and (N, 4) respectively. + """ + # number of poses given + num_poses = source_pos.shape[0] + device = source_pos.device + + # interpret delta_pose[:, 0:3] as target position displacements + target_pos = source_pos + delta_pose[:, 0:3] + # interpret delta_pose[:, 3:6] as target rotation displacements + rot_actions = delta_pose[:, 3:6] + angle = torch.linalg.vector_norm(rot_actions, dim=1) + axis = rot_actions / angle.unsqueeze(-1) + # change from axis-angle to quat convention + identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device).repeat( + num_poses, 1 + ) + rot_delta_quat = torch.where( + angle.unsqueeze(-1).repeat(1, 4) > eps, + quat_from_angle_axis(angle, axis), + identity_quat, + ) + # TODO: Check if this is the correct order for this multiplication. + target_rot = quat_mul(rot_delta_quat, source_rot) + + return target_pos, target_rot + + +# @torch.jit.script +def transform_points( + points: torch.Tensor, + pos: Union[torch.Tensor, None] = None, + quat: Union[torch.Tensor, None] = None, +) -> torch.Tensor: + r"""Transform input points in a given frame to a target frame. + + This function transform points from a source frame to a target frame. The transformation is defined by the + position :math:`t` and orientation :math:`R` of the target frame in the source frame. + + .. math:: + p_{target} = R_{target} \times p_{source} + t_{target} + + If the input `points` is a batch of points, the inputs `pos` and `quat` must be either a batch of + positions and quaternions or a single position and quaternion. If the inputs `pos` and `quat` are + a single position and quaternion, the same transformation is applied to all points in the batch. + + If either the inputs :attr:`pos` and :attr:`quat` are None, the corresponding transformation is not applied. + + Args: + points: Points to transform. Shape is (N, P, 3) or (P, 3). + pos: Position of the target frame. Shape is (N, 3) or (3,). + Defaults to None, in which case the position is assumed to be zero. + quat: Quaternion orientation of the target frame in (w, x, y, z). Shape is (N, 4) or (4,). + Defaults to None, in which case the orientation is assumed to be identity. + + Returns: + Transformed points in the target frame. Shape is (N, P, 3) or (P, 3). + + Raises: + ValueError: If the inputs `points` is not of shape (N, P, 3) or (P, 3). + ValueError: If the inputs `pos` is not of shape (N, 3) or (3,). + ValueError: If the inputs `quat` is not of shape (N, 4) or (4,). + """ + points_batch = points.clone() + # check if inputs are batched + is_batched = points_batch.dim() == 3 + # -- check inputs + if points_batch.dim() == 2: + points_batch = points_batch[None] # (P, 3) -> (1, P, 3) + if points_batch.dim() != 3: + raise ValueError( + f"Expected points to have dim = 2 or dim = 3: got shape {points.shape}" + ) + if not (pos is None or pos.dim() == 1 or pos.dim() == 2): + raise ValueError( + f"Expected pos to have dim = 1 or dim = 2: got shape {pos.shape}" + ) + if not (quat is None or quat.dim() == 1 or quat.dim() == 2): + raise ValueError( + f"Expected quat to have dim = 1 or dim = 2: got shape {quat.shape}" + ) + # -- rotation + if quat is not None: + # convert to batched rotation matrix + rot_mat = matrix_from_quat(quat) + if rot_mat.dim() == 2: + rot_mat = rot_mat[None] # (3, 3) -> (1, 3, 3) + # convert points to matching batch size (N, P, 3) -> (N, 3, P) + # and apply rotation + points_batch = torch.matmul(rot_mat, points_batch.transpose_(1, 2)) + # (N, 3, P) -> (N, P, 3) + points_batch = points_batch.transpose_(1, 2) + # -- translation + if pos is not None: + # convert to batched translation vector + if pos.dim() == 1: + pos = pos[None, None, :] # (3,) -> (1, 1, 3) + else: + pos = pos[:, None, :] # (N, 3) -> (N, 1, 3) + # apply translation + points_batch += pos + # -- return points in same shape as input + if not is_batched: + points_batch = points_batch.squeeze(0) # (1, P, 3) -> (P, 3) + + return points_batch + + +""" +Projection operations. +""" + + +@torch.jit.script +def orthogonalize_perspective_depth( + depth: torch.Tensor, intrinsics: torch.Tensor +) -> torch.Tensor: + """Converts perspective depth image to orthogonal depth image. + + Perspective depth images contain distances measured from the camera's optical center. + Meanwhile, orthogonal depth images provide the distance from the camera's image plane. + This method uses the camera geometry to convert perspective depth to orthogonal depth image. + + The function assumes that the width and height are both greater than 1. + + Args: + depth: The perspective depth images. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1). + intrinsics: The camera's calibration matrix. If a single matrix is provided, the same + calibration matrix is used across all the depth images in the batch. + Shape is (3, 3) or (N, 3, 3). + + Returns: + The orthogonal depth images. Shape matches the input shape of depth images. + + Raises: + ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1). + ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3). + """ + # Clone inputs to avoid in-place modifications + perspective_depth_batch = depth.clone() + intrinsics_batch = intrinsics.clone() + + # Check if inputs are batched + is_batched = perspective_depth_batch.dim() == 4 or ( + perspective_depth_batch.dim() == 3 and perspective_depth_batch.shape[-1] != 1 + ) + + # Track whether the last dimension was singleton + add_last_dim = False + if perspective_depth_batch.dim() == 4 and perspective_depth_batch.shape[-1] == 1: + add_last_dim = True + perspective_depth_batch = perspective_depth_batch.squeeze( + dim=3 + ) # (N, H, W, 1) -> (N, H, W) + if perspective_depth_batch.dim() == 3 and perspective_depth_batch.shape[-1] == 1: + add_last_dim = True + perspective_depth_batch = perspective_depth_batch.squeeze( + dim=2 + ) # (H, W, 1) -> (H, W) + + if perspective_depth_batch.dim() == 2: + perspective_depth_batch = perspective_depth_batch[None] # (H, W) -> (1, H, W) + + if intrinsics_batch.dim() == 2: + intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3) + + if is_batched and intrinsics_batch.shape[0] == 1: + intrinsics_batch = intrinsics_batch.expand( + perspective_depth_batch.shape[0], -1, -1 + ) # (1, 3, 3) -> (N, 3, 3) + + # Validate input shapes + if perspective_depth_batch.dim() != 3: + raise ValueError( + f"Expected depth images to have 2, 3, or 4 dimensions; got {depth.shape}." + ) + if intrinsics_batch.dim() != 3: + raise ValueError( + f"Expected intrinsics to have shape (3, 3) or (N, 3, 3); got {intrinsics.shape}." + ) + + # Image dimensions + im_height, im_width = perspective_depth_batch.shape[1:] + + # Get the intrinsics parameters + fx = intrinsics_batch[:, 0, 0].view(-1, 1, 1) + fy = intrinsics_batch[:, 1, 1].view(-1, 1, 1) + cx = intrinsics_batch[:, 0, 2].view(-1, 1, 1) + cy = intrinsics_batch[:, 1, 2].view(-1, 1, 1) + + # Create meshgrid of pixel coordinates + u_grid = torch.arange(im_width, device=depth.device, dtype=depth.dtype) + v_grid = torch.arange(im_height, device=depth.device, dtype=depth.dtype) + u_grid, v_grid = torch.meshgrid(u_grid, v_grid, indexing="xy") + + # Expand the grids for batch processing + u_grid = u_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0], -1, -1) + v_grid = v_grid.unsqueeze(0).expand(perspective_depth_batch.shape[0], -1, -1) + + # Compute the squared terms for efficiency + x_term = ((u_grid - cx) / fx) ** 2 + y_term = ((v_grid - cy) / fy) ** 2 + + # Calculate the orthogonal (normal) depth + orthogonal_depth = perspective_depth_batch / torch.sqrt(1 + x_term + y_term) + + # Restore the last dimension if it was present in the input + if add_last_dim: + orthogonal_depth = orthogonal_depth.unsqueeze(-1) + + # Return to original shape if input was not batched + if not is_batched: + orthogonal_depth = orthogonal_depth.squeeze(0) + + return orthogonal_depth + + +@torch.jit.script +def unproject_depth( + depth: torch.Tensor, intrinsics: torch.Tensor, is_ortho: bool = True +) -> torch.Tensor: + r"""Un-project depth image into a pointcloud. + + This function converts orthogonal or perspective depth images into points given the calibration matrix + of the camera. It uses the following transformation based on camera geometry: + + .. math:: + p_{3D} = K^{-1} \times [u, v, 1]^T \times d + + where :math:`p_{3D}` is the 3D point, :math:`d` is the depth value (measured from the image plane), + :math:`u` and :math:`v` are the pixel coordinates and :math:`K` is the intrinsic matrix. + + The function assumes that the width and height are both greater than 1. This makes the function + deal with many possible shapes of depth images and intrinsics matrices. + + .. note:: + If :attr:`is_ortho` is False, the input depth images are transformed to orthogonal depth images + by using the :meth:`orthogonalize_perspective_depth` method. + + Args: + depth: The depth measurement. Shape is (H, W) or or (H, W, 1) or (N, H, W) or (N, H, W, 1). + intrinsics: The camera's calibration matrix. If a single matrix is provided, the same + calibration matrix is used across all the depth images in the batch. + Shape is (3, 3) or (N, 3, 3). + is_ortho: Whether the input depth image is orthogonal or perspective depth image. If True, the input + depth image is considered as the *orthogonal* type, where the measurements are from the camera's + image plane. If False, the depth image is considered as the *perspective* type, where the + measurements are from the camera's optical center. Defaults to True. + + Returns: + The 3D coordinates of points. Shape is (P, 3) or (N, P, 3). + + Raises: + ValueError: When depth is not of shape (H, W) or (H, W, 1) or (N, H, W) or (N, H, W, 1). + ValueError: When intrinsics is not of shape (3, 3) or (N, 3, 3). + """ + # clone inputs to avoid in-place modifications + intrinsics_batch = intrinsics.clone() + # convert depth image to orthogonal if needed + if not is_ortho: + depth_batch = orthogonalize_perspective_depth(depth, intrinsics) + else: + depth_batch = depth.clone() + + # check if inputs are batched + is_batched = depth_batch.dim() == 4 or ( + depth_batch.dim() == 3 and depth_batch.shape[-1] != 1 + ) + # make sure inputs are batched + if depth_batch.dim() == 3 and depth_batch.shape[-1] == 1: + depth_batch = depth_batch.squeeze(dim=2) # (H, W, 1) -> (H, W) + if depth_batch.dim() == 2: + depth_batch = depth_batch[None] # (H, W) -> (1, H, W) + if depth_batch.dim() == 4 and depth_batch.shape[-1] == 1: + depth_batch = depth_batch.squeeze(dim=3) # (N, H, W, 1) -> (N, H, W) + if intrinsics_batch.dim() == 2: + intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3) + # check shape of inputs + if depth_batch.dim() != 3: + raise ValueError( + f"Expected depth images to have dim = 2 or 3 or 4: got shape {depth.shape}" + ) + if intrinsics_batch.dim() != 3: + raise ValueError( + f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}" + ) + + # get image height and width + im_height, im_width = depth_batch.shape[1:] + # create image points in homogeneous coordinates (3, H x W) + indices_u = torch.arange(im_width, device=depth.device, dtype=depth.dtype) + indices_v = torch.arange(im_height, device=depth.device, dtype=depth.dtype) + img_indices = torch.stack( + torch.meshgrid([indices_u, indices_v], indexing="ij"), dim=0 + ).reshape(2, -1) + pixels = torch.nn.functional.pad( + img_indices, (0, 0, 0, 1), mode="constant", value=1.0 + ) + pixels = pixels.unsqueeze(0) # (3, H x W) -> (1, 3, H x W) + + # unproject points into 3D space + points = torch.matmul(torch.inverse(intrinsics_batch), pixels) # (N, 3, H x W) + points = points / points[:, -1, :].unsqueeze(1) # normalize by last coordinate + # flatten depth image (N, H, W) -> (N, H x W) + depth_batch = ( + depth_batch.transpose_(1, 2).reshape(depth_batch.shape[0], -1).unsqueeze(2) + ) + depth_batch = depth_batch.expand(-1, -1, 3) + # scale points by depth + points_xyz = points.transpose_(1, 2) * depth_batch # (N, H x W, 3) + + # return points in same shape as input + if not is_batched: + points_xyz = points_xyz.squeeze(0) + + return points_xyz + + +@torch.jit.script +def project_points(points: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor: + r"""Projects 3D points into 2D image plane. + + This project 3D points into a 2D image plane. The transformation is defined by the intrinsic + matrix of the camera. + + .. math:: + + \begin{align} + p &= K \times p_{3D} = \\ + p_{2D} &= \begin{pmatrix} u \\ v \\ d \end{pmatrix} + = \begin{pmatrix} p[0] / p[2] \\ p[1] / p[2] \\ Z \end{pmatrix} + \end{align} + + where :math:`p_{2D} = (u, v, d)` is the projected 3D point, :math:`p_{3D} = (X, Y, Z)` is the + 3D point and :math:`K \in \mathbb{R}^{3 \times 3}` is the intrinsic matrix. + + If `points` is a batch of 3D points and `intrinsics` is a single intrinsic matrix, the same + calibration matrix is applied to all points in the batch. + + Args: + points: The 3D coordinates of points. Shape is (P, 3) or (N, P, 3). + intrinsics: Camera's calibration matrix. Shape is (3, 3) or (N, 3, 3). + + Returns: + Projected 3D coordinates of points. Shape is (P, 3) or (N, P, 3). + """ + # clone the inputs to avoid in-place operations modifying the original data + points_batch = points.clone() + intrinsics_batch = intrinsics.clone() + + # check if inputs are batched + is_batched = points_batch.dim() == 2 + # make sure inputs are batched + if points_batch.dim() == 2: + points_batch = points_batch[None] # (P, 3) -> (1, P, 3) + if intrinsics_batch.dim() == 2: + intrinsics_batch = intrinsics_batch[None] # (3, 3) -> (1, 3, 3) + # check shape of inputs + if points_batch.dim() != 3: + raise ValueError(f"Expected points to have dim = 3: got shape {points.shape}.") + if intrinsics_batch.dim() != 3: + raise ValueError( + f"Expected intrinsics to have shape (3, 3) or (N, 3, 3): got shape {intrinsics.shape}." + ) + + # project points into 2D image plane + points_2d = torch.matmul(intrinsics_batch, points_batch.transpose(1, 2)) + points_2d = points_2d / points_2d[:, -1, :].unsqueeze( + 1 + ) # normalize by last coordinate + points_2d = points_2d.transpose_(1, 2) # (N, 3, P) -> (N, P, 3) + # replace last coordinate with depth + points_2d[:, :, -1] = points_batch[:, :, -1] + + # return points in same shape as input + if not is_batched: + points_2d = points_2d.squeeze(0) # (1, 3, P) -> (3, P) + + return points_2d + + +""" +Sampling +""" + + +@torch.jit.script +def default_orientation(num: int, device: str) -> torch.Tensor: + """Returns identity rotation transform. + + Args: + num: The number of rotations to sample. + device: Device to create tensor on. + + Returns: + Identity quaternion in (w, x, y, z). Shape is (num, 4). + """ + quat = torch.zeros((num, 4), dtype=torch.float32, device=device) + quat[..., 0] = 1.0 + + return quat + + +@torch.jit.script +def random_orientation(num: int, device: str) -> torch.Tensor: + """Returns sampled rotation in 3D as quaternion. + + Args: + num: The number of rotations to sample. + device: Device to create tensor on. + + Returns: + Sampled quaternion in (w, x, y, z). Shape is (num, 4). + + Reference: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.random.html + """ + # sample random orientation from normal distribution + quat = torch.randn((num, 4), dtype=torch.float32, device=device) + # normalize the quaternion + return torch.nn.functional.normalize(quat, p=2.0, dim=-1, eps=1e-12) + + +@torch.jit.script +def random_yaw_orientation(num: int, device: str) -> torch.Tensor: + """Returns sampled rotation around z-axis. + + Args: + num: The number of rotations to sample. + device: Device to create tensor on. + + Returns: + Sampled quaternion in (w, x, y, z). Shape is (num, 4). + """ + roll = torch.zeros(num, dtype=torch.float32, device=device) + pitch = torch.zeros(num, dtype=torch.float32, device=device) + yaw = 2 * torch.pi * torch.rand(num, dtype=torch.float32, device=device) + + return quat_from_euler_xyz(roll, pitch, yaw) + + +def sample_triangle( + lower: float, upper: float, size: Union[int, tuple[int, ...]], device: str +) -> torch.Tensor: + """Randomly samples tensor from a triangular distribution. + + Args: + lower: The lower range of the sampled tensor. + upper: The upper range of the sampled tensor. + size: The shape of the tensor. + device: Device to create tensor on. + + Returns: + Sampled tensor. Shape is based on :attr:`size`. + """ + # convert to tuple + if isinstance(size, int): + size = (size,) + # create random tensor in the range [-1, 1] + r = 2 * torch.rand(*size, device=device) - 1 + # convert to triangular distribution + r = torch.where(r < 0.0, -torch.sqrt(-r), torch.sqrt(r)) + # rescale back to [0, 1] + r = (r + 1.0) / 2.0 + # rescale to range [lower, upper] + return (upper - lower) * r + lower + + +def sample_uniform( + lower: Union[torch.Tensor, float], + upper: Union[torch.Tensor, float], + size: Union[int, tuple[int, ...]], +) -> torch.Tensor: + """Sample uniformly within a range. + + Args: + lower: Lower bound of uniform range. + upper: Upper bound of uniform range. + size: The shape of the tensor. + device: Device to create tensor on. + + Returns: + Sampled tensor. Shape is based on :attr:`size`. + """ + # convert to tuple + if isinstance(size, int): + size = (size,) + # return tensor + return torch.rand(*size, device=lower.device) * (upper - lower) + lower + + +def sample_log_uniform( + lower: Union[torch.Tensor, float], + upper: Union[torch.Tensor, float], + size: Union[int, tuple[int, ...]], + device: str, +) -> torch.Tensor: + r"""Sample using log-uniform distribution within a range. + + The log-uniform distribution is defined as a uniform distribution in the log-space. It + is useful for sampling values that span several orders of magnitude. The sampled values + are uniformly distributed in the log-space and then exponentiated to get the final values. + + .. math:: + + x = \exp(\text{uniform}(\log(\text{lower}), \log(\text{upper}))) + + Args: + lower: Lower bound of uniform range. + upper: Upper bound of uniform range. + size: The shape of the tensor. + device: Device to create tensor on. + + Returns: + Sampled tensor. Shape is based on :attr:`size`. + """ + # cast to tensor if not already + if not isinstance(lower, torch.Tensor): + lower = torch.tensor(lower, dtype=torch.float32, device=device) + if not isinstance(upper, torch.Tensor): + upper = torch.tensor(upper, dtype=torch.float32, device=device) + # sample in log-space and exponentiate + return torch.exp(sample_uniform(torch.log(lower), torch.log(upper), size, device)) + + +def sample_gaussian( + mean: Union[torch.Tensor, float], + std: Union[torch.Tensor, float], + size: Union[int, tuple[int, ...]], + device: str, +) -> torch.Tensor: + """Sample using gaussian distribution. + + Args: + mean: Mean of the gaussian. + std: Std of the gaussian. + size: The shape of the tensor. + device: Device to create tensor on. + + Returns: + Sampled tensor. + """ + if isinstance(mean, float): + if isinstance(size, int): + size = (size,) + return torch.normal(mean=mean, std=std, size=size).to(device=device) + else: + return torch.normal(mean=mean, std=std).to(device=device) + + +def sample_cylinder( + radius: float, + h_range: tuple[float, float], + size: Union[int, tuple[int, ...]], + device: str, +) -> torch.Tensor: + """Sample 3D points uniformly on a cylinder's surface. + + The cylinder is centered at the origin and aligned with the z-axis. The height of the cylinder is + sampled uniformly from the range :obj:`h_range`, while the radius is fixed to :obj:`radius`. + + The sampled points are returned as a tensor of shape :obj:`(*size, 3)`, i.e. the last dimension + contains the x, y, and z coordinates of the sampled points. + + Args: + radius: The radius of the cylinder. + h_range: The minimum and maximum height of the cylinder. + size: The shape of the tensor. + device: Device to create tensor on. + + Returns: + Sampled tensor. Shape is :obj:`(*size, 3)`. + """ + # sample angles + angles = (torch.rand(size, device=device) * 2 - 1) * torch.pi + h_min, h_max = h_range + # add shape + if isinstance(size, int): + size = (size, 3) + else: + size += (3,) + # allocate a tensor + xyz = torch.zeros(size, device=device) + xyz[..., 0] = radius * torch.cos(angles) + xyz[..., 1] = radius * torch.sin(angles) + xyz[..., 2].uniform_(h_min, h_max) + # return positions + return xyz + + +""" +Orientation Conversions +""" + + +def convert_camera_frame_orientation_convention( + orientation: torch.Tensor, + origin: Literal["opengl", "ros", "world"] = "opengl", + target: Literal["opengl", "ros", "world"] = "ros", +) -> torch.Tensor: + r"""Converts a quaternion representing a rotation from one convention to another. + + In USD, the camera follows the ``"opengl"`` convention. Thus, it is always in **Y up** convention. + This means that the camera is looking down the -Z axis with the +Y axis pointing up , and +X axis pointing right. + However, in ROS, the camera is looking down the +Z axis with the +Y axis pointing down, and +X axis pointing right. + Thus, the camera needs to be rotated by :math:`180^{\circ}` around the X axis to follow the ROS convention. + + .. math:: + + T_{ROS} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & -1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} T_{USD} + + On the other hand, the typical world coordinate system is with +X pointing forward, +Y pointing left, + and +Z pointing up. The camera can also be set in this convention by rotating the camera by :math:`90^{\circ}` + around the X axis and :math:`-90^{\circ}` around the Y axis. + + .. math:: + + T_{WORLD} = \begin{bmatrix} 0 & 0 & -1 & 0 \\ -1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} T_{USD} + + Thus, based on their application, cameras follow different conventions for their orientation. This function + converts a quaternion from one convention to another. + + Possible conventions are: + + - :obj:`"opengl"` - forward axis: -Z - up axis +Y - Offset is applied in the OpenGL (Usd.Camera) convention + - :obj:`"ros"` - forward axis: +Z - up axis -Y - Offset is applied in the ROS convention + - :obj:`"world"` - forward axis: +X - up axis +Z - Offset is applied in the World Frame convention + + Args: + orientation: Quaternion of form `(w, x, y, z)` with shape (..., 4) in source convention. + origin: Convention to convert from. Defaults to "opengl". + target: Convention to convert to. Defaults to "ros". + + Returns: + Quaternion of form `(w, x, y, z)` with shape (..., 4) in target convention + """ + if target == origin: + return orientation.clone() + + # -- unify input type + if origin == "ros": + # convert from ros to opengl convention + rotm = matrix_from_quat(orientation) + rotm[:, :, 2] = -rotm[:, :, 2] + rotm[:, :, 1] = -rotm[:, :, 1] + # convert to opengl convention + quat_gl = quat_from_matrix(rotm) + elif origin == "world": + # convert from world (x forward and z up) to opengl convention + rotm = matrix_from_quat(orientation) + rotm = torch.matmul( + rotm, + matrix_from_euler( + torch.tensor([math.pi / 2, -math.pi / 2, 0], device=orientation.device), + "XYZ", + ), + ) + # convert to isaac-sim convention + quat_gl = quat_from_matrix(rotm) + else: + quat_gl = orientation + + # -- convert to target convention + if target == "ros": + # convert from opengl to ros convention + rotm = matrix_from_quat(quat_gl) + rotm[:, :, 2] = -rotm[:, :, 2] + rotm[:, :, 1] = -rotm[:, :, 1] + return quat_from_matrix(rotm) + elif target == "world": + # convert from opengl to world (x forward and z up) convention + rotm = matrix_from_quat(quat_gl) + rotm = torch.matmul( + rotm, + matrix_from_euler( + torch.tensor([math.pi / 2, -math.pi / 2, 0], device=orientation.device), + "XYZ", + ).T, + ) + return quat_from_matrix(rotm) + else: + return quat_gl.clone() + + +def create_rotation_matrix_from_view( + eyes: torch.Tensor, + targets: torch.Tensor, + up_axis: Literal["Y", "Z"] = "Z", + device: str = "cpu", +) -> torch.Tensor: + """Compute the rotation matrix from world to view coordinates. + + This function takes a vector ''eyes'' which specifies the location + of the camera in world coordinates and the vector ''targets'' which + indicate the position of the object. + The output is a rotation matrix representing the transformation + from world coordinates -> view coordinates. + + The inputs eyes and targets can each be a + - 3 element tuple/list + - torch tensor of shape (1, 3) + - torch tensor of shape (N, 3) + + Args: + eyes: Position of the camera in world coordinates. + targets: Position of the object in world coordinates. + up_axis: The up axis of the camera. Defaults to "Z". + device: The device to create torch tensors on. Defaults to "cpu". + + The vectors are broadcast against each other so they all have shape (N, 3). + + Returns: + R: (N, 3, 3) batched rotation matrices + + Reference: + Based on PyTorch3D (https://github.com/facebookresearch/pytorch3d/blob/eaf0709d6af0025fe94d1ee7cec454bc3054826a/pytorch3d/renderer/cameras.py#L1635-L1685) + """ + if up_axis == "Y": + up_axis_vec = torch.tensor( + (0, 1, 0), device=device, dtype=torch.float32 + ).repeat(eyes.shape[0], 1) + elif up_axis == "Z": + up_axis_vec = torch.tensor( + (0, 0, 1), device=device, dtype=torch.float32 + ).repeat(eyes.shape[0], 1) + else: + raise ValueError(f"Invalid up axis: {up_axis}. Valid options are 'Y' and 'Z'.") + + # get rotation matrix in opengl format (-Z forward, +Y up) + z_axis = -torch.nn.functional.normalize(targets - eyes, eps=1e-5) + x_axis = torch.nn.functional.normalize( + torch.cross(up_axis_vec, z_axis, dim=1), eps=1e-5 + ) + y_axis = torch.nn.functional.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5) + is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all( + dim=1, keepdim=True + ) + if is_close.any(): + replacement = torch.nn.functional.normalize( + torch.cross(y_axis, z_axis, dim=1), eps=1e-5 + ) + x_axis = torch.where(is_close, replacement, x_axis) + R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1) + return R.transpose(1, 2) + + +def make_pose(pos: torch.Tensor, rot: torch.Tensor) -> torch.Tensor: + """Creates transformation matrices from positions and rotation matrices. + + Args: + pos: Batch of position vectors with last dimension of 3. + rot: Batch of rotation matrices with last 2 dimensions of (3, 3). + + Returns: + Batch of pose matrices with last 2 dimensions of (4, 4). + """ + assert isinstance(pos, torch.Tensor), "Input must be a torch tensor" + assert isinstance(rot, torch.Tensor), "Input must be a torch tensor" + assert pos.shape[:-1] == rot.shape[:-2] + assert pos.shape[-1] == rot.shape[-2] == rot.shape[-1] == 3 + pose = torch.zeros(pos.shape[:-1] + (4, 4), dtype=pos.dtype, device=pos.device) + pose[..., :3, :3] = rot + pose[..., :3, 3] = pos + pose[..., 3, 3] = 1.0 + return pose + + +def unmake_pose(pose: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Splits transformation matrices into positions and rotation matrices. + + Args: + pose: Batch of pose matrices with last 2 dimensions of (4, 4). + + Returns: + Tuple containing: + - Batch of position vectors with last dimension of 3. + - Batch of rotation matrices with last 2 dimensions of (3, 3). + """ + assert isinstance(pose, torch.Tensor), "Input must be a torch tensor" + return pose[..., :3, 3], pose[..., :3, :3] + + +def pose_inv(pose: torch.Tensor) -> torch.Tensor: + """Computes the inverse of transformation matrices. + + The inverse of a pose matrix [R t; 0 1] is [R.T -R.T*t; 0 1]. + + Args: + pose: Batch of pose matrices with last 2 dimensions of (4, 4). + + Returns: + Batch of inverse pose matrices with last 2 dimensions of (4, 4). + """ + assert isinstance(pose, torch.Tensor), "Input must be a torch tensor" + num_axes = len(pose.shape) + assert num_axes >= 2 + + inv_pose = torch.zeros_like(pose) + + # Take transpose of last 2 dimensions + inv_pose[..., :3, :3] = pose[..., :3, :3].transpose(-1, -2) + + # note: PyTorch matmul wants shapes [..., 3, 3] x [..., 3, 1] -> [..., 3, 1] so we add a dimension and take it away after + inv_pose[..., :3, 3] = torch.matmul(-inv_pose[..., :3, :3], pose[..., :3, 3:4])[ + ..., 0 + ] + inv_pose[..., 3, 3] = 1.0 + return inv_pose + + +def pose_in_A_to_pose_in_B( + pose_in_A: torch.Tensor, pose_A_in_B: torch.Tensor +) -> torch.Tensor: + """Converts poses from one coordinate frame to another. + + Transforms matrices representing point C in frame A + to matrices representing the same point C in frame B. + + Example usage: + + frame_C_in_B = pose_in_A_to_pose_in_B(frame_C_in_A, frame_A_in_B) + + Args: + pose_in_A: Batch of transformation matrices of point C in frame A. + pose_A_in_B: Batch of transformation matrices of frame A in frame B. + + Returns: + Batch of transformation matrices of point C in frame B. + """ + assert isinstance(pose_in_A, torch.Tensor), "Input must be a torch tensor" + assert isinstance(pose_A_in_B, torch.Tensor), "Input must be a torch tensor" + return torch.matmul(pose_A_in_B, pose_in_A) + + +def quat_slerp(q1: torch.Tensor, q2: torch.Tensor, tau: float) -> torch.Tensor: + """Performs spherical linear interpolation (SLERP) between two quaternions. + + This function does not support batch processing. + + Args: + q1: First quaternion in (w, x, y, z) format. + q2: Second quaternion in (w, x, y, z) format. + tau: Interpolation coefficient between 0 (q1) and 1 (q2). + + Returns: + Interpolated quaternion in (w, x, y, z) format. + """ + assert isinstance(q1, torch.Tensor), "Input must be a torch tensor" + assert isinstance(q2, torch.Tensor), "Input must be a torch tensor" + if tau == 0.0: + return q1 + elif tau == 1.0: + return q2 + d = torch.dot(q1, q2) + if abs(abs(d) - 1.0) < torch.finfo(q1.dtype).eps * 4.0: + return q1 + if d < 0.0: + # Invert rotation + d = -d + q2 *= -1.0 + angle = torch.acos(torch.clamp(d, -1, 1)) + if abs(angle) < torch.finfo(q1.dtype).eps * 4.0: + return q1 + isin = 1.0 / torch.sin(angle) + q1 = q1 * torch.sin((1.0 - tau) * angle) * isin + q2 = q2 * torch.sin(tau * angle) * isin + q1 = q1 + q2 + return q1 + + +def interpolate_rotations( + R1: torch.Tensor, R2: torch.Tensor, num_steps: int, axis_angle: bool = True +) -> torch.Tensor: + """Interpolates between two rotation matrices. + + Args: + R1: First rotation matrix. (4x4). + R2: Second rotation matrix. (4x4). + num_steps: Number of desired interpolated rotations (excluding start and end). + axis_angle: If True, interpolate in axis-angle representation; + otherwise use slerp. Defaults to True. + + Returns: + Stack of interpolated rotation matrices of shape (num_steps + 1, 4, 4), + including the start and end rotations. + """ + assert isinstance(R1, torch.Tensor), "Input must be a torch tensor" + assert isinstance(R2, torch.Tensor), "Input must be a torch tensor" + if axis_angle: + # Delta rotation expressed as axis-angle + delta_rot_mat = torch.matmul(R2, R1.transpose(-1, -2)) + delta_quat = quat_from_matrix(delta_rot_mat) + delta_axis_angle = axis_angle_from_quat(delta_quat) + + # Grab angle + delta_angle = torch.linalg.norm(delta_axis_angle) + + # Fix the axis, and chunk the angle up into steps + rot_step_size = delta_angle / num_steps + + # Convert into delta rotation matrices, and then convert to absolute rotations + if delta_angle < 0.05: + # Small angle - don't bother with interpolation + rot_steps = torch.stack([R2 for _ in range(num_steps)]) + else: + # Make sure that axis is a unit vector + delta_axis = delta_axis_angle / delta_angle + delta_rot_steps = [ + matrix_from_quat(quat_from_angle_axis(i * rot_step_size, delta_axis)) + for i in range(num_steps) + ] + rot_steps = torch.stack( + [torch.matmul(delta_rot_steps[i], R1) for i in range(num_steps)] + ) + else: + q1 = quat_from_matrix(R1) + q2 = quat_from_matrix(R2) + rot_steps = torch.stack( + [ + matrix_from_quat(quat_slerp(q1, q2, tau=float(i) / num_steps)) + for i in range(num_steps) + ] + ) + + # Add in endpoint + rot_steps = torch.cat([rot_steps, R2[None]], dim=0) + + return rot_steps + + +def interpolate_poses( + pose_1: torch.Tensor, + pose_2: torch.Tensor, + num_steps: int = None, + step_size: float = None, + perturb: bool = False, +) -> tuple[torch.Tensor, int]: + """Performs linear interpolation between two poses. + + Args: + pose_1: 4x4 start pose. + pose_2: 4x4 end pose. + num_steps: If provided, specifies the number of desired interpolated points. + Passing 0 corresponds to no interpolation. If None, step_size must be provided. + step_size: If provided, determines number of steps based on distance between poses. + perturb: If True, randomly perturbs interpolated position points. + + Returns: + Tuple containing: + - Array of shape (N + 2, 4, 4) corresponding to the interpolated pose path. + - Number of interpolated points (N) in the path. + """ + assert isinstance(pose_1, torch.Tensor), "Input must be a torch tensor" + assert isinstance(pose_2, torch.Tensor), "Input must be a torch tensor" + assert step_size is None or num_steps is None + + pos1, rot1 = unmake_pose(pose_1) + pos2, rot2 = unmake_pose(pose_2) + + if num_steps == 0: + # Skip interpolation + return ( + torch.cat([pos1[None], pos2[None]], dim=0), + torch.cat([rot1[None], rot2[None]], dim=0), + num_steps, + ) + + delta_pos = pos2 - pos1 + if num_steps is None: + assert torch.norm(delta_pos) > 0 + num_steps = math.ceil(torch.norm(delta_pos) / step_size) + + num_steps += 1 # Include starting pose + assert num_steps >= 2 + + # Linear interpolation of positions + pos_step_size = delta_pos / num_steps + grid = torch.arange(num_steps, dtype=torch.float32) + if perturb: + # Move interpolation grid points by up to half-size forward or backward + perturbations = torch.rand(num_steps - 2) - 0.5 + grid[1:-1] += perturbations + pos_steps = torch.stack([pos1 + grid[i] * pos_step_size for i in range(num_steps)]) + + # Add in endpoint + pos_steps = torch.cat([pos_steps, pos2[None]], dim=0) + + # Interpolate rotations + rot_steps = interpolate_rotations( + R1=rot1, R2=rot2, num_steps=num_steps, axis_angle=True + ) + + pose_steps = make_pose(pos_steps, rot_steps) + return pose_steps, num_steps - 1 + + +def transform_poses_from_frame_A_to_frame_B( + src_poses: torch.Tensor, frame_A: torch.Tensor, frame_B: torch.Tensor +) -> torch.Tensor: + """Transforms poses from one coordinate frame to another preserving relative poses. + + Args: + src_poses: Input pose sequence (shape [T, 4, 4]) from source demonstration. + frame_A: 4x4 frame A pose. + frame_B: 4x4 frame B pose. + + Returns: + Transformed pose sequence of shape [T, 4, 4]. + """ + # Transform source end effector poses to be relative to source object frame + src_poses_rel_frame_B = pose_in_A_to_pose_in_B( + pose_in_A=src_poses, + pose_A_in_B=pose_inv(frame_B[None]), + ) + + # Apply relative poses to current object frame to obtain new target eef poses + transformed_poses = pose_in_A_to_pose_in_B( + pose_in_A=src_poses_rel_frame_B, + pose_A_in_B=frame_A[None], + ) + return transformed_poses + + +def generate_random_rotation(rot_boundary: float = (2 * math.pi)) -> torch.Tensor: + """Generates a random rotation matrix using Euler angles. + + Args: + rot_boundary: Range for random rotation angles around each axis (x, y, z). + + Returns: + 3x3 rotation matrix. + """ + angles = torch.rand(3) * rot_boundary + Rx = torch.tensor( + [ + [1, 0, 0], + [0, torch.cos(angles[0]), -torch.sin(angles[0])], + [0, torch.sin(angles[0]), torch.cos(angles[0])], + ] + ) + + Ry = torch.tensor( + [ + [torch.cos(angles[1]), 0, torch.sin(angles[1])], + [0, 1, 0], + [-torch.sin(angles[1]), 0, torch.cos(angles[1])], + ] + ) + + Rz = torch.tensor( + [ + [torch.cos(angles[2]), -torch.sin(angles[2]), 0], + [torch.sin(angles[2]), torch.cos(angles[2]), 0], + [0, 0, 1], + ] + ) + + # Combined rotation matrix + R = torch.matmul(torch.matmul(Rz, Ry), Rx) + return R + + +def generate_random_translation(pos_boundary: float = 1) -> torch.Tensor: + """Generates a random translation vector. + + Args: + pos_boundary: Range for random translation values in 3D space. + + Returns: + 3-element translation vector. + """ + return ( + torch.rand(3) * 2 * pos_boundary - pos_boundary + ) # Random translation in 3D space + + +def generate_random_transformation_matrix( + pos_boundary: float = 1, rot_boundary: float = (2 * math.pi) +) -> torch.Tensor: + """Generates a random transformation matrix combining rotation and translation. + + Args: + pos_boundary: Range for random translation values. + rot_boundary: Range for random rotation angles. + + Returns: + 4x4 transformation matrix. + """ + R = generate_random_rotation(rot_boundary) + translation = generate_random_translation(pos_boundary) + + # Create the transformation matrix + T = torch.eye(4) + T[:3, :3] = R + T[:3, 3] = translation + + return T diff --git a/embodichain/utils/module_utils.py b/embodichain/utils/module_utils.py new file mode 100644 index 00000000..a62eb41c --- /dev/null +++ b/embodichain/utils/module_utils.py @@ -0,0 +1,217 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import importlib +from typing import List, Union, Optional, Callable, Any + + +def find_function_from_modules( + function_name: str, modules: List[Union[str, Any]], raise_if_not_found: bool = True +) -> Optional[Callable]: + """ + Find a function from multiple Python modules. + + Args: + function_name (str): Name of the function to find + modules (List[Union[str, Any]]): List of module names (strings) or module objects + raise_if_not_found (bool): Whether to raise an exception if function is not found + + Returns: + Optional[Callable]: The function if found, None otherwise + + Raises: + AttributeError: If function is not found and raise_if_not_found is True + ImportError: If a module cannot be imported + """ + for module in modules: + try: + # Handle both module names (strings) and module objects + if isinstance(module, str): + mod = importlib.import_module(module) + else: + mod = module + + # Check if the function exists in this module + if hasattr(mod, function_name): + return getattr(mod, function_name) + + except ImportError as e: + print(f"Warning: Could not import module {module}: {e}") + continue + + if raise_if_not_found: + raise AttributeError( + f"Function '{function_name}' not found in any of the provided modules: {modules}" + ) + + return None + + +def find_class_from_modules( + class_name: str, modules: List[Union[str, Any]], raise_if_not_found: bool = True +) -> Optional[type]: + """ + Find a class from multiple Python modules. + + Args: + class_name (str): Name of the class to find + modules (List[Union[str, Any]]): List of module names (strings) or module objects + raise_if_not_found (bool): Whether to raise an exception if class is not found + + Returns: + Optional[type]: The class if found, None otherwise + + Raises: + AttributeError: If class is not found and raise_if_not_found is True + ImportError: If a module cannot be imported + """ + for module in modules: + try: + # Handle both module names (strings) and module objects + if isinstance(module, str): + mod = importlib.import_module(module) + else: + mod = module + + # Check if the class exists in this module + if hasattr(mod, class_name): + return getattr(mod, class_name) + + except ImportError as e: + print(f"Warning: Could not import module {module}: {e}") + continue + + if raise_if_not_found: + raise AttributeError( + f"Class '{class_name}' not found in any of the provided modules: {modules}" + ) + + return None + + +def get_all_functions_from_module(module: Union[str, Any]) -> dict: + """ + Get all functions from a module. + + Args: + module (Union[str, Any]): Module name (string) or module object + + Returns: + dict: Dictionary mapping function names to function objects + """ + import inspect + + if isinstance(module, str): + mod = importlib.import_module(module) + else: + mod = module + + functions = {} + for name, obj in inspect.getmembers(mod): + if inspect.isfunction(obj): + functions[name] = obj + + return functions + + +def find_function_by_pattern( + pattern: str, modules: List[Union[str, Any]], case_sensitive: bool = True +) -> dict: + """ + Find functions matching a pattern from multiple modules. + + Args: + pattern (str): Pattern to match (supports wildcards * and ?) + modules (List[Union[str, Any]]): List of module names or module objects + case_sensitive (bool): Whether the search should be case sensitive + + Returns: + dict: Dictionary mapping module names to dictionaries of matching functions + """ + import fnmatch + + results = {} + + for module in modules: + try: + if isinstance(module, str): + mod = importlib.import_module(module) + module_name = module + else: + mod = module + module_name = mod.__name__ + + module_functions = get_all_functions_from_module(mod) + matching_functions = {} + + for func_name, func_obj in module_functions.items(): + if case_sensitive: + if fnmatch.fnmatch(func_name, pattern): + matching_functions[func_name] = func_obj + else: + if fnmatch.fnmatch(func_name.lower(), pattern.lower()): + matching_functions[func_name] = func_obj + + if matching_functions: + results[module_name] = matching_functions + + except ImportError as e: + print(f"Warning: Could not import module {module}: {e}") + continue + + return results + + +def get_all_exported_items_from_module(module: Union[str, Any]) -> List[str]: + """ + Get all exported items from a module by checking its __all__ attribute. + + Args: + module (Union[str, Any]): Module name (string) or module object + + Returns: + List[str]: List of exported item names + """ + if isinstance(module, str): + mod = importlib.import_module(module) + else: + mod = module + + if hasattr(mod, "__all__"): + return list(mod.__all__) + else: + # If __all__ is not defined, return all public attributes (not starting with _) + return [name for name in dir(mod) if not name.startswith("_")] + + +# Example usage and test functions +if __name__ == "__main__": + # Example 1: Find a specific function from multiple modules + modules_to_search = ["math", "os", "sys"] + + # Find 'sqrt' function + sqrt_func = find_function_from_modules( + "sqrt", modules_to_search, raise_if_not_found=False + ) + if sqrt_func: + print(f"Found sqrt function: {sqrt_func}") + print(f"sqrt(16) = {sqrt_func(16)}") + + # Example 2: Find functions by pattern + pattern_results = find_function_by_pattern( + "*path*", ["os", "os.path"], case_sensitive=False + ) + print(f"Functions matching '*path*': {pattern_results}") diff --git a/embodichain/utils/string.py b/embodichain/utils/string.py new file mode 100644 index 00000000..ba794f75 --- /dev/null +++ b/embodichain/utils/string.py @@ -0,0 +1,336 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# All rights reserved. +# +# This file incorporates code from the Isaac Lab Project +# Copyright (c) 2022-2025, The Isaac Lab Project Developers +# (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause +# ---------------------------------------------------------------------------- + + +import ast +import importlib +import inspect +import re +from collections.abc import Callable, Sequence +from typing import Any, Union + + +def callable_to_string(value: Callable) -> str: + """Converts a callable object to a string. + + Args: + value: A callable object. + + Raises: + ValueError: When the input argument is not a callable object. + + Returns: + A string representation of the callable object. + """ + # check if callable + if not callable(value): + raise ValueError(f"The input argument is not callable: {value}.") + # check if lambda function + if value.__name__ == "": + # we resolve the lambda expression by checking the source code and extracting the line with lambda expression + # we also remove any comments from the line + lambda_line = ( + inspect.getsourcelines(value)[0][0] + .strip() + .split("lambda")[1] + .strip() + .split(",")[0] + ) + lambda_line = re.sub(r"#.*$", "", lambda_line).rstrip() + return f"lambda {lambda_line}" + else: + # get the module and function name + module_name = value.__module__ + function_name = value.__name__ + # return the string + return f"{module_name}:{function_name}" + + +def string_to_callable(name: str) -> Callable: + """Resolves the module and function names to return the function. + + Args: + name: The function name. The format should be 'module:attribute_name' or a + lambda expression of format: 'lambda x: x'. + + Raises: + ValueError: When the resolved attribute is not a function. + ValueError: When the module cannot be found. + + Returns: + Callable: The function loaded from the module. + """ + try: + if is_lambda_expression(name): + callable_object = eval(name) + else: + mod_name, attr_name = name.split(":") + mod = importlib.import_module(mod_name) + callable_object = getattr(mod, attr_name) + # check if attribute is callable + if callable(callable_object): + return callable_object + else: + raise AttributeError(f"The imported object is not callable: '{name}'") + except (ValueError, ModuleNotFoundError) as e: + msg = ( + f"Could not resolve the input string '{name}' into callable object." + " The format of input should be 'module:attribute_name'.\n" + f"Received the error:\n {e}." + ) + raise ValueError(msg) + + +def is_regular_expression(pattern: str) -> bool: + """Checks if the input string is a valid regular expression. + Args: + pattern: The input string to check. + + Returns: + bool: True if the input string is a valid regular expression, False otherwise. + """ + try: + re.compile(pattern) + return True + except re.error: + return False + + +def remove_regex_chars(pattern: str) -> str: + """Remove common regex metacharacters from the input pattern. + Args: + pattern: The input string pattern. + + Returns: + The cleaned pattern with regex metacharacters removed. + """ + # Remove common regex metacharacters + regex_chars = r"[\.\*\+\?\[\]\(\)\{\}\^\$\|\\]" + return re.sub(regex_chars, "", pattern) + + +def is_lambda_expression(name: str) -> bool: + """Checks if the input string is a lambda expression. + + Args: + name: The input string. + + Returns: + Whether the input string is a lambda expression. + """ + try: + ast.parse(name) + return isinstance(ast.parse(name).body[0], ast.Expr) and isinstance( + ast.parse(name).body[0].value, ast.Lambda + ) + except SyntaxError: + return False + + +def resolve_matching_names( + keys: Union[str, Sequence[str]], + list_of_strings: Sequence[str], + preserve_order: bool = False, +) -> tuple[list[int], list[str]]: + """Match a list of query regular expressions against a list of strings and return the matched indices and names. + + When a list of query regular expressions is provided, the function checks each target string against each + query regular expression and returns the indices of the matched strings and the matched strings. + + If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order + of the provided list of strings. This means that the ordering is dictated by the order of the target strings + and not the order of the query regular expressions. + + If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order + of the provided list of query regular expressions. + + For example, consider the list of strings is ['a', 'b', 'c', 'd', 'e'] and the regular expressions are ['a|c', 'b']. + If :attr:`preserve_order` is False, then the function will return the indices of the matched strings and the + strings as: ([0, 1, 2], ['a', 'b', 'c']). When :attr:`preserve_order` is True, it will return them as: + ([0, 2, 1], ['a', 'c', 'b']). + + Note: + The function does not sort the indices. It returns the indices in the order they are found. + + Args: + keys: A regular expression or a list of regular expressions to match the strings in the list. + list_of_strings: A list of strings to match. + preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False. + + Returns: + A tuple of lists containing the matched indices and names. + + Raises: + ValueError: When multiple matches are found for a string in the list. + ValueError: When not all regular expressions are matched. + """ + # resolve name keys + if isinstance(keys, str): + keys = [keys] + # find matching patterns + index_list = [] + names_list = [] + key_idx_list = [] + # book-keeping to check that we always have a one-to-one mapping + # i.e. each target string should match only one regular expression + target_strings_match_found = [None for _ in range(len(list_of_strings))] + keys_match_found = [[] for _ in range(len(keys))] + # loop over all target strings + for target_index, potential_match_string in enumerate(list_of_strings): + for key_index, re_key in enumerate(keys): + if re.fullmatch(re_key, potential_match_string): + # check if match already found + if target_strings_match_found[target_index]: + raise ValueError( + f"Multiple matches for '{potential_match_string}':" + f" '{target_strings_match_found[target_index]}' and '{re_key}'!" + ) + # add to list + target_strings_match_found[target_index] = re_key + index_list.append(target_index) + names_list.append(potential_match_string) + key_idx_list.append(key_index) + # add for regex key + keys_match_found[key_index].append(potential_match_string) + # reorder keys if they should be returned in order of the query keys + if preserve_order: + reordered_index_list = [None] * len(index_list) + global_index = 0 + for key_index in range(len(keys)): + for key_idx_position, key_idx_entry in enumerate(key_idx_list): + if key_idx_entry == key_index: + reordered_index_list[key_idx_position] = global_index + global_index += 1 + # reorder index and names list + index_list_reorder = [None] * len(index_list) + names_list_reorder = [None] * len(index_list) + for idx, reorder_idx in enumerate(reordered_index_list): + index_list_reorder[reorder_idx] = index_list[idx] + names_list_reorder[reorder_idx] = names_list[idx] + # update + index_list = index_list_reorder + names_list = names_list_reorder + # check that all regular expressions are matched + if not all(keys_match_found): + # make this print nicely aligned for debugging + msg = "\n" + for key, value in zip(keys, keys_match_found): + msg += f"\t{key}: {value}\n" + msg += f"Available strings: {list_of_strings}\n" + # raise error + raise ValueError( + f"Not all regular expressions are matched! Please check that the regular expressions are correct: {msg}" + ) + # return + return index_list, names_list + + +def resolve_matching_names_values( + data: dict[str, Any], list_of_strings: Sequence[str], preserve_order: bool = False +) -> tuple[list[int], list[str], list[Any]]: + """Match a list of regular expressions in a dictionary against a list of strings and return + the matched indices, names, and values. + + If the :attr:`preserve_order` is True, the ordering of the matched indices and names is the same as the order + of the provided list of strings. This means that the ordering is dictated by the order of the target strings + and not the order of the query regular expressions. + + If the :attr:`preserve_order` is False, the ordering of the matched indices and names is the same as the order + of the provided list of query regular expressions. + + For example, consider the dictionary is {"a|d|e": 1, "b|c": 2}, the list of strings is ['a', 'b', 'c', 'd', 'e']. + If :attr:`preserve_order` is False, then the function will return the indices of the matched strings, the + matched strings, and the values as: ([0, 1, 2, 3, 4], ['a', 'b', 'c', 'd', 'e'], [1, 2, 2, 1, 1]). When + :attr:`preserve_order` is True, it will return them as: ([0, 3, 4, 1, 2], ['a', 'd', 'e', 'b', 'c'], [1, 1, 1, 2, 2]). + + Args: + data: A dictionary of regular expressions and values to match the strings in the list. + list_of_strings: A list of strings to match. + preserve_order: Whether to preserve the order of the query keys in the returned values. Defaults to False. + + Returns: + A tuple of lists containing the matched indices, names, and values. + + Raises: + TypeError: When the input argument :attr:`data` is not a dictionary. + ValueError: When multiple matches are found for a string in the dictionary. + ValueError: When not all regular expressions in the data keys are matched. + + Reference: + https://github.com/isaac-sim/IsaacLab/blob/main/source/isaaclab/isaaclab/utils/string.py#L178-L271 + """ + # check valid input + if not isinstance(data, dict): + raise TypeError( + f"Input argument `data` should be a dictionary. Received: {data}" + ) + # find matching patterns + index_list = [] + names_list = [] + values_list = [] + key_idx_list = [] + # book-keeping to check that we always have a one-to-one mapping + # i.e. each target string should match only one regular expression + target_strings_match_found = [None for _ in range(len(list_of_strings))] + keys_match_found = [[] for _ in range(len(data))] + # loop over all target strings + for target_index, potential_match_string in enumerate(list_of_strings): + for key_index, (re_key, value) in enumerate(data.items()): + if re.fullmatch(re_key, potential_match_string): + # check if match already found + if target_strings_match_found[target_index]: + raise ValueError( + f"Multiple matches for '{potential_match_string}':" + f" '{target_strings_match_found[target_index]}' and '{re_key}'!" + ) + # add to list + target_strings_match_found[target_index] = re_key + index_list.append(target_index) + names_list.append(potential_match_string) + values_list.append(value) + key_idx_list.append(key_index) + # add for regex key + keys_match_found[key_index].append(potential_match_string) + # reorder keys if they should be returned in order of the query keys + if preserve_order: + reordered_index_list = [None] * len(index_list) + global_index = 0 + for key_index in range(len(data)): + for key_idx_position, key_idx_entry in enumerate(key_idx_list): + if key_idx_entry == key_index: + reordered_index_list[key_idx_position] = global_index + global_index += 1 + # reorder index and names list + index_list_reorder = [None] * len(index_list) + names_list_reorder = [None] * len(index_list) + values_list_reorder = [None] * len(index_list) + for idx, reorder_idx in enumerate(reordered_index_list): + index_list_reorder[reorder_idx] = index_list[idx] + names_list_reorder[reorder_idx] = names_list[idx] + values_list_reorder[reorder_idx] = values_list[idx] + # update + index_list = index_list_reorder + names_list = names_list_reorder + values_list = values_list_reorder + # check that all regular expressions are matched + if not all(keys_match_found): + # make this print nicely aligned for debugging + msg = "\n" + for key, value in zip(data.keys(), keys_match_found): + msg += f"\t{key}: {value}\n" + msg += f"Available strings: {list_of_strings}\n" + # raise error + raise ValueError( + f"Not all regular expressions are matched! Please check that the regular expressions are correct: {msg}" + ) + # return + return index_list, names_list, values_list diff --git a/embodichain/utils/utility.py b/embodichain/utils/utility.py new file mode 100644 index 00000000..d5506df6 --- /dev/null +++ b/embodichain/utils/utility.py @@ -0,0 +1,650 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + + +import cv2 +import pickle +import argparse +import time +import torch +import functools +import open3d as o3d +import numpy as np + +from tqdm import tqdm +from PIL import Image +from functools import wraps +from typing import Dict, List, Tuple, Optional, Callable, Any + +from embodichain.utils.string import callable_to_string + + +@functools.lru_cache(maxsize=None) # memoization +def get_func_tag(tagName): + return TagDecorator(tagName) + + +# https://stackoverflow.com/questions/41834530/how-to-make-python-decorators-work-like-a-tag-to-make-function-calls-by-tag +class TagDecorator(object): + def __init__(self, tagName): + self.functions = {} + self.tagName = tagName + + def __str__(self): + return "".format(tagName=self.tagName) + + def __call__(self, f): + class_name = f.__qualname__.split(".")[0] + if class_name in self.functions.keys(): + self.functions[class_name].update({f.__name__: f}) + else: + self.functions.update({class_name: {f.__name__: f}}) + return f + + +def set_attributes_for_class(self, params=None): + if params: + for k, v in params.items(): + if k != "self" and not k.startswith("_"): + setattr(self, k, v) + + +def timer(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() # 记录开始时间 + result = func(*args, **kwargs) # 执行被装饰的函数 + end_time = time.time() # 记录结束时间 + elapsed_time = end_time - start_time # 计算耗时 + # log_warning( + # f"Function '{func.__name__}' executed in {elapsed_time:.4f} seconds" + # ) + return result # 返回被装饰函数的执行结果 + + return wrapper + + +from embodichain.utils.logger import log_warning, log_error + + +def snake_to_camel(name): + import re + + name = re.sub("_([a-zA-Z])", lambda m: (m.group(1).upper()), name) + name = re.sub("-+", "_", name) + return name + + +def convert_bytes(d): + if isinstance(d, dict): + return {convert_bytes(k): convert_bytes(v) for k, v in d.items()} + if isinstance(d, list): + return [convert_bytes(i) for i in d] + if isinstance(d, bytes): + return d.decode("UTF-8") + return d + + +def pad_to_chunk(x: np.ndarray, chunk_size: int) -> np.ndarray: + if x.shape[0] < chunk_size: + + if len(x.shape) <= 2: + x = np.concatenate( + [ + x, + np.tile( + x[-1:], + (chunk_size - x.shape[0], 1), + ), + ], + axis=0, + ) + elif len(x.shape) == 3 or len(x.shape) == 4: + x = np.concatenate( + [ + x, + np.tile( + x[-1:], + ( + (chunk_size - x.shape[0], 1, 1, 1) + if len(x[:1].shape) == 4 + else (chunk_size - x.shape[0], 1, 1) + ), + ), + ], + axis=0, + ) + else: + raise ValueError("Unsupported shape {}.".format(x.shape)) + + assert x.shape[0] == chunk_size, "shape {} vs chunk_size {}.".format( + x.shape, chunk_size + ) + return x + + +def dict2args(d: Dict) -> argparse.ArgumentParser: + args = argparse.Namespace(**d) + return args + + +def parser2dict(args) -> Dict: + return vars(args) + + +def change_nested_dict(dict, keys, mode: str = "update", value=None): + """ + Update or delete a nested dictionary at a specific key. + + Args: + dict (dict): The dictionary to update. + keys (tuple): Tuple of keys to the target value. + mode (str): Whether to delete or remove the given key-value pair. + value: The new value to set. + + Returns: + dict: The updated dictionary. + """ + if mode == "update": + if value is None: + log_error("The value to be updated is None, please check.") + else: + if len(keys) == 1: + dict[keys[0]] = value + else: + change_nested_dict(dict[keys[0]], keys[1:], "update", value) + elif mode == "delete": + if value is not None: + log_warning( + f"Under mode 'delete' only the keys to be removed need to be provided. But got a not-None vlaue {value}." + ) + if len(keys) == 1: + del dict[keys[0]] + else: + change_nested_dict(dict[keys[0]], keys[1:], "delete") + else: + log_error(f"Mode '{mode}; is noet realized yet.") + + return dict + + +def set_texture_to_material(material, texture: np.ndarray, env, type: str = "color"): + if type == "color": + # TODO: Currently, create texture for base color map without alpha has error. + # should be fixed in the future. + if texture.shape[-1] == 3: + texture = np.concatenate( + [texture, np.ones_like(texture[..., :1]) * 255], axis=-1 + ) + + color_texture = env.create_color_texture(texture, has_alpha=True) + if color_texture: + material.get_inst().set_base_color_map(color_texture) + else: + log_error(f"Unsupported texture type: {type}. Only 'color' is supported.") + + +def get_random_real_image(base_path: str, read: bool = True) -> np.ndarray: + import os, random + + # 随机选择一个子文件夹 + subfolders = [f.path for f in os.scandir(base_path) if f.is_dir()] + selected_subfolder = random.choice(subfolders) + + # 随机选择一个图片文件 + image_files = [ + f.path + for f in os.scandir(selected_subfolder) + if f.is_file() and f.path.endswith((".png", ".jpg", ".jpeg")) + ] + selected_image_file = random.choice(image_files) + + # 读取图片 + if read: + real_image = cv2.imread(selected_image_file) + return real_image + else: + return selected_image_file + + +def read_all_folder_images(base_path: str) -> List[np.ndarray]: + """Read all images from all subfolders under the base path. + + Args: + base_path (str): The base directory containing subfolders with images. + + Returns: + List[np.ndarray]: A list of images read from the subfolders. + """ + import os + + images = [] + # 遍历所有子文件夹 + # First, collect all image files + image_files = [] + for subdir, _, files in os.walk(base_path): + for file in files: + if file.endswith((".png", ".jpg", ".jpeg")): + image_files.append(os.path.join(subdir, file)) + + # Then process with progress bar + for image_path in tqdm(image_files, desc="Loading images"): + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image is not None: + images.append(image) + return images + + +def reset_all_seeds(seed: int = 0): + import torch + import random + import open3d as o3d + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + o3d.utility.random.seed(seed) + + +def do_process_decorator( + pre_process: Optional[bool] = True, post_process: Optional[bool] = True +): + """A decorator to decorate :meth:`inference`. Usage and example is comming soon. + + Args: + pre_process (Optional[bool], optional): whether do pre-process. Defaults to True. + post_process (Optional[bool], optional): whether do post-process. Defaults to True. + """ + + def inner_decorator(func: Callable): + def main_wrapper(self, *args, **kwargs): + if pre_process: + input = getattr(self, "pre_process")(*args, **kwargs) + if isinstance(input, dict): + ret = func(self, input) + else: + ret = func(self, *input) + if post_process: + output = getattr(self, "post_process")(*ret) + return output + + return main_wrapper + + return inner_decorator + + +def pad_img_list(img_list, max_len): + while len(img_list) < max_len: + img_list.append(None) + + +def get_right_name(name: str): + return name + "_r" + + +def read_video(video_path: str): + video = cv2.VideoCapture(video_path) + total_frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + length = total_frame_count + fps = video.get(cv2.CAP_PROP_FPS) + return video, fps, length + + +def create_video_writer( + video_path: str, resolution: Tuple[int, int], fps: int +) -> cv2.VideoWriter: + fourcc = cv2.VideoWriter_fourcc(*"mp4v") # 用于mp4格式的生成 + video_vis = cv2.VideoWriter( + video_path, + fourcc, + fps, + (resolution[1], resolution[0]), + ) + return video_vis + + +def update_array( + mat: np.ndarray, vec: np.ndarray, first_is_latest: bool = True +) -> np.ndarray: + if first_is_latest: + mat[1:, :] = mat[:-1, :] + mat[0, :] = vec + return mat + else: + mat[:-1, :] = mat[1:, :] + mat[-1, :] = vec + return mat + + +def save_pkl(path: str, content): + with open(path, "wb") as f: # open a text file + pickle.dump(content, f) # serialize the list + + +def load_pkl( + path: str, +): + with open(path, "rb") as f: + content = pickle.load(f) + return content + + +def save_json(path: str, data): + import json + + with open(path, "w") as f: + json.dump(data, f, indent=4) + + +def save_json(path: str, data): + import json + + with open(path, "w") as f: + json.dump(data, f, indent=4) + + +def load_json(path: str) -> Dict: + import json + + with open(path) as f: + config = json.load(f) + return config + + +def load_txt(path: str) -> str: + with open(path, "r") as f: + contents = f.read().strip() + return contents + + +def encode_image(image: np.ndarray, format: str = "png"): + import base64 + + image_encode = cv2.imencode(f".{format}", image)[1] + base64_image = base64.b64encode(image_encode).decode("utf-8") + return base64_image + + +def inv_transform(transform: np.ndarray) -> np.ndarray: + """inverse transformation + + Args: + transform (np.array): [np.array of size [4 x 4]] + + Returns: + np.array: [np.array of size [4 x 4]] + """ + r = transform[:3, :3] + t = transform[:3, 3].T + inv_r = r.T + inv_t = -inv_r @ t + inv_pose = np.eye(4, dtype=np.float32) + inv_pose[:3, :3] = inv_r + inv_pose[:3, 3] = inv_t + return inv_pose + + +def scale_image(image, scale=0.5): + import cv2 + + h, w = image.shape[:2] + if image.dtype == np.uint8: + return cv2.resize( + image, + ( + int(w * scale), + int(h * scale), + ), + ) + elif image.dtype == np.bool_: + + image = image.astype(np.uint8) + image = cv2.resize( + image, + ( + int(w * scale), + int(h * scale), + ), + ) + return image.astype(np.bool_) + + +def padding_by_longest_edge(img: np.ndarray) -> np.ndarray: + w, h, c = img.shape[:3] + e = np.maximum(w, h) + ret = np.zeros((e, e, c)).astype(img.dtype) + ret[:w, :h] = img + return ret + + +def center_crop(img: np.ndarray, dim: Tuple[int, int]) -> np.ndarray: + """Returns center cropped image + Args: + img: image to be center cropped + dim: dimensions (width, height) to be cropped + """ + width, height = img.shape[1], img.shape[0] + + # process crop width and height for max available dimension + crop_width = dim[0] if dim[0] < img.shape[1] else img.shape[1] + crop_height = dim[1] if dim[1] < img.shape[0] else img.shape[0] + mid_x, mid_y = int(width / 2), int(height / 2) + cw2, ch2 = int(crop_width / 2), int(crop_height / 2) + crop_img = img[mid_y - ch2 : mid_y + ch2, mid_x - cw2 : mid_x + cw2] + return crop_img + + +def postprocess_small_regions( + masks: np.ndarray, + min_area: int, + max_area: int, +) -> List[int]: + keep_idx = [] + n = len(masks) if isinstance(masks, list) else masks.shape[0] + for i in range(n): + area = masks[i].astype(np.uint8).sum() + keep = area > min_area and area <= max_area + if keep: + keep_idx.append(i) + return keep_idx + + +def mask_to_box(mask: np.ndarray) -> np.ndarray: + from torchvision.ops import masks_to_boxes + import torch + + bbox = ( + masks_to_boxes(torch.from_numpy(mask).unsqueeze(0)) + .squeeze(0) + .numpy() + .astype(np.int16) + ) + return bbox + + +def postprocess_small_regions( + masks: np.ndarray, + min_area: int, + max_area: int, +) -> List[int]: + keep_idx = [] + n = len(masks) if isinstance(masks, list) else masks.shape[0] + for i in range(n): + area = masks[i].astype(np.uint8).sum() + keep = area > min_area and area <= max_area + if keep: + keep_idx.append(i) + return keep_idx + + +def remove_overlap_mask( + masks: List[np.ndarray], keep_inner_threshold: float = 0.5, eps: float = 1e-5 +) -> List[int]: + keep_ids = [] + + areas = [mask.astype(np.uint8).sum() for mask in masks] + + for i, maskA in enumerate(masks): + keep = True + for j, maskB in enumerate(masks): + if i == j: + # 同一个mask,跳过 + continue + if areas[i] > areas[j]: + # 大的包裹mask不能被过滤 + continue + + # 计算交集 + intersection = (maskA * maskB).sum() + # 计算maskA的覆盖比例 + overlap_ratio = intersection / (areas[i] + eps) + # maskA被maskB覆盖的面积比例达到threshold,不保留 + if overlap_ratio >= keep_inner_threshold: + keep = False + break + + if keep: + keep_ids.append(i) + + return keep_ids + + +def encode_image_from_path(image_path: str): + import base64 + + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def check_shared_memory_exists(name): + from multiprocessing import shared_memory + + try: + shm = shared_memory.SharedMemory(name=name) + return True + except FileNotFoundError: + return False + + +def get_class_instance(module_name, class_name, *args, **kwargs): + """Get an instance of a class from a module. + + Args: + module_name (str): The name of the module to import. + class_name (str): The name of the class to instantiate. + + Returns: + object: An instance of the specified class. + """ + import importlib + + # Import the module + module = importlib.import_module(module_name) + # Get the class from the module + cls = getattr(module, class_name) + return cls + + +def key_in_nested_dict(d: Dict, key: str) -> bool: + """Check if a key exists in a nested dictionary. + + Args: + d (Dict): A dictionary that may contain nested dictionaries. + key (str): The key to search for in the dictionary. + + Returns: + bool: True if the key exists in the dictionary or any of its nested dictionaries, False otherwise. + """ + if key in d: + return True + for value in d.values(): + if isinstance(value, dict): # Check if the value is a nested dictionary + if key_in_nested_dict( + value, key + ): # Recursively check the nested dictionary + return True + return False + + +def class_to_dict(obj: object) -> dict[str, Any]: + """Convert an object into dictionary recursively. + + Note: + Ignores all names starting with "__" (i.e. built-in methods). + + Args: + obj: An instance of a class to convert. + + Raises: + ValueError: When input argument is not an object. + + Returns: + Converted dictionary mapping. + """ + # check that input data is class instance + if not hasattr(obj, "__class__"): + raise ValueError(f"Expected a class instance. Received: {type(obj)}.") + # convert object to dictionary + if isinstance(obj, dict): + obj_dict = obj + elif isinstance(obj, torch.Tensor): + # We have to treat torch tensors specially because `torch.tensor.__dict__` returns an empty + # dict, which would mean that a torch.tensor would be stored as an empty dict. Instead we + # want to store it directly as the tensor. + return obj + elif hasattr(obj, "__dict__"): + obj_dict = obj.__dict__ + else: + return obj + + # convert to dictionary + data = dict() + for key, value in obj_dict.items(): + # disregard builtin attributes + if key.startswith("__"): + continue + # check if attribute is callable -- function + if callable(value): + data[key] = callable_to_string(value) + # check if attribute is a dictionary + elif hasattr(value, "__dict__") or isinstance(value, dict): + data[key] = class_to_dict(value) + # check if attribute is a list or tuple + elif isinstance(value, (list, tuple)): + data[key] = type(value)([class_to_dict(v) for v in value]) + else: + data[key] = value + return data + + +def get_mesh_md5(mesh: o3d.t.geometry.TriangleMesh) -> str: + """get mesh md5 unique key + + Args: + mesh (o3d.geometry.TriangleMesh): mesh + + Returns: + str: mesh md5 value. + """ + import hashlib + + vert = np.array(mesh.vertex.positions.numpy(), dtype=float) + face = np.array(mesh.triangle.indices.numpy(), dtype=float) + mix = np.vstack([vert, face]) + return hashlib.md5(np.array2string(mix).encode()).hexdigest() diff --git a/embodichain/utils/visualizer.py b/embodichain/utils/visualizer.py new file mode 100644 index 00000000..0b24a362 --- /dev/null +++ b/embodichain/utils/visualizer.py @@ -0,0 +1,568 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import platform +import numpy as np +import matplotlib.pyplot as plt + +from matplotlib.colors import LogNorm +from matplotlib.patches import Circle, Rectangle +from matplotlib import colors as mcolors +from embodichain.utils.logger import log_error +from typing import Dict, List +from operator import sub + +from matplotlib import rc + +x_min, x_max = 0.275, 1.125 +y_min, y_max = -0.425, 0.425 +bins = 100 + + +def draw_keypoints( + rgb: np.ndarray, keypoints_2d: np.ndarray, color_dict: dict = None +) -> np.ndarray: + import cv2 + + keypoints_2d = np.nan_to_num(keypoints_2d, nan=0) + assert ( + keypoints_2d.max(0)[0] <= 1 and keypoints_2d.max(0)[1] <= 1 + ), keypoints_2d.max(0) + assert ( + keypoints_2d.min(0)[0] >= 0 and keypoints_2d.min(0)[1] >= 0 + ), keypoints_2d.min(0) + n = keypoints_2d.shape[0] + color = [(255 - i / n * 255, 0, i / n * 255) for i in range(n)] + height, width = rgb.shape[:2] + + rgb = np.copy(rgb) + + for i in range(n): + assigned_color = False + if color_dict is not None: + for key_ids, color_str in color_dict.items(): + if i in key_ids: + color[i] = tuple( + int(chl * 255) for chl in mcolors.to_rgb(color_str)[::-1] + ) + assigned_color = True + break + if not assigned_color: + log_error( + f"Once color_dict is provided, all the keypoints ought to be colored, but got {i} not colored." + ) + + # Draw the keypoint + rgb = cv2.circle( + rgb.copy(), + (int(keypoints_2d[i][0] * width), int(keypoints_2d[i][1] * height)), + 2, + color[i], + 2, + ) + + return rgb + + +def draw_action_distribution( + actions: Dict[str, np.ndarray], + indices: Dict[str, List[int]] = None, + output_path: str = None, + smooth: bool = False, + return_data: bool = False, +): + import matplotlib.pyplot as plt + from scipy.ndimage import gaussian_filter1d + + key_names = indices.keys() if indices is not None else actions.keys() + data = {} + for key_name in key_names: + qpos = ( + actions[ + :, + indices[key_name], + ] + if indices is not None + else actions[key_name] + ) + num_dim = qpos.shape[1] + min_square = int(np.ceil(np.sqrt(num_dim))) + rowcol = (min_square, min_square) + + fig, axs = plt.subplots(rowcol[0], rowcol[1], figsize=(20, 20)) + for i in range(num_dim): + row = i // rowcol[0] + col = i % rowcol[1] + ax_i = axs[row, col] if min_square != 1 else axs + ax_i.plot( + ( + qpos[:, i] + if not smooth + else gaussian_filter1d(qpos[:, i], sigma=3, axis=0, mode="nearest") + ), + marker="o", + ms=2, + ) + ax_i.set_title(f"{key_name}_{i}") + + plt.tight_layout() + data[key_name] = fig + if output_path is not None and os.path.exists(output_path): + plt.savefig( + os.path.join(output_path, "action_distribution_{}.png".format(key_name)) + ) + + if return_data: + return data + + +def draw_feature( + feature_list: List[np.ndarray], vis_images: List[np.ndarray] +) -> List[np.ndarray]: + import cv2 + from copy import deepcopy + + vis_features = [] + for feature, image in zip(feature_list, vis_images): + feature_ = cv2.resize( + feature, + (image.shape[1], image.shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + image = cv2.addWeighted(deepcopy(image), 0.5, feature_, 0.5, 0) + vis_features.append(image) + return vis_features + + +class HeatMapEnv: + def __init__(self, is_success): + """Initialize the drawing environment and static elements""" + self.points = [] + self.b_fail_points = [] + self.c_fail_points = [] + self.fig, self.ax = plt.subplots(figsize=(10, 8)) + self.ax.set_aspect("equal") + + circle1 = Circle( + (0.7, 0), + radius=0.425, + fill=False, + edgecolor="red", + linewidth=2, + linestyle="--", + label="Circle Zone", + ) + circle2 = Circle( + (0.233, 0.3), + radius=0.08, + fill=False, + edgecolor="red", + linewidth=2, + linestyle="--", + label="Circle Zone", + ) + circle3 = Circle( + (0.233, -0.3), + radius=0.08, + fill=False, + edgecolor="red", + linewidth=2, + linestyle="--", + label="Circle Zone", + ) + + rectangle1 = Rectangle( + (0.67, -0.22), + 0.16, + 0.16, + angle=0, + fill=False, + edgecolor="blue", + linewidth=2, + linestyle="-.", + label="Rect Zone", + ) + + rectangle2 = Rectangle( + (0.67, 0.06), + 0.16, + 0.16, + angle=0, + fill=False, + edgecolor="green", + linewidth=2, + linestyle="-.", + label="Rect Zone", + ) + + for patch in [circle1, circle2, circle3, rectangle1, rectangle2]: + self.ax.add_patch(patch) + + self.ax.set( + xlim=(x_min, x_max), + ylim=(y_min, y_max), + xticks=np.arange(x_min, x_max + 0.01, 0.04), + yticks=np.arange(y_min, y_max + 0.01, 0.04), + ) + self.ax.grid(True, linestyle="--", alpha=0.3) + self.ax.set_title("Real-time Heatmap") + self.ax.set_xlabel("X") + self.ax.set_ylabel("Y") + + hist = np.zeros((bins, bins)) + self.im = self.ax.imshow( + hist.T, + origin="lower", + extent=[x_min, x_max, y_min, y_max], + cmap="Greys", + norm=LogNorm(vmin=0.1, vmax=10), + ) + + self.cbar = self.fig.colorbar(self.im) + self.cbar.set_label("Density") + self.is_success = is_success + if self.is_success: + text = "Success_Points_Pair: 0" + else: + text = "Bottle_Fail_Points: 0\nCup_Fail_Points: 0" + self.text_label = self.ax.text( + 0.95, + 0.95, + text, + transform=self.ax.transAxes, + fontsize=14, + color="red", + ha="right", + ) + + plt.ion() + plt.show(block=False) + plt.tight_layout() + + def update_heatmap(self, new_point, new_fail): + if self.is_success: + self.points.append(new_point) + x_coords = [p[0] for p in self.points] + y_coords = [p[1] for p in self.points] + else: + if new_fail == 0: + self.b_fail_points.append(new_point) + x_coords = [p[0] for p in self.b_fail_points] + y_coords = [p[1] for p in self.b_fail_points] + else: + self.c_fail_points.append(new_point) + x_coords = [p[0] for p in self.c_fail_points] + y_coords = [p[1] for p in self.c_fail_points] + + hist, x_edges, y_edges = np.histogram2d( + x_coords, y_coords, bins=bins, range=[[x_min, x_max], [y_min, y_max]] + ) + + self.im.set_data(hist.T) + + if self.is_success: + self.text_label.set_text(f"Success_Points_Pair: {len(self.points)/2}") + else: + if new_fail == 0: + self.text_label.set_text( + f"Bottle_Fail_Points: {len(self.b_fail_points)}\nCup_Fail_Points: {len(self.c_fail_points)}" + ) + else: + self.text_label.set_text( + f"Bottle_Fail_Points: {len(self.b_fail_points)}\nCup_Fail_Points: {len(self.c_fail_points)}" + ) + # im.autoscale() + + self.fig.canvas.draw_idle() + self.fig.canvas.flush_events() + + def save_map(self): + if self.is_success: + plt.savefig("./outputs/success_heatmap.png") + else: + plt.savefig("./outputs/fail_heatmap.png") + + +# TeX support: on Linux assume TeX in /usr/bin, on OSX check for texlive +if (platform.system() == "Darwin") and "tex" in os.getenv("PATH"): + LATEX = True +elif (platform.system() == "Linux") and os.path.isfile("/usr/bin/latex"): + LATEX = True +else: + LATEX = False + +# setup pyplot w/ tex support +if LATEX: + rc("text", usetex=True) + + +class Package: + """Encapsulation of a work package + + A work package is instantiated from a dictionary. It **has to have** + a label, astart and an end. Optionally it may contain milestones + and a color + + :arg str pkg: dictionary w/ package data name + """ + + def __init__(self, pkg): + + DEFCOLOR = "#32AEE0" + + self.label = pkg["label"] + self.start = pkg["start"] + self.end = pkg["end"] + + if self.start < 0 or self.end < 0: + raise ValueError("Package cannot begin at t < 0") + if self.start > self.end: + raise ValueError("Cannot end before started") + + try: + self.milestones = pkg["milestones"] + except KeyError: + pass + + try: + self.color = pkg["color"] + except KeyError: + self.color = DEFCOLOR + + try: + self.legend = pkg["legend"] + except KeyError: + self.legend = None + + +# https://github.com/stefanSchinkel/gantt/tree/master +class Gantt: + """Gantt + Class to render a simple Gantt chart, with optional milestones + """ + + def __init__(self, dict: Dict): + """Instantiation + + Create a new Gantt using the data in the file provided + or the sample data that came along with the script + + :arg str dataFile: file holding Gantt data + """ + + # some lists needed + self.packages = [] + self.labels = [] + + self._loadData(dict) + self._procData() + + def _loadData(self, data): + """Load data from a JSON file that has to have the keys: + packages & title. Packages is an array of objects with + a label, start and end property and optional milesstones + and color specs. + """ + + # must-haves + self.title = data["title"] + + for pkg in data["packages"]: + self.packages.append(Package(pkg)) + + self.labels = [pkg["label"] for pkg in data["packages"]] + + # optionals + self.milestones = {} + for pkg in self.packages: + try: + self.milestones[pkg.label] = pkg.milestones + except AttributeError: + pass + + try: + self.xlabel = data["xlabel"] + except KeyError: + self.xlabel = "" + try: + self.xticks = data["xticks"] + except KeyError: + self.xticks = "" + + def _procData(self): + """Process data to have all values needed for plotting""" + # parameters for bars + self.nPackages = len(self.labels) + self.start = [None] * self.nPackages + self.end = [None] * self.nPackages + + for pkg in self.packages: + idx = self.labels.index(pkg.label) + self.start[idx] = pkg.start + self.end[idx] = pkg.end + + self.durations = map(sub, self.end, self.start) + self.yPos = np.arange(self.nPackages, 0, -1) + + def format(self): + """Format various aspect of the plot, such as labels,ticks, BBox + :todo: Refactor to use a settings object + """ + # format axis + plt.tick_params( + axis="both", # format x and y + which="both", # major and minor ticks affected + bottom="on", # bottom edge ticks are on + top="off", # top, left and right edge ticks are off + left="off", + right="off", + ) + + # tighten axis but give a little room from bar height + plt.xlim(0, max(self.end)) + plt.ylim(0.5, self.nPackages + 0.5) + + # add title and package names + plt.yticks(self.yPos, [label.replace("qpos", "") for label in self.labels]) + plt.title(self.title) + + if self.xlabel: + plt.xlabel(self.xlabel) + + if self.xticks: + plt.xticks(self.xticks, map(str, self.xticks)) + + def add_milestones(self): + """Add milestones to GANTT chart. + The milestones are simple yellow diamonds + """ + + if not self.milestones: + return + + x = [] + y = [] + for key in self.milestones.keys(): + for value in self.milestones[key]: + y += [self.yPos[self.labels.index(key)]] + x += [value] + + plt.scatter( + x, y, s=120, marker="D", color="yellow", edgecolor="black", zorder=3 + ) + + def add_legend(self): + """Add a legend to the plot iff there are legend entries in + the package definitions + """ + cnt = 0 + legends = [] + for pkg in self.packages: + if pkg.legend not in legends: + cnt += 1 + idx = self.labels.index(pkg.label) + self.barlist[idx].set_label(pkg.legend) + legends.append(pkg.legend) + + if cnt > 0: + self.legend = self.ax.legend(shadow=False, ncol=3, fontsize="medium") + + def render(self): + """Prepare data for plotting""" + + # init figure + self.fig, self.ax = plt.subplots() + self.ax.yaxis.grid(False) + self.ax.xaxis.grid(True) + + # assemble colors + colors = [] + for pkg in self.packages: + colors.append(pkg.color) + + self.barlist = plt.barh( + self.yPos, + list(self.durations), + left=self.start, + align="center", + height=0.5, + alpha=1, + color=colors, + ) + + # format plot + self.format() + self.add_milestones() + self.add_legend() + + @staticmethod + def show(): + """Show the plot""" + plt.show() + + @staticmethod + def save(saveFile="img/GANTT.png"): + """Save the plot to a file. It defaults to `img/GANTT.png`. + + :arg str saveFile: file to save to + """ + plt.savefig(saveFile, bbox_inches="tight") + + +def visualize_trajectory(poses: np.ndarray): + """Visualizes a 3D trajectory and its z-axis directions. + + This function takes a series of 4x4 transformation matrices representing + poses in 3D space and visualizes the trajectory along with the z-axis + directions at each pose. + + Args: + poses (np.ndarray): A numpy array of shape (N, 4, 4), where N is the + number of poses. Each pose is a 4x4 transformation matrix. + + """ + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + # positions of the trajectory + positions = poses[:, :3, 3] + ax.plot( + positions[:, 0], positions[:, 1], positions[:, 2], "r-", linewidth=3, label="轨迹" + ) + + # direction of z-axis + for i in range(len(poses)): + R = poses[i, :3, :3] + t = poses[i, :3, 3] + + z_axis = R[:, 2] * 0.01 + ax.quiver( + t[0], + t[1], + t[2], + z_axis[0], + z_axis[1], + z_axis[2], + color="blue", + arrow_length_ratio=0.2, + ) + + ax.set_xlabel("X") + ax.set_ylabel("Y") + ax.set_zlabel("Z") + ax.set_box_aspect([1, 1, 1]) + plt.show() diff --git a/embodichain/utils/warp/__init__.py b/embodichain/utils/warp/__init__.py new file mode 100644 index 00000000..d3e93d7e --- /dev/null +++ b/embodichain/utils/warp/__init__.py @@ -0,0 +1,32 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .kernels import reshape_tiled_image +from . import kinematics +from .kinematics.opw_solver import opw_fk_kernel, opw_ik_kernel +from .kinematics.warp_trajectory import ( + trajectory_get_diff_kernel, + trajectory_interpolate_kernel, + trajectory_add_origin_kernel, + get_offset_qpos_kernel, +) + +from .kinematics.interpolate import ( + pairwise_distances, + cumsum_distances, + repeat_first_point, + interpolate_along_distance, +) diff --git a/embodichain/utils/warp/kernels.py b/embodichain/utils/warp/kernels.py new file mode 100644 index 00000000..65ab8eaf --- /dev/null +++ b/embodichain/utils/warp/kernels.py @@ -0,0 +1,93 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import warp as wp +from typing import Any + + +@wp.kernel(enable_backward=False) +def reshape_tiled_image( + tiled_image_buffer: Any, + batched_image: Any, + image_height: int, + image_width: int, + num_channels: int, + num_tiles_x: int, +): + """Reshapes a tiled image into a batch of images. + + This function reshapes the input tiled image buffer into a batch of images. The input image buffer + is assumed to be tiled in the x and y directions. The output image is a batch of images with the + specified height, width, and number of channels. + + Args: + tiled_image_buffer: The input image buffer. Shape is (height * width * num_channels * num_cameras,). + batched_image: The output image. Shape is (num_cameras, height, width, num_channels). + image_width: The width of the image. + image_height: The height of the image. + num_channels: The number of channels in the image. + num_tiles_x: The number of tiles in x-direction. + """ + # get the thread id + camera_id, height_id, width_id = wp.tid() + + # resolve the tile indices + tile_x_id = camera_id % num_tiles_x + # TODO: Currently, the tiles arranged in the bottom-to-top order, which should be changed. + tile_y_id = ( + num_tiles_x - 1 - (camera_id // num_tiles_x) + ) # Adjust for bottom-to-top tiling + # compute the start index of the pixel in the tiled image buffer + pixel_start = ( + num_channels + * num_tiles_x + * image_width + * (image_height * tile_y_id + height_id) + + num_channels * tile_x_id * image_width + + num_channels * width_id + ) + + # copy the pixel values into the batched image + for i in range(num_channels): + batched_image[camera_id, height_id, width_id, i] = batched_image.dtype( + tiled_image_buffer[pixel_start + i] + ) + + +# uint32 -> int32 conversion is required for non-colored segmentation annotators +wp.overload( + reshape_tiled_image, + { + "tiled_image_buffer": wp.array(dtype=wp.uint32), + "batched_image": wp.array(dtype=wp.uint32, ndim=4), + }, +) +# uint8 is used for 4 channel annotators +wp.overload( + reshape_tiled_image, + { + "tiled_image_buffer": wp.array(dtype=wp.uint8), + "batched_image": wp.array(dtype=wp.uint8, ndim=4), + }, +) +# float32 is used for single channel annotators +wp.overload( + reshape_tiled_image, + { + "tiled_image_buffer": wp.array(dtype=wp.float32), + "batched_image": wp.array(dtype=wp.float32, ndim=4), + }, +) diff --git a/embodichain/utils/warp/kinematics/__init__.py b/embodichain/utils/warp/kinematics/__init__.py new file mode 100644 index 00000000..8fae6605 --- /dev/null +++ b/embodichain/utils/warp/kinematics/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import interpolate +from . import opw_solver +from . import warp_trajectory diff --git a/embodichain/utils/warp/kinematics/interpolate.py b/embodichain/utils/warp/kinematics/interpolate.py new file mode 100644 index 00000000..e11b088a --- /dev/null +++ b/embodichain/utils/warp/kinematics/interpolate.py @@ -0,0 +1,172 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import warp as wp + + +@wp.kernel +def pairwise_distances( + points: wp.array(dtype=float), # flattened: length B*N*M + distances: wp.array(dtype=float), # flattened: length B*(N-1) + B: int, + N: int, + M: int, +): + """Compute Euclidean distances between consecutive points along N using 1D flattened storage. + + Memory layout (row-major): + points(b, i, j) => b*N*M + i*M + j + distances(b, i) => b*(N-1) + i + Result: + distances[b,i] = ||points[b,i+1,:] - points[b,i,:]||_2 + """ + tid = wp.tid() + total = B * (N - 1) + if tid >= total: + return + + b = tid // (N - 1) + i = tid - b * (N - 1) + + base_points = b * N * M + s = float(0.0) + for j in range(M): + p0 = points[base_points + i * M + j] + p1 = points[base_points + (i + 1) * M + j] + d = p1 - p0 + s = s + d * d + distances[b * (N - 1) + i] = wp.sqrt(s) + + +@wp.kernel +def cumsum_distances( + distances: wp.array(dtype=float), # flattened: length B*(N-1) + cumulative: wp.array(dtype=float), # flattened: length B*N + B: int, + N: int, +): + """Compute per-batch cumulative distances with flattened indexing. + + Layout: + distances(b,i) => b*(N-1) + i + cumulative(b,i) => b*N + i + Definition: + cumulative[b,0] = 0 + cumulative[b,i] = sum_{k=0}^{i-1} distances[b,k] + """ + b = wp.tid() + if b >= B: + return + + cumulative[b * N + 0] = float(0.0) + acc = float(0.0) + for i in range(N - 1): + acc = acc + distances[b * (N - 1) + i] + cumulative[b * N + (i + 1)] = acc + + +@wp.kernel +def repeat_first_point( + points: wp.array(dtype=float), # flattened: length B*N*M (N may be 1) + out: wp.array(dtype=float), # flattened: length B*T*M + B: int, + T: int, + M: int, + N: int, +): + """Repeat the first waypoint of each batch across T samples (used when N==1). + + First point (b,j): b*N*M + j (i=0) + Output (b,t,j): b*T*M + t*M + j + """ + tid = wp.tid() + total = B * T + if tid >= total: + return + + b = tid // T + t = tid - b * T + + base_in = b * N * M # N expected 1 in usage + base_out = b * T * M + t * M + for j in range(M): + out[base_out + j] = points[base_in + j] + + +@wp.kernel +def interpolate_along_distance( + points: wp.array(dtype=float), # flattened B*N*M + cumulative: wp.array(dtype=float), # flattened B*N + out: wp.array(dtype=float), # flattened B*T*M + B: int, + N: int, + M: int, + T: int, +): + """Piecewise-linear interpolation at uniformly spaced cumulative-distance samples. + + Indexing (flattened): + points(b,i,j) => b*N*M + i*M + j + cumulative(b,i) => b*N + i + out(b,t,j) => b*T*M + t*M + j + Steps: + 1. Compute target distance new_d in [0, total_len]. + 2. Binary search cumulative to find segment [lo, hi]. + 3. Linear interpolate each dimension. + """ + tid = wp.tid() + total_threads = B * T + if tid >= total_threads: + return + + b = tid // T + t = tid - b * T + + # total path length for batch b + total_len = cumulative[b * N + (N - 1)] + + # evenly spaced target distance + new_d = float(0.0) + if T > 1: + new_d = total_len * float(t) / float(T - 1) + else: + new_d = float(0.0) + + # binary search for segment boundaries + lo = int(0) + hi = N - 1 + while (lo + 1) < hi: + mid = (lo + hi) // 2 + if cumulative[b * N + mid] <= new_d: + lo = mid + else: + hi = mid + + c_lo = cumulative[b * N + lo] + c_hi = cumulative[b * N + hi] + denom = c_hi - c_lo + + alpha = float(0.0) + if denom > float(0.0): + alpha = (new_d - c_lo) / denom + + base_points = b * N * M + base_out = b * T * M + t * M + p_lo_offset = base_points + lo * M + p_hi_offset = base_points + hi * M + for j in range(M): + p_lo = points[p_lo_offset + j] + p_hi = points[p_hi_offset + j] + out[base_out + j] = p_lo + alpha * (p_hi - p_lo) diff --git a/embodichain/utils/warp/kinematics/opw_solver.py b/embodichain/utils/warp/kinematics/opw_solver.py new file mode 100644 index 00000000..d0d7213b --- /dev/null +++ b/embodichain/utils/warp/kinematics/opw_solver.py @@ -0,0 +1,500 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import warp as wp +import numpy as np +from typing import Tuple + + +wp_vec48f = wp.types.vector(length=48, dtype=float) +wp_vec6f = wp.types.vector(length=6, dtype=float) + + +@wp.func +def normalize_to_pi(angle: float) -> float: + angle = (angle + wp.pi) % (2.0 * wp.pi) - wp.pi + return angle + + +@wp.func +def safe_acos(x: float) -> float: + return wp.acos(wp.clamp(x, -1.0, 1.0)) + + +@wp.func +def th4_th6_for_branch( + i: int, + r_: wp.mat33f, + sin1: wp.vec4f, + cos1: wp.vec4f, + s23: wp.vec4f, + c23: wp.vec4f, +) -> Tuple[float, float]: + th4_y = r_[1, 2] * cos1[i] - r_[0, 2] * sin1[i] + th4_x = ( + r_[0, 2] * c23[i] * cos1[i] + r_[1, 2] * c23[i] * sin1[i] - r_[2, 2] * s23[i] + ) + th4 = wp.atan2(th4_y, th4_x) + + th6_y = ( + r_[0, 1] * s23[i] * cos1[i] + r_[1, 1] * s23[i] * sin1[i] + r_[2, 1] * c23[i] + ) + th6_x = ( + -r_[0, 0] * s23[i] * cos1[i] - r_[1, 0] * s23[i] * sin1[i] - r_[2, 0] * c23[i] + ) + th6 = wp.atan2(th6_y, th6_x) + return th4, th6 + + +@wp.struct +class OPWparam: + a1: float + a2: float + b: float + c1: float + c2: float + c3: float + c4: float + + +@wp.func +def get_transform_err( + transform1: wp.mat44f, transform2: wp.mat44f +) -> Tuple[float, float]: + t_diff = wp.vec3f( + transform1[0, 3] - transform2[0, 3], + transform1[1, 3] - transform2[1, 3], + transform1[2, 3] - transform2[2, 3], + ) + t_err = wp.length(t_diff) + r1 = wp.mat33f( + transform1[0, 0], + transform1[0, 1], + transform1[0, 2], + transform1[1, 0], + transform1[1, 1], + transform1[1, 2], + transform1[2, 0], + transform1[2, 1], + transform1[2, 2], + ) + r2 = wp.mat33f( + transform2[0, 0], + transform2[0, 1], + transform2[0, 2], + transform2[1, 0], + transform2[1, 1], + transform2[1, 2], + transform2[2, 0], + transform2[2, 1], + transform2[2, 2], + ) + r_diff = wp.transpose(r1) * r2 + cos_value = 0.5 * (wp.trace(r_diff) - 1.0) + r_err = wp.abs(safe_acos(cos_value)) + return t_err, r_err + + +@wp.func +def opw_single_fk( + q1: float, q2: float, q3: float, q4: float, q5: float, q6: float, params: OPWparam +): + psi3 = wp.atan2(params.a2, params.c3) + k = wp.sqrt(params.a2 * params.a2 + params.c3 * params.c3) + + # Precompute q23_psi3 for better readability and reuse + q23_psi3 = q2 + q3 + psi3 + sin_q23_psi3 = wp.sin(q23_psi3) + cos_q23_psi3 = wp.cos(q23_psi3) + + cx1 = params.c2 * wp.sin(q2) + k * sin_q23_psi3 + params.a1 + cy1 = params.b + cz1 = params.c2 * wp.cos(q2) + k * cos_q23_psi3 + + cx0 = cx1 * wp.cos(q1) - cy1 * wp.sin(q1) + cy0 = cx1 * wp.sin(q1) + cy1 * wp.cos(q1) + cz0 = cz1 + params.c1 + + s1, c1 = wp.sin(q1), wp.cos(q1) + s2, c2 = wp.sin(q2), wp.cos(q2) + s3, c3 = wp.sin(q3), wp.cos(q3) + s4, c4 = wp.sin(q4), wp.cos(q4) + s5, c5 = wp.sin(q5), wp.cos(q5) + s6, c6 = wp.sin(q6), wp.cos(q6) + + r_0c = wp.mat33f( + c1 * c2 * c3 - c1 * s2 * s3, + -s1, + c1 * c2 * s3 + c1 * s2 * c3, + s1 * c2 * c3 - s1 * s2 * s3, + c1, + s1 * c2 * s3 + s1 * s2 * c3, + -s2 * c3 - c2 * s3, + 0.0, + -s2 * s3 + c2 * c3, + ) + r_ce = wp.mat33f( + c4 * c5 * c6 - s4 * s6, + -c4 * c5 * s6 - s4 * c6, + c4 * s5, + s4 * c5 * c6 + c4 * s6, + -s4 * c5 * s6 + c4 * c6, + s4 * s5, + -s5 * c6, + s5 * s6, + c5, + ) + + r_0e = r_0c * r_ce + t_0e = wp.vec3f( + cx0 + params.c4 * r_0e[0, 2], + cy0 + params.c4 * r_0e[1, 2], + cz0 + params.c4 * r_0e[2, 2], + ) + + return wp.mat44f( + r_0e[0, 0], + r_0e[0, 1], + r_0e[0, 2], + t_0e[0], + r_0e[1, 0], + r_0e[1, 1], + r_0e[1, 2], + t_0e[1], + r_0e[2, 0], + r_0e[2, 1], + r_0e[2, 2], + t_0e[2], + 0.0, + 0.0, + 0.0, + 1.0, + ) + + +@wp.kernel +def opw_fk_kernel( + qpos: wp.array(dtype=float), + ee_pose: wp.mat44f, + params: OPWparam, + offsets: wp.array(dtype=float), + sign_corrections: wp.array(dtype=float), + xpos: wp.array(dtype=float), +): + i = wp.tid() + dof = 6 + q1 = qpos[0 + i * dof] * sign_corrections[0] - offsets[0] + q2 = qpos[1 + i * dof] * sign_corrections[1] - offsets[1] + q3 = qpos[2 + i * dof] * sign_corrections[2] - offsets[2] + q4 = qpos[3 + i * dof] * sign_corrections[3] - offsets[3] + q5 = qpos[4 + i * dof] * sign_corrections[4] - offsets[4] + q6 = qpos[5 + i * dof] * sign_corrections[5] - offsets[5] + + p_0e = opw_single_fk(q1, q2, q3, q4, q5, q6, params) + result = p_0e * ee_pose + + # assign to result + for t in range(16): + xpos[t + i * 16] = result[t // 4, t % 4] + + +@wp.kernel +def opw_ik_kernel( + xpos: wp.array(dtype=float), + ee_pose_inv: wp.mat44f, + params: OPWparam, + offsets: wp.array(dtype=float), + sign_corrections: wp.array(dtype=float), + qpos: wp.array(dtype=float), + ik_valid: wp.array(dtype=int), +): + i = wp.tid() + # TODO: warp slice ? + ee_pose = ( + wp.mat44f( + xpos[i * 16 + 0], + xpos[i * 16 + 1], + xpos[i * 16 + 2], + xpos[i * 16 + 3], + xpos[i * 16 + 4], + xpos[i * 16 + 5], + xpos[i * 16 + 6], + xpos[i * 16 + 7], + xpos[i * 16 + 8], + xpos[i * 16 + 9], + xpos[i * 16 + 10], + xpos[i * 16 + 11], + xpos[i * 16 + 12], + xpos[i * 16 + 13], + xpos[i * 16 + 14], + xpos[i * 16 + 15], + ) + * ee_pose_inv + ) + r_ = wp.mat33f( + ee_pose[0, 0], + ee_pose[0, 1], + ee_pose[0, 2], + ee_pose[1, 0], + ee_pose[1, 1], + ee_pose[1, 2], + ee_pose[2, 0], + ee_pose[2, 1], + ee_pose[2, 2], + ) + rz_ = wp.vec3f(ee_pose[0, 2], ee_pose[1, 2], ee_pose[2, 2]) + t_ = wp.vec3f(ee_pose[0, 3], ee_pose[1, 3], ee_pose[2, 3]) + + # to wrist center position + c = t_ - params.c4 * rz_ + + r_xy2 = c[0] * c[0] + c[1] * c[1] + nx1_sqrt_arg = r_xy2 - params.b * params.b + nx1 = wp.sqrt(nx1_sqrt_arg) - params.a1 + + tmp1 = wp.atan2(c[1], c[0]) + tmp2 = wp.atan2(params.b, nx1 + params.a1) + theta1_i = tmp1 - tmp2 + theta1_ii = tmp1 + tmp2 - wp.pi + + tmp3 = c[2] - params.c1 + s1_2 = nx1 * nx1 + tmp3 * tmp3 + + tmp4 = nx1 + 2.0 * params.a1 + s2_2 = tmp4 * tmp4 + tmp3 * tmp3 + kappa_2 = params.a2 * params.a2 + params.c3 * params.c3 + + c2_2 = params.c2 * params.c2 + + tmp5 = s1_2 + c2_2 - kappa_2 + s1 = wp.sqrt(s1_2) + s2 = wp.sqrt(s2_2) + + # theta2 + tmp13 = safe_acos(tmp5 / (2.0 * s1 * params.c2)) + tmp14 = wp.atan2(nx1, c[2] - params.c1) + theta2_i = -tmp13 + tmp14 + theta2_ii = tmp13 + tmp14 + + tmp6 = s2_2 + c2_2 - kappa_2 + tmp15 = safe_acos(tmp6 / (2.0 * s2 * params.c2)) + tmp16 = wp.atan2(nx1 + 2.0 * params.a1, c[2] - params.c1) + theta2_iii = -tmp15 - tmp16 + theta2_iv = tmp15 - tmp16 + + # theta3 + tmp7 = s1_2 - c2_2 - kappa_2 + tmp8 = s2_2 - c2_2 - kappa_2 + tmp9 = 2.0 * params.c2 * wp.sqrt(kappa_2) + tmp10 = wp.atan2(params.a2, params.c3) + + tmp11 = safe_acos(tmp7 / tmp9) + theta3_i = tmp11 - tmp10 + theta3_ii = -tmp11 - tmp10 + + tmp12 = safe_acos(tmp8 / tmp9) + theta3_iii = tmp12 - tmp10 + theta3_iv = -tmp12 - tmp10 + + # precompute sin/cos(theta1) + theta1_i_sin = wp.sin(theta1_i) + theta1_i_cos = wp.cos(theta1_i) + theta1_ii_sin = wp.sin(theta1_ii) + theta1_ii_cos = wp.cos(theta1_ii) + + sin1 = wp.vec4f(theta1_i_sin, theta1_i_sin, theta1_ii_sin, theta1_ii_sin) + cos1 = wp.vec4f(theta1_i_cos, theta1_i_cos, theta1_ii_cos, theta1_ii_cos) + s23 = wp.vec4f( + wp.sin(theta2_i + theta3_i), + wp.sin(theta2_ii + theta3_ii), + wp.sin(theta2_iii + theta3_iii), + wp.sin(theta2_iv + theta3_iv), + ) + c23 = wp.vec4f( + wp.cos(theta2_i + theta3_i), + wp.cos(theta2_ii + theta3_ii), + wp.cos(theta2_iii + theta3_iii), + wp.cos(theta2_iv + theta3_iv), + ) + + # m for theta5 + m = wp.vec4f( + r_[0, 2] * s23[0] * cos1[0] + r_[1, 2] * s23[0] * sin1[0] + r_[2, 2] * c23[0], + r_[0, 2] * s23[1] * cos1[1] + r_[1, 2] * s23[1] * sin1[1] + r_[2, 2] * c23[1], + r_[0, 2] * s23[2] * cos1[2] + r_[1, 2] * s23[2] * sin1[2] + r_[2, 2] * c23[2], + r_[0, 2] * s23[3] * cos1[3] + r_[1, 2] * s23[3] * sin1[3] + r_[2, 2] * c23[3], + ) + theta5 = wp.vec4f( + wp.atan2(wp.sqrt(wp.clamp(1.0 - m[0] * m[0], 0.0, 1.0)), m[0]), + wp.atan2(wp.sqrt(wp.clamp(1.0 - m[1] * m[1], 0.0, 1.0)), m[1]), + wp.atan2(wp.sqrt(wp.clamp(1.0 - m[2] * m[2], 0.0, 1.0)), m[2]), + wp.atan2(wp.sqrt(wp.clamp(1.0 - m[3] * m[3], 0.0, 1.0)), m[3]), + ) + + theta4_i, theta6_i = th4_th6_for_branch(0, r_, sin1, cos1, s23, c23) + theta4_ii, theta6_ii = th4_th6_for_branch(1, r_, sin1, cos1, s23, c23) + theta4_iii, theta6_iii = th4_th6_for_branch(2, r_, sin1, cos1, s23, c23) + theta4_iv, theta6_iv = th4_th6_for_branch(3, r_, sin1, cos1, s23, c23) + theta5_i, theta5_ii, theta5_iii, theta5_iv = ( + theta5[0], + theta5[1], + theta5[2], + theta5[3], + ) + theta5_v, theta5_vi, theta5_vii, theta5_viii = ( + -theta5_i, + -theta5_ii, + -theta5_iii, + -theta5_iv, + ) + + theta4_v, theta4_vi, theta4_vii, theta4_viii = ( + theta4_i + wp.pi, + theta4_ii + wp.pi, + theta4_iii + wp.pi, + theta4_iv + wp.pi, + ) + theta6_v, theta6_vi, theta6_vii, theta6_viii = ( + theta6_i - wp.pi, + theta6_ii - wp.pi, + theta6_iii - wp.pi, + theta6_iv - wp.pi, + ) + # combine all 8 solutions + theta = wp_vec48f( + theta1_i, + theta2_i, + theta3_i, + theta4_i, + theta5_i, + theta6_i, + theta1_i, + theta2_ii, + theta3_ii, + theta4_ii, + theta5_ii, + theta6_ii, + theta1_ii, + theta2_iii, + theta3_iii, + theta4_iii, + theta5_iii, + theta6_iii, + theta1_ii, + theta2_iv, + theta3_iv, + theta4_iv, + theta5_iv, + theta6_iv, + theta1_i, + theta2_i, + theta3_i, + theta4_v, + theta5_v, + theta6_v, + theta1_i, + theta2_ii, + theta3_ii, + theta4_vi, + theta5_vi, + theta6_vi, + theta1_ii, + theta2_iii, + theta3_iii, + theta4_vii, + theta5_vii, + theta6_vii, + theta1_ii, + theta2_iv, + theta3_iv, + theta4_viii, + theta5_viii, + theta6_viii, + ) + DOF = 6 + N_SOL = 8 + # apply sign correction and offsets, and write to qpos + for j in range(N_SOL): + qpos_start = i * DOF * N_SOL + j * DOF + + for k in range(DOF): + idx = j * DOF + k + qpos[qpos_start + k] = normalize_to_pi( + (theta[idx] + offsets[k]) * sign_corrections[k] + ) + + # filter invalid solutions + check_ee_pose = opw_single_fk( + theta[j * DOF + 0], + theta[j * DOF + 1], + theta[j * DOF + 2], + theta[j * DOF + 3], + theta[j * DOF + 4], + theta[j * DOF + 5], + params, + ) + t_err, r_err = get_transform_err(check_ee_pose, ee_pose) + # mark invalid solutions (cannot pass ik check) + if t_err > 1e-2 or r_err > 1e-1: + ik_valid[i * N_SOL + j] = 0 + else: + ik_valid[i * N_SOL + j] = 1 + + +@wp.kernel +def opw_best_ik_kernel( + full_ik_result: wp.array(dtype=float), + full_ik_valid: wp.array(dtype=int), + qpos_seed: wp.array(dtype=float), + joint_weights: wp_vec6f, + best_ik_result: wp.array(dtype=float), + best_ik_valid: wp.array(dtype=int), +): + i = wp.tid() + DOF = 6 + N_SOL = 8 + + best_weighted_dis = float(1e10) + best_ids = int(-1) + for j in range(N_SOL): + is_full_valid = full_ik_valid[i * N_SOL + j] + if is_full_valid == 0: + # invalid ik result + continue + weighted_dis = 0.0 + for t in range(DOF): + weighted_dis += ( + (full_ik_result[i * N_SOL * DOF + j * DOF + t] - qpos_seed[i * DOF + t]) + * joint_weights[0] + * ( + full_ik_result[i * N_SOL * DOF + j * DOF + t] + - qpos_seed[i * DOF + t] + ) + * joint_weights[0] + ) + if weighted_dis < best_weighted_dis: + best_weighted_dis = weighted_dis + best_ids = j + if best_ids != -1: + # found best solution + best_ik_valid[i] = 1 + for k in range(DOF): + best_ik_result[i * DOF + k] = full_ik_result[ + i * N_SOL * DOF + best_ids * DOF + k + ] + else: + # no valid solution + best_ik_valid[i] = 0 diff --git a/embodichain/utils/warp/kinematics/srs_solver.py b/embodichain/utils/warp/kinematics/srs_solver.py new file mode 100644 index 00000000..4e960b09 --- /dev/null +++ b/embodichain/utils/warp/kinematics/srs_solver.py @@ -0,0 +1,789 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import warp as wp + + +@wp.func +def identity_mat44() -> wp.mat44: + # fmt: off + return wp.mat44( + 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0 + ) + # fmt: on + + +@wp.func +def identity_mat33() -> wp.mat33: + # fmt: off + return wp.mat33( + 1.0, 0.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0 + ) + # fmt: on + + +@wp.func +def safe_acos(x: float) -> float: + return wp.acos(wp.clamp(x, -0.999999, 0.999999)) + + +@wp.func +def safe_division(numerator: float, denominator: float, eps: float = 1e-10) -> float: + if wp.abs(denominator) < eps: + return 0.0 + return numerator / denominator + + +@wp.func +def skew(vec: wp.vec3) -> wp.mat33: + """ + Calculate the skew-symmetric matrix of a vector. + + Args: + vec (wp.vec3): Input vector. + + Returns: + wp.mat33: Skew-symmetric matrix. + """ + # fmt: off + return wp.mat33( + 0.0, -vec[2], vec[1], + vec[2], 0.0, -vec[0], + -vec[1], vec[0], 0.0, + ) + # fmt: on + + +@wp.func +def dh_transform(d: float, alpha: float, a: float, theta: float) -> wp.mat44: + """ + Compute the Denavit-Hartenberg transformation matrix. + + Args: + d (float): Link offset. + alpha (float): Link twist. + a (float): Link length. + theta (float): Joint angle. + + Returns: + wp.mat44: The resulting transformation matrix. + """ + ct, st = wp.cos(theta), wp.sin(theta) + ca, sa = wp.cos(alpha), wp.sin(alpha) + # fmt: off + return wp.mat44( + ct, -st * ca, st * sa, a * ct, + st, ct * ca, -ct * sa, a * st, + 0.0, sa, ca, d, + 0.0, 0.0, 0.0, 1.0 + ) + # fmt: on + + +@wp.func +def transform_pose( + target_xpos: wp.mat44, + T_b_ob_inv: wp.mat44, + T_e_oe_inv: wp.mat44, + tcp_inv: wp.mat44, +) -> wp.mat44: + """ + Transform the target pose to the TCP frame. + Args: + target_xpos (wp.mat44): The target pose matrix. + T_b_ob_inv (wp.mat44): Inverse base-to-object transform. + tcp_inv (wp.mat44): Inverse TCP transform. + T_e_oe_inv (wp.mat44): Inverse end-effector transform. + Returns: + wp.mat44: Transformed pose in TCP frame. + """ + return T_b_ob_inv @ target_xpos @ tcp_inv @ T_e_oe_inv + + +@wp.kernel +def transform_pose_kernel( + target_xpos: wp.array(dtype=wp.mat44), + T_b_ob_inv: wp.mat44, + T_e_oe_inv: wp.mat44, + tcp_inv: wp.mat44, + output: wp.array(dtype=wp.mat44), +): + """ + Transform a batch of target poses to the TCP frame. + + Args: + target_xpos (wp.array): Batch of target pose matrices. + T_b_ob_inv (wp.mat44): Inverse base-to-object transform. + tcp_inv (wp.mat44): Inverse TCP transform. + T_e_oe_inv (wp.mat44): Inverse end-effector transform. + output (wp.array): Output array for transformed poses. + """ + tid = wp.tid() + output[tid] = T_b_ob_inv @ target_xpos[tid] @ tcp_inv @ T_e_oe_inv + + +@wp.func +def calculate_arm_joint_angles( + P_s_to_w: wp.vec3, + elbow_GC4: float, + link_lengths: wp.array(dtype=float), + res: wp.array(dtype=int), + joints: wp.array(dtype=wp.vec4), + tid: int, +): + """ + Compute joint angles for a 3-DOF arm given the shoulder-to-wrist vector. + + Args: + P_s_to_w (wp.vec3): Shoulder-to-wrist vector. + elbow_GC4 (float): Elbow configuration, typically ±1. + link_lengths (wp.array): [d_bs, d_se, d_ew] for each segment length. + res (wp.array): Output success flag. + joints (wp.array): Output joint angles. + tid (int): Thread index. + """ + d_bs = link_lengths[0] + d_se = link_lengths[1] + d_ew = link_lengths[2] + + # Extract components + x, y, z = P_s_to_w.x, P_s_to_w.y, P_s_to_w.z + horizontal_distance = wp.length(wp.vec2(x, y)) + shoulder_to_wrist_length = wp.length(P_s_to_w) + + # Initialize joint values + joints_val = wp.vec4() + + # Check reachability + if shoulder_to_wrist_length < wp.abs(d_bs + d_ew): + res[tid] = 0 + joints[tid] = joints_val + return + + # Compute elbow angle + elbow_cos_angle = ( + wp.pow(shoulder_to_wrist_length, 2.0) - wp.pow(d_se, 2.0) - wp.pow(d_ew, 2.0) + ) / (2.0 * d_se * d_ew) + if wp.abs(elbow_cos_angle) > 1.0: + res[tid] = 0 + joints[tid] = joints_val + return + + joints_val[3] = elbow_GC4 * safe_acos(elbow_cos_angle) + + # Compute shoulder angle + joints_val[0] = wp.atan2(y, x) if wp.abs(z) > 1e-6 else 0.0 + + # Compute joint 2 angle + angle_phi = safe_acos( + (wp.pow(d_se, 2.0) + wp.pow(shoulder_to_wrist_length, 2.0) - wp.pow(d_ew, 2.0)) + / (2.0 * d_se * shoulder_to_wrist_length) + ) + joints_val[1] = wp.atan2(horizontal_distance, z) + elbow_GC4 * angle_phi + + # Set success flag and output joint values + res[tid] = 1 + joints[tid] = joints_val + + +@wp.func +def compute_reference_plane( + pose: wp.mat44, + elbow_GC4: float, + link_lengths: wp.array(dtype=float), + dh_params: wp.array(dtype=float), + res: wp.array(dtype=int), + plane_normal: wp.array(dtype=wp.vec3), + base_to_elbow_rotation: wp.array(dtype=wp.mat33), + joints: wp.array(dtype=wp.vec4), + tid: int, +): + """ + Compute the reference plane normal, base-to-elbow rotation, and joint angles. + + Args: + pose (wp.mat44): Target pose matrix (4x4). + elbow_GC4 (float): Elbow configuration, typically ±1. + link_lengths (wp.array): Link lengths, at least [d_bs, d_se, d_ew, d_hand]. + dh_params (wp.array): DH parameters, shape [num_joints * 4]. + res (wp.array): Output success flag. + plane_normal (wp.array): Output plane normal vector. + base_to_elbow_rotation (wp.array): Output base-to-elbow rotation matrix. + joints (wp.array): Output joint angles. + tid (int): Thread index. + """ + # Extract position and rotation + P_target = wp.vec3(pose[0, 3], pose[1, 3], pose[2, 3]) + # fmt: off + R_target = wp.mat33( + pose[0, 0], pose[0, 1], pose[0, 2], + pose[1, 0], pose[1, 1], pose[1, 2], + pose[2, 0], pose[2, 1], pose[2, 2], + ) + # fmt: on + + # Base to shoulder + P02 = wp.vec3(0.0, 0.0, link_lengths[0]) + P67 = wp.vec3(0.0, 0.0, dh_params[6 * 4 + 0]) + + # Wrist position + P06 = P_target - R_target @ P67 + # Shoulder to wrist + P26 = P06 - P02 + + # Calculate joint angles + calculate_arm_joint_angles(P26, elbow_GC4, link_lengths, res, joints, tid) + if res[tid] == 0: + plane_normal[tid] = wp.vec3() + base_to_elbow_rotation[tid] = identity_mat33() + joints[tid] = wp.vec4() + return + + # Lower arm transformation (joint 4) + T34 = dh_transform( + dh_params[3 * 4 + 0], dh_params[3 * 4 + 1], dh_params[3 * 4 + 2], 0.0 + ) + P34 = wp.vec3(T34[0, 3], T34[1, 3], T34[2, 3]) + + # Reference plane normal + v1 = wp.normalize(P34 - P02) + v2 = wp.normalize(P06 - P02) + plane_normal[tid] = wp.cross(v1, v2) + + # Compute base-to-elbow rotation + base_to_elbow_rotation[tid] = identity_mat33() + for i in range(3): + base_idx = i * 4 + T = dh_transform( + dh_params[base_idx + 0], + dh_params[base_idx + 1], + dh_params[base_idx + 2], + joints[tid][i], + ) + # fmt: off + base_to_elbow_rotation[tid] = base_to_elbow_rotation[tid] @ wp.mat33( + T[0, 0], T[0, 1], T[0, 2], + T[1, 0], T[1, 1], T[1, 2], + T[2, 0], T[2, 1], T[2, 2], + ) + # fmt: on + + res[tid] = 1 + + +@wp.kernel +def compute_fk_kernel( + joint_angles: wp.array(dtype=float), + dh_params: wp.array(dtype=float), + rotation_directions: wp.array(dtype=float), + T_b_ob: wp.mat44, + T_oe_e: wp.mat44, + tcp_transform: wp.mat44, + pose_out: wp.array(dtype=wp.mat44), + success: wp.array(dtype=int), +): + """ + Compute forward kinematics (FK) for a batch of joint states. + + Args: + joint_angles (wp.array): Array of joint angles for each target ([N * num_joints]). + dh_params (wp.array): Denavit-Hartenberg parameters for the robot + ([num_joints * 4], where each joint has [d, alpha, a, theta]). + rotation_directions (wp.array): Array of rotation direction multipliers for each joint ([num_joints]). + T_b_ob (wp.mat44): Base-to-object transformation matrix. + T_oe_e (wp.mat44): End-effector-to-object transformation matrix. + tcp_transform (wp.mat44): Tool center point (TCP) transformation matrix. + pose_out (wp.array): Output array for computed poses ([N, 4x4]). + success (wp.array): Output array indicating whether FK computation was successful ([N]). + """ + tid = wp.tid() + num_joints = rotation_directions.shape[0] + + # Initialize pose as identity matrix + pose = identity_mat44() + + # Loop through each joint and apply DH transformation + for i in range(num_joints): + base_idx = i * 4 + d = dh_params[base_idx + 0] + alpha = dh_params[base_idx + 1] + a = dh_params[base_idx + 2] + theta = dh_params[base_idx + 3] + theta += joint_angles[tid * num_joints + i] * rotation_directions[i] + T = dh_transform(d, alpha, d, theta) + pose = pose @ T + + # Apply additional transforms: base, end-effector, TCP + pose = T_b_ob @ pose @ T_oe_e @ tcp_transform + + # Output pose and set success flag + pose_out[tid] = pose + success[tid] = 1 + + +@wp.func +def frobenius_norm(mat: wp.mat44) -> float: + """ + Compute the Frobenius norm of a 4x4 matrix. + + Args: + mat (wp.mat44): Input matrix. + + Returns: + float: Frobenius norm of the matrix. + """ + norm = 0.0 + for i in range(4): + for j in range(4): + norm += wp.pow(mat[i, j], 2.0) + return wp.sqrt(norm) + + +@wp.func +def validate_fk_with_target( + q1: float, + q2: float, + q3: float, + q4: float, + q5: float, + q6: float, + q7: float, + dh_params: wp.array(dtype=float), + rotation_directions: wp.array(dtype=float), + target_xpos: wp.mat44, + tolerance: float, +) -> int: + """ + Validate if the FK result matches the target pose within a given tolerance. + + Args: + joint_angles (wp.array): Joint angles for FK computation. + dh_params (wp.array): Denavit-Hartenberg parameters. + rotation_directions (wp.array): Rotation direction multipliers for each joint. + target_xpos (wp.mat44): Target pose matrix. + tolerance (float): Allowed error tolerance for validation. + + Returns: + int: 1 if FK result matches the target pose within tolerance, 0 otherwise. + """ + num_joints = wp.int32(rotation_directions.shape[0]) + + # Initialize pose as identity matrix + pose = identity_mat44() + + # Compute FK + for i in range(num_joints): + d = dh_params[i * 4 + 0] + alpha = dh_params[i * 4 + 1] + a = dh_params[i * 4 + 2] + theta = dh_params[i * 4 + 3] + # Apply joint angle with rotation direction + if i == 0: + joint_angle = q1 + elif i == 1: + joint_angle = q2 + elif i == 2: + joint_angle = q3 + elif i == 3: + joint_angle = q4 + elif i == 4: + joint_angle = q5 + elif i == 5: + joint_angle = q6 + elif i == 6: + joint_angle = q7 + + theta += joint_angle * rotation_directions[i] + T = dh_transform(d, alpha, a, theta) + pose = pose @ T + + # Compute the Frobenius norm of the difference + pose_diff = pose - target_xpos + pose_error = frobenius_norm(pose_diff) + + # Validate against tolerance + return 1 if pose_error <= tolerance else 0 + + +# TODO: automatic gradient support +@wp.kernel +def compute_ik_kernel( + combinations: wp.array(dtype=wp.vec3), + target_xpos_list: wp.array(dtype=wp.mat44), + angles_list: wp.array(dtype=float), + qpos_limits: wp.array(dtype=wp.vec2), + configs: wp.array(dtype=wp.vec3), + dh_params: wp.array(dtype=float), + link_lengths: wp.array(dtype=float), + rotation_directions: wp.array(dtype=float), + res_arm_angles: wp.array(dtype=int), + joints_arm: wp.array(dtype=wp.vec4), + res_plane_normal: wp.array(dtype=int), + plane_normal: wp.array(dtype=wp.vec3), + base_to_elbow_rotation: wp.array(dtype=wp.mat33), + joints_plane: wp.array(dtype=wp.vec4), + success: wp.array(dtype=int), + qpos_out: wp.array(dtype=float), +): + """ + Compute inverse kinematics (IK) in parallel for multiple target poses. + + Args: + combinations (wp.array): Array of combinations, where each entry specifies + the indices of the target pose, configuration, and reference angle. + target_xpos_list (wp.array): Array of target poses (4x4 transformation matrices). + angles_list (wp.array): Array of reference angles for IK computation. + qpos_limits (wp.array): Array of joint position limits (min, max) for each joint. + configs (wp.array): Array of configuration vectors (shoulder, elbow, wrist). + dh_params (wp.array): Denavit-Hartenberg parameters for the robot. + link_lengths (wp.array): Array of link lengths for the robot arm. + rotation_directions (wp.array): Array of rotation direction multipliers for each joint. + res_arm_angles (wp.array): Output array for arm joint angle computation results. + joints_arm (wp.array): Output array for computed arm joint angles. + res_plane_normal (wp.array): Output array for plane normal computation results. + plane_normal (wp.array): Output array for computed plane normal vectors. + base_to_elbow_rotation (wp.array): Output array for base-to-elbow rotation matrices. + joints_plane (wp.array): Output array for computed joint angles in the plane. + success (wp.array): Output array indicating whether IK computation was successful. + qpos_out (wp.array): Output array for computed joint positions. + + Notes: + This kernel computes the inverse kinematics for a batch of target poses in parallel. + It validates the computed joint positions against joint limits and the target pose. + Successful solutions are stored in the output arrays. + """ + tid = wp.tid() # Thread ID (for batch processing, if needed) + + # Extract indices + target_idx = int(combinations[tid][0]) + config_idx = int(combinations[tid][1]) + angle_idx = int(combinations[tid][2]) + + # Load inputs + target_xpos = target_xpos_list[target_idx] + config = configs[config_idx] + angle_ref = angles_list[angle_idx] + + # Extract shoulder, elbow, wrist configurations + shoulder_config, elbow_config, wrist_config = config.x, config.y, config.z + + # Transform target pose (xpos_ = target_xpos @ tcp_inv @ T_e_oe_inv) + # fmt: off + P_target = wp.vec3(target_xpos[0, 3], target_xpos[1, 3], target_xpos[2, 3]) + R_target = wp.mat33( + target_xpos[0, 0], target_xpos[0, 1], target_xpos[0, 2], + target_xpos[1, 0], target_xpos[1, 1], target_xpos[1, 2], + target_xpos[2, 0], target_xpos[2, 1], target_xpos[2, 2], + ) + # fmt: on + + # Compute shoulder-to-wrist vector + P02 = wp.vec3(0.0, 0.0, link_lengths[0]) + P67 = wp.vec3(0.0, 0.0, dh_params[12]) + P06 = P_target - R_target @ P67 + P26 = P06 - P02 + + calculate_arm_joint_angles( + P26, elbow_config, link_lengths, res_arm_angles, joints_arm, tid + ) + if res_arm_angles[tid] == 0: + success[tid] = 0 + return + joints_v = joints_arm[tid] + + # fmt: off + # Calculate transformations + T34 = dh_transform( + dh_params[12], + dh_params[13], + dh_params[14], + joints_v[3], + ) + R34 = wp.mat33( + T34[0, 0], T34[0, 1], T34[0, 2], + T34[1, 0], T34[1, 1], T34[1, 2], + T34[2, 0], T34[2, 1], T34[2, 2], + ) + # fmt: on + + # Calculate reference joint angles + compute_reference_plane( + target_xpos, + elbow_config, + link_lengths, + dh_params, + res_plane_normal, + plane_normal, + base_to_elbow_rotation, + joints_plane, + tid, + ) + if res_plane_normal[tid] == 0: + success[tid] = 0 + return + + R03_o = base_to_elbow_rotation[tid] + + usw = wp.normalize(P26) + skew_usw = skew(usw) + s_psi = wp.sin(angle_ref) + c_psi = wp.cos(angle_ref) + + # Calculate shoulder joint angles (q1, q2, q3) + As = skew_usw @ R03_o + Bs = -skew_usw @ skew_usw @ R03_o + Cs = wp.outer(usw, usw) @ R03_o + R03 = ( + (skew_usw @ R03_o) * s_psi + + (-skew_usw @ skew_usw @ R03_o) * c_psi + + (wp.outer(usw, usw) @ R03_o) + ) + + # TODO: judgment shoulder singularity + q1 = wp.atan2(R03[1, 1] * shoulder_config, R03[0, 1] * shoulder_config) + q2 = safe_acos(R03[2, 1]) * shoulder_config + q3 = wp.atan2(-R03[2, 2] * shoulder_config, -R03[2, 0] * shoulder_config) + + # Calculate wrist joint angles (q5, q6, q7) + Aw = wp.transpose(R34) @ wp.transpose(As) @ R_target + Bw = wp.transpose(R34) @ wp.transpose(Bs) @ R_target + Cw = wp.transpose(R34) @ wp.transpose(Cs) @ R_target + R47 = Aw * s_psi + Bw * c_psi + Cw + + q4 = joints_v[3] + # TODO: judgment wrist singularity + q5 = wp.atan2(R47[1, 2] * wrist_config, R47[0, 2] * wrist_config) + q6 = safe_acos(R47[2, 2]) * wrist_config + q7 = wp.atan2(R47[2, 1] * wrist_config, -R47[2, 0] * wrist_config) + + out_of_limits = int(0) + + q1_val = (q1 - dh_params[3]) * rotation_directions[0] + q2_val = (q2 - dh_params[7]) * rotation_directions[1] + q3_val = (q3 - dh_params[11]) * rotation_directions[2] + q4_val = (q4 - dh_params[15]) * rotation_directions[3] + q5_val = (q5 - dh_params[19]) * rotation_directions[4] + q6_val = (q6 - dh_params[23]) * rotation_directions[5] + q7_val = (q7 - dh_params[27]) * rotation_directions[6] + + out_of_limits = int(0) + out_of_limits = out_of_limits | ( + 1 if (q1_val < qpos_limits[0][0] or q1_val > qpos_limits[0][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q2_val < qpos_limits[1][0] or q2_val > qpos_limits[1][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q3_val < qpos_limits[2][0] or q3_val > qpos_limits[2][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q4_val < qpos_limits[3][0] or q4_val > qpos_limits[3][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q5_val < qpos_limits[4][0] or q5_val > qpos_limits[4][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q6_val < qpos_limits[5][0] or q6_val > qpos_limits[5][1]) else 0 + ) + out_of_limits = out_of_limits | ( + 1 if (q7_val < qpos_limits[6][0] or q7_val > qpos_limits[6][1]) else 0 + ) + + # Check joint limits + if out_of_limits == 1: + success[tid] = 0 + return + + is_valid = validate_fk_with_target( + q1=q1_val, + q2=q2_val, + q3=q3_val, + q4=q4_val, + q5=q5_val, + q6=q6_val, + q7=q7_val, + dh_params=dh_params, + rotation_directions=rotation_directions, + target_xpos=target_xpos, + tolerance=1e-4, + ) + + # Save joint angles only if valid + if is_valid: + qpos_out[tid * 7] = q1_val + qpos_out[tid * 7 + 1] = q2_val + qpos_out[tid * 7 + 2] = q3_val + qpos_out[tid * 7 + 3] = q4_val + qpos_out[tid * 7 + 4] = q5_val + qpos_out[tid * 7 + 5] = q6_val + qpos_out[tid * 7 + 6] = q7_val + success[tid] = 1 # Mark as successful + else: + success[tid] = 0 # Mark as failed + + +@wp.kernel +def sort_ik_kernel( + qpos_out: wp.array(dtype=float), # [N * N_SOL, 7] + success: wp.array(dtype=int), # [N * N_SOL] + qpos_seed: wp.array(dtype=float), # [N, 7] + ik_weight: wp.array(dtype=float), # [7] + distances: wp.array(dtype=float), # [N, N_SOL] + indices: wp.array(dtype=int), # [N, N_SOL] + N_SOL: int, + sorted_qpos: wp.array(dtype=float), # [N, N_SOL, 7] + sorted_valid: wp.array(dtype=int), # [N, N_SOL] +): + """ + Sort inverse kinematics (IK) solutions for multiple targets based on their distances + to a seed configuration. + + Args: + qpos_out (wp.array): Array of computed joint positions for all solutions + ([N * N_SOL, 7]). + success (wp.array): Array indicating whether each solution is valid ([N * N_SOL]). + qpos_seed (wp.array): Array of seed joint positions for each target ([N, 7]). + ik_weight (wp.array): Array of weights for each joint to compute distance ([7]). + distances (wp.array): Output array to store computed distances ([N, N_SOL]). + indices (wp.array): Output array to store sorted indices ([N, N_SOL]). + N_SOL (int): Number of solutions per target. + sorted_qpos (wp.array): Output array for sorted joint positions ([N, N_SOL, 7]). + sorted_valid (wp.array): Output array for sorted validity flags ([N, N_SOL]). + """ + tid = wp.tid() # target index + + # 1. compute distances + for i in range(N_SOL): + idx = tid * N_SOL + i + valid = success[idx] + dist = 0.0 + if valid: + for j in range(7): + diff = qpos_out[idx * 7 + j] - qpos_seed[tid * 7 + j] + dist += ik_weight[j] * diff * diff + else: + dist = 1e10 + + distances[idx] = dist + indices[idx] = i + + # 2. bubble sort (only sort the N_SOL solutions for the current target) + for i in range(N_SOL): + min_idx = i + for j in range(i + 1, N_SOL): + idx_a = tid * N_SOL + min_idx + idx_b = tid * N_SOL + j + if distances[idx_b] < distances[idx_a]: + min_idx = j + # Swap + if min_idx != i: + idx_i = tid * N_SOL + i + idx_min = tid * N_SOL + min_idx + tmp_dist = distances[idx_i] + distances[idx_i] = distances[idx_min] + distances[idx_min] = tmp_dist + tmp_idx = indices[idx_i] + indices[idx_i] = indices[idx_min] + indices[idx_min] = tmp_idx + + # 3. reorder qpos_out and success according to sorted indices + for i in range(N_SOL): + src_idx = tid * N_SOL + indices[tid * N_SOL + i] + for j in range(7): + sorted_qpos[(tid * N_SOL + i) * 7 + j] = qpos_out[src_idx * 7 + j] + sorted_valid[tid * N_SOL + i] = success[src_idx] + + +@wp.kernel +def nearest_ik_kernel( + qpos_out: wp.array(dtype=float), # [N * N_SOL * 7] + success: wp.array(dtype=int), # [N * N_SOL] + qpos_seed: wp.array(dtype=float), # [N * 7] + ik_weight: wp.array(dtype=float), # [7] + N_SOL: int, + nearest_qpos: wp.array(dtype=float), # [N * 7] + nearest_valid: wp.array(dtype=int), # [N] +): + """ + Find the nearest valid inverse kinematics (IK) solution for each target. + + Args: + qpos_out (wp.array): Array of computed joint positions for all solutions + ([N * N_SOL, 7]). + success (wp.array): Array indicating whether each solution is valid ([N * N_SOL]). + qpos_seed (wp.array): Array of seed joint positions for each target ([N, 7]). + ik_weight (wp.array): Array of weights for each joint to compute distance ([7]). + N_SOL (int): Number of solutions per target. + nearest_qpos (wp.array): Output array for the nearest joint positions ([N, 7]). + nearest_valid (wp.array): Output array indicating whether a valid solution was found ([N]). + """ + + tid = wp.tid() # target index + + min_dist = float(1e20) + nearest_idx = int(-1) + + for i in range(N_SOL): + idx = tid * N_SOL + i + if success[idx]: + dist = 0.0 + for j in range(7): + diff = qpos_out[idx * 7 + j] - qpos_seed[tid * 7 + j] + dist += ik_weight[j] * diff * diff + if dist < min_dist: + min_dist = dist + nearest_idx = idx + + if nearest_idx >= 0: + for j in range(7): + nearest_qpos[tid * 7 + j] = qpos_out[nearest_idx * 7 + j] + nearest_valid[tid] = 1 + else: + for j in range(7): + nearest_qpos[tid * 7 + j] = 0.0 + nearest_valid[tid] = 0 + + +@wp.kernel +def check_success_kernel( + success_wp: wp.array(dtype=int), + num_solutions: int, + success_counts: wp.array(dtype=int), +): + """ + Count the number of successful inverse kinematics (IK) solutions for each target. + + Args: + success_wp (wp.array): Array indicating whether each solution is valid + ([N * num_solutions], where N is the number of targets). + num_solutions (int): Number of solutions per target. + success_counts (wp.array): Output array to store the count of valid solutions + for each target ([N]). + """ + tid = wp.tid() # target index + count = int(0) + + for i in range(num_solutions): + idx = tid * num_solutions + i + if success_wp[idx]: + count += 1 + + success_counts[tid] = count diff --git a/embodichain/utils/warp/kinematics/warp_trajectory.py b/embodichain/utils/warp/kinematics/warp_trajectory.py new file mode 100644 index 00000000..535136b3 --- /dev/null +++ b/embodichain/utils/warp/kinematics/warp_trajectory.py @@ -0,0 +1,180 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import warp as wp + + +@wp.kernel +def trajectory_get_diff_kernel( + in_trajectory: wp.array(dtype=float), + key_indices: wp.array(dtype=int), + key_frames: wp.array(dtype=float), + waypoint_num: int, + dof: int, + key_frame_num: int, + warp_trajectory: wp.array(dtype=float), +): + """warp trajectory get diff kernel + + Args: + in_trajectory (wp.array, optional): (waypoint_num * dof) of float. Input trajectory. + key_indices (wp.array, optional): (key_frame_num,) of int. Key frame indices. + key_frames (wp.array, optional): (bn * key_frame_num * dof) of float. Batch key frames. + waypoint_num (int): number of waypoints. + dof (int): number of degrees of freedom. + key_frame_num (int): number of key frames. + warp_trajectory (wp.array, optional): (bn * waypoint_num * dof) of float. Output warp trajectory. + """ + arena_id, dim = wp.tid() + + # write_diff + for i in range(key_frame_num): + key_id = key_indices[i] + warp_id = arena_id * waypoint_num * dof + key_id * dof + dim + key_frame_id = arena_id * key_frame_num * dof + i * dof + dim + in_trajectory_id = key_id * dof + dim + warp_trajectory[warp_id] = ( + key_frames[key_frame_id] - in_trajectory[in_trajectory_id] + ) + + +@wp.kernel +def trajectory_interpolate_kernel( + key_indices: wp.array(dtype=int), + waypoint_num: int, + dof: int, + key_frame_num: int, + warp_trajectory: wp.array(dtype=float), +): + """warp trajectory interpolate kernel + + Args: + key_indices (wp.array, optional): (key_frame_num,) of int. Key frame indices. + waypoint_num (int): number of waypoints. + dof (int): number of degrees of freedom. + key_frame_num (int): number of key frames. + warp_trajectory (wp.array, optional): (bn * waypoint_num * dof) of float. Output warp trajectory. + """ + arena_id, waypoint_id, dim = wp.tid() + inter_warp_id = arena_id * waypoint_num * dof + waypoint_id * dof + dim + + start_id = int(-1) + end_id = int(-1) + # find start id and end id + # assume key_indices is sorted, start from 0, end at waypoint_num - 1 + for i in range(key_frame_num): + key_id = key_indices[i] + # to the final one + if waypoint_id >= key_id: + start_id = key_id + end_id = key_indices[i + 1] + + if waypoint_id == end_id or waypoint_id == start_id: + # start | final key frame, only add to interp id + return + + if start_id == -1 or end_id == -1: + # invalid, do nothing + return + + alpha = float(waypoint_id - start_id) / float(end_id - start_id) + start_warp_id = arena_id * waypoint_num * dof + start_id * dof + dim + end_warp_id = arena_id * waypoint_num * dof + end_id * dof + dim + + warp_trajectory[inter_warp_id] = (1.0 - alpha) * warp_trajectory[ + start_warp_id + ] + alpha * warp_trajectory[end_warp_id] + + +@wp.kernel +def trajectory_add_origin_kernel( + in_trajectory: wp.array(dtype=float), + waypoint_num: int, + dof: int, + warp_trajectory: wp.array(dtype=float), +): + arena_id, waypoint_id, dim = wp.tid() + inter_warp_id = arena_id * waypoint_num * dof + waypoint_id * dof + dim + in_trajectory_id = waypoint_id * dof + dim + warp_trajectory[inter_warp_id] += in_trajectory[in_trajectory_id] + + +@wp.kernel +def get_offset_qpos_kernel( + key_obj_indices: wp.array(dtype=int), + obj_offset: wp.array(dtype=float), + key_xpos: wp.array(dtype=float), + base_xpos_inv: wp.mat44f, + n_batch: int, + n_keyframe: int, + key_xpos_offset: wp.array(dtype=float), +): + batch_id, key_id = wp.tid() + obj_idx = key_obj_indices[key_id] + obj_offset_idx = n_batch * obj_idx + batch_id + obj_offset_pose = wp.mat44f( + obj_offset[obj_offset_idx * 16 + 0], + obj_offset[obj_offset_idx * 16 + 1], + obj_offset[obj_offset_idx * 16 + 2], + obj_offset[obj_offset_idx * 16 + 3], + obj_offset[obj_offset_idx * 16 + 4], + obj_offset[obj_offset_idx * 16 + 5], + obj_offset[obj_offset_idx * 16 + 6], + obj_offset[obj_offset_idx * 16 + 7], + obj_offset[obj_offset_idx * 16 + 8], + obj_offset[obj_offset_idx * 16 + 9], + obj_offset[obj_offset_idx * 16 + 10], + obj_offset[obj_offset_idx * 16 + 11], + obj_offset[obj_offset_idx * 16 + 12], + obj_offset[obj_offset_idx * 16 + 13], + obj_offset[obj_offset_idx * 16 + 14], + obj_offset[obj_offset_idx * 16 + 15], + ) + key_xpos_single = wp.mat44f( + key_xpos[key_id * 16 + 0], + key_xpos[key_id * 16 + 1], + key_xpos[key_id * 16 + 2], + key_xpos[key_id * 16 + 3], + key_xpos[key_id * 16 + 4], + key_xpos[key_id * 16 + 5], + key_xpos[key_id * 16 + 6], + key_xpos[key_id * 16 + 7], + key_xpos[key_id * 16 + 8], + key_xpos[key_id * 16 + 9], + key_xpos[key_id * 16 + 10], + key_xpos[key_id * 16 + 11], + key_xpos[key_id * 16 + 12], + key_xpos[key_id * 16 + 13], + key_xpos[key_id * 16 + 14], + key_xpos[key_id * 16 + 15], + ) + key_xpos_offset_i = base_xpos_inv * key_xpos_single * obj_offset_pose + key_xpos_offset_idx = batch_id * n_keyframe + key_id + key_xpos_offset[key_xpos_offset_idx * 16 + 0] = key_xpos_offset_i[0][0] + key_xpos_offset[key_xpos_offset_idx * 16 + 1] = key_xpos_offset_i[0][1] + key_xpos_offset[key_xpos_offset_idx * 16 + 2] = key_xpos_offset_i[0][2] + key_xpos_offset[key_xpos_offset_idx * 16 + 3] = key_xpos_offset_i[0][3] + key_xpos_offset[key_xpos_offset_idx * 16 + 4] = key_xpos_offset_i[1][0] + key_xpos_offset[key_xpos_offset_idx * 16 + 5] = key_xpos_offset_i[1][1] + key_xpos_offset[key_xpos_offset_idx * 16 + 6] = key_xpos_offset_i[1][2] + key_xpos_offset[key_xpos_offset_idx * 16 + 7] = key_xpos_offset_i[1][3] + key_xpos_offset[key_xpos_offset_idx * 16 + 8] = key_xpos_offset_i[2][0] + key_xpos_offset[key_xpos_offset_idx * 16 + 9] = key_xpos_offset_i[2][1] + key_xpos_offset[key_xpos_offset_idx * 16 + 10] = key_xpos_offset_i[2][2] + key_xpos_offset[key_xpos_offset_idx * 16 + 11] = key_xpos_offset_i[2][3] + key_xpos_offset[key_xpos_offset_idx * 16 + 12] = key_xpos_offset_i[3][0] + key_xpos_offset[key_xpos_offset_idx * 16 + 13] = key_xpos_offset_i[3][1] + key_xpos_offset[key_xpos_offset_idx * 16 + 14] = key_xpos_offset_i[3][2] + key_xpos_offset[key_xpos_offset_idx * 16 + 15] = key_xpos_offset_i[3][3] diff --git a/examples/gym/pour_water.sh b/examples/gym/pour_water.sh new file mode 100755 index 00000000..49485f16 --- /dev/null +++ b/examples/gym/pour_water.sh @@ -0,0 +1,3 @@ +python -m embodichain.lab.scripts.run_env --gym_config configs/gym/pour_water/gym_config.json \ + --action_config configs/gym/pour_water/action_config.json \ + --filter_visual_rand diff --git a/examples/gym/run_scoop_ice.sh b/examples/gym/run_scoop_ice.sh new file mode 100755 index 00000000..269f4797 --- /dev/null +++ b/examples/gym/run_scoop_ice.sh @@ -0,0 +1 @@ +python -m embodichain.lab.scripts.run_env --gym_config configs/gym/scoop_ice/gym_config.json diff --git a/examples/sim/demo/grasp_cup_to_caffe.py b/examples/sim/demo/grasp_cup_to_caffe.py new file mode 100644 index 00000000..c1073c3f --- /dev/null +++ b/examples/sim/demo/grasp_cup_to_caffe.py @@ -0,0 +1,469 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with dexterous hands, +and performs a scoop ice task in a simulated environment. +""" + +import argparse +import numpy as np +import torch +from tqdm import tqdm +from typing import Union +from scipy.spatial.transform import Rotation as R +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject +from embodichain.lab.sim.cfg import ( + LightCfg, + JointDrivePropertiesCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, + ArticulationCfg, +) +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + +from embodichain.lab.sim.robots.dexforce_w1.cfg import DexforceW1Cfg + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=9, help="Number of parallel environments" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument("--headless", action="store_true", help="Enable headless mode") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + return parser.parse_args() + + +def initialize_simulation(args) -> SimulationManager: + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + ) + sim = SimulationManager(config) + + sim.build_multiple_arenas(args.num_envs, space=2.5) + # Set manual physics update for precise control + sim.set_manual_update(True) + + if args.enable_rt: + light = sim.add_light( + cfg=LightCfg( + uid="main_light", + color=(0.6, 0.6, 0.6), + intensity=30.0, + init_pos=(1.0, 0, 3.0), + ) + ) + + return sim + + +def create_robot(sim: SimulationManager) -> Robot: + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + cfg = DexforceW1Cfg.from_dict( + { + "uid": "dexforce_w1", + "init_pos": [0.4, -0.5, 0.0], + } + ) + cfg.solver_cfg["left_arm"].tcp = np.array( + [ + [1.0, 0.0, 0.0, 0.012], + [0.0, 1.0, 0.0, 0.04], + [0.0, 0.0, 1.0, 0.11], + [0.0, 0.0, 0.0, 1.0], + ] + ) + cfg.solver_cfg["right_arm"].tcp = np.array( + [ + [1.0, 0.0, 0.0, 0.012], + [0.0, 1.0, 0.0, -0.04], + [0.0, 0.0, 1.0, 0.11], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + cfg.init_qpos = [ + 1.0000e00, + -2.0000e00, + 1.0000e00, + 0.0000e00, + -2.6921e-05, + -2.6514e-03, + -1.5708e00, + 1.4575e00, + -7.8540e-01, + 1.2834e-01, + 1.5708e00, + -2.2310e00, + -7.8540e-01, + 1.4461e00, + -1.5708e00, + 1.6716e00, + 7.8540e-01, + 7.6745e-01, + 0.0000e00, + 3.8108e-01, + 0.0000e00, + 0.0000e00, + 0.0000e00, + 0.0000e00, + 1.5000e00, + 0.0000e00, + 0.0000e00, + 0.0000e00, + 0.0000e00, + 1.5000e00, + 6.9974e-02, + 7.3950e-02, + 6.6574e-02, + 6.0923e-02, + 0.0000e00, + 6.7342e-02, + 7.0862e-02, + 6.3684e-02, + 5.7822e-02, + 0.0000e00, + ] + return sim.add_robot(cfg=cfg) + + +def create_table(sim: SimulationManager) -> RigidObject: + """ + Create a table rigid object in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + RigidObject: The table object added to the simulation. + """ + scoop_cfg = RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("MultiW1Data/table_a.obj"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.5, + ), + max_convex_hull_num=8, + body_type="kinematic", + init_pos=[1.1, -0.5, 0.08], + init_rot=[0.0, 0.0, 0.0], + ) + scoop = sim.add_rigid_object(cfg=scoop_cfg) + return scoop + + +def create_caffe(sim: SimulationManager) -> Robot: + """ + Create a caffe (container) articulated object in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The caffe object added to the simulation. + """ + container_cfg = ArticulationCfg( + uid="caffe", + fpath=get_data_path("MultiW1Data/cafe/cafe.urdf"), + init_pos=[1.05, -0.5, 0.79], + init_rot=[0, 0, -30], + attrs=RigidBodyAttributesCfg( + mass=1.0, + ), + drive_pros=JointDrivePropertiesCfg( + stiffness=1.0, damping=0.1, max_effort=100.0, drive_type="force" + ), + ) + container = sim.add_articulation(cfg=container_cfg) + return container + + +def create_cup(sim: SimulationManager) -> RigidObject: + """ + Create a cup rigid object in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + RigidObject: The cup object added to the simulation. + """ + scoop_cfg = RigidObjectCfg( + uid="cup", + shape=MeshCfg( + fpath=get_data_path("MultiW1Data/paper_cup_2.obj"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.3, + ), + max_convex_hull_num=1, + body_type="dynamic", + init_pos=[0.86, -0.76, 0.841], + init_rot=[0.0, 0.0, 0.0], + ) + scoop = sim.add_rigid_object(cfg=scoop_cfg) + return scoop + + +def create_trajectory( + sim: SimulationManager, robot: Robot, cup: RigidObject, caffe: Robot +) -> torch.Tensor: + """ + Generate a trajectory for the right arm to grasp the cup and move it to the caffe. + + Args: + sim (SimulationManager): The simulation manager instance. + robot (Robot): The robot instance. + cup (RigidObject): The cup object. + caffe (Robot): The caffe object. + + Returns: + torch.Tensor: Interpolated trajectory of shape [n_envs, n_waypoint, dof]. + """ + right_arm_ids = robot.get_joint_ids("right_arm") + hand_open_qpos = torch.tensor( + [0.0, 1.5, 0.0, 0.0, 0.0, 0.0], + dtype=torch.float32, + device=sim.device, + ) + hand_close_qpos = torch.tensor( + [0.1, 1.5, 0.3, 0.2, 0.3, 0.3], + dtype=torch.float32, + device=sim.device, + ) + + cup_position = cup.get_local_pose(to_matrix=True)[:, :3, 3] + + # grasp cup waypoint generation + rest_right_qpos = robot.get_qpos()[:, right_arm_ids] # [n_envs, dof] + right_arm_xpos = robot.compute_fk( + qpos=rest_right_qpos, name="right_arm", to_matrix=True + ) + approach_cup_relative_position = torch.tensor( + [-0.05, -0.06, 0.025], dtype=torch.float32, device=sim.device + ) + pick_cup_relative_position = torch.tensor( + [-0.03, -0.028, 0.021], dtype=torch.float32, device=sim.device + ) + + approach_xpos = right_arm_xpos.clone() + approach_xpos[:, :3, 3] = cup_position + approach_cup_relative_position + + pick_xpos = right_arm_xpos.clone() + pick_xpos[:, :3, 3] = cup_position + pick_cup_relative_position + + lift_xpos = pick_xpos.clone() + lift_xpos[:, 2, 3] += 0.07 + + # place cup to caffe waypoint generation + caffe_position = caffe.get_local_pose(to_matrix=True)[:, :3, 3] + place_cup_up_relative_position = torch.tensor( + [-0.14, -0.18, 0.13], dtype=torch.float32, device=sim.device + ) + place_cup_down_relative_position = torch.tensor( + [-0.14, -0.18, 0.09], dtype=torch.float32, device=sim.device + ) + + place_cup_up_pose = lift_xpos.clone() + place_cup_up_pose[:, :3, 3] = caffe_position + place_cup_up_relative_position + place_down_pose = lift_xpos.clone() + place_down_pose[:, :3, 3] = caffe_position + place_cup_down_relative_position + # compute ik for each waypoint + is_success, approach_qpos = robot.compute_ik( + pose=approach_xpos, joint_seed=rest_right_qpos, name="right_arm" + ) + is_success, pick_qpos = robot.compute_ik( + pose=pick_xpos, joint_seed=approach_qpos, name="right_arm" + ) + is_success, lift_qpos = robot.compute_ik( + pose=lift_xpos, joint_seed=pick_qpos, name="right_arm" + ) + is_success, place_up_qpos = robot.compute_ik( + pose=place_cup_up_pose, joint_seed=lift_qpos, name="right_arm" + ) + is_success, place_down_qpos = robot.compute_ik( + pose=place_down_pose, joint_seed=place_up_qpos, name="right_arm" + ) + + n_envs = sim.num_envs + + # combine hand and arm trajectory + arm_trajectory = torch.cat( + [ + rest_right_qpos[:, None, :], + approach_qpos[:, None, :], + pick_qpos[:, None, :], + pick_qpos[:, None, :], + lift_qpos[:, None, :], + place_up_qpos[:, None, :], + place_down_qpos[:, None, :], + place_down_qpos[:, None, :], + lift_qpos[:, None, :], + rest_right_qpos[:, None, :], + ], + dim=1, + ) + hand_trajectory = torch.cat( + [ + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_close_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + hand_open_qpos[None, None, :].repeat(n_envs, 1, 1), + ], + dim=1, + ) + all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) + # trajetory with shape [n_envs, n_waypoint, dof] + interp_trajectory = interpolate_with_distance_warp( + trajectory=all_trajectory, interp_num=150, device=sim.device + ) + return interp_trajectory + + +def run_simulation( + sim: SimulationManager, robot: Robot, cup: RigidObject, caffe: Robot +): + """ + Execute the generated trajectory to drive the robot to complete the grasp and place task. + + Args: + sim (SimulationManager): The simulation manager instance. + robot (Robot): The robot instance. + cup (RigidObject): The cup object. + caffe (Robot): The caffe object. + """ + # [n_envs, n_waypoint, dof] + interp_trajectory = create_trajectory(sim, robot, cup, caffe) + + right_arm_ids = robot.get_joint_ids("right_arm") + right_hand_ids = robot.get_joint_ids("right_eef") + combine_ids = np.concatenate([right_arm_ids, right_hand_ids]) + n_waypoints = interp_trajectory.shape[1] + logger.log_info(f"Executing trajectory...") + for i in tqdm(range(n_waypoints)): + robot.set_qpos(interp_trajectory[:, i, :], joint_ids=combine_ids) + sim.update(step=10) + + +def apply_random_xy_perturbation( + item: Union[RigidObject, Robot], max_perturbation: float = 0.02 +): + """ + Apply random perturbation to the object's XY position. + + Args: + item (Union[RigidObject, Robot]): The object to perturb. + max_perturbation (float): Maximum perturbation magnitude. + """ + item_pose = item.get_local_pose(to_matrix=True) + item_xy = item_pose[:, :2, 3].to("cpu").numpy() + perturbation = np.random.uniform( + low=-max_perturbation, high=max_perturbation, size=item_xy.shape + ) + new_xy = item_xy + perturbation + item_pose[:, :2, 3] = torch.tensor( + new_xy, dtype=torch.float32, device=item_pose.device + ) + item.set_local_pose(item_pose) + + +def main(): + """ + Main function to demonstrate robot simulation. + + Initializes the simulation, creates the robot and objects, and performs the grasp and place task. + """ + args = parse_arguments() + sim = initialize_simulation(args) + + robot = create_robot(sim) + table = create_table(sim) + caffe = create_caffe(sim) + cup = create_cup(sim) + + # apply random perturbation + apply_random_xy_perturbation(cup, max_perturbation=0.05) + apply_random_xy_perturbation(caffe, max_perturbation=0.05) + + if not args.headless: + sim.open_window() + + run_simulation(sim, robot, cup, caffe) + + logger.log_info("\n Press Ctrl+C to exit simulation loop.") + try: + counter = 0 + while True: + counter += 1 + sim.update(step=10) + if counter % 10 == 0: + pass + + except KeyboardInterrupt: + logger.log_info("\n Exit") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/demo/press_softbody.py b/examples/sim/demo/press_softbody.py new file mode 100644 index 00000000..2d353aad --- /dev/null +++ b/examples/sim/demo/press_softbody.py @@ -0,0 +1,261 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with dexterous hands, +and performs a scoop ice task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import torch +from tqdm import tqdm +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject, RigidObjectGroup +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + URDFCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, + ArticulationCfg, + RigidObjectGroupCfg, + LightCfg, +) +from embodichain.lab.sim.material import VisualMaterialCfg +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.shapes import MeshCfg, CubeCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger +from dexsim.utility.path import get_resources_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RigidBodyAttributesCfg, + SoftbodyVoxelAttributesCfg, + SoftbodyPhysicalAttributesCfg, +) +from embodichain.lab.sim.shapes import CubeCfg, MeshCfg +from embodichain.lab.sim.objects import ( + RigidObject, + RigidObjectCfg, + SoftObject, + SoftObjectCfg, +) + + +def parse_arguments(): + """ + Parse command-line arguments to configure the simulation. + + Returns: + argparse.Namespace: Parsed arguments including number of environments, device, and rendering options. + """ + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + return parser.parse_args() + + +def initialize_simulation(args): + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device="cuda", + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + ) + sim = SimulationManager(config) + + light = sim.add_light( + cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0)) + ) + + # Set manual physics update for precise control + sim.set_manual_update(True) + return sim + + +def create_robot(sim: SimulationManager): + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + hand_urdf_path = get_data_path( + "BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf" + ) + + # Define transformation for attaching the hand to the arm + hand_attach_xpos = np.eye(4) + hand_attach_xpos[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix() + + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="ur10_with_brainco", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + ] + ), + control_parts={ + "arm": ["Joint[0-9]"], + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={ + "Joint[0-9]": 1e4, + }, + damping={ + "Joint[0-9]": 1e3, + }, + max_effort={ + "Joint[0-9]": 1e5, + }, + drive_type="force", + ), + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=np.eye(4), + ) + }, + init_qpos=[ + 0.0, + -np.pi / 2, + -np.pi / 2, + np.pi / 2, + -np.pi / 2, + 0.0, + ], + ) + return sim.add_robot(cfg=cfg) + + +def create_soft_cow(sim: SimulationManager) -> SoftObject: + """create soft cow object in the simulation + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + SoftObject: soft cow object + """ + cow: SoftObject = sim.add_soft_object( + cfg=SoftObjectCfg( + uid="cow", + shape=MeshCfg( + fpath=get_resources_data_path("Model", "cow", "cow2.obj"), + ), + init_pos=[0.5, 0.0, 0.3], + voxel_attr=SoftbodyVoxelAttributesCfg( + simulation_mesh_resolution=8, + maximal_edge_length=0.5, + ), + physical_attr=SoftbodyPhysicalAttributesCfg( + youngs=1e4, + poissons=0.45, + density=100, + dynamic_friction=0.1, + min_position_iters=30, + ), + ), + ) + return cow + + +def press_cow(sim: SimulationManager, robot: Robot): + """robot press cow softbody with its end link + + Args: + sim (SimulationManager): The simulation manager instance. + robot (Robot): The robot instance to be controlled. + """ + start_qpos = robot.get_qpos() + arm_ids = robot.get_joint_ids("arm") + arm_start_qpos = start_qpos[:, arm_ids] + + arm_start_xpos = robot.compute_fk(arm_start_qpos, name="arm", to_matrix=True) + press_xpos = arm_start_xpos.clone() + press_xpos[:, :3, 3] = torch.tensor([0.5, -0.1, 0.01], device=press_xpos.device) + + approach_xpos = press_xpos.clone() + approach_xpos[:, 2, 3] += 0.05 + + is_success, approach_qpos = robot.compute_ik( + approach_xpos, joint_seed=arm_start_qpos, name="arm" + ) + is_success, press_qpos = robot.compute_ik( + approach_xpos, joint_seed=arm_start_qpos, name="arm" + ) + + arm_trajectory = torch.concatenate([arm_start_qpos, approach_qpos, press_qpos]) + interp_trajectory = interpolate_with_distance_warp( + trajectory=arm_trajectory[None, :, :], interp_num=50, device=sim.device + ) + interp_trajectory = interp_trajectory[0] + for qpos in interp_trajectory: + robot.set_qpos(qpos.unsqueeze(0), joint_ids=arm_ids) + sim.update(step=5) + + +def main(): + """ + Main function to demonstrate robot simulation. + + This function initializes the simulation, creates the robot and other objects, + and performs the press softbody task. + """ + args = parse_arguments() + sim = initialize_simulation(args) + + robot = create_robot(sim) + soft_cow = create_soft_cow(sim) + sim.init_gpu_physics() + sim.open_window() + + press_cow(sim, robot) + + logger.log_info("\n Press Ctrl+C to exit simulation loop.") + try: + while True: + sim.update(step=10) + except KeyboardInterrupt: + logger.log_info("\n Exit") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/demo/scoop_ice.py b/examples/sim/demo/scoop_ice.py new file mode 100644 index 00000000..941f6d6d --- /dev/null +++ b/examples/sim/demo/scoop_ice.py @@ -0,0 +1,569 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates the creation and simulation of a robot with dexterous hands, +and performs a scoop ice task in a simulated environment. +""" + +import argparse +import numpy as np +import time +import torch +from tqdm import tqdm +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot, RigidObject, RigidObjectGroup +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + URDFCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, + ArticulationCfg, + RigidObjectGroupCfg, + LightCfg, +) +from embodichain.lab.sim.material import VisualMaterialCfg +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.shapes import MeshCfg, CubeCfg +from embodichain.lab.sim.solvers import PytorchSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + + +def initialize_simulation(): + """ + Initialize the simulation environment based on the provided arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + + Returns: + SimulationManager: Configured simulation manager instance. + """ + config = SimulationManagerCfg( + headless=True, + sim_device="cpu", + enable_rt=True, + physics_dt=1.0 / 100.0, + ) + sim = SimulationManager(config) + + light = sim.add_light( + cfg=LightCfg(uid="main_light", intensity=50.0, init_pos=(0, 0, 2.0)) + ) + + # Set manual physics update for precise control + sim.set_manual_update(True) + return sim + + +def randomize_ice_positions(sim, ice_cubes): + """ + Randomly drop ice cubes into the container within a specified range. + + Args: + sim (SimulationManager): The simulation manager instance. + ice_cubes (RigidObjectGroup): Group of ice cube objects to be randomized. + """ + num_objs = ice_cubes.num_objects + position_low = np.array([0.65, -0.45, 0.5]) + position_high = np.array([0.55, -0.35, 0.5]) + position_random = np.random.uniform( + low=position_low, high=position_high, size=(num_objs, 3) + ) + random_drop_pose_np = np.eye(4)[None, :, :].repeat(num_objs, axis=0) + random_drop_pose_np[:, :3, 3] = position_random + + # Assign random positions to each ice cube + for i in tqdm(range(num_objs), desc="Dropping ice cubes"): + ice_cubes.set_local_pose( + pose=torch.tensor( + random_drop_pose_np[i][None, None, :, :], + dtype=torch.float32, + device=sim.device, + ), + obj_ids=[i], + ) + sim.update(step=10) + + +def create_robot(sim): + """ + Create and configure a robot with an arm and a dexterous hand in the simulation. + + Args: + sim (SimulationManager): The simulation manager instance. + + Returns: + Robot: The configured robot instance added to the simulation. + """ + # Retrieve URDF paths for the robot arm and hand + ur10_urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + hand_urdf_path = get_data_path( + "BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf" + ) + + # Define transformation for attaching the hand to the arm + hand_attach_xpos = np.eye(4) + hand_attach_xpos[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix() + + # Configure the robot with its components and control properties + cfg = RobotCfg( + uid="ur10_with_brainco", + urdf_cfg=URDFCfg( + components=[ + {"component_type": "arm", "urdf_path": ur10_urdf_path}, + { + "component_type": "hand", + "urdf_path": hand_urdf_path, + "transform": hand_attach_xpos, + }, + ] + ), + control_parts={ + "arm": ["JOINT[0-9]"], + "hand": [ + "LEFT_HAND_THUMB1", + "LEFT_HAND_THUMB2", + "LEFT_HAND_INDEX", + "LEFT_HAND_MIDDLE", + "LEFT_HAND_RING", + "LEFT_HAND_PINKY", + ], + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[0-9]": 1e4, "LEFT_[A-Z|_]+[0-9]?": 1e2}, + damping={"JOINT[0-9]": 1e3, "LEFT_[A-Z|_]+[0-9]?": 1e1}, + max_effort={"JOINT[0-9]": 1e5, "LEFT_[A-Z|_]+[0-9]?": 1e3}, + drive_type="force", + ), + solver_cfg={ + "arm": PytorchSolverCfg( + end_link_name="ee_link", + root_link_name="base_link", + tcp=np.eye(4), + ) + }, + init_qpos=[ + 0.0, + -np.pi / 2, + -np.pi / 2, + 2.5, + -np.pi / 2, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.5, + -0.00016, + -0.00010, + -0.00013, + -0.00009, + 0.0, + ], + ) + + return sim.add_robot(cfg=cfg) + + +def create_scoop(sim: SimulationManager): + """Create a scoop rigid object in the simulation.""" + scoop_cfg = RigidObjectCfg( + uid="scoop", + shape=MeshCfg( + fpath=get_data_path("ScoopIceNewEnv/scoop.ply"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.5, + static_friction=0.95, + dynamic_friction=0.9, + restitution=0.01, + min_position_iters=32, + min_velocity_iters=8, + ), + max_convex_hull_num=12, + body_type="dynamic", + init_pos=[0.6, 0.0, 0.09], + init_rot=[0.0, 0.0, 0.0], + ) + scoop = sim.add_rigid_object(cfg=scoop_cfg) + return scoop + + +def create_heave_ice(sim: SimulationManager): + """Create a heave ice rigid object in the simulation. Make sure that""" + heave_ice_cfg = RigidObjectCfg( + uid="heave_ice", + shape=MeshCfg( + fpath=get_data_path("ScoopIceNewEnv/ice_mesh_small/ice_000.obj"), + ), + attrs=RigidBodyAttributesCfg( + mass=0.5, + static_friction=0.95, + dynamic_friction=0.9, + restitution=0.01, + min_position_iters=32, + min_velocity_iters=8, + ), + body_type="dynamic", + init_pos=[10, 10, 0.08], + init_rot=[0.0, 0.0, 0.0], + ) + heave_ice = sim.add_rigid_object(cfg=heave_ice_cfg) + return heave_ice + + +def create_padding_box(sim: SimulationManager): + padding_box_cfg = RigidObjectCfg( + uid="padding_box", + shape=CubeCfg( + size=[0.1, 0.16, 0.05], + ), + attrs=RigidBodyAttributesCfg( + mass=1.0, + static_friction=0.95, + dynamic_friction=0.9, + restitution=0.01, + min_position_iters=32, + min_velocity_iters=8, + ), + body_type="kinematic", + init_pos=[0.6, 0.15, 0.025], + init_rot=[0.0, 0.0, 0.0], + ) + heave_ice = sim.add_rigid_object(cfg=padding_box_cfg) + return heave_ice + + +def create_container(sim: SimulationManager): + container_cfg = ArticulationCfg( + uid="container", + fpath=get_data_path("ScoopIceNewEnv/IceContainer/ice_container.urdf"), + init_pos=[0.7, -0.4, 0.21], + init_rot=[0, 0, -90], + attrs=RigidBodyAttributesCfg( + mass=1.0, + static_friction=0.95, + dynamic_friction=0.9, + restitution=0.01, + min_position_iters=32, + min_velocity_iters=8, + ), + drive_pros=JointDrivePropertiesCfg( + stiffness=1.0, damping=0.1, max_effort=100.0, drive_type="force" + ), + ) + container = sim.add_articulation(cfg=container_cfg) + return container + + +def create_ice_cubes(sim: SimulationManager): + ice_cubes_path = get_data_path("ScoopIceNewEnv/ice_mesh_small") + cfg_dict = { + "uid": "ice_cubes", + "max_num": 300, + "folder_path": ice_cubes_path, + "ext": ".obj", + "rigid_objects": { + "obj": { + "attrs": { + "mass": 0.003, + "contact_offset": 0.001, + "rest_offset": 0, + "dynamic_friction": 0.05, + "static_friction": 0.1, + "restitution": 0.01, + "min_position_iters": 32, + "min_velocity_iters": 4, + "max_depenetration_velocity": 1.0, + }, + "shape": {"shape_type": "Mesh"}, + "init_pos": [20.0, 0, 1.0], + } + }, + } + + ice_cubes_cfg = RigidObjectGroupCfg.from_dict(cfg_dict) + ice_cubes: RigidObjectGroup = sim.add_rigid_object_group(cfg=ice_cubes_cfg) + + # Set visual material for ice cubes. + # The material below only works for ray tracing backend. + # Set ior to 1.31 and material type to "BSDF_GGX_SMITH" for better ice appearance. + ice_mat = sim.create_visual_material( + cfg=VisualMaterialCfg( + base_color=[1.0, 1.0, 1.0, 1.0], + ior=1.31, + roughness=0.05, + rt_material_type="BSDF_GGX_SMITH", + ) + ) + ice_cubes.set_visual_material(mat=ice_mat) + + return ice_cubes + + +def scoop_grasp( + sim: SimulationManager, + robot: Robot, + scoop: RigidObject, + heave_ice: RigidObject, + padding_box: RigidObject, +): + """ + Control the robot to grasp the scoop object and position the heave ice for scooping. + + Args: + sim (SimulationManager): The simulation manager instance. + robot (Robot): The robot instance to be controlled. + scoop (RigidObject): The scoop object to be grasped. + heave_ice (RigidObject): The heave ice object to be positioned. + padding_box (RigidObject): The padding box object used as a reference for positioning. + """ + rest_qpos = robot.get_qpos() + arm_ids = robot.get_joint_ids("arm") + hand_ids = robot.get_joint_ids("hand") + hand_open_qpos = torch.tensor([0.0, 1.5, 0.4, 0.4, 0.4, 0.4]) + hand_close_qpos = torch.tensor([0.4, 1.5, 1.0, 1.1, 1.1, 0.9]) + arm_rest_qpos = rest_qpos[:, arm_ids] + + # Calculate and set the drop pose for the scoop object + padding_box_pose = padding_box.get_local_pose(to_matrix=True) + scoop_drop_relative_pose = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.115], + [0.0, 0.0, 1.0, 0.065], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=sim.device, + ) + scoop_drop_pose = torch.bmm( + padding_box_pose, + scoop_drop_relative_pose[None, :, :].repeat(sim.num_envs, 1, 1), + ) + scoop.set_local_pose(scoop_drop_pose) + + scoop_pose = scoop.get_local_pose(to_matrix=True) + + # tricky implementation + heave_ice_relative = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, -0.13], + [0.0, 0.0, 1.0, 0.04], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=sim.device, + )[None, :, :].repeat(sim.num_envs, 1, 1) + heave_ice_pose = torch.bmm(scoop_pose, heave_ice_relative) + heave_ice.set_local_pose(heave_ice_pose) + sim.update(step=200) + + # move hand to grasp scoop + scoop_pose = scoop.get_local_pose(to_matrix=True) + grasp_scoop_pose_relative = torch.tensor( + [ + [0.00522967, 0.6788424, 0.7342653, -0.05885637], + [0.99054945, 0.0971214, -0.09684561, 0.0301468], + [-0.13705578, 0.72783256, -0.6719191, 0.1040391], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=sim.device, + )[None, :, :].repeat(sim.num_envs, 1, 1) + + grasp_scoop_pose = torch.bmm(scoop_pose, grasp_scoop_pose_relative) + pregrasp_scoop_pose = grasp_scoop_pose.clone() + pregrasp_scoop_pose[:, 2, 3] += 0.1 + is_success, pre_grasp_scoop_qpos = robot.compute_ik( + pregrasp_scoop_pose, joint_seed=arm_rest_qpos, name="arm" + ) + + is_success, grasp_scoop_qpos = robot.compute_ik( + grasp_scoop_pose, joint_seed=arm_rest_qpos, name="arm" + ) + robot.set_qpos(pre_grasp_scoop_qpos, joint_ids=arm_ids) + sim.update(step=100) + robot.set_qpos(grasp_scoop_qpos, joint_ids=arm_ids) + sim.update(step=100) + + # close hand + robot.set_qpos(hand_close_qpos[None, :].repeat(sim.num_envs, 1), joint_ids=hand_ids) + sim.update(step=100) + + # remove heave ice + remove_heave_ice_pose = torch.tensor( + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 0.04], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=torch.float32, + device=sim.device, + ) + heave_ice.set_local_pose(remove_heave_ice_pose[None, :, :]) + + +def scoop_ice(sim: SimulationManager, robot: Robot, scoop: RigidObject): + """ + Control the robot to perform the scoop ice task, including lifting, scooping, + and placing the ice. + + Args: + sim (SimulationManager): The simulation manager instance. + robot (Robot): The robot instance to be controlled. + scoop (RigidObject): The scoop object used for scooping ice. + """ + start_qpos = robot.get_qpos() + arm_ids = robot.get_joint_ids("arm") + hand_ids = robot.get_joint_ids("hand") + hand_open_qpos = torch.tensor([0.0, 1.5, 0.4, 0.4, 0.4, 0.4]) + hand_close_qpos = torch.tensor([0.4, 1.5, 1.0, 1.1, 1.1, 0.9]) + arm_start_qpos = start_qpos[:, arm_ids] + + # lift + arm_start_xpos = robot.compute_fk(arm_start_qpos, name="arm", to_matrix=True) + arm_lift_xpos = arm_start_xpos.clone() + arm_lift_xpos[:, 2, 3] += 0.45 + is_success, arm_lift_qpos = robot.compute_ik( + arm_lift_xpos, joint_seed=arm_start_qpos, name="arm" + ) + + # apply 45 degree wrist rotation + wrist_rotation = R.from_euler("X", 45, degrees=True).as_matrix() + arm_lift_rotation = arm_lift_xpos[0, :3, :3].to("cpu").numpy() + new_rotation = wrist_rotation @ arm_lift_rotation + arm_lift_xpos_rotated = arm_lift_xpos.clone() + arm_lift_xpos_rotated[:, :3, :3] = torch.tensor( + new_rotation, dtype=torch.float32, device=sim.device + ) + arm_lift_xpos_rotated[:, :3, 3] = torch.tensor( + [0.5, -0.2, 0.55], dtype=torch.float32, device=sim.device + ) + is_success, arm_lift_qpos_rotated = robot.compute_ik( + arm_lift_xpos_rotated, joint_seed=arm_lift_qpos, name="arm" + ) + + # into container + scoop_dis = 0.252 + scoop_offset = scoop_dis * torch.tensor( + [0.0, -0.58123819, -0.81373347], dtype=torch.float32, device=sim.device + ) + arm_into_container_xpos = arm_lift_xpos_rotated.clone() + arm_into_container_xpos[:, :3, 3] = arm_into_container_xpos[:, :3, 3] + scoop_offset + is_success, arm_into_container_qpos = robot.compute_ik( + arm_into_container_xpos, joint_seed=arm_lift_qpos_rotated, name="arm" + ) + + # apply -60 degree wrist rotation + arm_into_container_rotation = arm_into_container_xpos[0, :3, :3].to("cpu").numpy() + wrist_rotation = R.from_euler("X", -60, degrees=True).as_matrix() + new_rotation = wrist_rotation @ arm_into_container_rotation + arm_scoop_xpos = arm_into_container_xpos.clone() + arm_scoop_xpos[:, :3, :3] = torch.tensor( + new_rotation, dtype=torch.float32, device=sim.device + ) + is_success, arm_scoop_qpos = robot.compute_ik( + arm_scoop_xpos, joint_seed=arm_into_container_qpos, name="arm" + ) + + # minor lift + arm_scoop_xpos[:, 2, 3] += 0.15 + is_success, arm_scoop_lift_qpos = robot.compute_ik( + arm_scoop_xpos, joint_seed=arm_scoop_qpos, name="arm" + ) + + # pack arm and hand trajectory + arm_trajectory = torch.concatenate( + [ + arm_start_qpos, + arm_lift_qpos, + arm_lift_qpos_rotated, + arm_into_container_qpos, + arm_scoop_qpos, + arm_scoop_lift_qpos, + ] + ) + + hand_trajectory = torch.vstack( + [ + hand_close_qpos, + hand_close_qpos, + hand_close_qpos, + hand_close_qpos, + hand_close_qpos, + hand_close_qpos, + ] + ) + + all_trajectory = torch.hstack([arm_trajectory, hand_trajectory]) + interp_trajectory = interpolate_with_distance_warp( + trajectory=all_trajectory[None, :, :], interp_num=200, device=sim.device + ) + interp_trajectory = interp_trajectory[0] + # run trajectory + arm_ids = robot.get_joint_ids("arm") + hand_ids = robot.get_joint_ids("hand") + combine_ids = np.concatenate([arm_ids, hand_ids]) + for qpos in interp_trajectory: + robot.set_qpos(qpos.unsqueeze(0), joint_ids=combine_ids) + sim.update(step=10) + + +def main(): + """ + Main function to demonstrate robot simulation. + + This function initializes the simulation, creates the robot and other objects, + and performs the scoop ice task. + """ + sim = initialize_simulation() + + # Create simulation objects + robot = create_robot(sim) + container = create_container(sim) + padding_box = create_padding_box(sim) + scoop = create_scoop(sim) + heave_ice = create_heave_ice(sim) + ice_cubes = create_ice_cubes(sim) + + sim.open_window() + + # Randomize ice positions + randomize_ice_positions(sim, ice_cubes) + + # Perform tasks + scoop_grasp(sim, robot, scoop, heave_ice, padding_box) + scoop_ice(sim, robot, scoop) + + logger.log_info("\n Press Ctrl+C to exit simulation loop.") + try: + while True: + # sim.update(step=10) + time.sleep(1e-2) + except KeyboardInterrupt: + logger.log_info("\n Exit") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/gizmo/gizmo_camera.py b/examples/sim/gizmo/gizmo_camera.py new file mode 100644 index 00000000..2b782a08 --- /dev/null +++ b/examples/sim/gizmo/gizmo_camera.py @@ -0,0 +1,236 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +This script demonstrates how to use the Gizmo class for interactive camera control. +It shows how to create a gizmo attached to a camera for real-time pose manipulation. +""" + +import argparse +import cv2 +import numpy as np +import time +import torch + +torch.set_printoptions(precision=4, sci_mode=False) + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.sensors import Camera, CameraCfg +from embodichain.lab.sim.cfg import RigidObjectCfg, RigidBodyAttributesCfg +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.utils import logger + + +def main(): + """Main function to demonstrate camera gizmo manipulation.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create and simulate a camera with gizmo in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of environments to simulate" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Device to run simulation on", + ) + parser.add_argument("--headless", action="store_true", help="Run in headless mode") + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=args.enable_rt, + ) + + # Create simulation context + sim = SimulationManager(sim_cfg) + sim.set_manual_update(False) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Add some objects to the scene for camera to observe + for i in range(5): + cube_cfg = RigidObjectCfg( + uid=f"cube_{i}", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + body_type="dynamic", + attrs=RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.3, + ), + init_pos=[0.5 + i * 0.3, 0.0, 0.5], + ) + sim.add_rigid_object(cfg=cube_cfg) + + # Create camera configuration + camera_cfg = CameraCfg( + uid="gizmo_camera", + width=640, + height=480, + intrinsics=(320, 320, 320, 240), # fx, fy, cx, cy + near=0.1, + far=10.0, + enable_color=True, + enable_depth=True, + extrinsics=CameraCfg.ExtrinsicsCfg( + eye=(2.0, 2.0, 2.0), + target=(0.0, 0.0, 0.0), + up=(0.0, 0.0, 1.0), + ), + ) + + # Add camera to simulation + camera = sim.add_sensor(sensor_cfg=camera_cfg) + + # Wait for initialization + time.sleep(0.2) + + # Enable gizmo for interactive camera control using the new unified API + sim.enable_gizmo(uid="gizmo_camera") + if not sim.has_gizmo("gizmo_camera"): + logger.log_error("Failed to enable gizmo for camera!") + return + + # Open simulation window (if not headless) + if not args.headless: + sim.open_window() + + logger.log_info("Gizmo-Camera tutorial started!") + logger.log_info( + "Use the gizmo to interactively control the camera position and orientation" + ) + logger.log_info( + "The camera will follow the gizmo pose for dynamic viewpoint control" + ) + logger.log_info("Press Ctrl+C to stop the simulation") + + # Run simulation loop + run_simulation(sim, camera) + + +def run_simulation(sim, camera): + """Run the simulation loop with gizmo updates.""" + step_count = 0 + last_time = time.time() + last_step = 0 + + logger.log_info("Camera view window will open. Press Ctrl+C or 'q' to exit") + logger.log_info( + "Use the gizmo in the 3D view to control camera position and orientation" + ) + + try: + while True: + # Update all gizmos managed by sim (including camera gizmo) + sim.update_gizmos() + + # Update camera to get latest sensor data + camera.update() + + # Refresh camera data if method available + if hasattr(camera, "refresh"): + camera.refresh() + + step_count += 1 + + # Display camera view in separate window + if step_count % 5 == 0: # Update display every 5 steps for performance + data = camera.get_data() + if "color" in data: + # Get RGB image and convert for OpenCV display + rgb_image = data["color"].cpu().numpy()[0, :, :, :3] # (H, W, 3) + # Convert RGB to BGR for OpenCV + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + + # Add text overlay + cv2.putText( + bgr_image, + "Press 'h' to toggle camera gizmo visibility", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + + # Display the image + cv2.imshow("Gizmo Camera View", bgr_image) + + # Check for key press + key = cv2.waitKey(1) & 0xFF + if key == ord("h"): + # Toggle the camera gizmo visibility using SimulationManager API + sim.toggle_gizmo_visibility("gizmo_camera") + + # Example: Destroy gizmo after certain steps to test cleanup + if step_count == 30000 and sim.has_gizmo("gizmo_camera"): + logger.log_info("Disabling gizmo at step 30000 (demonstration)") + sim.disable_gizmo("gizmo_camera") + + # Print simulation statistics and camera info + if step_count % 1000 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + + # Get camera pose for debugging + if sim.has_gizmo("gizmo_camera"): + camera_pose = camera.get_local_pose(to_matrix=True)[0] + camera_pos = camera_pose[:3, 3] + logger.log_info( + f"Step: {step_count}, FPS: {fps:.2f}, Camera pos: [{camera_pos[0]:.2f}, {camera_pos[1]:.2f}, {camera_pos[2]:.2f}]" + ) + else: + logger.log_info(f"Step: {step_count}, FPS: {fps:.2f}") + + last_time = current_time + last_step = step_count + + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + # Clean up resources + cv2.destroyAllWindows() + # Disable gizmo if it exists + if sim.has_gizmo("gizmo_camera"): + sim.disable_gizmo("gizmo_camera") + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/gizmo/gizmo_object.py b/examples/sim/gizmo/gizmo_object.py new file mode 100644 index 00000000..31967277 --- /dev/null +++ b/examples/sim/gizmo/gizmo_object.py @@ -0,0 +1,175 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create a simulation scene using SimulationManager. +It shows the basic setup of simulation context, adding objects, and sensors. +""" + +import argparse +import time + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RigidBodyAttributesCfg +from embodichain.lab.sim.shapes import CubeCfg + +from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg +from embodichain.utils import logger + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--headless", + action="store_true", + default=False, + help="Run simulation in headless mode", + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=args.headless, + physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) + sim_device=args.device, + enable_rt=args.enable_rt, # Enable ray tracing for better visuals + ) + + # Create the simulation instance + sim = SimulationManager(sim_cfg) + + # Enable manual physics update for precise control + sim.set_manual_update(True) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Add two cubes to the scene + cube1: RigidObject = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="cube1", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + body_type="kinematic", + attrs=RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ), + init_pos=[0.0, 0.0, 1.0], + ) + ) + cube2: RigidObject = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="cube2", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + body_type="kinematic", + attrs=RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ), + init_pos=[0.3, 0.0, 1.0], + ) + ) + + # Enable Gizmo for both cubes using the new API (only in window mode) + if not args.headless: + sim.enable_gizmo(uid="cube1") + sim.enable_gizmo(uid="cube2") + + logger.log_info("Scene setup complete!") + logger.log_info(f"Running simulation with {args.num_envs} environment(s)") + if not args.headless: + if sim.has_gizmo("cube1"): + logger.log_info("Gizmo enabled for cube1 - you can drag it around!") + if sim.has_gizmo("cube2"): + logger.log_info("Gizmo enabled for cube2 - you can drag it around!") + logger.log_info("Press Ctrl+C to stop the simulation") + + # Open window when the scene has been set up + if not args.headless: + sim.open_window() + + # Run the simulation + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + """Run the simulation loop.""" + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + step_count = 0 + try: + last_time = time.time() + last_step = 0 + while True: + sim.update(step=1) + + # Update all gizmos if any are enabled + sim.update_gizmos() + + step_count += 1 + + # Disable gizmo after 200000 steps (example) + if step_count == 200000 and gizmo_enabled: + logger.log_info("Disabling gizmo at step 200000") + sim.disable_gizmo("cube") + gizmo_enabled = False + + # Print FPS every second + if step_count % 1000 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + logger.log_info(f"Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/gizmo/gizmo_robot.py b/examples/sim/gizmo/gizmo_robot.py new file mode 100644 index 00000000..fae79962 --- /dev/null +++ b/examples/sim/gizmo/gizmo_robot.py @@ -0,0 +1,158 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +Gizmo-Robot Example: Test Gizmo class on a robot (UR10) +""" + +import time +import torch +import numpy as np +import argparse + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, +) + +from embodichain.lab.sim.solvers import PinkSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=args.enable_rt, + ) + + sim = SimulationManager(sim_cfg) + sim.set_manual_update(False) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Get UR10 URDF path + urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + + # Create UR10 robot + robot_cfg = RobotCfg( + uid="ur10_gizmo_test", + urdf_cfg=URDFCfg( + components=[{"component_type": "arm", "urdf_path": urdf_path}] + ), + control_parts={"arm": ["Joint[1-6]"]}, + solver_cfg={ + "arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="ee_link", + root_link_name="base_link", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ) + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"Joint[1-6]": 1e4}, + damping={"Joint[1-6]": 1e3}, + ), + ) + robot = sim.add_robot(cfg=robot_cfg) + + # Set initial joint positions + initial_qpos = torch.tensor( + [[0, -np.pi / 2, np.pi / 2, 0.0, np.pi / 2, 0.0]], + dtype=torch.float32, + device="cpu", + ) + joint_ids = robot.get_joint_ids("arm") + robot.set_qpos(qpos=initial_qpos, joint_ids=joint_ids) + + time.sleep(0.2) # Wait for a moment to ensure everything is set up + + # Enable gizmo using the new API + sim.enable_gizmo(uid="ur10_gizmo_test", control_part="arm") + if not sim.has_gizmo("ur10_gizmo_test", control_part="arm"): + logger.log_error("Failed to enable gizmo!") + return + + sim.open_window() + + logger.log_info("Gizmo-Robot example started!") + logger.log_info("Use the gizmo to drag the robot end-effector (EE)") + logger.log_info("Press Ctrl+C to stop the simulation") + + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + step_count = 0 + try: + last_time = time.time() + last_step = 0 + while True: + time.sleep(0.033) # 30Hz + # Update all gizmos managed by sim + sim.update_gizmos() + step_count += 1 + + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + logger.log_info(f"Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/gizmo/gizmo_scene.py b/examples/sim/gizmo/gizmo_scene.py new file mode 100644 index 00000000..31c7d233 --- /dev/null +++ b/examples/sim/gizmo/gizmo_scene.py @@ -0,0 +1,263 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +Gizmo Scene Example: Interactive scene with both robot and rigid object gizmos + +This example demonstrates how to create an interactive simulation scene with: +- A UR10 robot with gizmo control for end-effector manipulation +- A rigid object (cube) with gizmo control for direct manipulation +Both objects can be interactively controlled through their respective gizmos. +""" + +import time +import torch +import numpy as np +import argparse +import cv2 + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, +) +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.lab.sim.sensors import CameraCfg +from embodichain.lab.sim.solvers import PinkSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + + +def main(): + """Main function to create and run the simulation scene.""" + + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=args.enable_rt, + ) + + sim = SimulationManager(sim_cfg) + sim.set_manual_update(False) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Get DexForce W1 URDF path + urdf_path = get_data_path( + "DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M/DexforceW1V021.urdf" + ) + + # Create DexForce W1 robot + robot_cfg = RobotCfg( + uid="w1_gizmo_test", + urdf_cfg=URDFCfg( + components=[{"component_type": "humanoid", "urdf_path": urdf_path}] + ), + control_parts={"left_arm": ["LEFT_J[1-7]"], "right_arm": ["RIGHT_J[1-7]"]}, + solver_cfg={ + "left_arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="left_ee", + root_link_name="left_arm_base", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ), + "right_arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="right_ee", + root_link_name="right_arm_base", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ), + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"LEFT_J[1-7]": 1e4, "RIGHT_J[1-7]": 1e4}, + damping={"LEFT_J[1-7]": 1e3, "RIGHT_J[1-7]": 1e3}, + ), + ) + robot = sim.add_robot(cfg=robot_cfg) + + # Set initial joint positions for both arms + left_arm_qpos = torch.tensor( + [ + [0, 0, -np.pi / 4, np.pi / 4, -np.pi / 2, 0.0, np.pi / 4, 0.0] + ], # WAIST + LEFT_J[1-7] + dtype=torch.float32, + device="cpu", + ) + right_arm_qpos = torch.tensor( + [ + [0, 0, np.pi / 4, -np.pi / 4, np.pi / 2, 0.0, -np.pi / 4, 0.0] + ], # WAIST + RIGHT_J[1-7] + dtype=torch.float32, + device="cpu", + ) + + left_joint_ids = robot.get_joint_ids("left_arm") + right_joint_ids = robot.get_joint_ids("right_arm") + + robot.set_qpos(qpos=left_arm_qpos, joint_ids=left_joint_ids) + robot.set_qpos(qpos=right_arm_qpos, joint_ids=right_joint_ids) + + # Create a rigid object (cube) positioned to the side of the robot + cube_cfg = RigidObjectCfg( + uid="interactive_cube", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + body_type="kinematic", + attrs=RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ), + init_pos=[1.0, 0.0, 0.5], # Position to the side of the robot + ) + cube = sim.add_rigid_object(cube_cfg) + + camera_cfg = CameraCfg( + uid="scene_camera", + width=640, + height=480, + intrinsics=(320, 320, 320, 240), # fx, fy, cx, cy + near=0.1, + far=10.0, + enable_color=True, + enable_depth=True, + extrinsics=CameraCfg.ExtrinsicsCfg( + eye=(2.0, 2.0, 2.0), + target=(0.0, 0.0, 0.0), + up=(0.0, 0.0, 1.0), + ), + ) + camera = sim.add_sensor(sensor_cfg=camera_cfg) + + # Enable gizmo for all assets after all are created and initialized + sim.enable_gizmo(uid="w1_gizmo_test", control_part="left_arm") + if not sim.has_gizmo("w1_gizmo_test", control_part="left_arm"): + logger.log_error("Failed to enable left arm gizmo!") + return + + sim.enable_gizmo(uid="w1_gizmo_test", control_part="right_arm") + if not sim.has_gizmo("w1_gizmo_test", control_part="right_arm"): + logger.log_error("Failed to enable right arm gizmo!") + return + + sim.enable_gizmo(uid="interactive_cube") + if not sim.has_gizmo("interactive_cube"): + logger.log_error("Failed to enable gizmo for cube!") + return + + sim.enable_gizmo(uid="scene_camera") + if not sim.has_gizmo("scene_camera"): + logger.log_error("Failed to enable gizmo for camera!") + return + + sim.open_window() + + logger.log_info("Gizmo Scene example started!") + logger.log_info("Four gizmos are active in the scene:") + logger.log_info("1. Left arm gizmo - Use to drag the left arm end-effector (EE)") + logger.log_info("2. Right arm gizmo - Use to drag the right arm end-effector (EE)") + logger.log_info("3. Cube gizmo - Use to drag and position the cube") + logger.log_info("4. Camera gizmo - Use to drag and orient the camera") + logger.log_info("Press Ctrl+C to stop the simulation") + + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + step_count = 0 + # Get the camera instance by uid + camera = sim.get_sensor("scene_camera") + try: + last_time = time.time() + last_step = 0 + while True: + time.sleep(0.033) # 30Hz + sim.update_gizmos() + step_count += 1 + + # Display camera view in a window every 5 steps + if camera is not None and step_count % 5 == 0: + camera.update() + data = camera.get_data() + if "color" in data: + rgb_image = data["color"].cpu().numpy()[0, :, :, :3] + bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) + cv2.putText( + bgr_image, + "Press 'h' to toggle camera gizmo visibility", + (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, + 0.6, + (0, 255, 0), + 2, + ) + cv2.imshow("Camera Sensor View", bgr_image) + key = cv2.waitKey(1) & 0xFF + if key == ord("h"): + # Toggle the camera gizmo visibility using SimulationManager API + sim.toggle_gizmo_visibility("scene_camera") + + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + logger.log_info(f"Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + cv2.destroyAllWindows() + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/gizmo/gizmo_w1.py b/examples/sim/gizmo/gizmo_w1.py new file mode 100644 index 00000000..791582a0 --- /dev/null +++ b/examples/sim/gizmo/gizmo_w1.py @@ -0,0 +1,188 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +Gizmo-Robot Example: Test Gizmo class on a robot (UR10) +""" + +import time +import torch +import numpy as np +import argparse + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, +) + +from embodichain.lab.sim.solvers import PinkSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=args.enable_rt, + ) + + sim = SimulationManager(sim_cfg) + sim.set_manual_update(False) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Get DexForce W1 URDF path + urdf_path = get_data_path( + "DexforceW1V021_INDUSTRIAL_DH_PGC_GRIPPER_M/DexforceW1V021.urdf" + ) + + # Create DexForce W1 robot + robot_cfg = RobotCfg( + uid="w1_gizmo_test", + urdf_cfg=URDFCfg( + components=[{"component_type": "humanoid", "urdf_path": urdf_path}] + ), + control_parts={"left_arm": ["LEFT_J[1-7]"], "right_arm": ["RIGHT_J[1-7]"]}, + solver_cfg={ + "left_arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="left_ee", + root_link_name="left_arm_base", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ), + "right_arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="right_ee", + root_link_name="right_arm_base", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ), + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"LEFT_J[1-7]": 1e4, "RIGHT_J[1-7]": 1e4}, + damping={"LEFT_J[1-7]": 1e3, "RIGHT_J[1-7]": 1e3}, + ), + ) + robot = sim.add_robot(cfg=robot_cfg) + + # Set initial joint positions for both arms + # Left arm: 8 joints (WAIST + 7 LEFT_J), Right arm: 8 joints (WAIST + 7 RIGHT_J) + left_arm_qpos = torch.tensor( + [ + [0, 0, -np.pi / 4, np.pi / 4, -np.pi / 2, 0.0, np.pi / 4, 0.0] + ], # WAIST + LEFT_J[1-7] + dtype=torch.float32, + device="cpu", + ) + right_arm_qpos = torch.tensor( + [ + [0, 0, np.pi / 4, -np.pi / 4, np.pi / 2, 0.0, -np.pi / 4, 0.0] + ], # WAIST + RIGHT_J[1-7] + dtype=torch.float32, + device="cpu", + ) + + left_joint_ids = robot.get_joint_ids("left_arm") + right_joint_ids = robot.get_joint_ids("right_arm") + + robot.set_qpos(qpos=left_arm_qpos, joint_ids=left_joint_ids) + robot.set_qpos(qpos=right_arm_qpos, joint_ids=right_joint_ids) + + time.sleep(0.2) # Wait for a moment to ensure everything is set up + + # Enable gizmo for both arms using the new API + sim.enable_gizmo(uid="w1_gizmo_test", control_part="left_arm") + if not sim.has_gizmo("w1_gizmo_test", control_part="left_arm"): + logger.log_error("Failed to enable left arm gizmo!") + return + + sim.enable_gizmo(uid="w1_gizmo_test", control_part="right_arm") + if not sim.has_gizmo("w1_gizmo_test", control_part="right_arm"): + logger.log_error("Failed to enable right arm gizmo!") + return + + sim.open_window() + + logger.log_info("Gizmo-DexForce W1 example started!") + logger.log_info("Use the gizmos to drag both robot arms' end-effectors") + logger.log_info("Press Ctrl+C to stop the simulation") + + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + step_count = 0 + try: + last_time = time.time() + last_step = 0 + while True: + time.sleep(0.033) # 30Hz + # Update all gizmos managed by sim + sim.update_gizmos() + step_count += 1 + + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + logger.log_info(f"Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/planners/motion_generator.py b/examples/sim/planners/motion_generator.py new file mode 100644 index 00000000..4557f6d1 --- /dev/null +++ b/examples/sim/planners/motion_generator.py @@ -0,0 +1,146 @@ +import time +import torch +import numpy as np +from copy import deepcopy +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.robots import CobotMagicCfg +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.lab.sim.planners.motion_generator import MotionGenerator + + +def move_robot_along_trajectory( + robot: Robot, arm_name: str, qpos_list: list[torch.Tensor], delay: float = 0.1 +): + """ + Set the robot joint positions sequentially along the given joint trajectory. + Args: + robot: Robot instance. + arm_name: Name of the robot arm. + qpos_list: List of joint positions (torch.Tensor). + delay: Time delay between each step (seconds). + """ + for q in qpos_list: + robot.set_qpos(qpos=q.unsqueeze(0), joint_ids=robot.get_joint_ids(arm_name)) + time.sleep(delay) + + +def create_demo_trajectory( + robot: Robot, arm_name: str +) -> tuple[list[torch.Tensor], list[np.ndarray]]: + """ + Generate a three-point trajectory (start, middle, end) for demonstration. + Args: + robot: Robot instance. + arm_name: Name of the robot arm. + Returns: + qpos_list: List of joint positions (torch.Tensor). + xpos_list: List of end-effector poses (numpy arrays). + """ + qpos_fk = torch.tensor( + [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], dtype=torch.float32 + ) + xpos_begin = robot.compute_fk(name=arm_name, qpos=qpos_fk, to_matrix=True) + xpos_mid = deepcopy(xpos_begin) + xpos_mid[0, 2, 3] -= 0.1 # Move down by 0.1m in Z direction + xpos_final = deepcopy(xpos_mid) + xpos_final[0, 0, 3] += 0.2 # Move forward by 0.2m in X direction + + qpos_begin = robot.compute_ik(pose=xpos_begin, name=arm_name)[1][0] + qpos_mid = robot.compute_ik(pose=xpos_mid, name=arm_name)[1][0] + qpos_final = robot.compute_ik(pose=xpos_final, name=arm_name)[1][0] + return [qpos_begin, qpos_mid, qpos_final], [ + xpos_begin[0], + xpos_mid[0], + xpos_final[0], + ] + + +def main(interactive=False): + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Initialize simulation + sim = SimulationManager(SimulationManagerCfg(headless=False, sim_device="cpu")) + sim.build_multiple_arenas(1) + sim.set_manual_update(False) + + # Robot configuration + cfg_dict = { + "uid": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [ + -0.3, + 0.3, + 1.0, + 1.0, + -1.2, + -1.2, + 0.0, + 0.0, + 0.6, + 0.6, + 0.0, + 0.0, + 0.05, + 0.05, + 0.05, + 0.05, + ], + "solver_cfg": { + "left_arm": { + "class_type": "OPWSolver", + "end_link_name": "left_link6", + "root_link_name": "left_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + "right_arm": { + "class_type": "OPWSolver", + "end_link_name": "right_link6", + "root_link_name": "right_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + }, + } + robot: Robot = sim.add_robot(cfg=CobotMagicCfg.from_dict(cfg_dict)) + arm_name = "left_arm" + + # Generate trajectory points + qpos_list, xpos_list = create_demo_trajectory(robot, arm_name) + + # Initialize motion generator + motion_generator = MotionGenerator( + robot=robot, + uid=arm_name, + planner_type="toppra", + default_velocity=0.2, + default_acceleration=0.5, + ) + + # Joint space trajectory + out_qpos_list, _ = motion_generator.create_discrete_trajectory( + qpos_list=[q.numpy() for q in qpos_list], + is_linear=False, + sample_method=TrajectorySampleMethod.QUANTITY, + sample_num=20, + ) + move_robot_along_trajectory(robot, arm_name, out_qpos_list) + + # Cartesian space trajectory + out_qpos_list, _ = motion_generator.create_discrete_trajectory( + xpos_list=[x.numpy() for x in xpos_list], + is_linear=True, + sample_method=TrajectorySampleMethod.QUANTITY, + sample_num=20, + ) + move_robot_along_trajectory(robot, arm_name, out_qpos_list) + + if interactive: + # Enter IPython interactive shell if needed + from IPython import embed + + embed() + + +if __name__ == "__main__": + main() diff --git a/examples/sim/scene/scene_demo.py b/examples/sim/scene/scene_demo.py new file mode 100644 index 00000000..6efa3b13 --- /dev/null +++ b/examples/sim/scene/scene_demo.py @@ -0,0 +1,195 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +This script demonstrates how to create a simulation scene using SimulationManager. +It supports loading kitchen/factory/office scenes via EmbodiChainDataset. +""" + +import argparse +import time +from pathlib import Path +import math +import embodichain.utils.logger as logger +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RigidBodyAttributesCfg, LightCfg, RobotCfg, URDFCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg, Robot +from embodichain.data.assets.scene_assets import SceneData +from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATA_ROOT + + +def resolve_asset_path(scene_name: str) -> str: + """ + Resolve the asset path for a given scene (.glb/.gltf), + downloading if needed using EmbodiChainData. + """ + + current_dir = Path(__file__).parent + local_glb = current_dir / f"{scene_name}.glb" + local_gltf = current_dir / f"{scene_name}.gltf" + if local_glb.exists(): + logger.log_info(f"Using local asset: {local_glb}") + return str(local_glb) + if local_gltf.exists(): + logger.log_info(f"Using local asset: {local_gltf}") + return str(local_gltf) + + scene_data = SceneData() + + extracted_dir = Path(EMBODICHAIN_DEFAULT_DATA_ROOT) / "extract" / "SceneData" + glb_path = extracted_dir / f"{scene_name}.glb" + gltf_path = extracted_dir / f"{scene_name}.gltf" + + if glb_path.exists(): + logger.log_info(f"Using downloaded asset: {glb_path}") + return str(glb_path) + if gltf_path.exists(): + logger.log_info(f"Using downloaded asset: {gltf_path}") + return str(gltf_path) + + raise FileNotFoundError( + f"No .glb or .gltf found in extracted folder: {extracted_dir}" + ) + + +def run_simulation(sim: SimulationManager): + """Run the simulation loop.""" + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + try: + while True: + time.sleep(0.01) + except KeyboardInterrupt: + logger.log_info("\n Stopping simulation...") + finally: + sim.destroy() + logger.log_info("Simulation terminated successfully.") + + +def main(): + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--scene", + type=str, + default="kitchen", + choices=["kitchen", "factory", "office", "local"], + help="Choose which scene to load", + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--disable_rt", + action="store_true", + default=False, + help="Disable ray tracing for better visuals", + ) + args = parser.parse_args() + + logger.log_info(f"Initializing scene '{args.scene}'") + + logger.log_info(f"Resolving and downloading scene '{args.scene}' if needed...") + try: + asset_path = resolve_asset_path(args.scene) + logger.log_info(f"Scene asset ready at: {asset_path}") + except Exception as e: + print(f"Failed to download or resolve scene asset: {e}") + return + + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=True, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=not args.disable_rt, + ) + sim = SimulationManager(sim_cfg) + sim.set_manual_update(True) + + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=10.0) + + num_lights = 8 + radius = 5 + height = 8 + intensity = 200 + lights = [] + + for i in range(num_lights): + angle = 2 * math.pi * i / num_lights + x = radius * math.cos(angle) + y = radius * math.sin(angle) + z = height + uid = f"l{i+1}" + cfg = LightCfg(uid=uid, intensity=intensity, radius=600, init_pos=[x, y, z]) + lights.append(sim.add_light(cfg)) + + physics_attrs = RigidBodyAttributesCfg( + mass=10, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ) + + try: + logger.log_info(f"Loading scene asset into simulation: {asset_path}") + scene_obj: RigidObject = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid=args.scene, + shape=MeshCfg(fpath=asset_path), + body_type="static", + init_pos=[0.0, 0.0, 0.1], + init_rot=[90, 180, 0], + attrs=physics_attrs, + ) + ) + if args.scene == "factory": + from embodichain.lab.sim.robots.dexforce_w1.cfg import DexforceW1Cfg + + w1_robot: Robot = sim.add_robot( + cfg=DexforceW1Cfg.from_dict( + { + "uid": "dexforce_w1", + "version": "v021", + "arm_kind": "anthropomorphic", + "init_pos": [-1, -0.5, 0], + "init_rot": [0, 0, 90], + } + ), + ) + + except Exception as e: + logger.log_info(f"Failed to load scene asset: {e}") + return + + logger.log_info(f"Scene '{args.scene}' setup complete!") + logger.log_info(f"Running simulation with {args.num_envs} environment(s)") + logger.log_info("Press Ctrl+C to stop the simulation") + + sim.open_window() + + run_simulation(sim) + + +if __name__ == "__main__": + main() diff --git a/examples/sim/sensors/batch_camera.py b/examples/sim/sensors/batch_camera.py new file mode 100644 index 00000000..efbd17fb --- /dev/null +++ b/examples/sim/sensors/batch_camera.py @@ -0,0 +1,145 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import time +import numpy as np +import matplotlib.pyplot as plt + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RigidObjectCfg, LightCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.objects import RigidObject, Light +from embodichain.lab.sim.sensors import ( + Camera, + StereoCamera, + CameraCfg, + StereoCameraCfg, +) +from embodichain.data import get_data_path + + +def main(args): + config = SimulationManagerCfg( + headless=True, sim_device=args.device, arena_space=2, enable_rt=args.enable_rt + ) + sim = SimulationManager(config) + sim.build_multiple_arenas(args.num_envs) + + rigid_obj: RigidObject = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="obj", + shape=MeshCfg(fpath=get_data_path("Chair/chair.glb")), + init_pos=(0, 0, 0.2), + ) + ) + light: Light = sim.add_light( + cfg=LightCfg(light_type="point", init_pos=(0, 0, 2), intensity=50) + ) + + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + if args.headless is False: + sim.open_window() + + import torch + + torch.set_printoptions(precision=4, sci_mode=False) + + eye = (0.0, 0, 2.0) + target = (0.0, 0.0, 0.0) + if args.sensor_type == "stereo": + camera: StereoCamera = sim.add_sensor( + sensor_cfg=StereoCameraCfg( + width=640, + height=480, + extrinsics=CameraCfg.ExtrinsicsCfg(eye=eye, target=target), + ) + ) + else: + camera: Camera = sim.add_sensor( + sensor_cfg=CameraCfg( + width=640, + height=480, + extrinsics=CameraCfg.ExtrinsicsCfg(eye=eye, target=target), + ) + ) + + # TODO: To be removed + sim.reset_objects_state() + + t0 = time.time() + camera.update() + print(f"Camera update time: {time.time() - t0:.4f} seconds") + + data_frame = camera.get_data() + + t0 = time.time() + rgba = data_frame["color"].cpu().numpy() + if args.sensor_type == "stereo": + rgba_right = data_frame["color_right"].cpu().numpy() + + # plot rgba into a grid of images + grid_x = np.ceil(np.sqrt(args.num_envs)).astype(int) + grid_y = np.ceil(args.num_envs / grid_x).astype(int) + fig, axs = plt.subplots(grid_x, grid_y, figsize=(12, 6)) + axs = axs.flatten() + for i in range(args.num_envs): + + if args.sensor_type == "stereo": + image = np.concatenate((rgba[i], rgba_right[i]), axis=1) + else: + image = rgba[i] + axs[i].imshow(image) + axs[i].axis("off") + axs[i].set_title(f"Env {i}") + + if args.headless: + plt.savefig(f"camera_data.png") + else: + plt.show() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run the batch robot simulation.") + parser.add_argument( + "--num_envs", type=int, default=4, help="Number of environments to simulate." + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Device to run the simulation on.", + ) + parser.add_argument( + "--headless", action="store_true", help="Run the simulation in headless mode." + ) + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering." + ) + parser.add_argument( + "--sensor_type", + type=str, + default="camera", + choices=["stereo", "camera"], + help="Type of camera sensor to use.", + ) + + args = parser.parse_args() + main(args) diff --git a/examples/sim/solvers/differential_solver.py b/examples/sim/solvers/differential_solver.py new file mode 100644 index 00000000..1f056575 --- /dev/null +++ b/examples/sim/solvers/differential_solver.py @@ -0,0 +1,250 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import os +import time +import numpy as np +import torch +from IPython import embed + +from embodichain.data import get_data_path +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import MarkerCfg + + +def main(visualize: bool = True): + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Set up simulation with specified device (CPU or CUDA) + sim_device = "cpu" + num_envs = 9 # Number of parallel arenas/environments + config = SimulationManagerCfg( + headless=False, sim_device=sim_device, arena_space=1.5 + ) + sim = SimulationManager(config) + sim.build_multiple_arenas(num_envs) + sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("Rokae/SR5/SR5.urdf") + assert os.path.isfile(urdf) + + # Robot configuration + cfg_dict = { + "fpath": urdf, + "control_parts": { + "main_arm": [ + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ], + }, + "solver_cfg": { + "main_arm": { + "class_type": "DifferentialSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + }, + }, + } + + robot: Robot = sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + # Prepare initial joint positions for all environments + rad = torch.deg2rad(torch.tensor(45.0)) + arm_name = "main_arm" + fk_qpos = torch.full((num_envs, 6), rad, dtype=torch.float32, device="cpu") + # All envs start with the same qpos (can be randomized) + qpos = torch.from_numpy(np.array([0.0, 0.0, np.pi / 2, 0.0, np.pi / 2, 0.0])).to( + fk_qpos.device + ) + qpos = qpos.unsqueeze(0).repeat(num_envs, 1) + robot.set_qpos(qpos=qpos, joint_ids=robot.get_joint_ids(arm_name)) + + time.sleep(3.0) + fk_xpos = robot.compute_fk( + qpos=qpos, name=arm_name, to_matrix=True + ) # (num_envs, 4, 4) + + # Prepare batch start and end poses for all envs + start_pose = fk_xpos.clone() # (num_envs, 4, 4) + end_pose = fk_xpos.clone() + move_vecs = torch.tensor( + [ + [0.3, 0.0, 0.0], + [0.2, -0.2, 0.0], + [0.0, 0.0, 0.2], + [0.2, 0.0, 0.2], + [-0.3, 0.0, 0.0], + [-0.2, 0.2, 0.0], + [0.0, 0.0, -0.2], + [-0.2, 0.0, -0.2], + [0.1, 0.1, -0.1], + ], + dtype=end_pose.dtype, + device=end_pose.device, + ) + end_pose[ + :, :3, 3 + ] += move_vecs # Move each env's end-effector in a different direction + + num_steps = 100 + # Interpolate poses for each env + interpolated_poses = torch.stack( + [torch.lerp(start_pose, end_pose, t) for t in np.linspace(0, 1, num_steps)], + dim=1, + ) # (num_envs, num_steps, 4, 4) + + # Initial joint positions for all envs + ik_qpos = qpos.clone() + + ik_qpos_results = [] + ik_success_flags = [] + + ik_compute_begin = time.time() + # Batch IK solving for each step + for step in range(num_steps): + poses = interpolated_poses[:, step, :, :] # (num_envs, 4, 4) + if poses.shape[0] != num_envs: + poses = poses.expand(num_envs, *poses.shape[1:]) + if ik_qpos.shape[0] != num_envs: + ik_qpos = ik_qpos.expand(num_envs, *ik_qpos.shape[1:]) + assert ( + poses.shape[0] == num_envs + ), f"poses batch mismatch: {poses.shape[0]} vs {num_envs}" + assert ( + ik_qpos.shape[0] == num_envs + ), f"ik_qpos batch mismatch: {ik_qpos.shape[0]} vs {num_envs}" + + # Parallel batch IK solving + res, ik_qpos_new = robot.compute_ik( + pose=poses, joint_seed=ik_qpos, name=arm_name + ) + ik_qpos_results.append(ik_qpos_new.clone()) + ik_success_flags.append(res) + ik_qpos = ik_qpos_new # Update joint seed + ik_compute_end = time.time() + print( + f"IK compute time for {num_steps} steps and {num_envs} envs: {ik_compute_end - ik_compute_begin:.4f} seconds" + ) + + # Collect visualization data for all steps and environments + if visualize: + draw_data = [[] for _ in range(num_envs)] + for env_id in range(num_envs): + for step in range(num_steps): + ik_qpos_new = ik_qpos_results[step] + ik_xpos = robot.compute_fk(qpos=ik_qpos_new, name=arm_name, to_matrix=True) + local_pose = robot._entities[env_id].get_world_pose() + if visualize: + fk_axis = local_pose @ end_pose[env_id].cpu().numpy() + ik_axis = local_pose @ ik_xpos[env_id].cpu().numpy() + local_axis = local_pose @ ik_xpos[env_id].cpu().numpy() + + draw_data[env_id].append( + { + "step": step, + "fk_axis": fk_axis, + "ik_axis": ik_axis, + "local_axis": local_axis, + } + ) + + if visualize: + # Batch draw all steps' data for each environment + for env_id in range(num_envs): + # Only draw fk_axis and ik_axis once per env (first step) + fk_axis = draw_data[env_id][0]["fk_axis"] + ik_axis = draw_data[env_id][0]["ik_axis"] + + sim.draw_marker( + cfg=MarkerCfg( + name=f"fk_axis_env{env_id}", + marker_type="axis", + axis_xpos=fk_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name=f"ik_axis_env{env_id}", + marker_type="axis", + axis_xpos=ik_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + # Draw the whole local_axis trajectory as a single call + local_axes = np.stack( + [item["local_axis"] for item in draw_data[env_id]], axis=0 + ) # (num_steps, 4, 4) or (num_steps, 3) + + sim.draw_marker( + cfg=MarkerCfg( + name=f"local_axis_env{env_id}_trajectory", + marker_type="axis", + axis_xpos=local_axes, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + # Optionally, set qpos for each step (replay or animate) + for step in range(num_steps): + ik_qpos_new = ik_qpos_results[step] + res = ik_success_flags[step] + # Only set qpos for successful IK results + if isinstance(res, (list, np.ndarray, torch.Tensor)): + for env_id, success in enumerate(res): + if success: + q = ( + ik_qpos_new[env_id] + if ik_qpos_new.dim() == 3 + else ik_qpos_new[env_id] + ) + robot.set_qpos( + qpos=q, + joint_ids=robot.get_joint_ids(arm_name), + env_ids=[env_id], + ) + else: + # fallback: set all + if ik_qpos_new.dim() == 3: + robot.set_qpos( + qpos=ik_qpos_new[:, 0, :], joint_ids=robot.get_joint_ids(arm_name) + ) + else: + robot.set_qpos( + qpos=ik_qpos_new, joint_ids=robot.get_joint_ids(arm_name) + ) + time.sleep(0.005) + + embed(header="Test DifferentialSolver example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main(visualize=True) diff --git a/examples/sim/solvers/opw_solver.py b/examples/sim/solvers/opw_solver.py new file mode 100644 index 00000000..aa47eae8 --- /dev/null +++ b/examples/sim/solvers/opw_solver.py @@ -0,0 +1,195 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import os +import time +import torch +import numpy as np +from IPython import embed + +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.robots import CobotMagicCfg +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import MarkerCfg + + +def main(): + # Set print options for better readability + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Initialize simulation + sim_device = "cpu" + config = SimulationManagerCfg(headless=False, sim_device=sim_device) + sim = SimulationManager(config) + sim.build_multiple_arenas(1) + sim.set_manual_update(False) + + # Robot configuration dictionary + cfg_dict = { + "uid": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [ + -0.3, + 0.3, + 1.0, + 1.0, + -1.2, + -1.2, + 0.0, + 0.0, + 0.6, + 0.6, + 0.0, + 0.0, + 0.05, + 0.05, + 0.05, + 0.05, + ], + "solver_cfg": { + "left_arm": { + "class_type": "OPWSolver", + "end_link_name": "left_link6", + "root_link_name": "left_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + "right_arm": { + "class_type": "OPWSolver", + "end_link_name": "right_link6", + "root_link_name": "right_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + }, + } + + # Add robot to simulation + robot: Robot = sim.add_robot(cfg=CobotMagicCfg.from_dict(cfg_dict)) + + # Left arm control + arm_name = "left_arm" + # Set initial joint positions for left arm + qpos_seed = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) + qpos_fk = torch.tensor( + [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], dtype=torch.float32 + ) + fk_xpos = robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + link_pose = robot._entities[0].get_link_pose("left_base_link") + link_pose_tensor = torch.from_numpy(link_pose).to( + fk_xpos.device, dtype=fk_xpos.dtype + ) + + # Solve IK for the left arm + res, ik_qpos = robot.compute_ik(pose=fk_xpos, name=arm_name, joint_seed=qpos_seed) + + # Measure IK computation time and visualize result + a = time.time() + if ik_qpos.dim() == 3: + ik_xpos = robot.compute_fk(qpos=ik_qpos[0][0], name=arm_name, to_matrix=True) + else: + ik_xpos = robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + b = time.time() + print(f"Left arm IK computation time: {b-a:.6f} seconds") + + # Visualize the result in simulation + sim.draw_marker( + cfg=MarkerCfg( + name="fk_xpos_left", + marker_type="axis", + axis_xpos=np.array(fk_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name="ik_xpos_left", + marker_type="axis", + axis_xpos=np.array(ik_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + # Move robot to IK result joint positions + if ik_qpos.dim() == 3: + robot.set_qpos(qpos=ik_qpos[0][0], joint_ids=robot.get_joint_ids(arm_name)) + else: + robot.set_qpos(qpos=ik_qpos, joint_ids=robot.get_joint_ids(arm_name)) + + # Right arm control + arm_name_r = "right_arm" + # Set initial joint positions for right arm + qpos_seed_r = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=torch.float32) + qpos_fk_r = torch.tensor( + [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], dtype=torch.float32 + ) + fk_xpos_r = robot.compute_fk(qpos=qpos_fk_r, name=arm_name_r, to_matrix=True) + + link_pose_r = robot._entities[0].get_link_pose("right_base_link") + link_pose_tensor_r = torch.from_numpy(link_pose_r).to( + fk_xpos_r.device, dtype=fk_xpos_r.dtype + ) + + # Solve IK for the right arm + res_r, ik_qpos_r = robot.compute_ik( + pose=fk_xpos_r, name=arm_name_r, joint_seed=qpos_seed_r + ) + + # Measure IK computation time and visualize result + a_r = time.time() + if ik_qpos_r.dim() == 3: + ik_xpos_r = robot.compute_fk( + qpos=ik_qpos_r[0][0], name=arm_name_r, to_matrix=True + ) + else: + ik_xpos_r = robot.compute_fk(qpos=ik_qpos_r, name=arm_name_r, to_matrix=True) + b_r = time.time() + print(f"Right arm IK computation time: {b_r-a_r:.6f} seconds") + + # Visualize the result in simulation + sim.draw_marker( + cfg=MarkerCfg( + name="fk_xpos_right", + marker_type="axis", + axis_xpos=np.array(fk_xpos_r.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name="ik_xpos_right", + marker_type="axis", + axis_xpos=np.array(ik_xpos_r.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + # Move robot to IK result joint positions + if ik_qpos_r.dim() == 3: + robot.set_qpos(qpos=ik_qpos_r[0][0], joint_ids=robot.get_joint_ids(arm_name_r)) + else: + robot.set_qpos(qpos=ik_qpos_r, joint_ids=robot.get_joint_ids(arm_name_r)) + + embed(header="Test OPWSolver example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/solvers/pink_solver.py b/examples/sim/solvers/pink_solver.py new file mode 100644 index 00000000..8ad325fe --- /dev/null +++ b/examples/sim/solvers/pink_solver.py @@ -0,0 +1,177 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import os +import time +import numpy as np +import torch +from IPython import embed + +from embodichain.data import get_data_path +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import MarkerCfg + + +def main(): + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Set up simulation with specified device (CPU or CUDA) + sim_device = "cpu" + config = SimulationManagerCfg(headless=False, sim_device=sim_device) + sim = SimulationManager(config) + sim.build_multiple_arenas(1) + sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("Rokae/SR5/SR5.urdf") + + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "main_arm": [ + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ], + }, + "solver_cfg": { + "main_arm": { + "class_type": "PinkSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + }, + }, + } + + robot: Robot = sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + # Define a sample target pose as a 1x4x4 homogeneous matrix + rad = torch.deg2rad(torch.tensor(45.0)) + + arm_name = "main_arm" + fk_qpos = torch.full((1, 6), rad, dtype=torch.float32, device="cpu") + + # Set initial joint positions + qpos = torch.from_numpy(np.array([0.0, 0.0, np.pi / 2, 0.0, np.pi / 2, 0.0])).to( + fk_qpos.device + ) + qpos = qpos.unsqueeze(0) + robot.set_qpos(qpos=qpos, joint_ids=robot.get_joint_ids("main_arm")) + import time + + time.sleep(3.0) + fk_xpos = robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=True) + print(f"fk_xpos: {fk_xpos}") + start_pose = fk_xpos.clone()[0] # Start pose + end_pose = fk_xpos.clone()[0] # End pose + + end_pose[:3, 3] = end_pose[:3, 3][:3] + torch.tensor( + [0.0, 0.4, 0.0], device=fk_xpos.device + ) + + num_steps = 100 + + # Interpolate poses between start and end + interpolated_poses = [ + torch.lerp(start_pose, end_pose, t) for t in np.linspace(0, 1, num_steps) + ] + + ik_qpos = qpos + + qpos = ik_qpos + res, ik_qpos = robot.compute_ik(pose=end_pose, joint_seed=qpos, name=arm_name) + import time + + a = time.time() + if ik_qpos.dim() == 3: + ik_xpos = robot.compute_fk(qpos=ik_qpos[0][0], name=arm_name, to_matrix=True) + else: + ik_xpos = robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + b = time.time() + print(f"ik time: {b-a}") + + ik_xpos = ik_xpos + + sim.draw_marker( + cfg=MarkerCfg( + name="fk_xpos", + marker_type="axis", + axis_xpos=np.array(end_pose.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name="ik_xpos", + marker_type="axis", + axis_xpos=np.array(ik_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + for i, pose in enumerate(interpolated_poses): + print(f"Step {i}: Moving to pose:\n{pose}") + start_time = time.time() + res, ik_qpos = robot.compute_ik(pose=pose, joint_seed=ik_qpos, name=arm_name) + end_time = time.time() + compute_time = end_time - start_time + print(f"Step {i}: IK computation time: {compute_time:.6f} seconds") + + print(f"IK result: {res}, ik_qpos: {ik_qpos}") + if not res: + print(f"Step {i}: IK failed for pose:\n{pose}") + continue + + # Set robot joint positions + if ik_qpos.dim() == 3: + robot.set_qpos(qpos=ik_qpos[0][0], joint_ids=robot.get_joint_ids(arm_name)) + else: + robot.set_qpos( + qpos=ik_qpos.unsqueeze(0), joint_ids=robot.get_joint_ids(arm_name) + ) + + # Visualize current pose + ik_xpos = robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + ik_xpos = ik_xpos + + sim.draw_marker( + cfg=MarkerCfg( + name=f"ik_xpos_step_{i}", + marker_type="axis", + axis_xpos=np.array(ik_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + # Add delay to simulate motion + time.sleep(0.005) + + embed(header="Test PinkSolver example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/solvers/pinocchio_solver.py b/examples/sim/solvers/pinocchio_solver.py new file mode 100644 index 00000000..bfd3730b --- /dev/null +++ b/examples/sim/solvers/pinocchio_solver.py @@ -0,0 +1,127 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import os +import time +import numpy as np +import torch +from IPython import embed + +from embodichain.data import get_data_path +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import MarkerCfg + + +def main(): + # Set print options for better readability + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Initialize simulation + sim_device = "cpu" + config = SimulationManagerCfg(headless=False, sim_device=sim_device) + sim = SimulationManager(config) + sim.build_multiple_arenas(1) + sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + # Robot configuration dictionary + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [f"LEFT_J{i+1}" for i in range(7)], + "right_arm": [f"RIGHT_J{i+1}" for i in range(7)], + }, + "solver_cfg": { + "left_arm": { + "class_type": "PinocchioSolver", + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + }, + "right_arm": { + "class_type": "PinocchioSolver", + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + }, + }, + } + + robot: Robot = sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + arm_name = "left_arm" + # Set initial joint positions for left arm + qpos_seed = torch.tensor( + [[0.0, 0.1, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], dtype=torch.float32 + ) + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], dtype=torch.float32 + ) + fk_xpos = robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + link_pose = robot._entities[0].get_link_pose("left_arm_base") + link_pose_tensor = torch.from_numpy(link_pose).to( + fk_xpos.device, dtype=fk_xpos.dtype + ) + + # Solve IK for the left arm + res, ik_qpos = robot.compute_ik(pose=fk_xpos, name=arm_name, joint_seed=qpos_seed) + + # Measure IK computation time and visualize result + a = time.time() + if ik_qpos.dim() == 3: + ik_xpos = robot.compute_fk(qpos=ik_qpos[0][0], name=arm_name, to_matrix=True) + else: + ik_xpos = robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + b = time.time() + print(f"IK computation time: {b-a:.6f} seconds") + + fk_xpos = link_pose_tensor @ fk_xpos + ik_xpos = link_pose_tensor @ ik_xpos + + # Visualize the result in simulation + sim.draw_marker( + cfg=MarkerCfg( + name="fk_xpos", + marker_type="axis", + axis_xpos=np.array(fk_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name="ik_xpos", + marker_type="axis", + axis_xpos=np.array(ik_xpos.tolist()), + axis_size=0.002, + axis_len=0.005, + ) + ) + + # Move robot to IK result joint positions + if ik_qpos.dim() == 3: + robot.set_qpos(qpos=ik_qpos[0][0], joint_ids=robot.get_joint_ids(arm_name)) + else: + robot.set_qpos(qpos=ik_qpos, joint_ids=robot.get_joint_ids(arm_name)) + + embed(header="Test PinocchioSolver example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/solvers/pytorch_solver.py b/examples/sim/solvers/pytorch_solver.py new file mode 100644 index 00000000..822eac52 --- /dev/null +++ b/examples/sim/solvers/pytorch_solver.py @@ -0,0 +1,226 @@ +import os +import time +import numpy as np +import torch +from IPython import embed + +from embodichain.data import get_data_path +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import MarkerCfg + + +def main(): + # Set numpy and torch print options for better readability + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Initialize simulation environment (CPU or CUDA) + sim_device = "cpu" + num_envs = 9 # Number of parallel environments + config = SimulationManagerCfg( + headless=False, sim_device=sim_device, arena_space=2.0 + ) + sim = SimulationManager(config) + sim.build_multiple_arenas(num_envs) + sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + # Robot configuration dictionary + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [ + "LEFT_J1", + "LEFT_J2", + "LEFT_J3", + "LEFT_J4", + "LEFT_J5", + "LEFT_J6", + "LEFT_J7", + ], + }, + "solver_cfg": { + "left_arm": { + "class_type": "PytorchSolver", + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + }, + }, + } + + # Add robot to simulation + robot: Robot = sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + # Prepare initial joint positions for all environments + arm_name = "left_arm" + qpos = ( + torch.tensor([0.0, 0.0, 0.0, -np.pi / 2, 0.0, 0.0, 0.0], dtype=torch.float32) + .unsqueeze(0) + .repeat(num_envs, 1) + ) + robot.set_qpos(qpos=qpos, joint_ids=robot.get_joint_ids(arm_name)) + + time.sleep(2.0) + fk_xpos = robot.compute_fk( + qpos=qpos, name=arm_name, to_matrix=True + ) # (num_envs, 4, 4) + + # Prepare batch start and end poses for all envs + start_pose = fk_xpos.clone() + end_pose = fk_xpos.clone() + move_vecs = torch.tensor( + [ + [0.2, 0.0, 0.0], + [0.0, 0.2, 0.0], + [0.0, -0.2, -0.5], + [-0.2, 0.0, 0.0], + [-0.2, 0.0, 0.0], + [0.0, -0.2, 0.0], + [0.0, 0.0, -0.5], + [-0.2, 0.2, 0.0], + [0.0, 0.2, -0.5], + ], + dtype=end_pose.dtype, + device=end_pose.device, + ) + end_pose[:, :3, 3] += move_vecs + + num_steps = 50 + # Interpolate poses for each env + interpolated_poses = torch.stack( + [torch.lerp(start_pose, end_pose, t) for t in np.linspace(0, 1, num_steps)], + dim=1, + ) # (num_envs, num_steps, 4, 4) + + # Initial joint positions for all envs + ik_qpos = qpos.clone() + ik_qpos_results = [] + ik_success_flags = [] + + # Batch IK solving for each step + ik_compute_begin = time.time() + for step in range(num_steps): + poses = interpolated_poses[:, step, :, :] # (num_envs, 4, 4) + if poses.shape[0] != num_envs: + poses = poses.expand(num_envs, *poses.shape[1:]) + if ik_qpos.shape[0] != num_envs: + ik_qpos = ik_qpos.expand(num_envs, *ik_qpos.shape[1:]) + assert poses.shape[0] == num_envs + assert ik_qpos.shape[0] == num_envs + + # Parallel batch IK solving + res, ik_qpos_new = robot.compute_ik( + pose=poses, joint_seed=ik_qpos, name=arm_name + ) + ik_qpos_results.append(ik_qpos_new.clone()) + ik_success_flags.append(res) + ik_qpos = ik_qpos_new # Update joint seed + ik_compute_end = time.time() + print( + f"IK compute time for {num_steps} steps and {num_envs} envs: {ik_compute_end - ik_compute_begin:.4f} seconds" + ) + + # Collect visualization data for all steps and environments + draw_data = [[] for _ in range(num_envs)] + for env_id in range(num_envs): + for step in range(num_steps): + ik_qpos_new = ik_qpos_results[step] + ik_xpos = robot.compute_fk(qpos=ik_qpos_new, name=arm_name, to_matrix=True) + local_pose = robot._entities[env_id].get_link_pose("left_arm_base") + if isinstance(local_pose, np.ndarray): + local_pose = torch.from_numpy(local_pose).to( + ik_xpos.device, dtype=ik_xpos.dtype + ) + fk_axis = (local_pose @ end_pose[env_id]).cpu().numpy() + ik_axis = (local_pose @ ik_xpos[env_id]).cpu().numpy() + local_axis = (local_pose @ ik_xpos[env_id]).cpu().numpy() + draw_data[env_id].append( + { + "step": step, + "fk_axis": fk_axis, + "ik_axis": ik_axis, + "local_axis": local_axis, + } + ) + + # Batch draw: only draw fk_axis and ik_axis once per env, draw local_axis trajectory for all steps + for env_id in range(num_envs): + fk_axis = draw_data[env_id][0]["fk_axis"] + ik_axis = draw_data[env_id][0]["ik_axis"] + + sim.draw_marker( + cfg=MarkerCfg( + name=f"fk_axis_env{env_id}", + marker_type="axis", + axis_xpos=fk_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + sim.draw_marker( + cfg=MarkerCfg( + name=f"ik_axis_env{env_id}", + marker_type="axis", + axis_xpos=ik_axis, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + # Draw the whole local_axis trajectory as a single call (if supported) + local_axes = np.stack( + [item["local_axis"] for item in draw_data[env_id]], axis=0 + ) + + sim.draw_marker( + cfg=MarkerCfg( + name=f"local_axis_env{env_id}_trajectory", + marker_type="axis", + axis_xpos=local_axes, + axis_size=0.002, + axis_len=0.005, + arena_index=env_id, + ) + ) + + # Optionally, set qpos for each step (replay or animate) + for step in range(num_steps): + ik_qpos_new = ik_qpos_results[step] + res = ik_success_flags[step] + if isinstance(res, (list, np.ndarray, torch.Tensor)): + for env_id, success in enumerate(res): + if success: + q = ( + ik_qpos_new[env_id] + if ik_qpos_new.dim() == 3 + else ik_qpos_new[env_id] + ) + robot.set_qpos( + qpos=q, + joint_ids=robot.get_joint_ids(arm_name), + env_ids=[env_id], + ) + else: + if ik_qpos_new.dim() == 3: + robot.set_qpos( + qpos=ik_qpos_new[:, 0, :], joint_ids=robot.get_joint_ids(arm_name) + ) + else: + robot.set_qpos( + qpos=ik_qpos_new, joint_ids=robot.get_joint_ids(arm_name) + ) + time.sleep(0.005) + + embed(header="Test PytorchSolver batch example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main() diff --git a/examples/sim/solvers/srs_solver.py b/examples/sim/solvers/srs_solver.py new file mode 100644 index 00000000..5d9922b3 --- /dev/null +++ b/examples/sim/solvers/srs_solver.py @@ -0,0 +1,85 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import time +import numpy as np +import torch + +from IPython import embed + +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.robots import DexforceW1Cfg + + +def main(): + # Set print options for better readability + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + + # Initialize simulation + sim_device = "cpu" + sim = SimulationManager( + SimulationManagerCfg( + headless=False, sim_device=sim_device, width=2200, height=1200 + ) + ) + + sim.build_multiple_arenas(1) + sim.set_manual_update(False) + + robot: Robot = sim.add_robot(cfg=DexforceW1Cfg.from_dict({"uid": "dexforce_w1"})) + arm_name = "left_arm" + # Set initial joint positions for left arm + qpos_fk_list = [ + torch.tensor([[0.0, 0.0, 0.0, -np.pi / 2, 0.0, 0.0, 0.0]], dtype=torch.float32), + ] + robot.set_qpos(qpos_fk_list[0], joint_ids=robot.get_joint_ids(arm_name)) + + time.sleep(0.5) + + fk_xpos_batch = torch.cat(qpos_fk_list, dim=0) + + fk_xpos_list = robot.compute_fk(qpos=fk_xpos_batch, name=arm_name, to_matrix=True) + + start_time = time.time() + res, ik_qpos = robot.compute_ik( + pose=fk_xpos_list, + name=arm_name, + # joint_seed=qpos_fk_list[0], + return_all_solutions=True, + ) + end_time = time.time() + print( + f"Batch IK computation time for {len(fk_xpos_list)} poses: {end_time - start_time:.6f} seconds" + ) + + if ik_qpos.dim() == 3: + first_solutions = ik_qpos[:, 0, :] + else: + first_solutions = ik_qpos + robot.set_qpos(first_solutions, joint_ids=robot.get_joint_ids(arm_name)) + + ik_xpos_list = robot.compute_fk(qpos=first_solutions, name=arm_name, to_matrix=True) + + print("fk_xpos_list: ", fk_xpos_list) + print("ik_xpos_list: ", ik_xpos_list) + + embed(header="Test SRSSolver example. Press Ctrl-D to exit.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..83b28a47 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,74 @@ +[build-system] +# Use the legacy setuptools backend so the existing `setup.py` (which dynamically +# detects torch and builds C/CUDA extensions via torch.utils.cpp_extension) +# is still executed during build. This is a minimal, non-opinionated pyproject +# that modernizes the package while preserving the custom extension build steps. +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta:__legacy__" + + +# Note: +# - Building the torch extensions requires a working PyTorch installation +# available at build time (or FORCE_CUDA=1 for cross-compilation). We keep +# the extension-building logic in `setup.py` because it needs to import +# `torch` at build time to select CUDA/Cpp extension classes. +# - If you prefer a pure PEP 517 build without a legacy setup.py, we can +# refactor the extension-building logic into a custom build backend plugin. + +[project] +name = "embodichain" +version = "0.0.1" +description = "A modular platform for building generalized embodied intelligence." +readme = "README.md" +authors = [ { name = "Dexforce" } ] +requires-python = ">=3.9" + + +# Core install dependencies (kept from requirements.txt). Some VCS links are +# specified using PEP 508 direct references where present. +dependencies = [ + "setuptools==69.5.1", + "gymnasium==0.29.1", + "langchain==0.2.14", + "langchain-openai==0.1.22", + "pillow==9.5.0", + "ffmpeg-python==0.2.0", + "pytransform3d", + "uvicorn", + "fastapi", + "casadi==3.7.1", + "pin==2.7.0", + "toppra==0.6.3", + "qpsolvers==4.8.1", + "pin-pink==3.4.0", + "PyYAML>=6.0", + "transformers==4.48.0", + "diffusers==0.32.1", + "balanced-loss", + "accelerate==1.2.1", + "wandb==0.20.1", + "tensorboard", + "pydantic==2.7.1", + "deepspeed==0.16.2", + "py_opw_kinematics==0.1.6", + "pytorch_kinematics==0.7.5", + "polars==1.31.0", + "cvxpy==1.4.0", + "ortools", + "prettytable", + "transforms3d==0.4.2", + "hdfdict@git+http://69.235.177.182:8081/externalrepo/hdfdict", + "OmegaConf", + "dill", + "black==22.3", + "aenum", + "h5py", + "dacite==1.9.2", + "zmq", +] + +[tool.setuptools.packages.find] +where = ["."] +exclude = ["docs"] + +[tool.black] diff --git a/scripts/benchmark/opw_solver.py b/scripts/benchmark/opw_solver.py new file mode 100644 index 00000000..e3112ab0 --- /dev/null +++ b/scripts/benchmark/opw_solver.py @@ -0,0 +1,155 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import warp as wp +from scipy.spatial.transform import Rotation +from embodichain.lab.sim.solvers.opw_solver import OPWSolver, OPWSolverCfg +from typing import Tuple, List +import time + + +def get_pose_err(matrix_a: np.ndarray, matrix_b: np.ndarray) -> Tuple[float, float]: + t_err = np.linalg.norm(matrix_a[:3, 3] - matrix_b[:3, 3]) + relative_rot = matrix_a[:3, :3].T @ matrix_b[:3, :3] + cos_angle = (np.trace(relative_rot) - 1) / 2.0 + cos_angle = np.clip(cos_angle, -1.0, 1.0) + r_err = np.arccos(cos_angle) + return t_err, r_err + + +def get_poses_err( + matrix_a_list: List[np.ndarray], matrix_b_list: List[np.ndarray] +) -> Tuple[float, float]: + t_errs = [] + r_errs = [] + for mat_a, mat_b in zip(matrix_a_list, matrix_b_list): + t_err, r_err = get_pose_err(mat_a, mat_b) + t_errs.append(t_err) + r_errs.append(r_err) + return np.mean(t_errs), np.mean(r_errs) + + +def check_opw_solver(solver_warp, solver_py_opw, n_samples=1000): + DOF = 6 + qpos_np = np.random.uniform(low=-np.pi, high=np.pi, size=(n_samples, DOF)).astype( + float + ) + qpos = torch.tensor(qpos_np, device=torch.device("cuda"), dtype=torch.float32) + xpos = solver_warp.get_fk(qpos) + qpos_seed = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + device=torch.device("cuda"), + dtype=torch.float32, + ) + + warp_ik_start_time = time.time() + warp_ik_success, warp_ik_qpos = solver_warp.get_ik( + xpos, + qpos_seed=qpos_seed, + initial_guess=qpos, + # return_all_solutions=True, + ) + warp_cost_time = time.time() - warp_ik_start_time + + # TODO: debug code + # warp_ik_success_np = warp_ik_success.cpu().numpy() + # warp_ik_failure_indices = np.where(warp_ik_success_np == False)[0] + # if len(warp_ik_failure_indices) > 0: + # failure_qpos = qpos_np[warp_ik_failure_indices] + # failure_xpos = xpos.cpu().numpy()[warp_ik_failure_indices] + # print("=====warp_ik_failure_qpos:\n", repr(failure_qpos)) + # print("=====warp_ik_failure_xpos:\n", repr(failure_xpos)) + + # print("=====xpos:\n", repr(xpos.cpu().numpy())) + # print("=====warp_ik_qpos:\n", repr(warp_ik_qpos.cpu().numpy())) + # print("=====warp_ik_success:\n", repr(warp_ik_success.cpu().numpy())) + + check_xpos = solver_warp.get_fk(warp_ik_qpos) + warp_t_mean_err, warp_r_mean_err = get_poses_err( + [x.cpu().numpy() for x in xpos], + [x.cpu().numpy() for x in check_xpos], + ) + + py_opw_ik_start_time = time.time() + py_opw_ik_success, py_opw_ik_qpos = solver_py_opw.get_ik( + xpos, qpos_seed=qpos_seed, initial_guess=qpos + ) + py_opw_cost_time = time.time() - py_opw_ik_start_time + + check_xpos = solver_warp.get_fk(py_opw_ik_qpos.to(torch.device("cuda"))) + py_opw_t_mean_err, py_opw_r_mean_err = get_poses_err( + [x.cpu().numpy() for x in xpos], + [x.cpu().numpy() for x in check_xpos], + ) + + return ( + warp_cost_time, + warp_t_mean_err, + warp_r_mean_err, + py_opw_cost_time, + py_opw_t_mean_err, + py_opw_r_mean_err, + ) + + +def benchmark_opw_solver(): + cfg = OPWSolverCfg() + cfg.a1 = 400.333 + cfg.a2 = -251.449 + cfg.b = 0.0 + cfg.c1 = 830 + cfg.c2 = 1177.556 + cfg.c3 = 1443.593 + cfg.c4 = 230 + cfg.offsets = ( + 0.0, + 82.21350356417211 * np.pi / 180.0, + -167.21710113148163 * np.pi / 180.0, + 0.0, + 0.0, + 0.0, + ) + cfg.flip_axes = (True, False, True, True, False, True) + cfg.has_parallelogram = False + + # TODO: ignore pk_serial_chain for OPW + solver_warp = cfg.init_solver(device=torch.device("cuda"), pk_serial_chain="") + solver_py_opw = cfg.init_solver(device=torch.device("cpu"), pk_serial_chain="") + n_samples = [100, 1000, 10000, 100000] + # n_samples = [100] + for n_sample in n_samples: + # check_opw_solver(solver_warp, solver_py_opw, device=device, n_samples=n_sample) + ( + warp_cost_time, + warp_t_mean_err, + warp_r_mean_err, + py_opw_cost_time, + py_opw_t_mean_err, + py_opw_r_mean_err, + ) = check_opw_solver(solver_warp, solver_py_opw, n_samples=n_sample) + print(f"===warp OPW Solver FK/IK test over {n_sample} samples:") + print(f" Warp IK time: {warp_cost_time * 1000:.6f} ms") + print(f"Translation mean error: {warp_t_mean_err*1000:.6f} mm") + print(f"Rotation mean error: {warp_r_mean_err*180/np.pi:.6f} degrees") + print(f"===Py OPW IK time: {py_opw_cost_time * 1000:.6f} ms") + print(f"Translation mean error: {py_opw_t_mean_err*1000:.6f} mm") + print(f"Rotation mean error: {py_opw_r_mean_err*180/np.pi:.6f} degrees") + + +if __name__ == "__main__": + benchmark_opw_solver() diff --git a/scripts/tutorials/gym/modular_env.py b/scripts/tutorials/gym/modular_env.py new file mode 100644 index 00000000..21b2e893 --- /dev/null +++ b/scripts/tutorials/gym/modular_env.py @@ -0,0 +1,216 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch + +from typing import List, Dict, Any + +import embodichain.lab.gym.envs.managers.randomization as rand +import embodichain.lab.gym.envs.managers.events as events +import embodichain.lab.gym.envs.managers.observations as obs + +from embodichain.lab.gym.envs.managers import ( + EventCfg, + SceneEntityCfg, + ObservationCfg, +) +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.sim.robots import DexforceW1Cfg +from embodichain.lab.sim.sensors import StereoCameraCfg, SensorCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.cfg import ( + LightCfg, + ArticulationCfg, + RobotCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, +) +from embodichain.data import get_data_path +from embodichain.utils import configclass + + +@configclass +class ExampleEventCfg: + + replace_obj: EventCfg = EventCfg( + func=events.replace_assets_from_group, + mode="reset", + params={ + "entity_cfg": SceneEntityCfg( + uid="fork", + ), + "folder_path": get_data_path("TableWare/tableware/fork/"), + }, + ) + + randomize_light: EventCfg = EventCfg( + func=rand.randomize_light, + mode="interval", + interval_step=5, + params={ + "entity_cfg": SceneEntityCfg( + uid="point", + ), + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [50.0, 100.0], + }, + ) + + randomize_table_mat: EventCfg = EventCfg( + func=rand.randomize_visual_material, + mode="interval", + interval_step=10, + params={ + "entity_cfg": SceneEntityCfg( + uid="table", + ), + "random_texture_prob": 0.5, + "texture_path": get_data_path("CocoBackground/coco"), + "base_color_range": [[0.2, 0.2, 0.2], [1.0, 1.0, 1.0]], + }, + ) + + +@configclass +class ObsCfg: + + obj_pose: ObservationCfg = ObservationCfg( + func=obs.get_rigid_object_pose, + mode="add", + name="fork_pose", + params={"entity_cfg": SceneEntityCfg(uid="fork")}, + ) + + +@configclass +class ExampleCfg(EmbodiedEnvCfg): + + # Define the robot configuration using DexforceW1Cfg + robot: RobotCfg = DexforceW1Cfg.from_dict( + { + "uid": "dexforce_w1", + "version": "v021", + "arm_kind": "anthropomorphic", + "init_pos": [0.0, 0, 0.0], + } + ) + + # Define the sensor configuration using StereoCameraCfg + sensor: List[SensorCfg] = [ + StereoCameraCfg( + uid="eye_in_head", + width=960, + height=540, + enable_mask=True, + enable_depth=True, + left_to_right_pos=(0.06, 0, 0), + intrinsics=(450, 450, 480, 270), + intrinsics_right=(450, 450, 480, 270), + extrinsics=StereoCameraCfg.ExtrinsicsCfg( + parent="eyes", + ), + ) + ] + + light: EmbodiedEnvCfg.EnvLightCfg = EmbodiedEnvCfg.EnvLightCfg( + direct=[ + LightCfg( + uid="point", + light_type="point", + color=(1.0, 1.0, 1.0), + intensity=50.0, + init_pos=(0, 0, 2), + ) + ] + ) + + background: List[RigidObjectCfg] = [ + RigidObjectCfg( + uid="table", + shape=MeshCfg( + fpath=get_data_path("CircleTableSimple/circle_table_simple.ply"), + compute_uv=True, + ), + attrs=RigidBodyAttributesCfg( + mass=10.0, + static_friction=0.95, + dynamic_friction=0.85, + restitution=0.01, + ), + body_type="kinematic", + init_pos=(0.80, 0, 0.8), + init_rot=(0, 90, 0), + ), + ] + + rigid_object: List[RigidObjectCfg] = [ + RigidObjectCfg( + uid="fork", + shape=MeshCfg( + fpath=get_data_path("TableWare/tableware/fork/standard_fork_scale.ply"), + ), + body_scale=(0.75, 0.75, 1.0), + init_pos=(0.8, 0, 1.0), + ), + ] + + articulation_cfg: List[ArticulationCfg] = [ + ArticulationCfg( + uid="drawer", + fpath="SlidingBoxDrawer/SlidingBoxDrawer.urdf", + init_pos=(0.5, 0.0, 0.85), + ) + ] + + events = ExampleEventCfg() + + observations = ObsCfg() + + +@register_env("ModularEnv-v1", max_episode_steps=100, override=True) +class ModularEnv(EmbodiedEnv): + """ + An example of a modular environment that inherits from EmbodiedEnv + and uses custom event and observation managers. + """ + + def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): + super().__init__(cfg, **kwargs) + + +if __name__ == "__main__": + import gymnasium as gym + import argparse + + from embodichain.lab.sim import SimulationManagerCfg + + parser = argparse.ArgumentParser() + parser.add_argument("--enable_rt", action="store_true", help="Enable ray tracing") + args = parser.parse_args() + + env_cfg = ExampleCfg(sim_cfg=SimulationManagerCfg(enable_rt=args.enable_rt)) + + # Create the Gym environment + env = gym.make("ModularEnv-v1", cfg=env_cfg) + + while True: + obs, info = env.reset() + + for i in range(100): + action = torch.zeros(env.action_space.shape, dtype=torch.float32) + obs, reward, done, truncated, info = env.step(action) diff --git a/scripts/tutorials/gym/random_reach.py b/scripts/tutorials/gym/random_reach.py new file mode 100644 index 00000000..a6e4ed0a --- /dev/null +++ b/scripts/tutorials/gym/random_reach.py @@ -0,0 +1,170 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import numpy as np +import gymnasium as gym + +from embodichain.lab.gym.envs import BaseEnv, EnvCfg +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.sim.types import EnvAction, EnvObs +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.cfg import ( + RobotCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, +) +from embodichain.lab.gym.utils.registration import register_env + + +@register_env("RandomReach-v1", max_episode_steps=100, override=True) +class RandomReachEnv(BaseEnv): + + robot_init_qpos = np.array( + [1.57079, -1.57079, 1.57079, -1.57079, -1.57079, -3.14159] + ) + + def __init__( + self, + num_envs=1, + headless=False, + device="cpu", + **kwargs, + ): + env_cfg = EnvCfg( + sim_cfg=SimulationManagerCfg( + headless=headless, arena_space=2.0, sim_device=device + ), + num_envs=num_envs, + ) + + super().__init__( + cfg=env_cfg, + **kwargs, + ) + + def _setup_robot(self, **kwargs) -> Robot: + from embodichain.data import get_data_path + + file_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + + robot: Robot = self.sim.add_robot( + cfg=RobotCfg( + uid="ur10", + fpath=file_path, + init_pos=(0, 0, 1), + init_qpos=self.robot_init_qpos, + ) + ) + + qpos_limits = robot.body_data.qpos_limits[0].cpu().numpy() + self.single_action_space = gym.spaces.Box( + low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 + ) + + return robot + + def _prepare_scene(self, **kwargs) -> None: + size = 0.03 + # Create a kinematic cube object without collision. + # Currently, we use this workaround for visualization purposes. + self.cube: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="cube", + shape=CubeCfg(size=[size, size, size]), + attrs=RigidBodyAttributesCfg(enable_collision=False), + init_pos=(0.0, 0.0, 0.5), + body_type="kinematic", + ), + ) + + def _update_sim_state(self, **kwargs) -> None: + pose = torch.eye(4, device=self.device) + pose = pose.unsqueeze_(0).repeat(self.num_envs, 1, 1) + pose[:, :3, 3] += torch.rand(self.num_envs, 3, device=self.device) * 0.5 - 0.25 + self.cube.set_local_pose(pose=pose) + + def _step_action(self, action: EnvAction) -> EnvAction: + self.robot.set_qpos(qpos=action) + return action + + def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: + # You can also use `cube = self.sim.get_rigid_object("cube")` to access obj. + # obs["cube_position"] = self.cube.get_local_pose()[:, :3] + return obs + + +if __name__ == "__main__": + import argparse + import time + + parser = argparse.ArgumentParser( + description="Demo for running a random reach environment." + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="number of environments to run" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + help="device to run the environment on, e.g., 'cpu' or 'cuda'", + ) + parser.add_argument("--headless", action="store_true", help="run in headless mode") + args = parser.parse_args() + + env = gym.make( + "RandomReach-v1", + num_envs=args.num_envs, + headless=args.headless, + device=args.device, + ) + + for episode in range(10): + print("Episode:", episode) + env.reset() + start_time = time.time() + total_steps = 0 + + for i in range(100): + action = env.action_space.sample() + action = torch.as_tensor(action, dtype=torch.float32, device=env.device) + + init_pose = env.robot_init_qpos + init_pose = ( + torch.as_tensor(init_pose, dtype=torch.float32, device=env.device) + .unsqueeze_(0) + .repeat(env.num_envs, 1) + ) + action = ( + init_pose + + torch.rand_like(action, dtype=torch.float32, device=env.device) * 0.2 + - 0.1 + ) + + obs, reward, done, truncated, info = env.step(action) + total_steps += env.num_envs + + end_time = time.time() + elapsed_time = end_time - start_time + if elapsed_time > 0: + fps = total_steps / elapsed_time + print(f"Total steps: {total_steps}") + print(f"Elapsed time: {elapsed_time:.2f} seconds") + print(f"FPS: {fps:.2f}") + else: + print("Elapsed time is too short to calculate FPS.") diff --git a/scripts/tutorials/sim/create_rigid_object_group.py b/scripts/tutorials/sim/create_rigid_object_group.py new file mode 100644 index 00000000..f8ae262f --- /dev/null +++ b/scripts/tutorials/sim/create_rigid_object_group.py @@ -0,0 +1,170 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create a rigid object group using SimulationManager. +""" + +import argparse +import time + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RigidBodyAttributesCfg +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.lab.sim.objects import ( + RigidObjectGroup, + RigidObjectGroupCfg, + RigidObjectCfg, +) + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--headless", + action="store_true", + default=False, + help="Run simulation in headless mode", + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=True, + physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) + sim_device=args.device, + enable_rt=args.enable_rt, # Enable ray tracing for better visuals + ) + + # Create the simulation instance + sim = SimulationManager(sim_cfg) + + # Enable manual physics update for precise control + sim.set_manual_update(True) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + physics_attrs = RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ) + + # Add objects to the scene + obj_group: RigidObjectGroup = sim.add_rigid_object_group( + cfg=RigidObjectGroupCfg( + uid="obj_group", + rigid_objects={ + "cube_1": RigidObjectCfg( + uid="cube_1", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + attrs=physics_attrs, + init_pos=[0.0, 0.0, 1.0], + ), + "cube_2": RigidObjectCfg( + uid="cube_2", + shape=CubeCfg(size=[0.2, 0.2, 0.2]), + attrs=physics_attrs, + init_pos=[0.5, 0.0, 1.0], + ), + "cube_3": RigidObjectCfg( + uid="cube_3", + shape=CubeCfg(size=[0.3, 0.3, 0.3]), + attrs=physics_attrs, + init_pos=[-0.5, 0.0, 1.0], + ), + }, + ) + ) + + print("[INFO]: Scene setup complete!") + print(f"[INFO]: Running simulation with {args.num_envs} environment(s)") + print("[INFO]: Press Ctrl+C to stop the simulation") + + # Open window when the scene has been set up + if not args.headless: + sim.open_window() + + # Run the simulation + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + """Run the simulation loop. + + Args: + sim: The SimulationManager instance to run + """ + + # Initialize GPU physics if using CUDA + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + step_count = 0 + + try: + last_time = time.time() + last_step = 0 + while True: + # Update physics simulation + sim.update(step=1) + step_count += 1 + + # Print FPS every second + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + print(f"[INFO]: Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + + except KeyboardInterrupt: + print("\n[INFO]: Stopping simulation...") + finally: + # Clean up resources + sim.destroy() + print("[INFO]: Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/sim/create_robot.py b/scripts/tutorials/sim/create_robot.py new file mode 100644 index 00000000..6a504f70 --- /dev/null +++ b/scripts/tutorials/sim/create_robot.py @@ -0,0 +1,218 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create and simulate a robot using SimulationManager. +It shows how to load a robot from URDF, set up control parts, and run basic simulation. +""" + +import argparse +import numpy as np +import time +import torch + +torch.set_printoptions(precision=4, sci_mode=False) + +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + URDFCfg, +) +from embodichain.data import get_data_path + + +def main(): + """Main function to demonstrate robot simulation.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of environments to simulate" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Device to run simulation on", + ) + parser.add_argument("--headless", action="store_true", help="Run in headless mode") + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + args = parser.parse_args() + + # Initialize simulation + print("Creating simulation...") + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + arena_space=3.0, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + ) + sim = SimulationManager(config) + + # Build multiple environments if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs) + + # Set manual physics update for precise control + sim.set_manual_update(True) + + # Create robot configuration + robot = create_robot(sim) + + # Initialize GPU physics if using CUDA + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + # Open visualization window if not headless + if not args.headless: + sim.open_window() + + # Run simulation loop + run_simulation(sim, robot) + + +def create_robot(sim): + """Create and configure a robot in the simulation.""" + + print("Loading robot...") + + # Get SR5 arm URDF path + sr5_urdf_path = get_data_path("Rokae/SR5/SR5.urdf") + + # Get hand URDF path + hand_urdf_path = get_data_path( + "BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf" + ) + + # Define control parts for the robot + # Joint names in control_parts can be regex patterns + CONTROL_PARTS = { + "arm": [ + "JOINT[1-6]", # Matches JOINT1, JOINT2, ..., JOINT6 + ], + "hand": ["LEFT_.*"], # Matches all joints starting with L_ + } + + # Define transformation for hand attachment + hand_attach_xpos = np.eye(4) + hand_attach_xpos[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix() + hand_attach_xpos[2, 3] = 0.02 + + cfg = RobotCfg( + uid="sr5_with_brainco", + urdf_cfg=URDFCfg( + components=[ + { + "component_type": "arm", + "urdf_path": sr5_urdf_path, + }, + { + "component_type": "hand", + "urdf_path": hand_urdf_path, + "transform": hand_attach_xpos, + }, + ] + ), + control_parts=CONTROL_PARTS, + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[1-6]": 1e4, "LEFT_.*": 1e3}, + damping={"JOINT[1-6]": 1e3, "LEFT_.*": 1e2}, + ), + ) + + # Add robot to simulation + robot: Robot = sim.add_robot(cfg=cfg) + + print(f"Robot created successfully with {robot.dof} joints") + + return robot + + +def run_simulation(sim: SimulationManager, robot: Robot): + """Run the simulation loop with robot control.""" + + print("Starting simulation...") + print("Robot will move through different poses") + print("Press Ctrl+C to stop") + + step_count = 0 + + arm_joint_ids = robot.get_joint_ids("arm") + # Define some target joint positions for demonstration + arm_position1 = ( + torch.tensor( + [0.0, -0.5, 0.5, -1.0, 0.5, 0.0], dtype=torch.float32, device=sim.device + ) + .unsqueeze_(0) + .repeat(sim.num_envs, 1) + ) + + arm_position2 = ( + torch.tensor( + [0.5, 0.0, -0.5, 0.5, -0.5, 0.5], dtype=torch.float32, device=sim.device + ) + .unsqueeze_(0) + .repeat(sim.num_envs, 1) + ) + + # Get joint IDs for the hand. + hand_joint_ids = robot.get_joint_ids("hand") + # Define hand open and close positions based on joint limits. + hand_position_open = robot.body_data.qpos_limits[:, hand_joint_ids, 1] + hand_position_close = robot.body_data.qpos_limits[:, hand_joint_ids, 0] + + try: + while True: + # Update physics + sim.update(step=1) + + if step_count % 4000 == 0: + robot.set_qpos(qpos=arm_position1, joint_ids=arm_joint_ids) + print(f"Moving to arm position 1") + + if step_count % 4000 == 1000: + robot.set_qpos(qpos=arm_position2, joint_ids=arm_joint_ids) + print(f"Moving to arm position 2") + + if step_count % 4000 == 2000: + robot.set_qpos(qpos=hand_position_close, joint_ids=hand_joint_ids) + print(f"Closing hand") + + if step_count % 4000 == 3000: + robot.set_qpos(qpos=hand_position_open, joint_ids=hand_joint_ids) + print(f"Opening hand") + + step_count += 1 + + except KeyboardInterrupt: + print("Stopping simulation...") + finally: + print("Cleaning up...") + sim.destroy() + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/sim/create_scene.py b/scripts/tutorials/sim/create_scene.py new file mode 100644 index 00000000..3f1082e6 --- /dev/null +++ b/scripts/tutorials/sim/create_scene.py @@ -0,0 +1,149 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create a simulation scene using SimulationManager. +It shows the basic setup of simulation context, adding objects, and sensors. +""" + +import argparse +import time + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import RigidBodyAttributesCfg +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--headless", + action="store_true", + default=False, + help="Run simulation in headless mode", + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=True, + physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) + sim_device=args.device, + enable_rt=args.enable_rt, # Enable ray tracing for better visuals + ) + + # Create the simulation instance + sim = SimulationManager(sim_cfg) + + # Enable manual physics update for precise control + sim.set_manual_update(True) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Add objects to the scene + cube: RigidObject = sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="cube", + shape=CubeCfg(size=[0.1, 0.1, 0.1]), + body_type="dynamic", + attrs=RigidBodyAttributesCfg( + mass=1.0, + dynamic_friction=0.5, + static_friction=0.5, + restitution=0.1, + ), + init_pos=[0.0, 0.0, 1.0], + ) + ) + + print("[INFO]: Scene setup complete!") + print(f"[INFO]: Running simulation with {args.num_envs} environment(s)") + print("[INFO]: Press Ctrl+C to stop the simulation") + + # Open window when the scene has been set up + if not args.headless: + sim.open_window() + + # Run the simulation + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + """Run the simulation loop. + + Args: + sim: The SimulationManager instance to run + """ + + # Initialize GPU physics if using CUDA + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + step_count = 0 + + try: + last_time = time.time() + last_step = 0 + while True: + # Update physics simulation + sim.update(step=1) + step_count += 1 + + # Print FPS every second + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + print(f"[INFO]: Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + + except KeyboardInterrupt: + print("\n[INFO]: Stopping simulation...") + finally: + # Clean up resources + sim.destroy() + print("[INFO]: Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/sim/create_sensor.py b/scripts/tutorials/sim/create_sensor.py new file mode 100644 index 00000000..380852a5 --- /dev/null +++ b/scripts/tutorials/sim/create_sensor.py @@ -0,0 +1,346 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create and simulate a camera sensor attached to a robot using SimulationManager. +It shows how to configure a camera sensor, attach it to the robot's end-effector, and visualize the sensor's output during simulation. +""" + +import argparse +import cv2 +import numpy as np +import time +import torch + +torch.set_printoptions(precision=4, sci_mode=False) + +from scipy.spatial.transform import Rotation as R + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.sensors import Camera, CameraCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import ( + JointDrivePropertiesCfg, + RobotCfg, + URDFCfg, + RigidObjectCfg, +) +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.data import get_data_path + + +def mask_to_color_map(mask, user_ids, fix_seed=True): + """ + Convert instance mask into color map. + :param mask: Instance mask map. + :param user_ids: List of unique user IDs in the mask. + :return: Color map. + """ + # Create a blank RGB image + color_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8) + + # Generate deterministic colors based on user_id values + colors = [] + for user_id in user_ids: + # Use the user_id as seed to generate deterministic color + np.random.seed(user_id) + color = np.random.choice(range(256), size=3) + colors.append(color) + + for idx, color in enumerate(colors): + # Assign color to the instances of each class + color_map[mask == user_ids[idx]] = color + + return color_map + + +def main(): + """Main function to demonstrate robot sensor simulation.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create and simulate a robot in SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of environments to simulate" + ) + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Device to run simulation on", + ) + parser.add_argument("--headless", action="store_true", help="Run in headless mode") + parser.add_argument( + "--enable_rt", action="store_true", help="Enable ray tracing rendering" + ) + parser.add_argument( + "--attach_sensor", + action="store_true", + help="Attach sensor to robot end-effector", + ) + args = parser.parse_args() + + # Initialize simulation + print("Creating simulation...") + config = SimulationManagerCfg( + headless=True, + sim_device=args.device, + arena_space=3.0, + enable_rt=args.enable_rt, + physics_dt=1.0 / 100.0, + ) + sim = SimulationManager(config) + + # Build multiple environments if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs) + + # Set manual physics update for precise control + sim.set_manual_update(True) + + # Create robot configuration + robot = create_robot(sim) + + sensor = create_sensor(sim, args) + + # Add a cube to the scene + cube_cfg = RigidObjectCfg( + uid="cube", + shape=CubeCfg(size=[0.05, 0.05, 0.05]), # Use CubeCfg for a cube + init_pos=[1.2, -0.2, 0.1], + init_rot=[0, 0, 0], + ) + sim.add_rigid_object(cfg=cube_cfg) + + # Initialize GPU physics if using CUDA + if sim.is_use_gpu_physics: + sim.init_gpu_physics() + + # Open visualization window if not headless + if not args.headless: + sim.open_window() + + # Run simulation loop + run_simulation(sim, robot, sensor) + + +def create_sensor(sim: SimulationManager, args): + # intrinsics params + intrinsics = (600, 600, 320.0, 240.0) + width = 640 + height = 480 + + # extrinsics params + pos = [0.09, 0.05, 0.04] + quat = R.from_euler("xyz", [-35, 135, 0], degrees=True).as_quat().tolist() + + # If attach_sensor is True, attach to robot end-effector; otherwise, place it in the scene + if args.attach_sensor: + parent = "ee_link" + else: + parent = None + pos = [1.2, -0.2, 1.5] + quat = R.from_euler("xyz", [0, 180, 0], degrees=True).as_quat().tolist() + quat = [quat[3], quat[0], quat[1], quat[2]] # Convert to (w, x, y, z) + + # create camera sensor and attach to robot end-effector + camera: Camera = sim.add_sensor( + sensor_cfg=CameraCfg( + width=width, + height=height, + intrinsics=intrinsics, + extrinsics=CameraCfg.ExtrinsicsCfg( + parent=parent, + pos=pos, + quat=quat, + ), + near=0.01, + far=10.0, + enable_color=True, + enable_depth=True, + enable_mask=True, + enable_normal=True, + ) + ) + return camera + + +def create_robot(sim): + """Create and configure a robot in the simulation.""" + + print("Loading robot...") + + # Get SR5 URDF path + sr5_urdf_path = get_data_path("Rokae/SR5/SR5.urdf") + + # Get hand URDF path + hand_urdf_path = get_data_path( + "BrainCoHandRevo1/BrainCoLeftHand/BrainCoLeftHand.urdf" + ) + + # Define control parts for the robot + # Joint names in control_parts can be regex patterns + CONTROL_PARTS = { + "arm": [ + "JOINT[1-6]", # Matches JOINT1, JOINT2, ..., JOINT6 + ], + "hand": ["LEFT_.*"], # Matches all joints starting with L_ + } + + # Define transformation for hand attachment + hand_attach_xpos = np.eye(4) + hand_attach_xpos[:3, :3] = R.from_rotvec([90, 0, 0], degrees=True).as_matrix() + hand_attach_xpos[2, 3] = 0.02 + + cfg = RobotCfg( + uid="sr5_with_brainco", + urdf_cfg=URDFCfg( + components=[ + { + "component_type": "arm", + "urdf_path": sr5_urdf_path, + }, + { + "component_type": "hand", + "urdf_path": hand_urdf_path, + "transform": hand_attach_xpos, + }, + ] + ), + control_parts=CONTROL_PARTS, + drive_pros=JointDrivePropertiesCfg( + stiffness={"JOINT[1-6]": 1e4, "LEFT_.*": 1e3}, + damping={"JOINT[1-6]": 1e3, "LEFT_.*": 1e2}, + ), + ) + + # Add robot to simulation + robot: Robot = sim.add_robot(cfg=cfg) + + print(f"Robot created successfully with {robot.dof} joints") + + return robot + + +def get_sensor_image(camera: Camera, headless=False, step_count=0): + """ + Get color, depth, mask, and normals views from the camera, + and visualize them in a 2x2 grid (or save if headless). + """ + import matplotlib.pyplot as plt + + camera.update() + data = camera.get_data() + # Get four views + rgba = data["color"].cpu().numpy()[0, :, :, :3] # (H, W, 3) + depth = data["depth"].squeeze_().cpu().numpy() # (H, W) + mask = data["mask"].squeeze_().cpu().numpy() # (H, W) + normals = data["normal"].cpu().numpy()[0] # (H, W, 3) + + # Normalize for visualization + depth_vis = (depth - depth.min()) / (depth.ptp() + 1e-8) + depth_vis = (depth_vis * 255).astype("uint8") + mask_vis = mask_to_color_map(mask, user_ids=np.unique(mask)) + normals_vis = ((normals + 1) / 2 * 255).astype("uint8") + + # Prepare titles and images for display + titles = ["Color", "Depth", "Mask", "Normals"] + images = [ + cv2.cvtColor(rgba, cv2.COLOR_RGB2BGR), + cv2.cvtColor(depth_vis, cv2.COLOR_GRAY2BGR), + mask_vis, + cv2.cvtColor(normals_vis, cv2.COLOR_RGB2BGR), + ] + + if not headless: + # Concatenate images for 2x2 grid display using OpenCV + top = np.hstack([images[0], images[1]]) + bottom = np.hstack([images[2], images[3]]) + grid = np.vstack([top, bottom]) + cv2.imshow("Sensor Views (Color / Depth / Mask / Normals)", grid) + cv2.waitKey(1) + else: + # Save the 2x2 grid as an image using matplotlib + fig, axs = plt.subplots(2, 2, figsize=(10, 8)) + for ax, img, title in zip(axs.flatten(), images, titles): + ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + ax.set_title(title) + ax.axis("off") + plt.tight_layout() + plt.savefig(f"sensor_views_{step_count}.png") + plt.close(fig) + + +def run_simulation(sim: SimulationManager, robot: Robot, camera: Camera): + """Run the simulation loop with robot and camera sensor control.""" + + print("Starting simulation...") + print("Robot will move through different poses") + print("Press Ctrl+C to stop") + + step_count = 0 + + arm_joint_ids = robot.get_joint_ids("arm") + # Define some target joint positions for demonstration + + arm_position1 = ( + torch.tensor( + [0.0, 0.5, -1.5, 0.3, -0.5, 0], dtype=torch.float32, device=sim.device + ) + .unsqueeze_(0) + .repeat(sim.num_envs, 1) + ) + + arm_position2 = ( + torch.tensor( + [0.0, 0.5, -1.5, -0.3, -0.5, 0], dtype=torch.float32, device=sim.device + ) + .unsqueeze_(0) + .repeat(sim.num_envs, 1) + ) + + try: + while True: + # Update physics + sim.update(step=1) + + if step_count % 1001 == 0: + robot.set_qpos(qpos=arm_position1, joint_ids=arm_joint_ids) + print(f"Moving to arm position 1") + + # Refresh and get image from sensor + get_sensor_image(camera) + + if step_count % 2003 == 0: + robot.set_qpos(qpos=arm_position2, joint_ids=arm_joint_ids) + print(f"Moving to arm position 2") + + # Refresh and get image from sensor + get_sensor_image(camera) + + step_count += 1 + + except KeyboardInterrupt: + print("Stopping simulation...") + finally: + print("Cleaning up...") + sim.destroy() + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/sim/create_softbody.py b/scripts/tutorials/sim/create_softbody.py new file mode 100644 index 00000000..af6b5cc8 --- /dev/null +++ b/scripts/tutorials/sim/create_softbody.py @@ -0,0 +1,161 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +""" +This script demonstrates how to create a simulation scene using SimulationManager. +It shows the basic setup of simulation context, adding objects, lighting, and sensors. +""" + +import argparse +import time +from dexsim.utility.path import get_resources_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + SoftbodyVoxelAttributesCfg, + SoftbodyPhysicalAttributesCfg, +) +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.objects import ( + SoftObject, + SoftObjectCfg, +) + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--headless", + action="store_true", + default=False, + help="Run simulation in headless mode", + ) + parser.add_argument( + "--num_envs", type=int, default=4, help="Number of parallel environments" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=True, + physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) + sim_device="cuda", # soft simulation only supports cuda device + enable_rt=args.enable_rt, # Enable ray tracing for better visuals + ) + + # Create the simulation instance + sim = SimulationManager(sim_cfg) + + # Enable manual physics update for precise control + sim.set_manual_update(True) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + print("[INFO]: Scene setup complete!") + + # add softbody to the scene + cow: SoftObject = sim.add_soft_object( + cfg=SoftObjectCfg( + uid="cow", + shape=MeshCfg( + fpath=get_resources_data_path("Model", "cow", "cow.obj"), + ), + init_pos=[0.0, 0.0, 3.0], + voxel_attr=SoftbodyVoxelAttributesCfg( + simulation_mesh_resolution=8, + maximal_edge_length=0.5, + ), + physical_attr=SoftbodyPhysicalAttributesCfg( + youngs=1e6, + poissons=0.45, + density=100, + dynamic_friction=0.1, + min_position_iters=30, + ), + ), + ) + print("[INFO]: Add soft object complete!") + + # Open window when the scene has been set up + if not args.headless: + sim.open_window() + + print(f"[INFO]: Running simulation with {args.num_envs} environment(s)") + print("[INFO]: Press Ctrl+C to stop the simulation") + + # Run the simulation + run_simulation(sim, cow) + + +def run_simulation(sim: SimulationManager, soft_obj: SoftObject) -> None: + """Run the simulation loop. + + Args: + sim: The SimulationManager instance to run + soft_obj: soft object + """ + + # Initialize GPU physics + sim.init_gpu_physics() + + step_count = 0 + + try: + last_time = time.time() + last_step = 0 + while True: + # Update physics simulation + sim.update(step=1) + step_count += 1 + + # Print FPS every second + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + print(f"[INFO]: Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + if step_count % 500 == 0: + soft_obj.reset() + + except KeyboardInterrupt: + print("\n[INFO]: Stopping simulation...") + finally: + # Clean up resources + sim.destroy() + print("[INFO]: Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/scripts/tutorials/sim/gizmo_robot.py b/scripts/tutorials/sim/gizmo_robot.py new file mode 100644 index 00000000..fae79962 --- /dev/null +++ b/scripts/tutorials/sim/gizmo_robot.py @@ -0,0 +1,158 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +""" +Gizmo-Robot Example: Test Gizmo class on a robot (UR10) +""" + +import time +import torch +import numpy as np +import argparse + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + URDFCfg, + JointDrivePropertiesCfg, +) + +from embodichain.lab.sim.solvers import PinkSolverCfg +from embodichain.data import get_data_path +from embodichain.utils import logger + + +def main(): + """Main function to create and run the simulation scene.""" + + # Parse command line arguments + parser = argparse.ArgumentParser( + description="Create a simulation scene with SimulationManager" + ) + parser.add_argument( + "--num_envs", type=int, default=1, help="Number of parallel environments" + ) + parser.add_argument( + "--device", type=str, default="cpu", help="Simulation device (cuda or cpu)" + ) + parser.add_argument( + "--enable_rt", + action="store_true", + default=False, + help="Enable ray tracing for better visuals", + ) + args = parser.parse_args() + + # Configure the simulation + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + physics_dt=1.0 / 100.0, + sim_device=args.device, + enable_rt=args.enable_rt, + ) + + sim = SimulationManager(sim_cfg) + sim.set_manual_update(False) + + # Build multiple arenas if requested + if args.num_envs > 1: + sim.build_multiple_arenas(args.num_envs, space=3.0) + + # Get UR10 URDF path + urdf_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + + # Create UR10 robot + robot_cfg = RobotCfg( + uid="ur10_gizmo_test", + urdf_cfg=URDFCfg( + components=[{"component_type": "arm", "urdf_path": urdf_path}] + ), + control_parts={"arm": ["Joint[1-6]"]}, + solver_cfg={ + "arm": PinkSolverCfg( + urdf_path=urdf_path, + end_link_name="ee_link", + root_link_name="base_link", + pos_eps=1e-2, + rot_eps=5e-2, + max_iterations=300, + dt=0.1, + ) + }, + drive_pros=JointDrivePropertiesCfg( + stiffness={"Joint[1-6]": 1e4}, + damping={"Joint[1-6]": 1e3}, + ), + ) + robot = sim.add_robot(cfg=robot_cfg) + + # Set initial joint positions + initial_qpos = torch.tensor( + [[0, -np.pi / 2, np.pi / 2, 0.0, np.pi / 2, 0.0]], + dtype=torch.float32, + device="cpu", + ) + joint_ids = robot.get_joint_ids("arm") + robot.set_qpos(qpos=initial_qpos, joint_ids=joint_ids) + + time.sleep(0.2) # Wait for a moment to ensure everything is set up + + # Enable gizmo using the new API + sim.enable_gizmo(uid="ur10_gizmo_test", control_part="arm") + if not sim.has_gizmo("ur10_gizmo_test", control_part="arm"): + logger.log_error("Failed to enable gizmo!") + return + + sim.open_window() + + logger.log_info("Gizmo-Robot example started!") + logger.log_info("Use the gizmo to drag the robot end-effector (EE)") + logger.log_info("Press Ctrl+C to stop the simulation") + + run_simulation(sim) + + +def run_simulation(sim: SimulationManager): + step_count = 0 + try: + last_time = time.time() + last_step = 0 + while True: + time.sleep(0.033) # 30Hz + # Update all gizmos managed by sim + sim.update_gizmos() + step_count += 1 + + if step_count % 100 == 0: + current_time = time.time() + elapsed = current_time - last_time + fps = ( + sim.num_envs * (step_count - last_step) / elapsed + if elapsed > 0 + else 0 + ) + logger.log_info(f"Simulation step: {step_count}, FPS: {fps:.2f}") + last_time = current_time + last_step = step_count + except KeyboardInterrupt: + logger.log_info("\nStopping simulation...") + finally: + sim.destroy() + logger.log_info("Simulation terminated successfully") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..ae041d57 --- /dev/null +++ b/setup.py @@ -0,0 +1,123 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import glob +import logging +import os +import shutil +import sys +from os import path as osp +from pathlib import Path + +from setuptools import Command, find_packages, setup + +logging.basicConfig(stream=sys.stderr, level=logging.INFO) +logger = logging.getLogger() + +THIS_DIR = Path(__file__).resolve().parent + +# Defer importing torch until it's actually needed (when building extensions). +# This prevents `setup.py` from failing at import time in environments where +# torch isn't available or isn't on the same interpreter. +BuildExtension = None +CppExtension = None +CUDAExtension = None + + +class CleanCommand(Command): + description = "Delete build, dist, *.egg-info and all __pycache__ directories." + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + for d in ["build", "dist", "embodichain.egg-info"]: + rm_path = THIS_DIR / d + if not rm_path.exists(): + continue + try: + shutil.rmtree(rm_path, ignore_errors=True) + logger.info(f"removed '{rm_path}'") + except: + pass + + for pdir, sdirs, filenames in os.walk(THIS_DIR): + for sdir in sdirs: + if sdir == "__pycache__": + rm_path = Path(pdir) / sdir + shutil.rmtree(str(rm_path), ignore_errors=True) + logger.info(f"removed '{rm_path}'") + for filename in filenames: + if filename.endswith(".so"): + rm_path = Path(pdir) / filename + rm_path.unlink() + logger.info(f"removed '{rm_path}'") + + +def get_data_files_of_a_directory(source_dir, target_dir=None, ignore_py=False): + if target_dir is None: + target_dir = source_dir + + base_dir = os.sep + "embodichain" + os.sep + + filelist = [] + for parent_dir, dirnames, filenames in os.walk(source_dir): + for filename in filenames: + if ignore_py and filename.endswith(".py"): + continue + filelist.append( + ( + os.path.join( + base_dir, parent_dir.replace(source_dir, target_dir, 1) + ), + [os.path.join(parent_dir, filename)], + ) + ) + + return filelist + + +# Extract version +here = osp.abspath(osp.dirname(__file__)) +version = None +with open(os.path.join(os.path.dirname(__file__), "VERSION")) as f: + full_version = f.read().strip() + version = ".".join(full_version.split(".")[:3]) + +ignore_py = sys.argv[1] == "bdist_nuitka" if len(sys.argv) > 1 else False +data_files = [] +data_files += get_data_files_of_a_directory("embodichain", ignore_py=ignore_py) + +cmdclass = {"clean": CleanCommand} +if BuildExtension is not None: + cmdclass["build_ext"] = BuildExtension.with_options(no_python_abi_suffix=True) + +setup( + name="embodichain", + version=version, + url="https://github.com/DexForce/EmbodiChain", + author="EmbodiChain Developers", + description="An end-to-end, GPU-accelerated, and modular platform for building generalized Embodied Intelligence.", + packages=find_packages(exclude=["docs"]), + data_files=data_files, + entry_points={}, + cmdclass=cmdclass, + include_package_data=True, +) diff --git a/tests/common.py b/tests/common.py new file mode 100644 index 00000000..9d80e5c6 --- /dev/null +++ b/tests/common.py @@ -0,0 +1,63 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +from unittest import TestLoader +from fnmatch import fnmatchcase + + +__all__ = ["UnittestMetaclass", "OrderedTestLoader"] + + +# to learn about the usage of metaclass here: https://www.liaoxuefeng.com/wiki/1016959663602400/1017592449371072 +class UnittestMetaclass(type): + def __new__(cls, name, bases, attrs): + # add 'attrs_by_writing_order' attribute containing writing order of all attributes and functions + attrs["attrs_by_writing_order"] = list(attrs.keys()) + return super().__new__(cls, name, bases, attrs) + + +# By default, TestLoader runs tests in alphabetical order. However, some tests +# need to be executed in the order they are written. This custom loader overrides +# the default sorting behavior to run tests sequentially based on the writing order. +# Note that when both errors and failures occur, errors will be logged first, +# which may differ from the execution order. This is acceptable as it prioritizes +# highlighting errors. +class OrderedTestLoader(TestLoader): + """This TestLoader will load testFnNames in the code writing order""" + + # copied from getTestCaseNames() of TestLoader and make some modification + def getTestCaseNames(self, testCaseClass): + """Return a sorted sequence of method names found within testCaseClass""" + + def shouldIncludeMethod(attrname): + if not attrname.startswith(self.testMethodPrefix): + return False + testFunc = getattr(testCaseClass, attrname) + if not callable(testFunc): + return False + fullName = f"%s.%s.%s" % ( + testCaseClass.__module__, + testCaseClass.__qualname__, + attrname, + ) + return self.testNamePatterns is None or any( + fnmatchcase(fullName, pattern) for pattern in self.testNamePatterns + ) + + testFnNames = list( + filter(shouldIncludeMethod, testCaseClass.attrs_by_writing_order) + ) + + return testFnNames diff --git a/tests/datasets/run_pourwater_env_offline.py b/tests/datasets/run_pourwater_env_offline.py new file mode 100644 index 00000000..80da0d4e --- /dev/null +++ b/tests/datasets/run_pourwater_env_offline.py @@ -0,0 +1,107 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import unittest +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from common import UnittestMetaclass, OrderedTestLoader + +import os +import tempfile +import gymnasium +from pathlib import Path + +from embodichain.utils.utility import dict2args +from embodichain.utils.utility import load_json +from embodichain.lab.sim import SimulationManagerCfg +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.utils.gym_utils import config_to_cfg + +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + + +class TestPourWaterv3OfflineRunEnv(unittest.TestCase, metaclass=UnittestMetaclass): + datacenter_backup = Path("/tmp/datacenter_test") + + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass + + def test_offline_run_env(self): + from embodichain.lab.scripts.run_env import main + import os + + with tempfile.TemporaryDirectory(prefix=self.__class__.__name__) as temp_dir: + gym_conf_path = os.path.join( + "configs", + "gym", + "pour_water", + "gym_config.json", + ) + action_conf_path = os.path.join( + "configs", + "gym", + "pour_water", + "action_config.json", + ) + input_dict = { + "num_envs": 1, # TODO: change it to >1 as v3 supports it. but now CobotMagic use cpu-OPWSolver. Wait @Chenjian for gpu version. + "device": "cpu", # TODO: test both cpu and cuda device + "headless": True, + "enable_rt": False, + "gpu_id": 0, + "save_video": False, + "save_path": temp_dir, + "debug_mode": False, + "filter_visual_rand": False, + "online_config": "", + "gym_config": gym_conf_path, + "action_config": action_conf_path, + } + args = dict2args(input_dict) + gym_config = load_json(args.gym_config) + gym_config["env"]["dataset"]["save_path"] = temp_dir + gym_config["max_episodes"] = 1 + + cfg: EmbodiedEnvCfg = config_to_cfg(gym_config) + cfg.filter_visual_rand = args.filter_visual_rand + + action_config = {} + if args.action_config is not None: + action_config = load_json(args.action_config) + action_config["action_config"] = action_config + + cfg.num_envs = args.num_envs + cfg.sim_cfg = SimulationManagerCfg( + headless=args.headless, + sim_device=args.device, + enable_rt=args.enable_rt, + gpu_id=args.gpu_id, + ) + + env = gymnasium.make(id=gym_config["id"], cfg=cfg, **action_config) + main(args, env, gym_config) + + +if __name__ == "__main__": + # `unittest.main()` is the standard usage to start testing, here we use a customed + # TestLoader to keep executing order of functions the same as their writing order + + unittest.main(testLoader=OrderedTestLoader()) diff --git a/tests/datasets/test_configurable_action.py b/tests/datasets/test_configurable_action.py new file mode 100644 index 00000000..16593ca0 --- /dev/null +++ b/tests/datasets/test_configurable_action.py @@ -0,0 +1,246 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from embodichain.lab.gym.envs.action_bank.configurable_action import ( + ActionBank, + tag_node, + tag_edge, + get_func_tag, +) +import numpy as np +import os +from typing import Dict, Tuple, Union, List, Callable +import unittest +from embodichain.utils.utility import load_json +import inspect + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from common import UnittestMetaclass, OrderedTestLoader + + +class FakePourwaterEnv: + def __init__(self) -> None: + pass + + +class FakePourwaterActionBank(ActionBank): + @staticmethod + @tag_node + def A(env: FakePourwaterEnv): + env.A = "A" + return True + + @staticmethod + @tag_node + def B(env: FakePourwaterEnv): + env.B = env.A + return True + + @staticmethod + @tag_node + def C(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def D(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def a(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def b(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def aa(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def bb(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def cc(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_node + def dd(env: FakePourwaterEnv): + return True + + @staticmethod + @tag_edge + def init_to_pre1(env: FakePourwaterEnv, **kwargs): + return np.random.rand(6, 1) + + @staticmethod + @tag_edge + def grasp_to_move(env: FakePourwaterEnv, **kwargs): + return np.random.rand(6, 2) + + @staticmethod + @tag_edge + def move_to_rotation(env: FakePourwaterEnv, **kwargs): + env.move_to_rotation = np.random.rand(6, 3) + return env.move_to_rotation + + @staticmethod + @tag_edge + def rotation_back_to_move(env: FakePourwaterEnv, **kwargs): + return np.random.rand(6, 4) + + @staticmethod + @tag_edge + def init_to_monitor(env: FakePourwaterEnv, **kwargs): + return np.random.rand(6, 1) + + @staticmethod + @tag_edge + def left_arm_go_back(env: FakePourwaterEnv, **kwargs): + return np.random.rand(6, 2) + + @staticmethod + @tag_edge + def lopen(env: FakePourwaterEnv, **kwargs): + return np.random.rand(1, 10) + + @staticmethod + @tag_edge + def ropen(env: FakePourwaterEnv, **kwargs) -> np.ndarray: + return np.random.rand(1, 10) + + +class TestActionBank(unittest.TestCase, metaclass=UnittestMetaclass): + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass + + def test_simple(self): + class FakeFunctions: + def __init__( + self, + ) -> None: + self.dummy_function = lambda: 1 + + def get_functions(self, names: List[str]) -> Dict[str, Callable]: + """ + Returns a dictionary of dummy functions for the given names. + """ + return {name: self.dummy_function for name in names} + + funcs = FakeFunctions() + conf = load_json(os.path.join("configs", "gym", "action_bank", "conf.json")) + action_bank = ActionBank(conf) + action_bank.parse_network( + funcs.get_functions( + ["A", "B", "C", "D", "a", "b", "aa", "bb", "cc", "dd"], + ), + funcs.get_functions( + [ + "init_to_pre1", + "grasp_to_move", + "move_to_rotation", + "rotation_back_to_move", + "move_back_to_grasp", + "grasp_back_to_pre1", + "init_to_monitor", + "left_arm_go_back", + "lopen", + "ropen", + ], + ), + vis_graph=False, + ) + + def test_hook_and_gantt(self): + conf = load_json(os.path.join("configs", "gym", "action_bank", "conf.json")) + action_bank = FakePourwaterActionBank(conf) + print(get_func_tag("node").functions[action_bank.__class__.__name__]) + _, jobs_data, jobkey2index = action_bank.parse_network( + get_func_tag("node").functions[action_bank.__class__.__name__], + get_func_tag("edge").functions[action_bank.__class__.__name__], + vis_graph=False, + ) + + action_bank.gantt(jobs_data, jobkey2index, vis=False) + + def test_create_action_list(self): + np.random.seed(0) + conf = load_json(os.path.join("configs", "gym", "action_bank", "conf.json")) + action_bank = FakePourwaterActionBank(conf) + graph_compose, jobs_data, jobkey2index = action_bank.parse_network( + get_func_tag("node").functions[action_bank.__class__.__name__], + get_func_tag("edge").functions[action_bank.__class__.__name__], + vis_graph=False, + ) + env = FakePourwaterEnv() + packages = action_bank.gantt(jobs_data, jobkey2index, vis=False) + ret = action_bank.create_action_list(env, graph_compose, packages) + + assert ( + np.linalg.norm(ret["left_arm"][:, 3:10] - ret["left_arm"][:, 3:4]) <= 1e-6 + ) # padding. + assert ( + np.linalg.norm(ret["right_arm"][:, 3:6] - env.move_to_rotation) <= 1e-6 + ) # rotation_back_to_move + + def test_bad_conf(self): + np.random.seed(0) + conf = load_json(os.path.join("configs", "gym", "action_bank", "conf.json")) + conf["node"]["right_arm"] = [ + { + "init_to_pre1": { + "src": "home_qpos", + "sink": "bottle_pre1_pose", + "duration": 1, + "kwargs": {}, + }, + "grasp_to_move": { + "src": "bottle_pre1_pose", + "sink": "bottle_grasp", + "duration": 2, + "kwargs": {}, + }, + } + ] + action_bank = FakePourwaterActionBank(conf) + self.assertRaises( + ValueError, + action_bank.parse_network, + get_func_tag("node").functions[action_bank.__class__.__name__], + get_func_tag("edge").functions[action_bank.__class__.__name__], + vis_graph=False, + ) + + +if __name__ == "__main__": + # `unittest.main()` is the standard usage to start testing, here we use a customed + # TestLoader to keep executing order of functions the same as their writing order + + unittest.main(testLoader=OrderedTestLoader()) diff --git a/tests/datasets/test_online_training.py b/tests/datasets/test_online_training.py new file mode 100644 index 00000000..1e05daa4 --- /dev/null +++ b/tests/datasets/test_online_training.py @@ -0,0 +1,102 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import unittest +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) +from common import UnittestMetaclass, OrderedTestLoader +from embodichain.data.data_engine.online.engine import OnlineEngine +import numpy as np +from tqdm import tqdm +import time + + +class TestDataDictExtractor(unittest.TestCase, metaclass=UnittestMetaclass): + datacenter_backup = Path("/tmp/datacenter_test") + base_url = "http://192.168.3.120/MixedAI/" + + def setUp(self) -> None: + pass + + def tearDown(self) -> None: + pass + + def test_online_generation( + self, + ): + from embodichain.utils.logger import log_warning + from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, + ) + + log_warning("Start online data generation.") + + online_config = { + "episode_limit": 4, + "max_sample_num": 100, + "port": 5566, + "buffer_size": 4, + "max_limit_gb": 5, + } + online_callback = OnlineGenerator(**online_config) + generator_func = lambda **kwargs: [{"data": np.random.randn(1000, 1000)}] + online_callback.generator(generator_func, loop_times=2) + online_callback.empty_memory() + + def test_sample_data(self): + + from embodichain.utils.logger import log_warning + import threading + from embodichain.data.data_engine.online.online_generator import ( + OnlineGenerator, + ) + + log_warning("Start online data generation.") + + online_config = { + "episode_limit": 4, + "max_sample_num": 100, + "port": 7788, + "buffer_size": 4, + "max_limit_gb": 5, + } + online_callback = OnlineGenerator(**online_config) + data_o = np.random.randn(1000, 1000) + generator_func = lambda **kwargs: [{"data": data_o}] + + thread = threading.Thread( + target=online_callback.generator, + kwargs={"generate_func": generator_func, "loop_times": 2}, + daemon=True, + ) + thread.start() + time.sleep(1.0) + + callback = OnlineEngine(**online_config) + callback.start() + time.sleep(1.0) + for i in tqdm(range(5)): + data = callback.sample_data() + assert data.sum() == data_o.sum() + + +if __name__ == "__main__": + # `unittest.main()` is the standard usage to start testing, here we use a customed + # TestLoader to keep executing order of functions the same as their writing order + + unittest.main(testLoader=OrderedTestLoader()) diff --git a/tests/gym/envs/test_base_env.py b/tests/gym/envs/test_base_env.py new file mode 100644 index 00000000..beab6a89 --- /dev/null +++ b/tests/gym/envs/test_base_env.py @@ -0,0 +1,182 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np +import gymnasium as gym + +from embodichain.lab.gym.envs import BaseEnv, EnvCfg +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.shapes import CubeCfg +from embodichain.lab.sim.cfg import ( + RobotCfg, + JointDrivePropertiesCfg, + RigidObjectCfg, + RigidBodyAttributesCfg, +) +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.data import get_data_path + +NUM_ENVS = 10 + + +@register_env("RandomReach-v1", max_episode_steps=100, override=True) +class RandomReachEnv(BaseEnv): + + robot_init_qpos = np.array( + [1.57079, -1.57079, 1.57079, -1.57079, -1.57079, -3.14159] + ) + + def __init__( + self, + num_envs=1, + drive_type="force", + headless=False, + device="cpu", + **kwargs, + ): + self.drive_type = drive_type + + env_cfg = EnvCfg( + sim_cfg=SimulationManagerCfg( + headless=headless, arena_space=2.0, sim_device=device + ), + num_envs=num_envs, + ) + + super().__init__( + cfg=env_cfg, + **kwargs, + ) + + def _setup_robot(self, **kwargs): + file_path = get_data_path("UniversalRobots/UR10/UR10.urdf") + + robot: Robot = self.sim.add_robot( + cfg=RobotCfg( + uid="UR10", + fpath=file_path, + init_pos=(0, 0, 1), + init_qpos=self.robot_init_qpos, + drive_pros=JointDrivePropertiesCfg(drive_type=self.drive_type), + ) + ) + + qpos_limits = robot.body_data.qpos_limits[0].cpu().numpy() + self.single_action_space = gym.spaces.Box( + low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 + ) + + return robot + + def _prepare_scene(self, **kwargs): + size = 0.03 + # Create a kinematic cube object without collision. + # Currently, we use this workaround for visualization purposes. + self.cube: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="cube", + shape=CubeCfg(size=[size, size, size]), + attrs=RigidBodyAttributesCfg(enable_collision=False), + init_pos=(0.0, 0.0, 0.5), + body_type="kinematic", + ), + ) + + def _update_sim_state(self, **kwargs): + pose = torch.eye(4, device=self.device) + pose = pose.unsqueeze_(0).repeat(self.num_envs, 1, 1) + pose[:, :3, 3] += torch.rand(self.num_envs, 3, device=self.device) * 0.5 - 0.25 + self.cube.set_local_pose(pose=pose) + + def _step_action(self, action): + self.robot.set_qpos(qpos=action) + return action + + def _extend_obs(self, obs, **kwargs): + obs["cube_position"] = self.cube.get_local_pose()[:, :3] + return obs + + +class BaseEnvTest: + """Shared test logic for CPU and CUDA.""" + + def setup_simulation(self, sim_device): + self.env = gym.make( + "RandomReach-v1", + num_envs=NUM_ENVS, + headless=True, + device=sim_device, + ) + + def test_env_rollout(self): + """Test environment rollout.""" + for episode in range(2): + print("Episode:", episode) + obs, info = self.env.reset() + + for i in range(2): + action = self.env.action_space.sample() + action = torch.as_tensor( + action, dtype=torch.float32, device=self.env.device + ) + + init_pose = self.env.robot_init_qpos + init_pose = ( + torch.as_tensor( + init_pose, dtype=torch.float32, device=self.env.device + ) + .unsqueeze_(0) + .repeat(self.env.num_envs, 1) + ) + action = ( + init_pose + + torch.rand_like( + action, dtype=torch.float32, device=self.env.device + ) + * 0.2 + - 0.1 + ) + + obs, reward, done, truncated, info = self.env.step(action) + + assert reward.shape == ( + self.env.num_envs, + ), f"Expected reward shape ({self.env.num_envs},), got {reward.shape}" + assert done.shape == ( + self.env.num_envs, + ), f"Expected done shape ({self.env.num_envs},), got {done.shape}" + assert truncated.shape == ( + self.env.num_envs, + ), f"Expected truncated shape ({self.env.num_envs},), got {truncated.shape}" + assert ( + obs.get("cube_position") is not None + ), "Expected 'cube_position' in the obs dict" + assert obs.get("robot") is not None, "Expected 'robot' in the obs dict" + + +class TestBaseEnvCPU(BaseEnvTest): + def setup_method(self): + self.setup_simulation("cpu") + + +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestBaseEnvCUDA(BaseEnvTest): + def setup_method(self): + self.setup_simulation("cuda") diff --git a/tests/gym/envs/test_embodied_env.py b/tests/gym/envs/test_embodied_env.py new file mode 100644 index 00000000..feb4bc5b --- /dev/null +++ b/tests/gym/envs/test_embodied_env.py @@ -0,0 +1,163 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np +import gymnasium as gym + +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.gym.utils.gym_utils import config_to_cfg +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.data import get_data_path + +NUM_ENVS = 10 + +urdf_path = get_data_path("UniversalRobots/UR5/UR5.urdf") +METADATA = { + "id": "EmbodiedEnv-v1", + "max_episodes": 1, + "env": { + "events": { + "random_light": { + "func": "randomize_light", + "mode": "interval", + "interval_step": 10, + "params": { + "entity_cfg": {"uid": "light_1"}, + "position_range": [[-0.5, -0.5, 2], [0.5, 0.5, 2]], + "color_range": [[0.6, 0.6, 0.6], [1, 1, 1]], + "intensity_range": [500000.0, 1500000.0], + }, + } + } + }, + "sensor": [ + { + "sensor_type": "Camera", + "width": 640, + "height": 480, + "enable_mask": True, + "enable_depth": True, + "extrinsics": { + "eye": [0.0, 0.0, 1.0], + "target": [0.0, 0.0, 0.0], + }, + } + ], + "robot": { + "fpath": urdf_path, + "drive_pros": {"stiffness": {"joint[1-6]": 200.0}}, + "solver_cfg": { + "class_type": "PytorchSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + }, + "init_pos": [0.0, 0.3, 1.0], + }, + "light": { + "direct": [ + { + "uid": "light_1", + "light_type": "point", + "color": [1.0, 1.0, 1.0], + "intensity": 1000000.0, + "init_pos": [0, 0, 2], + "radius": 10.0, + } + ] + }, + "background": [ + { + "uid": "shop_table", + "shape": { + "shape_type": "Mesh", + "fpath": "ShopTableSimple/shop_table_simple.ply", + }, + "max_convex_hull_num": 2, + "attrs": {"mass": 10.0}, + "body_scale": (2, 1.6, 1), + } + ], + "rigid_object": [ + { + "uid": "paper_cup", + "shape": { + "shape_type": "Mesh", + "fpath": "PaperCup/paper_cup.ply", + }, + "body_scale": (0.75, 0.75, 1.0), + "init_pos": (0.0, 0.0, 1.0), + } + ], + "articulation": [ + { + "uid": "sliding_box_drawer", + "fpath": "SlidingBoxDrawer/SlidingBoxDrawer.urdf", + "init_pos": (0.5, 0.0, 0.5), + } + ], +} + + +class EmbodiedEnvTest: + """Shared test logic for CPU and CUDA.""" + + def setup_simulation(self, sim_device): + cfg: EmbodiedEnvCfg = config_to_cfg(METADATA) + cfg.num_envs = NUM_ENVS + cfg.sim_cfg = SimulationManagerCfg(headless=True, sim_device=sim_device) + + self.env = gym.make(id=METADATA["id"], cfg=cfg) + + def test_env_rollout(self): + """Test environment rollout.""" + for episode in range(2): + print("Episode:", episode) + obs, info = self.env.reset() + + for i in range(2): + action = self.env.action_space.sample() + action = torch.as_tensor( + action, dtype=torch.float32, device=self.env.device + ) + + obs, reward, done, truncated, info = self.env.step(action) + + assert reward.shape == ( + self.env.num_envs, + ), f"Expected reward shape ({self.env.num_envs},), got {reward.shape}" + assert done.shape == ( + self.env.num_envs, + ), f"Expected done shape ({self.env.num_envs},), got {done.shape}" + assert truncated.shape == ( + self.env.num_envs, + ), f"Expected truncated shape ({self.env.num_envs},), got {truncated.shape}" + assert obs.get("robot") is not None, "Expected 'robot' info in the info dict" + + +class TestCPU(EmbodiedEnvTest): + def setup_method(self): + self.setup_simulation("cpu") + + +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestCUDA(EmbodiedEnvTest): + def setup_method(self): + self.setup_simulation("cuda") diff --git a/tests/sim/objects/test_articulation.py b/tests/sim/objects/test_articulation.py new file mode 100644 index 00000000..d233f8f9 --- /dev/null +++ b/tests/sim/objects/test_articulation.py @@ -0,0 +1,198 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +from numpy import mat +import torch +import pytest + +from embodichain.lab.sim import ( + SimulationManager, + SimulationManagerCfg, + VisualMaterialCfg, +) +from embodichain.lab.sim.objects import Articulation +from embodichain.lab.sim.cfg import ArticulationCfg +from embodichain.data import get_data_path +from dexsim.types import ActorType + +ART_PATH = "AiLiMu_BoxDrawer/AiLiMu_BoxDrawer.urdf" +NUM_ARENAS = 10 + + +class BaseArticulationTest: + """Shared test logic for CPU and CUDA.""" + + def setup_simulation(self, sim_device): + config = SimulationManagerCfg(headless=True, sim_device=sim_device) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(NUM_ARENAS) + self.sim.set_manual_update(True) + + art_path = get_data_path(ART_PATH) + assert os.path.isfile(art_path) + + cfg_dict = {"fpath": art_path} + self.art: Articulation = self.sim.add_articulation( + cfg=ArticulationCfg.from_dict(cfg_dict) + ) + + if sim_device == "cuda" and getattr(self.sim, "is_use_gpu_physics", False): + self.sim.init_gpu_physics() + + def test_local_pose_behavior(self): + """Test set_local_pose and get_local_pose: + - Drawer pose is correctly set + """ + + # Set initial poses + pose = torch.eye(4, device=self.sim.device) + pose[2, 3] = 1.0 + pose = pose.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + self.art.set_local_pose(pose, env_ids=None) + + # --- Check poses immediately after setting + xyz = self.art.get_local_pose()[0, :3] + + expected_pos = torch.tensor( + [0.0, 0.0, 1.0], device=self.sim.device, dtype=torch.float32 + ) + assert torch.allclose( + xyz, expected_pos, atol=1e-5 + ), f"FAIL: Drawer pose not set correctly: {xyz.tolist()}" + + def test_control_api(self): + """Test control API for setting and getting joint positions.""" + # Set initial joint positions + qpos_zero = torch.zeros( + (NUM_ARENAS, self.art.dof), dtype=torch.float32, device=self.sim.device + ) + qpos = qpos_zero.clone() + qpos[:, -1] = 0.1 + + # Test setting joint positions directly. + self.art.set_qpos(qpos, env_ids=None, target=False) + target_qpos = self.art.body_data.qpos + assert torch.allclose( + target_qpos, qpos, atol=1e-5 + ), f"FAIL: Joint positions not set correctly: {target_qpos.tolist()}" + + self.art.set_qpos(qpos=qpos_zero, env_ids=None, target=False) + + # Test setting joint positions with target=True + self.art.set_qpos(qpos, env_ids=None, target=True) + self.sim.update(step=100) + target_qpos = self.art.body_data.qpos + assert torch.allclose( + target_qpos, qpos, atol=1e-5 + ), f"FAIL: Joint positions not set correctly with target=True: {target_qpos.tolist()}" + + self.art.set_qpos(qpos=qpos_zero, env_ids=None, target=False) + self.art.clear_dynamics() + + # Test setting joint forces + qf = torch.ones( + (NUM_ARENAS, self.art.dof), dtype=torch.float32, device=self.sim.device + ) + self.art.set_qf(qf, env_ids=None) + target_qf = self.art.body_data.qf + assert torch.allclose( + target_qf, qf, atol=1e-5 + ), f"FAIL: Joint forces not set correctly: {target_qf.tolist()}" + print("Applying joint forces...") + print(f"qpos before applying force: {qpos_zero.tolist()}") + print(f"qf before applying force: {qf.tolist()}") + + self.sim.update(step=100) + target_qpos = self.art.body_data.qpos + print(f"target_qpos: {target_qpos}") + print(f"qpos_zero: {qpos_zero}") + print("qpos diff:", target_qpos - qpos_zero) + # check target_qpos is greater than qpos + assert torch.any( + (target_qpos - qpos_zero).abs() > 1e-4 + ), f"FAIL: Target qpos did not change after applying force: {target_qpos.tolist()}" + + def test_set_visual_material(self): + """Test setting visual material properties.""" + # Create blue material + blue_mat = self.sim.create_visual_material( + cfg=VisualMaterialCfg(base_color=[0.0, 0.0, 1.0, 1.0]) + ) + + self.art.set_visual_material(blue_mat, link_names=["outer_box", "handle_xpos"]) + + mat_insts = self.art.get_visual_material_inst() + + assert ( + len(mat_insts) == 10 + ), f"FAIL: Expected 10 material instances, got {len(mat_insts)}" + assert ( + "outer_box" in mat_insts[0] + ), "FAIL: 'outer_box' not in material instances" + assert ( + "handle_xpos" in mat_insts[0] + ), "FAIL: 'handle_xpos' not in material instances" + assert mat_insts[0]["outer_box"].base_color == [ + 0.0, + 0.0, + 1.0, + 1.0, + ], f"FAIL: 'outer_box' base color not set correctly: {mat_insts[0]['outer_box'].base_color}" + assert mat_insts[0]["handle_xpos"].base_color == [ + 0.0, + 0.0, + 1.0, + 1.0, + ], f"FAIL: 'handle_xpos' base color not set correctly: {mat_insts[0]['handle_xpos'].base_color}" + + # TODO: Open this test will cause segfault in CI env + # def test_get_link_pose(self): + # """Test getting link poses.""" + # poses = self.art.get_link_pose(link_name="handle_xpos", to_matrix=False) + # assert poses.shape == ( + # NUM_ARENAS, + # 7, + # ), f"FAIL: Expected poses shape {(NUM_ARENAS, 7)}, got {poses.shape}" + + def test_remove_articulation(self): + """Test removing an articulation from the simulation.""" + self.sim.remove_asset(self.art.uid) + assert ( + self.art.uid not in self.sim.asset_uids + ), "FAIL: Articulation UID still present after removal" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestArticulationCPU(BaseArticulationTest): + def setup_method(self): + self.setup_simulation("cpu") + + +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestArticulationCUDA(BaseArticulationTest): + def setup_method(self): + self.setup_simulation("cuda") + + +if __name__ == "__main__": + test = TestArticulationCPU() + test.setup_method() + test.test_set_visual_material() diff --git a/tests/sim/objects/test_light.py b/tests/sim/objects/test_light.py new file mode 100644 index 00000000..8f61fce3 --- /dev/null +++ b/tests/sim/objects/test_light.py @@ -0,0 +1,155 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import pytest +import torch +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import LightCfg + + +class TestLight: + def setup_method(self): + # Setup SimulationManager + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(10) + # sim.set_manual_update(True) + # Create batch of lights + cfg_dict = { + "light_type": "point", + "color": [0.1, 0.1, 0.1], + "radius": 10.0, + "position": [0.0, 0.0, 2.0], + "uid": "point_light", + } + self.light = self.sim.add_light(cfg=LightCfg.from_dict(cfg_dict)) + + def test_set_color_with_env_ids(self): + """Test set_color with and without env_ids.""" + base_color = torch.tensor([0.1, 0.1, 0.1], device=self.sim.device) + + # Set for all environments + try: + self.light.set_color(base_color) + except Exception as e: + pytest.fail(f"Failed to set color for all envs: {e}") + + # Set for specific envs + env_ids = [1, 3, 5] + new_color = torch.tensor([0.9, 0.8, 0.7], device=self.sim.device) + try: + self.light.set_color(new_color, env_ids=env_ids) + except Exception as e: + pytest.fail(f"Failed to set color for env_ids={env_ids}: {e}") + + def test_set_falloff_with_env_ids(self): + """Test set_falloff with and without env_ids.""" + base_val = torch.tensor(100.0, device=self.sim.device) + + # Set for all + try: + self.light.set_falloff(base_val) + except Exception as e: + pytest.fail(f"Failed to set falloff for all envs: {e}") + + env_ids = [0, 7, 9] + new_vals = torch.tensor([200.0, 300.0, 400.0], device=self.sim.device) + try: + self.light.set_falloff(new_vals, env_ids=env_ids) + except Exception as e: + pytest.fail(f"Failed to set falloff for env_ids={env_ids}: {e}") + + def test_set_and_get_local_pose_matrix_and_vector(self): + """ + Test setting and getting local pose in both matrix and vector forms. + + 1. Set all lights to identity pose (4x4 matrix) + 2. Overwrite subset of lights (env_ids) with custom pose + 3. Check both vector and matrix results match expectations + """ + + # ---------------------------- + # 1. Set all lights to identity matrix + # ---------------------------- + pose_matrix = torch.eye(4, device=self.sim.device) + try: + self.light.set_local_pose(pose_matrix, to_matrix=True) + except Exception as e: + pytest.fail(f"Failed to set pose matrix for all envs: {e}") + + result_matrix = self.light.get_local_pose(to_matrix=True) + assert result_matrix.shape == ( + 10, + 4, + 4, + ), "Unexpected shape from get_local_pose(to_matrix=True)" + for i, mat in enumerate(result_matrix): + assert torch.allclose( + mat, pose_matrix, atol=1e-5 + ), f"Initial matrix pose mismatch at env {i}" + + # ---------------------------- + # 2. Set translation via matrix for selected env_ids + # ---------------------------- + env_ids = [2, 4, 6] + pose_matrix_2 = ( + torch.eye(4, device=self.sim.device).unsqueeze(0).repeat(len(env_ids), 1, 1) + ) + pose_matrix_2[:, 0, 3] = 1.0 + pose_matrix_2[:, 1, 3] = 2.0 + pose_matrix_2[:, 2, 3] = 3.0 + + try: + self.light.set_local_pose(pose_matrix_2, env_ids=env_ids, to_matrix=True) + except Exception as e: + pytest.fail(f"Failed to set pose matrix for env_ids={env_ids}: {e}") + + # ---------------------------- + # 3. Check vector form after env_ids modification + # ---------------------------- + result_vec = self.light.get_local_pose(to_matrix=False) + assert result_vec.shape == ( + 10, + 3, + ), "Unexpected shape from get_local_pose(to_matrix=False)" + + for i in range(10): + expected = ( + torch.tensor([1.0, 2.0, 3.0], device=self.sim.device) + if i in env_ids + else torch.tensor([0.0, 0.0, 0.0], device=self.sim.device) + ) + assert torch.allclose( + result_vec[i], expected, atol=1e-5 + ), f"Translation vector mismatch at env {i}" + + # ---------------------------- + # 4. Verify matrix form translation field + # ---------------------------- + result_matrix = self.light.get_local_pose(to_matrix=True) + for i in range(10): + expected = ( + torch.tensor([1.0, 2.0, 3.0], device=self.sim.device) + if i in env_ids + else torch.tensor([0.0, 0.0, 0.0], device=self.sim.device) + ) + assert torch.allclose( + result_matrix[i][:3, 3], expected, atol=1e-5 + ), f"Translation matrix mismatch at env {i}" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() diff --git a/tests/sim/objects/test_rigid_object.py b/tests/sim/objects/test_rigid_object.py new file mode 100644 index 00000000..f57f8a91 --- /dev/null +++ b/tests/sim/objects/test_rigid_object.py @@ -0,0 +1,280 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest + +from embodichain.lab.sim import ( + SimulationManager, + SimulationManagerCfg, + VisualMaterialCfg, +) +from embodichain.lab.sim.objects import RigidObject +from embodichain.lab.sim.cfg import RigidObjectCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.data import get_data_path +from dexsim.types import ActorType + +DUCK_PATH = "ToyDuck/toy_duck.glb" +TABLE_PATH = "ShopTableSimple/shop_table_simple.ply" +CHAIR_PATH = "Chair/chair.glb" +NUM_ARENAS = 2 +Z_TRANSLATION = 2.0 + + +class BaseRigidObjectTest: + """Shared test logic for CPU and CUDA.""" + + def setup_simulation(self, sim_device): + config = SimulationManagerCfg(headless=True, sim_device=sim_device) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(NUM_ARENAS) + self.sim.set_manual_update(True) + + duck_path = get_data_path(DUCK_PATH) + assert os.path.isfile(duck_path) + table_path = get_data_path(TABLE_PATH) + assert os.path.isfile(table_path) + chair_path = get_data_path(CHAIR_PATH) + assert os.path.isfile(chair_path) + + cfg_dict = { + "uid": "duck", + "shape": { + "shape_type": "Mesh", + "fpath": duck_path, + }, + "body_type": "dynamic", + } + self.duck: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg.from_dict(cfg_dict), + ) + self.table: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="table", shape=MeshCfg(fpath=table_path), body_type="static" + ), + ) + self.chair: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg( + uid="chair", shape=MeshCfg(fpath=chair_path), body_type="kinematic" + ), + ) + + if sim_device == "cuda" and getattr(self.sim, "is_use_gpu_physics", False): + self.sim.init_gpu_physics() + + self.sim.enable_physics(True) + + def test_is_static(self): + """Test the is_static() method of duck, table, and chair objects.""" + assert not self.duck.is_static, "Duck should be dynamic but is marked static" + assert self.table.is_static, "Table should be static but is marked dynamic" + assert ( + not self.chair.is_static + ), "Chair should be kinematic but is marked static" + + def test_local_pose_behavior(self): + """Test set_local_pose and get_local_pose: + - duck pose is correctly set + - duck falls after physics update + - table stays in place throughout + - chair is kinematic and does not move + """ + + # Set initial poses + pose_duck = torch.eye(4, device=self.sim.device) + pose_duck[2, 3] = Z_TRANSLATION + pose_duck = pose_duck.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + pose_table = torch.eye(4, device=self.sim.device) + pose_table = pose_table.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + pose_chair = torch.eye(4, device=self.sim.device) + pose_chair[0, 3] = 1.0 + pose_chair[1, 3] = 2.0 + pose_chair = pose_chair.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + self.duck.set_local_pose(pose_duck) + self.table.set_local_pose(pose_table) + self.chair.set_local_pose(pose_chair) + + # --- Check poses immediately after setting + duck_xyz = self.duck.get_local_pose()[0, :3] + table_xyz = self.table.get_local_pose()[0, :3] + chair_xyz = self.chair.get_local_pose()[0, :3] + + expected_duck_pos = torch.tensor( + [0.0, 0.0, Z_TRANSLATION], device=self.sim.device, dtype=torch.float32 + ) + expected_table_pos = torch.tensor( + [0.0, 0.0, 0.0], device=self.sim.device, dtype=torch.float32 + ) + expected_chair_pos = torch.tensor( + [1.0, 2.0, 0.0], device=self.sim.device, dtype=torch.float32 + ) + + assert torch.allclose( + duck_xyz, expected_duck_pos, atol=1e-5 + ), f"FAIL: Duck pose not set correctly: {duck_xyz.tolist()}" + assert torch.allclose( + table_xyz, expected_table_pos, atol=1e-5 + ), f"FAIL: Table pose not set correctly: {table_xyz.tolist()}" + assert torch.allclose( + chair_xyz, expected_chair_pos, atol=1e-5 + ), f"FAIL: Chair pose not set correctly: {chair_xyz.tolist()}" + + # --- Step simulation + for _ in range(10): + self.sim.update(0.01) + + # --- Post-update checks + duck_z_after = self.duck.get_local_pose()[0, 2].item() + table_xyz_after = self.table.get_local_pose()[0, :3].tolist() + chair_xyz_after = self.chair.get_local_pose()[0, :3] + + assert ( + duck_z_after < Z_TRANSLATION + ), f"FAIL: Duck did not fall: z = {duck_z_after:.3f}" + assert all( + abs(x) < 1e-5 for x in table_xyz_after + ), f"FAIL: Table moved unexpectedly: {table_xyz_after}" + assert torch.allclose( + chair_xyz_after, expected_chair_pos, atol=1e-5 + ), f"FAIL: Chair pose changed unexpectedly: {chair_xyz_after.tolist()}" + + def test_add_force_torque(self): + """Test that add_force applies force correctly to the duck object.""" + + pose_before = self.duck.get_local_pose() + + force = ( + torch.tensor([10.0, 0.0, 0], device=self.sim.device) + .unsqueeze(0) + .repeat(NUM_ARENAS, 1) + ) + self.duck.add_force_torque(force) + + # Update simulation to apply the force + self.sim.update(0.01) + + # Check if the duck's z position has changed + pose_after = self.duck.get_local_pose() + assert not torch.allclose( + pose_before, pose_after + ), "FAIL: Duck pose did not change after applying force" + + pose_before = self.duck.get_local_pose() + torque = ( + torch.tensor([0.0, 10.0, 0.0], device=self.sim.device) + .unsqueeze(0) + .repeat(NUM_ARENAS, 1) + ) + self.duck.add_force_torque(None, torque=torque) + + # Update simulation to apply the torque + self.sim.update(0.01) + + pose_after = self.duck.get_local_pose() + assert not torch.allclose( + pose_before, pose_after + ), "FAIL: Duck pose did not change after applying torque" + + # Test clear dynamics + self.duck.clear_dynamics() + + def test_set_visual_material(self): + """Test that set_material correctly assigns the material to the duck.""" + + # Create blue material + blue_mat = self.sim.create_visual_material( + cfg=VisualMaterialCfg(base_color=[0.0, 0.0, 1.0, 1.0]) + ) + + # Set it to the duck + self.duck.set_visual_material(blue_mat) + + # # # Get material instances + material_list = self.duck.get_visual_material_inst() + + # # Check correctness + assert isinstance(material_list, list), "get_material() did not return a list" + assert ( + len(material_list) == NUM_ARENAS + ), f"Expected {NUM_ARENAS} materials, got {len(material_list)}" + for mat_inst in material_list: + assert mat_inst.base_color == [ + 0.0, + 0.0, + 1.0, + 1.0, + ], f"Material base color incorrect: {mat_inst.base_color}" + + def test_add_cube(self): + cfg_dict = { + "uid": "cube", + "shape": { + "shape_type": "Cube", + }, + "body_type": "dynamic", + } + cube: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg.from_dict(cfg_dict), + ) + + def test_add_sphere(self): + cfg_dict = { + "uid": "sphere", + "shape": { + "shape_type": "Sphere", + }, + "body_type": "dynamic", + } + sphere: RigidObject = self.sim.add_rigid_object( + cfg=RigidObjectCfg.from_dict(cfg_dict), + ) + + def test_remove(self): + self.sim.remove_asset(self.duck.uid) + + assert ( + self.duck.uid not in self.sim.asset_uids + ), "Duck UID still present after removal" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestRigidObjectCPU(BaseRigidObjectTest): + def setup_method(self): + self.setup_simulation("cpu") + + +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestRigidObjectCUDA(BaseRigidObjectTest): + def setup_method(self): + self.setup_simulation("cuda") + + +if __name__ == "__main__": + # pytest.main(["-s", __file__]) + test = TestRigidObjectCPU() + test.setup_method() + test.test_set_visual_material() + from IPython import embed + + embed() diff --git a/tests/sim/objects/test_rigid_object_group.py b/tests/sim/objects/test_rigid_object_group.py new file mode 100644 index 00000000..7ca12fb3 --- /dev/null +++ b/tests/sim/objects/test_rigid_object_group.py @@ -0,0 +1,132 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import RigidObjectGroup +from embodichain.lab.sim.cfg import RigidObjectGroupCfg, RigidObjectCfg +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.data import get_data_path +from dexsim.types import ActorType + +DUCK_PATH = "ToyDuck/toy_duck.glb" +TABLE_PATH = "ShopTableSimple/shop_table_simple.ply" +NUM_ARENAS = 4 +Z_TRANSLATION = 2.0 + + +class BaseRigidObjectGroupTest: + """Shared test logic for CPU and CUDA.""" + + def setup_simulation(self, sim_device): + config = SimulationManagerCfg(headless=True, sim_device=sim_device) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(NUM_ARENAS) + self.sim.set_manual_update(True) + + duck_path = get_data_path(DUCK_PATH) + assert os.path.isfile(duck_path) + table_path = get_data_path(TABLE_PATH) + assert os.path.isfile(table_path) + + cfg_dict = { + "uid": "group", + "rigid_objects": { + "duck1": { + "shape": { + "shape_type": "Mesh", + "fpath": duck_path, + }, + }, + "duck2": { + "shape": { + "shape_type": "Mesh", + "fpath": duck_path, + }, + }, + }, + } + self.obj_group: RigidObjectGroup = self.sim.add_rigid_object_group( + cfg=RigidObjectGroupCfg.from_dict(cfg_dict) + ) + + if sim_device == "cuda" and self.sim.is_use_gpu_physics: + self.sim.init_gpu_physics() + + self.sim.enable_physics(True) + + def test_local_pose_behavior(self): + + # Set initial poses + pose_duck1 = torch.eye(4, device=self.sim.device) + pose_duck1[2, 3] = Z_TRANSLATION + pose_duck1 = pose_duck1.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + pose_duck2 = torch.eye(4, device=self.sim.device) + pose_duck2[2, 3] = Z_TRANSLATION + pose_duck2 = pose_duck2.unsqueeze(0).repeat(NUM_ARENAS, 1, 1) + + combined_pose = torch.stack([pose_duck1, pose_duck2], dim=1) + + self.obj_group.set_local_pose(combined_pose) + group_pos = self.obj_group.get_local_pose()[..., :3] + assert torch.allclose( + group_pos, + combined_pose[..., :3, 3], + atol=1e-5, + ), "FAIL: Local poses do not match after setting." + + def test_get_user_ids(self): + """Test get_user_ids method.""" + user_ids = self.obj_group.get_user_ids() + + assert user_ids.shape == (NUM_ARENAS, self.obj_group.num_objects), ( + f"Unexpected user_ids shape: {user_ids.shape}, " + f"expected ({NUM_ARENAS}, {self.obj_group.num_objects})" + ) + + def test_remove(self): + self.sim.remove_asset(self.obj_group.uid) + + assert ( + self.obj_group.uid not in self.sim.asset_uids + ), "Object group UID still present after removal" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestRigidObjectGroupCPU(BaseRigidObjectGroupTest): + def setup_method(self): + self.setup_simulation("cpu") + + +# TODO: Fix CUDA tests issue. +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestRigidObjectGroupCUDA(BaseRigidObjectGroupTest): + def setup_method(self): + self.setup_simulation("cuda") + + +if __name__ == "__main__": + # pytest.main(["-s", __file__]) + test = TestRigidObjectGroupCPU() + test.setup_method() + test.test_local_pose_behavior() diff --git a/tests/sim/objects/test_robot.py b/tests/sim/objects/test_robot.py new file mode 100644 index 00000000..78a5ac3a --- /dev/null +++ b/tests/sim/objects/test_robot.py @@ -0,0 +1,264 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.robots.dexforce_w1 import DexforceW1Cfg +from embodichain.data import get_data_path + + +# Define control parts +CONTROL_PARTS = { + "left_arm": [ + "LEFT_J1", + "LEFT_J2", + "LEFT_J3", + "LEFT_J4", + "LEFT_J5", + "LEFT_J6", + "LEFT_J7", + ], + "right_arm": [ + "RIGHT_J1", + "RIGHT_J2", + "RIGHT_J3", + "RIGHT_J4", + "RIGHT_J5", + "RIGHT_J6", + "RIGHT_J7", + ], +} + +# Base test class for CPU and CUDA +class BaseRobotTest: + def setup_simulation(self, sim_device): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device=sim_device) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(10) # NUM_ARENAS = 10 + self.sim.set_manual_update(True) + + cfg = DexforceW1Cfg.from_dict( + { + "uid": "dexforce_w1", + "version": "v021", + "arm_kind": "anthropomorphic", + } + ) + + self.robot: Robot = self.sim.add_robot(cfg=cfg) + + # Initialize GPU physics if needed + if sim_device == "cuda" and getattr(self.sim, "is_use_gpu_physics", False): + self.sim.init_gpu_physics() + + def test_get_joint_ids(self): + left_joint_ids = self.robot.get_joint_ids("left_arm") + right_joint_ids = self.robot.get_joint_ids("right_arm") + + assert left_joint_ids == [ + 6, + 8, + 10, + 12, + 14, + 16, + 18, + ], f"Unexpected left arm joint IDs: {left_joint_ids}" + assert right_joint_ids == [ + 7, + 9, + 11, + 13, + 15, + 17, + 19, + ], f"Unexpected right arm joint IDs: {right_joint_ids}" + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_fk(self, arm_name: str): + # Test forward kinematics (FK) for both to_matrix=True and to_matrix=False + + qpos = torch.randn(10, 7, device=self.sim.device) # Random joint positions + + # Test with to_matrix=False (6D result: translation + Euler angles) + result_7d = self.robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=False) + + # Check result shape for 6D output (batch, 6) + assert result_7d.shape == ( + 10, + 7, + ), f"Expected shape (10, 7), got {result_7d.shape}" + + # Test with to_matrix=True (4x4 matrix result) + result_matrix = self.robot.compute_fk(qpos=qpos, name=arm_name, to_matrix=True) + print("result_matrix:", result_matrix) + # Check result shape for matrix output (batch, 4, 4) + assert result_matrix.shape == ( + 10, + 4, + 4, + ), f"Expected shape (10, 4, 4), got {result_matrix.shape}" + + def test_compute_fk(self): + torch.set_printoptions(precision=6, sci_mode=False) + qpos = np.zeros(40) + result = self.robot.compute_fk(qpos=qpos, link_names=["left_ee", "right_ee"]) + + # Additional checks for specific values (if known) + expected_values = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.791], + [0.0, -1.0, 0.0, 1.3648], + [0.0, 0.0, 0.0, 1.0], + ], + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, -1.0, -0.791], + [0.0, 1.0, 0.0, 1.3648], + [0.0, 0.0, 0.0, 1.0], + ], + ], + dtype=torch.float32, + ).unsqueeze_(0) + + assert torch.allclose( + result, expected_values, atol=1e-4, rtol=1e-4 + ), f"FK result does not match expected values. Got {result}, expected {expected_values}." + + def test_compute_jacobian(self): + qpos = np.full(7, 10 * np.pi / 180) + + left_ee_jacobian = self.robot.compute_jacobian( + qpos=qpos, end_link_name="left_ee", root_link_name="left_arm_base" + ) + right_ee_jacobian = self.robot.compute_jacobian( + qpos=qpos, end_link_name="right_ee", root_link_name="right_arm_base" + ) + + assert left_ee_jacobian.shape == ( + 1, + 6, + 7, + ), f"Expected shape (1, 6, 7) for left EE Jacobian, got {left_ee_jacobian.shape}" + assert right_ee_jacobian.shape == ( + 1, + 6, + 7, + ), f"Expected shape (1, 6, 7) for right EE Jacobian, got {right_ee_jacobian.shape}" + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_ik(self, arm_name: str): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + + # Define a sample target pose as a 1x4x4 homogeneous matrix + target_pose = torch.tensor( + [ + [-0.3490, -0.6369, -0.6874, -0.4502], + [0.2168, -0.7685, 0.6020, -0.0639], + [-0.9117, 0.0611, 0.4063, 0.3361], + [0.0000, 0.0000, 0.0000, 1.0000], + ], + dtype=torch.float32, + device=self.sim.device, + ).unsqueeze(0) + + # Define joint_seed as a tensor of ones with shape (1, 7) for initialization + joint_seed = torch.ones(1, 7, device=self.sim.device) + success_tensor, qpos_tensor = self.robot.compute_ik( + pose=target_pose, name=arm_name, joint_seed=joint_seed, env_ids=[0] + ) + print(f"Success: {success_tensor}, Qpos: {qpos_tensor}") + + # Check output shapes robustly + assert success_tensor.shape == ( + 1, + ), f"Expected shape (1,), got {success_tensor.shape}" + assert isinstance( + qpos_tensor, torch.Tensor + ), "qpos_tensor should be a torch.Tensor" + # Accept both (1, 7) and (1, N, 7) shapes + if qpos_tensor.ndim == 2: + assert qpos_tensor.shape == ( + 1, + 7, + ), f"Expected shape (1, 7), got {qpos_tensor.shape}" + elif qpos_tensor.ndim == 3: + assert ( + qpos_tensor.shape[2] == 7 + ), f"Expected dof 7, got {qpos_tensor.shape[2]}" + assert ( + qpos_tensor.shape[0] == 1 + ), f"Expected batch size 1, got {qpos_tensor.shape[0]}" + assert ( + qpos_tensor.shape[1] >= 1 + ), f"Expected at least one solution, got {qpos_tensor.shape[1]}" + else: + raise AssertionError(f"Unexpected qpos_tensor shape: {qpos_tensor.shape}") + + # If success, check qpos is not all zeros + if success_tensor.item(): + assert not torch.all( + qpos_tensor == 0 + ), "IK returned all zeros for valid solution" + + def test_mimic(self): + + assert ( + len(self.robot.mimic_ids) == 8 + ), f"Expected 8 mimic IDs, got {len(self.robot.mimic_ids)}" + + left_eef_ids_without_mimic = self.robot.get_joint_ids( + "left_eef", remove_mimic=False + ) + right_eef_ids_without_mimic = self.robot.get_joint_ids( + "right_eef", remove_mimic=False + ) + assert ( + len(left_eef_ids_without_mimic) == 6 + ), f"Expected 6 left eef joint IDs without mimic, got {len(left_eef_ids_without_mimic)}" + assert ( + len(right_eef_ids_without_mimic) == 6 + ), f"Expected 6 right eef joint IDs without mimic, got {len(right_eef_ids_without_mimic)}" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestRobotCPU(BaseRobotTest): + def setup_method(self): + self.setup_simulation("cpu") + + +@pytest.mark.skip(reason="Skipping CUDA tests temporarily") +class TestRobotCUDA(BaseRobotTest): + def setup_method(self): + self.setup_simulation("cuda") + + +if __name__ == "__main__": + # Run tests directly + test_cpu = TestRobotCPU() + test_cpu.setup_method() + test_cpu.test_fk("left_arm") diff --git a/tests/sim/objects/test_soft_object.py b/tests/sim/objects/test_soft_object.py new file mode 100644 index 00000000..9cbc31af --- /dev/null +++ b/tests/sim/objects/test_soft_object.py @@ -0,0 +1,105 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +from dexsim.utility.path import get_resources_data_path +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.cfg import ( + SoftbodyVoxelAttributesCfg, + SoftbodyPhysicalAttributesCfg, +) +from embodichain.lab.sim.shapes import MeshCfg +from embodichain.lab.sim.objects import ( + SoftObject, + SoftObjectCfg, +) +import pytest + +COW_PATH = get_resources_data_path("Model", "cow", "cow.obj") + + +class BaseSoftObjectTest: + def setup_simulation(self): + sim_cfg = SimulationManagerCfg( + width=1920, + height=1080, + headless=True, + physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) + sim_device="cuda", + enable_rt=False, # Enable ray tracing for better visuals + ) + + # Create the simulation instance + self.sim = SimulationManager(sim_cfg) + + assert os.path.isfile(COW_PATH) + + # Enable manual physics update for precise control + self.sim.set_manual_update(True) + self.n_envs = 4 + # Build multiple arenas if requested + self.sim.build_multiple_arenas(self.n_envs, space=3.0) + # add softbody to the scene + self.cow: SoftObject = self.sim.add_soft_object( + cfg=SoftObjectCfg( + uid="cow", + shape=MeshCfg( + fpath=get_resources_data_path("Model", "cow", "cow.obj"), + ), + init_pos=[0.0, 0.0, 3.0], + voxel_attr=SoftbodyVoxelAttributesCfg( + simulation_mesh_resolution=8, + maximal_edge_length=0.5, + ), + physical_attr=SoftbodyPhysicalAttributesCfg( + youngs=1e6, + poissons=0.45, + density=100, + dynamic_friction=0.1, + min_position_iters=30, + ), + ), + ) + + def test_run_simulation(self): + self.sim.init_gpu_physics() + for _ in range(100): + self.sim.update(step=1) + self.cow.reset() + for _ in range(100): + self.sim.update(step=1) + + def test_remove(self): + self.sim.remove_asset(self.cow.uid) + assert ( + self.cow.uid not in self.sim._soft_objects + ), "Cow UID still present after removal" + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +@pytest.mark.skip(reason="Skipping SoftObject test now") +class TestSoftObjectCUDA(BaseSoftObjectTest): + def setup_method(self): + self.setup_simulation() + + +if __name__ == "__main__": + test = TestSoftObjectCUDA() + test.setup_method() + test.test_run_simulation() diff --git a/tests/sim/planners/test_motion_generator.py b/tests/sim/planners/test_motion_generator.py new file mode 100644 index 00000000..67e17dcf --- /dev/null +++ b/tests/sim/planners/test_motion_generator.py @@ -0,0 +1,251 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- +import time +import torch +import pytest +import numpy as np +from copy import deepcopy +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.robots import CobotMagicCfg + +from embodichain.lab.sim.planners.utils import TrajectorySampleMethod +from embodichain.lab.sim.planners.motion_generator import MotionGenerator + + +def to_numpy(tensor): + if isinstance(tensor, torch.Tensor): + tensor = tensor.detach().cpu() + if tensor.ndim == 3 and tensor.shape[0] == 1: + tensor = tensor[0] + return tensor.numpy() + return np.array(tensor) + + +class BaseTestMotionGenerator(object): + @classmethod + def setup_class(cls): + cls.config = SimulationManagerCfg(headless=True, sim_device="cpu") + cls.robot_sim = SimulationManager(cls.config) + cls.robot_sim.build_multiple_arenas(1) + cls.robot_sim.set_manual_update(False) + + cfg_dict = { + "uid": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [ + -0.3, + 0.3, + 1.0, + 1.0, + -1.2, + -1.2, + 0.0, + 0.0, + 0.6, + 0.6, + 0.0, + 0.0, + 0.05, + 0.05, + 0.05, + 0.05, + ], + "solver_cfg": { + "left_arm": { + "class_type": "OPWSolver", + "end_link_name": "left_link6", + "root_link_name": "left_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + "right_arm": { + "class_type": "OPWSolver", + "end_link_name": "right_link6", + "root_link_name": "right_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + }, + } + + cls.robot: Robot = cls.robot_sim.add_robot( + cfg=CobotMagicCfg.from_dict(cfg_dict) + ) + + cls.arm_name = "left_arm" + + cls.motion_gen = MotionGenerator( + robot=cls.robot, + uid=cls.arm_name, + planner_type="toppra", + default_velocity=0.2, + default_acceleration=0.5, + ) + + # Test data for trajectory generation + qpos_fk = torch.tensor( + [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], dtype=torch.float32 + ) + xpos_begin = cls.robot.compute_fk( + name=cls.arm_name, qpos=qpos_fk, to_matrix=True + ) + xpos_mid = deepcopy(xpos_begin) + xpos_mid[0, 2, 3] -= 0.1 # Move down by 0.1m in Z direction + xpos_final = deepcopy(xpos_mid) + xpos_final[0, 0, 3] += 0.2 # Move forward by 0.2m in X direction + + qpos_begin = cls.robot.compute_ik(pose=xpos_begin, name=cls.arm_name)[1][0] + qpos_mid = cls.robot.compute_ik(pose=xpos_mid, name=cls.arm_name)[1][0] + qpos_final = cls.robot.compute_ik(pose=xpos_final, name=cls.arm_name)[1][0] + + cls.qpos_list = [qpos_begin, qpos_mid, qpos_final] + cls.xpos_list = [ + xpos_begin[0].numpy(), + xpos_mid[0].numpy(), + xpos_final[0].numpy(), + ] + + cls.sample_num = 20 + + def get_joint_ids(self): + return self.robot.get_joint_ids(self.arm_name) + + def get_current_qpos(self): + qpos_tensor = self.robot.get_qpos() + if qpos_tensor.ndim == 2 and qpos_tensor.shape[0] == 1: + qpos_tensor = qpos_tensor[0] + return qpos_tensor[self.get_joint_ids()].cpu() + + def verify_final_xpos(self, expected_xpos, decimal=5e-3): + final_xpos = self.robot.compute_fk( + qpos=self.get_current_qpos(), name=self.arm_name, to_matrix=True + ) + np.testing.assert_array_almost_equal( + to_numpy(final_xpos)[:3, 3], + to_numpy(expected_xpos)[:3, 3], + decimal=decimal, + err_msg=f"Expected: {to_numpy(expected_xpos)[:3, 3]}, Got: {to_numpy(final_xpos)[:3, 3]}", + ) + + def _execute_trajectory(self, qpos_list, forward=True, delay=0.01): + indices = ( + range(len(qpos_list)) if forward else range(len(qpos_list) - 1, -1, -1) + ) + for i in indices: + self.robot.set_qpos(qpos=qpos_list[i], joint_ids=self.get_joint_ids()) + time.sleep(delay) + time.sleep(delay * 2) + + @classmethod + def teardown_class(cls): + try: + cls.robot_sim.destroy() + print("robot_sim destroyed successfully") + except Exception as e: + print(f"Error during robot_sim.destroy(): {e}") + + def _execute_forward_trajectory(self, robot, qpos_list, delay=0.1): + """Helper method to execute trajectory""" + # Forward + for q in qpos_list: + robot.set_qpos(qpos=q, joint_ids=self.robot.get_joint_ids(self.arm_name)) + time.sleep(delay) + time.sleep(delay * 5) + + def _execute_backward_trajectory(self, robot, qpos_list, delay=0.1): + """Helper method to execute trajectory""" + # Backward + for q in qpos_list[::-1]: + robot.set_qpos(qpos=q, joint_ids=self.robot.get_joint_ids(self.arm_name)) + time.sleep(delay) + time.sleep(delay * 5) + + +class TestMotionGenerator(BaseTestMotionGenerator): + """Test suite for MotionGenerator trajectory generation""" + + @pytest.mark.parametrize("is_linear", [True, False]) + def test_create_trajectory_with_xpos(self, is_linear): + """Test trajectory generation with cartesian positions""" + self.robot.set_qpos(qpos=self.qpos_list[0], joint_ids=self.get_joint_ids()) + time.sleep(0.2) + out_qpos_list, out_xpos_list = self.motion_gen.create_discrete_trajectory( + xpos_list=self.xpos_list, + is_use_current_qpos=True, + sample_num=self.sample_num, + is_linear=is_linear, + sample_method=TrajectorySampleMethod.QUANTITY, + qpos_seed=self.qpos_list[0], + ) + out_qpos_list = to_numpy(out_qpos_list) + assert ( + len(out_qpos_list) == self.sample_num + ), f"Sample number mismatch: {len(out_qpos_list)} != {self.sample_num}" + np.testing.assert_array_almost_equal( + out_xpos_list[-1], self.xpos_list[-1], decimal=3 + ) + self._execute_trajectory(out_qpos_list, forward=True) + self.verify_final_xpos(self.xpos_list[-1]) + self._execute_trajectory(out_qpos_list, forward=False) + self.verify_final_xpos(self.xpos_list[0]) + + @pytest.mark.parametrize("is_linear", [True, False]) + def test_create_trajectory_with_qpos(self, is_linear): + """Test trajectory generation with joint positions""" + self.robot.set_qpos(qpos=self.qpos_list[0], joint_ids=self.get_joint_ids()) + time.sleep(0.05) + qpos_list_in = [qpos.to("cpu").numpy() for qpos in self.qpos_list] + out_qpos_list, out_xpos_list = self.motion_gen.create_discrete_trajectory( + qpos_list=qpos_list_in, + sample_num=self.sample_num, + is_linear=False, + sample_method=TrajectorySampleMethod.QUANTITY, + qpos_seed=self.qpos_list[0], + ) + out_qpos_list = to_numpy(out_qpos_list) + assert ( + len(out_qpos_list) == self.sample_num + ), f"Sample number mismatch: {len(out_qpos_list)} != {self.sample_num}" + np.testing.assert_array_almost_equal( + out_qpos_list[-1], self.qpos_list[-1], decimal=3 + ) + self._execute_trajectory(out_qpos_list, forward=True) + self.verify_final_xpos(self.xpos_list[-1]) + self._execute_trajectory(out_qpos_list, forward=False) + self.verify_final_xpos(self.xpos_list[0]) + + @pytest.mark.parametrize("xpos_or_qpos", ["xpos", "qpos"]) + def test_estimate_trajectory_sample_count(self, xpos_or_qpos: str): + """Test estimation of trajectory sample count""" + if xpos_or_qpos == "xpos": + estimated_num = self.motion_gen.estimate_trajectory_sample_count( + xpos_list=self.xpos_list, + step_size=0.01, + angle_step=np.pi / 90, + ) + else: + estimated_num = self.motion_gen.estimate_trajectory_sample_count( + qpos_list=self.qpos_list, + step_size=0.01, + angle_step=np.pi / 90, + ) + assert (estimated_num - 30) < 2, "Estimated sample count failed" + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + torch.set_printoptions(precision=5, sci_mode=False) + pytest_args = ["-v", "-s", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/sensors/test_camera.py b/tests/sim/sensors/test_camera.py new file mode 100644 index 00000000..8e578177 --- /dev/null +++ b/tests/sim/sensors/test_camera.py @@ -0,0 +1,153 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------, + +import pytest +import torch +import os + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.sensors import Camera, SensorCfg, CameraCfg +from embodichain.lab.sim.objects import Articulation +from embodichain.lab.sim.cfg import ArticulationCfg +from embodichain.data import get_data_path + + +NUM_ENVS = 4 +ART_PATH = "AiLiMu_BoxDrawer/AiLiMu_BoxDrawer.urdf" + + +class CameraTest: + def setup_simulation(self, sim_device, enable_rt): + # Setup SimulationManager + config = SimulationManagerCfg( + headless=True, sim_device=sim_device, enable_rt=enable_rt + ) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(NUM_ENVS) + # Create batch of cameras + cfg_dict = { + "sensor_type": "Camera", + "width": 640, + "height": 480, + "enable_mask": True, + "enable_depth": True, + "enable_normal": True, + "enable_position": True, + } + cfg = SensorCfg.from_dict(cfg_dict) + self.camera: Camera = self.sim.add_sensor(cfg) + + def test_get_data(self): + + self.camera.update() + + # Get data from the camera + data = self.camera.get_data() + + # Check if data is a dictionary + assert isinstance(data, dict), "Camera data should be a dictionary" + + # Check if all expected keys are present + for key in self.camera.SUPPORTED_DATA_TYPES: + assert key in data, f"Missing key in camera data: {key}" + + # Check if the data shape matches the expected shape + assert data["color"].shape == (NUM_ENVS, 480, 640, 4), "RGB data shape mismatch" + assert data["depth"].shape == ( + NUM_ENVS, + 480, + 640, + ), "Depth data shape mismatch" + assert data["normal"].shape == ( + NUM_ENVS, + 480, + 640, + 3, + ), "Normal data shape mismatch" + assert data["position"].shape == ( + NUM_ENVS, + 480, + 640, + 3, + ), "Position data shape mismatch" + assert data["mask"].shape == (NUM_ENVS, 480, 640), "Mask data shape mismatch" + + # Check if the data types are correct + assert data["color"].dtype == torch.uint8, "Color data type mismatch" + assert data["depth"].dtype == torch.float32, "Depth data type mismatch" + assert data["normal"].dtype == torch.float32, "Normal data type mismatch" + assert data["position"].dtype == torch.float32, "Position data type mismatch" + assert data["mask"].dtype == torch.int32, "Mask data type mismatch" + + def test_local_pose_with_env_ids(self): + env_ids = [0, 1, 2] + + pose = ( + torch.eye(4, device=self.sim.device).unsqueeze(0).repeat(len(env_ids), 1, 1) + ) + pose[:, 2, 3] = 2.0 + + self.camera.set_local_pose(pose, env_ids=env_ids) + + # Verify the local pose for specified env_ids + assert torch.allclose(self.camera.get_local_pose(to_matrix=True)[env_ids], pose) + + def test_attach_to_parent(self): + art_path = get_data_path(ART_PATH) + assert os.path.isfile(art_path) + + cfg_dict = {"fpath": art_path} + self.art: Articulation = self.sim.add_articulation( + cfg=ArticulationCfg.from_dict(cfg_dict) + ) + self.camera: Camera = self.sim.add_sensor( + sensor_cfg=CameraCfg( + uid="test", extrinsics=CameraCfg.ExtrinsicsCfg(parent="handle_xpos") + ) + ) + + def test_set_intrinsics(self): + # Define new intrinsic parameters + new_intrinsics = ( + torch.tensor( + [500.0, 500.0, 320.0, 240.0], + device=self.sim.device, + ) + .unsqueeze(0) + .repeat(NUM_ENVS, 1) + ) + + # Set new intrinsic parameters for all environments + self.camera.set_intrinsics(new_intrinsics) + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestCameraRaster(CameraTest): + def setup_method(self): + self.setup_simulation("cpu", enable_rt=False) + + +class TestCameraFastRT(CameraTest): + def setup_method(self): + self.setup_simulation("cpu", enable_rt=True) + + +if __name__ == "__main__": + test = CameraTest() + test.setup_simulation("cpu", enable_rt=False) diff --git a/tests/sim/sensors/test_stereo.py b/tests/sim/sensors/test_stereo.py new file mode 100644 index 00000000..737bae6a --- /dev/null +++ b/tests/sim/sensors/test_stereo.py @@ -0,0 +1,155 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ----------------------------------------------------------------------------, + +import pytest +import torch +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.sensors import StereoCamera, SensorCfg + + +NUM_ENVS = 4 + + +class StereoCameraTest: + def setup_simulation(self, sim_device, enable_rt): + # Setup SimulationManager + config = SimulationManagerCfg( + headless=True, sim_device=sim_device, enable_rt=enable_rt + ) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(4) + # Create batch of cameras + cfg_dict = { + "sensor_type": "StereoCamera", + "width": 640, + "height": 480, + "enable_mask": True, + "enable_depth": True, + "enable_normal": True, + "enable_position": True, + "enable_disparity": True, + "left_to_right_pos": (0.1, 0.0, 0.0), + } + cfg = SensorCfg.from_dict(cfg_dict) + self.camera: StereoCamera = self.sim.add_sensor(cfg) + + def test_get_data(self): + + self.camera.update() + + # Get data from the camera + data = self.camera.get_data() + + # Check if data is a dictionary + assert isinstance(data, dict), "Camera data should be a dictionary" + + # Check if all expected keys are present + for key in self.camera.SUPPORTED_DATA_TYPES: + assert key in data, f"Missing key in camera data: {key}" + + # Check if the data shape matches the expected shape + assert data["color"].shape == (NUM_ENVS, 480, 640, 4), "RGB data shape mismatch" + assert data["depth"].shape == ( + NUM_ENVS, + 480, + 640, + 1, + ), "Depth data shape mismatch" + assert data["normal"].shape == ( + NUM_ENVS, + 480, + 640, + 3, + ), "Normal data shape mismatch" + assert data["position"].shape == ( + NUM_ENVS, + 480, + 640, + 3, + ), "Position data shape mismatch" + assert data["mask"].shape == (NUM_ENVS, 480, 640, 1), "Mask data shape mismatch" + assert data["disparity"].shape == ( + NUM_ENVS, + 480, + 640, + 1, + ), "Disparity data shape mismatch" + + # Check if the data types are correct + assert data["color"].dtype == torch.uint8, "Color data type mismatch" + assert data["depth"].dtype == torch.float32, "Depth data type mismatch" + assert data["normal"].dtype == torch.float32, "Normal data type mismatch" + assert data["position"].dtype == torch.float32, "Position data type mismatch" + assert data["mask"].dtype == torch.int32, "Mask data type mismatch" + assert data["disparity"].dtype == torch.float32, "Disparity data type mismatch" + + def test_local_pose_with_env_ids(self): + env_ids = [0, 1, 2] + + pose = ( + torch.eye(4, device=self.sim.device).unsqueeze(0).repeat(len(env_ids), 1, 1) + ) + pose[:, 2, 3] = 2.0 + + self.camera.set_local_pose(pose, env_ids=env_ids) + + # Verify the local pose for specified env_ids + assert torch.allclose(self.camera.get_local_pose(to_matrix=True)[env_ids], pose) + + def test_set_intrinsics(self): + # Define new intrinsic parameters + new_intrinsics = ( + torch.tensor( + [500.0, 500.0, 320.0, 240.0], + device=self.sim.device, + ) + .unsqueeze(0) + .repeat(NUM_ENVS, 1) + ) + + # Set new intrinsic parameters for all environments + self.camera.set_intrinsics(new_intrinsics) + + right_intrinsics = ( + torch.tensor( + [520.0, 520.0, 315.0, 235.0], + device=self.sim.device, + ) + .unsqueeze(0) + .repeat(NUM_ENVS, 1) + ) + + self.camera.set_intrinsics(new_intrinsics, right_intrinsics=right_intrinsics) + + new_intrinsics = torch.tensor( + [500.0, 500.0, 320.0, 240.0], + device=self.sim.device, + ) + self.camera.set_intrinsics(new_intrinsics) + + def teardown_method(self): + """Clean up resources after each test method.""" + self.sim.destroy() + + +class TestStereoCameraRaster(StereoCameraTest): + def setup_method(self): + self.setup_simulation("cpu", enable_rt=False) + + +class TestStereoCameraFastRT(StereoCameraTest): + def setup_method(self): + self.setup_simulation("cpu", enable_rt=True) diff --git a/tests/sim/solvers/test_differential_solver.py b/tests/sim/solvers/test_differential_solver.py new file mode 100644 index 00000000..c9adc24e --- /dev/null +++ b/tests/sim/solvers/test_differential_solver.py @@ -0,0 +1,142 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.data import get_data_path + + +# Base test class for differential solver +class BaseSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self, solver_type: str): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(True) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [f"LEFT_J{i+1}" for i in range(7)], + "right_arm": [f"RIGHT_J{i+1}" for i in range(7)], + }, + "solver_cfg": { + "left_arm": { + "class_type": solver_type, + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + }, + "right_arm": { + "class_type": solver_type, + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_differential_solver(self, arm_name: str): + # Test differential solver with a 1x4x4 homogeneous matrix pose and a joint_seed + + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 2, 0.0, 0.0, 0.0]], + dtype=torch.float32, + device=self.robot.device, + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + # Define start and end poses + start_pose = fk_xpos.clone()[0] + end_pose = fk_xpos.clone()[0] + end_pose[:3, 3] += torch.tensor( + [0.0, 0.0, -0.02], dtype=torch.float32, device=self.robot.device + ) + + # Interpolate poses + num_steps = 5 + interpolated_poses = [ + torch.lerp(start_pose, end_pose, t) for t in np.linspace(0, 1, num_steps) + ] + + ik_qpos = qpos_fk + + for i, pose in enumerate(interpolated_poses): + res, ik_qpos = self.robot.compute_ik( + pose=pose, joint_seed=ik_qpos, name=arm_name + ) + assert res, f"IK failed for step {i} with pose:\n{pose}" + + # Verify forward kinematics matches the target pose + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + assert torch.allclose( + pose, ik_xpos, atol=5e-3, rtol=5e-3 + ), f"FK result does not match target pose at step {i}." + + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +class TestDifferentialSolver(BaseSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="DifferentialSolver") + + +if __name__ == "__main__": + torch.set_printoptions(precision=5, sci_mode=False) + pytest_args = ["-v", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/solvers/test_opw_solver.py b/tests/sim/solvers/test_opw_solver.py new file mode 100644 index 00000000..6df71a18 --- /dev/null +++ b/tests/sim/solvers/test_opw_solver.py @@ -0,0 +1,139 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.robots import CobotMagicCfg + + +# Base test class for OPWSolver +class BaseSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self): + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(False) + + cfg_dict = { + "uid": "CobotMagic", + "init_pos": [0.0, 0.0, 0.7775], + "init_qpos": [ + -0.3, + 0.3, + 1.0, + 1.0, + -1.2, + -1.2, + 0.0, + 0.0, + 0.6, + 0.6, + 0.0, + 0.0, + 0.05, + 0.05, + 0.05, + 0.05, + ], + "solver_cfg": { + "left_arm": { + "class_type": "OPWSolver", + "end_link_name": "left_link6", + "root_link_name": "left_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + "right_arm": { + "class_type": "OPWSolver", + "end_link_name": "right_link6", + "root_link_name": "right_arm_base", + "tcp": [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.143], [0, 0, 0, 1]], + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=CobotMagicCfg.from_dict(cfg_dict)) + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_ik(self, arm_name: str): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + + qpos_fk = torch.tensor( + [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], dtype=torch.float32 + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + res, ik_qpos = self.robot.compute_ik( + pose=fk_xpos, joint_seed=qpos_fk, name=arm_name + ) + + res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) + + if ik_qpos.dim() == 3: + ik_xpos = self.robot.compute_fk( + qpos=ik_qpos[0][0], name=arm_name, to_matrix=True + ) + else: + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + + assert torch.allclose( + fk_xpos, ik_xpos, atol=5e-3, rtol=5e-3 + ), f"FK and IK results do not match for {arm_name}" + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +class TestOPWSolver(BaseSolverTest): + def setup_method(self): + self.setup_simulation() + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + pytest_args = ["-v", "-s", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/solvers/test_pink_solver.py b/tests/sim/solvers/test_pink_solver.py new file mode 100644 index 00000000..590ed6dc --- /dev/null +++ b/tests/sim/solvers/test_pink_solver.py @@ -0,0 +1,146 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.data import get_data_path + + +# Base test class for differential solver +class BaseSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self, solver_type: str): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("Rokae/SR5/SR5.urdf") + + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "main_arm": [ + "joint1", + "joint2", + "joint3", + "joint4", + "joint5", + "joint6", + ], + }, + "solver_cfg": { + "main_arm": { + "class_type": "PinkSolver", + "end_link_name": "ee_link", + "root_link_name": "base_link", + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + def test_differential_solver(self): + # Test differential solver with a 1x4x4 homogeneous matrix pose and a joint_seed + arm_name = "main_arm" + + qpos_fk = torch.tensor( + [[0.0, 0.0, np.pi / 2, 0.0, np.pi / 2, 0.0]], dtype=torch.float32 + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + # Define start and end poses + start_pose = fk_xpos.clone()[0] + end_pose = fk_xpos.clone()[0] + end_pose[:3, 3] += torch.tensor([0.0, 0.4, 0.0], dtype=torch.float32) + + # Interpolate poses + num_steps = 100 + interpolated_poses = [ + torch.lerp(start_pose, end_pose, t) for t in np.linspace(0, 1, num_steps) + ] + + ik_qpos = qpos_fk + + for i, pose in enumerate(interpolated_poses): + res, ik_qpos = self.robot.compute_ik( + pose=pose, joint_seed=ik_qpos, name=arm_name + ) + assert res, f"IK failed for step {i} with pose:\n{pose}" + + # Verify forward kinematics matches the target pose + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + assert torch.allclose( + pose, ik_xpos, atol=1e-3, rtol=1e-3 + ), f"FK result does not match target pose at step {i}." + + # Set robot joint positions + self.robot.set_qpos( + qpos=ik_qpos, joint_ids=self.robot.get_joint_ids(arm_name) + ) + + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +@pytest.mark.skip(reason="Skipping Pink tests temporarily") +class TestPinkSolver(BaseSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="PinkSolver") + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + pytest_args = ["-v", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/solvers/test_pinocchio_solver.py b/tests/sim/solvers/test_pinocchio_solver.py new file mode 100644 index 00000000..db9dd360 --- /dev/null +++ b/tests/sim/solvers/test_pinocchio_solver.py @@ -0,0 +1,126 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.data import get_data_path + + +# Base test class for CPU and CUDA +class BaseSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self, solver_type: str): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(False) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [f"LEFT_J{i+1}" for i in range(7)], + "right_arm": [f"RIGHT_J{i+1}" for i in range(7)], + }, + "solver_cfg": { + "left_arm": { + "class_type": solver_type, + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + }, + "right_arm": { + "class_type": solver_type, + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_ik(self, arm_name: str): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], dtype=torch.float32 + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) + + if ik_qpos.dim() == 3: + ik_xpos = self.robot.compute_fk( + qpos=ik_qpos[0][0], name=arm_name, to_matrix=True + ) + else: + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + + assert torch.allclose( + fk_xpos, ik_xpos, atol=5e-3, rtol=5e-3 + ), f"FK and IK results do not match for {arm_name}" + + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +class TestPinocchioSolver(BaseSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="PinocchioSolver") + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + pytest_args = ["-v", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/solvers/test_pytorch_solver.py b/tests/sim/solvers/test_pytorch_solver.py new file mode 100644 index 00000000..75129f10 --- /dev/null +++ b/tests/sim/solvers/test_pytorch_solver.py @@ -0,0 +1,130 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.data import get_data_path + + +# Base test class for CPU and CUDA +class BaseSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self, solver_type: str): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device="cpu") + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(True) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [f"LEFT_J{i+1}" for i in range(7)], + "right_arm": [f"RIGHT_J{i+1}" for i in range(7)], + }, + "solver_cfg": { + "left_arm": { + "class_type": solver_type, + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + "ik_nearest_weight": [1.0, 1.0, 1.0, 0.9, 0.9, 0.1, 0.1], + }, + "right_arm": { + "class_type": solver_type, + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + # Wait for robot to stabilize. + self.sim.update(step=100) + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_ik(self, arm_name: str): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], dtype=torch.float32 + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) + + if ik_qpos.dim() == 3: + ik_xpos = self.robot.compute_fk( + qpos=ik_qpos[0][0], name=arm_name, to_matrix=True + ) + else: + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + + assert torch.allclose( + fk_xpos, ik_xpos, atol=1e-2, rtol=1e-2 + ), f"FK and IK results do not match for {arm_name}" + + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +class TestPytorchSolver(BaseSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="PytorchSolver") + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + test_solver = TestPytorchSolver() + test_solver.setup_method() diff --git a/tests/sim/solvers/test_srs_solver.py b/tests/sim/solvers/test_srs_solver.py new file mode 100644 index 00000000..13860485 --- /dev/null +++ b/tests/sim/solvers/test_srs_solver.py @@ -0,0 +1,307 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import torch +import pytest +import numpy as np + +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from embodichain.lab.sim.objects import Robot +from embodichain.lab.sim.cfg import RobotCfg +from embodichain.data import get_data_path + +from embodichain.lab.sim.solvers.srs_solver import SRSSolver, SRSSolverCfg +from embodichain.lab.sim.robots.dexforce_w1.types import ( + DexforceW1ArmSide, + DexforceW1ArmKind, + DexforceW1Version, +) +from embodichain.lab.sim.robots.dexforce_w1.params import ( + W1ArmKineParams, +) + + +class BaseSolverTest: + solver = {} + + def get_arm_config(self): + return [ + (DexforceW1ArmSide.LEFT, DexforceW1ArmKind.ANTHROPOMORPHIC, "left_arm"), + (DexforceW1ArmSide.RIGHT, DexforceW1ArmKind.ANTHROPOMORPHIC, "right_arm"), + (DexforceW1ArmSide.LEFT, DexforceW1ArmKind.INDUSTRIAL, "left_arm"), + (DexforceW1ArmSide.RIGHT, DexforceW1ArmKind.INDUSTRIAL, "right_arm"), + ] + + def setup_solver(self, solver_type: str, device: str = "cpu"): + for arm_side, arm_kind, arm_name in self.get_arm_config(): + arm_params = W1ArmKineParams( + arm_side=arm_side, + arm_kind=arm_kind, + version=DexforceW1Version.V021, + ) + if arm_kind == DexforceW1ArmKind.ANTHROPOMORPHIC: + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + else: + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_2.urdf") + + cfg = SRSSolverCfg() + cfg.joint_names = [ + f"{'LEFT' if arm_side == DexforceW1ArmSide.LEFT else 'RIGHT'}_J{i+1}" + for i in range(7) + ] + cfg.end_link_name = ( + "left_ee" if arm_side == DexforceW1ArmSide.LEFT else "right_ee" + ) + cfg.root_link_name = ( + "left_arm_base" + if arm_side == DexforceW1ArmSide.LEFT + else "right_arm_base" + ) + cfg.urdf_path = urdf + cfg.dh_params = arm_params.dh_params + cfg.qpos_limits = arm_params.qpos_limits + cfg.T_e_oe = arm_params.T_e_oe + cfg.T_b_ob = arm_params.T_b_ob + cfg.link_lengths = arm_params.link_lengths + cfg.rotation_directions = arm_params.rotation_directions + + solver_key = f"{arm_name}_{arm_kind.name}" + self.solver[solver_key] = SRSSolver(cfg=cfg, num_envs=1, device=device) + + @pytest.mark.parametrize( + "arm_side, arm_kind, arm_name", + [ + (DexforceW1ArmSide.LEFT, DexforceW1ArmKind.ANTHROPOMORPHIC, "left_arm"), + (DexforceW1ArmSide.RIGHT, DexforceW1ArmKind.ANTHROPOMORPHIC, "right_arm"), + (DexforceW1ArmSide.LEFT, DexforceW1ArmKind.INDUSTRIAL, "left_arm"), + (DexforceW1ArmSide.RIGHT, DexforceW1ArmKind.INDUSTRIAL, "right_arm"), + ], + ) + def test_ik( + self, arm_side: DexforceW1ArmSide, arm_kind: DexforceW1ArmKind, arm_name: str + ): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + solver_key = f"{arm_name}_{arm_kind.name}" + device = self.solver[solver_key].device + + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], + dtype=torch.float32, + device=device, + ) + + fk_xpos = self.solver[solver_key].get_fk(qpos=qpos_fk) + + _, ik_qpos = self.solver[solver_key].get_ik(fk_xpos, return_all_solutions=False) + + ik_xpos = self.solver[solver_key].get_fk(qpos=ik_qpos[:, 0, :]) + + assert torch.allclose( + fk_xpos, ik_xpos, atol=1e-3, rtol=1e-3 + ), f"FK and IK results do not match for {solver_key}" + + @classmethod + def teardown_class(cls): + if cls.solver is not None: + try: + del cls.solver + print("solver destroyed successfully") + except Exception as e: + print(f"Error during solver destruction: {e}") + + +# Base test class for CPU and CUDA +class BaseRobotSolverTest: + sim = None # Define as a class attribute + + def setup_simulation(self, solver_type: str, device: str = "cpu"): + # Set up simulation with specified device (CPU or CUDA) + config = SimulationManagerCfg(headless=True, sim_device=device) + self.sim = SimulationManager(config) + self.sim.build_multiple_arenas(1) + self.sim.set_manual_update(True) + + # Load robot URDF file + urdf = get_data_path("DexforceW1V021/DexforceW1_v02_1.urdf") + assert os.path.isfile(urdf) + + w1_left_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.LEFT, + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + version=DexforceW1Version.V021, + ) + w1_right_arm_params = W1ArmKineParams( + arm_side=DexforceW1ArmSide.RIGHT, + arm_kind=DexforceW1ArmKind.ANTHROPOMORPHIC, + version=DexforceW1Version.V021, + ) + + # Robot configuration dictionary + cfg_dict = { + "fpath": urdf, + "control_parts": { + "left_arm": [f"LEFT_J{i+1}" for i in range(7)], + "right_arm": [f"RIGHT_J{i+1}" for i in range(7)], + "torso": ["ANKLE", "KNEE", "BUTTOCK", "WAIST"], + "head": [f"NECK{i+1}" for i in range(2)], + }, + "drive_pros": { + "stiffness": { + "LEFT_J[1-7]": 1e4, + "RIGHT_J[1-7]": 1e4, + "ANKLE": 1e7, + "KNEE": 1e7, + "BUTTOCK": 1e7, + "WAIST": 1e7, + }, + "damping": { + "LEFT_J[1-7]": 1e3, + "RIGHT_J[1-7]": 1e3, + "ANKLE": 1e4, + "KNEE": 1e4, + "BUTTOCK": 1e4, + "WAIST": 1e4, + }, + "max_effort": { + "LEFT_J[1-7]": 1e5, + "RIGHT_J[1-7]": 1e5, + "ANKLE": 1e10, + "KNEE": 1e10, + "BUTTOCK": 1e10, + "WAIST": 1e10, + }, + }, + "attrs": { + "mass": 1e-1, + "static_friction": 0.95, + "dynamic_friction": 0.9, + "linear_damping": 0.7, + "angular_damping": 0.7, + "max_depenetration_velocity": 10.0, + "min_position_iters": 32, + "min_velocity_iters": 8, + }, + "solver_cfg": { + "left_arm": { + "class_type": solver_type, + "end_link_name": "left_ee", + "root_link_name": "left_arm_base", + "dh_params": w1_left_arm_params.dh_params, + "qpos_limits": w1_left_arm_params.qpos_limits, + "T_b_ob": w1_right_arm_params.T_b_ob, + "T_e_oe": w1_left_arm_params.T_e_oe, + "link_lengths": w1_left_arm_params.link_lengths, + "rotation_directions": w1_left_arm_params.rotation_directions, + }, + "right_arm": { + "class_type": solver_type, + "end_link_name": "right_ee", + "root_link_name": "right_arm_base", + "dh_params": w1_right_arm_params.dh_params, + "qpos_limits": w1_right_arm_params.qpos_limits, + "T_b_ob": w1_right_arm_params.T_b_ob, + "T_e_oe": w1_right_arm_params.T_e_oe, + "link_lengths": w1_right_arm_params.link_lengths, + "rotation_directions": w1_right_arm_params.rotation_directions, + }, + }, + } + + self.robot: Robot = self.sim.add_robot(cfg=RobotCfg.from_dict(cfg_dict)) + + # Wait for robot to stabilize. + self.sim.update(step=100) + + @pytest.mark.parametrize("arm_name", ["left_arm", "right_arm"]) + def test_robot_ik(self, arm_name: str): + # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed + + qpos_fk = torch.tensor( + [[0.0, 0.0, 0.0, -np.pi / 4, 0.0, 0.0, 0.0]], + dtype=torch.float32, + device=self.robot.device, + ) + + fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + + res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) + + if ik_qpos.dim() == 3: + ik_xpos = self.robot.compute_fk( + qpos=ik_qpos[0][0], name=arm_name, to_matrix=True + ) + else: + ik_xpos = self.robot.compute_fk(qpos=ik_qpos, name=arm_name, to_matrix=True) + + assert torch.allclose( + fk_xpos, ik_xpos, atol=1e-4, rtol=1e-4 + ), f"FK and IK results do not match for {arm_name}" + + # test for failed xpos + invalid_pose = torch.tensor( + [ + [ + [1.0, 0.0, 0.0, 10.0], + [0.0, 1.0, 0.0, 10.0], + [0.0, 0.0, 1.0, 10.0], + [0.0, 0.0, 0.0, 1.0], + ] + ], + dtype=torch.float32, + device=self.robot.device, + ) + res, ik_qpos = self.robot.compute_ik( + pose=invalid_pose, joint_seed=ik_qpos, name=arm_name + ) + dof = ik_qpos.shape[-1] + assert res[0] == False + assert ik_qpos.shape == (1, dof) + + @classmethod + def teardown_class(cls): + if cls.sim is not None: + try: + cls.sim.destroy() + print("sim destroyed successfully") + except Exception as e: + print(f"Error during sim.destroy(): {e}") + + +class TestSRSCPUSolver(BaseSolverTest): + def setup_method(self): + self.setup_solver(solver_type="SRSSolver", device="cpu") + + +class TestSRSCUDASolver(BaseSolverTest): + def setup_method(self): + self.setup_solver(solver_type="SRSSolver", device="cuda") + + +class TestSRSCPURobotSolver(BaseRobotSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="SRSSolver", device="cpu") + + +class TestSRSCUDARobotSolver(BaseRobotSolverTest): + def setup_method(self): + self.setup_simulation(solver_type="SRSSolver", device="cuda") + + +if __name__ == "__main__": + np.set_printoptions(precision=5, suppress=True) + pytest_args = ["-v", __file__] + pytest.main(pytest_args) diff --git a/tests/sim/test_sim_manager.py b/tests/sim/test_sim_manager.py new file mode 100644 index 00000000..72b2a254 --- /dev/null +++ b/tests/sim/test_sim_manager.py @@ -0,0 +1,63 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import os +import dexsim.environment +import numpy as np +import dexsim +from embodichain.lab.sim import SimulationManager, SimulationManagerCfg +from dexsim.utility.path import get_resources_data_path + + +def test_sim_init(): + config = SimulationManagerCfg() + config.headless = True + + sim = SimulationManager(config) + sim.get_env().clean() + + assert isinstance(sim.get_env(), dexsim.environment.Env) + assert isinstance(sim.get_world(), dexsim.World) + + # test add_sensor + intrinsic = np.array([[600, 0, 320], [0, 600, 240], [0, 0, 1]]) + cam1 = sim.add_sensor( + "MonocularCam", sensor_uid="cam1", resolution=(640, 480), intrinsic=intrinsic + ) + assert sim.get_sensor("cam1") == cam1 + assert len(sim.get_sensor_uid_list()) == 1 + assert sim.get_sensor_uid_list()[0] == "cam1" + + # TODO: test add_robot + + # test_add_fixed_actor. + model_path = get_resources_data_path("Model", "lego", "lego.ply") + + actor = sim.add_fixed_actor(fpath=model_path, init_pose=np.eye(4)) + assert sim.get_fixed_actor_uid_list() == ["lego.ply"] + assert sim.get_fixed_actor("lego.ply") == actor + + sim.remove_fixed_actor("lego.ply") + assert sim.get_fixed_actor_uid_list() == [] + + # test add_dynamic_actor + actor = sim.add_dynamic_actor(fpath=model_path, init_pose=np.eye(4)) + assert sim.get_dynamic_actor_uid_list() == ["lego.ply"] + assert sim.get_dynamic_actor("lego.ply") == actor + + +if __name__ == "__main__": + test_sim_init() diff --git a/tests/toolkits/test_pg_grasp.py b/tests/toolkits/test_pg_grasp.py new file mode 100644 index 00000000..dd9f79e5 --- /dev/null +++ b/tests/toolkits/test_pg_grasp.py @@ -0,0 +1,96 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import open3d as o3d +import numpy as np +import os +from embodichain.toolkits.graspkit.pg_grasp import ( + AntipodalGenerator, + GraspSelectMethod, +) +from embodichain.data import get_data_path + + +def test_antipodal_score_selector(is_visual: bool = False): + mesh_path = get_data_path("ChainRainSec/mesh.ply") + mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) + generator = AntipodalGenerator( + mesh_o3dt=mesh_o3dt, + open_length=0.1, + max_angle=np.pi / 6, + surface_sample_num=5000, + cache_dir=None, + ) + grasp_list = generator.select_grasp( + approach_direction=np.array([0, 0, -1]), + select_num=5, + select_method=GraspSelectMethod.NORMAL_SCORE, + ) + assert len(grasp_list) == 5 + if is_visual: + visual_mesh_list = generator.grasp_pose_visual(grasp_list) + visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] + o3d.visualization.draw_geometries(visual_mesh_list) + + +def test_antipodal_position_selector(is_visual: bool = False): + mesh_path = get_data_path("ChainRainSec/mesh.ply") + mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) + generator = AntipodalGenerator( + mesh_o3dt=mesh_o3dt, + open_length=0.1, + max_angle=np.pi / 6, + surface_sample_num=5000, + cache_dir=None, + ) + grasp_list = generator.select_grasp( + approach_direction=np.array([0, 0, -1]), + select_num=5, + select_method=GraspSelectMethod.NEAR_APPROACH, + ) + assert len(grasp_list) == 5 + if is_visual: + visual_mesh_list = generator.grasp_pose_visual(grasp_list) + visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] + o3d.visualization.draw_geometries(visual_mesh_list) + + +def test_antipodal_center_selector(is_visual: bool = False): + mesh_path = get_data_path("ChainRainSec/mesh.ply") + mesh_o3dt = o3d.t.io.read_triangle_mesh(mesh_path) + generator = AntipodalGenerator( + mesh_o3dt=mesh_o3dt, + open_length=0.1, + max_angle=np.pi / 6, + surface_sample_num=5000, + cache_dir=None, + ) + grasp_list = generator.select_grasp( + approach_direction=np.array([0, 0, -1]), + select_num=5, + select_method=GraspSelectMethod.CENTER, + ) + assert len(grasp_list) == 5 + if is_visual: + visual_mesh_list = generator.grasp_pose_visual(grasp_list) + visual_mesh_list = [visual_mesh.to_legacy() for visual_mesh in visual_mesh_list] + o3d.visualization.draw_geometries(visual_mesh_list) + + +if __name__ == "__main__": + test_antipodal_score_selector(True) + test_antipodal_position_selector(True) + test_antipodal_center_selector(True) diff --git a/tests/unitest.sh b/tests/unitest.sh new file mode 100755 index 00000000..d827e1cb --- /dev/null +++ b/tests/unitest.sh @@ -0,0 +1,93 @@ +echo "exec python ..." + +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +log_info() { + echo -e "${BLUE}[INFO] $1${NC}" +} + +log_success() { + echo -e "${GREEN}[SUCCESS] $1${NC}" +} + +log_warn() { + echo -e "${YELLOW}[WARN] $1${NC}" +} + +log_error() { + echo -e "${RED}[ERROR] $1${NC}" +} + +run_python_script() { + local script_name="$1" + local script_path="$2" + + log_info "Running: ${script_name}" + + if python "$script_path"; then + log_success "${script_name} succeeded" + return 0 + else + log_error "${script_name} failed" + return 1 + fi +} + +run_pytest() { + local test_path="$1" + log_info "Running pytest: ${test_path}" + + local pytest_args=( + "--durations=1000" # record the slowest 1000 tests + "--tb=long" # long traceback + "-vv" # verbose + "--disable-warnings" # disable warnings + "--color=yes" # enable color output + ) + + pytest "${pytest_args[@]}" "$test_path" + local status=$? + + # Check if no tests were collected + if pytest --collect-only "$test_path" | grep -q "collected 0 items"; then + log_warn "No tests collected: ${test_path}" + return 0 + fi + + # Check pytest return code + if [ $status -ne 0 ]; then + log_error "pytest failed: ${test_path}" + exit 1 + fi +} + +main() { + echo "Starting scripts..." + + run_python_script "pourwater_offline_test" "tests/datasets/run_pourwater_env_offline.py" || exit 1 + + echo "Starting pytest unit tests..." + + for test_dir in tests/*/; do + # Check whether the directory contains any recursive test_*.py files; do not skip if top-level has none (supports subdirectory structures) + if ! find "$test_dir" -type f -name 'test_*.py' | grep -q .; then + log_warn "Skipping empty directory or no test files: ${test_dir}" + continue + fi + run_pytest "$test_dir" + done + + log_success "All test scripts and unit tests completed!" +} + +main "$@" + +# # demo +# ./demo/origin_demo.sh CI +# echo -e "\e[32morigin_demo executed successfully\e[0m"